[
  {
    "path": ".github/workflows/ci.yml",
    "content": "name: ci\n\non: [push, pull_request]\n\njobs:\n  test:\n    name: \"Python ${{ matrix.python-version }} on ${{ matrix.os }}\"\n    runs-on: \"${{ matrix.os }}\"\n\n    strategy:\n      matrix:\n        python-version: [\"3.9\", \"3.10\", \"3.11\", \"3.12\"]\n        os: [ubuntu-latest]\n\n    steps:\n    - uses: actions/checkout@v3\n    - uses: actions/setup-python@v4\n      with:\n        python-version: ${{ matrix.python-version }}\n    - name: Install dependencies\n      run: |\n        set -xe\n        python -m pip install --upgrade pip\n        pip install flake8 pytest pytest-xdist yapf\n        pip install -r requirements.txt\n        pip install -r requirements-test.txt\n    - name: Lint with flake8\n      run: |\n        set -xe\n        flake8 . --config=setup.cfg --count --statistics\n    - name: Check formatting with yapf\n      run: |\n        set -xe\n        yapf . --style=setup.cfg --recursive --diff\n    - name: Test with pytest\n      run: |\n        set -xe\n        pytest -n \"$(grep -c ^processor /proc/cpuinfo)\" hj_reachability\n"
  },
  {
    "path": ".github/workflows/pypi-publish.yml",
    "content": "name: pypi\n\non:\n  release:\n    types: [published]\n\njobs:\n  deploy:\n    runs-on: ubuntu-latest\n\n    steps:\n    - uses: actions/checkout@v4\n    - name: Set up Python\n      uses: actions/setup-python@v5\n      with:\n        python-version: '3.x'\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip\n        pip install build setuptools\n    - name: Check consistency between the package version and release tag\n      run: |\n        PACKAGE_VER=\"v`python setup.py --version`\"\n        if [ $PACKAGE_VER != ${{ github.event.release.tag_name }} ]\n        then\n          echo \"Package version ($PACKAGE_VER) != release tag (${{ github.event.release.tag_name }}).\"; exit 1\n        fi\n    - name: Build package\n      run: python -m build\n    - name: Publish to PyPI\n      uses: pypa/gh-action-pypi-publish@release/v1\n      with:\n        user: __token__\n        password: ${{ secrets.PYPI_API_TOKEN }}\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2021 Ed Schmerling\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "include requirements.txt\ninclude requirements-test.txt\n"
  },
  {
    "path": "README.md",
    "content": "# hj_reachability: Hamilton-Jacobi reachability analysis in [JAX]\nThis package implements numerical solvers for Hamilton-Jacobi (HJ) Partial Differential Equations (PDEs) which, in the context of optimal control, may be used to represent the continuous-time formulation of dynamic programming. Specifically, the focus of this package is on reachability analysis for zero-sum differential games modeled as Hamilton-Jacobi-Isaacs (HJI) PDEs, wherein an optimal controller and (optional) disturbance interact, and the set of reachable states at any time is represented as the zero sublevel set of a value function realized as the viscosity solution of the corresponding PDE.\n\nThis package is inspired by a number of related projects, including:\n\n- [A Toolbox of Level Set Methods (`toolboxls`, MATLAB)](https://www.cs.ubc.ca/~mitchell/ToolboxLS/)\n- [An Optimal Control Toolbox for Hamilton-Jacobi Reachability Analysis (`helperOC`, MATLAB)](https://github.com/HJReachability/helperOC)\n- [Berkeley Efficient API in C++ for Level Set methods (`beacls`, C++/CUDA C++)](https://hjreachability.github.io/beacls/)\n- [Optimizing Dynamic Programming-Based Algorithms (`optimized_dp`, python)](https://github.com/SFU-MARS/optimized_dp)\n\n## Installation\nThis package accommodates different [JAX] versions (i.e., CPU-only vs. JAX with GPU support); if accelerator support is desired you should first install JAX according to the relevant [installation instructions](https://github.com/google/jax#installation). A minimum JAX version requirement is listed in [`requirements.txt`](https://github.com/StanfordASL/hj_reachability/blob/main/requirements.txt), but in general this package should be compatible with the latest JAX releases (please [file an issue](https://github.com/StanfordASL/hj_reachability/issues) if you find that this is no longer the case!).\n\nIf you only want CPU computation or have already installed JAX with your preferred accelerator support, you may install this package using pip:\n```\npip install --upgrade hj-reachability\n```\n\n## TODOs\nAside from the specific TODOs scattered throughout the codebase, a few general TODOs:\n- Single-line docstrings (at a bare minimum) for everything. Test coverage, book/paper references, and proper documentation to come... eventually.\n- Look into using `jax.pmap`/`jax.lax.ppermute` for multi-device parallelism; see, e.g., [jax demo notebooks](https://github.com/google/jax/tree/master/cloud_tpu_colabs).\n- Incorporate neural-network-based PDE solvers; see, e.g., [Bansal, S. and Tomlin, C. \"DeepReach: A Deep Learning Approach to High-Dimensional Reachability.\" (2020)](https://arxiv.org/abs/2011.02082).\n\n[JAX]: https://github.com/google/jax\n"
  },
  {
    "path": "examples/quickstart.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# hj_reachability quickstart\\n\",\n    \"\\n\",\n    \"Notebook dependencies:\\n\",\n    \"- System: python3, ffmpeg (for rendering animations)\\n\",\n    \"- Python: jupyter, jax, numpy, matplotlib, plotly, tqdm, hj_reachability\\n\",\n    \"\\n\",\n    \"Example setup for a Ubuntu system (Mac users, maybe `brew` instead of `sudo apt`; Windows users, learn to love [WSL](https://docs.microsoft.com/en-us/windows/wsl/install-win10)):\\n\",\n    \"```\\n\",\n    \"sudo apt install ffmpeg\\n\",\n    \"/usr/bin/python3 -m pip install --upgrade pip\\n\",\n    \"pip install --upgrade jupyter jax[cpu] numpy matplotlib plotly tqdm hj-reachability\\n\",\n    \"jupyter notebook  # from the directory of this notebook\\n\",\n    \"```\\n\",\n    \"Alternatively, view this notebook on [Google Colab](https://colab.research.google.com/github/StanfordASL/hj_reachability/blob/main/examples/quickstart.ipynb) and run a cell containing this command:\\n\",\n    \"```\\n\",\n    \"!pip install --upgrade hj-reachability\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import jax\\n\",\n    \"import jax.numpy as jnp\\n\",\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"from IPython.display import HTML\\n\",\n    \"import matplotlib.animation as anim\\n\",\n    \"import matplotlib.pyplot as plt\\n\",\n    \"import plotly.graph_objects as go\\n\",\n    \"\\n\",\n    \"import hj_reachability as hj\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Example system: `Air3d`\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dynamics = hj.systems.Air3d()\\n\",\n    \"grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(np.array([-6., -10., 0.]),\\n\",\n    \"                                                                           np.array([20., 10., 2 * np.pi])),\\n\",\n    \"                                                               (51, 40, 50),\\n\",\n    \"                                                               periodic_dims=2)\\n\",\n    \"values = jnp.linalg.norm(grid.states[..., :2], axis=-1) - 5\\n\",\n    \"\\n\",\n    \"solver_settings = hj.SolverSettings.with_accuracy(\\\"very_high\\\",\\n\",\n    \"                                                  hamiltonian_postprocessor=hj.solver.backwards_reachable_tube)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### `hj.step`: propagate the HJ PDE from `(time, values)` to `target_time`.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"time = 0.\\n\",\n    \"target_time = -2.8\\n\",\n    \"target_values = hj.step(solver_settings, dynamics, grid, time, values, target_time)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"plt.jet()\\n\",\n    \"plt.figure(figsize=(13, 8))\\n\",\n    \"plt.contourf(grid.coordinate_vectors[0], grid.coordinate_vectors[1], target_values[:, :, 30].T)\\n\",\n    \"plt.colorbar()\\n\",\n    \"plt.contour(grid.coordinate_vectors[0],\\n\",\n    \"            grid.coordinate_vectors[1],\\n\",\n    \"            target_values[:, :, 30].T,\\n\",\n    \"            levels=0,\\n\",\n    \"            colors=\\\"black\\\",\\n\",\n    \"            linewidths=3)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"go.Figure(data=go.Isosurface(x=grid.states[..., 0].ravel(),\\n\",\n    \"                             y=grid.states[..., 1].ravel(),\\n\",\n    \"                             z=grid.states[..., 2].ravel(),\\n\",\n    \"                             value=target_values.ravel(),\\n\",\n    \"                             colorscale=\\\"jet\\\",\\n\",\n    \"                             isomin=0,\\n\",\n    \"                             surface_count=1,\\n\",\n    \"                             isomax=0))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### `hj.solve`: solve for `all_values` at a range of `times` (basically just iterating `hj.step`).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"times = np.linspace(0, -2.8, 57)\\n\",\n    \"initial_values = values\\n\",\n    \"all_values = hj.solve(solver_settings, dynamics, grid, times, initial_values)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vmin, vmax = all_values.min(), all_values.max()\\n\",\n    \"levels = np.linspace(round(vmin), round(vmax), round(vmax) - round(vmin) + 1)\\n\",\n    \"fig = plt.figure(figsize=(13, 8))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def render_frame(i, colorbar=False):\\n\",\n    \"    plt.contourf(grid.coordinate_vectors[0],\\n\",\n    \"                 grid.coordinate_vectors[1],\\n\",\n    \"                 all_values[i, :, :, 30].T,\\n\",\n    \"                 vmin=vmin,\\n\",\n    \"                 vmax=vmax,\\n\",\n    \"                 levels=levels)\\n\",\n    \"    if colorbar:\\n\",\n    \"        plt.colorbar()\\n\",\n    \"    plt.contour(grid.coordinate_vectors[0],\\n\",\n    \"                grid.coordinate_vectors[1],\\n\",\n    \"                target_values[:, :, 30].T,\\n\",\n    \"                levels=0,\\n\",\n    \"                colors=\\\"black\\\",\\n\",\n    \"                linewidths=3)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"render_frame(0, True)\\n\",\n    \"animation = HTML(anim.FuncAnimation(fig, render_frame, all_values.shape[0], interval=50).to_html5_video())\\n\",\n    \"plt.close(); animation\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Defining your own dynamics: `AccelerationCurvatureCar`\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"class AccelerationCurvatureCar(hj.ControlAndDisturbanceAffineDynamics):\\n\",\n    \"\\n\",\n    \"    def __init__(self,\\n\",\n    \"                 max_acceleration=1.,\\n\",\n    \"                 max_curvature=1.,\\n\",\n    \"                 max_position_disturbance=0.25,\\n\",\n    \"                 control_mode=\\\"min\\\",\\n\",\n    \"                 disturbance_mode=\\\"max\\\",\\n\",\n    \"                 control_space=None,\\n\",\n    \"                 disturbance_space=None):\\n\",\n    \"        if control_space is None:\\n\",\n    \"            control_space = hj.sets.Box(jnp.array([-max_acceleration, -max_curvature]),\\n\",\n    \"                                        jnp.array([max_acceleration, max_curvature]))\\n\",\n    \"        if disturbance_space is None:\\n\",\n    \"            disturbance_space = hj.sets.Ball(jnp.zeros(2), max_position_disturbance)\\n\",\n    \"        super().__init__(control_mode, disturbance_mode, control_space, disturbance_space)\\n\",\n    \"\\n\",\n    \"    def open_loop_dynamics(self, state, time):\\n\",\n    \"        _, _, v, q = state\\n\",\n    \"        return jnp.array([v * jnp.cos(q), v * jnp.sin(q), 0., 0.])\\n\",\n    \"\\n\",\n    \"    def control_jacobian(self, state, time):\\n\",\n    \"        v = state[2]\\n\",\n    \"        return jnp.array([\\n\",\n    \"            [0., 0.],\\n\",\n    \"            [0., 0.],\\n\",\n    \"            [1., 0.],\\n\",\n    \"            [0., v],\\n\",\n    \"        ])\\n\",\n    \"\\n\",\n    \"    def disturbance_jacobian(self, state, time):\\n\",\n    \"        return jnp.array([\\n\",\n    \"            [1., 0.],\\n\",\n    \"            [0., 1.],\\n\",\n    \"            [0., 0.],\\n\",\n    \"            [0., 0.],\\n\",\n    \"        ])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dynamics = AccelerationCurvatureCar()\\n\",\n    \"grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(lo=np.array([-5., -5., -1., -np.pi]),\\n\",\n    \"                                                                           hi=np.array([5., 5., 1., np.pi])),\\n\",\n    \"                                                               (40, 40, 50, 50),\\n\",\n    \"                                                               periodic_dims=3)\\n\",\n    \"values = jnp.linalg.norm(grid.states[..., :2], axis=-1) - 1\\n\",\n    \"\\n\",\n    \"solver_settings = hj.SolverSettings.with_accuracy(\\\"low\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"time = 0.\\n\",\n    \"target_time = -2.0\\n\",\n    \"target_values = hj.step(solver_settings, dynamics, grid, time, values, target_time)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"go.Figure(data=go.Isosurface(x=grid.states[:, :, -1, :, 0].ravel(),\\n\",\n    \"                             y=grid.states[:, :, -1, :, 1].ravel(),\\n\",\n    \"                             z=grid.states[:, :, -1, :, 3].ravel(),\\n\",\n    \"                             value=target_values[:, :, -1, :].ravel(),\\n\",\n    \"                             colorscale='jet',\\n\",\n    \"                             isomin=0,\\n\",\n    \"                             surface_count=1,\\n\",\n    \"                             isomax=0))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.10\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "hj_reachability/__init__.py",
    "content": "from hj_reachability import artificial_dissipation\nfrom hj_reachability import boundary_conditions\nfrom hj_reachability import finite_differences\nfrom hj_reachability import sets\nfrom hj_reachability import solver\nfrom hj_reachability import systems\nfrom hj_reachability import time_integration\nfrom hj_reachability import utils\nfrom hj_reachability.dynamics import ControlAndDisturbanceAffineDynamics, Dynamics\nfrom hj_reachability.grid import Grid\nfrom hj_reachability.solver import SolverSettings, solve, step\n\n__version__ = \"0.7.0\"\n\n__all__ = (\"ControlAndDisturbanceAffineDynamics\", \"Dynamics\", \"Grid\", \"SolverSettings\", \"artificial_dissipation\",\n           \"boundary_conditions\", \"finite_differences\", \"sets\", \"solve\", \"solver\", \"step\", \"systems\",\n           \"time_integration\", \"utils\")\n"
  },
  {
    "path": "hj_reachability/artificial_dissipation.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom hj_reachability import sets\nfrom hj_reachability import utils\n\n\ndef global_lax_friedrichs(partial_max_magnitudes, states, time, values, left_grad_values, right_grad_values):\n    \"\"\"Implements the Global Lax-Friedrichs (GLF) scheme for computing dissipation coefficients.\"\"\"\n    grid_axes = np.arange(values.ndim)\n    grad_value_box = sets.Box(jnp.minimum(jnp.min(left_grad_values, grid_axes), jnp.min(right_grad_values, grid_axes)),\n                              jnp.maximum(jnp.max(left_grad_values, grid_axes), jnp.max(right_grad_values, grid_axes)))\n    return utils.multivmap(lambda state, value: partial_max_magnitudes(state, time, value, grad_value_box),\n                           grid_axes)(states, values)\n\n\ndef local_lax_friedrichs(partial_max_magnitudes, states, time, values, left_grad_values, right_grad_values):\n    \"\"\"Implements the Local Lax-Friedrichs (LLF) scheme for computing dissipation coefficients.\"\"\"\n    grid_axes = np.arange(values.ndim)\n    global_grad_value_box = sets.Box(\n        jnp.minimum(jnp.min(left_grad_values, grid_axes), jnp.min(right_grad_values, grid_axes)),\n        jnp.maximum(jnp.max(left_grad_values, grid_axes), jnp.max(right_grad_values, grid_axes)))\n    local_local_grad_value_boxes = sets.Box(jnp.minimum(left_grad_values, right_grad_values),\n                                            jnp.maximum(left_grad_values, right_grad_values))\n    local_grad_value_boxes = jax.tree.map(\n        lambda global_grad_value, local_local_grad_values:\n        (jnp.broadcast_to(global_grad_value, values.shape +\n                          (values.ndim,) * 2).at[..., grid_axes, grid_axes].set(local_local_grad_values)),\n        global_grad_value_box, local_local_grad_value_boxes)\n    return utils.multivmap(\n        lambda state, value, grad_value_box: partial_max_magnitudes(state, time, value, grad_value_box),\n        grid_axes)(states, values, local_grad_value_boxes)\n\n\ndef local_local_lax_friedrichs(partial_max_magnitudes, states, time, values, left_grad_values, right_grad_values):\n    \"\"\"Implements the Local Local Lax-Friedrichs (LLLF) scheme for computing dissipation coefficients.\"\"\"\n    grid_axes = np.arange(values.ndim)\n    local_local_grad_value_boxes = sets.Box(jnp.minimum(left_grad_values, right_grad_values),\n                                            jnp.maximum(left_grad_values, right_grad_values))\n    return utils.multivmap(\n        lambda state, value, grad_value_box: partial_max_magnitudes(state, time, value, grad_value_box),\n        grid_axes)(states, values, local_local_grad_value_boxes)\n"
  },
  {
    "path": "hj_reachability/boundary_conditions.py",
    "content": "import jax.numpy as jnp\n\nfrom typing import Any, Callable\n\nArray = Any\n\nBoundaryCondition = Callable[[Array, int], Array]\n\n\ndef periodic(x: Array, pad_width: int) -> Array:\n    \"\"\"Pads a 1D array `x` by wrapping values, using the start values to pad the end and vice versa.\"\"\"\n    return jnp.pad(x, ((pad_width, pad_width)), \"wrap\")\n\n\ndef extrapolate(x: Array, pad_width: int) -> Array:\n    \"\"\"Pads a 1D array `x` by extrapolating using the slope at each end.\"\"\"\n    return jnp.concatenate(\n        [x[0] + (x[1] - x[0]) * jnp.arange(-pad_width, 0), x, x[-1] + (x[-1] - x[-2]) * jnp.arange(1, pad_width + 1)])\n\n\ndef extrapolate_away_from_zero(x: Array, pad_width: int) -> Array:\n    \"\"\"Pads a 1D array `x` by extrapolating away from zero using the (possibly negated) slope at each end.\"\"\"\n    return jnp.concatenate([\n        x[0] - jnp.sign(x[0]) * jnp.abs(x[1] - x[0]) * jnp.arange(-pad_width, 0), x,\n        x[-1] + jnp.sign(x[-1]) * jnp.abs(x[-1] - x[-2]) * jnp.arange(1, pad_width + 1)\n    ])\n"
  },
  {
    "path": "hj_reachability/boundary_conditions_test.py",
    "content": "from absl.testing import absltest\nimport numpy as np\n\nfrom hj_reachability import boundary_conditions\n\n\nclass BoundaryConditionsTest(absltest.TestCase):\n\n    def setUp(self):\n        np.random.seed(0)\n\n    def test_periodic(self):\n        x = np.arange(5)\n        np.testing.assert_array_equal(boundary_conditions.periodic(x, 0), x)\n        np.testing.assert_array_equal(boundary_conditions.periodic(x, 1), [4, 0, 1, 2, 3, 4, 0])\n        np.testing.assert_array_equal(boundary_conditions.periodic(x, 2), [3, 4, 0, 1, 2, 3, 4, 0, 1])\n\n    def test_extrapolate(self):\n        x = np.arange(5)\n        np.testing.assert_array_equal(boundary_conditions.extrapolate(x, 0), x)\n        np.testing.assert_array_equal(boundary_conditions.extrapolate(x, 1), np.arange(-1, 6))\n        np.testing.assert_array_equal(boundary_conditions.extrapolate(x, 2), np.arange(-2, 7))\n\n    def test_extrapolate_away_from_zero(self):\n        x = np.arange(1, 5)\n        np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 0), x)\n        np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 1), [2, 1, 2, 3, 4, 5])\n        np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 2), [3, 2, 1, 2, 3, 4, 5, 6])\n\n        x = x[::-1]\n        np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 0), x)\n        np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 1), [5, 4, 3, 2, 1, 2])\n        np.testing.assert_array_equal(boundary_conditions.extrapolate_away_from_zero(x, 2), [6, 5, 4, 3, 2, 1, 2, 3])\n\n\nif __name__ == \"__main__\":\n    absltest.main()\n"
  },
  {
    "path": "hj_reachability/dynamics.py",
    "content": "import abc\n\nimport jax.numpy as jnp\n\n\nclass Dynamics(metaclass=abc.ABCMeta):\n    \"\"\"Abstract base class for representing continuous-time dynamics in the context of Hamilton-Jacobi reachability.\n\n    TODO: Consider allowing for state/time-dependent control/disturbance spaces.\n\n    Attributes:\n        control_mode: Whether the controller is trying to \"max\"imize or \"min\"imize the value.\n        disturbance_mode: Whether the disturbance is trying to \"max\"imize or \"min\"imize the value.\n        control_space: A `BoundedSet` defining the (time-invariant) set of possible controls.\n        disturbance_space: A `BoundedSet` defining the (time-invariant) set of possible disturbances.\n    \"\"\"\n\n    def __init__(self, control_mode, disturbance_mode, control_space, disturbance_space):\n        self.control_mode = control_mode\n        self.disturbance_mode = disturbance_mode\n        self.control_space = control_space\n        self.disturbance_space = disturbance_space\n\n    @abc.abstractmethod\n    def __call__(self, state, control, disturbance, time):\n        \"\"\"Implements the continuous-time dynamics ODE.\"\"\"\n\n    @abc.abstractmethod\n    def optimal_control_and_disturbance(self, state, time, grad_value):\n        \"\"\"Computes the optimal control and disturbance realized by the HJ PDE Hamiltonian.\"\"\"\n\n    def optimal_control(self, state, time, grad_value):\n        \"\"\"Computes the optimal control realized by the HJ PDE Hamiltonian.\"\"\"\n        return self.optimal_control_and_disturbance(state, time, grad_value)[0]\n\n    def optimal_disturbance(self, state, time, grad_value):\n        \"\"\"Computes the optimal disturbance realized by the HJ PDE Hamiltonian.\"\"\"\n        return self.optimal_control_and_disturbance(state, time, grad_value)[1]\n\n    def hamiltonian(self, state, time, value, grad_value):\n        \"\"\"Evaluates the HJ PDE Hamiltonian.\"\"\"\n        del value  # unused\n        control, disturbance = self.optimal_control_and_disturbance(state, time, grad_value)\n        return grad_value @ self(state, control, disturbance, time)\n\n    @abc.abstractmethod\n    def partial_max_magnitudes(self, state, time, value, grad_value_box):\n        \"\"\"Computes the max magnitudes of the Hamiltonian partials over the `grad_value_box` in each dimension.\"\"\"\n\n\nclass ControlAndDisturbanceAffineDynamics(Dynamics):\n    \"\"\"Abstract base class for representing control- and disturbance-affine dynamics.\"\"\"\n\n    def __call__(self, state, control, disturbance, time):\n        \"\"\"Implements the affine dynamics `dx_dt = f(x, t) + G_u(x, t) @ u + G_d(x, t) @ d`.\"\"\"\n        return (self.open_loop_dynamics(state, time) + self.control_jacobian(state, time) @ control +\n                self.disturbance_jacobian(state, time) @ disturbance)\n\n    @abc.abstractmethod\n    def open_loop_dynamics(self, state, time):\n        \"\"\"Implements the open loop dynamics `f(x, t)`.\"\"\"\n\n    @abc.abstractmethod\n    def control_jacobian(self, state, time):\n        \"\"\"Implements the control Jacobian `G_u(x, t)`.\"\"\"\n\n    @abc.abstractmethod\n    def disturbance_jacobian(self, state, time):\n        \"\"\"Implements the disturbance Jacobian `G_d(x, t)`.\"\"\"\n\n    def optimal_control_and_disturbance(self, state, time, grad_value):\n        \"\"\"Computes the optimal control and disturbance realized by the HJ PDE Hamiltonian.\"\"\"\n        control_direction = grad_value @ self.control_jacobian(state, time)\n        if self.control_mode == \"min\":\n            control_direction = -control_direction\n        disturbance_direction = grad_value @ self.disturbance_jacobian(state, time)\n        if self.disturbance_mode == \"min\":\n            disturbance_direction = -disturbance_direction\n        return (self.control_space.extreme_point(control_direction),\n                self.disturbance_space.extreme_point(disturbance_direction))\n\n    def partial_max_magnitudes(self, state, time, value, grad_value_box):\n        \"\"\"Computes the max magnitudes of the Hamiltonian partials over the `grad_value_box` in each dimension.\"\"\"\n        del value, grad_value_box  # unused\n        # An overestimation; see Eq. (25) from https://www.cs.ubc.ca/~mitchell/ToolboxLS/toolboxLS-1.1.pdf.\n        return (jnp.abs(self.open_loop_dynamics(state, time)) +\n                jnp.abs(self.control_jacobian(state, time)) @ self.control_space.max_magnitudes +\n                jnp.abs(self.disturbance_jacobian(state, time)) @ self.disturbance_space.max_magnitudes)\n"
  },
  {
    "path": "hj_reachability/finite_differences/__init__.py",
    "content": "from hj_reachability.finite_differences.upwind_first import (ENO1, ENO2, ENO3, WENO1, WENO3, WENO5,\n                                                             essentially_non_oscillatory, first_order,\n                                                             weighted_essentially_non_oscillatory)\n\n__all__ = (\"ENO1\", \"ENO2\", \"ENO3\", \"WENO1\", \"WENO3\", \"WENO5\", \"essentially_non_oscillatory\", \"first_order\",\n           \"weighted_essentially_non_oscillatory\")\n"
  },
  {
    "path": "hj_reachability/finite_differences/upwind_first.py",
    "content": "import functools\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport numpy.polynomial.polynomial as poly\n\nfrom types import ModuleType\nfrom typing import Any, Callable, Optional, Tuple\n\nArray = Any\n\nWENO_EPS = 1e-6\n\n\ndef weighted_essentially_non_oscillatory(eno_order: int, values: Array, spacing: float,\n                                         boundary_condition: Callable[[Array, int], Array]) -> Tuple[Array, Array]:\n    \"\"\"Implements an upwind weighted essentially non-oscillatory (WENO) scheme for first derivative approximation.\n\n    Args:\n        eno_order: The order of the underlying essentially non-oscillatory (ENO) scheme; the resulting WENO scheme is\n            `(2 * eno_order - 1)`th-order accurate.\n        values: 1-dimensional array of function values assumed to be evaluated at a uniform grid in the domain.\n        spacing: Grid spacing of the `values`.\n        boundary_condition: A function used to pad `values` to implement a boundary condition (e.g., periodic).\n\n    Returns:\n        A tuple of arrays `(left_derivatives, right_derivatives)` each the same shape as `values` which contain,\n        respectively, left and right approximations of the first derivative at the grid points of `values`.\n    \"\"\"\n    if eno_order < 1:\n        raise ValueError(f\"`eno_order` must be at least 1; got {eno_order}.\")\n\n    values = boundary_condition(values, eno_order)\n    diffs = (values[1:] - values[:-1]) / spacing\n\n    if eno_order == 1:\n        return (diffs[:-1], diffs[1:])\n\n    substencil_approximations = tuple(\n        _unrolled_correlate(diffs[i:len(diffs) - eno_order + i], c)\n        for (i, c) in enumerate(_diff_coefficients(eno_order)))\n    diffs2 = diffs[1:] - diffs[:-1]\n    smoothness_indicators = [\n        sum(\n            _unrolled_correlate(diffs2[i + j:len(diffs2) - eno_order + i + 1], L[j:, j])**2\n            for j in range(eno_order - 1))\n        for (i, L) in enumerate(np.linalg.cholesky(_smoothness_indicator_quad_form(eno_order)))\n    ]\n    left_and_right_unnormalized_weights = [[\n        c / (s[i:len(s) + i - 1] + WENO_EPS)**2 for (c, s) in zip(coefficients, smoothness_indicators)\n    ] for (i, coefficients) in enumerate(_substencil_coefficients(eno_order))]\n    return tuple(\n        sum(w * a for (w, a) in zip(unnormalized_weights, substencil_approximations[i:eno_order + i])) /\n        sum(unnormalized_weights) for (i, unnormalized_weights) in enumerate(left_and_right_unnormalized_weights))\n\n\ndef essentially_non_oscillatory(order: int, values: Array, spacing: float,\n                                boundary_condition: Callable[[Array, int], Array]) -> Tuple[Array, Array]:\n    \"\"\"Implements an upwind essentially non-oscillatory (ENO) scheme for first derivative approximation.\n\n    Args:\n        order: The desired order of accuracy for the ENO scheme.\n        values: 1-dimensional array of function values assumed to be evaluated at a uniform grid in the domain.\n        spacing: Grid spacing of the `values`.\n        boundary_condition: A function used to pad `values` to implement a boundary condition (e.g., periodic).\n\n    Returns:\n        A tuple of arrays `(left_derivatives, right_derivatives)` each the same shape as `values` which contain,\n        respectively, left and right approximations of the first derivative at the grid points of `values`.\n    \"\"\"\n    if order < 1:\n        raise ValueError(f\"`order` must be at least 1; got {order}.\")\n\n    values = boundary_condition(values, order)\n    diffs = (values[1:] - values[:-1]) / spacing\n\n    if order == 1:\n        return (diffs[:-1], diffs[1:])\n\n    substencil_approximations = tuple(\n        _unrolled_correlate(diffs[i:len(diffs) - order + i], c) for (i, c) in enumerate(_diff_coefficients(order)))\n\n    undivided_differences = []\n    for i in range(2, order):\n        diffs = diffs[1:] - diffs[:-1]\n        undivided_differences.append(diffs[order - i:i - order])\n\n    abs_diffs = jnp.abs(diffs[1:] - diffs[:-1])\n    stencil_indices = abs_diffs[1:] < abs_diffs[:-1]\n    for diffs in reversed(undivided_differences):\n        abs_diffs = jnp.abs(diffs)\n        stencil_indices = jnp.where(abs_diffs[1:] < abs_diffs[:-1], stencil_indices[1:] + 1, stencil_indices[:-1])\n\n    return (jnp.select([stencil_indices[:-1] == i for i in range(order - 1)], substencil_approximations[:-2],\n                       substencil_approximations[-2]),\n            jnp.select([stencil_indices[1:] == i for i in range(order - 1)], substencil_approximations[1:-1],\n                       substencil_approximations[-1]))\n\n\nfirst_order = WENO1 = functools.partial(weighted_essentially_non_oscillatory, 1)\nWENO3 = functools.partial(weighted_essentially_non_oscillatory, 2)\nWENO5 = functools.partial(weighted_essentially_non_oscillatory, 3)\nENO1 = functools.partial(essentially_non_oscillatory, 1)\nENO2 = functools.partial(essentially_non_oscillatory, 2)\nENO3 = functools.partial(essentially_non_oscillatory, 3)\n\n\ndef _weighted_essentially_non_oscillatory_vectorized(\n        eno_order: int, values: Array, spacing: float, boundary_condition: Callable[[Array, int],\n                                                                                    Array]) -> Tuple[Array, Array]:\n    \"\"\"Implements a more \"vectorized\" but ultimately slower version of `weighted_essentially_non_oscillatory`.\"\"\"\n    if eno_order < 1:\n        raise ValueError(f\"`eno_order` must be at least 1; got {eno_order}.\")\n\n    values = boundary_condition(values, eno_order)\n    diffs = (values[1:] - values[:-1]) / spacing\n\n    if eno_order == 1:\n        return (diffs[:-1], diffs[1:])\n\n    substencil_approximations = _align_substencil_values(\n        jax.vmap(jnp.correlate, (None, 0), 0)(diffs, _diff_coefficients(eno_order)), jnp)\n    diffs2 = diffs[1:] - diffs[:-1]\n    chol_T = jnp.asarray(np.linalg.cholesky(_smoothness_indicator_quad_form(eno_order)).swapaxes(-1, -2))\n    smoothness_indicators = _align_substencil_values(\n        jnp.sum(jnp.square(jax.vmap(jax.vmap(jnp.correlate, (None, 0), 1), (None, 0), 0)(diffs2, chol_T)), -1), jnp)\n    unscaled_weights = 1 / jnp.square(smoothness_indicators + WENO_EPS)\n    unnormalized_weights = (jnp.asarray(_substencil_coefficients(eno_order)[..., np.newaxis]) *\n                            jnp.stack([unscaled_weights[:, :-1], unscaled_weights[:, 1:]]))\n    weights = unnormalized_weights / jnp.sum(unnormalized_weights, 1, keepdims=True)\n    return tuple(jnp.sum(jnp.stack([substencil_approximations[:-1], substencil_approximations[1:]]) * weights, 1))\n\n\ndef _unrolled_correlate(a: Array, v: Array) -> Array:\n    \"\"\"An unrolled equivalent of `np.correlate`.\"\"\"\n    return sum(a[i:len(a) - len(v) + i + 1] * x for (i, x) in enumerate(v))\n\n\ndef _substencils(k: int) -> Array:\n    \"\"\"Returns the `k + 1` subranges of length `k + 1` from the full stencil range `[-k, k + 1)`.\"\"\"\n    return np.arange(k + 1) + np.arange(k + 1)[:, np.newaxis] - k\n\n\ndef _spread_substencil_values(x: Array, np: ModuleType = np) -> Array:\n    \"\"\"Offsets each successive row of a matrix `x` by one additional column.\"\"\"\n    return np.reshape(np.reshape(np.pad(x, ((0, 0), (0, x.shape[0]))), -1)[:-x.shape[0]], (x.shape[0], -1))\n\n\ndef _align_substencil_values(x: Array, np: ModuleType = np) -> Array:\n    \"\"\"Slices and stacks windows, each offset by one column from the previous, from rows of a matrix `x`.\"\"\"\n    return np.reshape(np.pad(np.reshape(x, -1), (0, x.shape[0])), (x.shape[0], -1))[:, :-x.shape[0]]\n\n\ndef _diff_coefficients(k: Optional[int] = None, stencil: Optional[Array] = None) -> Array:\n    \"\"\"Returns first derivative approximation finite difference coefficients for function value first differences.\"\"\"\n    if k is None:\n        if stencil is None:\n            raise ValueError(\"One of `k` or `stencil` must be provided.\")\n        k = stencil.shape[-1] - 1\n    else:\n        if stencil is None:\n            stencil = _substencils(k)\n        elif k != stencil.shape[-1] - 1:\n            raise ValueError(\"`k` must match `stencil.shape[-1] - 1` if both arguments are provided; got \"\n                             f\"{(k, stencil.shape[-1] - 1)}.\")\n    return np.linalg.solve(\n        np.diff(poly.polyvander(stencil, k), axis=-2)[..., 1:].swapaxes(-1, -2),\n        np.eye(k)[(np.newaxis,) * (stencil.ndim - 1) + (0, ..., np.newaxis)])[..., 0]\n\n\ndef _substencil_coefficients(k: int) -> Array:\n    \"\"\"Returns coefficients for combining substencil approximations to yield higher order left/right approximations.\"\"\"\n    left_coefficients = np.linalg.solve(\n        _spread_substencil_values(_diff_coefficients(k))[:-1, :k].T,\n        _diff_coefficients(stencil=np.arange(-k, k))[:k])\n    return np.array([left_coefficients, left_coefficients[::-1]])\n\n\ndef _polyder_operator(k: int, d: int) -> Array:\n    \"\"\"Returns a matrix `D` such that `D @ p == poly.polyder(p, d)` for polynomials `p` of degree `k`.\"\"\"\n    return np.concatenate([np.zeros((k + 1 - d, d)), np.diag(poly.polyder(np.ones(k + 1), d))], 1)\n\n\ndef _smoothness_indicator_quad_form(k: int) -> Array:\n    \"\"\"Returns quadratic forms for computing substencil smoothness indicators as functions of second differences.\"\"\"\n    interp_poly_second_der = (poly.polyder(np.ones(k + 1), 2)[:, np.newaxis] *\n                              np.linalg.inv(np.diff(poly.polyvander(_substencils(k)[1:], k), 2, axis=-2)[..., 2:]))\n\n    quad_form = np.zeros((k, k - 1, k - 1))\n    for m in range(k - 1):\n        integrator_matrix = 1 / (np.arange(k - 1 - m) + np.arange(k - 1 - m)[:, np.newaxis] + 1)\n        interp_poly_m_plus_2_der = _polyder_operator(k - 2, m) @ interp_poly_second_der\n        quad_form += interp_poly_m_plus_2_der.swapaxes(-1, -2) @ integrator_matrix @ interp_poly_m_plus_2_der\n    return quad_form\n"
  },
  {
    "path": "hj_reachability/finite_differences/upwind_first_test.py",
    "content": "import math\n\nfrom absl.testing import absltest\nimport jax\nimport numpy as np\n\nfrom hj_reachability import boundary_conditions\nfrom hj_reachability.finite_differences import upwind_first\n\n\nclass UpwindFirstTest(absltest.TestCase):\n\n    def setUp(self):\n        np.random.seed(0)\n\n    def test_weighted_essentially_non_oscillatory(self):\n\n        def _WENO5(values, spacing, boundary_condition):\n            values = boundary_condition(values, 3)\n            diffs = (values[1:] - values[:-1]) / spacing\n\n            def compute_weno(v):\n                phi = [\n                    v[0] / 3 - 7 * v[1] / 6 + 11 * v[2] / 6,\n                    -v[1] / 6 + 5 * v[2] / 6 + v[3] / 3,\n                    v[2] / 3 + 5 * v[3] / 6 - v[4] / 6,\n                ]\n                s = [(13 / 12) * (v[0] - 2 * v[1] + v[2])**2 + (1 / 4) * (v[0] - 4 * v[1] + 3 * v[2])**2,\n                     (13 / 12) * (v[1] - 2 * v[2] + v[3])**2 + (1 / 4) * (v[1] - v[3])**2,\n                     (13 / 12) * (v[2] - 2 * v[3] + v[4])**2 + (1 / 4) * (3 * v[2] - 4 * v[3] + v[4])**2]\n                a = [w / (x + upwind_first.WENO_EPS)**2 for (w, x) in zip([0.1, 0.6, 0.3], s)]\n                w = [x / sum(a) for x in a]\n                return sum(p * w for (p, w) in zip(phi, w))\n\n            return (compute_weno([diffs[i:-5 + i] for i in range(5)]),\n                    compute_weno([diffs[5 - i:None if i == 0 else -i] for i in range(5)]))\n\n        values = np.random.rand(1000)\n        spacing = 0.1\n        jax.tree.map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5),\n                     upwind_first.WENO5(values, spacing, boundary_conditions.periodic),\n                     _WENO5(values, spacing, boundary_conditions.periodic))\n\n    def test_essentially_non_oscillatory(self):\n\n        def _brute_force_essentially_non_oscillatory(order, values, spacing, boundary_condition):\n\n            def _divided_difference(x, i, spacing=1):\n                if isinstance(i, int):\n                    return x[i]\n                order = len(i) - 1\n                return np.diff(x[i], order)[0] / (math.factorial(order) * spacing**order)\n\n            v = np.array(boundary_condition(values, order))\n            x = np.arange(len(v)) * spacing\n\n            p = [np.poly1d(v[i]) for i in range(order - 1, len(v) - order)]\n            ks = []\n            for i in range(len(p)):\n                j = i + order - 1\n                p[i] += _divided_difference(v, [j, j + 1], spacing) * np.poly1d([x[j]], True)\n                k = j\n                for d in range(2, order + 1):\n                    a = _divided_difference(v, np.arange(k, k + d + 1), spacing)\n                    b = _divided_difference(v, np.arange(k - 1, k + d), spacing)\n                    if np.abs(a) >= np.abs(b):\n                        c = b\n                        k_next = k - 1\n                    else:\n                        c = a\n                        k_next = k\n                    p[i] += c * np.poly1d(x[k:k + d], True)\n                    k = k_next\n                ks.append(k - j)\n            p_x = [np.polyder(f) for f in p]\n            return (np.array([np.polyval(f, x) for (f, x) in zip(p_x[:-1], x[order:-order])]),\n                    np.array([np.polyval(f, x) for (f, x) in zip(p_x[1:], x[order:-order])]))\n\n        values = np.random.rand(1000)\n        spacing = 0.1\n        for order in range(1, 5):\n            jax.tree.map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5),\n                         upwind_first.essentially_non_oscillatory(order, values, spacing, boundary_conditions.periodic),\n                         _brute_force_essentially_non_oscillatory(order, values, spacing, boundary_conditions.periodic))\n\n    def test_weighted_essentially_non_oscillatory_vectorized(self):\n        values = np.random.rand(1000)\n        spacing = 0.1\n        for eno_order in range(1, 5):\n            jax.tree.map(\n                lambda x, y: np.testing.assert_allclose(x, y, atol=1e-5),\n                upwind_first.weighted_essentially_non_oscillatory(eno_order, values, spacing,\n                                                                  boundary_conditions.periodic),\n                upwind_first._weighted_essentially_non_oscillatory_vectorized(eno_order, values, spacing,\n                                                                              boundary_conditions.periodic))\n\n    def test_diff_coefficients(self):\n        # k = 1\n        np.testing.assert_allclose(upwind_first._diff_coefficients(1), np.ones((2, 1)))\n\n        # k = 2\n        np.testing.assert_allclose(upwind_first._diff_coefficients(2), np.array([[-1, 3], [1, 1], [3, -1]]) / 2)\n\n        # k = 3\n        np.testing.assert_allclose(upwind_first._diff_coefficients(3),\n                                   np.array([[2, -7, 11], [-1, 5, 2], [2, 5, -1], [11, -7, 2]]) / 6)\n\n    def test_substencil_coefficients(self):\n        # k = 1\n        np.testing.assert_allclose(upwind_first._substencil_coefficients(1), np.ones((2, 1)))\n\n        # k = 2\n        np.testing.assert_allclose(upwind_first._substencil_coefficients(2), np.array([[1, 2], [2, 1]]) / 3)\n\n        # k = 3\n        np.testing.assert_allclose(upwind_first._substencil_coefficients(3), np.array([[1, 6, 3], [3, 6, 1]]) / 10)\n\n    def test_smoothness_indicator_quad_form(self):\n        diff_operator = lambda k: np.eye(k - 1, k, 1) - np.eye(k - 1, k, 0)\n        square_outer = lambda v: v[..., np.newaxis] * v[..., np.newaxis, :]\n\n        # k = 1\n        np.testing.assert_allclose(\n            diff_operator(1).T @ upwind_first._smoothness_indicator_quad_form(1) @ diff_operator(1), [[[0]]])\n\n        # k = 2\n        np.testing.assert_allclose(\n            diff_operator(2).T @ upwind_first._smoothness_indicator_quad_form(2) @ diff_operator(2),\n            square_outer(np.array([[1, -1], [1, -1]])))\n\n        # k = 3\n        np.testing.assert_allclose(\n            diff_operator(3).T @ upwind_first._smoothness_indicator_quad_form(3) @ diff_operator(3),\n            (13 / 12) * square_outer(np.array([[1, -2, 1], [1, -2, 1], [1, -2, 1]])) +\n            (1 / 4) * square_outer(np.array([[1, -4, 3], [1, 0, -1], [3, -4, 1]])))\n\n\nif __name__ == \"__main__\":\n    absltest.main()\n"
  },
  {
    "path": "hj_reachability/grid.py",
    "content": "import functools\n\nfrom flax import struct\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom hj_reachability import boundary_conditions as _boundary_conditions\nfrom hj_reachability.finite_differences import upwind_first\nfrom hj_reachability import sets\nfrom hj_reachability import utils\n\nfrom typing import Any, Callable, Optional, Tuple, Union\nfrom hj_reachability.boundary_conditions import BoundaryCondition\n\nArray = Any\n\n\n@struct.dataclass\nclass Grid:\n    \"\"\"Class for representing Cartesian state grids with uniform spacing in each dimension.\n\n    Attributes:\n        states: An `(N + 1)` dimensional array containing the state values at each grid location. The first `N`\n            dimensions correspond to the location in the grid, while the last dimension (itself of size `N`) contains\n            the state vector.\n        domain: A `Box` representing the domain of grid.\n        coordinate_vectors: A tuple of `N` arrays containing the discrete state values in each dimension. The `states`\n            attribute is produced by `stack`ing a `meshgrid` of these coordinate vectors.\n        spacings: A tuple of `N` scalars containing the grid spacing (the difference between successive elements of the\n            corresponding coordinate vector) in each dimension.\n        boundary_conditions: A tuple of `N` boundary conditions for each dimension. These boundary conditions are\n            functions used to pad values (notably not stored in this `Grid` data structure) to implement a boundary\n            condition (e.g., periodic).\n    \"\"\"\n    states: Array\n    domain: sets.Box\n    coordinate_vectors: Tuple[Array, ...]\n    spacings: Tuple[Array, ...]\n    boundary_conditions: Tuple[BoundaryCondition, ...] = struct.field(pytree_node=False)\n\n    @classmethod\n    def from_lattice_parameters_and_boundary_conditions(\n            cls,\n            domain: sets.Box,\n            shape: Tuple[int, ...],\n            boundary_conditions: Optional[Tuple[BoundaryCondition, ...]] = None,\n            periodic_dims: Optional[Union[int, Tuple[int, ...]]] = None) -> \"Grid\":\n        \"\"\"Constructs a `Grid` from a domain, shape, and boundary conditions.\n\n        Args:\n            domain: A `Box` representing the domain of grid.\n            shape: A tuple of `N` integers denoting the number of discretization nodes in each dimension.\n            boundary_conditions: A tuple of `N` boundary conditions for each dimension. If not provided, defaults to\n                `extrapolate_away_from_zero` in each dimension, with the exception of those dimensions that appear in\n                `periodic_dims` where the `periodic` boundary condition is used instead.\n            periodic_dims: A single integer or tuple of integers denoting which dimensions are periodic in the case that\n                the `boundary_conditions` are not explicitly provided as input to this factory method.\n\n        Returns:\n            A `Grid` constructed according to the provided specifications.\n        \"\"\"\n        ndim = len(shape)\n        if boundary_conditions is None:\n            if not isinstance(periodic_dims, tuple):\n                periodic_dims = (periodic_dims,)\n            boundary_conditions = tuple(\n                _boundary_conditions.periodic if i in periodic_dims else _boundary_conditions.extrapolate_away_from_zero\n                for i in range(ndim))\n\n        coordinate_vectors, spacings = zip(\n            *(jnp.linspace(l, h, n, endpoint=bc is not _boundary_conditions.periodic, retstep=True)\n              for l, h, n, bc in zip(domain.lo, domain.hi, shape, boundary_conditions)))\n        states = jnp.stack(jnp.meshgrid(*coordinate_vectors, indexing=\"ij\"), -1)\n\n        return cls(states, domain, coordinate_vectors, spacings, boundary_conditions)\n\n    @property\n    def ndim(self) -> int:\n        \"\"\"Returns the dimension `N` of the grid.\"\"\"\n        return self.states.ndim - 1\n\n    @property\n    def shape(self) -> Tuple[int, ...]:\n        \"\"\"Returns the shape of the grid, a tuple of `N` integers.\"\"\"\n        return self.states.shape[:-1]\n\n    def upwind_grad_values(self, upwind_scheme: Callable, values: Array) -> Tuple[Array, Array]:\n        \"\"\"Returns `(left_grad_values, right_grad_values)`.\"\"\"\n        left_derivatives, right_derivatives = zip(*[\n            utils.multivmap(lambda values: upwind_scheme(values, spacing, boundary_condition),\n                            np.array([j\n                                      for j in range(self.ndim)\n                                      if j != i]))(values)\n            for i, (spacing, boundary_condition) in enumerate(zip(self.spacings, self.boundary_conditions))\n        ])\n        return (jnp.stack(left_derivatives, -1), jnp.stack(right_derivatives, -1))\n\n    def grad_values(self, values: Array, upwind_scheme: Optional[Callable] = None) -> Array:\n        \"\"\"Returns a central difference-based approximation of `grad_values`.\"\"\"\n        # TODO: Implement central difference schemes in `hj_reachability.finite_differences`.\n        if upwind_scheme is None:\n            upwind_scheme = upwind_first.first_order\n        return sum(self.upwind_grad_values(upwind_scheme, values)) / 2\n\n    def position(self, state: Array) -> Array:\n        \"\"\"Returns an array of `float`s corresponding to the position of `state` in the grid.\"\"\"\n        position = (state - self.domain.lo) / jnp.array(self.spacings)\n        return jnp.where(self._is_periodic_dim, position % np.array(self.shape), position)\n\n    def nearest_index(self, state: Array) -> Array:\n        \"\"\"Returns the result of rounding `self.position(state)` to the nearest grid index.\"\"\"\n        return jnp.round(self.position(state)).astype(jnp.int32)\n\n    def interpolate(self, values, state):\n        \"\"\"Interpolates `values` (possibly multidimensional per node) defined over the grid at the given `state`.\"\"\"\n        position = (state - self.domain.lo) / jnp.array(self.spacings)\n        index_lo = jnp.floor(position).astype(jnp.int32)\n        index_hi = index_lo + 1\n        weight_hi = position - index_lo\n        weight_lo = 1 - weight_hi\n        index_lo, index_hi = tuple(\n            jnp.where(self._is_periodic_dim, index % np.array(self.shape), jnp.clip(index, 0,\n                                                                                    np.array(self.shape) - 1))\n            for index in (index_lo, index_hi))\n        weight = functools.reduce(lambda x, y: x * y, jnp.ix_(*jnp.stack([weight_lo, weight_hi], -1)))\n        # TODO: Double-check numerical stability here and/or switch to `tuple`s and `itertools.product` for clarity.\n        result = jnp.sum(\n            weight[(...,) + (np.newaxis,) * (values.ndim - self.ndim)] *\n            values[jnp.ix_(*jnp.stack([index_lo, index_hi], -1))], list(range(self.ndim)))\n        return jnp.where(jnp.any(~self._is_periodic_dim & ((state < self.domain.lo) | (state > self.domain.hi))),\n                         jnp.nan, result)\n\n    @property\n    def _is_periodic_dim(self) -> Array:\n        \"\"\"Returns a boolean vector indicating which dimensions (if any) are periodic.\"\"\"\n        return np.array([bc is _boundary_conditions.periodic for bc in self.boundary_conditions])\n"
  },
  {
    "path": "hj_reachability/grid_test.py",
    "content": "from absl.testing import absltest\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom hj_reachability import grid as _grid\nfrom hj_reachability import sets\n\n\nclass BoundaryConditionsTest(absltest.TestCase):\n\n    def setUp(self):\n        np.random.seed(0)\n\n    def test_grid_interpolate(self):\n        grid_domain = sets.Box(np.zeros(2), np.ones(2))\n        grid_shape = (3, 2)\n        grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape, periodic_dims=1)\n        values = np.random.random((3, 2))\n        np.testing.assert_allclose(grid.interpolate(values, np.array([0.25, 2.75])), np.mean(values[0:2, 0:2]))\n        np.testing.assert_allclose(grid.interpolate(values, np.zeros(2)), values[0, 0])\n        np.testing.assert_allclose(grid.interpolate(values, np.ones(2)), values[-1, 0])\n        values = np.random.random((3, 2, 3, 4))\n        np.testing.assert_allclose(grid.interpolate(values, np.array([0.75, 2.75])), np.mean(values[1:3, 0:2], (0, 1)))\n        np.testing.assert_allclose(grid.interpolate(values, np.zeros(2)), values[0, 0])\n        np.testing.assert_allclose(grid.interpolate(values, np.ones(2)), values[-1, 0])\n\n    def test_grid_interpolate_on_grid(self):\n        grid_domain = sets.Box(jnp.zeros(2), jnp.ones(2))\n        grid_shape = (3, 4)\n        for value_shape in ((), (5,)):\n            values = jnp.array(np.random.random(grid_shape + value_shape))\n            grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape)\n            np.testing.assert_allclose(jax.vmap(jax.vmap(lambda x: grid.interpolate(values, x)))(grid.states),\n                                       values,\n                                       atol=1e-6)\n\n            grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape, periodic_dims=0)\n            states = grid.states + (grid._is_periodic_dim * np.arange(-3, 4)[:, None, None, None] *\n                                    (grid.domain.hi - grid.domain.lo))\n            np.testing.assert_allclose(jax.vmap(jax.vmap(jax.vmap(lambda x: grid.interpolate(values, x))))(states),\n                                       np.broadcast_to(values, states.shape[:1] + values.shape),\n                                       atol=1e-6)\n\n    def test_grid_interpolate_off_grid(self):\n        grid_domain = sets.Box(jnp.zeros(2), jnp.ones(2))\n        grid_shape = (3, 4)\n        for value_shape in ((), (5,)):\n            a = np.random.random((2,) + value_shape)\n            grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape)\n            values = grid.states @ a\n            states = grid.domain.lo + np.random.random((100, 2)) * (grid.domain.hi - grid.domain.lo)\n            np.testing.assert_allclose(jax.vmap(lambda x: grid.interpolate(values, x))(states), states @ a, atol=1e-6)\n\n            grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape, periodic_dims=0)\n            values = jnp.array(np.random.random(grid_shape + value_shape))\n            grid_unwrapped = _grid.Grid.from_lattice_parameters_and_boundary_conditions(\n                grid.domain, tuple(d + 1 if p else d for d, p in zip(grid.shape, grid._is_periodic_dim)))\n            values_unwrapped = jnp.concatenate([values, values[:1]])\n            states = states + (grid._is_periodic_dim * np.arange(-3, 4)[:, None, None] *\n                               (grid.domain.hi - grid.domain.lo))\n            np.testing.assert_allclose(jax.vmap(jax.vmap(lambda x: grid.interpolate(values, x)))(states),\n                                       jax.vmap(jax.vmap(lambda x: grid_unwrapped.interpolate(values_unwrapped, x)))\n                                       ((states - grid.domain.lo) % (grid.domain.hi - grid.domain.lo) + grid.domain.lo),\n                                       atol=1e-6)\n\n    def test_grid_interpolate_extrapolate_nan(self):\n        grid_domain = sets.Box(jnp.zeros(2), jnp.ones(2))\n        grid_shape = (3, 4)\n        for value_shape in ((), (5,)):\n            values = jnp.array(np.random.random(grid_shape + value_shape))\n            grid = _grid.Grid.from_lattice_parameters_and_boundary_conditions(grid_domain, grid_shape)\n            states = grid.domain.lo + (grid.domain.hi - grid.domain.lo) * np.array(\n                [[0.5 + dx, 0.5 + dy] for dx in [-1, 0, 1] for dy in [-1, 0, 1] if dx or dy])\n            result = jax.vmap(lambda x: grid.interpolate(values, x))(states)\n            self.assertEqual(result.shape, (8,) + value_shape)\n            self.assertTrue(np.all(np.isnan(result)))\n\n\nif __name__ == \"__main__\":\n    absltest.main()\n"
  },
  {
    "path": "hj_reachability/sets.py",
    "content": "import abc\n\nfrom flax import struct\nimport jax.numpy as jnp\n\nfrom hj_reachability import utils\n\nfrom typing import Any\n\nArray = Any\n\n\n@struct.dataclass\nclass BoundedSet(metaclass=abc.ABCMeta):\n    \"\"\"Abstract base class for representing bounded subsets of Euclidean space.\"\"\"\n\n    @abc.abstractmethod\n    def extreme_point(self, direction: Array) -> Array:\n        \"\"\"Computes the point `x` in the set such that the dot product `x @ direction` is greatest.\"\"\"\n\n    @property\n    @abc.abstractmethod\n    def bounding_box(self) -> \"Box\":\n        \"\"\"Returns an axis-aligned bounding box for the set.\"\"\"\n\n    @property\n    def max_magnitudes(self) -> Array:\n        \"\"\"Returns the maximum magnitude (per dimension) of points in the set.\"\"\"\n        return jnp.maximum(jnp.abs(self.bounding_box.lo), jnp.abs(self.bounding_box.hi))\n\n    @property\n    def ndim(self) -> int:\n        \"\"\"Returns the dimension of the Euclidean space the set lies within.\"\"\"\n        return self.bounding_box.ndim\n\n\n@struct.dataclass\nclass Box(BoundedSet):\n    \"\"\"Class for representing axis-aligned boxes.\"\"\"\n    lo: Array\n    hi: Array\n\n    def extreme_point(self, direction: Array) -> Array:\n        \"\"\"Computes the point `x` in the set such that the dot product `x @ direction` is greatest.\"\"\"\n        return jnp.where(direction < 0, self.lo, self.hi)\n\n    @property\n    def bounding_box(self) -> \"Box\":\n        \"\"\"Returns an axis-aligned bounding box for the set.\"\"\"\n        return self\n\n    @property\n    def ndim(self) -> int:\n        \"\"\"Returns the dimension of the Euclidean space the set lies within.\"\"\"\n        return self.lo.shape[-1]\n\n\n@struct.dataclass\nclass Ball(BoundedSet):\n    \"\"\"Class for representing Euclidean (L2) balls.\"\"\"\n    center: Array\n    radius: Array\n\n    def extreme_point(self, direction: Array) -> Array:\n        \"\"\"Computes the point `x` in the set such that the dot product `x @ direction` is greatest.\"\"\"\n        return self.center + self.radius * utils.unit_vector(direction)\n\n    @property\n    def bounding_box(self) -> \"Box\":\n        \"\"\"Returns an axis-aligned bounding box for the set.\"\"\"\n        return Box(self.center - jnp.expand_dims(self.radius, -1), self.center + jnp.expand_dims(self.radius, -1))\n"
  },
  {
    "path": "hj_reachability/sets_test.py",
    "content": "from absl.testing import absltest\nimport jax\nimport numpy as np\n\nfrom hj_reachability import sets\n\n\nclass SetsTest(absltest.TestCase):\n\n    def setUp(self):\n        np.random.seed(0)\n\n    def test_box(self):\n        box = sets.Box(np.ones(3), 2 * np.ones(3))\n        np.testing.assert_allclose(box.extreme_point(np.array([1, -1, 1])), np.array([2, 1, 2]))\n        self.assertTrue(np.all(np.isfinite(box.extreme_point(np.zeros(3)))))\n        self.assertEqual(box.bounding_box, box)\n        np.testing.assert_allclose(box.max_magnitudes, 2 * np.ones(3))\n        self.assertEqual(box.ndim, 3)\n\n    def test_ball(self):\n        ball = sets.Ball(np.ones(3), np.sqrt(3))\n        np.testing.assert_allclose(ball.extreme_point(np.array([1, -1, 1])), np.array([2, 0, 2]), atol=1e-6)\n        self.assertTrue(np.all(np.isfinite(ball.extreme_point(np.zeros(3)))))\n        jax.tree.map(np.testing.assert_allclose, ball.bounding_box,\n                     sets.Box((1 - np.sqrt(3)) * np.ones(3), (1 + np.sqrt(3)) * np.ones(3)))\n        np.testing.assert_allclose(ball.max_magnitudes, (1 + np.sqrt(3)) * np.ones(3))\n        self.assertEqual(ball.ndim, 3)\n\n\nif __name__ == \"__main__\":\n    absltest.main()\n"
  },
  {
    "path": "hj_reachability/solver.py",
    "content": "import contextlib\nimport functools\n\nfrom flax import struct\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom hj_reachability import artificial_dissipation\nfrom hj_reachability import time_integration\nfrom hj_reachability.finite_differences import upwind_first\n\nfrom typing import Callable, Text\n\n# Hamiltonian postprocessors.\nidentity = lambda *x: x[-1]  # Returns the last argument so that this may also be used as a value postprocessor.\nbackwards_reachable_tube = lambda x: jnp.minimum(x, 0)\n\n# Value postprocessors.\nstatic_obstacle = lambda obstacle: (lambda t, v: jnp.maximum(v, obstacle))\n\n\n@struct.dataclass\nclass SolverSettings:\n    upwind_scheme: Callable = struct.field(\n        default=upwind_first.WENO5,\n        pytree_node=False,\n    )\n    artificial_dissipation_scheme: Callable = struct.field(\n        default=artificial_dissipation.global_lax_friedrichs,\n        pytree_node=False,\n    )\n    hamiltonian_postprocessor: Callable = struct.field(\n        default=identity,\n        pytree_node=False,\n    )\n    time_integrator: Callable = struct.field(\n        default=time_integration.third_order_total_variation_diminishing_runge_kutta,\n        pytree_node=False,\n    )\n    value_postprocessor: Callable = struct.field(\n        default=identity,\n        pytree_node=False,\n    )\n    CFL_number: float = 0.75\n\n    @classmethod\n    def with_accuracy(cls, accuracy: Text, **kwargs) -> \"SolverSettings\":\n        if accuracy == \"low\":\n            upwind_scheme = upwind_first.first_order\n            time_integrator = time_integration.first_order_total_variation_diminishing_runge_kutta\n        elif accuracy == \"medium\":\n            upwind_scheme = upwind_first.ENO2\n            time_integrator = time_integration.second_order_total_variation_diminishing_runge_kutta\n        elif accuracy == \"high\":\n            upwind_scheme = upwind_first.WENO3\n            time_integrator = time_integration.third_order_total_variation_diminishing_runge_kutta\n        elif accuracy == \"very_high\":\n            upwind_scheme = upwind_first.WENO5\n            time_integrator = time_integration.third_order_total_variation_diminishing_runge_kutta\n        return cls(upwind_scheme=upwind_scheme, time_integrator=time_integrator, **kwargs)\n\n\n@functools.partial(jax.jit, static_argnames=(\"dynamics\", \"progress_bar\"))\ndef step(solver_settings, dynamics, grid, time, values, target_time, progress_bar=True):\n    with (_try_get_progress_bar(time, target_time)\n          if progress_bar is True else contextlib.nullcontext(progress_bar)) as bar:\n\n        def sub_step(time_values):\n            t, v = solver_settings.time_integrator(solver_settings, dynamics, grid, *time_values, target_time)\n            if bar is not False:\n                bar.update_to(jnp.abs(t - bar.reference_time))\n            return t, v\n\n        return jax.lax.while_loop(lambda time_values: jnp.abs(target_time - time_values[0]) > 0, sub_step,\n                                  (time, values))[1]\n\n\n@functools.partial(jax.jit, static_argnames=(\"dynamics\", \"progress_bar\"))\ndef solve(solver_settings, dynamics, grid, times, initial_values, progress_bar=True):\n    with (_try_get_progress_bar(times[0], times[-1])\n          if progress_bar is True else contextlib.nullcontext(progress_bar)) as bar:\n        make_carry_and_output_slice = lambda t, v: ((t, v), v)\n        return jnp.concatenate([\n            initial_values[np.newaxis],\n            jax.lax.scan(\n                lambda time_values, target_time: make_carry_and_output_slice(\n                    target_time, step(solver_settings, dynamics, grid, *time_values, target_time, bar)),\n                (times[0], initial_values), times[1:])[1]\n        ])\n\n\ndef _try_get_progress_bar(reference_time, target_time):\n    try:\n        import tqdm\n    except ImportError:\n        raise ImportError(\"The option `progress_bar=True` requires the 'tqdm' package to be installed.\")\n    return TqdmWrapper(tqdm,\n                       reference_time,\n                       total=jnp.abs(target_time - reference_time),\n                       unit=\"sim_s\",\n                       bar_format=\"{l_bar}{bar}| {n:7.4f}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]\",\n                       ascii=True)\n\n\nclass TqdmWrapper:\n\n    def __init__(self, tqdm, reference_time, total, *args, **kwargs):\n        self.reference_time = reference_time\n        jax.experimental.io_callback(\n            lambda total: self._create_tqdm(tqdm, float(total), *args, **kwargs),\n            None,\n            total,\n            ordered=True,\n        )\n\n    def _create_tqdm(self, tqdm, total, *args, **kwargs):\n        self._tqdm = tqdm.tqdm(total=total, *args, **kwargs)\n\n    def update_to(self, n):\n        jax.experimental.io_callback(\n            lambda n: self._tqdm.update(float(n) - self._tqdm.n) and None,\n            None,\n            n,\n            ordered=True,\n        )\n\n    def close(self):\n        jax.experimental.io_callback(\n            lambda: self._tqdm.close(),\n            None,\n            ordered=True,\n        )\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.close()\n"
  },
  {
    "path": "hj_reachability/solver_test.py",
    "content": "from absl.testing import absltest\nimport numpy as np\n\nimport hj_reachability as hj\n\n\nclass SolverTest(absltest.TestCase):\n\n    def setUp(self):\n        np.random.seed(0)\n        solver_settings = hj.SolverSettings.with_accuracy(\"low\")\n        dynamics = hj.systems.Air3d()\n        grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(np.array([-6., -10., 0.]),\n                                                                                   np.array([20., 10., 2 * np.pi])),\n                                                                       (11, 10, 10),\n                                                                       periodic_dims=2)\n        self.problem_definition = {\n            \"solver_settings\": solver_settings,\n            \"dynamics\": dynamics,\n            \"grid\": grid,\n        }\n\n    def test_step(self):\n        values = np.linalg.norm(self.problem_definition[\"grid\"].states[..., :2], axis=-1) - 5\n        target_values = hj.step(**self.problem_definition, time=0., values=values, target_time=-0.1, progress_bar=False)\n        self.assertEqual(target_values.shape, values.shape)\n        np.testing.assert_allclose(\n            target_values,\n            hj.step(**self.problem_definition, time=0., values=values, target_time=-0.1, progress_bar=True))\n\n    def test_solve(self):\n        times = np.linspace(0, -0.1, 3)\n        initial_values = np.linalg.norm(self.problem_definition[\"grid\"].states[..., :2], axis=-1) - 5\n        all_values = hj.solve(**self.problem_definition, times=times, initial_values=initial_values, progress_bar=False)\n        self.assertEqual(all_values.shape, (len(times),) + initial_values.shape)\n        np.testing.assert_allclose(all_values[0], initial_values)\n        np.testing.assert_allclose(all_values[-1],\n                                   hj.step(**self.problem_definition,\n                                           time=0.,\n                                           values=initial_values,\n                                           target_time=-0.1,\n                                           progress_bar=False),\n                                   atol=1e-2)\n        np.testing.assert_allclose(\n            all_values,\n            hj.solve(**self.problem_definition, times=times, initial_values=initial_values, progress_bar=True))\n"
  },
  {
    "path": "hj_reachability/systems/__init__.py",
    "content": "from hj_reachability.systems.air3d import Air3d, DubinsCarCAvoid\n\n__all__ = (\"Air3d\", \"DubinsCarCAvoid\")\n"
  },
  {
    "path": "hj_reachability/systems/air3d.py",
    "content": "import jax.numpy as jnp\n\nfrom hj_reachability import dynamics\nfrom hj_reachability import sets\n\n\nclass Air3d(dynamics.ControlAndDisturbanceAffineDynamics):\n\n    def __init__(self,\n                 evader_speed=5.,\n                 pursuer_speed=5.,\n                 evader_max_turn_rate=1.,\n                 pursuer_max_turn_rate=1.,\n                 control_mode=\"max\",\n                 disturbance_mode=\"min\",\n                 control_space=None,\n                 disturbance_space=None):\n        self.evader_speed = evader_speed\n        self.pursuer_speed = pursuer_speed\n        if control_space is None:\n            control_space = sets.Box(jnp.array([-evader_max_turn_rate]), jnp.array([evader_max_turn_rate]))\n        if disturbance_space is None:\n            disturbance_space = sets.Box(jnp.array([-pursuer_max_turn_rate]), jnp.array([pursuer_max_turn_rate]))\n        super().__init__(control_mode, disturbance_mode, control_space, disturbance_space)\n\n    def open_loop_dynamics(self, state, time):\n        _, _, psi = state\n        v_a, v_b = self.evader_speed, self.pursuer_speed\n        return jnp.array([-v_a + v_b * jnp.cos(psi), v_b * jnp.sin(psi), 0.])\n\n    def control_jacobian(self, state, time):\n        x, y, _ = state\n        return jnp.array([\n            [y],\n            [-x],\n            [-1.],\n        ])\n\n    def disturbance_jacobian(self, state, time):\n        return jnp.array([\n            [0.],\n            [0.],\n            [1.],\n        ])\n\n\nDubinsCarCAvoid = Air3d\n"
  },
  {
    "path": "hj_reachability/time_integration.py",
    "content": "import functools\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom hj_reachability import utils\n\n\ndef lax_friedrichs_numerical_hamiltonian(hamiltonian, state, time, value, left_grad_value, right_grad_value,\n                                         dissipation_coefficients):\n    hamiltonian_value = hamiltonian(state, time, value, (left_grad_value + right_grad_value) / 2)\n    dissipation_value = dissipation_coefficients @ (right_grad_value - left_grad_value) / 2\n    return hamiltonian_value - dissipation_value\n\n\n@functools.partial(jax.jit, static_argnames=\"dynamics\")\ndef euler_step(solver_settings, dynamics, grid, time, values, time_step=None, max_time_step=None):\n    time_direction = jnp.sign(max_time_step) if time_step is None else jnp.sign(time_step)\n    signed_hamiltonian = lambda *args, **kwargs: time_direction * dynamics.hamiltonian(*args, **kwargs)\n    left_grad_values, right_grad_values = grid.upwind_grad_values(solver_settings.upwind_scheme, values)\n    dissipation_coefficients = solver_settings.artificial_dissipation_scheme(dynamics.partial_max_magnitudes,\n                                                                             grid.states, time, values,\n                                                                             left_grad_values, right_grad_values)\n    dvalues_dt = -solver_settings.hamiltonian_postprocessor(time_direction * utils.multivmap(\n        lambda state, value, left_grad_value, right_grad_value, dissipation_coefficients:\n        (lax_friedrichs_numerical_hamiltonian(signed_hamiltonian, state, time, value,\n                                              left_grad_value, right_grad_value, dissipation_coefficients)),\n        np.arange(grid.ndim))(grid.states, values, left_grad_values, right_grad_values, dissipation_coefficients))\n    if time_step is None:\n        time_step_bound = 1 / jnp.max(jnp.sum(dissipation_coefficients / jnp.array(grid.spacings), -1))\n        time_step = time_direction * jnp.minimum(solver_settings.CFL_number * time_step_bound, jnp.abs(max_time_step))\n    # TODO: Think carefully about whether `solver_settings.value_postprocessor` should be applied here instead.\n    return time + time_step, values + time_step * dvalues_dt\n\n\ndef first_order_total_variation_diminishing_runge_kutta(solver_settings, dynamics, grid, time, values, target_time):\n    time_1, values_1 = euler_step(solver_settings, dynamics, grid, time, values, max_time_step=target_time - time)\n    return time_1, solver_settings.value_postprocessor(time_1, values_1)\n\n\ndef second_order_total_variation_diminishing_runge_kutta(solver_settings, dynamics, grid, time, values, target_time):\n    time_1, values_1 = euler_step(solver_settings, dynamics, grid, time, values, max_time_step=target_time - time)\n    time_step = time_1 - time\n    _, values_2 = euler_step(solver_settings, dynamics, grid, time_1, values_1, time_step)\n    return time_1, solver_settings.value_postprocessor(time_1, (values + values_2) / 2)\n\n\ndef third_order_total_variation_diminishing_runge_kutta(solver_settings, dynamics, grid, time, values, target_time):\n    time_1, values_1 = euler_step(solver_settings, dynamics, grid, time, values, max_time_step=target_time - time)\n    time_step = time_1 - time\n    _, values_2 = euler_step(solver_settings, dynamics, grid, time_1, values_1, time_step)\n    time_0_5, values_0_5 = time + time_step / 2, (3 / 4) * values + (1 / 4) * values_2\n    _, values_1_5 = euler_step(solver_settings, dynamics, grid, time_0_5, values_0_5, time_step)\n    return time_1, solver_settings.value_postprocessor(time_1, (1 / 3) * values + (2 / 3) * values_1_5)\n"
  },
  {
    "path": "hj_reachability/utils.py",
    "content": "import functools\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom typing import Any, Callable, Iterable, List, Mapping, Optional, TypeVar, Union\n\nT = TypeVar(\"T\")\nTree = Union[T, Iterable[\"Tree[T]\"], Mapping[Any, \"Tree[T]\"]]\n\n\ndef multivmap(fun: Callable,\n              in_axes: Tree[Optional[np.ndarray]],\n              out_axes: Tree[Optional[np.ndarray]] = None) -> Callable:\n    \"\"\"Applies `jax.vmap` over multiple axes (equivalent to multiple nested `jax.vmap`s).\n\n    Args:\n        fun: Function to be mapped over additional axes (see `jax.vmap` for more details).\n        in_axes: Similar to the specification of `in_axes` for `jax.vmap`, with the main difference being that instead\n            of `Optional[int]` for axis specification, it's `Optional[np.ndarray]`. For each corresponding input of\n            `fun`, the `np.ndarray` specifies a sequence of axes to `jax.vmap` over; note that these axes are not\n            specified directly as a `list` so as not to conflict with the possible structure of `in_axes`. All\n            non-`None` leaves of `in_axes` (there must be at least one) must have the same length. This length is the\n            number of times `jax.vmap` will be applied to `fun`.\n        out_axes: Similar to the specification of `out_axes` for `jax.vmap`, with the main difference being that instead\n            of `Optional[int]` for axis specification, it's `Optional[np.ndarray]`. For each corresponding output of\n            `fun`, the `np.ndarray` specifies a sequence of additional mapped axes to appear in the output. The length\n            of non-`None` leaves of `out_axes` must be the same as the length of non-`None` leaves of `in_axes`; the\n            order of both axes specifications corresponds to successive nested `jax.vmap` applications. If not provided,\n            `out_axes` defaults to `in_axes`.\n\n    Returns:\n        A batched/vectorized version of `fun` with arguments that correspond to those of `fun`, but with (possibly\n        multiple per input) extra array axes at positions indicated by `in_axes`, and a return value that corresponds\n        to that of `fun`, but with (possibly multiple per output) extra array axes at positions indicated by `out_axes`.\n\n    Raises:\n        ValueError: if any specified axes are negative or repeated.\n    \"\"\"\n\n    def get_axis_sequence(axis_array: np.ndarray) -> List:\n        axis_list = axis_array.tolist()\n        if any(axis < 0 for axis in axis_list):\n            raise ValueError(f\"All `multivmap` axes must be nonnegative; got {axis_list}.\")\n        if len(axis_list) != len(set(axis_list)):\n            raise ValueError(f\"All `multivmap` axes must be distinct; got {axis_list}.\")\n        for i in range(len(axis_list)):\n            for j in range(i + 1, len(axis_list)):\n                if axis_list[i] > axis_list[j]:\n                    axis_list[i] -= 1\n        return axis_list\n\n    multivmap_kwargs = {\"in_axes\": in_axes, \"out_axes\": in_axes if out_axes is None else out_axes}\n    axis_sequence_structure = jax.tree.structure(next(a for a in jax.tree.leaves(in_axes) if a is not None).tolist())\n    vmap_kwargs = jax.tree.transpose(jax.tree.structure(multivmap_kwargs), axis_sequence_structure,\n                                     jax.tree.map(get_axis_sequence, multivmap_kwargs))\n    return functools.reduce(lambda f, kwargs: jax.vmap(f, **kwargs), vmap_kwargs, fun)\n\n\ndef unit_vector(x):\n    \"\"\"Normalizes a vector `x`, returning a unit vector in the same direction, or a zero vector if `x` is zero.\"\"\"\n    norm2 = jnp.sum(jnp.square(x))\n    iszero = norm2 < jnp.finfo(jnp.zeros(()).dtype).eps**2\n    return jnp.where(iszero, jnp.zeros_like(x), x / jnp.sqrt(jnp.where(iszero, 1, norm2)))\n"
  },
  {
    "path": "hj_reachability/utils_test.py",
    "content": "from absl.testing import absltest\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom hj_reachability import utils\n\n\nclass UtilsTest(absltest.TestCase):\n\n    def setUp(self):\n        np.random.seed(0)\n\n    def test_multivmap(self):\n        a = np.random.random((3, 4, 5, 6))\n        np.testing.assert_allclose(utils.multivmap(jnp.max, np.array([0, 1]))(a), np.max(a, (2, 3)))\n        np.testing.assert_allclose(utils.multivmap(jnp.max, np.array([0, 1, 2]))(a), np.max(a, -1))\n        np.testing.assert_allclose(utils.multivmap(jnp.max, np.array([0, 1, 3]), np.array([0, 1, 2]))(a), np.max(a, 2))\n        np.testing.assert_allclose(\n            utils.multivmap(jnp.max, np.array([1, 0, 2]), np.array([0, 1, 2]))(a),\n            np.max(a, 3).swapaxes(0, 1))\n        np.testing.assert_allclose(\n            utils.multivmap(jnp.max, np.array([3, 2]), np.array([0, 1]))(a),\n            np.max(a, (0, 1)).swapaxes(0, 1))\n\n    def test_unit_vector(self):\n        unsafe_unit_vector = lambda x: x / jnp.linalg.norm(x, axis=-1, keepdims=True)\n        for d in range(1, 4):\n            np.testing.assert_array_equal(utils.unit_vector(np.zeros(d)), np.zeros(d))\n            self.assertTrue(np.all(np.isfinite(jax.jacobian(utils.unit_vector)(np.zeros(d)))))\n            self.assertTrue(np.all(np.isnan(jax.jacobian(unsafe_unit_vector)(np.zeros(d)))))\n            a = np.random.random((100, d))\n            np.testing.assert_allclose(jax.vmap(utils.unit_vector)(a), unsafe_unit_vector(a), atol=1e-6)\n            np.testing.assert_allclose(jax.vmap(jax.jacobian(utils.unit_vector))(a),\n                                       jax.vmap(jax.jacobian(unsafe_unit_vector))(a),\n                                       atol=1e-6)\n\n\nif __name__ == \"__main__\":\n    absltest.main()\n"
  },
  {
    "path": "requirements-test.txt",
    "content": "absl-py>=0.12.0\ntqdm>=4.60.0\n"
  },
  {
    "path": "requirements.txt",
    "content": "flax>=0.6.6\njax>=0.4.25\nnumpy>=1.22\n"
  },
  {
    "path": "setup.cfg",
    "content": "[yapf]\nbased_on_style = google\ncolumn_limit = 120\n\n[flake8]\nmax-line-length = 120\nignore =\n    # E731: do not assign a lambda expression, use a def\n    E731\n    # E741: do not use variables named 'I', 'O', or 'l'\n    E741\n    # W504: line break occurred after a binary operator\n    W504\n"
  },
  {
    "path": "setup.py",
    "content": "import os\nimport setuptools\n\n_CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))\n\n\ndef _get_version():\n    with open(os.path.join(_CURRENT_DIR, \"hj_reachability\", \"__init__.py\")) as f:\n        for line in f:\n            if line.startswith(\"__version__\") and \"=\" in line:\n                version = line[line.find(\"=\") + 1:].strip(\" '\\\"\\n\")\n                if version:\n                    return version\n        raise ValueError(\"`__version__` not defined in `hj_reachability/__init__.py`\")\n\n\ndef _parse_requirements(file):\n    with open(os.path.join(_CURRENT_DIR, file)) as f:\n        return [line.rstrip() for line in f if not (line.isspace() or line.startswith(\"#\"))]\n\n\nsetuptools.setup(name=\"hj_reachability\",\n                 version=_get_version(),\n                 description=\"Hamilton-Jacobi reachability analysis in JAX.\",\n                 long_description=open(\"README.md\").read(),\n                 long_description_content_type=\"text/markdown\",\n                 author=\"Ed Schmerling\",\n                 author_email=\"ednerd@gmail.com\",\n                 url=\"https://github.com/StanfordASL/hj_reachability\",\n                 license=\"MIT\",\n                 packages=setuptools.find_packages(),\n                 install_requires=_parse_requirements(\"requirements.txt\"),\n                 tests_require=_parse_requirements(\"requirements-test.txt\"),\n                 python_requires=\"~=3.8\")\n"
  }
]