Repository: StanfordASL/hj_reachability
Branch: main
Commit: 9aebaf974406
Files: 30
Total size: 77.4 KB
Directory structure:
gitextract_6__9fzrr/
├── .github/
│ └── workflows/
│ ├── ci.yml
│ └── pypi-publish.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── examples/
│ └── quickstart.ipynb
├── hj_reachability/
│ ├── __init__.py
│ ├── artificial_dissipation.py
│ ├── boundary_conditions.py
│ ├── boundary_conditions_test.py
│ ├── dynamics.py
│ ├── finite_differences/
│ │ ├── __init__.py
│ │ ├── upwind_first.py
│ │ └── upwind_first_test.py
│ ├── grid.py
│ ├── grid_test.py
│ ├── sets.py
│ ├── sets_test.py
│ ├── solver.py
│ ├── solver_test.py
│ ├── systems/
│ │ ├── __init__.py
│ │ └── air3d.py
│ ├── time_integration.py
│ ├── utils.py
│ └── utils_test.py
├── requirements-test.txt
├── requirements.txt
├── setup.cfg
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/ci.yml
================================================
name: ci
on: [push, pull_request]
jobs:
test:
name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
runs-on: "${{ matrix.os }}"
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
set -xe
python -m pip install --upgrade pip
pip install flake8 pytest pytest-xdist yapf
pip install -r requirements.txt
pip install -r requirements-test.txt
- name: Lint with flake8
run: |
set -xe
flake8 . --config=setup.cfg --count --statistics
- name: Check formatting with yapf
run: |
set -xe
yapf . --style=setup.cfg --recursive --diff
- name: Test with pytest
run: |
set -xe
pytest -n "$(grep -c ^processor /proc/cpuinfo)" hj_reachability
================================================
FILE: .github/workflows/pypi-publish.yml
================================================
name: pypi
on:
release:
types: [published]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build setuptools
- name: Check consistency between the package version and release tag
run: |
PACKAGE_VER="v`python setup.py --version`"
if [ $PACKAGE_VER != ${{ github.event.release.tag_name }} ]
then
echo "Package version ($PACKAGE_VER) != release tag (${{ github.event.release.tag_name }})."; exit 1
fi
- name: Build package
run: python -m build
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2021 Ed Schmerling
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: MANIFEST.in
================================================
include requirements.txt
include requirements-test.txt
================================================
FILE: README.md
================================================
# hj_reachability: Hamilton-Jacobi reachability analysis in [JAX]
This package implements numerical solvers for Hamilton-Jacobi (HJ) Partial Differential Equations (PDEs) which, in the context of optimal control, may be used to represent the continuous-time formulation of dynamic programming. Specifically, the focus of this package is on reachability analysis for zero-sum differential games modeled as Hamilton-Jacobi-Isaacs (HJI) PDEs, wherein an optimal controller and (optional) disturbance interact, and the set of reachable states at any time is represented as the zero sublevel set of a value function realized as the viscosity solution of the corresponding PDE.
This package is inspired by a number of related projects, including:
- [A Toolbox of Level Set Methods (`toolboxls`, MATLAB)](https://www.cs.ubc.ca/~mitchell/ToolboxLS/)
- [An Optimal Control Toolbox for Hamilton-Jacobi Reachability Analysis (`helperOC`, MATLAB)](https://github.com/HJReachability/helperOC)
- [Berkeley Efficient API in C++ for Level Set methods (`beacls`, C++/CUDA C++)](https://hjreachability.github.io/beacls/)
- [Optimizing Dynamic Programming-Based Algorithms (`optimized_dp`, python)](https://github.com/SFU-MARS/optimized_dp)
## Installation
This package accommodates different [JAX] versions (i.e., CPU-only vs. JAX with GPU support); if accelerator support is desired you should first install JAX according to the relevant [installation instructions](https://github.com/google/jax#installation). A minimum JAX version requirement is listed in [`requirements.txt`](https://github.com/StanfordASL/hj_reachability/blob/main/requirements.txt), but in general this package should be compatible with the latest JAX releases (please [file an issue](https://github.com/StanfordASL/hj_reachability/issues) if you find that this is no longer the case!).
If you only want CPU computation or have already installed JAX with your preferred accelerator support, you may install this package using pip:
```
pip install --upgrade hj-reachability
```
## TODOs
Aside from the specific TODOs scattered throughout the codebase, a few general TODOs:
- Single-line docstrings (at a bare minimum) for everything. Test coverage, book/paper references, and proper documentation to come... eventually.
- Look into using `jax.pmap`/`jax.lax.ppermute` for multi-device parallelism; see, e.g., [jax demo notebooks](https://github.com/google/jax/tree/master/cloud_tpu_colabs).
- Incorporate neural-network-based PDE solvers; see, e.g., [Bansal, S. and Tomlin, C. "DeepReach: A Deep Learning Approach to High-Dimensional Reachability." (2020)](https://arxiv.org/abs/2011.02082).
[JAX]: https://github.com/google/jax
================================================
FILE: examples/quickstart.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# hj_reachability quickstart\n",
"\n",
"Notebook dependencies:\n",
"- System: python3, ffmpeg (for rendering animations)\n",
"- Python: jupyter, jax, numpy, matplotlib, plotly, tqdm, hj_reachability\n",
"\n",
"Example setup for a Ubuntu system (Mac users, maybe `brew` instead of `sudo apt`; Windows users, learn to love [WSL](https://docs.microsoft.com/en-us/windows/wsl/install-win10)):\n",
"```\n",
"sudo apt install ffmpeg\n",
"/usr/bin/python3 -m pip install --upgrade pip\n",
"pip install --upgrade jupyter jax[cpu] numpy matplotlib plotly tqdm hj-reachability\n",
"jupyter notebook # from the directory of this notebook\n",
"```\n",
"Alternatively, view this notebook on [Google Colab](https://colab.research.google.com/github/StanfordASL/hj_reachability/blob/main/examples/quickstart.ipynb) and run a cell containing this command:\n",
"```\n",
"!pip install --upgrade hj-reachability\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"from IPython.display import HTML\n",
"import matplotlib.animation as anim\n",
"import matplotlib.pyplot as plt\n",
"import plotly.graph_objects as go\n",
"\n",
"import hj_reachability as hj"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example system: `Air3d`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dynamics = hj.systems.Air3d()\n",
"grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(np.array([-6., -10., 0.]),\n",
" np.array([20., 10., 2 * np.pi])),\n",
" (51, 40, 50),\n",
" periodic_dims=2)\n",
"values = jnp.linalg.norm(grid.states[..., :2], axis=-1) - 5\n",
"\n",
"solver_settings = hj.SolverSettings.with_accuracy(\"very_high\",\n",
" hamiltonian_postprocessor=hj.solver.backwards_reachable_tube)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### `hj.step`: propagate the HJ PDE from `(time, values)` to `target_time`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"time = 0.\n",
"target_time = -2.8\n",
"target_values = hj.step(solver_settings, dynamics, grid, time, values, target_time)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.jet()\n",
"plt.figure(figsize=(13, 8))\n",
"plt.contourf(grid.coordinate_vectors[0], grid.coordinate_vectors[1], target_values[:, :, 30].T)\n",
"plt.colorbar()\n",
"plt.contour(grid.coordinate_vectors[0],\n",
" grid.coordinate_vectors[1],\n",
" target_values[:, :, 30].T,\n",
" levels=0,\n",
" colors=\"black\",\n",
" linewidths=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"go.Figure(data=go.Isosurface(x=grid.states[..., 0].ravel(),\n",
" y=grid.states[..., 1].ravel(),\n",
" z=grid.states[..., 2].ravel(),\n",
" value=target_values.ravel(),\n",
" colorscale=\"jet\",\n",
" isomin=0,\n",
" surface_count=1,\n",
" isomax=0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### `hj.solve`: solve for `all_values` at a range of `times` (basically just iterating `hj.step`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"times = np.linspace(0, -2.8, 57)\n",
"initial_values = values\n",
"all_values = hj.solve(solver_settings, dynamics, grid, times, initial_values)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vmin, vmax = all_values.min(), all_values.max()\n",
"levels = np.linspace(round(vmin), round(vmax), round(vmax) - round(vmin) + 1)\n",
"fig = plt.figure(figsize=(13, 8))\n",
"\n",
"\n",
"def render_frame(i, colorbar=False):\n",
" plt.contourf(grid.coordinate_vectors[0],\n",
" grid.coordinate_vectors[1],\n",
" all_values[i, :, :, 30].T,\n",
" vmin=vmin,\n",
" vmax=vmax,\n",
" levels=levels)\n",
" if colorbar:\n",
" plt.colorbar()\n",
" plt.contour(grid.coordinate_vectors[0],\n",
" grid.coordinate_vectors[1],\n",
" target_values[:, :, 30].T,\n",
" levels=0,\n",
" colors=\"black\",\n",
" linewidths=3)\n",
"\n",
"\n",
"render_frame(0, True)\n",
"animation = HTML(anim.FuncAnimation(fig, render_frame, all_values.shape[0], interval=50).to_html5_video())\n",
"plt.close(); animation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Defining your own dynamics: `AccelerationCurvatureCar`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class AccelerationCurvatureCar(hj.ControlAndDisturbanceAffineDynamics):\n",
"\n",
" def __init__(self,\n",
" max_acceleration=1.,\n",
" max_curvature=1.,\n",
" max_position_disturbance=0.25,\n",
" control_mode=\"min\",\n",
" disturbance_mode=\"max\",\n",
" control_space=None,\n",
" disturbance_space=None):\n",
" if control_space is None:\n",
" control_space = hj.sets.Box(jnp.array([-max_acceleration, -max_curvature]),\n",
" jnp.array([max_acceleration, max_curvature]))\n",
" if disturbance_space is None:\n",
" disturbance_space = hj.sets.Ball(jnp.zeros(2), max_position_disturbance)\n",
" super().__init__(control_mode, disturbance_mode, control_space, disturbance_space)\n",
"\n",
" def open_loop_dynamics(self, state, time):\n",
" _, _, v, q = state\n",
" return jnp.array([v * jnp.cos(q), v * jnp.sin(q), 0., 0.])\n",
"\n",
" def control_jacobian(self, state, time):\n",
" v = state[2]\n",
" return jnp.array([\n",
" [0., 0.],\n",
" [0., 0.],\n",
" [1., 0.],\n",
" [0., v],\n",
" ])\n",
"\n",
" def disturbance_jacobian(self, state, time):\n",
" return jnp.array([\n",
" [1., 0.],\n",
" [0., 1.],\n",
" [0., 0.],\n",
" [0., 0.],\n",
" ])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dynamics = AccelerationCurvatureCar()\n",
"grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(lo=np.array([-5., -5., -1., -np.pi]),\n",
" hi=np.array([5., 5., 1., np.pi])),\n",
" (40, 40, 50, 50),\n",
" periodic_dims=3)\n",
"values = jnp.linalg.norm(grid.states[..., :2], axis=-1) - 1\n",
"\n",
"solver_settings = hj.SolverSettings.with_accuracy(\"low\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"time = 0.\n",
"target_time = -2.0\n",
"target_values = hj.step(solver_settings, dynamics, grid, time, values, target_time)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"go.Figure(data=go.Isosurface(x=grid.states[:, :, -1, :, 0].ravel(),\n",
" y=grid.states[:, :, -1, :, 1].ravel(),\n",
" z=grid.states[:, :, -1, :, 3].ravel(),\n",
" value=target_values[:, :, -1, :].ravel(),\n",
" colorscale='jet',\n",
" isomin=0,\n",
" surface_count=1,\n",
" isomax=0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
================================================
FILE: hj_reachability/__init__.py
================================================
from hj_reachability import artificial_dissipation
from hj_reachability import boundary_conditions
from hj_reachability import finite_differences
from hj_reachability import sets
from hj_reachability import solver
from hj_reachability import systems
from hj_reachability import time_integration
from hj_reachability import utils
from hj_reachability.dynamics import ControlAndDisturbanceAffineDynamics, Dynamics
from hj_reachability.grid import Grid
from hj_reachability.solver import SolverSettings, solve, step
__version__ = "0.7.0"
__all__ = ("ControlAndDisturbanceAffineDynamics", "Dynamics", "Grid", "SolverSettings", "artificial_dissipation",
"boundary_conditions", "finite_differences", "sets", "solve", "solver", "step", "systems",
"time_integration", "utils")
================================================
FILE: hj_reachability/artificial_dissipation.py
================================================
import jax
import jax.numpy as jnp
import numpy as np
from hj_reachability import sets
from hj_reachability import utils
def global_lax_friedrichs(partial_max_magnitudes, states, time, values, left_grad_values, right_grad_values):
"""Implements the Global Lax-Friedrichs (GLF) scheme for computing dissipation coefficients."""
grid_axes = np.arange(values.ndim)
grad_value_box = sets.Box(jnp.minimum(jnp.min(left_grad_values, grid_axes), jnp.min(right_grad_values, grid_axes)),
jnp.maximum(jnp.max(left_grad_values, grid_axes), jnp.max(right_grad_values, grid_axes)))
return utils.multivmap(lambda state, value: partial_max_magnitudes(state, time, value, grad_value_box),
grid_axes)(states, values)
def local_lax_friedrichs(partial_max_magnitudes, states, time, values, left_grad_values, right_grad_values):
"""Implements the Local Lax-Friedrichs (LLF) scheme for computing dissipation coefficients."""
grid_axes = np.arange(values.ndim)
global_grad_value_box = sets.Box(
jnp.minimum(jnp.min(left_grad_values, grid_axes), jnp.min(right_grad_values, grid_axes)),
jnp.maximum(jnp.max(left_grad_values, grid_axes), jnp.max(right_grad_values, grid_axes)))
local_local_grad_value_boxes = sets.Box(jnp.minimum(left_grad_values, right_grad_values),
jnp.maximum(left_grad_values, right_grad_values))
local_grad_value_boxes = jax.tree.map(
lambda global_grad_value, local_local_grad_values:
(jnp.broadcast_to(global_grad_value, values.shape +
(values.ndim,) * 2).at[..., grid_axes, grid_axes].set(local_local_grad_values)),
global_grad_value_box, local_local_grad_value_boxes)
return utils.multivmap(
lambda state, value, grad_value_box: partial_max_magnitudes(state, time, value, grad_value_box),
grid_axes)(states, values, local_grad_value_boxes)
def local_local_lax_friedrichs(partial_max_magnitudes, states, time, values, left_grad_values, right_grad_values):
"""Implements the Local Local Lax-Friedrichs (LLLF) scheme for computing dissipation coefficients."""
grid_axes = np.arange(values.ndim)
local_local_grad_value_boxes = sets.Box(jnp.minimum(left_grad_values, right_grad_values),
jnp.maximum(left_grad_values, right_grad_values))
return utils.multivmap(
lambda state, value, grad_value_box: partial_max_magnitudes(state, time, value, grad_value_box),
grid_axes)(states, values, local_local_grad_value_boxes)
================================================
FILE: hj_reachability/boundary_conditions.py
================================================
import jax.numpy as jnp
from typing import Any, Callable
Array = Any
BoundaryCondition = Callable[[Array, int], Array]
def periodic(x: Array, pad_width: int) -> Array:
"""Pads a 1D array `x` by wrapping values, using the start values to pad the end and vice versa."""
return jnp.pad(x, ((pad_width, pad_width)), "wrap")
def extrapolate(x: Array, pad_width: int) -> Array:
"""Pads a 1D array `x` by extrapolating using the slope at each end."""
return jnp.concatenate(
[x[0] + (x[1] - x[0]) * jnp.arange(-pad_width, 0), x, x[-1] + (x[-1] - x[-2]) * jnp.arange(1, pad_width + 1)])
def extrapolate_away_from_zero(x: Array, pad_width: int) -> Array:
"""Pads a 1D array `x` by extrapolating away from zero using the (possibly negated) slope at each end."""
return jnp.concatenate([
x[0] - jnp.sign(x[0]) * jnp.abs(x[1] - x[0]) * jnp.arange(-pad_width, 0), x,
x[-1] + jnp.sign(x[-1]) * jnp.abs(x[-1] - x[-2]) * jnp.arange(1, pad_width + 1)
])
================================================
FILE: hj_reachability/boundary_conditions_test.py
================================================
from absl.testing import absltest
import numpy as np
from hj_reachability import boundary_conditions
class BoundaryConditionsTest(absltest.TestCase):
def setUp(self):
np.random.seed(0)
def test_periodic(self):
x = np.arange(5)
np.testing.assert_array_equal(boundary_conditions.periodic(x, 0), x)
np.testing.assert_array_equal(boundary_conditions.periodic(x, 1), [4, 0, 1, 2, 3, 4, 0])
np.testing.assert_array_equal(boundary_conditions.periodic(x, 2), [3, 4, 0, 1, 2, 3, 4, 0, 1])
def test_extrapolate(self):
x = np.arange(5)
np.testing.assert_array_equal(boundary_conditions.extrapolate(x, 0), x)
np.testing.assert_array_equal(boundary_conditions.extrapolate(x, 1), np.arange(-1, 6))
np.testing.assert_array_equal(boundary_conditions.extrapolate(x, 2), np.arange(-2, 7))
def test_extrapolate_away_from_zero(self):
x = np.arange(1, 5)
np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 0), x)
np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 1), [2, 1, 2, 3, 4, 5])
np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 2), [3, 2, 1, 2, 3, 4, 5, 6])
x = x[::-1]
np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 0), x)
np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 1), [5, 4, 3, 2, 1, 2])
np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 2), [6, 5, 4, 3, 2, 1, 2, 3])
if __name__ == "__main__":
absltest.main()
================================================
FILE: hj_reachability/dynamics.py
================================================
import abc
import jax.numpy as jnp
class Dynamics(metaclass=abc.ABCMeta):
"""Abstract base class for representing continuous-time dynamics in the context of Hamilton-Jacobi reachability.
TODO: Consider allowing for state/time-dependent control/disturbance spaces.
Attributes:
control_mode: Whether the controller is trying to "max"imize or "min"imize the value.
disturbance_mode: Whether the disturbance is trying to "max"imize or "min"imize the value.
control_space: A `BoundedSet` defining the (time-invariant) set of possible controls.
disturbance_space: A `BoundedSet` defining the (time-invariant) set of possible disturbances.
"""
def __init__(self, control_mode, disturbance_mode, control_space, disturbance_space):
self.control_mode = control_mode
self.disturbance_mode = disturbance_mode
self.control_space = control_space
self.disturbance_space = disturbance_space
@abc.abstractmethod
def __call__(self, state, control, disturbance, time):
"""Implements the continuous-time dynamics ODE."""
@abc.abstractmethod
def optimal_control_and_disturbance(self, state, time, grad_value):
"""Computes the optimal control and disturbance realized by the HJ PDE Hamiltonian."""
def optimal_control(self, state, time, grad_value):
"""Computes the optimal control realized by the HJ PDE Hamiltonian."""
return self.optimal_control_and_disturbance(state, time, grad_value)[0]
def optimal_disturbance(self, state, time, grad_value):
"""Computes the optimal disturbance realized by the HJ PDE Hamiltonian."""
return self.optimal_control_and_disturbance(state, time, grad_value)[1]
def hamiltonian(self, state, time, value, grad_value):
"""Evaluates the HJ PDE Hamiltonian."""
del value # unused
control, disturbance = self.optimal_control_and_disturbance(state, time, grad_value)
return grad_value @ self(state, control, disturbance, time)
@abc.abstractmethod
def partial_max_magnitudes(self, state, time, value, grad_value_box):
"""Computes the max magnitudes of the Hamiltonian partials over the `grad_value_box` in each dimension."""
class ControlAndDisturbanceAffineDynamics(Dynamics):
"""Abstract base class for representing control- and disturbance-affine dynamics."""
def __call__(self, state, control, disturbance, time):
"""Implements the affine dynamics `dx_dt = f(x, t) + G_u(x, t) @ u + G_d(x, t) @ d`."""
return (self.open_loop_dynamics(state, time) + self.control_jacobian(state, time) @ control +
self.disturbance_jacobian(state, time) @ disturbance)
@abc.abstractmethod
def open_loop_dynamics(self, state, time):
"""Implements the open loop dynamics `f(x, t)`."""
@abc.abstractmethod
def control_jacobian(self, state, time):
"""Implements the control Jacobian `G_u(x, t)`."""
@abc.abstractmethod
def disturbance_jacobian(self, state, time):
"""Implements the disturbance Jacobian `G_d(x, t)`."""
def optimal_control_and_disturbance(self, state, time, grad_value):
"""Computes the optimal control and disturbance realized by the HJ PDE Hamiltonian."""
control_direction = grad_value @ self.control_jacobian(state, time)
if self.control_mode == "min":
control_direction = -control_direction
disturbance_direction = grad_value @ self.disturbance_jacobian(state, time)
if self.disturbance_mode == "min":
disturbance_direction = -disturbance_direction
return (self.control_space.extreme_point(control_direction),
self.disturbance_space.extreme_point(disturbance_direction))
def partial_max_magnitudes(self, state, time, value, grad_value_box):
"""Computes the max magnitudes of the Hamiltonian partials over the `grad_value_box` in each dimension."""
del value, grad_value_box # unused
# An overestimation; see Eq. (25) from https://www.cs.ubc.ca/~mitchell/ToolboxLS/toolboxLS-1.1.pdf.
return (jnp.abs(self.open_loop_dynamics(state, time)) +
jnp.abs(self.control_jacobian(state, time)) @ self.control_space.max_magnitudes +
jnp.abs(self.disturbance_jacobian(state, time)) @ self.disturbance_space.max_magnitudes)
================================================
FILE: hj_reachability/finite_differences/__init__.py
================================================
from hj_reachability.finite_differences.upwind_first import (ENO1, ENO2, ENO3, WENO1, WENO3, WENO5,
essentially_non_oscillatory, first_order,
weighted_essentially_non_oscillatory)
__all__ = ("ENO1", "ENO2", "ENO3", "WENO1", "WENO3", "WENO5", "essentially_non_oscillatory", "first_order",
"weighted_essentially_non_oscillatory")
================================================
FILE: hj_reachability/finite_differences/upwind_first.py
================================================
import functools
import jax
import jax.numpy as jnp
import numpy as np
import numpy.polynomial.polynomial as poly
from types import ModuleType
from typing import Any, Callable, Optional, Tuple
Array = Any
WENO_EPS = 1e-6
def weighted_essentially_non_oscillatory(eno_order: int, values: Array, spacing: float,
boundary_condition: Callable[[Array, int], Array]) -> Tuple[Array, Array]:
"""Implements an upwind weighted essentially non-oscillatory (WENO) scheme for first derivative approximation.
Args:
eno_order: The order of the underlying essentially non-oscillatory (ENO) scheme; the resulting WENO scheme is
`(2 * eno_order - 1)`th-order accurate.
values: 1-dimensional array of function values assumed to be evaluated at a uniform grid in the domain.
spacing: Grid spacing of the `values`.
boundary_condition: A function used to pad `values` to implement a boundary condition (e.g., periodic).
Returns:
A tuple of arrays `(left_derivatives, right_derivatives)` each the same shape as `values` which contain,
respectively, left and right approximations of the first derivative at the grid points of `values`.
"""
if eno_order < 1:
raise ValueError(f"`eno_order` must be at least 1; got {eno_order}.")
values = boundary_condition(values, eno_order)
diffs = (values[1:] - values[:-1]) / spacing
if eno_order == 1:
return (diffs[:-1], diffs[1:])
substencil_approximations = tuple(
_unrolled_correlate(diffs[i:len(diffs) - eno_order + i], c)
for (i, c) in enumerate(_diff_coefficients(eno_order)))
diffs2 = diffs[1:] - diffs[:-1]
smoothness_indicators = [
sum(
_unrolled_correlate(diffs2[i + j:len(diffs2) - eno_order + i + 1], L[j:, j])**2
for j in range(eno_order - 1))
for (i, L) in enumerate(np.linalg.cholesky(_smoothness_indicator_quad_form(eno_order)))
]
left_and_right_unnormalized_weights = [[
c / (s[i:len(s) + i - 1] + WENO_EPS)**2 for (c, s) in zip(coefficients, smoothness_indicators)
] for (i, coefficients) in enumerate(_substencil_coefficients(eno_order))]
return tuple(
sum(w * a for (w, a) in zip(unnormalized_weights, substencil_approximations[i:eno_order + i])) /
sum(unnormalized_weights) for (i, unnormalized_weights) in enumerate(left_and_right_unnormalized_weights))
def essentially_non_oscillatory(order: int, values: Array, spacing: float,
boundary_condition: Callable[[Array, int], Array]) -> Tuple[Array, Array]:
"""Implements an upwind essentially non-oscillatory (ENO) scheme for first derivative approximation.
Args:
order: The desired order of accuracy for the ENO scheme.
values: 1-dimensional array of function values assumed to be evaluated at a uniform grid in the domain.
spacing: Grid spacing of the `values`.
boundary_condition: A function used to pad `values` to implement a boundary condition (e.g., periodic).
Returns:
A tuple of arrays `(left_derivatives, right_derivatives)` each the same shape as `values` which contain,
respectively, left and right approximations of the first derivative at the grid points of `values`.
"""
if order < 1:
raise ValueError(f"`order` must be at least 1; got {order}.")
values = boundary_condition(values, order)
diffs = (values[1:] - values[:-1]) / spacing
if order == 1:
return (diffs[:-1], diffs[1:])
substencil_approximations = tuple(
_unrolled_correlate(diffs[i:len(diffs) - order + i], c) for (i, c) in enumerate(_diff_coefficients(order)))
undivided_differences = []
for i in range(2, order):
diffs = diffs[1:] - diffs[:-1]
undivided_differences.append(diffs[order - i:i - order])
abs_diffs = jnp.abs(diffs[1:] - diffs[:-1])
stencil_indices = abs_diffs[1:] < abs_diffs[:-1]
for diffs in reversed(undivided_differences):
abs_diffs = jnp.abs(diffs)
stencil_indices = jnp.where(abs_diffs[1:] < abs_diffs[:-1], stencil_indices[1:] + 1, stencil_indices[:-1])
return (jnp.select([stencil_indices[:-1] == i for i in range(order - 1)], substencil_approximations[:-2],
substencil_approximations[-2]),
jnp.select([stencil_indices[1:] == i for i in range(order - 1)], substencil_approximations[1:-1],
substencil_approximations[-1]))
first_order = WENO1 = functools.partial(weighted_essentially_non_oscillatory, 1)
WENO3 = functools.partial(weighted_essentially_non_oscillatory, 2)
WENO5 = functools.partial(weighted_essentially_non_oscillatory, 3)
ENO1 = functools.partial(essentially_non_oscillatory, 1)
ENO2 = functools.partial(essentially_non_oscillatory, 2)
ENO3 = functools.partial(essentially_non_oscillatory, 3)
def _weighted_essentially_non_oscillatory_vectorized(
eno_order: int, values: Array, spacing: float, boundary_condition: Callable[[Array, int],
Array]) -> Tuple[Array, Array]:
"""Implements a more "vectorized" but ultimately slower version of `weighted_essentially_non_oscillatory`."""
if eno_order < 1:
raise ValueError(f"`eno_order` must be at least 1; got {eno_order}.")
values = boundary_condition(values, eno_order)
diffs = (values[1:] - values[:-1]) / spacing
if eno_order == 1:
return (diffs[:-1], diffs[1:])
substencil_approximations = _align_substencil_values(
jax.vmap(jnp.correlate, (None, 0), 0)(diffs, _diff_coefficients(eno_order)), jnp)
diffs2 = diffs[1:] - diffs[:-1]
chol_T = jnp.asarray(np.linalg.cholesky(_smoothness_indicator_quad_form(eno_order)).swapaxes(-1, -2))
smoothness_indicators = _align_substencil_values(
jnp.sum(jnp.square(jax.vmap(jax.vmap(jnp.correlate, (None, 0), 1), (None, 0), 0)(diffs2, chol_T)), -1), jnp)
unscaled_weights = 1 / jnp.square(smoothness_indicators + WENO_EPS)
unnormalized_weights = (jnp.asarray(_substencil_coefficients(eno_order)[..., np.newaxis]) *
jnp.stack([unscaled_weights[:, :-1], unscaled_weights[:, 1:]]))
weights = unnormalized_weights / jnp.sum(unnormalized_weights, 1, keepdims=True)
return tuple(jnp.sum(jnp.stack([substencil_approximations[:-1], substencil_approximations[1:]]) * weights, 1))
def _unrolled_correlate(a: Array, v: Array) -> Array:
"""An unrolled equivalent of `np.correlate`."""
return sum(a[i:len(a) - len(v) + i + 1] * x for (i, x) in enumerate(v))
def _substencils(k: int) -> Array:
"""Returns the `k + 1` subranges of length `k + 1` from the full stencil range `[-k, k + 1)`."""
return np.arange(k + 1) + np.arange(k + 1)[:, np.newaxis] - k
def _spread_substencil_values(x: Array, np: ModuleType = np) -> Array:
"""Offsets each successive row of a matrix `x` by one additional column."""
return np.reshape(np.reshape(np.pad(x, ((0, 0), (0, x.shape[0]))), -1)[:-x.shape[0]], (x.shape[0], -1))
def _align_substencil_values(x: Array, np: ModuleType = np) -> Array:
"""Slices and stacks windows, each offset by one column from the previous, from rows of a matrix `x`."""
return np.reshape(np.pad(np.reshape(x, -1), (0, x.shape[0])), (x.shape[0], -1))[:, :-x.shape[0]]
def _diff_coefficients(k: Optional[int] = None, stencil: Optional[Array] = None) -> Array:
"""Returns first derivative approximation finite difference coefficients for function value first differences."""
if k is None:
if stencil is None:
raise ValueError("One of `k` or `stencil` must be provided.")
k = stencil.shape[-1] - 1
else:
if stencil is None:
stencil = _substencils(k)
elif k != stencil.shape[-1] - 1:
raise ValueError("`k` must match `stencil.shape[-1] - 1` if both arguments are provided; got "
f"{(k, stencil.shape[-1] - 1)}.")
return np.linalg.solve(
np.diff(poly.polyvander(stencil, k), axis=-2)[..., 1:].swapaxes(-1, -2),
np.eye(k)[(np.newaxis,) * (stencil.ndim - 1) + (0, ..., np.newaxis)])[..., 0]
def _substencil_coefficients(k: int) -> Array:
"""Returns coefficients for combining substencil approximations to yield higher order left/right approximations."""
left_coefficients = np.linalg.solve(
_spread_substencil_values(_diff_coefficients(k))[:-1, :k].T,
_diff_coefficients(stencil=np.arange(-k, k))[:k])
return np.array([left_coefficients, left_coefficients[::-1]])
def _polyder_operator(k: int, d: int) -> Array:
"""Returns a matrix `D` such that `D @ p == poly.polyder(p, d)` for polynomials `p` of degree `k`."""
return np.concatenate([np.zeros((k + 1 - d, d)), np.diag(poly.polyder(np.ones(k + 1), d))], 1)
def _smoothness_indicator_quad_form(k: int) -> Array:
"""Returns quadratic forms for computing substencil smoothness indicators as functions of second differences."""
interp_poly_second_der = (poly.polyder(np.ones(k + 1), 2)[:, np.newaxis] *
np.linalg.inv(np.diff(poly.polyvander(_substencils(k)[1:], k), 2, axis=-2)[..., 2:]))
quad_form = np.zeros((k, k - 1, k - 1))
for m in range(k - 1):
integrator_matrix = 1 / (np.arange(k - 1 - m) + np.arange(k - 1 - m)[:, np.newaxis] + 1)
interp_poly_m_plus_2_der = _polyder_operator(k - 2, m) @ interp_poly_second_der
quad_form += interp_poly_m_plus_2_der.swapaxes(-1, -2) @ integrator_matrix @ interp_poly_m_plus_2_der
return quad_form
================================================
FILE: hj_reachability/finite_differences/upwind_first_test.py
================================================
import math
from absl.testing import absltest
import jax
import numpy as np
from hj_reachability import boundary_conditions
from hj_reachability.finite_differences import upwind_first
class UpwindFirstTest(absltest.TestCase):
def setUp(self):
np.random.seed(0)
def test_weighted_essentially_non_oscillatory(self):
def _WENO5(values, spacing, boundary_condition):
values = boundary_condition(values, 3)
diffs = (values[1:] - values[:-1]) / spacing
def compute_weno(v):
phi = [
v[0] / 3 - 7 * v[1] / 6 + 11 * v[2] / 6,
-v[1] / 6 + 5 * v[2] / 6 + v[3] / 3,
v[2] / 3 + 5 * v[3] / 6 - v[4] / 6,
]
s = [(13 / 12) * (v[0] - 2 * v[1] + v[2])**2 + (1 / 4) * (v[0] - 4 * v[1] + 3 * v[2])**2,
(13 / 12) * (v[1] - 2 * v[2] + v[3])**2 + (1 / 4) * (v[1] - v[3])**2,
(13 / 12) * (v[2] - 2 * v[3] + v[4])**2 + (1 / 4) * (3 * v[2] - 4 * v[3] + v[4])**2]
a = [w / (x + upwind_first.WENO_EPS)**2 for (w, x) in zip([0.1, 0.6, 0.3], s)]
w = [x / sum(a) for x in a]
return sum(p * w for (p, w) in zip(phi, w))
return (compute_weno([diffs[i:-5 + i] for i in range(5)]),
compute_weno([diffs[5 - i:None if i == 0 else -i] for i in range(5)]))
values = np.random.rand(1000)
spacing = 0.1
jax.tree.map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5),
upwind_first.WENO5(values, spacing, boundary_conditions.periodic),
_WENO5(values, spacing, boundary_conditions.periodic))
def test_essentially_non_oscillatory(self):
def _brute_force_essentially_non_oscillatory(order, values, spacing, boundary_condition):
def _divided_difference(x, i, spacing=1):
if isinstance(i, int):
return x[i]
order = len(i) - 1
return np.diff(x[i], order)[0] / (math.factorial(order) * spacing**order)
v = np.array(boundary_condition(values, order))
x = np.arange(len(v)) * spacing
p = [np.poly1d(v[i]) for i in range(order - 1, len(v) - order)]
ks = []
for i in range(len(p)):
j = i + order - 1
p[i] += _divided_difference(v, [j, j + 1], spacing) * np.poly1d([x[j]], True)
k = j
for d in range(2, order + 1):
a = _divided_difference(v, np.arange(k, k + d + 1), spacing)
b = _divided_difference(v, np.arange(k - 1, k + d), spacing)
if np.abs(a) >= np.abs(b):
c = b
k_next = k - 1
else:
c = a
k_next = k
p[i] += c * np.poly1d(x[k:k + d], True)
k = k_next
ks.append(k - j)
p_x = [np.polyder(f) for f in p]
return (np.array([np.polyval(f, x) for (f, x) in zip(p_x[:-1], x[order:-order])]),
np.array([np.polyval(f, x) for (f, x) in zip(p_x[1:], x[order:-order])]))
values = np.random.rand(1000)
spacing = 0.1
for order in range(1, 5):
jax.tree.map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5),
upwind_first.essentially_non_oscillatory(order, values, spacing, boundary_conditions.periodic),
_brute_force_essentially_non_oscillatory(order, values, spacing, boundary_conditions.periodic))
def test_weighted_essentially_non_oscillatory_vectorized(self):
values = np.random.rand(1000)
spacing = 0.1
for eno_order in range(1, 5):
jax.tree.map(
lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5),
upwind_first.weighted_essentially_non_oscillatory(eno_order, values, spacing,
boundary_conditions.periodic),
upwind_first._weighted_essentially_non_oscillatory_vectorized(eno_order, values, spacing,
boundary_conditions.periodic))
def test_diff_coefficients(self):
# k = 1
np.testing.assert_allclose(upwind_first._diff_coefficients(1), np.ones((2, 1)))
# k = 2
np.testing.assert_allclose(upwind_first._diff_coefficients(2), np.array([[-1, 3], [1, 1], [3, -1]]) / 2)
# k = 3
np.testing.assert_allclose(upwind_first._diff_coefficients(3),
np.array([[2, -7, 11], [-1, 5, 2], [2, 5, -1], [11, -7, 2]]) / 6)
def test_substencil_coefficients(self):
# k = 1
np.testing.assert_allclose(upwind_first._substencil_coefficients(1), np.ones((2, 1)))
# k = 2
np.testing.assert_allclose(upwind_first._substencil_coefficients(2), np.array([[1, 2], [2, 1]]) / 3)
# k = 3
np.testing.assert_allclose(upwind_first._substencil_coefficients(3), np.array([[1, 6, 3], [3, 6, 1]]) / 10)
def test_smoothness_indicator_quad_form(self):
diff_operator = lambda k: np.eye(k - 1, k, 1) - np.eye(k - 1, k, 0)
square_outer = lambda v: v[..., np.newaxis] * v[..., np.newaxis, :]
# k = 1
np.testing.assert_allclose(
diff_operator(1).T @ upwind_first._smoothness_indicator_quad_form(1) @ diff_operator(1), [[[0]]])
# k = 2
np.testing.assert_allclose(
diff_operator(2).T @ upwind_first._smoothness_indicator_quad_form(2) @ diff_operator(2),
square_outer(np.array([[1, -1], [1, -1]])))
# k = 3
np.testing.assert_allclose(
diff_operator(3).T @ upwind_first._smoothness_indicator_quad_form(3) @ diff_operator(3),
(13 / 12) * square_outer(np.array([[1, -2, 1], [1, -2, 1], [1, -2, 1]])) +
(1 / 4) * square_outer(np.array([[1, -4, 3], [1, 0, -1], [3, -4, 1]])))
if __name__ == "__main__":
absltest.main()
================================================
FILE: hj_reachability/grid.py
================================================
import functools
from flax import struct
import jax.numpy as jnp
import numpy as np
from hj_reachability import boundary_conditions as _boundary_conditions
from hj_reachability.finite_differences import upwind_first
from hj_reachability import sets
from hj_reachability import utils
from typing import Any, Callable, Optional, Tuple, Union
from hj_reachability.boundary_conditions import BoundaryCondition
Array = Any
@struct.dataclass
class Grid:
"""Class for representing Cartesian state grids with uniform spacing in each dimension.
Attributes:
states: An `(N + 1)` dimensional array containing the state values at each grid location. The first `N`
dimensions correspond to the location in the grid, while the last dimension (itself of size `N`) contains
the state vector.
domain: A `Box` representing the domain of grid.
coordinate_vectors: A tuple of `N` arrays containing the discrete state values in each dimension. The `states`
attribute is produced by `stack`ing a `meshgrid` of these coordinate vectors.
spacings: A tuple of `N` scalars containing the grid spacing (the difference between successive elements of the
corresponding coordinate vector) in each dimension.
boundary_conditions: A tuple of `N` boundary conditions for each dimension. These boundary conditions are
functions used to pad values (notably not stored in this `Grid` data structure) to implement a boundary
condition (e.g., periodic).
"""
states: Array
domain: sets.Box
coordinate_vectors: Tuple[Array, ...]
spacings: Tuple[Array, ...]
boundary_conditions: Tuple[BoundaryCondition, ...] = struct.field(pytree_node=False)
@classmethod
def from_lattice_parameters_and_boundary_conditions(
cls,
domain: sets.Box,
shape: Tuple[int, ...],
boundary_conditions: Optional[Tuple[BoundaryCondition, ...]] = None,
periodic_dims: Optional[Union[int, Tuple[int, ...]]] = None) -> "Grid":
"""Constructs a `Grid` from a domain, shape, and boundary conditions.
Args:
domain: A `Box` representing the domain of grid.
shape: A tuple of `N` integers denoting the number of discretization nodes in each dimension.
boundary_conditions: A tuple of `N` boundary conditions for each dimension. If not provided, defaults to
`extrapolate_away_from_zero` in each dimension, with the exception of those dimensions that appear in
`periodic_dims` where the `periodic` boundary condition is used instead.
periodic_dims: A single integer or tuple of integers denoting which dimensions are periodic in the case that
the `boundary_conditions` are not explicitly provided as input to this factory method.
Returns:
A `Grid` constructed according to the provided specifications.
"""
ndim = len(shape)
if boundary_conditions is None:
if not isinstance(periodic_dims, tuple):
periodic_dims = (periodic_dims,)
boundary_conditions = tuple(
_boundary_conditions.periodic if i in periodic_dims else _boundary_conditions.extrapolate_away_from_zero
for i in range(ndim))
coordinate_vectors, spacings = zip(
*(jnp.linspace(l, h, n, endpoint=bc is not _boundary_conditions.periodic, retstep=True)
for l, h, n, bc in zip(domain.lo, domain.hi, shape, boundary_conditions)))
states = jnp.stack(jnp.meshgrid(*coordinate_vectors, indexing="ij"), -1)
return cls(states, domain, coordinate_vectors, spacings, boundary_conditions)
@property
def ndim(self) -> int:
"""Returns the dimension `N` of the grid."""
return self.states.ndim - 1
@property
def shape(self) -> Tuple[int, ...]:
"""Returns the shape of the grid, a tuple of `N` integers."""
return self.states.shape[:-1]
def upwind_grad_values(self, upwind_scheme: Callable, values: Array) -> Tuple[Array, Array]:
"""Returns `(left_grad_values, right_grad_values)`."""
left_derivatives, right_derivatives = zip(*[
utils.multivmap(lambda values: upwind_scheme(values, spacing, boundary_condition),
np.array([j
for j in range(self.ndim)
if j != i]))(values)
for i, (spacing, boundary_condition) in enumerate(zip(self.spacings, self.boundary_conditions))
])
return (jnp.stack(left_derivatives, -1), jnp.stack(right_derivatives, -1))
def grad_values(self, values: Array, upwind_scheme: Optional[Callable] = None) -> Array:
"""Returns a central difference-based approximation of `grad_values`."""
# TODO: Implement central difference schemes in `hj_reachability.finite_differences`.
if upwind_scheme is None:
upwind_scheme = upwind_first.first_order
return sum(self.upwind_grad_values(upwind_scheme, values)) / 2
def position(self, state: Array) -> Array:
"""Returns an array of `float`s corresponding to the position of `state` in the grid."""
position = (state - self.domain.lo) / jnp.array(self.spacings)
return jnp.where(self._is_periodic_dim, position % np.array(self.shape), position)
def nearest_index(self, state: Array) -> Array:
"""Returns the result of rounding `self.position(state)` to the nearest grid index."""
return jnp.round(self.position(state)).astype(jnp.int32)
def interpolate(self, values, state):
"""Interpolates `values` (possibly multidimensional per node) defined over the grid at the given `state`."""
position = (state - self.domain.lo) / jnp.array(self.spacings)
index_lo = jnp.floor(position).astype(jnp.int32)
index_hi = index_lo + 1
weight_hi = position - index_lo
weight_lo = 1 - weight_hi
index_lo, index_hi = tuple(
jnp.where(self._is_periodic_dim, index % np.array(self.shape), jnp.clip(index, 0,
np.array(self.shape) - 1))
for index in (index_lo, index_hi))
weight = functools.reduce(lambda x, y: x * y, jnp.ix_(*jnp.stack([weight_lo, weight_hi], -1)))
# TODO: Double-check numerical stability here and/or switch to `tuple`s and `itertools.product` for clarity.
result = jnp.sum(
weight[(...,) + (np.newaxis,) * (values.ndim - self.ndim)] *
values[jnp.ix_(*jnp.stack([index_lo, index_hi], -1))], list(range(self.ndim)))
return jnp.where(jnp.any(~self._is_periodic_dim & ((state < self.domain.lo) | (state > self.domain.hi))),
jnp.nan, result)
@property
def _is_periodic_dim(self) -> Array:
"""Returns a boolean vector indicating which dimensions (if any) are periodic."""
return np.array([bc is _boundary_conditions.periodic for bc in self.boundary_conditions])
================================================
FILE: hj_reachability/grid_test.py
================================================
from absl.testing import absltest
import jax
import jax.numpy as jnp
import numpy as np
from hj_reachability import grid as _grid
from hj_reachability import sets
class BoundaryConditionsTest(absltest.TestCase):
def setUp(self):
np.random.seed(0)
def test_grid_interpolate(self):
grid_domain = sets.Box(np.zeros(2), np.ones(2))
grid_shape = (3, 2)
grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape, periodic_dims=1)
values = np.random.random((3, 2))
np.testing.assert_allclose(grid.interpolate(values, np.array([0.25, 2.75])), np.mean(values[0:2, 0:2]))
np.testing.assert_allclose(grid.interpolate(values, np.zeros(2)), values[0, 0])
np.testing.assert_allclose(grid.interpolate(values, np.ones(2)), values[-1, 0])
values = np.random.random((3, 2, 3, 4))
np.testing.assert_allclose(grid.interpolate(values, np.array([0.75, 2.75])), np.mean(values[1:3, 0:2], (0, 1)))
np.testing.assert_allclose(grid.interpolate(values, np.zeros(2)), values[0, 0])
np.testing.assert_allclose(grid.interpolate(values, np.ones(2)), values[-1, 0])
def test_grid_interpolate_on_grid(self):
grid_domain = sets.Box(jnp.zeros(2), jnp.ones(2))
grid_shape = (3, 4)
for value_shape in ((), (5,)):
values = jnp.array(np.random.random(grid_shape + value_shape))
grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape)
np.testing.assert_allclose(jax.vmap(jax.vmap(lambda x: grid.interpolate(values, x)))(grid.states),
values,
atol=1e-6)
grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape, periodic_dims=0)
states = grid.states + (grid._is_periodic_dim * np.arange(-3, 4)[:, None, None, None] *
(grid.domain.hi - grid.domain.lo))
np.testing.assert_allclose(jax.vmap(jax.vmap(jax.vmap(lambda x: grid.interpolate(values, x))))(states),
np.broadcast_to(values, states.shape[:1] + values.shape),
atol=1e-6)
def test_grid_interpolate_off_grid(self):
grid_domain = sets.Box(jnp.zeros(2), jnp.ones(2))
grid_shape = (3, 4)
for value_shape in ((), (5,)):
a = np.random.random((2,) + value_shape)
grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape)
values = grid.states @ a
states = grid.domain.lo + np.random.random((100, 2)) * (grid.domain.hi - grid.domain.lo)
np.testing.assert_allclose(jax.vmap(lambda x: grid.interpolate(values, x))(states), states @ a, atol=1e-6)
grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape, periodic_dims=0)
values = jnp.array(np.random.random(grid_shape + value_shape))
grid_unwrapped = _grid.Grid.from_lattice_parameters_and_boundary_conditions(
grid.domain, tuple(d + 1 if p else d for d, p in zip(grid.shape, grid._is_periodic_dim)))
values_unwrapped = jnp.concatenate([values, values[:1]])
states = states + (grid._is_periodic_dim * np.arange(-3, 4)[:, None, None] *
(grid.domain.hi - grid.domain.lo))
np.testing.assert_allclose(jax.vmap(jax.vmap(lambda x: grid.interpolate(values, x)))(states),
jax.vmap(jax.vmap(lambda x: grid_unwrapped.interpolate(values_unwrapped, x)))
((states - grid.domain.lo) % (grid.domain.hi - grid.domain.lo) + grid.domain.lo),
atol=1e-6)
def test_grid_interpolate_extrapolate_nan(self):
grid_domain = sets.Box(jnp.zeros(2), jnp.ones(2))
grid_shape = (3, 4)
for value_shape in ((), (5,)):
values = jnp.array(np.random.random(grid_shape + value_shape))
grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape)
states = grid.domain.lo + (grid.domain.hi - grid.domain.lo) * np.array(
[[0.5 + dx, 0.5 + dy] for dx in [-1, 0, 1] for dy in [-1, 0, 1] if dx or dy])
result = jax.vmap(lambda x: grid.interpolate(values, x))(states)
self.assertEqual(result.shape, (8,) + value_shape)
self.assertTrue(np.all(np.isnan(result)))
if __name__ == "__main__":
absltest.main()
================================================
FILE: hj_reachability/sets.py
================================================
import abc
from flax import struct
import jax.numpy as jnp
from hj_reachability import utils
from typing import Any
Array = Any
@struct.dataclass
class BoundedSet(metaclass=abc.ABCMeta):
"""Abstract base class for representing bounded subsets of Euclidean space."""
@abc.abstractmethod
def extreme_point(self, direction: Array) -> Array:
"""Computes the point `x` in the set such that the dot product `x @ direction` is greatest."""
@property
@abc.abstractmethod
def bounding_box(self) -> "Box":
"""Returns an axis-aligned bounding box for the set."""
@property
def max_magnitudes(self) -> Array:
"""Returns the maximum magnitude (per dimension) of points in the set."""
return jnp.maximum(jnp.abs(self.bounding_box.lo), jnp.abs(self.bounding_box.hi))
@property
def ndim(self) -> int:
"""Returns the dimension of the Euclidean space the set lies within."""
return self.bounding_box.ndim
@struct.dataclass
class Box(BoundedSet):
"""Class for representing axis-aligned boxes."""
lo: Array
hi: Array
def extreme_point(self, direction: Array) -> Array:
"""Computes the point `x` in the set such that the dot product `x @ direction` is greatest."""
return jnp.where(direction < 0, self.lo, self.hi)
@property
def bounding_box(self) -> "Box":
"""Returns an axis-aligned bounding box for the set."""
return self
@property
def ndim(self) -> int:
"""Returns the dimension of the Euclidean space the set lies within."""
return self.lo.shape[-1]
@struct.dataclass
class Ball(BoundedSet):
"""Class for representing Euclidean (L2) balls."""
center: Array
radius: Array
def extreme_point(self, direction: Array) -> Array:
"""Computes the point `x` in the set such that the dot product `x @ direction` is greatest."""
return self.center + self.radius * utils.unit_vector(direction)
@property
def bounding_box(self) -> "Box":
"""Returns an axis-aligned bounding box for the set."""
return Box(self.center - jnp.expand_dims(self.radius, -1), self.center + jnp.expand_dims(self.radius, -1))
================================================
FILE: hj_reachability/sets_test.py
================================================
from absl.testing import absltest
import jax
import numpy as np
from hj_reachability import sets
class SetsTest(absltest.TestCase):
def setUp(self):
np.random.seed(0)
def test_box(self):
box = sets.Box(np.ones(3), 2 * np.ones(3))
np.testing.assert_allclose(box.extreme_point(np.array([1, -1, 1])), np.array([2, 1, 2]))
self.assertTrue(np.all(np.isfinite(box.extreme_point(np.zeros(3)))))
self.assertEqual(box.bounding_box, box)
np.testing.assert_allclose(box.max_magnitudes, 2 * np.ones(3))
self.assertEqual(box.ndim, 3)
def test_ball(self):
ball = sets.Ball(np.ones(3), np.sqrt(3))
np.testing.assert_allclose(ball.extreme_point(np.array([1, -1, 1])), np.array([2, 0, 2]), atol=1e-6)
self.assertTrue(np.all(np.isfinite(ball.extreme_point(np.zeros(3)))))
jax.tree.map(np.testing.assert_allclose, ball.bounding_box,
sets.Box((1 - np.sqrt(3)) * np.ones(3), (1 + np.sqrt(3)) * np.ones(3)))
np.testing.assert_allclose(ball.max_magnitudes, (1 + np.sqrt(3)) * np.ones(3))
self.assertEqual(ball.ndim, 3)
if __name__ == "__main__":
absltest.main()
================================================
FILE: hj_reachability/solver.py
================================================
import contextlib
import functools
from flax import struct
import jax
import jax.numpy as jnp
import numpy as np
from hj_reachability import artificial_dissipation
from hj_reachability import time_integration
from hj_reachability.finite_differences import upwind_first
from typing import Callable, Text
# Hamiltonian postprocessors.
identity = lambda *x: x[-1] # Returns the last argument so that this may also be used as a value postprocessor.
backwards_reachable_tube = lambda x: jnp.minimum(x, 0)
# Value postprocessors.
static_obstacle = lambda obstacle: (lambda t, v: jnp.maximum(v, obstacle))
@struct.dataclass
class SolverSettings:
upwind_scheme: Callable = struct.field(
default=upwind_first.WENO5,
pytree_node=False,
)
artificial_dissipation_scheme: Callable = struct.field(
default=artificial_dissipation.global_lax_friedrichs,
pytree_node=False,
)
hamiltonian_postprocessor: Callable = struct.field(
default=identity,
pytree_node=False,
)
time_integrator: Callable = struct.field(
default=time_integration.third_order_total_variation_diminishing_runge_kutta,
pytree_node=False,
)
value_postprocessor: Callable = struct.field(
default=identity,
pytree_node=False,
)
CFL_number: float = 0.75
@classmethod
def with_accuracy(cls, accuracy: Text, **kwargs) -> "SolverSettings":
if accuracy == "low":
upwind_scheme = upwind_first.first_order
time_integrator = time_integration.first_order_total_variation_diminishing_runge_kutta
elif accuracy == "medium":
upwind_scheme = upwind_first.ENO2
time_integrator = time_integration.second_order_total_variation_diminishing_runge_kutta
elif accuracy == "high":
upwind_scheme = upwind_first.WENO3
time_integrator = time_integration.third_order_total_variation_diminishing_runge_kutta
elif accuracy == "very_high":
upwind_scheme = upwind_first.WENO5
time_integrator = time_integration.third_order_total_variation_diminishing_runge_kutta
return cls(upwind_scheme=upwind_scheme, time_integrator=time_integrator, **kwargs)
@functools.partial(jax.jit, static_argnames=("dynamics", "progress_bar"))
def step(solver_settings, dynamics, grid, time, values, target_time, progress_bar=True):
with (_try_get_progress_bar(time, target_time)
if progress_bar is True else contextlib.nullcontext(progress_bar)) as bar:
def sub_step(time_values):
t, v = solver_settings.time_integrator(solver_settings, dynamics, grid, *time_values, target_time)
if bar is not False:
bar.update_to(jnp.abs(t - bar.reference_time))
return t, v
return jax.lax.while_loop(lambda time_values: jnp.abs(target_time - time_values[0]) > 0, sub_step,
(time, values))[1]
@functools.partial(jax.jit, static_argnames=("dynamics", "progress_bar"))
def solve(solver_settings, dynamics, grid, times, initial_values, progress_bar=True):
with (_try_get_progress_bar(times[0], times[-1])
if progress_bar is True else contextlib.nullcontext(progress_bar)) as bar:
make_carry_and_output_slice = lambda t, v: ((t, v), v)
return jnp.concatenate([
initial_values[np.newaxis],
jax.lax.scan(
lambda time_values, target_time: make_carry_and_output_slice(
target_time, step(solver_settings, dynamics, grid, *time_values, target_time, bar)),
(times[0], initial_values), times[1:])[1]
])
def _try_get_progress_bar(reference_time, target_time):
try:
import tqdm
except ImportError:
raise ImportError("The option `progress_bar=True` requires the 'tqdm' package to be installed.")
return TqdmWrapper(tqdm,
reference_time,
total=jnp.abs(target_time - reference_time),
unit="sim_s",
bar_format="{l_bar}{bar}| {n:7.4f}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
ascii=True)
class TqdmWrapper:
def __init__(self, tqdm, reference_time, total, *args, **kwargs):
self.reference_time = reference_time
jax.experimental.io_callback(
lambda total: self._create_tqdm(tqdm, float(total), *args, **kwargs),
None,
total,
ordered=True,
)
def _create_tqdm(self, tqdm, total, *args, **kwargs):
self._tqdm = tqdm.tqdm(total=total, *args, **kwargs)
def update_to(self, n):
jax.experimental.io_callback(
lambda n: self._tqdm.update(float(n) - self._tqdm.n) and None,
None,
n,
ordered=True,
)
def close(self):
jax.experimental.io_callback(
lambda: self._tqdm.close(),
None,
ordered=True,
)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
================================================
FILE: hj_reachability/solver_test.py
================================================
from absl.testing import absltest
import numpy as np
import hj_reachability as hj
class SolverTest(absltest.TestCase):
def setUp(self):
np.random.seed(0)
solver_settings = hj.SolverSettings.with_accuracy("low")
dynamics = hj.systems.Air3d()
grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(np.array([-6., -10., 0.]),
np.array([20., 10., 2 * np.pi])),
(11, 10, 10),
periodic_dims=2)
self.problem_definition = {
"solver_settings": solver_settings,
"dynamics": dynamics,
"grid": grid,
}
def test_step(self):
values = np.linalg.norm(self.problem_definition["grid"].states[..., :2], axis=-1) - 5
target_values = hj.step(**self.problem_definition, time=0., values=values, target_time=-0.1, progress_bar=False)
self.assertEqual(target_values.shape, values.shape)
np.testing.assert_allclose(
target_values,
hj.step(**self.problem_definition, time=0., values=values, target_time=-0.1, progress_bar=True))
def test_solve(self):
times = np.linspace(0, -0.1, 3)
initial_values = np.linalg.norm(self.problem_definition["grid"].states[..., :2], axis=-1) - 5
all_values = hj.solve(**self.problem_definition, times=times, initial_values=initial_values, progress_bar=False)
self.assertEqual(all_values.shape, (len(times),) + initial_values.shape)
np.testing.assert_allclose(all_values[0], initial_values)
np.testing.assert_allclose(all_values[-1],
hj.step(**self.problem_definition,
time=0.,
values=initial_values,
target_time=-0.1,
progress_bar=False),
atol=1e-2)
np.testing.assert_allclose(
all_values,
hj.solve(**self.problem_definition, times=times, initial_values=initial_values, progress_bar=True))
================================================
FILE: hj_reachability/systems/__init__.py
================================================
from hj_reachability.systems.air3d import Air3d, DubinsCarCAvoid
__all__ = ("Air3d", "DubinsCarCAvoid")
================================================
FILE: hj_reachability/systems/air3d.py
================================================
import jax.numpy as jnp
from hj_reachability import dynamics
from hj_reachability import sets
class Air3d(dynamics.ControlAndDisturbanceAffineDynamics):
def __init__(self,
evader_speed=5.,
pursuer_speed=5.,
evader_max_turn_rate=1.,
pursuer_max_turn_rate=1.,
control_mode="max",
disturbance_mode="min",
control_space=None,
disturbance_space=None):
self.evader_speed = evader_speed
self.pursuer_speed = pursuer_speed
if control_space is None:
control_space = sets.Box(jnp.array([-evader_max_turn_rate]), jnp.array([evader_max_turn_rate]))
if disturbance_space is None:
disturbance_space = sets.Box(jnp.array([-pursuer_max_turn_rate]), jnp.array([pursuer_max_turn_rate]))
super().__init__(control_mode, disturbance_mode, control_space, disturbance_space)
def open_loop_dynamics(self, state, time):
_, _, psi = state
v_a, v_b = self.evader_speed, self.pursuer_speed
return jnp.array([-v_a + v_b * jnp.cos(psi), v_b * jnp.sin(psi), 0.])
def control_jacobian(self, state, time):
x, y, _ = state
return jnp.array([
[y],
[-x],
[-1.],
])
def disturbance_jacobian(self, state, time):
return jnp.array([
[0.],
[0.],
[1.],
])
DubinsCarCAvoid = Air3d
================================================
FILE: hj_reachability/time_integration.py
================================================
import functools
import jax
import jax.numpy as jnp
import numpy as np
from hj_reachability import utils
def lax_friedrichs_numerical_hamiltonian(hamiltonian, state, time, value, left_grad_value, right_grad_value,
dissipation_coefficients):
hamiltonian_value = hamiltonian(state, time, value, (left_grad_value + right_grad_value) / 2)
dissipation_value = dissipation_coefficients @ (right_grad_value - left_grad_value) / 2
return hamiltonian_value - dissipation_value
@functools.partial(jax.jit, static_argnames="dynamics")
def euler_step(solver_settings, dynamics, grid, time, values, time_step=None, max_time_step=None):
time_direction = jnp.sign(max_time_step) if time_step is None else jnp.sign(time_step)
signed_hamiltonian = lambda *args, **kwargs: time_direction * dynamics.hamiltonian(*args, **kwargs)
left_grad_values, right_grad_values = grid.upwind_grad_values(solver_settings.upwind_scheme, values)
dissipation_coefficients = solver_settings.artificial_dissipation_scheme(dynamics.partial_max_magnitudes,
grid.states, time, values,
left_grad_values, right_grad_values)
dvalues_dt = -solver_settings.hamiltonian_postprocessor(time_direction * utils.multivmap(
lambda state, value, left_grad_value, right_grad_value, dissipation_coefficients:
(lax_friedrichs_numerical_hamiltonian(signed_hamiltonian, state, time, value,
left_grad_value, right_grad_value, dissipation_coefficients)),
np.arange(grid.ndim))(grid.states, values, left_grad_values, right_grad_values, dissipation_coefficients))
if time_step is None:
time_step_bound = 1 / jnp.max(jnp.sum(dissipation_coefficients / jnp.array(grid.spacings), -1))
time_step = time_direction * jnp.minimum(solver_settings.CFL_number * time_step_bound, jnp.abs(max_time_step))
# TODO: Think carefully about whether `solver_settings.value_postprocessor` should be applied here instead.
return time + time_step, values + time_step * dvalues_dt
def first_order_total_variation_diminishing_runge_kutta(solver_settings, dynamics, grid, time, values, target_time):
time_1, values_1 = euler_step(solver_settings, dynamics, grid, time, values, max_time_step=target_time - time)
return time_1, solver_settings.value_postprocessor(time_1, values_1)
def second_order_total_variation_diminishing_runge_kutta(solver_settings, dynamics, grid, time, values, target_time):
time_1, values_1 = euler_step(solver_settings, dynamics, grid, time, values, max_time_step=target_time - time)
time_step = time_1 - time
_, values_2 = euler_step(solver_settings, dynamics, grid, time_1, values_1, time_step)
return time_1, solver_settings.value_postprocessor(time_1, (values + values_2) / 2)
def third_order_total_variation_diminishing_runge_kutta(solver_settings, dynamics, grid, time, values, target_time):
time_1, values_1 = euler_step(solver_settings, dynamics, grid, time, values, max_time_step=target_time - time)
time_step = time_1 - time
_, values_2 = euler_step(solver_settings, dynamics, grid, time_1, values_1, time_step)
time_0_5, values_0_5 = time + time_step / 2, (3 / 4) * values + (1 / 4) * values_2
_, values_1_5 = euler_step(solver_settings, dynamics, grid, time_0_5, values_0_5, time_step)
return time_1, solver_settings.value_postprocessor(time_1, (1 / 3) * values + (2 / 3) * values_1_5)
================================================
FILE: hj_reachability/utils.py
================================================
import functools
import jax
import jax.numpy as jnp
import numpy as np
from typing import Any, Callable, Iterable, List, Mapping, Optional, TypeVar, Union
T = TypeVar("T")
Tree = Union[T, Iterable["Tree[T]"], Mapping[Any, "Tree[T]"]]
def multivmap(fun: Callable,
in_axes: Tree[Optional[np.ndarray]],
out_axes: Tree[Optional[np.ndarray]] = None) -> Callable:
"""Applies `jax.vmap` over multiple axes (equivalent to multiple nested `jax.vmap`s).
Args:
fun: Function to be mapped over additional axes (see `jax.vmap` for more details).
in_axes: Similar to the specification of `in_axes` for `jax.vmap`, with the main difference being that instead
of `Optional[int]` for axis specification, it's `Optional[np.ndarray]`. For each corresponding input of
`fun`, the `np.ndarray` specifies a sequence of axes to `jax.vmap` over; note that these axes are not
specified directly as a `list` so as not to conflict with the possible structure of `in_axes`. All
non-`None` leaves of `in_axes` (there must be at least one) must have the same length. This length is the
number of times `jax.vmap` will be applied to `fun`.
out_axes: Similar to the specification of `out_axes` for `jax.vmap`, with the main difference being that instead
of `Optional[int]` for axis specification, it's `Optional[np.ndarray]`. For each corresponding output of
`fun`, the `np.ndarray` specifies a sequence of additional mapped axes to appear in the output. The length
of non-`None` leaves of `out_axes` must be the same as the length of non-`None` leaves of `in_axes`; the
order of both axes specifications corresponds to successive nested `jax.vmap` applications. If not provided,
`out_axes` defaults to `in_axes`.
Returns:
A batched/vectorized version of `fun` with arguments that correspond to those of `fun`, but with (possibly
multiple per input) extra array axes at positions indicated by `in_axes`, and a return value that corresponds
to that of `fun`, but with (possibly multiple per output) extra array axes at positions indicated by `out_axes`.
Raises:
ValueError: if any specified axes are negative or repeated.
"""
def get_axis_sequence(axis_array: np.ndarray) -> List:
axis_list = axis_array.tolist()
if any(axis < 0 for axis in axis_list):
raise ValueError(f"All `multivmap` axes must be nonnegative; got {axis_list}.")
if len(axis_list) != len(set(axis_list)):
raise ValueError(f"All `multivmap` axes must be distinct; got {axis_list}.")
for i in range(len(axis_list)):
for j in range(i + 1, len(axis_list)):
if axis_list[i] > axis_list[j]:
axis_list[i] -= 1
return axis_list
multivmap_kwargs = {"in_axes": in_axes, "out_axes": in_axes if out_axes is None else out_axes}
axis_sequence_structure = jax.tree.structure(next(a for a in jax.tree.leaves(in_axes) if a is not None).tolist())
vmap_kwargs = jax.tree.transpose(jax.tree.structure(multivmap_kwargs), axis_sequence_structure,
jax.tree.map(get_axis_sequence, multivmap_kwargs))
return functools.reduce(lambda f, kwargs: jax.vmap(f, **kwargs), vmap_kwargs, fun)
def unit_vector(x):
"""Normalizes a vector `x`, returning a unit vector in the same direction, or a zero vector if `x` is zero."""
norm2 = jnp.sum(jnp.square(x))
iszero = norm2 < jnp.finfo(jnp.zeros(()).dtype).eps**2
return jnp.where(iszero, jnp.zeros_like(x), x / jnp.sqrt(jnp.where(iszero, 1, norm2)))
================================================
FILE: hj_reachability/utils_test.py
================================================
from absl.testing import absltest
import jax
import jax.numpy as jnp
import numpy as np
from hj_reachability import utils
class UtilsTest(absltest.TestCase):
def setUp(self):
np.random.seed(0)
def test_multivmap(self):
a = np.random.random((3, 4, 5, 6))
np.testing.assert_allclose(utils.multivmap(jnp.max, np.array([0, 1]))(a), np.max(a, (2, 3)))
np.testing.assert_allclose(utils.multivmap(jnp.max, np.array([0, 1, 2]))(a), np.max(a, -1))
np.testing.assert_allclose(utils.multivmap(jnp.max, np.array([0, 1, 3]), np.array([0, 1, 2]))(a), np.max(a, 2))
np.testing.assert_allclose(
utils.multivmap(jnp.max, np.array([1, 0, 2]), np.array([0, 1, 2]))(a),
np.max(a, 3).swapaxes(0, 1))
np.testing.assert_allclose(
utils.multivmap(jnp.max, np.array([3, 2]), np.array([0, 1]))(a),
np.max(a, (0, 1)).swapaxes(0, 1))
def test_unit_vector(self):
unsafe_unit_vector = lambda x: x / jnp.linalg.norm(x, axis=-1, keepdims=True)
for d in range(1, 4):
np.testing.assert_array_equal(utils.unit_vector(np.zeros(d)), np.zeros(d))
self.assertTrue(np.all(np.isfinite(jax.jacobian(utils.unit_vector)(np.zeros(d)))))
self.assertTrue(np.all(np.isnan(jax.jacobian(unsafe_unit_vector)(np.zeros(d)))))
a = np.random.random((100, d))
np.testing.assert_allclose(jax.vmap(utils.unit_vector)(a), unsafe_unit_vector(a), atol=1e-6)
np.testing.assert_allclose(jax.vmap(jax.jacobian(utils.unit_vector))(a),
jax.vmap(jax.jacobian(unsafe_unit_vector))(a),
atol=1e-6)
if __name__ == "__main__":
absltest.main()
================================================
FILE: requirements-test.txt
================================================
absl-py>=0.12.0
tqdm>=4.60.0
================================================
FILE: requirements.txt
================================================
flax>=0.6.6
jax>=0.4.25
numpy>=1.22
================================================
FILE: setup.cfg
================================================
[yapf]
based_on_style = google
column_limit = 120
[flake8]
max-line-length = 120
ignore =
# E731: do not assign a lambda expression, use a def
E731
# E741: do not use variables named 'I', 'O', or 'l'
E741
# W504: line break occurred after a binary operator
W504
================================================
FILE: setup.py
================================================
import os
import setuptools
_CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
def _get_version():
with open(os.path.join(_CURRENT_DIR, "hj_reachability", "__init__.py")) as f:
for line in f:
if line.startswith("__version__") and "=" in line:
version = line[line.find("=") + 1:].strip(" '\"\n")
if version:
return version
raise ValueError("`__version__` not defined in `hj_reachability/__init__.py`")
def _parse_requirements(file):
with open(os.path.join(_CURRENT_DIR, file)) as f:
return [line.rstrip() for line in f if not (line.isspace() or line.startswith("#"))]
setuptools.setup(name="hj_reachability",
version=_get_version(),
description="Hamilton-Jacobi reachability analysis in JAX.",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
author="Ed Schmerling",
author_email="ednerd@gmail.com",
url="https://github.com/StanfordASL/hj_reachability",
license="MIT",
packages=setuptools.find_packages(),
install_requires=_parse_requirements("requirements.txt"),
tests_require=_parse_requirements("requirements-test.txt"),
python_requires="~=3.8")
gitextract_6__9fzrr/ ├── .github/ │ └── workflows/ │ ├── ci.yml │ └── pypi-publish.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── examples/ │ └── quickstart.ipynb ├── hj_reachability/ │ ├── __init__.py │ ├── artificial_dissipation.py │ ├── boundary_conditions.py │ ├── boundary_conditions_test.py │ ├── dynamics.py │ ├── finite_differences/ │ │ ├── __init__.py │ │ ├── upwind_first.py │ │ └── upwind_first_test.py │ ├── grid.py │ ├── grid_test.py │ ├── sets.py │ ├── sets_test.py │ ├── solver.py │ ├── solver_test.py │ ├── systems/ │ │ ├── __init__.py │ │ └── air3d.py │ ├── time_integration.py │ ├── utils.py │ └── utils_test.py ├── requirements-test.txt ├── requirements.txt ├── setup.cfg └── setup.py
SYMBOL INDEX (111 symbols across 17 files)
FILE: hj_reachability/artificial_dissipation.py
function global_lax_friedrichs (line 9) | def global_lax_friedrichs(partial_max_magnitudes, states, time, values, ...
function local_lax_friedrichs (line 18) | def local_lax_friedrichs(partial_max_magnitudes, states, time, values, l...
function local_local_lax_friedrichs (line 36) | def local_local_lax_friedrichs(partial_max_magnitudes, states, time, val...
FILE: hj_reachability/boundary_conditions.py
function periodic (line 10) | def periodic(x: Array, pad_width: int) -> Array:
function extrapolate (line 15) | def extrapolate(x: Array, pad_width: int) -> Array:
function extrapolate_away_from_zero (line 21) | def extrapolate_away_from_zero(x: Array, pad_width: int) -> Array:
FILE: hj_reachability/boundary_conditions_test.py
class BoundaryConditionsTest (line 7) | class BoundaryConditionsTest(absltest.TestCase):
method setUp (line 9) | def setUp(self):
method test_periodic (line 12) | def test_periodic(self):
method test_extrapolate (line 18) | def test_extrapolate(self):
method test_extrapolate_away_from_zero (line 24) | def test_extrapolate_away_from_zero(self):
FILE: hj_reachability/dynamics.py
class Dynamics (line 6) | class Dynamics(metaclass=abc.ABCMeta):
method __init__ (line 18) | def __init__(self, control_mode, disturbance_mode, control_space, dist...
method __call__ (line 25) | def __call__(self, state, control, disturbance, time):
method optimal_control_and_disturbance (line 29) | def optimal_control_and_disturbance(self, state, time, grad_value):
method optimal_control (line 32) | def optimal_control(self, state, time, grad_value):
method optimal_disturbance (line 36) | def optimal_disturbance(self, state, time, grad_value):
method hamiltonian (line 40) | def hamiltonian(self, state, time, value, grad_value):
method partial_max_magnitudes (line 47) | def partial_max_magnitudes(self, state, time, value, grad_value_box):
class ControlAndDisturbanceAffineDynamics (line 51) | class ControlAndDisturbanceAffineDynamics(Dynamics):
method __call__ (line 54) | def __call__(self, state, control, disturbance, time):
method open_loop_dynamics (line 60) | def open_loop_dynamics(self, state, time):
method control_jacobian (line 64) | def control_jacobian(self, state, time):
method disturbance_jacobian (line 68) | def disturbance_jacobian(self, state, time):
method optimal_control_and_disturbance (line 71) | def optimal_control_and_disturbance(self, state, time, grad_value):
method partial_max_magnitudes (line 82) | def partial_max_magnitudes(self, state, time, value, grad_value_box):
FILE: hj_reachability/finite_differences/upwind_first.py
function weighted_essentially_non_oscillatory (line 16) | def weighted_essentially_non_oscillatory(eno_order: int, values: Array, ...
function essentially_non_oscillatory (line 58) | def essentially_non_oscillatory(order: int, values: Array, spacing: float,
function _weighted_essentially_non_oscillatory_vectorized (line 109) | def _weighted_essentially_non_oscillatory_vectorized(
function _unrolled_correlate (line 135) | def _unrolled_correlate(a: Array, v: Array) -> Array:
function _substencils (line 140) | def _substencils(k: int) -> Array:
function _spread_substencil_values (line 145) | def _spread_substencil_values(x: Array, np: ModuleType = np) -> Array:
function _align_substencil_values (line 150) | def _align_substencil_values(x: Array, np: ModuleType = np) -> Array:
function _diff_coefficients (line 155) | def _diff_coefficients(k: Optional[int] = None, stencil: Optional[Array]...
function _substencil_coefficients (line 172) | def _substencil_coefficients(k: int) -> Array:
function _polyder_operator (line 180) | def _polyder_operator(k: int, d: int) -> Array:
function _smoothness_indicator_quad_form (line 185) | def _smoothness_indicator_quad_form(k: int) -> Array:
FILE: hj_reachability/finite_differences/upwind_first_test.py
class UpwindFirstTest (line 11) | class UpwindFirstTest(absltest.TestCase):
method setUp (line 13) | def setUp(self):
method test_weighted_essentially_non_oscillatory (line 16) | def test_weighted_essentially_non_oscillatory(self):
method test_essentially_non_oscillatory (line 44) | def test_essentially_non_oscillatory(self):
method test_weighted_essentially_non_oscillatory_vectorized (line 86) | def test_weighted_essentially_non_oscillatory_vectorized(self):
method test_diff_coefficients (line 97) | def test_diff_coefficients(self):
method test_substencil_coefficients (line 108) | def test_substencil_coefficients(self):
method test_smoothness_indicator_quad_form (line 118) | def test_smoothness_indicator_quad_form(self):
FILE: hj_reachability/grid.py
class Grid (line 19) | class Grid:
method from_lattice_parameters_and_boundary_conditions (line 42) | def from_lattice_parameters_and_boundary_conditions(
method ndim (line 78) | def ndim(self) -> int:
method shape (line 83) | def shape(self) -> Tuple[int, ...]:
method upwind_grad_values (line 87) | def upwind_grad_values(self, upwind_scheme: Callable, values: Array) -...
method grad_values (line 98) | def grad_values(self, values: Array, upwind_scheme: Optional[Callable]...
method position (line 105) | def position(self, state: Array) -> Array:
method nearest_index (line 110) | def nearest_index(self, state: Array) -> Array:
method interpolate (line 114) | def interpolate(self, values, state):
method _is_periodic_dim (line 134) | def _is_periodic_dim(self) -> Array:
FILE: hj_reachability/grid_test.py
class BoundaryConditionsTest (line 10) | class BoundaryConditionsTest(absltest.TestCase):
method setUp (line 12) | def setUp(self):
method test_grid_interpolate (line 15) | def test_grid_interpolate(self):
method test_grid_interpolate_on_grid (line 28) | def test_grid_interpolate_on_grid(self):
method test_grid_interpolate_off_grid (line 45) | def test_grid_interpolate_off_grid(self):
method test_grid_interpolate_extrapolate_nan (line 67) | def test_grid_interpolate_extrapolate_nan(self):
FILE: hj_reachability/sets.py
class BoundedSet (line 14) | class BoundedSet(metaclass=abc.ABCMeta):
method extreme_point (line 18) | def extreme_point(self, direction: Array) -> Array:
method bounding_box (line 23) | def bounding_box(self) -> "Box":
method max_magnitudes (line 27) | def max_magnitudes(self) -> Array:
method ndim (line 32) | def ndim(self) -> int:
class Box (line 38) | class Box(BoundedSet):
method extreme_point (line 43) | def extreme_point(self, direction: Array) -> Array:
method bounding_box (line 48) | def bounding_box(self) -> "Box":
method ndim (line 53) | def ndim(self) -> int:
class Ball (line 59) | class Ball(BoundedSet):
method extreme_point (line 64) | def extreme_point(self, direction: Array) -> Array:
method bounding_box (line 69) | def bounding_box(self) -> "Box":
FILE: hj_reachability/sets_test.py
class SetsTest (line 8) | class SetsTest(absltest.TestCase):
method setUp (line 10) | def setUp(self):
method test_box (line 13) | def test_box(self):
method test_ball (line 21) | def test_ball(self):
FILE: hj_reachability/solver.py
class SolverSettings (line 24) | class SolverSettings:
method with_accuracy (line 48) | def with_accuracy(cls, accuracy: Text, **kwargs) -> "SolverSettings":
function step (line 65) | def step(solver_settings, dynamics, grid, time, values, target_time, pro...
function solve (line 80) | def solve(solver_settings, dynamics, grid, times, initial_values, progre...
function _try_get_progress_bar (line 93) | def _try_get_progress_bar(reference_time, target_time):
class TqdmWrapper (line 106) | class TqdmWrapper:
method __init__ (line 108) | def __init__(self, tqdm, reference_time, total, *args, **kwargs):
method _create_tqdm (line 117) | def _create_tqdm(self, tqdm, total, *args, **kwargs):
method update_to (line 120) | def update_to(self, n):
method close (line 128) | def close(self):
method __enter__ (line 135) | def __enter__(self):
method __exit__ (line 138) | def __exit__(self, exc_type, exc_value, traceback):
FILE: hj_reachability/solver_test.py
class SolverTest (line 7) | class SolverTest(absltest.TestCase):
method setUp (line 9) | def setUp(self):
method test_step (line 23) | def test_step(self):
method test_solve (line 31) | def test_solve(self):
FILE: hj_reachability/systems/air3d.py
class Air3d (line 7) | class Air3d(dynamics.ControlAndDisturbanceAffineDynamics):
method __init__ (line 9) | def __init__(self,
method open_loop_dynamics (line 26) | def open_loop_dynamics(self, state, time):
method control_jacobian (line 31) | def control_jacobian(self, state, time):
method disturbance_jacobian (line 39) | def disturbance_jacobian(self, state, time):
FILE: hj_reachability/time_integration.py
function lax_friedrichs_numerical_hamiltonian (line 10) | def lax_friedrichs_numerical_hamiltonian(hamiltonian, state, time, value...
function euler_step (line 18) | def euler_step(solver_settings, dynamics, grid, time, values, time_step=...
function first_order_total_variation_diminishing_runge_kutta (line 37) | def first_order_total_variation_diminishing_runge_kutta(solver_settings,...
function second_order_total_variation_diminishing_runge_kutta (line 42) | def second_order_total_variation_diminishing_runge_kutta(solver_settings...
function third_order_total_variation_diminishing_runge_kutta (line 49) | def third_order_total_variation_diminishing_runge_kutta(solver_settings,...
FILE: hj_reachability/utils.py
function multivmap (line 13) | def multivmap(fun: Callable,
function unit_vector (line 61) | def unit_vector(x):
FILE: hj_reachability/utils_test.py
class UtilsTest (line 9) | class UtilsTest(absltest.TestCase):
method setUp (line 11) | def setUp(self):
method test_multivmap (line 14) | def test_multivmap(self):
method test_unit_vector (line 26) | def test_unit_vector(self):
FILE: setup.py
function _get_version (line 7) | def _get_version():
function _parse_requirements (line 17) | def _parse_requirements(file):
Condensed preview — 30 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (84K chars).
[
{
"path": ".github/workflows/ci.yml",
"chars": 1016,
"preview": "name: ci\n\non: [push, pull_request]\n\njobs:\n test:\n name: \"Python ${{ matrix.python-version }} on ${{ matrix.os }}\"\n "
},
{
"path": ".github/workflows/pypi-publish.yml",
"chars": 926,
"preview": "name: pypi\n\non:\n release:\n types: [published]\n\njobs:\n deploy:\n runs-on: ubuntu-latest\n\n steps:\n - uses: ac"
},
{
"path": ".gitignore",
"chars": 1799,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "LICENSE",
"chars": 1070,
"preview": "MIT License\n\nCopyright (c) 2021 Ed Schmerling\n\nPermission is hereby granted, free of charge, to any person obtaining a c"
},
{
"path": "MANIFEST.in",
"chars": 55,
"preview": "include requirements.txt\ninclude requirements-test.txt\n"
},
{
"path": "README.md",
"chars": 2689,
"preview": "# hj_reachability: Hamilton-Jacobi reachability analysis in [JAX]\nThis package implements numerical solvers for Hamilton"
},
{
"path": "examples/quickstart.ipynb",
"chars": 9683,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"# hj_reachability quickstart\\n\",\n "
},
{
"path": "hj_reachability/__init__.py",
"chars": 793,
"preview": "from hj_reachability import artificial_dissipation\nfrom hj_reachability import boundary_conditions\nfrom hj_reachability "
},
{
"path": "hj_reachability/artificial_dissipation.py",
"chars": 2616,
"preview": "import jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom hj_reachability import sets\nfrom hj_reachability import util"
},
{
"path": "hj_reachability/boundary_conditions.py",
"chars": 998,
"preview": "import jax.numpy as jnp\n\nfrom typing import Any, Callable\n\nArray = Any\n\nBoundaryCondition = Callable[[Array, int], Array"
},
{
"path": "hj_reachability/boundary_conditions_test.py",
"chars": 1657,
"preview": "from absl.testing import absltest\nimport numpy as np\n\nfrom hj_reachability import boundary_conditions\n\n\nclass BoundaryCo"
},
{
"path": "hj_reachability/dynamics.py",
"chars": 4395,
"preview": "import abc\n\nimport jax.numpy as jnp\n\n\nclass Dynamics(metaclass=abc.ABCMeta):\n \"\"\"Abstract base class for representing"
},
{
"path": "hj_reachability/finite_differences/__init__.py",
"chars": 462,
"preview": "from hj_reachability.finite_differences.upwind_first import (ENO1, ENO2, ENO3, WENO1, WENO3, WENO5,\n "
},
{
"path": "hj_reachability/finite_differences/upwind_first.py",
"chars": 9686,
"preview": "import functools\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport numpy.polynomial.polynomial as poly\n\nfrom"
},
{
"path": "hj_reachability/finite_differences/upwind_first_test.py",
"chars": 6201,
"preview": "import math\n\nfrom absl.testing import absltest\nimport jax\nimport numpy as np\n\nfrom hj_reachability import boundary_condi"
},
{
"path": "hj_reachability/grid.py",
"chars": 7166,
"preview": "import functools\n\nfrom flax import struct\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom hj_reachability import bounda"
},
{
"path": "hj_reachability/grid_test.py",
"chars": 4659,
"preview": "from absl.testing import absltest\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom hj_reachability import gri"
},
{
"path": "hj_reachability/sets.py",
"chars": 2216,
"preview": "import abc\n\nfrom flax import struct\nimport jax.numpy as jnp\n\nfrom hj_reachability import utils\n\nfrom typing import Any\n\n"
},
{
"path": "hj_reachability/sets_test.py",
"chars": 1188,
"preview": "from absl.testing import absltest\nimport jax\nimport numpy as np\n\nfrom hj_reachability import sets\n\n\nclass SetsTest(abslt"
},
{
"path": "hj_reachability/solver.py",
"chars": 5165,
"preview": "import contextlib\nimport functools\n\nfrom flax import struct\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom "
},
{
"path": "hj_reachability/solver_test.py",
"chars": 2319,
"preview": "from absl.testing import absltest\nimport numpy as np\n\nimport hj_reachability as hj\n\n\nclass SolverTest(absltest.TestCase)"
},
{
"path": "hj_reachability/systems/__init__.py",
"chars": 105,
"preview": "from hj_reachability.systems.air3d import Air3d, DubinsCarCAvoid\n\n__all__ = (\"Air3d\", \"DubinsCarCAvoid\")\n"
},
{
"path": "hj_reachability/systems/air3d.py",
"chars": 1499,
"preview": "import jax.numpy as jnp\n\nfrom hj_reachability import dynamics\nfrom hj_reachability import sets\n\n\nclass Air3d(dynamics.Co"
},
{
"path": "hj_reachability/time_integration.py",
"chars": 3620,
"preview": "import functools\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom hj_reachability import utils\n\n\ndef lax_fri"
},
{
"path": "hj_reachability/utils.py",
"chars": 3720,
"preview": "import functools\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom typing import Any, Callable, Iterable, Lis"
},
{
"path": "hj_reachability/utils_test.py",
"chars": 1765,
"preview": "from absl.testing import absltest\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom hj_reachability import uti"
},
{
"path": "requirements-test.txt",
"chars": 29,
"preview": "absl-py>=0.12.0\ntqdm>=4.60.0\n"
},
{
"path": "requirements.txt",
"chars": 36,
"preview": "flax>=0.6.6\njax>=0.4.25\nnumpy>=1.22\n"
},
{
"path": "setup.cfg",
"chars": 287,
"preview": "[yapf]\nbased_on_style = google\ncolumn_limit = 120\n\n[flake8]\nmax-line-length = 120\nignore =\n # E731: do not assign a l"
},
{
"path": "setup.py",
"chars": 1403,
"preview": "import os\nimport setuptools\n\n_CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))\n\n\ndef _get_version():\n with op"
}
]
About this extraction
This page contains the full source code of the StanfordASL/hj_reachability GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 30 files (77.4 KB), approximately 20.9k tokens, and a symbol index with 111 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.