Repository: patrick-kidger/lineax Branch: main Commit: 2bf9824f7df9 Files: 75 Total size: 416.0 KB Directory structure: gitextract_at99w4wk/ ├── .github/ │ └── workflows/ │ ├── build_docs.yml │ ├── release.yml │ └── run_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmarks/ │ ├── gmres_fails_safely.py │ ├── lstsq_gradients.py │ └── solver_speeds.py ├── docs/ │ ├── .htaccess │ ├── _overrides/ │ │ └── partials/ │ │ └── source.html │ ├── _static/ │ │ ├── custom_css.css │ │ └── mathjax.js │ ├── api/ │ │ ├── functions.md │ │ ├── linear_solve.md │ │ ├── operators.md │ │ ├── solution.md │ │ ├── solvers.md │ │ └── tags.md │ ├── examples/ │ │ ├── classical_solve.ipynb │ │ ├── complex_solve.ipynb │ │ ├── least_squares.ipynb │ │ ├── no_materialisation.ipynb │ │ ├── operators.ipynb │ │ └── structured_matrices.ipynb │ ├── faq.md │ └── index.md ├── lineax/ │ ├── __init__.py │ ├── _custom_types.py │ ├── _misc.py │ ├── _norm.py │ ├── _operator.py │ ├── _solution.py │ ├── _solve.py │ ├── _solver/ │ │ ├── __init__.py │ │ ├── bicgstab.py │ │ ├── cg.py │ │ ├── cholesky.py │ │ ├── diagonal.py │ │ ├── gmres.py │ │ ├── lsmr.py │ │ ├── lu.py │ │ ├── misc.py │ │ ├── normal.py │ │ ├── qr.py │ │ ├── svd.py │ │ ├── triangular.py │ │ └── tridiagonal.py │ ├── _tags.py │ └── internal/ │ └── __init__.py ├── mkdocs.yml ├── pyproject.toml └── tests/ ├── README.md ├── __init__.py ├── __main__.py ├── conftest.py ├── helpers.py ├── test_adjoint.py ├── test_invert.py ├── test_jvp.py ├── test_jvp_jvp1.py ├── test_jvp_jvp2.py ├── test_lsmr.py ├── test_misc.py ├── test_norm.py ├── test_operator.py ├── test_singular.py ├── test_solve.py ├── test_transpose.py ├── test_vmap.py ├── test_vmap_jvp.py ├── test_vmap_vmap.py └── test_well_posed.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/build_docs.yml ================================================ name: Build docs on: push: branches: - main jobs: build: strategy: matrix: python-version: [ 3.11 ] os: [ ubuntu-latest ] runs-on: ${{ matrix.os }} steps: - name: Checkout code uses: actions/checkout@v2 - name: Install the latest version of uv uses: astral-sh/setup-uv@v7 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | uv run echo done - name: Build docs run: | uv run mkdocs build - name: Upload docs uses: actions/upload-artifact@v4 with: name: docs path: site # where `mkdocs build` puts the built site ================================================ FILE: .github/workflows/release.yml ================================================ name: Release on: push: branches: - main jobs: build: runs-on: ubuntu-latest steps: - name: Release uses: patrick-kidger/action_update_python_project@v8 with: python-version: "3.11" # Uninstall and reinstall pytest to work around the fact that it doesn't get put into `bin` otherwise. test-script: | cp -r ${{ github.workspace }}/tests ./tests cp ${{ github.workspace }}/pyproject.toml ./pyproject.toml uv pip uninstall pytest uv sync --no-install-project --inexact uv run --no-sync pytest pypi-token: ${{ secrets.pypi_token }} github-user: patrick-kidger github-token: ${{ github.token }} ================================================ FILE: .github/workflows/run_tests.yml ================================================ name: Run tests on: pull_request: jobs: run-test: strategy: matrix: python-version: [ 3.11 ] os: [ ubuntu-latest ] fail-fast: false runs-on: ${{ matrix.os }} steps: - name: Checkout code uses: actions/checkout@v2 - name: Install the latest version of uv uses: astral-sh/setup-uv@v7 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | uv run echo done - name: Checks with pre-commit run: | uv run prek run --all-files - name: Test with pytest run: | uv run python -m tests - name: Check that documentation can be built. run: | uv run mkdocs build ================================================ FILE: .gitignore ================================================ **/__pycache__ **/.ipynb_checkpoints *.egg-info/ build/ dist/ site/ examples/data .all_objects.cache .pymon .idea .venv uv.lock ================================================ FILE: .pre-commit-config.yaml ================================================ fail_fast: true repos: - repo: meta hooks: - id: check-hooks-apply - id: check-useless-excludes - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - id: trailing-whitespace exclude: \.md$ - id: check-toml - id: mixed-line-ending - repo: local hooks: - id: sort-pyproject name: sort pyproject files: ^pyproject\.toml$ language: system entry: uv run -- toml-sort -i --sort-table-keys --sort-inline-tables - id: ruff-format name: ruff format types_or: [python, pyi, jupyter, toml] language: system entry: uv run -- ruff format -- require_serial: true - id: ruff-lint name: ruff lint types_or: [python, pyi, jupyter, toml] language: system entry: uv run -- ruff check --fix -- require_serial: true - id: pyright name: pyright types_or: [python] language: system entry: uv run -- pyright require_serial: true ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing Contributions (pull requests) are very welcome! Here's how to get started. --- ### Getting started [We assume that you have `uv` installed.](https://docs.astral.sh/uv/) Now fork the library on GitHub. Then clone and install the library: ```bash git clone https://github.com/your-username-here/lineax.git cd lineax uv run prek install # Creates a local venv + installs dependencies + installs pre-commit hooks. ``` --- ### If you're making changes to the code Now make your changes. Make sure to include additional tests if necessary. Next verify the tests all pass: ```bash uv run python -m tests ``` Then push your changes back to your fork of the repository: ```bash git push ``` Finally, open a pull request on GitHub! --- ### If you're making changes to the documentation Make your changes. You can then build the documentation by doing ```bash uv run mkdocs serve ``` You can then see your local copy of the documentation by navigating to `localhost:8000` in a web browser. ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================

Lineax

Lineax is a [JAX](https://github.com/google/jax) library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.) Features include: - PyTree-valued matrices and vectors; - General linear operators for Jacobians, transposes, etc.; - Efficient linear least squares (e.g. QR solvers); - Numerically stable gradients through linear least squares; - Support for structured (e.g. symmetric) matrices; - Improved compilation times; - Improved runtime of some algorithms; - Support for both real-valued and complex-valued inputs; - All the benefits of working with JAX: autodiff, autoparallelism, GPU/TPU support, etc. ## Installation ```bash pip install lineax ``` Requires Python 3.10+, JAX 0.4.38+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.10+. ## Documentation Available at [https://docs.kidger.site/lineax](https://docs.kidger.site/lineax). ## Quick examples Lineax can solve a least squares problem with an explicit matrix operator: ```python import jax.random as jr import lineax as lx matrix_key, vector_key = jr.split(jr.PRNGKey(0)) matrix = jr.normal(matrix_key, (10, 8)) vector = jr.normal(vector_key, (10,)) operator = lx.MatrixLinearOperator(matrix) solution = lx.linear_solve(operator, vector, solver=lx.QR()) ``` or Lineax can solve a problem without ever materializing a matrix, as done in this quadratic solve: ```python import jax import lineax as lx key = jax.random.PRNGKey(0) y = jax.random.normal(key, (10,)) def quadratic_fn(y, args): return jax.numpy.sum((y - 1)**2) gradient_fn = jax.grad(quadratic_fn) hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag) solver = lx.CG(rtol=1e-6, atol=1e-6) out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver) minimum = y - out.value ``` ## Citation If you found this library to be useful in academic work, then please cite: ([arXiv link](https://arxiv.org/abs/2311.17283)) ```bibtex @article{lineax2023, title={Lineax: unified linear solves and linear least-squares in JAX and Equinox}, author={Jason Rader and Terry Lyons and Patrick Kidger}, journal={ AI for science workshop at Neural Information Processing Systems 2023, arXiv:2311.17283 }, year={2023}, } ``` (Also consider starring the project on GitHub.) ## See also: other libraries in the JAX ecosystem **Always useful** [Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX! [jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays. **Deep learning** [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers. [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device). [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs). [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees. **Scientific computing** [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers. [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares. [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling. [sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent. [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!) **Awesome JAX** [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects. ================================================ FILE: benchmarks/gmres_fails_safely.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools as ft import equinox as eqx import equinox.internal as eqxi import jax import jax.numpy as jnp import jax.random as jr import jax.scipy as jsp import lineax as lx getkey = eqxi.GetKey() def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8): return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol) jax.config.update("jax_enable_x64", True) def make_problem(mat_size: int, *, key): mat = jr.normal(key, (mat_size, mat_size)) true_x = jr.normal(key, (mat_size,)) b = mat @ true_x op = lx.MatrixLinearOperator(mat) return mat, op, b, true_x def benchmark_jax(mat_size: int, *, key): mat, _, b, true_x = make_problem(mat_size, key=key) solve_with_jax = ft.partial( jsp.sparse.linalg.gmres, tol=1e-5, solve_method="batched" ) gmres_jit = jax.jit(solve_with_jax) jax_soln, info = gmres_jit(mat, b) # info == 0.0 implies that the solve has succeeded. returned_failed = jnp.all(info != 0.0) actually_failed = not tree_allclose(jax_soln, true_x, atol=1e-4, rtol=1e-4) assert actually_failed captured_failure = returned_failed & actually_failed return captured_failure def benchmark_lx(mat_size: int, *, key): _, op, b, true_x = make_problem(mat_size, key=key) lx_soln = lx.linear_solve(op, b, lx.GMRES(atol=1e-5, rtol=1e-5), throw=False) returned_failed = jnp.all(lx_soln.result != lx.RESULTS.successful) actually_failed = not tree_allclose(lx_soln.value, true_x, atol=1e-4, rtol=1e-4) assert actually_failed captured_failure = returned_failed & actually_failed return captured_failure lx_failed_safely = 0 jax_failed_safely = 0 for _ in range(100): key = getkey() jax_captured_failure = benchmark_jax(100, key=key) lx_captured_failure = benchmark_lx(100, key=key) jax_failed_safely = jax_failed_safely + jax_captured_failure lx_failed_safely = lx_failed_safely + lx_captured_failure print(f"JAX failed safely {jax_failed_safely} out of 100 times") print(f"Lineax failed safely {lx_failed_safely} out of 100 times") ================================================ FILE: benchmarks/lstsq_gradients.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Core JAX has some numerical issues with their lstsq gradients. # See https://github.com/google/jax/issues/14868 # This demonstrates that we don't have the same issue! import sys import jax import jax.numpy as jnp import lineax as lx sys.path.append("../tests") from helpers import finite_difference_jvp # pyright: ignore a_primal = (jnp.eye(3),) a_tangent = (jnp.zeros((3, 3)),) def jax_solve(a): sol, _, _, _ = jnp.linalg.lstsq(a, jnp.arange(3)) # pyright: ignore return sol def lx_solve(a): op = lx.MatrixLinearOperator(a) return lx.linear_solve(op, jnp.arange(3)).value _, true_jvp = finite_difference_jvp(jax_solve, a_primal, a_tangent) _, jax_jvp = jax.jvp(jax_solve, a_primal, a_tangent) _, lx_jvp = jax.jvp(lx_solve, a_primal, a_tangent) assert jnp.isnan(jax_jvp).all() assert jnp.allclose(true_jvp, lx_jvp) ================================================ FILE: benchmarks/solver_speeds.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools as ft import sys import timeit import equinox as eqx import equinox.internal as eqxi import jax import jax.numpy as jnp import jax.random as jr import jax.scipy as jsp import lineax as lx sys.path.append("../tests") from helpers import construct_matrix, has_tag # pyright: ignore[reportMissingImports] getkey = eqxi.GetKey() def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8): return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol) jax.config.update("jax_enable_x64", True) if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-12 else: tol = 1e-6 def base_wrapper(a, b, solver): op = lx.MatrixLinearOperator( a, ( lx.positive_semidefinite_tag, lx.symmetric_tag, lx.diagonal_tag, lx.tridiagonal_tag, ), ) out = lx.linear_solve(op, b, solver, throw=False) return out.value def jax_svd(a, b): out, _, _, _ = jnp.linalg.lstsq(a, b) # pyright: ignore return out def jax_gmres(a, b): out, _ = jsp.sparse.linalg.gmres(a, b, tol=tol) return out def jax_bicgstab(a, b): out, _ = jsp.sparse.linalg.bicgstab(a, b, tol=tol) return out def jax_cg(a, b): out, _ = jsp.sparse.linalg.cg(a, b, tol=tol) return out def jax_lu(matrix, vector): return jsp.linalg.lu_solve(jsp.linalg.lu_factor(matrix), vector) def jax_cholesky(matrix, vector): return jsp.linalg.cho_solve(jsp.linalg.cho_factor(matrix), vector) def jax_tridiagonal(matrix, vector): dl = jnp.append(0.0, matrix.diagonal(-1)) d = matrix.diagonal(0) du = jnp.append(matrix.diagonal(1), 0.0) return jax.lax.linalg.tridiagonal_solve(dl, d, du, vector[:, None])[:, 0] named_solvers = [ ("LU", "LU", lx.LU(), jax_lu, ()), ("QR", "SVD", lx.QR(), jax_svd, ()), ("SVD", "SVD", lx.SVD(), jax_svd, ()), ( "Cholesky", "Cholesky", lx.Cholesky(), jax_cholesky, lx.positive_semidefinite_tag, ), ("Diagonal", "None", lx.Diagonal(), None, lx.diagonal_tag), ( "Tridiagonal", "Tridiagonal", lx.Tridiagonal(), jax_tridiagonal, lx.tridiagonal_tag, ), ( "CG", "CG", lx.CG(atol=tol, rtol=tol, stabilise_every=None), jax_cg, lx.positive_semidefinite_tag, ), ( "GMRES", "GMRES", lx.GMRES(atol=1, rtol=1), jax_gmres, (), ), ( "BiCGStab", "BiCGStab", lx.BiCGStab(atol=tol, rtol=tol), jax_bicgstab, (), ), ] def create_problem(solver, tags, size=3): (matrix,) = construct_matrix(getkey, solver, tags, size=size) true_x = jr.normal(getkey(), (size,)) b = matrix @ true_x return matrix, true_x, b def create_easy_iterative_problem(size, tags): matrix = jr.normal(getkey(), (size, size)) / size + 2 * jnp.eye(size) true_x = jr.normal(getkey(), (size,)) if has_tag(tags, lx.positive_semidefinite_tag): matrix = matrix.T @ matrix b = matrix @ true_x return matrix, true_x, b def test_solvers(vmap_size, mat_size): for lx_name, jax_name, _lx_solver, jax_solver, tags in named_solvers: lx_solver = ft.partial(base_wrapper, solver=_lx_solver) if vmap_size == 1: if isinstance(_lx_solver, (lx.CG, lx.GMRES, lx.BiCGStab)): matrix, true_x, b = create_easy_iterative_problem(mat_size, tags) else: matrix, true_x, b = create_problem(lx_solver, tags, size=mat_size) else: if isinstance(_lx_solver, (lx.CG, lx.GMRES, lx.BiCGStab)): matrix, true_x, b = eqx.filter_vmap( create_easy_iterative_problem, axis_size=vmap_size, out_axes=eqx.if_array(0), )(mat_size, tags) else: matrix, true_x, b = create_problem(lx_solver, tags, size=mat_size) _create_problem = ft.partial(create_problem, size=mat_size) matrix, true_x, b = eqx.filter_vmap( _create_problem, axis_size=vmap_size, out_axes=eqx.if_array(0) )(lx_solver, tags) lx_solver = jax.vmap(lx_solver) if jax_solver is not None: jax_solver = jax.vmap(jax_solver) lx_solver = jax.jit(lx_solver) bench_lx = ft.partial(lx_solver, matrix, b) if vmap_size == 1: batch_msg = "problem" else: batch_msg = f"batch of {vmap_size} problems" lx_soln = bench_lx() if tree_allclose(lx_soln, true_x, atol=1e-4, rtol=1e-4): lx_solve_time = timeit.timeit(bench_lx, number=1) print( f"Lineax's {lx_name} solved {batch_msg} of " f"size {mat_size} in {lx_solve_time} seconds." ) else: fail_time = timeit.timeit(bench_lx, number=1) err = jnp.abs(lx_soln - true_x).max() print( f"Lineax's {lx_name} failed to solve {batch_msg} of " f"size {mat_size} with error {err} in {fail_time} seconds" ) if jax_solver is None: print("JAX has no equivalent solver. \n") else: jax_solver = jax.jit(jax_solver) bench_jax = ft.partial(jax_solver, matrix, b) jax_soln = bench_jax() if tree_allclose(jax_soln, true_x, atol=1e-4, rtol=1e-4): jax_solve_time = timeit.timeit(bench_jax, number=1) print( f"JAX's {jax_name} solved {batch_msg} of " f"size {mat_size} in {jax_solve_time} seconds. \n" ) else: fail_time = timeit.timeit(bench_jax, number=1) err = jnp.abs(jax_soln - true_x).max() print( f"JAX's {jax_name} failed to solve {batch_msg} of " f"size {mat_size} with error {err} in {fail_time} seconds. \n" ) for vmap_size, mat_size in [(1, 50), (1000, 50)]: test_solvers(vmap_size, mat_size) ================================================ FILE: docs/.htaccess ================================================ ErrorDocument 404 /jaxtyping/404.html ================================================ FILE: docs/_overrides/partials/source.html ================================================ {% import "partials/language.html" as lang with context %}
{% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %} {% include ".icons/" ~ icon ~ ".svg" %}
{{ config.repo_name }}
{% include ".icons/fontawesome/brands/twitter.svg" %}
{% include "bluesky.svg" %}
{{ config.theme.twitter_bluesky_name }}
================================================ FILE: docs/_static/custom_css.css ================================================ /* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ html { scroll-padding-top: 50px; } /* Fit the Twitter handle alongside the GitHub one in the top right. */ div.md-header__source { width: revert; max-width: revert; } a.md-source { display: inline-block; } .md-source__repository { max-width: 100%; } /* Emphasise sections of nav on left hand side */ nav.md-nav { padding-left: 5px; } nav.md-nav--secondary { border-left: revert !important; } .md-nav__title { font-size: 0.9rem; } .md-nav__item--section > .md-nav__link { font-size: 0.9rem; } /* Indent autogenerated documentation */ div.doc-contents { padding-left: 25px; border-left: 4px solid rgba(230, 230, 230); } /* Increase visibility of splitters "---" */ [data-md-color-scheme="default"] .md-typeset hr { border-bottom-color: rgb(0, 0, 0); border-bottom-width: 1pt; } [data-md-color-scheme="slate"] .md-typeset hr { border-bottom-color: rgb(230, 230, 230); } /* More space at the bottom of the page */ .md-main__inner { margin-bottom: 1.5rem; } /* Remove prev/next footer buttons */ .md-footer__inner { display: none; } /* Change font sizes */ html { /* Decrease font size for overall webpage Down from 137.5% which is the Material default */ font-size: 110%; } .md-typeset .admonition { /* Increase font size in admonitions */ font-size: 100% !important; } .md-typeset details { /* Increase font size in details */ font-size: 100% !important; } .md-typeset h1 { font-size: 1.6rem; } .md-typeset h2 { font-size: 1.5rem; } .md-typeset h3 { font-size: 1.3rem; } .md-typeset h4 { font-size: 1.1rem; } .md-typeset h5 { font-size: 0.9rem; } .md-typeset h6 { font-size: 0.8rem; } /* Bugfix: remove the superfluous parts generated when doing: ??? Blah ::: library.something */ .md-typeset details .mkdocstrings > h4 { display: none; } .md-typeset details .mkdocstrings > h5 { display: none; } /* Change default colours for tags */ [data-md-color-scheme="default"] { --md-typeset-a-color: rgb(0, 189, 164) !important; } [data-md-color-scheme="slate"] { --md-typeset-a-color: rgb(0, 189, 164) !important; } /* Highlight functions, classes etc. type signatures. Really helps to make clear where one item ends and another begins. */ [data-md-color-scheme="default"] { --doc-heading-color: #DDD; --doc-heading-border-color: #CCC; --doc-heading-color-alt: #F0F0F0; } [data-md-color-scheme="slate"] { --doc-heading-color: rgb(25,25,33); --doc-heading-border-color: rgb(25,25,33); --doc-heading-color-alt: rgb(33,33,44); --md-code-bg-color: rgb(38,38,50); } h4.doc-heading { /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ background-color: var(--doc-heading-color); border: solid var(--doc-heading-border-color); border-width: 1.5pt; border-radius: 2pt; padding: 0pt 5pt 2pt 5pt; } h5.doc-heading, h6.heading { background-color: var(--doc-heading-color-alt); border-radius: 2pt; padding: 0pt 5pt 2pt 5pt; } ================================================ FILE: docs/_static/mathjax.js ================================================ window.MathJax = { tex: { inlineMath: [["\\(", "\\)"]], displayMath: [["\\[", "\\]"]], processEscapes: true, processEnvironments: true }, options: { ignoreHtmlClass: ".*|", processHtmlClass: "arithmatex" } }; document$.subscribe(() => { MathJax.typesetPromise() }) ================================================ FILE: docs/api/functions.md ================================================ # Functions on linear operators We define a number of functions on [linear operators](./operators.md). ## Computational changes These do not change the mathematical meaning of the operator; they simply change how it is stored computationally. (E.g. to materialise the whole operator.) ::: lineax.linearise --- ::: lineax.materialise ## Extract information from the operator ::: lineax.diagonal --- ::: lineax.tridiagonal ## Test the operator to see if it exhibits a certain property Note that these do *not* inspect the values of the operator -- instead, they use typically use [tags](./tags.md). (Or in some cases, just the type of the operator: e.g. `is_diagonal(DiagonalLinearOperator(...)) == True`.) ::: lineax.has_unit_diagonal --- ::: lineax.is_diagonal --- ::: lineax.is_tridiagonal --- ::: lineax.is_lower_triangular --- ::: lineax.is_upper_triangular --- ::: lineax.is_positive_semidefinite --- ::: lineax.is_negative_semidefinite --- ::: lineax.is_symmetric ================================================ FILE: docs/api/linear_solve.md ================================================ # linear_solve This is the main entry point. ::: lineax.linear_solve ## invert A convenience function for obtaining the inverse of an operator as a [`lineax.FunctionLinearOperator`][]. ::: lineax.invert ================================================ FILE: docs/api/operators.md ================================================ # Linear operators We often talk about solving a linear system $Ax = b$, where $A \in \mathbb{R}^{n \times m}$ is a matrix, $b \in \mathbb{R}^n$ is a vector, and $x \in \mathbb{R}^m$ is our desired solution. The linear operators described on this page are ways of describing the matrix $A$. The simplest is [`lineax.MatrixLinearOperator`][], which simply holds the matrix $A$ directly. Meanwhile if $A$ is diagonal, then there is also [`lineax.DiagonalLinearOperator`][]: for efficiency this only stores the diagonal of $A$. Or, perhaps we only have a function $F : \mathbb{R}^m \to \mathbb{R}^n$ such that $F(x) = Ax$. Whilst we could use $F$ to materialise the whole matrix $A$ and then store it in a [`lineax.MatrixLinearOperator`][], that may be very memory intensive. Instead, we may prefer to use [`lineax.FunctionLinearOperator`][]. Many linear solvers (e.g. [`lineax.CG`][]) only use matrix-vector products, and this means we can avoid ever needing to materialise the whole matrix $A$. ??? abstract "`lineax.AbstractLinearOperator`" ::: lineax.AbstractLinearOperator options: members: - mv - as_matrix - transpose - in_structure - out_structure - in_size - out_size ::: lineax.MatrixLinearOperator options: members: - __init__ --- ::: lineax.DiagonalLinearOperator options: members: - __init__ --- ::: lineax.TridiagonalLinearOperator options: members: - __init__ --- ::: lineax.PyTreeLinearOperator options: members: - __init__ --- ::: lineax.JacobianLinearOperator options: members: - __init__ --- ::: lineax.FunctionLinearOperator options: members: - __init__ --- ::: lineax.IdentityLinearOperator options: members: - __init__ --- ::: lineax.TaggedLinearOperator options: members: - __init__ ================================================ FILE: docs/api/solution.md ================================================ # Solution ::: lineax.Solution options: members: [] --- ::: lineax.RESULTS options: members: [] ================================================ FILE: docs/api/solvers.md ================================================ # Solvers If you're not sure what to use, then pick [`lineax.AutoLinearSolver`][] and it will automatically dispatch to an efficient solver depending on what structure your linear operator is declared to exhibit. (See the [tags](./tags.md) page.) ??? abstract "`lineax.AbstractLinearSolver`" ::: lineax.AbstractLinearSolver options: members: - init - compute - transpose - conj - assume_full_rank ::: lineax.AutoLinearSolver options: members: - __init__ - select_solver --- ::: lineax.LU options: members: - __init__ ## Least squares solvers These are capable of solving ill-posed linear problems. ::: lineax.QR options: members: - __init__ --- ::: lineax.SVD options: members: - __init__ --- ::: lineax.Normal options: members: - __init__ --- ::: lineax.LSMR options: members: - __init__ #### Diagonal In addition to these, [`lineax.Diagonal`][] with `well_posed=False` (below) also supports ill-posed problems. ## Iterative solvers These solvers use only matrix-vector products, and do not require instantiating the whole matrix. This makes them good when used alongside e.g. [`lineax.JacobianLinearOperator`][] or [`lineax.FunctionLinearOperator`][], which only provide matrix-vector products. !!! warning Note that [`lineax.BiCGStab`][] and [`lineax.GMRES`][] may fail to converge on some (typically non-sparse) problems. ::: lineax.CG options: members: - __init__ --- ::: lineax.BiCGStab options: members: - __init__ --- ::: lineax.GMRES options: members: - __init__ #### LSMR In addition to these, [`lineax.LSMR`][] (above) is also an iterative method. ## Structure-exploiting solvers These require special structure in the operator. (And will throw an error if passed an operator without that structure.) In return, they are able to solve the linear problem much more efficiently. ::: lineax.Cholesky options: members: - __init__ --- ::: lineax.Diagonal options: members: - __init__ --- ::: lineax.Triangular options: members: - __init__ --- ::: lineax.Tridiagonal options: members: - __init__ #### CG In addition to these, [`lineax.CG`][] also requires special structure (positive or negative definiteness). ================================================ FILE: docs/api/tags.md ================================================ # Tags Lineax offers a way to "tag" linear operators as exhibiting certain properties, e.g. that they are positive semidefinite. If a linear operator is known to have a particular property, then this can be used to dispatch to a more efficient implementation, e.g. when solving a linear system. Generally speaking, tags are an *optional* tool that can be used to improve your run time and/or compile time, by statically telling the linear solvers what properties they may assume about your system. However, if misused then you may find that the wrong result is silently returned. In this way they are analogous to flags like `scipy.linalg.solve(..., assume_a="pos")`. !!! Example ```python # Some rank-2 JAX array. matrix = ... # Some rank-1 JAX array. vector = ... # Declare that this matrix is positive semidefinite. operator = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag) # This tag is used to dispatch to a maximally-efficient linear solver. # In this case, a Cholesky solver is used: solution = lx.linear_solve(operator, vector) # Whether operators are tagged can be checked: assert lx.is_positive_semidefinite(operator) ``` !!! Warning Be careful, only the tag is actually checked, not the actual value of the matrix: ```python # Not a positive semidefinite matrix matrix = jax.numpy.array([[1, 2], [3, 4]]) operator = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag) lx.is_positive_semidefinite(operator) # True lx.linear_solve(operator, vector) # Returns the wrong solution! ``` Of the built-in operators: [`lineax.MatrixLinearOperator`][], [`lineax.PyTreeLinearOperator`][], [`lineax.JacobianLinearOperator`][], [`lineax.FunctionLinearOperator`][], [`lineax.TaggedLinearOperator`][] directly support a `tags` argument that mark them as having certain characteristics: ```python operator = lx.MatrixLinearOperator(matrix, lx.symmetric_tag) ``` You can pass multiple tags at once: ```python operator = lx.MatrixLinearOperator(matrix, (lx.symmetric_tag, lx.unit_diagonal_tag)) ``` Other linear operators can be wrapped into a [`lineax.TaggedLinearOperator`][] if necessary: ```python operator = lx.MatrixLinearOperator(...) symmetric_operator = operator + operator.T lx.is_symmetric(symmetric_operator) # False symmetric_operator = lx.TaggedLinearOperator(symmetric_operator, lx.symmetric_tag) lx.is_symmetric(symmetric_operator) # True ``` Some linear operators are known to exhibit certain properties by construction, and need no additional tags: ```python lx.is_symmetric(lx.DiagonalLinearOperator(...)) # True lx.is_positive_semidefinite(lx.IdentityLinearOperator(...)) # True ``` ## List of available tags ::: lineax.symmetric_tag Marks that an operator is symmetric. (As a matrix, $A = A^\intercal$.) --- ::: lineax.diagonal_tag Marks than an operator is diagonal. (As a matrix, it must have zeros in the off-diagonal entries.) For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Diagonal`][] as the solver. --- ::: lineax.unit_diagonal_tag Marks than an operator has $1$ for every diagonal element. (As a matrix $A$, then it must have $A_{ii} = 1$ for all $i$.) Note that the whole matrix need not be diagonal. For example, [`lineax.Triangular`][] uses this to cheapen its solve. --- ::: lineax.lower_triangular_tag Marks that an operator is lower triangular. (As a matrix $A$, then it must have $A_{ij} = 0 for all $i < j$.) Note that the diagonal may still have nonzero entries. For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Triangular`][] as the solver. --- ::: lineax.upper_triangular_tag Marks that an operator is upper triangular. (As a matrix $A$, then it must have $A_{ij} = 0 for all $i > j$.) Note that the diagonal may still have nonzero entries. For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Triangular`][] as the solver. --- ::: lineax.positive_semidefinite_tag Marks than operator is positive **semidefinite**. For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Cholesky`][] as the solver. --- ::: lineax.negative_semidefinite_tag Marks than operator is negative **semidefinite**. For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Cholesky`][] as the solver. ================================================ FILE: docs/examples/classical_solve.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "8d41e1dd-93da-4e81-bd4a-33e5df8915f1", "metadata": {}, "source": [ "# Classical solve\n", "\n", "We wish to solve the linear system $Ax = b$. Here we consider the classical case for which the full matrix $A$ is square, well-posed and materialised in memory." ] }, { "cell_type": "code", "execution_count": 1, "id": "cb3a7781-2358-40c4-82f3-e908bddeb578", "metadata": { "tags": [], "ExecuteTime": { "end_time": "2024-04-02T05:26:05.556701Z", "start_time": "2024-04-02T05:26:03.814599Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "A=\n", "[[-0.3721109 0.26423115 -0.18252768]\n", " [-0.7368197 0.44973662 -0.1521442 ]\n", " [-0.67135346 -0.5908641 0.73168886]]\n", "b=[ 0.17269018 -0.64765567 1.2229712 ]\n", "x=[-2.7321298 -8.52878 -7.7226872]\n" ] } ], "source": [ "import jax.random as jr\n", "import lineax as lx\n", "\n", "\n", "matrix = jr.normal(jr.PRNGKey(0), (3, 3))\n", "vector = jr.normal(jr.PRNGKey(1), (3,))\n", "operator = lx.MatrixLinearOperator(matrix)\n", "solution = lx.linear_solve(operator, vector)\n", "print(f\"A=\\n{matrix}\\nb={vector}\\nx={solution.value}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs/examples/complex_solve.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "8d41e1dd-93da-4e81-bd4a-33e5df8915f1", "metadata": {}, "source": [ "# Complex solve\n", "\n", "We can also solve a system with complex entries. Here we consider the classical case for which the full matrix $A$ is square, well-posed and materialised in memory." ] }, { "cell_type": "code", "execution_count": 1, "id": "cb3a7781-2358-40c4-82f3-e908bddeb578", "metadata": { "tags": [], "ExecuteTime": { "end_time": "2024-04-02T05:29:04.909894Z", "start_time": "2024-04-02T05:29:04.103141Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "A=\n", "[[-1.8459436 -0.2744466j 0.02393756-0.03172905j 0.76815367-1.4444253j ]\n", " [-1.0467293 +0.05608991j 1.0891742 -0.03264743j 0.7513123 +0.56285536j]\n", " [ 0.38307396-1.0190808j 0.01203694-1.1971304j 0.19252291-0.26424018j]]\n", "b=[0.23162952+0.3614433j 0.05800135+1.6094692j 0.8979094 +0.16941352j]\n", "x=[-0.07652722-0.34397143j -0.22629777+1.0359733j 0.22135164-0.00880566j]\n" ] } ], "source": [ "import jax.numpy as jnp\n", "import jax.random as jr\n", "import lineax as lx\n", "\n", "\n", "matrix = jr.normal(jr.PRNGKey(0), (3, 3), dtype=jnp.complex64)\n", "vector = jr.normal(jr.PRNGKey(1), (3,), dtype=jnp.complex64)\n", "operator = lx.MatrixLinearOperator(matrix)\n", "solution = lx.linear_solve(operator, vector)\n", "print(f\"A=\\n{matrix}\\nb={vector}\\nx={solution.value}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs/examples/least_squares.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "44bff903-0e4d-4f3e-a75c-d3cfe8ab4dea", "metadata": {}, "source": [ "# Linear least squares\n", "\n", "The solution to a well-posed linear system $Ax = b$ is given by $x = A^{-1}b$. If the matrix is rectangular or not invertible, then we may generalise the notion of solution to $x = A^{\\dagger}b$, where $A^{\\dagger}$ denotes the [Moore--Penrose pseudoinverse](https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse).\n", "\n", "Lineax can handle problems of this type too.\n", "\n", "!!! info\n", "\n", " For reference: in core JAX, problems of this type are handled using [`jax.numpy.linalg.lstsq`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.lstsq.html#jax.numpy.linalg.lstsq).\n", " \n", "---\n", "\n", "## Picking a solver\n", "\n", "By default, the linear solve will fail. This will be a compile-time failure if using a rectangular matrix:" ] }, { "cell_type": "code", "execution_count": 2, "id": "a956c3f2-a70c-472f-9fa9-3dbc16293e1d", "metadata": { "tags": [] }, "outputs": [ { "ename": "ValueError", "evalue": "Cannot use `AutoLinearSolver(well_posed=True)` with a non-square operator. If you are trying solve a least-squares problem then you should pass `solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve` assumes that the operator is square and nonsingular.", "output_type": "error", "traceback": [ "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m Cannot use `AutoLinearSolver(well_posed=True)` with a non-square operator. If you are trying solve a least-squares problem then you should pass `solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve` assumes that the operator is square and nonsingular.\n" ] } ], "source": [ "import jax.random as jr\n", "import lineax as lx\n", "\n", "\n", "vector = jr.normal(jr.PRNGKey(1), (3,))\n", "\n", "rectangular_matrix = jr.normal(jr.PRNGKey(0), (3, 4))\n", "rectangular_operator = lx.MatrixLinearOperator(rectangular_matrix)\n", "lx.linear_solve(rectangular_operator, vector)" ] }, { "cell_type": "markdown", "id": "ba55c0dd-b696-497a-8b13-896c3a95d5fd", "metadata": {}, "source": [ "Or it will happen at run time if using a rank-deficient matrix:" ] }, { "cell_type": "code", "execution_count": 3, "id": "f0e7ffe6-1e3d-46dc-9dbd-d5ed4c2dedf4", "metadata": { "tags": [] }, "outputs": [ { "ename": "XlaRuntimeError", "evalue": "INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the\noperator was not well-posed, and that the solver does not support this.\n\nIf you are trying solve a linear least-squares problem then you should pass\n`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`\nassumes that the operator is square and nonsingular.\n\nIf you *were* expecting this solver to work with this operator, then it may be because:\n\n(a) the operator is singular, and your code has a bug; or\n\n(b) the operator was nearly singular (i.e. it had a high condition number:\n `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from\n numerical instability issues; or\n\n(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)\n that is does not actually satisfy.\n", "output_type": "error", "traceback": [ "\u001b[0;31mXlaRuntimeError\u001b[0m\u001b[0;31m:\u001b[0m INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the\noperator was not well-posed, and that the solver does not support this.\n\nIf you are trying solve a linear least-squares problem then you should pass\n`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`\nassumes that the operator is square and nonsingular.\n\nIf you *were* expecting this solver to work with this operator, then it may be because:\n\n(a) the operator is singular, and your code has a bug; or\n\n(b) the operator was nearly singular (i.e. it had a high condition number:\n `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from\n numerical instability issues; or\n\n(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)\n that is does not actually satisfy.\n" ] } ], "source": [ "deficient_matrix = jr.normal(jr.PRNGKey(0), (3, 3)).at[0].set(0)\n", "deficient_operator = lx.MatrixLinearOperator(deficient_matrix)\n", "lx.linear_solve(deficient_operator, vector)" ] }, { "cell_type": "markdown", "id": "4b5cedab-75e5-4b52-88d9-b9d574be7e19", "metadata": {}, "source": [ "Whilst linear least squares and pseudoinverse are a strict generalisation of linear solves and inverses (respectively), Lineax will *not* attempt to handle the ill-posed case automatically. This is because the algorithms for handling this case are much more computationally expensive.!\n", "\n", "If your matrix may be rectangular, but is still known to be full rank, then you can set the solver to allow this case like so:" ] }, { "cell_type": "code", "execution_count": 4, "id": "45abc8bf-4fcf-46be-a91a-58f4e04ac10e", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "rectangular_solution: [-0.3214848 -0.75565964 -0.6034579 -0.01326615]\n" ] } ], "source": [ "rectangular_solution = lx.linear_solve(\n", " rectangular_operator, vector, solver=lx.AutoLinearSolver(well_posed=None)\n", ")\n", "print(\"rectangular_solution: \", rectangular_solution.value)" ] }, { "cell_type": "markdown", "id": "86dc9e2f-fe2e-48c8-86ca-bc57f8137246", "metadata": {}, "source": [ "If your matrix may be either rectangular or rank-deficient, then you can set the solver to all this case like so:" ] }, { "cell_type": "code", "execution_count": 5, "id": "a9a2d92c-3676-471e-bb4a-5fd3b4748fd4", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "deficient_solution: [ 0.06046088 -1.0412765 0.8860444 ]\n" ] } ], "source": [ "deficient_solution = lx.linear_solve(\n", " deficient_operator, vector, solver=lx.AutoLinearSolver(well_posed=False)\n", ")\n", "print(\"deficient_solution: \", deficient_solution.value)" ] }, { "cell_type": "markdown", "id": "7b870311-de0f-434c-a9e7-2d8ebf9f0b38", "metadata": {}, "source": [ "Most users will want to use [`lineax.AutoLinearSolver`][], and not think about the details of which algorithm is selected.\n", "\n", "If you want to pick a particular algorithm, then that can be done too. [`lineax.QR`][] is capable of handling rectangular full-rank operators, and [`lineax.SVD`][] is capable of handling rank-deficient operators. (And in fact these are the algorithms that `AutoLinearSolver` is selecting in the examples above.)" ] }, { "cell_type": "markdown", "id": "c9649746-b0ef-495b-9ea1-eb5f6ca2e7e5", "metadata": {}, "source": [ "---\n", "\n", "## Differences from `jax.numpy.linalg.lstsq`?\n", "\n", "Lineax offers both speed and correctness advantages over the built-in algorithm. (This is partly because the built-in function has to have the same API as NumPy, so JAX is constrained in how it can be implemented.)\n", "\n", "### Speed (forward)\n", "\n", "First, in the rectangular case, then the QR algorithm is much faster than the SVD algorithm:" ] }, { "cell_type": "code", "execution_count": 25, "id": "d46d0c9a-47e4-439d-9beb-c9aaf47faa5d", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "JAX solution: [-0.10002219 0.09477127 -0.10846332 ... -0.08007179 -0.01216239\n", " -0.030862 ]\n", "Lineax solution: [-0.1000222 0.0947713 -0.10846333 ... -0.08007187 -0.01216241\n", " -0.03086199]\n", "\n", "JAX time: 0.011344402999384329\n", "Lineax time: 0.0028611960005946457\n" ] } ], "source": [ "import timeit\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "\n", "\n", "matrix = jr.normal(jr.PRNGKey(0), (500, 200))\n", "vector = jr.normal(jr.PRNGKey(1), (500,))\n", "\n", "\n", "@jax.jit\n", "def solve_jax(matrix, vector):\n", " out, *_ = jnp.linalg.lstsq(matrix, vector)\n", " return out\n", "\n", "\n", "@jax.jit\n", "def solve_lineax(matrix, vector):\n", " operator = lx.MatrixLinearOperator(matrix)\n", " solver = lx.QR() # or lx.AutoLinearSolver(well_posed=None)\n", " solution = lx.linear_solve(operator, vector, solver)\n", " return solution.value\n", "\n", "\n", "solution_jax = solve_jax(matrix, vector)\n", "solution_lineax = solve_lineax(matrix, vector)\n", "with np.printoptions(threshold=10):\n", " print(\"JAX solution:\", solution_jax)\n", " print(\"Lineax solution:\", solution_lineax)\n", "print()\n", "time_jax = timeit.repeat(lambda: solve_jax(matrix, vector), number=1, repeat=10)\n", "time_lineax = timeit.repeat(lambda: solve_lineax(matrix, vector), number=1, repeat=10)\n", "print(\"JAX time:\", min(time_jax))\n", "print(\"Lineax time:\", min(time_lineax))" ] }, { "cell_type": "markdown", "id": "397773d7-f782-45e6-9934-11c62c741380", "metadata": {}, "source": [ "### Speed (gradients)\n", "\n", "Lineax also uses a slightly more efficient autodifferentiation implementation, which ensures it is faster, even when both are using the SVD algorithm." ] }, { "cell_type": "code", "execution_count": 24, "id": "1988d0f6-86f5-401a-9615-30cccf04d129", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "JAX gradients: [[-1.75446249e-03 2.00700224e-03 ... -3.16517282e-04 -6.08515576e-04]\n", " [ 1.81865180e-04 4.51280124e-04 ... -1.64618701e-04 -6.53692259e-05]\n", " ...\n", " [-7.27269216e-04 1.27710134e-03 ... -2.64510425e-04 -3.38940619e-04]\n", " [ 6.55723223e-03 -3.18011409e-03 ... -1.10758876e-04 1.43246143e-03]]\n", "Lineax gradients: [[-1.7544631e-03 2.0070139e-03 ... -3.1653541e-04 -6.0847402e-04]\n", " [ 1.8186278e-04 4.5128341e-04 ... -1.6459504e-04 -6.5359738e-05]\n", " ...\n", " [-7.2721508e-04 1.2771402e-03 ... -2.6450949e-04 -3.3894143e-04]\n", " [ 6.5572355e-03 -3.1801097e-03 ... -1.1071599e-04 1.4324478e-03]]\n", "\n", "JAX time: 0.016591553001489956\n", "Lineax time: 0.012212782999995397\n" ] } ], "source": [ "@jax.jit\n", "@jax.grad\n", "def grad_jax(matrix):\n", " out, *_ = jnp.linalg.lstsq(matrix, vector)\n", " return out.sum()\n", "\n", "\n", "@jax.jit\n", "@jax.grad\n", "def grad_lineax(matrix):\n", " operator = lx.MatrixLinearOperator(matrix)\n", " solution = lx.linear_solve(operator, vector, lx.SVD())\n", " return solution.value.sum()\n", "\n", "\n", "gradients_jax = grad_jax(matrix)\n", "gradients_lineax = grad_lineax(matrix)\n", "with np.printoptions(threshold=10, edgeitems=2):\n", " print(\"JAX gradients:\", gradients_jax)\n", " print(\"Lineax gradients:\", gradients_lineax)\n", "print()\n", "time_jax = timeit.repeat(lambda: grad_jax(matrix), number=1, repeat=10)\n", "time_lineax = timeit.repeat(lambda: grad_lineax(matrix), number=1, repeat=10)\n", "print(\"JAX time:\", min(time_jax))\n", "print(\"Lineax time:\", min(time_lineax))" ] }, { "cell_type": "markdown", "id": "81a1da5a-3474-4613-926f-5c9d9cdcb4a7", "metadata": {}, "source": [ "### Correctness (gradients)\n", "\n", "Core JAX unfortunately has a bug that means it sometimes produces NaN gradients. Lineax does not:" ] }, { "cell_type": "code", "execution_count": 30, "id": "66b3a08e-92d0-4d0f-a5ea-9e8e5265d259", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "JAX gradients: [[nan nan nan]\n", " [nan nan nan]\n", " [nan nan nan]]\n", "Lineax gradients: [[ 0. -1. -2.]\n", " [ 0. -1. -2.]\n", " [ 0. -1. -2.]]\n" ] } ], "source": [ "@jax.jit\n", "@jax.grad\n", "def grad_jax(matrix):\n", " out, *_ = jnp.linalg.lstsq(matrix, jnp.arange(3.0))\n", " return out.sum()\n", "\n", "\n", "@jax.jit\n", "@jax.grad\n", "def grad_lineax(matrix):\n", " operator = lx.MatrixLinearOperator(matrix)\n", " solution = lx.linear_solve(operator, jnp.arange(3.0), lx.SVD())\n", " return solution.value.sum()\n", "\n", "\n", "print(\"JAX gradients:\", grad_jax(jnp.eye(3)))\n", "print(\"Lineax gradients:\", grad_lineax(jnp.eye(3)))" ] } ], "metadata": { "kernelspec": { "display_name": "py39", "language": "python", "name": "py39" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs/examples/no_materialisation.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "a7299095-8906-4867-82ef-d6b84b161366", "metadata": {}, "source": [ "# Using only matrix-vector operations\n", "\n", "When solving a linear system $Ax = b$, it is relatively common not to have immediate access to the full matrix $A$, but only to a function $F(x) = Ax$ computing the matrix-vector product. (We could compute $A$ from $F$, but is the matrix is large then this may be very inefficient.)\n", "\n", "**Example: Newton's method**\n", "\n", "For example, this comes up when using [Newton's method](https://en.wikipedia.org/wiki/Newton%27s_method#k_variables,_k_functions). In this case, we have a function $f \\colon \\mathbb{R}^n \\to \\mathbb{R}^n$, and wish to find the $\\delta \\in \\mathbb{R}^n$ for which $\\frac{\\mathrm{d}f}{\\mathrm{d}y}(y) \\; \\delta = -f(y)$. (Where $\\frac{\\mathrm{d}f}{\\mathrm{d}y}(y) \\in \\mathbb{R}^{n \\times n}$ is a matrix: it is the Jacobian of $f$.)\n", "\n", "In this case it is possible to use forward-mode autodifferentiation to evaluate $F(x) = \\frac{\\mathrm{d}f}{\\mathrm{d}y}(y) \\; x$, without ever instantiating the whole Jacobian $\\frac{\\mathrm{d}f}{\\mathrm{d}y}(y)$. Indeed, JAX has a [Jacobian-vector product function](https://jax.readthedocs.io/en/latest/_autosummary/jax.jvp.html#jax.jvp) for exactly this purpose.\n", "```python\n", "f = ...\n", "y = ...\n", "\n", "def F(x):\n", " \"\"\"Computes (df/dy) @ x.\"\"\"\n", " _, out = jax.jvp(f, (y,), (x,))\n", " return out\n", "```\n", "\n", "**Solving a linear system using only matrix-vector operations**\n", "\n", "Lineax offers [iterative solvers](../api/solvers.md#iterative-solvers), which are capable of solving a linear system knowing only its matrix-vector products." ] }, { "cell_type": "code", "execution_count": 1, "id": "b221ee1f-bd6b-4cbf-b69b-ed2e388602e1", "metadata": { "tags": [] }, "outputs": [], "source": [ "import jax.numpy as jnp\n", "import lineax as lx\n", "from jaxtyping import Array, Float # https://github.com/google/jaxtyping\n", "\n", "\n", "def f(y: Float[Array, \"3\"], args) -> Float[Array, \"3\"]:\n", " y0, y1, y2 = y\n", " f0 = 5 * y0 + y1**2\n", " f1 = y1 - y2 + 5\n", " f2 = y0 / (1 + 5 * y2**2)\n", " return jnp.stack([f0, f1, f2])\n", "\n", "\n", "y = jnp.array([1.0, 2.0, 3.0])\n", "operator = lx.JacobianLinearOperator(f, y, args=None)\n", "vector = f(y, args=None)\n", "solver = lx.Normal(lx.CG(rtol=1e-6, atol=1e-6))\n", "solution = lx.linear_solve(operator, vector, solver)" ] }, { "cell_type": "markdown", "id": "87568426-35ed-404b-bf78-425a6f519218", "metadata": {}, "source": [ "!!! warning\n", "\n", " Note that iterative solvers are something of a \"last resort\", and they are not suitable for all problems.\n", "\n", " - [CG](https://en.wikipedia.org/wiki/Conjugate_gradient_method) requires that the problem be positive or negative semidefinite.\n", " - Normalised CG (this is CG applied to the \"normal equations\" $(A^\\top A) x = (A^\\top b)$; note that $A^\\top A$ is always positive semidefinite) squares the condition number of $A$. In practice this means it may produce low-accuracy results if used with matrices with high condition number.\n", " - [BiCGStab](https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method) and [GMRES](https://en.wikipedia.org/wiki/Generalized_minimal_residual_method) will fail on many problems. They are primarily meant as specialised tools for e.g. the matrices that arise when solving elliptic systems." ] } ], "metadata": { "kernelspec": { "display_name": "py39", "language": "python", "name": "py39" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs/examples/operators.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "2fe0b1e4-35cb-4c39-b324-65253aab005a", "metadata": {}, "source": [ "# Manipulating linear operators\n", "\n", "Lineax offers a sophisticated system of linear operators, supporting many operations.\n", "\n", "## Arithmetic\n", "\n", "To begin with, they support arithmetic, like addition and multiplication:" ] }, { "cell_type": "code", "execution_count": 1, "id": "552021d3-dadf-49f3-bd17-84a18513bfcc", "metadata": { "tags": [] }, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import jax.random as jr\n", "import lineax as lx\n", "import numpy as np\n", "\n", "\n", "np.set_printoptions(precision=3)\n", "\n", "matrix = jnp.zeros((5, 5))\n", "matrix = matrix.at[0, 4].set(3) # top left corner\n", "sparse_operator = lx.MatrixLinearOperator(matrix)\n", "\n", "key0, key1, key = jr.split(jr.PRNGKey(0), 3)\n", "diag = jr.normal(key0, (5,))\n", "lower_diag = jr.normal(key0, (4,))\n", "upper_diag = jr.normal(key0, (4,))\n", "tridiag_operator = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)\n", "\n", "identity_operator = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((5,), jnp.float32))" ] }, { "cell_type": "code", "execution_count": 2, "id": "a4bb9825-73cc-447e-bc4c-c3e1a121a0a3", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-1.149 0.963 0. 0. 3. ]\n", " [ 0.963 -2.007 0.155 0. 0. ]\n", " [ 0. 0.155 0.988 -0.261 0. ]\n", " [ 0. 0. -0.261 0.931 0.899]\n", " [ 0. 0. 0. 0.899 -0.288]]\n" ] } ], "source": [ "print((sparse_operator + tridiag_operator).as_matrix())" ] }, { "cell_type": "code", "execution_count": 3, "id": "759c78a1-eee7-40e9-be6c-ea8c97c29e95", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-101.149 0.963 0. 0. 0. ]\n", " [ 0.963 -102.007 0.155 0. 0. ]\n", " [ 0. 0.155 -99.012 -0.261 0. ]\n", " [ 0. 0. -0.261 -99.069 0.899]\n", " [ 0. 0. 0. 0.899 -100.288]]\n" ] } ], "source": [ "print((tridiag_operator - 100 * identity_operator).as_matrix())" ] }, { "cell_type": "markdown", "id": "84412bfa-00ec-41d4-87d7-def781145a90", "metadata": {}, "source": [ "Or they can be composed together. (I.e. matrix multiplication.)" ] }, { "cell_type": "code", "execution_count": 4, "id": "8081d97f-5579-464f-8780-ffaa1d9c5f95", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 0. 0. 0. 0. -3.447]\n", " [ 0. 0. 0. 0. 2.888]\n", " [ 0. 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. 0. ]]\n" ] } ], "source": [ "print((tridiag_operator @ sparse_operator).as_matrix())" ] }, { "cell_type": "markdown", "id": "d2c2b580-616f-4abd-a732-7f4a9b13335f", "metadata": {}, "source": [ "Or they can be transposed:" ] }, { "cell_type": "code", "execution_count": 5, "id": "ae0393eb-3f43-490b-9842-bb374633633a", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0. 0. 0. 0. 0.]\n", " [0. 0. 0. 0. 0.]\n", " [0. 0. 0. 0. 0.]\n", " [0. 0. 0. 0. 0.]\n", " [3. 0. 0. 0. 0.]]\n" ] } ], "source": [ "print(sparse_operator.transpose().as_matrix()) # or sparse_operator.T will work" ] }, { "cell_type": "markdown", "id": "ddbbbb0f-7983-4e35-b92d-2512c9612d19", "metadata": {}, "source": [ "## Different operator types\n", "\n", "Lineax has many different operator types:\n", "\n", "- We've already seen some general examples above, like [`lineax.MatrixLinearOperator`][].\n", "- We've already seen some structured examples above, like [`lineax.TridiagonalLinearOperator`][].\n", "- Given a function $f \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$ and a point $x \\in \\mathbb{R}^n$, then [`lineax.JacobianLinearOperator`][] represents the Jacobian $\\frac{\\mathrm{d}f}{\\mathrm{d}x}(x) \\in \\mathbb{R}^{n \\times m}$.\n", "- Given a linear function $g \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$, then [`lineax.FunctionLinearOperator`][] represents the matrix corresponding to this linear function, i.e. the unique matrix $A$ for which $g(x) = Ax$.\n", "- etc!\n", "\n", "See the [operators](../api/operators.md) page for details on all supported operators.\n", "\n", "As above these can be freely combined:" ] }, { "cell_type": "code", "execution_count": 6, "id": "75ad4480-8ce0-4a88-9c76-bc054b1a0eaf", "metadata": { "tags": [] }, "outputs": [], "source": [ "from jaxtyping import Array, Float # https://github.com/google/jaxtyping\n", "\n", "\n", "def f(y: Float[Array, \"3\"], args) -> Float[Array, \"3\"]:\n", " y0, y1, y2 = y\n", " f0 = 5 * y0 + y1**2\n", " f1 = y1 - y2 + 5\n", " f2 = y0 / (1 + 5 * y2**2)\n", " return jnp.stack([f0, f1, f2])\n", "\n", "\n", "def g(y: Float[Array, \"3\"]) -> Float[Array, \"3\"]:\n", " # Must be linear!\n", " y0, y1, y2 = y\n", " f0 = y0 - y2\n", " f1 = 0.0\n", " f2 = 5 * y1\n", " return jnp.stack([f0, f1, f2])\n", "\n", "\n", "y = jnp.array([1.0, 2.0, 3.0])\n", "in_structure = jax.eval_shape(lambda: y)\n", "jac_operator = lx.JacobianLinearOperator(f, y, args=None)\n", "fn_operator = lx.FunctionLinearOperator(g, in_structure)\n", "identity_operator = lx.IdentityLinearOperator(in_structure)\n", "\n", "operator = jac_operator @ fn_operator + 0.9 * identity_operator" ] }, { "cell_type": "markdown", "id": "5e528057-29ff-468d-aa3d-7155dd57082d", "metadata": {}, "source": [ "This composition does not instantiate a matrix for them by default. (This is sometimes important for efficiency when working with many operators.) Instead, the composition is stored as another linear operator:" ] }, { "cell_type": "code", "execution_count": 7, "id": "5d15150d-955f-4006-bd36-58e2e6663307", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AddLinearOperator(\n", " operator1=ComposedLinearOperator(\n", " operator1=JacobianLinearOperator(...),\n", " operator2=FunctionLinearOperator(...)\n", " ),\n", " operator2=MulLinearOperator(\n", " operator=IdentityLinearOperator(...),\n", " scalar=f32[]\n", " )\n", ")\n" ] } ], "source": [ "import equinox as eqx # https://github.com/patrick-kidger/equinox\n", "\n", "\n", "truncate_leaf = lambda x: x in (jac_operator, fn_operator, identity_operator)\n", "eqx.tree_pprint(operator, truncate_leaf=truncate_leaf)" ] }, { "cell_type": "markdown", "id": "ff7b0591-1203-4f5e-886e-399822c68a15", "metadata": { "tags": [] }, "source": [ "If you want to materialise them into a matrix, then this can be done:" ] }, { "cell_type": "code", "execution_count": 8, "id": "3713589f-1ac4-4e08-946b-ecc3fcf6a4c3", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "Array([[ 5.9 , 0. , -5. ],\n", " [ 0. , -4.1 , 0. ],\n", " [ 0.022, -0.071, 0.878]], dtype=float32)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "operator.as_matrix()" ] }, { "cell_type": "markdown", "id": "a483517e-89d7-4e9e-ad89-1915d886c14c", "metadata": {}, "source": [ "Which can in turn be treated as another linear operator, if desired:" ] }, { "cell_type": "code", "execution_count": 9, "id": "fccddc81-d50e-4abe-a354-38402e462b1f", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MatrixLinearOperator(\n", " matrix=Array([[ 5.9 , 0. , -5. ],\n", " [ 0. , -4.1 , 0. ],\n", " [ 0.022, -0.071, 0.878]], dtype=float32),\n", " tags=frozenset()\n", ")\n" ] } ], "source": [ "operator_fully_materialised = lx.MatrixLinearOperator(operator.as_matrix())\n", "eqx.tree_pprint(operator_fully_materialised, short_arrays=False)" ] } ], "metadata": { "kernelspec": { "display_name": "py39", "language": "python", "name": "py39" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs/examples/structured_matrices.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "e2573d62-a505-4998-8796-b0f1bc889433", "metadata": {}, "source": [ "# Structured matrices\n", "\n", "Lineax can also be used with matrices known to exhibit special structure, e.g. tridiagonal matrices or positive definite matrices.\n", "\n", "Typically, that means using a particular operator type:" ] }, { "cell_type": "code", "execution_count": 1, "id": "8e275652-dd80-4a9a-b3ac-b96dc16d3334", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 4. 2. 0. 0. ]\n", " [ 1. -0.5 -1. 0. ]\n", " [ 0. 3. 7. -5. ]\n", " [ 0. 0. -0.7 1. ]]\n" ] } ], "source": [ "import jax.numpy as jnp\n", "import jax.random as jr\n", "import lineax as lx\n", "\n", "\n", "diag = jnp.array([4.0, -0.5, 7.0, 1.0])\n", "lower_diag = jnp.array([1.0, 3.0, -0.7])\n", "upper_diag = jnp.array([2.0, -1.0, -5.0])\n", "\n", "operator = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)\n", "print(operator.as_matrix())" ] }, { "cell_type": "code", "execution_count": 2, "id": "ba23ecc4-bdea-4293-a138-ce77bc83082c", "metadata": { "tags": [] }, "outputs": [], "source": [ "vector = jnp.array([1.0, -0.5, 2.0, 0.8])\n", "# Will automatically dispatch to a tridiagonal solver.\n", "solution = lx.linear_solve(operator, vector)" ] }, { "cell_type": "markdown", "id": "cd58979d-b619-4ddf-9a17-12e8babae3e8", "metadata": {}, "source": [ "If you're uncertain which solver is being dispatched to, then you can check:" ] }, { "cell_type": "code", "execution_count": 3, "id": "6984f62f-75fc-4d6e-ab42-fdade471be5b", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tridiagonal()\n" ] } ], "source": [ "default_solver = lx.AutoLinearSolver(well_posed=True)\n", "print(default_solver.select_solver(operator))" ] }, { "cell_type": "markdown", "id": "164a5bd5-5d48-4b28-bcc5-d276ab49c780", "metadata": {}, "source": [ "If you want to enforce that a particular solver is used, then it can be passed manually:" ] }, { "cell_type": "code", "execution_count": 4, "id": "102ada9a-0533-40cf-9bad-02918fffb6b1", "metadata": { "tags": [] }, "outputs": [], "source": [ "solution = lx.linear_solve(operator, vector, solver=lx.Tridiagonal())" ] }, { "cell_type": "markdown", "id": "1b4ebf09-e138-43f6-973c-c9f005ffb55e", "metadata": {}, "source": [ "Trying to use a solver with an unsupported operator will raise an error:" ] }, { "cell_type": "code", "execution_count": 6, "id": "d8f5bf66-53cd-4e81-a8d7-a19e86307ad3", "metadata": { "tags": [] }, "outputs": [ { "ename": "ValueError", "evalue": "`Tridiagonal` may only be used for linear solves with tridiagonal matrices", "output_type": "error", "traceback": [ "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m `Tridiagonal` may only be used for linear solves with tridiagonal matrices\n" ] } ], "source": [ "not_tridiagonal_matrix = jr.normal(jr.PRNGKey(0), (4, 4))\n", "not_tridiagonal_operator = lx.MatrixLinearOperator(not_tridiagonal_matrix)\n", "solution = lx.linear_solve(not_tridiagonal_operator, vector, solver=lx.Tridiagonal())" ] }, { "cell_type": "markdown", "id": "03c4c531-58fa-4b56-8b0a-6e611c8c5912", "metadata": {}, "source": [ "---\n", "\n", "Besides using a particular operator type, the structure of the matrix can also be expressed by [adding particular tags](../api/tags.md). These tags act as a manual override mechanism, and the values of the matrix are not checked.\n", "\n", "For example, let's construct a positive definite matrix:" ] }, { "cell_type": "code", "execution_count": 7, "id": "b5add874-7a2c-4000-84c3-8c94a121a831", "metadata": { "tags": [] }, "outputs": [], "source": [ "matrix = jr.normal(jr.PRNGKey(0), (4, 4))\n", "operator = lx.MatrixLinearOperator(matrix.T @ matrix)" ] }, { "cell_type": "markdown", "id": "5459b2d6-ddb9-4a37-bb51-3f5c204bab0d", "metadata": {}, "source": [ "Unfortunately, Lineax has no way of knowing that this matrix is positive definite. It can solve the system, but it will not use a solver that is adapted to exploit the extra structure:" ] }, { "cell_type": "code", "execution_count": 8, "id": "78400416-e774-4f74-a530-e368db84af0e", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LU()\n" ] } ], "source": [ "solution = lx.linear_solve(operator, vector)\n", "print(default_solver.select_solver(operator))" ] }, { "cell_type": "markdown", "id": "e108bdff-1cf1-4751-8c9d-3baae82ca9a7", "metadata": {}, "source": [ "But if we add a tag:" ] }, { "cell_type": "code", "execution_count": 9, "id": "f6dc2966-1dfa-4a3c-be6a-974926695547", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cholesky()\n" ] } ], "source": [ "operator = lx.MatrixLinearOperator(matrix.T @ matrix, lx.positive_semidefinite_tag)\n", "solution2 = lx.linear_solve(operator, vector)\n", "print(default_solver.select_solver(operator))" ] }, { "cell_type": "markdown", "id": "7274d17b-a7d3-45bf-9042-785ac25e2d74", "metadata": {}, "source": [ "Then a more efficient solver can be selected. We can check that the solutions returned from these two approaches are equal:" ] }, { "cell_type": "code", "execution_count": 10, "id": "fdcde152-9ac1-4532-a174-3fc39d83d289", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 1.400575 -0.41042092 0.5313305 0.28422552]\n", "[ 1.4005749 -0.41042086 0.53133047 0.2842255 ]\n" ] } ], "source": [ "print(solution.value)\n", "print(solution2.value)" ] } ], "metadata": { "kernelspec": { "display_name": "py39", "language": "python", "name": "py39" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs/faq.md ================================================ # FAQ ## How does this differ from `jax.numpy.solve`, `jax.scipy.{...}` etc.? Lineax offers several improvements. Most notably: - Several new solvers. For example, [`lineax.QR`][] has no counterpart in core JAX. (And it is much faster than `jax.numpy.linalg.lstsq`, which is the closest equivalent, and uses an SVD decomposition instead.) - Several new operators. For example, [`lineax.JacobianLinearOperator`][] has no counterpart in core JAX. - A consistent API. The built-in JAX operations all differ from each other slightly, and are split across `jax.numpy`, `jax.scipy`, and `jax.scipy.sparse`. - Numerically stable gradients. The existing JAX implementations will sometimes return `NaN`s! - Some faster compile times and run times in a few places. Most of these are because JAX aims to mimic the existing NumPy/SciPy APIs. (I.e. it's not JAX's fault that it doesn't take the approach that Lineax does!) ## How do I represent a {lower, upper} triangular matrix? Typically: create a full matrix, with the {lower, upper} part containing your values, and the converse {upper, lower} part containing all zeros. Then use, e.g., `operator = lx.MatrixLinearOperator(matrix, lx.lower_triangular_tag)`. This is the most efficient way to store a triangular matrix in JAX's ndarray-based programming model. ## What about other operations from linear algebra? (Determinants, eigenvalues, etc.) See [`jax.numpy.linalg`](https://jax.readthedocs.io/en/latest/jax.numpy.html#module-jax.numpy.linalg) and [`jax.scipy.linalg`](https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.linalg). ## How do I solve multiple systems of equations (i.e. `AX = B`)? Solvers implemented in Lineax target single systems of linear equations (i.e., `Ax = b`), however, using `jax.vmap` or `equinox.filter_vmap`, it can solve multiple systems with minimal effort. ```python multi_linear_solve = eqx.filter_vmap(lx.linear_solve, in_axes=(None, 1)) # or multi_linear_solve = jax.vmap(lx.linear_solve, in_axes=(None, 1)) ``` ================================================ FILE: docs/index.md ================================================ # Getting started Lineax is a [JAX](https://github.com/google/jax) library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.) Features include: - PyTree-valued matrices and vectors; - General linear operators for Jacobians, transposes, etc.; - Efficient linear least squares (e.g. QR solvers); - Numerically stable gradients through linear least squares; - Support for structured (e.g. symmetric) matrices; - Improved compilation times; - Improved runtime of some algorithms; - Support for both real-valued and complex-valued inputs; - All the benefits of working with JAX: autodiff, autoparallism, GPU/TPU support, etc. ## Installation ```bash pip install lineax ``` Requires Python 3.10+, JAX 0.4.38+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.10+. ## Quick example Lineax can solve a least squares problem with an explicit matrix operator: ```python import jax.random as jr import lineax as lx matrix_key, vector_key = jr.split(jr.PRNGKey(0)) matrix = jr.normal(matrix_key, (10, 8)) vector = jr.normal(vector_key, (10,)) operator = lx.MatrixLinearOperator(matrix) solution = lx.linear_solve(operator, vector, solver=lx.QR()) ``` or Lineax can solve a problem without ever materializing a matrix, as done in this quadratic solve: ```python import jax import lineax as lx key = jax.random.PRNGKey(0) y = jax.random.normal(key, (10,)) def quadratic_fn(y, args): return jax.numpy.sum((y - 1)**2) gradient_fn = jax.grad(quadratic_fn) hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag) solver = lx.CG(rtol=1e-6, atol=1e-6) out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver) minimum = y - out.value ``` ## Next steps Check out the examples or the API reference on the left-hand bar. ## See also: other libraries in the JAX ecosystem **Always useful** [Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX! [jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays. **Deep learning** [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers. [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device). [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs). [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees. **Scientific computing** [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers. [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares. [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling. [sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent. [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!) **Awesome JAX** [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects. ================================================ FILE: lineax/__init__.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib.metadata from . import internal as internal from ._operator import ( AbstractLinearOperator as AbstractLinearOperator, AddLinearOperator as AddLinearOperator, ComposedLinearOperator as ComposedLinearOperator, conj as conj, diagonal as diagonal, DiagonalLinearOperator as DiagonalLinearOperator, DivLinearOperator as DivLinearOperator, FunctionLinearOperator as FunctionLinearOperator, has_unit_diagonal as has_unit_diagonal, IdentityLinearOperator as IdentityLinearOperator, is_diagonal as is_diagonal, is_lower_triangular as is_lower_triangular, is_negative_semidefinite as is_negative_semidefinite, is_positive_semidefinite as is_positive_semidefinite, is_symmetric as is_symmetric, is_tridiagonal as is_tridiagonal, is_upper_triangular as is_upper_triangular, JacobianLinearOperator as JacobianLinearOperator, linearise as linearise, materialise as materialise, MatrixLinearOperator as MatrixLinearOperator, MulLinearOperator as MulLinearOperator, NegLinearOperator as NegLinearOperator, PyTreeLinearOperator as PyTreeLinearOperator, TaggedLinearOperator as TaggedLinearOperator, TangentLinearOperator as TangentLinearOperator, tridiagonal as tridiagonal, TridiagonalLinearOperator as TridiagonalLinearOperator, ) from ._solution import RESULTS as RESULTS, Solution as Solution from ._solve import ( AbstractLinearSolver as AbstractLinearSolver, AutoLinearSolver as AutoLinearSolver, invert as invert, linear_solve as linear_solve, ) from ._solver import ( BiCGStab as BiCGStab, CG as CG, Cholesky as Cholesky, Diagonal as Diagonal, GMRES as GMRES, LSMR as LSMR, LU as LU, Normal as Normal, NormalCG as NormalCG, QR as QR, SVD as SVD, Triangular as Triangular, Tridiagonal as Tridiagonal, ) from ._tags import ( diagonal_tag as diagonal_tag, lower_triangular_tag as lower_triangular_tag, negative_semidefinite_tag as negative_semidefinite_tag, positive_semidefinite_tag as positive_semidefinite_tag, symmetric_tag as symmetric_tag, transpose_tags as transpose_tags, transpose_tags_rules as transpose_tags_rules, tridiagonal_tag as tridiagonal_tag, unit_diagonal_tag as unit_diagonal_tag, upper_triangular_tag as upper_triangular_tag, ) __version__ = importlib.metadata.version("lineax") ================================================ FILE: lineax/_custom_types.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any import equinox.internal as eqxi sentinel: Any = eqxi.doc_repr(object(), "sentinel") ================================================ FILE: lineax/_misc.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import equinox as eqx import jax import jax.numpy as jnp import jax.tree_util as jtu from jaxtyping import Array, ArrayLike, Bool, PyTree # pyright:ignore def tree_where( pred: Bool[ArrayLike, ""], true: PyTree[ArrayLike], false: PyTree[ArrayLike] ) -> PyTree[Array]: keep = lambda a, b: jnp.where(pred, a, b) return jtu.tree_map(keep, true, false) def resolve_rcond(rcond, n, m, dtype): if rcond is None: # This `2 *` is a heuristic: I have seen very rare failures without it, in ways # that seem to depend on JAX compilation state. (E.g. running unrelated JAX # computations beforehand, in a completely different JIT-compiled region, can # result in differences in the success/failure of the solve.) return 2 * jnp.finfo(dtype).eps * max(n, m) else: return jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond) def jacobian(fn, in_size, out_size, holomorphic=False, has_aux=False, jac=None): if jac is None: # Heuristic for which is better in each case # These could probably be tuned a lot more. jac_fwd = (in_size < 100) or (in_size <= 1.5 * out_size) elif jac == "fwd": jac_fwd = True elif jac == "bwd": jac_fwd = False else: raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.") if jac_fwd: return jax.jacfwd(fn, holomorphic=holomorphic, has_aux=has_aux) else: return jax.jacrev(fn, holomorphic=holomorphic, has_aux=has_aux) def _asarray(dtype, x): return jnp.asarray(x, dtype=dtype) # Work around JAX issue #15676 _asarray = jax.custom_jvp(_asarray, nondiff_argnums=(0,)) @_asarray.defjvp def _asarray_jvp(dtype, x, tx): (x,) = x (tx,) = tx return _asarray(dtype, x), _asarray(dtype, tx) def default_floating_dtype(): if jax.config.jax_enable_x64: # pyright: ignore return jnp.float64 else: return jnp.float32 def inexact_asarray(x): dtype = jnp.result_type(x) if not jnp.issubdtype(jnp.result_type(x), jnp.inexact): dtype = default_floating_dtype() return _asarray(dtype, x) def complex_to_real_dtype(dtype): return jnp.finfo(dtype).dtype def strip_weak_dtype(tree: PyTree) -> PyTree: return jtu.tree_map( lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=x.sharding) if type(x) is jax.ShapeDtypeStruct else x, tree, ) def structure_equal(x, y) -> bool: x = strip_weak_dtype(jax.eval_shape(lambda: x)) y = strip_weak_dtype(jax.eval_shape(lambda: y)) return eqx.tree_equal(x, y) is True ================================================ FILE: lineax/_norm.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools as ft import math import jax import jax.numpy as jnp import jax.tree_util as jtu from equinox.internal import ω from jaxtyping import Array, ArrayLike, Inexact, PyTree, Scalar from ._misc import complex_to_real_dtype, default_floating_dtype def tree_dot(tree1: PyTree[ArrayLike], tree2: PyTree[ArrayLike]) -> Inexact[Array, ""]: """Compute the dot product of two pytrees of arrays with the same pytree structure.""" leaves1, treedef1 = jtu.tree_flatten(tree1) leaves2, treedef2 = jtu.tree_flatten(tree2) if treedef1 != treedef2: raise ValueError("trees must have the same structure") assert len(leaves1) == len(leaves2) dots = [] for leaf1, leaf2 in zip(leaves1, leaves2): dots.append( jnp.dot( jnp.conj(leaf1).reshape(-1), jnp.reshape(leaf2, -1), precision=jax.lax.Precision.HIGHEST, # pyright: ignore ) ) if len(dots) == 0: return jnp.array(0, default_floating_dtype()) else: return ft.reduce(jnp.add, dots) def sum_squares(x: PyTree[ArrayLike]) -> Scalar: """Computes the square of the L2 norm of a PyTree of arrays. Considering the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes `Σ_i x_i^2` """ return tree_dot(x, x).real def two_norm(x: PyTree[ArrayLike]) -> Scalar: """Computes the L2 norm of a PyTree of arrays. Considering the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes `sqrt(Σ_i x_i^2)` """ # Wrap the `custom_jvp` into a function so that our autogenerated documentation # displays the docstring correctly. return _two_norm(x) @jax.custom_jvp def _two_norm(x: PyTree[ArrayLike]) -> Scalar: leaves = jtu.tree_leaves(x) size = sum([jnp.size(xi) for xi in leaves]) if size == 1: # Avoid needless squaring-and-then-rooting. for leaf in leaves: if jnp.size(leaf) == 1: return jnp.abs(jnp.reshape(leaf, ())) else: assert False else: return jnp.sqrt(sum_squares(x)) @_two_norm.defjvp def _two_norm_jvp(x, tx): (x,) = x (tx,) = tx out = two_norm(x) # Get zero gradient, rather than NaN gradient, in these cases. pred = (out == 0) | jnp.isinf(out) denominator = jnp.where(pred, 1, out) # We could also switch the dot and the division. # This approach is a bit more expensive (more divisions), but should be more # numerically stable (`x` and `denominator` should be of the same scale; `tx` is of # unknown scale). with jax.numpy_dtype_promotion("standard"): div = (x**ω / denominator).ω t_out = tree_dot(div, tx).real t_out = jnp.where(pred, 0, t_out) return out, t_out def rms_norm(x: PyTree[ArrayLike]) -> Scalar: """Compute the RMS (root-mean-squared) norm of a PyTree of arrays. This is the same as the L2 norm, averaged by the size of the input `x`. Considering the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes `sqrt((Σ_i x_i^2)/n)` """ leaves = jtu.tree_leaves(x) size = sum([jnp.size(xi) for xi in leaves]) if size == 0: if len(leaves) == 0: dtype = default_floating_dtype() else: dtype = complex_to_real_dtype(jnp.result_type(*leaves)) return jnp.array(0.0, dtype) else: return two_norm(x) / math.sqrt(size) def max_norm(x: PyTree[ArrayLike]) -> Scalar: """Compute the L-infinity norm of a PyTree of arrays. This is the largest absolute elementwise value. Considering the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes `max_i |x_i|`. """ leaves = jtu.tree_leaves(x) leaf_maxes = [jnp.max(jnp.abs(xi)) for xi in leaves if jnp.size(xi) > 0] if len(leaf_maxes) == 0: if len(leaves) == 0: dtype = default_floating_dtype() else: dtype = complex_to_real_dtype(jnp.result_type(*leaves)) return jnp.array(0.0, dtype) else: out = ft.reduce(jnp.maximum, leaf_maxes) return _zero_grad_at_zero(out) @jax.custom_jvp def _zero_grad_at_zero(x): return x @_zero_grad_at_zero.defjvp def _zero_grad_at_zero_jvp(primals, tangents): (out,) = primals (t_out,) = tangents return out, jnp.where(out == 0, 0, t_out) ================================================ FILE: lineax/_operator.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import enum import functools as ft import math import warnings from collections.abc import Callable, Iterable from typing import Any, Literal, NoReturn, TypeVar import equinox as eqx import equinox.internal as eqxi import jax import jax.flatten_util as jfu import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu import numpy as np from equinox.internal import ω from jaxtyping import ( Array, ArrayLike, Inexact, PyTree, # pyright: ignore Scalar, Shaped, ) from ._custom_types import sentinel from ._misc import ( default_floating_dtype, inexact_asarray, jacobian, strip_weak_dtype, ) from ._tags import ( diagonal_tag, lower_triangular_tag, negative_semidefinite_tag, positive_semidefinite_tag, symmetric_tag, transpose_tags, tridiagonal_tag, unit_diagonal_tag, upper_triangular_tag, ) def _frozenset(x: object | Iterable[object]) -> frozenset[object]: try: iter_x = iter(x) # pyright: ignore except TypeError: return frozenset([x]) else: return frozenset(iter_x) class AbstractLinearOperator(eqx.Module): """Abstract base class for all linear operators. Linear operators can act between PyTrees. Each `AbstractLinearOperator` is thought of as a linear function `X -> Y`, where each element of `X` is as PyTree of floating-point JAX arrays, and each element of `Y` is a PyTree of floating-point JAX arrays. Abstract linear operators support some operations: ```python op1 + op2 # addition of two operators op1 @ op2 # composition of two operators. op1 * 3.2 # multiplication by a scalar op1 / 3.2 # division by a scalar ``` """ def __check_init__(self): if ( is_symmetric(self) or is_positive_semidefinite(self) or is_negative_semidefinite(self) ): # In particular, we check that dtypes match. in_structure = self.in_structure() out_structure = self.out_structure() # `is` check to handle the possibility of a tracer. if eqx.tree_equal(in_structure, out_structure) is not True: raise ValueError( "Symmetric/Hermitian matrices must have matching input and output " f"structures. Got input structure {in_structure} and output " f"structure {out_structure}." ) @abc.abstractmethod def mv( self, vector: PyTree[Inexact[Array, " _b"]] ) -> PyTree[Inexact[Array, " _a"]]: """Computes a matrix-vector product between this operator and a `vector`. **Arguments:** - `vector`: Should be some PyTree of floating-point arrays, whose structure should match `self.in_structure()`. **Returns:** A PyTree of floating-point arrays, with structure that matches `self.out_structure()`. """ @abc.abstractmethod def as_matrix(self) -> Inexact[Array, "a b"]: """Materialises this linear operator as a matrix. Note that this can be a computationally (time and/or memory) expensive operation, as many linear operators are defined implicitly, e.g. in terms of their action on a vector. **Arguments:** None. **Returns:** A 2-dimensional floating-point JAX array. """ @abc.abstractmethod def transpose(self) -> "AbstractLinearOperator": """Transposes this linear operator. This can be called as either `operator.T` or `operator.transpose()`. **Arguments:** None. **Returns:** Another [`lineax.AbstractLinearOperator`][]. """ @abc.abstractmethod def in_structure(self) -> PyTree[jax.ShapeDtypeStruct]: """Returns the expected input structure of this linear operator. **Arguments:** None. **Returns:** A PyTree of `jax.ShapeDtypeStruct`. """ @abc.abstractmethod def out_structure(self) -> PyTree[jax.ShapeDtypeStruct]: """Returns the expected output structure of this linear operator. **Arguments:** None. **Returns:** A PyTree of `jax.ShapeDtypeStruct`. """ def in_size(self) -> int: """Returns the total number of scalars in the input of this linear operator. That is, the dimensionality of its input space. **Arguments:** None. **Returns:** An integer. """ leaves = jtu.tree_leaves(self.in_structure()) return sum(math.prod(leaf.shape) for leaf in leaves) # pyright: ignore def out_size(self) -> int: """Returns the total number of scalars in the output of this linear operator. That is, the dimensionality of its output space. **Arguments:** None. **Returns:** An integer. """ leaves = jtu.tree_leaves(self.out_structure()) return sum(math.prod(leaf.shape) for leaf in leaves) # pyright: ignore @property def T(self) -> "AbstractLinearOperator": """Equivalent to [`lineax.AbstractLinearOperator.transpose`][]""" return self.transpose() def __add__(self, other) -> "AbstractLinearOperator": if not isinstance(other, AbstractLinearOperator): raise ValueError("Can only add AbstractLinearOperators together.") return AddLinearOperator(self, other) def __sub__(self, other) -> "AbstractLinearOperator": if not isinstance(other, AbstractLinearOperator): raise ValueError("Can only add AbstractLinearOperators together.") return AddLinearOperator(self, -other) def __mul__(self, other) -> "AbstractLinearOperator": other = jnp.asarray(other) if other.shape != (): raise ValueError("Can only multiply AbstractLinearOperators by scalars.") return MulLinearOperator(self, other) def __rmul__(self, other) -> "AbstractLinearOperator": return self * other def __matmul__(self, other) -> "AbstractLinearOperator": if not isinstance(other, AbstractLinearOperator): raise ValueError("Can only compose AbstractLinearOperators together.") return ComposedLinearOperator(self, other) def __truediv__(self, other) -> "AbstractLinearOperator": other = jnp.asarray(other) if other.shape != (): raise ValueError("Can only divide AbstractLinearOperators by scalars.") return DivLinearOperator(self, other) def __neg__(self) -> "AbstractLinearOperator": return NegLinearOperator(self) class MatrixLinearOperator(AbstractLinearOperator): """Wraps a 2-dimensional JAX array into a linear operator. If the matrix has shape `(a, b)` then matrix-vector multiplication (`self.mv`) is defined in the usual way: as performing a matrix-vector that accepts a vector of shape `(a,)` and returns a vector of shape `(b,)`. """ matrix: Inexact[Array, "a b"] tags: frozenset[object] = eqx.field(static=True) def __init__( self, matrix: Shaped[Array, "a b"], tags: object | frozenset[object] = () ): """**Arguments:** - `matrix`: a two-dimensional JAX array. For an array with shape `(a, b)` then this operator can perform matrix-vector products on a vector of shape `(b,)` to return a vector of shape `(a,)`. - `tags`: any tags indicating whether this matrix has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong. """ if jnp.ndim(matrix) != 2: raise ValueError( "`MatrixLinearOperator(matrix=...)` should be 2-dimensional." ) if not jnp.issubdtype(matrix.dtype, jnp.inexact): matrix = matrix.astype(jnp.float32) self.matrix = matrix self.tags = _frozenset(tags) def mv(self, vector): maybe_sparse_op = _try_sparse_materialise(self) if maybe_sparse_op is not self: return maybe_sparse_op.mv(vector) return jnp.matmul(self.matrix, vector, precision=lax.Precision.HIGHEST) def as_matrix(self): return self.matrix def transpose(self): if is_symmetric(self): return self return MatrixLinearOperator(self.matrix.T, transpose_tags(self.tags)) def in_structure(self): _, in_size = jnp.shape(self.matrix) return jax.ShapeDtypeStruct(shape=(in_size,), dtype=self.matrix.dtype) def out_structure(self): out_size, _ = jnp.shape(self.matrix) return jax.ShapeDtypeStruct(shape=(out_size,), dtype=self.matrix.dtype) def _matmul(matrix: ArrayLike, vector: ArrayLike) -> Array: # matrix has structure [leaf(out), leaf(in)] # vector has structure [leaf(in)] # return has structure [leaf(out)] return jnp.tensordot( matrix, vector, axes=jnp.ndim(vector), precision=lax.Precision.HIGHEST ) def _tree_matmul(matrix: PyTree[ArrayLike], vector: PyTree[ArrayLike]) -> PyTree[Array]: # matrix has structure [tree(in), leaf(out), leaf(in)] # vector has structure [tree(in), leaf(in)] # return has structure [leaf(out)] matrix = jtu.tree_leaves(matrix) vector = jtu.tree_leaves(vector) assert len(matrix) == len(vector) return sum([_matmul(m, v) for m, v in zip(matrix, vector)]) # Needed as static fields must be hashable and eq-able, and custom pytrees might have # e.g. define custom __eq__ methods. _T = TypeVar("_T") _FlatPyTree = tuple[list[_T], jtu.PyTreeDef] def _inexact_structure_impl2(x): if jnp.issubdtype(x.dtype, jnp.inexact): return x else: return x.astype(default_floating_dtype()) def _inexact_structure_impl(x): return jtu.tree_map(_inexact_structure_impl2, x) def _inexact_structure(x: PyTree[jax.ShapeDtypeStruct]) -> PyTree[jax.ShapeDtypeStruct]: return strip_weak_dtype(jax.eval_shape(_inexact_structure_impl, x)) class _Leaf: # not a pytree def __init__(self, value): self.value = value # The `{input,output}_structure`s have to be static because otherwise abstract # evaluation rules will promote them to ShapedArrays. class PyTreeLinearOperator(AbstractLinearOperator): """Represents a PyTree of floating-point JAX arrays as a linear operator. This is basically a generalisation of [`lineax.MatrixLinearOperator`][], from taking just a single array to take a PyTree-of-arrays. (And likewise from returning a single array to returning a PyTree-of-arrays.) Specifically, suppose we want this to be a linear operator `X -> Y`, for which elements of `X` are PyTrees with structure `T` whose `i`th leaf is a floating-point JAX array of shape `x_shape_i`, and elements of `Y` are PyTrees with structure `S` whose `j`th leaf is a floating-point JAX array of has shape `y_shape_j`. Then the input PyTree should have structure `T`-compose-`S`, and its `(i, j)`-th leaf should be a floating-point JAX array of shape `(*x_shape_i, *y_shape_j)`. !!! Example ```python # Suppose `x` is a member of our input space, with the following pytree # structure: eqx.tree_pprint(x) # [f32[5, 9], f32[3]] # Suppose `y` is a member of our output space, with the following pytree # structure: eqx.tree_pprint(y) # {"a": f32[1, 2]} # then `pytree` should be a pytree with the following structure: eqx.tree_pprint(pytree) # {"a": [f32[1, 2, 5, 9], f32[1, 2, 3]]} ``` """ pytree: PyTree[Inexact[Array, "..."]] output_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) tags: frozenset[object] = eqx.field(static=True) input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) def __init__( self, pytree: PyTree[ArrayLike], output_structure: PyTree[jax.ShapeDtypeStruct], tags: object | frozenset[object] = (), ): """**Arguments:** - `pytree`: this should be a PyTree, with structure as specified in [`lineax.PyTreeLinearOperator`][]. - `output_structure`: the structure of the output space. This should be a PyTree of `jax.ShapeDtypeStruct`s. (The structure of the input space is then automatically derived from the structure of `pytree`.) - `tags`: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong. """ output_structure = _inexact_structure(output_structure) self.pytree = jtu.tree_map(inexact_asarray, pytree) self.output_structure = jtu.tree_flatten(output_structure) self.tags = _frozenset(tags) # self.out_structure() has structure [tree(out)] # self.pytree has structure [tree(out), tree(in), leaf(out), leaf(in)] def get_structure(struct, subpytree): # subpytree has structure [tree(in), leaf(out), leaf(in)] def sub_get_structure(leaf): shape = jnp.shape(leaf) # [leaf(out), leaf(in)] ndim = len(struct.shape) if shape[:ndim] != struct.shape: raise ValueError( "`pytree` and `output_structure` are not consistent" ) return jax.ShapeDtypeStruct( shape=shape[ndim:], dtype=jnp.result_type(leaf) ) return _Leaf(jtu.tree_map(sub_get_structure, subpytree)) if output_structure is None: # Implies that len(input_structures) > 0 raise ValueError("Cannot have trivial output_structure") input_structures = jtu.tree_map(get_structure, output_structure, self.pytree) input_structures = jtu.tree_leaves(input_structures) input_structure = input_structures[0].value for val in input_structures[1:]: if eqx.tree_equal(input_structure, val.value) is not True: raise ValueError( "`pytree` does not have a consistent `input_structure`" ) self.input_structure = jtu.tree_flatten(input_structure) def mv(self, vector): # vector has structure [tree(in), leaf(in)] # self.out_structure() has structure [tree(out)] # self.pytree has structure [tree(out), tree(in), leaf(out), leaf(in)] # return has structure [tree(out), leaf(out)] maybe_sparse_op = _try_sparse_materialise(self) if maybe_sparse_op is not self: return maybe_sparse_op.mv(vector) def matmul(_, matrix): return _tree_matmul(matrix, vector) return jtu.tree_map(matmul, self.out_structure(), self.pytree) def as_matrix(self): with jax.numpy_dtype_promotion("standard"): dtype = jnp.result_type(*jtu.tree_leaves(self.pytree)) def concat_in(struct, subpytree): leaves = jtu.tree_leaves(subpytree) assert all(leaf.shape[: struct.ndim] == struct.shape for leaf in leaves) leaves = [ leaf.astype(dtype).reshape( struct.size, math.prod(leaf.shape[struct.ndim :]) ) for leaf in leaves ] return jnp.concatenate(leaves, axis=1) matrix = jtu.tree_map(concat_in, self.out_structure(), self.pytree) matrix = jtu.tree_leaves(matrix) return jnp.concatenate(matrix, axis=0) def transpose(self): if is_symmetric(self): return self def _transpose(struct, subtree): def _transpose_impl(leaf): return jnp.moveaxis(leaf, source, dest) source = list(range(struct.ndim)) dest = list(range(-struct.ndim, 0)) return jtu.tree_map(_transpose_impl, subtree) pytree_transpose = jtu.tree_map(_transpose, self.out_structure(), self.pytree) pytree_transpose = jtu.tree_transpose( jtu.tree_structure(self.out_structure()), jtu.tree_structure(self.in_structure()), pytree_transpose, ) return PyTreeLinearOperator( pytree_transpose, self.in_structure(), transpose_tags(self.tags) ) def in_structure(self): leaves, treedef = self.input_structure return jtu.tree_unflatten(treedef, leaves) def out_structure(self): leaves, treedef = self.output_structure return jtu.tree_unflatten(treedef, leaves) class DiagonalLinearOperator(AbstractLinearOperator): """A diagonal linear operator, e.g. for a diagonal matrix. Only the diagonal is stored (for memory efficiency). Matrix-vector products are computed by doing a pointwise diagonal * vector, rather than a full matrix @ vector (for speed). The diagonal may also be a PyTree, rather than a 1D array. When materialising the matrix, the diagonal is taken to be defined by the flattened PyTree (i.e. values show up in the same order.) """ diagonal: PyTree[Inexact[Array, "..."]] def __init__(self, diagonal: PyTree[ArrayLike]): """**Arguments:** - `diagonal`: an array or PyTree defining the diagonal of the matrix. """ self.diagonal = jtu.tree_map(inexact_asarray, diagonal) def mv(self, vector): return (ω(self.diagonal) * ω(vector)).ω def as_matrix(self): return jnp.diag(diagonal(self)) def transpose(self): return self def in_structure(self): return jax.eval_shape(lambda: self.diagonal) def out_structure(self): return jax.eval_shape(lambda: self.diagonal) class _NoAuxIn(eqx.Module): fn: Callable args: Any def __call__(self, x): return self.fn(x, self.args) class _Unwrap(eqx.Module): fn: Callable def __call__(self, x): (f,) = self.fn(x) return f class JacobianLinearOperator(AbstractLinearOperator): """Given a function `fn: X -> Y`, and a point `x in X`, then this defines the linear operator (also a function `X -> Y`) given by the Jacobian `(d(fn)/dx)(x)`. For example if the inputs and outputs are just arrays, then this is equivalent to `MatrixLinearOperator(jax.jacfwd(fn)(x))`. The Jacobian is not materialised; matrix-vector products, which are in fact Jacobian-vector products, are computed using autodifferentiation. By default (or with `jac="fwd"`), `JacobianLinearOperator(fn, x).mv(v)` is equivalent to `jax.jvp(fn, (x,), (v,))`. For `jac="bwd"`, `jax.vjp` is combined with `jax.linear_transpose`, which works even with functions that only define a custom VJP (via `jax.custom_vjp`) and don't support forward-mode differentiation. See also [`lineax.materialise`][], which materialises the whole Jacobian in memory. !!! tip For repeated `mv()` calls, consider using [`lineax.linearise`][] to cache the primal computation, e.g. for `jac="fwd"/None` it returns `_, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)` """ fn: Callable[ [PyTree[Inexact[Array, "..."]], PyTree[Any]], PyTree[Inexact[Array, "..."]] ] x: PyTree[Inexact[Array, "..."]] args: PyTree[Any] tags: frozenset[object] = eqx.field(static=True) jac: Literal["fwd", "bwd"] | None @eqxi.doc_remove_args("closure_convert") def __init__( self, fn: Callable, x: PyTree[ArrayLike], args: PyTree[Any] = None, tags: object | Iterable[object] = (), jac: Literal["fwd", "bwd"] | None = None, closure_convert: bool = True, ): """**Arguments:** - `fn`: A function `(x, args) -> y`. The Jacobian `d(fn)/dx` is used as the linear operator, and `args` are just any other arguments that should not be differentiated. - `x`: The point to evaluate `d(fn)/dx` at: `(d(fn)/dx)(x, args)`. - `args`: As `x`; this is the point to evaluate `d(fn)/dx` at: `(d(fn)/dx)(x, args)`. - `tags`: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong. - `jac`: allows to use specific jacobian computation method. If `jac=fwd` forces `jax.jacfwd` to be used, similarly `jac=bwd` mandates the use of `jax.jacrev`. Otherwise, if not specified it will be chosen by default according to input and output shape. """ if jac not in [None, "fwd", "bwd"]: raise ValueError( "`jac` argument of `JacobianLinearOperator` should be either " "`'fwd'`, `'bwd'`, or `None`." ) # Flush out any closed-over values, so that we can safely pass `self` # across API boundaries. (In particular, across `linear_solve_p`.) # We don't use `jax.closure_convert` as that only flushes autodiffable # (=floating-point) constants. It probably doesn't matter, but if `fn` is a # PyTree capturing non-floating-point constants, we should probably continue # to respect that, and keep any non-floating-point constants as part of the # PyTree structure. x = jtu.tree_map(inexact_asarray, x) if closure_convert: fn = eqx.filter_closure_convert(fn, x, args) self.fn = fn self.x = x self.args = args self.tags = _frozenset(tags) self.jac = jac def mv(self, vector): fn = _NoAuxIn(self.fn, self.args) if self.jac == "fwd" or self.jac is None: _, out = jax.jvp(fn, (self.x,), (vector,)) elif self.jac == "bwd": # Use VJP + linear_transpose instead of materializing full Jacobian. # This works even for custom_vjp functions that don't have JVP rules. _, vjp_fn = jax.vjp(fn, self.x) if is_symmetric(self): # For symmetric operators, J = J.T, so vjp directly gives J @ v (out,) = vjp_fn(vector) else: # For non-symmetric, transpose the VJP to get J @ v from J.T @ v transpose_vjp = jax.linear_transpose( lambda g: vjp_fn(g)[0], self.out_structure() ) (out,) = transpose_vjp(vector) else: raise ValueError("`jac` should be either `'fwd'`, `'bwd'`, or `None`.") return out def as_matrix(self): return materialise(self).as_matrix() def transpose(self): if is_symmetric(self): return self fn = _NoAuxIn(self.fn, self.args) # Works because vjpfn is a PyTree _, vjpfn = jax.vjp(fn, self.x) vjpfn = _Unwrap(vjpfn) return FunctionLinearOperator( vjpfn, self.out_structure(), transpose_tags(self.tags) ) def in_structure(self): return strip_weak_dtype(jax.eval_shape(lambda: self.x)) def out_structure(self): fn = _NoAuxIn(self.fn, self.args) return strip_weak_dtype(eqxi.cached_filter_eval_shape(fn, self.x)) # `input_structure` must be static as with `JacobianLinearOperator` class FunctionLinearOperator(AbstractLinearOperator): """Wraps a *linear* function `fn: X -> Y` into a linear operator. (So that `self.mv(x)` is defined by `self.mv(x) == fn(x)`.) See also [`lineax.materialise`][], which materialises the whole linear operator in memory. (Similar to `.as_matrix()`.) """ fn: Callable[[PyTree[Inexact[Array, "..."]]], PyTree[Inexact[Array, "..."]]] input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) tags: frozenset[object] = eqx.field(static=True) @eqxi.doc_remove_args("closure_convert") def __init__( self, fn: Callable[[PyTree[Inexact[Array, "..."]]], PyTree[Inexact[Array, "..."]]], input_structure: PyTree[jax.ShapeDtypeStruct], tags: object | Iterable[object] = (), closure_convert: bool = True, ): """**Arguments:** - `fn`: a linear function. Should accept a PyTree of floating-point JAX arrays, and return a PyTree of floating-point JAX arrays. - `input_structure`: A PyTree of `jax.ShapeDtypeStruct`s specifying the structure of the input to the function. (When later calling `self.mv(x)` then this should match the structure of `x`, i.e. `jax.eval_shape(lambda: x)`.) - `tags`: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong. """ # See matching comment in JacobianLinearOperator. input_structure = _inexact_structure(input_structure) if closure_convert: fn = eqx.filter_closure_convert(fn, input_structure) self.fn = fn self.input_structure = jtu.tree_flatten(input_structure) self.tags = _frozenset(tags) def mv(self, vector): return self.fn(vector) def as_matrix(self): return materialise(self).as_matrix() def transpose(self): if is_symmetric(self): return self transpose_fn = jax.linear_transpose(self.fn, self.in_structure()) def _transpose_fn(vector): (out,) = transpose_fn(vector) return out # Works because transpose_fn is a PyTree return FunctionLinearOperator( _transpose_fn, self.out_structure(), transpose_tags(self.tags) ) def in_structure(self): leaves, treedef = self.input_structure return jtu.tree_unflatten(treedef, leaves) def out_structure(self): return strip_weak_dtype( eqxi.cached_filter_eval_shape(self.fn, self.in_structure()) ) # `structure` must be static as with `JacobianLinearOperator` class IdentityLinearOperator(AbstractLinearOperator): """Represents the identity transformation `X -> X`, where each `x in X` is some PyTree of floating-point JAX arrays. """ input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) output_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) def __init__( self, input_structure: PyTree[jax.ShapeDtypeStruct], output_structure: PyTree[jax.ShapeDtypeStruct] = sentinel, ): """**Arguments:** - `input_structure`: A PyTree of `jax.ShapeDtypeStruct`s specifying the structure of the the input space. (When later calling `self.mv(x)` then this should match the structure of `x`, i.e. `jax.eval_shape(lambda: x)`.) - `output_structure`: A PyTree of `jax.ShapeDtypeStruct`s specifying the structure of the the output space. If not passed then this defaults to the same as `input_structure`. If passed then it must have the same number of elements as `input_structure`, so that the operator is square. """ if output_structure is sentinel: output_structure = input_structure input_structure = _inexact_structure(input_structure) output_structure = _inexact_structure(output_structure) self.input_structure = jtu.tree_flatten(input_structure) self.output_structure = jtu.tree_flatten(output_structure) def mv(self, vector): if not eqx.tree_equal( strip_weak_dtype(jax.eval_shape(lambda: vector)), strip_weak_dtype(self.in_structure()), ): raise ValueError("Vector and operator structures do not match") elif self.input_structure == self.output_structure: return vector # fast-path for common special case else: # TODO(kidger): this could be done slightly more efficiently, by iterating # leaf-by-leaf. leaves = jtu.tree_leaves(vector) with jax.numpy_dtype_promotion("standard"): dtype = jnp.result_type(*leaves) vector = jnp.concatenate([x.astype(dtype).reshape(-1) for x in leaves]) out_size = self.out_size() if vector.size < out_size: vector = jnp.concatenate( [vector, jnp.zeros(out_size - vector.size, vector.dtype)] ) else: vector = vector[:out_size] leaves, treedef = jtu.tree_flatten(self.out_structure()) sizes = np.cumsum([math.prod(x.shape) for x in leaves[:-1]]) split = jnp.split(vector, sizes) assert len(split) == len(leaves) with warnings.catch_warnings(): warnings.simplefilter("ignore") # ignore complex-to-real cast warning shaped = [ x.reshape(y.shape).astype(y.dtype) for x, y in zip(split, leaves) ] return jtu.tree_unflatten(treedef, shaped) def as_matrix(self): leaves = jtu.tree_leaves(self.in_structure()) with jax.numpy_dtype_promotion("standard"): dtype = ( default_floating_dtype() if len(leaves) == 0 else jnp.result_type(*leaves) ) return jnp.eye(self.out_size(), self.in_size(), dtype=dtype) def transpose(self): return IdentityLinearOperator(self.out_structure(), self.in_structure()) def in_structure(self): leaves, treedef = self.input_structure return jtu.tree_unflatten(treedef, leaves) def out_structure(self): leaves, treedef = self.output_structure return jtu.tree_unflatten(treedef, leaves) @property def tags(self): return frozenset() class TridiagonalLinearOperator(AbstractLinearOperator): """As [`lineax.MatrixLinearOperator`][], but for specifically a tridiagonal matrix. """ diagonal: Inexact[Array, " size"] lower_diagonal: Inexact[Array, " size-1"] upper_diagonal: Inexact[Array, " size-1"] def __init__( self, diagonal: Inexact[Array, " size"], lower_diagonal: Inexact[Array, " size-1"], upper_diagonal: Inexact[Array, " size-1"], ): """**Arguments:** - `diagonal`: A rank-one JAX array. This is the diagonal of the matrix. - `lower_diagonal`: A rank-one JAX array. This is the lower diagonal of the matrix. - `upper_diagonal`: A rank-one JAX array. This is the upper diagonal of the matrix. If `diagonal` has shape `(a,)` then `lower_diagonal` and `upper_diagonal` should both have shape `(a - 1,)`. """ self.diagonal = inexact_asarray(diagonal) self.lower_diagonal = inexact_asarray(lower_diagonal) self.upper_diagonal = inexact_asarray(upper_diagonal) (size,) = self.diagonal.shape if self.lower_diagonal.shape != (size - 1,): raise ValueError("lower_diagonal and diagonal do not have consistent size") if self.upper_diagonal.shape != (size - 1,): raise ValueError("upper_diagonal and diagonal do not have consistent size") def mv(self, vector): a = self.upper_diagonal * vector[1:] b = self.diagonal * vector c = self.lower_diagonal * vector[:-1] return b.at[:-1].add(a).at[1:].add(c) def as_matrix(self): (size,) = jnp.shape(self.diagonal) matrix = jnp.zeros((size, size), self.diagonal.dtype) arange = np.arange(size) matrix = matrix.at[arange, arange].set(self.diagonal) matrix = matrix.at[arange[1:], arange[:-1]].set(self.lower_diagonal) matrix = matrix.at[arange[:-1], arange[1:]].set(self.upper_diagonal) return matrix def transpose(self): return TridiagonalLinearOperator( self.diagonal, self.upper_diagonal, self.lower_diagonal ) def in_structure(self): (size,) = jnp.shape(self.diagonal) return jax.ShapeDtypeStruct(shape=(size,), dtype=self.diagonal.dtype) def out_structure(self): (size,) = jnp.shape(self.diagonal) return jax.ShapeDtypeStruct(shape=(size,), dtype=self.diagonal.dtype) class TaggedLinearOperator(AbstractLinearOperator): """Wraps another linear operator and specifies that it has certain tags, e.g. representing symmetry. !!! Example ```python # Some other operator. operator = lx.MatrixLinearOperator(some_jax_array) # Now symmetric! But the type system doesn't know this. sym_operator = operator + operator.T assert lx.is_symmetric(sym_operator) == False # We can declare that our operator has a particular property. sym_operator = lx.TaggedLinearOperator(sym_operator, lx.symmetric_tag) assert lx.is_symmetric(sym_operator) == True ``` """ operator: AbstractLinearOperator tags: frozenset[object] = eqx.field(static=True) def __init__( self, operator: AbstractLinearOperator, tags: object | Iterable[object] ): """**Arguments:** - `operator`: some other linear operator to wrap. - `tags`: any tags indicating whether this operator has any particular properties, like symmetry or positive-definite-ness. Note that these properties are unchecked and you may get incorrect values elsewhere if these tags are wrong. """ self.operator = operator self.tags = _frozenset(tags) def mv(self, vector): return self.operator.mv(vector) def as_matrix(self): return self.operator.as_matrix() def transpose(self): return TaggedLinearOperator( self.operator.transpose(), transpose_tags(self.tags) ) def in_structure(self): return self.operator.in_structure() def out_structure(self): return self.operator.out_structure() # # All operators below here are private to lineax. # def _is_none(x): return x is None class TangentLinearOperator(AbstractLinearOperator): """Internal to lineax. Used to represent the tangent (jvp) computation with respect to the linear operator in a linear solve. """ primal: AbstractLinearOperator tangent: AbstractLinearOperator def __check_init__(self): assert type(self.primal) is type(self.tangent) # noqa: E721 def mv(self, vector): mv = lambda operator: operator.mv(vector) out, t_out = eqx.filter_jvp(mv, (self.primal,), (self.tangent,)) return jtu.tree_map(eqxi.materialise_zeros, out, t_out, is_leaf=_is_none) def as_matrix(self): as_matrix = lambda operator: operator.as_matrix() out, t_out = eqx.filter_jvp(as_matrix, (self.primal,), (self.tangent,)) return jtu.tree_map(eqxi.materialise_zeros, out, t_out, is_leaf=_is_none) def transpose(self): transpose = lambda operator: operator.transpose() primal_out, tangent_out = eqx.filter_jvp( transpose, (self.primal,), (self.tangent,) ) return TangentLinearOperator(primal_out, tangent_out) def in_structure(self): return self.primal.in_structure() def out_structure(self): return self.primal.out_structure() class AddLinearOperator(AbstractLinearOperator): """A linear operator formed by adding two other linear operators together. !!! Example ```python x = MatrixLinearOperator(...) y = MatrixLinearOperator(...) assert isinstance(x + y, AddLinearOperator) ``` """ operator1: AbstractLinearOperator operator2: AbstractLinearOperator def __check_init__(self): if self.operator1.in_structure() != self.operator2.in_structure(): raise ValueError("Incompatible linear operator structures") if self.operator1.out_structure() != self.operator2.out_structure(): raise ValueError("Incompatible linear operator structures") def mv(self, vector): maybe_sparse_op = _try_sparse_materialise(self) if maybe_sparse_op is not self: return maybe_sparse_op.mv(vector) mv1 = self.operator1.mv(vector) mv2 = self.operator2.mv(vector) return (mv1**ω + mv2**ω).ω def as_matrix(self): return self.operator1.as_matrix() + self.operator2.as_matrix() def transpose(self): return self.operator1.transpose() + self.operator2.transpose() def in_structure(self): return self.operator1.in_structure() def out_structure(self): return self.operator1.out_structure() class MulLinearOperator(AbstractLinearOperator): """A linear operator formed by multiplying a linear operator by a scalar. !!! Example ```python x = MatrixLinearOperator(...) y = 0.5 assert isinstance(x * y, MulLinearOperator) ``` """ operator: AbstractLinearOperator scalar: Scalar def mv(self, vector): return (self.operator.mv(vector) ** ω * self.scalar).ω def as_matrix(self): return self.operator.as_matrix() * self.scalar def transpose(self): return self.operator.transpose() * self.scalar def in_structure(self): return self.operator.in_structure() def out_structure(self): return self.operator.out_structure() # Not just `MulLinearOperator(..., -1)` for compatibility with # `jax_numpy_dtype_promotion=strict`. class NegLinearOperator(AbstractLinearOperator): """A linear operator formed by computing the negative of a linear operator. !!! Example ```python x = MatrixLinearOperator(...) assert isinstance(-x, NegLinearOperator) ``` """ operator: AbstractLinearOperator def mv(self, vector): return (-(self.operator.mv(vector) ** ω)).ω def as_matrix(self): return -self.operator.as_matrix() def transpose(self): return -self.operator.transpose() def in_structure(self): return self.operator.in_structure() def out_structure(self): return self.operator.out_structure() class DivLinearOperator(AbstractLinearOperator): """A linear operator formed by dividing a linear operator by a scalar. !!! Example ```python x = MatrixLinearOperator(...) y = 0.5 assert isinstance(x / y, DivLinearOperator) ``` """ operator: AbstractLinearOperator scalar: Scalar def mv(self, vector): with jax.numpy_dtype_promotion("standard"): return (self.operator.mv(vector) ** ω / self.scalar).ω def as_matrix(self): return self.operator.as_matrix() / self.scalar def transpose(self): return self.operator.transpose() / self.scalar def in_structure(self): return self.operator.in_structure() def out_structure(self): return self.operator.out_structure() class ComposedLinearOperator(AbstractLinearOperator): """A linear operator formed by composing (matrix-multiplying) two other linear operators together. !!! Example ```python x = MatrixLinearOperator(matrix1) y = MatrixLinearOperator(matrix2) composed = x @ y assert isinstance(composed, ComposedLinearOperator) assert jnp.allclose(composed.as_matrix(), matrix1 @ matrix2) ``` """ operator1: AbstractLinearOperator operator2: AbstractLinearOperator def __check_init__(self): if self.operator1.in_structure() != self.operator2.out_structure(): raise ValueError("Incompatible linear operator structures") def mv(self, vector): maybe_sparse_op = _try_sparse_materialise(self) if maybe_sparse_op is not self: return maybe_sparse_op.mv(vector) return self.operator1.mv(self.operator2.mv(vector)) def as_matrix(self): if isinstance(self.operator1, IdentityLinearOperator): return self.operator2.as_matrix() if isinstance(self.operator2, IdentityLinearOperator): return self.operator1.as_matrix() _, unravel = eqx.filter_eval_shape( jfu.ravel_pytree, self.operator1.in_structure() ) def mv_flat(v): out = self.operator1.mv(unravel(v)) return jfu.ravel_pytree(out)[0] return jax.vmap(mv_flat, in_axes=1, out_axes=1)(self.operator2.as_matrix()) def transpose(self): return self.operator2.transpose() @ self.operator1.transpose() def in_structure(self): return self.operator2.in_structure() def out_structure(self): return self.operator1.out_structure() # # Operations on `AbstractLinearOperator`s. # These are done through `singledispatch` rather than as methods. # # If an end user ever wanted to add something analogous to # `diagonal: AbstractLinearOperator -> Array` # then of course they don't get to edit our base class and add overloads to all # subclasses. # They'd have to use `singledispatch` to get the desired behaviour. (Or maybe just # hardcode compatibility with only some `AbstractLinearOperator` subclasses, eurgh.) # So for consistency we do the same thing here, rather than adding privileged behaviour # for just the operations we happen to support. # # (Something something Julia something something orphan problem etc.) # def _default_not_implemented(name: str, operator: AbstractLinearOperator) -> NoReturn: msg = f"`lineax.{name}` has not been implemented for {type(operator)}" if type(operator).__module__.startswith("lineax"): assert False, msg + ". Please file a bug against Lineax." else: raise NotImplementedError(msg) # linearise @ft.singledispatch def linearise(operator: AbstractLinearOperator) -> AbstractLinearOperator: """Linearises a linear operator. This returns another linear operator. Mathematically speaking this is just the identity function. And indeed most linear operators will be returned unchanged. For specifically [`lineax.JacobianLinearOperator`][], then this will cache the primal pass, so that it does not need to be recomputed each time. That is, it uses some memory to improve speed. (This is the precisely same distinction as `jax.jvp` versus `jax.linearize`.) **Arguments:** - `operator`: a linear operator. **Returns:** Another linear operator. Mathematically it performs matrix-vector products (`operator.mv`) that produce the same results as the input `operator`. """ _default_not_implemented("linearise", operator) @linearise.register(MatrixLinearOperator) @linearise.register(PyTreeLinearOperator) @linearise.register(FunctionLinearOperator) @linearise.register(IdentityLinearOperator) @linearise.register(DiagonalLinearOperator) @linearise.register(TridiagonalLinearOperator) def _(operator): return operator @linearise.register(JacobianLinearOperator) def _(operator): fn = _NoAuxIn(operator.fn, operator.args) if operator.jac == "bwd": # For backward mode, use VJP + linear_transpose. # This works even with custom_vjp functions that don't support forward-mode AD. _, vjp_fn = jax.vjp(fn, operator.x) if is_symmetric(operator): # For symmetric: J = J.T, so vjp directly gives J @ v lin = _Unwrap(vjp_fn) else: # Transpose the VJP to get J @ v from J.T @ v lin = _Unwrap( jax.linear_transpose(lambda g: vjp_fn(g)[0], operator.out_structure()) ) else: # "fwd" or None _, lin = jax.linearize(fn, operator.x) return FunctionLinearOperator(lin, operator.in_structure(), operator.tags) # materialise @ft.singledispatch def materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator: """Materialises a linear operator. This returns another linear operator. Mathematically speaking this is just the identity function. And indeed most linear operators will be returned unchanged. For specifically [`lineax.JacobianLinearOperator`][] and [`lineax.FunctionLinearOperator`][] then the linear operator is materialised in memory. That is, it becomes defined as a matrix (or pytree of arrays), rather than being defined only through its matrix-vector product ([`lineax.AbstractLinearOperator.mv`][]). Materialisation sometimes improves compile time or run time. It usually increases memory usage. For example: ```python large_function = ... operator = lx.FunctionLinearOperator(large_function, ...) # Option 1 out1 = operator.mv(vector1) # Traces and compiles `large_function` out2 = operator.mv(vector2) # Traces and compiles `large_function` again! out3 = operator.mv(vector3) # Traces and compiles `large_function` a third time! # All that compilation might lead to long compile times. # If `large_function` takes a long time to run, then this might also lead to long # run times. # Option 2 operator = lx.materialise(operator) # Traces and compiles `large_function` and # stores the result as a matrix. out1 = operator.mv(vector1) # Each of these just computes a matrix-vector product out2 = operator.mv(vector2) # against the stored matrix. out3 = operator.mv(vector3) # # Now, `large_function` is only compiled once, and only ran once. # However, storing the matrix might take a lot of memory, and the initial # computation may-or-may-not take a long time to run. ``` Generally speaking it is worth first setting up your problem without `lx.materialise`, and using it as an optional optimisation if you find that it helps your particular problem. **Arguments:** - `operator`: a linear operator. **Returns:** Another linear operator. Mathematically it performs matrix-vector products (`operator.mv`) that produce the same results as the input `operator`. """ _default_not_implemented("materialise", operator) def _try_sparse_materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator: """Try to materialise to a sparse operator. Returns a (Tri)DiagonalLinearOperator if the operator is tagged as (tri)diagonal, otherwise returns the original operator unchanged. The resulting operator preserves the input/output structure of the original operator. """ if is_diagonal(operator): diag_flat = diagonal(operator) _, unravel = eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) diag_pytree = unravel(diag_flat) return DiagonalLinearOperator(diag_pytree) # TridiagonalLinearOperator only supports flat in and out structures if ( is_tridiagonal(operator) and isinstance(operator.in_structure(), jax.ShapeDtypeStruct) and isinstance(operator.out_structure(), jax.ShapeDtypeStruct) ): return TridiagonalLinearOperator(*tridiagonal(operator)) return operator @materialise.register(MatrixLinearOperator) @materialise.register(PyTreeLinearOperator) def _(operator): return _try_sparse_materialise(operator) @materialise.register(IdentityLinearOperator) @materialise.register(DiagonalLinearOperator) @materialise.register(TridiagonalLinearOperator) def _(operator): return operator @materialise.register(JacobianLinearOperator) def _(operator): maybe_sparse_op = _try_sparse_materialise(operator) if maybe_sparse_op is not operator: return maybe_sparse_op fn = _NoAuxIn(operator.fn, operator.args) jac = jacobian( fn, operator.in_size(), operator.out_size(), holomorphic=any(jnp.iscomplexobj(xi) for xi in jtu.tree_leaves(operator.x)), jac=operator.jac, )(operator.x) return PyTreeLinearOperator(jac, operator.out_structure(), operator.tags) @materialise.register(FunctionLinearOperator) def _(operator): maybe_sparse_op = _try_sparse_materialise(operator) if maybe_sparse_op is not operator: return maybe_sparse_op flat, unravel = strip_weak_dtype( eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) ) eye = jnp.eye(flat.size, dtype=flat.dtype) jac = jax.vmap(lambda x: operator.fn(unravel(x)), out_axes=-1)(eye) def batch_unravel(x): assert x.ndim > 0 unravel_ = unravel for _ in range(x.ndim - 1): unravel_ = jax.vmap(unravel_) return unravel_(x) jac = jtu.tree_map(batch_unravel, jac) return PyTreeLinearOperator(jac, operator.out_structure(), operator.tags) # diagonal @ft.singledispatch def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]: """Extracts the diagonal from a linear operator, and returns a vector. **Arguments:** - `operator`: a linear operator. **Returns:** A rank-1 JAX array. (That is, it has shape `(a,)` for some integer `a`.) For most operators this is just `jnp.diag(operator.as_matrix())`. Some operators (e.g. [`lineax.DiagonalLinearOperator`][]) can have more efficient implementations. If you don't know what kind of operator you might have, then this function ensures that you always get the most efficient implementation. """ _default_not_implemented("diagonal", operator) def _leaf_from_keypath(pytree: PyTree, keypath: jtu.KeyPath) -> Array: """Extract the leaf from a pytree at the given keypath.""" for path, leaf in jtu.tree_leaves_with_path(pytree): if path == keypath: return leaf raise ValueError(f"Leaf not found at keypath {keypath}") @diagonal.register(MatrixLinearOperator) def _(operator): return jnp.diag(operator.as_matrix()) @diagonal.register(PyTreeLinearOperator) def _(operator): if is_diagonal(operator): def extract_diag(keypath, struct, subpytree): block = _leaf_from_keypath(subpytree, keypath) return jnp.diag(block.reshape(struct.size, struct.size)) diags = jtu.tree_map_with_path( extract_diag, operator.out_structure(), operator.pytree ) return jnp.concatenate(jtu.tree_leaves(diags)) else: return jnp.diag(operator.as_matrix()) @diagonal.register(JacobianLinearOperator) @diagonal.register(FunctionLinearOperator) def _(operator): if is_diagonal(operator): with jax.ensure_compile_time_eval(): basis = jtu.tree_map( lambda s: jnp.ones(s.shape, s.dtype), operator.in_structure() ) diag_as_pytree = operator.mv(basis) diag, _ = jfu.ravel_pytree(diag_as_pytree) return diag return diagonal(materialise(operator)) @diagonal.register(DiagonalLinearOperator) def _(operator): diagonal, _ = jfu.ravel_pytree(operator.diagonal) return diagonal @diagonal.register(IdentityLinearOperator) def _(operator): return jnp.ones(operator.in_size()) @diagonal.register(TridiagonalLinearOperator) def _(operator): return operator.diagonal # tridiagonal @ft.singledispatch def tridiagonal( operator: AbstractLinearOperator, ) -> tuple[Shaped[Array, " size"], Shaped[Array, " size-1"], Shaped[Array, " size-1"]]: """Extracts the diagonal, lower diagonal, and upper diagonal, from a linear operator. Returns three vectors. **Arguments:** - `operator`: a linear operator. **Returns:** A 3-tuple, consisting of: - The diagonal of the matrix, represented as a vector. - The lower diagonal of the matrix, represented as a vector. - The upper diagonal of the matrix, represented as a vector. If the diagonal has shape `(a,)` then the lower and upper diagonals will have shape `(a - 1,)`. For most operators these are computed by materialising the array and then extracting the relevant elements, e.g. getting the main diagonal via `jnp.diag(operator.as_matrix())`. Some operators (e.g. [`lineax.TridiagonalLinearOperator`][]) can have more efficient implementations. If you don't know what kind of operator you might have, then this function ensures that you always get the most efficient implementation. """ _default_not_implemented("tridiagonal", operator) @tridiagonal.register(MatrixLinearOperator) @tridiagonal.register(PyTreeLinearOperator) def _(operator): matrix = operator.as_matrix() assert matrix.ndim == 2 main_diagonal = jnp.diagonal(matrix, offset=0) upper_diagonal = jnp.diagonal(matrix, offset=1) lower_diagonal = jnp.diagonal(matrix, offset=-1) return main_diagonal, lower_diagonal, upper_diagonal @tridiagonal.register(JacobianLinearOperator) @tridiagonal.register(FunctionLinearOperator) def _(operator): if is_tridiagonal(operator): with jax.ensure_compile_time_eval(): flat, unravel = strip_weak_dtype( eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) ) basis = jnp.zeros((3, flat.size), dtype=flat.dtype) for i in range(3): basis = basis.at[i, i::3].set(1.0) basis = jax.vmap(unravel)(basis) coloring = jnp.arange(flat.size) % 3 compressed_as_pytree = jax.vmap(operator.mv)(basis) compressed_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])( compressed_as_pytree ) # unique_indices propagates through linear_transpose to set unique_indices=True # on the scatter, allowing assignment rather than accumulation. rows = jnp.arange(flat.size) diag = compressed_flat.at[coloring, rows].get( wrap_negative_indices=False, unique_indices=True ) lower_diag = compressed_flat.at[coloring[:-1], rows[1:]].get( wrap_negative_indices=False, unique_indices=True ) upper_diag = compressed_flat.at[coloring[1:], rows[:-1]].get( wrap_negative_indices=False, unique_indices=True ) return diag, lower_diag, upper_diag matrix = operator.as_matrix() assert matrix.ndim == 2 main_diagonal = jnp.diagonal(matrix, offset=0) upper_diagonal = jnp.diagonal(matrix, offset=1) lower_diagonal = jnp.diagonal(matrix, offset=-1) return main_diagonal, lower_diagonal, upper_diagonal @tridiagonal.register(DiagonalLinearOperator) def _(operator): diag = diagonal(operator) upper_diag = jnp.zeros(diag.size - 1) lower_diag = jnp.zeros(diag.size - 1) return diag, lower_diag, upper_diag @tridiagonal.register(IdentityLinearOperator) def _(operator): size = operator.in_size() main_diagonal = jnp.ones(size) off_diagonal = jnp.zeros(size - 1) return main_diagonal, off_diagonal, off_diagonal @tridiagonal.register(TridiagonalLinearOperator) def _(operator): return operator.diagonal, operator.lower_diagonal, operator.upper_diagonal # is_symmetric @ft.singledispatch def is_symmetric(operator: AbstractLinearOperator) -> bool: """Returns whether an operator is marked as symmetric. See [the documentation on linear operator tags](../api/tags.md) for more information. **Arguments:** - `operator`: a linear operator. **Returns:** Either `True` or `False.` """ _default_not_implemented("is_symmetric", operator) def _has_real_dtype(operator) -> bool: """Check if all dtypes in an operator's structure are real (not complex).""" leaves = jtu.tree_leaves((operator.in_structure(), operator.out_structure())) dtype = jnp.result_type(*leaves) if jnp.issubdtype(dtype, jnp.complexfloating): return False elif jnp.issubdtype(dtype, jnp.floating): return True else: assert False, ( "Only `jnp.floating` and `jnp.complexfloating` dtypes are understood." ) @is_symmetric.register(MatrixLinearOperator) @is_symmetric.register(PyTreeLinearOperator) @is_symmetric.register(JacobianLinearOperator) @is_symmetric.register(FunctionLinearOperator) def _(operator): # Symmetric (A = A^T) if explicitly tagged symmetric or diagonal if symmetric_tag in operator.tags or diagonal_tag in operator.tags: return True # PSD/NSD implies symmetric only for real dtypes; for complex, it's Hermitian if ( positive_semidefinite_tag in operator.tags or negative_semidefinite_tag in operator.tags ): return _has_real_dtype(operator) return False @is_symmetric.register(IdentityLinearOperator) def _(operator): return eqx.tree_equal(operator.in_structure(), operator.out_structure()) is True @is_symmetric.register(DiagonalLinearOperator) def _(operator): return True @is_symmetric.register(TridiagonalLinearOperator) def _(operator): return False # is_diagonal @ft.singledispatch def is_diagonal(operator: AbstractLinearOperator) -> bool: """Returns whether an operator is marked as diagonal. See [the documentation on linear operator tags](../api/tags.md) for more information. **Arguments:** - `operator`: a linear operator. **Returns:** Either `True` or `False.` """ _default_not_implemented("is_diagonal", operator) @is_diagonal.register(MatrixLinearOperator) @is_diagonal.register(PyTreeLinearOperator) @is_diagonal.register(JacobianLinearOperator) @is_diagonal.register(FunctionLinearOperator) def _(operator): return diagonal_tag in operator.tags or ( operator.in_size() == 1 and operator.out_size() == 1 ) @is_diagonal.register(IdentityLinearOperator) @is_diagonal.register(DiagonalLinearOperator) def _(operator): return True @is_diagonal.register(TridiagonalLinearOperator) def _(operator): return operator.in_size() == 1 # is_tridiagonal @ft.singledispatch def is_tridiagonal(operator: AbstractLinearOperator) -> bool: """Returns whether an operator is marked as tridiagonal. See [the documentation on linear operator tags](../api/tags.md) for more information. **Arguments:** - `operator`: a linear operator. **Returns:** Either `True` or `False.` """ _default_not_implemented("is_tridiagonal", operator) @is_tridiagonal.register(MatrixLinearOperator) @is_tridiagonal.register(PyTreeLinearOperator) @is_tridiagonal.register(JacobianLinearOperator) @is_tridiagonal.register(FunctionLinearOperator) def _(operator): return tridiagonal_tag in operator.tags or diagonal_tag in operator.tags @is_tridiagonal.register(IdentityLinearOperator) @is_tridiagonal.register(DiagonalLinearOperator) @is_tridiagonal.register(TridiagonalLinearOperator) def _(operator): return True # has_unit_diagonal @ft.singledispatch def has_unit_diagonal(operator: AbstractLinearOperator) -> bool: """Returns whether an operator is marked as having unit diagonal. See [the documentation on linear operator tags](../api/tags.md) for more information. **Arguments:** - `operator`: a linear operator. **Returns:** Either `True` or `False.` """ _default_not_implemented("has_unit_diagonal", operator) @has_unit_diagonal.register(MatrixLinearOperator) @has_unit_diagonal.register(PyTreeLinearOperator) @has_unit_diagonal.register(JacobianLinearOperator) @has_unit_diagonal.register(FunctionLinearOperator) def _(operator): return unit_diagonal_tag in operator.tags @has_unit_diagonal.register(IdentityLinearOperator) def _(operator): return True @has_unit_diagonal.register(DiagonalLinearOperator) @has_unit_diagonal.register(TridiagonalLinearOperator) def _(operator): # TODO: refine this return False # is_lower_triangular @ft.singledispatch def is_lower_triangular(operator: AbstractLinearOperator) -> bool: """Returns whether an operator is marked as lower triangular. See [the documentation on linear operator tags](../api/tags.md) for more information. **Arguments:** - `operator`: a linear operator. **Returns:** Either `True` or `False.` """ _default_not_implemented("is_lower_triangular", operator) @is_lower_triangular.register(MatrixLinearOperator) @is_lower_triangular.register(PyTreeLinearOperator) @is_lower_triangular.register(JacobianLinearOperator) @is_lower_triangular.register(FunctionLinearOperator) def _(operator): return lower_triangular_tag in operator.tags @is_lower_triangular.register(IdentityLinearOperator) @is_lower_triangular.register(DiagonalLinearOperator) def _(operator): return True @is_lower_triangular.register(TridiagonalLinearOperator) def _(operator): return False # is_upper_triangular @ft.singledispatch def is_upper_triangular(operator: AbstractLinearOperator) -> bool: """Returns whether an operator is marked as upper triangular. See [the documentation on linear operator tags](../api/tags.md) for more information. **Arguments:** - `operator`: a linear operator. **Returns:** Either `True` or `False.` """ _default_not_implemented("is_upper_triangular", operator) @is_upper_triangular.register(MatrixLinearOperator) @is_upper_triangular.register(PyTreeLinearOperator) @is_upper_triangular.register(JacobianLinearOperator) @is_upper_triangular.register(FunctionLinearOperator) def _(operator): return upper_triangular_tag in operator.tags @is_upper_triangular.register(IdentityLinearOperator) @is_upper_triangular.register(DiagonalLinearOperator) def _(operator): return True @is_upper_triangular.register(TridiagonalLinearOperator) def _(operator): return False # is_positive_semidefinite @ft.singledispatch def is_positive_semidefinite(operator: AbstractLinearOperator) -> bool: """Returns whether an operator is marked as positive semidefinite. See [the documentation on linear operator tags](../api/tags.md) for more information. **Arguments:** - `operator`: a linear operator. **Returns:** Either `True` or `False.` """ _default_not_implemented("is_positive_semidefinite", operator) @is_positive_semidefinite.register(MatrixLinearOperator) @is_positive_semidefinite.register(PyTreeLinearOperator) @is_positive_semidefinite.register(JacobianLinearOperator) @is_positive_semidefinite.register(FunctionLinearOperator) def _(operator): return positive_semidefinite_tag in operator.tags @is_positive_semidefinite.register(IdentityLinearOperator) def _(operator): return eqx.tree_equal(operator.in_structure(), operator.out_structure()) is True @is_positive_semidefinite.register(DiagonalLinearOperator) @is_positive_semidefinite.register(TridiagonalLinearOperator) def _(operator): # TODO: refine this return False # is_negative_semidefinite @ft.singledispatch def is_negative_semidefinite(operator: AbstractLinearOperator) -> bool: """Returns whether an operator is marked as negative semidefinite. See [the documentation on linear operator tags](../api/tags.md) for more information. **Arguments:** - `operator`: a linear operator. **Returns:** Either `True` or `False.` """ _default_not_implemented("is_negative_semidefinite", operator) @is_negative_semidefinite.register(MatrixLinearOperator) @is_negative_semidefinite.register(PyTreeLinearOperator) @is_negative_semidefinite.register(JacobianLinearOperator) @is_negative_semidefinite.register(FunctionLinearOperator) def _(operator): return negative_semidefinite_tag in operator.tags @is_negative_semidefinite.register(IdentityLinearOperator) def _(operator): return False @is_negative_semidefinite.register(DiagonalLinearOperator) @is_negative_semidefinite.register(TridiagonalLinearOperator) def _(operator): # TODO: refine this return False # ops for wrapper operators @linearise.register(TaggedLinearOperator) def _(operator): return TaggedLinearOperator(linearise(operator.operator), operator.tags) @materialise.register(TaggedLinearOperator) def _(operator): return TaggedLinearOperator(materialise(operator.operator), operator.tags) @diagonal.register(TaggedLinearOperator) def _(operator): return diagonal(operator.operator) @tridiagonal.register(TaggedLinearOperator) def _(operator): return tridiagonal(operator.operator) for transform in (linearise, materialise, diagonal): @transform.register(MulLinearOperator) def _(operator, transform=transform): return transform(operator.operator) * operator.scalar @transform.register(NegLinearOperator) # pyright: ignore def _(operator, transform=transform): return -transform(operator.operator) @transform.register(DivLinearOperator) def _(operator, transform=transform): return transform(operator.operator) / operator.scalar for transform in (linearise, diagonal): @transform.register(AddLinearOperator) # pyright: ignore def _(operator, transform=transform): return transform(operator.operator1) + transform(operator.operator2) # pyright: ignore @materialise.register(AddLinearOperator) def _(operator): maybe_sparse_op = _try_sparse_materialise(operator) if maybe_sparse_op is not operator: return maybe_sparse_op return materialise(operator.operator1) + materialise(operator.operator2) @linearise.register(TangentLinearOperator) def _(operator): primal_out, tangent_out = eqx.filter_jvp( linearise, (operator.primal,), (operator.tangent,) ) return TangentLinearOperator(primal_out, tangent_out) @materialise.register(TangentLinearOperator) def _(operator): primal_out, tangent_out = eqx.filter_jvp( materialise, (operator.primal,), (operator.tangent,) ) return TangentLinearOperator(primal_out, tangent_out) @diagonal.register(TangentLinearOperator) def _(operator): # Should be unreachable: TangentLinearOperator is used for a narrow set of # operations only (mv; transpose) inside the JVP rule linear_solve_p. raise NotImplementedError( "Please open a GitHub issue: https://github.com/google/lineax" ) @tridiagonal.register(TangentLinearOperator) def _(operator): # Should be unreachable: TangentLinearOperator is used for a narrow set of # operations only (mv; transpose) inside the JVP rule linear_solve_p. raise NotImplementedError( "Please open a GitHub issue: https://github.com/google/lineax" ) @tridiagonal.register(AddLinearOperator) def _(operator): (diag1, lower1, upper1) = tridiagonal(operator.operator1) (diag2, lower2, upper2) = tridiagonal(operator.operator2) return (diag1 + diag2, lower1 + lower2, upper1 + upper2) @tridiagonal.register(MulLinearOperator) def _(operator): (diag, lower, upper) = tridiagonal(operator.operator) return (diag * operator.scalar, lower * operator.scalar, upper * operator.scalar) @tridiagonal.register(NegLinearOperator) def _(operator): (diag, lower, upper) = tridiagonal(operator.operator) return (-diag, -lower, -upper) @tridiagonal.register(DivLinearOperator) def _(operator): (diag, lower, upper) = tridiagonal(operator.operator) return (diag / operator.scalar, lower / operator.scalar, upper / operator.scalar) @linearise.register(ComposedLinearOperator) def _(operator): return linearise(operator.operator1) @ linearise(operator.operator2) @materialise.register(ComposedLinearOperator) def _(operator): if isinstance(operator.operator1, IdentityLinearOperator): return materialise(operator.operator2) if isinstance(operator.operator2, IdentityLinearOperator): return materialise(operator.operator1) maybe_sparse_op = _try_sparse_materialise(operator) if maybe_sparse_op is not operator: return maybe_sparse_op return materialise(operator.operator1) @ materialise(operator.operator2) @diagonal.register(ComposedLinearOperator) def _(operator): if is_diagonal(operator.operator1) and is_diagonal(operator.operator2): return diagonal(operator.operator1) * diagonal(operator.operator2) return jnp.diag(operator.as_matrix()) @tridiagonal.register(ComposedLinearOperator) def _(operator): if is_diagonal(operator.operator1) and is_tridiagonal(operator.operator2): d = diagonal(operator.operator1) main, lower, upper = tridiagonal(operator.operator2) # D @ T scales rows: row i multiplied by d[i] return d * main, d[1:] * lower, d[:-1] * upper if is_diagonal(operator.operator2) and is_tridiagonal(operator.operator1): d = diagonal(operator.operator2) main, lower, upper = tridiagonal(operator.operator1) # T @ D scales columns: column j multiplied by d[j] return d * main, d[:-1] * lower, d[1:] * upper matrix = operator.as_matrix() assert matrix.ndim == 2 main_diagonal = jnp.diagonal(matrix, offset=0) upper_diagonal = jnp.diagonal(matrix, offset=1) lower_diagonal = jnp.diagonal(matrix, offset=-1) return main_diagonal, lower_diagonal, upper_diagonal for check in ( is_symmetric, is_diagonal, has_unit_diagonal, is_lower_triangular, is_upper_triangular, is_tridiagonal, is_positive_semidefinite, is_negative_semidefinite, ): @check.register(TangentLinearOperator) def _(operator, check=check): return check(operator.primal) # Scaling/negating preserves these structural properties for check in ( is_symmetric, is_diagonal, is_lower_triangular, is_upper_triangular, is_tridiagonal, ): @check.register(MulLinearOperator) @check.register(NegLinearOperator) @check.register(DivLinearOperator) def _(operator, check=check): return check(operator.operator) # has_unit_diagonal is NOT preserved by scaling or negation @has_unit_diagonal.register(MulLinearOperator) @has_unit_diagonal.register(NegLinearOperator) @has_unit_diagonal.register(DivLinearOperator) def _(operator): return False class _ScalarSign(enum.Enum): positive = enum.auto() negative = enum.auto() zero = enum.auto() unknown = enum.auto() def _scalar_sign(scalar) -> _ScalarSign: """Returns the sign of a scalar, or unknown for JAX tracers.""" if isinstance(scalar, (int, float, np.ndarray, np.generic)): scalar = float(scalar) if scalar > 0: return _ScalarSign.positive elif scalar < 0: return _ScalarSign.negative else: return _ScalarSign.zero else: return _ScalarSign.unknown # PSD/NSD for MulLinearOperator: depends on sign of scalar # Zero scalar gives zero matrix which is both PSD and NSD @is_positive_semidefinite.register(MulLinearOperator) def _(operator): sign = _scalar_sign(operator.scalar) if sign is _ScalarSign.positive: return is_positive_semidefinite(operator.operator) elif sign is _ScalarSign.negative: return is_negative_semidefinite(operator.operator) elif sign is _ScalarSign.zero: return True # zero matrix is PSD return False @is_negative_semidefinite.register(MulLinearOperator) def _(operator): sign = _scalar_sign(operator.scalar) if sign is _ScalarSign.positive: return is_negative_semidefinite(operator.operator) elif sign is _ScalarSign.negative: return is_positive_semidefinite(operator.operator) elif sign is _ScalarSign.zero: return True # zero matrix is NSD return False # PSD/NSD for DivLinearOperator: depends on sign of scalar # Zero scalar is division by zero - return False (conservative) @is_positive_semidefinite.register(DivLinearOperator) def _(operator): sign = _scalar_sign(operator.scalar) if sign is _ScalarSign.positive: return is_positive_semidefinite(operator.operator) elif sign is _ScalarSign.negative: return is_negative_semidefinite(operator.operator) return False @is_negative_semidefinite.register(DivLinearOperator) def _(operator): sign = _scalar_sign(operator.scalar) if sign is _ScalarSign.positive: return is_negative_semidefinite(operator.operator) elif sign is _ScalarSign.negative: return is_positive_semidefinite(operator.operator) return False # PSD/NSD for NegLinearOperator: negation swaps PSD <-> NSD @is_positive_semidefinite.register(NegLinearOperator) def _(operator): return is_negative_semidefinite(operator.operator) @is_negative_semidefinite.register(NegLinearOperator) def _(operator): return is_positive_semidefinite(operator.operator) for check, tag in ( (is_symmetric, symmetric_tag), (is_diagonal, diagonal_tag), (has_unit_diagonal, unit_diagonal_tag), (is_lower_triangular, lower_triangular_tag), (is_upper_triangular, upper_triangular_tag), (is_positive_semidefinite, positive_semidefinite_tag), (is_negative_semidefinite, negative_semidefinite_tag), (is_tridiagonal, tridiagonal_tag), ): @check.register(TaggedLinearOperator) def _(operator, check=check, tag=tag): return (tag in operator.tags) or check(operator.operator) for check in ( is_symmetric, is_diagonal, is_lower_triangular, is_upper_triangular, is_positive_semidefinite, is_negative_semidefinite, is_tridiagonal, ): @check.register(AddLinearOperator) def _(operator, check=check): return check(operator.operator1) and check(operator.operator2) @has_unit_diagonal.register(AddLinearOperator) def _(operator): return False # These properties ARE preserved under composition for check in ( is_diagonal, is_lower_triangular, is_upper_triangular, ): @check.register(ComposedLinearOperator) def _(operator, check=check): return check(operator.operator1) and check(operator.operator2) # is_symmetric: A@B is symmetric only if A and B commute. Diagonal matrices commute. @is_symmetric.register(ComposedLinearOperator) def _(operator): return is_diagonal(operator.operator1) and is_diagonal(operator.operator2) # is_tridiagonal: tridiagonal @ tridiagonal = pentadiagonal, but # tridiagonal @ diagonal = tridiagonal and diagonal @ tridiagonal = tridiagonal @is_tridiagonal.register(ComposedLinearOperator) def _(operator): if is_diagonal(operator.operator1): return is_tridiagonal(operator.operator2) if is_diagonal(operator.operator2): return is_tridiagonal(operator.operator1) return False # PSD/NSD: not preserved under composition in general. @is_positive_semidefinite.register(ComposedLinearOperator) @is_negative_semidefinite.register(ComposedLinearOperator) def _(operator): return False @has_unit_diagonal.register(ComposedLinearOperator) def _(operator): a = is_diagonal(operator) b = is_lower_triangular(operator) c = is_upper_triangular(operator) d = has_unit_diagonal(operator.operator1) e = has_unit_diagonal(operator.operator2) return (a or b or c) and d and e # conj @ft.singledispatch def conj(operator: AbstractLinearOperator) -> AbstractLinearOperator: """Elementwise conjugate of a linear operator. This returns another linear operator. **Arguments:** - `operator`: a linear operator. **Returns:** Another linear operator. """ _default_not_implemented("conj", operator) @conj.register(MatrixLinearOperator) def _(operator): return MatrixLinearOperator(operator.matrix.conj(), operator.tags) @conj.register(PyTreeLinearOperator) def _(operator): pytree_conj = jtu.tree_map(lambda x: x.conj(), operator.pytree) return PyTreeLinearOperator(pytree_conj, operator.out_structure(), operator.tags) @conj.register(DiagonalLinearOperator) def _(operator): diagonal_conj = jtu.tree_map(lambda x: x.conj(), operator.diagonal) return DiagonalLinearOperator(diagonal_conj) @conj.register(JacobianLinearOperator) def _(operator): return conj(linearise(operator)) @conj.register(FunctionLinearOperator) def _(operator): return FunctionLinearOperator( lambda vec: jtu.tree_map(jnp.conj, operator.mv(jtu.tree_map(jnp.conj, vec))), operator.in_structure(), operator.tags, ) @conj.register(IdentityLinearOperator) def _(operator): return operator @conj.register(TridiagonalLinearOperator) def _(operator): return TridiagonalLinearOperator( operator.diagonal.conj(), operator.lower_diagonal.conj(), operator.upper_diagonal.conj(), ) @conj.register(TaggedLinearOperator) def _(operator): return TaggedLinearOperator(conj(operator.operator), operator.tags) @conj.register(TangentLinearOperator) def _(operator): c = lambda operator: conj(operator) primal_out, tangent_out = eqx.filter_jvp(c, (operator.primal,), (operator.tangent,)) return TangentLinearOperator(primal_out, tangent_out) @conj.register(AddLinearOperator) def _(operator): return conj(operator.operator1) + conj(operator.operator2) @conj.register(MulLinearOperator) def _(operator): return conj(operator.operator) * operator.scalar.conj() @conj.register(NegLinearOperator) def _(operator): return -conj(operator.operator) @conj.register(DivLinearOperator) def _(operator): return conj(operator.operator) / operator.scalar.conj() @conj.register(ComposedLinearOperator) def _(operator): return conj(operator.operator1) @ conj(operator.operator2) ================================================ FILE: lineax/_solution.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any import equinox as eqx import equinox.internal as eqxi from jaxtyping import Array, ArrayLike, PyTree _singular_msg = """ A linear solver returned non-finite (NaN or inf) output. This usually means that an operator was not well-posed, and that its solver does not support this. If you are trying solve a linear least-squares problem then you should pass `solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve` assumes that the operator is square and nonsingular. If you *were* expecting this solver to work with this operator, then it may be because: (a) the operator is singular, and your code has a bug; or (b) the operator was nearly singular (i.e. it had a high condition number: `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from numerical instability issues; or (c) the operator is declared to exhibit a certain property (e.g. positive definiteness) that is does not actually satisfy. """.strip() _nonfinite_msg = """ A linear solver received non-finite (NaN or inf) input and cannot determine a solution. This means that you have a bug upstream of Lineax and should check the inputs to `lineax.linear_solve` for non-finite values. """.strip() class RESULTS(eqxi.Enumeration): successful = "" max_steps_reached = ( "The maximum number of solver steps was reached. Try increasing `max_steps`." ) singular = _singular_msg breakdown = ( "A form of iterative breakdown has occured in a linear solve. " "Try using a different solver for this problem or increase `restart` " "if using GMRES." ) stagnation = ( "A stagnation in an iterative linear solve has occurred. Try increasing " "`stagnation_iters` or `restart`." ) conlim = "Condition number of A seems to be larger than `conlim`." nonfinite_input = _nonfinite_msg class Solution(eqx.Module): """The solution to a linear solve. **Attributes:** - `value`: The solution to the solve. - `result`: An integer representing whether the solve was successful or not. This can be converted into a human-readable error message via `lineax.RESULTS[result]`. - `stats`: Statistics about the solver, e.g. the number of steps that were required. - `state`: The internal state of the solver. The meaning of this is specific to each solver. """ value: PyTree[Array] result: RESULTS stats: dict[str, PyTree[ArrayLike]] state: PyTree[Any] ================================================ FILE: lineax/_solve.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import functools as ft from typing import Any, Generic, TypeAlias, TypeVar import equinox as eqx import equinox.internal as eqxi import jax import jax.core import jax.interpreters.ad as ad import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu from equinox.internal import ω from jax._src.ad_util import stop_gradient_p from jaxtyping import Array, ArrayLike, PyTree from ._custom_types import sentinel from ._misc import inexact_asarray, strip_weak_dtype from ._operator import ( AbstractLinearOperator, conj, FunctionLinearOperator, has_unit_diagonal, IdentityLinearOperator, is_diagonal, is_lower_triangular, is_negative_semidefinite, is_positive_semidefinite, is_symmetric, is_tridiagonal, is_upper_triangular, linearise, TangentLinearOperator, ) from ._solution import RESULTS, Solution from ._tags import ( diagonal_tag, lower_triangular_tag, negative_semidefinite_tag, positive_semidefinite_tag, symmetric_tag, unit_diagonal_tag, upper_triangular_tag, ) # # _linear_solve_p # def _to_shapedarray(x): if isinstance(x, jax.ShapeDtypeStruct): return jax.core.ShapedArray(x.shape, x.dtype) else: return x def _to_struct(x): if isinstance(x, jax.core.ShapedArray): return jax.ShapeDtypeStruct(x.shape, x.dtype) elif isinstance(x, jax.core.AbstractValue): raise NotImplementedError( "`lineax.linear_solve` only supports working with JAX arrays; not " f"other abstract values. Got abstract value {x}." ) else: return x def _assert_false(x): assert False def _is_none(x): return x is None def _sum(*args): return sum(args) def _linear_solve_impl(_, state, vector, options, solver, throw, *, check_closure): out = solver.compute(state, vector, options) if check_closure: out = eqxi.nontraceable( out, name="lineax.linear_solve with respect to a closed-over value" ) solution, result, stats = out has_nonfinite_output = jnp.any( jnp.stack( [jnp.any(jnp.invert(jnp.isfinite(x))) for x in jtu.tree_leaves(solution)] ) ) result = RESULTS.where( (result == RESULTS.successful) & has_nonfinite_output, RESULTS.singular, result, ) has_nonfinite_input = jnp.any( jnp.stack( [jnp.any(jnp.invert(jnp.isfinite(x))) for x in jtu.tree_leaves(vector)] ) ) result = RESULTS.where( (result == RESULTS.singular) & has_nonfinite_input, RESULTS.nonfinite_input, result, ) if throw: solution, result, stats = result.error_if( (solution, result, stats), result != RESULTS.successful, ) return solution, result, stats @eqxi.filter_primitive_def def _linear_solve_abstract_eval(operator, state, vector, options, solver, throw): state, vector, options, solver = jtu.tree_map( _to_struct, (state, vector, options, solver) ) out = eqx.filter_eval_shape( _linear_solve_impl, operator, state, vector, options, solver, throw, check_closure=False, ) out = jtu.tree_map(_to_shapedarray, out) return out @eqxi.filter_primitive_jvp def _linear_solve_jvp(primals, tangents): operator, state, vector, options, solver, throw = primals t_operator, t_state, t_vector, t_options, t_solver, t_throw = tangents jtu.tree_map(_assert_false, (t_state, t_options, t_solver, t_throw)) del t_state, t_options, t_solver, t_throw # Note that we pass throw=True unconditionally to all the tangent solves, as there # is nowhere we can pipe their error to. # This is the primal solve so we can respect the original `throw`. solution, result, stats = eqxi.filter_primitive_bind( linear_solve_p, operator, state, vector, options, solver, throw ) # # Consider the primal problem of linearly solving for x in Ax=b. # Let ^ denote pseudoinverses, ᵀ denote transposes, and ' denote tangents. # The linear_solve routine returns specifically the pseudoinverse solution, i.e. # # x = A^b # # Therefore x' = A^'b + A^b' # # Now A^' = -A^A'A^ + A^A^ᵀAᵀ'(I - AA^) + (I - A^A)Aᵀ'A^ᵀA^ # # (Source: https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative) # # This results in: # # x' = A^(-A'x + A^ᵀAᵀ'(b - Ax) - Ay + b') + y # # where # # y = Aᵀ'A^ᵀx # # note that if A has linearly independent columns, then the y - A^Ay # term disappears and gives # # x' = A^(-A'x + A^ᵀAᵀ'(b - Ax) + b') # # and if A has linearly independent rows, then the A^A^ᵀAᵀ'(b - Ax) term # disappears giving: # # x' = A^(-A'x - Ay + b') + y # # if A has linearly independent rows and columns, then A is nonsingular and # # x' = A^(-A'x + b') vecs = [] sols = [] if any(t is not None for t in jtu.tree_leaves(t_vector, is_leaf=_is_none)): # b' term vecs.append( jtu.tree_map(eqxi.materialise_zeros, vector, t_vector, is_leaf=_is_none) ) if any(t is not None for t in jtu.tree_leaves(t_operator, is_leaf=_is_none)): t_operator = TangentLinearOperator(operator, t_operator) t_operator = linearise(t_operator) # optimise for matvecs # -A'x term vec = (-(t_operator.mv(solution) ** ω)).ω vecs.append(vec) rows, columns = operator.out_size(), operator.in_size() assume_independent_rows = solver.assume_full_rank() and rows <= columns assume_independent_columns = solver.assume_full_rank() and columns <= rows if not assume_independent_rows or not assume_independent_columns: operator_conj_transpose = conj(operator).transpose() t_operator_conj_transpose = conj(t_operator).transpose() state_conj, options_conj = solver.conj(state, options) state_conj_transpose, options_conj_transpose = solver.transpose( state_conj, options_conj ) if not assume_independent_rows: lst_sqr_diff = (vector**ω - operator.mv(solution) ** ω).ω tmp = t_operator_conj_transpose.mv(lst_sqr_diff) # pyright: ignore tmp, _, _ = eqxi.filter_primitive_bind( linear_solve_p, operator_conj_transpose, # pyright: ignore state_conj_transpose, # pyright: ignore tmp, options_conj_transpose, # pyright: ignore solver, True, ) vecs.append(tmp) if not assume_independent_columns: tmp1, _, _ = eqxi.filter_primitive_bind( linear_solve_p, operator_conj_transpose, # pyright: ignore state_conj_transpose, # pyright:ignore solution, options_conj_transpose, # pyright: ignore solver, True, ) tmp2 = t_operator_conj_transpose.mv(tmp1) # pyright: ignore # tmp2 is the y term tmp3 = operator.mv(tmp2) tmp4 = (-(tmp3**ω)).ω # tmp4 is the Ay term vecs.append(tmp4) sols.append(tmp2) vecs = jtu.tree_map(_sum, *vecs) # the A^ term at the very beginning sol, _, _ = eqxi.filter_primitive_bind( linear_solve_p, operator, state, vecs, options, solver, True ) sols.append(sol) t_solution = jtu.tree_map(_sum, *sols) out = solution, result, stats t_out = ( t_solution, jtu.tree_map(lambda _: None, result), jtu.tree_map(lambda _: None, stats), ) return out, t_out def _is_undefined(x): return isinstance(x, ad.UndefinedPrimal) def _assert_defined(x): assert not _is_undefined(x) def _keep_undefined(v, ct): if _is_undefined(v): return ct else: return None @eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore def _linear_solve_transpose(inputs, cts_out): cts_solution, _, _ = cts_out operator, state, vector, options, solver, _ = inputs jtu.tree_map( _assert_defined, (operator, state, options, solver), is_leaf=_is_undefined ) cts_solution = jtu.tree_map( ft.partial(eqxi.materialise_zeros, allow_struct=True), operator.in_structure(), cts_solution, ) operator_transpose = operator.transpose() state_transpose, options_transpose = solver.transpose(state, options) cts_vector, _, _ = eqxi.filter_primitive_bind( linear_solve_p, operator_transpose, state_transpose, cts_solution, options_transpose, solver, True, # throw=True unconditionally: nowhere to pipe result to. ) cts_vector = jtu.tree_map( _keep_undefined, vector, cts_vector, is_leaf=_is_undefined ) operator_none = jtu.tree_map(lambda _: None, operator) state_none = jtu.tree_map(lambda _: None, state) options_none = jtu.tree_map(lambda _: None, options) solver_none = jtu.tree_map(lambda _: None, solver) throw_none = None return operator_none, state_none, cts_vector, options_none, solver_none, throw_none # Call with `check_closure=False` so that the autocreated vmap rule works. linear_solve_p = eqxi.create_vprim( "linear_solve", eqxi.filter_primitive_def(ft.partial(_linear_solve_impl, check_closure=False)), _linear_solve_abstract_eval, _linear_solve_jvp, _linear_solve_transpose, ) # Then rebind so that the impl rule catches leaked-in tracers. linear_solve_p.def_impl( eqxi.filter_primitive_def(ft.partial(_linear_solve_impl, check_closure=True)) ) eqxi.register_impl_finalisation(linear_solve_p) # # linear_solve # _SolverState = TypeVar("_SolverState") class AbstractLinearSolver(eqx.Module, Generic[_SolverState]): """Abstract base class for all linear solvers.""" @abc.abstractmethod def init( self, operator: AbstractLinearOperator, options: dict[str, Any] ) -> _SolverState: """Do any initial computation on just the `operator`. For example, an LU solver would compute the LU decomposition of the operator (and this does not require knowing the vector yet). It is common to need to solve the linear system `Ax=b` multiple times in succession, with the same operator `A` and multiple vectors `b`. This method improves efficiency by making it possible to re-use the computation performed on just the operator. !!! Example ```python operator = lx.MatrixLinearOperator(...) vector1 = ... vector2 = ... solver = lx.LU() state = solver.init(operator, options={}) solution1 = lx.linear_solve(operator, vector1, solver, state=state) solution2 = lx.linear_solve(operator, vector2, solver, state=state) ``` **Arguments:** - `operator`: a linear operator. - `options`: a dictionary of any extra options that the solver may wish to accept. **Returns:** A PyTree of arbitrary Python objects. """ @abc.abstractmethod def compute( self, state: _SolverState, vector: PyTree[Array], options: dict[str, Any] ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: """Solves a linear system. **Arguments:** - `state`: as returned from [`lineax.AbstractLinearSolver.init`][]. - `vector`: the vector to solve against. - `options`: a dictionary of any extra options that the solver may wish to accept. For example, [`lineax.CG`][] accepts a `preconditioner` option. **Returns:** A 3-tuple of: - The solution to the linear system. - An integer indicating the success or failure of the solve. This is an integer which may be converted to a human-readable error message via `lx.RESULTS[...]`. - A dictionary of an extra statistics about the solve, e.g. the number of steps taken. """ @abc.abstractmethod def transpose( self, state: _SolverState, options: dict[str, Any] ) -> tuple[_SolverState, dict[str, Any]]: """Transposes the result of [`lineax.AbstractLinearSolver.init`][]. That is, it should be the case that ```python state_transpose, _ = solver.transpose(solver.init(operator, options), options) state_transpose2 = solver.init(operator.T, options) ``` must be identical to each other. It is relatively common (in particular when differentiating through a linear solve) to need to solve both `Ax = b` and `A^T x = b`. This method makes it possible to avoid computing both `solver.init(operator)` and `solver.init(operator.T)` if one can be cheaply computed from the other. **Arguments:** - `state`: as returned from `solver.init`. - `options`: any extra options that were passed to `solve.init`. **Returns:** A 2-tuple of: - The state of the transposed operator. - The options for the transposed operator. """ @abc.abstractmethod def conj( self, state: _SolverState, options: dict[str, Any] ) -> tuple[_SolverState, dict[str, Any]]: """Conjugate the result of [`lineax.AbstractLinearSolver.init`][]. That is, it should be the case that ```python state_conj, _ = solver.conj(solver.init(operator, options), options) state_conj2 = solver.init(conj(operator), options) ``` must be identical to each other. **Arguments:** - `state`: as returned from `solver.init`. - `options`: any extra options that were passed to `solve.init`. **Returns:** A 2-tuple of: - The state of the conjugated operator. - The options for the conjugated operator. """ @abc.abstractmethod def assume_full_rank(self) -> bool: """Does this solver assume that all operators are full rank? When `False`, a more expensive backward pass is needed to account for the extra generality. In a custom linear solver, it is always safe to return False. **Arguments:** Nothing. **Returns:** Either `True` or `False`. """ _qr_token = eqxi.str2jax("qr_token") _diagonal_token = eqxi.str2jax("diagonal_token") _well_posed_diagonal_token = eqxi.str2jax("well_posed_diagonal_token") _tridiagonal_token = eqxi.str2jax("tridiagonal_token") _triangular_token = eqxi.str2jax("triangular_token") _cholesky_token = eqxi.str2jax("cholesky_token") _lu_token = eqxi.str2jax("lu_token") _svd_token = eqxi.str2jax("svd_token") # Ugly delayed import because we have the dependency chain # linear_solve -> AutoLinearSolver -> {Cholesky,...} -> AbstractLinearSolver # but we want linear_solver and AbstractLinearSolver in the same file. def _lookup(token) -> AbstractLinearSolver: from . import _solver # pyright doesn't know that these keys are hashable _lookup_dict = { _qr_token: _solver.QR(), # pyright: ignore _diagonal_token: _solver.Diagonal(), # pyright: ignore _well_posed_diagonal_token: _solver.Diagonal( # pyright: ignore well_posed=True ), _tridiagonal_token: _solver.Tridiagonal(), # pyright: ignore _triangular_token: _solver.Triangular(), # pyright: ignore _cholesky_token: _solver.Cholesky(), # pyright: ignore _lu_token: _solver.LU(), # pyright: ignore _svd_token: _solver.SVD(), # pyright: ignore } return _lookup_dict[token] _AutoLinearSolverState: TypeAlias = tuple[Any, Any] class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState]): """Automatically determines a good linear solver based on the structure of the operator. - If `well_posed=True`: - If the operator is diagonal, then use [`lineax.Diagonal`][]. - If the operator is tridiagonal, then use [`lineax.Tridiagonal`][]. - If the operator is triangular, then use [`lineax.Triangular`][]. - If the matrix is positive or negative (semi-)definite, then use [`lineax.Cholesky`][]. - Else use [`lineax.LU`][]. This is a good choice if you want to be certain that an error is raised for ill-posed systems. - If `well_posed=False`: - If the operator is diagonal, then use [`lineax.Diagonal`][]. - Else use [`lineax.SVD`][]. This is a good choice if you want to be certain that you can handle ill-posed systems. - If `well_posed=None`: - If the operator is non-square, then use [`lineax.QR`][]. - If the operator is diagonal, then use [`lineax.Diagonal`][]. - If the operator is tridiagonal, then use [`lineax.Tridiagonal`][]. - If the operator is triangular, then use [`lineax.Triangular`][]. - If the matrix is positive or negative (semi-)definite, then use [`lineax.Cholesky`][]. - Else, use [`lineax.LU`][]. This is a good choice if your primary concern is computational efficiency. It will handle ill-posed systems as long as it is not computationally expensive to do so. """ well_posed: bool | None def _select_solver(self, operator: AbstractLinearOperator): if self.well_posed is True: if operator.in_size() != operator.out_size(): raise ValueError( "Cannot use `AutoLinearSolver(well_posed=True)` with a non-square " "operator. If you are trying solve a least-squares problem then " "you should pass `solver=AutoLinearSolver(well_posed=False)`. By " "default `lineax.linear_solve` assumes that the operator is " "square and nonsingular." ) if is_diagonal(operator): token = _well_posed_diagonal_token elif is_tridiagonal(operator): token = _tridiagonal_token elif is_lower_triangular(operator) or is_upper_triangular(operator): token = _triangular_token elif is_positive_semidefinite(operator) or is_negative_semidefinite( operator ): token = _cholesky_token else: token = _lu_token elif self.well_posed is False: if is_diagonal(operator): token = _diagonal_token else: # TODO: use rank-revealing QR instead. token = _svd_token elif self.well_posed is None: if operator.in_size() != operator.out_size(): token = _qr_token elif is_diagonal(operator): token = _diagonal_token elif is_tridiagonal(operator): token = _tridiagonal_token elif is_lower_triangular(operator) or is_upper_triangular(operator): token = _triangular_token elif is_positive_semidefinite(operator) or is_negative_semidefinite( operator ): token = _cholesky_token else: token = _lu_token else: raise ValueError(f"Invalid value `well_posed={self.well_posed}`.") return token def select_solver(self, operator: AbstractLinearOperator) -> AbstractLinearSolver: """Check which solver that [`lineax.AutoLinearSolver`][] will dispatch to. **Arguments:** - `operator`: a linear operator. **Returns:** The linear solver that will be used. """ return _lookup(self._select_solver(operator)) def init(self, operator, options) -> _AutoLinearSolverState: token = self._select_solver(operator) return token, _lookup(token).init(operator, options) def compute( self, state: _AutoLinearSolverState, vector: PyTree[Array], options: dict[str, Any], ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: token, state = state solver = _lookup(token) solution, result, _ = solver.compute(state, vector, options) return solution, result, {} def transpose(self, state: _AutoLinearSolverState, options: dict[str, Any]): token, state = state solver = _lookup(token) transpose_state, transpose_options = solver.transpose(state, options) transpose_state = (token, transpose_state) return transpose_state, transpose_options def conj(self, state: _AutoLinearSolverState, options: dict[str, Any]): token, state = state solver = _lookup(token) conj_state, conj_options = solver.conj(state, options) conj_state = (token, conj_state) return conj_state, conj_options def assume_full_rank(self): return self.well_posed is not False AutoLinearSolver.__init__.__doc__ = """**Arguments:** - `well_posed`: whether to only handle well-posed systems or not, as discussed above. """ # TODO(kidger): gmres, bicgstab # TODO(kidger): support auxiliary outputs @eqx.filter_jit def linear_solve( operator: AbstractLinearOperator, vector: PyTree[ArrayLike], solver: AbstractLinearSolver = AutoLinearSolver(well_posed=True), *, options: dict[str, Any] | None = None, state: PyTree[Any] = sentinel, throw: bool = True, ) -> Solution: r"""Solves a linear system. Given an operator represented as a matrix $A$, and a vector $b$: if the operator is square and nonsingular (so that the problem is well-posed), then this returns the usual solution $x$ to $Ax = b$, defined as $A^{-1}b$. If the operator is overdetermined, then this either returns the least-squares solution $\min_x \| Ax - b \|_2$, or throws an error. (Depending on the choice of solver.) If the operator is underdetermined, then this either returns the minimum-norm solution $\min_x \|x\|_2 \text{ subject to } Ax = b$, or throws an error. (Depending on the choice of solver.) !!! info This function is equivalent to either `numpy.linalg.solve`, or to its generalisation `numpy.linalg.lstsq`, depending on the choice of solver. The default solver is `lineax.AutoLinearSolver(well_posed=True)`. This automatically selects a solver depending on the structure (e.g. triangular) of your problem, and will throw an error if your system is overdetermined or underdetermined. Use `lineax.AutoLinearSolver(well_posed=False)` if your system is known to be overdetermined or underdetermined (although handling this case implies greater computational cost). !!! tip These three kinds of solution to a linear system are collectively known as the "pseudoinverse solution" to a linear system. That is, given our matrix $A$, let $A^\dagger$ denote the [Moore--Penrose pseudoinverse](https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) of $A$. Then the usual/least-squares/minimum-norm solution are all equal to $A^\dagger b$. **Arguments:** - `operator`: a linear operator. This is the '$A$' in '$Ax = b$'. Most frequently this operator is simply represented as a JAX matrix (i.e. a rank-2 JAX array), but any [`lineax.AbstractLinearOperator`][] is supported. Note that if it is a matrix, then it should be passed as an [`lineax.MatrixLinearOperator`][], e.g. ```python matrix = jax.random.normal(key, (5, 5)) # JAX array of shape (5, 5) operator = lx.MatrixLinearOperator(matrix) # Wrap into a linear operator solution = lx.linear_solve(operator, ...) ``` rather than being passed directly. - `vector`: the vector to solve against. This is the '$b$' in '$Ax = b$'. - `solver`: the solver to use. Should be any [`lineax.AbstractLinearSolver`][]. The default is [`lineax.AutoLinearSolver`][] which behaves as discussed above. If the operator is overdetermined or underdetermined , then passing [`lineax.SVD`][] is typical. - `options`: Individual solvers may accept additional runtime arguments; for example [`lineax.CG`][] allows for specifying a preconditioner. See each individual solver's documentation for more details. Keyword only argument. - `state`: If performing multiple linear solves with the same operator, then some computation can be saved by recording and reusing some information; for example the matrix factorisation of the operator. This value should be the result of calling [`lineax.AbstractLinearSolver.init`][] on the provided `operator`. If provided, then the underlying `operator` must still be passed to `linear_solve`. Keyword only argument. - `throw`: How to report any failures. (E.g. an iterative solver running out of steps, or a well-posed-only solver being run with a singular operator.) If `True` then a failure will raise an error. Note that errors are only reliably raised on CPUs. If on GPUs then the error may only be printed to stderr, whilst on TPUs then the behaviour is undefined. If `False` then the returned solution object will have a `result` field indicating whether any failures occured. (See [`lineax.Solution`][].) Keyword only argument. **Returns:** An [`lineax.Solution`][] object containing the solution to the linear system. """ # noqa: E501 if eqx.is_array(operator): raise ValueError( "`lineax.linear_solve(operator=...)` should be an " "`AbstractLinearOperator`, not a raw JAX array. If you are trying to pass " "a matrix then this should be passed as " "`lineax.MatrixLinearOperator(matrix)`." ) if options is None: options = {} vector = jtu.tree_map(inexact_asarray, vector) vector_struct = strip_weak_dtype(jax.eval_shape(lambda: vector)) operator_out_structure = strip_weak_dtype(operator.out_structure()) # `is` to handle tracers if eqx.tree_equal(vector_struct, operator_out_structure) is not True: raise ValueError( "Vector and operator structures do not match. Got a vector with structure " f"{vector_struct} and an operator with out-structure " f"{operator_out_structure}" ) if isinstance(operator, IdentityLinearOperator): return Solution( value=vector, result=RESULTS.successful, state=state, stats={}, ) if state == sentinel: dynamic_operator, static_operator = eqx.partition(operator, eqx.is_array) stopped_operator = eqx.combine( lax.stop_gradient(dynamic_operator), static_operator ) state = solver.init(stopped_operator, options) dynamic_state, static_state = eqx.partition(state, eqx.is_array) dynamic_state = lax.stop_gradient(dynamic_state) state = eqx.combine(dynamic_state, static_state) options = eqxi.nondifferentiable( options, name="`lineax.linear_solve(..., options=...)`" ) solver = eqxi.nondifferentiable( solver, name="`lineax.linear_solve(..., solver=...)`" ) solution, result, stats = eqxi.filter_primitive_bind( linear_solve_p, operator, state, vector, options, solver, throw ) # TODO: prevent forward-mode autodiff through stats stats = eqxi.nondifferentiable_backward(stats) return Solution(value=solution, result=result, state=state, stats=stats) def invert( operator: AbstractLinearOperator, solver: AbstractLinearSolver = AutoLinearSolver(well_posed=True), *, options: dict[str, Any] | None = None, throw: bool = True, ) -> FunctionLinearOperator: r"""Returns a [`lineax.FunctionLinearOperator`][] representing the (pseudo)inverse of `operator`. `invert(A).mv(v)` is equivalent to `linear_solve(A, v, solver).value`. See [`lineax.linear_solve`][] for details on how the solution is defined for square, overdetermined, and underdetermined systems. The returned operator fully supports AD (both forward and reverse mode), `vmap`, and composition with other operators. **Arguments:** - `operator`: the linear operator to invert. - `solver`: the linear solver to use. Defaults to `AutoLinearSolver(well_posed=True)`. - `options`: additional options passed to the solver. Defaults to `None`. - `throw`: as [`lineax.linear_solve`][]. Defaults to `True`. **Returns:** A [`lineax.FunctionLinearOperator`][] whose `mv` solves `operator @ x = v`. """ if options is None: options = {} state = solver.init(operator, options) def solve_fn(vector): return linear_solve( operator, vector, solver, state=state, options=options, throw=throw, ).value tags = { tag for check, tag in [ (is_symmetric, symmetric_tag), (is_diagonal, diagonal_tag), (is_lower_triangular, lower_triangular_tag), (is_upper_triangular, upper_triangular_tag), (is_positive_semidefinite, positive_semidefinite_tag), (is_negative_semidefinite, negative_semidefinite_tag), ] if check(operator) } if has_unit_diagonal(operator) and ( is_diagonal(operator) or is_lower_triangular(operator) or is_upper_triangular(operator) ): tags.add(unit_diagonal_tag) return FunctionLinearOperator(solve_fn, operator.out_structure(), frozenset(tags)) # Work around JAX issue #22011, # as well as https://github.com/patrick-kidger/diffrax/pull/387#issuecomment-2174488365 def stop_gradient_transpose(ct, x): return (ct,) ad.primitive_transposes[stop_gradient_p] = stop_gradient_transpose ================================================ FILE: lineax/_solver/__init__.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .bicgstab import BiCGStab as BiCGStab from .cg import CG as CG, NormalCG as NormalCG from .cholesky import Cholesky as Cholesky from .diagonal import Diagonal as Diagonal from .gmres import GMRES as GMRES from .lsmr import LSMR as LSMR from .lu import LU as LU from .normal import Normal as Normal from .qr import QR as QR from .svd import SVD as SVD from .triangular import Triangular as Triangular from .tridiagonal import Tridiagonal as Tridiagonal ================================================ FILE: lineax/_solver/bicgstab.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable from typing import Any, TypeAlias import jax import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu from equinox.internal import ω from jaxtyping import Array, PyTree from .._norm import max_norm, tree_dot from .._operator import AbstractLinearOperator, conj, linearise from .._solution import RESULTS from .._solve import AbstractLinearSolver from .misc import preconditioner_and_y0 _BiCGStabState: TypeAlias = AbstractLinearOperator class BiCGStab(AbstractLinearSolver[_BiCGStabState]): """Biconjugate gradient stabilised method for linear systems. The operator should be square. Equivalent to `jax.scipy.sparse.linalg.bicgstab`. This supports the following `options` (as passed to `lx.linear_solve(..., options=...)`). - `preconditioner`: A [`lineax.AbstractLinearOperator`][] to be used as a preconditioner. Defaults to [`lineax.IdentityLinearOperator`][]. This method uses right preconditioning. - `y0`: The initial estimate of the solution to the linear system. Defaults to all zeros. """ rtol: float atol: float norm: Callable = max_norm max_steps: int | None = None def __check_init__(self): if isinstance(self.rtol, (int, float)) and self.rtol < 0: raise ValueError("Tolerances must be non-negative.") if isinstance(self.atol, (int, float)) and self.atol < 0: raise ValueError("Tolerances must be non-negative.") if isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)): if self.atol == 0 and self.rtol == 0 and self.max_steps is None: raise ValueError( "Must specify `rtol`, `atol`, or `max_steps` (or some combination " "of all three)." ) def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): if operator.in_structure() != operator.out_structure(): raise ValueError( "`BiCGstab(..., normal=False)` may only be used for linear solves with " "square matrices." ) return linearise(operator) def compute( self, state: _BiCGStabState, vector: PyTree[Array], options: dict[str, Any] ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: operator = state preconditioner, y0 = preconditioner_and_y0(operator, vector, options) leaves, _ = jtu.tree_flatten(vector) if self.max_steps is None: size = sum(leaf.size for leaf in leaves) max_steps = 10 * size else: max_steps = self.max_steps has_scale = not ( isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)) and self.atol == 0 and self.rtol == 0 ) if has_scale: b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω # This implementation is the same a jax.scipy.sparse.linalg.bicgstab # but with AbstractLinearOperator. # We use the notation found on the wikipedia except with y instead of x: # https://en.wikipedia.org/wiki/ # Biconjugate_gradient_stabilized_method#Preconditioned_BiCGSTAB # preconditioner in this case is K2^(-1) (i.e., right preconditioning) r0 = (vector**ω - operator.mv(y0) ** ω).ω def breakdown_occurred(omega, alpha, rho): # Empirically, the tolerance checks for breakdown are very tight. # These specific tolerances are heuristic. if jax.config.jax_enable_x64: # pyright: ignore return (omega == 0.0) | (alpha == 0.0) | (rho == 0.0) else: return (omega < 1e-16) | (alpha < 1e-16) | (rho < 1e-16) def not_converged(r, diff, y): # The primary tolerance check. # Given Ay=b, then we have to be doing better than `scale` in both # the `y` and the `b` spaces. if has_scale: with jax.numpy_dtype_promotion("standard"): y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω norm1 = self.norm((r**ω / b_scale**ω).ω) # pyright: ignore norm2 = self.norm((diff**ω / y_scale**ω).ω) return (norm1 > 1) | (norm2 > 1) else: return True def cond_fun(carry): y, r, alpha, omega, rho, _, _, diff, step = carry out = jnp.invert(breakdown_occurred(omega, alpha, rho)) out = out & not_converged(r, diff, y) out = out & (step < max_steps) return out def body_fun(carry): y, r, alpha, omega, rho, p, v, diff, step = carry rho_new = tree_dot(r0, r) beta = (rho_new / rho) * (alpha / omega) p_new = (r**ω + beta * (p**ω - omega * v**ω)).ω # TODO(raderj): reduce this to a single operator.mv call # by using the scan trick. x = preconditioner.mv(p_new) v_new = operator.mv(x) alpha_new = rho_new / tree_dot(r0, v_new) s = (r**ω - alpha_new * v_new**ω).ω z = preconditioner.mv(s) t = operator.mv(z) omega_new = tree_dot(s, t) / tree_dot(t, t) diff = (alpha_new * x**ω + omega_new * z**ω).ω y_new = (y**ω + diff**ω).ω r_new = (s**ω - omega_new * t**ω).ω return ( y_new, r_new, alpha_new, omega_new, rho_new, p_new, v_new, diff, step + 1, ) p0 = v0 = jtu.tree_map(jnp.zeros_like, vector) alpha = omega = rho = jnp.array(1.0) init_carry = ( y0, r0, alpha, omega, rho, p0, v0, ω(y0).call(lambda x: jnp.full_like(x, jnp.inf)).ω, 0, ) solution, residual, alpha, omega, rho, _, _, diff, num_steps = lax.while_loop( cond_fun, body_fun, init_carry ) if self.max_steps is None: result = RESULTS.where( num_steps == max_steps, RESULTS.singular, RESULTS.successful ) elif has_scale: result = RESULTS.where( num_steps == max_steps, RESULTS.max_steps_reached, RESULTS.successful ) else: result = RESULTS.successful # breakdown is only an issue if we did not converge breakdown = breakdown_occurred(omega, alpha, rho) & not_converged( residual, diff, solution ) result = RESULTS.where(breakdown, RESULTS.breakdown, result) stats = {"num_steps": num_steps, "max_steps": self.max_steps} return solution, result, stats def transpose(self, state: _BiCGStabState, options: dict[str, Any]): transpose_options = {} if "preconditioner" in options: transpose_options["preconditioner"] = options["preconditioner"].transpose() operator = state return operator.transpose(), transpose_options def conj(self, state: _BiCGStabState, options: dict[str, Any]): conj_options = {} if "preconditioner" in options: conj_options["preconditioner"] = conj(options["preconditioner"]) operator = state return conj(operator), conj_options def assume_full_rank(self): return True BiCGStab.__init__.__doc__ = r"""**Arguments:** - `rtol`: Relative tolerance for terminating solve. - `atol`: Absolute tolerance for terminating solve. - `norm`: The norm to use when computing whether the error falls within the tolerance. Defaults to the max norm. - `max_steps`: The maximum number of iterations to run the solver for. If more steps than this are required, then the solve is halted with a failure. """ ================================================ FILE: lineax/_solver/cg.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from collections.abc import Callable from typing import Any, TypeAlias import equinox.internal as eqxi import jax import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu from equinox.internal import ω from jaxtyping import Array, PyTree, Scalar from .._misc import resolve_rcond, structure_equal, tree_where from .._norm import max_norm, tree_dot from .._operator import ( AbstractLinearOperator, conj, is_negative_semidefinite, is_positive_semidefinite, linearise, ) from .._solution import RESULTS from .._solve import AbstractLinearSolver from .misc import preconditioner_and_y0 from .normal import Normal _CGState: TypeAlias = tuple[AbstractLinearOperator, eqxi.Static] # TODO(kidger): this is pretty slow to compile. # - CG evaluates `operator.mv` three times. # Possibly this can be cheapened a bit somehow? class CG(AbstractLinearSolver[_CGState]): """Conjugate gradient solver for linear systems. The operator should be positive or negative definite. Equivalent to `scipy.sparse.linalg.cg`. This supports the following `options` (as passed to `lx.linear_solve(..., options=...)`). - `preconditioner`: A positive definite [`lineax.AbstractLinearOperator`][] to be used as preconditioner. Defaults to [`lineax.IdentityLinearOperator`][]. This method uses left preconditioning, so it is the preconditioned residual that is minimized, though the actual termination criteria uses the un-preconditioned residual. - `y0`: The initial estimate of the solution to the linear system. Defaults to all zeros. """ rtol: float atol: float norm: Callable[[PyTree], Scalar] = max_norm stabilise_every: int | None = 10 max_steps: int | None = None def __check_init__(self): if isinstance(self.rtol, (int, float)) and self.rtol < 0: raise ValueError("Tolerances must be non-negative.") if isinstance(self.atol, (int, float)) and self.atol < 0: raise ValueError("Tolerances must be non-negative.") if isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)): if self.atol == 0 and self.rtol == 0 and self.max_steps is None: raise ValueError( "Must specify `rtol`, `atol`, or `max_steps` (or some combination " "of all three)." ) def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): del options is_nsd = is_negative_semidefinite(operator) if not structure_equal(operator.in_structure(), operator.out_structure()): raise ValueError( "`CG()` may only be used for linear solves with square matrices." ) if not (is_positive_semidefinite(operator) | is_nsd): raise ValueError( "`CG()` may only be used for positive " "or negative definite linear operators" ) if is_nsd: operator = -operator operator = linearise(operator) return operator, eqxi.Static(is_nsd) # This differs from jax.scipy.sparse.linalg.cg in: # 1. Every few steps we calculate the residual directly, rather than by cheaply # using the existing quantities. This improves numerical stability. # 2. We use a more sophisticated termination condition. To begin with we have an # rtol and atol in the conventional way, inducing a vector-valued scale. This is # then checked in both the `y` and `b` domains (for `Ay = b`). # 3. We return the number of steps, and whether or not the solve succeeded, as # additional information. # 4. We don't try to support complex numbers. (Yet.) def compute( self, state: _CGState, vector: PyTree[Array], options: dict[str, Any] ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: operator, is_nsd = state is_nsd = is_nsd.value preconditioner, y0 = preconditioner_and_y0(operator, vector, options) if not is_positive_semidefinite(preconditioner): raise ValueError("The preconditioner must be positive definite.") leaves, _ = jtu.tree_flatten(vector) size = sum(leaf.size for leaf in leaves) if self.max_steps is None: max_steps = 10 * size # Copied from SciPy! else: max_steps = self.max_steps r0 = (vector**ω - operator.mv(y0) ** ω).ω p0 = preconditioner.mv(r0) gamma0 = tree_dot(p0, r0) rcond = resolve_rcond(None, size, size, jnp.result_type(*leaves)) initial_value = ( ω(y0).call(lambda x: jnp.full_like(x, jnp.inf)).ω, y0, r0, p0, gamma0, 0, ) has_scale = not ( isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)) and self.atol == 0 and self.rtol == 0 ) if has_scale: b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω def not_converged(r, diff, y): # The primary tolerance check. # Given Ay=b, then we have to be doing better than `scale` in both # the `y` and the `b` spaces. if has_scale: with jax.numpy_dtype_promotion("standard"): y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω norm1 = self.norm((r**ω / b_scale**ω).ω) # pyright: ignore norm2 = self.norm((diff**ω / y_scale**ω).ω) return (norm1 > 1) | (norm2 > 1) else: return True def cond_fun(value): diff, y, r, _, gamma, step = value out = gamma > 0 out = out & (step < max_steps) out = out & not_converged(r, diff, y) return out def body_fun(value): _, y, r, p, gamma, step = value mat_p = operator.mv(p) inner_prod = tree_dot(mat_p, p) alpha = gamma / inner_prod alpha = tree_where( jnp.abs(inner_prod) > 100 * rcond * jnp.abs(gamma), # pyright: ignore alpha, jnp.nan, # pyright: ignore ) diff = (alpha * p**ω).ω y = (y**ω + diff**ω).ω step = step + 1 # E.g. see B.2 of # https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf # We compute the residual the "expensive" way every now and again, so as to # correct numerical rounding errors. def stable_r(): return (vector**ω - operator.mv(y) ** ω).ω def cheap_r(): return (r**ω - alpha * mat_p**ω).ω if self.stabilise_every == 1: r = stable_r() elif self.stabilise_every is None: r = cheap_r() else: stable_step = (eqxi.unvmap_max(step) % self.stabilise_every) == 0 stable_step = eqxi.nonbatchable(stable_step) r = lax.cond(stable_step, stable_r, cheap_r) z = preconditioner.mv(r) gamma_prev = gamma gamma = tree_dot(z, r) beta = gamma / gamma_prev p = (z**ω + beta * p**ω).ω return diff, y, r, p, gamma, step _, solution, _, _, _, num_steps = lax.while_loop( cond_fun, body_fun, initial_value ) if self.max_steps is None: result = RESULTS.where( num_steps == max_steps, RESULTS.singular, RESULTS.successful ) elif has_scale: result = RESULTS.where( num_steps == max_steps, RESULTS.max_steps_reached, RESULTS.successful ) else: result = RESULTS.successful if is_nsd: solution = -(solution**ω).ω stats = {"num_steps": num_steps, "max_steps": self.max_steps} return solution, result, stats def transpose( self, state: _CGState, options: dict[str, Any] ) -> tuple[_CGState, dict[str, Any]]: transpose_options = {} if "preconditioner" in options: transpose_options["preconditioner"] = options["preconditioner"].transpose() psd_op, is_nsd = state transpose_state = psd_op.transpose(), is_nsd return transpose_state, transpose_options def conj( self, state: _CGState, options: dict[str, Any] ) -> tuple[_CGState, dict[str, Any]]: conj_options = {} if "preconditioner" in options: conj_options["preconditioner"] = conj(options["preconditioner"]) psd_op, is_nsd = state conj_state = conj(psd_op), is_nsd return conj_state, conj_options def assume_full_rank(self): return True CG.__init__.__doc__ = r"""**Arguments:** - `rtol`: Relative tolerance for terminating solve. - `atol`: Absolute tolerance for terminating solve. - `norm`: The norm to use when computing whether the error falls within the tolerance. Defaults to the max norm. - `stabilise_every`: The conjugate gradient is an iterative method that produces candidate solutions $x_1, x_2, \ldots$, and terminates once $r_i = \| Ax_i - b \|$ is small enough. For computational efficiency, the values $r_i$ are computed using other internal quantities, and not by directly evaluating the formula above. However, this computation of $r_i$ is susceptible to drift due to limited floating-point precision. Every `stabilise_every` steps, then $r_i$ is computed directly using the formula above, in order to stabilise the computation. - `max_steps`: The maximum number of iterations to run the solver for. If more steps than this are required, then the solve is halted with a failure. """ def NormalCG(*args, **kwargs): """Deprecated helper function. Use `lx.Normal(lx.CG(...))` instead. !!! warning "Deprecated" `NormalCG(...)` is deprecated in favour of `lx.Normal(lx.CG(...))`. This will be removed in some future version of Lineax. """ warnings.warn( "`NormalCG(...)` is deprecated in favour of `lx.Normal(lx.CG(...))`. " "This will be removed in some future version of Lineax.", DeprecationWarning, stacklevel=2, ) return Normal(CG(*args, **kwargs)) ================================================ FILE: lineax/_solver/cholesky.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, TypeAlias import equinox.internal as eqxi import jax.flatten_util as jfu import jax.scipy as jsp from jaxtyping import Array, PyTree from .._operator import ( AbstractLinearOperator, is_negative_semidefinite, is_positive_semidefinite, ) from .._solution import RESULTS from .._solve import AbstractLinearSolver _CholeskyState: TypeAlias = tuple[Array, eqxi.Static] class Cholesky(AbstractLinearSolver[_CholeskyState]): """Cholesky solver for linear systems. This is generally the preferred solver for positive or negative definite systems. Equivalent to `scipy.linalg.solve(..., assume_a="pos")`. The operator must be square, nonsingular, and either positive or negative definite. """ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): del options is_nsd = is_negative_semidefinite(operator) if not (is_positive_semidefinite(operator) | is_nsd): raise ValueError( "`Cholesky(..., normal=False)` may only be used for positive " "or negative definite linear operators" ) matrix = operator.as_matrix() m, n = matrix.shape if m != n: raise ValueError( "`Cholesky(..., normal=False)` may only be used for linear solves " "with square matrices" ) if is_nsd: matrix = -matrix factor, lower = jsp.linalg.cho_factor(matrix) # Fix upper triangular for simplicity. assert lower is False return factor, eqxi.Static(is_nsd) def compute( self, state: _CholeskyState, vector: PyTree[Array], options: dict[str, Any] ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: factor, is_nsd = state is_nsd = is_nsd.value del options # Cholesky => PSD => symmetric => (in_structure == out_structure) => # we don't need to use packed structures. vector, unflatten = jfu.ravel_pytree(vector) solution = jsp.linalg.cho_solve((factor, False), vector) if is_nsd: solution = -solution solution = unflatten(solution) return solution, RESULTS.successful, {} def transpose( self, state: _CholeskyState, options: dict[str, Any] ) -> tuple[_CholeskyState, dict[str, Any]]: # Matrix is self-adjoint factor, is_nsd = state return (factor.conj(), is_nsd), options def conj( self, state: _CholeskyState, options: dict[str, Any] ) -> tuple[_CholeskyState, dict[str, Any]]: # Matrix is self-adjoint factor, is_nsd = state return (factor.conj(), is_nsd), options def assume_full_rank(self): return True Cholesky.__init__.__doc__ = """**Arguments:** Nothing. """ ================================================ FILE: lineax/_solver/diagonal.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, TypeAlias import jax.numpy as jnp from jaxtyping import Array, PyTree from .._misc import resolve_rcond from .._operator import AbstractLinearOperator, diagonal, has_unit_diagonal, is_diagonal from .._solution import RESULTS from .._solve import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, ravel_vector, transpose_packed_structures, unravel_solution, ) _DiagonalState: TypeAlias = tuple[Array | None, PackedStructures] class Diagonal(AbstractLinearSolver[_DiagonalState]): """Diagonal solver for linear systems. Requires that the operator be diagonal. Then $Ax = b$, with $A = diag[a]$, is solved simply by doing an elementwise division $x = b / a$. This solver can handle singular operators (i.e. diagonal entries with value 0). """ well_posed: bool = False rcond: float | None = None def init( self, operator: AbstractLinearOperator, options: dict[str, Any] ) -> _DiagonalState: del options if operator.in_size() != operator.out_size(): raise ValueError( "`Diagonal` may only be used for linear solves with square matrices" ) if not is_diagonal(operator): raise ValueError( "`Diagonal` may only be used for linear solves with diagonal matrices" ) packed_structures = pack_structures(operator) if has_unit_diagonal(operator): return None, packed_structures else: return diagonal(operator), packed_structures def compute( self, state: _DiagonalState, vector: PyTree[Array], options: dict[str, Any] ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: diag, packed_structures = state del state, options unit_diagonal = diag is None vector = ravel_vector(vector, packed_structures) if unit_diagonal: solution = vector else: if not self.well_posed: (size,) = diag.shape rcond = resolve_rcond(self.rcond, size, size, diag.dtype) abs_diag = jnp.abs(diag) diag = jnp.where(abs_diag > rcond * jnp.max(abs_diag), diag, jnp.inf) # pyright: ignore solution = vector / diag solution = unravel_solution(solution, packed_structures) return solution, RESULTS.successful, {} def transpose(self, state: _DiagonalState, options: dict[str, Any]): del options diag, packed_structures = state transposed_packed_structures = transpose_packed_structures(packed_structures) transpose_state = diag, transposed_packed_structures transpose_options = {} return transpose_state, transpose_options def conj(self, state: _DiagonalState, options: dict[str, Any]): del options diag, packed_structures = state if diag is None: conj_diag = None else: conj_diag = diag.conj() conj_options = {} conj_state = conj_diag, packed_structures return conj_state, conj_options def assume_full_rank(self): return self.well_posed Diagonal.__init__.__doc__ = """**Arguments**: - `well_posed`: if `False`, then singular operators are accepted, and the pseudoinverse solution is returned. If `True` then passing a singular operator will cause an error to be raised instead. - `rcond`: the cutoff for handling zero entries on the diagonal. Defaults to machine precision times `N`, where `N` is the input (or output) size of the operator. Only used if `well_posed=False` """ ================================================ FILE: lineax/_solver/gmres.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools as ft from collections.abc import Callable from typing import Any, cast, TypeAlias import equinox.internal as eqxi import jax import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu from equinox.internal import ω from jaxtyping import Array, ArrayLike, Bool, Float, Inexact, PyTree from .._misc import structure_equal from .._norm import max_norm, two_norm from .._operator import AbstractLinearOperator, conj, linearise, MatrixLinearOperator from .._solution import RESULTS from .._solve import AbstractLinearSolver, linear_solve from .misc import preconditioner_and_y0 from .qr import QR _GMRESState: TypeAlias = AbstractLinearOperator class GMRES(AbstractLinearSolver[_GMRESState]): """GMRES solver for linear systems. The operator should be square. Similar to `jax.scipy.sparse.linalg.gmres`. This supports the following `options` (as passed to `lx.linear_solve(..., options=...)`). - `preconditioner`: A [`lineax.AbstractLinearOperator`][] to be used as preconditioner. Defaults to [`lineax.IdentityLinearOperator`][]. This method uses left preconditioning, so it is the preconditioned residual that is minimized, though the actual termination criteria uses the un-preconditioned residual. - `y0`: The initial estimate of the solution to the linear system. Defaults to all zeros. """ rtol: float atol: float norm: Callable = max_norm max_steps: int | None = None restart: int = 20 stagnation_iters: int = 20 def __check_init__(self): if isinstance(self.rtol, (int, float)) and self.rtol < 0: raise ValueError("Tolerances must be non-negative.") if isinstance(self.atol, (int, float)) and self.atol < 0: raise ValueError("Tolerances must be non-negative.") if isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)): if self.atol == 0 and self.rtol == 0 and self.max_steps is None: raise ValueError( "Must specify `rtol`, `atol`, or `max_steps` (or some combination " "of all three)." ) def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): del options if not structure_equal(operator.in_structure(), operator.out_structure()): raise ValueError( "`GMRES(..., normal=False)` may only be used for linear solves with " "square matrices." ) return linearise(operator) # # This differs from `jax.scipy.sparse.linalg.gmres` in a few ways: # 1. We use a more sophisticated termination condition. To begin with we have an # rtol and atol in the conventional way, inducing a vector-valued scale. This is # then checked in both the `y` and `b` domains (for `Ay = b`). # 2. We handle in-place updates with buffers to avoid generating unnecessary # copies of arrays during the Gram-Schmidt procedure. # 3. We use a QR solve at the end of the batched Gram-Schmidt instead # of a Cholesky solve of the normal equations. This is both faster and more # numerically stable. # 4. We use tricks to compile `A y` fewer times throughout the code, including # passing a dummy initial residual. # 5. We return the number of steps, and whether or not the solve succeeded, as # additional information. # 6. We do not use the unnecessary loop within Gram-Schmidt, and simply compute # this in a single pass. # 7. We add better safety checks for breakdown, and a safety check for stagnation # of the iterates even when we don't explicitly get breakdown. # def compute( self, state: _GMRESState, vector: PyTree[Array], options: dict[str, Any], ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: has_scale = not ( isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)) and self.atol == 0 and self.rtol == 0 ) if has_scale: b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω operator = state preconditioner, y0 = preconditioner_and_y0(operator, vector, options) leaves, _ = jtu.tree_flatten(vector) size = sum(leaf.size for leaf in leaves) if self.max_steps is None: max_steps = 10 * size # Copied from SciPy! else: max_steps = self.max_steps restart = min(self.restart, size) def not_converged(r, diff, y): # The primary tolerance check. # Given Ay=b, then we have to be doing better than `scale` in both # the `y` and the `b` spaces. if has_scale: with jax.numpy_dtype_promotion("standard"): y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω norm1 = self.norm((r**ω / b_scale**ω).ω) # pyright: ignore norm2 = self.norm((diff**ω / y_scale**ω).ω) return (norm1 > 1) | (norm2 > 1) else: return True def cond_fun(carry): y, r, _, deferred_breakdown, diff, _, step, stagnation_counter = carry # NOTE: we defer ending due to breakdown by one loop! This is nonstandard, # but lets us use a cauchy-like condition in the convergence criteria. # If we do not defer breakdown, breakdown may detect convergence when # the diff between two iterations is still quite large, and we only # consider convergence when the diff is small. out = jnp.invert(deferred_breakdown) & ( stagnation_counter < self.stagnation_iters ) out = out & not_converged(r, diff, y) out = out & (step < max_steps) # The first pass uses a dummy value for r0 in order to save on compiling # an extra matvec. The dummy step may raise a breakdown, and `step == 0` # avoids us from returning prematurely. return out | (step == 0) def body_fun(carry): # `breakdown` -> `deferred_breakdown` and `deferred_breakdown` -> `_` y, r, deferred_breakdown, _, diff, r_min, step, stagnation_counter = carry y_new, r_new, breakdown, diff_new = self._gmres_compute( operator, vector, y, r, restart, preconditioner, step == 0 ) # # If the minimum residual does not decrease for many iterations # ("many" is determined by self.stagnation_iters) then the iterative # solve has stagnated and we stop the loop. This bit keeps track of how # long it has been since the minimum has decreased, and updates the minimum # when a new minimum is encountered. As far as I (raderj) am # aware, this is custom to our implementation and not standard practice. # r_new_norm = self.norm(r_new) r_decreased = (r_new_norm - r_min) < 0 stagnation_counter = jnp.where(r_decreased, 0, stagnation_counter + 1) stagnation_counter = cast(Array, stagnation_counter) r_min = jnp.minimum(r_new_norm, r_min) return ( y_new, r_new, breakdown, deferred_breakdown, diff_new, r_min, step + 1, stagnation_counter, ) # Initialise the residual r0 to the dummy value of all 0s. This means # the first iteration of Gram-Schmidt will do nothing, but it saves # us from compiling an extra matvec here. r0 = ω(vector).call(jnp.zeros_like).ω init_carry = ( y0, # y r0, # residual False, # breakdown False, # deferred_breakdown ω(y0).call(lambda x: jnp.full_like(x, jnp.inf)).ω, # diff jnp.inf, # r_min 0, # steps jnp.array(0), # stagnation counter ) ( solution, residual, _, # breakdown breakdown, # deferred_breakdown diff, _, num_steps, stagnation_counter, ) = lax.while_loop(cond_fun, body_fun, init_carry) if self.max_steps is None: result = RESULTS.where( num_steps == max_steps, RESULTS.singular, RESULTS.successful ) elif has_scale: result = RESULTS.where( num_steps == max_steps, RESULTS.max_steps_reached, RESULTS.successful ) else: result = RESULTS.successful result = RESULTS.where( stagnation_counter >= self.stagnation_iters, RESULTS.stagnation, result ) # breakdown is only an issue if we broke down outside the tolerance # of the solution. If we get breakdown and are within the tolerance, # this is called convergence :) breakdown = breakdown & not_converged(residual, diff, solution) # breakdown is the most serious potential issue result = RESULTS.where(breakdown, RESULTS.breakdown, result) stats = {"num_steps": num_steps, "max_steps": self.max_steps} return solution, result, stats def _gmres_compute( self, operator, vector, y, r, restart, preconditioner, first_pass ): # # internal function for computing the bulk of the gmres. We seperate this out # for two reasons: # 1. avoid nested body and cond functions in the body and cond function of # `self.compute`. `self.compute` is primarily responsible for the restart # behavior of gmres. # 2. Like the jax.scipy implementation we may want to add an incremental # version at a later date. # def main_gmres(y): # see the comment at the end of `_arnoldi_gram_schmidt` for a discussion # of `initial_breakdown` r_normalised, r_norm, initial_breakdown = self._normalise(r, eps=None) basis_init = jtu.tree_map( lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)), r_normalised, ) coeff_mat_init = jnp.eye( restart, restart + 1, dtype=jnp.result_type(*jtu.tree_leaves(r_normalised)), ) def cond_fun(carry): _, _, breakdown, step = carry return (step < restart) & jnp.invert(breakdown) def body_fun(carry): basis, coeff_mat, breakdown, step = carry basis_new, coeff_mat_new, breakdown = self._arnoldi_gram_schmidt( operator, preconditioner, basis, coeff_mat, step, restart, vector, breakdown, ) return basis_new, coeff_mat_new, breakdown, step + 1 def buffers(carry): basis, coeff_mat, _, _ = carry return basis, coeff_mat init_carry = (basis_init, coeff_mat_init, initial_breakdown, 0) basis, coeff_mat, breakdown, steps = eqxi.while_loop( cond_fun, body_fun, init_carry, kind="lax", buffers=buffers ) beta_vec = jnp.concatenate( ( r_norm[None].astype(jnp.result_type(coeff_mat)), jnp.zeros_like(coeff_mat, shape=(restart,)), ) ) coeff_op_transpose = MatrixLinearOperator(coeff_mat.T) # TODO(raderj): move to a Hessenberg-specific solver z = linear_solve(coeff_op_transpose, beta_vec, QR(), throw=False).value diff = jtu.tree_map( lambda mat: jnp.tensordot( mat[..., :-1], z, axes=1, precision=lax.Precision.HIGHEST ), basis, ) y_new = (y**ω + diff**ω).ω return y_new, diff, breakdown def first_gmres(y): return y, ω(y).call(lambda x: jnp.full_like(x, jnp.inf)).ω, False first_pass = eqxi.unvmap_any(first_pass) y_new, diff, breakdown = lax.cond(first_pass, first_gmres, main_gmres, y) r_new = preconditioner.mv((vector**ω - operator.mv(y_new) ** ω).ω) return y_new, r_new, breakdown, diff # NOTE: in the jax implementation: # https://github.com/google/jax/blob/ # c662fd216dec10cdb2cff4138b4318bb98853134/jax/_src/scipy/sparse/linalg.py#L327 # _classical_iterative_gram_schmidt uses a while loop to call this. # However, max_iterations is set to 2 in all calls they make to the function, # and the condition function requires steps < (max_iterations - 1). # This means that in fact they only apply Gram-Schmidt once, and using a # while_loop is unnecessary. def _arnoldi_gram_schmidt( self, operator, preconditioner, basis, coeff_mat, step, restart, vector, initial_breakdown, ): # # compute `basis.T @ basis_step` for each leaf of pytree # and then compute the projected vector onto the basis # # `basis` is a pytree with buffers, meaning it can only be # indexed into. Through this section, there are terms like `lambda _, x: ...` # because`jtu.tree_map` only uses the first argument to determine the shape # of the pytree. Since _Buffer is considered part of the pytree # structure, we get leaves which are not buffers if we directly pass `basis`. # Instead, we make sure that the first argument of the tree map is something # with the correct pytree structure, such as `vector` in the dummy case and # basis_step when not, so that we correctly index into `basis`. # basis_step = preconditioner.mv( operator.mv(jtu.tree_map(lambda _, x: x[..., step], vector, basis)) ) step_norm = two_norm(basis_step) contract_matrix = lambda x, y: ft.partial( jnp.tensordot, axes=x.ndim, precision=lax.Precision.HIGHEST )(x, y[...].conj()) _proj = jtu.tree_map(contract_matrix, basis_step, basis) proj = jtu.tree_reduce(lambda x, y: x + y, _proj) proj_on_cols = jtu.tree_map(lambda _, x: x[...] @ proj, vector, basis) # now remove the component of the vector in that subspace basis_step_new = (basis_step**ω - proj_on_cols**ω).ω eps = step_norm * jnp.finfo(proj.dtype).eps basis_step_normalised, step_norm_new, breakdown = self._normalise( basis_step_new, eps=eps ) basis_new = jtu.tree_map( lambda y, mat: mat.at[..., step + 1].set(y), basis_step_normalised, basis, ) proj_new = proj.at[step + 1].set(step_norm_new.astype(jnp.result_type(proj))) # # NOTE: two somewhat complicated things are going on here: # # The `coeff_mat` in_place update has a batch tracer, so we need to be # careful and wrap it in a buffer, hence the use of eqxi.while_loop # instead of lax.while_loop throughout. # # `initial_breakdown` occurs when the previous loop returns a # residual which is small enough to be interpreted as 0 by self._normalise, # but which was passed through the solver anyway. This occurs when # the residual is small but the diff is not, or if the # correct solution was given to GMRES from the start. Both of these tend to # happen at the start of `gmres_compute`. # The latter may happen when using a sequence of iterative methods. # If `initial_breakdown` occurs, then we leave the `coeff_mat` as it was # at initialisation. Replacing it with the projection (which will be all 0s) # will mean `coeff_mat` is not full-rank, and `QR` can only handle nonsquare # matrices of full-rank. # coeff_mat_new = coeff_mat.at[step, :].set( proj_new, pred=jnp.invert(initial_breakdown) ) return basis_new, coeff_mat_new, breakdown def _normalise( self, x: PyTree[Array], eps: Float[ArrayLike, ""] | None ) -> tuple[PyTree[Array], Inexact[Array, ""], Bool[ArrayLike, ""]]: norm = two_norm(x) if eps is None: eps = jnp.finfo(norm.dtype).eps else: eps = jnp.astype(eps, norm.dtype) breakdown = norm < eps # pyright: ignore safe_norm = jnp.where(breakdown, jnp.inf, norm) with jax.numpy_dtype_promotion("standard"): x_normalised = (x**ω / safe_norm).ω return x_normalised, norm, breakdown def transpose(self, state: _GMRESState, options: dict[str, Any]): transpose_options = {} if "preconditioner" in options: transpose_options["preconditioner"] = options["preconditioner"].transpose() operator = state return operator.transpose(), transpose_options def conj(self, state: _GMRESState, options: dict[str, Any]): conj_options = {} if "preconditioner" in options: conj_options["preconditioner"] = conj(options["preconditioner"]) operator = state return conj(operator), conj_options def assume_full_rank(self): return True GMRES.__init__.__doc__ = r"""**Arguments:** - `rtol`: Relative tolerance for terminating solve. - `atol`: Absolute tolerance for terminating solve. - `norm`: The norm to use when computing whether the error falls within the tolerance. Defaults to the max norm. - `max_steps`: The maximum number of iterations to run the solver for. If more steps than this are required, then the solve is halted with a failure. - `restart`: Size of the Krylov subspace built between restarts. The returned solution is the projection of the true solution onto this subpsace, so this direclty bounds the accuracy of the algorithm. Default is 20. - `stagnation_iters`: The maximum number of iterations for which the solver may not decrease. If more than `stagnation_iters` restarts are performed without sufficient decrease in the residual, the algorithm is halted. """ ================================================ FILE: lineax/_solver/lsmr.py ================================================ """Implementation adapted from SciPy, with BSD license: Copyright (c) 2001-2002 Enthought, Inc. 2003-2024, SciPy Developers. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ from collections.abc import Callable from typing import Any, TypeAlias import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu from equinox.internal import ω from jaxtyping import Array, PyTree from .._misc import complex_to_real_dtype from .._norm import two_norm from .._operator import AbstractLinearOperator, conj, linearise from .._solution import RESULTS from .._solve import AbstractLinearSolver _LSMRState: TypeAlias = AbstractLinearOperator class LSMR(AbstractLinearSolver[_LSMRState]): """LSMR solver for linear systems. This solver can handle any operator, even nonsquare or singular ones. In these cases it will return the pseudoinverse solution to the linear system. Similar to `scipy.sparse.linalg.lsmr`. This supports the following `options` (as passed to `lx.linear_solve(..., options=...)`). - `y0`: The initial estimate of the solution to the linear system. Defaults to all zeros. """ rtol: float atol: float norm: Callable = two_norm max_steps: int | None = None conlim: float = 1e8 def __check_init__(self): if isinstance(self.rtol, (int, float)) and self.rtol < 0: raise ValueError("Tolerances must be non-negative.") if isinstance(self.atol, (int, float)) and self.atol < 0: raise ValueError("Tolerances must be non-negative.") if isinstance(self.conlim, (int, float)) and self.conlim < 0: raise ValueError("Tolerances must be non-negative.") if isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)): if self.atol == 0 and self.rtol == 0 and self.max_steps is None: raise ValueError( "Must specify `atol`, `rtol`, or `max_steps` (or some combination " "of all three)." ) def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): return linearise(operator) def compute( self, state: _LSMRState, vector: PyTree[Array], options: dict[str, Any], ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: operator = state x = options.get("y0", None) # damp is not supported at this time. # damp = options.get("damp", 0.0) damp = 0.0 has_scale = not ( isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)) and self.atol == 0 and self.rtol == 0 ) dtype = jnp.result_type( *jtu.tree_leaves(vector), *jtu.tree_leaves(x), *jtu.tree_leaves(operator.in_structure()), ) m, n = operator.out_size(), operator.in_size() # number of singular values min_dim = min([m, n]) if self.max_steps is None: # Set max_steps based on the minimum dimension + avoid numerical overflows # https://github.com/patrick-kidger/lineax/issues/175 # https://github.com/patrick-kidger/lineax/issues/177 int_dtype = jnp.dtype(f"int{complex_to_real_dtype(dtype).itemsize * 8}") if min_dim > (jnp.iinfo(int_dtype).max / 10): max_steps = jnp.iinfo(int_dtype).max else: max_steps = min_dim * 10 # for consistency with other iterative solvers else: max_steps = self.max_steps if x is None: x = jtu.tree_map(jnp.zeros_like, operator.in_structure()) b = vector u = (ω(b) - ω(operator.mv(x))).ω normb = self.norm(b) beta = self.norm(u) def beta_nonzero(beta, u): u = (ω(u) / lax.select(beta == 0.0, 1.0, beta).astype(dtype)).ω v = conj(operator).T.mv(u) alpha = self.norm(v) return u, v, alpha def beta_zero(beta, u): v = jtu.tree_map(jnp.zeros_like, operator.in_structure()) alpha = 0.0 return u, v, alpha u, v, alpha = lax.cond(beta == 0.0, beta_zero, beta_nonzero, beta, u) v = (ω(v) / lax.select(alpha == 0.0, 1.0, alpha).astype(dtype)).ω h = v hbar = jtu.tree_map(jnp.zeros_like, operator.in_structure()) # Initialize variables for 1st iteration. # generally, latin letters (b, x, u, v, h etc) are vectors that may be complex # greek letters (alpha, beta, rho, zeta etc) are scalars that are always real loop_state = dict( # vectors x=x, u=u, v=v, h=h, hbar=hbar, # main loop variables itn=0, alpha=alpha, beta=beta, zetabar=alpha * beta, alphabar=alpha, rho=1.0, rhobar=1.0, cbar=1.0, sbar=0.0, # loop variables for estimation of ||r||. betadd=beta, betad=0.0, rhodold=1.0, tautildeold=0.0, thetatilde=0.0, zeta=0.0, delta=0.0, # variables for estimation of ||A|| and cond(A) normA2=alpha**2, maxrbar=0.0, minrbar=jnp.finfo(dtype).max, condA=1.0, # variables for use in stopping rules istop=0, normr=beta, normAr=alpha * beta, ) # beta == 0 means x exactly solves the well posed problem # alpha == 0 means x exactly solves the least squares problem # we check this here to shortcut the loop to avoid division by zero loop_state["istop"] = lax.select(alpha == 0, 2, loop_state["istop"]) loop_state["istop"] = lax.select(beta == 0, 1, loop_state["istop"]) def condfun(loop_state): return loop_state["istop"] == 0 def bodyfun(loop_state): st = loop_state # to avoid writing out loop_state every time st["itn"] = st["itn"] + 1 # Perform the next step of the bidiagonalization to obtain the # next beta, u, alpha, v. These satisfy the relations # beta*u = A@v - alpha*u, # alpha*v = A'@u - beta*v. st["u"] = (ω(st["u"]) * -st["alpha"].astype(dtype)).ω st["u"] = (ω(st["u"]) + ω(operator.mv(st["v"]))).ω st["beta"] = self.norm(st["u"]) def beta_nonzero(alpha, beta, u, v): u = (ω(u) / lax.select(beta == 0.0, 1.0, beta).astype(dtype)).ω v = (ω(v) * -beta.astype(dtype)).ω v = (ω(v) + ω(conj(operator).T.mv(u))).ω alpha = self.norm(v) v = (ω(v) / lax.select(alpha == 0.0, 1.0, alpha).astype(dtype)).ω return alpha, beta, u, v def beta_zero(alpha, beta, u, v): return alpha, beta, u, v st["alpha"], st["beta"], st["u"], st["v"] = lax.cond( st["beta"] == 0, beta_zero, beta_nonzero, st["alpha"], st["beta"], st["u"], st["v"], ) # At this point, beta = beta_{k+1}, alpha = alpha_{k+1}. # Construct rotation Qhat_{k,2k+1}. chat, shat, alphahat = self._givens(st["alphabar"], damp) # Use a plane rotation (Q_i) to turn B_i to R_i rhoold = st["rho"] c, s, st["rho"] = self._givens(alphahat, st["beta"]) thetanew = s * st["alpha"] st["alphabar"] = c * st["alpha"] # Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar rhobarold = st["rhobar"] zetaold = st["zeta"] thetabar = st["sbar"] * st["rho"] rhotemp = st["cbar"] * st["rho"] st["cbar"], st["sbar"], st["rhobar"] = self._givens( st["cbar"] * st["rho"], thetanew ) st["zeta"] = st["cbar"] * st["zetabar"] st["zetabar"] = -st["sbar"] * st["zetabar"] # Update h, h_hat, x. st["hbar"] = ( ω(st["hbar"]) * -(thetabar * st["rho"] / (rhoold * rhobarold)).astype(dtype) ).ω st["hbar"] = (ω(st["hbar"]) + ω(st["h"])).ω st["x"] = ( ω(st["x"]) + (st["zeta"] / (st["rho"] * st["rhobar"])).astype(dtype) * ω(st["hbar"]) ).ω st["h"] = (ω(st["h"]) * -(thetanew / st["rho"]).astype(dtype)).ω st["h"] = (ω(st["h"]) + ω(st["v"])).ω # Estimate of ||r||. # Apply rotation Qhat_{k,2k+1}. betaacute = chat * st["betadd"] betacheck = -shat * st["betadd"] # Apply rotation Q_{k,k+1}. betahat = c * betaacute st["betadd"] = -s * betaacute # Apply rotation Qtilde_{k-1}. # betad = betad_{k-1} here. thetatildeold = st["thetatilde"] ctildeold, stildeold, rhotildeold = self._givens(st["rhodold"], thetabar) st["thetatilde"] = stildeold * loop_state["rhobar"] st["rhodold"] = ctildeold * st["rhobar"] st["betad"] = -stildeold * st["betad"] + ctildeold * betahat # betad = betad_k here. # rhodold = rhod_k here. loop_state["tautildeold"] = ( zetaold - thetatildeold * st["tautildeold"] ) / rhotildeold taud = (st["zeta"] - st["thetatilde"] * st["tautildeold"]) / st["rhodold"] st["delta"] = st["delta"] + betacheck**2 st["normr"] = jnp.sqrt( st["delta"] + (st["betad"] - taud) ** 2 + st["betadd"] ** 2 ) # Estimate ||A||. st["normA2"] = st["normA2"] + st["beta"] ** 2 normA = jnp.sqrt(st["normA2"]) st["normA2"] = st["normA2"] + st["alpha"] ** 2 # Estimate cond(A). st["maxrbar"] = jnp.maximum(st["maxrbar"], rhobarold) st["minrbar"] = lax.select( st["itn"] > 1, jnp.minimum(st["minrbar"], rhobarold), st["minrbar"] ) st["condA"] = jnp.maximum(st["maxrbar"], rhotemp) / jnp.minimum( st["minrbar"], rhotemp ) # Compute norms for convergence testing. st["normAr"] = jnp.abs(st["zetabar"]) normx = self.norm(st["x"]) well_posed_tol = self.atol + self.rtol * (normA * normx + normb) least_squares_tol = self.atol + self.rtol * (normA * st["normr"]) # maxiter exceeded st["istop"] = lax.select(st["itn"] >= max_steps, 4, st["istop"]) # cond(A) seems to be greater than conlim st["istop"] = lax.select(st["condA"] > self.conlim, 3, st["istop"]) # x solves the least-squares problem according to atol and rtol. st["istop"] = lax.select(st["normAr"] < least_squares_tol, 2, st["istop"]) # x is a solution to A@x = b, according to atol and rtol. st["istop"] = lax.select(st["normr"] < well_posed_tol, 1, st["istop"]) return st loop_state = lax.while_loop(condfun, bodyfun, loop_state) stats = { "num_steps": loop_state["itn"], "istop": loop_state["istop"], "norm_r": loop_state["normr"], "norm_Ar": loop_state["normAr"], "norm_A": jnp.sqrt(loop_state["normA2"]), "cond_A": loop_state["condA"], "norm_x": self.norm(loop_state["x"]), } if self.max_steps is None: result = RESULTS.where( loop_state["itn"] == max_steps, RESULTS.singular, RESULTS.successful ) elif has_scale: result = RESULTS.where( loop_state["itn"] == max_steps, RESULTS.max_steps_reached, RESULTS.successful, ) else: result = RESULTS.successful result = RESULTS.where(loop_state["istop"] < 3, RESULTS.successful, result) result = RESULTS.where(loop_state["istop"] == 3, RESULTS.conlim, result) return loop_state["x"], result, stats def _givens(self, a, b): """Stable implementation of Givens rotation, from [1]_ finds c, s, r such that |c -s|[a| = |r| [s c|[b| |0| r = sqrt(a^2 + b^2) Assumes a, b are real. References ---------- .. [1] S.-C. Choi, "Iterative Methods for Singular Linear Equations and Least-Squares Problems", Dissertation, http://www.stanford.edu/group/SOL/dissertations/sou-cheng-choi-thesis.pdf """ assert not jnp.iscomplexobj(a) assert not jnp.iscomplexobj(b) def bzero(a, b): return jnp.sign(a), 0.0, jnp.abs(a) def azero(a, b): return 0.0, jnp.sign(b), jnp.abs(b) def b_gt_a(a, b): tau = a / lax.select(b == 0.0, 1.0, b) s = jnp.sign(b) / jnp.sqrt(1.0 + tau**2) c = s * tau r = b / lax.select(s == 0.0, 1.0, s) return c, s, r def a_ge_b(a, b): tau = b / lax.select(a == 0.0, 1.0, a) c = jnp.sign(a) / jnp.sqrt(1.0 + tau**2) s = c * tau r = a / lax.select(c == 0.0, 1.0, c) return c, s, r def either_zero(a, b): return lax.cond(b == 0.0, bzero, azero, a, b) def both_nonzero(a, b): return lax.cond(jnp.abs(b) > jnp.abs(a), b_gt_a, a_ge_b, a, b) return lax.cond((a == 0.0) | (b == 0.0), either_zero, both_nonzero, a, b) def transpose(self, state: _LSMRState, options: dict[str, Any]): del options operator = state transpose_options = {} return operator.transpose(), transpose_options def conj(self, state: _LSMRState, options: dict[str, Any]): del options operator = state conj_options = {} return conj(operator), conj_options def assume_full_rank(self): return False LSMR.__init__.__doc__ = r"""**Arguments:** - `rtol`: Relative tolerance for terminating solve. - `atol`: Absolute tolerance for terminating solve. - `norm`: The norm to use when computing whether the error falls within the tolerance. Defaults to the two norm. - `max_steps`: The maximum number of iterations to run the solver for. If more steps than this are required, then the solve is halted with a failure. - `conlim`: The solver terminates if an estimate of cond(A) exceeds conlim. For compatible systems Ax = b, conlim could be as large as 1.0e+12 (say). For least-squares problems, conlim should be less than 1.0e+8. If conlim is None, the default value is 1e+8. Maximum precision can be obtained by setting atol = rtol = 0, conlim = np.inf, but the number of iterations may then be excessive. Default is 1e8. """ ================================================ FILE: lineax/_solver/lu.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, TypeAlias import equinox.internal as eqxi import jax.numpy as jnp import jax.scipy as jsp from jaxtyping import Array, PyTree from .._operator import AbstractLinearOperator, is_diagonal from .._solution import RESULTS from .._solve import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, ravel_vector, transpose_packed_structures, unravel_solution, ) _LUState: TypeAlias = tuple[tuple[Array, Array], PackedStructures, eqxi.Static] class LU(AbstractLinearSolver[_LUState]): """LU solver for linear systems. This solver can only handle square nonsingular operators. """ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): del options if operator.in_size() != operator.out_size(): raise ValueError( "`LU` may only be used for linear solves with square matrices" ) packed_structures = pack_structures(operator) if is_diagonal(operator): lu = operator.as_matrix(), jnp.arange(operator.in_size(), dtype=jnp.int32) else: lu = jsp.linalg.lu_factor(operator.as_matrix()) return lu, packed_structures, eqxi.Static(False) def compute( self, state: _LUState, vector: PyTree[Array], options: dict[str, Any] ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: del options lu_and_piv, packed_structures, transpose = state transpose = transpose.value trans = 1 if transpose else 0 vector = ravel_vector(vector, packed_structures) solution = jsp.linalg.lu_solve(lu_and_piv, vector, trans=trans) solution = unravel_solution(solution, packed_structures) return solution, RESULTS.successful, {} def transpose( self, state: _LUState, options: dict[str, Any], ): lu_and_piv, packed_structures, transpose = state transposed_packed_structures = transpose_packed_structures(packed_structures) transpose_state = ( lu_and_piv, transposed_packed_structures, eqxi.Static(not transpose.value), ) transpose_options = {} return transpose_state, transpose_options def conj( self, state: _LUState, options: dict[str, Any], ): (lu, piv), packed_structures, transpose = state conj_state = ( (lu.conj(), piv), packed_structures, eqxi.Static(not transpose.value), ) conj_options = {} return conj_state, conj_options def assume_full_rank(self): return True LU.__init__.__doc__ = """**Arguments:** Nothing. """ ================================================ FILE: lineax/_solver/misc.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import typing import warnings from typing import Any, NewType, TYPE_CHECKING import equinox.internal as eqxi import jax.numpy as jnp import jax.tree_util as jtu import numpy as np from jaxtyping import Array, PyTree, Shaped from .._misc import strip_weak_dtype, structure_equal from .._operator import AbstractLinearOperator, IdentityLinearOperator, linearise def preconditioner_and_y0( operator: AbstractLinearOperator, vector: PyTree[Array], options: dict[str, Any] ): structure = operator.in_structure() try: preconditioner = linearise(options["preconditioner"]) except KeyError: preconditioner = IdentityLinearOperator(structure) else: if not isinstance(preconditioner, AbstractLinearOperator): raise ValueError("The preconditioner must be a linear operator.") if not structure_equal(preconditioner.in_structure(), structure): raise ValueError( "The preconditioner must have `in_structure` that matches the " "operator's `in_strucure`." ) if not structure_equal(preconditioner.out_structure(), structure): raise ValueError( "The preconditioner must have `out_structure` that matches the " "operator's `in_structure`." ) try: y0 = options["y0"] except KeyError: y0 = jtu.tree_map(jnp.zeros_like, vector) else: if not structure_equal(y0, vector): raise ValueError( "`y0` must have the same structure, shape, and dtype as `vector`" ) return preconditioner, y0 # This seems to introduce some spurious failure at docgen time. if hasattr(typing, "GENERATING_DOCUMENTATION") and not TYPE_CHECKING: PackedStructures = lambda x: x else: PackedStructures = NewType("PackedStructures", eqxi.Static) def pack_structures(operator: AbstractLinearOperator) -> PackedStructures: structures = ( strip_weak_dtype(operator.out_structure()), strip_weak_dtype(operator.in_structure()), ) leaves, treedef = jtu.tree_flatten(structures) # handle nonhashable pytrees return PackedStructures(eqxi.Static((leaves, treedef))) def ravel_vector( pytree: PyTree[Array], packed_structures: PackedStructures ) -> Shaped[Array, " size"]: leaves, treedef = packed_structures.value out_structure, _ = jtu.tree_unflatten(treedef, leaves) # `is` in case `tree_equal` returns a Tracer. if not structure_equal(pytree, out_structure): raise ValueError("pytree does not match out_structure") # not using `ravel_pytree` as that doesn't come with guarantees about order leaves = jtu.tree_leaves(pytree) dtype = jnp.result_type(*leaves) return jnp.concatenate([x.astype(dtype).reshape(-1) for x in leaves]) def unravel_solution( solution: Shaped[Array, " size"], packed_structures: PackedStructures ) -> PyTree[Array]: leaves, treedef = packed_structures.value _, in_structure = jtu.tree_unflatten(treedef, leaves) leaves, treedef = jtu.tree_flatten(in_structure) sizes = np.cumsum([math.prod(x.shape) for x in leaves[:-1]]) split = jnp.split(solution, sizes) assert len(split) == len(leaves) with warnings.catch_warnings(): warnings.simplefilter("ignore") # ignore complex-to-real cast warning shaped = [x.reshape(y.shape).astype(y.dtype) for x, y in zip(split, leaves)] return jtu.tree_unflatten(treedef, shaped) def transpose_packed_structures( packed_structures: PackedStructures, ) -> PackedStructures: leaves, treedef = packed_structures.value out_structure, in_structure = jtu.tree_unflatten(treedef, leaves) leaves, treedef = jtu.tree_flatten((in_structure, out_structure)) return PackedStructures(eqxi.Static((leaves, treedef))) ================================================ FILE: lineax/_solver/normal.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import copy from typing import Any, TypeVar import equinox.internal as eqxi from jaxtyping import Array, PyTree from .._operator import conj, linearise, materialise, TaggedLinearOperator from .._solution import RESULTS from .._solve import AbstractLinearOperator, AbstractLinearSolver from .._tags import positive_semidefinite_tag from .cholesky import Cholesky _InnerSolverState = TypeVar("_InnerSolverState") def normal_preconditioner_and_y0(options: dict[str, Any], tall: bool): preconditioner = options.get("preconditioner") y0 = options.get("y0") inner_options = copy(options) del options if preconditioner is not None: preconditioner = linearise(preconditioner) if tall: inner_options["preconditioner"] = TaggedLinearOperator( preconditioner @ conj(preconditioner.transpose()), positive_semidefinite_tag, ) else: inner_options["preconditioner"] = TaggedLinearOperator( conj(preconditioner.transpose()) @ preconditioner, positive_semidefinite_tag, ) if y0 is not None: inner_options["y0"] = conj(preconditioner.transpose()).mv(y0) return inner_options class Normal( AbstractLinearSolver[ tuple[_InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any]] ] ): """Wrapper for an inner solver of positive (semi)definite systems. The wrapped solver handles possibly nonsquare systems $Ax = b$ by applying the inner solver to the normal equations $A^* A x = A^* b$ if $m \\ge n$, otherwise $A A^* y = b$, where $x = A^* y$. If the inner solver solves systems with positive definite $A$, the wrapped solver solves systems with full rank $A$. If the inner solver solves systems with positive semidefinite $A$, the wrapped solver solves systems with arbitrary, possibly rank deficient, $A$. Note that this squares the condition number, so applying this method to an iterative inner solver may result in slow convergence and high sensitivity to roundoff error. In this case it may be advantageous to choose an appropriate preconditioner or initial solution guess for the problem. This wrapper adjusts the following `options` before passing to the inner operator (as passed to `lx.linear_solve(..., options=...)`). - `preconditioner`: A [`lineax.AbstractLinearOperator`][] to be used as preconditioner. Defaults to [`lineax.IdentityLinearOperator`][]. This should be an approximation of the (pseudo)inverse of $A$. When passed to the inner solver, the preconditioner $M$ is replaced by $M M^*$ and $M^* M$ in the first and second versions of the normal equations, respectively. - `y0`: An initial estimate of the solution of the linear system $Ax = b$. Defaults to all zeros. In the second version of the normal equations, $y_0$ is replaced with $M^* y_0$, where $M$ is the given outer preconditioner. !!! Info Good choices of inner solvers are the direct [`lineax.Cholesky`][] and the iterative [`lineax.CG`][]. """ inner_solver: AbstractLinearSolver[_InnerSolverState] def init(self, operator, options): tall = operator.out_size() >= operator.in_size() # Cholesky materialises op twice when computing (op^H @ op).as_matrix() # Cheaper to materialise first and then conjugate-transpose. # For iterative solvers we only linearise to avoid eager materialisation. is_cholesky = isinstance(self.inner_solver, Cholesky) lin_op = materialise(operator) if is_cholesky else linearise(operator) if tall: inner_operator = conj(lin_op.transpose()) @ lin_op else: inner_operator = lin_op @ conj(lin_op.transpose()) inner_operator = TaggedLinearOperator(inner_operator, positive_semidefinite_tag) inner_options = normal_preconditioner_and_y0(options, tall) inner_state = self.inner_solver.init(inner_operator, inner_options) operator_conj_transpose = conj(lin_op.transpose()) return inner_state, eqxi.Static(tall), operator_conj_transpose, inner_options def compute( self, state: tuple[ _InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any] ], vector: PyTree[Array], options: dict[str, Any], ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: inner_state, tall, operator_conj_transpose, inner_options = state tall = tall.value del state, options if tall: vector = operator_conj_transpose.mv(vector) solution, result, extra_stats = self.inner_solver.compute( inner_state, vector, inner_options ) if not tall: solution = operator_conj_transpose.mv(solution) return solution, result, extra_stats def transpose( self, state: tuple[ _InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any] ], options: dict[str, Any], ): inner_state, tall, operator_conj_transpose, inner_options = state inner_state_conj, inner_options = self.inner_solver.conj( inner_state, inner_options ) state_transpose = ( inner_state_conj, eqxi.Static(not tall.value), operator_conj_transpose.transpose(), inner_options, ) return state_transpose, options def conj( self, state: tuple[ _InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any] ], options: dict[str, Any], ): inner_state, tall, operator_conj_transpose, inner_options = state inner_state_conj, inner_options = self.inner_solver.conj( inner_state, inner_options ) state_conj = ( inner_state_conj, tall, conj(operator_conj_transpose), inner_options, ) return state_conj, options def assume_full_rank(self): return self.inner_solver.assume_full_rank() Normal.__init__.__doc__ = """**Arguments:** - `inner_solver`: The solver to wrap. It should support solving positive definite systems or positive semidefinite systems """ ================================================ FILE: lineax/_solver/qr.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, TypeAlias import equinox.internal as eqxi import jax.lax.linalg as jll import jax.numpy as jnp import jax.scipy as jsp from jaxtyping import Array, PyTree from .._solution import RESULTS from .._solve import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, ravel_vector, transpose_packed_structures, unravel_solution, ) _QRState: TypeAlias = tuple[tuple[Array, Array], eqxi.Static, PackedStructures] class QR(AbstractLinearSolver): """QR solver for linear systems. This solver can handle non-square operators. This is usually the preferred solver when dealing with non-square operators. !!! info Note that whilst this does handle non-square operators, it still can only handle full-rank operators. This is because JAX does not currently support a rank-revealing/pivoted QR decomposition, see [issue #12897](https://github.com/google/jax/issues/12897). For such use cases, switch to [`lineax.SVD`][] instead. """ def init(self, operator, options): del options matrix = operator.as_matrix() m, n = matrix.shape transpose = n > m if transpose: matrix = matrix.T h, taus = jnp.linalg.qr(matrix, mode="raw") # pyright: ignore a = h.mT packed_structures = pack_structures(operator) return (a, taus), eqxi.Static(transpose), packed_structures def compute( self, state: _QRState, vector: PyTree[Array], options: dict[str, Any], ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: (a, taus), transpose, packed_structures = state transpose = transpose.value del state, options vector = ravel_vector(vector, packed_structures) n_full, n_min = a.shape r = a[:n_min] if transpose: # Minimal norm solution if underdetermined: x = Q.conj() @ R^{-T} @ b. # Use Q.conj() @ z = (z^T @ Q^H)^T to avoid explicit `conj` calls, # and pad `y` along the row axis to absorb the discarded columns of Q. y = jsp.linalg.solve_triangular(r, vector, trans="T", unit_diagonal=False) zeros = jnp.zeros((1, n_full - n_min), dtype=y.dtype) y_pad = jnp.concatenate([y[None, :], zeros], axis=1) solution = jll.ormqr(a, taus, y_pad, left=False, transpose=True)[0] else: # Least squares solution if overdetermined. qHv = jll.ormqr(a, taus, vector[:, None], transpose=True)[:n_min, 0] solution = jsp.linalg.solve_triangular( r, qHv, trans="N", unit_diagonal=False ) solution = unravel_solution(solution, packed_structures) return solution, RESULTS.successful, {} def transpose(self, state: _QRState, options: dict[str, Any]): (a, taus), transpose, structures = state transposed_packed_structures = transpose_packed_structures(structures) transpose_state = ( (a, taus), eqxi.Static(not transpose.value), transposed_packed_structures, ) transpose_options = {} return transpose_state, transpose_options def conj(self, state: _QRState, options: dict[str, Any]): (a, taus), transpose, structures = state conj_state = ( (a.conj(), taus.conj()), transpose, structures, ) conj_options = {} return conj_state, conj_options def assume_full_rank(self): return True QR.__init__.__doc__ = """**Arguments:** Nothing. """ ================================================ FILE: lineax/_solver/svd.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, TypeAlias import jax.lax as lax import jax.numpy as jnp import jax.scipy as jsp from jaxtyping import Array, PyTree from .._misc import resolve_rcond from .._operator import AbstractLinearOperator from .._solution import RESULTS from .._solve import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, ravel_vector, transpose_packed_structures, unravel_solution, ) _SVDState: TypeAlias = tuple[tuple[Array, Array, Array], PackedStructures] class SVD(AbstractLinearSolver[_SVDState]): """SVD solver for linear systems. This solver can handle any operator, even nonsquare or singular ones. In these cases it will return the pseudoinverse solution to the linear system. Equivalent to `scipy.linalg.lstsq`. """ rcond: float | None = None def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): del options svd = jsp.linalg.svd(operator.as_matrix(), full_matrices=False) packed_structures = pack_structures(operator) return svd, packed_structures def compute( self, state: _SVDState, vector: PyTree[Array], options: dict[str, Any], ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: del options (u, s, vt), packed_structures = state vector = ravel_vector(vector, packed_structures) m, _ = u.shape _, n = vt.shape rcond = resolve_rcond(self.rcond, n, m, s.dtype) rcond = jnp.array(rcond, dtype=s.dtype) if s.size > 0: rcond = rcond * s[0] # Not >=, or this fails with a matrix of all-zeros. mask = s > rcond rank = mask.sum() safe_s = jnp.where(mask, s, 1) s_inv = jnp.where(mask, jnp.array(1.0) / safe_s, 0).astype(u.dtype) uTb = jnp.matmul(u.conj().T, vector, precision=lax.Precision.HIGHEST) solution = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) solution = unravel_solution(solution, packed_structures) return solution, RESULTS.successful, {"rank": rank} def transpose(self, state: _SVDState, options: dict[str, Any]): del options (u, s, vt), packed_structures = state transposed_packed_structures = transpose_packed_structures(packed_structures) transpose_state = (vt.T, s, u.T), transposed_packed_structures transpose_options = {} return transpose_state, transpose_options def conj(self, state: _SVDState, options: dict[str, Any]): del options (u, s, vt), packed_structures = state conj_state = (u.conj(), s, vt.conj()), packed_structures conj_options = {} return conj_state, conj_options def assume_full_rank(self): return False SVD.__init__.__doc__ = """**Arguments**: - `rcond`: the cutoff for handling zero entries on the diagonal. Defaults to machine precision times `max(N, M)`, where `(N, M)` is the shape of the operator. (I.e. `N` is the output size and `M` is the input size.) """ ================================================ FILE: lineax/_solver/triangular.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, TypeAlias import equinox.internal as eqxi import jax.scipy as jsp from jaxtyping import Array, PyTree from .._operator import ( AbstractLinearOperator, has_unit_diagonal, is_lower_triangular, is_upper_triangular, ) from .._solution import RESULTS from .._solve import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, ravel_vector, transpose_packed_structures, unravel_solution, ) _TriangularState: TypeAlias = tuple[ Array, eqxi.Static, eqxi.Static, PackedStructures, eqxi.Static ] class Triangular(AbstractLinearSolver[_TriangularState]): """Triangular solver for linear systems. The operator should either be lower triangular or upper triangular. """ def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): del options if operator.in_size() != operator.out_size(): raise ValueError( "`Triangular` may only be used for linear solves with square matrices" ) if not (is_lower_triangular(operator) or is_upper_triangular(operator)): raise ValueError( "`Triangular` may only be used for linear solves with triangular " "matrices" ) return ( operator.as_matrix(), eqxi.Static(is_lower_triangular(operator)), eqxi.Static(has_unit_diagonal(operator)), pack_structures(operator), eqxi.Static(False), # transposed ) def compute( self, state: _TriangularState, vector: PyTree[Array], options: dict[str, Any] ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: matrix, lower, unit_diagonal, packed_structures, transpose = state lower = lower.value unit_diagonal = unit_diagonal.value transpose = transpose.value del state, options vector = ravel_vector(vector, packed_structures) if transpose: trans = "T" else: trans = "N" solution = jsp.linalg.solve_triangular( matrix, vector, trans=trans, lower=lower, unit_diagonal=unit_diagonal ) solution = unravel_solution(solution, packed_structures) return solution, RESULTS.successful, {} def transpose(self, state: _TriangularState, options: dict[str, Any]): del options matrix, lower, unit_diagonal, packed_structures, transpose = state transposed_packed_structures = transpose_packed_structures(packed_structures) transpose_state = ( matrix, lower, unit_diagonal, transposed_packed_structures, eqxi.Static(not transpose.value), ) transpose_options = {} return transpose_state, transpose_options def conj(self, state: _TriangularState, options: dict[str, Any]): del options matrix, lower, unit_diagonal, packed_structures, transpose = state conj_state = ( matrix.conj(), lower, unit_diagonal, packed_structures, transpose, ) conj_options = {} return conj_state, conj_options def assume_full_rank(self): return True Triangular.__init__.__doc__ = """**Arguments:** Nothing. """ ================================================ FILE: lineax/_solver/tridiagonal.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, TypeAlias import jax.lax as lax import jax.numpy as jnp from jaxtyping import Array, PyTree from .._operator import AbstractLinearOperator, is_tridiagonal, tridiagonal from .._solution import RESULTS from .._solve import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, ravel_vector, transpose_packed_structures, unravel_solution, ) _TridiagonalState: TypeAlias = tuple[tuple[Array, Array, Array], PackedStructures] class Tridiagonal(AbstractLinearSolver[_TridiagonalState]): """Tridiagonal solver for linear systems, uses the LAPACK/cusparse implementation of Gaussian elimination with partial pivotting (which increases stability). .""" def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): del options if operator.in_size() != operator.out_size(): raise ValueError( "`Tridiagonal` may only be used for linear solves with square matrices" ) if not is_tridiagonal(operator): raise ValueError( "`Tridiagonal` may only be used for linear solves with tridiagonal " "matrices" ) return tridiagonal(operator), pack_structures(operator) def compute( self, state: _TridiagonalState, vector, options, ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: (diagonal, lower_diagonal, upper_diagonal), packed_structures = state del state, options vector = ravel_vector(vector, packed_structures) solution = lax.linalg.tridiagonal_solve( jnp.append(0.0, lower_diagonal), diagonal, jnp.append(upper_diagonal, 0.0), vector[:, None], ).flatten() solution = unravel_solution(solution, packed_structures) return solution, RESULTS.successful, {} def transpose(self, state: _TridiagonalState, options: dict[str, Any]): (diagonal, lower_diagonal, upper_diagonal), packed_structures = state transposed_packed_structures = transpose_packed_structures(packed_structures) transpose_diagonals = (diagonal, upper_diagonal, lower_diagonal) transpose_state = (transpose_diagonals, transposed_packed_structures) return transpose_state, options def conj(self, state: _TridiagonalState, options: dict[str, Any]): (diagonal, lower_diagonal, upper_diagonal), packed_structures = state conj_diagonals = (diagonal.conj(), lower_diagonal.conj(), upper_diagonal.conj()) conj_state = (conj_diagonals, packed_structures) return conj_state, options def assume_full_rank(self): return True Tridiagonal.__init__.__doc__ = """**Arguments:** Nothing. """ ================================================ FILE: lineax/_tags.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. class _HasRepr: def __init__(self, string: str): self.string = string def __repr__(self): return self.string symmetric_tag = _HasRepr("symmetric_tag") diagonal_tag = _HasRepr("diagonal_tag") tridiagonal_tag = _HasRepr("tridiagonal_tag") unit_diagonal_tag = _HasRepr("unit_diagonal_tag") lower_triangular_tag = _HasRepr("lower_triangular_tag") upper_triangular_tag = _HasRepr("upper_triangular_tag") positive_semidefinite_tag = _HasRepr("positive_semidefinite_tag") negative_semidefinite_tag = _HasRepr("negative_semidefinite_tag") transpose_tags_rules = [] for tag in ( symmetric_tag, unit_diagonal_tag, diagonal_tag, positive_semidefinite_tag, negative_semidefinite_tag, tridiagonal_tag, ): @transpose_tags_rules.append def _(tags: frozenset[object], tag=tag): if tag in tags: return tag @transpose_tags_rules.append def _(tags: frozenset[object]): if lower_triangular_tag in tags: return upper_triangular_tag @transpose_tags_rules.append def _(tags: frozenset[object]): if upper_triangular_tag in tags: return lower_triangular_tag def transpose_tags(tags: frozenset[object]): """Lineax uses "tags" to declare that a particular linear operator exhibits some property, e.g. symmetry. This function takes in a collection of tags representing a linear operator, and returns a collection of tags that should be associated with the transpose of that linear operator. **Arguments:** - `tags`: a `frozenset` of tags. **Returns:** A `frozenset` of tags. """ if symmetric_tag in tags: return tags new_tags = [] for rule in transpose_tags_rules: out = rule(tags) if out is not None: new_tags.append(out) return frozenset(new_tags) ================================================ FILE: lineax/internal/__init__.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .._misc import ( complex_to_real_dtype as complex_to_real_dtype, default_floating_dtype as default_floating_dtype, ) from .._norm import ( max_norm as max_norm, rms_norm as rms_norm, sum_squares as sum_squares, tree_dot as tree_dot, two_norm as two_norm, ) from .._solve import linear_solve_p as linear_solve_p from .._solver.misc import ( pack_structures as pack_structures, PackedStructures as PackedStructures, ravel_vector as ravel_vector, transpose_packed_structures as transpose_packed_structures, unravel_solution as unravel_solution, ) ================================================ FILE: mkdocs.yml ================================================ theme: name: material features: - navigation.sections # Sections are included in the navigation on the left. - toc.integrate # Table of contents is integrated on the left; does not appear separately on the right. - header.autohide # header disappears as you scroll palette: # Light mode / dark mode # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. - scheme: default primary: white accent: amber toggle: icon: material/weather-night name: Switch to dark mode - scheme: slate primary: black accent: amber toggle: icon: material/weather-sunny name: Switch to light mode icon: repo: fontawesome/brands/github # GitHub logo in top right logo: "material/matrix" # lineax logo in top left favicon: "_static/favicon.png" custom_dir: "docs/_overrides" # Overriding part of the HTML # These additions are my own custom ones, having overridden a partial. twitter_bluesky_name: "@PatrickKidger" twitter_url: "https://twitter.com/PatrickKidger" bluesky_url: "https://PatrickKidger.bsky.social" site_name: lineax site_description: The documentation for the Lineax software library. site_author: Patrick Kidger site_url: https://docs.kidger.site/lineax repo_url: https://github.com/patrick-kidger/lineax repo_name: patrick-kidger/lineax edit_uri: "" strict: true # Don't allow warnings during the build process extra_javascript: # The below two make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ - _static/mathjax.js - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js extra_css: - _static/custom_css.css markdown_extensions: - pymdownx.arithmatex: # Render LaTeX via MathJax generic: true - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. - pymdownx.details # Allowing hidden expandable regions denoted by ??? - pymdownx.snippets: # Include one Markdown file into another base_path: docs - admonition - toc: permalink: "¤" # Adds a clickable permalink to each section heading toc_depth: 4 plugins: - search: separator: '[\s\-,:!=\[\]()"/]+|(?!\b)(?=[A-Z][a-z])|\.(?!\d)|&[lg]t;' - include_exclude_files: include: - ".htaccess" exclude: - "_overrides" - "examples/.ipynb_checkpoints/" - ipynb - hippogriffe: extra_public_objects: - jax.ShapeDtypeStruct - mkdocstrings: handlers: python: options: force_inspection: true heading_level: 4 inherited_members: true members_order: source show_bases: false show_if_no_docstring: true show_overloads: false show_root_heading: true show_signature_annotations: true show_source: false show_symbol_type_heading: true show_symbol_type_toc: true nav: - 'index.md' - Examples: - 'examples/classical_solve.ipynb' - 'examples/least_squares.ipynb' - 'examples/structured_matrices.ipynb' - 'examples/no_materialisation.ipynb' - 'examples/operators.ipynb' - 'examples/complex_solve.ipynb' - API: - 'api/linear_solve.md' - 'api/solvers.md' - 'api/operators.md' - 'api/tags.md' - 'api/solution.md' - 'api/functions.md' - 'faq.md' ================================================ FILE: pyproject.toml ================================================ [build-system] build-backend = "hatchling.build" requires = ["hatchling"] [dependency-groups] dev = [ "prek==0.3.9", "pyright==1.1.406", "ruff==0.13.0", "toml-sort==0.23.1" ] docs = [ "hippogriffe==0.2.2", "griffe==1.7.3", "mkdocs==1.6.1", "mkdocs-include-exclude-files==0.1.0", "mkdocs-ipynb==0.1.1", "mkdocs-material==9.6.7", "mkdocstrings==0.28.3", "mkdocstrings-python==1.16.8", "pygments==2.20.0", "pymdown-extensions==10.21.2" ] tests = [ "beartype", "equinox", "pytest", "pytest-xdist", "jaxlib" ] [project] authors = [ {email = "raderjason@outlook.com", name = "Jason Rader"}, {email = "contact@kidger.site", name = "Patrick Kidger"} ] classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Intended Audience :: Financial and Insurance Industry", "Intended Audience :: Information Technology", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Natural Language :: English", "Programming Language :: Python :: 3", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Information Analysis", "Topic :: Scientific/Engineering :: Mathematics" ] dependencies = ["jax>=0.10.0", "jaxtyping>=0.2.24", "equinox>=0.11.10", "typing_extensions>=4.5.0"] description = "Linear solvers in JAX and Equinox." keywords = ["jax", "neural-networks", "deep-learning", "equinox", "linear-solvers", "least-squares", "numerical-methods"] license = {file = "LICENSE"} name = "lineax" readme = "README.md" requires-python = "~=3.11" urls = {repository = "https://github.com/google/lineax"} version = "0.1.1" [tool.hatch.build] include = ["lineax/*"] [tool.pyright] include = ["lineax", "tests"] reportIncompatibleMethodOverride = true [tool.pytest.ini_options] addopts = "--jaxtyping-packages=lineax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))" [tool.ruff] extend-include = ["*.ipynb"] src = [] [tool.ruff.lint] fixable = ["I001", "F401", "UP"] ignore = ["E402", "E721", "E731", "E741", "F722"] select = ["E", "F", "I001", "UP"] [tool.ruff.lint.flake8-import-conventions.extend-aliases] "collections" = "co" "functools" = "ft" "itertools" = "it" [tool.ruff.lint.isort] combine-as-imports = true extra-standard-library = ["typing_extensions"] lines-after-imports = 2 order-by-type = false [tool.uv] default-groups = ["dev", "docs", "tests"] ================================================ FILE: tests/README.md ================================================ Each file is run separately to avoid JAX out-of-memory'ing. As such, run tests using `python -m tests`, *not* by just running `pytest`. ================================================ FILE: tests/__init__.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: tests/__main__.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pathlib import subprocess import sys here = pathlib.Path(__file__).resolve().parent # Each file is ran separately to avoid out-of-memorying. running_out = 0 for file in here.iterdir(): if file.is_file() and file.name.startswith("test"): out = subprocess.run(f"pytest {file}", shell=True).returncode running_out = max(running_out, out) sys.exit(running_out) ================================================ FILE: tests/conftest.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import equinox.internal as eqxi import jax import pytest jax.config.update("jax_enable_x64", True) jax.config.update("jax_numpy_dtype_promotion", "strict") jax.config.update("jax_numpy_rank_promotion", "raise") @pytest.fixture def getkey(): return eqxi.GetKey() ================================================ FILE: tests/helpers.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools as ft import math import equinox as eqx import equinox.internal as eqxi import jax import jax.lax as lax import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu import lineax as lx import numpy as np from equinox.internal import ω @ft.cache def _construct_matrix_impl( getkey, tags, size, dtype, cond_or_singular: int | float | str, i: int ): del i # used to break the cache while True: matrix = jr.normal(getkey(), (size, size), dtype=dtype) if isinstance(cond_or_singular, str): if cond_or_singular == "zero": matrix = matrix.at[0, :].set(0) elif cond_or_singular == "trim_row": matrix = matrix[1:, :] elif cond_or_singular == "trim_col": matrix = matrix[:, 1:] if tags != (): assert ( isinstance(cond_or_singular, (int, float)) or cond_or_singular == "zero" ) if has_tag(tags, lx.diagonal_tag): matrix = jnp.diag(jnp.diag(matrix)) if has_tag(tags, lx.symmetric_tag): matrix = matrix + matrix.T if has_tag(tags, lx.lower_triangular_tag): matrix = jnp.tril(matrix) if has_tag(tags, lx.upper_triangular_tag): matrix = jnp.triu(matrix) if has_tag(tags, lx.unit_diagonal_tag): matrix = matrix.at[jnp.arange(size), jnp.arange(size)].set(1) if has_tag(tags, lx.tridiagonal_tag): diagonal = jnp.diag(jnp.diag(matrix)) upper_diagonal = jnp.diag(jnp.diag(matrix, k=1), k=1) lower_diagonal = jnp.diag(jnp.diag(matrix, k=-1), k=-1) matrix = lower_diagonal + diagonal + upper_diagonal if has_tag(tags, lx.positive_semidefinite_tag): matrix = matrix @ matrix.T.conj() if has_tag(tags, lx.negative_semidefinite_tag): matrix = -matrix @ matrix.T.conj() if isinstance(cond_or_singular, str): break else: if eqxi.unvmap_all(jnp.linalg.cond(matrix) < cond_or_singular): # pyright: ignore break return matrix def construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64): if isinstance(solver, lx.Normal): cond_cutoff = math.sqrt(1000) else: cond_cutoff = 1000 return tuple( _construct_matrix_impl(getkey, tags, size, dtype, cond_cutoff, i) for i in range(num) ) def construct_singular_matrix(getkey, solver, tags, num=1, dtype=jnp.float64): if isinstance(solver, (lx.Diagonal, lx.CG, lx.BiCGStab, lx.GMRES)): singular_method = "zero" else: # Use `getkey()` rather than the stdlib `random.choice` for reproducibility singular_method = ["zero", "trim_row", "trim_col"][ jr.choice(getkey(), np.array([0, 1, 2])) ] size = 3 return tuple( _construct_matrix_impl(getkey, tags, size, dtype, singular_method, i) for i in range(num) ) def construct_poisson_matrix(size, dtype=jnp.float64): matrix = ( -2 * jnp.diag(jnp.ones(size, dtype=dtype)) + jnp.diag(jnp.ones(size - 1, dtype=dtype), 1) + jnp.diag(jnp.ones(size - 1, dtype=dtype), -1) ) return matrix if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-12 else: tol = 1e-6 solvers_tags_pseudoinverse = [ (lx.AutoLinearSolver(well_posed=True), (), False), (lx.AutoLinearSolver(well_posed=False), (), True), (lx.Triangular(), lx.lower_triangular_tag, False), (lx.Triangular(), lx.upper_triangular_tag, False), (lx.Triangular(), (lx.lower_triangular_tag, lx.unit_diagonal_tag), False), (lx.Triangular(), (lx.upper_triangular_tag, lx.unit_diagonal_tag), False), (lx.Diagonal(), lx.diagonal_tag, False), (lx.Diagonal(), (lx.diagonal_tag, lx.unit_diagonal_tag), False), (lx.Tridiagonal(), lx.tridiagonal_tag, False), (lx.LU(), (), False), (lx.QR(), (), False), (lx.SVD(), (), True), (lx.BiCGStab(rtol=tol, atol=tol), (), False), (lx.GMRES(rtol=tol, atol=tol), (), False), (lx.CG(rtol=tol, atol=tol), lx.positive_semidefinite_tag, False), (lx.CG(rtol=tol, atol=tol), lx.negative_semidefinite_tag, False), (lx.Normal(lx.CG(rtol=tol, atol=tol)), (), False), (lx.LSMR(atol=tol, rtol=tol), (), True), (lx.Cholesky(), lx.positive_semidefinite_tag, False), (lx.Cholesky(), lx.negative_semidefinite_tag, False), (lx.Normal(lx.Cholesky()), (), False), ] solvers_tags = [(a, b) for a, b, _ in solvers_tags_pseudoinverse] solvers = [a for a, _, _ in solvers_tags_pseudoinverse] pseudosolvers_tags = [(a, b) for a, b, c in solvers_tags_pseudoinverse if c] def _transpose(operator, matrix): return operator.T, matrix.T def _linearise(operator, matrix): return lx.linearise(operator), matrix def _materialise(operator, matrix): return lx.materialise(operator), matrix ops = (lambda x, y: (x, y), _transpose, _linearise, _materialise) def params(only_pseudo): for make_operator in make_operators: for solver, tags, pseudoinverse in solvers_tags_pseudoinverse: if only_pseudo and not pseudoinverse: continue if ( make_operator is make_trivial_diagonal_operator and tags != lx.diagonal_tag ): continue if make_operator is make_identity_operator and tags != lx.unit_diagonal_tag: continue if ( make_operator is make_tridiagonal_operator and tags != lx.tridiagonal_tag ): continue yield make_operator, solver, tags def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8): return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol) def has_tag(tags, tag): return tag is tags or (isinstance(tags, tuple) and tag in tags) make_operators = [] def _operators_append(x): make_operators.append(x) return x @_operators_append def make_matrix_operator(getkey, matrix, tags): return lx.MatrixLinearOperator(matrix, tags) @_operators_append def make_trivial_pytree_operator(getkey, matrix, tags): out_size, _ = matrix.shape struct = jax.ShapeDtypeStruct((out_size,), matrix.dtype) return lx.PyTreeLinearOperator(matrix, struct, tags) @_operators_append def make_function_operator(getkey, matrix, tags): fn = lambda x: matrix @ x _, in_size = matrix.shape in_struct = jax.ShapeDtypeStruct((in_size,), matrix.dtype) return lx.FunctionLinearOperator(fn, in_struct, tags) @_operators_append def make_jac_operator(getkey, matrix, tags): out_size, in_size = matrix.shape x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype) a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) fn_tmp = lambda x, _: a + b @ x + c @ x**2 jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) diff = matrix - jac fn = lambda x, _: a + (b + diff) @ x + c @ x**2 return lx.JacobianLinearOperator(fn, x, None, tags) @_operators_append def make_jacfwd_operator(getkey, matrix, tags): out_size, in_size = matrix.shape x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype) a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) fn_tmp = lambda x, _: a + b @ x + c @ x**2 jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) diff = matrix - jac fn = lambda x, _: a + (b + diff) @ x + c @ x**2 return lx.JacobianLinearOperator(fn, x, None, tags, jac="fwd") @_operators_append def make_jacrev_operator(getkey, matrix, tags): """JacobianLinearOperator with jac='bwd' using a custom_vjp function. This uses custom_vjp so that forward-mode autodiff is NOT available, which tests that jac='bwd' works correctly without relying on JVP. """ out_size, in_size = matrix.shape x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype) a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) fn_tmp = lambda x, _: a + b @ x + c @ x**2 jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None) diff = matrix - jac # Use custom_vjp to define a function that only has reverse-mode autodiff @jax.custom_vjp def custom_fn(x): return a + (b + diff) @ x + c @ x**2 def custom_fn_fwd(x): return custom_fn(x), x def custom_fn_bwd(x, g): # Jacobian is: (b + diff) + 2 * c * x # VJP is: g @ J = g @ ((b + diff) + 2 * c * x) # So J.T @ g = return ((b + diff).T @ g + 2 * (c.T @ g) * x,) custom_fn.defvjp(custom_fn_fwd, custom_fn_bwd) fn = lambda x, _: custom_fn(x) return lx.JacobianLinearOperator(fn, x, None, tags, jac="bwd") @_operators_append def make_trivial_diagonal_operator(getkey, matrix, tags): assert tags == lx.diagonal_tag diag = jnp.diag(matrix) return lx.DiagonalLinearOperator(diag) @_operators_append def make_identity_operator(getkey, matrix, tags): in_struct = jax.ShapeDtypeStruct((matrix.shape[-1],), matrix.dtype) return lx.IdentityLinearOperator(input_structure=in_struct) @_operators_append def make_tridiagonal_operator(getkey, matrix, tags): diag1 = jnp.diag(matrix) if tags == lx.tridiagonal_tag: diag2 = jnp.diag(matrix, k=-1) diag3 = jnp.diag(matrix, k=1) return lx.TridiagonalLinearOperator(diag1, diag2, diag3) elif tags == lx.diagonal_tag: diag2 = diag3 = jnp.zeros(matrix.shape[0] - 1) return lx.TaggedLinearOperator( lx.TridiagonalLinearOperator(diag1, diag2, diag3), lx.diagonal_tag ) elif tags == lx.symmetric_tag: diag2 = diag3 = jnp.diag(matrix, k=1) return lx.TaggedLinearOperator( lx.TridiagonalLinearOperator(diag1, diag2, diag3), lx.symmetric_tag ) else: assert False, tags @_operators_append def make_add_operator(getkey, matrix, tags): matrix1 = 0.7 * matrix matrix2 = 0.3 * matrix operator = make_matrix_operator(getkey, matrix1, ()) + make_function_operator( getkey, matrix2, () ) return lx.TaggedLinearOperator(operator, tags) @_operators_append def make_mul_operator(getkey, matrix, tags): operator = make_jac_operator(getkey, 0.7 * matrix, ()) / 0.7 return lx.TaggedLinearOperator(operator, tags) @_operators_append def make_composed_operator(getkey, matrix, tags): _, size = matrix.shape diag = jr.normal(getkey(), (size,), dtype=matrix.dtype) diag = jnp.where(jnp.abs(diag) < 0.05, 0.8, diag) operator1 = make_trivial_pytree_operator(getkey, matrix / diag[None], ()) operator2 = lx.DiagonalLinearOperator(diag) return lx.TaggedLinearOperator(operator1 @ operator2, tags) # Slightly sketchy approach to finite differences, in that this is pulled out of # Numerical Recipes. # I also don't know of a handling of the JVP case off the top of my head -- although # I'm sure it exists somewhere -- so I'm improvising a little here. (In particular # removing the usual "(x + h) - x" denominator.) def finite_difference_jvp(fn, primals, tangents): out = fn(*primals) # Choose ε to trade-off truncation error and floating-point rounding error. max_leaves = [jnp.max(jnp.abs(p)) for p in jtu.tree_leaves(primals)] + [1] scale = jnp.max(jnp.stack(max_leaves)) ε = np.sqrt(np.finfo(np.float64).eps) * scale with jax.numpy_dtype_promotion("standard"): primals_ε = (ω(primals) + ε * ω(tangents)).ω out_ε = fn(*primals_ε) tangents_out = jtu.tree_map(lambda x, y: (x - y) / ε, out_ε, out) return out, tangents_out def jvp_jvp_impl( getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype ): t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None if (make_matrix is construct_matrix) or pseudoinverse: matrix, t_matrix, tt_matrix, tt_t_matrix = construct_matrix( getkey, solver, tags, num=4, dtype=dtype ) make_op = ft.partial(make_operator, getkey) t_make_operator = lambda p, t_p: eqx.filter_jvp( make_op, (p, tags), (t_p, t_tags) ) tt_make_operator = lambda p, t_p, tt_p, tt_t_p: eqx.filter_jvp( t_make_operator, (p, t_p), (tt_p, tt_t_p) ) (operator, t_operator), (tt_operator, tt_t_operator) = tt_make_operator( matrix, t_matrix, tt_matrix, tt_t_matrix ) out_size, _ = matrix.shape vec = jr.normal(getkey(), (out_size,), dtype=dtype) t_vec = jr.normal(getkey(), (out_size,), dtype=dtype) tt_vec = jr.normal(getkey(), (out_size,), dtype=dtype) tt_t_vec = jr.normal(getkey(), (out_size,), dtype=dtype) if use_state: def linear_solve1(operator, vector): op_dynamic, op_static = eqx.partition(operator, eqx.is_inexact_array) stopped_operator = eqx.combine(lax.stop_gradient(op_dynamic), op_static) state = solver.init(stopped_operator, options={}) sol = lx.linear_solve(operator, vector, state=state, solver=solver) return sol.value else: def linear_solve1(operator, vector): sol = lx.linear_solve(operator, vector, solver=solver) return sol.value if pseudoinverse: jnp_solve1 = lambda mat, vec: jnp.linalg.lstsq(mat, vec)[0] # pyright: ignore else: jnp_solve1 = jnp.linalg.solve # pyright: ignore linear_solve2 = ft.partial(eqx.filter_jvp, linear_solve1) jnp_solve2 = ft.partial(eqx.filter_jvp, jnp_solve1) def _make_primal_tangents(mode): lx_args = ([], [], operator, t_operator, tt_operator, tt_t_operator) jnp_args = ([], [], matrix, t_matrix, tt_matrix, tt_t_matrix) for primals, ttangents, op, t_op, tt_op, tt_t_op in (lx_args, jnp_args): if "op" in mode: primals.append(op) ttangents.append(tt_op) if "vec" in mode: primals.append(vec) ttangents.append(tt_vec) if "t_op" in mode: primals.append(t_op) ttangents.append(tt_t_op) if "t_vec" in mode: primals.append(t_vec) ttangents.append(tt_t_vec) lx_out = tuple(lx_args[0]), tuple(lx_args[1]) jnp_out = tuple(jnp_args[0]), tuple(jnp_args[1]) return lx_out, jnp_out modes = ( {"op"}, {"vec"}, {"t_op"}, {"t_vec"}, {"op", "vec"}, {"op", "t_op"}, {"op", "t_vec"}, {"vec", "t_op"}, {"vec", "t_vec"}, {"op", "vec", "t_op"}, {"op", "vec", "t_vec"}, {"vec", "t_op", "t_vec"}, {"op", "vec", "t_op", "t_vec"}, ) for mode in modes: if mode == {"op"}: linear_solve3 = lambda op: linear_solve2((op, vec), (t_operator, t_vec)) jnp_solve3 = lambda mat: jnp_solve2((mat, vec), (t_matrix, t_vec)) elif mode == {"vec"}: linear_solve3 = lambda v: linear_solve2( (operator, v), (t_operator, t_vec) ) jnp_solve3 = lambda v: jnp_solve2((matrix, v), (t_matrix, t_vec)) elif mode == {"op", "vec"}: linear_solve3 = lambda op, v: linear_solve2( (op, v), (t_operator, t_vec) ) jnp_solve3 = lambda mat, v: jnp_solve2((mat, v), (t_matrix, t_vec)) elif mode == {"t_op"}: linear_solve3 = lambda t_op: linear_solve2( (operator, vec), (t_op, t_vec) ) jnp_solve3 = lambda t_mat: jnp_solve2((matrix, vec), (t_mat, t_vec)) elif mode == {"t_vec"}: linear_solve3 = lambda t_v: linear_solve2( (operator, vec), (t_operator, t_v) ) jnp_solve3 = lambda t_v: jnp_solve2((matrix, vec), (t_matrix, t_v)) elif mode == {"op", "vec"}: linear_solve3 = lambda op, v: linear_solve2( (op, v), (t_operator, t_vec) ) jnp_solve3 = lambda mat, v: jnp_solve2((mat, v), (t_matrix, t_vec)) elif mode == {"op", "t_op"}: linear_solve3 = lambda op, t_op: linear_solve2((op, vec), (t_op, t_vec)) jnp_solve3 = lambda mat, t_mat: jnp_solve2((mat, vec), (t_mat, t_vec)) elif mode == {"op", "t_vec"}: linear_solve3 = lambda op, t_v: linear_solve2( (op, vec), (t_operator, t_v) ) jnp_solve3 = lambda mat, t_v: jnp_solve2((mat, vec), (t_matrix, t_v)) elif mode == {"vec", "t_op"}: linear_solve3 = lambda v, t_op: linear_solve2( (operator, v), (t_op, t_vec) ) jnp_solve3 = lambda v, t_mat: jnp_solve2((matrix, v), (t_mat, t_vec)) elif mode == {"vec", "t_vec"}: linear_solve3 = lambda v, t_v: linear_solve2( (operator, v), (t_operator, t_v) ) jnp_solve3 = lambda v, t_v: jnp_solve2((matrix, v), (t_matrix, t_v)) elif mode == {"op", "vec", "t_op"}: linear_solve3 = lambda op, v, t_op: linear_solve2( (op, v), (t_op, t_vec) ) jnp_solve3 = lambda mat, v, t_mat: jnp_solve2((mat, v), (t_mat, t_vec)) elif mode == {"op", "vec", "t_vec"}: linear_solve3 = lambda op, v, t_v: linear_solve2( (op, v), (t_operator, t_v) ) jnp_solve3 = lambda mat, v, t_v: jnp_solve2((mat, v), (t_matrix, t_v)) elif mode == {"vec", "t_op", "t_vec"}: linear_solve3 = lambda v, t_op, t_v: linear_solve2( (operator, v), (t_op, t_v) ) jnp_solve3 = lambda v, t_mat, t_v: jnp_solve2((matrix, v), (t_mat, t_v)) elif mode == {"op", "vec", "t_op", "t_vec"}: linear_solve3 = lambda op, v, t_op, t_v: linear_solve2( (op, v), (t_op, t_v) ) jnp_solve3 = lambda mat, v, t_mat, t_v: jnp_solve2( (mat, v), (t_mat, t_v) ) else: assert False linear_solve3 = ft.partial(eqx.filter_jvp, linear_solve3) linear_solve3 = eqx.filter_jit(linear_solve3) jnp_solve3 = ft.partial(eqx.filter_jvp, jnp_solve3) jnp_solve3 = eqx.filter_jit(jnp_solve3) (primal, tangent), (jnp_primal, jnp_tangent) = _make_primal_tangents(mode) (out, t_out), (minus_out, tt_out) = linear_solve3(primal, tangent) (true_out, true_t_out), (minus_true_out, true_tt_out) = jnp_solve3( jnp_primal, jnp_tangent ) assert tree_allclose(out, true_out, atol=1e-4) assert tree_allclose(t_out, true_t_out, atol=1e-4) assert tree_allclose(tt_out, true_tt_out, atol=1e-4) assert tree_allclose(minus_out, minus_true_out, atol=1e-4) ================================================ FILE: tests/test_adjoint.py ================================================ import jax import jax.numpy as jnp import jax.random as jr import lineax as lx import pytest from lineax import FunctionLinearOperator from .helpers import ( make_identity_operator, make_jacrev_operator, make_operators, make_tridiagonal_operator, make_trivial_diagonal_operator, tree_allclose, ) @pytest.mark.parametrize("make_operator", make_operators) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_adjoint(make_operator, dtype, getkey): if ( make_operator is make_trivial_diagonal_operator or make_operator is make_identity_operator ): matrix = jnp.eye(4, dtype=dtype) tags = lx.diagonal_tag in_size = out_size = 4 elif make_operator is make_tridiagonal_operator: matrix = jnp.eye(4, dtype=dtype) tags = lx.tridiagonal_tag in_size = out_size = 4 else: matrix = jr.normal(getkey(), (3, 5), dtype=dtype) tags = () in_size = 5 out_size = 3 if make_operator is make_jacrev_operator and dtype is jnp.complex128: # JacobianLinearOperator does not support complex dtypes when jac="bwd" return operator = make_operator(getkey, matrix, tags) v1, v2 = ( jr.normal(getkey(), (in_size,), dtype=dtype), jr.normal(getkey(), (out_size,), dtype=dtype), ) inner1 = operator.mv(v1) @ v2.conj() adjoint_op1 = lx.conj(operator).transpose() ov2 = adjoint_op1.mv(v2) inner2 = v1 @ ov2.conj() assert tree_allclose(inner1, inner2) adjoint_op2 = lx.conj(operator.transpose()) ov2 = adjoint_op2.mv(v2) inner2 = v1 @ ov2.conj() assert tree_allclose(inner1, inner2) def test_functional_pytree_adjoint(): def fn(y): return {"b": y["a"]} y_struct = jax.eval_shape(lambda: {"a": 0.0}) operator = FunctionLinearOperator(fn, y_struct) conj_operator = lx.conj(operator) assert tree_allclose(lx.materialise(conj_operator), lx.materialise(operator)) def test_functional_pytree_adjoint_complex(): def fn(y): return {"b": y["a"]} y_struct = jax.eval_shape(lambda: {"a": 0.0j}) operator = FunctionLinearOperator(fn, y_struct) conj_operator = lx.conj(operator) assert tree_allclose(lx.materialise(conj_operator), lx.materialise(operator)) if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-12 else: tol = 1e-6 @pytest.mark.parametrize( "solver", [ # in theory only 1 iteration is needed, but stopping criteria are # complicated, see gh #160 lx.GMRES(tol, tol, max_steps=4, restart=1), lx.BiCGStab(tol, tol, max_steps=3), lx.Normal(lx.CG(tol, tol, max_steps=4)), lx.CG(tol, tol, max_steps=3), ], ) def test_preconditioner_adjoint(solver): """Test for fix to gh #160""" # Nonsymmetric poorly conditioned matrix. Without preconditioning, # this would take 20+ iterations (100s for GMRES) key = jax.random.key(123) key, subkey = jax.random.split(key) A = jax.random.uniform(key, (10, 10)) A += jnp.diag(jnp.arange(A.shape[0]) ** 6).astype(A.dtype) b = jax.random.uniform(subkey, (A.shape[0],)) if isinstance(solver, lx.CG): A = A.T @ A tags = (lx.positive_semidefinite_tag,) else: tags = () A = lx.MatrixLinearOperator(A, tags=tags) # exact inverse, should only take ~1 iteration M = lx.MatrixLinearOperator( jnp.linalg.inv(A.matrix), tags=tags, ) def solve(b): out = lx.linear_solve( A, b, solver=solver, options={"preconditioner": M}, throw=True ) return out.value # if they don't converge then this will throw an error _ = solve(b) A1 = jax.jacfwd(solve)(b) A2 = jax.jacrev(solve)(b) # we also do a sanity check, dx/db should give A^{-1} assert tree_allclose(A1, jnp.linalg.inv(A.matrix), atol=tol, rtol=tol) assert tree_allclose(A2, jnp.linalg.inv(A.matrix), atol=tol, rtol=tol) ================================================ FILE: tests/test_invert.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import equinox as eqx import jax import jax.numpy as jnp import jax.random as jr import lineax as lx from .helpers import tree_allclose def _well_conditioned_matrix(getkey, size=3, dtype=jnp.float64): """Generate a well-conditioned random matrix.""" while True: matrix = jr.normal(getkey(), (size, size), dtype=dtype) if jnp.linalg.cond(matrix) < 100: return matrix def _well_conditioned_psd_matrix(getkey, size=3, dtype=jnp.float64): """Generate a well-conditioned PSD matrix.""" matrix = _well_conditioned_matrix(getkey, size, dtype) return matrix @ matrix.T.conj() # -- Core behaviour -- def test_mv(getkey): """invert(A).mv(v) solves A x = v.""" matrix = _well_conditioned_matrix(getkey) op = lx.MatrixLinearOperator(matrix) inv_op = lx.invert(op) vec = jr.normal(getkey(), (3,), dtype=jnp.float64) result = inv_op.mv(vec) expected = jnp.linalg.solve(matrix, vec) assert tree_allclose(result, expected, atol=1e-10) def test_composition_identity(getkey): """(invert(A) @ A).mv(v) ~ v.""" matrix = _well_conditioned_matrix(getkey) op = lx.MatrixLinearOperator(matrix) inv_op = lx.invert(op) composed = inv_op @ op vec = jr.normal(getkey(), (3,), dtype=jnp.float64) result = composed.mv(vec) assert tree_allclose(result, vec, atol=1e-10) def test_double_inverse(getkey): """invert(invert(A)).mv(v) ~ A.mv(v).""" matrix = _well_conditioned_matrix(getkey) op = lx.MatrixLinearOperator(matrix) double_inv = lx.invert(lx.invert(op)) vec = jr.normal(getkey(), (3,), dtype=jnp.float64) result = double_inv.mv(vec) expected = matrix @ vec assert tree_allclose(result, expected, atol=1e-8) # -- Pseudoinverse (non-square) -- def test_pseudoinverse_overdetermined(getkey): """invert of a tall matrix gives the least-squares pseudoinverse.""" matrix = jr.normal(getkey(), (5, 3), dtype=jnp.float64) op = lx.MatrixLinearOperator(matrix) pinv_op = lx.invert(op, solver=lx.AutoLinearSolver(well_posed=False)) vec = jr.normal(getkey(), (5,), dtype=jnp.float64) result = pinv_op.mv(vec) expected = jnp.linalg.lstsq(matrix, vec)[0] assert tree_allclose(result, expected, atol=1e-8) def test_pseudoinverse_underdetermined(getkey): """invert of a wide matrix gives the minimum-norm pseudoinverse.""" matrix = jr.normal(getkey(), (3, 5), dtype=jnp.float64) op = lx.MatrixLinearOperator(matrix) pinv_op = lx.invert(op, solver=lx.AutoLinearSolver(well_posed=False)) vec = jr.normal(getkey(), (3,), dtype=jnp.float64) result = pinv_op.mv(vec) expected = jnp.linalg.lstsq(matrix, vec)[0] assert tree_allclose(result, expected, atol=1e-8) # -- Explicit solver tests -- def test_solver_cholesky(getkey): """Works with Cholesky solver for PSD matrices.""" matrix = _well_conditioned_psd_matrix(getkey) op = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag) inv_op = lx.invert(op, solver=lx.Cholesky()) vec = jr.normal(getkey(), (3,), dtype=jnp.float64) result = inv_op.mv(vec) expected = jnp.linalg.solve(matrix, vec) assert tree_allclose(result, expected, atol=1e-10) def test_solver_cg(getkey): """Works with CG (iterative) solver for PSD matrices.""" matrix = _well_conditioned_psd_matrix(getkey) op = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag) inv_op = lx.invert(op, solver=lx.CG(rtol=1e-12, atol=1e-12)) vec = jr.normal(getkey(), (3,), dtype=jnp.float64) result = inv_op.mv(vec) expected = jnp.linalg.solve(matrix, vec) assert tree_allclose(result, expected, atol=1e-8) # -- vmap -- def test_vmap(getkey): """vmap over invert(A).mv works correctly.""" matrix = _well_conditioned_matrix(getkey) op = lx.MatrixLinearOperator(matrix) inv_op = lx.invert(op) vecs = jr.normal(getkey(), (5, 3), dtype=jnp.float64) result = jax.vmap(inv_op.mv)(vecs) expected = jax.vmap(lambda v: jnp.linalg.solve(matrix, v))(vecs) assert tree_allclose(result, expected, atol=1e-10) # -- AD -- def test_grad_wrt_vector(getkey): """VJP through invert(A).mv(v) wrt vector.""" matrix = _well_conditioned_matrix(getkey) op = lx.MatrixLinearOperator(matrix) inv_op = lx.invert(op) def f(vec): return jnp.sum(inv_op.mv(vec)) vec = jr.normal(getkey(), (3,), dtype=jnp.float64) grad = jax.grad(f)(vec) expected = jnp.linalg.solve(matrix.T, jnp.ones(3, dtype=jnp.float64)) assert tree_allclose(grad, expected, atol=1e-10) def test_jvp_wrt_vector(getkey): """JVP through invert(A).mv(v) wrt vector.""" matrix = _well_conditioned_matrix(getkey) op = lx.MatrixLinearOperator(matrix) inv_op = lx.invert(op) vec = jr.normal(getkey(), (3,), dtype=jnp.float64) t_vec = jr.normal(getkey(), (3,), dtype=jnp.float64) primals, tangents = eqx.filter_jvp(inv_op.mv, (vec,), (t_vec,)) expected_primals = jnp.linalg.solve(matrix, vec) expected_tangents = jnp.linalg.solve(matrix, t_vec) assert tree_allclose(primals, expected_primals, atol=1e-10) assert tree_allclose(tangents, expected_tangents, atol=1e-10) def test_grad_wrt_operator(getkey): """VJP through invert(A).mv(v) wrt the inner matrix.""" matrix = _well_conditioned_matrix(getkey) vec = jr.normal(getkey(), (3,), dtype=jnp.float64) def f_inv(mat): op = lx.MatrixLinearOperator(mat) inv_op = lx.invert(op) return jnp.sum(inv_op.mv(vec)) def f_jnp(mat): return jnp.sum(jnp.linalg.solve(mat, vec)) grad_inv = jax.grad(f_inv)(matrix) grad_jnp = jax.grad(f_jnp)(matrix) assert tree_allclose(grad_inv, grad_jnp, atol=1e-8) def test_jvp_wrt_operator(getkey): """JVP through invert(A).mv(v) wrt the inner matrix.""" matrix = _well_conditioned_matrix(getkey) t_matrix = jr.normal(getkey(), (3, 3), dtype=jnp.float64) vec = jr.normal(getkey(), (3,), dtype=jnp.float64) def f_inv(mat): op = lx.MatrixLinearOperator(mat) inv_op = lx.invert(op) return inv_op.mv(vec) def f_jnp(mat): return jnp.linalg.solve(mat, vec) out, t_out = eqx.filter_jvp(f_inv, (matrix,), (t_matrix,)) expected_out, expected_t_out = eqx.filter_jvp(f_jnp, (matrix,), (t_matrix,)) assert tree_allclose(out, expected_out, atol=1e-10) assert tree_allclose(t_out, expected_t_out, atol=1e-8) ================================================ FILE: tests/test_jvp.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools as ft import equinox as eqx import jax.numpy as jnp import jax.random as jr import lineax as lx import pytest from .helpers import ( construct_matrix, construct_singular_matrix, finite_difference_jvp, has_tag, make_jac_operator, make_matrix_operator, solvers_tags_pseudoinverse, tree_allclose, ) @pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) @pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) @pytest.mark.parametrize("use_state", (True, False)) @pytest.mark.parametrize( "make_matrix", ( construct_matrix, construct_singular_matrix, ), ) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_jvp( getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype ): t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None if (make_matrix is construct_matrix) or pseudoinverse: matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype) out_size, _ = matrix.shape vec = jr.normal(getkey(), (out_size,), dtype=dtype) t_vec = jr.normal(getkey(), (out_size,), dtype=dtype) if has_tag(tags, lx.unit_diagonal_tag): # For all the other tags, A + εB with A, B \in {matrices satisfying the tag} # still satisfies the tag itself. # This is the exception. t_matrix.at[jnp.arange(3), jnp.arange(3)].set(0) make_op = ft.partial(make_operator, getkey) operator, t_operator = eqx.filter_jvp( make_op, (matrix, tags), (t_matrix, t_tags) ) if use_state: state = solver.init(operator, options={}) linear_solve = ft.partial(lx.linear_solve, state=state) else: linear_solve = lx.linear_solve solve_vec_only = lambda v: linear_solve(operator, v, solver).value solve_op_only = lambda op: linear_solve(op, vec, solver).value solve_op_vec = lambda op, v: linear_solve(op, v, solver).value vec_out, t_vec_out = eqx.filter_jvp(solve_vec_only, (vec,), (t_vec,)) op_out, t_op_out = eqx.filter_jvp(solve_op_only, (operator,), (t_operator,)) op_vec_out, t_op_vec_out = eqx.filter_jvp( solve_op_vec, (operator, vec), (t_operator, t_vec), ) (expected_op_out, *_), (t_expected_op_out, *_) = eqx.filter_jvp( lambda op: jnp.linalg.lstsq(op, vec), # pyright: ignore (matrix,), (t_matrix,), ) (expected_op_vec_out, *_), (t_expected_op_vec_out, *_) = eqx.filter_jvp( jnp.linalg.lstsq, (matrix, vec), (t_matrix, t_vec), # pyright: ignore ) # Work around JAX issue #14868. if jnp.any(jnp.isnan(t_expected_op_out)): _, (t_expected_op_out, *_) = finite_difference_jvp( lambda op: jnp.linalg.lstsq(op, vec), # pyright: ignore (matrix,), (t_matrix,), ) if jnp.any(jnp.isnan(t_expected_op_vec_out)): _, (t_expected_op_vec_out, *_) = finite_difference_jvp( jnp.linalg.lstsq, (matrix, vec), (t_matrix, t_vec), # pyright: ignore ) pinv_matrix = jnp.linalg.pinv(matrix) # pyright: ignore expected_vec_out = pinv_matrix @ vec assert tree_allclose(vec_out, expected_vec_out) assert tree_allclose(op_out, expected_op_out) assert tree_allclose(op_vec_out, expected_op_vec_out) t_expected_vec_out = pinv_matrix @ t_vec assert tree_allclose(matrix @ t_vec_out, matrix @ t_expected_vec_out, rtol=1e-3) assert tree_allclose(t_op_out, t_expected_op_out, rtol=1e-3) assert tree_allclose(t_op_vec_out, t_expected_op_vec_out, rtol=1e-3) ================================================ FILE: tests/test_jvp_jvp1.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import equinox as eqx import jax.numpy as jnp import pytest from .helpers import ( construct_matrix, construct_singular_matrix, jvp_jvp_impl, make_jac_operator, make_matrix_operator, solvers_tags_pseudoinverse, ) # Workaround for https://github.com/jax-ml/jax/issues/27201 @pytest.fixture(autouse=True) def _clear_cache(): eqx.clear_caches() @pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) @pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) @pytest.mark.parametrize("use_state", (True, False)) @pytest.mark.parametrize("make_matrix", (construct_matrix, construct_singular_matrix)) @pytest.mark.parametrize("dtype", (jnp.float64,)) def test_jvp_jvp( getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype ): jvp_jvp_impl( getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype, ) ================================================ FILE: tests/test_jvp_jvp2.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import equinox as eqx import jax.numpy as jnp import pytest from .helpers import ( construct_matrix, construct_singular_matrix, jvp_jvp_impl, make_jac_operator, make_matrix_operator, solvers_tags_pseudoinverse, ) # Workaround for https://github.com/jax-ml/jax/issues/27201 @pytest.fixture(autouse=True) def _clear_cache(): eqx.clear_caches() @pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) @pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) @pytest.mark.parametrize("use_state", (True, False)) @pytest.mark.parametrize("make_matrix", (construct_matrix, construct_singular_matrix)) @pytest.mark.parametrize("dtype", (jnp.complex128,)) def test_jvp_jvp( getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype ): jvp_jvp_impl( getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype, ) ================================================ FILE: tests/test_lsmr.py ================================================ import equinox as ex import jax.numpy as jnp import lineax as lx import pytest solver = lx.LSMR(1e-10, 1e-10) Aill = lx.DiagonalLinearOperator(jnp.array([1e8, 1e6, 1e4, 1e2, 1])) Awell = lx.DiagonalLinearOperator(jnp.array([2.0, 4.0, 5.0, 8.0, 10.0])) Asing = lx.DiagonalLinearOperator(jnp.array([0.0, 4.0, 5.0, 8.0, 10.0])) def test_ill_conditioned(): try: lx.linear_solve(Aill, jnp.ones(5), solver=solver) except ex.EquinoxRuntimeError as e: assert "Condition number" in str(e) def test_zero_rhs(): # b=0, so x=0 is solution sol = lx.linear_solve(Aill, jnp.zeros(5), solver=solver) assert (sol.value == 0).all() sol = lx.linear_solve(Awell, jnp.zeros(5), solver=solver) assert (sol.value == 0).all() sol = lx.linear_solve(Asing, jnp.zeros(5), solver=solver) assert (sol.value == 0).all() # b lies in null space of A, so x=0 is minimum norm solution sol = lx.linear_solve(Asing, jnp.zeros(5).at[0].set(1), solver=solver) assert (sol.value == 0).all() @pytest.mark.skip("Damp support is disabled.") def test_damp_regularizes(): solution_ill = lx.linear_solve(Aill, jnp.ones(5), solver=solver, options={}) assert solution_ill.stats["istop"] == 1 solution_damped = lx.linear_solve( Aill, jnp.ones(5), solver=solver, options={"damp": 100.0} ) assert solution_damped.stats["istop"] == 2 assert solution_damped.stats["num_steps"] < solution_ill.stats["num_steps"] @pytest.mark.skip("Damp support is disabled.") def test_damp(): solution_damped = lx.linear_solve( Awell, jnp.ones(5), solver=solver, options={"damp": 1.0} ) assert jnp.allclose( solution_damped.value, jnp.array([0.4, 0.23529412, 0.19230769, 0.12307692, 0.0990099]), ) solution_damped = lx.linear_solve( Awell, jnp.ones(5), solver=solver, options={"damp": 1000.0} ) assert jnp.allclose( solution_damped.value, jnp.array([2e-6, 4e-6, 5e-6, 8e-6, 10.0e-6]) ) ================================================ FILE: tests/test_misc.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax import jax.numpy as jnp import lineax as lx import lineax._misc as lx_misc import pytest def test_inexact_asarray_no_copy(): x = jnp.array([1.0]) assert lx_misc.inexact_asarray(x) is x y = jnp.array([1.0, 2.0]) assert jax.vmap(lx_misc.inexact_asarray)(y) is y # See JAX issue #15676 def test_inexact_asarray_jvp(): p, t = jax.jvp(lx_misc.inexact_asarray, (1.0,), (2.0,)) assert type(p) is not float assert type(t) is not float @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_zero_matrix(dtype): A = lx.MatrixLinearOperator(jnp.zeros((2, 2), dtype=dtype)) b = jnp.array([1.0, 2.0], dtype=dtype) lx.linear_solve(A, b, lx.SVD()) ================================================ FILE: tests/test_norm.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax import jax.flatten_util as jfu import jax.numpy as jnp import lineax.internal as lxi from .helpers import tree_allclose def _square(x): return x * jnp.conj(x) def _two_norm(x): return jnp.sqrt(jnp.sum(_square(jfu.ravel_pytree(x)[0]))).real def _rms_norm(x): return jnp.sqrt(jnp.mean(_square(jfu.ravel_pytree(x)[0]))).real def _max_norm(x): return jnp.max(jnp.abs(jfu.ravel_pytree(x)[0])) def test_nonzero(): zero = [jnp.array(0.0), jnp.zeros((2, 2))] x = [jnp.array(1.0), jnp.arange(4.0).reshape(2, 2)] tx = [jnp.array(0.5), jnp.arange(1.0, 5.0).reshape(2, 2)] two = lxi.two_norm(x) rms = lxi.rms_norm(x) max = lxi.max_norm(x) true_two = _two_norm(x) true_rms = _rms_norm(x) true_max = _max_norm(x) assert jnp.allclose(two, true_two) assert jnp.allclose(rms, true_rms) assert jnp.allclose(max, true_max) two_jvp = jax.jvp(lxi.two_norm, (x,), (tx,)) true_two_jvp = jax.jvp(_two_norm, (x,), (tx,)) rms_jvp = jax.jvp(lxi.rms_norm, (x,), (tx,)) true_rms_jvp = jax.jvp(_rms_norm, (x,), (tx,)) max_jvp = jax.jvp(lxi.max_norm, (x,), (tx,)) true_max_jvp = jax.jvp(_max_norm, (x,), (tx,)) assert tree_allclose(two_jvp, true_two_jvp) assert tree_allclose(rms_jvp, true_rms_jvp) assert tree_allclose(max_jvp, true_max_jvp) two0_jvp = jax.jvp(lxi.two_norm, (x,), (zero,)) rms0_jvp = jax.jvp(lxi.rms_norm, (x,), (zero,)) max0_jvp = jax.jvp(lxi.max_norm, (x,), (zero,)) assert tree_allclose(two0_jvp, (true_two, jnp.array(0.0))) assert tree_allclose(rms0_jvp, (true_rms, jnp.array(0.0))) assert tree_allclose(max0_jvp, (true_max, jnp.array(0.0))) def test_zero(): zero = [jnp.array(0.0), jnp.zeros((2, 2))] tx = [jnp.array(0.5), jnp.arange(1.0, 5.0).reshape(2, 2)] for t in (zero, tx): two0 = jax.jvp(lxi.two_norm, (zero,), (t,)) rms0 = jax.jvp(lxi.rms_norm, (zero,), (t,)) max0 = jax.jvp(lxi.max_norm, (zero,), (t,)) true0 = (jnp.array(0.0), jnp.array(0.0)) assert tree_allclose(two0, true0) assert tree_allclose(rms0, true0) assert tree_allclose(max0, true0) def test_complex(): x = jnp.array([3 + 1.2j, -0.5 + 4.9j]) tx = jnp.array([2 - 0.3j, -0.7j]) two = jax.jvp(lxi.two_norm, (x,), (tx,)) true_two = jax.jvp(_two_norm, (x,), (tx,)) rms = jax.jvp(lxi.rms_norm, (x,), (tx,)) true_rms = jax.jvp(_rms_norm, (x,), (tx,)) max = jax.jvp(lxi.max_norm, (x,), (tx,)) true_max = jax.jvp(_max_norm, (x,), (tx,)) assert two[0].imag == 0 assert tree_allclose(two, true_two) assert rms[0].imag == 0 assert tree_allclose(rms, true_rms) assert max[0].imag == 0 assert tree_allclose(max, true_max) def test_size_zero(): zero = jnp.array(0.0) for x in (jnp.array([]), [jnp.array([]), jnp.array([])]): assert tree_allclose(lxi.two_norm(x), zero) assert tree_allclose(lxi.rms_norm(x), zero) assert tree_allclose(lxi.max_norm(x), zero) assert tree_allclose(jax.jvp(lxi.two_norm, (x,), (x,)), (zero, zero)) assert tree_allclose(jax.jvp(lxi.rms_norm, (x,), (x,)), (zero, zero)) assert tree_allclose(jax.jvp(lxi.max_norm, (x,), (x,)), (zero, zero)) ================================================ FILE: tests/test_operator.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import cast import equinox as eqx import jax import jax.numpy as jnp import jax.random as jr import lineax as lx import pytest from .helpers import ( make_identity_operator, make_jacrev_operator, make_operators, make_tridiagonal_operator, make_trivial_diagonal_operator, tree_allclose, ) @pytest.mark.parametrize("make_operator", make_operators) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_ops(make_operator, getkey, dtype): if ( make_operator is make_trivial_diagonal_operator or make_operator is make_identity_operator ): matrix = jnp.eye(3, dtype=dtype) tags = lx.diagonal_tag elif make_operator is make_tridiagonal_operator: matrix = jnp.eye(3, dtype=dtype) tags = lx.tridiagonal_tag else: matrix = jr.normal(getkey(), (3, 3), dtype=dtype) tags = () if make_operator is make_jacrev_operator and dtype is jnp.complex128: # JacobianLinearOperator does not support complex dtypes when jac="bwd" return matrix1 = make_operator(getkey, matrix, tags) matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3), dtype=dtype)) scalar = jr.normal(getkey(), (), dtype=dtype) add = matrix1 + matrix2 composed = matrix1 @ matrix2 mul = matrix1 * scalar rmul = cast(lx.AbstractLinearOperator, scalar * matrix1) div = matrix1 / scalar vec = jr.normal(getkey(), (3,), dtype=dtype) assert tree_allclose(matrix1.mv(vec) + matrix2.mv(vec), add.mv(vec)) assert tree_allclose(matrix1.mv(matrix2.mv(vec)), composed.mv(vec)) scalar_matvec = scalar * matrix1.mv(vec) assert tree_allclose(scalar_matvec, mul.mv(vec)) assert tree_allclose(scalar_matvec, rmul.mv(vec)) assert tree_allclose(matrix1.mv(vec) / scalar, div.mv(vec)) add_matrix = matrix1.as_matrix() + matrix2.as_matrix() composed_matrix = matrix1.as_matrix() @ matrix2.as_matrix() mul_matrix = scalar * matrix1.as_matrix() div_matrix = matrix1.as_matrix() / scalar assert tree_allclose(add_matrix, add.as_matrix()) assert tree_allclose(composed_matrix, composed.as_matrix()) assert tree_allclose(mul_matrix, mul.as_matrix()) assert tree_allclose(mul_matrix, rmul.as_matrix()) assert tree_allclose(div_matrix, div.as_matrix()) assert tree_allclose(add_matrix.T, add.T.as_matrix()) assert tree_allclose(composed_matrix.T, composed.T.as_matrix()) assert tree_allclose(mul_matrix.T, mul.T.as_matrix()) assert tree_allclose(mul_matrix.T, rmul.T.as_matrix()) assert tree_allclose(div_matrix.T, div.T.as_matrix()) @pytest.mark.parametrize("make_operator", make_operators) def test_structures_vector(make_operator, getkey): if ( make_operator is make_trivial_diagonal_operator or make_operator is make_identity_operator ): matrix = jnp.eye(4) tags = lx.diagonal_tag in_size = out_size = 4 elif make_operator is make_tridiagonal_operator: matrix = jnp.eye(4) tags = lx.tridiagonal_tag in_size = out_size = 4 else: matrix = jr.normal(getkey(), (3, 5)) tags = () in_size = 5 out_size = 3 operator = make_operator(getkey, matrix, tags) in_structure = jax.ShapeDtypeStruct((in_size,), jnp.float64) out_structure = jax.ShapeDtypeStruct((out_size,), jnp.float64) assert tree_allclose(in_structure, operator.in_structure()) assert tree_allclose(out_structure, operator.out_structure()) def _setup(getkey, matrix, tag: object | frozenset[object] = frozenset()): for make_operator in make_operators: if make_operator is make_trivial_diagonal_operator and tag != lx.diagonal_tag: continue if make_operator is make_tridiagonal_operator and tag not in ( lx.tridiagonal_tag, lx.diagonal_tag, lx.symmetric_tag, ): continue if make_operator is make_identity_operator and tag not in ( lx.tridiagonal_tag, lx.diagonal_tag, lx.symmetric_tag, ): continue operator = make_operator(getkey, matrix, tag) yield operator def _assert_except_diag(cond_fun, operators, flip_cond): if flip_cond: _cond_fun = cond_fun cond_fun = lambda x: not _cond_fun(x) for operator in operators: if isinstance(operator, lx.DiagonalLinearOperator): assert not cond_fun(operator) else: assert cond_fun(operator) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_linearise(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) operators = list(_setup(getkey, matrix)) vec = jr.normal(getkey(), (3,), dtype=dtype) for operator in operators: # Skip jacrev operators with complex dtype (jacrev doesn't support complex) if ( isinstance(operator, lx.JacobianLinearOperator) and operator.jac == "bwd" and dtype is jnp.complex128 ): continue linearised = lx.linearise(operator) # Actually evaluate the linearised operator to ensure it works result = linearised.mv(vec) expected = operator.mv(vec) assert tree_allclose(result, expected) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_materialise(dtype, getkey): operators = _setup(getkey, jr.normal(getkey(), (3, 3), dtype=dtype)) for operator in operators: lx.materialise(operator) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_materialise_large(dtype, getkey): operators = _setup(getkey, jr.normal(getkey(), (200, 500), dtype=dtype)) for operator in operators: lx.materialise(operator) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_diagonal(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) matrix_diag = jnp.diag(matrix) # test we properly extract diagonal from a dense matrix when not tagged operators = _setup(getkey, matrix) for operator in operators: assert jnp.allclose(lx.diagonal(operator), matrix_diag) # test we properly extract diagonal from diagonal matrix when tagged operators = _setup(getkey, jnp.diag(matrix_diag), lx.diagonal_tag) for operator in operators: if isinstance(operator, lx.IdentityLinearOperator): assert jnp.allclose(lx.diagonal(operator), jnp.ones(3)) else: assert jnp.allclose(lx.diagonal(operator), matrix_diag) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_tridiagonal(dtype, getkey): matrix = jr.normal(getkey(), (5, 5), dtype=dtype) matrix_diag = jnp.diag(matrix) matrix_lower_diag = jnp.diag(matrix, k=-1) matrix_upper_diag = jnp.diag(matrix, k=1) tridiag_matrix = ( jnp.diag(matrix_diag) + jnp.diag(matrix_lower_diag, k=-1) + jnp.diag(matrix_upper_diag, k=1) ) operators = _setup(getkey, tridiag_matrix, lx.tridiagonal_tag) for operator in operators: diag, lower_diag, upper_diag = lx.tridiagonal(operator) if isinstance(operator, lx.IdentityLinearOperator): assert jnp.allclose(diag, jnp.ones(5)) assert jnp.allclose(lower_diag, jnp.zeros(4)) assert jnp.allclose(upper_diag, jnp.zeros(4)) else: assert jnp.allclose(diag, matrix_diag) assert jnp.allclose(lower_diag, matrix_lower_diag) assert jnp.allclose(upper_diag, matrix_upper_diag) # Test ComposedLinearOperator: diagonal @ tridiagonal and tridiagonal @ diagonal random_diag = jr.normal(getkey(), (5,), dtype=dtype) tridiag_op = lx.TridiagonalLinearOperator( matrix_diag, matrix_lower_diag, matrix_upper_diag ) diag_op = lx.DiagonalLinearOperator(random_diag) # diagonal @ tridiagonal (row scaling) dt_matrix = jnp.matmul(jnp.diag(random_diag), tridiag_matrix) diag, lower_diag, upper_diag = lx.tridiagonal(diag_op @ tridiag_op) assert jnp.allclose(diag, jnp.diagonal(dt_matrix, 0)) assert jnp.allclose(lower_diag, jnp.diagonal(dt_matrix, -1)) assert jnp.allclose(upper_diag, jnp.diagonal(dt_matrix, 1)) # tridiagonal @ diagonal (column scaling) td_matrix = jnp.matmul(tridiag_matrix, jnp.diag(random_diag)) diag, lower_diag, upper_diag = lx.tridiagonal(tridiag_op @ diag_op) assert jnp.allclose(diag, jnp.diagonal(td_matrix, 0)) assert jnp.allclose(lower_diag, jnp.diagonal(td_matrix, -1)) assert jnp.allclose(upper_diag, jnp.diagonal(td_matrix, 1)) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_symmetric(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) symmetric_operators = _setup(getkey, matrix.T @ matrix, lx.symmetric_tag) for operator in symmetric_operators: assert lx.is_symmetric(operator) not_symmetric_operators = _setup(getkey, matrix) _assert_except_diag(lx.is_symmetric, not_symmetric_operators, flip_cond=True) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_diagonal(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) diagonal_operators = _setup(getkey, jnp.diag(jnp.diag(matrix)), lx.diagonal_tag) for operator in diagonal_operators: assert lx.is_diagonal(operator) not_diagonal_operators = _setup(getkey, matrix) _assert_except_diag(lx.is_diagonal, not_diagonal_operators, flip_cond=True) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_diagonal_scalar(dtype, getkey): matrix = jr.normal(getkey(), (1, 1), dtype=dtype) diagonal_operators = _setup(getkey, matrix) for operator in diagonal_operators: assert lx.is_diagonal(operator) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_diagonal_tridiagonal(dtype, getkey): diag1 = jr.normal(getkey(), (1,), dtype=dtype) diag2 = jnp.zeros((0,), dtype=dtype) op1 = lx.TridiagonalLinearOperator(diag1, diag2, diag2) assert lx.is_diagonal(op1) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_has_unit_diagonal(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) not_unit_diagonal = _setup(getkey, matrix) for operator in not_unit_diagonal: assert not lx.has_unit_diagonal(operator) matrix_unit_diag = matrix.at[jnp.arange(3), jnp.arange(3)].set(1) unit_diagonal = _setup(getkey, matrix_unit_diag, lx.unit_diagonal_tag) _assert_except_diag(lx.has_unit_diagonal, unit_diagonal, flip_cond=False) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_lower_triangular(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) lower_triangular = _setup(getkey, jnp.tril(matrix), lx.lower_triangular_tag) for operator in lower_triangular: assert lx.is_lower_triangular(operator) not_lower_triangular = _setup(getkey, matrix) _assert_except_diag(lx.is_lower_triangular, not_lower_triangular, flip_cond=True) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_upper_triangular(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) upper_triangular = _setup(getkey, jnp.triu(matrix), lx.upper_triangular_tag) for operator in upper_triangular: assert lx.is_upper_triangular(operator) not_upper_triangular = _setup(getkey, matrix) _assert_except_diag(lx.is_upper_triangular, not_upper_triangular, flip_cond=True) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_positive_semidefinite(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) not_positive_semidefinite = _setup(getkey, matrix) for operator in not_positive_semidefinite: assert not lx.is_positive_semidefinite(operator) positive_semidefinite = _setup( getkey, matrix.T.conj() @ matrix, lx.positive_semidefinite_tag ) _assert_except_diag( lx.is_positive_semidefinite, positive_semidefinite, flip_cond=False ) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_negative_semidefinite(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) not_negative_semidefinite = _setup(getkey, matrix) for operator in not_negative_semidefinite: assert not lx.is_negative_semidefinite(operator) negative_semidefinite = _setup( getkey, -matrix.T.conj() @ matrix, lx.negative_semidefinite_tag ) _assert_except_diag( lx.is_negative_semidefinite, negative_semidefinite, flip_cond=False ) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_tridiagonal(dtype, getkey): diag1 = jr.normal(getkey(), (5,), dtype=dtype) diag2 = jr.normal(getkey(), (4,), dtype=dtype) diag3 = jr.normal(getkey(), (4,), dtype=dtype) op1 = lx.TridiagonalLinearOperator(diag1, diag2, diag3) op2 = lx.IdentityLinearOperator(jax.eval_shape(lambda: diag1)) op3 = lx.MatrixLinearOperator(jnp.diag(diag1)) assert lx.is_tridiagonal(op1) assert lx.is_tridiagonal(op2) assert not lx.is_tridiagonal(op3) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_tangent_as_matrix(dtype, getkey): def _list_setup(matrix): # Exclude jacrev operator: jac="bwd" uses custom_vjp which doesn't support JVP return [ op for op in _setup(getkey, matrix) if not (isinstance(op, lx.JacobianLinearOperator) and op.jac == "bwd") ] matrix = jr.normal(getkey(), (3, 3), dtype=dtype) t_matrix = jr.normal(getkey(), (3, 3), dtype=dtype) operators, t_operators = eqx.filter_jvp(_list_setup, (matrix,), (t_matrix,)) for operator, t_operator in zip(operators, t_operators): t_operator = lx.TangentLinearOperator(operator, t_operator) if isinstance(operator, lx.DiagonalLinearOperator): assert jnp.allclose(operator.as_matrix(), jnp.diag(jnp.diag(matrix))) assert jnp.allclose(t_operator.as_matrix(), jnp.diag(jnp.diag(t_matrix))) else: assert jnp.allclose(operator.as_matrix(), matrix) assert jnp.allclose(t_operator.as_matrix(), t_matrix) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_materialise_function_linear_operator(dtype, getkey): x = ( jr.normal(getkey(), (5, 9), dtype=dtype), jr.normal(getkey(), (3,), dtype=dtype), ) input_structure = jax.eval_shape(lambda: x) fn = lambda x: {"a": jnp.broadcast_to(jnp.sum(x[0]), (1, 2))} output_structure = jax.eval_shape(fn, input_structure) operator = lx.FunctionLinearOperator(fn, input_structure) materialised_operator = lx.materialise(operator) assert materialised_operator.in_structure() == input_structure assert materialised_operator.out_structure() == output_structure assert isinstance(materialised_operator, lx.PyTreeLinearOperator) expected_struct = { "a": ( jax.ShapeDtypeStruct((1, 2, 5, 9), dtype), jax.ShapeDtypeStruct((1, 2, 3), dtype), ) } assert jax.eval_shape(lambda: materialised_operator.pytree) == expected_struct @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_pytree_transpose(dtype, getkey): out_struct = jax.eval_shape( lambda: ({"a": jnp.zeros((2, 3, 3), dtype=dtype)}, jnp.zeros((2,), dtype=dtype)) ) in_struct = jax.eval_shape(lambda: {"b": jnp.zeros((4,), dtype=dtype)}) leaf1 = jr.normal(getkey(), (2, 3, 3, 4), dtype=dtype) leaf2 = jr.normal(getkey(), (2, 4), dtype=dtype) pytree = ({"a": {"b": leaf1}}, {"b": leaf2}) operator = lx.PyTreeLinearOperator(pytree, out_struct) assert operator.in_structure() == in_struct assert operator.out_structure() == out_struct leaf1_T = jnp.moveaxis(leaf1, -1, 0) leaf2_T = jnp.moveaxis(leaf2, -1, 0) pytree_T = {"b": ({"a": leaf1_T}, leaf2_T)} operator_T = operator.T assert operator_T.in_structure() == out_struct assert operator_T.out_structure() == in_struct assert eqx.tree_equal(operator_T.pytree, pytree_T) # pyright: ignore def test_diagonal_tangent(): diag = jnp.array([1.0, 2.0, 3.0]) t_diag = jnp.array([4.0, 5.0, 6.0]) def run(diag): op = lx.DiagonalLinearOperator(diag) out = lx.linear_solve(op, jnp.array([1.0, 1.0, 1.0]), solver=lx.Diagonal()) return out.value jax.jvp(run, (diag,), (t_diag,)) def test_identity_with_different_structures(): structure1 = ( jax.ShapeDtypeStruct((), jnp.float32), jax.ShapeDtypeStruct((2, 3), jnp.float16), ) structure2 = {"a": jax.ShapeDtypeStruct((5,), jnp.float32)} # structure3 = (None, jax.ShapeDtypeStruct((2, 3), jnp.float16)) op1 = lx.IdentityLinearOperator(structure1, structure2) op2 = lx.IdentityLinearOperator(structure2, structure1) # op3 = lx.IdentityLinearOperator(structure3, structure2) assert op1.T == op2 # assert op2.transpose((True, False)) == op3 assert jnp.array_equal(op1.as_matrix(), jnp.eye(5, 7, dtype=jnp.float32)) assert op1.in_size() == 7 assert op1.out_size() == 5 vec1 = ( jnp.array(1.0, dtype=jnp.float32), jnp.array([[2, 3, 4], [5, 6, 7]], dtype=jnp.float16), ) vec2 = {"a": jnp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=jnp.float32)} vec1b = ( jnp.array(1.0, dtype=jnp.float32), jnp.array([[2, 3, 4], [5, 0, 0]], dtype=jnp.float16), ) assert tree_allclose(op1.mv(vec1), vec2) assert tree_allclose(op2.mv(vec2), vec1b) def test_identity_with_different_structures_complex(): structure1 = ( jax.ShapeDtypeStruct((), jnp.complex128), jax.ShapeDtypeStruct((2, 3), jnp.float16), ) structure2 = {"a": jax.ShapeDtypeStruct((5,), jnp.complex128)} # structure3 = (None, jax.ShapeDtypeStruct((2, 3), jnp.float16)) op1 = lx.IdentityLinearOperator(structure1, structure2) op2 = lx.IdentityLinearOperator(structure2, structure1) # op3 = lx.IdentityLinearOperator(structure3, structure2) assert op1.T == op2 # assert op2.transpose((True, False)) == op3 assert jnp.array_equal(op1.as_matrix(), jnp.eye(5, 7, dtype=jnp.complex128)) assert op1.in_size() == 7 assert op1.out_size() == 5 vec1 = ( jnp.array(1.0, dtype=jnp.complex128), jnp.array([[2, 3, 4], [5, 6, 7]], dtype=jnp.float16), ) vec2 = {"a": jnp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=jnp.complex128)} vec1b = ( jnp.array(1.0, dtype=jnp.complex128), jnp.array([[2, 3, 4], [5, 0, 0]], dtype=jnp.float16), ) assert tree_allclose(op1.mv(vec1), vec2) assert tree_allclose(op2.mv(vec2), vec1b) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_zero_pytree_as_matrix(dtype): a = jnp.array([], dtype=dtype).reshape(2, 1, 0, 2, 1, 0) struct = jax.ShapeDtypeStruct((2, 1, 0), a.dtype) op = lx.PyTreeLinearOperator(a, struct) assert op.as_matrix().shape == (0, 0) def test_jacrev_operator(): # Test that custom_vjp is respected. The custom backward multiplies by 3 # instead of the true derivative (which would be 2). # This tests that lineax uses the custom_vjp, not the true derivative. @jax.custom_vjp def f(x, _): return dict(foo=x["bar"] * 2) # forward: multiply by 2 def f_fwd(x, _): return f(x, None), None def f_bwd(_, g): # Custom backward: multiply by 3 (not the true derivative 2) # This must be linear in g for linear_transpose to work correctly. return dict(bar=g["foo"] * 3), None f.defvjp(f_fwd, f_bwd) x = dict(bar=jnp.arange(2.0)) rev_op = lx.JacobianLinearOperator(f, x, jac="bwd") # Jacobian is 3*I (from custom backward, not 2*I from true derivative) as_matrix = jnp.array([[3.0, 0.0], [0.0, 3.0]]) assert tree_allclose(rev_op.as_matrix(), as_matrix) y = dict(bar=jnp.arange(2.0) + 1) # y = [1, 2] true_out = dict(foo=jnp.array([3.0, 6.0])) # 3*I @ [1, 2] = [3, 6] for op in (rev_op, lx.materialise(rev_op)): out = op.mv(y) assert tree_allclose(out, true_out) fwd_op = lx.JacobianLinearOperator(f, x, jac="fwd") with pytest.raises(TypeError, match="can't apply forward-mode autodiff"): fwd_op.mv(y) with pytest.raises(TypeError, match="can't apply forward-mode autodiff"): lx.materialise(fwd_op) ================================================ FILE: tests/test_singular.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools as ft import equinox as eqx import jax import jax.numpy as jnp import jax.random as jr import lineax as lx import pytest from .helpers import ( construct_singular_matrix, finite_difference_jvp, make_jac_operator, make_matrix_operator, ops, params, tol, tree_allclose, ) @pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=True)) @pytest.mark.parametrize("ops", ops) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_small_singular(make_operator, solver, tags, ops, getkey, dtype): if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 else: tol = 1e-4 (matrix,) = construct_singular_matrix(getkey, solver, tags, dtype=dtype) operator = make_operator(getkey, matrix, tags) operator, matrix = ops(operator, matrix) assert tree_allclose(operator.as_matrix(), matrix, rtol=tol, atol=tol) out_size, in_size = matrix.shape true_x = jr.normal(getkey(), (in_size,), dtype=dtype) b = matrix @ true_x x = lx.linear_solve(operator, b, solver=solver, throw=False).value jax_x, *_ = jnp.linalg.lstsq(matrix, b) # pyright: ignore assert tree_allclose(x, jax_x, atol=tol, rtol=tol) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_bicgstab_breakdown(getkey, dtype): if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 else: tol = 1e-4 solver = lx.GMRES(atol=tol, rtol=tol, restart=2) matrix = jr.normal(jr.PRNGKey(0), (100, 100), dtype=dtype) true_x = jr.normal(jr.PRNGKey(0), (100,), dtype=dtype) b = matrix @ true_x operator = lx.MatrixLinearOperator(matrix) # result != 0 implies lineax reported failure lx_soln = lx.linear_solve(operator, b, solver, throw=False) assert jnp.all(lx_soln.result != lx.RESULTS.successful) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_gmres_stagnation_or_breakdown(getkey, dtype): if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 else: tol = 1e-4 solver = lx.GMRES(atol=tol, rtol=tol, restart=2) matrix = jnp.array( [ [0.15892892, 0.05884365, -0.60427412, 0.1891916], [-1.5484863, 0.93608822, 1.94888868, 1.37069667], [0.62687318, -0.13996738, -0.6824359, 0.30975754], [-0.67428635, 1.52372255, -0.88277754, 0.69633816], ], dtype=dtype, ) true_x = jnp.array([0.51383273, 1.72983427, -0.43251078, -1.11764668], dtype=dtype) b = matrix @ true_x operator = lx.MatrixLinearOperator(matrix) # result != 0 implies lineax reported failure lx_soln = lx.linear_solve(operator, b, solver, throw=False) assert jnp.all(lx_soln.result != lx.RESULTS.successful) @pytest.mark.parametrize( "solver", ( lx.AutoLinearSolver(well_posed=None), lx.QR(), lx.SVD(), lx.LSMR(atol=tol, rtol=tol), lx.Normal(lx.Cholesky()), lx.Normal(lx.SVD()), ), ) def test_nonsquare_pytree_operator1(solver): x = [[1, 5.0, jnp.array(-1.0)], [jnp.array(-2), jnp.array(-2.0), 3.0]] y = [3.0, 4] struct = jax.eval_shape(lambda: y) operator = lx.PyTreeLinearOperator(x, struct) out = lx.linear_solve(operator, y, solver=solver).value matrix = jnp.array([[1.0, 5.0, -1.0], [-2.0, -2.0, 3.0]]) true_out, _, _, _ = jnp.linalg.lstsq(matrix, jnp.array(y)) # pyright: ignore true_out = [true_out[0], true_out[1], true_out[2]] assert tree_allclose(out, true_out) @pytest.mark.parametrize( "solver", ( lx.AutoLinearSolver(well_posed=None), lx.QR(), lx.SVD(), lx.LSMR(atol=tol, rtol=tol), lx.Normal(lx.Cholesky()), lx.Normal(lx.SVD()), ), ) def test_nonsquare_pytree_operator2(solver): x = [[1, jnp.array(-2)], [5.0, jnp.array(-2.0)], [jnp.array(-1.0), 3.0]] y = [3.0, 4, 5.0] struct = jax.eval_shape(lambda: y) operator = lx.PyTreeLinearOperator(x, struct) out = lx.linear_solve(operator, y, solver=solver).value matrix = jnp.array([[1.0, -2.0], [5.0, -2.0], [-1.0, 3.0]]) true_out, _, _, _ = jnp.linalg.lstsq(matrix, jnp.array(y)) # pyright: ignore true_out = [true_out[0], true_out[1]] assert tree_allclose(out, true_out) @pytest.mark.parametrize( "solver", ( lx.AutoLinearSolver(well_posed=None), lx.QR(), lx.SVD(), lx.Normal(lx.Cholesky()), lx.Normal(lx.SVD()), ), ) @pytest.mark.parametrize("full_rank", (True, False)) @pytest.mark.parametrize("jvp", (False, True)) @pytest.mark.parametrize("wide", (False, True)) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_nonsquare_mat_vec(solver, full_rank, jvp, wide, dtype, getkey): if wide: out_size = 3 in_size = 6 else: out_size = 6 in_size = 3 matrix = jr.normal(getkey(), (out_size, in_size), dtype=dtype) if not full_rank: if solver.assume_full_rank(): # There is nothing to test. return # nontrivial rank 2 sparsity pattern matrix = matrix.at[1:, 1:].set(0) vector = jr.normal(getkey(), (out_size,), dtype=dtype) lx_solve = lambda mat, vec: lx.linear_solve( lx.MatrixLinearOperator(mat), vec, solver ).value jnp_solve = lambda mat, vec: jnp.linalg.lstsq(mat, vec)[0] # pyright: ignore if jvp: lx_solve = eqx.filter_jit(ft.partial(eqx.filter_jvp, lx_solve)) jnp_solve = eqx.filter_jit(ft.partial(finite_difference_jvp, jnp_solve)) t_matrix = jr.normal(getkey(), (out_size, in_size), dtype=dtype) if not full_rank: # t_matrix must be chosen tangent to the manifold of rank 2 # matrices at matrix. A simple way to achieve this is to make the # same restriction as we did to matrix t_matrix = t_matrix.at[1:, 1:].set(0) t_vector = jr.normal(getkey(), (out_size,), dtype=dtype) args = ((matrix, vector), (t_matrix, t_vector)) else: args = (matrix, vector) x = lx_solve(*args) # pyright: ignore true_x = jnp_solve(*args) assert tree_allclose(x, true_x, atol=1e-4, rtol=1e-4) @pytest.mark.parametrize( "solver", ( lx.AutoLinearSolver(well_posed=None), lx.QR(), lx.SVD(), lx.Normal(lx.Cholesky()), lx.Normal(lx.SVD()), ), ) @pytest.mark.parametrize("full_rank", (True, False)) @pytest.mark.parametrize("jvp", (False, True)) @pytest.mark.parametrize("wide", (False, True)) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_nonsquare_vec(solver, full_rank, jvp, wide, dtype, getkey): if wide: out_size = 3 in_size = 6 else: out_size = 6 in_size = 3 matrix = jr.normal(getkey(), (out_size, in_size), dtype=dtype) if not full_rank: if solver.assume_full_rank(): # There is nothing to test. return # nontrivial rank 2 sparsity pattern matrix = matrix.at[1:, 1:].set(0) vector = jr.normal(getkey(), (out_size,), dtype=dtype) lx_solve = lambda vec: lx.linear_solve( lx.MatrixLinearOperator(matrix), vec, solver ).value jnp_solve = lambda vec: jnp.linalg.lstsq(matrix, vec)[0] # pyright: ignore if jvp: lx_solve = eqx.filter_jit(ft.partial(eqx.filter_jvp, lx_solve)) jnp_solve = eqx.filter_jit(ft.partial(finite_difference_jvp, jnp_solve)) t_vector = jr.normal(getkey(), (out_size,), dtype=dtype) args = ((vector,), (t_vector,)) else: args = (vector,) x = lx_solve(*args) # pyright: ignore true_x = jnp_solve(*args) assert tree_allclose(x, true_x, atol=1e-4, rtol=1e-4) _iterative_solvers = ( (lx.CG(rtol=tol, atol=tol), lx.positive_semidefinite_tag), (lx.CG(rtol=tol, atol=tol, max_steps=512), lx.negative_semidefinite_tag), (lx.GMRES(rtol=tol, atol=tol), ()), (lx.BiCGStab(rtol=tol, atol=tol), ()), ) @pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) @pytest.mark.parametrize("solver, tags", _iterative_solvers) @pytest.mark.parametrize("use_state", (False, True)) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_iterative_singular(getkey, solver, tags, use_state, make_operator, dtype): (matrix,) = construct_singular_matrix(getkey, solver, tags) operator = make_operator(getkey, matrix, tags) out_size, _ = matrix.shape vec = jr.normal(getkey(), (out_size,), dtype=dtype) if use_state: state = solver.init(operator, options={}) linear_solve = ft.partial(lx.linear_solve, state=state) else: linear_solve = lx.linear_solve with pytest.raises(Exception): linear_solve(operator, vec, solver) ================================================ FILE: tests/test_solve.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax import jax.numpy as jnp import jax.random as jr import lineax as lx import pytest from .helpers import construct_poisson_matrix, tree_allclose def test_gmres_large_dense(getkey): if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 else: tol = 1e-4 solver = lx.GMRES(atol=tol, rtol=tol, restart=100) matrix = jr.normal(getkey(), (100, 100)) operator = lx.MatrixLinearOperator(matrix) true_x = jr.normal(getkey(), (100,)) b = matrix @ true_x lx_soln = lx.linear_solve(operator, b, solver).value assert tree_allclose(lx_soln, true_x, atol=tol, rtol=tol) def test_nontrivial_pytree_operator(): x = [[1, 5.0], [jnp.array(-2), jnp.array(-2.0)]] y = [3, 4] struct = jax.eval_shape(lambda: y) operator = lx.PyTreeLinearOperator(x, struct) out = lx.linear_solve(operator, y).value true_out = [jnp.array(-3.25), jnp.array(1.25)] assert tree_allclose(out, true_out) def test_nontrivial_diagonal_operator(): x = (8.0, jnp.array([1, 2, 3]), {"a": jnp.array([4, 5]), "b": 6}) y = (4.0, jnp.array([7, 8, 9]), {"a": jnp.array([2, 10]), "b": 12}) operator = lx.DiagonalLinearOperator(x) out = lx.linear_solve(operator, y).value true_out = ( jnp.array(0.5), jnp.array([7.0, 4.0, 3.0]), {"a": jnp.array([0.5, 2.0]), "b": jnp.array(2.0)}, ) assert tree_allclose(out, true_out) @pytest.mark.parametrize("solver", (lx.LU(), lx.QR(), lx.SVD())) def test_mixed_dtypes(solver): f32 = lambda x: jnp.array(x, dtype=jnp.float32) f64 = lambda x: jnp.array(x, dtype=jnp.float64) x = [[f32(1), f64(5)], [f32(-2), f64(-2)]] y = [f64(3), f64(4)] struct = jax.eval_shape(lambda: y) operator = lx.PyTreeLinearOperator(x, struct) out = lx.linear_solve(operator, y, solver=solver).value true_out = [f32(-3.25), f64(1.25)] assert tree_allclose(out, true_out) @pytest.mark.parametrize("solver", (lx.LU(), lx.QR(), lx.SVD())) def test_mixed_dtypes_complex(solver): c64 = lambda x: jnp.array(x, dtype=jnp.complex64) c128 = lambda x: jnp.array(x, dtype=jnp.complex128) x = [[c64(1), c128(5.0j)], [c64(2.0j), c128(-2)]] y = [c128(3), c128(4)] struct = jax.eval_shape(lambda: y) operator = lx.PyTreeLinearOperator(x, struct) out = lx.linear_solve(operator, y, solver=solver).value true_out = [c64(-0.75 - 2.5j), c128(0.5 - 0.75j)] assert tree_allclose(out, true_out) @pytest.mark.parametrize("solver", (lx.LU(), lx.QR(), lx.SVD())) def test_mixed_dtypes_complex_real(solver): f64 = lambda x: jnp.array(x, dtype=jnp.float64) c128 = lambda x: jnp.array(x, dtype=jnp.complex128) x = [[f64(1), c128(-5.0j)], [f64(2.0), c128(-2j)]] y = [c128(3), c128(4)] struct = jax.eval_shape(lambda: y) operator = lx.PyTreeLinearOperator(x, struct) out = lx.linear_solve(operator, y, solver=solver).value true_out = [f64(1.75), c128(0.25j)] assert tree_allclose(out, true_out) def test_mixed_dtypes_triangular(): f32 = lambda x: jnp.array(x, dtype=jnp.float32) f64 = lambda x: jnp.array(x, dtype=jnp.float64) x = [[f32(1), f64(0)], [f32(-2), f64(-2)]] y = [f64(3), f64(4)] struct = jax.eval_shape(lambda: y) operator = lx.PyTreeLinearOperator(x, struct, lx.lower_triangular_tag) out = lx.linear_solve(operator, y, solver=lx.Triangular()).value true_out = [f32(3), f64(-5)] assert tree_allclose(out, true_out) def test_mixed_dtypes_complex_triangular(): c64 = lambda x: jnp.array(x, dtype=jnp.complex64) c128 = lambda x: jnp.array(x, dtype=jnp.complex128) x = [[c64(1), c128(0)], [c64(2.0j), c128(-2)]] y = [c128(3), c128(4)] struct = jax.eval_shape(lambda: y) operator = lx.PyTreeLinearOperator(x, struct, lx.lower_triangular_tag) out = lx.linear_solve(operator, y, solver=lx.Triangular()).value true_out = [c64(3), c128(-2 + 3.0j)] assert tree_allclose(out, true_out) def test_mixed_dtypes_complex_real_triangular(): f64 = lambda x: jnp.array(x, dtype=jnp.float64) c128 = lambda x: jnp.array(x, dtype=jnp.complex128) x = [[f64(1), c128(0)], [f64(2.0), c128(2j)]] y = [c128(3), c128(4)] struct = jax.eval_shape(lambda: y) operator = lx.PyTreeLinearOperator(x, struct, lx.lower_triangular_tag) out = lx.linear_solve(operator, y, solver=lx.Triangular()).value true_out = [f64(3), c128(1j)] assert tree_allclose(out, true_out) def test_ad_closure_function_linear_operator(getkey): def f(x, z): def fn(y): return x * y op = lx.FunctionLinearOperator(fn, jax.eval_shape(lambda: z)) sol = lx.linear_solve(op, z).value return jnp.sum(sol), sol x = jr.normal(getkey(), (3,)) x = jnp.where(jnp.abs(x) < 1e-6, 0.7, x) z = jr.normal(getkey(), (3,)) grad, sol = jax.grad(f, has_aux=True)(x, z) assert tree_allclose(grad, -z / (x**2)) assert tree_allclose(sol, z / x) def test_grad_vmap_symbolic_cotangent(): def f(x): return x[0], x[1] @jax.vmap def to_vmap(x): op = lx.FunctionLinearOperator(f, jax.eval_shape(lambda: x)) sol = lx.linear_solve(op, x) return sol.value[0] @jax.grad def to_grad(x): return jnp.sum(to_vmap(x)) x = (jnp.arange(3.0), jnp.arange(3.0)) to_grad(x) @pytest.mark.parametrize( "solver", ( lx.CG(0.0, 0.0, max_steps=2), lx.Normal(lx.CG(0.0, 0.0, max_steps=2)), lx.BiCGStab(0.0, 0.0, max_steps=2), lx.GMRES(0.0, 0.0, max_steps=2), lx.LSMR(0.0, 0.0, max_steps=2), ), ) def test_iterative_solver_max_steps_only(solver): """Iterative solvers should work with max_steps only (no Equinox errors).""" SIZE = 100 poisson_matrix = construct_poisson_matrix(SIZE) poisson_operator = lx.MatrixLinearOperator( poisson_matrix, tags=(lx.negative_semidefinite_tag, lx.symmetric_tag) ) rhs = jax.random.normal(jax.random.key(0), (SIZE,)) lx.linear_solve(poisson_operator, rhs, solver) def test_solver_init_not_differentiated(getkey): """stop_gradient should be applied before solver.init, not after. Also checks that dynamic arrays in options don't cause issues. """ class DisallowGradWrapper(lx._solve.AbstractLinearSolver): solver: lx._solve.AbstractLinearSolver def init(self, operator, options): @jax.custom_jvp def f(operator, dummy): del dummy return self.solver.init(operator, options) @f.defjvp def _(*args): raise NotImplementedError("solver.init should not be differentiated") return f(operator, options.get("dummy")) def compute(self, state, vector, options): return self.solver.compute(state, vector, options) def transpose(self, state, options): return self.solver.transpose(state, options) def conj(self, state, options): return self.solver.conj(state, options) def assume_full_rank(self): return self.solver.assume_full_rank() m = jax.random.normal(getkey(), (3, 3)) mt = jax.random.normal(getkey(), (3, 3)) v = jax.random.normal(getkey(), (3,)) dummy = jnp.array(1.0) def f(m): op = lx.MatrixLinearOperator(m) return lx.linear_solve( op, v, solver=DisallowGradWrapper(lx.QR()), options={"dummy": dummy} ).value # Differentiating through operator only, but options has a dynamic array. # solver.init should not be differentiated through. jax.jvp(f, (m,), (mt,)) _, f_vjp = jax.vjp(f, m) f_vjp(v) def test_nonfinite_input(): operator = lx.DiagonalLinearOperator((1.0, 1.0)) vector = (1.0, jnp.inf) sol = lx.linear_solve(operator, vector, throw=False) assert sol.result == lx.RESULTS.nonfinite_input vector = (1.0, jnp.nan) sol = lx.linear_solve(operator, vector, throw=False) assert sol.result == lx.RESULTS.nonfinite_input vector = (jnp.nan, jnp.inf) sol = lx.linear_solve(operator, vector, throw=False) assert sol.result == lx.RESULTS.nonfinite_input ================================================ FILE: tests/test_transpose.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import equinox as eqx import jax import jax.numpy as jnp import jax.random as jr import lineax as lx import pytest from .helpers import construct_matrix, params, tree_allclose class TestTranspose: @pytest.fixture(scope="class") def assert_transpose_fixture(_): @eqx.filter_jit def solve_transpose(operator, out_vec, in_vec, solver): return jax.linear_transpose( lambda v: lx.linear_solve(operator, v, solver).value, out_vec )(in_vec) def assert_transpose(operator, out_vec, in_vec, solver): (out,) = solve_transpose(operator, out_vec, in_vec, solver) true_out = lx.linear_solve(operator.T, in_vec, solver).value assert tree_allclose(out, true_out) return assert_transpose @pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=False)) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_transpose( _, make_operator, solver, tags, assert_transpose_fixture, dtype, getkey ): (matrix,) = construct_matrix(getkey, solver, tags, dtype=dtype) operator = make_operator(getkey, matrix, tags) out_size, in_size = matrix.shape out_vec = jr.normal(getkey(), (out_size,), dtype=dtype) in_vec = jr.normal(getkey(), (in_size,), dtype=dtype) solver = lx.AutoLinearSolver(well_posed=True) assert_transpose_fixture(operator, out_vec, in_vec, solver) def test_pytree_transpose(_, assert_transpose_fixture): # pyright: ignore a = jnp.array pytree = [[a(1), a(2), a(3)], [a(4), a(5), a(6)]] output_structure = jax.eval_shape(lambda: [1, 2]) operator = lx.PyTreeLinearOperator(pytree, output_structure) out_vec = [a(1.0), a(2.0)] in_vec = [a(1.0), 2.0, 3.0] solver = lx.AutoLinearSolver(well_posed=False) assert_transpose_fixture(operator, out_vec, in_vec, solver) ================================================ FILE: tests/test_vmap.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import equinox as eqx import jax import jax.numpy as jnp import jax.random as jr import lineax as lx import pytest from .helpers import ( construct_matrix, construct_singular_matrix, make_jac_operator, make_matrix_operator, solvers_tags_pseudoinverse, tree_allclose, ) @pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) @pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) @pytest.mark.parametrize("use_state", (True, False)) @pytest.mark.parametrize( "make_matrix", ( construct_matrix, construct_singular_matrix, ), ) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_vmap( getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix, dtype ): if (make_matrix is construct_matrix) or pseudoinverse: def wrap_solve(matrix, vector): operator = make_operator(getkey, matrix, tags) if use_state: state = solver.init(operator, options={}) return lx.linear_solve(operator, vector, solver, state=state).value else: return lx.linear_solve(operator, vector, solver).value for op_axis, vec_axis in ( (None, 0), (eqx.if_array(0), None), (eqx.if_array(0), 0), ): if op_axis is None: axis_size = None out_axes = None else: axis_size = 10 out_axes = eqx.if_array(0) (matrix,) = eqx.filter_vmap( lambda getkey, solver, tags: make_matrix( getkey, solver, tags, dtype=dtype ), axis_size=axis_size, out_axes=out_axes, )(getkey, solver, tags) out_dim = matrix.shape[-2] if vec_axis is None: vec = jr.normal(getkey(), (out_dim,), dtype=dtype) else: vec = jr.normal(getkey(), (10, out_dim), dtype=dtype) jax_result, _, _, _ = eqx.filter_vmap( jnp.linalg.lstsq, in_axes=(op_axis, vec_axis), # pyright: ignore )(matrix, vec) lx_result = eqx.filter_vmap(wrap_solve, in_axes=(op_axis, vec_axis))( matrix, vec ) assert tree_allclose(lx_result, jax_result) # https://github.com/patrick-kidger/lineax/issues/101 def test_grad_vmap_basic(getkey): A = jr.normal(getkey(), (16, 8)) B = jr.normal(getkey(), (128, 16)) @jax.jit @jax.grad def fn(A): op = lx.MatrixLinearOperator(A) return jax.vmap( lambda b: lx.linear_solve( op, b, lx.AutoLinearSolver(well_posed=False) ).value )(B).mean() fn(A) def test_grad_vmap_advanced(getkey): # this is a more complicated version of the above test, in which the batch axes and # the undefinedprimals do not necessarily line up in the same arguments. A = jr.normal(getkey(), (2, 8)), jr.normal(getkey(), (3, 8, 128)) B = jr.normal(getkey(), (2, 128)), jr.normal(getkey(), (3,)) output_structure = ( jax.ShapeDtypeStruct((2,), jnp.float64), jax.ShapeDtypeStruct((3,), jnp.float64), ) def to_vmap(A, B): op = lx.PyTreeLinearOperator(A, output_structure) return lx.linear_solve(op, B, lx.AutoLinearSolver(well_posed=False)).value @jax.jit @jax.grad def fn(A): return jax.vmap(to_vmap, in_axes=((None, 2), (1, None)))(A, B).mean() fn(A) ================================================ FILE: tests/test_vmap_jvp.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools as ft import equinox as eqx import jax.lax as lax import jax.numpy as jnp import jax.random as jr import lineax as lx import pytest from .helpers import ( construct_matrix, construct_singular_matrix, make_jac_operator, make_matrix_operator, solvers_tags_pseudoinverse, tree_allclose, ) @pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) @pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) @pytest.mark.parametrize("use_state", (True, False)) @pytest.mark.parametrize( "make_matrix", ( construct_matrix, construct_singular_matrix, ), ) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_vmap_jvp( getkey, solver, tags, make_operator, pseudoinverse, use_state, make_matrix, dtype ): if (make_matrix is construct_matrix) or pseudoinverse: t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None if pseudoinverse: jnp_solve1 = lambda mat, vec: jnp.linalg.lstsq(mat, vec)[0] # pyright: ignore else: jnp_solve1 = jnp.linalg.solve # pyright: ignore if use_state: def linear_solve1(operator, vector): op_dynamic, op_static = eqx.partition(operator, eqx.is_inexact_array) stopped_operator = eqx.combine(lax.stop_gradient(op_dynamic), op_static) state = solver.init(stopped_operator, options={}) return lx.linear_solve(operator, vector, state=state, solver=solver) else: linear_solve1 = ft.partial(lx.linear_solve, solver=solver) for mode in ("vec", "op", "op_vec"): if "op" in mode: axis_size = 10 out_axes = eqx.if_array(0) else: axis_size = None out_axes = None def _make(): matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype) make_op = ft.partial(make_operator, getkey) operator, t_operator = eqx.filter_jvp( make_op, (matrix, tags), (t_matrix, t_tags) ) return matrix, t_matrix, operator, t_operator matrix, t_matrix, operator, t_operator = eqx.filter_vmap( _make, axis_size=axis_size, out_axes=out_axes )() if "op" in mode: _, out_size, _ = matrix.shape else: out_size, _ = matrix.shape if "vec" in mode: vec = jr.normal(getkey(), (10, out_size), dtype=dtype) t_vec = jr.normal(getkey(), (10, out_size), dtype=dtype) else: vec = jr.normal(getkey(), (out_size,), dtype=dtype) t_vec = jr.normal(getkey(), (out_size,), dtype=dtype) if mode == "op": linear_solve2 = lambda op: linear_solve1(op, vector=vec) jnp_solve2 = lambda mat: jnp_solve1(mat, vec) elif mode == "vec": linear_solve2 = lambda vector: linear_solve1(operator, vector) jnp_solve2 = lambda vector: jnp_solve1(matrix, vector) elif mode == "op_vec": linear_solve2 = linear_solve1 jnp_solve2 = jnp_solve1 else: assert False for jvp_first in (True, False): if jvp_first: linear_solve3 = ft.partial(eqx.filter_jvp, linear_solve2) else: linear_solve3 = linear_solve2 linear_solve3 = eqx.filter_vmap(linear_solve3) if not jvp_first: linear_solve3 = ft.partial(eqx.filter_jvp, linear_solve3) linear_solve3 = eqx.filter_jit(linear_solve3) jnp_solve3 = ft.partial(eqx.filter_jvp, jnp_solve2) jnp_solve3 = eqx.filter_vmap(jnp_solve3) jnp_solve3 = eqx.filter_jit(jnp_solve3) if mode == "op": out, t_out = linear_solve3((operator,), (t_operator,)) true_out, true_t_out = jnp_solve3((matrix,), (t_matrix,)) elif mode == "vec": out, t_out = linear_solve3((vec,), (t_vec,)) true_out, true_t_out = jnp_solve3((vec,), (t_vec,)) elif mode == "op_vec": out, t_out = linear_solve3((operator, vec), (t_operator, t_vec)) true_out, true_t_out = jnp_solve3((matrix, vec), (t_matrix, t_vec)) else: assert False assert tree_allclose(out.value, true_out, atol=1e-4) assert tree_allclose(t_out.value, true_t_out, atol=1e-4) ================================================ FILE: tests/test_vmap_vmap.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools as ft import equinox as eqx import jax.numpy as jnp import jax.random as jr import lineax as lx import pytest from .helpers import ( construct_matrix, construct_singular_matrix, make_jac_operator, make_matrix_operator, solvers_tags_pseudoinverse, tree_allclose, ) @pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) @pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) @pytest.mark.parametrize("use_state", (True, False)) @pytest.mark.parametrize( "make_matrix", ( construct_matrix, construct_singular_matrix, ), ) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_vmap_vmap( getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix, dtype ): if (make_matrix is construct_matrix) or pseudoinverse: # combinations with nontrivial application across both vmaps axes = [ (eqx.if_array(0), eqx.if_array(0), None, None), (None, None, 0, 0), (eqx.if_array(0), eqx.if_array(0), None, 0), (eqx.if_array(0), eqx.if_array(0), 0, 0), (None, eqx.if_array(0), 0, 0), ] for vmap2_op, vmap1_op, vmap2_vec, vmap1_vec in axes: if vmap1_op is not None: axis_size1 = 10 out_axis1 = eqx.if_array(0) else: axis_size1 = None out_axis1 = None if vmap2_op is not None: axis_size2 = 10 out_axis2 = eqx.if_array(0) else: axis_size2 = None out_axis2 = None (matrix,) = eqx.filter_vmap( eqx.filter_vmap( lambda getkey, solver, tags: make_matrix( getkey, solver, tags, dtype=dtype ), axis_size=axis_size1, out_axes=out_axis1, ), axis_size=axis_size2, out_axes=out_axis2, )(getkey, solver, tags) if vmap1_op is not None: if vmap2_op is not None: _, _, out_size, _ = matrix.shape else: _, out_size, _ = matrix.shape else: out_size, _ = matrix.shape if vmap1_vec is None: vec = jr.normal(getkey(), (out_size,), dtype=dtype) elif (vmap1_vec is not None) and (vmap2_vec is None): vec = jr.normal(getkey(), (10, out_size), dtype=dtype) else: vec = jr.normal(getkey(), (10, 10, out_size), dtype=dtype) make_op = ft.partial(make_operator, getkey) operator = eqx.filter_vmap( eqx.filter_vmap( make_op, in_axes=vmap1_op, out_axes=out_axis1, ), in_axes=vmap2_op, out_axes=out_axis2, )(matrix, tags) if use_state: def linear_solve(operator, vector): state = solver.init(operator, options={}) return lx.linear_solve(operator, vector, state=state, solver=solver) else: def linear_solve(operator, vector): return lx.linear_solve(operator, vector, solver) as_matrix_vmapped = eqx.filter_vmap( eqx.filter_vmap( lambda x: x.as_matrix(), in_axes=vmap1_op, out_axes=None if vmap1_op is None else 0, ), in_axes=vmap2_op, out_axes=None if vmap2_op is None else 0, )(operator) vmap1_axes = (vmap1_op, vmap1_vec) vmap2_axes = (vmap2_op, vmap2_vec) result = eqx.filter_vmap( eqx.filter_vmap(linear_solve, in_axes=vmap1_axes), in_axes=vmap2_axes )(operator, vec).value solve_with = lambda x: eqx.filter_vmap( eqx.filter_vmap(x, in_axes=vmap1_axes), in_axes=vmap2_axes )(as_matrix_vmapped, vec) if make_matrix is construct_singular_matrix: true_result, _, _, _ = solve_with(jnp.linalg.lstsq) # pyright: ignore else: true_result = solve_with(jnp.linalg.solve) # pyright: ignore assert tree_allclose(result, true_result, rtol=1e-3) ================================================ FILE: tests/test_well_posed.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax import jax.numpy as jnp import jax.random as jr import lineax as lx import pytest from .helpers import ( construct_matrix, make_jacrev_operator, ops, params, solvers, tree_allclose, ) @pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=False)) @pytest.mark.parametrize("ops", ops) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype): if make_operator is make_jacrev_operator and dtype is jnp.complex128: # JacobianLinearOperator does not support complex dtypes when jac="bwd" return if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 else: tol = 1e-4 (matrix,) = construct_matrix(getkey, solver, tags, dtype=dtype) operator = make_operator(getkey, matrix, tags) operator, matrix = ops(operator, matrix) assert tree_allclose(operator.as_matrix(), matrix, rtol=tol, atol=tol) out_size, _ = matrix.shape true_x = jr.normal(getkey(), (out_size,), dtype=dtype) b = matrix @ true_x x = lx.linear_solve(operator, b, solver=solver).value jax_x = jnp.linalg.solve(matrix, b) # pyright: ignore assert tree_allclose(x, true_x, atol=tol, rtol=tol) assert tree_allclose(x, jax_x, atol=tol, rtol=tol) @pytest.mark.parametrize("solver", solvers) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_pytree_wellposed(solver, getkey, dtype): if not isinstance( solver, (lx.Diagonal, lx.Triangular, lx.Tridiagonal, lx.Cholesky, lx.CG), ): if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 else: tol = 1e-4 true_x = [ jr.normal(getkey(), shape=(2, 4), dtype=dtype), jr.normal(getkey(), (3,), dtype=dtype), ] pytree = [ [ jr.normal(getkey(), shape=(2, 4, 2, 4), dtype=dtype), jr.normal(getkey(), shape=(2, 4, 3), dtype=dtype), ], [ jr.normal(getkey(), shape=(3, 2, 4), dtype=dtype), jr.normal(getkey(), shape=(3, 3), dtype=dtype), ], ] out_structure = jax.eval_shape(lambda: true_x) operator = lx.PyTreeLinearOperator(pytree, out_structure) b = operator.mv(true_x) lx_x = lx.linear_solve(operator, b, solver, throw=False) assert tree_allclose(lx_x.value, true_x, atol=tol, rtol=tol)