Repository: patrick-kidger/lineax
Branch: main
Commit: 2bf9824f7df9
Files: 75
Total size: 416.0 KB
Directory structure:
gitextract_at99w4wk/
├── .github/
│ └── workflows/
│ ├── build_docs.yml
│ ├── release.yml
│ └── run_tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── benchmarks/
│ ├── gmres_fails_safely.py
│ ├── lstsq_gradients.py
│ └── solver_speeds.py
├── docs/
│ ├── .htaccess
│ ├── _overrides/
│ │ └── partials/
│ │ └── source.html
│ ├── _static/
│ │ ├── custom_css.css
│ │ └── mathjax.js
│ ├── api/
│ │ ├── functions.md
│ │ ├── linear_solve.md
│ │ ├── operators.md
│ │ ├── solution.md
│ │ ├── solvers.md
│ │ └── tags.md
│ ├── examples/
│ │ ├── classical_solve.ipynb
│ │ ├── complex_solve.ipynb
│ │ ├── least_squares.ipynb
│ │ ├── no_materialisation.ipynb
│ │ ├── operators.ipynb
│ │ └── structured_matrices.ipynb
│ ├── faq.md
│ └── index.md
├── lineax/
│ ├── __init__.py
│ ├── _custom_types.py
│ ├── _misc.py
│ ├── _norm.py
│ ├── _operator.py
│ ├── _solution.py
│ ├── _solve.py
│ ├── _solver/
│ │ ├── __init__.py
│ │ ├── bicgstab.py
│ │ ├── cg.py
│ │ ├── cholesky.py
│ │ ├── diagonal.py
│ │ ├── gmres.py
│ │ ├── lsmr.py
│ │ ├── lu.py
│ │ ├── misc.py
│ │ ├── normal.py
│ │ ├── qr.py
│ │ ├── svd.py
│ │ ├── triangular.py
│ │ └── tridiagonal.py
│ ├── _tags.py
│ └── internal/
│ └── __init__.py
├── mkdocs.yml
├── pyproject.toml
└── tests/
├── README.md
├── __init__.py
├── __main__.py
├── conftest.py
├── helpers.py
├── test_adjoint.py
├── test_invert.py
├── test_jvp.py
├── test_jvp_jvp1.py
├── test_jvp_jvp2.py
├── test_lsmr.py
├── test_misc.py
├── test_norm.py
├── test_operator.py
├── test_singular.py
├── test_solve.py
├── test_transpose.py
├── test_vmap.py
├── test_vmap_jvp.py
├── test_vmap_vmap.py
└── test_well_posed.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/build_docs.yml
================================================
name: Build docs
on:
push:
branches:
- main
jobs:
build:
strategy:
matrix:
python-version: [ 3.11 ]
os: [ ubuntu-latest ]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
uv run echo done
- name: Build docs
run: |
uv run mkdocs build
- name: Upload docs
uses: actions/upload-artifact@v4
with:
name: docs
path: site # where `mkdocs build` puts the built site
================================================
FILE: .github/workflows/release.yml
================================================
name: Release
on:
push:
branches:
- main
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Release
uses: patrick-kidger/action_update_python_project@v8
with:
python-version: "3.11"
# Uninstall and reinstall pytest to work around the fact that it doesn't get put into `bin` otherwise.
test-script: |
cp -r ${{ github.workspace }}/tests ./tests
cp ${{ github.workspace }}/pyproject.toml ./pyproject.toml
uv pip uninstall pytest
uv sync --no-install-project --inexact
uv run --no-sync pytest
pypi-token: ${{ secrets.pypi_token }}
github-user: patrick-kidger
github-token: ${{ github.token }}
================================================
FILE: .github/workflows/run_tests.yml
================================================
name: Run tests
on:
pull_request:
jobs:
run-test:
strategy:
matrix:
python-version: [ 3.11 ]
os: [ ubuntu-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
uv run echo done
- name: Checks with pre-commit
run: |
uv run prek run --all-files
- name: Test with pytest
run: |
uv run python -m tests
- name: Check that documentation can be built.
run: |
uv run mkdocs build
================================================
FILE: .gitignore
================================================
**/__pycache__
**/.ipynb_checkpoints
*.egg-info/
build/
dist/
site/
examples/data
.all_objects.cache
.pymon
.idea
.venv
uv.lock
================================================
FILE: .pre-commit-config.yaml
================================================
fail_fast: true
repos:
- repo: meta
hooks:
- id: check-hooks-apply
- id: check-useless-excludes
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
exclude: \.md$
- id: check-toml
- id: mixed-line-ending
- repo: local
hooks:
- id: sort-pyproject
name: sort pyproject
files: ^pyproject\.toml$
language: system
entry: uv run -- toml-sort -i --sort-table-keys --sort-inline-tables
- id: ruff-format
name: ruff format
types_or: [python, pyi, jupyter, toml]
language: system
entry: uv run -- ruff format --
require_serial: true
- id: ruff-lint
name: ruff lint
types_or: [python, pyi, jupyter, toml]
language: system
entry: uv run -- ruff check --fix --
require_serial: true
- id: pyright
name: pyright
types_or: [python]
language: system
entry: uv run -- pyright
require_serial: true
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing
Contributions (pull requests) are very welcome! Here's how to get started.
---
### Getting started
[We assume that you have `uv` installed.](https://docs.astral.sh/uv/) Now fork the library on GitHub. Then clone and install the library:
```bash
git clone https://github.com/your-username-here/lineax.git
cd lineax
uv run prek install # Creates a local venv + installs dependencies + installs pre-commit hooks.
```
---
### If you're making changes to the code
Now make your changes. Make sure to include additional tests if necessary. Next verify the tests all pass:
```bash
uv run python -m tests
```
Then push your changes back to your fork of the repository:
```bash
git push
```
Finally, open a pull request on GitHub!
---
### If you're making changes to the documentation
Make your changes. You can then build the documentation by doing
```bash
uv run mkdocs serve
```
You can then see your local copy of the documentation by navigating to `localhost:8000` in a web browser.
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
Lineax
Lineax is a [JAX](https://github.com/google/jax) library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.)
Features include:
- PyTree-valued matrices and vectors;
- General linear operators for Jacobians, transposes, etc.;
- Efficient linear least squares (e.g. QR solvers);
- Numerically stable gradients through linear least squares;
- Support for structured (e.g. symmetric) matrices;
- Improved compilation times;
- Improved runtime of some algorithms;
- Support for both real-valued and complex-valued inputs;
- All the benefits of working with JAX: autodiff, autoparallelism, GPU/TPU support, etc.
## Installation
```bash
pip install lineax
```
Requires Python 3.10+, JAX 0.4.38+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.10+.
## Documentation
Available at [https://docs.kidger.site/lineax](https://docs.kidger.site/lineax).
## Quick examples
Lineax can solve a least squares problem with an explicit matrix operator:
```python
import jax.random as jr
import lineax as lx
matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 8))
vector = jr.normal(vector_key, (10,))
operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector, solver=lx.QR())
```
or Lineax can solve a problem without ever materializing a matrix, as done in this
quadratic solve:
```python
import jax
import lineax as lx
key = jax.random.PRNGKey(0)
y = jax.random.normal(key, (10,))
def quadratic_fn(y, args):
return jax.numpy.sum((y - 1)**2)
gradient_fn = jax.grad(quadratic_fn)
hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-6, atol=1e-6)
out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver)
minimum = y - out.value
```
## Citation
If you found this library to be useful in academic work, then please cite: ([arXiv link](https://arxiv.org/abs/2311.17283))
```bibtex
@article{lineax2023,
title={Lineax: unified linear solves and linear least-squares in JAX and Equinox},
author={Jason Rader and Terry Lyons and Patrick Kidger},
journal={
AI for science workshop at Neural Information Processing Systems 2023,
arXiv:2311.17283
},
year={2023},
}
```
(Also consider starring the project on GitHub.)
## See also: other libraries in the JAX ecosystem
**Always useful**
[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
[jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.
**Deep learning**
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
[paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
**Scientific computing**
[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
**Awesome JAX**
[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.
================================================
FILE: benchmarks/gmres_fails_safely.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools as ft
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import lineax as lx
getkey = eqxi.GetKey()
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8):
return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)
jax.config.update("jax_enable_x64", True)
def make_problem(mat_size: int, *, key):
mat = jr.normal(key, (mat_size, mat_size))
true_x = jr.normal(key, (mat_size,))
b = mat @ true_x
op = lx.MatrixLinearOperator(mat)
return mat, op, b, true_x
def benchmark_jax(mat_size: int, *, key):
mat, _, b, true_x = make_problem(mat_size, key=key)
solve_with_jax = ft.partial(
jsp.sparse.linalg.gmres, tol=1e-5, solve_method="batched"
)
gmres_jit = jax.jit(solve_with_jax)
jax_soln, info = gmres_jit(mat, b)
# info == 0.0 implies that the solve has succeeded.
returned_failed = jnp.all(info != 0.0)
actually_failed = not tree_allclose(jax_soln, true_x, atol=1e-4, rtol=1e-4)
assert actually_failed
captured_failure = returned_failed & actually_failed
return captured_failure
def benchmark_lx(mat_size: int, *, key):
_, op, b, true_x = make_problem(mat_size, key=key)
lx_soln = lx.linear_solve(op, b, lx.GMRES(atol=1e-5, rtol=1e-5), throw=False)
returned_failed = jnp.all(lx_soln.result != lx.RESULTS.successful)
actually_failed = not tree_allclose(lx_soln.value, true_x, atol=1e-4, rtol=1e-4)
assert actually_failed
captured_failure = returned_failed & actually_failed
return captured_failure
lx_failed_safely = 0
jax_failed_safely = 0
for _ in range(100):
key = getkey()
jax_captured_failure = benchmark_jax(100, key=key)
lx_captured_failure = benchmark_lx(100, key=key)
jax_failed_safely = jax_failed_safely + jax_captured_failure
lx_failed_safely = lx_failed_safely + lx_captured_failure
print(f"JAX failed safely {jax_failed_safely} out of 100 times")
print(f"Lineax failed safely {lx_failed_safely} out of 100 times")
================================================
FILE: benchmarks/lstsq_gradients.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Core JAX has some numerical issues with their lstsq gradients.
# See https://github.com/google/jax/issues/14868
# This demonstrates that we don't have the same issue!
import sys
import jax
import jax.numpy as jnp
import lineax as lx
sys.path.append("../tests")
from helpers import finite_difference_jvp # pyright: ignore
a_primal = (jnp.eye(3),)
a_tangent = (jnp.zeros((3, 3)),)
def jax_solve(a):
sol, _, _, _ = jnp.linalg.lstsq(a, jnp.arange(3)) # pyright: ignore
return sol
def lx_solve(a):
op = lx.MatrixLinearOperator(a)
return lx.linear_solve(op, jnp.arange(3)).value
_, true_jvp = finite_difference_jvp(jax_solve, a_primal, a_tangent)
_, jax_jvp = jax.jvp(jax_solve, a_primal, a_tangent)
_, lx_jvp = jax.jvp(lx_solve, a_primal, a_tangent)
assert jnp.isnan(jax_jvp).all()
assert jnp.allclose(true_jvp, lx_jvp)
================================================
FILE: benchmarks/solver_speeds.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools as ft
import sys
import timeit
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import lineax as lx
sys.path.append("../tests")
from helpers import construct_matrix, has_tag # pyright: ignore[reportMissingImports]
getkey = eqxi.GetKey()
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8):
return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)
jax.config.update("jax_enable_x64", True)
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-12
else:
tol = 1e-6
def base_wrapper(a, b, solver):
op = lx.MatrixLinearOperator(
a,
(
lx.positive_semidefinite_tag,
lx.symmetric_tag,
lx.diagonal_tag,
lx.tridiagonal_tag,
),
)
out = lx.linear_solve(op, b, solver, throw=False)
return out.value
def jax_svd(a, b):
out, _, _, _ = jnp.linalg.lstsq(a, b) # pyright: ignore
return out
def jax_gmres(a, b):
out, _ = jsp.sparse.linalg.gmres(a, b, tol=tol)
return out
def jax_bicgstab(a, b):
out, _ = jsp.sparse.linalg.bicgstab(a, b, tol=tol)
return out
def jax_cg(a, b):
out, _ = jsp.sparse.linalg.cg(a, b, tol=tol)
return out
def jax_lu(matrix, vector):
return jsp.linalg.lu_solve(jsp.linalg.lu_factor(matrix), vector)
def jax_cholesky(matrix, vector):
return jsp.linalg.cho_solve(jsp.linalg.cho_factor(matrix), vector)
def jax_tridiagonal(matrix, vector):
dl = jnp.append(0.0, matrix.diagonal(-1))
d = matrix.diagonal(0)
du = jnp.append(matrix.diagonal(1), 0.0)
return jax.lax.linalg.tridiagonal_solve(dl, d, du, vector[:, None])[:, 0]
named_solvers = [
("LU", "LU", lx.LU(), jax_lu, ()),
("QR", "SVD", lx.QR(), jax_svd, ()),
("SVD", "SVD", lx.SVD(), jax_svd, ()),
(
"Cholesky",
"Cholesky",
lx.Cholesky(),
jax_cholesky,
lx.positive_semidefinite_tag,
),
("Diagonal", "None", lx.Diagonal(), None, lx.diagonal_tag),
(
"Tridiagonal",
"Tridiagonal",
lx.Tridiagonal(),
jax_tridiagonal,
lx.tridiagonal_tag,
),
(
"CG",
"CG",
lx.CG(atol=tol, rtol=tol, stabilise_every=None),
jax_cg,
lx.positive_semidefinite_tag,
),
(
"GMRES",
"GMRES",
lx.GMRES(atol=1, rtol=1),
jax_gmres,
(),
),
(
"BiCGStab",
"BiCGStab",
lx.BiCGStab(atol=tol, rtol=tol),
jax_bicgstab,
(),
),
]
def create_problem(solver, tags, size=3):
(matrix,) = construct_matrix(getkey, solver, tags, size=size)
true_x = jr.normal(getkey(), (size,))
b = matrix @ true_x
return matrix, true_x, b
def create_easy_iterative_problem(size, tags):
matrix = jr.normal(getkey(), (size, size)) / size + 2 * jnp.eye(size)
true_x = jr.normal(getkey(), (size,))
if has_tag(tags, lx.positive_semidefinite_tag):
matrix = matrix.T @ matrix
b = matrix @ true_x
return matrix, true_x, b
def test_solvers(vmap_size, mat_size):
for lx_name, jax_name, _lx_solver, jax_solver, tags in named_solvers:
lx_solver = ft.partial(base_wrapper, solver=_lx_solver)
if vmap_size == 1:
if isinstance(_lx_solver, (lx.CG, lx.GMRES, lx.BiCGStab)):
matrix, true_x, b = create_easy_iterative_problem(mat_size, tags)
else:
matrix, true_x, b = create_problem(lx_solver, tags, size=mat_size)
else:
if isinstance(_lx_solver, (lx.CG, lx.GMRES, lx.BiCGStab)):
matrix, true_x, b = eqx.filter_vmap(
create_easy_iterative_problem,
axis_size=vmap_size,
out_axes=eqx.if_array(0),
)(mat_size, tags)
else:
matrix, true_x, b = create_problem(lx_solver, tags, size=mat_size)
_create_problem = ft.partial(create_problem, size=mat_size)
matrix, true_x, b = eqx.filter_vmap(
_create_problem, axis_size=vmap_size, out_axes=eqx.if_array(0)
)(lx_solver, tags)
lx_solver = jax.vmap(lx_solver)
if jax_solver is not None:
jax_solver = jax.vmap(jax_solver)
lx_solver = jax.jit(lx_solver)
bench_lx = ft.partial(lx_solver, matrix, b)
if vmap_size == 1:
batch_msg = "problem"
else:
batch_msg = f"batch of {vmap_size} problems"
lx_soln = bench_lx()
if tree_allclose(lx_soln, true_x, atol=1e-4, rtol=1e-4):
lx_solve_time = timeit.timeit(bench_lx, number=1)
print(
f"Lineax's {lx_name} solved {batch_msg} of "
f"size {mat_size} in {lx_solve_time} seconds."
)
else:
fail_time = timeit.timeit(bench_lx, number=1)
err = jnp.abs(lx_soln - true_x).max()
print(
f"Lineax's {lx_name} failed to solve {batch_msg} of "
f"size {mat_size} with error {err} in {fail_time} seconds"
)
if jax_solver is None:
print("JAX has no equivalent solver. \n")
else:
jax_solver = jax.jit(jax_solver)
bench_jax = ft.partial(jax_solver, matrix, b)
jax_soln = bench_jax()
if tree_allclose(jax_soln, true_x, atol=1e-4, rtol=1e-4):
jax_solve_time = timeit.timeit(bench_jax, number=1)
print(
f"JAX's {jax_name} solved {batch_msg} of "
f"size {mat_size} in {jax_solve_time} seconds. \n"
)
else:
fail_time = timeit.timeit(bench_jax, number=1)
err = jnp.abs(jax_soln - true_x).max()
print(
f"JAX's {jax_name} failed to solve {batch_msg} of "
f"size {mat_size} with error {err} in {fail_time} seconds. \n"
)
for vmap_size, mat_size in [(1, 50), (1000, 50)]:
test_solvers(vmap_size, mat_size)
================================================
FILE: docs/.htaccess
================================================
ErrorDocument 404 /jaxtyping/404.html
================================================
FILE: docs/_overrides/partials/source.html
================================================
{% import "partials/language.html" as lang with context %}
{% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %}
{% include ".icons/" ~ icon ~ ".svg" %}
{{ config.repo_name }}
{% include ".icons/fontawesome/brands/twitter.svg" %}
{% include "bluesky.svg" %}
{{ config.theme.twitter_bluesky_name }}
================================================
FILE: docs/_static/custom_css.css
================================================
/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */
html {
scroll-padding-top: 50px;
}
/* Fit the Twitter handle alongside the GitHub one in the top right. */
div.md-header__source {
width: revert;
max-width: revert;
}
a.md-source {
display: inline-block;
}
.md-source__repository {
max-width: 100%;
}
/* Emphasise sections of nav on left hand side */
nav.md-nav {
padding-left: 5px;
}
nav.md-nav--secondary {
border-left: revert !important;
}
.md-nav__title {
font-size: 0.9rem;
}
.md-nav__item--section > .md-nav__link {
font-size: 0.9rem;
}
/* Indent autogenerated documentation */
div.doc-contents {
padding-left: 25px;
border-left: 4px solid rgba(230, 230, 230);
}
/* Increase visibility of splitters "---" */
[data-md-color-scheme="default"] .md-typeset hr {
border-bottom-color: rgb(0, 0, 0);
border-bottom-width: 1pt;
}
[data-md-color-scheme="slate"] .md-typeset hr {
border-bottom-color: rgb(230, 230, 230);
}
/* More space at the bottom of the page */
.md-main__inner {
margin-bottom: 1.5rem;
}
/* Remove prev/next footer buttons */
.md-footer__inner {
display: none;
}
/* Change font sizes */
html {
/* Decrease font size for overall webpage
Down from 137.5% which is the Material default */
font-size: 110%;
}
.md-typeset .admonition {
/* Increase font size in admonitions */
font-size: 100% !important;
}
.md-typeset details {
/* Increase font size in details */
font-size: 100% !important;
}
.md-typeset h1 {
font-size: 1.6rem;
}
.md-typeset h2 {
font-size: 1.5rem;
}
.md-typeset h3 {
font-size: 1.3rem;
}
.md-typeset h4 {
font-size: 1.1rem;
}
.md-typeset h5 {
font-size: 0.9rem;
}
.md-typeset h6 {
font-size: 0.8rem;
}
/* Bugfix: remove the superfluous parts generated when doing:
??? Blah
::: library.something
*/
.md-typeset details .mkdocstrings > h4 {
display: none;
}
.md-typeset details .mkdocstrings > h5 {
display: none;
}
/* Change default colours for tags */
[data-md-color-scheme="default"] {
--md-typeset-a-color: rgb(0, 189, 164) !important;
}
[data-md-color-scheme="slate"] {
--md-typeset-a-color: rgb(0, 189, 164) !important;
}
/* Highlight functions, classes etc. type signatures. Really helps to make clear where
one item ends and another begins. */
[data-md-color-scheme="default"] {
--doc-heading-color: #DDD;
--doc-heading-border-color: #CCC;
--doc-heading-color-alt: #F0F0F0;
}
[data-md-color-scheme="slate"] {
--doc-heading-color: rgb(25,25,33);
--doc-heading-border-color: rgb(25,25,33);
--doc-heading-color-alt: rgb(33,33,44);
--md-code-bg-color: rgb(38,38,50);
}
h4.doc-heading {
/* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/
background-color: var(--doc-heading-color);
border: solid var(--doc-heading-border-color);
border-width: 1.5pt;
border-radius: 2pt;
padding: 0pt 5pt 2pt 5pt;
}
h5.doc-heading, h6.heading {
background-color: var(--doc-heading-color-alt);
border-radius: 2pt;
padding: 0pt 5pt 2pt 5pt;
}
================================================
FILE: docs/_static/mathjax.js
================================================
window.MathJax = {
tex: {
inlineMath: [["\\(", "\\)"]],
displayMath: [["\\[", "\\]"]],
processEscapes: true,
processEnvironments: true
},
options: {
ignoreHtmlClass: ".*|",
processHtmlClass: "arithmatex"
}
};
document$.subscribe(() => {
MathJax.typesetPromise()
})
================================================
FILE: docs/api/functions.md
================================================
# Functions on linear operators
We define a number of functions on [linear operators](./operators.md).
## Computational changes
These do not change the mathematical meaning of the operator; they simply change how it is stored computationally. (E.g. to materialise the whole operator.)
::: lineax.linearise
---
::: lineax.materialise
## Extract information from the operator
::: lineax.diagonal
---
::: lineax.tridiagonal
## Test the operator to see if it exhibits a certain property
Note that these do *not* inspect the values of the operator -- instead, they use typically use [tags](./tags.md). (Or in some cases, just the type of the operator: e.g. `is_diagonal(DiagonalLinearOperator(...)) == True`.)
::: lineax.has_unit_diagonal
---
::: lineax.is_diagonal
---
::: lineax.is_tridiagonal
---
::: lineax.is_lower_triangular
---
::: lineax.is_upper_triangular
---
::: lineax.is_positive_semidefinite
---
::: lineax.is_negative_semidefinite
---
::: lineax.is_symmetric
================================================
FILE: docs/api/linear_solve.md
================================================
# linear_solve
This is the main entry point.
::: lineax.linear_solve
## invert
A convenience function for obtaining the inverse of an operator as a [`lineax.FunctionLinearOperator`][].
::: lineax.invert
================================================
FILE: docs/api/operators.md
================================================
# Linear operators
We often talk about solving a linear system $Ax = b$, where $A \in \mathbb{R}^{n \times m}$ is a matrix, $b \in \mathbb{R}^n$ is a vector, and $x \in \mathbb{R}^m$ is our desired solution.
The linear operators described on this page are ways of describing the matrix $A$. The simplest is [`lineax.MatrixLinearOperator`][], which simply holds the matrix $A$ directly.
Meanwhile if $A$ is diagonal, then there is also [`lineax.DiagonalLinearOperator`][]: for efficiency this only stores the diagonal of $A$.
Or, perhaps we only have a function $F : \mathbb{R}^m \to \mathbb{R}^n$ such that $F(x) = Ax$. Whilst we could use $F$ to materialise the whole matrix $A$ and then store it in a [`lineax.MatrixLinearOperator`][], that may be very memory intensive. Instead, we may prefer to use [`lineax.FunctionLinearOperator`][]. Many linear solvers (e.g. [`lineax.CG`][]) only use matrix-vector products, and this means we can avoid ever needing to materialise the whole matrix $A$.
??? abstract "`lineax.AbstractLinearOperator`"
::: lineax.AbstractLinearOperator
options:
members:
- mv
- as_matrix
- transpose
- in_structure
- out_structure
- in_size
- out_size
::: lineax.MatrixLinearOperator
options:
members:
- __init__
---
::: lineax.DiagonalLinearOperator
options:
members:
- __init__
---
::: lineax.TridiagonalLinearOperator
options:
members:
- __init__
---
::: lineax.PyTreeLinearOperator
options:
members:
- __init__
---
::: lineax.JacobianLinearOperator
options:
members:
- __init__
---
::: lineax.FunctionLinearOperator
options:
members:
- __init__
---
::: lineax.IdentityLinearOperator
options:
members:
- __init__
---
::: lineax.TaggedLinearOperator
options:
members:
- __init__
================================================
FILE: docs/api/solution.md
================================================
# Solution
::: lineax.Solution
options:
members: []
---
::: lineax.RESULTS
options:
members: []
================================================
FILE: docs/api/solvers.md
================================================
# Solvers
If you're not sure what to use, then pick [`lineax.AutoLinearSolver`][] and it will automatically dispatch to an efficient solver depending on what structure your linear operator is declared to exhibit. (See the [tags](./tags.md) page.)
??? abstract "`lineax.AbstractLinearSolver`"
::: lineax.AbstractLinearSolver
options:
members:
- init
- compute
- transpose
- conj
- assume_full_rank
::: lineax.AutoLinearSolver
options:
members:
- __init__
- select_solver
---
::: lineax.LU
options:
members:
- __init__
## Least squares solvers
These are capable of solving ill-posed linear problems.
::: lineax.QR
options:
members:
- __init__
---
::: lineax.SVD
options:
members:
- __init__
---
::: lineax.Normal
options:
members:
- __init__
---
::: lineax.LSMR
options:
members:
- __init__
#### Diagonal
In addition to these, [`lineax.Diagonal`][] with `well_posed=False` (below) also supports ill-posed problems.
## Iterative solvers
These solvers use only matrix-vector products, and do not require instantiating the whole matrix. This makes them good when used alongside e.g. [`lineax.JacobianLinearOperator`][] or [`lineax.FunctionLinearOperator`][], which only provide matrix-vector products.
!!! warning
Note that [`lineax.BiCGStab`][] and [`lineax.GMRES`][] may fail to converge on some (typically non-sparse) problems.
::: lineax.CG
options:
members:
- __init__
---
::: lineax.BiCGStab
options:
members:
- __init__
---
::: lineax.GMRES
options:
members:
- __init__
#### LSMR
In addition to these, [`lineax.LSMR`][] (above) is also an iterative method.
## Structure-exploiting solvers
These require special structure in the operator. (And will throw an error if passed an operator without that structure.) In return, they are able to solve the linear problem much more efficiently.
::: lineax.Cholesky
options:
members:
- __init__
---
::: lineax.Diagonal
options:
members:
- __init__
---
::: lineax.Triangular
options:
members:
- __init__
---
::: lineax.Tridiagonal
options:
members:
- __init__
#### CG
In addition to these, [`lineax.CG`][] also requires special structure (positive or negative definiteness).
================================================
FILE: docs/api/tags.md
================================================
# Tags
Lineax offers a way to "tag" linear operators as exhibiting certain properties, e.g. that they are positive semidefinite.
If a linear operator is known to have a particular property, then this can be used to dispatch to a more efficient implementation, e.g. when solving a linear system.
Generally speaking, tags are an *optional* tool that can be used to improve your run time and/or compile time, by statically telling the linear solvers what properties they may assume about your system. However, if misused then you may find that the wrong result is silently returned.
In this way they are analogous to flags like `scipy.linalg.solve(..., assume_a="pos")`.
!!! Example
```python
# Some rank-2 JAX array.
matrix = ...
# Some rank-1 JAX array.
vector = ...
# Declare that this matrix is positive semidefinite.
operator = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag)
# This tag is used to dispatch to a maximally-efficient linear solver.
# In this case, a Cholesky solver is used:
solution = lx.linear_solve(operator, vector)
# Whether operators are tagged can be checked:
assert lx.is_positive_semidefinite(operator)
```
!!! Warning
Be careful, only the tag is actually checked, not the actual value of the matrix:
```python
# Not a positive semidefinite matrix
matrix = jax.numpy.array([[1, 2], [3, 4]])
operator = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag)
lx.is_positive_semidefinite(operator) # True
lx.linear_solve(operator, vector) # Returns the wrong solution!
```
Of the built-in operators: [`lineax.MatrixLinearOperator`][], [`lineax.PyTreeLinearOperator`][], [`lineax.JacobianLinearOperator`][], [`lineax.FunctionLinearOperator`][], [`lineax.TaggedLinearOperator`][] directly support a `tags` argument that mark them as having certain characteristics:
```python
operator = lx.MatrixLinearOperator(matrix, lx.symmetric_tag)
```
You can pass multiple tags at once:
```python
operator = lx.MatrixLinearOperator(matrix, (lx.symmetric_tag, lx.unit_diagonal_tag))
```
Other linear operators can be wrapped into a [`lineax.TaggedLinearOperator`][] if necessary:
```python
operator = lx.MatrixLinearOperator(...)
symmetric_operator = operator + operator.T
lx.is_symmetric(symmetric_operator) # False
symmetric_operator = lx.TaggedLinearOperator(symmetric_operator, lx.symmetric_tag)
lx.is_symmetric(symmetric_operator) # True
```
Some linear operators are known to exhibit certain properties by construction, and need no additional tags:
```python
lx.is_symmetric(lx.DiagonalLinearOperator(...)) # True
lx.is_positive_semidefinite(lx.IdentityLinearOperator(...)) # True
```
## List of available tags
::: lineax.symmetric_tag
Marks that an operator is symmetric. (As a matrix, $A = A^\intercal$.)
---
::: lineax.diagonal_tag
Marks than an operator is diagonal. (As a matrix, it must have zeros in the off-diagonal entries.)
For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Diagonal`][] as the solver.
---
::: lineax.unit_diagonal_tag
Marks than an operator has $1$ for every diagonal element. (As a matrix $A$, then it must have $A_{ii} = 1$ for all $i$.) Note that the whole matrix need not be diagonal.
For example, [`lineax.Triangular`][] uses this to cheapen its solve.
---
::: lineax.lower_triangular_tag
Marks that an operator is lower triangular. (As a matrix $A$, then it must have $A_{ij} = 0 for all $i < j$.) Note that the diagonal may still have nonzero entries.
For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Triangular`][] as the solver.
---
::: lineax.upper_triangular_tag
Marks that an operator is upper triangular. (As a matrix $A$, then it must have $A_{ij} = 0 for all $i > j$.) Note that the diagonal may still have nonzero entries.
For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Triangular`][] as the solver.
---
::: lineax.positive_semidefinite_tag
Marks than operator is positive **semidefinite**.
For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Cholesky`][] as the solver.
---
::: lineax.negative_semidefinite_tag
Marks than operator is negative **semidefinite**.
For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Cholesky`][] as the solver.
================================================
FILE: docs/examples/classical_solve.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "8d41e1dd-93da-4e81-bd4a-33e5df8915f1",
"metadata": {},
"source": [
"# Classical solve\n",
"\n",
"We wish to solve the linear system $Ax = b$. Here we consider the classical case for which the full matrix $A$ is square, well-posed and materialised in memory."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "cb3a7781-2358-40c4-82f3-e908bddeb578",
"metadata": {
"tags": [],
"ExecuteTime": {
"end_time": "2024-04-02T05:26:05.556701Z",
"start_time": "2024-04-02T05:26:03.814599Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A=\n",
"[[-0.3721109 0.26423115 -0.18252768]\n",
" [-0.7368197 0.44973662 -0.1521442 ]\n",
" [-0.67135346 -0.5908641 0.73168886]]\n",
"b=[ 0.17269018 -0.64765567 1.2229712 ]\n",
"x=[-2.7321298 -8.52878 -7.7226872]\n"
]
}
],
"source": [
"import jax.random as jr\n",
"import lineax as lx\n",
"\n",
"\n",
"matrix = jr.normal(jr.PRNGKey(0), (3, 3))\n",
"vector = jr.normal(jr.PRNGKey(1), (3,))\n",
"operator = lx.MatrixLinearOperator(matrix)\n",
"solution = lx.linear_solve(operator, vector)\n",
"print(f\"A=\\n{matrix}\\nb={vector}\\nx={solution.value}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/examples/complex_solve.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "8d41e1dd-93da-4e81-bd4a-33e5df8915f1",
"metadata": {},
"source": [
"# Complex solve\n",
"\n",
"We can also solve a system with complex entries. Here we consider the classical case for which the full matrix $A$ is square, well-posed and materialised in memory."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "cb3a7781-2358-40c4-82f3-e908bddeb578",
"metadata": {
"tags": [],
"ExecuteTime": {
"end_time": "2024-04-02T05:29:04.909894Z",
"start_time": "2024-04-02T05:29:04.103141Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A=\n",
"[[-1.8459436 -0.2744466j 0.02393756-0.03172905j 0.76815367-1.4444253j ]\n",
" [-1.0467293 +0.05608991j 1.0891742 -0.03264743j 0.7513123 +0.56285536j]\n",
" [ 0.38307396-1.0190808j 0.01203694-1.1971304j 0.19252291-0.26424018j]]\n",
"b=[0.23162952+0.3614433j 0.05800135+1.6094692j 0.8979094 +0.16941352j]\n",
"x=[-0.07652722-0.34397143j -0.22629777+1.0359733j 0.22135164-0.00880566j]\n"
]
}
],
"source": [
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"import lineax as lx\n",
"\n",
"\n",
"matrix = jr.normal(jr.PRNGKey(0), (3, 3), dtype=jnp.complex64)\n",
"vector = jr.normal(jr.PRNGKey(1), (3,), dtype=jnp.complex64)\n",
"operator = lx.MatrixLinearOperator(matrix)\n",
"solution = lx.linear_solve(operator, vector)\n",
"print(f\"A=\\n{matrix}\\nb={vector}\\nx={solution.value}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/examples/least_squares.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "44bff903-0e4d-4f3e-a75c-d3cfe8ab4dea",
"metadata": {},
"source": [
"# Linear least squares\n",
"\n",
"The solution to a well-posed linear system $Ax = b$ is given by $x = A^{-1}b$. If the matrix is rectangular or not invertible, then we may generalise the notion of solution to $x = A^{\\dagger}b$, where $A^{\\dagger}$ denotes the [Moore--Penrose pseudoinverse](https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse).\n",
"\n",
"Lineax can handle problems of this type too.\n",
"\n",
"!!! info\n",
"\n",
" For reference: in core JAX, problems of this type are handled using [`jax.numpy.linalg.lstsq`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.lstsq.html#jax.numpy.linalg.lstsq).\n",
" \n",
"---\n",
"\n",
"## Picking a solver\n",
"\n",
"By default, the linear solve will fail. This will be a compile-time failure if using a rectangular matrix:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a956c3f2-a70c-472f-9fa9-3dbc16293e1d",
"metadata": {
"tags": []
},
"outputs": [
{
"ename": "ValueError",
"evalue": "Cannot use `AutoLinearSolver(well_posed=True)` with a non-square operator. If you are trying solve a least-squares problem then you should pass `solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve` assumes that the operator is square and nonsingular.",
"output_type": "error",
"traceback": [
"\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m Cannot use `AutoLinearSolver(well_posed=True)` with a non-square operator. If you are trying solve a least-squares problem then you should pass `solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve` assumes that the operator is square and nonsingular.\n"
]
}
],
"source": [
"import jax.random as jr\n",
"import lineax as lx\n",
"\n",
"\n",
"vector = jr.normal(jr.PRNGKey(1), (3,))\n",
"\n",
"rectangular_matrix = jr.normal(jr.PRNGKey(0), (3, 4))\n",
"rectangular_operator = lx.MatrixLinearOperator(rectangular_matrix)\n",
"lx.linear_solve(rectangular_operator, vector)"
]
},
{
"cell_type": "markdown",
"id": "ba55c0dd-b696-497a-8b13-896c3a95d5fd",
"metadata": {},
"source": [
"Or it will happen at run time if using a rank-deficient matrix:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f0e7ffe6-1e3d-46dc-9dbd-d5ed4c2dedf4",
"metadata": {
"tags": []
},
"outputs": [
{
"ename": "XlaRuntimeError",
"evalue": "INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the\noperator was not well-posed, and that the solver does not support this.\n\nIf you are trying solve a linear least-squares problem then you should pass\n`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`\nassumes that the operator is square and nonsingular.\n\nIf you *were* expecting this solver to work with this operator, then it may be because:\n\n(a) the operator is singular, and your code has a bug; or\n\n(b) the operator was nearly singular (i.e. it had a high condition number:\n `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from\n numerical instability issues; or\n\n(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)\n that is does not actually satisfy.\n",
"output_type": "error",
"traceback": [
"\u001b[0;31mXlaRuntimeError\u001b[0m\u001b[0;31m:\u001b[0m INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the\noperator was not well-posed, and that the solver does not support this.\n\nIf you are trying solve a linear least-squares problem then you should pass\n`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`\nassumes that the operator is square and nonsingular.\n\nIf you *were* expecting this solver to work with this operator, then it may be because:\n\n(a) the operator is singular, and your code has a bug; or\n\n(b) the operator was nearly singular (i.e. it had a high condition number:\n `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from\n numerical instability issues; or\n\n(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)\n that is does not actually satisfy.\n"
]
}
],
"source": [
"deficient_matrix = jr.normal(jr.PRNGKey(0), (3, 3)).at[0].set(0)\n",
"deficient_operator = lx.MatrixLinearOperator(deficient_matrix)\n",
"lx.linear_solve(deficient_operator, vector)"
]
},
{
"cell_type": "markdown",
"id": "4b5cedab-75e5-4b52-88d9-b9d574be7e19",
"metadata": {},
"source": [
"Whilst linear least squares and pseudoinverse are a strict generalisation of linear solves and inverses (respectively), Lineax will *not* attempt to handle the ill-posed case automatically. This is because the algorithms for handling this case are much more computationally expensive.!\n",
"\n",
"If your matrix may be rectangular, but is still known to be full rank, then you can set the solver to allow this case like so:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "45abc8bf-4fcf-46be-a91a-58f4e04ac10e",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"rectangular_solution: [-0.3214848 -0.75565964 -0.6034579 -0.01326615]\n"
]
}
],
"source": [
"rectangular_solution = lx.linear_solve(\n",
" rectangular_operator, vector, solver=lx.AutoLinearSolver(well_posed=None)\n",
")\n",
"print(\"rectangular_solution: \", rectangular_solution.value)"
]
},
{
"cell_type": "markdown",
"id": "86dc9e2f-fe2e-48c8-86ca-bc57f8137246",
"metadata": {},
"source": [
"If your matrix may be either rectangular or rank-deficient, then you can set the solver to all this case like so:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a9a2d92c-3676-471e-bb4a-5fd3b4748fd4",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"deficient_solution: [ 0.06046088 -1.0412765 0.8860444 ]\n"
]
}
],
"source": [
"deficient_solution = lx.linear_solve(\n",
" deficient_operator, vector, solver=lx.AutoLinearSolver(well_posed=False)\n",
")\n",
"print(\"deficient_solution: \", deficient_solution.value)"
]
},
{
"cell_type": "markdown",
"id": "7b870311-de0f-434c-a9e7-2d8ebf9f0b38",
"metadata": {},
"source": [
"Most users will want to use [`lineax.AutoLinearSolver`][], and not think about the details of which algorithm is selected.\n",
"\n",
"If you want to pick a particular algorithm, then that can be done too. [`lineax.QR`][] is capable of handling rectangular full-rank operators, and [`lineax.SVD`][] is capable of handling rank-deficient operators. (And in fact these are the algorithms that `AutoLinearSolver` is selecting in the examples above.)"
]
},
{
"cell_type": "markdown",
"id": "c9649746-b0ef-495b-9ea1-eb5f6ca2e7e5",
"metadata": {},
"source": [
"---\n",
"\n",
"## Differences from `jax.numpy.linalg.lstsq`?\n",
"\n",
"Lineax offers both speed and correctness advantages over the built-in algorithm. (This is partly because the built-in function has to have the same API as NumPy, so JAX is constrained in how it can be implemented.)\n",
"\n",
"### Speed (forward)\n",
"\n",
"First, in the rectangular case, then the QR algorithm is much faster than the SVD algorithm:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "d46d0c9a-47e4-439d-9beb-c9aaf47faa5d",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"JAX solution: [-0.10002219 0.09477127 -0.10846332 ... -0.08007179 -0.01216239\n",
" -0.030862 ]\n",
"Lineax solution: [-0.1000222 0.0947713 -0.10846333 ... -0.08007187 -0.01216241\n",
" -0.03086199]\n",
"\n",
"JAX time: 0.011344402999384329\n",
"Lineax time: 0.0028611960005946457\n"
]
}
],
"source": [
"import timeit\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"\n",
"matrix = jr.normal(jr.PRNGKey(0), (500, 200))\n",
"vector = jr.normal(jr.PRNGKey(1), (500,))\n",
"\n",
"\n",
"@jax.jit\n",
"def solve_jax(matrix, vector):\n",
" out, *_ = jnp.linalg.lstsq(matrix, vector)\n",
" return out\n",
"\n",
"\n",
"@jax.jit\n",
"def solve_lineax(matrix, vector):\n",
" operator = lx.MatrixLinearOperator(matrix)\n",
" solver = lx.QR() # or lx.AutoLinearSolver(well_posed=None)\n",
" solution = lx.linear_solve(operator, vector, solver)\n",
" return solution.value\n",
"\n",
"\n",
"solution_jax = solve_jax(matrix, vector)\n",
"solution_lineax = solve_lineax(matrix, vector)\n",
"with np.printoptions(threshold=10):\n",
" print(\"JAX solution:\", solution_jax)\n",
" print(\"Lineax solution:\", solution_lineax)\n",
"print()\n",
"time_jax = timeit.repeat(lambda: solve_jax(matrix, vector), number=1, repeat=10)\n",
"time_lineax = timeit.repeat(lambda: solve_lineax(matrix, vector), number=1, repeat=10)\n",
"print(\"JAX time:\", min(time_jax))\n",
"print(\"Lineax time:\", min(time_lineax))"
]
},
{
"cell_type": "markdown",
"id": "397773d7-f782-45e6-9934-11c62c741380",
"metadata": {},
"source": [
"### Speed (gradients)\n",
"\n",
"Lineax also uses a slightly more efficient autodifferentiation implementation, which ensures it is faster, even when both are using the SVD algorithm."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "1988d0f6-86f5-401a-9615-30cccf04d129",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"JAX gradients: [[-1.75446249e-03 2.00700224e-03 ... -3.16517282e-04 -6.08515576e-04]\n",
" [ 1.81865180e-04 4.51280124e-04 ... -1.64618701e-04 -6.53692259e-05]\n",
" ...\n",
" [-7.27269216e-04 1.27710134e-03 ... -2.64510425e-04 -3.38940619e-04]\n",
" [ 6.55723223e-03 -3.18011409e-03 ... -1.10758876e-04 1.43246143e-03]]\n",
"Lineax gradients: [[-1.7544631e-03 2.0070139e-03 ... -3.1653541e-04 -6.0847402e-04]\n",
" [ 1.8186278e-04 4.5128341e-04 ... -1.6459504e-04 -6.5359738e-05]\n",
" ...\n",
" [-7.2721508e-04 1.2771402e-03 ... -2.6450949e-04 -3.3894143e-04]\n",
" [ 6.5572355e-03 -3.1801097e-03 ... -1.1071599e-04 1.4324478e-03]]\n",
"\n",
"JAX time: 0.016591553001489956\n",
"Lineax time: 0.012212782999995397\n"
]
}
],
"source": [
"@jax.jit\n",
"@jax.grad\n",
"def grad_jax(matrix):\n",
" out, *_ = jnp.linalg.lstsq(matrix, vector)\n",
" return out.sum()\n",
"\n",
"\n",
"@jax.jit\n",
"@jax.grad\n",
"def grad_lineax(matrix):\n",
" operator = lx.MatrixLinearOperator(matrix)\n",
" solution = lx.linear_solve(operator, vector, lx.SVD())\n",
" return solution.value.sum()\n",
"\n",
"\n",
"gradients_jax = grad_jax(matrix)\n",
"gradients_lineax = grad_lineax(matrix)\n",
"with np.printoptions(threshold=10, edgeitems=2):\n",
" print(\"JAX gradients:\", gradients_jax)\n",
" print(\"Lineax gradients:\", gradients_lineax)\n",
"print()\n",
"time_jax = timeit.repeat(lambda: grad_jax(matrix), number=1, repeat=10)\n",
"time_lineax = timeit.repeat(lambda: grad_lineax(matrix), number=1, repeat=10)\n",
"print(\"JAX time:\", min(time_jax))\n",
"print(\"Lineax time:\", min(time_lineax))"
]
},
{
"cell_type": "markdown",
"id": "81a1da5a-3474-4613-926f-5c9d9cdcb4a7",
"metadata": {},
"source": [
"### Correctness (gradients)\n",
"\n",
"Core JAX unfortunately has a bug that means it sometimes produces NaN gradients. Lineax does not:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "66b3a08e-92d0-4d0f-a5ea-9e8e5265d259",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"JAX gradients: [[nan nan nan]\n",
" [nan nan nan]\n",
" [nan nan nan]]\n",
"Lineax gradients: [[ 0. -1. -2.]\n",
" [ 0. -1. -2.]\n",
" [ 0. -1. -2.]]\n"
]
}
],
"source": [
"@jax.jit\n",
"@jax.grad\n",
"def grad_jax(matrix):\n",
" out, *_ = jnp.linalg.lstsq(matrix, jnp.arange(3.0))\n",
" return out.sum()\n",
"\n",
"\n",
"@jax.jit\n",
"@jax.grad\n",
"def grad_lineax(matrix):\n",
" operator = lx.MatrixLinearOperator(matrix)\n",
" solution = lx.linear_solve(operator, jnp.arange(3.0), lx.SVD())\n",
" return solution.value.sum()\n",
"\n",
"\n",
"print(\"JAX gradients:\", grad_jax(jnp.eye(3)))\n",
"print(\"Lineax gradients:\", grad_lineax(jnp.eye(3)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py39",
"language": "python",
"name": "py39"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/examples/no_materialisation.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "a7299095-8906-4867-82ef-d6b84b161366",
"metadata": {},
"source": [
"# Using only matrix-vector operations\n",
"\n",
"When solving a linear system $Ax = b$, it is relatively common not to have immediate access to the full matrix $A$, but only to a function $F(x) = Ax$ computing the matrix-vector product. (We could compute $A$ from $F$, but is the matrix is large then this may be very inefficient.)\n",
"\n",
"**Example: Newton's method**\n",
"\n",
"For example, this comes up when using [Newton's method](https://en.wikipedia.org/wiki/Newton%27s_method#k_variables,_k_functions). In this case, we have a function $f \\colon \\mathbb{R}^n \\to \\mathbb{R}^n$, and wish to find the $\\delta \\in \\mathbb{R}^n$ for which $\\frac{\\mathrm{d}f}{\\mathrm{d}y}(y) \\; \\delta = -f(y)$. (Where $\\frac{\\mathrm{d}f}{\\mathrm{d}y}(y) \\in \\mathbb{R}^{n \\times n}$ is a matrix: it is the Jacobian of $f$.)\n",
"\n",
"In this case it is possible to use forward-mode autodifferentiation to evaluate $F(x) = \\frac{\\mathrm{d}f}{\\mathrm{d}y}(y) \\; x$, without ever instantiating the whole Jacobian $\\frac{\\mathrm{d}f}{\\mathrm{d}y}(y)$. Indeed, JAX has a [Jacobian-vector product function](https://jax.readthedocs.io/en/latest/_autosummary/jax.jvp.html#jax.jvp) for exactly this purpose.\n",
"```python\n",
"f = ...\n",
"y = ...\n",
"\n",
"def F(x):\n",
" \"\"\"Computes (df/dy) @ x.\"\"\"\n",
" _, out = jax.jvp(f, (y,), (x,))\n",
" return out\n",
"```\n",
"\n",
"**Solving a linear system using only matrix-vector operations**\n",
"\n",
"Lineax offers [iterative solvers](../api/solvers.md#iterative-solvers), which are capable of solving a linear system knowing only its matrix-vector products."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b221ee1f-bd6b-4cbf-b69b-ed2e388602e1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import lineax as lx\n",
"from jaxtyping import Array, Float # https://github.com/google/jaxtyping\n",
"\n",
"\n",
"def f(y: Float[Array, \"3\"], args) -> Float[Array, \"3\"]:\n",
" y0, y1, y2 = y\n",
" f0 = 5 * y0 + y1**2\n",
" f1 = y1 - y2 + 5\n",
" f2 = y0 / (1 + 5 * y2**2)\n",
" return jnp.stack([f0, f1, f2])\n",
"\n",
"\n",
"y = jnp.array([1.0, 2.0, 3.0])\n",
"operator = lx.JacobianLinearOperator(f, y, args=None)\n",
"vector = f(y, args=None)\n",
"solver = lx.Normal(lx.CG(rtol=1e-6, atol=1e-6))\n",
"solution = lx.linear_solve(operator, vector, solver)"
]
},
{
"cell_type": "markdown",
"id": "87568426-35ed-404b-bf78-425a6f519218",
"metadata": {},
"source": [
"!!! warning\n",
"\n",
" Note that iterative solvers are something of a \"last resort\", and they are not suitable for all problems.\n",
"\n",
" - [CG](https://en.wikipedia.org/wiki/Conjugate_gradient_method) requires that the problem be positive or negative semidefinite.\n",
" - Normalised CG (this is CG applied to the \"normal equations\" $(A^\\top A) x = (A^\\top b)$; note that $A^\\top A$ is always positive semidefinite) squares the condition number of $A$. In practice this means it may produce low-accuracy results if used with matrices with high condition number.\n",
" - [BiCGStab](https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method) and [GMRES](https://en.wikipedia.org/wiki/Generalized_minimal_residual_method) will fail on many problems. They are primarily meant as specialised tools for e.g. the matrices that arise when solving elliptic systems."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py39",
"language": "python",
"name": "py39"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/examples/operators.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "2fe0b1e4-35cb-4c39-b324-65253aab005a",
"metadata": {},
"source": [
"# Manipulating linear operators\n",
"\n",
"Lineax offers a sophisticated system of linear operators, supporting many operations.\n",
"\n",
"## Arithmetic\n",
"\n",
"To begin with, they support arithmetic, like addition and multiplication:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "552021d3-dadf-49f3-bd17-84a18513bfcc",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"import lineax as lx\n",
"import numpy as np\n",
"\n",
"\n",
"np.set_printoptions(precision=3)\n",
"\n",
"matrix = jnp.zeros((5, 5))\n",
"matrix = matrix.at[0, 4].set(3) # top left corner\n",
"sparse_operator = lx.MatrixLinearOperator(matrix)\n",
"\n",
"key0, key1, key = jr.split(jr.PRNGKey(0), 3)\n",
"diag = jr.normal(key0, (5,))\n",
"lower_diag = jr.normal(key0, (4,))\n",
"upper_diag = jr.normal(key0, (4,))\n",
"tridiag_operator = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)\n",
"\n",
"identity_operator = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((5,), jnp.float32))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a4bb9825-73cc-447e-bc4c-c3e1a121a0a3",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-1.149 0.963 0. 0. 3. ]\n",
" [ 0.963 -2.007 0.155 0. 0. ]\n",
" [ 0. 0.155 0.988 -0.261 0. ]\n",
" [ 0. 0. -0.261 0.931 0.899]\n",
" [ 0. 0. 0. 0.899 -0.288]]\n"
]
}
],
"source": [
"print((sparse_operator + tridiag_operator).as_matrix())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "759c78a1-eee7-40e9-be6c-ea8c97c29e95",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-101.149 0.963 0. 0. 0. ]\n",
" [ 0.963 -102.007 0.155 0. 0. ]\n",
" [ 0. 0.155 -99.012 -0.261 0. ]\n",
" [ 0. 0. -0.261 -99.069 0.899]\n",
" [ 0. 0. 0. 0.899 -100.288]]\n"
]
}
],
"source": [
"print((tridiag_operator - 100 * identity_operator).as_matrix())"
]
},
{
"cell_type": "markdown",
"id": "84412bfa-00ec-41d4-87d7-def781145a90",
"metadata": {},
"source": [
"Or they can be composed together. (I.e. matrix multiplication.)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8081d97f-5579-464f-8780-ffaa1d9c5f95",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0. 0. 0. 0. -3.447]\n",
" [ 0. 0. 0. 0. 2.888]\n",
" [ 0. 0. 0. 0. 0. ]\n",
" [ 0. 0. 0. 0. 0. ]\n",
" [ 0. 0. 0. 0. 0. ]]\n"
]
}
],
"source": [
"print((tridiag_operator @ sparse_operator).as_matrix())"
]
},
{
"cell_type": "markdown",
"id": "d2c2b580-616f-4abd-a732-7f4a9b13335f",
"metadata": {},
"source": [
"Or they can be transposed:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ae0393eb-3f43-490b-9842-bb374633633a",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0. 0. 0. 0. 0.]\n",
" [0. 0. 0. 0. 0.]\n",
" [0. 0. 0. 0. 0.]\n",
" [0. 0. 0. 0. 0.]\n",
" [3. 0. 0. 0. 0.]]\n"
]
}
],
"source": [
"print(sparse_operator.transpose().as_matrix()) # or sparse_operator.T will work"
]
},
{
"cell_type": "markdown",
"id": "ddbbbb0f-7983-4e35-b92d-2512c9612d19",
"metadata": {},
"source": [
"## Different operator types\n",
"\n",
"Lineax has many different operator types:\n",
"\n",
"- We've already seen some general examples above, like [`lineax.MatrixLinearOperator`][].\n",
"- We've already seen some structured examples above, like [`lineax.TridiagonalLinearOperator`][].\n",
"- Given a function $f \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$ and a point $x \\in \\mathbb{R}^n$, then [`lineax.JacobianLinearOperator`][] represents the Jacobian $\\frac{\\mathrm{d}f}{\\mathrm{d}x}(x) \\in \\mathbb{R}^{n \\times m}$.\n",
"- Given a linear function $g \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$, then [`lineax.FunctionLinearOperator`][] represents the matrix corresponding to this linear function, i.e. the unique matrix $A$ for which $g(x) = Ax$.\n",
"- etc!\n",
"\n",
"See the [operators](../api/operators.md) page for details on all supported operators.\n",
"\n",
"As above these can be freely combined:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "75ad4480-8ce0-4a88-9c76-bc054b1a0eaf",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from jaxtyping import Array, Float # https://github.com/google/jaxtyping\n",
"\n",
"\n",
"def f(y: Float[Array, \"3\"], args) -> Float[Array, \"3\"]:\n",
" y0, y1, y2 = y\n",
" f0 = 5 * y0 + y1**2\n",
" f1 = y1 - y2 + 5\n",
" f2 = y0 / (1 + 5 * y2**2)\n",
" return jnp.stack([f0, f1, f2])\n",
"\n",
"\n",
"def g(y: Float[Array, \"3\"]) -> Float[Array, \"3\"]:\n",
" # Must be linear!\n",
" y0, y1, y2 = y\n",
" f0 = y0 - y2\n",
" f1 = 0.0\n",
" f2 = 5 * y1\n",
" return jnp.stack([f0, f1, f2])\n",
"\n",
"\n",
"y = jnp.array([1.0, 2.0, 3.0])\n",
"in_structure = jax.eval_shape(lambda: y)\n",
"jac_operator = lx.JacobianLinearOperator(f, y, args=None)\n",
"fn_operator = lx.FunctionLinearOperator(g, in_structure)\n",
"identity_operator = lx.IdentityLinearOperator(in_structure)\n",
"\n",
"operator = jac_operator @ fn_operator + 0.9 * identity_operator"
]
},
{
"cell_type": "markdown",
"id": "5e528057-29ff-468d-aa3d-7155dd57082d",
"metadata": {},
"source": [
"This composition does not instantiate a matrix for them by default. (This is sometimes important for efficiency when working with many operators.) Instead, the composition is stored as another linear operator:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5d15150d-955f-4006-bd36-58e2e6663307",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AddLinearOperator(\n",
" operator1=ComposedLinearOperator(\n",
" operator1=JacobianLinearOperator(...),\n",
" operator2=FunctionLinearOperator(...)\n",
" ),\n",
" operator2=MulLinearOperator(\n",
" operator=IdentityLinearOperator(...),\n",
" scalar=f32[]\n",
" )\n",
")\n"
]
}
],
"source": [
"import equinox as eqx # https://github.com/patrick-kidger/equinox\n",
"\n",
"\n",
"truncate_leaf = lambda x: x in (jac_operator, fn_operator, identity_operator)\n",
"eqx.tree_pprint(operator, truncate_leaf=truncate_leaf)"
]
},
{
"cell_type": "markdown",
"id": "ff7b0591-1203-4f5e-886e-399822c68a15",
"metadata": {
"tags": []
},
"source": [
"If you want to materialise them into a matrix, then this can be done:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "3713589f-1ac4-4e08-946b-ecc3fcf6a4c3",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"Array([[ 5.9 , 0. , -5. ],\n",
" [ 0. , -4.1 , 0. ],\n",
" [ 0.022, -0.071, 0.878]], dtype=float32)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"operator.as_matrix()"
]
},
{
"cell_type": "markdown",
"id": "a483517e-89d7-4e9e-ad89-1915d886c14c",
"metadata": {},
"source": [
"Which can in turn be treated as another linear operator, if desired:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fccddc81-d50e-4abe-a354-38402e462b1f",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MatrixLinearOperator(\n",
" matrix=Array([[ 5.9 , 0. , -5. ],\n",
" [ 0. , -4.1 , 0. ],\n",
" [ 0.022, -0.071, 0.878]], dtype=float32),\n",
" tags=frozenset()\n",
")\n"
]
}
],
"source": [
"operator_fully_materialised = lx.MatrixLinearOperator(operator.as_matrix())\n",
"eqx.tree_pprint(operator_fully_materialised, short_arrays=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py39",
"language": "python",
"name": "py39"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/examples/structured_matrices.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "e2573d62-a505-4998-8796-b0f1bc889433",
"metadata": {},
"source": [
"# Structured matrices\n",
"\n",
"Lineax can also be used with matrices known to exhibit special structure, e.g. tridiagonal matrices or positive definite matrices.\n",
"\n",
"Typically, that means using a particular operator type:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8e275652-dd80-4a9a-b3ac-b96dc16d3334",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 4. 2. 0. 0. ]\n",
" [ 1. -0.5 -1. 0. ]\n",
" [ 0. 3. 7. -5. ]\n",
" [ 0. 0. -0.7 1. ]]\n"
]
}
],
"source": [
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"import lineax as lx\n",
"\n",
"\n",
"diag = jnp.array([4.0, -0.5, 7.0, 1.0])\n",
"lower_diag = jnp.array([1.0, 3.0, -0.7])\n",
"upper_diag = jnp.array([2.0, -1.0, -5.0])\n",
"\n",
"operator = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)\n",
"print(operator.as_matrix())"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ba23ecc4-bdea-4293-a138-ce77bc83082c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"vector = jnp.array([1.0, -0.5, 2.0, 0.8])\n",
"# Will automatically dispatch to a tridiagonal solver.\n",
"solution = lx.linear_solve(operator, vector)"
]
},
{
"cell_type": "markdown",
"id": "cd58979d-b619-4ddf-9a17-12e8babae3e8",
"metadata": {},
"source": [
"If you're uncertain which solver is being dispatched to, then you can check:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6984f62f-75fc-4d6e-ab42-fdade471be5b",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tridiagonal()\n"
]
}
],
"source": [
"default_solver = lx.AutoLinearSolver(well_posed=True)\n",
"print(default_solver.select_solver(operator))"
]
},
{
"cell_type": "markdown",
"id": "164a5bd5-5d48-4b28-bcc5-d276ab49c780",
"metadata": {},
"source": [
"If you want to enforce that a particular solver is used, then it can be passed manually:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "102ada9a-0533-40cf-9bad-02918fffb6b1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"solution = lx.linear_solve(operator, vector, solver=lx.Tridiagonal())"
]
},
{
"cell_type": "markdown",
"id": "1b4ebf09-e138-43f6-973c-c9f005ffb55e",
"metadata": {},
"source": [
"Trying to use a solver with an unsupported operator will raise an error:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d8f5bf66-53cd-4e81-a8d7-a19e86307ad3",
"metadata": {
"tags": []
},
"outputs": [
{
"ename": "ValueError",
"evalue": "`Tridiagonal` may only be used for linear solves with tridiagonal matrices",
"output_type": "error",
"traceback": [
"\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m `Tridiagonal` may only be used for linear solves with tridiagonal matrices\n"
]
}
],
"source": [
"not_tridiagonal_matrix = jr.normal(jr.PRNGKey(0), (4, 4))\n",
"not_tridiagonal_operator = lx.MatrixLinearOperator(not_tridiagonal_matrix)\n",
"solution = lx.linear_solve(not_tridiagonal_operator, vector, solver=lx.Tridiagonal())"
]
},
{
"cell_type": "markdown",
"id": "03c4c531-58fa-4b56-8b0a-6e611c8c5912",
"metadata": {},
"source": [
"---\n",
"\n",
"Besides using a particular operator type, the structure of the matrix can also be expressed by [adding particular tags](../api/tags.md). These tags act as a manual override mechanism, and the values of the matrix are not checked.\n",
"\n",
"For example, let's construct a positive definite matrix:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b5add874-7a2c-4000-84c3-8c94a121a831",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"matrix = jr.normal(jr.PRNGKey(0), (4, 4))\n",
"operator = lx.MatrixLinearOperator(matrix.T @ matrix)"
]
},
{
"cell_type": "markdown",
"id": "5459b2d6-ddb9-4a37-bb51-3f5c204bab0d",
"metadata": {},
"source": [
"Unfortunately, Lineax has no way of knowing that this matrix is positive definite. It can solve the system, but it will not use a solver that is adapted to exploit the extra structure:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "78400416-e774-4f74-a530-e368db84af0e",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LU()\n"
]
}
],
"source": [
"solution = lx.linear_solve(operator, vector)\n",
"print(default_solver.select_solver(operator))"
]
},
{
"cell_type": "markdown",
"id": "e108bdff-1cf1-4751-8c9d-3baae82ca9a7",
"metadata": {},
"source": [
"But if we add a tag:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f6dc2966-1dfa-4a3c-be6a-974926695547",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cholesky()\n"
]
}
],
"source": [
"operator = lx.MatrixLinearOperator(matrix.T @ matrix, lx.positive_semidefinite_tag)\n",
"solution2 = lx.linear_solve(operator, vector)\n",
"print(default_solver.select_solver(operator))"
]
},
{
"cell_type": "markdown",
"id": "7274d17b-a7d3-45bf-9042-785ac25e2d74",
"metadata": {},
"source": [
"Then a more efficient solver can be selected. We can check that the solutions returned from these two approaches are equal:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "fdcde152-9ac1-4532-a174-3fc39d83d289",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 1.400575 -0.41042092 0.5313305 0.28422552]\n",
"[ 1.4005749 -0.41042086 0.53133047 0.2842255 ]\n"
]
}
],
"source": [
"print(solution.value)\n",
"print(solution2.value)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py39",
"language": "python",
"name": "py39"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/faq.md
================================================
# FAQ
## How does this differ from `jax.numpy.solve`, `jax.scipy.{...}` etc.?
Lineax offers several improvements. Most notably:
- Several new solvers. For example, [`lineax.QR`][] has no counterpart in core JAX. (And it is much faster than `jax.numpy.linalg.lstsq`, which is the closest equivalent, and uses an SVD decomposition instead.)
- Several new operators. For example, [`lineax.JacobianLinearOperator`][] has no counterpart in core JAX.
- A consistent API. The built-in JAX operations all differ from each other slightly, and are split across `jax.numpy`, `jax.scipy`, and `jax.scipy.sparse`.
- Numerically stable gradients. The existing JAX implementations will sometimes return `NaN`s!
- Some faster compile times and run times in a few places.
Most of these are because JAX aims to mimic the existing NumPy/SciPy APIs. (I.e. it's not JAX's fault that it doesn't take the approach that Lineax does!)
## How do I represent a {lower, upper} triangular matrix?
Typically: create a full matrix, with the {lower, upper} part containing your values, and the converse {upper, lower} part containing all zeros. Then use, e.g., `operator = lx.MatrixLinearOperator(matrix, lx.lower_triangular_tag)`.
This is the most efficient way to store a triangular matrix in JAX's ndarray-based programming model.
## What about other operations from linear algebra? (Determinants, eigenvalues, etc.)
See [`jax.numpy.linalg`](https://jax.readthedocs.io/en/latest/jax.numpy.html#module-jax.numpy.linalg) and [`jax.scipy.linalg`](https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.linalg).
## How do I solve multiple systems of equations (i.e. `AX = B`)?
Solvers implemented in Lineax target single systems of linear equations (i.e., `Ax = b`), however, using `jax.vmap` or `equinox.filter_vmap`, it can solve multiple systems with minimal effort.
```python
multi_linear_solve = eqx.filter_vmap(lx.linear_solve, in_axes=(None, 1))
# or
multi_linear_solve = jax.vmap(lx.linear_solve, in_axes=(None, 1))
```
================================================
FILE: docs/index.md
================================================
# Getting started
Lineax is a [JAX](https://github.com/google/jax) library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.)
Features include:
- PyTree-valued matrices and vectors;
- General linear operators for Jacobians, transposes, etc.;
- Efficient linear least squares (e.g. QR solvers);
- Numerically stable gradients through linear least squares;
- Support for structured (e.g. symmetric) matrices;
- Improved compilation times;
- Improved runtime of some algorithms;
- Support for both real-valued and complex-valued inputs;
- All the benefits of working with JAX: autodiff, autoparallism, GPU/TPU support, etc.
## Installation
```bash
pip install lineax
```
Requires Python 3.10+, JAX 0.4.38+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.10+.
## Quick example
Lineax can solve a least squares problem with an explicit matrix operator:
```python
import jax.random as jr
import lineax as lx
matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 8))
vector = jr.normal(vector_key, (10,))
operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector, solver=lx.QR())
```
or Lineax can solve a problem without ever materializing a matrix, as done in this
quadratic solve:
```python
import jax
import lineax as lx
key = jax.random.PRNGKey(0)
y = jax.random.normal(key, (10,))
def quadratic_fn(y, args):
return jax.numpy.sum((y - 1)**2)
gradient_fn = jax.grad(quadratic_fn)
hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-6, atol=1e-6)
out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver)
minimum = y - out.value
```
## Next steps
Check out the examples or the API reference on the left-hand bar.
## See also: other libraries in the JAX ecosystem
**Always useful**
[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
[jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.
**Deep learning**
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
[paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
**Scientific computing**
[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
**Awesome JAX**
[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.
================================================
FILE: lineax/__init__.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.metadata
from . import internal as internal
from ._operator import (
AbstractLinearOperator as AbstractLinearOperator,
AddLinearOperator as AddLinearOperator,
ComposedLinearOperator as ComposedLinearOperator,
conj as conj,
diagonal as diagonal,
DiagonalLinearOperator as DiagonalLinearOperator,
DivLinearOperator as DivLinearOperator,
FunctionLinearOperator as FunctionLinearOperator,
has_unit_diagonal as has_unit_diagonal,
IdentityLinearOperator as IdentityLinearOperator,
is_diagonal as is_diagonal,
is_lower_triangular as is_lower_triangular,
is_negative_semidefinite as is_negative_semidefinite,
is_positive_semidefinite as is_positive_semidefinite,
is_symmetric as is_symmetric,
is_tridiagonal as is_tridiagonal,
is_upper_triangular as is_upper_triangular,
JacobianLinearOperator as JacobianLinearOperator,
linearise as linearise,
materialise as materialise,
MatrixLinearOperator as MatrixLinearOperator,
MulLinearOperator as MulLinearOperator,
NegLinearOperator as NegLinearOperator,
PyTreeLinearOperator as PyTreeLinearOperator,
TaggedLinearOperator as TaggedLinearOperator,
TangentLinearOperator as TangentLinearOperator,
tridiagonal as tridiagonal,
TridiagonalLinearOperator as TridiagonalLinearOperator,
)
from ._solution import RESULTS as RESULTS, Solution as Solution
from ._solve import (
AbstractLinearSolver as AbstractLinearSolver,
AutoLinearSolver as AutoLinearSolver,
invert as invert,
linear_solve as linear_solve,
)
from ._solver import (
BiCGStab as BiCGStab,
CG as CG,
Cholesky as Cholesky,
Diagonal as Diagonal,
GMRES as GMRES,
LSMR as LSMR,
LU as LU,
Normal as Normal,
NormalCG as NormalCG,
QR as QR,
SVD as SVD,
Triangular as Triangular,
Tridiagonal as Tridiagonal,
)
from ._tags import (
diagonal_tag as diagonal_tag,
lower_triangular_tag as lower_triangular_tag,
negative_semidefinite_tag as negative_semidefinite_tag,
positive_semidefinite_tag as positive_semidefinite_tag,
symmetric_tag as symmetric_tag,
transpose_tags as transpose_tags,
transpose_tags_rules as transpose_tags_rules,
tridiagonal_tag as tridiagonal_tag,
unit_diagonal_tag as unit_diagonal_tag,
upper_triangular_tag as upper_triangular_tag,
)
__version__ = importlib.metadata.version("lineax")
================================================
FILE: lineax/_custom_types.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import equinox.internal as eqxi
sentinel: Any = eqxi.doc_repr(object(), "sentinel")
================================================
FILE: lineax/_misc.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import Array, ArrayLike, Bool, PyTree # pyright:ignore
def tree_where(
pred: Bool[ArrayLike, ""], true: PyTree[ArrayLike], false: PyTree[ArrayLike]
) -> PyTree[Array]:
keep = lambda a, b: jnp.where(pred, a, b)
return jtu.tree_map(keep, true, false)
def resolve_rcond(rcond, n, m, dtype):
if rcond is None:
# This `2 *` is a heuristic: I have seen very rare failures without it, in ways
# that seem to depend on JAX compilation state. (E.g. running unrelated JAX
# computations beforehand, in a completely different JIT-compiled region, can
# result in differences in the success/failure of the solve.)
return 2 * jnp.finfo(dtype).eps * max(n, m)
else:
return jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
def jacobian(fn, in_size, out_size, holomorphic=False, has_aux=False, jac=None):
if jac is None:
# Heuristic for which is better in each case
# These could probably be tuned a lot more.
jac_fwd = (in_size < 100) or (in_size <= 1.5 * out_size)
elif jac == "fwd":
jac_fwd = True
elif jac == "bwd":
jac_fwd = False
else:
raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.")
if jac_fwd:
return jax.jacfwd(fn, holomorphic=holomorphic, has_aux=has_aux)
else:
return jax.jacrev(fn, holomorphic=holomorphic, has_aux=has_aux)
def _asarray(dtype, x):
return jnp.asarray(x, dtype=dtype)
# Work around JAX issue #15676
_asarray = jax.custom_jvp(_asarray, nondiff_argnums=(0,))
@_asarray.defjvp
def _asarray_jvp(dtype, x, tx):
(x,) = x
(tx,) = tx
return _asarray(dtype, x), _asarray(dtype, tx)
def default_floating_dtype():
if jax.config.jax_enable_x64: # pyright: ignore
return jnp.float64
else:
return jnp.float32
def inexact_asarray(x):
dtype = jnp.result_type(x)
if not jnp.issubdtype(jnp.result_type(x), jnp.inexact):
dtype = default_floating_dtype()
return _asarray(dtype, x)
def complex_to_real_dtype(dtype):
return jnp.finfo(dtype).dtype
def strip_weak_dtype(tree: PyTree) -> PyTree:
return jtu.tree_map(
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=x.sharding)
if type(x) is jax.ShapeDtypeStruct
else x,
tree,
)
def structure_equal(x, y) -> bool:
x = strip_weak_dtype(jax.eval_shape(lambda: x))
y = strip_weak_dtype(jax.eval_shape(lambda: y))
return eqx.tree_equal(x, y) is True
================================================
FILE: lineax/_norm.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools as ft
import math
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from jaxtyping import Array, ArrayLike, Inexact, PyTree, Scalar
from ._misc import complex_to_real_dtype, default_floating_dtype
def tree_dot(tree1: PyTree[ArrayLike], tree2: PyTree[ArrayLike]) -> Inexact[Array, ""]:
"""Compute the dot product of two pytrees of arrays with the same pytree
structure."""
leaves1, treedef1 = jtu.tree_flatten(tree1)
leaves2, treedef2 = jtu.tree_flatten(tree2)
if treedef1 != treedef2:
raise ValueError("trees must have the same structure")
assert len(leaves1) == len(leaves2)
dots = []
for leaf1, leaf2 in zip(leaves1, leaves2):
dots.append(
jnp.dot(
jnp.conj(leaf1).reshape(-1),
jnp.reshape(leaf2, -1),
precision=jax.lax.Precision.HIGHEST, # pyright: ignore
)
)
if len(dots) == 0:
return jnp.array(0, default_floating_dtype())
else:
return ft.reduce(jnp.add, dots)
def sum_squares(x: PyTree[ArrayLike]) -> Scalar:
"""Computes the square of the L2 norm of a PyTree of arrays.
Considering the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes
`Σ_i x_i^2`
"""
return tree_dot(x, x).real
def two_norm(x: PyTree[ArrayLike]) -> Scalar:
"""Computes the L2 norm of a PyTree of arrays.
Considering the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes
`sqrt(Σ_i x_i^2)`
"""
# Wrap the `custom_jvp` into a function so that our autogenerated documentation
# displays the docstring correctly.
return _two_norm(x)
@jax.custom_jvp
def _two_norm(x: PyTree[ArrayLike]) -> Scalar:
leaves = jtu.tree_leaves(x)
size = sum([jnp.size(xi) for xi in leaves])
if size == 1:
# Avoid needless squaring-and-then-rooting.
for leaf in leaves:
if jnp.size(leaf) == 1:
return jnp.abs(jnp.reshape(leaf, ()))
else:
assert False
else:
return jnp.sqrt(sum_squares(x))
@_two_norm.defjvp
def _two_norm_jvp(x, tx):
(x,) = x
(tx,) = tx
out = two_norm(x)
# Get zero gradient, rather than NaN gradient, in these cases.
pred = (out == 0) | jnp.isinf(out)
denominator = jnp.where(pred, 1, out)
# We could also switch the dot and the division.
# This approach is a bit more expensive (more divisions), but should be more
# numerically stable (`x` and `denominator` should be of the same scale; `tx` is of
# unknown scale).
with jax.numpy_dtype_promotion("standard"):
div = (x**ω / denominator).ω
t_out = tree_dot(div, tx).real
t_out = jnp.where(pred, 0, t_out)
return out, t_out
def rms_norm(x: PyTree[ArrayLike]) -> Scalar:
"""Compute the RMS (root-mean-squared) norm of a PyTree of arrays.
This is the same as the L2 norm, averaged by the size of the input `x`. Considering
the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes
`sqrt((Σ_i x_i^2)/n)`
"""
leaves = jtu.tree_leaves(x)
size = sum([jnp.size(xi) for xi in leaves])
if size == 0:
if len(leaves) == 0:
dtype = default_floating_dtype()
else:
dtype = complex_to_real_dtype(jnp.result_type(*leaves))
return jnp.array(0.0, dtype)
else:
return two_norm(x) / math.sqrt(size)
def max_norm(x: PyTree[ArrayLike]) -> Scalar:
"""Compute the L-infinity norm of a PyTree of arrays.
This is the largest absolute elementwise value. Considering the input `x` as a flat
vector `(x_1, ..., x_n)`, then this computes `max_i |x_i|`.
"""
leaves = jtu.tree_leaves(x)
leaf_maxes = [jnp.max(jnp.abs(xi)) for xi in leaves if jnp.size(xi) > 0]
if len(leaf_maxes) == 0:
if len(leaves) == 0:
dtype = default_floating_dtype()
else:
dtype = complex_to_real_dtype(jnp.result_type(*leaves))
return jnp.array(0.0, dtype)
else:
out = ft.reduce(jnp.maximum, leaf_maxes)
return _zero_grad_at_zero(out)
@jax.custom_jvp
def _zero_grad_at_zero(x):
return x
@_zero_grad_at_zero.defjvp
def _zero_grad_at_zero_jvp(primals, tangents):
(out,) = primals
(t_out,) = tangents
return out, jnp.where(out == 0, 0, t_out)
================================================
FILE: lineax/_operator.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import enum
import functools as ft
import math
import warnings
from collections.abc import Callable, Iterable
from typing import Any, Literal, NoReturn, TypeVar
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.flatten_util as jfu
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from equinox.internal import ω
from jaxtyping import (
Array,
ArrayLike,
Inexact,
PyTree, # pyright: ignore
Scalar,
Shaped,
)
from ._custom_types import sentinel
from ._misc import (
default_floating_dtype,
inexact_asarray,
jacobian,
strip_weak_dtype,
)
from ._tags import (
diagonal_tag,
lower_triangular_tag,
negative_semidefinite_tag,
positive_semidefinite_tag,
symmetric_tag,
transpose_tags,
tridiagonal_tag,
unit_diagonal_tag,
upper_triangular_tag,
)
def _frozenset(x: object | Iterable[object]) -> frozenset[object]:
try:
iter_x = iter(x) # pyright: ignore
except TypeError:
return frozenset([x])
else:
return frozenset(iter_x)
class AbstractLinearOperator(eqx.Module):
"""Abstract base class for all linear operators.
Linear operators can act between PyTrees. Each `AbstractLinearOperator` is thought
of as a linear function `X -> Y`, where each element of `X` is as PyTree of
floating-point JAX arrays, and each element of `Y` is a PyTree of floating-point
JAX arrays.
Abstract linear operators support some operations:
```python
op1 + op2 # addition of two operators
op1 @ op2 # composition of two operators.
op1 * 3.2 # multiplication by a scalar
op1 / 3.2 # division by a scalar
```
"""
def __check_init__(self):
if (
is_symmetric(self)
or is_positive_semidefinite(self)
or is_negative_semidefinite(self)
):
# In particular, we check that dtypes match.
in_structure = self.in_structure()
out_structure = self.out_structure()
# `is` check to handle the possibility of a tracer.
if eqx.tree_equal(in_structure, out_structure) is not True:
raise ValueError(
"Symmetric/Hermitian matrices must have matching input and output "
f"structures. Got input structure {in_structure} and output "
f"structure {out_structure}."
)
@abc.abstractmethod
def mv(
self, vector: PyTree[Inexact[Array, " _b"]]
) -> PyTree[Inexact[Array, " _a"]]:
"""Computes a matrix-vector product between this operator and a `vector`.
**Arguments:**
- `vector`: Should be some PyTree of floating-point arrays, whose structure
should match `self.in_structure()`.
**Returns:**
A PyTree of floating-point arrays, with structure that matches
`self.out_structure()`.
"""
@abc.abstractmethod
def as_matrix(self) -> Inexact[Array, "a b"]:
"""Materialises this linear operator as a matrix.
Note that this can be a computationally (time and/or memory) expensive
operation, as many linear operators are defined implicitly, e.g. in terms of
their action on a vector.
**Arguments:** None.
**Returns:**
A 2-dimensional floating-point JAX array.
"""
@abc.abstractmethod
def transpose(self) -> "AbstractLinearOperator":
"""Transposes this linear operator.
This can be called as either `operator.T` or `operator.transpose()`.
**Arguments:** None.
**Returns:**
Another [`lineax.AbstractLinearOperator`][].
"""
@abc.abstractmethod
def in_structure(self) -> PyTree[jax.ShapeDtypeStruct]:
"""Returns the expected input structure of this linear operator.
**Arguments:** None.
**Returns:**
A PyTree of `jax.ShapeDtypeStruct`.
"""
@abc.abstractmethod
def out_structure(self) -> PyTree[jax.ShapeDtypeStruct]:
"""Returns the expected output structure of this linear operator.
**Arguments:** None.
**Returns:**
A PyTree of `jax.ShapeDtypeStruct`.
"""
def in_size(self) -> int:
"""Returns the total number of scalars in the input of this linear operator.
That is, the dimensionality of its input space.
**Arguments:** None.
**Returns:** An integer.
"""
leaves = jtu.tree_leaves(self.in_structure())
return sum(math.prod(leaf.shape) for leaf in leaves) # pyright: ignore
def out_size(self) -> int:
"""Returns the total number of scalars in the output of this linear operator.
That is, the dimensionality of its output space.
**Arguments:** None.
**Returns:** An integer.
"""
leaves = jtu.tree_leaves(self.out_structure())
return sum(math.prod(leaf.shape) for leaf in leaves) # pyright: ignore
@property
def T(self) -> "AbstractLinearOperator":
"""Equivalent to [`lineax.AbstractLinearOperator.transpose`][]"""
return self.transpose()
def __add__(self, other) -> "AbstractLinearOperator":
if not isinstance(other, AbstractLinearOperator):
raise ValueError("Can only add AbstractLinearOperators together.")
return AddLinearOperator(self, other)
def __sub__(self, other) -> "AbstractLinearOperator":
if not isinstance(other, AbstractLinearOperator):
raise ValueError("Can only add AbstractLinearOperators together.")
return AddLinearOperator(self, -other)
def __mul__(self, other) -> "AbstractLinearOperator":
other = jnp.asarray(other)
if other.shape != ():
raise ValueError("Can only multiply AbstractLinearOperators by scalars.")
return MulLinearOperator(self, other)
def __rmul__(self, other) -> "AbstractLinearOperator":
return self * other
def __matmul__(self, other) -> "AbstractLinearOperator":
if not isinstance(other, AbstractLinearOperator):
raise ValueError("Can only compose AbstractLinearOperators together.")
return ComposedLinearOperator(self, other)
def __truediv__(self, other) -> "AbstractLinearOperator":
other = jnp.asarray(other)
if other.shape != ():
raise ValueError("Can only divide AbstractLinearOperators by scalars.")
return DivLinearOperator(self, other)
def __neg__(self) -> "AbstractLinearOperator":
return NegLinearOperator(self)
class MatrixLinearOperator(AbstractLinearOperator):
"""Wraps a 2-dimensional JAX array into a linear operator.
If the matrix has shape `(a, b)` then matrix-vector multiplication (`self.mv`) is
defined in the usual way: as performing a matrix-vector that accepts a vector of
shape `(a,)` and returns a vector of shape `(b,)`.
"""
matrix: Inexact[Array, "a b"]
tags: frozenset[object] = eqx.field(static=True)
def __init__(
self, matrix: Shaped[Array, "a b"], tags: object | frozenset[object] = ()
):
"""**Arguments:**
- `matrix`: a two-dimensional JAX array. For an array with shape `(a, b)` then
this operator can perform matrix-vector products on a vector of shape
`(b,)` to return a vector of shape `(a,)`.
- `tags`: any tags indicating whether this matrix has any particular properties,
like symmetry or positive-definite-ness. Note that these properties are
unchecked and you may get incorrect values elsewhere if these tags are
wrong.
"""
if jnp.ndim(matrix) != 2:
raise ValueError(
"`MatrixLinearOperator(matrix=...)` should be 2-dimensional."
)
if not jnp.issubdtype(matrix.dtype, jnp.inexact):
matrix = matrix.astype(jnp.float32)
self.matrix = matrix
self.tags = _frozenset(tags)
def mv(self, vector):
maybe_sparse_op = _try_sparse_materialise(self)
if maybe_sparse_op is not self:
return maybe_sparse_op.mv(vector)
return jnp.matmul(self.matrix, vector, precision=lax.Precision.HIGHEST)
def as_matrix(self):
return self.matrix
def transpose(self):
if is_symmetric(self):
return self
return MatrixLinearOperator(self.matrix.T, transpose_tags(self.tags))
def in_structure(self):
_, in_size = jnp.shape(self.matrix)
return jax.ShapeDtypeStruct(shape=(in_size,), dtype=self.matrix.dtype)
def out_structure(self):
out_size, _ = jnp.shape(self.matrix)
return jax.ShapeDtypeStruct(shape=(out_size,), dtype=self.matrix.dtype)
def _matmul(matrix: ArrayLike, vector: ArrayLike) -> Array:
# matrix has structure [leaf(out), leaf(in)]
# vector has structure [leaf(in)]
# return has structure [leaf(out)]
return jnp.tensordot(
matrix, vector, axes=jnp.ndim(vector), precision=lax.Precision.HIGHEST
)
def _tree_matmul(matrix: PyTree[ArrayLike], vector: PyTree[ArrayLike]) -> PyTree[Array]:
# matrix has structure [tree(in), leaf(out), leaf(in)]
# vector has structure [tree(in), leaf(in)]
# return has structure [leaf(out)]
matrix = jtu.tree_leaves(matrix)
vector = jtu.tree_leaves(vector)
assert len(matrix) == len(vector)
return sum([_matmul(m, v) for m, v in zip(matrix, vector)])
# Needed as static fields must be hashable and eq-able, and custom pytrees might have
# e.g. define custom __eq__ methods.
_T = TypeVar("_T")
_FlatPyTree = tuple[list[_T], jtu.PyTreeDef]
def _inexact_structure_impl2(x):
if jnp.issubdtype(x.dtype, jnp.inexact):
return x
else:
return x.astype(default_floating_dtype())
def _inexact_structure_impl(x):
return jtu.tree_map(_inexact_structure_impl2, x)
def _inexact_structure(x: PyTree[jax.ShapeDtypeStruct]) -> PyTree[jax.ShapeDtypeStruct]:
return strip_weak_dtype(jax.eval_shape(_inexact_structure_impl, x))
class _Leaf: # not a pytree
def __init__(self, value):
self.value = value
# The `{input,output}_structure`s have to be static because otherwise abstract
# evaluation rules will promote them to ShapedArrays.
class PyTreeLinearOperator(AbstractLinearOperator):
"""Represents a PyTree of floating-point JAX arrays as a linear operator.
This is basically a generalisation of [`lineax.MatrixLinearOperator`][], from
taking just a single array to take a PyTree-of-arrays. (And likewise from returning
a single array to returning a PyTree-of-arrays.)
Specifically, suppose we want this to be a linear operator `X -> Y`, for which
elements of `X` are PyTrees with structure `T` whose `i`th leaf is a floating-point
JAX array of shape `x_shape_i`, and elements of `Y` are PyTrees with structure `S`
whose `j`th leaf is a floating-point JAX array of has shape `y_shape_j`. Then the
input PyTree should have structure `T`-compose-`S`, and its `(i, j)`-th leaf should
be a floating-point JAX array of shape `(*x_shape_i, *y_shape_j)`.
!!! Example
```python
# Suppose `x` is a member of our input space, with the following pytree
# structure:
eqx.tree_pprint(x) # [f32[5, 9], f32[3]]
# Suppose `y` is a member of our output space, with the following pytree
# structure:
eqx.tree_pprint(y)
# {"a": f32[1, 2]}
# then `pytree` should be a pytree with the following structure:
eqx.tree_pprint(pytree) # {"a": [f32[1, 2, 5, 9], f32[1, 2, 3]]}
```
"""
pytree: PyTree[Inexact[Array, "..."]]
output_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
tags: frozenset[object] = eqx.field(static=True)
input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
def __init__(
self,
pytree: PyTree[ArrayLike],
output_structure: PyTree[jax.ShapeDtypeStruct],
tags: object | frozenset[object] = (),
):
"""**Arguments:**
- `pytree`: this should be a PyTree, with structure as specified in
[`lineax.PyTreeLinearOperator`][].
- `output_structure`: the structure of the output space. This should be a PyTree
of `jax.ShapeDtypeStruct`s. (The structure of the input space is then
automatically derived from the structure of `pytree`.)
- `tags`: any tags indicating whether this operator has any particular
properties, like symmetry or positive-definite-ness. Note that these
properties are unchecked and you may get incorrect values elsewhere if these
tags are wrong.
"""
output_structure = _inexact_structure(output_structure)
self.pytree = jtu.tree_map(inexact_asarray, pytree)
self.output_structure = jtu.tree_flatten(output_structure)
self.tags = _frozenset(tags)
# self.out_structure() has structure [tree(out)]
# self.pytree has structure [tree(out), tree(in), leaf(out), leaf(in)]
def get_structure(struct, subpytree):
# subpytree has structure [tree(in), leaf(out), leaf(in)]
def sub_get_structure(leaf):
shape = jnp.shape(leaf) # [leaf(out), leaf(in)]
ndim = len(struct.shape)
if shape[:ndim] != struct.shape:
raise ValueError(
"`pytree` and `output_structure` are not consistent"
)
return jax.ShapeDtypeStruct(
shape=shape[ndim:], dtype=jnp.result_type(leaf)
)
return _Leaf(jtu.tree_map(sub_get_structure, subpytree))
if output_structure is None:
# Implies that len(input_structures) > 0
raise ValueError("Cannot have trivial output_structure")
input_structures = jtu.tree_map(get_structure, output_structure, self.pytree)
input_structures = jtu.tree_leaves(input_structures)
input_structure = input_structures[0].value
for val in input_structures[1:]:
if eqx.tree_equal(input_structure, val.value) is not True:
raise ValueError(
"`pytree` does not have a consistent `input_structure`"
)
self.input_structure = jtu.tree_flatten(input_structure)
def mv(self, vector):
# vector has structure [tree(in), leaf(in)]
# self.out_structure() has structure [tree(out)]
# self.pytree has structure [tree(out), tree(in), leaf(out), leaf(in)]
# return has structure [tree(out), leaf(out)]
maybe_sparse_op = _try_sparse_materialise(self)
if maybe_sparse_op is not self:
return maybe_sparse_op.mv(vector)
def matmul(_, matrix):
return _tree_matmul(matrix, vector)
return jtu.tree_map(matmul, self.out_structure(), self.pytree)
def as_matrix(self):
with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(*jtu.tree_leaves(self.pytree))
def concat_in(struct, subpytree):
leaves = jtu.tree_leaves(subpytree)
assert all(leaf.shape[: struct.ndim] == struct.shape for leaf in leaves)
leaves = [
leaf.astype(dtype).reshape(
struct.size, math.prod(leaf.shape[struct.ndim :])
)
for leaf in leaves
]
return jnp.concatenate(leaves, axis=1)
matrix = jtu.tree_map(concat_in, self.out_structure(), self.pytree)
matrix = jtu.tree_leaves(matrix)
return jnp.concatenate(matrix, axis=0)
def transpose(self):
if is_symmetric(self):
return self
def _transpose(struct, subtree):
def _transpose_impl(leaf):
return jnp.moveaxis(leaf, source, dest)
source = list(range(struct.ndim))
dest = list(range(-struct.ndim, 0))
return jtu.tree_map(_transpose_impl, subtree)
pytree_transpose = jtu.tree_map(_transpose, self.out_structure(), self.pytree)
pytree_transpose = jtu.tree_transpose(
jtu.tree_structure(self.out_structure()),
jtu.tree_structure(self.in_structure()),
pytree_transpose,
)
return PyTreeLinearOperator(
pytree_transpose, self.in_structure(), transpose_tags(self.tags)
)
def in_structure(self):
leaves, treedef = self.input_structure
return jtu.tree_unflatten(treedef, leaves)
def out_structure(self):
leaves, treedef = self.output_structure
return jtu.tree_unflatten(treedef, leaves)
class DiagonalLinearOperator(AbstractLinearOperator):
"""A diagonal linear operator, e.g. for a diagonal matrix. Only the diagonal is
stored (for memory efficiency). Matrix-vector products are computed by doing a
pointwise diagonal * vector, rather than a full matrix @ vector (for speed).
The diagonal may also be a PyTree, rather than a 1D array. When materialising the
matrix, the diagonal is taken to be defined by the flattened PyTree (i.e. values
show up in the same order.)
"""
diagonal: PyTree[Inexact[Array, "..."]]
def __init__(self, diagonal: PyTree[ArrayLike]):
"""**Arguments:**
- `diagonal`: an array or PyTree defining the diagonal of the matrix.
"""
self.diagonal = jtu.tree_map(inexact_asarray, diagonal)
def mv(self, vector):
return (ω(self.diagonal) * ω(vector)).ω
def as_matrix(self):
return jnp.diag(diagonal(self))
def transpose(self):
return self
def in_structure(self):
return jax.eval_shape(lambda: self.diagonal)
def out_structure(self):
return jax.eval_shape(lambda: self.diagonal)
class _NoAuxIn(eqx.Module):
fn: Callable
args: Any
def __call__(self, x):
return self.fn(x, self.args)
class _Unwrap(eqx.Module):
fn: Callable
def __call__(self, x):
(f,) = self.fn(x)
return f
class JacobianLinearOperator(AbstractLinearOperator):
"""Given a function `fn: X -> Y`, and a point `x in X`, then this defines the
linear operator (also a function `X -> Y`) given by the Jacobian `(d(fn)/dx)(x)`.
For example if the inputs and outputs are just arrays, then this is equivalent to
`MatrixLinearOperator(jax.jacfwd(fn)(x))`.
The Jacobian is not materialised; matrix-vector products, which are in fact
Jacobian-vector products, are computed using autodifferentiation. By default
(or with `jac="fwd"`), `JacobianLinearOperator(fn, x).mv(v)` is equivalent to
`jax.jvp(fn, (x,), (v,))`. For `jac="bwd"`, `jax.vjp` is combined with
`jax.linear_transpose`, which works even with functions
that only define a custom VJP (via `jax.custom_vjp`) and don't support
forward-mode differentiation.
See also [`lineax.materialise`][], which materialises the whole Jacobian in
memory.
!!! tip
For repeated `mv()` calls, consider using [`lineax.linearise`][] to cache
the primal computation, e.g. for `jac="fwd"/None` it returns
`_, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)`
"""
fn: Callable[
[PyTree[Inexact[Array, "..."]], PyTree[Any]], PyTree[Inexact[Array, "..."]]
]
x: PyTree[Inexact[Array, "..."]]
args: PyTree[Any]
tags: frozenset[object] = eqx.field(static=True)
jac: Literal["fwd", "bwd"] | None
@eqxi.doc_remove_args("closure_convert")
def __init__(
self,
fn: Callable,
x: PyTree[ArrayLike],
args: PyTree[Any] = None,
tags: object | Iterable[object] = (),
jac: Literal["fwd", "bwd"] | None = None,
closure_convert: bool = True,
):
"""**Arguments:**
- `fn`: A function `(x, args) -> y`. The Jacobian `d(fn)/dx` is used as the
linear operator, and `args` are just any other arguments that should not be
differentiated.
- `x`: The point to evaluate `d(fn)/dx` at: `(d(fn)/dx)(x, args)`.
- `args`: As `x`; this is the point to evaluate `d(fn)/dx` at:
`(d(fn)/dx)(x, args)`.
- `tags`: any tags indicating whether this operator has any particular
properties, like symmetry or positive-definite-ness. Note that these
properties are unchecked and you may get incorrect values elsewhere if these
tags are wrong.
- `jac`: allows to use specific jacobian computation method. If `jac=fwd`
forces `jax.jacfwd` to be used, similarly `jac=bwd` mandates the use of
`jax.jacrev`. Otherwise, if not specified it will be chosen
by default according to input and output shape.
"""
if jac not in [None, "fwd", "bwd"]:
raise ValueError(
"`jac` argument of `JacobianLinearOperator` should be either "
"`'fwd'`, `'bwd'`, or `None`."
)
# Flush out any closed-over values, so that we can safely pass `self`
# across API boundaries. (In particular, across `linear_solve_p`.)
# We don't use `jax.closure_convert` as that only flushes autodiffable
# (=floating-point) constants. It probably doesn't matter, but if `fn` is a
# PyTree capturing non-floating-point constants, we should probably continue
# to respect that, and keep any non-floating-point constants as part of the
# PyTree structure.
x = jtu.tree_map(inexact_asarray, x)
if closure_convert:
fn = eqx.filter_closure_convert(fn, x, args)
self.fn = fn
self.x = x
self.args = args
self.tags = _frozenset(tags)
self.jac = jac
def mv(self, vector):
fn = _NoAuxIn(self.fn, self.args)
if self.jac == "fwd" or self.jac is None:
_, out = jax.jvp(fn, (self.x,), (vector,))
elif self.jac == "bwd":
# Use VJP + linear_transpose instead of materializing full Jacobian.
# This works even for custom_vjp functions that don't have JVP rules.
_, vjp_fn = jax.vjp(fn, self.x)
if is_symmetric(self):
# For symmetric operators, J = J.T, so vjp directly gives J @ v
(out,) = vjp_fn(vector)
else:
# For non-symmetric, transpose the VJP to get J @ v from J.T @ v
transpose_vjp = jax.linear_transpose(
lambda g: vjp_fn(g)[0], self.out_structure()
)
(out,) = transpose_vjp(vector)
else:
raise ValueError("`jac` should be either `'fwd'`, `'bwd'`, or `None`.")
return out
def as_matrix(self):
return materialise(self).as_matrix()
def transpose(self):
if is_symmetric(self):
return self
fn = _NoAuxIn(self.fn, self.args)
# Works because vjpfn is a PyTree
_, vjpfn = jax.vjp(fn, self.x)
vjpfn = _Unwrap(vjpfn)
return FunctionLinearOperator(
vjpfn, self.out_structure(), transpose_tags(self.tags)
)
def in_structure(self):
return strip_weak_dtype(jax.eval_shape(lambda: self.x))
def out_structure(self):
fn = _NoAuxIn(self.fn, self.args)
return strip_weak_dtype(eqxi.cached_filter_eval_shape(fn, self.x))
# `input_structure` must be static as with `JacobianLinearOperator`
class FunctionLinearOperator(AbstractLinearOperator):
"""Wraps a *linear* function `fn: X -> Y` into a linear operator. (So that
`self.mv(x)` is defined by `self.mv(x) == fn(x)`.)
See also [`lineax.materialise`][], which materialises the whole linear operator
in memory. (Similar to `.as_matrix()`.)
"""
fn: Callable[[PyTree[Inexact[Array, "..."]]], PyTree[Inexact[Array, "..."]]]
input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
tags: frozenset[object] = eqx.field(static=True)
@eqxi.doc_remove_args("closure_convert")
def __init__(
self,
fn: Callable[[PyTree[Inexact[Array, "..."]]], PyTree[Inexact[Array, "..."]]],
input_structure: PyTree[jax.ShapeDtypeStruct],
tags: object | Iterable[object] = (),
closure_convert: bool = True,
):
"""**Arguments:**
- `fn`: a linear function. Should accept a PyTree of floating-point JAX arrays,
and return a PyTree of floating-point JAX arrays.
- `input_structure`: A PyTree of `jax.ShapeDtypeStruct`s specifying the
structure of the input to the function. (When later calling `self.mv(x)`
then this should match the structure of `x`, i.e.
`jax.eval_shape(lambda: x)`.)
- `tags`: any tags indicating whether this operator has any particular
properties, like symmetry or positive-definite-ness. Note that these
properties are unchecked and you may get incorrect values elsewhere if these
tags are wrong.
"""
# See matching comment in JacobianLinearOperator.
input_structure = _inexact_structure(input_structure)
if closure_convert:
fn = eqx.filter_closure_convert(fn, input_structure)
self.fn = fn
self.input_structure = jtu.tree_flatten(input_structure)
self.tags = _frozenset(tags)
def mv(self, vector):
return self.fn(vector)
def as_matrix(self):
return materialise(self).as_matrix()
def transpose(self):
if is_symmetric(self):
return self
transpose_fn = jax.linear_transpose(self.fn, self.in_structure())
def _transpose_fn(vector):
(out,) = transpose_fn(vector)
return out
# Works because transpose_fn is a PyTree
return FunctionLinearOperator(
_transpose_fn, self.out_structure(), transpose_tags(self.tags)
)
def in_structure(self):
leaves, treedef = self.input_structure
return jtu.tree_unflatten(treedef, leaves)
def out_structure(self):
return strip_weak_dtype(
eqxi.cached_filter_eval_shape(self.fn, self.in_structure())
)
# `structure` must be static as with `JacobianLinearOperator`
class IdentityLinearOperator(AbstractLinearOperator):
"""Represents the identity transformation `X -> X`, where each `x in X` is some
PyTree of floating-point JAX arrays.
"""
input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
output_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
def __init__(
self,
input_structure: PyTree[jax.ShapeDtypeStruct],
output_structure: PyTree[jax.ShapeDtypeStruct] = sentinel,
):
"""**Arguments:**
- `input_structure`: A PyTree of `jax.ShapeDtypeStruct`s specifying the
structure of the the input space. (When later calling `self.mv(x)`
then this should match the structure of `x`, i.e.
`jax.eval_shape(lambda: x)`.)
- `output_structure`: A PyTree of `jax.ShapeDtypeStruct`s specifying the
structure of the the output space. If not passed then this defaults to the
same as `input_structure`. If passed then it must have the same number of
elements as `input_structure`, so that the operator is square.
"""
if output_structure is sentinel:
output_structure = input_structure
input_structure = _inexact_structure(input_structure)
output_structure = _inexact_structure(output_structure)
self.input_structure = jtu.tree_flatten(input_structure)
self.output_structure = jtu.tree_flatten(output_structure)
def mv(self, vector):
if not eqx.tree_equal(
strip_weak_dtype(jax.eval_shape(lambda: vector)),
strip_weak_dtype(self.in_structure()),
):
raise ValueError("Vector and operator structures do not match")
elif self.input_structure == self.output_structure:
return vector # fast-path for common special case
else:
# TODO(kidger): this could be done slightly more efficiently, by iterating
# leaf-by-leaf.
leaves = jtu.tree_leaves(vector)
with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(*leaves)
vector = jnp.concatenate([x.astype(dtype).reshape(-1) for x in leaves])
out_size = self.out_size()
if vector.size < out_size:
vector = jnp.concatenate(
[vector, jnp.zeros(out_size - vector.size, vector.dtype)]
)
else:
vector = vector[:out_size]
leaves, treedef = jtu.tree_flatten(self.out_structure())
sizes = np.cumsum([math.prod(x.shape) for x in leaves[:-1]])
split = jnp.split(vector, sizes)
assert len(split) == len(leaves)
with warnings.catch_warnings():
warnings.simplefilter("ignore") # ignore complex-to-real cast warning
shaped = [
x.reshape(y.shape).astype(y.dtype) for x, y in zip(split, leaves)
]
return jtu.tree_unflatten(treedef, shaped)
def as_matrix(self):
leaves = jtu.tree_leaves(self.in_structure())
with jax.numpy_dtype_promotion("standard"):
dtype = (
default_floating_dtype()
if len(leaves) == 0
else jnp.result_type(*leaves)
)
return jnp.eye(self.out_size(), self.in_size(), dtype=dtype)
def transpose(self):
return IdentityLinearOperator(self.out_structure(), self.in_structure())
def in_structure(self):
leaves, treedef = self.input_structure
return jtu.tree_unflatten(treedef, leaves)
def out_structure(self):
leaves, treedef = self.output_structure
return jtu.tree_unflatten(treedef, leaves)
@property
def tags(self):
return frozenset()
class TridiagonalLinearOperator(AbstractLinearOperator):
"""As [`lineax.MatrixLinearOperator`][], but for specifically a tridiagonal
matrix.
"""
diagonal: Inexact[Array, " size"]
lower_diagonal: Inexact[Array, " size-1"]
upper_diagonal: Inexact[Array, " size-1"]
def __init__(
self,
diagonal: Inexact[Array, " size"],
lower_diagonal: Inexact[Array, " size-1"],
upper_diagonal: Inexact[Array, " size-1"],
):
"""**Arguments:**
- `diagonal`: A rank-one JAX array. This is the diagonal of the matrix.
- `lower_diagonal`: A rank-one JAX array. This is the lower diagonal of the
matrix.
- `upper_diagonal`: A rank-one JAX array. This is the upper diagonal of the
matrix.
If `diagonal` has shape `(a,)` then `lower_diagonal` and `upper_diagonal` should
both have shape `(a - 1,)`.
"""
self.diagonal = inexact_asarray(diagonal)
self.lower_diagonal = inexact_asarray(lower_diagonal)
self.upper_diagonal = inexact_asarray(upper_diagonal)
(size,) = self.diagonal.shape
if self.lower_diagonal.shape != (size - 1,):
raise ValueError("lower_diagonal and diagonal do not have consistent size")
if self.upper_diagonal.shape != (size - 1,):
raise ValueError("upper_diagonal and diagonal do not have consistent size")
def mv(self, vector):
a = self.upper_diagonal * vector[1:]
b = self.diagonal * vector
c = self.lower_diagonal * vector[:-1]
return b.at[:-1].add(a).at[1:].add(c)
def as_matrix(self):
(size,) = jnp.shape(self.diagonal)
matrix = jnp.zeros((size, size), self.diagonal.dtype)
arange = np.arange(size)
matrix = matrix.at[arange, arange].set(self.diagonal)
matrix = matrix.at[arange[1:], arange[:-1]].set(self.lower_diagonal)
matrix = matrix.at[arange[:-1], arange[1:]].set(self.upper_diagonal)
return matrix
def transpose(self):
return TridiagonalLinearOperator(
self.diagonal, self.upper_diagonal, self.lower_diagonal
)
def in_structure(self):
(size,) = jnp.shape(self.diagonal)
return jax.ShapeDtypeStruct(shape=(size,), dtype=self.diagonal.dtype)
def out_structure(self):
(size,) = jnp.shape(self.diagonal)
return jax.ShapeDtypeStruct(shape=(size,), dtype=self.diagonal.dtype)
class TaggedLinearOperator(AbstractLinearOperator):
"""Wraps another linear operator and specifies that it has certain tags, e.g.
representing symmetry.
!!! Example
```python
# Some other operator.
operator = lx.MatrixLinearOperator(some_jax_array)
# Now symmetric! But the type system doesn't know this.
sym_operator = operator + operator.T
assert lx.is_symmetric(sym_operator) == False
# We can declare that our operator has a particular property.
sym_operator = lx.TaggedLinearOperator(sym_operator, lx.symmetric_tag)
assert lx.is_symmetric(sym_operator) == True
```
"""
operator: AbstractLinearOperator
tags: frozenset[object] = eqx.field(static=True)
def __init__(
self, operator: AbstractLinearOperator, tags: object | Iterable[object]
):
"""**Arguments:**
- `operator`: some other linear operator to wrap.
- `tags`: any tags indicating whether this operator has any particular
properties, like symmetry or positive-definite-ness. Note that these
properties are unchecked and you may get incorrect values elsewhere if these
tags are wrong.
"""
self.operator = operator
self.tags = _frozenset(tags)
def mv(self, vector):
return self.operator.mv(vector)
def as_matrix(self):
return self.operator.as_matrix()
def transpose(self):
return TaggedLinearOperator(
self.operator.transpose(), transpose_tags(self.tags)
)
def in_structure(self):
return self.operator.in_structure()
def out_structure(self):
return self.operator.out_structure()
#
# All operators below here are private to lineax.
#
def _is_none(x):
return x is None
class TangentLinearOperator(AbstractLinearOperator):
"""Internal to lineax. Used to represent the tangent (jvp) computation with
respect to the linear operator in a linear solve.
"""
primal: AbstractLinearOperator
tangent: AbstractLinearOperator
def __check_init__(self):
assert type(self.primal) is type(self.tangent) # noqa: E721
def mv(self, vector):
mv = lambda operator: operator.mv(vector)
out, t_out = eqx.filter_jvp(mv, (self.primal,), (self.tangent,))
return jtu.tree_map(eqxi.materialise_zeros, out, t_out, is_leaf=_is_none)
def as_matrix(self):
as_matrix = lambda operator: operator.as_matrix()
out, t_out = eqx.filter_jvp(as_matrix, (self.primal,), (self.tangent,))
return jtu.tree_map(eqxi.materialise_zeros, out, t_out, is_leaf=_is_none)
def transpose(self):
transpose = lambda operator: operator.transpose()
primal_out, tangent_out = eqx.filter_jvp(
transpose, (self.primal,), (self.tangent,)
)
return TangentLinearOperator(primal_out, tangent_out)
def in_structure(self):
return self.primal.in_structure()
def out_structure(self):
return self.primal.out_structure()
class AddLinearOperator(AbstractLinearOperator):
"""A linear operator formed by adding two other linear operators together.
!!! Example
```python
x = MatrixLinearOperator(...)
y = MatrixLinearOperator(...)
assert isinstance(x + y, AddLinearOperator)
```
"""
operator1: AbstractLinearOperator
operator2: AbstractLinearOperator
def __check_init__(self):
if self.operator1.in_structure() != self.operator2.in_structure():
raise ValueError("Incompatible linear operator structures")
if self.operator1.out_structure() != self.operator2.out_structure():
raise ValueError("Incompatible linear operator structures")
def mv(self, vector):
maybe_sparse_op = _try_sparse_materialise(self)
if maybe_sparse_op is not self:
return maybe_sparse_op.mv(vector)
mv1 = self.operator1.mv(vector)
mv2 = self.operator2.mv(vector)
return (mv1**ω + mv2**ω).ω
def as_matrix(self):
return self.operator1.as_matrix() + self.operator2.as_matrix()
def transpose(self):
return self.operator1.transpose() + self.operator2.transpose()
def in_structure(self):
return self.operator1.in_structure()
def out_structure(self):
return self.operator1.out_structure()
class MulLinearOperator(AbstractLinearOperator):
"""A linear operator formed by multiplying a linear operator by a scalar.
!!! Example
```python
x = MatrixLinearOperator(...)
y = 0.5
assert isinstance(x * y, MulLinearOperator)
```
"""
operator: AbstractLinearOperator
scalar: Scalar
def mv(self, vector):
return (self.operator.mv(vector) ** ω * self.scalar).ω
def as_matrix(self):
return self.operator.as_matrix() * self.scalar
def transpose(self):
return self.operator.transpose() * self.scalar
def in_structure(self):
return self.operator.in_structure()
def out_structure(self):
return self.operator.out_structure()
# Not just `MulLinearOperator(..., -1)` for compatibility with
# `jax_numpy_dtype_promotion=strict`.
class NegLinearOperator(AbstractLinearOperator):
"""A linear operator formed by computing the negative of a linear operator.
!!! Example
```python
x = MatrixLinearOperator(...)
assert isinstance(-x, NegLinearOperator)
```
"""
operator: AbstractLinearOperator
def mv(self, vector):
return (-(self.operator.mv(vector) ** ω)).ω
def as_matrix(self):
return -self.operator.as_matrix()
def transpose(self):
return -self.operator.transpose()
def in_structure(self):
return self.operator.in_structure()
def out_structure(self):
return self.operator.out_structure()
class DivLinearOperator(AbstractLinearOperator):
"""A linear operator formed by dividing a linear operator by a scalar.
!!! Example
```python
x = MatrixLinearOperator(...)
y = 0.5
assert isinstance(x / y, DivLinearOperator)
```
"""
operator: AbstractLinearOperator
scalar: Scalar
def mv(self, vector):
with jax.numpy_dtype_promotion("standard"):
return (self.operator.mv(vector) ** ω / self.scalar).ω
def as_matrix(self):
return self.operator.as_matrix() / self.scalar
def transpose(self):
return self.operator.transpose() / self.scalar
def in_structure(self):
return self.operator.in_structure()
def out_structure(self):
return self.operator.out_structure()
class ComposedLinearOperator(AbstractLinearOperator):
"""A linear operator formed by composing (matrix-multiplying) two other linear
operators together.
!!! Example
```python
x = MatrixLinearOperator(matrix1)
y = MatrixLinearOperator(matrix2)
composed = x @ y
assert isinstance(composed, ComposedLinearOperator)
assert jnp.allclose(composed.as_matrix(), matrix1 @ matrix2)
```
"""
operator1: AbstractLinearOperator
operator2: AbstractLinearOperator
def __check_init__(self):
if self.operator1.in_structure() != self.operator2.out_structure():
raise ValueError("Incompatible linear operator structures")
def mv(self, vector):
maybe_sparse_op = _try_sparse_materialise(self)
if maybe_sparse_op is not self:
return maybe_sparse_op.mv(vector)
return self.operator1.mv(self.operator2.mv(vector))
def as_matrix(self):
if isinstance(self.operator1, IdentityLinearOperator):
return self.operator2.as_matrix()
if isinstance(self.operator2, IdentityLinearOperator):
return self.operator1.as_matrix()
_, unravel = eqx.filter_eval_shape(
jfu.ravel_pytree, self.operator1.in_structure()
)
def mv_flat(v):
out = self.operator1.mv(unravel(v))
return jfu.ravel_pytree(out)[0]
return jax.vmap(mv_flat, in_axes=1, out_axes=1)(self.operator2.as_matrix())
def transpose(self):
return self.operator2.transpose() @ self.operator1.transpose()
def in_structure(self):
return self.operator2.in_structure()
def out_structure(self):
return self.operator1.out_structure()
#
# Operations on `AbstractLinearOperator`s.
# These are done through `singledispatch` rather than as methods.
#
# If an end user ever wanted to add something analogous to
# `diagonal: AbstractLinearOperator -> Array`
# then of course they don't get to edit our base class and add overloads to all
# subclasses.
# They'd have to use `singledispatch` to get the desired behaviour. (Or maybe just
# hardcode compatibility with only some `AbstractLinearOperator` subclasses, eurgh.)
# So for consistency we do the same thing here, rather than adding privileged behaviour
# for just the operations we happen to support.
#
# (Something something Julia something something orphan problem etc.)
#
def _default_not_implemented(name: str, operator: AbstractLinearOperator) -> NoReturn:
msg = f"`lineax.{name}` has not been implemented for {type(operator)}"
if type(operator).__module__.startswith("lineax"):
assert False, msg + ". Please file a bug against Lineax."
else:
raise NotImplementedError(msg)
# linearise
@ft.singledispatch
def linearise(operator: AbstractLinearOperator) -> AbstractLinearOperator:
"""Linearises a linear operator. This returns another linear operator.
Mathematically speaking this is just the identity function. And indeed most linear
operators will be returned unchanged.
For specifically [`lineax.JacobianLinearOperator`][], then this will cache the
primal pass, so that it does not need to be recomputed each time. That is, it uses
some memory to improve speed. (This is the precisely same distinction as `jax.jvp`
versus `jax.linearize`.)
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Another linear operator. Mathematically it performs matrix-vector products
(`operator.mv`) that produce the same results as the input `operator`.
"""
_default_not_implemented("linearise", operator)
@linearise.register(MatrixLinearOperator)
@linearise.register(PyTreeLinearOperator)
@linearise.register(FunctionLinearOperator)
@linearise.register(IdentityLinearOperator)
@linearise.register(DiagonalLinearOperator)
@linearise.register(TridiagonalLinearOperator)
def _(operator):
return operator
@linearise.register(JacobianLinearOperator)
def _(operator):
fn = _NoAuxIn(operator.fn, operator.args)
if operator.jac == "bwd":
# For backward mode, use VJP + linear_transpose.
# This works even with custom_vjp functions that don't support forward-mode AD.
_, vjp_fn = jax.vjp(fn, operator.x)
if is_symmetric(operator):
# For symmetric: J = J.T, so vjp directly gives J @ v
lin = _Unwrap(vjp_fn)
else:
# Transpose the VJP to get J @ v from J.T @ v
lin = _Unwrap(
jax.linear_transpose(lambda g: vjp_fn(g)[0], operator.out_structure())
)
else: # "fwd" or None
_, lin = jax.linearize(fn, operator.x)
return FunctionLinearOperator(lin, operator.in_structure(), operator.tags)
# materialise
@ft.singledispatch
def materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator:
"""Materialises a linear operator. This returns another linear operator.
Mathematically speaking this is just the identity function. And indeed most linear
operators will be returned unchanged.
For specifically [`lineax.JacobianLinearOperator`][] and
[`lineax.FunctionLinearOperator`][] then the linear operator is materialised in
memory. That is, it becomes defined as a matrix (or pytree of arrays), rather
than being defined only through its matrix-vector product
([`lineax.AbstractLinearOperator.mv`][]).
Materialisation sometimes improves compile time or run time. It usually increases
memory usage.
For example:
```python
large_function = ...
operator = lx.FunctionLinearOperator(large_function, ...)
# Option 1
out1 = operator.mv(vector1) # Traces and compiles `large_function`
out2 = operator.mv(vector2) # Traces and compiles `large_function` again!
out3 = operator.mv(vector3) # Traces and compiles `large_function` a third time!
# All that compilation might lead to long compile times.
# If `large_function` takes a long time to run, then this might also lead to long
# run times.
# Option 2
operator = lx.materialise(operator) # Traces and compiles `large_function` and
# stores the result as a matrix.
out1 = operator.mv(vector1) # Each of these just computes a matrix-vector product
out2 = operator.mv(vector2) # against the stored matrix.
out3 = operator.mv(vector3) #
# Now, `large_function` is only compiled once, and only ran once.
# However, storing the matrix might take a lot of memory, and the initial
# computation may-or-may-not take a long time to run.
```
Generally speaking it is worth first setting up your problem without
`lx.materialise`, and using it as an optional optimisation if you find that it
helps your particular problem.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Another linear operator. Mathematically it performs matrix-vector products
(`operator.mv`) that produce the same results as the input `operator`.
"""
_default_not_implemented("materialise", operator)
def _try_sparse_materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator:
"""Try to materialise to a sparse operator.
Returns a (Tri)DiagonalLinearOperator if the operator is tagged as (tri)diagonal,
otherwise returns the original operator unchanged. The resulting operator
preserves the input/output structure of the original operator.
"""
if is_diagonal(operator):
diag_flat = diagonal(operator)
_, unravel = eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())
diag_pytree = unravel(diag_flat)
return DiagonalLinearOperator(diag_pytree)
# TridiagonalLinearOperator only supports flat in and out structures
if (
is_tridiagonal(operator)
and isinstance(operator.in_structure(), jax.ShapeDtypeStruct)
and isinstance(operator.out_structure(), jax.ShapeDtypeStruct)
):
return TridiagonalLinearOperator(*tridiagonal(operator))
return operator
@materialise.register(MatrixLinearOperator)
@materialise.register(PyTreeLinearOperator)
def _(operator):
return _try_sparse_materialise(operator)
@materialise.register(IdentityLinearOperator)
@materialise.register(DiagonalLinearOperator)
@materialise.register(TridiagonalLinearOperator)
def _(operator):
return operator
@materialise.register(JacobianLinearOperator)
def _(operator):
maybe_sparse_op = _try_sparse_materialise(operator)
if maybe_sparse_op is not operator:
return maybe_sparse_op
fn = _NoAuxIn(operator.fn, operator.args)
jac = jacobian(
fn,
operator.in_size(),
operator.out_size(),
holomorphic=any(jnp.iscomplexobj(xi) for xi in jtu.tree_leaves(operator.x)),
jac=operator.jac,
)(operator.x)
return PyTreeLinearOperator(jac, operator.out_structure(), operator.tags)
@materialise.register(FunctionLinearOperator)
def _(operator):
maybe_sparse_op = _try_sparse_materialise(operator)
if maybe_sparse_op is not operator:
return maybe_sparse_op
flat, unravel = strip_weak_dtype(
eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())
)
eye = jnp.eye(flat.size, dtype=flat.dtype)
jac = jax.vmap(lambda x: operator.fn(unravel(x)), out_axes=-1)(eye)
def batch_unravel(x):
assert x.ndim > 0
unravel_ = unravel
for _ in range(x.ndim - 1):
unravel_ = jax.vmap(unravel_)
return unravel_(x)
jac = jtu.tree_map(batch_unravel, jac)
return PyTreeLinearOperator(jac, operator.out_structure(), operator.tags)
# diagonal
@ft.singledispatch
def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]:
"""Extracts the diagonal from a linear operator, and returns a vector.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
A rank-1 JAX array. (That is, it has shape `(a,)` for some integer `a`.)
For most operators this is just `jnp.diag(operator.as_matrix())`. Some operators
(e.g. [`lineax.DiagonalLinearOperator`][]) can have more efficient
implementations. If you don't know what kind of operator you might have, then this
function ensures that you always get the most efficient implementation.
"""
_default_not_implemented("diagonal", operator)
def _leaf_from_keypath(pytree: PyTree, keypath: jtu.KeyPath) -> Array:
"""Extract the leaf from a pytree at the given keypath."""
for path, leaf in jtu.tree_leaves_with_path(pytree):
if path == keypath:
return leaf
raise ValueError(f"Leaf not found at keypath {keypath}")
@diagonal.register(MatrixLinearOperator)
def _(operator):
return jnp.diag(operator.as_matrix())
@diagonal.register(PyTreeLinearOperator)
def _(operator):
if is_diagonal(operator):
def extract_diag(keypath, struct, subpytree):
block = _leaf_from_keypath(subpytree, keypath)
return jnp.diag(block.reshape(struct.size, struct.size))
diags = jtu.tree_map_with_path(
extract_diag, operator.out_structure(), operator.pytree
)
return jnp.concatenate(jtu.tree_leaves(diags))
else:
return jnp.diag(operator.as_matrix())
@diagonal.register(JacobianLinearOperator)
@diagonal.register(FunctionLinearOperator)
def _(operator):
if is_diagonal(operator):
with jax.ensure_compile_time_eval():
basis = jtu.tree_map(
lambda s: jnp.ones(s.shape, s.dtype), operator.in_structure()
)
diag_as_pytree = operator.mv(basis)
diag, _ = jfu.ravel_pytree(diag_as_pytree)
return diag
return diagonal(materialise(operator))
@diagonal.register(DiagonalLinearOperator)
def _(operator):
diagonal, _ = jfu.ravel_pytree(operator.diagonal)
return diagonal
@diagonal.register(IdentityLinearOperator)
def _(operator):
return jnp.ones(operator.in_size())
@diagonal.register(TridiagonalLinearOperator)
def _(operator):
return operator.diagonal
# tridiagonal
@ft.singledispatch
def tridiagonal(
operator: AbstractLinearOperator,
) -> tuple[Shaped[Array, " size"], Shaped[Array, " size-1"], Shaped[Array, " size-1"]]:
"""Extracts the diagonal, lower diagonal, and upper diagonal, from a linear
operator. Returns three vectors.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
A 3-tuple, consisting of:
- The diagonal of the matrix, represented as a vector.
- The lower diagonal of the matrix, represented as a vector.
- The upper diagonal of the matrix, represented as a vector.
If the diagonal has shape `(a,)` then the lower and upper diagonals will have shape
`(a - 1,)`.
For most operators these are computed by materialising the array and then extracting
the relevant elements, e.g. getting the main diagonal via
`jnp.diag(operator.as_matrix())`. Some operators (e.g.
[`lineax.TridiagonalLinearOperator`][]) can have more efficient implementations.
If you don't know what kind of operator you might have, then this function ensures
that you always get the most efficient implementation.
"""
_default_not_implemented("tridiagonal", operator)
@tridiagonal.register(MatrixLinearOperator)
@tridiagonal.register(PyTreeLinearOperator)
def _(operator):
matrix = operator.as_matrix()
assert matrix.ndim == 2
main_diagonal = jnp.diagonal(matrix, offset=0)
upper_diagonal = jnp.diagonal(matrix, offset=1)
lower_diagonal = jnp.diagonal(matrix, offset=-1)
return main_diagonal, lower_diagonal, upper_diagonal
@tridiagonal.register(JacobianLinearOperator)
@tridiagonal.register(FunctionLinearOperator)
def _(operator):
if is_tridiagonal(operator):
with jax.ensure_compile_time_eval():
flat, unravel = strip_weak_dtype(
eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())
)
basis = jnp.zeros((3, flat.size), dtype=flat.dtype)
for i in range(3):
basis = basis.at[i, i::3].set(1.0)
basis = jax.vmap(unravel)(basis)
coloring = jnp.arange(flat.size) % 3
compressed_as_pytree = jax.vmap(operator.mv)(basis)
compressed_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(
compressed_as_pytree
)
# unique_indices propagates through linear_transpose to set unique_indices=True
# on the scatter, allowing assignment rather than accumulation.
rows = jnp.arange(flat.size)
diag = compressed_flat.at[coloring, rows].get(
wrap_negative_indices=False, unique_indices=True
)
lower_diag = compressed_flat.at[coloring[:-1], rows[1:]].get(
wrap_negative_indices=False, unique_indices=True
)
upper_diag = compressed_flat.at[coloring[1:], rows[:-1]].get(
wrap_negative_indices=False, unique_indices=True
)
return diag, lower_diag, upper_diag
matrix = operator.as_matrix()
assert matrix.ndim == 2
main_diagonal = jnp.diagonal(matrix, offset=0)
upper_diagonal = jnp.diagonal(matrix, offset=1)
lower_diagonal = jnp.diagonal(matrix, offset=-1)
return main_diagonal, lower_diagonal, upper_diagonal
@tridiagonal.register(DiagonalLinearOperator)
def _(operator):
diag = diagonal(operator)
upper_diag = jnp.zeros(diag.size - 1)
lower_diag = jnp.zeros(diag.size - 1)
return diag, lower_diag, upper_diag
@tridiagonal.register(IdentityLinearOperator)
def _(operator):
size = operator.in_size()
main_diagonal = jnp.ones(size)
off_diagonal = jnp.zeros(size - 1)
return main_diagonal, off_diagonal, off_diagonal
@tridiagonal.register(TridiagonalLinearOperator)
def _(operator):
return operator.diagonal, operator.lower_diagonal, operator.upper_diagonal
# is_symmetric
@ft.singledispatch
def is_symmetric(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as symmetric.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("is_symmetric", operator)
def _has_real_dtype(operator) -> bool:
"""Check if all dtypes in an operator's structure are real (not complex)."""
leaves = jtu.tree_leaves((operator.in_structure(), operator.out_structure()))
dtype = jnp.result_type(*leaves)
if jnp.issubdtype(dtype, jnp.complexfloating):
return False
elif jnp.issubdtype(dtype, jnp.floating):
return True
else:
assert False, (
"Only `jnp.floating` and `jnp.complexfloating` dtypes are understood."
)
@is_symmetric.register(MatrixLinearOperator)
@is_symmetric.register(PyTreeLinearOperator)
@is_symmetric.register(JacobianLinearOperator)
@is_symmetric.register(FunctionLinearOperator)
def _(operator):
# Symmetric (A = A^T) if explicitly tagged symmetric or diagonal
if symmetric_tag in operator.tags or diagonal_tag in operator.tags:
return True
# PSD/NSD implies symmetric only for real dtypes; for complex, it's Hermitian
if (
positive_semidefinite_tag in operator.tags
or negative_semidefinite_tag in operator.tags
):
return _has_real_dtype(operator)
return False
@is_symmetric.register(IdentityLinearOperator)
def _(operator):
return eqx.tree_equal(operator.in_structure(), operator.out_structure()) is True
@is_symmetric.register(DiagonalLinearOperator)
def _(operator):
return True
@is_symmetric.register(TridiagonalLinearOperator)
def _(operator):
return False
# is_diagonal
@ft.singledispatch
def is_diagonal(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as diagonal.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("is_diagonal", operator)
@is_diagonal.register(MatrixLinearOperator)
@is_diagonal.register(PyTreeLinearOperator)
@is_diagonal.register(JacobianLinearOperator)
@is_diagonal.register(FunctionLinearOperator)
def _(operator):
return diagonal_tag in operator.tags or (
operator.in_size() == 1 and operator.out_size() == 1
)
@is_diagonal.register(IdentityLinearOperator)
@is_diagonal.register(DiagonalLinearOperator)
def _(operator):
return True
@is_diagonal.register(TridiagonalLinearOperator)
def _(operator):
return operator.in_size() == 1
# is_tridiagonal
@ft.singledispatch
def is_tridiagonal(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as tridiagonal.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("is_tridiagonal", operator)
@is_tridiagonal.register(MatrixLinearOperator)
@is_tridiagonal.register(PyTreeLinearOperator)
@is_tridiagonal.register(JacobianLinearOperator)
@is_tridiagonal.register(FunctionLinearOperator)
def _(operator):
return tridiagonal_tag in operator.tags or diagonal_tag in operator.tags
@is_tridiagonal.register(IdentityLinearOperator)
@is_tridiagonal.register(DiagonalLinearOperator)
@is_tridiagonal.register(TridiagonalLinearOperator)
def _(operator):
return True
# has_unit_diagonal
@ft.singledispatch
def has_unit_diagonal(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as having unit diagonal.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("has_unit_diagonal", operator)
@has_unit_diagonal.register(MatrixLinearOperator)
@has_unit_diagonal.register(PyTreeLinearOperator)
@has_unit_diagonal.register(JacobianLinearOperator)
@has_unit_diagonal.register(FunctionLinearOperator)
def _(operator):
return unit_diagonal_tag in operator.tags
@has_unit_diagonal.register(IdentityLinearOperator)
def _(operator):
return True
@has_unit_diagonal.register(DiagonalLinearOperator)
@has_unit_diagonal.register(TridiagonalLinearOperator)
def _(operator):
# TODO: refine this
return False
# is_lower_triangular
@ft.singledispatch
def is_lower_triangular(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as lower triangular.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("is_lower_triangular", operator)
@is_lower_triangular.register(MatrixLinearOperator)
@is_lower_triangular.register(PyTreeLinearOperator)
@is_lower_triangular.register(JacobianLinearOperator)
@is_lower_triangular.register(FunctionLinearOperator)
def _(operator):
return lower_triangular_tag in operator.tags
@is_lower_triangular.register(IdentityLinearOperator)
@is_lower_triangular.register(DiagonalLinearOperator)
def _(operator):
return True
@is_lower_triangular.register(TridiagonalLinearOperator)
def _(operator):
return False
# is_upper_triangular
@ft.singledispatch
def is_upper_triangular(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as upper triangular.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("is_upper_triangular", operator)
@is_upper_triangular.register(MatrixLinearOperator)
@is_upper_triangular.register(PyTreeLinearOperator)
@is_upper_triangular.register(JacobianLinearOperator)
@is_upper_triangular.register(FunctionLinearOperator)
def _(operator):
return upper_triangular_tag in operator.tags
@is_upper_triangular.register(IdentityLinearOperator)
@is_upper_triangular.register(DiagonalLinearOperator)
def _(operator):
return True
@is_upper_triangular.register(TridiagonalLinearOperator)
def _(operator):
return False
# is_positive_semidefinite
@ft.singledispatch
def is_positive_semidefinite(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as positive semidefinite.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("is_positive_semidefinite", operator)
@is_positive_semidefinite.register(MatrixLinearOperator)
@is_positive_semidefinite.register(PyTreeLinearOperator)
@is_positive_semidefinite.register(JacobianLinearOperator)
@is_positive_semidefinite.register(FunctionLinearOperator)
def _(operator):
return positive_semidefinite_tag in operator.tags
@is_positive_semidefinite.register(IdentityLinearOperator)
def _(operator):
return eqx.tree_equal(operator.in_structure(), operator.out_structure()) is True
@is_positive_semidefinite.register(DiagonalLinearOperator)
@is_positive_semidefinite.register(TridiagonalLinearOperator)
def _(operator):
# TODO: refine this
return False
# is_negative_semidefinite
@ft.singledispatch
def is_negative_semidefinite(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as negative semidefinite.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("is_negative_semidefinite", operator)
@is_negative_semidefinite.register(MatrixLinearOperator)
@is_negative_semidefinite.register(PyTreeLinearOperator)
@is_negative_semidefinite.register(JacobianLinearOperator)
@is_negative_semidefinite.register(FunctionLinearOperator)
def _(operator):
return negative_semidefinite_tag in operator.tags
@is_negative_semidefinite.register(IdentityLinearOperator)
def _(operator):
return False
@is_negative_semidefinite.register(DiagonalLinearOperator)
@is_negative_semidefinite.register(TridiagonalLinearOperator)
def _(operator):
# TODO: refine this
return False
# ops for wrapper operators
@linearise.register(TaggedLinearOperator)
def _(operator):
return TaggedLinearOperator(linearise(operator.operator), operator.tags)
@materialise.register(TaggedLinearOperator)
def _(operator):
return TaggedLinearOperator(materialise(operator.operator), operator.tags)
@diagonal.register(TaggedLinearOperator)
def _(operator):
return diagonal(operator.operator)
@tridiagonal.register(TaggedLinearOperator)
def _(operator):
return tridiagonal(operator.operator)
for transform in (linearise, materialise, diagonal):
@transform.register(MulLinearOperator)
def _(operator, transform=transform):
return transform(operator.operator) * operator.scalar
@transform.register(NegLinearOperator) # pyright: ignore
def _(operator, transform=transform):
return -transform(operator.operator)
@transform.register(DivLinearOperator)
def _(operator, transform=transform):
return transform(operator.operator) / operator.scalar
for transform in (linearise, diagonal):
@transform.register(AddLinearOperator) # pyright: ignore
def _(operator, transform=transform):
return transform(operator.operator1) + transform(operator.operator2) # pyright: ignore
@materialise.register(AddLinearOperator)
def _(operator):
maybe_sparse_op = _try_sparse_materialise(operator)
if maybe_sparse_op is not operator:
return maybe_sparse_op
return materialise(operator.operator1) + materialise(operator.operator2)
@linearise.register(TangentLinearOperator)
def _(operator):
primal_out, tangent_out = eqx.filter_jvp(
linearise, (operator.primal,), (operator.tangent,)
)
return TangentLinearOperator(primal_out, tangent_out)
@materialise.register(TangentLinearOperator)
def _(operator):
primal_out, tangent_out = eqx.filter_jvp(
materialise, (operator.primal,), (operator.tangent,)
)
return TangentLinearOperator(primal_out, tangent_out)
@diagonal.register(TangentLinearOperator)
def _(operator):
# Should be unreachable: TangentLinearOperator is used for a narrow set of
# operations only (mv; transpose) inside the JVP rule linear_solve_p.
raise NotImplementedError(
"Please open a GitHub issue: https://github.com/google/lineax"
)
@tridiagonal.register(TangentLinearOperator)
def _(operator):
# Should be unreachable: TangentLinearOperator is used for a narrow set of
# operations only (mv; transpose) inside the JVP rule linear_solve_p.
raise NotImplementedError(
"Please open a GitHub issue: https://github.com/google/lineax"
)
@tridiagonal.register(AddLinearOperator)
def _(operator):
(diag1, lower1, upper1) = tridiagonal(operator.operator1)
(diag2, lower2, upper2) = tridiagonal(operator.operator2)
return (diag1 + diag2, lower1 + lower2, upper1 + upper2)
@tridiagonal.register(MulLinearOperator)
def _(operator):
(diag, lower, upper) = tridiagonal(operator.operator)
return (diag * operator.scalar, lower * operator.scalar, upper * operator.scalar)
@tridiagonal.register(NegLinearOperator)
def _(operator):
(diag, lower, upper) = tridiagonal(operator.operator)
return (-diag, -lower, -upper)
@tridiagonal.register(DivLinearOperator)
def _(operator):
(diag, lower, upper) = tridiagonal(operator.operator)
return (diag / operator.scalar, lower / operator.scalar, upper / operator.scalar)
@linearise.register(ComposedLinearOperator)
def _(operator):
return linearise(operator.operator1) @ linearise(operator.operator2)
@materialise.register(ComposedLinearOperator)
def _(operator):
if isinstance(operator.operator1, IdentityLinearOperator):
return materialise(operator.operator2)
if isinstance(operator.operator2, IdentityLinearOperator):
return materialise(operator.operator1)
maybe_sparse_op = _try_sparse_materialise(operator)
if maybe_sparse_op is not operator:
return maybe_sparse_op
return materialise(operator.operator1) @ materialise(operator.operator2)
@diagonal.register(ComposedLinearOperator)
def _(operator):
if is_diagonal(operator.operator1) and is_diagonal(operator.operator2):
return diagonal(operator.operator1) * diagonal(operator.operator2)
return jnp.diag(operator.as_matrix())
@tridiagonal.register(ComposedLinearOperator)
def _(operator):
if is_diagonal(operator.operator1) and is_tridiagonal(operator.operator2):
d = diagonal(operator.operator1)
main, lower, upper = tridiagonal(operator.operator2)
# D @ T scales rows: row i multiplied by d[i]
return d * main, d[1:] * lower, d[:-1] * upper
if is_diagonal(operator.operator2) and is_tridiagonal(operator.operator1):
d = diagonal(operator.operator2)
main, lower, upper = tridiagonal(operator.operator1)
# T @ D scales columns: column j multiplied by d[j]
return d * main, d[:-1] * lower, d[1:] * upper
matrix = operator.as_matrix()
assert matrix.ndim == 2
main_diagonal = jnp.diagonal(matrix, offset=0)
upper_diagonal = jnp.diagonal(matrix, offset=1)
lower_diagonal = jnp.diagonal(matrix, offset=-1)
return main_diagonal, lower_diagonal, upper_diagonal
for check in (
is_symmetric,
is_diagonal,
has_unit_diagonal,
is_lower_triangular,
is_upper_triangular,
is_tridiagonal,
is_positive_semidefinite,
is_negative_semidefinite,
):
@check.register(TangentLinearOperator)
def _(operator, check=check):
return check(operator.primal)
# Scaling/negating preserves these structural properties
for check in (
is_symmetric,
is_diagonal,
is_lower_triangular,
is_upper_triangular,
is_tridiagonal,
):
@check.register(MulLinearOperator)
@check.register(NegLinearOperator)
@check.register(DivLinearOperator)
def _(operator, check=check):
return check(operator.operator)
# has_unit_diagonal is NOT preserved by scaling or negation
@has_unit_diagonal.register(MulLinearOperator)
@has_unit_diagonal.register(NegLinearOperator)
@has_unit_diagonal.register(DivLinearOperator)
def _(operator):
return False
class _ScalarSign(enum.Enum):
positive = enum.auto()
negative = enum.auto()
zero = enum.auto()
unknown = enum.auto()
def _scalar_sign(scalar) -> _ScalarSign:
"""Returns the sign of a scalar, or unknown for JAX tracers."""
if isinstance(scalar, (int, float, np.ndarray, np.generic)):
scalar = float(scalar)
if scalar > 0:
return _ScalarSign.positive
elif scalar < 0:
return _ScalarSign.negative
else:
return _ScalarSign.zero
else:
return _ScalarSign.unknown
# PSD/NSD for MulLinearOperator: depends on sign of scalar
# Zero scalar gives zero matrix which is both PSD and NSD
@is_positive_semidefinite.register(MulLinearOperator)
def _(operator):
sign = _scalar_sign(operator.scalar)
if sign is _ScalarSign.positive:
return is_positive_semidefinite(operator.operator)
elif sign is _ScalarSign.negative:
return is_negative_semidefinite(operator.operator)
elif sign is _ScalarSign.zero:
return True # zero matrix is PSD
return False
@is_negative_semidefinite.register(MulLinearOperator)
def _(operator):
sign = _scalar_sign(operator.scalar)
if sign is _ScalarSign.positive:
return is_negative_semidefinite(operator.operator)
elif sign is _ScalarSign.negative:
return is_positive_semidefinite(operator.operator)
elif sign is _ScalarSign.zero:
return True # zero matrix is NSD
return False
# PSD/NSD for DivLinearOperator: depends on sign of scalar
# Zero scalar is division by zero - return False (conservative)
@is_positive_semidefinite.register(DivLinearOperator)
def _(operator):
sign = _scalar_sign(operator.scalar)
if sign is _ScalarSign.positive:
return is_positive_semidefinite(operator.operator)
elif sign is _ScalarSign.negative:
return is_negative_semidefinite(operator.operator)
return False
@is_negative_semidefinite.register(DivLinearOperator)
def _(operator):
sign = _scalar_sign(operator.scalar)
if sign is _ScalarSign.positive:
return is_negative_semidefinite(operator.operator)
elif sign is _ScalarSign.negative:
return is_positive_semidefinite(operator.operator)
return False
# PSD/NSD for NegLinearOperator: negation swaps PSD <-> NSD
@is_positive_semidefinite.register(NegLinearOperator)
def _(operator):
return is_negative_semidefinite(operator.operator)
@is_negative_semidefinite.register(NegLinearOperator)
def _(operator):
return is_positive_semidefinite(operator.operator)
for check, tag in (
(is_symmetric, symmetric_tag),
(is_diagonal, diagonal_tag),
(has_unit_diagonal, unit_diagonal_tag),
(is_lower_triangular, lower_triangular_tag),
(is_upper_triangular, upper_triangular_tag),
(is_positive_semidefinite, positive_semidefinite_tag),
(is_negative_semidefinite, negative_semidefinite_tag),
(is_tridiagonal, tridiagonal_tag),
):
@check.register(TaggedLinearOperator)
def _(operator, check=check, tag=tag):
return (tag in operator.tags) or check(operator.operator)
for check in (
is_symmetric,
is_diagonal,
is_lower_triangular,
is_upper_triangular,
is_positive_semidefinite,
is_negative_semidefinite,
is_tridiagonal,
):
@check.register(AddLinearOperator)
def _(operator, check=check):
return check(operator.operator1) and check(operator.operator2)
@has_unit_diagonal.register(AddLinearOperator)
def _(operator):
return False
# These properties ARE preserved under composition
for check in (
is_diagonal,
is_lower_triangular,
is_upper_triangular,
):
@check.register(ComposedLinearOperator)
def _(operator, check=check):
return check(operator.operator1) and check(operator.operator2)
# is_symmetric: A@B is symmetric only if A and B commute. Diagonal matrices commute.
@is_symmetric.register(ComposedLinearOperator)
def _(operator):
return is_diagonal(operator.operator1) and is_diagonal(operator.operator2)
# is_tridiagonal: tridiagonal @ tridiagonal = pentadiagonal, but
# tridiagonal @ diagonal = tridiagonal and diagonal @ tridiagonal = tridiagonal
@is_tridiagonal.register(ComposedLinearOperator)
def _(operator):
if is_diagonal(operator.operator1):
return is_tridiagonal(operator.operator2)
if is_diagonal(operator.operator2):
return is_tridiagonal(operator.operator1)
return False
# PSD/NSD: not preserved under composition in general.
@is_positive_semidefinite.register(ComposedLinearOperator)
@is_negative_semidefinite.register(ComposedLinearOperator)
def _(operator):
return False
@has_unit_diagonal.register(ComposedLinearOperator)
def _(operator):
a = is_diagonal(operator)
b = is_lower_triangular(operator)
c = is_upper_triangular(operator)
d = has_unit_diagonal(operator.operator1)
e = has_unit_diagonal(operator.operator2)
return (a or b or c) and d and e
# conj
@ft.singledispatch
def conj(operator: AbstractLinearOperator) -> AbstractLinearOperator:
"""Elementwise conjugate of a linear operator. This returns another linear operator.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Another linear operator.
"""
_default_not_implemented("conj", operator)
@conj.register(MatrixLinearOperator)
def _(operator):
return MatrixLinearOperator(operator.matrix.conj(), operator.tags)
@conj.register(PyTreeLinearOperator)
def _(operator):
pytree_conj = jtu.tree_map(lambda x: x.conj(), operator.pytree)
return PyTreeLinearOperator(pytree_conj, operator.out_structure(), operator.tags)
@conj.register(DiagonalLinearOperator)
def _(operator):
diagonal_conj = jtu.tree_map(lambda x: x.conj(), operator.diagonal)
return DiagonalLinearOperator(diagonal_conj)
@conj.register(JacobianLinearOperator)
def _(operator):
return conj(linearise(operator))
@conj.register(FunctionLinearOperator)
def _(operator):
return FunctionLinearOperator(
lambda vec: jtu.tree_map(jnp.conj, operator.mv(jtu.tree_map(jnp.conj, vec))),
operator.in_structure(),
operator.tags,
)
@conj.register(IdentityLinearOperator)
def _(operator):
return operator
@conj.register(TridiagonalLinearOperator)
def _(operator):
return TridiagonalLinearOperator(
operator.diagonal.conj(),
operator.lower_diagonal.conj(),
operator.upper_diagonal.conj(),
)
@conj.register(TaggedLinearOperator)
def _(operator):
return TaggedLinearOperator(conj(operator.operator), operator.tags)
@conj.register(TangentLinearOperator)
def _(operator):
c = lambda operator: conj(operator)
primal_out, tangent_out = eqx.filter_jvp(c, (operator.primal,), (operator.tangent,))
return TangentLinearOperator(primal_out, tangent_out)
@conj.register(AddLinearOperator)
def _(operator):
return conj(operator.operator1) + conj(operator.operator2)
@conj.register(MulLinearOperator)
def _(operator):
return conj(operator.operator) * operator.scalar.conj()
@conj.register(NegLinearOperator)
def _(operator):
return -conj(operator.operator)
@conj.register(DivLinearOperator)
def _(operator):
return conj(operator.operator) / operator.scalar.conj()
@conj.register(ComposedLinearOperator)
def _(operator):
return conj(operator.operator1) @ conj(operator.operator2)
================================================
FILE: lineax/_solution.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import equinox as eqx
import equinox.internal as eqxi
from jaxtyping import Array, ArrayLike, PyTree
_singular_msg = """
A linear solver returned non-finite (NaN or inf) output. This usually means that an
operator was not well-posed, and that its solver does not support this.
If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.
If you *were* expecting this solver to work with this operator, then it may be because:
(a) the operator is singular, and your code has a bug; or
(b) the operator was nearly singular (i.e. it had a high condition number:
`jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
numerical instability issues; or
(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
that is does not actually satisfy.
""".strip()
_nonfinite_msg = """
A linear solver received non-finite (NaN or inf) input and cannot determine a
solution.
This means that you have a bug upstream of Lineax and should check the inputs to
`lineax.linear_solve` for non-finite values.
""".strip()
class RESULTS(eqxi.Enumeration):
successful = ""
max_steps_reached = (
"The maximum number of solver steps was reached. Try increasing `max_steps`."
)
singular = _singular_msg
breakdown = (
"A form of iterative breakdown has occured in a linear solve. "
"Try using a different solver for this problem or increase `restart` "
"if using GMRES."
)
stagnation = (
"A stagnation in an iterative linear solve has occurred. Try increasing "
"`stagnation_iters` or `restart`."
)
conlim = "Condition number of A seems to be larger than `conlim`."
nonfinite_input = _nonfinite_msg
class Solution(eqx.Module):
"""The solution to a linear solve.
**Attributes:**
- `value`: The solution to the solve.
- `result`: An integer representing whether the solve was successful or not. This
can be converted into a human-readable error message via
`lineax.RESULTS[result]`.
- `stats`: Statistics about the solver, e.g. the number of steps that were required.
- `state`: The internal state of the solver. The meaning of this is specific to each
solver.
"""
value: PyTree[Array]
result: RESULTS
stats: dict[str, PyTree[ArrayLike]]
state: PyTree[Any]
================================================
FILE: lineax/_solve.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import functools as ft
from typing import Any, Generic, TypeAlias, TypeVar
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.core
import jax.interpreters.ad as ad
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from jax._src.ad_util import stop_gradient_p
from jaxtyping import Array, ArrayLike, PyTree
from ._custom_types import sentinel
from ._misc import inexact_asarray, strip_weak_dtype
from ._operator import (
AbstractLinearOperator,
conj,
FunctionLinearOperator,
has_unit_diagonal,
IdentityLinearOperator,
is_diagonal,
is_lower_triangular,
is_negative_semidefinite,
is_positive_semidefinite,
is_symmetric,
is_tridiagonal,
is_upper_triangular,
linearise,
TangentLinearOperator,
)
from ._solution import RESULTS, Solution
from ._tags import (
diagonal_tag,
lower_triangular_tag,
negative_semidefinite_tag,
positive_semidefinite_tag,
symmetric_tag,
unit_diagonal_tag,
upper_triangular_tag,
)
#
# _linear_solve_p
#
def _to_shapedarray(x):
if isinstance(x, jax.ShapeDtypeStruct):
return jax.core.ShapedArray(x.shape, x.dtype)
else:
return x
def _to_struct(x):
if isinstance(x, jax.core.ShapedArray):
return jax.ShapeDtypeStruct(x.shape, x.dtype)
elif isinstance(x, jax.core.AbstractValue):
raise NotImplementedError(
"`lineax.linear_solve` only supports working with JAX arrays; not "
f"other abstract values. Got abstract value {x}."
)
else:
return x
def _assert_false(x):
assert False
def _is_none(x):
return x is None
def _sum(*args):
return sum(args)
def _linear_solve_impl(_, state, vector, options, solver, throw, *, check_closure):
out = solver.compute(state, vector, options)
if check_closure:
out = eqxi.nontraceable(
out, name="lineax.linear_solve with respect to a closed-over value"
)
solution, result, stats = out
has_nonfinite_output = jnp.any(
jnp.stack(
[jnp.any(jnp.invert(jnp.isfinite(x))) for x in jtu.tree_leaves(solution)]
)
)
result = RESULTS.where(
(result == RESULTS.successful) & has_nonfinite_output,
RESULTS.singular,
result,
)
has_nonfinite_input = jnp.any(
jnp.stack(
[jnp.any(jnp.invert(jnp.isfinite(x))) for x in jtu.tree_leaves(vector)]
)
)
result = RESULTS.where(
(result == RESULTS.singular) & has_nonfinite_input,
RESULTS.nonfinite_input,
result,
)
if throw:
solution, result, stats = result.error_if(
(solution, result, stats),
result != RESULTS.successful,
)
return solution, result, stats
@eqxi.filter_primitive_def
def _linear_solve_abstract_eval(operator, state, vector, options, solver, throw):
state, vector, options, solver = jtu.tree_map(
_to_struct, (state, vector, options, solver)
)
out = eqx.filter_eval_shape(
_linear_solve_impl,
operator,
state,
vector,
options,
solver,
throw,
check_closure=False,
)
out = jtu.tree_map(_to_shapedarray, out)
return out
@eqxi.filter_primitive_jvp
def _linear_solve_jvp(primals, tangents):
operator, state, vector, options, solver, throw = primals
t_operator, t_state, t_vector, t_options, t_solver, t_throw = tangents
jtu.tree_map(_assert_false, (t_state, t_options, t_solver, t_throw))
del t_state, t_options, t_solver, t_throw
# Note that we pass throw=True unconditionally to all the tangent solves, as there
# is nowhere we can pipe their error to.
# This is the primal solve so we can respect the original `throw`.
solution, result, stats = eqxi.filter_primitive_bind(
linear_solve_p, operator, state, vector, options, solver, throw
)
#
# Consider the primal problem of linearly solving for x in Ax=b.
# Let ^ denote pseudoinverses, ᵀ denote transposes, and ' denote tangents.
# The linear_solve routine returns specifically the pseudoinverse solution, i.e.
#
# x = A^b
#
# Therefore x' = A^'b + A^b'
#
# Now A^' = -A^A'A^ + A^A^ᵀAᵀ'(I - AA^) + (I - A^A)Aᵀ'A^ᵀA^
#
# (Source: https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
#
# This results in:
#
# x' = A^(-A'x + A^ᵀAᵀ'(b - Ax) - Ay + b') + y
#
# where
#
# y = Aᵀ'A^ᵀx
#
# note that if A has linearly independent columns, then the y - A^Ay
# term disappears and gives
#
# x' = A^(-A'x + A^ᵀAᵀ'(b - Ax) + b')
#
# and if A has linearly independent rows, then the A^A^ᵀAᵀ'(b - Ax) term
# disappears giving:
#
# x' = A^(-A'x - Ay + b') + y
#
# if A has linearly independent rows and columns, then A is nonsingular and
#
# x' = A^(-A'x + b')
vecs = []
sols = []
if any(t is not None for t in jtu.tree_leaves(t_vector, is_leaf=_is_none)):
# b' term
vecs.append(
jtu.tree_map(eqxi.materialise_zeros, vector, t_vector, is_leaf=_is_none)
)
if any(t is not None for t in jtu.tree_leaves(t_operator, is_leaf=_is_none)):
t_operator = TangentLinearOperator(operator, t_operator)
t_operator = linearise(t_operator) # optimise for matvecs
# -A'x term
vec = (-(t_operator.mv(solution) ** ω)).ω
vecs.append(vec)
rows, columns = operator.out_size(), operator.in_size()
assume_independent_rows = solver.assume_full_rank() and rows <= columns
assume_independent_columns = solver.assume_full_rank() and columns <= rows
if not assume_independent_rows or not assume_independent_columns:
operator_conj_transpose = conj(operator).transpose()
t_operator_conj_transpose = conj(t_operator).transpose()
state_conj, options_conj = solver.conj(state, options)
state_conj_transpose, options_conj_transpose = solver.transpose(
state_conj, options_conj
)
if not assume_independent_rows:
lst_sqr_diff = (vector**ω - operator.mv(solution) ** ω).ω
tmp = t_operator_conj_transpose.mv(lst_sqr_diff) # pyright: ignore
tmp, _, _ = eqxi.filter_primitive_bind(
linear_solve_p,
operator_conj_transpose, # pyright: ignore
state_conj_transpose, # pyright: ignore
tmp,
options_conj_transpose, # pyright: ignore
solver,
True,
)
vecs.append(tmp)
if not assume_independent_columns:
tmp1, _, _ = eqxi.filter_primitive_bind(
linear_solve_p,
operator_conj_transpose, # pyright: ignore
state_conj_transpose, # pyright:ignore
solution,
options_conj_transpose, # pyright: ignore
solver,
True,
)
tmp2 = t_operator_conj_transpose.mv(tmp1) # pyright: ignore
# tmp2 is the y term
tmp3 = operator.mv(tmp2)
tmp4 = (-(tmp3**ω)).ω
# tmp4 is the Ay term
vecs.append(tmp4)
sols.append(tmp2)
vecs = jtu.tree_map(_sum, *vecs)
# the A^ term at the very beginning
sol, _, _ = eqxi.filter_primitive_bind(
linear_solve_p, operator, state, vecs, options, solver, True
)
sols.append(sol)
t_solution = jtu.tree_map(_sum, *sols)
out = solution, result, stats
t_out = (
t_solution,
jtu.tree_map(lambda _: None, result),
jtu.tree_map(lambda _: None, stats),
)
return out, t_out
def _is_undefined(x):
return isinstance(x, ad.UndefinedPrimal)
def _assert_defined(x):
assert not _is_undefined(x)
def _keep_undefined(v, ct):
if _is_undefined(v):
return ct
else:
return None
@eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore
def _linear_solve_transpose(inputs, cts_out):
cts_solution, _, _ = cts_out
operator, state, vector, options, solver, _ = inputs
jtu.tree_map(
_assert_defined, (operator, state, options, solver), is_leaf=_is_undefined
)
cts_solution = jtu.tree_map(
ft.partial(eqxi.materialise_zeros, allow_struct=True),
operator.in_structure(),
cts_solution,
)
operator_transpose = operator.transpose()
state_transpose, options_transpose = solver.transpose(state, options)
cts_vector, _, _ = eqxi.filter_primitive_bind(
linear_solve_p,
operator_transpose,
state_transpose,
cts_solution,
options_transpose,
solver,
True, # throw=True unconditionally: nowhere to pipe result to.
)
cts_vector = jtu.tree_map(
_keep_undefined, vector, cts_vector, is_leaf=_is_undefined
)
operator_none = jtu.tree_map(lambda _: None, operator)
state_none = jtu.tree_map(lambda _: None, state)
options_none = jtu.tree_map(lambda _: None, options)
solver_none = jtu.tree_map(lambda _: None, solver)
throw_none = None
return operator_none, state_none, cts_vector, options_none, solver_none, throw_none
# Call with `check_closure=False` so that the autocreated vmap rule works.
linear_solve_p = eqxi.create_vprim(
"linear_solve",
eqxi.filter_primitive_def(ft.partial(_linear_solve_impl, check_closure=False)),
_linear_solve_abstract_eval,
_linear_solve_jvp,
_linear_solve_transpose,
)
# Then rebind so that the impl rule catches leaked-in tracers.
linear_solve_p.def_impl(
eqxi.filter_primitive_def(ft.partial(_linear_solve_impl, check_closure=True))
)
eqxi.register_impl_finalisation(linear_solve_p)
#
# linear_solve
#
_SolverState = TypeVar("_SolverState")
class AbstractLinearSolver(eqx.Module, Generic[_SolverState]):
"""Abstract base class for all linear solvers."""
@abc.abstractmethod
def init(
self, operator: AbstractLinearOperator, options: dict[str, Any]
) -> _SolverState:
"""Do any initial computation on just the `operator`.
For example, an LU solver would compute the LU decomposition of the operator
(and this does not require knowing the vector yet).
It is common to need to solve the linear system `Ax=b` multiple times in
succession, with the same operator `A` and multiple vectors `b`. This method
improves efficiency by making it possible to re-use the computation performed
on just the operator.
!!! Example
```python
operator = lx.MatrixLinearOperator(...)
vector1 = ...
vector2 = ...
solver = lx.LU()
state = solver.init(operator, options={})
solution1 = lx.linear_solve(operator, vector1, solver, state=state)
solution2 = lx.linear_solve(operator, vector2, solver, state=state)
```
**Arguments:**
- `operator`: a linear operator.
- `options`: a dictionary of any extra options that the solver may wish to
accept.
**Returns:**
A PyTree of arbitrary Python objects.
"""
@abc.abstractmethod
def compute(
self, state: _SolverState, vector: PyTree[Array], options: dict[str, Any]
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
"""Solves a linear system.
**Arguments:**
- `state`: as returned from [`lineax.AbstractLinearSolver.init`][].
- `vector`: the vector to solve against.
- `options`: a dictionary of any extra options that the solver may wish to
accept. For example, [`lineax.CG`][] accepts a `preconditioner` option.
**Returns:**
A 3-tuple of:
- The solution to the linear system.
- An integer indicating the success or failure of the solve. This is an integer
which may be converted to a human-readable error message via
`lx.RESULTS[...]`.
- A dictionary of an extra statistics about the solve, e.g. the number of steps
taken.
"""
@abc.abstractmethod
def transpose(
self, state: _SolverState, options: dict[str, Any]
) -> tuple[_SolverState, dict[str, Any]]:
"""Transposes the result of [`lineax.AbstractLinearSolver.init`][].
That is, it should be the case that
```python
state_transpose, _ = solver.transpose(solver.init(operator, options), options)
state_transpose2 = solver.init(operator.T, options)
```
must be identical to each other.
It is relatively common (in particular when differentiating through a linear
solve) to need to solve both `Ax = b` and `A^T x = b`. This method makes it
possible to avoid computing both `solver.init(operator)` and
`solver.init(operator.T)` if one can be cheaply computed from the other.
**Arguments:**
- `state`: as returned from `solver.init`.
- `options`: any extra options that were passed to `solve.init`.
**Returns:**
A 2-tuple of:
- The state of the transposed operator.
- The options for the transposed operator.
"""
@abc.abstractmethod
def conj(
self, state: _SolverState, options: dict[str, Any]
) -> tuple[_SolverState, dict[str, Any]]:
"""Conjugate the result of [`lineax.AbstractLinearSolver.init`][].
That is, it should be the case that
```python
state_conj, _ = solver.conj(solver.init(operator, options), options)
state_conj2 = solver.init(conj(operator), options)
```
must be identical to each other.
**Arguments:**
- `state`: as returned from `solver.init`.
- `options`: any extra options that were passed to `solve.init`.
**Returns:**
A 2-tuple of:
- The state of the conjugated operator.
- The options for the conjugated operator.
"""
@abc.abstractmethod
def assume_full_rank(self) -> bool:
"""Does this solver assume that all operators are full rank?
When `False`, a more expensive backward pass is needed to account for
the extra generality. In a custom linear solver, it is always safe to
return False.
**Arguments:**
Nothing.
**Returns:**
Either `True` or `False`.
"""
_qr_token = eqxi.str2jax("qr_token")
_diagonal_token = eqxi.str2jax("diagonal_token")
_well_posed_diagonal_token = eqxi.str2jax("well_posed_diagonal_token")
_tridiagonal_token = eqxi.str2jax("tridiagonal_token")
_triangular_token = eqxi.str2jax("triangular_token")
_cholesky_token = eqxi.str2jax("cholesky_token")
_lu_token = eqxi.str2jax("lu_token")
_svd_token = eqxi.str2jax("svd_token")
# Ugly delayed import because we have the dependency chain
# linear_solve -> AutoLinearSolver -> {Cholesky,...} -> AbstractLinearSolver
# but we want linear_solver and AbstractLinearSolver in the same file.
def _lookup(token) -> AbstractLinearSolver:
from . import _solver
# pyright doesn't know that these keys are hashable
_lookup_dict = {
_qr_token: _solver.QR(), # pyright: ignore
_diagonal_token: _solver.Diagonal(), # pyright: ignore
_well_posed_diagonal_token: _solver.Diagonal( # pyright: ignore
well_posed=True
),
_tridiagonal_token: _solver.Tridiagonal(), # pyright: ignore
_triangular_token: _solver.Triangular(), # pyright: ignore
_cholesky_token: _solver.Cholesky(), # pyright: ignore
_lu_token: _solver.LU(), # pyright: ignore
_svd_token: _solver.SVD(), # pyright: ignore
}
return _lookup_dict[token]
_AutoLinearSolverState: TypeAlias = tuple[Any, Any]
class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState]):
"""Automatically determines a good linear solver based on the structure of the
operator.
- If `well_posed=True`:
- If the operator is diagonal, then use [`lineax.Diagonal`][].
- If the operator is tridiagonal, then use [`lineax.Tridiagonal`][].
- If the operator is triangular, then use [`lineax.Triangular`][].
- If the matrix is positive or negative (semi-)definite, then use
[`lineax.Cholesky`][].
- Else use [`lineax.LU`][].
This is a good choice if you want to be certain that an error is raised for
ill-posed systems.
- If `well_posed=False`:
- If the operator is diagonal, then use [`lineax.Diagonal`][].
- Else use [`lineax.SVD`][].
This is a good choice if you want to be certain that you can handle ill-posed
systems.
- If `well_posed=None`:
- If the operator is non-square, then use [`lineax.QR`][].
- If the operator is diagonal, then use [`lineax.Diagonal`][].
- If the operator is tridiagonal, then use [`lineax.Tridiagonal`][].
- If the operator is triangular, then use [`lineax.Triangular`][].
- If the matrix is positive or negative (semi-)definite, then use
[`lineax.Cholesky`][].
- Else, use [`lineax.LU`][].
This is a good choice if your primary concern is computational efficiency. It will
handle ill-posed systems as long as it is not computationally expensive to do so.
"""
well_posed: bool | None
def _select_solver(self, operator: AbstractLinearOperator):
if self.well_posed is True:
if operator.in_size() != operator.out_size():
raise ValueError(
"Cannot use `AutoLinearSolver(well_posed=True)` with a non-square "
"operator. If you are trying solve a least-squares problem then "
"you should pass `solver=AutoLinearSolver(well_posed=False)`. By "
"default `lineax.linear_solve` assumes that the operator is "
"square and nonsingular."
)
if is_diagonal(operator):
token = _well_posed_diagonal_token
elif is_tridiagonal(operator):
token = _tridiagonal_token
elif is_lower_triangular(operator) or is_upper_triangular(operator):
token = _triangular_token
elif is_positive_semidefinite(operator) or is_negative_semidefinite(
operator
):
token = _cholesky_token
else:
token = _lu_token
elif self.well_posed is False:
if is_diagonal(operator):
token = _diagonal_token
else:
# TODO: use rank-revealing QR instead.
token = _svd_token
elif self.well_posed is None:
if operator.in_size() != operator.out_size():
token = _qr_token
elif is_diagonal(operator):
token = _diagonal_token
elif is_tridiagonal(operator):
token = _tridiagonal_token
elif is_lower_triangular(operator) or is_upper_triangular(operator):
token = _triangular_token
elif is_positive_semidefinite(operator) or is_negative_semidefinite(
operator
):
token = _cholesky_token
else:
token = _lu_token
else:
raise ValueError(f"Invalid value `well_posed={self.well_posed}`.")
return token
def select_solver(self, operator: AbstractLinearOperator) -> AbstractLinearSolver:
"""Check which solver that [`lineax.AutoLinearSolver`][] will dispatch to.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
The linear solver that will be used.
"""
return _lookup(self._select_solver(operator))
def init(self, operator, options) -> _AutoLinearSolverState:
token = self._select_solver(operator)
return token, _lookup(token).init(operator, options)
def compute(
self,
state: _AutoLinearSolverState,
vector: PyTree[Array],
options: dict[str, Any],
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
token, state = state
solver = _lookup(token)
solution, result, _ = solver.compute(state, vector, options)
return solution, result, {}
def transpose(self, state: _AutoLinearSolverState, options: dict[str, Any]):
token, state = state
solver = _lookup(token)
transpose_state, transpose_options = solver.transpose(state, options)
transpose_state = (token, transpose_state)
return transpose_state, transpose_options
def conj(self, state: _AutoLinearSolverState, options: dict[str, Any]):
token, state = state
solver = _lookup(token)
conj_state, conj_options = solver.conj(state, options)
conj_state = (token, conj_state)
return conj_state, conj_options
def assume_full_rank(self):
return self.well_posed is not False
AutoLinearSolver.__init__.__doc__ = """**Arguments:**
- `well_posed`: whether to only handle well-posed systems or not, as discussed above.
"""
# TODO(kidger): gmres, bicgstab
# TODO(kidger): support auxiliary outputs
@eqx.filter_jit
def linear_solve(
operator: AbstractLinearOperator,
vector: PyTree[ArrayLike],
solver: AbstractLinearSolver = AutoLinearSolver(well_posed=True),
*,
options: dict[str, Any] | None = None,
state: PyTree[Any] = sentinel,
throw: bool = True,
) -> Solution:
r"""Solves a linear system.
Given an operator represented as a matrix $A$, and a vector $b$: if the operator is
square and nonsingular (so that the problem is well-posed), then this returns the
usual solution $x$ to $Ax = b$, defined as $A^{-1}b$.
If the operator is overdetermined, then this either returns the least-squares
solution $\min_x \| Ax - b \|_2$, or throws an error. (Depending on the choice of
solver.)
If the operator is underdetermined, then this either returns the minimum-norm
solution $\min_x \|x\|_2 \text{ subject to } Ax = b$, or throws an error. (Depending
on the choice of solver.)
!!! info
This function is equivalent to either `numpy.linalg.solve`, or to its
generalisation `numpy.linalg.lstsq`, depending on the choice of solver.
The default solver is `lineax.AutoLinearSolver(well_posed=True)`. This
automatically selects a solver depending on the structure (e.g. triangular) of your
problem, and will throw an error if your system is overdetermined or
underdetermined.
Use `lineax.AutoLinearSolver(well_posed=False)` if your system is known to be
overdetermined or underdetermined (although handling this case implies greater
computational cost).
!!! tip
These three kinds of solution to a linear system are collectively known as the
"pseudoinverse solution" to a linear system. That is, given our matrix $A$, let
$A^\dagger$ denote the
[Moore--Penrose pseudoinverse](https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse)
of $A$. Then the usual/least-squares/minimum-norm solution are all equal to
$A^\dagger b$.
**Arguments:**
- `operator`: a linear operator. This is the '$A$' in '$Ax = b$'.
Most frequently this operator is simply represented as a JAX matrix (i.e. a
rank-2 JAX array), but any [`lineax.AbstractLinearOperator`][] is supported.
Note that if it is a matrix, then it should be passed as an
[`lineax.MatrixLinearOperator`][], e.g.
```python
matrix = jax.random.normal(key, (5, 5)) # JAX array of shape (5, 5)
operator = lx.MatrixLinearOperator(matrix) # Wrap into a linear operator
solution = lx.linear_solve(operator, ...)
```
rather than being passed directly.
- `vector`: the vector to solve against. This is the '$b$' in '$Ax = b$'.
- `solver`: the solver to use. Should be any [`lineax.AbstractLinearSolver`][].
The default is [`lineax.AutoLinearSolver`][] which behaves as discussed
above.
If the operator is overdetermined or underdetermined , then passing
[`lineax.SVD`][] is typical.
- `options`: Individual solvers may accept additional runtime arguments; for example
[`lineax.CG`][] allows for specifying a preconditioner. See each individual
solver's documentation for more details. Keyword only argument.
- `state`: If performing multiple linear solves with the same operator, then some
computation can be saved by recording and reusing some information; for example
the matrix factorisation of the operator. This value should be the result of
calling [`lineax.AbstractLinearSolver.init`][] on the provided `operator`.
If provided, then the underlying `operator` must still be passed to
`linear_solve`.
Keyword only argument.
- `throw`: How to report any failures. (E.g. an iterative solver running out of
steps, or a well-posed-only solver being run with a singular operator.)
If `True` then a failure will raise an error. Note that errors are only reliably
raised on CPUs. If on GPUs then the error may only be printed to stderr, whilst
on TPUs then the behaviour is undefined.
If `False` then the returned solution object will have a `result` field
indicating whether any failures occured. (See [`lineax.Solution`][].)
Keyword only argument.
**Returns:**
An [`lineax.Solution`][] object containing the solution to the linear system.
""" # noqa: E501
if eqx.is_array(operator):
raise ValueError(
"`lineax.linear_solve(operator=...)` should be an "
"`AbstractLinearOperator`, not a raw JAX array. If you are trying to pass "
"a matrix then this should be passed as "
"`lineax.MatrixLinearOperator(matrix)`."
)
if options is None:
options = {}
vector = jtu.tree_map(inexact_asarray, vector)
vector_struct = strip_weak_dtype(jax.eval_shape(lambda: vector))
operator_out_structure = strip_weak_dtype(operator.out_structure())
# `is` to handle tracers
if eqx.tree_equal(vector_struct, operator_out_structure) is not True:
raise ValueError(
"Vector and operator structures do not match. Got a vector with structure "
f"{vector_struct} and an operator with out-structure "
f"{operator_out_structure}"
)
if isinstance(operator, IdentityLinearOperator):
return Solution(
value=vector,
result=RESULTS.successful,
state=state,
stats={},
)
if state == sentinel:
dynamic_operator, static_operator = eqx.partition(operator, eqx.is_array)
stopped_operator = eqx.combine(
lax.stop_gradient(dynamic_operator), static_operator
)
state = solver.init(stopped_operator, options)
dynamic_state, static_state = eqx.partition(state, eqx.is_array)
dynamic_state = lax.stop_gradient(dynamic_state)
state = eqx.combine(dynamic_state, static_state)
options = eqxi.nondifferentiable(
options, name="`lineax.linear_solve(..., options=...)`"
)
solver = eqxi.nondifferentiable(
solver, name="`lineax.linear_solve(..., solver=...)`"
)
solution, result, stats = eqxi.filter_primitive_bind(
linear_solve_p, operator, state, vector, options, solver, throw
)
# TODO: prevent forward-mode autodiff through stats
stats = eqxi.nondifferentiable_backward(stats)
return Solution(value=solution, result=result, state=state, stats=stats)
def invert(
operator: AbstractLinearOperator,
solver: AbstractLinearSolver = AutoLinearSolver(well_posed=True),
*,
options: dict[str, Any] | None = None,
throw: bool = True,
) -> FunctionLinearOperator:
r"""Returns a [`lineax.FunctionLinearOperator`][] representing the
(pseudo)inverse of `operator`.
`invert(A).mv(v)` is equivalent to `linear_solve(A, v, solver).value`.
See [`lineax.linear_solve`][] for details on how the solution is defined
for square, overdetermined, and underdetermined systems.
The returned operator fully supports AD (both forward and reverse mode),
`vmap`, and composition with other operators.
**Arguments:**
- `operator`: the linear operator to invert.
- `solver`: the linear solver to use. Defaults to
`AutoLinearSolver(well_posed=True)`.
- `options`: additional options passed to the solver. Defaults to `None`.
- `throw`: as [`lineax.linear_solve`][]. Defaults to `True`.
**Returns:**
A [`lineax.FunctionLinearOperator`][] whose `mv` solves `operator @ x = v`.
"""
if options is None:
options = {}
state = solver.init(operator, options)
def solve_fn(vector):
return linear_solve(
operator,
vector,
solver,
state=state,
options=options,
throw=throw,
).value
tags = {
tag
for check, tag in [
(is_symmetric, symmetric_tag),
(is_diagonal, diagonal_tag),
(is_lower_triangular, lower_triangular_tag),
(is_upper_triangular, upper_triangular_tag),
(is_positive_semidefinite, positive_semidefinite_tag),
(is_negative_semidefinite, negative_semidefinite_tag),
]
if check(operator)
}
if has_unit_diagonal(operator) and (
is_diagonal(operator)
or is_lower_triangular(operator)
or is_upper_triangular(operator)
):
tags.add(unit_diagonal_tag)
return FunctionLinearOperator(solve_fn, operator.out_structure(), frozenset(tags))
# Work around JAX issue #22011,
# as well as https://github.com/patrick-kidger/diffrax/pull/387#issuecomment-2174488365
def stop_gradient_transpose(ct, x):
return (ct,)
ad.primitive_transposes[stop_gradient_p] = stop_gradient_transpose
================================================
FILE: lineax/_solver/__init__.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bicgstab import BiCGStab as BiCGStab
from .cg import CG as CG, NormalCG as NormalCG
from .cholesky import Cholesky as Cholesky
from .diagonal import Diagonal as Diagonal
from .gmres import GMRES as GMRES
from .lsmr import LSMR as LSMR
from .lu import LU as LU
from .normal import Normal as Normal
from .qr import QR as QR
from .svd import SVD as SVD
from .triangular import Triangular as Triangular
from .tridiagonal import Tridiagonal as Tridiagonal
================================================
FILE: lineax/_solver/bicgstab.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
from typing import Any, TypeAlias
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from jaxtyping import Array, PyTree
from .._norm import max_norm, tree_dot
from .._operator import AbstractLinearOperator, conj, linearise
from .._solution import RESULTS
from .._solve import AbstractLinearSolver
from .misc import preconditioner_and_y0
_BiCGStabState: TypeAlias = AbstractLinearOperator
class BiCGStab(AbstractLinearSolver[_BiCGStabState]):
"""Biconjugate gradient stabilised method for linear systems.
The operator should be square.
Equivalent to `jax.scipy.sparse.linalg.bicgstab`.
This supports the following `options` (as passed to
`lx.linear_solve(..., options=...)`).
- `preconditioner`: A [`lineax.AbstractLinearOperator`][]
to be used as a preconditioner. Defaults to
[`lineax.IdentityLinearOperator`][]. This method uses right preconditioning.
- `y0`: The initial estimate of the solution to the linear system. Defaults to all
zeros.
"""
rtol: float
atol: float
norm: Callable = max_norm
max_steps: int | None = None
def __check_init__(self):
if isinstance(self.rtol, (int, float)) and self.rtol < 0:
raise ValueError("Tolerances must be non-negative.")
if isinstance(self.atol, (int, float)) and self.atol < 0:
raise ValueError("Tolerances must be non-negative.")
if isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)):
if self.atol == 0 and self.rtol == 0 and self.max_steps is None:
raise ValueError(
"Must specify `rtol`, `atol`, or `max_steps` (or some combination "
"of all three)."
)
def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
if operator.in_structure() != operator.out_structure():
raise ValueError(
"`BiCGstab(..., normal=False)` may only be used for linear solves with "
"square matrices."
)
return linearise(operator)
def compute(
self, state: _BiCGStabState, vector: PyTree[Array], options: dict[str, Any]
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
operator = state
preconditioner, y0 = preconditioner_and_y0(operator, vector, options)
leaves, _ = jtu.tree_flatten(vector)
if self.max_steps is None:
size = sum(leaf.size for leaf in leaves)
max_steps = 10 * size
else:
max_steps = self.max_steps
has_scale = not (
isinstance(self.atol, (int, float))
and isinstance(self.rtol, (int, float))
and self.atol == 0
and self.rtol == 0
)
if has_scale:
b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω
# This implementation is the same a jax.scipy.sparse.linalg.bicgstab
# but with AbstractLinearOperator.
# We use the notation found on the wikipedia except with y instead of x:
# https://en.wikipedia.org/wiki/
# Biconjugate_gradient_stabilized_method#Preconditioned_BiCGSTAB
# preconditioner in this case is K2^(-1) (i.e., right preconditioning)
r0 = (vector**ω - operator.mv(y0) ** ω).ω
def breakdown_occurred(omega, alpha, rho):
# Empirically, the tolerance checks for breakdown are very tight.
# These specific tolerances are heuristic.
if jax.config.jax_enable_x64: # pyright: ignore
return (omega == 0.0) | (alpha == 0.0) | (rho == 0.0)
else:
return (omega < 1e-16) | (alpha < 1e-16) | (rho < 1e-16)
def not_converged(r, diff, y):
# The primary tolerance check.
# Given Ay=b, then we have to be doing better than `scale` in both
# the `y` and the `b` spaces.
if has_scale:
with jax.numpy_dtype_promotion("standard"):
y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω
norm1 = self.norm((r**ω / b_scale**ω).ω) # pyright: ignore
norm2 = self.norm((diff**ω / y_scale**ω).ω)
return (norm1 > 1) | (norm2 > 1)
else:
return True
def cond_fun(carry):
y, r, alpha, omega, rho, _, _, diff, step = carry
out = jnp.invert(breakdown_occurred(omega, alpha, rho))
out = out & not_converged(r, diff, y)
out = out & (step < max_steps)
return out
def body_fun(carry):
y, r, alpha, omega, rho, p, v, diff, step = carry
rho_new = tree_dot(r0, r)
beta = (rho_new / rho) * (alpha / omega)
p_new = (r**ω + beta * (p**ω - omega * v**ω)).ω
# TODO(raderj): reduce this to a single operator.mv call
# by using the scan trick.
x = preconditioner.mv(p_new)
v_new = operator.mv(x)
alpha_new = rho_new / tree_dot(r0, v_new)
s = (r**ω - alpha_new * v_new**ω).ω
z = preconditioner.mv(s)
t = operator.mv(z)
omega_new = tree_dot(s, t) / tree_dot(t, t)
diff = (alpha_new * x**ω + omega_new * z**ω).ω
y_new = (y**ω + diff**ω).ω
r_new = (s**ω - omega_new * t**ω).ω
return (
y_new,
r_new,
alpha_new,
omega_new,
rho_new,
p_new,
v_new,
diff,
step + 1,
)
p0 = v0 = jtu.tree_map(jnp.zeros_like, vector)
alpha = omega = rho = jnp.array(1.0)
init_carry = (
y0,
r0,
alpha,
omega,
rho,
p0,
v0,
ω(y0).call(lambda x: jnp.full_like(x, jnp.inf)).ω,
0,
)
solution, residual, alpha, omega, rho, _, _, diff, num_steps = lax.while_loop(
cond_fun, body_fun, init_carry
)
if self.max_steps is None:
result = RESULTS.where(
num_steps == max_steps, RESULTS.singular, RESULTS.successful
)
elif has_scale:
result = RESULTS.where(
num_steps == max_steps, RESULTS.max_steps_reached, RESULTS.successful
)
else:
result = RESULTS.successful
# breakdown is only an issue if we did not converge
breakdown = breakdown_occurred(omega, alpha, rho) & not_converged(
residual, diff, solution
)
result = RESULTS.where(breakdown, RESULTS.breakdown, result)
stats = {"num_steps": num_steps, "max_steps": self.max_steps}
return solution, result, stats
def transpose(self, state: _BiCGStabState, options: dict[str, Any]):
transpose_options = {}
if "preconditioner" in options:
transpose_options["preconditioner"] = options["preconditioner"].transpose()
operator = state
return operator.transpose(), transpose_options
def conj(self, state: _BiCGStabState, options: dict[str, Any]):
conj_options = {}
if "preconditioner" in options:
conj_options["preconditioner"] = conj(options["preconditioner"])
operator = state
return conj(operator), conj_options
def assume_full_rank(self):
return True
BiCGStab.__init__.__doc__ = r"""**Arguments:**
- `rtol`: Relative tolerance for terminating solve.
- `atol`: Absolute tolerance for terminating solve.
- `norm`: The norm to use when computing whether the error falls within the tolerance.
Defaults to the max norm.
- `max_steps`: The maximum number of iterations to run the solver for. If more steps
than this are required, then the solve is halted with a failure.
"""
================================================
FILE: lineax/_solver/cg.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from collections.abc import Callable
from typing import Any, TypeAlias
import equinox.internal as eqxi
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from jaxtyping import Array, PyTree, Scalar
from .._misc import resolve_rcond, structure_equal, tree_where
from .._norm import max_norm, tree_dot
from .._operator import (
AbstractLinearOperator,
conj,
is_negative_semidefinite,
is_positive_semidefinite,
linearise,
)
from .._solution import RESULTS
from .._solve import AbstractLinearSolver
from .misc import preconditioner_and_y0
from .normal import Normal
_CGState: TypeAlias = tuple[AbstractLinearOperator, eqxi.Static]
# TODO(kidger): this is pretty slow to compile.
# - CG evaluates `operator.mv` three times.
# Possibly this can be cheapened a bit somehow?
class CG(AbstractLinearSolver[_CGState]):
"""Conjugate gradient solver for linear systems.
The operator should be positive or negative definite.
Equivalent to `scipy.sparse.linalg.cg`.
This supports the following `options` (as passed to
`lx.linear_solve(..., options=...)`).
- `preconditioner`: A positive definite [`lineax.AbstractLinearOperator`][]
to be used as preconditioner. Defaults to
[`lineax.IdentityLinearOperator`][]. This method uses left preconditioning,
so it is the preconditioned residual that is minimized, though the actual
termination criteria uses the un-preconditioned residual.
- `y0`: The initial estimate of the solution to the linear system. Defaults to all
zeros.
"""
rtol: float
atol: float
norm: Callable[[PyTree], Scalar] = max_norm
stabilise_every: int | None = 10
max_steps: int | None = None
def __check_init__(self):
if isinstance(self.rtol, (int, float)) and self.rtol < 0:
raise ValueError("Tolerances must be non-negative.")
if isinstance(self.atol, (int, float)) and self.atol < 0:
raise ValueError("Tolerances must be non-negative.")
if isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)):
if self.atol == 0 and self.rtol == 0 and self.max_steps is None:
raise ValueError(
"Must specify `rtol`, `atol`, or `max_steps` (or some combination "
"of all three)."
)
def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
del options
is_nsd = is_negative_semidefinite(operator)
if not structure_equal(operator.in_structure(), operator.out_structure()):
raise ValueError(
"`CG()` may only be used for linear solves with square matrices."
)
if not (is_positive_semidefinite(operator) | is_nsd):
raise ValueError(
"`CG()` may only be used for positive "
"or negative definite linear operators"
)
if is_nsd:
operator = -operator
operator = linearise(operator)
return operator, eqxi.Static(is_nsd)
# This differs from jax.scipy.sparse.linalg.cg in:
# 1. Every few steps we calculate the residual directly, rather than by cheaply
# using the existing quantities. This improves numerical stability.
# 2. We use a more sophisticated termination condition. To begin with we have an
# rtol and atol in the conventional way, inducing a vector-valued scale. This is
# then checked in both the `y` and `b` domains (for `Ay = b`).
# 3. We return the number of steps, and whether or not the solve succeeded, as
# additional information.
# 4. We don't try to support complex numbers. (Yet.)
def compute(
self, state: _CGState, vector: PyTree[Array], options: dict[str, Any]
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
operator, is_nsd = state
is_nsd = is_nsd.value
preconditioner, y0 = preconditioner_and_y0(operator, vector, options)
if not is_positive_semidefinite(preconditioner):
raise ValueError("The preconditioner must be positive definite.")
leaves, _ = jtu.tree_flatten(vector)
size = sum(leaf.size for leaf in leaves)
if self.max_steps is None:
max_steps = 10 * size # Copied from SciPy!
else:
max_steps = self.max_steps
r0 = (vector**ω - operator.mv(y0) ** ω).ω
p0 = preconditioner.mv(r0)
gamma0 = tree_dot(p0, r0)
rcond = resolve_rcond(None, size, size, jnp.result_type(*leaves))
initial_value = (
ω(y0).call(lambda x: jnp.full_like(x, jnp.inf)).ω,
y0,
r0,
p0,
gamma0,
0,
)
has_scale = not (
isinstance(self.atol, (int, float))
and isinstance(self.rtol, (int, float))
and self.atol == 0
and self.rtol == 0
)
if has_scale:
b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω
def not_converged(r, diff, y):
# The primary tolerance check.
# Given Ay=b, then we have to be doing better than `scale` in both
# the `y` and the `b` spaces.
if has_scale:
with jax.numpy_dtype_promotion("standard"):
y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω
norm1 = self.norm((r**ω / b_scale**ω).ω) # pyright: ignore
norm2 = self.norm((diff**ω / y_scale**ω).ω)
return (norm1 > 1) | (norm2 > 1)
else:
return True
def cond_fun(value):
diff, y, r, _, gamma, step = value
out = gamma > 0
out = out & (step < max_steps)
out = out & not_converged(r, diff, y)
return out
def body_fun(value):
_, y, r, p, gamma, step = value
mat_p = operator.mv(p)
inner_prod = tree_dot(mat_p, p)
alpha = gamma / inner_prod
alpha = tree_where(
jnp.abs(inner_prod) > 100 * rcond * jnp.abs(gamma), # pyright: ignore
alpha,
jnp.nan, # pyright: ignore
)
diff = (alpha * p**ω).ω
y = (y**ω + diff**ω).ω
step = step + 1
# E.g. see B.2 of
# https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf
# We compute the residual the "expensive" way every now and again, so as to
# correct numerical rounding errors.
def stable_r():
return (vector**ω - operator.mv(y) ** ω).ω
def cheap_r():
return (r**ω - alpha * mat_p**ω).ω
if self.stabilise_every == 1:
r = stable_r()
elif self.stabilise_every is None:
r = cheap_r()
else:
stable_step = (eqxi.unvmap_max(step) % self.stabilise_every) == 0
stable_step = eqxi.nonbatchable(stable_step)
r = lax.cond(stable_step, stable_r, cheap_r)
z = preconditioner.mv(r)
gamma_prev = gamma
gamma = tree_dot(z, r)
beta = gamma / gamma_prev
p = (z**ω + beta * p**ω).ω
return diff, y, r, p, gamma, step
_, solution, _, _, _, num_steps = lax.while_loop(
cond_fun, body_fun, initial_value
)
if self.max_steps is None:
result = RESULTS.where(
num_steps == max_steps, RESULTS.singular, RESULTS.successful
)
elif has_scale:
result = RESULTS.where(
num_steps == max_steps, RESULTS.max_steps_reached, RESULTS.successful
)
else:
result = RESULTS.successful
if is_nsd:
solution = -(solution**ω).ω
stats = {"num_steps": num_steps, "max_steps": self.max_steps}
return solution, result, stats
def transpose(
self, state: _CGState, options: dict[str, Any]
) -> tuple[_CGState, dict[str, Any]]:
transpose_options = {}
if "preconditioner" in options:
transpose_options["preconditioner"] = options["preconditioner"].transpose()
psd_op, is_nsd = state
transpose_state = psd_op.transpose(), is_nsd
return transpose_state, transpose_options
def conj(
self, state: _CGState, options: dict[str, Any]
) -> tuple[_CGState, dict[str, Any]]:
conj_options = {}
if "preconditioner" in options:
conj_options["preconditioner"] = conj(options["preconditioner"])
psd_op, is_nsd = state
conj_state = conj(psd_op), is_nsd
return conj_state, conj_options
def assume_full_rank(self):
return True
CG.__init__.__doc__ = r"""**Arguments:**
- `rtol`: Relative tolerance for terminating solve.
- `atol`: Absolute tolerance for terminating solve.
- `norm`: The norm to use when computing whether the error falls within the tolerance.
Defaults to the max norm.
- `stabilise_every`: The conjugate gradient is an iterative method that produces
candidate solutions $x_1, x_2, \ldots$, and terminates once $r_i = \| Ax_i - b \|$
is small enough. For computational efficiency, the values $r_i$ are computed using
other internal quantities, and not by directly evaluating the formula above.
However, this computation of $r_i$ is susceptible to drift due to limited
floating-point precision. Every `stabilise_every` steps, then $r_i$ is computed
directly using the formula above, in order to stabilise the computation.
- `max_steps`: The maximum number of iterations to run the solver for. If more steps
than this are required, then the solve is halted with a failure.
"""
def NormalCG(*args, **kwargs):
"""Deprecated helper function. Use `lx.Normal(lx.CG(...))` instead.
!!! warning "Deprecated"
`NormalCG(...)` is deprecated in favour of `lx.Normal(lx.CG(...))`.
This will be removed in some future version of Lineax.
"""
warnings.warn(
"`NormalCG(...)` is deprecated in favour of `lx.Normal(lx.CG(...))`. "
"This will be removed in some future version of Lineax.",
DeprecationWarning,
stacklevel=2,
)
return Normal(CG(*args, **kwargs))
================================================
FILE: lineax/_solver/cholesky.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, TypeAlias
import equinox.internal as eqxi
import jax.flatten_util as jfu
import jax.scipy as jsp
from jaxtyping import Array, PyTree
from .._operator import (
AbstractLinearOperator,
is_negative_semidefinite,
is_positive_semidefinite,
)
from .._solution import RESULTS
from .._solve import AbstractLinearSolver
_CholeskyState: TypeAlias = tuple[Array, eqxi.Static]
class Cholesky(AbstractLinearSolver[_CholeskyState]):
"""Cholesky solver for linear systems. This is generally the preferred solver for
positive or negative definite systems.
Equivalent to `scipy.linalg.solve(..., assume_a="pos")`.
The operator must be square, nonsingular, and either positive or negative definite.
"""
def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
del options
is_nsd = is_negative_semidefinite(operator)
if not (is_positive_semidefinite(operator) | is_nsd):
raise ValueError(
"`Cholesky(..., normal=False)` may only be used for positive "
"or negative definite linear operators"
)
matrix = operator.as_matrix()
m, n = matrix.shape
if m != n:
raise ValueError(
"`Cholesky(..., normal=False)` may only be used for linear solves "
"with square matrices"
)
if is_nsd:
matrix = -matrix
factor, lower = jsp.linalg.cho_factor(matrix)
# Fix upper triangular for simplicity.
assert lower is False
return factor, eqxi.Static(is_nsd)
def compute(
self, state: _CholeskyState, vector: PyTree[Array], options: dict[str, Any]
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
factor, is_nsd = state
is_nsd = is_nsd.value
del options
# Cholesky => PSD => symmetric => (in_structure == out_structure) =>
# we don't need to use packed structures.
vector, unflatten = jfu.ravel_pytree(vector)
solution = jsp.linalg.cho_solve((factor, False), vector)
if is_nsd:
solution = -solution
solution = unflatten(solution)
return solution, RESULTS.successful, {}
def transpose(
self, state: _CholeskyState, options: dict[str, Any]
) -> tuple[_CholeskyState, dict[str, Any]]:
# Matrix is self-adjoint
factor, is_nsd = state
return (factor.conj(), is_nsd), options
def conj(
self, state: _CholeskyState, options: dict[str, Any]
) -> tuple[_CholeskyState, dict[str, Any]]:
# Matrix is self-adjoint
factor, is_nsd = state
return (factor.conj(), is_nsd), options
def assume_full_rank(self):
return True
Cholesky.__init__.__doc__ = """**Arguments:**
Nothing.
"""
================================================
FILE: lineax/_solver/diagonal.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, TypeAlias
import jax.numpy as jnp
from jaxtyping import Array, PyTree
from .._misc import resolve_rcond
from .._operator import AbstractLinearOperator, diagonal, has_unit_diagonal, is_diagonal
from .._solution import RESULTS
from .._solve import AbstractLinearSolver
from .misc import (
pack_structures,
PackedStructures,
ravel_vector,
transpose_packed_structures,
unravel_solution,
)
_DiagonalState: TypeAlias = tuple[Array | None, PackedStructures]
class Diagonal(AbstractLinearSolver[_DiagonalState]):
"""Diagonal solver for linear systems.
Requires that the operator be diagonal. Then $Ax = b$, with $A = diag[a]$, is
solved simply by doing an elementwise division $x = b / a$.
This solver can handle singular operators (i.e. diagonal entries with value 0).
"""
well_posed: bool = False
rcond: float | None = None
def init(
self, operator: AbstractLinearOperator, options: dict[str, Any]
) -> _DiagonalState:
del options
if operator.in_size() != operator.out_size():
raise ValueError(
"`Diagonal` may only be used for linear solves with square matrices"
)
if not is_diagonal(operator):
raise ValueError(
"`Diagonal` may only be used for linear solves with diagonal matrices"
)
packed_structures = pack_structures(operator)
if has_unit_diagonal(operator):
return None, packed_structures
else:
return diagonal(operator), packed_structures
def compute(
self, state: _DiagonalState, vector: PyTree[Array], options: dict[str, Any]
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
diag, packed_structures = state
del state, options
unit_diagonal = diag is None
vector = ravel_vector(vector, packed_structures)
if unit_diagonal:
solution = vector
else:
if not self.well_posed:
(size,) = diag.shape
rcond = resolve_rcond(self.rcond, size, size, diag.dtype)
abs_diag = jnp.abs(diag)
diag = jnp.where(abs_diag > rcond * jnp.max(abs_diag), diag, jnp.inf) # pyright: ignore
solution = vector / diag
solution = unravel_solution(solution, packed_structures)
return solution, RESULTS.successful, {}
def transpose(self, state: _DiagonalState, options: dict[str, Any]):
del options
diag, packed_structures = state
transposed_packed_structures = transpose_packed_structures(packed_structures)
transpose_state = diag, transposed_packed_structures
transpose_options = {}
return transpose_state, transpose_options
def conj(self, state: _DiagonalState, options: dict[str, Any]):
del options
diag, packed_structures = state
if diag is None:
conj_diag = None
else:
conj_diag = diag.conj()
conj_options = {}
conj_state = conj_diag, packed_structures
return conj_state, conj_options
def assume_full_rank(self):
return self.well_posed
Diagonal.__init__.__doc__ = """**Arguments**:
- `well_posed`: if `False`, then singular operators are accepted, and the pseudoinverse
solution is returned. If `True` then passing a singular operator will cause an error
to be raised instead.
- `rcond`: the cutoff for handling zero entries on the diagonal. Defaults to machine
precision times `N`, where `N` is the input (or output) size of the operator.
Only used if `well_posed=False`
"""
================================================
FILE: lineax/_solver/gmres.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools as ft
from collections.abc import Callable
from typing import Any, cast, TypeAlias
import equinox.internal as eqxi
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from jaxtyping import Array, ArrayLike, Bool, Float, Inexact, PyTree
from .._misc import structure_equal
from .._norm import max_norm, two_norm
from .._operator import AbstractLinearOperator, conj, linearise, MatrixLinearOperator
from .._solution import RESULTS
from .._solve import AbstractLinearSolver, linear_solve
from .misc import preconditioner_and_y0
from .qr import QR
_GMRESState: TypeAlias = AbstractLinearOperator
class GMRES(AbstractLinearSolver[_GMRESState]):
"""GMRES solver for linear systems.
The operator should be square.
Similar to `jax.scipy.sparse.linalg.gmres`.
This supports the following `options` (as passed to
`lx.linear_solve(..., options=...)`).
- `preconditioner`: A [`lineax.AbstractLinearOperator`][]
to be used as preconditioner. Defaults to
[`lineax.IdentityLinearOperator`][]. This method uses left preconditioning,
so it is the preconditioned residual that is minimized, though the actual
termination criteria uses the un-preconditioned residual.
- `y0`: The initial estimate of the solution to the linear system. Defaults to all
zeros.
"""
rtol: float
atol: float
norm: Callable = max_norm
max_steps: int | None = None
restart: int = 20
stagnation_iters: int = 20
def __check_init__(self):
if isinstance(self.rtol, (int, float)) and self.rtol < 0:
raise ValueError("Tolerances must be non-negative.")
if isinstance(self.atol, (int, float)) and self.atol < 0:
raise ValueError("Tolerances must be non-negative.")
if isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)):
if self.atol == 0 and self.rtol == 0 and self.max_steps is None:
raise ValueError(
"Must specify `rtol`, `atol`, or `max_steps` (or some combination "
"of all three)."
)
def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
del options
if not structure_equal(operator.in_structure(), operator.out_structure()):
raise ValueError(
"`GMRES(..., normal=False)` may only be used for linear solves with "
"square matrices."
)
return linearise(operator)
#
# This differs from `jax.scipy.sparse.linalg.gmres` in a few ways:
# 1. We use a more sophisticated termination condition. To begin with we have an
# rtol and atol in the conventional way, inducing a vector-valued scale. This is
# then checked in both the `y` and `b` domains (for `Ay = b`).
# 2. We handle in-place updates with buffers to avoid generating unnecessary
# copies of arrays during the Gram-Schmidt procedure.
# 3. We use a QR solve at the end of the batched Gram-Schmidt instead
# of a Cholesky solve of the normal equations. This is both faster and more
# numerically stable.
# 4. We use tricks to compile `A y` fewer times throughout the code, including
# passing a dummy initial residual.
# 5. We return the number of steps, and whether or not the solve succeeded, as
# additional information.
# 6. We do not use the unnecessary loop within Gram-Schmidt, and simply compute
# this in a single pass.
# 7. We add better safety checks for breakdown, and a safety check for stagnation
# of the iterates even when we don't explicitly get breakdown.
#
def compute(
self,
state: _GMRESState,
vector: PyTree[Array],
options: dict[str, Any],
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
has_scale = not (
isinstance(self.atol, (int, float))
and isinstance(self.rtol, (int, float))
and self.atol == 0
and self.rtol == 0
)
if has_scale:
b_scale = (self.atol + self.rtol * ω(vector).call(jnp.abs)).ω
operator = state
preconditioner, y0 = preconditioner_and_y0(operator, vector, options)
leaves, _ = jtu.tree_flatten(vector)
size = sum(leaf.size for leaf in leaves)
if self.max_steps is None:
max_steps = 10 * size # Copied from SciPy!
else:
max_steps = self.max_steps
restart = min(self.restart, size)
def not_converged(r, diff, y):
# The primary tolerance check.
# Given Ay=b, then we have to be doing better than `scale` in both
# the `y` and the `b` spaces.
if has_scale:
with jax.numpy_dtype_promotion("standard"):
y_scale = (self.atol + self.rtol * ω(y).call(jnp.abs)).ω
norm1 = self.norm((r**ω / b_scale**ω).ω) # pyright: ignore
norm2 = self.norm((diff**ω / y_scale**ω).ω)
return (norm1 > 1) | (norm2 > 1)
else:
return True
def cond_fun(carry):
y, r, _, deferred_breakdown, diff, _, step, stagnation_counter = carry
# NOTE: we defer ending due to breakdown by one loop! This is nonstandard,
# but lets us use a cauchy-like condition in the convergence criteria.
# If we do not defer breakdown, breakdown may detect convergence when
# the diff between two iterations is still quite large, and we only
# consider convergence when the diff is small.
out = jnp.invert(deferred_breakdown) & (
stagnation_counter < self.stagnation_iters
)
out = out & not_converged(r, diff, y)
out = out & (step < max_steps)
# The first pass uses a dummy value for r0 in order to save on compiling
# an extra matvec. The dummy step may raise a breakdown, and `step == 0`
# avoids us from returning prematurely.
return out | (step == 0)
def body_fun(carry):
# `breakdown` -> `deferred_breakdown` and `deferred_breakdown` -> `_`
y, r, deferred_breakdown, _, diff, r_min, step, stagnation_counter = carry
y_new, r_new, breakdown, diff_new = self._gmres_compute(
operator, vector, y, r, restart, preconditioner, step == 0
)
#
# If the minimum residual does not decrease for many iterations
# ("many" is determined by self.stagnation_iters) then the iterative
# solve has stagnated and we stop the loop. This bit keeps track of how
# long it has been since the minimum has decreased, and updates the minimum
# when a new minimum is encountered. As far as I (raderj) am
# aware, this is custom to our implementation and not standard practice.
#
r_new_norm = self.norm(r_new)
r_decreased = (r_new_norm - r_min) < 0
stagnation_counter = jnp.where(r_decreased, 0, stagnation_counter + 1)
stagnation_counter = cast(Array, stagnation_counter)
r_min = jnp.minimum(r_new_norm, r_min)
return (
y_new,
r_new,
breakdown,
deferred_breakdown,
diff_new,
r_min,
step + 1,
stagnation_counter,
)
# Initialise the residual r0 to the dummy value of all 0s. This means
# the first iteration of Gram-Schmidt will do nothing, but it saves
# us from compiling an extra matvec here.
r0 = ω(vector).call(jnp.zeros_like).ω
init_carry = (
y0, # y
r0, # residual
False, # breakdown
False, # deferred_breakdown
ω(y0).call(lambda x: jnp.full_like(x, jnp.inf)).ω, # diff
jnp.inf, # r_min
0, # steps
jnp.array(0), # stagnation counter
)
(
solution,
residual,
_, # breakdown
breakdown, # deferred_breakdown
diff,
_,
num_steps,
stagnation_counter,
) = lax.while_loop(cond_fun, body_fun, init_carry)
if self.max_steps is None:
result = RESULTS.where(
num_steps == max_steps, RESULTS.singular, RESULTS.successful
)
elif has_scale:
result = RESULTS.where(
num_steps == max_steps, RESULTS.max_steps_reached, RESULTS.successful
)
else:
result = RESULTS.successful
result = RESULTS.where(
stagnation_counter >= self.stagnation_iters, RESULTS.stagnation, result
)
# breakdown is only an issue if we broke down outside the tolerance
# of the solution. If we get breakdown and are within the tolerance,
# this is called convergence :)
breakdown = breakdown & not_converged(residual, diff, solution)
# breakdown is the most serious potential issue
result = RESULTS.where(breakdown, RESULTS.breakdown, result)
stats = {"num_steps": num_steps, "max_steps": self.max_steps}
return solution, result, stats
def _gmres_compute(
self, operator, vector, y, r, restart, preconditioner, first_pass
):
#
# internal function for computing the bulk of the gmres. We seperate this out
# for two reasons:
# 1. avoid nested body and cond functions in the body and cond function of
# `self.compute`. `self.compute` is primarily responsible for the restart
# behavior of gmres.
# 2. Like the jax.scipy implementation we may want to add an incremental
# version at a later date.
#
def main_gmres(y):
# see the comment at the end of `_arnoldi_gram_schmidt` for a discussion
# of `initial_breakdown`
r_normalised, r_norm, initial_breakdown = self._normalise(r, eps=None)
basis_init = jtu.tree_map(
lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),
r_normalised,
)
coeff_mat_init = jnp.eye(
restart,
restart + 1,
dtype=jnp.result_type(*jtu.tree_leaves(r_normalised)),
)
def cond_fun(carry):
_, _, breakdown, step = carry
return (step < restart) & jnp.invert(breakdown)
def body_fun(carry):
basis, coeff_mat, breakdown, step = carry
basis_new, coeff_mat_new, breakdown = self._arnoldi_gram_schmidt(
operator,
preconditioner,
basis,
coeff_mat,
step,
restart,
vector,
breakdown,
)
return basis_new, coeff_mat_new, breakdown, step + 1
def buffers(carry):
basis, coeff_mat, _, _ = carry
return basis, coeff_mat
init_carry = (basis_init, coeff_mat_init, initial_breakdown, 0)
basis, coeff_mat, breakdown, steps = eqxi.while_loop(
cond_fun, body_fun, init_carry, kind="lax", buffers=buffers
)
beta_vec = jnp.concatenate(
(
r_norm[None].astype(jnp.result_type(coeff_mat)),
jnp.zeros_like(coeff_mat, shape=(restart,)),
)
)
coeff_op_transpose = MatrixLinearOperator(coeff_mat.T)
# TODO(raderj): move to a Hessenberg-specific solver
z = linear_solve(coeff_op_transpose, beta_vec, QR(), throw=False).value
diff = jtu.tree_map(
lambda mat: jnp.tensordot(
mat[..., :-1], z, axes=1, precision=lax.Precision.HIGHEST
),
basis,
)
y_new = (y**ω + diff**ω).ω
return y_new, diff, breakdown
def first_gmres(y):
return y, ω(y).call(lambda x: jnp.full_like(x, jnp.inf)).ω, False
first_pass = eqxi.unvmap_any(first_pass)
y_new, diff, breakdown = lax.cond(first_pass, first_gmres, main_gmres, y)
r_new = preconditioner.mv((vector**ω - operator.mv(y_new) ** ω).ω)
return y_new, r_new, breakdown, diff
# NOTE: in the jax implementation:
# https://github.com/google/jax/blob/
# c662fd216dec10cdb2cff4138b4318bb98853134/jax/_src/scipy/sparse/linalg.py#L327
# _classical_iterative_gram_schmidt uses a while loop to call this.
# However, max_iterations is set to 2 in all calls they make to the function,
# and the condition function requires steps < (max_iterations - 1).
# This means that in fact they only apply Gram-Schmidt once, and using a
# while_loop is unnecessary.
def _arnoldi_gram_schmidt(
self,
operator,
preconditioner,
basis,
coeff_mat,
step,
restart,
vector,
initial_breakdown,
):
#
# compute `basis.T @ basis_step` for each leaf of pytree
# and then compute the projected vector onto the basis
#
# `basis` is a pytree with buffers, meaning it can only be
# indexed into. Through this section, there are terms like `lambda _, x: ...`
# because`jtu.tree_map` only uses the first argument to determine the shape
# of the pytree. Since _Buffer is considered part of the pytree
# structure, we get leaves which are not buffers if we directly pass `basis`.
# Instead, we make sure that the first argument of the tree map is something
# with the correct pytree structure, such as `vector` in the dummy case and
# basis_step when not, so that we correctly index into `basis`.
#
basis_step = preconditioner.mv(
operator.mv(jtu.tree_map(lambda _, x: x[..., step], vector, basis))
)
step_norm = two_norm(basis_step)
contract_matrix = lambda x, y: ft.partial(
jnp.tensordot, axes=x.ndim, precision=lax.Precision.HIGHEST
)(x, y[...].conj())
_proj = jtu.tree_map(contract_matrix, basis_step, basis)
proj = jtu.tree_reduce(lambda x, y: x + y, _proj)
proj_on_cols = jtu.tree_map(lambda _, x: x[...] @ proj, vector, basis)
# now remove the component of the vector in that subspace
basis_step_new = (basis_step**ω - proj_on_cols**ω).ω
eps = step_norm * jnp.finfo(proj.dtype).eps
basis_step_normalised, step_norm_new, breakdown = self._normalise(
basis_step_new, eps=eps
)
basis_new = jtu.tree_map(
lambda y, mat: mat.at[..., step + 1].set(y),
basis_step_normalised,
basis,
)
proj_new = proj.at[step + 1].set(step_norm_new.astype(jnp.result_type(proj)))
#
# NOTE: two somewhat complicated things are going on here:
#
# The `coeff_mat` in_place update has a batch tracer, so we need to be
# careful and wrap it in a buffer, hence the use of eqxi.while_loop
# instead of lax.while_loop throughout.
#
# `initial_breakdown` occurs when the previous loop returns a
# residual which is small enough to be interpreted as 0 by self._normalise,
# but which was passed through the solver anyway. This occurs when
# the residual is small but the diff is not, or if the
# correct solution was given to GMRES from the start. Both of these tend to
# happen at the start of `gmres_compute`.
# The latter may happen when using a sequence of iterative methods.
# If `initial_breakdown` occurs, then we leave the `coeff_mat` as it was
# at initialisation. Replacing it with the projection (which will be all 0s)
# will mean `coeff_mat` is not full-rank, and `QR` can only handle nonsquare
# matrices of full-rank.
#
coeff_mat_new = coeff_mat.at[step, :].set(
proj_new, pred=jnp.invert(initial_breakdown)
)
return basis_new, coeff_mat_new, breakdown
def _normalise(
self, x: PyTree[Array], eps: Float[ArrayLike, ""] | None
) -> tuple[PyTree[Array], Inexact[Array, ""], Bool[ArrayLike, ""]]:
norm = two_norm(x)
if eps is None:
eps = jnp.finfo(norm.dtype).eps
else:
eps = jnp.astype(eps, norm.dtype)
breakdown = norm < eps # pyright: ignore
safe_norm = jnp.where(breakdown, jnp.inf, norm)
with jax.numpy_dtype_promotion("standard"):
x_normalised = (x**ω / safe_norm).ω
return x_normalised, norm, breakdown
def transpose(self, state: _GMRESState, options: dict[str, Any]):
transpose_options = {}
if "preconditioner" in options:
transpose_options["preconditioner"] = options["preconditioner"].transpose()
operator = state
return operator.transpose(), transpose_options
def conj(self, state: _GMRESState, options: dict[str, Any]):
conj_options = {}
if "preconditioner" in options:
conj_options["preconditioner"] = conj(options["preconditioner"])
operator = state
return conj(operator), conj_options
def assume_full_rank(self):
return True
GMRES.__init__.__doc__ = r"""**Arguments:**
- `rtol`: Relative tolerance for terminating solve.
- `atol`: Absolute tolerance for terminating solve.
- `norm`: The norm to use when computing whether the error falls within the tolerance.
Defaults to the max norm.
- `max_steps`: The maximum number of iterations to run the solver for. If more steps
than this are required, then the solve is halted with a failure.
- `restart`: Size of the Krylov subspace built between restarts. The returned solution
is the projection of the true solution onto this subpsace, so this direclty
bounds the accuracy of the algorithm. Default is 20.
- `stagnation_iters`: The maximum number of iterations for which the solver may not
decrease. If more than `stagnation_iters` restarts are performed without
sufficient decrease in the residual, the algorithm is halted.
"""
================================================
FILE: lineax/_solver/lsmr.py
================================================
"""Implementation adapted from SciPy, with BSD license:
Copyright (c) 2001-2002 Enthought, Inc. 2003-2024, SciPy Developers.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from collections.abc import Callable
from typing import Any, TypeAlias
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from jaxtyping import Array, PyTree
from .._misc import complex_to_real_dtype
from .._norm import two_norm
from .._operator import AbstractLinearOperator, conj, linearise
from .._solution import RESULTS
from .._solve import AbstractLinearSolver
_LSMRState: TypeAlias = AbstractLinearOperator
class LSMR(AbstractLinearSolver[_LSMRState]):
"""LSMR solver for linear systems.
This solver can handle any operator, even nonsquare or singular ones. In these
cases it will return the pseudoinverse solution to the linear system.
Similar to `scipy.sparse.linalg.lsmr`.
This supports the following `options` (as passed to
`lx.linear_solve(..., options=...)`).
- `y0`: The initial estimate of the solution to the linear system. Defaults to all
zeros.
"""
rtol: float
atol: float
norm: Callable = two_norm
max_steps: int | None = None
conlim: float = 1e8
def __check_init__(self):
if isinstance(self.rtol, (int, float)) and self.rtol < 0:
raise ValueError("Tolerances must be non-negative.")
if isinstance(self.atol, (int, float)) and self.atol < 0:
raise ValueError("Tolerances must be non-negative.")
if isinstance(self.conlim, (int, float)) and self.conlim < 0:
raise ValueError("Tolerances must be non-negative.")
if isinstance(self.atol, (int, float)) and isinstance(self.rtol, (int, float)):
if self.atol == 0 and self.rtol == 0 and self.max_steps is None:
raise ValueError(
"Must specify `atol`, `rtol`, or `max_steps` (or some combination "
"of all three)."
)
def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
return linearise(operator)
def compute(
self,
state: _LSMRState,
vector: PyTree[Array],
options: dict[str, Any],
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
operator = state
x = options.get("y0", None)
# damp is not supported at this time.
# damp = options.get("damp", 0.0)
damp = 0.0
has_scale = not (
isinstance(self.atol, (int, float))
and isinstance(self.rtol, (int, float))
and self.atol == 0
and self.rtol == 0
)
dtype = jnp.result_type(
*jtu.tree_leaves(vector),
*jtu.tree_leaves(x),
*jtu.tree_leaves(operator.in_structure()),
)
m, n = operator.out_size(), operator.in_size()
# number of singular values
min_dim = min([m, n])
if self.max_steps is None:
# Set max_steps based on the minimum dimension + avoid numerical overflows
# https://github.com/patrick-kidger/lineax/issues/175
# https://github.com/patrick-kidger/lineax/issues/177
int_dtype = jnp.dtype(f"int{complex_to_real_dtype(dtype).itemsize * 8}")
if min_dim > (jnp.iinfo(int_dtype).max / 10):
max_steps = jnp.iinfo(int_dtype).max
else:
max_steps = min_dim * 10 # for consistency with other iterative solvers
else:
max_steps = self.max_steps
if x is None:
x = jtu.tree_map(jnp.zeros_like, operator.in_structure())
b = vector
u = (ω(b) - ω(operator.mv(x))).ω
normb = self.norm(b)
beta = self.norm(u)
def beta_nonzero(beta, u):
u = (ω(u) / lax.select(beta == 0.0, 1.0, beta).astype(dtype)).ω
v = conj(operator).T.mv(u)
alpha = self.norm(v)
return u, v, alpha
def beta_zero(beta, u):
v = jtu.tree_map(jnp.zeros_like, operator.in_structure())
alpha = 0.0
return u, v, alpha
u, v, alpha = lax.cond(beta == 0.0, beta_zero, beta_nonzero, beta, u)
v = (ω(v) / lax.select(alpha == 0.0, 1.0, alpha).astype(dtype)).ω
h = v
hbar = jtu.tree_map(jnp.zeros_like, operator.in_structure())
# Initialize variables for 1st iteration.
# generally, latin letters (b, x, u, v, h etc) are vectors that may be complex
# greek letters (alpha, beta, rho, zeta etc) are scalars that are always real
loop_state = dict(
# vectors
x=x,
u=u,
v=v,
h=h,
hbar=hbar,
# main loop variables
itn=0,
alpha=alpha,
beta=beta,
zetabar=alpha * beta,
alphabar=alpha,
rho=1.0,
rhobar=1.0,
cbar=1.0,
sbar=0.0,
# loop variables for estimation of ||r||.
betadd=beta,
betad=0.0,
rhodold=1.0,
tautildeold=0.0,
thetatilde=0.0,
zeta=0.0,
delta=0.0,
# variables for estimation of ||A|| and cond(A)
normA2=alpha**2,
maxrbar=0.0,
minrbar=jnp.finfo(dtype).max,
condA=1.0,
# variables for use in stopping rules
istop=0,
normr=beta,
normAr=alpha * beta,
)
# beta == 0 means x exactly solves the well posed problem
# alpha == 0 means x exactly solves the least squares problem
# we check this here to shortcut the loop to avoid division by zero
loop_state["istop"] = lax.select(alpha == 0, 2, loop_state["istop"])
loop_state["istop"] = lax.select(beta == 0, 1, loop_state["istop"])
def condfun(loop_state):
return loop_state["istop"] == 0
def bodyfun(loop_state):
st = loop_state # to avoid writing out loop_state every time
st["itn"] = st["itn"] + 1
# Perform the next step of the bidiagonalization to obtain the
# next beta, u, alpha, v. These satisfy the relations
# beta*u = A@v - alpha*u,
# alpha*v = A'@u - beta*v.
st["u"] = (ω(st["u"]) * -st["alpha"].astype(dtype)).ω
st["u"] = (ω(st["u"]) + ω(operator.mv(st["v"]))).ω
st["beta"] = self.norm(st["u"])
def beta_nonzero(alpha, beta, u, v):
u = (ω(u) / lax.select(beta == 0.0, 1.0, beta).astype(dtype)).ω
v = (ω(v) * -beta.astype(dtype)).ω
v = (ω(v) + ω(conj(operator).T.mv(u))).ω
alpha = self.norm(v)
v = (ω(v) / lax.select(alpha == 0.0, 1.0, alpha).astype(dtype)).ω
return alpha, beta, u, v
def beta_zero(alpha, beta, u, v):
return alpha, beta, u, v
st["alpha"], st["beta"], st["u"], st["v"] = lax.cond(
st["beta"] == 0,
beta_zero,
beta_nonzero,
st["alpha"],
st["beta"],
st["u"],
st["v"],
)
# At this point, beta = beta_{k+1}, alpha = alpha_{k+1}.
# Construct rotation Qhat_{k,2k+1}.
chat, shat, alphahat = self._givens(st["alphabar"], damp)
# Use a plane rotation (Q_i) to turn B_i to R_i
rhoold = st["rho"]
c, s, st["rho"] = self._givens(alphahat, st["beta"])
thetanew = s * st["alpha"]
st["alphabar"] = c * st["alpha"]
# Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar
rhobarold = st["rhobar"]
zetaold = st["zeta"]
thetabar = st["sbar"] * st["rho"]
rhotemp = st["cbar"] * st["rho"]
st["cbar"], st["sbar"], st["rhobar"] = self._givens(
st["cbar"] * st["rho"], thetanew
)
st["zeta"] = st["cbar"] * st["zetabar"]
st["zetabar"] = -st["sbar"] * st["zetabar"]
# Update h, h_hat, x.
st["hbar"] = (
ω(st["hbar"])
* -(thetabar * st["rho"] / (rhoold * rhobarold)).astype(dtype)
).ω
st["hbar"] = (ω(st["hbar"]) + ω(st["h"])).ω
st["x"] = (
ω(st["x"])
+ (st["zeta"] / (st["rho"] * st["rhobar"])).astype(dtype)
* ω(st["hbar"])
).ω
st["h"] = (ω(st["h"]) * -(thetanew / st["rho"]).astype(dtype)).ω
st["h"] = (ω(st["h"]) + ω(st["v"])).ω
# Estimate of ||r||.
# Apply rotation Qhat_{k,2k+1}.
betaacute = chat * st["betadd"]
betacheck = -shat * st["betadd"]
# Apply rotation Q_{k,k+1}.
betahat = c * betaacute
st["betadd"] = -s * betaacute
# Apply rotation Qtilde_{k-1}.
# betad = betad_{k-1} here.
thetatildeold = st["thetatilde"]
ctildeold, stildeold, rhotildeold = self._givens(st["rhodold"], thetabar)
st["thetatilde"] = stildeold * loop_state["rhobar"]
st["rhodold"] = ctildeold * st["rhobar"]
st["betad"] = -stildeold * st["betad"] + ctildeold * betahat
# betad = betad_k here.
# rhodold = rhod_k here.
loop_state["tautildeold"] = (
zetaold - thetatildeold * st["tautildeold"]
) / rhotildeold
taud = (st["zeta"] - st["thetatilde"] * st["tautildeold"]) / st["rhodold"]
st["delta"] = st["delta"] + betacheck**2
st["normr"] = jnp.sqrt(
st["delta"] + (st["betad"] - taud) ** 2 + st["betadd"] ** 2
)
# Estimate ||A||.
st["normA2"] = st["normA2"] + st["beta"] ** 2
normA = jnp.sqrt(st["normA2"])
st["normA2"] = st["normA2"] + st["alpha"] ** 2
# Estimate cond(A).
st["maxrbar"] = jnp.maximum(st["maxrbar"], rhobarold)
st["minrbar"] = lax.select(
st["itn"] > 1, jnp.minimum(st["minrbar"], rhobarold), st["minrbar"]
)
st["condA"] = jnp.maximum(st["maxrbar"], rhotemp) / jnp.minimum(
st["minrbar"], rhotemp
)
# Compute norms for convergence testing.
st["normAr"] = jnp.abs(st["zetabar"])
normx = self.norm(st["x"])
well_posed_tol = self.atol + self.rtol * (normA * normx + normb)
least_squares_tol = self.atol + self.rtol * (normA * st["normr"])
# maxiter exceeded
st["istop"] = lax.select(st["itn"] >= max_steps, 4, st["istop"])
# cond(A) seems to be greater than conlim
st["istop"] = lax.select(st["condA"] > self.conlim, 3, st["istop"])
# x solves the least-squares problem according to atol and rtol.
st["istop"] = lax.select(st["normAr"] < least_squares_tol, 2, st["istop"])
# x is a solution to A@x = b, according to atol and rtol.
st["istop"] = lax.select(st["normr"] < well_posed_tol, 1, st["istop"])
return st
loop_state = lax.while_loop(condfun, bodyfun, loop_state)
stats = {
"num_steps": loop_state["itn"],
"istop": loop_state["istop"],
"norm_r": loop_state["normr"],
"norm_Ar": loop_state["normAr"],
"norm_A": jnp.sqrt(loop_state["normA2"]),
"cond_A": loop_state["condA"],
"norm_x": self.norm(loop_state["x"]),
}
if self.max_steps is None:
result = RESULTS.where(
loop_state["itn"] == max_steps, RESULTS.singular, RESULTS.successful
)
elif has_scale:
result = RESULTS.where(
loop_state["itn"] == max_steps,
RESULTS.max_steps_reached,
RESULTS.successful,
)
else:
result = RESULTS.successful
result = RESULTS.where(loop_state["istop"] < 3, RESULTS.successful, result)
result = RESULTS.where(loop_state["istop"] == 3, RESULTS.conlim, result)
return loop_state["x"], result, stats
def _givens(self, a, b):
"""Stable implementation of Givens rotation, from [1]_
finds c, s, r such that
|c -s|[a| = |r|
[s c|[b| |0|
r = sqrt(a^2 + b^2)
Assumes a, b are real.
References
----------
.. [1] S.-C. Choi, "Iterative Methods for Singular Linear Equations
and Least-Squares Problems", Dissertation,
http://www.stanford.edu/group/SOL/dissertations/sou-cheng-choi-thesis.pdf
"""
assert not jnp.iscomplexobj(a)
assert not jnp.iscomplexobj(b)
def bzero(a, b):
return jnp.sign(a), 0.0, jnp.abs(a)
def azero(a, b):
return 0.0, jnp.sign(b), jnp.abs(b)
def b_gt_a(a, b):
tau = a / lax.select(b == 0.0, 1.0, b)
s = jnp.sign(b) / jnp.sqrt(1.0 + tau**2)
c = s * tau
r = b / lax.select(s == 0.0, 1.0, s)
return c, s, r
def a_ge_b(a, b):
tau = b / lax.select(a == 0.0, 1.0, a)
c = jnp.sign(a) / jnp.sqrt(1.0 + tau**2)
s = c * tau
r = a / lax.select(c == 0.0, 1.0, c)
return c, s, r
def either_zero(a, b):
return lax.cond(b == 0.0, bzero, azero, a, b)
def both_nonzero(a, b):
return lax.cond(jnp.abs(b) > jnp.abs(a), b_gt_a, a_ge_b, a, b)
return lax.cond((a == 0.0) | (b == 0.0), either_zero, both_nonzero, a, b)
def transpose(self, state: _LSMRState, options: dict[str, Any]):
del options
operator = state
transpose_options = {}
return operator.transpose(), transpose_options
def conj(self, state: _LSMRState, options: dict[str, Any]):
del options
operator = state
conj_options = {}
return conj(operator), conj_options
def assume_full_rank(self):
return False
LSMR.__init__.__doc__ = r"""**Arguments:**
- `rtol`: Relative tolerance for terminating solve.
- `atol`: Absolute tolerance for terminating solve.
- `norm`: The norm to use when computing whether the error falls within the tolerance.
Defaults to the two norm.
- `max_steps`: The maximum number of iterations to run the solver for. If more steps
than this are required, then the solve is halted with a failure.
- `conlim`: The solver terminates if an estimate of cond(A) exceeds conlim. For
compatible systems Ax = b, conlim could be as large as 1.0e+12 (say). For
least-squares problems, conlim should be less than 1.0e+8. If conlim is None,
the default value is 1e+8. Maximum precision can be obtained by setting
atol = rtol = 0, conlim = np.inf, but the number of iterations may then be
excessive. Default is 1e8.
"""
================================================
FILE: lineax/_solver/lu.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, TypeAlias
import equinox.internal as eqxi
import jax.numpy as jnp
import jax.scipy as jsp
from jaxtyping import Array, PyTree
from .._operator import AbstractLinearOperator, is_diagonal
from .._solution import RESULTS
from .._solve import AbstractLinearSolver
from .misc import (
pack_structures,
PackedStructures,
ravel_vector,
transpose_packed_structures,
unravel_solution,
)
_LUState: TypeAlias = tuple[tuple[Array, Array], PackedStructures, eqxi.Static]
class LU(AbstractLinearSolver[_LUState]):
"""LU solver for linear systems.
This solver can only handle square nonsingular operators.
"""
def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
del options
if operator.in_size() != operator.out_size():
raise ValueError(
"`LU` may only be used for linear solves with square matrices"
)
packed_structures = pack_structures(operator)
if is_diagonal(operator):
lu = operator.as_matrix(), jnp.arange(operator.in_size(), dtype=jnp.int32)
else:
lu = jsp.linalg.lu_factor(operator.as_matrix())
return lu, packed_structures, eqxi.Static(False)
def compute(
self, state: _LUState, vector: PyTree[Array], options: dict[str, Any]
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
del options
lu_and_piv, packed_structures, transpose = state
transpose = transpose.value
trans = 1 if transpose else 0
vector = ravel_vector(vector, packed_structures)
solution = jsp.linalg.lu_solve(lu_and_piv, vector, trans=trans)
solution = unravel_solution(solution, packed_structures)
return solution, RESULTS.successful, {}
def transpose(
self,
state: _LUState,
options: dict[str, Any],
):
lu_and_piv, packed_structures, transpose = state
transposed_packed_structures = transpose_packed_structures(packed_structures)
transpose_state = (
lu_and_piv,
transposed_packed_structures,
eqxi.Static(not transpose.value),
)
transpose_options = {}
return transpose_state, transpose_options
def conj(
self,
state: _LUState,
options: dict[str, Any],
):
(lu, piv), packed_structures, transpose = state
conj_state = (
(lu.conj(), piv),
packed_structures,
eqxi.Static(not transpose.value),
)
conj_options = {}
return conj_state, conj_options
def assume_full_rank(self):
return True
LU.__init__.__doc__ = """**Arguments:**
Nothing.
"""
================================================
FILE: lineax/_solver/misc.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import typing
import warnings
from typing import Any, NewType, TYPE_CHECKING
import equinox.internal as eqxi
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from jaxtyping import Array, PyTree, Shaped
from .._misc import strip_weak_dtype, structure_equal
from .._operator import AbstractLinearOperator, IdentityLinearOperator, linearise
def preconditioner_and_y0(
operator: AbstractLinearOperator, vector: PyTree[Array], options: dict[str, Any]
):
structure = operator.in_structure()
try:
preconditioner = linearise(options["preconditioner"])
except KeyError:
preconditioner = IdentityLinearOperator(structure)
else:
if not isinstance(preconditioner, AbstractLinearOperator):
raise ValueError("The preconditioner must be a linear operator.")
if not structure_equal(preconditioner.in_structure(), structure):
raise ValueError(
"The preconditioner must have `in_structure` that matches the "
"operator's `in_strucure`."
)
if not structure_equal(preconditioner.out_structure(), structure):
raise ValueError(
"The preconditioner must have `out_structure` that matches the "
"operator's `in_structure`."
)
try:
y0 = options["y0"]
except KeyError:
y0 = jtu.tree_map(jnp.zeros_like, vector)
else:
if not structure_equal(y0, vector):
raise ValueError(
"`y0` must have the same structure, shape, and dtype as `vector`"
)
return preconditioner, y0
# This seems to introduce some spurious failure at docgen time.
if hasattr(typing, "GENERATING_DOCUMENTATION") and not TYPE_CHECKING:
PackedStructures = lambda x: x
else:
PackedStructures = NewType("PackedStructures", eqxi.Static)
def pack_structures(operator: AbstractLinearOperator) -> PackedStructures:
structures = (
strip_weak_dtype(operator.out_structure()),
strip_weak_dtype(operator.in_structure()),
)
leaves, treedef = jtu.tree_flatten(structures) # handle nonhashable pytrees
return PackedStructures(eqxi.Static((leaves, treedef)))
def ravel_vector(
pytree: PyTree[Array], packed_structures: PackedStructures
) -> Shaped[Array, " size"]:
leaves, treedef = packed_structures.value
out_structure, _ = jtu.tree_unflatten(treedef, leaves)
# `is` in case `tree_equal` returns a Tracer.
if not structure_equal(pytree, out_structure):
raise ValueError("pytree does not match out_structure")
# not using `ravel_pytree` as that doesn't come with guarantees about order
leaves = jtu.tree_leaves(pytree)
dtype = jnp.result_type(*leaves)
return jnp.concatenate([x.astype(dtype).reshape(-1) for x in leaves])
def unravel_solution(
solution: Shaped[Array, " size"], packed_structures: PackedStructures
) -> PyTree[Array]:
leaves, treedef = packed_structures.value
_, in_structure = jtu.tree_unflatten(treedef, leaves)
leaves, treedef = jtu.tree_flatten(in_structure)
sizes = np.cumsum([math.prod(x.shape) for x in leaves[:-1]])
split = jnp.split(solution, sizes)
assert len(split) == len(leaves)
with warnings.catch_warnings():
warnings.simplefilter("ignore") # ignore complex-to-real cast warning
shaped = [x.reshape(y.shape).astype(y.dtype) for x, y in zip(split, leaves)]
return jtu.tree_unflatten(treedef, shaped)
def transpose_packed_structures(
packed_structures: PackedStructures,
) -> PackedStructures:
leaves, treedef = packed_structures.value
out_structure, in_structure = jtu.tree_unflatten(treedef, leaves)
leaves, treedef = jtu.tree_flatten((in_structure, out_structure))
return PackedStructures(eqxi.Static((leaves, treedef)))
================================================
FILE: lineax/_solver/normal.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import copy
from typing import Any, TypeVar
import equinox.internal as eqxi
from jaxtyping import Array, PyTree
from .._operator import conj, linearise, materialise, TaggedLinearOperator
from .._solution import RESULTS
from .._solve import AbstractLinearOperator, AbstractLinearSolver
from .._tags import positive_semidefinite_tag
from .cholesky import Cholesky
_InnerSolverState = TypeVar("_InnerSolverState")
def normal_preconditioner_and_y0(options: dict[str, Any], tall: bool):
preconditioner = options.get("preconditioner")
y0 = options.get("y0")
inner_options = copy(options)
del options
if preconditioner is not None:
preconditioner = linearise(preconditioner)
if tall:
inner_options["preconditioner"] = TaggedLinearOperator(
preconditioner @ conj(preconditioner.transpose()),
positive_semidefinite_tag,
)
else:
inner_options["preconditioner"] = TaggedLinearOperator(
conj(preconditioner.transpose()) @ preconditioner,
positive_semidefinite_tag,
)
if y0 is not None:
inner_options["y0"] = conj(preconditioner.transpose()).mv(y0)
return inner_options
class Normal(
AbstractLinearSolver[
tuple[_InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any]]
]
):
"""Wrapper for an inner solver of positive (semi)definite systems. The
wrapped solver handles possibly nonsquare systems $Ax = b$ by applying the
inner solver to the normal equations
$A^* A x = A^* b$
if $m \\ge n$, otherwise
$A A^* y = b$,
where $x = A^* y$.
If the inner solver solves systems with positive definite $A$, the wrapped
solver solves systems with full rank $A$.
If the inner solver solves systems with positive semidefinite $A$, the
wrapped solver solves systems with arbitrary, possibly rank deficient, $A$.
Note that this squares the condition number, so applying this method to an
iterative inner solver may result in slow convergence and high sensitivity
to roundoff error. In this case it may be advantageous to choose an
appropriate preconditioner or initial solution guess for the problem.
This wrapper adjusts the following `options` before passing to the inner
operator (as passed to `lx.linear_solve(..., options=...)`).
- `preconditioner`: A [`lineax.AbstractLinearOperator`][] to be used as
preconditioner. Defaults to [`lineax.IdentityLinearOperator`][]. This
should be an approximation of the (pseudo)inverse of $A$. When passed
to the inner solver, the preconditioner $M$ is replaced by $M M^*$ and
$M^* M$ in the first and second versions of the normal equations,
respectively.
- `y0`: An initial estimate of the solution of the linear system $Ax = b$.
Defaults to all zeros. In the second version of the normal equations,
$y_0$ is replaced with $M^* y_0$, where $M$ is the given outer
preconditioner.
!!! Info
Good choices of inner solvers are the direct [`lineax.Cholesky`][] and
the iterative [`lineax.CG`][].
"""
inner_solver: AbstractLinearSolver[_InnerSolverState]
def init(self, operator, options):
tall = operator.out_size() >= operator.in_size()
# Cholesky materialises op twice when computing (op^H @ op).as_matrix()
# Cheaper to materialise first and then conjugate-transpose.
# For iterative solvers we only linearise to avoid eager materialisation.
is_cholesky = isinstance(self.inner_solver, Cholesky)
lin_op = materialise(operator) if is_cholesky else linearise(operator)
if tall:
inner_operator = conj(lin_op.transpose()) @ lin_op
else:
inner_operator = lin_op @ conj(lin_op.transpose())
inner_operator = TaggedLinearOperator(inner_operator, positive_semidefinite_tag)
inner_options = normal_preconditioner_and_y0(options, tall)
inner_state = self.inner_solver.init(inner_operator, inner_options)
operator_conj_transpose = conj(lin_op.transpose())
return inner_state, eqxi.Static(tall), operator_conj_transpose, inner_options
def compute(
self,
state: tuple[
_InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any]
],
vector: PyTree[Array],
options: dict[str, Any],
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
inner_state, tall, operator_conj_transpose, inner_options = state
tall = tall.value
del state, options
if tall:
vector = operator_conj_transpose.mv(vector)
solution, result, extra_stats = self.inner_solver.compute(
inner_state, vector, inner_options
)
if not tall:
solution = operator_conj_transpose.mv(solution)
return solution, result, extra_stats
def transpose(
self,
state: tuple[
_InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any]
],
options: dict[str, Any],
):
inner_state, tall, operator_conj_transpose, inner_options = state
inner_state_conj, inner_options = self.inner_solver.conj(
inner_state, inner_options
)
state_transpose = (
inner_state_conj,
eqxi.Static(not tall.value),
operator_conj_transpose.transpose(),
inner_options,
)
return state_transpose, options
def conj(
self,
state: tuple[
_InnerSolverState, eqxi.Static, AbstractLinearOperator, dict[str, Any]
],
options: dict[str, Any],
):
inner_state, tall, operator_conj_transpose, inner_options = state
inner_state_conj, inner_options = self.inner_solver.conj(
inner_state, inner_options
)
state_conj = (
inner_state_conj,
tall,
conj(operator_conj_transpose),
inner_options,
)
return state_conj, options
def assume_full_rank(self):
return self.inner_solver.assume_full_rank()
Normal.__init__.__doc__ = """**Arguments:**
- `inner_solver`: The solver to wrap. It should support solving positive
definite systems or positive semidefinite systems
"""
================================================
FILE: lineax/_solver/qr.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, TypeAlias
import equinox.internal as eqxi
import jax.lax.linalg as jll
import jax.numpy as jnp
import jax.scipy as jsp
from jaxtyping import Array, PyTree
from .._solution import RESULTS
from .._solve import AbstractLinearSolver
from .misc import (
pack_structures,
PackedStructures,
ravel_vector,
transpose_packed_structures,
unravel_solution,
)
_QRState: TypeAlias = tuple[tuple[Array, Array], eqxi.Static, PackedStructures]
class QR(AbstractLinearSolver):
"""QR solver for linear systems.
This solver can handle non-square operators.
This is usually the preferred solver when dealing with non-square operators.
!!! info
Note that whilst this does handle non-square operators, it still can only
handle full-rank operators.
This is because JAX does not currently support a rank-revealing/pivoted QR
decomposition, see [issue #12897](https://github.com/google/jax/issues/12897).
For such use cases, switch to [`lineax.SVD`][] instead.
"""
def init(self, operator, options):
del options
matrix = operator.as_matrix()
m, n = matrix.shape
transpose = n > m
if transpose:
matrix = matrix.T
h, taus = jnp.linalg.qr(matrix, mode="raw") # pyright: ignore
a = h.mT
packed_structures = pack_structures(operator)
return (a, taus), eqxi.Static(transpose), packed_structures
def compute(
self,
state: _QRState,
vector: PyTree[Array],
options: dict[str, Any],
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
(a, taus), transpose, packed_structures = state
transpose = transpose.value
del state, options
vector = ravel_vector(vector, packed_structures)
n_full, n_min = a.shape
r = a[:n_min]
if transpose:
# Minimal norm solution if underdetermined: x = Q.conj() @ R^{-T} @ b.
# Use Q.conj() @ z = (z^T @ Q^H)^T to avoid explicit `conj` calls,
# and pad `y` along the row axis to absorb the discarded columns of Q.
y = jsp.linalg.solve_triangular(r, vector, trans="T", unit_diagonal=False)
zeros = jnp.zeros((1, n_full - n_min), dtype=y.dtype)
y_pad = jnp.concatenate([y[None, :], zeros], axis=1)
solution = jll.ormqr(a, taus, y_pad, left=False, transpose=True)[0]
else:
# Least squares solution if overdetermined.
qHv = jll.ormqr(a, taus, vector[:, None], transpose=True)[:n_min, 0]
solution = jsp.linalg.solve_triangular(
r, qHv, trans="N", unit_diagonal=False
)
solution = unravel_solution(solution, packed_structures)
return solution, RESULTS.successful, {}
def transpose(self, state: _QRState, options: dict[str, Any]):
(a, taus), transpose, structures = state
transposed_packed_structures = transpose_packed_structures(structures)
transpose_state = (
(a, taus),
eqxi.Static(not transpose.value),
transposed_packed_structures,
)
transpose_options = {}
return transpose_state, transpose_options
def conj(self, state: _QRState, options: dict[str, Any]):
(a, taus), transpose, structures = state
conj_state = (
(a.conj(), taus.conj()),
transpose,
structures,
)
conj_options = {}
return conj_state, conj_options
def assume_full_rank(self):
return True
QR.__init__.__doc__ = """**Arguments:**
Nothing.
"""
================================================
FILE: lineax/_solver/svd.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, TypeAlias
import jax.lax as lax
import jax.numpy as jnp
import jax.scipy as jsp
from jaxtyping import Array, PyTree
from .._misc import resolve_rcond
from .._operator import AbstractLinearOperator
from .._solution import RESULTS
from .._solve import AbstractLinearSolver
from .misc import (
pack_structures,
PackedStructures,
ravel_vector,
transpose_packed_structures,
unravel_solution,
)
_SVDState: TypeAlias = tuple[tuple[Array, Array, Array], PackedStructures]
class SVD(AbstractLinearSolver[_SVDState]):
"""SVD solver for linear systems.
This solver can handle any operator, even nonsquare or singular ones. In these
cases it will return the pseudoinverse solution to the linear system.
Equivalent to `scipy.linalg.lstsq`.
"""
rcond: float | None = None
def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
del options
svd = jsp.linalg.svd(operator.as_matrix(), full_matrices=False)
packed_structures = pack_structures(operator)
return svd, packed_structures
def compute(
self,
state: _SVDState,
vector: PyTree[Array],
options: dict[str, Any],
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
del options
(u, s, vt), packed_structures = state
vector = ravel_vector(vector, packed_structures)
m, _ = u.shape
_, n = vt.shape
rcond = resolve_rcond(self.rcond, n, m, s.dtype)
rcond = jnp.array(rcond, dtype=s.dtype)
if s.size > 0:
rcond = rcond * s[0]
# Not >=, or this fails with a matrix of all-zeros.
mask = s > rcond
rank = mask.sum()
safe_s = jnp.where(mask, s, 1)
s_inv = jnp.where(mask, jnp.array(1.0) / safe_s, 0).astype(u.dtype)
uTb = jnp.matmul(u.conj().T, vector, precision=lax.Precision.HIGHEST)
solution = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
solution = unravel_solution(solution, packed_structures)
return solution, RESULTS.successful, {"rank": rank}
def transpose(self, state: _SVDState, options: dict[str, Any]):
del options
(u, s, vt), packed_structures = state
transposed_packed_structures = transpose_packed_structures(packed_structures)
transpose_state = (vt.T, s, u.T), transposed_packed_structures
transpose_options = {}
return transpose_state, transpose_options
def conj(self, state: _SVDState, options: dict[str, Any]):
del options
(u, s, vt), packed_structures = state
conj_state = (u.conj(), s, vt.conj()), packed_structures
conj_options = {}
return conj_state, conj_options
def assume_full_rank(self):
return False
SVD.__init__.__doc__ = """**Arguments**:
- `rcond`: the cutoff for handling zero entries on the diagonal. Defaults to machine
precision times `max(N, M)`, where `(N, M)` is the shape of the operator. (I.e.
`N` is the output size and `M` is the input size.)
"""
================================================
FILE: lineax/_solver/triangular.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, TypeAlias
import equinox.internal as eqxi
import jax.scipy as jsp
from jaxtyping import Array, PyTree
from .._operator import (
AbstractLinearOperator,
has_unit_diagonal,
is_lower_triangular,
is_upper_triangular,
)
from .._solution import RESULTS
from .._solve import AbstractLinearSolver
from .misc import (
pack_structures,
PackedStructures,
ravel_vector,
transpose_packed_structures,
unravel_solution,
)
_TriangularState: TypeAlias = tuple[
Array, eqxi.Static, eqxi.Static, PackedStructures, eqxi.Static
]
class Triangular(AbstractLinearSolver[_TriangularState]):
"""Triangular solver for linear systems.
The operator should either be lower triangular or upper triangular.
"""
def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
del options
if operator.in_size() != operator.out_size():
raise ValueError(
"`Triangular` may only be used for linear solves with square matrices"
)
if not (is_lower_triangular(operator) or is_upper_triangular(operator)):
raise ValueError(
"`Triangular` may only be used for linear solves with triangular "
"matrices"
)
return (
operator.as_matrix(),
eqxi.Static(is_lower_triangular(operator)),
eqxi.Static(has_unit_diagonal(operator)),
pack_structures(operator),
eqxi.Static(False), # transposed
)
def compute(
self, state: _TriangularState, vector: PyTree[Array], options: dict[str, Any]
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
matrix, lower, unit_diagonal, packed_structures, transpose = state
lower = lower.value
unit_diagonal = unit_diagonal.value
transpose = transpose.value
del state, options
vector = ravel_vector(vector, packed_structures)
if transpose:
trans = "T"
else:
trans = "N"
solution = jsp.linalg.solve_triangular(
matrix, vector, trans=trans, lower=lower, unit_diagonal=unit_diagonal
)
solution = unravel_solution(solution, packed_structures)
return solution, RESULTS.successful, {}
def transpose(self, state: _TriangularState, options: dict[str, Any]):
del options
matrix, lower, unit_diagonal, packed_structures, transpose = state
transposed_packed_structures = transpose_packed_structures(packed_structures)
transpose_state = (
matrix,
lower,
unit_diagonal,
transposed_packed_structures,
eqxi.Static(not transpose.value),
)
transpose_options = {}
return transpose_state, transpose_options
def conj(self, state: _TriangularState, options: dict[str, Any]):
del options
matrix, lower, unit_diagonal, packed_structures, transpose = state
conj_state = (
matrix.conj(),
lower,
unit_diagonal,
packed_structures,
transpose,
)
conj_options = {}
return conj_state, conj_options
def assume_full_rank(self):
return True
Triangular.__init__.__doc__ = """**Arguments:**
Nothing.
"""
================================================
FILE: lineax/_solver/tridiagonal.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, TypeAlias
import jax.lax as lax
import jax.numpy as jnp
from jaxtyping import Array, PyTree
from .._operator import AbstractLinearOperator, is_tridiagonal, tridiagonal
from .._solution import RESULTS
from .._solve import AbstractLinearSolver
from .misc import (
pack_structures,
PackedStructures,
ravel_vector,
transpose_packed_structures,
unravel_solution,
)
_TridiagonalState: TypeAlias = tuple[tuple[Array, Array, Array], PackedStructures]
class Tridiagonal(AbstractLinearSolver[_TridiagonalState]):
"""Tridiagonal solver for linear systems, uses the LAPACK/cusparse implementation
of Gaussian elimination with partial pivotting (which increases stability).
."""
def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):
del options
if operator.in_size() != operator.out_size():
raise ValueError(
"`Tridiagonal` may only be used for linear solves with square matrices"
)
if not is_tridiagonal(operator):
raise ValueError(
"`Tridiagonal` may only be used for linear solves with tridiagonal "
"matrices"
)
return tridiagonal(operator), pack_structures(operator)
def compute(
self,
state: _TridiagonalState,
vector,
options,
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
(diagonal, lower_diagonal, upper_diagonal), packed_structures = state
del state, options
vector = ravel_vector(vector, packed_structures)
solution = lax.linalg.tridiagonal_solve(
jnp.append(0.0, lower_diagonal),
diagonal,
jnp.append(upper_diagonal, 0.0),
vector[:, None],
).flatten()
solution = unravel_solution(solution, packed_structures)
return solution, RESULTS.successful, {}
def transpose(self, state: _TridiagonalState, options: dict[str, Any]):
(diagonal, lower_diagonal, upper_diagonal), packed_structures = state
transposed_packed_structures = transpose_packed_structures(packed_structures)
transpose_diagonals = (diagonal, upper_diagonal, lower_diagonal)
transpose_state = (transpose_diagonals, transposed_packed_structures)
return transpose_state, options
def conj(self, state: _TridiagonalState, options: dict[str, Any]):
(diagonal, lower_diagonal, upper_diagonal), packed_structures = state
conj_diagonals = (diagonal.conj(), lower_diagonal.conj(), upper_diagonal.conj())
conj_state = (conj_diagonals, packed_structures)
return conj_state, options
def assume_full_rank(self):
return True
Tridiagonal.__init__.__doc__ = """**Arguments:**
Nothing.
"""
================================================
FILE: lineax/_tags.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class _HasRepr:
def __init__(self, string: str):
self.string = string
def __repr__(self):
return self.string
symmetric_tag = _HasRepr("symmetric_tag")
diagonal_tag = _HasRepr("diagonal_tag")
tridiagonal_tag = _HasRepr("tridiagonal_tag")
unit_diagonal_tag = _HasRepr("unit_diagonal_tag")
lower_triangular_tag = _HasRepr("lower_triangular_tag")
upper_triangular_tag = _HasRepr("upper_triangular_tag")
positive_semidefinite_tag = _HasRepr("positive_semidefinite_tag")
negative_semidefinite_tag = _HasRepr("negative_semidefinite_tag")
transpose_tags_rules = []
for tag in (
symmetric_tag,
unit_diagonal_tag,
diagonal_tag,
positive_semidefinite_tag,
negative_semidefinite_tag,
tridiagonal_tag,
):
@transpose_tags_rules.append
def _(tags: frozenset[object], tag=tag):
if tag in tags:
return tag
@transpose_tags_rules.append
def _(tags: frozenset[object]):
if lower_triangular_tag in tags:
return upper_triangular_tag
@transpose_tags_rules.append
def _(tags: frozenset[object]):
if upper_triangular_tag in tags:
return lower_triangular_tag
def transpose_tags(tags: frozenset[object]):
"""Lineax uses "tags" to declare that a particular linear operator exhibits some
property, e.g. symmetry.
This function takes in a collection of tags representing a linear operator, and
returns a collection of tags that should be associated with the transpose of that
linear operator.
**Arguments:**
- `tags`: a `frozenset` of tags.
**Returns:**
A `frozenset` of tags.
"""
if symmetric_tag in tags:
return tags
new_tags = []
for rule in transpose_tags_rules:
out = rule(tags)
if out is not None:
new_tags.append(out)
return frozenset(new_tags)
================================================
FILE: lineax/internal/__init__.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .._misc import (
complex_to_real_dtype as complex_to_real_dtype,
default_floating_dtype as default_floating_dtype,
)
from .._norm import (
max_norm as max_norm,
rms_norm as rms_norm,
sum_squares as sum_squares,
tree_dot as tree_dot,
two_norm as two_norm,
)
from .._solve import linear_solve_p as linear_solve_p
from .._solver.misc import (
pack_structures as pack_structures,
PackedStructures as PackedStructures,
ravel_vector as ravel_vector,
transpose_packed_structures as transpose_packed_structures,
unravel_solution as unravel_solution,
)
================================================
FILE: mkdocs.yml
================================================
theme:
name: material
features:
- navigation.sections # Sections are included in the navigation on the left.
- toc.integrate # Table of contents is integrated on the left; does not appear separately on the right.
- header.autohide # header disappears as you scroll
palette:
# Light mode / dark mode
# We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as
# (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle.
- scheme: default
primary: white
accent: amber
toggle:
icon: material/weather-night
name: Switch to dark mode
- scheme: slate
primary: black
accent: amber
toggle:
icon: material/weather-sunny
name: Switch to light mode
icon:
repo: fontawesome/brands/github # GitHub logo in top right
logo: "material/matrix" # lineax logo in top left
favicon: "_static/favicon.png"
custom_dir: "docs/_overrides" # Overriding part of the HTML
# These additions are my own custom ones, having overridden a partial.
twitter_bluesky_name: "@PatrickKidger"
twitter_url: "https://twitter.com/PatrickKidger"
bluesky_url: "https://PatrickKidger.bsky.social"
site_name: lineax
site_description: The documentation for the Lineax software library.
site_author: Patrick Kidger
site_url: https://docs.kidger.site/lineax
repo_url: https://github.com/patrick-kidger/lineax
repo_name: patrick-kidger/lineax
edit_uri: ""
strict: true # Don't allow warnings during the build process
extra_javascript:
# The below two make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/
- _static/mathjax.js
- https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
extra_css:
- _static/custom_css.css
markdown_extensions:
- pymdownx.arithmatex: # Render LaTeX via MathJax
generic: true
- pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme.
- pymdownx.details # Allowing hidden expandable regions denoted by ???
- pymdownx.snippets: # Include one Markdown file into another
base_path: docs
- admonition
- toc:
permalink: "¤" # Adds a clickable permalink to each section heading
toc_depth: 4
plugins:
- search:
separator: '[\s\-,:!=\[\]()"/]+|(?!\b)(?=[A-Z][a-z])|\.(?!\d)|&[lg]t;'
- include_exclude_files:
include:
- ".htaccess"
exclude:
- "_overrides"
- "examples/.ipynb_checkpoints/"
- ipynb
- hippogriffe:
extra_public_objects:
- jax.ShapeDtypeStruct
- mkdocstrings:
handlers:
python:
options:
force_inspection: true
heading_level: 4
inherited_members: true
members_order: source
show_bases: false
show_if_no_docstring: true
show_overloads: false
show_root_heading: true
show_signature_annotations: true
show_source: false
show_symbol_type_heading: true
show_symbol_type_toc: true
nav:
- 'index.md'
- Examples:
- 'examples/classical_solve.ipynb'
- 'examples/least_squares.ipynb'
- 'examples/structured_matrices.ipynb'
- 'examples/no_materialisation.ipynb'
- 'examples/operators.ipynb'
- 'examples/complex_solve.ipynb'
- API:
- 'api/linear_solve.md'
- 'api/solvers.md'
- 'api/operators.md'
- 'api/tags.md'
- 'api/solution.md'
- 'api/functions.md'
- 'faq.md'
================================================
FILE: pyproject.toml
================================================
[build-system]
build-backend = "hatchling.build"
requires = ["hatchling"]
[dependency-groups]
dev = [
"prek==0.3.9",
"pyright==1.1.406",
"ruff==0.13.0",
"toml-sort==0.23.1"
]
docs = [
"hippogriffe==0.2.2",
"griffe==1.7.3",
"mkdocs==1.6.1",
"mkdocs-include-exclude-files==0.1.0",
"mkdocs-ipynb==0.1.1",
"mkdocs-material==9.6.7",
"mkdocstrings==0.28.3",
"mkdocstrings-python==1.16.8",
"pygments==2.20.0",
"pymdown-extensions==10.21.2"
]
tests = [
"beartype",
"equinox",
"pytest",
"pytest-xdist",
"jaxlib"
]
[project]
authors = [
{email = "raderjason@outlook.com", name = "Jason Rader"},
{email = "contact@kidger.site", name = "Patrick Kidger"}
]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Financial and Insurance Industry",
"Intended Audience :: Information Technology",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Scientific/Engineering :: Mathematics"
]
dependencies = ["jax>=0.10.0", "jaxtyping>=0.2.24", "equinox>=0.11.10", "typing_extensions>=4.5.0"]
description = "Linear solvers in JAX and Equinox."
keywords = ["jax", "neural-networks", "deep-learning", "equinox", "linear-solvers", "least-squares", "numerical-methods"]
license = {file = "LICENSE"}
name = "lineax"
readme = "README.md"
requires-python = "~=3.11"
urls = {repository = "https://github.com/google/lineax"}
version = "0.1.1"
[tool.hatch.build]
include = ["lineax/*"]
[tool.pyright]
include = ["lineax", "tests"]
reportIncompatibleMethodOverride = true
[tool.pytest.ini_options]
addopts = "--jaxtyping-packages=lineax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))"
[tool.ruff]
extend-include = ["*.ipynb"]
src = []
[tool.ruff.lint]
fixable = ["I001", "F401", "UP"]
ignore = ["E402", "E721", "E731", "E741", "F722"]
select = ["E", "F", "I001", "UP"]
[tool.ruff.lint.flake8-import-conventions.extend-aliases]
"collections" = "co"
"functools" = "ft"
"itertools" = "it"
[tool.ruff.lint.isort]
combine-as-imports = true
extra-standard-library = ["typing_extensions"]
lines-after-imports = 2
order-by-type = false
[tool.uv]
default-groups = ["dev", "docs", "tests"]
================================================
FILE: tests/README.md
================================================
Each file is run separately to avoid JAX out-of-memory'ing.
As such, run tests using `python -m tests`, *not* by just running `pytest`.
================================================
FILE: tests/__init__.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: tests/__main__.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
import subprocess
import sys
here = pathlib.Path(__file__).resolve().parent
# Each file is ran separately to avoid out-of-memorying.
running_out = 0
for file in here.iterdir():
if file.is_file() and file.name.startswith("test"):
out = subprocess.run(f"pytest {file}", shell=True).returncode
running_out = max(running_out, out)
sys.exit(running_out)
================================================
FILE: tests/conftest.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import equinox.internal as eqxi
import jax
import pytest
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_numpy_dtype_promotion", "strict")
jax.config.update("jax_numpy_rank_promotion", "raise")
@pytest.fixture
def getkey():
return eqxi.GetKey()
================================================
FILE: tests/helpers.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools as ft
import math
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import lineax as lx
import numpy as np
from equinox.internal import ω
@ft.cache
def _construct_matrix_impl(
getkey, tags, size, dtype, cond_or_singular: int | float | str, i: int
):
del i # used to break the cache
while True:
matrix = jr.normal(getkey(), (size, size), dtype=dtype)
if isinstance(cond_or_singular, str):
if cond_or_singular == "zero":
matrix = matrix.at[0, :].set(0)
elif cond_or_singular == "trim_row":
matrix = matrix[1:, :]
elif cond_or_singular == "trim_col":
matrix = matrix[:, 1:]
if tags != ():
assert (
isinstance(cond_or_singular, (int, float)) or cond_or_singular == "zero"
)
if has_tag(tags, lx.diagonal_tag):
matrix = jnp.diag(jnp.diag(matrix))
if has_tag(tags, lx.symmetric_tag):
matrix = matrix + matrix.T
if has_tag(tags, lx.lower_triangular_tag):
matrix = jnp.tril(matrix)
if has_tag(tags, lx.upper_triangular_tag):
matrix = jnp.triu(matrix)
if has_tag(tags, lx.unit_diagonal_tag):
matrix = matrix.at[jnp.arange(size), jnp.arange(size)].set(1)
if has_tag(tags, lx.tridiagonal_tag):
diagonal = jnp.diag(jnp.diag(matrix))
upper_diagonal = jnp.diag(jnp.diag(matrix, k=1), k=1)
lower_diagonal = jnp.diag(jnp.diag(matrix, k=-1), k=-1)
matrix = lower_diagonal + diagonal + upper_diagonal
if has_tag(tags, lx.positive_semidefinite_tag):
matrix = matrix @ matrix.T.conj()
if has_tag(tags, lx.negative_semidefinite_tag):
matrix = -matrix @ matrix.T.conj()
if isinstance(cond_or_singular, str):
break
else:
if eqxi.unvmap_all(jnp.linalg.cond(matrix) < cond_or_singular): # pyright: ignore
break
return matrix
def construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64):
if isinstance(solver, lx.Normal):
cond_cutoff = math.sqrt(1000)
else:
cond_cutoff = 1000
return tuple(
_construct_matrix_impl(getkey, tags, size, dtype, cond_cutoff, i)
for i in range(num)
)
def construct_singular_matrix(getkey, solver, tags, num=1, dtype=jnp.float64):
if isinstance(solver, (lx.Diagonal, lx.CG, lx.BiCGStab, lx.GMRES)):
singular_method = "zero"
else:
# Use `getkey()` rather than the stdlib `random.choice` for reproducibility
singular_method = ["zero", "trim_row", "trim_col"][
jr.choice(getkey(), np.array([0, 1, 2]))
]
size = 3
return tuple(
_construct_matrix_impl(getkey, tags, size, dtype, singular_method, i)
for i in range(num)
)
def construct_poisson_matrix(size, dtype=jnp.float64):
matrix = (
-2 * jnp.diag(jnp.ones(size, dtype=dtype))
+ jnp.diag(jnp.ones(size - 1, dtype=dtype), 1)
+ jnp.diag(jnp.ones(size - 1, dtype=dtype), -1)
)
return matrix
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-12
else:
tol = 1e-6
solvers_tags_pseudoinverse = [
(lx.AutoLinearSolver(well_posed=True), (), False),
(lx.AutoLinearSolver(well_posed=False), (), True),
(lx.Triangular(), lx.lower_triangular_tag, False),
(lx.Triangular(), lx.upper_triangular_tag, False),
(lx.Triangular(), (lx.lower_triangular_tag, lx.unit_diagonal_tag), False),
(lx.Triangular(), (lx.upper_triangular_tag, lx.unit_diagonal_tag), False),
(lx.Diagonal(), lx.diagonal_tag, False),
(lx.Diagonal(), (lx.diagonal_tag, lx.unit_diagonal_tag), False),
(lx.Tridiagonal(), lx.tridiagonal_tag, False),
(lx.LU(), (), False),
(lx.QR(), (), False),
(lx.SVD(), (), True),
(lx.BiCGStab(rtol=tol, atol=tol), (), False),
(lx.GMRES(rtol=tol, atol=tol), (), False),
(lx.CG(rtol=tol, atol=tol), lx.positive_semidefinite_tag, False),
(lx.CG(rtol=tol, atol=tol), lx.negative_semidefinite_tag, False),
(lx.Normal(lx.CG(rtol=tol, atol=tol)), (), False),
(lx.LSMR(atol=tol, rtol=tol), (), True),
(lx.Cholesky(), lx.positive_semidefinite_tag, False),
(lx.Cholesky(), lx.negative_semidefinite_tag, False),
(lx.Normal(lx.Cholesky()), (), False),
]
solvers_tags = [(a, b) for a, b, _ in solvers_tags_pseudoinverse]
solvers = [a for a, _, _ in solvers_tags_pseudoinverse]
pseudosolvers_tags = [(a, b) for a, b, c in solvers_tags_pseudoinverse if c]
def _transpose(operator, matrix):
return operator.T, matrix.T
def _linearise(operator, matrix):
return lx.linearise(operator), matrix
def _materialise(operator, matrix):
return lx.materialise(operator), matrix
ops = (lambda x, y: (x, y), _transpose, _linearise, _materialise)
def params(only_pseudo):
for make_operator in make_operators:
for solver, tags, pseudoinverse in solvers_tags_pseudoinverse:
if only_pseudo and not pseudoinverse:
continue
if (
make_operator is make_trivial_diagonal_operator
and tags != lx.diagonal_tag
):
continue
if make_operator is make_identity_operator and tags != lx.unit_diagonal_tag:
continue
if (
make_operator is make_tridiagonal_operator
and tags != lx.tridiagonal_tag
):
continue
yield make_operator, solver, tags
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8):
return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)
def has_tag(tags, tag):
return tag is tags or (isinstance(tags, tuple) and tag in tags)
make_operators = []
def _operators_append(x):
make_operators.append(x)
return x
@_operators_append
def make_matrix_operator(getkey, matrix, tags):
return lx.MatrixLinearOperator(matrix, tags)
@_operators_append
def make_trivial_pytree_operator(getkey, matrix, tags):
out_size, _ = matrix.shape
struct = jax.ShapeDtypeStruct((out_size,), matrix.dtype)
return lx.PyTreeLinearOperator(matrix, struct, tags)
@_operators_append
def make_function_operator(getkey, matrix, tags):
fn = lambda x: matrix @ x
_, in_size = matrix.shape
in_struct = jax.ShapeDtypeStruct((in_size,), matrix.dtype)
return lx.FunctionLinearOperator(fn, in_struct, tags)
@_operators_append
def make_jac_operator(getkey, matrix, tags):
out_size, in_size = matrix.shape
x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
diff = matrix - jac
fn = lambda x, _: a + (b + diff) @ x + c @ x**2
return lx.JacobianLinearOperator(fn, x, None, tags)
@_operators_append
def make_jacfwd_operator(getkey, matrix, tags):
out_size, in_size = matrix.shape
x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
diff = matrix - jac
fn = lambda x, _: a + (b + diff) @ x + c @ x**2
return lx.JacobianLinearOperator(fn, x, None, tags, jac="fwd")
@_operators_append
def make_jacrev_operator(getkey, matrix, tags):
"""JacobianLinearOperator with jac='bwd' using a custom_vjp function.
This uses custom_vjp so that forward-mode autodiff is NOT available,
which tests that jac='bwd' works correctly without relying on JVP.
"""
out_size, in_size = matrix.shape
x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
diff = matrix - jac
# Use custom_vjp to define a function that only has reverse-mode autodiff
@jax.custom_vjp
def custom_fn(x):
return a + (b + diff) @ x + c @ x**2
def custom_fn_fwd(x):
return custom_fn(x), x
def custom_fn_bwd(x, g):
# Jacobian is: (b + diff) + 2 * c * x
# VJP is: g @ J = g @ ((b + diff) + 2 * c * x)
# So J.T @ g =
return ((b + diff).T @ g + 2 * (c.T @ g) * x,)
custom_fn.defvjp(custom_fn_fwd, custom_fn_bwd)
fn = lambda x, _: custom_fn(x)
return lx.JacobianLinearOperator(fn, x, None, tags, jac="bwd")
@_operators_append
def make_trivial_diagonal_operator(getkey, matrix, tags):
assert tags == lx.diagonal_tag
diag = jnp.diag(matrix)
return lx.DiagonalLinearOperator(diag)
@_operators_append
def make_identity_operator(getkey, matrix, tags):
in_struct = jax.ShapeDtypeStruct((matrix.shape[-1],), matrix.dtype)
return lx.IdentityLinearOperator(input_structure=in_struct)
@_operators_append
def make_tridiagonal_operator(getkey, matrix, tags):
diag1 = jnp.diag(matrix)
if tags == lx.tridiagonal_tag:
diag2 = jnp.diag(matrix, k=-1)
diag3 = jnp.diag(matrix, k=1)
return lx.TridiagonalLinearOperator(diag1, diag2, diag3)
elif tags == lx.diagonal_tag:
diag2 = diag3 = jnp.zeros(matrix.shape[0] - 1)
return lx.TaggedLinearOperator(
lx.TridiagonalLinearOperator(diag1, diag2, diag3), lx.diagonal_tag
)
elif tags == lx.symmetric_tag:
diag2 = diag3 = jnp.diag(matrix, k=1)
return lx.TaggedLinearOperator(
lx.TridiagonalLinearOperator(diag1, diag2, diag3), lx.symmetric_tag
)
else:
assert False, tags
@_operators_append
def make_add_operator(getkey, matrix, tags):
matrix1 = 0.7 * matrix
matrix2 = 0.3 * matrix
operator = make_matrix_operator(getkey, matrix1, ()) + make_function_operator(
getkey, matrix2, ()
)
return lx.TaggedLinearOperator(operator, tags)
@_operators_append
def make_mul_operator(getkey, matrix, tags):
operator = make_jac_operator(getkey, 0.7 * matrix, ()) / 0.7
return lx.TaggedLinearOperator(operator, tags)
@_operators_append
def make_composed_operator(getkey, matrix, tags):
_, size = matrix.shape
diag = jr.normal(getkey(), (size,), dtype=matrix.dtype)
diag = jnp.where(jnp.abs(diag) < 0.05, 0.8, diag)
operator1 = make_trivial_pytree_operator(getkey, matrix / diag[None], ())
operator2 = lx.DiagonalLinearOperator(diag)
return lx.TaggedLinearOperator(operator1 @ operator2, tags)
# Slightly sketchy approach to finite differences, in that this is pulled out of
# Numerical Recipes.
# I also don't know of a handling of the JVP case off the top of my head -- although
# I'm sure it exists somewhere -- so I'm improvising a little here. (In particular
# removing the usual "(x + h) - x" denominator.)
def finite_difference_jvp(fn, primals, tangents):
out = fn(*primals)
# Choose ε to trade-off truncation error and floating-point rounding error.
max_leaves = [jnp.max(jnp.abs(p)) for p in jtu.tree_leaves(primals)] + [1]
scale = jnp.max(jnp.stack(max_leaves))
ε = np.sqrt(np.finfo(np.float64).eps) * scale
with jax.numpy_dtype_promotion("standard"):
primals_ε = (ω(primals) + ε * ω(tangents)).ω
out_ε = fn(*primals_ε)
tangents_out = jtu.tree_map(lambda x, y: (x - y) / ε, out_ε, out)
return out, tangents_out
def jvp_jvp_impl(
getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype
):
t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None
if (make_matrix is construct_matrix) or pseudoinverse:
matrix, t_matrix, tt_matrix, tt_t_matrix = construct_matrix(
getkey, solver, tags, num=4, dtype=dtype
)
make_op = ft.partial(make_operator, getkey)
t_make_operator = lambda p, t_p: eqx.filter_jvp(
make_op, (p, tags), (t_p, t_tags)
)
tt_make_operator = lambda p, t_p, tt_p, tt_t_p: eqx.filter_jvp(
t_make_operator, (p, t_p), (tt_p, tt_t_p)
)
(operator, t_operator), (tt_operator, tt_t_operator) = tt_make_operator(
matrix, t_matrix, tt_matrix, tt_t_matrix
)
out_size, _ = matrix.shape
vec = jr.normal(getkey(), (out_size,), dtype=dtype)
t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)
tt_vec = jr.normal(getkey(), (out_size,), dtype=dtype)
tt_t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)
if use_state:
def linear_solve1(operator, vector):
op_dynamic, op_static = eqx.partition(operator, eqx.is_inexact_array)
stopped_operator = eqx.combine(lax.stop_gradient(op_dynamic), op_static)
state = solver.init(stopped_operator, options={})
sol = lx.linear_solve(operator, vector, state=state, solver=solver)
return sol.value
else:
def linear_solve1(operator, vector):
sol = lx.linear_solve(operator, vector, solver=solver)
return sol.value
if pseudoinverse:
jnp_solve1 = lambda mat, vec: jnp.linalg.lstsq(mat, vec)[0] # pyright: ignore
else:
jnp_solve1 = jnp.linalg.solve # pyright: ignore
linear_solve2 = ft.partial(eqx.filter_jvp, linear_solve1)
jnp_solve2 = ft.partial(eqx.filter_jvp, jnp_solve1)
def _make_primal_tangents(mode):
lx_args = ([], [], operator, t_operator, tt_operator, tt_t_operator)
jnp_args = ([], [], matrix, t_matrix, tt_matrix, tt_t_matrix)
for primals, ttangents, op, t_op, tt_op, tt_t_op in (lx_args, jnp_args):
if "op" in mode:
primals.append(op)
ttangents.append(tt_op)
if "vec" in mode:
primals.append(vec)
ttangents.append(tt_vec)
if "t_op" in mode:
primals.append(t_op)
ttangents.append(tt_t_op)
if "t_vec" in mode:
primals.append(t_vec)
ttangents.append(tt_t_vec)
lx_out = tuple(lx_args[0]), tuple(lx_args[1])
jnp_out = tuple(jnp_args[0]), tuple(jnp_args[1])
return lx_out, jnp_out
modes = (
{"op"},
{"vec"},
{"t_op"},
{"t_vec"},
{"op", "vec"},
{"op", "t_op"},
{"op", "t_vec"},
{"vec", "t_op"},
{"vec", "t_vec"},
{"op", "vec", "t_op"},
{"op", "vec", "t_vec"},
{"vec", "t_op", "t_vec"},
{"op", "vec", "t_op", "t_vec"},
)
for mode in modes:
if mode == {"op"}:
linear_solve3 = lambda op: linear_solve2((op, vec), (t_operator, t_vec))
jnp_solve3 = lambda mat: jnp_solve2((mat, vec), (t_matrix, t_vec))
elif mode == {"vec"}:
linear_solve3 = lambda v: linear_solve2(
(operator, v), (t_operator, t_vec)
)
jnp_solve3 = lambda v: jnp_solve2((matrix, v), (t_matrix, t_vec))
elif mode == {"op", "vec"}:
linear_solve3 = lambda op, v: linear_solve2(
(op, v), (t_operator, t_vec)
)
jnp_solve3 = lambda mat, v: jnp_solve2((mat, v), (t_matrix, t_vec))
elif mode == {"t_op"}:
linear_solve3 = lambda t_op: linear_solve2(
(operator, vec), (t_op, t_vec)
)
jnp_solve3 = lambda t_mat: jnp_solve2((matrix, vec), (t_mat, t_vec))
elif mode == {"t_vec"}:
linear_solve3 = lambda t_v: linear_solve2(
(operator, vec), (t_operator, t_v)
)
jnp_solve3 = lambda t_v: jnp_solve2((matrix, vec), (t_matrix, t_v))
elif mode == {"op", "vec"}:
linear_solve3 = lambda op, v: linear_solve2(
(op, v), (t_operator, t_vec)
)
jnp_solve3 = lambda mat, v: jnp_solve2((mat, v), (t_matrix, t_vec))
elif mode == {"op", "t_op"}:
linear_solve3 = lambda op, t_op: linear_solve2((op, vec), (t_op, t_vec))
jnp_solve3 = lambda mat, t_mat: jnp_solve2((mat, vec), (t_mat, t_vec))
elif mode == {"op", "t_vec"}:
linear_solve3 = lambda op, t_v: linear_solve2(
(op, vec), (t_operator, t_v)
)
jnp_solve3 = lambda mat, t_v: jnp_solve2((mat, vec), (t_matrix, t_v))
elif mode == {"vec", "t_op"}:
linear_solve3 = lambda v, t_op: linear_solve2(
(operator, v), (t_op, t_vec)
)
jnp_solve3 = lambda v, t_mat: jnp_solve2((matrix, v), (t_mat, t_vec))
elif mode == {"vec", "t_vec"}:
linear_solve3 = lambda v, t_v: linear_solve2(
(operator, v), (t_operator, t_v)
)
jnp_solve3 = lambda v, t_v: jnp_solve2((matrix, v), (t_matrix, t_v))
elif mode == {"op", "vec", "t_op"}:
linear_solve3 = lambda op, v, t_op: linear_solve2(
(op, v), (t_op, t_vec)
)
jnp_solve3 = lambda mat, v, t_mat: jnp_solve2((mat, v), (t_mat, t_vec))
elif mode == {"op", "vec", "t_vec"}:
linear_solve3 = lambda op, v, t_v: linear_solve2(
(op, v), (t_operator, t_v)
)
jnp_solve3 = lambda mat, v, t_v: jnp_solve2((mat, v), (t_matrix, t_v))
elif mode == {"vec", "t_op", "t_vec"}:
linear_solve3 = lambda v, t_op, t_v: linear_solve2(
(operator, v), (t_op, t_v)
)
jnp_solve3 = lambda v, t_mat, t_v: jnp_solve2((matrix, v), (t_mat, t_v))
elif mode == {"op", "vec", "t_op", "t_vec"}:
linear_solve3 = lambda op, v, t_op, t_v: linear_solve2(
(op, v), (t_op, t_v)
)
jnp_solve3 = lambda mat, v, t_mat, t_v: jnp_solve2(
(mat, v), (t_mat, t_v)
)
else:
assert False
linear_solve3 = ft.partial(eqx.filter_jvp, linear_solve3)
linear_solve3 = eqx.filter_jit(linear_solve3)
jnp_solve3 = ft.partial(eqx.filter_jvp, jnp_solve3)
jnp_solve3 = eqx.filter_jit(jnp_solve3)
(primal, tangent), (jnp_primal, jnp_tangent) = _make_primal_tangents(mode)
(out, t_out), (minus_out, tt_out) = linear_solve3(primal, tangent)
(true_out, true_t_out), (minus_true_out, true_tt_out) = jnp_solve3(
jnp_primal, jnp_tangent
)
assert tree_allclose(out, true_out, atol=1e-4)
assert tree_allclose(t_out, true_t_out, atol=1e-4)
assert tree_allclose(tt_out, true_tt_out, atol=1e-4)
assert tree_allclose(minus_out, minus_true_out, atol=1e-4)
================================================
FILE: tests/test_adjoint.py
================================================
import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import pytest
from lineax import FunctionLinearOperator
from .helpers import (
make_identity_operator,
make_jacrev_operator,
make_operators,
make_tridiagonal_operator,
make_trivial_diagonal_operator,
tree_allclose,
)
@pytest.mark.parametrize("make_operator", make_operators)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_adjoint(make_operator, dtype, getkey):
if (
make_operator is make_trivial_diagonal_operator
or make_operator is make_identity_operator
):
matrix = jnp.eye(4, dtype=dtype)
tags = lx.diagonal_tag
in_size = out_size = 4
elif make_operator is make_tridiagonal_operator:
matrix = jnp.eye(4, dtype=dtype)
tags = lx.tridiagonal_tag
in_size = out_size = 4
else:
matrix = jr.normal(getkey(), (3, 5), dtype=dtype)
tags = ()
in_size = 5
out_size = 3
if make_operator is make_jacrev_operator and dtype is jnp.complex128:
# JacobianLinearOperator does not support complex dtypes when jac="bwd"
return
operator = make_operator(getkey, matrix, tags)
v1, v2 = (
jr.normal(getkey(), (in_size,), dtype=dtype),
jr.normal(getkey(), (out_size,), dtype=dtype),
)
inner1 = operator.mv(v1) @ v2.conj()
adjoint_op1 = lx.conj(operator).transpose()
ov2 = adjoint_op1.mv(v2)
inner2 = v1 @ ov2.conj()
assert tree_allclose(inner1, inner2)
adjoint_op2 = lx.conj(operator.transpose())
ov2 = adjoint_op2.mv(v2)
inner2 = v1 @ ov2.conj()
assert tree_allclose(inner1, inner2)
def test_functional_pytree_adjoint():
def fn(y):
return {"b": y["a"]}
y_struct = jax.eval_shape(lambda: {"a": 0.0})
operator = FunctionLinearOperator(fn, y_struct)
conj_operator = lx.conj(operator)
assert tree_allclose(lx.materialise(conj_operator), lx.materialise(operator))
def test_functional_pytree_adjoint_complex():
def fn(y):
return {"b": y["a"]}
y_struct = jax.eval_shape(lambda: {"a": 0.0j})
operator = FunctionLinearOperator(fn, y_struct)
conj_operator = lx.conj(operator)
assert tree_allclose(lx.materialise(conj_operator), lx.materialise(operator))
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-12
else:
tol = 1e-6
@pytest.mark.parametrize(
"solver",
[
# in theory only 1 iteration is needed, but stopping criteria are
# complicated, see gh #160
lx.GMRES(tol, tol, max_steps=4, restart=1),
lx.BiCGStab(tol, tol, max_steps=3),
lx.Normal(lx.CG(tol, tol, max_steps=4)),
lx.CG(tol, tol, max_steps=3),
],
)
def test_preconditioner_adjoint(solver):
"""Test for fix to gh #160"""
# Nonsymmetric poorly conditioned matrix. Without preconditioning,
# this would take 20+ iterations (100s for GMRES)
key = jax.random.key(123)
key, subkey = jax.random.split(key)
A = jax.random.uniform(key, (10, 10))
A += jnp.diag(jnp.arange(A.shape[0]) ** 6).astype(A.dtype)
b = jax.random.uniform(subkey, (A.shape[0],))
if isinstance(solver, lx.CG):
A = A.T @ A
tags = (lx.positive_semidefinite_tag,)
else:
tags = ()
A = lx.MatrixLinearOperator(A, tags=tags)
# exact inverse, should only take ~1 iteration
M = lx.MatrixLinearOperator(
jnp.linalg.inv(A.matrix),
tags=tags,
)
def solve(b):
out = lx.linear_solve(
A, b, solver=solver, options={"preconditioner": M}, throw=True
)
return out.value
# if they don't converge then this will throw an error
_ = solve(b)
A1 = jax.jacfwd(solve)(b)
A2 = jax.jacrev(solve)(b)
# we also do a sanity check, dx/db should give A^{-1}
assert tree_allclose(A1, jnp.linalg.inv(A.matrix), atol=tol, rtol=tol)
assert tree_allclose(A2, jnp.linalg.inv(A.matrix), atol=tol, rtol=tol)
================================================
FILE: tests/test_invert.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
from .helpers import tree_allclose
def _well_conditioned_matrix(getkey, size=3, dtype=jnp.float64):
"""Generate a well-conditioned random matrix."""
while True:
matrix = jr.normal(getkey(), (size, size), dtype=dtype)
if jnp.linalg.cond(matrix) < 100:
return matrix
def _well_conditioned_psd_matrix(getkey, size=3, dtype=jnp.float64):
"""Generate a well-conditioned PSD matrix."""
matrix = _well_conditioned_matrix(getkey, size, dtype)
return matrix @ matrix.T.conj()
# -- Core behaviour --
def test_mv(getkey):
"""invert(A).mv(v) solves A x = v."""
matrix = _well_conditioned_matrix(getkey)
op = lx.MatrixLinearOperator(matrix)
inv_op = lx.invert(op)
vec = jr.normal(getkey(), (3,), dtype=jnp.float64)
result = inv_op.mv(vec)
expected = jnp.linalg.solve(matrix, vec)
assert tree_allclose(result, expected, atol=1e-10)
def test_composition_identity(getkey):
"""(invert(A) @ A).mv(v) ~ v."""
matrix = _well_conditioned_matrix(getkey)
op = lx.MatrixLinearOperator(matrix)
inv_op = lx.invert(op)
composed = inv_op @ op
vec = jr.normal(getkey(), (3,), dtype=jnp.float64)
result = composed.mv(vec)
assert tree_allclose(result, vec, atol=1e-10)
def test_double_inverse(getkey):
"""invert(invert(A)).mv(v) ~ A.mv(v)."""
matrix = _well_conditioned_matrix(getkey)
op = lx.MatrixLinearOperator(matrix)
double_inv = lx.invert(lx.invert(op))
vec = jr.normal(getkey(), (3,), dtype=jnp.float64)
result = double_inv.mv(vec)
expected = matrix @ vec
assert tree_allclose(result, expected, atol=1e-8)
# -- Pseudoinverse (non-square) --
def test_pseudoinverse_overdetermined(getkey):
"""invert of a tall matrix gives the least-squares pseudoinverse."""
matrix = jr.normal(getkey(), (5, 3), dtype=jnp.float64)
op = lx.MatrixLinearOperator(matrix)
pinv_op = lx.invert(op, solver=lx.AutoLinearSolver(well_posed=False))
vec = jr.normal(getkey(), (5,), dtype=jnp.float64)
result = pinv_op.mv(vec)
expected = jnp.linalg.lstsq(matrix, vec)[0]
assert tree_allclose(result, expected, atol=1e-8)
def test_pseudoinverse_underdetermined(getkey):
"""invert of a wide matrix gives the minimum-norm pseudoinverse."""
matrix = jr.normal(getkey(), (3, 5), dtype=jnp.float64)
op = lx.MatrixLinearOperator(matrix)
pinv_op = lx.invert(op, solver=lx.AutoLinearSolver(well_posed=False))
vec = jr.normal(getkey(), (3,), dtype=jnp.float64)
result = pinv_op.mv(vec)
expected = jnp.linalg.lstsq(matrix, vec)[0]
assert tree_allclose(result, expected, atol=1e-8)
# -- Explicit solver tests --
def test_solver_cholesky(getkey):
"""Works with Cholesky solver for PSD matrices."""
matrix = _well_conditioned_psd_matrix(getkey)
op = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag)
inv_op = lx.invert(op, solver=lx.Cholesky())
vec = jr.normal(getkey(), (3,), dtype=jnp.float64)
result = inv_op.mv(vec)
expected = jnp.linalg.solve(matrix, vec)
assert tree_allclose(result, expected, atol=1e-10)
def test_solver_cg(getkey):
"""Works with CG (iterative) solver for PSD matrices."""
matrix = _well_conditioned_psd_matrix(getkey)
op = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag)
inv_op = lx.invert(op, solver=lx.CG(rtol=1e-12, atol=1e-12))
vec = jr.normal(getkey(), (3,), dtype=jnp.float64)
result = inv_op.mv(vec)
expected = jnp.linalg.solve(matrix, vec)
assert tree_allclose(result, expected, atol=1e-8)
# -- vmap --
def test_vmap(getkey):
"""vmap over invert(A).mv works correctly."""
matrix = _well_conditioned_matrix(getkey)
op = lx.MatrixLinearOperator(matrix)
inv_op = lx.invert(op)
vecs = jr.normal(getkey(), (5, 3), dtype=jnp.float64)
result = jax.vmap(inv_op.mv)(vecs)
expected = jax.vmap(lambda v: jnp.linalg.solve(matrix, v))(vecs)
assert tree_allclose(result, expected, atol=1e-10)
# -- AD --
def test_grad_wrt_vector(getkey):
"""VJP through invert(A).mv(v) wrt vector."""
matrix = _well_conditioned_matrix(getkey)
op = lx.MatrixLinearOperator(matrix)
inv_op = lx.invert(op)
def f(vec):
return jnp.sum(inv_op.mv(vec))
vec = jr.normal(getkey(), (3,), dtype=jnp.float64)
grad = jax.grad(f)(vec)
expected = jnp.linalg.solve(matrix.T, jnp.ones(3, dtype=jnp.float64))
assert tree_allclose(grad, expected, atol=1e-10)
def test_jvp_wrt_vector(getkey):
"""JVP through invert(A).mv(v) wrt vector."""
matrix = _well_conditioned_matrix(getkey)
op = lx.MatrixLinearOperator(matrix)
inv_op = lx.invert(op)
vec = jr.normal(getkey(), (3,), dtype=jnp.float64)
t_vec = jr.normal(getkey(), (3,), dtype=jnp.float64)
primals, tangents = eqx.filter_jvp(inv_op.mv, (vec,), (t_vec,))
expected_primals = jnp.linalg.solve(matrix, vec)
expected_tangents = jnp.linalg.solve(matrix, t_vec)
assert tree_allclose(primals, expected_primals, atol=1e-10)
assert tree_allclose(tangents, expected_tangents, atol=1e-10)
def test_grad_wrt_operator(getkey):
"""VJP through invert(A).mv(v) wrt the inner matrix."""
matrix = _well_conditioned_matrix(getkey)
vec = jr.normal(getkey(), (3,), dtype=jnp.float64)
def f_inv(mat):
op = lx.MatrixLinearOperator(mat)
inv_op = lx.invert(op)
return jnp.sum(inv_op.mv(vec))
def f_jnp(mat):
return jnp.sum(jnp.linalg.solve(mat, vec))
grad_inv = jax.grad(f_inv)(matrix)
grad_jnp = jax.grad(f_jnp)(matrix)
assert tree_allclose(grad_inv, grad_jnp, atol=1e-8)
def test_jvp_wrt_operator(getkey):
"""JVP through invert(A).mv(v) wrt the inner matrix."""
matrix = _well_conditioned_matrix(getkey)
t_matrix = jr.normal(getkey(), (3, 3), dtype=jnp.float64)
vec = jr.normal(getkey(), (3,), dtype=jnp.float64)
def f_inv(mat):
op = lx.MatrixLinearOperator(mat)
inv_op = lx.invert(op)
return inv_op.mv(vec)
def f_jnp(mat):
return jnp.linalg.solve(mat, vec)
out, t_out = eqx.filter_jvp(f_inv, (matrix,), (t_matrix,))
expected_out, expected_t_out = eqx.filter_jvp(f_jnp, (matrix,), (t_matrix,))
assert tree_allclose(out, expected_out, atol=1e-10)
assert tree_allclose(t_out, expected_t_out, atol=1e-8)
================================================
FILE: tests/test_jvp.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools as ft
import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import pytest
from .helpers import (
construct_matrix,
construct_singular_matrix,
finite_difference_jvp,
has_tag,
make_jac_operator,
make_matrix_operator,
solvers_tags_pseudoinverse,
tree_allclose,
)
@pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator))
@pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse)
@pytest.mark.parametrize("use_state", (True, False))
@pytest.mark.parametrize(
"make_matrix",
(
construct_matrix,
construct_singular_matrix,
),
)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_jvp(
getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype
):
t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None
if (make_matrix is construct_matrix) or pseudoinverse:
matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype)
out_size, _ = matrix.shape
vec = jr.normal(getkey(), (out_size,), dtype=dtype)
t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)
if has_tag(tags, lx.unit_diagonal_tag):
# For all the other tags, A + εB with A, B \in {matrices satisfying the tag}
# still satisfies the tag itself.
# This is the exception.
t_matrix.at[jnp.arange(3), jnp.arange(3)].set(0)
make_op = ft.partial(make_operator, getkey)
operator, t_operator = eqx.filter_jvp(
make_op, (matrix, tags), (t_matrix, t_tags)
)
if use_state:
state = solver.init(operator, options={})
linear_solve = ft.partial(lx.linear_solve, state=state)
else:
linear_solve = lx.linear_solve
solve_vec_only = lambda v: linear_solve(operator, v, solver).value
solve_op_only = lambda op: linear_solve(op, vec, solver).value
solve_op_vec = lambda op, v: linear_solve(op, v, solver).value
vec_out, t_vec_out = eqx.filter_jvp(solve_vec_only, (vec,), (t_vec,))
op_out, t_op_out = eqx.filter_jvp(solve_op_only, (operator,), (t_operator,))
op_vec_out, t_op_vec_out = eqx.filter_jvp(
solve_op_vec,
(operator, vec),
(t_operator, t_vec),
)
(expected_op_out, *_), (t_expected_op_out, *_) = eqx.filter_jvp(
lambda op: jnp.linalg.lstsq(op, vec), # pyright: ignore
(matrix,),
(t_matrix,),
)
(expected_op_vec_out, *_), (t_expected_op_vec_out, *_) = eqx.filter_jvp(
jnp.linalg.lstsq,
(matrix, vec),
(t_matrix, t_vec), # pyright: ignore
)
# Work around JAX issue #14868.
if jnp.any(jnp.isnan(t_expected_op_out)):
_, (t_expected_op_out, *_) = finite_difference_jvp(
lambda op: jnp.linalg.lstsq(op, vec), # pyright: ignore
(matrix,),
(t_matrix,),
)
if jnp.any(jnp.isnan(t_expected_op_vec_out)):
_, (t_expected_op_vec_out, *_) = finite_difference_jvp(
jnp.linalg.lstsq,
(matrix, vec),
(t_matrix, t_vec), # pyright: ignore
)
pinv_matrix = jnp.linalg.pinv(matrix) # pyright: ignore
expected_vec_out = pinv_matrix @ vec
assert tree_allclose(vec_out, expected_vec_out)
assert tree_allclose(op_out, expected_op_out)
assert tree_allclose(op_vec_out, expected_op_vec_out)
t_expected_vec_out = pinv_matrix @ t_vec
assert tree_allclose(matrix @ t_vec_out, matrix @ t_expected_vec_out, rtol=1e-3)
assert tree_allclose(t_op_out, t_expected_op_out, rtol=1e-3)
assert tree_allclose(t_op_vec_out, t_expected_op_vec_out, rtol=1e-3)
================================================
FILE: tests/test_jvp_jvp1.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import equinox as eqx
import jax.numpy as jnp
import pytest
from .helpers import (
construct_matrix,
construct_singular_matrix,
jvp_jvp_impl,
make_jac_operator,
make_matrix_operator,
solvers_tags_pseudoinverse,
)
# Workaround for https://github.com/jax-ml/jax/issues/27201
@pytest.fixture(autouse=True)
def _clear_cache():
eqx.clear_caches()
@pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse)
@pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator))
@pytest.mark.parametrize("use_state", (True, False))
@pytest.mark.parametrize("make_matrix", (construct_matrix, construct_singular_matrix))
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_jvp_jvp(
getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype
):
jvp_jvp_impl(
getkey,
solver,
tags,
pseudoinverse,
make_operator,
use_state,
make_matrix,
dtype,
)
================================================
FILE: tests/test_jvp_jvp2.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import equinox as eqx
import jax.numpy as jnp
import pytest
from .helpers import (
construct_matrix,
construct_singular_matrix,
jvp_jvp_impl,
make_jac_operator,
make_matrix_operator,
solvers_tags_pseudoinverse,
)
# Workaround for https://github.com/jax-ml/jax/issues/27201
@pytest.fixture(autouse=True)
def _clear_cache():
eqx.clear_caches()
@pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse)
@pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator))
@pytest.mark.parametrize("use_state", (True, False))
@pytest.mark.parametrize("make_matrix", (construct_matrix, construct_singular_matrix))
@pytest.mark.parametrize("dtype", (jnp.complex128,))
def test_jvp_jvp(
getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype
):
jvp_jvp_impl(
getkey,
solver,
tags,
pseudoinverse,
make_operator,
use_state,
make_matrix,
dtype,
)
================================================
FILE: tests/test_lsmr.py
================================================
import equinox as ex
import jax.numpy as jnp
import lineax as lx
import pytest
solver = lx.LSMR(1e-10, 1e-10)
Aill = lx.DiagonalLinearOperator(jnp.array([1e8, 1e6, 1e4, 1e2, 1]))
Awell = lx.DiagonalLinearOperator(jnp.array([2.0, 4.0, 5.0, 8.0, 10.0]))
Asing = lx.DiagonalLinearOperator(jnp.array([0.0, 4.0, 5.0, 8.0, 10.0]))
def test_ill_conditioned():
try:
lx.linear_solve(Aill, jnp.ones(5), solver=solver)
except ex.EquinoxRuntimeError as e:
assert "Condition number" in str(e)
def test_zero_rhs():
# b=0, so x=0 is solution
sol = lx.linear_solve(Aill, jnp.zeros(5), solver=solver)
assert (sol.value == 0).all()
sol = lx.linear_solve(Awell, jnp.zeros(5), solver=solver)
assert (sol.value == 0).all()
sol = lx.linear_solve(Asing, jnp.zeros(5), solver=solver)
assert (sol.value == 0).all()
# b lies in null space of A, so x=0 is minimum norm solution
sol = lx.linear_solve(Asing, jnp.zeros(5).at[0].set(1), solver=solver)
assert (sol.value == 0).all()
@pytest.mark.skip("Damp support is disabled.")
def test_damp_regularizes():
solution_ill = lx.linear_solve(Aill, jnp.ones(5), solver=solver, options={})
assert solution_ill.stats["istop"] == 1
solution_damped = lx.linear_solve(
Aill, jnp.ones(5), solver=solver, options={"damp": 100.0}
)
assert solution_damped.stats["istop"] == 2
assert solution_damped.stats["num_steps"] < solution_ill.stats["num_steps"]
@pytest.mark.skip("Damp support is disabled.")
def test_damp():
solution_damped = lx.linear_solve(
Awell, jnp.ones(5), solver=solver, options={"damp": 1.0}
)
assert jnp.allclose(
solution_damped.value,
jnp.array([0.4, 0.23529412, 0.19230769, 0.12307692, 0.0990099]),
)
solution_damped = lx.linear_solve(
Awell, jnp.ones(5), solver=solver, options={"damp": 1000.0}
)
assert jnp.allclose(
solution_damped.value, jnp.array([2e-6, 4e-6, 5e-6, 8e-6, 10.0e-6])
)
================================================
FILE: tests/test_misc.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
import jax.numpy as jnp
import lineax as lx
import lineax._misc as lx_misc
import pytest
def test_inexact_asarray_no_copy():
x = jnp.array([1.0])
assert lx_misc.inexact_asarray(x) is x
y = jnp.array([1.0, 2.0])
assert jax.vmap(lx_misc.inexact_asarray)(y) is y
# See JAX issue #15676
def test_inexact_asarray_jvp():
p, t = jax.jvp(lx_misc.inexact_asarray, (1.0,), (2.0,))
assert type(p) is not float
assert type(t) is not float
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_zero_matrix(dtype):
A = lx.MatrixLinearOperator(jnp.zeros((2, 2), dtype=dtype))
b = jnp.array([1.0, 2.0], dtype=dtype)
lx.linear_solve(A, b, lx.SVD())
================================================
FILE: tests/test_norm.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
import jax.flatten_util as jfu
import jax.numpy as jnp
import lineax.internal as lxi
from .helpers import tree_allclose
def _square(x):
return x * jnp.conj(x)
def _two_norm(x):
return jnp.sqrt(jnp.sum(_square(jfu.ravel_pytree(x)[0]))).real
def _rms_norm(x):
return jnp.sqrt(jnp.mean(_square(jfu.ravel_pytree(x)[0]))).real
def _max_norm(x):
return jnp.max(jnp.abs(jfu.ravel_pytree(x)[0]))
def test_nonzero():
zero = [jnp.array(0.0), jnp.zeros((2, 2))]
x = [jnp.array(1.0), jnp.arange(4.0).reshape(2, 2)]
tx = [jnp.array(0.5), jnp.arange(1.0, 5.0).reshape(2, 2)]
two = lxi.two_norm(x)
rms = lxi.rms_norm(x)
max = lxi.max_norm(x)
true_two = _two_norm(x)
true_rms = _rms_norm(x)
true_max = _max_norm(x)
assert jnp.allclose(two, true_two)
assert jnp.allclose(rms, true_rms)
assert jnp.allclose(max, true_max)
two_jvp = jax.jvp(lxi.two_norm, (x,), (tx,))
true_two_jvp = jax.jvp(_two_norm, (x,), (tx,))
rms_jvp = jax.jvp(lxi.rms_norm, (x,), (tx,))
true_rms_jvp = jax.jvp(_rms_norm, (x,), (tx,))
max_jvp = jax.jvp(lxi.max_norm, (x,), (tx,))
true_max_jvp = jax.jvp(_max_norm, (x,), (tx,))
assert tree_allclose(two_jvp, true_two_jvp)
assert tree_allclose(rms_jvp, true_rms_jvp)
assert tree_allclose(max_jvp, true_max_jvp)
two0_jvp = jax.jvp(lxi.two_norm, (x,), (zero,))
rms0_jvp = jax.jvp(lxi.rms_norm, (x,), (zero,))
max0_jvp = jax.jvp(lxi.max_norm, (x,), (zero,))
assert tree_allclose(two0_jvp, (true_two, jnp.array(0.0)))
assert tree_allclose(rms0_jvp, (true_rms, jnp.array(0.0)))
assert tree_allclose(max0_jvp, (true_max, jnp.array(0.0)))
def test_zero():
zero = [jnp.array(0.0), jnp.zeros((2, 2))]
tx = [jnp.array(0.5), jnp.arange(1.0, 5.0).reshape(2, 2)]
for t in (zero, tx):
two0 = jax.jvp(lxi.two_norm, (zero,), (t,))
rms0 = jax.jvp(lxi.rms_norm, (zero,), (t,))
max0 = jax.jvp(lxi.max_norm, (zero,), (t,))
true0 = (jnp.array(0.0), jnp.array(0.0))
assert tree_allclose(two0, true0)
assert tree_allclose(rms0, true0)
assert tree_allclose(max0, true0)
def test_complex():
x = jnp.array([3 + 1.2j, -0.5 + 4.9j])
tx = jnp.array([2 - 0.3j, -0.7j])
two = jax.jvp(lxi.two_norm, (x,), (tx,))
true_two = jax.jvp(_two_norm, (x,), (tx,))
rms = jax.jvp(lxi.rms_norm, (x,), (tx,))
true_rms = jax.jvp(_rms_norm, (x,), (tx,))
max = jax.jvp(lxi.max_norm, (x,), (tx,))
true_max = jax.jvp(_max_norm, (x,), (tx,))
assert two[0].imag == 0
assert tree_allclose(two, true_two)
assert rms[0].imag == 0
assert tree_allclose(rms, true_rms)
assert max[0].imag == 0
assert tree_allclose(max, true_max)
def test_size_zero():
zero = jnp.array(0.0)
for x in (jnp.array([]), [jnp.array([]), jnp.array([])]):
assert tree_allclose(lxi.two_norm(x), zero)
assert tree_allclose(lxi.rms_norm(x), zero)
assert tree_allclose(lxi.max_norm(x), zero)
assert tree_allclose(jax.jvp(lxi.two_norm, (x,), (x,)), (zero, zero))
assert tree_allclose(jax.jvp(lxi.rms_norm, (x,), (x,)), (zero, zero))
assert tree_allclose(jax.jvp(lxi.max_norm, (x,), (x,)), (zero, zero))
================================================
FILE: tests/test_operator.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import pytest
from .helpers import (
make_identity_operator,
make_jacrev_operator,
make_operators,
make_tridiagonal_operator,
make_trivial_diagonal_operator,
tree_allclose,
)
@pytest.mark.parametrize("make_operator", make_operators)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_ops(make_operator, getkey, dtype):
if (
make_operator is make_trivial_diagonal_operator
or make_operator is make_identity_operator
):
matrix = jnp.eye(3, dtype=dtype)
tags = lx.diagonal_tag
elif make_operator is make_tridiagonal_operator:
matrix = jnp.eye(3, dtype=dtype)
tags = lx.tridiagonal_tag
else:
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
tags = ()
if make_operator is make_jacrev_operator and dtype is jnp.complex128:
# JacobianLinearOperator does not support complex dtypes when jac="bwd"
return
matrix1 = make_operator(getkey, matrix, tags)
matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3), dtype=dtype))
scalar = jr.normal(getkey(), (), dtype=dtype)
add = matrix1 + matrix2
composed = matrix1 @ matrix2
mul = matrix1 * scalar
rmul = cast(lx.AbstractLinearOperator, scalar * matrix1)
div = matrix1 / scalar
vec = jr.normal(getkey(), (3,), dtype=dtype)
assert tree_allclose(matrix1.mv(vec) + matrix2.mv(vec), add.mv(vec))
assert tree_allclose(matrix1.mv(matrix2.mv(vec)), composed.mv(vec))
scalar_matvec = scalar * matrix1.mv(vec)
assert tree_allclose(scalar_matvec, mul.mv(vec))
assert tree_allclose(scalar_matvec, rmul.mv(vec))
assert tree_allclose(matrix1.mv(vec) / scalar, div.mv(vec))
add_matrix = matrix1.as_matrix() + matrix2.as_matrix()
composed_matrix = matrix1.as_matrix() @ matrix2.as_matrix()
mul_matrix = scalar * matrix1.as_matrix()
div_matrix = matrix1.as_matrix() / scalar
assert tree_allclose(add_matrix, add.as_matrix())
assert tree_allclose(composed_matrix, composed.as_matrix())
assert tree_allclose(mul_matrix, mul.as_matrix())
assert tree_allclose(mul_matrix, rmul.as_matrix())
assert tree_allclose(div_matrix, div.as_matrix())
assert tree_allclose(add_matrix.T, add.T.as_matrix())
assert tree_allclose(composed_matrix.T, composed.T.as_matrix())
assert tree_allclose(mul_matrix.T, mul.T.as_matrix())
assert tree_allclose(mul_matrix.T, rmul.T.as_matrix())
assert tree_allclose(div_matrix.T, div.T.as_matrix())
@pytest.mark.parametrize("make_operator", make_operators)
def test_structures_vector(make_operator, getkey):
if (
make_operator is make_trivial_diagonal_operator
or make_operator is make_identity_operator
):
matrix = jnp.eye(4)
tags = lx.diagonal_tag
in_size = out_size = 4
elif make_operator is make_tridiagonal_operator:
matrix = jnp.eye(4)
tags = lx.tridiagonal_tag
in_size = out_size = 4
else:
matrix = jr.normal(getkey(), (3, 5))
tags = ()
in_size = 5
out_size = 3
operator = make_operator(getkey, matrix, tags)
in_structure = jax.ShapeDtypeStruct((in_size,), jnp.float64)
out_structure = jax.ShapeDtypeStruct((out_size,), jnp.float64)
assert tree_allclose(in_structure, operator.in_structure())
assert tree_allclose(out_structure, operator.out_structure())
def _setup(getkey, matrix, tag: object | frozenset[object] = frozenset()):
for make_operator in make_operators:
if make_operator is make_trivial_diagonal_operator and tag != lx.diagonal_tag:
continue
if make_operator is make_tridiagonal_operator and tag not in (
lx.tridiagonal_tag,
lx.diagonal_tag,
lx.symmetric_tag,
):
continue
if make_operator is make_identity_operator and tag not in (
lx.tridiagonal_tag,
lx.diagonal_tag,
lx.symmetric_tag,
):
continue
operator = make_operator(getkey, matrix, tag)
yield operator
def _assert_except_diag(cond_fun, operators, flip_cond):
if flip_cond:
_cond_fun = cond_fun
cond_fun = lambda x: not _cond_fun(x)
for operator in operators:
if isinstance(operator, lx.DiagonalLinearOperator):
assert not cond_fun(operator)
else:
assert cond_fun(operator)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_linearise(dtype, getkey):
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
operators = list(_setup(getkey, matrix))
vec = jr.normal(getkey(), (3,), dtype=dtype)
for operator in operators:
# Skip jacrev operators with complex dtype (jacrev doesn't support complex)
if (
isinstance(operator, lx.JacobianLinearOperator)
and operator.jac == "bwd"
and dtype is jnp.complex128
):
continue
linearised = lx.linearise(operator)
# Actually evaluate the linearised operator to ensure it works
result = linearised.mv(vec)
expected = operator.mv(vec)
assert tree_allclose(result, expected)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_materialise(dtype, getkey):
operators = _setup(getkey, jr.normal(getkey(), (3, 3), dtype=dtype))
for operator in operators:
lx.materialise(operator)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_materialise_large(dtype, getkey):
operators = _setup(getkey, jr.normal(getkey(), (200, 500), dtype=dtype))
for operator in operators:
lx.materialise(operator)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_diagonal(dtype, getkey):
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
matrix_diag = jnp.diag(matrix)
# test we properly extract diagonal from a dense matrix when not tagged
operators = _setup(getkey, matrix)
for operator in operators:
assert jnp.allclose(lx.diagonal(operator), matrix_diag)
# test we properly extract diagonal from diagonal matrix when tagged
operators = _setup(getkey, jnp.diag(matrix_diag), lx.diagonal_tag)
for operator in operators:
if isinstance(operator, lx.IdentityLinearOperator):
assert jnp.allclose(lx.diagonal(operator), jnp.ones(3))
else:
assert jnp.allclose(lx.diagonal(operator), matrix_diag)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_tridiagonal(dtype, getkey):
matrix = jr.normal(getkey(), (5, 5), dtype=dtype)
matrix_diag = jnp.diag(matrix)
matrix_lower_diag = jnp.diag(matrix, k=-1)
matrix_upper_diag = jnp.diag(matrix, k=1)
tridiag_matrix = (
jnp.diag(matrix_diag)
+ jnp.diag(matrix_lower_diag, k=-1)
+ jnp.diag(matrix_upper_diag, k=1)
)
operators = _setup(getkey, tridiag_matrix, lx.tridiagonal_tag)
for operator in operators:
diag, lower_diag, upper_diag = lx.tridiagonal(operator)
if isinstance(operator, lx.IdentityLinearOperator):
assert jnp.allclose(diag, jnp.ones(5))
assert jnp.allclose(lower_diag, jnp.zeros(4))
assert jnp.allclose(upper_diag, jnp.zeros(4))
else:
assert jnp.allclose(diag, matrix_diag)
assert jnp.allclose(lower_diag, matrix_lower_diag)
assert jnp.allclose(upper_diag, matrix_upper_diag)
# Test ComposedLinearOperator: diagonal @ tridiagonal and tridiagonal @ diagonal
random_diag = jr.normal(getkey(), (5,), dtype=dtype)
tridiag_op = lx.TridiagonalLinearOperator(
matrix_diag, matrix_lower_diag, matrix_upper_diag
)
diag_op = lx.DiagonalLinearOperator(random_diag)
# diagonal @ tridiagonal (row scaling)
dt_matrix = jnp.matmul(jnp.diag(random_diag), tridiag_matrix)
diag, lower_diag, upper_diag = lx.tridiagonal(diag_op @ tridiag_op)
assert jnp.allclose(diag, jnp.diagonal(dt_matrix, 0))
assert jnp.allclose(lower_diag, jnp.diagonal(dt_matrix, -1))
assert jnp.allclose(upper_diag, jnp.diagonal(dt_matrix, 1))
# tridiagonal @ diagonal (column scaling)
td_matrix = jnp.matmul(tridiag_matrix, jnp.diag(random_diag))
diag, lower_diag, upper_diag = lx.tridiagonal(tridiag_op @ diag_op)
assert jnp.allclose(diag, jnp.diagonal(td_matrix, 0))
assert jnp.allclose(lower_diag, jnp.diagonal(td_matrix, -1))
assert jnp.allclose(upper_diag, jnp.diagonal(td_matrix, 1))
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_is_symmetric(dtype, getkey):
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
symmetric_operators = _setup(getkey, matrix.T @ matrix, lx.symmetric_tag)
for operator in symmetric_operators:
assert lx.is_symmetric(operator)
not_symmetric_operators = _setup(getkey, matrix)
_assert_except_diag(lx.is_symmetric, not_symmetric_operators, flip_cond=True)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_is_diagonal(dtype, getkey):
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
diagonal_operators = _setup(getkey, jnp.diag(jnp.diag(matrix)), lx.diagonal_tag)
for operator in diagonal_operators:
assert lx.is_diagonal(operator)
not_diagonal_operators = _setup(getkey, matrix)
_assert_except_diag(lx.is_diagonal, not_diagonal_operators, flip_cond=True)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_is_diagonal_scalar(dtype, getkey):
matrix = jr.normal(getkey(), (1, 1), dtype=dtype)
diagonal_operators = _setup(getkey, matrix)
for operator in diagonal_operators:
assert lx.is_diagonal(operator)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_is_diagonal_tridiagonal(dtype, getkey):
diag1 = jr.normal(getkey(), (1,), dtype=dtype)
diag2 = jnp.zeros((0,), dtype=dtype)
op1 = lx.TridiagonalLinearOperator(diag1, diag2, diag2)
assert lx.is_diagonal(op1)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_has_unit_diagonal(dtype, getkey):
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
not_unit_diagonal = _setup(getkey, matrix)
for operator in not_unit_diagonal:
assert not lx.has_unit_diagonal(operator)
matrix_unit_diag = matrix.at[jnp.arange(3), jnp.arange(3)].set(1)
unit_diagonal = _setup(getkey, matrix_unit_diag, lx.unit_diagonal_tag)
_assert_except_diag(lx.has_unit_diagonal, unit_diagonal, flip_cond=False)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_is_lower_triangular(dtype, getkey):
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
lower_triangular = _setup(getkey, jnp.tril(matrix), lx.lower_triangular_tag)
for operator in lower_triangular:
assert lx.is_lower_triangular(operator)
not_lower_triangular = _setup(getkey, matrix)
_assert_except_diag(lx.is_lower_triangular, not_lower_triangular, flip_cond=True)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_is_upper_triangular(dtype, getkey):
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
upper_triangular = _setup(getkey, jnp.triu(matrix), lx.upper_triangular_tag)
for operator in upper_triangular:
assert lx.is_upper_triangular(operator)
not_upper_triangular = _setup(getkey, matrix)
_assert_except_diag(lx.is_upper_triangular, not_upper_triangular, flip_cond=True)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_is_positive_semidefinite(dtype, getkey):
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
not_positive_semidefinite = _setup(getkey, matrix)
for operator in not_positive_semidefinite:
assert not lx.is_positive_semidefinite(operator)
positive_semidefinite = _setup(
getkey, matrix.T.conj() @ matrix, lx.positive_semidefinite_tag
)
_assert_except_diag(
lx.is_positive_semidefinite, positive_semidefinite, flip_cond=False
)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_is_negative_semidefinite(dtype, getkey):
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
not_negative_semidefinite = _setup(getkey, matrix)
for operator in not_negative_semidefinite:
assert not lx.is_negative_semidefinite(operator)
negative_semidefinite = _setup(
getkey, -matrix.T.conj() @ matrix, lx.negative_semidefinite_tag
)
_assert_except_diag(
lx.is_negative_semidefinite, negative_semidefinite, flip_cond=False
)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_is_tridiagonal(dtype, getkey):
diag1 = jr.normal(getkey(), (5,), dtype=dtype)
diag2 = jr.normal(getkey(), (4,), dtype=dtype)
diag3 = jr.normal(getkey(), (4,), dtype=dtype)
op1 = lx.TridiagonalLinearOperator(diag1, diag2, diag3)
op2 = lx.IdentityLinearOperator(jax.eval_shape(lambda: diag1))
op3 = lx.MatrixLinearOperator(jnp.diag(diag1))
assert lx.is_tridiagonal(op1)
assert lx.is_tridiagonal(op2)
assert not lx.is_tridiagonal(op3)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_tangent_as_matrix(dtype, getkey):
def _list_setup(matrix):
# Exclude jacrev operator: jac="bwd" uses custom_vjp which doesn't support JVP
return [
op
for op in _setup(getkey, matrix)
if not (isinstance(op, lx.JacobianLinearOperator) and op.jac == "bwd")
]
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
t_matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
operators, t_operators = eqx.filter_jvp(_list_setup, (matrix,), (t_matrix,))
for operator, t_operator in zip(operators, t_operators):
t_operator = lx.TangentLinearOperator(operator, t_operator)
if isinstance(operator, lx.DiagonalLinearOperator):
assert jnp.allclose(operator.as_matrix(), jnp.diag(jnp.diag(matrix)))
assert jnp.allclose(t_operator.as_matrix(), jnp.diag(jnp.diag(t_matrix)))
else:
assert jnp.allclose(operator.as_matrix(), matrix)
assert jnp.allclose(t_operator.as_matrix(), t_matrix)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_materialise_function_linear_operator(dtype, getkey):
x = (
jr.normal(getkey(), (5, 9), dtype=dtype),
jr.normal(getkey(), (3,), dtype=dtype),
)
input_structure = jax.eval_shape(lambda: x)
fn = lambda x: {"a": jnp.broadcast_to(jnp.sum(x[0]), (1, 2))}
output_structure = jax.eval_shape(fn, input_structure)
operator = lx.FunctionLinearOperator(fn, input_structure)
materialised_operator = lx.materialise(operator)
assert materialised_operator.in_structure() == input_structure
assert materialised_operator.out_structure() == output_structure
assert isinstance(materialised_operator, lx.PyTreeLinearOperator)
expected_struct = {
"a": (
jax.ShapeDtypeStruct((1, 2, 5, 9), dtype),
jax.ShapeDtypeStruct((1, 2, 3), dtype),
)
}
assert jax.eval_shape(lambda: materialised_operator.pytree) == expected_struct
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_pytree_transpose(dtype, getkey):
out_struct = jax.eval_shape(
lambda: ({"a": jnp.zeros((2, 3, 3), dtype=dtype)}, jnp.zeros((2,), dtype=dtype))
)
in_struct = jax.eval_shape(lambda: {"b": jnp.zeros((4,), dtype=dtype)})
leaf1 = jr.normal(getkey(), (2, 3, 3, 4), dtype=dtype)
leaf2 = jr.normal(getkey(), (2, 4), dtype=dtype)
pytree = ({"a": {"b": leaf1}}, {"b": leaf2})
operator = lx.PyTreeLinearOperator(pytree, out_struct)
assert operator.in_structure() == in_struct
assert operator.out_structure() == out_struct
leaf1_T = jnp.moveaxis(leaf1, -1, 0)
leaf2_T = jnp.moveaxis(leaf2, -1, 0)
pytree_T = {"b": ({"a": leaf1_T}, leaf2_T)}
operator_T = operator.T
assert operator_T.in_structure() == out_struct
assert operator_T.out_structure() == in_struct
assert eqx.tree_equal(operator_T.pytree, pytree_T) # pyright: ignore
def test_diagonal_tangent():
diag = jnp.array([1.0, 2.0, 3.0])
t_diag = jnp.array([4.0, 5.0, 6.0])
def run(diag):
op = lx.DiagonalLinearOperator(diag)
out = lx.linear_solve(op, jnp.array([1.0, 1.0, 1.0]), solver=lx.Diagonal())
return out.value
jax.jvp(run, (diag,), (t_diag,))
def test_identity_with_different_structures():
structure1 = (
jax.ShapeDtypeStruct((), jnp.float32),
jax.ShapeDtypeStruct((2, 3), jnp.float16),
)
structure2 = {"a": jax.ShapeDtypeStruct((5,), jnp.float32)}
# structure3 = (None, jax.ShapeDtypeStruct((2, 3), jnp.float16))
op1 = lx.IdentityLinearOperator(structure1, structure2)
op2 = lx.IdentityLinearOperator(structure2, structure1)
# op3 = lx.IdentityLinearOperator(structure3, structure2)
assert op1.T == op2
# assert op2.transpose((True, False)) == op3
assert jnp.array_equal(op1.as_matrix(), jnp.eye(5, 7, dtype=jnp.float32))
assert op1.in_size() == 7
assert op1.out_size() == 5
vec1 = (
jnp.array(1.0, dtype=jnp.float32),
jnp.array([[2, 3, 4], [5, 6, 7]], dtype=jnp.float16),
)
vec2 = {"a": jnp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=jnp.float32)}
vec1b = (
jnp.array(1.0, dtype=jnp.float32),
jnp.array([[2, 3, 4], [5, 0, 0]], dtype=jnp.float16),
)
assert tree_allclose(op1.mv(vec1), vec2)
assert tree_allclose(op2.mv(vec2), vec1b)
def test_identity_with_different_structures_complex():
structure1 = (
jax.ShapeDtypeStruct((), jnp.complex128),
jax.ShapeDtypeStruct((2, 3), jnp.float16),
)
structure2 = {"a": jax.ShapeDtypeStruct((5,), jnp.complex128)}
# structure3 = (None, jax.ShapeDtypeStruct((2, 3), jnp.float16))
op1 = lx.IdentityLinearOperator(structure1, structure2)
op2 = lx.IdentityLinearOperator(structure2, structure1)
# op3 = lx.IdentityLinearOperator(structure3, structure2)
assert op1.T == op2
# assert op2.transpose((True, False)) == op3
assert jnp.array_equal(op1.as_matrix(), jnp.eye(5, 7, dtype=jnp.complex128))
assert op1.in_size() == 7
assert op1.out_size() == 5
vec1 = (
jnp.array(1.0, dtype=jnp.complex128),
jnp.array([[2, 3, 4], [5, 6, 7]], dtype=jnp.float16),
)
vec2 = {"a": jnp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=jnp.complex128)}
vec1b = (
jnp.array(1.0, dtype=jnp.complex128),
jnp.array([[2, 3, 4], [5, 0, 0]], dtype=jnp.float16),
)
assert tree_allclose(op1.mv(vec1), vec2)
assert tree_allclose(op2.mv(vec2), vec1b)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_zero_pytree_as_matrix(dtype):
a = jnp.array([], dtype=dtype).reshape(2, 1, 0, 2, 1, 0)
struct = jax.ShapeDtypeStruct((2, 1, 0), a.dtype)
op = lx.PyTreeLinearOperator(a, struct)
assert op.as_matrix().shape == (0, 0)
def test_jacrev_operator():
# Test that custom_vjp is respected. The custom backward multiplies by 3
# instead of the true derivative (which would be 2).
# This tests that lineax uses the custom_vjp, not the true derivative.
@jax.custom_vjp
def f(x, _):
return dict(foo=x["bar"] * 2) # forward: multiply by 2
def f_fwd(x, _):
return f(x, None), None
def f_bwd(_, g):
# Custom backward: multiply by 3 (not the true derivative 2)
# This must be linear in g for linear_transpose to work correctly.
return dict(bar=g["foo"] * 3), None
f.defvjp(f_fwd, f_bwd)
x = dict(bar=jnp.arange(2.0))
rev_op = lx.JacobianLinearOperator(f, x, jac="bwd")
# Jacobian is 3*I (from custom backward, not 2*I from true derivative)
as_matrix = jnp.array([[3.0, 0.0], [0.0, 3.0]])
assert tree_allclose(rev_op.as_matrix(), as_matrix)
y = dict(bar=jnp.arange(2.0) + 1) # y = [1, 2]
true_out = dict(foo=jnp.array([3.0, 6.0])) # 3*I @ [1, 2] = [3, 6]
for op in (rev_op, lx.materialise(rev_op)):
out = op.mv(y)
assert tree_allclose(out, true_out)
fwd_op = lx.JacobianLinearOperator(f, x, jac="fwd")
with pytest.raises(TypeError, match="can't apply forward-mode autodiff"):
fwd_op.mv(y)
with pytest.raises(TypeError, match="can't apply forward-mode autodiff"):
lx.materialise(fwd_op)
================================================
FILE: tests/test_singular.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools as ft
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import pytest
from .helpers import (
construct_singular_matrix,
finite_difference_jvp,
make_jac_operator,
make_matrix_operator,
ops,
params,
tol,
tree_allclose,
)
@pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=True))
@pytest.mark.parametrize("ops", ops)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_small_singular(make_operator, solver, tags, ops, getkey, dtype):
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-10
else:
tol = 1e-4
(matrix,) = construct_singular_matrix(getkey, solver, tags, dtype=dtype)
operator = make_operator(getkey, matrix, tags)
operator, matrix = ops(operator, matrix)
assert tree_allclose(operator.as_matrix(), matrix, rtol=tol, atol=tol)
out_size, in_size = matrix.shape
true_x = jr.normal(getkey(), (in_size,), dtype=dtype)
b = matrix @ true_x
x = lx.linear_solve(operator, b, solver=solver, throw=False).value
jax_x, *_ = jnp.linalg.lstsq(matrix, b) # pyright: ignore
assert tree_allclose(x, jax_x, atol=tol, rtol=tol)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_bicgstab_breakdown(getkey, dtype):
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-10
else:
tol = 1e-4
solver = lx.GMRES(atol=tol, rtol=tol, restart=2)
matrix = jr.normal(jr.PRNGKey(0), (100, 100), dtype=dtype)
true_x = jr.normal(jr.PRNGKey(0), (100,), dtype=dtype)
b = matrix @ true_x
operator = lx.MatrixLinearOperator(matrix)
# result != 0 implies lineax reported failure
lx_soln = lx.linear_solve(operator, b, solver, throw=False)
assert jnp.all(lx_soln.result != lx.RESULTS.successful)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_gmres_stagnation_or_breakdown(getkey, dtype):
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-10
else:
tol = 1e-4
solver = lx.GMRES(atol=tol, rtol=tol, restart=2)
matrix = jnp.array(
[
[0.15892892, 0.05884365, -0.60427412, 0.1891916],
[-1.5484863, 0.93608822, 1.94888868, 1.37069667],
[0.62687318, -0.13996738, -0.6824359, 0.30975754],
[-0.67428635, 1.52372255, -0.88277754, 0.69633816],
],
dtype=dtype,
)
true_x = jnp.array([0.51383273, 1.72983427, -0.43251078, -1.11764668], dtype=dtype)
b = matrix @ true_x
operator = lx.MatrixLinearOperator(matrix)
# result != 0 implies lineax reported failure
lx_soln = lx.linear_solve(operator, b, solver, throw=False)
assert jnp.all(lx_soln.result != lx.RESULTS.successful)
@pytest.mark.parametrize(
"solver",
(
lx.AutoLinearSolver(well_posed=None),
lx.QR(),
lx.SVD(),
lx.LSMR(atol=tol, rtol=tol),
lx.Normal(lx.Cholesky()),
lx.Normal(lx.SVD()),
),
)
def test_nonsquare_pytree_operator1(solver):
x = [[1, 5.0, jnp.array(-1.0)], [jnp.array(-2), jnp.array(-2.0), 3.0]]
y = [3.0, 4]
struct = jax.eval_shape(lambda: y)
operator = lx.PyTreeLinearOperator(x, struct)
out = lx.linear_solve(operator, y, solver=solver).value
matrix = jnp.array([[1.0, 5.0, -1.0], [-2.0, -2.0, 3.0]])
true_out, _, _, _ = jnp.linalg.lstsq(matrix, jnp.array(y)) # pyright: ignore
true_out = [true_out[0], true_out[1], true_out[2]]
assert tree_allclose(out, true_out)
@pytest.mark.parametrize(
"solver",
(
lx.AutoLinearSolver(well_posed=None),
lx.QR(),
lx.SVD(),
lx.LSMR(atol=tol, rtol=tol),
lx.Normal(lx.Cholesky()),
lx.Normal(lx.SVD()),
),
)
def test_nonsquare_pytree_operator2(solver):
x = [[1, jnp.array(-2)], [5.0, jnp.array(-2.0)], [jnp.array(-1.0), 3.0]]
y = [3.0, 4, 5.0]
struct = jax.eval_shape(lambda: y)
operator = lx.PyTreeLinearOperator(x, struct)
out = lx.linear_solve(operator, y, solver=solver).value
matrix = jnp.array([[1.0, -2.0], [5.0, -2.0], [-1.0, 3.0]])
true_out, _, _, _ = jnp.linalg.lstsq(matrix, jnp.array(y)) # pyright: ignore
true_out = [true_out[0], true_out[1]]
assert tree_allclose(out, true_out)
@pytest.mark.parametrize(
"solver",
(
lx.AutoLinearSolver(well_posed=None),
lx.QR(),
lx.SVD(),
lx.Normal(lx.Cholesky()),
lx.Normal(lx.SVD()),
),
)
@pytest.mark.parametrize("full_rank", (True, False))
@pytest.mark.parametrize("jvp", (False, True))
@pytest.mark.parametrize("wide", (False, True))
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_nonsquare_mat_vec(solver, full_rank, jvp, wide, dtype, getkey):
if wide:
out_size = 3
in_size = 6
else:
out_size = 6
in_size = 3
matrix = jr.normal(getkey(), (out_size, in_size), dtype=dtype)
if not full_rank:
if solver.assume_full_rank():
# There is nothing to test.
return
# nontrivial rank 2 sparsity pattern
matrix = matrix.at[1:, 1:].set(0)
vector = jr.normal(getkey(), (out_size,), dtype=dtype)
lx_solve = lambda mat, vec: lx.linear_solve(
lx.MatrixLinearOperator(mat), vec, solver
).value
jnp_solve = lambda mat, vec: jnp.linalg.lstsq(mat, vec)[0] # pyright: ignore
if jvp:
lx_solve = eqx.filter_jit(ft.partial(eqx.filter_jvp, lx_solve))
jnp_solve = eqx.filter_jit(ft.partial(finite_difference_jvp, jnp_solve))
t_matrix = jr.normal(getkey(), (out_size, in_size), dtype=dtype)
if not full_rank:
# t_matrix must be chosen tangent to the manifold of rank 2
# matrices at matrix. A simple way to achieve this is to make the
# same restriction as we did to matrix
t_matrix = t_matrix.at[1:, 1:].set(0)
t_vector = jr.normal(getkey(), (out_size,), dtype=dtype)
args = ((matrix, vector), (t_matrix, t_vector))
else:
args = (matrix, vector)
x = lx_solve(*args) # pyright: ignore
true_x = jnp_solve(*args)
assert tree_allclose(x, true_x, atol=1e-4, rtol=1e-4)
@pytest.mark.parametrize(
"solver",
(
lx.AutoLinearSolver(well_posed=None),
lx.QR(),
lx.SVD(),
lx.Normal(lx.Cholesky()),
lx.Normal(lx.SVD()),
),
)
@pytest.mark.parametrize("full_rank", (True, False))
@pytest.mark.parametrize("jvp", (False, True))
@pytest.mark.parametrize("wide", (False, True))
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_nonsquare_vec(solver, full_rank, jvp, wide, dtype, getkey):
if wide:
out_size = 3
in_size = 6
else:
out_size = 6
in_size = 3
matrix = jr.normal(getkey(), (out_size, in_size), dtype=dtype)
if not full_rank:
if solver.assume_full_rank():
# There is nothing to test.
return
# nontrivial rank 2 sparsity pattern
matrix = matrix.at[1:, 1:].set(0)
vector = jr.normal(getkey(), (out_size,), dtype=dtype)
lx_solve = lambda vec: lx.linear_solve(
lx.MatrixLinearOperator(matrix), vec, solver
).value
jnp_solve = lambda vec: jnp.linalg.lstsq(matrix, vec)[0] # pyright: ignore
if jvp:
lx_solve = eqx.filter_jit(ft.partial(eqx.filter_jvp, lx_solve))
jnp_solve = eqx.filter_jit(ft.partial(finite_difference_jvp, jnp_solve))
t_vector = jr.normal(getkey(), (out_size,), dtype=dtype)
args = ((vector,), (t_vector,))
else:
args = (vector,)
x = lx_solve(*args) # pyright: ignore
true_x = jnp_solve(*args)
assert tree_allclose(x, true_x, atol=1e-4, rtol=1e-4)
_iterative_solvers = (
(lx.CG(rtol=tol, atol=tol), lx.positive_semidefinite_tag),
(lx.CG(rtol=tol, atol=tol, max_steps=512), lx.negative_semidefinite_tag),
(lx.GMRES(rtol=tol, atol=tol), ()),
(lx.BiCGStab(rtol=tol, atol=tol), ()),
)
@pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator))
@pytest.mark.parametrize("solver, tags", _iterative_solvers)
@pytest.mark.parametrize("use_state", (False, True))
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_iterative_singular(getkey, solver, tags, use_state, make_operator, dtype):
(matrix,) = construct_singular_matrix(getkey, solver, tags)
operator = make_operator(getkey, matrix, tags)
out_size, _ = matrix.shape
vec = jr.normal(getkey(), (out_size,), dtype=dtype)
if use_state:
state = solver.init(operator, options={})
linear_solve = ft.partial(lx.linear_solve, state=state)
else:
linear_solve = lx.linear_solve
with pytest.raises(Exception):
linear_solve(operator, vec, solver)
================================================
FILE: tests/test_solve.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import pytest
from .helpers import construct_poisson_matrix, tree_allclose
def test_gmres_large_dense(getkey):
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-10
else:
tol = 1e-4
solver = lx.GMRES(atol=tol, rtol=tol, restart=100)
matrix = jr.normal(getkey(), (100, 100))
operator = lx.MatrixLinearOperator(matrix)
true_x = jr.normal(getkey(), (100,))
b = matrix @ true_x
lx_soln = lx.linear_solve(operator, b, solver).value
assert tree_allclose(lx_soln, true_x, atol=tol, rtol=tol)
def test_nontrivial_pytree_operator():
x = [[1, 5.0], [jnp.array(-2), jnp.array(-2.0)]]
y = [3, 4]
struct = jax.eval_shape(lambda: y)
operator = lx.PyTreeLinearOperator(x, struct)
out = lx.linear_solve(operator, y).value
true_out = [jnp.array(-3.25), jnp.array(1.25)]
assert tree_allclose(out, true_out)
def test_nontrivial_diagonal_operator():
x = (8.0, jnp.array([1, 2, 3]), {"a": jnp.array([4, 5]), "b": 6})
y = (4.0, jnp.array([7, 8, 9]), {"a": jnp.array([2, 10]), "b": 12})
operator = lx.DiagonalLinearOperator(x)
out = lx.linear_solve(operator, y).value
true_out = (
jnp.array(0.5),
jnp.array([7.0, 4.0, 3.0]),
{"a": jnp.array([0.5, 2.0]), "b": jnp.array(2.0)},
)
assert tree_allclose(out, true_out)
@pytest.mark.parametrize("solver", (lx.LU(), lx.QR(), lx.SVD()))
def test_mixed_dtypes(solver):
f32 = lambda x: jnp.array(x, dtype=jnp.float32)
f64 = lambda x: jnp.array(x, dtype=jnp.float64)
x = [[f32(1), f64(5)], [f32(-2), f64(-2)]]
y = [f64(3), f64(4)]
struct = jax.eval_shape(lambda: y)
operator = lx.PyTreeLinearOperator(x, struct)
out = lx.linear_solve(operator, y, solver=solver).value
true_out = [f32(-3.25), f64(1.25)]
assert tree_allclose(out, true_out)
@pytest.mark.parametrize("solver", (lx.LU(), lx.QR(), lx.SVD()))
def test_mixed_dtypes_complex(solver):
c64 = lambda x: jnp.array(x, dtype=jnp.complex64)
c128 = lambda x: jnp.array(x, dtype=jnp.complex128)
x = [[c64(1), c128(5.0j)], [c64(2.0j), c128(-2)]]
y = [c128(3), c128(4)]
struct = jax.eval_shape(lambda: y)
operator = lx.PyTreeLinearOperator(x, struct)
out = lx.linear_solve(operator, y, solver=solver).value
true_out = [c64(-0.75 - 2.5j), c128(0.5 - 0.75j)]
assert tree_allclose(out, true_out)
@pytest.mark.parametrize("solver", (lx.LU(), lx.QR(), lx.SVD()))
def test_mixed_dtypes_complex_real(solver):
f64 = lambda x: jnp.array(x, dtype=jnp.float64)
c128 = lambda x: jnp.array(x, dtype=jnp.complex128)
x = [[f64(1), c128(-5.0j)], [f64(2.0), c128(-2j)]]
y = [c128(3), c128(4)]
struct = jax.eval_shape(lambda: y)
operator = lx.PyTreeLinearOperator(x, struct)
out = lx.linear_solve(operator, y, solver=solver).value
true_out = [f64(1.75), c128(0.25j)]
assert tree_allclose(out, true_out)
def test_mixed_dtypes_triangular():
f32 = lambda x: jnp.array(x, dtype=jnp.float32)
f64 = lambda x: jnp.array(x, dtype=jnp.float64)
x = [[f32(1), f64(0)], [f32(-2), f64(-2)]]
y = [f64(3), f64(4)]
struct = jax.eval_shape(lambda: y)
operator = lx.PyTreeLinearOperator(x, struct, lx.lower_triangular_tag)
out = lx.linear_solve(operator, y, solver=lx.Triangular()).value
true_out = [f32(3), f64(-5)]
assert tree_allclose(out, true_out)
def test_mixed_dtypes_complex_triangular():
c64 = lambda x: jnp.array(x, dtype=jnp.complex64)
c128 = lambda x: jnp.array(x, dtype=jnp.complex128)
x = [[c64(1), c128(0)], [c64(2.0j), c128(-2)]]
y = [c128(3), c128(4)]
struct = jax.eval_shape(lambda: y)
operator = lx.PyTreeLinearOperator(x, struct, lx.lower_triangular_tag)
out = lx.linear_solve(operator, y, solver=lx.Triangular()).value
true_out = [c64(3), c128(-2 + 3.0j)]
assert tree_allclose(out, true_out)
def test_mixed_dtypes_complex_real_triangular():
f64 = lambda x: jnp.array(x, dtype=jnp.float64)
c128 = lambda x: jnp.array(x, dtype=jnp.complex128)
x = [[f64(1), c128(0)], [f64(2.0), c128(2j)]]
y = [c128(3), c128(4)]
struct = jax.eval_shape(lambda: y)
operator = lx.PyTreeLinearOperator(x, struct, lx.lower_triangular_tag)
out = lx.linear_solve(operator, y, solver=lx.Triangular()).value
true_out = [f64(3), c128(1j)]
assert tree_allclose(out, true_out)
def test_ad_closure_function_linear_operator(getkey):
def f(x, z):
def fn(y):
return x * y
op = lx.FunctionLinearOperator(fn, jax.eval_shape(lambda: z))
sol = lx.linear_solve(op, z).value
return jnp.sum(sol), sol
x = jr.normal(getkey(), (3,))
x = jnp.where(jnp.abs(x) < 1e-6, 0.7, x)
z = jr.normal(getkey(), (3,))
grad, sol = jax.grad(f, has_aux=True)(x, z)
assert tree_allclose(grad, -z / (x**2))
assert tree_allclose(sol, z / x)
def test_grad_vmap_symbolic_cotangent():
def f(x):
return x[0], x[1]
@jax.vmap
def to_vmap(x):
op = lx.FunctionLinearOperator(f, jax.eval_shape(lambda: x))
sol = lx.linear_solve(op, x)
return sol.value[0]
@jax.grad
def to_grad(x):
return jnp.sum(to_vmap(x))
x = (jnp.arange(3.0), jnp.arange(3.0))
to_grad(x)
@pytest.mark.parametrize(
"solver",
(
lx.CG(0.0, 0.0, max_steps=2),
lx.Normal(lx.CG(0.0, 0.0, max_steps=2)),
lx.BiCGStab(0.0, 0.0, max_steps=2),
lx.GMRES(0.0, 0.0, max_steps=2),
lx.LSMR(0.0, 0.0, max_steps=2),
),
)
def test_iterative_solver_max_steps_only(solver):
"""Iterative solvers should work with max_steps only (no Equinox errors)."""
SIZE = 100
poisson_matrix = construct_poisson_matrix(SIZE)
poisson_operator = lx.MatrixLinearOperator(
poisson_matrix, tags=(lx.negative_semidefinite_tag, lx.symmetric_tag)
)
rhs = jax.random.normal(jax.random.key(0), (SIZE,))
lx.linear_solve(poisson_operator, rhs, solver)
def test_solver_init_not_differentiated(getkey):
"""stop_gradient should be applied before solver.init, not after.
Also checks that dynamic arrays in options don't cause issues.
"""
class DisallowGradWrapper(lx._solve.AbstractLinearSolver):
solver: lx._solve.AbstractLinearSolver
def init(self, operator, options):
@jax.custom_jvp
def f(operator, dummy):
del dummy
return self.solver.init(operator, options)
@f.defjvp
def _(*args):
raise NotImplementedError("solver.init should not be differentiated")
return f(operator, options.get("dummy"))
def compute(self, state, vector, options):
return self.solver.compute(state, vector, options)
def transpose(self, state, options):
return self.solver.transpose(state, options)
def conj(self, state, options):
return self.solver.conj(state, options)
def assume_full_rank(self):
return self.solver.assume_full_rank()
m = jax.random.normal(getkey(), (3, 3))
mt = jax.random.normal(getkey(), (3, 3))
v = jax.random.normal(getkey(), (3,))
dummy = jnp.array(1.0)
def f(m):
op = lx.MatrixLinearOperator(m)
return lx.linear_solve(
op, v, solver=DisallowGradWrapper(lx.QR()), options={"dummy": dummy}
).value
# Differentiating through operator only, but options has a dynamic array.
# solver.init should not be differentiated through.
jax.jvp(f, (m,), (mt,))
_, f_vjp = jax.vjp(f, m)
f_vjp(v)
def test_nonfinite_input():
operator = lx.DiagonalLinearOperator((1.0, 1.0))
vector = (1.0, jnp.inf)
sol = lx.linear_solve(operator, vector, throw=False)
assert sol.result == lx.RESULTS.nonfinite_input
vector = (1.0, jnp.nan)
sol = lx.linear_solve(operator, vector, throw=False)
assert sol.result == lx.RESULTS.nonfinite_input
vector = (jnp.nan, jnp.inf)
sol = lx.linear_solve(operator, vector, throw=False)
assert sol.result == lx.RESULTS.nonfinite_input
================================================
FILE: tests/test_transpose.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import pytest
from .helpers import construct_matrix, params, tree_allclose
class TestTranspose:
@pytest.fixture(scope="class")
def assert_transpose_fixture(_):
@eqx.filter_jit
def solve_transpose(operator, out_vec, in_vec, solver):
return jax.linear_transpose(
lambda v: lx.linear_solve(operator, v, solver).value, out_vec
)(in_vec)
def assert_transpose(operator, out_vec, in_vec, solver):
(out,) = solve_transpose(operator, out_vec, in_vec, solver)
true_out = lx.linear_solve(operator.T, in_vec, solver).value
assert tree_allclose(out, true_out)
return assert_transpose
@pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=False))
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_transpose(
_, make_operator, solver, tags, assert_transpose_fixture, dtype, getkey
):
(matrix,) = construct_matrix(getkey, solver, tags, dtype=dtype)
operator = make_operator(getkey, matrix, tags)
out_size, in_size = matrix.shape
out_vec = jr.normal(getkey(), (out_size,), dtype=dtype)
in_vec = jr.normal(getkey(), (in_size,), dtype=dtype)
solver = lx.AutoLinearSolver(well_posed=True)
assert_transpose_fixture(operator, out_vec, in_vec, solver)
def test_pytree_transpose(_, assert_transpose_fixture): # pyright: ignore
a = jnp.array
pytree = [[a(1), a(2), a(3)], [a(4), a(5), a(6)]]
output_structure = jax.eval_shape(lambda: [1, 2])
operator = lx.PyTreeLinearOperator(pytree, output_structure)
out_vec = [a(1.0), a(2.0)]
in_vec = [a(1.0), 2.0, 3.0]
solver = lx.AutoLinearSolver(well_posed=False)
assert_transpose_fixture(operator, out_vec, in_vec, solver)
================================================
FILE: tests/test_vmap.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import pytest
from .helpers import (
construct_matrix,
construct_singular_matrix,
make_jac_operator,
make_matrix_operator,
solvers_tags_pseudoinverse,
tree_allclose,
)
@pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator))
@pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse)
@pytest.mark.parametrize("use_state", (True, False))
@pytest.mark.parametrize(
"make_matrix",
(
construct_matrix,
construct_singular_matrix,
),
)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_vmap(
getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix, dtype
):
if (make_matrix is construct_matrix) or pseudoinverse:
def wrap_solve(matrix, vector):
operator = make_operator(getkey, matrix, tags)
if use_state:
state = solver.init(operator, options={})
return lx.linear_solve(operator, vector, solver, state=state).value
else:
return lx.linear_solve(operator, vector, solver).value
for op_axis, vec_axis in (
(None, 0),
(eqx.if_array(0), None),
(eqx.if_array(0), 0),
):
if op_axis is None:
axis_size = None
out_axes = None
else:
axis_size = 10
out_axes = eqx.if_array(0)
(matrix,) = eqx.filter_vmap(
lambda getkey, solver, tags: make_matrix(
getkey, solver, tags, dtype=dtype
),
axis_size=axis_size,
out_axes=out_axes,
)(getkey, solver, tags)
out_dim = matrix.shape[-2]
if vec_axis is None:
vec = jr.normal(getkey(), (out_dim,), dtype=dtype)
else:
vec = jr.normal(getkey(), (10, out_dim), dtype=dtype)
jax_result, _, _, _ = eqx.filter_vmap(
jnp.linalg.lstsq,
in_axes=(op_axis, vec_axis), # pyright: ignore
)(matrix, vec)
lx_result = eqx.filter_vmap(wrap_solve, in_axes=(op_axis, vec_axis))(
matrix, vec
)
assert tree_allclose(lx_result, jax_result)
# https://github.com/patrick-kidger/lineax/issues/101
def test_grad_vmap_basic(getkey):
A = jr.normal(getkey(), (16, 8))
B = jr.normal(getkey(), (128, 16))
@jax.jit
@jax.grad
def fn(A):
op = lx.MatrixLinearOperator(A)
return jax.vmap(
lambda b: lx.linear_solve(
op, b, lx.AutoLinearSolver(well_posed=False)
).value
)(B).mean()
fn(A)
def test_grad_vmap_advanced(getkey):
# this is a more complicated version of the above test, in which the batch axes and
# the undefinedprimals do not necessarily line up in the same arguments.
A = jr.normal(getkey(), (2, 8)), jr.normal(getkey(), (3, 8, 128))
B = jr.normal(getkey(), (2, 128)), jr.normal(getkey(), (3,))
output_structure = (
jax.ShapeDtypeStruct((2,), jnp.float64),
jax.ShapeDtypeStruct((3,), jnp.float64),
)
def to_vmap(A, B):
op = lx.PyTreeLinearOperator(A, output_structure)
return lx.linear_solve(op, B, lx.AutoLinearSolver(well_posed=False)).value
@jax.jit
@jax.grad
def fn(A):
return jax.vmap(to_vmap, in_axes=((None, 2), (1, None)))(A, B).mean()
fn(A)
================================================
FILE: tests/test_vmap_jvp.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools as ft
import equinox as eqx
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import pytest
from .helpers import (
construct_matrix,
construct_singular_matrix,
make_jac_operator,
make_matrix_operator,
solvers_tags_pseudoinverse,
tree_allclose,
)
@pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse)
@pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator))
@pytest.mark.parametrize("use_state", (True, False))
@pytest.mark.parametrize(
"make_matrix",
(
construct_matrix,
construct_singular_matrix,
),
)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_vmap_jvp(
getkey, solver, tags, make_operator, pseudoinverse, use_state, make_matrix, dtype
):
if (make_matrix is construct_matrix) or pseudoinverse:
t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None
if pseudoinverse:
jnp_solve1 = lambda mat, vec: jnp.linalg.lstsq(mat, vec)[0] # pyright: ignore
else:
jnp_solve1 = jnp.linalg.solve # pyright: ignore
if use_state:
def linear_solve1(operator, vector):
op_dynamic, op_static = eqx.partition(operator, eqx.is_inexact_array)
stopped_operator = eqx.combine(lax.stop_gradient(op_dynamic), op_static)
state = solver.init(stopped_operator, options={})
return lx.linear_solve(operator, vector, state=state, solver=solver)
else:
linear_solve1 = ft.partial(lx.linear_solve, solver=solver)
for mode in ("vec", "op", "op_vec"):
if "op" in mode:
axis_size = 10
out_axes = eqx.if_array(0)
else:
axis_size = None
out_axes = None
def _make():
matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype)
make_op = ft.partial(make_operator, getkey)
operator, t_operator = eqx.filter_jvp(
make_op, (matrix, tags), (t_matrix, t_tags)
)
return matrix, t_matrix, operator, t_operator
matrix, t_matrix, operator, t_operator = eqx.filter_vmap(
_make, axis_size=axis_size, out_axes=out_axes
)()
if "op" in mode:
_, out_size, _ = matrix.shape
else:
out_size, _ = matrix.shape
if "vec" in mode:
vec = jr.normal(getkey(), (10, out_size), dtype=dtype)
t_vec = jr.normal(getkey(), (10, out_size), dtype=dtype)
else:
vec = jr.normal(getkey(), (out_size,), dtype=dtype)
t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)
if mode == "op":
linear_solve2 = lambda op: linear_solve1(op, vector=vec)
jnp_solve2 = lambda mat: jnp_solve1(mat, vec)
elif mode == "vec":
linear_solve2 = lambda vector: linear_solve1(operator, vector)
jnp_solve2 = lambda vector: jnp_solve1(matrix, vector)
elif mode == "op_vec":
linear_solve2 = linear_solve1
jnp_solve2 = jnp_solve1
else:
assert False
for jvp_first in (True, False):
if jvp_first:
linear_solve3 = ft.partial(eqx.filter_jvp, linear_solve2)
else:
linear_solve3 = linear_solve2
linear_solve3 = eqx.filter_vmap(linear_solve3)
if not jvp_first:
linear_solve3 = ft.partial(eqx.filter_jvp, linear_solve3)
linear_solve3 = eqx.filter_jit(linear_solve3)
jnp_solve3 = ft.partial(eqx.filter_jvp, jnp_solve2)
jnp_solve3 = eqx.filter_vmap(jnp_solve3)
jnp_solve3 = eqx.filter_jit(jnp_solve3)
if mode == "op":
out, t_out = linear_solve3((operator,), (t_operator,))
true_out, true_t_out = jnp_solve3((matrix,), (t_matrix,))
elif mode == "vec":
out, t_out = linear_solve3((vec,), (t_vec,))
true_out, true_t_out = jnp_solve3((vec,), (t_vec,))
elif mode == "op_vec":
out, t_out = linear_solve3((operator, vec), (t_operator, t_vec))
true_out, true_t_out = jnp_solve3((matrix, vec), (t_matrix, t_vec))
else:
assert False
assert tree_allclose(out.value, true_out, atol=1e-4)
assert tree_allclose(t_out.value, true_t_out, atol=1e-4)
================================================
FILE: tests/test_vmap_vmap.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools as ft
import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import pytest
from .helpers import (
construct_matrix,
construct_singular_matrix,
make_jac_operator,
make_matrix_operator,
solvers_tags_pseudoinverse,
tree_allclose,
)
@pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator))
@pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse)
@pytest.mark.parametrize("use_state", (True, False))
@pytest.mark.parametrize(
"make_matrix",
(
construct_matrix,
construct_singular_matrix,
),
)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_vmap_vmap(
getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix, dtype
):
if (make_matrix is construct_matrix) or pseudoinverse:
# combinations with nontrivial application across both vmaps
axes = [
(eqx.if_array(0), eqx.if_array(0), None, None),
(None, None, 0, 0),
(eqx.if_array(0), eqx.if_array(0), None, 0),
(eqx.if_array(0), eqx.if_array(0), 0, 0),
(None, eqx.if_array(0), 0, 0),
]
for vmap2_op, vmap1_op, vmap2_vec, vmap1_vec in axes:
if vmap1_op is not None:
axis_size1 = 10
out_axis1 = eqx.if_array(0)
else:
axis_size1 = None
out_axis1 = None
if vmap2_op is not None:
axis_size2 = 10
out_axis2 = eqx.if_array(0)
else:
axis_size2 = None
out_axis2 = None
(matrix,) = eqx.filter_vmap(
eqx.filter_vmap(
lambda getkey, solver, tags: make_matrix(
getkey, solver, tags, dtype=dtype
),
axis_size=axis_size1,
out_axes=out_axis1,
),
axis_size=axis_size2,
out_axes=out_axis2,
)(getkey, solver, tags)
if vmap1_op is not None:
if vmap2_op is not None:
_, _, out_size, _ = matrix.shape
else:
_, out_size, _ = matrix.shape
else:
out_size, _ = matrix.shape
if vmap1_vec is None:
vec = jr.normal(getkey(), (out_size,), dtype=dtype)
elif (vmap1_vec is not None) and (vmap2_vec is None):
vec = jr.normal(getkey(), (10, out_size), dtype=dtype)
else:
vec = jr.normal(getkey(), (10, 10, out_size), dtype=dtype)
make_op = ft.partial(make_operator, getkey)
operator = eqx.filter_vmap(
eqx.filter_vmap(
make_op,
in_axes=vmap1_op,
out_axes=out_axis1,
),
in_axes=vmap2_op,
out_axes=out_axis2,
)(matrix, tags)
if use_state:
def linear_solve(operator, vector):
state = solver.init(operator, options={})
return lx.linear_solve(operator, vector, state=state, solver=solver)
else:
def linear_solve(operator, vector):
return lx.linear_solve(operator, vector, solver)
as_matrix_vmapped = eqx.filter_vmap(
eqx.filter_vmap(
lambda x: x.as_matrix(),
in_axes=vmap1_op,
out_axes=None if vmap1_op is None else 0,
),
in_axes=vmap2_op,
out_axes=None if vmap2_op is None else 0,
)(operator)
vmap1_axes = (vmap1_op, vmap1_vec)
vmap2_axes = (vmap2_op, vmap2_vec)
result = eqx.filter_vmap(
eqx.filter_vmap(linear_solve, in_axes=vmap1_axes), in_axes=vmap2_axes
)(operator, vec).value
solve_with = lambda x: eqx.filter_vmap(
eqx.filter_vmap(x, in_axes=vmap1_axes), in_axes=vmap2_axes
)(as_matrix_vmapped, vec)
if make_matrix is construct_singular_matrix:
true_result, _, _, _ = solve_with(jnp.linalg.lstsq) # pyright: ignore
else:
true_result = solve_with(jnp.linalg.solve) # pyright: ignore
assert tree_allclose(result, true_result, rtol=1e-3)
================================================
FILE: tests/test_well_posed.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import pytest
from .helpers import (
construct_matrix,
make_jacrev_operator,
ops,
params,
solvers,
tree_allclose,
)
@pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=False))
@pytest.mark.parametrize("ops", ops)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype):
if make_operator is make_jacrev_operator and dtype is jnp.complex128:
# JacobianLinearOperator does not support complex dtypes when jac="bwd"
return
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-10
else:
tol = 1e-4
(matrix,) = construct_matrix(getkey, solver, tags, dtype=dtype)
operator = make_operator(getkey, matrix, tags)
operator, matrix = ops(operator, matrix)
assert tree_allclose(operator.as_matrix(), matrix, rtol=tol, atol=tol)
out_size, _ = matrix.shape
true_x = jr.normal(getkey(), (out_size,), dtype=dtype)
b = matrix @ true_x
x = lx.linear_solve(operator, b, solver=solver).value
jax_x = jnp.linalg.solve(matrix, b) # pyright: ignore
assert tree_allclose(x, true_x, atol=tol, rtol=tol)
assert tree_allclose(x, jax_x, atol=tol, rtol=tol)
@pytest.mark.parametrize("solver", solvers)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_pytree_wellposed(solver, getkey, dtype):
if not isinstance(
solver,
(lx.Diagonal, lx.Triangular, lx.Tridiagonal, lx.Cholesky, lx.CG),
):
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-10
else:
tol = 1e-4
true_x = [
jr.normal(getkey(), shape=(2, 4), dtype=dtype),
jr.normal(getkey(), (3,), dtype=dtype),
]
pytree = [
[
jr.normal(getkey(), shape=(2, 4, 2, 4), dtype=dtype),
jr.normal(getkey(), shape=(2, 4, 3), dtype=dtype),
],
[
jr.normal(getkey(), shape=(3, 2, 4), dtype=dtype),
jr.normal(getkey(), shape=(3, 3), dtype=dtype),
],
]
out_structure = jax.eval_shape(lambda: true_x)
operator = lx.PyTreeLinearOperator(pytree, out_structure)
b = operator.mv(true_x)
lx_x = lx.linear_solve(operator, b, solver, throw=False)
assert tree_allclose(lx_x.value, true_x, atol=tol, rtol=tol)