Showing preview only (439K chars total). Download the full file or copy to clipboard to get everything.
Repository: patrick-kidger/lineax
Branch: main
Commit: 2bf9824f7df9
Files: 75
Total size: 416.0 KB
Directory structure:
gitextract_at99w4wk/
├── .github/
│ └── workflows/
│ ├── build_docs.yml
│ ├── release.yml
│ └── run_tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── benchmarks/
│ ├── gmres_fails_safely.py
│ ├── lstsq_gradients.py
│ └── solver_speeds.py
├── docs/
│ ├── .htaccess
│ ├── _overrides/
│ │ └── partials/
│ │ └── source.html
│ ├── _static/
│ │ ├── custom_css.css
│ │ └── mathjax.js
│ ├── api/
│ │ ├── functions.md
│ │ ├── linear_solve.md
│ │ ├── operators.md
│ │ ├── solution.md
│ │ ├── solvers.md
│ │ └── tags.md
│ ├── examples/
│ │ ├── classical_solve.ipynb
│ │ ├── complex_solve.ipynb
│ │ ├── least_squares.ipynb
│ │ ├── no_materialisation.ipynb
│ │ ├── operators.ipynb
│ │ └── structured_matrices.ipynb
│ ├── faq.md
│ └── index.md
├── lineax/
│ ├── __init__.py
│ ├── _custom_types.py
│ ├── _misc.py
│ ├── _norm.py
│ ├── _operator.py
│ ├── _solution.py
│ ├── _solve.py
│ ├── _solver/
│ │ ├── __init__.py
│ │ ├── bicgstab.py
│ │ ├── cg.py
│ │ ├── cholesky.py
│ │ ├── diagonal.py
│ │ ├── gmres.py
│ │ ├── lsmr.py
│ │ ├── lu.py
│ │ ├── misc.py
│ │ ├── normal.py
│ │ ├── qr.py
│ │ ├── svd.py
│ │ ├── triangular.py
│ │ └── tridiagonal.py
│ ├── _tags.py
│ └── internal/
│ └── __init__.py
├── mkdocs.yml
├── pyproject.toml
└── tests/
├── README.md
├── __init__.py
├── __main__.py
├── conftest.py
├── helpers.py
├── test_adjoint.py
├── test_invert.py
├── test_jvp.py
├── test_jvp_jvp1.py
├── test_jvp_jvp2.py
├── test_lsmr.py
├── test_misc.py
├── test_norm.py
├── test_operator.py
├── test_singular.py
├── test_solve.py
├── test_transpose.py
├── test_vmap.py
├── test_vmap_jvp.py
├── test_vmap_vmap.py
└── test_well_posed.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/build_docs.yml
================================================
name: Build docs
on:
push:
branches:
- main
jobs:
build:
strategy:
matrix:
python-version: [ 3.11 ]
os: [ ubuntu-latest ]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
uv run echo done
- name: Build docs
run: |
uv run mkdocs build
- name: Upload docs
uses: actions/upload-artifact@v4
with:
name: docs
path: site # where `mkdocs build` puts the built site
================================================
FILE: .github/workflows/release.yml
================================================
name: Release
on:
push:
branches:
- main
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Release
uses: patrick-kidger/action_update_python_project@v8
with:
python-version: "3.11"
# Uninstall and reinstall pytest to work around the fact that it doesn't get put into `bin` otherwise.
test-script: |
cp -r ${{ github.workspace }}/tests ./tests
cp ${{ github.workspace }}/pyproject.toml ./pyproject.toml
uv pip uninstall pytest
uv sync --no-install-project --inexact
uv run --no-sync pytest
pypi-token: ${{ secrets.pypi_token }}
github-user: patrick-kidger
github-token: ${{ github.token }}
================================================
FILE: .github/workflows/run_tests.yml
================================================
name: Run tests
on:
pull_request:
jobs:
run-test:
strategy:
matrix:
python-version: [ 3.11 ]
os: [ ubuntu-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
uv run echo done
- name: Checks with pre-commit
run: |
uv run prek run --all-files
- name: Test with pytest
run: |
uv run python -m tests
- name: Check that documentation can be built.
run: |
uv run mkdocs build
================================================
FILE: .gitignore
================================================
**/__pycache__
**/.ipynb_checkpoints
*.egg-info/
build/
dist/
site/
examples/data
.all_objects.cache
.pymon
.idea
.venv
uv.lock
================================================
FILE: .pre-commit-config.yaml
================================================
fail_fast: true
repos:
- repo: meta
hooks:
- id: check-hooks-apply
- id: check-useless-excludes
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
exclude: \.md$
- id: check-toml
- id: mixed-line-ending
- repo: local
hooks:
- id: sort-pyproject
name: sort pyproject
files: ^pyproject\.toml$
language: system
entry: uv run -- toml-sort -i --sort-table-keys --sort-inline-tables
- id: ruff-format
name: ruff format
types_or: [python, pyi, jupyter, toml]
language: system
entry: uv run -- ruff format --
require_serial: true
- id: ruff-lint
name: ruff lint
types_or: [python, pyi, jupyter, toml]
language: system
entry: uv run -- ruff check --fix --
require_serial: true
- id: pyright
name: pyright
types_or: [python]
language: system
entry: uv run -- pyright
require_serial: true
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing
Contributions (pull requests) are very welcome! Here's how to get started.
---
### Getting started
[We assume that you have `uv` installed.](https://docs.astral.sh/uv/) Now fork the library on GitHub. Then clone and install the library:
```bash
git clone https://github.com/your-username-here/lineax.git
cd lineax
uv run prek install # Creates a local venv + installs dependencies + installs pre-commit hooks.
```
---
### If you're making changes to the code
Now make your changes. Make sure to include additional tests if necessary. Next verify the tests all pass:
```bash
uv run python -m tests
```
Then push your changes back to your fork of the repository:
```bash
git push
```
Finally, open a pull request on GitHub!
---
### If you're making changes to the documentation
Make your changes. You can then build the documentation by doing
```bash
uv run mkdocs serve
```
You can then see your local copy of the documentation by navigating to `localhost:8000` in a web browser.
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
<h1 align='center'>Lineax</h1>
Lineax is a [JAX](https://github.com/google/jax) library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.)
Features include:
- PyTree-valued matrices and vectors;
- General linear operators for Jacobians, transposes, etc.;
- Efficient linear least squares (e.g. QR solvers);
- Numerically stable gradients through linear least squares;
- Support for structured (e.g. symmetric) matrices;
- Improved compilation times;
- Improved runtime of some algorithms;
- Support for both real-valued and complex-valued inputs;
- All the benefits of working with JAX: autodiff, autoparallelism, GPU/TPU support, etc.
## Installation
```bash
pip install lineax
```
Requires Python 3.10+, JAX 0.4.38+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.10+.
## Documentation
Available at [https://docs.kidger.site/lineax](https://docs.kidger.site/lineax).
## Quick examples
Lineax can solve a least squares problem with an explicit matrix operator:
```python
import jax.random as jr
import lineax as lx
matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 8))
vector = jr.normal(vector_key, (10,))
operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector, solver=lx.QR())
```
or Lineax can solve a problem without ever materializing a matrix, as done in this
quadratic solve:
```python
import jax
import lineax as lx
key = jax.random.PRNGKey(0)
y = jax.random.normal(key, (10,))
def quadratic_fn(y, args):
return jax.numpy.sum((y - 1)**2)
gradient_fn = jax.grad(quadratic_fn)
hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-6, atol=1e-6)
out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver)
minimum = y - out.value
```
## Citation
If you found this library to be useful in academic work, then please cite: ([arXiv link](https://arxiv.org/abs/2311.17283))
```bibtex
@article{lineax2023,
title={Lineax: unified linear solves and linear least-squares in JAX and Equinox},
author={Jason Rader and Terry Lyons and Patrick Kidger},
journal={
AI for science workshop at Neural Information Processing Systems 2023,
arXiv:2311.17283
},
year={2023},
}
```
(Also consider starring the project on GitHub.)
## See also: other libraries in the JAX ecosystem
**Always useful**
[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
[jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.
**Deep learning**
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
[paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
**Scientific computing**
[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
**Awesome JAX**
[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.
================================================
FILE: benchmarks/gmres_fails_safely.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools as ft
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import lineax as lx
getkey = eqxi.GetKey()
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8):
return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)
jax.config.update("jax_enable_x64", True)
def make_problem(mat_size: int, *, key):
mat = jr.normal(key, (mat_size, mat_size))
true_x = jr.normal(key, (mat_size,))
b = mat @ true_x
op = lx.MatrixLinearOperator(mat)
return mat, op, b, true_x
def benchmark_jax(mat_size: int, *, key):
mat, _, b, true_x = make_problem(mat_size, key=key)
solve_with_jax = ft.partial(
jsp.sparse.linalg.gmres, tol=1e-5, solve_method="batched"
)
gmres_jit = jax.jit(solve_with_jax)
jax_soln, info = gmres_jit(mat, b)
# info == 0.0 implies that the solve has succeeded.
returned_failed = jnp.all(info != 0.0)
actually_failed = not tree_allclose(jax_soln, true_x, atol=1e-4, rtol=1e-4)
assert actually_failed
captured_failure = returned_failed & actually_failed
return captured_failure
def benchmark_lx(mat_size: int, *, key):
_, op, b, true_x = make_problem(mat_size, key=key)
lx_soln = lx.linear_solve(op, b, lx.GMRES(atol=1e-5, rtol=1e-5), throw=False)
returned_failed = jnp.all(lx_soln.result != lx.RESULTS.successful)
actually_failed = not tree_allclose(lx_soln.value, true_x, atol=1e-4, rtol=1e-4)
assert actually_failed
captured_failure = returned_failed & actually_failed
return captured_failure
lx_failed_safely = 0
jax_failed_safely = 0
for _ in range(100):
key = getkey()
jax_captured_failure = benchmark_jax(100, key=key)
lx_captured_failure = benchmark_lx(100, key=key)
jax_failed_safely = jax_failed_safely + jax_captured_failure
lx_failed_safely = lx_failed_safely + lx_captured_failure
print(f"JAX failed safely {jax_failed_safely} out of 100 times")
print(f"Lineax failed safely {lx_failed_safely} out of 100 times")
================================================
FILE: benchmarks/lstsq_gradients.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Core JAX has some numerical issues with their lstsq gradients.
# See https://github.com/google/jax/issues/14868
# This demonstrates that we don't have the same issue!
import sys
import jax
import jax.numpy as jnp
import lineax as lx
sys.path.append("../tests")
from helpers import finite_difference_jvp # pyright: ignore
a_primal = (jnp.eye(3),)
a_tangent = (jnp.zeros((3, 3)),)
def jax_solve(a):
sol, _, _, _ = jnp.linalg.lstsq(a, jnp.arange(3)) # pyright: ignore
return sol
def lx_solve(a):
op = lx.MatrixLinearOperator(a)
return lx.linear_solve(op, jnp.arange(3)).value
_, true_jvp = finite_difference_jvp(jax_solve, a_primal, a_tangent)
_, jax_jvp = jax.jvp(jax_solve, a_primal, a_tangent)
_, lx_jvp = jax.jvp(lx_solve, a_primal, a_tangent)
assert jnp.isnan(jax_jvp).all()
assert jnp.allclose(true_jvp, lx_jvp)
================================================
FILE: benchmarks/solver_speeds.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools as ft
import sys
import timeit
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import lineax as lx
sys.path.append("../tests")
from helpers import construct_matrix, has_tag # pyright: ignore[reportMissingImports]
getkey = eqxi.GetKey()
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8):
return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)
jax.config.update("jax_enable_x64", True)
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-12
else:
tol = 1e-6
def base_wrapper(a, b, solver):
op = lx.MatrixLinearOperator(
a,
(
lx.positive_semidefinite_tag,
lx.symmetric_tag,
lx.diagonal_tag,
lx.tridiagonal_tag,
),
)
out = lx.linear_solve(op, b, solver, throw=False)
return out.value
def jax_svd(a, b):
out, _, _, _ = jnp.linalg.lstsq(a, b) # pyright: ignore
return out
def jax_gmres(a, b):
out, _ = jsp.sparse.linalg.gmres(a, b, tol=tol)
return out
def jax_bicgstab(a, b):
out, _ = jsp.sparse.linalg.bicgstab(a, b, tol=tol)
return out
def jax_cg(a, b):
out, _ = jsp.sparse.linalg.cg(a, b, tol=tol)
return out
def jax_lu(matrix, vector):
return jsp.linalg.lu_solve(jsp.linalg.lu_factor(matrix), vector)
def jax_cholesky(matrix, vector):
return jsp.linalg.cho_solve(jsp.linalg.cho_factor(matrix), vector)
def jax_tridiagonal(matrix, vector):
dl = jnp.append(0.0, matrix.diagonal(-1))
d = matrix.diagonal(0)
du = jnp.append(matrix.diagonal(1), 0.0)
return jax.lax.linalg.tridiagonal_solve(dl, d, du, vector[:, None])[:, 0]
named_solvers = [
("LU", "LU", lx.LU(), jax_lu, ()),
("QR", "SVD", lx.QR(), jax_svd, ()),
("SVD", "SVD", lx.SVD(), jax_svd, ()),
(
"Cholesky",
"Cholesky",
lx.Cholesky(),
jax_cholesky,
lx.positive_semidefinite_tag,
),
("Diagonal", "None", lx.Diagonal(), None, lx.diagonal_tag),
(
"Tridiagonal",
"Tridiagonal",
lx.Tridiagonal(),
jax_tridiagonal,
lx.tridiagonal_tag,
),
(
"CG",
"CG",
lx.CG(atol=tol, rtol=tol, stabilise_every=None),
jax_cg,
lx.positive_semidefinite_tag,
),
(
"GMRES",
"GMRES",
lx.GMRES(atol=1, rtol=1),
jax_gmres,
(),
),
(
"BiCGStab",
"BiCGStab",
lx.BiCGStab(atol=tol, rtol=tol),
jax_bicgstab,
(),
),
]
def create_problem(solver, tags, size=3):
(matrix,) = construct_matrix(getkey, solver, tags, size=size)
true_x = jr.normal(getkey(), (size,))
b = matrix @ true_x
return matrix, true_x, b
def create_easy_iterative_problem(size, tags):
matrix = jr.normal(getkey(), (size, size)) / size + 2 * jnp.eye(size)
true_x = jr.normal(getkey(), (size,))
if has_tag(tags, lx.positive_semidefinite_tag):
matrix = matrix.T @ matrix
b = matrix @ true_x
return matrix, true_x, b
def test_solvers(vmap_size, mat_size):
for lx_name, jax_name, _lx_solver, jax_solver, tags in named_solvers:
lx_solver = ft.partial(base_wrapper, solver=_lx_solver)
if vmap_size == 1:
if isinstance(_lx_solver, (lx.CG, lx.GMRES, lx.BiCGStab)):
matrix, true_x, b = create_easy_iterative_problem(mat_size, tags)
else:
matrix, true_x, b = create_problem(lx_solver, tags, size=mat_size)
else:
if isinstance(_lx_solver, (lx.CG, lx.GMRES, lx.BiCGStab)):
matrix, true_x, b = eqx.filter_vmap(
create_easy_iterative_problem,
axis_size=vmap_size,
out_axes=eqx.if_array(0),
)(mat_size, tags)
else:
matrix, true_x, b = create_problem(lx_solver, tags, size=mat_size)
_create_problem = ft.partial(create_problem, size=mat_size)
matrix, true_x, b = eqx.filter_vmap(
_create_problem, axis_size=vmap_size, out_axes=eqx.if_array(0)
)(lx_solver, tags)
lx_solver = jax.vmap(lx_solver)
if jax_solver is not None:
jax_solver = jax.vmap(jax_solver)
lx_solver = jax.jit(lx_solver)
bench_lx = ft.partial(lx_solver, matrix, b)
if vmap_size == 1:
batch_msg = "problem"
else:
batch_msg = f"batch of {vmap_size} problems"
lx_soln = bench_lx()
if tree_allclose(lx_soln, true_x, atol=1e-4, rtol=1e-4):
lx_solve_time = timeit.timeit(bench_lx, number=1)
print(
f"Lineax's {lx_name} solved {batch_msg} of "
f"size {mat_size} in {lx_solve_time} seconds."
)
else:
fail_time = timeit.timeit(bench_lx, number=1)
err = jnp.abs(lx_soln - true_x).max()
print(
f"Lineax's {lx_name} failed to solve {batch_msg} of "
f"size {mat_size} with error {err} in {fail_time} seconds"
)
if jax_solver is None:
print("JAX has no equivalent solver. \n")
else:
jax_solver = jax.jit(jax_solver)
bench_jax = ft.partial(jax_solver, matrix, b)
jax_soln = bench_jax()
if tree_allclose(jax_soln, true_x, atol=1e-4, rtol=1e-4):
jax_solve_time = timeit.timeit(bench_jax, number=1)
print(
f"JAX's {jax_name} solved {batch_msg} of "
f"size {mat_size} in {jax_solve_time} seconds. \n"
)
else:
fail_time = timeit.timeit(bench_jax, number=1)
err = jnp.abs(jax_soln - true_x).max()
print(
f"JAX's {jax_name} failed to solve {batch_msg} of "
f"size {mat_size} with error {err} in {fail_time} seconds. \n"
)
for vmap_size, mat_size in [(1, 50), (1000, 50)]:
test_solvers(vmap_size, mat_size)
================================================
FILE: docs/.htaccess
================================================
ErrorDocument 404 /jaxtyping/404.html
================================================
FILE: docs/_overrides/partials/source.html
================================================
{% import "partials/language.html" as lang with context %}
<a href="{{ config.repo_url }}" title="{{ lang.t('source.link.title') }}" class="md-source" data-md-component="source">
<div class="md-source__icon md-icon">
{% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %}
{% include ".icons/" ~ icon ~ ".svg" %}
</div>
<div class="md-source__repository">
{{ config.repo_name }}
</div>
</a>
<a href="{{ config.theme.twitter_url }}" title="Go to Twitter" class="md-source">
<div class="md-source__icon md-icon">
{% include ".icons/fontawesome/brands/twitter.svg" %}
</div>
</a>
<a href="{{ config.theme.bluesky_url }}" title="Go to Bluesky" class="md-source">
<div class="md-source__icon md-icon">
{% include "bluesky.svg" %}
</div>
<div class="md-source__repository">
{{ config.theme.twitter_bluesky_name }}
</div>
</a>
================================================
FILE: docs/_static/custom_css.css
================================================
/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */
html {
scroll-padding-top: 50px;
}
/* Fit the Twitter handle alongside the GitHub one in the top right. */
div.md-header__source {
width: revert;
max-width: revert;
}
a.md-source {
display: inline-block;
}
.md-source__repository {
max-width: 100%;
}
/* Emphasise sections of nav on left hand side */
nav.md-nav {
padding-left: 5px;
}
nav.md-nav--secondary {
border-left: revert !important;
}
.md-nav__title {
font-size: 0.9rem;
}
.md-nav__item--section > .md-nav__link {
font-size: 0.9rem;
}
/* Indent autogenerated documentation */
div.doc-contents {
padding-left: 25px;
border-left: 4px solid rgba(230, 230, 230);
}
/* Increase visibility of splitters "---" */
[data-md-color-scheme="default"] .md-typeset hr {
border-bottom-color: rgb(0, 0, 0);
border-bottom-width: 1pt;
}
[data-md-color-scheme="slate"] .md-typeset hr {
border-bottom-color: rgb(230, 230, 230);
}
/* More space at the bottom of the page */
.md-main__inner {
margin-bottom: 1.5rem;
}
/* Remove prev/next footer buttons */
.md-footer__inner {
display: none;
}
/* Change font sizes */
html {
/* Decrease font size for overall webpage
Down from 137.5% which is the Material default */
font-size: 110%;
}
.md-typeset .admonition {
/* Increase font size in admonitions */
font-size: 100% !important;
}
.md-typeset details {
/* Increase font size in details */
font-size: 100% !important;
}
.md-typeset h1 {
font-size: 1.6rem;
}
.md-typeset h2 {
font-size: 1.5rem;
}
.md-typeset h3 {
font-size: 1.3rem;
}
.md-typeset h4 {
font-size: 1.1rem;
}
.md-typeset h5 {
font-size: 0.9rem;
}
.md-typeset h6 {
font-size: 0.8rem;
}
/* Bugfix: remove the superfluous parts generated when doing:
??? Blah
::: library.something
*/
.md-typeset details .mkdocstrings > h4 {
display: none;
}
.md-typeset details .mkdocstrings > h5 {
display: none;
}
/* Change default colours for <a> tags */
[data-md-color-scheme="default"] {
--md-typeset-a-color: rgb(0, 189, 164) !important;
}
[data-md-color-scheme="slate"] {
--md-typeset-a-color: rgb(0, 189, 164) !important;
}
/* Highlight functions, classes etc. type signatures. Really helps to make clear where
one item ends and another begins. */
[data-md-color-scheme="default"] {
--doc-heading-color: #DDD;
--doc-heading-border-color: #CCC;
--doc-heading-color-alt: #F0F0F0;
}
[data-md-color-scheme="slate"] {
--doc-heading-color: rgb(25,25,33);
--doc-heading-border-color: rgb(25,25,33);
--doc-heading-color-alt: rgb(33,33,44);
--md-code-bg-color: rgb(38,38,50);
}
h4.doc-heading {
/* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/
background-color: var(--doc-heading-color);
border: solid var(--doc-heading-border-color);
border-width: 1.5pt;
border-radius: 2pt;
padding: 0pt 5pt 2pt 5pt;
}
h5.doc-heading, h6.heading {
background-color: var(--doc-heading-color-alt);
border-radius: 2pt;
padding: 0pt 5pt 2pt 5pt;
}
================================================
FILE: docs/_static/mathjax.js
================================================
window.MathJax = {
tex: {
inlineMath: [["\\(", "\\)"]],
displayMath: [["\\[", "\\]"]],
processEscapes: true,
processEnvironments: true
},
options: {
ignoreHtmlClass: ".*|",
processHtmlClass: "arithmatex"
}
};
document$.subscribe(() => {
MathJax.typesetPromise()
})
================================================
FILE: docs/api/functions.md
================================================
# Functions on linear operators
We define a number of functions on [linear operators](./operators.md).
## Computational changes
These do not change the mathematical meaning of the operator; they simply change how it is stored computationally. (E.g. to materialise the whole operator.)
::: lineax.linearise
---
::: lineax.materialise
## Extract information from the operator
::: lineax.diagonal
---
::: lineax.tridiagonal
## Test the operator to see if it exhibits a certain property
Note that these do *not* inspect the values of the operator -- instead, they use typically use [tags](./tags.md). (Or in some cases, just the type of the operator: e.g. `is_diagonal(DiagonalLinearOperator(...)) == True`.)
::: lineax.has_unit_diagonal
---
::: lineax.is_diagonal
---
::: lineax.is_tridiagonal
---
::: lineax.is_lower_triangular
---
::: lineax.is_upper_triangular
---
::: lineax.is_positive_semidefinite
---
::: lineax.is_negative_semidefinite
---
::: lineax.is_symmetric
================================================
FILE: docs/api/linear_solve.md
================================================
# linear_solve
This is the main entry point.
::: lineax.linear_solve
## invert
A convenience function for obtaining the inverse of an operator as a [`lineax.FunctionLinearOperator`][].
::: lineax.invert
================================================
FILE: docs/api/operators.md
================================================
# Linear operators
We often talk about solving a linear system $Ax = b$, where $A \in \mathbb{R}^{n \times m}$ is a matrix, $b \in \mathbb{R}^n$ is a vector, and $x \in \mathbb{R}^m$ is our desired solution.
The linear operators described on this page are ways of describing the matrix $A$. The simplest is [`lineax.MatrixLinearOperator`][], which simply holds the matrix $A$ directly.
Meanwhile if $A$ is diagonal, then there is also [`lineax.DiagonalLinearOperator`][]: for efficiency this only stores the diagonal of $A$.
Or, perhaps we only have a function $F : \mathbb{R}^m \to \mathbb{R}^n$ such that $F(x) = Ax$. Whilst we could use $F$ to materialise the whole matrix $A$ and then store it in a [`lineax.MatrixLinearOperator`][], that may be very memory intensive. Instead, we may prefer to use [`lineax.FunctionLinearOperator`][]. Many linear solvers (e.g. [`lineax.CG`][]) only use matrix-vector products, and this means we can avoid ever needing to materialise the whole matrix $A$.
??? abstract "`lineax.AbstractLinearOperator`"
::: lineax.AbstractLinearOperator
options:
members:
- mv
- as_matrix
- transpose
- in_structure
- out_structure
- in_size
- out_size
::: lineax.MatrixLinearOperator
options:
members:
- __init__
---
::: lineax.DiagonalLinearOperator
options:
members:
- __init__
---
::: lineax.TridiagonalLinearOperator
options:
members:
- __init__
---
::: lineax.PyTreeLinearOperator
options:
members:
- __init__
---
::: lineax.JacobianLinearOperator
options:
members:
- __init__
---
::: lineax.FunctionLinearOperator
options:
members:
- __init__
---
::: lineax.IdentityLinearOperator
options:
members:
- __init__
---
::: lineax.TaggedLinearOperator
options:
members:
- __init__
================================================
FILE: docs/api/solution.md
================================================
# Solution
::: lineax.Solution
options:
members: []
---
::: lineax.RESULTS
options:
members: []
================================================
FILE: docs/api/solvers.md
================================================
# Solvers
If you're not sure what to use, then pick [`lineax.AutoLinearSolver`][] and it will automatically dispatch to an efficient solver depending on what structure your linear operator is declared to exhibit. (See the [tags](./tags.md) page.)
??? abstract "`lineax.AbstractLinearSolver`"
::: lineax.AbstractLinearSolver
options:
members:
- init
- compute
- transpose
- conj
- assume_full_rank
::: lineax.AutoLinearSolver
options:
members:
- __init__
- select_solver
---
::: lineax.LU
options:
members:
- __init__
## Least squares solvers
These are capable of solving ill-posed linear problems.
::: lineax.QR
options:
members:
- __init__
---
::: lineax.SVD
options:
members:
- __init__
---
::: lineax.Normal
options:
members:
- __init__
---
::: lineax.LSMR
options:
members:
- __init__
#### Diagonal
In addition to these, [`lineax.Diagonal`][] with `well_posed=False` (below) also supports ill-posed problems.
## Iterative solvers
These solvers use only matrix-vector products, and do not require instantiating the whole matrix. This makes them good when used alongside e.g. [`lineax.JacobianLinearOperator`][] or [`lineax.FunctionLinearOperator`][], which only provide matrix-vector products.
!!! warning
Note that [`lineax.BiCGStab`][] and [`lineax.GMRES`][] may fail to converge on some (typically non-sparse) problems.
::: lineax.CG
options:
members:
- __init__
---
::: lineax.BiCGStab
options:
members:
- __init__
---
::: lineax.GMRES
options:
members:
- __init__
#### LSMR
In addition to these, [`lineax.LSMR`][] (above) is also an iterative method.
## Structure-exploiting solvers
These require special structure in the operator. (And will throw an error if passed an operator without that structure.) In return, they are able to solve the linear problem much more efficiently.
::: lineax.Cholesky
options:
members:
- __init__
---
::: lineax.Diagonal
options:
members:
- __init__
---
::: lineax.Triangular
options:
members:
- __init__
---
::: lineax.Tridiagonal
options:
members:
- __init__
#### CG
In addition to these, [`lineax.CG`][] also requires special structure (positive or negative definiteness).
================================================
FILE: docs/api/tags.md
================================================
# Tags
Lineax offers a way to "tag" linear operators as exhibiting certain properties, e.g. that they are positive semidefinite.
If a linear operator is known to have a particular property, then this can be used to dispatch to a more efficient implementation, e.g. when solving a linear system.
Generally speaking, tags are an *optional* tool that can be used to improve your run time and/or compile time, by statically telling the linear solvers what properties they may assume about your system. However, if misused then you may find that the wrong result is silently returned.
In this way they are analogous to flags like `scipy.linalg.solve(..., assume_a="pos")`.
!!! Example
```python
# Some rank-2 JAX array.
matrix = ...
# Some rank-1 JAX array.
vector = ...
# Declare that this matrix is positive semidefinite.
operator = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag)
# This tag is used to dispatch to a maximally-efficient linear solver.
# In this case, a Cholesky solver is used:
solution = lx.linear_solve(operator, vector)
# Whether operators are tagged can be checked:
assert lx.is_positive_semidefinite(operator)
```
!!! Warning
Be careful, only the tag is actually checked, not the actual value of the matrix:
```python
# Not a positive semidefinite matrix
matrix = jax.numpy.array([[1, 2], [3, 4]])
operator = lx.MatrixLinearOperator(matrix, lx.positive_semidefinite_tag)
lx.is_positive_semidefinite(operator) # True
lx.linear_solve(operator, vector) # Returns the wrong solution!
```
Of the built-in operators: [`lineax.MatrixLinearOperator`][], [`lineax.PyTreeLinearOperator`][], [`lineax.JacobianLinearOperator`][], [`lineax.FunctionLinearOperator`][], [`lineax.TaggedLinearOperator`][] directly support a `tags` argument that mark them as having certain characteristics:
```python
operator = lx.MatrixLinearOperator(matrix, lx.symmetric_tag)
```
You can pass multiple tags at once:
```python
operator = lx.MatrixLinearOperator(matrix, (lx.symmetric_tag, lx.unit_diagonal_tag))
```
Other linear operators can be wrapped into a [`lineax.TaggedLinearOperator`][] if necessary:
```python
operator = lx.MatrixLinearOperator(...)
symmetric_operator = operator + operator.T
lx.is_symmetric(symmetric_operator) # False
symmetric_operator = lx.TaggedLinearOperator(symmetric_operator, lx.symmetric_tag)
lx.is_symmetric(symmetric_operator) # True
```
Some linear operators are known to exhibit certain properties by construction, and need no additional tags:
```python
lx.is_symmetric(lx.DiagonalLinearOperator(...)) # True
lx.is_positive_semidefinite(lx.IdentityLinearOperator(...)) # True
```
## List of available tags
::: lineax.symmetric_tag
Marks that an operator is symmetric. (As a matrix, $A = A^\intercal$.)
---
::: lineax.diagonal_tag
Marks than an operator is diagonal. (As a matrix, it must have zeros in the off-diagonal entries.)
For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Diagonal`][] as the solver.
---
::: lineax.unit_diagonal_tag
Marks than an operator has $1$ for every diagonal element. (As a matrix $A$, then it must have $A_{ii} = 1$ for all $i$.) Note that the whole matrix need not be diagonal.
For example, [`lineax.Triangular`][] uses this to cheapen its solve.
---
::: lineax.lower_triangular_tag
Marks that an operator is lower triangular. (As a matrix $A$, then it must have $A_{ij} = 0 for all $i < j$.) Note that the diagonal may still have nonzero entries.
For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Triangular`][] as the solver.
---
::: lineax.upper_triangular_tag
Marks that an operator is upper triangular. (As a matrix $A$, then it must have $A_{ij} = 0 for all $i > j$.) Note that the diagonal may still have nonzero entries.
For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Triangular`][] as the solver.
---
::: lineax.positive_semidefinite_tag
Marks than operator is positive **semidefinite**.
For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Cholesky`][] as the solver.
---
::: lineax.negative_semidefinite_tag
Marks than operator is negative **semidefinite**.
For example, the default solver for [`lineax.linear_solve`][] uses this to dispatch to [`lineax.Cholesky`][] as the solver.
================================================
FILE: docs/examples/classical_solve.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "8d41e1dd-93da-4e81-bd4a-33e5df8915f1",
"metadata": {},
"source": [
"# Classical solve\n",
"\n",
"We wish to solve the linear system $Ax = b$. Here we consider the classical case for which the full matrix $A$ is square, well-posed and materialised in memory."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "cb3a7781-2358-40c4-82f3-e908bddeb578",
"metadata": {
"tags": [],
"ExecuteTime": {
"end_time": "2024-04-02T05:26:05.556701Z",
"start_time": "2024-04-02T05:26:03.814599Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A=\n",
"[[-0.3721109 0.26423115 -0.18252768]\n",
" [-0.7368197 0.44973662 -0.1521442 ]\n",
" [-0.67135346 -0.5908641 0.73168886]]\n",
"b=[ 0.17269018 -0.64765567 1.2229712 ]\n",
"x=[-2.7321298 -8.52878 -7.7226872]\n"
]
}
],
"source": [
"import jax.random as jr\n",
"import lineax as lx\n",
"\n",
"\n",
"matrix = jr.normal(jr.PRNGKey(0), (3, 3))\n",
"vector = jr.normal(jr.PRNGKey(1), (3,))\n",
"operator = lx.MatrixLinearOperator(matrix)\n",
"solution = lx.linear_solve(operator, vector)\n",
"print(f\"A=\\n{matrix}\\nb={vector}\\nx={solution.value}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/examples/complex_solve.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "8d41e1dd-93da-4e81-bd4a-33e5df8915f1",
"metadata": {},
"source": [
"# Complex solve\n",
"\n",
"We can also solve a system with complex entries. Here we consider the classical case for which the full matrix $A$ is square, well-posed and materialised in memory."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "cb3a7781-2358-40c4-82f3-e908bddeb578",
"metadata": {
"tags": [],
"ExecuteTime": {
"end_time": "2024-04-02T05:29:04.909894Z",
"start_time": "2024-04-02T05:29:04.103141Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A=\n",
"[[-1.8459436 -0.2744466j 0.02393756-0.03172905j 0.76815367-1.4444253j ]\n",
" [-1.0467293 +0.05608991j 1.0891742 -0.03264743j 0.7513123 +0.56285536j]\n",
" [ 0.38307396-1.0190808j 0.01203694-1.1971304j 0.19252291-0.26424018j]]\n",
"b=[0.23162952+0.3614433j 0.05800135+1.6094692j 0.8979094 +0.16941352j]\n",
"x=[-0.07652722-0.34397143j -0.22629777+1.0359733j 0.22135164-0.00880566j]\n"
]
}
],
"source": [
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"import lineax as lx\n",
"\n",
"\n",
"matrix = jr.normal(jr.PRNGKey(0), (3, 3), dtype=jnp.complex64)\n",
"vector = jr.normal(jr.PRNGKey(1), (3,), dtype=jnp.complex64)\n",
"operator = lx.MatrixLinearOperator(matrix)\n",
"solution = lx.linear_solve(operator, vector)\n",
"print(f\"A=\\n{matrix}\\nb={vector}\\nx={solution.value}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/examples/least_squares.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "44bff903-0e4d-4f3e-a75c-d3cfe8ab4dea",
"metadata": {},
"source": [
"# Linear least squares\n",
"\n",
"The solution to a well-posed linear system $Ax = b$ is given by $x = A^{-1}b$. If the matrix is rectangular or not invertible, then we may generalise the notion of solution to $x = A^{\\dagger}b$, where $A^{\\dagger}$ denotes the [Moore--Penrose pseudoinverse](https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse).\n",
"\n",
"Lineax can handle problems of this type too.\n",
"\n",
"!!! info\n",
"\n",
" For reference: in core JAX, problems of this type are handled using [`jax.numpy.linalg.lstsq`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.lstsq.html#jax.numpy.linalg.lstsq).\n",
" \n",
"---\n",
"\n",
"## Picking a solver\n",
"\n",
"By default, the linear solve will fail. This will be a compile-time failure if using a rectangular matrix:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a956c3f2-a70c-472f-9fa9-3dbc16293e1d",
"metadata": {
"tags": []
},
"outputs": [
{
"ename": "ValueError",
"evalue": "Cannot use `AutoLinearSolver(well_posed=True)` with a non-square operator. If you are trying solve a least-squares problem then you should pass `solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve` assumes that the operator is square and nonsingular.",
"output_type": "error",
"traceback": [
"\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m Cannot use `AutoLinearSolver(well_posed=True)` with a non-square operator. If you are trying solve a least-squares problem then you should pass `solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve` assumes that the operator is square and nonsingular.\n"
]
}
],
"source": [
"import jax.random as jr\n",
"import lineax as lx\n",
"\n",
"\n",
"vector = jr.normal(jr.PRNGKey(1), (3,))\n",
"\n",
"rectangular_matrix = jr.normal(jr.PRNGKey(0), (3, 4))\n",
"rectangular_operator = lx.MatrixLinearOperator(rectangular_matrix)\n",
"lx.linear_solve(rectangular_operator, vector)"
]
},
{
"cell_type": "markdown",
"id": "ba55c0dd-b696-497a-8b13-896c3a95d5fd",
"metadata": {},
"source": [
"Or it will happen at run time if using a rank-deficient matrix:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f0e7ffe6-1e3d-46dc-9dbd-d5ed4c2dedf4",
"metadata": {
"tags": []
},
"outputs": [
{
"ename": "XlaRuntimeError",
"evalue": "INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the\noperator was not well-posed, and that the solver does not support this.\n\nIf you are trying solve a linear least-squares problem then you should pass\n`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`\nassumes that the operator is square and nonsingular.\n\nIf you *were* expecting this solver to work with this operator, then it may be because:\n\n(a) the operator is singular, and your code has a bug; or\n\n(b) the operator was nearly singular (i.e. it had a high condition number:\n `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from\n numerical instability issues; or\n\n(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)\n that is does not actually satisfy.\n",
"output_type": "error",
"traceback": [
"\u001b[0;31mXlaRuntimeError\u001b[0m\u001b[0;31m:\u001b[0m INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the\noperator was not well-posed, and that the solver does not support this.\n\nIf you are trying solve a linear least-squares problem then you should pass\n`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`\nassumes that the operator is square and nonsingular.\n\nIf you *were* expecting this solver to work with this operator, then it may be because:\n\n(a) the operator is singular, and your code has a bug; or\n\n(b) the operator was nearly singular (i.e. it had a high condition number:\n `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from\n numerical instability issues; or\n\n(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)\n that is does not actually satisfy.\n"
]
}
],
"source": [
"deficient_matrix = jr.normal(jr.PRNGKey(0), (3, 3)).at[0].set(0)\n",
"deficient_operator = lx.MatrixLinearOperator(deficient_matrix)\n",
"lx.linear_solve(deficient_operator, vector)"
]
},
{
"cell_type": "markdown",
"id": "4b5cedab-75e5-4b52-88d9-b9d574be7e19",
"metadata": {},
"source": [
"Whilst linear least squares and pseudoinverse are a strict generalisation of linear solves and inverses (respectively), Lineax will *not* attempt to handle the ill-posed case automatically. This is because the algorithms for handling this case are much more computationally expensive.!\n",
"\n",
"If your matrix may be rectangular, but is still known to be full rank, then you can set the solver to allow this case like so:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "45abc8bf-4fcf-46be-a91a-58f4e04ac10e",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"rectangular_solution: [-0.3214848 -0.75565964 -0.6034579 -0.01326615]\n"
]
}
],
"source": [
"rectangular_solution = lx.linear_solve(\n",
" rectangular_operator, vector, solver=lx.AutoLinearSolver(well_posed=None)\n",
")\n",
"print(\"rectangular_solution: \", rectangular_solution.value)"
]
},
{
"cell_type": "markdown",
"id": "86dc9e2f-fe2e-48c8-86ca-bc57f8137246",
"metadata": {},
"source": [
"If your matrix may be either rectangular or rank-deficient, then you can set the solver to all this case like so:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a9a2d92c-3676-471e-bb4a-5fd3b4748fd4",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"deficient_solution: [ 0.06046088 -1.0412765 0.8860444 ]\n"
]
}
],
"source": [
"deficient_solution = lx.linear_solve(\n",
" deficient_operator, vector, solver=lx.AutoLinearSolver(well_posed=False)\n",
")\n",
"print(\"deficient_solution: \", deficient_solution.value)"
]
},
{
"cell_type": "markdown",
"id": "7b870311-de0f-434c-a9e7-2d8ebf9f0b38",
"metadata": {},
"source": [
"Most users will want to use [`lineax.AutoLinearSolver`][], and not think about the details of which algorithm is selected.\n",
"\n",
"If you want to pick a particular algorithm, then that can be done too. [`lineax.QR`][] is capable of handling rectangular full-rank operators, and [`lineax.SVD`][] is capable of handling rank-deficient operators. (And in fact these are the algorithms that `AutoLinearSolver` is selecting in the examples above.)"
]
},
{
"cell_type": "markdown",
"id": "c9649746-b0ef-495b-9ea1-eb5f6ca2e7e5",
"metadata": {},
"source": [
"---\n",
"\n",
"## Differences from `jax.numpy.linalg.lstsq`?\n",
"\n",
"Lineax offers both speed and correctness advantages over the built-in algorithm. (This is partly because the built-in function has to have the same API as NumPy, so JAX is constrained in how it can be implemented.)\n",
"\n",
"### Speed (forward)\n",
"\n",
"First, in the rectangular case, then the QR algorithm is much faster than the SVD algorithm:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "d46d0c9a-47e4-439d-9beb-c9aaf47faa5d",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"JAX solution: [-0.10002219 0.09477127 -0.10846332 ... -0.08007179 -0.01216239\n",
" -0.030862 ]\n",
"Lineax solution: [-0.1000222 0.0947713 -0.10846333 ... -0.08007187 -0.01216241\n",
" -0.03086199]\n",
"\n",
"JAX time: 0.011344402999384329\n",
"Lineax time: 0.0028611960005946457\n"
]
}
],
"source": [
"import timeit\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"\n",
"matrix = jr.normal(jr.PRNGKey(0), (500, 200))\n",
"vector = jr.normal(jr.PRNGKey(1), (500,))\n",
"\n",
"\n",
"@jax.jit\n",
"def solve_jax(matrix, vector):\n",
" out, *_ = jnp.linalg.lstsq(matrix, vector)\n",
" return out\n",
"\n",
"\n",
"@jax.jit\n",
"def solve_lineax(matrix, vector):\n",
" operator = lx.MatrixLinearOperator(matrix)\n",
" solver = lx.QR() # or lx.AutoLinearSolver(well_posed=None)\n",
" solution = lx.linear_solve(operator, vector, solver)\n",
" return solution.value\n",
"\n",
"\n",
"solution_jax = solve_jax(matrix, vector)\n",
"solution_lineax = solve_lineax(matrix, vector)\n",
"with np.printoptions(threshold=10):\n",
" print(\"JAX solution:\", solution_jax)\n",
" print(\"Lineax solution:\", solution_lineax)\n",
"print()\n",
"time_jax = timeit.repeat(lambda: solve_jax(matrix, vector), number=1, repeat=10)\n",
"time_lineax = timeit.repeat(lambda: solve_lineax(matrix, vector), number=1, repeat=10)\n",
"print(\"JAX time:\", min(time_jax))\n",
"print(\"Lineax time:\", min(time_lineax))"
]
},
{
"cell_type": "markdown",
"id": "397773d7-f782-45e6-9934-11c62c741380",
"metadata": {},
"source": [
"### Speed (gradients)\n",
"\n",
"Lineax also uses a slightly more efficient autodifferentiation implementation, which ensures it is faster, even when both are using the SVD algorithm."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "1988d0f6-86f5-401a-9615-30cccf04d129",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"JAX gradients: [[-1.75446249e-03 2.00700224e-03 ... -3.16517282e-04 -6.08515576e-04]\n",
" [ 1.81865180e-04 4.51280124e-04 ... -1.64618701e-04 -6.53692259e-05]\n",
" ...\n",
" [-7.27269216e-04 1.27710134e-03 ... -2.64510425e-04 -3.38940619e-04]\n",
" [ 6.55723223e-03 -3.18011409e-03 ... -1.10758876e-04 1.43246143e-03]]\n",
"Lineax gradients: [[-1.7544631e-03 2.0070139e-03 ... -3.1653541e-04 -6.0847402e-04]\n",
" [ 1.8186278e-04 4.5128341e-04 ... -1.6459504e-04 -6.5359738e-05]\n",
" ...\n",
" [-7.2721508e-04 1.2771402e-03 ... -2.6450949e-04 -3.3894143e-04]\n",
" [ 6.5572355e-03 -3.1801097e-03 ... -1.1071599e-04 1.4324478e-03]]\n",
"\n",
"JAX time: 0.016591553001489956\n",
"Lineax time: 0.012212782999995397\n"
]
}
],
"source": [
"@jax.jit\n",
"@jax.grad\n",
"def grad_jax(matrix):\n",
" out, *_ = jnp.linalg.lstsq(matrix, vector)\n",
" return out.sum()\n",
"\n",
"\n",
"@jax.jit\n",
"@jax.grad\n",
"def grad_lineax(matrix):\n",
" operator = lx.MatrixLinearOperator(matrix)\n",
" solution = lx.linear_solve(operator, vector, lx.SVD())\n",
" return solution.value.sum()\n",
"\n",
"\n",
"gradients_jax = grad_jax(matrix)\n",
"gradients_lineax = grad_lineax(matrix)\n",
"with np.printoptions(threshold=10, edgeitems=2):\n",
" print(\"JAX gradients:\", gradients_jax)\n",
" print(\"Lineax gradients:\", gradients_lineax)\n",
"print()\n",
"time_jax = timeit.repeat(lambda: grad_jax(matrix), number=1, repeat=10)\n",
"time_lineax = timeit.repeat(lambda: grad_lineax(matrix), number=1, repeat=10)\n",
"print(\"JAX time:\", min(time_jax))\n",
"print(\"Lineax time:\", min(time_lineax))"
]
},
{
"cell_type": "markdown",
"id": "81a1da5a-3474-4613-926f-5c9d9cdcb4a7",
"metadata": {},
"source": [
"### Correctness (gradients)\n",
"\n",
"Core JAX unfortunately has a bug that means it sometimes produces NaN gradients. Lineax does not:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "66b3a08e-92d0-4d0f-a5ea-9e8e5265d259",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"JAX gradients: [[nan nan nan]\n",
" [nan nan nan]\n",
" [nan nan nan]]\n",
"Lineax gradients: [[ 0. -1. -2.]\n",
" [ 0. -1. -2.]\n",
" [ 0. -1. -2.]]\n"
]
}
],
"source": [
"@jax.jit\n",
"@jax.grad\n",
"def grad_jax(matrix):\n",
" out, *_ = jnp.linalg.lstsq(matrix, jnp.arange(3.0))\n",
" return out.sum()\n",
"\n",
"\n",
"@jax.jit\n",
"@jax.grad\n",
"def grad_lineax(matrix):\n",
" operator = lx.MatrixLinearOperator(matrix)\n",
" solution = lx.linear_solve(operator, jnp.arange(3.0), lx.SVD())\n",
" return solution.value.sum()\n",
"\n",
"\n",
"print(\"JAX gradients:\", grad_jax(jnp.eye(3)))\n",
"print(\"Lineax gradients:\", grad_lineax(jnp.eye(3)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py39",
"language": "python",
"name": "py39"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/examples/no_materialisation.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "a7299095-8906-4867-82ef-d6b84b161366",
"metadata": {},
"source": [
"# Using only matrix-vector operations\n",
"\n",
"When solving a linear system $Ax = b$, it is relatively common not to have immediate access to the full matrix $A$, but only to a function $F(x) = Ax$ computing the matrix-vector product. (We could compute $A$ from $F$, but is the matrix is large then this may be very inefficient.)\n",
"\n",
"**Example: Newton's method**\n",
"\n",
"For example, this comes up when using [Newton's method](https://en.wikipedia.org/wiki/Newton%27s_method#k_variables,_k_functions). In this case, we have a function $f \\colon \\mathbb{R}^n \\to \\mathbb{R}^n$, and wish to find the $\\delta \\in \\mathbb{R}^n$ for which $\\frac{\\mathrm{d}f}{\\mathrm{d}y}(y) \\; \\delta = -f(y)$. (Where $\\frac{\\mathrm{d}f}{\\mathrm{d}y}(y) \\in \\mathbb{R}^{n \\times n}$ is a matrix: it is the Jacobian of $f$.)\n",
"\n",
"In this case it is possible to use forward-mode autodifferentiation to evaluate $F(x) = \\frac{\\mathrm{d}f}{\\mathrm{d}y}(y) \\; x$, without ever instantiating the whole Jacobian $\\frac{\\mathrm{d}f}{\\mathrm{d}y}(y)$. Indeed, JAX has a [Jacobian-vector product function](https://jax.readthedocs.io/en/latest/_autosummary/jax.jvp.html#jax.jvp) for exactly this purpose.\n",
"```python\n",
"f = ...\n",
"y = ...\n",
"\n",
"def F(x):\n",
" \"\"\"Computes (df/dy) @ x.\"\"\"\n",
" _, out = jax.jvp(f, (y,), (x,))\n",
" return out\n",
"```\n",
"\n",
"**Solving a linear system using only matrix-vector operations**\n",
"\n",
"Lineax offers [iterative solvers](../api/solvers.md#iterative-solvers), which are capable of solving a linear system knowing only its matrix-vector products."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b221ee1f-bd6b-4cbf-b69b-ed2e388602e1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import lineax as lx\n",
"from jaxtyping import Array, Float # https://github.com/google/jaxtyping\n",
"\n",
"\n",
"def f(y: Float[Array, \"3\"], args) -> Float[Array, \"3\"]:\n",
" y0, y1, y2 = y\n",
" f0 = 5 * y0 + y1**2\n",
" f1 = y1 - y2 + 5\n",
" f2 = y0 / (1 + 5 * y2**2)\n",
" return jnp.stack([f0, f1, f2])\n",
"\n",
"\n",
"y = jnp.array([1.0, 2.0, 3.0])\n",
"operator = lx.JacobianLinearOperator(f, y, args=None)\n",
"vector = f(y, args=None)\n",
"solver = lx.Normal(lx.CG(rtol=1e-6, atol=1e-6))\n",
"solution = lx.linear_solve(operator, vector, solver)"
]
},
{
"cell_type": "markdown",
"id": "87568426-35ed-404b-bf78-425a6f519218",
"metadata": {},
"source": [
"!!! warning\n",
"\n",
" Note that iterative solvers are something of a \"last resort\", and they are not suitable for all problems.\n",
"\n",
" - [CG](https://en.wikipedia.org/wiki/Conjugate_gradient_method) requires that the problem be positive or negative semidefinite.\n",
" - Normalised CG (this is CG applied to the \"normal equations\" $(A^\\top A) x = (A^\\top b)$; note that $A^\\top A$ is always positive semidefinite) squares the condition number of $A$. In practice this means it may produce low-accuracy results if used with matrices with high condition number.\n",
" - [BiCGStab](https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method) and [GMRES](https://en.wikipedia.org/wiki/Generalized_minimal_residual_method) will fail on many problems. They are primarily meant as specialised tools for e.g. the matrices that arise when solving elliptic systems."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py39",
"language": "python",
"name": "py39"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/examples/operators.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "2fe0b1e4-35cb-4c39-b324-65253aab005a",
"metadata": {},
"source": [
"# Manipulating linear operators\n",
"\n",
"Lineax offers a sophisticated system of linear operators, supporting many operations.\n",
"\n",
"## Arithmetic\n",
"\n",
"To begin with, they support arithmetic, like addition and multiplication:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "552021d3-dadf-49f3-bd17-84a18513bfcc",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"import lineax as lx\n",
"import numpy as np\n",
"\n",
"\n",
"np.set_printoptions(precision=3)\n",
"\n",
"matrix = jnp.zeros((5, 5))\n",
"matrix = matrix.at[0, 4].set(3) # top left corner\n",
"sparse_operator = lx.MatrixLinearOperator(matrix)\n",
"\n",
"key0, key1, key = jr.split(jr.PRNGKey(0), 3)\n",
"diag = jr.normal(key0, (5,))\n",
"lower_diag = jr.normal(key0, (4,))\n",
"upper_diag = jr.normal(key0, (4,))\n",
"tridiag_operator = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)\n",
"\n",
"identity_operator = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((5,), jnp.float32))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a4bb9825-73cc-447e-bc4c-c3e1a121a0a3",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-1.149 0.963 0. 0. 3. ]\n",
" [ 0.963 -2.007 0.155 0. 0. ]\n",
" [ 0. 0.155 0.988 -0.261 0. ]\n",
" [ 0. 0. -0.261 0.931 0.899]\n",
" [ 0. 0. 0. 0.899 -0.288]]\n"
]
}
],
"source": [
"print((sparse_operator + tridiag_operator).as_matrix())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "759c78a1-eee7-40e9-be6c-ea8c97c29e95",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-101.149 0.963 0. 0. 0. ]\n",
" [ 0.963 -102.007 0.155 0. 0. ]\n",
" [ 0. 0.155 -99.012 -0.261 0. ]\n",
" [ 0. 0. -0.261 -99.069 0.899]\n",
" [ 0. 0. 0. 0.899 -100.288]]\n"
]
}
],
"source": [
"print((tridiag_operator - 100 * identity_operator).as_matrix())"
]
},
{
"cell_type": "markdown",
"id": "84412bfa-00ec-41d4-87d7-def781145a90",
"metadata": {},
"source": [
"Or they can be composed together. (I.e. matrix multiplication.)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8081d97f-5579-464f-8780-ffaa1d9c5f95",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0. 0. 0. 0. -3.447]\n",
" [ 0. 0. 0. 0. 2.888]\n",
" [ 0. 0. 0. 0. 0. ]\n",
" [ 0. 0. 0. 0. 0. ]\n",
" [ 0. 0. 0. 0. 0. ]]\n"
]
}
],
"source": [
"print((tridiag_operator @ sparse_operator).as_matrix())"
]
},
{
"cell_type": "markdown",
"id": "d2c2b580-616f-4abd-a732-7f4a9b13335f",
"metadata": {},
"source": [
"Or they can be transposed:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ae0393eb-3f43-490b-9842-bb374633633a",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0. 0. 0. 0. 0.]\n",
" [0. 0. 0. 0. 0.]\n",
" [0. 0. 0. 0. 0.]\n",
" [0. 0. 0. 0. 0.]\n",
" [3. 0. 0. 0. 0.]]\n"
]
}
],
"source": [
"print(sparse_operator.transpose().as_matrix()) # or sparse_operator.T will work"
]
},
{
"cell_type": "markdown",
"id": "ddbbbb0f-7983-4e35-b92d-2512c9612d19",
"metadata": {},
"source": [
"## Different operator types\n",
"\n",
"Lineax has many different operator types:\n",
"\n",
"- We've already seen some general examples above, like [`lineax.MatrixLinearOperator`][].\n",
"- We've already seen some structured examples above, like [`lineax.TridiagonalLinearOperator`][].\n",
"- Given a function $f \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$ and a point $x \\in \\mathbb{R}^n$, then [`lineax.JacobianLinearOperator`][] represents the Jacobian $\\frac{\\mathrm{d}f}{\\mathrm{d}x}(x) \\in \\mathbb{R}^{n \\times m}$.\n",
"- Given a linear function $g \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$, then [`lineax.FunctionLinearOperator`][] represents the matrix corresponding to this linear function, i.e. the unique matrix $A$ for which $g(x) = Ax$.\n",
"- etc!\n",
"\n",
"See the [operators](../api/operators.md) page for details on all supported operators.\n",
"\n",
"As above these can be freely combined:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "75ad4480-8ce0-4a88-9c76-bc054b1a0eaf",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from jaxtyping import Array, Float # https://github.com/google/jaxtyping\n",
"\n",
"\n",
"def f(y: Float[Array, \"3\"], args) -> Float[Array, \"3\"]:\n",
" y0, y1, y2 = y\n",
" f0 = 5 * y0 + y1**2\n",
" f1 = y1 - y2 + 5\n",
" f2 = y0 / (1 + 5 * y2**2)\n",
" return jnp.stack([f0, f1, f2])\n",
"\n",
"\n",
"def g(y: Float[Array, \"3\"]) -> Float[Array, \"3\"]:\n",
" # Must be linear!\n",
" y0, y1, y2 = y\n",
" f0 = y0 - y2\n",
" f1 = 0.0\n",
" f2 = 5 * y1\n",
" return jnp.stack([f0, f1, f2])\n",
"\n",
"\n",
"y = jnp.array([1.0, 2.0, 3.0])\n",
"in_structure = jax.eval_shape(lambda: y)\n",
"jac_operator = lx.JacobianLinearOperator(f, y, args=None)\n",
"fn_operator = lx.FunctionLinearOperator(g, in_structure)\n",
"identity_operator = lx.IdentityLinearOperator(in_structure)\n",
"\n",
"operator = jac_operator @ fn_operator + 0.9 * identity_operator"
]
},
{
"cell_type": "markdown",
"id": "5e528057-29ff-468d-aa3d-7155dd57082d",
"metadata": {},
"source": [
"This composition does not instantiate a matrix for them by default. (This is sometimes important for efficiency when working with many operators.) Instead, the composition is stored as another linear operator:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5d15150d-955f-4006-bd36-58e2e6663307",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AddLinearOperator(\n",
" operator1=ComposedLinearOperator(\n",
" operator1=JacobianLinearOperator(...),\n",
" operator2=FunctionLinearOperator(...)\n",
" ),\n",
" operator2=MulLinearOperator(\n",
" operator=IdentityLinearOperator(...),\n",
" scalar=f32[]\n",
" )\n",
")\n"
]
}
],
"source": [
"import equinox as eqx # https://github.com/patrick-kidger/equinox\n",
"\n",
"\n",
"truncate_leaf = lambda x: x in (jac_operator, fn_operator, identity_operator)\n",
"eqx.tree_pprint(operator, truncate_leaf=truncate_leaf)"
]
},
{
"cell_type": "markdown",
"id": "ff7b0591-1203-4f5e-886e-399822c68a15",
"metadata": {
"tags": []
},
"source": [
"If you want to materialise them into a matrix, then this can be done:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "3713589f-1ac4-4e08-946b-ecc3fcf6a4c3",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"Array([[ 5.9 , 0. , -5. ],\n",
" [ 0. , -4.1 , 0. ],\n",
" [ 0.022, -0.071, 0.878]], dtype=float32)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"operator.as_matrix()"
]
},
{
"cell_type": "markdown",
"id": "a483517e-89d7-4e9e-ad89-1915d886c14c",
"metadata": {},
"source": [
"Which can in turn be treated as another linear operator, if desired:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fccddc81-d50e-4abe-a354-38402e462b1f",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MatrixLinearOperator(\n",
" matrix=Array([[ 5.9 , 0. , -5. ],\n",
" [ 0. , -4.1 , 0. ],\n",
" [ 0.022, -0.071, 0.878]], dtype=float32),\n",
" tags=frozenset()\n",
")\n"
]
}
],
"source": [
"operator_fully_materialised = lx.MatrixLinearOperator(operator.as_matrix())\n",
"eqx.tree_pprint(operator_fully_materialised, short_arrays=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py39",
"language": "python",
"name": "py39"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/examples/structured_matrices.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "e2573d62-a505-4998-8796-b0f1bc889433",
"metadata": {},
"source": [
"# Structured matrices\n",
"\n",
"Lineax can also be used with matrices known to exhibit special structure, e.g. tridiagonal matrices or positive definite matrices.\n",
"\n",
"Typically, that means using a particular operator type:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8e275652-dd80-4a9a-b3ac-b96dc16d3334",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 4. 2. 0. 0. ]\n",
" [ 1. -0.5 -1. 0. ]\n",
" [ 0. 3. 7. -5. ]\n",
" [ 0. 0. -0.7 1. ]]\n"
]
}
],
"source": [
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"import lineax as lx\n",
"\n",
"\n",
"diag = jnp.array([4.0, -0.5, 7.0, 1.0])\n",
"lower_diag = jnp.array([1.0, 3.0, -0.7])\n",
"upper_diag = jnp.array([2.0, -1.0, -5.0])\n",
"\n",
"operator = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)\n",
"print(operator.as_matrix())"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ba23ecc4-bdea-4293-a138-ce77bc83082c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"vector = jnp.array([1.0, -0.5, 2.0, 0.8])\n",
"# Will automatically dispatch to a tridiagonal solver.\n",
"solution = lx.linear_solve(operator, vector)"
]
},
{
"cell_type": "markdown",
"id": "cd58979d-b619-4ddf-9a17-12e8babae3e8",
"metadata": {},
"source": [
"If you're uncertain which solver is being dispatched to, then you can check:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6984f62f-75fc-4d6e-ab42-fdade471be5b",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tridiagonal()\n"
]
}
],
"source": [
"default_solver = lx.AutoLinearSolver(well_posed=True)\n",
"print(default_solver.select_solver(operator))"
]
},
{
"cell_type": "markdown",
"id": "164a5bd5-5d48-4b28-bcc5-d276ab49c780",
"metadata": {},
"source": [
"If you want to enforce that a particular solver is used, then it can be passed manually:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "102ada9a-0533-40cf-9bad-02918fffb6b1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"solution = lx.linear_solve(operator, vector, solver=lx.Tridiagonal())"
]
},
{
"cell_type": "markdown",
"id": "1b4ebf09-e138-43f6-973c-c9f005ffb55e",
"metadata": {},
"source": [
"Trying to use a solver with an unsupported operator will raise an error:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d8f5bf66-53cd-4e81-a8d7-a19e86307ad3",
"metadata": {
"tags": []
},
"outputs": [
{
"ename": "ValueError",
"evalue": "`Tridiagonal` may only be used for linear solves with tridiagonal matrices",
"output_type": "error",
"traceback": [
"\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m `Tridiagonal` may only be used for linear solves with tridiagonal matrices\n"
]
}
],
"source": [
"not_tridiagonal_matrix = jr.normal(jr.PRNGKey(0), (4, 4))\n",
"not_tridiagonal_operator = lx.MatrixLinearOperator(not_tridiagonal_matrix)\n",
"solution = lx.linear_solve(not_tridiagonal_operator, vector, solver=lx.Tridiagonal())"
]
},
{
"cell_type": "markdown",
"id": "03c4c531-58fa-4b56-8b0a-6e611c8c5912",
"metadata": {},
"source": [
"---\n",
"\n",
"Besides using a particular operator type, the structure of the matrix can also be expressed by [adding particular tags](../api/tags.md). These tags act as a manual override mechanism, and the values of the matrix are not checked.\n",
"\n",
"For example, let's construct a positive definite matrix:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b5add874-7a2c-4000-84c3-8c94a121a831",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"matrix = jr.normal(jr.PRNGKey(0), (4, 4))\n",
"operator = lx.MatrixLinearOperator(matrix.T @ matrix)"
]
},
{
"cell_type": "markdown",
"id": "5459b2d6-ddb9-4a37-bb51-3f5c204bab0d",
"metadata": {},
"source": [
"Unfortunately, Lineax has no way of knowing that this matrix is positive definite. It can solve the system, but it will not use a solver that is adapted to exploit the extra structure:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "78400416-e774-4f74-a530-e368db84af0e",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LU()\n"
]
}
],
"source": [
"solution = lx.linear_solve(operator, vector)\n",
"print(default_solver.select_solver(operator))"
]
},
{
"cell_type": "markdown",
"id": "e108bdff-1cf1-4751-8c9d-3baae82ca9a7",
"metadata": {},
"source": [
"But if we add a tag:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f6dc2966-1dfa-4a3c-be6a-974926695547",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cholesky()\n"
]
}
],
"source": [
"operator = lx.MatrixLinearOperator(matrix.T @ matrix, lx.positive_semidefinite_tag)\n",
"solution2 = lx.linear_solve(operator, vector)\n",
"print(default_solver.select_solver(operator))"
]
},
{
"cell_type": "markdown",
"id": "7274d17b-a7d3-45bf-9042-785ac25e2d74",
"metadata": {},
"source": [
"Then a more efficient solver can be selected. We can check that the solutions returned from these two approaches are equal:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "fdcde152-9ac1-4532-a174-3fc39d83d289",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 1.400575 -0.41042092 0.5313305 0.28422552]\n",
"[ 1.4005749 -0.41042086 0.53133047 0.2842255 ]\n"
]
}
],
"source": [
"print(solution.value)\n",
"print(solution2.value)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py39",
"language": "python",
"name": "py39"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/faq.md
================================================
# FAQ
## How does this differ from `jax.numpy.solve`, `jax.scipy.{...}` etc.?
Lineax offers several improvements. Most notably:
- Several new solvers. For example, [`lineax.QR`][] has no counterpart in core JAX. (And it is much faster than `jax.numpy.linalg.lstsq`, which is the closest equivalent, and uses an SVD decomposition instead.)
- Several new operators. For example, [`lineax.JacobianLinearOperator`][] has no counterpart in core JAX.
- A consistent API. The built-in JAX operations all differ from each other slightly, and are split across `jax.numpy`, `jax.scipy`, and `jax.scipy.sparse`.
- Numerically stable gradients. The existing JAX implementations will sometimes return `NaN`s!
- Some faster compile times and run times in a few places.
Most of these are because JAX aims to mimic the existing NumPy/SciPy APIs. (I.e. it's not JAX's fault that it doesn't take the approach that Lineax does!)
## How do I represent a {lower, upper} triangular matrix?
Typically: create a full matrix, with the {lower, upper} part containing your values, and the converse {upper, lower} part containing all zeros. Then use, e.g., `operator = lx.MatrixLinearOperator(matrix, lx.lower_triangular_tag)`.
This is the most efficient way to store a triangular matrix in JAX's ndarray-based programming model.
## What about other operations from linear algebra? (Determinants, eigenvalues, etc.)
See [`jax.numpy.linalg`](https://jax.readthedocs.io/en/latest/jax.numpy.html#module-jax.numpy.linalg) and [`jax.scipy.linalg`](https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.linalg).
## How do I solve multiple systems of equations (i.e. `AX = B`)?
Solvers implemented in Lineax target single systems of linear equations (i.e., `Ax = b`), however, using `jax.vmap` or `equinox.filter_vmap`, it can solve multiple systems with minimal effort.
```python
multi_linear_solve = eqx.filter_vmap(lx.linear_solve, in_axes=(None, 1))
# or
multi_linear_solve = jax.vmap(lx.linear_solve, in_axes=(None, 1))
```
================================================
FILE: docs/index.md
================================================
# Getting started
Lineax is a [JAX](https://github.com/google/jax) library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.)
Features include:
- PyTree-valued matrices and vectors;
- General linear operators for Jacobians, transposes, etc.;
- Efficient linear least squares (e.g. QR solvers);
- Numerically stable gradients through linear least squares;
- Support for structured (e.g. symmetric) matrices;
- Improved compilation times;
- Improved runtime of some algorithms;
- Support for both real-valued and complex-valued inputs;
- All the benefits of working with JAX: autodiff, autoparallism, GPU/TPU support, etc.
## Installation
```bash
pip install lineax
```
Requires Python 3.10+, JAX 0.4.38+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.10+.
## Quick example
Lineax can solve a least squares problem with an explicit matrix operator:
```python
import jax.random as jr
import lineax as lx
matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 8))
vector = jr.normal(vector_key, (10,))
operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector, solver=lx.QR())
```
or Lineax can solve a problem without ever materializing a matrix, as done in this
quadratic solve:
```python
import jax
import lineax as lx
key = jax.random.PRNGKey(0)
y = jax.random.normal(key, (10,))
def quadratic_fn(y, args):
return jax.numpy.sum((y - 1)**2)
gradient_fn = jax.grad(quadratic_fn)
hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-6, atol=1e-6)
out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver)
minimum = y - out.value
```
## Next steps
Check out the examples or the API reference on the left-hand bar.
## See also: other libraries in the JAX ecosystem
**Always useful**
[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
[jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.
**Deep learning**
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
[paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
**Scientific computing**
[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
**Awesome JAX**
[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.
================================================
FILE: lineax/__init__.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.metadata
from . import internal as internal
from ._operator import (
AbstractLinearOperator as AbstractLinearOperator,
AddLinearOperator as AddLinearOperator,
ComposedLinearOperator as ComposedLinearOperator,
conj as conj,
diagonal as diagonal,
DiagonalLinearOperator as DiagonalLinearOperator,
DivLinearOperator as DivLinearOperator,
FunctionLinearOperator as FunctionLinearOperator,
has_unit_diagonal as has_unit_diagonal,
IdentityLinearOperator as IdentityLinearOperator,
is_diagonal as is_diagonal,
is_lower_triangular as is_lower_triangular,
is_negative_semidefinite as is_negative_semidefinite,
is_positive_semidefinite as is_positive_semidefinite,
is_symmetric as is_symmetric,
is_tridiagonal as is_tridiagonal,
is_upper_triangular as is_upper_triangular,
JacobianLinearOperator as JacobianLinearOperator,
linearise as linearise,
materialise as materialise,
MatrixLinearOperator as MatrixLinearOperator,
MulLinearOperator as MulLinearOperator,
NegLinearOperator as NegLinearOperator,
PyTreeLinearOperator as PyTreeLinearOperator,
TaggedLinearOperator as TaggedLinearOperator,
TangentLinearOperator as TangentLinearOperator,
tridiagonal as tridiagonal,
TridiagonalLinearOperator as TridiagonalLinearOperator,
)
from ._solution import RESULTS as RESULTS, Solution as Solution
from ._solve import (
AbstractLinearSolver as AbstractLinearSolver,
AutoLinearSolver as AutoLinearSolver,
invert as invert,
linear_solve as linear_solve,
)
from ._solver import (
BiCGStab as BiCGStab,
CG as CG,
Cholesky as Cholesky,
Diagonal as Diagonal,
GMRES as GMRES,
LSMR as LSMR,
LU as LU,
Normal as Normal,
NormalCG as NormalCG,
QR as QR,
SVD as SVD,
Triangular as Triangular,
Tridiagonal as Tridiagonal,
)
from ._tags import (
diagonal_tag as diagonal_tag,
lower_triangular_tag as lower_triangular_tag,
negative_semidefinite_tag as negative_semidefinite_tag,
positive_semidefinite_tag as positive_semidefinite_tag,
symmetric_tag as symmetric_tag,
transpose_tags as transpose_tags,
transpose_tags_rules as transpose_tags_rules,
tridiagonal_tag as tridiagonal_tag,
unit_diagonal_tag as unit_diagonal_tag,
upper_triangular_tag as upper_triangular_tag,
)
__version__ = importlib.metadata.version("lineax")
================================================
FILE: lineax/_custom_types.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import equinox.internal as eqxi
sentinel: Any = eqxi.doc_repr(object(), "sentinel")
================================================
FILE: lineax/_misc.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import Array, ArrayLike, Bool, PyTree # pyright:ignore
def tree_where(
pred: Bool[ArrayLike, ""], true: PyTree[ArrayLike], false: PyTree[ArrayLike]
) -> PyTree[Array]:
keep = lambda a, b: jnp.where(pred, a, b)
return jtu.tree_map(keep, true, false)
def resolve_rcond(rcond, n, m, dtype):
if rcond is None:
# This `2 *` is a heuristic: I have seen very rare failures without it, in ways
# that seem to depend on JAX compilation state. (E.g. running unrelated JAX
# computations beforehand, in a completely different JIT-compiled region, can
# result in differences in the success/failure of the solve.)
return 2 * jnp.finfo(dtype).eps * max(n, m)
else:
return jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
def jacobian(fn, in_size, out_size, holomorphic=False, has_aux=False, jac=None):
if jac is None:
# Heuristic for which is better in each case
# These could probably be tuned a lot more.
jac_fwd = (in_size < 100) or (in_size <= 1.5 * out_size)
elif jac == "fwd":
jac_fwd = True
elif jac == "bwd":
jac_fwd = False
else:
raise ValueError("`jac` should either be None, 'fwd', or 'bwd'.")
if jac_fwd:
return jax.jacfwd(fn, holomorphic=holomorphic, has_aux=has_aux)
else:
return jax.jacrev(fn, holomorphic=holomorphic, has_aux=has_aux)
def _asarray(dtype, x):
return jnp.asarray(x, dtype=dtype)
# Work around JAX issue #15676
_asarray = jax.custom_jvp(_asarray, nondiff_argnums=(0,))
@_asarray.defjvp
def _asarray_jvp(dtype, x, tx):
(x,) = x
(tx,) = tx
return _asarray(dtype, x), _asarray(dtype, tx)
def default_floating_dtype():
if jax.config.jax_enable_x64: # pyright: ignore
return jnp.float64
else:
return jnp.float32
def inexact_asarray(x):
dtype = jnp.result_type(x)
if not jnp.issubdtype(jnp.result_type(x), jnp.inexact):
dtype = default_floating_dtype()
return _asarray(dtype, x)
def complex_to_real_dtype(dtype):
return jnp.finfo(dtype).dtype
def strip_weak_dtype(tree: PyTree) -> PyTree:
return jtu.tree_map(
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=x.sharding)
if type(x) is jax.ShapeDtypeStruct
else x,
tree,
)
def structure_equal(x, y) -> bool:
x = strip_weak_dtype(jax.eval_shape(lambda: x))
y = strip_weak_dtype(jax.eval_shape(lambda: y))
return eqx.tree_equal(x, y) is True
================================================
FILE: lineax/_norm.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools as ft
import math
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from jaxtyping import Array, ArrayLike, Inexact, PyTree, Scalar
from ._misc import complex_to_real_dtype, default_floating_dtype
def tree_dot(tree1: PyTree[ArrayLike], tree2: PyTree[ArrayLike]) -> Inexact[Array, ""]:
"""Compute the dot product of two pytrees of arrays with the same pytree
structure."""
leaves1, treedef1 = jtu.tree_flatten(tree1)
leaves2, treedef2 = jtu.tree_flatten(tree2)
if treedef1 != treedef2:
raise ValueError("trees must have the same structure")
assert len(leaves1) == len(leaves2)
dots = []
for leaf1, leaf2 in zip(leaves1, leaves2):
dots.append(
jnp.dot(
jnp.conj(leaf1).reshape(-1),
jnp.reshape(leaf2, -1),
precision=jax.lax.Precision.HIGHEST, # pyright: ignore
)
)
if len(dots) == 0:
return jnp.array(0, default_floating_dtype())
else:
return ft.reduce(jnp.add, dots)
def sum_squares(x: PyTree[ArrayLike]) -> Scalar:
"""Computes the square of the L2 norm of a PyTree of arrays.
Considering the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes
`Σ_i x_i^2`
"""
return tree_dot(x, x).real
def two_norm(x: PyTree[ArrayLike]) -> Scalar:
"""Computes the L2 norm of a PyTree of arrays.
Considering the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes
`sqrt(Σ_i x_i^2)`
"""
# Wrap the `custom_jvp` into a function so that our autogenerated documentation
# displays the docstring correctly.
return _two_norm(x)
@jax.custom_jvp
def _two_norm(x: PyTree[ArrayLike]) -> Scalar:
leaves = jtu.tree_leaves(x)
size = sum([jnp.size(xi) for xi in leaves])
if size == 1:
# Avoid needless squaring-and-then-rooting.
for leaf in leaves:
if jnp.size(leaf) == 1:
return jnp.abs(jnp.reshape(leaf, ()))
else:
assert False
else:
return jnp.sqrt(sum_squares(x))
@_two_norm.defjvp
def _two_norm_jvp(x, tx):
(x,) = x
(tx,) = tx
out = two_norm(x)
# Get zero gradient, rather than NaN gradient, in these cases.
pred = (out == 0) | jnp.isinf(out)
denominator = jnp.where(pred, 1, out)
# We could also switch the dot and the division.
# This approach is a bit more expensive (more divisions), but should be more
# numerically stable (`x` and `denominator` should be of the same scale; `tx` is of
# unknown scale).
with jax.numpy_dtype_promotion("standard"):
div = (x**ω / denominator).ω
t_out = tree_dot(div, tx).real
t_out = jnp.where(pred, 0, t_out)
return out, t_out
def rms_norm(x: PyTree[ArrayLike]) -> Scalar:
"""Compute the RMS (root-mean-squared) norm of a PyTree of arrays.
This is the same as the L2 norm, averaged by the size of the input `x`. Considering
the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes
`sqrt((Σ_i x_i^2)/n)`
"""
leaves = jtu.tree_leaves(x)
size = sum([jnp.size(xi) for xi in leaves])
if size == 0:
if len(leaves) == 0:
dtype = default_floating_dtype()
else:
dtype = complex_to_real_dtype(jnp.result_type(*leaves))
return jnp.array(0.0, dtype)
else:
return two_norm(x) / math.sqrt(size)
def max_norm(x: PyTree[ArrayLike]) -> Scalar:
"""Compute the L-infinity norm of a PyTree of arrays.
This is the largest absolute elementwise value. Considering the input `x` as a flat
vector `(x_1, ..., x_n)`, then this computes `max_i |x_i|`.
"""
leaves = jtu.tree_leaves(x)
leaf_maxes = [jnp.max(jnp.abs(xi)) for xi in leaves if jnp.size(xi) > 0]
if len(leaf_maxes) == 0:
if len(leaves) == 0:
dtype = default_floating_dtype()
else:
dtype = complex_to_real_dtype(jnp.result_type(*leaves))
return jnp.array(0.0, dtype)
else:
out = ft.reduce(jnp.maximum, leaf_maxes)
return _zero_grad_at_zero(out)
@jax.custom_jvp
def _zero_grad_at_zero(x):
return x
@_zero_grad_at_zero.defjvp
def _zero_grad_at_zero_jvp(primals, tangents):
(out,) = primals
(t_out,) = tangents
return out, jnp.where(out == 0, 0, t_out)
================================================
FILE: lineax/_operator.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import enum
import functools as ft
import math
import warnings
from collections.abc import Callable, Iterable
from typing import Any, Literal, NoReturn, TypeVar
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.flatten_util as jfu
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from equinox.internal import ω
from jaxtyping import (
Array,
ArrayLike,
Inexact,
PyTree, # pyright: ignore
Scalar,
Shaped,
)
from ._custom_types import sentinel
from ._misc import (
default_floating_dtype,
inexact_asarray,
jacobian,
strip_weak_dtype,
)
from ._tags import (
diagonal_tag,
lower_triangular_tag,
negative_semidefinite_tag,
positive_semidefinite_tag,
symmetric_tag,
transpose_tags,
tridiagonal_tag,
unit_diagonal_tag,
upper_triangular_tag,
)
def _frozenset(x: object | Iterable[object]) -> frozenset[object]:
try:
iter_x = iter(x) # pyright: ignore
except TypeError:
return frozenset([x])
else:
return frozenset(iter_x)
class AbstractLinearOperator(eqx.Module):
"""Abstract base class for all linear operators.
Linear operators can act between PyTrees. Each `AbstractLinearOperator` is thought
of as a linear function `X -> Y`, where each element of `X` is as PyTree of
floating-point JAX arrays, and each element of `Y` is a PyTree of floating-point
JAX arrays.
Abstract linear operators support some operations:
```python
op1 + op2 # addition of two operators
op1 @ op2 # composition of two operators.
op1 * 3.2 # multiplication by a scalar
op1 / 3.2 # division by a scalar
```
"""
def __check_init__(self):
if (
is_symmetric(self)
or is_positive_semidefinite(self)
or is_negative_semidefinite(self)
):
# In particular, we check that dtypes match.
in_structure = self.in_structure()
out_structure = self.out_structure()
# `is` check to handle the possibility of a tracer.
if eqx.tree_equal(in_structure, out_structure) is not True:
raise ValueError(
"Symmetric/Hermitian matrices must have matching input and output "
f"structures. Got input structure {in_structure} and output "
f"structure {out_structure}."
)
@abc.abstractmethod
def mv(
self, vector: PyTree[Inexact[Array, " _b"]]
) -> PyTree[Inexact[Array, " _a"]]:
"""Computes a matrix-vector product between this operator and a `vector`.
**Arguments:**
- `vector`: Should be some PyTree of floating-point arrays, whose structure
should match `self.in_structure()`.
**Returns:**
A PyTree of floating-point arrays, with structure that matches
`self.out_structure()`.
"""
@abc.abstractmethod
def as_matrix(self) -> Inexact[Array, "a b"]:
"""Materialises this linear operator as a matrix.
Note that this can be a computationally (time and/or memory) expensive
operation, as many linear operators are defined implicitly, e.g. in terms of
their action on a vector.
**Arguments:** None.
**Returns:**
A 2-dimensional floating-point JAX array.
"""
@abc.abstractmethod
def transpose(self) -> "AbstractLinearOperator":
"""Transposes this linear operator.
This can be called as either `operator.T` or `operator.transpose()`.
**Arguments:** None.
**Returns:**
Another [`lineax.AbstractLinearOperator`][].
"""
@abc.abstractmethod
def in_structure(self) -> PyTree[jax.ShapeDtypeStruct]:
"""Returns the expected input structure of this linear operator.
**Arguments:** None.
**Returns:**
A PyTree of `jax.ShapeDtypeStruct`.
"""
@abc.abstractmethod
def out_structure(self) -> PyTree[jax.ShapeDtypeStruct]:
"""Returns the expected output structure of this linear operator.
**Arguments:** None.
**Returns:**
A PyTree of `jax.ShapeDtypeStruct`.
"""
def in_size(self) -> int:
"""Returns the total number of scalars in the input of this linear operator.
That is, the dimensionality of its input space.
**Arguments:** None.
**Returns:** An integer.
"""
leaves = jtu.tree_leaves(self.in_structure())
return sum(math.prod(leaf.shape) for leaf in leaves) # pyright: ignore
def out_size(self) -> int:
"""Returns the total number of scalars in the output of this linear operator.
That is, the dimensionality of its output space.
**Arguments:** None.
**Returns:** An integer.
"""
leaves = jtu.tree_leaves(self.out_structure())
return sum(math.prod(leaf.shape) for leaf in leaves) # pyright: ignore
@property
def T(self) -> "AbstractLinearOperator":
"""Equivalent to [`lineax.AbstractLinearOperator.transpose`][]"""
return self.transpose()
def __add__(self, other) -> "AbstractLinearOperator":
if not isinstance(other, AbstractLinearOperator):
raise ValueError("Can only add AbstractLinearOperators together.")
return AddLinearOperator(self, other)
def __sub__(self, other) -> "AbstractLinearOperator":
if not isinstance(other, AbstractLinearOperator):
raise ValueError("Can only add AbstractLinearOperators together.")
return AddLinearOperator(self, -other)
def __mul__(self, other) -> "AbstractLinearOperator":
other = jnp.asarray(other)
if other.shape != ():
raise ValueError("Can only multiply AbstractLinearOperators by scalars.")
return MulLinearOperator(self, other)
def __rmul__(self, other) -> "AbstractLinearOperator":
return self * other
def __matmul__(self, other) -> "AbstractLinearOperator":
if not isinstance(other, AbstractLinearOperator):
raise ValueError("Can only compose AbstractLinearOperators together.")
return ComposedLinearOperator(self, other)
def __truediv__(self, other) -> "AbstractLinearOperator":
other = jnp.asarray(other)
if other.shape != ():
raise ValueError("Can only divide AbstractLinearOperators by scalars.")
return DivLinearOperator(self, other)
def __neg__(self) -> "AbstractLinearOperator":
return NegLinearOperator(self)
class MatrixLinearOperator(AbstractLinearOperator):
"""Wraps a 2-dimensional JAX array into a linear operator.
If the matrix has shape `(a, b)` then matrix-vector multiplication (`self.mv`) is
defined in the usual way: as performing a matrix-vector that accepts a vector of
shape `(a,)` and returns a vector of shape `(b,)`.
"""
matrix: Inexact[Array, "a b"]
tags: frozenset[object] = eqx.field(static=True)
def __init__(
self, matrix: Shaped[Array, "a b"], tags: object | frozenset[object] = ()
):
"""**Arguments:**
- `matrix`: a two-dimensional JAX array. For an array with shape `(a, b)` then
this operator can perform matrix-vector products on a vector of shape
`(b,)` to return a vector of shape `(a,)`.
- `tags`: any tags indicating whether this matrix has any particular properties,
like symmetry or positive-definite-ness. Note that these properties are
unchecked and you may get incorrect values elsewhere if these tags are
wrong.
"""
if jnp.ndim(matrix) != 2:
raise ValueError(
"`MatrixLinearOperator(matrix=...)` should be 2-dimensional."
)
if not jnp.issubdtype(matrix.dtype, jnp.inexact):
matrix = matrix.astype(jnp.float32)
self.matrix = matrix
self.tags = _frozenset(tags)
def mv(self, vector):
maybe_sparse_op = _try_sparse_materialise(self)
if maybe_sparse_op is not self:
return maybe_sparse_op.mv(vector)
return jnp.matmul(self.matrix, vector, precision=lax.Precision.HIGHEST)
def as_matrix(self):
return self.matrix
def transpose(self):
if is_symmetric(self):
return self
return MatrixLinearOperator(self.matrix.T, transpose_tags(self.tags))
def in_structure(self):
_, in_size = jnp.shape(self.matrix)
return jax.ShapeDtypeStruct(shape=(in_size,), dtype=self.matrix.dtype)
def out_structure(self):
out_size, _ = jnp.shape(self.matrix)
return jax.ShapeDtypeStruct(shape=(out_size,), dtype=self.matrix.dtype)
def _matmul(matrix: ArrayLike, vector: ArrayLike) -> Array:
# matrix has structure [leaf(out), leaf(in)]
# vector has structure [leaf(in)]
# return has structure [leaf(out)]
return jnp.tensordot(
matrix, vector, axes=jnp.ndim(vector), precision=lax.Precision.HIGHEST
)
def _tree_matmul(matrix: PyTree[ArrayLike], vector: PyTree[ArrayLike]) -> PyTree[Array]:
# matrix has structure [tree(in), leaf(out), leaf(in)]
# vector has structure [tree(in), leaf(in)]
# return has structure [leaf(out)]
matrix = jtu.tree_leaves(matrix)
vector = jtu.tree_leaves(vector)
assert len(matrix) == len(vector)
return sum([_matmul(m, v) for m, v in zip(matrix, vector)])
# Needed as static fields must be hashable and eq-able, and custom pytrees might have
# e.g. define custom __eq__ methods.
_T = TypeVar("_T")
_FlatPyTree = tuple[list[_T], jtu.PyTreeDef]
def _inexact_structure_impl2(x):
if jnp.issubdtype(x.dtype, jnp.inexact):
return x
else:
return x.astype(default_floating_dtype())
def _inexact_structure_impl(x):
return jtu.tree_map(_inexact_structure_impl2, x)
def _inexact_structure(x: PyTree[jax.ShapeDtypeStruct]) -> PyTree[jax.ShapeDtypeStruct]:
return strip_weak_dtype(jax.eval_shape(_inexact_structure_impl, x))
class _Leaf: # not a pytree
def __init__(self, value):
self.value = value
# The `{input,output}_structure`s have to be static because otherwise abstract
# evaluation rules will promote them to ShapedArrays.
class PyTreeLinearOperator(AbstractLinearOperator):
"""Represents a PyTree of floating-point JAX arrays as a linear operator.
This is basically a generalisation of [`lineax.MatrixLinearOperator`][], from
taking just a single array to take a PyTree-of-arrays. (And likewise from returning
a single array to returning a PyTree-of-arrays.)
Specifically, suppose we want this to be a linear operator `X -> Y`, for which
elements of `X` are PyTrees with structure `T` whose `i`th leaf is a floating-point
JAX array of shape `x_shape_i`, and elements of `Y` are PyTrees with structure `S`
whose `j`th leaf is a floating-point JAX array of has shape `y_shape_j`. Then the
input PyTree should have structure `T`-compose-`S`, and its `(i, j)`-th leaf should
be a floating-point JAX array of shape `(*x_shape_i, *y_shape_j)`.
!!! Example
```python
# Suppose `x` is a member of our input space, with the following pytree
# structure:
eqx.tree_pprint(x) # [f32[5, 9], f32[3]]
# Suppose `y` is a member of our output space, with the following pytree
# structure:
eqx.tree_pprint(y)
# {"a": f32[1, 2]}
# then `pytree` should be a pytree with the following structure:
eqx.tree_pprint(pytree) # {"a": [f32[1, 2, 5, 9], f32[1, 2, 3]]}
```
"""
pytree: PyTree[Inexact[Array, "..."]]
output_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
tags: frozenset[object] = eqx.field(static=True)
input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
def __init__(
self,
pytree: PyTree[ArrayLike],
output_structure: PyTree[jax.ShapeDtypeStruct],
tags: object | frozenset[object] = (),
):
"""**Arguments:**
- `pytree`: this should be a PyTree, with structure as specified in
[`lineax.PyTreeLinearOperator`][].
- `output_structure`: the structure of the output space. This should be a PyTree
of `jax.ShapeDtypeStruct`s. (The structure of the input space is then
automatically derived from the structure of `pytree`.)
- `tags`: any tags indicating whether this operator has any particular
properties, like symmetry or positive-definite-ness. Note that these
properties are unchecked and you may get incorrect values elsewhere if these
tags are wrong.
"""
output_structure = _inexact_structure(output_structure)
self.pytree = jtu.tree_map(inexact_asarray, pytree)
self.output_structure = jtu.tree_flatten(output_structure)
self.tags = _frozenset(tags)
# self.out_structure() has structure [tree(out)]
# self.pytree has structure [tree(out), tree(in), leaf(out), leaf(in)]
def get_structure(struct, subpytree):
# subpytree has structure [tree(in), leaf(out), leaf(in)]
def sub_get_structure(leaf):
shape = jnp.shape(leaf) # [leaf(out), leaf(in)]
ndim = len(struct.shape)
if shape[:ndim] != struct.shape:
raise ValueError(
"`pytree` and `output_structure` are not consistent"
)
return jax.ShapeDtypeStruct(
shape=shape[ndim:], dtype=jnp.result_type(leaf)
)
return _Leaf(jtu.tree_map(sub_get_structure, subpytree))
if output_structure is None:
# Implies that len(input_structures) > 0
raise ValueError("Cannot have trivial output_structure")
input_structures = jtu.tree_map(get_structure, output_structure, self.pytree)
input_structures = jtu.tree_leaves(input_structures)
input_structure = input_structures[0].value
for val in input_structures[1:]:
if eqx.tree_equal(input_structure, val.value) is not True:
raise ValueError(
"`pytree` does not have a consistent `input_structure`"
)
self.input_structure = jtu.tree_flatten(input_structure)
def mv(self, vector):
# vector has structure [tree(in), leaf(in)]
# self.out_structure() has structure [tree(out)]
# self.pytree has structure [tree(out), tree(in), leaf(out), leaf(in)]
# return has structure [tree(out), leaf(out)]
maybe_sparse_op = _try_sparse_materialise(self)
if maybe_sparse_op is not self:
return maybe_sparse_op.mv(vector)
def matmul(_, matrix):
return _tree_matmul(matrix, vector)
return jtu.tree_map(matmul, self.out_structure(), self.pytree)
def as_matrix(self):
with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(*jtu.tree_leaves(self.pytree))
def concat_in(struct, subpytree):
leaves = jtu.tree_leaves(subpytree)
assert all(leaf.shape[: struct.ndim] == struct.shape for leaf in leaves)
leaves = [
leaf.astype(dtype).reshape(
struct.size, math.prod(leaf.shape[struct.ndim :])
)
for leaf in leaves
]
return jnp.concatenate(leaves, axis=1)
matrix = jtu.tree_map(concat_in, self.out_structure(), self.pytree)
matrix = jtu.tree_leaves(matrix)
return jnp.concatenate(matrix, axis=0)
def transpose(self):
if is_symmetric(self):
return self
def _transpose(struct, subtree):
def _transpose_impl(leaf):
return jnp.moveaxis(leaf, source, dest)
source = list(range(struct.ndim))
dest = list(range(-struct.ndim, 0))
return jtu.tree_map(_transpose_impl, subtree)
pytree_transpose = jtu.tree_map(_transpose, self.out_structure(), self.pytree)
pytree_transpose = jtu.tree_transpose(
jtu.tree_structure(self.out_structure()),
jtu.tree_structure(self.in_structure()),
pytree_transpose,
)
return PyTreeLinearOperator(
pytree_transpose, self.in_structure(), transpose_tags(self.tags)
)
def in_structure(self):
leaves, treedef = self.input_structure
return jtu.tree_unflatten(treedef, leaves)
def out_structure(self):
leaves, treedef = self.output_structure
return jtu.tree_unflatten(treedef, leaves)
class DiagonalLinearOperator(AbstractLinearOperator):
"""A diagonal linear operator, e.g. for a diagonal matrix. Only the diagonal is
stored (for memory efficiency). Matrix-vector products are computed by doing a
pointwise diagonal * vector, rather than a full matrix @ vector (for speed).
The diagonal may also be a PyTree, rather than a 1D array. When materialising the
matrix, the diagonal is taken to be defined by the flattened PyTree (i.e. values
show up in the same order.)
"""
diagonal: PyTree[Inexact[Array, "..."]]
def __init__(self, diagonal: PyTree[ArrayLike]):
"""**Arguments:**
- `diagonal`: an array or PyTree defining the diagonal of the matrix.
"""
self.diagonal = jtu.tree_map(inexact_asarray, diagonal)
def mv(self, vector):
return (ω(self.diagonal) * ω(vector)).ω
def as_matrix(self):
return jnp.diag(diagonal(self))
def transpose(self):
return self
def in_structure(self):
return jax.eval_shape(lambda: self.diagonal)
def out_structure(self):
return jax.eval_shape(lambda: self.diagonal)
class _NoAuxIn(eqx.Module):
fn: Callable
args: Any
def __call__(self, x):
return self.fn(x, self.args)
class _Unwrap(eqx.Module):
fn: Callable
def __call__(self, x):
(f,) = self.fn(x)
return f
class JacobianLinearOperator(AbstractLinearOperator):
"""Given a function `fn: X -> Y`, and a point `x in X`, then this defines the
linear operator (also a function `X -> Y`) given by the Jacobian `(d(fn)/dx)(x)`.
For example if the inputs and outputs are just arrays, then this is equivalent to
`MatrixLinearOperator(jax.jacfwd(fn)(x))`.
The Jacobian is not materialised; matrix-vector products, which are in fact
Jacobian-vector products, are computed using autodifferentiation. By default
(or with `jac="fwd"`), `JacobianLinearOperator(fn, x).mv(v)` is equivalent to
`jax.jvp(fn, (x,), (v,))`. For `jac="bwd"`, `jax.vjp` is combined with
`jax.linear_transpose`, which works even with functions
that only define a custom VJP (via `jax.custom_vjp`) and don't support
forward-mode differentiation.
See also [`lineax.materialise`][], which materialises the whole Jacobian in
memory.
!!! tip
For repeated `mv()` calls, consider using [`lineax.linearise`][] to cache
the primal computation, e.g. for `jac="fwd"/None` it returns
`_, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)`
"""
fn: Callable[
[PyTree[Inexact[Array, "..."]], PyTree[Any]], PyTree[Inexact[Array, "..."]]
]
x: PyTree[Inexact[Array, "..."]]
args: PyTree[Any]
tags: frozenset[object] = eqx.field(static=True)
jac: Literal["fwd", "bwd"] | None
@eqxi.doc_remove_args("closure_convert")
def __init__(
self,
fn: Callable,
x: PyTree[ArrayLike],
args: PyTree[Any] = None,
tags: object | Iterable[object] = (),
jac: Literal["fwd", "bwd"] | None = None,
closure_convert: bool = True,
):
"""**Arguments:**
- `fn`: A function `(x, args) -> y`. The Jacobian `d(fn)/dx` is used as the
linear operator, and `args` are just any other arguments that should not be
differentiated.
- `x`: The point to evaluate `d(fn)/dx` at: `(d(fn)/dx)(x, args)`.
- `args`: As `x`; this is the point to evaluate `d(fn)/dx` at:
`(d(fn)/dx)(x, args)`.
- `tags`: any tags indicating whether this operator has any particular
properties, like symmetry or positive-definite-ness. Note that these
properties are unchecked and you may get incorrect values elsewhere if these
tags are wrong.
- `jac`: allows to use specific jacobian computation method. If `jac=fwd`
forces `jax.jacfwd` to be used, similarly `jac=bwd` mandates the use of
`jax.jacrev`. Otherwise, if not specified it will be chosen
by default according to input and output shape.
"""
if jac not in [None, "fwd", "bwd"]:
raise ValueError(
"`jac` argument of `JacobianLinearOperator` should be either "
"`'fwd'`, `'bwd'`, or `None`."
)
# Flush out any closed-over values, so that we can safely pass `self`
# across API boundaries. (In particular, across `linear_solve_p`.)
# We don't use `jax.closure_convert` as that only flushes autodiffable
# (=floating-point) constants. It probably doesn't matter, but if `fn` is a
# PyTree capturing non-floating-point constants, we should probably continue
# to respect that, and keep any non-floating-point constants as part of the
# PyTree structure.
x = jtu.tree_map(inexact_asarray, x)
if closure_convert:
fn = eqx.filter_closure_convert(fn, x, args)
self.fn = fn
self.x = x
self.args = args
self.tags = _frozenset(tags)
self.jac = jac
def mv(self, vector):
fn = _NoAuxIn(self.fn, self.args)
if self.jac == "fwd" or self.jac is None:
_, out = jax.jvp(fn, (self.x,), (vector,))
elif self.jac == "bwd":
# Use VJP + linear_transpose instead of materializing full Jacobian.
# This works even for custom_vjp functions that don't have JVP rules.
_, vjp_fn = jax.vjp(fn, self.x)
if is_symmetric(self):
# For symmetric operators, J = J.T, so vjp directly gives J @ v
(out,) = vjp_fn(vector)
else:
# For non-symmetric, transpose the VJP to get J @ v from J.T @ v
transpose_vjp = jax.linear_transpose(
lambda g: vjp_fn(g)[0], self.out_structure()
)
(out,) = transpose_vjp(vector)
else:
raise ValueError("`jac` should be either `'fwd'`, `'bwd'`, or `None`.")
return out
def as_matrix(self):
return materialise(self).as_matrix()
def transpose(self):
if is_symmetric(self):
return self
fn = _NoAuxIn(self.fn, self.args)
# Works because vjpfn is a PyTree
_, vjpfn = jax.vjp(fn, self.x)
vjpfn = _Unwrap(vjpfn)
return FunctionLinearOperator(
vjpfn, self.out_structure(), transpose_tags(self.tags)
)
def in_structure(self):
return strip_weak_dtype(jax.eval_shape(lambda: self.x))
def out_structure(self):
fn = _NoAuxIn(self.fn, self.args)
return strip_weak_dtype(eqxi.cached_filter_eval_shape(fn, self.x))
# `input_structure` must be static as with `JacobianLinearOperator`
class FunctionLinearOperator(AbstractLinearOperator):
"""Wraps a *linear* function `fn: X -> Y` into a linear operator. (So that
`self.mv(x)` is defined by `self.mv(x) == fn(x)`.)
See also [`lineax.materialise`][], which materialises the whole linear operator
in memory. (Similar to `.as_matrix()`.)
"""
fn: Callable[[PyTree[Inexact[Array, "..."]]], PyTree[Inexact[Array, "..."]]]
input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
tags: frozenset[object] = eqx.field(static=True)
@eqxi.doc_remove_args("closure_convert")
def __init__(
self,
fn: Callable[[PyTree[Inexact[Array, "..."]]], PyTree[Inexact[Array, "..."]]],
input_structure: PyTree[jax.ShapeDtypeStruct],
tags: object | Iterable[object] = (),
closure_convert: bool = True,
):
"""**Arguments:**
- `fn`: a linear function. Should accept a PyTree of floating-point JAX arrays,
and return a PyTree of floating-point JAX arrays.
- `input_structure`: A PyTree of `jax.ShapeDtypeStruct`s specifying the
structure of the input to the function. (When later calling `self.mv(x)`
then this should match the structure of `x`, i.e.
`jax.eval_shape(lambda: x)`.)
- `tags`: any tags indicating whether this operator has any particular
properties, like symmetry or positive-definite-ness. Note that these
properties are unchecked and you may get incorrect values elsewhere if these
tags are wrong.
"""
# See matching comment in JacobianLinearOperator.
input_structure = _inexact_structure(input_structure)
if closure_convert:
fn = eqx.filter_closure_convert(fn, input_structure)
self.fn = fn
self.input_structure = jtu.tree_flatten(input_structure)
self.tags = _frozenset(tags)
def mv(self, vector):
return self.fn(vector)
def as_matrix(self):
return materialise(self).as_matrix()
def transpose(self):
if is_symmetric(self):
return self
transpose_fn = jax.linear_transpose(self.fn, self.in_structure())
def _transpose_fn(vector):
(out,) = transpose_fn(vector)
return out
# Works because transpose_fn is a PyTree
return FunctionLinearOperator(
_transpose_fn, self.out_structure(), transpose_tags(self.tags)
)
def in_structure(self):
leaves, treedef = self.input_structure
return jtu.tree_unflatten(treedef, leaves)
def out_structure(self):
return strip_weak_dtype(
eqxi.cached_filter_eval_shape(self.fn, self.in_structure())
)
# `structure` must be static as with `JacobianLinearOperator`
class IdentityLinearOperator(AbstractLinearOperator):
"""Represents the identity transformation `X -> X`, where each `x in X` is some
PyTree of floating-point JAX arrays.
"""
input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
output_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
def __init__(
self,
input_structure: PyTree[jax.ShapeDtypeStruct],
output_structure: PyTree[jax.ShapeDtypeStruct] = sentinel,
):
"""**Arguments:**
- `input_structure`: A PyTree of `jax.ShapeDtypeStruct`s specifying the
structure of the the input space. (When later calling `self.mv(x)`
then this should match the structure of `x`, i.e.
`jax.eval_shape(lambda: x)`.)
- `output_structure`: A PyTree of `jax.ShapeDtypeStruct`s specifying the
structure of the the output space. If not passed then this defaults to the
same as `input_structure`. If passed then it must have the same number of
elements as `input_structure`, so that the operator is square.
"""
if output_structure is sentinel:
output_structure = input_structure
input_structure = _inexact_structure(input_structure)
output_structure = _inexact_structure(output_structure)
self.input_structure = jtu.tree_flatten(input_structure)
self.output_structure = jtu.tree_flatten(output_structure)
def mv(self, vector):
if not eqx.tree_equal(
strip_weak_dtype(jax.eval_shape(lambda: vector)),
strip_weak_dtype(self.in_structure()),
):
raise ValueError("Vector and operator structures do not match")
elif self.input_structure == self.output_structure:
return vector # fast-path for common special case
else:
# TODO(kidger): this could be done slightly more efficiently, by iterating
# leaf-by-leaf.
leaves = jtu.tree_leaves(vector)
with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(*leaves)
vector = jnp.concatenate([x.astype(dtype).reshape(-1) for x in leaves])
out_size = self.out_size()
if vector.size < out_size:
vector = jnp.concatenate(
[vector, jnp.zeros(out_size - vector.size, vector.dtype)]
)
else:
vector = vector[:out_size]
leaves, treedef = jtu.tree_flatten(self.out_structure())
sizes = np.cumsum([math.prod(x.shape) for x in leaves[:-1]])
split = jnp.split(vector, sizes)
assert len(split) == len(leaves)
with warnings.catch_warnings():
warnings.simplefilter("ignore") # ignore complex-to-real cast warning
shaped = [
x.reshape(y.shape).astype(y.dtype) for x, y in zip(split, leaves)
]
return jtu.tree_unflatten(treedef, shaped)
def as_matrix(self):
leaves = jtu.tree_leaves(self.in_structure())
with jax.numpy_dtype_promotion("standard"):
dtype = (
default_floating_dtype()
if len(leaves) == 0
else jnp.result_type(*leaves)
)
return jnp.eye(self.out_size(), self.in_size(), dtype=dtype)
def transpose(self):
return IdentityLinearOperator(self.out_structure(), self.in_structure())
def in_structure(self):
leaves, treedef = self.input_structure
return jtu.tree_unflatten(treedef, leaves)
def out_structure(self):
leaves, treedef = self.output_structure
return jtu.tree_unflatten(treedef, leaves)
@property
def tags(self):
return frozenset()
class TridiagonalLinearOperator(AbstractLinearOperator):
"""As [`lineax.MatrixLinearOperator`][], but for specifically a tridiagonal
matrix.
"""
diagonal: Inexact[Array, " size"]
lower_diagonal: Inexact[Array, " size-1"]
upper_diagonal: Inexact[Array, " size-1"]
def __init__(
self,
diagonal: Inexact[Array, " size"],
lower_diagonal: Inexact[Array, " size-1"],
upper_diagonal: Inexact[Array, " size-1"],
):
"""**Arguments:**
- `diagonal`: A rank-one JAX array. This is the diagonal of the matrix.
- `lower_diagonal`: A rank-one JAX array. This is the lower diagonal of the
matrix.
- `upper_diagonal`: A rank-one JAX array. This is the upper diagonal of the
matrix.
If `diagonal` has shape `(a,)` then `lower_diagonal` and `upper_diagonal` should
both have shape `(a - 1,)`.
"""
self.diagonal = inexact_asarray(diagonal)
self.lower_diagonal = inexact_asarray(lower_diagonal)
self.upper_diagonal = inexact_asarray(upper_diagonal)
(size,) = self.diagonal.shape
if self.lower_diagonal.shape != (size - 1,):
raise ValueError("lower_diagonal and diagonal do not have consistent size")
if self.upper_diagonal.shape != (size - 1,):
raise ValueError("upper_diagonal and diagonal do not have consistent size")
def mv(self, vector):
a = self.upper_diagonal * vector[1:]
b = self.diagonal * vector
c = self.lower_diagonal * vector[:-1]
return b.at[:-1].add(a).at[1:].add(c)
def as_matrix(self):
(size,) = jnp.shape(self.diagonal)
matrix = jnp.zeros((size, size), self.diagonal.dtype)
arange = np.arange(size)
matrix = matrix.at[arange, arange].set(self.diagonal)
matrix = matrix.at[arange[1:], arange[:-1]].set(self.lower_diagonal)
matrix = matrix.at[arange[:-1], arange[1:]].set(self.upper_diagonal)
return matrix
def transpose(self):
return TridiagonalLinearOperator(
self.diagonal, self.upper_diagonal, self.lower_diagonal
)
def in_structure(self):
(size,) = jnp.shape(self.diagonal)
return jax.ShapeDtypeStruct(shape=(size,), dtype=self.diagonal.dtype)
def out_structure(self):
(size,) = jnp.shape(self.diagonal)
return jax.ShapeDtypeStruct(shape=(size,), dtype=self.diagonal.dtype)
class TaggedLinearOperator(AbstractLinearOperator):
"""Wraps another linear operator and specifies that it has certain tags, e.g.
representing symmetry.
!!! Example
```python
# Some other operator.
operator = lx.MatrixLinearOperator(some_jax_array)
# Now symmetric! But the type system doesn't know this.
sym_operator = operator + operator.T
assert lx.is_symmetric(sym_operator) == False
# We can declare that our operator has a particular property.
sym_operator = lx.TaggedLinearOperator(sym_operator, lx.symmetric_tag)
assert lx.is_symmetric(sym_operator) == True
```
"""
operator: AbstractLinearOperator
tags: frozenset[object] = eqx.field(static=True)
def __init__(
self, operator: AbstractLinearOperator, tags: object | Iterable[object]
):
"""**Arguments:**
- `operator`: some other linear operator to wrap.
- `tags`: any tags indicating whether this operator has any particular
properties, like symmetry or positive-definite-ness. Note that these
properties are unchecked and you may get incorrect values elsewhere if these
tags are wrong.
"""
self.operator = operator
self.tags = _frozenset(tags)
def mv(self, vector):
return self.operator.mv(vector)
def as_matrix(self):
return self.operator.as_matrix()
def transpose(self):
return TaggedLinearOperator(
self.operator.transpose(), transpose_tags(self.tags)
)
def in_structure(self):
return self.operator.in_structure()
def out_structure(self):
return self.operator.out_structure()
#
# All operators below here are private to lineax.
#
def _is_none(x):
return x is None
class TangentLinearOperator(AbstractLinearOperator):
"""Internal to lineax. Used to represent the tangent (jvp) computation with
respect to the linear operator in a linear solve.
"""
primal: AbstractLinearOperator
tangent: AbstractLinearOperator
def __check_init__(self):
assert type(self.primal) is type(self.tangent) # noqa: E721
def mv(self, vector):
mv = lambda operator: operator.mv(vector)
out, t_out = eqx.filter_jvp(mv, (self.primal,), (self.tangent,))
return jtu.tree_map(eqxi.materialise_zeros, out, t_out, is_leaf=_is_none)
def as_matrix(self):
as_matrix = lambda operator: operator.as_matrix()
out, t_out = eqx.filter_jvp(as_matrix, (self.primal,), (self.tangent,))
return jtu.tree_map(eqxi.materialise_zeros, out, t_out, is_leaf=_is_none)
def transpose(self):
transpose = lambda operator: operator.transpose()
primal_out, tangent_out = eqx.filter_jvp(
transpose, (self.primal,), (self.tangent,)
)
return TangentLinearOperator(primal_out, tangent_out)
def in_structure(self):
return self.primal.in_structure()
def out_structure(self):
return self.primal.out_structure()
class AddLinearOperator(AbstractLinearOperator):
"""A linear operator formed by adding two other linear operators together.
!!! Example
```python
x = MatrixLinearOperator(...)
y = MatrixLinearOperator(...)
assert isinstance(x + y, AddLinearOperator)
```
"""
operator1: AbstractLinearOperator
operator2: AbstractLinearOperator
def __check_init__(self):
if self.operator1.in_structure() != self.operator2.in_structure():
raise ValueError("Incompatible linear operator structures")
if self.operator1.out_structure() != self.operator2.out_structure():
raise ValueError("Incompatible linear operator structures")
def mv(self, vector):
maybe_sparse_op = _try_sparse_materialise(self)
if maybe_sparse_op is not self:
return maybe_sparse_op.mv(vector)
mv1 = self.operator1.mv(vector)
mv2 = self.operator2.mv(vector)
return (mv1**ω + mv2**ω).ω
def as_matrix(self):
return self.operator1.as_matrix() + self.operator2.as_matrix()
def transpose(self):
return self.operator1.transpose() + self.operator2.transpose()
def in_structure(self):
return self.operator1.in_structure()
def out_structure(self):
return self.operator1.out_structure()
class MulLinearOperator(AbstractLinearOperator):
"""A linear operator formed by multiplying a linear operator by a scalar.
!!! Example
```python
x = MatrixLinearOperator(...)
y = 0.5
assert isinstance(x * y, MulLinearOperator)
```
"""
operator: AbstractLinearOperator
scalar: Scalar
def mv(self, vector):
return (self.operator.mv(vector) ** ω * self.scalar).ω
def as_matrix(self):
return self.operator.as_matrix() * self.scalar
def transpose(self):
return self.operator.transpose() * self.scalar
def in_structure(self):
return self.operator.in_structure()
def out_structure(self):
return self.operator.out_structure()
# Not just `MulLinearOperator(..., -1)` for compatibility with
# `jax_numpy_dtype_promotion=strict`.
class NegLinearOperator(AbstractLinearOperator):
"""A linear operator formed by computing the negative of a linear operator.
!!! Example
```python
x = MatrixLinearOperator(...)
assert isinstance(-x, NegLinearOperator)
```
"""
operator: AbstractLinearOperator
def mv(self, vector):
return (-(self.operator.mv(vector) ** ω)).ω
def as_matrix(self):
return -self.operator.as_matrix()
def transpose(self):
return -self.operator.transpose()
def in_structure(self):
return self.operator.in_structure()
def out_structure(self):
return self.operator.out_structure()
class DivLinearOperator(AbstractLinearOperator):
"""A linear operator formed by dividing a linear operator by a scalar.
!!! Example
```python
x = MatrixLinearOperator(...)
y = 0.5
assert isinstance(x / y, DivLinearOperator)
```
"""
operator: AbstractLinearOperator
scalar: Scalar
def mv(self, vector):
with jax.numpy_dtype_promotion("standard"):
return (self.operator.mv(vector) ** ω / self.scalar).ω
def as_matrix(self):
return self.operator.as_matrix() / self.scalar
def transpose(self):
return self.operator.transpose() / self.scalar
def in_structure(self):
return self.operator.in_structure()
def out_structure(self):
return self.operator.out_structure()
class ComposedLinearOperator(AbstractLinearOperator):
"""A linear operator formed by composing (matrix-multiplying) two other linear
operators together.
!!! Example
```python
x = MatrixLinearOperator(matrix1)
y = MatrixLinearOperator(matrix2)
composed = x @ y
assert isinstance(composed, ComposedLinearOperator)
assert jnp.allclose(composed.as_matrix(), matrix1 @ matrix2)
```
"""
operator1: AbstractLinearOperator
operator2: AbstractLinearOperator
def __check_init__(self):
if self.operator1.in_structure() != self.operator2.out_structure():
raise ValueError("Incompatible linear operator structures")
def mv(self, vector):
maybe_sparse_op = _try_sparse_materialise(self)
if maybe_sparse_op is not self:
return maybe_sparse_op.mv(vector)
return self.operator1.mv(self.operator2.mv(vector))
def as_matrix(self):
if isinstance(self.operator1, IdentityLinearOperator):
return self.operator2.as_matrix()
if isinstance(self.operator2, IdentityLinearOperator):
return self.operator1.as_matrix()
_, unravel = eqx.filter_eval_shape(
jfu.ravel_pytree, self.operator1.in_structure()
)
def mv_flat(v):
out = self.operator1.mv(unravel(v))
return jfu.ravel_pytree(out)[0]
return jax.vmap(mv_flat, in_axes=1, out_axes=1)(self.operator2.as_matrix())
def transpose(self):
return self.operator2.transpose() @ self.operator1.transpose()
def in_structure(self):
return self.operator2.in_structure()
def out_structure(self):
return self.operator1.out_structure()
#
# Operations on `AbstractLinearOperator`s.
# These are done through `singledispatch` rather than as methods.
#
# If an end user ever wanted to add something analogous to
# `diagonal: AbstractLinearOperator -> Array`
# then of course they don't get to edit our base class and add overloads to all
# subclasses.
# They'd have to use `singledispatch` to get the desired behaviour. (Or maybe just
# hardcode compatibility with only some `AbstractLinearOperator` subclasses, eurgh.)
# So for consistency we do the same thing here, rather than adding privileged behaviour
# for just the operations we happen to support.
#
# (Something something Julia something something orphan problem etc.)
#
def _default_not_implemented(name: str, operator: AbstractLinearOperator) -> NoReturn:
msg = f"`lineax.{name}` has not been implemented for {type(operator)}"
if type(operator).__module__.startswith("lineax"):
assert False, msg + ". Please file a bug against Lineax."
else:
raise NotImplementedError(msg)
# linearise
@ft.singledispatch
def linearise(operator: AbstractLinearOperator) -> AbstractLinearOperator:
"""Linearises a linear operator. This returns another linear operator.
Mathematically speaking this is just the identity function. And indeed most linear
operators will be returned unchanged.
For specifically [`lineax.JacobianLinearOperator`][], then this will cache the
primal pass, so that it does not need to be recomputed each time. That is, it uses
some memory to improve speed. (This is the precisely same distinction as `jax.jvp`
versus `jax.linearize`.)
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Another linear operator. Mathematically it performs matrix-vector products
(`operator.mv`) that produce the same results as the input `operator`.
"""
_default_not_implemented("linearise", operator)
@linearise.register(MatrixLinearOperator)
@linearise.register(PyTreeLinearOperator)
@linearise.register(FunctionLinearOperator)
@linearise.register(IdentityLinearOperator)
@linearise.register(DiagonalLinearOperator)
@linearise.register(TridiagonalLinearOperator)
def _(operator):
return operator
@linearise.register(JacobianLinearOperator)
def _(operator):
fn = _NoAuxIn(operator.fn, operator.args)
if operator.jac == "bwd":
# For backward mode, use VJP + linear_transpose.
# This works even with custom_vjp functions that don't support forward-mode AD.
_, vjp_fn = jax.vjp(fn, operator.x)
if is_symmetric(operator):
# For symmetric: J = J.T, so vjp directly gives J @ v
lin = _Unwrap(vjp_fn)
else:
# Transpose the VJP to get J @ v from J.T @ v
lin = _Unwrap(
jax.linear_transpose(lambda g: vjp_fn(g)[0], operator.out_structure())
)
else: # "fwd" or None
_, lin = jax.linearize(fn, operator.x)
return FunctionLinearOperator(lin, operator.in_structure(), operator.tags)
# materialise
@ft.singledispatch
def materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator:
"""Materialises a linear operator. This returns another linear operator.
Mathematically speaking this is just the identity function. And indeed most linear
operators will be returned unchanged.
For specifically [`lineax.JacobianLinearOperator`][] and
[`lineax.FunctionLinearOperator`][] then the linear operator is materialised in
memory. That is, it becomes defined as a matrix (or pytree of arrays), rather
than being defined only through its matrix-vector product
([`lineax.AbstractLinearOperator.mv`][]).
Materialisation sometimes improves compile time or run time. It usually increases
memory usage.
For example:
```python
large_function = ...
operator = lx.FunctionLinearOperator(large_function, ...)
# Option 1
out1 = operator.mv(vector1) # Traces and compiles `large_function`
out2 = operator.mv(vector2) # Traces and compiles `large_function` again!
out3 = operator.mv(vector3) # Traces and compiles `large_function` a third time!
# All that compilation might lead to long compile times.
# If `large_function` takes a long time to run, then this might also lead to long
# run times.
# Option 2
operator = lx.materialise(operator) # Traces and compiles `large_function` and
# stores the result as a matrix.
out1 = operator.mv(vector1) # Each of these just computes a matrix-vector product
out2 = operator.mv(vector2) # against the stored matrix.
out3 = operator.mv(vector3) #
# Now, `large_function` is only compiled once, and only ran once.
# However, storing the matrix might take a lot of memory, and the initial
# computation may-or-may-not take a long time to run.
```
Generally speaking it is worth first setting up your problem without
`lx.materialise`, and using it as an optional optimisation if you find that it
helps your particular problem.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Another linear operator. Mathematically it performs matrix-vector products
(`operator.mv`) that produce the same results as the input `operator`.
"""
_default_not_implemented("materialise", operator)
def _try_sparse_materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator:
"""Try to materialise to a sparse operator.
Returns a (Tri)DiagonalLinearOperator if the operator is tagged as (tri)diagonal,
otherwise returns the original operator unchanged. The resulting operator
preserves the input/output structure of the original operator.
"""
if is_diagonal(operator):
diag_flat = diagonal(operator)
_, unravel = eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())
diag_pytree = unravel(diag_flat)
return DiagonalLinearOperator(diag_pytree)
# TridiagonalLinearOperator only supports flat in and out structures
if (
is_tridiagonal(operator)
and isinstance(operator.in_structure(), jax.ShapeDtypeStruct)
and isinstance(operator.out_structure(), jax.ShapeDtypeStruct)
):
return TridiagonalLinearOperator(*tridiagonal(operator))
return operator
@materialise.register(MatrixLinearOperator)
@materialise.register(PyTreeLinearOperator)
def _(operator):
return _try_sparse_materialise(operator)
@materialise.register(IdentityLinearOperator)
@materialise.register(DiagonalLinearOperator)
@materialise.register(TridiagonalLinearOperator)
def _(operator):
return operator
@materialise.register(JacobianLinearOperator)
def _(operator):
maybe_sparse_op = _try_sparse_materialise(operator)
if maybe_sparse_op is not operator:
return maybe_sparse_op
fn = _NoAuxIn(operator.fn, operator.args)
jac = jacobian(
fn,
operator.in_size(),
operator.out_size(),
holomorphic=any(jnp.iscomplexobj(xi) for xi in jtu.tree_leaves(operator.x)),
jac=operator.jac,
)(operator.x)
return PyTreeLinearOperator(jac, operator.out_structure(), operator.tags)
@materialise.register(FunctionLinearOperator)
def _(operator):
maybe_sparse_op = _try_sparse_materialise(operator)
if maybe_sparse_op is not operator:
return maybe_sparse_op
flat, unravel = strip_weak_dtype(
eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())
)
eye = jnp.eye(flat.size, dtype=flat.dtype)
jac = jax.vmap(lambda x: operator.fn(unravel(x)), out_axes=-1)(eye)
def batch_unravel(x):
assert x.ndim > 0
unravel_ = unravel
for _ in range(x.ndim - 1):
unravel_ = jax.vmap(unravel_)
return unravel_(x)
jac = jtu.tree_map(batch_unravel, jac)
return PyTreeLinearOperator(jac, operator.out_structure(), operator.tags)
# diagonal
@ft.singledispatch
def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]:
"""Extracts the diagonal from a linear operator, and returns a vector.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
A rank-1 JAX array. (That is, it has shape `(a,)` for some integer `a`.)
For most operators this is just `jnp.diag(operator.as_matrix())`. Some operators
(e.g. [`lineax.DiagonalLinearOperator`][]) can have more efficient
implementations. If you don't know what kind of operator you might have, then this
function ensures that you always get the most efficient implementation.
"""
_default_not_implemented("diagonal", operator)
def _leaf_from_keypath(pytree: PyTree, keypath: jtu.KeyPath) -> Array:
"""Extract the leaf from a pytree at the given keypath."""
for path, leaf in jtu.tree_leaves_with_path(pytree):
if path == keypath:
return leaf
raise ValueError(f"Leaf not found at keypath {keypath}")
@diagonal.register(MatrixLinearOperator)
def _(operator):
return jnp.diag(operator.as_matrix())
@diagonal.register(PyTreeLinearOperator)
def _(operator):
if is_diagonal(operator):
def extract_diag(keypath, struct, subpytree):
block = _leaf_from_keypath(subpytree, keypath)
return jnp.diag(block.reshape(struct.size, struct.size))
diags = jtu.tree_map_with_path(
extract_diag, operator.out_structure(), operator.pytree
)
return jnp.concatenate(jtu.tree_leaves(diags))
else:
return jnp.diag(operator.as_matrix())
@diagonal.register(JacobianLinearOperator)
@diagonal.register(FunctionLinearOperator)
def _(operator):
if is_diagonal(operator):
with jax.ensure_compile_time_eval():
basis = jtu.tree_map(
lambda s: jnp.ones(s.shape, s.dtype), operator.in_structure()
)
diag_as_pytree = operator.mv(basis)
diag, _ = jfu.ravel_pytree(diag_as_pytree)
return diag
return diagonal(materialise(operator))
@diagonal.register(DiagonalLinearOperator)
def _(operator):
diagonal, _ = jfu.ravel_pytree(operator.diagonal)
return diagonal
@diagonal.register(IdentityLinearOperator)
def _(operator):
return jnp.ones(operator.in_size())
@diagonal.register(TridiagonalLinearOperator)
def _(operator):
return operator.diagonal
# tridiagonal
@ft.singledispatch
def tridiagonal(
operator: AbstractLinearOperator,
) -> tuple[Shaped[Array, " size"], Shaped[Array, " size-1"], Shaped[Array, " size-1"]]:
"""Extracts the diagonal, lower diagonal, and upper diagonal, from a linear
operator. Returns three vectors.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
A 3-tuple, consisting of:
- The diagonal of the matrix, represented as a vector.
- The lower diagonal of the matrix, represented as a vector.
- The upper diagonal of the matrix, represented as a vector.
If the diagonal has shape `(a,)` then the lower and upper diagonals will have shape
`(a - 1,)`.
For most operators these are computed by materialising the array and then extracting
the relevant elements, e.g. getting the main diagonal via
`jnp.diag(operator.as_matrix())`. Some operators (e.g.
[`lineax.TridiagonalLinearOperator`][]) can have more efficient implementations.
If you don't know what kind of operator you might have, then this function ensures
that you always get the most efficient implementation.
"""
_default_not_implemented("tridiagonal", operator)
@tridiagonal.register(MatrixLinearOperator)
@tridiagonal.register(PyTreeLinearOperator)
def _(operator):
matrix = operator.as_matrix()
assert matrix.ndim == 2
main_diagonal = jnp.diagonal(matrix, offset=0)
upper_diagonal = jnp.diagonal(matrix, offset=1)
lower_diagonal = jnp.diagonal(matrix, offset=-1)
return main_diagonal, lower_diagonal, upper_diagonal
@tridiagonal.register(JacobianLinearOperator)
@tridiagonal.register(FunctionLinearOperator)
def _(operator):
if is_tridiagonal(operator):
with jax.ensure_compile_time_eval():
flat, unravel = strip_weak_dtype(
eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())
)
basis = jnp.zeros((3, flat.size), dtype=flat.dtype)
for i in range(3):
basis = basis.at[i, i::3].set(1.0)
basis = jax.vmap(unravel)(basis)
coloring = jnp.arange(flat.size) % 3
compressed_as_pytree = jax.vmap(operator.mv)(basis)
compressed_flat = jax.vmap(lambda x: jfu.ravel_pytree(x)[0])(
compressed_as_pytree
)
# unique_indices propagates through linear_transpose to set unique_indices=True
# on the scatter, allowing assignment rather than accumulation.
rows = jnp.arange(flat.size)
diag = compressed_flat.at[coloring, rows].get(
wrap_negative_indices=False, unique_indices=True
)
lower_diag = compressed_flat.at[coloring[:-1], rows[1:]].get(
wrap_negative_indices=False, unique_indices=True
)
upper_diag = compressed_flat.at[coloring[1:], rows[:-1]].get(
wrap_negative_indices=False, unique_indices=True
)
return diag, lower_diag, upper_diag
matrix = operator.as_matrix()
assert matrix.ndim == 2
main_diagonal = jnp.diagonal(matrix, offset=0)
upper_diagonal = jnp.diagonal(matrix, offset=1)
lower_diagonal = jnp.diagonal(matrix, offset=-1)
return main_diagonal, lower_diagonal, upper_diagonal
@tridiagonal.register(DiagonalLinearOperator)
def _(operator):
diag = diagonal(operator)
upper_diag = jnp.zeros(diag.size - 1)
lower_diag = jnp.zeros(diag.size - 1)
return diag, lower_diag, upper_diag
@tridiagonal.register(IdentityLinearOperator)
def _(operator):
size = operator.in_size()
main_diagonal = jnp.ones(size)
off_diagonal = jnp.zeros(size - 1)
return main_diagonal, off_diagonal, off_diagonal
@tridiagonal.register(TridiagonalLinearOperator)
def _(operator):
return operator.diagonal, operator.lower_diagonal, operator.upper_diagonal
# is_symmetric
@ft.singledispatch
def is_symmetric(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as symmetric.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("is_symmetric", operator)
def _has_real_dtype(operator) -> bool:
"""Check if all dtypes in an operator's structure are real (not complex)."""
leaves = jtu.tree_leaves((operator.in_structure(), operator.out_structure()))
dtype = jnp.result_type(*leaves)
if jnp.issubdtype(dtype, jnp.complexfloating):
return False
elif jnp.issubdtype(dtype, jnp.floating):
return True
else:
assert False, (
"Only `jnp.floating` and `jnp.complexfloating` dtypes are understood."
)
@is_symmetric.register(MatrixLinearOperator)
@is_symmetric.register(PyTreeLinearOperator)
@is_symmetric.register(JacobianLinearOperator)
@is_symmetric.register(FunctionLinearOperator)
def _(operator):
# Symmetric (A = A^T) if explicitly tagged symmetric or diagonal
if symmetric_tag in operator.tags or diagonal_tag in operator.tags:
return True
# PSD/NSD implies symmetric only for real dtypes; for complex, it's Hermitian
if (
positive_semidefinite_tag in operator.tags
or negative_semidefinite_tag in operator.tags
):
return _has_real_dtype(operator)
return False
@is_symmetric.register(IdentityLinearOperator)
def _(operator):
return eqx.tree_equal(operator.in_structure(), operator.out_structure()) is True
@is_symmetric.register(DiagonalLinearOperator)
def _(operator):
return True
@is_symmetric.register(TridiagonalLinearOperator)
def _(operator):
return False
# is_diagonal
@ft.singledispatch
def is_diagonal(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as diagonal.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("is_diagonal", operator)
@is_diagonal.register(MatrixLinearOperator)
@is_diagonal.register(PyTreeLinearOperator)
@is_diagonal.register(JacobianLinearOperator)
@is_diagonal.register(FunctionLinearOperator)
def _(operator):
return diagonal_tag in operator.tags or (
operator.in_size() == 1 and operator.out_size() == 1
)
@is_diagonal.register(IdentityLinearOperator)
@is_diagonal.register(DiagonalLinearOperator)
def _(operator):
return True
@is_diagonal.register(TridiagonalLinearOperator)
def _(operator):
return operator.in_size() == 1
# is_tridiagonal
@ft.singledispatch
def is_tridiagonal(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as tridiagonal.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("is_tridiagonal", operator)
@is_tridiagonal.register(MatrixLinearOperator)
@is_tridiagonal.register(PyTreeLinearOperator)
@is_tridiagonal.register(JacobianLinearOperator)
@is_tridiagonal.register(FunctionLinearOperator)
def _(operator):
return tridiagonal_tag in operator.tags or diagonal_tag in operator.tags
@is_tridiagonal.register(IdentityLinearOperator)
@is_tridiagonal.register(DiagonalLinearOperator)
@is_tridiagonal.register(TridiagonalLinearOperator)
def _(operator):
return True
# has_unit_diagonal
@ft.singledispatch
def has_unit_diagonal(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as having unit diagonal.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("has_unit_diagonal", operator)
@has_unit_diagonal.register(MatrixLinearOperator)
@has_unit_diagonal.register(PyTreeLinearOperator)
@has_unit_diagonal.register(JacobianLinearOperator)
@has_unit_diagonal.register(FunctionLinearOperator)
def _(operator):
return unit_diagonal_tag in operator.tags
@has_unit_diagonal.register(IdentityLinearOperator)
def _(operator):
return True
@has_unit_diagonal.register(DiagonalLinearOperator)
@has_unit_diagonal.register(TridiagonalLinearOperator)
def _(operator):
# TODO: refine this
return False
# is_lower_triangular
@ft.singledispatch
def is_lower_triangular(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as lower triangular.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("is_lower_triangular", operator)
@is_lower_triangular.register(MatrixLinearOperator)
@is_lower_triangular.register(PyTreeLinearOperator)
@is_lower_triangular.register(JacobianLinearOperator)
@is_lower_triangular.register(FunctionLinearOperator)
def _(operator):
return lower_triangular_tag in operator.tags
@is_lower_triangular.register(IdentityLinearOperator)
@is_lower_triangular.register(DiagonalLinearOperator)
def _(operator):
return True
@is_lower_triangular.register(TridiagonalLinearOperator)
def _(operator):
return False
# is_upper_triangular
@ft.singledispatch
def is_upper_triangular(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as upper triangular.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("is_upper_triangular", operator)
@is_upper_triangular.register(MatrixLinearOperator)
@is_upper_triangular.register(PyTreeLinearOperator)
@is_upper_triangular.register(JacobianLinearOperator)
@is_upper_triangular.register(FunctionLinearOperator)
def _(operator):
return upper_triangular_tag in operator.tags
@is_upper_triangular.register(IdentityLinearOperator)
@is_upper_triangular.register(DiagonalLinearOperator)
def _(operator):
return True
@is_upper_triangular.register(TridiagonalLinearOperator)
def _(operator):
return False
# is_positive_semidefinite
@ft.singledispatch
def is_positive_semidefinite(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as positive semidefinite.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("is_positive_semidefinite", operator)
@is_positive_semidefinite.register(MatrixLinearOperator)
@is_positive_semidefinite.register(PyTreeLinearOperator)
@is_positive_semidefinite.register(JacobianLinearOperator)
@is_positive_semidefinite.register(FunctionLinearOperator)
def _(operator):
return positive_semidefinite_tag in operator.tags
@is_positive_semidefinite.register(IdentityLinearOperator)
def _(operator):
return eqx.tree_equal(operator.in_structure(), operator.out_structure()) is True
@is_positive_semidefinite.register(DiagonalLinearOperator)
@is_positive_semidefinite.register(TridiagonalLinearOperator)
def _(operator):
# TODO: refine this
return False
# is_negative_semidefinite
@ft.singledispatch
def is_negative_semidefinite(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as negative semidefinite.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("is_negative_semidefinite", operator)
@is_negative_semidefinite.register(MatrixLinearOperator)
@is_negative_semidefinite.register(PyTreeLinearOperator)
@is_negative_semidefinite.register(JacobianLinearOperator)
@is_negative_semidefinite.register(FunctionLinearOperator)
def _(operator):
return negative_semidefinite_tag in operator.tags
@is_negative_semidefinite.register(IdentityLinearOperator)
def _(operator):
return False
@is_negative_semidefinite.register(DiagonalLinearOperator)
@is_negative_semidefinite.register(TridiagonalLinearOperator)
def _(operator):
# TODO: refine this
return False
# ops for wrapper operators
@linearise.register(TaggedLinearOperator)
def _(operator):
return TaggedLinearOperator(linearise(operator.operator), operator.tags)
@materialise.register(TaggedLinearOperator)
def _(operator):
return TaggedLinearOperator(materialise(operator.operator), operator.tags)
@diagonal.register(TaggedLinearOperator)
def _(operator):
return diagonal(operator.operator)
@tridiagonal.register(TaggedLinearOperator)
def _(operator):
return tridiagonal(operator.operator)
for transform in (linearise, materialise, diagonal):
@transform.register(MulLinearOperator)
def _(operator, transform=transform):
return transform(operator.operator) * operator.scalar
@transform.register(NegLinearOperator) # pyright: ignore
def _(operator, transform=transform):
return -transform(operator.operator)
@transform.register(DivLinearOperator)
def _(operator, transform=transform):
return transform(operator.operator) / operator.scalar
for transform in (linearise, diagonal):
@transform.register(AddLinearOperator) # pyright: ignore
def _(operator, transform=transform):
return transform(operator.operator1) + transform(operator.operator2) # pyright: ignore
@materialise.register(AddLinearOperator)
def _(operator):
maybe_sparse_op = _try_sparse_materialise(operator)
if maybe_sparse_op is not operator:
return maybe_sparse_op
return materialise(operator.operator1) + materialise(operator.operator2)
@linearise.register(TangentLinearOperator)
def _(operator):
primal_out, tangent_out = eqx.filter_jvp(
linearise, (operator.primal,), (operator.tangent,)
)
return TangentLinearOperator(primal_out, tangent_out)
@materialise.register(TangentLinearOperator)
def _(operator):
primal_out, tangent_out = eqx.filter_jvp(
materialise, (operator.primal,), (operator.tangent,)
)
return TangentLinearOperator(primal_out, tangent_out)
@diagonal.register(TangentLinearOperator)
def _(operator):
# Should be unreachable: TangentLinearOperator is used for a narrow set of
# operations only (mv; transpose) inside the JVP rule linear_solve_p.
raise NotImplementedError(
"Please open a GitHub issue: https://github.com/google/lineax"
)
@tridiagonal.register(TangentLinearOperator)
def _(operator):
# Should be unreachable: TangentLinearOperator is used for a narrow set of
# operations only (mv; transpose) inside the JVP rule linear_solve_p.
raise NotImplementedError(
"Please open a GitHub issue: https://github.com/google/lineax"
)
@tridiagonal.register(AddLinearOperator)
def _(operator):
(diag1, lower1, upper1) = tridiagonal(operator.operator1)
(diag2, lower2, upper2) = tridiagonal(operator.operator2)
return (diag1 + diag2, lower1 + lower2, upper1 + upper2)
@tridiagonal.register(MulLinearOperator)
def _(operator):
(diag, lower, upper) = tridiagonal(operator.operator)
return (diag * operator.scalar, lower * operator.scalar, upper * operator.scalar)
@tridiagonal.register(NegLinearOperator)
def _(operator):
(diag, lower, upper) = tridiagonal(operator.operator)
return (-diag, -lower, -upper)
@tridiagonal.register(DivLinearOperator)
def _(operator):
(diag, lower, upper) = tridiagonal(operator.operator)
return (diag / operator.scalar, lower / operator.scalar, upper / operator.scalar)
@linearise.register(ComposedLinearOperator)
def _(operator):
return linearise(operator.operator1) @ linearise(operator.operator2)
@materialise.register(ComposedLinearOperator)
def _(operator):
if isinstance(operator.operator1, IdentityLinearOperator):
return materialise(operator.operator2)
if isinstance(operator.operator2, IdentityLinearOperator):
return materialise(operator.operator1)
maybe_sparse_op = _try_sparse_materialise(operator)
if maybe_sparse_op is not operator:
return maybe_sparse_op
return materialise(operator.operator1) @ materialise(operator.operator2)
@diagonal.register(ComposedLinearOperator)
def _(operator):
if is_diagonal(operator.operator1) and is_diagonal(operator.operator2):
return diagonal(operator.operator1) * diagonal(operator.operator2)
return jnp.diag(operator.as_matrix())
@tridiagonal.register(ComposedLinearOperator)
def _(operator):
if is_diagonal(operator.operator1) and is_tridiagonal(operator.operator2):
d = diagonal(operator.operator1)
main, lower, upper = tridiagonal(operator.operator2)
# D @ T scales rows: row i multiplied by d[i]
return d * main, d[1:] * lower, d[:-1] * upper
if is_diagonal(operator.operator2) and is_tridiagonal(operator.operator1):
d = diagonal(operator.operator2)
main, lower, upper = tridiagonal(operator.operator1)
# T @ D scales columns: column j multiplied by d[j]
return d * main, d[:-1] * lower, d[1:] * upper
matrix = operator.as_matrix()
assert matrix.ndim == 2
main_diagonal = jnp.diagonal(matrix, offset=0)
upper_diagonal = jnp.diagonal(matrix, offset=1)
lower_diagonal = jnp.diagonal(matrix, offset=-1)
return main_diagonal, lower_diagonal, upper_diagonal
for check in (
is_symmetric,
is_diagonal,
has_unit_diagonal,
is_lower_triangular,
is_upper_triangular,
is_tridiagonal,
is_positive_semidefinite,
is_negative_semidefinite,
):
@check.register(TangentLinearOperator)
def _(operator, check=check):
return check(operator.primal)
# Scaling/negating preserves these structural properties
for check in (
is_symmetric,
is_diagonal,
is_lower_triangular,
is_upper_triangular,
is_tridiagonal,
):
@check.register(MulLinearOperator)
@check.register(NegLinearOperator)
@check.register(DivLinearOperator)
def _(operator, check=check):
return check(operator.operator)
# has_unit_diagonal is NOT preserved by scaling or negation
@has_unit_diagonal.register(MulLinearOperator)
@has_unit_diagonal.register(NegLinearOperator)
@has_unit_diagonal.register(DivLinearOperator)
def _(operator):
return False
class _ScalarSign(enum.Enum):
positive = enum.auto()
negative = enum.auto()
zero = enum.auto()
unknown = enum.auto()
def _scalar_sign(scalar) -> _ScalarSign:
"""Returns the sign of a scalar, or unknown for JAX tracers."""
if isinstance(scalar, (int, float, np.ndarray, np.generic)):
scalar = float(scalar)
if scalar > 0:
return _ScalarSign.positive
elif scalar < 0:
return _ScalarSign.negative
else:
return _ScalarSign.zero
else:
return _ScalarSign.unknown
# PSD/NSD for MulLinearOperator: depends on sign of scalar
# Zero scalar gives zero matrix which is both PSD and NSD
@is_positive_semidefinite.register(MulLinearOperator)
def _(operator):
sign = _scalar_sign(operator.scalar)
if sign is _ScalarSign.positive:
return is_positive_semidefinite(operator.operator)
elif sign is _ScalarSign.negative:
return is_negative_semidefinite(operator.operator)
elif sign is _ScalarSign.zero:
return True # zero matrix is PSD
return False
@is_negative_semidefinite.register(MulLinearOperator)
def _(operator):
sign = _scalar_sign(operator.scalar)
if sign is _ScalarSign.positive:
return is_negative_semidefinite(operator.operator)
elif sign is _ScalarSign.negative:
return is_positive_semidefinite(operator.operator)
elif sign is _ScalarSign.zero:
return True # zero matrix is NSD
return False
# PSD/NSD for DivLinearOperator: depends on sign of scalar
# Zero scalar is division by zero - return False (conservative)
@is_positive_semidefinite.register(DivLinearOperator)
def _(operator):
sign = _scalar_sign(operator.scalar)
if sign is _ScalarSign.positive:
return is_positive_semidefinite(operator.operator)
elif sign is _ScalarSign.negative:
return is_negative_semidefinite(operator.operator)
return False
@is_negative_semidefinite.register(DivLinearOperator)
def _(operator):
sign = _scalar_sign(operator.scalar)
if sign is _ScalarSign.positive:
return is_negative_semidefinite(operator.operator)
elif sign is _ScalarSign.negative:
return is_positive_semidefinite(operator.operator)
return False
# PSD/NSD for NegLinearOperator: negation swaps PSD <-> NSD
@is_positive_semidefinite.register(NegLinearOperator)
def _(operator):
return is_negative_semidefinite(operator.operator)
@is_negative_semidefinite.register(NegLinearOperator)
def _(operator):
return is_positive_semidefinite(operator.operator)
for check, tag in (
(is_symmetric, symmetric_tag),
(is_diagonal, diagonal_tag),
(has_unit_diagonal, unit_diagonal_tag),
(is_lower_triangular, lower_triangular_tag),
(is_upper_triangular, upper_triangular_tag),
(is_positive_semidefinite, positive_semidefinite_tag),
(is_negative_semidefinite, negative_semidefinite_tag),
(is_tridiagonal, tridiagonal_tag),
):
@check.register(TaggedLinearOperator)
def _(operator, check=check, tag=tag):
return (tag in operator.tags) or check(operator.operator)
for check in (
is_symmetric,
is_diagonal,
is_lower_triangular,
is_upper_triangular,
is_positive_semidefinite,
is_negative_semidefinite,
is_tridiagonal,
):
@check.register(AddLinearOperator)
def _(operator, check=check):
return check(operator.operator1) and check(operator.operator2)
@has_unit_diagonal.register(AddLinearOperator)
def _(operator):
return False
# These properties ARE preserved under composition
for check in (
is_diagonal,
is_lower_triangular,
is_upper_triangular,
):
@check.register(ComposedLinearOperator)
def _(operator, check=check):
return check(operator.operator1) and check(operator.operator2)
# is_symmetric: A@B is symmetric only if A and B commute. Diagonal matrices commute.
@is_symmetric.register(ComposedLinearOperator)
def _(operator):
return is_diagonal(operator.operator1) and is_diagonal(operator.operator2)
# is_tridiagonal: tridiagonal @ tridiagonal = pentadiagonal, but
# tridiagonal @ diagonal = tridiagonal and diagonal @ tridiagonal = tridiagonal
@is_tridiagonal.register(ComposedLinearOperator)
def _(operator):
if is_diagonal(operator.operator1):
return is_tridiagonal(operator.operator2)
if is_diagonal(operator.operator2):
return is_tridiagonal(operator.operator1)
return False
# PSD/NSD: not preserved under composition in general.
@is_positive_semidefinite.register(ComposedLinearOperator)
@is_negative_semidefinite.register(ComposedLinearOperator)
def _(operator):
return False
@has_unit_diagonal.register(ComposedLinearOperator)
def _(operator):
a = is_diagonal(operator)
b = is_lower_triangular(operator)
c = is_upper_triangular(operator)
d = has_unit_diagonal(operator.operator1)
e = has_unit_diagonal(operator.operator2)
return (a or b or c) and d and e
# conj
@ft.singledispatch
def conj(operator: AbstractLinearOperator) -> AbstractLinearOperator:
"""Elementwise conjugate of a linear operator. This returns another linear operator.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Another linear operator.
"""
_default_not_implemented("conj", operator)
@conj.register(MatrixLinearOperator)
def _(operator):
return MatrixLinearOperator(operator.matrix.conj(), operator.tags)
@conj.register(PyTreeLinearOperator)
def _(operator):
pytree_conj = jtu.tree_map(lambda x: x.conj(), operator.pytree)
return PyTreeLinearOperator(pytree_conj, operator.out_structure(), operator.tags)
@conj.register(DiagonalLinearOperator)
def _(operator):
diagonal_conj = jtu.tree_map(lambda x: x.conj(), operator.diagonal)
return DiagonalLinearOperator(diagonal_conj)
@conj.register(JacobianLinearOperator)
def _(operator):
return conj(linearise(operator))
@conj.register(FunctionLinearOperator)
def _(operator):
return FunctionLinearOperator(
lambda vec: jtu.tree_map(jnp.conj, operator.mv(jtu.tree_map(jnp.conj, vec))),
operator.in_structure(),
operator.tags,
)
@conj.register(IdentityLinearOperator)
def _(operator):
return operator
@conj.register(TridiagonalLinearOperator)
def _(operator):
return TridiagonalLinearOperator(
operator.diagonal.conj(),
operator.lower_diagonal.conj(),
operator.upper_diagonal.conj(),
)
@conj.register(TaggedLinearOperator)
def _(operator):
return TaggedLinearOperator(conj(operator.operator), operator.tags)
@conj.register(TangentLinearOperator)
def _(operator):
c = lambda operator: conj(operator)
primal_out, tangent_out = eqx.filter_jvp(c, (operator.primal,), (operator.tangent,))
return TangentLinearOperator(primal_out, tangent_out)
@conj.register(AddLinearOperator)
def _(operator):
return conj(operator.operator1) + conj(operator.operator2)
@conj.register(MulLinearOperator)
def _(operator):
return conj(operator.operator) * operator.scalar.conj()
@conj.register(NegLinearOperator)
def _(operator):
return -conj(operator.operator)
@conj.register(DivLinearOperator)
def _(operator):
return conj(operator.operator) / operator.scalar.conj()
@conj.register(ComposedLinearOperator)
def _(operator):
return conj(operator.operator1) @ conj(operator.operator2)
================================================
FILE: lineax/_solution.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import equinox as eqx
import equinox.internal as eqxi
from jaxtyping import Array, ArrayLike, PyTree
_singular_msg = """
A linear solver returned non-finite (NaN or inf) output. This usually means that an
operator was not well-posed, and that its solver does not support this.
If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.
If you *were* expecting this solver to work with this operator, then it may be because:
(a) the operator is singular, and your code has a bug; or
(b) the operator was nearly singular (i.e. it had a high condition number:
`jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
numerical instability issues; or
(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
that is does not actually satisfy.
""".strip()
_nonfinite_msg = """
A linear solver received non-finite (NaN or inf) input and cannot determine a
solution.
This means that you have a bug upstream of Lineax and should check the inputs to
`lineax.linear_solve` for non-finite values.
""".strip()
class RESULTS(eqxi.Enumeration):
successful = ""
max_steps_reached = (
"The maximum number of solver steps was reached. Try increasing `max_steps`."
)
singular = _singular_msg
breakdown = (
"A form of iterative breakdown has occured in a linear solve. "
"Try using a different solver for this problem or increase `restart` "
"if using GMRES."
)
stagnation = (
"A stagnation in an iterative linear solve has occurred. Try increasing "
"`stagnation_iters` or `restart`."
)
conlim = "Condition number of A seems to be larger than `conlim`."
nonfinite_input = _nonfinite_msg
class Solution(eqx.Module):
"""The solution to a linear solve.
**Attributes:**
- `value`: The solution to the solve.
- `result`: An integer representing whether the solve was successful or not. This
can be converted into a human-readable error message via
`lineax.RESULTS[result]`.
- `stats`: Statistics about the solver, e.g. the number of steps that were required.
- `state`: The internal state of the solver. The meaning of this is specific to each
solver.
"""
value: PyTree[Array]
result: RESULTS
stats: dict[str, PyTree[ArrayLike]]
state: PyTree[Any]
================================================
FILE: lineax/_solve.py
================================================
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import functools as ft
from typing import Any, Generic, TypeAlias, TypeVar
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.core
import jax.interpreters.ad as ad
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from jax._src.ad_util import stop_gradient_p
from jaxtyping import Array, ArrayLike, PyTree
from ._custom_types import sentinel
from ._misc import inexact_asarray, strip_weak_dtype
from ._operator import (
AbstractLinearOperator,
conj,
FunctionLinearOperator,
has_unit_diagonal,
IdentityLinearOperator,
is_diagonal,
is_lower_triangular,
is_negative_semidefinite,
is_positive_semidefinite,
is_symmetric,
is_tridiagonal,
is_upper_triangular,
linearise,
TangentLinearOperator,
)
from ._solution import RESULTS, Solution
from ._tags import (
diagonal_tag,
lower_triangular_tag,
negative_semidefinite_tag,
positive_semidefinite_tag,
symmetric_tag,
unit_diagonal_tag,
upper_triangular_tag,
)
#
# _linear_solve_p
#
def _to_shapedarray(x):
if isinstance(x, jax.ShapeDtypeStruct):
return jax.core.ShapedArray(x.shape, x.dtype)
else:
return x
def _to_struct(x):
if isinstance(x, jax.core.ShapedArray):
return jax.ShapeDtypeStruct(x.shape, x.dtype)
elif isinstance(x, jax.core.AbstractValue):
raise NotImplementedError(
"`lineax.linear_solve` only supports working with JAX arrays; not "
f"other abstract values. Got abstract value {x}."
)
else:
return x
def _assert_false(x):
assert False
def _is_none(x):
return x is None
def _sum(*args):
return sum(args)
def _linear_solve_impl(_, state, vector, options, solver, throw, *, check_closure):
out = solver.compute(state, vector, options)
if check_closure:
out = eqxi.nontraceable(
out, name="lineax.linear_solve with respect to a closed-over value"
)
solution, result, stats = out
has_nonfinite_output = jnp.any(
jnp.stack(
[jnp.any(jnp.invert(jnp.isfinite(x))) for x in jtu.tree_leaves(solution)]
)
)
result = RESULTS.where(
(result == RESULTS.successful) & has_nonfinite_output,
RESULTS.singular,
result,
)
has_nonfinite_input = jnp.any(
jnp.stack(
[jnp.any(jnp.invert(jnp.isfinite(x))) for x in jtu.tree_leaves(vector)]
)
)
result = RESULTS.where(
(result == RESULTS.singular) & has_nonfinite_input,
RESULTS.nonfinite_input,
result,
)
if throw:
solution, result, stats = result.error_if(
(solution, result, stats),
result != RESULTS.successful,
)
return solution, result, stats
@eqxi.filter_primitive_def
def _linear_solve_abstract_eval(operator, state, vector, options, solver, throw):
state, vector, options, solver = jtu.tree_map(
_to_struct, (state, vector, options, solver)
)
out = eqx.filter_eval_shape(
_linear_solve_impl,
operator,
state,
vector,
options,
solver,
throw,
check_closure=False,
)
out = jtu.tree_map(_to_shapedarray, out)
return out
@eqxi.filter_primitive_jvp
def _linear_solve_jvp(primals, tangents):
operator, state, vector, options, solver, throw = primals
t_operator, t_state, t_vector, t_options, t_solver, t_throw = tangents
jtu.tree_map(_assert_false, (t_state, t_options, t_solver, t_throw))
del t_state, t_options, t_solver, t_throw
# Note that we pass throw=True unconditionally to all the tangent solves, as there
# is nowhere we can pipe their error to.
# This is the primal solve so we can respect the original `throw`.
solution, result, stats = eqxi.filter_primitive_bind(
linear_solve_p, operator, state, vector, options, solver, throw
)
#
# Consider the primal problem of linearly solving for x in Ax=b.
# Let ^ denote pseudoinverses, ᵀ denote transposes, and ' denote tangents.
# The linear_solve routine returns specifically the pseudoinverse solution, i.e.
#
# x = A^b
#
# Therefore x' = A^'b + A^b'
#
# Now A^' = -A^A'A^ + A^A^ᵀAᵀ'(I - AA^) + (I - A^A)Aᵀ'A^ᵀA^
#
# (Source: https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
#
# This results in:
#
# x' = A^(-A'x + A^ᵀAᵀ'(b - Ax) - Ay + b') + y
#
# where
#
# y = Aᵀ'A^ᵀx
#
# note that if A has linearly independent columns, then the y - A^Ay
# term disappears and gives
#
# x' = A^(-A'x + A^ᵀAᵀ'(b - Ax) + b')
#
# and if A has linearly independent rows, then the A^A^ᵀAᵀ'(b - Ax) term
# disappears giving:
#
# x' = A^(-A'x - Ay + b') + y
#
# if A has linearly independent rows and columns, then A is nonsingular and
#
# x' = A^(-A'x + b')
vecs = []
sols = []
if any(t is not None for t in jtu.tree_leaves(t_vector, is_leaf=_is_none)):
# b' term
vecs.append(
jtu.tree_map(eqxi.materialise_zeros, vector, t_vector, is_leaf=_is_none)
)
if any(t is not None for t in jtu.tree_leaves(t_operator, is_leaf=_is_none)):
t_operator = TangentLinearOperator(operator, t_operator)
t_operator = linearise(t_operator) # optimise for matvecs
# -A'x term
vec = (-(t_operator.mv(solution) ** ω)).ω
vecs.append(vec)
rows, columns = operator.out_size(), operator.in_size()
assume_independent_rows = solver.assume_full_rank() and rows <= columns
assume_independent_columns = solver.assume_full_rank() and columns <= rows
if not assume_independent_rows or not assume_independent_columns:
operator_conj_transpose = conj(operator).transpose()
t_operator_conj_transpose = conj(t_operator).transpose()
state_conj, options_conj = solver.conj(state, options)
state_conj_transpose, options_conj_transpose = solver.transpose(
state_conj, options_conj
)
if not assume_independent_rows:
lst_sqr_diff = (vector**ω - operator.mv(solution) ** ω).ω
tmp = t_operator_conj_transpose.mv(lst_sqr_diff) # pyright: ignore
tmp, _, _ = eqxi.filter_primitive_bind(
linear_solve_p,
operator_conj_transpose, # pyright: ignore
state_conj_transpose, # pyright: ignore
tmp,
options_conj_transpose, # pyright: ignore
solver,
True,
)
vecs.append(tmp)
if not assume_independent_columns:
tmp1, _, _ = eqxi.filter_primitive_bind(
linear_solve_p,
operator_conj_transpose, # pyright: ignore
state_conj_transpose, # pyright:ignore
solution,
options_conj_transpose, # pyright: ignore
solver,
True,
)
tmp2 = t_operator_conj_transpose.mv(tmp1) # pyright: ignore
# tmp2 is the y term
tmp3 = operator.mv(tmp2)
tmp4 = (-(tmp3**ω)).ω
# tmp4 is the Ay term
vecs.append(tmp4)
sols.append(tmp2)
vecs = jtu.tree_map(_sum, *vecs)
# the A^ term at the very beginning
sol, _, _ = eqxi.filter_primitive_bind(
linear_solve_p, operator, state, vecs, options, solver, True
)
sols.append(sol)
t_solution = jtu.tree_map(_sum, *sols)
out = solution, result, stats
t_out = (
t_solution,
jtu.tree_map(lambda _: None, result),
jtu.tree_map(lambda _: None, stats),
)
return out, t_out
def _is_undefined(x):
return isinstance(x, ad.UndefinedPrimal)
def _assert_defined(x):
assert not _is_undefined(x)
def _keep_undefined(v, ct):
if _is_undefined(v):
return ct
else:
return None
@eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore
def _linear_solve_transpose(inputs, cts_out):
cts_solution, _, _ = cts_out
operator, state, vector, options, solver, _ = inputs
jtu.tree_map(
_assert_defined, (operator, state, options, solver), is_leaf=_is_undefined
)
cts_solution = jtu.tree_map(
ft.partial(eqxi.materialise_zeros, allow_struct=True),
operator.in_structure(),
cts_solution,
)
operator_transpose = operator.transpose()
state_transpose, options_transpose = solver.transpose(state, options)
cts_vector, _, _ = eqxi.filter_primitive_bind(
linear_solve_p,
operator_transpose,
state_transpose,
cts_solution,
options_transpose,
solver,
True, # throw=True unconditionally: nowhere to pipe result to.
)
cts_vector = jtu.tree_map(
_keep_undefined, vector, cts_vector, is_leaf=_is_undefined
)
operator_none = jtu.tree_map(lambda _: None, operator)
state_none = jtu.tree_map(lambda _: None, state)
options_none = jtu.tree_map(lambda _: None, options)
solver_none = jtu.tree_map(lambda _: None, solver)
throw_none = None
return operator_none, state_none, cts_vector, options_none, solver_none, throw_none
# Call with `check_closure=False` so that the autocreated vmap rule works.
linear_solve_p = eqxi.create_vprim(
"linear_solve",
eqxi.filter_primitive_def(ft.partial(_linear_solve_impl, check_closure=False)),
_linear_solve_abstract_eval,
_linear_solve_jvp,
_linear_solve_transpose,
)
# Then rebind so that the impl rule catches leaked-in tracers.
linear_solve_p.def_impl(
eqxi.filter_primitive_def(ft.partial(_linear_solve_impl, check_closure=True))
)
eqxi.register_impl_finalisation(linear_solve_p)
#
# linear_solve
#
_SolverState = TypeVar("_SolverState")
class AbstractLinearSolver(eqx.Module, Generic[_SolverState]):
"""Abstract base class for all linear solvers."""
@abc.abstractmethod
def init(
self, operator: AbstractLinearOperator, options: dict[str, Any]
) -> _SolverState:
"""Do any initial computation on just the `operator`.
For example, an LU solver would compute the LU decomposition of the operator
(and this does not require knowing the vector yet).
It is common to need to solve the linear system `Ax=b` multiple times in
succession, with the same operator `A` and multiple vectors `b`. This method
improves efficiency by making it possible to re-use the computation performed
on just the operator.
!!! Example
```python
operator = lx.MatrixLinearOperator(...)
vector1 = ...
vector2 = ...
solver = lx.LU()
state = solver.init(operator, options={})
solution1 = lx.linear_solve(operator, vector1, solver, state=state)
solution2 = lx.linear_solve(operator, vector2, solver, state=state)
```
**Arguments:**
- `operator`: a linear operator.
- `options`: a dictionary of any extra options that the solver may wish to
accept.
**Returns:**
A PyTree of arbitrary Python objects.
"""
@abc.abstractmethod
def compute(
self, state: _SolverState, vector: PyTree[Array], options: dict[str, Any]
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
"""Solves a linear system.
**Arguments:**
- `state`: as returned from [`lineax.AbstractLinearSolver.init`][].
- `vector`: the vector to solve against.
- `options`: a dictionary of any extra options that the solver may wish to
accept. For example, [`lineax.CG`][] accepts a `preconditioner` option.
**Returns:**
A 3-tuple of:
- The solution to the linear system.
- An integer indicating the success or failure of the solve. This is an integer
which may be converted to a human-readable error message via
`lx.RESULTS[...]`.
- A dictionary of an extra statistics about the solve, e.g. the number of steps
gitextract_at99w4wk/
├── .github/
│ └── workflows/
│ ├── build_docs.yml
│ ├── release.yml
│ └── run_tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── benchmarks/
│ ├── gmres_fails_safely.py
│ ├── lstsq_gradients.py
│ └── solver_speeds.py
├── docs/
│ ├── .htaccess
│ ├── _overrides/
│ │ └── partials/
│ │ └── source.html
│ ├── _static/
│ │ ├── custom_css.css
│ │ └── mathjax.js
│ ├── api/
│ │ ├── functions.md
│ │ ├── linear_solve.md
│ │ ├── operators.md
│ │ ├── solution.md
│ │ ├── solvers.md
│ │ └── tags.md
│ ├── examples/
│ │ ├── classical_solve.ipynb
│ │ ├── complex_solve.ipynb
│ │ ├── least_squares.ipynb
│ │ ├── no_materialisation.ipynb
│ │ ├── operators.ipynb
│ │ └── structured_matrices.ipynb
│ ├── faq.md
│ └── index.md
├── lineax/
│ ├── __init__.py
│ ├── _custom_types.py
│ ├── _misc.py
│ ├── _norm.py
│ ├── _operator.py
│ ├── _solution.py
│ ├── _solve.py
│ ├── _solver/
│ │ ├── __init__.py
│ │ ├── bicgstab.py
│ │ ├── cg.py
│ │ ├── cholesky.py
│ │ ├── diagonal.py
│ │ ├── gmres.py
│ │ ├── lsmr.py
│ │ ├── lu.py
│ │ ├── misc.py
│ │ ├── normal.py
│ │ ├── qr.py
│ │ ├── svd.py
│ │ ├── triangular.py
│ │ └── tridiagonal.py
│ ├── _tags.py
│ └── internal/
│ └── __init__.py
├── mkdocs.yml
├── pyproject.toml
└── tests/
├── README.md
├── __init__.py
├── __main__.py
├── conftest.py
├── helpers.py
├── test_adjoint.py
├── test_invert.py
├── test_jvp.py
├── test_jvp_jvp1.py
├── test_jvp_jvp2.py
├── test_lsmr.py
├── test_misc.py
├── test_norm.py
├── test_operator.py
├── test_singular.py
├── test_solve.py
├── test_transpose.py
├── test_vmap.py
├── test_vmap_jvp.py
├── test_vmap_vmap.py
└── test_well_posed.py
SYMBOL INDEX (525 symbols across 40 files)
FILE: benchmarks/gmres_fails_safely.py
function tree_allclose (line 29) | def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8):
function make_problem (line 36) | def make_problem(mat_size: int, *, key):
function benchmark_jax (line 44) | def benchmark_jax(mat_size: int, *, key):
function benchmark_lx (line 63) | def benchmark_lx(mat_size: int, *, key):
FILE: benchmarks/lstsq_gradients.py
function jax_solve (line 34) | def jax_solve(a):
function lx_solve (line 39) | def lx_solve(a):
FILE: benchmarks/solver_speeds.py
function tree_allclose (line 35) | def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8):
function base_wrapper (line 47) | def base_wrapper(a, b, solver):
function jax_svd (line 61) | def jax_svd(a, b):
function jax_gmres (line 66) | def jax_gmres(a, b):
function jax_bicgstab (line 71) | def jax_bicgstab(a, b):
function jax_cg (line 76) | def jax_cg(a, b):
function jax_lu (line 81) | def jax_lu(matrix, vector):
function jax_cholesky (line 85) | def jax_cholesky(matrix, vector):
function jax_tridiagonal (line 89) | def jax_tridiagonal(matrix, vector):
function create_problem (line 139) | def create_problem(solver, tags, size=3):
function create_easy_iterative_problem (line 146) | def create_easy_iterative_problem(size, tags):
function test_solvers (line 155) | def test_solvers(vmap_size, mat_size):
FILE: lineax/_misc.py
function tree_where (line 23) | def tree_where(
function resolve_rcond (line 30) | def resolve_rcond(rcond, n, m, dtype):
function jacobian (line 41) | def jacobian(fn, in_size, out_size, holomorphic=False, has_aux=False, ja...
function _asarray (line 58) | def _asarray(dtype, x):
function _asarray_jvp (line 67) | def _asarray_jvp(dtype, x, tx):
function default_floating_dtype (line 73) | def default_floating_dtype():
function inexact_asarray (line 80) | def inexact_asarray(x):
function complex_to_real_dtype (line 87) | def complex_to_real_dtype(dtype):
function strip_weak_dtype (line 91) | def strip_weak_dtype(tree: PyTree) -> PyTree:
function structure_equal (line 100) | def structure_equal(x, y) -> bool:
FILE: lineax/_norm.py
function tree_dot (line 27) | def tree_dot(tree1: PyTree[ArrayLike], tree2: PyTree[ArrayLike]) -> Inex...
function sum_squares (line 50) | def sum_squares(x: PyTree[ArrayLike]) -> Scalar:
function two_norm (line 59) | def two_norm(x: PyTree[ArrayLike]) -> Scalar:
function _two_norm (line 71) | def _two_norm(x: PyTree[ArrayLike]) -> Scalar:
function _two_norm_jvp (line 86) | def _two_norm_jvp(x, tx):
function rms_norm (line 104) | def rms_norm(x: PyTree[ArrayLike]) -> Scalar:
function max_norm (line 123) | def max_norm(x: PyTree[ArrayLike]) -> Scalar:
function _zero_grad_at_zero (line 143) | def _zero_grad_at_zero(x):
function _zero_grad_at_zero_jvp (line 148) | def _zero_grad_at_zero_jvp(primals, tangents):
FILE: lineax/_operator.py
function _frozenset (line 61) | def _frozenset(x: object | Iterable[object]) -> frozenset[object]:
class AbstractLinearOperator (line 70) | class AbstractLinearOperator(eqx.Module):
method __check_init__ (line 87) | def __check_init__(self):
method mv (line 105) | def mv(
method as_matrix (line 122) | def as_matrix(self) -> Inexact[Array, "a b"]:
method transpose (line 137) | def transpose(self) -> "AbstractLinearOperator":
method in_structure (line 150) | def in_structure(self) -> PyTree[jax.ShapeDtypeStruct]:
method out_structure (line 161) | def out_structure(self) -> PyTree[jax.ShapeDtypeStruct]:
method in_size (line 171) | def in_size(self) -> int:
method out_size (line 183) | def out_size(self) -> int:
method T (line 196) | def T(self) -> "AbstractLinearOperator":
method __add__ (line 200) | def __add__(self, other) -> "AbstractLinearOperator":
method __sub__ (line 205) | def __sub__(self, other) -> "AbstractLinearOperator":
method __mul__ (line 210) | def __mul__(self, other) -> "AbstractLinearOperator":
method __rmul__ (line 216) | def __rmul__(self, other) -> "AbstractLinearOperator":
method __matmul__ (line 219) | def __matmul__(self, other) -> "AbstractLinearOperator":
method __truediv__ (line 224) | def __truediv__(self, other) -> "AbstractLinearOperator":
method __neg__ (line 230) | def __neg__(self) -> "AbstractLinearOperator":
class MatrixLinearOperator (line 234) | class MatrixLinearOperator(AbstractLinearOperator):
method __init__ (line 245) | def __init__(
method mv (line 267) | def mv(self, vector):
method as_matrix (line 273) | def as_matrix(self):
method transpose (line 276) | def transpose(self):
method in_structure (line 281) | def in_structure(self):
method out_structure (line 285) | def out_structure(self):
function _matmul (line 290) | def _matmul(matrix: ArrayLike, vector: ArrayLike) -> Array:
function _tree_matmul (line 299) | def _tree_matmul(matrix: PyTree[ArrayLike], vector: PyTree[ArrayLike]) -...
function _inexact_structure_impl2 (line 315) | def _inexact_structure_impl2(x):
function _inexact_structure_impl (line 322) | def _inexact_structure_impl(x):
function _inexact_structure (line 326) | def _inexact_structure(x: PyTree[jax.ShapeDtypeStruct]) -> PyTree[jax.Sh...
class _Leaf (line 330) | class _Leaf: # not a pytree
method __init__ (line 331) | def __init__(self, value):
class PyTreeLinearOperator (line 337) | class PyTreeLinearOperator(AbstractLinearOperator):
method __init__ (line 373) | def __init__(
method mv (line 426) | def mv(self, vector):
method as_matrix (line 440) | def as_matrix(self):
method transpose (line 459) | def transpose(self):
method in_structure (line 481) | def in_structure(self):
method out_structure (line 485) | def out_structure(self):
class DiagonalLinearOperator (line 490) | class DiagonalLinearOperator(AbstractLinearOperator):
method __init__ (line 502) | def __init__(self, diagonal: PyTree[ArrayLike]):
method mv (line 509) | def mv(self, vector):
method as_matrix (line 512) | def as_matrix(self):
method transpose (line 515) | def transpose(self):
method in_structure (line 518) | def in_structure(self):
method out_structure (line 521) | def out_structure(self):
class _NoAuxIn (line 525) | class _NoAuxIn(eqx.Module):
method __call__ (line 529) | def __call__(self, x):
class _Unwrap (line 533) | class _Unwrap(eqx.Module):
method __call__ (line 536) | def __call__(self, x):
class JacobianLinearOperator (line 541) | class JacobianLinearOperator(AbstractLinearOperator):
method __init__ (line 575) | def __init__(
method mv (line 622) | def mv(self, vector):
method as_matrix (line 643) | def as_matrix(self):
method transpose (line 646) | def transpose(self):
method in_structure (line 657) | def in_structure(self):
method out_structure (line 660) | def out_structure(self):
class FunctionLinearOperator (line 666) | class FunctionLinearOperator(AbstractLinearOperator):
method __init__ (line 679) | def __init__(
method mv (line 707) | def mv(self, vector):
method as_matrix (line 710) | def as_matrix(self):
method transpose (line 713) | def transpose(self):
method in_structure (line 727) | def in_structure(self):
method out_structure (line 731) | def out_structure(self):
class IdentityLinearOperator (line 738) | class IdentityLinearOperator(AbstractLinearOperator):
method __init__ (line 746) | def __init__(
method mv (line 769) | def mv(self, vector):
method as_matrix (line 802) | def as_matrix(self):
method transpose (line 812) | def transpose(self):
method in_structure (line 815) | def in_structure(self):
method out_structure (line 819) | def out_structure(self):
method tags (line 824) | def tags(self):
class TridiagonalLinearOperator (line 828) | class TridiagonalLinearOperator(AbstractLinearOperator):
method __init__ (line 837) | def __init__(
method mv (line 863) | def mv(self, vector):
method as_matrix (line 869) | def as_matrix(self):
method transpose (line 878) | def transpose(self):
method in_structure (line 883) | def in_structure(self):
method out_structure (line 887) | def out_structure(self):
class TaggedLinearOperator (line 892) | class TaggedLinearOperator(AbstractLinearOperator):
method __init__ (line 915) | def __init__(
method mv (line 929) | def mv(self, vector):
method as_matrix (line 932) | def as_matrix(self):
method transpose (line 935) | def transpose(self):
method in_structure (line 940) | def in_structure(self):
method out_structure (line 943) | def out_structure(self):
function _is_none (line 952) | def _is_none(x):
class TangentLinearOperator (line 956) | class TangentLinearOperator(AbstractLinearOperator):
method __check_init__ (line 964) | def __check_init__(self):
method mv (line 967) | def mv(self, vector):
method as_matrix (line 972) | def as_matrix(self):
method transpose (line 977) | def transpose(self):
method in_structure (line 984) | def in_structure(self):
method out_structure (line 987) | def out_structure(self):
class AddLinearOperator (line 991) | class AddLinearOperator(AbstractLinearOperator):
method __check_init__ (line 1006) | def __check_init__(self):
method mv (line 1012) | def mv(self, vector):
method as_matrix (line 1020) | def as_matrix(self):
method transpose (line 1023) | def transpose(self):
method in_structure (line 1026) | def in_structure(self):
method out_structure (line 1029) | def out_structure(self):
class MulLinearOperator (line 1033) | class MulLinearOperator(AbstractLinearOperator):
method mv (line 1048) | def mv(self, vector):
method as_matrix (line 1051) | def as_matrix(self):
method transpose (line 1054) | def transpose(self):
method in_structure (line 1057) | def in_structure(self):
method out_structure (line 1060) | def out_structure(self):
class NegLinearOperator (line 1066) | class NegLinearOperator(AbstractLinearOperator):
method mv (line 1079) | def mv(self, vector):
method as_matrix (line 1082) | def as_matrix(self):
method transpose (line 1085) | def transpose(self):
method in_structure (line 1088) | def in_structure(self):
method out_structure (line 1091) | def out_structure(self):
class DivLinearOperator (line 1095) | class DivLinearOperator(AbstractLinearOperator):
method mv (line 1110) | def mv(self, vector):
method as_matrix (line 1114) | def as_matrix(self):
method transpose (line 1117) | def transpose(self):
method in_structure (line 1120) | def in_structure(self):
method out_structure (line 1123) | def out_structure(self):
class ComposedLinearOperator (line 1127) | class ComposedLinearOperator(AbstractLinearOperator):
method __check_init__ (line 1145) | def __check_init__(self):
method mv (line 1149) | def mv(self, vector):
method as_matrix (line 1155) | def as_matrix(self):
method transpose (line 1170) | def transpose(self):
method in_structure (line 1173) | def in_structure(self):
method out_structure (line 1176) | def out_structure(self):
function _default_not_implemented (line 1197) | def _default_not_implemented(name: str, operator: AbstractLinearOperator...
function linearise (line 1209) | def linearise(operator: AbstractLinearOperator) -> AbstractLinearOperator:
function _ (line 1238) | def _(operator):
function _ (line 1243) | def _(operator):
function materialise (line 1266) | def materialise(operator: AbstractLinearOperator) -> AbstractLinearOpera...
function _try_sparse_materialise (line 1320) | def _try_sparse_materialise(operator: AbstractLinearOperator) -> Abstrac...
function _ (line 1344) | def _(operator):
function _ (line 1351) | def _(operator):
function _ (line 1356) | def _(operator):
function _ (line 1372) | def _(operator):
function diagonal (line 1397) | def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]:
function _leaf_from_keypath (line 1416) | def _leaf_from_keypath(pytree: PyTree, keypath: jtu.KeyPath) -> Array:
function _ (line 1425) | def _(operator):
function _ (line 1430) | def _(operator):
function _ (line 1447) | def _(operator):
function _ (line 1460) | def _(operator):
function _ (line 1466) | def _(operator):
function _ (line 1471) | def _(operator):
function tridiagonal (line 1479) | def tridiagonal(
function _ (line 1512) | def _(operator):
function _ (line 1523) | def _(operator):
function _ (line 1566) | def _(operator):
function _ (line 1574) | def _(operator):
function _ (line 1582) | def _(operator):
function is_symmetric (line 1590) | def is_symmetric(operator: AbstractLinearOperator) -> bool:
function _has_real_dtype (line 1607) | def _has_real_dtype(operator) -> bool:
function _ (line 1625) | def _(operator):
function _ (line 1639) | def _(operator):
function _ (line 1644) | def _(operator):
function _ (line 1649) | def _(operator):
function is_diagonal (line 1657) | def is_diagonal(operator: AbstractLinearOperator) -> bool:
function _ (line 1678) | def _(operator):
function _ (line 1686) | def _(operator):
function _ (line 1691) | def _(operator):
function is_tridiagonal (line 1699) | def is_tridiagonal(operator: AbstractLinearOperator) -> bool:
function _ (line 1720) | def _(operator):
function _ (line 1727) | def _(operator):
function has_unit_diagonal (line 1735) | def has_unit_diagonal(operator: AbstractLinearOperator) -> bool:
function _ (line 1756) | def _(operator):
function _ (line 1761) | def _(operator):
function _ (line 1767) | def _(operator):
function is_lower_triangular (line 1776) | def is_lower_triangular(operator: AbstractLinearOperator) -> bool:
function _ (line 1797) | def _(operator):
function _ (line 1803) | def _(operator):
function _ (line 1808) | def _(operator):
function is_upper_triangular (line 1816) | def is_upper_triangular(operator: AbstractLinearOperator) -> bool:
function _ (line 1837) | def _(operator):
function _ (line 1843) | def _(operator):
function _ (line 1848) | def _(operator):
function is_positive_semidefinite (line 1856) | def is_positive_semidefinite(operator: AbstractLinearOperator) -> bool:
function _ (line 1877) | def _(operator):
function _ (line 1882) | def _(operator):
function _ (line 1888) | def _(operator):
function is_negative_semidefinite (line 1897) | def is_negative_semidefinite(operator: AbstractLinearOperator) -> bool:
function _ (line 1918) | def _(operator):
function _ (line 1923) | def _(operator):
function _ (line 1929) | def _(operator):
function _ (line 1938) | def _(operator):
function _ (line 1943) | def _(operator):
function _ (line 1948) | def _(operator):
function _ (line 1953) | def _(operator):
function _ (line 1960) | def _(operator, transform=transform):
function _ (line 1964) | def _(operator, transform=transform):
function _ (line 1968) | def _(operator, transform=transform):
function _ (line 1975) | def _(operator, transform=transform):
function _ (line 1980) | def _(operator):
function _ (line 1988) | def _(operator):
function _ (line 1996) | def _(operator):
function _ (line 2004) | def _(operator):
function _ (line 2013) | def _(operator):
function _ (line 2022) | def _(operator):
function _ (line 2029) | def _(operator):
function _ (line 2035) | def _(operator):
function _ (line 2041) | def _(operator):
function _ (line 2047) | def _(operator):
function _ (line 2052) | def _(operator):
function _ (line 2064) | def _(operator):
function _ (line 2071) | def _(operator):
function _ (line 2102) | def _(operator, check=check):
function _ (line 2118) | def _(operator, check=check):
function _ (line 2126) | def _(operator):
class _ScalarSign (line 2130) | class _ScalarSign(enum.Enum):
function _scalar_sign (line 2137) | def _scalar_sign(scalar) -> _ScalarSign:
function _ (line 2154) | def _(operator):
function _ (line 2166) | def _(operator):
function _ (line 2180) | def _(operator):
function _ (line 2190) | def _(operator):
function _ (line 2201) | def _(operator):
function _ (line 2206) | def _(operator):
function _ (line 2222) | def _(operator, check=check, tag=tag):
function _ (line 2237) | def _(operator, check=check):
function _ (line 2242) | def _(operator):
function _ (line 2254) | def _(operator, check=check):
function _ (line 2260) | def _(operator):
function _ (line 2267) | def _(operator):
function _ (line 2278) | def _(operator):
function _ (line 2283) | def _(operator):
function conj (line 2296) | def conj(operator: AbstractLinearOperator) -> AbstractLinearOperator:
function _ (line 2311) | def _(operator):
function _ (line 2316) | def _(operator):
function _ (line 2322) | def _(operator):
function _ (line 2328) | def _(operator):
function _ (line 2333) | def _(operator):
function _ (line 2342) | def _(operator):
function _ (line 2347) | def _(operator):
function _ (line 2356) | def _(operator):
function _ (line 2361) | def _(operator):
function _ (line 2368) | def _(operator):
function _ (line 2373) | def _(operator):
function _ (line 2378) | def _(operator):
function _ (line 2383) | def _(operator):
function _ (line 2388) | def _(operator):
FILE: lineax/_solution.py
class RESULTS (line 52) | class RESULTS(eqxi.Enumeration):
class Solution (line 71) | class Solution(eqx.Module):
FILE: lineax/_solve.py
function _to_shapedarray (line 66) | def _to_shapedarray(x):
function _to_struct (line 73) | def _to_struct(x):
function _assert_false (line 85) | def _assert_false(x):
function _is_none (line 89) | def _is_none(x):
function _sum (line 93) | def _sum(*args):
function _linear_solve_impl (line 97) | def _linear_solve_impl(_, state, vector, options, solver, throw, *, chec...
function _linear_solve_abstract_eval (line 133) | def _linear_solve_abstract_eval(operator, state, vector, options, solver...
function _linear_solve_jvp (line 152) | def _linear_solve_jvp(primals, tangents):
function _is_undefined (line 271) | def _is_undefined(x):
function _assert_defined (line 275) | def _assert_defined(x):
function _keep_undefined (line 279) | def _keep_undefined(v, ct):
function _linear_solve_transpose (line 287) | def _linear_solve_transpose(inputs, cts_out):
class AbstractLinearSolver (line 343) | class AbstractLinearSolver(eqx.Module, Generic[_SolverState]):
method init (line 347) | def init(
method compute (line 384) | def compute(
method transpose (line 409) | def transpose(
method conj (line 440) | def conj(
method assume_full_rank (line 466) | def assume_full_rank(self) -> bool:
function _lookup (line 496) | def _lookup(token) -> AbstractLinearSolver:
class AutoLinearSolver (line 518) | class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState]):
method _select_solver (line 555) | def _select_solver(self, operator: AbstractLinearOperator):
method select_solver (line 602) | def select_solver(self, operator: AbstractLinearOperator) -> AbstractL...
method init (line 615) | def init(self, operator, options) -> _AutoLinearSolverState:
method compute (line 619) | def compute(
method transpose (line 630) | def transpose(self, state: _AutoLinearSolverState, options: dict[str, ...
method conj (line 637) | def conj(self, state: _AutoLinearSolverState, options: dict[str, Any]):
method assume_full_rank (line 644) | def assume_full_rank(self):
function linear_solve (line 657) | def linear_solve(
function invert (line 809) | def invert(
function stop_gradient_transpose (line 876) | def stop_gradient_transpose(ct, x):
FILE: lineax/_solver/bicgstab.py
class BiCGStab (line 35) | class BiCGStab(AbstractLinearSolver[_BiCGStabState]):
method __check_init__ (line 57) | def __check_init__(self):
method init (line 70) | def init(self, operator: AbstractLinearOperator, options: dict[str, An...
method compute (line 78) | def compute(
method transpose (line 207) | def transpose(self, state: _BiCGStabState, options: dict[str, Any]):
method conj (line 214) | def conj(self, state: _BiCGStabState, options: dict[str, Any]):
method assume_full_rank (line 221) | def assume_full_rank(self):
FILE: lineax/_solver/cg.py
class CG (line 48) | class CG(AbstractLinearSolver[_CGState]):
method __check_init__ (line 75) | def __check_init__(self):
method init (line 88) | def init(self, operator: AbstractLinearOperator, options: dict[str, An...
method compute (line 114) | def compute(
method transpose (line 229) | def transpose(
method conj (line 239) | def conj(
method assume_full_rank (line 249) | def assume_full_rank(self):
function NormalCG (line 271) | def NormalCG(*args, **kwargs):
FILE: lineax/_solver/cholesky.py
class Cholesky (line 34) | class Cholesky(AbstractLinearSolver[_CholeskyState]):
method init (line 43) | def init(self, operator: AbstractLinearOperator, options: dict[str, An...
method compute (line 65) | def compute(
method transpose (line 80) | def transpose(
method conj (line 87) | def conj(
method assume_full_rank (line 94) | def assume_full_rank(self):
FILE: lineax/_solver/diagonal.py
class Diagonal (line 36) | class Diagonal(AbstractLinearSolver[_DiagonalState]):
method init (line 48) | def init(
method compute (line 66) | def compute(
method transpose (line 85) | def transpose(self, state: _DiagonalState, options: dict[str, Any]):
method conj (line 93) | def conj(self, state: _DiagonalState, options: dict[str, Any]):
method assume_full_rank (line 104) | def assume_full_rank(self):
FILE: lineax/_solver/gmres.py
class GMRES (line 39) | class GMRES(AbstractLinearSolver[_GMRESState]):
method __check_init__ (line 65) | def __check_init__(self):
method init (line 78) | def init(self, operator: AbstractLinearOperator, options: dict[str, An...
method compute (line 106) | def compute(
method _gmres_compute (line 242) | def _gmres_compute(
method _arnoldi_gram_schmidt (line 331) | def _arnoldi_gram_schmidt(
method _normalise (line 401) | def _normalise(
method transpose (line 415) | def transpose(self, state: _GMRESState, options: dict[str, Any]):
method conj (line 422) | def conj(self, state: _GMRESState, options: dict[str, Any]):
method assume_full_rank (line 429) | def assume_full_rank(self):
FILE: lineax/_solver/lsmr.py
class LSMR (line 55) | class LSMR(AbstractLinearSolver[_LSMRState]):
method __check_init__ (line 76) | def __check_init__(self):
method init (line 91) | def init(self, operator: AbstractLinearOperator, options: dict[str, An...
method compute (line 94) | def compute(
method _givens (line 361) | def _givens(self, a, b):
method transpose (line 411) | def transpose(self, state: _LSMRState, options: dict[str, Any]):
method conj (line 417) | def conj(self, state: _LSMRState, options: dict[str, Any]):
method assume_full_rank (line 423) | def assume_full_rank(self):
FILE: lineax/_solver/lu.py
class LU (line 37) | class LU(AbstractLinearSolver[_LUState]):
method init (line 43) | def init(self, operator: AbstractLinearOperator, options: dict[str, An...
method compute (line 56) | def compute(
method transpose (line 68) | def transpose(
method conj (line 83) | def conj(
method assume_full_rank (line 97) | def assume_full_rank(self):
FILE: lineax/_solver/misc.py
function preconditioner_and_y0 (line 30) | def preconditioner_and_y0(
function pack_structures (line 70) | def pack_structures(operator: AbstractLinearOperator) -> PackedStructures:
function ravel_vector (line 79) | def ravel_vector(
function unravel_solution (line 93) | def unravel_solution(
function transpose_packed_structures (line 108) | def transpose_packed_structures(
FILE: lineax/_solver/normal.py
function normal_preconditioner_and_y0 (line 31) | def normal_preconditioner_and_y0(options: dict[str, Any], tall: bool):
class Normal (line 53) | class Normal(
method init (line 105) | def init(self, operator, options):
method compute (line 122) | def compute(
method transpose (line 142) | def transpose(
method conj (line 161) | def conj(
method assume_full_rank (line 180) | def assume_full_rank(self):
FILE: lineax/_solver/qr.py
class QR (line 37) | class QR(AbstractLinearSolver):
method init (line 55) | def init(self, operator, options):
method compute (line 67) | def compute(
method transpose (line 96) | def transpose(self, state: _QRState, options: dict[str, Any]):
method conj (line 107) | def conj(self, state: _QRState, options: dict[str, Any]):
method assume_full_rank (line 117) | def assume_full_rank(self):
FILE: lineax/_solver/svd.py
class SVD (line 38) | class SVD(AbstractLinearSolver[_SVDState]):
method init (line 49) | def init(self, operator: AbstractLinearOperator, options: dict[str, An...
method compute (line 55) | def compute(
method transpose (line 80) | def transpose(self, state: _SVDState, options: dict[str, Any]):
method conj (line 88) | def conj(self, state: _SVDState, options: dict[str, Any]):
method assume_full_rank (line 95) | def assume_full_rank(self):
FILE: lineax/_solver/triangular.py
class Triangular (line 43) | class Triangular(AbstractLinearSolver[_TriangularState]):
method init (line 49) | def init(self, operator: AbstractLinearOperator, options: dict[str, An...
method compute (line 68) | def compute(
method transpose (line 87) | def transpose(self, state: _TriangularState, options: dict[str, Any]):
method conj (line 101) | def conj(self, state: _TriangularState, options: dict[str, Any]):
method assume_full_rank (line 114) | def assume_full_rank(self):
FILE: lineax/_solver/tridiagonal.py
class Tridiagonal (line 36) | class Tridiagonal(AbstractLinearSolver[_TridiagonalState]):
method init (line 41) | def init(self, operator: AbstractLinearOperator, options: dict[str, An...
method compute (line 54) | def compute(
method transpose (line 74) | def transpose(self, state: _TridiagonalState, options: dict[str, Any]):
method conj (line 81) | def conj(self, state: _TridiagonalState, options: dict[str, Any]):
method assume_full_rank (line 87) | def assume_full_rank(self):
FILE: lineax/_tags.py
class _HasRepr (line 16) | class _HasRepr:
method __init__ (line 17) | def __init__(self, string: str):
method __repr__ (line 20) | def __repr__(self):
function _ (line 47) | def _(tags: frozenset[object], tag=tag):
function _ (line 53) | def _(tags: frozenset[object]):
function _ (line 59) | def _(tags: frozenset[object]):
function transpose_tags (line 64) | def transpose_tags(tags: frozenset[object]):
FILE: tests/conftest.py
function getkey (line 26) | def getkey():
FILE: tests/helpers.py
function _construct_matrix_impl (line 31) | def _construct_matrix_impl(
function construct_matrix (line 75) | def construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.f...
function construct_singular_matrix (line 86) | def construct_singular_matrix(getkey, solver, tags, num=1, dtype=jnp.flo...
function construct_poisson_matrix (line 101) | def construct_poisson_matrix(size, dtype=jnp.float64):
function _transpose (line 142) | def _transpose(operator, matrix):
function _linearise (line 146) | def _linearise(operator, matrix):
function _materialise (line 150) | def _materialise(operator, matrix):
function params (line 157) | def params(only_pseudo):
function tree_allclose (line 177) | def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8):
function has_tag (line 181) | def has_tag(tags, tag):
function _operators_append (line 188) | def _operators_append(x):
function make_matrix_operator (line 194) | def make_matrix_operator(getkey, matrix, tags):
function make_trivial_pytree_operator (line 199) | def make_trivial_pytree_operator(getkey, matrix, tags):
function make_function_operator (line 206) | def make_function_operator(getkey, matrix, tags):
function make_jac_operator (line 214) | def make_jac_operator(getkey, matrix, tags):
function make_jacfwd_operator (line 228) | def make_jacfwd_operator(getkey, matrix, tags):
function make_jacrev_operator (line 242) | def make_jacrev_operator(getkey, matrix, tags):
function make_trivial_diagonal_operator (line 278) | def make_trivial_diagonal_operator(getkey, matrix, tags):
function make_identity_operator (line 285) | def make_identity_operator(getkey, matrix, tags):
function make_tridiagonal_operator (line 291) | def make_tridiagonal_operator(getkey, matrix, tags):
function make_add_operator (line 312) | def make_add_operator(getkey, matrix, tags):
function make_mul_operator (line 322) | def make_mul_operator(getkey, matrix, tags):
function make_composed_operator (line 328) | def make_composed_operator(getkey, matrix, tags):
function finite_difference_jvp (line 342) | def finite_difference_jvp(fn, primals, tangents):
function jvp_jvp_impl (line 355) | def jvp_jvp_impl(
FILE: tests/test_adjoint.py
function test_adjoint (line 20) | def test_adjoint(make_operator, dtype, getkey):
function test_functional_pytree_adjoint (line 58) | def test_functional_pytree_adjoint():
function test_functional_pytree_adjoint_complex (line 68) | def test_functional_pytree_adjoint_complex():
function test_preconditioner_adjoint (line 95) | def test_preconditioner_adjoint(solver):
FILE: tests/test_invert.py
function _well_conditioned_matrix (line 24) | def _well_conditioned_matrix(getkey, size=3, dtype=jnp.float64):
function _well_conditioned_psd_matrix (line 32) | def _well_conditioned_psd_matrix(getkey, size=3, dtype=jnp.float64):
function test_mv (line 41) | def test_mv(getkey):
function test_composition_identity (line 52) | def test_composition_identity(getkey):
function test_double_inverse (line 63) | def test_double_inverse(getkey):
function test_pseudoinverse_overdetermined (line 77) | def test_pseudoinverse_overdetermined(getkey):
function test_pseudoinverse_underdetermined (line 88) | def test_pseudoinverse_underdetermined(getkey):
function test_solver_cholesky (line 102) | def test_solver_cholesky(getkey):
function test_solver_cg (line 113) | def test_solver_cg(getkey):
function test_vmap (line 127) | def test_vmap(getkey):
function test_grad_wrt_vector (line 141) | def test_grad_wrt_vector(getkey):
function test_jvp_wrt_vector (line 156) | def test_jvp_wrt_vector(getkey):
function test_grad_wrt_operator (line 172) | def test_grad_wrt_operator(getkey):
function test_jvp_wrt_operator (line 190) | def test_jvp_wrt_operator(getkey):
FILE: tests/test_jvp.py
function test_jvp (line 46) | def test_jvp(
FILE: tests/test_jvp_jvp1.py
function _clear_cache (line 31) | def _clear_cache():
function test_jvp_jvp (line 40) | def test_jvp_jvp(
FILE: tests/test_jvp_jvp2.py
function _clear_cache (line 31) | def _clear_cache():
function test_jvp_jvp (line 40) | def test_jvp_jvp(
FILE: tests/test_lsmr.py
function test_ill_conditioned (line 13) | def test_ill_conditioned():
function test_zero_rhs (line 20) | def test_zero_rhs():
function test_damp_regularizes (line 34) | def test_damp_regularizes():
function test_damp (line 47) | def test_damp():
FILE: tests/test_misc.py
function test_inexact_asarray_no_copy (line 22) | def test_inexact_asarray_no_copy():
function test_inexact_asarray_jvp (line 30) | def test_inexact_asarray_jvp():
function test_zero_matrix (line 37) | def test_zero_matrix(dtype):
FILE: tests/test_norm.py
function _square (line 23) | def _square(x):
function _two_norm (line 27) | def _two_norm(x):
function _rms_norm (line 31) | def _rms_norm(x):
function _max_norm (line 35) | def _max_norm(x):
function test_nonzero (line 39) | def test_nonzero():
function test_zero (line 72) | def test_zero():
function test_complex (line 85) | def test_complex():
function test_size_zero (line 102) | def test_size_zero():
FILE: tests/test_operator.py
function test_ops (line 36) | def test_ops(make_operator, getkey, dtype):
function test_structures_vector (line 87) | def test_structures_vector(make_operator, getkey):
function _setup (line 111) | def _setup(getkey, matrix, tag: object | frozenset[object] = frozenset()):
function _assert_except_diag (line 131) | def _assert_except_diag(cond_fun, operators, flip_cond):
function test_linearise (line 143) | def test_linearise(dtype, getkey):
function test_materialise (line 163) | def test_materialise(dtype, getkey):
function test_materialise_large (line 170) | def test_materialise_large(dtype, getkey):
function test_diagonal (line 177) | def test_diagonal(dtype, getkey):
function test_tridiagonal (line 194) | def test_tridiagonal(dtype, getkey):
function test_is_symmetric (line 239) | def test_is_symmetric(dtype, getkey):
function test_is_diagonal (line 250) | def test_is_diagonal(dtype, getkey):
function test_is_diagonal_scalar (line 261) | def test_is_diagonal_scalar(dtype, getkey):
function test_is_diagonal_tridiagonal (line 269) | def test_is_diagonal_tridiagonal(dtype, getkey):
function test_has_unit_diagonal (line 277) | def test_has_unit_diagonal(dtype, getkey):
function test_is_lower_triangular (line 289) | def test_is_lower_triangular(dtype, getkey):
function test_is_upper_triangular (line 300) | def test_is_upper_triangular(dtype, getkey):
function test_is_positive_semidefinite (line 311) | def test_is_positive_semidefinite(dtype, getkey):
function test_is_negative_semidefinite (line 326) | def test_is_negative_semidefinite(dtype, getkey):
function test_is_tridiagonal (line 341) | def test_is_tridiagonal(dtype, getkey):
function test_tangent_as_matrix (line 354) | def test_tangent_as_matrix(dtype, getkey):
function test_materialise_function_linear_operator (line 377) | def test_materialise_function_linear_operator(dtype, getkey):
function test_pytree_transpose (line 400) | def test_pytree_transpose(dtype, getkey):
function test_diagonal_tangent (line 420) | def test_diagonal_tangent():
function test_identity_with_different_structures (line 432) | def test_identity_with_different_structures():
function test_identity_with_different_structures_complex (line 461) | def test_identity_with_different_structures_complex():
function test_zero_pytree_as_matrix (line 491) | def test_zero_pytree_as_matrix(dtype):
function test_jacrev_operator (line 498) | def test_jacrev_operator():
FILE: tests/test_singular.py
function test_small_singular (line 39) | def test_small_singular(make_operator, solver, tags, ops, getkey, dtype):
function test_bicgstab_breakdown (line 57) | def test_bicgstab_breakdown(getkey, dtype):
function test_gmres_stagnation_or_breakdown (line 76) | def test_gmres_stagnation_or_breakdown(getkey, dtype):
function test_nonsquare_pytree_operator1 (line 113) | def test_nonsquare_pytree_operator1(solver):
function test_nonsquare_pytree_operator2 (line 136) | def test_nonsquare_pytree_operator2(solver):
function test_nonsquare_mat_vec (line 162) | def test_nonsquare_mat_vec(solver, full_rank, jvp, wide, dtype, getkey):
function test_nonsquare_vec (line 213) | def test_nonsquare_vec(solver, full_rank, jvp, wide, dtype, getkey):
function test_iterative_singular (line 256) | def test_iterative_singular(getkey, solver, tags, use_state, make_operat...
FILE: tests/test_solve.py
function test_gmres_large_dense (line 24) | def test_gmres_large_dense(getkey):
function test_nontrivial_pytree_operator (line 41) | def test_nontrivial_pytree_operator():
function test_nontrivial_diagonal_operator (line 51) | def test_nontrivial_diagonal_operator():
function test_mixed_dtypes (line 65) | def test_mixed_dtypes(solver):
function test_mixed_dtypes_complex (line 78) | def test_mixed_dtypes_complex(solver):
function test_mixed_dtypes_complex_real (line 91) | def test_mixed_dtypes_complex_real(solver):
function test_mixed_dtypes_triangular (line 103) | def test_mixed_dtypes_triangular():
function test_mixed_dtypes_complex_triangular (line 115) | def test_mixed_dtypes_complex_triangular():
function test_mixed_dtypes_complex_real_triangular (line 127) | def test_mixed_dtypes_complex_real_triangular():
function test_ad_closure_function_linear_operator (line 139) | def test_ad_closure_function_linear_operator(getkey):
function test_grad_vmap_symbolic_cotangent (line 156) | def test_grad_vmap_symbolic_cotangent():
function test_iterative_solver_max_steps_only (line 184) | def test_iterative_solver_max_steps_only(solver):
function test_solver_init_not_differentiated (line 197) | def test_solver_init_not_differentiated(getkey):
function test_nonfinite_input (line 249) | def test_nonfinite_input():
FILE: tests/test_transpose.py
class TestTranspose (line 25) | class TestTranspose:
method assert_transpose_fixture (line 27) | def assert_transpose_fixture(_):
method test_transpose (line 43) | def test_transpose(
method test_pytree_transpose (line 54) | def test_pytree_transpose(_, assert_transpose_fixture): # pyright: ig...
FILE: tests/test_vmap.py
function test_vmap (line 43) | def test_vmap(
function test_grad_vmap_basic (line 93) | def test_grad_vmap_basic(getkey):
function test_grad_vmap_advanced (line 110) | def test_grad_vmap_advanced(getkey):
FILE: tests/test_vmap_jvp.py
function test_vmap_jvp (line 45) | def test_vmap_jvp(
FILE: tests/test_vmap_vmap.py
function test_vmap_vmap (line 44) | def test_vmap_vmap(
FILE: tests/test_well_posed.py
function test_small_wellposed (line 34) | def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype):
function test_pytree_wellposed (line 57) | def test_pytree_wellposed(solver, getkey, dtype):
Condensed preview — 75 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (448K chars).
[
{
"path": ".github/workflows/build_docs.yml",
"chars": 737,
"preview": "name: Build docs\n\non:\n push:\n branches:\n - main\n\njobs:\n build:\n strategy:\n matrix:\n python-vers"
},
{
"path": ".github/workflows/release.yml",
"chars": 788,
"preview": "name: Release\n\non:\n push:\n branches:\n - main\n\njobs:\n build:\n runs-on: ubuntu-latest\n steps:\n - name"
},
{
"path": ".github/workflows/run_tests.yml",
"chars": 774,
"preview": "name: Run tests\n\non:\n pull_request:\n\njobs:\n run-test:\n strategy:\n matrix:\n python-version: [ 3.11 ]\n "
},
{
"path": ".gitignore",
"chars": 128,
"preview": "**/__pycache__\n**/.ipynb_checkpoints\n*.egg-info/\nbuild/\ndist/\nsite/\nexamples/data\n.all_objects.cache\n.pymon\n.idea\n.venv\n"
},
{
"path": ".pre-commit-config.yaml",
"chars": 1051,
"preview": "fail_fast: true\nrepos:\n - repo: meta\n hooks:\n - id: check-hooks-apply\n - id: check-useless-excludes\n - repo: "
},
{
"path": "CONTRIBUTING.md",
"chars": 1012,
"preview": "# Contributing\n\nContributions (pull requests) are very welcome! Here's how to get started.\n\n---\n\n### Getting started\n\n[W"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 3824,
"preview": "<h1 align='center'>Lineax</h1>\n\nLineax is a [JAX](https://github.com/google/jax) library for linear solves and linear le"
},
{
"path": "benchmarks/gmres_fails_safely.py",
"chars": 2654,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "benchmarks/lstsq_gradients.py",
"chars": 1423,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "benchmarks/solver_speeds.py",
"chars": 6749,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "docs/.htaccess",
"chars": 38,
"preview": "ErrorDocument 404 /jaxtyping/404.html\n"
},
{
"path": "docs/_overrides/partials/source.html",
"chars": 880,
"preview": "{% import \"partials/language.html\" as lang with context %}\n<a href=\"{{ config.repo_url }}\" title=\"{{ lang.t('source.link"
},
{
"path": "docs/_static/custom_css.css",
"chars": 3162,
"preview": "/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */\nhtml {\n scroll-padding-top: 50px;"
},
{
"path": "docs/_static/mathjax.js",
"chars": 300,
"preview": "window.MathJax = {\n tex: {\n inlineMath: [[\"\\\\(\", \"\\\\)\"]],\n displayMath: [[\"\\\\[\", \"\\\\]\"]],\n processEscapes: tru"
},
{
"path": "docs/api/functions.md",
"chars": 996,
"preview": "# Functions on linear operators\n\nWe define a number of functions on [linear operators](./operators.md).\n\n## Computationa"
},
{
"path": "docs/api/linear_solve.md",
"chars": 207,
"preview": "# linear_solve\n\nThis is the main entry point.\n\n::: lineax.linear_solve\n\n## invert\n\nA convenience function for obtaining "
},
{
"path": "docs/api/operators.md",
"chars": 2054,
"preview": "# Linear operators\n\nWe often talk about solving a linear system $Ax = b$, where $A \\in \\mathbb{R}^{n \\times m}$ is a mat"
},
{
"path": "docs/api/solution.md",
"chars": 123,
"preview": "# Solution\n\n::: lineax.Solution\n options:\n members: []\n\n---\n\n::: lineax.RESULTS\n options:\n members: "
},
{
"path": "docs/api/solvers.md",
"chars": 2588,
"preview": "# Solvers\n\nIf you're not sure what to use, then pick [`lineax.AutoLinearSolver`][] and it will automatically dispatch to"
},
{
"path": "docs/api/tags.md",
"chars": 4484,
"preview": "# Tags\n\nLineax offers a way to \"tag\" linear operators as exhibiting certain properties, e.g. that they are positive semi"
},
{
"path": "docs/examples/classical_solve.ipynb",
"chars": 1799,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"8d41e1dd-93da-4e81-bd4a-33e5df8915f1\",\n \"metadata\": {},\n \"so"
},
{
"path": "docs/examples/complex_solve.ipynb",
"chars": 2056,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"8d41e1dd-93da-4e81-bd4a-33e5df8915f1\",\n \"metadata\": {},\n \"so"
},
{
"path": "docs/examples/least_squares.ipynb",
"chars": 13976,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"44bff903-0e4d-4f3e-a75c-d3cfe8ab4dea\",\n \"metadata\": {},\n \"so"
},
{
"path": "docs/examples/no_materialisation.ipynb",
"chars": 4196,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"a7299095-8906-4867-82ef-d6b84b161366\",\n \"metadata\": {},\n \"so"
},
{
"path": "docs/examples/operators.ipynb",
"chars": 9500,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"2fe0b1e4-35cb-4c39-b324-65253aab005a\",\n \"metadata\": {},\n \"so"
},
{
"path": "docs/examples/structured_matrices.ipynb",
"chars": 6845,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"e2573d62-a505-4998-8796-b0f1bc889433\",\n \"metadata\": {},\n \"so"
},
{
"path": "docs/faq.md",
"chars": 2030,
"preview": "# FAQ\n\n## How does this differ from `jax.numpy.solve`, `jax.scipy.{...}` etc.?\n\nLineax offers several improvements. Most"
},
{
"path": "docs/index.md",
"chars": 3276,
"preview": "# Getting started\n\nLineax is a [JAX](https://github.com/google/jax) library for linear solves and linear least squares. "
},
{
"path": "lineax/__init__.py",
"chars": 3006,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_custom_types.py",
"chars": 686,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_misc.py",
"chars": 3183,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_norm.py",
"chars": 4952,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_operator.py",
"chars": 78519,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_solution.py",
"chars": 3112,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_solve.py",
"chars": 30945,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_solver/__init__.py",
"chars": 1033,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_solver/bicgstab.py",
"chars": 8627,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_solver/cg.py",
"chars": 11056,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_solver/cholesky.py",
"chars": 3407,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_solver/diagonal.py",
"chars": 4233,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_solver/gmres.py",
"chars": 19201,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_solver/lsmr.py",
"chars": 16637,
"preview": "\"\"\"Implementation adapted from SciPy, with BSD license:\n\nCopyright (c) 2001-2002 Enthought, Inc. 2003-2024, SciPy Develo"
},
{
"path": "lineax/_solver/lu.py",
"chars": 3300,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_solver/misc.py",
"chars": 4422,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_solver/normal.py",
"chars": 7005,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_solver/qr.py",
"chars": 4232,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_solver/svd.py",
"chars": 3662,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_solver/triangular.py",
"chars": 3918,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_solver/tridiagonal.py",
"chars": 3374,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/_tags.py",
"chars": 2410,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "lineax/internal/__init__.py",
"chars": 1174,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "mkdocs.yml",
"chars": 3918,
"preview": "theme:\n name: material\n features:\n - navigation.sections # Sections are included in the navigation on the "
},
{
"path": "pyproject.toml",
"chars": 2459,
"preview": "[build-system]\nbuild-backend = \"hatchling.build\"\nrequires = [\"hatchling\"]\n\n[dependency-groups]\ndev = [\n \"prek==0.3.9\",\n"
},
{
"path": "tests/README.md",
"chars": 137,
"preview": "Each file is run separately to avoid JAX out-of-memory'ing.\n\nAs such, run tests using `python -m tests`, *not* by just r"
},
{
"path": "tests/__init__.py",
"chars": 575,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/__main__.py",
"chars": 964,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/conftest.py",
"chars": 846,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/helpers.py",
"chars": 20442,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/test_adjoint.py",
"chars": 4003,
"preview": "import jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport lineax as lx\nimport pytest\nfrom lineax import Function"
},
{
"path": "tests/test_invert.py",
"chars": 7025,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/test_jvp.py",
"chars": 4505,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/test_jvp_jvp1.py",
"chars": 1588,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/test_jvp_jvp2.py",
"chars": 1591,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/test_lsmr.py",
"chars": 1996,
"preview": "import equinox as ex\nimport jax.numpy as jnp\nimport lineax as lx\nimport pytest\n\n\nsolver = lx.LSMR(1e-10, 1e-10)\nAill = l"
},
{
"path": "tests/test_misc.py",
"chars": 1285,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/test_norm.py",
"chars": 3827,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/test_operator.py",
"chars": 21197,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/test_singular.py",
"chars": 9444,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/test_solve.py",
"chars": 8761,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/test_transpose.py",
"chars": 2532,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/test_vmap.py",
"chars": 4187,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/test_vmap_jvp.py",
"chars": 5381,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/test_vmap_vmap.py",
"chars": 5106,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "tests/test_well_posed.py",
"chars": 3092,
"preview": "# Copyright 2023 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
}
]
About this extraction
This page contains the full source code of the patrick-kidger/lineax GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 75 files (416.0 KB), approximately 112.4k tokens, and a symbol index with 525 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.