Repository: StanfordASL/hj_reachability Branch: main Commit: 9aebaf974406 Files: 30 Total size: 77.4 KB Directory structure: gitextract_6__9fzrr/ ├── .github/ │ └── workflows/ │ ├── ci.yml │ └── pypi-publish.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── examples/ │ └── quickstart.ipynb ├── hj_reachability/ │ ├── __init__.py │ ├── artificial_dissipation.py │ ├── boundary_conditions.py │ ├── boundary_conditions_test.py │ ├── dynamics.py │ ├── finite_differences/ │ │ ├── __init__.py │ │ ├── upwind_first.py │ │ └── upwind_first_test.py │ ├── grid.py │ ├── grid_test.py │ ├── sets.py │ ├── sets_test.py │ ├── solver.py │ ├── solver_test.py │ ├── systems/ │ │ ├── __init__.py │ │ └── air3d.py │ ├── time_integration.py │ ├── utils.py │ └── utils_test.py ├── requirements-test.txt ├── requirements.txt ├── setup.cfg └── setup.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/ci.yml ================================================ name: ci on: [push, pull_request] jobs: test: name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" runs-on: "${{ matrix.os }}" strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12"] os: [ubuntu-latest] steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | set -xe python -m pip install --upgrade pip pip install flake8 pytest pytest-xdist yapf pip install -r requirements.txt pip install -r requirements-test.txt - name: Lint with flake8 run: | set -xe flake8 . --config=setup.cfg --count --statistics - name: Check formatting with yapf run: | set -xe yapf . --style=setup.cfg --recursive --diff - name: Test with pytest run: | set -xe pytest -n "$(grep -c ^processor /proc/cpuinfo)" hj_reachability ================================================ FILE: .github/workflows/pypi-publish.yml ================================================ name: pypi on: release: types: [published] jobs: deploy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v5 with: python-version: '3.x' - name: Install dependencies run: | python -m pip install --upgrade pip pip install build setuptools - name: Check consistency between the package version and release tag run: | PACKAGE_VER="v`python setup.py --version`" if [ $PACKAGE_VER != ${{ github.event.release.tag_name }} ] then echo "Package version ($PACKAGE_VER) != release tag (${{ github.event.release.tag_name }})."; exit 1 fi - name: Build package run: python -m build - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2021 Ed Schmerling Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: MANIFEST.in ================================================ include requirements.txt include requirements-test.txt ================================================ FILE: README.md ================================================ # hj_reachability: Hamilton-Jacobi reachability analysis in [JAX] This package implements numerical solvers for Hamilton-Jacobi (HJ) Partial Differential Equations (PDEs) which, in the context of optimal control, may be used to represent the continuous-time formulation of dynamic programming. Specifically, the focus of this package is on reachability analysis for zero-sum differential games modeled as Hamilton-Jacobi-Isaacs (HJI) PDEs, wherein an optimal controller and (optional) disturbance interact, and the set of reachable states at any time is represented as the zero sublevel set of a value function realized as the viscosity solution of the corresponding PDE. This package is inspired by a number of related projects, including: - [A Toolbox of Level Set Methods (`toolboxls`, MATLAB)](https://www.cs.ubc.ca/~mitchell/ToolboxLS/) - [An Optimal Control Toolbox for Hamilton-Jacobi Reachability Analysis (`helperOC`, MATLAB)](https://github.com/HJReachability/helperOC) - [Berkeley Efficient API in C++ for Level Set methods (`beacls`, C++/CUDA C++)](https://hjreachability.github.io/beacls/) - [Optimizing Dynamic Programming-Based Algorithms (`optimized_dp`, python)](https://github.com/SFU-MARS/optimized_dp) ## Installation This package accommodates different [JAX] versions (i.e., CPU-only vs. JAX with GPU support); if accelerator support is desired you should first install JAX according to the relevant [installation instructions](https://github.com/google/jax#installation). A minimum JAX version requirement is listed in [`requirements.txt`](https://github.com/StanfordASL/hj_reachability/blob/main/requirements.txt), but in general this package should be compatible with the latest JAX releases (please [file an issue](https://github.com/StanfordASL/hj_reachability/issues) if you find that this is no longer the case!). If you only want CPU computation or have already installed JAX with your preferred accelerator support, you may install this package using pip: ``` pip install --upgrade hj-reachability ``` ## TODOs Aside from the specific TODOs scattered throughout the codebase, a few general TODOs: - Single-line docstrings (at a bare minimum) for everything. Test coverage, book/paper references, and proper documentation to come... eventually. - Look into using `jax.pmap`/`jax.lax.ppermute` for multi-device parallelism; see, e.g., [jax demo notebooks](https://github.com/google/jax/tree/master/cloud_tpu_colabs). - Incorporate neural-network-based PDE solvers; see, e.g., [Bansal, S. and Tomlin, C. "DeepReach: A Deep Learning Approach to High-Dimensional Reachability." (2020)](https://arxiv.org/abs/2011.02082). [JAX]: https://github.com/google/jax ================================================ FILE: examples/quickstart.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# hj_reachability quickstart\n", "\n", "Notebook dependencies:\n", "- System: python3, ffmpeg (for rendering animations)\n", "- Python: jupyter, jax, numpy, matplotlib, plotly, tqdm, hj_reachability\n", "\n", "Example setup for a Ubuntu system (Mac users, maybe `brew` instead of `sudo apt`; Windows users, learn to love [WSL](https://docs.microsoft.com/en-us/windows/wsl/install-win10)):\n", "```\n", "sudo apt install ffmpeg\n", "/usr/bin/python3 -m pip install --upgrade pip\n", "pip install --upgrade jupyter jax[cpu] numpy matplotlib plotly tqdm hj-reachability\n", "jupyter notebook # from the directory of this notebook\n", "```\n", "Alternatively, view this notebook on [Google Colab](https://colab.research.google.com/github/StanfordASL/hj_reachability/blob/main/examples/quickstart.ipynb) and run a cell containing this command:\n", "```\n", "!pip install --upgrade hj-reachability\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "\n", "from IPython.display import HTML\n", "import matplotlib.animation as anim\n", "import matplotlib.pyplot as plt\n", "import plotly.graph_objects as go\n", "\n", "import hj_reachability as hj" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example system: `Air3d`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dynamics = hj.systems.Air3d()\n", "grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(np.array([-6., -10., 0.]),\n", " np.array([20., 10., 2 * np.pi])),\n", " (51, 40, 50),\n", " periodic_dims=2)\n", "values = jnp.linalg.norm(grid.states[..., :2], axis=-1) - 5\n", "\n", "solver_settings = hj.SolverSettings.with_accuracy(\"very_high\",\n", " hamiltonian_postprocessor=hj.solver.backwards_reachable_tube)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `hj.step`: propagate the HJ PDE from `(time, values)` to `target_time`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "time = 0.\n", "target_time = -2.8\n", "target_values = hj.step(solver_settings, dynamics, grid, time, values, target_time)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.jet()\n", "plt.figure(figsize=(13, 8))\n", "plt.contourf(grid.coordinate_vectors[0], grid.coordinate_vectors[1], target_values[:, :, 30].T)\n", "plt.colorbar()\n", "plt.contour(grid.coordinate_vectors[0],\n", " grid.coordinate_vectors[1],\n", " target_values[:, :, 30].T,\n", " levels=0,\n", " colors=\"black\",\n", " linewidths=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "go.Figure(data=go.Isosurface(x=grid.states[..., 0].ravel(),\n", " y=grid.states[..., 1].ravel(),\n", " z=grid.states[..., 2].ravel(),\n", " value=target_values.ravel(),\n", " colorscale=\"jet\",\n", " isomin=0,\n", " surface_count=1,\n", " isomax=0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `hj.solve`: solve for `all_values` at a range of `times` (basically just iterating `hj.step`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "times = np.linspace(0, -2.8, 57)\n", "initial_values = values\n", "all_values = hj.solve(solver_settings, dynamics, grid, times, initial_values)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vmin, vmax = all_values.min(), all_values.max()\n", "levels = np.linspace(round(vmin), round(vmax), round(vmax) - round(vmin) + 1)\n", "fig = plt.figure(figsize=(13, 8))\n", "\n", "\n", "def render_frame(i, colorbar=False):\n", " plt.contourf(grid.coordinate_vectors[0],\n", " grid.coordinate_vectors[1],\n", " all_values[i, :, :, 30].T,\n", " vmin=vmin,\n", " vmax=vmax,\n", " levels=levels)\n", " if colorbar:\n", " plt.colorbar()\n", " plt.contour(grid.coordinate_vectors[0],\n", " grid.coordinate_vectors[1],\n", " target_values[:, :, 30].T,\n", " levels=0,\n", " colors=\"black\",\n", " linewidths=3)\n", "\n", "\n", "render_frame(0, True)\n", "animation = HTML(anim.FuncAnimation(fig, render_frame, all_values.shape[0], interval=50).to_html5_video())\n", "plt.close(); animation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Defining your own dynamics: `AccelerationCurvatureCar`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class AccelerationCurvatureCar(hj.ControlAndDisturbanceAffineDynamics):\n", "\n", " def __init__(self,\n", " max_acceleration=1.,\n", " max_curvature=1.,\n", " max_position_disturbance=0.25,\n", " control_mode=\"min\",\n", " disturbance_mode=\"max\",\n", " control_space=None,\n", " disturbance_space=None):\n", " if control_space is None:\n", " control_space = hj.sets.Box(jnp.array([-max_acceleration, -max_curvature]),\n", " jnp.array([max_acceleration, max_curvature]))\n", " if disturbance_space is None:\n", " disturbance_space = hj.sets.Ball(jnp.zeros(2), max_position_disturbance)\n", " super().__init__(control_mode, disturbance_mode, control_space, disturbance_space)\n", "\n", " def open_loop_dynamics(self, state, time):\n", " _, _, v, q = state\n", " return jnp.array([v * jnp.cos(q), v * jnp.sin(q), 0., 0.])\n", "\n", " def control_jacobian(self, state, time):\n", " v = state[2]\n", " return jnp.array([\n", " [0., 0.],\n", " [0., 0.],\n", " [1., 0.],\n", " [0., v],\n", " ])\n", "\n", " def disturbance_jacobian(self, state, time):\n", " return jnp.array([\n", " [1., 0.],\n", " [0., 1.],\n", " [0., 0.],\n", " [0., 0.],\n", " ])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dynamics = AccelerationCurvatureCar()\n", "grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(lo=np.array([-5., -5., -1., -np.pi]),\n", " hi=np.array([5., 5., 1., np.pi])),\n", " (40, 40, 50, 50),\n", " periodic_dims=3)\n", "values = jnp.linalg.norm(grid.states[..., :2], axis=-1) - 1\n", "\n", "solver_settings = hj.SolverSettings.with_accuracy(\"low\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "time = 0.\n", "target_time = -2.0\n", "target_values = hj.step(solver_settings, dynamics, grid, time, values, target_time)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "go.Figure(data=go.Isosurface(x=grid.states[:, :, -1, :, 0].ravel(),\n", " y=grid.states[:, :, -1, :, 1].ravel(),\n", " z=grid.states[:, :, -1, :, 3].ravel(),\n", " value=target_values[:, :, -1, :].ravel(),\n", " colorscale='jet',\n", " isomin=0,\n", " surface_count=1,\n", " isomax=0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: hj_reachability/__init__.py ================================================ from hj_reachability import artificial_dissipation from hj_reachability import boundary_conditions from hj_reachability import finite_differences from hj_reachability import sets from hj_reachability import solver from hj_reachability import systems from hj_reachability import time_integration from hj_reachability import utils from hj_reachability.dynamics import ControlAndDisturbanceAffineDynamics, Dynamics from hj_reachability.grid import Grid from hj_reachability.solver import SolverSettings, solve, step __version__ = "0.7.0" __all__ = ("ControlAndDisturbanceAffineDynamics", "Dynamics", "Grid", "SolverSettings", "artificial_dissipation", "boundary_conditions", "finite_differences", "sets", "solve", "solver", "step", "systems", "time_integration", "utils") ================================================ FILE: hj_reachability/artificial_dissipation.py ================================================ import jax import jax.numpy as jnp import numpy as np from hj_reachability import sets from hj_reachability import utils def global_lax_friedrichs(partial_max_magnitudes, states, time, values, left_grad_values, right_grad_values): """Implements the Global Lax-Friedrichs (GLF) scheme for computing dissipation coefficients.""" grid_axes = np.arange(values.ndim) grad_value_box = sets.Box(jnp.minimum(jnp.min(left_grad_values, grid_axes), jnp.min(right_grad_values, grid_axes)), jnp.maximum(jnp.max(left_grad_values, grid_axes), jnp.max(right_grad_values, grid_axes))) return utils.multivmap(lambda state, value: partial_max_magnitudes(state, time, value, grad_value_box), grid_axes)(states, values) def local_lax_friedrichs(partial_max_magnitudes, states, time, values, left_grad_values, right_grad_values): """Implements the Local Lax-Friedrichs (LLF) scheme for computing dissipation coefficients.""" grid_axes = np.arange(values.ndim) global_grad_value_box = sets.Box( jnp.minimum(jnp.min(left_grad_values, grid_axes), jnp.min(right_grad_values, grid_axes)), jnp.maximum(jnp.max(left_grad_values, grid_axes), jnp.max(right_grad_values, grid_axes))) local_local_grad_value_boxes = sets.Box(jnp.minimum(left_grad_values, right_grad_values), jnp.maximum(left_grad_values, right_grad_values)) local_grad_value_boxes = jax.tree.map( lambda global_grad_value, local_local_grad_values: (jnp.broadcast_to(global_grad_value, values.shape + (values.ndim,) * 2).at[..., grid_axes, grid_axes].set(local_local_grad_values)), global_grad_value_box, local_local_grad_value_boxes) return utils.multivmap( lambda state, value, grad_value_box: partial_max_magnitudes(state, time, value, grad_value_box), grid_axes)(states, values, local_grad_value_boxes) def local_local_lax_friedrichs(partial_max_magnitudes, states, time, values, left_grad_values, right_grad_values): """Implements the Local Local Lax-Friedrichs (LLLF) scheme for computing dissipation coefficients.""" grid_axes = np.arange(values.ndim) local_local_grad_value_boxes = sets.Box(jnp.minimum(left_grad_values, right_grad_values), jnp.maximum(left_grad_values, right_grad_values)) return utils.multivmap( lambda state, value, grad_value_box: partial_max_magnitudes(state, time, value, grad_value_box), grid_axes)(states, values, local_local_grad_value_boxes) ================================================ FILE: hj_reachability/boundary_conditions.py ================================================ import jax.numpy as jnp from typing import Any, Callable Array = Any BoundaryCondition = Callable[[Array, int], Array] def periodic(x: Array, pad_width: int) -> Array: """Pads a 1D array `x` by wrapping values, using the start values to pad the end and vice versa.""" return jnp.pad(x, ((pad_width, pad_width)), "wrap") def extrapolate(x: Array, pad_width: int) -> Array: """Pads a 1D array `x` by extrapolating using the slope at each end.""" return jnp.concatenate( [x[0] + (x[1] - x[0]) * jnp.arange(-pad_width, 0), x, x[-1] + (x[-1] - x[-2]) * jnp.arange(1, pad_width + 1)]) def extrapolate_away_from_zero(x: Array, pad_width: int) -> Array: """Pads a 1D array `x` by extrapolating away from zero using the (possibly negated) slope at each end.""" return jnp.concatenate([ x[0] - jnp.sign(x[0]) * jnp.abs(x[1] - x[0]) * jnp.arange(-pad_width, 0), x, x[-1] + jnp.sign(x[-1]) * jnp.abs(x[-1] - x[-2]) * jnp.arange(1, pad_width + 1) ]) ================================================ FILE: hj_reachability/boundary_conditions_test.py ================================================ from absl.testing import absltest import numpy as np from hj_reachability import boundary_conditions class BoundaryConditionsTest(absltest.TestCase): def setUp(self): np.random.seed(0) def test_periodic(self): x = np.arange(5) np.testing.assert_array_equal(boundary_conditions.periodic(x, 0), x) np.testing.assert_array_equal(boundary_conditions.periodic(x, 1), [4, 0, 1, 2, 3, 4, 0]) np.testing.assert_array_equal(boundary_conditions.periodic(x, 2), [3, 4, 0, 1, 2, 3, 4, 0, 1]) def test_extrapolate(self): x = np.arange(5) np.testing.assert_array_equal(boundary_conditions.extrapolate(x, 0), x) np.testing.assert_array_equal(boundary_conditions.extrapolate(x, 1), np.arange(-1, 6)) np.testing.assert_array_equal(boundary_conditions.extrapolate(x, 2), np.arange(-2, 7)) def test_extrapolate_away_from_zero(self): x = np.arange(1, 5) np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 0), x) np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 1), [2, 1, 2, 3, 4, 5]) np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 2), [3, 2, 1, 2, 3, 4, 5, 6]) x = x[::-1] np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 0), x) np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 1), [5, 4, 3, 2, 1, 2]) np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 2), [6, 5, 4, 3, 2, 1, 2, 3]) if __name__ == "__main__": absltest.main() ================================================ FILE: hj_reachability/dynamics.py ================================================ import abc import jax.numpy as jnp class Dynamics(metaclass=abc.ABCMeta): """Abstract base class for representing continuous-time dynamics in the context of Hamilton-Jacobi reachability. TODO: Consider allowing for state/time-dependent control/disturbance spaces. Attributes: control_mode: Whether the controller is trying to "max"imize or "min"imize the value. disturbance_mode: Whether the disturbance is trying to "max"imize or "min"imize the value. control_space: A `BoundedSet` defining the (time-invariant) set of possible controls. disturbance_space: A `BoundedSet` defining the (time-invariant) set of possible disturbances. """ def __init__(self, control_mode, disturbance_mode, control_space, disturbance_space): self.control_mode = control_mode self.disturbance_mode = disturbance_mode self.control_space = control_space self.disturbance_space = disturbance_space @abc.abstractmethod def __call__(self, state, control, disturbance, time): """Implements the continuous-time dynamics ODE.""" @abc.abstractmethod def optimal_control_and_disturbance(self, state, time, grad_value): """Computes the optimal control and disturbance realized by the HJ PDE Hamiltonian.""" def optimal_control(self, state, time, grad_value): """Computes the optimal control realized by the HJ PDE Hamiltonian.""" return self.optimal_control_and_disturbance(state, time, grad_value)[0] def optimal_disturbance(self, state, time, grad_value): """Computes the optimal disturbance realized by the HJ PDE Hamiltonian.""" return self.optimal_control_and_disturbance(state, time, grad_value)[1] def hamiltonian(self, state, time, value, grad_value): """Evaluates the HJ PDE Hamiltonian.""" del value # unused control, disturbance = self.optimal_control_and_disturbance(state, time, grad_value) return grad_value @ self(state, control, disturbance, time) @abc.abstractmethod def partial_max_magnitudes(self, state, time, value, grad_value_box): """Computes the max magnitudes of the Hamiltonian partials over the `grad_value_box` in each dimension.""" class ControlAndDisturbanceAffineDynamics(Dynamics): """Abstract base class for representing control- and disturbance-affine dynamics.""" def __call__(self, state, control, disturbance, time): """Implements the affine dynamics `dx_dt = f(x, t) + G_u(x, t) @ u + G_d(x, t) @ d`.""" return (self.open_loop_dynamics(state, time) + self.control_jacobian(state, time) @ control + self.disturbance_jacobian(state, time) @ disturbance) @abc.abstractmethod def open_loop_dynamics(self, state, time): """Implements the open loop dynamics `f(x, t)`.""" @abc.abstractmethod def control_jacobian(self, state, time): """Implements the control Jacobian `G_u(x, t)`.""" @abc.abstractmethod def disturbance_jacobian(self, state, time): """Implements the disturbance Jacobian `G_d(x, t)`.""" def optimal_control_and_disturbance(self, state, time, grad_value): """Computes the optimal control and disturbance realized by the HJ PDE Hamiltonian.""" control_direction = grad_value @ self.control_jacobian(state, time) if self.control_mode == "min": control_direction = -control_direction disturbance_direction = grad_value @ self.disturbance_jacobian(state, time) if self.disturbance_mode == "min": disturbance_direction = -disturbance_direction return (self.control_space.extreme_point(control_direction), self.disturbance_space.extreme_point(disturbance_direction)) def partial_max_magnitudes(self, state, time, value, grad_value_box): """Computes the max magnitudes of the Hamiltonian partials over the `grad_value_box` in each dimension.""" del value, grad_value_box # unused # An overestimation; see Eq. (25) from https://www.cs.ubc.ca/~mitchell/ToolboxLS/toolboxLS-1.1.pdf. return (jnp.abs(self.open_loop_dynamics(state, time)) + jnp.abs(self.control_jacobian(state, time)) @ self.control_space.max_magnitudes + jnp.abs(self.disturbance_jacobian(state, time)) @ self.disturbance_space.max_magnitudes) ================================================ FILE: hj_reachability/finite_differences/__init__.py ================================================ from hj_reachability.finite_differences.upwind_first import (ENO1, ENO2, ENO3, WENO1, WENO3, WENO5, essentially_non_oscillatory, first_order, weighted_essentially_non_oscillatory) __all__ = ("ENO1", "ENO2", "ENO3", "WENO1", "WENO3", "WENO5", "essentially_non_oscillatory", "first_order", "weighted_essentially_non_oscillatory") ================================================ FILE: hj_reachability/finite_differences/upwind_first.py ================================================ import functools import jax import jax.numpy as jnp import numpy as np import numpy.polynomial.polynomial as poly from types import ModuleType from typing import Any, Callable, Optional, Tuple Array = Any WENO_EPS = 1e-6 def weighted_essentially_non_oscillatory(eno_order: int, values: Array, spacing: float, boundary_condition: Callable[[Array, int], Array]) -> Tuple[Array, Array]: """Implements an upwind weighted essentially non-oscillatory (WENO) scheme for first derivative approximation. Args: eno_order: The order of the underlying essentially non-oscillatory (ENO) scheme; the resulting WENO scheme is `(2 * eno_order - 1)`th-order accurate. values: 1-dimensional array of function values assumed to be evaluated at a uniform grid in the domain. spacing: Grid spacing of the `values`. boundary_condition: A function used to pad `values` to implement a boundary condition (e.g., periodic). Returns: A tuple of arrays `(left_derivatives, right_derivatives)` each the same shape as `values` which contain, respectively, left and right approximations of the first derivative at the grid points of `values`. """ if eno_order < 1: raise ValueError(f"`eno_order` must be at least 1; got {eno_order}.") values = boundary_condition(values, eno_order) diffs = (values[1:] - values[:-1]) / spacing if eno_order == 1: return (diffs[:-1], diffs[1:]) substencil_approximations = tuple( _unrolled_correlate(diffs[i:len(diffs) - eno_order + i], c) for (i, c) in enumerate(_diff_coefficients(eno_order))) diffs2 = diffs[1:] - diffs[:-1] smoothness_indicators = [ sum( _unrolled_correlate(diffs2[i + j:len(diffs2) - eno_order + i + 1], L[j:, j])**2 for j in range(eno_order - 1)) for (i, L) in enumerate(np.linalg.cholesky(_smoothness_indicator_quad_form(eno_order))) ] left_and_right_unnormalized_weights = [[ c / (s[i:len(s) + i - 1] + WENO_EPS)**2 for (c, s) in zip(coefficients, smoothness_indicators) ] for (i, coefficients) in enumerate(_substencil_coefficients(eno_order))] return tuple( sum(w * a for (w, a) in zip(unnormalized_weights, substencil_approximations[i:eno_order + i])) / sum(unnormalized_weights) for (i, unnormalized_weights) in enumerate(left_and_right_unnormalized_weights)) def essentially_non_oscillatory(order: int, values: Array, spacing: float, boundary_condition: Callable[[Array, int], Array]) -> Tuple[Array, Array]: """Implements an upwind essentially non-oscillatory (ENO) scheme for first derivative approximation. Args: order: The desired order of accuracy for the ENO scheme. values: 1-dimensional array of function values assumed to be evaluated at a uniform grid in the domain. spacing: Grid spacing of the `values`. boundary_condition: A function used to pad `values` to implement a boundary condition (e.g., periodic). Returns: A tuple of arrays `(left_derivatives, right_derivatives)` each the same shape as `values` which contain, respectively, left and right approximations of the first derivative at the grid points of `values`. """ if order < 1: raise ValueError(f"`order` must be at least 1; got {order}.") values = boundary_condition(values, order) diffs = (values[1:] - values[:-1]) / spacing if order == 1: return (diffs[:-1], diffs[1:]) substencil_approximations = tuple( _unrolled_correlate(diffs[i:len(diffs) - order + i], c) for (i, c) in enumerate(_diff_coefficients(order))) undivided_differences = [] for i in range(2, order): diffs = diffs[1:] - diffs[:-1] undivided_differences.append(diffs[order - i:i - order]) abs_diffs = jnp.abs(diffs[1:] - diffs[:-1]) stencil_indices = abs_diffs[1:] < abs_diffs[:-1] for diffs in reversed(undivided_differences): abs_diffs = jnp.abs(diffs) stencil_indices = jnp.where(abs_diffs[1:] < abs_diffs[:-1], stencil_indices[1:] + 1, stencil_indices[:-1]) return (jnp.select([stencil_indices[:-1] == i for i in range(order - 1)], substencil_approximations[:-2], substencil_approximations[-2]), jnp.select([stencil_indices[1:] == i for i in range(order - 1)], substencil_approximations[1:-1], substencil_approximations[-1])) first_order = WENO1 = functools.partial(weighted_essentially_non_oscillatory, 1) WENO3 = functools.partial(weighted_essentially_non_oscillatory, 2) WENO5 = functools.partial(weighted_essentially_non_oscillatory, 3) ENO1 = functools.partial(essentially_non_oscillatory, 1) ENO2 = functools.partial(essentially_non_oscillatory, 2) ENO3 = functools.partial(essentially_non_oscillatory, 3) def _weighted_essentially_non_oscillatory_vectorized( eno_order: int, values: Array, spacing: float, boundary_condition: Callable[[Array, int], Array]) -> Tuple[Array, Array]: """Implements a more "vectorized" but ultimately slower version of `weighted_essentially_non_oscillatory`.""" if eno_order < 1: raise ValueError(f"`eno_order` must be at least 1; got {eno_order}.") values = boundary_condition(values, eno_order) diffs = (values[1:] - values[:-1]) / spacing if eno_order == 1: return (diffs[:-1], diffs[1:]) substencil_approximations = _align_substencil_values( jax.vmap(jnp.correlate, (None, 0), 0)(diffs, _diff_coefficients(eno_order)), jnp) diffs2 = diffs[1:] - diffs[:-1] chol_T = jnp.asarray(np.linalg.cholesky(_smoothness_indicator_quad_form(eno_order)).swapaxes(-1, -2)) smoothness_indicators = _align_substencil_values( jnp.sum(jnp.square(jax.vmap(jax.vmap(jnp.correlate, (None, 0), 1), (None, 0), 0)(diffs2, chol_T)), -1), jnp) unscaled_weights = 1 / jnp.square(smoothness_indicators + WENO_EPS) unnormalized_weights = (jnp.asarray(_substencil_coefficients(eno_order)[..., np.newaxis]) * jnp.stack([unscaled_weights[:, :-1], unscaled_weights[:, 1:]])) weights = unnormalized_weights / jnp.sum(unnormalized_weights, 1, keepdims=True) return tuple(jnp.sum(jnp.stack([substencil_approximations[:-1], substencil_approximations[1:]]) * weights, 1)) def _unrolled_correlate(a: Array, v: Array) -> Array: """An unrolled equivalent of `np.correlate`.""" return sum(a[i:len(a) - len(v) + i + 1] * x for (i, x) in enumerate(v)) def _substencils(k: int) -> Array: """Returns the `k + 1` subranges of length `k + 1` from the full stencil range `[-k, k + 1)`.""" return np.arange(k + 1) + np.arange(k + 1)[:, np.newaxis] - k def _spread_substencil_values(x: Array, np: ModuleType = np) -> Array: """Offsets each successive row of a matrix `x` by one additional column.""" return np.reshape(np.reshape(np.pad(x, ((0, 0), (0, x.shape[0]))), -1)[:-x.shape[0]], (x.shape[0], -1)) def _align_substencil_values(x: Array, np: ModuleType = np) -> Array: """Slices and stacks windows, each offset by one column from the previous, from rows of a matrix `x`.""" return np.reshape(np.pad(np.reshape(x, -1), (0, x.shape[0])), (x.shape[0], -1))[:, :-x.shape[0]] def _diff_coefficients(k: Optional[int] = None, stencil: Optional[Array] = None) -> Array: """Returns first derivative approximation finite difference coefficients for function value first differences.""" if k is None: if stencil is None: raise ValueError("One of `k` or `stencil` must be provided.") k = stencil.shape[-1] - 1 else: if stencil is None: stencil = _substencils(k) elif k != stencil.shape[-1] - 1: raise ValueError("`k` must match `stencil.shape[-1] - 1` if both arguments are provided; got " f"{(k, stencil.shape[-1] - 1)}.") return np.linalg.solve( np.diff(poly.polyvander(stencil, k), axis=-2)[..., 1:].swapaxes(-1, -2), np.eye(k)[(np.newaxis,) * (stencil.ndim - 1) + (0, ..., np.newaxis)])[..., 0] def _substencil_coefficients(k: int) -> Array: """Returns coefficients for combining substencil approximations to yield higher order left/right approximations.""" left_coefficients = np.linalg.solve( _spread_substencil_values(_diff_coefficients(k))[:-1, :k].T, _diff_coefficients(stencil=np.arange(-k, k))[:k]) return np.array([left_coefficients, left_coefficients[::-1]]) def _polyder_operator(k: int, d: int) -> Array: """Returns a matrix `D` such that `D @ p == poly.polyder(p, d)` for polynomials `p` of degree `k`.""" return np.concatenate([np.zeros((k + 1 - d, d)), np.diag(poly.polyder(np.ones(k + 1), d))], 1) def _smoothness_indicator_quad_form(k: int) -> Array: """Returns quadratic forms for computing substencil smoothness indicators as functions of second differences.""" interp_poly_second_der = (poly.polyder(np.ones(k + 1), 2)[:, np.newaxis] * np.linalg.inv(np.diff(poly.polyvander(_substencils(k)[1:], k), 2, axis=-2)[..., 2:])) quad_form = np.zeros((k, k - 1, k - 1)) for m in range(k - 1): integrator_matrix = 1 / (np.arange(k - 1 - m) + np.arange(k - 1 - m)[:, np.newaxis] + 1) interp_poly_m_plus_2_der = _polyder_operator(k - 2, m) @ interp_poly_second_der quad_form += interp_poly_m_plus_2_der.swapaxes(-1, -2) @ integrator_matrix @ interp_poly_m_plus_2_der return quad_form ================================================ FILE: hj_reachability/finite_differences/upwind_first_test.py ================================================ import math from absl.testing import absltest import jax import numpy as np from hj_reachability import boundary_conditions from hj_reachability.finite_differences import upwind_first class UpwindFirstTest(absltest.TestCase): def setUp(self): np.random.seed(0) def test_weighted_essentially_non_oscillatory(self): def _WENO5(values, spacing, boundary_condition): values = boundary_condition(values, 3) diffs = (values[1:] - values[:-1]) / spacing def compute_weno(v): phi = [ v[0] / 3 - 7 * v[1] / 6 + 11 * v[2] / 6, -v[1] / 6 + 5 * v[2] / 6 + v[3] / 3, v[2] / 3 + 5 * v[3] / 6 - v[4] / 6, ] s = [(13 / 12) * (v[0] - 2 * v[1] + v[2])**2 + (1 / 4) * (v[0] - 4 * v[1] + 3 * v[2])**2, (13 / 12) * (v[1] - 2 * v[2] + v[3])**2 + (1 / 4) * (v[1] - v[3])**2, (13 / 12) * (v[2] - 2 * v[3] + v[4])**2 + (1 / 4) * (3 * v[2] - 4 * v[3] + v[4])**2] a = [w / (x + upwind_first.WENO_EPS)**2 for (w, x) in zip([0.1, 0.6, 0.3], s)] w = [x / sum(a) for x in a] return sum(p * w for (p, w) in zip(phi, w)) return (compute_weno([diffs[i:-5 + i] for i in range(5)]), compute_weno([diffs[5 - i:None if i == 0 else -i] for i in range(5)])) values = np.random.rand(1000) spacing = 0.1 jax.tree.map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5), upwind_first.WENO5(values, spacing, boundary_conditions.periodic), _WENO5(values, spacing, boundary_conditions.periodic)) def test_essentially_non_oscillatory(self): def _brute_force_essentially_non_oscillatory(order, values, spacing, boundary_condition): def _divided_difference(x, i, spacing=1): if isinstance(i, int): return x[i] order = len(i) - 1 return np.diff(x[i], order)[0] / (math.factorial(order) * spacing**order) v = np.array(boundary_condition(values, order)) x = np.arange(len(v)) * spacing p = [np.poly1d(v[i]) for i in range(order - 1, len(v) - order)] ks = [] for i in range(len(p)): j = i + order - 1 p[i] += _divided_difference(v, [j, j + 1], spacing) * np.poly1d([x[j]], True) k = j for d in range(2, order + 1): a = _divided_difference(v, np.arange(k, k + d + 1), spacing) b = _divided_difference(v, np.arange(k - 1, k + d), spacing) if np.abs(a) >= np.abs(b): c = b k_next = k - 1 else: c = a k_next = k p[i] += c * np.poly1d(x[k:k + d], True) k = k_next ks.append(k - j) p_x = [np.polyder(f) for f in p] return (np.array([np.polyval(f, x) for (f, x) in zip(p_x[:-1], x[order:-order])]), np.array([np.polyval(f, x) for (f, x) in zip(p_x[1:], x[order:-order])])) values = np.random.rand(1000) spacing = 0.1 for order in range(1, 5): jax.tree.map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5), upwind_first.essentially_non_oscillatory(order, values, spacing, boundary_conditions.periodic), _brute_force_essentially_non_oscillatory(order, values, spacing, boundary_conditions.periodic)) def test_weighted_essentially_non_oscillatory_vectorized(self): values = np.random.rand(1000) spacing = 0.1 for eno_order in range(1, 5): jax.tree.map( lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5), upwind_first.weighted_essentially_non_oscillatory(eno_order, values, spacing, boundary_conditions.periodic), upwind_first._weighted_essentially_non_oscillatory_vectorized(eno_order, values, spacing, boundary_conditions.periodic)) def test_diff_coefficients(self): # k = 1 np.testing.assert_allclose(upwind_first._diff_coefficients(1), np.ones((2, 1))) # k = 2 np.testing.assert_allclose(upwind_first._diff_coefficients(2), np.array([[-1, 3], [1, 1], [3, -1]]) / 2) # k = 3 np.testing.assert_allclose(upwind_first._diff_coefficients(3), np.array([[2, -7, 11], [-1, 5, 2], [2, 5, -1], [11, -7, 2]]) / 6) def test_substencil_coefficients(self): # k = 1 np.testing.assert_allclose(upwind_first._substencil_coefficients(1), np.ones((2, 1))) # k = 2 np.testing.assert_allclose(upwind_first._substencil_coefficients(2), np.array([[1, 2], [2, 1]]) / 3) # k = 3 np.testing.assert_allclose(upwind_first._substencil_coefficients(3), np.array([[1, 6, 3], [3, 6, 1]]) / 10) def test_smoothness_indicator_quad_form(self): diff_operator = lambda k: np.eye(k - 1, k, 1) - np.eye(k - 1, k, 0) square_outer = lambda v: v[..., np.newaxis] * v[..., np.newaxis, :] # k = 1 np.testing.assert_allclose( diff_operator(1).T @ upwind_first._smoothness_indicator_quad_form(1) @ diff_operator(1), [[[0]]]) # k = 2 np.testing.assert_allclose( diff_operator(2).T @ upwind_first._smoothness_indicator_quad_form(2) @ diff_operator(2), square_outer(np.array([[1, -1], [1, -1]]))) # k = 3 np.testing.assert_allclose( diff_operator(3).T @ upwind_first._smoothness_indicator_quad_form(3) @ diff_operator(3), (13 / 12) * square_outer(np.array([[1, -2, 1], [1, -2, 1], [1, -2, 1]])) + (1 / 4) * square_outer(np.array([[1, -4, 3], [1, 0, -1], [3, -4, 1]]))) if __name__ == "__main__": absltest.main() ================================================ FILE: hj_reachability/grid.py ================================================ import functools from flax import struct import jax.numpy as jnp import numpy as np from hj_reachability import boundary_conditions as _boundary_conditions from hj_reachability.finite_differences import upwind_first from hj_reachability import sets from hj_reachability import utils from typing import Any, Callable, Optional, Tuple, Union from hj_reachability.boundary_conditions import BoundaryCondition Array = Any @struct.dataclass class Grid: """Class for representing Cartesian state grids with uniform spacing in each dimension. Attributes: states: An `(N + 1)` dimensional array containing the state values at each grid location. The first `N` dimensions correspond to the location in the grid, while the last dimension (itself of size `N`) contains the state vector. domain: A `Box` representing the domain of grid. coordinate_vectors: A tuple of `N` arrays containing the discrete state values in each dimension. The `states` attribute is produced by `stack`ing a `meshgrid` of these coordinate vectors. spacings: A tuple of `N` scalars containing the grid spacing (the difference between successive elements of the corresponding coordinate vector) in each dimension. boundary_conditions: A tuple of `N` boundary conditions for each dimension. These boundary conditions are functions used to pad values (notably not stored in this `Grid` data structure) to implement a boundary condition (e.g., periodic). """ states: Array domain: sets.Box coordinate_vectors: Tuple[Array, ...] spacings: Tuple[Array, ...] boundary_conditions: Tuple[BoundaryCondition, ...] = struct.field(pytree_node=False) @classmethod def from_lattice_parameters_and_boundary_conditions( cls, domain: sets.Box, shape: Tuple[int, ...], boundary_conditions: Optional[Tuple[BoundaryCondition, ...]] = None, periodic_dims: Optional[Union[int, Tuple[int, ...]]] = None) -> "Grid": """Constructs a `Grid` from a domain, shape, and boundary conditions. Args: domain: A `Box` representing the domain of grid. shape: A tuple of `N` integers denoting the number of discretization nodes in each dimension. boundary_conditions: A tuple of `N` boundary conditions for each dimension. If not provided, defaults to `extrapolate_away_from_zero` in each dimension, with the exception of those dimensions that appear in `periodic_dims` where the `periodic` boundary condition is used instead. periodic_dims: A single integer or tuple of integers denoting which dimensions are periodic in the case that the `boundary_conditions` are not explicitly provided as input to this factory method. Returns: A `Grid` constructed according to the provided specifications. """ ndim = len(shape) if boundary_conditions is None: if not isinstance(periodic_dims, tuple): periodic_dims = (periodic_dims,) boundary_conditions = tuple( _boundary_conditions.periodic if i in periodic_dims else _boundary_conditions.extrapolate_away_from_zero for i in range(ndim)) coordinate_vectors, spacings = zip( *(jnp.linspace(l, h, n, endpoint=bc is not _boundary_conditions.periodic, retstep=True) for l, h, n, bc in zip(domain.lo, domain.hi, shape, boundary_conditions))) states = jnp.stack(jnp.meshgrid(*coordinate_vectors, indexing="ij"), -1) return cls(states, domain, coordinate_vectors, spacings, boundary_conditions) @property def ndim(self) -> int: """Returns the dimension `N` of the grid.""" return self.states.ndim - 1 @property def shape(self) -> Tuple[int, ...]: """Returns the shape of the grid, a tuple of `N` integers.""" return self.states.shape[:-1] def upwind_grad_values(self, upwind_scheme: Callable, values: Array) -> Tuple[Array, Array]: """Returns `(left_grad_values, right_grad_values)`.""" left_derivatives, right_derivatives = zip(*[ utils.multivmap(lambda values: upwind_scheme(values, spacing, boundary_condition), np.array([j for j in range(self.ndim) if j != i]))(values) for i, (spacing, boundary_condition) in enumerate(zip(self.spacings, self.boundary_conditions)) ]) return (jnp.stack(left_derivatives, -1), jnp.stack(right_derivatives, -1)) def grad_values(self, values: Array, upwind_scheme: Optional[Callable] = None) -> Array: """Returns a central difference-based approximation of `grad_values`.""" # TODO: Implement central difference schemes in `hj_reachability.finite_differences`. if upwind_scheme is None: upwind_scheme = upwind_first.first_order return sum(self.upwind_grad_values(upwind_scheme, values)) / 2 def position(self, state: Array) -> Array: """Returns an array of `float`s corresponding to the position of `state` in the grid.""" position = (state - self.domain.lo) / jnp.array(self.spacings) return jnp.where(self._is_periodic_dim, position % np.array(self.shape), position) def nearest_index(self, state: Array) -> Array: """Returns the result of rounding `self.position(state)` to the nearest grid index.""" return jnp.round(self.position(state)).astype(jnp.int32) def interpolate(self, values, state): """Interpolates `values` (possibly multidimensional per node) defined over the grid at the given `state`.""" position = (state - self.domain.lo) / jnp.array(self.spacings) index_lo = jnp.floor(position).astype(jnp.int32) index_hi = index_lo + 1 weight_hi = position - index_lo weight_lo = 1 - weight_hi index_lo, index_hi = tuple( jnp.where(self._is_periodic_dim, index % np.array(self.shape), jnp.clip(index, 0, np.array(self.shape) - 1)) for index in (index_lo, index_hi)) weight = functools.reduce(lambda x, y: x * y, jnp.ix_(*jnp.stack([weight_lo, weight_hi], -1))) # TODO: Double-check numerical stability here and/or switch to `tuple`s and `itertools.product` for clarity. result = jnp.sum( weight[(...,) + (np.newaxis,) * (values.ndim - self.ndim)] * values[jnp.ix_(*jnp.stack([index_lo, index_hi], -1))], list(range(self.ndim))) return jnp.where(jnp.any(~self._is_periodic_dim & ((state < self.domain.lo) | (state > self.domain.hi))), jnp.nan, result) @property def _is_periodic_dim(self) -> Array: """Returns a boolean vector indicating which dimensions (if any) are periodic.""" return np.array([bc is _boundary_conditions.periodic for bc in self.boundary_conditions]) ================================================ FILE: hj_reachability/grid_test.py ================================================ from absl.testing import absltest import jax import jax.numpy as jnp import numpy as np from hj_reachability import grid as _grid from hj_reachability import sets class BoundaryConditionsTest(absltest.TestCase): def setUp(self): np.random.seed(0) def test_grid_interpolate(self): grid_domain = sets.Box(np.zeros(2), np.ones(2)) grid_shape = (3, 2) grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape, periodic_dims=1) values = np.random.random((3, 2)) np.testing.assert_allclose(grid.interpolate(values, np.array([0.25, 2.75])), np.mean(values[0:2, 0:2])) np.testing.assert_allclose(grid.interpolate(values, np.zeros(2)), values[0, 0]) np.testing.assert_allclose(grid.interpolate(values, np.ones(2)), values[-1, 0]) values = np.random.random((3, 2, 3, 4)) np.testing.assert_allclose(grid.interpolate(values, np.array([0.75, 2.75])), np.mean(values[1:3, 0:2], (0, 1))) np.testing.assert_allclose(grid.interpolate(values, np.zeros(2)), values[0, 0]) np.testing.assert_allclose(grid.interpolate(values, np.ones(2)), values[-1, 0]) def test_grid_interpolate_on_grid(self): grid_domain = sets.Box(jnp.zeros(2), jnp.ones(2)) grid_shape = (3, 4) for value_shape in ((), (5,)): values = jnp.array(np.random.random(grid_shape + value_shape)) grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape) np.testing.assert_allclose(jax.vmap(jax.vmap(lambda x: grid.interpolate(values, x)))(grid.states), values, atol=1e-6) grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape, periodic_dims=0) states = grid.states + (grid._is_periodic_dim * np.arange(-3, 4)[:, None, None, None] * (grid.domain.hi - grid.domain.lo)) np.testing.assert_allclose(jax.vmap(jax.vmap(jax.vmap(lambda x: grid.interpolate(values, x))))(states), np.broadcast_to(values, states.shape[:1] + values.shape), atol=1e-6) def test_grid_interpolate_off_grid(self): grid_domain = sets.Box(jnp.zeros(2), jnp.ones(2)) grid_shape = (3, 4) for value_shape in ((), (5,)): a = np.random.random((2,) + value_shape) grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape) values = grid.states @ a states = grid.domain.lo + np.random.random((100, 2)) * (grid.domain.hi - grid.domain.lo) np.testing.assert_allclose(jax.vmap(lambda x: grid.interpolate(values, x))(states), states @ a, atol=1e-6) grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape, periodic_dims=0) values = jnp.array(np.random.random(grid_shape + value_shape)) grid_unwrapped = _grid.Grid.from_lattice_parameters_and_boundary_conditions( grid.domain, tuple(d + 1 if p else d for d, p in zip(grid.shape, grid._is_periodic_dim))) values_unwrapped = jnp.concatenate([values, values[:1]]) states = states + (grid._is_periodic_dim * np.arange(-3, 4)[:, None, None] * (grid.domain.hi - grid.domain.lo)) np.testing.assert_allclose(jax.vmap(jax.vmap(lambda x: grid.interpolate(values, x)))(states), jax.vmap(jax.vmap(lambda x: grid_unwrapped.interpolate(values_unwrapped, x))) ((states - grid.domain.lo) % (grid.domain.hi - grid.domain.lo) + grid.domain.lo), atol=1e-6) def test_grid_interpolate_extrapolate_nan(self): grid_domain = sets.Box(jnp.zeros(2), jnp.ones(2)) grid_shape = (3, 4) for value_shape in ((), (5,)): values = jnp.array(np.random.random(grid_shape + value_shape)) grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape) states = grid.domain.lo + (grid.domain.hi - grid.domain.lo) * np.array( [[0.5 + dx, 0.5 + dy] for dx in [-1, 0, 1] for dy in [-1, 0, 1] if dx or dy]) result = jax.vmap(lambda x: grid.interpolate(values, x))(states) self.assertEqual(result.shape, (8,) + value_shape) self.assertTrue(np.all(np.isnan(result))) if __name__ == "__main__": absltest.main() ================================================ FILE: hj_reachability/sets.py ================================================ import abc from flax import struct import jax.numpy as jnp from hj_reachability import utils from typing import Any Array = Any @struct.dataclass class BoundedSet(metaclass=abc.ABCMeta): """Abstract base class for representing bounded subsets of Euclidean space.""" @abc.abstractmethod def extreme_point(self, direction: Array) -> Array: """Computes the point `x` in the set such that the dot product `x @ direction` is greatest.""" @property @abc.abstractmethod def bounding_box(self) -> "Box": """Returns an axis-aligned bounding box for the set.""" @property def max_magnitudes(self) -> Array: """Returns the maximum magnitude (per dimension) of points in the set.""" return jnp.maximum(jnp.abs(self.bounding_box.lo), jnp.abs(self.bounding_box.hi)) @property def ndim(self) -> int: """Returns the dimension of the Euclidean space the set lies within.""" return self.bounding_box.ndim @struct.dataclass class Box(BoundedSet): """Class for representing axis-aligned boxes.""" lo: Array hi: Array def extreme_point(self, direction: Array) -> Array: """Computes the point `x` in the set such that the dot product `x @ direction` is greatest.""" return jnp.where(direction < 0, self.lo, self.hi) @property def bounding_box(self) -> "Box": """Returns an axis-aligned bounding box for the set.""" return self @property def ndim(self) -> int: """Returns the dimension of the Euclidean space the set lies within.""" return self.lo.shape[-1] @struct.dataclass class Ball(BoundedSet): """Class for representing Euclidean (L2) balls.""" center: Array radius: Array def extreme_point(self, direction: Array) -> Array: """Computes the point `x` in the set such that the dot product `x @ direction` is greatest.""" return self.center + self.radius * utils.unit_vector(direction) @property def bounding_box(self) -> "Box": """Returns an axis-aligned bounding box for the set.""" return Box(self.center - jnp.expand_dims(self.radius, -1), self.center + jnp.expand_dims(self.radius, -1)) ================================================ FILE: hj_reachability/sets_test.py ================================================ from absl.testing import absltest import jax import numpy as np from hj_reachability import sets class SetsTest(absltest.TestCase): def setUp(self): np.random.seed(0) def test_box(self): box = sets.Box(np.ones(3), 2 * np.ones(3)) np.testing.assert_allclose(box.extreme_point(np.array([1, -1, 1])), np.array([2, 1, 2])) self.assertTrue(np.all(np.isfinite(box.extreme_point(np.zeros(3))))) self.assertEqual(box.bounding_box, box) np.testing.assert_allclose(box.max_magnitudes, 2 * np.ones(3)) self.assertEqual(box.ndim, 3) def test_ball(self): ball = sets.Ball(np.ones(3), np.sqrt(3)) np.testing.assert_allclose(ball.extreme_point(np.array([1, -1, 1])), np.array([2, 0, 2]), atol=1e-6) self.assertTrue(np.all(np.isfinite(ball.extreme_point(np.zeros(3))))) jax.tree.map(np.testing.assert_allclose, ball.bounding_box, sets.Box((1 - np.sqrt(3)) * np.ones(3), (1 + np.sqrt(3)) * np.ones(3))) np.testing.assert_allclose(ball.max_magnitudes, (1 + np.sqrt(3)) * np.ones(3)) self.assertEqual(ball.ndim, 3) if __name__ == "__main__": absltest.main() ================================================ FILE: hj_reachability/solver.py ================================================ import contextlib import functools from flax import struct import jax import jax.numpy as jnp import numpy as np from hj_reachability import artificial_dissipation from hj_reachability import time_integration from hj_reachability.finite_differences import upwind_first from typing import Callable, Text # Hamiltonian postprocessors. identity = lambda *x: x[-1] # Returns the last argument so that this may also be used as a value postprocessor. backwards_reachable_tube = lambda x: jnp.minimum(x, 0) # Value postprocessors. static_obstacle = lambda obstacle: (lambda t, v: jnp.maximum(v, obstacle)) @struct.dataclass class SolverSettings: upwind_scheme: Callable = struct.field( default=upwind_first.WENO5, pytree_node=False, ) artificial_dissipation_scheme: Callable = struct.field( default=artificial_dissipation.global_lax_friedrichs, pytree_node=False, ) hamiltonian_postprocessor: Callable = struct.field( default=identity, pytree_node=False, ) time_integrator: Callable = struct.field( default=time_integration.third_order_total_variation_diminishing_runge_kutta, pytree_node=False, ) value_postprocessor: Callable = struct.field( default=identity, pytree_node=False, ) CFL_number: float = 0.75 @classmethod def with_accuracy(cls, accuracy: Text, **kwargs) -> "SolverSettings": if accuracy == "low": upwind_scheme = upwind_first.first_order time_integrator = time_integration.first_order_total_variation_diminishing_runge_kutta elif accuracy == "medium": upwind_scheme = upwind_first.ENO2 time_integrator = time_integration.second_order_total_variation_diminishing_runge_kutta elif accuracy == "high": upwind_scheme = upwind_first.WENO3 time_integrator = time_integration.third_order_total_variation_diminishing_runge_kutta elif accuracy == "very_high": upwind_scheme = upwind_first.WENO5 time_integrator = time_integration.third_order_total_variation_diminishing_runge_kutta return cls(upwind_scheme=upwind_scheme, time_integrator=time_integrator, **kwargs) @functools.partial(jax.jit, static_argnames=("dynamics", "progress_bar")) def step(solver_settings, dynamics, grid, time, values, target_time, progress_bar=True): with (_try_get_progress_bar(time, target_time) if progress_bar is True else contextlib.nullcontext(progress_bar)) as bar: def sub_step(time_values): t, v = solver_settings.time_integrator(solver_settings, dynamics, grid, *time_values, target_time) if bar is not False: bar.update_to(jnp.abs(t - bar.reference_time)) return t, v return jax.lax.while_loop(lambda time_values: jnp.abs(target_time - time_values[0]) > 0, sub_step, (time, values))[1] @functools.partial(jax.jit, static_argnames=("dynamics", "progress_bar")) def solve(solver_settings, dynamics, grid, times, initial_values, progress_bar=True): with (_try_get_progress_bar(times[0], times[-1]) if progress_bar is True else contextlib.nullcontext(progress_bar)) as bar: make_carry_and_output_slice = lambda t, v: ((t, v), v) return jnp.concatenate([ initial_values[np.newaxis], jax.lax.scan( lambda time_values, target_time: make_carry_and_output_slice( target_time, step(solver_settings, dynamics, grid, *time_values, target_time, bar)), (times[0], initial_values), times[1:])[1] ]) def _try_get_progress_bar(reference_time, target_time): try: import tqdm except ImportError: raise ImportError("The option `progress_bar=True` requires the 'tqdm' package to be installed.") return TqdmWrapper(tqdm, reference_time, total=jnp.abs(target_time - reference_time), unit="sim_s", bar_format="{l_bar}{bar}| {n:7.4f}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]", ascii=True) class TqdmWrapper: def __init__(self, tqdm, reference_time, total, *args, **kwargs): self.reference_time = reference_time jax.experimental.io_callback( lambda total: self._create_tqdm(tqdm, float(total), *args, **kwargs), None, total, ordered=True, ) def _create_tqdm(self, tqdm, total, *args, **kwargs): self._tqdm = tqdm.tqdm(total=total, *args, **kwargs) def update_to(self, n): jax.experimental.io_callback( lambda n: self._tqdm.update(float(n) - self._tqdm.n) and None, None, n, ordered=True, ) def close(self): jax.experimental.io_callback( lambda: self._tqdm.close(), None, ordered=True, ) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() ================================================ FILE: hj_reachability/solver_test.py ================================================ from absl.testing import absltest import numpy as np import hj_reachability as hj class SolverTest(absltest.TestCase): def setUp(self): np.random.seed(0) solver_settings = hj.SolverSettings.with_accuracy("low") dynamics = hj.systems.Air3d() grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(np.array([-6., -10., 0.]), np.array([20., 10., 2 * np.pi])), (11, 10, 10), periodic_dims=2) self.problem_definition = { "solver_settings": solver_settings, "dynamics": dynamics, "grid": grid, } def test_step(self): values = np.linalg.norm(self.problem_definition["grid"].states[..., :2], axis=-1) - 5 target_values = hj.step(**self.problem_definition, time=0., values=values, target_time=-0.1, progress_bar=False) self.assertEqual(target_values.shape, values.shape) np.testing.assert_allclose( target_values, hj.step(**self.problem_definition, time=0., values=values, target_time=-0.1, progress_bar=True)) def test_solve(self): times = np.linspace(0, -0.1, 3) initial_values = np.linalg.norm(self.problem_definition["grid"].states[..., :2], axis=-1) - 5 all_values = hj.solve(**self.problem_definition, times=times, initial_values=initial_values, progress_bar=False) self.assertEqual(all_values.shape, (len(times),) + initial_values.shape) np.testing.assert_allclose(all_values[0], initial_values) np.testing.assert_allclose(all_values[-1], hj.step(**self.problem_definition, time=0., values=initial_values, target_time=-0.1, progress_bar=False), atol=1e-2) np.testing.assert_allclose( all_values, hj.solve(**self.problem_definition, times=times, initial_values=initial_values, progress_bar=True)) ================================================ FILE: hj_reachability/systems/__init__.py ================================================ from hj_reachability.systems.air3d import Air3d, DubinsCarCAvoid __all__ = ("Air3d", "DubinsCarCAvoid") ================================================ FILE: hj_reachability/systems/air3d.py ================================================ import jax.numpy as jnp from hj_reachability import dynamics from hj_reachability import sets class Air3d(dynamics.ControlAndDisturbanceAffineDynamics): def __init__(self, evader_speed=5., pursuer_speed=5., evader_max_turn_rate=1., pursuer_max_turn_rate=1., control_mode="max", disturbance_mode="min", control_space=None, disturbance_space=None): self.evader_speed = evader_speed self.pursuer_speed = pursuer_speed if control_space is None: control_space = sets.Box(jnp.array([-evader_max_turn_rate]), jnp.array([evader_max_turn_rate])) if disturbance_space is None: disturbance_space = sets.Box(jnp.array([-pursuer_max_turn_rate]), jnp.array([pursuer_max_turn_rate])) super().__init__(control_mode, disturbance_mode, control_space, disturbance_space) def open_loop_dynamics(self, state, time): _, _, psi = state v_a, v_b = self.evader_speed, self.pursuer_speed return jnp.array([-v_a + v_b * jnp.cos(psi), v_b * jnp.sin(psi), 0.]) def control_jacobian(self, state, time): x, y, _ = state return jnp.array([ [y], [-x], [-1.], ]) def disturbance_jacobian(self, state, time): return jnp.array([ [0.], [0.], [1.], ]) DubinsCarCAvoid = Air3d ================================================ FILE: hj_reachability/time_integration.py ================================================ import functools import jax import jax.numpy as jnp import numpy as np from hj_reachability import utils def lax_friedrichs_numerical_hamiltonian(hamiltonian, state, time, value, left_grad_value, right_grad_value, dissipation_coefficients): hamiltonian_value = hamiltonian(state, time, value, (left_grad_value + right_grad_value) / 2) dissipation_value = dissipation_coefficients @ (right_grad_value - left_grad_value) / 2 return hamiltonian_value - dissipation_value @functools.partial(jax.jit, static_argnames="dynamics") def euler_step(solver_settings, dynamics, grid, time, values, time_step=None, max_time_step=None): time_direction = jnp.sign(max_time_step) if time_step is None else jnp.sign(time_step) signed_hamiltonian = lambda *args, **kwargs: time_direction * dynamics.hamiltonian(*args, **kwargs) left_grad_values, right_grad_values = grid.upwind_grad_values(solver_settings.upwind_scheme, values) dissipation_coefficients = solver_settings.artificial_dissipation_scheme(dynamics.partial_max_magnitudes, grid.states, time, values, left_grad_values, right_grad_values) dvalues_dt = -solver_settings.hamiltonian_postprocessor(time_direction * utils.multivmap( lambda state, value, left_grad_value, right_grad_value, dissipation_coefficients: (lax_friedrichs_numerical_hamiltonian(signed_hamiltonian, state, time, value, left_grad_value, right_grad_value, dissipation_coefficients)), np.arange(grid.ndim))(grid.states, values, left_grad_values, right_grad_values, dissipation_coefficients)) if time_step is None: time_step_bound = 1 / jnp.max(jnp.sum(dissipation_coefficients / jnp.array(grid.spacings), -1)) time_step = time_direction * jnp.minimum(solver_settings.CFL_number * time_step_bound, jnp.abs(max_time_step)) # TODO: Think carefully about whether `solver_settings.value_postprocessor` should be applied here instead. return time + time_step, values + time_step * dvalues_dt def first_order_total_variation_diminishing_runge_kutta(solver_settings, dynamics, grid, time, values, target_time): time_1, values_1 = euler_step(solver_settings, dynamics, grid, time, values, max_time_step=target_time - time) return time_1, solver_settings.value_postprocessor(time_1, values_1) def second_order_total_variation_diminishing_runge_kutta(solver_settings, dynamics, grid, time, values, target_time): time_1, values_1 = euler_step(solver_settings, dynamics, grid, time, values, max_time_step=target_time - time) time_step = time_1 - time _, values_2 = euler_step(solver_settings, dynamics, grid, time_1, values_1, time_step) return time_1, solver_settings.value_postprocessor(time_1, (values + values_2) / 2) def third_order_total_variation_diminishing_runge_kutta(solver_settings, dynamics, grid, time, values, target_time): time_1, values_1 = euler_step(solver_settings, dynamics, grid, time, values, max_time_step=target_time - time) time_step = time_1 - time _, values_2 = euler_step(solver_settings, dynamics, grid, time_1, values_1, time_step) time_0_5, values_0_5 = time + time_step / 2, (3 / 4) * values + (1 / 4) * values_2 _, values_1_5 = euler_step(solver_settings, dynamics, grid, time_0_5, values_0_5, time_step) return time_1, solver_settings.value_postprocessor(time_1, (1 / 3) * values + (2 / 3) * values_1_5) ================================================ FILE: hj_reachability/utils.py ================================================ import functools import jax import jax.numpy as jnp import numpy as np from typing import Any, Callable, Iterable, List, Mapping, Optional, TypeVar, Union T = TypeVar("T") Tree = Union[T, Iterable["Tree[T]"], Mapping[Any, "Tree[T]"]] def multivmap(fun: Callable, in_axes: Tree[Optional[np.ndarray]], out_axes: Tree[Optional[np.ndarray]] = None) -> Callable: """Applies `jax.vmap` over multiple axes (equivalent to multiple nested `jax.vmap`s). Args: fun: Function to be mapped over additional axes (see `jax.vmap` for more details). in_axes: Similar to the specification of `in_axes` for `jax.vmap`, with the main difference being that instead of `Optional[int]` for axis specification, it's `Optional[np.ndarray]`. For each corresponding input of `fun`, the `np.ndarray` specifies a sequence of axes to `jax.vmap` over; note that these axes are not specified directly as a `list` so as not to conflict with the possible structure of `in_axes`. All non-`None` leaves of `in_axes` (there must be at least one) must have the same length. This length is the number of times `jax.vmap` will be applied to `fun`. out_axes: Similar to the specification of `out_axes` for `jax.vmap`, with the main difference being that instead of `Optional[int]` for axis specification, it's `Optional[np.ndarray]`. For each corresponding output of `fun`, the `np.ndarray` specifies a sequence of additional mapped axes to appear in the output. The length of non-`None` leaves of `out_axes` must be the same as the length of non-`None` leaves of `in_axes`; the order of both axes specifications corresponds to successive nested `jax.vmap` applications. If not provided, `out_axes` defaults to `in_axes`. Returns: A batched/vectorized version of `fun` with arguments that correspond to those of `fun`, but with (possibly multiple per input) extra array axes at positions indicated by `in_axes`, and a return value that corresponds to that of `fun`, but with (possibly multiple per output) extra array axes at positions indicated by `out_axes`. Raises: ValueError: if any specified axes are negative or repeated. """ def get_axis_sequence(axis_array: np.ndarray) -> List: axis_list = axis_array.tolist() if any(axis < 0 for axis in axis_list): raise ValueError(f"All `multivmap` axes must be nonnegative; got {axis_list}.") if len(axis_list) != len(set(axis_list)): raise ValueError(f"All `multivmap` axes must be distinct; got {axis_list}.") for i in range(len(axis_list)): for j in range(i + 1, len(axis_list)): if axis_list[i] > axis_list[j]: axis_list[i] -= 1 return axis_list multivmap_kwargs = {"in_axes": in_axes, "out_axes": in_axes if out_axes is None else out_axes} axis_sequence_structure = jax.tree.structure(next(a for a in jax.tree.leaves(in_axes) if a is not None).tolist()) vmap_kwargs = jax.tree.transpose(jax.tree.structure(multivmap_kwargs), axis_sequence_structure, jax.tree.map(get_axis_sequence, multivmap_kwargs)) return functools.reduce(lambda f, kwargs: jax.vmap(f, **kwargs), vmap_kwargs, fun) def unit_vector(x): """Normalizes a vector `x`, returning a unit vector in the same direction, or a zero vector if `x` is zero.""" norm2 = jnp.sum(jnp.square(x)) iszero = norm2 < jnp.finfo(jnp.zeros(()).dtype).eps**2 return jnp.where(iszero, jnp.zeros_like(x), x / jnp.sqrt(jnp.where(iszero, 1, norm2))) ================================================ FILE: hj_reachability/utils_test.py ================================================ from absl.testing import absltest import jax import jax.numpy as jnp import numpy as np from hj_reachability import utils class UtilsTest(absltest.TestCase): def setUp(self): np.random.seed(0) def test_multivmap(self): a = np.random.random((3, 4, 5, 6)) np.testing.assert_allclose(utils.multivmap(jnp.max, np.array([0, 1]))(a), np.max(a, (2, 3))) np.testing.assert_allclose(utils.multivmap(jnp.max, np.array([0, 1, 2]))(a), np.max(a, -1)) np.testing.assert_allclose(utils.multivmap(jnp.max, np.array([0, 1, 3]), np.array([0, 1, 2]))(a), np.max(a, 2)) np.testing.assert_allclose( utils.multivmap(jnp.max, np.array([1, 0, 2]), np.array([0, 1, 2]))(a), np.max(a, 3).swapaxes(0, 1)) np.testing.assert_allclose( utils.multivmap(jnp.max, np.array([3, 2]), np.array([0, 1]))(a), np.max(a, (0, 1)).swapaxes(0, 1)) def test_unit_vector(self): unsafe_unit_vector = lambda x: x / jnp.linalg.norm(x, axis=-1, keepdims=True) for d in range(1, 4): np.testing.assert_array_equal(utils.unit_vector(np.zeros(d)), np.zeros(d)) self.assertTrue(np.all(np.isfinite(jax.jacobian(utils.unit_vector)(np.zeros(d))))) self.assertTrue(np.all(np.isnan(jax.jacobian(unsafe_unit_vector)(np.zeros(d))))) a = np.random.random((100, d)) np.testing.assert_allclose(jax.vmap(utils.unit_vector)(a), unsafe_unit_vector(a), atol=1e-6) np.testing.assert_allclose(jax.vmap(jax.jacobian(utils.unit_vector))(a), jax.vmap(jax.jacobian(unsafe_unit_vector))(a), atol=1e-6) if __name__ == "__main__": absltest.main() ================================================ FILE: requirements-test.txt ================================================ absl-py>=0.12.0 tqdm>=4.60.0 ================================================ FILE: requirements.txt ================================================ flax>=0.6.6 jax>=0.4.25 numpy>=1.22 ================================================ FILE: setup.cfg ================================================ [yapf] based_on_style = google column_limit = 120 [flake8] max-line-length = 120 ignore = # E731: do not assign a lambda expression, use a def E731 # E741: do not use variables named 'I', 'O', or 'l' E741 # W504: line break occurred after a binary operator W504 ================================================ FILE: setup.py ================================================ import os import setuptools _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) def _get_version(): with open(os.path.join(_CURRENT_DIR, "hj_reachability", "__init__.py")) as f: for line in f: if line.startswith("__version__") and "=" in line: version = line[line.find("=") + 1:].strip(" '\"\n") if version: return version raise ValueError("`__version__` not defined in `hj_reachability/__init__.py`") def _parse_requirements(file): with open(os.path.join(_CURRENT_DIR, file)) as f: return [line.rstrip() for line in f if not (line.isspace() or line.startswith("#"))] setuptools.setup(name="hj_reachability", version=_get_version(), description="Hamilton-Jacobi reachability analysis in JAX.", long_description=open("README.md").read(), long_description_content_type="text/markdown", author="Ed Schmerling", author_email="ednerd@gmail.com", url="https://github.com/StanfordASL/hj_reachability", license="MIT", packages=setuptools.find_packages(), install_requires=_parse_requirements("requirements.txt"), tests_require=_parse_requirements("requirements-test.txt"), python_requires="~=3.8")