[
  {
    "path": ".github/workflows/build_docs.yml",
    "content": "name: Build docs\n\non:\n  push:\n    branches:\n      - main\n\njobs:\n  build:\n    strategy:\n      matrix:\n        python-version: [ 3.11 ]\n        os: [ ubuntu-latest ]\n    runs-on: ${{ matrix.os }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v2\n\n      - name: Install the latest version of uv\n        uses: astral-sh/setup-uv@v7\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Install dependencies\n        run: |\n          uv run echo done\n\n      - name: Build docs\n        run: |\n          uv run mkdocs build\n\n      - name: Upload docs\n        uses: actions/upload-artifact@v4\n        with:\n          name: docs\n          path: site  # where `mkdocs build` puts the built site\n"
  },
  {
    "path": ".github/workflows/release.yml",
    "content": "name: Release\n\non:\n  push:\n    branches:\n      - main\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Release\n        uses: patrick-kidger/action_update_python_project@v8\n        with:\n            python-version: \"3.11\"\n            # Uninstall and reinstall pytest to work around the fact that it doesn't get put into `bin` otherwise.\n            test-script: |\n                cp -r ${{ github.workspace }}/tests ./tests\n                cp ${{ github.workspace }}/pyproject.toml ./pyproject.toml\n                uv pip uninstall pytest\n                uv sync --no-install-project --inexact\n                uv run --no-sync pytest\n            pypi-token: ${{ secrets.pypi_token }}\n            github-user: patrick-kidger\n            github-token: ${{ github.token }}\n"
  },
  {
    "path": ".github/workflows/run_tests.yml",
    "content": "name: Run tests\n\non:\n  pull_request:\n\njobs:\n  run-test:\n    strategy:\n      matrix:\n        python-version: [ 3.11 ]\n        os: [ ubuntu-latest ]\n      fail-fast: false\n    runs-on: ${{ matrix.os }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v2\n\n      - name: Install the latest version of uv\n        uses: astral-sh/setup-uv@v7\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Install dependencies\n        run: |\n          uv run echo done\n\n      - name: Checks with pre-commit\n        run: |\n          uv run prek run --all-files\n\n      - name: Test with pytest\n        run: |\n          uv run python -m tests\n\n      - name: Check that documentation can be built.\n        run: |\n          uv run mkdocs build\n"
  },
  {
    "path": ".gitignore",
    "content": "**/__pycache__\n**/.ipynb_checkpoints\n*.egg-info/\nbuild/\ndist/\nsite/\nexamples/data\n.all_objects.cache\n.pymon\n.idea\n.venv\nuv.lock\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "fail_fast: true\nrepos:\n  - repo: meta\n    hooks:\n    - id: check-hooks-apply\n    - id: check-useless-excludes\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.4.0\n    hooks:\n    - id: trailing-whitespace\n      exclude: \\.md$\n    - id: check-toml\n    - id: mixed-line-ending\n  - repo: local\n    hooks:\n      - id: sort-pyproject\n        name: sort pyproject\n        files: ^pyproject\\.toml$\n        language: system\n        entry: uv run -- toml-sort -i --sort-table-keys --sort-inline-tables\n      - id: ruff-format\n        name: ruff format\n        types_or: [python, pyi, jupyter, toml]\n        language: system\n        entry: uv run -- ruff format --\n        require_serial: true\n      - id: ruff-lint\n        name: ruff lint\n        types_or: [python, pyi, jupyter, toml]\n        language: system\n        entry: uv run -- ruff check --fix --\n        require_serial: true\n      - id: pyright\n        name: pyright\n        types_or: [python]\n        language: system\n        entry: uv run -- pyright\n        require_serial: true\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing\n\nContributions (pull requests) are very welcome! Here's how to get started.\n\n---\n\n### Getting started\n\n[We assume that you have `uv` installed.](https://docs.astral.sh/uv/) Now fork the library on GitHub. Then clone and install the library:\n\n```bash\ngit clone https://github.com/your-username-here/lineax.git\ncd lineax\nuv run prek install  # Creates a local venv + installs dependencies + installs pre-commit hooks.\n```\n\n---\n\n### If you're making changes to the code\n\nNow make your changes. Make sure to include additional tests if necessary. Next verify the tests all pass:\n\n```bash\nuv run python -m tests\n```\n\nThen push your changes back to your fork of the repository:\n\n```bash\ngit push\n```\n\nFinally, open a pull request on GitHub!\n\n---\n\n### If you're making changes to the documentation\n\nMake your changes. You can then build the documentation by doing\n\n```bash\nuv run mkdocs serve\n```\n\nYou can then see your local copy of the documentation by navigating to `localhost:8000` in a web browser.\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<h1 align='center'>Lineax</h1>\n\nLineax is a [JAX](https://github.com/google/jax) library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.)\n\nFeatures include:\n- PyTree-valued matrices and vectors;\n- General linear operators for Jacobians, transposes, etc.;\n- Efficient linear least squares (e.g. QR solvers);\n- Numerically stable gradients through linear least squares;\n- Support for structured (e.g. symmetric) matrices;\n- Improved compilation times;\n- Improved runtime of some algorithms;\n- Support for both real-valued and complex-valued inputs;\n- All the benefits of working with JAX: autodiff, autoparallelism, GPU/TPU support, etc.\n\n## Installation\n\n```bash\npip install lineax\n```\n\nRequires Python 3.10+, JAX 0.4.38+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.10+.\n\n## Documentation\n\nAvailable at [https://docs.kidger.site/lineax](https://docs.kidger.site/lineax).\n\n## Quick examples\n\nLineax can solve a least squares problem with an explicit matrix operator:\n\n```python\nimport jax.random as jr\nimport lineax as lx\n\nmatrix_key, vector_key = jr.split(jr.PRNGKey(0))\nmatrix = jr.normal(matrix_key, (10, 8))\nvector = jr.normal(vector_key, (10,))\noperator = lx.MatrixLinearOperator(matrix)\nsolution = lx.linear_solve(operator, vector, solver=lx.QR())\n```\n\nor Lineax can solve a problem without ever materializing a matrix, as done in this\nquadratic solve:\n\n```python\nimport jax\nimport lineax as lx\n\nkey = jax.random.PRNGKey(0)\ny = jax.random.normal(key, (10,))\n\ndef quadratic_fn(y, args):\n  return jax.numpy.sum((y - 1)**2)\n\ngradient_fn = jax.grad(quadratic_fn)\nhessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag)\nsolver = lx.CG(rtol=1e-6, atol=1e-6)\nout = lx.linear_solve(hessian, gradient_fn(y, args=None), solver)\nminimum = y - out.value\n```\n\n## Citation\n\nIf you found this library to be useful in academic work, then please cite: ([arXiv link](https://arxiv.org/abs/2311.17283))\n\n```bibtex\n@article{lineax2023,\n    title={Lineax: unified linear solves and linear least-squares in JAX and Equinox},\n    author={Jason Rader and Terry Lyons and Patrick Kidger},\n    journal={\n        AI for science workshop at Neural Information Processing Systems 2023,\n        arXiv:2311.17283\n    },\n    year={2023},\n}\n```\n\n(Also consider starring the project on GitHub.)\n\n## See also: other libraries in the JAX ecosystem\n\n**Always useful**  \n[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!  \n[jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.  \n\n**Deep learning**  \n[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.  \n[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).  \n[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).  \n[paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.  \n\n**Scientific computing**  \n[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.  \n[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.  \n[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.  \n[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.  \n[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)  \n\n**Awesome JAX**  \n[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.  \n"
  },
  {
    "path": "benchmarks/gmres_fails_safely.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools as ft\n\nimport equinox as eqx\nimport equinox.internal as eqxi\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport jax.scipy as jsp\nimport lineax as lx\n\n\ngetkey = eqxi.GetKey()\n\n\ndef tree_allclose(x, y, *, rtol=1e-5, atol=1e-8):\n    return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)\n\n\njax.config.update(\"jax_enable_x64\", True)\n\n\ndef make_problem(mat_size: int, *, key):\n    mat = jr.normal(key, (mat_size, mat_size))\n    true_x = jr.normal(key, (mat_size,))\n    b = mat @ true_x\n    op = lx.MatrixLinearOperator(mat)\n    return mat, op, b, true_x\n\n\ndef benchmark_jax(mat_size: int, *, key):\n    mat, _, b, true_x = make_problem(mat_size, key=key)\n\n    solve_with_jax = ft.partial(\n        jsp.sparse.linalg.gmres, tol=1e-5, solve_method=\"batched\"\n    )\n    gmres_jit = jax.jit(solve_with_jax)\n    jax_soln, info = gmres_jit(mat, b)\n\n    # info == 0.0 implies that the solve has succeeded.\n    returned_failed = jnp.all(info != 0.0)\n    actually_failed = not tree_allclose(jax_soln, true_x, atol=1e-4, rtol=1e-4)\n\n    assert actually_failed\n\n    captured_failure = returned_failed & actually_failed\n    return captured_failure\n\n\ndef benchmark_lx(mat_size: int, *, key):\n    _, op, b, true_x = make_problem(mat_size, key=key)\n\n    lx_soln = lx.linear_solve(op, b, lx.GMRES(atol=1e-5, rtol=1e-5), throw=False)\n\n    returned_failed = jnp.all(lx_soln.result != lx.RESULTS.successful)\n    actually_failed = not tree_allclose(lx_soln.value, true_x, atol=1e-4, rtol=1e-4)\n\n    assert actually_failed\n\n    captured_failure = returned_failed & actually_failed\n    return captured_failure\n\n\nlx_failed_safely = 0\njax_failed_safely = 0\n\nfor _ in range(100):\n    key = getkey()\n    jax_captured_failure = benchmark_jax(100, key=key)\n    lx_captured_failure = benchmark_lx(100, key=key)\n\n    jax_failed_safely = jax_failed_safely + jax_captured_failure\n    lx_failed_safely = lx_failed_safely + lx_captured_failure\n\nprint(f\"JAX failed safely {jax_failed_safely} out of 100 times\")\nprint(f\"Lineax failed safely {lx_failed_safely} out of 100 times\")\n"
  },
  {
    "path": "benchmarks/lstsq_gradients.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Core JAX has some numerical issues with their lstsq gradients.\n# See https://github.com/google/jax/issues/14868\n# This demonstrates that we don't have the same issue!\n\nimport sys\n\nimport jax\nimport jax.numpy as jnp\nimport lineax as lx\n\n\nsys.path.append(\"../tests\")\nfrom helpers import finite_difference_jvp  # pyright: ignore\n\n\na_primal = (jnp.eye(3),)\na_tangent = (jnp.zeros((3, 3)),)\n\n\ndef jax_solve(a):\n    sol, _, _, _ = jnp.linalg.lstsq(a, jnp.arange(3))  # pyright: ignore\n    return sol\n\n\ndef lx_solve(a):\n    op = lx.MatrixLinearOperator(a)\n    return lx.linear_solve(op, jnp.arange(3)).value\n\n\n_, true_jvp = finite_difference_jvp(jax_solve, a_primal, a_tangent)\n_, jax_jvp = jax.jvp(jax_solve, a_primal, a_tangent)\n_, lx_jvp = jax.jvp(lx_solve, a_primal, a_tangent)\nassert jnp.isnan(jax_jvp).all()\nassert jnp.allclose(true_jvp, lx_jvp)\n"
  },
  {
    "path": "benchmarks/solver_speeds.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools as ft\nimport sys\nimport timeit\n\nimport equinox as eqx\nimport equinox.internal as eqxi\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport jax.scipy as jsp\nimport lineax as lx\n\n\nsys.path.append(\"../tests\")\nfrom helpers import construct_matrix, has_tag  # pyright: ignore[reportMissingImports]\n\n\ngetkey = eqxi.GetKey()\n\n\ndef tree_allclose(x, y, *, rtol=1e-5, atol=1e-8):\n    return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)\n\n\njax.config.update(\"jax_enable_x64\", True)\n\nif jax.config.jax_enable_x64:  # pyright: ignore\n    tol = 1e-12\nelse:\n    tol = 1e-6\n\n\ndef base_wrapper(a, b, solver):\n    op = lx.MatrixLinearOperator(\n        a,\n        (\n            lx.positive_semidefinite_tag,\n            lx.symmetric_tag,\n            lx.diagonal_tag,\n            lx.tridiagonal_tag,\n        ),\n    )\n    out = lx.linear_solve(op, b, solver, throw=False)\n    return out.value\n\n\ndef jax_svd(a, b):\n    out, _, _, _ = jnp.linalg.lstsq(a, b)  # pyright: ignore\n    return out\n\n\ndef jax_gmres(a, b):\n    out, _ = jsp.sparse.linalg.gmres(a, b, tol=tol)\n    return out\n\n\ndef jax_bicgstab(a, b):\n    out, _ = jsp.sparse.linalg.bicgstab(a, b, tol=tol)\n    return out\n\n\ndef jax_cg(a, b):\n    out, _ = jsp.sparse.linalg.cg(a, b, tol=tol)\n    return out\n\n\ndef jax_lu(matrix, vector):\n    return jsp.linalg.lu_solve(jsp.linalg.lu_factor(matrix), vector)\n\n\ndef jax_cholesky(matrix, vector):\n    return jsp.linalg.cho_solve(jsp.linalg.cho_factor(matrix), vector)\n\n\ndef jax_tridiagonal(matrix, vector):\n    dl = jnp.append(0.0, matrix.diagonal(-1))\n    d = matrix.diagonal(0)\n    du = jnp.append(matrix.diagonal(1), 0.0)\n    return jax.lax.linalg.tridiagonal_solve(dl, d, du, vector[:, None])[:, 0]\n\n\nnamed_solvers = [\n    (\"LU\", \"LU\", lx.LU(), jax_lu, ()),\n    (\"QR\", \"SVD\", lx.QR(), jax_svd, ()),\n    (\"SVD\", \"SVD\", lx.SVD(), jax_svd, ()),\n    (\n        \"Cholesky\",\n        \"Cholesky\",\n        lx.Cholesky(),\n        jax_cholesky,\n        lx.positive_semidefinite_tag,\n    ),\n    (\"Diagonal\", \"None\", lx.Diagonal(), None, lx.diagonal_tag),\n    (\n        \"Tridiagonal\",\n        \"Tridiagonal\",\n        lx.Tridiagonal(),\n        jax_tridiagonal,\n        lx.tridiagonal_tag,\n    ),\n    (\n        \"CG\",\n        \"CG\",\n        lx.CG(atol=tol, rtol=tol, stabilise_every=None),\n        jax_cg,\n        lx.positive_semidefinite_tag,\n    ),\n    (\n        \"GMRES\",\n        \"GMRES\",\n        lx.GMRES(atol=1, rtol=1),\n        jax_gmres,\n        (),\n    ),\n    (\n        \"BiCGStab\",\n        \"BiCGStab\",\n        lx.BiCGStab(atol=tol, rtol=tol),\n        jax_bicgstab,\n        (),\n    ),\n]\n\n\ndef create_problem(solver, tags, size=3):\n    (matrix,) = construct_matrix(getkey, solver, tags, size=size)\n    true_x = jr.normal(getkey(), (size,))\n    b = matrix @ true_x\n    return matrix, true_x, b\n\n\ndef create_easy_iterative_problem(size, tags):\n    matrix = jr.normal(getkey(), (size, size)) / size + 2 * jnp.eye(size)\n    true_x = jr.normal(getkey(), (size,))\n    if has_tag(tags, lx.positive_semidefinite_tag):\n        matrix = matrix.T @ matrix\n    b = matrix @ true_x\n    return matrix, true_x, b\n\n\ndef test_solvers(vmap_size, mat_size):\n    for lx_name, jax_name, _lx_solver, jax_solver, tags in named_solvers:\n        lx_solver = ft.partial(base_wrapper, solver=_lx_solver)\n        if vmap_size == 1:\n            if isinstance(_lx_solver, (lx.CG, lx.GMRES, lx.BiCGStab)):\n                matrix, true_x, b = create_easy_iterative_problem(mat_size, tags)\n            else:\n                matrix, true_x, b = create_problem(lx_solver, tags, size=mat_size)\n        else:\n            if isinstance(_lx_solver, (lx.CG, lx.GMRES, lx.BiCGStab)):\n                matrix, true_x, b = eqx.filter_vmap(\n                    create_easy_iterative_problem,\n                    axis_size=vmap_size,\n                    out_axes=eqx.if_array(0),\n                )(mat_size, tags)\n            else:\n                matrix, true_x, b = create_problem(lx_solver, tags, size=mat_size)\n                _create_problem = ft.partial(create_problem, size=mat_size)\n                matrix, true_x, b = eqx.filter_vmap(\n                    _create_problem, axis_size=vmap_size, out_axes=eqx.if_array(0)\n                )(lx_solver, tags)\n\n            lx_solver = jax.vmap(lx_solver)\n            if jax_solver is not None:\n                jax_solver = jax.vmap(jax_solver)\n\n        lx_solver = jax.jit(lx_solver)\n        bench_lx = ft.partial(lx_solver, matrix, b)\n\n        if vmap_size == 1:\n            batch_msg = \"problem\"\n        else:\n            batch_msg = f\"batch of {vmap_size} problems\"\n\n        lx_soln = bench_lx()\n        if tree_allclose(lx_soln, true_x, atol=1e-4, rtol=1e-4):\n            lx_solve_time = timeit.timeit(bench_lx, number=1)\n\n            print(\n                f\"Lineax's {lx_name} solved {batch_msg} of \"\n                f\"size {mat_size} in {lx_solve_time} seconds.\"\n            )\n        else:\n            fail_time = timeit.timeit(bench_lx, number=1)\n            err = jnp.abs(lx_soln - true_x).max()\n            print(\n                f\"Lineax's {lx_name} failed to solve {batch_msg} of \"\n                f\"size {mat_size} with error {err} in {fail_time} seconds\"\n            )\n        if jax_solver is None:\n            print(\"JAX has no equivalent solver. \\n\")\n\n        else:\n            jax_solver = jax.jit(jax_solver)\n            bench_jax = ft.partial(jax_solver, matrix, b)\n            jax_soln = bench_jax()\n            if tree_allclose(jax_soln, true_x, atol=1e-4, rtol=1e-4):\n                jax_solve_time = timeit.timeit(bench_jax, number=1)\n                print(\n                    f\"JAX's {jax_name} solved {batch_msg} of \"\n                    f\"size {mat_size} in {jax_solve_time} seconds. \\n\"\n                )\n            else:\n                fail_time = timeit.timeit(bench_jax, number=1)\n                err = jnp.abs(jax_soln - true_x).max()\n                print(\n                    f\"JAX's {jax_name} failed to solve {batch_msg} of \"\n                    f\"size {mat_size} with error {err} in {fail_time} seconds. \\n\"\n                )\n\n\nfor vmap_size, mat_size in [(1, 50), (1000, 50)]:\n    test_solvers(vmap_size, mat_size)\n"
  },
  {
    "path": "docs/.htaccess",
    "content": "ErrorDocument 404 /jaxtyping/404.html\n"
  },
  {
    "path": "docs/_overrides/partials/source.html",
    "content": "{% import \"partials/language.html\" as lang with context %}\n<a href=\"{{ config.repo_url }}\" title=\"{{ lang.t('source.link.title') }}\" class=\"md-source\" data-md-component=\"source\">\n  <div class=\"md-source__icon md-icon\">\n    {% set icon = config.theme.icon.repo or \"fontawesome/brands/git-alt\" %}\n    {% include \".icons/\" ~ icon ~ \".svg\" %}\n  </div>\n  <div class=\"md-source__repository\">\n    {{ config.repo_name }}\n  </div>\n</a>\n<a href=\"{{ config.theme.twitter_url }}\" title=\"Go to Twitter\" class=\"md-source\">\n  <div class=\"md-source__icon md-icon\">\n    {% include \".icons/fontawesome/brands/twitter.svg\" %}\n  </div>\n</a>\n<a href=\"{{ config.theme.bluesky_url }}\" title=\"Go to Bluesky\" class=\"md-source\">\n  <div class=\"md-source__icon md-icon\">\n    {% include \"bluesky.svg\" %}\n  </div>\n  <div class=\"md-source__repository\">\n    {{ config.theme.twitter_bluesky_name }}\n  </div>\n</a>\n"
  },
  {
    "path": "docs/_static/custom_css.css",
    "content": "/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */\nhtml {\n    scroll-padding-top: 50px;\n}\n\n/* Fit the Twitter handle alongside the GitHub one in the top right. */\n\ndiv.md-header__source {\n    width: revert;\n    max-width: revert;\n}\n\na.md-source {\n    display: inline-block;\n}\n\n.md-source__repository {\n    max-width: 100%;\n}\n\n/* Emphasise sections of nav on left hand side */\n\nnav.md-nav {\n  padding-left: 5px;\n}\n\nnav.md-nav--secondary {\n    border-left: revert !important;\n}\n\n.md-nav__title {\n  font-size: 0.9rem;\n}\n\n.md-nav__item--section > .md-nav__link {\n  font-size: 0.9rem;\n}\n\n/* Indent autogenerated documentation */\n\ndiv.doc-contents {\n  padding-left: 25px;\n  border-left: 4px solid rgba(230, 230, 230);\n}\n\n/* Increase visibility of splitters \"---\" */\n\n[data-md-color-scheme=\"default\"] .md-typeset hr {\n    border-bottom-color: rgb(0, 0, 0);\n    border-bottom-width: 1pt;\n}\n\n[data-md-color-scheme=\"slate\"] .md-typeset hr {\n    border-bottom-color: rgb(230, 230, 230);\n}\n\n/* More space at the bottom of the page */\n\n.md-main__inner {\n  margin-bottom: 1.5rem;\n}\n\n/* Remove prev/next footer buttons */\n\n.md-footer__inner {\n    display: none;\n}\n\n/* Change font sizes */\n\nhtml {\n    /* Decrease font size for overall webpage\n       Down from 137.5% which is the Material default */\n    font-size: 110%;\n}\n\n.md-typeset .admonition {\n    /* Increase font size in admonitions */\n    font-size: 100% !important;\n}\n\n.md-typeset details {\n    /* Increase font size in details */\n    font-size: 100% !important;\n}\n\n.md-typeset h1 {\n    font-size: 1.6rem;\n}\n\n.md-typeset h2 {\n    font-size: 1.5rem;\n}\n\n.md-typeset h3 {\n    font-size: 1.3rem;\n}\n\n.md-typeset h4 {\n    font-size: 1.1rem;\n}\n\n.md-typeset h5 {\n    font-size: 0.9rem;\n}\n\n.md-typeset h6 {\n    font-size: 0.8rem;\n}\n\n/* Bugfix: remove the superfluous parts generated when doing:\n\n??? Blah\n\n    ::: library.something\n*/\n\n.md-typeset details .mkdocstrings > h4 {\n    display: none;\n}\n\n.md-typeset details .mkdocstrings > h5 {\n    display: none;\n}\n\n/* Change default colours for <a> tags */\n\n[data-md-color-scheme=\"default\"] {\n    --md-typeset-a-color: rgb(0, 189, 164) !important;\n}\n[data-md-color-scheme=\"slate\"] {\n    --md-typeset-a-color: rgb(0, 189, 164) !important;\n}\n\n/* Highlight functions, classes etc. type signatures. Really helps to make clear where\n   one item ends and another begins. */\n\n[data-md-color-scheme=\"default\"] {\n    --doc-heading-color: #DDD;\n    --doc-heading-border-color: #CCC;\n    --doc-heading-color-alt: #F0F0F0;\n}\n[data-md-color-scheme=\"slate\"] {\n    --doc-heading-color: rgb(25,25,33);\n    --doc-heading-border-color: rgb(25,25,33);\n    --doc-heading-color-alt: rgb(33,33,44);\n    --md-code-bg-color: rgb(38,38,50);\n}\n\nh4.doc-heading {\n    /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/\n    background-color: var(--doc-heading-color);\n    border: solid var(--doc-heading-border-color);\n    border-width: 1.5pt;\n    border-radius: 2pt;\n    padding: 0pt 5pt 2pt 5pt;\n}\nh5.doc-heading, h6.heading {\n    background-color: var(--doc-heading-color-alt);\n    border-radius: 2pt;\n    padding: 0pt 5pt 2pt 5pt;\n}\n"
  },
  {
    "path": "docs/_static/mathjax.js",
    "content": "window.MathJax = {\n  tex: {\n    inlineMath: [[\"\\\\(\", \"\\\\)\"]],\n    displayMath: [[\"\\\\[\", \"\\\\]\"]],\n    processEscapes: true,\n    processEnvironments: true\n  },\n  options: {\n    ignoreHtmlClass: \".*|\",\n    processHtmlClass: \"arithmatex\"\n  }\n};\n\ndocument$.subscribe(() => {\n  MathJax.typesetPromise()\n})\n"
  },
  {
    "path": "docs/api/functions.md",
    "content": "# Functions on linear operators\n\nWe define a number of functions on [linear operators](./operators.md).\n\n## Computational changes\n\nThese do not change the mathematical meaning of the operator; they simply change how it is stored computationally. (E.g. to materialise the whole operator.)\n\n::: lineax.linearise\n\n---\n\n::: lineax.materialise\n\n## Extract information from the operator\n\n::: lineax.diagonal\n\n---\n\n::: lineax.tridiagonal\n\n## Test the operator to see if it exhibits a certain property\n\nNote that these do *not* inspect the values of the operator -- instead, they use typically use [tags](./tags.md). (Or in some cases, just the type of the operator: e.g. `is_diagonal(DiagonalLinearOperator(...)) == True`.)\n\n::: lineax.has_unit_diagonal\n\n---\n\n::: lineax.is_diagonal\n\n---\n\n::: lineax.is_tridiagonal\n\n---\n\n::: lineax.is_lower_triangular\n\n---\n\n::: lineax.is_upper_triangular\n\n---\n\n::: lineax.is_positive_semidefinite\n\n---\n\n::: lineax.is_negative_semidefinite\n\n---\n\n::: lineax.is_symmetric\n"
  },
  {
    "path": "docs/api/linear_solve.md",
    "content": "# linear_solve\n\nThis is the main entry point.\n\n::: lineax.linear_solve\n\n## invert\n\nA convenience function for obtaining the inverse of an operator as a [`lineax.FunctionLinearOperator`][].\n\n::: lineax.invert"
  },
  {
    "path": "docs/api/operators.md",
    "content": "# Linear operators\n\nWe often talk about solving a linear system $Ax = b$, where $A \\in \\mathbb{R}^{n \\times m}$ is a matrix, $b \\in \\mathbb{R}^n$ is a vector, and $x \\in \\mathbb{R}^m$ is our desired solution.\n\nThe linear operators described on this page are ways of describing the matrix $A$. The simplest is [`lineax.MatrixLinearOperator`][], which simply holds the matrix $A$ directly.\n\nMeanwhile if $A$ is diagonal, then there is also [`lineax.DiagonalLinearOperator`][]: for efficiency this only stores the diagonal of $A$.\n\nOr, perhaps we only have a function $F : \\mathbb{R}^m \\to \\mathbb{R}^n$ such that $F(x) = Ax$. Whilst we could use $F$ to materialise the whole matrix $A$ and then store it in a [`lineax.MatrixLinearOperator`][], that may be very memory intensive. Instead, we may prefer to use [`lineax.FunctionLinearOperator`][]. Many linear solvers (e.g. [`lineax.CG`][]) only use matrix-vector products, and this means we can avoid ever needing to materialise the whole matrix $A$.\n\n??? abstract \"`lineax.AbstractLinearOperator`\"\n\n    ::: lineax.AbstractLinearOperator\n        options:\n            members:\n                - mv\n                - as_matrix\n                - transpose\n                - in_structure\n                - out_structure\n                - in_size\n                - out_size\n\n::: lineax.MatrixLinearOperator\n    options:\n        members:\n            - __init__\n\n---\n\n::: lineax.DiagonalLinearOperator\n    options: \n        members: \n            - __init__\n\n---\n\n::: lineax.TridiagonalLinearOperator\n    options:\n        members:\n            - __init__\n\n---\n\n::: lineax.PyTreeLinearOperator\n    options:\n        members:\n            - __init__\n\n---\n\n::: lineax.JacobianLinearOperator\n    options:\n        members:\n            - __init__\n\n---\n\n::: lineax.FunctionLinearOperator\n    options:\n        members:\n            - __init__\n\n---\n\n::: lineax.IdentityLinearOperator\n    options:\n        members:\n            - __init__\n\n---\n\n::: lineax.TaggedLinearOperator\n    options:\n        members:\n            - __init__\n"
  },
  {
    "path": "docs/api/solution.md",
    "content": "# Solution\n\n::: lineax.Solution\n    options:\n        members: []\n\n---\n\n::: lineax.RESULTS\n    options:\n        members: []\n"
  },
  {
    "path": "docs/api/solvers.md",
    "content": "# Solvers\n\nIf you're not sure what to use, then pick [`lineax.AutoLinearSolver`][] and it will automatically dispatch to an efficient solver depending on what structure your linear operator is declared to exhibit. (See the [tags](./tags.md) page.)\n\n??? abstract \"`lineax.AbstractLinearSolver`\"\n\n    ::: lineax.AbstractLinearSolver\n        options:\n            members:\n                - init\n                - compute\n                - transpose\n                - conj\n                - assume_full_rank\n\n::: lineax.AutoLinearSolver\n    options:\n        members:\n            - __init__\n            - select_solver\n\n---\n\n::: lineax.LU\n    options:\n        members:\n            - __init__\n\n## Least squares solvers\n\nThese are capable of solving ill-posed linear problems.\n\n::: lineax.QR\n    options:\n        members:\n            - __init__\n\n---\n\n::: lineax.SVD\n    options:\n        members:\n            - __init__\n\n---\n\n::: lineax.Normal\n    options:\n        members:\n            - __init__\n\n---\n\n::: lineax.LSMR\n    options:\n        members:\n            - __init__\n\n\n#### Diagonal\n\nIn addition to these, [`lineax.Diagonal`][] with `well_posed=False` (below) also supports ill-posed problems.\n\n## Iterative solvers\n\nThese solvers use only matrix-vector products, and do not require instantiating the whole matrix. This makes them good when used alongside e.g. [`lineax.JacobianLinearOperator`][] or [`lineax.FunctionLinearOperator`][], which only provide matrix-vector products.\n\n!!! warning\n\n    Note that [`lineax.BiCGStab`][] and [`lineax.GMRES`][] may fail to converge on some (typically non-sparse) problems.\n\n::: lineax.CG\n    options:\n        members:\n            - __init__\n\n---\n\n::: lineax.BiCGStab\n    options:\n        members:\n            - __init__\n\n---\n\n::: lineax.GMRES\n    options:\n        members:\n            - __init__\n\n#### LSMR\n\nIn addition to these, [`lineax.LSMR`][] (above) is also an iterative method.\n\n## Structure-exploiting solvers\n\nThese require special structure in the operator. (And will throw an error if passed an operator without that structure.) In return, they are able to solve the linear problem much more efficiently.\n\n::: lineax.Cholesky\n    options:\n        members:\n            - __init__\n\n---\n\n::: lineax.Diagonal\n    options:\n        members:\n            - __init__\n\n---\n\n::: lineax.Triangular\n    options:\n        members:\n            - __init__\n\n---\n\n::: lineax.Tridiagonal\n    options:\n        members:\n            - __init__\n\n#### CG\n\nIn addition to these, [`lineax.CG`][] also requires special structure (positive or negative definiteness).\n"
  },
  {
    "path": "docs/api/tags.md",
    "content": "# Tags\n\nLineax offers a way to \"tag\" linear operators as exhibiting certain properties, e.g. that they are positive semidefinite.\n\nIf a linear operator is known to have a particular property, then this can be used to dispatch to a more efficient implementation, e.g. when solving a linear system.\n\nGenerally speaking, tags are an *optional* tool that can be used to improve your run time and/or compile time, by statically telling the linear solvers what properties they may assume about your system. However, if misused then you may find that the wrong result is silently returned.\n\nIn this way they are analogous to flags like `scipy.linalg.solve(..., assume_a=\"pos\")`.\n\n!!! Example\n\n    ```python\n    # Some rank-2 JAX array.\n    matrix = ...\n    # Some rank-1 JAX array.\n    vector = ...\n\n    # Declare that this matrix is positive semidefinite.\n    operator = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag)\n\n    # This tag is used to dispatch to a maximally-efficient linear solver.\n    # In this case, a Cholesky solver is used:\n    solution = lx.linear_solve(operator, vector)\n\n    # Whether operators are tagged can be checked:\n    assert lx.is_positive_semidefinite(operator)\n    ```\n\n!!! Warning\n\n    Be careful, only the tag is actually checked, not the actual value of the matrix:\n    ```python\n    # Not a positive semidefinite matrix\n    matrix = jax.numpy.array([[1, 2], [3, 4]])\n\n    operator = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag)\n    lx.is_positive_semidefinite(operator)  # True\n    lx.linear_solve(operator, vector)  # Returns the wrong solution!\n    ```\n\nOf the built-in operators: [`lineax.MatrixLinearOperator`][], [`lineax.PyTreeLinearOperator`][], [`lineax.JacobianLinearOperator`][], [`lineax.FunctionLinearOperator`][], [`lineax.TaggedLinearOperator`][] directly support a `tags` argument that mark them as having certain characteristics:\n```python\noperator = lx.MatrixLinearOperator(matrix, lx.symmetric_tag)\n```\n\nYou can pass multiple tags at once:\n```python\noperator = lx.MatrixLinearOperator(matrix, (lx.symmetric_tag, lx.unit_diagonal_tag))\n```\n\nOther linear operators can be wrapped into a [`lineax.TaggedLinearOperator`][] if necessary:\n```python\noperator = lx.MatrixLinearOperator(...)\nsymmetric_operator = operator + operator.T\nlx.is_symmetric(symmetric_operator)  # False\nsymmetric_operator = lx.TaggedLinearOperator(symmetric_operator, lx.symmetric_tag)\nlx.is_symmetric(symmetric_operator)  # True\n```\n\nSome linear operators are known to exhibit certain properties by construction, and need no additional tags:\n```python\nlx.is_symmetric(lx.DiagonalLinearOperator(...))  # True\nlx.is_positive_semidefinite(lx.IdentityLinearOperator(...))  # True\n```\n\n## List of available tags\n\n::: lineax.symmetric_tag\n\nMarks that an operator is symmetric. (As a matrix, $A = A^\\intercal$.)\n\n---\n\n::: lineax.diagonal_tag\n\nMarks than an operator is diagonal. (As a matrix, it must have zeros in the off-diagonal entries.)\n\nFor example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Diagonal`][] as the solver.\n\n---\n\n::: lineax.unit_diagonal_tag\n\nMarks than an operator has $1$ for every diagonal element. (As a matrix $A$, then it must have $A_{ii} = 1$ for all $i$.) Note that the whole matrix need not be diagonal.\n\nFor example, [`lineax.Triangular`][] uses this to cheapen its solve.\n\n---\n\n::: lineax.lower_triangular_tag\n\nMarks that an operator is lower triangular. (As a matrix $A$, then it must have $A_{ij} = 0 for all $i < j$.) Note that the diagonal may still have nonzero entries.\n\nFor example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Triangular`][] as the solver.\n\n---\n\n::: lineax.upper_triangular_tag\n\nMarks that an operator is upper triangular. (As a matrix $A$, then it must have $A_{ij} = 0 for all $i > j$.) Note that the diagonal may still have nonzero entries.\n\nFor example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Triangular`][] as the solver.\n\n---\n\n::: lineax.positive_semidefinite_tag\n\nMarks than operator is positive **semidefinite**.\n\nFor example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Cholesky`][] as the solver.\n\n---\n\n::: lineax.negative_semidefinite_tag\n\nMarks than operator is negative **semidefinite**.\n\nFor example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Cholesky`][] as the solver.\n"
  },
  {
    "path": "docs/examples/classical_solve.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"8d41e1dd-93da-4e81-bd4a-33e5df8915f1\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Classical solve\\n\",\n    \"\\n\",\n    \"We wish to solve the linear system $Ax = b$. Here we consider the classical case for which the full matrix $A$ is square, well-posed and materialised in memory.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"cb3a7781-2358-40c4-82f3-e908bddeb578\",\n   \"metadata\": {\n    \"tags\": [],\n    \"ExecuteTime\": {\n     \"end_time\": \"2024-04-02T05:26:05.556701Z\",\n     \"start_time\": \"2024-04-02T05:26:03.814599Z\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"A=\\n\",\n      \"[[-0.3721109   0.26423115 -0.18252768]\\n\",\n      \" [-0.7368197   0.44973662 -0.1521442 ]\\n\",\n      \" [-0.67135346 -0.5908641   0.73168886]]\\n\",\n      \"b=[ 0.17269018 -0.64765567  1.2229712 ]\\n\",\n      \"x=[-2.7321298 -8.52878   -7.7226872]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import jax.random as jr\\n\",\n    \"import lineax as lx\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"matrix = jr.normal(jr.PRNGKey(0), (3, 3))\\n\",\n    \"vector = jr.normal(jr.PRNGKey(1), (3,))\\n\",\n    \"operator = lx.MatrixLinearOperator(matrix)\\n\",\n    \"solution = lx.linear_solve(operator, vector)\\n\",\n    \"print(f\\\"A=\\\\n{matrix}\\\\nb={vector}\\\\nx={solution.value}\\\")\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.16\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "docs/examples/complex_solve.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"8d41e1dd-93da-4e81-bd4a-33e5df8915f1\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Complex solve\\n\",\n    \"\\n\",\n    \"We can also solve a system with complex entries. Here we consider the classical case for which the full matrix $A$ is square, well-posed and materialised in memory.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"cb3a7781-2358-40c4-82f3-e908bddeb578\",\n   \"metadata\": {\n    \"tags\": [],\n    \"ExecuteTime\": {\n     \"end_time\": \"2024-04-02T05:29:04.909894Z\",\n     \"start_time\": \"2024-04-02T05:29:04.103141Z\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"A=\\n\",\n      \"[[-1.8459436 -0.2744466j   0.02393756-0.03172905j  0.76815367-1.4444253j ]\\n\",\n      \" [-1.0467293 +0.05608991j  1.0891742 -0.03264743j  0.7513123 +0.56285536j]\\n\",\n      \" [ 0.38307396-1.0190808j   0.01203694-1.1971304j   0.19252291-0.26424018j]]\\n\",\n      \"b=[0.23162952+0.3614433j  0.05800135+1.6094692j  0.8979094 +0.16941352j]\\n\",\n      \"x=[-0.07652722-0.34397143j -0.22629777+1.0359733j   0.22135164-0.00880566j]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import jax.numpy as jnp\\n\",\n    \"import jax.random as jr\\n\",\n    \"import lineax as lx\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"matrix = jr.normal(jr.PRNGKey(0), (3, 3), dtype=jnp.complex64)\\n\",\n    \"vector = jr.normal(jr.PRNGKey(1), (3,), dtype=jnp.complex64)\\n\",\n    \"operator = lx.MatrixLinearOperator(matrix)\\n\",\n    \"solution = lx.linear_solve(operator, vector)\\n\",\n    \"print(f\\\"A=\\\\n{matrix}\\\\nb={vector}\\\\nx={solution.value}\\\")\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.16\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "docs/examples/least_squares.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"44bff903-0e4d-4f3e-a75c-d3cfe8ab4dea\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Linear least squares\\n\",\n    \"\\n\",\n    \"The solution to a well-posed linear system $Ax = b$ is given by $x = A^{-1}b$. If the matrix is rectangular or not invertible, then we may generalise the notion of solution to $x = A^{\\\\dagger}b$, where $A^{\\\\dagger}$ denotes the [Moore--Penrose pseudoinverse](https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse).\\n\",\n    \"\\n\",\n    \"Lineax can handle problems of this type too.\\n\",\n    \"\\n\",\n    \"!!! info\\n\",\n    \"\\n\",\n    \"    For reference: in core JAX, problems of this type are handled using [`jax.numpy.linalg.lstsq`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.lstsq.html#jax.numpy.linalg.lstsq).\\n\",\n    \"    \\n\",\n    \"---\\n\",\n    \"\\n\",\n    \"## Picking a solver\\n\",\n    \"\\n\",\n    \"By default, the linear solve will fail. This will be a compile-time failure if using a rectangular matrix:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"a956c3f2-a70c-472f-9fa9-3dbc16293e1d\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"ename\": \"ValueError\",\n     \"evalue\": \"Cannot use `AutoLinearSolver(well_posed=True)` with a non-square operator. If you are trying solve a least-squares problem then you should pass `solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve` assumes that the operator is square and nonsingular.\",\n     \"output_type\": \"error\",\n     \"traceback\": [\n      \"\\u001b[0;31mValueError\\u001b[0m\\u001b[0;31m:\\u001b[0m Cannot use `AutoLinearSolver(well_posed=True)` with a non-square operator. If you are trying solve a least-squares problem then you should pass `solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve` assumes that the operator is square and nonsingular.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import jax.random as jr\\n\",\n    \"import lineax as lx\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"vector = jr.normal(jr.PRNGKey(1), (3,))\\n\",\n    \"\\n\",\n    \"rectangular_matrix = jr.normal(jr.PRNGKey(0), (3, 4))\\n\",\n    \"rectangular_operator = lx.MatrixLinearOperator(rectangular_matrix)\\n\",\n    \"lx.linear_solve(rectangular_operator, vector)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ba55c0dd-b696-497a-8b13-896c3a95d5fd\",\n   \"metadata\": {},\n   \"source\": [\n    \"Or it will happen at run time if using a rank-deficient matrix:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"f0e7ffe6-1e3d-46dc-9dbd-d5ed4c2dedf4\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"ename\": \"XlaRuntimeError\",\n     \"evalue\": \"INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the\\noperator was not well-posed, and that the solver does not support this.\\n\\nIf you are trying solve a linear least-squares problem then you should pass\\n`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`\\nassumes that the operator is square and nonsingular.\\n\\nIf you *were* expecting this solver to work with this operator, then it may be because:\\n\\n(a) the operator is singular, and your code has a bug; or\\n\\n(b) the operator was nearly singular (i.e. it had a high condition number:\\n    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from\\n    numerical instability issues; or\\n\\n(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)\\n    that is does not actually satisfy.\\n\",\n     \"output_type\": \"error\",\n     \"traceback\": [\n      \"\\u001b[0;31mXlaRuntimeError\\u001b[0m\\u001b[0;31m:\\u001b[0m INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the\\noperator was not well-posed, and that the solver does not support this.\\n\\nIf you are trying solve a linear least-squares problem then you should pass\\n`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`\\nassumes that the operator is square and nonsingular.\\n\\nIf you *were* expecting this solver to work with this operator, then it may be because:\\n\\n(a) the operator is singular, and your code has a bug; or\\n\\n(b) the operator was nearly singular (i.e. it had a high condition number:\\n    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from\\n    numerical instability issues; or\\n\\n(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)\\n    that is does not actually satisfy.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"deficient_matrix = jr.normal(jr.PRNGKey(0), (3, 3)).at[0].set(0)\\n\",\n    \"deficient_operator = lx.MatrixLinearOperator(deficient_matrix)\\n\",\n    \"lx.linear_solve(deficient_operator, vector)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"4b5cedab-75e5-4b52-88d9-b9d574be7e19\",\n   \"metadata\": {},\n   \"source\": [\n    \"Whilst linear least squares and pseudoinverse are a strict generalisation of linear solves and inverses (respectively), Lineax will *not* attempt to handle the ill-posed case automatically. This is because the algorithms for handling this case are much more computationally expensive.!\\n\",\n    \"\\n\",\n    \"If your matrix may be rectangular, but is still known to be full rank, then you can set the solver to allow this case like so:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"45abc8bf-4fcf-46be-a91a-58f4e04ac10e\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"rectangular_solution:  [-0.3214848  -0.75565964 -0.6034579  -0.01326615]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"rectangular_solution = lx.linear_solve(\\n\",\n    \"    rectangular_operator, vector, solver=lx.AutoLinearSolver(well_posed=None)\\n\",\n    \")\\n\",\n    \"print(\\\"rectangular_solution: \\\", rectangular_solution.value)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"86dc9e2f-fe2e-48c8-86ca-bc57f8137246\",\n   \"metadata\": {},\n   \"source\": [\n    \"If your matrix may be either rectangular or rank-deficient, then you can set the solver to all this case like so:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"a9a2d92c-3676-471e-bb4a-5fd3b4748fd4\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"deficient_solution:  [ 0.06046088 -1.0412765   0.8860444 ]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"deficient_solution = lx.linear_solve(\\n\",\n    \"    deficient_operator, vector, solver=lx.AutoLinearSolver(well_posed=False)\\n\",\n    \")\\n\",\n    \"print(\\\"deficient_solution: \\\", deficient_solution.value)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7b870311-de0f-434c-a9e7-2d8ebf9f0b38\",\n   \"metadata\": {},\n   \"source\": [\n    \"Most users will want to use [`lineax.AutoLinearSolver`][], and not think about the details of which algorithm is selected.\\n\",\n    \"\\n\",\n    \"If you want to pick a particular algorithm, then that can be done too. [`lineax.QR`][] is capable of handling rectangular full-rank operators, and [`lineax.SVD`][] is capable of handling rank-deficient operators. (And in fact these are the algorithms that `AutoLinearSolver` is selecting in the examples above.)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"c9649746-b0ef-495b-9ea1-eb5f6ca2e7e5\",\n   \"metadata\": {},\n   \"source\": [\n    \"---\\n\",\n    \"\\n\",\n    \"## Differences from `jax.numpy.linalg.lstsq`?\\n\",\n    \"\\n\",\n    \"Lineax offers both speed and correctness advantages over the built-in algorithm. (This is partly because the built-in function has to have the same API as NumPy, so JAX is constrained in how it can be implemented.)\\n\",\n    \"\\n\",\n    \"### Speed (forward)\\n\",\n    \"\\n\",\n    \"First, in the rectangular case, then the QR algorithm is much faster than the SVD algorithm:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"id\": \"d46d0c9a-47e4-439d-9beb-c9aaf47faa5d\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"JAX solution: [-0.10002219  0.09477127 -0.10846332 ... -0.08007179 -0.01216239\\n\",\n      \" -0.030862  ]\\n\",\n      \"Lineax solution: [-0.1000222   0.0947713  -0.10846333 ... -0.08007187 -0.01216241\\n\",\n      \" -0.03086199]\\n\",\n      \"\\n\",\n      \"JAX time: 0.011344402999384329\\n\",\n      \"Lineax time: 0.0028611960005946457\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import timeit\\n\",\n    \"\\n\",\n    \"import jax\\n\",\n    \"import jax.numpy as jnp\\n\",\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"matrix = jr.normal(jr.PRNGKey(0), (500, 200))\\n\",\n    \"vector = jr.normal(jr.PRNGKey(1), (500,))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@jax.jit\\n\",\n    \"def solve_jax(matrix, vector):\\n\",\n    \"    out, *_ = jnp.linalg.lstsq(matrix, vector)\\n\",\n    \"    return out\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@jax.jit\\n\",\n    \"def solve_lineax(matrix, vector):\\n\",\n    \"    operator = lx.MatrixLinearOperator(matrix)\\n\",\n    \"    solver = lx.QR()  # or lx.AutoLinearSolver(well_posed=None)\\n\",\n    \"    solution = lx.linear_solve(operator, vector, solver)\\n\",\n    \"    return solution.value\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"solution_jax = solve_jax(matrix, vector)\\n\",\n    \"solution_lineax = solve_lineax(matrix, vector)\\n\",\n    \"with np.printoptions(threshold=10):\\n\",\n    \"    print(\\\"JAX solution:\\\", solution_jax)\\n\",\n    \"    print(\\\"Lineax solution:\\\", solution_lineax)\\n\",\n    \"print()\\n\",\n    \"time_jax = timeit.repeat(lambda: solve_jax(matrix, vector), number=1, repeat=10)\\n\",\n    \"time_lineax = timeit.repeat(lambda: solve_lineax(matrix, vector), number=1, repeat=10)\\n\",\n    \"print(\\\"JAX time:\\\", min(time_jax))\\n\",\n    \"print(\\\"Lineax time:\\\", min(time_lineax))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"397773d7-f782-45e6-9934-11c62c741380\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Speed (gradients)\\n\",\n    \"\\n\",\n    \"Lineax also uses a slightly more efficient autodifferentiation implementation, which ensures it is faster, even when both are using the SVD algorithm.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"id\": \"1988d0f6-86f5-401a-9615-30cccf04d129\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"JAX gradients: [[-1.75446249e-03  2.00700224e-03 ... -3.16517282e-04 -6.08515576e-04]\\n\",\n      \" [ 1.81865180e-04  4.51280124e-04 ... -1.64618701e-04 -6.53692259e-05]\\n\",\n      \" ...\\n\",\n      \" [-7.27269216e-04  1.27710134e-03 ... -2.64510425e-04 -3.38940619e-04]\\n\",\n      \" [ 6.55723223e-03 -3.18011409e-03 ... -1.10758876e-04  1.43246143e-03]]\\n\",\n      \"Lineax gradients: [[-1.7544631e-03  2.0070139e-03 ... -3.1653541e-04 -6.0847402e-04]\\n\",\n      \" [ 1.8186278e-04  4.5128341e-04 ... -1.6459504e-04 -6.5359738e-05]\\n\",\n      \" ...\\n\",\n      \" [-7.2721508e-04  1.2771402e-03 ... -2.6450949e-04 -3.3894143e-04]\\n\",\n      \" [ 6.5572355e-03 -3.1801097e-03 ... -1.1071599e-04  1.4324478e-03]]\\n\",\n      \"\\n\",\n      \"JAX time: 0.016591553001489956\\n\",\n      \"Lineax time: 0.012212782999995397\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"@jax.jit\\n\",\n    \"@jax.grad\\n\",\n    \"def grad_jax(matrix):\\n\",\n    \"    out, *_ = jnp.linalg.lstsq(matrix, vector)\\n\",\n    \"    return out.sum()\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@jax.jit\\n\",\n    \"@jax.grad\\n\",\n    \"def grad_lineax(matrix):\\n\",\n    \"    operator = lx.MatrixLinearOperator(matrix)\\n\",\n    \"    solution = lx.linear_solve(operator, vector, lx.SVD())\\n\",\n    \"    return solution.value.sum()\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"gradients_jax = grad_jax(matrix)\\n\",\n    \"gradients_lineax = grad_lineax(matrix)\\n\",\n    \"with np.printoptions(threshold=10, edgeitems=2):\\n\",\n    \"    print(\\\"JAX gradients:\\\", gradients_jax)\\n\",\n    \"    print(\\\"Lineax gradients:\\\", gradients_lineax)\\n\",\n    \"print()\\n\",\n    \"time_jax = timeit.repeat(lambda: grad_jax(matrix), number=1, repeat=10)\\n\",\n    \"time_lineax = timeit.repeat(lambda: grad_lineax(matrix), number=1, repeat=10)\\n\",\n    \"print(\\\"JAX time:\\\", min(time_jax))\\n\",\n    \"print(\\\"Lineax time:\\\", min(time_lineax))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"81a1da5a-3474-4613-926f-5c9d9cdcb4a7\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Correctness (gradients)\\n\",\n    \"\\n\",\n    \"Core JAX unfortunately has a bug that means it sometimes produces NaN gradients. Lineax does not:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"id\": \"66b3a08e-92d0-4d0f-a5ea-9e8e5265d259\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"JAX gradients: [[nan nan nan]\\n\",\n      \" [nan nan nan]\\n\",\n      \" [nan nan nan]]\\n\",\n      \"Lineax gradients: [[ 0. -1. -2.]\\n\",\n      \" [ 0. -1. -2.]\\n\",\n      \" [ 0. -1. -2.]]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"@jax.jit\\n\",\n    \"@jax.grad\\n\",\n    \"def grad_jax(matrix):\\n\",\n    \"    out, *_ = jnp.linalg.lstsq(matrix, jnp.arange(3.0))\\n\",\n    \"    return out.sum()\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@jax.jit\\n\",\n    \"@jax.grad\\n\",\n    \"def grad_lineax(matrix):\\n\",\n    \"    operator = lx.MatrixLinearOperator(matrix)\\n\",\n    \"    solution = lx.linear_solve(operator, jnp.arange(3.0), lx.SVD())\\n\",\n    \"    return solution.value.sum()\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"print(\\\"JAX gradients:\\\", grad_jax(jnp.eye(3)))\\n\",\n    \"print(\\\"Lineax gradients:\\\", grad_lineax(jnp.eye(3)))\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"py39\",\n   \"language\": \"python\",\n   \"name\": \"py39\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.16\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "docs/examples/no_materialisation.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a7299095-8906-4867-82ef-d6b84b161366\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Using only matrix-vector operations\\n\",\n    \"\\n\",\n    \"When solving a linear system $Ax = b$, it is relatively common not to have immediate access to the full matrix $A$, but only to a function $F(x) = Ax$ computing the matrix-vector product. (We could compute $A$ from $F$, but is the matrix is large then this may be very inefficient.)\\n\",\n    \"\\n\",\n    \"**Example: Newton's method**\\n\",\n    \"\\n\",\n    \"For example, this comes up when using [Newton's method](https://en.wikipedia.org/wiki/Newton%27s_method#k_variables,_k_functions). In this case, we have a function $f \\\\colon \\\\mathbb{R}^n \\\\to \\\\mathbb{R}^n$, and wish to find the $\\\\delta \\\\in \\\\mathbb{R}^n$ for which $\\\\frac{\\\\mathrm{d}f}{\\\\mathrm{d}y}(y) \\\\; \\\\delta = -f(y)$. (Where $\\\\frac{\\\\mathrm{d}f}{\\\\mathrm{d}y}(y) \\\\in \\\\mathbb{R}^{n \\\\times n}$ is a matrix: it is the Jacobian of $f$.)\\n\",\n    \"\\n\",\n    \"In this case it is possible to use forward-mode autodifferentiation to evaluate $F(x) = \\\\frac{\\\\mathrm{d}f}{\\\\mathrm{d}y}(y) \\\\; x$, without ever instantiating the whole Jacobian $\\\\frac{\\\\mathrm{d}f}{\\\\mathrm{d}y}(y)$. Indeed, JAX has a [Jacobian-vector product function](https://jax.readthedocs.io/en/latest/_autosummary/jax.jvp.html#jax.jvp) for exactly this purpose.\\n\",\n    \"```python\\n\",\n    \"f = ...\\n\",\n    \"y = ...\\n\",\n    \"\\n\",\n    \"def F(x):\\n\",\n    \"    \\\"\\\"\\\"Computes (df/dy) @ x.\\\"\\\"\\\"\\n\",\n    \"    _, out = jax.jvp(f, (y,), (x,))\\n\",\n    \"    return out\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"**Solving a linear system using only matrix-vector operations**\\n\",\n    \"\\n\",\n    \"Lineax offers [iterative solvers](../api/solvers.md#iterative-solvers), which are capable of solving a linear system knowing only its matrix-vector products.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"b221ee1f-bd6b-4cbf-b69b-ed2e388602e1\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import jax.numpy as jnp\\n\",\n    \"import lineax as lx\\n\",\n    \"from jaxtyping import Array, Float  # https://github.com/google/jaxtyping\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def f(y: Float[Array, \\\"3\\\"], args) -> Float[Array, \\\"3\\\"]:\\n\",\n    \"    y0, y1, y2 = y\\n\",\n    \"    f0 = 5 * y0 + y1**2\\n\",\n    \"    f1 = y1 - y2 + 5\\n\",\n    \"    f2 = y0 / (1 + 5 * y2**2)\\n\",\n    \"    return jnp.stack([f0, f1, f2])\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"y = jnp.array([1.0, 2.0, 3.0])\\n\",\n    \"operator = lx.JacobianLinearOperator(f, y, args=None)\\n\",\n    \"vector = f(y, args=None)\\n\",\n    \"solver = lx.Normal(lx.CG(rtol=1e-6, atol=1e-6))\\n\",\n    \"solution = lx.linear_solve(operator, vector, solver)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"87568426-35ed-404b-bf78-425a6f519218\",\n   \"metadata\": {},\n   \"source\": [\n    \"!!! warning\\n\",\n    \"\\n\",\n    \"    Note that iterative solvers are something of a \\\"last resort\\\", and they are not suitable for all problems.\\n\",\n    \"\\n\",\n    \"    - [CG](https://en.wikipedia.org/wiki/Conjugate_gradient_method) requires that the problem be positive or negative semidefinite.\\n\",\n    \"    - Normalised CG (this is CG applied to the \\\"normal equations\\\" $(A^\\\\top A) x = (A^\\\\top b)$; note that $A^\\\\top A$ is always positive semidefinite) squares the condition number of $A$. In practice this means it may produce low-accuracy results if used with matrices with high condition number.\\n\",\n    \"    - [BiCGStab](https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method) and [GMRES](https://en.wikipedia.org/wiki/Generalized_minimal_residual_method) will fail on many problems. They are primarily meant as specialised tools for e.g. the matrices that arise when solving elliptic systems.\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"py39\",\n   \"language\": \"python\",\n   \"name\": \"py39\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.16\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "docs/examples/operators.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2fe0b1e4-35cb-4c39-b324-65253aab005a\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Manipulating linear operators\\n\",\n    \"\\n\",\n    \"Lineax offers a sophisticated system of linear operators, supporting many operations.\\n\",\n    \"\\n\",\n    \"## Arithmetic\\n\",\n    \"\\n\",\n    \"To begin with, they support arithmetic, like addition and multiplication:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"552021d3-dadf-49f3-bd17-84a18513bfcc\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import jax\\n\",\n    \"import jax.numpy as jnp\\n\",\n    \"import jax.random as jr\\n\",\n    \"import lineax as lx\\n\",\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"np.set_printoptions(precision=3)\\n\",\n    \"\\n\",\n    \"matrix = jnp.zeros((5, 5))\\n\",\n    \"matrix = matrix.at[0, 4].set(3)  # top left corner\\n\",\n    \"sparse_operator = lx.MatrixLinearOperator(matrix)\\n\",\n    \"\\n\",\n    \"key0, key1, key = jr.split(jr.PRNGKey(0), 3)\\n\",\n    \"diag = jr.normal(key0, (5,))\\n\",\n    \"lower_diag = jr.normal(key0, (4,))\\n\",\n    \"upper_diag = jr.normal(key0, (4,))\\n\",\n    \"tridiag_operator = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)\\n\",\n    \"\\n\",\n    \"identity_operator = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((5,), jnp.float32))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"a4bb9825-73cc-447e-bc4c-c3e1a121a0a3\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[[-1.149  0.963  0.     0.     3.   ]\\n\",\n      \" [ 0.963 -2.007  0.155  0.     0.   ]\\n\",\n      \" [ 0.     0.155  0.988 -0.261  0.   ]\\n\",\n      \" [ 0.     0.    -0.261  0.931  0.899]\\n\",\n      \" [ 0.     0.     0.     0.899 -0.288]]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print((sparse_operator + tridiag_operator).as_matrix())\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"759c78a1-eee7-40e9-be6c-ea8c97c29e95\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[[-101.149    0.963    0.       0.       0.   ]\\n\",\n      \" [   0.963 -102.007    0.155    0.       0.   ]\\n\",\n      \" [   0.       0.155  -99.012   -0.261    0.   ]\\n\",\n      \" [   0.       0.      -0.261  -99.069    0.899]\\n\",\n      \" [   0.       0.       0.       0.899 -100.288]]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print((tridiag_operator - 100 * identity_operator).as_matrix())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"84412bfa-00ec-41d4-87d7-def781145a90\",\n   \"metadata\": {},\n   \"source\": [\n    \"Or they can be composed together. (I.e. matrix multiplication.)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"8081d97f-5579-464f-8780-ffaa1d9c5f95\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[[ 0.     0.     0.     0.    -3.447]\\n\",\n      \" [ 0.     0.     0.     0.     2.888]\\n\",\n      \" [ 0.     0.     0.     0.     0.   ]\\n\",\n      \" [ 0.     0.     0.     0.     0.   ]\\n\",\n      \" [ 0.     0.     0.     0.     0.   ]]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print((tridiag_operator @ sparse_operator).as_matrix())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d2c2b580-616f-4abd-a732-7f4a9b13335f\",\n   \"metadata\": {},\n   \"source\": [\n    \"Or they can be transposed:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"ae0393eb-3f43-490b-9842-bb374633633a\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[[0. 0. 0. 0. 0.]\\n\",\n      \" [0. 0. 0. 0. 0.]\\n\",\n      \" [0. 0. 0. 0. 0.]\\n\",\n      \" [0. 0. 0. 0. 0.]\\n\",\n      \" [3. 0. 0. 0. 0.]]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(sparse_operator.transpose().as_matrix())  # or sparse_operator.T will work\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ddbbbb0f-7983-4e35-b92d-2512c9612d19\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Different operator types\\n\",\n    \"\\n\",\n    \"Lineax has many different operator types:\\n\",\n    \"\\n\",\n    \"- We've already seen some general examples above, like [`lineax.MatrixLinearOperator`][].\\n\",\n    \"- We've already seen some structured examples above, like [`lineax.TridiagonalLinearOperator`][].\\n\",\n    \"- Given a function $f \\\\colon \\\\mathbb{R}^n \\\\to \\\\mathbb{R}^m$ and a point $x \\\\in \\\\mathbb{R}^n$, then [`lineax.JacobianLinearOperator`][] represents the Jacobian $\\\\frac{\\\\mathrm{d}f}{\\\\mathrm{d}x}(x) \\\\in \\\\mathbb{R}^{n \\\\times m}$.\\n\",\n    \"- Given a linear function $g \\\\colon \\\\mathbb{R}^n \\\\to \\\\mathbb{R}^m$, then [`lineax.FunctionLinearOperator`][] represents the matrix corresponding to this linear function, i.e. the unique matrix $A$ for which $g(x) = Ax$.\\n\",\n    \"- etc!\\n\",\n    \"\\n\",\n    \"See the [operators](../api/operators.md) page for details on all supported operators.\\n\",\n    \"\\n\",\n    \"As above these can be freely combined:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"75ad4480-8ce0-4a88-9c76-bc054b1a0eaf\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from jaxtyping import Array, Float  # https://github.com/google/jaxtyping\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def f(y: Float[Array, \\\"3\\\"], args) -> Float[Array, \\\"3\\\"]:\\n\",\n    \"    y0, y1, y2 = y\\n\",\n    \"    f0 = 5 * y0 + y1**2\\n\",\n    \"    f1 = y1 - y2 + 5\\n\",\n    \"    f2 = y0 / (1 + 5 * y2**2)\\n\",\n    \"    return jnp.stack([f0, f1, f2])\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def g(y: Float[Array, \\\"3\\\"]) -> Float[Array, \\\"3\\\"]:\\n\",\n    \"    # Must be linear!\\n\",\n    \"    y0, y1, y2 = y\\n\",\n    \"    f0 = y0 - y2\\n\",\n    \"    f1 = 0.0\\n\",\n    \"    f2 = 5 * y1\\n\",\n    \"    return jnp.stack([f0, f1, f2])\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"y = jnp.array([1.0, 2.0, 3.0])\\n\",\n    \"in_structure = jax.eval_shape(lambda: y)\\n\",\n    \"jac_operator = lx.JacobianLinearOperator(f, y, args=None)\\n\",\n    \"fn_operator = lx.FunctionLinearOperator(g, in_structure)\\n\",\n    \"identity_operator = lx.IdentityLinearOperator(in_structure)\\n\",\n    \"\\n\",\n    \"operator = jac_operator @ fn_operator + 0.9 * identity_operator\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5e528057-29ff-468d-aa3d-7155dd57082d\",\n   \"metadata\": {},\n   \"source\": [\n    \"This composition does not instantiate a matrix for them by default. (This is sometimes important for efficiency when working with many operators.) Instead, the composition is stored as another linear operator:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"5d15150d-955f-4006-bd36-58e2e6663307\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"AddLinearOperator(\\n\",\n      \"  operator1=ComposedLinearOperator(\\n\",\n      \"    operator1=JacobianLinearOperator(...),\\n\",\n      \"    operator2=FunctionLinearOperator(...)\\n\",\n      \"  ),\\n\",\n      \"  operator2=MulLinearOperator(\\n\",\n      \"    operator=IdentityLinearOperator(...),\\n\",\n      \"    scalar=f32[]\\n\",\n      \"  )\\n\",\n      \")\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import equinox as eqx  # https://github.com/patrick-kidger/equinox\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"truncate_leaf = lambda x: x in (jac_operator, fn_operator, identity_operator)\\n\",\n    \"eqx.tree_pprint(operator, truncate_leaf=truncate_leaf)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ff7b0591-1203-4f5e-886e-399822c68a15\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"source\": [\n    \"If you want to materialise them into a matrix, then this can be done:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"3713589f-1ac4-4e08-946b-ecc3fcf6a4c3\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"Array([[ 5.9  ,  0.   , -5.   ],\\n\",\n       \"       [ 0.   , -4.1  ,  0.   ],\\n\",\n       \"       [ 0.022, -0.071,  0.878]], dtype=float32)\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"operator.as_matrix()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a483517e-89d7-4e9e-ad89-1915d886c14c\",\n   \"metadata\": {},\n   \"source\": [\n    \"Which can in turn be treated as another linear operator, if desired:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"id\": \"fccddc81-d50e-4abe-a354-38402e462b1f\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"MatrixLinearOperator(\\n\",\n      \"  matrix=Array([[ 5.9  ,  0.   , -5.   ],\\n\",\n      \"       [ 0.   , -4.1  ,  0.   ],\\n\",\n      \"       [ 0.022, -0.071,  0.878]], dtype=float32),\\n\",\n      \"  tags=frozenset()\\n\",\n      \")\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"operator_fully_materialised = lx.MatrixLinearOperator(operator.as_matrix())\\n\",\n    \"eqx.tree_pprint(operator_fully_materialised, short_arrays=False)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"py39\",\n   \"language\": \"python\",\n   \"name\": \"py39\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.16\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "docs/examples/structured_matrices.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e2573d62-a505-4998-8796-b0f1bc889433\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Structured matrices\\n\",\n    \"\\n\",\n    \"Lineax can also be used with matrices known to exhibit special structure, e.g. tridiagonal matrices or positive definite matrices.\\n\",\n    \"\\n\",\n    \"Typically, that means using a particular operator type:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"8e275652-dd80-4a9a-b3ac-b96dc16d3334\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[[ 4.   2.   0.   0. ]\\n\",\n      \" [ 1.  -0.5 -1.   0. ]\\n\",\n      \" [ 0.   3.   7.  -5. ]\\n\",\n      \" [ 0.   0.  -0.7  1. ]]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import jax.numpy as jnp\\n\",\n    \"import jax.random as jr\\n\",\n    \"import lineax as lx\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"diag = jnp.array([4.0, -0.5, 7.0, 1.0])\\n\",\n    \"lower_diag = jnp.array([1.0, 3.0, -0.7])\\n\",\n    \"upper_diag = jnp.array([2.0, -1.0, -5.0])\\n\",\n    \"\\n\",\n    \"operator = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)\\n\",\n    \"print(operator.as_matrix())\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"ba23ecc4-bdea-4293-a138-ce77bc83082c\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"vector = jnp.array([1.0, -0.5, 2.0, 0.8])\\n\",\n    \"# Will automatically dispatch to a tridiagonal solver.\\n\",\n    \"solution = lx.linear_solve(operator, vector)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"cd58979d-b619-4ddf-9a17-12e8babae3e8\",\n   \"metadata\": {},\n   \"source\": [\n    \"If you're uncertain which solver is being dispatched to, then you can check:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"6984f62f-75fc-4d6e-ab42-fdade471be5b\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Tridiagonal()\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"default_solver = lx.AutoLinearSolver(well_posed=True)\\n\",\n    \"print(default_solver.select_solver(operator))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"164a5bd5-5d48-4b28-bcc5-d276ab49c780\",\n   \"metadata\": {},\n   \"source\": [\n    \"If you want to enforce that a particular solver is used, then it can be passed manually:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"102ada9a-0533-40cf-9bad-02918fffb6b1\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"solution = lx.linear_solve(operator, vector, solver=lx.Tridiagonal())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1b4ebf09-e138-43f6-973c-c9f005ffb55e\",\n   \"metadata\": {},\n   \"source\": [\n    \"Trying to use a solver with an unsupported operator will raise an error:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"d8f5bf66-53cd-4e81-a8d7-a19e86307ad3\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"ename\": \"ValueError\",\n     \"evalue\": \"`Tridiagonal` may only be used for linear solves with tridiagonal matrices\",\n     \"output_type\": \"error\",\n     \"traceback\": [\n      \"\\u001b[0;31mValueError\\u001b[0m\\u001b[0;31m:\\u001b[0m `Tridiagonal` may only be used for linear solves with tridiagonal matrices\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"not_tridiagonal_matrix = jr.normal(jr.PRNGKey(0), (4, 4))\\n\",\n    \"not_tridiagonal_operator = lx.MatrixLinearOperator(not_tridiagonal_matrix)\\n\",\n    \"solution = lx.linear_solve(not_tridiagonal_operator, vector, solver=lx.Tridiagonal())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"03c4c531-58fa-4b56-8b0a-6e611c8c5912\",\n   \"metadata\": {},\n   \"source\": [\n    \"---\\n\",\n    \"\\n\",\n    \"Besides using a particular operator type, the structure of the matrix can also be expressed by [adding particular tags](../api/tags.md). These tags act as a manual override mechanism, and the values of the matrix are not checked.\\n\",\n    \"\\n\",\n    \"For example, let's construct a positive definite matrix:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"b5add874-7a2c-4000-84c3-8c94a121a831\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"matrix = jr.normal(jr.PRNGKey(0), (4, 4))\\n\",\n    \"operator = lx.MatrixLinearOperator(matrix.T @ matrix)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5459b2d6-ddb9-4a37-bb51-3f5c204bab0d\",\n   \"metadata\": {},\n   \"source\": [\n    \"Unfortunately, Lineax has no way of knowing that this matrix is positive definite. It can solve the system, but it will not use a solver that is adapted to exploit the extra structure:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"78400416-e774-4f74-a530-e368db84af0e\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"LU()\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"solution = lx.linear_solve(operator, vector)\\n\",\n    \"print(default_solver.select_solver(operator))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e108bdff-1cf1-4751-8c9d-3baae82ca9a7\",\n   \"metadata\": {},\n   \"source\": [\n    \"But if we add a tag:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"id\": \"f6dc2966-1dfa-4a3c-be6a-974926695547\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Cholesky()\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"operator = lx.MatrixLinearOperator(matrix.T @ matrix, lx.positive_semidefinite_tag)\\n\",\n    \"solution2 = lx.linear_solve(operator, vector)\\n\",\n    \"print(default_solver.select_solver(operator))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7274d17b-a7d3-45bf-9042-785ac25e2d74\",\n   \"metadata\": {},\n   \"source\": [\n    \"Then a more efficient solver can be selected. We can check that the solutions returned from these two approaches are equal:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"fdcde152-9ac1-4532-a174-3fc39d83d289\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[ 1.400575   -0.41042092  0.5313305   0.28422552]\\n\",\n      \"[ 1.4005749  -0.41042086  0.53133047  0.2842255 ]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(solution.value)\\n\",\n    \"print(solution2.value)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"py39\",\n   \"language\": \"python\",\n   \"name\": \"py39\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.16\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "docs/faq.md",
    "content": "# FAQ\n\n## How does this differ from `jax.numpy.solve`, `jax.scipy.{...}` etc.?\n\nLineax offers several improvements. Most notably:\n\n- Several new solvers. For example, [`lineax.QR`][] has no counterpart in core JAX. (And it is much faster than `jax.numpy.linalg.lstsq`, which is the closest equivalent, and uses an SVD decomposition instead.)\n\n- Several new operators. For example, [`lineax.JacobianLinearOperator`][] has no counterpart in core JAX.\n\n- A consistent API. The built-in JAX operations all differ from each other slightly, and are split across `jax.numpy`, `jax.scipy`, and `jax.scipy.sparse`.\n\n- Numerically stable gradients. The existing JAX implementations will sometimes return `NaN`s!\n\n- Some faster compile times and run times in a few places.\n\nMost of these are because JAX aims to mimic the existing NumPy/SciPy APIs. (I.e. it's not JAX's fault that it doesn't take the approach that Lineax does!)\n\n## How do I represent a {lower, upper} triangular matrix?\n\nTypically: create a full matrix, with the {lower, upper} part containing your values, and the converse {upper, lower} part containing all zeros. Then use, e.g., `operator = lx.MatrixLinearOperator(matrix, lx.lower_triangular_tag)`.\n\nThis is the most efficient way to store a triangular matrix in JAX's ndarray-based programming model.\n\n## What about other operations from linear algebra? (Determinants, eigenvalues, etc.)\n\nSee [`jax.numpy.linalg`](https://jax.readthedocs.io/en/latest/jax.numpy.html#module-jax.numpy.linalg) and [`jax.scipy.linalg`](https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.linalg).\n\n## How do I solve multiple systems of equations (i.e. `AX = B`)?\n\nSolvers implemented in Lineax target single systems of linear equations (i.e., `Ax = b`), however, using `jax.vmap` or `equinox.filter_vmap`, it can solve multiple systems with minimal effort.\n\n```python\nmulti_linear_solve = eqx.filter_vmap(lx.linear_solve, in_axes=(None, 1))\n#  or    \nmulti_linear_solve = jax.vmap(lx.linear_solve, in_axes=(None, 1))\n```\n"
  },
  {
    "path": "docs/index.md",
    "content": "# Getting started\n\nLineax is a [JAX](https://github.com/google/jax) library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.)\n\nFeatures include:\n\n- PyTree-valued matrices and vectors;\n- General linear operators for Jacobians, transposes, etc.;\n- Efficient linear least squares (e.g. QR solvers);\n- Numerically stable gradients through linear least squares;\n- Support for structured (e.g. symmetric) matrices;\n- Improved compilation times;\n- Improved runtime of some algorithms;\n- Support for both real-valued and complex-valued inputs;\n- All the benefits of working with JAX: autodiff, autoparallism, GPU/TPU support, etc.\n\n## Installation\n\n```bash\npip install lineax\n```\n\nRequires Python 3.10+, JAX 0.4.38+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.10+.\n\n## Quick example\n\nLineax can solve a least squares problem with an explicit matrix operator:\n\n```python\nimport jax.random as jr\nimport lineax as lx\n\nmatrix_key, vector_key = jr.split(jr.PRNGKey(0))\nmatrix = jr.normal(matrix_key, (10, 8))\nvector = jr.normal(vector_key, (10,))\noperator = lx.MatrixLinearOperator(matrix)\nsolution = lx.linear_solve(operator, vector, solver=lx.QR())\n```\n\nor Lineax can solve a problem without ever materializing a matrix, as done in this\nquadratic solve:\n\n```python\nimport jax\nimport lineax as lx\n\nkey = jax.random.PRNGKey(0)\ny = jax.random.normal(key, (10,))\n\ndef quadratic_fn(y, args):\n  return jax.numpy.sum((y - 1)**2)\n\ngradient_fn = jax.grad(quadratic_fn)\nhessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag)\nsolver = lx.CG(rtol=1e-6, atol=1e-6)\nout = lx.linear_solve(hessian, gradient_fn(y, args=None), solver)\nminimum = y - out.value\n```\n\n## Next steps\n\nCheck out the examples or the API reference on the left-hand bar.\n\n## See also: other libraries in the JAX ecosystem\n\n**Always useful**  \n[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!  \n[jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.  \n\n**Deep learning**  \n[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.  \n[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).  \n[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).  \n[paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.  \n\n**Scientific computing**  \n[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.  \n[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.  \n[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.  \n[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.  \n[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)  \n\n**Awesome JAX**  \n[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.  \n"
  },
  {
    "path": "lineax/__init__.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib.metadata\n\nfrom . import internal as internal\nfrom ._operator import (\n    AbstractLinearOperator as AbstractLinearOperator,\n    AddLinearOperator as AddLinearOperator,\n    ComposedLinearOperator as ComposedLinearOperator,\n    conj as conj,\n    diagonal as diagonal,\n    DiagonalLinearOperator as DiagonalLinearOperator,\n    DivLinearOperator as DivLinearOperator,\n    FunctionLinearOperator as FunctionLinearOperator,\n    has_unit_diagonal as has_unit_diagonal,\n    IdentityLinearOperator as IdentityLinearOperator,\n    is_diagonal as is_diagonal,\n    is_lower_triangular as is_lower_triangular,\n    is_negative_semidefinite as is_negative_semidefinite,\n    is_positive_semidefinite as is_positive_semidefinite,\n    is_symmetric as is_symmetric,\n    is_tridiagonal as is_tridiagonal,\n    is_upper_triangular as is_upper_triangular,\n    JacobianLinearOperator as JacobianLinearOperator,\n    linearise as linearise,\n    materialise as materialise,\n    MatrixLinearOperator as MatrixLinearOperator,\n    MulLinearOperator as MulLinearOperator,\n    NegLinearOperator as NegLinearOperator,\n    PyTreeLinearOperator as PyTreeLinearOperator,\n    TaggedLinearOperator as TaggedLinearOperator,\n    TangentLinearOperator as TangentLinearOperator,\n    tridiagonal as tridiagonal,\n    TridiagonalLinearOperator as TridiagonalLinearOperator,\n)\nfrom ._solution import RESULTS as RESULTS, Solution as Solution\nfrom ._solve import (\n    AbstractLinearSolver as AbstractLinearSolver,\n    AutoLinearSolver as AutoLinearSolver,\n    invert as invert,\n    linear_solve as linear_solve,\n)\nfrom ._solver import (\n    BiCGStab as BiCGStab,\n    CG as CG,\n    Cholesky as Cholesky,\n    Diagonal as Diagonal,\n    GMRES as GMRES,\n    LSMR as LSMR,\n    LU as LU,\n    Normal as Normal,\n    NormalCG as NormalCG,\n    QR as QR,\n    SVD as SVD,\n    Triangular as Triangular,\n    Tridiagonal as Tridiagonal,\n)\nfrom ._tags import (\n    diagonal_tag as diagonal_tag,\n    lower_triangular_tag as lower_triangular_tag,\n    negative_semidefinite_tag as negative_semidefinite_tag,\n    positive_semidefinite_tag as positive_semidefinite_tag,\n    symmetric_tag as symmetric_tag,\n    transpose_tags as transpose_tags,\n    transpose_tags_rules as transpose_tags_rules,\n    tridiagonal_tag as tridiagonal_tag,\n    unit_diagonal_tag as unit_diagonal_tag,\n    upper_triangular_tag as upper_triangular_tag,\n)\n\n\n__version__ = importlib.metadata.version(\"lineax\")\n"
  },
  {
    "path": "lineax/_custom_types.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Any\n\nimport equinox.internal as eqxi\n\n\nsentinel: Any = eqxi.doc_repr(object(), \"sentinel\")\n"
  },
  {
    "path": "lineax/_misc.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport equinox as eqx\nimport jax\nimport jax.numpy as jnp\nimport jax.tree_util as jtu\nfrom jaxtyping import Array, ArrayLike, Bool, PyTree  # pyright:ignore\n\n\ndef tree_where(\n    pred: Bool[ArrayLike, \"\"], true: PyTree[ArrayLike], false: PyTree[ArrayLike]\n) -> PyTree[Array]:\n    keep = lambda a, b: jnp.where(pred, a, b)\n    return jtu.tree_map(keep, true, false)\n\n\ndef resolve_rcond(rcond, n, m, dtype):\n    if rcond is None:\n        # This `2 *` is a heuristic: I have seen very rare failures without it, in ways\n        # that seem to depend on JAX compilation state. (E.g. running unrelated JAX\n        # computations beforehand, in a completely different JIT-compiled region, can\n        # result in differences in the success/failure of the solve.)\n        return 2 * jnp.finfo(dtype).eps * max(n, m)\n    else:\n        return jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)\n\n\ndef jacobian(fn, in_size, out_size, holomorphic=False, has_aux=False, jac=None):\n    if jac is None:\n        # Heuristic for which is better in each case\n        # These could probably be tuned a lot more.\n        jac_fwd = (in_size < 100) or (in_size <= 1.5 * out_size)\n    elif jac == \"fwd\":\n        jac_fwd = True\n    elif jac == \"bwd\":\n        jac_fwd = False\n    else:\n        raise ValueError(\"`jac` should either be None, 'fwd', or 'bwd'.\")\n    if jac_fwd:\n        return jax.jacfwd(fn, holomorphic=holomorphic, has_aux=has_aux)\n    else:\n        return jax.jacrev(fn, holomorphic=holomorphic, has_aux=has_aux)\n\n\ndef _asarray(dtype, x):\n    return jnp.asarray(x, dtype=dtype)\n\n\n# Work around JAX issue #15676\n_asarray = jax.custom_jvp(_asarray, nondiff_argnums=(0,))\n\n\n@_asarray.defjvp\ndef _asarray_jvp(dtype, x, tx):\n    (x,) = x\n    (tx,) = tx\n    return _asarray(dtype, x), _asarray(dtype, tx)\n\n\ndef default_floating_dtype():\n    if jax.config.jax_enable_x64:  # pyright: ignore\n        return jnp.float64\n    else:\n        return jnp.float32\n\n\ndef inexact_asarray(x):\n    dtype = jnp.result_type(x)\n    if not jnp.issubdtype(jnp.result_type(x), jnp.inexact):\n        dtype = default_floating_dtype()\n    return _asarray(dtype, x)\n\n\ndef complex_to_real_dtype(dtype):\n    return jnp.finfo(dtype).dtype\n\n\ndef strip_weak_dtype(tree: PyTree) -> PyTree:\n    return jtu.tree_map(\n        lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=x.sharding)\n        if type(x) is jax.ShapeDtypeStruct\n        else x,\n        tree,\n    )\n\n\ndef structure_equal(x, y) -> bool:\n    x = strip_weak_dtype(jax.eval_shape(lambda: x))\n    y = strip_weak_dtype(jax.eval_shape(lambda: y))\n    return eqx.tree_equal(x, y) is True\n"
  },
  {
    "path": "lineax/_norm.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools as ft\nimport math\n\nimport jax\nimport jax.numpy as jnp\nimport jax.tree_util as jtu\nfrom equinox.internal import ω\nfrom jaxtyping import Array, ArrayLike, Inexact, PyTree, Scalar\n\nfrom ._misc import complex_to_real_dtype, default_floating_dtype\n\n\ndef tree_dot(tree1: PyTree[ArrayLike], tree2: PyTree[ArrayLike]) -> Inexact[Array, \"\"]:\n    \"\"\"Compute the dot product of two pytrees of arrays with the same pytree\n    structure.\"\"\"\n    leaves1, treedef1 = jtu.tree_flatten(tree1)\n    leaves2, treedef2 = jtu.tree_flatten(tree2)\n    if treedef1 != treedef2:\n        raise ValueError(\"trees must have the same structure\")\n    assert len(leaves1) == len(leaves2)\n    dots = []\n    for leaf1, leaf2 in zip(leaves1, leaves2):\n        dots.append(\n            jnp.dot(\n                jnp.conj(leaf1).reshape(-1),\n                jnp.reshape(leaf2, -1),\n                precision=jax.lax.Precision.HIGHEST,  # pyright: ignore\n            )\n        )\n    if len(dots) == 0:\n        return jnp.array(0, default_floating_dtype())\n    else:\n        return ft.reduce(jnp.add, dots)\n\n\ndef sum_squares(x: PyTree[ArrayLike]) -> Scalar:\n    \"\"\"Computes the square of the L2 norm of a PyTree of arrays.\n\n    Considering the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes\n    `Σ_i x_i^2`\n    \"\"\"\n    return tree_dot(x, x).real\n\n\ndef two_norm(x: PyTree[ArrayLike]) -> Scalar:\n    \"\"\"Computes the L2 norm of a PyTree of arrays.\n\n    Considering the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes\n    `sqrt(Σ_i x_i^2)`\n    \"\"\"\n    # Wrap the `custom_jvp` into a function so that our autogenerated documentation\n    # displays the docstring correctly.\n    return _two_norm(x)\n\n\n@jax.custom_jvp\ndef _two_norm(x: PyTree[ArrayLike]) -> Scalar:\n    leaves = jtu.tree_leaves(x)\n    size = sum([jnp.size(xi) for xi in leaves])\n    if size == 1:\n        # Avoid needless squaring-and-then-rooting.\n        for leaf in leaves:\n            if jnp.size(leaf) == 1:\n                return jnp.abs(jnp.reshape(leaf, ()))\n        else:\n            assert False\n    else:\n        return jnp.sqrt(sum_squares(x))\n\n\n@_two_norm.defjvp\ndef _two_norm_jvp(x, tx):\n    (x,) = x\n    (tx,) = tx\n    out = two_norm(x)\n    # Get zero gradient, rather than NaN gradient, in these cases.\n    pred = (out == 0) | jnp.isinf(out)\n    denominator = jnp.where(pred, 1, out)\n    # We could also switch the dot and the division.\n    # This approach is a bit more expensive (more divisions), but should be more\n    # numerically stable (`x` and `denominator` should be of the same scale; `tx` is of\n    # unknown scale).\n    with jax.numpy_dtype_promotion(\"standard\"):\n        div = (x**ω / denominator).ω\n    t_out = tree_dot(div, tx).real\n    t_out = jnp.where(pred, 0, t_out)\n    return out, t_out\n\n\ndef rms_norm(x: PyTree[ArrayLike]) -> Scalar:\n    \"\"\"Compute the RMS (root-mean-squared) norm of a PyTree of arrays.\n\n    This is the same as the L2 norm, averaged by the size of the input `x`. Considering\n    the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes\n    `sqrt((Σ_i x_i^2)/n)`\n    \"\"\"\n    leaves = jtu.tree_leaves(x)\n    size = sum([jnp.size(xi) for xi in leaves])\n    if size == 0:\n        if len(leaves) == 0:\n            dtype = default_floating_dtype()\n        else:\n            dtype = complex_to_real_dtype(jnp.result_type(*leaves))\n        return jnp.array(0.0, dtype)\n    else:\n        return two_norm(x) / math.sqrt(size)\n\n\ndef max_norm(x: PyTree[ArrayLike]) -> Scalar:\n    \"\"\"Compute the L-infinity norm of a PyTree of arrays.\n\n    This is the largest absolute elementwise value. Considering the input `x` as a flat\n    vector `(x_1, ..., x_n)`, then this computes `max_i |x_i|`.\n    \"\"\"\n    leaves = jtu.tree_leaves(x)\n    leaf_maxes = [jnp.max(jnp.abs(xi)) for xi in leaves if jnp.size(xi) > 0]\n    if len(leaf_maxes) == 0:\n        if len(leaves) == 0:\n            dtype = default_floating_dtype()\n        else:\n            dtype = complex_to_real_dtype(jnp.result_type(*leaves))\n        return jnp.array(0.0, dtype)\n    else:\n        out = ft.reduce(jnp.maximum, leaf_maxes)\n        return _zero_grad_at_zero(out)\n\n\n@jax.custom_jvp\ndef _zero_grad_at_zero(x):\n    return x\n\n\n@_zero_grad_at_zero.defjvp\ndef _zero_grad_at_zero_jvp(primals, tangents):\n    (out,) = primals\n    (t_out,) = tangents\n    return out, jnp.where(out == 0, 0, t_out)\n"
  },
  {
    "path": "lineax/_operator.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport abc\nimport enum\nimport functools as ft\nimport math\nimport warnings\nfrom collections.abc import Callable, Iterable\nfrom typing import Any, Literal, NoReturn, TypeVar\n\nimport equinox as eqx\nimport equinox.internal as eqxi\nimport jax\nimport jax.flatten_util as jfu\nimport jax.lax as lax\nimport jax.numpy as jnp\nimport jax.tree_util as jtu\nimport numpy as np\nfrom equinox.internal import ω\nfrom jaxtyping import (\n    Array,\n    ArrayLike,\n    Inexact,\n    PyTree,  # pyright: ignore\n    Scalar,\n    Shaped,\n)\n\nfrom ._custom_types import sentinel\nfrom ._misc import (\n    default_floating_dtype,\n    inexact_asarray,\n    jacobian,\n    strip_weak_dtype,\n)\nfrom ._tags import (\n    diagonal_tag,\n    lower_triangular_tag,\n    negative_semidefinite_tag,\n    positive_semidefinite_tag,\n    symmetric_tag,\n    transpose_tags,\n    tridiagonal_tag,\n    unit_diagonal_tag,\n    upper_triangular_tag,\n)\n\n\ndef _frozenset(x: object | Iterable[object]) -> frozenset[object]:\n    try:\n        iter_x = iter(x)  # pyright: ignore\n    except TypeError:\n        return frozenset([x])\n    else:\n        return frozenset(iter_x)\n\n\nclass AbstractLinearOperator(eqx.Module):\n    \"\"\"Abstract base class for all linear operators.\n\n    Linear operators can act between PyTrees. Each `AbstractLinearOperator` is thought\n    of as a linear function `X -> Y`, where each element of `X` is as PyTree of\n    floating-point JAX arrays, and each element of `Y` is a PyTree of floating-point\n    JAX arrays.\n\n    Abstract linear operators support some operations:\n    ```python\n    op1 + op2  # addition of two operators\n    op1 @ op2  # composition of two operators.\n    op1 * 3.2  # multiplication by a scalar\n    op1 / 3.2  # division by a scalar\n    ```\n    \"\"\"\n\n    def __check_init__(self):\n        if (\n            is_symmetric(self)\n            or is_positive_semidefinite(self)\n            or is_negative_semidefinite(self)\n        ):\n            # In particular, we check that dtypes match.\n            in_structure = self.in_structure()\n            out_structure = self.out_structure()\n            # `is` check to handle the possibility of a tracer.\n            if eqx.tree_equal(in_structure, out_structure) is not True:\n                raise ValueError(\n                    \"Symmetric/Hermitian matrices must have matching input and output \"\n                    f\"structures. Got input structure {in_structure} and output \"\n                    f\"structure {out_structure}.\"\n                )\n\n    @abc.abstractmethod\n    def mv(\n        self, vector: PyTree[Inexact[Array, \" _b\"]]\n    ) -> PyTree[Inexact[Array, \" _a\"]]:\n        \"\"\"Computes a matrix-vector product between this operator and a `vector`.\n\n        **Arguments:**\n\n        - `vector`: Should be some PyTree of floating-point arrays, whose structure\n            should match `self.in_structure()`.\n\n        **Returns:**\n\n        A PyTree of floating-point arrays, with structure that matches\n        `self.out_structure()`.\n        \"\"\"\n\n    @abc.abstractmethod\n    def as_matrix(self) -> Inexact[Array, \"a b\"]:\n        \"\"\"Materialises this linear operator as a matrix.\n\n        Note that this can be a computationally (time and/or memory) expensive\n        operation, as many linear operators are defined implicitly, e.g. in terms of\n        their action on a vector.\n\n        **Arguments:** None.\n\n        **Returns:**\n\n        A 2-dimensional floating-point JAX array.\n        \"\"\"\n\n    @abc.abstractmethod\n    def transpose(self) -> \"AbstractLinearOperator\":\n        \"\"\"Transposes this linear operator.\n\n        This can be called as either `operator.T` or `operator.transpose()`.\n\n        **Arguments:** None.\n\n        **Returns:**\n\n        Another [`lineax.AbstractLinearOperator`][].\n        \"\"\"\n\n    @abc.abstractmethod\n    def in_structure(self) -> PyTree[jax.ShapeDtypeStruct]:\n        \"\"\"Returns the expected input structure of this linear operator.\n\n        **Arguments:** None.\n\n        **Returns:**\n\n        A PyTree of `jax.ShapeDtypeStruct`.\n        \"\"\"\n\n    @abc.abstractmethod\n    def out_structure(self) -> PyTree[jax.ShapeDtypeStruct]:\n        \"\"\"Returns the expected output structure of this linear operator.\n\n        **Arguments:** None.\n\n        **Returns:**\n\n        A PyTree of `jax.ShapeDtypeStruct`.\n        \"\"\"\n\n    def in_size(self) -> int:\n        \"\"\"Returns the total number of scalars in the input of this linear operator.\n\n        That is, the dimensionality of its input space.\n\n        **Arguments:** None.\n\n        **Returns:** An integer.\n        \"\"\"\n        leaves = jtu.tree_leaves(self.in_structure())\n        return sum(math.prod(leaf.shape) for leaf in leaves)  # pyright: ignore\n\n    def out_size(self) -> int:\n        \"\"\"Returns the total number of scalars in the output of this linear operator.\n\n        That is, the dimensionality of its output space.\n\n        **Arguments:** None.\n\n        **Returns:** An integer.\n        \"\"\"\n        leaves = jtu.tree_leaves(self.out_structure())\n        return sum(math.prod(leaf.shape) for leaf in leaves)  # pyright: ignore\n\n    @property\n    def T(self) -> \"AbstractLinearOperator\":\n        \"\"\"Equivalent to [`lineax.AbstractLinearOperator.transpose`][]\"\"\"\n        return self.transpose()\n\n    def __add__(self, other) -> \"AbstractLinearOperator\":\n        if not isinstance(other, AbstractLinearOperator):\n            raise ValueError(\"Can only add AbstractLinearOperators together.\")\n        return AddLinearOperator(self, other)\n\n    def __sub__(self, other) -> \"AbstractLinearOperator\":\n        if not isinstance(other, AbstractLinearOperator):\n            raise ValueError(\"Can only add AbstractLinearOperators together.\")\n        return AddLinearOperator(self, -other)\n\n    def __mul__(self, other) -> \"AbstractLinearOperator\":\n        other = jnp.asarray(other)\n        if other.shape != ():\n            raise ValueError(\"Can only multiply AbstractLinearOperators by scalars.\")\n        return MulLinearOperator(self, other)\n\n    def __rmul__(self, other) -> \"AbstractLinearOperator\":\n        return self * other\n\n    def __matmul__(self, other) -> \"AbstractLinearOperator\":\n        if not isinstance(other, AbstractLinearOperator):\n            raise ValueError(\"Can only compose AbstractLinearOperators together.\")\n        return ComposedLinearOperator(self, other)\n\n    def __truediv__(self, other) -> \"AbstractLinearOperator\":\n        other = jnp.asarray(other)\n        if other.shape != ():\n            raise ValueError(\"Can only divide AbstractLinearOperators by scalars.\")\n        return DivLinearOperator(self, other)\n\n    def __neg__(self) -> \"AbstractLinearOperator\":\n        return NegLinearOperator(self)\n\n\nclass MatrixLinearOperator(AbstractLinearOperator):\n    \"\"\"Wraps a 2-dimensional JAX array into a linear operator.\n\n    If the matrix has shape `(a, b)` then matrix-vector multiplication (`self.mv`) is\n    defined in the usual way: as performing a matrix-vector that accepts a vector of\n    shape `(a,)` and returns a vector of shape `(b,)`.\n    \"\"\"\n\n    matrix: Inexact[Array, \"a b\"]\n    tags: frozenset[object] = eqx.field(static=True)\n\n    def __init__(\n        self, matrix: Shaped[Array, \"a b\"], tags: object | frozenset[object] = ()\n    ):\n        \"\"\"**Arguments:**\n\n        - `matrix`: a two-dimensional JAX array. For an array with shape `(a, b)` then\n            this operator can perform matrix-vector products on a vector of shape\n            `(b,)` to return a vector of shape `(a,)`.\n        - `tags`: any tags indicating whether this matrix has any particular properties,\n            like symmetry or positive-definite-ness. Note that these properties are\n            unchecked and you may get incorrect values elsewhere if these tags are\n            wrong.\n        \"\"\"\n        if jnp.ndim(matrix) != 2:\n            raise ValueError(\n                \"`MatrixLinearOperator(matrix=...)` should be 2-dimensional.\"\n            )\n        if not jnp.issubdtype(matrix.dtype, jnp.inexact):\n            matrix = matrix.astype(jnp.float32)\n        self.matrix = matrix\n        self.tags = _frozenset(tags)\n\n    def mv(self, vector):\n        maybe_sparse_op = _try_sparse_materialise(self)\n        if maybe_sparse_op is not self:\n            return maybe_sparse_op.mv(vector)\n        return jnp.matmul(self.matrix, vector, precision=lax.Precision.HIGHEST)\n\n    def as_matrix(self):\n        return self.matrix\n\n    def transpose(self):\n        if is_symmetric(self):\n            return self\n        return MatrixLinearOperator(self.matrix.T, transpose_tags(self.tags))\n\n    def in_structure(self):\n        _, in_size = jnp.shape(self.matrix)\n        return jax.ShapeDtypeStruct(shape=(in_size,), dtype=self.matrix.dtype)\n\n    def out_structure(self):\n        out_size, _ = jnp.shape(self.matrix)\n        return jax.ShapeDtypeStruct(shape=(out_size,), dtype=self.matrix.dtype)\n\n\ndef _matmul(matrix: ArrayLike, vector: ArrayLike) -> Array:\n    # matrix has structure [leaf(out), leaf(in)]\n    # vector has structure [leaf(in)]\n    # return has structure [leaf(out)]\n    return jnp.tensordot(\n        matrix, vector, axes=jnp.ndim(vector), precision=lax.Precision.HIGHEST\n    )\n\n\ndef _tree_matmul(matrix: PyTree[ArrayLike], vector: PyTree[ArrayLike]) -> PyTree[Array]:\n    # matrix has structure [tree(in), leaf(out), leaf(in)]\n    # vector has structure [tree(in), leaf(in)]\n    # return has structure [leaf(out)]\n    matrix = jtu.tree_leaves(matrix)\n    vector = jtu.tree_leaves(vector)\n    assert len(matrix) == len(vector)\n    return sum([_matmul(m, v) for m, v in zip(matrix, vector)])\n\n\n# Needed as static fields must be hashable and eq-able, and custom pytrees might have\n# e.g. define custom __eq__ methods.\n_T = TypeVar(\"_T\")\n_FlatPyTree = tuple[list[_T], jtu.PyTreeDef]\n\n\ndef _inexact_structure_impl2(x):\n    if jnp.issubdtype(x.dtype, jnp.inexact):\n        return x\n    else:\n        return x.astype(default_floating_dtype())\n\n\ndef _inexact_structure_impl(x):\n    return jtu.tree_map(_inexact_structure_impl2, x)\n\n\ndef _inexact_structure(x: PyTree[jax.ShapeDtypeStruct]) -> PyTree[jax.ShapeDtypeStruct]:\n    return strip_weak_dtype(jax.eval_shape(_inexact_structure_impl, x))\n\n\nclass _Leaf:  # not a pytree\n    def __init__(self, value):\n        self.value = value\n\n\n# The `{input,output}_structure`s have to be static because otherwise abstract\n# evaluation rules will promote them to ShapedArrays.\nclass PyTreeLinearOperator(AbstractLinearOperator):\n    \"\"\"Represents a PyTree of floating-point JAX arrays as a linear operator.\n\n    This is basically a generalisation of [`lineax.MatrixLinearOperator`][], from\n    taking just a single array to take a PyTree-of-arrays. (And likewise from returning\n    a single array to returning a PyTree-of-arrays.)\n\n    Specifically, suppose we want this to be a linear operator `X -> Y`, for which\n    elements of `X` are PyTrees with structure `T` whose `i`th leaf is a floating-point\n    JAX array of shape `x_shape_i`, and elements of `Y` are PyTrees with structure `S`\n    whose `j`th leaf is a floating-point JAX array of has shape `y_shape_j`. Then the\n    input PyTree should have structure `T`-compose-`S`, and its `(i, j)`-th  leaf should\n    be a floating-point JAX array of shape `(*x_shape_i, *y_shape_j)`.\n\n    !!! Example\n\n        ```python\n        # Suppose `x` is a member of our input space, with the following pytree\n        # structure:\n        eqx.tree_pprint(x)  # [f32[5, 9], f32[3]]\n\n        # Suppose `y` is a member of our output space, with the following pytree\n        # structure:\n        eqx.tree_pprint(y)\n        # {\"a\": f32[1, 2]}\n\n        # then `pytree` should be a pytree with the following structure:\n        eqx.tree_pprint(pytree)  # {\"a\": [f32[1, 2, 5, 9], f32[1, 2, 3]]}\n        ```\n    \"\"\"\n\n    pytree: PyTree[Inexact[Array, \"...\"]]\n    output_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)\n    tags: frozenset[object] = eqx.field(static=True)\n    input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)\n\n    def __init__(\n        self,\n        pytree: PyTree[ArrayLike],\n        output_structure: PyTree[jax.ShapeDtypeStruct],\n        tags: object | frozenset[object] = (),\n    ):\n        \"\"\"**Arguments:**\n\n        - `pytree`: this should be a PyTree, with structure as specified in\n            [`lineax.PyTreeLinearOperator`][].\n        - `output_structure`: the structure of the output space. This should be a PyTree\n            of `jax.ShapeDtypeStruct`s. (The structure of the input space is then\n            automatically derived from the structure of `pytree`.)\n        - `tags`: any tags indicating whether this operator has any particular\n            properties, like symmetry or positive-definite-ness. Note that these\n            properties are unchecked and you may get incorrect values elsewhere if these\n            tags are wrong.\n        \"\"\"\n        output_structure = _inexact_structure(output_structure)\n        self.pytree = jtu.tree_map(inexact_asarray, pytree)\n        self.output_structure = jtu.tree_flatten(output_structure)\n        self.tags = _frozenset(tags)\n\n        # self.out_structure() has structure [tree(out)]\n        # self.pytree has structure [tree(out), tree(in), leaf(out), leaf(in)]\n        def get_structure(struct, subpytree):\n            # subpytree has structure [tree(in), leaf(out), leaf(in)]\n            def sub_get_structure(leaf):\n                shape = jnp.shape(leaf)  # [leaf(out), leaf(in)]\n                ndim = len(struct.shape)\n                if shape[:ndim] != struct.shape:\n                    raise ValueError(\n                        \"`pytree` and `output_structure` are not consistent\"\n                    )\n                return jax.ShapeDtypeStruct(\n                    shape=shape[ndim:], dtype=jnp.result_type(leaf)\n                )\n\n            return _Leaf(jtu.tree_map(sub_get_structure, subpytree))\n\n        if output_structure is None:\n            # Implies that len(input_structures) > 0\n            raise ValueError(\"Cannot have trivial output_structure\")\n        input_structures = jtu.tree_map(get_structure, output_structure, self.pytree)\n        input_structures = jtu.tree_leaves(input_structures)\n        input_structure = input_structures[0].value\n        for val in input_structures[1:]:\n            if eqx.tree_equal(input_structure, val.value) is not True:\n                raise ValueError(\n                    \"`pytree` does not have a consistent `input_structure`\"\n                )\n        self.input_structure = jtu.tree_flatten(input_structure)\n\n    def mv(self, vector):\n        # vector has structure [tree(in), leaf(in)]\n        # self.out_structure() has structure [tree(out)]\n        # self.pytree has structure [tree(out), tree(in), leaf(out), leaf(in)]\n        # return has structure [tree(out), leaf(out)]\n        maybe_sparse_op = _try_sparse_materialise(self)\n        if maybe_sparse_op is not self:\n            return maybe_sparse_op.mv(vector)\n\n        def matmul(_, matrix):\n            return _tree_matmul(matrix, vector)\n\n        return jtu.tree_map(matmul, self.out_structure(), self.pytree)\n\n    def as_matrix(self):\n        with jax.numpy_dtype_promotion(\"standard\"):\n            dtype = jnp.result_type(*jtu.tree_leaves(self.pytree))\n\n        def concat_in(struct, subpytree):\n            leaves = jtu.tree_leaves(subpytree)\n            assert all(leaf.shape[: struct.ndim] == struct.shape for leaf in leaves)\n            leaves = [\n                leaf.astype(dtype).reshape(\n                    struct.size, math.prod(leaf.shape[struct.ndim :])\n                )\n                for leaf in leaves\n            ]\n            return jnp.concatenate(leaves, axis=1)\n\n        matrix = jtu.tree_map(concat_in, self.out_structure(), self.pytree)\n        matrix = jtu.tree_leaves(matrix)\n        return jnp.concatenate(matrix, axis=0)\n\n    def transpose(self):\n        if is_symmetric(self):\n            return self\n\n        def _transpose(struct, subtree):\n            def _transpose_impl(leaf):\n                return jnp.moveaxis(leaf, source, dest)\n\n            source = list(range(struct.ndim))\n            dest = list(range(-struct.ndim, 0))\n            return jtu.tree_map(_transpose_impl, subtree)\n\n        pytree_transpose = jtu.tree_map(_transpose, self.out_structure(), self.pytree)\n        pytree_transpose = jtu.tree_transpose(\n            jtu.tree_structure(self.out_structure()),\n            jtu.tree_structure(self.in_structure()),\n            pytree_transpose,\n        )\n        return PyTreeLinearOperator(\n            pytree_transpose, self.in_structure(), transpose_tags(self.tags)\n        )\n\n    def in_structure(self):\n        leaves, treedef = self.input_structure\n        return jtu.tree_unflatten(treedef, leaves)\n\n    def out_structure(self):\n        leaves, treedef = self.output_structure\n        return jtu.tree_unflatten(treedef, leaves)\n\n\nclass DiagonalLinearOperator(AbstractLinearOperator):\n    \"\"\"A diagonal linear operator, e.g. for a diagonal matrix. Only the diagonal is\n    stored (for memory efficiency). Matrix-vector products are computed by doing a\n    pointwise diagonal * vector, rather than a full matrix @ vector (for speed).\n\n    The diagonal may also be a PyTree, rather than a 1D array. When materialising the\n    matrix, the diagonal is taken to be defined by the flattened PyTree (i.e. values\n    show up in the same order.)\n    \"\"\"\n\n    diagonal: PyTree[Inexact[Array, \"...\"]]\n\n    def __init__(self, diagonal: PyTree[ArrayLike]):\n        \"\"\"**Arguments:**\n\n        - `diagonal`: an array or PyTree defining the diagonal of the matrix.\n        \"\"\"\n        self.diagonal = jtu.tree_map(inexact_asarray, diagonal)\n\n    def mv(self, vector):\n        return (ω(self.diagonal) * ω(vector)).ω\n\n    def as_matrix(self):\n        return jnp.diag(diagonal(self))\n\n    def transpose(self):\n        return self\n\n    def in_structure(self):\n        return jax.eval_shape(lambda: self.diagonal)\n\n    def out_structure(self):\n        return jax.eval_shape(lambda: self.diagonal)\n\n\nclass _NoAuxIn(eqx.Module):\n    fn: Callable\n    args: Any\n\n    def __call__(self, x):\n        return self.fn(x, self.args)\n\n\nclass _Unwrap(eqx.Module):\n    fn: Callable\n\n    def __call__(self, x):\n        (f,) = self.fn(x)\n        return f\n\n\nclass JacobianLinearOperator(AbstractLinearOperator):\n    \"\"\"Given a function `fn: X -> Y`, and a point `x in X`, then this defines the\n    linear operator (also a function `X -> Y`) given by the Jacobian `(d(fn)/dx)(x)`.\n\n    For example if the inputs and outputs are just arrays, then this is equivalent to\n    `MatrixLinearOperator(jax.jacfwd(fn)(x))`.\n\n    The Jacobian is not materialised; matrix-vector products, which are in fact\n    Jacobian-vector products, are computed using autodifferentiation. By default\n    (or with `jac=\"fwd\"`), `JacobianLinearOperator(fn, x).mv(v)` is equivalent to\n    `jax.jvp(fn, (x,), (v,))`. For `jac=\"bwd\"`, `jax.vjp` is combined with\n    `jax.linear_transpose`, which works even with functions\n    that only define a custom VJP (via `jax.custom_vjp`) and don't support\n    forward-mode differentiation.\n\n    See also [`lineax.materialise`][], which materialises the whole Jacobian in\n    memory.\n\n    !!! tip\n\n        For repeated `mv()` calls, consider using [`lineax.linearise`][] to cache\n        the primal computation,  e.g. for `jac=\"fwd\"/None` it returns\n        `_, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)`\n    \"\"\"\n\n    fn: Callable[\n        [PyTree[Inexact[Array, \"...\"]], PyTree[Any]], PyTree[Inexact[Array, \"...\"]]\n    ]\n    x: PyTree[Inexact[Array, \"...\"]]\n    args: PyTree[Any]\n    tags: frozenset[object] = eqx.field(static=True)\n    jac: Literal[\"fwd\", \"bwd\"] | None\n\n    @eqxi.doc_remove_args(\"closure_convert\")\n    def __init__(\n        self,\n        fn: Callable,\n        x: PyTree[ArrayLike],\n        args: PyTree[Any] = None,\n        tags: object | Iterable[object] = (),\n        jac: Literal[\"fwd\", \"bwd\"] | None = None,\n        closure_convert: bool = True,\n    ):\n        \"\"\"**Arguments:**\n\n        - `fn`: A function `(x, args) -> y`. The Jacobian `d(fn)/dx` is used as the\n            linear operator, and `args` are just any other arguments that should not be\n            differentiated.\n        - `x`: The point to evaluate `d(fn)/dx` at: `(d(fn)/dx)(x, args)`.\n        - `args`: As `x`; this is the point to evaluate `d(fn)/dx` at:\n            `(d(fn)/dx)(x, args)`.\n        - `tags`: any tags indicating whether this operator has any particular\n            properties, like symmetry or positive-definite-ness. Note that these\n            properties are unchecked and you may get incorrect values elsewhere if these\n            tags are wrong.\n        - `jac`: allows to use specific jacobian computation method. If `jac=fwd`\n           forces `jax.jacfwd` to be used, similarly `jac=bwd` mandates the use of\n           `jax.jacrev`. Otherwise, if not specified it will be chosen\n           by default according to input and output shape.\n        \"\"\"\n        if jac not in [None, \"fwd\", \"bwd\"]:\n            raise ValueError(\n                \"`jac` argument of `JacobianLinearOperator` should be either \"\n                \"`'fwd'`, `'bwd'`, or `None`.\"\n            )\n        # Flush out any closed-over values, so that we can safely pass `self`\n        # across API boundaries. (In particular, across `linear_solve_p`.)\n        # We don't use `jax.closure_convert` as that only flushes autodiffable\n        # (=floating-point) constants. It probably doesn't matter, but if `fn` is a\n        # PyTree capturing non-floating-point constants, we should probably continue\n        # to respect that, and keep any non-floating-point constants as part of the\n        # PyTree structure.\n        x = jtu.tree_map(inexact_asarray, x)\n        if closure_convert:\n            fn = eqx.filter_closure_convert(fn, x, args)\n        self.fn = fn\n        self.x = x\n        self.args = args\n        self.tags = _frozenset(tags)\n        self.jac = jac\n\n    def mv(self, vector):\n        fn = _NoAuxIn(self.fn, self.args)\n        if self.jac == \"fwd\" or self.jac is None:\n            _, out = jax.jvp(fn, (self.x,), (vector,))\n        elif self.jac == \"bwd\":\n            # Use VJP + linear_transpose instead of materializing full Jacobian.\n            # This works even for custom_vjp functions that don't have JVP rules.\n            _, vjp_fn = jax.vjp(fn, self.x)\n            if is_symmetric(self):\n                # For symmetric operators, J = J.T, so vjp directly gives J @ v\n                (out,) = vjp_fn(vector)\n            else:\n                # For non-symmetric, transpose the VJP to get J @ v from J.T @ v\n                transpose_vjp = jax.linear_transpose(\n                    lambda g: vjp_fn(g)[0], self.out_structure()\n                )\n                (out,) = transpose_vjp(vector)\n        else:\n            raise ValueError(\"`jac` should be either `'fwd'`, `'bwd'`, or `None`.\")\n        return out\n\n    def as_matrix(self):\n        return materialise(self).as_matrix()\n\n    def transpose(self):\n        if is_symmetric(self):\n            return self\n        fn = _NoAuxIn(self.fn, self.args)\n        # Works because vjpfn is a PyTree\n        _, vjpfn = jax.vjp(fn, self.x)\n        vjpfn = _Unwrap(vjpfn)\n        return FunctionLinearOperator(\n            vjpfn, self.out_structure(), transpose_tags(self.tags)\n        )\n\n    def in_structure(self):\n        return strip_weak_dtype(jax.eval_shape(lambda: self.x))\n\n    def out_structure(self):\n        fn = _NoAuxIn(self.fn, self.args)\n        return strip_weak_dtype(eqxi.cached_filter_eval_shape(fn, self.x))\n\n\n# `input_structure` must be static as with `JacobianLinearOperator`\nclass FunctionLinearOperator(AbstractLinearOperator):\n    \"\"\"Wraps a *linear* function `fn: X -> Y` into a linear operator. (So that\n    `self.mv(x)` is defined by `self.mv(x) == fn(x)`.)\n\n    See also [`lineax.materialise`][], which materialises the whole linear operator\n    in memory. (Similar to `.as_matrix()`.)\n    \"\"\"\n\n    fn: Callable[[PyTree[Inexact[Array, \"...\"]]], PyTree[Inexact[Array, \"...\"]]]\n    input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)\n    tags: frozenset[object] = eqx.field(static=True)\n\n    @eqxi.doc_remove_args(\"closure_convert\")\n    def __init__(\n        self,\n        fn: Callable[[PyTree[Inexact[Array, \"...\"]]], PyTree[Inexact[Array, \"...\"]]],\n        input_structure: PyTree[jax.ShapeDtypeStruct],\n        tags: object | Iterable[object] = (),\n        closure_convert: bool = True,\n    ):\n        \"\"\"**Arguments:**\n\n        - `fn`: a linear function. Should accept a PyTree of floating-point JAX arrays,\n            and return a PyTree of floating-point JAX arrays.\n        - `input_structure`: A PyTree of `jax.ShapeDtypeStruct`s specifying the\n            structure of the input to the function. (When later calling `self.mv(x)`\n            then this should match the structure of `x`, i.e.\n            `jax.eval_shape(lambda: x)`.)\n        - `tags`: any tags indicating whether this operator has any particular\n            properties, like symmetry or positive-definite-ness. Note that these\n            properties are unchecked and you may get incorrect values elsewhere if these\n            tags are wrong.\n        \"\"\"\n        # See matching comment in JacobianLinearOperator.\n        input_structure = _inexact_structure(input_structure)\n        if closure_convert:\n            fn = eqx.filter_closure_convert(fn, input_structure)\n        self.fn = fn\n        self.input_structure = jtu.tree_flatten(input_structure)\n        self.tags = _frozenset(tags)\n\n    def mv(self, vector):\n        return self.fn(vector)\n\n    def as_matrix(self):\n        return materialise(self).as_matrix()\n\n    def transpose(self):\n        if is_symmetric(self):\n            return self\n        transpose_fn = jax.linear_transpose(self.fn, self.in_structure())\n\n        def _transpose_fn(vector):\n            (out,) = transpose_fn(vector)\n            return out\n\n        # Works because transpose_fn is a PyTree\n        return FunctionLinearOperator(\n            _transpose_fn, self.out_structure(), transpose_tags(self.tags)\n        )\n\n    def in_structure(self):\n        leaves, treedef = self.input_structure\n        return jtu.tree_unflatten(treedef, leaves)\n\n    def out_structure(self):\n        return strip_weak_dtype(\n            eqxi.cached_filter_eval_shape(self.fn, self.in_structure())\n        )\n\n\n# `structure` must be static as with `JacobianLinearOperator`\nclass IdentityLinearOperator(AbstractLinearOperator):\n    \"\"\"Represents the identity transformation `X -> X`, where each `x in X` is some\n    PyTree of floating-point JAX arrays.\n    \"\"\"\n\n    input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)\n    output_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)\n\n    def __init__(\n        self,\n        input_structure: PyTree[jax.ShapeDtypeStruct],\n        output_structure: PyTree[jax.ShapeDtypeStruct] = sentinel,\n    ):\n        \"\"\"**Arguments:**\n\n        - `input_structure`: A PyTree of `jax.ShapeDtypeStruct`s specifying the\n            structure of the the input space. (When later calling `self.mv(x)`\n            then this should match the structure of `x`, i.e.\n            `jax.eval_shape(lambda: x)`.)\n        - `output_structure`: A PyTree of `jax.ShapeDtypeStruct`s specifying the\n            structure of the the output space. If not passed then this defaults to the\n            same as `input_structure`. If passed then it must have the same number of\n            elements as `input_structure`, so that the operator is square.\n        \"\"\"\n        if output_structure is sentinel:\n            output_structure = input_structure\n        input_structure = _inexact_structure(input_structure)\n        output_structure = _inexact_structure(output_structure)\n        self.input_structure = jtu.tree_flatten(input_structure)\n        self.output_structure = jtu.tree_flatten(output_structure)\n\n    def mv(self, vector):\n        if not eqx.tree_equal(\n            strip_weak_dtype(jax.eval_shape(lambda: vector)),\n            strip_weak_dtype(self.in_structure()),\n        ):\n            raise ValueError(\"Vector and operator structures do not match\")\n        elif self.input_structure == self.output_structure:\n            return vector  # fast-path for common special case\n        else:\n            # TODO(kidger): this could be done slightly more efficiently, by iterating\n            #     leaf-by-leaf.\n            leaves = jtu.tree_leaves(vector)\n            with jax.numpy_dtype_promotion(\"standard\"):\n                dtype = jnp.result_type(*leaves)\n            vector = jnp.concatenate([x.astype(dtype).reshape(-1) for x in leaves])\n            out_size = self.out_size()\n            if vector.size < out_size:\n                vector = jnp.concatenate(\n                    [vector, jnp.zeros(out_size - vector.size, vector.dtype)]\n                )\n            else:\n                vector = vector[:out_size]\n            leaves, treedef = jtu.tree_flatten(self.out_structure())\n            sizes = np.cumsum([math.prod(x.shape) for x in leaves[:-1]])\n            split = jnp.split(vector, sizes)\n            assert len(split) == len(leaves)\n            with warnings.catch_warnings():\n                warnings.simplefilter(\"ignore\")  # ignore complex-to-real cast warning\n                shaped = [\n                    x.reshape(y.shape).astype(y.dtype) for x, y in zip(split, leaves)\n                ]\n            return jtu.tree_unflatten(treedef, shaped)\n\n    def as_matrix(self):\n        leaves = jtu.tree_leaves(self.in_structure())\n        with jax.numpy_dtype_promotion(\"standard\"):\n            dtype = (\n                default_floating_dtype()\n                if len(leaves) == 0\n                else jnp.result_type(*leaves)\n            )\n        return jnp.eye(self.out_size(), self.in_size(), dtype=dtype)\n\n    def transpose(self):\n        return IdentityLinearOperator(self.out_structure(), self.in_structure())\n\n    def in_structure(self):\n        leaves, treedef = self.input_structure\n        return jtu.tree_unflatten(treedef, leaves)\n\n    def out_structure(self):\n        leaves, treedef = self.output_structure\n        return jtu.tree_unflatten(treedef, leaves)\n\n    @property\n    def tags(self):\n        return frozenset()\n\n\nclass TridiagonalLinearOperator(AbstractLinearOperator):\n    \"\"\"As [`lineax.MatrixLinearOperator`][], but for specifically a tridiagonal\n    matrix.\n    \"\"\"\n\n    diagonal: Inexact[Array, \" size\"]\n    lower_diagonal: Inexact[Array, \" size-1\"]\n    upper_diagonal: Inexact[Array, \" size-1\"]\n\n    def __init__(\n        self,\n        diagonal: Inexact[Array, \" size\"],\n        lower_diagonal: Inexact[Array, \" size-1\"],\n        upper_diagonal: Inexact[Array, \" size-1\"],\n    ):\n        \"\"\"**Arguments:**\n\n        - `diagonal`: A rank-one JAX array. This is the diagonal of the matrix.\n        - `lower_diagonal`: A rank-one JAX array. This is the lower diagonal of the\n            matrix.\n        - `upper_diagonal`: A rank-one JAX array. This is the upper diagonal of the\n            matrix.\n\n        If `diagonal` has shape `(a,)` then `lower_diagonal` and `upper_diagonal` should\n        both have shape `(a - 1,)`.\n        \"\"\"\n        self.diagonal = inexact_asarray(diagonal)\n        self.lower_diagonal = inexact_asarray(lower_diagonal)\n        self.upper_diagonal = inexact_asarray(upper_diagonal)\n        (size,) = self.diagonal.shape\n        if self.lower_diagonal.shape != (size - 1,):\n            raise ValueError(\"lower_diagonal and diagonal do not have consistent size\")\n        if self.upper_diagonal.shape != (size - 1,):\n            raise ValueError(\"upper_diagonal and diagonal do not have consistent size\")\n\n    def mv(self, vector):\n        a = self.upper_diagonal * vector[1:]\n        b = self.diagonal * vector\n        c = self.lower_diagonal * vector[:-1]\n        return b.at[:-1].add(a).at[1:].add(c)\n\n    def as_matrix(self):\n        (size,) = jnp.shape(self.diagonal)\n        matrix = jnp.zeros((size, size), self.diagonal.dtype)\n        arange = np.arange(size)\n        matrix = matrix.at[arange, arange].set(self.diagonal)\n        matrix = matrix.at[arange[1:], arange[:-1]].set(self.lower_diagonal)\n        matrix = matrix.at[arange[:-1], arange[1:]].set(self.upper_diagonal)\n        return matrix\n\n    def transpose(self):\n        return TridiagonalLinearOperator(\n            self.diagonal, self.upper_diagonal, self.lower_diagonal\n        )\n\n    def in_structure(self):\n        (size,) = jnp.shape(self.diagonal)\n        return jax.ShapeDtypeStruct(shape=(size,), dtype=self.diagonal.dtype)\n\n    def out_structure(self):\n        (size,) = jnp.shape(self.diagonal)\n        return jax.ShapeDtypeStruct(shape=(size,), dtype=self.diagonal.dtype)\n\n\nclass TaggedLinearOperator(AbstractLinearOperator):\n    \"\"\"Wraps another linear operator and specifies that it has certain tags, e.g.\n    representing symmetry.\n\n    !!! Example\n\n        ```python\n        # Some other operator.\n        operator = lx.MatrixLinearOperator(some_jax_array)\n\n        # Now symmetric! But the type system doesn't know this.\n        sym_operator = operator + operator.T\n        assert lx.is_symmetric(sym_operator) == False\n\n        # We can declare that our operator has a particular property.\n        sym_operator = lx.TaggedLinearOperator(sym_operator, lx.symmetric_tag)\n        assert lx.is_symmetric(sym_operator) == True\n        ```\n    \"\"\"\n\n    operator: AbstractLinearOperator\n    tags: frozenset[object] = eqx.field(static=True)\n\n    def __init__(\n        self, operator: AbstractLinearOperator, tags: object | Iterable[object]\n    ):\n        \"\"\"**Arguments:**\n\n        - `operator`: some other linear operator to wrap.\n        - `tags`: any tags indicating whether this operator has any particular\n            properties, like symmetry or positive-definite-ness. Note that these\n            properties are unchecked and you may get incorrect values elsewhere if these\n            tags are wrong.\n        \"\"\"\n        self.operator = operator\n        self.tags = _frozenset(tags)\n\n    def mv(self, vector):\n        return self.operator.mv(vector)\n\n    def as_matrix(self):\n        return self.operator.as_matrix()\n\n    def transpose(self):\n        return TaggedLinearOperator(\n            self.operator.transpose(), transpose_tags(self.tags)\n        )\n\n    def in_structure(self):\n        return self.operator.in_structure()\n\n    def out_structure(self):\n        return self.operator.out_structure()\n\n\n#\n# All operators below here are private to lineax.\n#\n\n\ndef _is_none(x):\n    return x is None\n\n\nclass TangentLinearOperator(AbstractLinearOperator):\n    \"\"\"Internal to lineax. Used to represent the tangent (jvp) computation with\n    respect to the linear operator in a linear solve.\n    \"\"\"\n\n    primal: AbstractLinearOperator\n    tangent: AbstractLinearOperator\n\n    def __check_init__(self):\n        assert type(self.primal) is type(self.tangent)  # noqa: E721\n\n    def mv(self, vector):\n        mv = lambda operator: operator.mv(vector)\n        out, t_out = eqx.filter_jvp(mv, (self.primal,), (self.tangent,))\n        return jtu.tree_map(eqxi.materialise_zeros, out, t_out, is_leaf=_is_none)\n\n    def as_matrix(self):\n        as_matrix = lambda operator: operator.as_matrix()\n        out, t_out = eqx.filter_jvp(as_matrix, (self.primal,), (self.tangent,))\n        return jtu.tree_map(eqxi.materialise_zeros, out, t_out, is_leaf=_is_none)\n\n    def transpose(self):\n        transpose = lambda operator: operator.transpose()\n        primal_out, tangent_out = eqx.filter_jvp(\n            transpose, (self.primal,), (self.tangent,)\n        )\n        return TangentLinearOperator(primal_out, tangent_out)\n\n    def in_structure(self):\n        return self.primal.in_structure()\n\n    def out_structure(self):\n        return self.primal.out_structure()\n\n\nclass AddLinearOperator(AbstractLinearOperator):\n    \"\"\"A linear operator formed by adding two other linear operators together.\n\n    !!! Example\n\n        ```python\n        x = MatrixLinearOperator(...)\n        y = MatrixLinearOperator(...)\n        assert isinstance(x + y, AddLinearOperator)\n        ```\n    \"\"\"\n\n    operator1: AbstractLinearOperator\n    operator2: AbstractLinearOperator\n\n    def __check_init__(self):\n        if self.operator1.in_structure() != self.operator2.in_structure():\n            raise ValueError(\"Incompatible linear operator structures\")\n        if self.operator1.out_structure() != self.operator2.out_structure():\n            raise ValueError(\"Incompatible linear operator structures\")\n\n    def mv(self, vector):\n        maybe_sparse_op = _try_sparse_materialise(self)\n        if maybe_sparse_op is not self:\n            return maybe_sparse_op.mv(vector)\n        mv1 = self.operator1.mv(vector)\n        mv2 = self.operator2.mv(vector)\n        return (mv1**ω + mv2**ω).ω\n\n    def as_matrix(self):\n        return self.operator1.as_matrix() + self.operator2.as_matrix()\n\n    def transpose(self):\n        return self.operator1.transpose() + self.operator2.transpose()\n\n    def in_structure(self):\n        return self.operator1.in_structure()\n\n    def out_structure(self):\n        return self.operator1.out_structure()\n\n\nclass MulLinearOperator(AbstractLinearOperator):\n    \"\"\"A linear operator formed by multiplying a linear operator by a scalar.\n\n    !!! Example\n\n        ```python\n        x = MatrixLinearOperator(...)\n        y = 0.5\n        assert isinstance(x * y, MulLinearOperator)\n        ```\n    \"\"\"\n\n    operator: AbstractLinearOperator\n    scalar: Scalar\n\n    def mv(self, vector):\n        return (self.operator.mv(vector) ** ω * self.scalar).ω\n\n    def as_matrix(self):\n        return self.operator.as_matrix() * self.scalar\n\n    def transpose(self):\n        return self.operator.transpose() * self.scalar\n\n    def in_structure(self):\n        return self.operator.in_structure()\n\n    def out_structure(self):\n        return self.operator.out_structure()\n\n\n# Not just `MulLinearOperator(..., -1)` for compatibility with\n# `jax_numpy_dtype_promotion=strict`.\nclass NegLinearOperator(AbstractLinearOperator):\n    \"\"\"A linear operator formed by computing the negative of a linear operator.\n\n    !!! Example\n\n        ```python\n        x = MatrixLinearOperator(...)\n        assert isinstance(-x, NegLinearOperator)\n        ```\n    \"\"\"\n\n    operator: AbstractLinearOperator\n\n    def mv(self, vector):\n        return (-(self.operator.mv(vector) ** ω)).ω\n\n    def as_matrix(self):\n        return -self.operator.as_matrix()\n\n    def transpose(self):\n        return -self.operator.transpose()\n\n    def in_structure(self):\n        return self.operator.in_structure()\n\n    def out_structure(self):\n        return self.operator.out_structure()\n\n\nclass DivLinearOperator(AbstractLinearOperator):\n    \"\"\"A linear operator formed by dividing a linear operator by a scalar.\n\n    !!! Example\n\n        ```python\n        x = MatrixLinearOperator(...)\n        y = 0.5\n        assert isinstance(x / y, DivLinearOperator)\n        ```\n    \"\"\"\n\n    operator: AbstractLinearOperator\n    scalar: Scalar\n\n    def mv(self, vector):\n        with jax.numpy_dtype_promotion(\"standard\"):\n            return (self.operator.mv(vector) ** ω / self.scalar).ω\n\n    def as_matrix(self):\n        return self.operator.as_matrix() / self.scalar\n\n    def transpose(self):\n        return self.operator.transpose() / self.scalar\n\n    def in_structure(self):\n        return self.operator.in_structure()\n\n    def out_structure(self):\n        return self.operator.out_structure()\n\n\nclass ComposedLinearOperator(AbstractLinearOperator):\n    \"\"\"A linear operator formed by composing (matrix-multiplying) two other linear\n    operators together.\n\n    !!! Example\n\n        ```python\n        x = MatrixLinearOperator(matrix1)\n        y = MatrixLinearOperator(matrix2)\n        composed = x @ y\n        assert isinstance(composed, ComposedLinearOperator)\n        assert jnp.allclose(composed.as_matrix(), matrix1 @ matrix2)\n        ```\n    \"\"\"\n\n    operator1: AbstractLinearOperator\n    operator2: AbstractLinearOperator\n\n    def __check_init__(self):\n        if self.operator1.in_structure() != self.operator2.out_structure():\n            raise ValueError(\"Incompatible linear operator structures\")\n\n    def mv(self, vector):\n        maybe_sparse_op = _try_sparse_materialise(self)\n        if maybe_sparse_op is not self:\n            return maybe_sparse_op.mv(vector)\n        return self.operator1.mv(self.operator2.mv(vector))\n\n    def as_matrix(self):\n        if isinstance(self.operator1, IdentityLinearOperator):\n            return self.operator2.as_matrix()\n        if isinstance(self.operator2, IdentityLinearOperator):\n            return self.operator1.as_matrix()\n        _, unravel = eqx.filter_eval_shape(\n            jfu.ravel_pytree, self.operator1.in_structure()\n        )\n\n        def mv_flat(v):\n            out = self.operator1.mv(unravel(v))\n            return jfu.ravel_pytree(out)[0]\n\n        return jax.vmap(mv_flat, in_axes=1, out_axes=1)(self.operator2.as_matrix())\n\n    def transpose(self):\n        return self.operator2.transpose() @ self.operator1.transpose()\n\n    def in_structure(self):\n        return self.operator2.in_structure()\n\n    def out_structure(self):\n        return self.operator1.out_structure()\n\n\n#\n# Operations on `AbstractLinearOperator`s.\n# These are done through `singledispatch` rather than as methods.\n#\n# If an end user ever wanted to add something analogous to\n# `diagonal: AbstractLinearOperator -> Array`\n# then of course they don't get to edit our base class and add overloads to all\n# subclasses.\n# They'd have to use `singledispatch` to get the desired behaviour. (Or maybe just\n# hardcode compatibility with only some `AbstractLinearOperator` subclasses, eurgh.)\n# So for consistency we do the same thing here, rather than adding privileged behaviour\n# for just the operations we happen to support.\n#\n# (Something something Julia something something orphan problem etc.)\n#\n\n\ndef _default_not_implemented(name: str, operator: AbstractLinearOperator) -> NoReturn:\n    msg = f\"`lineax.{name}` has not been implemented for {type(operator)}\"\n    if type(operator).__module__.startswith(\"lineax\"):\n        assert False, msg + \". Please file a bug against Lineax.\"\n    else:\n        raise NotImplementedError(msg)\n\n\n# linearise\n\n\n@ft.singledispatch\ndef linearise(operator: AbstractLinearOperator) -> AbstractLinearOperator:\n    \"\"\"Linearises a linear operator. This returns another linear operator.\n\n    Mathematically speaking this is just the identity function. And indeed most linear\n    operators will be returned unchanged.\n\n    For specifically [`lineax.JacobianLinearOperator`][], then this will cache the\n    primal pass, so that it does not need to be recomputed each time. That is, it uses\n    some memory to improve speed. (This is the precisely same distinction as `jax.jvp`\n    versus `jax.linearize`.)\n\n    **Arguments:**\n\n    - `operator`: a linear operator.\n\n    **Returns:**\n\n    Another linear operator. Mathematically it performs matrix-vector products\n    (`operator.mv`) that produce the same results as the input `operator`.\n    \"\"\"\n    _default_not_implemented(\"linearise\", operator)\n\n\n@linearise.register(MatrixLinearOperator)\n@linearise.register(PyTreeLinearOperator)\n@linearise.register(FunctionLinearOperator)\n@linearise.register(IdentityLinearOperator)\n@linearise.register(DiagonalLinearOperator)\n@linearise.register(TridiagonalLinearOperator)\ndef _(operator):\n    return operator\n\n\n@linearise.register(JacobianLinearOperator)\ndef _(operator):\n    fn = _NoAuxIn(operator.fn, operator.args)\n    if operator.jac == \"bwd\":\n        # For backward mode, use VJP + linear_transpose.\n        # This works even with custom_vjp functions that don't support forward-mode AD.\n        _, vjp_fn = jax.vjp(fn, operator.x)\n        if is_symmetric(operator):\n            # For symmetric: J = J.T, so vjp directly gives J @ v\n            lin = _Unwrap(vjp_fn)\n        else:\n            # Transpose the VJP to get J @ v from J.T @ v\n            lin = _Unwrap(\n                jax.linear_transpose(lambda g: vjp_fn(g)[0], operator.out_structure())\n            )\n    else:  # \"fwd\" or None\n        _, lin = jax.linearize(fn, operator.x)\n    return FunctionLinearOperator(lin, operator.in_structure(), operator.tags)\n\n\n# materialise\n\n\n@ft.singledispatch\ndef materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator:\n    \"\"\"Materialises a linear operator. This returns another linear operator.\n\n    Mathematically speaking this is just the identity function. And indeed most linear\n    operators will be returned unchanged.\n\n    For specifically [`lineax.JacobianLinearOperator`][] and\n    [`lineax.FunctionLinearOperator`][] then the linear operator is materialised in\n    memory. That is, it becomes defined as a matrix (or pytree of arrays), rather\n    than being defined only through its matrix-vector product\n    ([`lineax.AbstractLinearOperator.mv`][]).\n\n    Materialisation sometimes improves compile time or run time. It usually increases\n    memory usage.\n\n    For example:\n    ```python\n    large_function = ...\n    operator = lx.FunctionLinearOperator(large_function, ...)\n\n    # Option 1\n    out1 = operator.mv(vector1)  # Traces and compiles `large_function`\n    out2 = operator.mv(vector2)  # Traces and compiles `large_function` again!\n    out3 = operator.mv(vector3)  # Traces and compiles `large_function` a third time!\n    # All that compilation might lead to long compile times.\n    # If `large_function` takes a long time to run, then this might also lead to long\n    # run times.\n\n    # Option 2\n    operator = lx.materialise(operator)  # Traces and compiles `large_function` and\n                                           # stores the result as a matrix.\n    out1 = operator.mv(vector1)  # Each of these just computes a matrix-vector product\n    out2 = operator.mv(vector2)  # against the stored matrix.\n    out3 = operator.mv(vector3)  #\n    # Now, `large_function` is only compiled once, and only ran once.\n    # However, storing the matrix might take a lot of memory, and the initial\n    # computation may-or-may-not take a long time to run.\n    ```\n    Generally speaking it is worth first setting up your problem without\n    `lx.materialise`, and using it as an optional optimisation if you find that it\n    helps your particular problem.\n\n    **Arguments:**\n\n    - `operator`: a linear operator.\n\n    **Returns:**\n\n    Another linear operator. Mathematically it performs matrix-vector products\n    (`operator.mv`) that produce the same results as the input `operator`.\n    \"\"\"\n    _default_not_implemented(\"materialise\", operator)\n\n\ndef _try_sparse_materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator:\n    \"\"\"Try to materialise to a sparse operator.\n\n    Returns a (Tri)DiagonalLinearOperator if the operator is tagged as (tri)diagonal,\n    otherwise returns the original operator unchanged. The resulting operator\n    preserves the input/output structure of the original operator.\n    \"\"\"\n    if is_diagonal(operator):\n        diag_flat = diagonal(operator)\n        _, unravel = eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())\n        diag_pytree = unravel(diag_flat)\n        return DiagonalLinearOperator(diag_pytree)\n    # TridiagonalLinearOperator only supports flat in and out structures\n    if (\n        is_tridiagonal(operator)\n        and isinstance(operator.in_structure(), jax.ShapeDtypeStruct)\n        and isinstance(operator.out_structure(), jax.ShapeDtypeStruct)\n    ):\n        return TridiagonalLinearOperator(*tridiagonal(operator))\n    return operator\n\n\n@materialise.register(MatrixLinearOperator)\n@materialise.register(PyTreeLinearOperator)\ndef _(operator):\n    return _try_sparse_materialise(operator)\n\n\n@materialise.register(IdentityLinearOperator)\n@materialise.register(DiagonalLinearOperator)\n@materialise.register(TridiagonalLinearOperator)\ndef _(operator):\n    return operator\n\n\n@materialise.register(JacobianLinearOperator)\ndef _(operator):\n    maybe_sparse_op = _try_sparse_materialise(operator)\n    if maybe_sparse_op is not operator:\n        return maybe_sparse_op\n    fn = _NoAuxIn(operator.fn, operator.args)\n    jac = jacobian(\n        fn,\n        operator.in_size(),\n        operator.out_size(),\n        holomorphic=any(jnp.iscomplexobj(xi) for xi in jtu.tree_leaves(operator.x)),\n        jac=operator.jac,\n    )(operator.x)\n    return PyTreeLinearOperator(jac, operator.out_structure(), operator.tags)\n\n\n@materialise.register(FunctionLinearOperator)\ndef _(operator):\n    maybe_sparse_op = _try_sparse_materialise(operator)\n    if maybe_sparse_op is not operator:\n        return maybe_sparse_op\n    flat, unravel = strip_weak_dtype(\n        eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())\n    )\n    eye = jnp.eye(flat.size, dtype=flat.dtype)\n    jac = jax.vmap(lambda x: operator.fn(unravel(x)), out_axes=-1)(eye)\n\n    def batch_unravel(x):\n        assert x.ndim > 0\n        unravel_ = unravel\n        for _ in range(x.ndim - 1):\n            unravel_ = jax.vmap(unravel_)\n        return unravel_(x)\n\n    jac = jtu.tree_map(batch_unravel, jac)\n    return PyTreeLinearOperator(jac, operator.out_structure(), operator.tags)\n\n\n# diagonal\n\n\n@ft.singledispatch\ndef diagonal(operator: AbstractLinearOperator) -> Shaped[Array, \" size\"]:\n    \"\"\"Extracts the diagonal from a linear operator, and returns a vector.\n\n    **Arguments:**\n\n    - `operator`: a linear operator.\n\n    **Returns:**\n\n    A rank-1 JAX array. (That is, it has shape `(a,)` for some integer `a`.)\n\n    For most operators this is just `jnp.diag(operator.as_matrix())`. Some operators\n    (e.g. [`lineax.DiagonalLinearOperator`][]) can have more efficient\n    implementations. If you don't know what kind of operator you might have, then this\n    function ensures that you always get the most efficient implementation.\n    \"\"\"\n    _default_not_implemented(\"diagonal\", operator)\n\n\ndef _leaf_from_keypath(pytree: PyTree, keypath: jtu.KeyPath) -> Array:\n    \"\"\"Extract the leaf from a pytree at the given keypath.\"\"\"\n    for path, leaf in jtu.tree_leaves_with_path(pytree):\n        if path == keypath:\n            return leaf\n    raise ValueError(f\"Leaf not found at keypath {keypath}\")\n\n\n@diagonal.register(MatrixLinearOperator)\ndef _(operator):\n    return jnp.diag(operator.as_matrix())\n\n\n@diagonal.register(PyTreeLinearOperator)\ndef _(operator):\n    if is_diagonal(operator):\n\n        def extract_diag(keypath, struct, subpytree):\n            block = _leaf_from_keypath(subpytree, keypath)\n            return jnp.diag(block.reshape(struct.size, struct.size))\n\n        diags = jtu.tree_map_with_path(\n            extract_diag, operator.out_structure(), operator.pytree\n        )\n        return jnp.concatenate(jtu.tree_leaves(diags))\n    else:\n        return jnp.diag(operator.as_matrix())\n\n\n@diagonal.register(JacobianLinearOperator)\n@diagonal.register(FunctionLinearOperator)\ndef _(operator):\n    if is_diagonal(operator):\n        with jax.ensure_compile_time_eval():\n            basis = jtu.tree_map(\n                lambda s: jnp.ones(s.shape, s.dtype), operator.in_structure()\n            )\n        diag_as_pytree = operator.mv(basis)\n        diag, _ = jfu.ravel_pytree(diag_as_pytree)\n        return diag\n    return diagonal(materialise(operator))\n\n\n@diagonal.register(DiagonalLinearOperator)\ndef _(operator):\n    diagonal, _ = jfu.ravel_pytree(operator.diagonal)\n    return diagonal\n\n\n@diagonal.register(IdentityLinearOperator)\ndef _(operator):\n    return jnp.ones(operator.in_size())\n\n\n@diagonal.register(TridiagonalLinearOperator)\ndef _(operator):\n    return operator.diagonal\n\n\n# tridiagonal\n\n\n@ft.singledispatch\ndef tridiagonal(\n    operator: AbstractLinearOperator,\n) -> tuple[Shaped[Array, \" size\"], Shaped[Array, \" size-1\"], Shaped[Array, \" size-1\"]]:\n    \"\"\"Extracts the diagonal, lower diagonal, and upper diagonal, from a linear\n    operator. Returns three vectors.\n\n    **Arguments:**\n\n    - `operator`: a linear operator.\n\n    **Returns:**\n\n    A 3-tuple, consisting of:\n\n    - The diagonal of the matrix, represented as a vector.\n    - The lower diagonal of the matrix, represented as a vector.\n    - The upper diagonal of the matrix, represented as a vector.\n\n    If the diagonal has shape `(a,)` then the lower and upper diagonals will have shape\n    `(a - 1,)`.\n\n    For most operators these are computed by materialising the array and then extracting\n    the relevant elements, e.g. getting the main diagonal via\n    `jnp.diag(operator.as_matrix())`. Some operators (e.g.\n    [`lineax.TridiagonalLinearOperator`][]) can have more efficient implementations.\n    If you don't know what kind of operator you might have, then this function ensures\n    that you always get the most efficient implementation.\n    \"\"\"\n    _default_not_implemented(\"tridiagonal\", operator)\n\n\n@tridiagonal.register(MatrixLinearOperator)\n@tridiagonal.register(PyTreeLinearOperator)\ndef _(operator):\n    matrix = operator.as_matrix()\n    assert matrix.ndim == 2\n    main_diagonal = jnp.diagonal(matrix, offset=0)\n    upper_diagonal = jnp.diagonal(matrix, offset=1)\n    lower_diagonal = jnp.diagonal(matrix, offset=-1)\n    return main_diagonal, lower_diagonal, upper_diagonal\n\n\n@tridiagonal.register(JacobianLinearOperator)\n@tridiagonal.register(FunctionLinearOperator)\ndef _(operator):\n    if is_tridiagonal(operator):\n        with jax.ensure_compile_time_eval():\n            flat, unravel = strip_weak_dtype(\n                eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())\n            )\n\n            basis = jnp.zeros((3, flat.size), dtype=flat.dtype)\n            for i in range(3):\n                basis = basis.at[i, i::3].set(1.0)\n\n            basis = jax.vmap(unravel)(basis)\n\n            coloring = jnp.arange(flat.size) % 3\n\n        compressed_as_pytree = jax.vmap(operator.mv)(basis)\n        compressed_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(\n            compressed_as_pytree\n        )\n\n        # unique_indices propagates through linear_transpose to set unique_indices=True\n        # on the scatter, allowing assignment rather than accumulation.\n        rows = jnp.arange(flat.size)\n        diag = compressed_flat.at[coloring, rows].get(\n            wrap_negative_indices=False, unique_indices=True\n        )\n        lower_diag = compressed_flat.at[coloring[:-1], rows[1:]].get(\n            wrap_negative_indices=False, unique_indices=True\n        )\n        upper_diag = compressed_flat.at[coloring[1:], rows[:-1]].get(\n            wrap_negative_indices=False, unique_indices=True\n        )\n\n        return diag, lower_diag, upper_diag\n    matrix = operator.as_matrix()\n    assert matrix.ndim == 2\n    main_diagonal = jnp.diagonal(matrix, offset=0)\n    upper_diagonal = jnp.diagonal(matrix, offset=1)\n    lower_diagonal = jnp.diagonal(matrix, offset=-1)\n    return main_diagonal, lower_diagonal, upper_diagonal\n\n\n@tridiagonal.register(DiagonalLinearOperator)\ndef _(operator):\n    diag = diagonal(operator)\n    upper_diag = jnp.zeros(diag.size - 1)\n    lower_diag = jnp.zeros(diag.size - 1)\n    return diag, lower_diag, upper_diag\n\n\n@tridiagonal.register(IdentityLinearOperator)\ndef _(operator):\n    size = operator.in_size()\n    main_diagonal = jnp.ones(size)\n    off_diagonal = jnp.zeros(size - 1)\n    return main_diagonal, off_diagonal, off_diagonal\n\n\n@tridiagonal.register(TridiagonalLinearOperator)\ndef _(operator):\n    return operator.diagonal, operator.lower_diagonal, operator.upper_diagonal\n\n\n# is_symmetric\n\n\n@ft.singledispatch\ndef is_symmetric(operator: AbstractLinearOperator) -> bool:\n    \"\"\"Returns whether an operator is marked as symmetric.\n\n    See [the documentation on linear operator tags](../api/tags.md) for more\n    information.\n\n    **Arguments:**\n\n    - `operator`: a linear operator.\n\n    **Returns:**\n\n    Either `True` or `False.`\n    \"\"\"\n    _default_not_implemented(\"is_symmetric\", operator)\n\n\ndef _has_real_dtype(operator) -> bool:\n    \"\"\"Check if all dtypes in an operator's structure are real (not complex).\"\"\"\n    leaves = jtu.tree_leaves((operator.in_structure(), operator.out_structure()))\n    dtype = jnp.result_type(*leaves)\n    if jnp.issubdtype(dtype, jnp.complexfloating):\n        return False\n    elif jnp.issubdtype(dtype, jnp.floating):\n        return True\n    else:\n        assert False, (\n            \"Only `jnp.floating` and `jnp.complexfloating` dtypes are understood.\"\n        )\n\n\n@is_symmetric.register(MatrixLinearOperator)\n@is_symmetric.register(PyTreeLinearOperator)\n@is_symmetric.register(JacobianLinearOperator)\n@is_symmetric.register(FunctionLinearOperator)\ndef _(operator):\n    # Symmetric (A = A^T) if explicitly tagged symmetric or diagonal\n    if symmetric_tag in operator.tags or diagonal_tag in operator.tags:\n        return True\n    # PSD/NSD implies symmetric only for real dtypes; for complex, it's Hermitian\n    if (\n        positive_semidefinite_tag in operator.tags\n        or negative_semidefinite_tag in operator.tags\n    ):\n        return _has_real_dtype(operator)\n    return False\n\n\n@is_symmetric.register(IdentityLinearOperator)\ndef _(operator):\n    return eqx.tree_equal(operator.in_structure(), operator.out_structure()) is True\n\n\n@is_symmetric.register(DiagonalLinearOperator)\ndef _(operator):\n    return True\n\n\n@is_symmetric.register(TridiagonalLinearOperator)\ndef _(operator):\n    return False\n\n\n# is_diagonal\n\n\n@ft.singledispatch\ndef is_diagonal(operator: AbstractLinearOperator) -> bool:\n    \"\"\"Returns whether an operator is marked as diagonal.\n\n    See [the documentation on linear operator tags](../api/tags.md) for more\n    information.\n\n    **Arguments:**\n\n    - `operator`: a linear operator.\n\n    **Returns:**\n\n    Either `True` or `False.`\n    \"\"\"\n    _default_not_implemented(\"is_diagonal\", operator)\n\n\n@is_diagonal.register(MatrixLinearOperator)\n@is_diagonal.register(PyTreeLinearOperator)\n@is_diagonal.register(JacobianLinearOperator)\n@is_diagonal.register(FunctionLinearOperator)\ndef _(operator):\n    return diagonal_tag in operator.tags or (\n        operator.in_size() == 1 and operator.out_size() == 1\n    )\n\n\n@is_diagonal.register(IdentityLinearOperator)\n@is_diagonal.register(DiagonalLinearOperator)\ndef _(operator):\n    return True\n\n\n@is_diagonal.register(TridiagonalLinearOperator)\ndef _(operator):\n    return operator.in_size() == 1\n\n\n# is_tridiagonal\n\n\n@ft.singledispatch\ndef is_tridiagonal(operator: AbstractLinearOperator) -> bool:\n    \"\"\"Returns whether an operator is marked as tridiagonal.\n\n    See [the documentation on linear operator tags](../api/tags.md) for more\n    information.\n\n    **Arguments:**\n\n    - `operator`: a linear operator.\n\n    **Returns:**\n\n    Either `True` or `False.`\n    \"\"\"\n    _default_not_implemented(\"is_tridiagonal\", operator)\n\n\n@is_tridiagonal.register(MatrixLinearOperator)\n@is_tridiagonal.register(PyTreeLinearOperator)\n@is_tridiagonal.register(JacobianLinearOperator)\n@is_tridiagonal.register(FunctionLinearOperator)\ndef _(operator):\n    return tridiagonal_tag in operator.tags or diagonal_tag in operator.tags\n\n\n@is_tridiagonal.register(IdentityLinearOperator)\n@is_tridiagonal.register(DiagonalLinearOperator)\n@is_tridiagonal.register(TridiagonalLinearOperator)\ndef _(operator):\n    return True\n\n\n# has_unit_diagonal\n\n\n@ft.singledispatch\ndef has_unit_diagonal(operator: AbstractLinearOperator) -> bool:\n    \"\"\"Returns whether an operator is marked as having unit diagonal.\n\n    See [the documentation on linear operator tags](../api/tags.md) for more\n    information.\n\n    **Arguments:**\n\n    - `operator`: a linear operator.\n\n    **Returns:**\n\n    Either `True` or `False.`\n    \"\"\"\n    _default_not_implemented(\"has_unit_diagonal\", operator)\n\n\n@has_unit_diagonal.register(MatrixLinearOperator)\n@has_unit_diagonal.register(PyTreeLinearOperator)\n@has_unit_diagonal.register(JacobianLinearOperator)\n@has_unit_diagonal.register(FunctionLinearOperator)\ndef _(operator):\n    return unit_diagonal_tag in operator.tags\n\n\n@has_unit_diagonal.register(IdentityLinearOperator)\ndef _(operator):\n    return True\n\n\n@has_unit_diagonal.register(DiagonalLinearOperator)\n@has_unit_diagonal.register(TridiagonalLinearOperator)\ndef _(operator):\n    # TODO: refine this\n    return False\n\n\n# is_lower_triangular\n\n\n@ft.singledispatch\ndef is_lower_triangular(operator: AbstractLinearOperator) -> bool:\n    \"\"\"Returns whether an operator is marked as lower triangular.\n\n    See [the documentation on linear operator tags](../api/tags.md) for more\n    information.\n\n    **Arguments:**\n\n    - `operator`: a linear operator.\n\n    **Returns:**\n\n    Either `True` or `False.`\n    \"\"\"\n    _default_not_implemented(\"is_lower_triangular\", operator)\n\n\n@is_lower_triangular.register(MatrixLinearOperator)\n@is_lower_triangular.register(PyTreeLinearOperator)\n@is_lower_triangular.register(JacobianLinearOperator)\n@is_lower_triangular.register(FunctionLinearOperator)\ndef _(operator):\n    return lower_triangular_tag in operator.tags\n\n\n@is_lower_triangular.register(IdentityLinearOperator)\n@is_lower_triangular.register(DiagonalLinearOperator)\ndef _(operator):\n    return True\n\n\n@is_lower_triangular.register(TridiagonalLinearOperator)\ndef _(operator):\n    return False\n\n\n# is_upper_triangular\n\n\n@ft.singledispatch\ndef is_upper_triangular(operator: AbstractLinearOperator) -> bool:\n    \"\"\"Returns whether an operator is marked as upper triangular.\n\n    See [the documentation on linear operator tags](../api/tags.md) for more\n    information.\n\n    **Arguments:**\n\n    - `operator`: a linear operator.\n\n    **Returns:**\n\n    Either `True` or `False.`\n    \"\"\"\n    _default_not_implemented(\"is_upper_triangular\", operator)\n\n\n@is_upper_triangular.register(MatrixLinearOperator)\n@is_upper_triangular.register(PyTreeLinearOperator)\n@is_upper_triangular.register(JacobianLinearOperator)\n@is_upper_triangular.register(FunctionLinearOperator)\ndef _(operator):\n    return upper_triangular_tag in operator.tags\n\n\n@is_upper_triangular.register(IdentityLinearOperator)\n@is_upper_triangular.register(DiagonalLinearOperator)\ndef _(operator):\n    return True\n\n\n@is_upper_triangular.register(TridiagonalLinearOperator)\ndef _(operator):\n    return False\n\n\n# is_positive_semidefinite\n\n\n@ft.singledispatch\ndef is_positive_semidefinite(operator: AbstractLinearOperator) -> bool:\n    \"\"\"Returns whether an operator is marked as positive semidefinite.\n\n    See [the documentation on linear operator tags](../api/tags.md) for more\n    information.\n\n    **Arguments:**\n\n    - `operator`: a linear operator.\n\n    **Returns:**\n\n    Either `True` or `False.`\n    \"\"\"\n    _default_not_implemented(\"is_positive_semidefinite\", operator)\n\n\n@is_positive_semidefinite.register(MatrixLinearOperator)\n@is_positive_semidefinite.register(PyTreeLinearOperator)\n@is_positive_semidefinite.register(JacobianLinearOperator)\n@is_positive_semidefinite.register(FunctionLinearOperator)\ndef _(operator):\n    return positive_semidefinite_tag in operator.tags\n\n\n@is_positive_semidefinite.register(IdentityLinearOperator)\ndef _(operator):\n    return eqx.tree_equal(operator.in_structure(), operator.out_structure()) is True\n\n\n@is_positive_semidefinite.register(DiagonalLinearOperator)\n@is_positive_semidefinite.register(TridiagonalLinearOperator)\ndef _(operator):\n    # TODO: refine this\n    return False\n\n\n# is_negative_semidefinite\n\n\n@ft.singledispatch\ndef is_negative_semidefinite(operator: AbstractLinearOperator) -> bool:\n    \"\"\"Returns whether an operator is marked as negative semidefinite.\n\n    See [the documentation on linear operator tags](../api/tags.md) for more\n    information.\n\n    **Arguments:**\n\n    - `operator`: a linear operator.\n\n    **Returns:**\n\n    Either `True` or `False.`\n    \"\"\"\n    _default_not_implemented(\"is_negative_semidefinite\", operator)\n\n\n@is_negative_semidefinite.register(MatrixLinearOperator)\n@is_negative_semidefinite.register(PyTreeLinearOperator)\n@is_negative_semidefinite.register(JacobianLinearOperator)\n@is_negative_semidefinite.register(FunctionLinearOperator)\ndef _(operator):\n    return negative_semidefinite_tag in operator.tags\n\n\n@is_negative_semidefinite.register(IdentityLinearOperator)\ndef _(operator):\n    return False\n\n\n@is_negative_semidefinite.register(DiagonalLinearOperator)\n@is_negative_semidefinite.register(TridiagonalLinearOperator)\ndef _(operator):\n    # TODO: refine this\n    return False\n\n\n# ops for wrapper operators\n\n\n@linearise.register(TaggedLinearOperator)\ndef _(operator):\n    return TaggedLinearOperator(linearise(operator.operator), operator.tags)\n\n\n@materialise.register(TaggedLinearOperator)\ndef _(operator):\n    return TaggedLinearOperator(materialise(operator.operator), operator.tags)\n\n\n@diagonal.register(TaggedLinearOperator)\ndef _(operator):\n    return diagonal(operator.operator)\n\n\n@tridiagonal.register(TaggedLinearOperator)\ndef _(operator):\n    return tridiagonal(operator.operator)\n\n\nfor transform in (linearise, materialise, diagonal):\n\n    @transform.register(MulLinearOperator)\n    def _(operator, transform=transform):\n        return transform(operator.operator) * operator.scalar\n\n    @transform.register(NegLinearOperator)  # pyright: ignore\n    def _(operator, transform=transform):\n        return -transform(operator.operator)\n\n    @transform.register(DivLinearOperator)\n    def _(operator, transform=transform):\n        return transform(operator.operator) / operator.scalar\n\n\nfor transform in (linearise, diagonal):\n\n    @transform.register(AddLinearOperator)  # pyright: ignore\n    def _(operator, transform=transform):\n        return transform(operator.operator1) + transform(operator.operator2)  # pyright: ignore\n\n\n@materialise.register(AddLinearOperator)\ndef _(operator):\n    maybe_sparse_op = _try_sparse_materialise(operator)\n    if maybe_sparse_op is not operator:\n        return maybe_sparse_op\n    return materialise(operator.operator1) + materialise(operator.operator2)\n\n\n@linearise.register(TangentLinearOperator)\ndef _(operator):\n    primal_out, tangent_out = eqx.filter_jvp(\n        linearise, (operator.primal,), (operator.tangent,)\n    )\n    return TangentLinearOperator(primal_out, tangent_out)\n\n\n@materialise.register(TangentLinearOperator)\ndef _(operator):\n    primal_out, tangent_out = eqx.filter_jvp(\n        materialise, (operator.primal,), (operator.tangent,)\n    )\n    return TangentLinearOperator(primal_out, tangent_out)\n\n\n@diagonal.register(TangentLinearOperator)\ndef _(operator):\n    # Should be unreachable: TangentLinearOperator is used for a narrow set of\n    # operations only (mv; transpose) inside the JVP rule linear_solve_p.\n    raise NotImplementedError(\n        \"Please open a GitHub issue: https://github.com/google/lineax\"\n    )\n\n\n@tridiagonal.register(TangentLinearOperator)\ndef _(operator):\n    # Should be unreachable: TangentLinearOperator is used for a narrow set of\n    # operations only (mv; transpose) inside the JVP rule linear_solve_p.\n    raise NotImplementedError(\n        \"Please open a GitHub issue: https://github.com/google/lineax\"\n    )\n\n\n@tridiagonal.register(AddLinearOperator)\ndef _(operator):\n    (diag1, lower1, upper1) = tridiagonal(operator.operator1)\n    (diag2, lower2, upper2) = tridiagonal(operator.operator2)\n    return (diag1 + diag2, lower1 + lower2, upper1 + upper2)\n\n\n@tridiagonal.register(MulLinearOperator)\ndef _(operator):\n    (diag, lower, upper) = tridiagonal(operator.operator)\n    return (diag * operator.scalar, lower * operator.scalar, upper * operator.scalar)\n\n\n@tridiagonal.register(NegLinearOperator)\ndef _(operator):\n    (diag, lower, upper) = tridiagonal(operator.operator)\n    return (-diag, -lower, -upper)\n\n\n@tridiagonal.register(DivLinearOperator)\ndef _(operator):\n    (diag, lower, upper) = tridiagonal(operator.operator)\n    return (diag / operator.scalar, lower / operator.scalar, upper / operator.scalar)\n\n\n@linearise.register(ComposedLinearOperator)\ndef _(operator):\n    return linearise(operator.operator1) @ linearise(operator.operator2)\n\n\n@materialise.register(ComposedLinearOperator)\ndef _(operator):\n    if isinstance(operator.operator1, IdentityLinearOperator):\n        return materialise(operator.operator2)\n    if isinstance(operator.operator2, IdentityLinearOperator):\n        return materialise(operator.operator1)\n    maybe_sparse_op = _try_sparse_materialise(operator)\n    if maybe_sparse_op is not operator:\n        return maybe_sparse_op\n    return materialise(operator.operator1) @ materialise(operator.operator2)\n\n\n@diagonal.register(ComposedLinearOperator)\ndef _(operator):\n    if is_diagonal(operator.operator1) and is_diagonal(operator.operator2):\n        return diagonal(operator.operator1) * diagonal(operator.operator2)\n    return jnp.diag(operator.as_matrix())\n\n\n@tridiagonal.register(ComposedLinearOperator)\ndef _(operator):\n    if is_diagonal(operator.operator1) and is_tridiagonal(operator.operator2):\n        d = diagonal(operator.operator1)\n        main, lower, upper = tridiagonal(operator.operator2)\n        # D @ T scales rows: row i multiplied by d[i]\n        return d * main, d[1:] * lower, d[:-1] * upper\n    if is_diagonal(operator.operator2) and is_tridiagonal(operator.operator1):\n        d = diagonal(operator.operator2)\n        main, lower, upper = tridiagonal(operator.operator1)\n        # T @ D scales columns: column j multiplied by d[j]\n        return d * main, d[:-1] * lower, d[1:] * upper\n    matrix = operator.as_matrix()\n    assert matrix.ndim == 2\n    main_diagonal = jnp.diagonal(matrix, offset=0)\n    upper_diagonal = jnp.diagonal(matrix, offset=1)\n    lower_diagonal = jnp.diagonal(matrix, offset=-1)\n    return main_diagonal, lower_diagonal, upper_diagonal\n\n\nfor check in (\n    is_symmetric,\n    is_diagonal,\n    has_unit_diagonal,\n    is_lower_triangular,\n    is_upper_triangular,\n    is_tridiagonal,\n    is_positive_semidefinite,\n    is_negative_semidefinite,\n):\n\n    @check.register(TangentLinearOperator)\n    def _(operator, check=check):\n        return check(operator.primal)\n\n\n# Scaling/negating preserves these structural properties\nfor check in (\n    is_symmetric,\n    is_diagonal,\n    is_lower_triangular,\n    is_upper_triangular,\n    is_tridiagonal,\n):\n\n    @check.register(MulLinearOperator)\n    @check.register(NegLinearOperator)\n    @check.register(DivLinearOperator)\n    def _(operator, check=check):\n        return check(operator.operator)\n\n\n# has_unit_diagonal is NOT preserved by scaling or negation\n@has_unit_diagonal.register(MulLinearOperator)\n@has_unit_diagonal.register(NegLinearOperator)\n@has_unit_diagonal.register(DivLinearOperator)\ndef _(operator):\n    return False\n\n\nclass _ScalarSign(enum.Enum):\n    positive = enum.auto()\n    negative = enum.auto()\n    zero = enum.auto()\n    unknown = enum.auto()\n\n\ndef _scalar_sign(scalar) -> _ScalarSign:\n    \"\"\"Returns the sign of a scalar, or unknown for JAX tracers.\"\"\"\n    if isinstance(scalar, (int, float, np.ndarray, np.generic)):\n        scalar = float(scalar)\n        if scalar > 0:\n            return _ScalarSign.positive\n        elif scalar < 0:\n            return _ScalarSign.negative\n        else:\n            return _ScalarSign.zero\n    else:\n        return _ScalarSign.unknown\n\n\n# PSD/NSD for MulLinearOperator: depends on sign of scalar\n# Zero scalar gives zero matrix which is both PSD and NSD\n@is_positive_semidefinite.register(MulLinearOperator)\ndef _(operator):\n    sign = _scalar_sign(operator.scalar)\n    if sign is _ScalarSign.positive:\n        return is_positive_semidefinite(operator.operator)\n    elif sign is _ScalarSign.negative:\n        return is_negative_semidefinite(operator.operator)\n    elif sign is _ScalarSign.zero:\n        return True  # zero matrix is PSD\n    return False\n\n\n@is_negative_semidefinite.register(MulLinearOperator)\ndef _(operator):\n    sign = _scalar_sign(operator.scalar)\n    if sign is _ScalarSign.positive:\n        return is_negative_semidefinite(operator.operator)\n    elif sign is _ScalarSign.negative:\n        return is_positive_semidefinite(operator.operator)\n    elif sign is _ScalarSign.zero:\n        return True  # zero matrix is NSD\n    return False\n\n\n# PSD/NSD for DivLinearOperator: depends on sign of scalar\n# Zero scalar is division by zero - return False (conservative)\n@is_positive_semidefinite.register(DivLinearOperator)\ndef _(operator):\n    sign = _scalar_sign(operator.scalar)\n    if sign is _ScalarSign.positive:\n        return is_positive_semidefinite(operator.operator)\n    elif sign is _ScalarSign.negative:\n        return is_negative_semidefinite(operator.operator)\n    return False\n\n\n@is_negative_semidefinite.register(DivLinearOperator)\ndef _(operator):\n    sign = _scalar_sign(operator.scalar)\n    if sign is _ScalarSign.positive:\n        return is_negative_semidefinite(operator.operator)\n    elif sign is _ScalarSign.negative:\n        return is_positive_semidefinite(operator.operator)\n    return False\n\n\n# PSD/NSD for NegLinearOperator: negation swaps PSD <-> NSD\n@is_positive_semidefinite.register(NegLinearOperator)\ndef _(operator):\n    return is_negative_semidefinite(operator.operator)\n\n\n@is_negative_semidefinite.register(NegLinearOperator)\ndef _(operator):\n    return is_positive_semidefinite(operator.operator)\n\n\nfor check, tag in (\n    (is_symmetric, symmetric_tag),\n    (is_diagonal, diagonal_tag),\n    (has_unit_diagonal, unit_diagonal_tag),\n    (is_lower_triangular, lower_triangular_tag),\n    (is_upper_triangular, upper_triangular_tag),\n    (is_positive_semidefinite, positive_semidefinite_tag),\n    (is_negative_semidefinite, negative_semidefinite_tag),\n    (is_tridiagonal, tridiagonal_tag),\n):\n\n    @check.register(TaggedLinearOperator)\n    def _(operator, check=check, tag=tag):\n        return (tag in operator.tags) or check(operator.operator)\n\n\nfor check in (\n    is_symmetric,\n    is_diagonal,\n    is_lower_triangular,\n    is_upper_triangular,\n    is_positive_semidefinite,\n    is_negative_semidefinite,\n    is_tridiagonal,\n):\n\n    @check.register(AddLinearOperator)\n    def _(operator, check=check):\n        return check(operator.operator1) and check(operator.operator2)\n\n\n@has_unit_diagonal.register(AddLinearOperator)\ndef _(operator):\n    return False\n\n\n# These properties ARE preserved under composition\nfor check in (\n    is_diagonal,\n    is_lower_triangular,\n    is_upper_triangular,\n):\n\n    @check.register(ComposedLinearOperator)\n    def _(operator, check=check):\n        return check(operator.operator1) and check(operator.operator2)\n\n\n# is_symmetric: A@B is symmetric only if A and B commute. Diagonal matrices commute.\n@is_symmetric.register(ComposedLinearOperator)\ndef _(operator):\n    return is_diagonal(operator.operator1) and is_diagonal(operator.operator2)\n\n\n# is_tridiagonal: tridiagonal @ tridiagonal = pentadiagonal, but\n# tridiagonal @ diagonal = tridiagonal and diagonal @ tridiagonal = tridiagonal\n@is_tridiagonal.register(ComposedLinearOperator)\ndef _(operator):\n    if is_diagonal(operator.operator1):\n        return is_tridiagonal(operator.operator2)\n    if is_diagonal(operator.operator2):\n        return is_tridiagonal(operator.operator1)\n    return False\n\n\n# PSD/NSD: not preserved under composition in general.\n@is_positive_semidefinite.register(ComposedLinearOperator)\n@is_negative_semidefinite.register(ComposedLinearOperator)\ndef _(operator):\n    return False\n\n\n@has_unit_diagonal.register(ComposedLinearOperator)\ndef _(operator):\n    a = is_diagonal(operator)\n    b = is_lower_triangular(operator)\n    c = is_upper_triangular(operator)\n    d = has_unit_diagonal(operator.operator1)\n    e = has_unit_diagonal(operator.operator2)\n    return (a or b or c) and d and e\n\n\n# conj\n\n\n@ft.singledispatch\ndef conj(operator: AbstractLinearOperator) -> AbstractLinearOperator:\n    \"\"\"Elementwise conjugate of a linear operator. This returns another linear operator.\n\n    **Arguments:**\n\n    - `operator`: a linear operator.\n\n    **Returns:**\n\n    Another linear operator.\n    \"\"\"\n    _default_not_implemented(\"conj\", operator)\n\n\n@conj.register(MatrixLinearOperator)\ndef _(operator):\n    return MatrixLinearOperator(operator.matrix.conj(), operator.tags)\n\n\n@conj.register(PyTreeLinearOperator)\ndef _(operator):\n    pytree_conj = jtu.tree_map(lambda x: x.conj(), operator.pytree)\n    return PyTreeLinearOperator(pytree_conj, operator.out_structure(), operator.tags)\n\n\n@conj.register(DiagonalLinearOperator)\ndef _(operator):\n    diagonal_conj = jtu.tree_map(lambda x: x.conj(), operator.diagonal)\n    return DiagonalLinearOperator(diagonal_conj)\n\n\n@conj.register(JacobianLinearOperator)\ndef _(operator):\n    return conj(linearise(operator))\n\n\n@conj.register(FunctionLinearOperator)\ndef _(operator):\n    return FunctionLinearOperator(\n        lambda vec: jtu.tree_map(jnp.conj, operator.mv(jtu.tree_map(jnp.conj, vec))),\n        operator.in_structure(),\n        operator.tags,\n    )\n\n\n@conj.register(IdentityLinearOperator)\ndef _(operator):\n    return operator\n\n\n@conj.register(TridiagonalLinearOperator)\ndef _(operator):\n    return TridiagonalLinearOperator(\n        operator.diagonal.conj(),\n        operator.lower_diagonal.conj(),\n        operator.upper_diagonal.conj(),\n    )\n\n\n@conj.register(TaggedLinearOperator)\ndef _(operator):\n    return TaggedLinearOperator(conj(operator.operator), operator.tags)\n\n\n@conj.register(TangentLinearOperator)\ndef _(operator):\n    c = lambda operator: conj(operator)\n    primal_out, tangent_out = eqx.filter_jvp(c, (operator.primal,), (operator.tangent,))\n    return TangentLinearOperator(primal_out, tangent_out)\n\n\n@conj.register(AddLinearOperator)\ndef _(operator):\n    return conj(operator.operator1) + conj(operator.operator2)\n\n\n@conj.register(MulLinearOperator)\ndef _(operator):\n    return conj(operator.operator) * operator.scalar.conj()\n\n\n@conj.register(NegLinearOperator)\ndef _(operator):\n    return -conj(operator.operator)\n\n\n@conj.register(DivLinearOperator)\ndef _(operator):\n    return conj(operator.operator) / operator.scalar.conj()\n\n\n@conj.register(ComposedLinearOperator)\ndef _(operator):\n    return conj(operator.operator1) @ conj(operator.operator2)\n"
  },
  {
    "path": "lineax/_solution.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Any\n\nimport equinox as eqx\nimport equinox.internal as eqxi\nfrom jaxtyping import Array, ArrayLike, PyTree\n\n\n_singular_msg = \"\"\"\nA linear solver returned non-finite (NaN or inf) output. This usually means that an\noperator was not well-posed, and that its solver does not support this.\n\nIf you are trying solve a linear least-squares problem then you should pass\n`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`\nassumes that the operator is square and nonsingular.\n\nIf you *were* expecting this solver to work with this operator, then it may be because:\n\n(a) the operator is singular, and your code has a bug; or\n\n(b) the operator was nearly singular (i.e. it had a high condition number:\n    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from\n    numerical instability issues; or\n\n(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)\n    that is does not actually satisfy.\n\"\"\".strip()\n\n\n_nonfinite_msg = \"\"\"\nA linear solver received non-finite (NaN or inf) input and cannot determine a\nsolution.\n\nThis means that you have a bug upstream of Lineax and should check the inputs to\n`lineax.linear_solve` for non-finite values.\n\"\"\".strip()\n\n\nclass RESULTS(eqxi.Enumeration):\n    successful = \"\"\n    max_steps_reached = (\n        \"The maximum number of solver steps was reached. Try increasing `max_steps`.\"\n    )\n    singular = _singular_msg\n    breakdown = (\n        \"A form of iterative breakdown has occured in a linear solve. \"\n        \"Try using a different solver for this problem or increase `restart` \"\n        \"if using GMRES.\"\n    )\n    stagnation = (\n        \"A stagnation in an iterative linear solve has occurred. Try increasing \"\n        \"`stagnation_iters` or `restart`.\"\n    )\n    conlim = \"Condition number of A seems to be larger than `conlim`.\"\n    nonfinite_input = _nonfinite_msg\n\n\nclass Solution(eqx.Module):\n    \"\"\"The solution to a linear solve.\n\n    **Attributes:**\n\n    - `value`: The solution to the solve.\n    - `result`: An integer representing whether the solve was successful or not. This\n        can be converted into a human-readable error message via\n        `lineax.RESULTS[result]`.\n    - `stats`: Statistics about the solver, e.g. the number of steps that were required.\n    - `state`: The internal state of the solver. The meaning of this is specific to each\n        solver.\n    \"\"\"\n\n    value: PyTree[Array]\n    result: RESULTS\n    stats: dict[str, PyTree[ArrayLike]]\n    state: PyTree[Any]\n"
  },
  {
    "path": "lineax/_solve.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport abc\nimport functools as ft\nfrom typing import Any, Generic, TypeAlias, TypeVar\n\nimport equinox as eqx\nimport equinox.internal as eqxi\nimport jax\nimport jax.core\nimport jax.interpreters.ad as ad\nimport jax.lax as lax\nimport jax.numpy as jnp\nimport jax.tree_util as jtu\nfrom equinox.internal import ω\nfrom jax._src.ad_util import stop_gradient_p\nfrom jaxtyping import Array, ArrayLike, PyTree\n\nfrom ._custom_types import sentinel\nfrom ._misc import inexact_asarray, strip_weak_dtype\nfrom ._operator import (\n    AbstractLinearOperator,\n    conj,\n    FunctionLinearOperator,\n    has_unit_diagonal,\n    IdentityLinearOperator,\n    is_diagonal,\n    is_lower_triangular,\n    is_negative_semidefinite,\n    is_positive_semidefinite,\n    is_symmetric,\n    is_tridiagonal,\n    is_upper_triangular,\n    linearise,\n    TangentLinearOperator,\n)\nfrom ._solution import RESULTS, Solution\nfrom ._tags import (\n    diagonal_tag,\n    lower_triangular_tag,\n    negative_semidefinite_tag,\n    positive_semidefinite_tag,\n    symmetric_tag,\n    unit_diagonal_tag,\n    upper_triangular_tag,\n)\n\n\n#\n# _linear_solve_p\n#\n\n\ndef _to_shapedarray(x):\n    if isinstance(x, jax.ShapeDtypeStruct):\n        return jax.core.ShapedArray(x.shape, x.dtype)\n    else:\n        return x\n\n\ndef _to_struct(x):\n    if isinstance(x, jax.core.ShapedArray):\n        return jax.ShapeDtypeStruct(x.shape, x.dtype)\n    elif isinstance(x, jax.core.AbstractValue):\n        raise NotImplementedError(\n            \"`lineax.linear_solve` only supports working with JAX arrays; not \"\n            f\"other abstract values. Got abstract value {x}.\"\n        )\n    else:\n        return x\n\n\ndef _assert_false(x):\n    assert False\n\n\ndef _is_none(x):\n    return x is None\n\n\ndef _sum(*args):\n    return sum(args)\n\n\ndef _linear_solve_impl(_, state, vector, options, solver, throw, *, check_closure):\n    out = solver.compute(state, vector, options)\n    if check_closure:\n        out = eqxi.nontraceable(\n            out, name=\"lineax.linear_solve with respect to a closed-over value\"\n        )\n    solution, result, stats = out\n    has_nonfinite_output = jnp.any(\n        jnp.stack(\n            [jnp.any(jnp.invert(jnp.isfinite(x))) for x in jtu.tree_leaves(solution)]\n        )\n    )\n    result = RESULTS.where(\n        (result == RESULTS.successful) & has_nonfinite_output,\n        RESULTS.singular,\n        result,\n    )\n    has_nonfinite_input = jnp.any(\n        jnp.stack(\n            [jnp.any(jnp.invert(jnp.isfinite(x))) for x in jtu.tree_leaves(vector)]\n        )\n    )\n    result = RESULTS.where(\n        (result == RESULTS.singular) & has_nonfinite_input,\n        RESULTS.nonfinite_input,\n        result,\n    )\n    if throw:\n        solution, result, stats = result.error_if(\n            (solution, result, stats),\n            result != RESULTS.successful,\n        )\n    return solution, result, stats\n\n\n@eqxi.filter_primitive_def\ndef _linear_solve_abstract_eval(operator, state, vector, options, solver, throw):\n    state, vector, options, solver = jtu.tree_map(\n        _to_struct, (state, vector, options, solver)\n    )\n    out = eqx.filter_eval_shape(\n        _linear_solve_impl,\n        operator,\n        state,\n        vector,\n        options,\n        solver,\n        throw,\n        check_closure=False,\n    )\n    out = jtu.tree_map(_to_shapedarray, out)\n    return out\n\n\n@eqxi.filter_primitive_jvp\ndef _linear_solve_jvp(primals, tangents):\n    operator, state, vector, options, solver, throw = primals\n    t_operator, t_state, t_vector, t_options, t_solver, t_throw = tangents\n    jtu.tree_map(_assert_false, (t_state, t_options, t_solver, t_throw))\n    del t_state, t_options, t_solver, t_throw\n\n    # Note that we pass throw=True unconditionally to all the tangent solves, as there\n    # is nowhere we can pipe their error to.\n    # This is the primal solve so we can respect the original `throw`.\n    solution, result, stats = eqxi.filter_primitive_bind(\n        linear_solve_p, operator, state, vector, options, solver, throw\n    )\n\n    #\n    # Consider the primal problem of linearly solving for x in Ax=b.\n    # Let ^ denote pseudoinverses, ᵀ denote transposes, and ' denote tangents.\n    # The linear_solve routine returns specifically the pseudoinverse solution, i.e.\n    #\n    # x = A^b\n    #\n    # Therefore x' = A^'b + A^b'\n    #\n    # Now A^' = -A^A'A^ + A^A^ᵀAᵀ'(I - AA^) + (I - A^A)Aᵀ'A^ᵀA^\n    #\n    # (Source: https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)\n    #\n    # This results in:\n    #\n    # x' = A^(-A'x + A^ᵀAᵀ'(b - Ax) - Ay + b') + y\n    #\n    # where\n    #\n    # y = Aᵀ'A^ᵀx\n    #\n    # note that if A has linearly independent columns, then the y - A^Ay\n    # term disappears and gives\n    #\n    # x' = A^(-A'x + A^ᵀAᵀ'(b - Ax) + b')\n    #\n    # and if A has linearly independent rows, then the A^A^ᵀAᵀ'(b - Ax) term\n    # disappears giving:\n    #\n    # x' = A^(-A'x - Ay + b') + y\n    #\n    # if A has linearly independent rows and columns, then A is nonsingular and\n    #\n    # x' = A^(-A'x + b')\n\n    vecs = []\n    sols = []\n    if any(t is not None for t in jtu.tree_leaves(t_vector, is_leaf=_is_none)):\n        # b' term\n        vecs.append(\n            jtu.tree_map(eqxi.materialise_zeros, vector, t_vector, is_leaf=_is_none)\n        )\n    if any(t is not None for t in jtu.tree_leaves(t_operator, is_leaf=_is_none)):\n        t_operator = TangentLinearOperator(operator, t_operator)\n        t_operator = linearise(t_operator)  # optimise for matvecs\n        # -A'x term\n        vec = (-(t_operator.mv(solution) ** ω)).ω\n        vecs.append(vec)\n        rows, columns = operator.out_size(), operator.in_size()\n        assume_independent_rows = solver.assume_full_rank() and rows <= columns\n        assume_independent_columns = solver.assume_full_rank() and columns <= rows\n        if not assume_independent_rows or not assume_independent_columns:\n            operator_conj_transpose = conj(operator).transpose()\n            t_operator_conj_transpose = conj(t_operator).transpose()\n            state_conj, options_conj = solver.conj(state, options)\n            state_conj_transpose, options_conj_transpose = solver.transpose(\n                state_conj, options_conj\n            )\n        if not assume_independent_rows:\n            lst_sqr_diff = (vector**ω - operator.mv(solution) ** ω).ω\n            tmp = t_operator_conj_transpose.mv(lst_sqr_diff)  # pyright: ignore\n            tmp, _, _ = eqxi.filter_primitive_bind(\n                linear_solve_p,\n                operator_conj_transpose,  # pyright: ignore\n                state_conj_transpose,  # pyright: ignore\n                tmp,\n                options_conj_transpose,  # pyright: ignore\n                solver,\n                True,\n            )\n            vecs.append(tmp)\n\n        if not assume_independent_columns:\n            tmp1, _, _ = eqxi.filter_primitive_bind(\n                linear_solve_p,\n                operator_conj_transpose,  # pyright: ignore\n                state_conj_transpose,  # pyright:ignore\n                solution,\n                options_conj_transpose,  # pyright: ignore\n                solver,\n                True,\n            )\n            tmp2 = t_operator_conj_transpose.mv(tmp1)  # pyright: ignore\n            # tmp2 is the y term\n            tmp3 = operator.mv(tmp2)\n            tmp4 = (-(tmp3**ω)).ω\n            # tmp4 is the Ay term\n            vecs.append(tmp4)\n            sols.append(tmp2)\n    vecs = jtu.tree_map(_sum, *vecs)\n    # the A^ term at the very beginning\n    sol, _, _ = eqxi.filter_primitive_bind(\n        linear_solve_p, operator, state, vecs, options, solver, True\n    )\n    sols.append(sol)\n    t_solution = jtu.tree_map(_sum, *sols)\n\n    out = solution, result, stats\n    t_out = (\n        t_solution,\n        jtu.tree_map(lambda _: None, result),\n        jtu.tree_map(lambda _: None, stats),\n    )\n    return out, t_out\n\n\ndef _is_undefined(x):\n    return isinstance(x, ad.UndefinedPrimal)\n\n\ndef _assert_defined(x):\n    assert not _is_undefined(x)\n\n\ndef _keep_undefined(v, ct):\n    if _is_undefined(v):\n        return ct\n    else:\n        return None\n\n\n@eqxi.filter_primitive_transpose(materialise_zeros=True)  # pyright: ignore\ndef _linear_solve_transpose(inputs, cts_out):\n    cts_solution, _, _ = cts_out\n    operator, state, vector, options, solver, _ = inputs\n    jtu.tree_map(\n        _assert_defined, (operator, state, options, solver), is_leaf=_is_undefined\n    )\n    cts_solution = jtu.tree_map(\n        ft.partial(eqxi.materialise_zeros, allow_struct=True),\n        operator.in_structure(),\n        cts_solution,\n    )\n    operator_transpose = operator.transpose()\n    state_transpose, options_transpose = solver.transpose(state, options)\n    cts_vector, _, _ = eqxi.filter_primitive_bind(\n        linear_solve_p,\n        operator_transpose,\n        state_transpose,\n        cts_solution,\n        options_transpose,\n        solver,\n        True,  # throw=True unconditionally: nowhere to pipe result to.\n    )\n    cts_vector = jtu.tree_map(\n        _keep_undefined, vector, cts_vector, is_leaf=_is_undefined\n    )\n    operator_none = jtu.tree_map(lambda _: None, operator)\n    state_none = jtu.tree_map(lambda _: None, state)\n    options_none = jtu.tree_map(lambda _: None, options)\n    solver_none = jtu.tree_map(lambda _: None, solver)\n    throw_none = None\n    return operator_none, state_none, cts_vector, options_none, solver_none, throw_none\n\n\n# Call with `check_closure=False` so that the autocreated vmap rule works.\nlinear_solve_p = eqxi.create_vprim(\n    \"linear_solve\",\n    eqxi.filter_primitive_def(ft.partial(_linear_solve_impl, check_closure=False)),\n    _linear_solve_abstract_eval,\n    _linear_solve_jvp,\n    _linear_solve_transpose,\n)\n# Then rebind so that the impl rule catches leaked-in tracers.\nlinear_solve_p.def_impl(\n    eqxi.filter_primitive_def(ft.partial(_linear_solve_impl, check_closure=True))\n)\neqxi.register_impl_finalisation(linear_solve_p)\n\n\n#\n# linear_solve\n#\n\n\n_SolverState = TypeVar(\"_SolverState\")\n\n\nclass AbstractLinearSolver(eqx.Module, Generic[_SolverState]):\n    \"\"\"Abstract base class for all linear solvers.\"\"\"\n\n    @abc.abstractmethod\n    def init(\n        self, operator: AbstractLinearOperator, options: dict[str, Any]\n    ) -> _SolverState:\n        \"\"\"Do any initial computation on just the `operator`.\n\n        For example, an LU solver would compute the LU decomposition of the operator\n        (and this does not require knowing the vector yet).\n\n        It is common to need to solve the linear system `Ax=b` multiple times in\n        succession, with the same operator `A` and multiple vectors `b`. This method\n        improves efficiency by making it possible to re-use the computation performed\n        on just the operator.\n\n        !!! Example\n\n            ```python\n            operator = lx.MatrixLinearOperator(...)\n            vector1 = ...\n            vector2 = ...\n            solver = lx.LU()\n            state = solver.init(operator, options={})\n            solution1 = lx.linear_solve(operator, vector1, solver, state=state)\n            solution2 = lx.linear_solve(operator, vector2, solver, state=state)\n            ```\n\n        **Arguments:**\n\n        - `operator`: a linear operator.\n        - `options`: a dictionary of any extra options that the solver may wish to\n            accept.\n\n        **Returns:**\n\n        A PyTree of arbitrary Python objects.\n        \"\"\"\n\n    @abc.abstractmethod\n    def compute(\n        self, state: _SolverState, vector: PyTree[Array], options: dict[str, Any]\n    ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:\n        \"\"\"Solves a linear system.\n\n        **Arguments:**\n\n        - `state`: as returned from [`lineax.AbstractLinearSolver.init`][].\n        - `vector`: the vector to solve against.\n        - `options`: a dictionary of any extra options that the solver may wish to\n            accept. For example, [`lineax.CG`][] accepts a `preconditioner` option.\n\n        **Returns:**\n\n        A 3-tuple of:\n\n        - The solution to the linear system.\n        - An integer indicating the success or failure of the solve. This is an integer\n            which may be converted to a human-readable error message via\n            `lx.RESULTS[...]`.\n        - A dictionary of an extra statistics about the solve, e.g. the number of steps\n            taken.\n        \"\"\"\n\n    @abc.abstractmethod\n    def transpose(\n        self, state: _SolverState, options: dict[str, Any]\n    ) -> tuple[_SolverState, dict[str, Any]]:\n        \"\"\"Transposes the result of [`lineax.AbstractLinearSolver.init`][].\n\n        That is, it should be the case that\n        ```python\n        state_transpose, _ = solver.transpose(solver.init(operator, options), options)\n        state_transpose2 = solver.init(operator.T, options)\n        ```\n        must be identical to each other.\n\n        It is relatively common (in particular when differentiating through a linear\n        solve) to need to solve both `Ax = b` and `A^T x = b`. This method makes it\n        possible to avoid computing both `solver.init(operator)` and\n        `solver.init(operator.T)` if one can be cheaply computed from the other.\n\n        **Arguments:**\n\n        - `state`: as returned from `solver.init`.\n        - `options`: any extra options that were passed to `solve.init`.\n\n        **Returns:**\n\n        A 2-tuple of:\n\n        - The state of the transposed operator.\n        - The options for the transposed operator.\n        \"\"\"\n\n    @abc.abstractmethod\n    def conj(\n        self, state: _SolverState, options: dict[str, Any]\n    ) -> tuple[_SolverState, dict[str, Any]]:\n        \"\"\"Conjugate the result of [`lineax.AbstractLinearSolver.init`][].\n\n        That is, it should be the case that\n        ```python\n        state_conj, _ = solver.conj(solver.init(operator, options), options)\n        state_conj2 = solver.init(conj(operator), options)\n        ```\n        must be identical to each other.\n\n        **Arguments:**\n\n        - `state`: as returned from `solver.init`.\n        - `options`: any extra options that were passed to `solve.init`.\n\n        **Returns:**\n\n        A 2-tuple of:\n\n        - The state of the conjugated operator.\n        - The options for the conjugated operator.\n        \"\"\"\n\n    @abc.abstractmethod\n    def assume_full_rank(self) -> bool:\n        \"\"\"Does this solver assume that all operators are full rank?\n\n        When `False`, a more expensive backward pass is needed to account for\n        the extra generality. In a custom linear solver, it is always safe to\n        return False.\n\n        **Arguments:**\n\n        Nothing.\n\n        **Returns:**\n\n        Either `True` or `False`.\n        \"\"\"\n\n\n_qr_token = eqxi.str2jax(\"qr_token\")\n_diagonal_token = eqxi.str2jax(\"diagonal_token\")\n_well_posed_diagonal_token = eqxi.str2jax(\"well_posed_diagonal_token\")\n_tridiagonal_token = eqxi.str2jax(\"tridiagonal_token\")\n_triangular_token = eqxi.str2jax(\"triangular_token\")\n_cholesky_token = eqxi.str2jax(\"cholesky_token\")\n_lu_token = eqxi.str2jax(\"lu_token\")\n_svd_token = eqxi.str2jax(\"svd_token\")\n\n\n# Ugly delayed import because we have the dependency chain\n# linear_solve -> AutoLinearSolver -> {Cholesky,...} -> AbstractLinearSolver\n# but we want linear_solver and AbstractLinearSolver in the same file.\ndef _lookup(token) -> AbstractLinearSolver:\n    from . import _solver\n\n    # pyright doesn't know that these keys are hashable\n    _lookup_dict = {\n        _qr_token: _solver.QR(),  # pyright: ignore\n        _diagonal_token: _solver.Diagonal(),  # pyright: ignore\n        _well_posed_diagonal_token: _solver.Diagonal(  # pyright: ignore\n            well_posed=True\n        ),\n        _tridiagonal_token: _solver.Tridiagonal(),  # pyright: ignore\n        _triangular_token: _solver.Triangular(),  # pyright: ignore\n        _cholesky_token: _solver.Cholesky(),  # pyright: ignore\n        _lu_token: _solver.LU(),  # pyright: ignore\n        _svd_token: _solver.SVD(),  # pyright: ignore\n    }\n    return _lookup_dict[token]\n\n\n_AutoLinearSolverState: TypeAlias = tuple[Any, Any]\n\n\nclass AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState]):\n    \"\"\"Automatically determines a good linear solver based on the structure of the\n    operator.\n\n    - If `well_posed=True`:\n        - If the operator is diagonal, then use [`lineax.Diagonal`][].\n        - If the operator is tridiagonal, then use [`lineax.Tridiagonal`][].\n        - If the operator is triangular, then use [`lineax.Triangular`][].\n        - If the matrix is positive or negative (semi-)definite, then use\n            [`lineax.Cholesky`][].\n        - Else use [`lineax.LU`][].\n\n    This is a good choice if you want to be certain that an error is raised for\n    ill-posed systems.\n\n    - If `well_posed=False`:\n        - If the operator is diagonal, then use [`lineax.Diagonal`][].\n        - Else use [`lineax.SVD`][].\n\n    This is a good choice if you want to be certain that you can handle ill-posed\n    systems.\n\n    - If `well_posed=None`:\n        - If the operator is non-square, then use [`lineax.QR`][].\n        - If the operator is diagonal, then use [`lineax.Diagonal`][].\n        - If the operator is tridiagonal, then use [`lineax.Tridiagonal`][].\n        - If the operator is triangular, then use [`lineax.Triangular`][].\n        - If the matrix is positive or negative (semi-)definite, then use\n            [`lineax.Cholesky`][].\n        - Else, use [`lineax.LU`][].\n\n    This is a good choice if your primary concern is computational efficiency. It will\n    handle ill-posed systems as long as it is not computationally expensive to do so.\n    \"\"\"\n\n    well_posed: bool | None\n\n    def _select_solver(self, operator: AbstractLinearOperator):\n        if self.well_posed is True:\n            if operator.in_size() != operator.out_size():\n                raise ValueError(\n                    \"Cannot use `AutoLinearSolver(well_posed=True)` with a non-square \"\n                    \"operator. If you are trying solve a least-squares problem then \"\n                    \"you should pass `solver=AutoLinearSolver(well_posed=False)`. By \"\n                    \"default `lineax.linear_solve` assumes that the operator is \"\n                    \"square and nonsingular.\"\n                )\n            if is_diagonal(operator):\n                token = _well_posed_diagonal_token\n            elif is_tridiagonal(operator):\n                token = _tridiagonal_token\n            elif is_lower_triangular(operator) or is_upper_triangular(operator):\n                token = _triangular_token\n            elif is_positive_semidefinite(operator) or is_negative_semidefinite(\n                operator\n            ):\n                token = _cholesky_token\n            else:\n                token = _lu_token\n        elif self.well_posed is False:\n            if is_diagonal(operator):\n                token = _diagonal_token\n            else:\n                # TODO: use rank-revealing QR instead.\n                token = _svd_token\n        elif self.well_posed is None:\n            if operator.in_size() != operator.out_size():\n                token = _qr_token\n            elif is_diagonal(operator):\n                token = _diagonal_token\n            elif is_tridiagonal(operator):\n                token = _tridiagonal_token\n            elif is_lower_triangular(operator) or is_upper_triangular(operator):\n                token = _triangular_token\n            elif is_positive_semidefinite(operator) or is_negative_semidefinite(\n                operator\n            ):\n                token = _cholesky_token\n            else:\n                token = _lu_token\n        else:\n            raise ValueError(f\"Invalid value `well_posed={self.well_posed}`.\")\n        return token\n\n    def select_solver(self, operator: AbstractLinearOperator) -> AbstractLinearSolver:\n        \"\"\"Check which solver that [`lineax.AutoLinearSolver`][] will dispatch to.\n\n        **Arguments:**\n\n        - `operator`: a linear operator.\n\n        **Returns:**\n\n        The linear solver that will be used.\n        \"\"\"\n        return _lookup(self._select_solver(operator))\n\n    def init(self, operator, options) -> _AutoLinearSolverState:\n        token = self._select_solver(operator)\n        return token, _lookup(token).init(operator, options)\n\n    def compute(\n        self,\n        state: _AutoLinearSolverState,\n        vector: PyTree[Array],\n        options: dict[str, Any],\n    ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:\n        token, state = state\n        solver = _lookup(token)\n        solution, result, _ = solver.compute(state, vector, options)\n        return solution, result, {}\n\n    def transpose(self, state: _AutoLinearSolverState, options: dict[str, Any]):\n        token, state = state\n        solver = _lookup(token)\n        transpose_state, transpose_options = solver.transpose(state, options)\n        transpose_state = (token, transpose_state)\n        return transpose_state, transpose_options\n\n    def conj(self, state: _AutoLinearSolverState, options: dict[str, Any]):\n        token, state = state\n        solver = _lookup(token)\n        conj_state, conj_options = solver.conj(state, options)\n        conj_state = (token, conj_state)\n        return conj_state, conj_options\n\n    def assume_full_rank(self):\n        return self.well_posed is not False\n\n\nAutoLinearSolver.__init__.__doc__ = \"\"\"**Arguments:**\n\n- `well_posed`: whether to only handle well-posed systems or not, as discussed above.\n\"\"\"\n\n\n# TODO(kidger): gmres, bicgstab\n# TODO(kidger): support auxiliary outputs\n@eqx.filter_jit\ndef linear_solve(\n    operator: AbstractLinearOperator,\n    vector: PyTree[ArrayLike],\n    solver: AbstractLinearSolver = AutoLinearSolver(well_posed=True),\n    *,\n    options: dict[str, Any] | None = None,\n    state: PyTree[Any] = sentinel,\n    throw: bool = True,\n) -> Solution:\n    r\"\"\"Solves a linear system.\n\n    Given an operator represented as a matrix $A$, and a vector $b$: if the operator is\n    square and nonsingular (so that the problem is well-posed), then this returns the\n    usual solution $x$ to $Ax = b$, defined as $A^{-1}b$.\n\n    If the operator is overdetermined, then this either returns the least-squares\n    solution $\\min_x \\| Ax - b \\|_2$, or throws an error. (Depending on the choice of\n    solver.)\n\n    If the operator is underdetermined, then this either returns the minimum-norm\n    solution $\\min_x \\|x\\|_2 \\text{ subject to } Ax = b$, or throws an error. (Depending\n    on the choice of solver.)\n\n    !!! info\n\n        This function is equivalent to either `numpy.linalg.solve`, or to its\n        generalisation `numpy.linalg.lstsq`, depending on the choice of solver.\n\n    The default solver is `lineax.AutoLinearSolver(well_posed=True)`. This\n    automatically selects a solver depending on the structure (e.g. triangular) of your\n    problem, and will throw an error if your system is overdetermined or\n    underdetermined.\n\n    Use `lineax.AutoLinearSolver(well_posed=False)` if your system is known to be\n    overdetermined or underdetermined (although handling this case implies greater\n    computational cost).\n\n    !!! tip\n\n        These three kinds of solution to a linear system are collectively known as the\n        \"pseudoinverse solution\" to a linear system. That is, given our matrix $A$, let\n        $A^\\dagger$ denote the\n        [Moore--Penrose pseudoinverse](https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse)\n        of $A$. Then the usual/least-squares/minimum-norm solution are all equal to\n        $A^\\dagger b$.\n\n    **Arguments:**\n\n    - `operator`: a linear operator. This is the '$A$' in '$Ax = b$'.\n\n        Most frequently this operator is simply represented as a JAX matrix (i.e. a\n        rank-2 JAX array), but any [`lineax.AbstractLinearOperator`][] is supported.\n\n        Note that if it is a matrix, then it should be passed as an\n        [`lineax.MatrixLinearOperator`][], e.g.\n        ```python\n        matrix = jax.random.normal(key, (5, 5))  # JAX array of shape (5, 5)\n        operator = lx.MatrixLinearOperator(matrix)  # Wrap into a linear operator\n        solution = lx.linear_solve(operator, ...)\n        ```\n        rather than being passed directly.\n\n    - `vector`: the vector to solve against. This is the '$b$' in '$Ax = b$'.\n\n    - `solver`: the solver to use. Should be any [`lineax.AbstractLinearSolver`][].\n        The default is [`lineax.AutoLinearSolver`][] which behaves as discussed\n        above.\n\n        If the operator is overdetermined or underdetermined , then passing\n        [`lineax.SVD`][] is typical.\n\n    - `options`: Individual solvers may accept additional runtime arguments; for example\n        [`lineax.CG`][] allows for specifying a preconditioner. See each individual\n        solver's documentation for more details. Keyword only argument.\n\n    - `state`: If performing multiple linear solves with the same operator, then some\n        computation can be saved by recording and reusing some information; for example\n        the matrix factorisation of the operator. This value should be the result of\n        calling [`lineax.AbstractLinearSolver.init`][] on the provided `operator`.\n\n        If provided, then the underlying `operator` must still be passed to\n        `linear_solve`.\n\n        Keyword only argument.\n\n    - `throw`: How to report any failures. (E.g. an iterative solver running out of\n        steps, or a well-posed-only solver being run with a singular operator.)\n\n        If `True` then a failure will raise an error. Note that errors are only reliably\n        raised on CPUs. If on GPUs then the error may only be printed to stderr, whilst\n        on TPUs then the behaviour is undefined.\n\n        If `False` then the returned solution object will have a `result` field\n        indicating whether any failures occured. (See [`lineax.Solution`][].)\n\n        Keyword only argument.\n\n    **Returns:**\n\n    An [`lineax.Solution`][] object containing the solution to the linear system.\n    \"\"\"  # noqa: E501\n\n    if eqx.is_array(operator):\n        raise ValueError(\n            \"`lineax.linear_solve(operator=...)` should be an \"\n            \"`AbstractLinearOperator`, not a raw JAX array. If you are trying to pass \"\n            \"a matrix then this should be passed as \"\n            \"`lineax.MatrixLinearOperator(matrix)`.\"\n        )\n    if options is None:\n        options = {}\n    vector = jtu.tree_map(inexact_asarray, vector)\n    vector_struct = strip_weak_dtype(jax.eval_shape(lambda: vector))\n    operator_out_structure = strip_weak_dtype(operator.out_structure())\n    # `is` to handle tracers\n    if eqx.tree_equal(vector_struct, operator_out_structure) is not True:\n        raise ValueError(\n            \"Vector and operator structures do not match. Got a vector with structure \"\n            f\"{vector_struct} and an operator with out-structure \"\n            f\"{operator_out_structure}\"\n        )\n    if isinstance(operator, IdentityLinearOperator):\n        return Solution(\n            value=vector,\n            result=RESULTS.successful,\n            state=state,\n            stats={},\n        )\n    if state == sentinel:\n        dynamic_operator, static_operator = eqx.partition(operator, eqx.is_array)\n        stopped_operator = eqx.combine(\n            lax.stop_gradient(dynamic_operator), static_operator\n        )\n        state = solver.init(stopped_operator, options)\n\n    dynamic_state, static_state = eqx.partition(state, eqx.is_array)\n    dynamic_state = lax.stop_gradient(dynamic_state)\n    state = eqx.combine(dynamic_state, static_state)\n    options = eqxi.nondifferentiable(\n        options, name=\"`lineax.linear_solve(..., options=...)`\"\n    )\n    solver = eqxi.nondifferentiable(\n        solver, name=\"`lineax.linear_solve(..., solver=...)`\"\n    )\n    solution, result, stats = eqxi.filter_primitive_bind(\n        linear_solve_p, operator, state, vector, options, solver, throw\n    )\n    # TODO: prevent forward-mode autodiff through stats\n    stats = eqxi.nondifferentiable_backward(stats)\n    return Solution(value=solution, result=result, state=state, stats=stats)\n\n\ndef invert(\n    operator: AbstractLinearOperator,\n    solver: AbstractLinearSolver = AutoLinearSolver(well_posed=True),\n    *,\n    options: dict[str, Any] | None = None,\n    throw: bool = True,\n) -> FunctionLinearOperator:\n    r\"\"\"Returns a [`lineax.FunctionLinearOperator`][] representing the\n    (pseudo)inverse of `operator`.\n\n    `invert(A).mv(v)` is equivalent to `linear_solve(A, v, solver).value`.\n    See [`lineax.linear_solve`][] for details on how the solution is defined\n    for square, overdetermined, and underdetermined systems.\n\n    The returned operator fully supports AD (both forward and reverse mode),\n    `vmap`, and composition with other operators.\n\n    **Arguments:**\n\n    - `operator`: the linear operator to invert.\n    - `solver`: the linear solver to use. Defaults to\n        `AutoLinearSolver(well_posed=True)`.\n    - `options`: additional options passed to the solver. Defaults to `None`.\n    - `throw`: as [`lineax.linear_solve`][]. Defaults to `True`.\n\n    **Returns:**\n\n    A [`lineax.FunctionLinearOperator`][] whose `mv` solves `operator @ x = v`.\n    \"\"\"\n    if options is None:\n        options = {}\n\n    state = solver.init(operator, options)\n\n    def solve_fn(vector):\n        return linear_solve(\n            operator,\n            vector,\n            solver,\n            state=state,\n            options=options,\n            throw=throw,\n        ).value\n\n    tags = {\n        tag\n        for check, tag in [\n            (is_symmetric, symmetric_tag),\n            (is_diagonal, diagonal_tag),\n            (is_lower_triangular, lower_triangular_tag),\n            (is_upper_triangular, upper_triangular_tag),\n            (is_positive_semidefinite, positive_semidefinite_tag),\n            (is_negative_semidefinite, negative_semidefinite_tag),\n        ]\n        if check(operator)\n    }\n    if has_unit_diagonal(operator) and (\n        is_diagonal(operator)\n        or is_lower_triangular(operator)\n        or is_upper_triangular(operator)\n    ):\n        tags.add(unit_diagonal_tag)\n    return FunctionLinearOperator(solve_fn, operator.out_structure(), frozenset(tags))\n\n\n# Work around JAX issue #22011,\n# as well as https://github.com/patrick-kidger/diffrax/pull/387#issuecomment-2174488365\ndef stop_gradient_transpose(ct, x):\n    return (ct,)\n\n\nad.primitive_transposes[stop_gradient_p] = stop_gradient_transpose\n"
  },
  {
    "path": "lineax/_solver/__init__.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .bicgstab import BiCGStab as BiCGStab\nfrom .cg import CG as CG, NormalCG as NormalCG\nfrom .cholesky import Cholesky as Cholesky\nfrom .diagonal import Diagonal as Diagonal\nfrom .gmres import GMRES as GMRES\nfrom .lsmr import LSMR as LSMR\nfrom .lu import LU as LU\nfrom .normal import Normal as Normal\nfrom .qr import QR as QR\nfrom .svd import SVD as SVD\nfrom .triangular import Triangular as Triangular\nfrom .tridiagonal import Tridiagonal as Tridiagonal\n"
  },
  {
    "path": "lineax/_solver/bicgstab.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom collections.abc import Callable\nfrom typing import Any, TypeAlias\n\nimport jax\nimport jax.lax as lax\nimport jax.numpy as jnp\nimport jax.tree_util as jtu\nfrom equinox.internal import ω\nfrom jaxtyping import Array, PyTree\n\nfrom .._norm import max_norm, tree_dot\nfrom .._operator import AbstractLinearOperator, conj, linearise\nfrom .._solution import RESULTS\nfrom .._solve import AbstractLinearSolver\nfrom .misc import preconditioner_and_y0\n\n\n_BiCGStabState: TypeAlias = AbstractLinearOperator\n\n\nclass BiCGStab(AbstractLinearSolver[_BiCGStabState]):\n    \"\"\"Biconjugate gradient stabilised method for linear systems.\n\n    The operator should be square.\n\n    Equivalent to `jax.scipy.sparse.linalg.bicgstab`.\n\n    This supports the following `options` (as passed to\n    `lx.linear_solve(..., options=...)`).\n\n    - `preconditioner`: A [`lineax.AbstractLinearOperator`][]\n        to be used as a preconditioner. Defaults to\n        [`lineax.IdentityLinearOperator`][]. This method uses right preconditioning.\n    - `y0`: The initial estimate of the solution to the linear system. Defaults to all\n        zeros.\n    \"\"\"\n\n    rtol: float\n    atol: float\n    norm: Callable = max_norm\n    max_steps: int | None = None\n\n    def __check_init__(self):\n        if isinstance(self.rtol, (int, float)) and self.rtol < 0:\n            raise ValueError(\"Tolerances must be non-negative.\")\n        if isinstance(self.atol, (int, float)) and self.atol < 0:\n            raise ValueError(\"Tolerances must be non-negative.\")\n\n        if isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)):\n            if self.atol == 0 and self.rtol == 0 and self.max_steps is None:\n                raise ValueError(\n                    \"Must specify `rtol`, `atol`, or `max_steps` (or some combination \"\n                    \"of all three).\"\n                )\n\n    def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):\n        if operator.in_structure() != operator.out_structure():\n            raise ValueError(\n                \"`BiCGstab(..., normal=False)` may only be used for linear solves with \"\n                \"square matrices.\"\n            )\n        return linearise(operator)\n\n    def compute(\n        self, state: _BiCGStabState, vector: PyTree[Array], options: dict[str, Any]\n    ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:\n        operator = state\n        preconditioner, y0 = preconditioner_and_y0(operator, vector, options)\n        leaves, _ = jtu.tree_flatten(vector)\n        if self.max_steps is None:\n            size = sum(leaf.size for leaf in leaves)\n            max_steps = 10 * size\n        else:\n            max_steps = self.max_steps\n        has_scale = not (\n            isinstance(self.atol, (int, float))\n            and isinstance(self.rtol, (int, float))\n            and self.atol == 0\n            and self.rtol == 0\n        )\n        if has_scale:\n            b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω\n\n        # This implementation is the same a jax.scipy.sparse.linalg.bicgstab\n        # but with AbstractLinearOperator.\n        # We use the notation found on the wikipedia except with y instead of x:\n        # https://en.wikipedia.org/wiki/\n        # Biconjugate_gradient_stabilized_method#Preconditioned_BiCGSTAB\n        # preconditioner in this case is K2^(-1) (i.e., right preconditioning)\n\n        r0 = (vector**ω - operator.mv(y0) ** ω).ω\n\n        def breakdown_occurred(omega, alpha, rho):\n            # Empirically, the tolerance checks for breakdown are very tight.\n            # These specific tolerances are heuristic.\n            if jax.config.jax_enable_x64:  # pyright: ignore\n                return (omega == 0.0) | (alpha == 0.0) | (rho == 0.0)\n            else:\n                return (omega < 1e-16) | (alpha < 1e-16) | (rho < 1e-16)\n\n        def not_converged(r, diff, y):\n            # The primary tolerance check.\n            # Given Ay=b, then we have to be doing better than `scale` in both\n            # the `y` and the `b` spaces.\n            if has_scale:\n                with jax.numpy_dtype_promotion(\"standard\"):\n                    y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω\n                    norm1 = self.norm((r**ω / b_scale**ω).ω)  # pyright: ignore\n                    norm2 = self.norm((diff**ω / y_scale**ω).ω)\n                return (norm1 > 1) | (norm2 > 1)\n            else:\n                return True\n\n        def cond_fun(carry):\n            y, r, alpha, omega, rho, _, _, diff, step = carry\n            out = jnp.invert(breakdown_occurred(omega, alpha, rho))\n            out = out & not_converged(r, diff, y)\n            out = out & (step < max_steps)\n            return out\n\n        def body_fun(carry):\n            y, r, alpha, omega, rho, p, v, diff, step = carry\n\n            rho_new = tree_dot(r0, r)\n            beta = (rho_new / rho) * (alpha / omega)\n            p_new = (r**ω + beta * (p**ω - omega * v**ω)).ω\n\n            # TODO(raderj): reduce this to a single operator.mv call\n            # by using the scan trick.\n            x = preconditioner.mv(p_new)\n            v_new = operator.mv(x)\n\n            alpha_new = rho_new / tree_dot(r0, v_new)\n            s = (r**ω - alpha_new * v_new**ω).ω\n\n            z = preconditioner.mv(s)\n            t = operator.mv(z)\n\n            omega_new = tree_dot(s, t) / tree_dot(t, t)\n\n            diff = (alpha_new * x**ω + omega_new * z**ω).ω\n            y_new = (y**ω + diff**ω).ω\n            r_new = (s**ω - omega_new * t**ω).ω\n            return (\n                y_new,\n                r_new,\n                alpha_new,\n                omega_new,\n                rho_new,\n                p_new,\n                v_new,\n                diff,\n                step + 1,\n            )\n\n        p0 = v0 = jtu.tree_map(jnp.zeros_like, vector)\n        alpha = omega = rho = jnp.array(1.0)\n\n        init_carry = (\n            y0,\n            r0,\n            alpha,\n            omega,\n            rho,\n            p0,\n            v0,\n            ω(y0).call(lambda x: jnp.full_like(x, jnp.inf)).ω,\n            0,\n        )\n        solution, residual, alpha, omega, rho, _, _, diff, num_steps = lax.while_loop(\n            cond_fun, body_fun, init_carry\n        )\n\n        if self.max_steps is None:\n            result = RESULTS.where(\n                num_steps == max_steps, RESULTS.singular, RESULTS.successful\n            )\n        elif has_scale:\n            result = RESULTS.where(\n                num_steps == max_steps, RESULTS.max_steps_reached, RESULTS.successful\n            )\n        else:\n            result = RESULTS.successful\n        # breakdown is only an issue if we did not converge\n        breakdown = breakdown_occurred(omega, alpha, rho) & not_converged(\n            residual, diff, solution\n        )\n        result = RESULTS.where(breakdown, RESULTS.breakdown, result)\n\n        stats = {\"num_steps\": num_steps, \"max_steps\": self.max_steps}\n        return solution, result, stats\n\n    def transpose(self, state: _BiCGStabState, options: dict[str, Any]):\n        transpose_options = {}\n        if \"preconditioner\" in options:\n            transpose_options[\"preconditioner\"] = options[\"preconditioner\"].transpose()\n        operator = state\n        return operator.transpose(), transpose_options\n\n    def conj(self, state: _BiCGStabState, options: dict[str, Any]):\n        conj_options = {}\n        if \"preconditioner\" in options:\n            conj_options[\"preconditioner\"] = conj(options[\"preconditioner\"])\n        operator = state\n        return conj(operator), conj_options\n\n    def assume_full_rank(self):\n        return True\n\n\nBiCGStab.__init__.__doc__ = r\"\"\"**Arguments:**\n\n- `rtol`: Relative tolerance for terminating solve.\n- `atol`: Absolute tolerance for terminating solve.\n- `norm`: The norm to use when computing whether the error falls within the tolerance.\n    Defaults to the max norm.\n- `max_steps`: The maximum number of iterations to run the solver for. If more steps\n    than this are required, then the solve is halted with a failure.\n\"\"\"\n"
  },
  {
    "path": "lineax/_solver/cg.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport warnings\nfrom collections.abc import Callable\nfrom typing import Any, TypeAlias\n\nimport equinox.internal as eqxi\nimport jax\nimport jax.lax as lax\nimport jax.numpy as jnp\nimport jax.tree_util as jtu\nfrom equinox.internal import ω\nfrom jaxtyping import Array, PyTree, Scalar\n\nfrom .._misc import resolve_rcond, structure_equal, tree_where\nfrom .._norm import max_norm, tree_dot\nfrom .._operator import (\n    AbstractLinearOperator,\n    conj,\n    is_negative_semidefinite,\n    is_positive_semidefinite,\n    linearise,\n)\nfrom .._solution import RESULTS\nfrom .._solve import AbstractLinearSolver\nfrom .misc import preconditioner_and_y0\nfrom .normal import Normal\n\n\n_CGState: TypeAlias = tuple[AbstractLinearOperator, eqxi.Static]\n\n\n# TODO(kidger): this is pretty slow to compile.\n# - CG evaluates `operator.mv` three times.\n# Possibly this can be cheapened a bit somehow?\nclass CG(AbstractLinearSolver[_CGState]):\n    \"\"\"Conjugate gradient solver for linear systems.\n\n    The operator should be positive or negative definite.\n\n    Equivalent to `scipy.sparse.linalg.cg`.\n\n    This supports the following `options` (as passed to\n    `lx.linear_solve(..., options=...)`).\n\n    - `preconditioner`: A positive definite [`lineax.AbstractLinearOperator`][]\n        to be used as preconditioner. Defaults to\n        [`lineax.IdentityLinearOperator`][]. This method uses left preconditioning,\n        so it is the preconditioned residual that is minimized, though the actual\n        termination criteria uses the un-preconditioned residual.\n\n    - `y0`: The initial estimate of the solution to the linear system. Defaults to all\n        zeros.\n\n    \"\"\"\n\n    rtol: float\n    atol: float\n    norm: Callable[[PyTree], Scalar] = max_norm\n    stabilise_every: int | None = 10\n    max_steps: int | None = None\n\n    def __check_init__(self):\n        if isinstance(self.rtol, (int, float)) and self.rtol < 0:\n            raise ValueError(\"Tolerances must be non-negative.\")\n        if isinstance(self.atol, (int, float)) and self.atol < 0:\n            raise ValueError(\"Tolerances must be non-negative.\")\n\n        if isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)):\n            if self.atol == 0 and self.rtol == 0 and self.max_steps is None:\n                raise ValueError(\n                    \"Must specify `rtol`, `atol`, or `max_steps` (or some combination \"\n                    \"of all three).\"\n                )\n\n    def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):\n        del options\n        is_nsd = is_negative_semidefinite(operator)\n        if not structure_equal(operator.in_structure(), operator.out_structure()):\n            raise ValueError(\n                \"`CG()` may only be used for linear solves with square matrices.\"\n            )\n        if not (is_positive_semidefinite(operator) | is_nsd):\n            raise ValueError(\n                \"`CG()` may only be used for positive \"\n                \"or negative definite linear operators\"\n            )\n        if is_nsd:\n            operator = -operator\n        operator = linearise(operator)\n        return operator, eqxi.Static(is_nsd)\n\n    # This differs from jax.scipy.sparse.linalg.cg in:\n    # 1. Every few steps we calculate the residual directly, rather than by cheaply\n    #    using the existing quantities. This improves numerical stability.\n    # 2. We use a more sophisticated termination condition. To begin with we have an\n    #    rtol and atol in the conventional way, inducing a vector-valued scale. This is\n    #    then checked in both the `y` and `b` domains (for `Ay = b`).\n    # 3. We return the number of steps, and whether or not the solve succeeded, as\n    #    additional information.\n    # 4. We don't try to support complex numbers. (Yet.)\n    def compute(\n        self, state: _CGState, vector: PyTree[Array], options: dict[str, Any]\n    ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:\n        operator, is_nsd = state\n        is_nsd = is_nsd.value\n        preconditioner, y0 = preconditioner_and_y0(operator, vector, options)\n        if not is_positive_semidefinite(preconditioner):\n            raise ValueError(\"The preconditioner must be positive definite.\")\n        leaves, _ = jtu.tree_flatten(vector)\n        size = sum(leaf.size for leaf in leaves)\n        if self.max_steps is None:\n            max_steps = 10 * size  # Copied from SciPy!\n        else:\n            max_steps = self.max_steps\n        r0 = (vector**ω - operator.mv(y0) ** ω).ω\n        p0 = preconditioner.mv(r0)\n        gamma0 = tree_dot(p0, r0)\n        rcond = resolve_rcond(None, size, size, jnp.result_type(*leaves))\n        initial_value = (\n            ω(y0).call(lambda x: jnp.full_like(x, jnp.inf)).ω,\n            y0,\n            r0,\n            p0,\n            gamma0,\n            0,\n        )\n        has_scale = not (\n            isinstance(self.atol, (int, float))\n            and isinstance(self.rtol, (int, float))\n            and self.atol == 0\n            and self.rtol == 0\n        )\n        if has_scale:\n            b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω\n\n        def not_converged(r, diff, y):\n            # The primary tolerance check.\n            # Given Ay=b, then we have to be doing better than `scale` in both\n            # the `y` and the `b` spaces.\n            if has_scale:\n                with jax.numpy_dtype_promotion(\"standard\"):\n                    y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω\n                    norm1 = self.norm((r**ω / b_scale**ω).ω)  # pyright: ignore\n                    norm2 = self.norm((diff**ω / y_scale**ω).ω)\n                return (norm1 > 1) | (norm2 > 1)\n            else:\n                return True\n\n        def cond_fun(value):\n            diff, y, r, _, gamma, step = value\n            out = gamma > 0\n            out = out & (step < max_steps)\n            out = out & not_converged(r, diff, y)\n            return out\n\n        def body_fun(value):\n            _, y, r, p, gamma, step = value\n            mat_p = operator.mv(p)\n            inner_prod = tree_dot(mat_p, p)\n            alpha = gamma / inner_prod\n            alpha = tree_where(\n                jnp.abs(inner_prod) > 100 * rcond * jnp.abs(gamma),  # pyright: ignore\n                alpha,\n                jnp.nan,  # pyright: ignore\n            )\n            diff = (alpha * p**ω).ω\n            y = (y**ω + diff**ω).ω\n            step = step + 1\n\n            # E.g. see B.2 of\n            # https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf\n            # We compute the residual the \"expensive\" way every now and again, so as to\n            # correct numerical rounding errors.\n            def stable_r():\n                return (vector**ω - operator.mv(y) ** ω).ω\n\n            def cheap_r():\n                return (r**ω - alpha * mat_p**ω).ω\n\n            if self.stabilise_every == 1:\n                r = stable_r()\n            elif self.stabilise_every is None:\n                r = cheap_r()\n            else:\n                stable_step = (eqxi.unvmap_max(step) % self.stabilise_every) == 0\n                stable_step = eqxi.nonbatchable(stable_step)\n                r = lax.cond(stable_step, stable_r, cheap_r)\n\n            z = preconditioner.mv(r)\n            gamma_prev = gamma\n            gamma = tree_dot(z, r)\n            beta = gamma / gamma_prev\n            p = (z**ω + beta * p**ω).ω\n            return diff, y, r, p, gamma, step\n\n        _, solution, _, _, _, num_steps = lax.while_loop(\n            cond_fun, body_fun, initial_value\n        )\n\n        if self.max_steps is None:\n            result = RESULTS.where(\n                num_steps == max_steps, RESULTS.singular, RESULTS.successful\n            )\n        elif has_scale:\n            result = RESULTS.where(\n                num_steps == max_steps, RESULTS.max_steps_reached, RESULTS.successful\n            )\n        else:\n            result = RESULTS.successful\n\n        if is_nsd:\n            solution = -(solution**ω).ω\n        stats = {\"num_steps\": num_steps, \"max_steps\": self.max_steps}\n        return solution, result, stats\n\n    def transpose(\n        self, state: _CGState, options: dict[str, Any]\n    ) -> tuple[_CGState, dict[str, Any]]:\n        transpose_options = {}\n        if \"preconditioner\" in options:\n            transpose_options[\"preconditioner\"] = options[\"preconditioner\"].transpose()\n        psd_op, is_nsd = state\n        transpose_state = psd_op.transpose(), is_nsd\n        return transpose_state, transpose_options\n\n    def conj(\n        self, state: _CGState, options: dict[str, Any]\n    ) -> tuple[_CGState, dict[str, Any]]:\n        conj_options = {}\n        if \"preconditioner\" in options:\n            conj_options[\"preconditioner\"] = conj(options[\"preconditioner\"])\n        psd_op, is_nsd = state\n        conj_state = conj(psd_op), is_nsd\n        return conj_state, conj_options\n\n    def assume_full_rank(self):\n        return True\n\n\nCG.__init__.__doc__ = r\"\"\"**Arguments:**\n\n- `rtol`: Relative tolerance for terminating solve.\n- `atol`: Absolute tolerance for terminating solve.\n- `norm`: The norm to use when computing whether the error falls within the tolerance.\n    Defaults to the max norm.\n- `stabilise_every`: The conjugate gradient is an iterative method that produces\n    candidate solutions $x_1, x_2, \\ldots$, and terminates once $r_i = \\| Ax_i - b \\|$\n    is small enough. For computational efficiency, the values $r_i$ are computed using\n    other internal quantities, and not by directly evaluating the formula above.\n    However, this computation of $r_i$ is susceptible to drift due to limited\n    floating-point precision. Every `stabilise_every` steps, then $r_i$ is computed\n    directly using the formula above, in order to stabilise the computation.\n- `max_steps`: The maximum number of iterations to run the solver for. If more steps\n    than this are required, then the solve is halted with a failure.\n\"\"\"\n\n\ndef NormalCG(*args, **kwargs):\n    \"\"\"Deprecated helper function. Use `lx.Normal(lx.CG(...))` instead.\n\n    !!! warning \"Deprecated\"\n        `NormalCG(...)` is deprecated in favour of `lx.Normal(lx.CG(...))`.\n        This will be removed in some future version of Lineax.\n    \"\"\"\n    warnings.warn(\n        \"`NormalCG(...)` is deprecated in favour of `lx.Normal(lx.CG(...))`. \"\n        \"This will be removed in some future version of Lineax.\",\n        DeprecationWarning,\n        stacklevel=2,\n    )\n    return Normal(CG(*args, **kwargs))\n"
  },
  {
    "path": "lineax/_solver/cholesky.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Any, TypeAlias\n\nimport equinox.internal as eqxi\nimport jax.flatten_util as jfu\nimport jax.scipy as jsp\nfrom jaxtyping import Array, PyTree\n\nfrom .._operator import (\n    AbstractLinearOperator,\n    is_negative_semidefinite,\n    is_positive_semidefinite,\n)\nfrom .._solution import RESULTS\nfrom .._solve import AbstractLinearSolver\n\n\n_CholeskyState: TypeAlias = tuple[Array, eqxi.Static]\n\n\nclass Cholesky(AbstractLinearSolver[_CholeskyState]):\n    \"\"\"Cholesky solver for linear systems. This is generally the preferred solver for\n    positive or negative definite systems.\n\n    Equivalent to `scipy.linalg.solve(..., assume_a=\"pos\")`.\n\n    The operator must be square, nonsingular, and either positive or negative definite.\n    \"\"\"\n\n    def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):\n        del options\n        is_nsd = is_negative_semidefinite(operator)\n        if not (is_positive_semidefinite(operator) | is_nsd):\n            raise ValueError(\n                \"`Cholesky(..., normal=False)` may only be used for positive \"\n                \"or negative definite linear operators\"\n            )\n        matrix = operator.as_matrix()\n        m, n = matrix.shape\n        if m != n:\n            raise ValueError(\n                \"`Cholesky(..., normal=False)` may only be used for linear solves \"\n                \"with square matrices\"\n            )\n        if is_nsd:\n            matrix = -matrix\n        factor, lower = jsp.linalg.cho_factor(matrix)\n        # Fix upper triangular for simplicity.\n        assert lower is False\n        return factor, eqxi.Static(is_nsd)\n\n    def compute(\n        self, state: _CholeskyState, vector: PyTree[Array], options: dict[str, Any]\n    ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:\n        factor, is_nsd = state\n        is_nsd = is_nsd.value\n        del options\n        # Cholesky => PSD => symmetric => (in_structure == out_structure) =>\n        # we don't need to use packed structures.\n        vector, unflatten = jfu.ravel_pytree(vector)\n        solution = jsp.linalg.cho_solve((factor, False), vector)\n        if is_nsd:\n            solution = -solution\n        solution = unflatten(solution)\n        return solution, RESULTS.successful, {}\n\n    def transpose(\n        self, state: _CholeskyState, options: dict[str, Any]\n    ) -> tuple[_CholeskyState, dict[str, Any]]:\n        # Matrix is self-adjoint\n        factor, is_nsd = state\n        return (factor.conj(), is_nsd), options\n\n    def conj(\n        self, state: _CholeskyState, options: dict[str, Any]\n    ) -> tuple[_CholeskyState, dict[str, Any]]:\n        # Matrix is self-adjoint\n        factor, is_nsd = state\n        return (factor.conj(), is_nsd), options\n\n    def assume_full_rank(self):\n        return True\n\n\nCholesky.__init__.__doc__ = \"\"\"**Arguments:**\n\nNothing.\n\"\"\"\n"
  },
  {
    "path": "lineax/_solver/diagonal.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Any, TypeAlias\n\nimport jax.numpy as jnp\nfrom jaxtyping import Array, PyTree\n\nfrom .._misc import resolve_rcond\nfrom .._operator import AbstractLinearOperator, diagonal, has_unit_diagonal, is_diagonal\nfrom .._solution import RESULTS\nfrom .._solve import AbstractLinearSolver\nfrom .misc import (\n    pack_structures,\n    PackedStructures,\n    ravel_vector,\n    transpose_packed_structures,\n    unravel_solution,\n)\n\n\n_DiagonalState: TypeAlias = tuple[Array | None, PackedStructures]\n\n\nclass Diagonal(AbstractLinearSolver[_DiagonalState]):\n    \"\"\"Diagonal solver for linear systems.\n\n    Requires that the operator be diagonal. Then $Ax = b$, with $A = diag[a]$, is\n    solved simply by doing an elementwise division $x = b / a$.\n\n    This solver can handle singular operators (i.e. diagonal entries with value 0).\n    \"\"\"\n\n    well_posed: bool = False\n    rcond: float | None = None\n\n    def init(\n        self, operator: AbstractLinearOperator, options: dict[str, Any]\n    ) -> _DiagonalState:\n        del options\n        if operator.in_size() != operator.out_size():\n            raise ValueError(\n                \"`Diagonal` may only be used for linear solves with square matrices\"\n            )\n        if not is_diagonal(operator):\n            raise ValueError(\n                \"`Diagonal` may only be used for linear solves with diagonal matrices\"\n            )\n        packed_structures = pack_structures(operator)\n        if has_unit_diagonal(operator):\n            return None, packed_structures\n        else:\n            return diagonal(operator), packed_structures\n\n    def compute(\n        self, state: _DiagonalState, vector: PyTree[Array], options: dict[str, Any]\n    ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:\n        diag, packed_structures = state\n        del state, options\n        unit_diagonal = diag is None\n        vector = ravel_vector(vector, packed_structures)\n        if unit_diagonal:\n            solution = vector\n        else:\n            if not self.well_posed:\n                (size,) = diag.shape\n                rcond = resolve_rcond(self.rcond, size, size, diag.dtype)\n                abs_diag = jnp.abs(diag)\n                diag = jnp.where(abs_diag > rcond * jnp.max(abs_diag), diag, jnp.inf)  # pyright: ignore\n            solution = vector / diag\n        solution = unravel_solution(solution, packed_structures)\n        return solution, RESULTS.successful, {}\n\n    def transpose(self, state: _DiagonalState, options: dict[str, Any]):\n        del options\n        diag, packed_structures = state\n        transposed_packed_structures = transpose_packed_structures(packed_structures)\n        transpose_state = diag, transposed_packed_structures\n        transpose_options = {}\n        return transpose_state, transpose_options\n\n    def conj(self, state: _DiagonalState, options: dict[str, Any]):\n        del options\n        diag, packed_structures = state\n        if diag is None:\n            conj_diag = None\n        else:\n            conj_diag = diag.conj()\n        conj_options = {}\n        conj_state = conj_diag, packed_structures\n        return conj_state, conj_options\n\n    def assume_full_rank(self):\n        return self.well_posed\n\n\nDiagonal.__init__.__doc__ = \"\"\"**Arguments**:\n\n- `well_posed`: if `False`, then singular operators are accepted, and the pseudoinverse\n    solution is returned. If `True` then passing a singular operator will cause an error\n    to be raised instead.\n- `rcond`: the cutoff for handling zero entries on the diagonal. Defaults to machine\n    precision times `N`, where `N` is the input (or output) size of the operator.\n    Only used if `well_posed=False`\n\"\"\"\n"
  },
  {
    "path": "lineax/_solver/gmres.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools as ft\nfrom collections.abc import Callable\nfrom typing import Any, cast, TypeAlias\n\nimport equinox.internal as eqxi\nimport jax\nimport jax.lax as lax\nimport jax.numpy as jnp\nimport jax.tree_util as jtu\nfrom equinox.internal import ω\nfrom jaxtyping import Array, ArrayLike, Bool, Float, Inexact, PyTree\n\nfrom .._misc import structure_equal\nfrom .._norm import max_norm, two_norm\nfrom .._operator import AbstractLinearOperator, conj, linearise, MatrixLinearOperator\nfrom .._solution import RESULTS\nfrom .._solve import AbstractLinearSolver, linear_solve\nfrom .misc import preconditioner_and_y0\nfrom .qr import QR\n\n\n_GMRESState: TypeAlias = AbstractLinearOperator\n\n\nclass GMRES(AbstractLinearSolver[_GMRESState]):\n    \"\"\"GMRES solver for linear systems.\n\n    The operator should be square.\n\n    Similar to `jax.scipy.sparse.linalg.gmres`.\n\n    This supports the following `options` (as passed to\n    `lx.linear_solve(..., options=...)`).\n\n    - `preconditioner`: A [`lineax.AbstractLinearOperator`][]\n        to be used as preconditioner. Defaults to\n        [`lineax.IdentityLinearOperator`][]. This method uses left preconditioning,\n        so it is the preconditioned residual that is minimized, though the actual\n        termination criteria uses the un-preconditioned residual.\n    - `y0`: The initial estimate of the solution to the linear system. Defaults to all\n        zeros.\n    \"\"\"\n\n    rtol: float\n    atol: float\n    norm: Callable = max_norm\n    max_steps: int | None = None\n    restart: int = 20\n    stagnation_iters: int = 20\n\n    def __check_init__(self):\n        if isinstance(self.rtol, (int, float)) and self.rtol < 0:\n            raise ValueError(\"Tolerances must be non-negative.\")\n        if isinstance(self.atol, (int, float)) and self.atol < 0:\n            raise ValueError(\"Tolerances must be non-negative.\")\n\n        if isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)):\n            if self.atol == 0 and self.rtol == 0 and self.max_steps is None:\n                raise ValueError(\n                    \"Must specify `rtol`, `atol`, or `max_steps` (or some combination \"\n                    \"of all three).\"\n                )\n\n    def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):\n        del options\n        if not structure_equal(operator.in_structure(), operator.out_structure()):\n            raise ValueError(\n                \"`GMRES(..., normal=False)` may only be used for linear solves with \"\n                \"square matrices.\"\n            )\n        return linearise(operator)\n\n    #\n    # This differs from `jax.scipy.sparse.linalg.gmres` in a few ways:\n    # 1. We use a more sophisticated termination condition. To begin with we have an\n    #    rtol and atol in the conventional way, inducing a vector-valued scale. This is\n    #    then checked in both the `y` and `b` domains (for `Ay = b`).\n    # 2. We handle in-place updates with buffers to avoid generating unnecessary\n    #    copies of arrays during the Gram-Schmidt procedure.\n    # 3. We use a QR solve at the end of the batched Gram-Schmidt instead\n    #    of a Cholesky solve of the normal equations. This is both faster and more\n    #    numerically stable.\n    # 4. We use tricks to compile `A y` fewer times throughout the code, including\n    #    passing a dummy initial residual.\n    # 5. We return the number of steps, and whether or not the solve succeeded, as\n    #    additional information.\n    # 6. We do not use the unnecessary loop within Gram-Schmidt, and simply compute\n    #    this in a single pass.\n    # 7. We add better safety checks for breakdown, and a safety check for stagnation\n    #    of the iterates even when we don't explicitly get breakdown.\n    #\n    def compute(\n        self,\n        state: _GMRESState,\n        vector: PyTree[Array],\n        options: dict[str, Any],\n    ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:\n        has_scale = not (\n            isinstance(self.atol, (int, float))\n            and isinstance(self.rtol, (int, float))\n            and self.atol == 0\n            and self.rtol == 0\n        )\n        if has_scale:\n            b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω\n        operator = state\n        preconditioner, y0 = preconditioner_and_y0(operator, vector, options)\n        leaves, _ = jtu.tree_flatten(vector)\n        size = sum(leaf.size for leaf in leaves)\n        if self.max_steps is None:\n            max_steps = 10 * size  # Copied from SciPy!\n        else:\n            max_steps = self.max_steps\n        restart = min(self.restart, size)\n\n        def not_converged(r, diff, y):\n            # The primary tolerance check.\n            # Given Ay=b, then we have to be doing better than `scale` in both\n            # the `y` and the `b` spaces.\n            if has_scale:\n                with jax.numpy_dtype_promotion(\"standard\"):\n                    y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω\n                    norm1 = self.norm((r**ω / b_scale**ω).ω)  # pyright: ignore\n                    norm2 = self.norm((diff**ω / y_scale**ω).ω)\n                return (norm1 > 1) | (norm2 > 1)\n            else:\n                return True\n\n        def cond_fun(carry):\n            y, r, _, deferred_breakdown, diff, _, step, stagnation_counter = carry\n            # NOTE: we defer ending due to breakdown by one loop! This is nonstandard,\n            # but lets us use a cauchy-like condition in the convergence criteria.\n            # If we do not defer breakdown, breakdown may detect convergence when\n            # the diff between two iterations is still quite large, and we only\n            # consider convergence when the diff is small.\n            out = jnp.invert(deferred_breakdown) & (\n                stagnation_counter < self.stagnation_iters\n            )\n            out = out & not_converged(r, diff, y)\n            out = out & (step < max_steps)\n            # The first pass uses a dummy value for r0 in order to save on compiling\n            # an extra matvec. The dummy step may raise a breakdown, and `step == 0`\n            # avoids us from returning prematurely.\n            return out | (step == 0)\n\n        def body_fun(carry):\n            # `breakdown` -> `deferred_breakdown` and `deferred_breakdown` -> `_`\n            y, r, deferred_breakdown, _, diff, r_min, step, stagnation_counter = carry\n            y_new, r_new, breakdown, diff_new = self._gmres_compute(\n                operator, vector, y, r, restart, preconditioner, step == 0\n            )\n\n            #\n            # If the minimum residual does not decrease for many iterations\n            # (\"many\" is determined by self.stagnation_iters) then the iterative\n            # solve has stagnated and we stop the loop. This bit keeps track of how\n            # long it has been since the minimum has decreased, and updates the minimum\n            # when a new minimum is encountered. As far as I (raderj) am\n            # aware, this is custom to our implementation and not standard practice.\n            #\n            r_new_norm = self.norm(r_new)\n            r_decreased = (r_new_norm - r_min) < 0\n            stagnation_counter = jnp.where(r_decreased, 0, stagnation_counter + 1)\n            stagnation_counter = cast(Array, stagnation_counter)\n            r_min = jnp.minimum(r_new_norm, r_min)\n\n            return (\n                y_new,\n                r_new,\n                breakdown,\n                deferred_breakdown,\n                diff_new,\n                r_min,\n                step + 1,\n                stagnation_counter,\n            )\n\n        # Initialise the residual r0 to the dummy value of all 0s. This means\n        # the first iteration of Gram-Schmidt will do nothing, but it saves\n        # us from compiling an extra matvec here.\n        r0 = ω(vector).call(jnp.zeros_like).ω\n        init_carry = (\n            y0,  # y\n            r0,  # residual\n            False,  # breakdown\n            False,  # deferred_breakdown\n            ω(y0).call(lambda x: jnp.full_like(x, jnp.inf)).ω,  # diff\n            jnp.inf,  # r_min\n            0,  # steps\n            jnp.array(0),  # stagnation counter\n        )\n        (\n            solution,\n            residual,\n            _,  # breakdown\n            breakdown,  # deferred_breakdown\n            diff,\n            _,\n            num_steps,\n            stagnation_counter,\n        ) = lax.while_loop(cond_fun, body_fun, init_carry)\n\n        if self.max_steps is None:\n            result = RESULTS.where(\n                num_steps == max_steps, RESULTS.singular, RESULTS.successful\n            )\n        elif has_scale:\n            result = RESULTS.where(\n                num_steps == max_steps, RESULTS.max_steps_reached, RESULTS.successful\n            )\n        else:\n            result = RESULTS.successful\n\n        result = RESULTS.where(\n            stagnation_counter >= self.stagnation_iters, RESULTS.stagnation, result\n        )\n\n        # breakdown is only an issue if we broke down outside the tolerance\n        # of the solution. If we get breakdown and are within the tolerance,\n        # this is called convergence :)\n        breakdown = breakdown & not_converged(residual, diff, solution)\n        # breakdown is the most serious potential issue\n        result = RESULTS.where(breakdown, RESULTS.breakdown, result)\n\n        stats = {\"num_steps\": num_steps, \"max_steps\": self.max_steps}\n        return solution, result, stats\n\n    def _gmres_compute(\n        self, operator, vector, y, r, restart, preconditioner, first_pass\n    ):\n        #\n        # internal function for computing the bulk of the gmres. We seperate this out\n        # for two reasons:\n        # 1. avoid nested body and cond functions in the body and cond function of\n        # `self.compute`. `self.compute` is primarily responsible for the restart\n        # behavior of gmres.\n        # 2. Like the jax.scipy implementation we may want to add an incremental\n        # version at a later date.\n        #\n\n        def main_gmres(y):\n            # see the comment at the end of `_arnoldi_gram_schmidt` for a discussion\n            # of `initial_breakdown`\n            r_normalised, r_norm, initial_breakdown = self._normalise(r, eps=None)\n            basis_init = jtu.tree_map(\n                lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),\n                r_normalised,\n            )\n            coeff_mat_init = jnp.eye(\n                restart,\n                restart + 1,\n                dtype=jnp.result_type(*jtu.tree_leaves(r_normalised)),\n            )\n\n            def cond_fun(carry):\n                _, _, breakdown, step = carry\n                return (step < restart) & jnp.invert(breakdown)\n\n            def body_fun(carry):\n                basis, coeff_mat, breakdown, step = carry\n                basis_new, coeff_mat_new, breakdown = self._arnoldi_gram_schmidt(\n                    operator,\n                    preconditioner,\n                    basis,\n                    coeff_mat,\n                    step,\n                    restart,\n                    vector,\n                    breakdown,\n                )\n                return basis_new, coeff_mat_new, breakdown, step + 1\n\n            def buffers(carry):\n                basis, coeff_mat, _, _ = carry\n                return basis, coeff_mat\n\n            init_carry = (basis_init, coeff_mat_init, initial_breakdown, 0)\n            basis, coeff_mat, breakdown, steps = eqxi.while_loop(\n                cond_fun, body_fun, init_carry, kind=\"lax\", buffers=buffers\n            )\n            beta_vec = jnp.concatenate(\n                (\n                    r_norm[None].astype(jnp.result_type(coeff_mat)),\n                    jnp.zeros_like(coeff_mat, shape=(restart,)),\n                )\n            )\n            coeff_op_transpose = MatrixLinearOperator(coeff_mat.T)\n            # TODO(raderj): move to a Hessenberg-specific solver\n            z = linear_solve(coeff_op_transpose, beta_vec, QR(), throw=False).value\n            diff = jtu.tree_map(\n                lambda mat: jnp.tensordot(\n                    mat[..., :-1], z, axes=1, precision=lax.Precision.HIGHEST\n                ),\n                basis,\n            )\n            y_new = (y**ω + diff**ω).ω\n            return y_new, diff, breakdown\n\n        def first_gmres(y):\n            return y, ω(y).call(lambda x: jnp.full_like(x, jnp.inf)).ω, False\n\n        first_pass = eqxi.unvmap_any(first_pass)\n        y_new, diff, breakdown = lax.cond(first_pass, first_gmres, main_gmres, y)\n        r_new = preconditioner.mv((vector**ω - operator.mv(y_new) ** ω).ω)\n\n        return y_new, r_new, breakdown, diff\n\n        # NOTE: in the jax implementation:\n        # https://github.com/google/jax/blob/\n        # c662fd216dec10cdb2cff4138b4318bb98853134/jax/_src/scipy/sparse/linalg.py#L327\n        # _classical_iterative_gram_schmidt uses a while loop to call this.\n        # However, max_iterations is set to 2 in all calls they make to the function,\n        # and the condition function requires steps < (max_iterations - 1).\n        # This means that in fact they only apply Gram-Schmidt once, and using a\n        # while_loop is unnecessary.\n\n    def _arnoldi_gram_schmidt(\n        self,\n        operator,\n        preconditioner,\n        basis,\n        coeff_mat,\n        step,\n        restart,\n        vector,\n        initial_breakdown,\n    ):\n        #\n        # compute `basis.T @ basis_step` for each leaf of pytree\n        # and then compute the projected vector onto the basis\n        #\n        # `basis` is a pytree with buffers, meaning it can only be\n        # indexed into. Through this section, there are terms like `lambda _, x: ...`\n        # because`jtu.tree_map` only uses the first argument to determine the shape\n        # of the pytree. Since _Buffer is considered part of the pytree\n        # structure, we get leaves which are not buffers if we directly pass `basis`.\n        # Instead, we make sure that the first argument of the tree map is something\n        # with the correct pytree structure, such as `vector` in the dummy case and\n        # basis_step when not, so that we correctly index into `basis`.\n        #\n        basis_step = preconditioner.mv(\n            operator.mv(jtu.tree_map(lambda _, x: x[..., step], vector, basis))\n        )\n        step_norm = two_norm(basis_step)\n        contract_matrix = lambda x, y: ft.partial(\n            jnp.tensordot, axes=x.ndim, precision=lax.Precision.HIGHEST\n        )(x, y[...].conj())\n        _proj = jtu.tree_map(contract_matrix, basis_step, basis)\n        proj = jtu.tree_reduce(lambda x, y: x + y, _proj)\n        proj_on_cols = jtu.tree_map(lambda _, x: x[...] @ proj, vector, basis)\n        # now remove the component of the vector in that subspace\n        basis_step_new = (basis_step**ω - proj_on_cols**ω).ω\n        eps = step_norm * jnp.finfo(proj.dtype).eps\n        basis_step_normalised, step_norm_new, breakdown = self._normalise(\n            basis_step_new, eps=eps\n        )\n        basis_new = jtu.tree_map(\n            lambda y, mat: mat.at[..., step + 1].set(y),\n            basis_step_normalised,\n            basis,\n        )\n        proj_new = proj.at[step + 1].set(step_norm_new.astype(jnp.result_type(proj)))\n        #\n        # NOTE: two somewhat complicated things are going on here:\n        #\n        # The `coeff_mat` in_place update has a batch tracer, so we need to be\n        # careful and wrap it in a buffer, hence the use of eqxi.while_loop\n        # instead of lax.while_loop throughout.\n        #\n        # `initial_breakdown` occurs when the previous loop returns a\n        # residual which is small enough to be interpreted as 0 by self._normalise,\n        # but which was passed through the solver anyway. This occurs when\n        # the residual is small but the diff is not, or if the\n        # correct solution was given to GMRES from the start. Both of these tend to\n        # happen at the start of `gmres_compute`.\n        # The latter may happen when using a sequence of iterative methods.\n        # If `initial_breakdown` occurs, then we leave the `coeff_mat` as it was\n        # at initialisation. Replacing it with the projection (which will be all 0s)\n        # will mean `coeff_mat` is not full-rank, and `QR` can only handle nonsquare\n        # matrices of full-rank.\n        #\n        coeff_mat_new = coeff_mat.at[step, :].set(\n            proj_new, pred=jnp.invert(initial_breakdown)\n        )\n        return basis_new, coeff_mat_new, breakdown\n\n    def _normalise(\n        self, x: PyTree[Array], eps: Float[ArrayLike, \"\"] | None\n    ) -> tuple[PyTree[Array], Inexact[Array, \"\"], Bool[ArrayLike, \"\"]]:\n        norm = two_norm(x)\n        if eps is None:\n            eps = jnp.finfo(norm.dtype).eps\n        else:\n            eps = jnp.astype(eps, norm.dtype)\n        breakdown = norm < eps  # pyright: ignore\n        safe_norm = jnp.where(breakdown, jnp.inf, norm)\n        with jax.numpy_dtype_promotion(\"standard\"):\n            x_normalised = (x**ω / safe_norm).ω\n        return x_normalised, norm, breakdown\n\n    def transpose(self, state: _GMRESState, options: dict[str, Any]):\n        transpose_options = {}\n        if \"preconditioner\" in options:\n            transpose_options[\"preconditioner\"] = options[\"preconditioner\"].transpose()\n        operator = state\n        return operator.transpose(), transpose_options\n\n    def conj(self, state: _GMRESState, options: dict[str, Any]):\n        conj_options = {}\n        if \"preconditioner\" in options:\n            conj_options[\"preconditioner\"] = conj(options[\"preconditioner\"])\n        operator = state\n        return conj(operator), conj_options\n\n    def assume_full_rank(self):\n        return True\n\n\nGMRES.__init__.__doc__ = r\"\"\"**Arguments:**\n\n- `rtol`: Relative tolerance for terminating solve.\n- `atol`: Absolute tolerance for terminating solve.\n- `norm`: The norm to use when computing whether the error falls within the tolerance.\n    Defaults to the max norm.\n- `max_steps`: The maximum number of iterations to run the solver for. If more steps\n    than this are required, then the solve is halted with a failure.\n- `restart`: Size of the Krylov subspace built between restarts. The returned solution\n    is the projection of the true solution onto this subpsace, so this direclty\n    bounds the accuracy of the algorithm. Default is 20.\n- `stagnation_iters`: The maximum number of iterations for which the solver may not\n    decrease. If more than `stagnation_iters` restarts are performed without\n    sufficient decrease in the residual, the algorithm is halted.\n\"\"\"\n"
  },
  {
    "path": "lineax/_solver/lsmr.py",
    "content": "\"\"\"Implementation adapted from SciPy, with BSD license:\n\nCopyright (c) 2001-2002 Enthought, Inc. 2003-2024, SciPy Developers.\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions\nare met:\n\n1. Redistributions of source code must retain the above copyright\n   notice, this list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above\n   copyright notice, this list of conditions and the following\n   disclaimer in the documentation and/or other materials provided\n   with the distribution.\n\n3. Neither the name of the copyright holder nor the names of its\n   contributors may be used to endorse or promote products derived\n   from this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\n\"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\nLIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\nA PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\nOWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\nSPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\nLIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\nDATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\nTHEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\"\"\"\n\nfrom collections.abc import Callable\nfrom typing import Any, TypeAlias\n\nimport jax.lax as lax\nimport jax.numpy as jnp\nimport jax.tree_util as jtu\nfrom equinox.internal import ω\nfrom jaxtyping import Array, PyTree\n\nfrom .._misc import complex_to_real_dtype\nfrom .._norm import two_norm\nfrom .._operator import AbstractLinearOperator, conj, linearise\nfrom .._solution import RESULTS\nfrom .._solve import AbstractLinearSolver\n\n\n_LSMRState: TypeAlias = AbstractLinearOperator\n\n\nclass LSMR(AbstractLinearSolver[_LSMRState]):\n    \"\"\"LSMR solver for linear systems.\n\n    This solver can handle any operator, even nonsquare or singular ones. In these\n    cases it will return the pseudoinverse solution to the linear system.\n\n    Similar to `scipy.sparse.linalg.lsmr`.\n\n    This supports the following `options` (as passed to\n    `lx.linear_solve(..., options=...)`).\n\n    - `y0`: The initial estimate of the solution to the linear system. Defaults to all\n        zeros.\n    \"\"\"\n\n    rtol: float\n    atol: float\n    norm: Callable = two_norm\n    max_steps: int | None = None\n    conlim: float = 1e8\n\n    def __check_init__(self):\n        if isinstance(self.rtol, (int, float)) and self.rtol < 0:\n            raise ValueError(\"Tolerances must be non-negative.\")\n        if isinstance(self.atol, (int, float)) and self.atol < 0:\n            raise ValueError(\"Tolerances must be non-negative.\")\n        if isinstance(self.conlim, (int, float)) and self.conlim < 0:\n            raise ValueError(\"Tolerances must be non-negative.\")\n\n        if isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)):\n            if self.atol == 0 and self.rtol == 0 and self.max_steps is None:\n                raise ValueError(\n                    \"Must specify `atol`, `rtol`, or `max_steps` (or some combination \"\n                    \"of all three).\"\n                )\n\n    def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):\n        return linearise(operator)\n\n    def compute(\n        self,\n        state: _LSMRState,\n        vector: PyTree[Array],\n        options: dict[str, Any],\n    ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:\n        operator = state\n        x = options.get(\"y0\", None)\n        # damp is not supported at this time.\n        #  damp = options.get(\"damp\", 0.0)\n        damp = 0.0\n        has_scale = not (\n            isinstance(self.atol, (int, float))\n            and isinstance(self.rtol, (int, float))\n            and self.atol == 0\n            and self.rtol == 0\n        )\n\n        dtype = jnp.result_type(\n            *jtu.tree_leaves(vector),\n            *jtu.tree_leaves(x),\n            *jtu.tree_leaves(operator.in_structure()),\n        )\n\n        m, n = operator.out_size(), operator.in_size()\n        # number of singular values\n        min_dim = min([m, n])\n        if self.max_steps is None:\n            # Set max_steps based on the minimum dimension + avoid numerical overflows\n            # https://github.com/patrick-kidger/lineax/issues/175\n            # https://github.com/patrick-kidger/lineax/issues/177\n            int_dtype = jnp.dtype(f\"int{complex_to_real_dtype(dtype).itemsize * 8}\")\n            if min_dim > (jnp.iinfo(int_dtype).max / 10):\n                max_steps = jnp.iinfo(int_dtype).max\n            else:\n                max_steps = min_dim * 10  # for consistency with other iterative solvers\n        else:\n            max_steps = self.max_steps\n\n        if x is None:\n            x = jtu.tree_map(jnp.zeros_like, operator.in_structure())\n\n        b = vector\n        u = (ω(b) - ω(operator.mv(x))).ω\n        normb = self.norm(b)\n        beta = self.norm(u)\n\n        def beta_nonzero(beta, u):\n            u = (ω(u) / lax.select(beta == 0.0, 1.0, beta).astype(dtype)).ω\n            v = conj(operator).T.mv(u)\n            alpha = self.norm(v)\n            return u, v, alpha\n\n        def beta_zero(beta, u):\n            v = jtu.tree_map(jnp.zeros_like, operator.in_structure())\n            alpha = 0.0\n            return u, v, alpha\n\n        u, v, alpha = lax.cond(beta == 0.0, beta_zero, beta_nonzero, beta, u)\n        v = (ω(v) / lax.select(alpha == 0.0, 1.0, alpha).astype(dtype)).ω\n\n        h = v\n        hbar = jtu.tree_map(jnp.zeros_like, operator.in_structure())\n\n        # Initialize variables for 1st iteration.\n        # generally, latin letters (b, x, u, v, h etc) are vectors that may be complex\n        # greek letters (alpha, beta, rho, zeta etc) are scalars that are always real\n        loop_state = dict(\n            # vectors\n            x=x,\n            u=u,\n            v=v,\n            h=h,\n            hbar=hbar,\n            # main loop variables\n            itn=0,\n            alpha=alpha,\n            beta=beta,\n            zetabar=alpha * beta,\n            alphabar=alpha,\n            rho=1.0,\n            rhobar=1.0,\n            cbar=1.0,\n            sbar=0.0,\n            # loop variables for estimation of ||r||.\n            betadd=beta,\n            betad=0.0,\n            rhodold=1.0,\n            tautildeold=0.0,\n            thetatilde=0.0,\n            zeta=0.0,\n            delta=0.0,\n            # variables for estimation of ||A|| and cond(A)\n            normA2=alpha**2,\n            maxrbar=0.0,\n            minrbar=jnp.finfo(dtype).max,\n            condA=1.0,\n            # variables for use in stopping rules\n            istop=0,\n            normr=beta,\n            normAr=alpha * beta,\n        )\n        # beta == 0 means x exactly solves the well posed problem\n        # alpha == 0 means x exactly solves the least squares problem\n        # we check this here to shortcut the loop to avoid division by zero\n        loop_state[\"istop\"] = lax.select(alpha == 0, 2, loop_state[\"istop\"])\n        loop_state[\"istop\"] = lax.select(beta == 0, 1, loop_state[\"istop\"])\n\n        def condfun(loop_state):\n            return loop_state[\"istop\"] == 0\n\n        def bodyfun(loop_state):\n            st = loop_state  # to avoid writing out loop_state every time\n            st[\"itn\"] = st[\"itn\"] + 1\n\n            # Perform the next step of the bidiagonalization to obtain the\n            # next  beta, u, alpha, v.  These satisfy the relations\n            #         beta*u  =  A@v   -  alpha*u,\n            #        alpha*v  =  A'@u  -  beta*v.\n\n            st[\"u\"] = (ω(st[\"u\"]) * -st[\"alpha\"].astype(dtype)).ω\n            st[\"u\"] = (ω(st[\"u\"]) + ω(operator.mv(st[\"v\"]))).ω\n            st[\"beta\"] = self.norm(st[\"u\"])\n\n            def beta_nonzero(alpha, beta, u, v):\n                u = (ω(u) / lax.select(beta == 0.0, 1.0, beta).astype(dtype)).ω\n                v = (ω(v) * -beta.astype(dtype)).ω\n                v = (ω(v) + ω(conj(operator).T.mv(u))).ω\n                alpha = self.norm(v)\n                v = (ω(v) / lax.select(alpha == 0.0, 1.0, alpha).astype(dtype)).ω\n                return alpha, beta, u, v\n\n            def beta_zero(alpha, beta, u, v):\n                return alpha, beta, u, v\n\n            st[\"alpha\"], st[\"beta\"], st[\"u\"], st[\"v\"] = lax.cond(\n                st[\"beta\"] == 0,\n                beta_zero,\n                beta_nonzero,\n                st[\"alpha\"],\n                st[\"beta\"],\n                st[\"u\"],\n                st[\"v\"],\n            )\n            # At this point, beta = beta_{k+1}, alpha = alpha_{k+1}.\n\n            # Construct rotation Qhat_{k,2k+1}.\n            chat, shat, alphahat = self._givens(st[\"alphabar\"], damp)\n\n            # Use a plane rotation (Q_i) to turn B_i to R_i\n            rhoold = st[\"rho\"]\n            c, s, st[\"rho\"] = self._givens(alphahat, st[\"beta\"])\n            thetanew = s * st[\"alpha\"]\n            st[\"alphabar\"] = c * st[\"alpha\"]\n\n            # Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar\n            rhobarold = st[\"rhobar\"]\n            zetaold = st[\"zeta\"]\n            thetabar = st[\"sbar\"] * st[\"rho\"]\n            rhotemp = st[\"cbar\"] * st[\"rho\"]\n            st[\"cbar\"], st[\"sbar\"], st[\"rhobar\"] = self._givens(\n                st[\"cbar\"] * st[\"rho\"], thetanew\n            )\n            st[\"zeta\"] = st[\"cbar\"] * st[\"zetabar\"]\n            st[\"zetabar\"] = -st[\"sbar\"] * st[\"zetabar\"]\n\n            # Update h, h_hat, x.\n            st[\"hbar\"] = (\n                ω(st[\"hbar\"])\n                * -(thetabar * st[\"rho\"] / (rhoold * rhobarold)).astype(dtype)\n            ).ω\n            st[\"hbar\"] = (ω(st[\"hbar\"]) + ω(st[\"h\"])).ω\n            st[\"x\"] = (\n                ω(st[\"x\"])\n                + (st[\"zeta\"] / (st[\"rho\"] * st[\"rhobar\"])).astype(dtype)\n                * ω(st[\"hbar\"])\n            ).ω\n            st[\"h\"] = (ω(st[\"h\"]) * -(thetanew / st[\"rho\"]).astype(dtype)).ω\n            st[\"h\"] = (ω(st[\"h\"]) + ω(st[\"v\"])).ω\n\n            # Estimate of ||r||.\n            # Apply rotation Qhat_{k,2k+1}.\n            betaacute = chat * st[\"betadd\"]\n            betacheck = -shat * st[\"betadd\"]\n            # Apply rotation Q_{k,k+1}.\n            betahat = c * betaacute\n            st[\"betadd\"] = -s * betaacute\n\n            # Apply rotation Qtilde_{k-1}.\n            # betad = betad_{k-1} here.\n            thetatildeold = st[\"thetatilde\"]\n            ctildeold, stildeold, rhotildeold = self._givens(st[\"rhodold\"], thetabar)\n            st[\"thetatilde\"] = stildeold * loop_state[\"rhobar\"]\n            st[\"rhodold\"] = ctildeold * st[\"rhobar\"]\n            st[\"betad\"] = -stildeold * st[\"betad\"] + ctildeold * betahat\n\n            # betad   = betad_k here.\n            # rhodold = rhod_k  here.\n\n            loop_state[\"tautildeold\"] = (\n                zetaold - thetatildeold * st[\"tautildeold\"]\n            ) / rhotildeold\n            taud = (st[\"zeta\"] - st[\"thetatilde\"] * st[\"tautildeold\"]) / st[\"rhodold\"]\n            st[\"delta\"] = st[\"delta\"] + betacheck**2\n            st[\"normr\"] = jnp.sqrt(\n                st[\"delta\"] + (st[\"betad\"] - taud) ** 2 + st[\"betadd\"] ** 2\n            )\n\n            # Estimate ||A||.\n            st[\"normA2\"] = st[\"normA2\"] + st[\"beta\"] ** 2\n            normA = jnp.sqrt(st[\"normA2\"])\n            st[\"normA2\"] = st[\"normA2\"] + st[\"alpha\"] ** 2\n\n            # Estimate cond(A).\n            st[\"maxrbar\"] = jnp.maximum(st[\"maxrbar\"], rhobarold)\n            st[\"minrbar\"] = lax.select(\n                st[\"itn\"] > 1, jnp.minimum(st[\"minrbar\"], rhobarold), st[\"minrbar\"]\n            )\n            st[\"condA\"] = jnp.maximum(st[\"maxrbar\"], rhotemp) / jnp.minimum(\n                st[\"minrbar\"], rhotemp\n            )\n\n            # Compute norms for convergence testing.\n            st[\"normAr\"] = jnp.abs(st[\"zetabar\"])\n            normx = self.norm(st[\"x\"])\n\n            well_posed_tol = self.atol + self.rtol * (normA * normx + normb)\n            least_squares_tol = self.atol + self.rtol * (normA * st[\"normr\"])\n            # maxiter exceeded\n            st[\"istop\"] = lax.select(st[\"itn\"] >= max_steps, 4, st[\"istop\"])\n            # cond(A) seems to be greater than conlim\n            st[\"istop\"] = lax.select(st[\"condA\"] > self.conlim, 3, st[\"istop\"])\n            # x solves the least-squares problem according to atol and rtol.\n            st[\"istop\"] = lax.select(st[\"normAr\"] < least_squares_tol, 2, st[\"istop\"])\n            # x is a solution to A@x = b, according to atol and rtol.\n            st[\"istop\"] = lax.select(st[\"normr\"] < well_posed_tol, 1, st[\"istop\"])\n            return st\n\n        loop_state = lax.while_loop(condfun, bodyfun, loop_state)\n\n        stats = {\n            \"num_steps\": loop_state[\"itn\"],\n            \"istop\": loop_state[\"istop\"],\n            \"norm_r\": loop_state[\"normr\"],\n            \"norm_Ar\": loop_state[\"normAr\"],\n            \"norm_A\": jnp.sqrt(loop_state[\"normA2\"]),\n            \"cond_A\": loop_state[\"condA\"],\n            \"norm_x\": self.norm(loop_state[\"x\"]),\n        }\n\n        if self.max_steps is None:\n            result = RESULTS.where(\n                loop_state[\"itn\"] == max_steps, RESULTS.singular, RESULTS.successful\n            )\n        elif has_scale:\n            result = RESULTS.where(\n                loop_state[\"itn\"] == max_steps,\n                RESULTS.max_steps_reached,\n                RESULTS.successful,\n            )\n        else:\n            result = RESULTS.successful\n        result = RESULTS.where(loop_state[\"istop\"] < 3, RESULTS.successful, result)\n        result = RESULTS.where(loop_state[\"istop\"] == 3, RESULTS.conlim, result)\n\n        return loop_state[\"x\"], result, stats\n\n    def _givens(self, a, b):\n        \"\"\"Stable implementation of Givens rotation, from [1]_\n\n        finds c, s, r such that\n\n        |c  -s|[a| = |r|\n        [s   c|[b|   |0|\n\n        r = sqrt(a^2 + b^2)\n\n        Assumes a, b are real.\n\n        References\n        ----------\n        .. [1] S.-C. Choi, \"Iterative Methods for Singular Linear Equations\n            and Least-Squares Problems\", Dissertation,\n            http://www.stanford.edu/group/SOL/dissertations/sou-cheng-choi-thesis.pdf\n\n        \"\"\"\n        assert not jnp.iscomplexobj(a)\n        assert not jnp.iscomplexobj(b)\n\n        def bzero(a, b):\n            return jnp.sign(a), 0.0, jnp.abs(a)\n\n        def azero(a, b):\n            return 0.0, jnp.sign(b), jnp.abs(b)\n\n        def b_gt_a(a, b):\n            tau = a / lax.select(b == 0.0, 1.0, b)\n            s = jnp.sign(b) / jnp.sqrt(1.0 + tau**2)\n            c = s * tau\n            r = b / lax.select(s == 0.0, 1.0, s)\n            return c, s, r\n\n        def a_ge_b(a, b):\n            tau = b / lax.select(a == 0.0, 1.0, a)\n            c = jnp.sign(a) / jnp.sqrt(1.0 + tau**2)\n            s = c * tau\n            r = a / lax.select(c == 0.0, 1.0, c)\n            return c, s, r\n\n        def either_zero(a, b):\n            return lax.cond(b == 0.0, bzero, azero, a, b)\n\n        def both_nonzero(a, b):\n            return lax.cond(jnp.abs(b) > jnp.abs(a), b_gt_a, a_ge_b, a, b)\n\n        return lax.cond((a == 0.0) | (b == 0.0), either_zero, both_nonzero, a, b)\n\n    def transpose(self, state: _LSMRState, options: dict[str, Any]):\n        del options\n        operator = state\n        transpose_options = {}\n        return operator.transpose(), transpose_options\n\n    def conj(self, state: _LSMRState, options: dict[str, Any]):\n        del options\n        operator = state\n        conj_options = {}\n        return conj(operator), conj_options\n\n    def assume_full_rank(self):\n        return False\n\n\nLSMR.__init__.__doc__ = r\"\"\"**Arguments:**\n\n- `rtol`: Relative tolerance for terminating solve.\n- `atol`: Absolute tolerance for terminating solve.\n- `norm`: The norm to use when computing whether the error falls within the tolerance.\n    Defaults to the two norm.\n- `max_steps`: The maximum number of iterations to run the solver for. If more steps\n    than this are required, then the solve is halted with a failure.\n- `conlim`: The solver terminates if an estimate of cond(A) exceeds conlim. For\n    compatible systems Ax = b, conlim could be as large as 1.0e+12 (say). For\n    least-squares problems, conlim should be less than 1.0e+8. If conlim is None,\n    the default value is 1e+8. Maximum precision can be obtained by setting\n    atol = rtol = 0, conlim = np.inf, but the number of iterations may then be\n    excessive. Default is 1e8.\n\"\"\"\n"
  },
  {
    "path": "lineax/_solver/lu.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Any, TypeAlias\n\nimport equinox.internal as eqxi\nimport jax.numpy as jnp\nimport jax.scipy as jsp\nfrom jaxtyping import Array, PyTree\n\nfrom .._operator import AbstractLinearOperator, is_diagonal\nfrom .._solution import RESULTS\nfrom .._solve import AbstractLinearSolver\nfrom .misc import (\n    pack_structures,\n    PackedStructures,\n    ravel_vector,\n    transpose_packed_structures,\n    unravel_solution,\n)\n\n\n_LUState: TypeAlias = tuple[tuple[Array, Array], PackedStructures, eqxi.Static]\n\n\nclass LU(AbstractLinearSolver[_LUState]):\n    \"\"\"LU solver for linear systems.\n\n    This solver can only handle square nonsingular operators.\n    \"\"\"\n\n    def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):\n        del options\n        if operator.in_size() != operator.out_size():\n            raise ValueError(\n                \"`LU` may only be used for linear solves with square matrices\"\n            )\n        packed_structures = pack_structures(operator)\n        if is_diagonal(operator):\n            lu = operator.as_matrix(), jnp.arange(operator.in_size(), dtype=jnp.int32)\n        else:\n            lu = jsp.linalg.lu_factor(operator.as_matrix())\n        return lu, packed_structures, eqxi.Static(False)\n\n    def compute(\n        self, state: _LUState, vector: PyTree[Array], options: dict[str, Any]\n    ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:\n        del options\n        lu_and_piv, packed_structures, transpose = state\n        transpose = transpose.value\n        trans = 1 if transpose else 0\n        vector = ravel_vector(vector, packed_structures)\n        solution = jsp.linalg.lu_solve(lu_and_piv, vector, trans=trans)\n        solution = unravel_solution(solution, packed_structures)\n        return solution, RESULTS.successful, {}\n\n    def transpose(\n        self,\n        state: _LUState,\n        options: dict[str, Any],\n    ):\n        lu_and_piv, packed_structures, transpose = state\n        transposed_packed_structures = transpose_packed_structures(packed_structures)\n        transpose_state = (\n            lu_and_piv,\n            transposed_packed_structures,\n            eqxi.Static(not transpose.value),\n        )\n        transpose_options = {}\n        return transpose_state, transpose_options\n\n    def conj(\n        self,\n        state: _LUState,\n        options: dict[str, Any],\n    ):\n        (lu, piv), packed_structures, transpose = state\n        conj_state = (\n            (lu.conj(), piv),\n            packed_structures,\n            eqxi.Static(not transpose.value),\n        )\n        conj_options = {}\n        return conj_state, conj_options\n\n    def assume_full_rank(self):\n        return True\n\n\nLU.__init__.__doc__ = \"\"\"**Arguments:**\n\nNothing.\n\"\"\"\n"
  },
  {
    "path": "lineax/_solver/misc.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nimport typing\nimport warnings\nfrom typing import Any, NewType, TYPE_CHECKING\n\nimport equinox.internal as eqxi\nimport jax.numpy as jnp\nimport jax.tree_util as jtu\nimport numpy as np\nfrom jaxtyping import Array, PyTree, Shaped\n\nfrom .._misc import strip_weak_dtype, structure_equal\nfrom .._operator import AbstractLinearOperator, IdentityLinearOperator, linearise\n\n\ndef preconditioner_and_y0(\n    operator: AbstractLinearOperator, vector: PyTree[Array], options: dict[str, Any]\n):\n    structure = operator.in_structure()\n    try:\n        preconditioner = linearise(options[\"preconditioner\"])\n    except KeyError:\n        preconditioner = IdentityLinearOperator(structure)\n    else:\n        if not isinstance(preconditioner, AbstractLinearOperator):\n            raise ValueError(\"The preconditioner must be a linear operator.\")\n        if not structure_equal(preconditioner.in_structure(), structure):\n            raise ValueError(\n                \"The preconditioner must have `in_structure` that matches the \"\n                \"operator's `in_strucure`.\"\n            )\n        if not structure_equal(preconditioner.out_structure(), structure):\n            raise ValueError(\n                \"The preconditioner must have `out_structure` that matches the \"\n                \"operator's `in_structure`.\"\n            )\n    try:\n        y0 = options[\"y0\"]\n    except KeyError:\n        y0 = jtu.tree_map(jnp.zeros_like, vector)\n    else:\n        if not structure_equal(y0, vector):\n            raise ValueError(\n                \"`y0` must have the same structure, shape, and dtype as `vector`\"\n            )\n    return preconditioner, y0\n\n\n# This seems to introduce some spurious failure at docgen time.\nif hasattr(typing, \"GENERATING_DOCUMENTATION\") and not TYPE_CHECKING:\n    PackedStructures = lambda x: x\nelse:\n    PackedStructures = NewType(\"PackedStructures\", eqxi.Static)\n\n\ndef pack_structures(operator: AbstractLinearOperator) -> PackedStructures:\n    structures = (\n        strip_weak_dtype(operator.out_structure()),\n        strip_weak_dtype(operator.in_structure()),\n    )\n    leaves, treedef = jtu.tree_flatten(structures)  # handle nonhashable pytrees\n    return PackedStructures(eqxi.Static((leaves, treedef)))\n\n\ndef ravel_vector(\n    pytree: PyTree[Array], packed_structures: PackedStructures\n) -> Shaped[Array, \" size\"]:\n    leaves, treedef = packed_structures.value\n    out_structure, _ = jtu.tree_unflatten(treedef, leaves)\n    # `is` in case `tree_equal` returns a Tracer.\n    if not structure_equal(pytree, out_structure):\n        raise ValueError(\"pytree does not match out_structure\")\n    # not using `ravel_pytree` as that doesn't come with guarantees about order\n    leaves = jtu.tree_leaves(pytree)\n    dtype = jnp.result_type(*leaves)\n    return jnp.concatenate([x.astype(dtype).reshape(-1) for x in leaves])\n\n\ndef unravel_solution(\n    solution: Shaped[Array, \" size\"], packed_structures: PackedStructures\n) -> PyTree[Array]:\n    leaves, treedef = packed_structures.value\n    _, in_structure = jtu.tree_unflatten(treedef, leaves)\n    leaves, treedef = jtu.tree_flatten(in_structure)\n    sizes = np.cumsum([math.prod(x.shape) for x in leaves[:-1]])\n    split = jnp.split(solution, sizes)\n    assert len(split) == len(leaves)\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\")  # ignore complex-to-real cast warning\n        shaped = [x.reshape(y.shape).astype(y.dtype) for x, y in zip(split, leaves)]\n    return jtu.tree_unflatten(treedef, shaped)\n\n\ndef transpose_packed_structures(\n    packed_structures: PackedStructures,\n) -> PackedStructures:\n    leaves, treedef = packed_structures.value\n    out_structure, in_structure = jtu.tree_unflatten(treedef, leaves)\n    leaves, treedef = jtu.tree_flatten((in_structure, out_structure))\n    return PackedStructures(eqxi.Static((leaves, treedef)))\n"
  },
  {
    "path": "lineax/_solver/normal.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom copy import copy\nfrom typing import Any, TypeVar\n\nimport equinox.internal as eqxi\nfrom jaxtyping import Array, PyTree\n\nfrom .._operator import conj, linearise, materialise, TaggedLinearOperator\nfrom .._solution import RESULTS\nfrom .._solve import AbstractLinearOperator, AbstractLinearSolver\nfrom .._tags import positive_semidefinite_tag\nfrom .cholesky import Cholesky\n\n\n_InnerSolverState = TypeVar(\"_InnerSolverState\")\n\n\ndef normal_preconditioner_and_y0(options: dict[str, Any], tall: bool):\n    preconditioner = options.get(\"preconditioner\")\n    y0 = options.get(\"y0\")\n    inner_options = copy(options)\n    del options\n    if preconditioner is not None:\n        preconditioner = linearise(preconditioner)\n        if tall:\n            inner_options[\"preconditioner\"] = TaggedLinearOperator(\n                preconditioner @ conj(preconditioner.transpose()),\n                positive_semidefinite_tag,\n            )\n        else:\n            inner_options[\"preconditioner\"] = TaggedLinearOperator(\n                conj(preconditioner.transpose()) @ preconditioner,\n                positive_semidefinite_tag,\n            )\n            if y0 is not None:\n                inner_options[\"y0\"] = conj(preconditioner.transpose()).mv(y0)\n    return inner_options\n\n\nclass Normal(\n    AbstractLinearSolver[\n        tuple[_InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any]]\n    ]\n):\n    \"\"\"Wrapper for an inner solver of positive (semi)definite systems. The\n    wrapped solver handles possibly nonsquare systems $Ax = b$ by applying the\n    inner solver to the normal equations\n\n    $A^* A x = A^* b$\n\n    if $m \\\\ge n$, otherwise\n\n    $A A^* y = b$,\n\n    where $x = A^* y$.\n\n    If the inner solver solves systems with positive definite $A$, the wrapped\n    solver solves systems with full rank $A$.\n\n    If the inner solver solves systems with positive semidefinite $A$, the\n    wrapped solver solves systems with arbitrary, possibly rank deficient, $A$.\n\n    Note that this squares the condition number, so applying this method to an\n    iterative inner solver may result in slow convergence and high sensitivity\n    to roundoff error. In this case it may be advantageous to choose an\n    appropriate preconditioner or initial solution guess for the problem.\n\n    This wrapper adjusts the following `options` before passing to the inner\n    operator (as passed to `lx.linear_solve(..., options=...)`).\n\n    - `preconditioner`: A [`lineax.AbstractLinearOperator`][] to be used as\n        preconditioner. Defaults to [`lineax.IdentityLinearOperator`][]. This\n        should be an approximation of the (pseudo)inverse of $A$. When passed\n        to the inner solver, the preconditioner $M$ is replaced by $M M^*$ and\n        $M^* M$ in the first and second versions of the normal equations,\n        respectively.\n\n    - `y0`: An initial estimate of the solution of the linear system $Ax = b$.\n        Defaults to all zeros. In the second version of the normal equations,\n        $y_0$ is replaced with $M^* y_0$, where $M$ is the given outer\n        preconditioner.\n\n    !!! Info\n\n        Good choices of inner solvers are the direct [`lineax.Cholesky`][] and\n        the iterative [`lineax.CG`][].\n\n    \"\"\"\n\n    inner_solver: AbstractLinearSolver[_InnerSolverState]\n\n    def init(self, operator, options):\n        tall = operator.out_size() >= operator.in_size()\n        # Cholesky materialises op twice when computing (op^H @ op).as_matrix()\n        # Cheaper to materialise first and then conjugate-transpose.\n        # For iterative solvers we only linearise to avoid eager materialisation.\n        is_cholesky = isinstance(self.inner_solver, Cholesky)\n        lin_op = materialise(operator) if is_cholesky else linearise(operator)\n        if tall:\n            inner_operator = conj(lin_op.transpose()) @ lin_op\n        else:\n            inner_operator = lin_op @ conj(lin_op.transpose())\n        inner_operator = TaggedLinearOperator(inner_operator, positive_semidefinite_tag)\n        inner_options = normal_preconditioner_and_y0(options, tall)\n        inner_state = self.inner_solver.init(inner_operator, inner_options)\n        operator_conj_transpose = conj(lin_op.transpose())\n        return inner_state, eqxi.Static(tall), operator_conj_transpose, inner_options\n\n    def compute(\n        self,\n        state: tuple[\n            _InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any]\n        ],\n        vector: PyTree[Array],\n        options: dict[str, Any],\n    ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:\n        inner_state, tall, operator_conj_transpose, inner_options = state\n        tall = tall.value\n        del state, options\n        if tall:\n            vector = operator_conj_transpose.mv(vector)\n        solution, result, extra_stats = self.inner_solver.compute(\n            inner_state, vector, inner_options\n        )\n        if not tall:\n            solution = operator_conj_transpose.mv(solution)\n        return solution, result, extra_stats\n\n    def transpose(\n        self,\n        state: tuple[\n            _InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any]\n        ],\n        options: dict[str, Any],\n    ):\n        inner_state, tall, operator_conj_transpose, inner_options = state\n        inner_state_conj, inner_options = self.inner_solver.conj(\n            inner_state, inner_options\n        )\n        state_transpose = (\n            inner_state_conj,\n            eqxi.Static(not tall.value),\n            operator_conj_transpose.transpose(),\n            inner_options,\n        )\n        return state_transpose, options\n\n    def conj(\n        self,\n        state: tuple[\n            _InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any]\n        ],\n        options: dict[str, Any],\n    ):\n        inner_state, tall, operator_conj_transpose, inner_options = state\n        inner_state_conj, inner_options = self.inner_solver.conj(\n            inner_state, inner_options\n        )\n        state_conj = (\n            inner_state_conj,\n            tall,\n            conj(operator_conj_transpose),\n            inner_options,\n        )\n        return state_conj, options\n\n    def assume_full_rank(self):\n        return self.inner_solver.assume_full_rank()\n\n\nNormal.__init__.__doc__ = \"\"\"**Arguments:**\n\n- `inner_solver`: The solver to wrap. It should support solving positive\n  definite systems or positive semidefinite systems\n\"\"\"\n"
  },
  {
    "path": "lineax/_solver/qr.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Any, TypeAlias\n\nimport equinox.internal as eqxi\nimport jax.lax.linalg as jll\nimport jax.numpy as jnp\nimport jax.scipy as jsp\nfrom jaxtyping import Array, PyTree\n\nfrom .._solution import RESULTS\nfrom .._solve import AbstractLinearSolver\nfrom .misc import (\n    pack_structures,\n    PackedStructures,\n    ravel_vector,\n    transpose_packed_structures,\n    unravel_solution,\n)\n\n\n_QRState: TypeAlias = tuple[tuple[Array, Array], eqxi.Static, PackedStructures]\n\n\nclass QR(AbstractLinearSolver):\n    \"\"\"QR solver for linear systems.\n\n    This solver can handle non-square operators.\n\n    This is usually the preferred solver when dealing with non-square operators.\n\n    !!! info\n\n        Note that whilst this does handle non-square operators, it still can only\n        handle full-rank operators.\n\n        This is because JAX does not currently support a rank-revealing/pivoted QR\n        decomposition, see [issue #12897](https://github.com/google/jax/issues/12897).\n\n        For such use cases, switch to [`lineax.SVD`][] instead.\n    \"\"\"\n\n    def init(self, operator, options):\n        del options\n        matrix = operator.as_matrix()\n        m, n = matrix.shape\n        transpose = n > m\n        if transpose:\n            matrix = matrix.T\n        h, taus = jnp.linalg.qr(matrix, mode=\"raw\")  # pyright: ignore\n        a = h.mT\n        packed_structures = pack_structures(operator)\n        return (a, taus), eqxi.Static(transpose), packed_structures\n\n    def compute(\n        self,\n        state: _QRState,\n        vector: PyTree[Array],\n        options: dict[str, Any],\n    ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:\n        (a, taus), transpose, packed_structures = state\n        transpose = transpose.value\n        del state, options\n        vector = ravel_vector(vector, packed_structures)\n        n_full, n_min = a.shape\n        r = a[:n_min]\n        if transpose:\n            # Minimal norm solution if underdetermined: x = Q.conj() @ R^{-T} @ b.\n            # Use Q.conj() @ z = (z^T @ Q^H)^T to avoid explicit `conj` calls,\n            # and pad `y` along the row axis to absorb the discarded columns of Q.\n            y = jsp.linalg.solve_triangular(r, vector, trans=\"T\", unit_diagonal=False)\n            zeros = jnp.zeros((1, n_full - n_min), dtype=y.dtype)\n            y_pad = jnp.concatenate([y[None, :], zeros], axis=1)\n            solution = jll.ormqr(a, taus, y_pad, left=False, transpose=True)[0]\n        else:\n            # Least squares solution if overdetermined.\n            qHv = jll.ormqr(a, taus, vector[:, None], transpose=True)[:n_min, 0]\n            solution = jsp.linalg.solve_triangular(\n                r, qHv, trans=\"N\", unit_diagonal=False\n            )\n        solution = unravel_solution(solution, packed_structures)\n        return solution, RESULTS.successful, {}\n\n    def transpose(self, state: _QRState, options: dict[str, Any]):\n        (a, taus), transpose, structures = state\n        transposed_packed_structures = transpose_packed_structures(structures)\n        transpose_state = (\n            (a, taus),\n            eqxi.Static(not transpose.value),\n            transposed_packed_structures,\n        )\n        transpose_options = {}\n        return transpose_state, transpose_options\n\n    def conj(self, state: _QRState, options: dict[str, Any]):\n        (a, taus), transpose, structures = state\n        conj_state = (\n            (a.conj(), taus.conj()),\n            transpose,\n            structures,\n        )\n        conj_options = {}\n        return conj_state, conj_options\n\n    def assume_full_rank(self):\n        return True\n\n\nQR.__init__.__doc__ = \"\"\"**Arguments:**\n\nNothing.\n\"\"\"\n"
  },
  {
    "path": "lineax/_solver/svd.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Any, TypeAlias\n\nimport jax.lax as lax\nimport jax.numpy as jnp\nimport jax.scipy as jsp\nfrom jaxtyping import Array, PyTree\n\nfrom .._misc import resolve_rcond\nfrom .._operator import AbstractLinearOperator\nfrom .._solution import RESULTS\nfrom .._solve import AbstractLinearSolver\nfrom .misc import (\n    pack_structures,\n    PackedStructures,\n    ravel_vector,\n    transpose_packed_structures,\n    unravel_solution,\n)\n\n\n_SVDState: TypeAlias = tuple[tuple[Array, Array, Array], PackedStructures]\n\n\nclass SVD(AbstractLinearSolver[_SVDState]):\n    \"\"\"SVD solver for linear systems.\n\n    This solver can handle any operator, even nonsquare or singular ones. In these\n    cases it will return the pseudoinverse solution to the linear system.\n\n    Equivalent to `scipy.linalg.lstsq`.\n    \"\"\"\n\n    rcond: float | None = None\n\n    def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):\n        del options\n        svd = jsp.linalg.svd(operator.as_matrix(), full_matrices=False)\n        packed_structures = pack_structures(operator)\n        return svd, packed_structures\n\n    def compute(\n        self,\n        state: _SVDState,\n        vector: PyTree[Array],\n        options: dict[str, Any],\n    ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:\n        del options\n        (u, s, vt), packed_structures = state\n        vector = ravel_vector(vector, packed_structures)\n        m, _ = u.shape\n        _, n = vt.shape\n        rcond = resolve_rcond(self.rcond, n, m, s.dtype)\n        rcond = jnp.array(rcond, dtype=s.dtype)\n        if s.size > 0:\n            rcond = rcond * s[0]\n        # Not >=, or this fails with a matrix of all-zeros.\n        mask = s > rcond\n        rank = mask.sum()\n        safe_s = jnp.where(mask, s, 1)\n        s_inv = jnp.where(mask, jnp.array(1.0) / safe_s, 0).astype(u.dtype)\n        uTb = jnp.matmul(u.conj().T, vector, precision=lax.Precision.HIGHEST)\n        solution = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)\n        solution = unravel_solution(solution, packed_structures)\n        return solution, RESULTS.successful, {\"rank\": rank}\n\n    def transpose(self, state: _SVDState, options: dict[str, Any]):\n        del options\n        (u, s, vt), packed_structures = state\n        transposed_packed_structures = transpose_packed_structures(packed_structures)\n        transpose_state = (vt.T, s, u.T), transposed_packed_structures\n        transpose_options = {}\n        return transpose_state, transpose_options\n\n    def conj(self, state: _SVDState, options: dict[str, Any]):\n        del options\n        (u, s, vt), packed_structures = state\n        conj_state = (u.conj(), s, vt.conj()), packed_structures\n        conj_options = {}\n        return conj_state, conj_options\n\n    def assume_full_rank(self):\n        return False\n\n\nSVD.__init__.__doc__ = \"\"\"**Arguments**:\n\n- `rcond`: the cutoff for handling zero entries on the diagonal. Defaults to machine\n    precision times `max(N, M)`, where `(N, M)` is the shape of the operator. (I.e.\n    `N` is the output size and `M` is the input size.)\n\"\"\"\n"
  },
  {
    "path": "lineax/_solver/triangular.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Any, TypeAlias\n\nimport equinox.internal as eqxi\nimport jax.scipy as jsp\nfrom jaxtyping import Array, PyTree\n\nfrom .._operator import (\n    AbstractLinearOperator,\n    has_unit_diagonal,\n    is_lower_triangular,\n    is_upper_triangular,\n)\nfrom .._solution import RESULTS\nfrom .._solve import AbstractLinearSolver\nfrom .misc import (\n    pack_structures,\n    PackedStructures,\n    ravel_vector,\n    transpose_packed_structures,\n    unravel_solution,\n)\n\n\n_TriangularState: TypeAlias = tuple[\n    Array, eqxi.Static, eqxi.Static, PackedStructures, eqxi.Static\n]\n\n\nclass Triangular(AbstractLinearSolver[_TriangularState]):\n    \"\"\"Triangular solver for linear systems.\n\n    The operator should either be lower triangular or upper triangular.\n    \"\"\"\n\n    def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):\n        del options\n        if operator.in_size() != operator.out_size():\n            raise ValueError(\n                \"`Triangular` may only be used for linear solves with square matrices\"\n            )\n        if not (is_lower_triangular(operator) or is_upper_triangular(operator)):\n            raise ValueError(\n                \"`Triangular` may only be used for linear solves with triangular \"\n                \"matrices\"\n            )\n        return (\n            operator.as_matrix(),\n            eqxi.Static(is_lower_triangular(operator)),\n            eqxi.Static(has_unit_diagonal(operator)),\n            pack_structures(operator),\n            eqxi.Static(False),  # transposed\n        )\n\n    def compute(\n        self, state: _TriangularState, vector: PyTree[Array], options: dict[str, Any]\n    ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:\n        matrix, lower, unit_diagonal, packed_structures, transpose = state\n        lower = lower.value\n        unit_diagonal = unit_diagonal.value\n        transpose = transpose.value\n        del state, options\n        vector = ravel_vector(vector, packed_structures)\n        if transpose:\n            trans = \"T\"\n        else:\n            trans = \"N\"\n        solution = jsp.linalg.solve_triangular(\n            matrix, vector, trans=trans, lower=lower, unit_diagonal=unit_diagonal\n        )\n        solution = unravel_solution(solution, packed_structures)\n        return solution, RESULTS.successful, {}\n\n    def transpose(self, state: _TriangularState, options: dict[str, Any]):\n        del options\n        matrix, lower, unit_diagonal, packed_structures, transpose = state\n        transposed_packed_structures = transpose_packed_structures(packed_structures)\n        transpose_state = (\n            matrix,\n            lower,\n            unit_diagonal,\n            transposed_packed_structures,\n            eqxi.Static(not transpose.value),\n        )\n        transpose_options = {}\n        return transpose_state, transpose_options\n\n    def conj(self, state: _TriangularState, options: dict[str, Any]):\n        del options\n        matrix, lower, unit_diagonal, packed_structures, transpose = state\n        conj_state = (\n            matrix.conj(),\n            lower,\n            unit_diagonal,\n            packed_structures,\n            transpose,\n        )\n        conj_options = {}\n        return conj_state, conj_options\n\n    def assume_full_rank(self):\n        return True\n\n\nTriangular.__init__.__doc__ = \"\"\"**Arguments:**\n\nNothing.\n\"\"\"\n"
  },
  {
    "path": "lineax/_solver/tridiagonal.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Any, TypeAlias\n\nimport jax.lax as lax\nimport jax.numpy as jnp\nfrom jaxtyping import Array, PyTree\n\nfrom .._operator import AbstractLinearOperator, is_tridiagonal, tridiagonal\nfrom .._solution import RESULTS\nfrom .._solve import AbstractLinearSolver\nfrom .misc import (\n    pack_structures,\n    PackedStructures,\n    ravel_vector,\n    transpose_packed_structures,\n    unravel_solution,\n)\n\n\n_TridiagonalState: TypeAlias = tuple[tuple[Array, Array, Array], PackedStructures]\n\n\nclass Tridiagonal(AbstractLinearSolver[_TridiagonalState]):\n    \"\"\"Tridiagonal solver for linear systems, uses the LAPACK/cusparse implementation\n    of Gaussian elimination with partial pivotting (which increases stability).\n    .\"\"\"\n\n    def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):\n        del options\n        if operator.in_size() != operator.out_size():\n            raise ValueError(\n                \"`Tridiagonal` may only be used for linear solves with square matrices\"\n            )\n        if not is_tridiagonal(operator):\n            raise ValueError(\n                \"`Tridiagonal` may only be used for linear solves with tridiagonal \"\n                \"matrices\"\n            )\n        return tridiagonal(operator), pack_structures(operator)\n\n    def compute(\n        self,\n        state: _TridiagonalState,\n        vector,\n        options,\n    ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:\n        (diagonal, lower_diagonal, upper_diagonal), packed_structures = state\n        del state, options\n        vector = ravel_vector(vector, packed_structures)\n\n        solution = lax.linalg.tridiagonal_solve(\n            jnp.append(0.0, lower_diagonal),\n            diagonal,\n            jnp.append(upper_diagonal, 0.0),\n            vector[:, None],\n        ).flatten()\n\n        solution = unravel_solution(solution, packed_structures)\n        return solution, RESULTS.successful, {}\n\n    def transpose(self, state: _TridiagonalState, options: dict[str, Any]):\n        (diagonal, lower_diagonal, upper_diagonal), packed_structures = state\n        transposed_packed_structures = transpose_packed_structures(packed_structures)\n        transpose_diagonals = (diagonal, upper_diagonal, lower_diagonal)\n        transpose_state = (transpose_diagonals, transposed_packed_structures)\n        return transpose_state, options\n\n    def conj(self, state: _TridiagonalState, options: dict[str, Any]):\n        (diagonal, lower_diagonal, upper_diagonal), packed_structures = state\n        conj_diagonals = (diagonal.conj(), lower_diagonal.conj(), upper_diagonal.conj())\n        conj_state = (conj_diagonals, packed_structures)\n        return conj_state, options\n\n    def assume_full_rank(self):\n        return True\n\n\nTridiagonal.__init__.__doc__ = \"\"\"**Arguments:**\n\nNothing.\n\"\"\"\n"
  },
  {
    "path": "lineax/_tags.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nclass _HasRepr:\n    def __init__(self, string: str):\n        self.string = string\n\n    def __repr__(self):\n        return self.string\n\n\nsymmetric_tag = _HasRepr(\"symmetric_tag\")\ndiagonal_tag = _HasRepr(\"diagonal_tag\")\ntridiagonal_tag = _HasRepr(\"tridiagonal_tag\")\nunit_diagonal_tag = _HasRepr(\"unit_diagonal_tag\")\nlower_triangular_tag = _HasRepr(\"lower_triangular_tag\")\nupper_triangular_tag = _HasRepr(\"upper_triangular_tag\")\npositive_semidefinite_tag = _HasRepr(\"positive_semidefinite_tag\")\nnegative_semidefinite_tag = _HasRepr(\"negative_semidefinite_tag\")\n\n\ntranspose_tags_rules = []\n\n\nfor tag in (\n    symmetric_tag,\n    unit_diagonal_tag,\n    diagonal_tag,\n    positive_semidefinite_tag,\n    negative_semidefinite_tag,\n    tridiagonal_tag,\n):\n\n    @transpose_tags_rules.append\n    def _(tags: frozenset[object], tag=tag):\n        if tag in tags:\n            return tag\n\n\n@transpose_tags_rules.append\ndef _(tags: frozenset[object]):\n    if lower_triangular_tag in tags:\n        return upper_triangular_tag\n\n\n@transpose_tags_rules.append\ndef _(tags: frozenset[object]):\n    if upper_triangular_tag in tags:\n        return lower_triangular_tag\n\n\ndef transpose_tags(tags: frozenset[object]):\n    \"\"\"Lineax uses \"tags\" to declare that a particular linear operator exhibits some\n    property, e.g. symmetry.\n\n    This function takes in a collection of tags representing a linear operator, and\n    returns a collection of tags that should be associated with the transpose of that\n    linear operator.\n\n    **Arguments:**\n\n    - `tags`: a `frozenset` of tags.\n\n    **Returns:**\n\n    A `frozenset` of tags.\n    \"\"\"\n    if symmetric_tag in tags:\n        return tags\n    new_tags = []\n    for rule in transpose_tags_rules:\n        out = rule(tags)\n        if out is not None:\n            new_tags.append(out)\n    return frozenset(new_tags)\n"
  },
  {
    "path": "lineax/internal/__init__.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom .._misc import (\n    complex_to_real_dtype as complex_to_real_dtype,\n    default_floating_dtype as default_floating_dtype,\n)\nfrom .._norm import (\n    max_norm as max_norm,\n    rms_norm as rms_norm,\n    sum_squares as sum_squares,\n    tree_dot as tree_dot,\n    two_norm as two_norm,\n)\nfrom .._solve import linear_solve_p as linear_solve_p\nfrom .._solver.misc import (\n    pack_structures as pack_structures,\n    PackedStructures as PackedStructures,\n    ravel_vector as ravel_vector,\n    transpose_packed_structures as transpose_packed_structures,\n    unravel_solution as unravel_solution,\n)\n"
  },
  {
    "path": "mkdocs.yml",
    "content": "theme:\n    name: material\n    features:\n        - navigation.sections  # Sections are included in the navigation on the left.\n        - toc.integrate  # Table of contents is integrated on the left; does not appear separately on the right.\n        - header.autohide  # header disappears as you scroll\n    palette:\n        # Light mode / dark mode\n        # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as\n        # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle.\n        - scheme: default\n          primary: white\n          accent: amber\n          toggle:\n             icon: material/weather-night\n             name: Switch to dark mode\n        - scheme: slate\n          primary: black\n          accent: amber\n          toggle:\n             icon: material/weather-sunny\n             name: Switch to light mode\n    icon:\n        repo: fontawesome/brands/github  # GitHub logo in top right\n        logo: \"material/matrix\"  # lineax logo in top left\n    favicon: \"_static/favicon.png\"\n    custom_dir: \"docs/_overrides\"  # Overriding part of the HTML\n\n    # These additions are my own custom ones, having overridden a partial.\n    twitter_bluesky_name: \"@PatrickKidger\"\n    twitter_url: \"https://twitter.com/PatrickKidger\"\n    bluesky_url: \"https://PatrickKidger.bsky.social\"\n\nsite_name: lineax\nsite_description: The documentation for the Lineax software library.\nsite_author: Patrick Kidger\nsite_url: https://docs.kidger.site/lineax\n\nrepo_url: https://github.com/patrick-kidger/lineax\nrepo_name: patrick-kidger/lineax\nedit_uri: \"\"\n\nstrict: true  # Don't allow warnings during the build process\n\nextra_javascript:\n    # The below two make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/\n    - _static/mathjax.js\n    - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js\n\nextra_css:\n    - _static/custom_css.css\n\nmarkdown_extensions:\n    - pymdownx.arithmatex:  # Render LaTeX via MathJax\n        generic: true\n    - pymdownx.superfences  # Seems to enable syntax highlighting when used with the Material theme.\n    - pymdownx.details  # Allowing hidden expandable regions denoted by ???\n    - pymdownx.snippets:  # Include one Markdown file into another\n        base_path: docs\n    - admonition\n    - toc:\n        permalink: \"¤\"  # Adds a clickable permalink to each section heading\n        toc_depth: 4\n\nplugins:\n    - search:\n        separator: '[\\s\\-,:!=\\[\\]()\"/]+|(?!\\b)(?=[A-Z][a-z])|\\.(?!\\d)|&[lg]t;'\n    - include_exclude_files:\n        include:\n            - \".htaccess\"\n        exclude:\n            - \"_overrides\"\n            - \"examples/.ipynb_checkpoints/\"\n    - ipynb\n    - hippogriffe:\n        extra_public_objects:\n            - jax.ShapeDtypeStruct\n    - mkdocstrings:\n        handlers:\n            python:\n                options:\n                    force_inspection: true\n                    heading_level: 4\n                    inherited_members: true\n                    members_order: source\n                    show_bases: false\n                    show_if_no_docstring: true\n                    show_overloads: false\n                    show_root_heading: true\n                    show_signature_annotations: true\n                    show_source: false\n                    show_symbol_type_heading: true\n                    show_symbol_type_toc: true\n\nnav:\n    - 'index.md'\n    - Examples:\n        - 'examples/classical_solve.ipynb'\n        - 'examples/least_squares.ipynb'\n        - 'examples/structured_matrices.ipynb'\n        - 'examples/no_materialisation.ipynb'\n        - 'examples/operators.ipynb'\n        - 'examples/complex_solve.ipynb'\n    - API:\n        - 'api/linear_solve.md'\n        - 'api/solvers.md'\n        - 'api/operators.md'\n        - 'api/tags.md'\n        - 'api/solution.md'\n        - 'api/functions.md'\n    - 'faq.md'\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nbuild-backend = \"hatchling.build\"\nrequires = [\"hatchling\"]\n\n[dependency-groups]\ndev = [\n  \"prek==0.3.9\",\n  \"pyright==1.1.406\",\n  \"ruff==0.13.0\",\n  \"toml-sort==0.23.1\"\n]\ndocs = [\n  \"hippogriffe==0.2.2\",\n  \"griffe==1.7.3\",\n  \"mkdocs==1.6.1\",\n  \"mkdocs-include-exclude-files==0.1.0\",\n  \"mkdocs-ipynb==0.1.1\",\n  \"mkdocs-material==9.6.7\",\n  \"mkdocstrings==0.28.3\",\n  \"mkdocstrings-python==1.16.8\",\n  \"pygments==2.20.0\",\n  \"pymdown-extensions==10.21.2\"\n]\ntests = [\n  \"beartype\",\n  \"equinox\",\n  \"pytest\",\n  \"pytest-xdist\",\n  \"jaxlib\"\n]\n\n[project]\nauthors = [\n  {email = \"raderjason@outlook.com\", name = \"Jason Rader\"},\n  {email = \"contact@kidger.site\", name = \"Patrick Kidger\"}\n]\nclassifiers = [\n  \"Development Status :: 3 - Alpha\",\n  \"Intended Audience :: Developers\",\n  \"Intended Audience :: Financial and Insurance Industry\",\n  \"Intended Audience :: Information Technology\",\n  \"Intended Audience :: Science/Research\",\n  \"License :: OSI Approved :: Apache Software License\",\n  \"Natural Language :: English\",\n  \"Programming Language :: Python :: 3\",\n  \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n  \"Topic :: Scientific/Engineering :: Information Analysis\",\n  \"Topic :: Scientific/Engineering :: Mathematics\"\n]\ndependencies = [\"jax>=0.10.0\", \"jaxtyping>=0.2.24\", \"equinox>=0.11.10\", \"typing_extensions>=4.5.0\"]\ndescription = \"Linear solvers in JAX and Equinox.\"\nkeywords = [\"jax\", \"neural-networks\", \"deep-learning\", \"equinox\", \"linear-solvers\", \"least-squares\", \"numerical-methods\"]\nlicense = {file = \"LICENSE\"}\nname = \"lineax\"\nreadme = \"README.md\"\nrequires-python = \"~=3.11\"\nurls = {repository = \"https://github.com/google/lineax\"}\nversion = \"0.1.1\"\n\n[tool.hatch.build]\ninclude = [\"lineax/*\"]\n\n[tool.pyright]\ninclude = [\"lineax\", \"tests\"]\nreportIncompatibleMethodOverride = true\n\n[tool.pytest.ini_options]\naddopts = \"--jaxtyping-packages=lineax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))\"\n\n[tool.ruff]\nextend-include = [\"*.ipynb\"]\nsrc = []\n\n[tool.ruff.lint]\nfixable = [\"I001\", \"F401\", \"UP\"]\nignore = [\"E402\", \"E721\", \"E731\", \"E741\", \"F722\"]\nselect = [\"E\", \"F\", \"I001\", \"UP\"]\n\n[tool.ruff.lint.flake8-import-conventions.extend-aliases]\n\"collections\" = \"co\"\n\"functools\" = \"ft\"\n\"itertools\" = \"it\"\n\n[tool.ruff.lint.isort]\ncombine-as-imports = true\nextra-standard-library = [\"typing_extensions\"]\nlines-after-imports = 2\norder-by-type = false\n\n[tool.uv]\ndefault-groups = [\"dev\", \"docs\", \"tests\"]\n"
  },
  {
    "path": "tests/README.md",
    "content": "Each file is run separately to avoid JAX out-of-memory'ing.\n\nAs such, run tests using `python -m tests`, *not* by just running `pytest`.\n"
  },
  {
    "path": "tests/__init__.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/__main__.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pathlib\nimport subprocess\nimport sys\n\n\nhere = pathlib.Path(__file__).resolve().parent\n\n\n# Each file is ran separately to avoid out-of-memorying.\nrunning_out = 0\nfor file in here.iterdir():\n    if file.is_file() and file.name.startswith(\"test\"):\n        out = subprocess.run(f\"pytest {file}\", shell=True).returncode\n        running_out = max(running_out, out)\nsys.exit(running_out)\n"
  },
  {
    "path": "tests/conftest.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport equinox.internal as eqxi\nimport jax\nimport pytest\n\n\njax.config.update(\"jax_enable_x64\", True)\njax.config.update(\"jax_numpy_dtype_promotion\", \"strict\")\njax.config.update(\"jax_numpy_rank_promotion\", \"raise\")\n\n\n@pytest.fixture\ndef getkey():\n    return eqxi.GetKey()\n"
  },
  {
    "path": "tests/helpers.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools as ft\nimport math\n\nimport equinox as eqx\nimport equinox.internal as eqxi\nimport jax\nimport jax.lax as lax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport jax.tree_util as jtu\nimport lineax as lx\nimport numpy as np\nfrom equinox.internal import ω\n\n\n@ft.cache\ndef _construct_matrix_impl(\n    getkey, tags, size, dtype, cond_or_singular: int | float | str, i: int\n):\n    del i  # used to break the cache\n    while True:\n        matrix = jr.normal(getkey(), (size, size), dtype=dtype)\n        if isinstance(cond_or_singular, str):\n            if cond_or_singular == \"zero\":\n                matrix = matrix.at[0, :].set(0)\n            elif cond_or_singular == \"trim_row\":\n                matrix = matrix[1:, :]\n            elif cond_or_singular == \"trim_col\":\n                matrix = matrix[:, 1:]\n        if tags != ():\n            assert (\n                isinstance(cond_or_singular, (int, float)) or cond_or_singular == \"zero\"\n            )\n        if has_tag(tags, lx.diagonal_tag):\n            matrix = jnp.diag(jnp.diag(matrix))\n        if has_tag(tags, lx.symmetric_tag):\n            matrix = matrix + matrix.T\n        if has_tag(tags, lx.lower_triangular_tag):\n            matrix = jnp.tril(matrix)\n        if has_tag(tags, lx.upper_triangular_tag):\n            matrix = jnp.triu(matrix)\n        if has_tag(tags, lx.unit_diagonal_tag):\n            matrix = matrix.at[jnp.arange(size), jnp.arange(size)].set(1)\n        if has_tag(tags, lx.tridiagonal_tag):\n            diagonal = jnp.diag(jnp.diag(matrix))\n            upper_diagonal = jnp.diag(jnp.diag(matrix, k=1), k=1)\n            lower_diagonal = jnp.diag(jnp.diag(matrix, k=-1), k=-1)\n            matrix = lower_diagonal + diagonal + upper_diagonal\n        if has_tag(tags, lx.positive_semidefinite_tag):\n            matrix = matrix @ matrix.T.conj()\n        if has_tag(tags, lx.negative_semidefinite_tag):\n            matrix = -matrix @ matrix.T.conj()\n        if isinstance(cond_or_singular, str):\n            break\n        else:\n            if eqxi.unvmap_all(jnp.linalg.cond(matrix) < cond_or_singular):  # pyright: ignore\n                break\n    return matrix\n\n\ndef construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64):\n    if isinstance(solver, lx.Normal):\n        cond_cutoff = math.sqrt(1000)\n    else:\n        cond_cutoff = 1000\n    return tuple(\n        _construct_matrix_impl(getkey, tags, size, dtype, cond_cutoff, i)\n        for i in range(num)\n    )\n\n\ndef construct_singular_matrix(getkey, solver, tags, num=1, dtype=jnp.float64):\n    if isinstance(solver, (lx.Diagonal, lx.CG, lx.BiCGStab, lx.GMRES)):\n        singular_method = \"zero\"\n    else:\n        # Use `getkey()` rather than the stdlib `random.choice` for reproducibility\n        singular_method = [\"zero\", \"trim_row\", \"trim_col\"][\n            jr.choice(getkey(), np.array([0, 1, 2]))\n        ]\n    size = 3\n    return tuple(\n        _construct_matrix_impl(getkey, tags, size, dtype, singular_method, i)\n        for i in range(num)\n    )\n\n\ndef construct_poisson_matrix(size, dtype=jnp.float64):\n    matrix = (\n        -2 * jnp.diag(jnp.ones(size, dtype=dtype))\n        + jnp.diag(jnp.ones(size - 1, dtype=dtype), 1)\n        + jnp.diag(jnp.ones(size - 1, dtype=dtype), -1)\n    )\n    return matrix\n\n\nif jax.config.jax_enable_x64:  # pyright: ignore\n    tol = 1e-12\nelse:\n    tol = 1e-6\nsolvers_tags_pseudoinverse = [\n    (lx.AutoLinearSolver(well_posed=True), (), False),\n    (lx.AutoLinearSolver(well_posed=False), (), True),\n    (lx.Triangular(), lx.lower_triangular_tag, False),\n    (lx.Triangular(), lx.upper_triangular_tag, False),\n    (lx.Triangular(), (lx.lower_triangular_tag, lx.unit_diagonal_tag), False),\n    (lx.Triangular(), (lx.upper_triangular_tag, lx.unit_diagonal_tag), False),\n    (lx.Diagonal(), lx.diagonal_tag, False),\n    (lx.Diagonal(), (lx.diagonal_tag, lx.unit_diagonal_tag), False),\n    (lx.Tridiagonal(), lx.tridiagonal_tag, False),\n    (lx.LU(), (), False),\n    (lx.QR(), (), False),\n    (lx.SVD(), (), True),\n    (lx.BiCGStab(rtol=tol, atol=tol), (), False),\n    (lx.GMRES(rtol=tol, atol=tol), (), False),\n    (lx.CG(rtol=tol, atol=tol), lx.positive_semidefinite_tag, False),\n    (lx.CG(rtol=tol, atol=tol), lx.negative_semidefinite_tag, False),\n    (lx.Normal(lx.CG(rtol=tol, atol=tol)), (), False),\n    (lx.LSMR(atol=tol, rtol=tol), (), True),\n    (lx.Cholesky(), lx.positive_semidefinite_tag, False),\n    (lx.Cholesky(), lx.negative_semidefinite_tag, False),\n    (lx.Normal(lx.Cholesky()), (), False),\n]\nsolvers_tags = [(a, b) for a, b, _ in solvers_tags_pseudoinverse]\nsolvers = [a for a, _, _ in solvers_tags_pseudoinverse]\npseudosolvers_tags = [(a, b) for a, b, c in solvers_tags_pseudoinverse if c]\n\n\ndef _transpose(operator, matrix):\n    return operator.T, matrix.T\n\n\ndef _linearise(operator, matrix):\n    return lx.linearise(operator), matrix\n\n\ndef _materialise(operator, matrix):\n    return lx.materialise(operator), matrix\n\n\nops = (lambda x, y: (x, y), _transpose, _linearise, _materialise)\n\n\ndef params(only_pseudo):\n    for make_operator in make_operators:\n        for solver, tags, pseudoinverse in solvers_tags_pseudoinverse:\n            if only_pseudo and not pseudoinverse:\n                continue\n            if (\n                make_operator is make_trivial_diagonal_operator\n                and tags != lx.diagonal_tag\n            ):\n                continue\n            if make_operator is make_identity_operator and tags != lx.unit_diagonal_tag:\n                continue\n            if (\n                make_operator is make_tridiagonal_operator\n                and tags != lx.tridiagonal_tag\n            ):\n                continue\n            yield make_operator, solver, tags\n\n\ndef tree_allclose(x, y, *, rtol=1e-5, atol=1e-8):\n    return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)\n\n\ndef has_tag(tags, tag):\n    return tag is tags or (isinstance(tags, tuple) and tag in tags)\n\n\nmake_operators = []\n\n\ndef _operators_append(x):\n    make_operators.append(x)\n    return x\n\n\n@_operators_append\ndef make_matrix_operator(getkey, matrix, tags):\n    return lx.MatrixLinearOperator(matrix, tags)\n\n\n@_operators_append\ndef make_trivial_pytree_operator(getkey, matrix, tags):\n    out_size, _ = matrix.shape\n    struct = jax.ShapeDtypeStruct((out_size,), matrix.dtype)\n    return lx.PyTreeLinearOperator(matrix, struct, tags)\n\n\n@_operators_append\ndef make_function_operator(getkey, matrix, tags):\n    fn = lambda x: matrix @ x\n    _, in_size = matrix.shape\n    in_struct = jax.ShapeDtypeStruct((in_size,), matrix.dtype)\n    return lx.FunctionLinearOperator(fn, in_struct, tags)\n\n\n@_operators_append\ndef make_jac_operator(getkey, matrix, tags):\n    out_size, in_size = matrix.shape\n    x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)\n    a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)\n    b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)\n    c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)\n    fn_tmp = lambda x, _: a + b @ x + c @ x**2\n    jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)\n    diff = matrix - jac\n    fn = lambda x, _: a + (b + diff) @ x + c @ x**2\n    return lx.JacobianLinearOperator(fn, x, None, tags)\n\n\n@_operators_append\ndef make_jacfwd_operator(getkey, matrix, tags):\n    out_size, in_size = matrix.shape\n    x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)\n    a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)\n    b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)\n    c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)\n    fn_tmp = lambda x, _: a + b @ x + c @ x**2\n    jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)\n    diff = matrix - jac\n    fn = lambda x, _: a + (b + diff) @ x + c @ x**2\n    return lx.JacobianLinearOperator(fn, x, None, tags, jac=\"fwd\")\n\n\n@_operators_append\ndef make_jacrev_operator(getkey, matrix, tags):\n    \"\"\"JacobianLinearOperator with jac='bwd' using a custom_vjp function.\n\n    This uses custom_vjp so that forward-mode autodiff is NOT available,\n    which tests that jac='bwd' works correctly without relying on JVP.\n    \"\"\"\n    out_size, in_size = matrix.shape\n    x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)\n    a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)\n    b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)\n    c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)\n    fn_tmp = lambda x, _: a + b @ x + c @ x**2\n    jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)\n    diff = matrix - jac\n\n    # Use custom_vjp to define a function that only has reverse-mode autodiff\n    @jax.custom_vjp\n    def custom_fn(x):\n        return a + (b + diff) @ x + c @ x**2\n\n    def custom_fn_fwd(x):\n        return custom_fn(x), x\n\n    def custom_fn_bwd(x, g):\n        # Jacobian is: (b + diff) + 2 * c * x\n        # VJP is: g @ J = g @ ((b + diff) + 2 * c * x)\n        # So J.T @ g =\n        return ((b + diff).T @ g + 2 * (c.T @ g) * x,)\n\n    custom_fn.defvjp(custom_fn_fwd, custom_fn_bwd)\n\n    fn = lambda x, _: custom_fn(x)\n    return lx.JacobianLinearOperator(fn, x, None, tags, jac=\"bwd\")\n\n\n@_operators_append\ndef make_trivial_diagonal_operator(getkey, matrix, tags):\n    assert tags == lx.diagonal_tag\n    diag = jnp.diag(matrix)\n    return lx.DiagonalLinearOperator(diag)\n\n\n@_operators_append\ndef make_identity_operator(getkey, matrix, tags):\n    in_struct = jax.ShapeDtypeStruct((matrix.shape[-1],), matrix.dtype)\n    return lx.IdentityLinearOperator(input_structure=in_struct)\n\n\n@_operators_append\ndef make_tridiagonal_operator(getkey, matrix, tags):\n    diag1 = jnp.diag(matrix)\n    if tags == lx.tridiagonal_tag:\n        diag2 = jnp.diag(matrix, k=-1)\n        diag3 = jnp.diag(matrix, k=1)\n        return lx.TridiagonalLinearOperator(diag1, diag2, diag3)\n    elif tags == lx.diagonal_tag:\n        diag2 = diag3 = jnp.zeros(matrix.shape[0] - 1)\n        return lx.TaggedLinearOperator(\n            lx.TridiagonalLinearOperator(diag1, diag2, diag3), lx.diagonal_tag\n        )\n    elif tags == lx.symmetric_tag:\n        diag2 = diag3 = jnp.diag(matrix, k=1)\n        return lx.TaggedLinearOperator(\n            lx.TridiagonalLinearOperator(diag1, diag2, diag3), lx.symmetric_tag\n        )\n    else:\n        assert False, tags\n\n\n@_operators_append\ndef make_add_operator(getkey, matrix, tags):\n    matrix1 = 0.7 * matrix\n    matrix2 = 0.3 * matrix\n    operator = make_matrix_operator(getkey, matrix1, ()) + make_function_operator(\n        getkey, matrix2, ()\n    )\n    return lx.TaggedLinearOperator(operator, tags)\n\n\n@_operators_append\ndef make_mul_operator(getkey, matrix, tags):\n    operator = make_jac_operator(getkey, 0.7 * matrix, ()) / 0.7\n    return lx.TaggedLinearOperator(operator, tags)\n\n\n@_operators_append\ndef make_composed_operator(getkey, matrix, tags):\n    _, size = matrix.shape\n    diag = jr.normal(getkey(), (size,), dtype=matrix.dtype)\n    diag = jnp.where(jnp.abs(diag) < 0.05, 0.8, diag)\n    operator1 = make_trivial_pytree_operator(getkey, matrix / diag[None], ())\n    operator2 = lx.DiagonalLinearOperator(diag)\n    return lx.TaggedLinearOperator(operator1 @ operator2, tags)\n\n\n# Slightly sketchy approach to finite differences, in that this is pulled out of\n# Numerical Recipes.\n# I also don't know of a handling of the JVP case off the top of my head -- although\n# I'm sure it exists somewhere -- so I'm improvising a little here. (In particular\n# removing the usual \"(x + h) - x\" denominator.)\ndef finite_difference_jvp(fn, primals, tangents):\n    out = fn(*primals)\n    # Choose ε to trade-off truncation error and floating-point rounding error.\n    max_leaves = [jnp.max(jnp.abs(p)) for p in jtu.tree_leaves(primals)] + [1]\n    scale = jnp.max(jnp.stack(max_leaves))\n    ε = np.sqrt(np.finfo(np.float64).eps) * scale\n    with jax.numpy_dtype_promotion(\"standard\"):\n        primals_ε = (ω(primals) + ε * ω(tangents)).ω\n        out_ε = fn(*primals_ε)\n        tangents_out = jtu.tree_map(lambda x, y: (x - y) / ε, out_ε, out)\n    return out, tangents_out\n\n\ndef jvp_jvp_impl(\n    getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype\n):\n    t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None\n    if (make_matrix is construct_matrix) or pseudoinverse:\n        matrix, t_matrix, tt_matrix, tt_t_matrix = construct_matrix(\n            getkey, solver, tags, num=4, dtype=dtype\n        )\n\n        make_op = ft.partial(make_operator, getkey)\n        t_make_operator = lambda p, t_p: eqx.filter_jvp(\n            make_op, (p, tags), (t_p, t_tags)\n        )\n        tt_make_operator = lambda p, t_p, tt_p, tt_t_p: eqx.filter_jvp(\n            t_make_operator, (p, t_p), (tt_p, tt_t_p)\n        )\n        (operator, t_operator), (tt_operator, tt_t_operator) = tt_make_operator(\n            matrix, t_matrix, tt_matrix, tt_t_matrix\n        )\n\n        out_size, _ = matrix.shape\n        vec = jr.normal(getkey(), (out_size,), dtype=dtype)\n        t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)\n        tt_vec = jr.normal(getkey(), (out_size,), dtype=dtype)\n        tt_t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)\n\n        if use_state:\n\n            def linear_solve1(operator, vector):\n                op_dynamic, op_static = eqx.partition(operator, eqx.is_inexact_array)\n                stopped_operator = eqx.combine(lax.stop_gradient(op_dynamic), op_static)\n                state = solver.init(stopped_operator, options={})\n\n                sol = lx.linear_solve(operator, vector, state=state, solver=solver)\n                return sol.value\n\n        else:\n\n            def linear_solve1(operator, vector):\n                sol = lx.linear_solve(operator, vector, solver=solver)\n                return sol.value\n\n        if pseudoinverse:\n            jnp_solve1 = lambda mat, vec: jnp.linalg.lstsq(mat, vec)[0]  # pyright: ignore\n        else:\n            jnp_solve1 = jnp.linalg.solve  # pyright: ignore\n\n        linear_solve2 = ft.partial(eqx.filter_jvp, linear_solve1)\n        jnp_solve2 = ft.partial(eqx.filter_jvp, jnp_solve1)\n\n        def _make_primal_tangents(mode):\n            lx_args = ([], [], operator, t_operator, tt_operator, tt_t_operator)\n            jnp_args = ([], [], matrix, t_matrix, tt_matrix, tt_t_matrix)\n            for primals, ttangents, op, t_op, tt_op, tt_t_op in (lx_args, jnp_args):\n                if \"op\" in mode:\n                    primals.append(op)\n                    ttangents.append(tt_op)\n                if \"vec\" in mode:\n                    primals.append(vec)\n                    ttangents.append(tt_vec)\n                if \"t_op\" in mode:\n                    primals.append(t_op)\n                    ttangents.append(tt_t_op)\n                if \"t_vec\" in mode:\n                    primals.append(t_vec)\n                    ttangents.append(tt_t_vec)\n            lx_out = tuple(lx_args[0]), tuple(lx_args[1])\n            jnp_out = tuple(jnp_args[0]), tuple(jnp_args[1])\n            return lx_out, jnp_out\n\n        modes = (\n            {\"op\"},\n            {\"vec\"},\n            {\"t_op\"},\n            {\"t_vec\"},\n            {\"op\", \"vec\"},\n            {\"op\", \"t_op\"},\n            {\"op\", \"t_vec\"},\n            {\"vec\", \"t_op\"},\n            {\"vec\", \"t_vec\"},\n            {\"op\", \"vec\", \"t_op\"},\n            {\"op\", \"vec\", \"t_vec\"},\n            {\"vec\", \"t_op\", \"t_vec\"},\n            {\"op\", \"vec\", \"t_op\", \"t_vec\"},\n        )\n        for mode in modes:\n            if mode == {\"op\"}:\n                linear_solve3 = lambda op: linear_solve2((op, vec), (t_operator, t_vec))\n                jnp_solve3 = lambda mat: jnp_solve2((mat, vec), (t_matrix, t_vec))\n            elif mode == {\"vec\"}:\n                linear_solve3 = lambda v: linear_solve2(\n                    (operator, v), (t_operator, t_vec)\n                )\n                jnp_solve3 = lambda v: jnp_solve2((matrix, v), (t_matrix, t_vec))\n            elif mode == {\"op\", \"vec\"}:\n                linear_solve3 = lambda op, v: linear_solve2(\n                    (op, v), (t_operator, t_vec)\n                )\n                jnp_solve3 = lambda mat, v: jnp_solve2((mat, v), (t_matrix, t_vec))\n            elif mode == {\"t_op\"}:\n                linear_solve3 = lambda t_op: linear_solve2(\n                    (operator, vec), (t_op, t_vec)\n                )\n                jnp_solve3 = lambda t_mat: jnp_solve2((matrix, vec), (t_mat, t_vec))\n            elif mode == {\"t_vec\"}:\n                linear_solve3 = lambda t_v: linear_solve2(\n                    (operator, vec), (t_operator, t_v)\n                )\n                jnp_solve3 = lambda t_v: jnp_solve2((matrix, vec), (t_matrix, t_v))\n            elif mode == {\"op\", \"vec\"}:\n                linear_solve3 = lambda op, v: linear_solve2(\n                    (op, v), (t_operator, t_vec)\n                )\n                jnp_solve3 = lambda mat, v: jnp_solve2((mat, v), (t_matrix, t_vec))\n            elif mode == {\"op\", \"t_op\"}:\n                linear_solve3 = lambda op, t_op: linear_solve2((op, vec), (t_op, t_vec))\n                jnp_solve3 = lambda mat, t_mat: jnp_solve2((mat, vec), (t_mat, t_vec))\n            elif mode == {\"op\", \"t_vec\"}:\n                linear_solve3 = lambda op, t_v: linear_solve2(\n                    (op, vec), (t_operator, t_v)\n                )\n                jnp_solve3 = lambda mat, t_v: jnp_solve2((mat, vec), (t_matrix, t_v))\n            elif mode == {\"vec\", \"t_op\"}:\n                linear_solve3 = lambda v, t_op: linear_solve2(\n                    (operator, v), (t_op, t_vec)\n                )\n                jnp_solve3 = lambda v, t_mat: jnp_solve2((matrix, v), (t_mat, t_vec))\n            elif mode == {\"vec\", \"t_vec\"}:\n                linear_solve3 = lambda v, t_v: linear_solve2(\n                    (operator, v), (t_operator, t_v)\n                )\n                jnp_solve3 = lambda v, t_v: jnp_solve2((matrix, v), (t_matrix, t_v))\n            elif mode == {\"op\", \"vec\", \"t_op\"}:\n                linear_solve3 = lambda op, v, t_op: linear_solve2(\n                    (op, v), (t_op, t_vec)\n                )\n                jnp_solve3 = lambda mat, v, t_mat: jnp_solve2((mat, v), (t_mat, t_vec))\n            elif mode == {\"op\", \"vec\", \"t_vec\"}:\n                linear_solve3 = lambda op, v, t_v: linear_solve2(\n                    (op, v), (t_operator, t_v)\n                )\n                jnp_solve3 = lambda mat, v, t_v: jnp_solve2((mat, v), (t_matrix, t_v))\n            elif mode == {\"vec\", \"t_op\", \"t_vec\"}:\n                linear_solve3 = lambda v, t_op, t_v: linear_solve2(\n                    (operator, v), (t_op, t_v)\n                )\n                jnp_solve3 = lambda v, t_mat, t_v: jnp_solve2((matrix, v), (t_mat, t_v))\n            elif mode == {\"op\", \"vec\", \"t_op\", \"t_vec\"}:\n                linear_solve3 = lambda op, v, t_op, t_v: linear_solve2(\n                    (op, v), (t_op, t_v)\n                )\n                jnp_solve3 = lambda mat, v, t_mat, t_v: jnp_solve2(\n                    (mat, v), (t_mat, t_v)\n                )\n            else:\n                assert False\n\n            linear_solve3 = ft.partial(eqx.filter_jvp, linear_solve3)\n            linear_solve3 = eqx.filter_jit(linear_solve3)\n            jnp_solve3 = ft.partial(eqx.filter_jvp, jnp_solve3)\n            jnp_solve3 = eqx.filter_jit(jnp_solve3)\n\n            (primal, tangent), (jnp_primal, jnp_tangent) = _make_primal_tangents(mode)\n            (out, t_out), (minus_out, tt_out) = linear_solve3(primal, tangent)\n            (true_out, true_t_out), (minus_true_out, true_tt_out) = jnp_solve3(\n                jnp_primal, jnp_tangent\n            )\n\n            assert tree_allclose(out, true_out, atol=1e-4)\n            assert tree_allclose(t_out, true_t_out, atol=1e-4)\n            assert tree_allclose(tt_out, true_tt_out, atol=1e-4)\n            assert tree_allclose(minus_out, minus_true_out, atol=1e-4)\n"
  },
  {
    "path": "tests/test_adjoint.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport lineax as lx\nimport pytest\nfrom lineax import FunctionLinearOperator\n\nfrom .helpers import (\n    make_identity_operator,\n    make_jacrev_operator,\n    make_operators,\n    make_tridiagonal_operator,\n    make_trivial_diagonal_operator,\n    tree_allclose,\n)\n\n\n@pytest.mark.parametrize(\"make_operator\", make_operators)\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_adjoint(make_operator, dtype, getkey):\n    if (\n        make_operator is make_trivial_diagonal_operator\n        or make_operator is make_identity_operator\n    ):\n        matrix = jnp.eye(4, dtype=dtype)\n        tags = lx.diagonal_tag\n        in_size = out_size = 4\n    elif make_operator is make_tridiagonal_operator:\n        matrix = jnp.eye(4, dtype=dtype)\n        tags = lx.tridiagonal_tag\n        in_size = out_size = 4\n    else:\n        matrix = jr.normal(getkey(), (3, 5), dtype=dtype)\n        tags = ()\n        in_size = 5\n        out_size = 3\n    if make_operator is make_jacrev_operator and dtype is jnp.complex128:\n        # JacobianLinearOperator does not support complex dtypes when jac=\"bwd\"\n        return\n    operator = make_operator(getkey, matrix, tags)\n    v1, v2 = (\n        jr.normal(getkey(), (in_size,), dtype=dtype),\n        jr.normal(getkey(), (out_size,), dtype=dtype),\n    )\n\n    inner1 = operator.mv(v1) @ v2.conj()\n    adjoint_op1 = lx.conj(operator).transpose()\n    ov2 = adjoint_op1.mv(v2)\n    inner2 = v1 @ ov2.conj()\n    assert tree_allclose(inner1, inner2)\n\n    adjoint_op2 = lx.conj(operator.transpose())\n    ov2 = adjoint_op2.mv(v2)\n    inner2 = v1 @ ov2.conj()\n    assert tree_allclose(inner1, inner2)\n\n\ndef test_functional_pytree_adjoint():\n    def fn(y):\n        return {\"b\": y[\"a\"]}\n\n    y_struct = jax.eval_shape(lambda: {\"a\": 0.0})\n    operator = FunctionLinearOperator(fn, y_struct)\n    conj_operator = lx.conj(operator)\n    assert tree_allclose(lx.materialise(conj_operator), lx.materialise(operator))\n\n\ndef test_functional_pytree_adjoint_complex():\n    def fn(y):\n        return {\"b\": y[\"a\"]}\n\n    y_struct = jax.eval_shape(lambda: {\"a\": 0.0j})\n    operator = FunctionLinearOperator(fn, y_struct)\n    conj_operator = lx.conj(operator)\n    assert tree_allclose(lx.materialise(conj_operator), lx.materialise(operator))\n\n\nif jax.config.jax_enable_x64:  # pyright: ignore\n    tol = 1e-12\nelse:\n    tol = 1e-6\n\n\n@pytest.mark.parametrize(\n    \"solver\",\n    [\n        # in theory only 1 iteration is needed, but stopping criteria are\n        # complicated, see gh #160\n        lx.GMRES(tol, tol, max_steps=4, restart=1),\n        lx.BiCGStab(tol, tol, max_steps=3),\n        lx.Normal(lx.CG(tol, tol, max_steps=4)),\n        lx.CG(tol, tol, max_steps=3),\n    ],\n)\ndef test_preconditioner_adjoint(solver):\n    \"\"\"Test for fix to gh #160\"\"\"\n    # Nonsymmetric poorly conditioned matrix. Without preconditioning,\n    # this would take 20+ iterations (100s for GMRES)\n    key = jax.random.key(123)\n    key, subkey = jax.random.split(key)\n    A = jax.random.uniform(key, (10, 10))\n    A += jnp.diag(jnp.arange(A.shape[0]) ** 6).astype(A.dtype)\n    b = jax.random.uniform(subkey, (A.shape[0],))\n    if isinstance(solver, lx.CG):\n        A = A.T @ A\n        tags = (lx.positive_semidefinite_tag,)\n    else:\n        tags = ()\n\n    A = lx.MatrixLinearOperator(A, tags=tags)\n    # exact inverse, should only take ~1 iteration\n    M = lx.MatrixLinearOperator(\n        jnp.linalg.inv(A.matrix),\n        tags=tags,\n    )\n\n    def solve(b):\n        out = lx.linear_solve(\n            A, b, solver=solver, options={\"preconditioner\": M}, throw=True\n        )\n        return out.value\n\n    # if they don't converge then this will throw an error\n    _ = solve(b)\n    A1 = jax.jacfwd(solve)(b)\n    A2 = jax.jacrev(solve)(b)\n\n    # we also do a sanity check, dx/db should give A^{-1}\n    assert tree_allclose(A1, jnp.linalg.inv(A.matrix), atol=tol, rtol=tol)\n    assert tree_allclose(A2, jnp.linalg.inv(A.matrix), atol=tol, rtol=tol)\n"
  },
  {
    "path": "tests/test_invert.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport equinox as eqx\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport lineax as lx\n\nfrom .helpers import tree_allclose\n\n\ndef _well_conditioned_matrix(getkey, size=3, dtype=jnp.float64):\n    \"\"\"Generate a well-conditioned random matrix.\"\"\"\n    while True:\n        matrix = jr.normal(getkey(), (size, size), dtype=dtype)\n        if jnp.linalg.cond(matrix) < 100:\n            return matrix\n\n\ndef _well_conditioned_psd_matrix(getkey, size=3, dtype=jnp.float64):\n    \"\"\"Generate a well-conditioned PSD matrix.\"\"\"\n    matrix = _well_conditioned_matrix(getkey, size, dtype)\n    return matrix @ matrix.T.conj()\n\n\n# -- Core behaviour --\n\n\ndef test_mv(getkey):\n    \"\"\"invert(A).mv(v) solves A x = v.\"\"\"\n    matrix = _well_conditioned_matrix(getkey)\n    op = lx.MatrixLinearOperator(matrix)\n    inv_op = lx.invert(op)\n    vec = jr.normal(getkey(), (3,), dtype=jnp.float64)\n    result = inv_op.mv(vec)\n    expected = jnp.linalg.solve(matrix, vec)\n    assert tree_allclose(result, expected, atol=1e-10)\n\n\ndef test_composition_identity(getkey):\n    \"\"\"(invert(A) @ A).mv(v) ~ v.\"\"\"\n    matrix = _well_conditioned_matrix(getkey)\n    op = lx.MatrixLinearOperator(matrix)\n    inv_op = lx.invert(op)\n    composed = inv_op @ op\n    vec = jr.normal(getkey(), (3,), dtype=jnp.float64)\n    result = composed.mv(vec)\n    assert tree_allclose(result, vec, atol=1e-10)\n\n\ndef test_double_inverse(getkey):\n    \"\"\"invert(invert(A)).mv(v) ~ A.mv(v).\"\"\"\n    matrix = _well_conditioned_matrix(getkey)\n    op = lx.MatrixLinearOperator(matrix)\n    double_inv = lx.invert(lx.invert(op))\n    vec = jr.normal(getkey(), (3,), dtype=jnp.float64)\n    result = double_inv.mv(vec)\n    expected = matrix @ vec\n    assert tree_allclose(result, expected, atol=1e-8)\n\n\n# -- Pseudoinverse (non-square) --\n\n\ndef test_pseudoinverse_overdetermined(getkey):\n    \"\"\"invert of a tall matrix gives the least-squares pseudoinverse.\"\"\"\n    matrix = jr.normal(getkey(), (5, 3), dtype=jnp.float64)\n    op = lx.MatrixLinearOperator(matrix)\n    pinv_op = lx.invert(op, solver=lx.AutoLinearSolver(well_posed=False))\n    vec = jr.normal(getkey(), (5,), dtype=jnp.float64)\n    result = pinv_op.mv(vec)\n    expected = jnp.linalg.lstsq(matrix, vec)[0]\n    assert tree_allclose(result, expected, atol=1e-8)\n\n\ndef test_pseudoinverse_underdetermined(getkey):\n    \"\"\"invert of a wide matrix gives the minimum-norm pseudoinverse.\"\"\"\n    matrix = jr.normal(getkey(), (3, 5), dtype=jnp.float64)\n    op = lx.MatrixLinearOperator(matrix)\n    pinv_op = lx.invert(op, solver=lx.AutoLinearSolver(well_posed=False))\n    vec = jr.normal(getkey(), (3,), dtype=jnp.float64)\n    result = pinv_op.mv(vec)\n    expected = jnp.linalg.lstsq(matrix, vec)[0]\n    assert tree_allclose(result, expected, atol=1e-8)\n\n\n# -- Explicit solver tests --\n\n\ndef test_solver_cholesky(getkey):\n    \"\"\"Works with Cholesky solver for PSD matrices.\"\"\"\n    matrix = _well_conditioned_psd_matrix(getkey)\n    op = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag)\n    inv_op = lx.invert(op, solver=lx.Cholesky())\n    vec = jr.normal(getkey(), (3,), dtype=jnp.float64)\n    result = inv_op.mv(vec)\n    expected = jnp.linalg.solve(matrix, vec)\n    assert tree_allclose(result, expected, atol=1e-10)\n\n\ndef test_solver_cg(getkey):\n    \"\"\"Works with CG (iterative) solver for PSD matrices.\"\"\"\n    matrix = _well_conditioned_psd_matrix(getkey)\n    op = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag)\n    inv_op = lx.invert(op, solver=lx.CG(rtol=1e-12, atol=1e-12))\n    vec = jr.normal(getkey(), (3,), dtype=jnp.float64)\n    result = inv_op.mv(vec)\n    expected = jnp.linalg.solve(matrix, vec)\n    assert tree_allclose(result, expected, atol=1e-8)\n\n\n# -- vmap --\n\n\ndef test_vmap(getkey):\n    \"\"\"vmap over invert(A).mv works correctly.\"\"\"\n    matrix = _well_conditioned_matrix(getkey)\n    op = lx.MatrixLinearOperator(matrix)\n    inv_op = lx.invert(op)\n    vecs = jr.normal(getkey(), (5, 3), dtype=jnp.float64)\n    result = jax.vmap(inv_op.mv)(vecs)\n    expected = jax.vmap(lambda v: jnp.linalg.solve(matrix, v))(vecs)\n    assert tree_allclose(result, expected, atol=1e-10)\n\n\n# -- AD --\n\n\ndef test_grad_wrt_vector(getkey):\n    \"\"\"VJP through invert(A).mv(v) wrt vector.\"\"\"\n    matrix = _well_conditioned_matrix(getkey)\n    op = lx.MatrixLinearOperator(matrix)\n    inv_op = lx.invert(op)\n\n    def f(vec):\n        return jnp.sum(inv_op.mv(vec))\n\n    vec = jr.normal(getkey(), (3,), dtype=jnp.float64)\n    grad = jax.grad(f)(vec)\n    expected = jnp.linalg.solve(matrix.T, jnp.ones(3, dtype=jnp.float64))\n    assert tree_allclose(grad, expected, atol=1e-10)\n\n\ndef test_jvp_wrt_vector(getkey):\n    \"\"\"JVP through invert(A).mv(v) wrt vector.\"\"\"\n    matrix = _well_conditioned_matrix(getkey)\n    op = lx.MatrixLinearOperator(matrix)\n    inv_op = lx.invert(op)\n\n    vec = jr.normal(getkey(), (3,), dtype=jnp.float64)\n    t_vec = jr.normal(getkey(), (3,), dtype=jnp.float64)\n\n    primals, tangents = eqx.filter_jvp(inv_op.mv, (vec,), (t_vec,))\n    expected_primals = jnp.linalg.solve(matrix, vec)\n    expected_tangents = jnp.linalg.solve(matrix, t_vec)\n    assert tree_allclose(primals, expected_primals, atol=1e-10)\n    assert tree_allclose(tangents, expected_tangents, atol=1e-10)\n\n\ndef test_grad_wrt_operator(getkey):\n    \"\"\"VJP through invert(A).mv(v) wrt the inner matrix.\"\"\"\n    matrix = _well_conditioned_matrix(getkey)\n    vec = jr.normal(getkey(), (3,), dtype=jnp.float64)\n\n    def f_inv(mat):\n        op = lx.MatrixLinearOperator(mat)\n        inv_op = lx.invert(op)\n        return jnp.sum(inv_op.mv(vec))\n\n    def f_jnp(mat):\n        return jnp.sum(jnp.linalg.solve(mat, vec))\n\n    grad_inv = jax.grad(f_inv)(matrix)\n    grad_jnp = jax.grad(f_jnp)(matrix)\n    assert tree_allclose(grad_inv, grad_jnp, atol=1e-8)\n\n\ndef test_jvp_wrt_operator(getkey):\n    \"\"\"JVP through invert(A).mv(v) wrt the inner matrix.\"\"\"\n    matrix = _well_conditioned_matrix(getkey)\n    t_matrix = jr.normal(getkey(), (3, 3), dtype=jnp.float64)\n    vec = jr.normal(getkey(), (3,), dtype=jnp.float64)\n\n    def f_inv(mat):\n        op = lx.MatrixLinearOperator(mat)\n        inv_op = lx.invert(op)\n        return inv_op.mv(vec)\n\n    def f_jnp(mat):\n        return jnp.linalg.solve(mat, vec)\n\n    out, t_out = eqx.filter_jvp(f_inv, (matrix,), (t_matrix,))\n    expected_out, expected_t_out = eqx.filter_jvp(f_jnp, (matrix,), (t_matrix,))\n    assert tree_allclose(out, expected_out, atol=1e-10)\n    assert tree_allclose(t_out, expected_t_out, atol=1e-8)\n"
  },
  {
    "path": "tests/test_jvp.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools as ft\n\nimport equinox as eqx\nimport jax.numpy as jnp\nimport jax.random as jr\nimport lineax as lx\nimport pytest\n\nfrom .helpers import (\n    construct_matrix,\n    construct_singular_matrix,\n    finite_difference_jvp,\n    has_tag,\n    make_jac_operator,\n    make_matrix_operator,\n    solvers_tags_pseudoinverse,\n    tree_allclose,\n)\n\n\n@pytest.mark.parametrize(\"make_operator\", (make_matrix_operator, make_jac_operator))\n@pytest.mark.parametrize(\"solver, tags, pseudoinverse\", solvers_tags_pseudoinverse)\n@pytest.mark.parametrize(\"use_state\", (True, False))\n@pytest.mark.parametrize(\n    \"make_matrix\",\n    (\n        construct_matrix,\n        construct_singular_matrix,\n    ),\n)\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_jvp(\n    getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype\n):\n    t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None\n\n    if (make_matrix is construct_matrix) or pseudoinverse:\n        matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype)\n\n        out_size, _ = matrix.shape\n        vec = jr.normal(getkey(), (out_size,), dtype=dtype)\n        t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)\n\n        if has_tag(tags, lx.unit_diagonal_tag):\n            # For all the other tags, A + εB with A, B \\in {matrices satisfying the tag}\n            # still satisfies the tag itself.\n            # This is the exception.\n            t_matrix.at[jnp.arange(3), jnp.arange(3)].set(0)\n\n        make_op = ft.partial(make_operator, getkey)\n        operator, t_operator = eqx.filter_jvp(\n            make_op, (matrix, tags), (t_matrix, t_tags)\n        )\n\n        if use_state:\n            state = solver.init(operator, options={})\n            linear_solve = ft.partial(lx.linear_solve, state=state)\n        else:\n            linear_solve = lx.linear_solve\n\n        solve_vec_only = lambda v: linear_solve(operator, v, solver).value\n        solve_op_only = lambda op: linear_solve(op, vec, solver).value\n        solve_op_vec = lambda op, v: linear_solve(op, v, solver).value\n\n        vec_out, t_vec_out = eqx.filter_jvp(solve_vec_only, (vec,), (t_vec,))\n        op_out, t_op_out = eqx.filter_jvp(solve_op_only, (operator,), (t_operator,))\n        op_vec_out, t_op_vec_out = eqx.filter_jvp(\n            solve_op_vec,\n            (operator, vec),\n            (t_operator, t_vec),\n        )\n        (expected_op_out, *_), (t_expected_op_out, *_) = eqx.filter_jvp(\n            lambda op: jnp.linalg.lstsq(op, vec),  # pyright: ignore\n            (matrix,),\n            (t_matrix,),\n        )\n        (expected_op_vec_out, *_), (t_expected_op_vec_out, *_) = eqx.filter_jvp(\n            jnp.linalg.lstsq,\n            (matrix, vec),\n            (t_matrix, t_vec),  # pyright: ignore\n        )\n\n        # Work around JAX issue #14868.\n        if jnp.any(jnp.isnan(t_expected_op_out)):\n            _, (t_expected_op_out, *_) = finite_difference_jvp(\n                lambda op: jnp.linalg.lstsq(op, vec),  # pyright: ignore\n                (matrix,),\n                (t_matrix,),\n            )\n        if jnp.any(jnp.isnan(t_expected_op_vec_out)):\n            _, (t_expected_op_vec_out, *_) = finite_difference_jvp(\n                jnp.linalg.lstsq,\n                (matrix, vec),\n                (t_matrix, t_vec),  # pyright: ignore\n            )\n\n        pinv_matrix = jnp.linalg.pinv(matrix)  # pyright: ignore\n        expected_vec_out = pinv_matrix @ vec\n        assert tree_allclose(vec_out, expected_vec_out)\n        assert tree_allclose(op_out, expected_op_out)\n        assert tree_allclose(op_vec_out, expected_op_vec_out)\n\n        t_expected_vec_out = pinv_matrix @ t_vec\n        assert tree_allclose(matrix @ t_vec_out, matrix @ t_expected_vec_out, rtol=1e-3)\n        assert tree_allclose(t_op_out, t_expected_op_out, rtol=1e-3)\n        assert tree_allclose(t_op_vec_out, t_expected_op_vec_out, rtol=1e-3)\n"
  },
  {
    "path": "tests/test_jvp_jvp1.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport equinox as eqx\nimport jax.numpy as jnp\nimport pytest\n\nfrom .helpers import (\n    construct_matrix,\n    construct_singular_matrix,\n    jvp_jvp_impl,\n    make_jac_operator,\n    make_matrix_operator,\n    solvers_tags_pseudoinverse,\n)\n\n\n# Workaround for https://github.com/jax-ml/jax/issues/27201\n@pytest.fixture(autouse=True)\ndef _clear_cache():\n    eqx.clear_caches()\n\n\n@pytest.mark.parametrize(\"solver, tags, pseudoinverse\", solvers_tags_pseudoinverse)\n@pytest.mark.parametrize(\"make_operator\", (make_matrix_operator, make_jac_operator))\n@pytest.mark.parametrize(\"use_state\", (True, False))\n@pytest.mark.parametrize(\"make_matrix\", (construct_matrix, construct_singular_matrix))\n@pytest.mark.parametrize(\"dtype\", (jnp.float64,))\ndef test_jvp_jvp(\n    getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype\n):\n    jvp_jvp_impl(\n        getkey,\n        solver,\n        tags,\n        pseudoinverse,\n        make_operator,\n        use_state,\n        make_matrix,\n        dtype,\n    )\n"
  },
  {
    "path": "tests/test_jvp_jvp2.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport equinox as eqx\nimport jax.numpy as jnp\nimport pytest\n\nfrom .helpers import (\n    construct_matrix,\n    construct_singular_matrix,\n    jvp_jvp_impl,\n    make_jac_operator,\n    make_matrix_operator,\n    solvers_tags_pseudoinverse,\n)\n\n\n# Workaround for https://github.com/jax-ml/jax/issues/27201\n@pytest.fixture(autouse=True)\ndef _clear_cache():\n    eqx.clear_caches()\n\n\n@pytest.mark.parametrize(\"solver, tags, pseudoinverse\", solvers_tags_pseudoinverse)\n@pytest.mark.parametrize(\"make_operator\", (make_matrix_operator, make_jac_operator))\n@pytest.mark.parametrize(\"use_state\", (True, False))\n@pytest.mark.parametrize(\"make_matrix\", (construct_matrix, construct_singular_matrix))\n@pytest.mark.parametrize(\"dtype\", (jnp.complex128,))\ndef test_jvp_jvp(\n    getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype\n):\n    jvp_jvp_impl(\n        getkey,\n        solver,\n        tags,\n        pseudoinverse,\n        make_operator,\n        use_state,\n        make_matrix,\n        dtype,\n    )\n"
  },
  {
    "path": "tests/test_lsmr.py",
    "content": "import equinox as ex\nimport jax.numpy as jnp\nimport lineax as lx\nimport pytest\n\n\nsolver = lx.LSMR(1e-10, 1e-10)\nAill = lx.DiagonalLinearOperator(jnp.array([1e8, 1e6, 1e4, 1e2, 1]))\nAwell = lx.DiagonalLinearOperator(jnp.array([2.0, 4.0, 5.0, 8.0, 10.0]))\nAsing = lx.DiagonalLinearOperator(jnp.array([0.0, 4.0, 5.0, 8.0, 10.0]))\n\n\ndef test_ill_conditioned():\n    try:\n        lx.linear_solve(Aill, jnp.ones(5), solver=solver)\n    except ex.EquinoxRuntimeError as e:\n        assert \"Condition number\" in str(e)\n\n\ndef test_zero_rhs():\n    # b=0, so x=0 is solution\n    sol = lx.linear_solve(Aill, jnp.zeros(5), solver=solver)\n    assert (sol.value == 0).all()\n    sol = lx.linear_solve(Awell, jnp.zeros(5), solver=solver)\n    assert (sol.value == 0).all()\n    sol = lx.linear_solve(Asing, jnp.zeros(5), solver=solver)\n    assert (sol.value == 0).all()\n    # b lies in null space of A, so x=0 is minimum norm solution\n    sol = lx.linear_solve(Asing, jnp.zeros(5).at[0].set(1), solver=solver)\n    assert (sol.value == 0).all()\n\n\n@pytest.mark.skip(\"Damp support is disabled.\")\ndef test_damp_regularizes():\n    solution_ill = lx.linear_solve(Aill, jnp.ones(5), solver=solver, options={})\n    assert solution_ill.stats[\"istop\"] == 1\n\n    solution_damped = lx.linear_solve(\n        Aill, jnp.ones(5), solver=solver, options={\"damp\": 100.0}\n    )\n    assert solution_damped.stats[\"istop\"] == 2\n\n    assert solution_damped.stats[\"num_steps\"] < solution_ill.stats[\"num_steps\"]\n\n\n@pytest.mark.skip(\"Damp support is disabled.\")\ndef test_damp():\n    solution_damped = lx.linear_solve(\n        Awell, jnp.ones(5), solver=solver, options={\"damp\": 1.0}\n    )\n    assert jnp.allclose(\n        solution_damped.value,\n        jnp.array([0.4, 0.23529412, 0.19230769, 0.12307692, 0.0990099]),\n    )\n    solution_damped = lx.linear_solve(\n        Awell, jnp.ones(5), solver=solver, options={\"damp\": 1000.0}\n    )\n    assert jnp.allclose(\n        solution_damped.value, jnp.array([2e-6, 4e-6, 5e-6, 8e-6, 10.0e-6])\n    )\n"
  },
  {
    "path": "tests/test_misc.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport jax\nimport jax.numpy as jnp\nimport lineax as lx\nimport lineax._misc as lx_misc\nimport pytest\n\n\ndef test_inexact_asarray_no_copy():\n    x = jnp.array([1.0])\n    assert lx_misc.inexact_asarray(x) is x\n    y = jnp.array([1.0, 2.0])\n    assert jax.vmap(lx_misc.inexact_asarray)(y) is y\n\n\n# See JAX issue #15676\ndef test_inexact_asarray_jvp():\n    p, t = jax.jvp(lx_misc.inexact_asarray, (1.0,), (2.0,))\n    assert type(p) is not float\n    assert type(t) is not float\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_zero_matrix(dtype):\n    A = lx.MatrixLinearOperator(jnp.zeros((2, 2), dtype=dtype))\n    b = jnp.array([1.0, 2.0], dtype=dtype)\n    lx.linear_solve(A, b, lx.SVD())\n"
  },
  {
    "path": "tests/test_norm.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport jax\nimport jax.flatten_util as jfu\nimport jax.numpy as jnp\nimport lineax.internal as lxi\n\nfrom .helpers import tree_allclose\n\n\ndef _square(x):\n    return x * jnp.conj(x)\n\n\ndef _two_norm(x):\n    return jnp.sqrt(jnp.sum(_square(jfu.ravel_pytree(x)[0]))).real\n\n\ndef _rms_norm(x):\n    return jnp.sqrt(jnp.mean(_square(jfu.ravel_pytree(x)[0]))).real\n\n\ndef _max_norm(x):\n    return jnp.max(jnp.abs(jfu.ravel_pytree(x)[0]))\n\n\ndef test_nonzero():\n    zero = [jnp.array(0.0), jnp.zeros((2, 2))]\n    x = [jnp.array(1.0), jnp.arange(4.0).reshape(2, 2)]\n    tx = [jnp.array(0.5), jnp.arange(1.0, 5.0).reshape(2, 2)]\n\n    two = lxi.two_norm(x)\n    rms = lxi.rms_norm(x)\n    max = lxi.max_norm(x)\n    true_two = _two_norm(x)\n    true_rms = _rms_norm(x)\n    true_max = _max_norm(x)\n    assert jnp.allclose(two, true_two)\n    assert jnp.allclose(rms, true_rms)\n    assert jnp.allclose(max, true_max)\n\n    two_jvp = jax.jvp(lxi.two_norm, (x,), (tx,))\n    true_two_jvp = jax.jvp(_two_norm, (x,), (tx,))\n    rms_jvp = jax.jvp(lxi.rms_norm, (x,), (tx,))\n    true_rms_jvp = jax.jvp(_rms_norm, (x,), (tx,))\n    max_jvp = jax.jvp(lxi.max_norm, (x,), (tx,))\n    true_max_jvp = jax.jvp(_max_norm, (x,), (tx,))\n    assert tree_allclose(two_jvp, true_two_jvp)\n    assert tree_allclose(rms_jvp, true_rms_jvp)\n    assert tree_allclose(max_jvp, true_max_jvp)\n\n    two0_jvp = jax.jvp(lxi.two_norm, (x,), (zero,))\n    rms0_jvp = jax.jvp(lxi.rms_norm, (x,), (zero,))\n    max0_jvp = jax.jvp(lxi.max_norm, (x,), (zero,))\n    assert tree_allclose(two0_jvp, (true_two, jnp.array(0.0)))\n    assert tree_allclose(rms0_jvp, (true_rms, jnp.array(0.0)))\n    assert tree_allclose(max0_jvp, (true_max, jnp.array(0.0)))\n\n\ndef test_zero():\n    zero = [jnp.array(0.0), jnp.zeros((2, 2))]\n    tx = [jnp.array(0.5), jnp.arange(1.0, 5.0).reshape(2, 2)]\n    for t in (zero, tx):\n        two0 = jax.jvp(lxi.two_norm, (zero,), (t,))\n        rms0 = jax.jvp(lxi.rms_norm, (zero,), (t,))\n        max0 = jax.jvp(lxi.max_norm, (zero,), (t,))\n        true0 = (jnp.array(0.0), jnp.array(0.0))\n        assert tree_allclose(two0, true0)\n        assert tree_allclose(rms0, true0)\n        assert tree_allclose(max0, true0)\n\n\ndef test_complex():\n    x = jnp.array([3 + 1.2j, -0.5 + 4.9j])\n    tx = jnp.array([2 - 0.3j, -0.7j])\n    two = jax.jvp(lxi.two_norm, (x,), (tx,))\n    true_two = jax.jvp(_two_norm, (x,), (tx,))\n    rms = jax.jvp(lxi.rms_norm, (x,), (tx,))\n    true_rms = jax.jvp(_rms_norm, (x,), (tx,))\n    max = jax.jvp(lxi.max_norm, (x,), (tx,))\n    true_max = jax.jvp(_max_norm, (x,), (tx,))\n    assert two[0].imag == 0\n    assert tree_allclose(two, true_two)\n    assert rms[0].imag == 0\n    assert tree_allclose(rms, true_rms)\n    assert max[0].imag == 0\n    assert tree_allclose(max, true_max)\n\n\ndef test_size_zero():\n    zero = jnp.array(0.0)\n    for x in (jnp.array([]), [jnp.array([]), jnp.array([])]):\n        assert tree_allclose(lxi.two_norm(x), zero)\n        assert tree_allclose(lxi.rms_norm(x), zero)\n        assert tree_allclose(lxi.max_norm(x), zero)\n        assert tree_allclose(jax.jvp(lxi.two_norm, (x,), (x,)), (zero, zero))\n        assert tree_allclose(jax.jvp(lxi.rms_norm, (x,), (x,)), (zero, zero))\n        assert tree_allclose(jax.jvp(lxi.max_norm, (x,), (x,)), (zero, zero))\n"
  },
  {
    "path": "tests/test_operator.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import cast\n\nimport equinox as eqx\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport lineax as lx\nimport pytest\n\nfrom .helpers import (\n    make_identity_operator,\n    make_jacrev_operator,\n    make_operators,\n    make_tridiagonal_operator,\n    make_trivial_diagonal_operator,\n    tree_allclose,\n)\n\n\n@pytest.mark.parametrize(\"make_operator\", make_operators)\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_ops(make_operator, getkey, dtype):\n    if (\n        make_operator is make_trivial_diagonal_operator\n        or make_operator is make_identity_operator\n    ):\n        matrix = jnp.eye(3, dtype=dtype)\n        tags = lx.diagonal_tag\n    elif make_operator is make_tridiagonal_operator:\n        matrix = jnp.eye(3, dtype=dtype)\n        tags = lx.tridiagonal_tag\n    else:\n        matrix = jr.normal(getkey(), (3, 3), dtype=dtype)\n        tags = ()\n    if make_operator is make_jacrev_operator and dtype is jnp.complex128:\n        # JacobianLinearOperator does not support complex dtypes when jac=\"bwd\"\n        return\n    matrix1 = make_operator(getkey, matrix, tags)\n    matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3), dtype=dtype))\n    scalar = jr.normal(getkey(), (), dtype=dtype)\n    add = matrix1 + matrix2\n    composed = matrix1 @ matrix2\n    mul = matrix1 * scalar\n    rmul = cast(lx.AbstractLinearOperator, scalar * matrix1)\n    div = matrix1 / scalar\n    vec = jr.normal(getkey(), (3,), dtype=dtype)\n\n    assert tree_allclose(matrix1.mv(vec) + matrix2.mv(vec), add.mv(vec))\n    assert tree_allclose(matrix1.mv(matrix2.mv(vec)), composed.mv(vec))\n    scalar_matvec = scalar * matrix1.mv(vec)\n    assert tree_allclose(scalar_matvec, mul.mv(vec))\n    assert tree_allclose(scalar_matvec, rmul.mv(vec))\n    assert tree_allclose(matrix1.mv(vec) / scalar, div.mv(vec))\n\n    add_matrix = matrix1.as_matrix() + matrix2.as_matrix()\n    composed_matrix = matrix1.as_matrix() @ matrix2.as_matrix()\n    mul_matrix = scalar * matrix1.as_matrix()\n    div_matrix = matrix1.as_matrix() / scalar\n    assert tree_allclose(add_matrix, add.as_matrix())\n    assert tree_allclose(composed_matrix, composed.as_matrix())\n    assert tree_allclose(mul_matrix, mul.as_matrix())\n    assert tree_allclose(mul_matrix, rmul.as_matrix())\n    assert tree_allclose(div_matrix, div.as_matrix())\n\n    assert tree_allclose(add_matrix.T, add.T.as_matrix())\n    assert tree_allclose(composed_matrix.T, composed.T.as_matrix())\n    assert tree_allclose(mul_matrix.T, mul.T.as_matrix())\n    assert tree_allclose(mul_matrix.T, rmul.T.as_matrix())\n    assert tree_allclose(div_matrix.T, div.T.as_matrix())\n\n\n@pytest.mark.parametrize(\"make_operator\", make_operators)\ndef test_structures_vector(make_operator, getkey):\n    if (\n        make_operator is make_trivial_diagonal_operator\n        or make_operator is make_identity_operator\n    ):\n        matrix = jnp.eye(4)\n        tags = lx.diagonal_tag\n        in_size = out_size = 4\n    elif make_operator is make_tridiagonal_operator:\n        matrix = jnp.eye(4)\n        tags = lx.tridiagonal_tag\n        in_size = out_size = 4\n    else:\n        matrix = jr.normal(getkey(), (3, 5))\n        tags = ()\n        in_size = 5\n        out_size = 3\n    operator = make_operator(getkey, matrix, tags)\n    in_structure = jax.ShapeDtypeStruct((in_size,), jnp.float64)\n    out_structure = jax.ShapeDtypeStruct((out_size,), jnp.float64)\n    assert tree_allclose(in_structure, operator.in_structure())\n    assert tree_allclose(out_structure, operator.out_structure())\n\n\ndef _setup(getkey, matrix, tag: object | frozenset[object] = frozenset()):\n    for make_operator in make_operators:\n        if make_operator is make_trivial_diagonal_operator and tag != lx.diagonal_tag:\n            continue\n        if make_operator is make_tridiagonal_operator and tag not in (\n            lx.tridiagonal_tag,\n            lx.diagonal_tag,\n            lx.symmetric_tag,\n        ):\n            continue\n        if make_operator is make_identity_operator and tag not in (\n            lx.tridiagonal_tag,\n            lx.diagonal_tag,\n            lx.symmetric_tag,\n        ):\n            continue\n        operator = make_operator(getkey, matrix, tag)\n        yield operator\n\n\ndef _assert_except_diag(cond_fun, operators, flip_cond):\n    if flip_cond:\n        _cond_fun = cond_fun\n        cond_fun = lambda x: not _cond_fun(x)\n    for operator in operators:\n        if isinstance(operator, lx.DiagonalLinearOperator):\n            assert not cond_fun(operator)\n        else:\n            assert cond_fun(operator)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_linearise(dtype, getkey):\n    matrix = jr.normal(getkey(), (3, 3), dtype=dtype)\n    operators = list(_setup(getkey, matrix))\n    vec = jr.normal(getkey(), (3,), dtype=dtype)\n    for operator in operators:\n        # Skip jacrev operators with complex dtype (jacrev doesn't support complex)\n        if (\n            isinstance(operator, lx.JacobianLinearOperator)\n            and operator.jac == \"bwd\"\n            and dtype is jnp.complex128\n        ):\n            continue\n        linearised = lx.linearise(operator)\n        # Actually evaluate the linearised operator to ensure it works\n        result = linearised.mv(vec)\n        expected = operator.mv(vec)\n        assert tree_allclose(result, expected)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_materialise(dtype, getkey):\n    operators = _setup(getkey, jr.normal(getkey(), (3, 3), dtype=dtype))\n    for operator in operators:\n        lx.materialise(operator)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_materialise_large(dtype, getkey):\n    operators = _setup(getkey, jr.normal(getkey(), (200, 500), dtype=dtype))\n    for operator in operators:\n        lx.materialise(operator)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_diagonal(dtype, getkey):\n    matrix = jr.normal(getkey(), (3, 3), dtype=dtype)\n    matrix_diag = jnp.diag(matrix)\n    # test we properly extract diagonal from a dense matrix when not tagged\n    operators = _setup(getkey, matrix)\n    for operator in operators:\n        assert jnp.allclose(lx.diagonal(operator), matrix_diag)\n    # test we properly extract diagonal from diagonal matrix when tagged\n    operators = _setup(getkey, jnp.diag(matrix_diag), lx.diagonal_tag)\n    for operator in operators:\n        if isinstance(operator, lx.IdentityLinearOperator):\n            assert jnp.allclose(lx.diagonal(operator), jnp.ones(3))\n        else:\n            assert jnp.allclose(lx.diagonal(operator), matrix_diag)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_tridiagonal(dtype, getkey):\n    matrix = jr.normal(getkey(), (5, 5), dtype=dtype)\n    matrix_diag = jnp.diag(matrix)\n    matrix_lower_diag = jnp.diag(matrix, k=-1)\n    matrix_upper_diag = jnp.diag(matrix, k=1)\n    tridiag_matrix = (\n        jnp.diag(matrix_diag)\n        + jnp.diag(matrix_lower_diag, k=-1)\n        + jnp.diag(matrix_upper_diag, k=1)\n    )\n    operators = _setup(getkey, tridiag_matrix, lx.tridiagonal_tag)\n    for operator in operators:\n        diag, lower_diag, upper_diag = lx.tridiagonal(operator)\n        if isinstance(operator, lx.IdentityLinearOperator):\n            assert jnp.allclose(diag, jnp.ones(5))\n            assert jnp.allclose(lower_diag, jnp.zeros(4))\n            assert jnp.allclose(upper_diag, jnp.zeros(4))\n        else:\n            assert jnp.allclose(diag, matrix_diag)\n            assert jnp.allclose(lower_diag, matrix_lower_diag)\n            assert jnp.allclose(upper_diag, matrix_upper_diag)\n\n    # Test ComposedLinearOperator: diagonal @ tridiagonal and tridiagonal @ diagonal\n    random_diag = jr.normal(getkey(), (5,), dtype=dtype)\n    tridiag_op = lx.TridiagonalLinearOperator(\n        matrix_diag, matrix_lower_diag, matrix_upper_diag\n    )\n    diag_op = lx.DiagonalLinearOperator(random_diag)\n\n    # diagonal @ tridiagonal (row scaling)\n    dt_matrix = jnp.matmul(jnp.diag(random_diag), tridiag_matrix)\n    diag, lower_diag, upper_diag = lx.tridiagonal(diag_op @ tridiag_op)\n    assert jnp.allclose(diag, jnp.diagonal(dt_matrix, 0))\n    assert jnp.allclose(lower_diag, jnp.diagonal(dt_matrix, -1))\n    assert jnp.allclose(upper_diag, jnp.diagonal(dt_matrix, 1))\n\n    # tridiagonal @ diagonal (column scaling)\n    td_matrix = jnp.matmul(tridiag_matrix, jnp.diag(random_diag))\n    diag, lower_diag, upper_diag = lx.tridiagonal(tridiag_op @ diag_op)\n    assert jnp.allclose(diag, jnp.diagonal(td_matrix, 0))\n    assert jnp.allclose(lower_diag, jnp.diagonal(td_matrix, -1))\n    assert jnp.allclose(upper_diag, jnp.diagonal(td_matrix, 1))\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_is_symmetric(dtype, getkey):\n    matrix = jr.normal(getkey(), (3, 3), dtype=dtype)\n    symmetric_operators = _setup(getkey, matrix.T @ matrix, lx.symmetric_tag)\n    for operator in symmetric_operators:\n        assert lx.is_symmetric(operator)\n\n    not_symmetric_operators = _setup(getkey, matrix)\n    _assert_except_diag(lx.is_symmetric, not_symmetric_operators, flip_cond=True)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_is_diagonal(dtype, getkey):\n    matrix = jr.normal(getkey(), (3, 3), dtype=dtype)\n    diagonal_operators = _setup(getkey, jnp.diag(jnp.diag(matrix)), lx.diagonal_tag)\n    for operator in diagonal_operators:\n        assert lx.is_diagonal(operator)\n\n    not_diagonal_operators = _setup(getkey, matrix)\n    _assert_except_diag(lx.is_diagonal, not_diagonal_operators, flip_cond=True)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_is_diagonal_scalar(dtype, getkey):\n    matrix = jr.normal(getkey(), (1, 1), dtype=dtype)\n    diagonal_operators = _setup(getkey, matrix)\n    for operator in diagonal_operators:\n        assert lx.is_diagonal(operator)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_is_diagonal_tridiagonal(dtype, getkey):\n    diag1 = jr.normal(getkey(), (1,), dtype=dtype)\n    diag2 = jnp.zeros((0,), dtype=dtype)\n    op1 = lx.TridiagonalLinearOperator(diag1, diag2, diag2)\n    assert lx.is_diagonal(op1)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_has_unit_diagonal(dtype, getkey):\n    matrix = jr.normal(getkey(), (3, 3), dtype=dtype)\n    not_unit_diagonal = _setup(getkey, matrix)\n    for operator in not_unit_diagonal:\n        assert not lx.has_unit_diagonal(operator)\n\n    matrix_unit_diag = matrix.at[jnp.arange(3), jnp.arange(3)].set(1)\n    unit_diagonal = _setup(getkey, matrix_unit_diag, lx.unit_diagonal_tag)\n    _assert_except_diag(lx.has_unit_diagonal, unit_diagonal, flip_cond=False)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_is_lower_triangular(dtype, getkey):\n    matrix = jr.normal(getkey(), (3, 3), dtype=dtype)\n    lower_triangular = _setup(getkey, jnp.tril(matrix), lx.lower_triangular_tag)\n    for operator in lower_triangular:\n        assert lx.is_lower_triangular(operator)\n\n    not_lower_triangular = _setup(getkey, matrix)\n    _assert_except_diag(lx.is_lower_triangular, not_lower_triangular, flip_cond=True)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_is_upper_triangular(dtype, getkey):\n    matrix = jr.normal(getkey(), (3, 3), dtype=dtype)\n    upper_triangular = _setup(getkey, jnp.triu(matrix), lx.upper_triangular_tag)\n    for operator in upper_triangular:\n        assert lx.is_upper_triangular(operator)\n\n    not_upper_triangular = _setup(getkey, matrix)\n    _assert_except_diag(lx.is_upper_triangular, not_upper_triangular, flip_cond=True)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_is_positive_semidefinite(dtype, getkey):\n    matrix = jr.normal(getkey(), (3, 3), dtype=dtype)\n    not_positive_semidefinite = _setup(getkey, matrix)\n    for operator in not_positive_semidefinite:\n        assert not lx.is_positive_semidefinite(operator)\n\n    positive_semidefinite = _setup(\n        getkey, matrix.T.conj() @ matrix, lx.positive_semidefinite_tag\n    )\n    _assert_except_diag(\n        lx.is_positive_semidefinite, positive_semidefinite, flip_cond=False\n    )\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_is_negative_semidefinite(dtype, getkey):\n    matrix = jr.normal(getkey(), (3, 3), dtype=dtype)\n    not_negative_semidefinite = _setup(getkey, matrix)\n    for operator in not_negative_semidefinite:\n        assert not lx.is_negative_semidefinite(operator)\n\n    negative_semidefinite = _setup(\n        getkey, -matrix.T.conj() @ matrix, lx.negative_semidefinite_tag\n    )\n    _assert_except_diag(\n        lx.is_negative_semidefinite, negative_semidefinite, flip_cond=False\n    )\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_is_tridiagonal(dtype, getkey):\n    diag1 = jr.normal(getkey(), (5,), dtype=dtype)\n    diag2 = jr.normal(getkey(), (4,), dtype=dtype)\n    diag3 = jr.normal(getkey(), (4,), dtype=dtype)\n    op1 = lx.TridiagonalLinearOperator(diag1, diag2, diag3)\n    op2 = lx.IdentityLinearOperator(jax.eval_shape(lambda: diag1))\n    op3 = lx.MatrixLinearOperator(jnp.diag(diag1))\n    assert lx.is_tridiagonal(op1)\n    assert lx.is_tridiagonal(op2)\n    assert not lx.is_tridiagonal(op3)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_tangent_as_matrix(dtype, getkey):\n    def _list_setup(matrix):\n        # Exclude jacrev operator: jac=\"bwd\" uses custom_vjp which doesn't support JVP\n        return [\n            op\n            for op in _setup(getkey, matrix)\n            if not (isinstance(op, lx.JacobianLinearOperator) and op.jac == \"bwd\")\n        ]\n\n    matrix = jr.normal(getkey(), (3, 3), dtype=dtype)\n    t_matrix = jr.normal(getkey(), (3, 3), dtype=dtype)\n    operators, t_operators = eqx.filter_jvp(_list_setup, (matrix,), (t_matrix,))\n    for operator, t_operator in zip(operators, t_operators):\n        t_operator = lx.TangentLinearOperator(operator, t_operator)\n        if isinstance(operator, lx.DiagonalLinearOperator):\n            assert jnp.allclose(operator.as_matrix(), jnp.diag(jnp.diag(matrix)))\n            assert jnp.allclose(t_operator.as_matrix(), jnp.diag(jnp.diag(t_matrix)))\n        else:\n            assert jnp.allclose(operator.as_matrix(), matrix)\n            assert jnp.allclose(t_operator.as_matrix(), t_matrix)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_materialise_function_linear_operator(dtype, getkey):\n    x = (\n        jr.normal(getkey(), (5, 9), dtype=dtype),\n        jr.normal(getkey(), (3,), dtype=dtype),\n    )\n    input_structure = jax.eval_shape(lambda: x)\n    fn = lambda x: {\"a\": jnp.broadcast_to(jnp.sum(x[0]), (1, 2))}\n    output_structure = jax.eval_shape(fn, input_structure)\n    operator = lx.FunctionLinearOperator(fn, input_structure)\n    materialised_operator = lx.materialise(operator)\n    assert materialised_operator.in_structure() == input_structure\n    assert materialised_operator.out_structure() == output_structure\n    assert isinstance(materialised_operator, lx.PyTreeLinearOperator)\n    expected_struct = {\n        \"a\": (\n            jax.ShapeDtypeStruct((1, 2, 5, 9), dtype),\n            jax.ShapeDtypeStruct((1, 2, 3), dtype),\n        )\n    }\n    assert jax.eval_shape(lambda: materialised_operator.pytree) == expected_struct\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_pytree_transpose(dtype, getkey):\n    out_struct = jax.eval_shape(\n        lambda: ({\"a\": jnp.zeros((2, 3, 3), dtype=dtype)}, jnp.zeros((2,), dtype=dtype))\n    )\n    in_struct = jax.eval_shape(lambda: {\"b\": jnp.zeros((4,), dtype=dtype)})\n    leaf1 = jr.normal(getkey(), (2, 3, 3, 4), dtype=dtype)\n    leaf2 = jr.normal(getkey(), (2, 4), dtype=dtype)\n    pytree = ({\"a\": {\"b\": leaf1}}, {\"b\": leaf2})\n    operator = lx.PyTreeLinearOperator(pytree, out_struct)\n    assert operator.in_structure() == in_struct\n    assert operator.out_structure() == out_struct\n    leaf1_T = jnp.moveaxis(leaf1, -1, 0)\n    leaf2_T = jnp.moveaxis(leaf2, -1, 0)\n    pytree_T = {\"b\": ({\"a\": leaf1_T}, leaf2_T)}\n    operator_T = operator.T\n    assert operator_T.in_structure() == out_struct\n    assert operator_T.out_structure() == in_struct\n    assert eqx.tree_equal(operator_T.pytree, pytree_T)  # pyright: ignore\n\n\ndef test_diagonal_tangent():\n    diag = jnp.array([1.0, 2.0, 3.0])\n    t_diag = jnp.array([4.0, 5.0, 6.0])\n\n    def run(diag):\n        op = lx.DiagonalLinearOperator(diag)\n        out = lx.linear_solve(op, jnp.array([1.0, 1.0, 1.0]), solver=lx.Diagonal())\n        return out.value\n\n    jax.jvp(run, (diag,), (t_diag,))\n\n\ndef test_identity_with_different_structures():\n    structure1 = (\n        jax.ShapeDtypeStruct((), jnp.float32),\n        jax.ShapeDtypeStruct((2, 3), jnp.float16),\n    )\n    structure2 = {\"a\": jax.ShapeDtypeStruct((5,), jnp.float32)}\n    # structure3 = (None, jax.ShapeDtypeStruct((2, 3), jnp.float16))\n    op1 = lx.IdentityLinearOperator(structure1, structure2)\n    op2 = lx.IdentityLinearOperator(structure2, structure1)\n    # op3 = lx.IdentityLinearOperator(structure3, structure2)\n\n    assert op1.T == op2\n    # assert op2.transpose((True, False)) == op3\n    assert jnp.array_equal(op1.as_matrix(), jnp.eye(5, 7, dtype=jnp.float32))\n    assert op1.in_size() == 7\n    assert op1.out_size() == 5\n    vec1 = (\n        jnp.array(1.0, dtype=jnp.float32),\n        jnp.array([[2, 3, 4], [5, 6, 7]], dtype=jnp.float16),\n    )\n    vec2 = {\"a\": jnp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=jnp.float32)}\n    vec1b = (\n        jnp.array(1.0, dtype=jnp.float32),\n        jnp.array([[2, 3, 4], [5, 0, 0]], dtype=jnp.float16),\n    )\n    assert tree_allclose(op1.mv(vec1), vec2)\n    assert tree_allclose(op2.mv(vec2), vec1b)\n\n\ndef test_identity_with_different_structures_complex():\n    structure1 = (\n        jax.ShapeDtypeStruct((), jnp.complex128),\n        jax.ShapeDtypeStruct((2, 3), jnp.float16),\n    )\n    structure2 = {\"a\": jax.ShapeDtypeStruct((5,), jnp.complex128)}\n    # structure3 = (None, jax.ShapeDtypeStruct((2, 3), jnp.float16))\n    op1 = lx.IdentityLinearOperator(structure1, structure2)\n    op2 = lx.IdentityLinearOperator(structure2, structure1)\n    # op3 = lx.IdentityLinearOperator(structure3, structure2)\n\n    assert op1.T == op2\n    # assert op2.transpose((True, False)) == op3\n    assert jnp.array_equal(op1.as_matrix(), jnp.eye(5, 7, dtype=jnp.complex128))\n    assert op1.in_size() == 7\n    assert op1.out_size() == 5\n    vec1 = (\n        jnp.array(1.0, dtype=jnp.complex128),\n        jnp.array([[2, 3, 4], [5, 6, 7]], dtype=jnp.float16),\n    )\n    vec2 = {\"a\": jnp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=jnp.complex128)}\n    vec1b = (\n        jnp.array(1.0, dtype=jnp.complex128),\n        jnp.array([[2, 3, 4], [5, 0, 0]], dtype=jnp.float16),\n    )\n    assert tree_allclose(op1.mv(vec1), vec2)\n    assert tree_allclose(op2.mv(vec2), vec1b)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_zero_pytree_as_matrix(dtype):\n    a = jnp.array([], dtype=dtype).reshape(2, 1, 0, 2, 1, 0)\n    struct = jax.ShapeDtypeStruct((2, 1, 0), a.dtype)\n    op = lx.PyTreeLinearOperator(a, struct)\n    assert op.as_matrix().shape == (0, 0)\n\n\ndef test_jacrev_operator():\n    # Test that custom_vjp is respected. The custom backward multiplies by 3\n    # instead of the true derivative (which would be 2).\n    # This tests that lineax uses the custom_vjp, not the true derivative.\n    @jax.custom_vjp\n    def f(x, _):\n        return dict(foo=x[\"bar\"] * 2)  # forward: multiply by 2\n\n    def f_fwd(x, _):\n        return f(x, None), None\n\n    def f_bwd(_, g):\n        # Custom backward: multiply by 3 (not the true derivative 2)\n        # This must be linear in g for linear_transpose to work correctly.\n        return dict(bar=g[\"foo\"] * 3), None\n\n    f.defvjp(f_fwd, f_bwd)\n\n    x = dict(bar=jnp.arange(2.0))\n    rev_op = lx.JacobianLinearOperator(f, x, jac=\"bwd\")\n    # Jacobian is 3*I (from custom backward, not 2*I from true derivative)\n    as_matrix = jnp.array([[3.0, 0.0], [0.0, 3.0]])\n    assert tree_allclose(rev_op.as_matrix(), as_matrix)\n\n    y = dict(bar=jnp.arange(2.0) + 1)  # y = [1, 2]\n    true_out = dict(foo=jnp.array([3.0, 6.0]))  # 3*I @ [1, 2] = [3, 6]\n    for op in (rev_op, lx.materialise(rev_op)):\n        out = op.mv(y)\n        assert tree_allclose(out, true_out)\n\n    fwd_op = lx.JacobianLinearOperator(f, x, jac=\"fwd\")\n    with pytest.raises(TypeError, match=\"can't apply forward-mode autodiff\"):\n        fwd_op.mv(y)\n    with pytest.raises(TypeError, match=\"can't apply forward-mode autodiff\"):\n        lx.materialise(fwd_op)\n"
  },
  {
    "path": "tests/test_singular.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools as ft\n\nimport equinox as eqx\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport lineax as lx\nimport pytest\n\nfrom .helpers import (\n    construct_singular_matrix,\n    finite_difference_jvp,\n    make_jac_operator,\n    make_matrix_operator,\n    ops,\n    params,\n    tol,\n    tree_allclose,\n)\n\n\n@pytest.mark.parametrize(\"make_operator,solver,tags\", params(only_pseudo=True))\n@pytest.mark.parametrize(\"ops\", ops)\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_small_singular(make_operator, solver, tags, ops, getkey, dtype):\n    if jax.config.jax_enable_x64:  # pyright: ignore\n        tol = 1e-10\n    else:\n        tol = 1e-4\n    (matrix,) = construct_singular_matrix(getkey, solver, tags, dtype=dtype)\n    operator = make_operator(getkey, matrix, tags)\n    operator, matrix = ops(operator, matrix)\n    assert tree_allclose(operator.as_matrix(), matrix, rtol=tol, atol=tol)\n    out_size, in_size = matrix.shape\n    true_x = jr.normal(getkey(), (in_size,), dtype=dtype)\n    b = matrix @ true_x\n    x = lx.linear_solve(operator, b, solver=solver, throw=False).value\n    jax_x, *_ = jnp.linalg.lstsq(matrix, b)  # pyright: ignore\n    assert tree_allclose(x, jax_x, atol=tol, rtol=tol)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_bicgstab_breakdown(getkey, dtype):\n    if jax.config.jax_enable_x64:  # pyright: ignore\n        tol = 1e-10\n    else:\n        tol = 1e-4\n    solver = lx.GMRES(atol=tol, rtol=tol, restart=2)\n\n    matrix = jr.normal(jr.PRNGKey(0), (100, 100), dtype=dtype)\n    true_x = jr.normal(jr.PRNGKey(0), (100,), dtype=dtype)\n    b = matrix @ true_x\n    operator = lx.MatrixLinearOperator(matrix)\n\n    # result != 0 implies lineax reported failure\n    lx_soln = lx.linear_solve(operator, b, solver, throw=False)\n\n    assert jnp.all(lx_soln.result != lx.RESULTS.successful)\n\n\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_gmres_stagnation_or_breakdown(getkey, dtype):\n    if jax.config.jax_enable_x64:  # pyright: ignore\n        tol = 1e-10\n    else:\n        tol = 1e-4\n    solver = lx.GMRES(atol=tol, rtol=tol, restart=2)\n\n    matrix = jnp.array(\n        [\n            [0.15892892, 0.05884365, -0.60427412, 0.1891916],\n            [-1.5484863, 0.93608822, 1.94888868, 1.37069667],\n            [0.62687318, -0.13996738, -0.6824359, 0.30975754],\n            [-0.67428635, 1.52372255, -0.88277754, 0.69633816],\n        ],\n        dtype=dtype,\n    )\n    true_x = jnp.array([0.51383273, 1.72983427, -0.43251078, -1.11764668], dtype=dtype)\n    b = matrix @ true_x\n    operator = lx.MatrixLinearOperator(matrix)\n\n    # result != 0 implies lineax reported failure\n    lx_soln = lx.linear_solve(operator, b, solver, throw=False)\n\n    assert jnp.all(lx_soln.result != lx.RESULTS.successful)\n\n\n@pytest.mark.parametrize(\n    \"solver\",\n    (\n        lx.AutoLinearSolver(well_posed=None),\n        lx.QR(),\n        lx.SVD(),\n        lx.LSMR(atol=tol, rtol=tol),\n        lx.Normal(lx.Cholesky()),\n        lx.Normal(lx.SVD()),\n    ),\n)\ndef test_nonsquare_pytree_operator1(solver):\n    x = [[1, 5.0, jnp.array(-1.0)], [jnp.array(-2), jnp.array(-2.0), 3.0]]\n    y = [3.0, 4]\n    struct = jax.eval_shape(lambda: y)\n    operator = lx.PyTreeLinearOperator(x, struct)\n    out = lx.linear_solve(operator, y, solver=solver).value\n    matrix = jnp.array([[1.0, 5.0, -1.0], [-2.0, -2.0, 3.0]])\n    true_out, _, _, _ = jnp.linalg.lstsq(matrix, jnp.array(y))  # pyright: ignore\n    true_out = [true_out[0], true_out[1], true_out[2]]\n    assert tree_allclose(out, true_out)\n\n\n@pytest.mark.parametrize(\n    \"solver\",\n    (\n        lx.AutoLinearSolver(well_posed=None),\n        lx.QR(),\n        lx.SVD(),\n        lx.LSMR(atol=tol, rtol=tol),\n        lx.Normal(lx.Cholesky()),\n        lx.Normal(lx.SVD()),\n    ),\n)\ndef test_nonsquare_pytree_operator2(solver):\n    x = [[1, jnp.array(-2)], [5.0, jnp.array(-2.0)], [jnp.array(-1.0), 3.0]]\n    y = [3.0, 4, 5.0]\n    struct = jax.eval_shape(lambda: y)\n    operator = lx.PyTreeLinearOperator(x, struct)\n    out = lx.linear_solve(operator, y, solver=solver).value\n    matrix = jnp.array([[1.0, -2.0], [5.0, -2.0], [-1.0, 3.0]])\n    true_out, _, _, _ = jnp.linalg.lstsq(matrix, jnp.array(y))  # pyright: ignore\n    true_out = [true_out[0], true_out[1]]\n    assert tree_allclose(out, true_out)\n\n\n@pytest.mark.parametrize(\n    \"solver\",\n    (\n        lx.AutoLinearSolver(well_posed=None),\n        lx.QR(),\n        lx.SVD(),\n        lx.Normal(lx.Cholesky()),\n        lx.Normal(lx.SVD()),\n    ),\n)\n@pytest.mark.parametrize(\"full_rank\", (True, False))\n@pytest.mark.parametrize(\"jvp\", (False, True))\n@pytest.mark.parametrize(\"wide\", (False, True))\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_nonsquare_mat_vec(solver, full_rank, jvp, wide, dtype, getkey):\n    if wide:\n        out_size = 3\n        in_size = 6\n    else:\n        out_size = 6\n        in_size = 3\n    matrix = jr.normal(getkey(), (out_size, in_size), dtype=dtype)\n    if not full_rank:\n        if solver.assume_full_rank():\n            # There is nothing to test.\n            return\n        # nontrivial rank 2 sparsity pattern\n        matrix = matrix.at[1:, 1:].set(0)\n    vector = jr.normal(getkey(), (out_size,), dtype=dtype)\n    lx_solve = lambda mat, vec: lx.linear_solve(\n        lx.MatrixLinearOperator(mat), vec, solver\n    ).value\n    jnp_solve = lambda mat, vec: jnp.linalg.lstsq(mat, vec)[0]  # pyright: ignore\n    if jvp:\n        lx_solve = eqx.filter_jit(ft.partial(eqx.filter_jvp, lx_solve))\n        jnp_solve = eqx.filter_jit(ft.partial(finite_difference_jvp, jnp_solve))\n        t_matrix = jr.normal(getkey(), (out_size, in_size), dtype=dtype)\n        if not full_rank:\n            # t_matrix must be chosen tangent to the manifold of rank 2\n            # matrices at matrix. A simple way to achieve this is to make the\n            # same restriction as we did to matrix\n            t_matrix = t_matrix.at[1:, 1:].set(0)\n        t_vector = jr.normal(getkey(), (out_size,), dtype=dtype)\n        args = ((matrix, vector), (t_matrix, t_vector))\n    else:\n        args = (matrix, vector)\n    x = lx_solve(*args)  # pyright: ignore\n    true_x = jnp_solve(*args)\n    assert tree_allclose(x, true_x, atol=1e-4, rtol=1e-4)\n\n\n@pytest.mark.parametrize(\n    \"solver\",\n    (\n        lx.AutoLinearSolver(well_posed=None),\n        lx.QR(),\n        lx.SVD(),\n        lx.Normal(lx.Cholesky()),\n        lx.Normal(lx.SVD()),\n    ),\n)\n@pytest.mark.parametrize(\"full_rank\", (True, False))\n@pytest.mark.parametrize(\"jvp\", (False, True))\n@pytest.mark.parametrize(\"wide\", (False, True))\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_nonsquare_vec(solver, full_rank, jvp, wide, dtype, getkey):\n    if wide:\n        out_size = 3\n        in_size = 6\n    else:\n        out_size = 6\n        in_size = 3\n    matrix = jr.normal(getkey(), (out_size, in_size), dtype=dtype)\n    if not full_rank:\n        if solver.assume_full_rank():\n            # There is nothing to test.\n            return\n        # nontrivial rank 2 sparsity pattern\n        matrix = matrix.at[1:, 1:].set(0)\n    vector = jr.normal(getkey(), (out_size,), dtype=dtype)\n    lx_solve = lambda vec: lx.linear_solve(\n        lx.MatrixLinearOperator(matrix), vec, solver\n    ).value\n    jnp_solve = lambda vec: jnp.linalg.lstsq(matrix, vec)[0]  # pyright: ignore\n    if jvp:\n        lx_solve = eqx.filter_jit(ft.partial(eqx.filter_jvp, lx_solve))\n        jnp_solve = eqx.filter_jit(ft.partial(finite_difference_jvp, jnp_solve))\n        t_vector = jr.normal(getkey(), (out_size,), dtype=dtype)\n        args = ((vector,), (t_vector,))\n    else:\n        args = (vector,)\n    x = lx_solve(*args)  # pyright: ignore\n    true_x = jnp_solve(*args)\n    assert tree_allclose(x, true_x, atol=1e-4, rtol=1e-4)\n\n\n_iterative_solvers = (\n    (lx.CG(rtol=tol, atol=tol), lx.positive_semidefinite_tag),\n    (lx.CG(rtol=tol, atol=tol, max_steps=512), lx.negative_semidefinite_tag),\n    (lx.GMRES(rtol=tol, atol=tol), ()),\n    (lx.BiCGStab(rtol=tol, atol=tol), ()),\n)\n\n\n@pytest.mark.parametrize(\"make_operator\", (make_matrix_operator, make_jac_operator))\n@pytest.mark.parametrize(\"solver, tags\", _iterative_solvers)\n@pytest.mark.parametrize(\"use_state\", (False, True))\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_iterative_singular(getkey, solver, tags, use_state, make_operator, dtype):\n    (matrix,) = construct_singular_matrix(getkey, solver, tags)\n    operator = make_operator(getkey, matrix, tags)\n\n    out_size, _ = matrix.shape\n    vec = jr.normal(getkey(), (out_size,), dtype=dtype)\n\n    if use_state:\n        state = solver.init(operator, options={})\n        linear_solve = ft.partial(lx.linear_solve, state=state)\n    else:\n        linear_solve = lx.linear_solve\n\n    with pytest.raises(Exception):\n        linear_solve(operator, vec, solver)\n"
  },
  {
    "path": "tests/test_solve.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport lineax as lx\nimport pytest\n\nfrom .helpers import construct_poisson_matrix, tree_allclose\n\n\ndef test_gmres_large_dense(getkey):\n    if jax.config.jax_enable_x64:  # pyright: ignore\n        tol = 1e-10\n    else:\n        tol = 1e-4\n    solver = lx.GMRES(atol=tol, rtol=tol, restart=100)\n\n    matrix = jr.normal(getkey(), (100, 100))\n    operator = lx.MatrixLinearOperator(matrix)\n    true_x = jr.normal(getkey(), (100,))\n    b = matrix @ true_x\n\n    lx_soln = lx.linear_solve(operator, b, solver).value\n\n    assert tree_allclose(lx_soln, true_x, atol=tol, rtol=tol)\n\n\ndef test_nontrivial_pytree_operator():\n    x = [[1, 5.0], [jnp.array(-2), jnp.array(-2.0)]]\n    y = [3, 4]\n    struct = jax.eval_shape(lambda: y)\n    operator = lx.PyTreeLinearOperator(x, struct)\n    out = lx.linear_solve(operator, y).value\n    true_out = [jnp.array(-3.25), jnp.array(1.25)]\n    assert tree_allclose(out, true_out)\n\n\ndef test_nontrivial_diagonal_operator():\n    x = (8.0, jnp.array([1, 2, 3]), {\"a\": jnp.array([4, 5]), \"b\": 6})\n    y = (4.0, jnp.array([7, 8, 9]), {\"a\": jnp.array([2, 10]), \"b\": 12})\n    operator = lx.DiagonalLinearOperator(x)\n    out = lx.linear_solve(operator, y).value\n    true_out = (\n        jnp.array(0.5),\n        jnp.array([7.0, 4.0, 3.0]),\n        {\"a\": jnp.array([0.5, 2.0]), \"b\": jnp.array(2.0)},\n    )\n    assert tree_allclose(out, true_out)\n\n\n@pytest.mark.parametrize(\"solver\", (lx.LU(), lx.QR(), lx.SVD()))\ndef test_mixed_dtypes(solver):\n    f32 = lambda x: jnp.array(x, dtype=jnp.float32)\n    f64 = lambda x: jnp.array(x, dtype=jnp.float64)\n    x = [[f32(1), f64(5)], [f32(-2), f64(-2)]]\n    y = [f64(3), f64(4)]\n    struct = jax.eval_shape(lambda: y)\n    operator = lx.PyTreeLinearOperator(x, struct)\n    out = lx.linear_solve(operator, y, solver=solver).value\n    true_out = [f32(-3.25), f64(1.25)]\n    assert tree_allclose(out, true_out)\n\n\n@pytest.mark.parametrize(\"solver\", (lx.LU(), lx.QR(), lx.SVD()))\ndef test_mixed_dtypes_complex(solver):\n    c64 = lambda x: jnp.array(x, dtype=jnp.complex64)\n    c128 = lambda x: jnp.array(x, dtype=jnp.complex128)\n    x = [[c64(1), c128(5.0j)], [c64(2.0j), c128(-2)]]\n    y = [c128(3), c128(4)]\n    struct = jax.eval_shape(lambda: y)\n    operator = lx.PyTreeLinearOperator(x, struct)\n    out = lx.linear_solve(operator, y, solver=solver).value\n    true_out = [c64(-0.75 - 2.5j), c128(0.5 - 0.75j)]\n    assert tree_allclose(out, true_out)\n\n\n@pytest.mark.parametrize(\"solver\", (lx.LU(), lx.QR(), lx.SVD()))\ndef test_mixed_dtypes_complex_real(solver):\n    f64 = lambda x: jnp.array(x, dtype=jnp.float64)\n    c128 = lambda x: jnp.array(x, dtype=jnp.complex128)\n    x = [[f64(1), c128(-5.0j)], [f64(2.0), c128(-2j)]]\n    y = [c128(3), c128(4)]\n    struct = jax.eval_shape(lambda: y)\n    operator = lx.PyTreeLinearOperator(x, struct)\n    out = lx.linear_solve(operator, y, solver=solver).value\n    true_out = [f64(1.75), c128(0.25j)]\n    assert tree_allclose(out, true_out)\n\n\ndef test_mixed_dtypes_triangular():\n    f32 = lambda x: jnp.array(x, dtype=jnp.float32)\n    f64 = lambda x: jnp.array(x, dtype=jnp.float64)\n    x = [[f32(1), f64(0)], [f32(-2), f64(-2)]]\n    y = [f64(3), f64(4)]\n    struct = jax.eval_shape(lambda: y)\n    operator = lx.PyTreeLinearOperator(x, struct, lx.lower_triangular_tag)\n    out = lx.linear_solve(operator, y, solver=lx.Triangular()).value\n    true_out = [f32(3), f64(-5)]\n    assert tree_allclose(out, true_out)\n\n\ndef test_mixed_dtypes_complex_triangular():\n    c64 = lambda x: jnp.array(x, dtype=jnp.complex64)\n    c128 = lambda x: jnp.array(x, dtype=jnp.complex128)\n    x = [[c64(1), c128(0)], [c64(2.0j), c128(-2)]]\n    y = [c128(3), c128(4)]\n    struct = jax.eval_shape(lambda: y)\n    operator = lx.PyTreeLinearOperator(x, struct, lx.lower_triangular_tag)\n    out = lx.linear_solve(operator, y, solver=lx.Triangular()).value\n    true_out = [c64(3), c128(-2 + 3.0j)]\n    assert tree_allclose(out, true_out)\n\n\ndef test_mixed_dtypes_complex_real_triangular():\n    f64 = lambda x: jnp.array(x, dtype=jnp.float64)\n    c128 = lambda x: jnp.array(x, dtype=jnp.complex128)\n    x = [[f64(1), c128(0)], [f64(2.0), c128(2j)]]\n    y = [c128(3), c128(4)]\n    struct = jax.eval_shape(lambda: y)\n    operator = lx.PyTreeLinearOperator(x, struct, lx.lower_triangular_tag)\n    out = lx.linear_solve(operator, y, solver=lx.Triangular()).value\n    true_out = [f64(3), c128(1j)]\n    assert tree_allclose(out, true_out)\n\n\ndef test_ad_closure_function_linear_operator(getkey):\n    def f(x, z):\n        def fn(y):\n            return x * y\n\n        op = lx.FunctionLinearOperator(fn, jax.eval_shape(lambda: z))\n        sol = lx.linear_solve(op, z).value\n        return jnp.sum(sol), sol\n\n    x = jr.normal(getkey(), (3,))\n    x = jnp.where(jnp.abs(x) < 1e-6, 0.7, x)\n    z = jr.normal(getkey(), (3,))\n    grad, sol = jax.grad(f, has_aux=True)(x, z)\n    assert tree_allclose(grad, -z / (x**2))\n    assert tree_allclose(sol, z / x)\n\n\ndef test_grad_vmap_symbolic_cotangent():\n    def f(x):\n        return x[0], x[1]\n\n    @jax.vmap\n    def to_vmap(x):\n        op = lx.FunctionLinearOperator(f, jax.eval_shape(lambda: x))\n        sol = lx.linear_solve(op, x)\n        return sol.value[0]\n\n    @jax.grad\n    def to_grad(x):\n        return jnp.sum(to_vmap(x))\n\n    x = (jnp.arange(3.0), jnp.arange(3.0))\n    to_grad(x)\n\n\n@pytest.mark.parametrize(\n    \"solver\",\n    (\n        lx.CG(0.0, 0.0, max_steps=2),\n        lx.Normal(lx.CG(0.0, 0.0, max_steps=2)),\n        lx.BiCGStab(0.0, 0.0, max_steps=2),\n        lx.GMRES(0.0, 0.0, max_steps=2),\n        lx.LSMR(0.0, 0.0, max_steps=2),\n    ),\n)\ndef test_iterative_solver_max_steps_only(solver):\n    \"\"\"Iterative solvers should work with max_steps only (no Equinox errors).\"\"\"\n    SIZE = 100\n\n    poisson_matrix = construct_poisson_matrix(SIZE)\n    poisson_operator = lx.MatrixLinearOperator(\n        poisson_matrix, tags=(lx.negative_semidefinite_tag, lx.symmetric_tag)\n    )\n    rhs = jax.random.normal(jax.random.key(0), (SIZE,))\n\n    lx.linear_solve(poisson_operator, rhs, solver)\n\n\ndef test_solver_init_not_differentiated(getkey):\n    \"\"\"stop_gradient should be applied before solver.init, not after.\n\n    Also checks that dynamic arrays in options don't cause issues.\n    \"\"\"\n\n    class DisallowGradWrapper(lx._solve.AbstractLinearSolver):\n        solver: lx._solve.AbstractLinearSolver\n\n        def init(self, operator, options):\n            @jax.custom_jvp\n            def f(operator, dummy):\n                del dummy\n                return self.solver.init(operator, options)\n\n            @f.defjvp\n            def _(*args):\n                raise NotImplementedError(\"solver.init should not be differentiated\")\n\n            return f(operator, options.get(\"dummy\"))\n\n        def compute(self, state, vector, options):\n            return self.solver.compute(state, vector, options)\n\n        def transpose(self, state, options):\n            return self.solver.transpose(state, options)\n\n        def conj(self, state, options):\n            return self.solver.conj(state, options)\n\n        def assume_full_rank(self):\n            return self.solver.assume_full_rank()\n\n    m = jax.random.normal(getkey(), (3, 3))\n    mt = jax.random.normal(getkey(), (3, 3))\n    v = jax.random.normal(getkey(), (3,))\n    dummy = jnp.array(1.0)\n\n    def f(m):\n        op = lx.MatrixLinearOperator(m)\n        return lx.linear_solve(\n            op, v, solver=DisallowGradWrapper(lx.QR()), options={\"dummy\": dummy}\n        ).value\n\n    # Differentiating through operator only, but options has a dynamic array.\n    # solver.init should not be differentiated through.\n    jax.jvp(f, (m,), (mt,))\n\n    _, f_vjp = jax.vjp(f, m)\n    f_vjp(v)\n\n\ndef test_nonfinite_input():\n    operator = lx.DiagonalLinearOperator((1.0, 1.0))\n    vector = (1.0, jnp.inf)\n    sol = lx.linear_solve(operator, vector, throw=False)\n    assert sol.result == lx.RESULTS.nonfinite_input\n\n    vector = (1.0, jnp.nan)\n    sol = lx.linear_solve(operator, vector, throw=False)\n    assert sol.result == lx.RESULTS.nonfinite_input\n\n    vector = (jnp.nan, jnp.inf)\n    sol = lx.linear_solve(operator, vector, throw=False)\n    assert sol.result == lx.RESULTS.nonfinite_input\n"
  },
  {
    "path": "tests/test_transpose.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport equinox as eqx\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport lineax as lx\nimport pytest\n\nfrom .helpers import construct_matrix, params, tree_allclose\n\n\nclass TestTranspose:\n    @pytest.fixture(scope=\"class\")\n    def assert_transpose_fixture(_):\n        @eqx.filter_jit\n        def solve_transpose(operator, out_vec, in_vec, solver):\n            return jax.linear_transpose(\n                lambda v: lx.linear_solve(operator, v, solver).value, out_vec\n            )(in_vec)\n\n        def assert_transpose(operator, out_vec, in_vec, solver):\n            (out,) = solve_transpose(operator, out_vec, in_vec, solver)\n            true_out = lx.linear_solve(operator.T, in_vec, solver).value\n            assert tree_allclose(out, true_out)\n\n        return assert_transpose\n\n    @pytest.mark.parametrize(\"make_operator,solver,tags\", params(only_pseudo=False))\n    @pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\n    def test_transpose(\n        _, make_operator, solver, tags, assert_transpose_fixture, dtype, getkey\n    ):\n        (matrix,) = construct_matrix(getkey, solver, tags, dtype=dtype)\n        operator = make_operator(getkey, matrix, tags)\n        out_size, in_size = matrix.shape\n        out_vec = jr.normal(getkey(), (out_size,), dtype=dtype)\n        in_vec = jr.normal(getkey(), (in_size,), dtype=dtype)\n        solver = lx.AutoLinearSolver(well_posed=True)\n        assert_transpose_fixture(operator, out_vec, in_vec, solver)\n\n    def test_pytree_transpose(_, assert_transpose_fixture):  # pyright: ignore\n        a = jnp.array\n        pytree = [[a(1), a(2), a(3)], [a(4), a(5), a(6)]]\n        output_structure = jax.eval_shape(lambda: [1, 2])\n        operator = lx.PyTreeLinearOperator(pytree, output_structure)\n        out_vec = [a(1.0), a(2.0)]\n        in_vec = [a(1.0), 2.0, 3.0]\n        solver = lx.AutoLinearSolver(well_posed=False)\n        assert_transpose_fixture(operator, out_vec, in_vec, solver)\n"
  },
  {
    "path": "tests/test_vmap.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport equinox as eqx\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport lineax as lx\nimport pytest\n\nfrom .helpers import (\n    construct_matrix,\n    construct_singular_matrix,\n    make_jac_operator,\n    make_matrix_operator,\n    solvers_tags_pseudoinverse,\n    tree_allclose,\n)\n\n\n@pytest.mark.parametrize(\"make_operator\", (make_matrix_operator, make_jac_operator))\n@pytest.mark.parametrize(\"solver, tags, pseudoinverse\", solvers_tags_pseudoinverse)\n@pytest.mark.parametrize(\"use_state\", (True, False))\n@pytest.mark.parametrize(\n    \"make_matrix\",\n    (\n        construct_matrix,\n        construct_singular_matrix,\n    ),\n)\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_vmap(\n    getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix, dtype\n):\n    if (make_matrix is construct_matrix) or pseudoinverse:\n\n        def wrap_solve(matrix, vector):\n            operator = make_operator(getkey, matrix, tags)\n            if use_state:\n                state = solver.init(operator, options={})\n                return lx.linear_solve(operator, vector, solver, state=state).value\n            else:\n                return lx.linear_solve(operator, vector, solver).value\n\n        for op_axis, vec_axis in (\n            (None, 0),\n            (eqx.if_array(0), None),\n            (eqx.if_array(0), 0),\n        ):\n            if op_axis is None:\n                axis_size = None\n                out_axes = None\n            else:\n                axis_size = 10\n                out_axes = eqx.if_array(0)\n\n            (matrix,) = eqx.filter_vmap(\n                lambda getkey, solver, tags: make_matrix(\n                    getkey, solver, tags, dtype=dtype\n                ),\n                axis_size=axis_size,\n                out_axes=out_axes,\n            )(getkey, solver, tags)\n            out_dim = matrix.shape[-2]\n\n            if vec_axis is None:\n                vec = jr.normal(getkey(), (out_dim,), dtype=dtype)\n            else:\n                vec = jr.normal(getkey(), (10, out_dim), dtype=dtype)\n\n            jax_result, _, _, _ = eqx.filter_vmap(\n                jnp.linalg.lstsq,\n                in_axes=(op_axis, vec_axis),  # pyright: ignore\n            )(matrix, vec)\n            lx_result = eqx.filter_vmap(wrap_solve, in_axes=(op_axis, vec_axis))(\n                matrix, vec\n            )\n            assert tree_allclose(lx_result, jax_result)\n\n\n# https://github.com/patrick-kidger/lineax/issues/101\ndef test_grad_vmap_basic(getkey):\n    A = jr.normal(getkey(), (16, 8))\n    B = jr.normal(getkey(), (128, 16))\n\n    @jax.jit\n    @jax.grad\n    def fn(A):\n        op = lx.MatrixLinearOperator(A)\n        return jax.vmap(\n            lambda b: lx.linear_solve(\n                op, b, lx.AutoLinearSolver(well_posed=False)\n            ).value\n        )(B).mean()\n\n    fn(A)\n\n\ndef test_grad_vmap_advanced(getkey):\n    # this is a more complicated version of the above test, in which the batch axes and\n    # the undefinedprimals do not necessarily line up in the same arguments.\n    A = jr.normal(getkey(), (2, 8)), jr.normal(getkey(), (3, 8, 128))\n    B = jr.normal(getkey(), (2, 128)), jr.normal(getkey(), (3,))\n\n    output_structure = (\n        jax.ShapeDtypeStruct((2,), jnp.float64),\n        jax.ShapeDtypeStruct((3,), jnp.float64),\n    )\n\n    def to_vmap(A, B):\n        op = lx.PyTreeLinearOperator(A, output_structure)\n        return lx.linear_solve(op, B, lx.AutoLinearSolver(well_posed=False)).value\n\n    @jax.jit\n    @jax.grad\n    def fn(A):\n        return jax.vmap(to_vmap, in_axes=((None, 2), (1, None)))(A, B).mean()\n\n    fn(A)\n"
  },
  {
    "path": "tests/test_vmap_jvp.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools as ft\n\nimport equinox as eqx\nimport jax.lax as lax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport lineax as lx\nimport pytest\n\nfrom .helpers import (\n    construct_matrix,\n    construct_singular_matrix,\n    make_jac_operator,\n    make_matrix_operator,\n    solvers_tags_pseudoinverse,\n    tree_allclose,\n)\n\n\n@pytest.mark.parametrize(\"solver, tags, pseudoinverse\", solvers_tags_pseudoinverse)\n@pytest.mark.parametrize(\"make_operator\", (make_matrix_operator, make_jac_operator))\n@pytest.mark.parametrize(\"use_state\", (True, False))\n@pytest.mark.parametrize(\n    \"make_matrix\",\n    (\n        construct_matrix,\n        construct_singular_matrix,\n    ),\n)\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_vmap_jvp(\n    getkey, solver, tags, make_operator, pseudoinverse, use_state, make_matrix, dtype\n):\n    if (make_matrix is construct_matrix) or pseudoinverse:\n        t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None\n        if pseudoinverse:\n            jnp_solve1 = lambda mat, vec: jnp.linalg.lstsq(mat, vec)[0]  # pyright: ignore\n        else:\n            jnp_solve1 = jnp.linalg.solve  # pyright: ignore\n        if use_state:\n\n            def linear_solve1(operator, vector):\n                op_dynamic, op_static = eqx.partition(operator, eqx.is_inexact_array)\n                stopped_operator = eqx.combine(lax.stop_gradient(op_dynamic), op_static)\n                state = solver.init(stopped_operator, options={})\n\n                return lx.linear_solve(operator, vector, state=state, solver=solver)\n\n        else:\n            linear_solve1 = ft.partial(lx.linear_solve, solver=solver)\n\n        for mode in (\"vec\", \"op\", \"op_vec\"):\n            if \"op\" in mode:\n                axis_size = 10\n                out_axes = eqx.if_array(0)\n            else:\n                axis_size = None\n                out_axes = None\n\n            def _make():\n                matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype)\n                make_op = ft.partial(make_operator, getkey)\n                operator, t_operator = eqx.filter_jvp(\n                    make_op, (matrix, tags), (t_matrix, t_tags)\n                )\n                return matrix, t_matrix, operator, t_operator\n\n            matrix, t_matrix, operator, t_operator = eqx.filter_vmap(\n                _make, axis_size=axis_size, out_axes=out_axes\n            )()\n\n            if \"op\" in mode:\n                _, out_size, _ = matrix.shape\n            else:\n                out_size, _ = matrix.shape\n\n            if \"vec\" in mode:\n                vec = jr.normal(getkey(), (10, out_size), dtype=dtype)\n                t_vec = jr.normal(getkey(), (10, out_size), dtype=dtype)\n            else:\n                vec = jr.normal(getkey(), (out_size,), dtype=dtype)\n                t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)\n\n            if mode == \"op\":\n                linear_solve2 = lambda op: linear_solve1(op, vector=vec)\n                jnp_solve2 = lambda mat: jnp_solve1(mat, vec)\n            elif mode == \"vec\":\n                linear_solve2 = lambda vector: linear_solve1(operator, vector)\n                jnp_solve2 = lambda vector: jnp_solve1(matrix, vector)\n            elif mode == \"op_vec\":\n                linear_solve2 = linear_solve1\n                jnp_solve2 = jnp_solve1\n            else:\n                assert False\n            for jvp_first in (True, False):\n                if jvp_first:\n                    linear_solve3 = ft.partial(eqx.filter_jvp, linear_solve2)\n                else:\n                    linear_solve3 = linear_solve2\n                linear_solve3 = eqx.filter_vmap(linear_solve3)\n                if not jvp_first:\n                    linear_solve3 = ft.partial(eqx.filter_jvp, linear_solve3)\n                linear_solve3 = eqx.filter_jit(linear_solve3)\n                jnp_solve3 = ft.partial(eqx.filter_jvp, jnp_solve2)\n                jnp_solve3 = eqx.filter_vmap(jnp_solve3)\n                jnp_solve3 = eqx.filter_jit(jnp_solve3)\n                if mode == \"op\":\n                    out, t_out = linear_solve3((operator,), (t_operator,))\n                    true_out, true_t_out = jnp_solve3((matrix,), (t_matrix,))\n                elif mode == \"vec\":\n                    out, t_out = linear_solve3((vec,), (t_vec,))\n                    true_out, true_t_out = jnp_solve3((vec,), (t_vec,))\n                elif mode == \"op_vec\":\n                    out, t_out = linear_solve3((operator, vec), (t_operator, t_vec))\n                    true_out, true_t_out = jnp_solve3((matrix, vec), (t_matrix, t_vec))\n                else:\n                    assert False\n                assert tree_allclose(out.value, true_out, atol=1e-4)\n                assert tree_allclose(t_out.value, true_t_out, atol=1e-4)\n"
  },
  {
    "path": "tests/test_vmap_vmap.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools as ft\n\nimport equinox as eqx\nimport jax.numpy as jnp\nimport jax.random as jr\nimport lineax as lx\nimport pytest\n\nfrom .helpers import (\n    construct_matrix,\n    construct_singular_matrix,\n    make_jac_operator,\n    make_matrix_operator,\n    solvers_tags_pseudoinverse,\n    tree_allclose,\n)\n\n\n@pytest.mark.parametrize(\"make_operator\", (make_matrix_operator, make_jac_operator))\n@pytest.mark.parametrize(\"solver, tags, pseudoinverse\", solvers_tags_pseudoinverse)\n@pytest.mark.parametrize(\"use_state\", (True, False))\n@pytest.mark.parametrize(\n    \"make_matrix\",\n    (\n        construct_matrix,\n        construct_singular_matrix,\n    ),\n)\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_vmap_vmap(\n    getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix, dtype\n):\n    if (make_matrix is construct_matrix) or pseudoinverse:\n        # combinations with nontrivial application across both vmaps\n        axes = [\n            (eqx.if_array(0), eqx.if_array(0), None, None),\n            (None, None, 0, 0),\n            (eqx.if_array(0), eqx.if_array(0), None, 0),\n            (eqx.if_array(0), eqx.if_array(0), 0, 0),\n            (None, eqx.if_array(0), 0, 0),\n        ]\n\n        for vmap2_op, vmap1_op, vmap2_vec, vmap1_vec in axes:\n            if vmap1_op is not None:\n                axis_size1 = 10\n                out_axis1 = eqx.if_array(0)\n            else:\n                axis_size1 = None\n                out_axis1 = None\n\n            if vmap2_op is not None:\n                axis_size2 = 10\n                out_axis2 = eqx.if_array(0)\n            else:\n                axis_size2 = None\n                out_axis2 = None\n\n            (matrix,) = eqx.filter_vmap(\n                eqx.filter_vmap(\n                    lambda getkey, solver, tags: make_matrix(\n                        getkey, solver, tags, dtype=dtype\n                    ),\n                    axis_size=axis_size1,\n                    out_axes=out_axis1,\n                ),\n                axis_size=axis_size2,\n                out_axes=out_axis2,\n            )(getkey, solver, tags)\n\n            if vmap1_op is not None:\n                if vmap2_op is not None:\n                    _, _, out_size, _ = matrix.shape\n                else:\n                    _, out_size, _ = matrix.shape\n            else:\n                out_size, _ = matrix.shape\n\n            if vmap1_vec is None:\n                vec = jr.normal(getkey(), (out_size,), dtype=dtype)\n            elif (vmap1_vec is not None) and (vmap2_vec is None):\n                vec = jr.normal(getkey(), (10, out_size), dtype=dtype)\n            else:\n                vec = jr.normal(getkey(), (10, 10, out_size), dtype=dtype)\n\n            make_op = ft.partial(make_operator, getkey)\n            operator = eqx.filter_vmap(\n                eqx.filter_vmap(\n                    make_op,\n                    in_axes=vmap1_op,\n                    out_axes=out_axis1,\n                ),\n                in_axes=vmap2_op,\n                out_axes=out_axis2,\n            )(matrix, tags)\n\n            if use_state:\n\n                def linear_solve(operator, vector):\n                    state = solver.init(operator, options={})\n                    return lx.linear_solve(operator, vector, state=state, solver=solver)\n\n            else:\n\n                def linear_solve(operator, vector):\n                    return lx.linear_solve(operator, vector, solver)\n\n            as_matrix_vmapped = eqx.filter_vmap(\n                eqx.filter_vmap(\n                    lambda x: x.as_matrix(),\n                    in_axes=vmap1_op,\n                    out_axes=None if vmap1_op is None else 0,\n                ),\n                in_axes=vmap2_op,\n                out_axes=None if vmap2_op is None else 0,\n            )(operator)\n\n            vmap1_axes = (vmap1_op, vmap1_vec)\n            vmap2_axes = (vmap2_op, vmap2_vec)\n\n            result = eqx.filter_vmap(\n                eqx.filter_vmap(linear_solve, in_axes=vmap1_axes), in_axes=vmap2_axes\n            )(operator, vec).value\n\n            solve_with = lambda x: eqx.filter_vmap(\n                eqx.filter_vmap(x, in_axes=vmap1_axes), in_axes=vmap2_axes\n            )(as_matrix_vmapped, vec)\n\n            if make_matrix is construct_singular_matrix:\n                true_result, _, _, _ = solve_with(jnp.linalg.lstsq)  # pyright: ignore\n            else:\n                true_result = solve_with(jnp.linalg.solve)  # pyright: ignore\n            assert tree_allclose(result, true_result, rtol=1e-3)\n"
  },
  {
    "path": "tests/test_well_posed.py",
    "content": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport lineax as lx\nimport pytest\n\nfrom .helpers import (\n    construct_matrix,\n    make_jacrev_operator,\n    ops,\n    params,\n    solvers,\n    tree_allclose,\n)\n\n\n@pytest.mark.parametrize(\"make_operator,solver,tags\", params(only_pseudo=False))\n@pytest.mark.parametrize(\"ops\", ops)\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype):\n    if make_operator is make_jacrev_operator and dtype is jnp.complex128:\n        # JacobianLinearOperator does not support complex dtypes when jac=\"bwd\"\n        return\n    if jax.config.jax_enable_x64:  # pyright: ignore\n        tol = 1e-10\n    else:\n        tol = 1e-4\n    (matrix,) = construct_matrix(getkey, solver, tags, dtype=dtype)\n    operator = make_operator(getkey, matrix, tags)\n    operator, matrix = ops(operator, matrix)\n    assert tree_allclose(operator.as_matrix(), matrix, rtol=tol, atol=tol)\n    out_size, _ = matrix.shape\n    true_x = jr.normal(getkey(), (out_size,), dtype=dtype)\n    b = matrix @ true_x\n    x = lx.linear_solve(operator, b, solver=solver).value\n    jax_x = jnp.linalg.solve(matrix, b)  # pyright: ignore\n    assert tree_allclose(x, true_x, atol=tol, rtol=tol)\n    assert tree_allclose(x, jax_x, atol=tol, rtol=tol)\n\n\n@pytest.mark.parametrize(\"solver\", solvers)\n@pytest.mark.parametrize(\"dtype\", (jnp.float64, jnp.complex128))\ndef test_pytree_wellposed(solver, getkey, dtype):\n    if not isinstance(\n        solver,\n        (lx.Diagonal, lx.Triangular, lx.Tridiagonal, lx.Cholesky, lx.CG),\n    ):\n        if jax.config.jax_enable_x64:  # pyright: ignore\n            tol = 1e-10\n        else:\n            tol = 1e-4\n\n        true_x = [\n            jr.normal(getkey(), shape=(2, 4), dtype=dtype),\n            jr.normal(getkey(), (3,), dtype=dtype),\n        ]\n        pytree = [\n            [\n                jr.normal(getkey(), shape=(2, 4, 2, 4), dtype=dtype),\n                jr.normal(getkey(), shape=(2, 4, 3), dtype=dtype),\n            ],\n            [\n                jr.normal(getkey(), shape=(3, 2, 4), dtype=dtype),\n                jr.normal(getkey(), shape=(3, 3), dtype=dtype),\n            ],\n        ]\n        out_structure = jax.eval_shape(lambda: true_x)\n\n        operator = lx.PyTreeLinearOperator(pytree, out_structure)\n        b = operator.mv(true_x)\n        lx_x = lx.linear_solve(operator, b, solver, throw=False)\n        assert tree_allclose(lx_x.value, true_x, atol=tol, rtol=tol)\n"
  }
]