[
  {
    "path": ".coveragerc",
    "content": "[run]\nsource = scico\ncommand_line = -m pytest\nomit =\n    scico/test/*\n    scico/plot.py\n    scico/trace.py\n    scico/linop/xray/_axitom/*.py\n\n[report]\n# Regexes for lines to exclude from consideration\nexclude_lines =\n    # Have to re-enable the standard pragma\n    pragma: no cover\n    def __repr__\n"
  },
  {
    "path": ".flake8",
    "content": "[flake8]\nmax-line-length = 100\nignore =\n#E731: do not assign a lambda expression, use a def\n  E731\n"
  },
  {
    "path": ".github/codecov.yml",
    "content": "coverage:\n  precision: 2\n  round: nearest\n  range: \"80...100\"\n\n  status:\n    project:\n      default:\n        target: auto\n        threshold: 0.05%\n    patch: false\n"
  },
  {
    "path": ".github/isbin.sh",
    "content": "#! /bin/bash\n\n# Determine whether files are acceptable for commit into main scico repo\n\nsize_threshold=65536\n\nSAVEIFS=$IFS\nIFS=$(echo -en \"\\n\\b\")\nOS=$(uname -a | cut -d ' ' -f 1)\n\nfor f in $@; do\n    echo $f\n    case \"$OS\" in\n        Linux)  size=$(stat --format \"%s\" $f);;\n        Darwin) size=$(stat -f \"%z\" $f);;\n        *)      echo \"Error: unsupported operating system $OS\" >&2; exit 1;;\n    esac\n    # Exception on maximum size for pytest-split .test_durations file\n    if [ $size -gt $size_threshold ] && [ \"$(basename $f)\" != \".test_durations\" ]; then\n        echo \"file exceeds maximum allowable size of $size_threshold bytes\"\n        echo \"raw data and ipynb files should go in scico-data\"\n        exit 2\n    fi\n    charset=$(file -b --mime $f | sed -e 's/.*charset=//')\n    if [ ! -L \"$f\" ] && [ \"$charset\" = \"binary\" ]; then\n        echo \"binary files cannot be commited to the repository\"\n        echo \"raw data and ipynb files should go in scico-data\"\n        exit 3\n    fi\n    basename=$(basename -- \"$f\")\n    ext=\"${basename##*.}\"\n    if [ \"$ext\" = \"ipynb\" ]; then\n        echo \"ipynb files cannot be commited to the repository\"\n        echo \"raw data and ipynb files should go in scico-data\"\n        exit 4\n    fi\ndone\n\nIFS=$SAVEIFS\n\nexit 0\n"
  },
  {
    "path": ".github/workflows/check_files.yml",
    "content": "# Check file types and sizes\n\nname: check files\n\non: [push, pull_request]\n\njobs:\n  checkfiles:\n    runs-on: ubuntu-latest\n    steps:\n    - name: checkout\n      uses: actions/checkout@v5\n    - id: files\n      uses: Ana06/get-changed-files@v2.3.0\n      continue-on-error: true\n    - run: |\n       for f in ${{ steps.files.outputs.added }}; do\n           ${GITHUB_WORKSPACE}/.github/./isbin.sh $f\n       done\n"
  },
  {
    "path": ".github/workflows/lint.yml",
    "content": "# Run isort and black on pushes to main and any pull requests\n\nname: lint\n\non:\n    push:\n        branches:\n          - main\n    pull_request:\n\njobs:\n  lint:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v5\n      - uses: actions/setup-python@v6\n        with:\n          python-version: \"3.12\"\n      - name: Black code formatter\n        uses: psf/black@stable\n        with:\n          version: \">=24.3.0\"\n      - name: Isort import sorter\n        uses: isort/isort-action@v1\n      - name: Pylint code analysis\n        run: |\n          pip install pylint\n          pylint --disable=all --enable=missing-docstring,broad-exception-raised scico\n"
  },
  {
    "path": ".github/workflows/mypy.yml",
    "content": "# Install and run mypy\n\nname: mypy\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\n  workflow_dispatch:\n\njobs:\n  mypy:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v5\n        with:\n          submodules: recursive\n      - name: Install Python 3\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.12\"\n      - name: Install dependencies\n        run: |\n          pip install mypy\n      - name: Run mypy\n        run: |\n          mypy --follow-imports=skip --ignore-missing-imports  --exclude \"(numpy|test)\" scico/ scico/numpy/util.py\n"
  },
  {
    "path": ".github/workflows/pypi_upload.yml",
    "content": "# When a tag is pushed, build packages and upload to PyPI\n\nname: pypi upload\n\n# Trigger when tags are pushed\non:\n  push:\n    tags:\n      - '*'\n\n  workflow_dispatch:\n\njobs:\n  build-and-upload:\n    name: Upload package to PyPI\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v5\n        with:\n          submodules: recursive\n      - name: Install Python 3\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.12\"\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          sudo apt-get install -y libopenblas-dev\n          pip install -r requirements.txt\n          pip install -r dev_requirements.txt\n          pip install wheel\n          python setup.py sdist bdist_wheel\n      - name: Upload package to PyPI\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          user: __token__\n          password: ${{ secrets.PYPI_API_TOKEN }}\n          verbose: true\n"
  },
  {
    "path": ".github/workflows/pytest_latest.yml",
    "content": "# Install scico requirements and run pytest with latest jax version\n\nname: unit tests (latest jax)\n\n# Controls when the workflow will run\non:\n  # Run workflow every Sunday at midnight UTC\n  schedule:\n    - cron: \"0 0 * * 0\"\n\n  # Allows you to run this workflow manually from the Actions tab\n  workflow_dispatch:\n\njobs:\n  pytest-latest-jax:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v5\n        with:\n          submodules: recursive\n      - name: Install Python 3\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.12\"\n      - name: Install lastversion\n        run: |\n          python -m pip install --upgrade pip\n          pip install lastversion\n      - name: Install dependencies\n        run: |\n          rjaxlib=$(grep jaxlib requirements.txt | sed -e 's/jaxlib.*<=\\([0-9\\.]*$\\)/\\1/')\n          rjax=$(grep -E \"jax[^lib]\" requirements.txt | sed -e 's/jax.*<=\\([0-9\\.]*$\\)/\\1/')\n          ljaxlib=$(lastversion --at pip jaxlib)\n          ljax=$(lastversion --at pip jax)\n          echo jaxlib  required: $rjaxlib  latest: $ljaxlib\n          echo jax     required: $rjax  latest: $ljax\n          if [ \"$rjaxlib\" = \"$ljaxlib\" ] && [ \"$rjax\" = \"$ljax\" ]; then\n            echo Test is redundant: required and latest jaxlib/jax versions match\n            echo \"TEST=cancel\" >> $GITHUB_ENV\n          else\n            echo \"TEST=run\" >> $GITHUB_ENV\n            sudo apt-get install -y libopenblas-dev\n            pip install -r requirements.txt\n            pip install -r dev_requirements.txt\n            pip install -e .\n            pip install --upgrade \"jax[cpu]\"\n          fi\n      - name: Run tests with pytest\n        run: |\n          TEST=\"${{ env.TEST }}\"\n          if [ \"$TEST\" = \"run\" ]; then\n            pytest\n          else\n            exit 0\n          fi\n"
  },
  {
    "path": ".github/workflows/pytest_macos.yml",
    "content": "# Install scico requirements and run pytest\n\nname: unit tests (macos)\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\n  workflow_dispatch:\n\njobs:\n\n  test:\n    runs-on: macos-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        group: [1, 2, 3, 4, 5]\n    name: pytest split ${{ matrix.group }} (macos)\n    defaults:\n      run:\n        shell: bash -l {0}\n    steps:\n      # Check-out the repository under $GITHUB_WORKSPACE\n      - uses: actions/checkout@v5\n        with:\n          submodules: recursive\n      # Set up conda environment\n      - name: Set up miniconda\n        uses: conda-incubator/setup-miniconda@v3\n        with:\n            miniforge-version: latest\n            activate-environment: test-env\n            python-version: \"3.12\"\n      # Configure conda environment cache\n      - name: Set up conda environment cache\n        uses: actions/cache@v4\n        with:\n          path: ${{ env.CONDA }}/envs\n          key: conda-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev_requirements.txt') }}-${{ env.CACHE_NUMBER }}\n        env:\n          CACHE_NUMBER: 0  # Increase this value to force cache reset\n        id: cache\n      # Display environment details\n      - name: Display environment details\n        run: |\n          conda info\n          printenv | sort\n      # Install dependencies in conda environment\n      - name: Install dependencies\n        if: steps.cache.outputs.cache-hit != 'true'\n        run: |\n          conda install -c conda-forge pytest pytest-cov\n          python -m pip install --upgrade pip\n          pip install pytest-split\n          pip install -r requirements.txt\n          pip install -r dev_requirements.txt\n          pip install \"bm3d>=4.0.0\"\n          pip install \"bm4d>=4.0.0\"\n          pip install \"ray[tune]>=2.44\"\n          pip install hyperopt\n          pip install \"setuptools<82.0.0\"  # workaround for hyperopt 0.2.7\n          pip install pydantic\n          pip install \"orbax-checkpoint>=0.5.0\"\n          #conda install -c conda-forge \"svmbir>=0.4.0\"\n          conda install -c astra-toolbox astra-toolbox\n          conda install -c conda-forge pyyaml\n      # Install package to be tested\n      - name: Install package to be tested\n        run: pip install -e .\n      # Run unit tests\n      - name: Run main unit tests\n        run: |\n          DURATIONS_FILE=$(mktemp)\n          bzcat data/pytest/durations_macos.bz2 > $DURATIONS_FILE\n          pytest -x --level=1 --durations-path=$DURATIONS_FILE --splits=5 --group=${{ matrix.group }} --pyargs scico\n"
  },
  {
    "path": ".github/workflows/pytest_ubuntu.yml",
    "content": "# Install scico requirements and run pytest\n\nname: unit tests (ubuntu)\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\n  workflow_dispatch:\n    inputs:\n      debug_enabled:\n        type: boolean\n        description: 'Run the build with tmate debugging enabled (https://github.com/marketplace/actions/debugging-with-tmate)'\n        required: false\n        default: false\njobs:\n\n  test:\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        group: [1, 2, 3, 4, 5]\n    name: pytest split ${{ matrix.group }} (ubuntu)\n    defaults:\n      run:\n        shell: bash -l {0}\n    steps:\n      # Check-out the repository under $GITHUB_WORKSPACE\n      - uses: actions/checkout@v5\n        with:\n          submodules: recursive\n      # Set up conda environment\n      - name: Set up miniconda\n        uses: conda-incubator/setup-miniconda@v3\n        with:\n            miniforge-version: latest\n            activate-environment: test-env\n            python-version: \"3.12\"\n      # Configure conda environment cache\n      - name: Set up conda environment cache\n        uses: actions/cache@v4\n        with:\n          path: ${{ env.CONDA }}/envs\n          key: conda-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev_requirements.txt') }}-${{ env.CACHE_NUMBER }}\n        env:\n          CACHE_NUMBER: 0  # Increase this value to force cache reset\n        id: cache\n      # Display environment details\n      - name: Display environment details\n        run: |\n          conda info\n          printenv | sort\n      # Install required system package\n      - name: Install required system package\n        run: sudo apt-get install -y libopenblas-dev\n      # Install dependencies in conda environment\n      - name: Install dependencies\n        if: steps.cache.outputs.cache-hit != 'true'\n        run: |\n          conda install -c conda-forge pytest pytest-cov\n          python -m pip install --upgrade pip\n          pip install pytest-split\n          pip install -r requirements.txt\n          pip install -r dev_requirements.txt\n          pip install \"bm4d>=4.2.2\"\n          pip install \"bm3d>=4.0.0\"\n          pip install \"ray[tune]>=2.44\"\n          pip install hyperopt\n          pip install \"setuptools<82.0.0\"  # workaround for hyperopt 0.2.7\n          pip install pydantic\n          pip install \"orbax-checkpoint>=0.5.0\"\n          conda install -c conda-forge \"svmbir>=0.4.0\"\n          conda install -c conda-forge astra-toolbox\n          conda install -c conda-forge pyyaml\n      # Install package to be tested\n      - name: Install package to be tested\n        run: pip install -e .\n      # Enable tmate debugging of manually-triggered workflows if the input option was provided\n      - name: Setup tmate session\n        uses: mxschmitt/action-tmate@v3\n        if: ${{ github.event_name == 'workflow_dispatch' && inputs.debug_enabled }}\n      # Run unit tests\n      - name: Run main unit tests\n        run: |\n          DURATIONS_FILE=$(mktemp)\n          bzcat data/pytest/durations_ubuntu.bz2 > $DURATIONS_FILE\n          pytest -x --cov --level=2 --durations-path=$DURATIONS_FILE --splits=5 --group=${{ matrix.group }} --pyargs scico\n      # Upload coverage data\n      - name: Upload coverage\n        uses: actions/upload-artifact@v4\n        with:\n          include-hidden-files: true\n          name: coverage${{ matrix.group }}\n          path: ${{ github.workspace }}/.coverage\n      # Run doc tests\n      - name: Run doc tests\n        if: matrix.group == 1\n        run: |\n          pytest --ignore-glob=\"*test_*.py\" --ignore=scico/linop/xray --doctest-modules scico\n          pytest --doctest-glob=\"*.rst\" docs\n\n  coverage:\n    needs: test\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v5\n      - name: Set up Python 3.12\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.12\"\n      - name: Install deps\n        run: |\n          python -m pip install --upgrade pip\n          pip install coverage\n      - name: Download all artifacts\n        # Downloads coverage1, coverage2, etc.\n        uses: actions/download-artifact@v4\n      - name: Run coverage\n        run: |\n          coverage combine coverage?/.coverage\n          coverage report\n          coverage xml\n      - uses: codecov/codecov-action@v4\n        env:\n          CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n        with:\n          env_vars: OS,PYTHON\n          fail_ci_if_error: false\n          files: coverage.xml\n          flags: unittests\n          name: codecov-umbrella\n          verbose: true\n"
  },
  {
    "path": ".github/workflows/test_examples.yml",
    "content": "# Install scico requirements and run short versions of example scripts\n\nname: test examples\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\n  # Allow this workflow to be run manually from the Actions tab\n  workflow_dispatch:\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n    name: test examples (ubuntu)\n    defaults:\n      run:\n        shell: bash -l {0}\n    steps:\n      # Check-out the repository under $GITHUB_WORKSPACE\n      - uses: actions/checkout@v5\n        with:\n          submodules: recursive\n      # Set up conda environment\n      - name: Set up miniconda\n        uses: conda-incubator/setup-miniconda@v3\n        with:\n            miniforge-version: latest\n            activate-environment: test-env\n            python-version: \"3.12\"\n      # Configure conda environment cache\n      - name: Set up conda environment cache\n        uses: actions/cache@v4\n        with:\n          path: ${{ env.CONDA }}/envs\n          key: conda-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev_requirements.txt') }}-${{ hashFiles('examples/examples_requirements.txt') }}-${{ env.CACHE_NUMBER }}\n        env:\n          CACHE_NUMBER: 0  # Increase this value to force cache reset\n        id: cache\n      # Display environment details\n      - name: Display environment details\n        run: |\n          conda info\n          printenv | sort\n      # Install required system package\n      - name: Install required system package\n        run: sudo apt-get install -y libopenblas-dev\n      # Install dependencies in conda environment\n      - name: Install dependencies\n        if: steps.cache.outputs.cache-hit != 'true'\n        run: |\n          conda install -c conda-forge pytest pytest-cov\n          python -m pip install --upgrade pip\n          pip install -r requirements.txt\n          pip install -r dev_requirements.txt\n          conda install -c conda-forge astra-toolbox\n          conda install -c conda-forge pyyaml\n          pip install --upgrade --force-reinstall scipy>=1.6.0  # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version\n          pip install -r examples/examples_requirements.txt\n      # Install package to be tested\n      - name: Install package to be tested\n        run: pip install -e .\n      # Run example test\n      - name: Run example test\n        run: |\n          ${GITHUB_WORKSPACE}/examples/scriptcheck.sh -e -d -t -g\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Editor backups\n.*~\n\n# Docs generation\ndocs/source/_autosummary/\ndocs/source/examples/\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# VS Code settings\n.vscode/\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# macos files\n*.DS_Store\n"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"data\"]\n\tpath = data\n\turl = https://github.com/lanl/scico-data.git\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "# See https://pre-commit.com for more information\n# See https://pre-commit.com/hooks.html for more hooks\nrepos:\n-   repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v2.3.0\n    hooks:\n    -   id: end-of-file-fixer\n    -   id: trailing-whitespace\n-   repo: local\n    hooks:\n    - id: check-for-binary\n      name: check for binary/ipynb files\n      entry: .github/isbin.sh\n      language: script\n      pass_filenames: true\n    - id: autoflake\n      name: autoflake\n      entry: autoflake\n      language: python\n      language_version:  python3\n      types: [python]\n      args: ['-i', '--remove-all-unused-imports',  '--ignore-init-module-imports']\n    - id: isort\n      name: isort (python)\n      entry: isort\n      language: python\n      language_version:  python3\n      types: [python]\n    - id: isort\n      name: isort (cython)\n      entry: isort\n      language: python\n      language_version:  python3\n      types: [cython]\n    - id: black\n      name: black\n      entry: black\n      description: 'Black: The uncompromising Python code formatter'\n      language: python\n      language_version:  python3\n      types: [python]\n    - id: pylint\n      name: pylint\n      entry: pylint\n      language: python\n      language_version: python3\n      types: [python]\n      exclude: ^(scico/test/|examples|docs)\n      args: ['--score=n', '--disable=all', '--enable=missing-docstring,broad-exception-raised']\n"
  },
  {
    "path": ".readthedocs.yaml",
    "content": "# .readthedocs.yaml\n# Read the Docs configuration file\n# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details\n\n# Required\nversion: 2\n\n# Get submodules\nsubmodules:\n  include: all\n  recursive: true\n\n# Set the version of Python and other tools you might need\nbuild:\n  os: ubuntu-24.04\n  tools:\n    python: \"3.12\"\n  jobs:\n    pre_build:\n      - mkdir -p docs/source/examples\n      - |\n        for f in data/notebooks/*; do\n          b=$(basename $f)\n          if [ ! -f \"docs/source/examples/$b\" ]; then\n            ln -s -t docs/source/examples \"../../../$f\"\n          fi\n        done\n    post_build:  # unclear why this is necessary\n      - cp docs/source/_static/scico.css  _readthedocs/html/_static\n  apt_packages:\n    - graphviz\n    - libopenblas-dev\n\n# Build documentation in the docs/ directory with Sphinx\nsphinx:\n   builder: html\n   configuration: docs/source/conf.py\n   fail_on_warning: false\n\n# Declare the Python requirements required to build your docs\npython:\n   install:\n   - requirements: docs/docs_requirements.txt\n   - requirements: docs/rtd_requirements.txt\n"
  },
  {
    "path": "CHANGES.rst",
    "content": "===================\nSCICO Release Notes\n===================\n\nVersion 0.0.8   (unreleased)\n----------------------------\n• Enable certain parameters of array creation functions to trigger\n ``BlockArray`` creation when they receive lists (currently ``device``).\n• New functional ``functional.BoxIndicator``.\n• Support ``jaxlib`` and ``jax`` versions 0.5.0 to 0.10.0.\n• Support ``flax`` versions 0.8.0 to 0.12.7.\n• Various bug fixes and minor improvements.\n\n\n\nVersion 0.0.7   (2025-12-09)\n----------------------------\n\n• New module ``scico.trace`` for tracing function/method calls.\n• New generic functional ``functional.ComposedFunctional`` representing\n  a functional composed with an orthogonal linear operator.\n• New optimizer methods ``save_state`` and ``load_state`` supporting\n  algorithm state checkpointing.\n• New classes for creating a volume from an image by symmetry, and\n  for cone beam X-ray transform of a cylindrically symmetric object\n  in module ``linop.xray.symcone``.\n• New utility functions for CT reconstruction preprocessing added in\n  module ``linop.xray``.\n• Moved ``linop.abel`` module to ``linop.xray.abel``.\n• Make ``orbax-checkpoint`` dependency optional due to absence of recent\n  conda-forge packages.\n• Support ``jaxlib`` and ``jax`` versions 0.5.0 to 0.8.1.\n• Support ``flax`` versions 0.8.0 to 0.12.0.\n\n\n\nVersion 0.0.6   (2024-10-25)\n----------------------------\n\n• Significant changes to ``linop.xray.astra`` API.\n• Rename integrated 2D X-ray transform class to\n  ``linop.xray.XRayTransform2D`` and add filtered back projection method\n  ``fbp``.\n• New integrated 3D X-ray transform via ``linop.xray.XRayTransform3D``.\n• New functional ``functional.IsotropicTVNorm`` and faster implementation\n  of ``functional.AnisotropicTVNorm``.\n• New linear operators ``linop.ProjectedGradient``, ``linop.PolarGradient``,\n  ``linop.CylindricalGradient``, and ``linop.SphericalGradient``.\n• Rename ``scico.numpy.util.parse_axes`` to\n  ``scico.numpy.util.normalize_axes``.\n• Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to\n  ``scico.flax.save_variables`` and ``scico.flax.load_variables``\n  respectively.\n• Support ``jaxlib`` and ``jax`` versions 0.4.13 to 0.4.35.\n• Support ``flax`` versions 0.8.0 to 0.10.0.\n\n\n\nVersion 0.0.5   (2023-12-18)\n----------------------------\n\n• New functionals ``functional.AnisotropicTVNorm`` and\n  ``functional.ProximalAverage`` with proximal operator approximations.\n• New integrated Radon/X-ray transform ``linop.XRayTransform``.\n• New operators ``operator.DiagonalStack`` and ``operator.VerticalStack``.\n• Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and\n  ``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes\n  to ``XRayTransform``.\n• Rename ``AbelProjector`` to ``AbelTransform``.\n• Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``.\n• Rename some ``__init__`` parameters of ``linop.DiagonalStack`` and\n  ``linop.VerticalStack``.\n• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.23.\n• Support ``flax`` versions up to 0.7.5.\n• Use ``orbax`` for checkpointing ``flax`` models.\n\n\n\nVersion 0.0.4   (2023-08-03)\n----------------------------\n\n• Add new ``Function`` class for representing array-to-array mappings with\n  more than one input.\n• Add new methods and a function for computing Jacobian-vector products for\n  ``Operator`` objects.\n• Add new proximal ADMM solvers.\n• Add new ADMM subproblem solvers for problems involving a sum-of-convolutions\n  operator.\n• Extend support for other ML models including UNet, ODP and MoDL.\n• Add functionality for training Flax-based ML models and for data generation.\n• Enable diagnostics for ML training loops.\n• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.14.\n• Change required packages and version numbers, including more recent version\n  for ``flax``.\n• Drop support for Python 3.7.\n• Add support for 3D tomographic projection with the ASTRA Toolbox.\n\n\n\nVersion 0.0.3   (2022-09-21)\n----------------------------\n\n• Change required packages and version numbers, including more recent version\n  requirements for ``numpy``, ``scipy``, ``svmbir``, and ``ray``.\n• Package ``bm4d`` removed from main requirements list due to issue #342.\n• Support ``jaxlib`` versions 0.3.0 to 0.3.15 and ``jax`` versions\n  0.3.0 to 0.3.17.\n• Rename linear operators in ``radon_astra`` and ``radon_svmbir`` modules\n  to ``TomographicProjector``.\n• Add support for fan beam CT in ``radon_svmbir`` module.\n• Add function ``linop.linop_from_function`` for constructing linear\n  operators from functions.\n• Enable addition operator for functionals.\n• Completely new implementation of ``BlockArray`` class.\n• Additional solvers in ``scico.solver``.\n• New Huber norm (``HuberNorm``) and set distance functionals (``SetDistance``\n  and ``SquaredSetDistance``).\n• New loss functions ``loss.SquaredL2AbsLoss`` and\n  ``loss.SquaredL2SquaredAbsLoss`` for phase retrieval problems.\n• Add interface to BM4D denoiser.\n• Change interfaces of ``linop.FiniteDifference`` and ``linop.DFT``.\n• Change filenames of some example scripts (and corresponding notebooks).\n• Add support for Python 3.7.\n• New ``DiagonalStack`` linear operator.\n• Add support for non-linear operators to ``optimize.PDHG`` optimizer class.\n• Various bug fixes.\n\n\n\nVersion 0.0.2   (2022-02-14)\n----------------------------\n\n• Additional optimization algorithms: Linearized ADMM and PDHG.\n• Additional Abel transform and array slicing linear operators.\n• Additional nuclear norm functional.\n• New module ``scico.ray.tune`` providing a simplified interface to Ray Tune.\n• Move optimization algorithms into ``optimize`` subpackage.\n• Additional iteration stats columns for iterative ADMM subproblem solvers.\n• Renamed \"Primal Rsdl\" to \"Prml Rsdl\" in displayed iteration stats.\n• Move some functions from ``util`` and ``math`` modules to new ``array``\n  module.\n• Bump pinned ``jaxlib`` and ``jax`` versions to 0.3.0.\n\n\nVersion 0.0.1   (2021-11-24)\n----------------------------\n\n• Initial release.\n"
  },
  {
    "path": "LICENSE",
    "content": "BSD 3-Clause License\n\nCopyright (c) 2021-2025, Los Alamos National Laboratory\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright notice, this\n   list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright notice,\n   this list of conditions and the following disclaimer in the documentation\n   and/or other materials provided with the distribution.\n\n3. Neither the name of the copyright holder nor the names of its\n   contributors may be used to endorse or promote products derived from\n   this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "include MANIFEST.in\ninclude README.md\ninclude CHANGES.rst\ninclude LICENSE\ninclude setup.py\ninclude conftest.py\ninclude pyproject.toml\ninclude pytest.ini\ninclude requirements.txt\ninclude dev_requirements.txt\ninclude docs/docs_requirements.txt\n\nrecursive-include scico *.py\nrecursive-include scico/data *.png *.mpk *.rst\nrecursive-include docs Makefile *.py *.ipynb *.rst *.bib *.css *.svg *.png *.ico\nrecursive-include examples *_requirements.txt *.txt *.rst *.py *.sh\nrecursive-include misc *.py *.sh *.rst\n"
  },
  {
    "path": "README.md",
    "content": "[![Python \\>= 3.8](https://img.shields.io/badge/python-3.8+-green.svg)](https://www.python.org/)\n[![Package License](https://img.shields.io/github/license/lanl/scico.svg)](https://github.com/lanl/scico/blob/main/LICENSE)\n[![Code style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)\n[![Documentation Status](https://readthedocs.org/projects/scico/badge/?version=latest)](http://scico.readthedocs.io/en/latest/?badge=latest)\n[![JOSS paper](https://joss.theoj.org/papers/10.21105/joss.04722/status.svg)](https://doi.org/10.21105/joss.04722)\\\n[![Lint status](https://github.com/lanl/scico/actions/workflows/lint.yml/badge.svg)](https://github.com/lanl/scico/actions/workflows/lint.yml)\n[![Test status](https://github.com/lanl/scico/actions/workflows/pytest_ubuntu.yml/badge.svg)](https://github.com/lanl/scico/actions/workflows/pytest_ubuntu.yml)\n[![Test coverage](https://codecov.io/gh/lanl/scico/branch/main/graph/badge.svg?token=wQimmjnzFf)](https://codecov.io/gh/lanl/scico)\n[![CodeFactor](https://www.codefactor.io/repository/github/lanl/scico/badge/main)](https://www.codefactor.io/repository/github/lanl/scico/overview/main)\\\n[![PyPI package version](https://badge.fury.io/py/scico.svg)](https://badge.fury.io/py/scico)\n[![PyPI download statistics](https://static.pepy.tech/personalized-badge/scico?period=total&left_color=grey&right_color=brightgreen&left_text=downloads)](https://pepy.tech/project/scico)\n[![Conda Forge Release](https://img.shields.io/conda/vn/conda-forge/scico.svg)](https://anaconda.org/conda-forge/scico)\n[![Conda Forge Downloads](https://img.shields.io/conda/dn/conda-forge/scico.svg)](https://anaconda.org/conda-forge/scico)\\\n[![View notebooks at nbviewer](https://raw.githubusercontent.com/jupyter/design/master/logos/Badges/nbviewer_badge.svg)](https://nbviewer.jupyter.org/github/lanl/scico-data/tree/main/notebooks/index.ipynb)\n[![Run notebooks on binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/lanl/scico-data/binder?labpath=notebooks%2Findex.ipynb)\n[![Run notebooks on google colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lanl/scico-data/blob/colab/notebooks/index.ipynb)\n\n\n# Scientific Computational Imaging Code (SCICO)\n\nSCICO is a Python package for solving the inverse problems that arise in\nscientific imaging applications. Its primary focus is providing methods\nfor solving ill-posed inverse problems by using an appropriate prior\nmodel of the reconstruction space. SCICO includes a growing suite of\noperators, cost functionals, regularizers, and optimization routines\nthat may be combined to solve a wide range of problems, and is designed\nso that it is easy to add new building blocks. SCICO is built on top of\n[JAX](https://github.com/google/jax), which provides features such as\nautomatic gradient calculation and GPU acceleration.\n\n[Documentation](https://scico.rtfd.io/) is available online. If you use\nthis software for published work, please cite the corresponding [JOSS\nPaper](https://doi.org/10.21105/joss.04722) (see bibtex entry\n`balke-2022-scico` in `docs/source/references.bib`).\n\n\n# Installation\n\nThe online documentation includes detailed\n[installation instructions](https://scico.rtfd.io/en/latest/install.html).\n\n\n# Usage Examples\n\nUsage examples are available as Python scripts and Jupyter Notebooks.\nExample scripts are located in `examples/scripts`. The corresponding\nJupyter Notebooks are provided in the\n[scico-data](https://github.com/lanl/scico-data) submodule and symlinked\nto `examples/notebooks`. They are also viewable on\n[GitHub](https://github.com/lanl/scico-data/tree/main/notebooks) or\n[nbviewer](https://nbviewer.jupyter.org/github/lanl/scico-data/tree/main/notebooks/index.ipynb),\nand can be run online on\n[binder](https://mybinder.org/v2/gh/lanl/scico-data/binder?labpath=notebooks%2Findex.ipynb)\nor\n[google colab](https://colab.research.google.com/github/lanl/scico-data/blob/colab/notebooks/index.ipynb).\n\n\n# License\n\nSCICO is distributed as open-source software under a BSD 3-Clause\nLicense (see the `LICENSE` file for details).\n\nLANL open source approval reference C20091.\n\n\\(c\\) 2020-2026. Triad National Security, LLC. All rights reserved. This\nprogram was produced under U.S. Government contract 89233218CNA000001\nfor Los Alamos National Laboratory (LANL), which is operated by Triad\nNational Security, LLC for the U.S. Department of Energy/National\nNuclear Security Administration. All rights in the program are reserved\nby Triad National Security, LLC, and the U.S. Department of\nEnergy/National Nuclear Security Administration. The Government has\ngranted for itself and others acting on its behalf a nonexclusive,\npaid-up, irrevocable worldwide license in this material to reproduce,\nprepare derivative works, distribute copies to the public, perform\npublicly and display publicly, and to permit others to do so.\n"
  },
  {
    "path": "conftest.py",
    "content": "\"\"\"\nConfigure pytest.\n\"\"\"\n\nimport os\n\nimport numpy as np\n\nimport pytest\n\nos.environ[\"RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO\"] = \"0\"  # suppress ray warning\ntry:\n    import ray  # noqa: F401\nexcept ImportError:\n    have_ray = False\nelse:\n    have_ray = True\n    ray.init(num_cpus=1)  # call required to be here: see ray-project/ray#44087\n\nimport jax.numpy as jnp\n\nimport scico.numpy as snp\n\n\ndef pytest_sessionstart(session):\n    \"\"\"Initialize before start of test session.\"\"\"\n    # placeholder: currently unused\n\n\ndef pytest_sessionfinish(session, exitstatus):\n    \"\"\"Clean up after end of test session.\"\"\"\n    if have_ray:\n        ray.shutdown()\n\n\n@pytest.fixture(autouse=True)\ndef add_modules(doctest_namespace):\n    \"\"\"Add common modules for use in docstring examples.\n\n    Necessary because `np` is used in doc strings for jax functions\n    (e.g. `linear_transpose`) that get pulled into `scico/__init__.py`.\n    Also allow `snp` and `jnp` to be used without explicitly importing.\n    \"\"\"\n    doctest_namespace[\"np\"] = np\n    doctest_namespace[\"snp\"] = snp\n    doctest_namespace[\"jnp\"] = jnp\n"
  },
  {
    "path": "dev_requirements.txt",
    "content": "-r requirements.txt\npylint\npytest>=7.3.0\npytest-split\npackaging\npre-commit\nblack>=24.3.0\nisort\nautoflake\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Makefile for Sphinx documentation\n\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS\t?=\nSPHINXBUILD\t?= sphinx-build\nSOURCEDIR\t= source\nBUILDDIR\t= ../build/sphinx\n\n.PHONY: help clean Makefile\n\n\n# Put this first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\nclean:\n\trm -rf $(BUILDDIR)/*\n\trm -f $(SOURCEDIR)/_autosummary/*\n\trm -f $(SOURCEDIR)/examples/*\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@mkdir -p source/examples; \\\n\tfor f in ../data/notebooks/*; do \\\n\t  b=$$(basename $$f) ; \\\n\t  if [ ! -f \"source/examples/$$b\" ]; then \\\n\t      echo Creating soft link for notebook $$b ; \\\n\t      ln -s -t source/examples \"../../$$f\" ; \\\n\t  fi \\\n\tdone\n\t$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/docs_requirements.txt",
    "content": "-r ../requirements.txt\nsphinx>=5.0.0\nsphinxcontrib-napoleon\nsphinxcontrib-bibtex\nsphinx-autodoc-typehints\nfuro>=2024.5.6\njinja2<3.1.0  # temporary fix for jinja2/nbconvert bug\ntraitlets!=5.2.2  # temporary fix for ipython/traitlets#741\nnbsphinx\nipython\nipython_genutils\npy2jn\npygraphviz>=1.9\npandoc\ndocutils>=0.18\n"
  },
  {
    "path": "docs/rtd_requirements.txt",
    "content": "# nbconvert>=7.5 requires a version of pandoc that is not available\n#   in the readthedocs build environment\nnbconvert<7.5\n"
  },
  {
    "path": "docs/source/_static/scico.css",
    "content": "/* furo theme customization */\n\nbody[data-theme=\"dark\"] figure img {\n  filter: invert(100%);\n}\n\n.sidebar-drawer {\n  width: fit-content !important;\n}\n\n.main > .content {\n  min-width: 75%;\n  width: fit-content !important;\n  max-width: 80em;\n}\n\n.highlight {\n  background: #e9efff;\n}\n\n.sidebar-brand-text {\n  font-size: 1.0rem !important;\n  text-align: center;\n  padding-top: 0.5em;\n}\n\n\n/* Code display section */\n\ndiv.doctest.highlight-default {\n  background-color: #f9f9f4;\n}\n\n\n/* Style for autosummary API docs */\n\n[data-theme=light] dl.field-list.simple {\n  background-color: #f5f5f5;\n  border-radius: 4px;\n}\n\n[data-theme=light] dl.field-list.simple > dt.field-odd {\n  background-color: #f2f2f2;\n  border-radius: 4px;\n}\n\n[data-theme=light] dl.field-list.simple > dt.field-even {\n  background-color: #f2f2f2;\n  border-radius: 4px;\n}\n\n[data-theme=light] dl.py.data {\n  background-color: #fdfafa;\n  border-radius: 4px;\n}\n\n[data-theme=light] dl.py.data > dt {\n  border-radius: 4px;\n}\n\n[data-theme=light] dl.py.attribute {\n  background-color: #fdfafa;\n  border-radius: 4px;\n}\n\n[data-theme=light] dl.py.attribute > dt {\n  border-radius: 4px;\n}\n\n[data-theme=light] dl.py.function {\n  background-color: #fdfafa;\n  border-radius: 4px;\n}\n\n[data-theme=light] dl.py.function > dt {\n  border-radius: 4px;\n}\n\n[data-theme=light] dl.py.function blockquote {\n  background-color: #f5f5f5;\n  border-left: 0px;\n}\n\n[data-theme=light] dl.py.class {\n  background-color: #fdfafa;\n  border-radius: 4px;\n}\n\n[data-theme=light] dl.py.class > dt {\n  border-radius: 4px;\n}\n\n[data-theme=light] dl.py.method {\n  background-color: #f6f6f6;\n  border-radius: 4px;\n}\n\n[data-theme=light] dl.py.method > dt {\n  border-radius: 4px;\n}\n\n[data-theme=light] dl.py.property {\n  background-color: #f6f6f6;\n  border-radius: 4px;\n}\n\n[data-theme=light] dl.py.property > dt {\n  border-radius: 4px;\n}\n\n\n/* Style for figure captions */\n\ndiv.figure p.caption span.caption-text,\nfigcaption span.caption-text {\n  font-size: var(--font-size--small);\n  margin-left: 5%;\n  margin-right: 5%;\n  display: inline-block;\n  text-align: justify;\n}\n"
  },
  {
    "path": "docs/source/_templates/autosummary/module.rst",
    "content": "{{ fullname | escape | underline}}\n\n.. automodule:: {{ fullname }}\n\n   {% block attributes %}\n   {% if attributes %}\n   .. rubric:: {{ _('Module Attributes') }}\n\n   .. autosummary::\n   {% for item in attributes %}\n      {{ item }}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n\n\n   {% block modules %}\n   {% if modules %}\n   .. rubric:: Modules\n\n   .. autosummary::\n      :toctree:\n      :recursive:\n   {% for item in modules %}\n      {{ item }}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n\n\n   {% block functions %}\n   {% if functions %}\n   .. rubric:: {{ _('Functions') }}\n\n   .. autosummary::\n   {% for item in functions %}\n      {{ item }}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n\n   {% block classes %}\n   {% if classes %}\n   .. rubric:: {{ _('Classes') }}\n\n   .. autosummary::\n   {% for item in classes %}\n      {{ item }}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n\n   {% block exceptions %}\n   {% if exceptions %}\n   .. rubric:: {{ _('Exceptions') }}\n\n   .. autosummary::\n   {% for item in exceptions %}\n      {{ item }}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n"
  },
  {
    "path": "docs/source/_templates/package.rst",
    "content": "API Reference\n=============\n\n.. automodule:: {{ fullname }}\n\n   {% block modules %}\n   {% if modules %}\n   .. autosummary::\n     :toctree:\n     :recursive:\n   {% for item in modules %}\n      {{ item }}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n"
  },
  {
    "path": "docs/source/_templates/sidebar/brand.html",
    "content": "{#-\n\nHi there!\n\nYou might be interested in https://pradyunsg.me/furo/customisation/sidebar/\n\nAlthough if you're reading this, chances are that you're either familiar\nenough with Sphinx that you know what you're doing, or landed here from that\ndocumentation page.\n\nHope your day's going well. :)\n\n-#}\n<a class=\"sidebar-brand{% if logo %} centered{% endif %}\" href=\"{{ pathto(master_doc) }}\">\n  {% block brand_content %}\n  {%- if logo_url %}\n  <div class=\"sidebar-logo-container\">\n    <img class=\"sidebar-logo\" src=\"{{ logo_url }}\" alt=\"Logo\"/>\n  </div>\n  {%- endif %}\n  {%- if theme_light_logo and theme_dark_logo %}\n  <div class=\"sidebar-logo-container\">\n    <img class=\"sidebar-logo only-light\" src=\"{{ pathto('_static/' + theme_light_logo, 1) }}\" alt=\"Light Logo\"/>\n    <img class=\"sidebar-logo only-dark\" src=\"{{ pathto('_static/' + theme_dark_logo, 1) }}\" alt=\"Dark Logo\"/>\n  </div>\n  {%- endif %}\n  {% if not theme_sidebar_hide_name %}\n  <span class=\"sidebar-brand-text\">{{ version }}</span>\n  {%- endif %}\n  {% endblock brand_content %}\n</a>\n"
  },
  {
    "path": "docs/source/advantages.rst",
    "content": "Why SCICO?\n==========\n\nAdvantages of JAX-based Design\n------------------------------\n\nThe vast majority of scientific computing packages in Python are based\non `NumPy <https://numpy.org/>`__ and `SciPy <https://scipy.org/>`__.\nSCICO, in contrast, is based on `JAX\n<https://jax.readthedocs.io/en/latest/>`__, which provides most of the\nsame features, but with the addition of automatic differentiation, GPU\nsupport, and just-in-time (JIT) compilation. (The availability of\nthese features in SCICO is subject to some :ref:`caveats\n<non_jax_dep>`.) SCICO users and developers are advised to become\nfamiliar with the `differences between JAX and\nNumPy. <https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html>`_.\n\nWhile recent advances in automatic differentiation have primarily been\ndriven by its important role in deep learning, it is also invaluable in\na functional minimization framework such as SCICO. The most obvious\nadvantage is allowing the use of gradient-based minimization methods\nwithout the need for tedious mathematical derivation of an expression\nfor the gradient. Equally valuable, though, is the ability to\nautomatically compute the adjoint operator of a linear operator, the\nmanual derivation of which is often time-consuming.\n\nGPU support and JIT compilation both offer the potential for significant\ncode acceleration, with the speed gains that can be obtained depending\non the algorithm/function to be executed. In many cases, a speed\nimprovement by an order of magnitude or more can be obtained by running\nthe same code on a GPU rather than a CPU, and similar speed gains can\nsometimes also be obtained via JIT compilation.\n\nThe figure below shows timing results obtained on a compute server\nwith an Intel Xeon Gold 6230 CPU and NVIDIA GeForce RTX 2080 Ti\nGPU. It is interesting to note that for :class:`.FiniteDifference` the\nGPU provides no acceleration, while JIT provides more than an order of\nmagnitude of speed improvement on both CPU and GPU. For :class:`.DFT`\nand :class:`.Convolve`, significant JIT acceleration is limited to the\nGPU, which also provides significant acceleration over the CPU.\n\n\n.. image:: /figures/jax-timing.png\n     :align: center\n     :width: 95%\n     :alt: Timing results for SCICO operators on CPU and GPU with and without JIT\n\n\nRelated Packages\n----------------\n\nMany elements of SCICO are partially available in other packages. We\nbriefly review them here, highlighting some of the main differences with\nSCICO.\n\n`GlobalBioIm <https://biomedical-imaging-group.github.io/GlobalBioIm/>`__\nis similar in structure to SCICO (and a major inspiration for SCICO),\nproviding linear operators and solvers for inverse problems in imaging.\nHowever, it is written in MATLAB and is thus not usable in a completely\nfree environment. It also lacks the automatic adjoint calculation and\nsimple GPU support offered by SCICO.\n\n`PyLops <https://pylops.readthedocs.io>`__ provides a linear operator\nclass and many built-in linear operators. These operators are compatible\nwith many `SciPy <https://scipy.org/>`__ solvers. GPU support is\nprovided via `CuPy <https://cupy.dev>`__, which has the disadvantage\nthat switching for a CPU to GPU requires code changes, unlike SCICO and\n`JAX <https://jax.readthedocs.io/en/latest/>`__. SCICO is more focused\non computational imaging that PyLops and has several specialized\noperators that PyLops does not.\n\n`Pycsou <https://matthieumeo.github.io/pycsou/html/index>`__, like\nSCICO, is a Python project inspired by GlobalBioIm. Since it is based on\nPyLops, it shares the disadvantages with respect to SCICO of that\nproject.\n\n`ODL <https://odlgroup.github.io/odl/>`__ provides a variety of\noperators and related infrastructure for prototyping of inverse\nproblems. It is built on top of\n`NumPy <https://numpy.org/>`__/`SciPy <https://scipy.org/>`__, and does\nnot support any of the advanced features of\n`JAX <https://jax.readthedocs.io/en/latest/>`__.\n\n`ProxImaL <http://www.proximal-lang.org/en/latest/>`__ is a Python\npackage for image optimization problems. Like SCICO and many of the\nother projects listed here, problems are specified by combining objects\nrepresenting, operators, functionals, and solvers. It does not support\nany of the advanced features of\n`JAX <https://jax.readthedocs.io/en/latest/>`__.\n\n`ProxMin <https://github.com/pmelchior/proxmin>`__ provides a set of\nproximal optimization algorithms for minimizing non-smooth functionals.\nIt is built on top of\n`NumPy <https://numpy.org/>`__/`SciPy <https://scipy.org/>`__, and does\nnot support any of the advanced features of\n`JAX <https://jax.readthedocs.io/en/latest/>`__ (however, an open issue\nsuggests that `JAX <https://jax.readthedocs.io/en/latest/>`__\ncompatibility is planned).\n\n`CVXPY <https://www.cvxpy.org>`__ provides a flexible language for\ndefining optimization problems and a wide selection of solvers, but has\nlimited support for matrix-free methods.\n\nOther related projects that may be of interest include:\n\n-  `ToMoBAR <https://github.com/dkazanc/ToMoBAR>`__\n-  `CCPi-Regularisation Toolkit <https://github.com/vais-ral/CCPi-Regularisation-Toolkit>`__\n-  `SPORCO <https://github.com/lanl/sporco>`__\n-  `SigPy <https://github.com/mikgroup/sigpy>`__\n-  `MIRT <https://github.com/JeffFessler/MIRT.jl>`__\n-  `BART <http://mrirecon.github.io/bart/>`__\n"
  },
  {
    "path": "docs/source/api.rst",
    "content": ":orphan:\n\nAPI Documentation\n=================\n\n.. autosummary::\n   :toctree: _autosummary\n   :template: package.rst\n   :caption: API Reference\n   :recursive:\n\n   scico\n"
  },
  {
    "path": "docs/source/classes.rst",
    "content": ".. _classes:\n\n******************\nMain SCICO Classes\n******************\n\n.. include:: include/blockarray.rst\n.. include:: include/operator.rst\n.. include:: include/functional.rst\n.. include:: include/optimizer.rst\n.. include:: include/learning.rst\n"
  },
  {
    "path": "docs/source/conf/10-project.py",
    "content": "from scico._version import package_version\n\n# General information about the project.\nproject = \"SCICO\"\ncopyright = \"2020-2026, SCICO Developers\"\n\n# The version info for the project you're documenting, acts as replacement for\n# |version| and |release|, also used in various other places throughout the\n# built documents.\n#\n# The short X.Y version.\nversion = package_version()\n# The full version, including alpha/beta/rc tags.\nrelease = version\n"
  },
  {
    "path": "docs/source/conf/15-theme.py",
    "content": "# -- Options for HTML output ----------------------------------------------\n\n# The theme to use for HTML and HTML Help pages. See the documentation for\n# a list of builtin themes.\n# html_theme = \"python_docs_theme\"\nhtml_theme = \"furo\"\n\nhtml_theme_options = {\n    \"top_of_page_buttons\": [],\n    # \"sidebar_hide_name\": True,\n}\n\nif html_theme == \"python_docs_theme\":\n    html_sidebars = {\n        \"**\": [\"globaltoc.html\", \"sourcelink.html\", \"searchbox.html\"],\n    }\n\n# These folders are copied to the documentation's HTML output\nhtml_static_path = [\"_static\"]\n\n# These paths are either relative to html_static_path or fully qualified\n# paths (eg. https://...)\nhtml_css_files = [\n    \"scico.css\",\n    \"http://netdna.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css\",\n]\n\n# The name of an image file (relative to this directory) to place at the top\n# of the sidebar.\nhtml_logo = \"_static/logo.svg\"\n\n# The name of an image file (within the static path) to use as favicon of the\n# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32\n# pixels large.\nhtml_favicon = \"_static/scico.ico\"\n"
  },
  {
    "path": "docs/source/conf/20-extensions.py",
    "content": "# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    \"sphinx.ext.napoleon\",\n    \"sphinx.ext.autodoc\",\n    \"sphinx_autodoc_typehints\",\n    \"sphinx.ext.autosummary\",\n    \"sphinx.ext.doctest\",\n    \"sphinx.ext.intersphinx\",\n    \"sphinx.ext.viewcode\",\n    \"sphinxcontrib.bibtex\",\n    \"sphinx.ext.inheritance_diagram\",\n    \"matplotlib.sphinxext.plot_directive\",\n    \"sphinx.ext.todo\",\n    \"nbsphinx\",\n]\n\n\nbibtex_bibfiles = [\"references.bib\"]\n"
  },
  {
    "path": "docs/source/conf/25-napoleon.py",
    "content": "from sphinx.ext.napoleon.docstring import GoogleDocstring\n\n\n## See\n##   https://github.com/sphinx-doc/sphinx/issues/2115\n##   https://michaelgoerz.net/notes/extending-sphinx-napoleon-docstring-sections.html\n##\n# first, we define new methods for any new sections and add them to the class\ndef parse_keys_section(self, section):\n    return self._format_fields(\"Keys\", self._consume_fields())\n\n\nGoogleDocstring._parse_keys_section = parse_keys_section\n\n\ndef parse_attributes_section(self, section):\n    return self._format_fields(\"Attributes\", self._consume_fields())\n\n\nGoogleDocstring._parse_attributes_section = parse_attributes_section\n\n\ndef parse_class_attributes_section(self, section):\n    return self._format_fields(\"Class Attributes\", self._consume_fields())\n\n\nGoogleDocstring._parse_class_attributes_section = parse_class_attributes_section\n\n\n# we now patch the parse method to guarantee that the the above methods are\n# assigned to the _section dict\ndef patched_parse(self):\n    self._sections[\"keys\"] = self._parse_keys_section\n    self._sections[\"class attributes\"] = self._parse_class_attributes_section\n    self._unpatched_parse()\n\n\nGoogleDocstring._unpatched_parse = GoogleDocstring._parse\nGoogleDocstring._parse = patched_parse\n\n\n# napoleon_include_init_with_doc = True\nnapoleon_use_ivar = True\nnapoleon_use_rtype = False\n\n# See https://github.com/sphinx-doc/sphinx/issues/9119\n# napoleon_custom_sections = [(\"Returns\", \"params_style\")]\n"
  },
  {
    "path": "docs/source/conf/30-autodoc.py",
    "content": "autodoc_default_options = {\n    \"member-order\": \"bysource\",\n    \"inherited-members\": False,\n    \"ignore-module-all\": False,\n    \"show-inheritance\": True,\n    \"members\": True,\n    \"special-members\": \"__call__\",\n}\nautodoc_docstring_signature = True\nautoclass_content = \"both\"\n\n\n# See https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_mock_imports\nautodoc_mock_imports = [\"astra\", \"svmbir\", \"ray\"]\n\n\n# See\n#  https://stackoverflow.com/questions/2701998#62613202\n#  https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion\nautosummary_generate = True\n\n\n# See https://stackoverflow.com/questions/5599254\nautoclass_content = \"both\"\n"
  },
  {
    "path": "docs/source/conf/40-intersphinx.py",
    "content": "# Intersphinx mapping\nintersphinx_mapping = {\n    \"python\": (\"https://docs.python.org/3/\", None),\n    \"numpy\": (\"https://numpy.org/doc/stable/\", None),\n    \"scipy\": (\"https://docs.scipy.org/doc/scipy/\", None),\n    \"matplotlib\": (\"https://matplotlib.org/stable/\", None),\n    \"jax\": (\"https://docs.jax.dev/en/latest/\", None),\n    \"flax\": (\"https://flax.readthedocs.io/en/latest/\", None),\n    \"ray\": (\"https://docs.ray.io/en/latest/\", None),\n    \"svmbir\": (\"https://svmbir.readthedocs.io/en/latest/\", None),\n}\n# Added timeout due to periodic scipy.org down time\n# intersphinx_timeout = 30\n"
  },
  {
    "path": "docs/source/conf/45-mathjax.py",
    "content": "import os\n\nif os.environ.get(\"NO_MATHJAX\"):\n    extensions.append(\"sphinx.ext.imgmath\")\n    imgmath_image_format = \"svg\"\nelse:\n    extensions.append(\"sphinx.ext.mathjax\")\n    # To use local copy of MathJax for offline use, set MATHJAX_URI to\n    #    file:///[path-to-mathjax-repo-root]/es5/tex-mml-chtml.js\n    if os.environ.get(\"MATHJAX_URI\"):\n        mathjax_path = os.environ.get(\"MATHJAX_URI\")\n\nmathjax3_config = {\n    \"tex\": {\n        \"macros\": {\n            \"mb\": [r\"\\mathbf{#1}\", 1],\n            \"mbs\": [r\"\\boldsymbol{#1}\", 1],\n            \"mbb\": [r\"\\mathbb{#1}\", 1],\n            \"norm\": [r\"\\lVert #1 \\rVert\", 1],\n            \"abs\": [r\"\\left| #1 \\right|\", 1],\n            \"argmin\": [r\"\\mathop{\\mathrm{argmin}}\"],\n            \"sign\": [r\"\\mathop{\\mathrm{sign}}\"],\n            \"prox\": [r\"\\mathrm{prox}\"],\n            \"det\": [r\"\\mathrm{det}\"],\n            \"exp\": [r\"\\mathrm{exp}\"],\n            \"loss\": [r\"\\mathop{\\mathrm{loss}}\"],\n            \"kp\": [r\"k_{\\|}\"],\n            \"rp\": [r\"r_{\\|}\"],\n        }\n    }\n}\n"
  },
  {
    "path": "docs/source/conf/50-graphviz.py",
    "content": "graphviz_output_format = \"svg\"\ninheritance_graph_attrs = dict(rankdir=\"LR\", fontsize=9, ratio=\"compress\", bgcolor=\"transparent\")\ninheritance_edge_attrs = dict(\n    color='\"#2962ffff\"',\n)\ninheritance_node_attrs = dict(\n    shape=\"box\",\n    fontsize=9,\n    height=0.4,\n    margin='\"0.08, 0.03\"',\n    style='\"rounded,filled\"',\n    color='\"#2962ffff\"',\n    fontcolor='\"#2962ffff\"',\n    fillcolor='\"#f0f0f8b0\"',\n)\n"
  },
  {
    "path": "docs/source/conf/55-nbsphinx.py",
    "content": "nbsphinx_prolog = \"\"\"\n.. raw:: html\n\n    <style>\n    .nbinput .prompt, .nboutput .prompt {\n        display: none;\n    }\n    div.highlight {\n        background-color: #f9f9f4;\n    }\n    p {\n        margin-bottom: 0.8em;\n        margin-top: 0.8em;\n    }\n    </style>\n\"\"\"\n\nnbsphinx_execute = \"never\"\n"
  },
  {
    "path": "docs/source/conf/60-rtd.py",
    "content": "import os\n\non_rtd = os.environ.get(\"READTHEDOCS\") == \"True\"\n\n\nif on_rtd:\n    print(\"Building on ReadTheDocs\\n\")\n    print(\"  current working directory: {}\".format(os.path.abspath(os.curdir)))\n    print(\"  rootpath: %s\" % rootpath)\n    print(\"  confpath: %s\" % confpath)\n\n    html_static_path = []\n\n    # See https://about.readthedocs.com/blog/2024/07/addons-by-default/#how-to-opt-in-to-addons-now\n    html_baseurl = os.environ.get(\"READTHEDOCS_CANONICAL_URL\", \"\")\n    if \"html_context\" not in globals():\n        html_context = {}\n    html_context[\"READTHEDOCS\"] = True\n\n    import matplotlib\n\n    matplotlib.use(\"agg\")\n\nelse:\n    # Add any paths that contain custom static files (such as style sheets) here,\n    # relative to this directory. They are copied after the builtin static files,\n    # so a file named \"default.css\" will overwrite the builtin \"default.css\".\n    html_static_path = [\"_static\"]\n"
  },
  {
    "path": "docs/source/conf/70-latex.py",
    "content": "# -- Options for LaTeX output ---------------------------------------------\n\nlatex_elements = {\n    # The paper size ('letterpaper' or 'a4paper').\n    #'papersize': 'letterpaper',\n    # The font size ('10pt', '11pt' or '12pt').\n    #'pointsize': '10pt',\n    # Additional stuff for the LaTeX preamble.\n    #'preamble': '',\n}\n\n# Grouping the document tree into LaTeX files. List of tuples\n# (source start file, target name, title,\n#  author, documentclass [howto, manual, or own class]).\nlatex_documents = [\n    (\"index\", \"scico.tex\", \"SCICO Documentation\", \"The SCICO Developers\", \"manual\"),\n]\n\nlatex_engine = \"xelatex\"\n\n# latex_use_xindy = False\n\n\n# mathjax3_config must already be defined\nlatex_macros = []\nfor k, v in mathjax3_config[\"tex\"][\"macros\"].items():\n    if len(v) == 1:\n        latex_macros.append(r\"\\newcommand{\\%s}{%s}\" % (k, v[0]))\n    else:\n        latex_macros.append(r\"\\newcommand{\\%s}[1]{%s}\" % (k, v[0]))\n\nimgmath_latex_preamble = \"\\n\".join(latex_macros)\n\nlatex_elements = {\"preamble\": \"\\n\".join(latex_macros)}\n"
  },
  {
    "path": "docs/source/conf/71-texinfo.py",
    "content": "# -- Options for Texinfo output -------------------------------------------\n\n# Grouping the document tree into Texinfo files. List of tuples\n# (source start file, target name, title, author,\n#  dir menu entry, description, category)\ntexinfo_documents = [\n    (\n        \"index\",\n        \"SCICO\",\n        \"SCICO Documentation\",\n        \"SCICO Developers\",\n        \"SCICO\",\n        \"Scientific Computational Imaging COde (SCICO)\",\n        \"Miscellaneous\",\n    ),\n]\n"
  },
  {
    "path": "docs/source/conf/72-man_page.py",
    "content": "# -- Options for manual page output ---------------------------------------\n\n# One entry per manual page. List of tuples\n# (source start file, name, description, authors, manual section).\nman_pages = [(\"index\", \"scico\", \"SCICO Documentation\", [\"SCICO Developers\"], 1)]\n\n# If true, show URL addresses after external links.\n# man_show_urls = False\n"
  },
  {
    "path": "docs/source/conf/80-scico_numpy.py",
    "content": "import re\nfrom inspect import getmembers, isfunction\n\n# Rewrite module names for certain functions imported into scico.numpy so that they are\n# included in the docs for that module. While a bit messy to do so here rather than in a\n# function run via app.connect, it is necessary (for some yet to be identified reason)\n# to do it here to ensure that the relevant API docs include a table of functions.\nimport scico.numpy\n\nfor module in (scico.numpy, scico.numpy.fft, scico.numpy.linalg, scico.numpy.testing):\n    for _, f in getmembers(module, isfunction):\n        # Rewrite module name so that function is included in docs\n        f.__module__ = module.__name__\n        f.__doc__ = re.sub(\n            r\"^:func:`([\\w_]+)` wrapped to operate\",\n            r\":obj:`jax.numpy.\\1` wrapped to operate\",\n            str(f.__doc__),\n            flags=re.M,\n        )\n        modname = \".\".join(module.__name__.split(\".\")[1:])\n        f.__doc__ = re.sub(\n            r\"^LAX-backend implementation of :func:`([\\w_]+)`.\",\n            r\"LAX-backend implementation of :obj:`%s.\\1`.\" % modname,\n            str(f.__doc__),\n            flags=re.M,\n        )\n        # Improve formatting of jax.numpy warning\n        f.__doc__ = re.sub(\n            r\"^\\*\\*\\* This function is not yet implemented by jax.numpy, and will \"\n            r\"raise NotImplementedError \\*\\*\\*\",\n            \"**WARNING**: This function is not yet implemented by jax.numpy, \"\n            \" and will raise :exc:`NotImplementedError`.\",\n            f.__doc__,\n            flags=re.M,\n        )\n        # Remove cross-references to section NEP35\n        f.__doc__ = re.sub(\":ref:`NEP 35 <NEP35>`\", \"NEP 35\", f.__doc__, re.M)\n        # Remove cross-reference to numpydoc style references section\n        f.__doc__ = re.sub(r\" \\[(\\d+)\\]_\", \"\", f.__doc__, flags=re.M)\n        # Remove entire numpydoc references section\n        f.__doc__ = re.sub(r\"References\\n----------\\n.*\\n\", \"\", f.__doc__, flags=re.DOTALL)\n\n\n# Fix various docstring formatting errors\nscico.numpy.testing.break_cycles.__doc__ = re.sub(\n    \"calling gc.collect$\",\n    \"calling gc.collect.\\n\\n\",\n    scico.numpy.testing.break_cycles.__doc__,\n    flags=re.M,\n)\nscico.numpy.testing.break_cycles.__doc__ = re.sub(\n    r\" __del__\\) inside\", r\"__del__\\) inside\", scico.numpy.testing.break_cycles.__doc__, flags=re.M\n)\nscico.numpy.testing.assert_raises_regex.__doc__ = re.sub(\n    r\"\\*args,\\n.*\\*\\*kwargs\",\n    \"*args, **kwargs\",\n    scico.numpy.testing.assert_raises_regex.__doc__,\n    flags=re.M,\n)\nscico.numpy.BlockArray.global_shards.__doc__ = re.sub(\n    r\"`Shard`s\", r\"`Shard`\\ s\", scico.numpy.BlockArray.global_shards.__doc__, flags=re.M\n)\n"
  },
  {
    "path": "docs/source/conf/81-scico_scipy.py",
    "content": "import re\nfrom inspect import getmembers, isfunction\n\n# Similar processing for scico.scipy\nimport scico.scipy\n\nssp_func = getmembers(scico.scipy.special, isfunction)\nfor _, f in ssp_func:\n    if f.__module__[0:11] == \"scico.scipy\" or f.__module__[0:14] == \"jax._src.scipy\":\n        # Rewrite module name so that function is included in docs\n        f.__module__ = \"scico.scipy.special\"\n        # Attempt to fix incorrect cross-reference\n        f.__doc__ = re.sub(\n            r\"^:func:`([\\w_]+)` wrapped to operate\",\n            r\":obj:`jax.scipy.special.\\1` wrapped to operate\",\n            str(f.__doc__),\n            flags=re.M,\n        )\n        modname = \"scipy.special\"\n        f.__doc__ = re.sub(\n            r\"^LAX-backend implementation of :func:`([\\w_]+)`.\",\n            r\"LAX-backend implementation of :obj:`%s.\\1`.\" % modname,\n            str(f.__doc__),\n            flags=re.M,\n        )\n        # Remove cross-reference to numpydoc style references section\n        f.__doc__ = re.sub(r\"(^|\\ )\\[(\\d+)\\]_\", \"\", f.__doc__, flags=re.M)\n        # Remove entire numpydoc references section\n        f.__doc__ = re.sub(r\"References\\n----------\\n.*\\n\", \"\", f.__doc__, flags=re.DOTALL)\n        # Remove problematic citation\n        f.__doc__ = re.sub(r\"See \\[dlmf\\]_ for details.\", \"\", f.__doc__, re.M)\n        f.__doc__ = re.sub(r\"\\[dlmf\\]_\", \"NIST DLMF\", f.__doc__, re.M)\n\n# Fix indentation problems\nif hasattr(scico.scipy.special, \"sph_harm\"):\n    scico.scipy.special.sph_harm.__doc__ = re.sub(\n        \"^Computes the\", \"  Computes the\", scico.scipy.special.sph_harm.__doc__, flags=re.M\n    )\n"
  },
  {
    "path": "docs/source/conf/85-dtype_typehints.py",
    "content": "from typing import Optional, Sequence, Union  # needed for typehints_formatter hack\n\nfrom scico.typing import (  # needed for typehints_formatter hack\n    ArrayIndex,\n    AxisIndex,\n    DType,\n)\n\n\n# An explanation for this nasty hack, the primary purpose of which is to avoid\n# the very long definition of the scico.typing.DType appearing explicitly in the\n# docs. This is handled correctly by sphinx.ext.autodoc in some circumstances,\n# but only when sphinx_autodoc_typehints is not included in the extension list,\n# and the appearance of the type hints (e.g. whether links to definitions are\n# included) seems to depend on whether \"from __future__ import annotations\" was\n# used in the module being documented, which is not ideal from a consistency\n# perspective. (It's also worth noting that sphinx.ext.autodoc provides some\n# configurability for type aliases via the autodoc_type_aliases sphinx\n# configuration option.) The alternative is to include sphinx_autodoc_typehints,\n# which gives a consistent appearance to the type hints, but the\n# autodoc_type_aliases configuration option is ignored, and type aliases are\n# always expanded. This hack avoids expansion for the type aliases with the\n# longest definitions by definining a custom function for formatting the\n# type hints, using an option provided by sphinx_autodoc_typehints. For\n# more information, see\n#   https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases\n#   https://github.com/tox-dev/sphinx-autodoc-typehints/issues/284\n#   https://github.com/tox-dev/sphinx-autodoc-typehints/blob/main/README.md\ndef typehints_formatter_function(annotation, config):\n    markup = {\n        DType: \":obj:`~scico.typing.DType`\",\n        # Compound types involving DType must be added here to avoid their DType\n        # component being expanded in the docs.\n        Optional[DType]: r\":obj:`~typing.Optional`\\ [\\ :obj:`~scico.typing.DType`\\ ]\",\n        Union[DType, Sequence[DType]]: (\n            r\":obj:`~typing.Union`\\ [\\ :obj:`~scico.typing.DType`\\ , \"\n            r\":obj:`~typing.Sequence`\\ [\\ :obj:`~scico.typing.DType`\\ ]]\"\n        ),\n        AxisIndex: \":obj:`~scico.typing.AxisIndex`\",\n        ArrayIndex: \":obj:`~scico.typing.ArrayIndex`\",\n    }\n    if annotation in markup:\n        return markup[annotation]\n    else:\n        return None\n\n\ntypehints_formatter = typehints_formatter_function\n"
  },
  {
    "path": "docs/source/conf.py",
    "content": "# -*- coding: utf-8 -*-\n\nimport os\nimport sys\n\nconfpath = os.path.dirname(__file__)\nsys.path.append(confpath)\nrootpath = os.path.realpath(os.path.join(confpath, \"..\", \"..\"))\nsys.path.append(rootpath)\n\nfrom docsutil import insert_inheritance_diagram, package_classes, run_conf_files\n\n# Process settings in files in conf directory\n_vardict = run_conf_files(vardict={\"confpath\": confpath, \"rootpath\": rootpath})\nfor _k, _v in _vardict.items():\n    globals()[_k] = _v\ndel _vardict, _k, _v\n\n\n# If your documentation needs a minimal Sphinx version, state it here.\nneeds_sphinx = \"5.0.0\"\n\n# The suffix of source filenames.\nsource_suffix = \".rst\"\n\n# The encoding of source files.\nsource_encoding = \"utf-8\"\n\n# The master toctree document.\nmaster_doc = \"index\"\n\n# Output file base name for HTML help builder.\nhtmlhelp_basename = \"SCICOdoc\"\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = [\"_templates\"]\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\nexclude_patterns = [\"_build\", \"**tests**\", \"**README.rst\", \"include\"]\n\n# If true, '()' will be appended to :func: etc. cross-reference text.\nadd_function_parentheses = False\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = \"sphinx\"\n\n# Include TODOs\ntodo_include_todos = True\n\n\ndef class_inherit_diagrams(_):\n    # Insert inheritance diagrams for classes that have base classes\n    import scico\n\n    custom_parts = {\"scico.ray.tune.Tuner\": 4}\n    clslst = package_classes(scico)\n    for cls in clslst:\n        insert_inheritance_diagram(cls, parts=custom_parts)\n\n\ndef process_docstring(app, what, name, obj, options, lines):\n    # Don't show docs for inherited members in classes in scico.flax.\n    # This is primarily useful for silencing warnings due to problems in\n    # the current release of flax, but is arguably also useful in avoiding\n    # extensive documentation of methods that are likely to be of limited\n    # interest to users of the scico.flax classes.\n    #\n    # Note: this event handler currently has no effect since inclusion of\n    #   inherited members is currently globally disabled (see\n    #   \"inherited-members\" in autodoc_default_options), but is left in\n    #   place in case a decision is ever made to revert the global setting.\n    #\n    # See https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html\n    # for documentation of the autodoc-process-docstring event used here.\n    if what == \"class\" and \"scico.flax.\" in name:\n        options[\"inherited-members\"] = False\n\n\ndef setup(app):\n    app.connect(\"builder-inited\", class_inherit_diagrams)\n    app.connect(\"autodoc-process-docstring\", process_docstring)\n"
  },
  {
    "path": "docs/source/contributing.rst",
    "content": ".. _scico_dev_contributing:\n\nContributing\n============\n\nContributions to SCICO are welcome. Before starting work, please\ncontact the maintainers, either via email or the GitHub issue system,\nto discuss the relevance of your contribution and the most appropriate\nlocation within the existing package structure.\n\n\n.. _installing_dev:\n\nInstalling a Development Version\n--------------------------------\n\n1. Fork both the ``scico`` and ``scico-data`` repositories, creating\n   copies of these repositories in your own git account.\n\n2. Make sure that you have Python 3.10 or later installed in order to\n   create a conda virtual environment.\n\n3. Clone your fork from the source repo.\n\n   ::\n\n      git clone --recurse-submodules git@github.com:<username>/scico.git\n\n4. Create a conda environment using Python 3.10 or later, e.g.:\n\n   ::\n\n      conda create -n scico python=3.12\n\n5. Activate the created conda virtual environment:\n\n   ::\n\n      conda activate scico\n\n6. Change directory to the root of the cloned repository:\n\n   ::\n\n      cd scico\n\n7. Add the ``scico`` repo as an upstream remote to sync your changes:\n\n   ::\n\n      git remote add upstream https://www.github.com/lanl/scico\n\n8. After adding the upstream, the recommended way to install SCICO and\n   its dependencies is via pip:\n\n   ::\n\n      pip install -r requirements.txt  # Installs basic requirements\n      pip install -r dev_requirements.txt  # Installs developer requirements\n      pip install -r docs/docs_requirements.txt # Installs documentation requirements\n      pip install -e .  # Installs SCICO from the current directory in editable mode\n\n   For installing dependencies related to the examples please see :ref:`example_notebooks`.\n   Installing these are neccessary for the successfull running of the tests.\n\n9. The SCICO project uses the `black\n   <https://black.readthedocs.io/en/stable/>`_, `isort\n   <https://pypi.org/project/isort/>`_ and `pylint\n   <https://pylint.pycqa.org/en/latest/>`_ code formatting\n   utilities. It is important to set up a `pre-commit hook\n   <https://pre-commit.com>`_ to ensure that any modified code passes\n   format check before it is committed to the development repo:\n\n   ::\n\n      pre-commit install  # Sets up git pre-commit hooks\n\n   It is also recommended to `pin the conda package version\n   <https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-pkgs.html#preventing-packages-from-updating-pinning>`__\n   of `black <https://black.readthedocs.io/en/stable/>`_ to the version\n   number specified in ``dev_requirements.txt``.\n\n10. For testing see `Tests`_.\n\n\n\nBuilding Documentation\n----------------------\n\nBefore building the documentation, one must install the documentation\nspecific dependencies by running\n\n::\n\n   pip install -r docs/docs_requirements.txt\n\nThen, a local copy of the documentation can be built from the\nrespository root directory by running\n\n::\n\n  python setup.py build_sphinx\n\n\nAlternatively, one can also build the documentation by running the\nfollowing from the `docs/` directory\n\n::\n\n   make html\n\n\n\nContributing Code\n-----------------\n\n- New features / bugs / documentation are *always* developed in separate branches.\n- Branches should be named in the form\n  `<username>/<brief-description>`, where `<brief-description>`\n  provides a highly condensed description of the purpose of the branch\n  (e.g. `address_todo`), and may include an issue number if\n  appropriate (e.g. `fix_223`).\n\n\nA feature development workflow might look like this:\n\n\n1. Follow the instructions in `Installing a Development Version`_.\n\n2. Sync with the upstream repository:\n\n   ::\n\n      git pull --rebase origin main --recurse-submodules\n\n3. Create a branch to develop from:\n\n   ::\n\n      git checkout -b <username>/<brief-description>\n\n4. Make your desired changes.\n\n5. Run the test suite:\n\n   ::\n\n      pytest\n\n   You can limit the test suite to a specific file for example:\n\n   ::\n\n      pytest scico/test/test_blockarray.py\n\n6. When you are finished making changes, create a new commit:\n\n   ::\n\n      git add file1.py git add file2.py\n      git commit -m \"A good commit message\"\n\n   If you have added or modified an example script, see `Usage Examples`_.\n   If your contribution involves any significant new features or changes,\n   add a corresponding entry to the change summary for the next release\n   in the ``CHANGES.rst`` file.\n\n7. Sync with the upstream repository:\n\n   ::\n\n      git fetch upstream\n      git rebase upstream/main\n\n8. Push your development upstream:\n\n   ::\n\n      git push --set-upstream origin <username>/<brief-description>\n\n9. Create a new pull request to the ``main`` branch; see `the GitHub instructions <https://docs.github.com/en/github/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request>`_.\n\n10. The SCICO maintainers will review and merge your PR.\n    The SCICO project recommends the ``squash and merge`` option for merging PRs.\n\n11. Delete the branch after it has been merged.\n\n\nAdding Data\n-----------\n\nThe following steps show how to add new data, ``new_data.npz``, to the\npackaged data. We assume the ``scico`` repository has been cloned to\n``scico/``. Note that the data is located in the ``scico-data``\nsubmodule, which is attached to the main `scico` repository via the\ndirectory ``scico/data`` (i.e. the ``data/`` subdirectory of the\nrepository root directory, *not* the ``scico/data`` subdirectory of\nthe repository root directory). When adding new data, both the\n``scico`` and ``scico-data`` repositories must be updated and kept in\nsync.\n\n\n1. Create new branches in the main ``scico`` repository as well as in\n   the submodule corresponding to the ``scico-data`` repository (which\n   can be achieved by following the usual branch creation procedure\n   after changing the current directory to ``scico/data``).\n\n2. Add the ``new_data.npz`` file to the appropriate subdirectory\n   (creating a new one if necessary) of the ``scico/data`` directory.\n\n3. Change directory to this directory (taken to be ``scico/data/flax``\n   for the purposes of this example) and add/commit the new data file:\n\n   ::\n\n      cd scico/data/flax\n      git add new_data.npz\n      git commit -m \"Add new data file\"\n\n4. Return to the ``scico`` repository root directory, add/commit the\n   new data, and update submodule:\n\n   ::\n\n      cd ../..  # pwd now `scico` repo root\n      git add data\n      git commit -m \"Add data and update data module\"\n\n5. Push both repositories:\n\n   ::\n\n      git submodule foreach --recursive 'git push' && git push\n\n\n\nType Checking\n-------------\n\nAll code is required to pass ``mypy`` type checking.\n\nInstall ``mypy``:\n\n::\n\n   conda install mypy\n\nTo run the type checker, execute the following from the scico repository root:\n\n::\n\n   mypy --follow-imports=skip --ignore-missing-imports  --exclude \"(numpy|test)\" scico/\n\n\n\nTests\n-----\n\nAll functions and classes should have corresponding ``pytest`` unit tests.\n\n\nRunning Tests\n^^^^^^^^^^^^^\n\n\nTo be able to run the tests, install ``pytest`` and, optionally,\n``pytest-runner``:\n\n::\n\n    conda install pytest pytest-runner\n\nThe tests can be run by\n\n::\n\n    pytest\n\nor (if ``pytest-runner`` is installed)\n\n::\n\n    python setup.py test\n\nfrom the ``scico`` repository root directory. Tests can be run in an installed\nversion of ``scico`` by\n\n::\n\n   pytest --pyargs scico\n\nWhen any significant changes are made to the test suite, the ``pytest-split`` test\ntime database files in ``data/pytest`` should be updated using\n\n::\n\n   pytest --store-durations --durations-path data/pytest/durations_ubuntu --level 2\n\n(for Ubuntu CI), and\n\n::\n\n   pytest --store-durations --durations-path data/pytest/durations_macos --level 1\n\n(for MacOS CI). These updated files should be bzipped and committed into the\n``scico-data`` repository, replacing the current versions.\n\n\nTest Coverage\n^^^^^^^^^^^^^\n\nTest coverage is a measure of the fraction of the package code that is\nexercised by the tests. While this should not be the primary criterion\nin designing tests, it is a useful tool for finding obvious areas of\nomission.\n\nTo be able to check test coverage, install ``coverage``:\n\n::\n\n    conda install coverage\n\nA coverage report can be obtained by\n\n::\n\n    coverage run\n    coverage report\n\n\n\n\n\nUsage Examples\n--------------\n\nNew usage examples should adhere to the same general structure as the\nexisting examples to ensure that the mechanism for automatically\ngenerating corresponding Jupyter notebooks functions correctly. In\nparticular:\n\n1. The initial lines of the script should consist of a comment block,\n   followed by a blank line, followed by a multiline string with an\n   RST heading on the first line, e.g.,\n\n   ::\n\n     #!/usr/bin/env python\n     # -*- coding: utf-8 -*-\n     # This file is part of the SCICO package. Details of the copyright\n     # and user license can be found in the 'LICENSE.txt' file distributed\n     # with the package.\n\n     \"\"\"\n     Script Title\n     ============\n\n     Script description.\n     \"\"\"\n\n2. The final line of the script is an ``input`` statement intended to\n   avoid the script terminating immediately, thereby closing all\n   figures:\n\n   ::\n\n     input(\"\\nWaiting for input to close figures and exit\")\n\n3. Citations are included using the standard `Sphinx\n   <https://www.sphinx-doc.org/en/master/>`__ ``:cite:`cite-key```\n   syntax, where ``cite-key`` is the key of an entry in\n   ``docs/source/references.bib``.\n\n4. Cross-references to other components of the documentation are\n   included using the syntax described in the `nbsphinx documentation\n   <https://nbsphinx.readthedocs.io/en/latest/markdown-cells.html#Links-to-*.rst-Files-(and-Other-Sphinx-Source-Files)>`__.\n\n5. External links are included using Markdown syntax ``[link text](url)``.\n\n6. When constructing a synthetic image/volume for use in the example,\n   define a global variable `N` that controls the size of the problem,\n   and where relevant, define a global variable `maxiter` that\n   controls the number of iterations of optimization algorithms such\n   as ADMM. Adhering to this convention allows the\n   ``examples/scriptcheck.sh`` utility to automatically construct less\n   computationally expensive versions of the example scripts for\n   testing that they run without any errors.\n\n\nAdding new examples\n^^^^^^^^^^^^^^^^^^^\n\nThe following steps show how to add a new example, ``new_example.py``,\nto the packaged usage examples. We assume the ``scico`` repository has\nbeen cloned to ``scico/``.\n\nNote that the ``.py`` scripts are included in\n``scico/examples/scripts``, while the compiled Jupyter Notebooks are\nlocated in the scico-data submodule, which is symlinked to\n``scico/data``. When adding a new usage example, both the ``scico``\nand ``scico-data`` repositories must be updated and kept in sync.\n\n.. warning:: Ensure that all binary data (including raw data, images,\n   ``.ipynb`` files) are added to ``scico-data``, not the main\n   ``scico`` repo.\n\n\n1. Create new branches in the main `scico` repository as well as in\n   the submodule corresponding to the `scico-data` repository (which\n   can be achieved by following the usual branch creation procedure\n   after changing the current directory to ``scico/data``).\n\n2. Add the ``new_example.py`` script to the ``scico/examples/scripts`` directory.\n\n3. Add the basename of the script (i.e., without the pathname; in this\n   case, ``new_example.py``) to the appropriate section of\n   ``examples/scripts/index.rst``.\n\n4. Convert your new example to a Jupyter notebook by changing\n   directory to the ``scico/examples`` directory and following the\n   instructions in ``scico/examples/README.rst``.\n\n5. Change directory to the ``data`` directory and add/commit the new\n   Jupyter Notebook:\n\n   ::\n\n      cd scico/data\n      git add notebooks/new_example.ipynb\n      git commit -m \"Add new usage example\"\n\n6. Return to the main ``scico`` repository root directory, ensure the\n   ``main`` branch is checked out, add/commit the new script and\n   updated submodule:\n\n   ::\n\n      cd ..  # pwd now `scico` repo root\n      git add data\n      git add examples/scripts/new_filename.py\n      git commit -m \"Add usage example and update data module\"\n\n7. Push both repositories:\n\n   ::\n\n      git submodule foreach --recursive 'git push' && git push\n"
  },
  {
    "path": "docs/source/docsutil.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\n\"\"\"Utilities for building docs.\"\"\"\n\nimport importlib\nimport inspect\nimport os\nimport pkgutil\nimport sys\nfrom glob import glob\nfrom runpy import run_path\n\n\ndef run_conf_files(vardict=None, path=None):\n    \"\"\"Execute Python files in conf directory.\n\n    Args:\n        vardict: Dictionary into which variable names should be inserted.\n            Defaults to empty dict.\n        path: Path to conf directory. Defaults to path to this module.\n\n    Returns:\n        A dict populated with variables defined during execution of the\n        configuration files.\n    \"\"\"\n    if vardict is None:\n        vardict = {}\n    if path is None:\n        path = os.path.dirname(__file__)\n\n    files = os.path.join(path, \"conf\", \"*.py\")\n    for f in sorted(glob(files)):\n        conf = run_path(f, init_globals=vardict)\n        for k, v in conf.items():\n            if len(k) >= 4 and k[0:2] == \"__\" and k[-2:] == \"__\":  # ignore __<name>__ variables\n                continue\n            vardict[k] = v\n    return vardict\n\n\ndef package_classes(package):\n    \"\"\"Get a list of classes in a package.\n\n    Return a list of qualified names of classes in the specified\n    package. Classes in modules with names beginning with an \"_\" are\n    omitted, as are classes whose internal module name record is not\n    the same as the module in which they are found (i.e. indicating\n    that they have been imported from elsewhere).\n\n    Args:\n        package: Reference to package for which classes are to be listed\n          (not package name string).\n\n    Returns:\n        A list of qualified names of classes in the specified package.\n    \"\"\"\n\n    classes = []\n    # Iterate over modules in package\n    for importer, modname, _ in pkgutil.walk_packages(\n        path=package.__path__, prefix=(package.__name__ + \".\"), onerror=lambda x: None\n    ):\n        # Skip modules whose names begin with a \"_\"\n        if modname.split(\".\")[-1][0] == \"_\":\n            continue\n        importlib.import_module(modname)\n        # Iterate over module members\n        for name, obj in inspect.getmembers(sys.modules[modname]):\n            if inspect.isclass(obj):\n                # Get internal module name of class for comparison with working module name\n                try:\n                    objmodname = getattr(sys.modules[modname], obj.__name__).__module__\n                except Exception:\n                    objmodname = None\n                if objmodname == modname:\n                    classes.append(modname + \".\" + obj.__name__)\n\n    return classes\n\n\ndef get_text_indentation(text, skiplines=0):\n    \"\"\"Compute the leading whitespace indentation in a block of text.\n\n    Args:\n        text: A block of text as a string.\n\n    Returns:\n        Indentation length.\n    \"\"\"\n    min_indent = len(text)\n    lines = text.splitlines()\n    if len(lines) > skiplines:\n        lines = lines[skiplines:]\n    else:\n        return None\n    for line in lines:\n        if len(line) > 0:\n            indent = len(line) - len(line.lstrip())\n            if indent < min_indent:\n                min_indent = indent\n    return min_indent\n\n\ndef add_text_indentation(text, indent):\n    \"\"\"Insert leading whitespace into a block of text.\n\n    Args:\n        text: A block of text as a string.\n        indent: Number of leading spaces to insert on each line.\n\n    Returns:\n        Text with additional indentation.\n    \"\"\"\n    lines = text.splitlines()\n    for n, line in enumerate(lines):\n        if len(line) > 0:\n            lines[n] = (\" \" * indent) + line\n    return \"\\n\".join(lines)\n\n\ndef insert_inheritance_diagram(clsqname, parts=None, default_nparts=2):\n    \"\"\"Insert an inheritance diagram into a class docstring.\n\n    No action is taken for classes without a base clase, and for classes\n    without a docstring.\n\n    Args:\n        clsqname: Qualified name (i.e. including module name path) of class.\n        parts: A dict mapping qualified class names to custom values for\n          the \":parts:\" directive.\n        default_nparts: Default value for the \":parts:\" directive.\n    \"\"\"\n\n    # Extract module name and class name from qualified class name\n    clspth = clsqname.split(\".\")\n    modname = \".\".join(clspth[0:-1])\n    clsname = clspth[-1]\n    # Get reference to class\n    cls = getattr(sys.modules[modname], clsname)\n    # Return immediately if class has no base classes\n    if getattr(cls, \"__bases__\") == (object,):\n        return\n    # Get current docstring\n    docstr = getattr(cls, \"__doc__\")\n    # Return immediately if class has no docstring\n    if docstr is None:\n        return\n    # Use class-specific parts or default parts directive value\n    if parts and clsqname in parts:\n        nparts = parts[clsqname]\n    else:\n        nparts = default_nparts\n    # Split docstring into individual lines\n    lines = docstr.splitlines()\n    # Return immediately if there are no lines\n    if not lines:\n        return\n    # Cut leading whitespace lines\n    n = 0\n    for n, line in enumerate(lines):\n        if line != \"\":\n            break\n    lines = lines[n:]\n    # Define inheritance diagram insertion text\n    idstr = f\"\"\"\n\n    .. inheritance-diagram:: {clsname}\n       :parts: {nparts}\n\n\n    \"\"\"\n    docstr_indent = get_text_indentation(docstr, skiplines=1)\n    if docstr_indent is not None and docstr_indent > 4:\n        idstr = add_text_indentation(idstr, docstr_indent - 4)\n    # Insert inheritance diagram after summary line and whitespace line following it\n    lines.insert(2, idstr)\n    # Construct new docstring and attach it to the class\n    extdocstr = \"\\n\".join(lines)\n    setattr(cls, \"__doc__\", extdocstr)\n"
  },
  {
    "path": "docs/source/examples.rst",
    "content": ".. _example_notebooks:\n\nUsage Examples\n==============\n\n.. toctree::\n   :maxdepth: 1\n\n.. include:: include/examplenotes.rst\n\n\nOrganized by Application\n------------------------\n\n.. toctree::\n   :maxdepth: 1\n\n\nComputed Tomography\n^^^^^^^^^^^^^^^^^^^\n\n.. toctree::\n   :maxdepth: 1\n\n   examples/ct_abel_tv_admm\n   examples/ct_abel_tv_admm_tune\n   examples/ct_symcone_tv_padmm\n   examples/ct_astra_noreg_pcg\n   examples/ct_astra_3d_tv_admm\n   examples/ct_astra_3d_tv_padmm\n   examples/ct_tv_admm\n   examples/ct_astra_tv_admm\n   examples/ct_multi_tv_admm\n   examples/ct_astra_weighted_tv_admm\n   examples/ct_svmbir_tv_multi\n   examples/ct_svmbir_ppp_bm3d_admm_cg\n   examples/ct_svmbir_ppp_bm3d_admm_prox\n   examples/ct_fan_svmbir_ppp_bm3d_admm_prox\n   examples/ct_modl_train_foam2\n   examples/ct_odp_train_foam2\n   examples/ct_unet_train_foam2\n   examples/ct_projector_comparison_2d\n   examples/ct_projector_comparison_3d\n\nDeconvolution\n^^^^^^^^^^^^^\n\n.. toctree::\n   :maxdepth: 1\n\n   examples/deconv_circ_tv_admm\n   examples/deconv_tv_admm\n   examples/deconv_tv_padmm\n   examples/deconv_tv_admm_tune\n   examples/deconv_microscopy_tv_admm\n   examples/deconv_microscopy_allchn_tv_admm\n   examples/deconv_ppp_bm3d_admm\n   examples/deconv_ppp_bm3d_apgm\n   examples/deconv_ppp_dncnn_admm\n   examples/deconv_ppp_dncnn_padmm\n   examples/deconv_ppp_bm4d_admm\n   examples/deconv_modl_train_foam1\n   examples/deconv_odp_train_foam1\n\n\nSparse Coding\n^^^^^^^^^^^^^\n\n.. toctree::\n   :maxdepth: 1\n\n   examples/sparsecode_nn_admm\n   examples/sparsecode_nn_apgm\n   examples/sparsecode_conv_admm\n   examples/sparsecode_conv_md_admm\n   examples/sparsecode_apgm\n   examples/sparsecode_poisson_apgm\n\n\nMiscellaneous\n^^^^^^^^^^^^^\n\n.. toctree::\n   :maxdepth: 1\n\n   examples/demosaic_ppp_bm3d_admm\n   examples/superres_ppp_dncnn_admm\n   examples/denoise_l1tv_admm\n   examples/denoise_ptv_pdhg\n   examples/denoise_tv_admm\n   examples/denoise_tv_apgm\n   examples/denoise_tv_multi\n   examples/denoise_approx_tv_multi\n   examples/denoise_cplx_tv_nlpadmm\n   examples/denoise_cplx_tv_pdhg\n   examples/denoise_dncnn_universal\n   examples/diffusercam_tv_admm\n   examples/video_rpca_admm\n   examples/ct_datagen_foam2\n   examples/deconv_datagen_bsds\n   examples/deconv_datagen_foam1\n   examples/denoise_datagen_bsds\n\n\nOrganized by Regularization\n---------------------------\n\n.. toctree::\n   :maxdepth: 1\n\nPlug and Play Priors\n^^^^^^^^^^^^^^^^^^^^\n\n.. toctree::\n   :maxdepth: 1\n\n   examples/ct_svmbir_ppp_bm3d_admm_cg\n   examples/ct_svmbir_ppp_bm3d_admm_prox\n   examples/ct_fan_svmbir_ppp_bm3d_admm_prox\n   examples/deconv_ppp_bm3d_admm\n   examples/deconv_ppp_bm3d_apgm\n   examples/deconv_ppp_dncnn_admm\n   examples/deconv_ppp_dncnn_padmm\n   examples/deconv_ppp_bm4d_admm\n   examples/demosaic_ppp_bm3d_admm\n   examples/superres_ppp_dncnn_admm\n\n\nTotal Variation\n^^^^^^^^^^^^^^^\n\n.. toctree::\n   :maxdepth: 1\n\n   examples/ct_abel_tv_admm\n   examples/ct_abel_tv_admm_tune\n   examples/ct_symcone_tv_padmm\n   examples/ct_tv_admm\n   examples/ct_multi_tv_admm\n   examples/ct_astra_tv_admm\n   examples/ct_astra_3d_tv_admm\n   examples/ct_astra_3d_tv_padmm\n   examples/ct_astra_weighted_tv_admm\n   examples/ct_svmbir_tv_multi\n   examples/deconv_circ_tv_admm\n   examples/deconv_tv_admm\n   examples/deconv_tv_admm_tune\n   examples/deconv_tv_padmm\n   examples/deconv_microscopy_tv_admm\n   examples/deconv_microscopy_allchn_tv_admm\n   examples/denoise_l1tv_admm\n   examples/denoise_ptv_pdhg\n   examples/denoise_tv_admm\n   examples/denoise_tv_apgm\n   examples/denoise_tv_multi\n   examples/denoise_approx_tv_multi\n   examples/denoise_cplx_tv_nlpadmm\n   examples/denoise_cplx_tv_pdhg\n   examples/diffusercam_tv_admm\n\n\n\nSparsity\n^^^^^^^^\n\n.. toctree::\n   :maxdepth: 1\n\n   examples/diffusercam_tv_admm\n   examples/sparsecode_nn_admm\n   examples/sparsecode_nn_apgm\n   examples/sparsecode_conv_admm\n   examples/sparsecode_conv_md_admm\n   examples/sparsecode_apgm\n   examples/sparsecode_poisson_apgm\n   examples/video_rpca_admm\n\n\nMachine Learning\n^^^^^^^^^^^^^^^^\n\n.. toctree::\n   :maxdepth: 1\n\n   examples/ct_datagen_foam2\n   examples/ct_modl_train_foam2\n   examples/ct_odp_train_foam2\n   examples/ct_unet_train_foam2\n   examples/deconv_datagen_bsds\n   examples/deconv_datagen_foam1\n   examples/deconv_modl_train_foam1\n   examples/deconv_odp_train_foam1\n   examples/denoise_datagen_bsds\n   examples/denoise_dncnn_train_bsds\n   examples/denoise_dncnn_universal\n\n\nOrganized by Optimization Algorithm\n-----------------------------------\n\n.. toctree::\n   :maxdepth: 1\n\nADMM\n^^^^\n\n.. toctree::\n   :maxdepth: 1\n\n   examples/ct_abel_tv_admm\n   examples/ct_abel_tv_admm_tune\n   examples/ct_symcone_tv_padmm\n   examples/ct_astra_tv_admm\n   examples/ct_tv_admm\n   examples/ct_astra_3d_tv_admm\n   examples/ct_astra_weighted_tv_admm\n   examples/ct_multi_tv_admm\n   examples/ct_svmbir_tv_multi\n   examples/ct_svmbir_ppp_bm3d_admm_cg\n   examples/ct_svmbir_ppp_bm3d_admm_prox\n   examples/ct_fan_svmbir_ppp_bm3d_admm_prox\n   examples/deconv_circ_tv_admm\n   examples/deconv_tv_admm\n   examples/deconv_tv_admm_tune\n   examples/deconv_microscopy_tv_admm\n   examples/deconv_microscopy_allchn_tv_admm\n   examples/deconv_ppp_bm3d_admm\n   examples/deconv_ppp_dncnn_admm\n   examples/deconv_ppp_bm4d_admm\n   examples/diffusercam_tv_admm\n   examples/sparsecode_nn_admm\n   examples/sparsecode_conv_admm\n   examples/sparsecode_conv_md_admm\n   examples/demosaic_ppp_bm3d_admm\n   examples/superres_ppp_dncnn_admm\n   examples/denoise_l1tv_admm\n   examples/denoise_tv_admm\n   examples/denoise_tv_multi\n   examples/denoise_approx_tv_multi\n   examples/video_rpca_admm\n\n\nLinearized ADMM\n^^^^^^^^^^^^^^^\n\n.. toctree::\n   :maxdepth: 1\n\n   examples/ct_svmbir_tv_multi\n   examples/denoise_tv_multi\n\n\nProximal ADMM\n^^^^^^^^^^^^^\n\n.. toctree::\n   :maxdepth: 1\n\n   examples/ct_astra_3d_tv_padmm\n   examples/deconv_tv_padmm\n   examples/denoise_tv_multi\n   examples/deconv_ppp_dncnn_padmm\n\n\nNon-linear Proximal ADMM\n^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. toctree::\n   :maxdepth: 1\n\n   examples/denoise_cplx_tv_nlpadmm\n\n\nPDHG\n^^^^\n\n.. toctree::\n   :maxdepth: 1\n\n   examples/ct_svmbir_tv_multi\n   examples/denoise_ptv_pdhg\n   examples/denoise_tv_multi\n   examples/denoise_cplx_tv_pdhg\n\n\nPGM\n^^^\n\n.. toctree::\n   :maxdepth: 1\n\n   examples/deconv_ppp_bm3d_apgm\n   examples/sparsecode_apgm\n   examples/sparsecode_nn_apgm\n   examples/sparsecode_poisson_apgm\n   examples/denoise_tv_apgm\n   examples/denoise_approx_tv_multi\n\n\nPCG\n^^^\n\n.. toctree::\n   :maxdepth: 1\n\n   examples/ct_astra_noreg_pcg\n"
  },
  {
    "path": "docs/source/include/blockarray.rst",
    "content": ".. _blockarray_class:\n\nBlockArray\n==========\n\n.. testsetup::\n\n   >>> import numpy as np\n   >>> import scico\n   >>> import scico.random\n   >>> import scico.linop\n   >>> import scico.numpy as snp\n   >>> from scico.numpy import BlockArray\n\nThe class :class:`.BlockArray` provides a way to combine arrays of\ndifferent shapes into a single object for use with other SCICO classes.\nA :class:`.BlockArray` consists of a list of :class:`jax.Array` objects,\nwhich we refer to as blocks. A :class:`.BlockArray` differs from a list in\nthat, whenever possible, :class:`.BlockArray` properties and methods\n(including unary and binary operators like +, -, \\*, ...) automatically\nmap along the blocks, returning another :class:`.BlockArray` or tuple as\nappropriate. For example,\n\n::\n\n    >>> x = snp.blockarray((\n    ...     [[1, 3, 7],\n    ...      [2, 2, 1]],\n    ...     [2, 4, 8]\n    ... ))\n\n    >>> x.shape  # returns tuple\n    ((2, 3), (3,))\n\n    >>> x * 2  # returns BlockArray   # doctest: +ELLIPSIS\n    BlockArray([...Array([[ 2,  6, 14],\n\t\t [ 4,  4,  2]], dtype=...), ...Array([ 4,  8, 16], dtype=...)])\n\n    >>> y = snp.blockarray((\n    ...        [[.2],\n    ...         [.3]],\n    ...        [.4]\n    ... ))\n\n    >>> x + y  # returns BlockArray  # doctest: +ELLIPSIS\n    BlockArray([...Array([[1.2, 3.2, 7.2],\n\t\t  [2.3, 2.3, 1.3]], dtype=...), ...Array([2.4, 4.4, 8.4], dtype=...)])\n\n\n.. _numpy_functions_blockarray:\n\nNumPy and SciPy Functions\n-------------------------\n\n:mod:`scico.numpy`, :mod:`scico.numpy.testing`, and\n:mod:`scico.scipy.special` provide wrappers around :mod:`jax.numpy`,\n:mod:`numpy.testing` and :mod:`jax.scipy.special` where many of the\nfunctions have been extended to work with instances of :class:`.BlockArray`.\nIn particular:\n\n* When a tuple of tuples is passed as the `shape`\n  argument to an array creation routine, a :class:`.BlockArray` is created.\n* When a :class:`.BlockArray` is passed to a reduction function, the blocks are\n  ravelled (i.e., reshaped to be 1D) and concatenated before the reduction\n  is applied. This behavior may be prevented by passing the `axis`\n  argument, in which case the function is mapped over the blocks.\n* When one or more :class:`.BlockArray` instances are passed to a mathematical\n  function that is not a reduction, the function is mapped over\n  (corresponding) blocks.\n\nFor a list of array creation routines, see\n\n::\n\n   >>> scico.numpy.creation_routines  # doctest: +ELLIPSIS\n   ('empty', ...)\n\nFor a list of  reduction functions, see\n\n::\n\n   >>> scico.numpy.reduction_functions  # doctest: +ELLIPSIS\n   ('sum', ...)\n\nFor lists of the remaining wrapped functions, see\n\n::\n\n   >>> scico.numpy.mathematical_functions  # doctest: +ELLIPSIS\n   ('sin', ...)\n   >>> scico.numpy.testing_functions  # doctest: +ELLIPSIS\n   ('testing.assert_allclose', ...)\n   >>> import scico.scipy\n   >>> scico.scipy.special.functions  # doctest: +ELLIPSIS\n   ('betainc', ...)\n\nNote that:\n\n* The functional and method versions of the \"same\" function differ in their\n  behavior, with the method version only applying the reduction within each\n  block, and the function version applying the reduction across all blocks.\n  For example, :func:`scico.numpy.sum` applied to a :class:`.BlockArray` with\n  two blocks returns a scalar value, while :meth:`.BlockArray.sum` returns a\n  :class:`.BlockArray` two scalar blocks.\n* For example, :func:`scico.numpy.ravel` returns a fully flattened, single\n  :class:`jax.Array`, while :meth:`.BlockArray.ravel` returns a\n  :class:`.BlockArray` with ravelled blocks.\n\n\nMotivating Example\n------------------\n\nThe discrete differences of a two-dimensional array, :math:`\\mb{x} \\in\n\\mbb{R}^{n \\times m}`, in the horizontal and vertical directions can\nbe represented by the arrays :math:`\\mb{x}_h \\in \\mbb{R}^{n \\times\n(m-1)}` and :math:`\\mb{x}_v \\in \\mbb{R}^{(n-1) \\times m}`\nrespectively. While it is usually useful to consider the output of a\ndifference operator as a single entity, we cannot combine these two\narrays into a single array since they have different shapes. We could\nvectorize each array and concatenate the resulting vectors, leading to\n:math:`\\mb{\\bar{x}} \\in \\mbb{R}^{n(m-1) + m(n-1)}`, which can be\nstored as a one-dimensional array, but this makes it hard to access\nthe individual components :math:`\\mb{x}_h` and :math:`\\mb{x}_v`.\n\nInstead, we can construct a :class:`.BlockArray`, :math:`\\mb{x}_B =\n[\\mb{x}_h, \\mb{x}_v]`:\n\n\n::\n\n  >>> n = 32\n  >>> m = 16\n  >>> x_h, key = scico.random.randn((n, m-1))\n  >>> x_v, _ = scico.random.randn((n-1, m), key=key)\n\n  # Form the blockarray\n  >>> x_B = snp.blockarray([x_h, x_v])\n\n  # The blockarray shape is a tuple of tuples\n  >>> x_B.shape\n  ((32, 15), (31, 16))\n\n  # Each block component can be easily accessed\n  >>> x_B[0].shape\n  (32, 15)\n  >>> x_B[1].shape\n  (31, 16)\n\n\nConstructing a BlockArray\n-------------------------\n\nThe recommended way to construct a :class:`.BlockArray` is by using the\n:func:`~scico.numpy.blockarray` function.\n\n::\n\n   >>> import scico.numpy as snp\n   >>> x0, key = scico.random.randn((32, 32))\n   >>> x1, _ = scico.random.randn((16,), key=key)\n   >>> X = snp.blockarray((x0, x1))\n   >>> X.shape\n   ((32, 32), (16,))\n   >>> X.size\n   (1024, 16)\n   >>> len(X)\n   2\n\nWhile :func:`~scico.numpy.blockarray` will accept arguments of type\n:class:`~numpy.ndarray` or :class:`~jax.Array`, arguments of type :class:`~numpy.ndarray` will be converted to :class:`~jax.Array` type.\n\n\nOperating on a BlockArray\n-------------------------\n\n\n.. _blockarray_indexing:\n\nIndexing\n^^^^^^^^\n\n:class:`.BlockArray` indexing works just like indexing a list.\n\n\nMultiplication between BlockArray and LinearOperator\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nThe :class:`.Operator` and :class:`.LinearOperator` classes are designed\nto work on instances of :class:`.BlockArray` in addition to instances of\n:obj:`~jax.Array`. For example\n\n::\n\n    >>> x, key = scico.random.randn((3, 4))\n    >>> A_1 = scico.linop.Identity(x.shape)\n    >>> A_1.shape  # array -> array\n    ((3, 4), (3, 4))\n\n    >>> A_2 = scico.linop.FiniteDifference(x.shape)\n    >>> A_2.shape  # array -> BlockArray\n    (((2, 4), (3, 3)), (3, 4))\n\n    >>> diag = snp.blockarray([np.array(1.0), np.array(2.0)])\n    >>> A_3 = scico.linop.Diagonal(diag, input_shape=(A_2.output_shape))\n    >>> A_3.shape  # BlockArray -> BlockArray\n    (((2, 4), (3, 3)), ((2, 4), (3, 3)))\n"
  },
  {
    "path": "docs/source/include/examplenotes.rst",
    "content": ".. _example_depend:\n\nExample Dependencies\n--------------------\n\nSome examples use additional dependencies, which are listed in `examples_requirements.txt <https://github.com/lanl/scico/blob/main/examples/examples_requirements.txt>`_.\nThe additional requirements should be installed via pip, with the exception of ``astra-toolbox``,\nwhich should be installed via conda:\n\n::\n\n   conda install astra-toolbox\n   pip install -r examples/examples_requirements.txt  # Installs other example requirements\n\nThe dependencies can also be installed individually as required.\n\nNote that ``astra-toolbox`` should be installed on a host with one or more CUDA GPUs to ensure\nthat the version with GPU support is installed.\n\n\nRun Time\n--------\n\nMost of these examples have been constructed with sufficiently small test problems to\nallow them to run to completion within 5 minutes or less on a reasonable workstation.\nNote, however, that it was not feasible to construct meaningful examples of the training\nof some of the deep learning algorithms that complete within a relatively short time;\nthe examples \"CT Training and Reconstructions with MoDL\" and \"CT Training and\nReconstructions with ODP\" in particular are much slower, and can require multiple hours\nto run on a workstation with multiple GPUs.\n\n|\n"
  },
  {
    "path": "docs/source/include/functional.rst",
    "content": "Functionals\n===========\n\nA functional is\na mapping from :math:`\\mathbb{R}^n` or :math:`\\mathbb{C}^n` to :math:`\\mathbb{R}`.\nIn SCICO, functionals are\nprimarily used to represent a cost to be minimized\nand are represented by instances of the :class:`.Functional` class.\nAn instance of :class:`.Functional`, ``f``, may provide three core operations.\n\n* Evaluation\n   - ``f(x)`` returns the value of the functional\n     evaluated at the point ``x``.\n   - A functional that can be evaluated\n     has the attribute ``f.has_eval == True``.\n   - Not all functionals can be evaluated:  see `Plug-and-Play`_.\n* Gradient\n   - ``f.grad(x)`` returns the gradient of the functional evaluated at ``x``.\n   - Gradients are calculated using JAX reverse-mode automatic differentiation,\n     exposed through :func:`scico.grad`.\n   - *Note:*  The gradient of a functional ``f`` can be evaluated even if that functional is not smooth.\n     All that is required is that the functional can be evaluated, ``f.has_eval == True``.\n     However, the result may not be a valid gradient (or subgradient) for all inputs.\n* Proximal operator\n   - ``f.prox(v, lam)`` returns the result of the scaled proximal\n     operator of ``f``, i.e., the proximal operator of ``lambda x:\n     lam * f(x)``, evaluated at the point ``v``.\n   - The proximal operator of a functional :math:`f : \\mathbb{R}^n \\to\n     \\mathbb{R}` is the mapping :math:`\\mathrm{prox}_f : \\mathbb{R}^n\n     \\to \\mathbb{R}^n` defined as\n\n     .. math::\n      \\mathrm{prox}_f (\\mb{v}) =  \\argmin_{\\mb{x}} f(\\mb{x}) +\n      \\frac{1}{2} \\norm{\\mb{v} - \\mb{x}}_2^2\\;.\n\n\nPlug-and-Play\n-------------\n\nFor the plug-and-play framework :cite:`sreehari-2016-plug`,\nwe encapsulate generic denoisers including CNNs\nin :class:`.Functional` objects that **cannot be evaluated**.\nThe denoiser is applied via the the proximal operator.\nFor examples, see :ref:`example_notebooks`.\n\n\nProximal Calculus\n-----------------\n\nWe support a limited subset of proximal calculus rules:\n\n\nScaled Functionals\n^^^^^^^^^^^^^^^^^^\n\nGiven a scalar ``c`` and a functional ``f`` with a defined proximal method, we can\ndetermine the proximal method of ``c * f`` as\n\n.. math::\n\n   \\begin{align}\n    \\mathrm{prox}_{c f} (v, \\lambda) &=  \\argmin_x \\lambda (c f)(x) + \\frac{1}{2} \\norm{v - x}_2^2  \\\\\n    &=  \\argmin_x (\\lambda c) f(x) + \\frac{1}{2} \\norm{v - x}_2^2 \\\\\n    &= \\mathrm{prox}_{f} (v, c \\lambda) \\;.\n    \\end{align}\n\nNote that we have made no assumptions regarding homogeneity of ``f``;\nrather, only that the proximal method of ``f`` is given\nin the parameterized form :math:`\\mathrm{prox}_{c f}`.\n\nIn SCICO, multiplying a :class:`.Functional` by a scalar\nwill return a :class:`.ScaledFunctional`.\nThis :class:`.ScaledFunctional` retains the ``has_eval`` and ``has_prox`` attributes\nfrom the original :class:`.Functional`,\nbut the proximal method is modified to accomodate the additional scalar.\n\n\nSeparable Functionals\n^^^^^^^^^^^^^^^^^^^^^\n\nA separable functional :math:`f : \\mathbb{C}^N \\to \\mathbb{R}` can be written as the sum\nof functionals :math:`f_i : \\mathbb{C}^{N_i} \\to \\mathbb{R}` with :math:`\\sum_i N_i = N`. In particular,\n\n.. math::\n   f(\\mb{x}) = f(\\mb{x}_1, \\dots, \\mb{x}_N) = f_1(\\mb{x}_1) + \\dots + f_N(\\mb{x}_N) \\;.\n\nThe proximal operator of a separable :math:`f` can be written\nin terms of the proximal operators of the :math:`f_i`\n(see Theorem 6.6 of :cite:`beck-2017-first`):\n\n.. math::\n    \\mathrm{prox}_f(\\mb{x}, \\lambda)\n    =\n    \\begin{pmatrix}\n      \\mathrm{prox}_{f_1}(\\mb{x}_1, \\lambda) \\\\\n      \\vdots \\\\\n      \\mathrm{prox}_{f_N}(\\mb{x}_N, \\lambda) \\\\\n    \\end{pmatrix} \\;.\n\nSeparable Functionals are implemented in the :class:`.SeparableFunctional` class. Separable functionals naturally accept :class:`.BlockArray` inputs and return the prox as a :class:`.BlockArray`.\n\n\n\nAdding New Functionals\n----------------------\n\nTo add a new functional,\ncreate a class which\n\n1. inherits from base :class:`.Functional`;\n2. has ``has_eval`` and ``has_prox`` flags;\n3. has ``_eval`` and ``prox`` methods, as necessary.\n\nFor example,\n\n::\n\n   class MyFunctional(scico.functional.Functional):\n\n       has_eval = True\n       has_prox = True\n\n       def _eval(self, x: JaxArray) -> float:\n            return snp.sum(x)\n\n       def prox(self, x: JaxArray, lam : float) -> JaxArray:\n            return x - lam\n\n\nLosses\n------\n\nIn SCICO, a loss is a special type of functional\n\n.. math::\n   f(\\mb{x}) = \\alpha l( \\mb{y}, A(\\mb{x}) ) \\;,\n\nwhere :math:`\\alpha` is a scaling parameter,\n:math:`l` is a functional,\n:math:`\\mb{y}` is a set of measurements,\nand :math:`A` is an operator.\nSCICO uses the class :class:`.Loss` to represent losses.\nLoss functionals commonly arrise in the context of solving\ninverse problems in scientific imaging,\nwhere they are used to represent the mismatch\nbetween predicted measurements :math:`A(\\mb{x})`\nand actual ones :math:`\\mb{y}`.\n"
  },
  {
    "path": "docs/source/include/learning.rst",
    "content": "Learned Models\n==============\n\nIn SCICO, neural network models are used to represent imaging problems and provide different modes of data-driven regularization.\nThe models are implemented in `Flax <https://flax.readthedocs.io/>`_, and constitute a representative sample of frequently used networks.\n\n\nFlaxMap\n-------\n\nSCICO interfaces with the implemented models via :class:`.FlaxMap`. This provides a standardized access to all trained models via the model definiton and the learned parameters. Further specialized functionality, such as learned denoisers, are built on top of :class:`.FlaxMap`. The specific models that have been implemented are described below.\n\n\n\nDnCNN\n-----\n\nThe denoiser convolutional neural network model (DnCNN) :cite:`zhang-2017-dncnn`, implemented as :class:`.DnCNNNet`, is used to denoise images that have been corrupted with additive Gaussian noise.\n\n\n\nODP\n---\n\nThe unrolled optimization with deep priors (ODP) :cite:`diamond-2018-odp`, implemented as :class:`.ODPNet`, is used to solve inverse problems in imaging by adapting classical iterative methods into an end-to-end framework that incorporates deep networks as well as knowledge of the image formation model.\n\nThe framework aims to solve the optimization problem\n\n.. math::\n   \\argmin_{\\mb{x}} \\; f(A \\mb{x}, \\mb{y}) + r(\\mb{x}) \\;,\n\nwhere :math:`A` represents a linear forward model and :math:`r` a regularization function encoding prior information, by unrolling the iterative solution method into a network where each iteration corresponds to a different stage in the ODP network. Different iterative solutions produce different unrolled optimization algorithms which, in turn, produce different ODP networks. The ones implemented in SCICO are described below.\n\n\nProximal Map\n^^^^^^^^^^^^\n\nThis algorithm corresponds to solving\n\n.. math::\n   :label: eq:odp_prox\n\n   \\argmin_{\\mb{x}} \\; \\alpha_k \\, f(A \\mb{x}, \\mb{y}) + \\frac{1}{2} \\| \\mb{x} - \\mb{x}^k - \\mb{x}^{k+1/2} \\|_2^2 \\;,\n\nwith :math:`k` corresponding to the index of the iteration, which translates to an index of the stage of the network, :math:`f(A \\mb{x}, \\mb{y})` a fidelity term, usually an :math:`\\ell_2` norm, and :math:`\\mb{x}^{k+1/2}` a regularization representing :math:`\\mathrm{prox}_r (\\mb{x}^k)` and usually implemented as a convolutional neural network (CNN). This proximal map representation is used when minimization problem :eq:`eq:odp_prox` can be solved in a computationally efficient manner.\n\n:class:`.ODPProxDnBlock` uses this formulation to solve a denoising problem, which, according to :cite:`diamond-2018-odp`, can be solved by\n\n.. math::\n   \\mb{x}^{k+1} = (\\alpha_k \\, \\mb{y} + \\mb{x}^k + \\mb{x}^{k+1/2}) \\, / \\, (\\alpha_k + 1) \\;,\n\nwhere :math:`A` corresponds to the identity operator and is therefore omitted, :math:`\\mb{y}` is the noisy signal, :math:`\\alpha_k > 0` is a learned stage-wise parameter weighting the contribution of the fidelity term and :math:`\\mb{x}^k + \\mb{x}^{k+1/2}` is the regularization, usually represented by a residual CNN.\n\n\n:class:`.ODPProxDblrBlock` uses this formulation to solve a deblurring problem, which, according to :cite:`diamond-2018-odp`, can be solved by\n\n.. math::\n   \\mb{x}^{k+1} = \\mathcal{F}^{-1} \\mathrm{diag} (\\alpha_k | \\mathcal{F}(K)|^2 + 1 )^{-1} \\mathcal{F} \\, (\\alpha_k K^T * \\mb{y} + \\mb{x}^k + \\mb{x}^{k+1/2}) \\;,\n\nwhere :math:`A` is the blurring operator, :math:`K` is the blurring kernel, :math:`\\mb{y}` is the blurred signal, :math:`\\mathcal{F}` is the DFT, :math:`\\alpha_k > 0` is a learned  stage-wise parameter weighting the contribution of the fidelity term and :math:`\\mb{x}^k + \\mb{x}^{k+1/2}` is the regularization represented by a residual CNN.\n\n\nGradient Descent\n^^^^^^^^^^^^^^^^\n\nWhen the solution of the optimization problem in :eq:`eq:odp_prox` can not be simply represented by an analytical step, a formulation based on a gradient descent iteration is preferred. This yields\n\n.. math::\n   \\mb{x}^{k+1} = \\mb{x}^k + \\mb{x}^{k+1/2} - \\alpha_k \\, A^T \\nabla_x \\, f(A \\mb{x}^k, \\mb{y}) \\;,\n\nwhere :math:`\\mb{x}^{k+1/2}` represents :math:`\\nabla r(\\mb{x}^k)`.\n\n:class:`.ODPGrDescBlock` uses this formulation to solve a generic problem with :math:`\\ell_2` fidelity as\n\n.. math::\n   \\mb{x}^{k+1} = \\mb{x}^k + \\mb{x}^{k+1/2} - \\alpha_k \\, A^T (A \\mb{x} - \\mb{y}) \\;,\n\nwith :math:`\\mb{y}` the measured signal and :math:`\\mb{x} + \\mb{x}^{k+1/2}` a residual CNN.\n\n\nMoDL\n----\n\nThe model-based deep learning (MoDL) :cite:`aggarwal-2019-modl`, implemented as :class:`.MoDLNet`, is used to solve inverse problems in imaging also by adapting classical iterative methods into an end-to-end deep learning framework, but, in contrast to ODP, it solves the optimization problem\n\n.. math::\n   \\argmin_{\\mb{x}} \\; \\| A \\mb{x} - \\mb{y}\\|_2^2 + \\lambda \\, \\| \\mb{x} - \\mathrm{D}_w(\\mb{x})\\|_2^2 \\;,\n\nby directly computing the update\n\n.. math::\n   \\mb{x}^{k+1} = (A^T A + \\lambda \\, I)^{-1} (A^T \\mb{y} + \\lambda \\, \\mb{z}^k) \\;,\n\nvia conjugate gradient. The regularization :math:`\\mb{z}^k = \\mathrm{D}_w(\\mb{x}^{k})` incorporates prior information, usually in the form of a denoiser model. In this case, the denoiser :math:`\\mathrm{D}_w` is shared between all the stages of the network requiring relatively less memory than other unrolling methods. This also allows for deploying a different number of iterations in testing than the ones used in training.\n"
  },
  {
    "path": "docs/source/include/operator.rst",
    "content": "Operators\n=========\n\nAn operator is a map from :math:`\\mathbb{R}^n` or :math:`\\mathbb{C}^n`\nto :math:`\\mathbb{R}^m` or :math:`\\mathbb{C}^m`. In SCICO, operators\nare primarily used to represent imaging systems and provide\nregularization. SCICO operators are represented by instances of the\n:class:`.Operator` class.\n\nSCICO :class:`.Operator` objects extend the notion of \"shape\" and\n\"size\" from the usual NumPy ``ndarray`` class. Each\n:class:`.Operator` object has an ``input_shape`` and ``output_shape``;\nthese shapes can be either tuples or a tuple of tuples (in the case of\na :class:`.BlockArray`). The ``matrix_shape`` attribute describes the\nshape of the :class:`.LinearOperator` if it were to act on vectorized,\nor flattened, inputs.\n\nFor example, consider a two-dimensional array :math:`\\mb{x} \\in\n\\mathbb{R}^{n \\times m}`. We compute the discrete differences of\n:math:`\\mb{x}` in the horizontal and vertical directions, generating\ntwo new arrays: :math:`\\mb{x}_h \\in \\mathbb{R}^{n \\times (m-1)}` and\n:math:`\\mb{x}_v \\in \\mathbb{R}^{(n-1) \\times m}`. We represent this\nlinear operator by :math:`\\mb{A} : \\mathbb{R}^{n \\times m} \\to\n\\mathbb{R}^{n \\times (m-1)} \\otimes \\mathbb{R}^{(n-1) \\times m}`. In\nSCICO, this linear operator will return a :class:`.BlockArray` with\nthe horizontal and vertical differences stored as blocks. Letting\n:math:`y = \\mb{A} x`, we have ``y.shape = ((n, m-1), (n-1, m))`` and\n\n::\n\n    A.input_shape = (n, m)\n    A.output_shape = ((n, m-1), (n-1, m)], (n, m))\n    A.shape = ( ((n, m-1), (n-1, m)), (n, m))   # (output_shape, input_shape)\n    A.input_size = n*m\n    A.output_size = n*(n-1)*m*(m-1)\n    A.matrix_shape = (n*(n-1)*m*(m-1), n*m)    # (output_size, input_size)\n\n\nOperator Calculus\n-----------------\n\nSCICO supports a variety of operator calculus rules, allowing new\noperators to be defined in terms of old ones. The following table\nsummarizes the available operations.\n\n+----------------+-----------------+\n| Operation      |  Result         |\n+----------------+-----------------+\n| ``(A+B)(x)``   | ``A(x) + B(x)`` |\n+----------------+-----------------+\n| ``(A-B)(x)``   | ``A(x) - B(x)`` |\n+----------------+-----------------+\n| ``(c * A)(x)`` | ``c * A(x)``    |\n+----------------+-----------------+\n| ``(A/c)(x)``   | ``A(x)/c``      |\n+----------------+-----------------+\n| ``(-A)(x)``    | ``-A(x)``       |\n+----------------+-----------------+\n| ``A(B)(x)``    | ``A(B(x))``     |\n+----------------+-----------------+\n| ``A(B)``       | ``Operator``    |\n+----------------+-----------------+\n\n\nDefining a New Operator\n-----------------------\n\nTo define a new operator, pass a callable to the :class:`.Operator`\nconstructor:\n\n::\n\n    A = Operator(input_shape=(32,), eval_fn = lambda x: 2 * x)\n\n\nOr use subclassing:\n\n::\n\n   >>> from scico.operator import Operator\n   >>> class MyOp(Operator):\n   ...\n   ...     def _eval(self, x):\n   ...         return 2 * x\n\n   >>> A = MyOp(input_shape=(32,))\n\nAt a minimum, the ``_eval`` function must be overridden.  If either\n``output_shape`` or ``output_dtype`` are unspecified, they are\ndetermined by evaluating the operator on an input of appropriate shape\nand dtype.\n\n\nLinear Operators\n================\n\nLinear operators are those for which\n\n.. math::\n\n   H(a \\mb{x} + b \\mb{y}) = a H(\\mb{x}) + b H(\\mb{y}) \\;.\n\nSCICO represents linear operators as instances of the class\n:class:`.LinearOperator`.  While finite-dimensional linear operators\ncan always be associated with a matrix, it is often useful to\nrepresent them in a matrix-free manner.  Most of SCICO's linear\noperators are implemented matrix-free.\n\n\n\nUsing a LinearOperator\n----------------------\n\nWe implement two ways to evaluate a :class:`.LinearOperator`. The\nfirst is using standard callable syntax: ``A(x)``. The second mimics\nthe NumPy matrix multiplication syntax: ``A @ x``. Both methods\nperform shape and type checks to validate the input before ultimately\neither calling `A._eval` or generating a new :class:`.LinearOperator`.\n\nFor linear operators that map real-valued inputs to real-valued\noutputs, there are two ways to apply the adjoint: ``A.adj(y)`` and\n``A.T @ y``.\n\nFor complex-valued linear operators, there are three ways to apply the\nadjoint ``A.adj(y)``, ``A.H @ y``, and ``A.conj().T @ y``.  Note that\nin this case, ``A.T`` returns the non-conjugated transpose of the\n:class:`.LinearOperator`.\n\nWhile the cost of evaluating the linear operator is virtually\nidentical for ``A(x)`` and ``A @ x``, the ``A.H`` and ``A.conj().T``\nmethods are somewhat slower; especially the latter. This is because\ntwo intermediate linear operators must be created before the function\nis evaluated.  Evaluating ``A.conj().T @ y`` is equivalent to:\n\n::\n\n  def f(y):\n    B = A.conj()  # New LinearOperator #1\n    C = B.T       # New LinearOperator #2\n    return C @ y\n\n**Note**: the speed differences between these methods vanish if\napplied inside of a jit-ed function.  For instance:\n\n::\n\n   f = jax.jit(lambda x:  A.conj().T @ x)\n\n\n+------------------+-----------------+\n|  Public Method   |  Private Method |\n+------------------+-----------------+\n|  ``__call__``    |  ``._eval``     |\n+------------------+-----------------+\n|  ``adj``         |  ``._adj``      |\n+------------------+-----------------+\n|  ``gram``        |  ``._gram``     |\n+------------------+-----------------+\n\nThe public methods perform shape and type checking to validate the\ninput before either calling the corresponding private method or\nreturning a composite LinearOperator.\n\n\nLinear Operator Calculus\n------------------------\n\nSCICO supports several linear operator calculus rules.\nGiven\n``A`` and ``B`` of class :class:`.LinearOperator` and of appropriate shape,\n``x`` an array of appropriate shape,\n``c`` a scalar, and\n``O`` an :class:`.Operator`,\nwe have\n\n+----------------+----------------------------+\n| Operation      |  Result                    |\n+----------------+----------------------------+\n| ``(A+B)(x)``   | ``A(x) + B(x)``            |\n+----------------+----------------------------+\n| ``(A-B)(x)``   | ``A(x) - B(x)``            |\n+----------------+----------------------------+\n| ``(c * A)(x)`` | ``c * A(x)``               |\n+----------------+----------------------------+\n| ``(A/c)(x)``   | ``A(x)/c``                 |\n+----------------+----------------------------+\n| ``(-A)(x)``    | ``-A(x)``                  |\n+----------------+----------------------------+\n| ``(A@B)(x)``   | ``A@B@x``                  |\n+----------------+----------------------------+\n| ``A @ B``      | ``ComposedLinearOperator`` |\n+----------------+----------------------------+\n| ``A @ O``      | ``Operator``               |\n+----------------+----------------------------+\n| ``O(A)``       | ``Operator``               |\n+----------------+----------------------------+\n\n\n\nDefining a New Linear Operator\n------------------------------\n\nTo define a new linear operator, pass a callable to the\n:class:`.LinearOperator` constructor\n\n::\n\n   >>> from scico.linop import LinearOperator\n   >>> A = LinearOperator(input_shape=(32,),\n   ...       eval_fn = lambda x: 2 * x)\n\nOr, use subclassing:\n\n::\n\n   >>> class MyLinearOperator(LinearOperator):\n   ...    def _eval(self, x):\n   ...        return 2 * x\n\n   >>> A = MyLinearOperator(input_shape=(32,))\n\nAt a minimum, the ``_eval`` method must be overridden.  If the\n``_adj`` method is not overriden, the adjoint is determined using\n:func:`scico.linear_adjoint`.  If either ``output_shape`` or\n``output_dtype`` are unspecified, they are determined by evaluating\nthe Operator on an input of appropriate shape and dtype.\n\n\n🔪 Sharp Edges 🔪\n------------------\n\nStrict Types in Adjoint\n^^^^^^^^^^^^^^^^^^^^^^^\n\nSCICO silently promotes real types to complex types in forward\napplication, but enforces strict type checking in the adjoint.  This\nis due to the strict type-safe nature of jax adjoints.\n\n\nLinearOperators from External Code\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nExternal code may be wrapped as a subclass of :class:`.Operator` or\n:class:`.LinearOperator` and used in SCICO optimization routines;\nhowever this process can be complicated and error-prone.  As a\nstarting point, look at the source for\n:class:`.radon_svmbir.TomographicProjector` or\n:class:`.radon_astra.TomographicProjector` and the JAX documentation\nfor the `vector-jacobian product\n<https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#vector-jacobian-products-vjps-aka-reverse-mode-autodiff>`_\nand `custom VJP rules\n<https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html>`_.\n"
  },
  {
    "path": "docs/source/include/optimizer.rst",
    "content": ".. _optimizer:\n\nOptimization Algorithms\n=======================\n\nADMM\n----\n\nThe Alternating Direction Method of Multipliers (ADMM)\n:cite:`glowinski-1975-approximation` :cite:`gabay-1976-dual` is an\nalgorithm for minimizing problems of the form\n\n.. math::\n   :label: eq:admm_prob\n\n   \\argmin_{\\mb{x}, \\mb{z}} \\; f(\\mb{x}) + g(\\mb{z}) \\; \\text{such that}\n   \\; \\acute{A} \\mb{x} + \\acute{B} \\mb{z} = \\mb{c} \\;,\n\nwhere :math:`f` and :math:`g` are convex (but not necessarily smooth)\nfunctionals, :math:`\\acute{A}` and :math:`\\acute{B}` are linear operators,\nand :math:`\\mb{c}` is a constant vector. (For a thorough introduction and\noverview, see :cite:`boyd-2010-distributed`.)\n\nThe SCICO ADMM solver, :class:`.ADMM`, solves problems of the form\n\n.. math::\n   \\argmin_{\\mb{x}} \\; f(\\mb{x}) + \\sum_{i=1}^N g_i(C_i \\mb{x}) \\;,\n\nwhere :math:`f` and the :math:`g_i` are instances of :class:`.Functional`,\nand the :math:`C_i` are :class:`.LinearOperator`, by defining\n\n.. math::\n   g(\\mb{z}) = \\sum_{i=1}^N g_i(\\mb{z}_i) \\qquad \\mb{z}_i = C_i \\mb{x}\n\nin :eq:`eq:admm_prob`, corresponding to defining\n\n.. math::\n  \\acute{A} = \\left( \\begin{array}{c} C_0 \\\\ C_1 \\\\ C_2 \\\\\n              \\vdots \\end{array} \\right)  \\quad\n  \\acute{B} = \\left( \\begin{array}{cccc}\n              -I & 0 & 0 & \\ldots \\\\\n              0 & -I & 0 & \\ldots \\\\\n              0 &  0  & -I & \\ldots \\\\\n              \\vdots & \\vdots & \\vdots & \\ddots\n              \\end{array} \\right) \\quad\n  \\mb{z} = \\left( \\begin{array}{c} \\mb{z}_0 \\\\ \\mb{z}_1 \\\\ \\mb{z}_2 \\\\\n              \\vdots \\end{array} \\right)  \\quad\n  \\mb{c} = \\left( \\begin{array}{c} 0 \\\\ 0 \\\\ 0 \\\\\n              \\vdots \\end{array} \\right) \\;.\n\nIn :class:`.ADMM`, :math:`f` is a :class:`.Functional`, typically a\n:class:`.Loss`, corresponding to the forward model of an imaging\nproblem, and the :math:`g_i` are :class:`.Functional`, typically\ncorresponding to a regularization term or constraint. Each of the\n:math:`g_i` must have a proximal operator defined. It is also possible\nto set ``f = None``, which corresponds to defining :math:`f = 0`,\ni.e. the zero function.\n\n\nSubproblem Solvers\n^^^^^^^^^^^^^^^^^^\n\nThe most computational expensive component of the ADMM iterations is typically\nthe :math:`\\mb{x}`-update,\n\n.. math::\n   :label: eq:admm_x_step\n\n   \\argmin_{\\mb{x}} \\; f(\\mb{x}) + \\sum_i \\frac{\\rho_i}{2}\n   \\norm{\\mb{z}^{(k)}_i - \\mb{u}^{(k)}_i - C_i \\mb{x}}_2^2 \\;.\n\n\nThe available solvers for this problem are:\n\n* :class:`.admm.GenericSubproblemSolver`\n\n  This is the default subproblem solver as it is applicable in all cases. It\n  it is only suitable for relatively small-scale problems as it makes use of\n  :func:`.solver.minimize`, which wraps :func:`scipy.optimize.minimize`.\n\n* :class:`.admm.LinearSubproblemSolver`\n\n  This subproblem solver can be used when :math:`f` takes the form\n  :math:`\\norm{\\mb{A} \\mb{x} - \\mb{y}}^2_W`. It makes use of the conjugate\n  gradient method, and is significantly more efficient than\n  :class:`.admm.GenericSubproblemSolver` when it can be used.\n\n* :class:`.admm.MatrixSubproblemSolver`\n\n  This subproblem solver can be used when :math:`f` takes the form\n  :math:`\\norm{\\mb{A} \\mb{x} - \\mb{y}}^2_W`, and :math:`A` and all of the\n  :math:`C_i` are diagonal (:class:`.Diagonal`) or matrix operators\n  (:class:`MatrixOperator`). It exploits a pre-computed matrix factorization\n  for a significantly more efficient solution than conjugate gradient.\n\n* :class:`.admm.CircularConvolveSolver`\n\n  This subproblem solver can be used when :math:`f` takes the form\n  :math:`\\norm{\\mb{A} \\mb{x} - \\mb{y}}^2_W` and :math:`\\mb{A}` and all\n  the :math:`C_i` s are circulant (i.e., diagonalized by the DFT).\n\n* :class:`.admm.FBlockCircularConvolveSolver` and :class:`.admm.G0BlockCircularConvolveSolver`\n\n  These subproblem solvers can be used when the primary linear operator\n  is block-circulant (i.e. an operator with blocks that are diagonalied\n  by the DFT).\n\n\nFor more details of these solvers and how to specify them, see the API\nreference page for :mod:`scico.optimize.admm`.\n\n\nProximal ADMM\n-------------\n\nProximal ADMM :cite:`deng-2015-global` is an algorithm for solving\nproblems of the form\n\n.. math::\n\n   \\argmin_{\\mb{x}} \\; f(\\mb{x}) + g(\\mb{z}) \\;\n   \\text{such that}\\; A \\mb{x} + B \\mb{z} = \\mb{c} \\;,\n\nwhere :math:`f` and :math:`g` are are convex (but not necessarily\nsmooth) functionals and :math:`A` and :math:`B` are linear\noperators. Although convergence per iteration is typically somewhat\nworse than that of ADMM, the iterations can be much cheaper than that\nof ADMM, giving Proximal ADMM competitive time convergence\nperformance.\n\nThe SCICO Proximal ADMM solver, :class:`.ProximalADMM`, requires\n:math:`f` and :math:`g` to be instances of :class:`.Functional`, and\nto have a proximal operator defined (:meth:`.Functional.prox`), and\n:math:`A` and :math:`B` are required to be an instance of\n:class:`.LinearOperator`.\n\n\nNon-Linear Proximal ADMM\n------------------------\n\nNon-Linear Proximal ADMM :cite:`benning-2016-preconditioned` is an\nalgorithm for solving problems of the form\n\n.. math::\n   \\argmin_{\\mb{x}} \\; f(\\mb{x}) + g(\\mb{z}) \\;\n   \\text{such that}\\; H(\\mb{x}, \\mb{z}) = 0 \\;,\n\nwhere :math:`f` and :math:`g` are are convex (but not necessarily\nsmooth) functionals and :math:`H` is a function of two vector variables.\n\nThe SCICO Non-Linear Proximal ADMM solver, :class:`.NonLinearPADMM`, requires\n:math:`f` and :math:`g` to be instances of :class:`.Functional`, and\nto have a proximal operator defined (:meth:`.Functional.prox`), and\n:math:`H` is required to be an instance of :class:`.Function`.\n\n\n\nLinearized ADMM\n---------------\n\nLinearized ADMM :cite:`yang-2012-linearized`\n:cite:`parikh-2014-proximal` (Sec. 4.4.2) is an algorithm for solving\nproblems of the form\n\n.. math::\n   \\argmin_{\\mb{x}} \\; f(\\mb{x}) + g(C \\mb{x}) \\;,\n\nwhere :math:`f` and :math:`g` are are convex (but not necessarily\nsmooth) functionals. Although convergence per iteration is typically\nsignificantly worse than that of ADMM, the :math:`\\mb{x}`-update, can\nbe much cheaper than that of ADMM, giving Linearized ADMM competitive\ntime convergence performance.\n\nThe SCICO Linearized ADMM solver, :class:`.LinearizedADMM`,\nrequires :math:`f` and :math:`g` to be instances of :class:`.Functional`,\nand to have a proximal operator defined (:meth:`.Functional.prox`), and\n:math:`C` is required to be an instance of :class:`.LinearOperator`.\n\n\n\nPDHG\n----\n\nThe Primal–Dual Hybrid Gradient (PDHG) algorithm\n:cite:`esser-2010-general` :cite:`chambolle-2010-firstorder`\n:cite:`pock-2011-diagonal` solves problems of the form\n\n.. math::\n   \\argmin_{\\mb{x}} \\; f(\\mb{x}) + g(C \\mb{x}) \\;,\n\nwhere :math:`f` and :math:`g` are are convex (but not necessarily smooth)\nfunctionals. The algorithm has similar advantages over ADMM to those of Linearized ADMM, but typically exhibits better convergence properties.\n\nThe SCICO PDHG solver, :class:`.PDHG`,\nrequires :math:`f` and :math:`g` to be instances of :class:`.Functional`,\nand to have a proximal operator defined (:meth:`.Functional.prox`), and\n:math:`C` is required to be an instance of :class:`.Operator` or :class:`.LinearOperator`.\n\n\n\nPGM\n---\n\nThe Proximal Gradient Method (PGM) :cite:`daubechies-2004-iterative`\n:cite:`beck-2010-gradient` and Accelerated Proximal Gradient Method\n(AcceleratedPGM) :cite:`beck-2009-fast` are algorithms for minimizing\nproblems of the form\n\n.. math::\n   \\argmin_{\\mb{x}} f(\\mb{x}) + g(\\mb{x}) \\;,\n\nwhere :math:`g` is convex and :math:`f` is smooth and convex. The\ncorresponding SCICO solvers are :class:`.PGM` and :class:`.AcceleratedPGM`\nrespectively. In most cases :class:`.AcceleratedPGM` is expected to provide\nfaster convergence. In both of these classes, :math:`f` and :math:`g` are\nboth of type :class:`.Functional`, where :math:`f` must be differentiable,\nand :math:`g` must have a proximal operator defined.\n\nWhile ADMM provides significantly more flexibility than PGM, and often\nconverges faster, the latter is preferred when solving the ADMM\n:math:`\\mb{x}`-step is very computationally expensive, such as in the case of\n:math:`f(\\mb{x}) = \\norm{\\mb{A} \\mb{x} - \\mb{y}}^2_W` where :math:`A` is\nlarge and does not have any special structure that would allow an efficient\nsolution of :eq:`eq:admm_x_step`.\n\n\n\nStep Size Options\n^^^^^^^^^^^^^^^^^\n\nThe step size (usually referred to in terms of its reciprocal,\n:math:`L`) for the gradient descent in :class:`PGM` can be adapted via\nBarzilai-Borwein methods (also called spectral methods) and iterative\nline search methods.\n\nThe available step size policy classes are:\n\n* :class:`.BBStepSize`\n\n  This implements the step size adaptation based on the Barzilai-Borwein\n  method :cite:`barzilai-1988-stepsize`. The step size :math:`\\alpha` is\n  estimated as\n\n  .. math::\n     \\mb{\\Delta x} = \\mb{x}_k - \\mb{x}_{k-1} \\; \\\\\n     \\mb{\\Delta g} = \\nabla f(\\mb{x}_k) - \\nabla f (\\mb{x}_{k-1}) \\; \\\\\n     \\alpha = \\frac{\\mb{\\Delta x}^T \\mb{\\Delta g}}{\\mb{\\Delta g}^T\n     \\mb{\\Delta g}} \\;.\n\n  Since the PGM solver uses the reciprocal of the step size, the value\n  :math:`L = 1 / \\alpha` is returned.\n\n\n* :class:`.AdaptiveBBStepSize`\n\n  This implements the adaptive Barzilai-Borwein method as introduced in\n  :cite:`zhou-2006-adaptive`. The adaptive step size rule computes\n\n  .. math::\n     \\mb{\\Delta x} = \\mb{x}_k - \\mb{x}_{k-1} \\; \\\\\n     \\mb{\\Delta g} = \\nabla f(\\mb{x}_k) - \\nabla f (\\mb{x}_{k-1}) \\; \\\\\n     \\alpha^{\\mathrm{BB1}} = \\frac{\\mb{\\Delta x}^T \\mb{\\Delta x}}\n     {\\mb{\\Delta x}^T \\mb{\\Delta g}} \\; \\\\\n     \\alpha^{\\mathrm{BB2}} = \\frac{\\mb{\\Delta x}^T \\mb{\\Delta g}}\n     {\\mb{\\Delta g}^T \\mb{\\Delta g}} \\;.\n\n  The determination of the new step size is made via the rule\n\n  .. math::\n     \\alpha = \\left\\{ \\begin{array}{ll} \\alpha^{\\mathrm{BB2}}  &\n     \\mathrm{~if~} \\alpha^{\\mathrm{BB2}} / \\alpha^{\\mathrm{BB1}}\n     < \\kappa \\; \\\\\n     \\alpha^{\\mathrm{BB1}}  & \\mathrm{~otherwise} \\end{array}\n     \\right . \\;,\n\n  with :math:`\\kappa \\in (0, 1)`.\n\n  Since the PGM solver uses the reciprocal of the step size, the value\n  :math:`L = 1 / \\alpha` is returned.\n\n\n* :class:`.LineSearchStepSize`\n\n  This implements the line search strategy described in :cite:`beck-2009-fast`.\n  This strategy estimates :math:`L` such that\n  :math:`f(\\mb{x}) \\leq \\hat{f}_{L}(\\mb{x})` is satisfied with\n  :math:`\\hat{f}_{L}` a quadratic approximation to :math:`f` defined as\n\n  .. math::\n     \\hat{f}_{L}(\\mb{x}, \\mb{y}) = f(\\mb{y}) + \\nabla f(\\mb{y})^H\n     (\\mb{x} - \\mb{y}) + \\frac{L}{2} \\left\\| \\mb{x} - \\mb{y}\n     \\right\\|_2^2 \\;,\n\n  with :math:`\\mb{x}` the potential new update and :math:`\\mb{y}` the\n  current solution or current extrapolation (if using :class:`.AcceleratedPGM`).\n\n\n* :class:`.RobustLineSearchStepSize`\n\n  This implements the robust line search strategy described in\n  :cite:`florea-2017-robust`. This strategy estimates :math:`L` such that\n  :math:`f(\\mb{x}) \\leq \\hat{f}_{L}(\\mb{x})` is satisfied with\n  :math:`\\hat{f}_{L}` a quadratic approximation to :math:`f` defined as\n\n  .. math::\n     \\hat{f}_{L}(\\mb{x}, \\mb{y}) = f(\\mb{y}) + \\nabla f(\\mb{y})^H\n     (\\mb{x} - \\mb{y}) + \\frac{L}{2} \\left\\| \\mb{x} - \\mb{y} \\right\\|_2^2 \\;,\n\n  with :math:`\\mb{x}` the potential new update and :math:`\\mb{y}` the\n  auxiliary extrapolation state. Note that this should only be used\n  with :class:`.AcceleratedPGM`.\n\n\nFor more details of these step size managers and how to specify them, see\nthe API reference page for :mod:`scico.optimize.pgm`.\n"
  },
  {
    "path": "docs/source/index.rst",
    "content": "SCICO Documentation\n===================\n\n.. toctree::\n   :maxdepth: 2\n   :caption: User Documentation\n\n   overview\n   inverse\n   advantages\n   install\n   classes\n   notes\n   examples\n   API Reference <_autosummary/scico.rst>\n   zreferences\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Developer Documentation\n\n   team\n   contributing\n   style\n\n\nIndices\n=======\n\n* :ref:`genindex`\n* :ref:`modindex`\n"
  },
  {
    "path": "docs/source/install.rst",
    "content": ".. _installing:\n\nInstalling SCICO\n================\n\nSCICO requires Python version 3.8 or later. (Version 3.12 is\nrecommended as it is the version under which SCICO is tested in GitHub\ncontinuous integration, and since the most recent versions of JAX require\nversion 3.10 or later.) SCICO is supported on both Linux and\nMacOS, but is not currently supported on Windows due to the limited\nsupport for ``jaxlib`` on Windows. However, Windows users can use\nSCICO via the `Windows Subsystem for Linux\n<https://docs.microsoft.com/en-us/windows/wsl/about>`_ (WSL). Guides\nexist for using WSL with\n`CPU only <https://docs.microsoft.com/en-us/windows/wsl/install-win10>`_\nand with\n`GPU support <https://docs.microsoft.com/en-us/windows/win32/direct3d12/gpu-cuda-in-wsl>`_.\n\nWhile not required, installation of SCICO and its dependencies within a\n`Conda <https://conda.io/projects/conda/en/latest/user-guide/index.html>`_\nenvironment is recommended.\n`Scripts <https://github.com/lanl/scico/tree/main/misc/conda>`_\nare provided for creating a\n`miniconda <https://docs.conda.io/en/latest/miniconda.html>`_\ninstallation and an environment including all primary SCICO dependencies\nas well as dependencies for usage example, testing, and building the\ndocumentation.\n\n\nFrom PyPI\n---------\n\nThe simplest way to install the most recent release of SCICO from\n`PyPI <https://pypi.python.org/pypi/scico/>`_ is\n::\n\n   pip install scico\n\nwhich will install SCICO and its primary dependencies. If the additional\ndependencies for the example scripts are also desired, it can instead be\ninstalled using\n::\n\n   pip install scico[examples]\n\nNote, however, that since the ``astra-toolbox`` package available from\nPyPI is not straightforward to install (it has numerous build requirements\nthat are not specified as package dependencies), it is recommended to\nfirst install this package via conda\n::\n\n   conda install astra-toolbox\n\n\n\nFrom conda-forge\n----------------\n\nSCICO can also be installed from `conda-forge <https://anaconda.org/conda-forge/scico>`_\n::\n\n  conda install -c conda-forge \"scico>0.0.5\"\n\nwhere the version constraint is required to avoid installation of an old\npackage with broken dependencies.\n\nNote, however, that installation from conda forge is only possible on a Linux\nplatform since there is no conda package for the secondary dependency\n``tensorstore`` under MacOS. There are also complications on Linux platforms\nwith Python versions 3.9 or earlier due to the automatic installation of a\nversion of secondary dependency ``etils`` that does not support Python versions\nearlier than 3.10. This can be rectified by\n::\n\n  conda install etils=1.5.1\n\nThe most recent SCICO conda forge package also includes dependencies for\nthe example scripts, except for ``bm3d``, ``bm4d``, and\n``colour_demosaicing``, for which conda packages are not available. These\ncan be installed from PyPI\n::\n\n  pip install bm3d bm4d colour_demosaicing\n\n\n\nFrom GitHub\n-----------\n\nThe development version of SCICO can be downloaded from the `GitHub repo\n<https://github.com/lanl/scico>`_. Note that, since the SCICO repo has\na submodule, it should be cloned via the command\n::\n\n   git clone --recurse-submodules git@github.com:lanl/scico.git\n\nInstall using the commands\n::\n\n   cd scico\n   pip install -r requirements.txt\n   pip install -e .\n\n\nIf a clone of the SCICO repository is not needed, it is simpler to\ninstall directly using ``pip``\n::\n\n   pip install git+https://github.com/lanl/scico\n\n\n\nGPU Support\n-----------\n\nThe instructions above install a CPU-only version of SCICO. To install\na version with GPU support:\n\n1. Follow the CPU-only instructions, above\n\n2. Install the version of jaxlib with GPU support, as described in the `JAX installation\n   instructions  <https://jax.readthedocs.io/en/latest/installation.html>`_.\n   In the simplest case, the appropriate command is\n   ::\n\n      pip install --upgrade \"jax[cuda12]\"\n\n   for CUDA 12, but it may be necessary to explicitly specify the\n   ``jaxlib`` version if the most recent release is not yet supported\n   by SCICO (as specified in the ``requirements.txt`` file).\n\n\nThe script\n`misc/gpu/envinfo.py <https://github.com/lanl/scico/blob/main/misc/gpu/envinfo.py>`_\nin the source distribution is provided as an aid to debugging GPU support\nissues. The script\n`misc/gpu/availgpu.py <https://github.com/lanl/scico/blob/main/misc/gpu/availgpu.py>`_\ncan be used to automatically recommend a setting of the CUDA_VISIBLE_DEVICES\nenvironment variable that excludes GPUs that are already in use.\n\n\n\nAdditional Dependencies\n-----------------------\n\nSee :ref:`example_depend` for instructions on installing dependencies\nrelated to the examples.\n\n\nFor Developers\n--------------\n\nSee :ref:`scico_dev_contributing` for instructions on installing a\nversion of SCICO suitable for development.\n"
  },
  {
    "path": "docs/source/inverse.rst",
    "content": "Inverse Problems\n================\n\nIn traditional imaging, the burden of image formation is placed on\nphysical components, such as a lens, with the resulting image being\ntaken from the sensor with minimal processing. In computational\nimaging, in contrast, the burden of image formation is shared with or\nshifted to computation, with the resulting image typically being very\ndifferent from the measured data. Common examples of computational\nimaging include demosaicing in consumer cameras, computed tomography\nand magnetic resonance imaging in medicine, and synthetic aperture\nradar in remote sensing. This is an active and growing area of\nresearch, and many of these problems have common properties that could\nbe supported by shared implementations of solution components.\n\nThe goal of SCICO is to provide a general research tool for\ncomputational imaging, with a particular focus on scientific imaging\napplications, which are particularly underrepresented in the existing\nrange of open-source packages in this area. While a number of other\npackages overlap somewhat in functionality with SCICO, only a few\nsupport execution of the same code on both CPU and GPU devices, and we\nare not aware of any that support just-in-time compilation and\nautomatic gradient computation, which is invaluable in computational\nimaging. SCICO provides all three of these valuable features (subject\nto some :ref:`caveats <non_jax_dep>`) by being built on top of `JAX\n<https://jax.readthedocs.io/en/latest/>`__ rather than `NumPy\n<https://numpy.org/>`__.\n\n\nThe remainder of this section outlines the steps involved in solving\nan inverse problem, and shows how each concept maps to a component of\nSCICO. More detail on the main classes involved in setting up and\nsolving an inverse problem can be found in :ref:`classes`.\n\n\nForward Modeling\n----------------\n\nIn order to solve a computational imaging problem we need to know how\nthe image we wish to reconstruct, :math:`\\mathbf{x}`, is related to the\ndata that we can measure, :math:`\\mathbf{y}`. This is represented via a\nmodel of the measurement process,\n\n.. math:: \\mathbf{y} = A(\\mathbf{x}) \\,.\n\nSCICO provides the :class:`.Operator` and :class:`.LinearOperator`\nclasses, which may be subclassed by users, in order to implement the\nforward operator, :math:`A`. It also has several built-in operators,\nmost of which are linear, e.g., finite convolutions, discrete Fourier\ntransforms, optical propagators, Abel transforms, and X-ray transforms\n(the same as Radon transforms in 2D). For example,\n\n.. code:: python\n\n       input_shape = (512, 512)\n       angles = np.linspace(0, 2 * np.pi, 180, endpoint=False)\n       channels = 512\n       A = scico.linop.xray.svmbir.XRayTransform(input_shape, angles, channels)\n\ndefines a tomographic projection operator.\n\nA significant advantage of SCICO being built on top of `JAX\n<https://jax.readthedocs.io/en/latest/>`__ is that the adjoints of\nlinear operators, which can be quite time consuming to implement even\nwhen the operator itself is straightforward, are computed\nautomatically by exploiting the automatic differentation features of\n`JAX <https://jax.readthedocs.io/en/latest/>`__. If :code:`A` is a\n:class:`.LinearOperator`, then its adjoint is simply :code:`A.T` for\nreal transforms and :code:`A.H` for complex transforms. Likewise,\nJacobian-vector products can be automatically computed for non-linear\noperators, allowing for simple linearization and gradient\ncalculations.\n\nSCICO operators can be composed to construct new operators. (If both\noperands are linear, then the result is also linear.) For example, if\n:code:`A` and :code:`B` have been defined as distinct linear\noperators, then\n\n.. code:: python\n\n       C = B @ A\n\ndefines a new linear operator :code:`C` that first applies operator\n:code:`A` and then applies operator :code:`B` to the result\n(i.e. :math:`C = B A` in math notation). This operator algebra can be\nused to build complicated forward operators from simpler building\nblocks.\n\nSCICO also handles cases where either the image we want to\nreconstruct, :math:`\\mb{x}`, or its measurements, :math:`\\mb{y}`, do\nnot fit neatly into a multi-dimensional array. This is achieved via\n:class:`.BlockArray` objects, which consist of a :class:`list` of\nmulti-dimensional array *blocks*. A :class:`.BlockArray` differs from\na :class:`list` in that, whenever possible, :class:`.BlockArray`\nproperties and methods (including unary and binary operators like\n``+``, ``-``, ``*``, …) automatically map along the blocks, returning\nanother :class:`.BlockArray` or :class:`tuple` as appropriate. For\nexample, consider a system that measures the column sums and row sums\nof an image. If the input image has shape :math:`M \\times N`, the\nresulting measurement will have shape :math:`M + N`, which is awkward\nto represent as a multi-dimensional array. In SCICO, we can represent\nthis operator by\n\n.. code:: python\n\n       input_shape = (130, 50)\n       H0 = scico.linop.Sum(input_shape, axis=0)\n       H1 = scico.linop.Sum(input_shape, axis=1)\n       H = scico.linop.VerticalStack((H0, H1))\n\nThe result of applying ``H`` to an image with shape ``(130, 50)`` is a\n:class:`.BlockArray` with shape ``((50,), (130,))``. This result is\ncompatible with the rest of SCICO and may be used, e.g., as the input\nof other operators.\n\n\nInverse Problem Formulation\n---------------------------\n\nIn order to estimate the image from the measured data, we need to solve\nan *inverse problem*. In its simplest form, the solution to such an\ninverse problem can be expressed as the optimization problem\n\n.. math:: \\hat{\\mb{x}} = \\mathop{\\mathrm{arg\\,min}}_{\\mb{x}} f( \\mb{x} ) \\,,\n\nwhere :math:`\\mb{x}` is the unknown image and :math:`\\hat{\\mb{x}}` is\nthe recovered image. A common choice of :math:`f` is\n\n.. math:: f(\\mb{x}) = (1/2) \\| A(\\mb{x}) - \\mb{y} \\|_2^2 \\,,\n\nwhere :math:`\\mb{y}` is the measured data and :math:`A` is the\nforward operator; in this case the minimization problem is a least\nsquares problem.\n\nIn SCICO, the :mod:`.functional` module provides implementations of common\nfunctionals such as :math:`\\ell_2` and :math:`\\ell_1` norms. The\n:mod:`.loss` module is used to implement a special type of functional\n\n.. math:: f(\\mb{x}) = \\alpha l(A(\\mb{x}),\\mb{y}) \\,,\n\nwhere :math:`\\alpha` is a scaling parameter and :math:`l(\\cdot)` is\nanother functional. The SCICO :mod:`.loss` module contains a variety\nof loss functionals that are commonly used in computational\nimaging. For example, the squared :math:`\\ell_2` loss written above\nfor a forward operator, :math:`A`, can be defined in SCICO using the\ncode:\n\n.. code:: python\n\n       f = scico.loss.SquaredL2Loss(y=y, A=A)\n\nThe difficulty of the inverse problem depends on the amount of noise in\nthe measured data and the properties of the forward operator. In\nparticular, if :math:`A` is a linear operator, then the difficulty of\nthe inverse problem depends significantly on the condition number of\n:math:`A`, since a large condition number implies that large changes in\n:math:`\\mb{x}` can correspond to small changes in\n:math:`\\mb{y}`, making it difficult to estimate :math:`\\mb{x}`\nfrom :math:`\\mb{y}`. When there is a significant amount of\nmeasurement noise or ill-conditioning of :math:`A`, the standard\napproach to resolve the limitations in the information available from\nthe measured data is to introduce a *prior model* of the solution space,\nwhich is typically achieved by adding a *regularization term* to the\ndata fidelity term, resulting in the optimization problem\n\n.. math:: \\hat{\\mb{x}} = \\mathop{\\mathrm{arg\\,min}}_{\\mb{x}} f(\\mb{x}) + g(C (\\mb{x})) \\,,\n\nwhere the functional :math:`g(C(\\cdot))` is designed to increase the\ncost for solutions that are considered less likely or desirable, based\non prior knowledge of the properties of the solution space. A common\nchoice of :math:`g(C(\\cdot))` is the total variation norm\n\n.. math:: g(\\mb{x}) = \\lambda \\| C \\mb{x} \\|_{2,1} \\,,\n\nwhere :math:`\\lambda` is a scalar controlling the regularization\nstrength, :math:`C` is a linear operator that computes the spatial\ngradients of its argument, and :math:`\\| \\cdot \\|_{2,1}` denotes the\n:math:`\\ell_{2,1}` norm, which promotes group sparsity. Use of this\nfunctional as a regularization term corresponds to the assumption that\nthe images of interest are piecewise constant. In SCICO, we can\nrepresent this regularization functional using a built-in linear\noperator and a member of the :mod:`.functional` module:\n\n.. code:: python\n\n       C = scico.linop.FiniteDifference(A.input_shape, append=0)\n       λ = 1.0e-1\n       g = λ * scico.functional.L21Norm()\n\nComputing the value of the regularizer then closely matches the math:\n:code:`g(C(x))`.\n\nFinally, the overall objective function needs to be optimized. One of\nthe primary goals of SCICO is to make the solution of such problems\naccessible to application domain scientists with limited expertise in\ncomputational imaging, providing infrastructure for solving this type of\nproblem efficiently, without the need for the user to implement complex\nalgorithms.\n\n\nSolvers\n-------\n\nOnce an inverse problem has been specified using the above components,\nthe resulting functional must be minimized in order to solve the\nproblem. SCICO provides a number of optimization algorithms for\naddressing a wide range of problems. These optimization algorithms\nbelong to two distinct categories.\n\n\nBasic Solvers\n~~~~~~~~~~~~~\n\nThe :mod:`scico.solver` module provides a number of functions for\nsolving linear systems and simple optimization problems, some of which\nare useful as subproblem solvers within the proximal algorithms\ndescribed in the following section. It also provides an interface to\nfunctions in :mod:`scipy.optimize`, supporting their use with\nmulti-dimensional arrays and scico :class:`.Functional` objects. These\nalgorithms are useful both as subproblem solvers within the proximal\nalgorithms described below, as well as for direct solution of\nhigher-level problems.\n\nFor example,\n\n.. code:: python\n\n       f = scico.loss.PoissonLoss(y=y, A=A)\n       method = 'BFGS' # or any method available for scipy.optimize.minimize\n       x0 = scico.numpy.ones(A.input_shape)\n       res = scico.solver.minimize(f, x0=x0, method=method)\n       x_hat = res.x\n\ndefines a Poisson objective function and minimizes it using the BFGS\n:cite:`nocedal-2006-numerical` algorithm.\n\n\nProximal Algorithms\n~~~~~~~~~~~~~~~~~~~\n\nThe :mod:`scico.optimize` sub-package provides a set of *proximal\nalgorithms* :cite:`parikh-2014-proximal` that have proven to be useful\nfor solving imaging inverse problems. The common feature of these\nalgorithms is their exploitation of the *proximal operator*\n:cite:`beck-2017-first` (Ch. 6), of the components of the functions\nthat they minimize.\n\n**ADMM** The most flexible of the proximal algorithms supported by SCICO\nis the alternating direction method of multipliers (ADMM)\n:cite:`glowinski-1975-approximation` :cite:`gabay-1976-dual`\n:cite:`boyd-2010-distributed`, which supports solving problems of the form\n\n.. math:: \\mathop{\\mathrm{arg\\,min}}_{\\mb{x}} \\; f(\\mb{x}) + \\sum_{i=1}^N g_i(C_i \\mb{x}) \\,.\n\nWhen :math:`f(\\cdot)` is an instance of ``scico.loss.SquaredL2Loss``,\ni.e.,\n\n.. math:: f(\\mb{x}) = (1/2) \\| A \\mb{x} - \\mb{y} \\|_2^2 \\,,\n\nfor linear operator :math:`A` and constant vector :math:`\\mb{y}`,\nthe primary computational cost of the algorithm is typically in solving\na linear system involving a weighted sum of :math:`A^\\top A` and the\n:math:`C_i^\\top C_i`, assuming that the proximal operators of the\nfunctionals :math:`g_i(\\cdot)` can be computed efficiently. This linear\nsystem can also be solved efficiently when :math:`A` and all of the\n:math:`C_i` are either identity operators or circular convolutions.\n\n**Proximal ADMM** Proximal ADMM :cite:`deng-2015-global` solves problems of\nthe form\n\n.. math::\n    \\argmin_{\\mb{x}} \\; f(\\mb{x}) + g(\\mb{z}) \\;\n    \\text{such that}\\; A \\mb{x} + B \\mb{z} = \\mb{c} \\;,\n\nwhere :math:`A` and :math:`B` are linear operators. There is also a non-linear\nPADMM solver :cite:`benning-2016-preconditioned` for problems of the form\n\n.. math::\n    \\argmin_{\\mb{x}} \\; f(\\mb{x}) + g(\\mb{z}) \\;\n    \\text{such that}\\; H(\\mb{x}, \\mb{z}) = 0 \\;,\n\nwhere :math:`H` is a function. For some problems, proximal ADMM converges\nsubstantially faster than ADMM or linearized ADMM.\n\n**Linearized ADMM** Linearized ADMM :cite:`yang-2012-linearized`\n:cite:`parikh-2014-proximal` solves a more restricted problem form,\n\n.. math:: \\mathop{\\mathrm{arg\\,min}}_{\\mb{x}} \\; f(\\mb{x}) + g(C \\mb{x}) \\,.\n\nIt is an effective algorithm when the proximal operators of both\n:math:`f(\\cdot)` and :math:`g(\\cdot)` can be computed efficiently, and\nhas the advantage over \"standard\" ADMM of avoiding the need for solving\na linear system involving :math:`C^\\top C`.\n\n**PDHG** Primal–dual hybrid gradient (PDHG) :cite:`esser-2010-general`\n:cite:`chambolle-2010-firstorder` :cite:`pock-2011-diagonal` solves\nthe same form of problem as linearized ADMM\n\n.. math:: \\mathop{\\mathrm{arg\\,min}}_{\\mb{x}} \\; f(\\mb{x}) + g(C \\mb{x}) \\,,\n\nbut unlike the linearized ADMM implementation, both linear and\nnon-linear operators :math:`C` are supported. For some problems, PDHG\nconverges substantially faster than ADMM or linearized ADMM.\n\n**PGM and Accelerated PGM** The proximal gradient method (PGM)\n:cite:`daubechies-2004-iterative` and accelerated proximal gradient method\n(APGM), which is also known as FISTA :cite:`beck-2017-first`, solve problems\nof the form\n\n.. math:: \\mathop{\\mathrm{arg\\,min}}_{\\mb{x}} \\; f(\\mb{x}) + g(\\mb{x}) \\,,\n\nwhere :math:`f(\\cdot)` is assumed to be differentiable, and\n:math:`g(\\cdot)` is assumed to have a proximal operator that can be\ncomputed efficiently. These algorithms typically require more iterations\nfor convergence than ADMM, but can provide faster convergence with time\nwhen the linear solve required by ADMM is slow to compute.\n\n\nMachine Learning\n----------------\n\nWhile relatively simple regularization terms such as the total\nvariation norm can be effective when the underlying assumptions are\nwell matched to the data (e.g., the reconstructed images for certain\nmaterials science applications really are approximately piecewise\nconstant), it is difficult to design mathematically simple\nregularization terms that adequately represent the properties of the\ncomplex data that is often encountered in practice. A widely-used\nalternative framework for regularizing the solution of imaging inverse\nproblems is *plug-and-play priors* (PPP)\n:cite:`venkatakrishnan-2013-plugandplay2` :cite:`sreehari-2016-plug`\n:cite:`kamilov-2023-plugandplay`, which provides a mechanism for\nexploiting image denoisers such as BM3D :cite:`dabov-2008-image` as\nimplicit priors. With the rise of deep learning methods, PPP provided\none of the first frameworks for applying machine learning methods to\ninverse problems via the use of learned denoisers such as DnCNN\n:cite:`zhang-2017-dncnn`.\n\nSCICO supports PPP inverse problems solutions with both BM3D and DnCNN\ndenoisers, and provides usage examples for both choices. BM3D is more\nflexible, as it includes a tunable noise level parameter, while SCICO\nonly includes DnCNN models trained at three different noise levels (as\nin the original DnCNN paper), but DnCNN has a significant speed\nadvantage when GPUs are available. As an example, the following code\noutline demonstrates a PPP solution, with a non-negativity constraint\nand a 17-layer DnCNN denoiser as a regularizer, of an inverse problem\nwith measurement, :math:`\\mb{y}`, and a generic linear forward\noperator, :math:`A`.\n\n.. code:: python\n\n       ρ = 0.3  # ADMM penalty parameter\n       maxiter = 10 # number of ADMM iterations\n\n       f = scico.loss.SquaredL2Loss(y=y, A=A)\n       g1 = scico.functional.DnCNN(\"17M\")\n       g2 = scico.functional.NonNegativeIndicator()\n       C = scico.linop.Identity(A.input_shape)\n\n       solver = scico.optimize.admm.ADMM(\n         f=f,\n         g_list=[g1, g2],\n         C_list=[C, C],\n         rho_list=[ρ, ρ],\n         x0=A.T @ y,\n         maxiter=maxiter,\n         subproblem_solver=scico.optimize.admm.LinearSubproblemSolver(),\n         itstat_options={\"display\": True, \"period\": 5},\n       )\n\n       x_hat = solver.solve()\n\nExample results for this type of approach applied to image deconvolution\n(i.e. with forward operator, :math:`A`, as a convolution) are shown in\nthe figure below.\n\n.. image:: /figures/deconv_ppp_dncnn.png\n     :align: center\n     :width: 95%\n     :alt: Image deconvolution via PPP with DnCNN denoiser.\n\n|\n\nMore recently, a wider variety of frameworks have been developed for\napplying deep learning methods to inverse problems, including the\napplication of the adjoint of the forward operator to map the\nmeasurement to the solution space followed by an artifact removal CNN\n:cite:`jin-2017-unet`, and learned networks with structures based on\nthe unrolling of iterative algorithms such as PPP\n:cite:`monga-2021-algorithm`. A number of these methods are currently\nbeing implemented, and will be included in a future SCICO release. It\nis worth noting, however, that while some of these methods offer\nsuperior performance to PPP, it is at the cost of having to train the\nmodels with problem-specific data, which may be difficult to obtain,\nwhile PPP is often able to function well with a denoiser trained on\ngeneric image data.\n"
  },
  {
    "path": "docs/source/notes.rst",
    "content": "*****\nNotes\n*****\n\n\nDebugging\n=========\n\nIf difficulties are encountered in debugging jitted functions, jit can\nbe globally disabled by setting the environment variable\n``JAX_DISABLE_JIT=1`` before running Python, as in\n\n::\n\n   JAX_DISABLE_JIT=1 python test_script.py\n\n\nDouble Precision\n================\n\nBy default, JAX enforces single-precision numbers. Double precision\ncan be enabled in one of two ways:\n\n1. Setting the environment variable ``JAX_ENABLE_X64=TRUE`` before\n   launching Python.\n2. Manually setting the ``jax_enable_x64`` flag **at program\n   startup**; that is, **before** importing SCICO.\n\n::\n\n   from jax.config import config\n   config.update(\"jax_enable_x64\", True)\n   import scico # continue as usual\n\n\nFor more information, see the `JAX notes on double precision <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision>`_.\n\nDevice Control\n==============\n\nUse of the CPU device can be forced even when GPUs are present by setting the\nenvironment variable ``JAX_PLATFORM_NAME=cpu`` before running Python. This also\nserves to disable the warning that older versions of JAX issued when running\non a platform without a GPU, but this should no longer be necessary for any\nJAX versions supported by SCICO.\n\nBy default, JAX views a multi-core CPU as a single device. Primarily for testing\npurposes, it may be useful to instruct JAX to emulate multiple CPU devices, by\nsetting the environment variable ``XLA_FLAGS='--xla_force_host_platform_device_count=<n>'``,\nwhere ``<n>`` is an integer number of devices. For more detail see the relevant\n`section of the JAX docs <https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#aside-hosts-and-devices-in-jax>`__.\n\nBy default, JAX will preallocate a large chunk of GPU memory on startup. This\nbehavior can be controlled using environment variables ``XLA_PYTHON_CLIENT_PREALLOCATE``,\n``XLA_PYTHON_CLIENT_MEM_FRACTION``, and ``XLA_PYTHON_CLIENT_ALLOCATOR``, as described in\nthe relevant `section of the JAX docs <https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html>`__.\n\n\nRandom Number Generation\n========================\n\nJAX implements an explicit, non-stateful pseudorandom number generator (PRNG).\nThe user is responsible for generating a PRNG key and mutating it each time a\nnew random number is generated. We recommend users read the `JAX documentation\n<https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers>`_\nfor information on the design of JAX random number functionality.\n\n\nIn :mod:`scico.random` we provide convenient wrappers around several `jax.random\n<https://jax.readthedocs.io/en/stable/jax.random.html>`_ routines to handle\nthe generation and splitting of PRNG keys.\n\n::\n\n   # Calls to scico.random functions always return a PRNG key\n   # If no key is passed to the function, a new key is generated\n   x, key = scico.random.randn((2,))\n   print(x)   # [ 0.19307713 -0.52678305]\n\n   # scico.random functions automatically split the PRNGkey and return\n   # an updated key\n   y, key = scico.random.randn((2,), key=key)\n   print(y) # [ 0.00870693 -0.04888531]\n\nThe user is responsible for passing the PRNG key to\n:mod:`scico.random` functions. If no key is passed, repeated calls to\n:mod:`scico.random` functions will return the same random numbers:\n\n::\n\n   x, key = scico.random.randn((2,))\n   print(x)   # [ 0.19307713 -0.52678305]\n\n   # No key passed, will return the same random numbers!\n   y, key = scico.random.randn((2,))\n   print(y)   # [ 0.19307713 -0.52678305]\n\n\n\n.. _non_jax_dep:\n\nCompiled Dependency Packages\n============================\n\nThe code acceleration and automatic differentiation features of JAX\nare not available for some components of SCICO that are provided via\ninterfaces to compiled C code. When these components are used on a\nplatform with GPUs, the remainder of the code will run on a GPU, but\nthere is potential for a considerable delay due to host-GPU memory\ntransfers. This issue primarily affects:\n\n\nDenoisers\n---------\n\nThe :func:`.bm3d` and :func:`.bm4d` denoisers (and the corresponding\n:class:`.BM3D` and :class:`.BM4D` pseudo-functionals) are implemented\nvia interfaces to the `bm3d <https://pypi.org/project/bm3d/>`__ and\n`bm4d <https://pypi.org/project/bm4d/>`__ packages respectively. The\n:class:`~.denoiser.DnCNN` denoiser (and the corresponding\n:class:`~.functional.DnCNN` pseudo-functional) denoiser should be used\nwhen the full benefits of JAX-based code are required.\n\n\nTomographic Projectors/Radon Transforms\n---------------------------------------\n\nNote that the tomographic projections that are frequently referred\nto as Radon transforms are referred to as X-ray transforms in SCICO.\nWhile the Radon transform is far more well-known than the X-ray\ntransform, which is the same as the Radon transform for projections\nin two dimensions, these two transform differ in higher numbers of\ndimensions, and it is the X-ray transform that is the appropriate\nmathematical model for beam attenuation based imaging in three or\nmore dimensions.\n\nSCICO includes three different implementations of X-ray transforms.\nOf these, :class:`.linop.XRayTransform` is an integral component of\nSCICO, while the other two depend on external packages.\nThe :class:`.xray.svmbir.XRayTransform` class is implemented\nvia an interface to the `svmbir\n<https://svmbir.readthedocs.io/en/latest/>`__ package. The\n:class:`.xray.astra.XRayTransform2D` and\n:class:`.xray.astra.XRayTransform3D` classes are implemented via an\ninterface to the `ASTRA toolbox\n<https://www.astra-toolbox.com/>`__. This toolbox does provide some\nGPU acceleration support, but efficiency is expected to be lower than\nJAX-based code due to host-GPU memory transfers.\n\n\nAutomatic Differentiation Caveats\n=================================\n\n\nComplex Functions\n-----------------\n\nThe JAX-defined gradient of a complex-valued function is a\ncomplex-conjugated version of the usual gradient used in mathematical\noptimization and computational imaging. Minimizing a function using\nthe JAX convention involves taking steps in the direction of the\ncomplex conjugated gradient.\n\nThe function :func:`scico.grad` returns the expected gradient, that\nis, the conjugate of the JAX gradient. For further discussion, see\nthis `JAX issue <https://github.com/google/jax/issues/4891>`_.\n\nAs a concrete example, consider the function :math:`f(x) =\n\\frac{1}{2}\\norm{\\mb{A} \\mb{x}}_2^2` where :math:`\\mb{A}` is a complex\nmatrix. The gradient of :math:`f` is usually given :math:`(\\nabla\nf)(\\mb{x}) = \\mb{A}^H \\mb{A} \\mb{x}`, where :math:`\\mb{A}^H` is the\nconjugate transpose of :math:`\\mb{A}`. Applying :func:`jax.grad` to\n:math:`f` will yield :math:`(\\mb{A}^H \\mb{A} \\mb{x})^*`, where\n:math:`\\cdot^*` denotes complex conjugation.\n\nThe following code demonstrates the use of :func:`jax.grad` and\n:func:`scico.grad`:\n\n\n::\n\n    m, n = (4, 3)\n    A, key = randn((m, n), dtype=np.complex64, key=None)\n    x, key = randn((n,), dtype=np.complex64, key=key)\n\n    def f(x):\n        return 0.5 * snp.linalg.norm(A @ x)**2\n\n    an_grad = A.conj().T @ A @ x  # The expected gradient\n\n    np.testing.assert_allclose(jax.grad(f)(x), an_grad.conj(), rtol=1e-4)\n    np.testing.assert_allclose(scico.grad(f)(x), an_grad, rtol=1e-4)\n\n\nNon-differentiable Functionals\n------------------------------\n\n:func:`scico.grad` can be applied to any function, but has undefined\nbehavior for non-differentiable functions. For non-differerentiable\nfunctions, :func:`scico.grad` may or may not return a valid\nsubgradient. As an example, ``scico.grad(snp.abs)(0.) = 0``, which is\na valid subgradient. However, ``scico.grad(snp.linalg.norm)([0., 0.])\n= [nan, nan]``.\n\nDifferentiable functions that are written as the composition of a\ndifferentiable and non-differentiable function should be avoided. As\nan example, :math:`f(x) = \\norm{x}_2^2` can be implemented in as ``f =\nlambda x: snp.linalg.norm(x)**2``. This involves first calculating the\nnon-squared :math:`\\ell_2` norm, then squaring it. The un-squared\n:math:`\\ell_2` norm is not differentiable at zero. When evaluating\nthe gradient of ``f`` at 0, :func:`scico.grad` returns :data:`~numpy.NaN`:\n\n::\n\n   >>> import scico\n   >>> import scico.numpy as snp\n   >>> f = lambda x: snp.linalg.norm(x)**2\n   >>> scico.grad(f)(snp.zeros(2, dtype=snp.float32))  # doctest: +SKIP\n   Array([nan, nan], dtype=float32)\n\nThis can be fixed (assuming real-valued arrays only) by defining the\nsquared :math:`\\ell_2` norm directly as ``g = lambda x: snp.sum(x**2)``.\nThe gradient will work as expected:\n\n::\n\n   >>> g = lambda x: snp.sum(x**2)\n   >>> scico.grad(g)(snp.zeros(2, dtype=snp.float32))  #doctest: +SKIP\n   Array([0., 0.], dtype=float32)\n\nIf complex-valued arrays also need to be supported, a minor modification is\nnecessary:\n\n::\n\n   >>> g = lambda x: snp.sum(snp.abs(x)**2)\n   >>> scico.grad(g)(snp.zeros(2, dtype=snp.float32))  #doctest: +SKIP\n   Array([0., 0.], dtype=float32)\n   >>> scico.grad(g)(snp.zeros(2, dtype=snp.complex64))  #doctest: +SKIP\n   Array([0.-0.j, 0.-0.j], dtype=complex64)\n\n\nAn alternative is to define a `custom derivative rule\n<https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html#enforcing-a-differentiation-convention>`_\nto enforce a particular derivative convention at a point.\n\n\nJAX Arrays\n==========\n\nJAX utilizes a new array type :class:`~jax.Array`, which is similar to\nNumPy :class:`~numpy.ndarray`, but can be backed by CPU, GPU, or TPU\nmemory and is immutable.\n\n\nJAX and NumPy Arrays\n--------------------\n\nSCICO and JAX functions can be applied directly to NumPy arrays\nwithout explicit conversion to JAX arrays, but this is not\nrecommended, as it can result in repeated data transfers from the CPU\nto GPU. Consider this toy example on a system with a GPU present:\n\n::\n\n   x = np.random.randn(8)    # Array on host\n   A = np.random.randn(8, 8) # Array on host\n   y = snp.dot(A, x)         # A, x transfered to GPU\n                             # y resides on GPU\n   z = y + x                 # x must be transfered to GPU again\n\n\nThe unnecessary transfer can be avoided by first converting ``A`` and ``x`` to\nJAX arrays:\n\n::\n\n   x = np.random.randn(8)    # array on host\n   A = np.random.randn(8, 8) # array on host\n   x = jax.device_put(x)     # transfer to GPU\n   A = jax.device_put(A)\n   y = snp.dot(A, x)         # no transfer needed\n   z = y + x                 # no transfer needed\n\n\nWe recommend that input data be converted to JAX arrays via\n:func:`jax.device_put` before calling any SCICO optimizers.\n\nOn a multi-GPU system, :func:`jax.device_put` can place data on a specific\nGPU. See the `JAX notes on data placement\n<https://jax.readthedocs.io/en/latest/faq.html?highlight=data%20placement#controlling-data-and-computation-placement-on-devices>`_.\n\n\nJAX Arrays are Immutable\n------------------------\n\nUnlike standard NumPy arrays, JAX arrays are immutable: once they have\nbeen created, they cannot be changed. This prohibits in-place updating\nof JAX arrays. JAX provides special syntax for updating individual\narray elements through the `indexed update operators\n<https://jax.readthedocs.io/en/latest/jax.ops.html#syntactic-sugar-for-indexed-update-operators>`_.\n"
  },
  {
    "path": "docs/source/overview.rst",
    "content": "Overview\n========\n\n`Scientific Computational Imaging Code (SCICO)\n<https://github.com/lanl/scico>`__ is a Python package for solving the\ninverse problems that arise in scientific imaging applications. Its\nprimary focus is providing methods for solving ill-posed inverse\nproblems by using an appropriate prior model of the reconstruction\nspace. SCICO includes a growing suite of operators, cost functionals,\nregularizers, and optimization algorithms that may be combined to\nsolve a wide range of problems, and is designed so that it is easy to\nadd new building blocks. When solving a problem, these components are\ncombined in a way that makes code for optimization routines look like\nthe pseudocode in scientific papers. SCICO is built on top of `JAX\n<https://jax.readthedocs.io/en/latest/>`__ rather than `NumPy\n<https://numpy.org/>`__, enabling GPU/TPU acceleration, just-in-time\ncompilation, and automatic gradient functionality, which is used to\nautomatically compute the adjoints of linear operators. An example of\nhow to solve a multi-channel tomography problem with SCICO is shown in\nthe figure below.\n\n\n.. image:: /figures/scico-tomo-overview.png\n     :align: center\n     :width: 95%\n     :alt: Solving a multi-channel tomography problem with SCICO.\n\n|\n\nThe SCICO source code is available from `GitHub\n<https://github.com/lanl/scico>`__, and pre-built packages are\navailable from `PyPI <https://github.com/lanl/scico>`__. (Detailed\ninstructions for installing SCICO are available in :ref:`installing`.)\nIt has extensive `online documentation <https://scico.rtfd.io/>`__,\nincluding :doc:`API documentation <_autosummary/scico>` and\n:ref:`usage examples <example_notebooks>`, which can be run online at\n`Google Colab\n<https://colab.research.google.com/github/lanl/scico-data/blob/colab/notebooks/index.ipynb>`__\nand `binder\n<https://mybinder.org/v2/gh/lanl/scico-data/binder?labpath=notebooks%2Findex.ipynb>`__.\n\n\nIf you use this package for published work, please cite\n:cite:`balke-2022-scico` (see bibtex entry ``balke-2022-scico`` in\n`docs/source/references.bib\n<https://github.com/lanl/scico/blob/main/docs/source/references.bib>`_\nin the source distribution).\n\n\n\nContributing\n------------\n\nBug reports, feature requests, and general suggestions are welcome,\nand should be submitted via the `GitHub issue system\n<https://github.com/lanl/scico/issues>`__. More substantial\ncontributions are also :ref:`welcome <scico_dev_contributing>`.\n\n\n\nLicense\n-------\n\nSCICO is distributed as open-source software under a BSD 3-Clause\nLicense (see the `LICENSE\n<https://github.com/lanl/scico/blob/master/LICENSE>`__ file for\ndetails). LANL open source approval reference C20091.\n\n© 2020-2025. Triad National Security, LLC. All rights reserved.\nThis program was produced under U.S. Government contract\n89233218CNA000001 for Los Alamos National Laboratory (LANL), which is\noperated by Triad National Security, LLC for the U.S. Department of\nEnergy/National Nuclear Security Administration.  All rights in the\nprogram are reserved by Triad National Security, LLC, and the\nU.S. Department of Energy/National Nuclear Security Administration.\nThe Government has granted for itself and others acting on its behalf\na nonexclusive, paid-up, irrevocable worldwide license in this\nmaterial to reproduce, prepare derivative works, distribute copies to\nthe public, perform publicly and display publicly, and to permit\nothers to do so.\n"
  },
  {
    "path": "docs/source/pyfigures/cylindgrad.py",
    "content": "import numpy as np\n\nimport scico.linop as scl\nfrom scico import plot\n\ninput_shape = (7, 7, 7)\ncentre = (np.array(input_shape) - 1) / 2\nend = np.array(input_shape) - centre\ng0, g1, g2 = np.mgrid[-centre[0] : end[0], -centre[1] : end[1], -centre[2] : end[2]]\n\ncg = scl.CylindricalGradient(input_shape=input_shape)\n\nang = cg.coord[0]\nrad = cg.coord[1]\naxi = cg.coord[2]\n\ntheta = np.arctan2(g0, g1)\nclr = theta\n# See https://stackoverflow.com/a/49888126\nclr = (clr.ravel() - clr.min()) / np.ptp(clr)\nclr = np.concatenate((clr, np.repeat(clr, 2)))\nclr = plot.plt.cm.plasma(clr)\n\nplot.plt.rcParams[\"savefig.transparent\"] = True\n\nfig = plot.plt.figure(figsize=(20, 6))\nax = fig.add_subplot(1, 3, 1, projection=\"3d\")\nax.quiver(g0, g1, g2, ang[0], ang[1], ang[2], colors=clr, length=0.9)\nax.set_title(\"Angular local coordinate axis\", fontsize=18)\nax.set_xlabel(\"$x$\", fontsize=15)\nax.set_ylabel(\"$y$\", fontsize=15)\nax.set_zlabel(\"$z$\", fontsize=15)\nax.tick_params(labelsize=15)\nax = fig.add_subplot(1, 3, 2, projection=\"3d\")\nax.quiver(g0, g1, g2, rad[0], rad[1], rad[2], colors=clr, length=0.9)\nax.set_title(\"Radial local coordinate axis\", fontsize=18)\nax.set_xlabel(\"$x$\", fontsize=15)\nax.set_ylabel(\"$y$\", fontsize=15)\nax.set_zlabel(\"$z$\", fontsize=15)\nax.tick_params(labelsize=15)\nax = fig.add_subplot(1, 3, 3, projection=\"3d\")\nax.quiver(g0, g1, g2, axi[0], axi[1], axi[2], colors=clr[0], length=0.9)\nax.set_title(\"Axial local coordinate axis\", fontsize=18)\nax.set_xlabel(\"$x$\", fontsize=15)\nax.set_ylabel(\"$y$\", fontsize=15)\nax.set_zlabel(\"$z$\", fontsize=15)\nax.tick_params(labelsize=15)\nfig.tight_layout()\nfig.show()\n"
  },
  {
    "path": "docs/source/pyfigures/polargrad.py",
    "content": "import numpy as np\n\nimport scico.linop as scl\nfrom scico import plot\n\ninput_shape = (21, 21)\ncentre = (np.array(input_shape) - 1) / 2\nend = np.array(input_shape) - centre\ng0, g1 = np.mgrid[-centre[0] : end[0], -centre[1] : end[1]]\n\npg = scl.PolarGradient(input_shape=input_shape)\n\nang = pg.coord[0]\nrad = pg.coord[1]\n\nclr = (np.arctan2(ang[1], ang[0]) + np.pi) / (2 * np.pi)\n\nplot.plt.rcParams[\"image.cmap\"] = \"plasma\"\nplot.plt.rcParams[\"savefig.transparent\"] = True\n\nfig, ax = plot.plt.subplots(nrows=1, ncols=2, figsize=(13, 6))\nax[0].quiver(g0, g1, ang[0], ang[1], clr)\nax[0].set_title(\"Angular local coordinate axis\", fontsize=16)\nax[0].set_xlabel(\"$x$\", fontsize=14)\nax[0].set_ylabel(\"$y$\", fontsize=14)\nax[0].tick_params(labelsize=14)\nax[0].xaxis.set_ticks((-10, -5, 0, 5, 10))\nax[0].yaxis.set_ticks((-10, -5, 0, 5, 10))\nax[1].quiver(g0, g1, rad[0], rad[1], clr)\nax[1].set_title(\"Radial local coordinate axis\", fontsize=16)\nax[1].set_xlabel(\"$x$\", fontsize=14)\nax[1].set_ylabel(\"$y$\", fontsize=14)\nax[1].tick_params(labelsize=14)\nax[1].xaxis.set_ticks((-10, -5, 0, 5, 10))\nax[1].yaxis.set_ticks((-10, -5, 0, 5, 10))\nfig.tight_layout()\nfig.show()\n"
  },
  {
    "path": "docs/source/pyfigures/spheregrad.py",
    "content": "import numpy as np\n\nimport scico.linop as scl\nfrom scico import plot\n\ninput_shape = (7, 7, 7)\ncentre = (np.array(input_shape) - 1) / 2\nend = np.array(input_shape) - centre\ng0, g1, g2 = np.mgrid[-centre[0] : end[0], -centre[1] : end[1], -centre[2] : end[2]]\n\nsg = scl.SphericalGradient(input_shape=input_shape)\n\nazi = sg.coord[0]\npol = sg.coord[1]\nrad = sg.coord[2]\n\ntheta = np.arctan2(g0, g1)\nphi = np.arctan2(np.sqrt(g0**2 + g1**2), g2)\nclr = theta * phi\n# See https://stackoverflow.com/a/49888126\nclr = (clr.ravel() - clr.min()) / np.ptp(clr)\nclr = np.concatenate((clr, np.repeat(clr, 2)))\nclr = plot.plt.cm.plasma(clr)\n\nplot.plt.rcParams[\"savefig.transparent\"] = True\n\nfig = plot.plt.figure(figsize=(20, 6))\nax = fig.add_subplot(1, 3, 1, projection=\"3d\")\nax.quiver(g0, g1, g2, azi[0], azi[1], azi[2], colors=clr, length=0.9)\nax.set_title(\"Azimuthal local coordinate axis\", fontsize=18)\nax.set_xlabel(\"$x$\", fontsize=15)\nax.set_ylabel(\"$y$\", fontsize=15)\nax.set_zlabel(\"$z$\", fontsize=15)\nax.tick_params(labelsize=15)\nax = fig.add_subplot(1, 3, 2, projection=\"3d\")\nax.quiver(g0, g1, g2, pol[0], pol[1], pol[2], colors=clr, length=0.9)\nax.set_title(\"Polar local coordinate axis\", fontsize=18)\nax.set_xlabel(\"$x$\", fontsize=15)\nax.set_ylabel(\"$y$\", fontsize=15)\nax.set_zlabel(\"$z$\", fontsize=15)\nax.tick_params(labelsize=15)\nax = fig.add_subplot(1, 3, 3, projection=\"3d\")\nax.quiver(g0, g1, g2, rad[0], rad[1], rad[2], colors=clr, length=0.9)\nax.set_title(\"Radial local coordinate axis\", fontsize=18)\nax.set_xlabel(\"$x$\", fontsize=15)\nax.set_ylabel(\"$y$\", fontsize=15)\nax.set_zlabel(\"$z$\", fontsize=15)\nax.tick_params(labelsize=15)\nfig.tight_layout()\nfig.show()\n"
  },
  {
    "path": "docs/source/pyfigures/xray_2d_geom.py",
    "content": "import numpy as np\n\nimport matplotlib as mpl\nimport matplotlib.patches as patches\nimport matplotlib.pyplot as plt\n\nmpl.rcParams[\"savefig.transparent\"] = True\n\n\nc = 1.0 / np.sqrt(2.0)\ne = 1e-2\nstyle = \"Simple, tail_width=0.5, head_width=4, head_length=8\"\nfig, ax = plt.subplots(nrows=1, ncols=3, figsize=(21, 7))\n\n# all plots\nfor n in range(3):\n    ax[n].set_aspect(1.0)\n    ax[n].set_xlim(-1.1, 1.1)\n    ax[n].set_ylim(-1.1, 1.1)\n    ax[n].set_xticks(np.linspace(-1.0, 1.0, 5))\n    ax[n].set_yticks(np.linspace(-1.0, 1.0, 5))\n    ax[n].tick_params(axis=\"x\", labelsize=14)\n    ax[n].tick_params(axis=\"y\", labelsize=14)\n    ax[n].set_xlabel(\"axis 1\", fontsize=16)\n    ax[n].set_ylabel(\"axis 0\", fontsize=16)\n\n\n# scico\nax[0].set_title(\"scico\", fontsize=18)\nplist = [\n    patches.FancyArrowPatch((-1.0, 0.0), (-0.5, 0.0), arrowstyle=style, color=\"r\"),\n    patches.FancyArrowPatch((-c, -c), (-c / 2.0, -c / 2.0), arrowstyle=style, color=\"r\"),\n    patches.FancyArrowPatch(\n        (\n            0.0,\n            -1.0,\n        ),\n        (0.0, -0.5),\n        arrowstyle=style,\n        color=\"r\",\n    ),\n    patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=180, theta2=-45.0, color=\"b\", lw=2, ls=\"dotted\"),\n    patches.FancyArrowPatch((c - e, -c - e), (c + e, -c + e), arrowstyle=style, color=\"b\"),\n]\nfor p in plist:\n    ax[0].add_patch(p)\n\nax[0].text(-0.88, 0.02, r\"$\\theta=0$\", color=\"r\", fontsize=16)\nax[0].text(-3 * c / 4 - 0.01, -3 * c / 4 - 0.1, r\"$\\theta=\\frac{\\pi}{4}$\", color=\"r\", fontsize=16)\nax[0].text(0.03, -0.8, r\"$\\theta=\\frac{\\pi}{2}$\", color=\"r\", fontsize=16)\n\nax[0].plot((1.0, 1.0), (-0.375, 0.375), color=\"orange\", lw=2)\nax[0].arrow(\n    0.94,\n    0.375,\n    0.0,\n    -0.75,\n    color=\"orange\",\n    lw=1.0,\n    ls=\"--\",\n    head_width=0.03,\n    length_includes_head=True,\n)\nax[0].text(0.7, 0.0, r\"$\\theta=0$\", color=\"orange\", ha=\"left\", fontsize=16)\nax[0].plot((-0.375, 0.375), (1.0, 1.0), color=\"orange\", lw=2)\nax[0].arrow(\n    -0.375,\n    0.94,\n    0.75,\n    0.0,\n    color=\"orange\",\n    lw=1.0,\n    ls=\"--\",\n    head_width=0.03,\n    length_includes_head=True,\n)\nax[0].text(0.0, 0.82, r\"$\\theta=\\frac{\\pi}{2}$\", color=\"orange\", ha=\"center\", fontsize=16)\n\n\n# astra\nax[1].set_title(\"astra\", fontsize=18)\nplist = [\n    patches.FancyArrowPatch((0.0, -1.0), (0.0, -0.5), arrowstyle=style, color=\"r\"),\n    patches.FancyArrowPatch((c, -c), (c / 2.0, -c / 2.0), arrowstyle=style, color=\"r\"),\n    patches.FancyArrowPatch((1.0, 0.0), (0.5, 0.0), arrowstyle=style, color=\"r\"),\n    patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=-90, theta2=45.0, color=\"b\", lw=2, ls=\"dotted\"),\n    patches.FancyArrowPatch((c + e, c - e), (c - e, c + e), arrowstyle=style, color=\"b\"),\n]\nfor p in plist:\n    ax[1].add_patch(p)\n\nax[1].text(0.02, -0.75, r\"$\\theta=0$\", color=\"r\", fontsize=16)\nax[1].text(3 * c / 4 + 0.01, -3 * c / 4 + 0.01, r\"$\\theta=\\frac{\\pi}{4}$\", color=\"r\", fontsize=16)\nax[1].text(0.65, 0.05, r\"$\\theta=\\frac{\\pi}{2}$\", color=\"r\", fontsize=16)\n\nax[1].plot((-0.375, 0.375), (1.0, 1.0), color=\"orange\", lw=2)\nax[1].arrow(\n    -0.375,\n    0.94,\n    0.75,\n    0.0,\n    color=\"orange\",\n    lw=1.0,\n    ls=\"--\",\n    head_width=0.03,\n    length_includes_head=True,\n)\nax[1].text(0.0, 0.82, r\"$\\theta=0$\", color=\"orange\", ha=\"center\", fontsize=16)\nax[1].plot((-1.0, -1.0), (-0.375, 0.375), color=\"orange\", lw=2)\nax[1].arrow(\n    -0.94,\n    -0.375,\n    0.0,\n    0.75,\n    color=\"orange\",\n    lw=1.0,\n    ls=\"--\",\n    head_width=0.03,\n    length_includes_head=True,\n)\nax[1].text(-0.9, 0.0, r\"$\\theta=\\frac{\\pi}{2}$\", color=\"orange\", ha=\"left\", fontsize=16)\n\n\n# svmbir\nax[2].set_title(\"svmbir\", fontsize=18)\nplist = [\n    patches.FancyArrowPatch((-1.0, 0.0), (-0.5, 0.0), arrowstyle=style, color=\"r\"),\n    patches.FancyArrowPatch((-c, c), (-c / 2.0, c / 2.0), arrowstyle=style, color=\"r\"),\n    patches.FancyArrowPatch(\n        (\n            0.0,\n            1.0,\n        ),\n        (0.0, 0.5),\n        arrowstyle=style,\n        color=\"r\",\n    ),\n    patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=45, theta2=180, color=\"b\", lw=2, ls=\"dotted\"),\n    patches.FancyArrowPatch((c - e, c + e), (c + e, c - e), arrowstyle=style, color=\"b\"),\n]\nfor p in plist:\n    ax[2].add_patch(p)\nax[2].text(-0.88, 0.02, r\"$\\theta=0$\", color=\"r\", fontsize=16)\nax[2].text(-3 * c / 4 + 0.01, 3 * c / 4 + 0.01, r\"$\\theta=\\frac{\\pi}{4}$\", color=\"r\", fontsize=16)\nax[2].text(0.03, 0.75, r\"$\\theta=\\frac{\\pi}{2}$\", color=\"r\", fontsize=16)\n\nax[2].plot((1.0, 1.0), (-0.375, 0.375), color=\"orange\", lw=2)\nax[2].arrow(\n    0.94,\n    0.375,\n    0.0,\n    -0.75,\n    color=\"orange\",\n    lw=1.0,\n    ls=\"--\",\n    head_width=0.03,\n    length_includes_head=True,\n)\nax[2].text(0.7, 0.0, r\"$\\theta=0$\", color=\"orange\", ha=\"left\", fontsize=16)\n\nax[2].plot((-0.375, 0.375), (-1.0, -1.0), color=\"orange\", lw=2)\nax[2].arrow(\n    0.375,\n    -0.94,\n    -0.75,\n    0.0,\n    color=\"orange\",\n    lw=1.0,\n    ls=\"--\",\n    head_width=0.03,\n    length_includes_head=True,\n)\nax[2].text(0.0, -0.82, r\"$\\theta=\\frac{\\pi}{2}$\", color=\"orange\", ha=\"center\", fontsize=16)\n\n\nfig.tight_layout()\nfig.show()\n"
  },
  {
    "path": "docs/source/pyfigures/xray_3d_ang.py",
    "content": "import numpy as np\n\nimport matplotlib as mpl\nimport matplotlib.patches as patches\nimport matplotlib.pyplot as plt\n\nmpl.rcParams[\"savefig.transparent\"] = True\n\n\nc = 1.0 / np.sqrt(2.0)\ne = 1e-2\nstyle = \"Simple, tail_width=0.5, head_width=4, head_length=8\"\nfig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5, 5))\nax.set_aspect(1.0)\nax.set_xlim(-1.1, 1.1)\nax.set_ylim(-1.1, 1.1)\nax.set_xticks(np.linspace(-1.0, 1.0, 5))\nax.set_yticks(np.linspace(-1.0, 1.0, 5))\nax.tick_params(axis=\"x\", labelsize=12)\nax.tick_params(axis=\"y\", labelsize=12)\nax.set_xlabel(\"$x$\", fontsize=14)\nax.set_ylabel(\"$y$\", fontsize=14)\n\nplist = [\n    patches.FancyArrowPatch((0.0, -1.0), (0.0, -0.5), arrowstyle=style, color=\"r\"),\n    patches.FancyArrowPatch((c, -c), (c / 2.0, -c / 2.0), arrowstyle=style, color=\"r\"),\n    patches.FancyArrowPatch((1.0, 0.0), (0.5, 0.0), arrowstyle=style, color=\"r\"),\n    patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=-90, theta2=45.0, color=\"b\", lw=2, ls=\"dotted\"),\n    patches.FancyArrowPatch((c + e, c - e), (c - e, c + e), arrowstyle=style, color=\"b\"),\n]\nfor p in plist:\n    ax.add_patch(p)\nax.text(0.02, -0.75, r\"$\\theta=0$\", color=\"r\", fontsize=14)\nax.text(\n    3 * c / 4 + 0.01,\n    -3 * c / 4 + 0.01,\n    r\"$\\theta=\\frac{\\pi}{4}$\",\n    color=\"r\",\n    fontsize=14,\n)\nax.text(0.65, 0.05, r\"$\\theta=\\frac{\\pi}{2}$\", color=\"r\", fontsize=14)\n\nax.plot((-0.375, 0.375), (1.0, 1.0), color=\"orange\", lw=2)\nax.arrow(\n    -0.375,\n    0.94,\n    0.75,\n    0.0,\n    color=\"orange\",\n    lw=0.5,\n    ls=\"--\",\n    head_width=0.03,\n    length_includes_head=True,\n)\nax.text(0.0, 0.82, r\"$\\theta=0$\", color=\"orange\", ha=\"center\", fontsize=14)\n\nax.plot((-1.0, -1.0), (-0.375, 0.375), color=\"orange\", lw=2)\nax.arrow(\n    -0.94,\n    -0.375,\n    0.0,\n    0.75,\n    color=\"orange\",\n    lw=0.5,\n    ls=\"--\",\n    head_width=0.03,\n    length_includes_head=True,\n)\nax.text(-0.9, 0.0, r\"$\\theta=\\frac{\\pi}{2}$\", color=\"orange\", ha=\"left\", fontsize=14)\n\nfig.tight_layout()\nfig.show()\n"
  },
  {
    "path": "docs/source/pyfigures/xray_3d_vec.py",
    "content": "import numpy as np\n\nimport matplotlib as mpl\nfrom matplotlib import pyplot as plt\nfrom matplotlib.patches import FancyArrowPatch\nfrom mpl_toolkits.mplot3d import proj3d\n\nmpl.rcParams[\"savefig.transparent\"] = True\n\n\n# See https://github.com/matplotlib/matplotlib/issues/21688\nclass Arrow3D(FancyArrowPatch):\n    def __init__(self, xs, ys, zs, *args, **kwargs):\n        FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)\n        self._verts3d = xs, ys, zs\n\n    def do_3d_projection(self, renderer=None):\n        xs3d, ys3d, zs3d = self._verts3d\n        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)\n        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))\n\n        return np.min(zs)\n\n\n# Define vector components\n𝜃 = 10 * np.pi / 180.0  # angle in x-y plane (azimuth angle)\n𝛼 = 70 * np.pi / 180.0  # angle with z axis (zenith angle)\n𝛥p, 𝛥d = 0.3, 1.0\nd = (-𝛥d * np.sin(𝛼) * np.sin(𝜃), 𝛥d * np.sin(𝛼) * np.cos(𝜃), 𝛥d * np.cos(𝛼))\nu = (𝛥p * np.cos(𝜃), 𝛥p * np.sin(𝜃), 0.0)\nv = (𝛥p * np.cos(𝛼) * np.sin(𝜃), -𝛥p * np.cos(𝛼) * np.cos(𝜃), 𝛥p * np.sin(𝛼))\n\n# Location of text labels\nd_txtpos = np.array(d) + np.array([0, 0, -0.12])\nu_txtpos = np.array(d) + np.array(u) + np.array([0, 0, -0.1])\nv_txtpos = np.array(d) + np.array(v) + np.array([0, 0, 0.03])\n\n\narrowstyle = \"-|>,head_width=2.5,head_length=9\"\n\nfig, ax = plt.subplots(subplot_kw={\"projection\": \"3d\"})\n\n# Set view\nax.set_aspect(\"equal\")\nax.elev = 15\nax.azim = -50\nax.set_box_aspect(None, zoom=2)\nax.set_xlim((-1.1, 1.1))\nax.set_ylim((-1.1, 1.1))\nax.set_zlim((-1.1, 1.1))\n\n# Disable shaded 3d axis grids\nax.set_axis_off()\n\n# Draw central x,y,z axes and labels\naxis_crds = np.array([[-1, 1], [0, 0], [0, 0]])\naxis_lbls = (\"$x$\", \"$y$\", \"$z$\")\nfor k in range(3):\n    crd = np.roll(axis_crds, k, axis=0)\n    ax.add_artist(\n        Arrow3D(\n            *crd.tolist(),\n            lw=1.5,\n            ls=\"--\",\n            arrowstyle=arrowstyle,\n            color=\"black\",\n        )\n    )\n    ax.text(*(1.05 * crd[:, 1]).tolist(), axis_lbls[k], fontsize=12)\n\n# Draw d, u, v and labels\nax.quiver(0, 0, 0, *d, arrow_length_ratio=0.08, lw=2, color=\"blue\")\nax.quiver(*d, *u, arrow_length_ratio=0.08 / 𝛥p, lw=2, color=\"blue\")\nax.quiver(*d, *v, arrow_length_ratio=0.08 / 𝛥p, lw=2, color=\"blue\")\nax.text(*d_txtpos, r\"$\\mathbf{d}$\", fontsize=12)\nax.text(*u_txtpos, r\"$\\mathbf{u}$\", fontsize=12)\nax.text(*v_txtpos, r\"$\\mathbf{v}$\", fontsize=12)\n\nfig.tight_layout()\nfig.subplots_adjust(-0.1, -0.06, 1, 1)\nfig.show()\n"
  },
  {
    "path": "docs/source/pyfigures/xray_3d_vol.py",
    "content": "import numpy as np\n\nimport matplotlib as mpl\nfrom matplotlib import pyplot as plt\nfrom matplotlib.patches import FancyArrowPatch\nfrom mpl_toolkits.mplot3d import proj3d\n\nmpl.rcParams[\"savefig.transparent\"] = True\n\n\n# See https://github.com/matplotlib/matplotlib/issues/21688\nclass Arrow3D(FancyArrowPatch):\n    def __init__(self, xs, ys, zs, *args, **kwargs):\n        FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)\n        self._verts3d = xs, ys, zs\n\n    def do_3d_projection(self, renderer=None):\n        xs3d, ys3d, zs3d = self._verts3d\n        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)\n        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))\n\n        return np.min(zs)\n\n\n# Define vector components\n𝜃 = 10 * np.pi / 180.0  # angle in x-y plane (azimuth angle)\n𝛼 = 70 * np.pi / 180.0  # angle with z axis (zenith angle)\n𝛥p, 𝛥d = 0.3, 1.0\nd = (-𝛥d * np.sin(𝛼) * np.sin(𝜃), 𝛥d * np.sin(𝛼) * np.cos(𝜃), 𝛥d * np.cos(𝛼))\nu = (𝛥p * np.cos(𝜃), 𝛥p * np.sin(𝜃), 0.0)\nv = (𝛥p * np.cos(𝛼) * np.sin(𝜃), -𝛥p * np.cos(𝛼) * np.cos(𝜃), 𝛥p * np.sin(𝛼))\n\n# Location of text labels\nd_txtpos = np.array(d) + np.array([0, 0, -0.12])\nu_txtpos = np.array(d) + np.array(u) + np.array([0, 0, -0.1])\nv_txtpos = np.array(d) + np.array(v) + np.array([0, 0, 0.03])\n\n\narrowstyle = \"-|>,head_width=2.5,head_length=9\"\n\nfig, ax = plt.subplots(subplot_kw={\"projection\": \"3d\"})\n\n# Set view\nax.set_aspect(\"equal\")\nax.elev = 40\nax.azim = -60\nax.set_box_aspect(None, zoom=1.8)\nax.set_xlim((-10.5, 10.5))\nax.set_ylim((-10.5, 10.5))\nax.set_zlim((-10.5, 10.5))\n\n# Disable shaded 3d axis grids\nax.set_axis_off()\n\n# Draw central x,y,z axes and labels\naxis_crds = np.array([[-10, 10], [0, 0], [0, 0]])\naxis_lbls = (\"$x$\", \"$y$\", \"$z$\")\nfor k in range(3):\n    crd = np.roll(axis_crds, k, axis=0)\n    ax.add_artist(\n        Arrow3D(\n            *crd.tolist(),\n            lw=1.5,\n            ls=\"--\",\n            arrowstyle=arrowstyle,\n            color=\"black\",\n        )\n    )\n    ax.text(*(1.05 * crd[:, 1]).tolist(), axis_lbls[k], fontsize=12)\n\nwx = 4\nwy = 3\nwz = 2\nbx = np.array([-wx, wx, wx, wx, -wx, -wx, -wx])\nby = np.array([-wy, -wy, wy, wy, wy, -wy, -wy])\nbz = np.array([-wz, -wz, -wz, wz, wz, wz, -wz])\nax.plot(bx, by, bz, lw=2, color=\"blue\")\nax.plot(bx[0:3], by[0:3], -bz[0:3], lw=2, color=\"blue\")\nbx = np.array([wx, wx])\nby = np.array([-wy, -wy])\nbz = np.array([-wz, wz])\nax.plot(bx, by, bz, lw=2, color=\"blue\")\nbx = np.array([-wx, -wx, wx])\nby = np.array([-wy, wy, wy])\nbz = np.array([-wz, -wz, -wz])\nax.plot(bx, by, bz, lw=2, ls=\"--\", color=\"blue\")\nbx = np.array([-wx, -wx])\nby = np.array([wy, wy])\nbz = np.array([-wz, wz])\nax.plot(bx, by, bz, lw=2, ls=\"--\", color=\"blue\")\n\nfig.tight_layout()\nfig.subplots_adjust(-0.1, -0.1, 1, 1.07)\nfig.show()\n"
  },
  {
    "path": "docs/source/references.bib",
    "content": "@Article {aggarwal-2019-modl,\n  author =\t {Aggarwal, Hemant K. and Mani, Merry P. and Jacob,\n                  Mathews},\n  journal =\t {IEEE Transactions on Medical Imaging},\n  title =\t {{MoDL}: Model-Based Deep Learning Architecture for\n                  Inverse Problems},\n  year =\t 2019,\n  volume =\t 38,\n  number =\t 2,\n  pages =\t {394--405},\n  doi =\t\t {10.1109/TMI.2018.2865356}\n}\n\n@Article {alliney-1992-digital,\n  author =\t {Alliney, Stefano},\n  journal =\t {IEEE Transactions on Signal Processing},\n  title =\t {Digital filters as absolute norm regularizers},\n  year =\t 1992,\n  volume =\t 40,\n  number =\t 6,\n  pages =\t {1548--1562},\n  doi =\t\t {10.1109/78.139258},\n  month =\t Jun\n}\n\n@Article {almeida-2013-deconvolving,\n  author =\t {Almeida, Mariana S. C. and Figueiredo, M\\'ario},\n  journal =\t {IEEE Transactions on Image Processing},\n  title =\t {Deconvolving Images With Unknown Boundaries Using\n                  the Alternating Direction Method of Multipliers},\n  year =\t 2013,\n  month =\t Aug,\n  volume =\t 22,\n  number =\t 8,\n  pages =\t {3074--3086},\n  doi =\t\t {10.1109/TIP.2013.2258354}\n}\n\n@Article {antipa-2018-diffusercam,\n  author =\t {Nick Antipa and Grace Kuo and Reinhard Heckel and\n                  Ben Mildenhall and Emrah Bostan and Ren Ng and Laura\n                  Waller},\n  title =\t {{DiffuserCam}: lensless single-exposure 3{D}\n                  imaging},\n  journal =\t {Optica},\n  year =\t 2018,\n  month =\t Jan,\n  volume =\t 5,\n  number =\t 1,\n  doi =\t\t {10.1364/optica.5.000001},\n  pages =\t {1--9}\n}\n\n@Article {balke-2022-scico,\n  author =\t {Thilo Balke and Fernando Davis and Cristina\n                  Garcia-Cardona and Soumendu Majee and Michael McCann\n                  and Luke Pfister and Brendt Wohlberg},\n  title =\t {Scientific Computational Imaging Code ({SCICO})},\n  journal =\t {Journal of Open Source Software},\n  year =\t 2022,\n  volume =\t 7,\n  number =\t 78,\n  pages =\t 4722,\n  doi =\t\t {10.21105/joss.04722}\n}\n\n@Article {barzilai-1988-stepsize,\n  author =\t {Jonathan Barzilai and Jonathan M. Borwein},\n  title =\t {Two-point step size gradient methods},\n  journal =\t {{IMA} Journal of Numerical Analysis},\n  volume =\t 8,\n  pages =\t {141--148},\n  year =\t 1988,\n  month =\t Jan,\n  doi =\t\t {10.1093/imanum/8.1.141}\n}\n\n@Article {beck-2009-fast,\n  title =\t {A Fast Iterative Shrinkage-Thresholding Algorithm\n                  for Linear Inverse Problems},\n  author =\t {Beck, Amir and Teboulle, Marc},\n  journal =\t {SIAM Journal on Imaging Sciences},\n  year =\t 2009,\n  volume =\t 2,\n  number =\t 1,\n  pages =\t {183--202},\n  doi =\t\t {10.1137/080716542}\n}\n\n@Article {beck-2009-tv,\n  title =\t {Fast Gradient-Based Algorithms for Constrained Total\n                  Variation Image Denoising and Deblurring Problems},\n  author =\t {Beck, Amir and Teboulle, Marc},\n  journal =\t {IEEE Transactions on Image Processing},\n  year =\t 2009,\n  month =\t Nov,\n  volume =\t 18,\n  number =\t 11,\n  pages =\t {2419--2434},\n  doi =\t\t {10.1109/TIP.2009.2028250}\n}\n\n@InCollection {beck-2010-gradient,\n  author =\t {Amir Beck and Marc Teboulle},\n  editor =\t {Daniel P. Palomar and Yonina C. Eldar},\n  title =\t {Gradient-based algorithms with applications to\n                  signal-recovery problems},\n  booktitle =\t {Convex Optimization in Signal Processing and\n                  Communications},\n  pages =\t {42--88},\n  publisher =\t {Cambridge University Press},\n  year =\t 2010,\n  doi =\t\t {10.1017/CBO9780511804458.003},\n  url =          {http://www.math.tau.ac.il/~teboulle/papers/gradient_chapter.pdf}\n}\n\n@Software {bradbury-2018-jax,\n  author =\t {James Bradbury and Roy Frostig and Peter Hawkins and\n                  Matthew James Johnson and Chris Leary and Dougal\n                  Maclaurin and George Necula and Adam Paszke and Jake\n                  Vander{P}las and Skye Wanderman-{M}ilne and Qiao\n                  Zhang},\n  title =\t {{JAX}: composable transformations of\n                  {P}ython+{N}um{P}y programs},\n  url =\t\t {http://github.com/google/jax},\n  version =\t {0.2.5},\n  year =\t 2018\n}\n\n@Book {beck-2017-first,\n  title =\t {First-order methods in optimization},\n  author =\t {Beck, Amir},\n  year =\t 2017,\n  publisher =\t {Society for Industrial and Applied Mathematics\n                  (SIAM)},\n  doi =\t\t {10.1137/1.9781611974997},\n  isbn =\t 1611974984\n}\n\n@InProceedings {benning-2016-preconditioned,\n  title =\t {Preconditioned {ADMM} with nonlinear operator\n                  constraint},\n  author =\t {Benning, Martin and Knoll, Florian and\n                  Sch{\\\"o}nlieb, Carola-Bibiane and Valkonen, Tuomo},\n  booktitle =\t {IFIP Conference on System Modeling and Optimization\n                  (CSMO) 2015},\n  pages =\t {117--126},\n  year =\t 2016,\n  doi =\t\t {10.1007/978-3-319-55795-3_10}\n}\n\n@Article {boyd-2010-distributed,\n  title =\t {Distributed optimization and statistical learning\n                  via the alternating direction method of multipliers},\n  author =\t {Boyd, Stephen and Parikh, Neal and Chu, Eric and\n                  Peleato, Borja and Eckstein, Jonathan},\n  journal =\t {Foundations and Trends in Machine Learning},\n  year =\t 2010,\n  volume =\t 3,\n  number =\t 1,\n  pages =\t {1--122},\n  doi =\t\t {10.1561/2200000016}\n}\n\n@Article {buzzard-2018-plug,\n  title =\t {Plug-and-play unplugged: Optimization-free\n                  reconstruction using consensus equilibrium},\n  author =\t {Buzzard, Gregery T. and Chan, Stanley H. and\n                  Sreehari, Suhas and Bouman, Charles A.},\n  journal =\t {SIAM Journal on Imaging Sciences},\n  volume =\t 11,\n  number =\t 3,\n  pages =\t {2001--2020},\n  year =\t 2018,\n  doi =\t\t {10.1137/17M1122451}\n}\n\n@Article {cai-2010-singular,\n  title =\t {A Singular Value Thresholding Algorithm for Matrix\n                  Completion},\n  author =\t {Cai, Jian-Feng and Cand{\\`e}s, Emmanuel J. and Shen,\n                  Zuowei},\n  journal =\t {SIAM Journal on Optimization},\n  year =\t 2010,\n  volume =\t 20,\n  number =\t 4,\n  pages =\t {1956--1982},\n  doi =\t\t {10.1137/080738970}\n}\n\n@Article {chambolle-2010-firstorder,\n  author =\t {Antonin Chambolle and Thomas Pock},\n  title =\t {A First-Order Primal-Dual Algorithm for Convex\n                  Problems with~Applications to Imaging},\n  journal =\t {Journal of Mathematical Imaging and Vision},\n  doi =\t\t {10.1007/s10851-010-0251-1},\n  year =\t 2010,\n  month =\t Dec,\n  volume =\t 40,\n  number =\t 1,\n  pages =\t {120--145}\n}\n\n@Misc {chandler-2024-closedform,\n  author =\t {Edward P. Chandler and Shirin Shoushtari and Brendt\n                  Wohlberg and Ulugbek S. Kamilov},\n  title =\t {Closed-Form Approximation of the Total Variation\n                  Proximal Operator},\n  year =\t 2024,\n  eprint =\t {2412.07718}\n}\n\n@Article {clinthorne-1993-preconditioning,\n  author =\t {Clinthorne, Neal H. and Pan, Tin-Su and Chiao,\n                  Ping-Chun and Rogers, W. Leslie and Stamos, John A.},\n  title =\t {Preconditioning methods for improved convergence\n                  rates in iterative reconstructions},\n  journal =\t {IEEE Transactions on Medical Imaging},\n  year =\t 1993,\n  volume =\t 12,\n  number =\t 1,\n  pages =\t {78--83},\n  month =\t Mar,\n  doi =\t\t {10.1109/42.222670}\n}\n\n@InProceedings {dabov-2008-image,\n  author =\t {Kostadin Dabov and Alessandro Foi and Vladimir\n                  Katkovnik and Karen Egiazarian},\n  title =\t {Image restoration by sparse {3D} transform-domain\n                  collaborative filtering},\n  volume =\t 6812,\n  booktitle =\t {Image Processing: Algorithms and Systems VI},\n  editor =\t {Jaakko T. Astola and Karen O. Egiazarian and Edward\n                  R. Dougherty},\n  organization = {International Society for Optics and Photonics},\n  publisher =\t {SPIE},\n  pages =\t {62--73},\n  year =\t 2008,\n  month =\t Mar,\n  doi =\t\t {10.1117/12.766355}\n}\n\n@Article {daubechies-2004-iterative,\n  title =\t {An iterative thresholding algorithm for linear\n                  inverse problems with a sparsity constraint},\n  author =\t {Daubechies, Ingrid and Defrise, Michel and De Mol,\n                  Christine},\n  journal =\t {Communications on Pure and Applied Mathematics},\n  volume =\t 57,\n  number =\t 11,\n  pages =\t {1413--1457},\n  year =\t 2004,\n  doi =\t\t {10.1002/cpa.20042}\n}\n\n@Article {deng-2015-global,\n  author =\t {Wei Deng and Wotao Yin},\n  title =\t {On the Global and Linear Convergence of the\n                  Generalized Alternating Direction Method of\n                  Multipliers},\n  journal =\t {Journal of Scientific Computing},\n  year =\t 2015,\n  month =\t May,\n  volume =\t 66,\n  number =\t 3,\n  pages =\t {889--916},\n  doi =\t\t {10.1007/s10915-015-0048-x},\n}\n\n@Misc {diamond-2018-odp,\n  author =\t {Steven Diamond and Vincent Sitzmann and Felix Heide\n                  and Gordon Wetzstein},\n  title =\t {Unrolled Optimization with Deep Priors},\n  year =\t 2018,\n  eprint =\t {1705.08041v2}\n}\n\n@Article {esser-2010-general,\n  author =\t {Ernie Esser and Xiaoqun Zhang and Tony F. Chan},\n  title =\t {A General Framework for a Class of First Order\n                  Primal-Dual Algorithms for Convex Optimization in\n                  Imaging Science},\n  journal =\t {SIAM Journal on Imaging Sciences},\n  doi =\t\t {10.1137/09076934x},\n  year =\t 2010,\n  month =\t Jan,\n  volume =\t 3,\n  number =\t 4,\n  pages =\t {1015--1046}\n}\n\n@PhDThesis {esser-2010-primal,\n  author =\t {Ernie Esser},\n  title =\t {Primal Dual Algorithms for Convex Models and\n                  Applications to Image Restoration, Registration and\n                  Nonlocal Inpainting},\n  school =\t {University of California Los Angeles},\n  year =\t 2010\n}\n\n@InProceedings {florea-2017-robust,\n  title =\t {A Robust {FISTA}-Like Algorithm},\n  author =\t {Mihai I. Florea and Sergiy A. Vorobyov},\n  booktitle =\t {Proceedings of the IEEE International Conference on\n                  Acoustics, Speech and Signal Processing (ICASSP)},\n  year =\t 2017,\n  month =\t Mar,\n  pages =\t {4521--4525},\n  doi =\t\t {10.1109/ICASSP.2017.7953012},\n  location =\t {New Orleans, LA, USA}\n}\n\n@Article {gabay-1976-dual,\n  title =\t {A dual algorithm for the solution of nonlinear\n                  variational problems via finite element\n                  approximation},\n  author =\t {Gabay, Daniel and Mercier, Bertrand},\n  journal =\t {Computers \\& Mathematics with Applications},\n  volume =\t 2,\n  number =\t 1,\n  pages =\t {17--40},\n  year =\t 1976,\n  doi =\t\t {10.1016/0898-1221(76)90003-1}\n}\n\n@Article {glowinski-1975-approximation,\n  title =\t {Sur l'approximation, par {\\'e}l{\\'e}ments finis\n                  d'ordre un, et la r{\\'e}solution, par\n                  p{\\'e}nalisation-dualit{\\'e} d'une classe de\n                  probl{\\`e}mes de Dirichlet non lin{\\'e}aires},\n  author =\t {Glowinski, Roland and Marroco, Americo},\n  journal =\t {ESAIM: Mathematical Modelling and Numerical Analysis\n                  - Mod{\\'e}lisation Math{\\'e}matique et Analyse\n                  Num{\\'e}rique},\n  volume =\t 9,\n  number =\t {R2},\n  pages =\t {41--76},\n  year =\t 1975,\n  url =\t\t {http://eudml.org/doc/193269}\n}\n\n@Article {goldstein-2009-split,\n  author =\t {Tom Goldstein and Stanley Osher},\n  title =\t {The Split {B}regman Method for L1-Regularized\n                  Problems},\n  journal =\t {SIAM Journal on Imaging Sciences},\n  volume =\t 2,\n  number =\t 2,\n  pages =\t {323--343},\n  year =\t 2009,\n  doi =\t\t {10.1137/080725891}\n}\n\n@Misc {goldstein-2014-fasta,\n  title =\t {A Field Guide to Forward-Backward Splitting with a\n                  {FASTA} Implementation},\n  author =\t {Tom Goldstein and Christoph Studer and Richard\n                  Baraniuk},\n  year =\t 2014,\n  eprint =\t {1411.3406},\n  url =\t\t {http://arxiv.org/abs/1411.3406},\n}\n\n@Book {goodman-2005-fourier,\n  author =\t {Goodman, Joseph W.},\n  title =\t {Introduction to {F}ourier Optics},\n  publisher =\t {McGraw-Hill},\n  year =\t 2005,\n  isbn =\t 9780974707723,\n  edition =\t 3\n}\n\n@Misc {hossein-2024-total,\n  title =\t {Total Variation Regularization for Tomographic\n                  Reconstruction of Cylindrically Symmetric Objects},\n  author =\t {Maliha Hossain and Charles A. Bouman and Brendt\n                  Wohlberg},\n  year =\t 2024,\n  eprint =\t {2406.17928}\n}\n\n@Article {hoyer-2004-nonnegative,\n  title =\t {Non-negative matrix factorization with sparseness\n                  constraints},\n  author =\t {Patrik O. Hoyer},\n  journal =\t {Journal of Machine Learning Research},\n  volume =\t 5,\n  number =\t Nov,\n  pages =\t {1457--1469},\n  year =\t 2004,\n  url =          {https://www.jmlr.org/papers/volume5/hoyer04a/hoyer04a.pdf}\n}\n\n@Article {huber-1964-robust,\n  doi =\t\t {10.1214/aoms/1177703732},\n  year =\t 1964,\n  month =\t Mar,\n  volume =\t 35,\n  number =\t 1,\n  pages =\t {73--101},\n  author =\t {Peter J. Huber},\n  title =\t {Robust Estimation of a Location Parameter},\n  journal =\t {The Annals of Mathematical Statistics}\n}\n\n@Article {jin-2017-unet,\n  title =\t {Deep Convolutional Neural Network for Inverse\n                  Problems in Imaging},\n  author =\t {Kyong Hwan Jin and Michael T. McCann and Emmanuel\n                  Froustey and Michael Unser},\n  journal =\t {IEEE Transactions on Image Processing},\n  volume =\t 26,\n  number =\t 9,\n  pages =\t {4509--4522},\n  year =\t 2017,\n  doi =\t\t {10.1109/TIP.2017.2713099}\n}\n\n@Book {kak-1988-principles,\n  author = \t {Avinash C. Kak and Malcolm Slaney},\n  title = \t {Principles of Computerized Tomographic Imaging},\n  publisher = \t {IEEE Press},\n  year = \t 1988\n}\n\n@TechReport {kamilov-2016-minimizing,\n  author =\t {Ulugbek S. Kamilov},\n  title =\t {Minimizing Isotropic Total Variation without\n                  Subiterations},\n  institution =\t {Mitsubishi Electric Research Laboratories (MERL)},\n  year =\t 2016,\n  number =\t {TR2016-109},\n  month =\t Aug,\n  note =\t {Presented at International Traveling Workshop on\n                  Interactions Between Sparse Models and Technology\n                  (iTWIST) 2016},\n  url =           {https://www.merl.com/publications/docs/TR2016-109.pdf}\n}\n\n@Article {kamilov-2016-parallel,\n  title =\t {A parallel proximal algorithm for anisotropic total\n                  variation minimization},\n  author =\t {Ulugbek S. Kamilov},\n  journal =\t {IEEE Transactions on Image Processing},\n  volume =\t 26,\n  number =\t 2,\n  pages =\t {539--548},\n  year =\t 2016,\n  doi =\t\t {10.1109/tip.2016.2629449 }\n}\n\n@Article {kamilov-2017-plugandplay,\n  author =\t {Ulugbek S. Kamilov and Hassan Mansour and Brendt\n                  Wohlberg},\n  title =\t {A Plug-and-Play Priors Approach for Solving\n                  Nonlinear Imaging Inverse Problems},\n  year =\t 2017,\n  month =\t Dec,\n  journal =\t {IEEE Signal Processing Letters},\n  volume =\t 24,\n  number =\t 12,\n  doi =\t\t {10.1109/LSP.2017.2763583},\n  pages =\t {1872--1876}\n}\n\n@Article {kamilov-2023-plugandplay,\n  author =\t {Ulugbek S. Kamilov and Charles A. Bouman and Gregery\n                  T. Buzzard and Brendt Wohlberg},\n  title =\t {Plug-and-Play Methods for Integrating Physical and\n                  Learned Models in Computational Imaging},\n  journal =\t {IEEE Signal Processing Magazine},\n  year =\t 2023,\n  month =\t Jan,\n  volume =\t 40,\n  number =\t 1,\n  pages =\t {85--97},\n  doi =\t\t {10.1109/MSP.2022.3199595}\n}\n\n@Article {liu-2018-first,\n  author =\t {Jialin Liu and Cristina Garcia-Cardona and Brendt\n                  Wohlberg and Wotao Yin},\n  title =\t {First and Second Order Methods for Online\n                  Convolutional Dictionary Learning},\n  journal =\t {SIAM Journal on Imaging Sciences},\n  year =\t 2018,\n  volume =\t 11,\n  number =\t 2,\n  pages =\t {1589--1628},\n  doi =\t\t {10.1137/17M1145689},\n  eprint =\t {1709.00106}\n}\n\n@Article {lou-2018-fast,\n  title =\t {Fast {L1-L2} Minimization via a Proximal Operator},\n  author =\t {Yifei Lou and Ming Yan},\n  journal =\t {Journal of Scientific Computing},\n  volume =\t 74,\n  number =\t 2,\n  pages =\t {767--785},\n  year =\t 2018,\n  doi =\t\t {10.1007/s10915-017-0463-2}\n}\n\n@Article {maggioni-2012-nonlocal,\n  title =\t {Nonlocal transform-domain filter for volumetric data\n                  denoising and reconstruction},\n  author =\t {Maggioni, Matteo and Katkovnik, Vladimir and\n                  Egiazarian, Karen and Foi, Alessandro},\n  journal =\t {IEEE Transactions on Image Processing},\n  volume =\t 22,\n  number =\t 1,\n  pages =\t {119--133},\n  year =\t 2012,\n  doi =\t\t {10.1109/TIP.2012.2210725}\n}\n\n@InProceedings {makinen-2019-exact,\n  author =\t {Ymir M\\\"akinen and Lucio Azzari and Alessandro Foi},\n  booktitle =\t {IEEE International Conference on Image Processing\n                  (ICIP)},\n  title =\t {Exact Transform-Domain Noise Variance for\n                  Collaborative Filtering of Stationary Correlated\n                  Noise},\n  year =\t 2019,\n  pages =\t {185--189},\n  doi =\t\t {10.1109/ICIP.2019.8802964},\n  month =\t Sep\n}\n\n@Article {menon-2007-demosaicing,\n  title =\t {Demosaicing With Directional Filtering and a\n                  posteriori Decision},\n  author =\t {Daniele Menon and Stefano Andriani and Giancarlo\n                  Calvagno},\n  journal =\t {IEEE Transactions on Image Processing},\n  year =\t 2007,\n  month =\t Jan,\n  volume =\t 16,\n  number =\t 1,\n  pages =\t {132--141},\n  doi =\t\t {10.1109/tip.2006.884928}\n}\n\n@Article {monga-2021-algorithm,\n  author =\t {Monga, Vishal and Li, Yuelong and Eldar, Yonina C.},\n  journal =\t {IEEE Signal Processing Magazine},\n  title =\t {Algorithm Unrolling: Interpretable, Efficient Deep\n                  Learning for Signal and Image Processing},\n  year =\t 2021,\n  volume =\t 38,\n  number =\t 2,\n  pages =\t {18-44},\n  doi =\t\t {10.1109/MSP.2020.3016905}\n}\n\n@Book {nocedal-2006-numerical,\n  title =\t {Numerical Optimization},\n  author =\t {Jorge Nocedal and Stephen J. Wright},\n  year =\t 2006,\n  publisher =\t {Springer},\n  doi =\t\t {10.1007/978-0-387-40065-5},\n  isbn =\t 9780387303031\n}\n\n@Article {olufsen-2019-axitom,\n  title =\t {{AXITOM}: A {P}ython package for reconstruction of\n                  axisymmetric tomograms acquired by a conical beam},\n  volume =\t 4,\n  doi =\t\t {10.21105/joss.01704},\n  number =\t 42,\n  journal =\t {Journal of Open Source Software},\n  author =\t {Olufsen, Sindre},\n  year =\t 2019,\n  month =\t oct,\n  pages =\t {1704}\n}\n\n@Book {paganin-2006-coherent,\n  doi =\t\t {10.1093/acprof:oso/9780198567288.001.0001},\n  isbn =\t 9780198567288,\n  year =\t 2006,\n  month =\t Jan,\n  publisher =\t {Oxford University Press},\n  author =\t {David Paganin},\n  title =\t {Coherent X-Ray Optics}\n}\n\n@Article {parikh-2014-proximal,\n  title =\t {Proximal algorithms},\n  author =\t {Parikh, Neal and Boyd, Stephen},\n  journal =\t {Foundations and Trends in optimization},\n  volume =\t 1,\n  number =\t 3,\n  pages =\t {127--239},\n  year =\t 2014,\n  doi =\t\t {10.1561/2400000003}\n}\n\n@InProceedings {pock-2011-diagonal,\n  author =\t {Thomas Pock and Antonin Chambolle},\n  title =\t {Diagonal preconditioning for first order primal-dual\n                  algorithms in convex optimization},\n  booktitle =\t {Proceedings of the International Conference on\n                  Computer Vision (ICCV)},\n  doi =\t\t {10.1109/iccv.2011.6126441},\n  pages =\t {1762--1769},\n  year =\t 2011,\n  month =\t Nov,\n  address =\t {Barcelona, Spain}\n}\n\n@Misc {pyabel-2022,\n  author =\t {Stephen Gibson and Daniel Hickstein and Roman\n                  Yurchak, Mikhail Ryazanov and Dhrubajyoti Das and\n                  Gilbert Shih},\n  title =\t {PyAbel},\n  howpublished = {PyAbel/PyAbel: v0.8.5},\n  year =\t 2022,\n  doi =\t\t {10.5281/zenodo.5888391}\n}\n\n@InProceedings {ronneberger-2015-unet,\n  author =\t {Olaf Ronneberger and Philipp Fischer and Thomas\n                  Brox},\n  title =\t {{U}-{N}et: Convolutional Networks for Biomedical\n                  Image Segmentation},\n  booktitle =\t {Proceedings of the 18th International Conference on\n                  Medical Image Computing and Computer-Assisted\n                  Intervention},\n  doi =\t\t {10.1007/978-3-319-24574-4_28},\n  volume =\t 9351,\n  pages =\t {234--241},\n  year =\t 2015,\n  month =\t Oct,\n  address =\t {Munich, Germany},\n}\n\n@Article {rudin-1992-nonlinear,\n  author =\t {Leonid I. Rudin and Stanley Osher and Emad Fatemi},\n  title =\t {Nonlinear total variation based noise removal\n                  algorithms},\n  journal =\t {Physica D: Nonlinear Phenomena},\n  volume =\t 60,\n  number =\t {1--4},\n  pages =\t {259-268},\n  year =\t 1992,\n  doi =\t\t {10.1016/0167-2789(92)90242-F}\n}\n\n@Article {sauer-1993-local,\n  title =\t {A local update strategy for iterative reconstruction\n                  from projections},\n  author =\t {Sauer, Ken and Bouman, Charles},\n  journal =\t {IEEE Transactions on Signal Processing},\n  year =\t 1993,\n  month =\t Feb,\n  number =\t 2,\n  pages =\t {534--548},\n  volume =\t 41,\n  doi =\t\t {10.1109/78.193196}\n}\n\n@Article {soulez-2016-proximity,\n  author =\t {Ferr{\\'{e}}ol Soulez and {\\'{E}}ric Thi{\\'{e}}baut\n                  and Antony Schutz and Andr{\\'{e}} Ferrari and\n                  Fr{\\'{e}}d{\\'{e}}ric Courbin and Michael Unser},\n  title =\t {Proximity operators for phase retrieval},\n  journal =\t {Applied Optics},\n  doi =\t\t {10.1364/ao.55.007412},\n  year =\t 2016,\n  month =\t Sep,\n  volume =\t 55,\n  number =\t 26,\n  pages =\t {7412--7421}\n}\n\n@Article {sreehari-2016-plug,\n  author =\t {Suhas Sreehari and Singanallur V. Venkatakrishnan\n                  and Brendt Wohlberg and Gregery T. Buzzard and\n                  Lawrence F. Drummy and Jeffrey P. Simmons and\n                  Charles A. Bouman},\n  title =\t {Plug-and-Play Priors for Bright Field Electron\n                  Tomography and Sparse Interpolation},\n  year =\t 2016,\n  month =\t Dec,\n  journal =\t {IEEE Transactions on Computational Imaging},\n  volume =\t 2,\n  number =\t 4,\n  doi =\t\t {10.1109/TCI.2016.2599778},\n  pages =\t {408--423}\n}\n\n@Misc {svmbir-2020,\n  author =\t {SVMBIR Development Team},\n  title =\t {{S}uper-{V}oxel {M}odel {B}ased {I}terative\n                  {R}econstruction ({SVMBIR})},\n  howpublished = {Software library available from\n                  \\url{https://github.com/cabouman/svmbir}},\n  year =\t 2020\n}\n\n@Article {valkonen-2014-primal,\n  title =\t {A primal--dual hybrid gradient method for nonlinear\n                  operators with applications to {MRI}},\n  author =\t {Valkonen, Tuomo},\n  journal =\t {Inverse Problems},\n  volume =\t 30,\n  number =\t 5,\n  pages =\t 055012,\n  year =\t 2014,\n  doi =\t\t {10.1088/0266-5611/30/5/055012}\n}\n\n@InProceedings {venkatakrishnan-2013-plugandplay2,\n  author =\t {Singanallur V. Venkatakrishnan and Charles A. Bouman\n                  and Brendt Wohlberg},\n  title =\t {Plug-and-Play Priors for Model Based Reconstruction},\n  year =\t 2013,\n  month =\t Dec,\n  booktitle =\t {Proceedings of IEEE Global Conference on Signal and\n                  Information Processing (GlobalSIP)},\n  address =\t {Austin, TX, USA},\n  doi =\t\t {10.1109/GlobalSIP.2013.6737048},\n  pages =\t {945--948}\n}\n\n@Article {voelz-2009-digital,\n  author =\t {David G. Voelz and Michael C. Roggemann},\n  title =\t {Digital Simulation of Scalar Optical Diffraction:\n                  Revisiting Chirp Function Sampling Criteria and\n                  Consequences},\n  journal =\t {Applied Optics},\n  volume =\t 48,\n  number =\t 32,\n  pages =\t 6132,\n  year =\t 2009,\n  doi =\t\t {10.1364/ao.48.006132},\n}\n\n@Book {voelz-2011-computational,\n  author =\t {Voelz, David},\n  title =\t {Computational {F}ourier optics : a {MATLAB}\n                  tutorial},\n  year =\t 2011,\n  publisher =\t {SPIE Press},\n  address =\t {Bellingham, Wash},\n  isbn =\t 9780819482044,\n}\n\n@InProceedings {wohlberg-2014-efficient,\n  author =\t {Brendt Wohlberg},\n  title =\t {Efficient Convolutional Sparse Coding},\n  booktitle =\t {Proceedings of IEEE International Conference on\n                  Acoustics, Speech, and Signal Processing (ICASSP)},\n  year =\t 2014,\n  month =\t May,\n  doi =\t\t {10.1109/ICASSP.2014.6854992},\n  pages =\t {7173--7177},\n  location =\t {Florence, Italy}\n}\n\n@Article {wohlberg-2021-psf,\n  author =\t {Brendt Wohlberg and Przemek Wozniak},\n  title =\t {PSF Estimation in Crowded Astronomical Imagery as a\n                  Convolutional Dictionary Learning Problem},\n  year =\t 2021,\n  month =\t Feb,\n  journal =\t {IEEE Signal Processing Letters},\n  volume =\t 28,\n  doi =\t\t {10.1109/LSP.2021.3050706},\n  pages =\t {374--378}\n}\n\n\n@Article {yang-2012-linearized,\n  author =\t {Junfeng Yang and Xiaoming Yuan},\n  title =\t {Linearized augmented {L}agrangian and alternating\n                  direction methods for nuclear norm minimization},\n  journal =\t {Mathematics of Computation},\n  doi =\t\t {10.1090/s0025-5718-2012-02598-1},\n  year =\t 2012,\n  month =\t Mar,\n  volume =\t 82,\n  number =\t 281,\n  pages =\t {301--329}\n}\n\n\n@InProceedings {yu-2013-better,\n  author =\t {Yu, Yao-Liang},\n  booktitle =\t {Advances in Neural Information Processing Systems},\n  editor =\t {C.J. Burges and L. Bottou and M. Welling and\n                  Z. Ghahramani and K.Q. Weinberger},\n  title =\t {Better Approximation and Faster Algorithm Using the\n                  Proximal Average},\n  url =          {https://proceedings.neurips.cc/paper_files/paper/2013/file/49182f81e6a13cf5eaa496d51fea6406-Paper.pdf},\n  volume =\t 26,\n  year =\t 2013\n}\n\n@Article {zhang-2017-dncnn,\n  author =\t {Kai Zhang and Wangmeng Zuo and Yunjin Chen and Deyu\n                  Meng and Lei Zhang},\n  title =\t {Beyond a {G}aussian Denoiser: Residual Learning of\n                  Deep {CNN} for Image Denoising},\n  year =\t 2017,\n  month =\t Jul,\n  journal =\t {IEEE Transactions on Image Processing},\n  volume =\t 26,\n  number =\t 7,\n  doi =\t\t {10.1109/TIP.2017.2662206},\n  pages =\t {3142--3155}\n}\n\n@Article {zhang-2021-plug,\n  author =\t {Zhang, Kai and Li, Yawei and Zuo, Wangmeng and\n                  Zhang, Lei and Van Gool, Luc and Timofte, Radu},\n  title =\t {Plug-and-Play Image Restoration With Deep Denoiser\n                  Prior},\n  journal =\t {IEEE Transactions on Pattern Analysis and Machine\n                  Intelligence},\n  year =\t 2022,\n  volume =\t 44,\n  number =\t 10,\n  doi =\t\t {10.1109/TPAMI.2021.3088914},\n  pages =\t {6360--6376}\n}\n\n@Article {zhou-2006-adaptive,\n  author =\t {Bin Zhou and Li Gao and Yu-Hong Dai},\n  title =\t {Gradient Methods with Adaptive Step-Sizes},\n  year =\t 2006,\n  month =\t Mar,\n  journal =\t {Computational Optimization and Applications},\n  volume =\t 35,\n  doi =\t\t {10.1007/s10589-006-6446-0},\n  pages =\t {69--86}\n}\n"
  },
  {
    "path": "docs/source/style.rst",
    "content": ".. _scico_dev_style:\n\n\nStyle Guide\n===========\n\n\nOverview\n--------\n\nWe adhere to `PEP8 <https://www.python.org/dev/peps/pep-0008/>`_ with\nthe exception of allowing a line length limit of 99 characters (as\nopposed to 79 characters). The standard limit of 72 characters for\n\"flowing long blocks of text\" in docstrings or comments is\nretained. We use `Black <https://github.com/psf/black>`_ as our PEP-8\nFormatter and `isort <https://pypi.org/project/isort/>`_ to sort\nimports. (Please set up a `pre-commit hook <https://pre-commit.com>`_\nto ensure any modified code passes format check before it is committed\nto the development repo.)\n\nWe aim to incorporate `PEP 526\n<https://www.python.org/dev/peps/pep-0484/>`_ type annotations\nthroughout the library. See the `Mypy\n<https://mypy.readthedocs.io/en/stable/>`_ type annotation `cheat\nsheet <https://mypy.readthedocs.io/en/stable/cheat_sheet_py3.html>`_\nfor usage examples. Custom types are defined in :mod:`.typing`.\n\nOur coding conventions are based on both the `NumPy conventions\n<https://numpydoc.readthedocs.io/en/latest/format.html#overview>`_ and\nthe `Google docstring conventions\n<https://google.github.io/styleguide/pyguide.html>`_.\n\nUnicode variable names are allowed for internal usage (e.g. for Greek\ncharacters for mathematical symbols), but not as part of the public\ninterface for functions or methods.\n\n\nNaming\n------\n\nWe follow the `Google naming conventions <https://google.github.io/styleguide/pyguide.html#3164-guidelines-derived-from-guidos-recommendations>`_:\n\n.. list-table:: Naming Conventions\n   :widths: 20 20\n   :header-rows: 1\n\n   * - Component\n     - Naming Convention\n   * - Modules\n     - module_name\n   * - Package\n     - package_name\n   * - Class\n     - ClassName\n   * - Method\n     - method_name\n   * - Function\n     - function_name\n   * - Exception\n     - ExceptionName\n   * - Variable\n     - var_name\n   * - Parameter\n     - parameter_name\n   * - Constant\n     - CONSTANT_NAME\n\nThese names should be descriptive and unambiguous to avoid confusion\nwithin the code and other modules in the future.\n\nExample:\n\n.. code:: Python\n\n    d = 6  # Day of the week == Saturday\n    if d < 5:\n        print(\"Weekday\")\n\nHere the code could be hard to follow since the name ``d`` is not\ndescriptive and requires extra comments to explain the code, which\nwould have been solved otherwise by good naming conventions.\n\nExample:\n\n.. code:: Python\n\n   fldln = 5 # field length\n\nThis could be improved by using the descriptive variable ``field_len``.\n\nThings to avoid:\n\n- Single character names except for the following special cases:\n    - counters or iterators (``i``, ``j``);\n    - `e` as an exception identifier (``Exception e``);\n    - `f` as a file in ``with`` statements;\n    - mathematical notation in which a reference to the paper or\n      algorithm with said notation is preferred if not clear from the\n      intended purpose.\n\n- Trailing underscores unless the component is meant to be protected or private:\n    - protected: Use a single underscore, ``_``, for protected access; and\n    - pseudo-private: Use double underscores, ``__``, for\n      pseudo-private access via name mangling.\n\n\nDisplaying and Printing Strings\n-------------------------------\n\nWe follow the `Google string conventions\n<https://google.github.io/styleguide/pyguide.html#310-strings>`_. Notably,\nprefer to use Python f-strings, rather than `.format` or `%`\nsyntax. For example:\n\n.. code:: Python\n\n   state = \"active\"\n   print(\"The state is %s\" % state) # Not preferred\n   print(f\"The state is {state}\")   # Preferred\n\n\nImports\n-------\n\nWe follow the `Google import conventions\n<https://google.github.io/styleguide/pyguide.html#22-imports>`_. The\nuse of ``import`` statements should be reserved for packages and\nmodules only, i.e. individual classes and functions should not be\nimported. The only exception to this is the typing module.\n\n- Use ``import x`` for importing packages and modules, where x is the package or\n  module name.\n- Use ``from x import y`` where x is the package name and y is the module name.\n- Use ``from x import y as z`` if two modules named ``y`` are imported\n  or if ``y`` is too long of a name.\n- Use ``import y as z`` when ``z`` is a standard abbreviation like\n  ``import numpy as np``.\n\n\nVariables\n---------\n\nWe follow the `Google variable typing conventions\n<https://google.github.io/styleguide/pyguide.html#3198-typing-variables>`_\nwhich states that there are a few extra documentation and coding\npractices that can be applied to variables such as:\n\n- One may type a variables by using a ``: type`` before the function\n  value is assigned, e.g.,\n\n  .. code-block:: python\n\n     a: Foo = SomeDecoratedFunction()\n\n- Avoid global variables.\n- A function can refer to variables defined in enclosing functions but\n  cannot assign to them.\n\n\nParameters\n----------\n\nThere are three important style components for parameters inspired by\nthe `NumPy parameter conventions\n<https://numpydoc.readthedocs.io/en/latest/format.html#parameters>`_:\n\n1. Typing\n\n   We use type annotations meaning we specify the types of the inputs\n   and outputs of any method.  From the ``typing`` module we can use\n   more types such as ``Optional``, ``Union``, and ``Any``.  For\n   example,\n\n   .. code-block:: python\n\n      def foo(a: str) -> str:\n          \"\"\"Takes an input of type string and returns a value of type string\"\"\"\n          ...\n\n2. Default Values\n\n   Parameters should include ``parameter_name = value`` where value is\n   the default for that particular parameter. If the parameter has a\n   type then the format is ``parameter_name: Type = value``. When\n   documenting parameters, if a parameter can only assume one of a\n   fixed set of values, those values can be listed in braces, with the\n   default appearing first. For example,\n\n   .. code-block:: python\n\n      \"\"\"\n      letters: {'A', 'B, 'C'}\n         Description of `letters`.\n      \"\"\"\n\n3. NoneType\n\n   In Python, ``NoneType`` is a first-class type, meaning the type\n   itself can be passed into and returned from functions.  ``None`` is\n   the most commonly used alias for ``NoneType``. If any of the\n   parameters of a function can be ``None`` then it has to be\n   declared. ``Optional[T]`` is preferred over ``Union[T, None]``.\n   For example,\n\n   .. code-block:: python\n\n      def foo(a: Optional[str], b: Optional[Union[str, int]]) -> str:\n      ...\n\n   For documentation purposes, ``NoneType`` or ``None`` should be\n   written with double backticks.\n\n\nDocstrings\n----------\n\nDocstrings are a way to document code within Python and it is the\nfirst statement within a package, module, class, or function. To\ngenerate a document with all the documentation for the code use `pydoc\n<https://docs.python.org/3/library/pydoc.html>`_.\n\n\nTyping\n~~~~~~\n\nWe follow the `NumPy parameter conventions\n<https://numpydoc.readthedocs.io/en/latest/format.html#parameters>`_. The\nfollowing are docstring-specific usages:\n\n- Always enclose variables in single backticks.\n- For the parameter types, be as precise as possible, do not use backticks.\n\n\nModules\n~~~~~~~\n\nWe follow the `Google module conventions\n<https://google.github.io/styleguide/pyguide.html#382-modules>`_. Notably,\nfiles must start with a docstring that describes the functionality of\nthe module. For example,\n\n.. code-block:: python\n\n    \"\"\"A one-line summary of the module must be terminated by a period.\n\n    Leave a blank line and describe the module or program. Optionally\n    describe exported classes, functions, and/or usage examples.\n\n    Usage Example:\n\n    foo = ClassFoo()\n    bar = foo.FunctionBar()\n    \"\"\"\"\n\n\nFunctions\n~~~~~~~~~\n\nThe word *function* encompasses functions, methods, or generators in\nthis section.  The docstring should give enough information to make\ncalls to the function without needing to read the functions code.\n\nWe follow the `Google function conventions\n<https://google.github.io/styleguide/pyguide.html#383-functions-and-methods>`_.\nNotably, functions should contain docstrings unless:\n- not externally visible (the function name is prefaced with an underscore) or\n- very short.\n\nThe docstring should be imperative-style ``\"\"\"Fetch rows from a\nTable\"\"\"`` instead of the descriptive-style ``\"\"\"Fetches rows from a\nTable\"\"\"``. If the method overrides a method from a base class then it\nmay use a simple docstring referencing that base class such as\n``\"\"\"See base class\"\"\"``, unless the behavior is different from the\noverridden method or there are extra details that need to be\ndocumented.\n\n| There are three sections to function docstrings:\n\n- Args:\n    - List each parameter by name, and include a description for each parameter.\n- Returns: (or Yield in the case of generators)\n    - Describe the type of the return value. If a function only\n      returns ``None`` then this section is not required.\n- Raises:\n   - List all exceptions followed by a description. The name and\n     description should be separated by a colon followed by a space.\n\nExample:\n\n.. code-block:: python\n\n    def fetch_smalltable_rows(table_handle: smalltable.Table,\n                              keys: Sequence[Union[bytes, str]],\n                              require_all_keys: bool = False,\n    ) -> Mapping[bytes, Tuple[str]]:\n        \"\"\"Fetch rows from a Smalltable.\n\n        Retrieve rows pertaining to the given keys from the Table instance\n        represented by table_handle. String keys will be UTF-8 encoded.\n\n        Args:\n            table_handle:\n               An open smalltable.Table instance.\n            keys:\n               A sequence of strings representing the key of each table\n               row to fetch. String `keys` will be UTF-8 encoded.\n            require_all_keys: Optional\n               If `require_all_keys` is ``True`` only\n               rows with values set for all keys will be returned.\n\n        Returns:\n            A dict mapping keys to the corresponding table row data\n            fetched. Each row is represented as a tuple of strings. For\n            example:\n\n            {b'Serak': ('Rigel VII', 'Preparer'),\n             b'Zim': ('Irk', 'Invader'),\n             b'Lrrr': ('Omicron Persei 8', 'Emperor')}\n\n            Returned keys are always bytes. If a key from the keys argument is\n            missing from the dictionary, then that row was not found in the\n            table (and require_all_keys must have been False).\n\n        Raises:\n            IOError: An error occurred accessing the smalltable.\n        \"\"\"\n\n\nClasses\n~~~~~~~\n\nWe follow the `Google class conventions\n<https://google.github.io/styleguide/pyguide.html#384-classes>`_. Classes,\nlike functions, should have a docstring below the definition\ndescribing the class and the class functionality. If the class\ncontains public attributes, the class should have an attributes\nsection where each attribute is listed by name and followed by a\ndescription, separated by a colon, like for function parameters. For\nexample,\n\n| Example:\n\n.. code:: Python\n\n    class foo:\n\t\"\"\"One-liner describing the class.\n\n        Additional information or description for the class.\n        Can be multi-line\n\n        Attributes:\n            attr1: First attribute of the class.\n            attr2: Second attribute of the class.\n        \"\"\"\n\n    def __init__(self):\n        \"\"\"Should have a docstring of type function.\"\"\"\n        pass\n\n    def method(self):\n        \"\"\"Should have a docstring of type: function.\"\"\"\n        pass\n\n\nExtra Sections\n~~~~~~~~~~~~~~\n\nWe follow the `NumPy style guide\n<https://numpydoc.readthedocs.io/en/latest/format.html#sections>`_. Notably,\nthe following are sections that can be added to functions, modules,\nclasses, or method definitions.\n\n-  See Also:\n\n   - Refers to related code. Used to direct users to other modules,\n     functions, or classes that they may not be aware of.\n   - When referring to functions in the same sub-module, no prefix is\n     needed. Example: For ``numpy.mean`` inside the same sub-module:\n\n     .. code-block:: python\n\n       \"\"\"\n       See Also\n       --------\n       average: Weighted average.\n       \"\"\"\n\n   - For a reference to ``fft`` in another module:\n\n     .. code-block:: python\n\n       \"\"\"\n       See Also\n       --------\n       fft.fft2: 2-D fast discrete Fourier transform.\n       \"\"\"\n\n-  Notes\n\n   - Provide additional information about the code. May include\n     mathematical equations in LaTeX format. For example,\n\n     .. code-block:: python\n\n       \"\"\"\n       Notes\n       -----\n       The FFT is a fast implementation of the discrete Fourier transform:\n       .. math::\n            X(e^{j\\omega } ) = x(n)e^{ - j\\omega n}\n       \"\"\"\n\n     Math can also be used inline:\n\n     .. code-block:: python\n\n       \"\"\"\n       Notes\n       -----\n       The value of :math:`\\omega` is larger than 5.\n       \"\"\"\n\n     For a list of available LaTex macros, search for \"macros\" in\n     `docs/source/conf.py <https://github.com/lanl/scico/blob/main/docs/source/conf.py>`_.\n\n-  Examples:\n\n   - Uses the doctest format and is meant to showcase usage.\n   - If there are multiple examples include blank lines before and\n      after each example. For example,\n\n     .. code-block:: python\n\n       \"\"\"\n       Examples\n       --------\n       Necessary imports\n       >>> import numpy as np\n\n       Comment explaining example 1.\n\n       >>> int(np.add(1, 2))\n       3\n\n       Comment explaining a new example.\n\n       >>> np.add([1, 2], [3, 4])\n       array([4, 6])\n\n       If the example is too long then each line after the first start it\n       with a ``...``\n\n       >>> np.add([[1, 2], [3, 4]],\n       ...        [[5, 6], [7, 8]])\n       array([[ 6,  8],\n              [10, 12]])\n\n       \"\"\"\n\n\nComments\n~~~~~~~~\n\nThere are two types of comments: *block* and *inline*. A good rule of\nthumb to follow for when to include a comment in your code is *if you\nhave to explain it or is too hard to figure out at first glance, then\ncomment it*.  An example of this, taken from the `Google comment\nconventions\n<https://google.github.io/styleguide/pyguide.html#385-block-and-inline-comments>`_,\nis complicated operations which most likely require a block of\ncomments beforehand.\n\n.. code-block:: Python\n\n    # We use a block comment because the following code performs a\n    # difficult operation. Here we can explain the variables or\n    # what the concept of the operation does in an easier\n    # to understand way.\n\n    i = i & (i-1) == 0:  # true if i is 0 or a power of 2 [explains the concept not the code]\n\nIf a comment consists of one or more full sentences (as is typically\nthe case for *block* comments), it should start with an upper case\nletter and end with a period. *Inline* comments often consist of a\nbrief phrase which is not a full sentence, in which case they should\nhave a lower case initial letter and not have a terminating period.\n\n\nMarkup\n~~~~~~\n\nThe following components require the recommended markup taken from the\n`NumPy Conventions\n<https://numpydoc.readthedocs.io/en/latest/format.html#common-rest-concepts>`__.:\n\n- Paragraphs:\n  Indentation is significant and indicates the indentation of the output. New\n  paragraphs are marked with a blank line.\n- Variable, parameter, module, function, method, and class names:\n  Should be written between single back-ticks (e.g. \\`x\\`, rendered as `x`), but\n  note that use of `Sphinx cross-reference syntax <https://www.sphinx-doc.org/en/master/usage/restructuredtext/domains.html#cross-referencing-python-objects>`_ is preferred for modules (`:mod:\\`module-name\\`` ), functions (`:func:\\`function-name\\`` ), methods (`:meth:\\`method-name\\`` ) and classes (`:class:\\`class-name\\`` ).\n- None, NoneType, True, and False:\n  Should be written between double back-ticks (e.g. \\`\\`None\\`\\`, \\`\\`True\\`\\`,\n  rendered as ``None``, ``True``).\n- Types:\n  Should be written between double back-ticks (e.g. \\`\\`int\\`\\`, rendered as ``int``).\n  NumPy dtypes, however, should be written using cross-reference syntax, e.g.\n  \\:attr\\:\\`~numpy.float32\\` for :attr:`~numpy.float32`.\n\nOther components can use \\*italics\\*, \\*\\*bold\\*\\*, and \\`\\`monospace\\`\\`\n(respectively rendered as *italics*, **bold**, and ``monospace``) if needed, but\nnot for variable names, doctest code, or multi-line code.\n\n\nDocumentation\n-------------\n\nDocumentation that is separate from code (like this page) should follow the\n`IEEE Style Manual\n<https://journals.ieeeauthorcenter.ieee.org/your-role-in-article-production/ieee-editorial-style-manual/>`_.\nFor additional grammar and usage guidance,\nrefer to `The Chicago Manual of Style <https://www.chicagomanualofstyle.org/>`_.\nA few notable guidelines:\n\n* Equations which conclude a sentence should end with a period,\n  e.g., \"Poisson's equation is\n\n  .. math::\n\n     \\Delta \\varphi = f \\;.\"\n\n* Do not capitalize acronyms or inititalisms when defining them,\n  e.g., \"computer-aided system engineering (CASE),\"\n  \"fast Fourier transform (FFT).\"\n\n* Avoid capitalization in text except where absolutely necessary,\n  e.g., \"Newton’s first law.\"\n\n* Use a single space after the period at the end of a sentence.\n\n\nThe source code (`.rst` files) for these pages does not have a hard\nline-length guideline, but line breaks at or before 79 characters are\nencouraged.\n"
  },
  {
    "path": "docs/source/team.rst",
    "content": "Developers\n==========\n\nCore Developers\n---------------\n\n- `Cristina Garcia Cardona <https://github.com/crstngc>`_\n- `Michael McCann <https://github.com/Michael-T-McCann>`_\n- `Brendt Wohlberg <https://github.com/bwohlberg>`_\n\n\nEmeritus Developers\n-------------------\n\n- `Thilo Balke <https://github.com/tbalke>`_\n- `Fernando Davis <https://github.com/FernandoDavis>`_\n- `Soumendu Majee <https://github.com/smajee>`_\n- `Luke Pfister <https://github.com/lukepfister>`_\n\n\nContributors\n------------\n\n- `Weijie Gan <https://github.com/wjgancn>`_ (Non-blind variant of DnCNN)\n- `Oleg Korobkin <https://github.com/korobkin>`_ (BlockArray improvements)\n- `Andrew Leong <https://scholar.google.com/citations?user=-2wRWbcAAAAJ&hl=en>`_ (Improvements to optics module documentation)\n- `Saurav Maheshkar <https://github.com/SauravMaheshkar>`_ (Improvements to pre-commit configuration)\n- `Yanpeng Yuan <https://github.com/yanpeng7>`_ (ASTRA interface improvements)\n- `Li-Ta (Ollie) Lo <https://github.com/ollielo>`_ (ASTRA interface improvements)\n- `Renat Sibgatulin <https://github.com/Sibgatulin>`_ (Docs corrections)\n- `Salman Naqvi <https://github.com/shnaqvi>`_ (Contributions to approximate TV norm prox and proximal average implementation)\n- `Eddie Chandler <https://github.com/edchandler00>`_ (Contributions to approximate isotropic TV norm prox)\n"
  },
  {
    "path": "docs/source/zreferences.rst",
    "content": "References\n==========\n\n.. bibliography:: references.bib\n   :style: plain\n"
  },
  {
    "path": "docs/tikxfigures/img_align.tex",
    "content": "\\documentclass[tikz]{standalone}\n\\usetikzlibrary{calc,angles,quotes}\n\\begin{document}\n\\begin{tikzpicture}[scale=2]\n  \\footnotesize\n\n  % Define rectangle dimensions\n  \\def\\width{2}       % base width\n  \\def\\aspect{1.25}   % aspect ratio (height/width)\n  \\pgfmathsetmacro{\\height}{\\width*\\aspect}\n\n  % Rotate rectangle by 20 degrees\n  \\begin{scope}[rotate around={-20:(0,0)}]\n    % Draw rectangle with bottom-left corner at origin\n    \\draw[thick] (0,0) -| (\\width,\\height)\n           node[pos=0.25,below] {$N_1$}\n           node[pos=0.75,right] {$N_0$}\n           -| (0,0);\n\n    % Save post-rotation rectangle corners\n    \\coordinate (BL) at (0,0);\n    \\coordinate (BR) at (\\width,0);\n    \\coordinate (TL) at (0,\\height);\n    \\coordinate (TR) at (\\width,\\height);\n  \\end{scope}\n\n  \\def\\liney{2.5} % top line height\n  \\coordinate (PL) at (BL |- 0,\\liney); % vertical intersection from bottom-left\n  \\coordinate (PR) at (TR |- 0,\\liney); % vertical intersection from top-right\n\n   % Horizontal line representing sensor\n  \\draw[blue,thick] (PL) -- (PR);\n\n  % Draw verticals to meet horizontal line\n  \\draw[blue,thick,dashed] (BL) -- (PL);\n  \\draw[blue,thick,dashed] (TR) -- (PR);\n\n  % Double-sided arrow for width label\n  \\draw[<->,blue,dashed] (PL) ++(0,0.15) -- ($(PR)+(0,0.15)$)\n    node[midway,above] {$w_0 + w_1$};\n\n  % Central vertical line through top-left corner\n  \\draw[blue,thick,dashed] (TL |- BL) -- (TL |- 0,\\liney);\n\n  % Horizontal lines with labels\n  \\draw[blue,thick,dashed] (TR) -- (TL |- TR) node[midway,below] {$w_1 = N_1 \\cos(\\theta)$};\n  \\draw[blue,thick,dashed] (BL) -- (TL |- BL) node[right,below] {$\\qquad w_0 = N_0 \\sin(\\theta)$};\n\n  % Define intersection point with central vertical line\n  \\coordinate (VL) at (TL |- BR);\n  \\coordinate (HL) at (TL |- TR);\n\n  % θ between left rectangle side and central vertical line\n  \\pic [draw, ->, \"$\\theta$\", angle radius=30] {angle = BL--TL--VL};\n\n  % 90-θ between top rectangle side and horizontal line\n  \\pic [draw, ->, \"$90\\!-\\!\\theta\\quad\\;\\;$\", angle radius=50] {angle = TL--TR--HL};\n\n\\end{tikzpicture}\n\\end{document}\n"
  },
  {
    "path": "docs/tikxfigures/makesvg.sh",
    "content": "#! /bin/bash\n\npdf2svg vol_align_xyz.pdf vol_align_xyz.svg\npdf2svg vol_align_xz.pdf vol_align_xz.svg\npdf2svg vol_align_yz.pdf vol_align_yz.svg\npdf2svg img_align.pdf img_align.svg\n"
  },
  {
    "path": "docs/tikxfigures/vol_align_xyz.tex",
    "content": "\\documentclass{standalone}\n\n\\usepackage{tikz, tikz-3dplot}\n\\begin{document}\n\n\\tdplotsetmaincoords{70}{110}\n\n\\begin{tikzpicture}[scale=5,tdplot_main_coords]\n  \\footnotesize\n  \\draw[thick,->] (0,0,0) -- (1,0,0) node[anchor=north east]{$x$};\n  \\draw[thick,->] (0,0,0) -- (0,1,0) node[anchor=north west]{$y$};\n  \\draw[thick,->] (0,0,0) -- (0,0,1) node[anchor=south]{$z$};\n\n  \\coordinate (O) at (0,0,0);\n  \\tdplotsetcoord{P}{1}{30}{40}\n\n  \\draw[-stealth,thick,color=red] (O) -- (P);\n  \\node[draw=none,color=red] at (0.0,0.11,0.77) {$(x, y, z)$};\n\n  \\draw[dashed, color=blue] (P) -- (Pxz);\n  \\draw[dashed, color=blue] (P) -- (Pyz);\n  \\draw[dashed, color=blue] (O) -- (Pxz);\n  \\draw[dashed, color=blue] (O) -- (Pyz);\n  \\draw[dashed, color=blue] (Pz) -- (Pxz);\n  \\draw[dashed, color=blue] (Pz) -- (Pyz);\n\n  \\node[draw=none,color=blue] at (0.38,0.0,0.5) {$r_x$};\n  \\node[draw=none,color=blue] at (0.0,0.24,0.4) {$r_y$};\n\n  \\tdplotsetthetaplanecoords{0}\n  \\tdplotdrawarc[tdplot_rotated_coords,blue,dotted]{(O)}{.25}{22.7}{90}{anchor=mid east}{$\\theta_x$}\n  \\tdplotsetthetaplanecoords{90}\n  \\tdplotdrawarc[tdplot_rotated_coords,blue,dotted]{(O)}{.25}{23}{90}{anchor=mid west}{$\\theta_y$}\n\n\\end{tikzpicture}\n\n\\end{document}\n"
  },
  {
    "path": "docs/tikxfigures/vol_align_xz.tex",
    "content": "\\documentclass{standalone}\n\n\\usepackage{tikz, tikz-3dplot}\n\\begin{document}\n\n\\tdplotsetmaincoords{90}{0}\n\n\\begin{tikzpicture}[scale=5,tdplot_main_coords]\n  \\footnotesize\n  \\draw[thick,->] (0,0,0) -- (1,0,0) node[anchor=west]{$x$};\n  \\draw[thick,->] (0,0,0) -- (0,0,1) node[anchor=south]{$z$};\n\n  \\coordinate (O) at (0,0,0);\n  \\tdplotsetcoord{P}{1}{30}{40}\n\n  \\draw[-stealth,thick,color=red] (O) -- (P)  node[anchor=west]{$\\!(x, z)$};\n\n  \\draw[dashed, color=blue] (P) -- (Pyz);\n  \\draw[dashed, color=blue] (P) -- (Px);\n\n  \\node[draw=none,color=blue] at (0.4,0.0,-0.055) {$r_x \\cos (\\theta_x)$};\n  \\node[draw=none,rotate=90,color=blue] at (-0.055,0.0,0.84) {$r_x \\sin (\\theta_x)$};\n  \\tdplotsetthetaplanecoords{0}\n  \\tdplotdrawarc[tdplot_rotated_coords,red,dotted]{(O)}{.25}{22.7}{90}{anchor=north east}{$\\theta_x$}\n\n  \\node[draw=none,color=red] at (0.14,0.0,0.5) {$r_x$};\n\n\\end{tikzpicture}\n\n\\end{document}\n"
  },
  {
    "path": "docs/tikxfigures/vol_align_yz.tex",
    "content": "\\documentclass{standalone}\n\n\\usepackage{tikz, tikz-3dplot}\n\\begin{document}\n\n\\tdplotsetmaincoords{90}{90}\n\n\\begin{tikzpicture}[scale=5,tdplot_main_coords]\n  \\footnotesize\n  \\draw[thick,->] (0,0,0) -- (0,1,0) node[anchor=west]{$y$};\n  \\draw[thick,->] (0,0,0) -- (0,0,1) node[anchor=south]{$z$};\n\n  \\coordinate (O) at (0,0,0);\n  \\tdplotsetcoord{P}{1}{30}{40}\n\n  \\draw[-stealth,thick,color=red] (O) -- (P)  node[anchor=west]{$\\!(y, z)$};\n\n  \\draw[dashed, color=blue] (P) -- (Pxz);\n  \\draw[dashed, color=blue] (P) -- (Py);\n\n  \\node[draw=none,color=blue] at (0.0,0.35,-0.055) {$r_y \\cos (\\theta_y)$};\n  \\node[draw=none,rotate=90,color=blue] at (0.0,-0.055,0.84) {$r_y \\sin (\\theta_y)$};\n  \\tdplotsetthetaplanecoords{90}\n  \\tdplotdrawarc[tdplot_rotated_coords,red,dotted]{(O)}{.25}{23}{90}{anchor=north east}{$\\theta_y$}\n\n  \\node[draw=none,color=red] at (0.05,0.11,0.5) {$r_y$};\n\n\\end{tikzpicture}\n\n\\end{document}\n"
  },
  {
    "path": "examples/README.rst",
    "content": "SCICO Usage Examples\n====================\n\nThis directory contains usage examples for the SCICO package. The primary form of these examples is the Python scripts in the directory ``scripts``. A corresponding set of Jupyter notebooks, in the directory ``notebooks``, is auto-generated from these usage example scripts.\n\n\nBuilding Notebooks\n------------------\n\nThe scripts for building Jupyter notebooks from the source example scripts are currently only supported under Linux. All scripts described below should be run from this directory, i.e. ``[repo root]/examples``.\n\n\nRunning on a GPU\n^^^^^^^^^^^^^^^^\n\nSince some of the examples require a considerable amount of memory (``deconv_microscopy_tv_admm.py`` and ``deconv_microscopy_allchn_tv_admm.py`` in particular), it is recommended to set the following environment variables prior to building the notebooks:\n\n::\n\n  export XLA_PYTHON_CLIENT_ALLOCATOR=platform\n  export XLA_PYTHON_CLIENT_PREALLOCATE=false\n\n\nRunning on a CPU\n^^^^^^^^^^^^^^^^\n\nIf a GPU is not available, or if the available GPU does not have sufficient memory to build the notebooks, set the environment variable\n\n::\n\n  JAX_PLATFORM_NAME=cpu\n\nto run on the CPU instead.\n\n\nBuilding Specific Examples\n--------------------------\n\nTo build or rebuild notebooks for specific examples, the example script names can be specified on the command line, e.g.\n\n::\n\n  python makenotebooks.py ct_astra_pcg.py ct_astra_tv_admm.py\n\nWhen rebuilding notebooks for examples that themselves make use of ``ray``\nfor parallelization (e.g. ``deconv_microscopy_allchn_tv_admm.py``), it is recommended to specify serial notebook execution, as in\n\n::\n\n  python makenotebooks.py --no-ray deconv_microscopy_allchn_tv_admm.py\n\n\nBuilding All Examples\n---------------------\n\nBy default, ``makenotebooks.py`` only rebuilds notebooks that are out of date with respect to their corresponding example scripts, as determined by their respective file timestamps. However, timestamps for files retrieved from version control may not be meaningful for this purpose. To rebuild all examples, the following commands (assuming that GPUs are available) are recommended:\n\n::\n\n  export XLA_PYTHON_CLIENT_ALLOCATOR=platform\n  export XLA_PYTHON_CLIENT_PREALLOCATE=false\n\n  touch scripts/*.py\n\n  python makenotebooks.py --no-ray deconv_microscopy_tv_admm.py deconv_microscopy_allchn_tv_admm.py\n\n  python makenotebooks.py\n\n\nUpdating Notebooks in the Repo\n------------------------------\n\nThe recommended procedure for rebuilding notebooks for inclusion in the ``data`` submodule is:\n\n1. Add and commit the modified script(s).\n\n2. Rebuild the notebooks as described above.\n\n2. Add and commit the updated notebooks following the submodule handling procedure described in the developer docs.\n\n\nAdding a New Notebook\n---------------------\n\nThe procedure for adding a adding a new notebook is:\n\n1. Add an entry for the source file in ``scripts/index.rst``. Note that a script that is not listed in this index will not be converted into a notebook.\n\n2. Run ``makeindex.py`` to update the example scripts README file, the notebook index file, and the examples index in the docs.\n\n3. Build the corresponding notebook following the instructions above.\n\n4. Add and commit the new script, the ``scripts/index.rst`` script index file, the auto-generated ``scripts/README.rst`` file and ``docs/source/examples.rst`` index file, and the new or updated notebooks and the auto-generated ``notebooks/index.ipynb`` file in the notebooks directory, following the submodule handling procedure as described in the developer docs.\n\n\n\nManagement Utilities\n--------------------\n\nA number of files in this directory assist in the mangement of the usage examples:\n\n`examples_requirements.txt <examples_requirements.txt>`_\n   Requirements file (as used by ``pip``) listing additional dependencies for running the usage example scripts.\n\n`notebooks_requirements.txt <notebooks_requirements.txt>`_\n   Requirements file (as used by ``pip``) listing additional dependencies for building the Jupyter notebooks from the usage example scripts.\n\n`makenotebooks.py <makenotebooks.py>`_\n   Auto-generate Jupyter notebooks from the example scripts.\n\n`updatejnbmd.py <updatejnbmd.py>`_\n   Update markdown cells in notebooks from corresponding example scripts.\n\n`makeindex.py <makeindex.py>`_\n   Auto-generate the docs example index ``docs/source/examples.rst`` from the example scripts index ``scripts/index.rst``.\n\n`scriptcheck.sh <scriptcheck.sh>`_\n   Run all example scripts with smaller problems and a reduced number of iterations as a rapid check that they are functioning correctly.\n"
  },
  {
    "path": "examples/examples_requirements.txt",
    "content": "-r ../requirements.txt\ncolorama\ncolour_demosaicing\nsvmbir>=0.4.0\nastra-toolbox\nxdesign>=0.5.5\nray[tune,train]>=2.44\nhyperopt\nsetuptools<82.0.0  # workaround for hyperopt 0.2.7\npydantic\norbax-checkpoint>=0.5.0\nbm3d>=4.0.0\nbm4d>=4.2.2\n"
  },
  {
    "path": "examples/jnb.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Support functions for manipulating Jupyter notebooks.\"\"\"\n\nimport re\nfrom timeit import default_timer as timer\n\nimport nbformat\nfrom nbconvert.preprocessors import CellExecutionError, ExecutePreprocessor\nfrom py2jn.tools import py_string_to_notebook, write_notebook\n\n\ndef py_file_to_string(src):\n    \"\"\"Preprocess example script file and return result as a string.\"\"\"\n\n    with open(src, \"r\") as srcfile:\n        # Drop header comment\n        for line in srcfile:\n            if line[0] != \"#\":\n                break  # assume first non-comment line is a newline that can be dropped\n        # Insert notebook plot config after last import\n        lines = []\n        import_seen = False\n        for line in srcfile:\n            line = re.sub('^r\"\"\"', '\"\"\"', line)  # remove r from r\"\"\"\n            line = re.sub(\":cite:`([^`]+)`\", r'<cite data-cite=\"\\1\"/>', line)  # fix cite format\n            if import_seen:\n                # Once an import statement has been seen, break on encountering a line that\n                # is neither an import statement nor a newline, nor a component of an import\n                # statement extended over multiple lines, nor an os.environ statement, nor a\n                # ray.init statement, nor components of a try/except construction (note that\n                # handling of these final two cases is probably not very robust).\n                if not re.match(\n                    r\"(^import|^from|^\\n$|^\\W+[^\\W]|^\\)$|^os.environ|^ray.init|^try:$|^except)\",\n                    line,\n                ):\n                    lines.append(line)\n                    break\n            else:\n                # Set flag indicating that an import statement has been seen once one has\n                # been encountered\n                if re.match(\"^import|^from .* import\", line):\n                    import_seen = True\n            lines.append(line)\n\n        if \"plot\" in \"\".join(lines):\n            # Backtrack through list of lines to find last import statement\n            n = 1\n            for line in lines[-2::-1]:\n                if re.match(\"^(import|from)\", line):\n                    break\n                else:\n                    n += 1\n            # Insert notebook plotting config directly after last import statement\n            lines.insert(-n, \"plot.config_notebook_plotting()\\n\")\n\n        # Process remainder of source file\n        for line in srcfile:\n            if re.match(r\"^input\\(\", line):  # end processing when input statement encountered\n                break\n            line = re.sub('^r\"\"\"', '\"\"\"', line)  # remove r from r\"\"\"\n            line = re.sub(r\":cite:\\`([^`]+)\\`\", r'<cite data-cite=\"\\1\"/>', line)  # fix cite format\n            lines.append(line)\n\n        # Backtrack through list of lines to remove trailing newlines\n        n = 0\n        for line in lines[::-1]:\n            if re.match(\"^\\n$\", line):\n                n += 1\n            else:\n                break\n        if n > 0:\n            lines = lines[0:-n]\n\n        return \"\".join(lines)\n\n\ndef script_to_notebook(src, dst):\n    \"\"\"Convert a Python example script into a Jupyter notebook.\"\"\"\n\n    s = py_file_to_string(src)\n    nb = py_string_to_notebook(s)\n    write_notebook(nb, dst)\n\n\ndef read_notebook(fname):\n    \"\"\"Read a notebook from the specified notebook file.\"\"\"\n\n    try:\n        nb = nbformat.read(fname, as_version=4)\n    except (AttributeError, nbformat.reader.NotJSONError):\n        raise RuntimeError(\"Error reading notebook file %s.\" % fname)\n    return nb\n\n\ndef execute_notebook(fname):\n    \"\"\"Execute the specified notebook file.\"\"\"\n\n    with open(fname) as f:\n        nb = nbformat.read(f, as_version=4)\n    ep = ExecutePreprocessor(timeout=None)\n    try:\n        t0 = timer()\n        out = ep.preprocess(nb)\n        t1 = timer()\n        with open(fname, \"w\", encoding=\"utf-8\") as f:\n            nbformat.write(nb, f)\n    except CellExecutionError:\n        print(f\"ERROR executing {fname}\")\n        return False\n    print(f\"{fname} done in {(t1 - t0):.1e} s\")\n    return True\n\n\ndef notebook_executed(nbfn):\n    \"\"\"Determine whether the notebook at `nbfn` has been executed.\"\"\"\n\n    try:\n        nb = nbformat.read(nbfn, as_version=4)\n    except (AttributeError, nbformat.reader.NotJSONError):\n        raise RuntimeError(\"Error reading notebook file %s.\" % pth)\n    cells = nb[\"worksheets\"][0][\"cells\"]\n    for n in range(len(nb[\"cells\"])):\n        if cells[n].cell_type == \"code\" and cells[n].execution_count is None:\n            return False\n    return True\n\n\ndef same_notebook_code(nb1, nb2):\n    \"\"\"Return ``True`` if the code cells of notebook objects `nb1` and `nb2`\n    are all the same.\n    \"\"\"\n\n    if \"cells\" in nb1:\n        nb1c = nb1[\"cells\"]\n    else:\n        nb1c = nb1[\"worksheets\"][0][\"cells\"]\n    if \"cells\" in nb2:\n        nb2c = nb2[\"cells\"]\n    else:\n        nb2c = nb2[\"worksheets\"][0][\"cells\"]\n\n    # Notebooks do not match if the number of cells differ\n    if len(nb1c) != len(nb2c):\n        return False\n\n    # Iterate over cells in nb1\n    for n in range(len(nb1c)):\n        # Notebooks do not match if corresponding cells have different\n        # types\n        if nb1c[n][\"cell_type\"] != nb2c[n][\"cell_type\"]:\n            return False\n        # Notebooks do not match if source of corresponding code cells\n        # differ\n        if nb1c[n][\"cell_type\"] == \"code\" and nb1c[n][\"source\"] != nb2c[n][\"source\"]:\n            return False\n\n    return True\n\n\ndef same_notebook_markdown(nb1, nb2):\n    \"\"\"Return ``True`` if the markdown cells of notebook objects `nb1`\n    and `nb2` are all the same.\n    \"\"\"\n\n    if \"cells\" in nb1:\n        nb1c = nb1[\"cells\"]\n    else:\n        nb1c = nb1[\"worksheets\"][0][\"cells\"]\n    if \"cells\" in nb2:\n        nb2c = nb2[\"cells\"]\n    else:\n        nb2c = nb2[\"worksheets\"][0][\"cells\"]\n\n    # Notebooks do not match if the number of cells differ\n    if len(nb1c) != len(nb2c):\n        return False\n\n    # Iterate over cells in nb1\n    for n in range(len(nb1c)):\n        # Notebooks do not match if corresponding cells have different\n        # types\n        if nb1c[n][\"cell_type\"] != nb2c[n][\"cell_type\"]:\n            return False\n        # Notebooks do not match if source of corresponding code cells\n        # differ\n        if nb1c[n][\"cell_type\"] == \"markdown\" and nb1c[n][\"source\"] != nb2c[n][\"source\"]:\n            return False\n\n    return True\n\n\ndef replace_markdown_cells(src, dst):\n    \"\"\"Overwrite markdown cells in notebook object `dst` with corresponding\n    cells in notebook object `src`.\n    \"\"\"\n\n    if \"cells\" in src:\n        srccell = src[\"cells\"]\n    else:\n        srccell = src[\"worksheets\"][0][\"cells\"]\n    if \"cells\" in dst:\n        dstcell = dst[\"cells\"]\n    else:\n        dstcell = dst[\"worksheets\"][0][\"cells\"]\n\n    # It is an error to attempt markdown replacement if src and dst\n    # have different numbers of cells\n    if len(srccell) != len(dstcell):\n        raise ValueError(\"Notebooks do not have the same number of cells.\")\n\n    # Iterate over cells in src\n    for n in range(len(srccell)):\n        # It is an error to attempt markdown replacement if any\n        # corresponding pair of cells have different type\n        if srccell[n][\"cell_type\"] != dstcell[n][\"cell_type\"]:\n            raise ValueError(\"Cell number %d of different type in src and dst.\")\n        # If current src cell is a markdown cell, copy the src cell to\n        # the dst cell\n        if srccell[n][\"cell_type\"] == \"markdown\":\n            dstcell[n][\"source\"] = srccell[n][\"source\"]\n\n\ndef remove_error_output(src):\n    \"\"\"Remove output to stderr from all cells in `src`.\"\"\"\n\n    if \"cells\" in src:\n        cells = src[\"cells\"]\n    else:\n        cells = src[\"worksheets\"][0][\"cells\"]\n\n    modified = False\n    for c in cells:\n        if \"outputs\" in c:\n            dellist = []\n            for n, out in enumerate(c[\"outputs\"]):\n                if \"name\" in out and out[\"name\"] == \"stderr\":\n                    dellist.append(n)\n                    modified = True\n            for n in dellist[::-1]:\n                del c[\"outputs\"][n]\n\n    return modified\n"
  },
  {
    "path": "examples/makeindex.py",
    "content": "#!/usr/bin/env python\n\n# Construct an index README file and a docs example index file from\n# source index file \"scripts/index.rst\".\n# Run as\n#     python makeindex.py\n\n\nimport re\nfrom pathlib import Path\n\nimport nbformat as nbf\nimport py2jn\nimport pypandoc\n\nsrc = \"scripts/index.rst\"\n\n# Make dict mapping script names to docstring header titles\ntitles = {}\nscripts = list(Path(\"scripts\").glob(\"*py\"))\nfor s in scripts:\n    prevline = None\n    with open(s, \"r\") as sfile:\n        for line in sfile:\n            if line[0:3] == \"===\":\n                titles[s.name] = prevline.rstrip()\n                break\n            else:\n                prevline = line\n\n\n# Build README in scripts directory\ndst = \"scripts/README.rst\"\nwith open(dst, \"w\") as dstfile:\n    with open(src, \"r\") as srcfile:\n        for line in srcfile:\n            # Detect lines containing script filenames\n            m = re.match(r\"(\\s+)- ([^\\s]+.py)\", line)\n            if m:\n                prespace = m.group(1)\n                name = m.group(2)\n                title = titles[name]\n                print(\n                    \"%s`%s <%s>`_\\n%s   %s\" % (prespace, name, name, prespace, title), file=dstfile\n                )\n            else:\n                print(line, end=\"\", file=dstfile)\n\n\n# Build notebooks index file in notebooks directory\ndst = \"notebooks/index.ipynb\"\nrst_text = \"\"\nwith open(src, \"r\") as srcfile:\n    for line in srcfile:\n        # Detect lines containing script filenames\n        m = re.match(r\"(\\s+)- ([^\\s]+).py\", line)\n        if m:\n            prespace = m.group(1)\n            name = m.group(2)\n            title = titles[name + \".py\"]\n            rst_text += \"%s- `%s <%s.ipynb>`_\\n\" % (prespace, title, name)\n        else:\n            rst_text += line\n# Convert text from rst to markdown\nmd_format = \"markdown_github+tex_math_dollars+fenced_code_attributes\"\nmd_text = pypandoc.convert_text(rst_text, md_format, format=\"rst\", extra_args=[\"--atx-headers\"])\nmd_text = '\"\"\"' + md_text + '\"\"\"'\n# Convert from python to notebook format and write notebook\nnb = py2jn.py_string_to_notebook(md_text)\npy2jn.tools.write_notebook(nb, dst, nbver=4)\nnb = nbf.read(dst, nbf.NO_CONVERT)\nnb.metadata = {\"nbsphinx\": {\"orphan\": True}}\nnbf.write(nb, dst)\n\n\n# Build examples index for docs\ndst = \"../docs/source/examples.rst\"\nprfx = \"examples/\"\nwith open(dst, \"w\") as dstfile:\n    print(\".. _example_notebooks:\\n\", file=dstfile)\n    with open(src, \"r\") as srcfile:\n        for line in srcfile:\n            # Add toctree and include statements after main heading\n            if line[0:3] == \"===\":\n                print(line, end=\"\", file=dstfile)\n                print(\"\\n.. toctree::\\n   :maxdepth: 1\", file=dstfile)\n                print(\"\\n.. include:: include/examplenotes.rst\", file=dstfile)\n                continue\n            # Detect lines containing script filenames\n            m = re.match(r\"(\\s+)- ([^\\s]+).py\", line)\n            if m:\n                print(\"   \" + prfx + m.group(2), file=dstfile)\n            else:\n                print(line, end=\"\", file=dstfile)\n                # Add toctree statement after section headings\n                if line[0:3] == line[0] * 3 and line[0] in [\"=\", \"-\", \"^\"]:\n                    print(\"\\n.. toctree::\\n   :maxdepth: 1\", file=dstfile)\n"
  },
  {
    "path": "examples/makenotebooks.py",
    "content": "#!/usr/bin/env python\n\n# Extract a list of Python scripts from \"scripts/index.rst\" and\n# create/update and execute any Jupyter notebooks that are out\n# of date with respect to their source Python scripts. If script\n# names specified on command line, process them instead.\n# Run\n#     python makenotebooks.py -h\n# for usage details.\n\nimport argparse\nimport os\nimport re\nimport signal\nimport sys\nfrom pathlib import Path\n\nimport psutil\nfrom jnb import execute_notebook, script_to_notebook\n\nexamples_dir = Path(__file__).resolve().parent  # absolute path to ../scico/examples/\n\nhave_ray = True\ntry:\n    import ray\nexcept ImportError:\n    have_ray = False\n\n\ndef script_uses_ray(fname):\n    \"\"\"Determine whether a script uses ray.\"\"\"\n\n    with open(fname, \"r\") as f:\n        text = f.read()\n    return bool(re.search(\"^import ray\", text, re.MULTILINE)) or bool(\n        re.search(\"^import scico.ray\", text, re.MULTILINE)\n    )\n\n\ndef script_path(sname):\n    \"\"\"Get script path from script name.\"\"\"\n\n    return examples_dir / \"scripts\" / Path(sname)\n\n\ndef notebook_path(sname):\n    \"\"\"Get notebook path from script path.\"\"\"\n\n    return examples_dir / \"notebooks\" / Path(Path(sname).stem + \".ipynb\")\n\n\nargparser = argparse.ArgumentParser(\n    description=\"Convert Python example scripts to Jupyter notebooks.\"\n)\nargparser.add_argument(\n    \"--all\",\n    action=\"store_true\",\n    help=\"Process all notebooks, without checking timestamps. \"\n    \"Has no effect when files to process are explicitly specified.\",\n)\nargparser.add_argument(\n    \"--no-exec\", action=\"store_true\", help=\"Create/update notebooks but don't execute them.\"\n)\nargparser.add_argument(\n    \"--no-ray\",\n    action=\"store_true\",\n    help=\"Execute notebooks serially, without the use of ray parallelization.\",\n)\nargparser.add_argument(\n    \"--verbose\",\n    action=\"store_true\",\n    help=\"Verbose operation.\",\n)\nargparser.add_argument(\n    \"--test\",\n    action=\"store_true\",\n    help=\"Show actions that would be taken but don't do anything.\",\n)\nargparser.add_argument(\"filename\", nargs=\"*\", help=\"Optional Python example script filenames\")\nargs = argparser.parse_args()\n\n\n# Raise error if ray needed but not present\nif not have_ray and not args.no_ray:\n    raise RuntimeError(\"The ray package is required to run this script, try --no-ray\")\n\n\nif args.filename:\n    # Script names specified on command line\n    scriptnames = [os.path.basename(s) for s in args.filename]\nelse:\n    # Read script names from index file\n    scriptnames = []\n    srcidx = examples_dir / \"scripts\" / \"index.rst\"\n    with open(srcidx, \"r\") as idxfile:\n        for line in idxfile:\n            m = re.match(r\"(\\s+)- ([^\\s]+.py)\", line)\n            if m:\n                scriptnames.append(m.group(2))\n\n# Ensure list entries are unique\nscriptnames = sorted(list(set(scriptnames)))\n\n# Create list of selected scripts.\nscripts = []\nfor s in scriptnames:\n    spath = script_path(s)\n    npath = notebook_path(s)\n    # If scripts specified on command line or --all flag specified, convert all scripts.\n    # Otherwise, only convert scripts that have a newer timestamp than their corresponding\n    # notebooks, or that have not previously been converted (i.e. corresponding notebook\n    # file does not exist).\n    if (\n        args.all\n        or args.filename\n        or not npath.is_file()\n        or spath.stat().st_mtime > npath.stat().st_mtime\n    ):\n        # Add to the list of selected scripts\n        scripts.append(s)\n\nif not scripts:\n    if args.verbose:\n        print(\"No scripts require conversion\")\n    sys.exit(0)\n\n# Display status information\nif args.verbose:\n    print(f\"Processing scripts {', '.join(scripts)}\")\n\n# Convert selected scripts to corresponding notebooks and determine which can be run in parallel\nserial_scripts = []\nparallel_scripts = []\nfor s in scripts:\n    spath = script_path(s)\n    npath = notebook_path(s)\n    # Determine how script should be executed\n    if script_uses_ray(spath):\n        serial_scripts.append(s)\n    else:\n        parallel_scripts.append(s)\n    # Make notebook file\n    if args.verbose or args.test:\n        print(f\"Converting script {s} to notebook\")\n    if not args.test:\n        script_to_notebook(spath, npath)\n\nif args.no_exec:\n    if args.verbose:\n        print(\"Notebooks will not be executed\")\n    sys.exit(0)\n\n\n# If ray disabled or not worth using, run all serially\nif args.no_ray or len(parallel_scripts) < 2:\n    serial_scripts.extend(parallel_scripts)\n    parallel_scripts = []\n\n# Execute notebooks corresponding to serial_scripts\nfor s in serial_scripts:\n    npath = notebook_path(s)\n    if args.verbose or args.test:\n        print(f\"Executing notebook corresponding to script {s}\")\n    if not args.test:\n        execute_notebook(npath)\n\n\n# Execute notebooks corresponding to parallel_scripts\nif parallel_scripts:\n    if args.verbose or args.test:\n        print(\n            f\"Notebooks corresponding to scripts {', '.join(parallel_scripts)} will \"\n            \"be executed in parallel\"\n        )\n\n    # Execute notebooks in parallel using ray\n    nproc = len(parallel_scripts)\n    ray.init()\n\n    ngpu = 0\n    ar = ray.available_resources()\n    ncpu = max(int(ar[\"CPU\"]) // nproc, 1)\n    if \"GPU\" in ar:\n        ngpu = max(int(ar[\"GPU\"]) // nproc, 1)\n    if args.verbose or args.test:\n        print(f\"    Running on {ncpu} CPUs and {ngpu} GPUs per process\")\n\n    # Function to execute each notebook with available resources suitably divided\n    @ray.remote(num_cpus=ncpu, num_gpus=ngpu)\n    def ray_run_nb(fname):\n        execute_notebook(fname)\n\n    if not args.test:\n        # Execute relevant notebooks in parallel\n        try:\n            notebooks = [notebook_path(s) for s in parallel_scripts]\n            objrefs = [ray_run_nb.remote(nbfile) for nbfile in notebooks]\n            ray.wait(objrefs, num_returns=len(objrefs))\n        except KeyboardInterrupt:\n            print(\"\\nTerminating on keyboard interrupt\")\n            for ref in objrefs:\n                ray.cancel(ref, force=True)\n            ray.shutdown()\n            # Clean up sub-processes not ended by ray.cancel\n            process = psutil.Process()\n            children = process.children(recursive=True)\n            for child in children:\n                os.kill(child.pid, signal.SIGTERM)\n"
  },
  {
    "path": "examples/notebooks_requirements.txt",
    "content": "-r examples-requirements.txt\nipykernel\nipywidgets\nnbformat\nnbconvert\nnb_conda_kernels<=2.5.1  # 2.5.2 broken: see anaconda/nb_conda_kernels#280\npsutil\npy2jn\npypandoc\n"
  },
  {
    "path": "examples/removejnberr.py",
    "content": "#!/usr/bin/env python\n\n# Remove output to stderr in notebooks. NB: use with caution!\n# Run as\n#     python removejnberr.py\n\nimport glob\nimport os\n\nfrom jnb import read_notebook, remove_error_output\nfrom py2jn.tools import write_notebook\n\nfor src in glob.glob(os.path.join(\"notebooks\", \"*.ipynb\")):\n    nb = read_notebook(src)\n    modflg = remove_error_output(nb)\n    if modflg:\n        print(f\"Removing output to stderr from {src}\")\n        write_notebook(nb, src)\n"
  },
  {
    "path": "examples/scriptcheck.sh",
    "content": "#!/usr/bin/env bash\n\n# Basic test of example script functionality by running them all with\n# optimization algorithms configured to use only a small number of iterations.\n# Currently only supported under Linux.\n\nSCRIPT=$(basename $0)\nSCRIPTPATH=$(realpath $(dirname $0))\nUSAGE=$(cat <<-EOF\nUsage: $SCRIPT [-h] [-d]\n          [-h] Display usage information\n          [-e] Display excerpt of error message on failure\n          [-d] Skip tests involving additional data downloads\n          [-t] Skip tests related to learned model training\n          [-g] Skip tests that need a GPU\nEOF\n)\n\nOPTIND=1\nDISPLAY_ERROR=0\nSKIP_DOWNLOAD=0\nSKIP_TRAINING=0\nSKIP_GPU=0\nwhile getopts \":hedtg\" opt; do\n    case $opt in\n    h) echo \"$USAGE\"; exit 0;;\n    e) DISPLAY_ERROR=1;;\n    d) SKIP_DOWNLOAD=1;;\n    t) SKIP_TRAINING=1;;\n    g) SKIP_GPU=1;;\n    \\?) echo \"Error: invalid option -$OPTARG\" >&2\n            echo \"$USAGE\" >&2\n            exit 1\n            ;;\n    esac\ndone\n\nshift $((OPTIND-1))\nif [ ! $# -eq 0 ] ; then\n    echo \"Error: no positional arguments\" >&2\n    echo \"$USAGE\" >&2\n    exit 2\nfi\n\n# Set environment variables and paths. This script is assumed to be run\n# from its root directory.\nexport PYTHONPATH=$SCRIPTPATH/..\nexport PYTHONIOENCODING=utf-8\nexport MPLBACKEND=agg\nexport PYTHONWARNINGS=ignore:Matplotlib:UserWarning\nd='/tmp/scriptcheck_'$$\nmkdir -p $d\nretval=0\n\n# On SIGINT clean up temporary script directory and exit.\nfunction cleanupexit {\n    rm $d/*.py\n    rmdir $d\n    exit 2\n}\ntrap cleanupexit SIGINT\n\n# Define regex strings.\nre1=\"s/'maxiter' ?: ?[0-9]+/'maxiter': 2/g; \"\nre2=\"s/^maxiter ?= ?[0-9]+/maxiter = 2/g; \"\nre3=\"s/^N ?= ?[0-9]+/N = 32/g; \"\nre4=\"s/num_samples= ?[0-9]+/num_samples = 2/g; \"\nre5='s/\\\"cpu\\\": ?[0-9]+/\\\"cpu\\\": 1/g; '\nre6=\"s/^downsampling_rate ?= ?[0-9]+/downsampling_rate = 12/g; \"\nre7=\"s/input\\(/#input\\(/g; \"\nre8=\"s/fig.show\\(/#fig.show\\(/g\"\n\n# Iterate over all scripts.\nfor f in $SCRIPTPATH/scripts/*.py; do\n\n    printf \"%-50s \" $(basename $f)\n\n    # Skip problem cases.\n    if [ $SKIP_DOWNLOAD -eq 1 ] && grep -q '_microscopy' <<< $f; then\n        printf \"%s\\n\" skipped\n        continue\n    fi\n    if [ $SKIP_TRAINING -eq 1 ]; then\n    if grep -q '_datagen' <<< $f || grep -q '_train' <<< $f; then\n        printf \"%s\\n\" skipped\n        continue\n        fi\n    fi\n    if [ $SKIP_GPU -eq 1 ] && grep -q '_astra_3d' <<< $f; then\n        printf \"%s\\n\" skipped\n        continue\n    fi\n    if [ $SKIP_GPU -eq 1 ] && grep -q 'ct_projector_comparison_3d' <<< $f; then\n        printf \"%s\\n\" skipped\n        continue\n    fi\n\n    # Create temporary copy of script with all algorithm maxiter values set\n    # to small number and final input statements commented out.\n    g=$d/$(basename $f)\n    sed -E -e \"$re1$re2$re3$re4$re5$re6$re7$re8\" $f > $g\n\n    # Run temporary script and print status message.\n    if output=$(timeout 180s python $g 2>&1); then\n        printf \"%s\\n\" succeeded\n    else\n        printf \"%s\\n\" FAILED\n        retval=1\n    if [ $DISPLAY_ERROR -eq 1 ]; then\n       echo \"$output\" | tail -8 | sed -e 's/^/    /'\n    fi\n    fi\n\n    # Remove temporary script.\n    rm -f $g\n\ndone\n\n# Remove temporary script directory.\nrmdir $d\n\nexit $retval\n"
  },
  {
    "path": "examples/scripts/README.rst",
    "content": "Usage Examples\n==============\n\n\nOrganized by Application\n------------------------\n\n\nComputed Tomography\n^^^^^^^^^^^^^^^^^^^\n\n   `ct_abel_tv_admm.py <ct_abel_tv_admm.py>`_\n      TV-Regularized Abel Inversion\n   `ct_abel_tv_admm_tune.py <ct_abel_tv_admm_tune.py>`_\n      Parameter Tuning for TV-Regularized Abel Inversion\n   `ct_symcone_tv_padmm.py <ct_symcone_tv_padmm.py>`_\n      TV-Regularized Cone Beam CT for Symmetric Objects\n   `ct_astra_noreg_pcg.py <ct_astra_noreg_pcg.py>`_\n      CT Reconstruction with CG and PCG\n   `ct_astra_3d_tv_admm.py <ct_astra_3d_tv_admm.py>`_\n      3D TV-Regularized Sparse-View CT Reconstruction (ADMM Solver)\n   `ct_astra_3d_tv_padmm.py <ct_astra_3d_tv_padmm.py>`_\n      3D TV-Regularized Sparse-View CT Reconstruction (Proximal ADMM Solver)\n   `ct_tv_admm.py <ct_tv_admm.py>`_\n      TV-Regularized Sparse-View CT Reconstruction (Integrated Projector)\n   `ct_astra_tv_admm.py <ct_astra_tv_admm.py>`_\n      TV-Regularized Sparse-View CT Reconstruction (ASTRA Projector)\n   `ct_multi_tv_admm.py <ct_multi_tv_admm.py>`_\n      TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors)\n   `ct_astra_weighted_tv_admm.py <ct_astra_weighted_tv_admm.py>`_\n      TV-Regularized Low-Dose CT Reconstruction\n   `ct_svmbir_tv_multi.py <ct_svmbir_tv_multi.py>`_\n      TV-Regularized CT Reconstruction (Multiple Algorithms)\n   `ct_svmbir_ppp_bm3d_admm_cg.py <ct_svmbir_ppp_bm3d_admm_cg.py>`_\n      PPP (with BM3D) CT Reconstruction (ADMM with CG Subproblem Solver)\n   `ct_svmbir_ppp_bm3d_admm_prox.py <ct_svmbir_ppp_bm3d_admm_prox.py>`_\n      PPP (with BM3D) CT Reconstruction (ADMM with Fast SVMBIR Prox)\n   `ct_fan_svmbir_ppp_bm3d_admm_prox.py <ct_fan_svmbir_ppp_bm3d_admm_prox.py>`_\n      PPP (with BM3D) Fan-Beam CT Reconstruction\n   `ct_modl_train_foam2.py <ct_modl_train_foam2.py>`_\n      CT Training and Reconstruction with MoDL\n   `ct_odp_train_foam2.py <ct_odp_train_foam2.py>`_\n      CT Training and Reconstruction with ODP\n   `ct_unet_train_foam2.py <ct_unet_train_foam2.py>`_\n      CT Training and Reconstructions with UNet\n   `ct_projector_comparison_2d.py <ct_projector_comparison_2d.py>`_\n      2D X-ray Transform Comparison\n   `ct_projector_comparison_3d.py <ct_projector_comparison_3d.py>`_\n      3D X-ray Transform Comparison\n\nDeconvolution\n^^^^^^^^^^^^^\n\n   `deconv_circ_tv_admm.py <deconv_circ_tv_admm.py>`_\n      Circulant Blur Image Deconvolution with TV Regularization\n   `deconv_tv_admm.py <deconv_tv_admm.py>`_\n      Image Deconvolution with TV Regularization (ADMM Solver)\n   `deconv_tv_padmm.py <deconv_tv_padmm.py>`_\n      Image Deconvolution with TV Regularization (Proximal ADMM Solver)\n   `deconv_tv_admm_tune.py <deconv_tv_admm_tune.py>`_\n      Parameter Tuning for Image Deconvolution with TV Regularization (ADMM Solver)\n   `deconv_microscopy_tv_admm.py <deconv_microscopy_tv_admm.py>`_\n      Deconvolution Microscopy (Single Channel)\n   `deconv_microscopy_allchn_tv_admm.py <deconv_microscopy_allchn_tv_admm.py>`_\n      Deconvolution Microscopy (All Channels)\n   `deconv_ppp_bm3d_admm.py <deconv_ppp_bm3d_admm.py>`_\n      PPP (with BM3D) Image Deconvolution (ADMM Solver)\n   `deconv_ppp_bm3d_apgm.py <deconv_ppp_bm3d_apgm.py>`_\n      PPP (with BM3D) Image Deconvolution (APGM Solver)\n   `deconv_ppp_dncnn_admm.py <deconv_ppp_dncnn_admm.py>`_\n      PPP (with DnCNN) Image Deconvolution (ADMM Solver)\n   `deconv_ppp_dncnn_padmm.py <deconv_ppp_dncnn_padmm.py>`_\n      PPP (with DnCNN) Image Deconvolution (Proximal ADMM Solver)\n   `deconv_ppp_bm4d_admm.py <deconv_ppp_bm4d_admm.py>`_\n      PPP (with BM4D) Volume Deconvolution\n   `deconv_modl_train_foam1.py <deconv_modl_train_foam1.py>`_\n      Deconvolution Training and Reconstructions with MoDL\n   `deconv_odp_train_foam1.py <deconv_odp_train_foam1.py>`_\n      Deconvolution Training and Reconstructions with ODP\n\n\nSparse Coding\n^^^^^^^^^^^^^\n\n   `sparsecode_nn_admm.py <sparsecode_nn_admm.py>`_\n      Non-Negative Basis Pursuit DeNoising (ADMM)\n   `sparsecode_nn_apgm.py <sparsecode_nn_apgm.py>`_\n      Non-Negative Basis Pursuit DeNoising (APGM)\n   `sparsecode_conv_admm.py <sparsecode_conv_admm.py>`_\n      Convolutional Sparse Coding (ADMM)\n   `sparsecode_conv_md_admm.py <sparsecode_conv_md_admm.py>`_\n      Convolutional Sparse Coding with Mask Decoupling (ADMM)\n   `sparsecode_apgm.py <sparsecode_apgm.py>`_\n      Basis Pursuit DeNoising (APGM)\n   `sparsecode_poisson_apgm.py <sparsecode_poisson_apgm.py>`_\n      Non-negative Poisson Loss Reconstruction (APGM)\n\n\nMiscellaneous\n^^^^^^^^^^^^^\n\n   `demosaic_ppp_bm3d_admm.py <demosaic_ppp_bm3d_admm.py>`_\n      PPP (with BM3D) Image Demosaicing\n   `superres_ppp_dncnn_admm.py <superres_ppp_dncnn_admm.py>`_\n      PPP (with DnCNN) Image Superresolution\n   `denoise_l1tv_admm.py <denoise_l1tv_admm.py>`_\n      ℓ1 Total Variation Denoising\n   `denoise_ptv_pdhg.py <denoise_ptv_pdhg.py>`_\n      Polar Total Variation Denoising (PDHG)\n   `denoise_tv_admm.py <denoise_tv_admm.py>`_\n      Total Variation Denoising (ADMM)\n   `denoise_tv_apgm.py <denoise_tv_apgm.py>`_\n      Total Variation Denoising with Constraint (APGM)\n   `denoise_tv_multi.py <denoise_tv_multi.py>`_\n      Comparison of Optimization Algorithms for Total Variation Denoising\n   `denoise_approx_tv_multi.py <denoise_approx_tv_multi.py>`_\n      Denoising with Approximate Total Variation Proximal Operator\n   `denoise_cplx_tv_nlpadmm.py <denoise_cplx_tv_nlpadmm.py>`_\n      Complex Total Variation Denoising with NLPADMM Solver\n   `denoise_cplx_tv_pdhg.py <denoise_cplx_tv_pdhg.py>`_\n      Complex Total Variation Denoising with PDHG Solver\n   `denoise_dncnn_universal.py <denoise_dncnn_universal.py>`_\n      Comparison of DnCNN Variants for Image Denoising\n   `diffusercam_tv_admm.py <diffusercam_tv_admm.py>`_\n      TV-Regularized 3D DiffuserCam Reconstruction\n   `video_rpca_admm.py <video_rpca_admm.py>`_\n      Video Decomposition via Robust PCA\n   `ct_datagen_foam2.py <ct_datagen_foam2.py>`_\n      CT Data Generation for NN Training\n   `deconv_datagen_bsds.py <deconv_datagen_bsds.py>`_\n      Blurred Data Generation (Natural Images) for NN Training\n   `deconv_datagen_foam1.py <deconv_datagen_foam1.py>`_\n      Blurred Data Generation (Foams) for NN Training\n   `denoise_datagen_bsds.py <denoise_datagen_bsds.py>`_\n      Noisy Data Generation for NN Training\n\n\nOrganized by Regularization\n---------------------------\n\nPlug and Play Priors\n^^^^^^^^^^^^^^^^^^^^\n\n   `ct_svmbir_ppp_bm3d_admm_cg.py <ct_svmbir_ppp_bm3d_admm_cg.py>`_\n      PPP (with BM3D) CT Reconstruction (ADMM with CG Subproblem Solver)\n   `ct_svmbir_ppp_bm3d_admm_prox.py <ct_svmbir_ppp_bm3d_admm_prox.py>`_\n      PPP (with BM3D) CT Reconstruction (ADMM with Fast SVMBIR Prox)\n   `ct_fan_svmbir_ppp_bm3d_admm_prox.py <ct_fan_svmbir_ppp_bm3d_admm_prox.py>`_\n      PPP (with BM3D) Fan-Beam CT Reconstruction\n   `deconv_ppp_bm3d_admm.py <deconv_ppp_bm3d_admm.py>`_\n      PPP (with BM3D) Image Deconvolution (ADMM Solver)\n   `deconv_ppp_bm3d_apgm.py <deconv_ppp_bm3d_apgm.py>`_\n      PPP (with BM3D) Image Deconvolution (APGM Solver)\n   `deconv_ppp_dncnn_admm.py <deconv_ppp_dncnn_admm.py>`_\n      PPP (with DnCNN) Image Deconvolution (ADMM Solver)\n   `deconv_ppp_dncnn_padmm.py <deconv_ppp_dncnn_padmm.py>`_\n      PPP (with DnCNN) Image Deconvolution (Proximal ADMM Solver)\n   `deconv_ppp_bm4d_admm.py <deconv_ppp_bm4d_admm.py>`_\n      PPP (with BM4D) Volume Deconvolution\n   `demosaic_ppp_bm3d_admm.py <demosaic_ppp_bm3d_admm.py>`_\n      PPP (with BM3D) Image Demosaicing\n   `superres_ppp_dncnn_admm.py <superres_ppp_dncnn_admm.py>`_\n      PPP (with DnCNN) Image Superresolution\n\n\nTotal Variation\n^^^^^^^^^^^^^^^\n\n   `ct_abel_tv_admm.py <ct_abel_tv_admm.py>`_\n      TV-Regularized Abel Inversion\n   `ct_abel_tv_admm_tune.py <ct_abel_tv_admm_tune.py>`_\n      Parameter Tuning for TV-Regularized Abel Inversion\n   `ct_symcone_tv_padmm.py <ct_symcone_tv_padmm.py>`_\n      TV-Regularized Cone Beam CT for Symmetric Objects\n   `ct_tv_admm.py <ct_tv_admm.py>`_\n      TV-Regularized Sparse-View CT Reconstruction (Integrated Projector)\n   `ct_multi_tv_admm.py <ct_multi_tv_admm.py>`_\n      TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors)\n   `ct_astra_tv_admm.py <ct_astra_tv_admm.py>`_\n      TV-Regularized Sparse-View CT Reconstruction (ASTRA Projector)\n   `ct_astra_3d_tv_admm.py <ct_astra_3d_tv_admm.py>`_\n      3D TV-Regularized Sparse-View CT Reconstruction (ADMM Solver)\n   `ct_astra_3d_tv_padmm.py <ct_astra_3d_tv_padmm.py>`_\n      3D TV-Regularized Sparse-View CT Reconstruction (Proximal ADMM Solver)\n   `ct_astra_weighted_tv_admm.py <ct_astra_weighted_tv_admm.py>`_\n      TV-Regularized Low-Dose CT Reconstruction\n   `ct_svmbir_tv_multi.py <ct_svmbir_tv_multi.py>`_\n      TV-Regularized CT Reconstruction (Multiple Algorithms)\n   `deconv_circ_tv_admm.py <deconv_circ_tv_admm.py>`_\n      Circulant Blur Image Deconvolution with TV Regularization\n   `deconv_tv_admm.py <deconv_tv_admm.py>`_\n      Image Deconvolution with TV Regularization (ADMM Solver)\n   `deconv_tv_admm_tune.py <deconv_tv_admm_tune.py>`_\n      Parameter Tuning for Image Deconvolution with TV Regularization (ADMM Solver)\n   `deconv_tv_padmm.py <deconv_tv_padmm.py>`_\n      Image Deconvolution with TV Regularization (Proximal ADMM Solver)\n   `deconv_microscopy_tv_admm.py <deconv_microscopy_tv_admm.py>`_\n      Deconvolution Microscopy (Single Channel)\n   `deconv_microscopy_allchn_tv_admm.py <deconv_microscopy_allchn_tv_admm.py>`_\n      Deconvolution Microscopy (All Channels)\n   `denoise_l1tv_admm.py <denoise_l1tv_admm.py>`_\n      ℓ1 Total Variation Denoising\n   `denoise_ptv_pdhg.py <denoise_ptv_pdhg.py>`_\n      Polar Total Variation Denoising (PDHG)\n   `denoise_tv_admm.py <denoise_tv_admm.py>`_\n      Total Variation Denoising (ADMM)\n   `denoise_tv_apgm.py <denoise_tv_apgm.py>`_\n      Total Variation Denoising with Constraint (APGM)\n   `denoise_tv_multi.py <denoise_tv_multi.py>`_\n      Comparison of Optimization Algorithms for Total Variation Denoising\n   `denoise_approx_tv_multi.py <denoise_approx_tv_multi.py>`_\n      Denoising with Approximate Total Variation Proximal Operator\n   `denoise_cplx_tv_nlpadmm.py <denoise_cplx_tv_nlpadmm.py>`_\n      Complex Total Variation Denoising with NLPADMM Solver\n   `denoise_cplx_tv_pdhg.py <denoise_cplx_tv_pdhg.py>`_\n      Complex Total Variation Denoising with PDHG Solver\n   `diffusercam_tv_admm.py <diffusercam_tv_admm.py>`_\n      TV-Regularized 3D DiffuserCam Reconstruction\n\n\n\nSparsity\n^^^^^^^^\n\n   `diffusercam_tv_admm.py <diffusercam_tv_admm.py>`_\n      TV-Regularized 3D DiffuserCam Reconstruction\n   `sparsecode_nn_admm.py <sparsecode_nn_admm.py>`_\n      Non-Negative Basis Pursuit DeNoising (ADMM)\n   `sparsecode_nn_apgm.py <sparsecode_nn_apgm.py>`_\n      Non-Negative Basis Pursuit DeNoising (APGM)\n   `sparsecode_conv_admm.py <sparsecode_conv_admm.py>`_\n      Convolutional Sparse Coding (ADMM)\n   `sparsecode_conv_md_admm.py <sparsecode_conv_md_admm.py>`_\n      Convolutional Sparse Coding with Mask Decoupling (ADMM)\n   `sparsecode_apgm.py <sparsecode_apgm.py>`_\n      Basis Pursuit DeNoising (APGM)\n   `sparsecode_poisson_apgm.py <sparsecode_poisson_apgm.py>`_\n      Non-negative Poisson Loss Reconstruction (APGM)\n   `video_rpca_admm.py <video_rpca_admm.py>`_\n      Video Decomposition via Robust PCA\n\n\nMachine Learning\n^^^^^^^^^^^^^^^^\n\n   `ct_datagen_foam2.py <ct_datagen_foam2.py>`_\n      CT Data Generation for NN Training\n   `ct_modl_train_foam2.py <ct_modl_train_foam2.py>`_\n      CT Training and Reconstruction with MoDL\n   `ct_odp_train_foam2.py <ct_odp_train_foam2.py>`_\n      CT Training and Reconstruction with ODP\n   `ct_unet_train_foam2.py <ct_unet_train_foam2.py>`_\n      CT Training and Reconstructions with UNet\n   `deconv_datagen_bsds.py <deconv_datagen_bsds.py>`_\n      Blurred Data Generation (Natural Images) for NN Training\n   `deconv_datagen_foam1.py <deconv_datagen_foam1.py>`_\n      Blurred Data Generation (Foams) for NN Training\n   `deconv_modl_train_foam1.py <deconv_modl_train_foam1.py>`_\n      Deconvolution Training and Reconstructions with MoDL\n   `deconv_odp_train_foam1.py <deconv_odp_train_foam1.py>`_\n      Deconvolution Training and Reconstructions with ODP\n   `denoise_datagen_bsds.py <denoise_datagen_bsds.py>`_\n      Noisy Data Generation for NN Training\n   `denoise_dncnn_train_bsds.py <denoise_dncnn_train_bsds.py>`_\n      Training of DnCNN for Denoising\n   `denoise_dncnn_universal.py <denoise_dncnn_universal.py>`_\n      Comparison of DnCNN Variants for Image Denoising\n\n\nOrganized by Optimization Algorithm\n-----------------------------------\n\nADMM\n^^^^\n\n   `ct_abel_tv_admm.py <ct_abel_tv_admm.py>`_\n      TV-Regularized Abel Inversion\n   `ct_abel_tv_admm_tune.py <ct_abel_tv_admm_tune.py>`_\n      Parameter Tuning for TV-Regularized Abel Inversion\n   `ct_symcone_tv_padmm.py <ct_symcone_tv_padmm.py>`_\n      TV-Regularized Cone Beam CT for Symmetric Objects\n   `ct_astra_tv_admm.py <ct_astra_tv_admm.py>`_\n      TV-Regularized Sparse-View CT Reconstruction (ASTRA Projector)\n   `ct_tv_admm.py <ct_tv_admm.py>`_\n      TV-Regularized Sparse-View CT Reconstruction (Integrated Projector)\n   `ct_astra_3d_tv_admm.py <ct_astra_3d_tv_admm.py>`_\n      3D TV-Regularized Sparse-View CT Reconstruction (ADMM Solver)\n   `ct_astra_weighted_tv_admm.py <ct_astra_weighted_tv_admm.py>`_\n      TV-Regularized Low-Dose CT Reconstruction\n   `ct_multi_tv_admm.py <ct_multi_tv_admm.py>`_\n      TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors)\n   `ct_svmbir_tv_multi.py <ct_svmbir_tv_multi.py>`_\n      TV-Regularized CT Reconstruction (Multiple Algorithms)\n   `ct_svmbir_ppp_bm3d_admm_cg.py <ct_svmbir_ppp_bm3d_admm_cg.py>`_\n      PPP (with BM3D) CT Reconstruction (ADMM with CG Subproblem Solver)\n   `ct_svmbir_ppp_bm3d_admm_prox.py <ct_svmbir_ppp_bm3d_admm_prox.py>`_\n      PPP (with BM3D) CT Reconstruction (ADMM with Fast SVMBIR Prox)\n   `ct_fan_svmbir_ppp_bm3d_admm_prox.py <ct_fan_svmbir_ppp_bm3d_admm_prox.py>`_\n      PPP (with BM3D) Fan-Beam CT Reconstruction\n   `deconv_circ_tv_admm.py <deconv_circ_tv_admm.py>`_\n      Circulant Blur Image Deconvolution with TV Regularization\n   `deconv_tv_admm.py <deconv_tv_admm.py>`_\n      Image Deconvolution with TV Regularization (ADMM Solver)\n   `deconv_tv_admm_tune.py <deconv_tv_admm_tune.py>`_\n      Parameter Tuning for Image Deconvolution with TV Regularization (ADMM Solver)\n   `deconv_microscopy_tv_admm.py <deconv_microscopy_tv_admm.py>`_\n      Deconvolution Microscopy (Single Channel)\n   `deconv_microscopy_allchn_tv_admm.py <deconv_microscopy_allchn_tv_admm.py>`_\n      Deconvolution Microscopy (All Channels)\n   `deconv_ppp_bm3d_admm.py <deconv_ppp_bm3d_admm.py>`_\n      PPP (with BM3D) Image Deconvolution (ADMM Solver)\n   `deconv_ppp_dncnn_admm.py <deconv_ppp_dncnn_admm.py>`_\n      PPP (with DnCNN) Image Deconvolution (ADMM Solver)\n   `deconv_ppp_bm4d_admm.py <deconv_ppp_bm4d_admm.py>`_\n      PPP (with BM4D) Volume Deconvolution\n   `diffusercam_tv_admm.py <diffusercam_tv_admm.py>`_\n      TV-Regularized 3D DiffuserCam Reconstruction\n   `sparsecode_nn_admm.py <sparsecode_nn_admm.py>`_\n      Non-Negative Basis Pursuit DeNoising (ADMM)\n   `sparsecode_conv_admm.py <sparsecode_conv_admm.py>`_\n      Convolutional Sparse Coding (ADMM)\n   `sparsecode_conv_md_admm.py <sparsecode_conv_md_admm.py>`_\n      Convolutional Sparse Coding with Mask Decoupling (ADMM)\n   `demosaic_ppp_bm3d_admm.py <demosaic_ppp_bm3d_admm.py>`_\n      PPP (with BM3D) Image Demosaicing\n   `superres_ppp_dncnn_admm.py <superres_ppp_dncnn_admm.py>`_\n      PPP (with DnCNN) Image Superresolution\n   `denoise_l1tv_admm.py <denoise_l1tv_admm.py>`_\n      ℓ1 Total Variation Denoising\n   `denoise_tv_admm.py <denoise_tv_admm.py>`_\n      Total Variation Denoising (ADMM)\n   `denoise_tv_multi.py <denoise_tv_multi.py>`_\n      Comparison of Optimization Algorithms for Total Variation Denoising\n   `denoise_approx_tv_multi.py <denoise_approx_tv_multi.py>`_\n      Denoising with Approximate Total Variation Proximal Operator\n   `video_rpca_admm.py <video_rpca_admm.py>`_\n      Video Decomposition via Robust PCA\n\n\nLinearized ADMM\n^^^^^^^^^^^^^^^\n\n    `ct_svmbir_tv_multi.py <ct_svmbir_tv_multi.py>`_\n       TV-Regularized CT Reconstruction (Multiple Algorithms)\n    `denoise_tv_multi.py <denoise_tv_multi.py>`_\n       Comparison of Optimization Algorithms for Total Variation Denoising\n\n\nProximal ADMM\n^^^^^^^^^^^^^\n\n    `ct_astra_3d_tv_padmm.py <ct_astra_3d_tv_padmm.py>`_\n       3D TV-Regularized Sparse-View CT Reconstruction (Proximal ADMM Solver)\n    `deconv_tv_padmm.py <deconv_tv_padmm.py>`_\n       Image Deconvolution with TV Regularization (Proximal ADMM Solver)\n    `denoise_tv_multi.py <denoise_tv_multi.py>`_\n       Comparison of Optimization Algorithms for Total Variation Denoising\n    `deconv_ppp_dncnn_padmm.py <deconv_ppp_dncnn_padmm.py>`_\n       PPP (with DnCNN) Image Deconvolution (Proximal ADMM Solver)\n\n\nNon-linear Proximal ADMM\n^^^^^^^^^^^^^^^^^^^^^^^^\n\n    `denoise_cplx_tv_nlpadmm.py <denoise_cplx_tv_nlpadmm.py>`_\n       Complex Total Variation Denoising with NLPADMM Solver\n\n\nPDHG\n^^^^\n\n    `ct_svmbir_tv_multi.py <ct_svmbir_tv_multi.py>`_\n       TV-Regularized CT Reconstruction (Multiple Algorithms)\n    `denoise_ptv_pdhg.py <denoise_ptv_pdhg.py>`_\n       Polar Total Variation Denoising (PDHG)\n    `denoise_tv_multi.py <denoise_tv_multi.py>`_\n       Comparison of Optimization Algorithms for Total Variation Denoising\n    `denoise_cplx_tv_pdhg.py <denoise_cplx_tv_pdhg.py>`_\n       Complex Total Variation Denoising with PDHG Solver\n\n\nPGM\n^^^\n\n   `deconv_ppp_bm3d_apgm.py <deconv_ppp_bm3d_apgm.py>`_\n      PPP (with BM3D) Image Deconvolution (APGM Solver)\n   `sparsecode_apgm.py <sparsecode_apgm.py>`_\n      Basis Pursuit DeNoising (APGM)\n   `sparsecode_nn_apgm.py <sparsecode_nn_apgm.py>`_\n      Non-Negative Basis Pursuit DeNoising (APGM)\n   `sparsecode_poisson_apgm.py <sparsecode_poisson_apgm.py>`_\n      Non-negative Poisson Loss Reconstruction (APGM)\n   `denoise_tv_apgm.py <denoise_tv_apgm.py>`_\n      Total Variation Denoising with Constraint (APGM)\n   `denoise_approx_tv_multi.py <denoise_approx_tv_multi.py>`_\n      Denoising with Approximate Total Variation Proximal Operator\n\n\nPCG\n^^^\n\n   `ct_astra_noreg_pcg.py <ct_astra_noreg_pcg.py>`_\n      CT Reconstruction with CG and PCG\n"
  },
  {
    "path": "examples/scripts/ct_abel_tv_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nTV-Regularized Abel Inversion\n=============================\n\nThis example demonstrates a total variation (TV) regularized Abel\ninversion by solving the problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - A \\mathbf{x}\n  \\|_2^2 + \\lambda \\| C \\mathbf{x} \\|_1 \\;,$$\n\nwhere $A$ is the Abel projector (with an implementation based on a\nprojector from PyAbel :cite:`pyabel-2022`), $\\mathbf{y}$ is the measured\ndata, $C$ is a 2D finite difference operator, and $\\mathbf{x}$ is the\nsolution.\n\"\"\"\n\nimport numpy as np\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.examples import create_circular_phantom\nfrom scico.linop.xray.abel import AbelTransform\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nN = 256  # image size\nx_gt = create_circular_phantom((N, N), [0.4 * N, 0.2 * N, 0.1 * N], [1, 0, 0.5])\n\n\n\"\"\"\nSet up the forward operator and create a test measurement.\n\"\"\"\nA = AbelTransform(x_gt.shape)\ny = A @ x_gt\nnp.random.seed(12345)\ny = y + np.random.normal(size=y.shape).astype(np.float32)\n\n\n\"\"\"\nCompute inverse Abel transform solution.\n\"\"\"\nx_inv = A.inverse(y)\n\n\n\"\"\"\nSet up the problem to be solved. Anisotropic TV, which gives slightly\nbetter performance than isotropic TV for this problem, is used here.\n\"\"\"\nf = loss.SquaredL2Loss(y=y, A=A)\nλ = 2.35e1  # ℓ1 norm regularization parameter\ng = λ * functional.L1Norm()  # Note the use of anisotropic TV\nC = linop.FiniteDifference(input_shape=x_gt.shape)\n\n\n\"\"\"\nSet up ADMM solver object.\n\"\"\"\nρ = 1.03e2  # ADMM penalty parameter\nmaxiter = 100  # number of ADMM iterations\ncg_tol = 1e-4  # CG relative tolerance\ncg_maxiter = 25  # maximum CG iterations per ADMM iteration\n\nsolver = ADMM(\n    f=f,\n    g_list=[g],\n    C_list=[C],\n    rho_list=[ρ],\n    x0=snp.clip(x_inv, 0.0, 1.0),\n    maxiter=maxiter,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": cg_tol, \"maxiter\": cg_maxiter}),\n    itstat_options={\"display\": True, \"period\": 10},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nsolver.solve()\nx_tv = snp.clip(solver.x, 0.0, 1.0)\n\n\n\"\"\"\nShow results.\n\"\"\"\nnorm = plot.matplotlib.colors.Normalize(vmin=-0.1, vmax=1.2)\nfig, ax = plot.subplots(nrows=2, ncols=2, figsize=(12, 12))\nplot.imview(x_gt, title=\"Ground Truth\", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 0], norm=norm)\nplot.imview(y, title=\"Measurement\", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 1])\nplot.imview(\n    x_inv,\n    title=\"Inverse Abel: %.2f (dB)\" % metric.psnr(x_gt, x_inv),\n    cmap=plot.cm.Blues,\n    fig=fig,\n    ax=ax[1, 0],\n    norm=norm,\n)\nplot.imview(\n    x_tv,\n    title=\"TV-Regularized Inversion: %.2f (dB)\" % metric.psnr(x_gt, x_tv),\n    cmap=plot.cm.Blues,\n    fig=fig,\n    ax=ax[1, 1],\n    norm=norm,\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_abel_tv_admm_tune.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nParameter Tuning for TV-Regularized Abel Inversion\n==================================================\n\nThis example demonstrates the use of\n[scico.ray.tune](../_autosummary/scico.ray.tune.rst) to tune\nparameters for the companion [example script](ct_abel_tv_admm.rst). The\n`ray.tune` class API is used in this example.\n\nThis script is hard-coded to run on CPU only to avoid the large number of\nwarnings that are emitted when GPU resources are requested but not\navailable, and due to the difficulty of suppressing these warnings in a\nway that does not force use of the CPU only. To enable GPU usage, comment\nout the `os.environ` statements near the beginning of the script, and\nchange the value of the \"gpu\" entry in the `resources` dict from 0 to 1.\nNote that two environment variables are set to suppress the warnings\nbecause `JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but\nthis change has yet to be correctly implemented\n(see [google/jax#6805](https://github.com/google/jax/issues/6805) and\n[google/jax#10272](https://github.com/google/jax/pull/10272)).\n\"\"\"\n\n# isort: off\nimport os\n\nos.environ[\"JAX_PLATFORM_NAME\"] = \"cpu\"\nos.environ[\"JAX_PLATFORMS\"] = \"cpu\"\n\nimport numpy as np\n\nimport logging\nimport ray\n\nray.init(logging_level=logging.ERROR)  # need to call init before jax import: ray-project/ray#44087\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.examples import create_circular_phantom\nfrom scico.linop.xray.abel import AbelTransform\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.ray import tune\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nN = 256  # image size\nx_gt = create_circular_phantom((N, N), [0.4 * N, 0.2 * N, 0.1 * N], [1, 0, 0.5])\n\n\n\"\"\"\nSet up the forward operator and create a test measurement.\n\"\"\"\nA = AbelTransform(x_gt.shape)\ny = A @ x_gt\nnp.random.seed(12345)\ny = y + np.random.normal(size=y.shape).astype(np.float32)\n\n\n\"\"\"\nCompute inverse Abel transform solution for use as initial solution.\n\"\"\"\nx_inv = A.inverse(y)\nx0 = snp.clip(x_inv, 0.0, 1.0)\n\n\n\"\"\"\nDefine performance evaluation class.\n\"\"\"\n\n\nclass Trainable(tune.Trainable):\n    \"\"\"Parameter evaluation class.\"\"\"\n\n    def setup(self, config, x_gt, x0, y):\n        \"\"\"This method initializes a new parameter evaluation object. It\n        is called once when a new parameter evaluation object is created.\n        The `config` parameter is a dict of specific parameters for\n        evaluation of a single parameter set (a pair of parameters in\n        this case). The remaining parameters are objects that are passed\n        to the evaluation function via the ray object store.\n        \"\"\"\n        # Get arrays passed by tune call.\n        self.x_gt, self.x0, self.y = snp.array(x_gt), snp.array(x0), snp.array(y)\n        # Set up problem to be solved.\n        self.A = AbelTransform(self.x_gt.shape)\n        self.f = loss.SquaredL2Loss(y=self.y, A=self.A)\n        self.C = linop.FiniteDifference(input_shape=self.x_gt.shape)\n        self.reset_config(config)\n\n    def reset_config(self, config):\n        \"\"\"This method is only required when `scico.ray.tune.Tuner` is\n        initialized with `reuse_actors` set to ``True`` (the default). In\n        this case, a set of parameter evaluation processes and\n        corresponding objects are created once (including initialization\n        via a call to the `setup` method), and this method is called when\n        switching to evaluation of a different parameter configuration.\n        If `reuse_actors` is set to ``False``, then a new process and\n        object are created for each parameter configuration, and this\n        method is not used.\n        \"\"\"\n        # Extract solver parameters from config dict.\n        λ, ρ = config[\"lambda\"], config[\"rho\"]\n        # Set up parameter-dependent functional.\n        g = λ * functional.L1Norm()\n        # Define solver.\n        cg_tol = 1e-4\n        cg_maxiter = 25\n        self.solver = ADMM(\n            f=self.f,\n            g_list=[g],\n            C_list=[self.C],\n            rho_list=[ρ],\n            x0=self.x0,\n            maxiter=10,\n            subproblem_solver=LinearSubproblemSolver(\n                cg_kwargs={\"tol\": cg_tol, \"maxiter\": cg_maxiter}\n            ),\n        )\n        return True\n\n    def step(self):\n        \"\"\"This method is called for each step in the evaluation of a\n        single parameter configuration. The maximum number of times it\n        can be called is controlled by the `num_iterations` parameter\n        in the initialization of a `scico.ray.tune.Tuner` object.\n        \"\"\"\n        # Perform 10 solver steps for every ray.tune step\n        x_tv = snp.clip(self.solver.solve(), 0.0, 1.0)\n        return {\"psnr\": float(metric.psnr(self.x_gt, x_tv))}\n\n\n\"\"\"\nDefine parameter search space and resources per trial.\n\"\"\"\nconfig = {\"lambda\": tune.loguniform(1e0, 1e2), \"rho\": tune.loguniform(1e1, 1e3)}\nresources = {\"gpu\": 0, \"cpu\": 1}  # gpus per trial, cpus per trial\n\n\n\"\"\"\nRun parameter search.\n\"\"\"\ntuner = tune.Tuner(\n    tune.with_parameters(Trainable, x_gt=x_gt, x0=x0, y=y),\n    param_space=config,\n    resources=resources,\n    metric=\"psnr\",\n    mode=\"max\",\n    num_samples=100,  # perform 100 parameter evaluations\n    num_iterations=10,  # perform at most 10 steps for each parameter evaluation\n)\nresults = tuner.fit()\nray.shutdown()\n\n\n\"\"\"\nDisplay best parameters and corresponding performance.\n\"\"\"\nbest_result = results.get_best_result()\nbest_config = best_result.config\nprint(f\"Best PSNR: {best_result.metrics['psnr']:.2f} dB\")\nprint(\"Best config: \" + \", \".join([f\"{k}: {v:.2e}\" for k, v in best_config.items()]))\n\n\n\"\"\"\nPlot parameter values visited during parameter search. Marker sizes are\nproportional to number of iterations run at each parameter pair. The best\npoint in the parameter space is indicated in red.\n\"\"\"\nfig = plot.figure(figsize=(8, 8))\ntrials = results.get_dataframe()\nfor t in trials.iloc:\n    n = t[\"training_iteration\"]\n    plot.plot(\n        t[\"config/lambda\"],\n        t[\"config/rho\"],\n        ptyp=\"loglog\",\n        lw=0,\n        ms=(0.5 + 1.5 * n),\n        marker=\"o\",\n        mfc=\"blue\",\n        mec=\"blue\",\n        fig=fig,\n    )\nplot.plot(\n    best_config[\"lambda\"],\n    best_config[\"rho\"],\n    ptyp=\"loglog\",\n    title=\"Parameter search sampling locations\\n(marker size proportional to number of iterations)\",\n    xlbl=r\"$\\rho$\",\n    ylbl=r\"$\\lambda$\",\n    lw=0,\n    ms=5.0,\n    marker=\"o\",\n    mfc=\"red\",\n    mec=\"red\",\n    fig=fig,\n)\nax = fig.axes[0]\nax.set_xlim([config[\"rho\"].lower, config[\"rho\"].upper])\nax.set_ylim([config[\"lambda\"].lower, config[\"lambda\"].upper])\nfig.show()\n\n\n\"\"\"\nPlot parameter values visited during parameter search and corresponding\nreconstruction PSNRs.The best point in the parameter space is indicated\nin red.\n\"\"\"\n𝜌 = [t[\"config/rho\"] for t in trials.iloc]\n𝜆 = [t[\"config/lambda\"] for t in trials.iloc]\npsnr = [t[\"psnr\"] for t in trials.iloc]\nminpsnr = min(max(psnr), 20.0)\n𝜌, 𝜆, psnr = zip(*filter(lambda x: x[2] >= minpsnr, zip(𝜌, 𝜆, psnr)))\nfig, ax = plot.subplots(figsize=(10, 8))\nsc = ax.scatter(𝜌, 𝜆, c=psnr, cmap=plot.cm.plasma_r)\nfig.colorbar(sc)\nplot.plot(\n    best_config[\"lambda\"],\n    best_config[\"rho\"],\n    ptyp=\"loglog\",\n    lw=0,\n    ms=12.0,\n    marker=\"2\",\n    mfc=\"red\",\n    mec=\"red\",\n    fig=fig,\n    ax=ax,\n)\nax.set_xscale(\"log\")\nax.set_yscale(\"log\")\nax.set_xlabel(r\"$\\rho$\")\nax.set_ylabel(r\"$\\lambda$\")\nax.set_title(\"PSNR at each sample location\\n(values below 20 dB omitted)\")\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_astra_3d_tv_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\n3D TV-Regularized Sparse-View CT Reconstruction (ADMM Solver)\n=============================================================\n\nThis example demonstrates solution of a sparse-view, 3D CT\nreconstruction problem with isotropic total variation (TV)\nregularization\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - C \\mathbf{x}\n  \\|_2^2 + \\lambda \\| D \\mathbf{x} \\|_{2,1} \\;,$$\n\nwhere $C$ is the X-ray transform (the CT forward projection operator),\n$\\mathbf{y}$ is the sinogram, $D$ is a 3D finite difference operator,\nand $\\mathbf{x}$ is the reconstructed image.\n\nIn this example the problem is solved via ADMM, while proximal\nADMM is used in a [companion example](ct_astra_3d_tv_padmm.rst).\n\"\"\"\n\nimport numpy as np\n\nfrom mpl_toolkits.axes_grid1 import make_axes_locatable\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.examples import create_tangle_phantom\nfrom scico.linop.xray.astra import XRayTransform3D\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image and projector.\n\"\"\"\nNx = 128\nNy = 256\nNz = 64\n\ntangle = snp.array(create_tangle_phantom(Nx, Ny, Nz))\n\nn_projection = 10  # number of projections\nangles = np.linspace(0, np.pi, n_projection, endpoint=False)  # evenly spaced projection angles\nC = XRayTransform3D(\n    tangle.shape, det_count=[Nz, max(Nx, Ny)], det_spacing=[1.0, 1.0], angles=angles\n)  # CT projection operator\ny = C @ tangle  # sinogram\n\n\n\"\"\"\nSet up problem and solver.\n\"\"\"\nλ = 2e0  # ℓ2,1 norm regularization parameter\nρ = 5e0  # ADMM penalty parameter\nmaxiter = 25  # number of ADMM iterations\ncg_tol = 1e-4  # CG relative tolerance\ncg_maxiter = 25  # maximum CG iterations per ADMM iteration\n\n# The append=0 option makes the results of horizontal and vertical\n# finite differences the same shape, which is required for the L21Norm,\n# which is used so that g(Ax) corresponds to isotropic TV.\nD = linop.FiniteDifference(input_shape=tangle.shape, append=0)\ng = λ * functional.L21Norm()\nf = loss.SquaredL2Loss(y=y, A=C)\n\nsolver = ADMM(\n    f=f,\n    g_list=[g],\n    C_list=[D],\n    rho_list=[ρ],\n    x0=C.T(y),\n    maxiter=maxiter,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": cg_tol, \"maxiter\": cg_maxiter}),\n    itstat_options={\"display\": True, \"period\": 5},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\ntangle_recon = solver.solve()\n\nprint(\n    \"TV Restruction\\nSNR: %.2f (dB), MAE: %.3f\"\n    % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon))\n)\n\n\n\"\"\"\nShow the recovered image.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 6))\nplot.imview(\n    tangle[32],\n    title=\"Ground truth (central slice)\",\n    cmap=plot.cm.Blues,\n    cbar=None,\n    fig=fig,\n    ax=ax[0],\n)\nplot.imview(\n    tangle_recon[32],\n    title=\"TV Reconstruction (central slice)\\nSNR: %.2f (dB), MAE: %.3f\"\n    % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)),\n    cmap=plot.cm.Blues,\n    fig=fig,\n    ax=ax[1],\n)\ndivider = make_axes_locatable(ax[1])\ncax = divider.append_axes(\"right\", size=\"5%\", pad=0.2)\nfig.colorbar(ax[1].get_images()[0], cax=cax, label=\"arbitrary units\")\nfig.show()\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_astra_3d_tv_padmm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\n3D TV-Regularized Sparse-View CT Reconstruction (Proximal ADMM Solver)\n======================================================================\n\nThis example demonstrates solution of a sparse-view, 3D CT\nreconstruction problem with isotropic total variation (TV)\nregularization\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - C \\mathbf{x}\n  \\|_2^2 + \\lambda \\| D \\mathbf{x} \\|_{2,1} \\;,$$\n\nwhere $C$ is the X-ray transform (the CT forward projection operator),\n$\\mathbf{y}$ is the sinogram, $D$ is a 3D finite difference operator,\nand $\\mathbf{x}$ is the reconstructed image.\n\nIn this example the problem is solved via proximal ADMM, while standard\nADMM is used in a [companion example](ct_astra_3d_tv_admm.rst).\n\"\"\"\n\nimport numpy as np\n\nfrom mpl_toolkits.axes_grid1 import make_axes_locatable\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.examples import create_tangle_phantom\nfrom scico.linop.xray.astra import XRayTransform3D, angle_to_vector\nfrom scico.optimize import ProximalADMM\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image and projector.\n\"\"\"\nNx = 128\nNy = 256\nNz = 64\n\ntangle = snp.array(create_tangle_phantom(Nx, Ny, Nz))\n\nn_projection = 10  # number of projections\nangles = np.linspace(0, np.pi, n_projection, endpoint=False)  # evenly spaced projection angles\ndet_spacing = [1.0, 1.0]\ndet_count = [Nz, max(Nx, Ny)]\nvectors = angle_to_vector(det_spacing, angles)\n\n# It would have been more straightforward to use the det_spacing and angles keywords\n# in this case (since vectors is just computed directly from these two quantities), but\n# the more general form is used here as a demonstration.\nC = XRayTransform3D(tangle.shape, det_count=det_count, vectors=vectors)  # CT projection operator\ny = C @ tangle  # sinogram\n\n\nr\"\"\"\nSet up problem and solver. We want to minimize the functional\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - C \\mathbf{x}\n  \\|_2^2 + \\lambda \\| D \\mathbf{x} \\|_{2,1} \\;,$$\n\nwhere $C$ is the X-ray transform and $D$ is a finite difference\noperator. This problem can be expressed as\n\n  $$\\mathrm{argmin}_{\\mathbf{x}, \\mathbf{z}} \\; (1/2) \\| \\mathbf{y} -\n  \\mathbf{z}_0 \\|_2^2 + \\lambda \\| \\mathbf{z}_1 \\|_{2,1} \\;\\;\n  \\text{such that} \\;\\; \\mathbf{z}_0 = C \\mathbf{x} \\;\\; \\text{and} \\;\\;\n  \\mathbf{z}_1 = D \\mathbf{x} \\;,$$\n\nwhich can be written in the form of a standard ADMM problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}, \\mathbf{z}} \\; f(\\mathbf{x}) + g(\\mathbf{z})\n  \\;\\; \\text{such that} \\;\\; A \\mathbf{x} + B \\mathbf{z} = \\mathbf{c}$$\n\nwith\n\n  $$f = 0 \\qquad g = g_0 + g_1$$\n  $$g_0(\\mathbf{z}_0) = (1/2) \\| \\mathbf{y} - \\mathbf{z}_0 \\|_2^2 \\qquad\n  g_1(\\mathbf{z}_1) = \\lambda \\| \\mathbf{z}_1 \\|_{2,1}$$\n  $$A = \\left( \\begin{array}{c} C \\\\ D \\end{array} \\right) \\qquad\n  B = \\left( \\begin{array}{cc} -I & 0 \\\\ 0 & -I \\end{array} \\right) \\qquad\n  \\mathbf{c} = \\left( \\begin{array}{c} 0 \\\\ 0 \\end{array} \\right) \\;.$$\n\nThis is a more complex splitting than that used in the\n[companion example](ct_astra_3d_tv_admm.rst), but it allows the use of a\nproximal ADMM solver in a way that avoids the need for the conjugate\ngradient sub-iterations used by the ADMM solver in the\n[companion example](ct_astra_3d_tv_admm.rst).\n\"\"\"\n𝛼 = 1e2  # improve problem conditioning by balancing C and D components of A\nλ = 2e0  # ℓ2,1 norm regularization parameter\nρ = 5e-3  # ADMM penalty parameter\nmaxiter = 1000  # number of ADMM iterations\n\nf = functional.ZeroFunctional()\ng0 = loss.SquaredL2Loss(y=y)\ng1 = (λ / 𝛼) * functional.L21Norm()\ng = functional.SeparableFunctional((g0, g1))\nD = linop.FiniteDifference(input_shape=tangle.shape, append=0)\n\nA = linop.VerticalStack((C, 𝛼 * D))\nmu, nu = ProximalADMM.estimate_parameters(A)\n\nsolver = ProximalADMM(\n    f=f,\n    g=g,\n    A=A,\n    B=None,\n    rho=ρ,\n    mu=mu,\n    nu=nu,\n    maxiter=maxiter,\n    itstat_options={\"display\": True, \"period\": 50},\n)\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\ntangle_recon = solver.solve()\n\nprint(\n    \"TV Restruction\\nSNR: %.2f (dB), MAE: %.3f\"\n    % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon))\n)\n\n\n\"\"\"\nShow the recovered image.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 6))\nplot.imview(\n    tangle[32],\n    title=\"Ground truth (central slice)\",\n    cmap=plot.cm.Blues,\n    cbar=None,\n    fig=fig,\n    ax=ax[0],\n)\nplot.imview(\n    tangle_recon[32],\n    title=\"TV Reconstruction (central slice)\\nSNR: %.2f (dB), MAE: %.3f\"\n    % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)),\n    cmap=plot.cm.Blues,\n    fig=fig,\n    ax=ax[1],\n)\ndivider = make_axes_locatable(ax[1])\ncax = divider.append_axes(\"right\", size=\"5%\", pad=0.2)\nfig.colorbar(ax[1].get_images()[0], cax=cax, label=\"arbitrary units\")\nfig.show()\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_astra_noreg_pcg.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nCT Reconstruction with CG and PCG\n=================================\n\nThis example demonstrates a simple iterative CT reconstruction using\nconjugate gradient (CG) and preconditioned conjugate gradient (PCG)\nalgorithms to solve the problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - A \\mathbf{x}\n  \\|_2^2 \\;,$$\n\nwhere $A$ is the X-ray transform (the CT forward projection operator),\n$\\mathbf{y}$ is the sinogram, and $\\mathbf{x}$ is the reconstructed image.\n\"\"\"\n\nfrom time import time\n\nimport numpy as np\n\nimport jax.numpy as jnp\n\nfrom xdesign import Foam, discrete_phantom\n\nfrom scico import loss, plot\nfrom scico.linop import CircularConvolve\nfrom scico.linop.xray.astra import XRayTransform2D\nfrom scico.solver import cg\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nN = 256  # phantom size\nx_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)\nx_gt = jnp.array(x_gt)  # convert to jax type\n\n\n\"\"\"\nConfigure a CT projection operator and generate synthetic measurements.\n\"\"\"\nn_projection = N  # matches the phantom size so this is not few-view CT\nangles = np.linspace(0, np.pi, n_projection, endpoint=False)  # evenly spaced projection angles\nA = 1 / N * XRayTransform2D(x_gt.shape, N, 1.0, angles)  # CT projection operator\ny = A @ x_gt  # sinogram\n\n\nr\"\"\"\nForward and back project a single pixel (Kronecker delta) to compute\nan approximate impulse response for $\\mathbf{A}^T \\mathbf{A}$.\n\"\"\"\nH = CircularConvolve.from_operator(A.T @ A)\n\n\nr\"\"\"\nInvert in the Fourier domain to form a preconditioner $\\mathbf{M}\n\\approx (\\mathbf{A}^T \\mathbf{A})^{-1}$ (see\n:cite:`clinthorne-1993-preconditioning` Section V.A. for more details).\n\"\"\"\n# γ limits the gain of the preconditioner; higher gives a weaker filter.\nγ = 1e-2\n\n# The imaginary part comes from numerical errors in A.T and needs to be\n# removed to ensure H is symmetric, positive definite.\nfrequency_response = np.real(H.h_dft)\ninv_frequency_response = 1 / (frequency_response + γ)\n# Using circular convolution without padding is sufficient here because\n# M is approximate anyway.\nM = CircularConvolve(inv_frequency_response, x_gt.shape, h_is_dft=True)\n\n\nr\"\"\"\nCheck that $\\mathbf{M}$ does approximately invert $\\mathbf{A}^T \\mathbf{A}$.\n\"\"\"\nplot_args = dict(\n    norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.5), cmap=plot.matplotlib.cm.Blues_r\n)\n\nfig, axes = plot.subplots(nrows=1, ncols=3, figsize=(12, 4.5))\nplot.imview(x_gt, title=\"Ground truth, $x_{gt}$\", fig=fig, ax=axes[0], **plot_args)\nplot.imview(\n    A.T @ A @ x_gt, title=r\"$\\mathbf{A}^T \\mathbf{A} x_{gt}$\", fig=fig, ax=axes[1], **plot_args\n)\nplot.imview(\n    M @ A.T @ A @ x_gt,\n    title=r\"$\\mathbf{M} \\mathbf{A}^T \\mathbf{A} x_{gt}$\",\n    fig=fig,\n    ax=axes[2],\n    **plot_args,\n)\nfig.suptitle(r\"$\\mathbf{M}$ approximately inverts $\\mathbf{A}^T \\mathbf{A}$\")\nfig.tight_layout()\nfig.colorbar(\n    axes[2].get_images()[0],\n    ax=axes,\n    location=\"right\",\n    shrink=0.82,\n    pad=0.02,\n    label=\"Arbitrary Units\",\n)\nfig.show()\n\n\n\"\"\"\nReconstruct with both standard and preconditioned conjugate gradient.\n\"\"\"\nstart_time = time()\nx_cg, info_cg = cg(\n    A.T @ A,\n    A.T @ y,\n    jnp.zeros(A.input_shape, dtype=A.input_dtype),\n    tol=1e-5,\n    info=True,\n)\ntime_cg = time() - start_time\n\nstart_time = time()\nx_pcg, info_pcg = cg(\n    A.T @ A,\n    A.T @ y,\n    jnp.zeros(A.input_shape, dtype=A.input_dtype),\n    tol=2e-5,  # preconditioning affects the problem scaling so tol differs between CG and PCG\n    info=True,\n    M=M,\n)\ntime_pcg = time() - start_time\n\n\n\"\"\"\nCompare CG and PCG in terms of reconstruction time and data fidelity.\n\"\"\"\nf_cg = loss.SquaredL2Loss(y=A.T @ y, A=A.T @ A)\nf_data = loss.SquaredL2Loss(y=y, A=A)\nprint(\n    f\"{'Method':8s}{'Iterations':>11s}{'Time (s)':>12s}{'||ATAx - ATy||':>17s}{'||Ax - y||':>15s}\"\n)\nprint(\n    f\"{'CG':8s}{info_cg['num_iter']:>11d}{time_cg:>12.2f}{f_cg(x_cg):>17.2e}{f_data(x_cg):>15.2e}\"\n)\nprint(\n    f\"{'PCG':8s}{info_pcg['num_iter']:>11d}{time_pcg:>12.2f}{f_cg(x_pcg):>17.2e}\"\n    f\"{f_data(x_pcg):>15.2e}\"\n)\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_astra_tv_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nTV-Regularized Sparse-View CT Reconstruction (ASTRA Projector)\n==============================================================\n\nThis example demonstrates solution of a sparse-view CT reconstruction\nproblem with isotropic total variation (TV) regularization\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - A \\mathbf{x}\n  \\|_2^2 + \\lambda \\| C \\mathbf{x} \\|_{2,1} \\;,$$\n\nwhere $A$ is the X-ray transform (the CT forward projection operator),\n$\\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and\n$\\mathbf{x}$ is the reconstructed image. This example uses the CT\nprojector provided by the astra package, while the companion\n[example script](ct_tv_admm.rst) uses the projector integrated into\nscico.\n\"\"\"\n\nimport numpy as np\n\nfrom mpl_toolkits.axes_grid1 import make_axes_locatable\nfrom xdesign import Foam, discrete_phantom\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.linop.xray.astra import XRayTransform2D\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nN = 512  # phantom size\nnp.random.seed(1234)\nx_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N))\n\n\n\"\"\"\nConfigure CT projection operator and generate synthetic measurements.\n\"\"\"\nn_projection = 45  # number of projections\nangles = np.linspace(0, np.pi, n_projection, endpoint=False)  # evenly spaced projection angles\ndet_count = int(N * 1.05 / np.sqrt(2.0))\ndet_spacing = np.sqrt(2)\nA = XRayTransform2D(x_gt.shape, det_count, det_spacing, angles)  # CT projection operator\ny = A @ x_gt  # sinogram\n\n\n\"\"\"\nSet up problem functional and ADMM solver object.\n\"\"\"\nλ = 2e0  # ℓ1 norm regularization parameter\nρ = 5e0  # ADMM penalty parameter\nmaxiter = 25  # number of ADMM iterations\ncg_tol = 1e-4  # CG relative tolerance\ncg_maxiter = 25  # maximum CG iterations per ADMM iteration\n\n# The append=0 option makes the results of horizontal and vertical\n# finite differences the same shape, which is required for the L21Norm,\n# which is used so that g(Cx) corresponds to isotropic TV.\nC = linop.FiniteDifference(input_shape=x_gt.shape, append=0)\ng = λ * functional.L21Norm()\nf = loss.SquaredL2Loss(y=y, A=A)\nx0 = snp.clip(A.fbp(y), 0, 1.0)\n\nsolver = ADMM(\n    f=f,\n    g_list=[g],\n    C_list=[C],\n    rho_list=[ρ],\n    x0=x0,\n    maxiter=maxiter,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": cg_tol, \"maxiter\": cg_maxiter}),\n    itstat_options={\"display\": True, \"period\": 5},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nsolver.solve()\nhist = solver.itstat_object.history(transpose=True)\nx_reconstruction = snp.clip(solver.x, 0, 1.0)\n\n\n\"\"\"\nShow the recovered image.\n\"\"\"\n\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(x_gt, title=\"Ground truth\", cbar=None, fig=fig, ax=ax[0])\nplot.imview(\n    x0,\n    title=\"FBP Reconstruction: \\nSNR: %.2f (dB), MAE: %.3f\"\n    % (metric.snr(x_gt, x0), metric.mae(x_gt, x0)),\n    cbar=None,\n    fig=fig,\n    ax=ax[1],\n)\nplot.imview(\n    x_reconstruction,\n    title=\"TV Reconstruction\\nSNR: %.2f (dB), MAE: %.3f\"\n    % (metric.snr(x_gt, x_reconstruction), metric.mae(x_gt, x_reconstruction)),\n    fig=fig,\n    ax=ax[2],\n)\ndivider = make_axes_locatable(ax[2])\ncax = divider.append_axes(\"right\", size=\"5%\", pad=0.2)\nfig.colorbar(ax[2].get_images()[0], cax=cax, label=\"arbitrary units\")\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\nplot.plot(\n    hist.Objective,\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    ylbl=\"Functional value\",\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[1],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_astra_weighted_tv_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nTV-Regularized Low-Dose CT Reconstruction\n=========================================\n\nThis example demonstrates solution of a low-dose CT reconstruction\nproblem with isotropic total variation (TV) regularization\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - A \\mathbf{x}\n  \\|_W^2 + \\lambda \\| C \\mathbf{x} \\|_{2,1} \\;,$$\n\nwhere $A$ is the X-ray transform (the CT forward projection),\n$\\mathbf{y}$ is the sinogram, the norm weighting $W$ is chosen so that\nthe weighted norm is an approximation to the Poisson negative log\nlikelihood :cite:`sauer-1993-local`, $C$ is a 2D finite difference\noperator, and $\\mathbf{x}$ is the reconstructed image.\n\"\"\"\n\nimport numpy as np\n\nfrom xdesign import Soil, discrete_phantom\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.linop.xray.astra import XRayTransform2D\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nN = 512  # phantom size\nnp.random.seed(0)\nx_gt = discrete_phantom(Soil(porosity=0.80), size=384)\nx_gt = np.ascontiguousarray(np.pad(x_gt, (64, 64)))\nx_gt = np.clip(x_gt, 0, np.inf)  # clip to positive values\nx_gt = snp.array(x_gt)  # convert to jax type\n\n\n\"\"\"\nConfigure CT projection operator and generate synthetic measurements.\n\"\"\"\nn_projection = 360  # number of projections\nIo = 1e3  # source flux\n𝛼 = 1e-2  # attenuation coefficient\nangles = np.linspace(0, 2 * np.pi, n_projection, endpoint=False)  # evenly spaced projection angles\nA = XRayTransform2D(x_gt.shape, N, 1.0, angles)  # CT projection operator\ny_c = A @ x_gt  # sinogram\n\n\nr\"\"\"\nAdd Poisson noise to projections according to\n\n$$\\mathrm{counts} \\sim \\mathrm{Poi}\\left(I_0 \\exp (- \\alpha A\n\\mathbf{x} ) \\right)$$\n\n$$\\mathbf{y} = - \\frac{1}{\\alpha} \\log\\left(\\mathrm{counts} /\nI_0\\right) \\;.$$\n\nWe use the NumPy random functionality so we can generate using 64-bit\nnumbers.\n\"\"\"\ncounts = np.random.poisson(Io * snp.exp(-𝛼 * A @ x_gt))\ncounts = np.clip(counts, a_min=1, a_max=np.inf)  # replace any 0s count with 1\ny = -1 / 𝛼 * np.log(counts / Io)\ny = snp.array(y)  # convert back to float32 as a jax array\n\n\n\"\"\"\nSet up post processing. For this example, we clip all reconstructions\nto the range of the ground truth.\n\"\"\"\n\n\ndef postprocess(x):\n    return snp.clip(x, 0, snp.max(x_gt))\n\n\n\"\"\"\nCompute an FBP reconstruction as an initial guess.\n\"\"\"\nx0 = postprocess(A.fbp(y))\n\n\nr\"\"\"\nSet up and solve the un-weighted reconstruction problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - A \\mathbf{x}\n  \\|_2^2 + \\lambda \\| C \\mathbf{x} \\|_{2,1} \\;.$$\n\"\"\"\n# Note that rho and lambda were selected via a parameter sweep (not\n# shown here).\nρ = 2.5e3  # ADMM penalty parameter\nlambda_unweighted = 3e2  # regularization strength\nmaxiter = 100  # number of ADMM iterations\ncg_tol = 1e-5  # CG relative tolerance\ncg_maxiter = 10  # maximum CG iterations per ADMM iteration\nf = loss.SquaredL2Loss(y=y, A=A)\nadmm_unweighted = ADMM(\n    f=f,\n    g_list=[lambda_unweighted * functional.L21Norm()],\n    C_list=[linop.FiniteDifference(x_gt.shape, append=0)],\n    rho_list=[ρ],\n    x0=x0,\n    maxiter=maxiter,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": cg_tol, \"maxiter\": cg_maxiter}),\n    itstat_options={\"display\": True, \"period\": 10},\n)\nprint(f\"Solving on {device_info()}\\n\")\nadmm_unweighted.solve()\nx_unweighted = postprocess(admm_unweighted.x)\n\n\nr\"\"\"\nSet up and solve the weighted reconstruction problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - A \\mathbf{x}\n  \\|_W^2 + \\lambda \\| C \\mathbf{x} \\|_{2,1} \\;,$$\n\nwhere\n\n  $$W = \\mathrm{diag}( \\mathrm{counts} / I_0 ) \\;.$$\n\nThe data fidelity term in this formulation follows\n:cite:`sauer-1993-local` (9) except for the scaling by $I_0$, which we\nuse to maintain balance between the data and regularization terms if\n$I_0$ changes.\n\"\"\"\nlambda_weighted = 5e1\nweights = snp.array(counts / Io)\nf = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights))\nadmm_weighted = ADMM(\n    f=f,\n    g_list=[lambda_weighted * functional.L21Norm()],\n    C_list=[linop.FiniteDifference(x_gt.shape, append=0)],\n    rho_list=[ρ],\n    maxiter=maxiter,\n    x0=x0,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": cg_tol, \"maxiter\": cg_maxiter}),\n    itstat_options={\"display\": True, \"period\": 10},\n)\nprint()\nadmm_weighted.solve()\nx_weighted = postprocess(admm_weighted.x)\n\n\n\"\"\"\nShow recovered images.\n\"\"\"\n\n\ndef plot_recon(x, title, ax):\n    \"\"\"Plot an image with title indicating error metrics.\"\"\"\n    plot.imview(\n        x,\n        title=f\"{title}\\nSNR: {metric.snr(x_gt, x):.2f} (dB), MAE: {metric.mae(x_gt, x):.3f}\",\n        fig=fig,\n        ax=ax,\n    )\n\n\nfig, ax = plot.subplots(nrows=2, ncols=2, figsize=(11, 10))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0, 0])\nplot_recon(x0, \"FBP Reconstruction\", ax=ax[0, 1])\nplot_recon(x_unweighted, \"Unweighted TV Reconstruction\", ax=ax[1, 0])\nplot_recon(x_weighted, \"Weighted TV Reconstruction\", ax=ax[1, 1])\nfor ax_ in ax.ravel():\n    ax_.set_xlim(64, 448)\n    ax_.set_ylim(64, 448)\nfig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)\nfig.colorbar(\n    ax[0, 0].get_images()[0], ax=ax, location=\"right\", shrink=0.9, pad=0.05, label=\"arbitrary units\"\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_datagen_foam2.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"\nCT Data Generation for NN Training\n==================================\n\nThis example demonstrates how to generate synthetic CT data for training\nneural network models. If desired, a basic reconstruction can be\ngenerated using filtered back projection (FBP).\n\"\"\"\n\n# isort: off\nimport os\nimport numpy as np\n\nimport logging\nimport ray\n\nray.init(logging_level=logging.ERROR)  # need to call init before jax import: ray-project/ray#44087\n\n# Set an arbitrary processor count (only applies if GPU is not available).\nos.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=8\"\n\nfrom scico import plot\nfrom scico.flax.examples import load_ct_data\n\n\"\"\"\nRead data from cache or generate if not available.\n\"\"\"\nN = 256  # phantom size\ntrain_nimg = 536  # number of training images\ntest_nimg = 64  # number of testing images\nnimg = train_nimg + test_nimg\nn_projection = 45  # CT views\n\ntrdt, ttdt = load_ct_data(train_nimg, test_nimg, N, n_projection, verbose=True)\n\n\n\"\"\"\nPlot randomly selected sample.\n\"\"\"\nindx_tr = np.random.randint(0, train_nimg)\nindx_te = np.random.randint(0, test_nimg)\nfig, axes = plot.subplots(nrows=2, ncols=3, figsize=(9, 9))\nplot.imview(\n    trdt[\"img\"][indx_tr, ..., 0], title=\"Ground truth - Training Sample\", fig=fig, ax=axes[0, 0]\n)\nplot.imview(\n    trdt[\"sino\"][indx_tr, ..., 0], title=\"Sinogram - Training Sample\", fig=fig, ax=axes[0, 1]\n)\nplot.imview(\n    trdt[\"fbp\"][indx_tr, ..., 0],\n    title=\"FBP - Training Sample\",\n    fig=fig,\n    ax=axes[0, 2],\n)\nplot.imview(\n    ttdt[\"img\"][indx_te, ..., 0],\n    title=\"Ground truth - Testing Sample\",\n    fig=fig,\n    ax=axes[1, 0],\n)\nplot.imview(\n    ttdt[\"sino\"][indx_te, ..., 0], title=\"Sinogram - Testing Sample\", fig=fig, ax=axes[1, 1]\n)\nplot.imview(\n    ttdt[\"fbp\"][indx_te, ..., 0],\n    title=\"FBP - Testing Sample\",\n    fig=fig,\n    ax=axes[1, 2],\n)\nfig.suptitle(r\"Training and Testing samples\")\nfig.tight_layout()\nfig.colorbar(\n    axes[0, 2].get_images()[0],\n    ax=axes,\n    shrink=0.5,\n    pad=0.05,\n    label=\"Arbitrary Units\",\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"\nPPP (with BM3D) Fan-Beam CT Reconstruction\n==========================================\n\nThis example demonstrates solution of a fan-beam tomographic reconstruction\nproblem using the Plug-and-Play Priors framework\n:cite:`venkatakrishnan-2013-plugandplay2`, using BM3D\n:cite:`dabov-2008-image` as a denoiser and SVMBIR\n:cite:`svmbir-2020` for tomographic projection.\n\nThis example uses the data fidelity term as one of the ADMM $g$\nfunctionals so that the optimization with respect to the data fidelity is\nable to exploit the internal prox of the `SVMBIRExtendedLoss` functional.\n\nWe solve the problem in two different ways:\n1. Approximating the fan-beam geometry using parallel-beam and using the\n   parallel beam projector to compute the reconstruction.\n2. Using the correct fan-beam geometry to perform a reconstruction.\n\"\"\"\n\nimport numpy as np\n\nimport matplotlib.pyplot as plt\nimport svmbir\nfrom matplotlib.ticker import MaxNLocator\nfrom xdesign import Foam, discrete_phantom\n\nimport scico.numpy as snp\nfrom scico import metric, plot\nfrom scico.functional import BM3D\nfrom scico.linop import Diagonal, Identity\nfrom scico.linop.xray.svmbir import SVMBIRExtendedLoss, XRayTransform\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nGenerate a ground truth image.\n\"\"\"\nN = 256  # image size\ndensity = 0.025  # attenuation density of the image\nnp.random.seed(1234)\npad_len = 5\nx_gt = discrete_phantom(\n    Foam(size_range=[0.075, 0.005], gap=2e-3, porosity=1.0), size=N - 2 * pad_len\n)\nx_gt = x_gt / np.max(x_gt) * density\nx_gt = np.pad(x_gt, pad_len)\nx_gt[x_gt < 0] = 0\n\n\n\"\"\"\nGenerate tomographic projector and sinogram for fan beam and parallel beam.\nFor fan beam, use view angles spanning 2π since unlike parallel beam, views\nat 0 and π are not equivalent.\n\"\"\"\nnum_angles = int(N / 2)\nnum_channels = N\n\n# Use angles in the range [0, 2*pi] for fan beam\nangles = snp.linspace(0, 2 * snp.pi, num_angles, endpoint=False, dtype=snp.float32)\n\ndist_source_detector = 1500.0\nmagnification = 1.2\nA_fan = XRayTransform(\n    x_gt.shape,\n    angles,\n    num_channels,\n    geometry=\"fan-curved\",\n    dist_source_detector=dist_source_detector,\n    magnification=magnification,\n)\nA_parallel = XRayTransform(\n    x_gt.shape,\n    angles,\n    num_channels,\n    geometry=\"parallel\",\n)\n\nsino_fan = A_fan @ x_gt\n\n\n\"\"\"\nImpose Poisson noise on sinograms. Higher max_intensity means less noise.\n\"\"\"\n\n\ndef add_poisson_noise(sino, max_intensity):\n    expected_counts = max_intensity * np.exp(-sino)\n    noisy_counts = np.random.poisson(expected_counts).astype(np.float32)\n    noisy_counts[noisy_counts == 0] = 1  # deal with 0s\n    y = -np.log(noisy_counts / max_intensity)\n\n    return y\n\n\ny_fan = add_poisson_noise(sino_fan, max_intensity=500)\n\n\n\"\"\"\nReconstruct using default prior of SVMBIR :cite:`svmbir-2020`.\n\"\"\"\nweights_fan = svmbir.calc_weights(y_fan, weight_type=\"transmission\")\n\nx_mrf_fan = svmbir.recon(\n    np.array(y_fan[:, np.newaxis]),\n    np.array(angles),\n    weights=weights_fan[:, np.newaxis],\n    num_rows=N,\n    num_cols=N,\n    positivity=True,\n    verbose=0,\n    stop_threshold=0.0,\n    geometry=\"fan-curved\",\n    dist_source_detector=dist_source_detector,\n    magnification=magnification,\n    delta_channel=1.0,\n    delta_pixel=1.0 / magnification,\n)[0]\n\nx_mrf_parallel = svmbir.recon(\n    np.array(y_fan[:, np.newaxis]),\n    np.array(angles),\n    weights=weights_fan[:, np.newaxis],\n    num_rows=N,\n    num_cols=N,\n    positivity=True,\n    verbose=0,\n    stop_threshold=0.0,\n    geometry=\"parallel\",\n)[0]\n\n\n\"\"\"\nConvert numpy arrays to jax arrays.\n\"\"\"\ny_fan = snp.array(y_fan)\nx0_fan = snp.array(x_mrf_fan)\nweights_fan = snp.array(weights_fan)\nx0_parallel = snp.array(x_mrf_parallel)\n\n\n\"\"\"\nSet problem parameters and BM3D pseudo-functional.\n\"\"\"\nρ = 10  # ADMM penalty parameter\nσ = density * 0.6  # denoiser sigma\ng0 = σ * ρ * BM3D()\n\n\n\"\"\"\nSet up problem using `SVMBIRExtendedLoss`.\n\"\"\"\nf_extloss_fan = SVMBIRExtendedLoss(\n    y=y_fan,\n    A=A_fan,\n    W=Diagonal(weights_fan),\n    scale=0.5,\n    positivity=True,\n    prox_kwargs={\"maxiter\": 5, \"ctol\": 0.0},\n)\nf_extloss_parallel = SVMBIRExtendedLoss(\n    y=y_fan,\n    A=A_parallel,\n    W=Diagonal(weights_fan),\n    scale=0.5,\n    positivity=True,\n    prox_kwargs={\"maxiter\": 5, \"ctol\": 0.0},\n)\n\nsolver_extloss_fan = ADMM(\n    f=None,\n    g_list=[f_extloss_fan, g0],\n    C_list=[Identity(x_mrf_fan.shape), Identity(x_mrf_fan.shape)],\n    rho_list=[ρ, ρ],\n    x0=x0_fan,\n    maxiter=20,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 100}),\n    itstat_options={\"display\": True, \"period\": 5},\n)\nsolver_extloss_parallel = ADMM(\n    f=None,\n    g_list=[f_extloss_parallel, g0],\n    C_list=[Identity(x_mrf_parallel.shape), Identity(x_mrf_parallel.shape)],\n    rho_list=[ρ, ρ],\n    x0=x0_parallel,\n    maxiter=20,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 100}),\n    itstat_options={\"display\": True, \"period\": 5},\n)\n\n\n\"\"\"\nRun the ADMM solvers.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx_extloss_fan = solver_extloss_fan.solve()\nhist_extloss_fan = solver_extloss_fan.itstat_object.history(transpose=True)\n\nprint()\nx_extloss_parallel = solver_extloss_parallel.solve()\nhist_extloss_parallel = solver_extloss_parallel.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow the recovered images. The parallel beam reconstruction is poor because\nthe parallel beam is a poor approximation of the specific fan beam geometry\nused here.\n\"\"\"\nnorm = plot.matplotlib.colors.Normalize(vmin=-0.1 * density, vmax=1.2 * density)\n\nfig, ax = plt.subplots(1, 3, figsize=(20, 7))\nplot.imview(img=x_gt, title=\"Ground Truth Image\", cbar=True, fig=fig, ax=ax[0], norm=norm)\nplot.imview(\n    img=x_mrf_parallel,\n    title=f\"Parallel-beam MRF (PSNR: {metric.psnr(x_gt, x_mrf_parallel):.2f} dB)\",\n    cbar=True,\n    fig=fig,\n    ax=ax[1],\n    norm=norm,\n)\nplot.imview(\n    img=x_extloss_parallel,\n    title=f\"Parallel-beam Extended Loss (PSNR: {metric.psnr(x_gt, x_extloss_parallel):.2f} dB)\",\n    cbar=True,\n    fig=fig,\n    ax=ax[2],\n    norm=norm,\n)\nfig.show()\n\n\nfig, ax = plt.subplots(1, 3, figsize=(20, 7))\nplot.imview(img=x_gt, title=\"Ground Truth Image\", cbar=True, fig=fig, ax=ax[0], norm=norm)\nplot.imview(\n    img=x_mrf_fan,\n    title=f\"Fan-beam MRF (PSNR: {metric.psnr(x_gt, x_mrf_fan):.2f} dB)\",\n    cbar=True,\n    fig=fig,\n    ax=ax[1],\n    norm=norm,\n)\nplot.imview(\n    img=x_extloss_fan,\n    title=f\"Fan-beam Extended Loss (PSNR: {metric.psnr(x_gt, x_extloss_fan):.2f} dB)\",\n    cbar=True,\n    fig=fig,\n    ax=ax[2],\n    norm=norm,\n)\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plt.subplots(1, 2, figsize=(15, 6))\nplot.plot(\n    snp.array((hist_extloss_parallel.Prml_Rsdl, hist_extloss_parallel.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals for parallel-beam reconstruction\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[0],\n)\nax[0].set_ylim([1e-1, 1e1])\nax[0].xaxis.set_major_locator(MaxNLocator(integer=True))\nplot.plot(\n    snp.array((hist_extloss_fan.Prml_Rsdl, hist_extloss_fan.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals for fan-beam reconstruction\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[1],\n)\nax[1].set_ylim([1e-1, 1e1])\nax[1].xaxis.set_major_locator(MaxNLocator(integer=True))\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_modl_train_foam2.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nCT Training and Reconstruction with MoDL\n========================================\n\nThis example demonstrates the training and application of a\nmodel-based deep learning (MoDL) architecture described in\n:cite:`aggarwal-2019-modl` applied to a CT reconstruction problem.\n\nThe source images are foam phantoms generated with xdesign.\n\nA class\n[scico.flax.MoDLNet](../_autosummary/scico.flax.rst#scico.flax.MoDLNet)\nimplements the MoDL architecture, which solves the optimization\nproblem\n\n$$\\mathrm{argmin}_{\\mathbf{x}} \\; \\| A \\mathbf{x} - \\mathbf{y} \\|_2^2\n+ \\lambda \\, \\| \\mathbf{x} - \\mathrm{D}_w(\\mathbf{x})\\|_2^2 \\;,$$\n\nwhere $A$ is a tomographic projector, $\\mathbf{y}$ is a set of sinograms,\n$\\mathrm{D}_w$ is the regularization (a denoiser), and $\\mathbf{x}$ is\nthe set of reconstructed images. The MoDL abstracts the iterative\nsolution by an unrolled network where each iteration corresponds to a\ndifferent stage in the MoDL network and updates the prediction by solving\n\n$$\\mathbf{x}^{k+1} = (A^T A + \\lambda \\, I)^{-1} (A^T \\mathbf{y} +\n\\lambda \\, \\mathbf{z}^k) \\;,$$\n\nvia conjugate gradient. In the expression, $k$ is the index of the stage\n(iteration), $\\mathbf{z}^k = \\mathrm{ResNet}(\\mathbf{x}^{k})$ is the\nregularization (a denoiser implemented as a residual convolutional neural\nnetwork), $\\mathbf{x}^k$ is the output of the previous stage,\n$\\lambda > 0$ is a learned regularization parameter, and $I$ is the\nidentity operator. The output of the final stage is the set of\nreconstructed images.\n\"\"\"\n\n# isort: off\nimport os\nfrom functools import partial\nfrom time import time\n\nimport numpy as np\n\nimport logging\nimport ray\n\nray.init(logging_level=logging.ERROR)  # need to call init before jax import: ray-project/ray#44087\n\nimport jax\n\ntry:\n    from jax.extend.backend import get_backend  # introduced in jax 0.4.33\nexcept ImportError:\n    from jax.lib.xla_bridge import get_backend\n\nfrom mpl_toolkits.axes_grid1 import make_axes_locatable\n\nfrom scico import flax as sflax\nfrom scico import metric, plot\nfrom scico.flax.examples import load_ct_data\nfrom scico.flax.train.traversals import clip_positive, construct_traversal\nfrom scico.linop.xray import XRayTransform2D\n\n\"\"\"\nPrepare parallel processing. Set an arbitrary processor count (only\napplies if GPU is not available).\n\"\"\"\nos.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=8\"\nplatform = get_backend().platform\nprint(\"Platform: \", platform)\n\n\n\"\"\"\nRead data from cache or generate if not available.\n\"\"\"\nN = 256  # phantom size\ntrain_nimg = 536  # number of training images\ntest_nimg = 64  # number of testing images\nnimg = train_nimg + test_nimg\nn_projection = 45  # CT views\n\ntrdt, ttdt = load_ct_data(train_nimg, test_nimg, N, n_projection, verbose=True)\n\n\n\"\"\"\nBuild CT projection operator. Parameters are chosen so that the operator\nis equivalent to the one used to generate the training data.\n\"\"\"\nangles = np.linspace(0, np.pi, n_projection, endpoint=False)  # evenly spaced projection angles\nA = XRayTransform2D(\n    input_shape=(N, N),\n    angles=angles,\n    det_count=int(N * 1.05 / np.sqrt(2.0)),\n    dx=1.0 / np.sqrt(2),\n)\nA = (1.0 / N) * A  # normalize projection operator\n\n\n\"\"\"\nBuild training and testing structures. Inputs are the sinograms and\noutputs are the original generated foams. Keep training and testing\npartitions.\n\"\"\"\nnumtr = 100\nnumtt = 16\ntrain_ds = {\"image\": trdt[\"sino\"][:numtr], \"label\": trdt[\"img\"][:numtr]}\ntest_ds = {\"image\": ttdt[\"sino\"][:numtt], \"label\": ttdt[\"img\"][:numtt]}\n\n\n\"\"\"\nDefine configuration dictionary for model and training loop.\n\nParameters have been selected for demonstration purposes and\nrelatively short training. The model depth is akin to the number of\nunrolled iterations in the MoDL model. The block depth controls the\nnumber of layers at each unrolled iteration. The number of filters is\nuniform throughout the iterations. The iterations used for the\nconjugate gradient (CG) solver can also be specified. Better\nperformance may be obtained by increasing depth, block depth, number\nof filters, CG iterations, or training epochs, but may require longer\ntraining times.\n\"\"\"\n# model configuration\nmodel_conf = {\n    \"depth\": 10,\n    \"num_filters\": 64,\n    \"block_depth\": 4,\n    \"cg_iter_1\": 3,\n    \"cg_iter_2\": 8,\n}\n# training configuration\ntrain_conf: sflax.ConfigDict = {\n    \"seed\": 12345,\n    \"opt_type\": \"SGD\",\n    \"momentum\": 0.9,\n    \"batch_size\": 16,\n    \"num_epochs\": 20,\n    \"base_learning_rate\": 1e-2,\n    \"warmup_epochs\": 0,\n    \"log_every_steps\": 40,\n    \"log\": True,\n    \"checkpointing\": True,\n}\n\n\n\"\"\"\nConstruct functionality for ensuring that the learned\nregularization parameter is always positive.\n\"\"\"\nlmbdatrav = construct_traversal(\"lmbda\")  # select lmbda parameters in model\nlmbdapos = partial(\n    clip_positive,  # apply this function\n    traversal=lmbdatrav,  # to lmbda parameters in model\n    minval=5e-4,\n)\n\n\n\"\"\"\nPrint configuration of distributed run.\n\"\"\"\nprint(f\"\\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}\")\nprint(f\"JAX local devices: {jax.local_devices()}\\n\")\n\n\n\"\"\"\nCheck for iterated trained model. If not found, construct MoDLNet\nmodel, using only one iteration (depth) in model and few CG iterations\nfor faster intialization. Run first stage (initialization) training\nloop followed by a second stage (depth iterations) training loop.\n\"\"\"\nchannels = train_ds[\"image\"].shape[-1]\nworkdir2 = os.path.join(\n    os.path.expanduser(\"~\"), \".cache\", \"scico\", \"examples\", \"modl_ct_out\", \"iterated\"\n)\n\nstats_object_ini = None\nstats_object = None\n\ncheckpoint_files = []\nfor dirpath, dirnames, filenames in os.walk(workdir2):\n    checkpoint_files = [fn for fn in filenames]\n\nif len(checkpoint_files) > 0:\n    model = sflax.MoDLNet(\n        operator=A,\n        depth=model_conf[\"depth\"],\n        channels=channels,\n        num_filters=model_conf[\"num_filters\"],\n        block_depth=model_conf[\"block_depth\"],\n        cg_iter=model_conf[\"cg_iter_2\"],\n    )\n\n    train_conf[\"post_lst\"] = [lmbdapos]\n    # Parameters for 2nd stage\n    train_conf[\"workdir\"] = workdir2\n    train_conf[\"opt_type\"] = \"ADAM\"\n    train_conf[\"num_epochs\"] = 150\n    # Construct training object\n    trainer = sflax.BasicFlaxTrainer(\n        train_conf,\n        model,\n        train_ds,\n        test_ds,\n    )\n    start_time = time()\n    modvar, stats_object = trainer.train()\n    time_train = time() - start_time\n    time_init = 0.0\n    epochs_init = 0\nelse:\n    # One iteration (depth) in model and few CG iterations\n    model = sflax.MoDLNet(\n        operator=A,\n        depth=1,\n        channels=channels,\n        num_filters=model_conf[\"num_filters\"],\n        block_depth=model_conf[\"block_depth\"],\n        cg_iter=model_conf[\"cg_iter_1\"],\n    )\n    # First stage: initialization training loop.\n    workdir1 = os.path.join(os.path.expanduser(\"~\"), \".cache\", \"scico\", \"examples\", \"modl_ct_out\")\n    train_conf[\"workdir\"] = workdir1\n    train_conf[\"post_lst\"] = [lmbdapos]\n    # Construct training object\n    trainer = sflax.BasicFlaxTrainer(\n        train_conf,\n        model,\n        train_ds,\n        test_ds,\n    )\n\n    start_time = time()\n    modvar, stats_object_ini = trainer.train()\n    time_init = time() - start_time\n    epochs_init = train_conf[\"num_epochs\"]\n\n    print(\n        f\"{'MoDLNet init':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}{'':3s}\"\n        f\"{'time[s]:':21s}{time_init:>7.2f}\"\n    )\n\n    # Second stage: depth iterations training loop.\n    model.depth = model_conf[\"depth\"]\n    model.cg_iter = model_conf[\"cg_iter_2\"]\n    train_conf[\"opt_type\"] = \"ADAM\"\n    train_conf[\"num_epochs\"] = 150\n    train_conf[\"workdir\"] = workdir2\n    # Construct training object, include current model parameters\n    trainer = sflax.BasicFlaxTrainer(\n        train_conf,\n        model,\n        train_ds,\n        test_ds,\n        variables0=modvar,\n    )\n    start_time = time()\n    modvar, stats_object = trainer.train()\n    time_train = time() - start_time\n\n\n\"\"\"\nEvaluate on testing data.\n\"\"\"\ndel train_ds[\"image\"]\ndel train_ds[\"label\"]\n\nfmap = sflax.FlaxMap(model, modvar)\ndel model, modvar\n\nmaxn = numtt\nstart_time = time()\noutput = fmap(test_ds[\"image\"][:maxn])\ntime_eval = time() - start_time\noutput = np.clip(output, a_min=0, a_max=1.0)\n\n\n\"\"\"\nEvaluate trained model in terms of reconstruction time\nand data fidelity.\n\"\"\"\ntotal_epochs = epochs_init + train_conf[\"num_epochs\"]\ntotal_time_train = time_init + time_train\nsnr_eval = metric.snr(test_ds[\"label\"][:maxn], output)\npsnr_eval = metric.psnr(test_ds[\"label\"][:maxn], output)\nprint(\n    f\"{'MoDLNet training':18s}{'epochs:':2s}{total_epochs:>5d}{'':21s}\"\n    f\"{'time[s]:':10s}{total_time_train:>7.2f}\"\n)\nprint(\n    f\"{'MoDLNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}\"\n    f\"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}\"\n)\n\n\"\"\"\nPlot comparison.\n\"\"\"\nnp.random.seed(123)\nindx = np.random.randint(0, high=maxn)\n\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(test_ds[\"label\"][indx, ..., 0], title=\"Ground truth\", cbar=None, fig=fig, ax=ax[0])\nplot.imview(\n    test_ds[\"image\"][indx, ..., 0],\n    title=\"Sinogram\",\n    cbar=None,\n    fig=fig,\n    ax=ax[1],\n)\nplot.imview(\n    output[indx, ..., 0],\n    title=\"MoDLNet Reconstruction\\nSNR: %.2f (dB), PSNR: %.2f\"\n    % (\n        metric.snr(test_ds[\"label\"][indx, ..., 0], output[indx, ..., 0]),\n        metric.psnr(test_ds[\"label\"][indx, ..., 0], output[indx, ..., 0]),\n    ),\n    fig=fig,\n    ax=ax[2],\n)\ndivider = make_axes_locatable(ax[2])\ncax = divider.append_axes(\"right\", size=\"5%\", pad=0.2)\nfig.colorbar(ax[2].get_images()[0], cax=cax, label=\"arbitrary units\")\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics. Statistics are generated only if a training\ncycle was done (i.e. if not reading final epoch results from checkpoint).\n\"\"\"\nif stats_object is not None and len(stats_object.iterations) > 0:\n    hist = stats_object.history(transpose=True)\n    fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\n    plot.plot(\n        np.array((hist.Train_Loss, hist.Eval_Loss)).T,\n        x=hist.Epoch,\n        ptyp=\"semilogy\",\n        title=\"Loss function\",\n        xlbl=\"Epoch\",\n        ylbl=\"Loss value\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[0],\n    )\n    plot.plot(\n        np.array((hist.Train_SNR, hist.Eval_SNR)).T,\n        x=hist.Epoch,\n        title=\"Metric\",\n        xlbl=\"Epoch\",\n        ylbl=\"SNR (dB)\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[1],\n    )\n    fig.show()\n\n# Stats for initialization loop\nif stats_object_ini is not None and len(stats_object_ini.iterations) > 0:\n    hist = stats_object_ini.history(transpose=True)\n    fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\n    plot.plot(\n        np.array((hist.Train_Loss, hist.Eval_Loss)).T,\n        ptyp=\"semilogy\",\n        title=\"Loss function - Initialization\",\n        xlbl=\"Epoch\",\n        ylbl=\"Loss value\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[0],\n    )\n    plot.plot(\n        np.array((hist.Train_SNR, hist.Eval_SNR)).T,\n        title=\"Metric - Initialization\",\n        xlbl=\"Epoch\",\n        ylbl=\"SNR (dB)\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[1],\n    )\n    fig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_multi_tv_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nTV-Regularized Sparse-View CT Reconstruction (Multiple Projectors)\n==================================================================\n\nThis example demonstrates solution of a sparse-view CT reconstruction\nproblem with isotropic total variation (TV) regularization\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - A \\mathbf{x}\n  \\|_2^2 + \\lambda \\| C \\mathbf{x} \\|_{2,1} \\;,$$\n\nwhere $A$ is the X-ray transform (the CT forward projection operator),\n$\\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and\n$\\mathbf{x}$ is the reconstructed image. The solution is computed and\ncompared for all three 2D CT projectors available in scico, using a\nsinogram computed with the astra projector.\n\"\"\"\n\nimport numpy as np\n\nfrom xdesign import Foam, discrete_phantom\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.linop.xray import XRayTransform2D, astra, svmbir\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nN = 512  # phantom size\nnp.random.seed(1234)\nx_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N))\n\n\n\"\"\"\nDefine CT geometry and construct array of (approximately) equivalent projectors.\n\"\"\"\nn_projection = 45  # number of projections\nangles = np.linspace(0, np.pi, n_projection, endpoint=False)  # evenly spaced projection angles\ndet_count = int(N * 1.05 / np.sqrt(2.0))\ndet_spacing = np.sqrt(2)\nprojectors = {\n    \"astra\": astra.XRayTransform2D(\n        x_gt.shape, det_count, det_spacing, angles - np.pi / 2.0\n    ),  # astra\n    \"svmbir\": svmbir.XRayTransform(\n        x_gt.shape, 2 * np.pi - angles, det_count, delta_pixel=1.0, delta_channel=det_spacing\n    ),  # svmbir\n    \"scico\": XRayTransform2D((N, N), angles, det_count=det_count, dx=1 / det_spacing),  # scico\n}\n\n\n\"\"\"\nCompute common sinogram using astra projector.\n\"\"\"\nA = projectors[\"astra\"]\nnoise = np.random.normal(size=(n_projection, det_count)).astype(np.float32)\ny = A @ x_gt + 2.0 * noise\n\n\n\"\"\"\nConstruct initial solution for regularized problem.\n\"\"\"\nx0 = A.fbp(y)\n\n\n\"\"\"\nSolve the same problem using the different projectors.\n\"\"\"\nprint(f\"Solving on {device_info()}\")\nx_rec, hist = {}, {}\nfor p in projectors.keys():\n    print(f\"\\nSolving with {p} projector\")\n\n    # Set up ADMM solver object.\n    λ = 2e1  # L1 norm regularization parameter\n    ρ = 1e3  # ADMM penalty parameter\n    maxiter = 100  # number of ADMM iterations\n    cg_tol = 1e-4  # CG relative tolerance\n    cg_maxiter = 50  # maximum CG iterations per ADMM iteration\n\n    # The append=0 option makes the results of horizontal and vertical\n    # finite differences the same shape, which is required for the L21Norm,\n    # which is used so that g(Cx) corresponds to isotropic TV.\n    C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)\n    g = λ * functional.L21Norm()\n    A = projectors[p]\n    f = loss.SquaredL2Loss(y=y, A=A)\n\n    # Set up the solver.\n    solver = ADMM(\n        f=f,\n        g_list=[g],\n        C_list=[C],\n        rho_list=[ρ],\n        x0=x0,\n        maxiter=maxiter,\n        subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": cg_tol, \"maxiter\": cg_maxiter}),\n        itstat_options={\"display\": True, \"period\": 5},\n    )\n\n    # Run the solver.\n    solver.solve()\n    hist[p] = solver.itstat_object.history(transpose=True)\n    x_rec[p] = solver.x\n\n    if p == \"scico\":\n        x_rec[p] = x_rec[p] * det_spacing  # to match ASTRA's scaling\n\n\n\"\"\"\nCompare reconstruction results.\n\"\"\"\nprint(\"Reconstruction SNR:\")\nfor p in projectors.keys():\n    print(f\"  {(p + ':'):7s}  {metric.snr(x_gt, x_rec[p]):5.2f} dB\")\n\n\n\"\"\"\nDisplay sinogram.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=1, figsize=(15, 3))\nplot.imview(y, title=\"sinogram\", fig=fig, ax=ax)\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(12, 5))\nplot.plot(\n    np.array([hist[p].Objective for p in projectors.keys()]).T,\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    ylbl=\"Functional value\",\n    lgnd=projectors.keys(),\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    np.array([hist[p].Prml_Rsdl for p in projectors.keys()]).T,\n    ptyp=\"semilogy\",\n    title=\"Primal Residual\",\n    xlbl=\"Iteration\",\n    fig=fig,\n    ax=ax[1],\n)\nplot.plot(\n    np.array([hist[p].Dual_Rsdl for p in projectors.keys()]).T,\n    ptyp=\"semilogy\",\n    title=\"Dual Residual\",\n    xlbl=\"Iteration\",\n    fig=fig,\n    ax=ax[2],\n)\nfig.show()\n\n\n\"\"\"\nShow the recovered images.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=4, figsize=(15, 5))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0])\nfor n, p in enumerate(projectors.keys()):\n    plot.imview(\n        x_rec[p],\n        title=\"%s  SNR: %.2f (dB)\" % (p, metric.snr(x_gt, x_rec[p])),\n        fig=fig,\n        ax=ax[n + 1],\n    )\nfor ax in ax:\n    ax.get_images()[0].set_clim(-0.1, 1.1)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_odp_train_foam2.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nCT Training and Reconstruction with ODP\n=======================================\n\nThis example demonstrates the training of the unrolled optimization with\ndeep priors (ODP) gradient descent architecture described in\n:cite:`diamond-2018-odp` applied to a CT reconstruction problem.\n\nThe source images are foam phantoms generated with xdesign.\n\nA class\n[scico.flax.ODPNet](../_autosummary/scico.flax.rst#scico.flax.ODPNet)\nimplements the ODP architecture, which solves the optimization problem\n\n$$\\mathrm{argmin}_{\\mathbf{x}} \\; \\| A \\mathbf{x} - \\mathbf{y} \\|_2^2\n+ r(\\mathbf{x}) \\;,$$\n\nwhere $A$ is a tomographic projector, $\\mathbf{y}$ is a set of sinograms,\n$r$ is a regularizer and $\\mathbf{x}$ is the set of reconstructed images.\nThe ODP, gradient descent architecture, abstracts the iterative solution\nby an unrolled network where each iteration corresponds to a different\nstage in the ODP network and updates the prediction by solving\n\n$$\\mathbf{x}^{k+1} = \\mathrm{argmin}_{\\mathbf{x}} \\; \\alpha_k \\| A\n\\mathbf{x} - \\mathbf{y} \\|_2^2 + \\frac{1}{2} \\| \\mathbf{x} -\n\\mathbf{x}^k - \\mathbf{x}^{k+1/2} \\|_2^2 \\;,$$\n\nwhich for the CT problem, using gradient descent, corresponds to\n\n$$\\mathbf{x}^{k+1} = \\mathbf{x}^k + \\mathbf{x}^{k+1/2} - \\alpha_k \\,\nA^T \\, (A \\mathbf{x}^k - \\mathbf{y}) \\;,$$\n\nwhere $k$ is the index of the stage (iteration), $\\mathbf{x}^k +\n\\mathbf{x}^{k+1/2} = \\mathrm{ResNet}(\\mathbf{x}^{k})$ is the\nregularization (implemented as a residual convolutional neural network),\n$\\mathbf{x}^k$ is the output of the previous stage and $\\alpha_k > 0$ is\na learned stage-wise parameter weighting the contribution of the fidelity\nterm. The output of the final stage is the set of reconstructed images.\n\"\"\"\n\n# isort: off\nimport os\nfrom functools import partial\nfrom time import time\n\nimport numpy as np\n\nimport logging\nimport ray\n\nray.init(logging_level=logging.ERROR)  # need to call init before jax import: ray-project/ray#44087\n\n# Set an arbitrary processor count (only applies if GPU is not available).\nos.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=8\"\n\nimport jax\n\ntry:\n    from jax.extend.backend import get_backend  # introduced in jax 0.4.33\nexcept ImportError:\n    from jax.lib.xla_bridge import get_backend\n\nfrom mpl_toolkits.axes_grid1 import make_axes_locatable\n\nfrom scico import flax as sflax\nfrom scico import metric, plot\nfrom scico.flax.examples import load_ct_data\nfrom scico.flax.train.traversals import clip_positive, construct_traversal\nfrom scico.linop.xray import XRayTransform2D\n\nplatform = get_backend().platform\nprint(\"Platform: \", platform)\n\n\n\"\"\"\nRead data from cache or generate if not available.\n\"\"\"\nN = 256  # phantom size\ntrain_nimg = 536  # number of training images\ntest_nimg = 64  # number of testing images\nnimg = train_nimg + test_nimg\nn_projection = 45  # CT views\n\ntrdt, ttdt = load_ct_data(train_nimg, test_nimg, N, n_projection, verbose=True)\n\n\n\"\"\"\nBuild CT projection operator. Parameters are chosen so that the operator\nis equivalent to the one used to generate the training data.\n\"\"\"\nangles = np.linspace(0, np.pi, n_projection, endpoint=False)  # evenly spaced projection angles\nA = XRayTransform2D(\n    input_shape=(N, N),\n    angles=angles,\n    det_count=int(N * 1.05 / np.sqrt(2.0)),\n    dx=1.0 / np.sqrt(2),\n)\nA = (1.0 / N) * A  # normalize projection operator\n\n\n\"\"\"\nBuild training and testing structures. Inputs are the sinograms and\noutputs are the original generated foams. Keep training and testing\npartitions.\n\"\"\"\nnumtr = 320\nnumtt = 32\ntrain_ds = {\"image\": trdt[\"sino\"][:numtr], \"label\": trdt[\"img\"][:numtr]}\ntest_ds = {\"image\": ttdt[\"sino\"][:numtt], \"label\": ttdt[\"img\"][:numtt]}\n\n\n\"\"\"\nDefine configuration dictionary for model and training loop.\n\nParameters have been selected for demonstration purposes and relatively\nshort training. The model depth is akin to the number of unrolled\niterations in the MoDL model. The block depth controls the number of\nlayers at each unrolled iteration. The number of filters is uniform\nthroughout the iterations. The iterations used for the conjugate gradient\n(CG) solver can also be specified. Better performance may be obtained by\nincreasing depth, block depth, number of filters, CG iterations, or\ntraining epochs, but may require longer training times.\n\"\"\"\n# model configuration\nmodel_conf = {\n    \"depth\": 8,\n    \"num_filters\": 64,\n    \"block_depth\": 6,\n}\n# training configuration\ntrain_conf: sflax.ConfigDict = {\n    \"seed\": 1234,\n    \"opt_type\": \"ADAM\",\n    \"batch_size\": 16,\n    \"num_epochs\": 200,\n    \"base_learning_rate\": 1e-3,\n    \"warmup_epochs\": 0,\n    \"log_every_steps\": 160,\n    \"log\": True,\n    \"checkpointing\": True,\n}\n\n\n\"\"\"\nConstruct functionality for ensuring that the learned fidelity weight\nparameter is always positive.\n\"\"\"\nalphatrav = construct_traversal(\"alpha\")  # select alpha parameters in model\nalphapost = partial(\n    clip_positive,  # apply this function\n    traversal=alphatrav,  # to alpha parameters in model\n    minval=1e-3,\n)\n\n\n\"\"\"\nPrint configuration of distributed run.\n\"\"\"\nprint(f\"\\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}\")\nprint(f\"JAX local devices: {jax.local_devices()}\\n\")\n\n\n\"\"\"\nConstruct ODPNet model.\n\"\"\"\nchannels = train_ds[\"image\"].shape[-1]\nmodel = sflax.ODPNet(\n    operator=A,\n    depth=model_conf[\"depth\"],\n    channels=channels,\n    num_filters=model_conf[\"num_filters\"],\n    block_depth=model_conf[\"block_depth\"],\n    odp_block=sflax.inverse.ODPGrDescBlock,\n    alpha_ini=1e-2,\n)\n\n\n\"\"\"\nRun training loop.\n\"\"\"\nworkdir = os.path.join(os.path.expanduser(\"~\"), \".cache\", \"scico\", \"examples\", \"odp_ct_out\")\n\ntrain_conf[\"workdir\"] = workdir\ntrain_conf[\"post_lst\"] = [alphapost]\n# Construct training object\ntrainer = sflax.BasicFlaxTrainer(\n    train_conf,\n    model,\n    train_ds,\n    test_ds,\n)\nmodvar, stats_object = trainer.train()\n\n\n\"\"\"\nEvaluate on testing data.\n\"\"\"\ndel train_ds[\"image\"]\ndel train_ds[\"label\"]\n\nfmap = sflax.FlaxMap(model, modvar)\ndel model, modvar\n\nmaxn = numtt\nstart_time = time()\noutput = fmap(test_ds[\"image\"][:maxn])\ntime_eval = time() - start_time\noutput = np.clip(output, a_min=0, a_max=1.0)\nepochs = train_conf[\"num_epochs\"]\n\n\n\"\"\"\nEvaluate trained model in terms of reconstruction time and data fidelity.\n\"\"\"\nsnr_eval = metric.snr(test_ds[\"label\"][:maxn], output)\npsnr_eval = metric.psnr(test_ds[\"label\"][:maxn], output)\nprint(\n    f\"{'ODPNet training':18s}{'epochs:':2s}{epochs:>5d}{'':21s}\"\n    f\"{'time[s]:':10s}{trainer.train_time:>7.2f}\"\n)\nprint(\n    f\"{'ODPNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}\"\n    f\"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}\"\n)\n\n\n\"\"\"\nPlot comparison.\n\"\"\"\nnp.random.seed(123)\nindx = np.random.randint(0, high=maxn)\n\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(test_ds[\"label\"][indx, ..., 0], title=\"Ground truth\", cbar=None, fig=fig, ax=ax[0])\nplot.imview(\n    test_ds[\"image\"][indx, ..., 0],\n    title=\"Sinogram\",\n    cbar=None,\n    fig=fig,\n    ax=ax[1],\n)\nplot.imview(\n    output[indx, ..., 0],\n    title=\"ODPNet Reconstruction\\nSNR: %.2f (dB), PSNR: %.2f\"\n    % (\n        metric.snr(test_ds[\"label\"][indx, ..., 0], output[indx, ..., 0]),\n        metric.psnr(test_ds[\"label\"][indx, ..., 0], output[indx, ..., 0]),\n    ),\n    fig=fig,\n    ax=ax[2],\n)\ndivider = make_axes_locatable(ax[2])\ncax = divider.append_axes(\"right\", size=\"5%\", pad=0.2)\nfig.colorbar(ax[2].get_images()[0], cax=cax, label=\"arbitrary units\")\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics. Statistics are generated only if a training\ncycle was done (i.e. if not reading final epoch results from checkpoint).\n\"\"\"\nif stats_object is not None and len(stats_object.iterations) > 0:\n    hist = stats_object.history(transpose=True)\n    fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\n    plot.plot(\n        np.array((hist.Train_Loss, hist.Eval_Loss)).T,\n        x=hist.Epoch,\n        ptyp=\"semilogy\",\n        title=\"Loss function\",\n        xlbl=\"Epoch\",\n        ylbl=\"Loss value\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[0],\n    )\n    plot.plot(\n        np.array((hist.Train_SNR, hist.Eval_SNR)).T,\n        x=hist.Epoch,\n        title=\"Metric\",\n        xlbl=\"Epoch\",\n        ylbl=\"SNR (dB)\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[1],\n    )\n    fig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_projector_comparison_2d.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\nr\"\"\"\n2D X-ray Transform Comparison\n=============================\n\nThis example compares SCICO's native 2D X-ray transform algorithm\nto that of the ASTRA toolbox.\n\"\"\"\n\nimport numpy as np\n\nimport jax\nimport jax.numpy as jnp\n\nfrom xdesign import Foam, discrete_phantom\n\nimport scico.linop.xray.astra as astra\nfrom scico import plot\nfrom scico.linop.xray import XRayTransform2D\nfrom scico.util import Timer\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nN = 512\nx_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)\nx_gt = jnp.array(x_gt)\n\n\n\"\"\"\nTime projector instantiation.\n\"\"\"\nnum_angles = 500\nangles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)\ndet_count = int(N * 1.02 / jnp.sqrt(2.0))\n\ntimer = Timer()\n\nprojectors = {}\ntimer.start(\"scico_init\")\nprojectors[\"scico\"] = XRayTransform2D((N, N), angles, det_count=det_count)\ntimer.stop(\"scico_init\")\n\ntimer.start(\"astra_init\")\nprojectors[\"astra\"] = astra.XRayTransform2D(\n    (N, N), det_count=det_count, det_spacing=np.sqrt(2), angles=angles - jnp.pi / 2.0\n)\ntimer.stop(\"astra_init\")\n\n\n\"\"\"\nTime first projector application, which might include JIT overhead.\n\"\"\"\nys = {}\nfor name, H in projectors.items():\n    timer_label = f\"{name}_first_fwd\"\n    timer.start(timer_label)\n    ys[name] = H @ x_gt\n    jax.block_until_ready(ys[name])\n    timer.stop(timer_label)\n\n\n\"\"\"\nCompute average time for a projector application.\n\"\"\"\nnum_repeats = 3\nfor name, H in projectors.items():\n    timer_label = f\"{name}_avg_fwd\"\n    timer.start(timer_label)\n    for _ in range(num_repeats):\n        ys[name] = H @ x_gt\n        jax.block_until_ready(ys[name])\n    timer.stop(timer_label)\n    timer.td[timer_label] /= num_repeats\n\n\n\"\"\"\nTime first back projection, which might include JIT overhead.\n\"\"\"\ny = np.zeros(H.output_shape, dtype=np.float32)\ny[num_angles // 3, det_count // 2] = 1.0\ny = jnp.array(y)\n\nHTys = {}\nfor name, H in projectors.items():\n    timer_label = f\"{name}_first_back\"\n    timer.start(timer_label)\n    HTys[name] = H.T @ y\n    jax.block_until_ready(ys[name])\n    timer.stop(timer_label)\n\n\n\"\"\"\nCompute average time for back projection.\n\"\"\"\nnum_repeats = 3\nfor name, H in projectors.items():\n    timer_label = f\"{name}_avg_back\"\n    timer.start(timer_label)\n    for _ in range(num_repeats):\n        HTys[name] = H.T @ y\n        jax.block_until_ready(ys[name])\n    timer.stop(timer_label)\n    timer.td[timer_label] /= num_repeats\n\n\n\"\"\"\nDisplay timing results.\n\nOn our server, when using the GPU, the SCICO projector (both forward\nand backward) is faster than ASTRA. When using the CPU, it is slower\nfor forward projection and faster for back projection. The SCICO object\ninitialization and first back projection are slow due to JIT\noverhead.\n\nOn our server, using the GPU:\n```\ninit         astra    4.81e-02 s\ninit         scico    2.53e-01 s\n\nfirst  fwd   astra    4.44e-02 s\nfirst  fwd   scico    2.82e-02 s\n\nfirst  back  astra    3.31e-02 s\nfirst  back  scico    2.80e-01 s\n\navg    fwd   astra    4.76e-02 s\navg    fwd   scico    2.83e-02 s\n\navg    back  astra    3.96e-02 s\navg    back  scico    6.80e-04 s\n```\n\nUsing the CPU:\n```\ninit         astra    1.72e-02 s\ninit         scico    2.88e+00 s\n\nfirst  fwd   astra    1.02e+00 s\nfirst  fwd   scico    2.40e+00 s\n\nfirst  back  astra    1.03e+00 s\nfirst  back  scico    3.53e+00 s\n\navg    fwd   astra    1.03e+00 s\navg    fwd   scico    2.54e+00 s\n\navg    back  astra    1.01e+00 s\navg    back  scico    5.98e-01 s\n```\n\"\"\"\nprint(f\"init         astra    {timer.td['astra_init']:.2e} s\")\nprint(f\"init         scico    {timer.td['scico_init']:.2e} s\")\nprint(\"\")\nfor tstr in (\"first\", \"avg\"):\n    for dstr in (\"fwd\", \"back\"):\n        for pstr in (\"astra\", \"scico\"):\n            print(\n                f\"{tstr:5s}  {dstr:4s}  {pstr}    {timer.td[pstr + '_' + tstr + '_' + dstr]:.2e} s\"\n            )\n        print()\n\n\n\"\"\"\nShow projections.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 6))\nplot.imview(ys[\"scico\"], title=\"SCICO projection\", cbar=None, fig=fig, ax=ax[0])\nplot.imview(ys[\"astra\"], title=\"ASTRA projection\", cbar=None, fig=fig, ax=ax[1])\nfig.show()\n\n\n\"\"\"\nShow back projections of a single detector element, i.e., a line.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 6))\nplot.imview(HTys[\"scico\"], title=\"SCICO back projection (zoom)\", cbar=None, fig=fig, ax=ax[0])\nplot.imview(HTys[\"astra\"], title=\"ASTRA back projection (zoom)\", cbar=None, fig=fig, ax=ax[1])\nfor ax_i in ax:\n    ax_i.set_xlim(2 * N / 5, N - 2 * N / 5)\n    ax_i.set_ylim(2 * N / 5, N - 2 * N / 5)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_projector_comparison_3d.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\nr\"\"\"\n3D X-ray Transform Comparison\n=============================\n\nThis example shows how to define a SCICO native 3D X-ray transform using\nASTRA toolbox conventions and vice versa.\n\"\"\"\n\nimport numpy as np\n\nimport jax\nimport jax.numpy as jnp\n\nimport scico.linop.xray.astra as astra\nfrom scico import plot\nfrom scico.examples import create_block_phantom\nfrom scico.linop.xray import XRayTransform3D\nfrom scico.util import ContextTimer, Timer\n\n\"\"\"\nCreate a ground truth image and set detector dimensions.\n\"\"\"\nN = 64\n# use rectangular volume to check whether axes are handled correctly\nin_shape = (N + 1, N + 2, N + 3)\nx = create_block_phantom(in_shape)\nx = jnp.array(x)\n\n# use rectangular detector to check whether axes are handled correctly\nout_shape = (N, N + 1)\n\n\n\"\"\"\nSet up SCICO projection.\n\"\"\"\nnum_angles = 3\n\n\nrot_X = 90.0 - 16.0\nrot_Y = np.linspace(0, 180, num_angles, endpoint=False)\nangles = np.stack(np.broadcast_arrays(rot_X, rot_Y), axis=-1)\nmatrices = XRayTransform3D.matrices_from_euler_angles(\n    in_shape, out_shape, \"XY\", angles, degrees=True\n)\n\n\"\"\"\nSpecify geometry using SCICO conventions and project.\n\"\"\"\nnum_repeats = 3\n\ntimer_scico = Timer()\nwith ContextTimer(timer_scico, \"init\"):\n    H_scico = XRayTransform3D(in_shape, matrices, out_shape)\n\nwith ContextTimer(timer_scico, \"first_fwd\"):\n    y_scico = H_scico @ x\n    jax.block_until_ready(y_scico)\n\nwith ContextTimer(timer_scico, \"avg_fwd\"):\n    for _ in range(num_repeats):\n        y_scico = H_scico @ x\n        jax.block_until_ready(y_scico)\ntimer_scico.td[\"avg_fwd\"] /= num_repeats\n\nwith ContextTimer(timer_scico, \"first_back\"):\n    HTy_scico = H_scico.T @ y_scico\n\nwith ContextTimer(timer_scico, \"avg_back\"):\n    for _ in range(num_repeats):\n        HTy_scico = H_scico.T @ y_scico\n        jax.block_until_ready(HTy_scico)\ntimer_scico.td[\"avg_back\"] /= num_repeats\n\n\n\"\"\"\nConvert SCICO geometry to ASTRA and project.\n\"\"\"\n\nvectors_from_scico = astra.convert_from_scico_geometry(in_shape, matrices, out_shape)\n\ntimer_astra = Timer()\nwith ContextTimer(timer_astra, \"init\"):\n    H_astra_from_scico = astra.XRayTransform3D(\n        input_shape=in_shape, det_count=out_shape, vectors=vectors_from_scico\n    )\n\nwith ContextTimer(timer_astra, \"first_fwd\"):\n    y_astra_from_scico = H_astra_from_scico @ x\n    jax.block_until_ready(y_astra_from_scico)\n\nwith ContextTimer(timer_astra, \"avg_fwd\"):\n    for _ in range(num_repeats):\n        y_astra_from_scico = H_astra_from_scico @ x\n        jax.block_until_ready(y_astra_from_scico)\ntimer_astra.td[\"avg_fwd\"] /= num_repeats\n\nwith ContextTimer(timer_astra, \"first_back\"):\n    HTy_astra_from_scico = H_astra_from_scico.T @ y_astra_from_scico\n\nwith ContextTimer(timer_astra, \"avg_back\"):\n    for _ in range(num_repeats):\n        HTy_astra_from_scico = H_astra_from_scico.T @ y_astra_from_scico\n        jax.block_until_ready(HTy_astra_from_scico)\ntimer_astra.td[\"avg_back\"] /= num_repeats\n\n\n\"\"\"\nSpecify geometry with ASTRA conventions and project.\n\"\"\"\n\nangles = np.random.rand(num_angles) * 180  # random projection angles\ndet_spacing = [1.0, 1.0]\nvectors = astra.angle_to_vector(det_spacing, angles)\n\nH_astra = astra.XRayTransform3D(input_shape=in_shape, det_count=out_shape, vectors=vectors)\n\ny_astra = H_astra @ x\nHTy_astra = H_astra.T @ y_astra\n\n\n\"\"\"\nConvert ASTRA geometry to SCICO and project.\n\"\"\"\n\nP_from_astra = astra._astra_to_scico_geometry(H_astra.vol_geom, H_astra.proj_geom)\nH_scico_from_astra = XRayTransform3D(in_shape, P_from_astra, out_shape)\n\ny_scico_from_astra = H_scico_from_astra @ x\nHTy_scico_from_astra = H_scico_from_astra.T @ y_scico_from_astra\n\n\n\"\"\"\nPrint timing results.\n\"\"\"\nprint(f\"init         astra    {timer_astra.td['init']:.2e} s\")\nprint(f\"init         scico    {timer_scico.td['init']:.2e} s\")\nprint(\"\")\nfor tstr in (\"first\", \"avg\"):\n    for dstr in (\"fwd\", \"back\"):\n        for timer, pstr in zip((timer_astra, timer_scico), (\"astra\", \"scico\")):\n            print(f\"{tstr:5s}  {dstr:4s}  {pstr}    {timer.td[tstr + '_' + dstr]:.2e} s\")\n        print()\n\n\n\"\"\"\nShow projections.\n\"\"\"\nfig, ax = plot.subplots(nrows=3, ncols=2, figsize=(8, 10))\nplot.imview(y_scico[0], title=\"SCICO projections\", cbar=None, fig=fig, ax=ax[0, 0])\nplot.imview(y_scico[1], cbar=None, fig=fig, ax=ax[1, 0])\nplot.imview(y_scico[2], cbar=None, fig=fig, ax=ax[2, 0])\nplot.imview(y_astra_from_scico[:, 0], title=\"ASTRA projections\", cbar=None, fig=fig, ax=ax[0, 1])\nplot.imview(y_astra_from_scico[:, 1], cbar=None, fig=fig, ax=ax[1, 1])\nplot.imview(y_astra_from_scico[:, 2], cbar=None, fig=fig, ax=ax[2, 1])\nfig.suptitle(\"Using SCICO conventions\")\nfig.tight_layout()\nfig.show()\n\nfig, ax = plot.subplots(nrows=3, ncols=2, figsize=(8, 10))\nplot.imview(y_scico_from_astra[0], title=\"SCICO projections\", cbar=None, fig=fig, ax=ax[0, 0])\nplot.imview(y_scico_from_astra[1], cbar=None, fig=fig, ax=ax[1, 0])\nplot.imview(y_scico_from_astra[2], cbar=None, fig=fig, ax=ax[2, 0])\nplot.imview(y_astra[:, 0], title=\"ASTRA projections\", cbar=None, fig=fig, ax=ax[0, 1])\nplot.imview(y_astra[:, 1], cbar=None, fig=fig, ax=ax[1, 1])\nplot.imview(y_astra[:, 2], cbar=None, fig=fig, ax=ax[2, 1])\nfig.suptitle(\"Using ASTRA conventions\")\nfig.tight_layout()\nfig.show()\n\n\n\"\"\"\nShow back projections.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(8, 5))\nplot.imview(HTy_scico[N // 2], title=\"SCICO back projection\", cbar=None, fig=fig, ax=ax[0])\nplot.imview(\n    HTy_astra_from_scico[N // 2], title=\"ASTRA back projection\", cbar=None, fig=fig, ax=ax[1]\n)\nfig.suptitle(\"Using SCICO conventions\")\nfig.tight_layout()\nfig.show()\n\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(8, 5))\nplot.imview(\n    HTy_scico_from_astra[N // 2], title=\"SCICO back projection\", cbar=None, fig=fig, ax=ax[0]\n)\nplot.imview(HTy_astra[N // 2], title=\"ASTRA back projection\", cbar=None, fig=fig, ax=ax[1])\nfig.suptitle(\"Using ASTRA conventions\")\nfig.tight_layout()\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"\nPPP (with BM3D) CT Reconstruction (ADMM with CG Subproblem Solver)\n==================================================================\n\nThis example demonstrates solution of a tomographic reconstruction problem\nusing the Plug-and-Play Priors framework\n:cite:`venkatakrishnan-2013-plugandplay2`, using BM3D\n:cite:`dabov-2008-image` as a denoiser and SVMBIR :cite:`svmbir-2020` for\ntomographic projection.\n\nThere are two versions of this example, solving the same problem in two\ndifferent ways. This version uses the data fidelity term as the ADMM $f$,\nand thus the optimization with respect to the data fidelity uses CG rather\nthan the prox of the `SVMBIRSquaredL2Loss` functional, as in the\n[other version](ct_svmbir_ppp_bm3d_admm_prox.rst).\n\"\"\"\n\nimport numpy as np\n\nimport matplotlib.pyplot as plt\nimport svmbir\nfrom xdesign import Foam, discrete_phantom\n\nimport scico.numpy as snp\nfrom scico import metric, plot\nfrom scico.functional import BM3D, NonNegativeIndicator\nfrom scico.linop import Diagonal, Identity\nfrom scico.linop.xray.svmbir import SVMBIRSquaredL2Loss, XRayTransform\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nGenerate a ground truth image.\n\"\"\"\nN = 256  # image size\ndensity = 0.025  # attenuation density of the image\nnp.random.seed(1234)\nx_gt = discrete_phantom(Foam(size_range=[0.075, 0.005], gap=2e-3, porosity=1.0), size=N - 10)\nx_gt = x_gt / np.max(x_gt) * density\nx_gt = np.pad(x_gt, 5)\nx_gt[x_gt < 0] = 0\n\n\n\"\"\"\nGenerate tomographic projector and sinogram.\n\"\"\"\nnum_angles = int(N / 2)\nnum_channels = N\nangles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32)\nA = XRayTransform(x_gt.shape, angles, num_channels)\nsino = A @ x_gt\n\n\n\"\"\"\nImpose Poisson noise on sinogram. Higher max_intensity means less noise.\n\"\"\"\nmax_intensity = 2000\nexpected_counts = max_intensity * np.exp(-sino)\nnoisy_counts = np.random.poisson(expected_counts).astype(np.float32)\nnoisy_counts[noisy_counts == 0] = 1  # deal with 0s\ny = -np.log(noisy_counts / max_intensity)\n\n\n\"\"\"\nReconstruct using default prior of SVMBIR :cite:`svmbir-2020`.\n\"\"\"\nweights = svmbir.calc_weights(y, weight_type=\"transmission\")\n\nx_mrf = svmbir.recon(\n    np.array(y[:, np.newaxis]),\n    np.array(angles),\n    weights=weights[:, np.newaxis],\n    num_rows=N,\n    num_cols=N,\n    positivity=True,\n    verbose=0,\n)[0]\n\n\n\"\"\"\nSet up an ADMM solver.\n\"\"\"\ny = snp.array(y)\nx0 = snp.array(x_mrf)\nweights = snp.array(weights)\n\nρ = 15  # ADMM penalty parameter\nσ = density * 0.18  # denoiser sigma\n\nf = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)\ng0 = σ * ρ * BM3D()\ng1 = NonNegativeIndicator()\n\nsolver = ADMM(\n    f=f,\n    g_list=[g0, g1],\n    C_list=[Identity(x_mrf.shape), Identity(x_mrf.shape)],\n    rho_list=[ρ, ρ],\n    x0=x0,\n    maxiter=20,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-4, \"maxiter\": 100}),\n    itstat_options={\"display\": True, \"period\": 5},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx_bm3d = solver.solve()\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow the recovered image.\n\"\"\"\nnorm = plot.matplotlib.colors.Normalize(vmin=-0.1 * density, vmax=1.2 * density)\nfig, ax = plt.subplots(1, 3, figsize=[15, 5])\nplot.imview(img=x_gt, title=\"Ground Truth Image\", cbar=True, fig=fig, ax=ax[0], norm=norm)\nplot.imview(\n    img=x_mrf,\n    title=f\"MRF (PSNR: {metric.psnr(x_gt, x_mrf):.2f} dB)\",\n    cbar=True,\n    fig=fig,\n    ax=ax[1],\n    norm=norm,\n)\nplot.imview(\n    img=x_bm3d,\n    title=f\"BM3D (PSNR: {metric.psnr(x_gt, x_bm3d):.2f} dB)\",\n    cbar=True,\n    fig=fig,\n    ax=ax[2],\n    norm=norm,\n)\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n)\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"\nPPP (with BM3D) CT Reconstruction (ADMM with Fast SVMBIR Prox)\n==============================================================\n\nThis example demonstrates solution of a tomographic reconstruction\nproblem using the Plug-and-Play Priors framework\n:cite:`venkatakrishnan-2013-plugandplay2`, using BM3D\n:cite:`dabov-2008-image` as a denoiser and SVMBIR :cite:`svmbir-2020` for\ntomographic projection.\n\nThere are two versions of this example, solving the same problem in two\ndifferent ways. This version uses the data fidelity term as one of the\nADMM $g$ functionals so that the optimization with respect to the data\nfidelity is able to exploit the internal prox of the `SVMBIRExtendedLoss`\nand `SVMBIRSquaredL2Loss` functionals. The\n[other version](ct_svmbir_ppp_bm3d_admm_cg.rst) solves the ADMM subproblem\ncorresponding to the data fidelity term via CG.\n\nTwo ways of exploiting the SVMBIR internal prox are explored in this\nexample:\n1. Using the `SVMBIRSquaredL2Loss` together with the BM3D pseudo-functional\n   and a non-negative indicator function, and\n2. Using the `SVMBIRExtendedLoss`, which includes a non-negativity\n   constraint, together with the BM3D pseudo-functional.\n\"\"\"\n\nimport numpy as np\n\nimport matplotlib.pyplot as plt\nimport svmbir\nfrom matplotlib.ticker import MaxNLocator\nfrom xdesign import Foam, discrete_phantom\n\nimport scico.numpy as snp\nfrom scico import metric, plot\nfrom scico.functional import BM3D, NonNegativeIndicator\nfrom scico.linop import Diagonal, Identity\nfrom scico.linop.xray.svmbir import (\n    SVMBIRExtendedLoss,\n    SVMBIRSquaredL2Loss,\n    XRayTransform,\n)\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nGenerate a ground truth image.\n\"\"\"\nN = 256  # image size\ndensity = 0.025  # attenuation density of the image\nnp.random.seed(1234)\nx_gt = discrete_phantom(Foam(size_range=[0.075, 0.005], gap=2e-3, porosity=1.0), size=N - 10)\nx_gt = x_gt / np.max(x_gt) * density\nx_gt = np.pad(x_gt, 5)\nx_gt[x_gt < 0] = 0\n\n\n\"\"\"\nGenerate tomographic projector and sinogram.\n\"\"\"\nnum_angles = int(N / 2)\nnum_channels = N\nangles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32)\nA = XRayTransform(x_gt.shape, angles, num_channels)\nsino = A @ x_gt\n\n\n\"\"\"\nImpose Poisson noise on sinogram. Higher max_intensity means less noise.\n\"\"\"\nmax_intensity = 2000\nexpected_counts = max_intensity * np.exp(-sino)\nnoisy_counts = np.random.poisson(expected_counts).astype(np.float32)\nnoisy_counts[noisy_counts == 0] = 1  # deal with 0s\ny = -np.log(noisy_counts / max_intensity)\n\n\n\"\"\"\nReconstruct using default prior of SVMBIR :cite:`svmbir-2020`.\n\"\"\"\nweights = svmbir.calc_weights(y, weight_type=\"transmission\")\n\nx_mrf = svmbir.recon(\n    np.array(y[:, np.newaxis]),\n    np.array(angles),\n    weights=weights[:, np.newaxis],\n    num_rows=N,\n    num_cols=N,\n    positivity=True,\n    verbose=0,\n)[0]\n\n\n\"\"\"\nConvert numpy arrays to jax arrays.\n\"\"\"\ny = snp.array(y)\nx0 = snp.array(x_mrf)\nweights = snp.array(weights)\n\n\n\"\"\"\nSet problem parameters and BM3D pseudo-functional.\n\"\"\"\nρ = 10  # ADMM penalty parameter\nσ = density * 0.26  # denoiser sigma\ng0 = σ * ρ * BM3D()\n\n\n\"\"\"\nSet up problem using `SVMBIRSquaredL2Loss` and `NonNegativeIndicator`.\n\"\"\"\nf_l2loss = SVMBIRSquaredL2Loss(\n    y=y, A=A, W=Diagonal(weights), scale=0.5, prox_kwargs={\"maxiter\": 5, \"ctol\": 0.0}\n)\ng1 = NonNegativeIndicator()\n\nsolver_l2loss = ADMM(\n    f=None,\n    g_list=[f_l2loss, g0, g1],\n    C_list=[Identity(x_mrf.shape), Identity(x_mrf.shape), Identity(x_mrf.shape)],\n    rho_list=[ρ, ρ, ρ],\n    x0=x0,\n    maxiter=20,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 100}),\n    itstat_options={\"display\": True, \"period\": 5},\n)\n\n\n\"\"\"\nRun the ADMM solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx_l2loss = solver_l2loss.solve()\nhist_l2loss = solver_l2loss.itstat_object.history(transpose=True)\n\n\n\"\"\"\nSet up problem using `SVMBIRExtendedLoss`, without need for `NonNegativeIndicator`.\n\"\"\"\nf_extloss = SVMBIRExtendedLoss(\n    y=y,\n    A=A,\n    W=Diagonal(weights),\n    scale=0.5,\n    positivity=True,\n    prox_kwargs={\"maxiter\": 5, \"ctol\": 0.0},\n)\n\nsolver_extloss = ADMM(\n    f=None,\n    g_list=[f_extloss, g0],\n    C_list=[Identity(x_mrf.shape), Identity(x_mrf.shape)],\n    rho_list=[ρ, ρ],\n    x0=x0,\n    maxiter=20,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 100}),\n    itstat_options={\"display\": True, \"period\": 5},\n)\n\n\n\"\"\"\nRun the ADMM solver.\n\"\"\"\nprint()\nx_extloss = solver_extloss.solve()\nhist_extloss = solver_extloss.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow the recovered images.\n\"\"\"\nnorm = plot.matplotlib.colors.Normalize(vmin=-0.1 * density, vmax=1.2 * density)\nfig, ax = plt.subplots(2, 2, figsize=(15, 15))\nplot.imview(img=x_gt, title=\"Ground Truth Image\", cbar=True, fig=fig, ax=ax[0, 0], norm=norm)\nplot.imview(\n    img=x_mrf,\n    title=f\"MRF (PSNR: {metric.psnr(x_gt, x_mrf):.2f} dB)\",\n    cbar=True,\n    fig=fig,\n    ax=ax[0, 1],\n    norm=norm,\n)\nplot.imview(\n    img=x_l2loss,\n    title=f\"SquaredL2Loss + non-negativity (PSNR: {metric.psnr(x_gt, x_l2loss):.2f} dB)\",\n    cbar=True,\n    fig=fig,\n    ax=ax[1, 0],\n    norm=norm,\n)\nplot.imview(\n    img=x_extloss,\n    title=f\"ExtendedLoss (PSNR: {metric.psnr(x_gt, x_extloss):.2f} dB)\",\n    cbar=True,\n    fig=fig,\n    ax=ax[1, 1],\n    norm=norm,\n)\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plt.subplots(1, 2, figsize=(15, 5))\nplot.plot(\n    snp.array((hist_l2loss.Prml_Rsdl, hist_l2loss.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals (SquaredL2Loss + non-negativity)\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[0],\n)\nax[0].set_ylim([1e-1, 5e0])\nax[0].xaxis.set_major_locator(MaxNLocator(integer=True))\nplot.plot(\n    snp.array((hist_extloss.Prml_Rsdl, hist_extloss.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals (ExtendedLoss)\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[1],\n)\nax[1].set_ylim([1e-1, 5e0])\nax[1].xaxis.set_major_locator(MaxNLocator(integer=True))\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_svmbir_tv_multi.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nTV-Regularized CT Reconstruction (Multiple Algorithms)\n======================================================\n\nThis example demonstrates the use of different optimization algorithms to\nsolve the TV-regularized CT problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - A \\mathbf{x}\n  \\|_2^2 + \\lambda \\| C \\mathbf{x} \\|_{2,1} \\;,$$\n\nwhere $A$ is the X-ray transform (implemented using the SVMBIR\n:cite:`svmbir-2020` tomographic projection), $\\mathbf{y}$ is the sinogram,\n$C$ is a 2D finite difference operator, and $\\mathbf{x}$ is the\nreconstructed image.\n\"\"\"\n\nimport numpy as np\n\nimport matplotlib.pyplot as plt\nimport svmbir\nfrom xdesign import Foam, discrete_phantom\n\nimport scico.numpy as snp\nfrom scico import functional, linop, metric, plot\nfrom scico.linop import Diagonal\nfrom scico.linop.xray.svmbir import SVMBIRSquaredL2Loss, XRayTransform\nfrom scico.optimize import PDHG, LinearizedADMM\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nGenerate a ground truth image.\n\"\"\"\nN = 256  # image size\ndensity = 0.025  # attenuation density of the image\nnp.random.seed(1234)\nx_gt = discrete_phantom(Foam(size_range=[0.075, 0.005], gap=2e-3, porosity=1.0), size=N - 10)\nx_gt = x_gt / np.max(x_gt) * density\nx_gt = np.pad(x_gt, 5)\nx_gt[x_gt < 0] = 0\n\n\n\"\"\"\nGenerate tomographic projector and sinogram.\n\"\"\"\nnum_angles = int(N / 2)\nnum_channels = N\nangles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32)\nA = XRayTransform(x_gt.shape, angles, num_channels)\nsino = A @ x_gt\n\n\n\"\"\"\nImpose Poisson noise on sinogram. Higher max_intensity means less noise.\n\"\"\"\nmax_intensity = 2000\nexpected_counts = max_intensity * np.exp(-sino)\nnoisy_counts = np.random.poisson(expected_counts).astype(np.float32)\nnoisy_counts[noisy_counts == 0] = 1  # deal with 0s\ny = -snp.log(noisy_counts / max_intensity)\n\n\n\"\"\"\nReconstruct using default prior of SVMBIR :cite:`svmbir-2020`.\n\"\"\"\nweights = svmbir.calc_weights(y, weight_type=\"transmission\")\n\nx_mrf = svmbir.recon(\n    np.array(y[:, np.newaxis]),\n    np.array(angles),\n    weights=weights[:, np.newaxis],\n    num_rows=N,\n    num_cols=N,\n    positivity=True,\n    verbose=0,\n)[0]\n\n\n\"\"\"\nSet up problem.\n\"\"\"\nx0 = snp.array(x_mrf)\nweights = snp.array(weights)\nλ = 1e-1  # ℓ1 norm regularization parameter\nf = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)\ng = λ * functional.L21Norm()  # regularization functional\n# The append=0 option makes the results of horizontal and vertical finite\n# differences the same shape, which is required for the L21Norm.\nC = linop.FiniteDifference(input_shape=x_gt.shape, append=0)\n\n\n\"\"\"\nSolve via ADMM.\n\"\"\"\nsolve_admm = ADMM(\n    f=f,\n    g_list=[g],\n    C_list=[C],\n    rho_list=[2e1],\n    x0=x0,\n    maxiter=50,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-4, \"maxiter\": 10}),\n    itstat_options={\"display\": True, \"period\": 10},\n)\nprint(f\"Solving on {device_info()}\\n\")\nprint(\"ADMM:\")\nx_admm = solve_admm.solve()\nhist_admm = solve_admm.itstat_object.history(transpose=True)\nprint(f\"PSNR: {metric.psnr(x_gt, x_admm):.2f} dB\\n\")\n\n\n\"\"\"\nSolve via Linearized ADMM.\n\"\"\"\nsolver_ladmm = LinearizedADMM(\n    f=f,\n    g=g,\n    C=C,\n    mu=3e-2,\n    nu=2e-1,\n    x0=x0,\n    maxiter=50,\n    itstat_options={\"display\": True, \"period\": 10},\n)\nprint(\"Linearized ADMM:\")\nx_ladmm = solver_ladmm.solve()\nhist_ladmm = solver_ladmm.itstat_object.history(transpose=True)\nprint(f\"PSNR: {metric.psnr(x_gt, x_ladmm):.2f} dB\\n\")\n\n\n\"\"\"\nSolve via PDHG.\n\"\"\"\nsolver_pdhg = PDHG(\n    f=f,\n    g=g,\n    C=C,\n    tau=2e-2,\n    sigma=8e0,\n    x0=x0,\n    maxiter=50,\n    itstat_options={\"display\": True, \"period\": 10},\n)\nprint(\"PDHG:\")\nx_pdhg = solver_pdhg.solve()\nhist_pdhg = solver_pdhg.itstat_object.history(transpose=True)\nprint(f\"PSNR: {metric.psnr(x_gt, x_pdhg):.2f} dB\\n\")\n\n\n\"\"\"\nShow the recovered images.\n\"\"\"\nnorm = plot.matplotlib.colors.Normalize(vmin=-0.1 * density, vmax=1.2 * density)\nfig, ax = plt.subplots(1, 2, figsize=[10, 5])\nplot.imview(img=x_gt, title=\"Ground Truth Image\", cbar=True, fig=fig, ax=ax[0], norm=norm)\nplot.imview(\n    img=x_mrf,\n    title=f\"MRF (PSNR: {metric.psnr(x_gt, x_mrf):.2f} dB)\",\n    cbar=True,\n    fig=fig,\n    ax=ax[1],\n    norm=norm,\n)\nfig.show()\n\nfig, ax = plt.subplots(1, 3, figsize=[15, 5])\nplot.imview(\n    img=x_admm,\n    title=f\"TV ADMM (PSNR: {metric.psnr(x_gt, x_admm):.2f} dB)\",\n    cbar=True,\n    fig=fig,\n    ax=ax[0],\n    norm=norm,\n)\nplot.imview(\n    img=x_ladmm,\n    title=f\"TV LinADMM (PSNR: {metric.psnr(x_gt, x_ladmm):.2f} dB)\",\n    cbar=True,\n    fig=fig,\n    ax=ax[1],\n    norm=norm,\n)\nplot.imview(\n    img=x_pdhg,\n    title=f\"TV PDHG (PSNR: {metric.psnr(x_gt, x_pdhg):.2f} dB)\",\n    cbar=True,\n    fig=fig,\n    ax=ax[2],\n    norm=norm,\n)\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6))\nplot.plot(\n    snp.array((hist_admm.Objective, hist_ladmm.Objective, hist_pdhg.Objective)).T,\n    ptyp=\"semilogy\",\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    lgnd=(\"ADMM\", \"LinADMM\", \"PDHG\"),\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist_admm.Prml_Rsdl, hist_ladmm.Prml_Rsdl, hist_pdhg.Prml_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Primal residual\",\n    xlbl=\"Iteration\",\n    lgnd=(\"ADMM\", \"LinADMM\", \"PDHG\"),\n    fig=fig,\n    ax=ax[1],\n)\nplot.plot(\n    snp.array((hist_admm.Dual_Rsdl, hist_ladmm.Dual_Rsdl, hist_pdhg.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Dual residual\",\n    xlbl=\"Iteration\",\n    lgnd=(\"ADMM\", \"LinADMM\", \"PDHG\"),\n    fig=fig,\n    ax=ax[2],\n)\nfig.show()\n\nfig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6))\nplot.plot(\n    snp.array((hist_admm.Objective, hist_ladmm.Objective, hist_pdhg.Objective)).T,\n    snp.array((hist_admm.Time, hist_ladmm.Time, hist_pdhg.Time)).T,\n    ptyp=\"semilogy\",\n    title=\"Objective function\",\n    xlbl=\"Time (s)\",\n    lgnd=(\"ADMM\", \"LinADMM\", \"PDHG\"),\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist_admm.Prml_Rsdl, hist_ladmm.Prml_Rsdl, hist_pdhg.Prml_Rsdl)).T,\n    snp.array((hist_admm.Time, hist_ladmm.Time, hist_pdhg.Time)).T,\n    ptyp=\"semilogy\",\n    title=\"Primal residual\",\n    xlbl=\"Time (s)\",\n    lgnd=(\"ADMM\", \"LinADMM\", \"PDHG\"),\n    fig=fig,\n    ax=ax[1],\n)\nplot.plot(\n    snp.array((hist_admm.Dual_Rsdl, hist_ladmm.Dual_Rsdl, hist_pdhg.Dual_Rsdl)).T,\n    snp.array((hist_admm.Time, hist_ladmm.Time, hist_pdhg.Time)).T,\n    ptyp=\"semilogy\",\n    title=\"Dual residual\",\n    xlbl=\"Time (s)\",\n    lgnd=(\"ADMM\", \"LinADMM\", \"PDHG\"),\n    fig=fig,\n    ax=ax[2],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_symcone_tv_padmm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nTV-Regularized Cone Beam CT for Symmetric Objects\n=================================================\n\nThis example demonstrates a total variation (TV) regularized\nreconstruction for cone beam CT of a cylindrically symmetric object,\nby solving the problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - C \\mathbf{x}\n  \\|_2^2 + \\lambda \\| D \\mathbf{x} \\|_1 \\;,$$\n\nwhere $C$ is a single-view X-ray transform (with an implementation based\non a projector from the AXITOM package :cite:`olufsen-2019-axitom`),\n$\\mathbf{y}$ is the measured data, $D$ is a 2D finite difference\noperator, and $\\mathbf{x}$ is the solution.\n\"\"\"\n\nimport numpy as np\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.examples import create_circular_phantom\nfrom scico.linop.xray.symcone import SymConeXRayTransform\nfrom scico.optimize import ProximalADMM\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nN = 256  # image size\nx_gt = create_circular_phantom((N, N), [0.4 * N, 0.2 * N, 0.1 * N], [1, 0, 0.5])\n\n\n\"\"\"\nSet up the forward operator and create a test measurement.\n\"\"\"\nC = SymConeXRayTransform(x_gt.shape, obj_dist=5e2 * N, det_dist=6e2 * N, num_slabs=4)\ny = C @ x_gt\nnp.random.seed(12345)\ny = y + np.random.normal(size=y.shape).astype(np.float32)\n\n\n\"\"\"\nCompute FDK reconstruction.\n\"\"\"\nx_inv = C.fdk(y)\n\n\nr\"\"\"\nSet up problem and solver. We want to minimize the functional\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - C \\mathbf{x}\n  \\|_2^2 + \\lambda \\| D \\mathbf{x} \\|_1 \\;,$$\n\nwhere $C$ is the X-ray transform and $D$ is a finite difference\noperator. We use anisotropic TV, which gives slightly better performance\nthan isotropic TV in this case. This problem can be expressed as\n\n  $$\\mathrm{argmin}_{\\mathbf{x}, \\mathbf{z}} \\; (1/2) \\| \\mathbf{y} -\n  \\mathbf{z}_0 \\|_2^2 + \\lambda \\| \\mathbf{z}_1 \\|_1 \\;\\;\n  \\text{such that} \\;\\; \\mathbf{z}_0 = C \\mathbf{x} \\;\\; \\text{and} \\;\\;\n  \\mathbf{z}_1 = D \\mathbf{x} \\;,$$\n\nwhich can be written in the form of a standard ADMM problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}, \\mathbf{z}} \\; f(\\mathbf{x}) + g(\\mathbf{z})\n  \\;\\; \\text{such that} \\;\\; A \\mathbf{x} + B \\mathbf{z} = \\mathbf{c}$$\n\nwith\n\n  $$f = 0 \\qquad g = g_0 + g_1$$\n  $$g_0(\\mathbf{z}_0) = (1/2) \\| \\mathbf{y} - \\mathbf{z}_0 \\|_2^2 \\qquad\n  g_1(\\mathbf{z}_1) = \\lambda \\| \\mathbf{z}_1 \\|_1$$\n  $$A = \\left( \\begin{array}{c} C \\\\ D \\end{array} \\right) \\qquad\n  B = \\left( \\begin{array}{cc} -I & 0 \\\\ 0 & -I \\end{array} \\right) \\qquad\n  \\mathbf{c} = \\left( \\begin{array}{c} 0 \\\\ 0 \\end{array} \\right) \\;.$$\n\"\"\"\n𝛼 = 7e1  # improve problem conditioning by balancing C and D components of A\nλ = 8e0  # ℓ1 norm regularization parameter\nρ = 1e-2  # ADMM penalty parameter\nmaxiter = 250  # number of ADMM iterations\n\nf = functional.ZeroFunctional()\ng0 = loss.SquaredL2Loss(y=y)\ng1 = (λ / 𝛼) * functional.L1Norm()\ng = functional.SeparableFunctional((g0, g1))\nD = linop.FiniteDifference(input_shape=x_gt.shape, append=0)\n\nA = linop.VerticalStack((C, 𝛼 * D))\nmu, nu = ProximalADMM.estimate_parameters(A, maxiter=20)\n\nsolver = ProximalADMM(\n    f=f,\n    g=g,\n    A=A,\n    B=None,\n    rho=ρ,\n    mu=mu,\n    nu=nu,\n    x0=snp.clip(x_inv, 0.0, 1.0),\n    maxiter=maxiter,\n    itstat_options={\"display\": True, \"period\": 20},\n)\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx_tv = solver.solve()\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow results.\n\"\"\"\nnorm = plot.matplotlib.colors.Normalize(vmin=-0.1, vmax=1.2)\nfig, ax = plot.subplots(nrows=2, ncols=2, figsize=(12, 12))\nplot.imview(x_gt, title=\"Ground Truth\", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 0], norm=norm)\nplot.imview(y, title=\"Measurement\", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 1])\nplot.imview(\n    x_inv,\n    title=\"FDK: %.2f (dB)\" % metric.psnr(x_gt, x_inv),\n    cmap=plot.cm.Blues,\n    fig=fig,\n    ax=ax[1, 0],\n    norm=norm,\n)\nplot.imview(\n    x_tv,\n    title=\"TV-Regularized Inversion: %.2f (dB)\" % metric.psnr(x_gt, x_tv),\n    cmap=plot.cm.Blues,\n    fig=fig,\n    ax=ax[1, 1],\n    norm=norm,\n)\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\nplot.plot(\n    hist.Objective,\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    ylbl=\"Functional value\",\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[1],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_tv_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nTV-Regularized Sparse-View CT Reconstruction (Integrated Projector)\n===================================================================\n\nThis example demonstrates solution of a sparse-view CT reconstruction\nproblem with isotropic total variation (TV) regularization\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - A \\mathbf{x}\n  \\|_2^2 + \\lambda \\| C \\mathbf{x} \\|_{2,1} \\;,$$\n\nwhere $A$ is the X-ray transform (the CT forward projection operator),\n$\\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and\n$\\mathbf{x}$ is the reconstructed image. This example uses the CT\nprojector integrated into scico, while the companion\n[example script](ct_astra_tv_admm.rst) uses the projector provided by\nthe astra package.\n\"\"\"\n\nimport numpy as np\n\nfrom mpl_toolkits.axes_grid1 import make_axes_locatable\nfrom xdesign import Foam, discrete_phantom\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.linop.xray import XRayTransform2D\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nN = 512  # phantom size\nnp.random.seed(1234)\nx_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N))\n\n\n\"\"\"\nConfigure CT projection operator and generate synthetic measurements.\n\"\"\"\nn_projection = 45  # number of projections\nangles = np.linspace(0, np.pi, n_projection, endpoint=False)  # evenly spaced projection angles\ndet_count = int(N * 1.05 / np.sqrt(2.0))\ndx = 1.0 / np.sqrt(2)\nA = XRayTransform2D(\n    (N, N), angles + np.pi / 2.0, det_count=det_count, dx=dx\n)  # CT projection operator\ny = A @ x_gt  # sinogram\n\n\n\"\"\"\nSet up problem functional and ADMM solver object.\n\"\"\"\nλ = 2e0  # ℓ1 norm regularization parameter\nρ = 5e0  # ADMM penalty parameter\nmaxiter = 25  # number of ADMM iterations\ncg_tol = 1e-4  # CG relative tolerance\ncg_maxiter = 25  # maximum CG iterations per ADMM iteration\n\n# The append=0 option makes the results of horizontal and vertical\n# finite differences the same shape, which is required for the L21Norm,\n# which is used so that g(Cx) corresponds to isotropic TV.\nC = linop.FiniteDifference(input_shape=x_gt.shape, append=0)\ng = λ * functional.L21Norm()\nf = loss.SquaredL2Loss(y=y, A=A)\nx0 = snp.clip(A.fbp(y), 0, 1.0)\n\nsolver = ADMM(\n    f=f,\n    g_list=[g],\n    C_list=[C],\n    rho_list=[ρ],\n    x0=x0,\n    maxiter=maxiter,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": cg_tol, \"maxiter\": cg_maxiter}),\n    itstat_options={\"display\": True, \"period\": 5},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nsolver.solve()\nhist = solver.itstat_object.history(transpose=True)\nx_reconstruction = snp.clip(solver.x, 0, 1.0)\n\n\n\"\"\"\nShow the recovered image.\n\"\"\"\n\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(x_gt, title=\"Ground truth\", cbar=None, fig=fig, ax=ax[0])\nplot.imview(\n    x0,\n    title=\"FBP Reconstruction: \\nSNR: %.2f (dB), MAE: %.3f\"\n    % (metric.snr(x_gt, x0), metric.mae(x_gt, x0)),\n    cbar=None,\n    fig=fig,\n    ax=ax[1],\n)\nplot.imview(\n    x_reconstruction,\n    title=\"TV Reconstruction\\nSNR: %.2f (dB), MAE: %.3f\"\n    % (metric.snr(x_gt, x_reconstruction), metric.mae(x_gt, x_reconstruction)),\n    fig=fig,\n    ax=ax[2],\n)\ndivider = make_axes_locatable(ax[2])\ncax = divider.append_axes(\"right\", size=\"5%\", pad=0.2)\nfig.colorbar(ax[2].get_images()[0], cax=cax, label=\"arbitrary units\")\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\nplot.plot(\n    hist.Objective,\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    ylbl=\"Functional value\",\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[1],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/ct_unet_train_foam2.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"\nCT Training and Reconstructions with UNet\n=========================================\n\nThis example demonstrates the training and application of UNet to denoise\npreviously filtered back projections (FBP) for CT reconstruction inspired\nby :cite:`jin-2017-unet`.\n\"\"\"\n\n# isort: off\nimport os\nfrom time import time\n\nimport logging\nimport ray\n\nray.init(logging_level=logging.ERROR)  # need to call init before jax import: ray-project/ray#44087\n\n# Set an arbitrary processor count (only applies if GPU is not available).\nos.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=8\"\n\nimport jax\n\ntry:\n    from jax.extend.backend import get_backend  # introduced in jax 0.4.33\nexcept ImportError:\n    from jax.lib.xla_bridge import get_backend\n\nimport numpy as np\n\nfrom mpl_toolkits.axes_grid1 import make_axes_locatable\n\nfrom scico import flax as sflax\nfrom scico import metric, plot\nfrom scico.flax.examples import load_ct_data\n\nplatform = get_backend().platform\nprint(\"Platform: \", platform)\n\n\n\"\"\"\nRead data from cache or generate if not available.\n\"\"\"\nN = 256  # phantom size\ntrain_nimg = 498  # number of training images\ntest_nimg = 32  # number of testing images\nnimg = train_nimg + test_nimg\nn_projection = 45  # CT views\n\ntrdt, ttdt = load_ct_data(train_nimg, test_nimg, N, n_projection, verbose=True)\n\n\n\"\"\"\nBuild training and testing structures. Inputs are the filter\nback-projected sinograms and outpus are the original generated foams.\nKeep training and testing partitions.\n\"\"\"\ntrain_ds = {\"image\": trdt[\"fbp\"], \"label\": trdt[\"img\"]}\ntest_ds = {\"image\": ttdt[\"fbp\"], \"label\": ttdt[\"img\"]}\n\n\n\"\"\"\nDefine configuration dictionary for model and training loop.\n\nParameters have been selected for demonstration purposes and relatively\nshort training. The model depth controls the levels of pooling in the\nU-Net model. The block depth controls the number of layers at each level\nof depth. The number of filters controls the number of filters at the\ninput and output levels and doubles (halves) at each pooling (unpooling)\noperation. Better performance may be obtained by increasing depth, block\ndepth, number of filters or training epochs, but may require longer\ntraining times.\n\"\"\"\n# model configuration\nmodel_conf = {\n    \"depth\": 2,\n    \"num_filters\": 64,\n    \"block_depth\": 2,\n}\n# training configuration\ntrain_conf: sflax.ConfigDict = {\n    \"seed\": 0,\n    \"opt_type\": \"SGD\",\n    \"momentum\": 0.9,\n    \"batch_size\": 16,\n    \"num_epochs\": 200,\n    \"base_learning_rate\": 1e-2,\n    \"warmup_epochs\": 0,\n    \"log_every_steps\": 1000,\n    \"log\": True,\n    \"checkpointing\": True,\n}\n\n\n\"\"\"\nConstruct UNet model.\n\"\"\"\nchannels = train_ds[\"image\"].shape[-1]\nmodel = sflax.UNet(\n    depth=model_conf[\"depth\"],\n    channels=channels,\n    num_filters=model_conf[\"num_filters\"],\n    block_depth=model_conf[\"block_depth\"],\n)\n\n\n\"\"\"\nRun training loop.\n\"\"\"\nworkdir = os.path.join(os.path.expanduser(\"~\"), \".cache\", \"scico\", \"examples\", \"unet_ct_out\")\ntrain_conf[\"workdir\"] = workdir\nprint(f\"\\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}\")\nprint(f\"JAX local devices: {jax.local_devices()}\\n\")\n\ntrainer = sflax.BasicFlaxTrainer(\n    train_conf,\n    model,\n    train_ds,\n    test_ds,\n)\nmodvar, stats_object = trainer.train()\n\n\n\"\"\"\nEvaluate on testing data.\n\"\"\"\ndel train_ds[\"image\"]\ndel train_ds[\"label\"]\n\nfmap = sflax.FlaxMap(model, modvar)\ndel model, modvar\n\nmaxn = test_nimg // 2\nstart_time = time()\noutput = fmap(test_ds[\"image\"][:maxn])\ntime_eval = time() - start_time\noutput = jax.numpy.clip(output, a_min=0, a_max=1.0)\n\n\n\"\"\"\nEvaluate trained model in terms of reconstruction time and data fidelity.\n\"\"\"\nsnr_eval = metric.snr(test_ds[\"label\"][:maxn], output)\npsnr_eval = metric.psnr(test_ds[\"label\"][:maxn], output)\nprint(\n    f\"{'UNet training':15s}{'epochs:':2s}{train_conf['num_epochs']:>5d}\"\n    f\"{'':21s}{'time[s]:':10s}{trainer.train_time:>7.2f}\"\n)\nprint(\n    f\"{'UNet testing':15s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}\"\n    f\"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}\"\n)\n\n\n\"\"\"\nPlot comparison.\n\"\"\"\nkey = jax.random.key(123)\nindx = jax.random.randint(key, shape=(1,), minval=0, maxval=maxn)[0]\n\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(test_ds[\"label\"][indx, ..., 0], title=\"Ground truth\", cbar=None, fig=fig, ax=ax[0])\nplot.imview(\n    test_ds[\"image\"][indx, ..., 0],\n    title=\"FBP Reconstruction: \\nSNR: %.2f (dB), MAE: %.3f\"\n    % (\n        metric.snr(test_ds[\"label\"][indx, ..., 0], test_ds[\"image\"][indx, ..., 0]),\n        metric.mae(test_ds[\"label\"][indx, ..., 0], test_ds[\"image\"][indx, ..., 0]),\n    ),\n    cbar=None,\n    fig=fig,\n    ax=ax[1],\n)\nplot.imview(\n    output[indx, ..., 0],\n    title=\"UNet Reconstruction\\nSNR: %.2f (dB), MAE: %.3f\"\n    % (\n        metric.snr(test_ds[\"label\"][indx, ..., 0], output[indx, ..., 0]),\n        metric.mae(test_ds[\"label\"][indx, ..., 0], output[indx, ..., 0]),\n    ),\n    fig=fig,\n    ax=ax[2],\n)\ndivider = make_axes_locatable(ax[2])\ncax = divider.append_axes(\"right\", size=\"5%\", pad=0.2)\nfig.colorbar(ax[2].get_images()[0], cax=cax, label=\"arbitrary units\")\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics. Statistics are generated only if a training\ncycle was done (i.e. if not reading final epoch results from checkpoint).\n\"\"\"\nif stats_object is not None and len(stats_object.iterations) > 0:\n    hist = stats_object.history(transpose=True)\n    fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\n    plot.plot(\n        np.array((hist.Train_Loss, hist.Eval_Loss)).T,\n        x=hist.Epoch,\n        ptyp=\"semilogy\",\n        title=\"Loss function\",\n        xlbl=\"Epoch\",\n        ylbl=\"Loss value\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[0],\n    )\n    plot.plot(\n        np.array((hist.Train_SNR, hist.Eval_SNR)).T,\n        x=hist.Epoch,\n        title=\"Metric\",\n        xlbl=\"Epoch\",\n        ylbl=\"SNR (dB)\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[1],\n    )\n    fig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/deconv_circ_tv_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nCirculant Blur Image Deconvolution with TV Regularization\n=========================================================\n\nThis example demonstrates the solution of an image deconvolution problem\nwith isotropic total variation (TV) regularization\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - A \\mathbf{x}\n  \\|_2^2 + \\lambda \\| C \\mathbf{x} \\|_{2,1} \\;,$$\n\nwhere $A$ is a circular convolution operator, $\\mathbf{y}$ is the blurred\nimage, $C$ is a 2D finite difference operator, and $\\mathbf{x}$ is the\ndeconvolved image.\n\"\"\"\n\nfrom xdesign import SiemensStar, discrete_phantom\n\nimport scico.numpy as snp\nimport scico.random\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.optimize.admm import ADMM, CircularConvolveSolver\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nphantom = SiemensStar(32)\nN = 256  # image size\nx_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)\n\n\n\"\"\"\nSet up the forward operator and create a test signal consisting of a\nblurred signal with additive Gaussian noise.\n\"\"\"\nn = 5  # convolution kernel size\nσ = 20.0 / 255  # noise level\n\npsf = snp.ones((n, n)) / (n * n)\nA = linop.CircularConvolve(h=psf, input_shape=x_gt.shape)\n\nAx = A(x_gt)  # blurred image\nnoise, key = scico.random.randn(Ax.shape, seed=0)\ny = Ax + σ * noise\n\n\n\"\"\"\nSet up an ADMM solver object.\n\"\"\"\nλ = 2e-2  # ℓ2,1 norm regularization parameter\nρ = 5e-1  # ADMM penalty parameter\nmaxiter = 50  # number of ADMM iterations\n\nf = loss.SquaredL2Loss(y=y, A=A)\n# Penalty parameters must be accounted for in the gi functions, not as\n# additional inputs.\ng = λ * functional.L21Norm()  # regularization functionals gi\nC = linop.FiniteDifference(x_gt.shape, circular=True)\nsolver = ADMM(\n    f=f,\n    g_list=[g],\n    C_list=[C],\n    rho_list=[ρ],\n    x0=A.adj(y),\n    maxiter=maxiter,\n    subproblem_solver=CircularConvolveSolver(),\n    itstat_options={\"display\": True, \"period\": 10},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx = solver.solve()\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow the recovered image.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0])\nplot.imview(y, title=\"Blurred, noisy image: %.2f (dB)\" % metric.psnr(x_gt, y), fig=fig, ax=ax[1])\nplot.imview(x, title=\"Deconvolved image: %.2f (dB)\" % metric.psnr(x_gt, x), fig=fig, ax=ax[2])\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\nplot.plot(\n    hist.Objective,\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    ylbl=\"Functional value\",\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[1],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/deconv_datagen_bsds.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nBlurred Data Generation (Natural Images) for NN Training\n========================================================\n\nThis example demonstrates how to generate blurred image data for\ntraining neural network models for deconvolution (deblurring). The\noriginal images are part of the\n[BSDS500 dataset](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/)\nprovided by the Berkeley Segmentation Dataset and Benchmark project.\n\"\"\"\n\nimport numpy as np\n\nfrom jax import vmap\n\nfrom scico import plot\nfrom scico.flax.examples import PaddedCircularConvolve, load_image_data\n\n\"\"\"\nDefine blur operator.\n\"\"\"\noutput_size = 256  # patch size\nchannels = 1  # gray scale problem\nblur_shape = (9, 9)  # shape of blur kernel\nblur_sigma = 5  # Gaussian blur kernel parameter\n\nopBlur = PaddedCircularConvolve(output_size, channels, blur_shape, blur_sigma)\nopBlur_vmap = vmap(opBlur)  # for batch processing\n\n\n\"\"\"\nRead data from cache or generate if not available.\n\"\"\"\ntrain_nimg = 400  # number of training images\ntest_nimg = 64  # number of testing images\nnimg = train_nimg + test_nimg\ngray = True  # use gray scale images\ndata_mode = \"dcnv\"  # deconvolution problem\nnoise_level = 0.005  # standard deviation of noise\nnoise_range = False  # use fixed noise level\nstride = 100  # stride to sample multiple patches from each image\naugment = True  # augment data via rotations and flips\n\ntrain_ds, test_ds = load_image_data(\n    train_nimg,\n    test_nimg,\n    output_size,\n    gray,\n    data_mode,\n    verbose=True,\n    noise_level=noise_level,\n    noise_range=noise_range,\n    transf=opBlur_vmap,\n    stride=stride,\n    augment=augment,\n)\n\n\n\"\"\"\nPlot randomly selected sample.\n\"\"\"\nindx_tr = np.random.randint(0, train_nimg)\nindx_te = np.random.randint(0, test_nimg)\nfig, axes = plot.subplots(nrows=2, ncols=2, figsize=(7, 7))\nplot.imview(\n    train_ds[\"label\"][indx_tr, ..., 0],\n    title=\"Ground truth - Training Sample\",\n    fig=fig,\n    ax=axes[0, 0],\n)\nplot.imview(\n    train_ds[\"image\"][indx_tr, ..., 0],\n    title=\"Blurred Image - Training Sample\",\n    fig=fig,\n    ax=axes[0, 1],\n)\nplot.imview(\n    test_ds[\"label\"][indx_te, ..., 0],\n    title=\"Ground truth - Testing Sample\",\n    fig=fig,\n    ax=axes[1, 0],\n)\nplot.imview(\n    test_ds[\"image\"][indx_te, ..., 0],\n    title=\"Blurred Image - Testing Sample\",\n    fig=fig,\n    ax=axes[1, 1],\n)\nfig.suptitle(r\"Training and Testing samples\")\nfig.tight_layout()\nfig.colorbar(\n    axes[0, 1].get_images()[0],\n    ax=axes,\n    shrink=0.5,\n    pad=0.05,\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/deconv_datagen_foam1.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nBlurred Data Generation (Foams) for NN Training\n===============================================\n\nThis example demonstrates how to generate blurred image data for\ntraining neural network models for deconvolution (deblurring), using foam\nphantoms generated by `xdesign`.\n\"\"\"\n\n# isort: off\nimport numpy as np\n\nimport logging\nimport ray\n\nray.init(logging_level=logging.ERROR)  # need to call init before jax import: ray-project/ray#44087\n\nfrom scico import plot\nfrom scico.flax.examples import load_blur_data\n\n\"\"\"\nRead data from cache or generate if not available.\n\"\"\"\nn = 3  # convolution kernel size\nσ = 20.0 / 255  # noise level\npsf = np.ones((n, n)) / (n * n)  # kernel\n\ntrain_nimg = 416  # number of training images\ntest_nimg = 64  # number of testing images\nnimg = train_nimg + test_nimg\noutput_size = 256  # image size\n\ntrain_ds, test_ds = load_blur_data(\n    train_nimg,\n    test_nimg,\n    output_size,\n    psf,\n    σ,\n    verbose=True,\n)\n\n\n\"\"\"\nPlot randomly selected sample.\n\"\"\"\nindx_tr = np.random.randint(0, train_nimg)\nindx_te = np.random.randint(0, test_nimg)\nfig, axes = plot.subplots(nrows=2, ncols=2, figsize=(7, 7))\nplot.imview(\n    train_ds[\"label\"][indx_tr, ..., 0],\n    title=\"Ground truth - Training Sample\",\n    fig=fig,\n    ax=axes[0, 0],\n)\nplot.imview(\n    train_ds[\"image\"][indx_tr, ..., 0],\n    title=\"Blurred Image - Training Sample\",\n    fig=fig,\n    ax=axes[0, 1],\n)\nplot.imview(\n    test_ds[\"label\"][indx_te, ..., 0],\n    title=\"Ground truth - Testing Sample\",\n    fig=fig,\n    ax=axes[1, 0],\n)\nplot.imview(\n    test_ds[\"image\"][indx_te, ..., 0],\n    title=\"Blurred Image - Testing Sample\",\n    fig=fig,\n    ax=axes[1, 1],\n)\nfig.suptitle(r\"Training and Testing samples\")\nfig.tight_layout()\nfig.colorbar(\n    axes[0, 1].get_images()[0],\n    ax=axes,\n    shrink=0.5,\n    pad=0.05,\n    label=\"Arbitrary Units\",\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/deconv_microscopy_allchn_tv_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nDeconvolution Microscopy (All Channels)\n=======================================\n\nThis example partially replicates a [GlobalBioIm\nexample](https://biomedical-imaging-group.github.io/GlobalBioIm/examples.html)\nusing the [microscopy data](http://bigwww.epfl.ch/deconvolution/bio/)\nprovided by the EPFL Biomedical Imaging Group.\n\nThe deconvolution problem is solved using class\n[admm.ADMM](../_autosummary/scico.optimize.rst#scico.optimize.ADMM) to\nsolve an image deconvolution problem with isotropic total variation (TV)\nregularization\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| M (\\mathbf{y} - A \\mathbf{x})\n  \\|_2^2 + \\lambda \\| C \\mathbf{x} \\|_{2,1} +\n  \\iota_{\\mathrm{NN}}(\\mathbf{x}) \\;,$$\n\nwhere $M$ is a mask operator, $A$ is circular convolution,\n$\\mathbf{y}$ is the blurred image, $C$ is a convolutional gradient\noperator, $\\iota_{\\mathrm{NN}}$ is the indicator function of the\nnon-negativity constraint, and $\\mathbf{x}$ is the deconvolved image.\n\"\"\"\n\n# isort: off\nimport numpy as np\n\nimport logging\nimport ray\n\nray.init(logging_level=logging.ERROR)  # need to call init before jax import: ray-project/ray#44087\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, plot\nfrom scico.examples import downsample_volume, epfl_deconv_data, tile_volume_slices\nfrom scico.optimize.admm import ADMM, CircularConvolveSolver\n\n\"\"\"\nGet and preprocess data. The data is downsampled to limit the memory\nrequirements and run time of the example. Reducing the downsampling rate\nwill make the example slower and more memory-intensive. To run this\nexample on a GPU it may be necessary to set environment variables\n`XLA_PYTHON_CLIENT_ALLOCATOR=platform` and\n`XLA_PYTHON_CLIENT_PREALLOCATE=false`. If your GPU does not have enough\nmemory, try setting the environment variable `JAX_PLATFORM_NAME=cpu` to\nrun on CPU.\n\"\"\"\ndownsampling_rate = 2\n\ny_list = []\ny_pad_list = []\npsf_list = []\nfor channel in range(3):\n    y, psf = epfl_deconv_data(channel, verbose=True)  # get data\n    y = downsample_volume(y, downsampling_rate)  # downsample\n    psf = downsample_volume(psf, downsampling_rate)\n    y -= y.min()  # normalize y\n    y /= y.max()\n    psf /= psf.sum()  # normalize psf\n    if channel == 0:\n        padding = [[0, p] for p in snp.array(psf.shape) - 1]\n        mask = snp.pad(snp.ones_like(y), padding)\n    y_pad = snp.pad(y, padding)  # zero-padded version of y\n    y_list.append(y)\n    y_pad_list.append(y_pad)\n    psf_list.append(psf)\ny = snp.stack(y_list, axis=-1)\nyshape = y.shape\ndel y_list\n\n\n\"\"\"\nDefine problem and algorithm parameters.\n\"\"\"\nλ = 2e-6  # ℓ1 norm regularization parameter\nρ0 = 1e-3  # ADMM penalty parameter for first auxiliary variable\nρ1 = 1e-3  # ADMM penalty parameter for second auxiliary variable\nρ2 = 1e-3  # ADMM penalty parameter for third auxiliary variable\nmaxiter = 100  # number of ADMM iterations\n\n\n\"\"\"\nDetermine available computing resources, and put large arrays in ray\nobject store.\n\"\"\"\nngpu = 0\nar = ray.available_resources()\nncpu = max(int(ar[\"CPU\"]) // 3, 1)\nif \"GPU\" in ar:\n    ngpu = int(ar[\"GPU\"]) // 3\nprint(f\"Running on {ncpu} CPUs and {ngpu} GPUs per process\")\n\ny_pad_list = ray.put(y_pad_list)\npsf_list = ray.put(psf_list)\nmask_store = ray.put(mask)\n\n\n\"\"\"\nDefine ray remote function for parallel solves.\n\"\"\"\n\n\n@ray.remote(num_cpus=ncpu, num_gpus=ngpu)\ndef deconvolve_channel(channel):\n    \"\"\"Deconvolve a single channel.\"\"\"\n    y_pad = ray.get(y_pad_list)[channel]\n    psf = ray.get(psf_list)[channel]\n    mask = ray.get(mask_store)\n    M = linop.Diagonal(mask)\n    C0 = linop.CircularConvolve(\n        h=psf, input_shape=mask.shape, h_center=snp.array(psf.shape) / 2 - 0.5  # forward operator\n    )\n    C1 = linop.FiniteDifference(input_shape=mask.shape, circular=True)  # gradient operator\n    C2 = linop.Identity(mask.shape)  # identity operator\n    g0 = loss.SquaredL2Loss(y=y_pad, A=M)  # loss function (forward model)\n    g1 = λ * functional.L21Norm()  # TV penalty (when applied to gradient)\n    g2 = functional.NonNegativeIndicator()  # non-negativity constraint\n    if channel == 0:\n        print(\"Displaying solver status for channel 0\")\n        display = True\n    else:\n        display = False\n    solver = ADMM(\n        f=None,\n        g_list=[g0, g1, g2],\n        C_list=[C0, C1, C2],\n        rho_list=[ρ0, ρ1, ρ2],\n        maxiter=maxiter,\n        itstat_options={\"display\": display, \"period\": 10, \"overwrite\": False},\n        x0=y_pad,\n        subproblem_solver=CircularConvolveSolver(),\n    )\n    x_pad = solver.solve()\n    x = x_pad[: yshape[0], : yshape[1], : yshape[2]]\n    return (x, solver.itstat_object.history(transpose=True))\n\n\n\"\"\"\nSolve problems for all three channels in parallel and extract results.\n\"\"\"\nray_return = ray.get([deconvolve_channel.remote(channel) for channel in range(3)])\nx = snp.stack([t[0] for t in ray_return], axis=-1)\nsolve_stats = [t[1] for t in ray_return]\n\n\n\"\"\"\nShow the recovered image.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(14, 7))\nplot.imview(tile_volume_slices(y), title=\"Blurred measurements\", fig=fig, ax=ax[0])\nplot.imview(tile_volume_slices(x), title=\"Deconvolved image\", fig=fig, ax=ax[1])\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(18, 5))\nplot.plot(\n    np.stack([s.Objective for s in solve_stats]).T,\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    ylbl=\"Functional value\",\n    lgnd=(\"CY3\", \"DAPI\", \"FITC\"),\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    np.stack([s.Prml_Rsdl for s in solve_stats]).T,\n    ptyp=\"semilogy\",\n    title=\"Primal Residual\",\n    xlbl=\"Iteration\",\n    lgnd=(\"CY3\", \"DAPI\", \"FITC\"),\n    fig=fig,\n    ax=ax[1],\n)\nplot.plot(\n    np.stack([s.Dual_Rsdl for s in solve_stats]).T,\n    ptyp=\"semilogy\",\n    title=\"Dual Residual\",\n    xlbl=\"Iteration\",\n    lgnd=(\"CY3\", \"DAPI\", \"FITC\"),\n    fig=fig,\n    ax=ax[2],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/deconv_microscopy_tv_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nDeconvolution Microscopy (Single Channel)\n=========================================\n\nThis example partially replicates a [GlobalBioIm\nexample](https://biomedical-imaging-group.github.io/GlobalBioIm/examples.html)\nusing the [microscopy data](http://bigwww.epfl.ch/deconvolution/bio/)\nprovided by the EPFL Biomedical Imaging Group.\n\nThe deconvolution problem is solved using class\n[admm.ADMM](../_autosummary/scico.optimize.rst#scico.optimize.ADMM) to\nsolve an image deconvolution problem with isotropic total variation (TV)\nregularization\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| M (\\mathbf{y} - A \\mathbf{x})\n  \\|_2^2 + \\lambda \\| C \\mathbf{x} \\|_{2,1} +\n  \\iota_{\\mathrm{NN}}(\\mathbf{x}) \\;,$$\n\nwhere $M$ is a mask operator, $A$ is circular convolution,\n$\\mathbf{y}$ is the blurred image, $C$ is a convolutional gradient\noperator, $\\iota_{\\mathrm{NN}}$ is the indicator function of the\nnon-negativity constraint, and $\\mathbf{x}$ is the deconvolved image.\n\"\"\"\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, plot, util\nfrom scico.examples import downsample_volume, epfl_deconv_data, tile_volume_slices\nfrom scico.optimize.admm import ADMM, CircularConvolveSolver\n\n\"\"\"\nGet and preprocess data. The data is downsampled to limit the memory\nrequirements and run time of the example. Reducing the downsampling rate\nwill make the example slower and more memory-intensive. To run this\nexample on a GPU it may be necessary to set environment variables\n`XLA_PYTHON_CLIENT_ALLOCATOR=platform` and\n`XLA_PYTHON_CLIENT_PREALLOCATE=false`. If your GPU does not have enough\nmemory, try setting the environment variable `JAX_PLATFORM_NAME=cpu` to\nrun on CPU.\n\"\"\"\nchannel = 0\ndownsampling_rate = 2\n\ny, psf = epfl_deconv_data(channel, verbose=True)\ny = downsample_volume(y, downsampling_rate)\npsf = downsample_volume(psf, downsampling_rate)\n\ny -= y.min()\ny /= y.max()\npsf /= psf.sum()\n\n\n\"\"\"\nPad data and create mask.\n\"\"\"\npadding = [[0, p] for p in snp.array(psf.shape) - 1]\ny_pad = snp.pad(y, padding)\nmask = snp.pad(snp.ones_like(y), padding)\n\n\n\"\"\"\nDefine problem and algorithm parameters.\n\"\"\"\nλ = 2e-6  # ℓ1 norm regularization parameter\nρ0 = 1e-3  # ADMM penalty parameter for first auxiliary variable\nρ1 = 1e-3  # ADMM penalty parameter for second auxiliary variable\nρ2 = 1e-3  # ADMM penalty parameter for third auxiliary variable\nmaxiter = 100  # number of ADMM iterations\n\n\n\"\"\"\nCreate operators.\n\"\"\"\nM = linop.Diagonal(mask)\nC0 = linop.CircularConvolve(h=psf, input_shape=mask.shape, h_center=snp.array(psf.shape) / 2 - 0.5)\nC1 = linop.FiniteDifference(input_shape=mask.shape, circular=True)\nC2 = linop.Identity(mask.shape)\n\n\n\"\"\"\nCreate functionals.\n\"\"\"\ng0 = loss.SquaredL2Loss(y=y_pad, A=M)  # loss function (forward model)\ng1 = λ * functional.L21Norm()  # TV penalty (when applied to gradient)\ng2 = functional.NonNegativeIndicator()  # non-negativity constraint\n\n\n\"\"\"\nSet up ADMM solver object and solve problem.\n\"\"\"\nsolver = ADMM(\n    f=None,\n    g_list=[g0, g1, g2],\n    C_list=[C0, C1, C2],\n    rho_list=[ρ0, ρ1, ρ2],\n    maxiter=maxiter,\n    itstat_options={\"display\": True, \"period\": 10},\n    x0=y_pad,\n    subproblem_solver=CircularConvolveSolver(),\n)\n\nprint(\"Solving on %s\\n\" % util.device_info())\nsolver.solve()\nsolve_stats = solver.itstat_object.history(transpose=True)\nx_pad = solver.x\nx = x_pad[: y.shape[0], : y.shape[1], : y.shape[2]]\n\n\n\"\"\"\nShow the recovered image.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(14, 7))\nplot.imview(tile_volume_slices(y), title=\"Blurred measurements\", fig=fig, ax=ax[0])\nplot.imview(tile_volume_slices(x), title=\"Deconvolved image\", fig=fig, ax=ax[1])\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\nplot.plot(\n    solve_stats.Objective,\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    ylbl=\"Functional value\",\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((solve_stats.Prml_Rsdl, solve_stats.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[1],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/deconv_modl_train_foam1.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nDeconvolution Training and Reconstructions with MoDL\n====================================================\n\nThis example demonstrates the training and application of a\nmodel-based deep learning (MoDL) architecture described in\n:cite:`aggarwal-2019-modl` for a deconvolution (deblurring) problem.\n\nThe source images are foam phantoms generated with xdesign.\n\nA class\n[scico.flax.MoDLNet](../_autosummary/scico.flax.rst#scico.flax.MoDLNet)\nimplements the MoDL architecture, which solves the optimization\nproblem\n\n$$\\mathrm{argmin}_{\\mathbf{x}} \\; \\| A \\mathbf{x} - \\mathbf{y} \\|_2^2\n+ \\lambda \\, \\| \\mathbf{x} - \\mathrm{D}_w(\\mathbf{x})\\|_2^2 \\;,$$\n\nwhere $A$ is a circular convolution, $\\mathbf{y}$ is a set of blurred\nimages, $\\mathrm{D}_w$ is the regularization (a denoiser), and\n$\\mathbf{x}$ is the set of deblurred images. The MoDL abstracts the\niterative solution by an unrolled network where each iteration\ncorresponds to a different stage in the MoDL network and updates the\nprediction by solving\n\n$$\\mathbf{x}^{k+1} = (A^T A + \\lambda \\, I)^{-1} (A^T \\mathbf{y} +\n\\lambda \\, \\mathbf{z}^k) \\;,$$\n\nvia conjugate gradient. In the expression, $k$ is the index of the stage\n(iteration), $\\mathbf{z}^k = \\mathrm{ResNet}(\\mathbf{x}^{k})$ is the\nregularization (a denoiser implemented as a residual convolutional neural\nnetwork), $\\mathbf{x}^k$ is the output of the previous stage,\n$\\lambda > 0$ is a learned regularization parameter, and $I$ is the\nidentity operator. The output of the final stage is the set of deblurred\nimages.\n\"\"\"\n\n# isort: off\nimport os\nfrom functools import partial\nfrom time import time\n\nimport numpy as np\n\nimport logging\nimport ray\n\nray.init(logging_level=logging.ERROR)  # need to call init before jax import: ray-project/ray#44087\n\n# Set an arbitrary processor count (only applies if GPU is not available).\nos.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=8\"\n\nimport jax\n\ntry:\n    from jax.extend.backend import get_backend  # introduced in jax 0.4.33\nexcept ImportError:\n    from jax.lib.xla_bridge import get_backend\n\nfrom mpl_toolkits.axes_grid1 import make_axes_locatable\n\nfrom scico import flax as sflax\nfrom scico import metric, plot\nfrom scico.flax.examples import load_blur_data\nfrom scico.flax.train.traversals import clip_positive, construct_traversal\nfrom scico.linop import CircularConvolve\n\nplatform = get_backend().platform\nprint(\"Platform: \", platform)\n\n\n\"\"\"\nDefine blur operator.\n\"\"\"\noutput_size = 256  # image size\n\nn = 3  # convolution kernel size\nσ = 20.0 / 255  # noise level\npsf = np.ones((n, n)) / (n * n)  # blur kernel\n\nishape = (output_size, output_size)\nopBlur = CircularConvolve(h=psf, input_shape=ishape)\nopBlur_vmap = jax.vmap(opBlur)  # for batch processing in data generation\n\n\n\"\"\"\nRead data from cache or generate if not available.\n\"\"\"\ntrain_nimg = 416  # number of training images\ntest_nimg = 64  # number of testing images\nnimg = train_nimg + test_nimg\n\ntrain_ds, test_ds = load_blur_data(\n    train_nimg,\n    test_nimg,\n    output_size,\n    psf,\n    σ,\n    verbose=True,\n)\n\n\n\"\"\"\nDefine configuration dictionary for model and training loop.\n\nParameters have been selected for demonstration purposes and relatively\nshort training. The model depth is akin to the number of unrolled\niterations in the MoDL model. The block depth controls the number of\nlayers at each unrolled iteration. The number of filters is uniform\nthroughout the iterations. The iterations used for the conjugate gradient\n(CG) solver can also be specified. Better performance may be obtained by\nincreasing depth, block depth, number of filters, CG iterations, or\ntraining epochs, but may require longer training times.\n\"\"\"\n# model configuration\nmodel_conf = {\n    \"depth\": 2,\n    \"num_filters\": 64,\n    \"block_depth\": 4,\n    \"cg_iter\": 4,\n}\n# training configuration\ntrain_conf: sflax.ConfigDict = {\n    \"seed\": 0,\n    \"opt_type\": \"SGD\",\n    \"momentum\": 0.9,\n    \"batch_size\": 16,\n    \"num_epochs\": 25,\n    \"base_learning_rate\": 1e-2,\n    \"warmup_epochs\": 0,\n    \"log_every_steps\": 100,\n    \"log\": True,\n    \"checkpointing\": True,\n}\n\n\n\"\"\"\nConstruct functionality for ensuring that the learned regularization\nparameter is always positive.\n\"\"\"\nlmbdatrav = construct_traversal(\"lmbda\")  # select lmbda parameters in model\nlmbdapos = partial(\n    clip_positive,  # apply this function\n    traversal=lmbdatrav,  # to lmbda parameters in model\n    minval=5e-4,\n)\n\n\n\"\"\"\nPrint configuration of distributed run.\n\"\"\"\nprint(f\"\\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}\")\nprint(f\"JAX local devices: {jax.local_devices()}\\n\")\n\n\n\"\"\"\nCheck for iterated trained model. If not found, construct MoDLNet model,\nusing only one iteration (depth) in model and few CG iterations for\nfaster intialization. Run first stage (initialization) training loop\nfollowed by a second stage (depth iterations) training loop.\n\"\"\"\nchannels = train_ds[\"image\"].shape[-1]\nworkdir2 = os.path.join(\n    os.path.expanduser(\"~\"), \".cache\", \"scico\", \"examples\", \"modl_dcnv_out\", \"iterated\"\n)\n\nstats_object_ini = None\nstats_object = None\n\ncheckpoint_files = []\nfor dirpath, dirnames, filenames in os.walk(workdir2):\n    checkpoint_files = [fn for fn in filenames]\n\nif len(checkpoint_files) > 0:\n    model = sflax.MoDLNet(\n        operator=opBlur,\n        depth=model_conf[\"depth\"],\n        channels=channels,\n        num_filters=model_conf[\"num_filters\"],\n        block_depth=model_conf[\"block_depth\"],\n        cg_iter=model_conf[\"cg_iter\"],\n    )\n\n    train_conf[\"workdir\"] = workdir2\n    train_conf[\"post_lst\"] = [lmbdapos]\n    # Construct training object\n    trainer = sflax.BasicFlaxTrainer(\n        train_conf,\n        model,\n        train_ds,\n        test_ds,\n    )\n    start_time = time()\n    modvar, stats_object = trainer.train()\n    time_train = time() - start_time\n    time_init = 0.0\n    epochs_init = 0\nelse:\n    # One iteration (depth) in model and few CG iterations\n    model = sflax.MoDLNet(\n        operator=opBlur,\n        depth=1,\n        channels=channels,\n        num_filters=model_conf[\"num_filters\"],\n        block_depth=model_conf[\"block_depth\"],\n        cg_iter=model_conf[\"cg_iter\"],\n    )\n    # First stage: initialization training loop.\n    workdir1 = os.path.join(os.path.expanduser(\"~\"), \".cache\", \"scico\", \"examples\", \"modl_dcnv_out\")\n    train_conf[\"workdir\"] = workdir1\n    train_conf[\"post_lst\"] = [lmbdapos]\n    # Construct training object\n    trainer = sflax.BasicFlaxTrainer(\n        train_conf,\n        model,\n        train_ds,\n        test_ds,\n    )\n\n    start_time = time()\n    modvar, stats_object_ini = trainer.train()\n    time_init = time() - start_time\n    epochs_init = train_conf[\"num_epochs\"]\n\n    print(\n        f\"{'MoDLNet init':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}{'':3s}\"\n        f\"{'time[s]:':21s}{time_init:>7.2f}\"\n    )\n\n    # Second stage: depth iterations training loop.\n    model.depth = model_conf[\"depth\"]\n    train_conf[\"workdir\"] = workdir2\n    # Construct training object, include current model parameters\n    trainer = sflax.BasicFlaxTrainer(\n        train_conf,\n        model,\n        train_ds,\n        test_ds,\n        variables0=modvar,\n    )\n\n    start_time = time()\n    modvar, stats_object = trainer.train()\n    time_train = time() - start_time\n\n\n\"\"\"\nEvaluate on testing data.\n\"\"\"\ndel train_ds[\"image\"]\ndel train_ds[\"label\"]\n\nfmap = sflax.FlaxMap(model, modvar)\ndel model, modvar\n\nmaxn = test_nimg // 4\nstart_time = time()\noutput = fmap(test_ds[\"image\"][:maxn])\ntime_eval = time() - start_time\noutput = np.clip(output, a_min=0, a_max=1.0)\n\n\n\"\"\"\nEvaluate trained model in terms of reconstruction time\nand data fidelity.\n\"\"\"\ntotal_epochs = epochs_init + train_conf[\"num_epochs\"]\ntotal_time_train = time_init + time_train\nsnr_eval = metric.snr(test_ds[\"label\"][:maxn], output)\npsnr_eval = metric.psnr(test_ds[\"label\"][:maxn], output)\nprint(\n    f\"{'MoDLNet training':18s}{'epochs:':2s}{total_epochs:>5d}{'':21s}\"\n    f\"{'time[s]:':10s}{total_time_train:>7.2f}\"\n)\nprint(\n    f\"{'MoDLNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}\"\n    f\"{'':3s}{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}\"\n)\n\n\n\"\"\"\nPlot comparison.\n\"\"\"\nnp.random.seed(123)\nindx = np.random.randint(0, high=maxn)\n\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(test_ds[\"label\"][indx, ..., 0], title=\"Ground truth\", cbar=None, fig=fig, ax=ax[0])\nplot.imview(\n    test_ds[\"image\"][indx, ..., 0],\n    title=\"Blurred: \\nSNR: %.2f (dB), PSNR: %.2f\"\n    % (\n        metric.snr(test_ds[\"label\"][indx, ..., 0], test_ds[\"image\"][indx, ..., 0]),\n        metric.psnr(test_ds[\"label\"][indx, ..., 0], test_ds[\"image\"][indx, ..., 0]),\n    ),\n    cbar=None,\n    fig=fig,\n    ax=ax[1],\n)\nplot.imview(\n    output[indx, ..., 0],\n    title=\"MoDLNet Reconstruction\\nSNR: %.2f (dB), PSNR: %.2f\"\n    % (\n        metric.snr(test_ds[\"label\"][indx, ..., 0], output[indx, ..., 0]),\n        metric.psnr(test_ds[\"label\"][indx, ..., 0], output[indx, ..., 0]),\n    ),\n    fig=fig,\n    ax=ax[2],\n)\ndivider = make_axes_locatable(ax[2])\ncax = divider.append_axes(\"right\", size=\"5%\", pad=0.2)\nfig.colorbar(ax[2].get_images()[0], cax=cax, label=\"arbitrary units\")\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics. Statistics are generated only if a training\ncycle was done (i.e. if not reading final epoch results from checkpoint).\n\"\"\"\nif stats_object is not None and len(stats_object.iterations) > 0:\n    hist = stats_object.history(transpose=True)\n    fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\n    plot.plot(\n        np.array((hist.Train_Loss, hist.Eval_Loss)).T,\n        x=hist.Epoch,\n        ptyp=\"semilogy\",\n        title=\"Loss function\",\n        xlbl=\"Epoch\",\n        ylbl=\"Loss value\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[0],\n    )\n    plot.plot(\n        np.array((hist.Train_SNR, hist.Eval_SNR)).T,\n        x=hist.Epoch,\n        title=\"Metric\",\n        xlbl=\"Epoch\",\n        ylbl=\"SNR (dB)\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[1],\n    )\n    fig.show()\n\n# Stats for initialization loop\nif stats_object_ini is not None and len(stats_object_ini.iterations) > 0:\n    hist = stats_object_ini.history(transpose=True)\n    fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\n    plot.plot(\n        np.array((hist.Train_Loss, hist.Eval_Loss)).T,\n        x=hist.Epoch,\n        ptyp=\"semilogy\",\n        title=\"Loss function - Initialization\",\n        xlbl=\"Epoch\",\n        ylbl=\"Loss value\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[0],\n    )\n    plot.plot(\n        np.array((hist.Train_SNR, hist.Eval_SNR)).T,\n        x=hist.Epoch,\n        title=\"Metric - Initialization\",\n        xlbl=\"Epoch\",\n        ylbl=\"SNR (dB)\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[1],\n    )\n    fig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/deconv_odp_train_foam1.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nDeconvolution Training and Reconstructions with ODP\n===================================================\n\nThis example demonstrates the training and application of the unrolled\noptimization with deep priors (ODP) with proximal map architecture\ndescribed in :cite:`diamond-2018-odp` for a deconvolution (deblurring)\nproblem.\n\nThe source images are foam phantoms generated with xdesign.\n\nA class\n[scico.flax.ODPNet](../_autosummary/scico.flax.rst#scico.flax.ODPNet)\nimplements the ODP architecture, which solves the optimization problem\n\n$$\\mathrm{argmin}_{\\mathbf{x}} \\; \\| A \\mathbf{x} - \\mathbf{y} \\|_2^2\n+ r(\\mathbf{x}) \\;,$$\n\nwhere $A$ is a circular convolution, $\\mathbf{y}$ is a set of blurred\nimages, $r$ is a regularizer and $\\mathbf{x}$ is the set of deblurred\nimages. The ODP, proximal map architecture, abstracts the iterative\nsolution by an unrolled network where each iteration corresponds to a\ndifferent stage in the ODP network and updates the prediction by\nsolving\n\n$$\\mathbf{x}^{k+1} = \\mathrm{argmin}_{\\mathbf{x}} \\; \\alpha_k \\| A\n\\mathbf{x} - \\mathbf{y} \\|_2^2 + \\frac{1}{2} \\| \\mathbf{x} -\n\\mathbf{x}^k - \\mathbf{x}^{k+1/2} \\|_2^2 \\;,$$\n\nwhich for the deconvolution problem corresponds to\n\n$$\\mathbf{x}^{k+1} = \\mathcal{F}^{-1} \\mathrm{diag} (\\alpha_k |\n\\mathcal{K}|^2 + 1 )^{-1} \\mathcal{F} \\, (\\alpha_k K^T * \\mathbf{y} +\n\\mathbf{x}^k + \\mathbf{x}^{k+1/2}) \\;,$$\n\nwhere $k$ is the index of the stage (iteration), $\\mathbf{x}^k +\n\\mathbf{x}^{k+1/2} = \\mathrm{ResNet}(\\mathbf{x}^{k})$ is the\nregularization (implemented as a residual convolutional neural network),\n$\\mathbf{x}^k$ is the output of the previous stage, $\\alpha_k > 0$ is a\nlearned stage-wise parameter weighting the contribution of the fidelity\nterm, $\\mathcal{F}$ is the DFT, $K$ is the blur kernel, and\n$\\mathcal{K}$ is the DFT of $K$. The output of the final stage is the\nset of deblurred images.\n\"\"\"\n\n# isort: off\nimport os\nfrom functools import partial\nfrom time import time\n\nimport numpy as np\n\nimport logging\nimport ray\n\nray.init(logging_level=logging.ERROR)  # need to call init before jax import: ray-project/ray#44087\n\n# Set an arbitrary processor count (only applies if GPU is not available).\nos.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=8\"\n\nimport jax\n\ntry:\n    from jax.extend.backend import get_backend  # introduced in jax 0.4.33\nexcept ImportError:\n    from jax.lib.xla_bridge import get_backend\n\nfrom mpl_toolkits.axes_grid1 import make_axes_locatable\n\nfrom scico import flax as sflax\nfrom scico import metric, plot\nfrom scico.flax.examples import load_blur_data\nfrom scico.flax.train.traversals import clip_positive, construct_traversal\nfrom scico.linop import CircularConvolve\n\nplatform = get_backend().platform\nprint(\"Platform: \", platform)\n\n\n\"\"\"\nDefine blur operator.\n\"\"\"\noutput_size = 256  # patch size\n\nn = 3  # convolution kernel size\nσ = 20.0 / 255  # noise level\npsf = np.ones((n, n)) / (n * n)  # blur kernel\n\nishape = (output_size, output_size)\nopBlur = CircularConvolve(h=psf, input_shape=ishape)\nopBlur_vmap = jax.vmap(opBlur)  # for batch processing in data generation\n\n\n\"\"\"\nRead data from cache or generate if not available.\n\"\"\"\ntrain_nimg = 416  # number of training images\ntest_nimg = 64  # number of testing images\nnimg = train_nimg + test_nimg\n\ntrain_ds, test_ds = load_blur_data(\n    train_nimg,\n    test_nimg,\n    output_size,\n    psf,\n    σ,\n    verbose=True,\n)\n\n\n\"\"\"\nDefine configuration dictionary for model and training loop.\n\nParameters have been selected for demonstration purposes and relatively\nshort training. The model depth is akin to the number of unrolled\niterations in the ODP model. The block depth controls the number of\nlayers at each unrolled iteration. The number of filters is uniform\nthroughout the iterations. Better performance may be obtained by\nincreasing depth, block depth, number of filters or training epochs, but\nmay require longer training times.\n\"\"\"\n# model configuration\nmodel_conf = {\n    \"depth\": 2,\n    \"num_filters\": 64,\n    \"block_depth\": 3,\n}\n# training configuration\ntrain_conf: sflax.ConfigDict = {\n    \"seed\": 0,\n    \"opt_type\": \"SGD\",\n    \"momentum\": 0.9,\n    \"batch_size\": 16,\n    \"num_epochs\": 50,\n    \"base_learning_rate\": 1e-2,\n    \"warmup_epochs\": 0,\n    \"log_every_steps\": 100,\n    \"log\": True,\n    \"checkpointing\": True,\n}\n\n\n\"\"\"\nConstruct ODPNet model.\n\"\"\"\nchannels = train_ds[\"image\"].shape[-1]\nmodel = sflax.ODPNet(\n    operator=opBlur,\n    depth=model_conf[\"depth\"],\n    channels=channels,\n    num_filters=model_conf[\"num_filters\"],\n    block_depth=model_conf[\"block_depth\"],\n    odp_block=sflax.inverse.ODPProxDcnvBlock,\n)\n\n\n\"\"\"\nConstruct functionality for ensuring that the learned fidelity weight\nparameter is always positive.\n\"\"\"\nalphatrav = construct_traversal(\"alpha\")  # select alpha parameters in model\nalphapos = partial(\n    clip_positive,  # apply this function\n    traversal=alphatrav,  # to alpha parameters in model\n    minval=1e-3,\n)\n\n\n\"\"\"\nRun training loop.\n\"\"\"\nprint(f\"\\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}\")\nprint(f\"JAX local devices: {jax.local_devices()}\\n\")\n\nworkdir = os.path.join(os.path.expanduser(\"~\"), \".cache\", \"scico\", \"examples\", \"odp_dcnv_out\")\ntrain_conf[\"workdir\"] = workdir\ntrain_conf[\"post_lst\"] = [alphapos]\n# Construct training object\ntrainer = sflax.BasicFlaxTrainer(\n    train_conf,\n    model,\n    train_ds,\n    test_ds,\n)\nmodvar, stats_object = trainer.train()\n\n\n\"\"\"\nEvaluate on testing data.\n\"\"\"\ndel train_ds[\"image\"]\ndel train_ds[\"label\"]\n\nfmap = sflax.FlaxMap(model, modvar)\ndel model, modvar\n\nmaxn = test_nimg // 4\nstart_time = time()\noutput = fmap(test_ds[\"image\"][:maxn])\ntime_eval = time() - start_time\noutput = np.clip(output, a_min=0, a_max=1.0)\n\n\n\"\"\"\nEvaluate trained model in terms of reconstruction time and data\nfidelity.\n\"\"\"\nsnr_eval = metric.snr(test_ds[\"label\"][:maxn], output)\npsnr_eval = metric.psnr(test_ds[\"label\"][:maxn], output)\nprint(\n    f\"{'ODPNet training':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}\"\n    f\"{'':21s}{'time[s]:':10s}{trainer.train_time:>7.2f}\"\n)\nprint(\n    f\"{'ODPNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}\"\n    f\"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}\"\n)\n\n\n\"\"\"\nPlot comparison.\n\"\"\"\nnp.random.seed(123)\nindx = np.random.randint(0, high=maxn)\n\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(test_ds[\"label\"][indx, ..., 0], title=\"Ground truth\", cbar=None, fig=fig, ax=ax[0])\nplot.imview(\n    test_ds[\"image\"][indx, ..., 0],\n    title=\"Blurred: \\nSNR: %.2f (dB), PSNR: %.2f\"\n    % (\n        metric.snr(test_ds[\"label\"][indx, ..., 0], test_ds[\"image\"][indx, ..., 0]),\n        metric.psnr(test_ds[\"label\"][indx, ..., 0], test_ds[\"image\"][indx, ..., 0]),\n    ),\n    cbar=None,\n    fig=fig,\n    ax=ax[1],\n)\nplot.imview(\n    output[indx, ..., 0],\n    title=\"ODPNet Reconstruction\\nSNR: %.2f (dB), PSNR: %.2f\"\n    % (\n        metric.snr(test_ds[\"label\"][indx, ..., 0], output[indx, ..., 0]),\n        metric.psnr(test_ds[\"label\"][indx, ..., 0], output[indx, ..., 0]),\n    ),\n    fig=fig,\n    ax=ax[2],\n)\ndivider = make_axes_locatable(ax[2])\ncax = divider.append_axes(\"right\", size=\"5%\", pad=0.2)\nfig.colorbar(ax[2].get_images()[0], cax=cax, label=\"arbitrary units\")\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics. Statistics are generated only if a training\ncycle was done (i.e. if not reading final epoch results from checkpoint).\n\"\"\"\nif stats_object is not None and len(stats_object.iterations) > 0:\n    hist = stats_object.history(transpose=True)\n    fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\n    plot.plot(\n        np.array((hist.Train_Loss, hist.Eval_Loss)).T,\n        x=hist.Epoch,\n        ptyp=\"semilogy\",\n        title=\"Loss function\",\n        xlbl=\"Epoch\",\n        ylbl=\"Loss value\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[0],\n    )\n    plot.plot(\n        np.array((hist.Train_SNR, hist.Eval_SNR)).T,\n        x=hist.Epoch,\n        title=\"Metric\",\n        xlbl=\"Epoch\",\n        ylbl=\"SNR (dB)\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[1],\n    )\n    fig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/deconv_ppp_bm3d_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"\nPPP (with BM3D) Image Deconvolution (ADMM Solver)\n=================================================\n\nThis example demonstrates the solution of an image deconvolution problem\nusing the ADMM Plug-and-Play Priors (PPP) algorithm\n:cite:`venkatakrishnan-2013-plugandplay2`, with the BM3D\n:cite:`dabov-2008-image` denoiser.\n\"\"\"\n\nimport numpy as np\n\nfrom xdesign import Foam, discrete_phantom\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, plot, random\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nnp.random.seed(1234)\nN = 512  # image size\nx_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)\nx_gt = snp.array(x_gt)  # convert to jax array\n\n\n\"\"\"\nSet up forward operator and test signal consisting of blurred signal with\nadditive Gaussian noise.\n\"\"\"\nn = 5  # convolution kernel size\nσ = 20.0 / 255  # noise level\n\npsf = snp.ones((n, n)) / (n * n)\nA = linop.Convolve(h=psf, input_shape=x_gt.shape)\n\nAx = A(x_gt)  # blurred image\nnoise, key = random.randn(Ax.shape)\ny = Ax + σ * noise\n\n\n\"\"\"\nSet up ADMM solver.\n\"\"\"\nf = loss.SquaredL2Loss(y=y, A=A)\nC = linop.Identity(x_gt.shape)\n\nλ = 20.0 / 255  # BM3D regularization strength\ng = λ * functional.BM3D()\n\nρ = 1.0  # ADMM penalty parameter\nmaxiter = 10  # number of ADMM iterations\n\nsolver = ADMM(\n    f=f,\n    g_list=[g],\n    C_list=[C],\n    rho_list=[ρ],\n    x0=A.T @ y,\n    maxiter=maxiter,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 100}),\n    itstat_options={\"display\": True},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx = solver.solve()\nx = snp.clip(x, 0, 1)\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow the recovered image.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0])\nnc = n // 2\nyc = snp.clip(y[nc:-nc, nc:-nc], 0, 1)\nplot.imview(y, title=\"Blurred, noisy image: %.2f (dB)\" % metric.psnr(x_gt, yc), fig=fig, ax=ax[1])\nplot.imview(x, title=\"Deconvolved image: %.2f (dB)\" % metric.psnr(x_gt, x), fig=fig, ax=ax[2])\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n)\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/deconv_ppp_bm3d_apgm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"\nPPP (with BM3D) Image Deconvolution (APGM Solver)\n=================================================\n\nThis example demonstrates the solution of an image deconvolution problem\nusing the APGM Plug-and-Play Priors (PPP) algorithm\n:cite:`kamilov-2017-plugandplay`, with the BM3D :cite:`dabov-2008-image`\ndenoiser.\n\"\"\"\n\nimport numpy as np\n\nfrom xdesign import Foam, discrete_phantom\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, plot, random\nfrom scico.optimize.pgm import AcceleratedPGM\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nnp.random.seed(1234)\nN = 512  # image size\nx_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)\nx_gt = snp.array(x_gt)  # convert to jax array\n\n\n\"\"\"\nSet up forward operator and test signal consisting of blurred signal with\nadditive Gaussian noise.\n\"\"\"\nn = 5  # convolution kernel size\nσ = 20.0 / 255  # noise level\n\npsf = snp.ones((n, n)) / (n * n)\nA = linop.Convolve(h=psf, input_shape=x_gt.shape)\n\nAx = A(x_gt)  # blurred image\nnoise, key = random.randn(Ax.shape)\ny = Ax + σ * noise\n\n\n\"\"\"\nSet up PGM solver.\n\"\"\"\nf = loss.SquaredL2Loss(y=y, A=A)\n\nL0 = 15  # APGM inverse step size parameter\nλ = L0 * 2.0 / 255  # BM3D regularization strength\ng = λ * functional.BM3D()\n\nmaxiter = 50  # number of APGM iterations\n\nsolver = AcceleratedPGM(\n    f=f, g=g, L0=L0, x0=A.T @ y, maxiter=maxiter, itstat_options={\"display\": True, \"period\": 10}\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx = solver.solve()\nx = snp.clip(x, 0, 1)\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow the recovered image.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0])\nnc = n // 2\nyc = snp.clip(y[nc:-nc, nc:-nc], 0, 1)\nplot.imview(y, title=\"Blurred, noisy image: %.2f (dB)\" % metric.psnr(x_gt, yc), fig=fig, ax=ax[1])\nplot.imview(x, title=\"Deconvolved image: %.2f (dB)\" % metric.psnr(x_gt, x), fig=fig, ax=ax[2])\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nplot.plot(hist.Residual, ptyp=\"semilogy\", title=\"PGM Residual\", xlbl=\"Iteration\", ylbl=\"Residual\")\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/deconv_ppp_bm4d_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"\nPPP (with BM4D) Volume Deconvolution\n====================================\n\nThis example demonstrates the solution of a 3D image deconvolution problem\n(involving recovering a 3D volume that has been convolved with a 3D kernel\nand corrupted by noise) using the ADMM Plug-and-Play Priors (PPP)\nalgorithm :cite:`venkatakrishnan-2013-plugandplay2`, with the BM4D\n:cite:`maggioni-2012-nonlocal` denoiser.\n\"\"\"\n\nimport numpy as np\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, plot, random\nfrom scico.examples import create_3d_foam_phantom, downsample_volume, tile_volume_slices\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nnp.random.seed(1234)\nN = 128  # phantom size\nNx, Ny, Nz = N, N, N // 4\nupsamp = 2\nx_gt_hires = create_3d_foam_phantom((upsamp * Nz, upsamp * Ny, upsamp * Nx), N_sphere=100)\nx_gt = downsample_volume(x_gt_hires, upsamp)\nx_gt = snp.array(x_gt)  # convert to jax array\n\n\n\"\"\"\nSet up forward operator and test signal consisting of blurred signal with\nadditive Gaussian noise.\n\"\"\"\nn = 5  # convolution kernel size\nσ = 20.0 / 255  # noise level\n\npsf = snp.ones((n, n, n)) / (n**3)\nA = linop.Convolve(h=psf, input_shape=x_gt.shape)\n\nAx = A(x_gt)  # blurred image\nnoise, key = random.randn(Ax.shape)\ny = Ax + σ * noise\n\n\n\"\"\"\nSet up ADMM solver.\n\"\"\"\nf = loss.SquaredL2Loss(y=y, A=A)\nC = linop.Identity(x_gt.shape)\n\nλ = 40.0 / 255  # BM4D regularization strength\ng = λ * functional.BM4D()\n\nρ = 1.0  # ADMM penalty parameter\nmaxiter = 10  # number of ADMM iterations\n\nsolver = ADMM(\n    f=f,\n    g_list=[g],\n    C_list=[C],\n    rho_list=[ρ],\n    x0=A.T @ y,\n    maxiter=maxiter,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 100}),\n    itstat_options={\"display\": True},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx = solver.solve()\nx = snp.clip(x, 0, 1)\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow slices of the recovered 3D volume.\n\"\"\"\nshow_id = Nz // 2\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(tile_volume_slices(x_gt), title=\"Ground truth\", fig=fig, ax=ax[0])\nnc = n // 2\nyc = y[nc:-nc, nc:-nc, nc:-nc]\nyc = snp.clip(yc, 0, 1)\nplot.imview(\n    tile_volume_slices(yc),\n    title=\"Slices of blurred, noisy volume: %.2f (dB)\" % metric.psnr(x_gt, yc),\n    fig=fig,\n    ax=ax[1],\n)\nplot.imview(\n    tile_volume_slices(x),\n    title=\"Slices of deconvolved volume: %.2f (dB)\" % metric.psnr(x_gt, x),\n    fig=fig,\n    ax=ax[2],\n)\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n)\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/deconv_ppp_dncnn_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"\nPPP (with DnCNN) Image Deconvolution (ADMM Solver)\n==================================================\n\nThis example demonstrates the solution of an image deconvolution problem\nusing the ADMM Plug-and-Play Priors (PPP) algorithm\n:cite:`venkatakrishnan-2013-plugandplay2` with the DnCNN\n:cite:`zhang-2017-dncnn` denoiser.\n\"\"\"\n\nimport numpy as np\n\nfrom xdesign import Foam, discrete_phantom\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, plot, random\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nnp.random.seed(1234)\nN = 512  # image size\nx_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)\nx_gt = snp.array(x_gt)  # convert to jax array\n\n\n\"\"\"\nSet up forward operator and test signal consisting of blurred signal with\nadditive Gaussian noise.\n\"\"\"\nn = 5  # convolution kernel size\nσ = 20.0 / 255  # noise level\n\npsf = snp.ones((n, n)) / (n * n)\nA = linop.Convolve(h=psf, input_shape=x_gt.shape)\n\nAx = A(x_gt)  # blurred image\nnoise, key = random.randn(Ax.shape)\ny = Ax + σ * noise\n\n\n\"\"\"\nSet up the problem to be solved. We want to minimize the functional\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - A \\mathbf{x}\n  \\|_2^2 + R(\\mathbf{x}) \\;$$\n\nwhere $R(\\cdot)$ is a pseudo-functional having the DnCNN denoiser as its\nproximal operator. The problem is solved via ADMM, using the standard\nvariable splitting for problems of this form, which requires the use of\nconjugate gradient sub-iterations in the ADMM step that involves the data\nfidelity term.\n\"\"\"\nf = loss.SquaredL2Loss(y=y, A=A)\ng = functional.DnCNN(\"17M\")\nC = linop.Identity(x_gt.shape)\n\n\n\"\"\"\nSet up ADMM solver.\n\"\"\"\nρ = 0.2  # ADMM penalty parameter\nmaxiter = 10  # number of ADMM iterations\n\nsolver = ADMM(\n    f=f,\n    g_list=[g],\n    C_list=[C],\n    rho_list=[ρ],\n    x0=A.T @ y,\n    maxiter=maxiter,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 30}),\n    itstat_options={\"display\": True},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx = solver.solve()\nx = snp.clip(x, 0, 1)\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow the recovered image.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0])\nnc = n // 2\nyc = snp.clip(y[nc:-nc, nc:-nc], 0, 1)\nplot.imview(y, title=\"Blurred, noisy image: %.2f (dB)\" % metric.psnr(x_gt, yc), fig=fig, ax=ax[1])\nplot.imview(x, title=\"Deconvolved image: %.2f (dB)\" % metric.psnr(x_gt, x), fig=fig, ax=ax[2])\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n)\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/deconv_ppp_dncnn_padmm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"\nPPP (with DnCNN) Image Deconvolution (Proximal ADMM Solver)\n===========================================================\n\nThis example demonstrates the solution of an image deconvolution problem\nusing a proximal ADMM variant of the Plug-and-Play Priors (PPP) algorithm\n:cite:`venkatakrishnan-2013-plugandplay2` with the DnCNN\n:cite:`zhang-2017-dncnn` denoiser.\n\"\"\"\n\nimport numpy as np\n\nfrom xdesign import Foam, discrete_phantom\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, plot, random\nfrom scico.optimize import ProximalADMM\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nnp.random.seed(1234)\nN = 512  # image size\nx_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)\nx_gt = snp.array(x_gt)  # convert to jax array\n\n\n\"\"\"\nSet up forward operator $A$ and test signal consisting of blurred signal with\nadditive Gaussian noise.\n\"\"\"\nn = 5  # convolution kernel size\nσ = 20.0 / 255  # noise level\n\npsf = snp.ones((n, n)) / (n * n)\nA = linop.Convolve(h=psf, input_shape=x_gt.shape)\n\nAx = A(x_gt)  # blurred image\nnoise, key = random.randn(Ax.shape)\ny = Ax + σ * noise\n\n\nr\"\"\"\nSet up the problem to be solved. We want to minimize the functional\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - A \\mathbf{x}\n  \\|_2^2 + R(\\mathbf{x}) \\;$$\n\nwhere $R(\\cdot)$ is a pseudo-functional having the DnCNN denoiser as its\nproximal operator. A slightly unusual variable splitting is used,\\\nincluding setting the $f$ functional to the $R(\\cdot)$ term and the $g$\nfunctional to the data fidelity term to allow the use of proximal ADMM,\nwhich avoids the need for conjugate gradient sub-iterations in the solver\nsteps.\n\"\"\"\nf = functional.DnCNN(variant=\"17M\")\ng = loss.SquaredL2Loss(y=y)\n\n\n\"\"\"\nSet up proximal ADMM solver.\n\"\"\"\nρ = 0.2  # ADMM penalty parameter\nmaxiter = 10  # number of proximal ADMM iterations\nmu, nu = ProximalADMM.estimate_parameters(A)\n\nsolver = ProximalADMM(\n    f=f,\n    g=g,\n    A=A,\n    rho=ρ,\n    mu=mu,\n    nu=nu,\n    x0=A.T @ y,\n    maxiter=maxiter,\n    itstat_options={\"display\": True},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx = solver.solve()\nx = snp.clip(x, 0, 1)\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow the recovered image.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0])\nnc = n // 2\nyc = snp.clip(y[nc:-nc, nc:-nc], 0, 1)\nplot.imview(y, title=\"Blurred, noisy image: %.2f (dB)\" % metric.psnr(x_gt, yc), fig=fig, ax=ax[1])\nplot.imview(x, title=\"Deconvolved image: %.2f (dB)\" % metric.psnr(x_gt, x), fig=fig, ax=ax[2])\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n)\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/deconv_tv_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nImage Deconvolution with TV Regularization (ADMM Solver)\n========================================================\n\nThis example demonstrates the solution of an image deconvolution problem\nwith isotropic total variation (TV) regularization\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - C \\mathbf{x}\n  \\|_2^2 + \\lambda \\| D \\mathbf{x} \\|_{2,1} \\;,$$\n\nwhere $C$ is a convolution operator, $\\mathbf{y}$ is the blurred image,\n$D$ is a 2D finite fifference operator, and $\\mathbf{x}$ is the\ndeconvolved image.\n\nIn this example the problem is solved via standard ADMM, while proximal\nADMM is used in a [companion example](deconv_tv_padmm.rst).\n\"\"\"\n\nfrom xdesign import SiemensStar, discrete_phantom\n\nimport scico.numpy as snp\nimport scico.random\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nphantom = SiemensStar(32)\nN = 256  # image size\nx_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)\n\n\n\"\"\"\nSet up the forward operator and create a test signal consisting of a\nblurred signal with additive Gaussian noise.\n\"\"\"\nn = 5  # convolution kernel size\nσ = 20.0 / 255  # noise level\n\npsf = snp.ones((n, n)) / (n * n)\nC = linop.Convolve(h=psf, input_shape=x_gt.shape)\n\nCx = C(x_gt)  # blurred image\nnoise, key = scico.random.randn(Cx.shape, seed=0)\ny = Cx + σ * noise\n\n\nr\"\"\"\nSet up the problem to be solved. We want to minimize the functional\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - C \\mathbf{x}\n  \\|_2^2 + \\lambda \\| D \\mathbf{x} \\|_{2,1} \\;,$$\n\nwhere $C$ is the convolution operator and $D$ is a finite difference\noperator. This problem can be expressed as\n\n  $$\\mathrm{argmin}_{\\mathbf{x}, \\mathbf{z}} \\; (1/2) \\| \\mathbf{y} -\n  C \\mathbf{x} \\|_2^2 + \\lambda \\| \\mathbf{z} \\|_{2,1} \\;\\;\n  \\text{such that} \\;\\; \\mathbf{z} = D \\mathbf{x} \\;,$$\n\nwhich is easily written in the form of a standard ADMM problem.\n\nThis is simpler splitting than that used in the\n[companion example](deconv_tv_padmm.rst), but it requires the use\nconjugate gradient sub-iterations to solve the ADMM step associated with\nthe data fidelity term.\n\"\"\"\nf = loss.SquaredL2Loss(y=y, A=C)\n# Penalty parameters must be accounted for in the gi functions, not as\n# additional inputs.\nλ = 2.1e-2  # ℓ2,1 norm regularization parameter\ng = λ * functional.L21Norm()\n# The append=0 option makes the results of horizontal and vertical\n# finite differences the same shape, which is required for the L21Norm,\n# which is used so that g(Cx) corresponds to isotropic TV.\nD = linop.FiniteDifference(input_shape=x_gt.shape, append=0)\n\n\n\"\"\"\nSet up an ADMM solver object.\n\"\"\"\nρ = 1.0e-1  # ADMM penalty parameter\nmaxiter = 50  # number of ADMM iterations\n\nsolver = ADMM(\n    f=f,\n    g_list=[g],\n    C_list=[D],\n    rho_list=[ρ],\n    x0=C.adj(y),\n    maxiter=maxiter,\n    subproblem_solver=LinearSubproblemSolver(),\n    itstat_options={\"display\": True, \"period\": 10},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx = solver.solve()\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow the recovered image.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0])\nnc = n // 2\nyc = y[nc:-nc, nc:-nc]\nplot.imview(y, title=\"Blurred, noisy image: %.2f (dB)\" % metric.psnr(x_gt, yc), fig=fig, ax=ax[1])\nplot.imview(\n    solver.x, title=\"Deconvolved image: %.2f (dB)\" % metric.psnr(x_gt, solver.x), fig=fig, ax=ax[2]\n)\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\nplot.plot(\n    hist.Objective,\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    ylbl=\"Functional value\",\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[1],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/deconv_tv_admm_tune.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nParameter Tuning for Image Deconvolution with TV Regularization (ADMM Solver)\n=============================================================================\n\nThis example demonstrates the use of\n[scico.ray.tune](../_autosummary/scico.ray.tune.rst) to tune parameters\nfor the companion [example script](deconv_tv_admm.rst). The `ray.tune`\nfunction API is used in this example.\n\nThis script is hard-coded to run on CPU only to avoid the large number of\nwarnings that are emitted when GPU resources are requested but not available,\nand due to the difficulty of suppressing these warnings in a way that does\nnot force use of the CPU only. To enable GPU usage, comment out the\n`os.environ` statements near the beginning of the script, and change the\nvalue of the \"gpu\" entry in the `resources` dict from 0 to 1. Note that\ntwo environment variables are set to suppress the warnings because\n`JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but this change\nhas yet to be correctly implemented\n(see [google/jax#6805](https://github.com/google/jax/issues/6805) and\n[google/jax#10272](https://github.com/google/jax/pull/10272)).\n\"\"\"\n\n# isort: off\nimport os\n\nos.environ[\"JAX_PLATFORM_NAME\"] = \"cpu\"\nos.environ[\"JAX_PLATFORMS\"] = \"cpu\"\n\nfrom xdesign import SiemensStar, discrete_phantom\n\nimport logging\nimport ray\n\nray.init(logging_level=logging.ERROR)  # need to call init before jax import: ray-project/ray#44087\n\nimport scico.numpy as snp\nimport scico.random\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.ray import report, tune\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nphantom = SiemensStar(32)\nN = 256  # image size\nx_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)\n\n\n\"\"\"\nSet up the forward operator and create a test signal consisting of a\nblurred signal with additive Gaussian noise.\n\"\"\"\nn = 5  # convolution kernel size\nσ = 20.0 / 255  # noise level\n\npsf = snp.ones((n, n)) / (n * n)\nA = linop.Convolve(h=psf, input_shape=x_gt.shape)\n\nAx = A(x_gt)  # blurred image\nnoise, key = scico.random.randn(Ax.shape, seed=0)\ny = Ax + σ * noise\n\n\n\"\"\"\nDefine performance evaluation function.\n\"\"\"\n\n\ndef eval_params(config, x_gt, psf, y):\n    \"\"\"Parameter evaluation function. The `config` parameter is a\n    dict of specific parameters for evaluation of a single parameter\n    set (a pair of parameters in this case). The remaining parameters\n    are objects that are passed to the evaluation function via the\n    ray object store.\n    \"\"\"\n    # Extract solver parameters from config dict.\n    λ, ρ = config[\"lambda\"], config[\"rho\"]\n    # Set up problem to be solved.\n    A = linop.Convolve(h=psf, input_shape=x_gt.shape)\n    f = loss.SquaredL2Loss(y=y, A=A)\n    g = λ * functional.L21Norm()\n    C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)\n    # Define solver.\n    solver = ADMM(\n        f=f,\n        g_list=[g],\n        C_list=[C],\n        rho_list=[ρ],\n        x0=A.adj(y),\n        maxiter=10,\n        subproblem_solver=LinearSubproblemSolver(),\n    )\n    # Perform 50 iterations, reporting performance to ray.tune every 10 iterations.\n    for step in range(5):\n        x_admm = solver.solve()\n        report({\"psnr\": float(metric.psnr(x_gt, x_admm))})\n\n\n\"\"\"\nDefine parameter search space and resources per trial.\n\"\"\"\nconfig = {\"lambda\": tune.loguniform(1e-3, 1e-1), \"rho\": tune.loguniform(1e-2, 1e0)}\nresources = {\"cpu\": 4, \"gpu\": 0}  # cpus per trial, gpus per trial\n\n\n\"\"\"\nRun parameter search.\n\"\"\"\ntuner = tune.Tuner(\n    tune.with_parameters(eval_params, x_gt=x_gt, psf=psf, y=y),\n    param_space=config,\n    resources=resources,\n    metric=\"psnr\",\n    mode=\"max\",\n    num_samples=100,  # perform 100 parameter evaluations\n)\nresults = tuner.fit()\nray.shutdown()\n\n\n\"\"\"\nDisplay best parameters and corresponding performance.\n\"\"\"\nbest_result = results.get_best_result()\nbest_config = best_result.config\nprint(f\"Best PSNR: {best_result.metrics['psnr']:.2f} dB\")\nprint(\"Best config: \" + \", \".join([f\"{k}: {v:.2e}\" for k, v in best_config.items()]))\n\n\n\"\"\"\nPlot parameter values visited during parameter search. Marker sizes are\nproportional to number of iterations run at each parameter pair. The best\npoint in the parameter space is indicated in red.\n\"\"\"\nfig = plot.figure(figsize=(8, 8))\ntrials = results.get_dataframe()\nfor t in trials.iloc:\n    n = t[\"training_iteration\"]\n    plot.plot(\n        t[\"config/lambda\"],\n        t[\"config/rho\"],\n        ptyp=\"loglog\",\n        lw=0,\n        ms=(0.5 + 1.5 * n),\n        marker=\"o\",\n        mfc=\"blue\",\n        mec=\"blue\",\n        fig=fig,\n    )\nplot.plot(\n    best_config[\"lambda\"],\n    best_config[\"rho\"],\n    ptyp=\"loglog\",\n    title=\"Parameter search sampling locations\\n(marker size proportional to number of iterations)\",\n    xlbl=r\"$\\rho$\",\n    ylbl=r\"$\\lambda$\",\n    lw=0,\n    ms=5.0,\n    marker=\"o\",\n    mfc=\"red\",\n    mec=\"red\",\n    fig=fig,\n)\nax = fig.axes[0]\nax.set_xlim([config[\"rho\"].lower, config[\"rho\"].upper])\nax.set_ylim([config[\"lambda\"].lower, config[\"lambda\"].upper])\nfig.show()\n\n\n\"\"\"\nPlot parameter values visited during parameter search and corresponding\nreconstruction PSNRs.The best point in the parameter space is indicated\nin red.\n\"\"\"\n𝜌 = [t[\"config/rho\"] for t in trials.iloc]\n𝜆 = [t[\"config/lambda\"] for t in trials.iloc]\npsnr = [t[\"psnr\"] for t in trials.iloc]\nminpsnr = min(max(psnr), 18.0)\n𝜌, 𝜆, psnr = zip(*filter(lambda x: x[2] >= minpsnr, zip(𝜌, 𝜆, psnr)))\nfig, ax = plot.subplots(figsize=(10, 8))\nsc = ax.scatter(𝜌, 𝜆, c=psnr, cmap=plot.cm.plasma_r)\nfig.colorbar(sc)\nplot.plot(\n    best_config[\"lambda\"],\n    best_config[\"rho\"],\n    ptyp=\"loglog\",\n    lw=0,\n    ms=12.0,\n    marker=\"2\",\n    mfc=\"red\",\n    mec=\"red\",\n    fig=fig,\n    ax=ax,\n)\nax.set_xscale(\"log\")\nax.set_yscale(\"log\")\nax.set_xlabel(r\"$\\rho$\")\nax.set_ylabel(r\"$\\lambda$\")\nax.set_title(\"PSNR at each sample location\\n(values below 18 dB omitted)\")\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/deconv_tv_padmm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nImage Deconvolution with TV Regularization (Proximal ADMM Solver)\n=================================================================\n\nThis example demonstrates the solution of an image deconvolution problem\nwith isotropic total variation (TV) regularization\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - C \\mathbf{x}\n  \\|_2^2 + \\lambda \\| D \\mathbf{x} \\|_{2,1} \\;,$$\n\nwhere $C$ is a convolution operator, $\\mathbf{y}$ is the blurred image,\n$D$ is a 2D finite difference operator, and $\\mathbf{x}$ is the\ndeconvolved image.\n\nIn this example the problem is solved via proximal ADMM, while standard\nADMM is used in a [companion example](deconv_tv_admm.rst).\n\"\"\"\n\nfrom xdesign import SiemensStar, discrete_phantom\n\nimport scico.numpy as snp\nimport scico.random\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.optimize import ProximalADMM\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nphantom = SiemensStar(32)\nN = 256  # image size\nx_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)\n\n\n\"\"\"\nSet up the forward operator and create a test signal consisting of a\nblurred signal with additive Gaussian noise.\n\"\"\"\nn = 5  # convolution kernel size\nσ = 20.0 / 255  # noise level\n\npsf = snp.ones((n, n)) / (n * n)\nC = linop.Convolve(h=psf, input_shape=x_gt.shape)\n\nCx = C(x_gt)  # blurred image\nnoise, key = scico.random.randn(Cx.shape, seed=0)\ny = Cx + σ * noise\n\n\nr\"\"\"\nSet up the problem to be solved. We want to minimize the functional\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - C \\mathbf{x}\n  \\|_2^2 + \\lambda \\| D \\mathbf{x} \\|_{2,1} \\;,$$\n\nwhere $C$ is the convolution operator and $D$ is a finite difference\noperator. This problem can be expressed as\n\n  $$\\mathrm{argmin}_{\\mathbf{x}, \\mathbf{z}} \\; (1/2) \\| \\mathbf{y} -\n  \\mathbf{z}_0 \\|_2^2 + \\lambda \\| \\mathbf{z}_1 \\|_{2,1} \\;\\;\n  \\text{such that} \\;\\; \\mathbf{z}_0 = C \\mathbf{x} \\;\\; \\text{and} \\;\\;\n  \\mathbf{z}_1 = D \\mathbf{x} \\;,$$\n\nwhich can be written in the form of a standard ADMM problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}, \\mathbf{z}} \\; f(\\mathbf{x}) + g(\\mathbf{z})\n  \\;\\; \\text{such that} \\;\\; A \\mathbf{x} + B \\mathbf{z} = \\mathbf{c}$$\n\nwith\n\n  $$f = 0 \\qquad g = g_0 + g_1$$\n  $$g_0(\\mathbf{z}_0) = (1/2) \\| \\mathbf{y} - \\mathbf{z}_0 \\|_2^2 \\qquad\n  g_1(\\mathbf{z}_1) = \\lambda \\| \\mathbf{z}_1 \\|_{2,1}$$\n  $$A = \\left( \\begin{array}{c} C \\\\ D \\end{array} \\right) \\qquad\n  B = \\left( \\begin{array}{cc} -I & 0 \\\\ 0 & -I \\end{array} \\right) \\qquad\n  \\mathbf{c} = \\left( \\begin{array}{c} 0 \\\\ 0 \\end{array} \\right) \\;.$$\n\nThis is a more complex splitting than that used in the\n[companion example](deconv_tv_admm.rst), but it allows the use of a\nproximal ADMM solver in a way that avoids the need for the conjugate\ngradient sub-iterations used by the ADMM solver in the\n[companion example](deconv_tv_admm.rst).\n\"\"\"\nf = functional.ZeroFunctional()\ng0 = loss.SquaredL2Loss(y=y)\nλ = 2.0e-2  # ℓ2,1 norm regularization parameter\ng1 = λ * functional.L21Norm()\ng = functional.SeparableFunctional((g0, g1))\n\nD = linop.FiniteDifference(input_shape=x_gt.shape, append=0)\nA = linop.VerticalStack((C, D))\n\n\n\"\"\"\nSet up a proximal ADMM solver object.\n\"\"\"\nρ = 5.0e-2  # ADMM penalty parameter\nmaxiter = 50  # number of ADMM iterations\nmu, nu = ProximalADMM.estimate_parameters(A)\n\nsolver = ProximalADMM(\n    f=f,\n    g=g,\n    A=A,\n    B=None,\n    rho=ρ,\n    mu=mu,\n    nu=nu,\n    x0=C.adj(y),\n    maxiter=maxiter,\n    itstat_options={\"display\": True, \"period\": 10},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx = solver.solve()\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow the recovered image.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0])\nnc = n // 2\nyc = y[nc:-nc, nc:-nc]\nplot.imview(y, title=\"Blurred, noisy image: %.2f (dB)\" % metric.psnr(x_gt, yc), fig=fig, ax=ax[1])\nplot.imview(\n    solver.x, title=\"Deconvolved image: %.2f (dB)\" % metric.psnr(x_gt, solver.x), fig=fig, ax=ax[2]\n)\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\nplot.plot(\n    hist.Objective,\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    ylbl=\"Functional value\",\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[1],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/demosaic_ppp_bm3d_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"\nPPP (with BM3D) Image Demosaicing\n=================================\n\nThis example demonstrates the use of the ADMM Plug and Play Priors (PPP)\nalgorithm :cite:`venkatakrishnan-2013-plugandplay2`, with the BM3D\n:cite:`dabov-2008-image` denoiser, for solving a raw image demosaicing\nproblem.\n\"\"\"\n\nimport numpy as np\n\nfrom bm3d import bm3d_rgb\n\n# Workarounds for colour_demosaicing incompatibility with NumPy 2.x\nnp.float_ = np.float64\nnp.float = np.float64\nnp.complex = np.complex128\nnp.sctypes = {\n    \"float\": [np.float16, np.float32, np.float64, np.longdouble],\n    \"int\": [np.int8, np.int16, np.int32, np.int64],\n}\nfrom colour_demosaicing import demosaicing_CFA_Bayer_Menon2007\n\nimport scico\nimport scico.numpy as snp\nimport scico.random\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.data import kodim23\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nRead a ground truth image.\n\"\"\"\nimg = snp.array(kodim23(asfloat=True)[160:416, 60:316])\n\n\n\"\"\"\nDefine demosaicing forward operator and its transpose.\n\"\"\"\n\n\ndef Afn(x):\n    \"\"\"Map an RGB image to a single channel image with each pixel\n    representing a single colour according to the colour filter array.\n    \"\"\"\n\n    y = snp.zeros(x.shape[0:2])\n    y = y.at[1::2, 1::2].set(x[1::2, 1::2, 0])\n    y = y.at[0::2, 1::2].set(x[0::2, 1::2, 1])\n    y = y.at[1::2, 0::2].set(x[1::2, 0::2, 1])\n    y = y.at[0::2, 0::2].set(x[0::2, 0::2, 2])\n    return y\n\n\ndef ATfn(x):\n    \"\"\"Back project a single channel raw image to an RGB image with zeros\n    at the locations of undefined samples.\n    \"\"\"\n\n    y = snp.zeros(x.shape + (3,))\n    y = y.at[1::2, 1::2, 0].set(x[1::2, 1::2])\n    y = y.at[0::2, 1::2, 1].set(x[0::2, 1::2])\n    y = y.at[1::2, 0::2, 1].set(x[1::2, 0::2])\n    y = y.at[0::2, 0::2, 2].set(x[0::2, 0::2])\n    return y\n\n\n\"\"\"\nDefine a baseline demosaicing function based on the demosaicing\nalgorithm of :cite:`menon-2007-demosaicing` from package\n[colour_demosaicing](https://github.com/colour-science/colour-demosaicing).\n\"\"\"\n\n\ndef demosaic(cfaimg):\n    \"\"\"Apply baseline demosaicing.\"\"\"\n    return demosaicing_CFA_Bayer_Menon2007(cfaimg, pattern=\"BGGR\").astype(np.float32)\n\n\n\"\"\"\nCreate a test image by color filter array sampling and adding Gaussian\nwhite noise.\n\"\"\"\ns = Afn(img)\nrgbshp = s.shape + (3,)  # shape of reconstructed RGB image\nσ = 2e-2  # noise standard deviation\nnoise, key = scico.random.randn(s.shape, seed=0)\nsn = s + σ * noise\n\n\n\"\"\"\nCompute a baseline demosaicing solution.\n\"\"\"\nimgb = snp.array(bm3d_rgb(demosaic(sn), 3 * σ).astype(np.float32))\n\n\n\"\"\"\nSet up an ADMM solver object. Note the use of the baseline solution\nas an initializer. We use BM3D :cite:`dabov-2008-image` as the\ndenoiser, using the [code](https://pypi.org/project/bm3d) released\nwith :cite:`makinen-2019-exact`.\n\"\"\"\nA = linop.LinearOperator(input_shape=rgbshp, output_shape=s.shape, eval_fn=Afn, adj_fn=ATfn)\nf = loss.SquaredL2Loss(y=sn, A=A)\nC = linop.Identity(input_shape=rgbshp)\ng = 1.8e-1 * 6.1e-2 * functional.BM3D(is_rgb=True)\nρ = 1.8e-1  # ADMM penalty parameter\nmaxiter = 12  # number of ADMM iterations\n\nsolver = ADMM(\n    f=f,\n    g_list=[g],\n    C_list=[C],\n    rho_list=[ρ],\n    x0=imgb,\n    maxiter=maxiter,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 100}),\n    itstat_options={\"display\": True},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx = solver.solve()\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow reference and demosaiced images.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(21, 7))\nplot.imview(img, title=\"Reference\", fig=fig, ax=ax[0])\nplot.imview(imgb, title=\"Baseline demoisac: %.2f (dB)\" % metric.psnr(img, imgb), fig=fig, ax=ax[1])\nplot.imview(x, title=\"PPP demoisac: %.2f (dB)\" % metric.psnr(img, x), fig=fig, ax=ax[2])\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n)\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/denoise_approx_tv_multi.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nDenoising with Approximate Total Variation Proximal Operator\n============================================================\n\nThis example demonstrates use of approximations to the proximal\noperators of isotropic :cite:`kamilov-2016-minimizing` and anisotropic\n:cite:`kamilov-2016-parallel` total variation norms for solving\ndenoising problems using proximal algorithms.\n\"\"\"\n\nimport matplotlib\nfrom xdesign import SiemensStar, discrete_phantom\n\nimport scico.numpy as snp\nimport scico.random\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.optimize import AcceleratedPGM\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nN = 256  # image size\nphantom = SiemensStar(16)\nx_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)\nx_gt = x_gt / x_gt.max()\n\n\n\"\"\"\nAdd noise to create a noisy test image.\n\"\"\"\nσ = 0.5  # noise standard deviation\nnoise, key = scico.random.randn(x_gt.shape, seed=0)\ny = x_gt + σ * noise\n\n\n\"\"\"\nDenoise with isotropic total variation, solved via ADMM.\n\"\"\"\nλ_iso = 1.0e0\nf = loss.SquaredL2Loss(y=y)\ng_iso = λ_iso * functional.L21Norm()\nC = linop.FiniteDifference(input_shape=x_gt.shape, circular=True)\n\nsolver = ADMM(\n    f=f,\n    g_list=[g_iso],\n    C_list=[C],\n    rho_list=[1e1],\n    x0=y,\n    maxiter=200,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-4, \"maxiter\": 25}),\n    itstat_options={\"display\": True, \"period\": 25},\n)\nprint(f\"Solving on {device_info()}\\n\")\nx_iso = solver.solve()\nprint()\n\n\n\"\"\"\nDenoise with anisotropic total variation, solved via ADMM.\n\"\"\"\n# Tune the weight to give the same data fidelity as the isotropic case.\nλ_aniso = 8.68e-1\ng_aniso = λ_aniso * functional.L1Norm()\n\nsolver = ADMM(\n    f=f,\n    g_list=[g_aniso],\n    C_list=[C],\n    rho_list=[1e1],\n    x0=y,\n    maxiter=200,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-4, \"maxiter\": 25}),\n    itstat_options={\"display\": True, \"period\": 25},\n)\nx_aniso = solver.solve()\nprint()\n\n\n\"\"\"\nDenoise with isotropic total variation, solved using an approximation of\nthe TV norm proximal operator.\n\"\"\"\nh = λ_iso * functional.IsotropicTVNorm(circular=True, input_shape=y.shape)\nsolver = AcceleratedPGM(\n    f=f, g=h, L0=1e3, x0=y, maxiter=500, itstat_options={\"display\": True, \"period\": 50}\n)\nx_iso_aprx = solver.solve()\nprint()\n\n\n\"\"\"\nDenoise with anisotropic total variation, solved using an approximation\nof the TV norm proximal operator.\n\"\"\"\nh = λ_aniso * functional.AnisotropicTVNorm(circular=True, input_shape=y.shape)\nsolver = AcceleratedPGM(\n    f=f, g=h, L0=1e3, x0=y, maxiter=500, itstat_options={\"display\": True, \"period\": 50}\n)\nx_aniso_aprx = solver.solve()\nprint()\n\n\n\"\"\"\nCompute and print the data fidelity.\n\"\"\"\nfor x, name in zip(\n    (x_iso, x_aniso, x_iso_aprx, x_aniso_aprx),\n    (\"Isotropic\", \"Anisotropic\", \"Approx. Isotropic\", \"Approx. Anisotropic\"),\n):\n    df = f(x)\n    print(f\"Data fidelity for {name} TV: {' ' * (20 - len(name))} {df:.2e}\")\n\n\n\"\"\"\nPlot results.\n\"\"\"\nmatplotlib.rc(\"font\", size=9)\nplt_args = dict(norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.5))\nfig, ax = plot.subplots(nrows=2, ncols=3, sharex=True, sharey=True, figsize=(15, 8))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0, 0], **plt_args)\nplot.imview(\n    y, title=f\"Noisy version SNR: {metric.snr(x_gt, y):.2f} dB\", fig=fig, ax=ax[1, 0], **plt_args\n)\nplot.imview(\n    x_iso,\n    title=f\"Iso. TV denoising SNR: {metric.snr(x_gt, x_iso):.2f} dB\",\n    fig=fig,\n    ax=ax[0, 1],\n    **plt_args,\n)\nplot.imview(\n    x_aniso,\n    title=f\"Aniso. TV denoising SNR: {metric.snr(x_gt, x_aniso):.2f} dB\",\n    fig=fig,\n    ax=ax[1, 1],\n    **plt_args,\n)\nplot.imview(\n    x_iso_aprx,\n    title=f\"Approx. Iso. TV denoising SNR: {metric.snr(x_gt, x_iso_aprx):.2f} dB\",\n    fig=fig,\n    ax=ax[0, 2],\n    **plt_args,\n)\nplot.imview(\n    x_aniso_aprx,\n    title=f\"Approx. Aniso. TV denoising SNR: {metric.snr(x_gt, x_aniso_aprx):.2f} dB\",\n    fig=fig,\n    ax=ax[1, 2],\n    **plt_args,\n)\nfig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)\nfig.colorbar(\n    ax[0, 0].get_images()[0], ax=ax, location=\"right\", shrink=0.9, pad=0.05, label=\"Arbitrary Units\"\n)\nfig.suptitle(\"Denoising comparison\")\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/denoise_cplx_tv_nlpadmm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nComplex Total Variation Denoising with NLPADMM Solver\n=====================================================\n\nThis example demonstrates solution of a problem of the form\n\n$$\\argmin_{\\mb{x}} \\; f(\\mb{x}) + g(\\mb{z}) \\; \\text{such that}\\;\nH(\\mb{x}, \\mb{z}) = 0 \\;,$$\n\nwhere $H$ is a nonlinear function, via a variant of the proximal ADMM\nalgorithm for problems with a non-linear operator constraint\n:cite:`benning-2016-preconditioned`. The example problem represents\ntotal variation (TV) denoising applied to a complex image with\npiece-wise smooth magnitude and non-smooth phase. (This example is rather\ncontrived, and was not constructed to represent a specific real imaging\nproblem, but it does have some properties in common with synthetic\naperture radar single look complex data in which the magnitude has much\nmore discernible structure than the phase.) The appropriate TV denoising\nformulation for this problem is\n\n$$\\argmin_{\\mb{x}} \\; (1/2) \\| \\mb{y} - \\mb{x} \\|_2^2 + \\lambda\n\\| C(\\mb{x}) \\|_{2,1} \\;,$$\n\nwhere $\\mb{y}$ is the measurement, $\\|\\cdot\\|_{2,1}$ is the\n$\\ell_{2,1}$ mixed norm, and $C$ is a non-linear operator consisting of\na linear difference operator applied to the magnitude of a complex array.\nThis problem is represented in the form above by taking $H(\\mb{x},\n\\mb{z}) = C(\\mb{x}) - \\mb{z}$. The standard TV solution, which is\nalso computed for comparison purposes, gives very poor results since\nthe difference is applied independently to real and imaginary\ncomponents of the complex image.\n\"\"\"\n\nfrom mpl_toolkits.axes_grid1 import make_axes_locatable\nfrom xdesign import SiemensStar, discrete_phantom\n\nimport scico.numpy as snp\nimport scico.random\nfrom scico import function, functional, linop, loss, metric, operator, plot\nfrom scico.examples import phase_diff\nfrom scico.optimize import NonLinearPADMM, ProximalADMM\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nN = 256  # image size\nphantom = SiemensStar(16)\nx_mag = snp.pad(discrete_phantom(phantom, N - 16), 8) + 1.0\nx_mag /= x_mag.max()\n# Create reference image with structured magnitude and random phase\nx_gt = x_mag * snp.exp(-1j * scico.random.randn(x_mag.shape, seed=0)[0])\n\n\n\"\"\"\nAdd noise to create a noisy test image.\n\"\"\"\nσ = 0.25  # noise standard deviation\nnoise, key = scico.random.randn(x_gt.shape, seed=1, dtype=snp.complex64)\ny = x_gt + σ * noise\n\n\n\"\"\"\nDenoise with standard total variation.\n\"\"\"\nλ_tv = 6e-2\nf = loss.SquaredL2Loss(y=y)\ng = λ_tv * functional.L21Norm()\n# The append=0 option makes the results of horizontal and vertical finite\n# differences the same shape, which is required for the L21Norm.\nC = linop.FiniteDifference(input_shape=y.shape, input_dtype=snp.complex64, append=0)\n\nsolver_tv = ProximalADMM(\n    f=f,\n    g=g,\n    A=C,\n    rho=1.0,\n    mu=8.0,\n    nu=1.0,\n    maxiter=200,\n    itstat_options={\"display\": True, \"period\": 20},\n)\nprint(f\"Solving on {device_info()}\\n\")\nx_tv = solver_tv.solve()\nprint()\nhist_tv = solver_tv.itstat_object.history(transpose=True)\n\n\n\"\"\"\nDenoise with total variation applied to the magnitude of a complex image.\n\"\"\"\nλ_nltv = 2e-1\ng = λ_nltv * functional.L21Norm()\n# Redefine C for real input (now applied to magnitude of a complex array)\nC = linop.FiniteDifference(input_shape=y.shape, input_dtype=snp.float32, append=0)\n# Operator computing differences of absolute values\nD = C @ operator.Abs(input_shape=x_gt.shape, input_dtype=snp.complex64)\n# Constraint function imposing z = D(x) constraint\nH = function.Function(\n    (C.shape[1], C.shape[0]),\n    output_shape=C.shape[0],\n    eval_fn=lambda x, z: D(x) - z,\n    input_dtypes=(snp.complex64, snp.float32),\n    output_dtype=snp.float32,\n)\n\nsolver_nltv = NonLinearPADMM(\n    f=f,\n    g=g,\n    H=H,\n    rho=5.0,\n    mu=6.0,\n    nu=1.0,\n    maxiter=200,\n    itstat_options={\"display\": True, \"period\": 20},\n)\nx_nltv = solver_nltv.solve()\nhist_nltv = solver_nltv.itstat_object.history(transpose=True)\n\n\n\"\"\"\nPlot results.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6))\nplot.plot(\n    snp.array((hist_tv.Objective, hist_nltv.Objective)).T,\n    ptyp=\"semilogy\",\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Standard TV\", \"Magnitude TV\"),\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist_tv.Prml_Rsdl, hist_nltv.Prml_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Primal residual\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Standard TV\", \"Magnitude TV\"),\n    fig=fig,\n    ax=ax[1],\n)\nplot.plot(\n    snp.array((hist_tv.Dual_Rsdl, hist_nltv.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Dual residual\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Standard TV\", \"Magnitude TV\"),\n    fig=fig,\n    ax=ax[2],\n)\nfig.show()\n\n\nfig, ax = plot.subplots(nrows=2, ncols=4, figsize=(20, 10))\nnorm = plot.matplotlib.colors.Normalize(\n    vmin=min(snp.abs(x_gt).min(), snp.abs(y).min(), snp.abs(x_tv).min(), snp.abs(x_nltv).min()),\n    vmax=max(snp.abs(x_gt).max(), snp.abs(y).max(), snp.abs(x_tv).max(), snp.abs(x_nltv).max()),\n)\nplot.imview(snp.abs(x_gt), title=\"Ground truth\", cbar=None, fig=fig, ax=ax[0, 0], norm=norm)\nplot.imview(\n    snp.abs(y),\n    title=\"Measured: PSNR %.2f (dB)\" % metric.psnr(snp.abs(x_gt), snp.abs(y)),\n    cbar=None,\n    fig=fig,\n    ax=ax[0, 1],\n    norm=norm,\n)\nplot.imview(\n    snp.abs(x_tv),\n    title=\"Standard TV: PSNR %.2f (dB)\" % metric.psnr(snp.abs(x_gt), snp.abs(x_tv)),\n    cbar=None,\n    fig=fig,\n    ax=ax[0, 2],\n    norm=norm,\n)\nplot.imview(\n    snp.abs(x_nltv),\n    title=\"Magnitude TV: PSNR %.2f (dB)\" % metric.psnr(snp.abs(x_gt), snp.abs(x_nltv)),\n    cbar=None,\n    fig=fig,\n    ax=ax[0, 3],\n    norm=norm,\n)\ndivider = make_axes_locatable(ax[0, 3])\ncax = divider.append_axes(\"right\", size=\"5%\", pad=0.2)\nfig.colorbar(ax[0, 3].get_images()[0], cax=cax)\nnorm = plot.matplotlib.colors.Normalize(\n    vmin=min(snp.angle(x_gt).min(), snp.angle(x_tv).min(), snp.angle(x_nltv).min()),\n    vmax=max(snp.angle(x_gt).max(), snp.angle(x_tv).max(), snp.angle(x_nltv).max()),\n)\nplot.imview(\n    snp.angle(x_gt),\n    title=\"Ground truth\",\n    cbar=None,\n    fig=fig,\n    ax=ax[1, 0],\n    norm=norm,\n)\nplot.imview(\n    snp.angle(y),\n    title=\"Measured: Mean phase diff. %.2f\" % phase_diff(snp.angle(x_gt), snp.angle(y)).mean(),\n    cbar=None,\n    fig=fig,\n    ax=ax[1, 1],\n    norm=norm,\n)\nplot.imview(\n    snp.angle(x_tv),\n    title=\"Standard TV: Mean phase diff. %.2f\"\n    % phase_diff(snp.angle(x_gt), snp.angle(x_tv)).mean(),\n    cbar=None,\n    fig=fig,\n    ax=ax[1, 2],\n    norm=norm,\n)\nplot.imview(\n    snp.angle(x_nltv),\n    title=\"Magnitude TV: Mean phase diff. %.2f\"\n    % phase_diff(snp.angle(x_gt), snp.angle(x_nltv)).mean(),\n    cbar=None,\n    fig=fig,\n    ax=ax[1, 3],\n    norm=norm,\n)\ndivider = make_axes_locatable(ax[1, 3])\ncax = divider.append_axes(\"right\", size=\"5%\", pad=0.2)\nfig.colorbar(ax[1, 3].get_images()[0], cax=cax)\nax[0, 0].set_ylabel(\"Magnitude\")\nax[1, 0].set_ylabel(\"Phase\")\nfig.tight_layout()\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/denoise_cplx_tv_pdhg.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nComplex Total Variation Denoising with PDHG Solver\n==================================================\n\nThis example demonstrates solution of a problem of the form\n\n  $$\\argmin_{\\mathbf{x}} \\; f(\\mathbf{x}) + g(C(\\mathbf{x})) \\;,$$\n\nwhere $C$ is a nonlinear operator, via non-linear PDHG\n:cite:`valkonen-2014-primal`. The example problem represents total\nvariation (TV) denoising applied to a complex image with piece-wise\nsmooth magnitude and non-smooth phase. The appropriate TV denoising\nformulation for this problem is\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - \\mathbf{x}\n  \\|_2^2 + \\lambda \\| C(\\mathbf{x}) \\|_{2,1} \\;,$$\n\nwhere $\\mathbf{y}$ is the measurement, $\\|\\cdot\\|_{2,1}$ is the\n$\\ell_{2,1}$ mixed norm, and $C$ is a non-linear operator that applies a\nlinear difference operator to the magnitude of a complex array. The\nstandard TV solution, which is also computed for comparison purposes,\ngives very poor results since the difference is applied independently to\nreal and imaginary components of the complex image.\n\"\"\"\n\nfrom mpl_toolkits.axes_grid1 import make_axes_locatable\nfrom xdesign import SiemensStar, discrete_phantom\n\nimport scico.numpy as snp\nimport scico.random\nfrom scico import functional, linop, loss, metric, operator, plot\nfrom scico.examples import phase_diff\nfrom scico.optimize import PDHG\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nN = 256  # image size\nphantom = SiemensStar(16)\nx_mag = snp.pad(discrete_phantom(phantom, N - 16), 8) + 1.0\nx_mag /= x_mag.max()\n# Create reference image with structured magnitude and random phase\nx_gt = x_mag * snp.exp(-1j * scico.random.randn(x_mag.shape, seed=0)[0])\n\n\n\"\"\"\nAdd noise to create a noisy test image.\n\"\"\"\nσ = 0.25  # noise standard deviation\nnoise, key = scico.random.randn(x_gt.shape, seed=1, dtype=snp.complex64)\ny = x_gt + σ * noise\n\n\n\"\"\"\nDenoise with standard total variation.\n\"\"\"\nλ_tv = 6e-2\nf = loss.SquaredL2Loss(y=y)\ng = λ_tv * functional.L21Norm()\n# The append=0 option makes the results of horizontal and vertical finite\n# differences the same shape, which is required for the L21Norm.\nC = linop.FiniteDifference(input_shape=x_gt.shape, input_dtype=snp.complex64, append=0)\nsolver_tv = PDHG(\n    f=f,\n    g=g,\n    C=C,\n    tau=4e-1,\n    sigma=4e-1,\n    maxiter=200,\n    itstat_options={\"display\": True, \"period\": 10},\n)\nprint(f\"Solving on {device_info()}\\n\")\nx_tv = solver_tv.solve()\nhist_tv = solver_tv.itstat_object.history(transpose=True)\n\n\n\"\"\"\nDenoise with total variation applied to the magnitude of a complex image.\n\"\"\"\nλ_nltv = 2e-1\ng = λ_nltv * functional.L21Norm()\n# Redefine C for real input (now applied to magnitude of a complex array)\nC = linop.FiniteDifference(input_shape=x_gt.shape, input_dtype=snp.float32, append=0)\n# Operator computing differences of absolute values\nD = C @ operator.Abs(input_shape=x_gt.shape, input_dtype=snp.complex64)\nsolver_nltv = PDHG(\n    f=f,\n    g=g,\n    C=D,\n    tau=4e-1,\n    sigma=4e-1,\n    maxiter=200,\n    itstat_options={\"display\": True, \"period\": 10},\n)\nx_nltv = solver_nltv.solve()\nhist_nltv = solver_nltv.itstat_object.history(transpose=True)\n\n\n\"\"\"\nPlot results.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6))\nplot.plot(\n    snp.array((hist_tv.Objective, hist_nltv.Objective)).T,\n    ptyp=\"semilogy\",\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    lgnd=(\"PDHG\", \"NL-PDHG\"),\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist_tv.Prml_Rsdl, hist_nltv.Prml_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Primal residual\",\n    xlbl=\"Iteration\",\n    lgnd=(\"PDHG\", \"NL-PDHG\"),\n    fig=fig,\n    ax=ax[1],\n)\nplot.plot(\n    snp.array((hist_tv.Dual_Rsdl, hist_nltv.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Dual residual\",\n    xlbl=\"Iteration\",\n    lgnd=(\"PDHG\", \"NL-PDHG\"),\n    fig=fig,\n    ax=ax[2],\n)\nfig.show()\n\n\nfig, ax = plot.subplots(nrows=2, ncols=4, figsize=(20, 10))\nnorm = plot.matplotlib.colors.Normalize(\n    vmin=min(snp.abs(x_gt).min(), snp.abs(y).min(), snp.abs(x_tv).min(), snp.abs(x_nltv).min()),\n    vmax=max(snp.abs(x_gt).max(), snp.abs(y).max(), snp.abs(x_tv).max(), snp.abs(x_nltv).max()),\n)\nplot.imview(snp.abs(x_gt), title=\"Ground truth\", cbar=None, fig=fig, ax=ax[0, 0], norm=norm)\nplot.imview(\n    snp.abs(y),\n    title=\"Measured: PSNR %.2f (dB)\" % metric.psnr(snp.abs(x_gt), snp.abs(y)),\n    cbar=None,\n    fig=fig,\n    ax=ax[0, 1],\n    norm=norm,\n)\nplot.imview(\n    snp.abs(x_tv),\n    title=\"TV: PSNR %.2f (dB)\" % metric.psnr(snp.abs(x_gt), snp.abs(x_tv)),\n    cbar=None,\n    fig=fig,\n    ax=ax[0, 2],\n    norm=norm,\n)\nplot.imview(\n    snp.abs(x_nltv),\n    title=\"NL-TV: PSNR %.2f (dB)\" % metric.psnr(snp.abs(x_gt), snp.abs(x_nltv)),\n    cbar=None,\n    fig=fig,\n    ax=ax[0, 3],\n    norm=norm,\n)\ndivider = make_axes_locatable(ax[0, 3])\ncax = divider.append_axes(\"right\", size=\"5%\", pad=0.2)\nfig.colorbar(ax[0, 3].get_images()[0], cax=cax)\nnorm = plot.matplotlib.colors.Normalize(\n    vmin=min(snp.angle(x_gt).min(), snp.angle(x_tv).min(), snp.angle(x_nltv).min()),\n    vmax=max(snp.angle(x_gt).max(), snp.angle(x_tv).max(), snp.angle(x_nltv).max()),\n)\nplot.imview(\n    snp.angle(x_gt),\n    title=\"Ground truth\",\n    cbar=None,\n    fig=fig,\n    ax=ax[1, 0],\n    norm=norm,\n)\nplot.imview(\n    snp.angle(y),\n    title=\"Measured: Mean phase diff. %.2f\" % phase_diff(snp.angle(x_gt), snp.angle(y)).mean(),\n    cbar=None,\n    fig=fig,\n    ax=ax[1, 1],\n    norm=norm,\n)\nplot.imview(\n    snp.angle(x_tv),\n    title=\"TV: Mean phase diff. %.2f\" % phase_diff(snp.angle(x_gt), snp.angle(x_tv)).mean(),\n    cbar=None,\n    fig=fig,\n    ax=ax[1, 2],\n    norm=norm,\n)\nplot.imview(\n    snp.angle(x_nltv),\n    title=\"NL-TV: Mean phase diff. %.2f\" % phase_diff(snp.angle(x_gt), snp.angle(x_nltv)).mean(),\n    cbar=None,\n    fig=fig,\n    ax=ax[1, 3],\n    norm=norm,\n)\ndivider = make_axes_locatable(ax[1, 3])\ncax = divider.append_axes(\"right\", size=\"5%\", pad=0.2)\nfig.colorbar(ax[1, 3].get_images()[0], cax=cax)\nax[0, 0].set_ylabel(\"Magnitude\")\nax[1, 0].set_ylabel(\"Phase\")\nfig.tight_layout()\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/denoise_datagen_bsds.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"\nNoisy Data Generation for NN Training\n=====================================\n\nThis example demonstrates how to generate noisy image data for\ntraining neural network models for denoising. The original images are\npart of the\n[BSDS500 dataset](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/)\nprovided by the Berkeley Segmentation Dataset and Benchmark project.\n\"\"\"\n\nimport numpy as np\n\nfrom scico import plot\nfrom scico.flax.examples import load_image_data\n\n\"\"\"\nRead data from cache or generate if not available.\n\"\"\"\nsize = 40  # patch size\ntrain_nimg = 400  # number of training images\ntest_nimg = 64  # number of testing images\nnimg = train_nimg + test_nimg\ngray = True  # use gray scale images\ndata_mode = \"dn\"  # Denoising problem\nnoise_level = 0.1  # Standard deviation of noise\nnoise_range = False  # Use fixed noise level\nstride = 23  # Stride to sample multiple patches from each image\n\ntrain_ds, test_ds = load_image_data(\n    train_nimg,\n    test_nimg,\n    size,\n    gray,\n    data_mode,\n    verbose=True,\n    noise_level=noise_level,\n    noise_range=noise_range,\n    stride=stride,\n)\n\n\n\"\"\"\nPlot randomly selected sample. Note that patches have small sizes, thus,\nplots may correspond to unidentifiable fragments.\n\"\"\"\nindx_tr = np.random.randint(0, train_nimg)\nindx_te = np.random.randint(0, test_nimg)\nfig, axes = plot.subplots(nrows=2, ncols=2, figsize=(7, 7))\nplot.imview(\n    train_ds[\"label\"][indx_tr, ..., 0],\n    title=\"Ground truth - Training Sample\",\n    fig=fig,\n    ax=axes[0, 0],\n)\nplot.imview(\n    train_ds[\"image\"][indx_tr, ..., 0],\n    title=\"Noisy Image - Training Sample\",\n    fig=fig,\n    ax=axes[0, 1],\n)\nplot.imview(\n    test_ds[\"label\"][indx_te, ..., 0],\n    title=\"Ground truth - Testing Sample\",\n    fig=fig,\n    ax=axes[1, 0],\n)\nplot.imview(\n    test_ds[\"image\"][indx_te, ..., 0], title=\"Noisy Image - Testing Sample\", fig=fig, ax=axes[1, 1]\n)\nfig.suptitle(r\"Training and Testing samples\")\nfig.tight_layout()\nfig.colorbar(\n    axes[0, 1].get_images()[0],\n    ax=axes,\n    shrink=0.5,\n    pad=0.05,\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/denoise_dncnn_train_bsds.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nTraining of DnCNN for Denoising\n===============================\n\nThis example demonstrates the training and application of the DnCNN model\nfrom :cite:`zhang-2017-dncnn` to denoise images that have been corrupted\nwith additive Gaussian noise.\n\"\"\"\n\n# isort: off\nimport os\nfrom time import time\n\nimport numpy as np\n\n# Set an arbitrary processor count (only applies if GPU is not available).\nos.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=8\"\n\nimport jax\n\ntry:\n    from jax.extend.backend import get_backend  # introduced in jax 0.4.33\nexcept ImportError:\n    from jax.lib.xla_bridge import get_backend\n\nfrom mpl_toolkits.axes_grid1 import make_axes_locatable\n\nfrom scico import flax as sflax\nfrom scico import metric, plot\nfrom scico.flax.examples import load_image_data\n\nplatform = get_backend().platform\nprint(\"Platform: \", platform)\n\n\n\"\"\"\nRead data from cache or generate if not available.\n\"\"\"\nsize = 40  # patch size\ntrain_nimg = 400  # number of training images\ntest_nimg = 16  # number of testing images\nnimg = train_nimg + test_nimg\ngray = True  # use gray scale images\ndata_mode = \"dn\"  # Denoising problem\nnoise_level = 0.1  # Standard deviation of noise\nnoise_range = False  # Use fixed noise level\nstride = 23  # Stride to sample multiple patches from each image\n\ntrain_ds, test_ds = load_image_data(\n    train_nimg,\n    test_nimg,\n    size,\n    gray,\n    data_mode,\n    verbose=True,\n    noise_level=noise_level,\n    noise_range=noise_range,\n    stride=stride,\n)\n\n\n\"\"\"\nDefine configuration dictionary for model and training loop.\n\nParameters have been selected for demonstration purposes and relatively\nshort training. The depth of the model has been reduced to 6, instead of\nthe 17 of the original model. The suggested settings can be found in the\noriginal paper.\n\"\"\"\n# model configuration\nmodel_conf = {\n    \"depth\": 6,\n    \"num_filters\": 64,\n}\n# training configuration\ntrain_conf: sflax.ConfigDict = {\n    \"seed\": 0,\n    \"opt_type\": \"ADAM\",\n    \"batch_size\": 128,\n    \"num_epochs\": 50,\n    \"base_learning_rate\": 1e-3,\n    \"warmup_epochs\": 0,\n    \"log_every_steps\": 5000,\n    \"log\": True,\n    \"checkpointing\": True,\n}\n\n\n\"\"\"\nConstruct DnCNN model.\n\"\"\"\nchannels = train_ds[\"image\"].shape[-1]\nmodel = sflax.DnCNNNet(\n    depth=model_conf[\"depth\"],\n    channels=channels,\n    num_filters=model_conf[\"num_filters\"],\n)\n\n\n\"\"\"\nRun training loop.\n\"\"\"\nworkdir = os.path.join(os.path.expanduser(\"~\"), \".cache\", \"scico\", \"examples\", \"dncnn_out\")\ntrain_conf[\"workdir\"] = workdir\nprint(f\"\\nJAX local devices: {jax.local_devices()}\\n\")\n\ntrainer = sflax.BasicFlaxTrainer(\n    train_conf,\n    model,\n    train_ds,\n    test_ds,\n)\nmodvar, stats_object = trainer.train()\n\n\n\"\"\"\nEvaluate on testing data.\n\"\"\"\ntest_patches = 720\nstart_time = time()\nfmap = sflax.FlaxMap(model, modvar)\noutput = fmap(test_ds[\"image\"][:test_patches])\ntime_eval = time() - start_time\noutput = np.clip(output, a_min=0, a_max=1.0)\n\n\n\"\"\"\nEvaluate trained model in terms of reconstruction time and data fidelity.\n\"\"\"\nsnr_eval = metric.snr(test_ds[\"label\"][:test_patches], output)\npsnr_eval = metric.psnr(test_ds[\"label\"][:test_patches], output)\nprint(\n    f\"{'DnCNNNet training':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}\"\n    f\"{'':21s}{'time[s]:':10s}{trainer.train_time:>7.2f}\"\n)\nprint(\n    f\"{'DnCNNNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}\"\n    f\"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}\"\n)\n\n\n\"\"\"\nPlot comparison. Note that plots may display unidentifiable image\nfragments due to the small patch size.\n\"\"\"\nnp.random.seed(123)\nindx = np.random.randint(0, high=test_patches)\n\nfig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))\nplot.imview(test_ds[\"label\"][indx, ..., 0], title=\"Ground truth\", cbar=None, fig=fig, ax=ax[0])\nplot.imview(\n    test_ds[\"image\"][indx, ..., 0],\n    title=\"Noisy: \\nSNR: %.2f (dB), PSNR: %.2f\"\n    % (\n        metric.snr(test_ds[\"label\"][indx, ..., 0], test_ds[\"image\"][indx, ..., 0]),\n        metric.psnr(test_ds[\"label\"][indx, ..., 0], test_ds[\"image\"][indx, ..., 0]),\n    ),\n    cbar=None,\n    fig=fig,\n    ax=ax[1],\n)\nplot.imview(\n    output[indx, ..., 0],\n    title=\"DnCNNNet Reconstruction\\nSNR: %.2f (dB), PSNR: %.2f\"\n    % (\n        metric.snr(test_ds[\"label\"][indx, ..., 0], output[indx, ..., 0]),\n        metric.psnr(test_ds[\"label\"][indx, ..., 0], output[indx, ..., 0]),\n    ),\n    fig=fig,\n    ax=ax[2],\n)\ndivider = make_axes_locatable(ax[2])\ncax = divider.append_axes(\"right\", size=\"5%\", pad=0.2)\nfig.colorbar(ax[2].get_images()[0], cax=cax, label=\"arbitrary units\")\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics. Statistics are generated only if a training\ncycle was done (i.e. if not reading final epoch results from checkpoint).\n\"\"\"\nif stats_object is not None and len(stats_object.iterations) > 0:\n    hist = stats_object.history(transpose=True)\n    fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\n    plot.plot(\n        np.array((hist.Train_Loss, hist.Eval_Loss)).T,\n        x=hist.Epoch,\n        ptyp=\"semilogy\",\n        title=\"Loss function\",\n        xlbl=\"Epoch\",\n        ylbl=\"Loss value\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[0],\n    )\n    plot.plot(\n        np.array((hist.Train_SNR, hist.Eval_SNR)).T,\n        x=hist.Epoch,\n        title=\"Metric\",\n        xlbl=\"Epoch\",\n        ylbl=\"SNR (dB)\",\n        lgnd=(\"Train\", \"Test\"),\n        fig=fig,\n        ax=ax[1],\n    )\n    fig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/denoise_dncnn_universal.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"\nComparison of DnCNN Variants for Image Denoising\n================================================\n\nThis example demonstrates the solution of an image denoising problem\nusing DnCNN :cite:`zhang-2017-dncnn` networks trained for different noise\nlevels, as well as custom variants with fewer network layers, and  with a\nnoise level input.\n\nThe networks trained for specific noise levels are labeled 6L, 6M, 6H,\n17L, 17M, and 17H, where {6, 17} denote the number of layers, and {L, M,\nH} represent noise standard deviation of the training images (0.06, 0.10,\nand 0.20 respectively). The networks with a noise standard deviation\ninput are labeled 6N and 17N, where {6, 17} again denote the number of\nlayers.\n\"\"\"\n\nimport numpy as np\n\nfrom xdesign import Foam, discrete_phantom\n\nimport scico.numpy as snp\nimport scico.random\nfrom scico import metric, plot\nfrom scico.denoiser import DnCNN\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nnp.random.seed(1234)\nN = 512  # image size\nx_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)\nx_gt = snp.array(x_gt)  # convert to jax array\n\n\n\"\"\"\nTest different DnCNN variants on images with different noise levels.\n\"\"\"\nprint(\"  σ   | variant | noisy image PSNR (dB)   | denoised image PSNR (dB)\")\nfor σ in [0.06, 0.10, 0.20]:\n    print(\"------+---------+-------------------------+-------------------------\")\n    for variant in [\"17L\", \"17M\", \"17H\", \"17N\", \"6L\", \"6M\", \"6H\", \"6N\"]:\n        # Instantiate a DnCNN.\n        denoiser = DnCNN(variant=variant)\n\n        # Generate a noisy image.\n        noise, key = scico.random.randn(x_gt.shape, seed=0)\n        y = x_gt + σ * noise\n\n        if variant in [\"6N\", \"17N\"]:\n            x_hat = denoiser(y, sigma=σ)\n        else:\n            x_hat = denoiser(y)\n\n        x_hat = np.clip(x_hat, a_min=0, a_max=1.0)\n\n        if variant[0] == \"6\":\n            variant += \" \"  # add spaces to maintain alignment\n\n        print(\n            \" %.2f | %s     |          %.2f          |          %.2f          \"\n            % (σ, variant, metric.psnr(x_gt, y), metric.psnr(x_gt, x_hat))\n        )\n\n\n\"\"\"\nShow reference and denoised images for σ=0.2 and variant=6N.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(21, 7))\nplot.imview(x_gt, title=\"Reference\", fig=fig, ax=ax[0])\nplot.imview(y, title=\"Noisy image: %.2f (dB)\" % metric.psnr(x_gt, y), fig=fig, ax=ax[1])\nplot.imview(x_hat, title=\"Denoised image: %.2f (dB)\" % metric.psnr(x_gt, x_hat), fig=fig, ax=ax[2])\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/denoise_l1tv_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nℓ1 Total Variation Denoising\n============================\n\nThis example demonstrates impulse noise removal via ℓ1 total variation\n:cite:`alliney-1992-digital` :cite:`esser-2010-primal` (Sec. 2.4.4)\n(i.e. total variation regularization with an ℓ1 data fidelity term),\nminimizing the functional\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\;  \\| \\mathbf{y} - \\mathbf{x}\n  \\|_1 + \\lambda \\| C \\mathbf{x} \\|_{2,1} \\;,$$\n\nwhere $\\mathbf{y}$ is the noisy image, $C$ is a 2D finite difference\noperator, and $\\mathbf{x}$ is the denoised image.\n\"\"\"\n\nfrom xdesign import SiemensStar, discrete_phantom\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.examples import spnoise\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\nfrom scipy.ndimage import median_filter\n\n\"\"\"\nCreate a ground truth image and impose salt & pepper noise to create a\nnoisy test image.\n\"\"\"\nN = 256  # image size\nphantom = SiemensStar(16)\nx_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)\nx_gt = 0.5 * x_gt / x_gt.max()\ny = spnoise(x_gt, 0.5)\n\n\n\"\"\"\nDenoise with median filtering.\n\"\"\"\nx_med = median_filter(y, size=(5, 5))\n\n\n\"\"\"\nDenoise with ℓ1 total variation.\n\"\"\"\nλ = 1.5e0\ng_loss = loss.Loss(y=y, f=functional.L1Norm())\ng_tv = λ * functional.L21Norm()\n# The append=0 option makes the results of horizontal and vertical finite\n# differences the same shape, which is required for the L21Norm.\nC = linop.FiniteDifference(input_shape=x_gt.shape, append=0)\n\nsolver = ADMM(\n    f=None,\n    g_list=[g_loss, g_tv],\n    C_list=[linop.Identity(input_shape=y.shape), C],\n    rho_list=[5e0, 5e0],\n    x0=y,\n    maxiter=100,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 20}),\n    itstat_options={\"display\": True, \"period\": 10},\n)\n\nprint(f\"Solving on {device_info()}\\n\")\nx_tv = solver.solve()\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nPlot results.\n\"\"\"\nplt_args = dict(norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.0))\nfig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(13, 12))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0, 0], **plt_args)\nplot.imview(y, title=\"Noisy image\", fig=fig, ax=ax[0, 1], **plt_args)\nplot.imview(\n    x_med,\n    title=f\"Median filtering: {metric.psnr(x_gt, x_med):.2f} (dB)\",\n    fig=fig,\n    ax=ax[1, 0],\n    **plt_args,\n)\nplot.imview(\n    x_tv,\n    title=f\"ℓ1-TV denoising: {metric.psnr(x_gt, x_tv):.2f} (dB)\",\n    fig=fig,\n    ax=ax[1, 1],\n    **plt_args,\n)\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\nplot.plot(\n    hist.Objective,\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    ylbl=\"Functional value\",\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[1],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/denoise_ptv_pdhg.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nPolar Total Variation Denoising (PDHG)\n======================================\n\nThis example compares denoising via standard isotropic total\nvariation (TV) regularization :cite:`rudin-1992-nonlinear`\n:cite:`goldstein-2009-split` and a variant based on local polar\ncoordinates, as described in :cite:`hossein-2024-total`. It solves the\ndenoising problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - \\mathbf{x}\n  \\|_2^2 + \\lambda R(\\mathbf{x}) \\;,$$\n\nwhere $R$ is either the isotropic or polar TV regularizer, via the\nprimal–dual hybrid gradient (PDHG) algorithm.\n\"\"\"\n\nfrom xdesign import SiemensStar, discrete_phantom\n\nimport scico.numpy as snp\nimport scico.random\nfrom scico import functional, linop, loss, metric, plot\nfrom scico.optimize import PDHG\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nN = 256  # image size\nphantom = SiemensStar(16)\nx_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)\nx_gt = x_gt / x_gt.max()\n\n\n\"\"\"\nAdd noise to create a noisy test image.\n\"\"\"\nσ = 0.75  # noise standard deviation\nnoise, key = scico.random.randn(x_gt.shape, seed=0)\ny = x_gt + σ * noise\n\n\n\"\"\"\nDenoise with standard isotropic total variation.\n\"\"\"\nλ_std = 0.8e0\nf = loss.SquaredL2Loss(y=y)\ng_std = λ_std * functional.L21Norm()\n\n# The append=0 option makes the results of horizontal and vertical finite\n# differences the same shape, which is required for the L21Norm.\nC = linop.FiniteDifference(input_shape=x_gt.shape, append=0)\ntau, sigma = PDHG.estimate_parameters(C, ratio=20.0)\nsolver = PDHG(\n    f=f,\n    g=g_std,\n    C=C,\n    tau=tau,\n    sigma=sigma,\n    maxiter=200,\n    itstat_options={\"display\": True, \"period\": 10},\n)\nprint(f\"Solving on {device_info()}\\n\")\nsolver.solve()\nhist_std = solver.itstat_object.history(transpose=True)\nx_std = solver.x\nprint()\n\n\n\"\"\"\nDenoise with polar total variation for comparison.\n\"\"\"\n# Tune the weight to give the same data fidelty as the isotropic case.\nλ_plr = 1.2e0\ng_plr = λ_plr * functional.L1Norm()\n\nG = linop.PolarGradient(input_shape=x_gt.shape)\nD = linop.Diagonal(snp.array([0.3, 1.0]).reshape((2, 1, 1)), input_shape=G.shape[0])\nC = D @ G\n\ntau, sigma = PDHG.estimate_parameters(C, ratio=20.0)\nsolver = PDHG(\n    f=f,\n    g=g_plr,\n    C=C,\n    tau=tau,\n    sigma=sigma,\n    maxiter=200,\n    itstat_options={\"display\": True, \"period\": 10},\n)\nsolver.solve()\nhist_plr = solver.itstat_object.history(transpose=True)\nx_plr = solver.x\nprint()\n\n\n\"\"\"\nCompute and print the data fidelity.\n\"\"\"\nfor x, name in zip((x_std, x_plr), (\"Isotropic\", \"Polar\")):\n    df = f(x)\n    print(f\"Data fidelity for {(name + ' TV'):12}: {df:.2e}   SNR: {metric.snr(x_gt, x):5.2f} dB\")\n\n\n\"\"\"\nPlot results.\n\"\"\"\nplt_args = dict(norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.5))\nfig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0, 0], **plt_args)\nplot.imview(y, title=\"Noisy version\", fig=fig, ax=ax[0, 1], **plt_args)\nplot.imview(x_std, title=\"Isotropic TV denoising\", fig=fig, ax=ax[1, 0], **plt_args)\nplot.imview(x_plr, title=\"Polar TV denoising\", fig=fig, ax=ax[1, 1], **plt_args)\nfig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)\nfig.colorbar(\n    ax[0, 0].get_images()[0], ax=ax, location=\"right\", shrink=0.9, pad=0.05, label=\"Arbitrary Units\"\n)\nfig.suptitle(\"Denoising comparison\")\nfig.show()\n\n# zoomed version\nfig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0, 0], **plt_args)\nplot.imview(y, title=\"Noisy version\", fig=fig, ax=ax[0, 1], **plt_args)\nplot.imview(x_std, title=\"Isotropic TV denoising\", fig=fig, ax=ax[1, 0], **plt_args)\nplot.imview(x_plr, title=\"Polar TV denoising\", fig=fig, ax=ax[1, 1], **plt_args)\nax[0, 0].set_xlim(N // 4, N // 4 + N // 2)\nax[0, 0].set_ylim(N // 4, N // 4 + N // 2)\nfig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)\nfig.colorbar(\n    ax[0, 0].get_images()[0], ax=ax, location=\"right\", shrink=0.9, pad=0.05, label=\"Arbitrary Units\"\n)\nfig.suptitle(\"Denoising comparison (zoomed)\")\nfig.show()\n\n\nfig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(20, 5))\nplot.plot(\n    snp.array((hist_std.Objective, hist_plr.Objective)).T,\n    ptyp=\"semilogy\",\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Standard\", \"Polar\"),\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist_std.Prml_Rsdl, hist_plr.Prml_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Primal residual\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Standard\", \"Polar\"),\n    fig=fig,\n    ax=ax[1],\n)\nplot.plot(\n    snp.array((hist_std.Dual_Rsdl, hist_plr.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Dual residual\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Standard\", \"Polar\"),\n    fig=fig,\n    ax=ax[2],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/denoise_tv_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nTotal Variation Denoising (ADMM)\n================================\n\nThis example compares denoising via isotropic and anisotropic total\nvariation (TV) regularization :cite:`rudin-1992-nonlinear`\n:cite:`goldstein-2009-split`. It solves the denoising problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - \\mathbf{x}\n  \\|_2^2 + \\lambda R(\\mathbf{x}) \\;,$$\n\nwhere $R$ is either the isotropic or anisotropic TV regularizer.\nIn SCICO, switching between these two regularizers involves a one-line\nchange: replacing an\n[L1Norm](../_autosummary/scico.functional.rst#scico.functional.L1Norm)\nwith a\n[L21Norm](../_autosummary/scico.functional.rst#scico.functional.L21Norm).\nNote that the isotropic version exhibits fewer block-like artifacts on\nedges that are not vertical or horizontal.\n\"\"\"\n\nfrom xdesign import SiemensStar, discrete_phantom\n\nimport scico.numpy as snp\nimport scico.random\nfrom scico import functional, linop, loss, plot\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nN = 256  # image size\nphantom = SiemensStar(16)\nx_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)\nx_gt = x_gt / x_gt.max()\n\n\n\"\"\"\nAdd noise to create a noisy test image.\n\"\"\"\nσ = 0.75  # noise standard deviation\nnoise, key = scico.random.randn(x_gt.shape, seed=0)\ny = x_gt + σ * noise\n\n\n\"\"\"\nDenoise with isotropic total variation.\n\"\"\"\nλ_iso = 1.4e0\nf = loss.SquaredL2Loss(y=y)\ng_iso = λ_iso * functional.L21Norm()\n\n# The append=0 option makes the results of horizontal and vertical finite\n# differences the same shape, which is required for the L21Norm.\nC = linop.FiniteDifference(input_shape=x_gt.shape, append=0)\nsolver = ADMM(\n    f=f,\n    g_list=[g_iso],\n    C_list=[C],\n    rho_list=[1e1],\n    x0=y,\n    maxiter=100,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 20}),\n    itstat_options={\"display\": True, \"period\": 10},\n)\n\nprint(f\"Solving on {device_info()}\\n\")\nsolver.solve()\nx_iso = solver.x\nprint()\n\n\"\"\"\nDenoise with anisotropic total variation for comparison.\n\"\"\"\n# Tune the weight to give the same data fidelity as the isotropic case.\nλ_aniso = 1.2e0\ng_aniso = λ_aniso * functional.L1Norm()\n\nsolver = ADMM(\n    f=f,\n    g_list=[g_aniso],\n    C_list=[C],\n    rho_list=[1e1],\n    x0=y,\n    maxiter=100,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 20}),\n    itstat_options={\"display\": True, \"period\": 10},\n)\n\nsolver.solve()\nx_aniso = solver.x\nprint()\n\n\n\"\"\"\nCompute and print the data fidelity.\n\"\"\"\nfor x, name in zip((x_iso, x_aniso), (\"Isotropic\", \"Anisotropic\")):\n    df = f(x)\n    print(f\"Data fidelity for {name} TV was {df:.2e}\")\n\n\n\"\"\"\nPlot results.\n\"\"\"\nplt_args = dict(norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.5))\nfig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0, 0], **plt_args)\nplot.imview(y, title=\"Noisy version\", fig=fig, ax=ax[0, 1], **plt_args)\nplot.imview(x_iso, title=\"Isotropic TV denoising\", fig=fig, ax=ax[1, 0], **plt_args)\nplot.imview(x_aniso, title=\"Anisotropic TV denoising\", fig=fig, ax=ax[1, 1], **plt_args)\nfig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)\nfig.colorbar(\n    ax[0, 0].get_images()[0], ax=ax, location=\"right\", shrink=0.9, pad=0.05, label=\"Arbitrary Units\"\n)\nfig.suptitle(\"Denoising comparison\")\nfig.show()\n\n# zoomed version\nfig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0, 0], **plt_args)\nplot.imview(y, title=\"Noisy version\", fig=fig, ax=ax[0, 1], **plt_args)\nplot.imview(x_iso, title=\"Isotropic TV denoising\", fig=fig, ax=ax[1, 0], **plt_args)\nplot.imview(x_aniso, title=\"Anisotropic TV denoising\", fig=fig, ax=ax[1, 1], **plt_args)\nax[0, 0].set_xlim(N // 4, N // 4 + N // 2)\nax[0, 0].set_ylim(N // 4, N // 4 + N // 2)\nfig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)\nfig.colorbar(\n    ax[0, 0].get_images()[0], ax=ax, location=\"right\", shrink=0.9, pad=0.05, label=\"Arbitrary Units\"\n)\nfig.suptitle(\"Denoising comparison (zoomed)\")\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/denoise_tv_apgm.py",
    "content": "#!/Usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nTotal Variation Denoising with Constraint (APGM)\n================================================\n\nThis example demonstrates the solution of the isotropic total variation\n(TV) denoising problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - \\mathbf{x}\n  \\|_2^2 + \\lambda R(\\mathbf{x}) + \\iota_C(\\mathbf{x}) \\;,$$\n\nwhere $R$ is a TV regularizer, $\\iota_C(\\cdot)$ is the indicator function\nof constraint set $C$, and $C = \\{ \\mathbf{x} \\, | \\, x_i \\in [0, 1] \\}$,\ni.e. the set of vectors with components constrained to be in the interval\n$[0, 1]$. The problem is solved seperately with $R$ taken as isotropic\nand anisotropic TV regularization\n\nThe solution via APGM is based on the approach in :cite:`beck-2009-tv`,\nwhich involves constructing a dual for the constrained denoising problem.\nThe APGM solution minimizes the resulting dual. In this case, switching\nbetween the two regularizers corresponds to switching between two\ndifferent projectors.\n\"\"\"\n\nfrom typing import Callable, Optional, Union\n\nimport jax.numpy as jnp\n\nfrom xdesign import SiemensStar, discrete_phantom\n\nimport scico.numpy as snp\nimport scico.random\nfrom scico import functional, linop, loss, operator, plot\nfrom scico.numpy import Array, BlockArray\nfrom scico.optimize.pgm import AcceleratedPGM, RobustLineSearchStepSize\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nN = 256  # image size\nphantom = SiemensStar(16)\nx_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)\nx_gt = x_gt / x_gt.max()\n\n\n\"\"\"\nAdd noise to create a noisy test image.\n\"\"\"\nσ = 0.75  # noise standard deviation\nnoise, key = scico.random.randn(x_gt.shape, seed=0)\ny = x_gt + σ * noise\n\n\n\"\"\"\nDefine finite difference operator and adjoint.\n\"\"\"\n# The append=0 option appends 0 to the input along the axis\n# prior to performing the difference to make the results of\n# horizontal and vertical finite differences the same shape.\nC = linop.FiniteDifference(input_shape=x_gt.shape, append=0)\nA = C.adj\n\n\n\"\"\"\nDefine a zero array as initial estimate.\n\"\"\"\nx0 = jnp.zeros(C(y).shape)\n\n\n\"\"\"\nDefine the dual of the total variation denoising problem.\n\"\"\"\n\n\nclass DualTVLoss(loss.Loss):\n    def __init__(\n        self,\n        y: Union[Array, BlockArray],\n        A: Optional[Union[Callable, operator.Operator]] = None,\n        lmbda: float = 0.5,\n    ):\n        self.functional = functional.SquaredL2Norm()\n        super().__init__(y=y, A=A, scale=1.0)\n        self.lmbda = lmbda\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        xint = self.y - self.lmbda * self.A(x)\n        return -1.0 * self.functional(xint - jnp.clip(xint, 0.0, 1.0)) + self.functional(xint)\n\n\n\"\"\"\nDenoise with isotropic total variation. Define projector for isotropic\ntotal variation.\n\"\"\"\n\n\n# Evaluation of functional set to zero.\nclass IsoProjector(functional.Functional):\n    has_eval = True\n    has_prox = True\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        return 0.0\n\n    def prox(self, v: Array, lam: float, **kwargs) -> Array:\n        norm_v_ptp = jnp.sqrt(jnp.sum(jnp.abs(v) ** 2, axis=0))\n\n        x_out = v / jnp.maximum(jnp.ones(v.shape), norm_v_ptp)\n        out1 = v[0, :, -1] / jnp.maximum(jnp.ones(v[0, :, -1].shape), jnp.abs(v[0, :, -1]))\n        x_out = x_out.at[0, :, -1].set(out1)\n        out2 = v[1, -1, :] / jnp.maximum(jnp.ones(v[1, -1, :].shape), jnp.abs(v[1, -1, :]))\n        x_out = x_out.at[1, -1, :].set(out2)\n\n        return x_out\n\n\n\"\"\"\nSet up `AcceleratedPGM` solver object using `RobustLineSearchStepSize`\nstep size policy. Run the solver.\n\"\"\"\nreg_weight_iso = 1.4e0\nf_iso = DualTVLoss(y=y, A=A, lmbda=reg_weight_iso)\ng_iso = IsoProjector()\n\nsolver_iso = AcceleratedPGM(\n    f=f_iso,\n    g=g_iso,\n    L0=16.0 * f_iso.lmbda**2,\n    x0=x0,\n    maxiter=100,\n    itstat_options={\"display\": True, \"period\": 10},\n    step_size=RobustLineSearchStepSize(),\n)\n\n# Run the solver.\nprint(f\"Solving on {device_info()}\\n\")\nx = solver_iso.solve()\nhist_iso = solver_iso.itstat_object.history(transpose=True)\n# Project to constraint set.\nx_iso = jnp.clip(y - f_iso.lmbda * f_iso.A(x), 0.0, 1.0)\n\n\n\"\"\"\nDenoise with anisotropic total variation for comparison. Define\nprojector for anisotropic total variation.\n\"\"\"\n\n\n# Evaluation of functional set to zero.\nclass AnisoProjector(functional.Functional):\n    has_eval = True\n    has_prox = True\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        return 0.0\n\n    def prox(self, v: Array, lam: float, **kwargs) -> Array:\n        return v / jnp.maximum(jnp.ones(v.shape), jnp.abs(v))\n\n\n\"\"\"\nSet up `AcceleratedPGM` solver object using `RobustLineSearchStepSize`\nstep size policy. (Weight was tuned to give the same data fidelity as the\nisotropic case.) Run the solver.\n\"\"\"\n\nreg_weight_aniso = 1.2e0\nf = DualTVLoss(y=y, A=A, lmbda=reg_weight_aniso)\ng = AnisoProjector()\n\nsolver = AcceleratedPGM(\n    f=f,\n    g=g,\n    L0=16.0 * f.lmbda**2,\n    x0=x0,\n    maxiter=100,\n    itstat_options={\"display\": True, \"period\": 10},\n    step_size=RobustLineSearchStepSize(),\n)\n\n# Run the solver.\nprint()\nx = solver.solve()\n# Project to constraint set.\nx_aniso = jnp.clip(y - f.lmbda * f.A(x), 0.0, 1.0)\n\n\n\"\"\"\nCompute the data fidelity.\n\"\"\"\ndf = hist_iso.Objective[-1]\nprint(f\"\\nData fidelity for isotropic TV was {df:.2e}\")\nhist = solver.itstat_object.history(transpose=True)\ndf = hist.Objective[-1]\nprint(f\"Data fidelity for anisotropic TV was {df:.2e}\")\n\n\n\"\"\"\nPlot results.\n\"\"\"\nplt_args = dict(norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.5))\nfig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0, 0], **plt_args)\nplot.imview(y, title=\"Noisy version\", fig=fig, ax=ax[0, 1], **plt_args)\nplot.imview(x_iso, title=\"Isotropic TV denoising\", fig=fig, ax=ax[1, 0], **plt_args)\nplot.imview(x_aniso, title=\"Anisotropic TV denoising\", fig=fig, ax=ax[1, 1], **plt_args)\nfig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)\nfig.colorbar(\n    ax[0, 0].get_images()[0], ax=ax, location=\"right\", shrink=0.9, pad=0.05, label=\"Arbitrary Units\"\n)\nfig.suptitle(\"Denoising comparison\")\nfig.show()\n\n# zoomed version\nfig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10))\nplot.imview(x_gt, title=\"Ground truth\", fig=fig, ax=ax[0, 0], **plt_args)\nplot.imview(y, title=\"Noisy version\", fig=fig, ax=ax[0, 1], **plt_args)\nplot.imview(x_iso, title=\"Isotropic TV denoising\", fig=fig, ax=ax[1, 0], **plt_args)\nplot.imview(x_aniso, title=\"Anisotropic TV denoising\", fig=fig, ax=ax[1, 1], **plt_args)\nax[0, 0].set_xlim(N // 4, N // 4 + N // 2)\nax[0, 0].set_ylim(N // 4, N // 4 + N // 2)\nfig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)\nfig.colorbar(\n    ax[0, 0].get_images()[0], ax=ax, location=\"right\", shrink=0.9, pad=0.05, label=\"Arbitrary Units\"\n)\nfig.suptitle(\"Denoising comparison (zoomed)\")\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/denoise_tv_multi.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nComparison of Optimization Algorithms for Total Variation Denoising\n===================================================================\n\nThis example compares the performance of alternating direction method of\nmultipliers (ADMM), linearized ADMM, proximal ADMM, and primal–dual\nhybrid gradient (PDHG) in solving the isotropic total variation (TV)\ndenoising problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - \\mathbf{x}\n  \\|_2^2 + \\lambda R(\\mathbf{x}) \\;,$$\n\nwhere $R$ is the isotropic TV: the sum of the norms of the gradient\nvectors at each point in the image $\\mathbf{x}$.\n\"\"\"\n\nfrom xdesign import SiemensStar, discrete_phantom\n\nimport scico.numpy as snp\nimport scico.random\nfrom scico import functional, linop, loss, plot\nfrom scico.optimize import PDHG, LinearizedADMM, ProximalADMM\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nCreate a ground truth image.\n\"\"\"\nphantom = SiemensStar(32)\nN = 256  # image size\nx_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)\n\n\n\"\"\"\nAdd noise to create a noisy test image.\n\"\"\"\nσ = 1.0  # noise standard deviation\nnoise, key = scico.random.randn(x_gt.shape, seed=0)\ny = x_gt + σ * noise\n\n\n\"\"\"\nConstruct operators and functionals and set regularization parameter.\n\"\"\"\n# The append=0 option makes the results of horizontal and vertical\n# finite differences the same shape, which is required for the L21Norm.\nC = linop.FiniteDifference(input_shape=x_gt.shape, append=0)\nf = loss.SquaredL2Loss(y=y)\nλ = 1e0\ng = λ * functional.L21Norm()\n\n\n\"\"\"\nThe first step of the first-run solver is much slower than the\nfollowing steps, presumably due to just-in-time compilation of\nrelevant operators in first use. The code below performs a preliminary\nsolver step, the result of which is discarded, to reduce this bias in\nthe timing results. The precise cause of the remaining differences in\ntime required to compute the first step of each algorithm is unknown,\nbut it is worth noting that this difference becomes negligible when\njust-in-time compilation is disabled (e.g. via the `JAX_DISABLE_JIT`\nenvironment variable).\n\"\"\"\nsolver_admm = ADMM(\n    f=f,\n    g_list=[g],\n    C_list=[C],\n    rho_list=[1e1],\n    x0=y,\n    maxiter=1,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"maxiter\": 1}),\n)\nsolver_admm.solve();  # fmt: skip\n# trailing semi-colon suppresses output in notebook\n\n\n\"\"\"\nSolve via ADMM with a maximum of 2 CG iterations.\n\"\"\"\nsolver_admm = ADMM(\n    f=f,\n    g_list=[g],\n    C_list=[C],\n    rho_list=[1e1],\n    x0=y,\n    maxiter=200,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"maxiter\": 2}),\n    itstat_options={\"display\": True, \"period\": 10},\n)\nprint(f\"Solving on {device_info()}\\n\")\nprint(\"ADMM solver\")\nsolver_admm.solve()\nhist_admm = solver_admm.itstat_object.history(transpose=True)\n\n\n\"\"\"\nSolve via Linearized ADMM.\n\"\"\"\nsolver_ladmm = LinearizedADMM(\n    f=f,\n    g=g,\n    C=C,\n    mu=1e-2,\n    nu=1e-1,\n    x0=y,\n    maxiter=200,\n    itstat_options={\"display\": True, \"period\": 10},\n)\nprint(\"\\nLinearized ADMM solver\")\nsolver_ladmm.solve()\nhist_ladmm = solver_ladmm.itstat_object.history(transpose=True)\n\n\n\"\"\"\nSolve via Proximal ADMM.\n\"\"\"\nmu, nu = ProximalADMM.estimate_parameters(C)\nsolver_padmm = ProximalADMM(\n    f=f,\n    g=g,\n    A=C,\n    rho=1e0,\n    mu=mu,\n    nu=nu,\n    x0=y,\n    maxiter=200,\n    itstat_options={\"display\": True, \"period\": 10},\n)\nprint(\"\\nProximal ADMM solver\")\nsolver_padmm.solve()\nhist_padmm = solver_padmm.itstat_object.history(transpose=True)\n\n\n\"\"\"\nSolve via PDHG.\n\"\"\"\ntau, sigma = PDHG.estimate_parameters(C, factor=1.5)\nsolver_pdhg = PDHG(\n    f=f,\n    g=g,\n    C=C,\n    tau=tau,\n    sigma=sigma,\n    maxiter=200,\n    itstat_options={\"display\": True, \"period\": 10},\n)\nprint(\"\\nPDHG solver\")\nsolver_pdhg.solve()\nhist_pdhg = solver_pdhg.itstat_object.history(transpose=True)\n\n\n\"\"\"\nPlot results. It is worth noting that:\n\n1. PDHG outperforms ADMM both with respect to iterations and time.\n2. Proximal ADMM has similar performance to PDHG with respect to iterations,\n   but is slightly inferior with respect to time.\n3. ADMM greatly outperforms Linearized ADMM with respect to iterations.\n4. ADMM slightly outperforms Linearized ADMM with respect to time. This is\n   possible because the ADMM $\\mathbf{x}$-update can be solved relatively\n   cheaply, with only 2 CG iterations. If more CG iterations were required,\n   the time comparison would be favorable to Linearized ADMM.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6))\nplot.plot(\n    snp.array(\n        (hist_admm.Objective, hist_ladmm.Objective, hist_padmm.Objective, hist_pdhg.Objective)\n    ).T,\n    ptyp=\"semilogy\",\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    lgnd=(\"ADMM\", \"LinADMM\", \"ProxADMM\", \"PDHG\"),\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array(\n        (hist_admm.Prml_Rsdl, hist_ladmm.Prml_Rsdl, hist_padmm.Prml_Rsdl, hist_pdhg.Prml_Rsdl)\n    ).T,\n    ptyp=\"semilogy\",\n    title=\"Primal residual\",\n    xlbl=\"Iteration\",\n    lgnd=(\"ADMM\", \"LinADMM\", \"ProxADMM\", \"PDHG\"),\n    fig=fig,\n    ax=ax[1],\n)\nplot.plot(\n    snp.array(\n        (hist_admm.Dual_Rsdl, hist_ladmm.Dual_Rsdl, hist_padmm.Dual_Rsdl, hist_pdhg.Dual_Rsdl)\n    ).T,\n    ptyp=\"semilogy\",\n    title=\"Dual residual\",\n    xlbl=\"Iteration\",\n    lgnd=(\"ADMM\", \"LinADMM\", \"ProxADMM\", \"PDHG\"),\n    fig=fig,\n    ax=ax[2],\n)\nfig.show()\n\nfig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6))\nplot.plot(\n    snp.array(\n        (hist_admm.Objective, hist_ladmm.Objective, hist_padmm.Objective, hist_pdhg.Objective)\n    ).T,\n    snp.array((hist_admm.Time, hist_ladmm.Time, hist_padmm.Time, hist_pdhg.Time)).T,\n    ptyp=\"semilogy\",\n    title=\"Objective function\",\n    xlbl=\"Time (s)\",\n    lgnd=(\"ADMM\", \"LinADMM\", \"ProxADMM\", \"PDHG\"),\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array(\n        (hist_admm.Prml_Rsdl, hist_ladmm.Prml_Rsdl, hist_padmm.Prml_Rsdl, hist_pdhg.Prml_Rsdl)\n    ).T,\n    snp.array((hist_admm.Time, hist_ladmm.Time, hist_padmm.Time, hist_pdhg.Time)).T,\n    ptyp=\"semilogy\",\n    title=\"Primal residual\",\n    xlbl=\"Time (s)\",\n    lgnd=(\"ADMM\", \"LinADMM\", \"ProxADMM\", \"PDHG\"),\n    fig=fig,\n    ax=ax[1],\n)\nplot.plot(\n    snp.array(\n        (hist_admm.Dual_Rsdl, hist_ladmm.Dual_Rsdl, hist_padmm.Dual_Rsdl, hist_pdhg.Dual_Rsdl)\n    ).T,\n    snp.array((hist_admm.Time, hist_ladmm.Time, hist_padmm.Time, hist_pdhg.Time)).T,\n    ptyp=\"semilogy\",\n    title=\"Dual residual\",\n    xlbl=\"Time (s)\",\n    lgnd=(\"ADMM\", \"LinADMM\", \"ProxADMM\", \"PDHG\"),\n    fig=fig,\n    ax=ax[2],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/diffusercam_tv_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nTV-Regularized 3D DiffuserCam Reconstruction\n============================================\n\nThis example demonstrates reconstruction of a 3D DiffuserCam\n:cite:`antipa-2018-diffusercam`\n[dataset](https://github.com/Waller-Lab/DiffuserCam/tree/master/example_data).\nThe inverse problem can be written as\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; \\frac{1}{2} \\Big\\| \\mathbf{y} -\n  M \\Big( \\sum_k \\mathbf{h}_k \\ast \\mathbf{x}_k \\Big) \\Big\\|_2^2 +\n  \\lambda_0 \\sum_k \\| D \\mathbf{x}_k \\|_{2,1} +\n  \\lambda_1 \\sum_k \\| \\mathbf{x}_k \\|_1  \\;,$$\n\nwhere the $\\mathbf{h}$_k are the components of the PSF stack, the\n$\\mathbf{x}$_k are the corrresponding components of the reconstructed\nvolume, $\\mathbf{y}$ is the measured image, and $M$ is a cropping\noperator that allows the boundary artifacts resulting from circular\nconvolution to be avoided. Following the mask decoupling approach\n:cite:`almeida-2013-deconvolving`, the problem is posed in ADMM form\nas\n\n  $$\\mathrm{argmin}_{\\mathbf{x}, \\mathbf{z}_0, \\mathbf{z}_1,\n  \\mathbf{z}_2} \\; \\frac{1}{2} \\| \\mathbf{y} - M \\mathbf{z}_0 \\|_2^2 +\n  \\lambda_0 \\sum_k \\| \\mathbf{z}_{1,k} \\|_{2,1} +\n  \\lambda_1 \\sum_k \\| \\mathbf{z}_{2,k}\n  \\|_1  \\\\ \\;\\; \\text{s.t.} \\;\\; \\mathbf{z}_0 = \\sum_k \\mathbf{h}_k \\ast\n  \\mathbf{x}_k \\qquad \\mathbf{z}_{1,k} = D \\mathbf{x}_k \\qquad\n  \\mathbf{z}_{2,k} = \\mathbf{x}_k \\;.$$\n\nThe most computationally expensive step in the ADMM algorithm is solved\nusing the frequency-domain approach proposed in\n:cite:`wohlberg-2014-efficient`.\n\"\"\"\n\nimport numpy as np\n\nimport scico.numpy as snp\nfrom scico import plot\nfrom scico.examples import ucb_diffusercam_data\nfrom scico.functional import L1Norm, L21Norm, ZeroFunctional\nfrom scico.linop import CircularConvolve, Crop, FiniteDifference, Identity, Sum\nfrom scico.loss import SquaredL2Loss\nfrom scico.optimize.admm import ADMM, G0BlockCircularConvolveSolver\nfrom scico.util import device_info\n\n\"\"\"\nLoad the DiffuserCam PSF stack and measured image. The computational cost\nof the reconstruction is reduced slightly by removing parts of the PSF\nstack that don't make a significant contribution to the reconstruction.\n\"\"\"\ny, psf = ucb_diffusercam_data()\npsf = psf[..., 1:-7]\n\n\n\"\"\"\nTo avoid boundary artifacts, the measured image is padded by half the PSF\nwidth/height and then cropped within the data fidelity term. This padding\nis implicit in that the reconstruction volume is computed at the padded\nsize, but the actual measured image is never explicitly padded since it is\nused at the original (unpadded) size within the data fidelity term due to\nthe cropping operation. The PSF axis order is modified to put the stack\naxis at index 0, as required by components of the ADMM solver to be used.\nFinally, each PSF in the stack is individually normalized.\n\"\"\"\nhalf_psf = np.array(psf.shape[0:2]) // 2\npad_spec = ((half_psf[0],) * 2, (half_psf[1],) * 2)\ny_pad_shape = tuple(np.array(y.shape) + np.array(pad_spec).sum(axis=1))\nx_shape = (psf.shape[-1],) + y_pad_shape\npsf = psf.transpose((2, 0, 1))\npsf /= np.sqrt(np.sum(psf**2, axis=(1, 2), keepdims=True))\n\n\n\"\"\"\nConvert the image and PSF stack to JAX arrays with `float32` dtype since\nJAX by default does not support double-precision floating point\narithmetic. This limited precision leads to relatively poor, but still\nacceptable accuracy within the ADMM solver x-step. To experiment with the\neffect of higher numerical precision, set the environment variable\n`JAX_ENABLE_X64=True` and change `dtype` below to `np.float64`.\n\"\"\"\ndtype = np.float32\ny = snp.array(y.astype(dtype))\npsf = snp.array(psf.astype(dtype))\n\n\n\"\"\"\nDefine problem and algorithm parameters.\n\"\"\"\nλ0 = 3e-3  # TV regularization parameter\nλ1 = 1e-2  # ℓ1 norm regularization parameter\nρ0 = 1e0  # ADMM penalty parameter for first auxiliary variable\nρ1 = 5e0  # ADMM penalty parameter for second auxiliary variable\nρ2 = 1e1  # ADMM penalty parameter for third auxiliary variable\nmaxiter = 100  # number of ADMM iterations\n\n\n\"\"\"\nCreate operators.\n\"\"\"\nC = CircularConvolve(psf, input_shape=x_shape, input_dtype=dtype, h_center=half_psf, ndims=2)\nS = Sum(input_shape=x_shape, input_dtype=dtype, axis=0)\nM = Crop(pad_spec, input_shape=y_pad_shape, input_dtype=dtype)\n\n\n\"\"\"\nCreate functionals.\n\"\"\"\ng0 = SquaredL2Loss(y=y, A=M)\ng1 = λ0 * L21Norm()\ng2 = λ1 * L1Norm()\nC0 = S @ C\nC1 = FiniteDifference(input_shape=x_shape, input_dtype=dtype, axes=(-2, -1), circular=True)\nC2 = Identity(input_shape=x_shape, input_dtype=dtype)\n\n\n\"\"\"\nSet up ADMM solver object and solve problem.\n\"\"\"\nsolver = ADMM(\n    f=ZeroFunctional(),\n    g_list=[g0, g1, g2],\n    C_list=[C0, C1, C2],\n    rho_list=[ρ0, ρ1, ρ2],\n    alpha=1.4,\n    maxiter=maxiter,\n    nanstop=True,\n    subproblem_solver=G0BlockCircularConvolveSolver(ndims=2, check_solve=True),\n    itstat_options={\"display\": True, \"period\": 10},\n)\n\nprint(f\"Solving on {device_info()}\\n\")\nx = solver.solve()\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow the measured image and samples from PDF stack\n\"\"\"\nplot.imview(y, cmap=plot.plt.cm.Blues, cbar=True, title=\"Measured Image\")\n\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(14, 7))\nplot.imview(psf[0], title=\"Nearest PSF\", cmap=plot.plt.cm.Blues, fig=fig, ax=ax[0])\nplot.imview(psf[-1], title=\"Furthest PSF\", cmap=plot.plt.cm.Blues, fig=fig, ax=ax[1])\nfig.show()\n\n\n\"\"\"\nShow the recovered volume with depth indicated by color.\n\"\"\"\nXCrop = Crop(((0, 0),) + pad_spec, input_shape=x_shape, input_dtype=dtype)\nxm = np.array(XCrop(x[..., ::-1]))\nxmr = xm.transpose((1, 2, 0))[..., np.newaxis] / xm.max()\ncmap = plot.plt.cm.viridis_r\ncmval = cmap(np.arange(0, xm.shape[0]).reshape(1, 1, -1) / (xm.shape[0] - 1))\nxms = np.sum(cmval * xmr, axis=2)[..., 0:3]\n\nplot.imview(xms, cmap=cmap, cbar=True, title=\"Recovered Volume\")\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\nplot.plot(\n    hist.Objective,\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    ylbl=\"Functional value\",\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[1],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/index.rst",
    "content": "Usage Examples\n==============\n\n\nOrganized by Application\n------------------------\n\n\nComputed Tomography\n^^^^^^^^^^^^^^^^^^^\n\n   - ct_abel_tv_admm.py\n   - ct_abel_tv_admm_tune.py\n   - ct_symcone_tv_padmm.py\n   - ct_astra_noreg_pcg.py\n   - ct_astra_3d_tv_admm.py\n   - ct_astra_3d_tv_padmm.py\n   - ct_tv_admm.py\n   - ct_astra_tv_admm.py\n   - ct_multi_tv_admm.py\n   - ct_astra_weighted_tv_admm.py\n   - ct_svmbir_tv_multi.py\n   - ct_svmbir_ppp_bm3d_admm_cg.py\n   - ct_svmbir_ppp_bm3d_admm_prox.py\n   - ct_fan_svmbir_ppp_bm3d_admm_prox.py\n   - ct_modl_train_foam2.py\n   - ct_odp_train_foam2.py\n   - ct_unet_train_foam2.py\n   - ct_projector_comparison_2d.py\n   - ct_projector_comparison_3d.py\n\nDeconvolution\n^^^^^^^^^^^^^\n\n   - deconv_circ_tv_admm.py\n   - deconv_tv_admm.py\n   - deconv_tv_padmm.py\n   - deconv_tv_admm_tune.py\n   - deconv_microscopy_tv_admm.py\n   - deconv_microscopy_allchn_tv_admm.py\n   - deconv_ppp_bm3d_admm.py\n   - deconv_ppp_bm3d_apgm.py\n   - deconv_ppp_dncnn_admm.py\n   - deconv_ppp_dncnn_padmm.py\n   - deconv_ppp_bm4d_admm.py\n   - deconv_modl_train_foam1.py\n   - deconv_odp_train_foam1.py\n\n\nSparse Coding\n^^^^^^^^^^^^^\n\n   - sparsecode_nn_admm.py\n   - sparsecode_nn_apgm.py\n   - sparsecode_conv_admm.py\n   - sparsecode_conv_md_admm.py\n   - sparsecode_apgm.py\n   - sparsecode_poisson_apgm.py\n\n\nMiscellaneous\n^^^^^^^^^^^^^\n\n   - demosaic_ppp_bm3d_admm.py\n   - superres_ppp_dncnn_admm.py\n   - denoise_l1tv_admm.py\n   - denoise_ptv_pdhg.py\n   - denoise_tv_admm.py\n   - denoise_tv_apgm.py\n   - denoise_tv_multi.py\n   - denoise_approx_tv_multi.py\n   - denoise_cplx_tv_nlpadmm.py\n   - denoise_cplx_tv_pdhg.py\n   - denoise_dncnn_universal.py\n   - diffusercam_tv_admm.py\n   - video_rpca_admm.py\n   - ct_datagen_foam2.py\n   - deconv_datagen_bsds.py\n   - deconv_datagen_foam1.py\n   - denoise_datagen_bsds.py\n\n\nOrganized by Regularization\n---------------------------\n\nPlug and Play Priors\n^^^^^^^^^^^^^^^^^^^^\n\n   - ct_svmbir_ppp_bm3d_admm_cg.py\n   - ct_svmbir_ppp_bm3d_admm_prox.py\n   - ct_fan_svmbir_ppp_bm3d_admm_prox.py\n   - deconv_ppp_bm3d_admm.py\n   - deconv_ppp_bm3d_apgm.py\n   - deconv_ppp_dncnn_admm.py\n   - deconv_ppp_dncnn_padmm.py\n   - deconv_ppp_bm4d_admm.py\n   - demosaic_ppp_bm3d_admm.py\n   - superres_ppp_dncnn_admm.py\n\n\nTotal Variation\n^^^^^^^^^^^^^^^\n\n   - ct_abel_tv_admm.py\n   - ct_abel_tv_admm_tune.py\n   - ct_symcone_tv_padmm.py\n   - ct_tv_admm.py\n   - ct_multi_tv_admm.py\n   - ct_astra_tv_admm.py\n   - ct_astra_3d_tv_admm.py\n   - ct_astra_3d_tv_padmm.py\n   - ct_astra_weighted_tv_admm.py\n   - ct_svmbir_tv_multi.py\n   - deconv_circ_tv_admm.py\n   - deconv_tv_admm.py\n   - deconv_tv_admm_tune.py\n   - deconv_tv_padmm.py\n   - deconv_microscopy_tv_admm.py\n   - deconv_microscopy_allchn_tv_admm.py\n   - denoise_l1tv_admm.py\n   - denoise_ptv_pdhg.py\n   - denoise_tv_admm.py\n   - denoise_tv_apgm.py\n   - denoise_tv_multi.py\n   - denoise_approx_tv_multi.py\n   - denoise_cplx_tv_nlpadmm.py\n   - denoise_cplx_tv_pdhg.py\n   - diffusercam_tv_admm.py\n\n\n\nSparsity\n^^^^^^^^\n\n   - diffusercam_tv_admm.py\n   - sparsecode_nn_admm.py\n   - sparsecode_nn_apgm.py\n   - sparsecode_conv_admm.py\n   - sparsecode_conv_md_admm.py\n   - sparsecode_apgm.py\n   - sparsecode_poisson_apgm.py\n   - video_rpca_admm.py\n\n\nMachine Learning\n^^^^^^^^^^^^^^^^\n\n   - ct_datagen_foam2.py\n   - ct_modl_train_foam2.py\n   - ct_odp_train_foam2.py\n   - ct_unet_train_foam2.py\n   - deconv_datagen_bsds.py\n   - deconv_datagen_foam1.py\n   - deconv_modl_train_foam1.py\n   - deconv_odp_train_foam1.py\n   - denoise_datagen_bsds.py\n   - denoise_dncnn_train_bsds.py\n   - denoise_dncnn_universal.py\n\n\nOrganized by Optimization Algorithm\n-----------------------------------\n\nADMM\n^^^^\n\n   - ct_abel_tv_admm.py\n   - ct_abel_tv_admm_tune.py\n   - ct_symcone_tv_padmm.py\n   - ct_astra_tv_admm.py\n   - ct_tv_admm.py\n   - ct_astra_3d_tv_admm.py\n   - ct_astra_weighted_tv_admm.py\n   - ct_multi_tv_admm.py\n   - ct_svmbir_tv_multi.py\n   - ct_svmbir_ppp_bm3d_admm_cg.py\n   - ct_svmbir_ppp_bm3d_admm_prox.py\n   - ct_fan_svmbir_ppp_bm3d_admm_prox.py\n   - deconv_circ_tv_admm.py\n   - deconv_tv_admm.py\n   - deconv_tv_admm_tune.py\n   - deconv_microscopy_tv_admm.py\n   - deconv_microscopy_allchn_tv_admm.py\n   - deconv_ppp_bm3d_admm.py\n   - deconv_ppp_dncnn_admm.py\n   - deconv_ppp_bm4d_admm.py\n   - diffusercam_tv_admm.py\n   - sparsecode_nn_admm.py\n   - sparsecode_conv_admm.py\n   - sparsecode_conv_md_admm.py\n   - demosaic_ppp_bm3d_admm.py\n   - superres_ppp_dncnn_admm.py\n   - denoise_l1tv_admm.py\n   - denoise_tv_admm.py\n   - denoise_tv_multi.py\n   - denoise_approx_tv_multi.py\n   - video_rpca_admm.py\n\n\nLinearized ADMM\n^^^^^^^^^^^^^^^\n\n    - ct_svmbir_tv_multi.py\n    - denoise_tv_multi.py\n\n\nProximal ADMM\n^^^^^^^^^^^^^\n\n    - ct_astra_3d_tv_padmm.py\n    - deconv_tv_padmm.py\n    - denoise_tv_multi.py\n    - deconv_ppp_dncnn_padmm.py\n\n\nNon-linear Proximal ADMM\n^^^^^^^^^^^^^^^^^^^^^^^^\n\n    - denoise_cplx_tv_nlpadmm.py\n\n\nPDHG\n^^^^\n\n    - ct_svmbir_tv_multi.py\n    - denoise_ptv_pdhg.py\n    - denoise_tv_multi.py\n    - denoise_cplx_tv_pdhg.py\n\n\nPGM\n^^^\n\n   - deconv_ppp_bm3d_apgm.py\n   - sparsecode_apgm.py\n   - sparsecode_nn_apgm.py\n   - sparsecode_poisson_apgm.py\n   - denoise_tv_apgm.py\n   - denoise_approx_tv_multi.py\n\n\nPCG\n^^^\n\n   - ct_astra_noreg_pcg.py\n"
  },
  {
    "path": "examples/scripts/sparsecode_apgm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nBasis Pursuit DeNoising (APGM)\n==============================\n\nThis example demonstrates the solution of the the sparse coding problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - D \\mathbf{x}\n  \\|_2^2 + \\lambda \\| \\mathbf{x} \\|_1\\;,$$\n\nwhere $D$ the dictionary, $\\mathbf{y}$ the signal to be represented,\nand $\\mathbf{x}$ is the sparse representation.\n\"\"\"\n\nimport numpy as np\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, plot\nfrom scico.optimize.pgm import AcceleratedPGM\nfrom scico.util import device_info\n\n\"\"\"\nConstruct a random dictionary, a reference random sparse\nrepresentation, and a test signal consisting of the synthesis of the\nreference sparse representation.\n\"\"\"\nm = 512  # Signal size\nn = 4 * m  # Dictionary size\ns = 32  # Sparsity level (number of non-zeros)\nσ = 0.5  # Noise level\n\nnp.random.seed(12345)\nD = np.random.randn(m, n).astype(np.float32)\nL0 = np.linalg.norm(D, 2) ** 2\n\nx_gt = np.zeros(n, dtype=np.float32)  # true signal\nidx = np.random.permutation(list(range(0, n - 1)))\nx_gt[idx[0:s]] = np.random.randn(s)\ny = D @ x_gt + σ * np.random.randn(m)  # synthetic signal\n\nx_gt = snp.array(x_gt)  # convert to jax array\ny = snp.array(y)  # convert to jax array\n\n\n\"\"\"\nSet up the forward operator and `AcceleratedPGM` solver object.\n\"\"\"\nmaxiter = 100\nλ = 2.98e1\nA = linop.MatrixOperator(D)\nf = loss.SquaredL2Loss(y=y, A=A)\ng = λ * functional.L1Norm()\nsolver = AcceleratedPGM(\n    f=f, g=g, L0=L0, x0=A.adj(y), maxiter=maxiter, itstat_options={\"display\": True, \"period\": 10}\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx = solver.solve()\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nPlot the recovered coefficients and convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\nplot.plot(\n    np.vstack((x_gt, x)).T,\n    title=\"Coefficients\",\n    lgnd=(\"Ground Truth\", \"Recovered\"),\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    np.array((hist.Objective, hist.Residual)).T,\n    ptyp=\"semilogy\",\n    title=\"Convergence\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Objective\", \"Residual\"),\n    fig=fig,\n    ax=ax[1],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/sparsecode_conv_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nConvolutional Sparse Coding (ADMM)\n==================================\n\nThis example demonstrates the solution of a simple convolutional sparse\ncoding problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; \\frac{1}{2} \\Big\\| \\mathbf{y} -\n  \\sum_k \\mathbf{h}_k \\ast \\mathbf{x}_k \\Big\\|_2^2 + \\lambda \\sum_k\n  ( \\| \\mathbf{x}_k \\|_1 - \\| \\mathbf{x}_k \\|_2 ) \\;,$$\n\nwhere the $\\mathbf{h}$_k is a set of filters comprising the dictionary,\nthe $\\mathbf{x}$_k is a corrresponding set of coefficient maps, and\n$\\mathbf{y}$ is the signal to be represented. The problem is solved via\nan ADMM algorithm using the frequency-domain approach proposed in\n:cite:`wohlberg-2014-efficient`.\n\"\"\"\n\nimport numpy as np\n\nimport scico.numpy as snp\nfrom scico import plot\nfrom scico.examples import create_conv_sparse_phantom\nfrom scico.functional import L1MinusL2Norm\nfrom scico.linop import CircularConvolve, Identity, Sum\nfrom scico.loss import SquaredL2Loss\nfrom scico.optimize.admm import ADMM, FBlockCircularConvolveSolver\nfrom scico.util import device_info\n\n\"\"\"\nSet problem size and create random convolutional dictionary (a set of\nfilters) and a corresponding sparse random set of coefficient maps.\n\"\"\"\nN = 128  # image size\nNnz = 128  # number of non-zeros in coefficient maps\nh, x0 = create_conv_sparse_phantom(N, Nnz)\n\n\n\"\"\"\nNormalize dictionary filters and scale coefficient maps accordingly.\n\"\"\"\nhnorm = np.sqrt(np.sum(h**2, axis=(1, 2), keepdims=True))\nh /= hnorm\nx0 *= hnorm\n\n\n\"\"\"\nConvert numpy arrays to jax arrays.\n\"\"\"\nh = snp.array(h)\nx0 = snp.array(x0)\n\n\n\"\"\"\nSet up sum-of-convolutions forward operator.\n\"\"\"\nC = CircularConvolve(h, input_shape=x0.shape, ndims=2)\nS = Sum(input_shape=C.output_shape, axis=0)\nA = S @ C\n\n\n\"\"\"\nConstruct test image from dictionary $\\mathbf{h}$ and coefficient maps\n$\\mathbf{x}_0$.\n\"\"\"\ny = A(x0)\n\n\n\"\"\"\nSet functional and solver parameters.\n\"\"\"\nλ = 1e0  # ℓ1-ℓ2 norm regularization parameter\nρ = 2e0  # ADMM penalty parameter\nmaxiter = 200  # number of ADMM iterations\n\n\n\"\"\"\nDefine loss function and regularization. Note the use of the\n$\\ell_1 - \\ell_2$ norm, which has been found to provide slightly better\nperformance than the $\\ell_1$ norm in this type of problem\n:cite:`wohlberg-2021-psf`.\n\"\"\"\nf = SquaredL2Loss(y=y, A=A)\ng0 = λ * L1MinusL2Norm()\nC0 = Identity(input_shape=x0.shape)\n\n\n\"\"\"\nInitialize ADMM solver.\n\"\"\"\nsolver = ADMM(\n    f=f,\n    g_list=[g0],\n    C_list=[C0],\n    rho_list=[ρ],\n    alpha=1.8,\n    maxiter=maxiter,\n    subproblem_solver=FBlockCircularConvolveSolver(check_solve=True),\n    itstat_options={\"display\": True, \"period\": 10},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx1 = solver.solve()\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow the recovered coefficient maps.\n\"\"\"\nfig, ax = plot.subplots(nrows=2, ncols=3, figsize=(12, 8.6))\nplot.imview(x0[0], title=\"Coef. map 0\", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 0])\nax[0, 0].set_ylabel(\"Ground truth\")\nplot.imview(x0[1], title=\"Coef. map 1\", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 1])\nplot.imview(x0[2], title=\"Coef. map 2\", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 2])\nplot.imview(x1[0], cmap=plot.cm.Blues, fig=fig, ax=ax[1, 0])\nax[1, 0].set_ylabel(\"Recovered\")\nplot.imview(x1[1], cmap=plot.cm.Blues, fig=fig, ax=ax[1, 1])\nplot.imview(x1[2], cmap=plot.cm.Blues, fig=fig, ax=ax[1, 2])\nfig.tight_layout()\nfig.show()\n\n\n\"\"\"\nShow test image and reconstruction from recovered coefficient maps.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 6))\nplot.imview(y, title=\"Test image\", cmap=plot.cm.gist_heat_r, fig=fig, ax=ax[0])\nplot.imview(A(x1), title=\"Reconstructed image\", cmap=plot.cm.gist_heat_r, fig=fig, ax=ax[1])\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\nplot.plot(\n    hist.Objective,\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    ylbl=\"Functional value\",\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[1],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/sparsecode_conv_md_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nConvolutional Sparse Coding with Mask Decoupling (ADMM)\n=======================================================\n\nThis example demonstrates the solution of a convolutional sparse coding\nproblem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; \\frac{1}{2} \\Big\\| \\mathbf{y} -\n  B \\Big( \\sum_k \\mathbf{h}_k \\ast \\mathbf{x}_k \\Big) \\Big\\|_2^2 +\n  \\lambda \\sum_k ( \\| \\mathbf{x}_k \\|_1 - \\| \\mathbf{x}_k \\|_2 ) \\;,$$\n\nwhere the $\\mathbf{h}$_k is a set of filters comprising the dictionary,\nthe $\\mathbf{x}$_k is a corrresponding set of coefficient maps,\n$\\mathbf{y}$ is the signal to be represented, and $B$ is a cropping\noperator that allows the boundary artifacts resulting from circular\nconvolution to be avoided. Following the mask decoupling approach\n:cite:`almeida-2013-deconvolving`, the problem is posed in ADMM form\nas\n\n  $$\\mathrm{argmin}_{\\mathbf{x}, \\mathbf{z}_0, \\mathbf{z}_1} \\; (1/2) \\|\n  \\mathbf{y} - B \\mb{z}_0 \\|_2^2 + \\lambda \\sum_k ( \\| \\mathbf{z}_{1,k}\n  \\|_1 - \\| \\mathbf{z}_{1,k} \\|_2 ) \\\\ \\;\\; \\text{s.t.} \\;\\;\n  \\mathbf{z}_0 = \\sum_k \\mathbf{h}_k \\ast \\mathbf{x}_k \\;\\;\n  \\mathbf{z}_{1,k} = \\mathbf{x}_k\\;,$$.\n\nThe most computationally expensive step in the ADMM algorithm is solved\nusing the frequency-domain approach proposed in\n:cite:`wohlberg-2014-efficient`.\n\"\"\"\n\nimport numpy as np\n\nimport scico.numpy as snp\nfrom scico import plot\nfrom scico.examples import create_conv_sparse_phantom\nfrom scico.functional import L1MinusL2Norm, ZeroFunctional\nfrom scico.linop import CircularConvolve, Crop, Identity, Sum\nfrom scico.loss import SquaredL2Loss\nfrom scico.optimize.admm import ADMM, G0BlockCircularConvolveSolver\nfrom scico.util import device_info\n\n\"\"\"\nSet problem size and create random convolutional dictionary (a set of\nfilters) and a corresponding sparse random set of coefficient maps.\n\"\"\"\nN = 121  # image size\nNnz = 128  # number of non-zeros in coefficient maps\nh, x0 = create_conv_sparse_phantom(N, Nnz)\n\n\n\"\"\"\nNormalize dictionary filters and scale coefficient maps accordingly.\n\"\"\"\nhnorm = np.sqrt(np.sum(h**2, axis=(1, 2), keepdims=True))\nh /= hnorm\nx0 *= hnorm\n\n\n\"\"\"\nConvert numpy arrays to jax arrays.\n\"\"\"\nh = snp.array(h)\nx0 = snp.array(x0)\n\n\n\"\"\"\nSet up required padding and corresponding crop operator.\n\"\"\"\nh_center = (h.shape[1] // 2, h.shape[2] // 2)\npad_width = ((0, 0), (h_center[0], h_center[0]), (h_center[1], h_center[1]))\nx0p = snp.pad(x0, pad_width=pad_width)\nB = Crop(pad_width[1:], input_shape=x0p.shape[1:])\n\n\n\"\"\"\nSet up sum-of-convolutions forward operator.\n\"\"\"\nC = CircularConvolve(h, input_shape=x0p.shape, ndims=2, h_center=h_center)\nS = Sum(input_shape=C.output_shape, axis=0)\nA = S @ C\n\n\n\"\"\"\nConstruct test image from dictionary $\\mathbf{h}$ and padded version of\ncoefficient maps $\\mathbf{x}_0$.\n\"\"\"\ny = B(A(x0p))\n\n\n\"\"\"\nSet functional and solver parameters.\n\"\"\"\nλ = 1e0  # ℓ1-ℓ2 norm regularization parameter\nρ0 = 1e0  # ADMM penalty parameters\nρ1 = 3e0\nmaxiter = 200  # number of ADMM iterations\n\n\n\"\"\"\nDefine loss function and regularization. Note the use of the\n$\\ell_1 - \\ell_2$ norm, which has been found to provide slightly better\nperformance than the $\\ell_1$ norm in this type of problem\n:cite:`wohlberg-2021-psf`.\n\"\"\"\nf = ZeroFunctional()\ng0 = SquaredL2Loss(y=y, A=B)\ng1 = λ * L1MinusL2Norm()\nC0 = A\nC1 = Identity(input_shape=x0p.shape)\n\n\n\"\"\"\nInitialize ADMM solver.\n\"\"\"\nsolver = ADMM(\n    f=f,\n    g_list=[g0, g1],\n    C_list=[C0, C1],\n    rho_list=[ρ0, ρ1],\n    alpha=1.8,\n    maxiter=maxiter,\n    subproblem_solver=G0BlockCircularConvolveSolver(check_solve=True),\n    itstat_options={\"display\": True, \"period\": 10},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx1 = solver.solve()\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nShow the recovered coefficient maps.\n\"\"\"\nfig, ax = plot.subplots(nrows=2, ncols=3, figsize=(12, 8.6))\nplot.imview(x0[0], title=\"Coef. map 0\", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 0])\nax[0, 0].set_ylabel(\"Ground truth\")\nplot.imview(x0[1], title=\"Coef. map 1\", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 1])\nplot.imview(x0[2], title=\"Coef. map 2\", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 2])\nplot.imview(x1[0], cmap=plot.cm.Blues, fig=fig, ax=ax[1, 0])\nax[1, 0].set_ylabel(\"Recovered\")\nplot.imview(x1[1], cmap=plot.cm.Blues, fig=fig, ax=ax[1, 1])\nplot.imview(x1[2], cmap=plot.cm.Blues, fig=fig, ax=ax[1, 2])\nfig.tight_layout()\nfig.show()\n\n\n\"\"\"\nShow test image and reconstruction from recovered coefficient maps. Note\nthe absence of the wrap-around effects at the boundary that can be seen\nin the corresponding images in the [related example](sparsecode_conv_admm.rst).\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 6))\nplot.imview(y, title=\"Test image\", cmap=plot.cm.gist_heat_r, fig=fig, ax=ax[0])\nplot.imview(B(A(x1)), title=\"Reconstructed image\", cmap=plot.cm.gist_heat_r, fig=fig, ax=ax[1])\nfig.show()\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\nplot.plot(\n    hist.Objective,\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    ylbl=\"Functional value\",\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[1],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/sparsecode_nn_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nNon-Negative Basis Pursuit DeNoising (ADMM)\n===========================================\n\nThis example demonstrates the solution of a non-negative sparse coding\nproblem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - D \\mathbf{x} \\|_2^2\n  + \\lambda \\| \\mathbf{x} \\|_1 + \\iota_{\\mathrm{NN}}(\\mathbf{x}) \\;,$$\n\nwhere $D$ the dictionary, $\\mathbf{y}$ the signal to be represented,\n$\\mathbf{x}$ is the sparse representation, and $\\iota_{\\mathrm{NN}}$ is\nthe indicator function of the non-negativity constraint.\n\nIn this example the problem is solved via ADMM, while Accelerated PGM is\nused in a [companion example](sparsecode_nn_apgm.rst).\n\"\"\"\n\nimport numpy as np\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, plot\nfrom scico.optimize.admm import ADMM, MatrixSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nCreate random dictionary, reference random sparse representation, and\ntest signal consisting of the synthesis of the reference sparse\nrepresentation.\n\"\"\"\nm = 32  # signal size\nn = 128  # dictionary size\ns = 10  # sparsity level\n\nnp.random.seed(1)\nD = np.random.randn(m, n).astype(np.float32)\nD = D / np.linalg.norm(D, axis=0, keepdims=True)  # normalize dictionary\n\nxt = np.zeros(n, dtype=np.float32)  # true signal\nidx = np.random.randint(low=0, high=n, size=s)  # support of xt\nxt[idx] = np.random.rand(s)\ny = D @ xt + 5e-2 * np.random.randn(m)  # synthetic signal\n\nxt = snp.array(xt)  # convert to jax array\ny = snp.array(y)  # convert to jax array\n\n\n\"\"\"\nSet up the forward operator and ADMM solver object.\n\"\"\"\nlmbda = 1e-1\nA = linop.MatrixOperator(D)\nf = loss.SquaredL2Loss(y=y, A=A)\ng_list = [lmbda * functional.L1Norm(), functional.NonNegativeIndicator()]\nC_list = [linop.Identity((n)), linop.Identity((n))]\nrho_list = [1.0, 1.0]\nmaxiter = 100  # number of ADMM iterations\n\nsolver = ADMM(\n    f=f,\n    g_list=g_list,\n    C_list=C_list,\n    rho_list=rho_list,\n    x0=A.adj(y),\n    maxiter=maxiter,\n    subproblem_solver=MatrixSubproblemSolver(),\n    itstat_options={\"display\": True, \"period\": 10},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx = solver.solve()\n\n\n\"\"\"\nPlot the recovered coefficients and signal.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\nplot.plot(\n    np.vstack((xt, solver.x)).T,\n    title=\"Coefficients\",\n    lgnd=(\"Ground Truth\", \"Recovered\"),\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    np.vstack((D @ xt, y, D @ solver.x)).T,\n    title=\"Signal\",\n    lgnd=(\"Ground Truth\", \"Noisy\", \"Recovered\"),\n    fig=fig,\n    ax=ax[1],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/sparsecode_nn_apgm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nNon-Negative Basis Pursuit DeNoising (APGM)\n===========================================\n\nThis example demonstrates the solution of a non-negative sparse coding\nproblem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - D \\mathbf{x} \\|_2^2\n  + \\lambda \\| \\mathbf{x} \\|_1 + \\iota_{\\mathrm{NN}}(\\mathbf{x}) \\;,$$\n\nwhere $D$ the dictionary, $\\mathbf{y}$ the signal to be represented,\n$\\mathbf{x}$ is the sparse representation, and $\\iota_{\\mathrm{NN}}$ is\nthe indicator function of the non-negativity constraint.\n\nIn this example the problem is solved via Accelerated PGM, using the\nproximal averaging method :cite:`yu-2013-better` to approximate the\nproximal operator of the sum of the $\\ell_1$ norm and an indicator\nfunction, while ADMM is used in a\n[companion example](sparsecode_nn_admm.rst).\n\"\"\"\n\nimport numpy as np\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, plot\nfrom scico.optimize.pgm import AcceleratedPGM\nfrom scico.util import device_info\n\n\"\"\"\nCreate random dictionary, reference random sparse representation, and\ntest signal consisting of the synthesis of the reference sparse\nrepresentation.\n\"\"\"\nm = 32  # signal size\nn = 128  # dictionary size\ns = 10  # sparsity level\n\nnp.random.seed(1)\nD = np.random.randn(m, n).astype(np.float32)\nD = D / np.linalg.norm(D, axis=0, keepdims=True)  # normalize dictionary\nL0 = max(np.linalg.norm(D, 2) ** 2, 5e1)\n\nxt = np.zeros(n, dtype=np.float32)  # true signal\nidx = np.random.randint(low=0, high=n, size=s)  # support of xt\nxt[idx] = np.random.rand(s)\ny = D @ xt + 5e-2 * np.random.randn(m)  # synthetic signal\n\nxt = snp.array(xt)  # convert to jax array\ny = snp.array(y)  # convert to jax array\n\n\n\"\"\"\nSet up the forward operator and APGM solver object.\n\"\"\"\nlmbda = 2e-1\nA = linop.MatrixOperator(D)\nf = loss.SquaredL2Loss(y=y, A=A)\ng = functional.ProximalAverage([lmbda * functional.L1Norm(), functional.NonNegativeIndicator()])\nmaxiter = 250  # number of APGM iterations\nsolver = AcceleratedPGM(\n    f=f, g=g, L0=L0, x0=A.adj(y), maxiter=maxiter, itstat_options={\"display\": True, \"period\": 20}\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx = solver.solve()\n\n\n\"\"\"\nPlot the recovered coefficients and signal.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\nplot.plot(\n    np.vstack((xt, solver.x)).T,\n    title=\"Coefficients\",\n    lgnd=(\"Ground Truth\", \"Recovered\"),\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    np.vstack((D @ xt, y, D @ solver.x)).T,\n    title=\"Signal\",\n    lgnd=(\"Ground Truth\", \"Noisy\", \"Recovered\"),\n    fig=fig,\n    ax=ax[1],\n)\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/sparsecode_poisson_apgm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nNon-negative Poisson Loss Reconstruction (APGM)\n===============================================\n\nThis example demonstrates the use of class\n[pgm.PGMStepSize](../_autosummary/scico.optimize.pgm.rst#scico.optimize.pgm.PGMStepSize)\nto solve the non-negative reconstruction problem with Poisson negative\nlog likelihood loss\n\n  $$\\mathrm{argmin}_{\\mathbf{x}} \\; \\frac{1}{2} \\left( A(\\mathbf{x}) -\n  \\mathbf{y} \\log\\left( A(\\mathbf{x}) \\right) + \\log(\\mathbf{y}!) \\right) +\n  \\iota_{\\mathrm{NN}}(\\mathbf{x}_0) \\;,$$\n\nwhere $A$ is the forward operator, $\\mathbf{y}$ is the measurement,\n$\\mathbf{x}$ is the signal reconstruction, and $\\iota_{\\mathrm{NN}}$ is\nthe indicator function of the non-negativity constraint.\n\nThis example also demonstrates the application of\n[numpy.BlockArray](../_autosummary/scico.numpy.rst#scico.numpy.BlockArray),\n[functional.SeparableFunctional](../_autosummary/scico.functional.rst#scico.functional.SeparableFunctional),\nand\n[functional.ZeroFunctional](../_autosummary/scico.functional.rst#scico.functional.ZeroFunctional)\nto implement the forward operator\n$A(\\mathbf{x}) = A_0(\\mathbf{x}_0) + A_1(\\mathbf{x}_1)$\nand the selective non-negativity constraint that only applies to\n$\\mathbf{x}_0$.\n\"\"\"\n\nimport matplotlib.gridspec as gridspec\nimport matplotlib.pyplot as plt\n\nimport scico.numpy as snp\nimport scico.random\nfrom scico import functional, loss, plot\nfrom scico.numpy import BlockArray\nfrom scico.operator import Operator\nfrom scico.optimize.pgm import (\n    AcceleratedPGM,\n    AdaptiveBBStepSize,\n    BBStepSize,\n    LineSearchStepSize,\n    RobustLineSearchStepSize,\n)\nfrom scico.typing import Shape\nfrom scico.util import device_info\nfrom scipy.linalg import dft\n\n\"\"\"\nConstruct a dictionary, a reference random reconstruction, and a test\nmeasurement signal consisting of the synthesis of the reference\nreconstruction.\n\"\"\"\nm = 1024  # signal size\nn = 8  # dictionary size\nn0 = 2\nn1 = n - n0\n\n# Create dictionary with bump-like features.\nD = ((snp.real(dft(m))[1 : n + 1, :m]) ** 12).T\nD0 = D[:, :n0]\nD1 = D[:, n0:]\n\n\n# Define composed operator.\nclass ForwardOperator(Operator):\n    \"\"\"Toy problem non-linear forward operator with different treatment\n       of x[0] and x[1].\n\n    Attributes:\n        D0: Matrix multiplying x[0].\n        D1: Matrix multiplying x[1].\n    \"\"\"\n\n    def __init__(self, input_shape: Shape, D0, D1, jit: bool = True):\n        self.D0 = D0\n        self.D1 = D1\n\n        output_shape = (D0.shape[0],)\n\n        super().__init__(\n            input_shape=input_shape,\n            input_dtype=snp.complex64,\n            output_dtype=snp.complex64,\n            output_shape=output_shape,\n            jit=jit,\n        )\n\n    def _eval(self, x: BlockArray) -> BlockArray:\n        return 10 * snp.exp(-D0 @ x[0]) + 5 * snp.exp(-D1 @ x[1])\n\n\nx_gt, key = scico.random.uniform(((n0,), (n1,)), seed=12345)  # true coefficients\n\nA = ForwardOperator(x_gt.shape, D0, D1)\n\nlam = A(x_gt)\ny, key = scico.random.poisson(lam, shape=lam.shape, key=key)  # synthetic signal\n\n\n\"\"\"\nSet up the loss function and the regularization.\n\"\"\"\nf = loss.PoissonLoss(y=y, A=A)\n\ng0 = functional.NonNegativeIndicator()\ng1 = functional.ZeroFunctional()\ng = functional.SeparableFunctional([g0, g1])\n\n\n\"\"\"\nDefine common setup: maximum of iterations and initial estimate of solution.\n\"\"\"\nmaxiter = 50\nx0, key = scico.random.uniform(((n0,), (n1,)), key=key)\n\n\n\"\"\"\nDefine plotting functionality.\n\"\"\"\n\n\ndef plot_results(hist, str_ss, L0, xsol, xgt, Aop):\n    # Plot signal, coefficients and convergence statistics.\n    fig = plot.figure(\n        figsize=(12, 6),\n        tight_layout=True,\n    )\n    gs = gridspec.GridSpec(nrows=2, ncols=3)\n\n    fig.suptitle(\n        \"Results for PGM Solver and \" + str_ss + r\" ($L_0$: \" + \"{:4.2f}\".format(L0) + \")\",\n        fontsize=14,\n    )\n\n    ax0 = fig.add_subplot(gs[0, 0])\n    plot.plot(\n        hist.Objective,\n        ptyp=\"semilogy\",\n        title=\"Objective\",\n        xlbl=\"Iteration\",\n        fig=fig,\n        ax=ax0,\n    )\n\n    ax1 = fig.add_subplot(gs[0, 1])\n    plot.plot(\n        hist.Residual,\n        ptyp=\"semilogy\",\n        title=\"Residual\",\n        xlbl=\"Iteration\",\n        fig=fig,\n        ax=ax1,\n    )\n\n    ax2 = fig.add_subplot(gs[0, 2])\n    plot.plot(\n        hist.L,\n        ptyp=\"semilogy\",\n        title=\"L\",\n        xlbl=\"Iteration\",\n        fig=fig,\n        ax=ax2,\n    )\n\n    ax3 = fig.add_subplot(gs[1, 0])\n    plt.stem(snp.concatenate((xgt[0], xgt[1])), linefmt=\"C1-\", markerfmt=\"C1o\", basefmt=\"C1-\")\n    plt.stem(snp.concatenate((xsol[0], xsol[1])), linefmt=\"C2-\", markerfmt=\"C2x\", basefmt=\"C1-\")\n    plt.legend([\"Ground Truth\", \"Recovered\"])\n    plt.xlabel(\"Index\")\n    plt.title(\"Coefficients\")\n\n    ax4 = fig.add_subplot(gs[1, 1:])\n    plot.plot(\n        snp.vstack((y, Aop(xgt), Aop(xsol))).T,\n        title=\"Fit\",\n        xlbl=\"Index\",\n        lgnd=(\"y\", \"A(x_gt)\", \"A(x)\"),\n        fig=fig,\n        ax=ax4,\n    )\n    fig.show()\n\n\n\"\"\"\nUse default `PGMStepSize` object, set L0 based on norm of forward\noperator and set up `AcceleratedPGM` solver object. Run the solver and\nplot the recontructed signal and convergence statistics.\n\"\"\"\nL0 = 1e3\nstr_L0 = \"(Specifically chosen so that convergence occurs)\"\n\nsolver = AcceleratedPGM(\n    f=f,\n    g=g,\n    L0=L0,\n    x0=x0,\n    maxiter=maxiter,\n    itstat_options={\"display\": True, \"period\": 10},\n)\nstr_ss = type(solver.step_size).__name__\n\nprint(f\"Solving on {device_info()}\\n\")\nprint(\"============================================================\")\nprint(\"Running solver with step size of class: \", str_ss)\nprint(\"L0 \" + str_L0 + \": \", L0, \"\\n\")\n\nx = solver.solve()  # run the solver\nhist = solver.itstat_object.history(transpose=True)\nplot_results(hist, str_ss, L0, x, x_gt, A)\n\n\n\"\"\"\nUse `BBStepSize` object, set L0 with arbitary initial value and set up\n`AcceleratedPGM` solver object. Run the solver and plot the\nrecontructed signal and convergence statistics.\n\"\"\"\nL0 = 90.0  # initial reciprocal of gradient descent step size\nstr_L0 = \"(Arbitrary Initialization)\"\n\nsolver = AcceleratedPGM(\n    f=f,\n    g=g,\n    L0=L0,\n    x0=x0,\n    maxiter=maxiter,\n    itstat_options={\"display\": True, \"period\": 10},\n    step_size=BBStepSize(),\n)\nstr_ss = type(solver.step_size).__name__\n\nprint(\"===================================================\")\nprint(\"Running solver with step size of class: \", str_ss)\nprint(\"L0 \" + str_L0 + \": \", L0, \"\\n\")\n\nx = solver.solve()  # run the solver\nhist = solver.itstat_object.history(transpose=True)\nplot_results(hist, str_ss, L0, x, x_gt, A)\n\n\n\"\"\"\nUse `AdaptiveBBStepSize` object, set L0 with arbitary initial value and\nset up `AcceleratedPGM` solver object. Run the solver and plot the\nrecontructed signal and convergence statistics.\n\"\"\"\nL0 = 90.0  # initial reciprocal of gradient descent step size\nstr_L0 = \"(Arbitrary Initialization)\"\n\nsolver = AcceleratedPGM(\n    f=f,\n    g=g,\n    L0=L0,\n    x0=x0,\n    maxiter=maxiter,\n    itstat_options={\"display\": True, \"period\": 10},\n    step_size=AdaptiveBBStepSize(kappa=0.75),\n)\nstr_ss = type(solver.step_size).__name__\n\nprint(\"===========================================================\")\nprint(\"Running solver with step size of class: \", str_ss)\nprint(\"L0 \" + str_L0 + \": \", L0, \"\\n\")\n\nx = solver.solve()  # run the solver\nhist = solver.itstat_object.history(transpose=True)\nplot_results(hist, str_ss, L0, x, x_gt, A)\n\n\n\"\"\"\nUse `LineSearchStepSize` object, set L0 with arbitary initial value and\nset up `AcceleratedPGM` solver object. Run the solver and plot the\nrecontructed signal and convergence statistics.\n\"\"\"\nL0 = 90.0  # initial reciprocal of gradient descent step size\nstr_L0 = \"(Arbitrary Initialization)\"\n\nsolver = AcceleratedPGM(\n    f=f,\n    g=g,\n    L0=L0,\n    x0=x0,\n    maxiter=maxiter,\n    itstat_options={\"display\": True, \"period\": 10},\n    step_size=LineSearchStepSize(),\n)\nstr_ss = type(solver.step_size).__name__\n\nprint(\"===========================================================\")\nprint(\"Running solver with step size of class: \", str_ss)\nprint(\"L0 \" + str_L0 + \": \", L0, \"\\n\")\n\nx = solver.solve()  # run the solver\nhist = solver.itstat_object.history(transpose=True)\nplot_results(hist, str_ss, L0, x, x_gt, A)\n\n\n\"\"\"\nUse `RobustLineSearchStepSize` object, set L0 with arbitary initial\nvalue and set up `AcceleratedPGM` solver object. Run the solver and\nplot the recontructed signal and convergence statistics.\n\"\"\"\nL0 = 90.0  # initial reciprocal of gradient descent step size\nstr_L0 = \"(Arbitrary Initialization)\"\n\nsolver = AcceleratedPGM(\n    f=f,\n    g=g,\n    L0=L0,\n    x0=x0,\n    maxiter=maxiter,\n    itstat_options={\"display\": True, \"period\": 10},\n    step_size=RobustLineSearchStepSize(),\n)\nstr_ss = type(solver.step_size).__name__\n\nprint(\"=================================================================\")\nprint(\"Running solver with step size of class: \", str_ss)\nprint(\"L0 \" + str_L0 + \": \", L0, \"\\n\")\n\nx = solver.solve()  # run the solver\nhist = solver.itstat_object.history(transpose=True)\nplot_results(hist, str_ss, L0, x, x_gt, A)\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/superres_ppp_dncnn_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"\nPPP (with DnCNN) Image Superresolution\n======================================\n\nThis example demonstrates the use of the ADMM Plug and Play Priors\n(PPP) algorithm :cite:`venkatakrishnan-2013-plugandplay2`, with DnCNN\n:cite:`zhang-2017-dncnn` denoiser, for solving a simple image\nsuperresolution problem.\n\"\"\"\n\nimport scico\nimport scico.numpy as snp\nimport scico.random\nfrom scico import denoiser, functional, linop, loss, metric, plot\nfrom scico.data import kodim23\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.solver import cg\nfrom scico.util import device_info\n\n\"\"\"\nDefine downsampling function.\n\"\"\"\n\n\ndef downsample_image(img, rate):\n    img = snp.mean(snp.reshape(img, (-1, rate, img.shape[1], img.shape[2])), axis=1)\n    img = snp.mean(snp.reshape(img, (img.shape[0], -1, rate, img.shape[2])), axis=2)\n    return img\n\n\n\"\"\"\nRead a ground truth image.\n\"\"\"\nimg = snp.array(kodim23(asfloat=True)[160:416, 60:316])\n\n\n\"\"\"\nCreate a test image by downsampling and adding Gaussian white noise.\n\"\"\"\nrate = 4  # downsampling rate\nσ = 2e-2  # noise standard deviation\n\nAfn = lambda x: downsample_image(x, rate=rate)\ns = Afn(img)\ninput_shape = img.shape\noutput_shape = s.shape\nnoise, key = scico.random.randn(s.shape, seed=0)\nsn = s + σ * noise\n\n\n\"\"\"\nSet up the PPP problem pseudo-functional. The DnCNN denoiser\n:cite:`zhang-2017-dncnn` is used as a regularizer.\n\"\"\"\nA = linop.LinearOperator(input_shape=input_shape, output_shape=output_shape, eval_fn=Afn)\nf = loss.SquaredL2Loss(y=sn, A=A)\nC = linop.Identity(input_shape=input_shape)\ng = functional.DnCNN(\"17M\")\n\n\n\"\"\"\nCompute a baseline solution via denoising of the pseudo-inverse of the\nforward operator. This baseline solution is also used to initialize the\nPPP solver.\n\"\"\"\nxpinv, info = cg(A.T @ A, A.T @ sn, snp.zeros(input_shape))\ndncnn = denoiser.DnCNN(\"17M\")\nxden = dncnn(xpinv)\n\n\n\"\"\"\nSet up an ADMM solver and solve.\n\"\"\"\nρ = 3.4e-2  # ADMM penalty parameter\nmaxiter = 12  # number of ADMM iterations\nsolver = ADMM(\n    f=f,\n    g_list=[g],\n    C_list=[C],\n    rho_list=[ρ],\n    x0=xden,\n    maxiter=maxiter,\n    subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-3, \"maxiter\": 10}),\n    itstat_options={\"display\": True},\n)\n\nprint(f\"Solving on {device_info()}\\n\")\nxppp = solver.solve()\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n)\n\n\n\"\"\"\nShow reference and test images.\n\"\"\"\nfig = plot.figure(figsize=(8, 6))\nax0 = plot.plt.subplot2grid((1, rate + 1), (0, 0), colspan=rate)\nplot.imview(img, title=\"Reference\", fig=fig, ax=ax0)\nax1 = plot.plt.subplot2grid((1, rate + 1), (0, rate))\nplot.imview(sn, title=\"Downsampled\", fig=fig, ax=ax1)\nfig.show()\n\n\n\"\"\"\nShow recovered full-resolution images.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(21, 7))\nplot.imview(xpinv, title=\"Pseudo-inverse: %.2f (dB)\" % metric.psnr(img, xpinv), fig=fig, ax=ax[0])\nplot.imview(\n    xden, title=\"Denoised pseudo-inverse: %.2f (dB)\" % metric.psnr(img, xden), fig=fig, ax=ax[1]\n)\nplot.imview(xppp, title=\"PPP solution: %.2f (dB)\" % metric.psnr(img, xppp), fig=fig, ax=ax[2])\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/scripts/trace_example.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nSCICO Call Tracing\n==================\n\nThis example demonstrates the call tracing functionality provided by the\n[trace](../_autosummary/scico.trace.rst) module. It is based on the\n[non-negative BPDN example](sparsecode_nn_admm.rst).\n\"\"\"\n\nimport numpy as np\n\nimport jax\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric\nfrom scico.optimize.admm import ADMM, MatrixSubproblemSolver\nfrom scico.trace import register_variable, trace_scico_calls\nfrom scico.util import device_info\n\n\"\"\"\nInitialize tracing. JIT must be disabled for correct tracing.\n\nThe call tracing mechanism prints the name, arguments, and return values\nof functions/methods as they are called. Module and class names are\nprinted in light red, function and method names in dark red, arguments\nand return values in light blue, and the names of registered variables\nin light yellow. When a method defined in a class is called for an object\nof a derived class type, the class of that object is printed in light\nmagenta, in square brackets. Function names and return values are\ndistinguished by initial \">>\" and \"<<\" characters respectively.\n\"\"\"\njax.config.update(\"jax_disable_jit\", True)\ntrace_scico_calls()\n\n\n\"\"\"\nCreate random dictionary, reference random sparse representation, and\ntest signal consisting of the synthesis of the reference sparse\nrepresentation.\n\"\"\"\nm = 32  # signal size\nn = 128  # dictionary size\ns = 10  # sparsity level\n\nnp.random.seed(1)\nD = np.random.randn(m, n).astype(np.float32)\nD = D / np.linalg.norm(D, axis=0, keepdims=True)  # normalize dictionary\n\nxt = np.zeros(n, dtype=np.float32)  # true signal\nidx = np.random.randint(low=0, high=n, size=s)  # support of xt\nxt[idx] = np.random.rand(s)\ny = D @ xt + 5e-2 * np.random.randn(m)  # synthetic signal\n\nxt = snp.array(xt)  # convert to jax array\ny = snp.array(y)  # convert to jax array\n\n\n\"\"\"\nRegister a variable so that it can be referenced by name in the call trace.\nAny hashable object and numpy arrays may be registered, but JAX arrays\ncannot.\n\"\"\"\nregister_variable(D, \"D\")\n\n\n\"\"\"\nSet up the forward operator and ADMM solver object.\n\"\"\"\nlmbda = 1e-1\nA = linop.MatrixOperator(D)\nregister_variable(A, \"A\")\nf = loss.SquaredL2Loss(y=y, A=A)\ng_list = [lmbda * functional.L1Norm(), functional.NonNegativeIndicator()]\nC_list = [linop.Identity((n)), linop.Identity((n))]\nrho_list = [1.0, 1.0]\nmaxiter = 1  # number of ADMM iterations (set to small value to simplify trace output)\n\nregister_variable(f, \"f\")\nregister_variable(g_list[0], \"g_list[0]\")\nregister_variable(g_list[1], \"g_list[1]\")\nregister_variable(C_list[0], \"C_list[0]\")\nregister_variable(C_list[1], \"C_list[1]\")\n\nsolver = ADMM(\n    f=f,\n    g_list=g_list,\n    C_list=C_list,\n    rho_list=rho_list,\n    x0=A.adj(y),\n    maxiter=maxiter,\n    subproblem_solver=MatrixSubproblemSolver(),\n    itstat_options={\"display\": True, \"period\": 5},\n)\n\nregister_variable(solver, \"solver\")\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx = solver.solve()\nmse = metric.mse(xt, x)\n"
  },
  {
    "path": "examples/scripts/video_rpca_admm.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n# This file is part of the SCICO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\nr\"\"\"\nVideo Decomposition via Robust PCA\n==================================\n\nThis example demonstrates video foreground/background separation via a\nvariant of the Robust PCA problem\n\n  $$\\mathrm{argmin}_{\\mathbf{x}_0, \\mathbf{x}_1} \\; (1/2) \\| \\mathbf{x}_0\n      + \\mathbf{x}_1 - \\mathbf{y} \\|_2^2 + \\lambda_0 \\| \\mathbf{x}_0 \\|_*\n      + \\lambda_1 \\| \\mathbf{x}_1 \\|_1 \\;,$$\n\nwhere $\\mathbf{x}_0$ and $\\mathbf{x}_1$ are respectively low-rank and\nsparse components, $\\| \\cdot \\|_*$ denotes the nuclear norm, and\n$\\| \\cdot \\|_1$ denotes the $\\ell_1$ norm.\n\nNote: while video foreground/background separation is not an example of\nthe scientific and computational imaging problems that are the focus of\nSCICO, it provides a convenient demonstration of Robust PCA, which does\nhave potential application in scientific imaging problems.\n\"\"\"\n\nimport imageio.v3 as iio\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, plot\nfrom scico.examples import rgb2gray\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.util import device_info\n\n\"\"\"\nLoad example video.\n\"\"\"\nvid = rgb2gray(\n    iio.imread(\"imageio:newtonscradle.gif\").transpose((1, 2, 3, 0)).astype(snp.float32) / 255.0\n)\n\n\n\"\"\"\nConstruct matrix with each column consisting of a vectorised video frame.\n\"\"\"\ny = vid.reshape((-1, vid.shape[-1]))\n\n\n\"\"\"\nDefine functional for Robust PCA problem.\n\"\"\"\nA = linop.Sum(axis=0, input_shape=(2,) + y.shape)\nf = loss.SquaredL2Loss(y=y, A=A)\nC0 = linop.Slice(idx=0, input_shape=(2,) + y.shape)\ng0 = functional.NuclearNorm()\nC1 = linop.Slice(idx=1, input_shape=(2,) + y.shape)\ng1 = functional.L1Norm()\n\n\n\"\"\"\nSet up an ADMM solver object.\n\"\"\"\nλ0 = 1e1  # nuclear norm regularization parameter\nλ1 = 3e1  # ℓ1 norm regularization parameter\nρ0 = 2e1  # ADMM penalty parameter\nρ1 = 2e1  # ADMM penalty parameter\nmaxiter = 50  # number of ADMM iterations\n\nsolver = ADMM(\n    f=f,\n    g_list=[λ0 * g0, λ1 * g1],\n    C_list=[C0, C1],\n    rho_list=[ρ0, ρ1],\n    x0=A.adj(y),\n    maxiter=maxiter,\n    subproblem_solver=LinearSubproblemSolver(),\n    itstat_options={\"display\": True, \"period\": 10},\n)\n\n\n\"\"\"\nRun the solver.\n\"\"\"\nprint(f\"Solving on {device_info()}\\n\")\nx = solver.solve()\nhist = solver.itstat_object.history(transpose=True)\n\n\n\"\"\"\nPlot convergence statistics.\n\"\"\"\nfig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\nplot.plot(\n    hist.Objective,\n    title=\"Objective function\",\n    xlbl=\"Iteration\",\n    ylbl=\"Functional value\",\n    fig=fig,\n    ax=ax[0],\n)\nplot.plot(\n    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,\n    ptyp=\"semilogy\",\n    title=\"Residuals\",\n    xlbl=\"Iteration\",\n    lgnd=(\"Primal\", \"Dual\"),\n    fig=fig,\n    ax=ax[1],\n)\nfig.show()\n\n\n\"\"\"\nReshape low-rank component as background video sequence and sparse component\nas foreground video sequence.\n\"\"\"\nxlr = C0(x)\nxsp = C1(x)\nvbg = xlr.reshape(vid.shape)\nvfg = xsp.reshape(vid.shape)\n\n\n\"\"\"\nDisplay original video frames and corresponding background and foreground frames.\n\"\"\"\nfig, ax = plot.subplots(nrows=4, ncols=3, figsize=(10, 10))\nax[0][0].set_title(\"Original\")\nax[0][1].set_title(\"Background\")\nax[0][2].set_title(\"Foreground\")\nfor n, fn in enumerate(range(1, 9, 2)):\n    plot.imview(vid[..., fn], fig=fig, ax=ax[n][0])\n    plot.imview(vbg[..., fn], fig=fig, ax=ax[n][1])\n    plot.imview(vfg[..., fn], fig=fig, ax=ax[n][2])\n    ax[n][0].set_ylabel(\"Frame %d\" % fn, labelpad=5, rotation=90, size=\"large\")\nfig.tight_layout()\nfig.show()\n\n\ninput(\"\\nWaiting for input to close figures and exit\")\n"
  },
  {
    "path": "examples/updatejnbcode.py",
    "content": "#!/usr/bin/env python\n\n# Update code cells in notebooks from corresponding scripts without\n# the need to re-execute the notebook. NB: use with caution!\n# Run as\n#     python updatejnbcode.py <script-name.py>\n\nimport os\nimport sys\n\nfrom jnb import py_file_to_string, read_notebook\nfrom py2jn.tools import py_string_to_notebook, write_notebook\n\n\ndef replace_code_cells(src, dst):\n    \"\"\"Overwrite code cells in notebook object `dst` with corresponding\n    cells in notebook object `src`.\n    \"\"\"\n\n    if \"cells\" in src:\n        srccell = src[\"cells\"]\n    else:\n        srccell = src[\"worksheets\"][0][\"cells\"]\n    if \"cells\" in dst:\n        dstcell = dst[\"cells\"]\n    else:\n        dstcell = dst[\"worksheets\"][0][\"cells\"]\n\n    # It is an error to attempt replacement if src and dst have different\n    # numbers of cells\n    if len(srccell) != len(dstcell):\n        raise ValueError(\"Notebooks do not have the same number of cells.\")\n\n    # Iterate over cells in src\n    for n in range(len(srccell)):\n        # It is an error to attempt replacement if any corresponding pair\n        # of cells have different type\n        if srccell[n][\"cell_type\"] != dstcell[n][\"cell_type\"]:\n            raise ValueError(\"Cell number %d of different type in src and dst.\")\n        # If current src cell is a code cell, copy the src cell to the dst cell\n        if srccell[n][\"cell_type\"] == \"code\":\n            dstcell[n][\"source\"] = srccell[n][\"source\"]\n\n\nsrc = sys.argv[1]\ndst = os.path.join(\"notebooks\", os.path.splitext(os.path.basename(src))[0] + \".ipynb\")\nprint(f\"Updating code cells in {dst} from {src}\")\nif os.path.exists(dst):\n    srcnb = py_string_to_notebook(py_file_to_string(src), nbver=4)\n    dstnb = read_notebook(dst)\n    replace_code_cells(srcnb, dstnb)\n    write_notebook(dstnb, dst)\n"
  },
  {
    "path": "examples/updatejnbmd.py",
    "content": "#!/usr/bin/env python\n\n# Update markdown cells in notebooks from corresponding scripts without\n# the need to re-execute the notebook. Only applicable if the changes to\n# the script since generation of the corresponding notebook only affect\n# markdown cells.\n# Run as\n#     python updatejnbmd.py\n\nimport glob\nimport os\n\nfrom jnb import (\n    py_file_to_string,\n    read_notebook,\n    replace_markdown_cells,\n    same_notebook_code,\n    same_notebook_markdown,\n)\nfrom py2jn.tools import py_string_to_notebook, write_notebook\n\nfor src in glob.glob(os.path.join(\"scripts\", \"*.py\")):\n    dst = os.path.join(\"notebooks\", os.path.splitext(os.path.basename(src))[0] + \".ipynb\")\n    if os.path.exists(dst):\n        srcnb = py_string_to_notebook(py_file_to_string(src), nbver=4)\n        dstnb = read_notebook(dst)\n        if not same_notebook_code(srcnb, dstnb):\n            print(f\"Non-markup changes in {src}\")\n            continue\n        if not same_notebook_markdown(srcnb, dstnb):\n            print(f\"Updating markdown in {dst}\")\n            replace_markdown_cells(srcnb, dstnb)\n            write_notebook(dstnb, dst)\n"
  },
  {
    "path": "misc/README.rst",
    "content": "Miscellaneous\n=============\n\nThis directory is a temporary location for content for which there is no\nobviously more appropriate location:\n\n- ``conda``: Scripts intended to faciliate the installation of miniconda and an environment with all SCICO requirements.\n- ``gpu``: Scripts for debugging and managing JAX use of GPUs.\n- ``pytest``: Scripts for specialized use of ``pytest``.\n"
  },
  {
    "path": "misc/conda/README.rst",
    "content": "Conda Installation Scripts\n==========================\n\nThese scripts are intended to faciliate the installation of `miniconda <https://docs.conda.io/en/latest/miniconda.html>`__ and an environment with all SCICO requirements:\n\n- ``install_conda.sh``:  Install miniconda\n- ``make_conda_env.sh``:  Create a conda environment with all SCICO requirements\n\nFor usage details, run the scripts with the ``-h`` flag, e.g. ``./install_conda.sh -h``.\n\n\nExample Usage\n-------------\n\nTo install miniconda in ``/opt/conda`` do\n\n::\n\n   ./install_conda.sh -y /opt/conda\n\n\nTo create a conda environment called ``scico`` with Python version 3.12 and without GPU support\n\n::\n\n   ./make_conda_env.sh -y -p 3.12 -e scico\n\n\nTo include GPU support, follow the `jax installation instructions <https://github.com/google/jax#pip-installation-gpu-cuda>`__ after\nrunning this script and activating the environment created by it.\n\n\nCaveats\n-------\n\nThese scripts should function correctly out-of-the-box on a standard Linux installation. (If you find that this is not the case, please create a GitHub issue, providing details of the Linux variant and version.)\n\nWhile these scripts are supported under OSX (MacOS), there are some caveats:\n\n- Required utilities ``realpath`` and ``gsed`` (GNU sed) must be installed via MacPorts or some other 3rd party package management system.\n- Installation of jaxlib with GPU capabilities is not supported.\n- While ``make_conda_env.sh`` installs ``matplotlib``, it does not attempt to resolve the `additional complications <https://matplotlib.org/faq/osx_framework.html>`_ in using a conda installed matplotlib under OSX.\n"
  },
  {
    "path": "misc/conda/install_conda.sh",
    "content": "#!/usr/bin/env bash\n\n# This script installs miniconda3 in the specified path\n#\n# Run with -h flag for usage information\n\nURLROOT=https://repo.continuum.io/miniconda/\nINSTLINUX=Miniconda3-latest-Linux-x86_64.sh\nINSTMACOSX=Miniconda3-latest-MacOSX-x86_64.sh\n\nSCRIPT=$(basename $0)\nUSAGE=$(cat <<-EOF\nUsage: $SCRIPT [-h] [-y] install_path\n          [-h] Display usage information\n          [-y] Do not ask for confirmation\nEOF\n)\nAGREE=no\n\nOPTIND=1\nwhile getopts \":hy\" opt; do\n  case $opt in\n    h) echo \"$USAGE\"; exit 0;;\n    y) AGREE=yes;;\n    \\?) echo \"Error: invalid option -$OPTARG\" >&2\n\techo \"$USAGE\" >&2\n\texit 1\n\t;;\n  esac\ndone\n\nshift $((OPTIND-1))\nif [ ! $# -eq 1 ] ; then\n    echo \"Error: one positional argument required\" >&2\n    echo \"$USAGE\" >&2\n    exit 1\nfi\n\nOS=$(uname -a | cut -d ' ' -f 1)\ncase \"$OS\" in\n    Linux)    SOURCEURL=$URLROOT$INSTLINUX;;\n    Darwin)   SOURCEURL=$URLROOT$INSTMACOSX;;\n    *)        echo \"Error: unsupported operating system $OS\" >&2; exit 2;;\nesac\n\nif [ ! \"$(which wget 2>/dev/null)\" ]; then\n    has_wget=0\nelse\n    has_wget=1\nfi\n\nif [ ! \"$(which curl 2>/dev/null)\" ]; then\n    has_curl=0\nelse\n    has_curl=1\nfi\n\nif [ $has_curl -eq 0 ] && [ $has_wget -eq 0 ]; then\n    echo \"Error: neither curl nor wget found; at least one required\" >&2\n    exit 3\nfi\n\nINSTALLROOT=$1\nif [ ! -d \"$INSTALLROOT\" ] || [ ! -w \"$INSTALLROOT\" ]; then\n    echo \"Error: installation root path \\\"$INSTALLROOT\\\" is not a directory \"\\\n\t \"or is not writable\"  >&2\n    exit 4\nfi\n\nCONDAHOME=$INSTALLROOT/miniconda3\nif [ -d \"$CONDAHOME\" ]; then\n    echo \"Error: miniconda3 installation directory $CONDAHOME already exists\"\\\n\t >&2\n    exit 5\nfi\n\nif [ \"$AGREE\" == \"no\" ]; then\n    read -r -p \"Confirm conda installation in root path $INSTALLROOT [y/N] \"\\\n\t CNFRM\n    if [ \"$CNFRM\" != 'y' ] && [ \"$CNFRM\" != 'Y' ]; then\n\techo \"Cancelling installation\"\n\texit 6\n    fi\nfi\n\n# Get miniconda bash archive and install it\nif [ $has_wget -eq 1 ]; then\n    wget $SOURCEURL -O /tmp/miniconda.sh\nelif [ $has_curl -eq 1 ]; then\n    curl -L $SOURCEURL -o /tmp/miniconda.sh\nfi\n\nbash /tmp/miniconda.sh -b -p $CONDAHOME\nrm -f /tmp/miniconda.sh\n\n# Initial conda setup\nexport PATH=\"$CONDAHOME/bin:$PATH\"\nhash -r\nconda config --set always_yes yes\nconda update -q conda\nconda info -a\n\necho \"Add the following to your .bashrc or .bash_aliases file\"\necho \"  export CONDAHOME=$CONDAHOME\"\necho \"  export PATH=\\$PATH:\\$CONDAHOME/bin\"\n\nexit 0\n"
  },
  {
    "path": "misc/conda/make_conda_env.sh",
    "content": "#!/usr/bin/env bash\n\n# This script installs a conda environment with all required and\n# optional scico dependencies. The user is assumed to have write\n# permission for the conda installation. It should function correctly\n# under both Linux and OSX, but note that there are some additional\n# complications in using a conda installed matplotlib under OSX\n#   https://matplotlib.org/faq/osx_framework.html\n# that are not addressed, and that installation of jaxlib with GPU\n# capabilities is not supported under OSX. Note also that additional\n# utilities realpath and gsed (gnu sed), available from MacPorts, are\n# required to run this script under OSX.\n#\n# Run with -h flag for usage information\nset -e  # exit when any command fails\n\nif [ \"$(cut -d '.' -f 1 <<< \"$BASH_VERSION\")\" -lt \"4\" ]; then\n    echo \"Error: this script requires bash version 4 or later\" >&2\n    exit 1\nfi\n\n\nSCRIPT=$(basename $0)\nREPOPATH=$(realpath $(dirname $0))\nUSAGE=$(cat <<-EOF\nUsage: $SCRIPT [-h] [-y] [-g] [-p python_version] [-e env_name]\n          [-h] Display usage information\n          [-v] Verbose operation\n          [-t] Display actions that would be taken but do nothing\n          [-y] Do not ask for confirmation\n          [-p python_version] Specify Python version (e.g. 3.12)\n          [-e env_name] Specify conda environment name\nEOF\n)\n\nAGREE=no\nVERBOSE=no\nTEST=no\nPYVER=\"3.12\"\nENVNM=py$(echo $PYVER | sed -e 's/\\.//g')\n\n# Project requirements files\nREQUIRE=$(cat <<-EOF\nrequirements.txt\ndev_requirements.txt\ndocs/docs_requirements.txt\nexamples/examples_requirements.txt\nexamples/notebooks_requirements.txt\nEOF\n)\n# Requirements that cannot be installed via conda (i.e. have to use pip)\nNOCONDA=$(cat <<-EOF\nflax orbax-checkpoint bm3d bm4d py2jn colour_demosaicing hyperopt ray[tune,train]\nEOF\n)\n\n\nOPTIND=1\nwhile getopts \":hvtyp:e:\" opt; do\n    case $opt in\n\tp|e) if [ -z \"$OPTARG\" ] || [ \"${OPTARG:0:1}\" = \"-\" ] ; then\n\t\t     echo \"Error: option -$opt requires an argument\" >&2\n\t\t     echo \"$USAGE\" >&2\n\t\t     exit 2\n\t\t fi\n\t\t ;;&\n\th) echo \"$USAGE\"; exit 0;;\n\tt) VERBOSE=yes;TEST=yes;;\n\tv) VERBOSE=yes;;\n\ty) AGREE=yes;;\n\tp) PYVER=$OPTARG;;\n\te) ENVNM=$OPTARG;;\n\t:) echo \"Error: option -$OPTARG requires an argument\" >&2\n           echo \"$USAGE\" >&2\n           exit 2\n           ;;\n\t\\?) echo \"Error: invalid option -$OPTARG\" >&2\n            echo \"$USAGE\" >&2\n            exit 2\n            ;;\n    esac\ndone\n\nshift $((OPTIND-1))\nif [ ! $# -eq 0 ] ; then\n    echo \"Error: no positional arguments\" >&2\n    echo \"$USAGE\" >&2\n    exit 2\nfi\n\nif [ ! \"$(which conda 2>/dev/null)\" ]; then\n    echo \"Error: conda command required but not found\" >&2\n    exit 3\nfi\n\n# Not available on BSD systems such as OSX: install via MacPorts etc.\nif [ ! \"$(which realpath 2>/dev/null)\" ]; then\n    echo \"Error: realpath command required but not found\" >&2\n    exit 4\nfi\n\n# Ensure that a C compiler is available; required for installing svmbir\n# On debian/ubuntu linux systems, install package build-essential\nif [ -z \"$CC\" ] && [ ! \"$(which gcc 2>/dev/null)\" ]; then\n    echo \"Error: gcc command not found and CC environment variable not set\"\n    echo \"       set CC to the path of your C compiler, or install gcc.\"\n    echo \"       On debian/ubuntu, you may need to do\"\n    echo \"           sudo apt install build-essential\"\n    exit 5\nfi\n\nOS=$(uname -a | cut -d ' ' -f 1)\ncase \"$OS\" in\n    Linux)    SOURCEURL=$URLROOT$INSTLINUX; SED=\"sed\";;\n    Darwin)   SOURCEURL=$URLROOT$INSTMACOSX; SED=\"gsed\";;\n    *)        echo \"Error: unsupported operating system $OS\" >&2; exit 6;;\nesac\nif [ \"$OS\" == \"Darwin\" ] && [ \"$GPU\" == yes ]; then\n    echo \"Error: GPU-enabled jaxlib installation not supported under OSX\" >&2\n    exit 7\nfi\nif [ \"$OS\" == \"Darwin\" ]; then\n    if [ ! \"$(which gsed 2>/dev/null)\" ]; then\n\techo \"Error: gsed command required but not found\" >&2\n\texit 8\n    fi\nfi\n\nJLVER=$($SED -n 's/^jaxlib>=.*<=\\([0-9\\.]*\\).*/\\1/p' \\\n\t     $REPOPATH/../../requirements.txt)\nJXVER=$($SED -n 's/^jax>=.*<=\\([0-9\\.]*\\).*/\\1/p' \\\n\t     $REPOPATH/../../requirements.txt)\n\n# Construct merged list of all requirements\nif [ \"$OS\" == \"Darwin\" ]; then\n    ALLREQUIRE=$(/usr/bin/mktemp -t condaenv)\nelse\n    ALLREQUIRE=$(mktemp -t condaenv_XXXXXX.txt)\nfi\nfor req in $REQUIRE; do\n    pthreq=\"$REPOPATH/../../$req\"\n    cat $pthreq >> $ALLREQUIRE\ndone\n\n# Construct filtered list of requirements: sort, remove duplicates, and\n# remove requirements that require special handling\nif [ \"$OS\" == \"Darwin\" ]; then\n    FLTREQUIRE=$(mktemp -t condaenv)\nelse\n    FLTREQUIRE=$(mktemp -t condaenv_XXXXXX.txt)\nfi\n# Filter the list of requirements; sed patterns are for\n#  1st: escape >,<,| characters with a backslash\n#  2nd: remove comments in requirements file\n#  3rd: remove recursive include (-r) lines and packages that require\n#       special handling, e.g. jaxlib\nsort $ALLREQUIRE | uniq | $SED -E 's/(>|<|\\|)/\\\\\\1/g' \\\n    | $SED -E 's/\\#.*$//g' \\\n    | $SED -E '/^-r.*|^jaxlib.*|^jax.*/d' > $FLTREQUIRE\n# Remove requirements that cannot be installed via conda\nPIPREQ=\"\"\nfor nc in $NOCONDA; do\n    # Escape [ and ] for use in regex\n    nc=$(echo $nc | $SED -E 's/(\\[|\\])/\\\\\\1/g')\n    # Add package to pip package list\n    PIPREQ=\"$PIPREQ \"$(grep \"$nc\" $FLTREQUIRE | $SED 's/\\\\//g')\n    # Remove package $nc from conda package list\n    $SED -i \"/^$nc.*\\$/d\" $FLTREQUIRE\ndone\n# Get list of requirements to be installed via conda\nCONDAREQ=$(cat $FLTREQUIRE | xargs)\n\nif [ \"$VERBOSE\" == \"yes\" ]; then\n    echo \"Create python $PYVER environment $ENVNM in conda installation\"\n    echo \"    $CONDAHOME\"\n    echo \"Packages to be installed via conda:\"\n    echo \"    $CONDAREQ\" | fmt -w 79\n    echo \"Packages to be installed via pip:\"\n    echo \"    jaxlib==$JLVER jax==$JXVER $PIPREQ\" | fmt -w 79\n    if [ \"$TEST\" == \"yes\" ]; then\n\texit 0\n    fi\nfi\n\nCONDAHOME=$(conda info --base)\nENVDIR=$CONDAHOME/envs/$ENVNM\nif [ -d \"$ENVDIR\" ]; then\n    echo \"Error: environment $ENVNM already exists\"\n    exit 9\nfi\n\nif [ \"$AGREE\" == \"no\" ]; then\n    RSTR=\"Confirm creation of conda environment $ENVNM with Python $PYVER\"\n    RSTR=\"$RSTR [y/N] \"\n    read -r -p \"$RSTR\" CNFRM\n    if [ \"$CNFRM\" != 'y' ] && [ \"$CNFRM\" != 'Y' ]; then\n\techo \"Cancelling environment creation\"\n\texit 10\n    fi\nelse\n    echo \"Creating conda environment $ENVNM with Python $PYVER\"\nfi\n\nif [ \"$AGREE\" == \"yes\" ]; then\n    CONDA_FLAGS=\"-y\"\nelse\n    CONDA_FLAGS=\"\"\nfi\n\n\n# Update conda, create new environment, and activate it\nconda update $CONDA_FLAGS -n base conda\nconda create $CONDA_FLAGS -n $ENVNM python=$PYVER\n\n# See https://stackoverflow.com/a/56155771/1666357\neval \"$(conda shell.bash hook)\"  # required to avoid errors re: `conda init`\nconda activate $ENVNM  # Q: why not `source activate`? A: not always in the path\n\n# Add conda-forge channel\nconda config --append channels conda-forge\n\n# Install required conda packages (and extra useful packages)\nconda install $CONDA_FLAGS $CONDAREQ ipython\n\n# Utility ffmpeg is required by imageio for reading mp4 video files\n# it can also be installed via the system package manager, .e.g.\n#    sudo apt install ffmpeg\nif [ \"$(which ffmpeg)\" = '' ]; then\n    conda install $CONDA_FLAGS ffmpeg\nfi\n\n# Install jaxlib and jax\npip install --upgrade jaxlib==$JLVER jax==$JXVER\n\n# Install other packages that require installation via pip\npip install $PIPREQ\n\n# Warn if libopenblas-dev not installed on debian/ubuntu\nif [ \"$(which dpkg 2>/dev/null)\" ]; then\n    if [ ! \"$(dpkg -s libopenblas-dev 2>/dev/null)\" ]; then\n\techo \"Warning (debian/ubuntu): package libopenblas-dev,\"\n\techo \"which is required by bm3d, does not appear to be\"\n\techo \"installed; install using the command\"\n\techo \"   sudo apt install libopenblas-dev\"\n    fi\nfi\n\necho\necho \"Activate the conda environment with the command\"\necho \"  conda activate $ENVNM\"\necho \"The environment can be deactivated with the command\"\necho \"  conda deactivate\"\necho\necho \"JAX installed without GPU support. To enable GPU support, install a\"\necho \"version of jaxlib with CUDA support following the instructions at\"\necho \"   https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu\"\necho \"In most cases this just requires the command\"\necho \"   pip install -U \\\"jax[cuda12]\\\"\"\necho\necho \"ASTRA Toolbox installed without GPU support if this script was\"\necho \"run on a host without CUDA drivers installed. To enable GPU support,\"\necho \"uninstall and then reinstall the astra-toolbox conda package on a\"\necho \"host with CUDA drivers installed.\"\n\nexit 0\n"
  },
  {
    "path": "misc/gpu/README.rst",
    "content": "GPU Utility Scripts\n===================\n\nThese scripts are intended for debugging and managing JAX use of GPUs:\n\n- ``availgpu.py``: Automatically recommend a setting of the ``CUDA_VISIBLE_DEVICES`` environment variable that excludes GPUs that are already in use.\n- ``envinfo.py``: An aid to debugging JAX GPU access.\n"
  },
  {
    "path": "misc/gpu/availgpu.py",
    "content": "#!/usr/bin/env python\n\n# Determine which GPUs available for use and recommend CUDA_VISIBLE_DEVICES\n# setting if any are already in use.\n\n# pylint: disable=missing-module-docstring\n\n\nimport GPUtil\n\nprint(\"GPU utlizitation\")\nGPUtil.showUtilization()\n\ndevIDs = GPUtil.getAvailable(\n    order=\"first\", limit=65536, maxLoad=0.1, maxMemory=0.1, includeNan=False\n)\n\nNgpu = len(GPUtil.getGPUs())\nif len(devIDs) == Ngpu:\n    print(f\"All {Ngpu} GPUs available for use\")\nelse:\n    print(f\"Only {len(devIDs)} of {Ngpu} GPUs available for use\")\n    print(\"To avoid attempting to use GPUs already in use, run the command\")\n    print(f\"    export CUDA_VISIBLE_DEVICES={','.join(map(str, devIDs))}\")\n"
  },
  {
    "path": "misc/gpu/envinfo.py",
    "content": "#!/usr/bin/env python\n\n# Print host and environment information. Useful for determining whether\n# a Python host has available GPUs, and if so, whether the JAX installation\n# is able to make use of them.\n\n# pylint: disable=missing-module-docstring\n\nimport sys\n\nmissing = []\n\ntry:\n    import psutil\n\n    have_psutil = True\nexcept ImportError:\n    have_psutil = False\n    missing.append(\"psutil\")\n\ntry:\n    import GPUtil\n\n    have_gputil = True\nexcept ImportError:\n    have_gputil = False\n    missing.append(\"gputil\")\n\nimport jax\n\nimport jaxlib\n\ntry:\n    import scico\n\n    have_scico = True\nexcept ImportError:\n    scico = None\n    have_scico = False\n    missing.append(\"scico\")\n\n\nif missing:\n    print(\"Some output not available due to missing modules: \" + \", \".join(missing))\n\npyver = \".\".join([f\"{v}\" for v in sys.version_info[0:3]])\nprint(f\"Python version: {pyver}\")\nprint(\"Packages:\")\npackages = [jaxlib, jax, scico]\nfor p in packages:\n    if hasattr(p, \"__version__\") and hasattr(p, \"__name__\"):\n        v = getattr(p, \"__version__\")\n        n = getattr(p, \"__name__\")\n        print(f\"    {n:15s} {v}\")\n\nif have_psutil:\n    print(f\"Number of CPU cores: {psutil.cpu_count(logical=False)}\")\n\nif have_gputil:\n    if GPUtil.getAvailable():\n        print(\"GPUs:\")\n        for gpu in GPUtil.getGPUs():\n            print(f\"    {gpu.id:2d}  {gpu.name:10s}  {gpu.memoryTotal} kB RAM\")\n    else:\n        print(\"No GPUs available\")\n\nsys.stderr = open(\"/dev/null\")  # suppress annoying jax warning\nnumdev = jax.device_count()\nif jax.devices()[0].device_kind == \"cpu\":\n    print(\"No GPUs available to JAX (JAX device is CPU)\")\nelse:\n    print(f\"Number of GPUs available to JAX: {jax.device_count()}\")\n"
  },
  {
    "path": "misc/pytest/README.rst",
    "content": "Specialized Pytest Usage\n========================\n\nThese scripts support specialized ``pytest`` usage:\n\n- ``pytest_cov.sh``: This script runs ``scico`` unit tests using the ``pytest-cov`` plugin for test coverage analysis.\n- ``pytest_fast.sh``: This script runs ``pytest`` tests in parallel using the ``pytest-xdist`` plugin. Some tests (those that do not function correctly when run in parallel) are run separately.\n- ``pytest_time.sh``: This script runs each ``scico`` unit test module and lists them all in order of decreasing run time.\n\nAll of these scripts must be run from the repository root directory.\n"
  },
  {
    "path": "misc/pytest/pytest_cov.sh",
    "content": "#!/usr/bin/env bash\n\n# This script runs scico unit tests using the pytest-cov plugin for test\n# coverage analysis. It must be run from the repository root directory.\n\nplugin=\"pytest-cov\"\nif ! pytest -VV | grep -o $plugin > /dev/null; then\n    echo Required pytest plugin $plugin not installed\n    exit 1\nfi\n\npytest --cov=scico --cov-report html\n\necho \"To view the report, open htmlcov/index.html in a web browser.\"\n\nexit 0\n"
  },
  {
    "path": "misc/pytest/pytest_fast.sh",
    "content": "#!/usr/bin/env bash\n\n# This script runs pytest tests in parallel using the pytest-xdist plugin.\n# Some tests that do not function correctly when run in parallel are run\n# separately. It must be run from the repository root directory.\n\nplugin=\"pytest-xdist\"\nif ! pytest -VV | grep -o $plugin > /dev/null; then\n    echo Required pytest plugin $plugin not installed\n    exit 1\nfi\n\npytest --deselect scico/test/test_ray_tune.py \\\n       --deselect scico/test/functional/test_core.py -x -n 2\npytest -x scico/test/test_ray_tune.py scico/test/functional/test_core.py\n\nexit 0\n"
  },
  {
    "path": "misc/pytest/pytest_time.sh",
    "content": "#!/usr/bin/env bash\n\n# This script runs each scico unit test module and lists them all in order\n# of decreasing run time. It must be run from the repository root directory.\n\ntmp=/tmp/pytest_time.$$\nrm -f $tmp\nfor f in $(find scico/test -name \"test_*.py\"); do\n    tstr=$(/usr/bin/time -p pytest -qqq --disable-warnings $f 2>&1 | tail -4)\n    # Warning does not work in OSX bash\n    if grep -q \"Command exited with non-zero status\" <<<\"$tstr\"; then\n\techo \"WARNING: test failure in $f\" >&2\n    fi\n    t=$(grep \"^real\" <<<\"$tstr\" | grep -o -E \"[0-9\\.]*$\")\n    printf \"%6.2f  %s\\n\" $t $f >> $tmp\ndone\nsort -r -n $tmp\nrm $tmp\n\nexit 0\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools\"]\nbuild-backend = \"setuptools.build_meta\"\n\n\n[tool.black]\nline-length = 100\ntarget-version = ['py312']\ninclude = '\\.pyi?$'\nexclude = '''\n\n(\n  /(\n      \\.eggs         # exclude a few common directories in the\n    | \\.git          # root of the project\n    | \\.hg\n    | \\.mypy_cache\n    | \\.tox\n    | \\.venv\n    | _build\n    | buck-out\n    | build\n    | dist\n  )/\n  | foo.py           # also separately exclude a file named foo.py in\n                     # the root of the project\n)\n'''\n\n[tool.isort]\nprofile = \"black\"\nmulti_line_output = 3\nknown_jax = ['jax']\nknown_numpy = ['numpy']\nsections = ['FUTURE', 'STDLIB', 'NUMPY', 'JAX', 'THIRDPARTY', 'FIRSTPARTY', 'LOCALFOLDER']\nsrc_paths = [\"scico\", \"examples/scripts\"]\n\n[mypy]\npython_version = 3.12\ndisable_error_code = ['attr-defined']\n"
  },
  {
    "path": "pytest.ini",
    "content": "[pytest]\ntestpaths = scico/test docs\naddopts = --doctest-glob=\"*rst\"\ndoctest_optionflags = NORMALIZE_WHITESPACE NUMBER\nfilterwarnings =\n    ignore::DeprecationWarning:.*pkg_resources.*\n    ignore::DeprecationWarning:.*hyperopt.*\n    ignore::DeprecationWarning:.*flax.*\n    ignore::DeprecationWarning:.*.tensorboardx.*\n    ignore::DeprecationWarning:.*xdesign.*\n    ignore:.*pkg_resources.*:DeprecationWarning\n    ignore:.*imp module.*:DeprecationWarning\n"
  },
  {
    "path": "requirements.txt",
    "content": "typing_extensions\nnumpy>=2.0\nscipy>=1.13\nimageio>=2.17\ntifffile\nmatplotlib\njaxlib>=0.5.0,<=0.10.0\njax>=0.5.0,<=0.10.0\nflax>=0.8.0,<=0.12.7\npyabel>=0.9.1\n"
  },
  {
    "path": "scico/__init__.py",
    "content": "# Copyright (C) 2021-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Scientific Computational Imaging COde (SCICO) is a Python package for\nsolving the inverse problems that arise in scientific imaging applications.\n\"\"\"\n\n__version__ = \"0.0.8.dev0\"\n\nimport logging\nimport sys\n\n# isort: off\n\n# Suppress jax device warning. See https://github.com/google/jax/issues/6805\nlogging.getLogger(\"jax._src.xla_bridge\").addFilter(  # jax 0.4.8 and later\n    logging.Filter(\"No GPU/TPU found, falling back to CPU.\")\n)\n\n# isort: on\n\nimport jax\nfrom jax import custom_jvp, custom_vjp, hessian, jacfwd, jvp, linearize, vjp\n\nimport jaxlib\n\nfrom . import numpy\nfrom ._core import *\nfrom ._core import __all__ as _core_all\n\n# See https://github.com/google/jax/issues/19444\njax.config.update(\"jax_default_matmul_precision\", \"highest\")\n\n__all__ = _core_all + [\n    \"custom_jvp\",\n    \"custom_vjp\",\n    \"hessian\",\n    \"jacfwd\",\n    \"jvp\",\n    \"linearize\",\n    \"vjp\",\n]\n\n# Imported items in __all__ appear to originate in top-level functional module\nfor name in __all__:\n    getattr(sys.modules[__name__], name).__module__ = __name__\n"
  },
  {
    "path": "scico/_core.py",
    "content": "# Copyright (C) 2020-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Extensions of core jax functions, including tools for automatic differentiation\nand shape evaluation.\"\"\"\n\nimport sys\nfrom typing import Any, Callable, Optional, Sequence, Tuple, Union\n\nimport jax\nfrom jax.tree_util import tree_map\n\nimport scico.numpy\nimport scico.numpy.util\nimport scico.util\n\n__all__ = [\n    \"cvjp\",\n    \"eval_shape\",\n    \"grad\",\n    \"jacrev\",\n    \"linear_adjoint\",\n    \"linear_transpose\",\n    \"value_and_grad\",\n]\n\n\ndef _append_jax_docs(fn, jaxfn=None):\n    \"\"\"Append the jax function docs.\n\n    Given wrapper function `fn`, concatenate its docstring with the\n    docstring of the wrapped jax function.\n    \"\"\"\n\n    name = fn.__name__\n    if jaxfn is None:\n        jaxfn = getattr(jax, name)\n    doc = \"  \" + fn.__doc__.replace(\"\\n    \", \"\\n  \")  # deal with indentation differences\n    jaxdoc = \"\\n\".join(jaxfn.__doc__.split(\"\\n\")[2:])  # strip initial lines\n    return doc + f\"\\n  Docstring for :func:`jax.{name}`:\\n\\n\" + jaxdoc\n\n\ndef _convert_ba_dts(arg: Any) -> Any:\n    \"\"\"Convert a ShapeDtypeStruct with nested shape into a BlockArray\n    of ShapeDtypeStruct.\n    \"\"\"\n    if isinstance(arg, jax.ShapeDtypeStruct) and scico.numpy.util.is_nested(arg.shape):\n        return scico.numpy.BlockArray(\n            [jax.ShapeDtypeStruct(blk_shape, dtype=arg.dtype) for blk_shape in arg.shape]\n        )\n    else:\n        return arg\n\n\ndef eval_shape(fun: Callable, *args, **kwargs) -> Any:\n    \"\"\"Compute the shape and dtype of a function without executing it.\n\n    Compute the shape and dtype of a function without executing it, via\n    a call to :func:`jax.eval_shape`, with ``args`` and ``kwargs`` mapped\n    to handle :class:`jax.ShapeDtypeStruct` objects with nested shapes\n    corresponding to :class:`.BlockArray` objects.\n    \"\"\"\n    mapped_args = jax.tree_util.tree_map(_convert_ba_dts, args)\n    mapped_kwargs = jax.tree_util.tree_map(_convert_ba_dts, kwargs)\n    return jax.eval_shape(fun, *mapped_args, **mapped_kwargs)\n\n\ndef grad(\n    fun: Callable,\n    argnums: Union[int, Sequence[int]] = 0,\n    has_aux: bool = False,\n    holomorphic: bool = False,\n    allow_int: bool = False,\n) -> Callable:\n    \"\"\"Create a function that evaluates the gradient of `fun`.\n\n    :func:`scico.grad` differs from :func:`jax.grad` in that the output\n    is conjugated.\n    \"\"\"\n\n    jax_grad = jax.grad(\n        fun=fun, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int\n    )\n\n    def conjugated_grad_aux(*args, **kwargs):\n        jg, aux = jax_grad(*args, **kwargs)\n        return tree_map(jax.numpy.conj, jg), aux\n\n    def conjugated_grad(*args, **kwargs):\n        jg = jax_grad(*args, **kwargs)\n        return tree_map(jax.numpy.conj, jg)\n\n    return conjugated_grad_aux if has_aux else conjugated_grad\n\n\ndef value_and_grad(\n    fun: Callable,\n    argnums: Union[int, Sequence[int]] = 0,\n    has_aux: bool = False,\n    holomorphic: bool = False,\n    allow_int: bool = False,\n) -> Callable[..., Tuple[Any, Any]]:\n    \"\"\"Create a function that evaluates both `fun` and its gradient.\n\n    :func:`scico.value_and_grad` differs from :func:`jax.value_and_grad`\n    in that the gradient is conjugated.\n    \"\"\"\n    jax_val_grad = jax.value_and_grad(\n        fun=fun, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int\n    )\n\n    def conjugated_value_and_grad_aux(*args, **kwargs):\n        (value, aux), jg = jax_val_grad(*args, **kwargs)\n        conj_grad = tree_map(jax.numpy.conj, jg)\n        return (value, aux), conj_grad\n\n    def conjugated_value_and_grad(*args, **kwargs):\n        value, jax_grad = jax_val_grad(*args, **kwargs)\n        conj_grad = tree_map(jax.numpy.conj, jax_grad)\n        return value, conj_grad\n\n    return conjugated_value_and_grad_aux if has_aux else conjugated_value_and_grad\n\n\ndef linear_transpose(fun: Callable, *primals) -> Callable:\n    \"\"\"Transpose a function that is guaranteed to be linear.\n\n    :func:`scico.linear_adjoint` differs from :func:`jax.linear_transpose`\n    in that it correctly handles primals consisting of\n    :class:`jax.ShapeDtypeStruct` objects with nested shapes, i.e.\n    corresponding to :class:`.BlockArray` shapes.\n    \"\"\"\n    mapped_primals = jax.tree_util.tree_map(_convert_ba_dts, primals)\n    return jax.linear_transpose(fun, *mapped_primals)\n\n\ndef linear_adjoint(fun: Callable, *primals) -> Callable:\n    \"\"\"Conjugate transpose a function that is guaranteed to be linear.\n\n    :func:`scico.linear_adjoint` differs from :func:`jax.linear_transpose`\n    for complex inputs in that the conjugate transpose (adjoint) of `fun`\n    is returned. :func:`scico.linear_adjoint` is identical to\n    :func:`jax.linear_transpose` for real-valued primals.\n    \"\"\"\n\n    def conj_fun(*primals):\n        conj_primals = tree_map(jax.numpy.conj, primals)\n        return tree_map(jax.numpy.conj, fun(*(conj_primals)))\n\n    return linear_transpose(conj_fun, *primals)\n\n\ndef jacrev(\n    fun: Callable,\n    argnums: Union[int, Sequence[int]] = 0,\n    holomorphic: bool = False,\n    allow_int: bool = False,\n) -> Callable:\n    \"\"\"Jacobian of `fun` evaluated row-by-row using reverse-mode AD.\n\n    :func:`scico.jacrev` differs from :func:`jax.jacrev` in that the\n    output is conjugated.\n    \"\"\"\n\n    jax_jacrev = jax.jacrev(fun=fun, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int)\n\n    def conjugated_jacrev(*args, **kwargs):\n        tmp = jax_jacrev(*args, **kwargs)\n        return tree_map(jax.numpy.conj, tmp)\n\n    return conjugated_jacrev\n\n\ndef cvjp(fun: Callable, *primals, jidx: Optional[int] = None) -> Tuple[Tuple[Any, ...], Callable]:\n    r\"\"\"Compute a vector-Jacobian product with conjugate transpose.\n\n    Compute the product :math:`[J(\\mb{x})]^H \\mb{v}` where\n    :math:`[J(\\mb{x})]` is the Jacobian of function `fun` evaluated at\n    :math:`\\mb{x}`. Instead of directly evaluating the product, a\n    function is returned that takes :math:`\\mb{v}` as an argument. If\n    `fun` has multiple positional parameters, the Jacobian can be taken\n    with respect to only one of them by setting the `jidx` parameter of\n    this function to the positional index of that parameter.\n\n    Args:\n        fun: Function for which the Jacobian is implicitly computed.\n        primals: Sequence of values at which the Jacobian is\n           evaluated, with length equal to the number of positional\n           arguments of `fun`.\n        jidx: Index of the positional parameter of `fun` with respect\n           to which the Jacobian is taken.\n\n    Returns:\n        A pair `(primals_out, conj_vjp)` where `primals_out` is the\n        output of `fun` evaluated at `primals`, i.e. `primals_out\n        = fun(*primals)`, and `conj_vjp` is a function that computes the\n        product of the conjugate (Hermitian) transpose of the Jacobian of\n        `fun` and its argument. If the `jidx` parameter is an integer,\n        then the Jacobian is only taken with respect to the coresponding\n        positional parameter of `fun`.\n    \"\"\"\n\n    if jidx is None:\n        primals_out, fun_vjp = jax.vjp(fun, *primals)\n    else:\n        fixidx = tuple(range(0, jidx)) + tuple(range(jidx + 1, len(primals)))\n        fixprm = primals[0:jidx] + primals[jidx + 1 :]\n        pfun = scico.util.partial(fun, fixidx, *fixprm)\n        primals_out, fun_vjp = jax.vjp(pfun, primals[jidx])\n\n    def conj_vjp(tangent):\n        return jax.tree_util.tree_map(jax.numpy.conj, fun_vjp(tangent.conj()))\n\n    return primals_out, conj_vjp\n\n\n# Append docstring from original jax function\nfor name in __all__:\n    if name == \"cvjp\":\n        continue\n    func = getattr(sys.modules[__name__], name)\n    jaxfn = jax.linear_transpose if name == \"linear_adjoint\" else None\n    func.__doc__ = _append_jax_docs(func, jaxfn=jaxfn)\n"
  },
  {
    "path": "scico/_version.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Support functions for determining the package version.\"\"\"\n\nimport os\nimport re\nfrom ast import parse\nfrom subprocess import PIPE, Popen\nfrom typing import Any, Optional, Tuple, Union\n\n\ndef root_init_path() -> str:  # pragma: no cover\n    \"\"\"Get the path to the package root `__init__.py` file.\n\n    Returns:\n       Path to the package root `__init__.py` file.\n    \"\"\"\n    return os.path.join(os.path.dirname(__file__), \"__init__.py\")\n\n\ndef variable_assign_value(path: str, var: str) -> Any:\n    \"\"\"Get variable initialization value from a Python file.\n\n    Args:\n        path: Path of Python file.\n        var: Name of variable.\n\n    Returns:\n        Value to which variable `var` is initialized.\n\n    Raises:\n        RuntimeError: If the statement initializing variable `var` is not\n           found.\n    \"\"\"\n    with open(path) as f:\n        try:\n            # See https://stackoverflow.com/a/30471662\n            value_obj = parse(next(filter(lambda line: line.startswith(var), f))).body[0].value  # type: ignore\n            value = value_obj.value  # type: ignore\n        except StopIteration:\n            raise RuntimeError(f\"Could not find initialization of variable {var}\")\n    return value\n\n\ndef init_variable_assign_value(var: str) -> Any:  # pragma: no cover\n    \"\"\"Get variable initialization value from package `__init__.py` file.\n\n    Args:\n        var: Name of variable.\n\n    Returns:\n        Value to which variable `var` is initialized.\n\n    Raises:\n        RuntimeError: If the statement initializing variable `var` is not\n           found.\n    \"\"\"\n    return variable_assign_value(root_init_path(), var)\n\n\ndef current_git_hash() -> Optional[str]:  # nosec  pragma: no cover\n    \"\"\"Get current short git hash.\n\n    Returns:\n       Short git hash of current commit, or ``None`` if no git repo found.\n    \"\"\"\n    process = Popen([\"git\", \"rev-parse\", \"--short\", \"HEAD\"], shell=False, stdout=PIPE, stderr=PIPE)\n    git_hash: Optional[str] = process.communicate()[0].strip().decode(\"utf-8\")\n    if git_hash == \"\":\n        git_hash = None\n    return git_hash\n\n\ndef package_version(split: bool = False) -> Union[str, Tuple[str, str]]:  # pragma: no cover\n    \"\"\"Get current package version.\n\n    Args:\n        split: Flag indicating whether to return the package version as a\n           single string or split into a tuple of components.\n\n    Returns:\n        Package version string or tuple of strings.\n    \"\"\"\n    version = init_variable_assign_value(\"__version__\")\n    # don't extend purely numeric version numbers, possibly ending with post<n>\n    if re.match(r\"^[0-9\\.]+(post[0-9]+)?$\", version):\n        git_hash = None\n    else:\n        git_hash = current_git_hash()\n    if git_hash:\n        git_hash = \"+\" + git_hash\n    else:\n        git_hash = \"\"\n    if split:\n        version = (version, git_hash)\n    else:\n        version = version + git_hash\n    return version\n"
  },
  {
    "path": "scico/data/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Data files for usage examples.\"\"\"\n\nimport os.path\nfrom typing import Optional\n\nfrom imageio.v3 import imread\n\nimport scico.numpy as snp\n\n__all__ = [\"kodim23\"]\n\n\ndef _imread(filename: str, path: Optional[str] = None, asfloat: bool = False) -> snp.Array:\n    \"\"\"Read an image from disk.\n\n    Args:\n        filename: Base filename (i.e. without path) of image file.\n        path: Path to directory containing the image file.\n        asfloat: Flag indicating whether the returned image should be\n          converted to :attr:`~numpy.float32` dtype with a range [0, 1].\n\n    Returns:\n       Image data array.\n    \"\"\"\n\n    if path is None:\n        path = os.path.join(os.path.dirname(__file__), \"examples\")\n    im = imread(os.path.join(path, filename))\n    if asfloat:\n        im = im.astype(snp.float32) / 255.0\n    return im\n\n\ndef kodim23(asfloat: bool = False) -> snp.Array:\n    \"\"\"Return the `kodim23` test image.\n\n    Args:\n        asfloat: Flag indicating whether the returned image should be\n          converted to :attr:`~numpy.float32` dtype with a range [0, 1].\n\n    Returns:\n       Image data array.\n    \"\"\"\n\n    return _imread(\"kodim23.png\", asfloat=asfloat)\n\n\ndef _flax_data_path(filename: str) -> str:\n    \"\"\"Get the full filename of a flax data file.\n\n    Args:\n        filename: Base filename (i.e. without path) of data file.\n\n    Returns:\n       Full filename, with path, of data file.\n    \"\"\"\n\n    return os.path.join(os.path.dirname(__file__), \"flax\", filename)\n"
  },
  {
    "path": "scico/denoiser.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Interfaces to standard denoisers.\"\"\"\n\nfrom typing import Any, Optional, Union\n\nimport numpy as np\n\nimport jax\n\ntry:\n    import bm3d as tubm3d\nexcept ImportError:\n    have_bm3d = False\n    BM3DProfile = Any\nelse:\n    have_bm3d = True\n    from bm3d.profiles import BM3DProfile  # type: ignore\n\ntry:\n    import bm4d as tubm4d\nexcept ImportError:\n    have_bm4d = False\n    BM4DProfile = Any\nelse:\n    have_bm4d = True\n    from bm4d.profiles import BM4DProfile  # type: ignore\n\nimport scico.numpy as snp\nfrom scico.data import _flax_data_path\nfrom scico.flax import DnCNNNet, FlaxMap, load_variables\n\n\ndef bm3d(x: snp.Array, sigma: float, is_rgb: bool = False, profile: Union[BM3DProfile, str] = \"np\"):\n    r\"\"\"An interface to the BM3D denoiser :cite:`dabov-2008-image`.\n\n    BM3D denoising is performed using the\n    `code <https://pypi.org/project/bm3d>`__ released with\n    :cite:`makinen-2019-exact`. Since this package is an interface\n    to compiled C code, JAX features such as automatic differentiation\n    and support for GPU devices are not available.\n\n    Args:\n        x: Input image. Expected to be a 2D array (gray-scale denoising)\n            or 3D array (color denoising). Higher-dimensional arrays are\n            tolerated only if the additional dimensions are singletons.\n            For color denoising, the color channel is assumed to be in\n            the last non-singleton dimension.\n        sigma: Noise parameter.\n        is_rgb: Flag indicating use of BM3D with a color transform.\n            Default: ``False``.\n        profile: Parameter configuration for BM3D.\n\n    Returns:\n        Denoised output.\n    \"\"\"\n    if not have_bm3d:\n        raise RuntimeError(\"Package bm3d is required for use of this function.\")\n\n    if is_rgb is True:\n\n        def bm3d_eval(x: snp.Array, sigma: float):\n            return tubm3d.bm3d_rgb(x, sigma, profile=profile)\n\n    else:\n\n        def bm3d_eval(x: snp.Array, sigma: float):\n            return tubm3d.bm3d(x, sigma, profile=profile)\n\n    if snp.util.is_complex_dtype(x.dtype):\n        raise TypeError(f\"BM3D requires real-valued inputs, got {x.dtype}.\")\n\n    # Support arrays with more than three axes when the additional axes are singletons.\n    x_in_shape = x.shape\n\n    if isinstance(x.ndim, tuple) or x.ndim < 2:\n        raise ValueError(\n            \"BM3D requires two-dimensional or three dimensional inputs; got ndim = {x.ndim}.\"\n        )\n\n    # This check is also performed inside the BM3D call, but due to the callback,\n    # no exception is raised and the program will crash with no traceback.\n    # NOTE: if BM3D is extended to allow for different profiles, the block size must be\n    #       updated; this presumes 'np' profile (bs=8)\n    if profile == \"np\" and np.min(x.shape[:2]) < 8:\n        raise ValueError(\n            \"Two leading dimensions of input cannot be smaller than block size \"\n            f\"(8); got image size = {x.shape}.\"\n        )\n\n    if x.ndim > 3:\n        if all(k == 1 for k in x.shape[3:]):\n            x = x.squeeze()\n        else:\n            raise ValueError(\n                \"Arrays with more than three axes are only supported when \"\n                \" the additional axes are singletons.\"\n            )\n\n    y = jax.pure_callback(\n        lambda args: bm3d_eval(*args).astype(x.dtype),\n        jax.ShapeDtypeStruct(x.shape, x.dtype),\n        (x, sigma),\n    )\n\n    # undo squeezing, if neccessary\n    y = y.reshape(x_in_shape)\n\n    return y\n\n\ndef bm4d(x: snp.Array, sigma: float, profile: Union[BM4DProfile, str] = \"np\"):\n    r\"\"\"An interface to the BM4D denoiser :cite:`maggioni-2012-nonlocal`.\n\n    BM4D denoising is performed using the\n    `code <https://pypi.org/project/bm4d/>`__ released by the authors of\n    :cite:`maggioni-2012-nonlocal`. Since this package is an interface\n    to compiled C code, JAX features such as automatic differentiation\n    and support for GPU devices are not available.\n\n    Args:\n        x: Input image. Expected to be a 3D array. Higher-dimensional\n            arrays are tolerated only if the additional dimensions are\n            singletons.\n        sigma: Noise parameter.\n        profile: Parameter configuration for BM4D.\n\n    Returns:\n        Denoised output.\n    \"\"\"\n    if not have_bm4d:\n        raise RuntimeError(\"Package bm4d is required for use of this function.\")\n\n    def bm4d_eval(x: snp.Array, sigma: float):\n        return tubm4d.bm4d(x, sigma, profile=profile)\n\n    if snp.util.is_complex_dtype(x.dtype):\n        raise TypeError(f\"BM4D requires real-valued inputs, got {x.dtype}.\")\n\n    # Support arrays with more than three axes when the additional axes are singletons.\n    x_in_shape = x.shape\n\n    if isinstance(x.ndim, tuple) or x.ndim < 3:\n        raise ValueError(f\"BM4D requires three-dimensional inputs; got ndim = {x.ndim}.\")\n\n    # This check is also performed inside the BM4D call, but due to the callback,\n    # no exception is raised and the program will crash with no traceback.\n    # NOTE: if BM4D is extended to allow for different profiles, the block size must be\n    #       updated; this presumes 'np' profile (bs=8)\n    if profile == \"np\" and np.min(x.shape[:3]) < 8:\n        raise ValueError(\n            \"Three leading dimensions of input cannot be smaller than block size \"\n            f\"(8); got image size = {x.shape}.\"\n        )\n\n    if x.ndim > 3:\n        if all(k == 1 for k in x.shape[3:]):\n            x = x.squeeze()\n        else:\n            raise ValueError(\n                \"Arrays with more than three axes are only supported when \"\n                \" the additional axes are singletons.\"\n            )\n\n    y = jax.pure_callback(\n        lambda args: bm4d_eval(*args).astype(x.dtype),\n        jax.ShapeDtypeStruct(x.shape, x.dtype),\n        (x, sigma),\n    )\n\n    # undo squeezing, if neccessary\n    y = y.reshape(x_in_shape)\n\n    return y\n\n\nclass DnCNN(FlaxMap):\n    \"\"\"Flax implementation of the DnCNN denoiser.\n\n    A flax implementation of the DnCNN denoiser :cite:`zhang-2017-dncnn`.\n    Note that :class:`.DnCNNNet` represents an untrained form of the\n    generic DnCNN CNN structure, while this class represents a trained\n    form with six or seventeen layers.\n\n    The standard DnCNN as proposed in :cite:`zhang-2017-dncnn` does not\n    have a noise level input. This implementation of DnCNN also supports\n    a custom variant that includes a noise standard deviation input,\n    `sigma`, which is included in the network as an additional channel\n    consisting of a constant array with value `sigma`. This network was\n    trained with image data on the range [0, 1], and with noise standard\n    deviations ranging from 0.0 to 0.2. It is worth noting that DRUNet\n    :cite:`zhang-2021-plug`, another recent approach to including a noise\n    level input in a CNN denoiser, is based on a substantially different\n    network architecture.\n    \"\"\"\n\n    def __init__(self, variant: str = \"6M\"):\n        \"\"\"\n        Note that all DnCNN models are trained for single-channel image\n        input. Multi-channel input is supported via independent denoising\n        of each channel. Input images are expected to have pixel values\n        in the range [0, 1].\n\n        Args:\n            variant: Identify the DnCNN model to be used. Options are\n                '6L', '6M' (default), '6H', '6N', '17L', '17M', '17H',\n                and '17N', where the integer indicates the number of\n                layers in the network, and the postfix indicates the\n                training noise standard deviation (with respect to data\n                in the range [0, 1]): L (low) = 0.06, M (mid) = 0.10,\n                H (high) = 0.20, or N indicating that a noise standard\n                deviation input, `sigma`, is available.\n        \"\"\"\n\n        self.variant = variant\n\n        if variant not in [\"6L\", \"6M\", \"6H\", \"17L\", \"17M\", \"17H\", \"6N\", \"17N\"]:\n            raise ValueError(f\"Invalid value {variant} of parameter variant.\")\n        if variant[0] == \"6\":\n            nlayer = 6\n        else:\n            nlayer = 17\n        channels = 2 if variant in [\"6N\", \"17N\"] else 1\n\n        if variant in [\"6N\", \"17N\"]:\n            self.is_blind = False\n        else:\n            self.is_blind = True\n\n        model = DnCNNNet(depth=nlayer, channels=channels, num_filters=64, dtype=np.float32)\n        variables = load_variables(_flax_data_path(\"dncnn%s.mpk\" % variant))\n        super().__init__(model, variables)\n\n    def __call__(self, x: snp.Array, sigma: Optional[float] = None) -> snp.Array:\n        r\"\"\"Apply DnCNN denoiser.\n\n        Args:\n            x: Input array.\n            sigma: Noise standard deviation (for variants `6N` and `17N`).\n\n        Returns:\n            Denoised output.\n        \"\"\"\n        if sigma is not None and self.is_blind:\n            raise ValueError(\n                \"A non-default value for the sigma parameter may \"\n                \"only be specified when the variant is 6N or 17N\"\n                f\"; got variant = {self.variant}.\"\n            )\n\n        if sigma is None and not self.is_blind:\n            raise ValueError(\n                \"A float value must be specified for the sigma \"\n                \"parameter when the variant is 6N or 17N.\"\n            )\n\n        if snp.util.is_complex_dtype(x.dtype):\n            raise TypeError(f\"DnCNN requries real-valued inputs, got {x.dtype}.\")\n\n        if isinstance(x.ndim, tuple) or x.ndim < 2:\n            raise ValueError(\n                \"DnCNN requires two-dimensional (M, N) or three-dimensional (M, N, C)\"\n                f\" inputs; got ndim = {x.ndim}.\"\n            )\n\n        x_in_shape = x.shape\n        if x.ndim > 3:\n            if all(k == 1 for k in x.shape[3:]):\n                x = x.squeeze()\n            else:\n                raise ValueError(\n                    \"Arrays with more than three axes are only supported when\"\n                    \" the additional axes are singletons.\"\n                )\n\n        if x.ndim == 3:\n            y = snp.swapaxes(x, 0, -1)\n\n            if sigma is not None:\n                y = snp.stack([y, snp.ones_like(y) * sigma], -1)\n            else:\n                y = y[..., np.newaxis]\n\n            # swap channel axis to batch axis and add singleton axis at end\n            y = super().__call__(y)\n            # drop singleton axis and swap axes back to original positions\n            y = snp.swapaxes(y[..., 0], 0, -1)\n\n        else:\n            if sigma is not None:\n                x = snp.stack([x, snp.ones_like(x) * sigma], -1)\n                x = x[np.newaxis, ...]\n\n            y = super().__call__(x)\n\n            if sigma is not None:\n                y = y[0, ..., 0]\n\n        y = y.reshape(x_in_shape)\n\n        return y\n"
  },
  {
    "path": "scico/diagnostics.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Diagnostic information for iterative solvers.\"\"\"\n\nimport re\nimport warnings\nfrom collections import OrderedDict, namedtuple\nfrom typing import List, NamedTuple, Optional, Tuple, Union\n\nfrom scico.numpy.util import is_array\n\n\nclass IterationStats:\n    \"\"\"Display and record iterative algorithms statistics.\n\n    Display and record statistics related to convergence of iterative\n    algorithms.\n    \"\"\"\n\n    def __init__(\n        self,\n        fields: OrderedDict,\n        ident: Optional[dict] = None,\n        display: bool = False,\n        period: int = 1,\n        shift_cycles: bool = True,\n        overwrite: bool = True,\n        colsep: int = 2,\n    ):\n        \"\"\"\n        The `fields` parameter represents an OrderedDict (to ensure that\n        field order is retained) specifying field names for each value to\n        be inserted and a corresponding format string for when it is\n        displayed. When inserted values are printed in tabular form, the\n        field lengths are taken as the maxima of the header string\n        lengths and the field lengths embedded in the format strings (if\n        specified). For best results, the field lengths should be\n        manually specified based on knowledge of the ranges of values\n        that may be encountered. For example, for a '%e' format string,\n        the specified field length should be at least the precision (e.g.\n        '%.2e' specifies a precision of 2 places) plus 6 when only\n        positive values may encountered, and plus 7 when negative values\n        may be encountered.\n\n        Args:\n            fields: A dictionary associating field names with format\n                strings for displaying the corresponding values.\n            ident: A dictionary associating field names.\n                with corresponding valid identifiers for use within the\n                namedtuple used to record results. Defaults to ``None``.\n            display: Flag indicating whether results should be printed\n                to stdout. Defaults to ``False``.\n            period: Only display one result in every cycle of length\n                `period`.\n            shift_cycles: If ``True``, apply an offset to the iteration\n                count so that display cycles end at 0, `period` - 1, etc.\n                Otherwise, cycles end at `period`, 2 * `period`, etc.\n            overwrite: If ``True``, display all results, but each one\n                overwrites the next, except for one result per cycle.\n            colsep: Number of spaces seperating fields in displayed\n                tables. Defaults to 2.\n\n        Raises:\n            TypeError: If the `fields` parameter is not a dict.\n        \"\"\"\n\n        # Parameter fields must be specified as an OrderedDict to ensure\n        # that field order is retained\n        if not isinstance(fields, dict):\n            raise TypeError(\"Argument 'fields' must be an instance of dict.\")\n        # Subsampling rate of results that are to be displayed\n        self.period: int = period\n        # Offset to iteration count for determining start of period\n        self.period_offset = 1 if shift_cycles else 0\n        # Flag indicating whether to display and overwrite, or not display at all\n        self.overwrite: bool = overwrite\n        # Number of spaces seperating fields in displayed tables\n        self.colsep: int = colsep\n        # Main list of inserted values\n        self.iterations: List = []\n        # Total length of header string in displayed tables\n        self.headlength: int = 0\n        # List of field names\n        self.fieldname: List[str] = []\n        # List of field format strings\n        self.fieldformat: List[str] = []\n        # List of lengths of each field in displayed tables\n        self.fieldlength: List[int] = []\n        # Names of fields in namedtuple used to record iteration values\n        self.tuplefields: List[str] = []\n        # Compile regex for decomposing format strings\n        fmre = re.compile(r\"%(\\+?-?)((?:\\d+)?)(\\.?)((?:\\d+)?)([a-z])\")\n        # Iterate over field names\n        for name in fields:\n            # Get format string and decompose it using compiled regex\n            fmt = fields[name]\n            fmtmatch = fmre.match(fmt)\n            if not fmtmatch:\n                raise ValueError(f\"Format string '{fmt}' could not be parsed.\")\n            fmflg, fmlen, fmdot, fmprc, fmtyp = fmtmatch.groups()\n            flen = len(fmt % 0)\n            # Warn if actual formatted length longer than specified field\n            # length, e.g. as in \"%4e\"\n            if fmlen != \"\" and flen > int(fmlen):\n                warnings.warn(\n                    f'Actual length {flen} of format \"{fmt}\" for field '\n                    f'\"{name}\" is longer than specified value {fmlen}',\n                    stacklevel=2,\n                )\n            # If the actual formatted length is less than that of the header\n            # string, insert a field length specifier to increase the\n            # length to that of the header string\n            if flen < len(name):\n                fmt = f\"%{fmflg}{len(name)}{fmdot}{fmprc}{fmtyp}\"\n                flen = len(name)\n            self.fieldname.append(name)\n            self.fieldformat.append(fmt)\n            self.fieldlength.append(flen)\n            self.headlength += flen + colsep\n\n            # If a distinct identifier is specified for this field, use it\n            # as the namedtuple identifier, otherwise compute it from the\n            # field name\n            if ident is not None and name in ident:\n                self.tuplefields.append(ident[name])\n            else:\n                # See https://stackoverflow.com/a/3305731\n                tfnm = re.sub(r\"\\W+|^(?=\\d)\", \"_\", name)\n                if tfnm[0] == \"_\":\n                    tfnm = tfnm[1:]\n                self.tuplefields.append(tfnm)\n\n        # Decrement head length to account for final colsep added\n        self.headlength -= colsep\n\n        # Construct namedtuple used to record values\n        self.IterTuple = namedtuple(\"IterationStatsTuple\", self.tuplefields)  # type: ignore\n\n        # Set up table header string display if requested\n        self.display = display\n        self.disphdr = None\n        if display:\n            self.disphdr = (\n                (\" \" * colsep).join(\n                    [\"%-*s\" % (fl, fn) for fl, fn in zip(self.fieldlength, self.fieldname)]\n                )\n                + \"\\n\"\n                + \"-\" * self.headlength\n            )\n\n    def insert(self, values: Union[List, Tuple]):\n        \"\"\"Insert a list of values for a single iteration.\n\n        Args:\n            values: Statistics for a single iteration.\n        \"\"\"\n\n        scalar_values = [v.item() if is_array(v) else v for v in values]\n        self.iterations.append(self.IterTuple(*scalar_values))\n\n        if self.display:\n            if self.disphdr is not None:\n                print(self.disphdr)\n                self.disphdr = None\n            if self.overwrite:\n                if (len(self.iterations) - self.period_offset) % self.period == 0:\n                    end = \"\\n\"\n                else:\n                    end = \"\\r\"\n                print((\" \" * self.colsep).join(self.fieldformat) % values, end=end)\n            else:\n                if (len(self.iterations) - self.period_offset) % self.period == 0:\n                    print((\" \" * self.colsep).join(self.fieldformat) % values)\n\n    def end(self):\n        \"\"\"Mark end of iterations.\n\n        This method should be called at the end of a set of iterations.\n        Its only function is to ensure that the displayed output is left\n        in an appropriate state when overwriting is active with a display\n        period other than unity.\n        \"\"\"\n        if (\n            self.display\n            and self.overwrite\n            and self.period > 1\n            and (len(self.iterations) - self.period_offset) % self.period\n        ):\n            print()\n\n    def history(self, transpose: bool = False) -> Union[List[NamedTuple], Tuple[List]]:\n        \"\"\"Retrieve record of all inserted iterations.\n\n        Args:\n            transpose: Flag indicating whether results should be returned\n                in \"transposed\" form, i.e. as a namedtuple of lists\n                rather than a list of namedtuples.\n\n        Returns:\n            list of namedtuple or namedtuple of lists: Record of all\n            inserted iterations.\n        \"\"\"\n\n        if transpose and self.iterations:\n            return self.IterTuple(\n                *[\n                    [self.iterations[m][n] for m in range(len(self.iterations))]\n                    for n in range(len(self.iterations[0]))\n                ]\n            )\n        return self.iterations\n"
  },
  {
    "path": "scico/examples.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Utility functions used by example scripts.\"\"\"\n\nimport glob\nimport os\nimport tempfile\nimport zipfile\nfrom functools import partial\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\n\nimport jax\n\nimport imageio.v3 as iio\n\nimport scico.numpy as snp\nfrom scico import random, util\nfrom scico.typing import Shape\nfrom scipy.io import loadmat\nfrom scipy.ndimage import zoom\n\n\ndef rgb2gray(rgb: np.ndarray) -> np.ndarray:\n    \"\"\"Convert an RGB image (or images) to grayscale.\n\n    Args:\n        rgb: RGB image as Nr x Nc x 3 or Nr x Nc x 3 x K array.\n\n    Returns:\n        Grayscale image as Nr x Nc or Nr x Nc x K array.\n    \"\"\"\n\n    shape: Union[Tuple[int, int, int], Tuple[int, int, int, int]]\n    if rgb.ndim == 3:\n        shape = (1, 1, 3)\n    else:\n        shape = (1, 1, 3, 1)\n    w = np.array([0.299, 0.587, 0.114], dtype=rgb.dtype).reshape(shape)\n    return np.sum(w * rgb, axis=2)\n\n\ndef volume_read(path: str, ext: str = \"tif\") -> np.ndarray:\n    \"\"\"Read a 3D volume from a set of files in the specified directory.\n\n    All files with extension `ext` (i.e. matching glob `*.ext`)\n    in directory `path` are assumed to be image files and are read.\n    The filenames are assumed to be such that their alphanumeric\n    ordering corresponds to their order as volume slices.\n\n    Args:\n        path: Path to directory containing the image files.\n        ext: Filename extension.\n\n    Returns:\n        Volume as a 3D array.\n    \"\"\"\n\n    slices = []\n    for file in sorted(glob.glob(os.path.join(path, \"*.\" + ext))):\n        image = iio.imread(file)\n        slices.append(image)\n    return np.dstack(slices)\n\n\ndef get_epfl_deconv_data(channel: int, path: str, verbose: bool = False):  # pragma: no cover\n    \"\"\"Download example data from EPFL Biomedical Imaging Group.\n\n    Download deconvolution problem data from EPFL Biomedical Imaging\n    Group. The downloaded data is converted to `.npz` format for\n    convenient access via :func:`numpy.load`. The converted data is saved\n    in a file `epfl_big_deconv_<channel>.npz` in the directory specified\n    by `path`.\n\n    Args:\n        channel: Channel number between 0 and 2.\n        path: Directory in which converted data is saved.\n        verbose: Flag indicating whether to print status messages.\n    \"\"\"\n\n    # data source URL and filenames\n    data_base_url = \"http://bigwww.epfl.ch/deconvolution/bio/\"\n    data_zip_files = [\"CElegans-CY3.zip\", \"CElegans-DAPI.zip\", \"CElegans-FITC.zip\"]\n    psf_zip_files = [\"PSF-\" + data for data in data_zip_files]\n\n    # ensure path directory exists\n    if not os.path.isdir(path):\n        raise ValueError(f\"Path {path} does not exist or is not a directory.\")\n\n    # create temporary directory\n    temp_dir = tempfile.TemporaryDirectory()\n    # download data and psf files for selected channel into temporary directory\n    for zip_file in (data_zip_files[channel], psf_zip_files[channel]):\n        if verbose:\n            print(f\"Downloading {zip_file} from {data_base_url}\")\n        data = util.url_get(data_base_url + zip_file)\n        f = open(os.path.join(temp_dir.name, zip_file), \"wb\")\n        f.write(data.read())\n        f.close()\n        if verbose:\n            print(\"Download complete\")\n\n    # unzip downloaded data into temporary directory\n    for zip_file in (data_zip_files[channel], psf_zip_files[channel]):\n        if verbose:\n            print(f\"Extracting content from zip file {zip_file}\")\n        with zipfile.ZipFile(os.path.join(temp_dir.name, zip_file), \"r\") as zip_ref:\n            zip_ref.extractall(temp_dir.name)\n\n    # read unzipped data files into 3D arrays and save as .npz\n    zip_file = data_zip_files[channel]\n    y = volume_read(os.path.join(temp_dir.name, zip_file[:-4]))\n    zip_file = psf_zip_files[channel]\n    psf = volume_read(os.path.join(temp_dir.name, zip_file[:-4]))\n\n    npz_file = os.path.join(path, f\"epfl_big_deconv_{channel}.npz\")\n    if verbose:\n        print(f\"Saving as {npz_file}\")\n    np.savez(npz_file, y=y, psf=psf)\n\n\ndef epfl_deconv_data(\n    channel: int, verbose: bool = False, cache_path: Optional[str] = None\n) -> Tuple[np.ndarray, np.ndarray]:\n    \"\"\"Get deconvolution problem data from EPFL Biomedical Imaging Group.\n\n    If the data has previously been downloaded, it will be retrieved from\n    a local cache.\n\n    Args:\n        channel: Channel number between 0 and 2.\n        verbose: Flag indicating whether to print status messages.\n        cache_path: Directory in which downloaded data is cached. The\n           default is `~/.cache/scico/examples`, where `~` represents\n           the user home directory.\n\n    Returns:\n       tuple: A tuple (y, psf) containing:\n\n           - **y** : (np.ndarray): Blurred channel data.\n           - **psf** : (np.ndarray): Channel psf.\n    \"\"\"\n\n    # set default cache path if not specified\n    if cache_path is None:  # pragma: no cover\n        cache_path = os.path.join(os.path.expanduser(\"~\"), \".cache\", \"scico\", \"examples\")\n\n    # create cache directory and download data if not already present\n    npz_file = os.path.join(cache_path, f\"epfl_big_deconv_{channel}.npz\")\n    if not os.path.isfile(npz_file):  # pragma: no cover\n        if not os.path.isdir(cache_path):\n            os.makedirs(cache_path)\n        get_epfl_deconv_data(channel, path=cache_path, verbose=verbose)\n\n    # load data and return y and psf arrays converted to float32\n    npz = np.load(npz_file)\n    y = npz[\"y\"].astype(np.float32)\n    psf = npz[\"psf\"].astype(np.float32)\n    return y, psf\n\n\ndef get_ucb_diffusercam_data(path: str, verbose: bool = False):  # pragma: no cover\n    \"\"\"Download data from UC Berkeley Waller Lab diffusercam project.\n\n    Download deconvolution problem data from UC Berkeley Waller Lab\n    diffusercam project.  The downloaded data is converted to `.npz`\n    format for convenient access via :func:`numpy.load`.  The\n    converted data is saved in a file `ucb_diffcam_data.npz.npz` in\n    the directory specified by `path`.\n\n    Args:\n        path: Directory in which converted data is saved.\n        verbose: Flag indicating whether to print status messages.\n    \"\"\"\n\n    # data source URL, filenames, and request header\n    data_base_url = \"https://github.com/Waller-Lab/DiffuserCam/blob/master/example_data/\"\n    data_files = [\"example_psfs.mat\", \"example_raw.png\"]\n    headers = {\"User-Agent\": \"Mozilla/5.0 (X11; Linux x86_64)\", \"Referer\": data_base_url}\n\n    # ensure path directory exists\n    if not os.path.isdir(path):\n        raise ValueError(f\"Path {path} does not exist or is not a directory.\")\n\n    # create temporary directory\n    temp_dir = tempfile.TemporaryDirectory()\n    # download data files into temporary directory\n    for data_file in data_files:\n        if verbose:\n            print(f\"Downloading {data_file} from {data_base_url}\")\n        data = util.url_get(data_base_url + data_file + \"?raw=true\", headers=headers)\n        f = open(os.path.join(temp_dir.name, data_file), \"wb\")\n        f.write(data.read())\n        f.close()\n        if verbose:\n            print(\"Download complete\")\n\n    # load data, normalize it, and save as npz\n    y = iio.imread(os.path.join(temp_dir.name, \"example_raw.png\"))\n    y = y.astype(np.float32)\n    y -= 100.0\n    y /= y.max()\n    mat = loadmat(os.path.join(temp_dir.name, \"example_psfs.mat\"))\n    psf = mat[\"psf\"].astype(np.float64)\n    psf -= 102.0\n    psf /= np.linalg.norm(psf, axis=(0, 1)).min()\n\n    # save as .npz\n    npz_file = os.path.join(path, \"ucb_diffcam_data.npz\")\n    if verbose:\n        print(f\"Saving as {npz_file}\")\n    np.savez(npz_file, y=y, psf=psf)\n\n\ndef ucb_diffusercam_data(\n    verbose: bool = False, cache_path: Optional[str] = None\n) -> Tuple[np.ndarray, np.ndarray]:\n    \"\"\"Get example data from UC Berkeley Waller Lab diffusercam project.\n\n    If the data has previously been downloaded, it will be retrieved from\n    a local cache.\n\n    Args:\n        verbose: Flag indicating whether to print status messages.\n        cache_path: Directory in which downloaded data is cached. The\n           default is `~/.cache/scico/examples`, where `~` represents\n           the user home directory.\n\n    Returns:\n       tuple: A tuple (y, psf) containing:\n\n           - **y** : (np.ndarray): Measured image\n           - **psf** : (np.ndarray): Stack of psfs.\n    \"\"\"\n\n    # set default cache path if not specified\n    if cache_path is None:  # pragma: no cover\n        cache_path = os.path.join(os.path.expanduser(\"~\"), \".cache\", \"scico\", \"examples\")\n\n    # create cache directory and download data if not already present\n    npz_file = os.path.join(cache_path, \"ucb_diffcam_data.npz\")\n    if not os.path.isfile(npz_file):  # pragma: no cover\n        if not os.path.isdir(cache_path):\n            os.makedirs(cache_path)\n        get_ucb_diffusercam_data(path=cache_path, verbose=verbose)\n\n    # load data and return y and psf arrays converted to float32\n    npz = np.load(npz_file)\n    y = npz[\"y\"].astype(np.float32)\n    psf = npz[\"psf\"].astype(np.float64)\n    return y, psf\n\n\ndef downsample_volume(vol: np.ndarray, rate: int) -> np.ndarray:\n    \"\"\"Downsample a 3D array.\n\n    Downsample a 3D array. If the volume dimensions can be divided by\n    `rate`, this is achieved via averaging distinct `rate` x `rate` x\n    `rate` block in `vol`. Otherwise it is achieved via a call to\n    :func:`scipy.ndimage.zoom`.\n\n    Args:\n        vol: Input volume.\n        rate: Downsampling rate.\n\n    Returns:\n        Downsampled volume.\n    \"\"\"\n\n    if rate == 1:\n        return vol\n\n    if np.all([n % rate == 0 for n in vol.shape]):\n        vol = np.mean(np.reshape(vol, (-1, rate, vol.shape[1], vol.shape[2])), axis=1)\n        vol = np.mean(np.reshape(vol, (vol.shape[0], -1, rate, vol.shape[2])), axis=2)\n        vol = np.mean(np.reshape(vol, (vol.shape[0], vol.shape[1], -1, rate)), axis=3)\n    else:\n        vol = zoom(vol, 1.0 / rate)\n\n    return vol\n\n\ndef tile_volume_slices(x: np.ndarray, sep_width: int = 10) -> np.ndarray:\n    \"\"\"Make an image with tiled slices from an input volume.\n\n    Make an image with tiled `xy`, `xz`, and `yz` slices from an input\n    volume.\n\n    Args:\n        x: Input volume consisting of a 3D or 4D array. If the input is\n           4D, the final axis represents a channel index.\n        sep_width: Number of pixels separating the slices in the output\n           image.\n\n    Returns:\n        Image containing tiled slices.\n    \"\"\"\n\n    if x.ndim == 3:\n        fshape: Tuple[int, ...] = (x.shape[0], sep_width)\n    else:\n        fshape = (x.shape[0], sep_width, 3)\n    out = np.concatenate(\n        (\n            x[:, :, x.shape[2] // 2],\n            np.full(fshape, np.nan),\n            x[:, x.shape[1] // 2, :],\n        ),\n        axis=1,\n    )\n\n    if x.ndim == 3:\n        fshape0: Tuple[int, ...] = (sep_width, out.shape[1])\n        fshape1: Tuple[int, ...] = (x.shape[2], x.shape[2] + sep_width)\n        trans: Tuple[int, ...] = (1, 0)\n\n    else:\n        fshape0 = (sep_width, out.shape[1], 3)\n        fshape1 = (x.shape[2], x.shape[2] + sep_width, 3)\n        trans = (1, 0, 2)\n    out = np.concatenate(\n        (\n            out,\n            np.full(fshape0, np.nan),\n            np.concatenate(\n                (\n                    x[x.shape[0] // 2, :, :].transpose(trans),\n                    np.full(fshape1, np.nan),\n                ),\n                axis=1,\n            ),\n        ),\n        axis=0,\n    )\n\n    out = np.where(np.isnan(out), np.nanmax(out), out)\n\n    return out\n\n\ndef gaussian(shape: Shape, sigma: Optional[np.ndarray] = None) -> np.ndarray:\n    r\"\"\"Construct a multivariate Gaussian distribution function.\n\n    Construct a zero-mean multivariate Gaussian distribution function\n\n    .. math::\n        f(\\mb{x}) = (2 \\pi)^{-N/2} \\, \\det(\\Sigma)^{-1/2} \\, \\exp \\left(\n        -\\frac{\\mb{x}^T \\, \\Sigma^{-1} \\, \\mb{x}}{2} \\right) \\;,\n\n    where :math:`\\Sigma` is the covariance matrix of the distribution.\n\n    Args:\n        shape: Shape of output array.\n        sigma: Covariance matrix.\n\n    Returns:\n        Sampled function.\n\n    Raises:\n        ValueError: If the array `sigma` cannot be inverted.\n    \"\"\"\n\n    if sigma is None:\n        sigma = np.diag(np.array(shape) / 7) ** 2\n    N = len(shape)\n    try:\n        sigmainv = np.linalg.inv(sigma)\n        sigmadet = np.linalg.det(sigma)\n    except np.linalg.LinAlgError as e:\n        raise ValueError(f\"Invalid covariance matrix {sigma}.\") from e\n    grd = np.stack(np.mgrid[[slice(-(n - 1) / 2, (n + 1) / 2) for n in shape]], axis=-1)\n    sigmax = np.dot(grd, sigmainv)\n    xtsigmax = np.sum(grd * np.dot(grd, sigmainv), axis=-1)\n    const = ((2.0 * np.pi) ** (-N / 2.0)) * (sigmadet ** (-1.0 / 2.0))\n    return const * np.exp(-xtsigmax / 2.0)\n\n\ndef create_cone(shape: Shape, center: Optional[List[float]] = None) -> np.ndarray:\n    \"\"\"Compute a map of distances from a center pixel.\n\n    Args:\n        shape: Shape of the array for which the distance map is to be\n            computed.\n        center: Tuple of center coordinates. If ``None``, it is set to\n            the center of the array.\n\n    Returns:\n        An array containing a map of the distances.\n    \"\"\"\n\n    if center is None:\n        center = [(dim - 1) / 2 for dim in shape]\n\n    coords = [np.arange(0, dim) for dim in shape]\n    coord_mesh = np.meshgrid(*coords, sparse=True, indexing=\"ij\")\n\n    dist_map = sum([(coord_mesh[i] - center[i]) ** 2 for i in range(len(coord_mesh))])\n    dist_map = np.sqrt(dist_map)\n\n    return dist_map\n\n\ndef create_circular_phantom(\n    shape: Shape, radius_list: list, val_list: list, center: Optional[list] = None\n) -> np.ndarray:\n    \"\"\"Construct a circular phantom with given radii and intensities.\n\n    This functions supports both circular (``shape`` is 2D) and spherical\n    (``shape`` is 3D) phantoms.\n\n    Args:\n        shape: Shape of the phantom to be created.\n        radius_list: List of radii of the rings in the phantom.\n        val_list: List of intensity values of the rings in the phantom.\n        center: Tuple of center coordinates. If ``None``, it is set to\n           the center of the array.\n\n    Returns:\n        The computed phantom.\n    \"\"\"\n\n    dist_map = create_cone(shape, center)\n\n    img = np.zeros(shape)\n    for r, val in zip(radius_list, val_list):\n        # In numpy: img[dist_map < r] = val\n        # In jax.numpy: img = img.at[dist_map < r].set(val)\n        img[dist_map < r] = val\n\n    return img\n\n\ndef create_3d_foam_phantom(\n    im_shape: Shape,\n    N_sphere: int,\n    r_mean: float = 0.1,\n    r_std: float = 0.001,\n    pad: float = 0.01,\n    is_random: bool = False,\n) -> np.ndarray:\n    \"\"\"Construct a 3D phantom with random radii and centers.\n\n    Args:\n        im_shape: Shape of input image.\n        N_sphere: Number of spheres added.\n        r_mean: Mean radius of sphere (normalized to 1 along each axis).\n                Default 0.1.\n        r_std: Standard deviation of radius of sphere (normalized to 1\n                along each axis). Default 0.001.\n        pad: Padding length (normalized to 1 along each axis). Default 0.01.\n        is_random: Flag used to control randomness of phantom generation.\n                If ``False``, random seed is set to 1 in order to make the\n                process deterministic. Default ``False``.\n\n    Returns:\n        3D phantom of shape `im_shape`.\n    \"\"\"\n    c_lo = 0.0\n    c_hi = 1.0\n\n    if not is_random:\n        np.random.seed(1)\n\n    coord_list = [np.linspace(0, 1, N) for N in im_shape]\n    x = np.stack(np.meshgrid(*coord_list, indexing=\"ij\"), axis=-1)\n\n    centers = np.random.uniform(low=r_mean + pad, high=1 - r_mean - pad, size=(N_sphere, 3))\n    radii = r_std * np.random.randn(N_sphere) + r_mean\n\n    im = np.zeros(im_shape) + c_lo\n    for c, r in zip(centers, radii):  # type: ignore\n        dist = np.sum((x - c) ** 2, axis=-1)\n        select = im[dist < r**2]\n        if select.size > 0 and np.mean(select - c_lo) < 0.01 * c_hi:\n            # In numpy: im[dist < r**2] = c_hi\n            # In jax.numpy: im = im.at[dist < r**2].set(c_hi)\n            im[dist < r**2] = c_hi\n\n    return im\n\n\ndef create_conv_sparse_phantom(Nx: int, Nnz: int) -> Tuple[np.ndarray, np.ndarray]:\n    \"\"\"Construct a disc dictionary and sparse coefficient maps.\n\n    Construct a disc dictionary and a corresponding set of sparse\n    coefficient maps for testing convolutional sparse coding algorithms.\n\n    Args:\n        Nx: Size of coefficient maps (3 x Nx x Nx).\n        Nnz: Number of non-zero coefficients across all coefficient maps.\n\n    Returns:\n        A tuple consisting of a stack of 2D filters and the coefficient\n           map array.\n    \"\"\"\n\n    # constant parameters\n    M = 3\n    Nh = 7\n    e = 1\n\n    # create disc filters\n    h = np.zeros((M, 2 * Nh + 1, 2 * Nh + 1))\n    gr, gc = np.ogrid[-Nh : Nh + 1, -Nh : Nh + 1]\n    for m in range(M):\n        r = 2 * m + 3\n        d = np.sqrt(gr**2 + gc**2)\n        v = (np.clip(d, r - e, r + e) - (r - e)) / (2 * e)\n        v = 1.0 - v\n        h[m] = v\n\n    # create sparse random coefficient maps\n    np.random.seed(1234)\n    x = np.zeros((M, Nx, Nx))\n    idx0 = np.random.randint(0, M, size=(Nnz,))\n    idx1 = np.random.randint(0, Nx, size=(2, Nnz))\n    val = np.random.uniform(0, 5, size=(Nnz,))\n    x[idx0, idx1[0], idx1[1]] = val\n\n    return h, x\n\n\ndef create_tangle_phantom(nx: int, ny: int, nz: int) -> np.ndarray:\n    \"\"\"Construct a 3D phantom using the tangle function.\n\n    Args:\n        nx: x-size of output.\n        ny: y-size of output.\n        nz: z-size of output.\n\n    Returns:\n        An array with shape (nz, ny, nx).\n\n    \"\"\"\n    xs = 1.0 * np.linspace(-1.0, 1.0, nx)\n    ys = 1.0 * np.linspace(-1.0, 1.0, ny)\n    zs = 1.0 * np.linspace(-1.0, 1.0, nz)\n\n    # default ordering for meshgrid is `xy`, this makes inputs of length\n    # M, N, P will create a mesh of N, M, P. Thus we want ys, zs and xs.\n    xx: np.ndarray\n    yy: np.ndarray\n    zz: np.ndarray\n    xx, yy, zz = np.meshgrid(ys, zs, xs, copy=True)\n    xx = 3.0 * xx\n    yy = 3.0 * yy\n    zz = 3.0 * zz\n    values = (\n        xx * xx * xx * xx\n        - 5.0 * xx * xx\n        + yy * yy * yy * yy\n        - 5.0 * yy * yy\n        + zz * zz * zz * zz\n        - 5.0 * zz * zz\n        + 11.8\n    ) * 0.2 + 0.5\n    return (values < 2.0).astype(float)\n\n\n@partial(jax.jit, static_argnums=0)\ndef create_block_phantom(out_shape: Shape) -> np.ndarray:\n    \"\"\"Construct a blocky 3D phantom.\n\n    Args:\n        out_shape: desired phantom shape.\n\n    Returns:\n        Phantom.\n\n    \"\"\"\n    # make the phantom at a low resolution\n    low_res = np.array(\n        [\n            [\n                [0.0, 0.0, 0.0],\n                [0.0, 0.0, 0.0],\n                [0.0, 0.0, 0.0],\n            ],\n            [\n                [0.0, 1.0, 0.0],\n                [0.0, 1.0, 0.0],\n                [1.0, 1.0, 0.0],\n            ],\n            [\n                [0.0, 1.0, 0.0],\n                [0.0, 0.0, 0.0],\n                [0.0, 0.0, 0.0],\n            ],\n        ]\n    )\n    positions = np.stack(\n        np.meshgrid(*[np.linspace(-0.5, 2.5, s) for s in out_shape], indexing=\"ij\")\n    )\n    indices = np.round(positions).astype(int)\n    return low_res[indices[0], indices[1], indices[2]]\n\n\ndef spnoise(\n    img: Union[np.ndarray, snp.Array], nfrac: float, nmin: float = 0.0, nmax: float = 1.0\n) -> Union[np.ndarray, snp.Array]:\n    \"\"\"Return image with salt & pepper noise imposed on it.\n\n    Args:\n        img: Input image.\n        nfrac: Desired fraction of pixels corrupted by noise.\n        nmin: Lower value for noise (pepper). Default 0.0.\n        nmax: Upper value for noise (salt). Default 1.0.\n\n    Returns:\n        Noisy image\n    \"\"\"\n\n    if isinstance(img, np.ndarray):\n        spm = np.random.uniform(-1.0, 1.0, img.shape)  # type: ignore\n        imgn = img.copy()\n        imgn[spm < nfrac - 1.0] = nmin\n        imgn[spm > 1.0 - nfrac] = nmax\n    else:\n        spm, key = random.uniform(shape=img.shape, minval=-1.0, maxval=1.0, seed=0)  # type: ignore\n        imgn = img\n        imgn = imgn.at[spm < nfrac - 1.0].set(nmin)  # type: ignore\n        imgn = imgn.at[spm > 1.0 - nfrac].set(nmax)  # type: ignore\n    return imgn\n\n\ndef phase_diff(x: snp.Array, y: snp.Array) -> snp.Array:\n    \"\"\"Distance between phase angles.\n\n    Compute the distance between two arrays of phase angles, with\n    appropriate phase wrapping to minimize the distance.\n\n    Args:\n        x: Input array.\n        y: Input array.\n\n    Returns:\n        Array of angular distances.\n    \"\"\"\n    mod = snp.mod(snp.abs(x - y), 2 * snp.pi)\n    return snp.minimum(mod, 2 * snp.pi - mod)\n"
  },
  {
    "path": "scico/flax/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Neural network models implemented in `Flax <https://flax.readthedocs.io/en/latest/>`_ and utility functions.\n\nMany of the function and parameter names used in this sub-package are\nbased on the somewhat non-standard Flax terminology for neural network\ncomponents:\n\n`model`\n    The model is an abstract representation of the network structure that\n    does not include specific weight values.\n\n`parameters`\n    The parameters of a model are the weights of the network represented\n    by the model.\n\n`variables`\n    The variables encompass both the parameters (i.e. network weights)\n    and secondary values that are set from training data, such as\n    layer-dependent statistics used in batch normalization.\n\n`state`\n    The state encompasses both a set of model parameters as well as\n    optimizer parameters involved in training of that model. Storing the\n    state rather than just the variables enables a warm start for\n    additional training.\n\n|\n\"\"\"\n\nimport sys\n\n# isort: off\nfrom ._flax import FlaxMap, load_variables, save_variables\nfrom ._models import ConvBNNet, DnCNNNet, ResNet, UNet\nfrom .inverse import MoDLNet, ODPNet\nfrom .train.input_pipeline import create_input_iter\nfrom .train.typed_dict import ConfigDict\nfrom .train.trainer import BasicFlaxTrainer\nfrom .train.apply import only_apply\nfrom .train.clu_utils import count_parameters\n\n__all__ = [\n    \"FlaxMap\",\n    \"load_variables\",\n    \"save_variables\",\n    \"ConvBNNet\",\n    \"DnCNNNet\",\n    \"ResNet\",\n    \"UNet\",\n    \"MoDLNet\",\n    \"ODPNet\",\n    \"create_input_iter\",\n    \"ConfigDict\",\n    \"BasicFlaxTrainer\",\n    \"only_apply\",\n    \"count_parameters\",\n]\n\n# Imported items in __all__ appear to originate in top-level flax module\n# except ConfigDict.\nfor name in __all__:\n    if name != \"ConfigDict\":\n        getattr(sys.modules[__name__], name).__module__ = __name__\n"
  },
  {
    "path": "scico/flax/_flax.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Convolutional neural network models implemented in Flax.\"\"\"\n\nimport warnings\nfrom typing import Any, Optional\n\nwarnings.simplefilter(action=\"ignore\", category=FutureWarning)\n\nfrom flax import serialization\nfrom flax.linen.module import Module\nfrom scico.numpy import Array, BlockArray\nfrom scico.typing import Shape\n\nPyTree = Any\n\n\ndef load_variables(filename: str) -> PyTree:\n    \"\"\"Load trained model variables.\n\n    Args:\n        filename: Name of file containing trained model variables.\n\n    Returns:\n        A tree-like structure containing the values of the model\n        variables.\n    \"\"\"\n    with open(filename, \"rb\") as data_file:\n        bytes_input = data_file.read()\n\n    variables = serialization.msgpack_restore(bytes_input)\n\n    var_in = {\"params\": variables[\"params\"], \"batch_stats\": variables[\"batch_stats\"]}\n\n    return var_in\n\n\ndef save_variables(variables: PyTree, filename: str):\n    \"\"\"Save trained model weights.\n\n    Args:\n        filename: Name of file to to which model variables should be\n            saved.\n        variables: Model variables to save.\n    \"\"\"\n    bytes_output = serialization.msgpack_serialize(variables)\n\n    with open(filename, \"wb\") as data_file:\n        data_file.write(bytes_output)\n\n\nclass FlaxMap:\n    r\"\"\"A trained flax model.\"\"\"\n\n    def __init__(self, model: Module, variables: PyTree):\n        r\"\"\"Initialize a :class:`FlaxMap` object.\n\n        Args:\n            model: Flax model to apply.\n            variables: Parameters and batch stats of trained model.\n        \"\"\"\n        self.model = model\n        self.variables = variables\n        super().__init__()\n\n    def __call__(self, x: Array) -> Array:\n        r\"\"\"Apply trained flax model.\n\n        Args:\n            x: Input array.\n\n        Returns:\n            Output of flax model.\n        \"\"\"\n        if isinstance(x, BlockArray):\n            raise NotImplementedError\n\n        # Add singleton to input as necessary:\n        #   scico typically works with (H x W) or (H x W x C) arrays\n        #   flax expects (K x H x W x C) arrays\n        #   H: spatial height  W: spatial width\n        #   K: batch size  C: channel size\n        xndim = x.ndim\n        axsqueeze: Optional[Shape] = None\n        if xndim == 2:\n            x = x.reshape((1,) + x.shape + (1,))\n            axsqueeze = (0, 3)\n        elif xndim == 3:\n            x = x.reshape((1,) + x.shape)\n            axsqueeze = (0,)\n        y = self.model.apply(self.variables, x, train=False, mutable=False)\n        if y.ndim != xndim:\n            return y.squeeze(axis=axsqueeze)\n        return y\n"
  },
  {
    "path": "scico/flax/_models.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Flax implementation of different convolutional nets.\"\"\"\n\nimport warnings\n\nwarnings.simplefilter(action=\"ignore\", category=FutureWarning)\n\nfrom functools import partial\nfrom typing import Any, Callable, Tuple\n\nimport jax.numpy as jnp\n\nfrom flax.core import Scope  # noqa\nfrom flax.linen import BatchNorm, Conv, max_pool, relu\nfrom flax.linen.initializers import kaiming_normal, xavier_normal\nfrom flax.linen.module import _Sentinel  # noqa\nfrom flax.linen.module import Module, compact\nfrom scico.flax.blocks import (\n    ConvBNBlock,\n    ConvBNMultiBlock,\n    ConvBNPoolBlock,\n    ConvBNUpsampleBlock,\n    upscale_nn,\n)\nfrom scico.numpy import Array\n\n# The imports of Scope and _Sentinel (above) are required to silence\n# \"cannot resolve forward reference\" warnings when building sphinx api\n# docs.\n\n\nModuleDef = Any\n\n\nclass DnCNNNet(Module):\n    r\"\"\"Flax implementation of DnCNN :cite:`zhang-2017-dncnn`.\n\n    Flax implementation of the convolutional neural network (CNN)\n    architecture for denoising described in :cite:`zhang-2017-dncnn`.\n\n    Attributes:\n        depth: Number of layers in the neural network.\n        channels: Number of channels of input tensor.\n        num_filters: Number of filters in the convolutional layers.\n        kernel_size: Size of the convolution filters.\n        strides: Convolution strides.\n        dtype: Output dtype. Default: :attr:`~numpy.float32`.\n        act: Class of activation function to apply. Default:\n            :func:`~flax.linen.activation.relu`.\n    \"\"\"\n\n    depth: int\n    channels: int\n    num_filters: int = 64\n    kernel_size: Tuple[int, int] = (3, 3)\n    strides: Tuple[int, int] = (1, 1)\n    dtype: Any = jnp.float32\n    act: Callable = relu\n\n    @compact\n    def __call__(\n        self,\n        inputs: Array,\n        train: bool = True,\n    ) -> Array:\n        \"\"\"Apply DnCNN denoiser.\n\n        Args:\n            inputs: The array to be transformed.\n            train: Flag to differentiate between training and testing stages.\n\n        Returns:\n            The denoised input.\n        \"\"\"\n        # Definition using arguments common to all convolutions.\n        conv = partial(\n            Conv, use_bias=False, padding=\"CIRCULAR\", dtype=self.dtype, kernel_init=kaiming_normal()\n        )\n        # Definition using arguments common to all batch normalizations.\n        norm = partial(\n            BatchNorm,\n            use_running_average=not train,\n            momentum=0.99,\n            epsilon=1e-5,\n            dtype=self.dtype,\n        )\n\n        # Definition and application of DnCNN model.\n        base = inputs\n        y = conv(\n            self.num_filters,\n            self.kernel_size,\n            strides=self.strides,\n            name=\"conv_start\",\n        )(inputs)\n        y = self.act(y)\n        for _ in range(self.depth - 2):\n            y = ConvBNBlock(\n                self.num_filters,\n                conv=conv,\n                norm=norm,\n                act=self.act,\n                kernel_size=self.kernel_size,\n                strides=self.strides,\n            )(y)\n        y = conv(\n            self.channels,\n            self.kernel_size,\n            strides=self.strides,\n            name=\"conv_end\",\n        )(y)\n        return base - y  # residual-like network\n\n\nclass ResNet(Module):\n    \"\"\"Flax implementation of convolutional network with residual connection.\n\n    Net constructed from sucessive applications of convolution plus batch\n    normalization blocks and ending with residual connection (i.e. adding\n    the input to the output of the block).\n\n    Args:\n        depth: Depth of residual net.\n        channels: Number of channels of input tensor.\n        num_filters: Number of filters in the layers of the block.\n            Corresponds to the number of channels in the network\n            processing.\n        kernel_size: Size of the convolution filters.\n        strides: Convolution strides.\n        dtype: Output dtype. Default: :attr:`~numpy.float32`.\n    \"\"\"\n\n    depth: int\n    channels: int\n    num_filters: int = 64\n    kernel_size: Tuple[int, int] = (3, 3)\n    strides: Tuple[int, int] = (1, 1)\n    dtype: Any = jnp.float32\n\n    @compact\n    def __call__(self, x: Array, train: bool = True) -> Array:\n        \"\"\"Apply ResNet.\n\n        Args:\n            x: The array to be transformed.\n            train: Flag to differentiate between training and testing stages.\n\n        Returns:\n            The ResNet result.\n        \"\"\"\n\n        residual = x\n\n        # Definition using arguments common to all convolutions.\n        conv = partial(\n            Conv, use_bias=False, padding=\"CIRCULAR\", dtype=self.dtype, kernel_init=xavier_normal()\n        )\n\n        # Definition using arguments common to all batch normalizations.\n        norm = partial(\n            BatchNorm,\n            use_running_average=not train,\n            momentum=0.99,\n            epsilon=1e-5,\n            dtype=self.dtype,\n        )\n        act = relu\n\n        # Definition and application of ResNet.\n        for _ in range(self.depth - 1):\n            x = ConvBNBlock(\n                self.num_filters,\n                conv=conv,\n                norm=norm,\n                act=act,\n                kernel_size=self.kernel_size,\n                strides=self.strides,\n            )(x)\n\n        x = conv(\n            self.channels,\n            self.kernel_size,\n            strides=self.strides,\n        )(x)\n        x = norm()(x)\n\n        return x + residual\n\n\nclass ConvBNNet(Module):\n    \"\"\"Convolution and batch normalization net.\n\n    Net constructed from sucessive applications of convolution plus batch\n    normalization blocks. No residual connection.\n\n    Args:\n        depth: Depth of net.\n        channels: Number of channels of input tensor.\n        num_filters: Number of filters in the layers of the block.\n            Corresponds to the number of channels in the network\n            processing.\n        kernel_size: Size of the convolution filters.\n        strides: Convolution strides.\n        dtype: Output dtype. Default: :attr:`~numpy.float32`.\n    \"\"\"\n\n    depth: int\n    channels: int\n    num_filters: int = 64\n    kernel_size: Tuple[int, int] = (3, 3)\n    strides: Tuple[int, int] = (1, 1)\n    dtype: Any = jnp.float32\n\n    @compact\n    def __call__(self, x: Array, train: bool = True) -> Array:\n        \"\"\"Apply ConvBNNet.\n\n        Args:\n            x: The array to be transformed.\n            train: Flag to differentiate between training and testing stages.\n\n        Returns:\n            The ConvBNNet result.\n        \"\"\"\n        # Definition using arguments common to all convolutions.\n        conv = partial(\n            Conv, use_bias=False, padding=\"CIRCULAR\", dtype=self.dtype, kernel_init=xavier_normal()\n        )\n\n        # Definition using arguments common to all batch normalizations.\n        norm = partial(\n            BatchNorm,\n            use_running_average=not train,\n            momentum=0.99,\n            epsilon=1e-5,\n            dtype=self.dtype,\n        )\n        act = relu\n\n        # Definition and application of ConvBNNet.\n        for _ in range(self.depth - 1):\n            x = ConvBNBlock(\n                self.num_filters,\n                conv=conv,\n                norm=norm,\n                act=act,\n                kernel_size=self.kernel_size,\n                strides=self.strides,\n            )(x)\n\n        x = conv(\n            self.channels,\n            self.kernel_size,\n            strides=self.strides,\n        )(x)\n        x = norm()(x)\n\n        return x\n\n\nclass UNet(Module):\n    \"\"\"Flax implementation of U-Net model :cite:`ronneberger-2015-unet`.\n\n    Args:\n        depth: Depth of U-Net.\n        channels: Number of channels of input tensor.\n        num_filters: Number of filters in the convolutional layer of the\n            block. Corresponds to the number of channels in the network\n            processing.\n        kernel_size: Size of the convolution filters.\n        strides: Convolution strides.\n        block_depth: Number of processing layers per block.\n        window_shape: Window for reduction for pooling and downsampling.\n        upsampling: Factor for expanding.\n        dtype: Output dtype. Default: :attr:`~numpy.float32`.\n    \"\"\"\n\n    depth: int\n    channels: int\n    num_filters: int = 64\n    kernel_size: Tuple[int, int] = (3, 3)\n    strides: Tuple[int, int] = (1, 1)\n    block_depth: int = 2\n    window_shape: Tuple[int, int] = (2, 2)\n    upsampling: int = 2\n    dtype: Any = jnp.float32\n\n    @compact\n    def __call__(self, x: Array, train: bool = True) -> Array:\n        \"\"\"Apply U-Net.\n\n        Args:\n            x: The array to be transformed.\n            train: Flag to differentiate between training and testing stages.\n\n        Returns:\n            The U-Net result.\n        \"\"\"\n        # Definition using arguments common to all convolutions.\n        conv = partial(\n            Conv, use_bias=False, padding=\"CIRCULAR\", dtype=self.dtype, kernel_init=kaiming_normal()\n        )\n\n        # Definition using arguments common to all batch normalizations.\n        norm = partial(\n            BatchNorm,\n            use_running_average=not train,\n            momentum=0.99,\n            epsilon=1e-5,\n            dtype=self.dtype,\n        )\n\n        act = relu\n\n        # Definition of upscaling function.\n        upfn = partial(upscale_nn, scale=self.upsampling)\n\n        # Definition and application of U-Net.\n        x = ConvBNMultiBlock(\n            self.block_depth,\n            self.num_filters,\n            conv=conv,\n            norm=norm,\n            act=act,\n            kernel_size=self.kernel_size,\n            strides=self.strides,\n        )(x)\n        residual = []\n        # going down\n        j: int = 1\n        for _ in range(self.depth - 1):\n            residual.append(x)  # for skip connections\n            x = ConvBNPoolBlock(\n                2 * j * self.num_filters,\n                conv=conv,\n                norm=norm,\n                act=act,\n                pool=max_pool,\n                kernel_size=self.kernel_size,\n                strides=self.strides,\n                window_shape=self.window_shape,\n            )(x)\n            x = ConvBNMultiBlock(\n                self.block_depth,\n                2 * j * self.num_filters,\n                conv=conv,\n                norm=norm,\n                act=act,\n                kernel_size=self.kernel_size,\n                strides=self.strides,\n            )(x)\n            j = 2 * j\n\n        # going up\n        j = j // 2  # undo last\n        res_ind = -1\n        for _ in range(self.depth - 1):\n            x = ConvBNUpsampleBlock(\n                j * self.num_filters,\n                conv=conv,\n                norm=norm,\n                act=act,\n                upfn=upfn,\n                kernel_size=self.kernel_size,\n                strides=self.strides,\n            )(x)\n            # skip connection\n            x = jnp.concatenate((residual[res_ind], x), axis=3)\n            x = ConvBNMultiBlock(\n                self.block_depth,\n                j * self.num_filters,\n                conv=conv,\n                norm=norm,\n                act=act,\n                kernel_size=self.kernel_size,\n                strides=self.strides,\n            )(x)\n            res_ind -= 1\n            j = j // 2\n\n        # final conv1x1\n        ksz_out = (1, 1)\n        x = conv(self.channels, ksz_out, strides=self.strides)(x)\n\n        return x\n"
  },
  {
    "path": "scico/flax/blocks.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Flax implementation of different convolutional blocks.\"\"\"\n\nimport warnings\n\nwarnings.simplefilter(action=\"ignore\", category=FutureWarning)\n\nfrom typing import Any, Callable, Tuple\n\nimport jax.numpy as jnp\n\nfrom flax.core import Scope  # noqa\nfrom flax.linen.module import _Sentinel  # noqa\nfrom flax.linen.module import Module, compact\nfrom scico.numpy import Array\n\n# The imports of Scope and _Sentinel (above) are required to silence\n# \"cannot resolve forward reference\" warnings when building sphinx api\n# docs.\n\nModuleDef = Any\n\n\nclass ConvBNBlock(Module):\n    \"\"\"Define convolution and batch normalization Flax block.\n\n    Args:\n        num_filters: Number of filters in the\n            convolutional layer of the block.\n            Corresponds to the number of channels in the output tensor.\n        conv: Flax module implementing the convolution layer to apply.\n        norm: Flax module implementing the batch normalization layer to\n            apply.\n        act: Flax function defining the activation operation to apply.\n        kernel_size: A shape tuple defining the size of the convolution\n            filters.\n        strides: A shape tuple defining the size of strides in\n            convolution.\n    \"\"\"\n\n    num_filters: int\n    conv: ModuleDef\n    norm: ModuleDef\n    act: Callable[..., Array]\n    kernel_size: Tuple[int, int] = (3, 3)\n    strides: Tuple[int, int] = (1, 1)\n\n    @compact\n    def __call__(\n        self,\n        inputs: Array,\n    ) -> Array:\n        \"\"\"Apply convolution followed by normalization and activation.\n\n        Args:\n            inputs: The array to be transformed.\n\n        Returns:\n            The transformed input.\n        \"\"\"\n        y = self.conv(\n            self.num_filters,\n            self.kernel_size,\n            strides=self.strides,\n        )(inputs)\n        y = self.norm()(y)\n        return self.act(y)\n\n\nclass ConvBlock(Module):\n    \"\"\"Define Flax convolution block.\n\n    Args:\n        num_filters: Number of filters in the convolutional layer of the\n            block. Corresponds to the number of channels in the output\n            tensor.\n        conv: Flax module implementing the convolution layer to apply.\n        act: Flax function defining the activation operation to apply.\n        kernel_size: A shape tuple defining the size of the convolution\n            filters.\n        strides: A shape tuple defining the size of strides in\n            convolution.\n    \"\"\"\n\n    num_filters: int\n    conv: ModuleDef\n    act: Callable[..., Array]\n    kernel_size: Tuple[int, int] = (3, 3)\n    strides: Tuple[int, int] = (1, 1)\n\n    @compact\n    def __call__(\n        self,\n        inputs: Array,\n    ) -> Array:\n        \"\"\"Apply convolution followed by activation.\n\n        Args:\n            inputs: The array to be transformed.\n\n        Returns:\n            The transformed input.\n        \"\"\"\n        y = self.conv(\n            self.num_filters,\n            self.kernel_size,\n            strides=self.strides,\n        )(inputs)\n        return self.act(y)\n\n\nclass ConvBNPoolBlock(Module):\n    \"\"\"Define convolution, batch normalization and pooling Flax block.\n\n    Args:\n        num_filters: Number of filters in the convolutional layer of the\n            block. Corresponds to the number of channels in the output\n            tensor.\n        conv: Flax module implementing the convolution layer to apply.\n        norm: Flax module implementing the batch normalization layer to\n            apply.\n        act: Flax function defining the activation operation to apply.\n        pool: Flax function defining the pooling operation to apply.\n        kernel_size: A shape tuple defining the size of the convolution\n            filters.\n        strides: A shape tuple defining the size of strides in convolution.\n        window_shape: A shape tuple defining the window to reduce over in\n            the pooling operation.\n    \"\"\"\n\n    num_filters: int\n    conv: ModuleDef\n    norm: ModuleDef\n    act: Callable[..., Array]\n    pool: Callable[..., Array]\n    kernel_size: Tuple[int, int]\n    strides: Tuple[int, int]\n    window_shape: Tuple[int, int]\n\n    @compact\n    def __call__(\n        self,\n        inputs: Array,\n    ) -> Array:\n        \"\"\"Apply convolution followed by normalization, activation and pooling.\n\n        Args:\n            inputs: The array to be transformed.\n\n        Returns:\n            The transformed input.\n        \"\"\"\n        y = self.conv(\n            self.num_filters,\n            self.kernel_size,\n            strides=self.strides,\n        )(inputs)\n        y = self.norm()(y)\n        y = self.act(y)\n        # 'SAME': pads so as to have the same output shape as input if the stride is 1.\n        return self.pool(y, self.window_shape, strides=self.window_shape, padding=\"SAME\")\n\n\nclass ConvBNUpsampleBlock(Module):\n    \"\"\"Define convolution, batch normalization and upsample Flax block.\n\n    Args:\n        num_filters: Number of filters in the convolutional layer of the\n            block. Corresponds to the number of channels in the output\n            tensor.\n        conv: Flax module implementing the convolution layer to apply.\n        norm: Flax module implementing the batch normalization layer to\n            apply.\n        act: Flax function defining the activation operation to apply.\n        upfn: Flax function defining the upsampling operation to apply.\n        kernel_size: A shape tuple defining the size of the convolution\n            filters.\n        strides: A shape tuple defining the size of strides in convolution.\n    \"\"\"\n\n    num_filters: int\n    conv: ModuleDef\n    norm: ModuleDef\n    act: Callable[..., Array]\n    upfn: Callable[..., Array]\n    kernel_size: Tuple[int, int]\n    strides: Tuple[int, int]\n\n    @compact\n    def __call__(\n        self,\n        inputs: Array,\n    ) -> Array:\n        \"\"\"Apply convolution followed by normalization, activation and upsampling.\n\n        Args:\n            inputs: The array to be transformed.\n\n        Returns:\n            The transformed input.\n        \"\"\"\n        y = self.conv(\n            self.num_filters,\n            self.kernel_size,\n            strides=self.strides,\n        )(inputs)\n        y = self.norm()(y)\n        y = self.act(y)\n        return self.upfn(y)\n\n\nclass ConvBNMultiBlock(Module):\n    \"\"\"Block constructed from sucessive applications of :class:`ConvBNBlock`.\n\n    Args:\n        num_blocks: Number of convolutional batch normalization blocks to\n            apply. Each block has its own parameters for convolution\n            and batch normalization.\n        num_filters: Number of filters in the convolutional layer of the\n            block. Corresponds to the number of channels in the output\n            tensor.\n        conv: Flax module implementing the convolution layer to apply.\n        norm: Flax module implementing the batch normalization layer to\n            apply.\n        act: Flax function defining the activation operation to apply.\n        kernel_size: A shape tuple defining the size of the convolution\n            filters.\n        strides: A shape tuple defining the size of strides in\n            convolution.\n    \"\"\"\n\n    num_blocks: int\n    num_filters: int\n    conv: ModuleDef\n    norm: ModuleDef\n    act: Callable[..., Array]\n    kernel_size: Tuple[int, int] = (3, 3)\n    strides: Tuple[int, int] = (1, 1)\n\n    @compact\n    def __call__(\n        self,\n        x: Array,\n    ) -> Array:\n        \"\"\"Apply sucessive convolution normalization and activation blocks.\n\n        Apply sucessive blocks, each one composed of convolution\n        normalization and activation.\n\n        Args:\n            x: The array to be transformed.\n\n        Returns:\n            The transformed input.\n        \"\"\"\n        for _ in range(self.num_blocks):\n            x = ConvBNBlock(\n                self.num_filters,\n                conv=self.conv,\n                norm=self.norm,\n                act=self.act,\n                kernel_size=self.kernel_size,\n                strides=self.strides,\n            )(x)\n\n        return x\n\n\ndef upscale_nn(x: Array, scale: int = 2) -> Array:\n    \"\"\"Nearest neighbor upscale for image batches of shape (N, H, W, C).\n\n    Args:\n        x: Input tensor of shape (N, H, W, C).\n        scale: Integer scaling factor.\n\n    Returns:\n        Output tensor of shape (N, H * scale, W * scale, C).\n    \"\"\"\n    s = x.shape\n    x = x.reshape((s[0],) + (s[1], 1, s[2], 1) + (s[3],))\n    x = jnp.tile(x, (1, 1, scale, 1, scale, 1))\n    return x.reshape((s[0],) + (scale * s[1], scale * s[2]) + (s[3],))\n"
  },
  {
    "path": "scico/flax/examples/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Data utility functions used by Flax example scripts.\"\"\"\n\nfrom .data_preprocessing import PaddedCircularConvolve, build_blur_kernel\nfrom .examples import load_blur_data, load_ct_data, load_image_data\n\n__all__ = [\n    \"load_ct_data\",\n    \"load_blur_data\",\n    \"load_image_data\",\n    \"PaddedCircularConvolve\",\n    \"build_blur_kernel\",\n]\n"
  },
  {
    "path": "scico/flax/examples/data_generation.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Functionality to generate training data for Flax example scripts.\n\nComputation is distributed via ray to reduce processing time.\n\"\"\"\n\nfrom functools import partial\nfrom time import time\nfrom typing import Callable, List, Tuple, Union\n\nimport numpy as np\n\ntry:\n    import xdesign  # noqa: F401\nexcept ImportError:\n    have_xdesign = False\n\n    # pylint: disable=missing-class-docstring\n    class UnitCircle:\n        pass\n\n    # pylint: enable=missing-class-docstring\nelse:\n    have_xdesign = True\n    from xdesign import (  # type: ignore\n        Foam,\n        SimpleMaterial,\n        UnitCircle,\n        discrete_phantom,\n    )\n\ntry:\n    import os\n\n    os.environ[\"RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO\"] = \"0\"  # suppress ray warning\n    import ray  # noqa: F401\nexcept ImportError:\n    have_ray = False\nelse:\n    have_ray = True\n\nimport jax\nimport jax.numpy as jnp\n\ntry:\n    from jax.extend.backend import get_backend  # introduced in jax 0.4.33\nexcept ImportError:\n    from jax.lib.xla_bridge import get_backend\n\nfrom scico.linop import CircularConvolve\nfrom scico.linop.xray import XRayTransform2D\nfrom scico.numpy import Array\n\n\nclass Foam2(UnitCircle):\n    \"\"\"Foam-like material with two attenuations.\n\n    Define functionality to generate phantom with structure similar\n    to foam with two different attenuation properties.\"\"\"\n\n    def __init__(\n        self,\n        size_range: Union[float, List[float]] = [0.05, 0.01],\n        gap: float = 0,\n        porosity: float = 1,\n        attn1: float = 1.0,\n        attn2: float = 10.0,\n    ):\n        \"\"\"Foam-like structure with two different attenuations.\n        Circles for material 1 are more sparse than for material 2\n        by design.\n\n        Args:\n            size_range: The radius, or range of radius, of the\n                circles to be added. Default: [0.05, 0.01].\n            gap: Minimum distance between circle boundaries.\n                Default: 0.\n            porosity: Target porosity. Must be a value between\n                [0, 1]. Default: 1.\n            attn1: Mass attenuation parameter for material 1.\n                Default: 1.\n            attn2: Mass attenuation parameter for material 2.\n                Default: 10.\n        \"\"\"\n        if porosity < 0 or porosity > 1:\n            raise ValueError(\"Argument 'porosity' must be in the range [0,1).\")\n        super().__init__(radius=0.5, material=SimpleMaterial(attn1))  # type: ignore\n        self.sprinkle(  # type: ignore\n            300, size_range, gap, material=SimpleMaterial(attn2), max_density=porosity / 2.0\n        ) + self.sprinkle(  # type: ignore\n            300, size_range, gap, material=SimpleMaterial(20), max_density=porosity\n        )\n\n\ndef generate_foam1_images(seed: float, size: int, ndata: int) -> np.ndarray:\n    \"\"\"Generate batch of xdesign foam-like structures.\n\n    Generate batch of images with `xdesign` foam-like structure, which\n    uses one attenuation.\n\n    Args:\n        seed: Seed for data generation.\n        size: Size of image to generate.\n        ndata: Number of images to generate.\n\n    Returns:\n        Array of generated data.\n    \"\"\"\n    if not have_xdesign:\n        raise RuntimeError(\"Package xdesign is required for use of this function.\")\n    np.random.seed(seed)\n    saux: np.ndarray = np.zeros((ndata, size, size, 1), dtype=np.float32)\n    for i in range(ndata):\n        foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1)\n        saux[i, ..., 0] = discrete_phantom(foam, size=size)\n\n    return saux\n\n\ndef generate_foam2_images(seed: float, size: int, ndata: int) -> np.ndarray:\n    \"\"\"Generate batch of foam2 structures.\n\n    Generate batch of images with :class:`Foam2` structure\n    (foam-like material with two different attenuations).\n\n    Args:\n        seed: Seed for data generation.\n        size: Size of image to generate.\n        ndata: Number of images to generate.\n\n    Returns:\n        Array of generated data.\n    \"\"\"\n    if not have_xdesign:\n        raise RuntimeError(\"Package xdesign is required for use of this function.\")\n    np.random.seed(seed)\n    saux: np.ndarray = np.zeros((ndata, size, size, 1), dtype=np.float32)\n    for i in range(ndata):\n        foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1)\n        saux[i, ..., 0] = discrete_phantom(foam, size=size)\n    # normalize\n    saux /= np.max(saux, axis=(1, 2), keepdims=True)\n\n    return saux\n\n\ndef vector_f(f_: Callable, v: Array) -> Array:\n    \"\"\"Vectorize application of operator.\n\n    Args:\n        f_: Operator to apply.\n        v:  Array to evaluate.\n\n    Returns:\n       Result of evaluating operator over given arrays.\n    \"\"\"\n    lf = lambda x: jnp.atleast_3d(f_(x.squeeze()))\n    auto_batch = jax.vmap(lf)\n    return auto_batch(v)\n\n\ndef batched_f(f_: Callable, vr: Array) -> Array:\n    \"\"\"Distribute application of operator over a batch of vectors\n       among available processes.\n\n    Args:\n        f_: Operator to apply.\n        vr: Batch of arrays to evaluate.\n\n    Returns:\n       Result of evaluating operator over given batch of arrays. This\n       evaluation preserves the batch axis.\n    \"\"\"\n    nproc = jax.device_count()\n    if vr.shape[0] != nproc:\n        vrr = vr.reshape((nproc, -1, *vr.shape[:1]))\n    else:\n        vrr = vr\n    res = jax.pmap(partial(vector_f, f_))(vrr)\n    return res\n\n\ndef generate_ct_data(\n    nimg: int,\n    size: int,\n    nproj: int,\n    imgfunc: Callable = generate_foam2_images,\n    seed: int = 1234,\n    verbose: bool = False,\n) -> Tuple[Array, Array, Array]:\n    \"\"\"Generate batch of computed tomography (CT) data.\n\n    Generate batch of CT data for training of machine learning network\n    models.\n\n    Args:\n        nimg: Number of images to generate.\n        size: Size of reconstruction images.\n        nproj: Number of CT views.\n        imgfunc: Function for generating input images (e.g. foams).\n        seed: Seed for data generation.\n        verbose: Flag indicating whether to print status messages.\n\n    Returns:\n       tuple: A tuple (img, sino, fbp) containing:\n\n           - **img** : (:class:`jax.Array`): Generated foam images.\n           - **sino** : (:class:`jax.Array`): Corresponding sinograms.\n           - **fbp** : (:class:`jax.Array`) Corresponding filtered back projections.\n    \"\"\"\n    if not (have_ray and have_xdesign):\n        raise RuntimeError(\"Packages ray and xdesign are required for use of this function.\")\n\n    # Generate input data.\n    start_time = time()\n    img = distributed_data_generation(imgfunc, size, nimg, seed)\n    time_dtgen = time() - start_time\n    # clip to [0,1] range\n    img = jnp.clip(img, 0, 1)\n\n    nproc = jax.device_count()\n    if img.shape[0] % nproc > 0:\n        # Decrease nimg to be a multiple of nproc if it isn't already\n        nimg = (img.shape[0] // nproc) * nproc\n        img = img[:nimg]\n\n    # Configure a CT projection operator to generate synthetic measurements.\n    angles = np.linspace(0, jnp.pi, nproj)  # evenly spaced projection angles\n    gt_shape = (size, size)\n    dx = 1.0 / np.sqrt(2)\n    det_count = int(size * 1.05 / np.sqrt(2.0))\n    A = XRayTransform2D(gt_shape, angles, dx=dx, det_count=det_count)\n    # Compute sinograms in parallel.\n    start_time = time()\n    if nproc > 1:\n        # shard array\n        imgshd = img.reshape((nproc, -1, size, size, 1))\n        sinoshd = batched_f(A, imgshd)\n        sino = sinoshd.reshape((-1, nproj, sinoshd.shape[-2], 1))\n    else:\n        sino = vector_f(A, img)\n\n    time_sino = time() - start_time\n\n    # Compute filtered back-projection in parallel.\n    start_time = time()\n    if nproc > 1:\n        fbpshd = batched_f(A.fbp, sinoshd)\n        fbp = fbpshd.reshape((-1, size, size, 1))\n    else:\n        fbp = vector_f(A.fbp, sino)\n    time_fbp = time() - start_time\n\n    # Normalize sinogram.\n    sino = sino / size\n    # Clip FBP to [0,1] range.\n    fbp = np.clip(fbp, 0, 1)\n\n    if verbose:  # pragma: no cover\n        platform = get_backend().platform\n        print(f\"{'Platform':26s}{':':4s}{platform}\")\n        print(f\"{'Device count':26s}{':':4s}{jax.device_count()}\")\n        print(f\"{'Data generation':19s}{'time[s]:':10s}{time_dtgen:>7.2f}\")\n        print(f\"{'Sinogram':19s}{'time[s]:':10s}{time_sino:>7.2f}\")\n        print(f\"{'FBP':19s}{'time[s]:':10s}{time_fbp:>7.2f}\")\n\n    return img, sino, fbp\n\n\ndef generate_blur_data(\n    nimg: int,\n    size: int,\n    blur_kernel: Array,\n    noise_sigma: float,\n    imgfunc: Callable = generate_foam1_images,\n    seed: int = 4321,\n    verbose: bool = False,\n) -> Tuple[Array, Array]:\n    \"\"\"Generate batch of blurred data.\n\n    Generate batch of blurred data for training of machine learning\n    network models.\n\n    Args:\n        nimg: Number of images to generate.\n        size: Size of reconstruction images.\n        blur_kernel: Kernel for blurring the generated images.\n        noise_sigma: Level of additive Gaussian noise to apply.\n        imgfunc: Function to generate foams.\n        seed: Seed for data generation.\n        verbose: Flag indicating whether to print status messages.\n\n    Returns:\n       tuple: A tuple (img, blurn) containing:\n\n           - **img** : Generated foam images.\n           - **blurn** : Corresponding blurred and noisy images.\n    \"\"\"\n    if not (have_ray and have_xdesign):\n        raise RuntimeError(\"Packages ray and xdesign are required for use of this function.\")\n    start_time = time()\n    img = distributed_data_generation(imgfunc, size, nimg, seed)\n    time_dtgen = time() - start_time\n\n    # Clip to [0,1] range.\n    img = jnp.clip(img, 0, 1)\n\n    nproc = jax.device_count()\n    if img.shape[0] % nproc > 0:\n        # Decrease nimg to be a multiple of nproc if it isn't already\n        nimg = (img.shape[0] // nproc) * nproc\n        img = img[:nimg]\n\n    # Configure blur operator\n    ishape = (size, size)\n    A = CircularConvolve(h=blur_kernel, input_shape=ishape)\n\n    # Compute blurred images in parallel\n    start_time = time()\n    if nproc > 1:\n        # Shard array\n        imgshd = img.reshape((nproc, -1, size, size, 1))\n        blurshd = batched_f(A, imgshd)\n        blur = blurshd.reshape((-1, size, size, 1))\n    else:\n        blur = vector_f(A, img)\n    time_blur = time() - start_time\n    # Normalize blurred images\n    blur = blur / jnp.max(blur, axis=(1, 2), keepdims=True)\n    # Add Gaussian noise\n    key = jax.random.key(seed)\n    noise = jax.random.normal(key, blur.shape)\n    blurn = blur + noise_sigma * noise\n    # Clip to [0,1] range.\n    blurn = jnp.clip(blurn, 0, 1)\n\n    if verbose:  # pragma: no cover\n        platform = get_backend().platform\n        print(f\"{'Platform':26s}{':':4s}{platform}\")\n        print(f\"{'Device count':26s}{':':4s}{jax.device_count()}\")\n        print(f\"{'Data generation':19s}{'time[s]:':10s}{time_dtgen:>7.2f}\")\n        print(f\"{'Blur generation':19s}{'time[s]:':10s}{time_blur:>7.2f}\")\n\n    return img, blurn\n\n\ndef distributed_data_generation(\n    imgenf: Callable, size: int, nimg: int, seedg: float = 123\n) -> np.ndarray:\n    \"\"\"Data generation distributed among processes using ray.\n\n    *Warning:* callable `imgenf` should not make use of any jax functions\n    to avoid the risk of errors when running with GPU devices, in which\n    case jax is initialized to expect the availability of GPUs, which are\n    then not available within the `ray.remote` function due to the absence\n    of any declared GPUs as a `num_gpus` parameter of `@ray.remote`.\n\n    Args:\n        imagenf: Function for batch-data generation.\n        size: Size of image to generate.\n        ndata: Number of images to generate.\n        seedg: Base seed for data generation.\n\n    Returns:\n        Array of generated data.\n    \"\"\"\n    if not have_ray:\n        raise RuntimeError(\"Package ray is required for use of this function.\")\n    if not ray.is_initialized():\n        raise RuntimeError(\"Ray must be initialized via ray.init() before calling this function.\")\n\n    # Use half of available CPU resources\n    ar = ray.available_resources()\n    nproc = max(int(ar.get(\"CPU\", 1)) // 2, 1)\n\n    # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that\n    # can severely limit parallel execution (since ray will ensure that only\n    # as many actors as available GPUs are created), and is expected to be\n    # rather brittle.\n    if \"GPU\" in ar:\n        num_gpus = 1\n        nproc = min(nproc, int(ar.get(\"GPU\")))\n    else:\n        num_gpus = 0\n\n    if nproc > nimg:\n        nproc = nimg\n    if nimg % nproc > 0:\n        # Increase nimg to be a multiple of nproc if it isn't already\n        nimg = (nimg // nproc + 1) * nproc\n\n    ndata_per_proc = int(nimg // nproc)\n\n    @ray.remote(num_gpus=num_gpus)\n    def data_gen(seed, size, ndata, imgf):\n        return imgf(seed, size, ndata)\n\n    ray_return = ray.get(\n        [data_gen.remote(seed + seedg, size, ndata_per_proc, imgenf) for seed in range(nproc)]\n    )\n    imgs = np.vstack([t for t in ray_return])\n\n    return imgs\n"
  },
  {
    "path": "scico/flax/examples/data_preprocessing.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Image manipulation utils.\"\"\"\n\nimport glob\nimport math\nimport os\nimport tarfile\nimport tempfile\nfrom typing import Any, Callable, Optional, Tuple, Union\n\nimport numpy as np\n\nimport jax.numpy as jnp\n\nimport imageio.v3 as iio\n\nfrom scico import util\nfrom scico.examples import rgb2gray\nfrom scico.flax.train.typed_dict import DataSetDict\nfrom scico.linop import CircularConvolve, LinearOperator\nfrom scico.numpy import Array\nfrom scico.typing import Shape\n\nfrom .typed_dict import ConfigImageSetDict\n\n\ndef rotation90(img: Array) -> Array:\n    \"\"\"Rotate an image, or a batch of images, by 90 degrees.\n\n    Rotate an image or a batch of images by 90 degrees counterclockwise.\n    An image is an array with size H x W x C with H and W spatial\n    dimensions and C number of channels. A batch of images is an\n    array with size N x H x W x C with N number of images.\n\n    Args:\n        img: The array to be rotated.\n\n    Returns:\n       An image, or batch of images, rotated by 90 degrees\n       counterclockwise.\n    \"\"\"\n    if img.ndim < 4:\n        return np.swapaxes(img, 0, 1)\n    else:\n        return np.swapaxes(img, 1, 2)\n\n\ndef flip(img: Array) -> Array:\n    \"\"\"Horizontal flip of an image or a batch of images.\n\n    Horizontally flip an image or a batch of images. An image is an\n    array with size H x W x C with H and W spatial dimensions and C\n    number of channels. A batch of images is an array with size\n    N x H x W x C with N number of images.\n\n    Args:\n        img: The array to be flipped.\n\n    Returns:\n       An image, or batch of images, flipped horizontally.\n    \"\"\"\n    if img.ndim < 4:\n        return img[:, ::-1, ...]\n    else:\n        return img[..., ::-1, :]\n\n\nclass CenterCrop:\n    \"\"\"Crop central part of an image to a specified size.\n\n    Crop central part of an image. An image is an array with size\n    H x W x C with H and W spatial dimensions and C number of channels.\n    \"\"\"\n\n    def __init__(self, output_size: Union[Shape, int]):\n        \"\"\"\n        Args:\n            output_size: Desired output size. If int, square crop is\n                made.\n        \"\"\"\n        # assert isinstance(output_size, (int, tuple))\n        if isinstance(output_size, int):\n            self.output_size: Shape = (output_size, output_size)\n        else:\n            assert len(output_size) == 2\n            self.output_size = output_size\n\n    def __call__(self, image: Array) -> Array:\n        \"\"\"Apply center crop.\n\n        Args:\n            image: The array to be cropped.\n\n        Returns:\n            The cropped image.\n        \"\"\"\n\n        h, w = image.shape[:2]\n        new_h, new_w = self.output_size\n        top = (h - new_h) // 2\n        left = (w - new_w) // 2\n\n        image = image[top : top + new_h, left : left + new_w]\n\n        return image\n\n\nclass PositionalCrop:\n    \"\"\"Crop an image from a given corner to a specified size.\n\n    Crop an image from a given corner. An image is an array with size\n    H x W x C with H and W spatial dimensions and C number of channels.\n    \"\"\"\n\n    def __init__(self, output_size: Union[Shape, int]):\n        \"\"\"\n        Args:\n            output_size: Desired output size. If int, square crop is\n                made.\n        \"\"\"\n        # assert isinstance(output_size, (int, tuple))\n        if isinstance(output_size, int):\n            self.output_size: Shape = (output_size, output_size)\n        else:\n            assert len(output_size) == 2\n            self.output_size = output_size\n\n    def __call__(self, image: Array, top: int, left: int) -> Array:\n        \"\"\"Apply positional crop.\n\n        Args:\n            image: The array to be cropped.\n            top: Vertical top coordinate of corner to start cropping.\n            left: Horizontal left coordinate of corner to start\n                cropping.\n\n        Returns:\n            The cropped image.\n        \"\"\"\n\n        h, w = image.shape[:2]\n        new_h, new_w = self.output_size\n\n        image = image[top : top + new_h, left : left + new_w]\n\n        return image\n\n\nclass RandomNoise:\n    \"\"\"Add Gaussian noise to an image or a batch of images.\n\n    Add Gaussian noise to an image or a batch of images. An image is\n    an array with size H x W x C with H and W spatial dimensions\n    and C number of channels. A batch of images is an array with\n    size N x H x W x C with N number of images. The Gaussian noise is\n    a Gaussian random variable with mean zero and given standard\n    deviation. The standard deviation can be a fix value corresponding\n    to the specified noise level or randomly selected on a range\n    between 50% and 100% of the specified noise level.\n    \"\"\"\n\n    def __init__(self, noise_level: float, range_flag: bool = False):\n        \"\"\"\n        Args:\n            noise_level: Standard dev of the Gaussian noise.\n            range_flag: If ``True``, the standard dev is randomly\n                selected between 50% and 100% of `noise_level` set.\n                Default: ``False``.\n        \"\"\"\n        self.range_flag = range_flag\n        if range_flag:\n            self.noise_level_low = 0.5 * noise_level\n        self.noise_level = noise_level\n\n    def __call__(self, image: Array) -> Array:\n        \"\"\"Add Gaussian noise.\n\n        Args:\n            image: The array to add noise to.\n\n        Returns:\n            The noisy image.\n        \"\"\"\n\n        noise_level = self.noise_level\n\n        if self.range_flag:\n            if image.ndim > 3:\n                num_img = image.shape[0]\n            else:\n                num_img = 1\n            noise_level_range = np.random.uniform(self.noise_level_low, self.noise_level, num_img)\n            noise_level = noise_level_range.reshape(\n                (noise_level_range.shape[0],) + (1,) * (image.ndim - 1)\n            )\n\n        imgnoised = image + np.random.normal(0.0, noise_level, image.shape)\n        imgnoised = np.clip(imgnoised, 0.0, 1.0)\n\n        return imgnoised\n\n\ndef preprocess_images(\n    images: Array,\n    output_size: Union[Shape, int],\n    gray_flag: bool = False,\n    num_img: Optional[int] = None,\n    multi_flag: bool = False,\n    stride: Optional[Union[Shape, int]] = None,\n    dtype: Any = np.float32,\n) -> Array:\n    \"\"\"Preprocess (scale, crop, etc.) set of images.\n\n    Preprocess set of images, converting to gray scale, or cropping or\n    sampling multiple patches from each one, or selecting a subset of\n    them, according to specified setup.\n\n    Args:\n        images: Array of color images.\n        output_size: Desired output size. If int, square crop is made.\n        gray_flag: If ``True``, converts to gray scale.\n        num_img: If specified, reads that number of images, if not reads\n            all the images in path.\n        multi_flag: If ``True``, samples multiple patches of specified\n            size in each image.\n        stride: Stride between patch origins (indexed from left-top\n            corner). If int, the same stride is used in h and w.\n        dtype: dtype of array. Default: :attr:`~numpy.float32`.\n\n    Returns:\n        Preprocessed array.\n    \"\"\"\n\n    # Get number of images to use.\n    if num_img is None:\n        num_img = images.shape[0]\n\n    # Get channels of ouput image.\n    C = 3\n    if gray_flag:\n        C = 1\n\n    # Define functionality to crop and create signal array.\n    if multi_flag:\n        tsfm = PositionalCrop(output_size)\n        assert stride is not None\n        if isinstance(stride, int):\n            stride_multi = (stride, stride)\n        S = np.zeros((num_img, images.shape[1], images.shape[2], C), dtype=dtype)\n    else:\n        tsfm_crop = CenterCrop(output_size)\n        S = np.zeros((num_img, tsfm_crop.output_size[0], tsfm_crop.output_size[1], C), dtype=dtype)\n\n    # Convert to gray scale and/or crop.\n    for i in range(S.shape[0]):\n        img = images[i] / 255.0\n        if gray_flag:\n            imgG = rgb2gray(img)\n            # Keep channel singleton.\n            img = imgG.reshape(imgG.shape + (1,))\n        if not multi_flag:\n            # Crop image\n            img = tsfm_crop(img)\n        S[i] = img\n\n    if multi_flag:\n        # Sample multiple patches from image\n        h = S.shape[1]\n        w = S.shape[2]\n        nh = int(math.floor((h - tsfm.output_size[0]) / stride_multi[0])) + 1\n        nw = int(math.floor((w - tsfm.output_size[1]) / stride_multi[1])) + 1\n        saux = np.zeros(\n            (nh * nw * num_img, tsfm.output_size[0], tsfm.output_size[1], S.shape[-1]), dtype=dtype\n        )\n        count2 = 0\n        for i in range(S.shape[0]):\n            for top in range(0, h - tsfm.output_size[0], stride_multi[0]):\n                for left in range(0, w - tsfm.output_size[1], stride_multi[1]):\n                    saux[count2, ...] = tsfm(S[i], top, left)\n                    count2 += 1\n        S = saux\n    return S\n\n\ndef build_image_dataset(\n    imgs_train, imgs_test, config: ConfigImageSetDict, transf: Optional[Callable] = None\n) -> Tuple[DataSetDict, ...]:\n    \"\"\"Preprocess and assemble dataset for training.\n\n    Preprocess images according to the specified configuration and\n    assemble a dataset into a structure that can be used for training\n    machine learning models. Keep training and testing partitions.\n    Each dictionary returned has images and labels, which are arrays\n    of dimensions (N, H, W, C) with N: number of images; H,\n    W: spatial dimensions and C: number of channels.\n\n    Args:\n        imgs_train: 4D array (NHWC) with images for training.\n        imgs_test: 4D array (NHWC) with images for testing.\n        config: Configuration of image data set to read.\n        transf: Operator for blurring or other non-trivial\n            transformations. Default: ``None``.\n\n    Returns:\n       tuple: A tuple (train_ds, test_ds) containing:\n\n           - **train_ds** : Dictionary of training data (includes images and labels).\n           - **test_ds** : Dictionary of testing data (includes images and labels).\n    \"\"\"\n    # Preprocess images by converting to gray scale or sampling multiple\n    # patches according to specified configuration.\n    S_train = preprocess_images(\n        imgs_train,\n        config[\"output_size\"],\n        gray_flag=config[\"run_gray\"],\n        num_img=config[\"num_img\"],\n        multi_flag=config[\"multi\"],\n        stride=config[\"stride\"],\n    )\n    S_test = preprocess_images(\n        imgs_test,\n        config[\"output_size\"],\n        gray_flag=config[\"run_gray\"],\n        num_img=config[\"test_num_img\"],\n        multi_flag=config[\"multi\"],\n        stride=config[\"stride\"],\n    )\n\n    # Check for transformation\n    tsfm: Optional[Callable] = None\n    # Processing: add noise or blur or etc.\n    if config[\"data_mode\"] == \"dn\":  # Denoise problem\n        tsfm = RandomNoise(config[\"noise_level\"], config[\"noise_range\"])\n    elif config[\"data_mode\"] == \"dcnv\":  # Deconvolution problem\n        assert transf is not None\n        tsfm = transf\n\n    if config[\"augment\"]:  # Augment training data set by flip and 90 degrees rotation\n\n        strain1 = rotation90(S_train.copy())\n        strain2 = flip(S_train.copy())\n\n        S_train = np.concatenate((S_train, strain1, strain2), axis=0)\n\n    # Processing: apply transformation\n    if tsfm is not None:\n        if config[\"data_mode\"] == \"dn\":\n            Stsfm_train = tsfm(S_train.copy())\n            Stsfm_test = tsfm(S_test.copy())\n        elif config[\"data_mode\"] == \"dcnv\":\n            tsfm2 = RandomNoise(config[\"noise_level\"], config[\"noise_range\"])\n            Stsfm_train = tsfm2(tsfm(S_train.copy()))\n            Stsfm_test = tsfm2(tsfm(S_test.copy()))\n\n    # Shuffle data\n    rng = np.random.default_rng(config[\"seed\"])\n    perm_tr = rng.permutation(Stsfm_train.shape[0])\n    perm_tt = rng.permutation(Stsfm_test.shape[0])\n    train_ds: DataSetDict = {\"image\": Stsfm_train[perm_tr], \"label\": S_train[perm_tr]}\n    test_ds: DataSetDict = {\"image\": Stsfm_test[perm_tt], \"label\": S_test[perm_tt]}\n\n    return train_ds, test_ds\n\n\ndef images_read(path: str, ext: str = \"jpg\") -> Array:  # pragma: no cover\n    \"\"\"Read a collection of color images from a set of files.\n\n    Read a collection of color images from a set of files in the\n    specified directory. All files with extension `ext` (i.e.\n    matching glob `*.ext`) in directory `path` are assumed to be image\n    files and are read. Images may have different aspect ratios,\n    therefore, they are transposed to keep the aspect ratio of the first\n    image read.\n\n    Args:\n        path: Path to directory containing the image files.\n        ext: Filename extension.\n\n    Returns:\n        Collection of color images as a 4D array.\n    \"\"\"\n\n    slices = []\n    shape = None\n    for file in sorted(glob.glob(os.path.join(path, \"*.\" + ext))):\n        image = iio.imread(file)\n        if shape is None:\n            shape = image.shape[:2]\n        if shape != image.shape[:2]:\n            image = np.transpose(image, (1, 0, 2))\n        slices.append(image)\n    return np.stack(slices)\n\n\ndef get_bsds_data(path: str, verbose: bool = False):  # pragma: no cover\n    \"\"\"Download BSDS500 data from the BSDB project.\n\n    Download the BSDS500 dataset, a set of 500 color images of size\n    481x321 or 321x481, from the Berkeley Segmentation Dataset and\n    Benchmark project.\n\n    The downloaded data is converted to `.npz` format for\n    convenient access via :func:`numpy.load`. The converted data\n    is saved in a file `bsds500.npz` in the directory specified by\n    `path`. Note that train and test folders are merged to get a\n    set of 400 images for training while the val folder is reserved\n    as a set of 100 images for testing. This is done in multiple\n    works such as :cite:`zhang-2017-dncnn`.\n\n    Args:\n        path: Directory in which converted data is saved.\n        verbose: Flag indicating whether to print status messages.\n    \"\"\"\n    # data source URL and filenames\n    data_base_url = \"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/\"\n    data_tar_file = \"BSR_bsds500.tgz\"\n    # ensure path directory exists\n    if not os.path.isdir(path):\n        raise ValueError(f\"Path {path} does not exist or is not a directory.\")\n    # create temporary directory\n    temp_dir = tempfile.TemporaryDirectory()\n    if verbose:\n        print(f\"Downloading {data_tar_file} from {data_base_url}\")\n    data = util.url_get(data_base_url + data_tar_file)\n    f = open(os.path.join(temp_dir.name, data_tar_file), \"wb\")\n    f.write(data.read())\n    f.close()\n    if verbose:\n        print(\"Download complete\")\n\n    # untar downloaded data into temporary directory\n    if verbose:\n        print(f\"Extracting content from tar file {data_tar_file}\")\n\n    with tarfile.open(os.path.join(temp_dir.name, data_tar_file), \"r\") as tar_ref:\n        tar_ref.extractall(temp_dir.name)\n\n    # read untared data files into 4D arrays and save as .npz\n    data_path = os.path.join(\"BSR\", \"BSDS500\", \"data\", \"images\")\n    train_path = os.path.join(data_path, \"train\")\n    imgs_train = images_read(os.path.join(temp_dir.name, train_path))\n    val_path = os.path.join(data_path, \"val\")\n    imgs_val = images_read(os.path.join(temp_dir.name, val_path))\n    test_path = os.path.join(data_path, \"test\")\n    imgs_test = images_read(os.path.join(temp_dir.name, test_path))\n\n    # Train and test data merge into train.\n    # Leave val data for testing.\n    imgs400 = np.vstack([imgs_train, imgs_test])\n    if verbose:\n        print(f\"Read {imgs400.shape[0]} images for training\")\n        print(f\"Read {imgs_val.shape[0]} images for testing\")\n\n    npz_file = os.path.join(path, \"bsds500.npz\")\n    if verbose:\n        subpath = str.split(npz_file, \".cache\")\n        npz_file_display = \"~/.cache\" + subpath[-1]\n        print(f\"Saving as {npz_file_display}\")\n    np.savez(npz_file, imgstr=imgs400, imgstt=imgs_val)\n\n\ndef build_blur_kernel(\n    kernel_size: Shape,\n    blur_sigma: float,\n    dtype: Any = np.float32,\n):\n    \"\"\"Construct a blur kernel as specified.\n\n    Args:\n        kernel_size: Size of the blur kernel.\n        blur_sigma: Standard deviation of the blur kernel.\n        dtype: Output dtype. Default: :attr:`~numpy.float32`.\n    \"\"\"\n    kernel = 1.0\n    meshgrids = np.meshgrid(*[np.arange(size, dtype=dtype) for size in kernel_size])\n    for size, mgrid in zip(kernel_size, meshgrids):\n        mean = (size - 1) / 2\n        kernel *= np.exp(-(((mgrid - mean) / blur_sigma) ** 2) / 2)\n    # Make sure norm of values in gaussian kernel equals 1.\n    knorm = np.sqrt(np.sum(kernel * kernel))\n    kernel = kernel / knorm\n\n    return kernel\n\n\nclass PaddedCircularConvolve(LinearOperator):\n    \"\"\"Define padded convolutional operator.\n\n    The operator pads the signal with a reflection of the borders\n    before convolving with the kernel provided at initialization. It\n    crops the result of the convolution to maintain the same signal\n    size.\n    \"\"\"\n\n    def __init__(\n        self,\n        output_size: Union[Shape, int],\n        channels: int,\n        kernel_size: Union[Shape, int],\n        blur_sigma: float,\n        dtype: Any = np.float32,\n    ):\n        \"\"\"\n        Args:\n            output_size: Size of the image to blur.\n            channels: Number of channels in image to blur.\n            kernel_size: Size of the blur kernel.\n            blur_sigma: Standard deviation of the blur kernel.\n            dtype: Output dtype. Default: :attr:`~numpy.float32`.\n        \"\"\"\n        if isinstance(output_size, int):\n            output_size = (output_size, output_size)\n        else:\n            assert len(output_size) == 2\n\n        if isinstance(kernel_size, int):\n            kernel_size = (kernel_size, kernel_size)\n        else:\n            assert len(kernel_size) == 2\n\n        # Define padding.\n        self.padsz = (\n            (kernel_size[0] // 2, kernel_size[0] // 2),\n            (kernel_size[1] // 2, kernel_size[1] // 2),\n            (0, 0),\n        )\n\n        shape = (output_size[0], output_size[1], channels)\n        with_pad = (\n            output_size[0] + self.padsz[0][0] + self.padsz[0][1],\n            output_size[1] + self.padsz[1][0] + self.padsz[1][1],\n        )\n        shape_padded = (with_pad[0], with_pad[1], channels)\n\n        # Define data types.\n        input_dtype = dtype\n        output_dtype = dtype\n\n        # Construct blur kernel as specified.\n        kernel = build_blur_kernel(kernel_size, blur_sigma)\n\n        # Define convolution part.\n        self.conv = CircularConvolve(kernel, input_shape=shape_padded, ndims=2, input_dtype=dtype)\n\n        # Initialize Linear Operator.\n        super().__init__(\n            input_shape=shape,\n            output_shape=shape,\n            input_dtype=input_dtype,\n            output_dtype=output_dtype,\n            jit=True,\n        )\n\n    def _eval(self, x: Array) -> Array:\n        \"\"\"Apply operator.\n\n        Args:\n            x: The array with input signal. The input to the\n                constructed operator should be HWC with H and W spatial\n                dimensions given by `output_size` and C the given\n                `channels`.\n\n        Returns:\n            The result of padding, convolving and cropping the signal.\n            The output signal has the same HWC dimensions as the input\n            signal.\n        \"\"\"\n        xpadd: Array = jnp.pad(x, self.padsz, mode=\"reflect\")\n        rconv: Array = self.conv(xpadd)\n        return rconv[self.padsz[0][0] : -self.padsz[0][1], self.padsz[1][0] : -self.padsz[1][1], :]\n"
  },
  {
    "path": "scico/flax/examples/examples.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Generation and loading of data used in Flax example scripts.\"\"\"\n\nimport os\nfrom typing import Callable, Optional, Tuple, Union\n\nimport numpy as np\n\nfrom scico.flax.train.typed_dict import DataSetDict\nfrom scico.numpy import Array\nfrom scico.typing import Shape\n\nfrom .data_generation import generate_blur_data, generate_ct_data\nfrom .data_preprocessing import ConfigImageSetDict, build_image_dataset, get_bsds_data\nfrom .typed_dict import CTDataSetDict\n\n\ndef get_cache_path(cache_path: Optional[str] = None) -> Tuple[str, str]:\n    \"\"\"Get input/output SCICO cache path.\n\n    Args:\n        cache_path: Given cache path. If ``None`` SCICO default cache\n            path is constructed.\n\n    Returns:\n        The cache path and a display string with private user path\n        information stripped.\n    \"\"\"\n    if cache_path is None:\n        cache_path = os.path.join(os.path.expanduser(\"~\"), \".cache\", \"scico\", \"examples\", \"data\")\n        subpath = str.split(cache_path, \".cache\")\n        cache_path_display = \"~/.cache\" + subpath[-1]\n    else:\n        cache_path_display = cache_path\n\n    return cache_path, cache_path_display\n\n\ndef load_ct_data(\n    train_nimg: int,\n    test_nimg: int,\n    size: int,\n    nproj: int,\n    cache_path: Optional[str] = None,\n    verbose: bool = False,\n) -> Tuple[CTDataSetDict, ...]:  # pragma: no cover\n    \"\"\"\n    Load or generate CT data.\n\n    Load or generate CT data for training of machine learning network\n    models. If cached file exists and enough data of the requested\n    size is available, data is loaded and returned.\n\n    If either `size` or `nproj` requested does not match the data read\n    from the cached file, a `RunTimeError` is generated.\n\n    If no cached file is found or not enough data is contained in the\n    file a new data set is generated and stored in `cache_path`. The\n    data is stored in `.npz` format for convenient access via\n    :func:`numpy.load`. The data is saved in two distinct files:\n    `ct_foam2_train.npz` and `ct_foam2_test.npz` to keep separated\n    training and testing partitions.\n\n    Args:\n        train_nimg: Number of images required for training.\n        test_nimg: Number of images required for testing.\n        size: Size of reconstruction images.\n        nproj: Number of CT views.\n        cache_path: Directory in which generated data is saved.\n            Default: ``None``.\n        verbose: Flag indicating whether to print status messages.\n            Default: ``False``.\n\n    Returns:\n       tuple: A tuple (trdt, ttdt) containing:\n\n           - **trdt** : (Dictionary): Collection of images (key `img`),\n               sinograms (key `sino`) and filtered back projections\n               (key `fbp`) for training.\n           - **ttdt** : (Dictionary): Collection of images (key `img`),\n               sinograms (key `sino`) and filtered back projections\n               (key `fbp`) for testing.\n    \"\"\"\n    # Set default cache path if not specified\n    cache_path, cache_path_display = get_cache_path(cache_path)\n\n    # Create cache directory and generate data if not already present.\n    npz_train_file = os.path.join(cache_path, \"ct_foam2_train.npz\")\n    npz_test_file = os.path.join(cache_path, \"ct_foam2_test.npz\")\n\n    if os.path.isfile(npz_train_file) and os.path.isfile(npz_test_file):\n        # Load data\n        trdt_in = np.load(npz_train_file)\n        ttdt_in = np.load(npz_test_file)\n        # Check image size\n        if trdt_in[\"img\"].shape[1] != size:\n            runtime_error_scalar(\"size\", \"training\", size, trdt_in[\"img\"].shape[1])\n        if ttdt_in[\"img\"].shape[1] != size:\n            runtime_error_scalar(\"size\", \"testing\", size, ttdt_in[\"img\"].shape[1])\n        # Check number of projections\n        if trdt_in[\"sino\"].shape[1] != nproj:\n            runtime_error_scalar(\"views\", \"training\", nproj, trdt_in[\"sino\"].shape[1])\n        if ttdt_in[\"sino\"].shape[1] != nproj:\n            runtime_error_scalar(\"views\", \"testing\", nproj, ttdt_in[\"sino\"].shape[1])\n        # Check that enough data is available\n        if trdt_in[\"img\"].shape[0] >= train_nimg:\n            if ttdt_in[\"img\"].shape[0] >= test_nimg:\n                trdt: CTDataSetDict = {\n                    \"img\": trdt_in[\"img\"][:train_nimg],\n                    \"sino\": trdt_in[\"sino\"][:train_nimg],\n                    \"fbp\": trdt_in[\"fbp\"][:train_nimg],\n                }\n                ttdt: CTDataSetDict = {\n                    \"img\": ttdt_in[\"img\"][:test_nimg],\n                    \"sino\": ttdt_in[\"sino\"][:test_nimg],\n                    \"fbp\": ttdt_in[\"fbp\"][:test_nimg],\n                }\n                if verbose:\n                    print_input_path(cache_path_display)\n                    print_data_size(\"training\", trdt[\"img\"].shape[0])\n                    print_data_size(\"testing \", ttdt[\"img\"].shape[0])\n                    print_data_range(\"images  \", trdt[\"img\"])\n                    print_data_range(\"sinogram\", trdt[\"sino\"])\n                    print_data_range(\"FBP     \", trdt[\"fbp\"])\n\n                return trdt, ttdt\n\n            elif verbose:\n                print_data_warning(\"testing\", test_nimg, ttdt_in[\"img\"].shape[0])\n        elif verbose:\n            print_data_warning(\"training\", train_nimg, trdt_in[\"img\"].shape[0])\n\n    # Generate new data.\n    nimg = train_nimg + test_nimg\n    img, sino, fbp = generate_ct_data(\n        nimg,\n        size,\n        nproj,\n        verbose=verbose,\n    )\n    # Separate training and testing partitions.\n    trdt = {\"img\": img[:train_nimg], \"sino\": sino[:train_nimg], \"fbp\": fbp[:train_nimg]}\n    ttdt = {\"img\": img[train_nimg:], \"sino\": sino[train_nimg:], \"fbp\": fbp[train_nimg:]}\n\n    # Store images, sinograms and filtered back-projections.\n    os.makedirs(cache_path, exist_ok=True)\n    np.savez(\n        npz_train_file,\n        img=img[:train_nimg],\n        sino=sino[:train_nimg],\n        fbp=fbp[:train_nimg],\n    )\n    np.savez(\n        npz_test_file,\n        img=img[train_nimg:],\n        sino=sino[train_nimg:],\n        fbp=fbp[train_nimg:],\n    )\n\n    if verbose:\n        print_output_path(cache_path_display)\n        print_data_size(\"training\", train_nimg)\n        print_data_size(\"testing \", test_nimg)\n        print_data_range(\"images  \", img)\n        print_data_range(\"sinogram\", sino)\n        print_data_range(\"FBP     \", fbp)\n\n    return trdt, ttdt\n\n\ndef load_blur_data(\n    train_nimg: int,\n    test_nimg: int,\n    size: int,\n    blur_kernel: Array,\n    noise_sigma: float,\n    cache_path: Optional[str] = None,\n    verbose: bool = False,\n) -> Tuple[DataSetDict, ...]:  # pragma: no cover\n    \"\"\"Load or generate blurred data based on xdesign foam structures.\n\n    Load or generate blurred data for training of machine learning\n    network models. If cached file exists and enough data of the\n    requested size is available, data is loaded and returned.\n\n    If `size`, `blur_kernel` or `noise_sigma` requested do not match\n    the data read from the cached file, a `RunTimeError` is generated.\n\n    If no cached file is found or not enough data is contained in the\n    file a new data set is generated and stored in `cache_path`. The\n    data is stored in `.npz` format for convenient access via\n    :func:`numpy.load`. The data is saved in two distinct files:\n    `dcnv_foam1_train.npz` and `dcnv_foam1_test.npz` to keep separated\n    training and testing partitions.\n\n    Args:\n        train_nimg: Number of images required for training.\n        test_nimg: Number of images required for testing.\n        size: Size of reconstruction images.\n        blur_kernel: Kernel for blurring the generated images.\n        noise_sigma: Level of additive Gaussian noise to apply.\n        cache_path: Directory in which generated data is saved.\n            Default: ``None``.\n        verbose: Flag indicating whether to print status messages.\n            Default: ``False``.\n\n    Returns:\n       tuple: A tuple (train_ds, test_ds) containing:\n\n           - **train_ds** : Dictionary of training data (includes images\n                            and labels).\n           - **test_ds** : Dictionary of testing data (includes images\n                           and labels).\n    \"\"\"\n    # Set default cache path if not specified\n    cache_path, cache_path_display = get_cache_path(cache_path)\n\n    # Create cache directory and generate data if not already present.\n    npz_train_file = os.path.join(cache_path, \"dcnv_foam1_train.npz\")\n    npz_test_file = os.path.join(cache_path, \"dcnv_foam1_test.npz\")\n\n    if os.path.isfile(npz_train_file) and os.path.isfile(npz_test_file):\n        # Load data and convert arrays to float32.\n        trdt = np.load(npz_train_file)  # Training\n        ttdt = np.load(npz_test_file)  # Testing\n        train_in = trdt[\"image\"].astype(np.float32)\n        train_out = trdt[\"label\"].astype(np.float32)\n        test_in = ttdt[\"image\"].astype(np.float32)\n        test_out = ttdt[\"label\"].astype(np.float32)\n\n        # Check image size\n        if train_in.shape[1] != size:\n            runtime_error_scalar(\"size\", \"training\", size, train_in.shape[1])\n        if test_in.shape[1] != size:\n            runtime_error_scalar(\"size\", \"testing \", size, test_in.shape[1])\n\n        # Check noise_sigma\n        if trdt[\"noise\"] != noise_sigma:\n            runtime_error_scalar(\"noise\", \"training\", noise_sigma, trdt[\"noise\"])\n        if ttdt[\"noise\"] != noise_sigma:\n            runtime_error_scalar(\"noise\", \"testing \", noise_sigma, ttdt[\"noise\"])\n\n        # Check blur kernel\n        blur_train = trdt[\"blur\"].astype(np.float32)\n        if not np.allclose(blur_kernel, blur_train):\n            runtime_error_array(\"blur\", \"testing \", np.abs(blur_kernel - blur_train).max())\n        blur_test = ttdt[\"blur\"].astype(np.float32)\n        if not np.allclose(blur_kernel, blur_test):\n            runtime_error_array(\"blur\", \"testing \", np.abs(blur_kernel - blur_test).max())\n\n        # Check that enough images were restored.\n        if trdt[\"numimg\"] >= train_nimg:\n            if ttdt[\"numimg\"] >= test_nimg:\n                train_ds: DataSetDict = {\n                    \"image\": train_in,\n                    \"label\": train_out,\n                }\n                test_ds: DataSetDict = {\n                    \"image\": test_in,\n                    \"label\": test_out,\n                }\n                if verbose:\n                    print_info(\n                        \"in\",\n                        cache_path_display,\n                        train_ds[\"image\"],\n                        train_ds[\"label\"],\n                        test_ds[\"image\"].shape[0],\n                    )\n\n                return train_ds, test_ds\n\n            elif verbose:\n                print_data_warning(\"testing \", test_nimg, ttdt[\"numimg\"])\n        elif verbose:\n            print_data_warning(\"training\", train_nimg, trdt[\"numimg\"])\n\n    # Generate new data.\n    nimg = train_nimg + test_nimg\n    img, blrn = generate_blur_data(\n        nimg,\n        size,\n        blur_kernel,\n        noise_sigma,\n        verbose=verbose,\n    )\n    # Separate training and testing partitions.\n    train_ds = {\"image\": blrn[:train_nimg], \"label\": img[:train_nimg]}\n    test_ds = {\"image\": blrn[train_nimg:], \"label\": img[train_nimg:]}\n\n    # Store original and blurred images.\n    os.makedirs(cache_path, exist_ok=True)\n    np.savez(\n        npz_train_file,\n        image=train_ds[\"image\"],\n        label=train_ds[\"label\"],\n        numimg=train_nimg,\n        noise=noise_sigma,\n        blur=blur_kernel.astype(np.float32),\n    )\n    np.savez(\n        npz_test_file,\n        image=test_ds[\"image\"],\n        label=test_ds[\"label\"],\n        numimg=test_nimg,\n        noise=noise_sigma,\n        blur=blur_kernel.astype(np.float32),\n    )\n\n    if verbose:\n        print_info(\n            \"out\",\n            cache_path_display,\n            train_ds[\"image\"],\n            train_ds[\"label\"],\n            test_ds[\"image\"].shape[0],\n        )\n\n    return train_ds, test_ds\n\n\ndef load_image_data(\n    train_nimg: int,\n    test_nimg: int,\n    size: int,\n    gray_flag: bool,\n    data_mode: str = \"dn\",\n    cache_path: Optional[str] = None,\n    verbose: bool = False,\n    noise_level: float = 0.1,\n    noise_range: bool = False,\n    transf: Optional[Callable] = None,\n    stride: Optional[int] = None,\n    augment: bool = False,\n) -> Tuple[DataSetDict, ...]:  # pragma: no cover\n    \"\"\"Load or load and preprocess image data.\n\n    Load or load and preprocess image data for training of neural\n    network models. The original source is the BSDS500 data from the\n    Berkeley Segmentation Dataset and Benchmark project. Depending on\n    the intended applications, different preprocessings can be performed\n    to the source data.\n\n    If a cached file exists, and enough images were sampled, data is\n    loaded and returned.\n\n    If either `size` or type of data (gray scale or color) requested\n    does not match the data read from the cached file, a\n    `RunTimeError` is generated. In contrast, there is no checking for\n    the specific contamination (i.e. noise level, blur kernel, etc.).\n\n    If no cached file is found or not enough images were sampled and\n    stored in the file, a new data set is generated and stored in\n    `cache_path`. The data is stored in `.npz` format for convenient\n    access via :func:`numpy.load`. The data is saved in two distinct\n    files: `*_bsds_train.npz` and `*_bsds_test.npz` to keep separated\n    training and testing partitions. The * stands for `dn` if\n    denoising problem or `dcnv` if deconvolution problem. Other types\n    of pre-processings may be specified via the `transf` operator.\n\n    Args:\n        train_nimg: Number of images required for sampling training data.\n        test_nimg: Number of images required for sampling testing data.\n        size: Size of reconstruction images.\n        gray_flag: Flag to indicate if gray scale images or color\n            images. When ``True`` gray scale images are used.\n        data_mode: Type of image problem. Options are: `dn` for\n            denosing, `dcnv` for deconvolution.\n        cache_path: Directory in which processed data is saved.\n            Default: ``None``.\n        verbose: Flag indicating whether to print status messages.\n            Default: ``False``.\n        noise_level: Standard deviation of the Gaussian noise.\n        noise_range: Flag to indicate if a fixed or a random standard\n            deviation must be used. Default: ``False`` i.e. fixed\n            standard deviation given by `noise_level`.\n        transf: Operator for blurring or other non-trivial\n            transformations. Should be able to handle batched (NHWC)\n            data. Default: ``None``.\n        stride: Stride between patch origins (indexed from left-top\n            corner). Default: 0 (i.e. no stride, only one patch per\n            image).\n        augment: Augment training data set by flip and 90 degrees\n            rotation. Default: ``False`` (i.e. no augmentation).\n\n    Returns:\n       tuple: A tuple (train_ds, test_ds) containing:\n\n           - **train_ds** : (DataSetDict): Dictionary of training data\n                            (includes images and labels).\n           - **test_ds** : (DataSetDict): Dictionary of testing data\n                           (includes images and labels).\n    \"\"\"\n    # Set default cache path if not specified\n    cache_path, cache_path_display = get_cache_path(cache_path)\n\n    # Create cache directory and generate data if not already present.\n    npz_train_file = os.path.join(cache_path, data_mode + \"_bsds_train.npz\")\n    npz_test_file = os.path.join(cache_path, data_mode + \"_bsds_test.npz\")\n\n    if os.path.isfile(npz_train_file) and os.path.isfile(npz_test_file):\n        # Load data and convert arrays to float32.\n        trdt = np.load(npz_train_file)  # Training\n        ttdt = np.load(npz_test_file)  # Testing\n        train_in = trdt[\"image\"].astype(np.float32)\n        train_out = trdt[\"label\"].astype(np.float32)\n        test_in = ttdt[\"image\"].astype(np.float32)\n        test_out = ttdt[\"label\"].astype(np.float32)\n\n        if check_img_data_requirements(\n            train_nimg,\n            test_nimg,\n            size,\n            gray_flag,\n            train_in.shape,\n            test_in.shape,\n            trdt[\"numimg\"],\n            ttdt[\"numimg\"],\n            verbose,\n        ):\n\n            train_ds: DataSetDict = {\n                \"image\": train_in,\n                \"label\": train_out,\n            }\n            test_ds: DataSetDict = {\n                \"image\": test_in,\n                \"label\": test_out,\n            }\n            if verbose:\n                print_info(\n                    \"in\",\n                    cache_path_display,\n                    train_ds[\"image\"],\n                    train_ds[\"label\"],\n                    test_ds[\"image\"].shape[0],\n                )\n\n                print(\n                    \"NOTE: If blur kernel or noise parameter are changed, the cache \"\n                    \"must be manually\\n      deleted to ensure that the training data\"\n                    \" is regenerated with the new\\n      parameters.\"\n                )\n\n            return train_ds, test_ds\n\n    # Check if BSDS folder exists if not create and download BSDS data.\n    bsds_cache_path = os.path.join(cache_path, \"BSDS\")\n    if not os.path.isdir(bsds_cache_path):\n        os.makedirs(bsds_cache_path)\n        get_bsds_data(path=bsds_cache_path, verbose=verbose)\n    # Load data, convert arrays to float32 and return\n    # after pre-processing for specified data_mode.\n    npz_file = os.path.join(bsds_cache_path, \"bsds500.npz\")\n    npz = np.load(npz_file)\n    imgs_train = npz[\"imgstr\"].astype(np.float32)\n    imgs_test = npz[\"imgstt\"].astype(np.float32)\n\n    # Generate new data.\n    if stride is None:\n        multi = False\n    else:\n        multi = True\n\n    config: ConfigImageSetDict = {\n        \"output_size\": size,\n        \"stride\": stride,\n        \"multi\": multi,\n        \"augment\": augment,\n        \"run_gray\": gray_flag,\n        \"num_img\": train_nimg,\n        \"test_num_img\": test_nimg,\n        \"data_mode\": data_mode,\n        \"noise_level\": noise_level,\n        \"noise_range\": noise_range,\n        \"test_split\": 0.2,\n        \"seed\": 1234,\n    }\n    train_ds, test_ds = build_image_dataset(imgs_train, imgs_test, config, transf)\n    # Store generated images.\n    os.makedirs(cache_path, exist_ok=True)\n    np.savez(\n        npz_train_file,\n        image=train_ds[\"image\"],\n        label=train_ds[\"label\"],\n        numimg=train_nimg,\n    )\n    np.savez(\n        npz_test_file,\n        image=test_ds[\"image\"],\n        label=test_ds[\"label\"],\n        numimg=test_nimg,\n    )\n\n    if verbose:\n        print_info(\n            \"out\",\n            cache_path_display,\n            train_ds[\"image\"],\n            train_ds[\"label\"],\n            test_ds[\"image\"].shape[0],\n        )\n\n    return train_ds, test_ds\n\n\ndef check_img_data_requirements(\n    train_nimg: int,\n    test_nimg: int,\n    size: int,\n    gray_flag: bool,\n    train_in_shp: Shape,\n    test_in_shp: Shape,\n    train_nimg_avail: int,\n    test_nimg_avail: int,\n    verbose: bool,\n) -> bool:  # pragma: no cover\n    \"\"\"Check data loaded with respect to data requirements.\n\n    Args:\n        train_nimg: Number of images required for training data.\n        test_nimg: Number of images required for testing data.\n        size: Size of images requested.\n        gray_flag: Flag to indicate if gray scale images or color images\n            are requested. When ``True`` gray scale images are used,\n            therefore, one channel is expected.\n        train_in_shp: Shape of images/patches loaded as training data.\n        test_in_shp: Shape of images/patches loaded as testing data.\n        train_nimg_avail: Number of images available in loaded training\n            image data.\n        test_nimg_avail: Number of images available in loaded testing\n            image data.\n        verbose: Flag indicating whether to print status messages.\n\n    Returns:\n       ``True`` if the loaded image data satifies requirements of size,\n       number of samples and number of channels and ``False`` otherwise.\n    \"\"\"\n    # Check image size\n    if train_in_shp[1] != size:\n        runtime_error_scalar(\"size\", \"training\", size, train_in_shp[1])\n    if test_in_shp[1] != size:\n        runtime_error_scalar(\"size\", \"testing \", size, test_in_shp[1])\n\n    # Check gray scale or color images.\n    C_train = train_in_shp[-1]\n    C_test = test_in_shp[-1]\n    if gray_flag:\n        C = 1\n    else:\n        C = 3\n    if C_train != C:\n        runtime_error_scalar(\"channels\", \"training\", C, C_train)\n    if C_test != C:\n        runtime_error_scalar(\"channels\", \"testing \", C, C_test)\n\n    # Check that enough images were sampled.\n    if train_nimg_avail >= train_nimg:\n        if test_nimg_avail >= test_nimg:\n            return True\n\n        elif verbose:\n            print_data_warning(\"testing \", test_nimg, test_nimg_avail)\n    elif verbose:\n        print_data_warning(\"training\", train_nimg, train_nimg_avail)\n\n    return False\n\n\ndef print_input_path(path_display: str):  # pragma: no cover\n    \"\"\"Display path from where data is being loaded.\n\n    Args:\n        path_display: Path for loading data.\n    \"\"\"\n    print(f\"Data read from path: {path_display}\")\n\n\ndef print_output_path(path_display: str):  # pragma: no cover\n    \"\"\"Display path where data is being stored.\n\n    Args:\n        path_display: Path for storing data.\n    \"\"\"\n    print(f\"Storing data in path: {path_display}\")\n\n\ndef print_data_range(idstring: str, data: Array):  # pragma: no cover\n    \"\"\"Display min and max values of given data array.\n\n    Args:\n        idstring: Data descriptive string.\n        data: Array to compute min and max.\n    \"\"\"\n    print(f\"Data range --{idstring}--  Min: {data.min():>5.2f}  \" f\"Max: {data.max():>5.2f}\")\n\n\ndef print_data_size(idstring: str, size: int):  # pragma: no cover\n    \"\"\"Display integer given.\n\n    Args:\n        idstring: Data descriptive string.\n        size: Integer representing size of a set.\n    \"\"\"\n    print(f\"Set --{idstring}-- size: {size}\")\n\n\ndef print_info(\n    iomode: str, path_display: str, train_in: Array, train_out: Array, test_size: int\n):  # pragma: no cover\n    \"\"\"Display information related to data input/output.\n\n    Args:\n        iomode: Identification of input (load) or ouput (save)\n            operation.\n        path_display: Input or output path.\n        train_in: Input features in training set.\n        train_out: Outputs in training set.\n        test_size: Size of testing set.\n    \"\"\"\n    if iomode == \"in\":\n        print_input_path(path_display)\n    else:\n        print_output_path(path_display)\n    print_data_size(\"training\", train_in.shape[0])\n    print_data_size(\"testing \", test_size)\n    print_data_range(\" images \", train_in)\n    print_data_range(\" labels \", train_out)\n\n\ndef print_data_warning(idstring: str, requested: int, available: int):  # pragma: no cover\n    \"\"\"Display warning related to data size demands not satisfied.\n\n    Args:\n        idstring: Data descriptive string.\n        requested: Size of data set requested.\n        available: Size of data set available.\n    \"\"\"\n    print(\n        f\"Not enough images sampled in {idstring} file. \"\n        f\"Requested: {requested}  Available: {available}\"\n    )\n\n\ndef runtime_error_scalar(\n    type: str, idstring: str, requested: Union[int, float], available: Union[int, float]\n):\n    \"\"\"Raise run time error related to unsatisfied scalar parameter request.\n\n    Raise run time error related to scalar parameter request not satisfied\n    in available data.\n\n    Args:\n        type: Type of parameter in the request.\n        idstring: Data descriptive string.\n        requested: Parameter value requested.\n        available: Parameter value available in data.\n    \"\"\"\n    raise RuntimeError(\n        f\"Requested value of argument '{type}' does not match value \"\n        f\"read from {idstring} file. Requested: {requested}  Available: \"\n        f\"{available}.\\nDelete cache and check data source.\"\n    )\n\n\ndef runtime_error_array(type: str, idstring: str, maxdiff: float):\n    \"\"\"Raise run time error related to unsatisfied array parameter request.\n\n    Raise run time error related to array parameter request not satisfied\n    in available data.\n\n    Args:\n        type: Type of parameter in the request.\n        idstring: Data descriptive string.\n        maxdiff: Maximum error between requested and available array\n           entries.\n    \"\"\"\n    raise RuntimeError(\n        f\"Requested value of argument '{type}' does not match value \"\n        f\"read from {idstring} file. Maximum array difference: \"\n        f\"{maxdiff:>5.3f}.\\nDelete cache and check data source.\"\n    )\n"
  },
  {
    "path": "scico/flax/examples/typed_dict.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Definition of typed dictionaries for training data.\"\"\"\n\nimport sys\nfrom typing import Optional, Union\n\nif sys.version_info >= (3, 8):\n    from typing import TypedDict  # pylint: disable=no-name-in-module\nelse:\n    from typing_extensions import TypedDict\n\nfrom scico.numpy import Array\nfrom scico.typing import Shape\n\n\nclass CTDataSetDict(TypedDict):\n    \"\"\"Definition of the structure to store generated CT data.\"\"\"\n\n    img: Array  # original image\n    sino: Array  # sinogram\n    fbp: Array  # filtered back projection\n\n\nclass ConfigImageSetDict(TypedDict):\n    \"\"\"Definition of the configuration for image data preprocessing.\"\"\"\n\n    output_size: Union[int, Shape]\n    stride: Optional[Union[Shape, int]]\n    multi: bool\n    augment: bool\n    run_gray: bool\n    num_img: int\n    test_num_img: int\n    data_mode: str\n    noise_level: float\n    noise_range: bool\n    test_split: float\n    seed: float\n"
  },
  {
    "path": "scico/flax/inverse.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\"\"\"Flax implementation of different imaging inversion models.\"\"\"\n\nimport warnings\n\nwarnings.simplefilter(action=\"ignore\", category=FutureWarning)\n\nfrom functools import partial\nfrom typing import Any, Callable, Tuple\n\nimport jax.numpy as jnp\nfrom jax import jit, lax, random\n\nfrom flax.core import Scope  # noqa\nfrom flax.linen.module import _Sentinel  # noqa\nfrom flax.linen.module import Module, compact\nfrom scico.flax import ResNet\nfrom scico.linop import LinearOperator\nfrom scico.numpy import Array\nfrom scico.typing import DType, PRNGKey, Shape\n\n# The imports of Scope and _Sentinel (above) are required to silence\n# \"cannot resolve forward reference\" warnings when building sphinx api\n# docs.\n\n\nModuleDef = Any\n\n\nclass MoDLNet(Module):\n    \"\"\"Flax implementation of MoDL :cite:`aggarwal-2019-modl`.\n\n    Flax implementation of the model-based deep learning (MoDL)\n    architecture for inverse problems described in :cite:`aggarwal-2019-modl`.\n\n    Args:\n        operator: Operator for computing forward and adjoint mappings.\n        depth: Depth of MoDL net.\n        channels: Number of channels of input tensor.\n        num_filters: Number of filters in the convolutional layer of the\n            block. Corresponds to the number of channels in the output\n            tensor.\n        block_depth: Number of layers in the computational block.\n        kernel_size: Size of the convolution filters.\n        strides: Convolution strides.\n        lmbda_ini: Initial value of the regularization weight `lambda`.\n        dtype: Output dtype. Default: :attr:`~numpy.float32`.\n        cg_iter: Number of iterations for cg solver.\n    \"\"\"\n\n    operator: ModuleDef\n    depth: int\n    channels: int\n    num_filters: int\n    block_depth: int\n    kernel_size: Tuple[int, int] = (3, 3)\n    strides: Tuple[int, int] = (1, 1)\n    lmbda_ini: float = 0.5\n    dtype: Any = jnp.float32\n    cg_iter: int = 10\n\n    @compact\n    def __call__(self, y: Array, train: bool = True) -> Array:\n        \"\"\"Apply MoDL net for inversion.\n\n        Args:\n            y: The array with signal to invert.\n            train: Flag to differentiate between training and testing\n               stages.\n\n        Returns:\n            The reconstructed signal.\n        \"\"\"\n\n        def lmbda_init_wrap(rng: PRNGKey, shape: Shape, dtype: DType = self.dtype) -> Array:\n            return jnp.ones(shape, dtype) * self.lmbda_ini\n\n        lmbda = self.param(\"lmbda\", lmbda_init_wrap, (1,))\n\n        resnet = ResNet(\n            self.block_depth,\n            self.channels,\n            self.num_filters,\n            self.kernel_size,\n            self.strides,\n            dtype=self.dtype,\n        )\n\n        ah_f = lambda v: jnp.atleast_3d(self.operator.adj(v.reshape(self.operator.output_shape)))\n\n        Ahb = lax.map(ah_f, y)\n        x = Ahb\n\n        ahaI_f = lambda v: self.operator.adj(self.operator(v)) + lmbda * v\n\n        cgsol = lambda b: jnp.atleast_3d(\n            cg_solver(ahaI_f, b.reshape(self.operator.input_shape), maxiter=self.cg_iter)\n        )\n\n        for i in range(self.depth):\n            z = resnet(x, train)\n            # Solve: (AH A + lmbda I) x = Ahb + lmbda * z\n            b = Ahb + lmbda * z\n            x = lax.map(cgsol, b)\n        return x\n\n\ndef cg_solver(A: Callable, b: Array, x0: Array = None, maxiter: int = 50) -> Array:\n    r\"\"\"Conjugate gradient solver.\n\n    Solve the linear system :math:`A\\mb{x} = \\mb{b}`, where :math:`A` is\n    positive definite, via the conjugate gradient method. This is a light\n    version constructed to be differentiable with the autograd\n    functionality from jax. Therefore, (i) it uses :meth:`jax.lax.scan`\n    to execute a fixed number of iterations and (ii) it assumes that the\n    linear operator may use :meth:`jax.pure_callback`. Due to the\n    utilization of a while cycle, :meth:`scico.cg` is not differentiable\n    by jax and :meth:`jax.scipy.sparse.linalg.cg` does not support\n    functions using :meth:`jax.pure_callback`, which is why an additional\n    conjugate gradient function has been implemented.\n\n    Args:\n        A: Function implementing linear operator :math:`A`, should be\n            positive definite.\n        b: Input array :math:`\\mb{b}`.\n        x0: Initial solution.\n        maxiter: Maximum iterations.\n\n    Returns:\n        x: Solution array.\n    \"\"\"\n\n    def fun(carry, _):\n        \"\"\"Function implementing one iteration of the conjugate gradient solver.\"\"\"\n        x, r, p, num = carry\n        Ap = A(p)\n        alpha = num / (p.ravel().conj().T @ Ap.ravel())\n        x = x + alpha * p\n        r = r - alpha * Ap\n        num_old = num\n        num = r.ravel().conj().T @ r.ravel()\n        beta = num / num_old\n        p = r + beta * p\n\n        return (x, r, p, num), None\n\n    if x0 is None:\n        x0 = jnp.zeros_like(b)\n    r0 = b - A(x0)\n    num0 = r0.ravel().conj().T @ r0.ravel()\n    carry = (x0, r0, r0, num0)\n    carry, _ = lax.scan(fun, carry, xs=None, length=maxiter)\n    return carry[0]\n\n\nclass ODPProxDnBlock(Module):\n    \"\"\"Flax implementation of ODP proximal gradient denoise block.\n\n    Flax implementation of the unrolled optimization with deep priors\n    (ODP) proximal gradient block for denoising :cite:`diamond-2018-odp`.\n\n    Args:\n        operator: Operator for computing forward and adjoint mappings.\n            In this case it corresponds to the identity operator and is\n            used at the network level.\n        depth: Number of layers in block.\n        channels: Number of channels of input tensor.\n        num_filters: Number of filters in the convolutional layer of the\n            block. Corresponds to the number of channels in the output\n            tensor.\n        kernel_size: Size of the convolution filters.\n        strides: Convolution strides.\n        alpha_ini: Initial value of the fidelity weight `alpha`.\n        dtype: Output dtype. Default: :attr:`~numpy.float32`.\n    \"\"\"\n\n    operator: ModuleDef\n    depth: int\n    channels: int\n    num_filters: int\n    kernel_size: Tuple[int, int] = (3, 3)\n    strides: Tuple[int, int] = (1, 1)\n    alpha_ini: float = 0.2\n    dtype: Any = jnp.float32\n\n    def batch_op_adj(self, y: Array) -> Array:\n        \"\"\"Batch application of adjoint operator.\"\"\"\n        return self.operator.adj(y)\n\n    @compact\n    def __call__(self, x: Array, y: Array, train: bool = True) -> Array:\n        \"\"\"Apply denoising block.\n\n        Args:\n            x: The array with current stage of denoised signal.\n            y: The array with noisy signal.\n            train: Flag to differentiate between training and testing\n                stages.\n\n        Returns:\n            The block output (i.e. next stage of denoised signal).\n        \"\"\"\n\n        def alpha_init_wrap(rng: PRNGKey, shape: Shape, dtype: DType = self.dtype) -> Array:\n            return jnp.ones(shape, dtype) * self.alpha_ini\n\n        alpha = self.param(\"alpha\", alpha_init_wrap, (1,))\n\n        resnet = ResNet(\n            self.depth,\n            self.channels,\n            self.num_filters,\n            self.kernel_size,\n            self.strides,\n            dtype=self.dtype,\n        )\n\n        x = (resnet(x, train) + y * alpha) / (1.0 + alpha)\n\n        return x\n\n\nclass ODPProxDcnvBlock(Module):\n    \"\"\"Flax implementation of ODP proximal gradient deconvolution block.\n\n    Flax implementation of the unrolled optimization with deep priors\n    (ODP) proximal gradient block for deconvolution under Gaussian noise\n    :cite:`diamond-2018-odp`.\n\n    Args:\n        operator: Operator for computing forward and adjoint mappings.\n            In this case it correponds to a circular convolution operator.\n        depth: Number of layers in block.\n        channels: Number of channels of input tensor.\n        num_filters: Number of filters in the convolutional layer of the\n            block. Corresponds to the number of channels in the output\n            tensor.\n        kernel_size: Size of the convolution filters.\n        strides: Convolution strides.\n        alpha_ini: Initial value of the fidelity weight `alpha`.\n        dtype: Output dtype. Default: :attr:`~numpy.float32`.\n    \"\"\"\n\n    operator: ModuleDef\n    depth: int\n    channels: int\n    num_filters: int\n    kernel_size: Tuple[int, int] = (3, 3)\n    strides: Tuple[int, int] = (1, 1)\n    alpha_ini: float = 0.99\n    dtype: Any = jnp.float32\n\n    def setup(self):\n        \"\"\"Computing operator norm and setting operator for batch\n        evaluation and defining network layers.\"\"\"\n        self.operator_norm = jnp.sqrt(power_iteration(self.operator.H @ self.operator)[0].real)\n\n        self.ah_f = lambda v: jnp.atleast_3d(\n            self.operator.adj(v.reshape(self.operator.output_shape))\n        )\n\n        self.resnet = ResNet(\n            self.depth,\n            self.channels,\n            self.num_filters,\n            self.kernel_size,\n            self.strides,\n            dtype=self.dtype,\n        )\n\n        def alpha_init_wrap(rng: PRNGKey, shape: Shape, dtype: DType = self.dtype) -> Array:\n            return jnp.ones(shape, dtype) * self.alpha_ini\n\n        self.alpha = self.param(\"alpha\", alpha_init_wrap, (1,))\n\n    def batch_op_adj(self, y: Array) -> Array:\n        \"\"\"Batch application of adjoint operator.\"\"\"\n        return lax.map(self.ah_f, y)\n\n    def __call__(self, x: Array, y: Array, train: bool = True) -> Array:\n        \"\"\"Apply debluring block.\n\n        Args:\n            x: The array with current stage of reconstructed signal.\n            y: The array with signal to invert.\n            train: Flag to differentiate between training and testing\n                stages.\n\n        Returns:\n            The block output (i.e. next stage of reconstructed signal).\n        \"\"\"\n\n        # DFT over spatial dimensions\n        fft_shape: Shape = x.shape[1:-1]\n        fft_axes: Tuple[int, int] = (1, 2)\n\n        scale = 1.0 / (self.alpha * self.operator_norm**2 + 1)\n\n        x = jnp.fft.irfftn(\n            jnp.fft.rfftn(\n                self.alpha * self.batch_op_adj(y) + self.resnet(x, train),\n                s=fft_shape,\n                axes=fft_axes,\n            )\n            / scale,\n            s=fft_shape,\n            axes=fft_axes,\n        )\n\n        return x\n\n\nclass ODPGrDescBlock(Module):\n    r\"\"\"Flax implementation of ODP gradient descent with :math:`\\ell_2` loss block.\n\n    Flax implementation of the unrolled optimization with deep priors\n    (ODP) gradient descent block for inversion using :math:`\\ell_2` loss\n    described in :cite:`diamond-2018-odp`.\n\n    Args:\n        operator: Operator for computing forward and adjoint mappings. In\n            this case it corresponds to the identity operator and is used\n            at the network level.\n        depth: Number of layers in block.\n        channels: Number of channels of input tensor.\n        num_filters: Number of filters in the convolutional layer of the\n            block. Corresponds to the number of channels in the output\n            tensor.\n        kernel_size: Size of the convolution filters.\n        strides: Convolution strides.\n        alpha_ini: Initial value of the fidelity weight `alpha`.\n        dtype: Output dtype. Default: :attr:`~numpy.float32`.\n    \"\"\"\n\n    operator: ModuleDef\n    depth: int\n    channels: int\n    num_filters: int\n    kernel_size: Tuple[int, int] = (3, 3)\n    strides: Tuple[int, int] = (1, 1)\n    alpha_ini: float = 0.2\n    dtype: Any = jnp.float32\n\n    def setup(self):\n        \"\"\"Setting operator for batch evaluation and defining network layers.\"\"\"\n        self.ah_f = lambda v: jnp.atleast_3d(\n            self.operator.adj(v.reshape(self.operator.output_shape))\n        )\n        self.a_f = lambda v: jnp.atleast_3d(self.operator(v.reshape(self.operator.input_shape)))\n\n        self.resnet = ResNet(\n            self.depth,\n            self.channels,\n            self.num_filters,\n            self.kernel_size,\n            self.strides,\n            dtype=self.dtype,\n        )\n\n        def alpha_init_wrap(rng: PRNGKey, shape: Shape, dtype: DType = self.dtype) -> Array:\n            return jnp.ones(shape, dtype) * self.alpha_ini\n\n        self.alpha = self.param(\"alpha\", alpha_init_wrap, (1,))\n\n    def batch_op_adj(self, y: Array) -> Array:\n        \"\"\"Batch application of adjoint operator.\"\"\"\n        return lax.map(self.ah_f, y)\n\n    def __call__(self, x: Array, y: Array, train: bool = True) -> Array:\n        \"\"\"Apply gradient descent block.\n\n        Args:\n            x: The array with current stage of reconstructed signal.\n            y: The array with signal to invert.\n            train: Flag to differentiate between training and testing\n                stages.\n\n        Returns:\n            The block output (i.e. next stage of inverted signal).\n        \"\"\"\n\n        x = self.resnet(x, train) - self.alpha * self.batch_op_adj(lax.map(self.a_f, x) - y)\n\n        return x\n\n\nclass ODPNet(Module):\n    \"\"\"Flax implementation of ODP network :cite:`diamond-2018-odp`.\n\n    Flax implementation of the unrolled optimization with deep priors\n    (ODP) network for inverse problems described in\n    :cite:`diamond-2018-odp`. It can be constructed with proximal gradient\n    blocks or gradient descent blocks.\n\n    Args:\n        operator: Operator for computing forward and adjoint mappings.\n        depth: Depth of MoDL net.\n        channels: Number of channels of input tensor.\n        num_filters: Number of filters in the convolutional layer of the\n            block. Corresponds to the number of channels in the output\n            tensor.\n        block_depth: Number of layers in the computational block.\n        kernel_size: Size of the convolution filters.\n        strides: Convolution strides.\n        alpha_ini: Initial value of the fidelity weight `alpha`.\n        dtype: Output dtype. Default: :attr:`~numpy.float32`.\n        odp_block: processing block to apply. Default\n            :class:`.ODPProxDnBlock`.\n    \"\"\"\n\n    operator: ModuleDef\n    depth: int\n    channels: int\n    num_filters: int\n    block_depth: int\n    kernel_size: Tuple[int, int] = (3, 3)\n    strides: Tuple[int, int] = (1, 1)\n    alpha_ini: float = 0.5\n    dtype: Any = jnp.float32\n    odp_block: Callable = ODPProxDnBlock\n\n    @compact\n    def __call__(self, y: Array, train: bool = True) -> Array:\n        \"\"\"Apply ODP net for inversion.\n\n        Args:\n            y: The array with signal to invert.\n            train: Flag to differentiate between training and testing\n                stages.\n\n        Returns:\n            The reconstructed signal.\n        \"\"\"\n        block = partial(\n            self.odp_block,\n            operator=self.operator,\n            depth=self.block_depth,\n            channels=self.channels,\n            num_filters=self.num_filters,\n            kernel_size=self.kernel_size,\n            strides=self.strides,\n            dtype=self.dtype,\n        )\n\n        # Initial block handles initial inversion.\n        # Not all operators are batch-ready.\n        alpha0_i = self.alpha_ini\n        block0 = block(alpha_ini=alpha0_i)\n        x = block0.batch_op_adj(y)\n        x = block0(x, y, train)\n        alpha0_i /= 2.0\n\n        for i in range(self.depth - 1):\n            x = block(alpha_ini=alpha0_i)(x, y, train)\n            alpha0_i /= 2.0\n        return x\n\n\n@partial(jit, static_argnums=0)\ndef power_iteration(A: LinearOperator, maxiter: int = 100):\n    \"\"\"Compute largest eigenvalue of a diagonalizable :class:`.LinearOperator`.\n\n    Compute largest eigenvalue of a diagonalizable :class:`LinearOperator`\n    using power iteration. This function has the same functionality as\n    :class:`.linop.power_iteration` but is implemented using lax operations to\n    allow jitting and general jax function composition.\n\n    Args:\n        A: :class:`LinearOperator` used for computation. Must be diagonalizable.\n        maxiter: Maximum number of power iterations to use.\n\n    Returns:\n        tuple: A tuple (`mu`, `v`) containing:\n\n            - **mu**: Estimate of largest eigenvalue of `A`.\n            - **v**: Eigenvector of `A` with eigenvalue `mu`.\n\n    \"\"\"\n    key = random.PRNGKey(0)\n    v = random.normal(key, shape=A.input_shape, dtype=A.input_dtype)\n\n    v = v / jnp.linalg.norm(v)\n\n    init_val = (0, v, v, 1.0)\n\n    def cond_fun(val):\n        return jnp.logical_and(val[0] <= maxiter, val[3] > 0.0)\n\n    def body_fun(val):\n        i, v, Av, normAv = val\n        v = Av / normAv\n        i = i + 1\n        Av = A @ v\n        normAv = jnp.linalg.norm(Av)\n        return (i, v, Av, normAv)\n\n    def true_fun(v, Av, normAv):\n        return jnp.sum(v.conj() * Av) / jnp.linalg.norm(v) ** 2, Av / normAv\n\n    def false_fun(v, Av, normAv):\n        return 0.0 * normAv, Av  # Multiplication by zero used to preserve data type\n\n    i, v, Av, normAv = lax.while_loop(cond_fun, body_fun, init_val)\n    mu, v = lax.cond(normAv > 0.0, true_fun, false_fun, v, Av, normAv)\n    return mu, v\n"
  },
  {
    "path": "scico/flax/train/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Utilities for training Flax models.\"\"\"\n"
  },
  {
    "path": "scico/flax/train/apply.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Functionality to evaluate Flax trained model.\n\nUses data parallel evaluation.\n\"\"\"\n\nfrom typing import Any, Callable, Optional, Tuple\n\nimport jax\nimport jax.numpy as jnp\n\nfrom flax import jax_utils\nfrom scico.flax import create_input_iter\nfrom scico.numpy import Array\n\nfrom .checkpoints import checkpoint_restore\nfrom .clu_utils import get_parameter_overview\nfrom .learning_rate import create_cnst_lr_schedule\nfrom .state import create_basic_train_state\nfrom .typed_dict import ConfigDict, DataSetDict, ModelVarDict\n\nModuleDef = Any\n\n\ndef apply_fn(model: ModuleDef, variables: ModelVarDict, batch: DataSetDict) -> Array:\n    \"\"\"Apply current model.\n\n    Assumes sharded batched data and replicated variables for\n    distributed processing.\n\n    This function is intended to be used via\n    :meth:`~scico.flax.only_apply`, not directly.\n\n    Args:\n        model: Flax model to apply.\n        variables: State of model parameters (replicated).\n        batch: Sharded and batched training data.\n\n    Returns:\n        Output computed by given model.\n    \"\"\"\n    output = model.apply(variables, batch[\"image\"], train=False, mutable=False)\n    return output\n\n\ndef only_apply(\n    config: ConfigDict,\n    model: ModuleDef,\n    test_ds: DataSetDict,\n    apply_fn: Callable = apply_fn,\n    variables: Optional[ModelVarDict] = None,\n) -> Tuple[Array, ModelVarDict]:\n    \"\"\"Execute model application loop.\n\n    Args:\n        config: Hyperparameter configuration.\n        model: Flax model to apply.\n        test_ds: Dictionary of testing data (includes images and\n            labels).\n        apply_fn: A hook for a function that applies current model.\n            Default: :meth:`~scico.flax.train.apply.apply_fn`, i.e. use\n            the standard apply function.\n        variables: Model parameters to use for evaluation. Default:\n            ``None`` (i.e. read from checkpoint).\n\n    Returns:\n        Output of model evaluated at the input provided in `test_ds`.\n\n    Raises:\n        RuntimeError: If no model variables and no checkpoint are\n            specified.\n    \"\"\"\n    if \"workdir\" in config:\n        workdir: str = config[\"workdir\"]\n    else:\n        workdir = \"./\"\n\n    if \"checkpointing\" in config:\n        checkpointing: bool = config[\"checkpointing\"]\n    else:\n        checkpointing = False\n\n    # Configure seed.\n    key = jax.random.key(config[\"seed\"])\n\n    if variables is None:\n        if checkpointing:  # pragma: no cover\n            ishape = test_ds[\"image\"].shape[1:3]\n            lr_ = create_cnst_lr_schedule(config)\n            empty_state = create_basic_train_state(key, config, model, ishape, lr_)\n            state = checkpoint_restore(empty_state, workdir)\n            if hasattr(state, \"batch_stats\"):\n                variables = {\n                    \"params\": state.params,\n                    \"batch_stats\": state.batch_stats,\n                }  # type: ignore\n                print(get_parameter_overview(variables[\"params\"]))\n                print(get_parameter_overview(variables[\"batch_stats\"]))\n            else:\n                variables = {\"params\": state.params, \"batch_stats\": {}}\n                print(get_parameter_overview(variables[\"params\"]))\n        else:\n            raise RuntimeError(\"No variables or checkpoint provided.\")\n\n    # For distributed testing\n    local_batch_size = config[\"batch_size\"] // jax.process_count()\n    size_device_prefetch = 2  # Set for GPU\n    # Set data iterator\n    eval_dt_iter = create_input_iter(\n        key,  # eval: no permutation\n        test_ds,\n        local_batch_size,\n        size_device_prefetch,\n        model.dtype,\n        train=False,\n    )\n    p_apply_step = jax.pmap(apply_fn, axis_name=\"batch\", static_broadcasted_argnums=0)\n\n    # Evaluate model with provided variables\n    variables = jax_utils.replicate(variables)\n    num_examples = test_ds[\"image\"].shape[0]\n    steps_ = num_examples // config[\"batch_size\"]\n    output_lst = []\n    for _ in range(steps_):\n        eval_batch = next(eval_dt_iter)\n        output_batch = p_apply_step(model, variables, eval_batch)\n        output_lst.append(output_batch.reshape((-1,) + output_batch.shape[-3:]))\n\n    # Allow for completing the async run\n    jax.random.normal(jax.random.key(0), ()).block_until_ready()\n\n    # Extract one copy of variables\n    variables = jax_utils.unreplicate(variables)\n    # Convert to array\n    output = jnp.array(output_lst)\n    # Remove leading dimension\n    output = output.reshape((-1,) + output.shape[-3:])\n\n    return output, variables  # type: ignore\n"
  },
  {
    "path": "scico/flax/train/checkpoints.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Utilities for checkpointing Flax models.\"\"\"\n\nimport logging\nfrom pathlib import Path\nfrom typing import Union\n\ntry:\n    import orbax.checkpoint as ocp\n\n    have_orbax = True\n    if not hasattr(ocp, \"CheckpointManager\") or not hasattr(ocp, \"checkpoint_managers\"):\n        have_orbax = False\nexcept ImportError:\n    have_orbax = False\n\nif have_orbax:\n    from orbax.checkpoint.checkpoint_managers import LatestN\n\n    logging.getLogger(\"absl\").addFilter(logging.Filter(\"could not be identified as a temporary\"))\n\n# remove the handler that orbax.checkpoint adds to the root logger.\n# see https://github.com/google/orbax/issues/1951\nfor h in logging.root.handlers.copy():\n    h.close()\n    logging.root.removeHandler(h)\n\nfrom .state import TrainState\nfrom .typed_dict import ConfigDict\n\n\ndef checkpoint_restore(\n    state: TrainState, workdir: Union[str, Path], ok_no_ckpt: bool = False\n) -> TrainState:\n    \"\"\"Load model and optimiser state.\n\n    Args:\n        state: Flax train state which includes model and optimiser\n            parameters.\n        workdir: Checkpoint file or directory of checkpoints to restore\n            from.\n        ok_no_ckpt: Flag to indicate if a checkpoint is expected. If\n            ``False``, an error is generated if a checkpoint is not\n            found.\n\n    Returns:\n        A restored Flax train state updated from checkpoint file is\n        returned. If no checkpoint files are present and checkpoints are\n        not strictly expected it returns the passed-in `state` unchanged.\n\n    Raises:\n        FileNotFoundError: If a checkpoint is expected and is not found.\n    \"\"\"\n    if not have_orbax:\n        raise RuntimeError(\"Package orbax.checkpoint is required for use of this function.\")\n    # Check if workdir is Path or convert to Path\n    workdir_ = workdir\n    if isinstance(workdir_, str):\n        workdir_ = Path(workdir_)\n    if workdir_.exists():\n        mngr = ocp.CheckpointManager(\n            workdir_,\n        )\n        step = mngr.latest_step()\n        if step is not None:\n            restored = mngr.restore(\n                step, args=ocp.args.Composite(state=ocp.args.StandardRestore(state))\n            )\n            mngr.wait_until_finished()\n            mngr.close()\n            state = restored.state\n    elif not ok_no_ckpt:\n        raise FileNotFoundError(\"Could not read from checkpoint: \" + str(workdir) + \".\")\n\n    return state\n\n\ndef checkpoint_save(state: TrainState, config: ConfigDict, workdir: Union[str, Path]):\n    \"\"\"Store model, model configuration, and optimiser state.\n\n    Note that naming is slightly different to distinguish from Flax\n    functions.\n\n    Args:\n        state: Flax train state which includes model and optimiser\n            parameters.\n        config: Python dictionary including model train configuration.\n        workdir: Path in which to store checkpoint files.\n    \"\"\"\n    if not have_orbax:\n        raise RuntimeError(\"Package orbax.checkpoint is required for use of this function.\")\n    # Check if workdir is Path or convert to Path\n    workdir_ = workdir\n    if isinstance(workdir_, str):\n        workdir_ = Path(workdir_)\n    options = ocp.CheckpointManagerOptions(preservation_policy=LatestN(3), create=True)\n    mngr = ocp.CheckpointManager(\n        workdir_,\n        options=options,\n    )\n    step = int(state.step)\n    # Remove non-serializable partial functools in post_lst if it exists\n    config_ = config.copy()\n    if \"post_lst\" in config_:\n        config_.pop(\"post_lst\", None)  # type: ignore\n    mngr.save(\n        step,\n        args=ocp.args.Composite(\n            state=ocp.args.StandardSave(state),\n            config=ocp.args.JsonSave(config_),\n        ),\n    )\n    mngr.wait_until_finished()\n    mngr.close()\n"
  },
  {
    "path": "scico/flax/train/clu_utils.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Utilities for displaying Flax models.\"\"\"\n\n# These utilities have been copied from the Common Loop Utils (CLU)\n#   https://github.com/google/CommonLoopUtils/tree/main/clu\n# and have been modified to remove TensorFlow dependencies\n# CLU is licensed under the Apache License, Version 2.0, which may\n# be obtained from\n#   http://www.apache.org/licenses/LICENSE-2.0\n\n\nimport warnings\n\nwarnings.simplefilter(action=\"ignore\", category=FutureWarning)\n\nimport dataclasses\nfrom typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union\n\nimport numpy as np\n\nimport jax\n\nimport flax\n\nPyTree = Any\nParamsContainer = Union[Dict[str, np.ndarray], Mapping[str, Mapping[str, Any]]]\n\n\n@dataclasses.dataclass\nclass ParamRow:\n    \"\"\"Definition of the structure of a row for printing parameters without stats.\"\"\"\n\n    name: str\n    shape: Tuple[int]\n    size: int\n\n\n@dataclasses.dataclass\nclass ParamRowWithStats(ParamRow):\n    \"\"\"Definition of the structure of a row for printing parameters with stats.\"\"\"\n\n    mean: float\n    std: float\n\n\ndef flatten_dict(\n    input_dict: Dict[str, Any], prefix: str = \"\", delimiter: str = \"/\"\n) -> Dict[str, Any]:\n    \"\"\"Flatten keys of a nested dictionary.\n\n    Args:\n        input_dict: Nested dictionary.\n        prefix: Prefix of already flatten. Default: empty string.\n        delimiter: Delimiter for displaying. Default: ``/``.\n\n    Returns:\n        A dictionary with the keys flattened.\n    \"\"\"\n    output_dict = {}\n    for key, value in input_dict.items():\n        nested_key = f\"{prefix}{delimiter}{key}\" if prefix else key\n        if isinstance(value, (dict, flax.core.FrozenDict)):\n            output_dict.update(flatten_dict(value, prefix=nested_key, delimiter=delimiter))\n        else:\n            output_dict[nested_key] = value\n    return output_dict\n\n\ndef count_parameters(params: PyTree) -> int:\n    \"\"\"Return count of variables for the parameter dictionary.\n\n    Args:\n        params: Flax model parameters.\n\n    Returns:\n        The number of parameters in the model.\n    \"\"\"\n    flat_params = flatten_dict(params)\n    return sum(np.prod(v.shape) for v in flat_params.values())  # type: ignore\n\n\ndef get_parameter_rows(\n    params: ParamsContainer,\n    *,\n    include_stats: bool = False,\n) -> List[Union[ParamRow, ParamRowWithStats]]:\n    \"\"\"Return information about parameters as a list of dictionaries.\n\n    Args:\n        params: Dictionary with parameters as NumPy arrays. The dictionary\n            can be nested.\n        include_stats: If ``True`` add columns with mean and std for each\n            variable. Note that this can be considerably more compute\n            intensive and cause a lot of memory to be transferred to the\n            host.\n\n    Returns:\n        A list of `ParamRow`, or `ParamRowWithStats`, depending on the\n        passed value of `include_stats`.\n    \"\"\"\n    assert isinstance(params, (dict, flax.core.FrozenDict))\n    if params:\n        params = flatten_dict(params)\n        names, values = map(list, tuple(zip(*sorted(params.items()))))\n    else:\n        names, values = [], []\n\n    def make_row(name, value):\n        if include_stats:\n            return ParamRowWithStats(\n                name=name,\n                shape=value.shape,\n                size=int(np.prod(value.shape)),\n                mean=float(value.mean()),\n                std=float(value.std()),\n            )\n        else:\n            return ParamRow(name=name, shape=value.shape, size=int(np.prod(value.shape)))\n\n    return [make_row(name, value) for name, value in zip(names, values)]\n\n\ndef _default_table_value_formatter(value):\n    \"\"\"Format ints with \",\" between thousands, and floats to 3 digits.\"\"\"\n    if isinstance(value, bool):\n        return str(value)\n    elif isinstance(value, int):\n        return \"{:,}\".format(value)\n    elif isinstance(value, float):\n        return \"{:.3}\".format(value)\n    else:\n        return str(value)\n\n\ndef make_table(\n    rows: List[Any],\n    *,\n    column_names: Optional[Sequence[str]] = None,\n    value_formatter: Callable[[Any], str] = _default_table_value_formatter,\n    max_lines: Optional[int] = None,\n) -> str:\n    \"\"\"Render list of rows to a table.\n\n    Args:\n        rows: List of dataclass instances of a single type\n            (e.g. `ParamRow`).\n        column_names: List of columns that that should be included in the\n            output. If not provided, then the columns are taken from keys\n            of the first row.\n        value_formatter: Callable used to format cell values.\n        max_lines: Don't render a table longer than this.\n\n    Returns:\n        A string representation of a table as in the example below.\n\n        ::\n\n          +---------+---------+\n          | Col1    | Col2    |\n          +---------+---------+\n          | value11 | value12 |\n          | value21 | value22 |\n          +---------+---------+\n    \"\"\"\n    if any(not dataclasses.is_dataclass(row) for row in rows):\n        raise ValueError(\"Expected argument 'rows' to be list of dataclasses\")\n    if len(set(map(type, rows))) > 1:\n        raise ValueError(\"Expected elements of argument 'rows' be of same type.\")\n\n    class Column:\n        \"\"\"Definition of a column for printing parameters.\"\"\"\n\n        def __init__(self, name, values):\n            self.name = name.capitalize()\n            self.values = values\n            self.width = max(len(v) for v in values + [name])\n\n    if column_names is None:\n        if not rows:\n            return \"(empty table)\"\n        column_names = [field.name for field in dataclasses.fields(rows[0])]\n\n    columns = [\n        Column(name, [value_formatter(getattr(row, name)) for row in rows]) for name in column_names\n    ]\n\n    var_line_format = \"|\" + \"\".join(f\" {{: <{c.width}s}} |\" for c in columns)\n    sep_line_format = var_line_format.replace(\" \", \"-\").replace(\"|\", \"+\")\n    header = var_line_format.replace(\">\", \"<\").format(*[c.name for c in columns])\n    separator = sep_line_format.format(*[\"\" for c in columns])\n\n    lines = [separator, header, separator]\n    for i in range(len(rows)):\n        if max_lines and len(lines) >= max_lines - 3:\n            lines.append(\"[...]\")\n            break\n        lines.append(var_line_format.format(*[c.values[i] for c in columns]))\n    lines.append(separator)\n\n    return \"\\n\".join(lines)\n\n\ndef get_parameter_overview(\n    params: ParamsContainer, *, include_stats: bool = True, max_lines: Optional[int] = None\n) -> str:\n    \"\"\"Return string with variables names, their shapes, count.\n\n    Args:\n        params: Dictionary with parameters as NumPy arrays. The dictionary\n            can be nested.\n        include_stats: If ``True``, add columns with mean and std for each\n            variable.\n        max_lines: If not ``None``, the maximum number of variables to\n            include.\n\n    Returns:\n        A string with a table as in the example below.\n\n        ::\n\n          +----------------+---------------+------------+\n          | Name           | Shape         | Size       |\n          +----------------+---------------+------------+\n          | FC_1/weights:0 | (63612, 1024) | 65,138,688 |\n          | FC_1/biases:0  |       (1024,) |      1,024 |\n          | FC_2/weights:0 |    (1024, 32) |     32,768 |\n          | FC_2/biases:0  |         (32,) |         32 |\n          +----------------+---------------+------------+\n          Total weights: 65,172,512\n    \"\"\"\n    if isinstance(params, (dict, flax.core.FrozenDict)):\n        params = jax.tree_util.tree_map(np.asarray, params)\n    rows = get_parameter_rows(params, include_stats=include_stats)\n    total_weights = count_parameters(params)\n    RowType = ParamRowWithStats if include_stats else ParamRow\n    # Pass in `column_names` to enable rendering empty tables.\n    column_names = [field.name for field in dataclasses.fields(RowType)]\n    table = make_table(rows, max_lines=max_lines, column_names=column_names)\n    return table + f\"\\nTotal weights: {total_weights:,}\"\n"
  },
  {
    "path": "scico/flax/train/diagnostics.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Utilities for computing and displaying performance metrics during training.\n\nAssumes sharded batched data.\n\"\"\"\n\nfrom typing import Callable, Dict, Tuple, Union\n\nfrom jax import lax\n\nfrom scico.diagnostics import IterationStats\nfrom scico.metric import snr\nfrom scico.numpy import Array\n\nfrom .losses import mse_loss\nfrom .typed_dict import MetricsDict\n\n\ndef compute_metrics(output: Array, labels: Array, criterion: Callable = mse_loss) -> MetricsDict:\n    \"\"\"Compute diagnostic metrics.\n\n    Assumes sharded batched data (i.e. it only works inside pmap because\n    it needs an axis name).\n\n    Args:\n        output: Comparison signal.\n        labels: Reference signal.\n        criterion: Loss function. Default: :meth:`~scico.flax.train.losses.mse_loss`.\n\n    Returns:\n        Loss and SNR between `output` and `labels`.\n    \"\"\"\n    loss = criterion(output, labels)\n    snr_ = snr(labels, output)\n    metrics: MetricsDict = {\n        \"loss\": loss,\n        \"snr\": snr_,\n    }\n    metrics = lax.pmean(metrics, axis_name=\"batch\")\n    return metrics\n\n\nclass ArgumentStruct:\n    \"\"\"Class that converts a dictionary into an object with named entries.\n\n    Class that converts a python dictionary into an object with named\n    entries given by the dictionary keys. After the object instantiation\n    both modes of access (dictionary or object entries) can be used.\n    \"\"\"\n\n    def __init__(self, **entries):\n        self.__dict__.update(entries)\n\n\ndef stats_obj() -> Tuple[IterationStats, Callable]:\n    \"\"\"Functionality to log and store iteration statistics.\n\n    This function initializes an object\n    :class:`~.diagnostics.IterationStats` to log and store iteration\n    statistics if logging is enabled during training. The statistics\n    collected are: epoch, time, learning rate, loss and snr in training\n    and loss and snr in evaluation. The\n    :class:`~.diagnostics.IterationStats` object takes care of both\n    printing stats to command line and storing them for further analysis.\n    \"\"\"\n    # epoch, time learning rate loss and snr (train and\n    # eval) fields\n    itstat_fields = {\n        \"Epoch\": \"%d\",\n        \"Time\": \"%8.2e\",\n        \"Train_LR\": \"%.6f\",\n        \"Train_Loss\": \"%.6f\",\n        \"Train_SNR\": \"%.2f\",\n        \"Eval_Loss\": \"%.6f\",\n        \"Eval_SNR\": \"%.2f\",\n    }\n    itstat_attrib = [\n        \"epoch\",\n        \"time\",\n        \"train_learning_rate\",\n        \"train_loss\",\n        \"train_snr\",\n        \"loss\",\n        \"snr\",\n    ]\n\n    # dynamically create itstat_func; see https://stackoverflow.com/questions/24733831\n    itstat_return = \"return(\" + \", \".join([\"obj.\" + attr for attr in itstat_attrib]) + \")\"\n    scope: Dict[str, Callable] = {}\n    exec(\"def itstat_func(obj): \" + itstat_return, scope)\n    default_itstat_options: Dict[str, Union[dict, Callable, bool]] = {\n        \"fields\": itstat_fields,\n        \"itstat_func\": scope[\"itstat_func\"],\n        \"display\": True,\n    }\n    itstat_insert_func: Callable = default_itstat_options.pop(\"itstat_func\")  # type: ignore\n    itstat_object = IterationStats(**default_itstat_options)  # type: ignore\n\n    return itstat_object, itstat_insert_func\n"
  },
  {
    "path": "scico/flax/train/input_pipeline.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Generalized data handling for training script.\n\nIncludes construction of data iterator and\ninstantiation for parallel processing.\n\"\"\"\n\nimport warnings\n\nwarnings.simplefilter(action=\"ignore\", category=FutureWarning)\n\nfrom typing import Any, Union\n\nimport jax\nimport jax.numpy as jnp\n\nfrom flax import jax_utils\nfrom scico.numpy import Array\n\nfrom .typed_dict import DataSetDict\n\nDType = Any\nKeyArray = Union[Array, jax.Array]\n\n\nclass IterateData:\n    \"\"\"Class to load data for training and testing.\n\n    It uses the generator pattern to obtain an iterable object.\n    \"\"\"\n\n    def __init__(self, dt: DataSetDict, batch_size: int, train: bool = True, key: KeyArray = None):\n        r\"\"\"Initialize a :class:`IterateData` object.\n\n        Args:\n            dt: Dictionary of data for supervised training including\n               images and labels.\n            batch_size: Size of batch for iterating through the data.\n            train: Flag indicating use of iterator for training. Iterator\n                for training is infinite, iterator for testing passes\n                once through the data. Default: ``True``.\n            key: A PRNGKey used as the random key. Default: ``None``.\n        \"\"\"\n        self.dt = dt\n        self.batch_size = batch_size\n        self.train = train\n        self.n = dt[\"image\"].shape[0]\n        self.key = key\n        if key is None:\n            self.key = jax.random.key(0)\n        self.steps_per_epoch = self.n // batch_size\n        self.reset()\n\n    def reset(self):\n        \"\"\"Re-shuffle data in training.\"\"\"\n        if self.train:\n            self.key, subkey = jax.random.split(self.key)\n            self.perms = jax.random.permutation(subkey, self.n)\n        else:\n            self.perms = jnp.arange(self.n)\n\n        self.perms = self.perms[: self.steps_per_epoch * self.batch_size]  # skips incomplete batch\n        self.perms = self.perms.reshape((self.steps_per_epoch, self.batch_size))\n        self.ns = 0\n\n    def __iter__(self):\n        return self\n\n    def __next__(self):\n        \"\"\"Get next batch.\n\n        During training it reshuffles the batches when the data is\n        exhausted.\"\"\"\n        if self.ns >= self.steps_per_epoch:\n            if self.train:\n                self.reset()\n            else:\n                self.ns = 0\n        batch = {k: v[self.perms[self.ns], ...] for k, v in self.dt.items()}\n        self.ns += 1\n        return batch\n\n\ndef prepare_data(xs: Array) -> Any:\n    \"\"\"Reshape input batch for parallel training.\"\"\"\n    local_device_count = jax.local_device_count()\n\n    def _prepare(x: Array) -> Array:\n        # reshape (host_batch_size, height, width, channels) to\n        # (local_devices, device_batch_size, height, width, channels)\n        return x.reshape((local_device_count, -1) + x.shape[1:])\n\n    return jax.tree_util.tree_map(_prepare, xs)\n\n\ndef create_input_iter(\n    key: KeyArray,\n    dataset: DataSetDict,\n    batch_size: int,\n    size_device_prefetch: int = 2,\n    dtype: DType = jnp.float32,\n    train: bool = True,\n) -> Any:\n    \"\"\"Create data iterator for training.\n\n    Create data iterator for training by sharding and prefetching batches\n    on device.\n\n    Args:\n        key: A PRNGKey used for random data permutations.\n        dataset: Dictionary of data for supervised training including\n            images and labels.\n        batch_size: Size of batch for iterating through the data.\n        size_device_prefetch: Size of prefetch buffer. Default: 2.\n        dtype: Type of data to handle. Default: :attr:`~numpy.float32`.\n        train: Flag indicating the type of iterator to construct and use.\n            The iterator for training permutes data on each epoch while\n            the iterator for testing passes through the data without\n            permuting it. Default: ``True``.\n\n    Returns:\n        Array-like data sharded to specific devices coming from an\n        iterator built from the provided dataset.\n    \"\"\"\n    ds = IterateData(dataset, batch_size, train, key)\n    it = map(prepare_data, ds)\n    it = jax_utils.prefetch_to_device(it, size_device_prefetch)\n    return it\n"
  },
  {
    "path": "scico/flax/train/learning_rate.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Learning rate schedulers.\"\"\"\n\nimport optax\n\nfrom .typed_dict import ConfigDict\n\n\ndef create_cnst_lr_schedule(config: ConfigDict) -> optax._src.base.Schedule:\n    \"\"\"Create learning rate to be a constant specified\n    value.\n\n    Args:\n        config: Dictionary of configuration. The value to use corresponds\n           to the `base_learning_rate` keyword.\n\n    Returns:\n        schedule: A function that maps step counts to values.\n    \"\"\"\n    schedule = optax.constant_schedule(config[\"base_learning_rate\"])\n    return schedule\n\n\ndef create_exp_lr_schedule(config: ConfigDict) -> optax._src.base.Schedule:\n    \"\"\"Create learning rate schedule to have an exponential decay.\n\n    Args:\n        config: Dictionary of configuration. The values to use correspond\n            to `base_learning_rate`, `num_epochs`, `steps_per_epochs` and\n            `lr_decay_rate`.\n\n    Returns:\n        schedule: A function that maps step counts to values.\n    \"\"\"\n    decay_steps = config[\"num_epochs\"] * config[\"steps_per_epoch\"]\n    schedule = optax.exponential_decay(\n        config[\"base_learning_rate\"], decay_steps, config[\"lr_decay_rate\"]\n    )\n    return schedule\n\n\ndef create_cosine_lr_schedule(config: ConfigDict) -> optax._src.base.Schedule:\n    \"\"\"Create learning rate to follow a pre-specified schedule.\n\n    Create learning rate to follow a pre-specified schedule with warmup\n    and cosine stages.\n\n    Args:\n        config: Dictionary of configuration. The parameters to use\n            correspond to keywords: `base_learning_rate`, `num_epochs`,\n            `warmup_epochs` and `steps_per_epoch`.\n\n    Returns:\n        schedule: A function that maps step counts to values.\n    \"\"\"\n    # Warmup stage\n    warmup_fn = optax.linear_schedule(\n        init_value=0.0,\n        end_value=config[\"base_learning_rate\"],\n        transition_steps=config[\"warmup_epochs\"] * config[\"steps_per_epoch\"],\n    )\n    # Cosine stage\n    cosine_epochs = max(config[\"num_epochs\"] - config[\"warmup_epochs\"], 1)\n    cosine_fn = optax.cosine_decay_schedule(\n        init_value=config[\"base_learning_rate\"],\n        decay_steps=cosine_epochs * config[\"steps_per_epoch\"],\n    )\n\n    schedule = optax.join_schedules(\n        schedules=[warmup_fn, cosine_fn],\n        boundaries=[config[\"warmup_epochs\"] * config[\"steps_per_epoch\"]],\n    )\n\n    return schedule\n"
  },
  {
    "path": "scico/flax/train/losses.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Definition of loss functions for model optimization.\"\"\"\n\nimport jax.numpy as jnp\n\nimport optax\n\nfrom scico.numpy import Array\n\n\ndef mse_loss(output: Array, labels: Array) -> float:\n    \"\"\"Compute Mean Squared Error (MSE) loss for training via Optax.\n\n    Args:\n        output: Comparison signal.\n        labels: Reference signal.\n\n    Returns:\n        MSE between `output` and `labels`.\n    \"\"\"\n    mse = optax.l2_loss(output, labels)\n    return jnp.mean(mse)\n"
  },
  {
    "path": "scico/flax/train/spectral.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Utils for spectral normalization of convolutional layers in Flax models.\"\"\"\n\nimport warnings\n\nwarnings.simplefilter(action=\"ignore\", category=FutureWarning)\n\nfrom typing import Any, Callable, Sequence\n\nimport numpy as np\n\nimport jax\nimport jax.numpy as jnp\nfrom jax import lax\n\nimport scipy\nfrom flax.core import freeze, unfreeze\nfrom flax.linen import Conv\nfrom flax.linen.module import Module, compact\nfrom scico.numpy import Array\nfrom scico.typing import Shape\n\nfrom .traversals import ModelParamTraversal\n\nPyTree = Any\n\n\n# From https://github.com/deepmind/dm-haiku/issues/71\ndef _l2_normalize(x: Array, eps: float = 1e-12) -> Array:\n    r\"\"\"Normalize array by its :math:`\\el_2` norm.\n\n    Args:\n        x: Array to be normalized.\n        eps: Small value to prevent divide by zero. Default: 1e-12.\n\n    Returns:\n        Normalized array.\n    \"\"\"\n    return x * lax.rsqrt((x**2).sum() + eps)\n\n\n# From https://nbviewer.org/gist/shoyer/fa9a29fd0880e2e033d7696585978bfc\ndef estimate_spectral_norm(\n    f: Callable, input_shape: Shape, seed: float = 0, n_steps: int = 10, eps: float = 1e-12\n):\n    \"\"\"Estimate spectral norm of operator.\n\n    This function estimates the spectral norm of an operator by\n    estimating the singular vectors of the operator via the power\n    iteration method and the transpose operator enabled by nested\n    autodiff in JAX.\n\n    Args:\n        f: Operator to compute spectral norm.\n        input_shape: Shape of input to operator.\n        seed: Value to seed the random generation. Default: 0.\n        n_steps: Number of power iterations to compute. Default: 10.\n        eps: Small value to prevent divide by zero. Default: 1e-12.\n\n    Returns:\n        Spectral norm.\n    \"\"\"\n    rng = jax.random.key(seed)\n    u0 = jax.random.normal(rng, input_shape)\n    v0 = jnp.zeros_like(f(u0))\n\n    def fun(carry, _):\n        u, v = carry\n        v, f_vjp = jax.vjp(f, u)\n        v = _l2_normalize(v, eps)\n        (u,) = f_vjp(v)\n        u = _l2_normalize(u, eps)\n        return (u, v), None\n\n    (u, v), _ = lax.scan(fun, (u0, v0), xs=None, length=n_steps)\n    return jnp.vdot(v, f(u))\n\n\nclass CNN(Module):\n    \"\"\"Evaluation of convolution operator via Flax convolutional layer.\n\n    Evaluation of convolution operator via Flax implementation of a\n    convolutional layer. This is form of convolution is used only for the\n    estimation of the spectral norm of the operator. Therefore, the value\n    of the kernel is provided too.\n\n    Attributes:\n        kernel_size: Size of the convolution filter.\n        kernel0: Convolution filter.\n        dtype: Output type.\n    \"\"\"\n\n    kernel_size: Sequence[int]\n    kernel0: Array\n    dtype: Any\n\n    @compact\n    def __call__(self, x):\n        \"\"\"Apply CNN layer.\n\n        Args:\n            x: The array to be convolved.\n\n        Returns:\n            The result of the convolution with `kernel0`.\n        \"\"\"\n\n        def kinit_wrap(rng, shape, dtype=self.dtype):\n            return jnp.array(self.kernel0, dtype)\n\n        return Conv(\n            features=self.kernel_size[3],\n            kernel_size=self.kernel_size[:2],\n            use_bias=False,\n            padding=\"CIRCULAR\",\n            kernel_init=kinit_wrap,\n        )(x)\n\n\ndef conv(inputs: Array, kernel: Array) -> Array:\n    \"\"\"Compute convolution betwen input and kernel.\n\n    The convolution is evaluated via a CNN Flax model.\n\n    Args:\n        inputs: Array to compute convolution.\n        kernel: Filter of the convolutional operator.\n\n    Returns:\n        Result of convolution of input with kernel.\n    \"\"\"\n\n    dtype = kernel.dtype\n    inputs = jnp.asarray(inputs, dtype)\n    kernel = jnp.asarray(kernel, dtype)\n\n    rng = jax.random.key(0)  # not used\n    model = CNN(kernel_size=kernel.shape, kernel0=kernel, dtype=dtype)\n    variables = model.init(rng, np.zeros(inputs.shape))\n    y = model.apply(variables, inputs)\n\n    return y\n\n\ndef spectral_normalization_conv(\n    params: PyTree, traversal: ModelParamTraversal, xshape: Shape, n_steps: int = 10\n) -> PyTree:\n    \"\"\"Normalize parameters of convolutional layer by its spectral norm.\n\n    Args:\n        params: Current model parameters.\n        traversal: Utility to select model parameters.\n        xshape: Shape of input.\n        n_steps: Number of power iterations to compute. Default: 10.\n    \"\"\"\n    params_out = traversal.update(\n        lambda kernel: kernel\n        / (\n            estimate_spectral_norm(\n                lambda x: conv(x, kernel), (1, xshape[1], xshape[2], kernel.shape[2]), n_steps\n            )\n            * 1.02\n        ),\n        unfreeze(params),\n    )\n\n    return freeze(params_out)\n\n\n# From https://nbviewer.org/gist/shoyer/fa9a29fd0880e2e033d7696585978bfc\ndef exact_spectral_norm(f, input_shape):\n    \"\"\"Compute spectral norm of operator.\n\n    This function computes the spectral norm of an operator via autodiff\n    in JAX.\n\n    Args:\n        f: Operator to compute spectral norm.\n        input_shape: Shape of input to operator.\n\n    Returns:\n        Spectral norm.\n    \"\"\"\n    dummy_input = jnp.zeros(input_shape)\n    jacobian = jax.jacfwd(f)(dummy_input)\n    shape = (np.prod(jacobian.shape[: -dummy_input.ndim]), np.prod(input_shape))\n    return scipy.linalg.svdvals(jacobian.reshape(shape)).max()\n"
  },
  {
    "path": "scico/flax/train/state.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Configuration of Flax Train State.\"\"\"\n\nfrom typing import Any, Optional, Tuple, Union\n\nimport jax\nimport jax.numpy as jnp\n\nimport optax\n\nfrom flax.training import train_state\nfrom scico.numpy import Array\nfrom scico.typing import Shape\n\nfrom .typed_dict import ConfigDict, ModelVarDict\n\nModuleDef = Any\nKeyArray = Union[Array, jax.Array]\nPyTree = Any\nArrayTree = optax.Params\n\n\nclass TrainState(train_state.TrainState):\n    \"\"\"Definition of Flax train state.\n\n    Definition of Flax train state including `batch_stats` for batch\n    normalization.\n    \"\"\"\n\n    batch_stats: Any\n\n\ndef initialize(key: KeyArray, model: ModuleDef, ishape: Shape) -> Tuple[PyTree, ...]:\n    \"\"\"Initialize Flax model.\n\n    Args:\n        key: A PRNGKey used as the random key.\n        model: Flax model to train.\n        ishape: Shape of signal (image) to process by `model`. Make sure\n            that no batch dimension is included.\n\n    Returns:\n        Initial model parameters (including `batch_stats`).\n    \"\"\"\n    input_shape = (1, ishape[0], ishape[1], model.channels)\n\n    @jax.jit\n    def init(*args):\n        return model.init(*args)\n\n    variables = init({\"params\": key}, jnp.ones(input_shape, model.dtype))\n    if \"batch_stats\" in variables:\n        return variables[\"params\"], variables[\"batch_stats\"]\n    return variables[\"params\"]\n\n\ndef create_basic_train_state(\n    key: KeyArray,\n    config: ConfigDict,\n    model: ModuleDef,\n    ishape: Shape,\n    learning_rate_fn: optax._src.base.Schedule,\n    variables0: Optional[ModelVarDict] = None,\n) -> TrainState:\n    \"\"\"Create Flax basic train state and initialize.\n\n    Args:\n        key: A PRNGKey used as the random key.\n        config: Dictionary of configuration. The values to use correspond\n            to keywords: `opt_type` and `momentum`.\n        model: Flax model to train.\n        ishape: Shape of signal (image) to process by `model`. Ensure\n            that no batch dimension is included.\n        variables0: Optional initial state of model parameters. If not\n            provided a random initialization is performed. Default:\n            ``None``.\n        learning_rate_fn: A function that maps step counts to values.\n\n    Returns:\n        state: Flax train state which includes the model apply function,\n           the model parameters and an Optax optimizer.\n    \"\"\"\n    batch_stats = None\n    if variables0 is None:\n        aux = initialize(key, model, ishape)\n        if len(aux) > 1:\n            params, batch_stats = aux\n        else:\n            params = aux\n    else:\n        params = variables0[\"params\"]\n        if \"batch_stats\" in variables0:\n            batch_stats = variables0[\"batch_stats\"]\n\n    if config[\"opt_type\"] == \"SGD\":\n        # Stochastic Gradient Descent optimiser\n        if \"momentum\" in config:\n            tx = optax.sgd(\n                learning_rate=learning_rate_fn, momentum=config[\"momentum\"], nesterov=True\n            )\n        else:\n            tx = optax.sgd(learning_rate=learning_rate_fn)\n    elif config[\"opt_type\"] == \"ADAM\":\n        # Adam optimiser\n        tx = optax.adam(\n            learning_rate=learning_rate_fn,\n        )\n    elif config[\"opt_type\"] == \"ADAMW\":\n        # Adam with weight decay regularization\n        tx = optax.adamw(\n            learning_rate=learning_rate_fn,\n        )\n    else:\n        raise NotImplementedError(\n            f\"Optimizer specified {config['opt_type']} has not been included in SCICO.\"\n        )\n\n    if batch_stats is None:\n        state = TrainState.create(\n            apply_fn=model.apply,\n            params=params,\n            tx=tx,\n        )\n    else:\n        state = TrainState.create(\n            apply_fn=model.apply,\n            params=params,\n            tx=tx,\n            batch_stats=batch_stats,\n        )\n\n    return state\n"
  },
  {
    "path": "scico/flax/train/steps.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Definition of steps to iterate during training or evaluation.\"\"\"\n\nfrom typing import Any, Callable, List, Tuple, Union\n\nimport jax\nfrom jax import lax\n\nimport optax\n\nfrom scico.numpy import Array\n\nfrom .state import TrainState\nfrom .typed_dict import DataSetDict, MetricsDict\n\nKeyArray = Union[Array, jax.Array]\nPyTree = Any\n\n\ndef train_step(\n    state: TrainState,\n    batch: DataSetDict,\n    learning_rate_fn: optax._src.base.Schedule,\n    criterion: Callable,\n    metrics_fn: Callable,\n) -> Tuple[TrainState, MetricsDict]:\n    \"\"\"Perform a single data parallel training step.\n\n    Assumes sharded batched data. This function is intended to be used via\n    :class:`~scico.flax.BasicFlaxTrainer`, not directly.\n\n    Args:\n        state: Flax train state which includes the model apply function,\n            the model parameters and an Optax optimizer.\n        batch: Sharded and batched training data.\n        learning_rate_fn: A function to map step\n            counts to values. This is only used for display purposes\n            (optax optimizers are stateless, so the current learning rate\n            is not stored). The real learning rate schedule applied is the\n            one defined when creating the Flax state. If a different\n            object is passed here, then the displayed value will be\n            inaccurate.\n        criterion: A function that specifies the loss being minimized in\n            training.\n        metrics_fn: A function to evaluate quality of current model.\n\n    Returns:\n        Updated parameters and diagnostic statistics.\n    \"\"\"\n\n    def loss_fn(params: PyTree):\n        \"\"\"Loss function used for training.\"\"\"\n        output, new_model_state = state.apply_fn(\n            {\n                \"params\": params,\n                \"batch_stats\": state.batch_stats,\n            },\n            batch[\"image\"],\n            mutable=[\"batch_stats\"],\n        )\n        loss = criterion(output, batch[\"label\"])\n        return loss, (new_model_state, output)\n\n    step = state.step\n    # Only to figure out current learning rate, which cannot be stored in stateless optax.\n    # Requires agreement between the function passed here and the one used to create the\n    # train state.\n    lr = learning_rate_fn(step)\n\n    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n    aux, grads = grad_fn(state.params)\n    # Re-use same axis_name as in call to pmap\n    grads = lax.pmean(grads, axis_name=\"batch\")\n    new_model_state, output = aux[1]\n    metrics = metrics_fn(output, batch[\"label\"], criterion)\n    metrics[\"learning_rate\"] = lr\n\n    # Update params and stats\n    new_state = state.apply_gradients(\n        grads=grads,\n        batch_stats=new_model_state[\"batch_stats\"],\n    )\n\n    return new_state, metrics\n\n\ndef train_step_post(\n    state: TrainState,\n    batch: DataSetDict,\n    learning_rate_fn: optax._src.base.Schedule,\n    criterion: Callable,\n    train_step_fn: Callable,\n    metrics_fn: Callable,\n    post_lst: List[Callable],\n) -> Tuple[TrainState, MetricsDict]:\n    \"\"\"Perform a single data parallel training step with postprocessing.\n\n    A list of postprocessing functions (i.e. for spectral normalization\n    or positivity condition, etc.) is applied after the gradient update.\n    Assumes sharded batched data.\n\n    This function is intended to be used via\n    :class:`~scico.flax.BasicFlaxTrainer`, not directly.\n\n    Args:\n        state: Flax train state which includes the model apply function,\n            the model parameters and an Optax optimizer.\n        batch: Sharded and batched training data.\n        learning_rate_fn: A function to map step counts to values.\n        criterion: A function that specifies the loss being minimized in\n            training.\n        train_step_fn: A function that executes a training step.\n        metrics_fn: A function to evaluate quality of current model.\n        post_lst: List of postprocessing functions to apply to parameter\n            set after optimizer step (e.g. clip to a specified range,\n            normalize, etc.).\n\n    Returns:\n        Updated parameters, fulfilling additional constraints, and\n        diagnostic statistics.\n    \"\"\"\n\n    new_state, metrics = train_step_fn(state, batch, learning_rate_fn, criterion, metrics_fn)\n\n    # Post-process parameters\n    for post_fn in post_lst:\n        new_params = post_fn(new_state.params)\n        new_state = new_state.replace(params=new_params)\n\n    return new_state, metrics\n\n\ndef eval_step(\n    state: TrainState, batch: DataSetDict, criterion: Callable, metrics_fn: Callable\n) -> MetricsDict:\n    \"\"\"Evaluate current model state.\n\n    Assumes sharded batched data. This function is intended to be used\n    via :class:`~scico.flax.BasicFlaxTrainer` or\n    :meth:`~scico.flax.only_evaluate`, not directly.\n\n    Args:\n        state: Flax train state which includes the model apply function\n            and the model parameters.\n        batch: Sharded and batched training data.\n        criterion: Loss function.\n        metrics_fn: A function to evaluate quality of current model.\n\n    Returns:\n        Current diagnostic statistics.\n    \"\"\"\n    variables = {\n        \"params\": state.params,\n        \"batch_stats\": state.batch_stats,\n    }\n    output = state.apply_fn(variables, batch[\"image\"], train=False, mutable=False)\n    return metrics_fn(output, batch[\"label\"], criterion)\n"
  },
  {
    "path": "scico/flax/train/trainer.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Class providing integrated access to functionality for training Flax\n   models.\n\nAssumes sharded batched data and uses data parallel training.\n\"\"\"\n\nimport warnings\n\nwarnings.simplefilter(action=\"ignore\", category=FutureWarning)\n\nimport functools\nimport time\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport jax\nfrom jax import lax\n\nfrom flax import jax_utils\nfrom flax.training import common_utils\nfrom scico.diagnostics import IterationStats\nfrom scico.numpy import Array\n\nfrom .checkpoints import checkpoint_restore, checkpoint_save\nfrom .clu_utils import get_parameter_overview\nfrom .diagnostics import ArgumentStruct, compute_metrics, stats_obj\nfrom .input_pipeline import create_input_iter\nfrom .learning_rate import create_cnst_lr_schedule\nfrom .losses import mse_loss\nfrom .state import TrainState, create_basic_train_state\nfrom .steps import eval_step, train_step, train_step_post\nfrom .typed_dict import ConfigDict, DataSetDict, MetricsDict, ModelVarDict\n\nModuleDef = Any\nKeyArray = Union[Array, jax.Array]\nPyTree = Any\nDType = Any\n\n\n# sync across replicas\ndef sync_batch_stats(state: TrainState) -> TrainState:\n    \"\"\"Sync the batch statistics across replicas.\"\"\"\n    # Each device has its own version of the running average batch\n    # statistics and those are synced before evaluation\n    return state.replace(batch_stats=cross_replica_mean(state.batch_stats))\n\n\n# pmean only works inside pmap because it needs an axis name.\n#: This function will average the inputs across all devices.\ncross_replica_mean = jax.pmap(lambda x: lax.pmean(x, \"x\"), \"x\")\n\n\nclass BasicFlaxTrainer:\n    \"\"\"Class encapsulating Flax training configuration and execution.\"\"\"\n\n    def __init__(\n        self,\n        config: ConfigDict,\n        model: ModuleDef,\n        train_ds: DataSetDict,\n        test_ds: DataSetDict,\n        variables0: Optional[ModelVarDict] = None,\n    ):\n        \"\"\"Initializer for :class:`BasicFlaxTrainer`.\n\n        Initializer for :class:`BasicFlaxTrainer` to configure model\n        training and evaluation loop. Construct a Flax train state (which\n        includes the model apply function, the model parameters and an\n        Optax optimizer). This uses data parallel training assuming\n        sharded batched data.\n\n        Args:\n            config: Hyperparameter configuration.\n            model: Flax model to train.\n            train_ds: Dictionary of training data (includes images and\n                labels).\n            test_ds: Dictionary of testing data (includes images and\n                labels).\n            variables0: Optional initial state of model parameters.\n        \"\"\"\n        # Configure seed\n        if \"seed\" not in config:\n            key = jax.random.key(0)\n        else:\n            key = jax.random.key(config[\"seed\"])\n        # Split seed for data iterators and model initialization\n        key1, key2 = jax.random.split(key)\n\n        # Object for storing iteration stats\n        self.itstat_object: Optional[IterationStats] = None\n\n        # Configure training loop\n        len_train = train_ds[\"image\"].shape[0]\n        len_test = test_ds[\"image\"].shape[0]\n        self.set_training_parameters(config, len_train, len_test)\n        self.construct_data_iterators(train_ds, test_ds, key1, model.dtype)\n\n        self.define_parallel_training_functions()\n\n        self.initialize_training_state(config, key2, model, variables0)\n\n        # Store configuration\n        self.config = config\n\n    def set_training_parameters(\n        self,\n        config: ConfigDict,\n        len_train: int,\n        len_test: int,\n    ):\n        \"\"\"Extract configuration parameters and construct training functions.\n\n        Parameters and functions are passed in the configuration\n        dictionary. Default values are used when parameters are not\n        included in configuration.\n\n        Args:\n            config: Hyperparameter configuration.\n            len_train: Number of samples in training set.\n            len_test: Number of samples in testing set.\n        \"\"\"\n        self.configure_steps(config, len_train, len_test)\n        self.configure_reporting(config)\n        self.configure_training_functions(config)\n\n    def configure_steps(\n        self,\n        config: ConfigDict,\n        len_train: int,\n        len_test: int,\n    ):\n        \"\"\"Configure training, evaluation and monitoring steps.\n\n        Args:\n            config: Hyperparameter configuration.\n            len_train: Number of samples in training set.\n            len_test: Number of samples in testing set.\n        \"\"\"\n        # Set required defaults if not present\n        if \"batch_size\" not in config:\n            batch_size = 2 * jax.device_count()\n        else:\n            batch_size = config[\"batch_size\"]\n        if \"num_epochs\" not in config:\n            num_epochs = 10\n        else:\n            num_epochs = config[\"num_epochs\"]\n\n        # Determine sharded vs. batch partition\n        if batch_size % jax.device_count() > 0:\n            raise ValueError(\"Batch size must be divisible by the number of devices.\")\n        self.local_batch_size: int = batch_size // jax.process_count()\n\n        # Training steps\n        self.steps_per_epoch: int = len_train // batch_size\n        config[\"steps_per_epoch\"] = self.steps_per_epoch  # needed for creating lr schedule\n        self.num_steps: int = int(self.steps_per_epoch * num_epochs)\n\n        # Evaluation (over testing set) steps\n        num_validation_examples: int = len_test\n        if \"steps_per_eval\" not in config:\n            self.steps_per_eval: int = num_validation_examples // batch_size\n        else:\n            self.steps_per_eval = config[\"steps_per_eval\"]\n\n        # Determine monitoring steps\n        if \"steps_per_checkpoint\" not in config:\n            self.steps_per_checkpoint: int = self.steps_per_epoch * 10\n        else:\n            self.steps_per_checkpoint = config[\"steps_per_checkpoint\"]\n\n        if \"log_every_steps\" not in config:\n            self.log_every_steps: int = self.steps_per_epoch * 20\n        else:\n            self.log_every_steps = config[\"log_every_steps\"]\n\n    def configure_reporting(self, config: ConfigDict):\n        \"\"\"Configure logging and checkpointing.\n\n        The parameters configured correspond to\n\n        - **logflag**: A flag for logging to the output terminal the\n              evolution of results. Default: ``False``.\n        - **workdir**: Directory to write checkpoints. Default: execution\n              directory.\n        - **checkpointing**: A flag for checkpointing model state.\n              Default: ``False``.\n        - **return_state**: A flag for returning the train state instead\n              of the model variables. Default: ``False``, i.e. return\n              model variables.\n\n        Args:\n            config: Hyperparameter configuration.\n        \"\"\"\n\n        # Determine logging configuration\n        if \"log\" in config:\n            self.logflag: bool = config[\"log\"]\n            if self.logflag:\n                self.itstat_object, self.itstat_insert_func = stats_obj()\n        else:\n            self.logflag = False\n\n        # Determine checkpointing configuration\n        if \"workdir\" in config:\n            self.workdir: str = config[\"workdir\"]\n        else:\n            self.workdir = \"./\"\n\n        if \"checkpointing\" in config:\n            self.checkpointing: bool = config[\"checkpointing\"]\n        else:\n            self.checkpointing = False\n\n        # Determine variable to return at end of training\n        if \"return_state\" in config:\n            # Returning Flax train state\n            self.return_state = config[\"return_state\"]\n        else:\n            # Return model variables\n            self.return_state = False\n\n    def configure_training_functions(self, config: ConfigDict):\n        \"\"\"Construct training functions.\n\n        Default functions are used if not specified in configuration.\n\n        The parameters configured correspond to\n\n        - **lr_schedule**: A function that creates an Optax learning rate\n              schedule. Default: :meth:`~scico.flax.train.learning_rate.create_cnst_lr_schedule`.\n        - **criterion**: A function that specifies the loss being minimized\n              in training. Default: :meth:`~scico.flax.train.losses.mse_loss`.\n        - **create_train_state**: A function that creates a Flax train state\n              and initializes it. A train state object helps to keep optimizer\n              and module functionality grouped for training. Default:\n              :meth:`~scico.flax.train.state.create_basic_train_state`.\n        - **train_step_fn**: A function that executes a training step.\n              Default: :meth:`~scico.flax.train.steps.train_step`, i.e.\n              use the standard train step.\n        - **eval_step_fn**: A function that executes an eval step. Default:\n              :meth:`~scico.flax.train.steps.eval_step`, i.e. use the\n              standard eval step.\n        - **metrics_fn**: A function that computes metrics. Default:\n              :meth:`~scico.flax.train.diagnostics.compute_metrics`, i.e.\n              use the standard compute metrics function.\n        - **post_lst**: List of postprocessing functions to apply to\n              parameter set after optimizer step (e.g. clip to a specified\n              range, normalize, etc.).\n\n        Args:\n            config: Hyperparameter configuration.\n        \"\"\"\n\n        if \"lr_schedule\" in config:\n            create_lr_schedule: Callable = config[\"lr_schedule\"]\n            self.lr_schedule = create_lr_schedule(config)\n        else:\n            self.lr_schedule = create_cnst_lr_schedule(config)\n\n        if \"criterion\" in config:\n            self.criterion: Callable = config[\"criterion\"]\n        else:\n            self.criterion = mse_loss\n\n        if \"create_train_state\" in config:\n            self.create_train_state: Callable = config[\"create_train_state\"]\n        else:\n            self.create_train_state = create_basic_train_state\n\n        if \"train_step_fn\" in config:\n            self.train_step_fn: Callable = config[\"train_step_fn\"]\n        else:\n            self.train_step_fn = train_step\n\n        if \"eval_step_fn\" in config:\n            self.eval_step_fn: Callable = config[\"eval_step_fn\"]\n        else:\n            self.eval_step_fn = eval_step\n\n        if \"metrics_fn\" in config:\n            self.metrics_fn: Callable = config[\"metrics_fn\"]\n        else:\n            self.metrics_fn = compute_metrics\n\n        self.post_lst: Optional[List[Callable]] = None\n        if \"post_lst\" in config:\n            self.post_lst = config[\"post_lst\"]\n\n    def construct_data_iterators(\n        self,\n        train_ds: DataSetDict,\n        test_ds: DataSetDict,\n        key: KeyArray,\n        mdtype: DType,\n    ):\n        \"\"\"Construct iterators for training and testing (evaluation) sets.\n\n        Args:\n            train_ds: Dictionary of training data (includes images\n                and labels).\n            test_ds: Dictionary of testing data (includes images\n                and labels).\n            key: A PRNGKey used as the random key.\n            mdtype: Output type of Flax model to be trained.\n        \"\"\"\n        size_device_prefetch = 2  # Set for GPU\n\n        self.train_dt_iter = create_input_iter(\n            key,\n            train_ds,\n            self.local_batch_size,\n            size_device_prefetch,\n            mdtype,\n            train=True,\n        )\n        self.eval_dt_iter = create_input_iter(\n            key,  # eval: no permutation\n            test_ds,\n            self.local_batch_size,\n            size_device_prefetch,\n            mdtype,\n            train=False,\n        )\n\n        self.ishape = train_ds[\"image\"].shape[1:3]\n        self.log(\n            \"channels: %d   training signals: %d   testing\"\n            \" signals: %d   signal size: %d\\n\"\n            % (\n                train_ds[\"label\"].shape[-1],\n                train_ds[\"label\"].shape[0],\n                test_ds[\"label\"].shape[0],\n                train_ds[\"label\"].shape[1],\n            )\n        )\n\n    def define_parallel_training_functions(self):\n        \"\"\"Construct parallel versions of training functions.\n\n        Construct parallel versions of training functions via\n        :func:`jax.pmap`.\n        \"\"\"\n        if self.post_lst is not None:\n            self.p_train_step = jax.pmap(\n                functools.partial(\n                    train_step_post,\n                    train_step_fn=self.train_step_fn,\n                    learning_rate_fn=self.lr_schedule,\n                    criterion=self.criterion,\n                    metrics_fn=self.metrics_fn,\n                    post_lst=self.post_lst,\n                ),\n                axis_name=\"batch\",\n            )\n        else:\n            self.p_train_step = jax.pmap(\n                functools.partial(\n                    self.train_step_fn,\n                    learning_rate_fn=self.lr_schedule,\n                    criterion=self.criterion,\n                    metrics_fn=self.metrics_fn,\n                ),\n                axis_name=\"batch\",\n            )\n        self.p_eval_step = jax.pmap(\n            functools.partial(\n                self.eval_step_fn, criterion=self.criterion, metrics_fn=self.metrics_fn\n            ),\n            axis_name=\"batch\",\n        )\n\n    def initialize_training_state(\n        self,\n        config: ConfigDict,\n        key: KeyArray,\n        model: ModuleDef,\n        variables0: Optional[ModelVarDict] = None,\n    ):\n        \"\"\"Construct and initialize Flax train state.\n\n        A train state object helps to keep optimizer and module\n        functionality grouped for training.\n\n        Args:\n            config: Hyperparameter configuration.\n            key: A PRNGKey used as the random key.\n            model: Flax model to train.\n            variables0: Optional initial state of model parameters.\n        \"\"\"\n        # Create Flax training state\n        state = self.create_train_state(\n            key, config, model, self.ishape, self.lr_schedule, variables0\n        )\n        # Only restore if no initialization is provided\n        if self.checkpointing and variables0 is None:\n            ok_no_ckpt = True  # It is ok if no checkpoint is found\n            state = checkpoint_restore(state, self.workdir, ok_no_ckpt)\n\n        self.log(\"Network Structure:\")\n        self.log(get_parameter_overview(state.params) + \"\\n\")\n        if hasattr(state, \"batch_stats\"):\n            self.log(\"Batch Normalization:\")\n            self.log(get_parameter_overview(state.batch_stats) + \"\\n\")\n\n        self.state = state\n\n    def train(self) -> Tuple[Dict[str, Any], Optional[IterationStats]]:\n        \"\"\"Execute training loop.\n\n        Returns:\n            Model variables extracted from :class:`.TrainState` and\n            iteration stats object obtained after executing the training\n            loop. Alternatively the :class:`.TrainState` can be returned\n            directly instead of the model variables. Note that the\n            iteration stats object is not ``None`` only if log is enabled\n            when configuring the training loop.\n        \"\"\"\n        state = self.state\n        step_offset = int(state.step)  # > 0 if restarting from checkpoint\n\n        # For parallel training\n        state = jax_utils.replicate(state)\n        # Execute training loop and register stats\n        t0 = time.time()\n        self.log(\"Initial compilation, which might take some time ...\")\n\n        train_metrics: List[Any] = []\n\n        for step, batch in zip(range(step_offset, self.num_steps), self.train_dt_iter):\n            state, metrics = self.p_train_step(state, batch)\n            # Training metrics computed in step\n            train_metrics.append(metrics)\n            if step == step_offset:\n                self.log(\"Initial compilation completed.\\n\")\n            if (step + 1) % self.log_every_steps == 0:\n                # sync batch statistics across replicas\n                state = sync_batch_stats(state)\n                self.update_metrics(state, step, train_metrics, t0)\n                train_metrics = []\n            if (step + 1) % self.steps_per_checkpoint == 0 or step + 1 == self.num_steps:\n                # sync batch statistics across replicas\n                state = sync_batch_stats(state)\n                self.checkpoint(state)\n\n        # Wait for finishing asynchronous execution\n        jax.random.normal(jax.random.key(0), ()).block_until_ready()\n        # Close object for iteration stats if logging\n        if self.logflag:\n            assert self.itstat_object is not None\n            self.itstat_object.end()\n\n        state = sync_batch_stats(state)\n        # Final checkpointing\n        self.checkpoint(state)\n        # Extract one copy of state\n        state = jax_utils.unreplicate(state)\n        if self.return_state:\n            return state, self.itstat_object  # type: ignore\n\n        dvar: ModelVarDict = {\n            \"params\": state.params,\n            \"batch_stats\": state.batch_stats,\n        }\n\n        self.train_time = time.time() - t0\n\n        return dvar, self.itstat_object  # type: ignore\n\n    def update_metrics(self, state: TrainState, step: int, train_metrics: List[MetricsDict], t0):\n        \"\"\"Compute metrics for current model state.\n\n        Metrics for training and testing (eval) sets are computed and\n        stored in an iteration stats object. This is executed only if\n        logging is enabled.\n\n        Args:\n            state: Flax train state which includes the model apply\n                function and the model parameters.\n            step: Current step in training.\n            train_metrics: List of diagnostic statistics computed from\n                training set.\n            t0: Time when training loop started.\n        \"\"\"\n        if not self.logflag:\n            return\n\n        eval_metrics: List[Any] = []\n\n        # Build summary dictionary for logging\n        # Include training stats\n        train_metrics = common_utils.get_metrics(train_metrics)\n        summary = {\n            f\"train_{k}\": v\n            for k, v in jax.tree_util.tree_map(lambda x: x.mean(), train_metrics).items()\n        }\n        epoch = step // self.steps_per_epoch\n        summary[\"epoch\"] = epoch\n        summary[\"time\"] = time.time() - t0\n\n        # Eval over testing set\n        for _ in range(self.steps_per_eval):\n            eval_batch = next(self.eval_dt_iter)\n            metrics = self.p_eval_step(state, eval_batch)\n            eval_metrics.append(metrics)\n        # Compute testing metrics\n        eval_metrics = common_utils.get_metrics(eval_metrics)\n\n        # Add testing stats to summary\n        summary_eval = jax.tree_util.tree_map(lambda x: x.mean(), eval_metrics)\n        summary.update(summary_eval)\n\n        # Update iteration stats object\n        assert isinstance(self.itstat_object, IterationStats)  # for mypy\n        self.itstat_object.insert(self.itstat_insert_func(ArgumentStruct(**summary)))\n\n    def checkpoint(self, state: TrainState):  # pragma: no cover\n        \"\"\"Checkpoint training state if enabled.\n\n        Args:\n            state: Flax train state.\n        \"\"\"\n        if self.checkpointing:\n            checkpoint_save(jax_utils.unreplicate(state), self.config, self.workdir)\n\n    def log(self, logstr: str):\n        \"\"\"Print stats to output terminal if logging is enabled.\n\n        Args:\n            logstr: String to be logged.\n        \"\"\"\n        if self.logflag:\n            print(logstr)\n"
  },
  {
    "path": "scico/flax/train/traversals.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Functionality to traverse, select, and update model parameters.\"\"\"\n\nfrom typing import Any\n\nimport jax.numpy as jnp\n\nfrom flax.traverse_util import ModelParamTraversal\n\nPyTree = Any\n\n\ndef construct_traversal(prmname: str) -> ModelParamTraversal:\n    \"\"\"Construct utility to select model parameters using a name filter.\n\n    Args:\n        prmname: Name of parameter to select.\n\n    Returns:\n        Flax utility to traverse and select model parameters.\n    \"\"\"\n    return ModelParamTraversal(lambda path, _: prmname in path)\n\n\ndef clip_positive(params: PyTree, traversal: ModelParamTraversal, minval: float = 1e-4) -> PyTree:\n    \"\"\"Clip parameters to positive range.\n\n    Args:\n        params: Current model parameters.\n        traversal: Utility to select model parameters.\n        minval: Minimum value to clip selected model parameters and keep\n            them in a positive range. Default: 1e-4.\n    \"\"\"\n    params_out = traversal.update(lambda x: jnp.clip(x, minval), params)\n\n    return params_out\n\n\ndef clip_range(\n    params: PyTree, traversal: ModelParamTraversal, minval: float = 1e-4, maxval: float = 1\n) -> PyTree:\n    \"\"\"Clip parameters to specified range.\n\n    Args:\n        params: Current model parameters.\n        traversal: Utility to select model parameters.\n        minval: Minimum value to clip selected model parameters.\n            Default: 1e-4.\n        maxval: Maximum value to clip selected model parameters.\n            Default: 1.\n    \"\"\"\n    params_out = traversal.update(lambda x: jnp.clip(x, minval, maxval), params)\n\n    return params_out\n"
  },
  {
    "path": "scico/flax/train/typed_dict.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Definition of typed dictionaries for objects in training functionality.\"\"\"\n\nimport sys\nfrom typing import Any, Callable, List\n\nif sys.version_info >= (3, 8):\n    from typing import TypedDict  # pylint: disable=no-name-in-module\nelse:\n    from typing_extensions import TypedDict\n\nfrom scico.numpy import Array\n\nPyTree = Any\n\n\nclass DataSetDict(TypedDict):\n    \"\"\"Dictionary structure for training data sets.\n\n    Definition of the dictionary structure\n    expected for the training data sets.\n    \"\"\"\n\n    #: Input (Num. samples x Height x Width x Channels).\n    image: Array\n    #: Output (Num. samples x Height x Width x Channels) or (Num. samples x Classes).\n    label: Array\n\n\nclass ConfigDict(TypedDict):\n    \"\"\"Dictionary structure for training parameters.\n\n    Definition of the dictionary structure expected for specifying\n    training parameters.\n    \"\"\"\n\n    #: Value to initialize seed for random generation.\n    seed: float\n    #: Type of optimizer. Options: SGD, ADAM, ADAMW.\n    opt_type: str\n    #: Momentum for SGD optimizer in case Nesterov is ``True``.\n    momentum: float\n    #: Size of batch for training.\n    batch_size: int\n    #: Number of epochs for training (an epoch is one whole pass through the training dataset).\n    num_epochs: int\n    #: Starting learning rate for scheduling.\n    base_learning_rate: float\n    #: Rate for decaying learning rate when scheduling is used.\n    lr_decay_rate: float\n    #: Number of epochs if warmup scheduling is used.\n    warmup_epochs: int\n    #: Period of training steps to evaluate over test set.\n    steps_per_eval: int\n    #: Period of training steps to print current train and test metrics.\n    log_every_steps: int\n    #: Training steps to be executed per epoch (depends on batch size).\n    steps_per_epoch: int\n    #: Period of training steps to save model (if checkpointing is ``True``).\n    steps_per_checkpoint: int\n    #: Flag to indicate if evolution metrics are to be printed.\n    log: bool\n    #: Path to directory for checkpointing model parameters.\n    workdir: str\n    #: Flag to indicate if model parameters and optimizer state are to\n    #: be stored while training.\n    checkpointing: bool\n    #: Flag to indicate if state (params and batch_stats) are to\n    #: be returned at the end of training.\n    return_state: bool\n    #: Function to modify the learning rate while training (type optax schedule).\n    lr_schedule: Callable\n    #: Criterion to optimize during training.\n    criterion: Callable\n    #: Function to create and initialize trainig state. Should include initialization\n    #: of optimizer and of batch_stats (if applicable).\n    create_train_state: Callable\n    #: Function to execute each training step.\n    train_step_fn: Callable\n    #: Function to execute each evaluation step.\n    eval_step_fn: Callable\n    #: Function to track metrics during training.\n    metrics_fn: Callable\n    #: List of post-processing functions to apply after a train step (if any).\n    post_lst: List[Callable]\n\n\nclass ModelVarDict(TypedDict):\n    \"\"\"Dictionary structure for Flax variables.\n\n    Definition of the dictionary structure grouping all Flax model\n    variables.\n    \"\"\"\n\n    #: Model weights and biases.\n    params: PyTree\n    #: Batch statistics (e.g. normalization parameters that depend on training data).\n    batch_stats: PyTree\n\n\nclass MetricsDict(TypedDict, total=False):\n    \"\"\"Dictionary structure for training metrics.\n\n    Definition of the dictionary structure for metrics computed or\n    updates made during training.\n    \"\"\"\n\n    loss: float  #: Evaluation of criterion being optimized.\n    snr: float  #: Evaluation of signal to noise ratio.\n    learning_rate: float  #: Current learning rate.\n"
  },
  {
    "path": "scico/function.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Function class.\"\"\"\n\nfrom typing import Any, Callable, Optional, Sequence, Tuple, Union\n\nimport jax\n\nimport scico\nimport scico.numpy as snp\nfrom scico.linop import LinearOperator, jacobian\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.util import dtype_name\nfrom scico.operator import Operator\nfrom scico.typing import BlockShape, DType, Shape\n\n\nclass Function:\n    r\"\"\"Function class.\n\n    A :class:`Function` maps multiple :code:`array-like` arguments to\n    another :code:`array-like`. It is more general than both\n    :class:`.Functional`, which is a mapping to a scalar, and\n    :class:`.Operator`, which takes a single argument.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shapes: Sequence[Union[Shape, BlockShape]],\n        output_shape: Optional[Union[Shape, BlockShape]] = None,\n        eval_fn: Optional[Callable] = None,\n        input_dtypes: Union[DType, Sequence[DType]] = snp.float32,\n        output_dtype: Optional[DType] = None,\n        jit: bool = False,\n    ):\n        \"\"\"\n        Args:\n            input_shapes: Shapes of input arrays.\n            output_shape: Shape of output array. Defaults to ``None``.\n                If ``None``, `output_shape` is determined by evaluating\n                `self.__call__` on input arrays of zeros.\n            eval_fn: Function used in evaluating this :class:`Function`.\n                Defaults to ``None``. Required unless `__init__` is being\n                called from a derived class with an `_eval` method.\n            input_dtypes: `dtype` for input argument. If a single `dtype`\n                is specified, it implies a common `dtype` for all inputs,\n                otherwise a list or tuple of values should be provided,\n                one per input. Defaults to :attr:`~numpy.float32`.\n            output_dtype: `dtype` for output argument. Defaults to\n                ``None``. If ``None``, `output_dtype` is determined by\n                evaluating `self.__call__` on an input arrays of zeros.\n            jit: If ``True``,  jit the evaluation function.\n        \"\"\"\n        self.jit = jit\n        self.input_shapes = input_shapes\n        if isinstance(input_dtypes, (list, tuple)):\n            self.input_dtypes = input_dtypes\n        else:\n            self.input_dtypes = (input_dtypes,) * len(input_shapes)\n\n        if eval_fn is not None:\n            self._eval = jax.jit(eval_fn) if jit else eval_fn\n        elif not hasattr(self, \"_eval\"):\n            raise NotImplementedError(\n                \"Function is an abstract base class when argument 'eval_fn' is not specified.\"\n            )\n\n        # If the output shape/dtype aren't specified, they can be inferred\n        # using scico.eval_shape\n        if output_shape is None or output_dtype is None:\n            dts_in = [\n                jax.ShapeDtypeStruct(shape, dtype=dtype)\n                for (shape, dtype) in zip(self.input_shapes, self.input_dtypes)\n            ]\n            dts_out = scico.eval_shape(self._eval, *dts_in)\n        if output_shape is None:\n            self.output_shape = dts_out.shape  # type: ignore\n        else:\n            self.output_shape = output_shape\n        if output_dtype is None:\n            self.output_dtype = dts_out.dtype\n        else:\n            self.output_dtype = output_dtype\n\n    def __repr__(self):\n        return f\"\"\"{self.__module__}.{self.__class__.__qualname__}\n  input_shapes: {self.input_shapes}\n  output_shape: {self.output_shape}\n  input_dtypes: {\", \".join([dtype_name(dt) for dt in self.input_dtypes])}\n  output_dtype: {dtype_name(self.output_dtype)}\n\"\"\"\n\n    def __call__(self, *args: Union[Array, BlockArray]) -> Union[Array, BlockArray]:\n        \"\"\"Evaluate this function with the specified parameters.\n\n        Args:\n           *args: Parameters at which to evaluate the function.\n\n        Returns:\n           Value of function with specified parameters.\n        \"\"\"\n        return self._eval(*args)\n\n    def slice(self, index: int, *fix_args: Union[Array, BlockArray]) -> Operator:\n        \"\"\"Fix all but one parameter, returning a :class:`.Operator`.\n\n        Args:\n           index: Index of parameter that remains free.\n           *fix_args: Fixed values for remaining parameters.\n\n        Returns:\n           An :class:`.Operator` taking the free parameter of the\n           :class:`Function` as its input.\n        \"\"\"\n\n        def pfunc(var_arg):\n            args = fix_args[0:index] + (var_arg,) + fix_args[index:]\n            return self._eval(*args)\n\n        return Operator(\n            self.input_shapes[index],\n            output_shape=self.output_shape,\n            eval_fn=pfunc,\n            input_dtype=self.input_dtypes[index],\n            output_dtype=self.output_dtype,\n            jit=self.jit,\n        )\n\n    def join(self) -> Operator:\n        \"\"\"Combine inputs into a :class:`.BlockArray`.\n\n        Construct an equivalent :class:`.Operator` taking a single\n        :class:`.BlockArray` input combining all inputs of this\n        :class:`Function`.\n\n        Returns:\n           An :class:`.Operator` taking a :class:`.BlockArray` as its\n           input.\n        \"\"\"\n        for dtype in self.input_dtypes[1:]:\n            if dtype != self.input_dtypes[0]:\n                raise ValueError(\n                    \"The join method may only be applied to Functions that have \"\n                    \"homogeneous input dtypes.\"\n                )\n\n        def jfunc(blkarr):\n            return self._eval(*blkarr.arrays)\n\n        return Operator(\n            self.input_shapes,  # type: ignore\n            output_shape=self.output_shape,\n            eval_fn=jfunc,\n            input_dtype=self.input_dtypes[0],\n            output_dtype=self.output_dtype,\n            jit=self.jit,\n        )\n\n    def jvp(\n        self, index: int, v: Union[Array, BlockArray], *args: Union[Array, BlockArray]\n    ) -> Tuple[Union[Array, BlockArray], Union[Array, BlockArray]]:\n        \"\"\"Jacobian-vector product with respect to a single parameter.\n\n        Compute a Jacobian-vector product with respect to a single\n        parameter of a :class:`Function`. Note that the order of the\n        parameters specifying where to evaluate the Jacobian and the\n        vector in the product is reverse with respect to :func:`jax.jvp`.\n\n        Args:\n           index: Index of parameter with respect to which the Jacobian\n              is to be computed.\n           v: Vector against which the Jacobian-vector product is to be\n              computed.\n           *args: Values of function parameters at which Jacobian is to\n              be computed.\n\n        Returns:\n           A pair consisting of the operator evaluated at the parameters\n           specified by `*args` and the Jacobian-vector product.\n        \"\"\"\n        var_arg = args[index]\n        fix_args = args[0:index] + args[(index + 1) :]\n        F = self.slice(index, *fix_args)\n        return F.jvp(var_arg, v)\n\n    def vjp(\n        self, index: int, *args: Union[Array, BlockArray], conjugate: Optional[bool] = True\n    ) -> Tuple[Tuple[Any, ...], Callable]:\n        \"\"\"Vector-Jacobian product with respect to a single parameter.\n\n        Compute a vector-Jacobian product with respect to a single\n        parameter of a :class:`Function`.\n\n        Args:\n           index: Index of parameter with respect to which the Jacobian\n              is to be computed.\n           *args: Values of function parameters at which Jacobian is to\n              be computed.\n           conjugate: If ``True``, compute the product using the\n               conjugate (Hermitian) transpose.\n\n        Returns:\n           A pair consisting of the operator evaluated at the parameters\n           specified by `*args` and a function that computes the\n           vector-Jacobian product.\n        \"\"\"\n        var_arg = args[index]\n        fix_args = args[0:index] + args[(index + 1) :]\n        F = self.slice(index, *fix_args)\n        return F.vjp(var_arg, conjugate=conjugate)\n\n    def jacobian(\n        self, index: int, *args: Union[Array, BlockArray], include_eval: Optional[bool] = False\n    ) -> LinearOperator:\n        \"\"\"Construct Jacobian linear operator for the function.\n\n        Construct a Jacobian :class:`.LinearOperator` that computes\n        vector products with the Jacobian with respect to a specified\n        variable of the function.\n\n        Args:\n           index: Index of parameter with respect to which the Jacobian\n              is to be computed.\n           *args: Values of function parameters at which Jacobian is to\n              be computed.\n           include_eval: Flag indicating whether the result of evaluating\n              the :class:`.Operator` should be included (as the first\n              component of a :class:`.BlockArray`) in the output of the\n              Jacobian :class:`.LinearOperator` constructed by this\n              function.\n\n        Returns:\n           A :class:`.LinearOperator` capable of computing Jacobian-vector\n           products.\n        \"\"\"\n        var_arg = args[index]\n        fix_args = args[0:index] + args[(index + 1) :]\n        F = self.slice(index, *fix_args)\n        return jacobian(F, var_arg, include_eval=include_eval)\n"
  },
  {
    "path": "scico/functional/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Functionals and functionals classes.\"\"\"\n\nimport sys\n\n# isort: off\nfrom ._functional import (\n    Functional,\n    FunctionalSum,\n    ComposedFunctional,\n    ScaledFunctional,\n    SeparableFunctional,\n    ZeroFunctional,\n)\nfrom ._norm import (\n    HuberNorm,\n    L0Norm,\n    L1Norm,\n    SquaredL2Norm,\n    L2Norm,\n    L21Norm,\n    NuclearNorm,\n    L1MinusL2Norm,\n)\nfrom ._tvnorm import AnisotropicTVNorm, IsotropicTVNorm, TVNorm\nfrom ._proxavg import ProximalAverage\nfrom ._indicator import NonNegativeIndicator, L2BallIndicator, BoxIndicator\nfrom ._denoiser import BM3D, BM4D, DnCNN\nfrom ._dist import SetDistance, SquaredSetDistance\n\n__all__ = [\n    \"AnisotropicTVNorm\",\n    \"IsotropicTVNorm\",\n    \"TVNorm\",\n    \"Functional\",\n    \"FunctionalSum\",\n    \"ComposedFunctional\",\n    \"ScaledFunctional\",\n    \"SeparableFunctional\",\n    \"ZeroFunctional\",\n    \"HuberNorm\",\n    \"L0Norm\",\n    \"L1Norm\",\n    \"SquaredL2Norm\",\n    \"L2Norm\",\n    \"L21Norm\",\n    \"L1MinusL2Norm\",\n    \"NonNegativeIndicator\",\n    \"BoxIndicator\",\n    \"NuclearNorm\",\n    \"L2BallIndicator\",\n    \"ProximalAverage\",\n    \"SetDistance\",\n    \"SquaredSetDistance\",\n    \"BM3D\",\n    \"BM4D\",\n    \"DnCNN\",\n]\n\n# Imported items in __all__ appear to originate in top-level functional module\nfor name in __all__:\n    getattr(sys.modules[__name__], name).__module__ = __name__\n"
  },
  {
    "path": "scico/functional/_denoiser.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Pseudo-functionals that have denoisers as their proximal operators.\"\"\"\n\nfrom typing import Union\n\nfrom scico import denoiser\nfrom scico.numpy import Array\n\nfrom ._functional import Functional\n\n\nclass BM3D(Functional):\n    r\"\"\"Pseudo-functional whose prox applies the BM3D denoising algorithm.\n\n    A pseudo-functional that has the BM3D algorithm\n    :cite:`dabov-2008-image` as its proximal operator, which calls\n    :func:`.denoiser.bm3d`. Since this function provides an interface\n    to compiled C code, JAX features such as automatic differentiation\n    and support for GPU devices are not available.\n    \"\"\"\n\n    has_eval = False\n    has_prox = True\n\n    def __init__(self, is_rgb: bool = False, profile: Union[denoiser.BM3DProfile, str] = \"np\"):\n        r\"\"\"Initialize a :class:`BM3D` object.\n\n        Args:\n            is_rgb: Flag indicating use of BM3D with a color transform.\n                    Default: ``False``.\n            profile: Parameter configuration for BM3D.\n        \"\"\"\n\n        self.is_rgb = is_rgb\n        self.profile = profile\n        super().__init__()\n\n    def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array:  # type: ignore\n        r\"\"\"Apply BM3D denoiser.\n\n        Args:\n            x: Input image.\n            lam: Noise parameter.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Denoised output.\n        \"\"\"\n        return denoiser.bm3d(x, lam, self.is_rgb, profile=self.profile)\n\n\nclass BM4D(Functional):\n    r\"\"\"Pseudo-functional whose prox applies the BM4D denoising algorithm.\n\n    A pseudo-functional that has the BM4D algorithm\n    :cite:`maggioni-2012-nonlocal` as its proximal operator, which calls\n    :func:`.denoiser.bm4d`. Since this function provides an interface\n    to compiled C code, JAX features such as automatic differentiation\n    and support for GPU devices are not available.\n    \"\"\"\n\n    has_eval = False\n    has_prox = True\n\n    def __init__(self, profile: Union[denoiser.BM4DProfile, str] = \"np\"):\n        r\"\"\"Initialize a :class:`BM4D` object.\n\n        Args:\n            profile: Parameter configuration for BM4D.\n        \"\"\"\n        self.profile = profile\n        super().__init__()\n\n    def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array:  # type: ignore\n        r\"\"\"Apply BM4D denoiser.\n\n        Args:\n            x: Input image.\n            lam: Noise parameter.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Denoised output.\n        \"\"\"\n        return denoiser.bm4d(x, lam, profile=self.profile)\n\n\nclass DnCNN(Functional):\n    \"\"\"Pseudo-functional whose prox applies the DnCNN denoising algorithm.\n\n    A pseudo-functional that has the DnCNN algorithm\n    :cite:`zhang-2017-dncnn` as its proximal operator, implemented via\n    :class:`.denoiser.DnCNN`.\n    \"\"\"\n\n    has_eval = False\n    has_prox = True\n\n    def __init__(self, variant: str = \"6M\"):\n        \"\"\"\n        Args:\n            variant: Identify the DnCNN model to be used. See\n               :class:`.denoiser.DnCNN` for valid values.\n        \"\"\"\n        self.dncnn = denoiser.DnCNN(variant)\n        if self.dncnn.is_blind:\n\n            def denoise(x, sigma):\n                return self.dncnn(x)\n\n        else:\n\n            def denoise(x, sigma):\n                return self.dncnn(x, sigma)\n\n        self._denoise = denoise\n\n    def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array:  # type: ignore\n        r\"\"\"Apply DnCNN denoiser.\n\n        *Warning*: The `lam` parameter is ignored, and has no effect on\n        the output for :class:`.DnCNN` objects initialized with\n        :code:`variant` parameter values other than `6N` and `17N`.\n\n        Args:\n            x: Input array.\n            lam: Noise parameter (ignored).\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Denoised output.\n        \"\"\"\n        return self._denoise(x, lam)\n"
  },
  {
    "path": "scico/functional/_dist.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Distance functions.\"\"\"\n\nfrom typing import Callable, Union\n\nfrom scico import numpy as snp\nfrom scico.numpy import Array, BlockArray\n\nfrom ._functional import Functional\n\n\nclass SetDistance(Functional):\n    r\"\"\"Distance to a closed convex set.\n\n    This functional computes the :math:`\\ell_2` distance from a vector to\n    a closed convex set :math:`C`\n\n    .. math::\n        d(\\mb{x}) = \\min_{\\mb{y} \\in C} \\, \\| \\mb{x} - \\mb{y} \\|_2 \\;.\n\n    The set is not specified directly, but in terms of a function\n    computing the projection into that set, i.e.\n\n\n    .. math::\n        d(\\mb{x}) = \\| \\mb{x} - P_C(\\mb{x}) \\|_2 \\;,\n\n    where :math:`P_C(\\mb{x})` is the projection of :math:`\\mb{x}` into\n    set :math:`C`.\n    \"\"\"\n\n    has_eval = True\n    has_prox = True\n\n    def __init__(self, proj: Callable, args=()):\n        r\"\"\"\n        Args:\n            proj: Function computing the projection into the convex set.\n            args: Additional arguments for function `proj`.\n        \"\"\"\n        self.proj = proj\n        self.args = args\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        r\"\"\"Compute the :math:`\\ell_2` distance to the set.\n\n        Compute the distance :math:`d(\\mb{x})` between :math:`\\mb{x}` and\n        the set :math:`C`.\n\n        Args:\n            x: Input array :math:`\\mb{x}`.\n\n        Returns:\n            Euclidean distance from `x` to the projection of `x`.\n        \"\"\"\n        y = self.proj(*((x,) + self.args))\n        return snp.linalg.norm(x - y)\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"Proximal operator of the :math:`\\ell_2` distance function.\n\n        Compute the proximal operator of the :math:`\\ell_2` distance\n        function :math:`d(\\mb{x})` :cite:`beck-2017-first` (Lemma 6.43).\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Proximal parameter :math:`\\lambda`.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Scaled proximal operator evaluated at `v`.\n        \"\"\"\n        y = self.proj(*((v,) + self.args))\n        d = snp.linalg.norm(v - y)\n        𝜃 = lam / d if d >= lam else 1.0\n        return 𝜃 * y + (1.0 - 𝜃) * v\n\n\nclass SquaredSetDistance(Functional):\n    r\"\"\"Squared :math:`\\ell_2` distance to a closed convex set.\n\n    This functional computes the :math:`\\ell_2` distance from a vector to\n    a closed convex set :math:`C`\n\n    .. math::\n        d(\\mb{x}) = \\min_{\\mb{y} \\in C} \\, (1/2) \\| \\mb{x} - \\mb{y} \\|_2^2\n        \\;.\n\n    The set is not specified directly, but in terms of a function\n    computing the projection into that set, i.e.\n\n\n    .. math::\n        d(\\mb{x}) = (1/2) \\| \\mb{x} - P_C(\\mb{x}) \\|_2^2 \\;,\n\n    where :math:`P_C(\\mb{x})` is the projection of :math:`\\mb{x}` into\n    set :math:`C`.\n    \"\"\"\n\n    has_eval = True\n    has_prox = True\n\n    def __init__(self, proj: Callable, args=()):\n        r\"\"\"\n        Args:\n            proj: Function computing the projection into the convex set.\n            args: Additional arguments for function `proj`.\n        \"\"\"\n        self.proj = proj\n        self.args = args\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        r\"\"\"Compute the squared :math:`\\ell_2` distance to the set.\n\n        Compute the distance :math:`d(\\mb{x})` between :math:`\\mb{x}` and\n        the set :math:`C`.\n\n        Args:\n            x: Input array :math:`\\mb{x}`.\n\n        Returns:\n            Squared :math:`\\ell_2` distance from `x` to the projection of `x`.\n        \"\"\"\n        y = self.proj(*((x,) + self.args))\n        return 0.5 * snp.linalg.norm(x - y) ** 2\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"Proximal operator of the squared :math:`\\ell_2` distance function.\n\n        Compute the proximal operator of the squared :math:`\\ell_2` distance\n        function :math:`d(\\mb{x})` :cite:`beck-2017-first` (Example 6.65).\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Proximal parameter :math:`\\lambda`.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Scaled proximal operator evaluated at `v`.\n        \"\"\"\n        y = self.proj(*((v,) + self.args))\n        𝛼 = 1.0 / (1.0 + lam)\n        return 𝛼 * v + lam * 𝛼 * y\n"
  },
  {
    "path": "scico/functional/_functional.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Functional base class.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom typing import List, Optional, Union\n\nimport scico\nfrom scico import numpy as snp\nfrom scico.linop import LinearOperator\nfrom scico.numpy import Array, BlockArray\n\n\nclass Functional:\n    r\"\"\"Base class for functionals.\n\n    A functional maps an :code:`array-like` to a scalar; abstractly, a\n    functional is a mapping from :math:`\\mathbb{R}^n` or\n    :math:`\\mathbb{C}^n` to :math:`\\mathbb{R}`.\n    \"\"\"\n\n    #: True if this functional can be evaluated, False otherwise.\n    #: This attribute must be overridden and set to True or False in any derived classes.\n    has_eval: Optional[bool] = None\n\n    #: True if this functional has the prox method, False otherwise.\n    #: This attribute must be overridden and set to True or False in any derived classes.\n    has_prox: Optional[bool] = None\n\n    def __init__(self):\n        self._grad = scico.grad(self.__call__)\n\n    def __str__(self):\n        return f\"\"\"{self.__module__}.{self.__class__.__qualname__}\"\"\"\n\n    def __repr__(self):\n        return self.__str__() + f\"\"\"\\n  has_eval: {self.has_eval}\\n  has_prox: {self.has_prox}\\n\"\"\"\n\n    def __mul__(self, other: Union[float, int]) -> ScaledFunctional:\n        if snp.util.is_scalar_equiv(other):\n            return ScaledFunctional(self, other)\n        return NotImplemented\n\n    def __rmul__(self, other: Union[float, int]) -> ScaledFunctional:\n        return self.__mul__(other)\n\n    def __add__(self, other: Functional) -> FunctionalSum:\n        if isinstance(other, Functional):\n            return FunctionalSum(self, other)\n        return NotImplemented\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        r\"\"\"Evaluate this functional at point :math:`\\mb{x}`.\n\n        Args:\n           x: Point at which to evaluate this functional.\n\n        Returns:\n           Result of evaluating the functional at `x`.\n        \"\"\"\n        # Functionals that can be evaluated should override this method.\n        raise NotImplementedError(f\"Functional {type(self)} cannot be evaluated.\")\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"Scaled proximal operator of functional.\n\n        Evaluate scaled proximal operator of this functional, with\n        scaling :math:`\\lambda` = `lam` and evaluated at point\n        :math:`\\mb{v}` = `v`. The scaled proximal operator is defined as\n\n        .. math::\n           \\prox_{\\lambda f}(\\mb{v}) = \\argmin_{\\mb{x}}\n           \\lambda f(\\mb{x}) +\n           \\frac{1}{2} \\norm{\\mb{v} - \\mb{x}}_2^2\\;,\n\n        where :math:`\\lambda f(\\mb{x})` represents this functional evaluated at\n        :math:`\\mb{x}` multiplied by :math:`\\lambda`.\n\n        Args:\n            v: Point at which to evaluate prox function.\n            lam: Proximal parameter :math:`\\lambda`.\n            **kwargs: Additional arguments that may be used by derived\n                classes. These include `x0`, an initial guess for the\n                minimizer in the definition of :math:`\\prox`.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        # Functionals that have a prox should override this method.\n        raise NotImplementedError(f\"Functional {type(self)} does not have a prox.\")\n\n    def conj_prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"Scaled proximal operator of convex conjugate of functional.\n\n        Evaluate scaled proximal operator of convex conjugate (Fenchel\n        conjugate) of this functional, with scaling\n        :math:`\\lambda` = `lam`, and evaluated at point\n        :math:`\\mb{v}` = `v`. Denoting this functional by :math:`f` and\n        its convex conjugate by :math:`f^*`, the proximal operator of\n        :math:`f^*` is computed as follows by exploiting the extended\n        Moreau decomposition (see Sec. 6.6 of :cite:`beck-2017-first`)\n\n        .. math::\n           \\prox_{\\lambda f^*}(\\mb{v}) = \\mb{v} - \\lambda \\,\n           \\prox_{\\lambda^{-1} f}(\\mb{v / \\lambda}) \\;.\n\n        Args:\n            v: Point at which to evaluate prox function.\n            lam: Proximal parameter :math:`\\lambda`.\n            **kwargs: Additional keyword args, passed directly to\n               `self.prox`.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        return v - lam * self.prox(v / lam, 1.0 / lam, **kwargs)\n\n    def grad(self, x: Union[Array, BlockArray]):\n        r\"\"\"Evaluate the gradient of this functional at :math:`\\mb{x}`.\n\n        Args:\n            x: Point at which to evaluate gradient.\n\n        Returns:\n            The gradient at `x`.\n        \"\"\"\n        return self._grad(x)\n\n\nclass ScaledFunctional(Functional):\n    r\"\"\"A functional multiplied by a scalar.\"\"\"\n\n    def __init__(self, functional: Functional, scale: float):\n        self.functional = functional\n        self.scale = scale\n        self.has_eval = functional.has_eval\n        self.has_prox = functional.has_prox\n        super().__init__()\n\n    def __repr__(self):\n        return (\n            f\"\"\"{Functional.__repr__(self)}\"\"\"\n            f\"\"\"  functional: {Functional.__str__(self.functional)}\\n\"\"\"\n            f\"\"\"  scale:      {self.scale}\\n\"\"\"\n        )\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        return self.scale * self.functional(x)\n\n    def __mul__(self, other: Union[float, int]) -> ScaledFunctional:\n        if snp.util.is_scalar_equiv(other):\n            return ScaledFunctional(self.functional, other * self.scale)\n        return NotImplemented\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"Evaluate the scaled proximal operator of the scaled functional.\n\n        Note that, by definition, the scaled proximal operator of a\n        functional is the proximal operator of the scaled functional. The\n        scaled proximal operator of a scaled functional is the scaled\n        proximal operator of the unscaled functional with the proximal\n        operator scaling consisting of the product of the two scaling\n        factors, i.e., for functional :math:`f` and scaling factors\n        :math:`\\alpha` and :math:`\\beta`, the proximal operator with\n        scaling parameter :math:`\\alpha` of scaled functional\n        :math:`\\beta f` is the proximal operator with scaling parameter\n        :math:`\\alpha \\beta` of functional :math:`f`,\n\n        .. math::\n           \\prox_{\\alpha (\\beta f)}(\\mb{v}) =\n           \\prox_{(\\alpha \\beta) f}(\\mb{v}) \\;.\n\n\n        Args:\n            v: Point at which to evaluate prox function.\n            lam: Proximal parameter :math:`\\lambda`.\n            **kwargs: Additional arguments that may be used by derived\n                classes. These include `x0`, an initial guess for the\n                minimizer in the definition of :math:`\\prox`.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        return self.functional.prox(v, lam * self.scale, **kwargs)\n\n\nclass SeparableFunctional(Functional):\n    r\"\"\"A functional that is separable in its arguments.\n\n    A separable functional :math:`f : \\mathbb{C}^N \\to \\mathbb{R}` can\n    be written as the sum of functionals :math:`f_i : \\mathbb{C}^{N_i}\n    \\to \\mathbb{R}` with :math:`\\sum_i N_i = N`. In particular,\n\n    .. math::\n       f(\\mb{x}) = f(\\mb{x}_1, \\dots, \\mb{x}_N) = f_1(\\mb{x}_1) + \\dots\n       + f_N(\\mb{x}_N) \\;.\n\n    A :class:`SeparableFunctional` accepts a :class:`.BlockArray` and is\n    separable in the block components.\n    \"\"\"\n\n    def __init__(self, functional_list: List[Functional]):\n        r\"\"\"\n        Args:\n            functional_list: List of component functionals f_i. This\n               functional takes as an input a :class:`.BlockArray` with\n               `num_blocks == len(functional_list)`.\n        \"\"\"\n        self.functional_list: List[Functional] = functional_list\n        self.has_eval: bool = all(fi.has_eval for fi in functional_list)\n        self.has_prox: bool = all(fi.has_prox for fi in functional_list)\n        super().__init__()\n\n    def __repr__(self):\n        return (\n            Functional.__repr__(self)\n            + \"  components: \"\n            + \", \".join([str(f) for f in self.functional_list])\n            + \"\\n\"\n        )\n\n    def __call__(self, x: BlockArray) -> float:\n        if len(x.shape) == len(self.functional_list):\n            return snp.sum(snp.array([fi(xi) for fi, xi in zip(self.functional_list, x)]))\n        raise ValueError(\n            f\"Number of blocks in x, {len(x.shape)}, and length of functional_list, \"\n            f\"{len(self.functional_list)}, do not match.\"\n        )\n\n    def prox(self, v: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray:\n        r\"\"\"Evaluate proximal operator of the separable functional.\n\n        Evaluate proximal operator of the separable functional (see\n        Theorem 6.6 of :cite:`beck-2017-first`).\n\n          .. math::\n             \\prox_{\\lambda f}(\\mb{v})\n             =\n             \\begin{bmatrix}\n               \\prox_{\\lambda f_1}(\\mb{v}_1) \\\\ \\vdots \\\\\n               \\prox_{\\lambda f_N}(\\mb{v}_N) \\\\\n             \\end{bmatrix} \\;.\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Proximal parameter :math:`\\lambda`.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        if len(v.shape) == len(self.functional_list):\n            return snp.blockarray(\n                [fi.prox(vi, lam, **kwargs) for fi, vi in zip(self.functional_list, v)]\n            )\n        raise ValueError(\n            f\"Number of blocks in v, {len(v.shape)}, and length of functional_list, \"\n            f\"{len(self.functional_list)}, do not match.\"\n        )\n\n\nclass ComposedFunctional(Functional):\n    r\"\"\"A functional constructed by composition.\n\n    A functional constructed by composition of a functional with an\n    orthogonal linear operator, i.e.\n\n    .. math::\n       f(\\mb{x}) = g(A \\mb{x})\n\n    where :math:`f` is the composed functional, :math:`g` is the\n    functional from which it is composed, and :math:`A` is an orthogonal\n    linear operator. Note that the resulting :class:`Functional` can only\n    be applied (either via evaluation or :meth:`prox` calls) to inputs\n    of shape and dtype corresponding to the input specification of the\n    linear operator.\n    \"\"\"\n\n    def __init__(self, functional: Functional, linop: LinearOperator):\n        r\"\"\"\n        Args:\n            functional: The functional :math:`g` to be composed.\n            linop: The linear operator :math:`A` to be composed. Note\n              that it is the user's responsibility to confirm that\n              the linear operator is orthogonal. If it is not, the\n              result of :meth:`prox` will be incorrect.\n        \"\"\"\n        self.functional = functional\n        self.linop = linop\n        self.has_eval = functional.has_eval\n        self.has_prox = functional.has_prox\n        super().__init__()\n\n    def __repr__(self):\n        return (\n            Functional.__repr__(self)\n            + f\"\"\"  composition of: {self.functional.__str__()} and {self.linop.__str__()}\\n\"\"\"\n        )\n\n    def __call__(self, x: BlockArray) -> float:\n        return self.functional(self.linop(x))\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"Evaluate proximal operator of a composed functional.\n\n        Evaluate proximal operator :math:`f(\\mb{x}) = g(A \\mb{x})`, where\n        :math:`A` is an orthogonal linear operator, via a special case of\n        Theorem 6.15 of :cite:`beck-2017-first`\n\n        .. math::\n           \\prox_{\\lambda f}(\\mb{v}) = A^T \\prox_{\\lambda g}(A \\mb{v}) \\;.\n\n        Examples of orthogonal linear operator in SCICO include\n        :class:`.linop.Reshape` and :class:`.linop.Transpose`.\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Proximal parameter :math:`\\lambda`.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        return self.linop.H(self.functional.prox(self.linop(v), lam=lam, **kwargs))\n\n\nclass FunctionalSum(Functional):\n    r\"\"\"A sum of two functionals.\"\"\"\n\n    def __init__(self, functional1: Functional, functional2: Functional):\n        self.functional1 = functional1\n        self.functional2 = functional2\n        self.has_eval = functional1.has_eval and functional2.has_eval\n        self.has_prox = False\n        super().__init__()\n\n    def __repr__(self):\n        return (\n            Functional.__repr__(self)\n            + f\"\"\"  sum of functionals: {Functional.__str__(self.functional1)} and \"\"\"\n            + f\"\"\"{Functional.__str__(self.functional2)}\\n\"\"\"\n        )\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        return self.functional1(x) + self.functional2(x)\n\n\nclass ZeroFunctional(Functional):\n    r\"\"\"Zero functional, :math:`f(\\mb{x}) = 0 \\in \\mbb{R}` for any input.\"\"\"\n\n    has_eval = True\n    has_prox = True\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        return 0.0\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        return v\n"
  },
  {
    "path": "scico/functional/_indicator.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Functionals that are indicator functions/constraints.\"\"\"\n\nfrom typing import Union\n\nimport jax\n\nfrom scico import numpy as snp\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.linalg import norm\n\nfrom ._functional import Functional\n\n\nclass NonNegativeIndicator(Functional):\n    r\"\"\"Indicator function for non-negative orthant.\n\n    Returns 0 if all elements of input array-like are non-negative, and\n    `inf` otherwise\n\n    .. math::\n        I(\\mb{x}) = \\begin{cases}\n        0  & \\text{ if } x_i \\geq 0 \\; \\forall i \\\\\n        \\infty  & \\text{ otherwise} \\;.\n        \\end{cases}\n    \"\"\"\n\n    has_eval = True\n    has_prox = True\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        if snp.util.is_complex_dtype(x.dtype):\n            raise ValueError(\"Not defined for complex input.\")\n\n        # Equivalent to snp.inf if snp.any(x < 0) else 0.0\n        return jax.lax.cond(snp.any(x < 0), lambda x: snp.inf, lambda x: 0.0, None)\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"The scaled proximal operator of the non-negative indicator.\n\n        Evaluate the scaled proximal operator of the indicator over\n        the non-negative orthant, :math:`I`,\n\n        .. math::\n            [\\mathrm{prox}_{\\lambda I}(\\mb{v})]_i =\n            \\begin{cases}\n            v_i\\, & \\text{ if } v_i \\geq 0 \\\\\n            0\\, & \\text{ otherwise} \\;.\n            \\end{cases}\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Proximal parameter :math:`\\lambda` (has no effect).\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        return snp.maximum(v, 0)\n\n\nclass L2BallIndicator(Functional):\n    r\"\"\"Indicator function for :math:`\\ell_2` ball of given radius.\n\n    Indicator function for :math:`\\ell_2` ball of given radius, :math:`r`\n\n    .. math::\n        I(\\mb{x}) = \\begin{cases}\n        0  & \\text{ if } \\norm{\\mb{x}}_2 \\leq r \\\\\n        \\infty  & \\text{ otherwise} \\;.\n        \\end{cases}\n\n    Attributes:\n        radius: Radius of :math:`\\ell_2` ball.\n    \"\"\"\n\n    has_eval = True\n    has_prox = True\n\n    def __init__(self, radius: float = 1.0):\n        r\"\"\"Initialize a :class:`L2BallIndicator` object.\n\n        Args:\n            radius: Radius of :math:`\\ell_2` ball. Default: 1.0.\n        \"\"\"\n        self.radius = radius\n        super().__init__()\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        # Equivalent to: snp.inf if norm(x) > self.radius else 0.0\n        return jax.lax.cond(norm(x) > self.radius, lambda x: snp.inf, lambda x: 0.0, None)\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"The scaled proximal operator of the :math:`\\ell_2` ball indicator.\n\n        Evaluate the scaled proximal operator of the indicator, :math:`I`,\n        of the :math:`\\ell_2` ball with radius :math:`r`\n\n        .. math::\n            \\mathrm{prox}_{\\lambda I}(\\mb{v}) = \\begin{cases}\n            \\mb{v}  & \\text{ if } \\norm{\\mb{v}}_2 \\leq r \\\\\n            r \\frac{\\mb{v}}{\\norm{\\mb{v}}_2}  & \\text{ otherwise} \\;.\n            \\end{cases}\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Proximal parameter :math:`\\lambda` (has no effect).\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        return jax.lax.cond(\n            norm(v) > self.radius, lambda v: self.radius * v / norm(v), lambda v: v, v\n        )\n\n\nclass BoxIndicator(Functional):\n    r\"\"\"Box indicator function..\n\n    Indicator function of the constraint set :math:`a \\leq x \\leq b` for\n    lower and upper bounds :math:`a` and :math:`b` respectively.\n    \"\"\"\n\n    has_eval = True\n    has_prox = True\n\n    def __init__(self, lb: float = 0.0, ub: float = 1.0):\n        r\"\"\"Initialize a :class:`BoxIndicator` object.\n\n        Args:\n            lb: Lower bound.\n            ub: Upper bound.\n        \"\"\"\n        self.lb = lb\n        self.ub = ub\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        if snp.util.is_complex_dtype(x.dtype):\n            raise ValueError(\"Not defined for complex input.\")\n        constr = snp.logical_and(self.lb <= x, x <= self.ub)\n        return jax.lax.cond(snp.all(constr), lambda x: 0.0, lambda x: snp.inf, None)\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"The scaled proximal operator of the box indicator.\n\n        Evaluate the scaled proximal operator of the constraint set\n        :math:`a \\leq x \\leq b` for lower and upper bounds :math:`a` and\n        :math:`b` respectively.\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Proximal parameter :math:`\\lambda` (has no effect).\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        return snp.clip(v, self.lb, self.ub)\n"
  },
  {
    "path": "scico/functional/_norm.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Functionals that are norms.\"\"\"\n\nfrom functools import partial\nfrom typing import Optional, Tuple, Union\n\nfrom jax import jit, lax\n\nfrom scico import numpy as snp\nfrom scico.numpy import Array, BlockArray, count_nonzero\nfrom scico.numpy.linalg import norm\nfrom scico.numpy.util import no_nan_divide\n\nfrom ._functional import Functional\n\n\nclass L0Norm(Functional):\n    r\"\"\"The :math:`\\ell_0` 'norm'.\n\n    The :math:`\\ell_0` 'norm' counts the number of non-zero elements in\n    an array.\n    \"\"\"\n\n    has_eval = True\n    has_prox = True\n\n    @staticmethod\n    @jit\n    def __call__(x: Union[Array, BlockArray]) -> float:\n        return count_nonzero(x)\n\n    @staticmethod\n    @jit\n    def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Union[Array, BlockArray]:\n        r\"\"\"Evaluate scaled proximal operator of :math:`\\ell_0` norm.\n\n        Evaluate scaled proximal operator of :math:`\\ell_0` norm using\n\n        .. math::\n\n            \\left[ \\prox_{\\lambda\\| \\cdot \\|_0}(\\mb{v}) \\right]_i =\n            \\begin{cases}\n            v_i  & \\text{ if } \\abs{v_i} \\geq \\lambda \\\\\n            0  & \\text{ otherwise } \\;.\n            \\end{cases}\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Thresholding parameter :math:`\\lambda`.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        return snp.where(snp.abs(v) >= lam, v, 0)\n\n\nclass L1Norm(Functional):\n    r\"\"\"The :math:`\\ell_1` norm.\n\n    Computes\n\n    .. math::\n       \\norm{\\mb{x}}_1 = \\sum_i \\abs{x_i}^2 \\;.\n    \"\"\"\n\n    has_eval = True\n    has_prox = True\n\n    @staticmethod\n    @jit\n    def __call__(x: Union[Array, BlockArray]) -> float:\n        return snp.sum(snp.abs(x))\n\n    @staticmethod\n    @jit\n    def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Array:\n        r\"\"\"Evaluate scaled proximal operator of :math:`\\ell_1` norm.\n\n        Evaluate scaled proximal operator of :math:`\\ell_1` norm using\n\n        .. math::\n            \\left[ \\prox_{\\lambda \\|\\cdot\\|_1}(\\mb{v}) \\right]_i =\n            \\sign(v_i) (\\abs{v_i} - \\lambda)_+ \\;,\n\n        where\n\n        .. math::\n            (x)_+ = \\begin{cases}\n            x  & \\text{ if } x \\geq 0 \\\\\n            0  & \\text{ otherwise} \\;.\n            \\end{cases}\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Thresholding parameter :math:`\\lambda`.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        tmp = snp.abs(v) - lam\n        tmp = 0.5 * (tmp + snp.abs(tmp))\n        if snp.util.is_complex_dtype(v.dtype):\n            out = snp.exp(1j * snp.angle(v)) * tmp\n        else:\n            out = snp.sign(v) * tmp\n        return out\n\n\nclass SquaredL2Norm(Functional):\n    r\"\"\"The squared :math:`\\ell_2` norm.\n\n    Squared :math:`\\ell_2` norm\n\n    .. math::\n       \\norm{\\mb{x}}^2_2 = \\sum_i \\abs{x_i}^2 \\;.\n    \"\"\"\n\n    has_eval = True\n    has_prox = True\n\n    @staticmethod\n    @jit\n    def __call__(x: Union[Array, BlockArray]) -> float:\n        # Directly implement the squared l2 norm to avoid nondifferentiable\n        # behavior of snp.norm(x) at 0.\n        return snp.sum(snp.abs(x) ** 2)\n\n    @staticmethod\n    @jit\n    def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Union[Array, BlockArray]:\n        r\"\"\"Evaluate proximal operator of squared :math:`\\ell_2` norm.\n\n        Evaluate proximal operator of squared :math:`\\ell_2` norm using\n\n        .. math::\n            \\prox_{\\lambda \\| \\cdot \\|_2^2}(\\mb{v})\n            = \\frac{\\mb{v}}{1 + 2 \\lambda} \\;.\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Proximal parameter :math:`\\lambda`.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        return v / (1.0 + 2.0 * lam)\n\n\nclass L2Norm(Functional):\n    r\"\"\"The :math:`\\ell_2` norm.\n\n    .. math::\n       \\norm{\\mb{x}}_2 = \\sqrt{\\sum_i \\abs{x_i}^2} \\;.\n    \"\"\"\n\n    has_eval = True\n    has_prox = True\n\n    @staticmethod\n    @jit\n    def __call__(x: Union[Array, BlockArray]) -> float:\n        return norm(x)\n\n    @staticmethod\n    @jit\n    def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Union[Array, BlockArray]:\n        r\"\"\"Evaluate proximal operator of :math:`\\ell_2` norm.\n\n        Evaluate proximal operator of :math:`\\ell_2` norm using\n\n        .. math::\n            \\prox_{\\lambda \\| \\cdot \\|_2}(\\mb{v}) = \\mb{v} \\,\n            \\left(1 - \\frac{\\lambda}{\\norm{\\mb{v}}_2} \\right)_+ \\;,\n\n        where\n\n        .. math::\n            (x)_+ = \\begin{cases}\n            x  & \\text{ if } x \\geq 0 \\\\\n            0  & \\text{ otherwise} \\;.\n            \\end{cases}\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Proximal parameter :math:`\\lambda`.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        norm_v = norm(v)\n        return snp.where(norm_v == 0, 0 * v, snp.maximum(1 - lam / norm_v, 0) * v)\n\n\nclass L21Norm(Functional):\n    r\"\"\"The :math:`\\ell_{2,1}` norm.\n\n    For a :math:`M \\times N` matrix, :math:`\\mb{A}`, by default,\n\n    .. math::\n           \\norm{\\mb{A}}_{2,1} = \\sum_{n=1}^N \\sqrt{\\sum_{m=1}^M\n           \\abs{A_{m,n}}^2} \\;.\n\n    The norm generalizes to more dimensions by first computing the\n    :math:`\\ell_2` norm along one or more (user-specified) axes,\n    followed by a sum over all remaining axes. :class:`.BlockArray` inputs\n    require parameter `l2_axis` to be  ``None``, in which case the\n    :math:`\\ell_2` norm is computed over each block.\n\n    A typical use case is computing the isotropic total variation norm.\n    \"\"\"\n\n    has_eval = True\n    has_prox = True\n\n    def __init__(self, l2_axis: Union[None, int, Tuple] = 0):\n        r\"\"\"\n        Args:\n            l2_axis: Axis/axes over which to take the l2 norm. Required\n               to be ``None`` for :class:`.BlockArray` inputs to be\n               supported.\n        \"\"\"\n        self.l2_axis = l2_axis\n\n    @staticmethod\n    @partial(jit, static_argnames=(\"axis\", \"keepdims\"))\n    def _l2norm(\n        x: Union[Array, BlockArray], axis: Union[None, int, Tuple], keepdims: Optional[bool] = False\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"Return the :math:`\\ell_2` norm of an array.\"\"\"\n        return snp.sqrt((snp.abs(x) ** 2).sum(axis=axis, keepdims=keepdims))\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        if isinstance(x, snp.BlockArray) and self.l2_axis is not None:\n            raise ValueError(\"Initializer argument 'l2_axis' must be None for BlockArray input.\")\n        l2 = L21Norm._l2norm(x, axis=self.l2_axis)\n        return snp.sum(snp.abs(l2))\n\n    @staticmethod\n    @partial(jit, static_argnames=(\"axis\"))\n    def _prox(\n        v: Union[Array, BlockArray], lam: float, axis: Union[None, int, Tuple]\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"Evaluate proximal operator of the :math:`\\ell_{2,1}` norm.\"\"\"\n        length = L21Norm._l2norm(v, axis=axis, keepdims=True)\n        direction = no_nan_divide(v, length)\n        new_length = length - lam\n        # set negative values to zero without `if`\n        new_length = 0.5 * (new_length + snp.abs(new_length))\n        return new_length * direction\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"Evaluate proximal operator of the :math:`\\ell_{2,1}` norm.\n\n        In two dimensions,\n\n        .. math::\n            \\prox_{\\lambda \\|\\cdot\\|_{2,1}}(\\mb{v}, \\lambda)_{:, n} =\n             \\frac{\\mb{v}_{:, n}}{\\|\\mb{v}_{:, n}\\|_2}\n             (\\|\\mb{v}_{:, n}\\|_2 - \\lambda)_+ \\;,\n\n        where\n\n        .. math::\n            (x)_+ = \\begin{cases}\n            x  & \\text{ if } x \\geq 0 \\\\\n            0  & \\text{ otherwise} \\;.\n            \\end{cases}\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Proximal parameter :math:`\\lambda`.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        if isinstance(v, snp.BlockArray) and self.l2_axis is not None:\n            raise ValueError(\"Initializer argument 'l2_axis' must be None for BlockArray input.\")\n        return L21Norm._prox(v, lam=lam, axis=self.l2_axis)\n\n\nclass L1MinusL2Norm(Functional):\n    r\"\"\"Difference of :math:`\\ell_1` and :math:`\\ell_2` norms.\n\n    Difference of :math:`\\ell_1` and :math:`\\ell_2` norms\n\n    .. math::\n        \\norm{\\mb{x}}_1 - \\beta * \\norm{\\mb{x}}_2\n    \"\"\"\n\n    has_eval = True\n    has_prox = True\n\n    def __init__(self, beta: float = 1.0):\n        r\"\"\"\n        Args:\n            beta: Parameter :math:`\\beta` in the norm definition.\n        \"\"\"\n        self.beta = beta\n\n    @staticmethod\n    @jit\n    def _l1minusl2norm(x: Union[Array, BlockArray], beta: float) -> float:\n        r\"\"\"Return the :math:`\\ell_1 - \\ell_2` norm of an array.\"\"\"\n        return snp.sum(snp.abs(x)) - beta * norm(x)\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        return L1MinusL2Norm._l1minusl2norm(x, self.beta)\n\n    @staticmethod\n    def _prox_vamx_ge_thresh(v, va, vs, alpha, beta):\n        u = snp.zeros(v.shape, dtype=v.dtype)\n        idx = va.ravel().argmax()\n        u = (\n            u.ravel().at[idx].set((va.ravel()[idx] + (beta - 1.0) * alpha) * vs.ravel()[idx])\n        ).reshape(v.shape)\n        return u\n\n    @staticmethod\n    def _prox_vamx_le_alpha(v, va, vs, vamx, alpha, beta):\n        return snp.where(\n            vamx < (1.0 - beta) * alpha,\n            snp.zeros(v.shape, dtype=v.dtype),\n            L1MinusL2Norm._prox_vamx_ge_thresh(v, va, vs, alpha, beta),\n        )\n\n    @staticmethod\n    def _prox_vamx_gt_alpha(v, va, vs, alpha, beta):\n        u = snp.maximum(va - alpha, 0.0) * vs\n        l2u = norm(u)\n        u *= (l2u + alpha * beta) / l2u\n        return u\n\n    @staticmethod\n    def _prox_vamx_gt_0(v, va, vs, vamx, alpha, beta):\n        return snp.where(\n            vamx > alpha,\n            L1MinusL2Norm._prox_vamx_gt_alpha(v, va, vs, alpha, beta),\n            L1MinusL2Norm._prox_vamx_le_alpha(v, va, vs, vamx, alpha, beta),\n        )\n\n    @staticmethod\n    @jit\n    def _prox(v: Union[Array, BlockArray], lam: float, beta: float) -> Union[Array, BlockArray]:\n        r\"\"\"Proximal operator of :math:`\\ell_1 - \\ell_2` norm.\"\"\"\n        alpha = lam\n        va = snp.abs(v)\n        vamx = snp.max(va)\n        if snp.util.is_complex_dtype(v.dtype):\n            vs = snp.exp(1j * snp.angle(v))\n        else:\n            vs = snp.sign(v)\n\n        return snp.where(\n            vamx > 0.0,\n            L1MinusL2Norm._prox_vamx_gt_0(v, va, vs, vamx, alpha, beta),\n            snp.zeros(v.shape, dtype=v.dtype),\n        )\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"Proximal operator of difference of :math:`\\ell_1` and :math:`\\ell_2` norms.\n\n        Evaluate the proximal operator of the difference of :math:`\\ell_1`\n        and :math:`\\ell_2` norms, i.e. :math:`\\alpha \\left( \\| \\mb{x}\n        \\|_1 - \\beta \\| \\mb{x} \\|_2 \\right)` :cite:`lou-2018-fast`. Note\n        that this is not a proximal operator according to the strict\n        definition since the loss function is non-convex.\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Proximal parameter :math:`\\lambda`.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        return L1MinusL2Norm._prox(v, lam=lam, beta=self.beta)\n\n\nclass HuberNorm(Functional):\n    r\"\"\"Huber norm.\n\n    Compute a norm based on the Huber function :cite:`huber-1964-robust`\n    :cite:`beck-2017-first` (Sec. 6.7.1). In the non-separable case the\n    norm is\n\n    .. math::\n         H_{\\delta}(\\mb{x}) = \\begin{cases}\n         (1/2) \\norm{ \\mb{x} }_2^2  & \\text{ when } \\norm{ \\mb{x} }_2\n         \\leq \\delta \\\\\n         \\delta \\left( \\norm{ \\mb{x} }_2  - (\\delta / 2) \\right) &\n         \\text{ when } \\norm{ \\mb{x} }_2 > \\delta \\;,\n         \\end{cases}\n\n    where :math:`\\delta` is a parameter controlling the transitions\n    between :math:`\\ell_1`-norm like and :math:`\\ell_2`-norm like\n    behavior. In the separable case the norm is\n\n    .. math::\n         H_{\\delta}(\\mb{x}) = \\sum_i h_{\\delta}(x_i) \\,,\n\n    where\n\n    .. math::\n         h_{\\delta}(x) = \\begin{cases}\n         (1/2) \\abs{ x }^2  & \\text{ when } \\abs{ x } \\leq \\delta \\\\\n         \\delta \\left( \\abs{ x }  - (\\delta / 2) \\right) &\n         \\text{ when } \\abs{ x } > \\delta \\;.\n         \\end{cases}\n    \"\"\"\n\n    has_eval = True\n    has_prox = True\n\n    def __init__(self, delta: float = 1.0, separable: bool = True):\n        r\"\"\"\n        Args:\n            delta: Huber function parameter :math:`\\delta`.\n            separable: Flag indicating whether to compute separable or\n               non-separable form.\n        \"\"\"\n        self.delta = delta\n        self.separable = separable\n\n        if separable:\n            self._call = self._call_sep\n            self._prox = self._prox_sep\n        else:\n            self._call = self._call_nonsep\n            self._prox = self._prox_nonsep\n\n        super().__init__()\n\n    @staticmethod\n    @jit\n    def _call_sep(x: Union[Array, BlockArray], delta: float) -> float:\n        xabs = snp.abs(x)\n        hx = snp.where(xabs <= delta, 0.5 * xabs**2, delta * (xabs - (delta / 2.0)))\n        return snp.sum(hx)\n\n    @staticmethod\n    @jit\n    def _call_nonsep(x: Union[Array, BlockArray], delta: float) -> float:\n        xl2 = snp.linalg.norm(x)\n        return lax.cond(\n            xl2 <= delta, lambda xl2: 0.5 * xl2**2, lambda xl2: delta * (xl2 - delta / 2.0), xl2\n        )\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        return self._call(x, self.delta)\n\n    @staticmethod\n    @jit\n    def _prox_sep(\n        v: Union[Array, BlockArray], lam: float, delta: float\n    ) -> Union[Array, BlockArray]:\n        den = snp.maximum(snp.abs(v), delta * (1.0 + lam))\n        return (1.0 - ((delta * lam) / den)) * v\n\n    @staticmethod\n    @jit\n    def _prox_nonsep(\n        v: Union[Array, BlockArray], lam: float, delta: float\n    ) -> Union[Array, BlockArray]:\n        vl2 = snp.linalg.norm(v)\n        den = snp.maximum(vl2, delta * (1.0 + lam))\n        return (1.0 - ((delta * lam) / den)) * v\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"Evaluate proximal operator of the Huber function.\n\n        Evaluate scaled proximal operator of the Huber function\n        :cite:`beck-2017-first` (Sec. 6.7.3). The prox is\n\n        .. math::\n             \\prox_{\\lambda H_{\\delta}} (\\mb{v}) = \\left( 1 -\n             \\frac{\\lambda \\delta} {\\max\\left\\{\\norm{\\mb{v}}_2,\n             \\delta + \\lambda \\delta\\right\\} } \\right) \\mb{v}\n\n        in the non-separable case, and\n\n        .. math::\n             \\left[ \\prox_{\\lambda H_{\\delta}} (\\mb{v}) \\right]_i =\n             \\left( 1 - \\frac{\\lambda \\delta} {\\max\\left\\{\\abs{v_i},\n             \\delta + \\lambda \\delta\\right\\} } \\right) v_i\n\n        in the separable case.\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Proximal parameter :math:`\\lambda`.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        return self._prox(v, lam=lam, delta=self.delta)\n\n\nclass NuclearNorm(Functional):\n    r\"\"\"Nuclear norm.\n\n    Compute the nuclear norm\n\n    .. math::\n        \\| X \\|_* = \\sum_i \\sigma_i\n\n    where :math:`\\sigma_i` are the singular values of matrix :math:`X`.\n    \"\"\"\n\n    has_eval = True\n    has_prox = True\n\n    @staticmethod\n    @jit\n    def __call__(x: Union[Array, BlockArray]) -> float:\n        if x.ndim != 2:\n            raise ValueError(\"Input array must be two dimensional.\")\n        return snp.sum(snp.linalg.svd(x, full_matrices=False, compute_uv=False))\n\n    @staticmethod\n    @jit\n    def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Union[Array, BlockArray]:\n        r\"\"\"Evaluate proximal operator of the nuclear norm.\n\n        Evaluate proximal operator of the nuclear norm\n        :cite:`cai-2010-singular`.\n\n        Args:\n            v: Input array :math:`\\mb{v}`. Required to be two-dimensional.\n            lam: Proximal parameter :math:`\\lambda`.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        if v.ndim != 2:\n            raise ValueError(\"Input array must be two dimensional.\")\n        svdU, svdS, svdV = snp.linalg.svd(v, full_matrices=False)\n        svdS = snp.maximum(0, svdS - lam)\n        return svdU @ snp.diag(svdS) @ svdV\n"
  },
  {
    "path": "scico/functional/_proxavg.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2023-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Implementation of the proximal average method.\"\"\"\n\nfrom typing import List, Optional, Union\n\nfrom scico.numpy import Array, BlockArray, isinf\n\nfrom ._functional import Functional\n\n\nclass ProximalAverage(Functional):\n    \"\"\"Weighted average of functionals.\n\n    A functional that is composed of a weighted average of functionals.\n    All of the component functionals are required to have proximal\n    operators. The proximal operator of the composite functional is\n    approximated via the proximal average method :cite:`yu-2013-better`,\n    which holds for small scaling parameters. This does not imply that it\n    can only be applied to problems requiring a small regularization\n    parameter since most proximal algorithms include an additional\n    algorithm parameter that also plays a role in the parameter of the\n    proximal operator. For example, in :class:`.PGM` and\n    :class:`.AcceleratedPGM`, the scaled proximal operator parameter\n    is the regularization parameter divided by the `L0` algorithm\n    parameter, and for :class:`.ADMM`, the scaled proximal operator\n    parameters are the regularization parameters divided by the entries\n    in the `rho_list` algorithm parameter.\n    \"\"\"\n\n    def __init__(\n        self,\n        func_list: List[Functional],\n        alpha_list: Optional[List[float]] = None,\n        no_inf_eval=True,\n    ):\n        \"\"\"\n        Args:\n            func_list: List of component :class:`.Functional` objects,\n                all of which must have a proximal operator.\n            alpha_list: List of scalar weights for each\n                :class:`.Functional`. If not specified, defaults to equal\n                weights. If specified, the list of weights must have the\n                same length as the :class:`.Functional` list. If the\n                weights do not sum to unity, they are scaled to ensure\n                that they do.\n            no_inf_eval: If ``True``, exclude infinite values (typically\n                associated with a functional that is an indicator\n                function) from the evaluation of the sum of component\n                functionals.\n        \"\"\"\n        self.has_prox = all([f.has_prox for f in func_list])\n        if not self.has_prox:\n            raise ValueError(\"All functionals in 'func_list' must have has_prox == True.\")\n        self.has_eval = all([f.has_eval for f in func_list])\n        self.no_inf_eval = no_inf_eval\n        self.func_list = func_list\n        N = len(func_list)\n        if alpha_list is None:\n            self.alpha_list = [1.0 / N] * N\n        else:\n            if len(alpha_list) != N:\n                raise ValueError(\n                    \"If specified, argument 'alpha_list' must have the same length as func_list\"\n                )\n            alpha_sum = sum(alpha_list)\n            if alpha_sum != 1.0:\n                alpha_list = [alpha / alpha_sum for alpha in alpha_list]\n            self.alpha_list = alpha_list\n\n    def __repr__(self):\n        return (\n            Functional.__repr__(self)\n            + \"  components: \"\n            + \", \".join([str(f) for f in self.func_list])\n            + \"\\n  weights:    \"\n            + \", \".join([str(alpha) for alpha in self.alpha_list])\n            + \"\\n\"\n        )\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        \"\"\"Evaluate the weighted average of component functionals.\"\"\"\n        if self.has_eval:\n            weight_func_vals = [alpha * f(x) for (alpha, f) in zip(self.alpha_list, self.func_list)]\n            if self.no_inf_eval:\n                weight_func_vals = list(filter(lambda x: not isinf(x), weight_func_vals))\n            return sum(weight_func_vals)\n        else:\n            raise ValueError(\n                \"At least one functional in argument 'func_list' has has_eval == False.\"\n            )\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"Approximate proximal operator of the average of functionals.\n\n        Approximation of the proximal operator of a weighted average of\n        functionals computed via the proximal average method\n        :cite:`yu-2013-better`.\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Proximal parameter :math:`\\lam`.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        return sum(\n            [\n                alpha * f.prox(v, lam, **kwargs)\n                for (alpha, f) in zip(self.alpha_list, self.func_list)\n            ]\n        )\n"
  },
  {
    "path": "scico/functional/_tvnorm.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2023-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Total variation norms.\"\"\"\n\nfrom functools import partial\nfrom typing import Optional, Tuple\n\nimport jax\n\nfrom scico import numpy as snp\nfrom scico.linop import (\n    Crop,\n    FiniteDifference,\n    LinearOperator,\n    Pad,\n    SingleAxisFiniteDifference,\n    VerticalStack,\n    linop_over_axes,\n)\nfrom scico.numpy import Array\nfrom scico.numpy.util import normalize_axes\nfrom scico.typing import Axes, DType, Shape\n\nfrom ._functional import Functional\nfrom ._norm import L1Norm, L21Norm\n\n\nclass TVNorm(Functional):\n    r\"\"\"Generic total variation (TV) norm.\n\n    Generic total variation (TV) norm with approximation of the scaled\n    proximal operator :cite:`kamilov-2016-parallel`\n    :cite:`kamilov-2016-minimizing` :cite:`chandler-2024-closedform`.\n    \"\"\"\n\n    has_eval = True\n    has_prox = True\n\n    def __init__(\n        self,\n        norm: Functional,\n        circular: bool = True,\n        axes: Optional[Axes] = None,\n        input_shape: Optional[Shape] = None,\n        input_dtype: DType = snp.float32,\n    ):\n        \"\"\"\n        While initializers for :class:`.Functional` objects typically do\n        not take `input_shape` and `input_dtype` parameters, they are\n        included here because methods :meth:`__call__` and :meth:`prox`\n        require instantiation of some :class:`.LinearOperator` objects,\n        which do take these parameters. If these parameters are not\n        provided on intialization of a :class:`TVNorm` object, then\n        creation of the required :class:`.LinearOperator` objects is\n        deferred until these methods are called, which can result in\n        `JAX tracer <https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables>`__\n        errors when they are components of a jitted function.\n\n        Args:\n            norm: Norm functional from which the TV norm is composed.\n            circular: Flag indicating use of circular boundary conditions.\n            axes: Axis or axes over which to apply finite difference\n                operator. If not specified, or ``None``, differences are\n                evaluated along all axes.\n            input_shape: Shape of input arrays of :meth:`__call__` and\n                :meth:`prox`.\n            input_dtype: `dtype` of input arrays of :meth:`__call__` and\n                :meth:`prox`.\n        \"\"\"\n        self.norm = norm\n        self.circular = circular\n        self.axes = axes\n        self.G: Optional[LinearOperator] = None\n        self.WP: Optional[LinearOperator] = None\n        self.prox_ndims: Optional[int] = None\n        self.prox_slice: Optional[Tuple] = None\n\n        if input_shape is not None:\n            self.G = self._call_operator(input_shape, input_dtype)\n            self.WP, self.CWT, self.prox_ndims, self.prox_slice = self._prox_operators(\n                input_shape, input_dtype\n            )\n\n    def _call_operator(self, input_shape: Shape, input_dtype: DType) -> LinearOperator:\n        \"\"\"Construct operator required by __call__ method.\"\"\"\n        G = FiniteDifference(\n            input_shape,\n            input_dtype=input_dtype,\n            axes=self.axes,\n            circular=self.circular,\n            # For non-circular boundary conditions, zero-pad to the right\n            # for equivalence with boundary conditions implemented in the\n            # prox calculation.\n            append=None if self.circular else 0,\n            jit=True,\n        )\n        return G\n\n    def __call__(self, x: Array) -> float:\n        \"\"\"Compute the TV norm of an array.\n\n        Args:\n            x: Array for which the TV norm should be computed.\n\n        Returns:\n              TV norm of `x`.\n        \"\"\"\n        if self.G is None or self.G.shape[1] != x.shape:\n            self.G = self._call_operator(x.shape, x.dtype)\n        return self.norm(self.G @ x)\n\n    def _prox_operators(\n        self, input_shape: Shape, input_dtype: DType\n    ) -> Tuple[LinearOperator, LinearOperator, int, Tuple]:\n        \"\"\"Construct operators required by prox method.\"\"\"\n        axes = normalize_axes(self.axes, input_shape)\n        ndims = len(axes)\n        w_input_shape = (\n            # circular boundary: shape of input array\n            input_shape\n            if self.circular\n            # non-circular boundary: shape of input array on non-differenced\n            #    axes and one greater for axes that are differenced\n            else tuple([s + 1 if i in axes else s for i, s in enumerate(input_shape)])  # type: ignore\n        )\n        W = HaarTransform(w_input_shape, input_dtype=input_dtype, axes=axes, jit=True)  # type: ignore\n        if self.circular:\n            # slice selecting highpass component of shift-invariant Haar transform\n            slce = snp.s_[:, 1]\n            # No boundary extension, so fused extend and forward transform, and fused\n            # adjoint transform and crop are just forward and adjoint respectively.\n            WP, CWT = W, W.T\n        else:\n            # slice selecting non-boundary region of highpass component of\n            # shift-invariant Haar transform\n            slce = (\n                snp.s_[:],\n                snp.s_[1],\n            ) + tuple(\n                [snp.s_[:-1] if i in axes else snp.s_[:] for i, s in enumerate(input_shape)]\n            )  # type: ignore\n            # Replicate-pad to the right (resulting in a zero after finite differencing)\n            # on all axes subject to finite differencing.\n            pad_width = [(0, 1) if i in axes else (0, 0) for i, s in enumerate(input_shape)]  # type: ignore\n            P = Pad(\n                input_shape, input_dtype=input_dtype, pad_width=pad_width, mode=\"edge\", jit=True\n            )\n            # fused boundary extend and forward transform linop\n            WP = W @ P\n            # crop operation that is inverse of the padding operation\n            C = Crop(\n                crop_width=pad_width, input_shape=w_input_shape, input_dtype=input_dtype, jit=True\n            )\n            # fused adjoint transform and crop linop\n            CWT = C @ W.T\n        return WP, CWT, ndims, slce\n\n    @staticmethod\n    def _slice_tuple_to_tuple(st: Tuple) -> Tuple:\n        \"\"\"Convert a tuple of slice or int to a tuple of tuple or int.\n\n        Required here as a workaround for the unhashability of slices in\n        Python < 3.12, since jax.jit requires static arguments to be\n        hashable.\n        \"\"\"\n        return tuple([(s.start, s.stop, s.step) if isinstance(s, slice) else s for s in st])\n\n    @staticmethod\n    def _slice_tuple_from_tuple(st: Tuple) -> Tuple:\n        \"\"\"Convert a tuple of tuple or int to a tuple of slice or int.\n\n        Required here as a workaround for the unhashability of slices in\n        Python < 3.12, since jax.jit requires static arguments to be\n        hashable.\n        \"\"\"\n        return tuple([slice(*s) if isinstance(s, tuple) else s for s in st])\n\n    @staticmethod\n    @partial(jax.jit, static_argnums=(0, 1, 2, 4))\n    def _prox_core(\n        WP: LinearOperator,\n        CWT: LinearOperator,\n        norm: Functional,\n        K: int,\n        slce_rep: Tuple,\n        v: Array,\n        lam: float = 1.0,\n    ) -> Array:\n        \"\"\"Core component of prox calculation.\"\"\"\n        # Apply boundary extension (when circular==False) and single-level Haar\n        # transform to input array.\n        WPv: Array = WP(v)\n        # Convert tuple of slices/ints to tuple of tuples/ints to avoid jax.jit\n        # complaints about unhashability of slices.\n        slce = TVNorm._slice_tuple_from_tuple(slce_rep)\n        # Apply shrinkage to highpass component of shift-invariant Haar transform\n        # of padded input (or to non-boundary region thereof when circular==False).\n        WPv = WPv.at[slce].set(norm.prox(WPv[slce], snp.sqrt(2) * K * lam))\n        # Apply adjoint of single-level Haar transform and crop extended\n        # part of array (when circular==False).\n        return (1.0 / K) * CWT(WPv)\n\n    def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array:\n        r\"\"\"Approximate scaled proximal operator of the TV norm.\n\n        Approximation of the scaled proximal operator of the TV norm,\n        computed via the methods described in\n        :cite:`kamilov-2016-parallel` :cite:`kamilov-2016-minimizing`\n        :cite:`chandler-2024-closedform`.\n\n        Args:\n            v: Input array :math:`\\mb{v}`.\n            lam: Proximal parameter :math:`\\lam`.\n            **kwargs: Additional arguments that may be used by derived\n                classes.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        if self.WP is None or self.WP.shape[1] != v.shape:\n            self.WP, self.CWT, self.prox_ndims, self.prox_slice = self._prox_operators(\n                v.shape, v.dtype\n            )\n        assert self.prox_ndims is not None\n        assert self.prox_slice is not None\n        K = 2 * self.prox_ndims\n        u = TVNorm._prox_core(\n            self.WP, self.CWT, self.norm, K, TVNorm._slice_tuple_to_tuple(self.prox_slice), v, lam\n        )\n\n        return u\n\n\nclass AnisotropicTVNorm(TVNorm):\n    r\"\"\"The anisotropic total variation (TV) norm.\n\n    The anisotropic total variation (TV) norm computed by\n\n    .. code-block:: python\n\n       ATV = scico.functional.AnisotropicTVNorm()\n       x_norm = ATV(x)\n\n    is equivalent to\n\n    .. code-block:: python\n\n       C = linop.FiniteDifference(input_shape=x.shape, circular=True)\n       L1 = functional.L1Norm()\n       x_norm = L1(C @ x)\n\n    The scaled proximal operator is computed using an approximation that\n    holds for small scaling parameters :cite:`kamilov-2016-parallel`.\n    This does not imply that it can only be applied to problems requiring\n    a small regularization parameter since most proximal algorithms\n    include an additional algorithm parameter that also plays a role in\n    the parameter of the proximal operator. For example, in :class:`.PGM`\n    and :class:`.AcceleratedPGM`, the scaled proximal operator parameter\n    is the regularization parameter divided by the `L0` algorithm\n    parameter, and for :class:`.ADMM`, the scaled proximal operator\n    parameters are the regularization parameters divided by the entries\n    in the `rho_list` algorithm parameter.\n    \"\"\"\n\n    def __init__(\n        self,\n        circular: bool = False,\n        axes: Optional[Axes] = None,\n        input_shape: Optional[Shape] = None,\n        input_dtype: DType = snp.float32,\n    ):\n        \"\"\"\n        Args:\n            circular: Flag indicating use of circular boundary conditions.\n            axes: Axis or axes over which to apply finite difference\n                operator. If not specified, or ``None``, differences are\n                evaluated along all axes.\n            input_shape: Shape of input arrays of :meth:`~.TVNorm.__call__` and\n                :meth:`~.TVNorm.prox`.\n            input_dtype: `dtype` of input arrays of :meth:`~.TVNorm.__call__` and\n                :meth:`~.TVNorm.prox`.\n        \"\"\"\n        super().__init__(\n            L1Norm(),\n            circular=circular,\n            axes=axes,\n            input_shape=input_shape,\n            input_dtype=input_dtype,\n        )\n\n\nclass IsotropicTVNorm(TVNorm):\n    r\"\"\"The isotropic total variation (TV) norm.\n\n    The isotropic total variation (TV) norm computed by\n\n    .. code-block:: python\n\n       ATV = scico.functional.IsotropicTVNorm()\n       x_norm = ATV(x)\n\n    is equivalent to\n\n    .. code-block:: python\n\n       C = linop.FiniteDifference(input_shape=x.shape, circular=True)\n       L21 = functional.L21Norm()\n       x_norm = L21(C @ x)\n\n    The scaled proximal operator is computed using an approximation that\n    holds for small scaling parameters :cite:`kamilov-2016-minimizing`.\n    This does not imply that it can only be applied to problems requiring\n    a small regularization parameter since most proximal algorithms\n    include an additional algorithm parameter that also plays a role in\n    the parameter of the proximal operator. For example, in :class:`.PGM`\n    and :class:`.AcceleratedPGM`, the scaled proximal operator parameter\n    is the regularization parameter divided by the `L0` algorithm\n    parameter, and for :class:`.ADMM`, the scaled proximal operator\n    parameters are the regularization parameters divided by the entries\n    in the `rho_list` algorithm parameter.\n    \"\"\"\n\n    def __init__(\n        self,\n        circular: bool = False,\n        axes: Optional[Axes] = None,\n        input_shape: Optional[Shape] = None,\n        input_dtype: DType = snp.float32,\n    ):\n        r\"\"\"\n        Args:\n            circular: Flag indicating use of circular boundary conditions.\n            axes: Axis or axes over which to apply finite difference\n                operator. If not specified, or ``None``, differences are\n                evaluated along all axes.\n            input_shape: Shape of input arrays of :meth:`~.TVNorm.__call__` and\n                :meth:`~.TVNorm.prox`.\n            input_dtype: `dtype` of input arrays of :meth:`~.TVNorm.__call__` and\n                :meth:`~.TVNorm.prox`.\n        \"\"\"\n        super().__init__(\n            L21Norm(),\n            circular=circular,\n            axes=axes,\n            input_shape=input_shape,\n            input_dtype=input_dtype,\n        )\n\n\nclass SingleAxisFiniteSum(LinearOperator):\n    r\"\"\"Two-point sum operator acting along a single axis.\n\n    Boundary handling is circular, so that the sum operator corresponds\n    to the matrix\n\n    .. math::\n\n       \\left(\\begin{array}{rrrrr}\n        1 & 0 & 0 & \\ldots & 0\\\\\n       1 & 1 & 0 & \\ldots & 0\\\\\n       0 & 1 & 1 & \\ldots & 0\\\\\n       \\vdots & \\vdots & \\ddots & \\ddots & \\vdots\\\\\n       0 & 0 & \\ldots & 1 & 1\\\\\n       1 & 0 & \\dots & 0 & 1\n       \\end{array}\\right) \\;.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        input_dtype: DType = snp.float32,\n        axis: int = -1,\n        jit: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`.\n            axis: Axis over which to apply sum operator.\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n                functions of the :class:`LinearOperator`.\n        \"\"\"\n\n        if not isinstance(axis, int):\n            raise TypeError(\n                f\"Expected argument 'axis' to be of type int, got {type(axis)} instead.\"\n            )\n\n        if axis < 0:\n            axis = len(input_shape) + axis\n        if axis >= len(input_shape):\n            raise ValueError(\n                f\"Invalid argument 'axis' specified ({axis}); 'axis' must be less than \"\n                f\"len(input_shape)={len(input_shape)}.\"\n            )\n        self.axis = axis\n\n        super().__init__(\n            input_shape=input_shape,\n            output_shape=input_shape,\n            input_dtype=input_dtype,\n            output_dtype=input_dtype,\n            jit=jit,\n            **kwargs,\n        )\n\n    def _eval(self, x: snp.Array) -> snp.Array:\n        return x + snp.roll(x, -1, self.axis)\n\n\nclass FiniteSum(VerticalStack):\n    \"\"\"Two-point sum operator.\n\n    Compute two-point sums along the specified axes, returning the\n    results stacked on axis 0 of a :class:`jax.Array`.\n    See :class:`SingleAxisFiniteSum` for boundary handling details.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        input_dtype: DType = snp.float32,\n        axes: Optional[Axes] = None,\n        jit: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`.\n            axes: Axis or axes over which to apply sum operator. If not\n                specified, or ``None``, sums are evaluated along all axes.\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n                functions of the :class:`LinearOperator`.\n        \"\"\"\n        self.axes, ops = linop_over_axes(\n            SingleAxisFiniteSum,\n            input_shape,\n            axes=axes,\n            input_dtype=input_dtype,\n            jit=False,\n        )\n        super().__init__(\n            ops,  # type: ignore\n            jit=jit,\n            **kwargs,\n        )\n\n\nclass SingleAxisHaarTransform(VerticalStack):\n    \"\"\"Single-level shift-invariant Haar transform along a single axis.\n\n    Compute one level of a shift-invariant Haar transform along the\n    specified axis, returning the results in a :class:`jax.Array`\n    consisting of sum and difference components (corresponding to lowpass\n    and highpass filtered components respectively) stacked on axis 0.\n    See :class:`SingleAxisFiniteSum` for boundary handling details.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        input_dtype: DType = snp.float32,\n        axis: int = -1,\n        jit: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`.\n            axis: Axis over which to apply Haar transform.\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n                functions of the :class:`LinearOperator`.\n        \"\"\"\n        self.axis = axis\n        self.HaarL = (1.0 / snp.sqrt(2.0)) * SingleAxisFiniteSum(\n            input_shape, input_dtype=input_dtype, axis=axis, jit=jit, **kwargs\n        )\n        self.HaarH = (1.0 / snp.sqrt(2.0)) * SingleAxisFiniteDifference(\n            input_shape, input_dtype=input_dtype, axis=axis, circular=True, jit=jit, **kwargs\n        )\n        super().__init__(\n            (self.HaarL, self.HaarH),\n            jit=jit,\n            **kwargs,\n        )\n\n\nclass HaarTransform(VerticalStack):\n    \"\"\"Single-level shift-invariant Haar transform.\n\n    Compute one level of a shift-invariant Haar transform along the\n    specified axes, returning the results in a :class:`jax.Array`.\n    See :class:`SingleAxisHaarTransform` for details of the transform\n    along each axis.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        input_dtype: DType = snp.float32,\n        axes: Optional[Axes] = None,\n        jit: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`.\n            axes: Axis or axes over which to apply Haar transform. If not\n                specified, or ``None``, the transform is evaluated along\n                all axes.\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n                functions of the :class:`LinearOperator`.\n        \"\"\"\n        self.axes, ops = linop_over_axes(\n            SingleAxisHaarTransform,\n            input_shape,\n            axes=axes,\n            input_dtype=input_dtype,\n            jit=False,\n        )\n        super().__init__(\n            ops,  # type: ignore\n            jit=jit,\n            **kwargs,\n        )\n"
  },
  {
    "path": "scico/linop/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Linear operator functions and classes.\"\"\"\n\nimport sys\n\nfrom ._circconv import CircularConvolve\nfrom ._convolve import Convolve, ConvolveByX\nfrom ._dft import DFT\nfrom ._diag import Diagonal, Identity, ScaledIdentity\nfrom ._diff import FiniteDifference, SingleAxisFiniteDifference\nfrom ._func import Crop, Pad, Reshape, Slice, Sum, Transpose, linop_from_function\nfrom ._grad import (\n    CylindricalGradient,\n    PolarGradient,\n    ProjectedGradient,\n    SphericalGradient,\n)\nfrom ._linop import ComposedLinearOperator, LinearOperator\nfrom ._matrix import MatrixOperator\nfrom ._stack import DiagonalReplicated, DiagonalStack, VerticalStack, linop_over_axes\nfrom ._util import jacobian, operator_norm, power_iteration, valid_adjoint\n\n__all__ = [\n    \"CircularConvolve\",\n    \"Convolve\",\n    \"DFT\",\n    \"Diagonal\",\n    \"FiniteDifference\",\n    \"ProjectedGradient\",\n    \"PolarGradient\",\n    \"CylindricalGradient\",\n    \"SphericalGradient\",\n    \"SingleAxisFiniteDifference\",\n    \"Identity\",\n    \"DiagonalReplicated\",\n    \"VerticalStack\",\n    \"DiagonalStack\",\n    \"MatrixOperator\",\n    \"Pad\",\n    \"Crop\",\n    \"Reshape\",\n    \"ScaledIdentity\",\n    \"Slice\",\n    \"Sum\",\n    \"Transpose\",\n    \"LinearOperator\",\n    \"ComposedLinearOperator\",\n    \"linop_from_function\",\n    \"linop_over_axes\",\n    \"operator_norm\",\n    \"power_iteration\",\n    \"valid_adjoint\",\n    \"jacobian\",\n]\n\n# Imported items in __all__ appear to originate in top-level linop module\nfor name in __all__:\n    getattr(sys.modules[__name__], name).__module__ = __name__\n"
  },
  {
    "path": "scico/linop/_circconv.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Circular convolution linear operator.\"\"\"\n\nimport math\nfrom typing import Optional, Sequence, Tuple, Union\n\nimport numpy as np\n\nfrom jax.dtypes import result_type\n\nimport scico.numpy as snp\nfrom scico.numpy.util import is_nested\nfrom scico.operator import Operator\nfrom scico.typing import DType, Shape\n\nfrom ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar\n\n\nclass CircularConvolve(LinearOperator):\n    r\"\"\"A circular convolution linear operator.\n\n    This linear operator implements circular, multi-dimensional\n    convolution via pointwise multiplication in the DFT domain. In its\n    simplest form, it implements a single convolution and can be\n    represented by linear operator :math:`H` such that\n\n    .. math::\n       H \\mb{x} = \\mb{h} \\ast \\mb{x} \\;,\n\n    where :math:`\\mb{h}` is a user-defined filter.\n\n    More complex forms, corresponding to the case where either the input\n    (as represented by parameter `input_shape`) or filter (parameter `h`)\n    have additional axes that are not involved in the convolution are\n    also supported. These follow numpy broadcasting rules. For example:\n\n    Additional axes in the input :math:`\\mb{x}` and not in :math:`\\mb{h}`\n    corresponds to the operation\n\n    .. math::\n       H \\mb{x} = \\left( \\begin{array}{ccc}  H' & 0 & \\ldots\\\\\n                                            0 & H' & \\ldots\\\\\n                                            \\vdots & \\vdots & \\ddots\n                        \\end{array} \\right)\n       \\left( \\begin{array}{c}  \\mb{x}_0\\\\ \\mb{x}_1\\\\ \\vdots \\end{array}\n       \\right) \\;.\n\n    Additional axes in :math:`\\mb{h}` corresponds to multiple filters,\n    which will be denoted by :math:`\\{\\mb{h}_m\\}`, with corresponding\n    individual linear operations being denoted by :math:`h_m \\mb{x}_m =\n    \\mb{h}_m \\ast \\mb{x}_m`. The full linear operator can then be\n    represented as\n\n    .. math::\n       H \\mb{x} = \\left( \\begin{array}{c}  H_0\\\\ H_1\\\\ \\vdots \\end{array}\n       \\right) \\mb{x} \\;.\n\n    if the input is singleton, and as\n\n    .. math::\n       H \\mb{x} = \\left( \\begin{array}{ccc}  H_0 & 0 & \\ldots\\\\\n                                            0 & H_1 & \\ldots\\\\\n                                            \\vdots & \\vdots & \\ddots\n                        \\end{array} \\right)\n       \\left( \\begin{array}{c}  \\mb{x}_0\\\\ \\mb{x}_1\\\\ \\vdots \\end{array}\n       \\right)\n\n    otherwise.\n    \"\"\"\n\n    def __init__(\n        self,\n        h: snp.Array,\n        input_shape: Shape,\n        ndims: Optional[int] = None,\n        input_dtype: DType = snp.float32,\n        h_is_dft: bool = False,\n        h_center: Optional[Union[snp.Array, np.ndarray, Sequence, float, int]] = None,\n        jit: bool = True,\n        **kwargs,\n    ):\n        \"\"\"\n        Args:\n            h: Array of filters.\n            input_shape: Shape of input array.\n            ndims: Number of (trailing) dimensions of the input and `h`\n                involved in the convolution. Defaults to the number of\n                dimensions in the input.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`.\n            h_is_dft: Flag indicating whether `h` is in the DFT domain.\n            h_center: Array of length `ndims` specifying the center of\n                the filter. Defaults to the upper left corner, i.e.,\n                `h_center = [0, 0, ..., 0]`, may be noninteger. May be a\n                ``float`` or ``int`` if `h` is one-dimensional.\n            jit:  If ``True``, jit the evaluation, adjoint, and gram\n                functions of the :class:`LinearOperator`.\n        \"\"\"\n\n        if ndims is None:\n            self.ndims = len(input_shape)\n        else:\n            self.ndims = ndims\n\n        if h_is_dft and h_center is not None:\n            raise ValueError(\"Argument 'h_center' must be None when h_is_dft=True.\")\n        self.h_center = h_center\n\n        if h_is_dft:\n            self.h_dft = h\n            output_dtype = snp.dtype(input_dtype)  # cannot infer from h_dft because it is complex\n        else:\n            fft_shape = input_shape[-self.ndims :]\n            fft_axes = list(range(h.ndim - self.ndims, h.ndim))\n            self.h_dft = snp.fft.fftn(h, s=fft_shape, axes=fft_axes)\n            output_dtype = result_type(h.dtype, input_dtype)\n\n            if self.h_center is not None:\n                shift = self._dft_center_shift(input_shape)\n                self.h_dft = self.h_dft * shift\n\n        self.real = output_dtype.kind != \"c\"\n\n        try:\n            output_shape = np.broadcast_shapes(self.h_dft.shape, input_shape)\n        except ValueError:\n            raise ValueError(\n                f\"Shape of 'h' after padding was {self.h_dft.shape}, needs to be compatible \"\n                f\"for broadcasting with {input_shape}.\"\n            )\n\n        self.batch_axes = tuple(\n            range(0, len(output_shape) - len(input_shape))\n        )  # used in adjoint to undo broadcasting\n\n        self.ifft_axes = list(range(len(output_shape) - self.ndims, len(output_shape)))\n        self.x_fft_axes = list(range(len(input_shape) - self.ndims, len(input_shape)))\n\n        super().__init__(\n            input_shape=input_shape,\n            output_shape=output_shape,\n            input_dtype=input_dtype,\n            output_dtype=output_dtype,\n            jit=jit,\n            **kwargs,\n        )\n\n    def _dft_center_shift(self, input_shape) -> np.ndarray:\n        \"\"\"Compute DFT domain shift required for centering.\n\n        See doi:10.1109/78.700979 and doi:10.1109/LSP.2012.2191280 for\n        details of the shift computation.\n        \"\"\"\n        if isinstance(self.h_center, (float, int)):  # support float/int h_center\n            offset = -np.array(\n                [\n                    self.h_center,\n                ]\n            )\n        else:  # support array/list/tuple h_center\n            offset = -np.array(self.h_center)\n        shifts: Tuple[np.ndarray, ...] = np.ix_(\n            *tuple(\n                np.select(\n                    [np.arange(s) < s / 2, np.arange(s) == s / 2, np.arange(s) > s / 2],\n                    [\n                        np.exp(-1j * k * 2 * np.pi * np.arange(s) / s),\n                        np.cos(k * np.pi),\n                        np.exp(1j * k * 2 * np.pi * (s - np.arange(s)) / s),\n                    ],  # type: ignore\n                )\n                for k, s in zip(offset, input_shape[-self.ndims :])\n            )\n        )\n        # prevent accidental promotion to double\n        shifts = tuple(s.astype(self.h_dft.dtype) for s in shifts)\n        shift = math.prod(shifts)  # np.prod warns\n        assert isinstance(shift, np.ndarray)\n        return shift\n\n    def _eval(self, x: snp.Array) -> snp.Array:\n        x = x.astype(self.input_dtype)\n        x_dft = snp.fft.fftn(x, axes=self.x_fft_axes)\n        hx = snp.fft.ifftn(\n            self.h_dft * x_dft,\n            axes=self.ifft_axes,\n        )\n        if self.real:\n            hx = hx.real\n        return hx\n\n    def _adj(self, x: snp.Array) -> snp.Array:  # type: ignore\n        x_dft = snp.fft.fftn(x, axes=self.ifft_axes)\n        H_adj_x = snp.fft.ifftn(\n            snp.conj(self.h_dft) * x_dft,\n            axes=self.ifft_axes,\n            s=self.input_shape[-self.ndims :],\n        )\n        H_adj_x = snp.sum(H_adj_x, axis=self.batch_axes)  # adjoint of the broadcast\n        if self.real:\n            H_adj_x = H_adj_x.real\n        return H_adj_x\n\n    @_wrap_add_sub\n    def __add__(self, other):\n        if self.ndims != other.ndims:\n            raise ValueError(f\"Incompatible ndims: {self.ndims} != {other.ndims}.\")\n\n        return CircularConvolve(\n            h=self.h_dft + other.h_dft,\n            input_shape=self.input_shape,\n            input_dtype=result_type(self.input_dtype, other.input_dtype),\n            ndims=self.ndims,\n            h_is_dft=True,\n        )\n\n    @_wrap_add_sub\n    def __sub__(self, other):\n        if self.ndims != other.ndims:\n            raise ValueError(f\"Incompatible ndims: {self.ndims} != {other.ndims}.\")\n\n        return CircularConvolve(\n            h=self.h_dft - other.h_dft,\n            input_shape=self.input_shape,\n            input_dtype=result_type(self.input_dtype, other.input_dtype),\n            ndims=self.ndims,\n            h_is_dft=True,\n        )\n\n    @_wrap_mul_div_scalar\n    def __mul__(self, scalar):\n        return CircularConvolve(\n            h=self.h_dft * scalar,\n            input_shape=self.input_shape,\n            ndims=self.ndims,\n            input_dtype=self.input_dtype,\n            h_is_dft=True,\n        )\n\n    @_wrap_mul_div_scalar\n    def __truediv__(self, scalar):\n        return CircularConvolve(\n            h=self.h_dft / scalar,\n            input_shape=self.input_shape,\n            ndims=self.ndims,\n            input_dtype=self.input_dtype,\n            h_is_dft=True,\n        )\n\n    @staticmethod\n    def from_operator(\n        H: Operator, ndims: Optional[int] = None, center: Optional[Shape] = None, jit: bool = True\n    ):\n        r\"\"\"Construct a CircularConvolve version of a given operator.\n\n        Construct a CircularConvolve version of a given operator,\n        which is assumed to be linear and shift invariant (LSI).\n\n        Args:\n            H: Input operator.\n            ndims: Number of trailing dims over which the H acts.\n            center: Location at which to place the Kronecker delta. For\n              LSI inputs, this will not matter. Defaults to the center\n              of H.input_shape, i.e., (n_1 // 2, n_2 // 2, ...).\n            jit: If ``True``, jit the resulting `CircularConvolve`.\n        \"\"\"\n\n        if is_nested(H.input_shape):\n            raise ValueError(\n                f\"H.input_shape ({H.input_shape}) suggests that H \"\n                \"takes a BlockArray as input, which is not supported \"\n                \"by this function.\"\n            )\n\n        if ndims is None:\n            ndims = len(H.input_shape)\n        else:\n            ndims = ndims\n\n        if center is None:\n            center = tuple(d // 2 for d in H.input_shape[-ndims:])  # type: ignore\n\n        # compute impulse response\n        d = snp.zeros(H.input_shape, H.input_dtype)\n        d = d.at[(Ellipsis,) + center].set(1.0)\n        Hd = H @ d\n\n        # build CircularConvolve\n        return CircularConvolve(\n            Hd,\n            H.input_shape,  # type: ignore\n            ndims=ndims,\n            input_dtype=H.input_dtype,\n            h_center=snp.array(center),\n            jit=jit,\n        )\n"
  },
  {
    "path": "scico/linop/_convolve.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Convolution linear operator class.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nimport numpy as np\n\nfrom jax.dtypes import result_type\nfrom jax.scipy.signal import convolve\n\nimport scico.numpy as snp\nfrom scico.typing import DType, Shape\n\nfrom ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar\n\n\nclass Convolve(LinearOperator):\n    \"\"\"A convolution linear operator.\"\"\"\n\n    def __init__(\n        self,\n        h: snp.Array,\n        input_shape: Shape,\n        input_dtype: DType = np.float32,\n        mode: str = \"full\",\n        jit: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"Wrap :func:`jax.scipy.signal.convolve` as a :class:`.LinearOperator`.\n\n        Args:\n            h: Convolutional filter. Must have same number of dimensions\n                as `len(input_shape)`.\n            input_shape: Shape of input array.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`.\n            mode: A string indicating the size of the output. One of\n                \"full\", \"valid\", \"same\". Defaults to \"full\".\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n                functions of the :class:`LinearOperator`.\n\n        For more details on `mode`, see :func:`jax.scipy.signal.convolve`.\n        \"\"\"\n\n        self.h: snp.Array  # : Convolution kernel\n        self.mode: str  # : Convolution mode\n\n        if h.ndim != len(input_shape):\n            raise ValueError(f\"h.ndim = {h.ndim} must equal len(input_shape) = {len(input_shape)}.\")\n        self.h = h\n\n        if mode not in [\"full\", \"valid\", \"same\"]:\n            raise ValueError(f\"Invalid mode={mode}; must be one of 'full', 'valid', 'same'.\")\n\n        self.mode = mode\n\n        if input_dtype is None:\n            input_dtype = self.h.dtype\n\n        output_dtype = result_type(input_dtype, self.h.dtype)\n\n        super().__init__(\n            input_shape=input_shape,\n            input_dtype=input_dtype,\n            output_dtype=output_dtype,\n            jit=jit,\n            **kwargs,\n        )\n\n    def _eval(self, x: snp.Array) -> snp.Array:\n        return convolve(x, self.h, mode=self.mode)\n\n    @_wrap_add_sub\n    def __add__(self, other):\n        if self.mode != other.mode:\n            raise ValueError(f\"Incompatible modes:  {self.mode} != {other.mode}.\")\n\n        if self.h.shape == other.h.shape:\n            return Convolve(\n                h=self.h + other.h,\n                input_shape=self.input_shape,\n                input_dtype=result_type(self.input_dtype, other.input_dtype),\n                mode=self.mode,\n                output_shape=self.output_shape,\n                adj_fn=lambda x: self.adj(x) + other.adj(x),\n            )\n\n        raise ValueError(f\"Incompatible shapes: {self.shape} != {other.shape}.\")\n\n    @_wrap_add_sub\n    def __sub__(self, other):\n        if self.mode != other.mode:\n            raise ValueError(f\"Incompatible modes:  {self.mode} != {other.mode}.\")\n\n        if self.h.shape == other.h.shape:\n            return Convolve(\n                h=self.h - other.h,\n                input_shape=self.input_shape,\n                input_dtype=result_type(self.input_dtype, other.input_dtype),\n                mode=self.mode,\n                output_shape=self.output_shape,\n                adj_fn=lambda x: self.adj(x) - other.adj(x),\n            )\n        raise ValueError(f\"Incompatible shapes: {self.shape} != {other.shape}.\")\n\n    @_wrap_mul_div_scalar\n    def __mul__(self, scalar):\n        return Convolve(\n            h=self.h * scalar,\n            input_shape=self.input_shape,\n            input_dtype=result_type(self.input_dtype, type(scalar)),\n            mode=self.mode,\n            output_shape=self.output_shape,\n            adj_fn=lambda x: snp.conj(scalar) * self.adj(x),\n        )\n\n    @_wrap_mul_div_scalar\n    def __truediv__(self, scalar):\n        return Convolve(\n            h=self.h / scalar,\n            input_shape=self.input_shape,\n            input_dtype=result_type(self.input_dtype, type(scalar)),\n            mode=self.mode,\n            output_shape=self.output_shape,\n            adj_fn=lambda x: self.adj(x) / snp.conj(scalar),\n        )\n\n\nclass ConvolveByX(LinearOperator):\n    \"\"\"A LinearOperator that performs convolution as a function of the first argument.\n\n    The :class:`LinearOperator` `ConvolveByX(x=x)(y)` implements\n    `jax.scipy.signal.convolve(x, y)`.\n    \"\"\"\n\n    def __init__(\n        self,\n        x: snp.Array,\n        input_shape: Shape,\n        input_dtype: DType = np.float32,\n        mode: str = \"full\",\n        jit: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n\n        Args:\n            x: Convolutional filter. Must have same number of dimensions\n                as `len(input_shape)`.\n            input_shape: Shape of input array.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`.\n            mode: A string indicating the size of the output. One of\n                \"full\", \"valid\", \"same\". Defaults to \"full\".\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n                functions of the :class:`LinearOperator`.\n\n        For more details on `mode`, see :func:`jax.scipy.signal.convolve`.\n        \"\"\"\n\n        self.x: snp.Array  # : Fixed signal to convolve with\n        self.mode: str  # : Convolution mode\n\n        if x.ndim != len(input_shape):\n            raise ValueError(f\"x.ndim = {x.ndim} must equal len(input_shape) = {len(input_shape)}.\")\n\n        # Ensure that x is a numpy or jax array.\n        if not snp.util.is_arraylike(x):\n            raise TypeError(f\"Expected numpy or jax array, got {type(x)}.\")\n        self.x = x\n\n        if mode not in [\"full\", \"valid\", \"same\"]:\n            raise ValueError(f\"Invalid mode={mode}; must be one of 'full', 'valid', 'same'.\")\n\n        self.mode = mode\n\n        if input_dtype is None:\n            input_dtype = x.dtype\n\n        output_dtype = result_type(input_dtype, x.dtype)\n\n        super().__init__(\n            input_shape=input_shape,\n            input_dtype=input_dtype,\n            output_dtype=output_dtype,\n            jit=jit,\n            **kwargs,\n        )\n\n    def _eval(self, h: snp.Array) -> snp.Array:\n        return convolve(self.x, h, mode=self.mode)\n\n    @_wrap_add_sub\n    def __add__(self, other):\n        if self.mode != other.mode:\n            raise ValueError(f\"Incompatible modes:  {self.mode} != {other.mode}.\")\n        if self.x.shape == other.x.shape:\n            return ConvolveByX(\n                x=self.x + other.x,\n                input_shape=self.input_shape,\n                input_dtype=result_type(self.input_dtype, other.input_dtype),\n                mode=self.mode,\n                output_shape=self.output_shape,\n                adj_fn=lambda x: self.adj(x) + other.adj(x),\n            )\n        raise ValueError(f\"Incompatible shapes: {self.shape} != {other.shape}.\")\n\n    @_wrap_add_sub\n    def __sub__(self, other):\n        if self.mode != other.mode:\n            raise ValueError(f\"Incompatible modes:  {self.mode} != {other.mode}.\")\n\n        if self.x.shape == other.x.shape:\n            return ConvolveByX(\n                x=self.x - other.x,\n                input_shape=self.input_shape,\n                input_dtype=result_type(self.input_dtype, other.input_dtype),\n                mode=self.mode,\n                output_shape=self.output_shape,\n                adj_fn=lambda x: self.adj(x) - other.adj(x),\n            )\n\n        raise ValueError(f\"Incompatible shapes: {self.shape} != {other.shape}.\")\n\n    @_wrap_mul_div_scalar\n    def __mul__(self, scalar):\n        return ConvolveByX(\n            x=self.x * scalar,\n            input_shape=self.input_shape,\n            input_dtype=result_type(self.input_dtype, type(scalar)),\n            mode=self.mode,\n            output_shape=self.output_shape,\n            adj_fn=lambda x: snp.conj(scalar) * self.adj(x),\n        )\n\n    @_wrap_mul_div_scalar\n    def __truediv__(self, scalar):\n        return ConvolveByX(\n            x=self.x / scalar,\n            input_shape=self.input_shape,\n            input_dtype=result_type(self.input_dtype, type(scalar)),\n            mode=self.mode,\n            output_shape=self.output_shape,\n            adj_fn=lambda x: self.adj(x) / snp.conj(scalar),\n        )\n"
  },
  {
    "path": "scico/linop/_dft.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Discrete Fourier transform linear operator class.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom typing import Optional, Sequence\n\nimport numpy as np\n\nimport scico.numpy as snp\nfrom scico.typing import Shape\n\nfrom ._linop import LinearOperator\n\n\nclass DFT(LinearOperator):\n    r\"\"\"Multi-dimensional discrete Fourier transform.\"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        axes: Optional[Sequence] = None,\n        axes_shape: Optional[Shape] = None,\n        norm: Optional[str] = None,\n        jit: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array.\n            axes: Axes over which to compute the DFT. If ``None``, the\n                DFT is computed over all axes.\n            axes_shape: Output shape on the subset of array axes selected\n                by `axes`. This parameter has the same behavior as the\n                `s` parameter of :func:`numpy.fft.fftn`.\n            norm: DFT normalization mode. See the `norm` parameter of\n                :func:`numpy.fft.fftn`.\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n                functions of the LinearOperator.\n        \"\"\"\n        if axes is not None and axes_shape is not None and len(axes) != len(axes_shape):\n            raise ValueError(\n                f\"len(axes)={len(axes)} does not equal len(axes_shape)={len(axes_shape)}.\"\n            )\n\n        if axes_shape is not None:\n            if axes is None:\n                axes = tuple(range(len(input_shape) - len(axes_shape), len(input_shape)))\n            tmp_output_shape = list(input_shape)\n            for i, s in zip(axes, axes_shape):\n                tmp_output_shape[i] = s\n            output_shape = tuple(tmp_output_shape)\n        else:\n            output_shape = input_shape\n\n        if axes is None or axes_shape is None:\n            self.inv_axes_shape = None\n        else:\n            self.inv_axes_shape = [input_shape[i] for i in axes]\n\n        self.axes = axes\n        self.axes_shape = axes_shape\n        self.norm = norm\n\n        # To satisfy mypy -- DFT shapes must be tuples, not list of tuple\n        # These get set inside of super().__init__ call, but we want to have\n        # more restrictive type than the general LinearOperator\n        self.input_shape: Shape\n        self.output_shape: Shape\n\n        super().__init__(\n            input_shape=input_shape,\n            output_shape=output_shape,\n            input_dtype=np.complex64,\n            output_dtype=np.complex64,\n            jit=jit,\n            **kwargs,\n        )\n\n    def _eval(self, x: snp.Array) -> snp.Array:\n        return snp.fft.fftn(x, s=self.axes_shape, axes=self.axes, norm=self.norm)\n\n    def inv(self, z: snp.Array) -> snp.Array:\n        \"\"\"Compute the inverse of this LinearOperator.\n\n        Compute the inverse of this LinearOperator applied to `z`.\n\n        Args:\n            z: Input array to inverse DFT.\n        \"\"\"\n        return snp.fft.ifftn(z, s=self.inv_axes_shape, axes=self.axes, norm=self.norm)\n"
  },
  {
    "path": "scico/linop/_diag.py",
    "content": "# Copyright (C) 2020-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Miscellaneous linear operator definitions.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom typing import Optional, Union\n\nimport scico.numpy as snp\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.util import broadcast_nested_shapes, is_nested\nfrom scico.operator._operator import _wrap_mul_div_scalar\nfrom scico.typing import BlockShape, DType, Shape\n\nfrom ._linop import LinearOperator, _wrap_add_sub\n\n__all__ = [\"Diagonal\", \"Identity\", \"ScaledIdentity\"]\n\n\nclass Diagonal(LinearOperator):\n    \"\"\"Diagonal linear operator.\"\"\"\n\n    def __init__(\n        self,\n        diagonal: Union[Array, BlockArray],\n        input_shape: Optional[Union[Shape, BlockShape]] = None,\n        input_dtype: Optional[DType] = None,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            diagonal: Diagonal elements of this :class:`LinearOperator`.\n            input_shape: Shape of input array. By default, equal to\n               `diagonal.shape`, but may also be set to a shape that is\n               broadcast-compatible with `diagonal.shape`.\n            input_dtype: `dtype` of input argument. The default,\n               ``None``, means `diagonal.dtype`.\n        \"\"\"\n        self._diagonal = diagonal\n\n        if input_shape is None:\n            input_shape = self._diagonal.shape\n\n        if input_dtype is None:\n            input_dtype = self._diagonal.dtype\n\n        if isinstance(diagonal, BlockArray) and is_nested(input_shape):\n            output_shape = broadcast_nested_shapes(input_shape, self._diagonal.shape)\n        elif not isinstance(diagonal, BlockArray) and not is_nested(input_shape):\n            output_shape = snp.broadcast_shapes(input_shape, self._diagonal.shape)\n        elif isinstance(diagonal, BlockArray):\n            raise ValueError(\"Argument 'diagonal' was a BlockArray but input_shape was not nested.\")\n        else:\n            raise ValueError(\"Argument 'diagonal' was not a BlockArray but input_shape was nested.\")\n\n        super().__init__(\n            input_shape=input_shape,\n            input_dtype=input_dtype,\n            output_shape=output_shape,\n            output_dtype=input_dtype,\n            **kwargs,\n        )\n\n    def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]:\n        return self._diagonal * x\n\n    @property\n    def diagonal(self) -> Union[Array, BlockArray]:\n        \"\"\"Return an array representing the diagonal component.\"\"\"\n        return self._diagonal\n\n    @property\n    def T(self) -> Diagonal:\n        \"\"\"Transpose of this :class:`Diagonal`.\"\"\"\n        return self\n\n    def conj(self) -> Diagonal:\n        \"\"\"Complex conjugate of this :class:`Diagonal`.\"\"\"\n        return Diagonal(diagonal=self.diagonal.conj())\n\n    @property\n    def H(self) -> Diagonal:\n        \"\"\"Hermitian transpose of this :class:`Diagonal`.\"\"\"\n        return self.conj()\n\n    @property\n    def gram_op(self) -> Diagonal:\n        \"\"\"Gram operator of this :class:`Diagonal`.\n\n        Return a new :class:`Diagonal` :code:`G` such that\n        :code:`G(x) = A.adj(A(x)))`.\n        \"\"\"\n        return Diagonal(diagonal=self.diagonal.conj() * self.diagonal)\n\n    @_wrap_add_sub\n    def __add__(self, other):\n        if self.shape == other.shape:\n            return Diagonal(diagonal=self.diagonal + other.diagonal)\n        raise ValueError(f\"Incompatible shapes: {self.shape} != {other.shape}.\")\n\n    @_wrap_add_sub\n    def __sub__(self, other):\n        if self.shape == other.shape:\n            return Diagonal(diagonal=self.diagonal - other.diagonal)\n        raise ValueError(f\"Incompatible shapes: {self.shape} != {other.shape}.\")\n\n    @_wrap_mul_div_scalar\n    def __mul__(self, scalar):\n        return Diagonal(diagonal=self.diagonal * scalar)\n\n    @_wrap_mul_div_scalar\n    def __truediv__(self, scalar):\n        return Diagonal(diagonal=self.diagonal / scalar)\n\n    def __matmul__(self, other):\n        # self @ other\n        if isinstance(other, Diagonal):\n            if self.shape == other.shape:\n                return Diagonal(diagonal=self.diagonal * other.diagonal)\n            raise ValueError(f\"Shapes {self.shape} and {other.shape} do not match.\")\n        else:\n            return self(other)\n\n    def norm(self, ord=None):  # pylint: disable=W0622\n        \"\"\"Compute the matrix norm of the diagonal operator.\n\n        Valid values of `ord` and the corresponding norm definition\n        are those listed under \"norm for matrices\" in the\n        :func:`scico.numpy.linalg.norm` documentation.\n        \"\"\"\n        ordfunc = {\n            \"fro\": lambda x: snp.linalg.norm(x),\n            \"nuc\": lambda x: snp.sum(snp.abs(x)),\n            -snp.inf: lambda x: snp.abs(x).min(),\n            snp.inf: lambda x: snp.abs(x).max(),\n        }\n        mord = ord\n        if mord is None:\n            mord = \"fro\"\n        elif mord in (-1, -2):\n            mord = -snp.inf\n        elif mord in (1, 2):\n            mord = snp.inf\n        if mord not in ordfunc:\n            raise ValueError(f\"Invalid value {ord} for argument 'ord'.\")\n        return ordfunc[mord](self._diagonal)\n\n\nclass ScaledIdentity(Diagonal):\n    \"\"\"Scaled identity operator.\"\"\"\n\n    def __init__(\n        self,\n        scalar: float,\n        input_shape: Union[Shape, BlockShape],\n        input_dtype: DType = snp.float32,\n        **kwargs,\n    ):\n        \"\"\"\n        Args:\n            scalar: Scaling of the identity.\n            input_shape: Shape of input array.\n            input_dtype: `dtype` of input argument.\n        \"\"\"\n        if is_nested(input_shape):\n            diagonal = scalar * snp.ones(((),) * len(input_shape), dtype=input_dtype)\n        else:\n            diagonal = scalar * snp.ones((), dtype=input_dtype)\n        super().__init__(\n            diagonal=diagonal,\n            input_shape=input_shape,\n            input_dtype=input_dtype,\n            **kwargs,\n        )\n\n    @property\n    def diagonal(self) -> Union[Array, BlockArray]:\n        return self._diagonal\n\n    def conj(self) -> ScaledIdentity:\n        \"\"\"Complex conjugate of this :class:`ScaledIdentity`.\"\"\"\n        return ScaledIdentity(\n            scalar=self._diagonal.conj(), input_shape=self.input_shape, input_dtype=self.input_dtype\n        )\n\n    @property\n    def gram_op(self) -> ScaledIdentity:\n        \"\"\"Gram operator of this :class:`ScaledIdentity`.\"\"\"\n        return ScaledIdentity(\n            scalar=self._diagonal * self._diagonal.conj(),\n            input_shape=self.input_shape,\n            input_dtype=self.input_dtype,\n        )\n\n    @_wrap_add_sub\n    def __add__(self, other):\n        if self.input_shape == other.input_shape:\n            return ScaledIdentity(\n                scalar=self._diagonal + other._diagonal,\n                input_shape=self.input_shape,\n                input_dtype=self.input_dtype,\n            )\n        raise ValueError(f\"Incompatible shapes: {self.shape} != {other.shape}.\")\n\n    @_wrap_add_sub\n    def __sub__(self, other):\n        if self.input_shape == other.input_shape:\n            return ScaledIdentity(\n                scalar=self._diagonal - other._diagonal,\n                input_shape=self.input_shape,\n                input_dtype=self.input_dtype,\n            )\n        raise ValueError(f\"Incompatible shapes: {self.shape} != {other.shape}.\")\n\n    @_wrap_mul_div_scalar\n    def __mul__(self, scalar):\n        return ScaledIdentity(\n            scalar=self._diagonal * scalar,\n            input_shape=self.input_shape,\n            input_dtype=self.input_dtype,\n        )\n\n    @_wrap_mul_div_scalar\n    def __truediv__(self, scalar):\n        return ScaledIdentity(\n            scalar=self._diagonal / scalar,\n            input_shape=self.input_shape,\n            input_dtype=self.input_dtype,\n        )\n\n    def __matmul__(self, other):\n        # self @ other\n        if isinstance(other, Diagonal):\n            if self.shape != other.shape:\n                raise ValueError(f\"Shapes {self.shape} and {other.shape} do not match.\")\n            if isinstance(other, ScaledIdentity):\n                return ScaledIdentity(\n                    scalar=self._diagonal * other._diagonal,\n                    input_shape=self.input_shape,\n                    input_dtype=self.input_dtype,\n                )\n            else:\n                return Diagonal(diagonal=self._diagonal * other.diagonal)\n        else:\n            return self(other)\n\n    def norm(self, ord=None):  # pylint: disable=W0622\n        \"\"\"Compute the matrix norm of the identity operator.\n\n        Valid values of `ord` and the corresponding norm definition\n        are those listed under \"norm for matrices\" in the\n        :func:`scico.numpy.linalg.norm` documentation.\n        \"\"\"\n        N = self.input_size\n        if ord is None or ord == \"fro\":\n            return snp.abs(self._diagonal) * snp.sqrt(N)\n        elif ord == \"nuc\":\n            return snp.abs(self._diagonal) * N\n        elif ord in (-snp.inf, -1, -2, 1, 2, snp.inf):\n            return snp.abs(self._diagonal)\n        else:\n            raise ValueError(f\"Invalid value {ord} for argument 'ord'.\")\n\n\nclass Identity(ScaledIdentity):\n    \"\"\"Identity operator.\"\"\"\n\n    def __init__(\n        self, input_shape: Union[Shape, BlockShape], input_dtype: DType = snp.float32, **kwargs\n    ):\n        \"\"\"\n        Args:\n            input_shape: Shape of input array.\n            input_dtype: `dtype` of input argument.\n        \"\"\"\n        super().__init__(\n            scalar=1.0,\n            input_shape=input_shape,\n            input_dtype=input_dtype,\n            **kwargs,\n        )\n\n    def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]:\n        return x\n\n    def conj(self) -> Identity:\n        \"\"\"Complex conjugate of this :class:`Identity`.\"\"\"\n        return self\n\n    @property\n    def gram_op(self) -> Identity:\n        \"\"\"Gram operator of this :class:`Identity`.\"\"\"\n        return self\n\n    def __matmul__(self, other):\n        return other\n\n    def __rmatmul__(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]:\n        return x\n"
  },
  {
    "path": "scico/linop/_diff.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Finite difference linear operator class.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom typing import Literal, Optional, Union\n\nimport numpy as np\n\nimport scico.numpy as snp\nfrom scico.typing import Axes, DType, Shape\n\nfrom ._linop import LinearOperator\nfrom ._stack import VerticalStack, linop_over_axes\n\n\nclass FiniteDifference(VerticalStack):\n    \"\"\"Finite difference operator.\n\n    Compute finite differences along the specified axes, returning the\n    results in a :class:`jax.Array` (when possible) or :class:`BlockArray`.\n    See :class:`VerticalStack` for details on how this choice is made.\n    See :class:`SingleAxisFiniteDifference` for the mathematical\n    implications of the different boundary handling options `prepend`,\n    `append`, and `circular`.\n\n    Example\n    -------\n    >>> A = FiniteDifference((2, 3))\n    >>> x = snp.array([[1, 2, 4],\n    ...                [0, 4, 1]])\n    >>> (A @ x)[0]\n    Array([[-1,  2, -3]], dtype=int32)\n    >>> (A @ x)[1]\n    Array([[ 1,  2],\n           [ 4, -3]], dtype=int32)\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        input_dtype: DType = np.float32,\n        axes: Optional[Axes] = None,\n        prepend: Optional[Union[Literal[0], Literal[1]]] = None,\n        append: Optional[Union[Literal[0], Literal[1]]] = None,\n        circular: bool = False,\n        jit: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`.\n            axes: Axis or axes over which to apply finite difference\n                operator. If not specified, or ``None``, differences are\n                evaluated along all axes.\n            prepend: Flag indicating handling of the left/top/etc.\n                boundary. If ``None``, there is no boundary extension.\n                Values of `0` or `1` indicate respectively that zeros or\n                the initial value in the array are prepended to the\n                difference array.\n            append: Flag indicating handling of the right/bottom/etc.\n                boundary. If ``None``, there is no boundary extension.\n                Values of `0` or `1` indicate respectively that zeros or\n                -1 times the final value in the array are appended to the\n                difference array.\n            circular: If ``True``, perform circular differences, i.e.,\n                include x[-1] - x[0]. If ``True``, `prepend` and `append`\n                must both be ``None``.\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n                functions of the :class:`LinearOperator`.\n        \"\"\"\n        self.axes, ops = linop_over_axes(\n            SingleAxisFiniteDifference,\n            input_shape,\n            axes=axes,\n            input_dtype=input_dtype,\n            prepend=prepend,\n            append=append,\n            circular=circular,\n            jit=False,\n        )\n        super().__init__(\n            ops,  # type: ignore\n            jit=jit,\n            **kwargs,\n        )\n\n\nclass SingleAxisFiniteDifference(LinearOperator):\n    r\"\"\"Finite difference operator acting along a single axis.\n\n    By default (i.e. `prepend` and `append` set to ``None`` and `circular`\n    set to ``False``), the difference operator corresponds to the matrix\n\n    .. math::\n\n       \\left(\\begin{array}{rrrrr}\n       -1 & 1 & 0 & \\ldots & 0\\\\\n       0 & -1 & 1 & \\ldots & 0\\\\\n       \\vdots & \\vdots & \\ddots & \\ddots & \\vdots\\\\\n       0 & 0 & \\ldots & -1 & 1\n       \\end{array}\\right) \\;,\n\n    mapping :math:`\\mbb{R}^N \\rightarrow \\mbb{R}^{N-1}`, while if `circular`\n    is ``True``, it corresponds to the :math:`\\mbb{R}^N \\rightarrow \\mbb{R}^N`\n    mapping\n\n    .. math::\n\n       \\left(\\begin{array}{rrrrr}\n       -1 & 1 & 0 & \\ldots & 0\\\\\n       0 & -1 & 1 & \\ldots & 0\\\\\n       \\vdots & \\vdots & \\ddots & \\ddots & \\vdots\\\\\n       0 & 0 & \\ldots & -1 & 1\\\\\n       1 & 0 & \\dots & 0 & -1\n       \\end{array}\\right) \\;.\n\n    Other possible choices include `prepend` set to ``None`` and `append`\n    set to `0`, giving the :math:`\\mbb{R}^N \\rightarrow \\mbb{R}^N`\n    mapping\n\n    .. math::\n\n       \\left(\\begin{array}{rrrrr}\n       -1 & 1 & 0 & \\ldots & 0\\\\\n       0 & -1 & 1 & \\ldots & 0\\\\\n       \\vdots & \\vdots & \\ddots & \\ddots & \\vdots\\\\\n       0 & 0 & \\ldots & -1 & 1\\\\\n       0 & 0 & \\dots & 0 & 0\n       \\end{array}\\right) \\;,\n\n    and both `prepend` and `append` set to `1`, giving the\n    :math:`\\mbb{R}^N \\rightarrow \\mbb{R}^{N+1}` mapping\n\n    .. math::\n\n       \\left(\\begin{array}{rrrrr}\n        1 & 0 & 0 & \\ldots & 0\\\\\n       -1 & 1 & 0 & \\ldots & 0\\\\\n       0 & -1 & 1 & \\ldots & 0\\\\\n       \\vdots & \\vdots & \\ddots & \\ddots & \\vdots\\\\\n       0 & 0 & \\ldots & -1 & 1\\\\\n       0 & 0 & \\dots & 0 & -1\n       \\end{array}\\right) \\;.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        input_dtype: DType = np.float32,\n        axis: int = -1,\n        prepend: Optional[Union[Literal[0], Literal[1]]] = None,\n        append: Optional[Union[Literal[0], Literal[1]]] = None,\n        circular: bool = False,\n        jit: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`.\n            axis: Axis over which to apply finite difference operator.\n            prepend: Flag indicating handling of the left/top/etc.\n                boundary. If ``None``, there is no boundary extension.\n                Values of `0` or `1` indicate respectively that zeros or\n                the initial value in the array are prepended to the\n                difference array.\n            append: Flag indicating handling of the right/bottom/etc.\n                boundary. If ``None``, there is no boundary extension.\n                Values of `0` or `1` indicate respectively that zeros or\n                -1 times the final value in the array are appended to the\n                difference array.\n            circular: If ``True``, perform circular differences, i.e.,\n                include x[-1] - x[0]. If ``True``, `prepend` and `append`\n                must both be ``None``.\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n                functions of the :class:`LinearOperator`.\n        \"\"\"\n\n        if not isinstance(axis, int):\n            raise TypeError(\n                f\"Expected argument 'axis' to be of type int, got {type(axis)} instead.\"\n            )\n\n        if axis < 0:\n            axis = len(input_shape) + axis\n        if axis >= len(input_shape):\n            raise ValueError(\n                f\"Invalid axis {axis} specified; axis must be less than \"\n                f\"len(input_shape)={len(input_shape)}.\"\n            )\n\n        self.axis = axis\n\n        if circular and (prepend is not None or append is not None):\n            raise ValueError(\n                \"Argument 'circular' must be False if either prepend or append is not None.\"\n            )\n        if prepend not in [None, 0, 1]:\n            raise ValueError(\"Argument 'prepend' may only take values None, 0, or 1.\")\n        if append not in [None, 0, 1]:\n            raise ValueError(\"Argument 'append' may only take values None, 0, or 1.\")\n\n        self.prepend = prepend\n        self.append = append\n        self.circular = circular\n\n        if self.circular:\n            output_shape = input_shape\n        else:\n            output_shape = tuple(\n                x + ((i == axis) * ((self.prepend is not None) + (self.append is not None) - 1))\n                for i, x in enumerate(input_shape)\n            )\n\n        super().__init__(\n            input_shape=input_shape,\n            output_shape=output_shape,\n            input_dtype=input_dtype,\n            output_dtype=input_dtype,\n            jit=jit,\n            **kwargs,\n        )\n\n    def _eval(self, x: snp.Array) -> snp.Array:\n        prepend = None\n        append = None\n        if self.circular:\n            # Append a copy of the initial value at the end of the array so that the difference\n            # array includes the difference across the right/bottom/etc. boundary.\n            ind = tuple(\n                slice(0, 1) if i == self.axis else slice(None) for i in range(len(self.input_shape))\n            )\n            append = x[ind]\n        else:\n            if self.prepend == 0:\n                # Prepend a 0 to the difference array by prepending a copy of the initial value\n                # before the difference is computed.\n                ind = tuple(\n                    slice(0, 1) if i == self.axis else slice(None)\n                    for i in range(len(self.input_shape))\n                )\n                prepend = x[ind]\n            elif self.prepend == 1:\n                # Prepend a copy of the initial value to the difference array by prepending a 0\n                # before the difference is computed.\n                prepend = 0\n            if self.append == 0:\n                # Append a 0 to the difference array by appending a copy of the initial value\n                # before the difference is computed.\n                ind = tuple(\n                    slice(-1, None) if i == self.axis else slice(None)\n                    for i in range(len(self.input_shape))\n                )\n                append = x[ind]\n            elif self.append == 1:\n                # Append a copy of the initial value to the difference array by appending a 0\n                # before the difference is computed.\n                append = 0\n\n        return snp.diff(x, axis=self.axis, prepend=prepend, append=append)\n"
  },
  {
    "path": "scico/linop/_func.py",
    "content": "# Copyright (C) 2020-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Linear operators constructed from functions.\"\"\"\n\nfrom typing import Any, Callable, Optional, Sequence, Union\n\nimport jax\n\nimport scico.numpy as snp\nfrom scico._core import linear_transpose\nfrom scico.numpy.util import indexed_shape, is_nested\nfrom scico.typing import ArrayIndex, BlockShape, DType, Shape\n\nfrom ._linop import LinearOperator\n\n__all__ = [\"operator_from_function\", \"Tranpose\", \"Sum\", \"Crop\", \"Pad\", \"Reshape\", \"Slice\"]\n\n\ndef linop_from_function(f: Callable, classname: str, f_name: Optional[str] = None):\n    \"\"\"Make a :class:`LinearOperator` from a function.\n\n    Example\n    -------\n    >>> Sum = linop_from_function(snp.sum, 'Sum')\n    >>> H = Sum((2, 10), axis=1)\n    >>> H @ snp.ones((2, 10))\n    Array([10., 10.], dtype=float32)\n\n    Args:\n        f: Function from which to create a :class:`LinearOperator`.\n        classname: Name of the resulting class.\n        f_name: Name of `f` for use in docstrings. Useful for getting\n            the correct version of wrapped functions. Defaults to\n            `f\"{f.__module__}.{f.__name__}\"`.\n    \"\"\"\n\n    if f_name is None:\n        f_name = f\"{f.__module__}.{f.__name__}\"\n\n    f_doc = rf\"\"\"\n\n        Args:\n            input_shape: Shape of input array.\n            args: Positional arguments passed to :func:`{f_name}`.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`. If the :class:`LinearOperator`\n                implements complex-valued operations, this must be a\n                complex dtype (typically :attr:`~numpy.complex64`) for\n                correct adjoint and gradient calculation.\n            output_shape: Shape of output array. Defaults to ``None``.\n                If ``None``, `output_shape` is determined by evaluating\n                `self.__call__` on an input array of zeros.\n            output_dtype: `dtype` for output argument. Defaults to\n                ``None``. If ``None``, `output_dtype` is determined by\n                evaluating `self.__call__` on an input array of zeros.\n            jit: If ``True``, call :meth:`~.LinearOperator.jit` on this\n                :class:`LinearOperator` to jit the forward, adjoint, and\n                gram functions. Same as calling\n                :meth:`~.LinearOperator.jit` after the\n                :class:`LinearOperator` is created.\n            kwargs: Keyword arguments passed to :func:`{f_name}`.\n        \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Union[Shape, BlockShape],\n        *args: Any,\n        input_dtype: DType = snp.float32,\n        output_shape: Optional[Union[Shape, BlockShape]] = None,\n        output_dtype: Optional[DType] = None,\n        jit: bool = True,\n        **kwargs: Any,\n    ):\n        self._eval = lambda x: f(x, *args, **kwargs)\n        self.kwargs = kwargs\n        super().__init__(input_shape, input_dtype=input_dtype, output_shape=output_shape, output_dtype=output_dtype, jit=jit)  # type: ignore\n\n    OpClass = type(classname, (LinearOperator,), {\"__init__\": __init__})\n    __class__ = OpClass  # needed for super() to work\n\n    OpClass.__doc__ = f\"Linear operator version of :func:`{f_name}`.\"\n    OpClass.__init__.__doc__ = f_doc  # type: ignore\n\n    return OpClass\n\n\nTranspose = linop_from_function(snp.transpose, \"Transpose\", \"scico.numpy.transpose\")\nReshape = linop_from_function(snp.reshape, \"Reshape\")\nPad = linop_from_function(snp.pad, \"Pad\", \"scico.numpy.pad\")\nSum = linop_from_function(snp.sum, \"Sum\")\n\n\nclass Crop(LinearOperator):\n    \"\"\"A linear operator for cropping an array.\"\"\"\n\n    def __init__(\n        self,\n        crop_width: Union[int, Sequence],\n        input_shape: Shape,\n        input_dtype: DType = snp.float32,\n        jit: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            crop_width: Specify the crop width using the same format as\n                the `pad_width` parameter of :func:`snp.pad`.\n            input_shape: Shape of input :class:`jax.Array`.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`.\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n               functions of the :class:`LinearOperator`.\n        \"\"\"\n\n        self.crop_width = crop_width\n        # The crop function is defined as the adjoint of snp.pad\n        pad = lambda x: snp.pad(x, pad_width=crop_width)\n        # The output shape of this operator is the input shape of the corresponding\n        # pad operation of which it is the adjoint. Since we don't know this output\n        # shape, we assume that it can be computed by subtracting the difference in\n        # output and input shapes resulting from applying the pad operator to the\n        # input shape of this operator.\n        pad_shape = jax.eval_shape(pad, jax.ShapeDtypeStruct(input_shape, dtype=input_dtype)).shape\n        output_shape = tuple((2 * snp.array(input_shape) - snp.array(pad_shape)).tolist())\n        pad_adjoint = linear_transpose(pad, jax.ShapeDtypeStruct(output_shape, dtype=input_dtype))\n        super().__init__(\n            input_shape=input_shape,\n            input_dtype=input_dtype,\n            eval_fn=lambda x: pad_adjoint(x)[0],\n            output_shape=output_shape,\n            output_dtype=input_dtype,\n            jit=jit,\n            **kwargs,\n        )\n\n\nclass Slice(LinearOperator):\n    \"\"\"A linear operator for slicing an array.\"\"\"\n\n    def __init__(\n        self,\n        idx: ArrayIndex,\n        input_shape: Union[Shape, BlockShape],\n        input_dtype: DType = snp.float32,\n        jit: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        This operator may be applied to either a :class:`jax.Array` or a\n        :class:`.BlockArray`. In the latter case, parameter `idx` must\n        conform to the\n        :ref:`BlockArray indexing requirements <blockarray_indexing>`.\n\n        Args:\n            idx: An array indexing expression, as generated by\n                :data:`numpy.s_`, for example.\n            input_shape: Shape of input :class:`jax.Array` or :class:`.BlockArray`.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`.\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n                functions of the :class:`LinearOperator`.\n        \"\"\"\n\n        output_shape: Union[Shape, BlockShape]\n        if is_nested(input_shape):\n            output_shape = input_shape[idx]  # type: ignore\n        else:\n            output_shape = indexed_shape(input_shape, idx)  # type: ignore\n\n        self.idx: ArrayIndex = idx\n        super().__init__(\n            input_shape=input_shape,\n            output_shape=output_shape,\n            input_dtype=input_dtype,\n            output_dtype=input_dtype,\n            jit=jit,\n            **kwargs,\n        )\n\n    def _eval(self, x: snp.Array) -> snp.Array:\n        return x[self.idx]\n"
  },
  {
    "path": "scico/linop/_grad.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Non-Cartesian gradient linear operators.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom typing import Optional, Sequence, Tuple, Union\n\nimport numpy as np\n\nimport scico.numpy as snp\nfrom scico.numpy import Array, BlockArray\nfrom scico.typing import DType, Shape\n\nfrom ._linop import LinearOperator\n\n\ndef diffstack(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:\n    \"\"\"Compute the discrete difference along multiple axes.\n\n    Apply :func:`snp.diff` along multiple axes, stacking the results on\n    a newly inserted axis at index 0. The `append` parameter of\n    :func:`snp.diff` is exploited to give output of the same length as\n    the input, which is achieved by zero-padding the output at the end\n    of each axis.\n\n\n    \"\"\"\n    if axis is None:\n        axis = tuple(range(x.ndim))\n    elif isinstance(axis, int):\n        axis = (axis,)\n    dstack = [\n        snp.diff(\n            x,\n            axis=ax,\n            append=x[tuple(slice(-1, None) if i == ax else slice(None) for i in range(x.ndim))],\n        )\n        for ax in axis\n    ]\n    return snp.stack(dstack)\n\n\nclass ProjectedGradient(LinearOperator):\n    \"\"\"Gradient projected onto local coordinate system.\n\n    This class represents a linear operator that computes gradients of\n    arrays projected onto a local coordinate system that may differ at\n    every position in the array, as described in\n    :cite:`hossein-2024-total`. In the 2D illustration below :math:`x`\n    and :math:`y` represent the standard coordinate system defined by the\n    array axes, :math:`(g_x, g_y)` is the gradient vector within that\n    coordinate system, :math:`x'` and :math:`y'` are the local coordinate\n    axes, and :math:`(g_x', g_y')` is the gradient vector within the\n    local coordinate system.\n\n    .. image:: /figures/projgrad.svg\n         :align: center\n         :alt: Figure illustrating projection of gradient onto local\n               coordinate system.\n\n    Each of the local coordinate axes (e.g. :math:`x'` and :math:`y'` in\n    the illustration above) is represented by a separate array in the\n    `coord` tuple of arrays parameter of the class initializer.\n\n    .. note::\n\n       This operator should not be confused with the Projected Gradient\n       optimization algorithm (a special case of Proximal Gradient), with\n       which it is unrelated.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        axes: Optional[Tuple[int, ...]] = None,\n        coord: Optional[Sequence[Union[Array, BlockArray]]] = None,\n        cdiff: bool = False,\n        input_dtype: DType = np.float32,\n        jit: bool = True,\n    ):\n        r\"\"\"\n\n        The result of applying the operator is always a\n        :class:`jax.Array`. If `coord` is a singleton tuple, it has the\n        same shape as the input array. Otherwise, the gradients for each\n        of the local coordinate axes are stacked on an additional axis at\n        index 0.\n\n        If `coord` is ``None``, which is the default, gradients are\n        computed in the standard axis-aligned coordinate system, and the\n        shape of the returned array depends on the number of axes on\n        which the gradient is calculated, as specified explicitly or\n        implicitly via the `axes` parameter.\n\n        Args:\n            input_shape: Shape of input array.\n            axes: Axes over which to compute the gradient. Defaults to\n                ``None``, in which case the gradient is computed along\n                all axes.\n            coord: A tuple of arrays, each of which specifies a local\n                coordinate axis direction. Each member of the tuple\n                should either be a :class:`jax.Array` or a\n                :class:`.BlockArray`. If it is the former, it should have\n                shape :math:`N \\times M_0 \\times M_1 \\times \\ldots`,\n                where :math:`N` is the number of axes specified by\n                parameter `axes`, and :math:`M_i` is the size of the\n                :math:`i^{\\mrm{th}}` axis. If it is the latter, it should\n                consist of :math:`N` blocks, each of which has a shape\n                that is suitable for multiplication with an array of\n                shape :math:`M_0 \\times M_1 \\times \\ldots`.\n            cdiff: If ``True``, estimate gradients using the second order\n                central different returned by :func:`snp.gradient`,\n                otherwise use the first order asymmetric difference\n                returned by :func:`snp.diff`.\n            input_dtype: `dtype` for input argument. Default is\n                :attr:`~numpy.float32`.\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n                functions of the LinearOperator.\n        \"\"\"\n        if axes is None:\n            # If axes is None, set it to all axes in input shape.\n            self.axes = tuple(range(len(input_shape)))\n        else:\n            # Ensure no invalid axis indices specified.\n            if snp.any(np.array(axes) >= len(input_shape)):\n                raise ValueError(\n                    \"Invalid axes specified; all elements of argument 'axes' must \"\n                    f\"be less than len(input_shape)={len(input_shape)}.\"\n                )\n            self.axes = axes\n        output_shape: Shape\n        if coord is None:\n            # If coord is None, output shape is determined by number of axes.\n            if len(self.axes) == 1:\n                output_shape = input_shape\n            else:\n                output_shape = (len(self.axes),) + input_shape\n        else:\n            # If coord is not None, output shape is determined by number of coord arrays.\n            if len(coord) == 1:\n                output_shape = input_shape\n            else:\n                output_shape = (len(coord),) + input_shape\n        self.coord = coord\n        self.cdiff = cdiff\n        super().__init__(\n            input_shape=input_shape,\n            output_shape=output_shape,\n            input_dtype=input_dtype,\n            output_dtype=input_dtype,\n            jit=jit,\n        )\n\n    def _eval(self, x: Array) -> Union[Array, BlockArray]:\n\n        if self.cdiff:\n            grad = snp.stack(snp.gradient(x, axis=self.axes))\n        else:\n            grad = diffstack(x, axis=self.axes)\n        if self.coord is None:\n            # If coord attribute is None, just return gradients on specified axes.\n            if len(self.axes) == 1:\n                return grad[0]\n            else:\n                return grad\n        else:\n            # If coord attribute is not None, return gradients projected onto specified local\n            # coordinate systems.\n            projgrad = [sum([c[m] * grad[m] for m in range(len(self.axes))]) for c in self.coord]\n            if len(self.coord) == 1:\n                return projgrad[0]\n            else:\n                return snp.stack(projgrad)\n\n\nclass PolarGradient(ProjectedGradient):\n    \"\"\"Gradient projected into polar coordinates.\n\n    Compute gradients projected onto angular and/or radial axis\n    directions, as described in :cite:`hossein-2024-total`. Local\n    coordinate axes are illustrated in the figure below.\n\n    .. plot:: pyfigures/polargrad.py\n       :align: center\n       :include-source: False\n       :show-source-link: False\n\n    |\n\n    If only one of `angular` and `radial` is ``True``, the operator\n    output has the same shape as the input, otherwise the gradients for\n    the two local coordinate axes are stacked on an additional axis at\n    index 0.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        axes: Optional[Tuple[int, ...]] = None,\n        center: Optional[Union[Tuple[int, ...], Array]] = None,\n        angular: bool = True,\n        radial: bool = True,\n        cdiff: bool = False,\n        input_dtype: DType = np.float32,\n        jit: bool = True,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array.\n            axes: Axes over which to compute the gradient. Should be a\n                tuple :math:`(i_x, i_y)`, where :math:`i_x` and\n                :math:`i_y` are input array axes assigned to :math:`x`\n                and :math:`y` coordinates respectively. Defaults to\n                ``None``, in which case the axes are taken to be `(0, 1)`.\n            center: Center of the polar coordinate system in array\n                indexing coordinates. Default is ``None``, which places\n                the center at the center of the input array.\n            angular: Flag indicating whether to compute gradients in the\n                angular (i.e. tangent to circles) direction.\n            radial: Flag indicating whether to compute gradients in the\n                radial (i.e. directed outwards from the origin) direction.\n            cdiff: If ``True``, estimate gradients using the second order\n                central different returned by :func:`snp.gradient`,\n                otherwise use the first order asymmetric difference\n                returned by :func:`snp.diff`.\n            input_dtype: `dtype` for input argument. Default is\n                :attr:`~numpy.float32`.\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n                functions of the LinearOperator.\n        \"\"\"\n\n        if len(input_shape) < 2:\n            raise ValueError(\"Invalid input shape; input must have at least two axes.\")\n        if axes is not None and len(axes) != 2:\n            raise ValueError(\"Invalid axes specified; exactly two axes must be specified.\")\n        if not angular and not radial:\n            raise ValueError(\"At least one of angular and radial must be True.\")\n\n        real_input_dtype = snp.util.real_dtype(input_dtype)\n        if axes is None:\n            axes = (0, 1)\n        axes_shape = [input_shape[ax] for ax in axes]\n        if center is None:\n            center = (snp.array(axes_shape, dtype=real_input_dtype) - 1) / 2\n        else:\n            center = snp.array(center, dtype=real_input_dtype)\n        end = snp.array(axes_shape, dtype=real_input_dtype) - center\n        g0, g1 = snp.ogrid[-center[0] : end[0], -center[1] : end[1]]\n        theta = snp.arctan2(g0, g1)\n        # Re-order theta axes in case indices in axes parameter are not in increasing order.\n        axis_order = np.argsort(axes)\n        theta = snp.transpose(theta, axis_order)\n        if len(input_shape) > 2:\n            # Construct list of input axes that are not included in the gradient axes.\n            single = tuple(set(range(len(input_shape))) - set(axes))\n            # Insert singleton axes to align theta for multiplication with gradients.\n            theta = snp.expand_dims(theta, single)\n        coord = []\n        if angular:\n            coord.append(snp.blockarray([-snp.cos(theta), snp.sin(theta)]))\n        if radial:\n            coord.append(snp.blockarray([snp.sin(theta), snp.cos(theta)]))\n        super().__init__(\n            input_shape=input_shape,\n            input_dtype=input_dtype,\n            axes=axes,\n            coord=coord,\n            cdiff=cdiff,\n            jit=jit,\n        )\n\n\nclass CylindricalGradient(ProjectedGradient):\n    \"\"\"Gradient projected into cylindrical coordinates.\n\n    Compute gradients projected onto cylindrical coordinate axes, as\n    described in :cite:`hossein-2024-total`. The local coordinate axes\n    are illustrated in the figure below.\n\n    .. plot:: pyfigures/cylindgrad.py\n       :align: center\n       :include-source: False\n       :show-source-link: False\n\n    |\n\n    If only one of `angular`, `radial`, and `axial` is ``True``, the\n    operator output has the same shape as the input, otherwise the\n    gradients for the selected local coordinate axes are stacked on an\n    additional axis at index 0.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        axes: Optional[Tuple[int, ...]] = None,\n        center: Optional[Union[Tuple[int, ...], Array]] = None,\n        angular: bool = True,\n        radial: bool = True,\n        axial: bool = True,\n        cdiff: bool = False,\n        input_dtype: DType = np.float32,\n        jit: bool = True,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array.\n            axes: Axes over which to compute the gradient. Should be a\n                tuple :math:`(i_x, i_y, i_z)`, where :math:`i_x`,\n                :math:`i_y` and :math:`i_z` are input array axes assigned\n                to :math:`x`, :math:`y`, and :math:`z` coordinates\n                respectively. Defaults to ``None``, in which case the\n                axes are taken to be `(0, 1, 2)`. If an integer, this\n                operator returns a :class:`jax.Array`. If a tuple or\n                ``None``, the resulting arrays are stacked into a\n                :class:`.BlockArray`.\n            center: Center of the cylindrical coordinate system in array\n                indexing coordinates. Default is ``None``, which places\n                the center at the center of the two polar axes of the\n                input array and at the zero index of the axial axis.\n            angular: Flag indicating whether to compute gradients in the\n                angular (i.e. tangent to circles) direction.\n            radial: Flag indicating whether to compute gradients in the\n                radial (i.e. directed outwards from the origin) direction.\n            axial: Flag indicating whether to compute gradients in the\n                direction of the axis of the cylinder.\n            cdiff: If ``True``, estimate gradients using the second order\n                central different returned by :func:`snp.gradient`,\n                otherwise use the first order asymmetric difference\n                returned by :func:`snp.diff`.\n            input_dtype: `dtype` for input argument. Default is\n                :attr:`~numpy.float32`.\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n                functions of the LinearOperator.\n        \"\"\"\n\n        if len(input_shape) < 3:\n            raise ValueError(\"Invalid input shape; input must have at least three axes.\")\n        if axes is not None and len(axes) != 3:\n            raise ValueError(\"Invalid axes specified; exactly three axes must be specified.\")\n        if not angular and not radial and not axial:\n            raise ValueError(\"At least one of angular, radial, and axial must be True.\")\n\n        real_input_dtype = snp.util.real_dtype(input_dtype)\n        if axes is None:\n            axes = (0, 1, 2)\n        axes_shape = [input_shape[ax] for ax in axes]\n        if center is None:\n            center = (snp.array(axes_shape, dtype=real_input_dtype) - 1) / 2\n            center = center.at[-1].set(0)  # type: ignore\n        else:\n            center = snp.array(center, dtype=real_input_dtype)\n        end = snp.array(axes_shape, dtype=real_input_dtype) - center\n        g0, g1 = snp.ogrid[-center[0] : end[0], -center[1] : end[1]]\n        g0 = g0[..., np.newaxis]\n        g1 = g1[..., np.newaxis]\n        theta = snp.arctan2(g0, g1)\n        # Re-order theta axes in case indices in axes parameter are not in increasing order.\n        axis_order = np.argsort(axes)\n        theta = snp.transpose(theta, axis_order)\n        if len(input_shape) > 3:\n            # Construct list of input axes that are not included in the gradient axes.\n            single = tuple(set(range(len(input_shape))) - set(axes))\n            # Insert singleton axes to align theta for multiplication with gradients.\n            theta = snp.expand_dims(theta, single)\n        coord = []\n        if angular:\n            coord.append(\n                snp.blockarray(\n                    [-snp.cos(theta), snp.sin(theta), snp.array([0.0], dtype=real_input_dtype)]\n                )\n            )\n        if radial:\n            coord.append(\n                snp.blockarray(\n                    [snp.sin(theta), snp.cos(theta), snp.array([0.0], dtype=real_input_dtype)]\n                )\n            )\n        if axial:\n            coord.append(\n                snp.blockarray(\n                    [\n                        snp.array([0.0], dtype=real_input_dtype),\n                        snp.array([0.0], dtype=real_input_dtype),\n                        snp.array([1.0], dtype=real_input_dtype),\n                    ]\n                )\n            )\n        super().__init__(\n            input_shape=input_shape,\n            input_dtype=input_dtype,\n            axes=axes,\n            cdiff=cdiff,\n            coord=coord,\n            jit=jit,\n        )\n\n\nclass SphericalGradient(ProjectedGradient):\n    \"\"\"Gradient projected into spherical coordinates.\n\n    Compute gradients projected onto spherical coordinate axes, based on\n    the approach described in :cite:`hossein-2024-total`. The local\n    coordinate axes are illustrated in the figure below.\n\n    .. plot:: pyfigures/spheregrad.py\n       :align: center\n       :include-source: False\n       :show-source-link: False\n\n    |\n\n    If only one of `azimuthal`, `polar`, and `radial` is ``True``, the\n    operator output has the same shape as the input, otherwise the\n    gradients for the selected local coordinate axes are stacked on an\n    additional axis at index 0.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        axes: Optional[Tuple[int, ...]] = None,\n        center: Optional[Union[Tuple[int, ...], Array]] = None,\n        azimuthal: bool = True,\n        polar: bool = True,\n        radial: bool = True,\n        cdiff: bool = False,\n        input_dtype: DType = np.float32,\n        jit: bool = True,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array.\n            axes: Axes over which to compute the gradient. Should be a\n                tuple :math:`(i_x, i_y, i_z)`, where :math:`i_x`,\n                :math:`i_y` and :math:`i_z` are input array axes assigned\n                to :math:`x`, :math:`y`, and :math:`z` coordinates\n                respectively. Defaults to ``None``, in which case the\n                axes are taken to be `(0, 1, 2)`. If an integer, this\n                operator returns a :class:`jax.Array`. If a tuple or\n                ``None``, the resulting arrays are stacked into a\n                :class:`.BlockArray`.\n            center: Center of the spherical coordinate system in array\n                indexing coordinates. Default is ``None``, which places\n                the center at the center of the input array.\n            azimuthal: Flag indicating whether to compute gradients in\n                the azimuthal direction.\n            polar: Flag indicating whether to compute gradients in the\n                polar direction.\n            radial: Flag indicating whether to compute gradients in the\n                radial direction.\n            cdiff: If ``True``, estimate gradients using the second order\n                central different returned by :func:`snp.gradient`,\n                otherwise use the first order asymmetric difference\n                returned by :func:`snp.diff`.\n            input_dtype: `dtype` for input argument. Default is\n                :attr:`~numpy.float32`.\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n                functions of the LinearOperator.\n        \"\"\"\n\n        if len(input_shape) < 3:\n            raise ValueError(\"Invalid input shape; input must have at least three axes.\")\n        if axes is not None and len(axes) != 3:\n            raise ValueError(\"Invalid axes specified; exactly three axes must be specified.\")\n        if not azimuthal and not polar and not radial:\n            raise ValueError(\"At least one of azimuthal, polar, and radial must be True.\")\n\n        real_input_dtype = snp.util.real_dtype(input_dtype)\n        if axes is None:\n            axes = (0, 1, 2)\n        axes_shape = [input_shape[ax] for ax in axes]\n        if center is None:\n            center = (snp.array(axes_shape, dtype=real_input_dtype) - 1) / 2\n        else:\n            center = snp.array(center, dtype=real_input_dtype)\n        end = snp.array(axes_shape, dtype=real_input_dtype) - center\n        g0, g1, g2 = snp.ogrid[-center[0] : end[0], -center[1] : end[1], -center[2] : end[2]]\n        theta = snp.arctan2(g1, g0)\n        phi = snp.arctan2(snp.sqrt(g0**2 + g1**2), g2)\n        # Re-order theta and phi axes in case indices in axes parameter are not in\n        # increasing order.\n        axis_order = np.argsort(axes)\n        theta = snp.transpose(theta, axis_order)\n        phi = snp.transpose(phi, axis_order)\n        if len(input_shape) > 3:\n            # Construct list of input axes that are not included in the gradient axes.\n            single = tuple(set(range(len(input_shape))) - set(axes))\n            # Insert singleton axes to align theta for multiplication with gradients.\n            theta = snp.expand_dims(theta, single)\n            phi = snp.expand_dims(phi, single)\n        coord = []\n        if azimuthal:\n            coord.append(\n                snp.blockarray(\n                    [snp.sin(theta), -snp.cos(theta), snp.array([0.0], dtype=real_input_dtype)]\n                )\n            )\n        if polar:\n            coord.append(\n                snp.blockarray(\n                    [snp.cos(phi) * snp.cos(theta), snp.cos(phi) * snp.sin(theta), -snp.sin(phi)]\n                )\n            )\n        if radial:\n            coord.append(\n                snp.blockarray(\n                    [snp.sin(phi) * snp.cos(theta), snp.sin(phi) * snp.sin(theta), snp.cos(phi)]\n                )\n            )\n        super().__init__(\n            input_shape=input_shape,\n            input_dtype=input_dtype,\n            axes=axes,\n            coord=coord,\n            cdiff=cdiff,\n            jit=jit,\n        )\n"
  },
  {
    "path": "scico/linop/_linop.py",
    "content": "# Copyright (C) 2020-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Linear operator base class.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom functools import wraps\nfrom typing import Callable, Optional, Union\n\nimport numpy as np\n\nimport jax\nimport jax.numpy as jnp\nfrom jax.dtypes import result_type\n\nimport scico.numpy as snp\nfrom scico._core import linear_adjoint\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.util import is_complex_dtype\nfrom scico.operator._operator import Operator, _wrap_mul_div_scalar\nfrom scico.typing import BlockShape, DType, Shape\n\n\ndef _wrap_add_sub(func: Callable) -> Callable:\n    r\"\"\"Wrapper function for defining `__add__` and `__sub__`.\n\n    Wrapper function for defining `__add__` and ` __sub__` between\n    :class:`LinearOperator` and derived classes. Operations\n    between :class:`LinearOperator` and :class:`.Operator`\n    types are also supported.\n\n    Handles shape checking and function dispatch based on types of\n    operands `a` and `b` in the call `func(a, b)`. Note that `func`\n    will always be a method of the type of `a`, and since this wrapper\n    should only be applied within :class:`LinearOperator` or derived\n    classes, we can assume that `a` is always an instance of\n    :class:`LinearOperator`. The general rule for dispatch is that the\n    `__add__` or `__sub__` operator of the nearest common base class\n    of `a` and `b` should be called. If `b` is derived from `a`, this\n    entails using the operator defined in the class of `a`, and\n    vice-versa. If one of the operands is not a descendant of the other\n    in the class hierarchy, then it is assumed that their common base\n    class is either :class:`.Operator` or :class:`LinearOperator`,\n    depending on the type of `b`.\n\n    - If `b` is not an instance of :class:`.Operator`, a :exc:`TypeError`\n      is raised.\n    - If the shapes of `a` and `b` do not match, a :exc:`ValueError` is\n      raised.\n    - If `b` is an instance of the type of `a` then `func(a, b)` is\n      called where `func` is the argument of this wrapper, i.e.\n      the unwrapped function defined in the class of `a`.\n    - If `a` is an instance of the type of `b` then `func(a, b)` is\n      called where `func` is the unwrapped function defined in the class\n      of `b`.\n    - If `b` is a :class:`LinearOperator` then `func(a, b)` is called\n      where `func` is the operator defined in :class:`LinearOperator`.\n    - Othwerwise, `func(a, b)` is called where `func` is the operator\n      defined in :class:`.Operator`.\n\n    Args:\n        func: should be either `.__add__` or `.__sub__`.\n\n    Returns:\n       Wrapped version of `func`.\n\n    Raises:\n        ValueError: If the shapes of two operators do not match.\n        TypeError: If one of the two operands is not an\n            :class:`.Operator` or :class:`LinearOperator`.\n    \"\"\"\n\n    @wraps(func)\n    def wrapper(\n        a: LinearOperator, b: Union[Operator, LinearOperator]\n    ) -> Union[Operator, LinearOperator]:\n        if isinstance(b, Operator):\n            if a.shape == b.shape:\n                if isinstance(b, type(a)):\n                    # b is an instance of the class of a: call the unwrapped operator\n                    # defined in the class of a, which is the func argument of this\n                    # wrapper\n                    return func(a, b)\n                if isinstance(a, type(b)):\n                    # a is an instance of class b: call the unwrapped operator\n                    # defined in the class of b. A test is required because\n                    # the operators defined in Operator and non-LinearOperator\n                    # derived classes are not wrapped.\n                    if hasattr(getattr(type(b), func.__name__), \"_unwrapped\"):\n                        uwfunc = getattr(type(b), func.__name__)._unwrapped\n                    else:\n                        uwfunc = getattr(type(b), func.__name__)\n                    return uwfunc(a, b)\n                # The most general approach here would be to automatically determine\n                # the nearest common ancestor of the classes of a and b (e.g. as\n                # discussed in https://stackoverflow.com/a/58290475 ), but the\n                # simpler approach adopted here is to just assume that the common\n                # base of two classes that do not have an ancestor-descendant\n                # relationship is either Operator or LinearOperator.\n                if isinstance(b, LinearOperator):\n                    # LinearOperator + LinearOperator -> LinearOperator\n                    uwfunc = getattr(LinearOperator, func.__name__)._unwrapped\n                    return uwfunc(a, b)\n                # LinearOperator + Operator -> Operator (access to the function\n                # definition differs from that for LinearOperator because\n                # Operator __add__ and __sub__ are not wrapped)\n                uwfunc = getattr(Operator, func.__name__)\n                return uwfunc(a, b)\n            raise ValueError(f\"Shapes {a.shape} and {b.shape} do not match.\")\n        raise TypeError(f\"Operation {func.__name__} not defined between {type(a)} and {type(b)}.\")\n\n    wrapper._unwrapped = func  # type: ignore\n\n    return wrapper\n\n\nclass LinearOperator(Operator):\n    \"\"\"Generic linear operator base class\"\"\"\n\n    def __init__(\n        self,\n        input_shape: Union[Shape, BlockShape],\n        output_shape: Optional[Union[Shape, BlockShape]] = None,\n        eval_fn: Optional[Callable] = None,\n        adj_fn: Optional[Callable] = None,\n        input_dtype: DType = np.float32,\n        output_dtype: Optional[DType] = None,\n        jit: bool = False,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array.\n            output_shape: Shape of output array. Defaults to ``None``.\n                If ``None``, `output_shape` is determined by evaluating\n                `self.__call__` on an input array of zeros.\n            eval_fn: Function used in evaluating this\n                :class:`LinearOperator`. Defaults to ``None``. If\n                ``None``, then `self.__call__` must be defined in any\n                derived classes.\n            adj_fn: Function used to evaluate the adjoint of this\n                :class:`LinearOperator`. Defaults to ``None``. If\n                ``None``, the adjoint is not set, and the\n                :meth:`._set_adjoint` will be called silently at the\n                first :meth:`.adj` call or can be called manually.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`. If the :class:`.LinearOperator`\n                implements complex-valued operations, this must be a\n                complex dtype (typically :attr:`~numpy.complex64`) for\n                correct adjoint and gradient calculation.\n            output_dtype: `dtype` for output argument. Defaults to\n                ``None``. If ``None``, `output_dtype` is determined by\n                evaluating `self.__call__` on an input array of zeros.\n            jit: If ``True``, call :meth:`.jit()` on this\n                :class:`LinearOperator` to jit the forward, adjoint, and\n                gram functions. Same as calling :meth:`.jit` after the\n                :class:`LinearOperator` is created.\n        \"\"\"\n\n        super().__init__(\n            input_shape=input_shape,\n            output_shape=output_shape,\n            eval_fn=eval_fn,\n            input_dtype=input_dtype,\n            output_dtype=output_dtype,\n            jit=False,\n        )\n\n        if not hasattr(self, \"_adj\"):\n            self._adj: Optional[Callable] = None\n        if not hasattr(self, \"_gram\"):\n            self._gram: Optional[Callable] = None\n        if callable(adj_fn):\n            self._adj = adj_fn\n            self._gram = lambda x: self.adj(self(x))\n        elif adj_fn is not None:\n            raise TypeError(f\"Argument 'adj_fn' must be either a Callable or None; got {adj_fn}.\")\n\n        if jit:\n            self.jit()\n\n    def _set_adjoint(self):\n        \"\"\"Automatically create adjoint method.\"\"\"\n        adj_fun = linear_adjoint(\n            self._eval, jax.ShapeDtypeStruct(self.input_shape, dtype=self.input_dtype)\n        )\n        self._adj = lambda x: adj_fun(x)[0]\n\n    def _set_gram(self):\n        \"\"\"Automatically create gram method.\"\"\"\n        self._gram = lambda x: self.adj(self(x))\n\n    def jit(self):\n        \"\"\"Replace the private functions :meth:`._eval`, :meth:`_adj`, :meth:`._gram`\n        with jitted versions.\n        \"\"\"\n        if self._adj is None:\n            self._set_adjoint()\n\n        if self._gram is None:\n            self._set_gram()\n\n        self._eval = jax.jit(self._eval)\n        self._adj = jax.jit(self._adj)\n        self._gram = jax.jit(self._gram)\n\n    @_wrap_add_sub\n    def __add__(self, other):\n        return LinearOperator(\n            input_shape=self.input_shape,\n            output_shape=self.output_shape,\n            eval_fn=lambda x: self(x) + other(x),\n            adj_fn=lambda x: self.adj(x) + other.adj(x),\n            input_dtype=self.input_dtype,\n            output_dtype=result_type(self.output_dtype, other.output_dtype),\n        )\n\n    @_wrap_add_sub\n    def __sub__(self, other):\n        return LinearOperator(\n            input_shape=self.input_shape,\n            output_shape=self.output_shape,\n            eval_fn=lambda x: self(x) - other(x),\n            adj_fn=lambda x: self.adj(x) - other.adj(x),\n            input_dtype=self.input_dtype,\n            output_dtype=result_type(self.output_dtype, other.output_dtype),\n        )\n\n    @_wrap_mul_div_scalar\n    def __mul__(self, other):\n        return LinearOperator(\n            input_shape=self.input_shape,\n            output_shape=self.output_shape,\n            eval_fn=lambda x: other * self(x),\n            adj_fn=lambda x: snp.conj(other) * self.adj(x),\n            input_dtype=self.input_dtype,\n            output_dtype=result_type(self.output_dtype, other),\n        )\n\n    @_wrap_mul_div_scalar\n    def __rmul__(self, other):\n        return self.__mul__(other)  # scalar multiplication is commutative\n\n    @_wrap_mul_div_scalar\n    def __truediv__(self, other):\n        return LinearOperator(\n            input_shape=self.input_shape,\n            output_shape=self.output_shape,\n            eval_fn=lambda x: self(x) / other,\n            adj_fn=lambda x: self.adj(x) / snp.conj(other),\n            input_dtype=self.input_dtype,\n            output_dtype=result_type(self.output_dtype, other),\n        )\n\n    def __matmul__(self, other):\n        # self @ other\n        return self(other)\n\n    def __rmatmul__(self, other):\n        # other @ self\n        if isinstance(other, LinearOperator):\n            return other(self)\n\n        if isinstance(other, (np.ndarray, jnp.ndarray)):\n            # for real valued inputs: y @ self == (self.T @ y.T).T\n            # for complex:  y @ self == (self.conj().T @ y.conj().T).conj().T\n            # self.conj().T == self.adj\n            return self.adj(other.conj().T).conj().T\n\n        raise NotImplementedError(\n            f\"Operation __rmatmul__ not defined between {type(self)} and {type(other)}.\"\n        )\n\n    def __call__(\n        self, x: Union[LinearOperator, Array, BlockArray]\n    ) -> Union[LinearOperator, Array, BlockArray]:\n        r\"\"\"Evaluate this :class:`LinearOperator` at the point :math:`\\mb{x}`.\n\n        Args:\n            x: Point at which to evaluate this :class:`LinearOperator`.\n               If `x` is a :class:`jax.Array` or :class:`.BlockArray`,\n               must have `shape == self.input_shape`. If `x` is a\n               :class:`LinearOperator`, must have\n               `x.output_shape == self.input_shape`.\n        \"\"\"\n        if isinstance(x, LinearOperator):\n            return ComposedLinearOperator(self, x)\n        # Use Operator __call__ for LinearOperator @ array or LinearOperator @ Operator\n        return super().__call__(x)\n\n    def adj(\n        self, y: Union[LinearOperator, Array, BlockArray]\n    ) -> Union[LinearOperator, Array, BlockArray]:\n        \"\"\"Adjoint of this :class:`LinearOperator`.\n\n        Compute the adjoint of this :class:`LinearOperator` applied to\n        input `y`.\n\n        Args:\n            y: Point at which to compute adjoint. If `y` is\n                :class:`jax.Array` or :class:`.BlockArray`, must have\n                `shape == self.output_shape`. If `y` is a\n                :class:`LinearOperator`, must have\n                `y.output_shape == self.output_shape`.\n\n        Returns:\n            Adjoint evaluated at `y`.\n        \"\"\"\n        if self._adj is None:\n            self._set_adjoint()\n\n        if isinstance(y, LinearOperator):\n            return ComposedLinearOperator(self.H, y)\n        if self.output_dtype != y.dtype:\n            raise ValueError(f\"Dtype error: expected {self.output_dtype}, got {y.dtype}.\")\n        if self.output_shape != y.shape:\n            raise ValueError(\n                f\"\"\"Shapes do not conform: input array with shape {y.shape} does not match\n                LinearOperator output_shape {self.output_shape}.\"\"\"\n            )\n        assert self._adj is not None\n        return self._adj(y)\n\n    @property\n    def T(self) -> LinearOperator:\n        \"\"\"Transpose of this :class:`LinearOperator`.\n\n        Return a new :class:`LinearOperator` that implements the\n        transpose of this :class:`LinearOperator`. For a real-valued\n        :class:`LinearOperator` `A` (`A.input_dtype` is\n        :attr:`~numpy.float32` or :attr:`~numpy.float64`), the\n        :class:`LinearOperator` `A.T` implements the adjoint:\n        `A.T(y) == A.adj(y)`. For a complex-valued :class:`LinearOperator`\n        `A` (`A.input_dtype` is :attr:`~numpy.complex64` or\n        :attr:`~numpy.complex128`), the :class:`LinearOperator` `A.T` is\n        not the adjoint. For the conjugate transpose, use `.conj().T` or\n        :meth:`.H`.\n        \"\"\"\n        if is_complex_dtype(self.input_dtype):\n            return LinearOperator(\n                input_shape=self.output_shape,\n                output_shape=self.input_shape,\n                eval_fn=lambda x: self.adj(x.conj()).conj(),\n                adj_fn=self.__call__,\n                input_dtype=self.input_dtype,\n                output_dtype=self.output_dtype,\n            )\n        return LinearOperator(\n            input_shape=self.output_shape,\n            output_shape=self.input_shape,\n            eval_fn=self.adj,\n            adj_fn=self.__call__,\n            input_dtype=self.output_dtype,\n            output_dtype=self.input_dtype,\n        )\n\n    @property\n    def H(self) -> LinearOperator:\n        \"\"\"Hermitian transpose of this :class:`LinearOperator`.\n\n        Return a new :class:`LinearOperator` that is the Hermitian\n        transpose of this :class:`LinearOperator`. For a real-valued\n        :class:`LinearOperator` `A` (`A.input_dtype` is\n        :attr:`~numpy.float32` or :attr:`~numpy.float64`), the\n        :class:`LinearOperator` `A.H` is equivalent to `A.T`. For a\n        complex-valued :class:`LinearOperator` `A` (`A.input_dtype` is\n        :attr:`~numpy.complex64` or :attr:`~numpy.complex128`), the\n        :class:`LinearOperator` `A.H` implements the adjoint of `A :\n        A.H @ y == A.adj(y) == A.conj().T @ y)`.\n\n        For the non-conjugate transpose, see :meth:`.T`.\n        \"\"\"\n        return LinearOperator(\n            input_shape=self.output_shape,\n            output_shape=self.input_shape,\n            eval_fn=self.adj,\n            adj_fn=self.__call__,\n            input_dtype=self.output_dtype,\n            output_dtype=self.input_dtype,\n        )\n\n    def conj(self) -> LinearOperator:\n        \"\"\"Complex conjugate of this :class:`LinearOperator`.\n\n        Return a new :class:`LinearOperator` `Ac` such that\n        `Ac(x) = conj(A)(x)`.\n        \"\"\"\n        # A.conj() x == (A @ x.conj()).conj()\n        return LinearOperator(\n            input_shape=self.input_shape,\n            output_shape=self.output_shape,\n            eval_fn=lambda x: self(x.conj()).conj(),\n            adj_fn=lambda x: self.adj(x.conj()).conj(),\n            input_dtype=self.input_dtype,\n            output_dtype=self.output_dtype,\n        )\n\n    @property\n    def gram_op(self) -> LinearOperator:\n        \"\"\"Gram operator of this :class:`LinearOperator`.\n\n        Return a new :class:`LinearOperator` `G` such that\n        `G(x) = A.adj(A(x)))`.\n        \"\"\"\n        if self._gram is None:\n            self._set_gram()\n\n        return LinearOperator(\n            input_shape=self.input_shape,\n            output_shape=self.input_shape,\n            eval_fn=self.gram,\n            adj_fn=self.gram,\n            input_dtype=self.input_dtype,\n            output_dtype=self.output_dtype,\n        )\n\n    def gram(\n        self, x: Union[LinearOperator, Array, BlockArray]\n    ) -> Union[LinearOperator, Array, BlockArray]:\n        \"\"\"Compute `A.adj(A(x)).`\n\n        Args:\n            x: Point at which to evaluate the gram operator. If `x` is\n               a :class:`jax.Array` or :class:`.BlockArray`, must have\n               `shape == self.input_shape`. If `x` is a\n               :class:`LinearOperator`, must have\n               `x.output_shape == self.input_shape`.\n\n        Returns:\n            Result of `A.adj(A(x))`.\n        \"\"\"\n        if self._gram is None:\n            self._set_gram()\n        assert self._gram is not None\n        return self._gram(x)\n\n\nclass ComposedLinearOperator(LinearOperator):\n    \"\"\"A composition of two :class:`LinearOperator` objects.\n\n    A new :class:`LinearOperator` formed by the composition of two other\n    :class:`LinearOperator` objects.\n    \"\"\"\n\n    def __init__(self, A: LinearOperator, B: LinearOperator, jit: bool = False):\n        r\"\"\"\n        A :class:`ComposedLinearOperator` `AB` implements\n        `AB @ x == A @ B @ x`. :class:`LinearOperator` `A` and `B` are\n        stored as attributes of the :class:`ComposedLinearOperator`.\n\n        :class:`LinearOperator` `A` and `B` must have compatible shapes\n        and dtypes: `A.input_shape == B.output_shape` and\n        `A.input_dtype == B.input_dtype`.\n\n        Args:\n            A: First (left) :class:`LinearOperator`.\n            B: Second (right) :class:`LinearOperator`.\n            jit: If ``True``, call :meth:`~.LinearOperator.jit()` on this\n                :class:`LinearOperator` to jit the forward, adjoint, and\n                gram functions. Same as calling\n                :meth:`~.LinearOperator.jit` after the\n                :class:`LinearOperator` is created.\n        \"\"\"\n        if not isinstance(A, LinearOperator):\n            raise TypeError(\n                \"The first argument to ComposedLinearOperator must be a LinearOperator; \"\n                f\"got {type(A)}.\"\n            )\n        if not isinstance(B, LinearOperator):\n            raise TypeError(\n                \"The second argument to ComposedLinearOperator must be a LinearOperator; \"\n                f\"got {type(B)}.\"\n            )\n        if A.input_shape != B.output_shape:\n            raise ValueError(f\"Incompatable LinearOperator shapes {A.shape}, {B.shape}.\")\n        if A.input_dtype != B.output_dtype:\n            raise ValueError(\n                f\"Incompatable LinearOperator dtypes {A.input_dtype}, {B.output_dtype}.\"\n            )\n\n        self.A = A\n        self.B = B\n\n        super().__init__(\n            input_shape=self.B.input_shape,\n            output_shape=self.A.output_shape,\n            input_dtype=self.B.input_dtype,\n            output_dtype=self.A.output_dtype,\n            eval_fn=lambda x: self.A(self.B(x)),\n            adj_fn=lambda z: self.B.adj(self.A.adj(z)),\n            jit=jit,\n        )\n"
  },
  {
    "path": "scico/linop/_matrix.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Matrix linear operator classes.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nimport operator\nfrom functools import partial, wraps\n\nimport numpy as np\n\nimport jax.numpy as jnp\nfrom jax.typing import ArrayLike\n\nimport scico.numpy as snp\nfrom scico.operator._operator import Operator\n\nfrom ._diag import Identity\nfrom ._linop import LinearOperator\n\n\ndef _wrap_add_sub_matrix(func, op):\n    @wraps(func)\n    def wrapper(a, b):\n        if np.isscalar(b):\n            return MatrixOperator(op(a.A, b))\n\n        if isinstance(b, MatrixOperator):\n            if a.shape == b.shape:\n                return MatrixOperator(op(a.A, b.A))\n\n            raise ValueError(f\"MatrixOperator shapes {a.shape} and {b.shape} do not match.\")\n\n        if isinstance(b, (jnp.ndarray, np.ndarray)):\n            if a.matrix_shape == b.shape:\n                return MatrixOperator(op(a.A, b))\n\n            raise ValueError(f\"Shapes {a.matrix_shape} and {b.shape} do not match.\")\n\n        if isinstance(b, Operator):\n            if a.shape != b.shape:\n                raise ValueError(f\"Shapes {a.shape} and {b.shape} do not match.\")\n\n        if isinstance(b, LinearOperator):\n            uwfunc = getattr(LinearOperator, func.__name__)._unwrapped\n            return uwfunc(a, b)\n\n        if isinstance(b, Operator):\n            uwfunc = getattr(Operator, func.__name__)\n            return uwfunc(a, b)\n\n        raise TypeError(f\"Operation {func.__name__} not defined between {type(a)} and {type(b)}.\")\n\n    return wrapper\n\n\nclass MatrixOperator(LinearOperator):\n    \"\"\"Linear operator implementing matrix multiplication.\"\"\"\n\n    def __init__(self, A: ArrayLike, input_cols: int = 0):\n        \"\"\"\n        Args:\n            A: Dense array. The action of the created\n                :class:`.LinearOperator` will\n                implement matrix multiplication with `A`.\n            input_cols: If this parameter is set to the default of 0, the\n                :class:`MatrixOperator` takes a vector (one-dimensional\n                array) input. If the input is intended to be a matrix\n                (two-dimensional array), this parameter should specify\n                number of columns in the matrix.\n        \"\"\"\n        self.A: snp.Array  #: Dense array implementing this matrix\n\n        # Ensure that A is a numpy or jax array.\n        if not snp.util.is_arraylike(A):\n            raise TypeError(f\"Expected numpy or jax array, got {type(A)}.\")\n        self.A = A\n\n        # Can only do rank-2 arrays\n        if A.ndim != 2:\n            raise TypeError(f\"Expected a two-dimensional array, got array of shape {A.shape}.\")\n\n        self.__array__ = A.__array__  # enables jnp.array(H)\n\n        if input_cols == 0:\n            input_shape = A.shape[1]\n            output_shape = A.shape[0]\n        else:\n            input_shape = (A.shape[1], input_cols)\n            output_shape = (A.shape[0], input_cols)\n\n        super().__init__(\n            input_shape=input_shape, output_shape=output_shape, input_dtype=self.A.dtype\n        )\n\n    def __call__(self, other):\n        if isinstance(other, LinearOperator):\n            if self.input_shape == other.output_shape:\n                if isinstance(other, Identity):\n                    return self\n\n                if isinstance(other, MatrixOperator):\n                    return MatrixOperator(A=self.A @ other.A)\n\n                # must be a generic linop so return composition of the two\n                return LinearOperator(\n                    input_shape=other.input_shape,\n                    output_shape=self.output_shape,\n                    eval_fn=lambda x: self(other(x)),\n                    input_dtype=self.input_dtype,\n                )\n\n            raise ValueError(\n                \"Cannot compute MatrixOperator-LinearOperator product, \"\n                f\"{other.output_shape} does not match {self.input_shape}.\"\n            )\n\n        return self._eval(other)\n\n    def _eval(self, other):\n        return self.A @ other\n\n    def gram(self, other):\n        return self.A.conj().T @ self.A @ other\n\n    @partial(_wrap_add_sub_matrix, op=operator.add)\n    def __add__(self, other):\n        pass\n\n    @partial(_wrap_add_sub_matrix, op=operator.sub)\n    def __sub__(self, other):\n        pass\n\n    def __radd__(self, other):\n        # Addition is commutative\n        return self + other\n\n    def __rsub__(self, other):\n        return -self + other\n\n    def __neg__(self):\n        return MatrixOperator(-self.A)\n\n    # Could write another wrapper for mul, truediv, and rtuediv, but there is\n    # no operator.__rtruediv__;  have to write that case out manually anyway.\n    def __mul__(self, other):\n        if np.isscalar(other):\n            return MatrixOperator(other * self.A)\n\n        if isinstance(other, MatrixOperator):\n            if self.shape == other.shape:\n                return MatrixOperator(self.A * other.A)\n\n            raise ValueError(f\"Shapes {self.shape} and {other.shape} do not match.\")\n\n        if isinstance(other, (jnp.ndarray, np.ndarray)):\n            if self.matrix_shape == other.shape:\n                return MatrixOperator(self.A * other)\n\n            raise ValueError(f\"Shapes {self.matrix_shape} and {other.shape} do not match.\")\n\n        # includes generic LinearOperator\n        raise TypeError(f\"Operation __mul__ not defined between {type(self)} and {type(other)}.\")\n\n    def __rmul__(self, other):\n        # multiplication is commutative\n        return self * other\n\n    def __truediv__(self, other):\n        if np.isscalar(other):\n            return MatrixOperator(self.A / other)\n\n        if isinstance(other, MatrixOperator):\n            if self.shape == other.shape:\n                return MatrixOperator(self.A / other.A)\n            raise ValueError(f\"Shapes {self.shape} and {other.shape} do not match.\")\n\n        if isinstance(other, (jnp.ndarray, np.ndarray)):\n            if self.matrix_shape == other.shape:\n                return MatrixOperator(self.A / other)\n\n            raise ValueError(f\"Shapes {self.matrix_shape} and {other.shape} do not match.\")\n\n        raise TypeError(\n            f\"Operation __truediv__ not defined between {type(self)} and {type(other)}.\"\n        )\n\n    def __rtruediv__(self, other):\n        if np.isscalar(other):\n            return MatrixOperator(other / self.A)\n\n        if isinstance(other, (jnp.ndarray, np.ndarray)):\n            if self.matrix_shape == other.shape:\n                return MatrixOperator(other / self.A)\n\n            raise ValueError(f\"Shapes {other.shape} and {self.matrix_shape} do not match.\")\n\n        raise TypeError(\n            f\"Operation __truediv__ not defined between {type(other)} and {type(self)}.\"\n        )\n\n    def __getitem__(self, key):\n        return self.A[key]\n\n    @property\n    def T(self):\n        \"\"\"Transpose of this :class:`.MatrixOperator`.\n\n        Return a :class:`.MatrixOperator` corresponding to the transpose\n        of this matrix.\n        \"\"\"\n        return MatrixOperator(self.A.T)\n\n    @property\n    def H(self):\n        \"\"\"Hermitian (conjugate) transpose of this :class:`.MatrixOperator`.\n\n        Return a :class:`.MatrixOperator` corresponding to the Hermitian\n        (conjugate) transpose of this matrix.\n        \"\"\"\n        return MatrixOperator(self.A.conj().T)\n\n    def conj(self):\n        \"\"\"Complex conjugate of this :class:`.MatrixOperator`.\n\n        Return a :class:`.MatrixOperator` with complex conjugated\n        elements.\n        \"\"\"\n        return MatrixOperator(A=self.A.conj())\n\n    def adj(self, y):\n        return self.A.conj().T @ y\n\n    def to_array(self):\n        \"\"\"Return a :class:`numpy.ndarray` containing `self.A`.\"\"\"\n        return np.array(self.A)\n\n    @property\n    def gram_op(self):\n        \"\"\"Gram operator of this :class:`.MatrixOperator`.\n\n        Return a new :class:`.LinearOperator` `G` such that\n        `G(x) = A.adj(A(x)))`.\"\"\"\n        return MatrixOperator(A=self.A.conj().T @ self.A)\n\n    def norm(self, ord=None, axis=None, keepdims=False):  # pylint: disable=W0622\n        \"\"\"Compute the norm of the dense matrix `self.A`.\n\n        Call :func:`scico.numpy.linalg.norm` on the dense matrix `self.A`.\n        \"\"\"\n        return snp.linalg.norm(self.A, ord=ord, axis=axis, keepdims=keepdims)\n"
  },
  {
    "path": "scico/linop/_stack.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Stack of linear operators classes.\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Any, List, Optional, Sequence, Union\n\nimport scico.numpy as snp\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.util import normalize_axes\nfrom scico.operator._stack import DiagonalReplicated as DiagonalReplicatedOperator\nfrom scico.operator._stack import DiagonalStack as DiagonalStackOperator\nfrom scico.operator._stack import VerticalStack as VerticalStackOperator\nfrom scico.typing import Axes, Shape\n\nfrom ._linop import LinearOperator\n\n\nclass VerticalStack(VerticalStackOperator, LinearOperator):\n    r\"\"\"A vertical stack of linear operators.\n\n    Given linear operators :math:`A_1, A_2, \\dots, A_N`, create the\n    linear operator\n\n    .. math::\n       H =\n       \\begin{pmatrix}\n            A_1 \\\\\n            A_2 \\\\\n            \\vdots \\\\\n            A_N \\\\\n       \\end{pmatrix} \\qquad\n       \\text{such that} \\qquad\n       H \\mb{x}\n       =\n       \\begin{pmatrix}\n            A_1(\\mb{x}) \\\\\n            A_2(\\mb{x}) \\\\\n            \\vdots \\\\\n            A_N(\\mb{x}) \\\\\n       \\end{pmatrix} \\;.\n    \"\"\"\n\n    def __init__(\n        self,\n        ops: Sequence[LinearOperator],\n        collapse_output: Optional[bool] = True,\n        jit: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            ops: Linear operators to stack.\n            collapse_output: If ``True`` and the output would be a\n                :class:`BlockArray` with shape ((m, n, ...), (m, n, ...),\n                ...), the output is instead a :class:`jax.Array` with\n                shape (S, m, n, ...) where S is the length of `ops`.\n            jit: See `jit` in :class:`LinearOperator`.\n        \"\"\"\n        if not all(isinstance(op, LinearOperator) for op in ops):\n            raise TypeError(\"All elements of 'ops' must be of type LinearOperator.\")\n\n        super().__init__(ops=ops, collapse_output=collapse_output, jit=jit, **kwargs)\n\n    def _adj(self, y: Union[Array, BlockArray]) -> Array:  # type: ignore\n        return sum([op.adj(y_block) for y_block, op in zip(y, self.ops)])  # type: ignore\n\n\nclass DiagonalStack(DiagonalStackOperator, LinearOperator):\n    r\"\"\"A diagonal stack of linear operators.\n\n    Given linear operators :math:`A_1, A_2, \\dots, A_N`, create the\n    linear operator\n\n    .. math::\n       H =\n       \\begin{pmatrix}\n            A_1 & 0   & \\ldots & 0\\\\\n            0   & A_2 & \\ldots & 0\\\\\n            \\vdots & \\vdots & \\ddots & \\vdots\\\\\n            0   & 0 & \\ldots & A_N \\\\\n       \\end{pmatrix} \\qquad\n       \\text{such that} \\qquad\n       H\n       \\begin{pmatrix}\n            \\mb{x}_1 \\\\\n            \\mb{x}_2 \\\\\n            \\vdots \\\\\n            \\mb{x}_N \\\\\n       \\end{pmatrix}\n       =\n       \\begin{pmatrix}\n            A_1(\\mb{x}_1) \\\\\n            A_2(\\mb{x}_2) \\\\\n            \\vdots \\\\\n            A_N(\\mb{x}_N) \\\\\n       \\end{pmatrix} \\;.\n\n    By default, if the inputs :math:`\\mb{x}_1, \\mb{x}_2, \\dots,\n    \\mb{x}_N` all have the same (possibly nested) shape, `S`, this\n    operator will work on the stack, i.e., have an input shape of `(N,\n    *S)`. If the inputs have distinct shapes, `S1`, `S2`, ..., `SN`,\n    this operator will work on the block concatenation, i.e.,\n    have an input shape of `(S1, S2, ..., SN)`. The same holds for the\n    output shape.\n    \"\"\"\n\n    def __init__(\n        self,\n        ops: Sequence[LinearOperator],\n        collapse_input: Optional[bool] = True,\n        collapse_output: Optional[bool] = True,\n        jit: bool = True,\n        **kwargs,\n    ):\n        \"\"\"\n        Args:\n            ops: Operators to stack.\n            collapse_input: If ``True``, inputs are expected to be\n                stacked along the first dimension when possible.\n            collapse_output: If ``True``, the output will be\n                stacked along the first dimension when possible.\n            jit: See `jit` in :class:`LinearOperator`.\n\n        \"\"\"\n        if not all(isinstance(op, LinearOperator) for op in ops):\n            raise TypeError(\"All elements of 'ops' must be of type LinearOperator.\")\n\n        super().__init__(\n            ops=ops,\n            collapse_input=collapse_input,\n            collapse_output=collapse_output,\n            jit=jit,\n            **kwargs,\n        )\n\n    def _adj(self, y: Union[Array, BlockArray]) -> Union[Array, BlockArray]:  # type: ignore\n        result = tuple(op.T @ y_n for op, y_n in zip(self.ops, y))  # type: ignore\n        if self.collapse_input:\n            return snp.stack(result)\n        return snp.blockarray(result)\n\n\nclass DiagonalReplicated(DiagonalReplicatedOperator, LinearOperator):\n    r\"\"\"A diagonal stack constructed from a single linear operator.\n\n    Given linear operator :math:`A`, create the linear operator\n\n    .. math::\n       H =\n       \\begin{pmatrix}\n            A & 0   & \\ldots & 0\\\\\n            0   & A & \\ldots & 0\\\\\n            \\vdots & \\vdots & \\ddots & \\vdots\\\\\n            0   & 0 & \\ldots & A \\\\\n       \\end{pmatrix} \\qquad\n       \\text{such that} \\qquad\n       H\n       \\begin{pmatrix}\n            \\mb{x}_1 \\\\\n            \\mb{x}_2 \\\\\n            \\vdots \\\\\n            \\mb{x}_N \\\\\n       \\end{pmatrix}\n       =\n       \\begin{pmatrix}\n            A(\\mb{x}_1) \\\\\n            A(\\mb{x}_2) \\\\\n            \\vdots \\\\\n            A(\\mb{x}_N) \\\\\n       \\end{pmatrix} \\;.\n\n    The application of :math:`A` to each component :math:`\\mb{x}_k` is\n    computed using :func:`jax.pmap` or :func:`jax.vmap`. The input shape\n    for linear operator :math:`A` should exclude the array axis on which\n    :math:`A` is replicated to form :math:`H`. For example, if :math:`A`\n    has input shape `(3, 4)` and :math:`H` is constructed to replicate\n    on axis 0 with 2 replicates, the input shape of :math:`H` will be\n    `(2, 3, 4)`.\n\n    Linear operators taking :class:`.BlockArray` input are not supported.\n    \"\"\"\n\n    def __init__(\n        self,\n        op: LinearOperator,\n        replicates: int,\n        input_axis: int = 0,\n        output_axis: Optional[int] = None,\n        map_type: str = \"auto\",\n        **kwargs,\n    ):\n        \"\"\"\n        Args:\n            op: Linear operator to replicate.\n            replicates: Number of replicates of `op`.\n            input_axis: Input axis over which `op` should be replicated.\n            output_axis: Index of replication axis in output array.\n               If ``None``, the input replication axis is used.\n            map_type: If \"pmap\" or \"vmap\", apply replicated mapping using\n               :func:`jax.pmap` or :func:`jax.vmap` respectively. If\n               \"auto\", use :func:`jax.pmap` if sufficient devices are\n               available for the number of replicates, otherwise use\n               :func:`jax.vmap`.\n        \"\"\"\n        if not isinstance(op, LinearOperator):\n            raise TypeError(\"Argument 'op' must be of type LinearOperator.\")\n\n        super().__init__(\n            op,\n            replicates,\n            input_axis=input_axis,\n            output_axis=output_axis,\n            map_type=map_type,\n            **kwargs,\n        )\n\n        self._adj = self.jaxmap(op.adj, in_axes=self.input_axis, out_axes=self.output_axis)\n\n\ndef linop_over_axes(\n    linop: type[LinearOperator],\n    input_shape: Shape,\n    *args: Any,\n    axes: Optional[Axes] = None,\n    **kwargs: Any,\n) -> List[LinearOperator]:\n    \"\"\"Construct a list of :class:`LinearOperator` by iterating over axes.\n\n    Construct a list of :class:`LinearOperator` by iterating over a\n    specified sequence of axes, passing each value in sequence to the\n    `axis` keyword argument of the :class:`LinearOperator` initializer.\n\n    Args:\n        linop: Type of :class:`LinearOperator` to construct for each axis.\n        input_shape: Shape of input array.\n        *args: Positional arguments for the :class:`LinearOperator`\n            initializer.\n        axes: Axis or axes over which to construct the list. If not\n            specified, or ``None``, use all axes corresponding to\n            `input_shape`.\n        **kwargs: Keyword arguments for the :class:`LinearOperator`\n            initializer.\n\n    Returns:\n        A tuple (`axes`, `ops`) where `axes` is a tuple of the axes used\n        to construct the list of :class:`LinearOperator`, and `ops` is\n        the list itself.\n    \"\"\"\n    axes = normalize_axes(axes, input_shape)  # type: ignore\n    return axes, [linop(input_shape, *args, axis=axis, **kwargs) for axis in axes]  # type: ignore\n"
  },
  {
    "path": "scico/linop/_util.py",
    "content": "# Copyright (C) 2020-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Linear operator utility functions.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom typing import Optional, Union\n\nimport scico.numpy as snp\nfrom scico.operator._operator import Operator\nfrom scico.random import randn\nfrom scico.typing import PRNGKey\n\nfrom ._linop import LinearOperator\n\n\ndef power_iteration(A: LinearOperator, maxiter: int = 100, key: Optional[PRNGKey] = None):\n    \"\"\"Compute largest eigenvalue of a diagonalizable :class:`.LinearOperator`.\n\n    Compute largest eigenvalue of a diagonalizable\n    :class:`.LinearOperator` using power iteration.\n\n    Args:\n        A: :class:`.LinearOperator` used for computation. Must be\n            diagonalizable.\n        maxiter: Maximum number of power iterations to use.\n        key: Jax PRNG key. Defaults to ``None``, in which case a new key\n            is created.\n\n    Returns:\n        tuple: A tuple (`mu`, `v`) containing:\n\n            - **mu**: Estimate of largest eigenvalue of `A`.\n            - **v**: Eigenvector of `A` with eigenvalue `mu`.\n\n    \"\"\"\n    v, key = randn(shape=A.input_shape, key=key, dtype=A.input_dtype)\n    v = v / snp.linalg.norm(v)\n\n    for i in range(maxiter):\n        Av = A @ v\n        normAv = snp.linalg.norm(Av)\n        if normAv == 0.0:  # Assume that ||Av|| == 0 implies A is a zero operator\n            mu = 0.0\n            v = Av\n            break\n        mu = snp.sum(v.conj() * Av) / snp.linalg.norm(v) ** 2\n        v = Av / normAv\n    return mu, v\n\n\ndef operator_norm(A: LinearOperator, maxiter: int = 100, key: Optional[PRNGKey] = None):\n    r\"\"\"Estimate the norm of a :class:`.LinearOperator`.\n\n    Estimate the operator norm\n    `induced <https://en.wikipedia.org/wiki/Matrix_norm#Matrix_norms_induced_by_vector_norms>`_\n    by the :math:`\\ell_2` vector norm, i.e. for :class:`.LinearOperator`\n    :math:`A`,\n\n    .. math::\n       \\| A \\|_2 &= \\max \\{ \\| A \\mb{x} \\|_2 \\, : \\, \\| \\mb{x} \\|_2 \\leq 1 \\} \\\\\n                 &= \\sqrt{ \\lambda_{ \\mathrm{max} }( A^H A ) }\n                 = \\sigma_{\\mathrm{max}}(A) \\;,\n\n    where :math:`\\lambda_{\\mathrm{max}}(B)` and\n    :math:`\\sigma_{\\mathrm{max}}(B)` respectively denote the\n    largest eigenvalue of :math:`B` and the largest singular value of\n    :math:`B`. The value is estimated via power iteration, using\n    :func:`power_iteration`, to estimate\n    :math:`\\lambda_{\\mathrm{max}}(A^H A)`.\n\n    Args:\n        A: :class:`.LinearOperator` for which operator norm is desired.\n        maxiter: Maximum number of power iterations to use. Default: 100\n        key: Jax PRNG key. Defaults to ``None``, in which case a new key\n            is created.\n\n    Returns:\n        float: Norm of operator :math:`A`.\n\n    \"\"\"\n    return snp.sqrt(power_iteration(A.H @ A, maxiter, key)[0].real)\n\n\ndef valid_adjoint(\n    A: LinearOperator,\n    AT: LinearOperator,\n    eps: Optional[float] = 1e-7,\n    x: Optional[snp.Array] = None,\n    y: Optional[snp.Array] = None,\n    key: Optional[PRNGKey] = None,\n) -> Union[bool, float]:\n    r\"\"\"Check whether :class:`.LinearOperator` `AT` is the adjoint of `A`.\n\n    Check whether :class:`.LinearOperator` :math:`\\mathsf{AT}` is the\n    adjoint of :math:`\\mathsf{A}`. The test exploits the identity\n\n    .. math::\n      \\mathbf{y}^T (A \\mathbf{x}) = (\\mathbf{y}^T A) \\mathbf{x} =\n      (A^T \\mathbf{y})^T \\mathbf{x}\n\n    by computing :math:`\\mathbf{u} = \\mathsf{A}(\\mathbf{x})` and\n    :math:`\\mathbf{v} = \\mathsf{AT}(\\mathbf{y})` for random\n    :math:`\\mathbf{x}` and :math:`\\mathbf{y}` and confirming that\n\n    .. math::\n      \\frac{| \\mathbf{y}^T \\mathbf{u} - \\mathbf{v}^T \\mathbf{x} |}\n      {\\max \\left\\{ | \\mathbf{y}^T \\mathbf{u} |,\n       | \\mathbf{v}^T \\mathbf{x} | \\right\\}}\n      < \\epsilon \\;.\n\n    If :math:`\\mathsf{A}` is a complex operator (with a complex\n    `input_dtype`) then the test checks whether :math:`\\mathsf{AT}` is\n    the Hermitian conjugate of :math:`\\mathsf{A}`, with a test as above,\n    but with all the :math:`(\\cdot)^T` replaced with :math:`(\\cdot)^H`.\n\n    Args:\n        A: Primary :class:`.LinearOperator`.\n        AT: Adjoint :class:`.LinearOperator`.\n        eps: Error threshold for validation of :math:`\\mathsf{AT}` as\n           adjoint of :math:`\\mathsf{AT}`. If ``None``, the relative\n           error is returned instead of a boolean value.\n        x: If not the default ``None``, use the specified array instead\n           of a random array as test vector :math:`\\mb{x}`. If specified,\n           the array must have shape `A.input_shape`.\n        y: If not the default ``None``, use the specified array instead\n           of a random array as test vector :math:`\\mb{y}`. If specified,\n           the array must have shape `AT.input_shape`.\n        key: Jax PRNG key. Defaults to ``None``, in which case a new key\n           is created.\n\n    Returns:\n      Boolean value indicating whether validation passed, or relative\n      error of test, depending on type of parameter `eps`.\n    \"\"\"\n\n    if x is None:\n        x, key = randn(shape=A.input_shape, key=key, dtype=A.input_dtype)\n    else:\n        if x.shape != A.input_shape:\n            raise ValueError(\"Shape of 'x' array not appropriate as an input for operator 'A'.\")\n    if y is None:\n        y, key = randn(shape=AT.input_shape, key=key, dtype=AT.input_dtype)\n    else:\n        if y.shape != AT.input_shape:\n            raise ValueError(\"Shape of 'y' array not appropriate as an input for operator AT.\")\n\n    u = A(x)\n    v = AT(y)\n    yTu = snp.sum(y.conj() * u)  # type: ignore\n    vTx = snp.sum(v.conj() * x)  # type: ignore\n    err = snp.abs(yTu - vTx) / max(snp.abs(yTu), snp.abs(vTx))\n    if eps is None:\n        return err\n    return float(err) < eps\n\n\ndef jacobian(F: Operator, u: snp.Array, include_eval: Optional[bool] = False) -> LinearOperator:\n    \"\"\"Construct Jacobian linear operator for a general operator.\n\n    For a specified :class:`.Operator`, construct a corresponding\n    Jacobian :class:`LinearOperator`, the application of which is\n    equivalent to multiplication by the Jacobian of the\n    :class:`.Operator` at a specified input value.\n\n    The implementation of this function is based on :meth:`.Operator.jvp`\n    and :meth:`.Operator.vjp`, which are themselves based on\n    :func:`jax.jvp` and :func:`jax.vjp`. For reasons of computational\n    efficiency, these functions return the value of the :class:`.Operator`\n    evaluated at the specified point in addition to the requested\n    Jacobian-vector product. If the `include_eval` parameter of this\n    function is ``True``, the constructed :class:`LinearOperator` returns\n    a :class:`.BlockArray` output, the first component of which is the\n    result of the :class:`.Operator` evaluation, and the second component\n    of which is the requested Jacobian-vector product. If `include_eval`\n    is ``False``, then the :class:`.Operator` evaluation computed by\n    :func:`jax.jvp` and :func:`jax.vjp` are discarded.\n\n    Args:\n        F: :class:`.Operator` of which the Jacobian is to be computed.\n        u: Input value of the :class:`.Operator` at which the Jacobian is\n           to be computed.\n        include_eval: Flag indicating whether the result of evaluating\n           the :class:`.Operator` should be included (as the first\n           component of a :class:`.BlockArray`) in the output of the\n           Jacobian :class:`LinearOperator` constructed by this function.\n\n    Returns:\n      A :class:`LinearOperator` capable of computing Jacobian-vector\n      products.\n    \"\"\"\n    if include_eval:\n        Fu, G = F.vjp(u, conjugate=True)\n\n        def adj_fn(v):\n            return snp.blockarray((Fu, G(v)))\n\n        def eval_fn(v):\n            return snp.blockarray(F.jvp(u, v))\n\n    else:\n        adj_fn = F.vjp(u, conjugate=True)[1]\n\n        def eval_fn(v):\n            return F.jvp(u, v)[1]\n\n    return LinearOperator(\n        F.input_shape,\n        output_shape=F.output_shape,\n        eval_fn=eval_fn,\n        adj_fn=adj_fn,\n        input_dtype=F.input_dtype,\n        output_dtype=F.output_dtype,\n    )\n"
  },
  {
    "path": "scico/linop/optics.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\nr\"\"\"Optical propagator classes.\n\nThis module provides classes that model the propagation of a\nmonochromatic waveform between two parallel planes in a homogeneous\nmedium. The corresponding linear operators are referred to here as\n\"propagators\", which represents a departure from standard terminology,\nin which \"propagator\" refers specifically to the Fourier domain\ncomponent of the linear operator, i.e. if the full linear operator\ncan be written as :math:`F^{-1} D F` where :math:`F` is the Fourier\ntransform, then :math:`D` is usually referred to as the propagator.\n\n\nThe following notation is used throughout the module:\n\n.. math ::\n     \\begin{align}\n     \\Delta x, \\Delta y  & \\quad \\text{Sampling intervals in } x\n     \\text{ and } y \\text{ axes}\\\\\n     z  & \\quad \\text{Propagation distance} \\;\\; (z \\geq 0) \\\\\n     N_x, N_y  & \\quad \\text{Number of samples in } x \\text{ and } y\n     \\text{ axes}\\\\\n     k_0 & \\quad \\text{Illumination wavenumber corresponding to } 2\\pi /\n     \\text{wavelength} \\;.\n     \\end{align}\n\nVariables :math:`\\Delta x, \\Delta y, z,` and :math:`k_0` represent\nphysical quantities. Any units may be chosen, but they must be consistent\nacross all of these variables, e.g. m (metres) for :math:`\\Delta x,\n\\Delta y, z,` and :math:`\\mathrm{m}^{-1}` for :math:`k_0`, as well as with\nthe units for the physical dimensions of the source wavefield.\n\nSubscripts :math:`S` and :math:`D` are used to refer to the source and\ndestination planes respectively when it is necessary to distinguish\nbetween them. In the absence of subscripts, the variables refer to the\nsource plane (e.g. both :math:`\\Delta x` and :math:`\\Delta x_S` refer to\nthe :math:`x`-axis sampling interval in the source plane, while\n:math:`\\Delta x_D` refers to it in the destination plane).\n\nNote that :math:`x` corresponds to axis 0 (rows, increasing downwards)\nand :math:`y` to axis 1 (columns, increasing to the right).\n\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom typing import Any, Tuple, Union\n\nimport numpy as np\nfrom numpy.lib.scimath import sqrt  # complex sqrt\n\nfrom typing_extensions import TypeGuard\n\nimport scico.numpy as snp\nfrom scico.linop import Diagonal, Identity, LinearOperator\nfrom scico.numpy.util import no_nan_divide\nfrom scico.typing import Shape\n\nfrom ._dft import DFT\n\n\ndef _isscalar(element: Any) -> TypeGuard[Union[int, float]]:\n    \"\"\"Type guard interface to `snp.isscalar`.\"\"\"\n    return snp.isscalar(element)\n\n\ndef radial_transverse_frequency(\n    input_shape: Shape, dx: Union[float, Tuple[float, ...]]\n) -> np.ndarray:\n    r\"\"\"Construct radial Fourier coordinate system.\n\n    Args:\n        input_shape: Tuple of length 1 or 2 containing the number of\n            samples per dimension, i.e. :math:`(N_x,)` or\n            :math:`(N_x, N_y)`\n        dx: Sampling interval at source plane. If a float and\n            `len(input_shape)==2` the same sampling interval is applied\n            to both dimensions. If `dx` is a tuple, it must have same\n            length as `input_shape`, and corresponds to either\n            :math:`(\\Delta x,)` or :math:`(\\Delta x, \\Delta y)`.\n\n    Returns:\n        If `len(input_shape)==1`, returns an ndarray containing\n        corresponding Fourier coordinates. If `len(input_shape) == 2`,\n        returns an ndarray containing the radial Fourier coordinates\n        :math:`\\sqrt{k_x^2 + k_y^2}\\,`.\n    \"\"\"\n    ndim: int = len(input_shape)  # 1 or 2 dimensions\n    if ndim not in (1, 2):\n        raise ValueError(\"Invalid input dimensions; must be 1 or 2.\")\n\n    if _isscalar(dx):\n        dx = (dx,) * ndim\n    else:\n        assert isinstance(dx, tuple)\n        if len(dx) != ndim:\n            raise ValueError(\n                \"Argument 'dx' must be a scalar or have len(dx) == len(input_shape); \"\n                f\"got len(dx)={len(dx)}, len(input_shape)={ndim}.\"\n            )\n    assert isinstance(dx, tuple)\n\n    if ndim == 1:\n        kx = 2 * np.pi * np.fft.fftfreq(input_shape[0], dx[0])\n        kp = kx\n    elif ndim == 2:\n        kx = 2 * np.pi * np.fft.fftfreq(input_shape[0], dx[0])\n        ky = 2 * np.pi * np.fft.fftfreq(input_shape[1], dx[1])\n        kp = np.sqrt(kx[None, :] ** 2 + ky[:, None] ** 2)\n    return kp\n\n\nclass Propagator(LinearOperator):\n    \"\"\"Base class for angular spectrum and Fresnel propagators.\"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        dx: Union[float, Tuple[float, ...]],\n        k0: float,\n        z: float,\n        pad_factor: int = 1,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array as a tuple of length\n               1 or 2, corresponding to :math:`(N_x,)` or\n               :math:`(N_x, N_y)`.\n            dx: Sampling interval at source plane. If a float and\n               `len(input_shape)==2` the same sampling interval is applied\n               to both dimensions. If `dx` is a tuple, it must have same\n               length as `input_shape`, and corresponds to either\n               :math:`(\\Delta x,)` or :math:`(\\Delta x, \\Delta y)`.\n            k0: Illumination wavenumber, :math:`k_0`, corresponding to\n               :math:`2 \\pi` / wavelength.\n            z: Propagation distance, :math:`z`.\n            pad_factor: The padded input shape is the input shape\n               multiplied by this integer factor.\n        \"\"\"\n        ndim = len(input_shape)  # 1 or 2 dimensions\n        if ndim not in (1, 2):\n            raise ValueError(\"Invalid input dimensions; must be 1 or 2.\")\n\n        if _isscalar(dx):\n            dx = (dx,) * ndim\n        else:\n            assert isinstance(dx, tuple)\n            if len(dx) != ndim:\n                raise ValueError(\n                    \"Argument 'dx' must be a scalar or have len(dx) == len(input_shape); \"\n                    f\"got len(dx)={len(dx)}, len(input_shape)={ndim}.\"\n                )\n        assert isinstance(dx, tuple)\n\n        #: Illumination wavenumber; 2𝜋/wavelength\n        self.k0: float = k0\n        #: Shape of input after padding\n        self.padded_shape: Shape = tuple(pad_factor * s for s in input_shape)\n        #: Padded source plane side length (dx[i] * padded_shape[i])\n        self.L: Tuple[float, ...] = tuple(\n            s * d for s, d in zip(self.padded_shape, dx)\n        )  # computational plane size\n        #: Transverse Fourier coordinates (radial)\n        self.kp = radial_transverse_frequency(self.padded_shape, dx)\n        #: Source plane sampling interval\n        self.dx: Union[float, Tuple[float, ...]] = dx\n        #: Propagation distance\n        self.z: float = z\n\n        # Fourier operator\n        self.F = DFT(input_shape=input_shape, axes_shape=self.padded_shape, jit=False)\n\n        # Diagonal operator; phase shifting\n        self.D: LinearOperator = Identity(self.kp.shape)\n\n        super().__init__(\n            input_shape=input_shape,\n            input_dtype=np.complex64,\n            output_shape=input_shape,\n            output_dtype=np.complex64,\n            adj_fn=None,\n            **kwargs,\n        )\n\n    def __repr__(self):\n        extra_repr = f\"\"\"  k0:  {self.k0}\n  λ:   {2*np.pi/self.k0}\n  z:   {self.z}\n  dx:  {self.dx}\n  L:   {self.L}\n\"\"\"\n        return LinearOperator.__repr__(self) + extra_repr\n\n    def _eval(self, x):\n        return self.F.inv(self.D @ self.F @ x)\n\n\nclass AngularSpectrumPropagator(Propagator):\n    r\"\"\"Angular spectrum propagator.\n\n    Propagates a planar source field with coordinates :math:`(x, y, z_0)`\n    to a destination plane at a distance :math:`z` with coordinates\n    :math:`(x, y, z_0 + z)`. The action of this linear operator is\n    given by (Eq. 3.74, :cite:`goodman-2005-fourier`)\n\n    .. math ::\n         (A \\mb{u})(x, y, z_0 + z) = \\frac{1}{2 \\pi} \\iint_{-\\infty}^{\\infty}\n         \\mb{\\hat{u}}(k_x, k_y) e^{j \\sqrt{k_0^2 - k_x^2 - k_y^2} \\,\n         z} e^{j (x k_x + y k_y) } d k_x \\ d k_y \\;,\n\n    where the :math:`\\mb{\\hat{u}}` is the Fourier transform of the\n    field :math:`\\mb{u}(x, y)` in the plane :math:`z=z_0`, given by\n\n    .. math ::\n         \\mb{\\hat{u}}(k_x, k_y) = \\iint_{-\\infty}^{\\infty}\n         \\mb{u}(x, y) e^{- j (x k_x + y k_y)} d k_x \\ d k_y \\;,\n\n    where :math:`(k_x, k_y)` are the :math:`x` and :math:`y` components\n    respectively of the wave-vector of the plane wave, and :math:`j` is\n    the imaginary unit.\n\n    The angular spectrum propagator can be written\n\n    .. math ::\n         A\\mb{u} = F^{-1} D F \\mb{u} \\;,\n\n    where :math:`F` is the Fourier transform with respect to\n    :math:`(x, y)`, :math:`F^{-1}` is the inverse transform with respect\n    to :math:`(k_x, k_y)`, and the propagator term is given by\n\n    .. math ::\n         D = \\exp \\left( j \\sqrt{k_0^2 - k_x^2 - k_y^2} \\, z \\right) \\;.\n\n    Aliasing of the wavefield at the destination plane is avoided when\n    the propagator term is adequately sampled according to\n    :cite:`voelz-2009-digital`\n\n    .. math ::\n         (\\Delta x)^2 \\geq \\frac{\\pi}{k_0 N_x} \\sqrt{ (\\Delta x)^2 N_x^2 +\n         4 z^2} \\quad \\text{and} \\quad\n         (\\Delta y)^2 \\geq \\frac{\\pi}{k_0 N_y} \\sqrt{ (\\Delta y)^2 N_y^2 +\n         4 z^2} \\;.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        dx: Union[float, Tuple[float, ...]],\n        k0: float,\n        z: float,\n        pad_factor: int = 1,\n        jit: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array. Can be a tuple of length\n               2 or 3.\n            dx: Sampling interval, :math:`\\Delta x`, at source plane. If\n               a float and `len(input_shape)==2` the same sampling\n               interval is applied to both dimensions. If `dx` is a tuple,\n               must have same length as `input_shape`.\n            k0: Illumination wavenumber, :math:`k_0`, corresponding to\n               :math:`2 \\pi` / wavelength.\n            z: Propagation distance, :math:`z`.\n            pad_factor: The padded input shape is the input shape\n               multiplied by this integer factor.\n            jit: If ``True``, call :meth:`~.Operator.jit` on this\n               :class:`LinearOperator` to jit the forward, adjoint, and\n               gram functions. Same as calling :meth:`~.Operator.jit`\n               after the :class:`LinearOperator` is created.\n        \"\"\"\n        # Diagonal operator; phase shifting\n        super().__init__(\n            input_shape=input_shape, dx=dx, k0=k0, z=z, pad_factor=pad_factor, **kwargs\n        )\n\n        self.phase = snp.exp(1j * z * sqrt(self.k0**2 - self.kp**2)).astype(np.complex64)\n        self.D = Diagonal(self.phase)\n        self._set_adjoint()\n\n        if jit:\n            self.jit()\n\n    def adequate_sampling(self):\n        r\"\"\"Verify the angular spectrum kernel is not aliased.\n\n        Checks the condition for adequate sampling\n        :cite:`voelz-2009-digital`,\n\n         .. math ::\n             (\\Delta x)^2 \\geq \\frac{\\pi}{k_0 N_x} \\sqrt{ (\\Delta x)^2 N_x^2 +\n             4 z^2} \\quad \\text{and} \\quad\n             (\\Delta y)^2 \\geq \\frac{\\pi}{k_0 N_y} \\sqrt{ (\\Delta y)^2 N_y^2 +\n             4 z^2} \\;.\n\n        Returns:\n             ``True`` if the angular spectrum kernel is adequately sampled,\n             ``False`` otherwise.\n        \"\"\"\n        tmp = []\n        for d, N in zip(self.dx, self.padded_shape):\n            tmp.append(d**2 > np.pi / (self.k0 * N) * np.sqrt(d**2 * N**2 + 4 * self.z**2))\n        return np.all(tmp)\n\n    def pinv(self, y):\n        \"\"\"Apply pseudoinverse of Angular Spectrum propagator.\"\"\"\n        diag_inv = no_nan_divide(1, self.D.diagonal)\n        return self.F.inv(diag_inv * self.F(y))\n\n\nclass FresnelPropagator(Propagator):\n    r\"\"\"Fresnel (small-angle/paraxial) propagator.\n\n    Propagates a planar source field with coordinates :math:`(x, y, z_0)`\n    to a destination plane at a distance :math:`z` with coordinates\n    :math:`(x, y, z_0 + z)`. The action of this linear operator is given\n    by (Eq. 4.20, :cite:`goodman-2005-fourier`)\n\n    .. math ::\n        (A \\mb{u})(x, y, z + z_0) = e^{j k_0 z} \\frac{1}{2 \\pi}\n        \\iint_{-\\infty}^{\\infty} \\mb{\\hat{u}}(k_x, k_y)\n        e^{-j \\frac{z}{2 k_0}\\left(k_x^2 + k_y^2\\right) }\n        e^{j (x k_x + y k_y) } d k_x \\ d k_y \\;,\n\n    where the :math:`\\mb{\\hat{u}}` is the Fourier transform of the field\n    in the source plane, given by\n\n    .. math ::\n        \\mb{\\hat{u}}(k_x, k_y) = \\iint_{-\\infty}^{\\infty} \\mb{u}(x, y)\n        e^{- j (x k_x + y k_y)} d k_x \\ d k_y \\;.\n\n    This linear operator is valid when :math:`k_0^2 << k_x^2 + k_y^2`.\n    The Fresnel propagator can be written\n\n    .. math ::\n        A\\mb{u} = F^{-1} D F \\mb{u} \\;,\n\n    where :math:`F` is the Fourier transform with respect to\n    :math:`(x, y)`, :math:`F^{-1}` is the inverse transform with respect\n    to :math:`(k_x, k_y)`, and the propagator term is given by\n\n    .. math ::\n        D = \\exp \\left( -j \\frac{z}{2 k_0}\\left(k_x^2 + k_y^2 \\right)\n        \\right) \\;,\n\n    where :math:`(k_x, k_y)` are the :math:`x` and :math:`y` components\n    respectively of the wave-vector of the plane wave, and :math:`j` is\n    the imaginary unit.\n\n    The propagator term is adequately sampled when\n    :cite:`voelz-2011-computational`\n\n    .. math ::\n         (\\Delta x)^2 \\geq \\frac{2 \\pi z }{k_0 N_x} \\quad \\text{and}\n         \\quad (\\Delta y)^2 \\geq \\frac{2 \\pi z }{k_0 N_y} \\;.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        dx: float,\n        k0: float,\n        z: float,\n        pad_factor: int = 1,\n        jit: bool = True,\n        **kwargs,\n    ):\n        super().__init__(\n            input_shape=input_shape, dx=dx, k0=k0, z=z, pad_factor=pad_factor, **kwargs\n        )\n\n        self.phase = snp.exp(1j * z * (self.k0 - self.kp**2 / (2 * self.k0))).astype(np.complex64)\n        self.D = Diagonal(self.phase)\n\n        self._set_adjoint()\n\n        if jit:\n            self.jit()\n\n    def adequate_sampling(self):\n        r\"\"\"Verify the Fresnel propagation kernel is not aliased.\n\n        Checks the condition for adequate sampling\n        :cite:`voelz-2011-computational`,\n\n        .. math ::\n            (\\Delta x)^2 \\geq \\frac{2 \\pi z }{k_0 N_x} \\quad \\text{and}\n            \\quad (\\Delta y)^2 \\geq \\frac{2 \\pi z }{k_0 N_y} \\;.\n\n\n        Returns:\n            ``True`` if the Fresnel propagation kernel is adequately sampled,\n            ``False`` otherwise.\n        \"\"\"\n        tmp = []\n        for d, N in zip(self.dx, self.padded_shape):\n            tmp.append(d**2 > 2 * np.pi * self.z / (self.k0 * N))\n        return np.all(tmp)\n\n\nclass FraunhoferPropagator(LinearOperator):\n    r\"\"\"Fraunhofer (far-field) propagator.\n\n    Propagates a source field with coordinates :math:`(x_S, y_S)` to\n    a destination plane at a distance :math:`z` with coordinates\n    :math:`(x_D, y_D)`.\n\n    The action of this linear operator is given by (Eq. 4.25,\n    :cite:`goodman-2005-fourier`)\n\n    .. math ::\n        (A \\mb{u})(x_D, y_D) = \\underbrace{\\frac{k_0}{2 \\pi}\n        \\frac{e^{j k_0 z}}{j z} \\mathrm{exp} \\left( j \\frac{k_0}{2 z}\n        (x_D^2 + y_D^2) \\right)}_{\\triangleq P(x_D, y_D)}\n        \\int \\mb{u}(x_S, y_S) e^{-j \\frac{k_0}{z} (x_D x_S + y_D y_S)\n        } dx_S \\ dy_S \\;.\n\n    This is valid when :math:`N_F << 1`, where :math:`N_F` is the\n    Fresnel number (Sec. 1.5, Sec. 4.7.2.1) :cite:`paganin-2006-coherent`.\n    Writing the Fourier transform of the field :math:`\\mb{u}` as\n\n    .. math ::\n        \\hat{\\mb{u}}(k_x, k_y) = \\int e^{-j (k_x x + k_y y)}\n        \\mb{u}(x, y) dx \\ dy \\;,\n\n    the action of this linear operator can be written\n\n    .. math ::\n        (A \\mb{u})(x_D, y_D) = P(x_D, y_D) \\ \\hat{\\mb{u}}\n        \\left({\\frac{k_0}{z} x_D, \\frac{k_0}{z} y_D}\\right) \\;.\n\n    Ignoring multiplicative prefactors, the Fraunhofer propagated\n    field is the Fourier transform of the source field, evaluated at\n    coordinates :math:`(k_x, k_y) = (\\frac{k_0}{z} x_D,\n    \\frac{k_0}{z} y_D)`.\n\n    In general, the sampling intervals (and thus plane lengths)\n    differ between source and destination planes. In particular,\n    (Eq. 5.18, :cite:`voelz-2011-computational`)\n\n    .. math ::\n        \\Delta x_D =  \\frac{2 \\pi z}{k_0 L_{Sx} } \\quad \\text{and}\n        \\quad L_{Dx} =  \\frac{2 \\pi z}{k_0 \\Delta x_S } \\;,\n\n    and similarly for the :math:`y` axis.\n\n    The Fraunhofer propagator term :math:`P(x_D, y_D)` is adequately\n    sampled when\n\n    .. math ::\n         \\Delta x_S \\geq \\sqrt{\\frac{2 \\pi z}{N_x k_0}} \\quad \\text{and}\n         \\quad \\Delta y_S \\geq \\sqrt{\\frac{2 \\pi z}{N_y k_0}} \\;.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        dx: Union[float, Tuple[float, ...]],\n        k0: float,\n        z: float,\n        jit: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array as a tuple of length\n               1 or 2, corresponding to :math:`(N_x,)` or\n               :math:`(N_x, N_y)`.\n            dx: Sampling interval at source plane. If a float and\n               `len(input_shape)==2` the same sampling interval is applied\n               to both dimensions. If `dx` is a tuple, it must have same\n               length as `input_shape`, and corresponds to either\n               :math:`(\\Delta x,)` or :math:`(\\Delta x, \\Delta y)`.\n            k0: Illumination wavenumber, :math:`k_0`, corresponding to\n               :math:`2 \\pi` / wavelength.\n            z: Propagation distance, :math:`z`.\n            jit: If ``True``, jit the evaluation, adjoint, and gram\n               functions of this :class:`LinearOperator`. Default:\n               ``True``.\n        \"\"\"\n        ndim = len(input_shape)  # 1 or 2 dimensions\n        if ndim not in (1, 2):\n            raise ValueError(\"Invalid input dimensions; must be 1 or 2.\")\n\n        if _isscalar(dx):\n            dx = (dx,) * ndim\n        else:\n            assert isinstance(dx, tuple)\n            if len(dx) != ndim:\n                raise ValueError(\n                    \"Argument 'dx' must be a scalar or have len(dx) == len(input_shape); \"\n                    f\"got len(dx)={len(dx)}, len(input_shape)={ndim}.\"\n                )\n        assert isinstance(dx, tuple)\n\n        L: Tuple[float, ...] = tuple(s * d for s, d in zip(input_shape, dx))\n\n        #: Illumination wavenumber\n        self.k0: float = k0\n        #: Propagation distance\n        self.z: float = z\n        #: Source plane side length (dx[i] * input_shape[i])\n        self.L: Tuple[float, ...] = L\n        #: Source plane sampling interval\n        self.dx: Tuple[float, ...] = dx\n\n        #: Destination plane sampling interval\n        self.dx_D: Tuple[float, ...] = tuple(np.abs(2 * np.pi * z / (k0 * l)).item() for l in L)\n        #: Destination plane side length\n        self.L_D: Tuple[float, ...] = tuple(np.abs(2 * np.pi * z / (k0 * d)).item() for d in dx)\n        x_D = tuple(np.r_[-l / 2 : l / 2 : d] for l, d in zip(self.L_D, self.dx_D))  # type: ignore\n\n        # set up radial coordinate system; either x^2 or (x^2 + y^2)\n        if ndim == 1:\n            self.r2 = x_D[0]\n        elif ndim == 2:\n            self.r2 = np.sqrt(x_D[0][:, None] ** 2 + x_D[1][None, :] ** 2)\n\n        phase = -1j * snp.exp(1j * k0 * z) * snp.exp(1j * 0.5 * k0 / z * self.r2**2)\n        phase *= k0 / (2 * np.pi) * np.abs(1 / z)\n        phase *= np.prod(dx)  # from approximating continouous FT with DFT\n        phase = phase.astype(np.complex64)\n\n        self.F = DFT(input_shape=input_shape, jit=False)\n        self.D = Diagonal(phase)\n        super().__init__(\n            input_shape=input_shape,\n            input_dtype=np.complex64,\n            output_shape=input_shape,\n            output_dtype=np.complex64,\n            **kwargs,\n        )\n\n        if jit:\n            self.jit()\n\n    def __repr__(self):\n        extra_repr = f\"\"\"  k0:   {self.k0}\n  λ:    {2*np.pi/self.k0}\n  z:    {self.z}\n  dx:   {self.dx}\n  L:    {self.L}\n  dx_D: {self.dx_D}\n  L_D:  {self.L_D}\n\"\"\"\n        return LinearOperator.__repr__(self) + extra_repr\n\n    def _eval(self, x):\n        x = snp.fft.fftshift(x)\n        y = self.D @ self.F @ x\n        y = snp.fft.ifftshift(y)\n        return y\n\n    def adequate_sampling(self):\n        r\"\"\"Verify the Fraunhofer propagation kernel is not aliased.\n\n        Checks the condition for adequate sampling\n        :cite:`voelz-2011-computational`,\n\n        .. math ::\n            \\Delta x_S \\geq \\sqrt{\\frac{2 \\pi z}{N_x k_0}} \\quad \\text{and}\n            \\quad \\Delta y_S \\geq \\sqrt{\\frac{2 \\pi z}{N_y k_0}} \\;.\n\n        Returns:\n             ``True`` if the Fraunhofer propagation kernel is adequately\n             sampled, ``False`` otherwise.\n        \"\"\"\n        tmp = []\n        for d, N in zip(self.dx, self.input_shape):\n            tmp.append(d**2 > 2 * np.pi * self.z / (self.k0 * N))\n        return np.all(tmp)\n"
  },
  {
    "path": "scico/linop/xray/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2023-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\nr\"\"\"X-ray transform classes.\n\nThe tomographic projections that are frequently referred to as Radon\ntransforms are referred to as X-ray transforms in SCICO. While the Radon\ntransform is far more well-known than the X-ray transform, which is the\nsame as the Radon transform for projections in two dimensions, these two\ntransform differ in higher numbers of dimensions, and it is the X-ray\ntransform that is the appropriate mathematical model for beam attenuation\nbased imaging in three or more dimensions.\n\nSCICO includes its own integrated 2D and 3D X-ray transforms, and also\nprovides interfaces to those implemented in the\n`ASTRA toolbox <https://github.com/astra-toolbox/astra-toolbox>`_\nand the `svmbir <https://github.com/cabouman/svmbir>`_ package.\n\n\n**2D Transforms**\n\nThe SCICO, ASTRA, and svmbir transforms use different conventions for\nview angle directions, as illustrated in the figure below.\n\n.. plot:: pyfigures/xray_2d_geom.py\n   :align: center\n   :include-source: False\n   :show-source-link: False\n   :caption: Comparison of 2D X-ray projector geometries. The radial\n      arrows are directed towards the locations of the corresponding\n      detectors, with the direction of increasing pixel indices indicated\n      by the arrows on the dotted lines parallel to the detectors.\n\n|\n\nThe conversion from the SCICO projection angle convention to those of the\nother two transforms is\n\n.. math::\n\n   \\begin{aligned}\n   \\theta_{\\text{astra}} &= \\theta_{\\text{scico}} - \\frac{\\pi}{2} \\\\\n   \\theta_{\\text{svmbir}} &= 2 \\pi - \\theta_{\\text{scico}} \\;.\n   \\end{aligned}\n\n\n**3D Transforms**\n\nThere are more significant differences in the interfaces for the 3D SCICO\nand ASTRA transforms. The SCICO 3D transform :class:`.xray.XRayTransform3D`\ndefines the projection geometry in terms of a set of projection matrices,\nwhile the geometry for the ASTRA 3D transform\n:class:`.astra.XRayTransform3D` may either be specified in terms of a set\nof view angles, or via a more general set of vectors specifying projection\ndirection and detector orientation. A number of support functions are\nprovided for convering between these conventions.\n\nNote that the SCICO transform is implemented in JAX and can be run on\nboth CPU and GPU devices, while the ASTRA transform is implemented in\nCUDA, and can only be run on GPU devices.\n\"\"\"\n\nimport sys\n\nfrom ._util import (\n    center_image,\n    image_alignment_rotation,\n    image_centroid,\n    rotate_volume,\n    volume_alignment_rotation,\n)\nfrom ._xray import XRayTransform2D, XRayTransform3D\n\n__all__ = [\n    \"XRayTransform2D\",\n    \"XRayTransform3D\",\n    \"image_centroid\",\n    \"center_image\",\n    \"rotate_volume\",\n    \"image_alignment_rotation\",\n    \"volume_alignment_rotation\",\n]\n\n\n# Imported items in __all__ appear to originate in top-level xray module\nfor name in __all__:\n    getattr(sys.modules[__name__], name).__module__ = __name__\n"
  },
  {
    "path": "scico/linop/xray/_axitom/LICENSE",
    "content": "MIT License\n\nCopyright (c) 2019 PolymerGuy\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "scico/linop/xray/_axitom/README.md",
    "content": "# AXITOM\n\nThe modules in this directory are derived from the\n[AXITOM](https://github.com/PolymerGuy/AXITOM) package. All original\ncomponents of AXITOM are subject to its license, which is included in\nthe file \"LICENSE\".\n"
  },
  {
    "path": "scico/linop/xray/_axitom/backprojection.py",
    "content": "\"\"\"\nThis file is a modified version of \"backprojection.py\" from the\n[AXITOM](https://github.com/PolymerGuy/AXITOM) package.\n\nFiltered back projection functions\n\nThis module contains the Feldkamp David Kress filtered back projection\nroutines.\n\"\"\"\n\nfrom typing import Optional\n\nimport jax.numpy as jnp\nfrom jax import Array\nfrom jax.scipy.ndimage import map_coordinates\n\nfrom .config import Config\nfrom .filtering import ramp_filter_and_weight\nfrom .utilities import rotate_coordinates\n\n\ndef map_object_to_detector_coords(object_xs, object_ys, object_zs, config):\n    \"\"\"Map the object coordinates to detector pixel coordinates\n    accounting for cone beam divergence.\n\n    Parameters\n    ----------\n    object_xs : np.ndarray\n        The x-coordinate array of the object to be reconstructed\n    object_ys : np.ndarray\n        The y-coordinate array of the object to be reconstructed\n    object_zs : np.ndarray\n        The z-coordinate array of the object to be reconstructed\n    config : obj\n        The config object containing all necessary settings for the\n        reconstruction\n\n    Returns\n    -------\n    detector_cords_a\n        The detector coordinates along the a-axis corresponding to the\n        given points\n    detector_cords_b\n        The detector coordinates along the b-axis corresponding to the\n        given points\n    \"\"\"\n\n    detector_cords_a = (\n        ((object_ys * config.source_to_detector_dist) / (object_xs + config.source_to_object_dist))\n        - config.detector_us[0]\n    ) / config.pixel_size_u\n\n    if object_xs.ndim == 2:\n        detector_cords_b = (\n            (\n                (object_zs[jnp.newaxis, jnp.newaxis, :] * config.source_to_detector_dist)\n                / (object_xs[:, :, jnp.newaxis] + config.source_to_object_dist)\n            )\n            - config.detector_vs[0]\n        ) / config.pixel_size_v\n\n    elif object_xs.ndim == 1:\n        detector_cords_b = (\n            (\n                (object_zs[jnp.newaxis, :] * config.source_to_detector_dist)\n                / (object_xs[:, jnp.newaxis] + config.source_to_object_dist)\n            )\n            - config.detector_vs[0]\n        ) / config.pixel_size_v\n    else:\n        raise ValueError(\"Invalid dimensions on the object coordinates\")\n\n    return detector_cords_a, detector_cords_b\n\n\ndef _fdk_axisym(projection_filtered, config, angles):\n    \"\"\"Filtered back projection algorithm as proposed by Feldkamp David\n    Kress, adapted for axisymmetry.\n\n    This implementation has been adapted for axis-symmetry by using a\n    single projection only and by only reconstructing a single R-Z slice.\n\n    This algorithm is based on:\n       https://doi.org/10.1364/JOSAA.1.000612\n    but follows the notation used by:\n    Henrik Turbell, Cone-Beam Reconstruction Using Filtered\n    Backprojection, PhD Thesis, Linkoping Studies in Science and\n    Technology\n    https://people.csail.mit.edu/bkph/courses/papers/Exact_Conebeam/Turbell_Thesis_FBP_2001.pdf\n\n    Parameters\n    ----------\n    projection_filtered : jnp.ndarray\n        The ramp filtered and weighted projection used in the reconstruction\n    config : obj\n        The config object containing all necessary settings for the reconstruction\n\n\n    Returns\n    -------\n    ndarray\n        The reconstructed slice is a R-Z plane of a axis-symmetric\n        tomogram where Z is the symmetry axis.\n    \"\"\"\n    proj_width, proj_height = projection_filtered.shape\n    proj_center = int(proj_width / 2)\n\n    # Allocate an empty array\n    recon_slice = jnp.zeros((proj_width, proj_height), dtype=jnp.float32)\n\n    for frame_nr, angle in enumerate(angles):\n        x_rotated, y_rotated = rotate_coordinates(\n            jnp.zeros_like(config.object_xs),\n            config.object_ys,\n            jnp.radians(angle),\n        )\n        detector_cords_a, detector_cords_b = map_object_to_detector_coords(\n            x_rotated, y_rotated, config.object_zs, config\n        )\n        # a is independent of Z but has to match the shape of b\n        detector_cords_a = detector_cords_a[:, jnp.newaxis] * jnp.ones_like(detector_cords_b)\n        # This term is caused by the divergent cone geometry\n        ratio = (config.source_to_object_dist**2.0) / (\n            config.source_to_object_dist + x_rotated\n        ) ** 2.0\n        recon_slice = recon_slice + ratio[:, jnp.newaxis] * map_coordinates(\n            projection_filtered, [detector_cords_a, detector_cords_b], cval=0.0, order=1\n        )\n\n    return recon_slice / angles.size\n\n\ndef fdk(projection: Array, config: Config, angles: Optional[Array] = None) -> Array:\n    \"\"\"Filtered back projection algorithm as proposed by Feldkamp David\n    Kress, adapted for axisymmetry.\n\n    This implementation has been adapted for axis-symmetry by using a\n    single projection only and by only reconstructing a single R-Z slice.\n\n    This algorithm is based on:\n       https://doi.org/10.1364/JOSAA.1.000612\n    but follows the notation used by:\n    Henrik Turbell, Cone-Beam Reconstruction Using Filtered\n    Backprojection, PhD Thesis, Linkoping Studies in Science and\n    Technology\n    https://people.csail.mit.edu/bkph/courses/papers/Exact_Conebeam/Turbell_Thesis_FBP_2001.pdf\n\n    Args:\n      projection: The projection used in the reconstruction\n      config: The config object containing all necessary settings for the\n        reconstruction.\n      angles: Array of angles at which reconstruction should be computed.\n        Defaults to 0 to 359 degrees with a 1 degree step.\n\n    Returns:\n        The reconstructed slice is a R-Z plane of a axis-symmetric\n        tomogram where Z is the symmetry axis.\n    \"\"\"\n    if angles is None:\n        angles = jnp.arange(0, 360)\n\n    if not isinstance(config, Config):\n        raise ValueError(\"Only instances of Config are valid settings\")\n\n    if projection.ndim == 2:\n        projection_filtered = ramp_filter_and_weight(projection, config)\n    else:\n        raise ValueError(\"The projection has to be a 2D array\")\n\n    tomo = _fdk_axisym(projection_filtered, config, angles)\n    return tomo\n"
  },
  {
    "path": "scico/linop/xray/_axitom/config.py",
    "content": "\"\"\"\nThis file is a modified version of \"config.py\" from the\n[AXITOM](https://github.com/PolymerGuy/AXITOM) package.\n\nConfig object and factory.\n\nThis module contains the Config class which has all the settings that are\nused during the reconstruction of the tomogram.\n\"\"\"\n\nimport numpy as np\n\n\nclass Config:\n    \"\"\"Configuration object for the forward projection.\"\"\"\n\n    def __init__(\n        self,\n        n_pixels_u: int,\n        n_pixels_v: int,\n        pixel_size_u: float,\n        pixel_size_v: float,\n        source_to_detector_dist: float,\n        source_to_object_dist: float,\n        **kwargs,\n    ):\n        \"\"\"\n        Note that invalid arguments are neglected without warning.\n\n        Args:\n            n_pixels_u: Number of pixels in the u direction of the sensor.\n            n_pixels_v: Number of pixels in the u direction of the sensor.\n            pixel_size_u: Pixel size in the u direction [mm].\n            pixel_size_v: Pixel size in the v direction [mm].\n            source_to_detector_dist: Distance between source and\n              detector [mm].\n            source_to_object_dist: Distance between source and object\n              [mm].\n        \"\"\"\n\n        self.n_pixels_u = n_pixels_u\n        self.n_pixels_v = n_pixels_v\n\n        self.pixel_size_u = pixel_size_u\n        self.pixel_size_v = pixel_size_v\n\n        self.detector_size_u = self.pixel_size_u * self.n_pixels_u\n        self.detector_size_v = self.pixel_size_v * self.n_pixels_v\n\n        self.source_to_detector_dist = source_to_detector_dist\n        self.source_to_object_dist = source_to_object_dist\n\n        # All values below are calculated\n\n        self.object_size_x = (\n            self.detector_size_u * self.source_to_object_dist / self.source_to_detector_dist\n        )\n        self.object_size_y = (\n            self.detector_size_u * self.source_to_object_dist / self.source_to_detector_dist\n        )\n        self.object_size_z = (\n            self.detector_size_v * self.source_to_object_dist / self.source_to_detector_dist\n        )\n\n        self.voxel_size_x = self.object_size_x / self.n_pixels_u\n        self.voxel_size_y = self.object_size_y / self.n_pixels_u\n        self.voxel_size_z = self.object_size_z / self.n_pixels_v\n\n        self.object_ys = (\n            np.arange(self.n_pixels_u, dtype=np.float32) - self.n_pixels_u / 2.0\n        ) * self.voxel_size_y\n        self.object_xs = (\n            np.arange(self.n_pixels_u, dtype=np.float32) - self.n_pixels_u / 2.0\n        ) * self.voxel_size_x\n        self.object_zs = (\n            np.arange(self.n_pixels_v, dtype=np.float32) - self.n_pixels_v / 2.0\n        ) * self.voxel_size_z\n\n        self.detector_us = (\n            np.arange(self.n_pixels_u, dtype=np.float32) - self.n_pixels_u / 2.0\n        ) * self.pixel_size_u\n        self.detector_vs = (\n            np.arange(self.n_pixels_v, dtype=np.float32) - self.n_pixels_v / 2.0\n        ) * self.pixel_size_v\n\n    def __repr__(self):\n        str = f\"Source-object distance: {self.source_to_object_dist}  \"\n        str += f\"Source-detector distance: {self.source_to_detector_dist}\\n\"\n        str += f\"Detector pixels: {self.n_pixels_u}, {self.n_pixels_v}  \"\n        str += f\"Detector size: {self.detector_size_u:.3e}, {self.detector_size_v:.3e}\\n\"\n        str += f\"Pixel size: {self.pixel_size_u:.3e}, {self.pixel_size_v:.3e}\\n\"\n        str += (\n            f\"Voxel size: {self.voxel_size_x:.3e}, {self.voxel_size_y:.3e}, \"\n            f\"{self.voxel_size_z:.3e}\"\n        )\n        return str\n\n    def with_param(self, **kwargs):\n        \"\"\"Get a clone of the object with changed parameters.\n\n        Get a clone of the object with changed parameters and all\n        calculations updated.\n\n        Args:\n          kwargs: The arguments of the config object that should be\n            changed.\n\n        Returns:\n          obj: Config object with modified settings.\n\n        \"\"\"\n        params = self.__dict__.copy()\n\n        for arg, value in kwargs.items():\n            params[arg] = value\n        return Config(**params)\n"
  },
  {
    "path": "scico/linop/xray/_axitom/filtering.py",
    "content": "\"\"\"\nThis file is a modified version of \"filtering.py\" from the\n[AXITOM](https://github.com/PolymerGuy/AXITOM) package.\n\nFilter tools\n\nThis module contains the ramp filter and the weighting function.\n\"\"\"\n\nimport numpy as np\n\nimport jax.numpy as jnp\nimport jax.scipy.signal as sig\n\n\ndef _ramp_kernel_real(cutoff, length):\n    \"\"\"Ramp filter kernel in real space defined by the cut-off frequency\n    and the spatial dimension.\n\n    Parameters\n    ----------\n    cutoff : float\n        The cut-off frequency\n    length : int\n        The kernel filter length\n\n    Returns\n    -------\n    ndarray\n        The filter kernel\n    \"\"\"\n    pos = jnp.arange(-length, length, 1)\n    return cutoff**2.0 * (2.0 * jnp.sinc(2 * pos * cutoff) - jnp.sinc(pos * cutoff) ** 2.0)\n\n\ndef _add_weights(projection, config):\n    \"\"\"Add weights to the projection according to the ray length traveled\n    through a voxel.\n\n    Parameters\n    ----------\n    projection : jnp.ndarray\n        The projection used in the reconstruction\n    config : obj\n        The config object containing all necessary settings for the\n        reconstruction\n\n    Returns\n    -------\n    ndarray\n        The projections weighted by the ray length\n    \"\"\"\n    uu, vv = jnp.meshgrid(config.detector_vs, config.detector_us)\n\n    weights = config.source_to_detector_dist / jnp.sqrt(\n        config.source_to_detector_dist**2.0 + uu**2.0 + vv**2.0\n    )\n\n    return projection * weights\n\n\ndef ramp_filter_and_weight(projection, config):\n    \"\"\"Add weights to the projection and apply a ramp-high-pass filter\n    set to 0.5*Nyquist_frequency\n\n    Parameters\n    ----------\n    projection : jnp.ndarray\n        The projection used in the reconstruction\n    config : obj\n        The config object containing all necessary settings for the\n        reconstruction\n\n    Returns\n    -------\n    ndarray\n        The projections weighted by the ray length and filtered by ramp\n        filter\n    \"\"\"\n    projections_weighted = _add_weights(projection, config)\n\n    n_pixels_u, _ = np.shape(projections_weighted)\n    ramp_kernel = _ramp_kernel_real(0.5, n_pixels_u)\n\n    projections_filtered = np.zeros_like(projections_weighted)\n\n    _, n_lines = projections_weighted.shape\n\n    for j in range(n_lines):\n        projections_filtered[:, j] = sig.fftconvolve(\n            projections_weighted[:, j], ramp_kernel, mode=\"same\"\n        )\n\n    scale_factor = (\n        1.0\n        / config.pixel_size_u\n        * np.pi\n        * (config.source_to_detector_dist / config.source_to_object_dist)\n    )\n\n    return projections_filtered * scale_factor\n"
  },
  {
    "path": "scico/linop/xray/_axitom/projection.py",
    "content": "\"\"\"\nThis file is a modified version of \"projection.py\" from the\n[AXITOM](https://github.com/PolymerGuy/AXITOM) package.\n\nForward projection routines.\n\nThis module contains the functions used to forward project a volume onto\na sensor plane.\n\"\"\"\n\nfrom functools import partial\n\nimport numpy as np\n\nimport jax\nimport jax.numpy as jnp\nfrom jax import Array, jit\nfrom jax.scipy.ndimage import map_coordinates\n\nfrom .config import Config\n\n\n@partial(jit, static_argnames=[\"config\", \"input_2d\"])\ndef _partial_forward_project(\n    volume: Array,\n    uu: Array,\n    vv: Array,\n    irslab,\n    config: Config,\n    input_2d: bool = False,\n) -> Array:\n    \"\"\"Partial projection of a volume onto a sensor plane.\n\n    Partial projection of a cylindrically symmetric volume onto a sensor\n    plane using conical beam geometry: this functional only sums along\n    the section of the imaging direction specified by :code:`ratios`.\n\n    Args:\n      volume: The volume that will be projected onto the sensor.\n      uu: Detector grid in axis 1 direction.\n      vv: Detector grid in axis 0 direction.\n      irslab: Array of indices and ratios.\n      config: The settings object.\n      input_2d: If ``True``, the input is a 2D image from which a 3D\n        volume is constructed by rotation about the center of axis 1\n        of the image.\n\n    Returns:\n        The projection.\n    \"\"\"\n    islab = irslab[0]\n    rslab = irslab[1]\n    N = config.object_ys.size\n\n    pvs = (\n        vv[:, jnp.newaxis, :] * rslab[jnp.newaxis, :, jnp.newaxis] - config.object_zs[0]\n    ) / config.voxel_size_z\n    pys = islab[jnp.newaxis, :, jnp.newaxis] * jnp.ones_like(pvs)\n    pus = (\n        uu[:, jnp.newaxis, :] * rslab[jnp.newaxis, :, jnp.newaxis] - config.object_xs[0]\n    ) / config.voxel_size_x\n\n    if input_2d:\n        ax0c, ax1c, ax2c = ((np.array(pvs.shape) + 1) / 2 - 1).tolist()\n        ax1c = (N + 1) / 2 - 1\n        r = jnp.hypot(pus - ax2c, pys - ax1c)\n        ax1 = jnp.where(pys >= ax1c, ax1c + r, ax1c - r)\n        proj2d = jnp.sum(map_coordinates(volume, [pvs, ax1], cval=0.0, order=1), axis=1)\n    else:\n        proj2d = jnp.sum(map_coordinates(volume, [pvs, pys, pus], cval=0.0, order=1), axis=1)\n\n    dist = (\n        jnp.sqrt(config.source_to_detector_dist**2.0 + uu**2.0 + vv**2.0)\n        / (config.source_to_detector_dist)\n        * config.voxel_size_y\n    )\n\n    return proj2d * dist\n\n\n@partial(jit, static_argnames=[\"config\", \"num_slabs\", \"input_2d\"])\ndef forward_project(\n    volume: Array, config: Config, num_slabs: int = 8, input_2d: bool = False\n) -> Array:\n    \"\"\"Projection of a volume onto a sensor plane.\n\n    Projection of a cylindrically symmetric volume onto a sensor plane\n    using conical beam geometry.\n\n    Args:\n      volume: The volume that will be projected onto the sensor.\n      config: The settings object.\n      num_slabs: Number of slabs into which the volume should be\n        divided (for serial processing, to limit memory usage) in\n        the imaging direction.\n      input_2d: If ``True``, the input is a 2D image from which a 3D\n        volume is constructed by rotation about the center of axis 1\n        of the image.\n\n    Returns:\n        The projection.\n    \"\"\"\n    uu, vv = jnp.meshgrid(config.detector_us, config.detector_vs)\n    ratios = (config.object_ys + config.source_to_object_dist) / config.source_to_detector_dist\n    N = ratios.size\n    slab_size = N // num_slabs\n    remainder = N % num_slabs\n    islabs = jnp.stack(jnp.split(jnp.arange(0, slab_size * num_slabs), num_slabs))\n    rslabs = jnp.stack(jnp.split(ratios[0 : slab_size * num_slabs], num_slabs))\n    irslabs = jnp.stack((islabs, rslabs), axis=1)\n\n    func = lambda irslab: _partial_forward_project(\n        volume, uu, vv, irslab, config, input_2d=input_2d\n    )\n    # jax.checkpoint used to avoid excessive memory requirements\n    proj = jnp.sum(jax.lax.map(jax.checkpoint(func), irslabs), axis=0)\n\n    if remainder:\n        irslab = jnp.stack((jnp.arange(slab_size * num_slabs, N), ratios[-remainder:]))\n        proj += jax.checkpoint(func)(irslab)\n\n    return proj\n"
  },
  {
    "path": "scico/linop/xray/_axitom/utilities.py",
    "content": "\"\"\"\nThis file is a modified version of \"utilities.py\" from the\n[AXITOM](https://github.com/PolymerGuy/AXITOM) package.\n\nUtilites\n\nThis module contains various utility functions that does not have any\nother obvious home.\n\"\"\"\n\nimport numpy as np\n\nimport jax.numpy as jnp\n\n\ndef _find_center_of_gravity_in_projection(projection, background_internsity=0.9):\n    \"\"\"Find axis of rotation in the projection.\n    This is done by binarization of the image into object and background\n    and determining the center of gravity of the object.\n\n    Parameters\n    ----------\n    projection : ndarray\n        The projection, normalized between 0 and 1\n    background_internsity : float\n        The background intensity threshold\n\n\n    Returns\n    -------\n    float64\n        The center of gravity in the u-direction\n    float64\n        The center of gravity in the v-direction\n\n    \"\"\"\n    m, n = np.shape(projection)\n\n    binary_proj = np.zeros_like(projection, dtype=np.float)\n    binary_proj[projection < background_internsity] = 1.0\n\n    area_x = np.sum(binary_proj, axis=1)\n    area_y = np.sum(binary_proj, axis=0)\n\n    non_zero_rows = np.arange(n)[area_y != 0.0]\n    non_zero_columns = np.arange(m)[area_x != 0.0]\n\n    # Now removing all columns that does not intersect the object\n    object_pixels = binary_proj[non_zero_columns, :][:, non_zero_rows]\n    area_x = area_x[non_zero_columns]\n    area_y = area_y[non_zero_rows]\n    xs, ys = np.meshgrid(non_zero_rows, non_zero_columns)\n\n    # Determine center of gravity\n    center_of_grav_x = np.average(np.sum(xs * object_pixels, axis=1) / area_x) - n / 2.0\n    center_of_grav_y = np.average(np.sum(ys * object_pixels, axis=0) / area_y) - m / 2.0\n    return center_of_grav_x, center_of_grav_y\n\n\ndef find_center_of_rotation(projection, background_internsity=0.9, method=\"center_of_gravity\"):\n    \"\"\"Find the axis of rotation of the object in the projection\n\n    Parameters\n    ----------\n    projection : ndarray\n        The projection, normalized between 0 and 1\n    background_internsity : float\n        The background intensity threshold\n    method : string\n        The background intensity threshold\n\n\n    Returns\n    -------\n    float64\n        The center of gravity in the v-direction\n    float64\n        The center of gravity in the u-direction\n\n    \"\"\"\n    if projection.ndim != 2:\n        raise ValueError(\"Invalid projection shape. It has to be a 2d numpy array\")\n\n    if method == \"center_of_gravity\":\n        center_v, center_u = _find_center_of_gravity_in_projection(\n            projection, background_internsity\n        )\n    else:\n        raise ValueError(\"Invalid method\")\n\n    return center_v, center_u\n\n\ndef rotate_coordinates(xs_array, ys_array, angle_rad):\n    \"\"\"Rotate coordinate arrays by a given angle\n\n    Parameters\n    ----------\n    xs_array : ndarray\n        Two dimensional coordinate array with x-coordinates\n    ys_array : ndarray\n        Two dimensional coordinate array with y-coordinates\n    angle_rad : float\n        Rotation angle in radians\n\n    Returns\n    -------\n    ndarray\n        The rotated x-coordinates\n    ndarray\n        The rotated x-coordinates\n\n    \"\"\"\n    rx = xs_array * jnp.cos(angle_rad) + ys_array * jnp.sin(angle_rad)\n    ry = -xs_array * jnp.sin(angle_rad) + ys_array * jnp.cos(angle_rad)\n    return rx, ry\n"
  },
  {
    "path": "scico/linop/xray/_util.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2024-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Utilities for CT data.\"\"\"\n\nfrom typing import Optional, Tuple\n\nimport numpy as np\n\nimport jax.numpy as jnp\nimport jax.scipy.spatial.transform as jsst\nfrom jax import Array\nfrom jax.image import ResizeMethod, scale_and_translate\nfrom jax.scipy.ndimage import map_coordinates\nfrom jax.typing import ArrayLike\n\ntry:\n    import scico.linop.xray.astra\n\n    have_astra = True\nexcept ModuleNotFoundError as e:\n    if e.name == \"astra\":\n        have_astra = False\n    else:\n        raise e\nimport scipy.spatial.transform as sst\n\n\ndef image_centroid(v: ArrayLike, center_offset: bool = False) -> Tuple[float, ...]:\n    \"\"\"Compute the centroid of an image.\n\n    Compute the centroid of an image or higher-dimensional array.\n\n    Args:\n        v: Array for which centroid is to be computed.\n        center_offset: If ``True``, compute centroid coordinates\n           relative to the spatial center of the image.\n\n    Returns:\n        Tuple of centroid coordinates.\n    \"\"\"\n    if center_offset:\n        offset = (jnp.array(v.shape, dtype=jnp.float32) - 1.0) / 2.0\n    else:\n        offset = jnp.zeros((v.ndim,), dtype=jnp.float32)\n    g1d = [jnp.arange(size, dtype=jnp.float32) - offset[idx] for idx, size in enumerate(v.shape)]\n    g = jnp.meshgrid(*g1d, sparse=True, indexing=\"ij\")\n    m00 = v.astype(jnp.float32).sum()\n    if m00 == 0.0:\n        c = (0.0,) * v.ndim\n    else:\n        c = tuple([(jnp.sum(v * g[idx]) / m00).item() for idx in range(v.ndim)])\n\n    return c\n\n\ndef center_image(\n    v: ArrayLike,\n    axes: Optional[Tuple[int, ...]] = None,\n    method: ResizeMethod = ResizeMethod.LANCZOS3,\n) -> Array:\n    \"\"\"Translate an image to center the centroid.\n\n    Translate an image (or higher-dimensional array) so that the centroid\n    is at the spatial center of the image grid.\n\n    Args:\n        v: Array to be centered.\n        axes: Array axes on which centering is to be applied. Defaults to\n          all axes.\n        method: Interpolation method for image translation.\n\n    Returns:\n        Centered array.\n    \"\"\"\n    if axes is None:\n        axes = tuple(range(v.ndim))\n    c = jnp.array(image_centroid(v, center_offset=True), dtype=jnp.float32)\n    scale = jnp.ones((v.ndim,), dtype=jnp.float32)[jnp.array(axes)]\n    trans = -c[jnp.array(axes)]\n    cv = scale_and_translate(v, v.shape, axes, scale, trans, method=method)\n    return cv\n\n\ndef rotate_volume(\n    vol: ArrayLike,\n    rot: jsst.Rotation,\n    x: Optional[ArrayLike] = None,\n    y: Optional[ArrayLike] = None,\n    z: Optional[ArrayLike] = None,\n    center: Optional[ArrayLike] = None,\n) -> Array:\n    \"\"\"Rotate a 3D array.\n\n    Rotate a 3D array as specified by an instance of\n    :class:`~jax.scipy.spatial.transform.Rotation`. Any axis coordinates\n    that are not specified default to a range corresponding to the size\n    of the array on that axis, starting at zero.\n\n    Args:\n        vol: Array to be rotated.\n        rot: Rotation specification.\n        x: Coordinates for :code:`x` axis (axis 0).\n        y: Coordinates for :code:`y` axis (axis 1).\n        z: Coordinates for :code:`z` axis (axis 2).\n        center: A 3-vector specifying the center of rotation.\n           Defaults to the center of the array.\n\n    Returns:\n        Rotated array.\n    \"\"\"\n    shape = vol.shape\n    if x is None:\n        x = jnp.arange(shape[0])\n    if y is None:\n        y = jnp.arange(shape[1])\n    if z is None:\n        z = jnp.arange(shape[2])\n    if center is None:\n        center = (jnp.array(shape, dtype=jnp.float32) - 1.0) / 2.0\n    gx, gy, gz = jnp.meshgrid(x - center[0], y - center[1], z - center[2], indexing=\"ij\")\n    crd = jnp.stack((gx.ravel(), gy.ravel(), gz.ravel()))\n    rot_crd = rot.as_matrix() @ crd + center[:, jnp.newaxis]  # faster than rot.apply(crd.T)\n    rot_vol = map_coordinates(vol, rot_crd.reshape((3,) + shape), order=1)\n    return rot_vol\n\n\ndef image_alignment_rotation(\n    img: ArrayLike, max_angle: float = 2.5, angle_step: float = 0.025, center_factor: float = 5e-3\n) -> float:\n    r\"\"\"Estimate an image alignment rotation.\n\n    Estimate the rotation that best aligns vertical straight lines in\n    the image with the vertical axis.\n\n    The approach is roughly based on that used in the\n    :code:`find_img_rotation_2D` function in the `cSAXS base package`\n    released by the CXS group at the Paul Scherrer Institute, which\n    finds the rotation angle that results in the sparsest column sum\n    according to the sparsity measure proposed in Sec 3.1 of\n    :cite:`hoyer-2004-nonnegative`. (Note that an :math:`\\ell_1` norm\n    sparsity measure is not suitable for this purpose since it is, in\n    typical cases, appropximately invariant to the rotation angle.) The\n    implementation here uses the plain ratio of :math:`\\ell_1` and\n    :math:`\\ell_2` norms as a sparsity measure, more efficiently computes\n    the column sums at different angles by exploiting the 2D X-ray\n    transform, and includes a small bias for smaller angle rotations that\n    improves performance when a range of rotation angles have the same\n    sparsity measure.\n\n    Args:\n        img: Array of pixel values.\n        max_angle: Maximum  angle (negative and positive) to test, in\n          degrees.\n        angle_step: Increment in angle values for range of angles to\n          test, in degrees.\n        center_factor: The angle multiplied by this scalar is added to\n          the sparsity measure to slightly prefer smaller-angle\n          solutions.\n\n    Returns:\n        Rotation angle (in degrees) providing best alignment with the\n        vertical (0) axis.\n\n    Notes:\n        The number number of detector pixels for the 2D X-ray transform\n        is chosen based on the shape :math:`(N_0, N_1)` of :code:`img`\n        and the value :math:`\\theta` of parameter :code:`max_angle`, as\n        indicated in Fig. 1.\n\n        .. figure:: /figures/img_align.svg\n           :align: center\n           :width: 40%\n\n           Fig 1. Calculation of the number of detector pixels for the 2D\n           X-ray transform.\n\n\n    \"\"\"\n    if not have_astra:\n        raise RuntimeError(\"Package astra is required for use of this function.\")\n    angles = np.arange(-max_angle, max_angle, angle_step)\n    max_angle_rad = max_angle * np.pi / 180\n    # choose det_count so that projected image is within the detector bounds\n    det_count = int(\n        1.05 * (img.shape[0] * np.sin(max_angle_rad) + img.shape[1] * np.cos(max_angle_rad))\n    )\n    A = scico.linop.xray.astra.XRayTransform2D(\n        img.shape,\n        det_count=det_count,\n        det_spacing=1.0,\n        angles=angles * np.pi / 180.0,\n    )\n    y = A @ jnp.abs(img)\n    # compute the ℓ1/ℓ2 norm of the projection for each view angle\n    cost = jnp.sum(jnp.abs(y), axis=1) / jnp.sqrt(jnp.sum(y**2, axis=1))\n    ext_cost = cost + center_factor * (cost.max() - cost.min()) * jnp.abs(angles)\n    idx = jnp.argmin(ext_cost)\n    return angles[idx]\n\n\ndef volume_alignment_rotation(\n    vol: ArrayLike,\n    xslice: Optional[int] = None,\n    yslice: Optional[int] = None,\n    max_angle: float = 2.5,\n    angle_step: float = 0.025,\n    center_factor: float = 5e-3,\n) -> jsst.Rotation:\n    r\"\"\"Estimate a volume alignment rotation.\n\n    Estimate the 3D rotation that best aligns planar structures in a\n    volume with the x-y (0-1) plane. The algorithm is based on\n    independent rotation angle estimates, obtained using\n    :func:`image_alignment_rotation`, within 2D slices in the x-z (0-2)\n    and y-z (1-2) planes. These estimates are integrated into a\n    combined 3D rotation specification as explained in the technical note\n    below.\n\n    Args:\n        vol: Array of voxel values.\n        xslice: Index of slice on axis 0.\n        yslice: Index of slice on axis 1.\n        max_angle: Maximum  angle (negative and positive) to test, in\n          degrees.\n        angle_step: Increment in angle values for range of angles to\n          test, in degrees.\n        center_factor: The angle multiplied by this scalar is added to\n          the sparsity measure to slightly prefer smaller-angle\n          solutions.\n\n    Returns:\n        Rotation object.\n\n    Notes:\n        The estimation of the 3D rotation required to align planar\n        structure in the volume with the x-y (0-1) plane is approached\n        by estimating the 3D normal vector to this structure, illustrated\n        in Fig. 1. The independent rotation angle estimates with the x-z\n        (0-2) and y-z (1-2) planes are exploited as estimates (after a\n        90° rotation of each) as estimates of the projections of this\n        normal vector into the x-z (0-2) and y-z (1-2) planes,\n        illustrated in Figs. 2 and 3 respectively.\n\n        .. figure:: /figures/vol_align_xyz.svg\n           :align: center\n           :width: 60%\n\n           Fig 1. 3D orientation of the normal to the plane that is\n           desired to be aligned with the x-y plane.\n\n\n        .. list-table::\n           :width: 100\n\n           * - .. figure:: /figures/vol_align_xz.svg\n                  :align: center\n                  :width: 100%\n\n                  Fig 2. Projection of the normal onto the x-z plane.\n\n             - .. figure:: /figures/vol_align_yz.svg\n                  :align: center\n                  :width: 100%\n\n                  Fig 3. Projection of the normal onto the y-z plane.\n\n        It can be observed from these figures that\n\n        .. math::\n\n           x &= r_x \\cos (\\theta_x) \\\\\n           y &= r_y \\cos (\\theta_y) \\\\\n           z &= r_x \\sin (\\theta_x) = r_y \\sin (\\theta_y) \\;,\n\n        where :math:`(x, y, z)` are the coordinates of the normal\n        vector. We can write\n\n        .. math::\n\n           r_x = \\frac{z}{\\sin(\\theta_x)} \\quad \\text{and} \\quad\n           r_y = \\frac{z}{\\sin(\\theta_y)} \\;,\n\n        and therefore\n\n        .. math::\n           x = z \\cot (\\theta_x) \\quad \\text{and} \\quad\n           y = z \\cot (\\theta_y) \\;.\n\n        Since :math:`(x, y, z) = z (\\cot (\\theta_x), \\cot (\\theta_y), 1)`\n        it is clear that the choice of :math:`z` only affects the norm of\n        the vector, and can therefore be set to unity. The rotation of\n        this vector is then determined by computing the rotation required\n        to align it (after normalization) with the :math:`z` axis\n        :math:`(0, 0, 1)`.\n    \"\"\"\n    # x, y, z volume axes correspond to axes 0, 1, 2\n    if xslice is None:\n        xslice = vol.shape[0] // 2  # default to central slice\n    if yslice is None:\n        yslice = vol.shape[1] // 2  # default to central slice\n    # projected angles of normal to plane angles identified in yz and xz slices\n    angle_y = (\n        (90 - image_alignment_rotation(vol[xslice], max_angle=max_angle, angle_step=angle_step))\n        * np.pi\n        / 180\n    )\n    angle_x = (\n        (90 - image_alignment_rotation(vol[:, yslice], max_angle=max_angle, angle_step=angle_step))\n        * np.pi\n        / 180\n    )\n    # unit vector normal to plane\n    vec = np.array([1.0 / np.tan(angle_x), 1.0 / np.tan(angle_y), 1.0])\n    vec /= np.linalg.norm(vec)\n    # rotation required to align unit vector with z axis\n    r = sst.Rotation.align_vectors(vec, np.array([0, 0, 1]))[0]\n    # jax.scipy.spatial.transform.Rotation does not have align_vectors method\n    return jsst.Rotation.from_quat(r.as_quat())\n"
  },
  {
    "path": "scico/linop/xray/_xray.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2023-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"X-ray transform classes.\"\"\"\n\nfrom functools import partial\nfrom typing import Optional, Tuple\nfrom warnings import warn\n\nimport numpy as np\n\nimport jax\nimport jax.numpy as jnp\nfrom jax.typing import ArrayLike\n\nimport scico.numpy as snp\nfrom scico.numpy.util import is_scalar_equiv\nfrom scico.typing import Shape\nfrom scipy.spatial.transform import Rotation\n\nfrom .._linop import LinearOperator\n\n\nclass XRayTransform2D(LinearOperator):\n    r\"\"\"Parallel ray, single axis, 2D X-ray projector.\n\n    This implementation approximates the projection of each rectangular\n    pixel as a boxcar function (whereas the exact projection is a\n    trapezoid). Detector pixels are modeled as bins (rather than points)\n    and this approximation allows fast calculation of the contribution\n    of each pixel to each bin because the integral of the boxcar is\n    simple.\n\n    By requiring the width of a projected pixel to be less than or equal\n    to the bin width (which is defined to be 1.0), we ensure that\n    each pixel contributes to at most two bins, which accelerates the\n    accumulation of pixel values into bins (equivalently, makes the\n    linear operator sparse).\n\n    Warning: The default pixel spacing is :math:`\\sqrt{2}/2` (rather\n    than 1) in order to satisfy the aforementioned spacing requirement.\n\n    `x0`, `dx`, and `y0` should be expressed in units such that the\n    detector spacing `dy` is 1.0.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        angles: ArrayLike,\n        x0: Optional[ArrayLike] = None,\n        dx: Optional[ArrayLike] = None,\n        y0: Optional[float] = None,\n        det_count: Optional[int] = None,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array.\n            angles: (num_angles,) array of angles in radians. Viewing an\n                (M, N) array as a matrix with M rows and N columns, an\n                angle of 0 corresponds to summing rows, an angle of pi/2\n                corresponds to summing columns, and an angle of pi/4\n                corresponds to summing along antidiagonals.\n            x0: (x, y) position of the corner of the pixel `im[0,0]`. By\n                default, `(-input_shape * dx[0] / 2, -input_shape * dx[1] / 2)`.\n            dx: Image pixel side length in x- and y-direction (axis 0 and\n                1 respectively). Must be set so that the width of a\n                projected pixel is never larger than 1.0. By default,\n                [:math:`\\sqrt{2}/2`, :math:`\\sqrt{2}/2`].\n            y0: Location of the edge of the first detector bin. By\n                default, `-det_count / 2`\n            det_count: Number of elements in detector. If ``None``,\n                defaults to the size of the diagonal of `input_shape`.\n        \"\"\"\n        self.input_shape = input_shape\n        self.angles = angles\n\n        self.nx = tuple(input_shape)\n        if dx is None:\n            dx = 2 * (np.sqrt(2) / 2,)\n        if is_scalar_equiv(dx):\n            dx = 2 * (dx,)\n        self.dx = dx\n\n        # check projected pixel width assumption\n        Pdx = np.stack((dx[0] * jnp.cos(angles), dx[1] * jnp.sin(angles)))\n        Pdiag1 = np.abs(Pdx[0] + Pdx[1])\n        Pdiag2 = np.abs(Pdx[0] - Pdx[1])\n        max_width: float = np.max(np.maximum(Pdiag1, Pdiag2))\n\n        if max_width > 1:\n            warn(\n                f\"A projected pixel has width {max_width} > 1.0, \"\n                \"which will reduce projector accuracy.\"\n            )\n\n        if x0 is None:\n            x0 = -(np.array(self.nx) * self.dx) / 2\n        self.x0 = x0\n\n        if det_count is None:\n            det_count = int(np.ceil(np.linalg.norm(input_shape)))\n        self.det_count = det_count\n        self.ny = det_count\n        self.output_shape = (len(angles), det_count)\n\n        if y0 is None:\n            y0 = -self.ny / 2\n        self.y0 = y0\n        self.dy = 1.0\n\n        self.fbp_filter: Optional[snp.Array] = None\n        self.fbp_mask: Optional[snp.Array] = None\n\n        super().__init__(\n            input_shape=self.input_shape,\n            input_dtype=np.float32,\n            output_shape=self.output_shape,\n            output_dtype=np.float32,\n            eval_fn=self.project,\n            adj_fn=self.back_project,\n        )\n\n    def project(self, im: ArrayLike) -> snp.Array:\n        \"\"\"Compute X-ray projection, equivalent to `H @ im`.\n\n        Args:\n            im: Input array representing the image to project.\n        \"\"\"\n        return XRayTransform2D._project(im, self.x0, self.dx, self.y0, self.ny, self.angles)\n\n    def back_project(self, y: ArrayLike) -> snp.Array:\n        \"\"\"Compute X-ray back projection, equivalent to `H.T @ y`.\n\n        Args:\n            y: Input array representing the sinogram to back project.\n        \"\"\"\n        return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles)\n\n    def fbp(self, y: ArrayLike) -> snp.Array:\n        r\"\"\"Compute filtered back projection (FBP) inverse of projection.\n\n        Compute the filtered back projection inverse by filtering each\n        row of the sinogram with the filter defined in (61) in\n        :cite:`kak-1988-principles` and then back projecting. The\n        projection angles are assumed to be evenly spaced in\n        :math:`[0, \\pi)`; reconstruction quality may be poor if\n        this assumption is violated. Poor quality reconstructions should\n        also be expected when `dx[0]` and `dx[1]` are not equal.\n\n        Args:\n            y: Input projection, (num_angles, N).\n\n        Returns:\n            FBP inverse of projection.\n        \"\"\"\n        N = y.shape[1]\n\n        if self.fbp_filter is None:\n            nvec = jnp.arange(N) - (N - 1) // 2\n            self.fbp_filter = XRayTransform2D._ramp_filter(nvec, 1.0).reshape(1, -1)\n\n        if self.fbp_mask is None:\n            unit_sino = jnp.ones(self.output_shape, dtype=np.float32)\n            # Threshold is multiplied by 0.99... fudge factor to account for numerical errors\n            # in back projection.\n            self.fbp_mask = self.back_project(unit_sino) >= (self.output_shape[0] * (1.0 - 1e-5))  # type: ignore\n\n        # Apply ramp filter in the frequency domain, padding to avoid\n        # boundary effects\n        h = self.fbp_filter\n        hf = jnp.fft.fft(h, n=2 * N - 1, axis=1)\n        yf = jnp.fft.fft(y, n=2 * N - 1, axis=1)\n        hy = jnp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[\n            :, (N - 1) // 2 : -(N - 1) // 2\n        ].real.astype(jnp.float32)\n\n        x = (jnp.pi * self.dx[0] * self.dx[1] / y.shape[0]) * self.fbp_mask * self.back_project(hy)  # type: ignore\n        return x\n\n    @staticmethod\n    def _ramp_filter(x: ArrayLike, tau: float) -> snp.Array:\n        \"\"\"Compute coefficients of ramp filter used in FBP.\n\n        Compute coefficients of ramp filter used in FBP, as defined in\n        (61) in :cite:`kak-1988-principles`.\n\n        Args:\n            x: Sampling locations at which to compute filter coefficients.\n            tau: Sampling rate.\n\n        Returns:\n            Spatial-domain coefficients of ramp filter.\n        \"\"\"\n        # The (x == 0) term in x**2 * np.pi**2 * tau**2 + (x == 0)\n        # is included to avoid division by zero warnings when x == 1\n        # since np.where evaluates all values for both True and False\n        # branches.\n        return jnp.where(\n            x == 0,\n            1.0 / (4.0 * tau**2),\n            jnp.where(x % 2, -1.0 / (x**2 * np.pi**2 * tau**2 + (x == 0)), 0),\n        )\n\n    @staticmethod\n    @partial(jax.jit, static_argnames=[\"ny\"])\n    def _project(\n        im: ArrayLike, x0: ArrayLike, dx: ArrayLike, y0: float, ny: int, angles: ArrayLike\n    ) -> snp.Array:\n        r\"\"\"Compute X-ray projection.\n\n        Args:\n            im: Input array, (M, N).\n            x0: (x, y) position of the corner of the pixel im[0,0].\n            dx: Pixel side length in x- and y-direction. Units are such\n                that the detector bins have length 1.0.\n            y0: Location of the edge of the first detector bin.\n            ny: Number of detector bins.\n            angles: (num_angles,) array of angles in radians. Pixels are\n                projected onto unit vectors pointing in these directions.\n        \"\"\"\n        nx = im.shape\n        inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0)\n\n        # avoid incompatible types in the .add (scatter operation)\n        weights = weights.astype(im.dtype)\n\n        # Handle out of bounds indices by setting weight to zero\n        weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0)\n        y = (\n            jnp.zeros((len(angles), ny), dtype=im.dtype)\n            .at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds]\n            .add(im * weights_valid)\n        )\n\n        weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0)\n        y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * weights_valid)\n\n        return y\n\n    @staticmethod\n    @partial(jax.jit, static_argnames=[\"nx\"])\n    def _back_project(\n        y: ArrayLike, x0: ArrayLike, dx: ArrayLike, nx: Shape, y0: float, angles: ArrayLike\n    ) -> snp.Array:\n        r\"\"\"Compute X-ray back projection.\n\n        Args:\n            y: Input projection, (num_angles, N).\n            x0: (x, y) position of the corner of the pixel im[0,0].\n            dx: Pixel side length in x- and y-direction. Units are such\n                that the detector bins have length 1.0.\n            nx: Shape of back projection.\n            y0: Location of the edge of the first detector bin.\n            angles: (num_angles,) array of angles in radians. Pixels are\n                projected onto units vectors pointing in these directions.\n        \"\"\"\n        ny = y.shape[1]\n        inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0)\n        # Handle out of bounds indices by setting weight to zero\n        weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0)\n\n        # the idea: [y[0, inds[0]], y[1, inds[1]], ...]\n        HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights_valid, axis=0)\n\n        weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0)\n        HTy = HTy + jnp.sum(\n            y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * weights_valid, axis=0\n        )\n\n        return HTy.astype(jnp.float32)\n\n    @staticmethod\n    @partial(jax.jit, static_argnames=[\"nx\"])\n    @partial(jax.vmap, in_axes=(None, None, None, 0, None))\n    def _calc_weights(\n        x0: ArrayLike, dx: ArrayLike, nx: Shape, angles: ArrayLike, y0: float\n    ) -> Tuple[snp.Array, snp.Array]:\n        \"\"\"\n\n        Args:\n            x0: Location of the corner of the pixel im[0,0].\n            dx: Pixel side length in x- and y-direction. Units are such\n                that the detector bins have length 1.0.\n            nx: Input image shape.\n            angles: (num_angles,) array of angles in radians. Pixels are\n                projected onto units vectors pointing in these directions.\n                (This argument is `vmap`ed.)\n            y0: Location of the edge of the first detector bin.\n        \"\"\"\n        u = [jnp.cos(angles), jnp.sin(angles)]\n        Px0 = x0[0] * u[0] + x0[1] * u[1] - y0\n        Pdx = [dx[0] * u[0], dx[1] * u[1]]\n        Pxmin = jnp.min(jnp.array([Px0, Px0 + Pdx[0], Px0 + Pdx[1], Px0 + Pdx[0] + Pdx[1]]))\n\n        Px = (\n            Pxmin\n            + Pdx[0] * jnp.arange(nx[0]).reshape(-1, 1)\n            + Pdx[1] * jnp.arange(nx[1]).reshape(1, -1)\n        )\n\n        # detector bin inds\n        inds = jnp.floor(Px).astype(int)\n\n        # weights\n        Pdx = jnp.array(u) * jnp.array(dx)\n        diag1 = jnp.abs(Pdx[0] + Pdx[1])\n        diag2 = jnp.abs(Pdx[0] - Pdx[1])\n        w = jnp.max(jnp.array([diag1, diag2]))\n        f = jnp.min(jnp.array([diag1, diag2]))\n\n        width = (w + f) / 2\n        distance_to_next = 1 - (Px - inds)  # always in (0, 1]\n        weights = jnp.minimum(distance_to_next, width) / width\n\n        return inds, weights\n\n\nclass XRayTransform3D(LinearOperator):\n    r\"\"\"General-purpose, 3D, parallel ray X-ray projector.\n\n    This projector approximates cubic voxels projecting onto\n    rectangular pixels and provides a back projector that is the exact\n    adjoint of the forward projector. It is written purely in JAX,\n    allowing it to run on either CPU or GPU and minimizing host copies.\n\n    Warning: This class is experimental and may be up to ten times slower\n    than :class:`scico.linop.xray.astra.XRayTransform3D`.\n\n    For each view, the projection geometry is specified by an array\n    with shape (2, 4) that specifies a :math:`2 \\times 3` projection\n    matrix and a :math:`2 \\times 1` offset vector. Denoting the matrix\n    by :math:`\\mathbf{M}` and the offset by :math:`\\mathbf{t}`, a voxel at array\n    index `(i, j, k)` has its center projected to the detector coordinates\n\n    .. math::\n        \\mathbf{M} \\begin{bmatrix}\n        i + \\frac{1}{2} \\\\ j + \\frac{1}{2} \\\\ k + \\frac{1}{2}\n        \\end{bmatrix} + \\mathbf{t} \\,.\n\n    The detector pixel at index `(i, j)` covers detector coordinates\n    :math:`[i+1) \\times [j+1)`.\n\n    :meth:`XRayTransform3D.matrices_from_euler_angles` can help to\n    make these geometry arrays.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        matrices: ArrayLike,\n        det_shape: Shape,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input image.\n            matrices: (num_views, 2, 4) array of homogeneous projection matrices.\n            det_shape: Shape of detector.\n        \"\"\"\n\n        self.input_shape: Shape = input_shape\n        self.matrices = jnp.asarray(matrices, dtype=np.float32)\n        self.det_shape = det_shape\n        self.output_shape = (len(matrices), *det_shape)\n        super().__init__(\n            input_shape=input_shape,\n            output_shape=self.output_shape,\n            eval_fn=self.project,\n            adj_fn=self.back_project,\n        )\n\n    def project(self, im: ArrayLike) -> snp.Array:\n        \"\"\"Compute X-ray projection.\"\"\"\n        return XRayTransform3D._project(im, self.matrices, self.det_shape)\n\n    def back_project(self, proj: ArrayLike) -> snp.Array:\n        \"\"\"Compute X-ray back projection\"\"\"\n        return XRayTransform3D._back_project(proj, self.matrices, self.input_shape)\n\n    @staticmethod\n    def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> snp.Array:\n        r\"\"\"\n        Args:\n            im: Input image.\n            matrix: (num_views, 2, 4) array of homogeneous projection matrices.\n            det_shape: Shape of detector.\n        \"\"\"\n        MAX_SLICE_LEN = 10\n        slice_offsets = list(range(0, im.shape[0], MAX_SLICE_LEN))\n\n        num_views = len(matrices)\n        proj = jnp.zeros((num_views,) + det_shape, dtype=im.dtype)\n        for view_ind, matrix in enumerate(matrices):\n            for slice_offset in slice_offsets:\n                proj = proj.at[view_ind].set(\n                    XRayTransform3D._project_single(\n                        im[slice_offset : slice_offset + MAX_SLICE_LEN],\n                        matrix,\n                        proj[view_ind],\n                        slice_offset=slice_offset,\n                    )\n                )\n        return proj\n\n    @staticmethod\n    @partial(jax.jit, donate_argnames=\"proj\")\n    def _project_single(\n        im: ArrayLike, matrix: ArrayLike, proj: ArrayLike, slice_offset: int = 0\n    ) -> snp.Array:\n        r\"\"\"\n        Args:\n            im: Input image.\n            matrix: (2, 4) homogeneous projection matrix.\n            det_shape: Shape of detector.\n        \"\"\"\n\n        ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = XRayTransform3D._calc_weights(\n            im.shape, matrix, proj.shape, slice_offset\n        )\n        proj = proj.at[ul_ind[0], ul_ind[1]].add(ul_weight * im, mode=\"drop\")\n        proj = proj.at[ul_ind[0] + 1, ul_ind[1]].add(ur_weight * im, mode=\"drop\")\n        proj = proj.at[ul_ind[0], ul_ind[1] + 1].add(ll_weight * im, mode=\"drop\")\n        proj = proj.at[ul_ind[0] + 1, ul_ind[1] + 1].add(lr_weight * im, mode=\"drop\")\n        return proj\n\n    @staticmethod\n    def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> snp.Array:\n        r\"\"\"\n        Args:\n            proj: Input (set of) projection(s).\n            matrix: (num_views, 2, 4) array of homogeneous projection matrices.\n            input_shape: Shape of desired back projection.\n        \"\"\"\n        MAX_SLICE_LEN = 10\n        slice_offsets = list(range(0, input_shape[0], MAX_SLICE_LEN))\n\n        HTy = jnp.zeros(input_shape, dtype=proj.dtype)\n        for view_ind, matrix in enumerate(matrices):\n            for slice_offset in slice_offsets:\n                HTy = HTy.at[slice_offset : slice_offset + MAX_SLICE_LEN].set(\n                    XRayTransform3D._back_project_single(\n                        proj[view_ind],\n                        matrix,\n                        HTy[slice_offset : slice_offset + MAX_SLICE_LEN],\n                        slice_offset=slice_offset,\n                    )\n                )\n                HTy.block_until_ready()  # prevent OOM\n\n        return HTy\n\n    @staticmethod\n    @partial(jax.jit, donate_argnames=\"HTy\")\n    def _back_project_single(\n        y: ArrayLike, matrix: ArrayLike, HTy: ArrayLike, slice_offset: int = 0\n    ) -> snp.Array:\n        ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = XRayTransform3D._calc_weights(\n            HTy.shape, matrix, y.shape, slice_offset\n        )\n        HTy = HTy + y[ul_ind[0], ul_ind[1]] * ul_weight\n        HTy = HTy + y[ul_ind[0] + 1, ul_ind[1]] * ur_weight\n        HTy = HTy + y[ul_ind[0], ul_ind[1] + 1] * ll_weight\n        HTy = HTy + y[ul_ind[0] + 1, ul_ind[1] + 1] * lr_weight\n        return HTy\n\n    @staticmethod\n    def _calc_weights(\n        input_shape: Shape, matrix: snp.Array, det_shape: Shape, slice_offset: int = 0\n    ) -> snp.Array:\n        # pixel (0, 0, 0) has its center at (0.5, 0.5, 0.5)\n        x = jnp.mgrid[: input_shape[0], : input_shape[1], : input_shape[2]] + 0.5  # (3, ...)\n        x = x.at[0].add(slice_offset)\n\n        Px = jnp.stack(\n            (\n                matrix[0, 0] * x[0] + matrix[0, 1] * x[1] + matrix[0, 2] * x[2] + matrix[0, 3],\n                matrix[1, 0] * x[0] + matrix[1, 1] * x[1] + matrix[1, 2] * x[2] + matrix[1, 3],\n            )\n        )  # (2, ...)\n\n        # calculate weight on 4 intersecting pixels\n        w = 0.5  # assumed <= 1.0\n        left_edge = Px - w / 2\n        to_next = jnp.minimum(jnp.ceil(left_edge) - left_edge, w)\n        ul_ind = jnp.floor(left_edge).astype(\"int32\")\n\n        ul_weight = to_next[0] * to_next[1] * (1 / w**2)\n        ur_weight = (w - to_next[0]) * to_next[1] * (1 / w**2)\n        ll_weight = to_next[0] * (w - to_next[1]) * (1 / w**2)\n        lr_weight = (w - to_next[0]) * (w - to_next[1]) * (1 / w**2)\n\n        # set weights to zero out of bounds\n        ul_weight = jnp.where(\n            (ul_ind[0] >= 0)\n            * (ul_ind[0] < det_shape[0])\n            * (ul_ind[1] >= 0)\n            * (ul_ind[1] < det_shape[1]),\n            ul_weight,\n            0.0,\n        )\n        ur_weight = jnp.where(\n            (ul_ind[0] + 1 >= 0)\n            * (ul_ind[0] + 1 < det_shape[0])\n            * (ul_ind[1] >= 0)\n            * (ul_ind[1] < det_shape[1]),\n            ur_weight,\n            0.0,\n        )\n        ll_weight = jnp.where(\n            (ul_ind[0] >= 0)\n            * (ul_ind[0] < det_shape[0])\n            * (ul_ind[1] + 1 >= 0)\n            * (ul_ind[1] + 1 < det_shape[1]),\n            ll_weight,\n            0.0,\n        )\n        lr_weight = jnp.where(\n            (ul_ind[0] + 1 >= 0)\n            * (ul_ind[0] + 1 < det_shape[0])\n            * (ul_ind[1] + 1 >= 0)\n            * (ul_ind[1] + 1 < det_shape[1]),\n            lr_weight,\n            0.0,\n        )\n\n        return ul_ind, ul_weight, ur_weight, ll_weight, lr_weight\n\n    @staticmethod\n    def matrices_from_euler_angles(\n        input_shape: Shape,\n        output_shape: Shape,\n        seq: str,\n        angles: ArrayLike,\n        degrees: bool = False,\n        voxel_spacing: ArrayLike = None,\n        det_spacing: ArrayLike = None,\n    ) -> snp.Array:\n        \"\"\"\n        Create a set of projection matrices from Euler angles. The\n        input voxels will undergo the specified rotation and then be\n        projected onto the global xy-plane.\n\n        Args:\n            input_shape: Shape of input image.\n            output_shape: Shape of output (detector).\n            str: Sequence of axes for rotation. Up to 3 characters belonging to the set {'X', 'Y', 'Z'}\n                for intrinsic rotations, or {'x', 'y', 'z'} for extrinsic rotations. Extrinsic and\n                intrinsic rotations cannot be mixed in one function call.\n            angles: (num_views, N), N = 1, 2, or 3 Euler angles.\n            degrees: If ``True``, angles are in degrees, otherwise radians. Default: ``True``, radians.\n            voxel_spacing: (3,) array giving the spacing of image\n                voxels.  Default: `[1.0, 1.0, 1.0]`. Experimental.\n            det_spacing: (2,) array giving the spacing of detector\n                pixels.  Default: `[1.0, 1.0]`. Experimental.\n\n\n        Returns:\n            (num_views, 2, 4) array of homogeneous projection matrices.\n        \"\"\"\n\n        if voxel_spacing is None:\n            voxel_spacing = np.ones(3)\n\n        if det_spacing is None:\n            det_spacing = np.ones(2)\n\n        # make projection matrix: form a rotation matrix and chop off the last row\n        matrices = Rotation.from_euler(seq, angles, degrees=degrees).as_matrix()\n        matrices = matrices[:, :2, :]  # (num_views, 2, 3)\n\n        # handle scaling\n        M_voxel = np.diag(voxel_spacing)  # (3, 3)\n        M_det = np.diag(1 / np.array(det_spacing))  # (2, 2)\n\n        # idea: M_det * M * M_voxel, but with a leading batch dimension\n        matrices = np.einsum(\"vmn,nn->vmn\", matrices, M_voxel)\n        matrices = np.einsum(\"mm,vmn->vmn\", M_det, matrices)\n\n        # add translation to line up the centers\n        x0 = np.array(input_shape) / 2\n        t = -np.einsum(\"vmn,n->vm\", matrices, x0) + np.array(output_shape) / 2\n        matrices = snp.concatenate((matrices, t[..., np.newaxis]), axis=2)\n\n        return matrices\n"
  },
  {
    "path": "scico/linop/xray/abel.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Abel transform LinearOperator wrapping the pyabel package.\n\nAbel transform LinearOperator wrapping the\n`pyabel <https://github.com/PyAbel/PyAbel>`_ package.\n\"\"\"\n\nimport math\nfrom typing import Optional\n\nimport numpy as np\n\nimport jax\nimport jax.numpy as jnp\nimport jax.numpy.fft as jnfft\n\nimport abel\n\nfrom scico.linop import LinearOperator\nfrom scico.typing import Shape\nfrom scipy.linalg import solve_triangular\n\n\nclass AbelTransform(LinearOperator):\n    r\"\"\"Abel transform based on `PyAbel <https://github.com/PyAbel/PyAbel>`_.\n\n    Perform Abel transform (parallel beam projection of cylindrically\n    symmetric objects) for a 2D image. The input 2D image is assumed to\n    be centered and left-right symmetric.\n    \"\"\"\n\n    def __init__(self, img_shape: Shape):\n        \"\"\"\n        Args:\n            img_shape: Shape of the input image.\n        \"\"\"\n        self.proj_mat_quad = _pyabel_daun_get_proj_matrix(img_shape)\n\n        super().__init__(\n            input_shape=img_shape,\n            output_shape=img_shape,\n            input_dtype=np.float32,\n            output_dtype=np.float32,\n            adj_fn=self._adj,\n            jit=True,\n        )\n\n    def _eval(self, x: jax.Array) -> jax.Array:\n        return _pyabel_transform(x, direction=\"forward\", proj_mat_quad=self.proj_mat_quad).astype(\n            self.output_dtype\n        )\n\n    def _adj(self, x: jax.Array) -> jax.Array:  # type: ignore\n        return _pyabel_transform(x, direction=\"transpose\", proj_mat_quad=self.proj_mat_quad).astype(\n            self.input_dtype\n        )\n\n    def inverse(self, y: jax.Array) -> jax.Array:\n        \"\"\"Perform inverse Abel transform.\n\n        Args:\n            y: Input image (assumed to be a result of an Abel transform).\n\n        Returns:\n            Output of inverse Abel transform.\n        \"\"\"\n        return _pyabel_transform(y, direction=\"inverse\", proj_mat_quad=self.proj_mat_quad).astype(\n            self.input_dtype\n        )\n\n\ndef _pyabel_transform(\n    x: jax.Array, direction: str, proj_mat_quad: jax.Array, symmetry_axis: Optional[list] = None\n) -> jax.Array:\n    \"\"\"Apply Abel transforms (forward, inverse and transposed).\n\n    This function contains code copied from `PyAbel <https://github.com/PyAbel/PyAbel>`_.\n    \"\"\"\n\n    if symmetry_axis is None:\n        symmetry_axis = [None]\n\n    Q0, Q1, Q2, Q3 = get_image_quadrants(\n        x, symmetry_axis=symmetry_axis, use_quadrants=(True, True, True, True)\n    )\n\n    def transform_quad(data):\n        if direction == \"forward\":\n            return data.dot(proj_mat_quad)\n        elif direction == \"transpose\":\n            return data.dot(proj_mat_quad.T)\n        elif direction == \"inverse\":\n            return solve_triangular(proj_mat_quad.T, data.T).T\n        else:\n            ValueError(\"Unsupported direction\")\n\n    AQ0 = AQ1 = AQ2 = AQ3 = None\n    AQ1 = transform_quad(Q1)\n\n    if 1 not in symmetry_axis:\n        AQ2 = transform_quad(Q2)\n\n    if 0 not in symmetry_axis:\n        AQ0 = transform_quad(Q0)\n\n    if None in symmetry_axis:\n        AQ3 = transform_quad(Q3)\n\n    return put_image_quadrants(\n        (AQ0, AQ1, AQ2, AQ3), original_image_shape=x.shape, symmetry_axis=symmetry_axis\n    )\n\n\ndef _pyabel_daun_get_proj_matrix(img_shape: Shape) -> jax.Array:\n    \"\"\"Get single-quadrant projection matrix.\"\"\"\n    proj_matrix = abel.daun.get_bs_cached(\n        math.ceil(img_shape[1] / 2),\n        degree=0,\n        reg_type=None,\n        strength=0,\n        direction=\"forward\",\n        verbose=False,\n    )\n    return jnp.array(proj_matrix)\n\n\n# Read abel.tools.symmetry module into a string.\nmod_file = abel.tools.symmetry.__file__\nwith open(mod_file, \"r\") as f:\n    mod_str = f.read()\n\n# Replace numpy functions that touch the main arrays with corresponding jax.numpy functions\nmod_str = mod_str.replace(\"fftpack.\", \"jnfft.\")\nmod_str = mod_str.replace(\"np.atleast_2d\", \"jnp.atleast_2d\")\nmod_str = mod_str.replace(\"np.flip\", \"jnp.flip\")\nmod_str = mod_str.replace(\"np.concat\", \"jnp.concat\")\n\n# Exec the module extract defined functions from the exec scope\nscope = {\"jnp\": jnp, \"jnfft\": jnfft}\nexec(mod_str, scope)\nget_image_quadrants = scope[\"get_image_quadrants\"]\nput_image_quadrants = scope[\"put_image_quadrants\"]\n"
  },
  {
    "path": "scico/linop/xray/astra.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"X-ray transform LinearOperators wrapping the ASTRA toolbox.\n\nX-ray transform :class:`.LinearOperator` wrapping the parallel beam\nprojections in the\n`ASTRA toolbox <https://github.com/astra-toolbox/astra-toolbox>`_.\nThis package provides both C and CUDA implementations of core\nfunctionality, but note that use of the CUDA/GPU implementation is\nexpected to result in GPU-host-GPU memory copies when transferring\nJAX arrays. Other JAX features such as automatic differentiation are\nnot available.\n\nFunctions here refer to three coordinate systems: world coordinates,\nvolume coordinates, and detector coordinates. World coordinates are 3D\ncoordinates representing a point in physical space. Volume coordinates\nrefer to a position in the reconstruction volume, where the voxel with\nits intensity value stored at `vol[i, j, k]` has its center at volume\ncoordinate (i+0.5, j+0.5, k+0.5) and side lengths of 1. Detector\ncoordinates refer to a position on the detector array, and the pixel at\n`det[i, j]` has its center at detector coordinates (i+0.5, j+0.5) and\nside lengths of one.\n\n\"\"\"\n\nfrom typing import List, Optional, Sequence, Tuple, Union\n\nimport numpy as np\nimport numpy.typing\n\nimport jax\nfrom jax.typing import ArrayLike\n\nfrom scipy.spatial.transform import Rotation\n\ntry:\n    import astra\nexcept ModuleNotFoundError as e:\n    if e.name == \"astra\":\n        new_e = ModuleNotFoundError(\"Could not import astra; please install the ASTRA toolbox.\")\n        new_e.name = \"astra\"\n        raise new_e from e\n    else:\n        raise e\n\ntry:\n    from collections import Iterable  # type: ignore\nexcept ImportError:\n    import collections\n\n    # Monkey patching required because latest astra release uses old module path for Iterable\n    collections.Iterable = collections.abc.Iterable  # type: ignore\n\nfrom scico.linop import LinearOperator\nfrom scico.typing import Shape, TypeAlias\n\nVolumeGeometry: TypeAlias = dict\nProjectionGeometry: TypeAlias = dict\n\n\ndef set_astra_gpu_index(idx: Union[int, Sequence[int]]):\n    \"\"\"Set the index/indices of GPU(s) to be used by astra.\n\n    Args:\n        idx: Index or indices of GPU(s).\n    \"\"\"\n    astra.set_gpu_index(idx)\n\n\ndef _project_coords(\n    x_volume: np.ndarray, vol_geom: VolumeGeometry, proj_geom: ProjectionGeometry\n) -> np.ndarray:\n    \"\"\"\n    Project volume coordinates into detector coordinates based on ASTRA\n    geometry objects.\n\n    Args:\n        x_volume: (..., 3) vector(s) of volume coordinates.\n        vol_geom: ASTRA volume geometry object.\n        proj_geom: ASTRA projection geometry object.\n\n    Returns:\n        (num_angles, ..., 2) array of detector coordinates corresponding\n        to projections of the points in `x_volume`.\n\n    \"\"\"\n    det_shape = (proj_geom[\"DetectorRowCount\"], proj_geom[\"DetectorColCount\"])\n    x_world = volume_coords_to_world_coords(x_volume, vol_geom=vol_geom)\n    x_dets = []\n    for vec in proj_geom[\"Vectors\"]:\n        ray, d, u, v = vec[0:3], vec[3:6], vec[6:9], vec[9:12]\n        x_det = project_world_coordinates(x_world, ray, d, u, v, det_shape)\n        x_dets.append(x_det)\n\n    return np.stack(x_dets)\n\n\ndef project_world_coordinates(\n    x: np.ndarray,\n    ray: np.typing.ArrayLike,\n    d: np.typing.ArrayLike,\n    u: np.typing.ArrayLike,\n    v: np.typing.ArrayLike,\n    det_shape: Sequence[int],\n) -> np.ndarray:\n    \"\"\"Project world coordinates along ray into the specified basis.\n\n    Project world coordinates along `ray` into the basis described by `u`\n    and `v` with center `d`.\n\n    Args:\n        x: (..., 3) vector(s) of world coordinates.\n        ray: (3,) ray direction\n        d: (3,) center of the detector\n        u: (3,) vector from detector pixel (0,0) to (0,1), columns, x\n        v: (3,) vector from detector pixel (0,0) to (1,0), rows, y\n\n    Returns:\n        (..., 2) vector(s) in the detector coordinates\n\n    \"\"\"\n    Phi = np.stack((ray, u, v), axis=1)\n    x = x - d  # express with respect to detector center\n    alpha = np.linalg.pinv(Phi) @ x[..., :, np.newaxis]  # (3,3) times <stack of> (3,1)\n    alpha = alpha[..., 0]  # squash from (..., 3, 1) to (..., 3)\n    Palpha = alpha[..., 1:]  # throw away ray coordinate\n    det_center_idx = (\n        np.array(det_shape)[::-1] / 2 - 0.5\n    )  # center of length-2 is index 0.5, length-3 -> index 1\n    ind_xy = Palpha + det_center_idx\n    ind_ij = ind_xy[..., ::-1]\n    return ind_ij\n\n\ndef volume_coords_to_world_coords(idx: np.ndarray, vol_geom: VolumeGeometry) -> np.ndarray:\n    \"\"\"Convert a volume coordinate into a world coordinate.\n\n    Convert a volume coordinate into a world coordinate using ASTRA\n    conventions.\n\n    Args:\n        idx: (..., 2) or (..., 3) vector(s) of index coordinates.\n        vol_geom: ASTRA volume geometry object.\n\n    Returns:\n        (..., 2) or (..., 3) vector(s) of world coordinates.\n\n    \"\"\"\n    if \"GridSliceCount\" not in vol_geom:\n        return _volume_index_to_astra_world_2d(idx, vol_geom)\n\n    return _volume_index_to_astra_world_3d(idx, vol_geom)\n\n\ndef _volume_index_to_astra_world_2d(idx: np.ndarray, vol_geom: VolumeGeometry) -> np.ndarray:\n    \"\"\"Convert a 2D volume coordinate into a 2D world coordinate.\"\"\"\n    coord = idx[..., [1, 0]]  # x:col, y:row,\n    nx = np.array(  # (x, y) order\n        (\n            vol_geom[\"GridColCount\"],\n            vol_geom[\"GridRowCount\"],\n        )\n    )\n    opt = vol_geom[\"option\"]\n    dx = np.array(\n        (\n            (opt[\"WindowMaxX\"] - opt[\"WindowMinX\"]) / nx[0],\n            (opt[\"WindowMaxY\"] - opt[\"WindowMinY\"]) / nx[1],\n        )\n    )\n    center_coord = nx / 2 - 0.5  # center of length-2 is index 0.5, center of length-3 is index 1\n    return (coord - center_coord) * dx\n\n\ndef _volume_index_to_astra_world_3d(idx: np.ndarray, vol_geom: VolumeGeometry) -> np.ndarray:\n    \"\"\"Convert a 3D volume coordinate into a 3D world coordinate.\"\"\"\n    coord = idx[..., [2, 1, 0]]  # x:col, y:row, z:slice\n    nx = np.array(  # (x, y, z) order\n        (\n            vol_geom[\"GridColCount\"],\n            vol_geom[\"GridRowCount\"],\n            vol_geom[\"GridSliceCount\"],\n        )\n    )\n    opt = vol_geom[\"option\"]\n    dx = np.array(\n        (\n            (opt[\"WindowMaxX\"] - opt[\"WindowMinX\"]) / nx[0],\n            (opt[\"WindowMaxY\"] - opt[\"WindowMinY\"]) / nx[1],\n            (opt[\"WindowMaxZ\"] - opt[\"WindowMinZ\"]) / nx[2],\n        )\n    )\n    center_coord = nx / 2 - 0.5  # center of length-2 is index 0.5, center of length-3 is index 1\n    return (coord - center_coord) * dx\n\n\nclass XRayTransform2D(LinearOperator):\n    r\"\"\"2D parallel beam X-ray transform based on the ASTRA toolbox.\n\n    Perform tomographic projection (also called X-ray projection) of an\n    image at specified angles, using the\n    `ASTRA toolbox <https://github.com/astra-toolbox/astra-toolbox>`_.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        det_count: int,\n        det_spacing: float,\n        angles: np.ndarray,\n        volume_geometry: Optional[List[float]] = None,\n        device: str = \"auto\",\n    ):\n        \"\"\"\n        Args:\n            input_shape: Shape of the input array.\n            det_count: Number of detector elements. See the\n               `astra documentation <https://www.astra-toolbox.com/docs/geom2d.html#projection-geometries>`__\n               for more information.\n            det_spacing: Spacing between detector elements. See the\n               `astra documentation <https://www.astra-toolbox.com/docs/geom2d.html#projection-geometries>`__\n               for more information..\n            angles: Array of projection angles in radians.\n            volume_geometry: Specification of the shape of the\n               discretized reconstruction volume. Must either ``None``,\n               in which case it is inferred from `input_shape`, or\n               follow the syntax described in the\n               `astra documentation <https://www.astra-toolbox.com/docs/geom2d.html#volume-geometries>`__.\n            device: Specifies device for projection operation.\n               One of [\"auto\", \"gpu\", \"cpu\"]. If \"auto\", a GPU is used if\n               available, otherwise, the CPU is used.\n        \"\"\"\n\n        self.num_dims = len(input_shape)\n        if self.num_dims != 2:\n            raise ValueError(\n                f\"Only 2D projections are supported, but 'input_shape' is {input_shape}.\"\n            )\n        if not isinstance(det_count, int):\n            raise ValueError(\"Expected argument 'det_count' to be an int.\")\n        output_shape: Shape = (len(angles), det_count)\n\n        # Set up all the ASTRA config\n        self.det_spacing = det_spacing\n        self.det_count = det_count\n        self.angles: np.ndarray = np.array(angles)\n\n        self.proj_geom: dict = astra.create_proj_geom(\n            \"parallel\", det_spacing, det_count, self.angles\n        )\n\n        self.proj_id: Optional[int]\n        self.input_shape: tuple = input_shape\n\n        if volume_geometry is None:\n            self.vol_geom = astra.create_vol_geom(*input_shape)\n        else:\n            if len(volume_geometry) == 4:\n                self.vol_geom = astra.create_vol_geom(*input_shape, *volume_geometry)\n            else:\n                raise ValueError(\n                    \"Argument 'volume_geometry' must be a tuple of len 4.\"\n                    \"Please see the astra documentation for details.\"\n                )\n\n        if device in [\"cpu\", \"gpu\"]:\n            # If cpu or gpu selected, attempt to comply (no checking to\n            # confirm that a gpu is available to astra).\n            self.device = device\n        elif device == \"auto\":\n            # If auto selected, use cpu or gpu depending on the default\n            # jax device (for simplicity, no checking whether gpu is\n            # available to astra when one is not available to jax).\n            dev0 = jax.devices()[0]\n            self.device = dev0.platform\n        else:\n            raise ValueError(f\"Invalid 'device' specified; got {device}.\")\n\n        if self.device == \"cpu\":\n            self.proj_id = astra.create_projector(\"line\", self.proj_geom, self.vol_geom)\n        elif self.device == \"gpu\":\n            self.proj_id = astra.create_projector(\"cuda\", self.proj_geom, self.vol_geom)\n\n        # Wrap our non-jax function to indicate we will supply fwd/rev mode functions\n        self._eval = jax.custom_vjp(self._proj)\n        self._eval.defvjp(lambda x: (self._proj(x), None), lambda _, y: (self._bproj(y),))  # type: ignore\n        self._adj = jax.custom_vjp(self._bproj)\n        self._adj.defvjp(lambda y: (self._bproj(y), None), lambda _, x: (self._proj(x),))  # type: ignore\n\n        super().__init__(\n            input_shape=self.input_shape,\n            output_shape=output_shape,\n            input_dtype=np.float32,\n            output_dtype=np.float32,\n            adj_fn=self._adj,\n            jit=False,\n        )\n\n    def _proj(self, x: jax.Array) -> jax.Array:\n        # apply the forward projector and generate a sinogram\n\n        def f(x):\n            x = _ensure_writeable(x)\n            proj_id, result = astra.create_sino(x, self.proj_id)\n            astra.data2d.delete(proj_id)\n            return result\n\n        return jax.pure_callback(f, jax.ShapeDtypeStruct(self.output_shape, self.output_dtype), x)\n\n    def _bproj(self, y: jax.Array) -> jax.Array:\n        # apply backprojector\n        def f(y):\n            y = _ensure_writeable(y)\n            proj_id, result = astra.create_backprojection(y, self.proj_id)\n            astra.data2d.delete(proj_id)\n            return result\n\n        return jax.pure_callback(f, jax.ShapeDtypeStruct(self.input_shape, self.input_dtype), y)\n\n    def fbp(self, sino: jax.Array, filter_type: str = \"Ram-Lak\") -> jax.Array:\n        \"\"\"Filtered back projection (FBP) reconstruction.\n\n        Perform tomographic reconstruction using the filtered back\n        projection (FBP) algorithm.\n\n        Args:\n            sino: Sinogram to reconstruct.\n            filter_type: Select the filter to use. For a list of options\n               see `cfg.FilterType` in the `ASTRA documentation\n               <https://www.astra-toolbox.com/docs/algs/FBP_CUDA.html>`__.\n\n        Returns:\n            Reconstructed volume.\n        \"\"\"\n\n        def f(sino):\n            sino = _ensure_writeable(sino)\n            sino_id = astra.data2d.create(\"-sino\", self.proj_geom, sino)\n\n            # create memory for result\n            rec_id = astra.data2d.create(\"-vol\", self.vol_geom)\n\n            # start to populate config\n            cfg = astra.astra_dict(\"FBP_CUDA\" if self.device == \"gpu\" else \"FBP\")\n            cfg[\"ReconstructionDataId\"] = rec_id\n            cfg[\"ProjectorId\"] = self.proj_id\n            cfg[\"ProjectionDataId\"] = sino_id\n            cfg[\"option\"] = {\"FilterType\": filter_type}\n\n            # initialize algorithm; run\n            alg_id = astra.algorithm.create(cfg)\n            astra.algorithm.run(alg_id)\n\n            # get the result\n            out = astra.data2d.get(rec_id)\n\n            # cleanup FBP-specific arra\n            astra.algorithm.delete(alg_id)\n            astra.data2d.delete(rec_id)\n            astra.data2d.delete(sino_id)\n            return out\n\n        return jax.pure_callback(f, jax.ShapeDtypeStruct(self.input_shape, self.input_dtype), sino)\n\n\ndef convert_from_scico_geometry(\n    in_shape: Shape, matrices: ArrayLike, det_shape: Shape\n) -> np.ndarray:\n    \"\"\"Convert SCICO projection matrices into ASTRA \"parallel3d_vec\" vectors.\n\n    For 3D arrays,\n    in ASTRA, the dimensions go (slices, rows, columns) and (z, y, x);\n    in SCICO, the dimensions go (x, y, z).\n\n    In ASTRA, the x-grid (recon) is centered on the origin and the y-grid (projection) can move.\n    In SCICO, the x-grid origin is the center of x[0, 0, 0], the y-grid origin is the center\n    of y[0, 0].\n\n    See section \"parallel3d_vec\" in the\n    `astra documentation <https://astra-toolbox.com/docs/geom3d.html#projection-geometries>`__.\n\n    Args:\n        in_shape: Shape of input image.\n        matrices: (num_angles, 2, 4) array of homogeneous projection matrices.\n        det_shape: Shape of detector.\n\n    Returns:\n        (num_angles, 12) vector array in the ASTRA \"parallel3d_vec\" convention.\n    \"\"\"\n    # ray is perpendicular to projection axes\n    ray = np.cross(matrices[:, 0, :3], matrices[:, 1, :3])\n    # detector center comes from lifting the center index to 3D\n    y_center = (np.array(det_shape) - 1) / 2\n    x_center = (\n        np.einsum(\"...mn,n->...m\", matrices[..., :3], (np.array(in_shape) - 1) / 2)\n        + matrices[..., 3]\n    )\n    d = np.einsum(\"...mn,...m->...n\", matrices[..., :3], y_center - x_center)  # (V, 2, 3) x (V, 2)\n    u = matrices[:, 1, :3]\n    v = matrices[:, 0, :3]\n\n    # handle different axis conventions\n    ray = ray[:, [2, 1, 0]]\n    d = d[:, [2, 1, 0]]\n    u = u[:, [2, 1, 0]]\n    v = v[:, [2, 1, 0]]\n\n    vectors = np.concatenate((ray, d, u, v), axis=1)  # (v, 12)\n    return vectors\n\n\ndef _astra_to_scico_geometry(vol_geom: VolumeGeometry, proj_geom: ProjectionGeometry) -> np.ndarray:\n    \"\"\"Convert ASTRA geometry objects into a SCICO projection matrix.\n\n    Convert ASTRA volume and projection geometry into a SCICO X-ray\n    projection matrix, assuming \"parallel3d_vec\" format.\n\n    The approach is to locate 3 points in the volume domain,\n    deduce the corresponding projection locations, and, then, solve a\n    linear system to determine the affine relationship between them.\n\n    Args:\n        vol_geom: ASTRA volume geometry object.\n        proj_geom: ASTRA projection geometry object.\n\n    Returns:\n        (num_angles, 2, 4) array of homogeneous projection matrices.\n\n    \"\"\"\n    x_volume = np.concatenate((np.zeros((1, 3)), np.eye(3)), axis=0)  # (4, 3)\n    x_dets = _project_coords(x_volume, vol_geom, proj_geom)  # (num_angles, 4, 2)\n\n    x_volume_aug = np.concatenate((x_volume, np.ones((4, 1))), axis=1)  # (4, 4)\n    matrices = []\n    for x_det in x_dets:\n        M = np.linalg.solve(x_volume_aug, x_det).T\n        np.testing.assert_allclose(M @ x_volume_aug[0], x_det[0])\n        matrices.append(M)\n\n    return np.stack(matrices)\n\n\ndef convert_to_scico_geometry(\n    input_shape: Shape,\n    det_count: Tuple[int, int],\n    det_spacing: Optional[Tuple[float, float]] = None,\n    angles: Optional[np.ndarray] = None,\n    vectors: Optional[np.ndarray] = None,\n) -> np.ndarray:\n    \"\"\"Convert X-ray geometry specification to a SCICO projection matrix.\n\n    The approach is to locate 3 points in the volume domain,\n    deduce the corresponding projection locations, and, then, solve a\n    linear system to determine the affine relationship between them.\n\n    Args:\n        input_shape: Shape of the input array.\n        det_count: Number of detector elements. See the\n           `astra documentation <https://www.astra-toolbox.com/docs/geom3d.html#projection-geometries>`__\n           for more information.\n        det_spacing: Spacing between detector elements. See the\n           `astra documentation <https://www.astra-toolbox.com/docs/geom3d.html#projection-geometries>`__\n           for more information.\n        angles: Array of projection angles in radians. This parameter is\n            mutually exclusive with `vectors`.\n        vectors: Array of ASTRA geometry specification vectors. This\n            parameter is mutually exclusive with `angles`.\n\n    Returns:\n        (num_angles, 2, 4) array of homogeneous projection matrices.\n\n    \"\"\"\n    if angles is not None and vectors is not None:\n        raise ValueError(\"Arguments 'angles' and 'vectors' are mutually exclusive.\")\n    if angles is None and vectors is None:\n        raise ValueError(\"Exactly one of arguments 'angles' and 'vectors' must be provided.\")\n    vol_geom, proj_geom = XRayTransform3D.create_astra_geometry(\n        input_shape, det_count, det_spacing=det_spacing, angles=angles, vectors=vectors\n    )\n    return _astra_to_scico_geometry(vol_geom, proj_geom)\n\n\nclass XRayTransform3D(LinearOperator):  # pragma: no cover\n    r\"\"\"3D parallel beam X-ray transform based on the ASTRA toolbox.\n\n    Perform tomographic projection (also called X-ray projection) of a\n    volume at specified angles, using the\n    `ASTRA toolbox <https://github.com/astra-toolbox/astra-toolbox>`_.\n    The `3D geometries <https://astra-toolbox.com/docs/geom3d.html#projection-geometries>`__\n    \"parallel3d\" and \"parallel3d_vec\" are supported by this interface.\n    Note that a CUDA GPU is required for the primary functionality of\n    this class; if no GPU is available, initialization will fail with a\n    :exc:`RuntimeError`.\n\n    The volume is fixed with respect to the coordinate system, centered\n    at the origin, as illustrated below:\n\n    .. plot:: pyfigures/xray_3d_vol.py\n       :align: center\n       :include-source: False\n       :show-source-link: False\n\n    The voxels sides have unit length (in arbitrary units), which defines\n    the scale for all other dimensions in the source-volume-detector\n    configuration. Geometry axes `z`, `y`, and `x` correspond to volume\n    array axes 0, 1, and 2 respectively. The projected array axes 0, 1,\n    and 2 correspond respectively to detector rows, views, and detector\n    columns.\n\n    In the \"parallel3d\" case, the source and detector rotate clockwise\n    about the `z` axis in the `x`-`y` plane, as illustrated below:\n\n    .. plot:: pyfigures/xray_3d_ang.py\n       :align: center\n       :include-source: False\n       :show-source-link: False\n       :caption: Each radial arrow indicates the direction of the beam\n          towards the detector (indicated in orange in the \"light\"\n          display mode) and the arrow parallel to the detector indicates\n          the direction of increasing pixel indices.\n\n    In this case the `z` axis is in the same direction as the\n    vertical/row axis of the detector and its projection corresponds to\n    a vertical line in the center of the horizontal/column detector axis.\n    Note that the view images must be displayed with the origin at the\n    bottom left (i.e. vertically inverted from the top left origin image\n    indexing convention) in order for the projections to correspond to\n    the positive up/negative down orientation of the `z` axis in the\n    figures here.\n\n    In the \"parallel3d_vec\" case, each view is determined by the following\n    vectors:\n\n    .. list-table:: View definition vectors\n       :widths: 10 90\n\n       * - :math:`\\mb{r}`\n         - Direction of the parallel beam\n       * - :math:`\\mb{d}`\n         - Center of the detector\n       * - :math:`\\mb{u}`\n         - Vector from detector pixel (0,0) to (0,1) (direction of\n           increasing detector column index)\n       * - :math:`\\mb{v}`\n         - Vector from detector pixel (0,0) to (1,0) (direction of\n           increasing detector row index)\n\n    Note that the components of these vectors are in `x`, `y`, `z` order,\n    not the `z`, `y`, `x` order of the volume axes.\n\n    .. plot:: pyfigures/xray_3d_vec.py\n       :align: center\n       :include-source: False\n       :show-source-link: False\n\n    Vector :math:`\\mb{r}` is not illustrated to avoid cluttering the\n    figure, but will typically be directed toward the center of the\n    detector (i.e. in the direction of :math:`\\mb{d}` in the figure.)\n    Since the volume-detector distance does not have a geometric effect\n    for a parallel-beam configuration, :math:`\\mb{d}` may be set to the\n    zero vector when the detector and beam centers coincide (e.g., as in\n    the case of the \"parallel3d\" geometry). Note that the view images\n    must be displayed with the origin at the bottom left (i.e. vertically\n    inverted from the top left origin image indexing convention) in order\n    for the row indexing of the projections to correspond to the\n    direction of :math:`\\mb{v}` in the figure.\n\n    These vectors are concatenated into a single row vector\n    :math:`(\\mb{r}, \\mb{d}, \\mb{u}, \\mb{v})` to form the full\n    geometry specification for a single view, and multiple such\n    row vectors are stacked to specify the geometry for a set\n    of views.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        det_count: Tuple[int, int],\n        det_spacing: Optional[Tuple[float, float]] = None,\n        angles: Optional[np.ndarray] = None,\n        vectors: Optional[np.ndarray] = None,\n    ):\n        \"\"\"\n        Keyword arguments `det_spacing` and `angles` should be specified\n        to use the \"parallel3d\" geometry, and keyword argument `vectors`\n        should be specified to use the \"parallel3d_vec\" geometry. These\n        parameters are mutually exclusive.\n\n        Args:\n            input_shape: Shape of the input array.\n            det_count: Number of detector elements. See the\n               `astra documentation <https://www.astra-toolbox.com/docs/geom3d.html#projection-geometries>`__\n               for more information.\n            det_spacing: Spacing between detector elements. See the\n               `astra documentation <https://www.astra-toolbox.com/docs/geom3d.html#projection-geometries>`__\n               for more information.\n            angles: Array of projection angles in radians. This\n                parameter is  mutually exclusive with `vectors`.\n            vectors: Array of ASTRA geometry specification vectors. This\n                parameter is mutually exclusive with `angles`.\n\n        Raises:\n            RuntimeError: If a CUDA GPU is not available to the ASTRA\n                toolbox.\n        \"\"\"\n        if not astra.use_cuda():\n            raise RuntimeError(\"CUDA GPU required but not available or not enabled.\")\n\n        if not (\n            (det_spacing is not None and angles is not None and vectors is None)\n            or (vectors is not None and det_spacing is None and angles is None)\n        ):\n            raise ValueError(\n                \"Keyword arguments 'det_spacing' and 'angles', or keyword argument \"\n                \"'vectors' must be specified, but not both.\"\n            )\n\n        self.num_dims = len(input_shape)\n        if self.num_dims != 3:\n            raise ValueError(\n                f\"Only 3D projections are supported, but 'input_shape' is {input_shape}.\"\n            )\n\n        if not isinstance(det_count, (list, tuple)) or len(det_count) != 2:\n            raise ValueError(\"Expected argument 'det_count' to be a tuple with 2 elements.\")\n        if angles is not None and vectors is not None:\n            raise ValueError(\"Arguments 'angles' and 'vectors' are mutually exclusive.\")\n        if angles is None and vectors is None:\n            raise ValueError(\n                \"Exactly one of the arguments 'angles' and 'vectors' must be provided.\"\n            )\n        if angles is not None:\n            Nview = angles.size\n            self.angles: Optional[np.ndarray] = np.array(angles)\n            self.vectors: Optional[np.ndarray] = None\n        if vectors is not None:\n            Nview = vectors.shape[0]\n            self.vectors = np.array(vectors)\n            self.angles = None\n        output_shape: Shape = (det_count[0], Nview, det_count[1])\n\n        self.det_count = det_count\n        assert isinstance(det_count, (list, tuple))\n        self.input_shape: tuple = input_shape\n        self.vol_geom, self.proj_geom = self.create_astra_geometry(\n            input_shape,\n            det_count,\n            det_spacing=det_spacing,\n            angles=self.angles,\n            vectors=self.vectors,\n        )\n\n        # Wrap our non-jax function to indicate we will supply fwd/rev mode functions\n        self._eval = jax.custom_vjp(self._proj)\n        self._eval.defvjp(lambda x: (self._proj(x), None), lambda _, y: (self._bproj(y),))  # type: ignore\n        self._adj = jax.custom_vjp(self._bproj)\n        self._adj.defvjp(lambda y: (self._bproj(y), None), lambda _, x: (self._proj(x),))  # type: ignore\n\n        super().__init__(\n            input_shape=self.input_shape,\n            output_shape=output_shape,\n            input_dtype=np.float32,\n            output_dtype=np.float32,\n            adj_fn=self._adj,\n            jit=False,\n        )\n\n    @staticmethod\n    def create_astra_geometry(\n        input_shape: Shape,\n        det_count: Tuple[int, int],\n        det_spacing: Optional[Tuple[float, float]] = None,\n        angles: Optional[np.ndarray] = None,\n        vectors: Optional[np.ndarray] = None,\n    ) -> Tuple[VolumeGeometry, ProjectionGeometry]:\n        \"\"\"Create ASTRA 3D geometry objects.\n\n        Keyword arguments `det_spacing` and `angles` should be specified\n        to use the \"parallel3d\" geometry, and keyword argument `vectors`\n        should be specified to use the \"parallel3d_vec\" geometry. These\n        parameters are mutually exclusive.\n\n        Args:\n            input_shape: Shape of the input array.\n            det_count: Number of detector elements. See the\n               `astra documentation <https://www.astra-toolbox.com/docs/geom3d.html#projection-geometries>`__\n               for more information.\n            det_spacing: Spacing between detector elements. See the\n               `astra documentation <https://www.astra-toolbox.com/docs/geom3d.html#projection-geometries>`__\n               for more information.\n            angles: Array of projection angles in radians.\n            vectors: Array of geometry specification vectors.\n\n        Returns:\n            A tuple `(vol_geom, proj_geom)` of ASTRA volume geometry and\n            projection geometry objects.\n        \"\"\"\n        vol_geom = astra.create_vol_geom(input_shape[1], input_shape[2], input_shape[0])\n        if angles is not None:\n            assert det_spacing is not None\n            proj_geom = astra.create_proj_geom(\n                \"parallel3d\",\n                det_spacing[0],\n                det_spacing[1],\n                det_count[0],\n                det_count[1],\n                angles,\n            )\n        else:\n            proj_geom = astra.create_proj_geom(\n                \"parallel3d_vec\", det_count[0], det_count[1], vectors\n            )\n        return vol_geom, proj_geom\n\n    def _proj(self, x: jax.Array) -> jax.Array:\n        # apply the forward projector and generate a sinogram\n\n        def f(x):\n            x = _ensure_writeable(x)\n            proj_id, result = astra.create_sino3d_gpu(x, self.proj_geom, self.vol_geom)\n            astra.data3d.delete(proj_id)\n            return result\n\n        return jax.pure_callback(f, jax.ShapeDtypeStruct(self.output_shape, self.output_dtype), x)\n\n    def _bproj(self, y: jax.Array) -> jax.Array:\n        # apply backprojector\n        def f(y):\n            y = _ensure_writeable(y)\n            proj_id, result = astra.create_backprojection3d_gpu(y, self.proj_geom, self.vol_geom)\n            astra.data3d.delete(proj_id)\n            return result\n\n        return jax.pure_callback(f, jax.ShapeDtypeStruct(self.input_shape, self.input_dtype), y)\n\n\ndef angle_to_vector(det_spacing: Tuple[float, float], angles: np.ndarray) -> np.ndarray:\n    \"\"\"Convert det_spacing and angles to vector geometry specification.\n\n    Args:\n        det_spacing: Spacing between detector elements. See the\n            `astra documentation <https://www.astra-toolbox.com/docs/geom3d.html#projection-geometries>`__\n            for more information.\n        angles: Array of projection angles in radians.\n\n    Returns:\n        Array of geometry specification vectors.\n    \"\"\"\n    vectors = np.zeros((angles.size, 12))\n    vectors[:, 0] = np.sin(angles)\n    vectors[:, 1] = -np.cos(angles)\n    vectors[:, 6] = np.cos(angles) * det_spacing[0]\n    vectors[:, 7] = np.sin(angles) * det_spacing[0]\n    vectors[:, 11] = det_spacing[1]\n    return vectors\n\n\ndef rotate_vectors(vectors: np.ndarray, rot: Rotation) -> np.ndarray:\n    \"\"\"Rotate geometry specification vectors.\n\n    Rotate ASTRA \"parallel3d_vec\" geometry specification vectors.\n\n    Args:\n        vectors: Array of geometry specification vectors.\n        rot: Rotation.\n\n    Returns:\n        Rotated geometry specification vectors.\n    \"\"\"\n    rot_vecs = vectors.copy()\n    for k in range(0, 12, 3):\n        rot_vecs[:, k : k + 3] = rot.apply(rot_vecs[:, k : k + 3])\n    return rot_vecs\n\n\ndef _ensure_writeable(x):\n    \"\"\"Ensure that `x.flags.writeable` is ``True``, copying if needed.\"\"\"\n    if hasattr(x, \"flags\"):  # x is a numpy array\n        if not x.flags.writeable:\n            try:\n                x.setflags(write=True)\n            except ValueError:\n                x = x.copy()\n    else:  # x is a jax array (which is immutable)\n        x = np.array(x)\n    return x\n"
  },
  {
    "path": "scico/linop/xray/svmbir.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"X-ray transform LinearOperator wrapping the svmbir package.\n\nX-ray transform :class:`.LinearOperator` wrapping the\n`svmbir <https://github.com/cabouman/svmbir>`_ package. Since this\npackage is an interface to compiled C code, JAX features such as\nautomatic differentiation and support for GPU devices are not available.\n\"\"\"\n\nfrom typing import Optional, Tuple, Union\n\nimport numpy as np\n\nimport jax\n\nimport scico.numpy as snp\nfrom scico.loss import Loss, SquaredL2Loss\nfrom scico.typing import Shape\n\nfrom .._diag import Diagonal, Identity\nfrom .._linop import LinearOperator\n\ntry:\n    import svmbir\nexcept ImportError:\n    raise ImportError(\"Could not import svmbir; please install it.\")\n\n\nclass XRayTransform(LinearOperator):\n    r\"\"\"X-ray transform based on svmbir.\n\n    Perform tomographic projection of an image at specified angles, using\n    the `svmbir <https://github.com/cabouman/svmbir>`_ package. The\n    `is_masked` option selects whether a valid region for projections\n    (pixels outside this region are ignored when performing the\n    projection) is active. This region of validity is also respected by\n    :meth:`.SVMBIRSquaredL2Loss.prox` when :class:`.SVMBIRSquaredL2Loss`\n    is initialized with a :class:`XRayTransform` with this option\n    enabled.\n\n    A brief description of the supported scanner geometries can be found\n    in the `svmbir documentation <https://svmbir.readthedocs.io/en/latest/overview.html>`_.\n    Parallel beam geometry and two different fan beam geometries are supported.\n\n    .. list-table::\n\n       * - .. figure:: /figures/geom-parallel.png\n              :align: center\n              :width: 75%\n\n              Fig 1. Parallel beam geometry.\n\n         - .. figure:: /figures/geom-fan.png\n              :align: center\n              :width: 75%\n\n              Fig 2. Curved fan beam geometry.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        angles: snp.Array,\n        num_channels: int,\n        center_offset: float = 0.0,\n        is_masked: bool = False,\n        geometry: str = \"parallel\",\n        dist_source_detector: Optional[float] = None,\n        magnification: Optional[float] = None,\n        delta_channel: Optional[float] = None,\n        delta_pixel: Optional[float] = None,\n    ):\n        \"\"\"\n        The output of this linear operator is an array of shape\n        `(num_angles, num_channels)` when input_shape is 2D, or of shape\n        `(num_angles, num_slices, num_channels)` when input_shape is 3D,\n        where `num_angles` is the length of the `angles` argument, and\n        `num_slices` is inferred from the `input_shape` argument.\n\n        Most of the the following arguments have the same name as and\n        correspond to arguments of :func:`svmbir.project`. A brief\n        summary of each is provided here, but the documentation for\n        :func:`svmbir.project` should be consulted for further details.\n\n        Args:\n            input_shape: Shape of the input array. May be of length 2 (a\n                2D array) or 3 (a 3D array). When specifying a 2D array,\n                the format for the input_shape is `(num_rows, num_cols)`.\n                For a 3D array, the format for the input_shape is\n                `(num_slices, num_rows, num_cols)`, where `num_slices`\n                denotes the number of slices in the input, and `num_rows`\n                and `num_cols` denote the number of rows and columns in a\n                single slice of the input. A slice is a plane\n                perpendicular to the axis of rotation of the tomographic\n                system. At angle zero, each row is oriented along the\n                X-rays (parallel beam) or the X-ray beam directed toward\n                the detector center (fan beam).  Note that\n                `input_shape=(num_rows, num_cols)` and\n                `input_shape=(1, num_rows, num_cols)` result in the\n                same underlying projector.\n            angles: Array of projection angles in radians, should be\n                increasing.\n            num_channels: Number of detector channels in the sinogram\n                data.\n            center_offset: Position of the detector center relative to\n                the projection of the center of rotation onto the\n                detector, in units of pixels.\n            is_masked: If ``True``, the valid region of the image is\n                determined by a mask defined as the circle inscribed\n                within the image boundary. Otherwise, the whole image\n                array is taken into account by projections.\n            geometry: Scanner geometry, either \"parallel\", \"fan-curved\",\n                or \"fan-flat\". Note that the `dist_source_detector` and\n                `magnification` arguments must be provided for then fan\n                beam geometries.\n            dist_source_detector: Distance from X-ray focal spot to\n                detectors in units of pixel pitch. Only used when geometry\n                is \"fan-flat\" or \"fan-curved\".\n            magnification: Magnification factor of the scanner geometry.\n                Only used when geometry is \"fan-flat\" or \"fan-curved\".\n            delta_channel: Detector channel spacing.\n            delta_pixel: Spacing between image pixels in the 2D slice\n                plane.\n        \"\"\"\n        self.angles = angles\n        self.num_channels = num_channels\n        self.center_offset = center_offset\n\n        if len(input_shape) == 2:  # 2D input\n            self.svmbir_input_shape = (1,) + input_shape\n            output_shape: Tuple[int, ...] = (len(angles), num_channels)\n            self.svmbir_output_shape = output_shape[0:1] + (1,) + output_shape[1:2]\n        elif len(input_shape) == 3:  # 3D input\n            self.svmbir_input_shape = input_shape\n            output_shape = (len(angles), input_shape[0], num_channels)\n            self.svmbir_output_shape = output_shape\n        else:\n            raise ValueError(\n                f\"Only 2D and 3D inputs are supported, but input_shape was {input_shape}.\"\n            )\n\n        self.is_masked = is_masked\n        if self.is_masked:\n            self.roi_radius = None\n        else:\n            self.roi_radius = max(self.svmbir_input_shape[1], self.svmbir_input_shape[2])\n\n        self.geometry = geometry\n        self.dist_source_detector = dist_source_detector\n        self.magnification = magnification\n\n        if delta_channel is None:\n            self.delta_channel = 1.0\n        else:\n            self.delta_channel = delta_channel\n\n        if self.geometry == \"fan-curved\" or self.geometry == \"fan-flat\":\n            if self.dist_source_detector is None:\n                raise ValueError(\n                    \"Argument 'dist_source_detector' must be specified for fan beam geometry.\"\n                )\n            if self.magnification is None:\n                raise ValueError(\n                    \"Argument 'magnification' must be specified for fan beam geometry.\"\n                )\n\n            if delta_pixel is None:\n                self.delta_pixel = self.delta_channel / self.magnification\n            else:\n                self.delta_pixel = delta_pixel\n\n        elif self.geometry == \"parallel\":\n            self.magnification = 1.0\n            if delta_pixel is None:\n                self.delta_pixel = self.delta_channel\n            else:\n                self.delta_pixel = delta_pixel\n\n        else:\n            raise ValueError(\"Unspecified geometry {}.\".format(self.geometry))\n\n        # Set up custom_vjp for _eval and _adj so jax.grad works on them.\n        self._eval = jax.custom_vjp(self._proj_hcb)\n        self._eval.defvjp(lambda x: (self._proj_hcb(x), None), lambda _, y: (self._bproj_hcb(y),))  # type: ignore\n\n        self._adj = jax.custom_vjp(self._bproj_hcb)\n        self._adj.defvjp(lambda y: (self._bproj_hcb(y), None), lambda _, x: (self._proj_hcb(x),))  # type: ignore\n\n        super().__init__(\n            input_shape=input_shape,\n            output_shape=output_shape,\n            input_dtype=np.float32,\n            output_dtype=np.float32,\n            adj_fn=self._adj,\n            jit=False,\n        )\n\n    @staticmethod\n    def _proj(\n        x: snp.Array,\n        angles: snp.Array,\n        num_channels: int,\n        center_offset: float = 0.0,\n        roi_radius: Optional[float] = None,\n        geometry: str = \"parallel\",\n        dist_source_detector: Optional[float] = None,\n        magnification: Optional[float] = None,\n        delta_channel: Optional[float] = None,\n        delta_pixel: Optional[float] = None,\n    ) -> snp.Array:\n        return snp.array(\n            svmbir.project(\n                np.array(x),\n                np.array(angles),\n                num_channels,\n                verbose=0,\n                center_offset=center_offset,\n                roi_radius=roi_radius,\n                geometry=geometry,\n                dist_source_detector=dist_source_detector,\n                magnification=magnification,\n                delta_channel=delta_channel,\n                delta_pixel=delta_pixel,\n            )\n        )\n\n    def _proj_hcb(self, x):\n        x = x.reshape(self.svmbir_input_shape)\n        # callback wrapper for _proj\n        y = jax.pure_callback(\n            lambda x: self._proj(\n                x,\n                self.angles,\n                self.num_channels,\n                center_offset=self.center_offset,\n                roi_radius=self.roi_radius,\n                geometry=self.geometry,\n                dist_source_detector=self.dist_source_detector,\n                magnification=self.magnification,\n                delta_channel=self.delta_channel,\n                delta_pixel=self.delta_pixel,\n            ),\n            jax.ShapeDtypeStruct(self.svmbir_output_shape, self.output_dtype),\n            x,\n        )\n        return y.reshape(self.output_shape)\n\n    @staticmethod\n    def _bproj(\n        y: snp.Array,\n        angles: snp.Array,\n        num_rows: int,\n        num_cols: int,\n        center_offset: Optional[float] = 0.0,\n        roi_radius: Optional[float] = None,\n        geometry: str = \"parallel\",\n        dist_source_detector: Optional[float] = None,\n        magnification: Optional[float] = None,\n        delta_channel: Optional[float] = None,\n        delta_pixel: Optional[float] = None,\n    ) -> snp.Array:\n        return snp.array(\n            svmbir.backproject(\n                np.array(y),\n                np.array(angles),\n                num_rows=num_rows,\n                num_cols=num_cols,\n                verbose=0,\n                center_offset=center_offset,\n                roi_radius=roi_radius,\n                geometry=geometry,\n                dist_source_detector=dist_source_detector,\n                magnification=magnification,\n                delta_channel=delta_channel,\n                delta_pixel=delta_pixel,\n            )\n        )\n\n    def _bproj_hcb(self, y):\n        y = y.reshape(self.svmbir_output_shape)\n        # callback wrapper for _bproj\n        x = jax.pure_callback(\n            lambda y: self._bproj(\n                y,\n                self.angles,\n                self.svmbir_input_shape[1],\n                self.svmbir_input_shape[2],\n                center_offset=self.center_offset,\n                roi_radius=self.roi_radius,\n                geometry=self.geometry,\n                dist_source_detector=self.dist_source_detector,\n                magnification=self.magnification,\n                delta_channel=self.delta_channel,\n                delta_pixel=self.delta_pixel,\n            ),\n            jax.ShapeDtypeStruct(self.svmbir_input_shape, self.input_dtype),\n            y,\n        )\n        return x.reshape(self.input_shape)\n\n\nclass SVMBIRExtendedLoss(Loss):\n    r\"\"\"Extended squared :math:`\\ell_2` loss with svmbir tomographic projector.\n\n    Generalization of the weighted squared :math:`\\ell_2` loss for a CT\n    reconstruction problem,\n\n    .. math::\n        \\alpha \\norm{\\mb{y} - A(\\mb{x})}_W^2 =\n        \\alpha \\left(\\mb{y} - A(\\mb{x})\\right)^T W \\left(\\mb{y} -\n        A(\\mb{x})\\right) \\;,\n\n    where :math:`A` is a :class:`.XRayTransform`,\n    :math:`\\alpha` is the scaling parameter and :math:`W` is an instance\n    of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, it is set\n    to :class:`scico.linop.Identity`.\n\n    The extended loss differs from a typical weighted squared\n    :math:`\\ell_2` loss as follows. When `positivity=True`, the prox\n    projects onto the non-negative orthant and the loss is infinite if\n    any element of the input is negative. When the `is_masked` option\n    of the associated :class:`.XRayTransform` is ``True``, the\n    reconstruction is computed over a masked region of the image as\n    described in class :class:`.XRayTransform`.\n    \"\"\"\n\n    A: XRayTransform\n    W: Union[Identity, Diagonal]\n\n    def __init__(\n        self,\n        *args,\n        scale: float = 0.5,\n        prox_kwargs: Optional[dict] = None,\n        positivity: bool = False,\n        W: Optional[Diagonal] = None,\n        **kwargs,\n    ):\n        r\"\"\"Initialize a :class:`SVMBIRExtendedLoss` object.\n\n        Args:\n            y: Sinogram measurement.\n            A: Forward operator.\n            scale: Scaling parameter.\n            prox_kwargs: Dictionary of arguments passed to the\n               :meth:`svmbir.recon` prox routine. Defaults to\n               {\"maxiter\": 1000, \"ctol\": 0.001}.\n            positivity: Enforce positivity in the prox operation. The\n               loss is infinite if any element of the input is negative.\n            W: Weighting diagonal operator. Must be non-negative.\n               If ``None``, defaults to :class:`.Identity`.\n        \"\"\"\n        super().__init__(*args, scale=scale, **kwargs)  # type: ignore\n\n        if not isinstance(self.A, XRayTransform):\n            raise ValueError(\"LinearOperator A must be a radon_svmbir.XRayTransform.\")\n\n        self.has_prox = True\n\n        if prox_kwargs is None:\n            prox_kwargs = {}\n\n        default_prox_args = {\"maxiter\": 1000, \"ctol\": 0.001}\n        default_prox_args.update(prox_kwargs)\n\n        svmbir_prox_args = {}\n        if \"maxiter\" in default_prox_args:\n            svmbir_prox_args[\"max_iterations\"] = default_prox_args[\"maxiter\"]\n        if \"ctol\" in default_prox_args:\n            svmbir_prox_args[\"stop_threshold\"] = default_prox_args[\"ctol\"]\n        self.svmbir_prox_args = svmbir_prox_args\n\n        self.positivity = positivity\n\n        if W is None:\n            self.W = Identity(self.y.shape)\n        elif isinstance(W, Diagonal):\n            if snp.all(W.diagonal >= 0):\n                self.W = W\n            else:\n                raise ValueError(f\"The weights, W, must be non-negative.\")\n        else:\n            raise TypeError(f\"Argument 'W' must be None or a linop.Diagonal, got {type(W)}.\")\n\n    def __call__(self, x: snp.Array) -> float:\n        if self.positivity and snp.sum(x < 0) > 0:\n            return snp.inf\n        else:\n            return self.scale * (self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2).sum()\n\n    def prox(self, v: snp.Array, lam: float = 1, **kwargs) -> snp.Array:\n        v = v.reshape(self.A.svmbir_input_shape)\n        y = self.y.reshape(self.A.svmbir_output_shape)\n        weights = self.W.diagonal.reshape(self.A.svmbir_output_shape)\n        sigma_p = snp.sqrt(lam)\n        if \"v0\" in kwargs and kwargs[\"v0\"] is not None:\n            v0: Union[float, np.ndarray] = np.reshape(\n                np.array(kwargs[\"v0\"]), self.A.svmbir_input_shape\n            )\n        else:\n            v0 = 0.0\n\n        # change: stop, mask-rad, init\n        result = svmbir.recon(\n            np.array(y),\n            np.array(self.A.angles),\n            weights=np.array(weights),\n            prox_image=np.array(v),\n            num_rows=self.A.svmbir_input_shape[1],\n            num_cols=self.A.svmbir_input_shape[2],\n            center_offset=self.A.center_offset,\n            roi_radius=self.A.roi_radius,\n            geometry=self.A.geometry,\n            dist_source_detector=self.A.dist_source_detector,\n            magnification=self.A.magnification,\n            delta_channel=self.A.delta_channel,\n            delta_pixel=self.A.delta_pixel,\n            sigma_p=float(sigma_p),\n            sigma_y=1.0,\n            positivity=self.positivity,\n            verbose=0,\n            init_image=v0,\n            **self.svmbir_prox_args,\n        )\n        if np.sum(np.isnan(result)):\n            raise ValueError(\"Result contains NaNs.\")\n\n        return snp.array(result.reshape(self.A.input_shape))\n\n\nclass SVMBIRSquaredL2Loss(SVMBIRExtendedLoss, SquaredL2Loss):\n    r\"\"\"Weighted squared :math:`\\ell_2` loss with svmbir tomographic projector.\n\n    Weighted squared :math:`\\ell_2` loss of a CT reconstruction problem,\n\n    .. math::\n        \\alpha \\norm{\\mb{y} - A(\\mb{x})}_W^2 =\n        \\alpha \\left(\\mb{y} - A(\\mb{x})\\right)^T W \\left(\\mb{y} -\n        A(\\mb{x})\\right) \\;,\n\n    where :math:`A` is a :class:`.XRayTransform`, :math:`\\alpha`\n    is the scaling parameter and :math:`W` is an instance\n    of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, it is set\n    to :class:`scico.linop.Identity`.\n    \"\"\"\n\n    def __init__(\n        self,\n        *args,\n        prox_kwargs: Optional[dict] = None,\n        **kwargs,\n    ):\n        r\"\"\"Initialize a :class:`SVMBIRSquaredL2Loss` object.\n\n        Args:\n            y: Sinogram measurement.\n            A: Forward operator.\n            scale: Scaling parameter.\n            W: Weighting diagonal operator. Must be non-negative.\n               If ``None``, defaults to :class:`.Identity`.\n            prox_kwargs: Dictionary of arguments passed to the\n               :meth:`svmbir.recon` prox routine. Defaults to\n               {\"maxiter\": 1000, \"ctol\": 0.001}.\n        \"\"\"\n        super().__init__(*args, **kwargs, prox_kwargs=prox_kwargs, positivity=False)\n\n        if self.A.is_masked:\n            raise ValueError(\n                \"Argument 'is_masked' must be False for the XRayTransform in SVMBIRSquaredL2Loss.\"\n            )\n"
  },
  {
    "path": "scico/linop/xray/symcone.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Cone beam X-ray transform for cylindrically symmetric objects.\n\nCone beam X-ray transform and FDK reconstruction for cylindrically\nsymmetric objects; essentialy a cone-beam variant of the Abel transform.\nThe implementation is based on code modified from the\n`axitom <https://github.com/PolymerGuy/AXITOM>`_ package\n:cite:`olufsen-2019-axitom`.\n\"\"\"\n\nfrom functools import partial\nfrom typing import Optional, Tuple\n\nimport numpy as np\n\nimport jax.numpy as jnp\nfrom jax import Array, jit, vjp\nfrom jax.scipy.ndimage import map_coordinates\nfrom jax.typing import ArrayLike\n\nfrom scico.typing import DType, Shape\n\nfrom .._linop import LinearOperator\nfrom ._axitom import backprojection, config, projection\n\n\n@partial(jit, static_argnames=[\"axis\", \"center\"])\ndef _volume_by_axial_symmetry(\n    x: Array, axis: int = 0, center: Optional[int] = None, zrange: Optional[Array] = None\n) -> Array:\n    \"\"\"Create a volume by axial rotation of a plane.\n\n    Args:\n        x: 2D array that is rotated about an axis to generate a volume.\n        axis: Index of axis of symmetry (must be 0 or 1).\n        center: Location of the axis of symmetry on the other axis. If\n          ``None``, defaults to center of that axis. Otherwise identifies\n          the center coordinate on that axis.\n        zrange: 1D array of points at which the extended axis is\n          constructed. Defaults to the same as for axis :code:`1 - axis`.\n\n    Returns:\n        Volume as a 3D array.\n    \"\"\"\n    N0, N1 = x.shape\n    N0h, N1h = (N0 + 1) / 2 - 1, (N1 + 1) / 2 - 1\n    half_shape = (N0h, N1h)\n    if zrange is None:\n        N2 = x.shape[1 - axis]\n        N2h = (N2 + 1) / 2 - 1\n        zrange = jnp.arange(-N2h, N2h + 1)\n    if axis == 0:\n        g1d = [np.arange(0, N0), jnp.arange(-N1h, N1h + 1), zrange]\n    else:\n        g1d = [np.arange(-N0h, N0h + 1), jnp.arange(0, N1), zrange]\n\n    if center is None:\n        offset = 0\n    else:\n        offset = center - half_shape[1 - axis]\n\n    g0, g1, g2 = jnp.meshgrid(*g1d, indexing=\"ij\")\n    grids = (g0, g1, g2)\n    r = jnp.hypot(grids[1 - axis], g2)\n    sym_ax_crd = jnp.where(\n        grids[1 - axis] >= 0, half_shape[1 - axis] + offset + r, half_shape[1 - axis] + offset - r\n    )\n    if axis == 0:\n        coords = [grids[axis], sym_ax_crd]\n    else:\n        coords = [sym_ax_crd, grids[axis]]\n    v = map_coordinates(x, coords, cval=0.0, order=1)\n\n    return v\n\n\nclass AxiallySymmetricVolume(LinearOperator):\n    \"\"\"Create a volume by axial rotation of a plane.\"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        input_dtype: DType = np.float32,\n        axis: int = 0,\n        center: Optional[int] = None,\n    ):\n        \"\"\"\n        Args:\n            input_shape: Input image shape.\n            input_dtype: Input image dtype.\n            axis: Index of axis of symmetry (must be 0 or 1).\n            center: If ``None``, defaults to the center of the image on\n              the specified axis. Otherwise identifies the center\n              coordinate on that axis.\n        \"\"\"\n        self.axis = axis\n        self.center = center\n        output_shape = input_shape + (input_shape[axis],)\n        super().__init__(\n            input_shape=input_shape,\n            output_shape=output_shape,\n            input_dtype=input_dtype,\n            output_dtype=input_dtype,\n            eval_fn=lambda x: _volume_by_axial_symmetry(x, axis=self.axis, center=self.center),\n            jit=True,\n        )\n\n\nclass SymConeXRayTransform(LinearOperator):\n    \"\"\"Cone beam X-ray transform for cylindrically symmetric objects.\n\n    Cone beam X-ray transform of a cylindrically symmetric volume, which\n    may be represented by a 2D central slice, which is rotated about\n    the specified axis to generate a 3D volume for projection.\n    The implementation is based on code modified from the AXITOM package\n    :cite:`olufsen-2019-axitom`..\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Shape,\n        obj_dist: float,\n        det_dist: float,\n        axis: int = 0,\n        pixel_size: Optional[Tuple[float, float]] = None,\n        num_slabs: int = 1,\n    ):\n        \"\"\"\n        Args:\n            input_shape: Shape of the input array. If 2D, the input is\n              extended to 3D (onto a new axis 1) by cylindrical symmetry.\n            obj_dist: Source-object distance in arbitary length units (ALU).\n            det_dist: Source-detector distance in ALU.\n            axis: Index of axis of symmetry (must be 0 or 1).\n            pixel_size: Tuple of pixel size values in ALU.\n            num_slabs: Number of slabs into which the volume should be\n              divided (for serial processing, to limit memory usage) in\n              the imaging direction.\n        \"\"\"\n        if len(input_shape) == 2:\n            self.input_2d = True\n            output_shape = input_shape[::-1]\n        else:\n            self.input_2d = False\n            output_shape = (input_shape[2], input_shape[0])\n        if pixel_size is None:\n            pixel_size = (1.0, 1.0)\n        self.axis = axis\n        self.config = config.Config(*output_shape, *pixel_size, det_dist, obj_dist)\n        self.num_slabs = num_slabs\n        if len(input_shape) == 2 and axis == 1:\n            eval_fn = lambda x: projection.forward_project(\n                x.T, self.config, num_slabs=self.num_slabs, input_2d=self.input_2d\n            ).T\n        else:\n            eval_fn = lambda x: projection.forward_project(\n                x, self.config, num_slabs=self.num_slabs, input_2d=self.input_2d\n            )\n        # use vjp rather than linear_transpose due to jax-ml/jax#30552\n        adj_fn = vjp(eval_fn, jnp.zeros(input_shape))[1]\n        super().__init__(\n            input_shape=input_shape,\n            output_shape=output_shape,\n            input_dtype=np.float32,\n            output_dtype=np.float32,\n            eval_fn=eval_fn,\n            adj_fn=lambda x: adj_fn(x)[0],\n            jit=True,\n        )\n\n    def fdk(self, y: ArrayLike, num_angles: int = 360):\n        \"\"\"Reconstruct central slice from projection.\n\n        Reconstruct the central slice of the cylindrically symmetric\n        volume from a projection. The reconstruction makes use of the\n        Feldkamp David Kress (FDK) algorithm implemented in the\n        `axitom <https://github.com/PolymerGuy/AXITOM>`_ package.\n\n        Args:\n          y: The projection to be reconstructed.\n          num_angles: Number of angles to be averaged in the\n            reconstruction.\n\n        Returns:\n          Reconstruction of the central slice of the volume.\n        \"\"\"\n        angles = jnp.linspace(0, 360, num_angles, endpoint=False)\n        x = backprojection.fdk(y if self.axis == 1 else y.T, self.config, angles)\n        return x if self.axis == 1 else x.T\n"
  },
  {
    "path": "scico/loss.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\n\"\"\"Loss function classes.\"\"\"\n\nimport warnings\nfrom copy import copy\nfrom functools import wraps\nfrom typing import Callable, Optional, Union\n\nimport jax\n\nimport scico\nimport scico.numpy as snp\nfrom scico import functional, linop, operator\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.util import no_nan_divide\nfrom scico.scipy.special import gammaln  # type: ignore\nfrom scico.solver import cg\n\n\ndef _loss_mul_div_wrapper(func):\n    @wraps(func)\n    def wrapper(self, other):\n        if snp.isscalar(other) or isinstance(other, jax.core.Tracer):\n            return func(self, other)\n\n        raise NotImplementedError(\n            f\"Operation {func} not defined between {type(self)} and {type(other)}.\"\n        )\n\n    return wrapper\n\n\nclass Loss(functional.Functional):\n    r\"\"\"Generic loss function.\n\n    Generic loss function\n\n    .. math::\n        \\alpha f(\\mb{y}, A(\\mb{x})) \\;,\n\n    where :math:`\\alpha` is the scaling parameter and :math:`f(\\cdot)` is\n    the loss functional.\n    \"\"\"\n\n    def __init__(\n        self,\n        y: Union[Array, BlockArray],\n        A: Optional[Union[Callable, operator.Operator]] = None,\n        f: Optional[functional.Functional] = None,\n        scale: float = 1.0,\n    ):\n        r\"\"\"\n        Args:\n            y: Measurement.\n            A: Forward operator. Defaults to ``None``, in which case\n               `self.A` is a :class:`.Identity` with input shape\n               and dtype determined by the shape and dtype of `y`.\n            f: Functional :math:`f`. If defined, the loss function is\n               :math:`\\alpha f(\\mb{y} - A(\\mb{x}))`. If ``None``, then\n               :meth:`__call__` and :meth:`prox` (where appropriate) must\n               be defined in a derived class.\n            scale: Scaling parameter. Default: 1.0.\n        \"\"\"\n        self.y = y\n        if A is None:\n            # y and x must have same shape\n            A = linop.Identity(input_shape=self.y.shape, input_dtype=self.y.dtype)  # type: ignore\n        self.A = A\n        self.f = f\n        self.scale = scale\n\n        # Set functional-specific flags\n        self.has_eval = True\n        if self.f is not None and isinstance(self.A, linop.Identity):\n            self.has_prox = True\n        else:\n            self.has_prox = False\n        super().__init__()\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        r\"\"\"Evaluate this loss at point :math:`\\mb{x}`.\n\n        Args:\n            x: Point at which to evaluate loss.\n\n        Returns:\n            Result of evaluating the loss at `x`.\n        \"\"\"\n        if self.f is None:\n            raise NotImplementedError(\n                \"Functional f is not defined and __call__ has not been overridden.\"\n            )\n        return self.scale * self.f(self.A(x) - self.y)\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1, **kwargs\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"Scaled proximal operator of loss function.\n\n        Evaluate scaled proximal operator of this loss function, with\n        scaling :math:`\\lambda` = `lam` and evaluated at point\n        :math:`\\mb{v}` = `v`. If :meth:`prox` is not defined in a derived\n        class, and if operator :math:`A` is the identity operator, then\n        the proximal operator is computed using the proximal operator of\n        functional :math:`l`, via Theorem 6.11 in :cite:`beck-2017-first`.\n\n        Args:\n            v: Point at which to evaluate prox function.\n            lam: Proximal parameter :math:`\\lambda`.\n            **kwargs: Additional arguments that may be used by derived\n               classes. These include `x0`, an initial guess for the\n               minimizer in the defintion of :math:`\\mathrm{prox}`.\n\n        Returns:\n            Result of evaluating the scaled proximal operator at `v`.\n        \"\"\"\n        if not self.has_prox:\n            raise NotImplementedError(\n                f\"Method prox is not implemented for {type(self)} when A is {type(self.A)}; \"\n                \"A must be an Identity.\"\n            )\n        assert self.f is not None\n        return self.f.prox(v - self.y, self.scale * lam, **kwargs) + self.y\n\n    @_loss_mul_div_wrapper\n    def __mul__(self, other):\n        new_loss = copy(self)\n        new_loss._grad = scico.grad(new_loss.__call__)\n        new_loss.set_scale(self.scale * other)\n        return new_loss\n\n    def __rmul__(self, other):\n        return self.__mul__(other)\n\n    @_loss_mul_div_wrapper\n    def __truediv__(self, other):\n        new_loss = copy(self)\n        new_loss._grad = scico.grad(new_loss.__call__)\n        new_loss.set_scale(self.scale / other)\n        return new_loss\n\n    def set_scale(self, new_scale: float):\n        r\"\"\"Update the scale attribute.\"\"\"\n        self.scale = new_scale\n\n\nclass SquaredL2Loss(Loss):\n    r\"\"\"Weighted squared :math:`\\ell_2` loss.\n\n    Weighted squared :math:`\\ell_2` loss\n\n    .. math::\n        \\alpha \\norm{\\mb{y} - A(\\mb{x})}_W^2 =\n        \\alpha \\left(\\mb{y} - A(\\mb{x})\\right)^T W \\left(\\mb{y} -\n        A(\\mb{x})\\right) \\;,\n\n    where :math:`\\alpha` is the scaling parameter and :math:`W` is an\n    instance of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``,\n    the weighting is an identity operator, giving an unweighted squared\n    :math:`\\ell_2` loss.\n    \"\"\"\n\n    def __init__(\n        self,\n        y: Union[Array, BlockArray],\n        A: Optional[Union[Callable, operator.Operator]] = None,\n        scale: float = 0.5,\n        W: Optional[linop.Diagonal] = None,\n        prox_kwargs: Optional[dict] = None,\n    ):\n        r\"\"\"\n        Args:\n            y: Measurement.\n            A: Forward operator. If ``None``, defaults to :class:`.Identity`.\n            scale: Scaling parameter.\n            W: Weighting diagonal operator. Must be non-negative.\n                If ``None``, defaults to :class:`.Identity`.\n        \"\"\"\n        self.W: linop.Diagonal\n\n        if W is None:\n            self.W = linop.Identity(y.shape)  # type: ignore\n        elif isinstance(W, linop.Diagonal):\n            if snp.all(W.diagonal >= 0):  # type: ignore\n                self.W = W\n            else:\n                raise ValueError(f\"The weights, W.diagonal, must be non-negative.\")\n        else:\n            raise TypeError(f\"Parameter W must be None or a linop.Diagonal, got {type(W)}.\")\n\n        super().__init__(y=y, A=A, scale=scale)\n\n        default_prox_kwargs = {\"maxiter\": 100, \"tol\": 1e-5}\n        if prox_kwargs:\n            default_prox_kwargs.update(prox_kwargs)\n        self.prox_kwargs = default_prox_kwargs\n\n        if isinstance(self.A, linop.LinearOperator):\n            self.has_prox = True\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        return self.scale * snp.sum(self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2)\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        if not isinstance(self.A, linop.LinearOperator):\n            raise NotImplementedError(\n                f\"Method prox is not implemented for {type(self)} when A is {type(self.A)}; \"\n                \"A must be a LinearOperator.\"\n            )\n\n        if isinstance(self.A, linop.Diagonal):\n            c = 2.0 * self.scale * lam\n            A = self.A.diagonal\n            W = self.W.diagonal\n            lhs = c * A.conj() * W * self.y + v  # type: ignore\n            ATWA = c * A.conj() * W * A  # type: ignore\n            return lhs / (ATWA + 1.0)\n\n        #   prox_f(v) = arg min  1/2 || v - x ||_2^2 + λ 𝛼 || A x - y ||^2_W\n        #                  x\n        #   with solution: (I + λ 2𝛼 A^T W A) x = v + λ 2𝛼 A^T W y\n        W = self.W\n        A = self.A\n        𝛼 = self.scale\n        y = self.y\n        if \"x0\" in kwargs and kwargs[\"x0\"] is not None:\n            x0 = kwargs[\"x0\"]\n        else:\n            x0 = snp.zeros_like(v)\n        hessian = self.hessian  # = (2𝛼 A^T W A)\n        lhs = linop.Identity(v.shape) + lam * hessian\n        rhs = v + 2 * lam * 𝛼 * A.adj(W(y))\n        x, _ = cg(lhs, rhs, x0, **self.prox_kwargs)  # type: ignore\n        return x\n\n    @property\n    def hessian(self) -> linop.LinearOperator:\n        r\"\"\"Compute the Hessian of linear operator `A`.\n\n        If `self.A` is a :class:`scico.linop.LinearOperator`, returns a\n        :class:`scico.linop.LinearOperator` corresponding to  the Hessian\n        :math:`2 \\alpha \\mathrm{A^H W A}`. Otherwise not implemented.\n        \"\"\"\n        A = self.A\n        W = self.W\n        if isinstance(A, linop.LinearOperator):\n            return linop.LinearOperator(\n                input_shape=A.input_shape,\n                output_shape=A.input_shape,\n                eval_fn=lambda x: 2 * self.scale * A.adj(W(A(x))),  # type: ignore\n                adj_fn=lambda x: 2 * self.scale * A.adj(W(A(x))),  # type: ignore\n                input_dtype=A.input_dtype,\n            )\n\n        raise NotImplementedError(\n            f\"Hessian is not implemented for {type(self)} when A is {type(A)}; \"\n            \"A must be LinearOperator.\"\n        )\n\n\nclass PoissonLoss(Loss):\n    r\"\"\"Poisson negative log likelihood loss.\n\n    Poisson negative log likelihood loss\n\n    .. math::\n        \\alpha \\left( \\sum_i [A(x)]_i - y_i \\log\\left( [A(x)]_i \\right) +\n        \\log(y_i!) \\right) \\;,\n\n    where :math:`\\alpha` is the scaling parameter.\n    \"\"\"\n\n    def __init__(\n        self,\n        y: Union[Array, BlockArray],\n        A: Optional[Union[Callable, operator.Operator]] = None,\n        scale: float = 0.5,\n    ):\n        r\"\"\"\n        Args:\n            y: Measurement.\n            A: Forward operator. Defaults to ``None``, in which case\n                `self.A` is a :class:`.Identity`.\n            scale: Scaling parameter. Default: 0.5.\n        \"\"\"\n        super().__init__(y=y, A=A, scale=scale)\n\n        #: Constant term, :math:`\\ln(y!)`, in Poisson log likehood.\n        self.const = gammaln(self.y + 1.0)\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        Ax = self.A(x)\n        return self.scale * snp.sum(Ax - self.y * snp.log(Ax) + self.const)\n\n\nclass SquaredL2AbsLoss(Loss):\n    r\"\"\"Weighted squared :math:`\\ell_2` with absolute value loss.\n\n    Weighted squared :math:`\\ell_2` with absolute value loss\n\n    .. math::\n        \\alpha \\norm{\\mb{y} - | A(\\mb{x}) |\\,}_W^2 =\n        \\alpha \\left(\\mb{y} - | A(\\mb{x}) |\\right)^T W \\left(\\mb{y} -\n        | A(\\mb{x}) |\\right) \\;,\n\n    where :math:`\\alpha` is the scaling parameter and :math:`W` is an\n    instance of :class:`scico.linop.Diagonal`.\n\n    Proximal operator :meth:`prox` is implemented when :math:`A` is an\n    instance of :class:`scico.linop.Identity`. This is not proximal\n    operator according to the strict definition since the loss function\n    is non-convex (Sec. 3) :cite:`soulez-2016-proximity`.\n    \"\"\"\n\n    def __init__(\n        self,\n        y: Union[Array, BlockArray],\n        A: Optional[Union[Callable, operator.Operator]] = None,\n        scale: float = 0.5,\n        W: Optional[linop.Diagonal] = None,\n    ):\n        r\"\"\"\n        Args:\n            y: Measurement.\n            A: Forward operator. If ``None``, defaults to :class:`.Identity`.\n            scale: Scaling parameter.\n            W: Weighting diagonal operator. Must be non-negative.\n                If ``None``, defaults to :class:`.Identity`.\n        \"\"\"\n        if W is None:\n            self.W: Union[linop.Diagonal, linop.Identity] = linop.Identity(y.shape)\n        elif isinstance(W, linop.Diagonal):\n            if snp.all(W.diagonal >= 0):\n                self.W = W\n            else:\n                raise ValueError(\"The weights, W.diagonal, must be non-negative.\")\n        else:\n            raise TypeError(f\"Parameter W must be None or a linop.Diagonal, got {type(W)}.\")\n\n        super().__init__(y=y, A=A, scale=scale)\n\n        if isinstance(self.A, linop.Identity) and snp.all(y >= 0):\n            self.has_prox = True\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        return self.scale * snp.sum(self.W.diagonal * snp.abs(self.y - snp.abs(self.A(x))) ** 2)\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        if not self.has_prox:\n            raise NotImplementedError(f\"Method prox is not implemented.\")\n\n        𝛼 = lam * 2.0 * self.scale * self.W.diagonal\n        y = self.y\n        r = snp.abs(v)\n        𝛽 = (𝛼 * y + r) / (𝛼 + 1.0)\n        x = snp.where(r > 0, (𝛽 / r) * v, 𝛽)\n        return x\n\n\ndef _cbrt(x: Union[Array, BlockArray]) -> Union[Array, BlockArray]:\n    \"\"\"Compute the cube root of the argument.\n\n    The two standard options for computing the cube root of an array are\n    :func:`numpy.cbrt`, or raising to the power of (1/3), i.e. `x ** (1/3)`.\n    The former cannot be used for complex values, and the latter returns\n    a complex root of a negative real value. This functions can be used\n    for both real and complex values, and returns the real root of\n    negative real values.\n\n    Args:\n        x: Input array.\n\n    Returns:\n        Array of cube roots of input `x`.\n    \"\"\"\n    s = snp.where(snp.abs(snp.angle(x)) <= 2 * snp.pi / 3, 1, -1)\n    return s * (s * x) ** (1 / 3)\n\n\ndef _check_root(\n    x: Union[Array, BlockArray],\n    p: Union[Array, BlockArray],\n    q: Union[Array, BlockArray],\n    tol: float = 1e-4,\n):\n    \"\"\"Check the precision of a cubic equation solution.\n\n    Check the precision of an array of depressed cubic equation solutions,\n    issuing a warning if any of the errors exceed a specified tolerance.\n\n    Args:\n        x: Array of roots of a depressed cubic equation.\n        p: Array of linear parameters of a depressed cubic equation.\n        q: Array of constant parameters of a depressed cubic equation.\n        tol: Expected tolerance for solution precision.\n    \"\"\"\n    err = snp.abs(x**3 + p * x + q)\n    if not snp.allclose(err, 0, atol=tol):\n        idx = snp.argmax(err)\n        msg = (\n            \"Low precision in root calculation. Worst error is \"\n            f\"{err.ravel()[idx]:.3e} for p={p.ravel()[idx]} and q={q.ravel()[idx]}\"\n        )\n        warnings.warn(msg)\n\n\ndef _dep_cubic_root(\n    p: Union[Array, BlockArray], q: Union[Array, BlockArray]\n) -> Union[Array, BlockArray]:\n    r\"\"\"Compute a real root of a depressed cubic equation.\n\n    A depressed cubic equation is one that can be written in the form\n\n    .. math::\n       x^3 + px + q = 0 \\;.\n\n    The determinant is\n\n    .. math::\n       \\Delta = (q/2)^2 + (p/3)^3 \\;.\n\n    When :math:`\\Delta > 0` this equation has one real root and two\n    complex (conjugate) roots, when :math:`\\Delta = 0`, all three roots\n    are real, with at least two being equal, and when :math:`\\Delta < 0`,\n    all roots are real and unequal.\n\n    According to Vieta's formulas, the roots :math:`x_0, x_1`, and\n    :math:`x_2` of this equation satisfy\n\n    .. math::\n       x_0 + x_1 + x_2 &= 0 \\\\\n       x_0 x_1 + x_0 x_2 + x_2 x_3 &= p \\\\\n       x_0 x_1 x_2 &= -q \\;.\n\n    Therefore, when :math:`q` is negative, the equation has a single real\n    positive root since at least one root must be negative for their sum\n    to be zero, and their product could not be positive if only one root\n    were zero. This function always returns a real root; when :math:`q`\n    is negative, it returns the single positive root.\n\n    The solution is computed using\n    `Vieta's substitution <https://mathworld.wolfram.com/CubicFormula.html>`__,\n\n    .. math::\n       w = x - \\frac{p}{3w} \\;,\n\n    which reduces the depressed cubic equation to\n\n    .. math::\n       w^3 - \\frac{p^3}{27w^3} + q = 0\\;,\n\n    which can be expressed as a quadratic equation in :math:`w^3` by\n    multiplication by :math:`w^3`, leading to\n\n    .. math::\n       w^3 = -\\frac{q}{2} \\pm \\sqrt{\\frac{q^2}{4} + \\frac{p^3}{27}} \\;.\n\n    Note that the multiplication by :math:`w^3` introduces a spurious\n    solution at zero in the case :math:`p = 0`, which must be handled\n    separately as\n\n    .. math::\n       w^3 = -q \\;.\n\n    Despite taking this into account, very poor numerical precision can\n    be obtained when :math:`p` is small but non-zero since, in this case\n\n    .. math::\n       \\sqrt{\\Delta} = \\sqrt{(q/2)^2 + (p/3)^3} \\approx q/2 \\;,\n\n    so that an incorrect solutions :math:`w^3 = 0` or :math:`w^3 = -q`\n    are obtained, depending on the choice of sign in the equation for\n    :math:`w^3`.\n\n    An alternative derivation leads to the equation\n\n    .. math::\n       x = \\sqrt[3]{-q/2 + \\sqrt{\\Delta}} + \\sqrt[3]{-q/2 - \\sqrt{\\Delta}}\n\n    for the real root, but this is also prone to severe numerical errors\n    in single precision arithmetic.\n\n    Args:\n       p: Array of :math:`p` values.\n       q: Array of :math:`q` values.\n\n    Returns:\n       Array of real roots of the cubic equation.\n    \"\"\"\n    Δ = (q**2) / 4.0 + (p**3) / 27.0\n    w3 = snp.where(snp.abs(p) <= 1e-7, -q, -q / 2.0 + snp.sqrt(Δ + 0j))\n    w = _cbrt(w3)\n    r = (w - no_nan_divide(p, 3 * w)).real\n    _check_root(r, p, q)\n    return r\n\n\nclass SquaredL2SquaredAbsLoss(Loss):\n    r\"\"\"Weighted squared :math:`\\ell_2` with squared absolute value loss.\n\n    Weighted squared :math:`\\ell_2` with squared absolute value loss\n\n    .. math::\n        \\alpha \\norm{\\mb{y} - | A(\\mb{x}) |^2 \\,}_W^2 =\n        \\alpha \\left(\\mb{y} - | A(\\mb{x}) |^2 \\right)^T W \\left(\\mb{y} -\n        | A(\\mb{x}) |^2 \\right) \\;,\n\n    where :math:`\\alpha` is the scaling parameter and :math:`W` is an\n    instance of :class:`scico.linop.Diagonal`.\n\n    Proximal operator :meth:`prox` is implemented when :math:`A` is an\n    instance of :class:`scico.linop.Identity`. This is not proximal\n    operator according to the strict definition since the loss function\n    is non-convex (Sec. 3) :cite:`soulez-2016-proximity`.\n    \"\"\"\n\n    def __init__(\n        self,\n        y: Union[Array, BlockArray],\n        A: Optional[Union[Callable, operator.Operator]] = None,\n        scale: float = 0.5,\n        W: Optional[linop.Diagonal] = None,\n    ):\n        r\"\"\"\n        Args:\n            y: Measurement.\n            A: Forward operator. If ``None``, defaults to :class:`.Identity`.\n            scale: Scaling parameter.\n            W: Weighting diagonal operator. Must be non-negative.\n                If ``None``, defaults to :class:`.Identity`.\n        \"\"\"\n        if W is None:\n            self.W: Union[linop.Diagonal, linop.Identity] = linop.Identity(y.shape)\n        elif isinstance(W, linop.Diagonal):\n            if snp.all(W.diagonal >= 0):\n                self.W = W\n            else:\n                raise ValueError(\"The weights, W.diagonal, must be non-negative.\")\n        else:\n            raise TypeError(f\"Parameter W must be None or a linop.Diagonal, got {type(W)}.\")\n\n        super().__init__(y=y, A=A, scale=scale)\n\n        if isinstance(self.A, linop.Identity) and snp.all(y >= 0):\n            self.has_prox = True\n\n    def __call__(self, x: Union[Array, BlockArray]) -> float:\n        return self.scale * snp.sum(\n            self.W.diagonal * snp.abs(self.y - snp.abs(self.A(x)) ** 2) ** 2\n        )\n\n    def prox(\n        self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs\n    ) -> Union[Array, BlockArray]:\n        if not self.has_prox:\n            raise NotImplementedError(f\"Method prox is not implemented.\")\n\n        𝛼 = lam * 4.0 * self.scale * self.W.diagonal\n        𝛽 = snp.abs(v)\n        p = no_nan_divide(1.0 - 𝛼 * self.y, 𝛼)\n        q = no_nan_divide(-𝛽, 𝛼)\n        r = _dep_cubic_root(p, q)\n        φ = snp.where(𝛽 > 0, v / snp.abs(v), 1.0)\n        x = snp.where(𝛼 > 0, r * φ, v)\n        return x\n"
  },
  {
    "path": "scico/metric.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Image quality metrics and related functions.\"\"\"\n\n# This module is copied from https://github.com/bwohlberg/sporco\n\nfrom typing import Optional, Union\n\nimport numpy as np\n\nimport scico.numpy as snp\nfrom scico.numpy import Array, BlockArray\n\n\ndef mae(reference: Union[Array, BlockArray], comparison: Union[Array, BlockArray]) -> float:\n    \"\"\"Compute Mean Absolute Error (MAE) between two images.\n\n    Args:\n        reference: Reference image.\n        comparison: Comparison image.\n\n    Returns:\n        MAE between `reference` and `comparison`.\n    \"\"\"\n\n    return snp.mean(snp.abs(reference - comparison).ravel())\n\n\ndef mse(reference: Union[Array, BlockArray], comparison: Union[Array, BlockArray]) -> float:\n    \"\"\"Compute Mean Squared Error (MSE) between two images.\n\n    Args:\n        reference : Reference image.\n        comparison : Comparison image.\n\n    Returns:\n        MSE between `reference` and `comparison`.\n    \"\"\"\n\n    return snp.mean(snp.abs(reference - comparison).ravel() ** 2)\n\n\ndef snr(reference: Union[Array, BlockArray], comparison: Union[Array, BlockArray]) -> float:\n    \"\"\"Compute Signal to Noise Ratio (SNR) of two images.\n\n    Args:\n        reference: Reference image.\n        comparison: Comparison image.\n\n    Returns:\n        SNR of `comparison` with respect to `reference`.\n    \"\"\"\n\n    dv = snp.var(reference)\n    with np.errstate(divide=\"ignore\"):\n        rt = dv / mse(reference, comparison)\n    return 10.0 * snp.log10(rt)\n\n\ndef psnr(\n    reference: Union[Array, BlockArray],\n    comparison: Union[Array, BlockArray],\n    signal_range: Optional[Union[int, float]] = None,\n) -> float:\n    \"\"\"Compute Peak Signal to Noise Ratio (PSNR) of two images.\n\n    The PSNR calculation defaults to using the less common definition\n    in terms of the actual range (i.e. max minus min) of the reference\n    signal instead of the maximum possible range for the data type\n    (i.e. :math:`2^b-1` for a :math:`b` bit representation).\n\n    Args:\n        reference: Reference image.\n        comparison: Comparison image.\n        signal_range: Signal range, either the value to use (e.g. 255 for\n            8 bit samples) or ``None``, in which case the actual range of\n            the reference signal is used.\n\n    Returns:\n        PSNR of `comparison` with respect to `reference`.\n    \"\"\"\n\n    if signal_range is None:\n        signal_range = snp.abs(snp.max(reference) - snp.min(reference))\n    with np.errstate(divide=\"ignore\"):\n        rt = signal_range**2 / mse(reference, comparison)\n    return 10.0 * snp.log10(rt)\n\n\ndef isnr(\n    reference: Union[Array, BlockArray],\n    degraded: Union[Array, BlockArray],\n    restored: Union[Array, BlockArray],\n) -> float:\n    \"\"\"Compute Improvement Signal to Noise Ratio (ISNR).\n\n    Compute Improvement Signal to Noise Ratio (ISNR) for reference,\n    degraded, and restored images.\n\n    Args:\n        reference: Reference image.\n        degraded: Degraded/observed image.\n        restored: Restored/estimated image.\n\n    Returns:\n        ISNR of `restored` with respect to `reference` and `degraded`.\n    \"\"\"\n\n    msedeg = mse(reference, degraded)\n    mserst = mse(reference, restored)\n    with np.errstate(divide=\"ignore\"):\n        rt = msedeg / mserst\n    return 10.0 * snp.log10(rt)\n\n\ndef bsnr(blurry: Union[Array, BlockArray], noisy: Union[Array, BlockArray]) -> float:\n    \"\"\"Compute Blurred Signal to Noise Ratio (BSNR).\n\n    Compute Blurred Signal to Noise Ratio (BSNR) for a blurred and noisy\n    image.\n\n    Args:\n        blurry: Blurred noise free image.\n        noisy: Blurred image with additive noise.\n\n    Returns:\n        BSNR of `noisy` with respect to `blurry`.\n    \"\"\"\n\n    blrvar = snp.var(blurry)\n    nsevar = snp.var(noisy - blurry)\n    with np.errstate(divide=\"ignore\"):\n        rt = blrvar / nsevar\n    return 10.0 * snp.log10(rt)\n\n\ndef rel_res(ax: Union[BlockArray, Array], b: Union[BlockArray, Array]) -> float:\n    r\"\"\"Relative residual of the solution to a linear equation.\n\n    The standard relative residual for the linear system\n    :math:`A \\mathbf{x} = \\mathbf{b}` is :math:`\\|\\mathbf{b} -\n    A \\mathbf{x}\\|_2 / \\|\\mathbf{b}\\|_2`. This function computes a\n    variant :math:`\\|\\mathbf{b} - A \\mathbf{x}\\|_2 /\n    \\max(\\|A\\mathbf{x}\\|_2, \\|\\mathbf{b}\\|_2)` that is robust to the case\n    :math:`\\mathbf{b} = 0`.\n\n    Args:\n        ax: Linear component :math:`A \\mathbf{x}` of equation.\n        b: Constant component :math:`\\mathbf{b}` of equation.\n\n    Returns:\n        Relative residual value.\n    \"\"\"\n\n    nrm = max(snp.linalg.norm(ax.ravel()), snp.linalg.norm(b.ravel()))\n    if nrm == 0.0:\n        return 0.0\n    return snp.linalg.norm((b - ax).ravel()) / nrm\n"
  },
  {
    "path": "scico/numpy/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\nr\"\"\":class:`.BlockArray` and compatible functions.\n\nThis module consists of :class:`.BlockArray` and functions that support\nboth instances of this class and jax arrays. This includes all the\nfunctions from :mod:`jax.numpy` and :mod:`numpy.testing`, where many have\nbeen extended to automatically map over block array blocks as described\nin :ref:`numpy_functions_blockarray`. Also included are additional\nfunctions unique to SCICO in :mod:`.util`.\n\"\"\"\n\nimport sys\nfrom functools import partial\nfrom typing import Union\n\nimport numpy as np\n\nimport jax\nimport jax.numpy as jnp\nfrom jax import Array\n\nfrom . import _wrappers, fft, linalg, testing, util\nfrom ._blockarray import BlockArray\nfrom ._wrapped_function_lists import (\n    creation_routines,\n    mathematical_functions,\n    reduction_functions,\n    testing_functions,\n)\n\n__all__ = [\"fft\", \"linalg\", \"testing\", \"util\"]\n\n# allow snp.blockarray(...) to create BlockArrays\nblockarray = BlockArray.blockarray\nblockarray.__module__ = __name__  # so that blockarray can be referenced in docs\n\n# BlockArray appears to originate in this module\nsys.modules[__name__].BlockArray.__module__ = __name__\n\n# copy most of jnp without wrapping\n_wrappers.add_attributes(to_dict=vars(), from_dict=jnp.__dict__)\n\n# wrap jnp funcs\n_wrappers.wrap_recursively(\n    vars(),\n    creation_routines,\n    partial(\n        _wrappers.map_func_over_args,\n        map_if_nested_args=[\"shape\"],\n        map_if_list_args=[\"device\"],\n    ),\n)\n_wrappers.wrap_recursively(vars(), mathematical_functions, _wrappers.map_func_over_args)\n_wrappers.wrap_recursively(vars(), reduction_functions, _wrappers.add_full_reduction)\n\n\ndef ravel(ba: Union[Array | BlockArray]) -> Array:\n    \"\"\"Completely flatten a :class:`BlockArray` into a single ``Array``.\n\n    When called on an ``Array``, flattens the array.\n\n    Args:\n        ba: The :class:`BlockArray` to flatten.\n\n    Returns:\n        `ba` flattened into a single ``Array.``\n    \"\"\"\n    if isinstance(ba, BlockArray):\n        return jax.numpy.concatenate([arr.flatten() for arr in ba])\n\n    return ba.ravel()\n\n\n# wrap testing funcs\n_wrappers.wrap_recursively(\n    vars(), testing_functions, partial(_wrappers.map_func_over_args, is_void=True)\n)\n\n# clean up\ndel np, jnp, _wrappers\n"
  },
  {
    "path": "scico/numpy/_blockarray.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SPORCO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"Block array class.\"\"\"\n\nimport inspect\nfrom functools import WRAPPER_ASSIGNMENTS, wraps\nfrom typing import Callable\n\nimport jax\nimport jax.numpy as jnp\n\nfrom ._wrapped_function_lists import binary_ops, unary_ops\nfrom .util import is_collapsible\n\n# Determine type of \"standard\" jax array since jax.Array is an abstract\n# base class type that is not suitable for use here.\nJaxArray = type(jnp.array([0]))\n\n\nclass BlockArray:\n    \"\"\"Block array class.\n\n    A block array provides a way to combine arrays of different shapes\n    into a single object for use with other SCICO classes. For further\n    information, see the\n    :ref:`detailed BlockArray documentation <blockarray_class>`.\n\n    Example\n    -------\n\n    >>> x = snp.blockarray((\n    ...     [[1, 3, 7],\n    ...      [2, 2, 1]],\n    ...     [2, 4, 8]\n    ... ))\n    >>> x.shape\n    ((2, 3), (3,))\n    >>> snp.sum(x)\n    Array(30, dtype=int32)\n    \"\"\"\n\n    # Ensure we use BlockArray.__radd__, __rmul__, etc for binary\n    # operations of the form op(np.ndarray, BlockArray) See\n    # https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__\n    __array_priority__ = 1\n\n    def __init__(self, inputs):\n        # convert inputs to jax arrays\n        self.arrays = [x if isinstance(x, jax.ShapeDtypeStruct) else jnp.array(x) for x in inputs]\n\n        # check that dtypes match\n        if not all(a.dtype == self.arrays[0].dtype for a in self.arrays):\n            raise ValueError(\"Heterogeneous dtypes not supported.\")\n\n    @property\n    def dtype(self):\n        \"\"\"Return the dtype of the blocks, which must currently be homogeneous.\n\n        This allows `snp.zeros(x.shape, x.dtype)` to work without a mechanism\n        to handle lists of dtypes.\n        \"\"\"\n        return self.arrays[0].dtype\n\n    def __len__(self):\n        return self.arrays.__len__()\n\n    def __getitem__(self, key):\n        \"\"\"Indexing method equivalent to x[key].\n\n        This is overridden to make, e.g., x[:2] return a BlockArray\n        rather than a list.\n        \"\"\"\n        result = self.arrays[key]\n        if isinstance(result, list):\n            return BlockArray(result)  # x[k:k+1] returns a BlockArray\n        return result  # x[k] returns a jax array\n\n    def __setitem__(self, key, value):\n        self.arrays[key] = value\n\n    @staticmethod\n    def blockarray(iterable):\n        \"\"\"Construct a :class:`.BlockArray` from a list or tuple of existing array-like.\"\"\"\n        return BlockArray(iterable)\n\n    def __repr__(self):\n        return f\"BlockArray({repr(self.arrays)})\"\n\n    def stack(self, axis=0):\n        \"\"\"Collapse a :class:`.BlockArray` to :class:`jax.Array`.\n\n        Collapse a :class:`.BlockArray` to :class:`jax.Array` by stacking\n        the blocks on axis `axis`.\n\n        Args:\n            axis: Index of new axis on which blocks are to be stacked.\n\n        Returns:\n            A :class:`jax.Array` obtained by stacking.\n\n        Raises:\n            ValueError: When called on a :class:`.BlockArray` that is not\n               stackable.\n        \"\"\"\n        if is_collapsible(self.shape):\n            return jnp.stack(self.arrays, axis=axis)\n        else:\n            raise ValueError(f\"BlockArray of shape {self.shape} cannot be collapsed to an Array.\")\n\n\n# Register BlockArray as a jax pytree; without this, jax autograd won't work.\n# Taken from what is done with tuples in jax._src.tree_util\njax.tree_util.register_pytree_node(\n    BlockArray,\n    lambda xs: (xs, None),  # to iter\n    lambda _, xs: BlockArray(xs),  # from iter\n)\n\n\n# Wrap unary ops like -x.\ndef _unary_op_wrapper(op_name):\n    op = getattr(JaxArray, op_name)\n\n    @wraps(op)\n    def op_block_array(self):\n        return BlockArray(op(x) for x in self)\n\n    return op_block_array\n\n\nfor op_name in unary_ops:\n    setattr(BlockArray, op_name, _unary_op_wrapper(op_name))\n\n\n# Wrap binary ops like x + y. \"\"\"\ndef _binary_op_wrapper(op_name):\n    op = getattr(JaxArray, op_name)\n\n    @wraps(op)\n    def op_block_array(self, other):\n        # If other is a block array, we can assume the operation is\n        # implemented (because block arrays must contain jax arrays)\n        if isinstance(other, BlockArray):\n            return BlockArray(op(x, y) for x, y in zip(self, other))\n\n        # If not, need to handle possible NotImplemented. Without this,\n        # block_array + 'hi' -> [NotImplemented, NotImplemented, ...]\n        result = list(op(x, other) for x in self)\n        if NotImplemented in result:\n            return NotImplemented\n        return BlockArray(result)\n\n    return op_block_array\n\n\nfor op_name in binary_ops:\n    setattr(BlockArray, op_name, _binary_op_wrapper(op_name))\n\n\n# Wrap jax array properties.\ndef _jax_array_prop_wrapper(prop_name):\n    prop = getattr(JaxArray, prop_name)\n\n    @property\n    @wraps(prop)\n    def prop_block_array(self):\n        result = tuple(getattr(x, prop_name) for x in self)\n\n        # If each jax_array.prop is a jax array, ...\n        if all([isinstance(x, jnp.ndarray) for x in result]):\n            # ...return a block array...\n            return BlockArray(result)\n\n        # ... otherwise return a tuple.\n        return result\n\n    return prop_block_array\n\n\nskip_props = (\"at\",)\njax_array_props = [\n    k\n    for k, v in dict(inspect.getmembers(JaxArray)).items()  # (name, method) pairs\n    if isinstance(v, property) and k[0] != \"_\" and k not in dir(BlockArray) and k not in skip_props\n]\n\nfor prop_name in jax_array_props:\n    setattr(BlockArray, prop_name, _jax_array_prop_wrapper(prop_name))\n\n\n# Wrap jax array methods.\ndef _jax_array_method_wrapper(method_name):\n    method = getattr(JaxArray, method_name)\n\n    # Don't try to set attributes that are None. Not clear why some\n    # functions/methods (e.g. block_until_ready) have None values\n    # for these attributes.\n    wrapper_assignments = WRAPPER_ASSIGNMENTS\n    for attr in (\"__name__\", \"__qualname__\"):\n        if getattr(method, attr) is None:\n            wrapper_assignments = tuple(x for x in wrapper_assignments if x != attr)\n\n    @wraps(method, assigned=wrapper_assignments)\n    def method_block_array(self, *args, **kwargs):\n        result = tuple(getattr(x, method_name)(*args, **kwargs) for x in self)\n\n        # If each jax_array.method(...) call returns a jax array, ...\n        if all([isinstance(x, jnp.ndarray) for x in result]):\n            # ... return a block array...\n            return BlockArray(result)\n\n        # ... otherwise return a tuple.\n        return result\n\n    return method_block_array\n\n\nskip_methods = ()\njax_array_methods = [\n    k\n    for k, v in dict(inspect.getmembers(JaxArray)).items()  # (name, method) pairs\n    if isinstance(v, Callable)\n    and k[0] != \"_\"\n    and k not in dir(BlockArray)\n    and k not in skip_methods\n]\n\nfor method_name in jax_array_methods:\n    setattr(BlockArray, method_name, _jax_array_method_wrapper(method_name))\n"
  },
  {
    "path": "scico/numpy/_wrapped_function_lists.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SPORCO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\n\"\"\"\nLists of functions to be wrapped in scico.numpy.\n\nThese are intended to be the functions in :mod:`jax.numpy` that should\neither\n   #. map over the blocks of a block array (for math functions);\n   #. map over a tuple of tuples to create a block array (for creation\n      functions); or\n   #. reduce a block array to a scalar (for reductions).\n\nThe links to the numpy docs in the comments are useful for distinguishing\nbetween these three cases, but note that these lists of numpy functions\ninclude extra functions that are not in :mod:`jax.numpy`, and that are\ntherefore not listed here.\n\"\"\"\n\n\"\"\" BlockArray \"\"\"\nunary_ops = (  # found from dir() on jax array\n    \"__abs__\",\n    \"__neg__\",\n    \"__pos__\",\n)\n\nbinary_ops = (  # found from dir() on jax array\n    \"__add__\",\n    \"__eq__\",\n    \"__floordiv__\",\n    \"__ge__\",\n    \"__gt__\",\n    \"__le__\",\n    \"__lt__\",\n    \"__matmul__\",\n    \"__mod__\",\n    \"__mul__\",\n    \"__ne__\",\n    \"__pow__\",\n    \"__radd__\",\n    \"__rfloordiv__\",\n    \"__rmatmul__\",\n    \"__rmul__\",\n    \"__rpow__\",\n    \"__rsub__\",\n    \"__rtruediv__\",\n    \"__sub__\",\n    \"__truediv__\",\n)\n\n\"\"\" jax.numpy \"\"\"\n\ncreation_routines = (\n    \"empty\",\n    \"ones\",\n    \"zeros\",\n    \"full\",\n)\n\nmathematical_functions = (\n    \"sin\",  # https://numpy.org/doc/stable/reference/routines.math.html\n    \"cos\",\n    \"tan\",\n    \"arcsin\",\n    \"arccos\",\n    \"arctan\",\n    \"hypot\",\n    \"arctan2\",\n    \"degrees\",\n    \"radians\",\n    \"unwrap\",\n    \"deg2rad\",\n    \"rad2deg\",\n    \"sinh\",\n    \"cosh\",\n    \"tanh\",\n    \"arcsinh\",\n    \"arccosh\",\n    \"arctanh\",\n    \"around\",\n    \"round\",\n    \"rint\",\n    \"floor\",\n    \"ceil\",\n    \"trunc\",\n    \"prod\",\n    \"sum\",\n    \"nanprod\",\n    \"nansum\",\n    \"cumprod\",\n    \"cumsum\",\n    \"nancumprod\",\n    \"nancumsum\",\n    \"diff\",\n    \"ediff1d\",\n    \"gradient\",\n    \"cross\",\n    \"exp\",\n    \"expm1\",\n    \"exp2\",\n    \"log\",\n    \"log10\",\n    \"log2\",\n    \"log1p\",\n    \"logaddexp\",\n    \"logaddexp2\",\n    \"i0\",\n    \"sinc\",\n    \"signbit\",\n    \"copysign\",\n    \"frexp\",\n    \"ldexp\",\n    \"nextafter\",\n    \"lcm\",\n    \"gcd\",\n    \"add\",\n    \"reciprocal\",\n    \"positive\",\n    \"negative\",\n    \"multiply\",\n    \"divide\",\n    \"power\",\n    \"subtract\",\n    \"true_divide\",\n    \"floor_divide\",\n    \"float_power\",\n    \"fmod\",\n    \"mod\",\n    \"modf\",\n    \"remainder\",\n    \"divmod\",\n    \"angle\",\n    \"real\",\n    \"imag\",\n    \"conj\",\n    \"conjugate\",\n    \"maximum\",\n    \"fmax\",\n    \"amax\",\n    \"nanmax\",\n    \"minimum\",\n    \"fmin\",\n    \"amin\",\n    \"nanmin\",\n    \"convolve\",\n    \"clip\",\n    \"sqrt\",\n    \"cbrt\",\n    \"square\",\n    \"abs\",\n    \"absolute\",\n    \"fabs\",\n    \"sign\",\n    \"heaviside\",\n    \"nan_to_num\",\n    \"interp\",\n    \"sort\",  # https://numpy.org/doc/stable/reference/routines.sort.html\n    \"lexsort\",\n    \"argsort\",\n    \"sort_complex\",\n    \"partition\",\n    \"argmax\",\n    \"nanargmax\",\n    \"argmin\",\n    \"nanargmin\",\n    \"argwhere\",\n    \"nonzero\",\n    \"flatnonzero\",\n    \"where\",\n    \"searchsorted\",\n    \"extract\",\n    \"count_nonzero\",\n    \"dot\",  # https://numpy.org/doc/stable/reference/routines.linalg.html\n    \"linalg.multi_dot\",\n    \"vdot\",\n    \"inner\",\n    \"outer\",\n    \"matmul\",\n    \"tensordot\",\n    \"einsum\",\n    \"einsum_path\",\n    \"linalg.matrix_power\",\n    \"kron\",\n    \"linalg.cholesky\",\n    \"linalg.qr\",\n    \"linalg.svd\",\n    \"linalg.eig\",\n    \"linalg.eigh\",\n    \"linalg.eigvals\",\n    \"linalg.eigvalsh\",\n    \"linalg.norm\",\n    \"linalg.cond\",\n    \"linalg.det\",\n    \"linalg.matrix_rank\",\n    \"linalg.slogdet\",\n    \"trace\",\n    \"linalg.solve\",\n    \"linalg.tensorsolve\",\n    \"linalg.lstsq\",\n    \"linalg.inv\",\n    \"linalg.pinv\",\n    \"linalg.tensorinv\",\n    \"shape\",  # https://numpy.org/doc/stable/reference/routines.array-manipulation.html\n    \"reshape\",\n    \"moveaxis\",\n    \"rollaxis\",\n    \"swapaxes\",\n    \"transpose\",\n    \"atleast_1d\",\n    \"atleast_2d\",\n    \"atleast_3d\",\n    \"expand_dims\",\n    \"squeeze\",\n    \"asarray\",\n    \"stack\",\n    \"block\",\n    \"vstack\",\n    \"hstack\",\n    \"dstack\",\n    \"column_stack\",\n    \"split\",\n    \"array_split\",\n    \"dsplit\",\n    \"hsplit\",\n    \"vsplit\",\n    \"tile\",\n    \"repeat\",\n    \"insert\",\n    \"append\",\n    \"resize\",\n    \"trim_zeros\",\n    \"unique\",\n    \"pad\",\n    \"flip\",\n    \"fliplr\",\n    \"flipud\",\n    \"reshape\",\n    \"roll\",\n    \"rot90\",\n    \"all\",\n    \"any\",\n    \"isfinite\",\n    \"isinf\",\n    \"isnan\",\n    \"isneginf\",\n    \"isposinf\",\n    \"iscomplex\",\n    \"iscomplexobj\",\n    \"isreal\",\n    \"isrealobj\",\n    \"isscalar\",\n    \"logical_and\",\n    \"logical_or\",\n    \"logical_not\",\n    \"logical_xor\",\n    \"allclose\",\n    \"isclose\",\n    \"array_equal\",\n    \"array_equiv\",\n    \"greater\",\n    \"greater_equal\",\n    \"less\",\n    \"less_equal\",\n    \"equal\",\n    \"not_equal\",\n    \"empty_like\",  # https://numpy.org/doc/stable/reference/routines.array-creation.html\n    \"ones_like\",\n    \"zeros_like\",\n    \"full_like\",\n)\n\n# these may also appear in the mathematical function list\nreduction_functions = (\"sum\", \"linalg.norm\", \"count_nonzero\", \"all\", \"any\")\n\n\"\"\" testing \"\"\"\n\ntesting_functions = (\"testing.assert_allclose\", \"testing.assert_array_equal\")\n"
  },
  {
    "path": "scico/numpy/_wrappers.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SPORCO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\n\"\"\"Utilities for wrapping jnp functions to handle BlockArray inputs.\"\"\"\n\nimport sys\nimport warnings\nfrom functools import wraps\nfrom inspect import Parameter, signature\nfrom types import ModuleType\nfrom typing import Callable, Iterable, Optional\n\nimport jax.numpy as jnp\n\nimport scico.numpy as snp\n\nfrom ._blockarray import BlockArray\n\n\ndef add_attributes(\n    to_dict: dict,\n    from_dict: dict,\n    modules_to_recurse: Optional[Iterable[str]] = None,\n):\n    \"\"\"Add attributes in `from_dict` to `to_dict`.\n\n    Underscore attributes are ignored. Modules are ignored, except those\n    listed in `modules_to_recurse`, which are added recursively. All\n    others are added.\n    \"\"\"\n\n    if modules_to_recurse is None:\n        modules_to_recurse = ()\n\n    for name, obj in from_dict.items():\n        if name[0] == \"_\":\n            continue\n        if isinstance(obj, ModuleType):\n            if name in modules_to_recurse:\n                qualname = to_dict[\"__name__\"] + \".\" + name\n                to_dict[name] = ModuleType(name, doc=obj.__doc__)\n                to_dict[name].__package__ = to_dict[\"__name__\"]\n                # enable `import scico.numpy.linalg` and `from scico.numpy.linalg import norm`\n                sys.modules[qualname] = to_dict[name]\n                sys.modules[qualname].__name__ = qualname\n                add_attributes(to_dict[name].__dict__, obj.__dict__)\n        else:\n            to_dict[name] = obj\n\n\ndef wrap_recursively(\n    target_dict: dict,\n    names: Iterable[str],\n    wrap: Callable,\n):\n    \"\"\"Call wrap functions in `target_dict`, correctly handling names like `\"linalg.norm\"`.\"\"\"\n\n    for name in names:\n        if \".\" in name:\n            module, rest = name.split(\".\", maxsplit=1)\n            wrap_recursively(target_dict[module].__dict__, [rest], wrap)\n        else:\n            if name in target_dict:\n                target_dict[name] = wrap(target_dict[name])\n            else:\n                warnings.warn(f\"In call to wrap_recursively, name {name} is not in target_dict\")\n\n\ndef map_func_over_args(\n    func: Callable,\n    map_if_nested_args: Optional[list[str]] = [],\n    map_if_list_args: Optional[list[str]] = [],\n    is_void: Optional[bool] = False,\n):\n    \"\"\"\n    Wrap a function so that it automatically maps over its arguments,\n    returning a BlockArray.\n\n    BlockArray arguments always trigger mapping. Other arguments trigger\n    mapping if they meet specified criteria.\n    \"\"\"\n    # check inputs\n    func_signature = signature(func)\n    for arg in map_if_nested_args + map_if_list_args:\n        if arg not in func_signature.parameters:\n            raise ValueError(f\"`{arg}` is not an argument of {func.__name__}\")\n\n    # define wrapped function\n    @wraps(func)\n    def wrapped(*args, **kwargs):\n        arg_names = [\n            k\n            for k, v in func_signature.parameters.items()\n            if v.kind\n            in (\n                Parameter.POSITIONAL_ONLY,\n                Parameter.POSITIONAL_OR_KEYWORD,\n            )\n        ]\n\n        # look in args for mapping triggers\n        arg_is_mapping = []\n        for arg_num, arg_val in enumerate(args):\n            if (\n                isinstance(arg_val, BlockArray)\n                or (\n                    snp.util.is_nested(arg_val)\n                    and arg_num < len(arg_names)\n                    and arg_names[arg_num] in map_if_nested_args\n                )\n                or (\n                    isinstance(arg_val, (list, tuple))\n                    and arg_num < len(arg_names)\n                    and arg_names[arg_num] in map_if_list_args\n                )\n            ):\n                arg_is_mapping.append(True)\n            else:\n                arg_is_mapping.append(False)\n\n        # look in kwargs for mapping triggers\n        kwarg_is_mapping = {}\n        for arg_name, arg_val in kwargs.items():\n            if (\n                isinstance(arg_val, BlockArray)\n                or (arg_name in map_if_nested_args and snp.util.is_nested(arg_val))\n                or (arg_name in map_if_list_args and isinstance(arg_val, (list, tuple)))\n            ):\n                kwarg_is_mapping[arg_name] = True\n            else:\n                kwarg_is_mapping[arg_name] = False\n\n        # no arguments that trigger mapping? call as usual\n        if sum(arg_is_mapping) == 0 and sum(kwarg_is_mapping.values()) == 0:\n            return func(*args, **kwargs)\n\n        # count number of blocks\n        num_blocks = (\n            len(\n                args[\n                    [index for index, mapping_flag in enumerate(arg_is_mapping) if mapping_flag][0]\n                ]\n            )  # first mapping arg\n            if sum(arg_is_mapping)\n            else len(\n                kwargs[[k for k, mapping_flag in kwarg_is_mapping.items() if mapping_flag][0]]\n            )  # first mapping kwarg\n        )\n\n        # map func over the mapping args\n        results = []\n        for block_ind in range(num_blocks):\n            result = func(\n                *[\n                    arg[block_ind] if is_mapping else arg\n                    for arg, is_mapping in zip(args, arg_is_mapping)\n                ],\n                **{\n                    k: kwargs[k][block_ind] if is_mapping else kwargs[k]\n                    for k, is_mapping in kwarg_is_mapping.items()\n                },\n            )\n            results.append(result)\n        if is_void:\n            return\n\n        return BlockArray(results)\n\n    return wrapped\n\n\ndef add_full_reduction(func: Callable, axis_arg_name: Optional[str] = \"axis\"):\n    \"\"\"Wrap a function so that it can fully reduce a BlockArray.\n\n    Wrap a function so that it can fully reduce a :class:`.BlockArray`. If\n    nothing is passed for the axis argument and the function is called\n    on a :class:`.BlockArray`, it is fully ravelled before the function is\n    called.\n\n    Should be outside :func:`map_func_over_args`.\n    \"\"\"\n    sig = signature(func)\n    if axis_arg_name not in sig.parameters:\n        raise ValueError(\n            f\"Cannot wrap {func} as a reduction because it has no {axis_arg_name} argument.\"\n        )\n\n    @wraps(func)\n    def wrapped(*args, **kwargs):\n        bound_args = sig.bind(*args, **kwargs)\n\n        ba_args = {}\n        for k, v in list(bound_args.arguments.items()):\n            if isinstance(v, BlockArray):\n                ba_args[k] = bound_args.arguments.pop(k)\n\n        if \"axis\" in bound_args.arguments:\n            return func(*bound_args.args, **bound_args.kwargs, **ba_args)  # call func as normal\n\n        if len(ba_args) > 1:\n            raise ValueError(\"Cannot perform a full reduction with multiple BlockArray arguments.\")\n\n        # fully ravel the ba argument\n        ba_args = {k: jnp.concatenate(v.ravel()) for k, v in ba_args.items()}\n        return func(*bound_args.args, **bound_args.kwargs, **ba_args)\n\n    return wrapped\n"
  },
  {
    "path": "scico/numpy/fft.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Discrete Fourier Transform functions.\"\"\"\n\nimport numpy as np\n\nimport jax.numpy as jnp\n\nfrom . import _wrappers\n\n_wrappers.add_attributes(\n    to_dict=vars(),\n    from_dict=jnp.fft.__dict__,\n)\n\n# clean up\ndel np, jnp, _wrappers\n"
  },
  {
    "path": "scico/numpy/linalg.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Linear algebra functions.\"\"\"\n\nimport numpy as np\n\nimport jax.numpy as jnp\n\nfrom . import _wrappers\n\n_wrappers.add_attributes(\n    to_dict=vars(),\n    from_dict=jnp.linalg.__dict__,\n)\n\n# clean up\ndel np, jnp, _wrappers\n"
  },
  {
    "path": "scico/numpy/testing.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Test support functions.\"\"\"\n\nimport numpy as np\n\nfrom . import _wrappers\n\n_wrappers.add_attributes(\n    to_dict=vars(),\n    from_dict=np.testing.__dict__,\n)\n\n# clean up\ndel np, _wrappers\n"
  },
  {
    "path": "scico/numpy/util.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SPORCO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"Utility functions for working with jax arrays and BlockArrays.\"\"\"\n\nfrom __future__ import annotations\n\nimport collections\nfrom math import prod\nfrom typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union\n\nimport numpy as np\n\nimport jax\n\nfrom typing_extensions import TypeGuard\n\nimport scico.numpy as snp\nfrom scico.typing import ArrayIndex, Axes, AxisIndex, BlockShape, DType, Shape\n\n\ndef transpose_ntpl_of_list(ntpl: NamedTuple) -> List[NamedTuple]:\n    \"\"\"Convert a namedtuple of lists/arrays to a list of namedtuples.\n\n    Args:\n        ntpl: Named tuple object to be transposed.\n\n    Returns:\n        List of namedtuple objects.\n    \"\"\"\n    cls = ntpl.__class__\n    numentry = len(ntpl[0]) if isinstance(ntpl[0], list) else ntpl[0].shape[0]\n    nfields = len(ntpl._fields)\n    return [cls(*[ntpl[m][n] for m in range(nfields)]) for n in range(numentry)]\n\n\ndef transpose_list_of_ntpl(ntlist: List[NamedTuple]) -> NamedTuple:\n    \"\"\"Convert a list of namedtuples to namedtuple of lists.\n\n    Args:\n        ntpl: List of namedtuple objects to be transposed.\n\n    Returns:\n        Named tuple of lists.\n    \"\"\"\n    cls = ntlist[0].__class__\n    numentry = len(ntlist)\n    nfields = len(ntlist[0])\n    return cls(*[[ntlist[m][n] for m in range(numentry)] for n in range(nfields)])  # type: ignore\n\n\ndef namedtuple_to_array(ntpl: NamedTuple) -> snp.Array:\n    \"\"\"Convert a namedtuple to an array.\n\n    Convert a :func:`collections.namedtuple` object to a\n    :class:`numpy.ndarray` object that can be saved using\n    :func:`numpy.savez`.\n\n    Args:\n        ntpl: Named tuple object to be converted to ndarray.\n\n    Returns:\n      Array representation of input named tuple.\n    \"\"\"\n    return np.asarray(\n        {\n            \"name\": ntpl.__class__.__name__,\n            \"fields\": ntpl._fields,\n            \"data\": {fname: fval for fname, fval in zip(ntpl._fields, ntpl)},\n        }\n    )\n\n\ndef array_to_namedtuple(array: snp.Array) -> NamedTuple:\n    \"\"\"Convert an array representation of a namedtuple back to a namedtuple.\n\n    Convert a :class:`numpy.ndarray` object constructed by\n    :func:`namedtuple_to_array` back to the original\n    :func:`collections.namedtuple` representation.\n\n    Args:\n      Array representation of named tuple constructed by\n        :func:`namedtuple_to_array`.\n\n    Returns:\n      Named tuple object with the same name and fields as the original\n      named tuple object provided to :func:`namedtuple_to_array`.\n    \"\"\"\n    cls = collections.namedtuple(array.item()[\"name\"], array.item()[\"fields\"])  # type: ignore\n    return cls(**array.item()[\"data\"])\n\n\ndef normalize_axes(\n    axes: Optional[Axes],\n    shape: Optional[Shape] = None,\n    default: Optional[List[int]] = None,\n    sort: bool = False,\n) -> Sequence[int]:\n    \"\"\"Normalize `axes` to a sequence and optionally ensure correctness.\n\n    Normalize `axes` to a tuple or list and (optionally) ensure that\n    entries refer to axes that exist in `shape`.\n\n    Args:\n        axes: User specification of one or more axes: int, list, tuple,\n           or ``None``. Negative values count from the last to the first\n           axis.\n        shape: The shape of the array of which axes are being specified.\n           If not ``None``, `axes` is checked to make sure its entries\n           refer to axes that exist in `shape`.\n        default: Default value to return if `axes` is ``None``. By\n           default, `tuple(range(len(shape)))`.\n        sort: If ``True``, sort the returned axis indices.\n\n    Returns:\n        Tuple or list of axes (never an int, never ``None``). The output\n        will only be a list if the input is a list or if the input is\n        ``None`` and `defaults` is a list.\n    \"\"\"\n\n    if axes is None:\n        if default is None:\n            if shape is None:\n                raise ValueError(\n                    \"Argument 'axes' cannot be None without a default or shape specified.\"\n                )\n            axes = tuple(range(len(shape)))\n        else:\n            axes = default\n    elif isinstance(axes, (list, tuple)):\n        axes = axes\n    elif isinstance(axes, int):\n        axes = (axes,)\n    else:\n        raise ValueError(f\"Could not understand argument 'axes' {axes} as a list of axes.\")\n    if shape is not None:\n        if min(axes) < 0:\n            axes = tuple([len(shape) + a if a < 0 else a for a in axes])\n        if max(axes) >= len(shape):\n            raise ValueError(\n                f\"Invalid axes {axes} specified; each axis must be less than `len(shape)`={len(shape)}.\"\n            )\n    if len(set(axes)) != len(axes):\n        raise ValueError(f\"Duplicate value in axes {axes}; each axis must be unique.\")\n    if sort:\n        axes = tuple(sorted(axes))\n    return axes\n\n\ndef slice_length(length: int, idx: AxisIndex) -> Optional[int]:\n    \"\"\"Determine the length of an array axis after indexing.\n\n    Determine the length of an array axis after slicing. An exception is\n    raised if the indexing expression is an integer that is out of bounds\n    for the specified axis length. A value of ``None`` is returned for\n    valid integer indexing expressions as an indication that the\n    corresponding axis shape is an empty tuple; this value should be\n    converted to a unit integer if the axis size is required.\n\n    Args:\n        length: Length of axis being sliced.\n        idx: Indexing/slice to be applied to axis.\n\n    Returns:\n        Length of indexed/sliced axis.\n\n    Raises:\n        ValueError: If `idx` is an integer index that is out bounds for\n            the axis length or if the type of `idx` is not one of\n            `Ellipsis`, `int`, or `slice`.\n    \"\"\"\n    if idx is Ellipsis:\n        return length\n    if isinstance(idx, int):\n        if idx < -length or idx > length - 1:\n            raise ValueError(f\"Index {idx} out of bounds for axis of length {length}.\")\n        return None\n    if not isinstance(idx, slice):\n        raise ValueError(f\"Index expression {idx} is of an unrecognized type.\")\n    start, stop, stride = idx.indices(length)\n    if start > stop:\n        start = stop\n    return (stop - start + stride - 1) // stride\n\n\ndef indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]:\n    \"\"\"Determine the shape of an array after indexing/slicing.\n\n    The indexed shape is determined by replicating the observed effects\n    of NumPy/JAX array indexing/slicing syntax. It is significantly\n    faster than :func:`.jax_indexed_shape`, and has a minimal memory\n    footprint in all circumstances.\n\n    Args:\n        shape: Shape of array.\n        idx: Indexing expression (singleton or tuple of `Ellipsis`,\n           `int`, `slice`, or ``None`` (`np.newaxis`)).\n\n    Returns:\n        Shape of indexed/sliced array.\n\n    Raises:\n        ValueError: If any element of `idx` is not one of `Ellipsis`,\n        `int`, `slice`, or ``None`` (`np.newaxis`), or if an integer\n        index is out bounds for the corresponding axis length.\n    \"\"\"\n    if not isinstance(idx, tuple):\n        idx = (idx,)\n    idx_shape: List[Optional[int]] = list(shape)\n    offset = 0\n    newaxis = 0\n    for axis, ax_idx in enumerate(idx):\n        if ax_idx is None:\n            idx_shape.insert(axis, 1)\n            newaxis += 1\n            continue\n        if ax_idx is Ellipsis:\n            offset = len(shape) - len(idx)\n            continue\n        idx_shape[axis + offset + newaxis] = slice_length(shape[axis + offset], ax_idx)\n    return tuple(filter(lambda x: x is not None, idx_shape))  # type: ignore\n\n\ndef jax_indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]:\n    \"\"\"Determine the shape of an array after indexing/slicing.\n\n    The indexed shape is determined by constructing and indexing an array\n    of the appropriate shape, relying on :func:`jax.jit` to avoid memory\n    allocation. It is potentially more reliable than\n    :func:`.indexed_shape` because the indexing/slicing calculations are\n    referred to JAX, but is significantly slower, and will involved\n    potentially significant memory allocations if JIT is disabled, e.g.\n    for debugging purposes.\n\n    Args:\n        shape: Shape of array.\n        idx: Indexing expression (singleton or tuple of `Ellipsis`,\n           `int`, `slice`, or ``None`` (`np.newaxis`)).\n\n    Returns:\n        Shape of indexed/sliced array.\n    \"\"\"\n    if not isinstance(idx, tuple):\n        idx = (idx,)\n\n    # Convert any slices to its representation (slice, (start, stop, step))\n    # allowing hashing, needed for jax.jit\n    idx = tuple(exp.__reduce__() if isinstance(exp, slice) else exp for exp in idx)  # type: ignore\n\n    def get_shape(in_shape, ind_expr):\n        # convert slices representations back to slices\n        ind_expr = tuple(\n            (slice(*exp[1]) if isinstance(exp, tuple) and len(exp) > 0 and exp[0] == slice else exp)\n            for exp in ind_expr\n        )\n        return jax.numpy.empty(in_shape)[ind_expr].shape\n\n    # This compiles each time it gets new arguments because all arguments are static.\n    f = jax.jit(get_shape, static_argnums=(0, 1))\n\n    return tuple(t.item() for t in f(shape, idx))  # type: ignore\n\n\ndef no_nan_divide(\n    x: Union[snp.BlockArray, snp.Array], y: Union[snp.BlockArray, snp.Array]\n) -> Union[snp.BlockArray, snp.Array]:\n    \"\"\"Return `x/y`, with 0 instead of :data:`~numpy.NaN` where `y` is 0.\n\n    Args:\n        x: Numerator.\n        y: Denominator.\n\n    Returns:\n        `x / y` with 0 wherever `y == 0`.\n    \"\"\"\n\n    return snp.where(y != 0, snp.divide(x, snp.where(y != 0, y, 1)), 0)\n\n\ndef _readable_size(size: int) -> str:\n    \"\"\"Return a human-readable representation of an array size.\n\n    Args:\n        size: A positive integer array size.\n\n    Returns:\n        A string representation of the size.\n    \"\"\"\n    factor = [1, 1024, 1024**2, 1024**3, 1024**4]\n    units = [\"B\", \"KB\", \"MB\", \"GB\", \"TB\"]\n    idx_tuple = np.nonzero([size // f for f in factor[::-1]])\n    if idx_tuple[0].size == 0:\n        idx = len(factor) - 1\n    else:\n        idx = int(idx_tuple[0][0])\n    val = size // factor[::-1][idx]\n    ustr = units[::-1][idx]\n    return f\"{val} {ustr}\"\n\n\ndef array_info(x: Union[snp.BlockArray, snp.Array]) -> str:\n    \"\"\"Return a string providing information about an array.\n\n    Args:\n        x: A numpy or jax array or scico :class:`BlockArray`.\n\n    Returns:\n        A string containing information on the array.\n\n    Raises:\n       TypeError: If the array is not of a recognized type.\n    \"\"\"\n    if isinstance(x, np.ndarray):\n        array_type = \"numpy.ndarray\"\n    elif isinstance(x, jax.Array):\n        array_type = \"jax.Array\"\n    elif isinstance(x, snp.BlockArray):\n        array_type = \"scico.numpy.BlockArray\"\n    else:\n        raise TypeError(\"Unrecognized array type {type(x)}.\")\n    totalbytes = np.sum(x.nbytes).item()  # type: ignore\n    return (\n        f\"\"\"{array_type}\n  shape:    {x.shape}\n  size:     {x.size}\n  bytes:    {totalbytes} ({_readable_size(totalbytes)})\n\"\"\"\n        + (f\"  device:   {x.device}\\n\" if hasattr(x, \"device\") else \"\")\n        + f\"\"\"  dtype:    {dtype_name(x.dtype)}\n  id:       {id(x)}\n  min, max: {snp.ravel(x).min()}, {snp.ravel(x).max()}\n\"\"\"\n    )\n\n\ndef shape_to_size(shape: Union[Shape, BlockShape]) -> int:\n    r\"\"\"Compute array size corresponding to a specified shape.\n\n    Compute array size corresponding to a specified shape, which may be\n    nested, i.e. corresponding to a :class:`BlockArray`.\n\n    Args:\n        shape: A shape tuple.\n\n    Returns:\n        The number of elements in an array or :class:`BlockArray` with\n        shape `shape`.\n    \"\"\"\n\n    if is_nested(shape):\n        return sum(prod(s) for s in shape)  # type: ignore\n\n    return prod(shape)  # type: ignore\n\n\ndef is_array(x: Any) -> bool:\n    \"\"\"Check if input is of type :class:`jax.Array` or :class:`numpy.ndarray`.\n\n    Check if input is an array, of type :class:`jax.Array` or\n    :class:`numpy.ndarray`.\n\n    Args:\n        x: Object to be tested.\n\n    Returns:\n        ``True`` if `x` is an array, ``False`` otherwise.\n    \"\"\"\n    return isinstance(x, (np.ndarray, jax.Array))\n\n\ndef is_arraylike(x: Any) -> bool:\n    \"\"\"Check if input is of type :class:`jax.typing.ArrayLike`.\n\n    `isinstance(x, jax.typing.ArrayLike)` does not work in Python < 3.10,\n    see https://jax.readthedocs.io/en/latest/jax.typing.html#jax-typing-best-practices.\n\n    Args:\n        x: Object to be tested.\n\n    Returns:\n        ``True`` if `x` is an ArrayLike, ``False`` otherwise.\n    \"\"\"\n    return isinstance(x, (np.ndarray, jax.Array)) or np.isscalar(x)\n\n\ndef is_nested(x: Any) -> bool:\n    \"\"\"Check if input is a list/tuple containing at least one list/tuple.\n\n    Args:\n        x: Object to be tested.\n\n    Returns:\n        ``True`` if `x` is a list/tuple containing at least one\n        list/tuple, ``False`` otherwise.\n\n    Example:\n        >>> is_nested([1, 2, 3])\n        False\n        >>> is_nested([(1,2), (3,)])\n        True\n        >>> is_nested([[1, 2], 3])\n        True\n\n    \"\"\"\n    return isinstance(x, (list, tuple)) and any([isinstance(_, (list, tuple)) for _ in x])\n\n\ndef is_collapsible(shapes: Sequence[Union[Shape, BlockShape]]) -> bool:\n    \"\"\"Determine whether a sequence of shapes can be collapsed.\n\n    Return ``True`` if the a list of shapes represent arrays that can\n    be stacked, i.e., they are all the same.\n\n    Args:\n        shapes: A sequence of shapes.\n\n    Returns:\n        A boolean value indicating whether the shapes are all the same.\n    \"\"\"\n    return all(s == shapes[0] for s in shapes)\n\n\ndef is_blockable(shapes: Sequence[Union[Shape, BlockShape]]) -> TypeGuard[Union[Shape, BlockShape]]:\n    \"\"\"Determine whether a sequence of shapes could be a :class:`BlockArray` shape.\n\n    Return ``True`` if the sequence of shapes represent arrays that can\n    be combined into a :class:`BlockArray`, i.e., none are nested.\n\n    Args:\n        shapes: A sequence of shapes.\n\n    Returns:\n        A boolean value indicating whether any of the shapes are nested.\n    \"\"\"\n    return not any(is_nested(s) for s in shapes)\n\n\ndef shape_dtype_rep(\n    shape: Union[Shape, BlockShape], dtype: DType\n) -> Union[jax.ShapeDtypeStruct, snp.BlockArray]:\n    \"\"\"Construct a representation of array or blockarray shape and dtype.\n\n    Construct a representation of array or block array shape and dtype\n    that is suitable for both jax arrays and scico blockarrays.\n\n    Args:\n       shape: Array or blockarray shape.\n       dtype: Array or blockarray dtype.\n\n    Returns:\n       A :class:`jax.ShapeDtypeStruct` or a :class:`.BlockArray`\n       containing :class:`jax.ShapeDtypeStruct`s.\n    \"\"\"\n    if is_nested(shape):  # block array\n        return snp.BlockArray([jax.ShapeDtypeStruct(blk_shape, dtype=dtype) for blk_shape in shape])\n    else:  # standard array\n        return jax.ShapeDtypeStruct(shape, dtype=dtype)\n\n\ndef broadcast_nested_shapes(\n    shape_a: Union[Shape, BlockShape], shape_b: Union[Shape, BlockShape]\n) -> Union[Shape, BlockShape]:\n    r\"\"\"Compute the result of broadcasting on array shapes.\n\n    Compute the result of applying a broadcasting binary operator to\n    (block) arrays with (possibly nested) shapes `shape_a` and `shape_b`.\n    Extends :func:`numpy.broadcast_shapes` to also support the nested\n    tuple shapes of :class:`BlockArray`\\ s.\n\n    Args:\n        shape_a: First array shape.\n        shape_b: Second array shape.\n\n    Returns:\n        A (possibly nested) shape tuple.\n\n    Example:\n        >>> broadcast_nested_shapes(((1, 1, 3), (2, 3, 1)), ((2, 3,), (2, 1, 4)))\n        ((1, 2, 3), (2, 3, 4))\n    \"\"\"\n    if not is_nested(shape_a) and not is_nested(shape_b):\n        return snp.broadcast_shapes(shape_a, shape_b)\n\n    if is_nested(shape_a) and not is_nested(shape_b):\n        return tuple(snp.broadcast_shapes(s, shape_b) for s in shape_a)\n\n    if not is_nested(shape_a) and is_nested(shape_b):\n        return tuple(snp.broadcast_shapes(shape_a, s) for s in shape_b)\n\n    if is_nested(shape_a) and is_nested(shape_b):\n        return tuple(snp.broadcast_shapes(s_a, s_b) for s_a, s_b in zip(shape_a, shape_b))\n\n    raise RuntimeError(\"Unexpected case encountered in broadcast_nested_shapes.\")\n\n\ndef is_real_dtype(dtype: DType) -> bool:\n    \"\"\"Determine whether a dtype is real.\n\n    Args:\n        dtype: A :mod:`numpy` or :mod:`scico.numpy` dtype (e.g.\n               :attr:`~numpy.float32`, :attr:`~numpy.complex64`).\n\n    Returns:\n        ``False`` if the dtype is complex, otherwise ``True``.\n    \"\"\"\n    return snp.dtype(dtype).kind != \"c\"\n\n\ndef is_complex_dtype(dtype: DType) -> bool:\n    \"\"\"Determine whether a dtype is complex.\n\n    Args:\n        dtype: A :mod:`numpy` or :mod:`scico.numpy` dtype (e.g.\n               :attr:`~numpy.float32`, :attr:`~numpy.complex64`).\n\n    Returns:\n        ``True`` if the dtype is complex, otherwise ``False``.\n    \"\"\"\n    return snp.dtype(dtype).kind == \"c\"\n\n\ndef real_dtype(dtype: DType) -> DType:\n    \"\"\"Construct the corresponding real dtype for a given complex dtype.\n\n    Construct the corresponding real dtype for a given complex dtype,\n    e.g. the real dtype corresponding to :attr:`~numpy.complex64` is\n    :attr:`~numpy.float32`.\n\n    Args:\n        dtype: A complex numpy or scico.numpy dtype (e.g.\n               :attr:`~numpy.complex64`, :attr:`~numpy.complex128`).\n\n    Returns:\n        The real dtype corresponding to the input dtype\n    \"\"\"\n\n    return snp.zeros(1, dtype).real.dtype\n\n\ndef complex_dtype(dtype: DType) -> DType:\n    \"\"\"Construct the corresponding complex dtype for a given real dtype.\n\n    Construct the corresponding complex dtype for a given real dtype,\n    e.g. the complex dtype corresponding to :attr:`~numpy.float32` is\n    :attr:`~numpy.complex64`.\n\n    Args:\n        dtype: A real numpy or scico.numpy dtype (e.g. :attr:`~numpy.float32`,\n               :attr:`~numpy.float64`).\n\n    Returns:\n        The complex dtype corresponding to the input dtype.\n    \"\"\"\n\n    return (snp.zeros(1, dtype) + 1j).dtype\n\n\ndef dtype_name(dtype: DType) -> str:\n    \"\"\"Return the name of a dtype.\n\n    Construct a string representation of a dtype name.\n\n    Args:\n        dtype: The dtype for which the name is required.\n\n    Returns:\n        The name of the dtype.\n    \"\"\"\n    if type(dtype).__module__ == \"numpy.dtypes\":\n        return f\"\"\"numpy.{dtype.name}\"\"\"  # type: ignore\n    return f\"\"\"{dtype.__module__}.{dtype.__qualname__}\"\"\"  # type: ignore\n\n\ndef is_scalar_equiv(s: Any) -> bool:\n    \"\"\"Determine whether an object is a scalar or is scalar-equivalent.\n\n    Determine whether an object is a scalar or a singleton array.\n\n    Args:\n        s: Object to be tested.\n\n    Returns:\n        ``True`` if the object is a scalar or a singleton array,\n        otherwise ``False``.\n    \"\"\"\n    return snp.isscalar(s) or (isinstance(s, jax.Array) and s.ndim == 0)\n"
  },
  {
    "path": "scico/operator/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Operator functions and classes.\"\"\"\n\nimport sys\n\n# isort: off\nfrom ._operator import Operator\nfrom .biconvolve import BiConvolve\nfrom ._func import operator_from_function, Abs, Angle, Exp\nfrom ._stack import DiagonalStack, VerticalStack, DiagonalReplicated\n\n__all__ = [\n    \"Operator\",\n    \"BiConvolve\",\n    \"DiagonalReplicated\",\n    \"DiagonalStack\",\n    \"VerticalStack\",\n    \"operator_from_function\",\n    \"Abs\",\n    \"Angle\",\n    \"Exp\",\n]\n\n# Imported items in __all__ appear to originate in top-level linop module\nfor name in __all__:\n    getattr(sys.modules[__name__], name).__module__ = __name__\n"
  },
  {
    "path": "scico/operator/_func.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Operators constructed from functions.\"\"\"\n\nfrom typing import Any, Callable, Optional, Union\n\nimport scico.numpy as snp\nfrom scico.typing import BlockShape, DType, Shape\n\nfrom ._operator import Operator\n\n__all__ = [\n    \"operator_from_function\",\n    \"Abs\",\n    \"Angle\",\n    \"Exp\",\n]\n\n\ndef operator_from_function(f: Callable, classname: str, f_name: Optional[str] = None):\n    \"\"\"Make an :class:`.Operator` from a function.\n\n    Example\n    -------\n    >>> AbsVal = operator_from_function(snp.abs, 'AbsVal')\n    >>> H = AbsVal((2,))\n    >>> H(snp.array([1.0, -1.0]))\n    Array([1., 1.], dtype=float32)\n\n    Args:\n        f: Function from which to create an :class:`.Operator`.\n        classname: Name of the resulting class.\n        f_name: Name of `f` for use in docstrings. Useful for getting\n            the correct version of wrapped functions. Defaults to\n            `f\"{f.__module__}.{f.__name__}\"`.\n    \"\"\"\n\n    if f_name is None:\n        f_name = f\"{f.__module__}.{f.__name__}\"\n\n    f_doc = rf\"\"\"\n\n        Args:\n            input_shape: Shape of input array.\n            args: Positional arguments passed to :func:`{f_name}`.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`. If the :class:`.Operator`\n                implements complex-valued operations, this must be a\n                complex dtype (typically :attr:`~numpy.complex64`) for\n                correct adjoint and gradient calculation.\n            output_shape: Shape of output array. Defaults to ``None``.\n                If ``None``, `output_shape` is determined by evaluating\n                `self.__call__` on an input array of zeros.\n            output_dtype: `dtype` for output argument. Defaults to\n                ``None``. If ``None``, `output_dtype` is determined by\n                evaluating `self.__call__` on an input array of zeros.\n            jit: If ``True``, call :meth:`.Operator.jit` on this\n                `Operator` to jit the forward, adjoint, and gram\n                functions. Same as calling :meth:`.Operator.jit` after\n                the :class:`.Operator` is created.\n            **kwargs: Keyword arguments passed to :func:`{f_name}`.\n        \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Union[Shape, BlockShape],\n        *args: Any,\n        input_dtype: DType = snp.float32,\n        output_shape: Optional[Union[Shape, BlockShape]] = None,\n        output_dtype: Optional[DType] = None,\n        jit: bool = True,\n        **kwargs: Any,\n    ):\n        self._eval = lambda x: f(x, *args, **kwargs)\n        super().__init__(input_shape, input_dtype=input_dtype, output_shape=output_shape, output_dtype=output_dtype, jit=jit)  # type: ignore\n\n    OpClass = type(classname, (Operator,), {\"__init__\": __init__})\n    __class__ = OpClass  # needed for super() to work\n\n    OpClass.__doc__ = f\"Operator version of :func:`{f_name}`.\"\n    OpClass.__init__.__doc__ = f_doc  # type: ignore\n\n    return OpClass\n\n\nAbs = operator_from_function(snp.abs, \"Abs\", \"scico.numpy.abs\")\nAngle = operator_from_function(snp.angle, \"Angle\", \"scico.numpy.angle\")\nExp = operator_from_function(snp.exp, \"Exp\", \"scico.numpy.exp\")\n"
  },
  {
    "path": "scico/operator/_operator.py",
    "content": "# Copyright (C) 2020-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Operator base class.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom functools import wraps\nfrom typing import Callable, Optional, Tuple, Union\n\nimport numpy as np\n\nimport jax\nimport jax.numpy as jnp\nfrom jax.dtypes import result_type\n\nimport scico\nimport scico.numpy as snp\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.util import dtype_name, is_nested, shape_to_size\nfrom scico.typing import BlockShape, DType, Shape\n\n\ndef _wrap_mul_div_scalar(func: Callable) -> Callable:\n    r\"\"\"Wrapper function for multiplication and division operators.\n\n    Wrapper function for defining `__mul__`, `__rmul__`, and\n    `__truediv__` between a scalar and an `Operator`.\n\n    If one of these binary operations are called in the form\n    `binop(Operator, other)` and 'b' is a scalar, specialized\n    :class:`.Operator` constructors can be called.\n\n    Args:\n        func: should be either `.__mul__()`, `.__rmul__()`,\n           or `.__truediv__()`.\n\n    Returns:\n       Wrapped version of `func`.\n\n    Raises:\n        TypeError: If a binop with the form `binop(Operator, other)` is\n        called and `other` is not a scalar.\n    \"\"\"\n\n    @wraps(func)\n    def wrapper(a, b):\n        if snp.util.is_scalar_equiv(b):\n            return func(a, b)\n\n        raise TypeError(f\"Operation {func.__name__} not defined between {type(a)} and {type(b)}.\")\n\n    wrapper._unwrapped = func  # type: ignore\n\n    return wrapper\n\n\nclass Operator:\n    \"\"\"Generic operator class.\"\"\"\n\n    # See https://numpy.org/doc/stable/user/c-info.beyond-basics.html#ndarray.__array_priority__\n    __array_priority__ = 1\n\n    def __init__(\n        self,\n        input_shape: Union[Shape, BlockShape],\n        output_shape: Optional[Union[Shape, BlockShape]] = None,\n        eval_fn: Optional[Callable] = None,\n        input_dtype: DType = np.float32,\n        output_dtype: Optional[DType] = None,\n        jit: bool = False,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input array.\n            output_shape: Shape of output array. Defaults to ``None``.\n                If ``None``, `output_shape` is determined by evaluating\n                `self.__call__` on an input array of zeros.\n            eval_fn: Function used in evaluating this :class:`.Operator`.\n                Defaults to ``None``. Required unless `__init__` is being\n                called from a derived class with an `_eval` method.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`. If the :class:`.Operator`\n                implements complex-valued operations, this must be a\n                complex dtype (typically :attr:`~numpy.complex64`) for\n                correct adjoint and gradient calculation.\n            output_dtype: `dtype` for output argument. Defaults to\n                ``None``. If ``None``, `output_dtype` is determined by\n                evaluating `self.__call__` on an input array of zeros.\n            jit: If ``True``, call :meth:`Operator.jit()` on this\n                :class:`.Operator` to jit the forward, adjoint, and gram\n                functions. Same as calling :meth:`Operator.jit` after the\n                :class:`.Operator` is created.\n\n        Raises:\n            NotImplementedError: If the `eval_fn` parameter is not\n               specified and the `_eval` method is not defined in a\n               derived class.\n        \"\"\"\n\n        #: Shape of input array or :class:`.BlockArray`.\n        self.input_shape: Union[Shape, BlockShape]\n\n        #: Size of flattened input. Sum of product of `input_shape` tuples.\n        self.input_size: int\n\n        #: Shape of output array or :class:`.BlockArray`\n        self.output_shape: Union[Shape, BlockShape]\n\n        #: Size of flattened output. Sum of product of `output_shape` tuples.\n        self.output_size: int\n\n        #: Shape Operator would take if it operated on flattened arrays.\n        #: Consists of (output_size, input_size)\n        self.matrix_shape: Tuple[int, int]\n\n        #: Shape of Operator, consisting of (output_shape, input_shape).\n        self.shape: Tuple[Union[Shape, BlockShape], Union[Shape, BlockShape]]\n\n        #: Dtype of input\n        self.input_dtype: DType\n\n        #: Dtype of operator\n        self.dtype: DType\n\n        if isinstance(input_shape, int):\n            self.input_shape = (input_shape,)\n        else:\n            self.input_shape = input_shape\n        self.input_dtype = input_dtype\n\n        # Allows for dynamic creation of new Operator/LinearOperator, e.g. for adjoints\n        if eval_fn:\n            self._eval = eval_fn  # type: ignore\n        elif not hasattr(self, \"_eval\"):\n            raise NotImplementedError(\n                \"Operator is an abstract base class when argument 'eval_fn' is not specified.\"\n            )\n\n        # If the output shape/dtype aren't specified, they can be inferred\n        # using scico.eval_shape\n        if output_shape is None or output_dtype is None:\n            dts = scico.eval_shape(\n                self._eval, jax.ShapeDtypeStruct(self.input_shape, dtype=input_dtype)\n            )\n        if output_shape is None:\n            self.output_shape = dts.shape  # type: ignore\n        else:\n            self.output_shape = (output_shape,) if isinstance(output_shape, int) else output_shape\n        if output_dtype is None:\n            self.output_dtype = dts.dtype\n        else:\n            self.output_dtype = output_dtype\n\n        # Determine the shape of the \"vectorized\" operator (as an element of ℝ^{n × m}\n        # If the function returns a BlockArray we need to compute the size of each block,\n        # then sum.\n        self.input_size = shape_to_size(self.input_shape)\n        self.output_size = shape_to_size(self.output_shape)\n\n        self.shape = (self.output_shape, self.input_shape)\n        self.matrix_shape = (self.output_size, self.input_size)\n\n        if jit:\n            self.jit()\n\n    def jit(self):\n        \"\"\"Activate just-in-time compilation for the `_eval` method.\"\"\"\n        self._eval = jax.jit(self._eval)\n\n    def __str__(self):\n        return f\"\"\"{self.__module__}.{self.__class__.__qualname__}\"\"\"\n\n    def __repr__(self):\n        return f\"\"\"{str(self)}\n  input_shape:  {self.input_shape}\n  output_shape: {self.output_shape}\n  input_dtype:  {dtype_name(self.input_dtype)}\n  output_dtype: {dtype_name(self.output_dtype)}\n\"\"\"\n\n    def __call__(self, x: Union[Operator, Array, BlockArray]) -> Union[Operator, Array, BlockArray]:\n        r\"\"\"Evaluate this :class:`Operator` at the point :math:`\\mb{x}`.\n\n        Args:\n            x: Point at which to evaluate this :class:`.Operator`. If `x`\n               is a :class:`jax.Array` or :class:`.BlockArray`, it must\n               have `shape == self.input_shape`. If `x` is a\n               :class:`.Operator` or :class:`.LinearOperator`, it must\n               have `x.output_shape == self.input_shape`.\n\n        Returns:\n             :class:`.Operator` evaluated at `x`.\n\n        Raises:\n            ValueError: If the `input_shape` attribute of the\n                :class:`.Operator` is not equal to the input array shape,\n                or to the `output_shape` attribute of another\n                :class:`.Operator` with which it is composed.\n        \"\"\"\n\n        if isinstance(x, Operator):\n            # Compose the two operators if shapes conform\n            if self.input_shape == x.output_shape:\n                return Operator(\n                    input_shape=x.input_shape,\n                    output_shape=self.output_shape,\n                    eval_fn=lambda z: self(x(z)),\n                    input_dtype=self.input_dtype,\n                    output_dtype=x.output_dtype,\n                )\n            raise ValueError(f\"Incompatible shapes {self.shape}, {x.shape}.\")\n\n        if self.input_shape != x.shape:\n            raise ValueError(\n                f\"Cannot evaluate {type(self)} with input_shape={self.input_shape} \"\n                f\"on array with shape={x.shape}.\"\n            )\n\n        return self._eval(x)\n\n    def __add__(self, other: Operator) -> Operator:\n        if isinstance(other, Operator):\n            if self.shape == other.shape:\n                return Operator(\n                    input_shape=self.input_shape,\n                    output_shape=self.output_shape,\n                    eval_fn=lambda x: self(x) + other(x),\n                    input_dtype=self.input_dtype,\n                    output_dtype=result_type(self.output_dtype, other.output_dtype),\n                )\n            raise ValueError(f\"Shapes {self.shape} and {other.shape} do not match.\")\n        raise TypeError(f\"Operation __add__ not defined between {type(self)} and {type(other)}.\")\n\n    def __sub__(self, other: Operator) -> Operator:\n        if isinstance(other, Operator):\n            if self.shape == other.shape:\n                return Operator(\n                    input_shape=self.input_shape,\n                    output_shape=self.output_shape,\n                    eval_fn=lambda x: self(x) - other(x),\n                    input_dtype=self.input_dtype,\n                    output_dtype=result_type(self.output_dtype, other.output_dtype),\n                )\n            raise ValueError(f\"Shapes {self.shape} and {other.shape} do not match.\")\n        raise TypeError(f\"Operation __sub__ not defined between {type(self)} and {type(other)}.\")\n\n    @_wrap_mul_div_scalar\n    def __mul__(self, other):\n        return Operator(\n            input_shape=self.input_shape,\n            output_shape=self.output_shape,\n            eval_fn=lambda x: other * self(x),\n            input_dtype=self.input_dtype,\n            output_dtype=result_type(self.output_dtype, other),\n        )\n\n    def __neg__(self) -> Operator:\n        return -1.0 * self\n\n    @_wrap_mul_div_scalar\n    def __rmul__(self, other):\n        return Operator(\n            input_shape=self.input_shape,\n            output_shape=self.output_shape,\n            eval_fn=lambda x: other * self(x),\n            input_dtype=self.input_dtype,\n            output_dtype=result_type(self.output_dtype, other),\n        )\n\n    @_wrap_mul_div_scalar\n    def __truediv__(self, other):\n        return Operator(\n            input_shape=self.input_shape,\n            output_shape=self.output_shape,\n            eval_fn=lambda x: self(x) / other,\n            input_dtype=self.input_dtype,\n            output_dtype=result_type(self.output_dtype, other),\n        )\n\n    def jvp(self, u, v):\n        r\"\"\"Compute a Jacobian-vector product.\n\n        Compute the product :math:`J_F(\\mb{u}) \\mb{v}` where :math:`F`\n        represents this operator and :math:`J_F(\\mb{u})` is the Jacobian\n        of :math:`F` evaluated at :math:`\\mb{u}`. This method is\n        implemented via a call to :func:`jax.jvp`.\n\n        Args:\n            u: Value at which the Jacobian is evaluated.\n            v: Vector in the Jacobian-vector product.\n\n        Returns:\n           A pair :math:`(F(\\mb{u}), J_F(\\mb{u}) \\mb{v})`, i.e. a pair\n           consisting of the operator evaluated at :math:`\\mb{u}` and the\n           Jacobian-vector product.\n        \"\"\"\n        return jax.jvp(self, (u,), (v,))\n\n    def vjp(self, u, conjugate=True):\n        r\"\"\"Compute a vector-Jacobian product.\n\n        Compute the product :math:`[J_F(\\mb{u})]^T \\mb{v}` where :math:`F`\n        represents this operator and :math:`J_F(\\mb{u})` is the Jacobian\n        of :math:`F` evaluated at :math:`\\mb{u}`. Instead of directly\n        computing the vector-Jacobian product, this method returns a\n        function, taking :math:`\\mb{v}` as an argument, that returns\n        the product. This method is implemented via a call to\n        :func:`jax.vjp`.\n\n        Args:\n            u: Value at which the Jacobian is evaluated.\n            conjugate: If ``True``, compute the product using the\n               conjugate (Hermitian) transpose.\n\n        Returns:\n            A pair :math:`(F(\\mb{u}), G(\\cdot))` where :math:`G(\\cdot)`\n            is a function that computes the vector-Jacobian product, i.e.\n            :math:`G(\\mb{v}) = [J_F(\\mb{u})]^T \\mb{v}` when `conjugate`\n            is ``False``, or :math:`G(\\mb{v}) = [J_F(\\mb{u})]^H \\mb{v}`\n            when `conjugate` is ``True``.\n        \"\"\"\n        Fu, G = jax.vjp(self, u)\n\n        if conjugate:\n\n            def Gmap(v):\n                return G(v.conj())[0].conj()\n\n        else:\n\n            def Gmap(v):\n                return G(v)[0]\n\n        return Fu, Gmap\n\n    def freeze(self, argnum: int, val: Union[Array, BlockArray]) -> Operator:\n        \"\"\"Return a new :class:`.Operator` with fixed block argument.\n\n        Return a new :class:`.Operator` with block argument `argnum`\n        fixed to value `val`.\n\n        Args:\n            argnum: Index of block to freeze. Must be less than or equal\n               to the number of blocks in an input array.\n            val: Value to fix the `argnum`-th input to.\n\n        Returns:\n            A new :class:`.Operator` with one of the blocks of the input\n            fixed to the specified value.\n\n        Raises:\n            ValueError: If the :class:`.Operator` does not take a\n               :class:`.BlockArray` as its input, if the block index\n               equals or exceeds the number of blocks, or if the shape of\n               the fixed value differs from the shape of the specified\n               block.\n        \"\"\"\n\n        if not is_nested(self.input_shape):\n            raise ValueError(\n                \"The freeze method can only be applied to Operators that take BlockArray inputs.\"\n            )\n\n        input_ndim = len(self.input_shape)\n        if argnum > input_ndim - 1:\n            raise ValueError(\n                f\"Argument 'argnum' must be fewer than the number of input arguments to \"\n                f\"this operator ({input_ndim}); got {argnum}.\"\n            )\n\n        if val.shape != self.input_shape[argnum]:\n            raise ValueError(\n                f\"Value to be frozen at position {argnum} must have shape \"\n                f\"{self.input_shape[argnum]}, got {val.shape}.\"\n            )\n\n        input_shape: Union[Shape, BlockShape]\n        input_shape = tuple(s for i, s in enumerate(self.input_shape) if i != argnum)  # type: ignore\n\n        if len(input_shape) == 1:\n            input_shape = input_shape[0]  # type: ignore\n\n        def concat_args(args):\n            # Create a blockarray with args and the frozen value in the correct place\n            # E.g. if this operator takes a blockarray with two blocks, then\n            # concat_args(args) = snp.blockarray([val, args]) if argnum = 0\n            # concat_args(args) = snp.blockarray([args, val]) if argnum = 1\n\n            if isinstance(args, (jnp.ndarray, np.ndarray)):\n                # In the case that the original operator takes a blockarray with two\n                # blocks, wrap in a list so we can use the same indexing as >2 block case\n                args = [args]\n\n            arg_list = []\n            for i in range(input_ndim):\n                if i < argnum:\n                    arg_list.append(args[i])\n                elif i > argnum:\n                    arg_list.append(args[i - 1])\n                else:\n                    arg_list.append(val)\n            return snp.blockarray(arg_list)\n\n        return Operator(\n            input_shape=input_shape,\n            output_shape=self.output_shape,\n            eval_fn=lambda x: self(concat_args(x)),\n        )\n"
  },
  {
    "path": "scico/operator/_stack.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2023-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Stack of operators classes.\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Optional, Sequence, Tuple, Union\n\nimport numpy as np\n\nimport jax\n\nimport scico.numpy as snp\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.util import is_blockable, is_collapsible, is_nested\nfrom scico.typing import BlockShape, Shape\n\nfrom ._operator import Operator\n\n\ndef collapse_shapes(\n    shapes: Sequence[Union[Shape, BlockShape]], allow_collapse=True\n) -> Tuple[Union[Shape, BlockShape], bool]:\n    \"\"\"Compute the collapsed representation of a sequence of shapes.\n\n    Decide whether to collapse a sequence of shapes, returning either\n    the sequence of shapes or a collapsed shape, and a boolean indicating\n    whether the shape was collapsed.\"\"\"\n\n    if is_collapsible(shapes) and allow_collapse:\n        return (len(shapes), *shapes[0]), True\n\n    if is_blockable(shapes):\n        return shapes, False\n\n    raise ValueError(\n        \"Combining these shapes would result in a twice-nested BlockArray, which is not supported.\"\n    )\n\n\nclass VerticalStack(Operator):\n    r\"\"\"A vertical stack of operators.\n\n    Given operators :math:`A_1, A_2, \\dots, A_N`, create the operator\n    :math:`H` such that\n\n    .. math::\n       H(\\mb{x})\n       =\n       \\begin{pmatrix}\n            A_1(\\mb{x}) \\\\\n            A_2(\\mb{x}) \\\\\n            \\vdots \\\\\n            A_N(\\mb{x}) \\\\\n       \\end{pmatrix} \\;.\n    \"\"\"\n\n    def __init__(\n        self,\n        ops: Sequence[Operator],\n        collapse_output: Optional[bool] = True,\n        jit: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            ops: Operators to stack.\n            collapse_output: If ``True`` and the output would be a\n                :class:`BlockArray` with shape ((m, n, ...), (m, n, ...),\n                ...), the output is instead a :class:`jax.Array` with\n                shape (S, m, n, ...) where S is the length of `ops`.\n            jit: See `jit` in :class:`Operator`.\n        \"\"\"\n        VerticalStack.check_if_stackable(ops)\n\n        self.ops = ops\n        self.collapse_output = collapse_output\n\n        output_shapes = tuple(op.output_shape for op in ops)\n        self.output_collapsible = is_collapsible(output_shapes)\n\n        if self.output_collapsible and self.collapse_output:\n            output_shape = (len(ops),) + output_shapes[0]  # collapse to jax array\n        else:\n            output_shape = output_shapes\n\n        super().__init__(\n            input_shape=ops[0].input_shape,\n            output_shape=output_shape,  # type: ignore\n            input_dtype=ops[0].input_dtype,\n            output_dtype=ops[0].output_dtype,\n            jit=jit,\n            **kwargs,\n        )\n\n    @staticmethod\n    def check_if_stackable(ops: Sequence[Operator]):\n        \"\"\"Check that input ops are suitable for stack creation.\"\"\"\n        if not isinstance(ops, (list, tuple)):\n            raise TypeError(\"Expected a list of Operator.\")\n\n        input_shapes = [op.shape[1] for op in ops]\n        if not all(input_shapes[0] == s for s in input_shapes):\n            raise ValueError(\n                \"Expected all Operators to have the same input shapes, \" f\"but got {input_shapes}.\"\n            )\n\n        input_dtypes = [op.input_dtype for op in ops]\n        if not all(input_dtypes[0] == s for s in input_dtypes):\n            raise ValueError(\n                \"Expected all Operators to have the same input dtype, \" f\"but got {input_dtypes}.\"\n            )\n\n        if any([is_nested(op.shape[0]) for op in ops]):\n            raise ValueError(\"Cannot stack Operators with nested output shapes.\")\n\n        output_dtypes = [op.output_dtype for op in ops]\n        if not np.all(output_dtypes[0] == s for s in output_dtypes):\n            raise ValueError(\"Expected all Operators to have the same output dtype.\")\n\n    def _eval(self, x: Array) -> Union[Array, BlockArray]:\n        if self.output_collapsible and self.collapse_output:\n            return snp.stack([op(x) for op in self.ops])\n        return BlockArray([op(x) for op in self.ops])\n\n    def __repr__(self):\n        crepr = \", \".join([str(f) for f in self.ops])\n        return Operator.__repr__(self) + f\"\"\"  components: {crepr}\\n\"\"\"\n\n\nclass DiagonalStack(Operator):\n    r\"\"\"A diagonal stack of operators.\n\n    Given operators :math:`A_1, A_2, \\dots, A_N`, create the operator\n    :math:`H` such that\n\n    .. math::\n       H \\left(\n       \\begin{pmatrix}\n            \\mb{x}_1 \\\\\n            \\mb{x}_2 \\\\\n            \\vdots \\\\\n            \\mb{x}_N \\\\\n       \\end{pmatrix} \\right)\n       =\n       \\begin{pmatrix}\n            A_1(\\mb{x}_1) \\\\\n            A_2(\\mb{x}_2) \\\\\n            \\vdots \\\\\n            A_N(\\mb{x}_N) \\\\\n       \\end{pmatrix} \\;.\n\n    By default, if the inputs :math:`\\mb{x}_1, \\mb{x}_2, \\dots,\n    \\mb{x}_N` all have the same (possibly nested) shape, `S`, this\n    operator will work on the stack, i.e., have an input shape of `(N,\n    *S)`. If the inputs have distinct shapes, `S1`, `S2`, ..., `SN`,\n    this operator will work on the block concatenation, i.e.,\n    have an input shape of `(S1, S2, ..., SN)`. The same holds for the\n    output shape.\n    \"\"\"\n\n    def __init__(\n        self,\n        ops: Sequence[Operator],\n        collapse_input: Optional[bool] = True,\n        collapse_output: Optional[bool] = True,\n        jit: bool = True,\n        **kwargs,\n    ):\n        \"\"\"\n        Args:\n            ops: Operators to stack.\n            collapse_input: If ``True``, inputs are expected to be\n                stacked along the first dimension when possible.\n            collapse_output: If ``True``, the output will be\n                stacked along the first dimension when possible.\n            jit: See `jit` in :class:`Operator`.\n\n        \"\"\"\n        DiagonalStack.check_if_stackable(ops)\n\n        self.ops = ops\n\n        input_shape, self.collapse_input = collapse_shapes(\n            tuple(op.input_shape for op in ops),\n            collapse_input,\n        )\n        output_shape, self.collapse_output = collapse_shapes(\n            tuple(op.output_shape for op in ops),\n            collapse_output,\n        )\n\n        super().__init__(\n            input_shape=input_shape,\n            output_shape=output_shape,\n            input_dtype=ops[0].input_dtype,\n            output_dtype=ops[0].output_dtype,\n            jit=jit,\n            **kwargs,\n        )\n\n    @staticmethod\n    def check_if_stackable(ops: Sequence[Operator]):\n        \"\"\"Check that input ops are suitable for stack creation.\"\"\"\n        if not isinstance(ops, (list, tuple)):\n            raise TypeError(\"Expected a list of Operator.\")\n\n        if any([is_nested(op.shape[0]) for op in ops]):\n            raise ValueError(\"Cannot stack Operators with nested output shapes.\")\n\n        output_dtypes = [op.output_dtype for op in ops]\n        if not np.all(output_dtypes[0] == s for s in output_dtypes):\n            raise ValueError(\"Expected all Operators to have the same output dtype.\")\n\n    def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]:\n        result = tuple(op(x_n) for op, x_n in zip(self.ops, x))\n        if self.collapse_output:\n            return snp.stack(result)\n        return snp.blockarray(result)\n\n    def __repr__(self):\n        crepr = \", \".join([str(f) for f in self.ops])\n        return Operator.__repr__(self) + f\"\"\"  components: {crepr}\\n\"\"\"\n\n\nclass DiagonalReplicated(Operator):\n    r\"\"\"A diagonal stack constructed from a single operator.\n\n    Given operator :math:`A`, create the operator :math:`H` such that\n\n    .. math::\n       H \\left(\n       \\begin{pmatrix}\n            \\mb{x}_1 \\\\\n            \\mb{x}_2 \\\\\n            \\vdots \\\\\n            \\mb{x}_N \\\\\n       \\end{pmatrix} \\right)\n       =\n       \\begin{pmatrix}\n            A(\\mb{x}_1) \\\\\n            A(\\mb{x}_2) \\\\\n            \\vdots \\\\\n            A(\\mb{x}_N) \\\\\n       \\end{pmatrix} \\;.\n\n    The application of :math:`A` to each component :math:`\\mb{x}_k` is\n    computed using :func:`jax.pmap` or :func:`jax.vmap`. The input shape\n    for operator :math:`A` should exclude the array axis on which\n    :math:`A` is replicated to form :math:`H`. For example, if :math:`A`\n    has input shape `(3, 4)` and :math:`H` is constructed to replicate\n    on axis 0 with 2 replicates, the input shape of :math:`H` will be\n    `(2, 3, 4)`.\n\n    Operators taking :class:`.BlockArray` input are not supported.\n    \"\"\"\n\n    def __init__(\n        self,\n        op: Operator,\n        replicates: int,\n        input_axis: int = 0,\n        output_axis: Optional[int] = None,\n        map_type: str = \"auto\",\n        **kwargs,\n    ):\n        \"\"\"\n        Args:\n            op: Operator to replicate.\n            replicates: Number of replicates of `op`.\n            input_axis: Input axis over which `op` should be replicated.\n            output_axis: Index of replication axis in output array.\n               If ``None``, the input replication axis is used.\n            map_type: If \"pmap\" or \"vmap\", apply replicated mapping using\n               :func:`jax.pmap` or :func:`jax.vmap` respectively. If\n               \"auto\", use :func:`jax.pmap` if sufficient devices are\n               available for the number of replicates, otherwise use\n               :func:`jax.vmap`.\n        \"\"\"\n        if map_type not in [\"auto\", \"pmap\", \"vmap\"]:\n            raise ValueError(\"Argument 'map_type' must be one of 'auto', 'pmap, or 'vmap'.\")\n        if input_axis < 0:\n            input_axis = len(op.input_shape) + 1 + input_axis\n        if input_axis < 0 or input_axis > len(op.input_shape):\n            raise ValueError(\n                \"Argument 'input_axis' must be positive and less than the number of axes \"\n                \"in the input shape of argument 'op'.\"\n            )\n        if is_nested(op.input_shape):\n            raise ValueError(\"Argument 'op' may not be an Operator taking BlockArray input.\")\n        if is_nested(op.output_shape):\n            raise ValueError(\"Argument 'op' may not be an Operator with BlockArray output.\")\n        self.op = op\n        self.replicates = replicates\n        self.input_axis = input_axis\n        self.output_axis = self.input_axis if output_axis is None else output_axis\n\n        if map_type == \"auto\":\n            self.jaxmap = jax.pmap if replicates <= jax.device_count() else jax.vmap\n        else:\n            if map_type == \"pmap\" and replicates > jax.device_count():\n                raise ValueError(\n                    \"Requested pmap mapping but number of replicates exceeds device count.\"\n                )\n            else:\n                self.jaxmap = jax.pmap if map_type == \"pmap\" else jax.vmap\n\n        eval_fn = self.jaxmap(op.__call__, in_axes=self.input_axis, out_axes=self.output_axis)\n\n        input_shape = (\n            op.input_shape[0 : self.input_axis] + (replicates,) + op.input_shape[self.input_axis :]\n        )\n        output_shape = (\n            op.output_shape[0 : self.output_axis]\n            + (replicates,)\n            + op.output_shape[self.output_axis :]\n        )\n\n        super().__init__(\n            input_shape=input_shape,  # type: ignore\n            output_shape=output_shape,  # type: ignore\n            eval_fn=eval_fn,\n            input_dtype=op.input_dtype,\n            output_dtype=op.output_dtype,\n            jit=False,\n            **kwargs,\n        )\n\n    def __repr__(self):\n        return (\n            Operator.__repr__(self)\n            + f\"\"\"  component:  {str(self.op)}\\n  replicates: {self.replicates}\\n\"\"\"\n        )\n"
  },
  {
    "path": "scico/operator/biconvolve.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Biconvolution operator.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom typing import Tuple, cast\n\nimport numpy as np\n\nfrom jax.scipy.signal import convolve\n\nimport scico.linop\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.util import is_nested\nfrom scico.typing import DType, Shape\n\nfrom ._operator import Operator\n\n\nclass BiConvolve(Operator):\n    \"\"\"Biconvolution operator.\n\n    A :class:`.BiConvolve` operator accepts a :class:`.BlockArray` input\n    with two blocks of equal ndims, and convolves the first block with\n    the second.\n\n    If `A` is a :class:`.BiConvolve` operator, then\n    `A(snp.blockarray([x, h]))` equals `jax.scipy.signal.convolve(x, h)`.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        input_shape: Tuple[Shape, Shape],\n        input_dtype: DType = np.float32,\n        mode: str = \"full\",\n        jit: bool = True,\n    ):\n        r\"\"\"\n        Args:\n            input_shape: Shape of input :class:`.BlockArray`. Must\n                correspond to a :class:`.`BlockArray` with two blocks of\n                equal ndims.\n            input_dtype: `dtype` for input argument. Defaults to\n                :attr:`~numpy.float32`.\n            mode:  A string indicating the size of the output. One of\n                \"full\", \"valid\", \"same\". Defaults to \"full\".\n            jit: If ``True``, jit the evaluation of this\n                :class:`.Operator`.\n\n        For more details on `mode`, see :func:`jax.scipy.signal.convolve`.\n        \"\"\"\n\n        if not is_nested(input_shape):\n            raise ValueError(\"A BlockShape is expected; got {input_shape}.\")\n        if len(input_shape) != 2:\n            raise ValueError(\n                f\"Argument 'input_shape' must have two blocks; got {len(input_shape)}.\"\n            )\n        if len(input_shape[0]) != len(input_shape[1]):\n            raise ValueError(\n                f\"Both input blocks must have same number of dimensions; got \"\n                f\"{len(input_shape[0]), len(input_shape[1])}.\"\n            )\n\n        if mode not in [\"full\", \"valid\", \"same\"]:\n            raise ValueError(f\"Invalid mode={mode}; must be one of 'full', 'valid', 'same'.\")\n\n        self.mode = mode\n\n        super().__init__(input_shape=input_shape, input_dtype=input_dtype, jit=jit)\n\n    def _eval(self, x: BlockArray) -> Array:\n        return convolve(x[0], x[1], mode=self.mode)\n\n    def freeze(self, argnum: int, val: Array) -> scico.linop.LinearOperator:\n        \"\"\"Freeze the `argnum` parameter.\n\n        Return a new :class:`.LinearOperator` with block argument\n        `argnum` fixed to value `val`.\n\n        If `argnum == 0`, a :class:`.ConvolveByX` object is returned.\n        If `argnum == 1`, a :class:`.Convolve` object is returned.\n\n        Args:\n            argnum: Index of block to freeze. Must be 0 or 1.\n            val: Value to fix the `argnum`-th input to.\n        \"\"\"\n\n        if argnum == 0:\n            return scico.linop.ConvolveByX(\n                x=val,\n                input_shape=cast(Shape, self.input_shape[1]),\n                input_dtype=self.input_dtype,\n                output_shape=self.output_shape,\n                mode=self.mode,\n            )\n        if argnum == 1:\n            return scico.linop.Convolve(\n                h=val,\n                input_shape=cast(Shape, self.input_shape[0]),\n                input_dtype=self.input_dtype,\n                output_shape=self.output_shape,\n                mode=self.mode,\n            )\n        raise ValueError(f\"Argument 'argnum' must be 0 or 1; got {argnum}.\")\n"
  },
  {
    "path": "scico/optimize/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Optimization algorithms.\"\"\"\n\nimport sys\n\n# isort: off\nfrom .admm import ADMM\nfrom ._common import Optimizer\nfrom ._ladmm import LinearizedADMM\nfrom .pgm import PGM, AcceleratedPGM\nfrom ._primaldual import PDHG\nfrom ._padmm import ProximalADMM, NonLinearPADMM, ProximalADMMBase\n\n__all__ = [\n    \"ADMM\",\n    \"LinearizedADMM\",\n    \"ProximalADMM\",\n    \"ProximalADMMBase\",\n    \"NonLinearPADMM\",\n    \"PGM\",\n    \"AcceleratedPGM\",\n    \"PDHG\",\n    \"Optimizer\",\n]\n\n# Imported items in __all__ appear to originate in top-level linop module\nfor name in __all__:\n    getattr(sys.modules[__name__], name).__module__ = __name__\n"
  },
  {
    "path": "scico/optimize/_admm.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"ADMM solver.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom typing import List, Optional, Tuple, Union\n\nimport scico.numpy as snp\nfrom scico.functional import Functional\nfrom scico.linop import LinearOperator\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.linalg import norm\n\nfrom ._admmaux import (\n    FBlockCircularConvolveSolver,\n    G0BlockCircularConvolveSolver,\n    GenericSubproblemSolver,\n    LinearSubproblemSolver,\n    MatrixSubproblemSolver,\n    SubproblemSolver,\n)\nfrom ._common import Optimizer\n\n\nclass ADMM(Optimizer):\n    r\"\"\"Basic Alternating Direction Method of Multipliers (ADMM) algorithm.\n\n    |\n\n    Solve an optimization problem of the form\n\n    .. math::\n        \\argmin_{\\mb{x}} \\; f(\\mb{x}) + \\sum_{i=1}^N g_i(C_i \\mb{x}) \\;,\n\n    where :math:`f` and the :math:`g_i` are instances of\n    :class:`.Functional`, and the :math:`C_i` are\n    :class:`.LinearOperator`.\n\n    The optimization problem is solved by introducing the splitting\n    :math:`\\mb{z}_i = C_i \\mb{x}` and solving\n\n    .. math::\n        \\argmin_{\\mb{x}, \\mb{z}_i} \\; f(\\mb{x}) + \\sum_{i=1}^N\n        g_i(\\mb{z}_i) \\; \\text{such that}\\; C_i \\mb{x} = \\mb{z}_i \\;,\n\n    via an ADMM algorithm :cite:`glowinski-1975-approximation`\n    :cite:`gabay-1976-dual` :cite:`boyd-2010-distributed` consisting of\n    the iterations (see :meth:`step`)\n\n    .. math::\n       \\begin{aligned}\n       \\mb{x}^{(k+1)} &= \\argmin_{\\mb{x}} \\; f(\\mb{x}) + \\sum_i\n       \\frac{\\rho_i}{2} \\norm{\\mb{z}^{(k)}_i - \\mb{u}^{(k)}_i - C_i\n       \\mb{x}}_2^2 \\\\\n       \\mb{z}_i^{(k+1)} &= \\argmin_{\\mb{z}_i} \\; g_i(\\mb{z}_i) +\n       \\frac{\\rho_i}{2}\n       \\norm{\\mb{z}_i - \\mb{u}^{(k)}_i - C_i \\mb{x}^{(k+1)}}_2^2  \\\\\n       \\mb{u}_i^{(k+1)} &=  \\mb{u}_i^{(k)} + C_i \\mb{x}^{(k+1)} -\n       \\mb{z}^{(k+1)}_i  \\; .\n       \\end{aligned}\n\n\n    Attributes:\n        f (:class:`.Functional`): Functional :math:`f` (usually a\n            :class:`.Loss`)\n        g_list (list of :class:`.Functional`): List of :math:`g_i`\n            functionals. Must be same length as :code:`C_list` and\n            :code:`rho_list`.\n        C_list (list of :class:`.LinearOperator`): List of :math:`C_i`\n            operators.\n        rho_list (list of scalars): List of :math:`\\rho_i` penalty\n            parameters. Must be same length as :code:`C_list` and\n            :code:`g_list`.\n        alpha (float): Relaxation parameter.\n        u_list (list of array-like): List of scaled Lagrange multipliers\n            :math:`\\mb{u}_i` at current iteration.\n        x (array-like): Solution.\n        subproblem_solver (:class:`.SubproblemSolver`): Solver for\n            :math:`\\mb{x}`-update step.\n        z_list (list of array-like): List of auxiliary variables\n            :math:`\\mb{z}_i` at current iteration.\n        z_list_old (list of array-like): List of auxiliary variables\n            :math:`\\mb{z}_i` at previous iteration.\n    \"\"\"\n\n    def __init__(\n        self,\n        f: Functional,\n        g_list: List[Functional],\n        C_list: List[LinearOperator],\n        rho_list: List[float],\n        alpha: float = 1.0,\n        x0: Optional[Union[Array, BlockArray]] = None,\n        subproblem_solver: Optional[SubproblemSolver] = None,\n        **kwargs,\n    ):\n        r\"\"\"Initialize an :class:`ADMM` object.\n\n        Args:\n            f: Functional :math:`f` (usually a loss function).\n            g_list: List of :math:`g_i` functionals. Must be same length\n                 as :code:`C_list` and :code:`rho_list`.\n            C_list: List of :math:`C_i` operators.\n            rho_list: List of :math:`\\rho_i` penalty parameters.\n                Must be same length as :code:`C_list` and :code:`g_list`.\n            alpha: Relaxation parameter. No relaxation for default 1.0.\n            x0: Initial value for :math:`\\mb{x}`. If ``None``, defaults\n                to an array of zeros.\n            subproblem_solver: Solver for :math:`\\mb{x}`-update step.\n                Defaults to ``None``, which implies use of an instance of\n                :class:`GenericSubproblemSolver`.\n            **kwargs: Additional optional parameters handled by\n                initializer of base class :class:`.Optimizer`.\n        \"\"\"\n        N = len(g_list)\n        if len(C_list) != N:\n            raise ValueError(f\"len(C_list)={len(C_list)} not equal to len(g_list)={N}.\")\n        if len(rho_list) != N:\n            raise ValueError(f\"len(rho_list)={len(rho_list)} not equal to len(g_list)={N}.\")\n\n        self.f: Functional = f\n        self.g_list: List[Functional] = g_list\n        self.C_list: List[LinearOperator] = C_list\n        self.rho_list: List[float] = rho_list\n        self.alpha: float = alpha\n\n        if subproblem_solver is None:\n            subproblem_solver = GenericSubproblemSolver()\n        self.subproblem_solver: SubproblemSolver = subproblem_solver\n        self.subproblem_solver.internal_init(self)\n\n        if x0 is None:\n            input_shape = C_list[0].input_shape\n            dtype = C_list[0].input_dtype\n            x0 = snp.zeros(input_shape, dtype=dtype)\n        self.x = x0\n        self.z_list, self.z_list_old = self.z_init(self.x)\n        self.u_list = self.u_init(self.x)\n\n        super().__init__(**kwargs)\n\n    def _working_vars_finite(self) -> bool:\n        \"\"\"Determine where ``NaN`` of ``Inf`` encountered in solve.\n\n        Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in\n        a solver working variable.\n        \"\"\"\n        for v in (\n            [\n                self.x,\n            ]\n            + self.z_list\n            + self.u_list\n        ):\n            if not snp.all(snp.isfinite(v)):\n                return False\n        return True\n\n    def _objective_evaluatable(self):\n        \"\"\"Determine whether the objective function can be evaluated.\"\"\"\n        return (not self.f or self.f.has_eval) and all([_.has_eval for _ in self.g_list])\n\n    def _itstat_extra_fields(self):\n        \"\"\"Define ADMM-specific iteration statistics fields.\"\"\"\n        itstat_fields = {\"Prml Rsdl\": \"%9.3e\", \"Dual Rsdl\": \"%9.3e\"}\n        itstat_attrib = [\"norm_primal_residual()\", \"norm_dual_residual()\"]\n\n        # subproblem solver info when available\n        if isinstance(self.subproblem_solver, GenericSubproblemSolver):\n            itstat_fields.update({\"Num FEv\": \"%6d\", \"Num It\": \"%6d\"})\n            itstat_attrib.extend(\n                [\"subproblem_solver.info['nfev']\", \"subproblem_solver.info['nit']\"]\n            )\n        elif (\n            type(self.subproblem_solver) == LinearSubproblemSolver\n            and self.subproblem_solver.cg_function == \"scico\"\n        ):\n            itstat_fields.update({\"CG It\": \"%5d\", \"CG Res\": \"%9.3e\"})\n            itstat_attrib.extend(\n                [\"subproblem_solver.info['num_iter']\", \"subproblem_solver.info['rel_res']\"]\n            )\n        elif (\n            type(self.subproblem_solver)\n            in [MatrixSubproblemSolver, FBlockCircularConvolveSolver, G0BlockCircularConvolveSolver]\n            and self.subproblem_solver.check_solve\n        ):\n            itstat_fields.update({\"Slv Res\": \"%9.3e\"})\n            itstat_attrib.extend([\"subproblem_solver.accuracy\"])\n\n        return itstat_fields, itstat_attrib\n\n    def _state_variable_names(self) -> List[str]:\n        # While x is in the most abstract sense not part of the algorithm\n        # state, it does form part of the state in pratice due to its use\n        # as an initializer for iterative solvers for the x step of the\n        # ADMM algorithm.\n        return [\"x\", \"z_list\", \"z_list_old\", \"u_list\"]\n\n    def minimizer(self) -> Union[Array, BlockArray]:\n        return self.x\n\n    def objective(\n        self,\n        x: Optional[Union[Array, BlockArray]] = None,\n        z_list: Optional[List[Union[Array, BlockArray]]] = None,\n    ) -> float:\n        r\"\"\"Evaluate the objective function.\n\n        Evaluate the objective function\n\n        .. math::\n            f(\\mb{x}) + \\sum_{i=1}^N g_i(\\mb{z}_i) \\;.\n\n        Note that this form is cheaper to compute, but may have very poor\n        accuracy compared with the \"true\" objective function\n\n        .. math::\n            f(\\mb{x}) + \\sum_{i=1}^N g_i(C_i \\mb{x}) \\;.\n\n        when the primal residual is large.\n\n        Args:\n            x: Point at which to evaluate objective function. If ``None``,\n                the objective is  evaluated at the current iterate\n                :code:`self.x`.\n            z_list: Point at which to evaluate objective function. If\n                ``None``, the objective is evaluated at the current iterate\n                :code:`self.z_list`.\n\n        Returns:\n            Value of the objective function.\n        \"\"\"\n        if (x is None) != (z_list is None):\n            raise ValueError(\"Both or neither of arguments 'x' and 'z_list' must be supplied.\")\n        if x is None:\n            x = self.x\n            z_list = self.z_list\n        assert z_list is not None\n        out = 0.0\n        if self.f:\n            out += self.f(x)\n        for g, z in zip(self.g_list, z_list):\n            out += g(z)\n        return out\n\n    def norm_primal_residual(self, x: Optional[Union[Array, BlockArray]] = None) -> float:\n        r\"\"\"Compute the :math:`\\ell_2` norm of the primal residual.\n\n        Compute the :math:`\\ell_2` norm of the primal residual\n\n        .. math::\n            \\left( \\sum_{i=1}^N \\rho_i \\left\\| C_i \\mb{x} -\n            \\mb{z}_i^{(k)} \\right\\|_2^2\\right)^{1/2} \\;.\n\n        Args:\n            x: Point at which to evaluate primal residual. If ``None``,\n                the primal residual is evaluated at the current iterate\n                :code:`self.x`.\n\n        Returns:\n            Norm of primal residual.\n        \"\"\"\n        if x is None:\n            x = self.x\n\n        sum = 0.0\n        for rhoi, Ci, zi in zip(self.rho_list, self.C_list, self.z_list):\n            sum += rhoi * norm(Ci(self.x) - zi) ** 2\n        return snp.sqrt(sum)\n\n    def norm_dual_residual(self) -> float:\n        r\"\"\"Compute the :math:`\\ell_2` norm of the dual residual.\n\n        Compute the :math:`\\ell_2` norm of the dual residual\n\n        .. math::\n            \\left\\| \\sum_{i=1}^N \\rho_i C_i^T \\left( \\mb{z}^{(k)}_i -\n            \\mb{z}^{(k-1)}_i \\right) \\right\\|_2 \\;.\n\n        Returns:\n            Norm of dual residual.\n\n        \"\"\"\n        sum = 0.0\n        for rhoi, zi, ziold, Ci in zip(self.rho_list, self.z_list, self.z_list_old, self.C_list):\n            sum += rhoi * Ci.adj(zi - ziold)\n        return norm(sum)\n\n    def z_init(\n        self, x0: Union[Array, BlockArray]\n    ) -> Tuple[List[Union[Array, BlockArray]], List[Union[Array, BlockArray]]]:\n        r\"\"\"Initialize auxiliary variables :math:`\\mb{z}_i`.\n\n        Initialized to\n\n        .. math::\n            \\mb{z}_i = C_i \\mb{x}^{(0)} \\;.\n\n        :code:`z_list` and :code:`z_list_old` are initialized to the same\n        value.\n\n        Args:\n            x0: Initial value of :math:`\\mb{x}`.\n        \"\"\"\n        z_list: List[Union[Array, BlockArray]] = [Ci(x0) for Ci in self.C_list]\n        z_list_old = z_list.copy()\n        return z_list, z_list_old\n\n    def u_init(self, x0: Union[Array, BlockArray]) -> List[Union[Array, BlockArray]]:\n        r\"\"\"Initialize scaled Lagrange multipliers :math:`\\mb{u}_i`.\n\n        Initialized to\n\n        .. math::\n            \\mb{u}_i = \\mb{0} \\;.\n\n        Note that the parameter `x0` is unused, but is provided for\n        potential use in an overridden method.\n\n        Args:\n            x0: Initial value of :math:`\\mb{x}`.\n        \"\"\"\n        u_list = [snp.zeros(Ci.output_shape, dtype=Ci.output_dtype) for Ci in self.C_list]\n        return u_list\n\n    def step(self):\n        r\"\"\"Perform a single ADMM iteration.\n\n        The primary variable :math:`\\mb{x}` is updated by solving the the\n        optimization problem\n\n        .. math::\n            \\mb{x}^{(k+1)} = \\argmin_{\\mb{x}} \\; f(\\mb{x}) + \\sum_i\n            \\frac{\\rho_i}{2} \\norm{\\mb{z}^{(k)}_i - \\mb{u}^{(k)}_i -\n            C_i \\mb{x}}_2^2 \\;.\n\n        Update auxiliary variables :math:`\\mb{z}_i` and scaled Lagrange\n        multipliers :math:`\\mb{u}_i`. The auxiliary variables are updated\n        according to\n\n        .. math::\n            \\begin{aligned}\n            \\mb{z}_i^{(k+1)} &= \\argmin_{\\mb{z}_i} \\; g_i(\\mb{z}_i) +\n            \\frac{\\rho_i}{2} \\norm{\\mb{z}_i - \\mb{u}^{(k)}_i - C_i\n            \\mb{x}^{(k+1)}}_2^2  \\\\\n            &= \\mathrm{prox}_{g_i}(C_i \\mb{x} + \\mb{u}_i, 1 / \\rho_i) \\;,\n            \\end{aligned}\n\n        and the scaled Lagrange multipliers are updated according to\n\n        .. math::\n            \\mb{u}_i^{(k+1)} =  \\mb{u}_i^{(k)} + C_i \\mb{x}^{(k+1)} -\n            \\mb{z}^{(k+1)}_i \\;.\n        \"\"\"\n\n        self.x = self.subproblem_solver.solve(self.x)\n\n        self.z_list_old = self.z_list.copy()\n\n        for i, (rhoi, gi, Ci, zi, ui) in enumerate(\n            zip(self.rho_list, self.g_list, self.C_list, self.z_list, self.u_list)\n        ):\n            if self.alpha == 1.0:\n                Cix = Ci(self.x)\n            else:\n                Cix = self.alpha * Ci(self.x) + (1.0 - self.alpha) * zi\n            zi = gi.prox(Cix + ui, 1 / rhoi, v0=zi)\n            ui = ui + Cix - zi\n            self.z_list[i] = zi\n            self.u_list[i] = ui\n"
  },
  {
    "path": "scico/optimize/_admmaux.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"ADMM auxiliary classes.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom functools import reduce\nfrom typing import Any, Optional, Union\n\nimport jax\nfrom jax.scipy.sparse.linalg import cg as jax_cg\n\nimport scico.numpy as snp\nimport scico.optimize.admm as soa\nfrom scico.functional import ZeroFunctional\nfrom scico.linop import (\n    CircularConvolve,\n    ComposedLinearOperator,\n    Diagonal,\n    Identity,\n    LinearOperator,\n    MatrixOperator,\n)\nfrom scico.loss import SquaredL2Loss\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.util import is_real_dtype\nfrom scico.solver import ConvATADSolver, MatrixATADSolver\nfrom scico.solver import cg as scico_cg\nfrom scico.solver import minimize\n\n\nclass SubproblemSolver:\n    r\"\"\"Base class for solvers for the non-separable ADMM step.\n\n    The ADMM solver implemented by :class:`.ADMM` addresses a general\n    problem form for which one of the corresponding ADMM algorithm\n    subproblems is separable into distinct subproblems for each of the\n    :math:`g_i`, and another that is non-separable, involving function\n    :math:`f` and a sum over :math:`\\ell_2` norm terms involving all\n    operators :math:`C_i`. This class is a base class for solvers of\n    the latter subproblem\n\n    ..  math::\n\n        \\argmin_{\\mb{x}} \\; f(\\mb{x}) + \\sum_i \\frac{\\rho_i}{2}\n        \\norm{\\mb{z}^{(k)}_i - \\mb{u}^{(k)}_i - C_i \\mb{x}}_2^2 \\;.\n\n    Attributes:\n        admm (:class:`.ADMM`): ADMM solver object to which the\n            solver is attached.\n    \"\"\"\n\n    def internal_init(self, admm: soa.ADMM):\n        \"\"\"Second stage initializer to be called by :meth:`.ADMM.__init__`.\n\n        Args:\n            admm: Reference to :class:`.ADMM` object to which the\n               :class:`.SubproblemSolver` object is to be attached.\n        \"\"\"\n        self.admm = admm\n\n\nclass GenericSubproblemSolver(SubproblemSolver):\n    \"\"\"Solver for generic problem without special structure.\n\n    Note that this solver is only suitable for small-scale problems.\n\n    Attributes:\n        admm (:class:`.ADMM`): ADMM solver object to which the solver is\n           attached.\n        minimize_kwargs (dict): Dictionary of arguments for\n           :func:`scico.solver.minimize`.\n    \"\"\"\n\n    def __init__(self, minimize_kwargs: dict = {\"options\": {\"maxiter\": 100}}):\n        \"\"\"Initialize a :class:`GenericSubproblemSolver` object.\n\n        Args:\n            minimize_kwargs: Dictionary of arguments for\n                :func:`scico.solver.minimize`.\n        \"\"\"\n        self.minimize_kwargs = minimize_kwargs\n        self.info: dict = {}\n\n    def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]:\n        \"\"\"Solve the ADMM step.\n\n        Args:\n           x0: Initial value.\n\n        Returns:\n            Computed solution.\n        \"\"\"\n\n        @jax.jit\n        def obj(x):\n            out = 0.0\n            for rhoi, Ci, zi, ui in zip(\n                self.admm.rho_list, self.admm.C_list, self.admm.z_list, self.admm.u_list\n            ):\n                out += 0.5 * rhoi * snp.sum(snp.abs(zi - ui - Ci(x)) ** 2)\n            if self.admm.f is not None:\n                out += self.admm.f(x)\n            return out\n\n        res = minimize(obj, x0, **self.minimize_kwargs)\n        for attrib in (\"success\", \"status\", \"message\", \"nfev\", \"njev\", \"nhev\", \"nit\", \"maxcv\"):\n            self.info[attrib] = getattr(res, attrib, None)\n\n        return res.x\n\n\nclass LinearSubproblemSolver(SubproblemSolver):\n    r\"\"\"Solver for quadratic functionals.\n\n    Solver for the case in which :code:`f` is a quadratic function of\n    :math:`\\mb{x}`. It is a specialization of :class:`.SubproblemSolver`\n    for the case where :code:`f` is an :math:`\\ell_2` or weighted\n    :math:`\\ell_2` norm, and :code:`f.A` is a linear operator, so that\n    the subproblem involves solving a linear equation. This requires that\n    :code:`f.functional` be an instance of :class:`.SquaredL2Loss` and\n    for the forward operator :code:`f.A` to be an instance of\n    :class:`.LinearOperator`.\n\n    The :math:`\\mb{x}`-update step is\n\n    ..  math::\n\n        \\mb{x}^{(k+1)} = \\argmin_{\\mb{x}} \\; \\frac{1}{2}\n        \\norm{\\mb{y} - A \\mb{x}}_W^2 + \\sum_i \\frac{\\rho_i}{2}\n        \\norm{\\mb{z}^{(k)}_i - \\mb{u}^{(k)}_i - C_i \\mb{x}}_2^2 \\;,\n\n    where :math:`W` a weighting :class:`.Diagonal` operator\n    or an :class:`.Identity` operator (i.e., no weighting).\n    This update step reduces to the solution of the linear system\n\n    ..  math::\n\n        \\left(A^H W A + \\sum_{i=1}^N \\rho_i C_i^H C_i \\right)\n        \\mb{x}^{(k+1)} = \\;\n        A^H W \\mb{y} + \\sum_{i=1}^N \\rho_i C_i^H ( \\mb{z}^{(k)}_i -\n        \\mb{u}^{(k)}_i) \\;.\n\n\n    Attributes:\n        admm (:class:`.ADMM`): ADMM solver object to which the solver is\n            attached.\n        cg_kwargs (dict): Dictionary of arguments for CG solver.\n        cg (func): CG solver function (:func:`scico.solver.cg` or\n            :func:`jax.scipy.sparse.linalg.cg`) lhs (type): Function\n            implementing the linear operator needed for the\n            :math:`\\mb{x}` update step.\n    \"\"\"\n\n    def __init__(self, cg_kwargs: Optional[dict[str, Any]] = None, cg_function: str = \"scico\"):\n        \"\"\"Initialize a :class:`LinearSubproblemSolver` object.\n\n        Args:\n            cg_kwargs: Dictionary of arguments for CG solver. See\n                documentation for :func:`scico.solver.cg` or\n                :func:`jax.scipy.sparse.linalg.cg`,\n                including how to specify a preconditioner.\n                Default values are the same as those of\n                :func:`scico.solver.cg`, except for\n                `\"tol\": 1e-4` and `\"maxiter\": 100`.\n            cg_function: String indicating which CG implementation to\n                use. One of \"jax\" or \"scico\"; default \"scico\". If\n                \"scico\", uses :func:`scico.solver.cg`. If \"jax\", uses\n                :func:`jax.scipy.sparse.linalg.cg`. The \"jax\" option is\n                slower on small-scale problems or problems involving\n                external functions, but can be differentiated through.\n                The \"scico\" option is faster on small-scale problems, but\n                slower on large-scale problems where the forward\n                operator is written entirely in jax.\n        \"\"\"\n\n        default_cg_kwargs = {\"tol\": 1e-4, \"maxiter\": 100}\n        if cg_kwargs:\n            default_cg_kwargs.update(cg_kwargs)\n        self.cg_kwargs = default_cg_kwargs\n        self.cg_function = cg_function\n        if cg_function == \"scico\":\n            self.cg = scico_cg\n        elif cg_function == \"jax\":\n            self.cg = jax_cg\n        else:\n            raise ValueError(\n                f\"Argument 'cg_function' must be one of 'jax', 'scico'; got {cg_function}.\"\n            )\n        self.info = None\n\n    def internal_init(self, admm: soa.ADMM):\n        if admm.f is not None:\n            if not isinstance(admm.f, SquaredL2Loss):\n                raise TypeError(\n                    \"LinearSubproblemSolver requires f to be a scico.loss.SquaredL2Loss; \"\n                    f\"got {type(admm.f)}.\"\n                )\n            if not isinstance(admm.f.A, LinearOperator):\n                raise TypeError(\n                    \"LinearSubproblemSolver requires f.A to be a scico.linop.LinearOperator; \"\n                    f\"got {type(admm.f.A)}.\"\n                )\n\n        super().internal_init(admm)  # call method of SubproblemSolver via GenericSubproblemSolver\n\n        # Set lhs_op =  \\sum_i rho_i * Ci.H @ Ci\n        # Use reduce as the initialization of this sum is messy otherwise\n        lhs_op = reduce(\n            lambda a, b: a + b, [rhoi * Ci.gram_op for rhoi, Ci in zip(admm.rho_list, admm.C_list)]\n        )\n        if admm.f is not None:\n            # hessian = A.T @ W @ A; W may be identity\n            lhs_op += admm.f.hessian\n\n        self.lhs_op = lhs_op\n\n    def compute_rhs(self) -> Union[Array, BlockArray]:\n        r\"\"\"Compute the right hand side of the linear equation to be solved.\n\n        Compute\n\n        .. math::\n\n            A^H W \\mb{y} + \\sum_{i=1}^N \\rho_i C_i^H ( \\mb{z}^{(k)}_i -\n            \\mb{u}^{(k)}_i) \\;.\n\n        Returns:\n            Computed solution.\n        \"\"\"\n\n        C0 = self.admm.C_list[0]\n        rhs = snp.zeros(C0.input_shape, C0.input_dtype)\n\n        if self.admm.f is not None:\n            ATWy = self.admm.f.A.adj(self.admm.f.W.diagonal * self.admm.f.y)  # type: ignore\n            rhs += 2.0 * self.admm.f.scale * ATWy  # type: ignore\n\n        for rhoi, Ci, zi, ui in zip(\n            self.admm.rho_list, self.admm.C_list, self.admm.z_list, self.admm.u_list\n        ):\n            rhs += rhoi * Ci.adj(zi - ui)\n        return rhs\n\n    def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]:\n        \"\"\"Solve the ADMM step.\n\n        Args:\n            x0: Initial value.\n\n        Returns:\n            Computed solution.\n        \"\"\"\n        rhs = self.compute_rhs()\n        x, self.info = self.cg(self.lhs_op, rhs, x0, **self.cg_kwargs)  # type: ignore\n        return x\n\n\nclass MatrixSubproblemSolver(LinearSubproblemSolver):\n    r\"\"\"Solver for quadratic functionals involving matrix operators.\n\n    Solver for the case in which :math:`f` is a quadratic function of\n    :math:`\\mb{x}`, and :math:`A` and all the :math:`C_i` are diagonal\n    or matrix operators. It is a specialization of\n    :class:`.LinearSubproblemSolver`.\n\n    As for :class:`.LinearSubproblemSolver`, the :math:`\\mb{x}`-update\n    step is\n\n    ..  math::\n\n        \\mb{x}^{(k+1)} = \\argmin_{\\mb{x}} \\; \\frac{1}{2}\n        \\norm{\\mb{y} - A \\mb{x}}_W^2 + \\sum_i \\frac{\\rho_i}{2}\n        \\norm{\\mb{z}^{(k)}_i - \\mb{u}^{(k)}_i - C_i \\mb{x}}_2^2 \\;,\n\n    where :math:`W` is a weighting :class:`.Diagonal` operator\n    or an :class:`.Identity` operator (i.e., no weighting).\n    This update step reduces to the solution of the linear system\n\n    ..  math::\n\n        \\left(A^H W A + \\sum_{i=1}^N \\rho_i C_i^H C_i \\right)\n        \\mb{x}^{(k+1)} = \\;\n        A^H W \\mb{y} + \\sum_{i=1}^N \\rho_i C_i^H ( \\mb{z}^{(k)}_i -\n        \\mb{u}^{(k)}_i) \\;,\n\n    which is solved by factorization of the left hand side of the\n    equation, using :class:`.MatrixATADSolver`.\n\n\n    Attributes:\n        admm (:class:`.ADMM`): ADMM solver object to which the solver is\n            attached.\n        solve_kwargs (dict): Dictionary of arguments for solver\n            :class:`.MatrixATADSolver` initialization.\n    \"\"\"\n\n    def __init__(self, check_solve: bool = False, solve_kwargs: Optional[dict[str, Any]] = None):\n        \"\"\"Initialize a :class:`MatrixSubproblemSolver` object.\n\n        Args:\n            check_solve: If ``True``, compute solver accuracy after each\n                solve.\n            solve_kwargs: Dictionary of arguments for solver\n                :class:`.MatrixATADSolver` initialization.\n        \"\"\"\n        self.check_solve = check_solve\n        default_solve_kwargs = {\"cho_factor\": False}\n        if solve_kwargs:\n            default_solve_kwargs.update(solve_kwargs)\n        self.solve_kwargs = default_solve_kwargs\n\n    def internal_init(self, admm: soa.ADMM):\n        if admm.f is not None:\n            if not isinstance(admm.f, SquaredL2Loss):\n                raise TypeError(\n                    \"MatrixSubproblemSolver requires f to be a scico.loss.SquaredL2Loss; \"\n                    f\"got {type(admm.f)}.\"\n                )\n            if not isinstance(admm.f.A, (Diagonal, MatrixOperator)):\n                raise TypeError(\n                    \"MatrixSubproblemSolver requires f.A to be a Diagonal or MatrixOperator; \"\n                    f\"got {type(admm.f.A)}.\"\n                )\n        for i, Ci in enumerate(admm.C_list):\n            if not isinstance(Ci, (Diagonal, MatrixOperator)):\n                raise TypeError(\n                    \"MatrixSubproblemSolver requires C[{i}] to be a Diagonal or MatrixOperator; \"\n                    f\"got {type(Ci)}.\"\n                )\n\n        SubproblemSolver.internal_init(self, admm)\n\n        if admm.f is None:\n            A = snp.zeros(admm.C_list[0].input_shape[0], dtype=admm.C_list[0].input_dtype)\n            W = None\n        else:\n            A = admm.f.A\n            W = 2.0 * self.admm.f.scale * admm.f.W  # type: ignore\n\n        Csum = reduce(\n            lambda a, b: a + b, [rhoi * Ci.gram_op for rhoi, Ci in zip(admm.rho_list, admm.C_list)]\n        )\n        self.solver = MatrixATADSolver(A, Csum, W, **self.solve_kwargs)\n\n    def solve(self, x0: Array) -> Array:\n        \"\"\"Solve the ADMM step.\n\n        Args:\n            x0: Initial value (ignored).\n\n        Returns:\n            Computed solution.\n        \"\"\"\n        rhs = self.compute_rhs()\n        x = self.solver.solve(rhs)\n        if self.check_solve:\n            self.accuracy = self.solver.accuracy(x, rhs)\n\n        return x\n\n\nclass CircularConvolveSolver(LinearSubproblemSolver):\n    r\"\"\"Solver for linear operators diagonalized in the DFT domain.\n\n    Specialization of :class:`.LinearSubproblemSolver` for the case\n    where :code:`f` is ``None``, or an instance of\n    :class:`.SquaredL2Loss` with a forward operator :code:`f.A` that is\n    either an instance of :class:`.Identity` or\n    :class:`.CircularConvolve`, and the :code:`C_i` are all shift\n    invariant linear operators, examples of which include instances of\n    :class:`.Identity` as well as some instances (depending on\n    initializer parameters) of :class:`.CircularConvolve` and\n    :class:`.FiniteDifference`. None of the instances of\n    :class:`.CircularConvolve` may sum over any of their axes.\n\n    Attributes:\n        admm (:class:`.ADMM`): ADMM solver object to which the solver is\n            attached.\n        lhs_f (array): Left hand side, in the DFT domain, of the linear\n            equation to be solved.\n    \"\"\"\n\n    def __init__(self, ndims: Optional[int] = None):\n        \"\"\"Initialize a :class:`CircularConvolveSolver` object.\n\n        Args:\n            ndims: Number of trailing dimensions of the input and kernel\n                involved in the :class:`.CircularConvolve` convolutions.\n                In most cases this value is automatically determined from\n                the optimization problem specification, but this is not\n                possible when :code:`f` is ``None`` and none of the\n                :code:`C_i` are of type :class:`.CircularConvolve`. When\n                not ``None``, this parameter overrides the automatic\n                mechanism.\n        \"\"\"\n        self.ndims = ndims\n\n    def internal_init(self, admm: soa.ADMM):\n        if admm.f is None:\n            is_cc = [isinstance(C, CircularConvolve) for C in admm.C_list]\n            if any(is_cc):\n                auto_ndims = admm.C_list[is_cc.index(True)].ndims\n            else:\n                auto_ndims = None\n        else:\n            if not isinstance(admm.f, SquaredL2Loss):\n                raise TypeError(\n                    \"CircularConvolveSolver requires f to be a scico.loss.SquaredL2Loss; \"\n                    f\"got {type(admm.f)}.\"\n                )\n            if not isinstance(admm.f.A, (CircularConvolve, Identity)):\n                raise TypeError(\n                    \"CircularConvolveSolver requires f.A to be a scico.linop.CircularConvolve \"\n                    f\"or scico.linop.Identity; got {type(admm.f.A)}.\"\n                )\n            auto_ndims = admm.f.A.ndims if isinstance(admm.f.A, CircularConvolve) else None\n\n        if self.ndims is None:\n            self.ndims = auto_ndims\n\n        SubproblemSolver.internal_init(self, admm)\n\n        self.real_result = is_real_dtype(admm.C_list[0].input_dtype)\n\n        # All of the C operators are assumed to be linear and shift invariant\n        # but this is not checked.\n        lhs_op_list = [\n            rho * CircularConvolve.from_operator(C.gram_op, ndims=self.ndims)\n            for rho, C in zip(admm.rho_list, admm.C_list)\n        ]\n        A_lhs = reduce(lambda a, b: a + b, lhs_op_list)\n        if self.admm.f is not None:\n            A_lhs += (\n                2.0\n                * admm.f.scale\n                * CircularConvolve.from_operator(admm.f.A.gram_op, ndims=self.ndims)\n            )\n\n        self.A_lhs = A_lhs\n\n    def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]:\n        \"\"\"Solve the ADMM step.\n\n        Args:\n            x0: Initial value (unused, has no effect).\n\n        Returns:\n            Computed solution.\n        \"\"\"\n        rhs = self.compute_rhs()\n        rhs_dft = snp.fft.fftn(rhs, axes=self.A_lhs.x_fft_axes)\n        x_dft = rhs_dft / self.A_lhs.h_dft\n        x = snp.fft.ifftn(x_dft, axes=self.A_lhs.x_fft_axes)\n        if self.real_result:\n            x = x.real\n\n        return x\n\n\nclass FBlockCircularConvolveSolver(LinearSubproblemSolver):\n    r\"\"\"Solver for linear operators block-diagonalized in the DFT domain.\n\n    Specialization of :class:`.LinearSubproblemSolver` for the case where\n    :code:`f` is an instance of :class:`.SquaredL2Loss`, the forward\n    operator :code:`f.A` is a composition of a :class:`.Sum` operator and\n    a :class:`.CircularConvolve` operator. The former must sum over the\n    first axis of its input, and the latter must be initialized so that\n    it convolves a set of filters, indexed by the first axis, with an\n    input array that has the same number of axes as the filter array, and\n    has an initial axis of the same length as that of the filter array.\n    The :math:`C_i` must all be shift invariant linear operators,\n    examples of which include instances of :class:`.Identity` as well as\n    some instances (depending on initializer parameters) of\n    :class:`.CircularConvolve` and :class:`.FiniteDifference`. None of\n    the instances of :class:`.CircularConvolve` may be summed over any of\n    their axes.\n\n    The solver is based on the frequency-domain approach proposed in\n    :cite:`wohlberg-2014-efficient`. We have :math:`f = \\omega\n    \\norm{A \\mb{x} - \\mb{y}}_2^2`, where typically :math:`\\omega = 1/2`,\n    and :math:`A` is a block-row operator with circulant blocks, i.e. it\n    can be written as\n\n    .. math::\n\n       A = \\left( \\begin{array}{cccc} A_1 & A_2 & \\ldots & A_{K}\n           \\end{array} \\right) \\;,\n\n    where all of the :math:`A_k` are circular convolution operators. The\n    complete functional to be minimized is\n\n    .. math::\n\n       \\omega \\norm{A \\mb{x} - \\mb{y}}_2^2 + \\sum_{i=1}^N g_i(C_i \\mb{x})\n       \\;,\n\n    where the :math:`C_i` are either identity or circular convolutions,\n    and the ADMM x-step is\n\n    .. math::\n\n       \\mb{x}^{(j+1)} = \\argmin_{\\mb{x}} \\; \\omega \\norm{A \\mb{x}\n       - \\mb{y}}_2^2 + \\sum_i \\frac{\\rho_i}{2} \\norm{C_i \\mb{x} -\n       (\\mb{z}^{(j)}_i - \\mb{u}^{(j)}_i)}_2^2 \\;.\n\n    This subproblem is most easily solved in the DFT transform domain,\n    where the circular convolutions become diagonal operators. Denoting\n    the frequency-domain versions of variables with a circumflex (e.g.\n    :math:`\\hat{\\mb{x}}` is the frequency-domain version of\n    :math:`\\mb{x}`), the solution of the subproblem can be written as\n\n    .. math::\n\n       \\left( \\hat{A}^H \\hat{A} + \\frac{1}{2 \\omega} \\sum_i \\rho_i\n       \\hat{C}_i^H \\hat{C}_i \\right) \\hat{\\mathbf{x}} = \\hat{A}^H\n       \\hat{\\mb{y}} + \\frac{1}{2 \\omega} \\sum_i \\rho_i \\hat{C}_i^H\n       (\\hat{\\mb{z}}_i - \\hat{\\mb{u}}_i) \\;.\n\n    This linear equation is computational expensive to solve because\n    the left hand side includes the term :math:`\\hat{A}^H \\hat{A}`,\n    which corresponds to the outer product of :math:`\\hat{A}^H`\n    and :math:`\\hat{A}`. A computationally efficient solution is possible,\n    however, by exploiting the Woodbury matrix identity\n\n    .. math::\n\n       (D + U G V)^{-1} = D^{-1} - D^{-1} U (G^{-1} + V D^{-1} U)^{-1}\n       V D^{-1} \\;.\n\n    Setting\n\n    .. math::\n\n       D &= \\frac{1}{2 \\omega} \\sum_i \\rho_i \\hat{C}_i^H \\hat{C}_i \\\\\n       U &= \\hat{A}^H \\\\\n       G &= I \\\\\n       V &= \\hat{A}\n\n    we have\n\n    .. math::\n\n       (D + \\hat{A}^H \\hat{A})^{-1} = D^{-1} - D^{-1} \\hat{A}^H\n       (I + \\hat{A} D^{-1} \\hat{A}^H)^{-1} \\hat{A} D^{-1}\n\n    which can be simplified to\n\n    .. math::\n\n       (D + \\hat{A}^H \\hat{A})^{-1} = D^{-1} (I - \\hat{A}^H E^{-1}\n       \\hat{A} D^{-1})\n\n    by defining :math:`E = I + \\hat{A} D^{-1} \\hat{A}^H`. The right\n    hand side is much cheaper to compute because the only matrix\n    inversions involve :math:`D`, which is diagonal, and :math:`E`,\n    which is a weighted inner product of :math:`\\hat{A}^H` and\n    :math:`\\hat{A}`.\n    \"\"\"\n\n    def __init__(self, ndims: Optional[int] = None, check_solve: bool = False):\n        \"\"\"Initialize a :class:`FBlockCircularConvolveSolver` object.\n\n        Args:\n            check_solve: If ``True``, compute solver accuracy after each\n                solve.\n        \"\"\"\n        self.ndims = ndims\n        self.check_solve = check_solve\n        self.accuracy: Optional[float] = None\n\n    def internal_init(self, admm: soa.ADMM):\n        if admm.f is None:\n            raise ValueError(\"FBlockCircularConvolveSolver does not allow f to be None.\")\n        else:\n            if not isinstance(admm.f, SquaredL2Loss):\n                raise TypeError(\n                    \"FBlockCircularConvolveSolver requires f to be a scico.loss.SquaredL2Loss; \"\n                    f\"got {type(admm.f)}.\"\n                )\n            if not isinstance(admm.f.A, ComposedLinearOperator):\n                raise TypeError(\n                    \"FBlockCircularConvolveSolver requires f.A to be a composition of Sum \"\n                    f\"and CircularConvolve linear operators; got {type(admm.f.A)}.\"\n                )\n\n        SubproblemSolver.internal_init(self, admm)\n\n        assert isinstance(self.admm.f, SquaredL2Loss)\n        assert isinstance(self.admm.f.A, ComposedLinearOperator)\n\n        # All of the C operators are assumed to be linear and shift invariant\n        # but this is not checked.\n        c_gram_list = [\n            rho * CircularConvolve.from_operator(C.gram_op, ndims=self.ndims)\n            for rho, C in zip(admm.rho_list, admm.C_list)\n        ]\n        D = reduce(lambda a, b: a + b, c_gram_list) / (2.0 * self.admm.f.scale)\n        self.solver = ConvATADSolver(self.admm.f.A, D)\n\n    def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]:\n        \"\"\"Solve the ADMM step.\n\n        Args:\n            x0: Initial value (unused, has no effect).\n\n        Returns:\n            Computed solution.\n        \"\"\"\n        assert isinstance(self.admm.f, SquaredL2Loss)\n\n        rhs = self.compute_rhs() / (2.0 * self.admm.f.scale)\n        x = self.solver.solve(rhs)\n        if self.check_solve:\n            self.accuracy = self.solver.accuracy(x, rhs)\n\n        return x\n\n\nclass G0BlockCircularConvolveSolver(SubproblemSolver):\n    r\"\"\"Solver for linear operators block-diagonalized in the DFT\n    domain.\n\n    Specialization of :class:`.LinearSubproblemSolver` for the case\n    where :math:`f = 0` (i.e, :code:`f` is a :class:`.ZeroFunctional`),\n    :math:`g_1` is an instance of :class:`.SquaredL2Loss`, :math:`C_1`\n    is a composition of a :class:`.Sum` operator an a\n    :class:`.CircularConvolve` operator.  The former must sum over the\n    first axis of its input, and the latter must be initialized so\n    that it convolves a set of filters, indexed by the first axis,\n    with an input array that has the same number of axes as the filter\n    array, and has an initial axis of the same length as that of the\n    filter array.  The other :math:`C_i` must all be shift invariant\n    linear operators, examples of which include instances of\n    :class:`.Identity` as well as some instances (depending on\n    initializer parameters) of :class:`.CircularConvolve` and\n    :class:`.FiniteDifference`.  None of these instances of\n    :class:`.CircularConvolve` may be summed over any of their axes.\n\n    The solver is based on the frequency-domain approach proposed in\n    :cite:`wohlberg-2014-efficient`.  We have :math:`g_1 = \\omega\n    \\norm{B A \\mb{x} - \\mb{y}}_2^2`, where typically :math:`\\omega =\n    1/2`, :math:`B` is the identity or a diagonal operator, and\n    :math:`A` is a block-row operator with circulant blocks, i.e. it\n    can be written as\n\n    .. math::\n\n       A = \\left( \\begin{array}{cccc} A_1 & A_2 & \\ldots & A_{K}\n           \\end{array} \\right) \\;,\n\n    where all of the :math:`A_k` are circular convolution operators. The\n    complete functional to be minimized is\n\n    .. math::\n\n       \\sum_{i=1}^N g_i(C_i \\mb{x}) \\;,\n\n    where\n\n    .. math::\n\n       g_1(\\mb{z}) &= \\omega \\norm{B \\mb{z} - \\mb{y}}_2^2\\\\\n       C_1 &= A \\;,\n\n    and the other :math:`C_i` are either identity or circular\n    convolutions. The ADMM x-step is\n\n    .. math::\n\n       \\mb{x}^{(j+1)} = \\argmin_{\\mb{x}} \\; \\rho_1 \\omega \\norm{\n       A \\mb{x} - (\\mb{z}^{(j)}_1 - \\mb{u}^{(j)}_1)}_2^2 + \\sum_{i=2}^N\n       \\frac{\\rho_i}{2} \\norm{C_i \\mb{x} - (\\mb{z}^{(j)}_i -\n       \\mb{u}^{(j)}_i)}_2^2 \\;.\n\n    This subproblem is most easily solved in the DFT transform domain,\n    where the circular convolutions become diagonal operators. Denoting\n    the frequency-domain versions of variables with a circumflex (e.g.\n    :math:`\\hat{\\mb{x}}` is the frequency-domain version of\n    :math:`\\mb{x}`), the solution of the subproblem can be written as\n\n    .. math::\n\n       \\left( \\hat{A}^H \\hat{A} + \\frac{1}{2 \\omega \\rho_1} \\sum_{i=2}^N\n       \\rho_i \\hat{C}_i^H \\hat{C}_i \\right) \\hat{\\mathbf{x}} =\n       \\hat{A}^H (\\hat{\\mb{z}}_1 - \\hat{\\mb{u}}_1) +\n       \\frac{1}{2 \\omega \\rho_1} \\sum_{i=2}^N \\rho_i\n       \\hat{C}_i^H (\\hat{\\mb{z}}_i - \\hat{\\mb{u}}_i) \\;.\n\n    This linear equation is computational expensive to solve because\n    the left hand side includes the term :math:`\\hat{A}^H \\hat{A}`,\n    which corresponds to the outer product of :math:`\\hat{A}^H`\n    and :math:`\\hat{A}`. A computationally efficient solution is possible,\n    however, by exploiting the Woodbury matrix identity\n\n    .. math::\n\n       (D + U G V)^{-1} = D^{-1} - D^{-1} U (G^{-1} + V D^{-1} U)^{-1}\n       V D^{-1} \\;.\n\n    Setting\n\n    .. math::\n\n       D &= \\frac{1}{2 \\omega \\rho_1} \\sum_{i=2}^N \\rho_i \\hat{C}_i^H\n            \\hat{C}_i \\\\\n       U &= \\hat{A}^H \\\\\n       G &= I \\\\\n       V &= \\hat{A}\n\n    we have\n\n    .. math::\n\n       (D + \\hat{A}^H \\hat{A})^{-1} = D^{-1} - D^{-1} \\hat{A}^H\n       (I + \\hat{A} D^{-1} \\hat{A}^H)^{-1} \\hat{A} D^{-1}\n\n    which can be simplified to\n\n    .. math::\n\n       (D + \\hat{A}^H \\hat{A})^{-1} = D^{-1} (I - \\hat{A}^H E^{-1}\n       \\hat{A} D^{-1})\n\n    by defining :math:`E = I + \\hat{A} D^{-1} \\hat{A}^H`. The right\n    hand side is much cheaper to compute because the only matrix\n    inversions involve :math:`D`, which is diagonal, and :math:`E`,\n    which is a weighted inner product of :math:`\\hat{A}^H` and\n    :math:`\\hat{A}`.\n    \"\"\"\n\n    def __init__(self, ndims: Optional[int] = None, check_solve: bool = False):\n        \"\"\"Initialize a :class:`G0BlockCircularConvolveSolver` object.\n\n        Args:\n            check_solve: If ``True``, compute solver accuracy after each\n                solve.\n        \"\"\"\n        self.ndims = ndims\n        self.check_solve = check_solve\n        self.accuracy: Optional[float] = None\n\n    def internal_init(self, admm: soa.ADMM):\n        if admm.f is not None and not isinstance(admm.f, ZeroFunctional):\n            raise ValueError(\n                \"G0BlockCircularConvolveSolver requires f to be None or a ZeroFunctional\"\n            )\n        if not isinstance(admm.g_list[0], SquaredL2Loss):\n            raise TypeError(\n                \"G0BlockCircularConvolveSolver requires g_1 to be a scico.loss.SquaredL2Loss; \"\n                f\"got {type(admm.g_list[0])}.\"\n            )\n        if not isinstance(admm.C_list[0], ComposedLinearOperator):\n            raise TypeError(\n                \"G0BlockCircularConvolveSolver requires C_1 to be a composition of Sum \"\n                f\"and CircularConvolve linear operators; got {type(admm.C_list[0])}.\"\n            )\n\n        SubproblemSolver.internal_init(self, admm)\n\n        assert isinstance(self.admm.g_list[0], SquaredL2Loss)\n        assert isinstance(self.admm.C_list[0], ComposedLinearOperator)\n\n        # All of the C operators are assumed to be linear and shift invariant\n        # but this is not checked.\n        c_gram_list = [\n            rho * CircularConvolve.from_operator(C.gram_op, ndims=self.ndims)\n            for rho, C in zip(admm.rho_list[1:], admm.C_list[1:])\n        ]\n        D = reduce(lambda a, b: a + b, c_gram_list) / (\n            2.0 * self.admm.g_list[0].scale * admm.rho_list[0]\n        )\n        self.solver = ConvATADSolver(self.admm.C_list[0], D)\n\n    def compute_rhs(self) -> Union[Array, BlockArray]:\n        r\"\"\"Compute the right hand side of the linear equation to be solved.\n\n        Compute\n\n        .. math::\n\n            C_1^H  ( \\mb{z}^{(k)}_1 - \\mb{u}^{(k)}_1) +\n            \\frac{1}{2 \\omega \\rho_1}\\sum_{i=2}^N \\rho_i C_i^H\n            ( \\mb{z}^{(k)}_i - \\mb{u}^{(k)}_i) \\;.\n\n        Returns:\n            Right hand side of the linear equation.\n        \"\"\"\n        assert isinstance(self.admm.g_list[0], SquaredL2Loss)\n\n        C0 = self.admm.C_list[0]\n        rhs = snp.zeros(C0.input_shape, C0.input_dtype)\n        omega = self.admm.g_list[0].scale\n        omega_list = [\n            2.0 * omega,\n        ] + [\n            1.0,\n        ] * (len(self.admm.C_list) - 1)\n        for omegai, rhoi, Ci, zi, ui in zip(\n            omega_list, self.admm.rho_list, self.admm.C_list, self.admm.z_list, self.admm.u_list\n        ):\n            rhs += omegai * rhoi * Ci.adj(zi - ui)\n        return rhs\n\n    def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]:\n        \"\"\"Solve the ADMM step.\n\n        Args:\n            x0: Initial value (unused, has no effect).\n\n        Returns:\n            Computed solution.\n        \"\"\"\n        assert isinstance(self.admm.g_list[0], SquaredL2Loss)\n\n        rhs = self.compute_rhs() / (2.0 * self.admm.g_list[0].scale * self.admm.rho_list[0])\n        x = self.solver.solve(rhs)\n        if self.check_solve:\n            self.accuracy = self.solver.accuracy(x, rhs)\n\n        return x\n"
  },
  {
    "path": "scico/optimize/_common.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2023-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Functions common to multiple optimizer modules.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union\n\nimport numpy as np\n\nfrom scico.diagnostics import IterationStats\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.util import (\n    array_to_namedtuple,\n    namedtuple_to_array,\n    transpose_ntpl_of_list,\n)\nfrom scico.util import Timer\n\n\ndef itstat_func_and_object(\n    itstat_fields: dict, itstat_attrib: List, itstat_options: Optional[dict] = None\n) -> Tuple[Callable, IterationStats]:\n    \"\"\"Iteration statistics initialization.\n\n    Iteration statistics initialization steps common to all optimizer\n    classes.\n\n    Args:\n        itstat_fields: A dictionary associating field names with format\n              strings for displaying the corresponding values.\n        itstat_attrib: A list of expressions corresponding of optimizer\n              class attributes to be evaluated when computing iteration\n              statistics.\n        itstat_options: A dict of named parameters to be passed to\n              the :class:`.diagnostics.IterationStats` initializer. The\n              dict may also include an additional key \"itstat_func\"\n              with the corresponding value being a function with two\n              parameters, an integer and a :class:`Optimizer` object,\n              responsible for constructing a tuple ready for insertion\n              into the :class:`.diagnostics.IterationStats` object. If\n              ``None``, default values are used for the dict entries,\n              otherwise the default dict is updated with the dict\n              specified by this parameter.\n\n    Returns:\n        A tuple consisting of the statistics insertion function and the\n        :class:`.diagnostics.IterationStats` object.\n    \"\"\"\n    # dynamically create itstat_func; see https://stackoverflow.com/questions/24733831\n    itstat_return = \"return(\" + \", \".join([\"obj.\" + attr for attr in itstat_attrib]) + \")\"\n    scope: Dict[str, Callable] = {}\n    exec(\"def itstat_func(obj): \" + itstat_return, scope)\n\n    # determine itstat options and initialize IterationStats object\n    default_itstat_options: Dict[str, Union[dict, Callable, bool]] = {\n        \"fields\": itstat_fields,\n        \"itstat_func\": scope[\"itstat_func\"],\n        \"display\": False,\n    }\n    if itstat_options:\n        default_itstat_options.update(itstat_options)\n\n    itstat_insert_func: Callable = default_itstat_options.pop(\"itstat_func\", None)  # type: ignore\n    itstat_object = IterationStats(**default_itstat_options)  # type: ignore\n\n    return itstat_insert_func, itstat_object\n\n\nclass Optimizer:\n    \"\"\"Base class for optimizer classes.\n\n    Attributes:\n        itnum (int): Optimizer iteration counter.\n        maxiter (int): Maximum number of optimizer outer-loop iterations.\n        timer (:class:`.Timer`): Iteration timer.\n    \"\"\"\n\n    def __init__(self, **kwargs: Any):\n        \"\"\"Initialize common attributes of :class:`Optimizer` objects.\n\n        Args:\n            **kwargs: Optional parameter dict. Valid keys are:\n\n                iter0:\n                  Initial value of iteration counter. Default value is 0.\n\n                maxiter:\n                  Maximum iterations on call to :meth:`solve`. Default\n                  value is 100.\n\n                nanstop:\n                  If ``True``, stop iterations if a ``NaN`` or ``Inf``\n                  value is encountered in a solver working variable.\n                  Default value is ``False``.\n\n                itstat_options:\n                  A dict of named parameters to be passed to\n                  the :class:`.diagnostics.IterationStats` initializer.\n                  The dict may also include an additional key\n                  \"itstat_func\" with the corresponding value being a\n                  function with two parameters, an integer and an\n                  :class:`Optimizer` object, responsible for constructing\n                  a tuple ready for insertion into the\n                  :class:`.diagnostics.IterationStats` object. If\n                  ``None``, default values are used for the dict entries,\n                  otherwise the default dict is updated with the dict\n                  specified by this parameter.\n        \"\"\"\n        iter0 = kwargs.pop(\"iter0\", 0)\n        self.maxiter: int = kwargs.pop(\"maxiter\", 100)\n        self.nanstop: bool = kwargs.pop(\"nanstop\", False)\n        itstat_options = kwargs.pop(\"itstat_options\", None)\n\n        if kwargs:\n            raise TypeError(f\"Unrecognized keyword argument(s) {', '.join([k for k in kwargs])}\")\n\n        self.itnum: int = iter0\n        self.timer: Timer = Timer()\n\n        itstat_fields, itstat_attrib = self._itstat_default_fields()\n        itstat_extra_fields, itstat_extra_attrib = self._itstat_extra_fields()\n        itstat_fields.update(itstat_extra_fields)\n        itstat_attrib.extend(itstat_extra_attrib)\n\n        self.itstat_insert_func, self.itstat_object = itstat_func_and_object(\n            itstat_fields, itstat_attrib, itstat_options\n        )\n\n    def _working_vars_finite(self) -> bool:\n        \"\"\"Determine where ``NaN`` of ``Inf`` encountered in solve.\n\n        Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in\n        a solver working variable.\n        \"\"\"\n        raise NotImplementedError(\n            \"NaN check requested but _working_vars_finite not implemented.\" \"\"\n        )\n\n    def _itstat_default_fields(self) -> Tuple[Dict[str, str], List[str]]:\n        \"\"\"Define iterations stats default fields.\n\n        Return a dict mapping field names to format strings, and a list\n        of strings containing the names of attributes or methods to call\n        in order to determine the value for each field.\n        \"\"\"\n        # iteration number and time fields\n        itstat_fields = {\n            \"Iter\": \"%d\",\n            \"Time\": \"%8.2e\",\n        }\n        itstat_attrib = [\"itnum\", \"timer.elapsed()\"]\n        # objective function can be evaluated if 'g' function can be evaluated\n        if self._objective_evaluatable():\n            itstat_fields.update({\"Objective\": \"%9.3e\"})\n            itstat_attrib.append(\"objective()\")\n\n        return itstat_fields, itstat_attrib\n\n    def _objective_evaluatable(self) -> bool:\n        \"\"\"Determine whether the objective function can be evaluated.\n\n        Determine whether the objective function for a :class:`Optimizer`\n        object can be evaluated.\n        \"\"\"\n        return False\n\n    def _itstat_extra_fields(self) -> Tuple[Dict[str, str], List[str]]:\n        \"\"\"Define additional iterations stats fields.\n\n        Define iterations stats fields that are not common to all\n        :class:`Optimizer` classes. Return a dict mapping field names to\n        format strings, and a list of strings containing the names of\n        attributes or methods to call in order to determine the value for\n        each field.\n        \"\"\"\n        return {}, []\n\n    def _state_variable_names(self) -> List[str]:\n        \"\"\"Get optimizer state variable names.\n\n        Get optimizer state variable names.\n\n        Returns:\n            List of names of class attributes that represent algorithm\n            state variables.\n        \"\"\"\n        raise NotImplementedError(f\"Method _state_variables is not implemented for {type(self)}.\")\n\n    def _get_state_variables(self) -> dict[str, Any]:\n        \"\"\"Get optimizer state variables.\n\n        Get optimizer state variables.\n\n        Returns:\n            Dict of state variable names and corresponding values.\n        \"\"\"\n        return {k: getattr(self, k) for k in self._state_variable_names()}\n\n    def _set_state_variables(self, **kwargs):\n        \"\"\"Set optimizer state variables.\n\n        Set optimizer state variables.\n\n        Args:\n            **kwargs: State variables to be set, with parameter names\n                corresponding to their class attribute names.\n        \"\"\"\n        valid_vars = self._state_variable_names()\n        for k, v in kwargs.items():\n            if k not in valid_vars:\n                raise RuntimeError(f\"{k} is not a valid state variable for {type(self)}.\")\n            setattr(self, k, v)\n\n    def save_state(self, path: str):\n        \"\"\"Save optimizer state to a file.\n\n        Save optimizer state to a file.\n\n        Args:\n            path: Filename of file to which state should be saved.\n        \"\"\"\n        state_vars = self._get_state_variables()\n        np.savez(\n            path,\n            opt_class=self.__class__,\n            itnum=self.itnum,\n            history=namedtuple_to_array(self.history(transpose=True)),  # type: ignore\n            **state_vars,\n        )\n\n    def load_state(self, path: str):\n        \"\"\"Load optimizer state from a file.\n\n        Restore optimizer state from a file.\n\n        Args:\n            path: Filename of state file saved using :meth:`save_state`.\n        \"\"\"\n        npz = np.load(path, allow_pickle=True)\n        if npz[\"opt_class\"] != self.__class__:\n            raise TypeError(\n                f\"Cannot load state for {npz['solver_class']} into optimizer \"\n                f\"of type {self.__class__}.\"\n            )\n        npzd = dict(npz)\n        npzd.pop(\"opt_class\")\n        self.itnum = npzd.pop(\"itnum\")\n        history = transpose_ntpl_of_list(array_to_namedtuple(npzd.pop(\"history\")))\n        self.itstat_object.iterations = history\n        self._set_state_variables(**npzd)\n\n    def history(self, transpose: bool = False) -> Union[List[NamedTuple], Tuple[List]]:\n        \"\"\"Retrieve record of algorithm iterations.\n\n        Retrieve record of algorithm iterations.\n\n        Args:\n            transpose: Flag indicating whether results should be returned\n                in \"transposed\" form, i.e. as a namedtuple of lists\n                rather than a list of namedtuples.\n\n        Returns:\n            Record of all iterations.\n        \"\"\"\n        return self.itstat_object.history(transpose=transpose)\n\n    def minimizer(self) -> Union[Array, BlockArray]:\n        \"\"\"Return the current estimate of the functional mimimizer.\n\n        Returns:\n            Current estimate of the functional mimimizer.\n        \"\"\"\n        raise NotImplementedError(f\"Method minimizer is not implemented for {type(self)}.\")\n\n    def step(self):\n        \"\"\"Perform a single optimizer step.\"\"\"\n        raise NotImplementedError(f\"Method step is not implemented for {type(self)}.\")\n\n    def solve(\n        self,\n        callback: Optional[Callable[[Optimizer], None]] = None,\n    ) -> Union[Array, BlockArray]:\n        r\"\"\"Initialize and run the optimization algorithm.\n\n        Initialize and run the opimization algorithm for a total of\n        `self.maxiter` iterations.\n\n        Args:\n            callback: An optional callback function, taking an a single\n              argument of type :class:`Optimizer`, that is called\n              at the end of every iteration.\n\n        Returns:\n            Computed solution.\n        \"\"\"\n        self.timer.start()\n        for self.itnum in range(self.itnum, self.itnum + self.maxiter):\n            self.step()\n            if self.nanstop and not self._working_vars_finite():\n                raise ValueError(\n                    f\"NaN or Inf value encountered in working variable in iteration {self.itnum}.\"\n                    \"\"\n                )\n            self.itstat_object.insert(self.itstat_insert_func(self))\n            if callback:\n                self.timer.stop()\n                callback(self)\n                self.timer.start()\n        self.timer.stop()\n        self.itnum += 1\n        self.itstat_object.end()\n        return self.minimizer()\n"
  },
  {
    "path": "scico/optimize/_ladmm.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Linearized ADMM solver.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom typing import List, Optional, Tuple, Union\n\nimport scico.numpy as snp\nfrom scico.functional import Functional\nfrom scico.linop import LinearOperator\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.linalg import norm\n\nfrom ._common import Optimizer\n\n\nclass LinearizedADMM(Optimizer):\n    r\"\"\"Linearized alternating direction method of multipliers algorithm.\n\n    |\n\n    Solve an optimization problem of the form\n\n    .. math::\n        \\argmin_{\\mb{x}} \\; f(\\mb{x}) + g(C \\mb{x}) \\;,\n\n    where :math:`f` and :math:`g` are instances of :class:`.Functional`,\n    (in most cases :math:`f` will, more specifically be an an instance\n    of :class:`.Loss`), and :math:`C` is an instance of\n    :class:`.LinearOperator`.\n\n    The optimization problem is solved by introducing the splitting\n    :math:`\\mb{z} = C \\mb{x}` and solving\n\n    .. math::\n        \\argmin_{\\mb{x}, \\mb{z}} \\; f(\\mb{x}) + g(\\mb{z}) \\;\n       \\text{such that}\\; C \\mb{x} = \\mb{z} \\;,\n\n    via a linearized ADMM algorithm :cite:`yang-2012-linearized`\n    :cite:`parikh-2014-proximal` (Sec. 4.4.2) consisting of the\n    iterations (see :meth:`step`)\n\n    .. math::\n       \\begin{aligned}\n       \\mb{x}^{(k+1)} &= \\mathrm{prox}_{\\mu f} \\left( \\mb{x}^{(k)} -\n       (\\mu / \\nu) C^T \\left(C \\mb{x}^{(k)} - \\mb{z}^{(k)} + \\mb{u}^{(k)}\n       \\right) \\right) \\\\\n       \\mb{z}^{(k+1)} &= \\mathrm{prox}_{\\nu g} \\left(C \\mb{x}^{(k+1)} +\n       \\mb{u}^{(k)} \\right) \\\\\n       \\mb{u}^{(k+1)} &=  \\mb{u}^{(k)} + C \\mb{x}^{(k+1)} -\n       \\mb{z}^{(k+1)}  \\;.\n       \\end{aligned}\n\n    Parameters :math:`\\mu` and :math:`\\nu` are required to satisfy\n\n    .. math::\n       0 < \\mu < \\nu \\| C \\|_2^{-2} \\;.\n\n\n    Attributes:\n        f (:class:`.Functional`): Functional :math:`f` (usually a\n           :class:`.Loss`).\n        g (:class:`.Functional`): Functional :math:`g`.\n        C (:class:`.LinearOperator`): :math:`C` operator.\n        mu (scalar): First algorithm parameter.\n        nu (scalar): Second algorithm parameter.\n        u (array-like): Scaled Lagrange multipliers :math:`\\mb{u}` at\n           current iteration.\n        x (array-like): Solution variable.\n        z (array-like): Auxiliary variables :math:`\\mb{z}` at current\n          iteration.\n        z_old (array-like): Auxiliary variables :math:`\\mb{z}` at\n          previous iteration.\n    \"\"\"\n\n    def __init__(\n        self,\n        f: Functional,\n        g: Functional,\n        C: LinearOperator,\n        mu: float,\n        nu: float,\n        x0: Optional[Union[Array, BlockArray]] = None,\n        **kwargs,\n    ):\n        r\"\"\"Initialize a :class:`LinearizedADMM` object.\n\n        Args:\n            f: Functional :math:`f` (usually a loss function).\n            g: Functional :math:`g`.\n            C: Operator :math:`C`.\n            mu: First algorithm parameter.\n            nu: Second algorithm parameter.\n            x0: Starting point for :math:`\\mb{x}`. If ``None``, defaults\n                to an array of zeros.\n            **kwargs: Additional optional parameters handled by\n                initializer of base class :class:`.Optimizer`.\n        \"\"\"\n        self.f: Functional = f\n        self.g: Functional = g\n        self.C: LinearOperator = C\n        self.mu: float = mu\n        self.nu: float = nu\n\n        if x0 is None:\n            input_shape = C.input_shape\n            dtype = C.input_dtype\n            x0 = snp.zeros(input_shape, dtype=dtype)\n        self.x = x0\n        self.z, self.z_old = self.z_init(self.x)\n        self.u = self.u_init(self.x)\n\n        super().__init__(**kwargs)\n\n    def _working_vars_finite(self) -> bool:\n        \"\"\"Determine where ``NaN`` of ``Inf`` encountered in solve.\n\n        Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in\n        a solver working variable.\n        \"\"\"\n        return (\n            snp.all(snp.isfinite(self.x))\n            and snp.all(snp.isfinite(self.z))\n            and snp.all(snp.isfinite(self.u))\n        )\n\n    def _objective_evaluatable(self):\n        \"\"\"Determine whether the objective function can be evaluated.\"\"\"\n        return self.f.has_eval and self.g.has_eval\n\n    def _itstat_extra_fields(self):\n        \"\"\"Define linearized ADMM-specific iteration statistics fields.\"\"\"\n        itstat_fields = {\"Prml Rsdl\": \"%9.3e\", \"Dual Rsdl\": \"%9.3e\"}\n        itstat_attrib = [\"norm_primal_residual()\", \"norm_dual_residual()\"]\n        return itstat_fields, itstat_attrib\n\n    def _state_variable_names(self) -> List[str]:\n        return [\"x\", \"z\", \"z_old\", \"u\"]\n\n    def minimizer(self) -> Union[Array, BlockArray]:\n        return self.x\n\n    def objective(\n        self,\n        x: Optional[Union[Array, BlockArray]] = None,\n        z: Optional[Union[Array, BlockArray]] = None,\n    ) -> float:\n        r\"\"\"Evaluate the objective function.\n\n\n        Evaluate the objective function\n\n        .. math::\n            f(\\mb{x}) + g(\\mb{z}) \\;.\n\n\n        Args:\n            x: Point at which to evaluate objective function. If\n               ``None``, the objective is evaluated at the current\n               iterate :code:`self.x`.\n            z: Point at which to evaluate objective function. If\n               ``None``, the objective is evaluated at the current\n               iterate :code:`self.z`.\n\n        Returns:\n            scalar: Value of the objective function.\n        \"\"\"\n        if (x is None) != (z is None):\n            raise ValueError(\"Both or neither of arguments 'x' and 'z' must be supplied.\")\n        if x is None:\n            x = self.x\n            z = self.z\n        return self.f(x) + self.g(z)\n\n    def norm_primal_residual(self, x: Optional[Union[Array, BlockArray]] = None) -> float:\n        r\"\"\"Compute the :math:`\\ell_2` norm of the primal residual.\n\n        Compute the :math:`\\ell_2` norm of the primal residual\n\n        .. math::\n            \\norm{C \\mb{x} - \\mb{z}}_2 \\;.\n\n        Args:\n            x: Point at which to evaluate primal residual. If ``None``,\n               the primal residual is evaluated at the current iterate\n               :code:`self.x`.\n\n        Returns:\n            Norm of primal residual.\n        \"\"\"\n        if x is None:\n            x = self.x\n\n        return norm(self.C(self.x) - self.z)\n\n    def norm_dual_residual(self) -> float:\n        r\"\"\"Compute the :math:`\\ell_2` norm of the dual residual.\n\n        Compute the :math:`\\ell_2` norm of the dual residual\n\n        .. math::\n            \\norm{\\mb{z}^{(k)} - \\mb{z}^{(k-1)}}_2 \\;.\n\n        Returns:\n            Current norm of dual residual.\n        \"\"\"\n        return norm(self.C.adj(self.z - self.z_old))\n\n    def z_init(\n        self, x0: Union[Array, BlockArray]\n    ) -> Tuple[Union[Array, BlockArray], Union[Array, BlockArray]]:\n        r\"\"\"Initialize auxiliary variable :math:`\\mb{z}`.\n\n        Initialized to\n\n        .. math::\n            \\mb{z} = C \\mb{x}^{(0)} \\;.\n\n        :code:`z` and :code:`z_old` are initialized to the same value.\n\n        Args:\n            x0: Starting point for :math:`\\mb{x}`.\n        \"\"\"\n        z = self.C(x0)\n        z_old = z\n        return z, z_old\n\n    def u_init(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]:\n        r\"\"\"Initialize scaled Lagrange multiplier :math:`\\mb{u}`.\n\n        Initialized to\n\n        .. math::\n            \\mb{u} = \\mb{0} \\;.\n\n        Note that the parameter `x0` is unused, but is provided for\n        potential use in an overridden method.\n\n        Args:\n            x0: Starting point for :math:`\\mb{x}`.\n        \"\"\"\n        u = snp.zeros(self.C.output_shape, dtype=self.C.output_dtype)\n        return u\n\n    def step(self):\n        r\"\"\"Perform a single linearized ADMM iteration.\n\n        The primary variable :math:`\\mb{x}` is updated by computing\n\n        .. math::\n            \\mb{x}^{(k+1)} = \\mathrm{prox}_{\\mu f} \\left( \\mb{x}^{(k)} -\n            (\\mu / \\nu) A^T \\left(A \\mb{x}^{(k)} - \\mb{z}^{(k)} +\n            \\mb{u}^{(k)} \\right) \\right) \\;.\n\n        The auxiliary variable is updated according to\n\n        .. math::\n            \\mb{z}^{(k+1)} = \\mathrm{prox}_{\\nu g} \\left(A \\mb{x}^{(k+1)}\n            + \\mb{u}^{(k)} \\right) \\;,\n\n        and the scaled Lagrange multiplier is updated according to\n\n        .. math::\n            \\mb{u}^{(k+1)} =  \\mb{u}^{(k)} + C \\mb{x}^{(k+1)} -\n            \\mb{z}^{(k+1)} \\;.\n        \"\"\"\n        proxarg = self.x - (self.mu / self.nu) * self.C.conj().T(self.C(self.x) - self.z + self.u)\n        self.x = self.f.prox(proxarg, self.mu, v0=self.x)\n\n        self.z_old = self.z\n        Cx = self.C(self.x)\n        self.z = self.g.prox(Cx + self.u, self.nu, v0=self.z)\n        self.u = self.u + Cx - self.z\n"
  },
  {
    "path": "scico/optimize/_padmm.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Proximal ADMM solvers.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom typing import List, Optional, Tuple, Union\n\nimport scico.numpy as snp\nfrom scico import cvjp, jvp\nfrom scico.function import Function\nfrom scico.functional import Functional\nfrom scico.linop import Identity, LinearOperator, operator_norm\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.linalg import norm\nfrom scico.typing import BlockShape, DType, PRNGKey, Shape\n\nfrom ._common import Optimizer\n\n# mypy: disable-error-code=override\n\n\nclass ProximalADMMBase(Optimizer):\n    r\"\"\"Base class for proximal ADMM solvers.\n\n    Attributes:\n        f (:class:`.Functional`): Functional :math:`f` (usually a\n           :class:`.Loss`).\n        g (:class:`.Functional`): Functional :math:`g`.\n        rho (scalar): Penalty parameter.\n        mu (scalar): First algorithm parameter.\n        nu (scalar): Second algorithm parameter.\n        x (array-like): Solution variable.\n        z (array-like): Auxiliary variables :math:`\\mb{z}` at current\n          iteration.\n        z_old (array-like): Auxiliary variables :math:`\\mb{z}` at\n          previous iteration.\n        u (array-like): Scaled Lagrange multipliers :math:`\\mb{u}` at\n           current iteration.\n        u_old (array-like): Scaled Lagrange multipliers :math:`\\mb{u}` at\n           previous iteration.\n    \"\"\"\n\n    def __init__(\n        self,\n        f: Functional,\n        g: Functional,\n        rho: float,\n        mu: float,\n        nu: float,\n        xshape: Union[Shape, BlockShape],\n        zshape: Union[Shape, BlockShape],\n        ushape: Union[Shape, BlockShape],\n        xdtype: DType,\n        zdtype: DType,\n        udtype: DType,\n        x0: Optional[Union[Array, BlockArray]] = None,\n        z0: Optional[Union[Array, BlockArray]] = None,\n        u0: Optional[Union[Array, BlockArray]] = None,\n        fast_dual_residual: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"Initialize a :class:`ProximalADMMBase` object.\n\n        Args:\n            f: Functional :math:`f` (usually a loss function).\n            g: Functional :math:`g`.\n            rho: Penalty parameter.\n            mu: First algorithm parameter.\n            nu: Second algorithm parameter.\n            xshape: Shape of variable :math:`\\mb{x}`.\n            zshape: Shape of variable :math:`\\mb{z}`.\n            ushape: Shape of variable :math:`\\mb{u}`.\n            xdtype: Dtype of variable :math:`\\mb{x}`.\n            zdtype: Dtype of variable :math:`\\mb{z}`.\n            udtype: Dtype of variable :math:`\\mb{u}`.\n            x0: Initial value for :math:`\\mb{x}`. If ``None``, defaults\n                to an array of zeros.\n            z0: Initial value for :math:`\\mb{z}`. If ``None``, defaults\n                to an array of zeros.\n            u0: Initial value for :math:`\\mb{u}`. If ``None``, defaults\n                to an array of zeros.\n            fast_dual_residual: Flag indicating whether to use fast\n                approximation to the dual residual, or a slower but more\n                accurate calculation.\n            **kwargs: Additional optional parameters handled by\n                initializer of base class :class:`.Optimizer`.\n        \"\"\"\n        self.f: Functional = f\n        self.g: Functional = g\n\n        self.rho: float = rho\n        self.mu: float = mu\n        self.nu: float = nu\n        self.fast_dual_residual: bool = fast_dual_residual\n\n        if x0 is None:\n            x0 = snp.zeros(xshape, dtype=xdtype)\n        self.x = x0\n        if z0 is None:\n            z0 = snp.zeros(zshape, dtype=zdtype)\n        self.z = z0\n        self.z_old = self.z\n        if u0 is None:\n            u0 = snp.zeros(ushape, dtype=udtype)\n        self.u = u0\n        self.u_old = self.u\n\n        super().__init__(**kwargs)\n\n    def _working_vars_finite(self) -> bool:\n        \"\"\"Determine where ``NaN`` of ``Inf`` encountered in solve.\n\n        Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in\n        a solver working variable.\n        \"\"\"\n        return (\n            snp.all(snp.isfinite(self.x))\n            and snp.all(snp.isfinite(self.z))\n            and snp.all(snp.isfinite(self.u))\n        )\n\n    def _objective_evaluatable(self):\n        \"\"\"Determine whether the objective function can be evaluated.\"\"\"\n        return self.f.has_eval and self.g.has_eval\n\n    def _itstat_extra_fields(self):\n        \"\"\"Define linearized ADMM-specific iteration statistics fields.\"\"\"\n        itstat_fields = {\"Prml Rsdl\": \"%9.3e\", \"Dual Rsdl\": \"%9.3e\"}\n        itstat_attrib = [\"norm_primal_residual()\", \"norm_dual_residual()\"]\n        return itstat_fields, itstat_attrib\n\n    def _state_variable_names(self) -> List[str]:\n        return [\"x\", \"z\", \"z_old\", \"u\", \"u_old\"]\n\n    def minimizer(self) -> Union[Array, BlockArray]:\n        return self.x\n\n    def objective(\n        self,\n        x: Optional[Union[Array, BlockArray]] = None,\n        z: Optional[Union[Array, BlockArray]] = None,\n    ) -> float:\n        r\"\"\"Evaluate the objective function.\n\n        Evaluate the objective function\n\n        .. math::\n            f(\\mb{x}) + g(\\mb{z}) \\;.\n\n        Args:\n            x: Point at which to evaluate objective function. If\n               ``None``, the objective is evaluated at the current\n               iterate :code:`self.x`.\n            z: Point at which to evaluate objective function. If\n               ``None``, the objective is evaluated at the current\n               iterate :code:`self.z`.\n\n        Returns:\n            scalar: Current value of the objective function.\n        \"\"\"\n        if (x is None) != (z is None):\n            raise ValueError(\"Both or neither of arguments 'x' and 'z' must be supplied\")\n        if x is None:\n            x = self.x\n            z = self.z\n        return self.f(x) + self.g(z)\n\n\nclass ProximalADMM(ProximalADMMBase):\n    r\"\"\"Proximal alternating direction method of multipliers.\n\n    |\n\n    Solve an optimization problem of the form\n\n    .. math::\n        \\argmin_{\\mb{x}} \\; f(\\mb{x}) + g(\\mb{z}) \\;\n        \\text{such that}\\; A \\mb{x} + B \\mb{z} = \\mb{c} \\;,\n\n    where :math:`f` and :math:`g` are instances of :class:`.Functional`,\n    (in most cases :math:`f` will, more specifically be an instance\n    of :class:`.Loss`), and :math:`A` and :math:`B` are instances of\n    :class:`LinearOperator`.\n\n    The optimization problem is solved via a variant of the proximal ADMM\n    algorithm :cite:`deng-2015-global`, consisting of the iterations\n    (see :meth:`step`)\n\n    .. math::\n       \\begin{aligned}\n       \\mb{x}^{(k+1)} &= \\mathrm{prox}_{\\rho^{-1} \\mu^{-1} f} \\left(\n         \\mb{x}^{(k)} - \\mu^{-1} A^T \\left(2 \\mb{u}^{(k)} -\n         \\mb{u}^{(k-1)} \\right) \\right) \\\\\n       \\mb{z}^{(k+1)} &= \\mathrm{prox}_{\\rho^{-1} \\nu^{-1} g} \\left(\n         \\mb{z}^{(k)} - \\nu^{-1} B^T \\left(\n         A \\mb{x}^{(k+1)} + B \\mb{z}^{(k)} - \\mb{c} + \\mb{u}^{(k)}\n         \\right) \\right) \\\\\n       \\mb{u}^{(k+1)} &=  \\mb{u}^{(k)} + A \\mb{x}^{(k+1)} + B\n         \\mb{z}^{(k+1)} - \\mb{c}  \\;.\n       \\end{aligned}\n\n    Parameters :math:`\\mu` and :math:`\\nu` are required to satisfy\n\n    .. math::\n       \\mu > \\norm{ A }_2^2 \\quad \\text{and} \\quad \\nu > \\norm{ B }_2^2 \\;.\n\n\n    Attributes:\n        A (:class:`.LinearOperator`): :math:`A` linear operator.\n        B (:class:`.LinearOperator`): :math:`B` linear operator.\n        c (array-like): constant :math:`\\mb{c}`.\n    \"\"\"\n\n    def __init__(\n        self,\n        f: Functional,\n        g: Functional,\n        A: LinearOperator,\n        rho: float,\n        mu: float,\n        nu: float,\n        B: Optional[LinearOperator] = None,\n        c: Optional[Union[float, Array, BlockArray]] = None,\n        x0: Optional[Union[Array, BlockArray]] = None,\n        z0: Optional[Union[Array, BlockArray]] = None,\n        u0: Optional[Union[Array, BlockArray]] = None,\n        fast_dual_residual: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"Initialize a :class:`ProximalADMM` object.\n\n        Args:\n            f: Functional :math:`f` (usually a loss function).\n            g: Functional :math:`g`.\n            A: Linear operator :math:`A`.\n            rho: Penalty parameter.\n            mu: First algorithm parameter.\n            nu: Second algorithm parameter.\n            B: Linear operator :math:`B` (if ``None``, :math:`B = -I`\n               where :math:`I` is the identity operator).\n            c: Constant :math:`\\mb{c}`. If ``None``, defaults to zero.\n            x0: Starting value for :math:`\\mb{x}`. If ``None``, defaults\n                to an array of zeros.\n            z0: Starting value for :math:`\\mb{z}`. If ``None``, defaults\n                to an array of zeros.\n            u0: Starting value for :math:`\\mb{u}`. If ``None``, defaults\n                to an array of zeros.\n            fast_dual_residual: Flag indicating whether to use fast\n                approximation to the dual residual, or a slower but more\n                accurate calculation.\n            **kwargs: Additional optional parameters handled by\n                initializer of base class :class:`.Optimizer`.\n        \"\"\"\n        self.A: LinearOperator = A\n        if B is None:\n            self.B = -Identity(self.A.output_shape, self.A.output_dtype)\n        else:\n            self.B = B\n        if c is None:\n            self.c = 0.0\n        else:\n            self.c = c\n\n        super().__init__(\n            f,\n            g,\n            rho,\n            mu,\n            nu,\n            self.A.input_shape,\n            self.B.input_shape,\n            self.A.output_shape,\n            self.A.input_dtype,\n            self.B.input_dtype,\n            self.A.output_dtype,\n            x0=x0,\n            z0=z0,\n            u0=u0,\n            fast_dual_residual=fast_dual_residual,\n            **kwargs,\n        )\n\n    def norm_primal_residual(\n        self,\n        x: Optional[Union[Array, BlockArray]] = None,\n        z: Optional[Union[Array, BlockArray]] = None,\n    ) -> float:\n        r\"\"\"Compute the :math:`\\ell_2` norm of the primal residual.\n\n        Compute the :math:`\\ell_2` norm of the primal residual\n\n        .. math::\n            \\norm{A \\mb{x} + B \\mb{z} - \\mb{c}}_2 \\;.\n\n        Args:\n            x: Point at which to evaluate primal residual. If ``None``,\n               the primal residual is evaluated at the current iterate\n               :code:`self.x`.\n            z: Point at which to evaluate primal residual. If ``None``,\n               the primal residual is evaluated at the current iterate\n               :code:`self.z`.\n\n        Returns:\n            Norm of primal residual.\n        \"\"\"\n        if (x is None) != (z is None):\n            raise ValueError(\"Both or neither of arguments 'x' and 'z' must be supplied\")\n        if x is None:\n            x = self.x\n            z = self.z\n\n        return norm(self.A(x) + self.B(z) - self.c)\n\n    def norm_dual_residual(self) -> float:\n        r\"\"\"Compute the :math:`\\ell_2` norm of the dual residual.\n\n        Compute the :math:`\\ell_2` norm of the dual residual. If the flag\n        requesting a fast approximate calculation is set, it is computed\n        as\n\n        .. math::\n            \\norm{\\mb{z}^{(k+1)} - \\mb{z}^{(k)}}_2 \\;,\n\n        otherwise it is computed as\n\n        .. math::\n            \\norm{A^T B ( \\mb{z}^{(k+1)} - \\mb{z}^{(k)} ) }_2 \\;.\n\n        Returns:\n            Current norm of dual residual.\n        \"\"\"\n        if self.fast_dual_residual:\n            rsdl = self.z - self.z_old  # fast but poor approximation\n        else:\n            rsdl = self.A.H(self.B(self.z - self.z_old))\n        return norm(rsdl)\n\n    def step(self):\n        r\"\"\"Perform a single algorithm iteration.\n\n        Perform a single algorithm iteration.\n        \"\"\"\n        proxarg = self.x - (1.0 / self.mu) * self.A.H(2.0 * self.u - self.u_old)\n        self.x = self.f.prox(proxarg, (1.0 / (self.rho * self.mu)), v0=self.x)\n        proxarg = self.z - (1.0 / self.nu) * self.B.H(\n            self.A(self.x) + self.B(self.z) - self.c + self.u\n        )\n        self.z_old = self.z\n        self.z = self.g.prox(proxarg, (1.0 / (self.rho * self.nu)), v0=self.z)\n        self.u_old = self.u\n        self.u = self.u + self.A(self.x) + self.B(self.z) - self.c\n\n    @staticmethod\n    def estimate_parameters(\n        A: LinearOperator,\n        B: Optional[LinearOperator] = None,\n        factor: Optional[float] = 1.01,\n        maxiter: int = 100,\n        key: Optional[PRNGKey] = None,\n    ) -> Tuple[float, float]:\n        r\"\"\"Estimate `mu` and `nu` parameters of :class:`ProximalADMM`.\n\n        Find values of the `mu` and `nu` parameters of :class:`ProximalADMM`\n        that respect the constraints\n\n        .. math::\n           \\mu > \\norm{ A }_2^2 \\quad \\text{and} \\quad \\nu >\n           \\norm{ B }_2^2 \\;.\n\n        Args:\n            A: Linear operator :math:`A`.\n            B: Linear operator :math:`B` (if ``None``, :math:`B = -I`\n               where :math:`I` is the identity operator).\n            factor: Safety factor with which to multiply estimated\n               operator norms to ensure strict inequality compliance. If\n               ``None``, return the estimated squared operator norms.\n            maxiter: Maximum number of power iterations to use in operator\n               norm estimation (see :func:`.operator_norm`). Default: 100.\n            key: Jax PRNG key to use in operator norm estimation (see\n               :func:`.operator_norm`). Defaults to ``None``, in which\n               case a new key is created.\n\n        Returns:\n            A tuple (`mu`, `nu`) representing the estimated parameter\n            values or corresponding squared operator norm values,\n            depending on the value of the `factor` parameter.\n        \"\"\"\n        mu = operator_norm(A, maxiter=maxiter, key=key) ** 2\n        if B is None:\n            nu = 1.0\n        else:\n            nu = operator_norm(B, maxiter=maxiter, key=key) ** 2\n        if factor is None:\n            return (mu, nu)\n        else:\n            return (factor * mu, factor * nu)\n\n\nclass NonLinearPADMM(ProximalADMMBase):\n    r\"\"\"Non-linear proximal alternating direction method of multipliers.\n\n    |\n\n    Solve an optimization problem of the form\n\n    .. math::\n        \\argmin_{\\mb{x}} \\; f(\\mb{x}) + g(\\mb{z}) \\;\n        \\text{such that}\\; H(\\mb{x}, \\mb{z}) = 0 \\;,\n\n    where :math:`f` and :math:`g` are instances of :class:`.Functional`,\n    (in most cases :math:`f` will, more specifically be an instance\n    of :class:`.Loss`), and :math:`H` is a function.\n\n    The optimization problem is solved via a variant of the proximal ADMM\n    algorithm for problems with a non-linear operator constraint\n    :cite:`benning-2016-preconditioned`, consisting of the\n    iterations (see :meth:`step`)\n\n    .. math::\n       \\begin{aligned}\n       A^{(k)} &= J_{\\mb{x}} H(\\mb{x}^{(k)}, \\mb{z}^{(k)}) \\\\\n       \\mb{x}^{(k+1)} &= \\mathrm{prox}_{\\rho^{-1} \\mu^{-1} f} \\left(\n         \\mb{x}^{(k)} - \\mu^{-1} (A^{(k)})^T \\left(2 \\mb{u}^{(k)} -\n         \\mb{u}^{(k-1)} \\right) \\right) \\\\\n       B^{(k)} &= J_{\\mb{z}} H(\\mb{x}^{(k+1)}, \\mb{z}^{(k)}) \\\\\n       \\mb{z}^{(k+1)} &= \\mathrm{prox}_{\\rho^{-1} \\nu^{-1} g} \\left(\n         \\mb{z}^{(k)} - \\nu^{-1} (B^{(k)})^T \\left(\n         H(\\mb{x}^{(k+1)}, \\mb{z}^{(k)}) + \\mb{u}^{(k)} \\right) \\right) \\\\\n       \\mb{u}^{(k+1)} &=  \\mb{u}^{(k)} + H(\\mb{x}^{(k+1)},\n         \\mb{z}^{(k+1)})  \\;.\n       \\end{aligned}\n\n    Parameters :math:`\\mu` and :math:`\\nu` are required to satisfy\n\n    .. math::\n       \\mu > \\norm{ A^{(k)} }_2^2 \\quad \\text{and} \\quad \\nu > \\norm{ B^{(k)} }_2^2\n\n    for all :math:`A^{(k)}` and :math:`B^{(k)}`.\n\n\n    Attributes:\n        H (:class:`.Function`): :math:`H` function.\n    \"\"\"\n\n    def __init__(\n        self,\n        f: Functional,\n        g: Functional,\n        H: Function,\n        rho: float,\n        mu: float,\n        nu: float,\n        x0: Optional[Union[Array, BlockArray]] = None,\n        z0: Optional[Union[Array, BlockArray]] = None,\n        u0: Optional[Union[Array, BlockArray]] = None,\n        fast_dual_residual: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"Initialize a :class:`NonLinearPADMM` object.\n\n        Args:\n            f: Functional :math:`f` (usually a loss function).\n            g: Functional :math:`g`.\n            H: Function :math:`H`.\n            rho: Penalty parameter.\n            mu: First algorithm parameter.\n            nu: Second algorithm parameter.\n            x0: Starting value for :math:`\\mb{x}`. If ``None``, defaults\n                to an array of zeros.\n            z0: Starting value for :math:`\\mb{z}`. If ``None``, defaults\n                to an array of zeros.\n            u0: Starting value for :math:`\\mb{u}`. If ``None``, defaults\n                to an array of zeros.\n            fast_dual_residual: Flag indicating whether to use fast\n                approximation to the dual residual, or a slower but more\n                accurate calculation.\n            **kwargs: Additional optional parameters handled by\n                initializer of base class :class:`.Optimizer`.\n        \"\"\"\n        self.H: Function = H\n\n        super().__init__(\n            f,\n            g,\n            rho,\n            mu,\n            nu,\n            H.input_shapes[0],\n            H.input_shapes[1],\n            H.output_shape,\n            H.input_dtypes[0],\n            H.input_dtypes[1],\n            H.output_dtype,\n            x0=x0,\n            z0=z0,\n            u0=u0,\n            fast_dual_residual=fast_dual_residual,\n            **kwargs,\n        )\n\n    def norm_primal_residual(\n        self,\n        x: Optional[Union[Array, BlockArray]] = None,\n        z: Optional[Union[Array, BlockArray]] = None,\n    ) -> float:\n        r\"\"\"Compute the :math:`\\ell_2` norm of the primal residual.\n\n        Compute the :math:`\\ell_2` norm of the primal residual\n\n        .. math::\n            \\norm{H(\\mb{x}, \\mb{z})}_2 \\;.\n\n        Args:\n            x: Point at which to evaluate primal residual. If ``None``,\n               the primal residual is evaluated at the current iterate\n               :code:`self.x`.\n            z: Point at which to evaluate primal residual. If ``None``,\n               the primal residual is evaluated at the current iterate\n               :code:`self.z`.\n\n        Returns:\n            Norm of primal residual.\n        \"\"\"\n        if (x is None) != (z is None):\n            raise ValueError(\"Both or neither of arguments 'x' and 'z' must be supplied\")\n        if x is None:\n            x = self.x\n            z = self.z\n\n        return norm(self.H(x, z))\n\n    def norm_dual_residual(self) -> float:\n        r\"\"\"Compute the :math:`\\ell_2` norm of the dual residual.\n\n        Compute the :math:`\\ell_2` norm of the dual residual. If the flag\n        requesting a fast approximate calculation is set, it is computed\n        as\n\n        .. math::\n            \\norm{\\mb{z}^{(k+1)} - \\mb{z}^{(k)}}_2 \\;,\n\n        otherwise it is computed as\n\n        .. math::\n            \\norm{A^T B ( \\mb{z}^{(k+1)} - \\mb{z}^{(k)} ) }_2 \\;,\n\n        where\n\n        .. math::\n            A &= J_{\\mb{x}} H(\\mb{x}^{(k+1)}, \\mb{z}^{(k+1)}) \\\\\n            B &= J_{\\mb{z}} H(\\mb{x}^{(k+1)}, \\mb{z}^{(k+1)}) \\;.\n\n        Returns:\n            Current norm of dual residual.\n        \"\"\"\n        if self.fast_dual_residual:\n            rsdl = self.z - self.z_old  # fast but poor approximation\n        else:\n            Hz = lambda z: self.H(self.x, z)\n            B = lambda u: jvp(Hz, (self.z,), (u,))[1]\n            Hx = lambda x: self.H(x, self.z)\n            AH = cvjp(Hx, self.x)[1]\n            rsdl = AH(B(self.z - self.z_old))\n        return norm(rsdl)\n\n    def step(self):\n        r\"\"\"Perform a single algorithm iteration.\n\n        Perform a single algorithm iteration.\n        \"\"\"\n        AH = self.H.vjp(0, self.x, self.z, conjugate=True)[1]\n        proxarg = self.x - (1.0 / self.mu) * AH(2.0 * self.u - self.u_old)\n        self.x = self.f.prox(proxarg, (1.0 / (self.rho * self.mu)), v0=self.x)\n        BH = self.H.vjp(1, self.x, self.z, conjugate=True)[1]\n        proxarg = self.z - (1.0 / self.nu) * BH(self.H(self.x, self.z) + self.u)\n        self.z_old = self.z\n        self.z = self.g.prox(proxarg, (1.0 / (self.rho * self.nu)), v0=self.z)\n        self.u_old = self.u\n        self.u = self.u + self.H(self.x, self.z)\n\n    @staticmethod\n    def estimate_parameters(\n        H: Function,\n        x: Optional[Union[Array, BlockArray]] = None,\n        z: Optional[Union[Array, BlockArray]] = None,\n        factor: Optional[float] = 1.01,\n        maxiter: int = 100,\n        key: Optional[PRNGKey] = None,\n    ) -> Tuple[float, float]:\n        r\"\"\"Estimate `mu` and `nu` parameters of :class:`NonLinearPADMM`.\n\n        Find values of the `mu` and `nu` parameters of :class:`NonLinearPADMM`\n        that respect the constraints\n\n        .. math::\n           \\mu > \\norm{ J_x H(\\mb{x}, \\mb{z}) }_2^2 \\quad \\text{and} \\quad\n           \\nu > \\norm{ J_z H(\\mb{x}, \\mb{z}) }_2^2 \\;.\n\n        Args:\n            H: Constraint function :math:`H`.\n            x: Value of :math:`\\mb{x}` at which to evaluate the Jacobian.\n               If ``None``, defaults to an array of zeros.\n            z: Value of :math:`\\mb{z}` at which to evaluate the Jacobian.\n               If ``None``, defaults to an array of zeros.\n            factor: Safety factor with which to multiply estimated\n               operator norms to ensure strict inequality compliance. If\n               ``None``, return the estimated squared operator norms.\n            maxiter: Maximum number of power iterations to use in operator\n               norm estimation (see :func:`.operator_norm`). Default: 100.\n            key: Jax PRNG key to use in operator norm estimation (see\n               :func:`.operator_norm`). Defaults to ``None``, in which\n               case a new key is created.\n\n        Returns:\n            A tuple (`mu`, `nu`) representing the estimated parameter\n            values or corresponding squared operator norm values,\n            depending on the value of the `factor` parameter.\n        \"\"\"\n        if x is None:\n            x = snp.zeros(H.input_shapes[0], dtype=H.input_dtypes[0])\n        if z is None:\n            z = snp.zeros(H.input_shapes[1], dtype=H.input_dtypes[1])\n        Jx = H.jacobian(0, x, z)\n        Jz = H.jacobian(1, x, z)\n        mu = operator_norm(Jx, maxiter=maxiter, key=key) ** 2\n        nu = operator_norm(Jz, maxiter=maxiter, key=key) ** 2\n        if factor is None:\n            return (mu, nu)\n        else:\n            return (factor * mu, factor * nu)\n"
  },
  {
    "path": "scico/optimize/_pgm.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Proximal Gradient Method classes.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom functools import partial\nfrom typing import List, Optional, Union\n\nimport jax\n\nimport scico.numpy as snp\nfrom scico.functional import Functional\nfrom scico.loss import Loss\nfrom scico.numpy import Array, BlockArray\n\nfrom ._common import Optimizer\nfrom ._pgmaux import (\n    AdaptiveBBStepSize,\n    BBStepSize,\n    PGMStepSize,\n    RobustLineSearchStepSize,\n)\n\n\nclass PGM(Optimizer):\n    r\"\"\"Proximal gradient method (PGM) algorithm.\n\n    Minimize a functional of the form :math:`f(\\mb{x}) + g(\\mb{x})`,\n    where :math:`f` and the :math:`g` are instances of\n    :class:`.Functional`. Functional :math:`f` should be differentiable\n    and have a Lipschitz continuous derivative, and functional :math:`g`\n    should have a proximal operator defined.\n\n    The step size :math:`\\alpha` of the algorithm is defined in terms of\n    its reciprocal :math:`L`, i.e. :math:`\\alpha = 1 / L`. The initial\n    value for this parameter, `L0`, is required to satisfy\n\n    .. math::\n       L_0 \\geq K(\\nabla f) \\;,\n\n    where :math:`K(\\nabla f)` denotes the Lipschitz constant of the\n    gradient of :math:`f`. When `f` is an instance of\n    :class:`.SquaredL2Loss` with a :class:`.LinearOperator` `A`,\n\n    .. math::\n       K(\\nabla f) = \\lambda_{ \\mathrm{max} }( A^H A ) = \\| A \\|_2^2 \\;,\n\n    where :math:`\\lambda_{\\mathrm{max}}(B)` denotes the largest\n    eigenvalue of :math:`B`.\n\n    The evolution of the step size is controlled by auxiliary class\n    :class:`.PGMStepSize` and derived classes. The default\n    :class:`.PGMStepSize` simply sets :math:`L = L_0`, while the derived\n    classes implement a variety of adaptive strategies.\n    \"\"\"\n\n    def __init__(\n        self,\n        f: Union[Loss, Functional],\n        g: Functional,\n        L0: float,\n        x0: Union[Array, BlockArray],\n        step_size: Optional[PGMStepSize] = None,\n        **kwargs,\n    ):\n        r\"\"\"\n\n        Args:\n            f: Instance of :class:`.Loss` or :class:`.Functional` with\n               defined `grad` method.\n            g: Instance of :class:`.Functional` with defined prox method.\n            L0: Initial estimate of Lipschitz constant of gradient of `f`.\n            x0: Starting point for :math:`\\mb{x}`.\n            step_size: Instance of an auxiliary class of type\n                :class:`.PGMStepSize` determining the evolution of the\n                algorithm step size.\n            **kwargs: Additional optional parameters handled by\n                initializer of base class :class:`.Optimizer`.\n        \"\"\"\n\n        #: Functional or Loss to minimize; must have grad method defined.\n        self.f: Union[Loss, Functional] = f\n\n        if g.has_prox is not True:\n            raise ValueError(f\"Functional 'g' ({type(g)}) must have a prox method.\")\n\n        #: Functional to minimize; must have prox defined\n        self.g: Functional = g\n\n        if step_size is None:\n            step_size = PGMStepSize()\n        self.step_size: PGMStepSize = step_size\n        self.step_size.internal_init(self)\n        self.L: float = L0  # reciprocal of step size (estimate of Lipschitz constant of ∇f)\n        self.fixed_point_residual = snp.inf\n\n        self.x: Union[Array, BlockArray] = x0  # current estimate of solution\n\n        super().__init__(**kwargs)\n\n    def x_step(self, v: Union[Array, BlockArray], L: float) -> Union[Array, BlockArray]:\n        \"\"\"Compute update for variable `x`.\"\"\"\n        return PGM._x_step(self.f, self.g, v, L)\n\n    @staticmethod\n    @partial(jax.jit, static_argnums=(0, 1))\n    def _x_step(\n        f: Functional, g: Functional, v: Union[Array, BlockArray], L: float\n    ) -> Union[Array, BlockArray]:\n        \"\"\"Jit-able static method for computing update for variable `x`.\"\"\"\n        return g.prox(v - 1.0 / L * f.grad(v), 1.0 / L)\n\n    def _working_vars_finite(self) -> bool:\n        \"\"\"Determine where ``NaN`` of ``Inf`` encountered in solve.\n\n        Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in\n        a solver working variable.\n        \"\"\"\n        return snp.all(snp.isfinite(self.x))\n\n    def _objective_evaluatable(self):\n        \"\"\"Determine whether the objective function can be evaluated.\"\"\"\n        return self.f.has_eval and self.g.has_eval\n\n    def _itstat_extra_fields(self):\n        \"\"\"Define linearized ADMM-specific iteration statistics fields.\"\"\"\n        itstat_fields = {\"L\": \"%9.3e\", \"Residual\": \"%9.3e\"}\n        itstat_attrib = [\"L\", \"norm_residual()\"]\n        return itstat_fields, itstat_attrib\n\n    def _state_variable_names(self) -> List[str]:\n        return [\"x\", \"L\"]\n\n    def minimizer(self) -> Union[Array, BlockArray]:\n        return self.x\n\n    def objective(self, x: Optional[Union[Array, BlockArray]] = None) -> float:\n        r\"\"\"Evaluate the objective function :math:`f(\\mb{x}) + g(\\mb{x})`.\"\"\"\n        if x is None:\n            x = self.x\n        return self.f(x) + self.g(x)\n\n    def f_quad_approx(\n        self, x: Union[Array, BlockArray], y: Union[Array, BlockArray], L: float\n    ) -> float:\n        r\"\"\"Evaluate the quadratic approximation to function :math:`f`.\n\n        Evaluate the quadratic approximation to function :math:`f`,\n        corresponding to :math:`\\hat{f}_{L}(\\mb{x}, \\mb{y}) = f(\\mb{y}) +\n        \\nabla f(\\mb{y})^H (\\mb{x} - \\mb{y}) + \\frac{L}{2} \\left\\|\\mb{x}\n        - \\mb{y}\\right\\|_2^2`.\n        \"\"\"\n        diff_xy = x - y\n        return (\n            self.f(y)\n            + snp.sum(snp.real(snp.conj(self.f.grad(y)) * diff_xy))\n            + 0.5 * L * snp.linalg.norm(diff_xy) ** 2\n        )\n\n    def norm_residual(self) -> float:\n        r\"\"\"Return the fixed point residual.\n\n        Return the fixed point residual (see Sec. 4.3 of\n        :cite:`liu-2018-first`).\n        \"\"\"\n        return self.fixed_point_residual\n\n    def step(self):\n        \"\"\"Take a single PGM step.\"\"\"\n        # Update reciprocal of step size using current solution.\n        self.L = self.step_size.update(self.x)\n        x = self.x_step(self.x, self.L)\n        self.fixed_point_residual = snp.linalg.norm(self.x - x)\n        self.x = x\n\n\nclass AcceleratedPGM(PGM):\n    r\"\"\"Accelerated proximal gradient method (APGM) algorithm.\n\n    Minimize a function of the form :math:`f(\\mb{x}) + g(\\mb{x})`, where\n    :math:`f` and the :math:`g` are instances of :class:`.Functional`.\n    The accelerated form of PGM is also known as FISTA\n    :cite:`beck-2009-fast`.\n\n    See :class:`.PGM` for more detailed documentation.\n    \"\"\"\n\n    def __init__(\n        self,\n        f: Union[Loss, Functional],\n        g: Functional,\n        L0: float,\n        x0: Union[Array, BlockArray],\n        step_size: Optional[PGMStepSize] = None,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            f: Instance of :class:`.Loss` or :class:`.Functional` with\n               defined `grad` method.\n            g: Instance of :class:`.Functional` with defined prox method.\n            L0: Initial estimate of Lipschitz constant of gradient of `f`.\n            x0: Starting point for :math:`\\mb{x}`.\n            step_size: Instance of an auxiliary class of type\n                :class:`.PGMStepSize` determining the evolution of the\n                algorithm step size.\n            **kwargs: Additional optional parameters handled by\n                initializer of base class :class:`.Optimizer`.\n        \"\"\"\n        super().__init__(f=f, g=g, L0=L0, x0=x0, step_size=step_size, **kwargs)\n\n        self.v = x0\n        self.t = 1.0\n\n    def step(self):\n        \"\"\"Take a single AcceleratedPGM step.\"\"\"\n        x_old = self.x\n        # Update reciprocal of step size using current extrapolation.\n        if isinstance(self.step_size, (AdaptiveBBStepSize, BBStepSize)):\n            self.L = self.step_size.update(self.x)\n        else:\n            self.L = self.step_size.update(self.v)\n        if isinstance(self.step_size, RobustLineSearchStepSize):\n            # Robust line search step size uses a different extrapolation sequence.\n            # Update in solution is computed while updating the reciprocal of step size.\n            self.x = self.step_size.Z\n            self.fixed_point_residual = snp.linalg.norm(self.x - x_old)\n        else:\n            self.x = self.x_step(self.v, self.L)\n\n            self.fixed_point_residual = snp.linalg.norm(self.x - self.v)\n            t_old = self.t\n            self.t = 0.5 * (1 + snp.sqrt(1 + 4 * t_old**2))\n            self.v = self.x + ((t_old - 1) / self.t) * (self.x - x_old)\n\n    def _working_vars_finite(self) -> bool:\n        \"\"\"Determine where ``NaN`` of ``Inf`` encountered in solve.\n\n        Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in\n        a solver working variable.\n        \"\"\"\n        return snp.all(snp.isfinite(self.x)) and snp.all(snp.isfinite(self.v))\n\n    def _state_variable_names(self) -> List[str]:\n        \"\"\"Get optimizer state variable names.\"\"\"\n        return [\"x\", \"v\", \"t\", \"L\"]\n"
  },
  {
    "path": "scico/optimize/_pgmaux.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Proximal Gradient Method auxiliary classes.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom typing import Optional, Union\n\nimport jax\n\nimport scico.numpy as snp\nimport scico.optimize.pgm as sop\nfrom scico.numpy import Array, BlockArray\n\n\nclass PGMStepSize:\n    r\"\"\"Base class for computing the PGM step size.\n\n    Base class for computing the reciprocal of the step size for PGM\n    solvers.\n\n    The PGM solver implemented by :class:`.PGM` addresses a general\n    proximal gradient form that requires the specification of a step size\n    for the gradient descent step. This class is a base class for methods\n    that estimate the reciprocal of the step size (:math:`L` in PGM\n    equations).\n\n    Attributes:\n        pgm (:class:`.PGM`): PGM solver object to which the solver is\n           attached.\n    \"\"\"\n\n    def internal_init(self, pgm: sop.PGM):\n        \"\"\"Second stage initializer to be called by :meth:`.PGM.__init__`.\n\n        Args:\n            pgm: Reference to :class:`.PGM` object to which the\n              :class:`.StepSize` object is to be attached.\n        \"\"\"\n        self.pgm = pgm\n\n    def update(self, v: Union[Array, BlockArray]) -> float:\n        \"\"\"Hook for updating the step size in derived classes.\n\n        Hook for updating the reciprocal of the step size in derived\n        classes. The base class does not compute any update.\n\n        Args:\n            v: Current solution or current extrapolation (if accelerated\n               PGM).\n\n        Returns:\n            Current reciprocal of the step size.\n        \"\"\"\n        return self.pgm.L\n\n\nclass BBStepSize(PGMStepSize):\n    r\"\"\"Scheme for step size estimation based on Barzilai-Borwein method.\n\n    The Barzilai-Borwein method :cite:`barzilai-1988-stepsize` estimates\n    the step size :math:`\\alpha` as\n\n    .. math::\n       \\mb{\\Delta x} = \\mb{x}_k - \\mb{x}_{k-1} \\; \\\\\n       \\mb{\\Delta g} = \\nabla f(\\mb{x}_k) - \\nabla f (\\mb{x}_{k-1}) \\; \\\\\n       \\alpha = \\frac{\\mb{\\Delta x}^T \\mb{\\Delta g}}{\\mb{\\Delta g}^T\n       \\mb{\\Delta g}} \\;\\;.\n\n    Since the PGM solver uses the reciprocal of the step size, the value\n    :math:`L = 1 / \\alpha` is returned.\n\n    When applied to complex-valued problems, only the real part of the\n    inner product is used. When the inner product is negative, the\n    previous iterate is used instead.\n\n    Attributes:\n        pgm (:class:`.PGM`): PGM solver object to which the solver is\n           attached.\n    \"\"\"\n\n    def __init__(self):\n        \"\"\"Initialize a :class:`BBStepSize` object.\"\"\"\n        self.xprev = None\n        self.gradprev = None\n\n    def update(self, v: Union[Array, BlockArray]) -> float:\n        \"\"\"Update the reciprocal of the step size.\n\n        Args:\n            v: Current solution or current extrapolation (if accelerated\n               PGM).\n\n        Returns:\n            Updated reciprocal of the step size.\n        \"\"\"\n\n        if self.xprev is None:\n            # Solution and gradient of previous iterate are required.\n            # For first iteration these variables are stored and current estimate is returned.\n            self.xprev = v\n            self.gradprev = self.pgm.f.grad(self.xprev)\n            L = self.pgm.L\n        else:\n            Δx = v - self.xprev\n            gradv = self.pgm.f.grad(v)\n            Δg = gradv - self.gradprev\n            # Taking real part of inner products in case of complex-value problem.\n            den = snp.real(snp.sum(Δx.conj() * Δg))\n            num = snp.real(snp.sum(Δg.conj() * Δg))\n            L = num / den\n            # Revert to previous iterate if update results in nan or negative value.\n            if snp.isnan(L) or L <= 0.0:\n                L = self.pgm.L\n            # Store current state and gradient for next update.\n            self.xprev = v\n            self.gradprev = gradv\n        return L\n\n\nclass AdaptiveBBStepSize(PGMStepSize):\n    r\"\"\"Adaptive Barzilai-Borwein method to determine step size.\n\n    Adaptive Barzilai-Borwein method to determine step size in PGM, as\n    introduced in :cite:`zhou-2006-adaptive`.\n\n    The adaptive step size rule computes\n\n    .. math::\n\n       \\mb{\\Delta x} = \\mb{x}_k - \\mb{x}_{k-1} \\; \\\\\n       \\mb{\\Delta g} = \\nabla f(\\mb{x}_k) - \\nabla f (\\mb{x}_{k-1}) \\; \\\\\n       \\alpha^{\\mathrm{BB1}} = \\frac{\\mb{\\Delta x}^T \\mb{\\Delta x}}\n       {\\mb{\\Delta x}^T \\mb{\\Delta g}} \\; \\\\\n       \\alpha^{\\mathrm{BB2}} = \\frac{\\mb{\\Delta x}^T \\mb{\\Delta g}}\n       {\\mb{\\Delta g}^T \\mb{\\Delta g}} \\;\\;.\n\n    The determination of the new steps size is made via the rule\n\n    .. math::\n\n        \\alpha = \\left\\{ \\begin{matrix} \\alpha^{\\mathrm{BB2}}  &\n        \\mathrm{~if~} \\alpha^{\\mathrm{BB2}} / \\alpha^{\\mathrm{BB1}}\n        < \\kappa \\; \\\\\n        \\alpha^{\\mathrm{BB1}} & \\mathrm{~otherwise} \\end{matrix}\n        \\right . \\;,\n\n    with :math:`\\kappa \\in (0, 1)`.\n\n    Since the PGM solver uses the reciprocal of the step size, the value\n    :math:`L = 1 / \\alpha` is returned.\n\n    When applied to complex-valued problems, only the real part of the\n    inner product is used. When the inner product is negative, the\n    previous iterate is used instead.\n\n    Attributes:\n        pgm (:class:`.PGM`): PGM solver object to which the solver is\n           attached.\n    \"\"\"\n\n    def __init__(self, kappa: float = 0.5):\n        r\"\"\"Initialize a :class:`AdaptiveBBStepSize` object.\n\n        Args:\n            kappa : Threshold for step size selection :math:`\\kappa`.\n        \"\"\"\n        self.kappa: float = kappa\n        self.xprev: Union[Array, BlockArray] = None\n        self.gradprev: Union[Array, BlockArray] = None\n        self.Lbb1prev: Optional[float] = None\n        self.Lbb2prev: Optional[float] = None\n\n    def update(self, v: Union[Array, BlockArray]) -> float:\n        \"\"\"Update the reciprocal of the step size.\n\n        Args:\n            v: Current solution or current extrapolation (if accelerated\n               PGM).\n\n        Returns:\n            Updated reciprocal of the step size.\n        \"\"\"\n\n        if self.xprev is None:\n            # Solution and gradient of previous iterate are required.\n            # For first iteration these variables are stored and current estimate is returned.\n            self.xprev = v\n            self.gradprev = self.pgm.f.grad(self.xprev)\n            L = self.pgm.L\n        else:\n            Δx = v - self.xprev\n            gradv = self.pgm.f.grad(v)\n            Δg = gradv - self.gradprev\n            # Taking real part of inner products in case of complex-value problem.\n            innerxx = snp.real(snp.sum(Δx.conj() * Δx))\n            innerxg = snp.real(snp.sum(Δx.conj() * Δg))\n            innergg = snp.real(snp.sum(Δg.conj() * Δg))\n            Lbb1 = innerxg / innerxx\n            # Revert to previous iterate if computation results in nan or negative value.\n            if snp.isnan(Lbb1) or Lbb1 <= 0.0:\n                Lbb1 = self.Lbb1prev\n            Lbb2 = innergg / innerxg\n            # Revert to previous iterate if computation results in nan or negative value.\n            if snp.isnan(Lbb2) or Lbb2 <= 0.0:\n                Lbb2 = self.Lbb2prev\n            # If possible, apply adaptive selection rule, if not, revert to previous iterate\n            if Lbb1 is not None and Lbb2 is not None:\n                if (Lbb1 / Lbb2) < self.kappa:\n                    L = Lbb2\n                else:\n                    L = Lbb1\n            else:\n                L = self.pgm.L\n            # Store current state and gradient for next update.\n            self.xprev = v\n            self.gradprev = gradv\n            # Store current estimates of Barzilai-Borwein 1 (Lbb1) and Barzilai-Borwein 2 (Lbb2).\n            self.Lbb1prev = Lbb1\n            self.Lbb2prev = Lbb2\n\n        return L\n\n\nclass LineSearchStepSize(PGMStepSize):\n    r\"\"\"Line search for estimating the step size for PGM solvers.\n\n    Line search for estimating the reciprocal of step size for PGM\n    solvers. The line search strategy described in :cite:`beck-2009-fast`\n    estimates :math:`L` such that :math:`f(\\mb{x}) <= \\hat{f}_{L}(\\mb{x})`\n    is satisfied with :math:`\\hat{f}_{L}` a quadratic approximation to\n    :math:`f` defined as\n\n    .. math::\n       \\hat{f}_{L}(\\mb{x}, \\mb{y}) = f(\\mb{y}) + \\nabla f(\\mb{y})^H\n       (\\mb{x} - \\mb{y}) + \\frac{L}{2} \\left\\| \\mb{x} - \\mb{y}\n       \\right\\|_2^2 \\;,\n\n    with :math:`\\mb{x}` the potential new update and :math:`\\mb{y}` the\n    current solution or current extrapolation (if accelerated PGM).\n\n    Attributes:\n        pgm (:class:`.PGM`): PGM solver object to which the solver is\n           attached.\n    \"\"\"\n\n    def __init__(self, gamma_u: float = 1.2, maxiter: int = 50):\n        r\"\"\"Initialize a :class:`LineSearchStepSize` object.\n\n        Args:\n            gamma_u: Rate of increment in :math:`L`.\n            maxiter: Maximum iterations in line search.\n        \"\"\"\n        self.gamma_u: float = gamma_u\n        self.maxiter: int = maxiter\n\n        def g_prox(v, gradv, L):\n            return self.pgm.g.prox(v - 1.0 / L * gradv, 1.0 / L)\n\n        self.g_prox = jax.jit(g_prox)\n\n    def update(self, v: Union[Array, BlockArray]) -> float:\n        \"\"\"Update the reciprocal of the step size.\n\n        Args:\n            v: Current solution or current extrapolation (if accelerated\n               PGM).\n\n        Returns:\n            Updated reciprocal of the step size.\n        \"\"\"\n\n        gradv = self.pgm.f.grad(v)\n        L = self.pgm.L\n        it = 0\n        while it < self.maxiter:\n            z = self.g_prox(v, gradv, L)\n            fz = self.pgm.f(z)\n            fquad = self.pgm.f_quad_approx(z, v, L)\n            if fz <= fquad:\n                break\n            else:\n                L *= self.gamma_u\n            it += 1\n        return L\n\n\nclass RobustLineSearchStepSize(LineSearchStepSize):\n    r\"\"\"Robust line search for estimating the accelerated PGM step size.\n\n    A robust line search for estimating the reciprocal of step size for\n    accelerated PGM solvers.\n\n    The robust line search strategy described in :cite:`florea-2017-robust`\n    estimates :math:`L` such that :math:`f(\\mb{x}) <= \\hat{f}_{L}(\\mb{x})`\n    is satisfied with :math:`\\hat{f}_{L}` a quadratic approximation to\n    :math:`f` defined as\n\n    .. math::\n       \\hat{f}_{L}(\\mb{x}, \\mb{y}) = f(\\mb{y}) + \\nabla f(\\mb{y})^H\n       (\\mb{x} - \\mb{y}) + \\frac{L}{2} \\left\\| \\mb{x} - \\mb{y}\n       \\right\\|_2^2 \\;,\n\n    with :math:`\\mb{x}` the potential new update and :math:`\\mb{y}` the\n    auxiliary extrapolation state.\n\n    Attributes:\n        pgm (:class:`.PGM`): PGM solver object to which the solver is\n           attached.\n    \"\"\"\n\n    def __init__(self, gamma_d: float = 0.9, gamma_u: float = 2.0, maxiter: int = 50):\n        r\"\"\"Initialize a :class:`RobustLineSearchStepSize` object.\n\n        Args:\n            gamma_d: Rate of decrement in :math:`L`.\n            gamma_u: Rate of increment in :math:`L`.\n            maxiter: Maximum iterations in line search.\n        \"\"\"\n        super(RobustLineSearchStepSize, self).__init__(gamma_u, maxiter)\n        self.gamma_d: float = gamma_d\n        self.Tk: float = 0.0\n        # State needed for computing auxiliary extrapolation sequence in robust line search.\n        self.Zrb: Union[Array, BlockArray] = None\n        #: Current estimate of solution in robust line search.\n        self.Z: Union[Array, BlockArray] = None\n\n    def update(self, v: Union[Array, BlockArray]) -> float:\n        \"\"\"Update the reciprocal of the step size.\n\n        Args:\n            v: Current solution or current extrapolation (if accelerated\n               PGM).\n\n        Returns:\n            Updated reciprocal of the step size.\n        \"\"\"\n        if self.Zrb is None:\n            self.Zrb = self.pgm.x\n\n        L = self.pgm.L * self.gamma_d\n\n        it = 0\n        while it < self.maxiter:\n            t = (1.0 + snp.sqrt(1.0 + 4.0 * L * self.Tk)) / (2.0 * L)\n            T = self.Tk + t\n            # Auxiliary extrapolation sequence.\n            y = (self.Tk * self.pgm.x + t * self.Zrb) / T\n            # New update based on auxiliary extrapolation and current L estimate.\n            z = self.pgm.x_step(y, L)\n            fz = self.pgm.f(z)\n            fquad = self.pgm.f_quad_approx(z, y, L)\n            if fz <= fquad:\n                break\n            else:\n                L *= self.gamma_u\n            it += 1\n        self.Tk = T\n        self.Zrb += t * L * (z - y)\n        self.Z = z\n\n        return L\n"
  },
  {
    "path": "scico/optimize/_primaldual.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Primal-dual solvers.\"\"\"\n\n# Needed to annotate a class method that returns the encapsulating class;\n# see https://www.python.org/dev/peps/pep-0563/\nfrom __future__ import annotations\n\nfrom typing import List, Optional, Union\n\nimport scico.numpy as snp\nfrom scico.functional import Functional\nfrom scico.linop import LinearOperator, jacobian, operator_norm\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.linalg import norm\nfrom scico.operator import Operator\nfrom scico.typing import PRNGKey\n\nfrom ._common import Optimizer\n\n\nclass PDHG(Optimizer):\n    r\"\"\"Primal–dual hybrid gradient (PDHG) algorithm.\n\n    |\n\n    Primal–dual hybrid gradient (PDHG) is a family of algorithms\n    :cite:`esser-2010-general` that includes the Chambolle-Pock\n    primal-dual algorithm :cite:`chambolle-2010-firstorder`. The form\n    implemented here is a minor variant :cite:`pock-2011-diagonal` of the\n    original Chambolle-Pock algorithm.\n\n    Solve an optimization problem of the form\n\n    .. math::\n        \\argmin_{\\mb{x}} \\; f(\\mb{x}) + g(C \\mb{x}) \\;,\n\n    where :math:`f` and :math:`g` are instances of :class:`.Functional`,\n    (in most cases :math:`f` will, more specifically be an an instance\n    of :class:`.Loss`), and :math:`C` is an instance of\n    :class:`.Operator` or :class:`.LinearOperator`.\n\n    When `C` is a :class:`.LinearOperator`, the algorithm iterations are\n\n    .. math::\n       \\begin{aligned}\n       \\mb{x}^{(k+1)} &= \\mathrm{prox}_{\\tau f} \\left( \\mb{x}^{(k)} -\n       \\tau C^T \\mb{z}^{(k)} \\right) \\\\\n       \\mb{z}^{(k+1)} &= \\mathrm{prox}_{\\sigma g^*} \\left( \\mb{z}^{(k)}\n       + \\sigma C((1 + \\alpha) \\mb{x}^{(k+1)} - \\alpha \\mb{x}^{(k)}\n       \\right) \\;,\n       \\end{aligned}\n\n    where :math:`g^*` denotes the convex conjugate of :math:`g`.\n    Parameters :math:`\\tau > 0` and :math:`\\sigma > 0` are also required\n    to satisfy\n\n    .. math::\n       \\tau \\sigma < \\| C \\|_2^{-2} \\;,\n\n    and it is required that :math:`\\alpha \\in [0, 1]`.\n\n    When `C` is a non-linear :class:`.Operator`, a non-linear PDHG variant\n    :cite:`valkonen-2014-primal` is used, with the same iterations except\n    for :math:`\\mb{x}` update\n\n    .. math::\n       \\mb{x}^{(k+1)} = \\mathrm{prox}_{\\tau f} \\left( \\mb{x}^{(k)} -\n       \\tau [J_x C(\\mb{x}^{(k)})]^T \\mb{z}^{(k)} \\right) \\;.\n\n\n    Attributes:\n        f (:class:`.Functional`): Functional :math:`f` (usually a\n          :class:`.Loss`).\n        g (:class:`.Functional`): Functional :math:`g`.\n        C (:class:`.Operator`): :math:`C` operator.\n        tau (scalar): First algorithm parameter.\n        sigma (scalar): Second algorithm parameter.\n        alpha (scalar): Relaxation parameter.\n        x (array-like): Primal variable :math:`\\mb{x}` at current\n          iteration.\n        x_old (array-like): Primal variable :math:`\\mb{x}` at previous\n          iteration.\n        z (array-like): Dual variable :math:`\\mb{z}` at current\n          iteration.\n        z_old (array-like): Dual variable :math:`\\mb{z}` at previous\n          iteration.\n    \"\"\"\n\n    def __init__(\n        self,\n        f: Functional,\n        g: Functional,\n        C: Operator,\n        tau: float,\n        sigma: float,\n        alpha: float = 1.0,\n        x0: Optional[Union[Array, BlockArray]] = None,\n        z0: Optional[Union[Array, BlockArray]] = None,\n        **kwargs,\n    ):\n        r\"\"\"Initialize a :class:`PDHG` object.\n\n        Args:\n            f: Functional :math:`f` (usually a loss function).\n            g: Functional :math:`g`.\n            C: Operator :math:`C`.\n            tau: First algorithm parameter.\n            sigma: Second algorithm parameter.\n            alpha: Relaxation parameter.\n            x0: Starting point for :math:`\\mb{x}`. If ``None``, defaults\n               to an array of zeros.\n            z0: Starting point for :math:`\\mb{z}`. If ``None``, defaults\n               to an array of zeros.\n            **kwargs: Additional optional parameters handled by\n                initializer of base class :class:`.Optimizer`.\n        \"\"\"\n        self.f: Functional = f\n        self.g: Functional = g\n        self.C: Operator = C\n        self.tau: float = tau\n        self.sigma: float = sigma\n        self.alpha: float = alpha\n\n        if x0 is None:\n            input_shape = C.input_shape\n            dtype = C.input_dtype\n            x0 = snp.zeros(input_shape, dtype=dtype)\n        self.x = x0\n        self.x_old = self.x\n        if z0 is None:\n            input_shape = C.output_shape\n            dtype = C.output_dtype\n            z0 = snp.zeros(input_shape, dtype=dtype)\n        self.z = z0\n        self.z_old = self.z\n\n        super().__init__(**kwargs)\n\n    def _working_vars_finite(self) -> bool:\n        \"\"\"Determine where ``NaN`` of ``Inf`` encountered in solve.\n\n        Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in\n        a solver working variable.\n        \"\"\"\n        return snp.all(snp.isfinite(self.x)) and snp.all(snp.isfinite(self.z))\n\n    def _objective_evaluatable(self):\n        \"\"\"Determine whether the objective function can be evaluated.\"\"\"\n        return self.f.has_eval and self.g.has_eval\n\n    def _itstat_extra_fields(self):\n        \"\"\"Define linearized ADMM-specific iteration statistics fields.\"\"\"\n        itstat_fields = {\"Prml Rsdl\": \"%9.3e\", \"Dual Rsdl\": \"%9.3e\"}\n        itstat_attrib = [\"norm_primal_residual()\", \"norm_dual_residual()\"]\n        return itstat_fields, itstat_attrib\n\n    def _state_variable_names(self) -> List[str]:\n        return [\"x\", \"x_old\", \"z\", \"z_old\"]\n\n    def minimizer(self) -> Union[Array, BlockArray]:\n        return self.x\n\n    def objective(\n        self,\n        x: Optional[Union[Array, BlockArray]] = None,\n    ) -> float:\n        r\"\"\"Evaluate the objective function.\n\n        Evaluate the objective function\n\n        .. math::\n            f(\\mb{x}) + g(C \\mb{x}) \\;.\n\n        Args:\n            x: Point at which to evaluate objective function. If ``None``,\n                the objective is evaluated at the current iterate\n                :code:`self.x`\n\n        Returns:\n            scalar: Value of the objective function.\n        \"\"\"\n        if x is None:\n            x = self.x\n        return self.f(x) + self.g(self.C(x))\n\n    def norm_primal_residual(self) -> float:\n        r\"\"\"Compute the :math:`\\ell_2` norm of the primal residual.\n\n        Compute the :math:`\\ell_2` norm of the primal residual\n\n        .. math::\n            \\tau^{-1} \\norm{\\mb{x}^{(k)} - \\mb{x}^{(k-1)}}_2 \\;.\n\n        Returns:\n            Current norm of primal residual.\n        \"\"\"\n\n        return norm(self.x - self.x_old) / self.tau  # type: ignore\n\n    def norm_dual_residual(self) -> float:\n        r\"\"\"Compute the :math:`\\ell_2` norm of the dual residual.\n\n        Compute the :math:`\\ell_2` norm of the dual residual\n\n        .. math::\n            \\sigma^{-1} \\norm{\\mb{z}^{(k)} - \\mb{z}^{(k-1)}}_2 \\;.\n\n        Returns:\n            Current norm of dual residual.\n\n        \"\"\"\n        return norm(self.z - self.z_old) / self.sigma\n\n    def step(self):\n        \"\"\"Perform a single iteration.\"\"\"\n        self.x_old = self.x\n        self.z_old = self.z\n        if isinstance(self.C, LinearOperator):\n            proxarg = self.x - self.tau * self.C.conj().T(self.z)\n        else:\n            proxarg = self.x - self.tau * self.C.vjp(self.x, conjugate=True)[1](self.z)\n        self.x = self.f.prox(proxarg, self.tau, v0=self.x)\n        proxarg = self.z + self.sigma * self.C(\n            (1.0 + self.alpha) * self.x - self.alpha * self.x_old\n        )\n        self.z = self.g.conj_prox(proxarg, self.sigma, v0=self.z)\n\n    @staticmethod\n    def estimate_parameters(\n        C: Operator,\n        x: Optional[Union[Array, BlockArray]] = None,\n        ratio: float = 1.0,\n        factor: Optional[float] = 1.01,\n        maxiter: int = 100,\n        key: Optional[PRNGKey] = None,\n    ):\n        r\"\"\"Estimate `tau` and `sigma` parameters of :class:`PDHG`.\n\n        Find values of the `tau` and `sigma` parameters of :class:`PDHG`\n        that respect the constraint\n\n        .. math::\n           \\tau \\sigma < \\| C \\|_2^{-2} \\quad \\text{or} \\quad\n           \\tau \\sigma < \\| J_x C(\\mb{x}) \\|_2^{-2} \\;,\n\n        depending on whether :math:`C` is a :class:`.LinearOperator` or\n        not.\n\n        Args:\n            C: Operator :math:`C`.\n            x: Value of :math:`\\mb{x}` at which to evaluate the Jacobian\n               of :math:`C` (when it is not a :class:`.LinearOperator`).\n               If ``None``, defaults to an array of zeros.\n            ratio: Desired ratio between return :math:`\\tau` and\n               :math:`\\sigma` values (:math:`\\sigma = \\mathrm{ratio}\n               \\tau`).\n            factor: Safety factor with which to multiply :math:`\\| C\n               \\|_2^{-2}` to ensure strict inequality compliance. If\n               ``None``, the value is set to 1.0.\n            maxiter: Maximum number of power iterations to use in operator\n               norm estimation (see :func:`.operator_norm`). Default: 100.\n            key: Jax PRNG key to use in operator norm estimation (see\n               :func:`.operator_norm`). Defaults to ``None``, in which\n               case a new key is created.\n\n        Returns:\n            A tuple (`tau`, `sigma`) representing the estimated parameter\n            values.\n        \"\"\"\n        if x is None:\n            x = snp.zeros(C.input_shape, dtype=C.input_dtype)\n        if factor is None:\n            factor = 1.0\n        if isinstance(C, LinearOperator):\n            J = C\n        else:\n            J = jacobian(C, x)\n        Cnrm = operator_norm(J, maxiter=maxiter, key=key)\n        tau = snp.sqrt(factor / ratio) / Cnrm\n        sigma = ratio * tau\n        return (tau, sigma)\n"
  },
  {
    "path": "scico/optimize/admm.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2023 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"ADMM solver and auxiliary classes.\"\"\"\n\nimport sys\n\n# isort: off\nfrom ._admmaux import (\n    SubproblemSolver,\n    GenericSubproblemSolver,\n    LinearSubproblemSolver,\n    MatrixSubproblemSolver,\n    CircularConvolveSolver,\n    FBlockCircularConvolveSolver,\n    G0BlockCircularConvolveSolver,\n)\nfrom ._admm import ADMM\n\n__all__ = [\n    \"SubproblemSolver\",\n    \"GenericSubproblemSolver\",\n    \"LinearSubproblemSolver\",\n    \"MatrixSubproblemSolver\",\n    \"CircularConvolveSolver\",\n    \"FBlockCircularConvolveSolver\",\n    \"G0BlockCircularConvolveSolver\",\n    \"ADMM\",\n]\n\n# Imported items in __all__ appear to originate in top-level linop module\nfor name in __all__:\n    getattr(sys.modules[__name__], name).__module__ = __name__\n"
  },
  {
    "path": "scico/optimize/pgm.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"PGM solvers and auxiliary classes.\"\"\"\n\nimport sys\n\n# isort: off\nfrom ._pgmaux import (\n    PGMStepSize,\n    BBStepSize,\n    AdaptiveBBStepSize,\n    LineSearchStepSize,\n    RobustLineSearchStepSize,\n)\nfrom ._pgm import PGM, AcceleratedPGM\n\n__all__ = [\n    \"PGMStepSize\",\n    \"BBStepSize\",\n    \"AdaptiveBBStepSize\",\n    \"LineSearchStepSize\",\n    \"RobustLineSearchStepSize\",\n    \"PGM\",\n    \"AcceleratedPGM\",\n]\n\n# Imported items in __all__ appear to originate in top-level linop module\nfor name in __all__:\n    getattr(sys.modules[__name__], name).__module__ = __name__\n"
  },
  {
    "path": "scico/plot.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Plotting/visualization functions.\n\nOptional alternative high-level interface to selected :mod:`matplotlib`\nplotting functions.\n\"\"\"\n\n# This module is copied from https://github.com/bwohlberg/sporco\n\nimport os\nimport sys\n\nimport numpy as np\n\nimport matplotlib\nimport matplotlib.cm as cm\nimport matplotlib.pyplot as plt\nfrom matplotlib.pyplot import figure, gca, gcf, savefig, subplot, subplots  # noqa\nfrom mpl_toolkits.axes_grid1 import make_axes_locatable\nfrom mpl_toolkits.mplot3d import Axes3D  # noqa\n\ntry:\n    import mpldatacursor as mpldc\nexcept ImportError:\n    have_mpldc = False\nelse:\n    have_mpldc = True\n\n\n__all__ = [\n    \"plot\",\n    \"surf\",\n    \"contour\",\n    \"imview\",\n    \"close\",\n    \"set_ipython_plot_backend\",\n    \"set_notebook_plot_backend\",\n    \"config_notebook_plotting\",\n]\n\n\ndef _attach_keypress(fig, scaling=1.1):\n    \"\"\"Attach a key press event handler.\n\n    Attach a key press event handler that configures keys for closing a\n    figure and changing the figure size. Keys 'e' and 'c' respectively\n    expand and contract the figure, and key 'q' closes it.\n\n    **Note:** Resizing may not function correctly with all matplotlib\n    backends\n    (a `bug <https://github.com/matplotlib/matplotlib/issues/10083>`__\n    has been reported).\n\n    Args:\n        fig (:class:`matplotlib.figure.Figure` object): Figure to which\n            event handling is to be attached.\n        scaling (float, optional (default 1.1)): Scaling factor for\n            figure size changes.\n\n    Returns:\n        function: Key press event handler function.\n    \"\"\"\n\n    def press(event):\n        if event.key == \"q\":\n            plt.close(fig)\n        elif event.key == \"e\":\n            fig.set_size_inches(scaling * fig.get_size_inches(), forward=True)\n        elif event.key == \"c\":\n            fig.set_size_inches(fig.get_size_inches() / scaling, forward=True)\n\n    # Avoid multiple event handlers attached to the same figure\n    if not hasattr(fig, \"_scico_keypress_cid\"):\n        cid = fig.canvas.mpl_connect(\"key_press_event\", press)\n        fig._scico_keypress_cid = cid\n\n    return press\n\n\ndef _attach_zoom(ax, scaling=2.0):\n    \"\"\"Attach a scroll wheel event handler.\n\n    Attach an event handler that supports zooming within a plot using the\n    mouse scroll wheel.\n\n    Args:\n        ax (:class:`matplotlib.axes.Axes` object): Axes to which event\n            handling is to be attached.\n        scaling (float, optional (default 2.0)): Scaling factor for\n            zooming in and out.\n\n    Returns:\n        function: Mouse scroll wheel event handler function.\n    \"\"\"\n\n    # See https://stackoverflow.com/questions/11551049\n    def zoom(event):\n        # Get the current x and y limits\n        cur_xlim = ax.get_xlim()\n        cur_ylim = ax.get_ylim()\n        # Get event location\n        xdata = event.xdata\n        ydata = event.ydata\n        # Return if cursor is not over valid region of plot\n        if xdata is None or ydata is None:\n            return\n\n        if event.button == \"up\":\n            # Deal with zoom in\n            scale_factor = 1.0 / scaling\n        elif event.button == \"down\":\n            # Deal with zoom out\n            scale_factor = scaling\n\n        # Get distance from the cursor to the edge of the figure frame\n        x_left = xdata - cur_xlim[0]\n        x_right = cur_xlim[1] - xdata\n        y_top = ydata - cur_ylim[0]\n        y_bottom = cur_ylim[1] - ydata\n\n        # Calculate new x and y limits\n        new_xlim = (xdata - x_left * scale_factor, xdata + x_right * scale_factor)\n        new_ylim = (ydata - y_top * scale_factor, ydata + y_bottom * scale_factor)\n\n        # Ensure that x limit range is no larger than that of the reference\n        if np.diff(new_xlim) > np.diff(zoom.xlim_ref):\n            new_xlim *= np.diff(zoom.xlim_ref) / np.diff(new_xlim)\n        # Ensure that lower x limit is not less than that of the reference\n        if new_xlim[0] < zoom.xlim_ref[0]:\n            new_xlim += np.array(zoom.xlim_ref[0] - new_xlim[0])\n        # Ensure that upper x limit is not greater than that of the reference\n        if new_xlim[1] > zoom.xlim_ref[1]:\n            new_xlim -= np.array(new_xlim[1] - zoom.xlim_ref[1])\n\n        # Ensure that ylim tuple has the smallest value first\n        if zoom.ylim_ref[1] < zoom.ylim_ref[0]:\n            ylim_ref = zoom.ylim_ref[::-1]\n            new_ylim = new_ylim[::-1]\n        else:\n            ylim_ref = zoom.ylim_ref\n\n        # Ensure that y limit range is no larger than that of the reference\n        if np.diff(new_ylim) > np.diff(ylim_ref):\n            new_ylim *= np.diff(ylim_ref) / np.diff(new_ylim)\n        # Ensure that lower y limit is not less than that of the reference\n        if new_ylim[0] < ylim_ref[0]:\n            new_ylim += np.array(ylim_ref[0] - new_ylim[0])\n        # Ensure that upper y limit is not greater than that of the reference\n        if new_ylim[1] > ylim_ref[1]:\n            new_ylim -= np.array(new_ylim[1] - ylim_ref[1])\n\n        # Return the ylim tuple to its original order\n        if zoom.ylim_ref[1] < zoom.ylim_ref[0]:\n            new_ylim = new_ylim[::-1]\n\n        # Set new x and y limits\n        ax.set_xlim(new_xlim)\n        ax.set_ylim(new_ylim)\n\n        # Force redraw\n        ax.figure.canvas.draw()\n\n    # Record reference x and y limits prior to any zooming\n    zoom.xlim_ref = ax.get_xlim()\n    zoom.ylim_ref = ax.get_ylim()\n\n    # Get figure for specified axes and attach the event handler\n    fig = ax.get_figure()\n    fig.canvas.mpl_connect(\"scroll_event\", zoom)\n\n    return zoom\n\n\ndef plot(y, x=None, ptyp=\"plot\", xlbl=None, ylbl=None, title=None, lgnd=None, lglc=None, **kwargs):\n    \"\"\"Plot points or lines in 2D.\n\n    Plot points or lines in 2D. If a figure object is specified then the\n    plot is drawn in that figure, and `fig.show()` is not called. The\n    figure is closed on key entry 'q'.\n\n    Args:\n        y (array_like): 1d or 2d array of data to plot. If a 2d array,\n            each column is plotted as a separate curve.\n        x (array_like, optional (default ``None``)): Values for x-axis of\n            the plot.\n        ptyp (string, optional (default 'plot')): Plot type specification\n            (options are 'plot', 'semilogx', 'semilogy', and 'loglog').\n        xlbl (string, optional (default ``None``)): Label for x-axis.\n        ylbl (string, optional (default ``None``)): Label for y-axis.\n        title (string, optional (default ``None``)): Figure title.\n        lgnd (list of strings, optional (default ``None``)): List of\n            legend string.\n        lglc (string, optional (default ``None``)): Legend location string.\n        **kwargs: :class:`matplotlib.lines.Line2D` properties or figure\n            properties.\n\n            Keyword arguments specifying :class:`matplotlib.lines.Line2D`\n            properties, e.g. `lw=2.0` sets a line width of 2, or\n            properties of the figure and axes. If not specified, the\n            defaults for line width (`lw`) and marker size (`ms`) are\n            1.5 and 6.0 respectively. The valid figure and axes keyword\n            arguments are listed below:\n\n            .. |mplfg| replace:: :class:`matplotlib.figure.Figure` object\n            .. |mplax| replace:: :class:`matplotlib.axes.Axes` object\n\n            .. rst-class:: kwargs\n\n            =====  ==================== ===================================\n            kwarg  Accepts              Description\n            =====  ==================== ===================================\n            fgsz   tuple (width,height) Specify figure dimensions in inches\n            fgnm   integer              Figure number of figure\n            fig    |mplfg|              Draw in specified figure instead of\n                                        creating one\n            ax     |mplax|              Plot in specified axes instead of\n                                        current axes of figure\n            =====  ==================== ===================================\n\n    Returns:\n        - **fig** (:class:`matplotlib.figure.Figure` object):\n          Figure object for this figure.\n        - **ax** (:class:`matplotlib.axes.Axes` object):\n          Axes object for this plot.\n\n    Raises:\n        ValueError: If an invalid plot type is specified via parameter\n           `ptyp`.\n    \"\"\"\n\n    # Extract kwargs entries that are not related to line properties\n    fgsz = kwargs.pop(\"fgsz\", None)\n    fgnm = kwargs.pop(\"fgnm\", None)\n    fig = kwargs.pop(\"fig\", None)\n    ax = kwargs.pop(\"ax\", None)\n\n    figp = fig\n    if fig is None:\n        fig = plt.figure(num=fgnm, figsize=fgsz)\n        fig.clf()\n        ax = fig.gca()\n    elif ax is None:\n        ax = fig.gca()\n\n    # Set defaults for line width and marker size\n    if \"lw\" not in kwargs and \"linewidth\" not in kwargs:\n        kwargs[\"lw\"] = 1.5\n    if \"ms\" not in kwargs and \"markersize\" not in kwargs:\n        kwargs[\"ms\"] = 6.0\n\n    if ptyp not in (\"plot\", \"semilogx\", \"semilogy\", \"loglog\"):\n        raise ValueError(\"Invalid plot type '%s'.\" % ptyp)\n    pltmth = getattr(ax, ptyp)\n    if x is None:\n        pltln = pltmth(y, **kwargs)\n    else:\n        pltln = pltmth(x, y, **kwargs)\n\n    ax.fmt_xdata = \"{: .2f}\".format\n    ax.fmt_ydata = \"{: .2f}\".format\n\n    if title is not None:\n        ax.set_title(title)\n    if xlbl is not None:\n        ax.set_xlabel(xlbl)\n    if ylbl is not None:\n        ax.set_ylabel(ylbl)\n    if lgnd is not None:\n        ax.legend(lgnd, loc=lglc)\n\n    _attach_keypress(fig)\n    _attach_zoom(ax)\n\n    if have_mpldc:\n        mpldc.datacursor(pltln)\n\n    if figp is None:\n        fig.show()\n\n    return fig, ax\n\n\ndef surf(\n    z,\n    x=None,\n    y=None,\n    elev=None,\n    azim=None,\n    xlbl=None,\n    ylbl=None,\n    zlbl=None,\n    title=None,\n    lblpad=8.0,\n    alpha=1.0,\n    cntr=None,\n    cmap=None,\n    fgsz=None,\n    fgnm=None,\n    fig=None,\n    ax=None,\n):\n    \"\"\"Plot a 2D surface in 3D.\n\n    Plot a 2D surface in 3D. If a figure object is specified then the\n    surface is drawn in that figure, and `fig.show()` is not called.\n    The figure is closed on key entry 'q'.\n\n    Args:\n        z (array_like): 2d array of data to plot.\n        x (array_like, optional (default ``None``)): Values for x-axis of\n            the plot.\n        y (array_like, optional (default ``None``)): Values for y-axis of\n            the plot.\n        elev (float): Elevation angle (in degrees) in the z plane.\n        azim (float): Azimuth angle  (in degrees) in the x,y plane.\n        xlbl (string, optional (default ``None``)): Label for x-axis.\n        ylbl (string, optional (default ``None``)): Label for y-axis.\n        zlbl (string, optional (default ``None``)): Label for z-axis.\n        title (string, optional (default ``None``)): Figure title.\n        lblpad (float, optional (default 8.0)): Label padding.\n        alpha (float between 0.0 and 1.0, optional (default 1.0)):\n            Transparency.\n        cntr (int or sequence of ints, optional (default ``None``)): If\n            not ``None``, plot contours of the surface on the lower end\n            of the z-axis. An int specifies the number of contours to\n            plot, and a sequence specifies the specific contour levels to\n            plot.\n        cmap (:class:`matplotlib.colors.Colormap` object, optional (default ``None``)):\n            Color map for surface. If none specifed, defaults to `cm.YlOrRd`.\n        fgsz (tuple (width,height), optional (default ``None``)): Specify\n            figure dimensions in inches.\n        fgnm (integer, optional (default ``None``)): Figure number of figure.\n        fig (:class:`matplotlib.figure.Figure` object, optional (default ``None``)):\n            Draw in specified figure instead of creating one.\n        ax (:class:`matplotlib.axes.Axes` object, optional (default ``None``)):\n            Plot in specified axes instead of creating one.\n\n    Returns:\n        - **fig** (:class:`matplotlib.figure.Figure` object):\n          Figure object for this figure.\n        - **ax** (:class:`matplotlib.axes.Axes` object):\n          Axes object for this plot.\n    \"\"\"\n\n    figp = fig\n    if fig is None:\n        fig = plt.figure(num=fgnm, figsize=fgsz)\n        fig.clf()\n        ax = plt.axes(projection=\"3d\")\n    else:\n        if ax is None:\n            ax = plt.axes(projection=\"3d\")\n        else:\n            # See https://stackoverflow.com/a/43563804\n            #     https://stackoverflow.com/a/35221116\n            if ax.name != \"3d\":\n                ax.remove()\n                ax = fig.add_subplot(ax.get_subplotspec(), projection=\"3d\")\n\n    if elev is not None or azim is not None:\n        ax.view_init(elev=elev, azim=azim)\n\n    if cmap is None:\n        cmap = cm.YlOrRd\n\n    if x is None:\n        x = range(z.shape[1])\n    if y is None:\n        y = range(z.shape[0])\n\n    xg, yg = np.meshgrid(x, y)\n    ax.plot_surface(xg, yg, z, rstride=1, cstride=1, alpha=alpha, cmap=cmap)\n\n    if cntr is not None:\n        offset = np.around(z.min() - 0.2 * (z.max() - z.min()), 3)\n        ax.contour(xg, yg, z, cntr, cmap=cmap, linewidths=2, linestyles=\"solid\", offset=offset)\n        ax.set_zlim(offset, ax.get_zlim()[1])\n\n    ax.fmt_xdata = \"{: .2f}\".format\n    ax.fmt_ydata = \"{: .2f}\".format\n    ax.fmt_zdata = \"{: .2f}\".format\n\n    if title is not None:\n        ax.set_title(title)\n    if xlbl is not None:\n        ax.set_xlabel(xlbl, labelpad=lblpad)\n    if ylbl is not None:\n        ax.set_ylabel(ylbl, labelpad=lblpad)\n    if zlbl is not None:\n        ax.set_zlabel(zlbl, labelpad=lblpad)\n\n    _attach_keypress(fig)\n\n    if figp is None:\n        fig.show()\n\n    return fig, ax\n\n\ndef contour(\n    z,\n    x=None,\n    y=None,\n    v=5,\n    xlog=False,\n    ylog=False,\n    xlbl=None,\n    ylbl=None,\n    title=None,\n    cfmt=None,\n    cfntsz=10,\n    lfntsz=None,\n    alpha=1.0,\n    cmap=None,\n    vmin=None,\n    vmax=None,\n    fgsz=None,\n    fgnm=None,\n    fig=None,\n    ax=None,\n):\n    \"\"\"Contour plot of a 2D surface.\n\n    Contour plot of a 2D surface. If a figure object is specified then\n    the plot is drawn in that figure, and `fig.show()` is not called.\n    The figure is closed on key entry 'q'.\n\n    Args:\n        z (array_like): 2d array of data to plot.\n        x (array_like, optional (default ``None``)): Values for x-axis of\n            the plot.\n        y (array_like, optional (default ``None``)): Values for y-axis of\n            the plot.\n        v (int or sequence of floats, optional (default 5)): An int\n            specifies the number of contours to plot, and a sequence\n            specifies the specific contour levels to plot.\n        xlog (boolean, optional (default ``False``)): Set x-axis to log\n            scale.\n        ylog (boolean, optional (default ``False``)): Set y-axis to log\n            scale.\n        xlbl (string, optional (default ``None``)): Label for x-axis.\n        ylbl (string, optional (default ``None``)): Label for y-axis.\n        title (string, optional (default ``None``)): Figure title.\n        cfmt (string, optional (default ``None``)): Format string for\n            contour labels.\n        cfntsz (int or ``None``, optional (default 10)): Contour label\n            font size. No contour labels are displayed if set to 0 or\n            ``None``.\n        lfntsz (int, optional (default ``None``)): Axis label font size.\n            The default font size is used if set to ``None``.\n        alpha (float, optional (default 1.0)): Underlying image display\n            alpha value.\n        cmap (:class:`matplotlib.colors.Colormap`, optional (default ``None``)):\n            Color map for surface. If none specifed, defaults to `cm.YlOrRd`.\n        vmin, vmax (float, optional (default ``None``)): Set upper and\n            lower bounds for the color map (see the corresponding\n            parameters of :meth:`matplotlib.axes.Axes.imshow`).\n        fgsz (tuple (width,height), optional (default ``None``)): Specify\n            figure dimensions in inches.\n        fgnm (integer, optional (default ``None``)): Figure number of figure.\n        fig (:class:`matplotlib.figure.Figure` object, optional (default ``None``)):\n            Draw in specified figure instead of creating one.\n        ax (:class:`matplotlib.axes.Axes` object, optional (default ``None``)):\n            Plot in specified axes instead of current axes of figure.\n\n    Returns:\n        - **fig** (:class:`matplotlib.figure.Figure` object):\n          Figure object for this figure.\n        - **ax** (:class:`matplotlib.axes.Axes` object):\n          Axes object for this plot.\n    \"\"\"\n\n    figp = fig\n    if fig is None:\n        fig = plt.figure(num=fgnm, figsize=fgsz)\n        fig.clf()\n        ax = fig.gca()\n    elif ax is None:\n        ax = fig.gca()\n\n    if xlog:\n        ax.set_xscale(\"log\")\n    if ylog:\n        ax.set_yscale(\"log\")\n\n    if cmap is None:\n        cmap = cm.YlOrRd\n\n    if x is None:\n        x = np.arange(z.shape[1])\n    else:\n        x = np.array(x)\n    if y is None:\n        y = np.arange(z.shape[0])\n    else:\n        y = np.array(y)\n    xg, yg = np.meshgrid(x, y)\n\n    cntr = ax.contour(xg, yg, z, v, colors=\"black\")\n    kwargs = {}\n    if cfntsz is not None and cfntsz > 0:\n        kwargs[\"fontsize\"] = cfntsz\n    if cfmt is not None:\n        kwargs[\"fmt\"] = cfmt\n    if kwargs:\n        plt.clabel(cntr, inline=True, **kwargs)\n    pc = ax.pcolormesh(\n        xg,\n        yg,\n        z,\n        cmap=cmap,\n        vmin=vmin,\n        vmax=vmax,\n        alpha=alpha,\n        shading=\"gouraud\",\n        clim=(vmin, vmax),\n    )\n\n    if xlog:\n        ax.fmt_xdata = \"{: .2e}\".format\n    else:\n        ax.fmt_xdata = \"{: .2f}\".format\n    if ylog:\n        ax.fmt_ydata = \"{: .2e}\".format\n    else:\n        ax.fmt_ydata = \"{: .2f}\".format\n\n    if title is not None:\n        ax.set_title(title)\n    if xlbl is not None:\n        ax.set_xlabel(xlbl, fontsize=lfntsz)\n    if ylbl is not None:\n        ax.set_ylabel(ylbl, fontsize=lfntsz)\n\n    divider = make_axes_locatable(ax)\n    cax = divider.append_axes(\"right\", size=\"5%\", pad=0.2)\n    plt.colorbar(pc, ax=ax, cax=cax)\n\n    _attach_keypress(fig)\n    _attach_zoom(ax)\n\n    if have_mpldc:\n        mpldc.datacursor()\n\n    if figp is None:\n        fig.show()\n\n    return fig, ax\n\n\ndef imview(\n    img,\n    title=None,\n    copy=True,\n    fltscl=False,\n    intrp=\"nearest\",\n    norm=None,\n    cbar=False,\n    cmap=None,\n    fgsz=None,\n    fgnm=None,\n    fig=None,\n    ax=None,\n):\n    \"\"\"Display an image.\n\n    Display an image. Pixel values are displayed when the pointer is over\n    valid image data. If a figure object is specified then the image is\n    drawn in that figure, and `fig.show()` is not called. The figure is\n    closed on key entry 'q'.\n\n    Args:\n        img (array_like, shape (Nr, Nc) or (Nr, Nc, 3) or (Nr, Nc, 4)):\n            Image to display.\n        title (string, optional (default ``None``)): Figure title.\n        copy (boolean, optional (default ``True``)): If ``True``, create\n            a copy of input `img` as a reference for displayed pixel\n            values, ensuring that displayed values do not change when the\n            array changes in the calling scope. Set this flag to\n            ``False`` if the overhead of an additional copy of the input\n            image is not acceptable.\n        fltscl (boolean, optional (default ``False``)): If ``True``,\n            rescale and shift floating point arrays to [0,1].\n        intrp (string, optional (default 'nearest')): Specify type of\n            interpolation used to display image (see `interpolation`\n            parameter of :meth:`matplotlib.axes.Axes.imshow`).\n        norm (:class:`matplotlib.colors.Normalize` object, optional (default ``None``)):\n            Specify the :class:`matplotlib.colors.Normalize` instance used\n            to scale pixel values for input to the color map.\n        cbar (boolean, optional (default ``False``)): Flag indicating\n            whether to display colorbar.\n        cmap (:class:`matplotlib.colors.Colormap`, optional (default ``None``)):\n            Color map for image. If none specifed, defaults to\n            `cm.Greys_r` for monochrome image.\n        fgsz (tuple (width,height), optional (default ``None``)): Specify\n            figure dimensions in inches.\n        fgnm (integer, optional (default ``None``)): Figure number of\n            figure.\n        fig (:class:`matplotlib.figure.Figure` object, optional (default ``None``)):\n            Draw in specified figure instead of creating one.\n        ax (:class:`matplotlib.axes.Axes` object, optional (default ``None``)):\n            Plot in specified axes instead of current axes of figure.\n\n    Returns:\n        - **fig** (:class:`matplotlib.figure.Figure` object):\n          Figure object for this figure.\n        - **ax** (:class:`matplotlib.axes.Axes` object):\n          Axes object for this plot.\n\n    Raises:\n        ValueError: If the input array is not of the required shape.\n    \"\"\"\n\n    if img.ndim > 2 and img.shape[2] != 3:\n        raise ValueError(\"Argument 'img' must be an Nr x Nc array or an Nr x Nc x 3 array.\")\n\n    figp = fig\n    if fig is None:\n        fig = plt.figure(num=fgnm, figsize=fgsz)\n        fig.clf()\n        ax = fig.gca()\n    elif ax is None:\n        ax = fig.gca()\n\n    # Deal with removal of 'box-forced' adjustable in Matplotlib 2.2.0\n    mplv = matplotlib.__version__.split(\".\")\n    if int(mplv[0]) > 2 or (int(mplv[0]) == 2 and int(mplv[1]) >= 2):\n        try:\n            ax.set_adjustable(\"box\")\n        except Exception:\n            ax.set_adjustable(\"datalim\")\n    else:\n        ax.set_adjustable(\"box-forced\")\n\n    imgd = img.copy()\n    if copy:\n        # Keep a separate copy of the input image so that the original\n        # pixel values can be display rather than the scaled pixel\n        # values that are actually plotted.\n        img = img.copy()\n\n    if cmap is None and img.ndim == 2:\n        cmap = cm.Greys_r\n\n    if np.issubdtype(img.dtype, np.floating):\n        if fltscl:\n            imgd -= imgd.min()\n            imgd /= imgd.max()\n        if img.ndim > 2:\n            imgd = np.clip(imgd, 0.0, 1.0)\n    elif img.dtype == np.uint16:\n        imgd = np.float16(imgd) / np.iinfo(np.uint16).max\n    elif img.dtype == np.int16:\n        imgd = np.float16(imgd) - imgd.min()\n        imgd /= imgd.max()\n\n    if norm is None:\n        im = ax.imshow(imgd, cmap=cmap, interpolation=intrp, vmin=imgd.min(), vmax=imgd.max())\n    else:\n        im = ax.imshow(imgd, cmap=cmap, interpolation=intrp, norm=norm)\n\n    ax.set_yticklabels([])\n    ax.set_xticklabels([])\n\n    if title is not None:\n        ax.set_title(title)\n\n    if cbar or cbar is None:\n        orient = \"vertical\" if img.shape[0] >= img.shape[1] else \"horizontal\"\n        pos = \"right\" if orient == \"vertical\" else \"bottom\"\n        divider = make_axes_locatable(ax)\n        cax = divider.append_axes(pos, size=\"5%\", pad=0.2)\n        if cbar is None:\n            # See http://chris35wills.github.io/matplotlib_axis\n            if hasattr(cax, \"set_facecolor\"):\n                cax.set_facecolor(\"none\")\n            else:\n                cax.set_axis_bgcolor(\"none\")\n            for axis in [\"top\", \"bottom\", \"left\", \"right\"]:\n                cax.spines[axis].set_linewidth(0)\n            cax.set_xticks([])\n            cax.set_yticks([])\n        else:\n            plt.colorbar(im, ax=ax, cax=cax, orientation=orient)\n\n    def format_coord(x, y):\n        nr, nc = imgd.shape[0:2]\n        col = int(x + 0.5)\n        row = int(y + 0.5)\n        if col >= 0 and col < nc and row >= 0 and row < nr:\n            z = img[row, col]\n            if imgd.ndim == 2:\n                return \"x=%6.2f, y=%6.2f, z=%.2f\" % (x, y, z)\n            return \"x=%6.2f, y=%6.2f, z=(%.2f,%.2f,%.2f)\" % sum(((x,), (y,), tuple(z)), ())\n        return \"x=%.2f, y=%.2f\" % (x, y)\n\n    ax.format_coord = format_coord\n\n    if fig.canvas.toolbar is not None:\n        # See https://stackoverflow.com/a/47086132\n        def mouse_move(self, event):\n            if event.inaxes and event.inaxes.get_navigate():\n                s = event.inaxes.format_coord(event.xdata, event.ydata)\n                self.set_message(s)\n\n        def mouse_move_patch(arg):\n            return mouse_move(fig.canvas.toolbar, arg)\n\n        fig.canvas.toolbar._idDrag = fig.canvas.mpl_connect(\"motion_notify_event\", mouse_move_patch)\n\n    _attach_keypress(fig)\n    _attach_zoom(ax)\n\n    if have_mpldc:\n        mpldc.datacursor(display=\"single\")\n\n    if figp is None:\n        fig.show()\n\n    return fig, ax\n\n\ndef close(fig=None):\n    \"\"\"Close figure(s).\n\n    Close figure(s). If a figure object reference or figure number is\n    provided, close the specified figure, otherwise close all figures.\n\n    Args:\n        fig (:class:`matplotlib.figure.Figure` object or integer (optional (default ``None``)):\n          Figure object or number of figure to close.\n    \"\"\"\n\n    if fig is None:\n        plt.close(\"all\")\n    else:\n        plt.close(fig)\n\n\ndef _in_ipython():\n    \"\"\"Determine whether code is running in an ipython shell.\n\n    Returns:\n        bool: ``True`` if running in an ipython shell, ``False``\n           otherwise.\n    \"\"\"\n\n    try:\n        # See https://stackoverflow.com/questions/15411967\n        shell = get_ipython().__class__.__name__\n        return bool(shell == \"TerminalInteractiveShell\")\n    except NameError:\n        return False\n\n\ndef _in_notebook():\n    \"\"\"Determine whether code is running in a Jupyter Notebook shell.\n\n    Returns:\n        bool: ``True`` if running in a notebook shell, ``False``\n           otherwise.\n    \"\"\"\n\n    try:\n        # See https://stackoverflow.com/questions/15411967\n        shell = get_ipython().__class__.__name__\n        return bool(shell == \"ZMQInteractiveShell\")\n    except NameError:\n        return False\n\n\ndef set_ipython_plot_backend(backend=\"qt\"):\n    \"\"\"Set matplotlib backend within an ipython shell.\n\n    Set matplotlib backend within an ipython shell. This function has the\n    same effect as the line magic `%matplotlib [backend]` but is called\n    as a function and includes a check to determine whether the code is\n    running in an ipython shell, so that it can safely be used within a\n    normal python script since it has no effect when not running in an\n    ipython shell.\n\n    Args:\n        backend (string, optional (default 'qt')): Name of backend to be\n            passed to the `%matplotlib` line magic command.\n    \"\"\"\n\n    if _in_ipython():\n        # See https://stackoverflow.com/questions/35595766\n        get_ipython().run_line_magic(\"matplotlib\", backend)\n\n\ndef set_notebook_plot_backend(backend=\"inline\"):\n    \"\"\"Set matplotlib backend within a Jupyter Notebook shell.\n\n    Set matplotlib backend within a Jupyter Notebook shell. This function\n    has the same effect as the line magic `%matplotlib [backend]` but\n    is called as a function and includes a check to determine whether the\n    code is running in a notebook shell, so that it can safely be used\n    within a normal python script since it has no effect when not running\n    in a notebook shell.\n\n    Args:\n        backend (string, optional (default 'inline')): Name of backend to\n            be passed to the `%matplotlib` line magic command.\n    \"\"\"\n\n    if _in_notebook():\n        # See https://stackoverflow.com/questions/35595766\n        get_ipython().run_line_magic(\"matplotlib\", backend)\n\n\ndef config_notebook_plotting():\n    \"\"\"Configure plotting functions for inline plotting.\n\n    Configure plotting functions for inline plotting within a Jupyter\n    Notebook shell. This function has no effect when not within a\n    notebook shell, and may therefore be used within a normal python\n    script. If environment variable ``MATPLOTLIB_IPYNB_BACKEND`` is set,\n    the matplotlib backend is explicitly set to the specified value.\n    \"\"\"\n\n    # Check whether running within a notebook shell and have\n    # not already monkey patched the plot function\n    module = sys.modules[__name__]\n    if _in_notebook() and module.plot.__name__ == \"plot\":\n\n        # Set backend if specified by environment variable\n        if \"MATPLOTLIB_IPYNB_BACKEND\" in os.environ:\n            set_notebook_plot_backend(os.environ[\"MATPLOTLIB_IPYNB_BACKEND\"])\n\n        # Replace plot function with a wrapper function that discards\n        # its return value (within a notebook with inline plotting, plots\n        # are duplicated if the return value from the original function is\n        # not assigned to a variable)\n        plot_original = module.plot\n\n        def plot_wrap(*args, **kwargs):\n            plot_original(*args, **kwargs)\n\n        module.plot = plot_wrap\n\n        # Replace surf function with a wrapper function that discards\n        # its return value (see comment for plot function)\n        surf_original = module.surf\n\n        def surf_wrap(*args, **kwargs):\n            surf_original(*args, **kwargs)\n\n        module.surf = surf_wrap\n\n        # Replace contour function with a wrapper function that discards\n        # its return value (see comment for plot function)\n        contour_original = module.contour\n\n        def contour_wrap(*args, **kwargs):\n            contour_original(*args, **kwargs)\n\n        module.contour = contour_wrap\n\n        # Replace imview function with a wrapper function that discards\n        # its return value (see comment for plot function)\n        imview_original = module.imview\n\n        def imview_wrap(*args, **kwargs):\n            imview_original(*args, **kwargs)\n\n        module.imview = imview_wrap\n\n        # Disable figure show method (results in a warning if used within\n        # a notebook with inline plotting)\n        import matplotlib.figure\n\n        def show_disable(self):\n            pass\n\n        matplotlib.figure.Figure.show = show_disable\n"
  },
  {
    "path": "scico/random.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Random number generation.\n\nThis module provides convenient wrappers around several `jax.random\n<https://jax.readthedocs.io/en/stable/jax.random.html>`_ routines to\nhandle the generation and splitting of PRNG keys, as well as the\ngeneration of random :class:`.BlockArray`:\n\n::\n\n   # Calls to scico.random functions always return a PRNG key\n   # If no key is passed to the function, a new key is generated\n   x, key = scico.random.randn((2,))\n   print(x)   # [ 0.19307713 -0.52678305]\n\n   # scico.random functions automatically split the PRNG key and return\n   # an updated key\n   y, key = scico.random.randn((2,), key=key)\n   print(y) # [ 0.00870693 -0.04888531]\n\nThe user is responsible for passing the PRNG key to :mod:`scico.random`\nfunctions. If no key is passed, repeated calls to :mod:`scico.random`\nfunctions will return the same random numbers:\n\n::\n\n   x, key = scico.random.randn((2,))\n   print(x)   # [ 0.19307713 -0.52678305]\n\n   # No key passed, will return the same random numbers!\n   y, key = scico.random.randn((2,))\n   print(y)   # [ 0.19307713 -0.52678305]\n\n\nIf the desired shape is a tuple containing tuples, a :class:`.BlockArray`\nis returned:\n\n::\n\n   x, key = scico.random.randn( ((1, 1), (2,)), key=key)\n   print(x)  # scico.numpy.BlockArray:\n             # Array([ 1.1378784 , -1.220955  , -0.59153646], dtype=float32)\n\n\"\"\"\n\nimport inspect\nimport sys\nfrom typing import Optional, Tuple, Union\n\nimport numpy as np\n\nimport jax\n\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy._wrappers import map_func_over_args\nfrom scico.typing import BlockShape, DType, PRNGKey, Shape\n\n\ndef _add_seed(fun):\n    \"\"\"\n    Modify a :mod:`jax.random` function to add a `seed` argument.\n\n    Args:\n        fun: function to be modified, e.g., :func:`jax.random.normal`.\n        Expects `key` to be the first argument.\n\n    Returns:\n        fun_alt: a version of `fun` supporting an optional `seed`\n           argument that is used to create a :func:`jax.random.key`\n           that is passed along as the `key`. The `key` argument may\n           still be used, but is moved to be second-to-last. By default,\n           `seed=0`. The `seed` argument is added last. Other arguments\n           are unchanged.\n    \"\"\"\n\n    # find number of non-keyword-only parameters of fun\n    num_params = len(\n        [\n            param\n            for param in inspect.signature(fun).parameters.values()\n            if param.kind != param.KEYWORD_ONLY\n        ]\n    )\n\n    def fun_alt(*args, key=None, seed=None, **kwargs):\n\n        # key and seed may be in *args, look for them\n        if len(args) >= num_params:  # they passed all position args including key\n            key = args[num_params - 1]\n        if len(args) > num_params:  # they passed all position args including key and seed\n            seed = args[num_params]\n\n        if key is not None and seed is not None:\n            raise ValueError(\"Arguments 'key' and 'seed' may not both be specified.\")\n\n        if key is None:\n            if seed is None:\n                seed = 0\n            key = jax.random.key(seed)\n\n        result = fun(key, *args[: num_params - 1], **kwargs)\n\n        key, subkey = jax.random.split(key, 2)\n        return result, key\n\n    lines = fun.__doc__.split(\"\\n\\n\")\n    fun_alt.__doc__ = \"\\n\\n\".join(\n        lines[0:1]\n        + [\n            f\"  Wrapped version of `jax.random.{fun.__name__} \"\n            f\"<https://jax.readthedocs.io/en/stable/jax.random.html#jax.random.{fun.__name__}>`_. \"\n            \"The SCICO version of this function moves the `key` argument to the end of the \"\n            \"argument list, adds an additional `seed` argument after that, and allows the \"\n            \"`shape` argument to accept a nested list, in which case a `BlockArray` is returned. \"\n            \"Always returns a `(result, key)` tuple. Original docstring below.\",\n        ]\n        + lines[1:]\n    )\n\n    return fun_alt\n\n\ndef _wrap(fun):\n    fun_wrapped = _add_seed(map_func_over_args(fun, map_if_nested_args=[\"shape\"]))\n    fun_wrapped.__module__ = __name__  # so it appears in docs\n    return fun_wrapped\n\n\ndef _is_wrappable(fun):\n    params = inspect.signature(getattr(jax.random, fun)).parameters\n    prmkey = list(params.keys())\n    return prmkey and (prmkey[0] == \"key\") and (\"shape\" in params.keys())\n\n\nwrappable_func_names = [\n    t[0] for t in inspect.getmembers(jax.random, inspect.isfunction) if _is_wrappable(t[0])\n]\n\nfor name in wrappable_func_names:\n    setattr(sys.modules[__name__], name, _wrap(getattr(jax.random, name)))\n\n\ndef randn(\n    shape: Union[Shape, BlockShape],\n    dtype: DType = np.float32,\n    key: Optional[PRNGKey] = None,\n    seed: Optional[int] = None,\n) -> Tuple[Union[Array, BlockArray], PRNGKey]:\n    \"\"\"Return an array drawn from the standard normal distribution.\n\n    Alias for :func:`scico.random.normal`.\n\n    Args:\n        shape: Shape of output array. If shape is a tuple, a\n            :class:`jax.Array` is returned. If shape is a tuple of tuples,\n            a :class:`.BlockArray` is returned.\n        key: JAX PRNGKey. Defaults to ``None``, in which case a new key\n            is created using the seed arg.\n        seed: Seed for new PRNGKey. Default: 0.\n        dtype: dtype for returned value. Defaults to :attr:`~numpy.float32`.\n            If a complex dtype such as :attr:`~numpy.complex64`, generates\n            an array sampled from complex normal distribution.\n\n    Returns:\n        tuple: A tuple (x, key) containing:\n\n           - **x** : (:class:`jax.Array`):  Generated random array.\n           - **key** : Updated random PRNGKey.\n    \"\"\"\n    return normal(shape, dtype, key, seed)  # type: ignore\n"
  },
  {
    "path": "scico/ray/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2022-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Simplified interfaces to :doc:`Ray <ray:index>`.\"\"\"\n\nimport os\n\nos.environ[\"RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO\"] = \"0\"  # suppress ray warning\ntry:\n    from ray import get, put\n    from ray.tune import report\nexcept ImportError:\n    raise ImportError(\"Could not import ray; please install it.\")\n"
  },
  {
    "path": "scico/ray/tune.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Parameter tuning using :doc:`ray.tune <ray:tune/index>`.\"\"\"\n\nimport datetime\nimport getpass\nimport logging\nimport os\nimport tempfile\nfrom typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union\n\nimport ray\n\ntry:\n    import ray.tune\n\n    os.environ[\"RAY_AIR_NEW_OUTPUT\"] = \"0\"\nexcept ImportError:\n    raise ImportError(\"Could not import ray.tune; please install it.\")\nimport ray.air\nfrom ray.tune import (  # noqa\n    CheckpointConfig,\n    RunConfig,\n    Trainable,\n    loguniform,\n    uniform,\n    with_parameters,\n)\nfrom ray.tune.experiment.trial import Trial\nfrom ray.tune.progress_reporter import TuneReporterBase, _get_trials_by_state\nfrom ray.tune.result_grid import ResultGrid\nfrom ray.tune.schedulers import AsyncHyperBandScheduler\nfrom ray.tune.search.hyperopt import HyperOptSearch\n\n\nclass _CustomReporter(TuneReporterBase):\n    \"\"\"Custom status reporter for :mod:`ray.tune`.\"\"\"\n\n    def should_report(self, trials: List[Trial], done: bool = False):\n        \"\"\"Return boolean indicating whether progress should be reported.\"\"\"\n        # Don't report on final call when done to avoid duplicate final output.\n        return not done\n\n    def report(self, trials: List[Trial], done: bool, *sys_info: Dict):\n        \"\"\"Report progress across trials.\"\"\"\n        # Get dict of trials in each state.\n        trials_by_state = _get_trials_by_state(trials)\n        # Construct list of number of trials in each of three possible states.\n        num_trials = [len(trials_by_state[state]) for state in [\"PENDING\", \"RUNNING\", \"TERMINATED\"]]\n        # Construct string description of number of trials in each state.\n        num_trials_str = f\"P: {num_trials[0]:3d} R: {num_trials[1]:3d} T: {num_trials[2]:3d} \"\n        # Get current best trial.\n        current_best_trial, metric = self._current_best_trial(trials)\n        if current_best_trial is None:\n            rslt_str = \"\"\n        else:\n            # If current best trial exists, construct string summary\n            val = current_best_trial.last_result[metric]\n            config = current_best_trial.last_result.get(\"config\", {})\n            rslt_str = f\" {metric}: {val:.2e} at \" + \", \".join(\n                [f\"{k}: {v:.2e}\" for k, v in config.items()]\n            )\n        # If all trials terminated, print with newline, otherwise carriage return for overwrite\n        if num_trials[0] + num_trials[1] == 0:\n            end = \"\\n\"\n        else:\n            end = \"\\r\"\n        print(num_trials_str + rslt_str, end=end)\n\n\ndef run(\n    run_or_experiment: Union[str, Callable, Type],\n    metric: str,\n    mode: str,\n    time_budget_s: Union[None, int, float, datetime.timedelta] = None,\n    num_samples: int = 1,\n    resources_per_trial: Union[None, Mapping[str, Union[float, int, Mapping]]] = None,\n    max_concurrent_trials: Optional[int] = None,\n    config: Optional[Dict[str, Any]] = None,\n    hyperopt: bool = True,\n    verbose: bool = True,\n    storage_path: Optional[str] = None,\n) -> ray.tune.ExperimentAnalysis:\n    \"\"\"Simplified wrapper for `ray.tune.run`_.\n\n    .. _ray.tune.run: https://github.com/ray-project/ray/blob/master/python/ray/tune/tune.py#L232\n\n    The `ray.tune.run`_ interface appears to be scheduled for deprecation.\n    Use of :class:`Tuner`, which is a simplified interface to\n    :class:`ray.tune.Tuner` is recommended instead.\n\n    Args:\n        run_or_experiment: Function that reports performance values.\n        metric: Name of the metric reported in the performance evaluation\n            function.\n        mode: Either \"min\" or \"max\", indicating which represents better\n            performance.\n        time_budget_s: Maximum time allowed in seconds for the parameter\n            search.\n        num_samples: Number of parameter evaluation samples to compute.\n        resources_per_trial: A dict mapping keys \"cpu\" and \"gpu\" to\n            integers specifying the corresponding resources to allocate\n            for each performance evaluation trial.\n        max_concurrent_trials: Maximum number of trials to run\n            concurrently.\n        config: Specification of the parameter search space.\n        hyperopt: If ``True``, use\n            :class:`~ray.tune.search.hyperopt.HyperOptSearch` search,\n            otherwise use simple random search (see\n            :class:`~ray.tune.search.basic_variant.BasicVariantGenerator`).\n        verbose: Flag indicating whether verbose operation is desired.\n            When verbose operation is enabled, the number of pending,\n            running, and terminated trials are indicated by \"P:\", \"R:\",\n            and \"T:\" respectively, followed by the current best metric\n            value and the parameters at which it was reported.\n        storage_path: Directory in which to save tuning results. Defaults to\n            a subdirectory \"<username>/ray_results\" within the path returned by\n            `tempfile.gettempdir()`, corresponding e.g. to\n            \"/tmp/<username>/ray_results\" under Linux.\n\n    Returns:\n        Result of parameter search.\n    \"\"\"\n    kwargs = {}\n    if hyperopt:\n        kwargs.update(\n            {\n                \"search_alg\": HyperOptSearch(metric=metric, mode=mode),\n                \"scheduler\": AsyncHyperBandScheduler(),\n            }\n        )\n    if verbose:\n        kwargs.update({\"verbose\": 1, \"progress_reporter\": _CustomReporter()})\n    else:\n        kwargs.update({\"verbose\": 0})\n\n    if isinstance(run_or_experiment, str):\n        name = run_or_experiment\n    else:\n        name = run_or_experiment.__name__\n    name += \"_\" + datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n\n    if storage_path is None:\n        try:\n            user = getpass.getuser()\n        except Exception:  # pragma: no cover\n            user = \"NOUSER\"\n        storage_path = os.path.join(tempfile.gettempdir(), user, \"ray_results\")\n\n    # Record original logger.info\n    logger_info = ray.tune.tune.logger.info\n\n    # Replace logger.info with filtered version\n    def logger_info_filter(msg, *args, **kwargs):\n        if msg[0:15] != \"Total run time:\":\n            logger_info(msg, *args, **kwargs)\n\n    ray.tune.tune.logger.info = logger_info_filter\n\n    result = ray.tune.run(\n        run_or_experiment,\n        metric=metric,\n        mode=mode,\n        name=name,\n        time_budget_s=time_budget_s,\n        num_samples=num_samples,\n        storage_path=storage_path,\n        resources_per_trial=resources_per_trial,\n        max_concurrent_trials=max_concurrent_trials,\n        reuse_actors=True,\n        config=config,\n        checkpoint_freq=0,\n        **kwargs,\n    )\n\n    # Restore original logger.info\n    ray.tune.tune.logger.info = logger_info\n\n    return result\n\n\nclass Tuner(ray.tune.Tuner):\n    \"\"\"Simplified interface for :class:`ray.tune.Tuner`.\"\"\"\n\n    def __init__(\n        self,\n        trainable: Union[Type[ray.tune.Trainable], Callable],\n        *,\n        param_space: Optional[Dict[str, Any]] = None,\n        resources: Optional[Dict] = None,\n        max_concurrent_trials: Optional[int] = None,\n        metric: Optional[str] = None,\n        mode: Optional[str] = None,\n        num_samples: Optional[int] = None,\n        num_iterations: Optional[int] = None,\n        time_budget: Optional[int] = None,\n        reuse_actors: bool = True,\n        hyperopt: bool = True,\n        verbose: bool = True,\n        storage_path: Optional[str] = None,\n        **kwargs,\n    ):\n        \"\"\"\n        Args:\n           trainable: Function that reports performance values.\n           param_space: Specification of the parameter search space.\n           resources: A dict mapping keys \"cpu\" and \"gpu\" to integers\n              specifying the corresponding resources to allocate for each\n              performance evaluation trial.\n           max_concurrent_trials: Maximum number of trials to run\n            concurrently.\n           metric: Name of the metric reported in the performance\n              evaluation function.\n           mode: Either \"min\" or \"max\", indicating which represents\n              better performance.\n           num_samples: Number of parameter evaluation samples to compute.\n           num_iterations: Number of training iterations for evaluation\n              of a single configuration. Only required for the Tune Class\n              API.\n           time_budget: Maximum time allowed in seconds for a single\n              parameter evaluation.\n           reuse_actors: If ``True``, reuse the same process/object for\n              multiple hyperparameters.\n           hyperopt: If ``True``, use\n              :class:`~ray.tune.search.hyperopt.HyperOptSearch` search,\n              otherwise use simple random search (see\n              :class:`~ray.tune.search.basic_variant.BasicVariantGenerator`).\n           verbose: Flag indicating whether verbose operation is desired.\n              When verbose operation is enabled, the number of pending,\n              running, and terminated trials are indicated by \"P:\", \"R:\",\n              and \"T:\" respectively, followed by the current best metric\n              value and the parameters at which it was reported.\n           storage_path: Directory in which to save tuning results. Defaults\n              to a subdirectory \"<username>/ray_results\" within the path\n              returned by `tempfile.gettempdir()`, corresponding e.g. to\n              \"/tmp/<username>/ray_results\" under Linux.\n        \"\"\"\n\n        k: Any  # Avoid typing errors\n        v: Any\n\n        if resources is None:\n            trainable_with_resources = trainable\n        else:\n            trainable_with_resources = ray.tune.with_resources(trainable, resources)\n\n        tune_config = kwargs.pop(\"tune_config\", None)\n        tune_config_kwargs = {\n            \"mode\": mode,\n            \"metric\": metric,\n            \"num_samples\": num_samples,\n            \"reuse_actors\": reuse_actors,\n        }\n        if hyperopt:\n            tune_config_kwargs.update(\n                {\n                    \"search_alg\": HyperOptSearch(metric=metric, mode=mode),\n                    \"scheduler\": AsyncHyperBandScheduler(),\n                }\n            )\n        if max_concurrent_trials is not None:\n            tune_config_kwargs.update({\"max_concurrent_trials\": max_concurrent_trials})\n        if tune_config is None:\n            tune_config = ray.tune.TuneConfig(**tune_config_kwargs)\n        else:\n            for k, v in tune_config_kwargs.items():\n                setattr(tune_config, k, v)\n\n        name = trainable.__name__ + \"_\" + datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n        if storage_path is None:\n            try:\n                user = getpass.getuser()\n            except Exception:  # pragma: no cover\n                user = \"NOUSER\"\n            storage_path = os.path.join(tempfile.gettempdir(), user, \"ray_results\")\n\n        run_config = kwargs.pop(\"run_config\", None)\n        run_config_kwargs = {\"name\": name, \"storage_path\": storage_path, \"verbose\": 0}\n        if verbose:\n            run_config_kwargs.update({\"verbose\": 1, \"progress_reporter\": _CustomReporter()})\n        if num_iterations is not None or time_budget is not None:\n            stop_criteria = {}\n            if num_iterations is not None:\n                stop_criteria.update({\"training_iteration\": num_iterations})\n            if time_budget is not None:\n                stop_criteria.update({\"time_total_s\": time_budget})\n            run_config_kwargs.update({\"stop\": stop_criteria})\n        if run_config is None:\n            run_config_kwargs.update(\n                {\"checkpoint_config\": CheckpointConfig(checkpoint_at_end=False)}\n            )\n            run_config = RunConfig(**run_config_kwargs)\n        else:\n            for k, v in run_config_kwargs.items():\n                setattr(run_config, k, v)\n\n        super().__init__(\n            trainable_with_resources,\n            param_space=param_space,\n            tune_config=tune_config,\n            run_config=run_config,\n            **kwargs,\n        )\n\n    def fit(self) -> ResultGrid:\n        \"\"\"Initialize ray and call :meth:`ray.tune.Tuner.fit`.\n\n        Initialize ray if not already initialized, and call\n        :meth:`ray.tune.Tuner.fit`. If ray was not previously initialized,\n        shut it down after fit process has completed.\n\n        Returns:\n           Result of parameter search.\n        \"\"\"\n        ray_init = ray.is_initialized()\n        if not ray_init:\n            ray.init(logging_level=logging.ERROR)\n\n        results = super().fit()\n\n        if not ray_init:\n            ray.shutdown()\n\n        return results\n"
  },
  {
    "path": "scico/scipy/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2024 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Wrapped versions of `jax.scipy <https://jax.readthedocs.io/en/latest/jax.scipy.html>`_ functions.\n\nThis modules currently serves simply as a namespace for :mod:`scico.scipy.special`.\n\"\"\"\n\nfrom . import special\n"
  },
  {
    "path": "scico/scipy/special.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\":class:`~scico.numpy.BlockArray`-compatible :mod:`jax.scipy.special`\nfunctions.\n\nThis modules is a wrapper for :mod:`jax.scipy.special` where some\nfunctions have been extended to automatically map over block array\nblocks as described in :ref:`numpy_functions_blockarray`.\n\"\"\"\n\nfrom typing import Tuple\n\nimport jax.scipy.special as js\n\nfrom scico.numpy import _wrappers\n\n# add most everything in jax.scipy.special to this module\n_wrappers.add_attributes(\n    to_dict=vars(),\n    from_dict=js.__dict__,\n)\n\n# wrap select functions\nfunctions: Tuple[str, ...] = (\n    \"betainc\",\n    \"entr\",\n    \"erf\",\n    \"erfc\",\n    \"erfinv\",\n    \"expit\",\n    \"gammainc\",\n    \"gammaincc\",\n    \"gammaln\",\n    \"i0\",\n    \"i0e\",\n    \"i1\",\n    \"i1e\",\n    \"log_ndtr\",\n    \"logit\",\n    \"logsumexp\",\n    \"multigammaln\",\n    \"ndtr\",\n    \"ndtri\",\n    \"polygamma\",\n    \"xlog1py\",\n    \"xlogy\",\n    \"zeta\",\n    \"digamma\",\n)\nif hasattr(js, \"sph_harm_y\"):  # not available in all supported jax versions\n    functions += (\"sph_harm_y\",)\nelse:\n    functions += (\"sph_harm\",)\n_wrappers.wrap_recursively(vars(), functions, _wrappers.map_func_over_args)\n\n# clean up\ndel js, _wrappers\n"
  },
  {
    "path": "scico/solver.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Solver and optimization algorithms.\n\nThis module provides a number of functions for solving linear systems and\noptimization problems, some of which are used as subproblem solvers\nwithin the iterations of the proximal algorithms in the\n:mod:`scico.optimize` subpackage.\n\nThis module also provides scico interface wrappers for functions\nfrom :mod:`scipy.optimize` since jax directly implements only a very\nlimited subset of these functions (there is limited, experimental support\nfor `L-BFGS-B <https://github.com/google/jax/pull/6053>`_), but only CG\nand BFGS are fully supported. These wrappers are required because the\nfunctions in :mod:`scipy.optimize` only support on 1D, real valued, numpy\narrays. These limitations are addressed by:\n\n- Enabling the use of multi-dimensional arrays by flattening and reshaping\n  within the wrapper.\n- Enabling the use of jax arrays by automatically converting to and from\n  numpy arrays.\n- Enabling the use of complex arrays by splitting them into real and\n  imaginary parts.\n\nThe wrapper also JIT compiles the function and gradient evaluations.\n\nThese wrapper functions have a number of advantages and disadvantages\nwith respect to those in :mod:`jax.scipy.optimize`:\n\n- This module provides many more algorithms than\n  :mod:`jax.scipy.optimize`.\n- The functions in this module tend to be faster for small-scale problems\n  (presumably due to some overhead in the jax functions).\n- The functions in this module are slower for large problems due to the\n  frequent host-device copies corresponding to conversion between numpy\n  arrays and jax arrays.\n- The solvers in this module can't be JIT compiled, and gradients cannot\n  be taken through them.\n\nIn the future, these wrapper functions may be replaced with a dependency\non `JAXopt <https://github.com/google/jaxopt>`__.\n\"\"\"\n\nfrom functools import wraps\nfrom typing import Any, Callable, Optional, Sequence, Tuple, Union\n\nimport numpy as np\n\nimport jax\nimport jax.numpy as jnp\nimport jax.scipy.linalg as jsl\n\nimport scico.numpy as snp\nfrom scico.linop import (\n    CircularConvolve,\n    ComposedLinearOperator,\n    Diagonal,\n    LinearOperator,\n    MatrixOperator,\n    Sum,\n)\nfrom scico.metric import rel_res\nfrom scico.numpy import Array, BlockArray\nfrom scico.numpy.util import is_complex_dtype, is_nested, is_real_dtype\nfrom scico.typing import BlockShape, DType, Shape\nfrom scipy import optimize as spopt\n\n\ndef _wrap_func(func: Callable, shape: Union[Shape, BlockShape], dtype: DType) -> Callable:\n    \"\"\"Function evaluation for use in :mod:`scipy.optimize`.\n\n    Compute function evaluation (without gradient) for use in\n    :mod:`scipy.optimize` functions. Reshapes the input to `func` to\n    have `shape`. Evaluates `func`.\n\n    Args:\n        func: The function to minimize.\n        shape: Shape of input to `func`.\n        dtype: Data type of input to `func`.\n    \"\"\"\n\n    val_func = jax.jit(func)\n\n    @wraps(func)\n    def wrapper(x, *args):\n        # apply val_grad_func to un-vectorized input\n        val = val_func(_unravel(x, shape).astype(dtype), *args)\n\n        # Convert val into numpy array, cast to float, convert to scalar\n        val = np.array(val).astype(float)\n        val = val.item() if val.ndim == 0 else val[0].item()\n\n        return val\n\n    return wrapper\n\n\ndef _wrap_func_and_grad(func: Callable, shape: Union[Shape, BlockShape], dtype: DType) -> Callable:\n    \"\"\"Function evaluation and gradient for use in :mod:`scipy.optimize`.\n\n    Compute function evaluation and gradient for use in\n    :mod:`scipy.optimize` functions. Reshapes the input to `func` to\n    have `shape`.  Evaluates `func` and computes gradient. Ensures\n    the returned `grad` is an ndarray.\n\n    Args:\n        func: The function to minimize.\n        shape: Shape of input to `func`.\n        dtype: Data type of input to `func`.\n    \"\"\"\n\n    # argnums=0 ensures only differentiate func wrt first argument,\n    #   in case func signature is func(x, *args)\n    val_grad_func = jax.jit(jax.value_and_grad(func, argnums=0))\n\n    @wraps(func)\n    def wrapper(x, *args):\n        # apply val_grad_func to un-vectorized input\n        val, grad = val_grad_func(_unravel(x, shape).astype(dtype), *args)\n\n        # Convert val & grad into numpy arrays, then cast to float\n        # Convert 'val' into a scalar, rather than ndarray of shape (1,)\n        val = np.array(val).astype(float).item()\n        grad = np.array(grad).astype(float).ravel()\n        return val, grad\n\n    return wrapper\n\n\ndef _split_real_imag(x: Union[Array, BlockArray]) -> Union[Array, BlockArray]:\n    \"\"\"Split an array of shape (N, M, ...) into real and imaginary parts.\n\n    Args:\n        x: Array to split.\n\n    Returns:\n        A real ndarray with stacked real/imaginary parts. If `x` has\n        shape (M, N, ...), the returned array will have shape\n        (2, M, N, ...) where the first slice contains the `x.real` and\n        the second contains `x.imag`. If `x` is a BlockArray, this\n        function is called on each block and the output is joined into a\n        BlockArray.\n    \"\"\"\n    if isinstance(x, BlockArray):\n        return snp.blockarray([_split_real_imag(_) for _ in x])\n    return snp.stack((snp.real(x), snp.imag(x)))\n\n\ndef _join_real_imag(x: Union[Array, BlockArray]) -> Union[Array, BlockArray]:\n    \"\"\"Join a real array of shape (2,N,M,...) into a complex array.\n\n    Join a real array of shape (2,N,M,...) into a complex array of length\n    (N,M, ...).\n\n    Args:\n        x: Array to join.\n\n    Returns:\n        A complex array with real and imaginary parts taken from `x[0]`\n        and `x[1]` respectively.\n    \"\"\"\n    if isinstance(x, BlockArray):\n        return snp.blockarray([_join_real_imag(_) for _ in x])\n    return x[0] + 1j * x[1]\n\n\ndef _ravel(x: Union[Array, BlockArray]) -> Array:\n    \"\"\"Vectorize an array or blockarray to a 1d array.\n\n    Args:\n        x: Array or blockarray to be vectorized.\n\n    Returns:\n        Vectorized array.\n    \"\"\"\n    if isinstance(x, snp.BlockArray):\n        return jnp.hstack(x.ravel().arrays)\n    else:\n        return x.ravel()\n\n\ndef _unravel(x: Array, shape: Union[Shape, BlockShape]) -> Union[Array, BlockArray]:\n    \"\"\"Return a vectorized array or blockarray to its original shape.\n\n    Args:\n        x: Vectorized array representation.\n        shape: Shape of original array or blockarray.\n\n    Returns:\n        Array or blockarray with original shape.\n    \"\"\"\n    if is_nested(shape):\n        sizes = [np.prod(e).item() for e in shape]\n        indices = np.cumsum(sizes[:-1])\n        chunks = jnp.split(x, indices)\n        return snp.BlockArray([chunks[k].reshape(cs) for k, cs in enumerate(shape)])\n    else:\n        return x.reshape(shape)\n\n\ndef minimize(\n    func: Callable,\n    x0: Union[Array, BlockArray],\n    args: Union[Tuple, Tuple[Any]] = (),\n    method: str = \"L-BFGS-B\",\n    hess: Optional[Union[Callable, str]] = None,\n    hessp: Optional[Callable] = None,\n    bounds: Optional[Union[Sequence, spopt.Bounds]] = None,\n    constraints: Union[spopt.LinearConstraint, spopt.NonlinearConstraint, dict] = (),\n    tol: Optional[float] = None,\n    callback: Optional[Callable] = None,\n    options: Optional[dict] = None,\n) -> spopt.OptimizeResult:\n    \"\"\"Minimization of scalar function of one or more variables.\n\n    Wrapper around :func:`scipy.optimize.minimize`. This function differs\n    from :func:`scipy.optimize.minimize` in three ways:\n\n        - The `jac` options of :func:`scipy.optimize.minimize` are not\n          supported. The gradient is calculated using :func:`jax.grad`.\n        - Functions mapping from N-dimensional arrays -> float are\n          supported.\n        - Functions mapping from complex arrays -> float are supported.\n\n    For more detail, including descriptions of the optimization methods\n    and custom minimizers, refer to the original docs for\n    :func:`scipy.optimize.minimize`.\n    \"\"\"\n\n    if is_complex_dtype(x0.dtype):\n        # scipy minimize function requires real-valued arrays, so\n        # we split x0 into a vector with real/imaginary parts stacked\n        # and compose `func` with a `_join_real_imag`\n        iscomplex = True\n        func_real = lambda x: func(_join_real_imag(x))\n        x0 = _split_real_imag(x0)\n    else:\n        iscomplex = False\n        func_real = func\n\n    x0_shape = x0.shape\n    x0_dtype = x0.dtype\n    x0 = _ravel(x0)\n\n    # Run the SciPy minimizer\n    if method in (\n        \"CG, BFGS, Newton-CG, L-BFGS-B, TNC, SLSQP, dogleg, trust-ncg, trust-krylov, \"\n        \"trust-exact, trust-constr\"\n    ).split(\n        \", \"\n    ):  # uses gradient info\n        min_func = _wrap_func_and_grad(func_real, x0_shape, x0_dtype)\n        jac = True  # see scipy.minimize docs\n    else:  # does not use gradient info\n        min_func = _wrap_func(func_real, x0_shape, x0_dtype)\n        jac = False\n\n    res = spopt.OptimizeResult({\"x\": None})\n\n    def fun(x0):\n        nonlocal res  # To use the external res\n        res = spopt.minimize(\n            min_func,\n            x0=x0,\n            args=args,\n            jac=jac,\n            method=method,\n            options=options,\n        )  # Return OptimizeResult with x0 as ndarray\n        return res.x.astype(x0_dtype)\n\n    res.x = jax.pure_callback(\n        fun,\n        jax.ShapeDtypeStruct(x0.shape, x0_dtype),\n        x0,\n    )\n\n    res.x = _unravel(res.x, x0_shape)  # un-vectorize the output array from spopt.minimize\n    if iscomplex:\n        res.x = _join_real_imag(res.x)\n\n    return res\n\n\ndef minimize_scalar(\n    func: Callable,\n    bracket: Optional[Sequence[float]] = None,\n    bounds: Optional[Sequence[float]] = None,\n    args: Union[Tuple, Tuple[Any]] = (),\n    method: str = \"brent\",\n    tol: Optional[float] = None,\n    options: Optional[dict] = None,\n) -> spopt.OptimizeResult:\n    \"\"\"Minimization of scalar function of one variable.\n\n    Wrapper around :func:`scipy.optimize.minimize_scalar`.\n\n    For more detail, including descriptions of the optimization methods\n    and custom minimizers, refer to the original docstring for\n    :func:`scipy.optimize.minimize_scalar`.\n    \"\"\"\n\n    def f(x, *args):\n        # Wrap jax-based function `func` to return a numpy float rather\n        # than a jax array of size (1,)\n        y = func(x, *args)\n        return y.item() if y.ndim == 0 else y[0].item()\n\n    res = spopt.minimize_scalar(\n        fun=f,\n        bracket=bracket,\n        bounds=bounds,\n        args=args,\n        method=method,\n        tol=tol,\n        options=options,\n    )\n    return res\n\n\ndef cg(\n    A: Callable,\n    b: Array,\n    x0: Optional[Array] = None,\n    *,\n    tol: float = 1e-5,\n    atol: float = 0.0,\n    maxiter: int = 1000,\n    info: bool = True,\n    M: Optional[Callable] = None,\n) -> Tuple[Array, dict]:\n    r\"\"\"Conjugate Gradient solver.\n\n    Solve the linear system :math:`A\\mb{x} = \\mb{b}`, where :math:`A` is\n    positive definite, via the conjugate gradient method.\n\n    Args:\n        A: Callable implementing linear operator :math:`A`, which should\n           be positive definite.\n        b: Input array :math:`\\mb{b}`.\n        x0: Initial solution. If `A` is a :class:`.LinearOperator`, this\n          parameter need not be specified, and defaults to a zero array.\n          Otherwise, it is required.\n        tol: Relative residual stopping tolerance. Convergence occurs\n           when `norm(residual) <= max(tol * norm(b), atol)`.\n        atol: Absolute residual stopping tolerance. Convergence occurs\n           when `norm(residual) <= max(tol * norm(b), atol)`.\n        maxiter: Maximum iterations. Default: 1000.\n        info: If ``True`` return a tuple consting of the solution array\n           and a dictionary containing diagnostic information, otherwise\n           just return the solution.\n        M: Preconditioner for `A`. The preconditioner should approximate\n           the inverse of `A`. The default, ``None``, uses no\n           preconditioner.\n\n    Returns:\n        tuple: A tuple (x, info) containing:\n\n            - **x** : Solution array.\n            - **info**: Dictionary containing diagnostic information.\n    \"\"\"\n    if x0 is None:\n        if isinstance(A, LinearOperator):\n            x0 = snp.zeros(A.input_shape, b.dtype)\n        else:\n            raise ValueError(\n                \"Argument 'x0' must be specified if argument 'A' is not a LinearOperator.\"\n            )\n\n    if M is None:\n        M = lambda x: x\n\n    x = x0\n    Ax = A(x0)\n    bn = snp.linalg.norm(b)\n    r = b - Ax\n    z = M(r)\n    p = z\n    num = snp.sum(r.conj() * z)\n    ii = 0\n\n    # termination tolerance (uses the \"non-legacy\" form of scicpy.sparse.linalg.cg)\n    termination_tol_sq = snp.maximum(tol * bn, atol) ** 2\n\n    while (ii < maxiter) and (num > termination_tol_sq):\n        Ap = A(p)\n        alpha = num / snp.sum(p.conj() * Ap)\n        x = x + alpha * p\n        r = r - alpha * Ap\n        z = M(r)\n        num_old = num\n        num = snp.sum(r.conj() * z)\n        beta = num / num_old\n        p = z + beta * p\n        ii += 1\n\n    if info:\n        return (x, {\"num_iter\": ii, \"rel_res\": snp.sqrt(num).real / bn})\n    else:\n        return x\n\n\ndef lstsq(\n    A: Callable,\n    b: Array,\n    x0: Optional[Array] = None,\n    tol: float = 1e-5,\n    atol: float = 0.0,\n    maxiter: int = 1000,\n    info: bool = False,\n    M: Optional[Callable] = None,\n) -> Tuple[Array, dict]:\n    r\"\"\"Least squares solver.\n\n    Solve the least squares problem\n\n    .. math::\n        \\argmin_{\\mb{x}} \\; (1/2) \\norm{ A \\mb{x} - \\mb{b} }_2^2 \\;,\n\n    where :math:`A` is a linear operator and :math:`\\mb{b}` is a vector.\n    The problem is solved using :func:`cg`.\n\n    Args:\n        A: Callable implementing linear operator :math:`A`.\n        b: Input array :math:`\\mb{b}`.\n        x0: Initial solution. If `A` is a :class:`.LinearOperator`, this\n          parameter need not be specified, and defaults to a zero array.\n          Otherwise, it is required.\n        tol: Relative residual stopping tolerance. Convergence occurs\n           when `norm(residual) <= max(tol * norm(b), atol)`.\n        atol: Absolute residual stopping tolerance. Convergence occurs\n           when `norm(residual) <= max(tol * norm(b), atol)`.\n        maxiter: Maximum iterations. Default: 1000.\n        info: If ``True`` return a tuple consting of the solution array\n           and a dictionary containing diagnostic information, otherwise\n           just return the solution.\n        M: Preconditioner for `A`. The preconditioner should approximate\n           the inverse of `A`. The default, ``None``, uses no\n           preconditioner.\n\n    Returns:\n        tuple: A tuple (x, info) containing:\n\n            - **x** : Solution array.\n            - **info**: Dictionary containing diagnostic information.\n    \"\"\"\n    if isinstance(A, LinearOperator):\n        Aop = A\n    else:\n        assert x0 is not None\n        Aop = LinearOperator(\n            input_shape=x0.shape,\n            output_shape=b.shape,\n            eval_fn=A,\n            input_dtype=b.dtype,\n            output_dtype=b.dtype,\n        )\n\n    ATA = Aop.T @ Aop\n    ATb = Aop.T @ b\n    return cg(ATA, ATb, x0=x0, tol=tol, atol=atol, maxiter=maxiter, info=info, M=M)\n\n\ndef bisect(\n    f: Callable,\n    a: Array,\n    b: Array,\n    args: Tuple = (),\n    xtol: float = 1e-7,\n    ftol: float = 1e-7,\n    maxiter: int = 100,\n    full_output: bool = False,\n    range_check: bool = True,\n) -> Union[Array, dict]:\n    \"\"\"Vectorised root finding via bisection method.\n\n    Vectorised root finding via bisection method, supporting\n    simultaneous finding of multiple roots on a function defined over a\n    multi-dimensional array. When the function is array-valued, each of\n    these values is treated as the independent application of a scalar\n    function. The initial interval `[a, b]` must bracket the root for all\n    scalar functions.\n\n    The interface is similar to that of :func:`scipy.optimize.bisect`,\n    which is much faster when `f` is a scalar function and `a` and `b`\n    are scalars.\n\n    Args:\n        f: Function returning a float or an array of floats.\n        a: Lower bound of interval on which to apply bisection.\n        b: Upper bound of interval on which to apply bisection.\n        args: Additional arguments for function `f`.\n        xtol: Stopping tolerance based on maximum bisection interval\n            length over array.\n        ftol: Stopping tolerance based on maximum absolute function value\n            over array.\n        maxiter: Maximum number of algorithm iterations.\n        full_output: If ``False``, return just the root, otherwise return a\n            tuple `(x, info)` where `x` is the root and `info` is a dict\n            containing algorithm status information.\n        range_check: If ``True``, check to ensure that the initial\n            `[a, b]` range brackets the root of `f`.\n\n    Returns:\n        tuple: A tuple `(x, info)` containing:\n\n            - **x** : Root array.\n            - **info**: Dictionary containing diagnostic information.\n    \"\"\"\n\n    fa = f(*((a,) + args))\n    fb = f(*((b,) + args))\n    if range_check and snp.any(snp.sign(fa) == snp.sign(fb)):\n        raise ValueError(\"Initial bisection range does not bracket zero.\")\n\n    for numiter in range(maxiter):\n        c = (a + b) / 2.0\n        fc = f(*((c,) + args))\n        fcs = snp.sign(fc)\n        a = snp.where(snp.logical_or(snp.sign(fa) * fcs == 1, fc == 0.0), c, a)\n        b = snp.where(snp.logical_or(fcs * snp.sign(fb) == 1, fc == 0.0), c, b)\n        fa = f(*((a,) + args))\n        fb = f(*((b,) + args))\n        xerr = snp.max(snp.abs(b - a))\n        ferr = snp.max(snp.abs(fc))\n        if xerr <= xtol and ferr <= ftol:\n            break\n\n    idx = snp.argmin(snp.stack((snp.abs(fa), snp.abs(fb))), axis=0)\n    x = snp.choose(idx, (a, b))\n    if full_output:\n        r = x, {\"iter\": numiter, \"xerr\": xerr, \"ferr\": ferr, \"a\": a, \"b\": b}\n    else:\n        r = x\n    return r\n\n\ndef golden(\n    f: Callable,\n    a: Array,\n    b: Array,\n    c: Optional[Array] = None,\n    args: Tuple = (),\n    xtol: float = 1e-7,\n    maxiter: int = 100,\n    full_output: bool = False,\n) -> Union[Array, dict]:\n    \"\"\"Vectorised scalar minimization via golden section method.\n\n    Vectorised scalar minimization via golden section method, supporting\n    simultaneous minimization of a function defined over a\n    multi-dimensional array. When the function is array-valued, each of\n    these values is treated as the independent application of a scalar\n    function. The minimizer must lie within the interval `(a, b)` for all\n    scalar functions, and, if specified `c` must be within that interval.\n\n\n    The interface is more similar to that of :func:`.bisect` than that of\n    :func:`scipy.optimize.golden` which is much faster when `f` is a\n    scalar function and `a`, `b`, and `c` are scalars.\n\n    Args:\n        f: Function returning a float or an array of floats.\n        a: Lower bound of interval on which to search.\n        b: Upper bound of interval on which to search.\n        c: Initial value for first search point interior to bounding\n            interval `(a, b)`\n        args: Additional arguments for function `f`.\n        xtol: Stopping tolerance based on maximum search interval length\n            over array.\n        maxiter: Maximum number of algorithm iterations.\n        full_output: If ``False``, return just the minizer, otherwise\n            return a tuple `(x, info)` where `x` is the minimizer and\n            `info` is a dict containing algorithm status information.\n\n    Returns:\n        tuple: A tuple `(x, info)` containing:\n\n            - **x** : Minimizer array.\n            - **info**: Dictionary containing diagnostic information.\n    \"\"\"\n    gr = 2 / (snp.sqrt(5) + 1)\n    if c is None:\n        c = b - gr * (b - a)\n    d = a + gr * (b - a)\n    for numiter in range(maxiter):\n        fc = f(*((c,) + args))\n        fd = f(*((d,) + args))\n        b = snp.where(fc < fd, d, b)\n        a = snp.where(fc >= fd, c, a)\n        xerr = snp.amax(snp.abs(b - a))\n        if xerr <= xtol:\n            break\n        c = b - gr * (b - a)\n        d = a + gr * (b - a)\n\n    fa = f(*((a,) + args))\n    fb = f(*((b,) + args))\n    idx = snp.argmin(snp.stack((fa, fb)), axis=0)\n    x = snp.choose(idx, (a, b))\n    if full_output:\n        r = (x, {\"iter\": numiter, \"xerr\": xerr})\n    else:\n        r = x\n    return r\n\n\nclass MatrixATADSolver:\n    r\"\"\"Solver for linear system involving a symmetric product.\n\n    Solve a linear system of the form\n\n    .. math::\n\n       (A^T W A + D) \\mb{x} = \\mb{b}\n\n    or\n\n    .. math::\n\n       (A^T W A + D) X = B \\;,\n\n    where :math:`A \\in \\mbb{R}^{M \\times N}`,\n    :math:`W \\in \\mbb{R}^{M \\times M}` and\n    :math:`D \\in \\mbb{R}^{N \\times N}`. :math:`A` must be an instance of\n    :class:`.MatrixOperator` or an array; :math:`D` must be an instance\n    of :class:`.MatrixOperator`, :class:`.Diagonal`, or an array, and\n    :math:`W`, if specified, must be an instance of :class:`.Diagonal`\n    or an array.\n\n\n    The solution is computed by factorization of matrix\n    :math:`A^T W A + D` and solution via Gaussian elimination. If\n    :math:`D` is diagonal and :math:`N < M` (i.e. :math:`A W A^T` is\n    smaller than :math:`A^T W A`), then :math:`A W A^T + D` is factorized\n    and the original problem is solved via the Woodbury matrix identity\n\n    .. math::\n\n       (E + U C V)^{-1} = E^{-1} - E^{-1} U (C^{-1} + V E^{-1} U)^{-1}\n       V E^{-1} \\;.\n\n    Setting\n\n    .. math::\n\n       E &= D \\\\\n       U &= A^T \\\\\n       C &= W \\\\\n       V &= A\n\n    we have\n\n    .. math::\n\n       (D + A^T W A)^{-1} = D^{-1} - D^{-1} A^T (W^{-1} + A D^{-1} A^T)^{-1} A\n       D^{-1}\n\n    which can be simplified to\n\n    .. math::\n\n       (D + A^T W A)^{-1} = D^{-1} (I - A^T G^{-1} A D^{-1})\n\n    by defining :math:`G = W^{-1} + A D^{-1} A^T`. We therefore have that\n\n    .. math::\n\n       \\mb{x} = (D + A^T W A)^{-1} \\mb{b} = D^{-1} (I - A^T G^{-1} A\n       D^{-1}) \\mb{b} \\;.\n\n    If we have a Cholesky factorization of :math:`G`, e.g.\n    :math:`G = L L^T`, we can define\n\n    .. math::\n\n       \\mb{w} = G^{-1} A D^{-1} \\mb{b}\n\n    so that\n\n    .. math::\n\n       G \\mb{w} &= A D^{-1} \\mb{b} \\\\\n       L L^T \\mb{w} &= A D^{-1} \\mb{b} \\;.\n\n    The Cholesky factorization can be exploited by solving for\n    :math:`\\mb{z}` in\n\n    .. math::\n\n       L \\mb{z} = A D^{-1} \\mb{b}\n\n    and then for :math:`\\mb{w}` in\n\n    .. math::\n\n       L^T \\mb{w} = \\mb{z} \\;,\n\n    so that\n\n    .. math::\n\n       \\mb{x} = D^{-1} \\mb{b} - D^{-1} A^T \\mb{w} \\;.\n\n    (Functions :func:`~jax.scipy.linalg.cho_solve` and\n    :func:`~jax.scipy.linalg.lu_solve` allow direct solution for\n    :math:`\\mb{w}` without the two-step procedure described here.) A\n    Cholesky factorization should only be used when :math:`G` is\n    positive-definite (e.g. :math:`D` is diagonal and positive); if not,\n    an LU factorization should be used.\n\n    Complex-valued problems are also supported, in which case the\n    transpose :math:`\\cdot^T` in the equations above should be taken to\n    represent the conjugate transpose.\n\n    To solve problems directly involving a matrix of the form\n    :math:`A W A^T + D`, initialize with :code:`A.T` (or\n    :code:`A.T.conj()` for complex problems) instead of :code:`A`.\n    \"\"\"\n\n    def __init__(\n        self,\n        A: Union[MatrixOperator, Array],\n        D: Union[MatrixOperator, Diagonal, Array],\n        W: Optional[Union[Diagonal, Array]] = None,\n        cho_factor: bool = False,\n        lower: bool = False,\n        check_finite: bool = True,\n    ):\n        r\"\"\"\n        Args:\n            A: Matrix :math:`A`.\n            D: Matrix :math:`D`. If a 2D array or :class:`MatrixOperator`,\n                specifies the 2D matrix :math:`D`. If 1D array or\n                :class:`Diagonal`, specifies the diagonal elements\n                of :math:`D`.\n            W: Matrix :math:`W`. Specifies the diagonal elements of\n                :math:`W`. Defaults to an array with unit entries.\n            cho_factor: Flag indicating whether to use Cholesky\n                (``True``) or LU (``False``) factorization.\n            lower: Flag indicating whether lower (``True``) or upper\n                (``False``) triangular factorization should be computed.\n                Only relevant to Cholesky factorization.\n            check_finite: Flag indicating whether the input array should\n                be checked for ``Inf`` and ``NaN`` values.\n        \"\"\"\n        A = jnp.array(A)\n\n        if isinstance(D, Diagonal):\n            D = D.diagonal\n            if D.ndim > 1:  # Identity operator has 0D diagonal\n                raise ValueError(\"If Diagonal, 'D' should have a 0D or 1D diagonal.\")\n        else:\n            D = jnp.array(D)\n            if not D.ndim in [1, 2]:\n                raise ValueError(\"If array or MatrixOperator, 'D' should be 1D or 2D.\")\n\n        if W is None:\n            W = snp.ones(A.shape[0], dtype=A.dtype)\n        elif isinstance(W, Diagonal):\n            W = W.diagonal\n            assert hasattr(W, \"ndim\")\n            if W.ndim > 1:  # Identity operator has 0D diagonal\n                raise ValueError(\"If Diagonal, 'W' should have a 0D or 1D diagonal.\")\n        elif not isinstance(W, Array):\n            raise TypeError(\n                f\"Operator 'W' is required to be None, a Diagonal, or an array; got a {type(W)}.\"\n            )\n\n        self.A = A\n        self.D = D\n        self.W = W\n        self.cho_factor = cho_factor\n        self.lower = lower\n        self.check_finite = check_finite\n\n        assert isinstance(W, Array)\n        N, M = A.shape\n        if N < M and D.ndim <= 1:\n            D2 = D if D.ndim == 0 else D[:, snp.newaxis]\n            if W.ndim == 1:\n                G = snp.diag(1.0 / W) + A @ (A.T.conj() / D2)\n            else:  # W is 0 dimensional (scalar equivalent)\n                G = A @ (A.T.conj() / D2)\n                G = jnp.fill_diagonal(G, G.diagonal() + (1.0 / W), inplace=False)\n        else:\n            W2 = W if W.ndim == 0 else W[:, snp.newaxis]\n            if D.ndim == 1:\n                G = A.T.conj() @ (W2 * A) + snp.diag(D)\n            else:\n                G = A.T.conj() @ (W2 * A) + D\n\n        if cho_factor:\n            c, lower = jsl.cho_factor(G, lower=lower, check_finite=check_finite)\n            self.factor = (c, lower)\n        else:\n            lu, piv = jsl.lu_factor(G, check_finite=check_finite)\n            self.factor = (lu, piv)\n\n    def solve(self, b: Array, check_finite: Optional[bool] = None) -> Array:\n        r\"\"\"Solve the linear system.\n\n        Solve the linear system with right hand side :math:`\\mb{b}` (`b`\n        is a vector) or :math:`B` (`b` is a 2d array).\n\n        Args:\n           b: Vector :math:`\\mathbf{b}` or matrix :math:`B`.\n           check_finite: Flag indicating whether the input array should\n               be checked for ``Inf`` and ``NaN`` values. If ``None``,\n               use the value selected on initialization.\n\n        Returns:\n          Solution to the linear system.\n        \"\"\"\n        if check_finite is None:\n            check_finite = self.check_finite\n        if self.cho_factor:\n            fact_solve = lambda x: jsl.cho_solve(self.factor, x, check_finite=check_finite)\n        else:\n            fact_solve = lambda x: jsl.lu_solve(self.factor, x, trans=0, check_finite=check_finite)\n\n        if b.ndim <= 1:\n            D = self.D\n        else:\n            D = self.D[:, snp.newaxis]\n        N, M = self.A.shape\n        if N < M and self.D.ndim <= 1:\n            w = fact_solve(self.A @ (b / D))\n            x = (b - (self.A.T.conj() @ w)) / D\n        else:\n            x = fact_solve(b)\n\n        return x\n\n    def accuracy(self, x: Array, b: Array) -> float:\n        r\"\"\"Compute solution relative residual.\n\n        Args:\n           x: Array :math:`\\mathbf{x}` (solution).\n           b: Array :math:`\\mathbf{b}` (right hand side of linear system).\n\n        Returns:\n           Relative residual of solution.\n        \"\"\"\n        if b.ndim == 1:\n            D = self.D\n        else:\n            D = self.D[:, snp.newaxis]\n        assert isinstance(self.W, Array)\n        return rel_res(self.A.T.conj() @ (self.W[:, snp.newaxis] * self.A) @ x + D * x, b)\n\n\nclass ConvATADSolver:\n    r\"\"\"Solver for a linear system involving a sum of convolutions.\n\n    Solve a linear system of the form\n\n    .. math::\n\n       (A^H A + D) \\mb{x} = \\mb{b}\n\n    where :math:`A` is a block-row operator with circulant blocks, i.e. it\n    can be written as\n\n    .. math::\n\n       A = \\left( \\begin{array}{cccc} A_1 & A_2 & \\ldots & A_{K}\n           \\end{array} \\right) \\;,\n\n    where all of the :math:`A_k` are circular convolution operators, and\n    :math:`D` is a circular convolution operator. This problem is most\n    easily solved in the DFT transform domain, where the circular\n    convolutions become diagonal operators. Denoting the frequency-domain\n    versions of variables with a circumflex (e.g. :math:`\\hat{\\mb{x}}` is\n    the frequency-domain version of :math:`\\mb{x}`), the the problem can\n    be written as\n\n    .. math::\n\n       (\\hat{A}^H \\hat{A} + \\hat{D}) \\hat{\\mb{x}} = \\hat{\\mb{b}} \\;,\n\n    where\n\n    .. math::\n\n       \\hat{A} = \\left( \\begin{array}{cccc} \\hat{A}_1 & \\hat{A}_2 &\n       \\ldots & \\hat{A}_{K} \\end{array} \\right) \\;,\n\n    and :math:`\\hat{D}` and all the :math:`\\hat{A}_k` are diagonal\n    operators.\n\n    This linear equation is computational expensive to solve because\n    the left hand side includes the term :math:`\\hat{A}^H \\hat{A}`,\n    which corresponds to the outer product of :math:`\\hat{A}^H`\n    and :math:`\\hat{A}`. A computationally efficient solution is possible,\n    however, by exploiting the Woodbury matrix identity\n    :cite:`wohlberg-2014-efficient`\n\n    .. math::\n\n       (B + U C V)^{-1} = B^{-1} - B^{-1} U (C^{-1} + V B^{-1} U)^{-1}\n       V B^{-1} \\;.\n\n    Setting\n\n    .. math::\n\n       B &= \\hat{D} \\\\\n       U &= \\hat{A}^H \\\\\n       C &= I \\\\\n       V &= \\hat{A}\n\n    we have\n\n    .. math::\n\n       (\\hat{D} + \\hat{A}^H \\hat{A})^{-1} = \\hat{D}^{-1} - \\hat{D}^{-1}\n       \\hat{A}^H (I + \\hat{A} \\hat{D}^{-1} \\hat{A}^H)^{-1} \\hat{A}\n       \\hat{D}^{-1}\n\n    which can be simplified to\n\n    .. math::\n\n       (\\hat{D} + \\hat{A}^H \\hat{A})^{-1} = \\hat{D}^{-1} (I - \\hat{A}^H\n       \\hat{E}^{-1} \\hat{A} \\hat{D}^{-1})\n\n    by defining :math:`\\hat{E} = I + \\hat{A} \\hat{D}^{-1} \\hat{A}^H`. The\n    right hand side is much cheaper to compute because the only matrix\n    inversions involve :math:`\\hat{D}`, which is diagonal, and\n    :math:`\\hat{E}`, which is a weighted inner product of\n    :math:`\\hat{A}^H` and :math:`\\hat{A}`.\n    \"\"\"\n\n    def __init__(self, A: ComposedLinearOperator, D: CircularConvolve):\n        r\"\"\"\n        Args:\n            A: Operator :math:`A`.\n            D: Operator :math:`D`.\n        \"\"\"\n        if not isinstance(A, ComposedLinearOperator):\n            raise TypeError(\n                f\"Operator 'A' is required to be a ComposedLinearOperator; got a {type(A)}.\"\n            )\n        if not isinstance(A.A, Sum) or not isinstance(A.B, CircularConvolve):\n            raise TypeError(\n                \"Operator 'A' is required to be a composition of Sum and CircularConvolve\"\n                f\"linear operators; got a composition of {type(A.A)} and {type(A.B)}.\"\n            )\n\n        self.A = A\n        self.D = D\n        self.sum_axis = A.A.kwargs[\"axis\"]\n        if not isinstance(self.sum_axis, int):\n            raise ValueError(\n                \"Sum component of operator 'A' must sum over a single axis of its input.\"\n            )\n        self.fft_axes = A.B.x_fft_axes\n        self.real_result = is_real_dtype(D.input_dtype)\n\n        Ahat = A.B.h_dft\n        Dhat = D.h_dft\n        self.AHEinv = Ahat.conj() / (\n            1.0 + snp.sum(Ahat * (Ahat.conj() / Dhat), axis=self.sum_axis, keepdims=True)\n        )\n\n    def solve(self, b: Array) -> Array:\n        r\"\"\"Solve the linear system.\n\n        Solve the linear system with right hand side :math:`\\mb{b}`.\n\n        Args:\n           b: Array :math:`\\mathbf{b}`.\n\n        Returns:\n          Solution to the linear system.\n        \"\"\"\n        assert isinstance(self.A.B, CircularConvolve)\n\n        Ahat = self.A.B.h_dft\n        Dhat = self.D.h_dft\n        bhat = snp.fft.fftn(b, axes=self.fft_axes)\n        xhat = (\n            bhat - (self.AHEinv * (snp.sum(Ahat * bhat / Dhat, axis=self.sum_axis, keepdims=True)))\n        ) / Dhat\n        x = snp.fft.ifftn(xhat, axes=self.fft_axes)\n        if self.real_result:\n            x = x.real\n\n        return x\n\n    def accuracy(self, x: Array, b: Array) -> float:\n        r\"\"\"Compute solution relative residual.\n\n        Args:\n           x: Array :math:`\\mathbf{x}` (solution).\n           b: Array :math:`\\mathbf{b}` (right hand side of linear system).\n\n        Returns:\n           Relative residual of solution.\n        \"\"\"\n        return rel_res(self.A.gram_op(x) + self.D(x), b)\n"
  },
  {
    "path": "scico/test/conftest.py",
    "content": "\"\"\"\nConfigure the --level pytest option and its functionality.\n\"\"\"\n\nimport pytest\n\n\ndef pytest_addoption(parser, pluginmanager):\n    \"\"\"Add --level pytest option.\n\n    Level definitions:\n      1  Critical tests only\n      2  Skip tests that have a significant impact on coverage\n      3  All standard tests\n      4  Run all tests, including those marked as slow to run\n    \"\"\"\n    parser.addoption(\n        \"--level\", action=\"store\", default=3, type=int, help=\"Set test level to be run\"\n    )\n\n\ndef pytest_configure(config):\n    \"\"\"Add marker description.\"\"\"\n    config.addinivalue_line(\"markers\", \"slow: mark test as slow to run\")\n\n\ndef pytest_collection_modifyitems(config, items):\n    \"\"\"Skip slow tests depending on selected testing level.\"\"\"\n    if config.getoption(\"--level\") >= 4:\n        # don't skip tests at level 4 or higher\n        return\n    level_skip = pytest.mark.skip(reason=\"test not appropriate for selected level\")\n    for item in items:\n        if \"slow\" in item.keywords:\n            item.add_marker(level_skip)\n"
  },
  {
    "path": "scico/test/flax/test_apply.py",
    "content": "import os\nimport tempfile\n\nimport numpy as np\n\nimport jax\n\nimport pytest\nfrom test_trainer import SetupTest\n\nfrom flax.traverse_util import flatten_dict\nfrom scico import flax as sflax\nfrom scico.flax.train.apply import apply_fn\nfrom scico.flax.train.checkpoints import checkpoint_save, have_orbax\nfrom scico.flax.train.input_pipeline import IterateData\nfrom scico.flax.train.learning_rate import create_cnst_lr_schedule\nfrom scico.flax.train.state import create_basic_train_state\n\n\n@pytest.fixture(scope=\"module\")\ndef testobj():\n    yield SetupTest()\n\n\ndef test_apply_fn(testobj):\n    key = jax.random.key(seed=531)\n    key1, key2 = jax.random.split(key)\n\n    model = sflax.ConvBNNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    input_shape = (1, testobj.N, testobj.N, testobj.chn)\n    variables = model.init({\"params\": key1}, np.ones(input_shape, model.dtype))\n\n    ds = IterateData(testobj.test_ds, testobj.bsize, train=False)\n\n    try:\n        batch = next(ds)\n        output = apply_fn(model, variables, batch)\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        assert output.shape[1:] == testobj.test_ds[\"label\"].shape[1:]\n\n\ndef test_except_only_apply(testobj):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    with pytest.raises(RuntimeError):\n        out_ = sflax.only_apply(\n            testobj.train_conf,\n            model,\n            testobj.test_ds,\n        )\n\n\n@pytest.mark.parametrize(\"model_cls\", [sflax.DnCNNNet, sflax.ResNet, sflax.ConvBNNet, sflax.UNet])\ndef test_eval(testobj, model_cls):\n    depth = testobj.model_conf[\"depth\"]\n    model = model_cls(depth, testobj.chn, testobj.model_conf[\"num_filters\"])\n    if isinstance(model, sflax.DnCNNNet):\n        depth = 3\n        model = sflax.DnCNNNet(depth, testobj.chn, testobj.model_conf[\"num_filters\"])\n\n    key = jax.random.key(123)\n    variables = model.init(key, testobj.train_ds[\"image\"])\n\n    # from train script\n    out_, _ = sflax.only_apply(\n        testobj.train_conf,\n        model,\n        testobj.test_ds,\n        variables=variables,\n    )\n    # from scico FlaxMap util\n    fmap = sflax.FlaxMap(model, variables)\n    out_fmap = fmap(testobj.test_ds[\"image\"])\n\n    np.testing.assert_allclose(out_, out_fmap, atol=5e-6)\n\n\n@pytest.mark.skipif(not have_orbax, reason=\"orbax.checkpoint package not installed\")\ndef test_apply_from_checkpoint(testobj):\n    depth = 3\n    model = sflax.DnCNNNet(depth, testobj.chn, testobj.model_conf[\"num_filters\"])\n\n    key = jax.random.key(123)\n    variables = model.init(key, testobj.train_ds[\"image\"])\n\n    temp_dir = tempfile.TemporaryDirectory()\n    workdir = os.path.join(temp_dir.name, \"temp_ckp\")\n\n    # State initialization\n    learning_rate = create_cnst_lr_schedule(testobj.train_conf)\n    state = create_basic_train_state(\n        key, testobj.train_conf, model, (testobj.N, testobj.N), learning_rate\n    )\n    flat_params1 = flatten_dict(state.params)\n    flat_bstats1 = flatten_dict(state.batch_stats)\n    params1 = [t[1] for t in sorted(flat_params1.items())]\n    bstats1 = [t[1] for t in sorted(flat_bstats1.items())]\n\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"checkpointing\"] = True\n    train_conf[\"workdir\"] = workdir\n    checkpoint_save(state, train_conf, workdir)\n\n    try:\n        output, variables = sflax.only_apply(\n            train_conf,\n            model,\n            testobj.test_ds,\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        flat_params2 = flatten_dict(variables[\"params\"])\n        flat_bstats2 = flatten_dict(variables[\"batch_stats\"])\n        params2 = [t[1] for t in sorted(flat_params2.items())]\n        bstats2 = [t[1] for t in sorted(flat_bstats2.items())]\n\n        for i in range(len(params1)):\n            np.testing.assert_allclose(params1[i], params2[i], rtol=1e-5)\n        for i in range(len(bstats1)):\n            np.testing.assert_allclose(bstats1[i], bstats2[i], rtol=1e-5)\n"
  },
  {
    "path": "scico/test/flax/test_checkpoints.py",
    "content": "import os\nimport tempfile\n\nimport numpy as np\n\nimport jax\n\nimport pytest\nfrom test_trainer import SetupTest\n\nfrom flax.traverse_util import flatten_dict\nfrom scico import flax as sflax\nfrom scico.flax.train.checkpoints import checkpoint_restore, checkpoint_save, have_orbax\nfrom scico.flax.train.learning_rate import create_cnst_lr_schedule\nfrom scico.flax.train.state import create_basic_train_state\n\n\n@pytest.fixture(scope=\"module\")\ndef testobj():\n    yield SetupTest()\n\n\n@pytest.mark.skipif(not have_orbax, reason=\"orbax.checkpoint package not installed\")\ndef test_checkpoint(testobj):\n    depth = 3\n    model = sflax.DnCNNNet(depth, testobj.chn, testobj.model_conf[\"num_filters\"])\n\n    key = jax.random.key(123)\n    variables = model.init(key, testobj.train_ds[\"image\"])\n\n    temp_dir = tempfile.TemporaryDirectory()\n    workdir = os.path.join(temp_dir.name, \"temp_ckp\")\n\n    # State initialization\n    learning_rate = create_cnst_lr_schedule(testobj.train_conf)\n    state = create_basic_train_state(\n        key, testobj.train_conf, model, (testobj.N, testobj.N), learning_rate\n    )\n    flat_params1 = flatten_dict(state.params)\n    flat_bstats1 = flatten_dict(state.batch_stats)\n    params1 = [t[1] for t in sorted(flat_params1.items())]\n    bstats1 = [t[1] for t in sorted(flat_bstats1.items())]\n\n    try:\n        checkpoint_save(state, testobj.train_conf, workdir)\n        state_in = checkpoint_restore(state, workdir)\n\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n\n        flat_params2 = flatten_dict(state_in.params)\n        flat_bstats2 = flatten_dict(state_in.batch_stats)\n        params2 = [t[1] for t in sorted(flat_params2.items())]\n        bstats2 = [t[1] for t in sorted(flat_bstats2.items())]\n\n        for i in range(len(params1)):\n            np.testing.assert_allclose(params1[i], params2[i], rtol=1e-5)\n        for i in range(len(bstats1)):\n            np.testing.assert_allclose(bstats1[i], bstats2[i], rtol=1e-5)\n\n\n@pytest.mark.skipif(not have_orbax, reason=\"orbax.checkpoint package not installed\")\n@pytest.mark.parametrize(\"model_cls\", [sflax.DnCNNNet, sflax.ResNet])\ndef test_checkpointing_from_trainer(testobj, model_cls):\n    depth = 3\n    model = model_cls(depth, testobj.chn, testobj.model_conf[\"num_filters\"])\n\n    temp_dir = tempfile.TemporaryDirectory()\n    workdir = os.path.join(temp_dir.name, \"temp_ckp\")\n\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"checkpointing\"] = True\n    train_conf[\"workdir\"] = workdir\n    train_conf[\"return_state\"] = True\n\n    # Create training object\n    trainer = sflax.BasicFlaxTrainer(\n        train_conf,\n        model,\n        testobj.train_ds,\n        testobj.test_ds,\n    )\n    try:\n        state_out, _ = trainer.train()\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        # Model parameters from training\n        flat_params1 = flatten_dict(state_out.params)\n        params1 = [t[1] for t in sorted(flat_params1.items())]\n\n        # Model parameteres from checkpoint\n        state_in = checkpoint_restore(state_out, workdir)\n        flat_params2 = flatten_dict(state_in.params)\n        params2 = [t[1] for t in sorted(flat_params2.items())]\n\n        for i in range(len(params1)):\n            np.testing.assert_allclose(params1[i], params2[i], rtol=1e-5)\n\n        if hasattr(state_out, \"batch_stats\"):\n            # Batch stats from training\n            flat_bstats1 = flatten_dict(state_out.batch_stats)\n            bstats1 = [t[1] for t in sorted(flat_bstats1.items())]\n            # Batch stats from checkpoint\n            flat_bstats2 = flatten_dict(state_in.batch_stats)\n            bstats2 = [t[1] for t in sorted(flat_bstats2.items())]\n            for i in range(len(bstats1)):\n                np.testing.assert_allclose(bstats1[i], bstats2[i], rtol=1e-5)\n\n\n@pytest.mark.skipif(not have_orbax, reason=\"orbax.checkpoint package not installed\")\ndef test_checkpoint_exception(testobj):\n    depth = 3\n    model = sflax.DnCNNNet(depth, testobj.chn, testobj.model_conf[\"num_filters\"])\n\n    key = jax.random.key(123)\n    variables = model.init(key, testobj.train_ds[\"image\"])\n\n    temp_dir = tempfile.TemporaryDirectory()\n    workdir = os.path.join(temp_dir.name, \"temp_ckp\")\n\n    # State initialization\n    learning_rate = create_cnst_lr_schedule(testobj.train_conf)\n    state = create_basic_train_state(\n        key, testobj.train_conf, model, (testobj.N, testobj.N), learning_rate\n    )\n\n    with pytest.raises(FileNotFoundError):\n        state_in = checkpoint_restore(state, workdir)\n"
  },
  {
    "path": "scico/test/flax/test_clu.py",
    "content": "import numpy as np\n\nimport jax\n\nfrom flax.linen import Conv\nfrom flax.linen.module import Module, compact\nfrom scico import flax as sflax\nfrom scico.flax.train.clu_utils import (\n    _default_table_value_formatter,\n    get_parameter_overview,\n)\n\n\ndef test_count_parameters():\n    N = 128  # signal size\n    chn = 1  # number of channels\n\n    # Model configuration\n    mconf = {\n        \"depth\": 2,\n        \"num_filters\": 16,\n    }\n\n    model = sflax.ResNet(mconf[\"depth\"], chn, mconf[\"num_filters\"])\n\n    key = jax.random.key(seed=1234)\n    input_shape = (1, N, N, chn)\n    variables = model.init({\"params\": key}, np.ones(input_shape, model.dtype))\n\n    filter_sz = model.kernel_size[0] * model.kernel_size[1]\n    # filter parameters output layer\n    sum_manual_params = filter_sz * mconf[\"num_filters\"] * chn\n    # bias and scale of batch normalization output layer\n    sum_manual_params += chn * 2\n    # mean and bar of batch normalization output layer\n    sum_manual_bst = chn * 2\n    chn_prev = 1\n    for i in range(mconf[\"depth\"] - 1):\n        # filter parameters\n        sum_manual_params += filter_sz * mconf[\"num_filters\"] * chn_prev\n        # bias and scale of batch normalization\n        sum_manual_params += mconf[\"num_filters\"] * 2\n        # mean and bar of batch normalization\n        sum_manual_bst += mconf[\"num_filters\"] * 2\n        chn_prev = mconf[\"num_filters\"]\n\n    total_nvar_params = sflax.count_parameters(variables[\"params\"])\n    total_nvar_bst = sflax.count_parameters(variables[\"batch_stats\"])\n\n    assert total_nvar_params == sum_manual_params\n    assert total_nvar_bst == sum_manual_bst\n\n\ndef test_count_parameters_empty():\n    assert sflax.count_parameters({}) == 0\n\n\n# From https://github.com/google/CommonLoopUtils/blob/main/clu/parameter_overview_test.py\nEMPTY_PARAMETER_OVERVIEW = \"\"\"+------+-------+------+------+-----+\n| Name | Shape | Size | Mean | Std |\n+------+-------+------+------+-----+\n+------+-------+------+------+-----+\nTotal weights: 0\"\"\"\n\nFLAX_CONV2D_PARAMETER_OVERVIEW = \"\"\"+-------------+--------------+------+\n| Name        | Shape        | Size |\n+-------------+--------------+------+\n| conv/bias   | (2,)         | 2    |\n| conv/kernel | (3, 3, 3, 2) | 54   |\n+-------------+--------------+------+\nTotal weights: 56\"\"\"\n\nFLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS = \"\"\"+-------------+--------------+------+------+-----+\n| Name        | Shape        | Size | Mean | Std |\n+-------------+--------------+------+------+-----+\n| conv/bias   | (2,)         | 2    | 1.0  | 0.0 |\n| conv/kernel | (3, 3, 3, 2) | 54   | 1.0  | 0.0 |\n+-------------+--------------+------+------+-----+\nTotal weights: 56\"\"\"\n\nFLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS = \"\"\"+--------------------+--------------+------+------+-----+\n| Name               | Shape        | Size | Mean | Std |\n+--------------------+--------------+------+------+-----+\n| params/conv/bias   | (2,)         | 2    | 1.0  | 0.0 |\n| params/conv/kernel | (3, 3, 3, 2) | 54   | 1.0  | 0.0 |\n+--------------------+--------------+------+------+-----+\nTotal weights: 56\"\"\"\n\n\n# From https://github.com/google/CommonLoopUtils/blob/main/clu/parameter_overview_test.py\ndef test_get_parameter_overview_empty():\n    assert get_parameter_overview({}) == EMPTY_PARAMETER_OVERVIEW\n\n\nclass CNN(Module):\n    @compact\n    def __call__(self, x):\n        return Conv(features=2, kernel_size=(3, 3), name=\"conv\")(x)\n\n\n# From https://github.com/google/CommonLoopUtils/blob/main/clu/parameter_overview_test.py\ndef test_get_parameter_overview():\n    rng = jax.random.key(42)\n    # Weights of a 2D convolution with 2 filters..\n    variables = CNN().init(rng, np.zeros((2, 5, 5, 3)))\n    variables = jax.tree_util.tree_map(jax.numpy.ones_like, variables)\n    assert (\n        get_parameter_overview(variables[\"params\"], include_stats=False)\n        == FLAX_CONV2D_PARAMETER_OVERVIEW\n    )\n    assert get_parameter_overview(variables[\"params\"]) == FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS\n    assert get_parameter_overview(variables) == FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS\n\n\n# From https://github.com/google/CommonLoopUtils/blob/main/clu/parameter_overview_test.py\ndef test_printing_bool():\n    assert _default_table_value_formatter(True) == \"True\"\n    assert _default_table_value_formatter(False) == \"False\"\n"
  },
  {
    "path": "scico/test/flax/test_examples_flax.py",
    "content": "import os\nimport tempfile\n\nimport numpy as np\n\nimport pytest\n\nfrom scico import random\nfrom scico.flax.examples.data_generation import (\n    distributed_data_generation,\n    generate_blur_data,\n    generate_ct_data,\n    generate_foam1_images,\n    generate_foam2_images,\n    have_ray,\n    have_xdesign,\n)\nfrom scico.flax.examples.data_preprocessing import (\n    CenterCrop,\n    PaddedCircularConvolve,\n    PositionalCrop,\n    RandomNoise,\n    build_image_dataset,\n    flip,\n    preprocess_images,\n    rotation90,\n)\nfrom scico.flax.examples.examples import (\n    get_cache_path,\n    runtime_error_array,\n    runtime_error_scalar,\n)\nfrom scico.flax.examples.typed_dict import ConfigImageSetDict\nfrom scico.typing import Shape\n\nos.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=8\"\n\n# These tests are for the scico.flax.examples module, NOT the example scripts\n\n\n@pytest.mark.skipif(not have_xdesign, reason=\"xdesign package not installed\")\ndef test_foam1_gen():\n    seed = 4444\n    N = 32\n    ndata = 2\n\n    dt = generate_foam1_images(seed, N, ndata)\n    assert dt.shape == (ndata, N, N, 1)\n\n\n@pytest.mark.skipif(not have_xdesign, reason=\"xdesign package not installed\")\ndef test_foam2_gen():\n    seed = 4321\n    N = 32\n    ndata = 2\n\n    dt = generate_foam2_images(seed, N, ndata)\n    assert dt.shape == (ndata, N, N, 1)\n\n\n@pytest.mark.skipif(not have_ray, reason=\"ray package not installed\")\ndef test_distdatagen():\n    N = 16\n    nimg = 8\n\n    def random_data_gen(seed, N, ndata):\n        np.random.seed(seed)\n        dt = np.random.randn(ndata, N, N, 1)\n        return dt\n\n    dt = distributed_data_generation(random_data_gen, N, nimg)\n    assert dt.ndim == 4\n    assert dt.shape == (nimg, N, N, 1)\n\n\n@pytest.mark.skipif(\n    not have_ray or not have_xdesign,\n    reason=\"ray or xdesign package not installed\",\n)\ndef test_ct_data_generation():\n    N = 32\n    nimg = 8\n    nproj = 45\n\n    def random_img_gen(seed, size, ndata):\n        np.random.seed(seed)\n        shape = (ndata, size, size, 1)\n        return np.random.randn(*shape)\n\n    img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen)\n    assert img.shape == (nimg, N, N, 1)\n    assert sino.shape == (nimg, nproj, sino.shape[2], 1)\n    assert fbp.shape == (nimg, N, N, 1)\n\n\n@pytest.mark.skipif(not have_ray or not have_xdesign, reason=\"ray or xdesign package not installed\")\ndef test_blur_data_generation():\n    N = 32\n    nimg = 8\n    n = 3  # convolution kernel size\n    blur_kernel = np.ones((n, n)) / (n * n)\n\n    def random_img_gen(seed, size, ndata):\n        np.random.seed(seed)\n        shape = (ndata, size, size, 1)\n        return np.random.randn(*shape)\n\n    img, blurn = generate_blur_data(nimg, N, blur_kernel, noise_sigma=0.01, imgfunc=random_img_gen)\n    assert img.shape == (nimg, N, N, 1)\n    assert blurn.shape == (nimg, N, N, 1)\n\n\ndef test_rotation90():\n    N = 128\n    x, key = random.randn((N, N), seed=4321)\n    x2, key = random.randn((10, N, N, 1), key=key)\n    x_rot = rotation90(x)\n    x2_rot = rotation90(x2)\n\n    np.testing.assert_allclose(x_rot, np.swapaxes(x, 0, 1), rtol=1e-5)\n    np.testing.assert_allclose(x2_rot, np.swapaxes(x2, 1, 2), rtol=1e-5)\n\n\ndef test_flip():\n    N = 128\n    x, key = random.randn((N, N), seed=4321)\n    x2, key = random.randn((10, N, N, 1), key=key)\n    x_flip = flip(x)\n    x2_flip = flip(x2)\n\n    np.testing.assert_allclose(x_flip, x[:, ::-1, ...], rtol=1e-5)\n    np.testing.assert_allclose(x2_flip, x2[..., ::-1, :], rtol=1e-5)\n\n\n@pytest.mark.parametrize(\"output_size\", [128, (128, 128), (128, 64)])\ndef test_center_crop(output_size):\n    N = 256\n    x, key = random.randn((N, N), seed=4321)\n    if isinstance(output_size, int):\n        ccrop = CenterCrop(output_size)\n    else:\n        shp: Shape = output_size\n        ccrop = CenterCrop(shp)\n\n    x_crop = ccrop(x)\n    if isinstance(output_size, int):\n        assert x_crop.shape[0] == output_size\n        assert x_crop.shape[1] == output_size\n    else:\n        assert x_crop.shape == output_size\n\n\n@pytest.mark.parametrize(\"output_size\", [128, (128, 128), (128, 64)])\ndef test_positional_crop(output_size):\n    N = 256\n    x, key = random.randn((N, N), seed=4321)\n    top, key = random.randint(shape=(1,), minval=0, maxval=N - 128, key=key)\n    left, key = random.randint(shape=(1,), minval=0, maxval=N - 128, key=key)\n    pcrop = PositionalCrop(output_size)\n\n    x_crop = pcrop(x, top[0], left[0])\n    if isinstance(output_size, int):\n        assert x_crop.shape[0] == output_size\n        assert x_crop.shape[1] == output_size\n    else:\n        assert x_crop.shape == output_size\n\n\n@pytest.mark.parametrize(\"range_flag\", [False, True])\ndef test_random_noise1(range_flag):\n    N = 128\n    x, key = random.randn((N, N), seed=4321)\n    noise = RandomNoise(0.1, range_flag)\n    xn = noise(x)\n    x2, key = random.randn((10, N, N, 1), key=key)\n    xn2 = noise(x2)\n\n    assert x.shape == xn.shape\n    assert x2.shape == xn2.shape\n\n\n@pytest.mark.parametrize(\"shape\", [(128, 128), (128, 128, 3), (5, 128, 128, 1)])\ndef test_random_noise2(shape):\n    x, key = random.randn(shape, seed=4321)\n    noise = RandomNoise(0.1, True)\n    xn = noise(x)\n\n    assert x.shape == xn.shape\n\n\n@pytest.mark.parametrize(\"output_size\", [64, (64, 64)])\n@pytest.mark.parametrize(\"gray_flag\", [False, True])\n@pytest.mark.parametrize(\"num_img_req\", [None, 4])\ndef test_preprocess_images(output_size, gray_flag, num_img_req):\n\n    num_img = 10\n    N = 128\n    C = 3\n    shape = (num_img, N, N, C)\n    images, key = random.randn(shape, seed=4444)\n\n    stride = 1\n    try:\n        output = preprocess_images(\n            images, output_size, gray_flag, num_img_req, multi_flag=False, stride=stride\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        assert output.shape[1] == 64\n        assert output.shape[2] == 64\n\n        if gray_flag:\n            assert output.shape[-1] == 1\n        else:\n            assert output.shape[-1] == C\n\n        if num_img_req is None:\n            assert output.shape[0] == num_img\n        else:\n            assert output.shape[0] == num_img_req\n\n\ndef test_preprocess_images_multi_flag():\n    num_img = 10\n    N = 128\n    C = 3\n    shape = (num_img, N, N, C)\n    images, key = random.randn(shape, seed=4444)\n\n    output_size = (64, 64)\n    gray_flag = True\n    num_img_req = 4\n\n    stride = 64  # 2 per side = 4 patches per image\n    try:\n        output = preprocess_images(\n            images, output_size, gray_flag, num_img_req, multi_flag=True, stride=stride\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        assert output.shape[0] == (4 * num_img_req)\n        assert output.shape[1] == 64\n        assert output.shape[2] == 64\n        assert output.shape[-1] == 1\n\n\nclass SetupTest:\n    def __init__(self):\n        # Data configuration\n        self.dtconf: ConfigImageSetDict = {\n            \"seed\": 0,\n            \"output_size\": 64,\n            \"stride\": 1,\n            \"multi\": False,\n            \"augment\": False,\n            \"run_gray\": True,\n            \"num_img\": 10,\n            \"test_num_img\": 4,\n            \"data_mode\": \"dn\",\n            \"noise_level\": 0.01,\n            \"noise_range\": False,\n            \"test_split\": 0.1,\n        }\n\n\n@pytest.fixture(scope=\"module\")\ndef testobj():\n    yield SetupTest()\n\n\n@pytest.mark.parametrize(\"augment\", [False, True])\ndef test_build_image_dataset(testobj, augment):\n    num_train = testobj.dtconf[\"num_img\"]\n    num_test = testobj.dtconf[\"test_num_img\"]\n    N = 128\n    C = 3\n    shape = (num_train, N, N, C)\n    img_train, key = random.randn(shape, seed=4444)\n    img_test, key = random.randn((num_test, N, N, C), key=key)\n\n    dtconf = dict(testobj.dtconf)\n    dtconf[\"augment\"] = augment\n\n    train_ds, test_ds = build_image_dataset(img_train, img_test, dtconf)\n    assert train_ds[\"image\"].shape == train_ds[\"label\"].shape\n    assert test_ds[\"image\"].shape == test_ds[\"label\"].shape\n    assert test_ds[\"label\"].shape[0] == num_test\n    if augment:\n        assert train_ds[\"label\"].shape[0] == num_train * 3\n    else:\n        assert train_ds[\"label\"].shape[0] == num_train\n\n\ndef test_padded_circular_convolve():\n    N = 64\n    C = 3\n    kernel_size = 5\n    blur_sigma = 2.1\n\n    x, key = random.randn((N, N, C), seed=2468)\n\n    pcc_op = PaddedCircularConvolve(N, C, kernel_size, blur_sigma)\n    xblur = pcc_op(x)\n    assert xblur.shape == x.shape\n\n\ndef test_runtime_error_scalar():\n    with pytest.raises(RuntimeError):\n        runtime_error_scalar(\"channels\", \"testing \", 3, 1)\n\n\ndef test_runtime_error_array():\n    with pytest.raises(RuntimeError):\n        runtime_error_array(\"channels\", \"testing \", 1e-2)\n\n\ndef test_default_cache_path():\n    try:\n        cache_path, cache_path_display = get_cache_path()\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        cache_path_display == \"~/.cache/scico/examples/data\"\n\n\ndef test_cache_path():\n    try:\n        temp_dir = tempfile.TemporaryDirectory()\n        cache_path = os.path.join(temp_dir.name, \".cache\")\n        cache_path_, cache_path_display = get_cache_path(cache_path)\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        cache_path_ == cache_path\n        cache_path_display == cache_path\n"
  },
  {
    "path": "scico/test/flax/test_flax.py",
    "content": "import os\nimport tempfile\nfrom functools import partial\n\nimport numpy as np\n\nimport pytest\n\nfrom flax.core import unfreeze\nfrom flax.errors import ScopeParamShapeError\nfrom flax.linen import BatchNorm, Conv, elu, leaky_relu, max_pool, relu\nfrom scico import flax as sflax\nfrom scico.data import _flax_data_path\nfrom scico.random import randn\n\n\nclass TestSet:\n    def test_convnblock_default(self):\n        nflt = 16  # number of filters\n        conv = partial(Conv, dtype=np.float32)\n        norm = partial(BatchNorm, dtype=np.float32)\n        flxm = sflax.blocks.ConvBNBlock(\n            num_filters=nflt,\n            conv=conv,\n            norm=norm,\n            act=relu,\n        )\n        assert flxm.kernel_size == (3, 3)  # size of kernel\n        assert flxm.strides == (1, 1)  # stride of convolution\n\n    def test_convnblock_args(self):\n        nflt = 16  # number of filters\n        ksz = (5, 5)  # size of kernel\n        strd = (2, 2)  # stride of convolution\n        conv = partial(Conv, dtype=np.float32)\n        norm = partial(BatchNorm, dtype=np.float32)\n        flxm = sflax.blocks.ConvBNBlock(\n            num_filters=nflt,\n            conv=conv,\n            norm=norm,\n            act=leaky_relu,\n            kernel_size=ksz,\n            strides=strd,\n        )\n        assert flxm.act == leaky_relu\n        assert flxm.kernel_size == ksz  # size of kernel\n        assert flxm.strides == strd  # stride of convolution\n\n    def test_convblock_default(self):\n        nflt = 16  # number of filters\n        conv = partial(Conv, dtype=np.float32)\n        flxm = sflax.blocks.ConvBlock(\n            num_filters=nflt,\n            conv=conv,\n            act=relu,\n        )\n        assert flxm.kernel_size == (3, 3)  # size of kernel\n        assert flxm.strides == (1, 1)  # stride of convolution\n\n    def test_convblock_args(self):\n        nflt = 16  # number of filters\n        ksz = (5, 5)  # size of kernel\n        strd = (2, 2)  # stride of convolution\n        conv = partial(Conv, dtype=np.float32)\n        flxm = sflax.blocks.ConvBlock(\n            num_filters=nflt,\n            conv=conv,\n            act=elu,\n            kernel_size=ksz,\n            strides=strd,\n        )\n        assert flxm.act == elu\n        assert flxm.kernel_size == ksz  # size of kernel\n        assert flxm.strides == strd  # stride of convolution\n\n    def test_convblock_call(self):\n        nflt = 16  # number of filters\n        ksz = (5, 5)  # size of kernel\n        strd = (2, 2)  # stride of convolution\n        conv = partial(Conv, dtype=np.float32)\n        flxb = sflax.blocks.ConvBlock(\n            num_filters=nflt,\n            conv=conv,\n            act=elu,\n            kernel_size=ksz,\n            strides=strd,\n        )\n        chn = 1  # number of channels\n        N = 128  # image size\n        x, key = randn((10, N, N, chn), seed=1234)\n        variables = flxb.init(key, x)\n        # Test for the construction / forward pass.\n        cbx = flxb.apply(variables, x)\n        assert x.dtype == cbx.dtype\n\n    def test_convnpblock_args(self):\n        nflt = 16  # number of filters\n        ksz = (5, 5)  # size of kernel\n        strd = (2, 2)  # stride of convolution\n        wnd = (2, 2)  # window for pooling\n        conv = partial(Conv, dtype=np.float32)\n        norm = partial(BatchNorm, dtype=np.float32)\n        flxm = sflax.blocks.ConvBNPoolBlock(\n            num_filters=nflt,\n            conv=conv,\n            norm=norm,\n            act=relu,\n            pool=max_pool,\n            kernel_size=ksz,\n            strides=strd,\n            window_shape=wnd,\n        )\n        assert flxm.act == relu\n        assert flxm.kernel_size == ksz  # size of kernel\n        assert flxm.strides == strd  # stride of convolution\n\n    def test_convnublock_args(self):\n        nflt = 16  # number of filters\n        ksz = (5, 5)  # size of kernel\n        strd = (2, 2)  # stride of convolution\n        upsampling = 2  # upsampling factor\n        conv = partial(Conv, dtype=np.float32)\n        norm = partial(BatchNorm, dtype=np.float32)\n        upfn = partial(sflax.blocks.upscale_nn, scale=upsampling)\n        flxm = sflax.blocks.ConvBNUpsampleBlock(\n            num_filters=nflt,\n            conv=conv,\n            norm=norm,\n            act=relu,\n            upfn=upfn,\n            kernel_size=ksz,\n            strides=strd,\n        )\n        assert flxm.act == relu\n        assert flxm.kernel_size == ksz  # size of kernel\n        assert flxm.strides == strd  # stride of convolution\n\n    def test_convmnblock_default(self):\n        nblck = 2  # number of blocks\n        nflt = 16  # number of filters\n        conv = partial(Conv, dtype=np.float32)\n        norm = partial(BatchNorm, dtype=np.float32)\n        flxm = sflax.blocks.ConvBNMultiBlock(\n            num_blocks=nblck,\n            num_filters=nflt,\n            conv=conv,\n            norm=norm,\n            act=relu,\n        )\n        assert flxm.kernel_size == (3, 3)  # size of kernel\n        assert flxm.strides == (1, 1)  # stride of convolution\n\n    def test_upscale(self):\n        N = 128  # image size\n        chn = 3  # channels\n        x, key = randn((10, N, N, chn), seed=1234)\n\n        xups = sflax.blocks.upscale_nn(x)\n        assert xups.shape == (10, 2 * N, 2 * N, chn)\n\n    def test_resnet_default(self):\n        depth = 3  # depth of model\n        chn = 1  # number of channels\n        num_filters = 16  # number of filters per layer\n        N = 128  # image size\n        x, key = randn((10, N, N, chn), seed=1234)\n        resnet = sflax.ResNet(\n            depth=depth,\n            channels=chn,\n            num_filters=num_filters,\n        )\n        variables = resnet.init(key, x)\n        # Test for the construction / forward pass.\n        rnx = resnet.apply(variables, x, train=False, mutable=False)\n        assert x.dtype == rnx.dtype\n\n    def test_unet_default(self):\n        depth = 2  # depth of model\n        chn = 1  # number of channels\n        num_filters = 16  # number of filters per layer\n        N = 128  # image size\n        x, key = randn((10, N, N, chn), seed=1234)\n        unet = sflax.UNet(\n            depth=depth,\n            channels=chn,\n            num_filters=num_filters,\n        )\n        variables = unet.init(key, x)\n        # Test for the construction / forward pass.\n        unx = unet.apply(variables, x, train=False, mutable=False)\n        assert x.dtype == unx.dtype\n\n\nclass DnCNNNetTest:\n    def __init__(self):\n        depth = 3  # depth of model\n        chn = 1  # number of channels\n        num_filters = 16  # number of filters per layer\n        N = 128  # image size\n        self.x, key = randn((10, N, N, chn), seed=1234)\n        self.dncnn = sflax.DnCNNNet(\n            depth=depth,\n            channels=chn,\n            num_filters=num_filters,\n        )\n        self.variables = self.dncnn.init(key, self.x)\n\n\n@pytest.fixture(scope=\"module\")\ndef testobj():\n    yield DnCNNNetTest()\n\n\ndef test_DnCNN_call(testobj):\n    # Test for the construction / forward pass.\n    dnx = testobj.dncnn.apply(testobj.variables, testobj.x, train=False, mutable=False)\n    assert testobj.x.dtype == dnx.dtype\n\n\ndef test_DnCNN_train(testobj):\n    # Test effect of training flag.\n    bn0bias_before = testobj.variables[\"params\"][\"ConvBNBlock_0\"][\"BatchNorm_0\"][\"bias\"]\n    bn0mean_before = testobj.variables[\"batch_stats\"][\"ConvBNBlock_0\"][\"BatchNorm_0\"][\"mean\"]\n    dnx, new_state = testobj.dncnn.apply(\n        testobj.variables, testobj.x, train=True, mutable=[\"batch_stats\"]\n    )\n    bn0mean_new = new_state[\"batch_stats\"][\"ConvBNBlock_0\"][\"BatchNorm_0\"][\"mean\"]\n    bn0bias_after = testobj.variables[\"params\"][\"ConvBNBlock_0\"][\"BatchNorm_0\"][\"bias\"]\n    bn0mean_after = testobj.variables[\"batch_stats\"][\"ConvBNBlock_0\"][\"BatchNorm_0\"][\"mean\"]\n    try:\n        np.testing.assert_allclose(bn0bias_before, bn0bias_after, rtol=1e-5)\n        np.testing.assert_allclose(\n            bn0mean_new - bn0mean_before, bn0mean_new + bn0mean_after, rtol=1e-5\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n\n\ndef test_DnCNN_test(testobj):\n    # Test effect of training flag.\n    bn0var_before = testobj.variables[\"batch_stats\"][\"ConvBNBlock_0\"][\"BatchNorm_0\"][\"var\"]\n    dnx, new_state = testobj.dncnn.apply(\n        testobj.variables, testobj.x, train=False, mutable=[\"batch_stats\"]\n    )\n    bn0var_after = new_state[\"batch_stats\"][\"ConvBNBlock_0\"][\"BatchNorm_0\"][\"var\"]\n    np.testing.assert_allclose(bn0var_before, bn0var_after, rtol=1e-5)\n\n\ndef test_FlaxMap_call(testobj):\n    # Test for the usage of flax model as a map.\n    # 2D evaluation signal.\n    fmap = sflax.FlaxMap(testobj.dncnn, testobj.variables)\n    N = 128  # image size\n    x, key = randn((N, N))\n    out = fmap(x)\n    assert x.dtype == out.dtype\n    assert x.ndim == out.ndim\n\n\ndef test_FlaxMap_3D_call(testobj):\n    # Test for the usage of flax model as a map.\n    # 3D evaluation signal.\n    fmap = sflax.FlaxMap(testobj.dncnn, testobj.variables)\n    N = 128  # image size\n    chn = 1  # channels\n    x, key = randn((N, N, chn))\n    out = fmap(x)\n    assert x.dtype == out.dtype\n    assert x.ndim == out.ndim\n\n\ndef test_FlaxMap_batch_call(testobj):\n    # Test for the usage of flax model as a map.\n    # 4D evaluation signal.\n    fmap = sflax.FlaxMap(testobj.dncnn, testobj.variables)\n    N = 128  # image size\n    chn = 1  # channels\n    batch = 8  # batch size\n    x, key = randn((batch, N, N, chn))\n    out = fmap(x)\n    assert x.dtype == out.dtype\n    assert x.ndim == out.ndim\n\n\ndef test_FlaxMap_blockarray_exception(testobj):\n\n    from scico.numpy import BlockArray\n\n    fmap = sflax.FlaxMap(testobj.dncnn, testobj.variables)\n\n    x0, key = randn(shape=(3, 4), seed=4321)\n    x1, key = randn(shape=(4, 5, 6), key=key)\n    x = BlockArray((x0, x1))\n\n    with pytest.raises(NotImplementedError):\n        fmap(x)\n\n\n@pytest.mark.parametrize(\"variant\", [\"6L\", \"6M\", \"6H\", \"17L\", \"17M\", \"17H\"])\ndef test_variable_load(variant):\n    N = 128  # image size\n    chn = 1  # channels\n    x, key = randn((10, N, N, chn), seed=1234)\n\n    if variant[0] == \"6\":\n        nlayer = 6\n    else:\n        nlayer = 17\n\n    model = sflax.DnCNNNet(depth=nlayer, channels=chn, num_filters=64, dtype=np.float32)\n    # Load weights for DnCNN.\n    variables = sflax.load_variables(_flax_data_path(\"dncnn%s.mpk\" % variant))\n\n    try:\n        fmap = sflax.FlaxMap(model, variables)\n        out = fmap(x)\n    except Exception as e:\n        print(e)\n        assert 0\n\n\ndef test_variable_load_mismatch():\n    N = 128  # image size\n    chn = 1  # channels\n    x, key = randn((10, N, N, chn), seed=1234)\n\n    nlayer = 6\n    model = sflax.ResNet(depth=nlayer, channels=chn, num_filters=64, dtype=np.float32)\n    # Load weights for DnCNN.\n    variables = sflax.load_variables(_flax_data_path(\"dncnn6L.mpk\"))\n\n    # created with mismatched parameters\n    fmap = sflax.FlaxMap(model, variables)\n    with pytest.raises(ScopeParamShapeError):\n        fmap(x)\n\n\ndef test_variable_save():\n    N = 128  # image size\n    chn = 1  # channels\n    x, key = randn((10, N, N, chn), seed=1234)\n\n    nlayer = 6\n    model = sflax.ResNet(depth=nlayer, channels=chn, num_filters=64, dtype=np.float32)\n\n    aux, key = randn((1,), seed=23432)\n    input_shape = (1, N, N, chn)\n    variables = model.init({\"params\": key}, np.ones(input_shape, model.dtype))\n\n    try:\n        temp_dir = tempfile.TemporaryDirectory()\n        sflax.save_variables(unfreeze(variables), os.path.join(temp_dir.name, \"vres6.mpk\"))\n    except Exception as e:\n        print(e)\n        assert 0\n"
  },
  {
    "path": "scico/test/flax/test_inv.py",
    "content": "import os\nfrom functools import partial\n\nimport numpy as np\n\nimport jax.numpy as jnp\nfrom jax import lax\n\nfrom scico import flax as sflax\nfrom scico import random\nfrom scico.flax.examples import PaddedCircularConvolve, build_blur_kernel\nfrom scico.flax.train.traversals import clip_positive, clip_range, construct_traversal\nfrom scico.linop import CircularConvolve, Identity\nfrom scico.linop.xray import XRayTransform2D\n\nos.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=8\"\n\n\nclass TestSet:\n    def setup_method(self, method):\n        self.depth = 1  # depth (equivalent to number of blocks) of model\n        self.chn = 1  # number of channels\n        self.num_filters = 16  # number of filters per layer\n        self.block_depth = 2  # number of layers in block\n        self.N = 128  # image size\n\n    def test_odpdn_default(self):\n        y, key = random.randn((10, self.N, self.N, self.chn), seed=1234)\n\n        opI = Identity(y.shape)\n\n        odpdn = sflax.ODPNet(\n            operator=opI,\n            depth=self.depth,\n            channels=self.chn,\n            num_filters=self.num_filters,\n            block_depth=self.block_depth,\n        )\n\n        variables = odpdn.init(key, y)\n        # Test for the construction / forward pass.\n        mny = odpdn.apply(variables, y, train=False, mutable=False)\n        assert y.dtype == mny.dtype\n        assert y.shape == mny.shape\n\n    def test_odpdcnv_default(self):\n        y, key = random.randn((10, self.N, self.N, self.chn), seed=1234)\n\n        blur_shape = (9, 9)\n        blur_sigma = 2.24\n        kernel = build_blur_kernel(blur_shape, blur_sigma)\n\n        ishape = (self.N, self.N)\n        opBlur = CircularConvolve(h=kernel, input_shape=ishape)\n\n        odpdb = sflax.ODPNet(\n            operator=opBlur,\n            depth=self.depth,\n            channels=self.chn,\n            num_filters=self.num_filters,\n            block_depth=self.block_depth,\n            odp_block=sflax.inverse.ODPProxDcnvBlock,\n        )\n\n        variables = odpdb.init(key, y)\n        # Test for the construction / forward pass.\n        mny = odpdb.apply(variables, y, train=False, mutable=False)\n        assert y.dtype == mny.dtype\n        assert y.shape == mny.shape\n\n    def test_odpdcnv_padded(self):\n        y, key = random.randn((10, self.N, self.N, self.chn), seed=1234)\n\n        blur_shape = (9, 9)\n        blur_sigma = 2.24\n        opBlur = PaddedCircularConvolve(self.N, self.chn, blur_shape, blur_sigma)\n\n        odpdb = sflax.ODPNet(\n            operator=opBlur,\n            depth=self.depth,\n            channels=self.chn,\n            num_filters=self.num_filters,\n            block_depth=self.block_depth,\n            odp_block=sflax.inverse.ODPProxDcnvBlock,\n        )\n\n        variables = odpdb.init(key, y)\n        # Test for the construction / forward pass.\n        mny = odpdb.apply(variables, y, train=False, mutable=False)\n        assert y.dtype == mny.dtype\n        assert y.shape == mny.shape\n\n    def test_train_odpdcnv_default(self):\n        xt, key = random.randn((10, self.N, self.N, self.chn), seed=4444)\n\n        blur_shape = (7, 7)\n        blur_sigma = 3.3\n        kernel = build_blur_kernel(blur_shape, blur_sigma)\n\n        ishape = (self.N, self.N)\n        opBlur = CircularConvolve(h=kernel, input_shape=ishape)\n\n        model = sflax.ODPNet(\n            operator=opBlur,\n            depth=self.depth,\n            channels=self.chn,\n            num_filters=self.num_filters,\n            block_depth=self.block_depth,\n            odp_block=sflax.inverse.ODPProxDcnvBlock,\n        )\n\n        train_conf: sflax.ConfigDict = {\n            \"seed\": 0,\n            \"opt_type\": \"ADAM\",\n            \"batch_size\": 8,\n            \"num_epochs\": 2,\n            \"base_learning_rate\": 1e-3,\n            \"warmup_epochs\": 0,\n            \"num_train_steps\": -1,\n            \"steps_per_eval\": -1,\n            \"steps_per_epoch\": 1,\n            \"log_every_steps\": 1000,\n        }\n\n        a_f = lambda v: jnp.atleast_3d(opBlur(v.reshape(opBlur.input_shape)))\n        y = lax.map(a_f, xt)\n\n        train_ds = {\"image\": y, \"label\": xt}\n        test_ds = {\"image\": y, \"label\": xt}\n\n        try:\n            alphatrav = construct_traversal(\"alpha\")\n            alphapos = partial(clip_positive, traversal=alphatrav, minval=1e-3)\n            train_conf[\"post_lst\"] = [alphapos]\n            trainer = sflax.BasicFlaxTrainer(\n                train_conf,\n                model,\n                train_ds,\n                test_ds,\n            )\n            modvar, _ = trainer.train()\n        except Exception as e:\n            print(e)\n            assert 0\n        else:\n            alphaval = np.array([alpha for alpha in alphatrav.iterate(modvar[\"params\"])])\n            np.testing.assert_array_less(1e-2 * np.ones(alphaval.shape), alphaval)\n\n\nclass TestCT:\n    def setup_method(self, method):\n        self.N = 32  # signal size\n        self.chn = 1  # number of channels\n        self.bsize = 16  # batch size\n        xt, key = random.randn((2 * self.bsize, self.N, self.N, self.chn), seed=4321)\n\n        self.nproj = 60  # number of projections\n        angles = np.linspace(0, np.pi, self.nproj, endpoint=False, dtype=np.float32)\n        self.opCT = XRayTransform2D(\n            input_shape=(self.N, self.N), det_count=self.N, angles=angles, dx=0.9999 / np.sqrt(2.0)\n        )  # Radon transform operator\n        a_f = lambda v: jnp.atleast_3d(self.opCT(v.squeeze()))\n        y = lax.map(a_f, xt)\n\n        self.train_ds = {\"image\": y, \"label\": xt}\n        self.test_ds = {\"image\": y, \"label\": xt}\n\n        # Model configuration\n        self.model_conf = {\n            \"depth\": 1,\n            \"num_filters\": 16,\n            \"block_depth\": 2,\n        }\n\n        # Training configuration\n        self.train_conf: sflax.ConfigDict = {\n            \"seed\": 0,\n            \"opt_type\": \"ADAM\",\n            \"batch_size\": self.bsize,\n            \"num_epochs\": 2,\n            \"base_learning_rate\": 1e-3,\n            \"warmup_epochs\": 0,\n            \"num_train_steps\": -1,\n            \"steps_per_eval\": -1,\n            \"steps_per_epoch\": 1,\n            \"log_every_steps\": 1000,\n        }\n\n    def test_odpct_default(self):\n        y, key = random.randn((10, self.nproj, self.N, self.chn), seed=1234)\n\n        model = sflax.ODPNet(\n            operator=self.opCT,\n            depth=self.model_conf[\"depth\"],\n            channels=self.chn,\n            num_filters=self.model_conf[\"num_filters\"],\n            block_depth=self.model_conf[\"block_depth\"],\n            odp_block=sflax.inverse.ODPGrDescBlock,\n        )\n\n        variables = model.init(key, y)\n        # Test for the construction / forward pass.\n        oy = model.apply(variables, y, train=False, mutable=False)\n        assert y.dtype == oy.dtype\n\n    def test_modlct_default(self):\n        y, key = random.randn((10, self.nproj, self.N, self.chn), seed=1234)\n\n        model = sflax.MoDLNet(\n            operator=self.opCT,\n            depth=self.model_conf[\"depth\"],\n            channels=self.chn,\n            num_filters=self.model_conf[\"num_filters\"],\n            block_depth=self.model_conf[\"block_depth\"],\n        )\n\n        variables = model.init(key, y)\n        # Test for the construction / forward pass.\n        mny = model.apply(variables, y, train=False, mutable=False)\n        assert y.dtype == mny.dtype\n\n    def test_train_modl(self):\n        model = sflax.MoDLNet(\n            operator=self.opCT,\n            depth=self.model_conf[\"depth\"],\n            channels=self.chn,\n            num_filters=self.model_conf[\"num_filters\"],\n            block_depth=self.model_conf[\"block_depth\"],\n        )\n        try:\n            minval = 1.1e-2\n            lmbdatrav = construct_traversal(\"lmbda\")\n            lmbdapos = partial(\n                clip_positive,\n                traversal=lmbdatrav,\n                minval=minval,\n            )\n            train_conf = dict(self.train_conf)\n            train_conf[\"post_lst\"] = [lmbdapos]\n            trainer = sflax.BasicFlaxTrainer(\n                train_conf,\n                model,\n                self.train_ds,\n                self.test_ds,\n            )\n            modvar, _ = trainer.train()\n        except Exception as e:\n            print(e)\n            assert 0\n        else:\n            lmbdaval = np.array([lmb for lmb in lmbdatrav.iterate(modvar[\"params\"])])\n            np.testing.assert_array_less(1e-2 * np.ones(lmbdaval.shape), lmbdaval)\n\n    def test_train_odpct(self):\n        model = sflax.ODPNet(\n            operator=self.opCT,\n            depth=self.model_conf[\"depth\"],\n            channels=self.chn,\n            num_filters=self.model_conf[\"num_filters\"],\n            block_depth=self.model_conf[\"block_depth\"],\n            odp_block=sflax.inverse.ODPGrDescBlock,\n        )\n\n        try:\n            minval = 1.1e-2\n            maxval = 1e2\n            alphatrav = construct_traversal(\"alpha\")\n            alpharange = partial(clip_range, traversal=alphatrav, minval=minval, maxval=maxval)\n            train_conf = dict(self.train_conf)\n            train_conf[\"post_lst\"] = [alpharange]\n            trainer = sflax.BasicFlaxTrainer(\n                train_conf,\n                model,\n                self.train_ds,\n                self.test_ds,\n            )\n            modvar, _ = trainer.train()\n        except Exception as e:\n            print(e)\n            assert 0\n        else:\n            alphaval = np.array([alpha for alpha in alphatrav.iterate(modvar[\"params\"])])\n            np.testing.assert_array_less(1e-2 * np.ones(alphaval.shape), alphaval)\n"
  },
  {
    "path": "scico/test/flax/test_spectral.py",
    "content": "from functools import partial\nfrom typing import Any, Tuple\n\nimport numpy as np\n\nimport jax\n\nimport pytest\n\nfrom flax.linen import Conv\nfrom flax.linen.module import Module, compact\nfrom scico import flax as sflax\nfrom scico import linop\nfrom scico.flax.train.spectral import (\n    _l2_normalize,\n    conv,\n    estimate_spectral_norm,\n    exact_spectral_norm,\n    spectral_normalization_conv,\n)\nfrom scico.flax.train.traversals import construct_traversal\nfrom scico.random import randn\n\n\ndef test_l2_normalize():\n    N = 256\n    x, key = randn((N, N), seed=135)\n\n    eps = 1e-6\n    l2_jnp = jax.numpy.sqrt((x**2).sum())\n    l2n_jnp = x / (l2_jnp + eps)\n    l2n_util = _l2_normalize(x, eps)\n    np.testing.assert_allclose(l2n_jnp, l2n_util, rtol=eps)\n\n\n@pytest.mark.parametrize(\"kernel_size\", [(3, 3, 1, 1), (11, 11, 1, 1)])\ndef test_conv(kernel_size):\n    key = jax.random.key(97531)\n    kernel, key = randn(kernel_size, dtype=np.float32, key=key)\n\n    input_size = (1, 128, 128, 1)\n    x, key = randn(input_size, dtype=np.float32, key=key)\n\n    pads = (\n        [(0, 0)]\n        + [(kernel_size[0] // 2, kernel_size[0] // 2)]\n        + [(kernel_size[1] // 2, kernel_size[1] // 2)]\n        + [(0, 0)]\n    )\n    xext = np.pad(x, pads, mode=\"wrap\")\n\n    y = jax.scipy.signal.convolve(xext.squeeze(), jax.numpy.flip(kernel).squeeze(), mode=\"valid\")\n\n    y_util = conv(x, kernel).squeeze()\n\n    np.testing.assert_allclose(y, y_util)\n\n\nclass CNN(Module):\n    kernel_size: Tuple[int, int]\n    kernel0: Any\n\n    @compact\n    def __call__(self, x):\n        def kinit_wrap(rng, shape, dtype=np.float32):\n            return np.array(self.kernel0, dtype)\n\n        return Conv(\n            features=1,\n            kernel_size=self.kernel_size,\n            use_bias=False,\n            padding=\"CIRCULAR\",\n            kernel_init=kinit_wrap,\n        )(x)\n\n\n@pytest.mark.parametrize(\"kernel_size\", [(3, 3, 1, 1), (11, 11, 1, 1)])\ndef test_conv_layer(kernel_size):\n    key = jax.random.key(12345)\n    kernel, key = randn(kernel_size, dtype=np.float32, key=key)\n\n    input_size = (1, 128, 128, 1)\n    x, key = randn(input_size, dtype=np.float32, key=key)\n\n    rng = jax.random.key(42)\n    model = CNN(kernel_size=kernel_size[:2], kernel0=kernel)\n    variables = model.init(rng, np.zeros(x.shape))\n    prms = variables[\"params\"]\n    np.testing.assert_allclose(prms[\"Conv_0\"][\"kernel\"], kernel)\n\n    y_layer = model.apply(variables, x)\n    y_util = conv(x, kernel)\n\n    np.testing.assert_allclose(y_layer, y_util)\n\n\n@pytest.mark.parametrize(\"input_shape\", [(8,), (128,)])\ndef test_spectral_norm(input_shape):\n    key = jax.random.key(1357)\n    diagonal, key = randn(input_shape, dtype=np.float32, key=key)\n\n    mu = np.linalg.norm(np.diag(diagonal), 2)\n\n    D = linop.Diagonal(diagonal=diagonal)\n    x, key = randn(input_shape, dtype=np.float32, key=key)\n    mu_util = estimate_spectral_norm(lambda x: D @ x, x.shape, n_steps=200)\n\n    np.testing.assert_allclose(mu, mu_util, rtol=1e-6)\n\n\n@pytest.mark.parametrize(\"kernel_shape\", [(3, 3, 1, 1), (7, 7, 1, 1)])\ndef test_spectral_norm_conv(kernel_shape):\n\n    key = jax.random.key(2468)\n    kernel, key = randn(kernel_shape, dtype=np.float32, key=key)\n\n    input_shape = (1, 32, 32, 1)\n    x, key = randn(input_shape, dtype=np.float32, key=key)\n\n    sn = exact_spectral_norm(lambda x: conv(x, kernel), x.shape)\n\n    sn_util = estimate_spectral_norm(lambda x: conv(x, kernel), x.shape, n_steps=100)\n\n    np.testing.assert_allclose(sn, sn_util, rtol=1e-3, atol=1e-2)\n\n\ndef test_train_spectral_norm():\n    depth = 3\n    channels = 1\n    num_filters = 16\n    model = sflax.DnCNNNet(depth, channels, num_filters)\n\n    train_conf: sflax.ConfigDict = {\n        \"seed\": 0,\n        \"opt_type\": \"ADAM\",\n        \"batch_size\": 16,\n        \"num_epochs\": 1,\n        \"base_learning_rate\": 1e-3,\n        \"lr_decay_rate\": 0.95,\n        \"warmup_epochs\": 0,\n        \"num_train_steps\": -1,\n        \"steps_per_eval\": -1,\n        \"steps_per_epoch\": 1,\n        \"log_every_steps\": 1000,\n    }\n\n    N = 64\n    xtr, key = randn((train_conf[\"batch_size\"], N, N, channels), seed=4321)\n    xtt, key = randn((train_conf[\"batch_size\"], N, N, channels), key=key)\n    train_ds = {\"image\": xtr, \"label\": xtr}\n    test_ds = {\"image\": xtt, \"label\": xtt}\n\n    try:\n        xshape = (1,) + train_ds[\"label\"][0].shape\n        convtrav = construct_traversal(\"kernel\")\n        kernelnorm = partial(\n            spectral_normalization_conv,\n            traversal=convtrav,\n            xshape=xshape,\n        )\n        train_conf[\"post_lst\"] = [kernelnorm]\n        trainer = sflax.BasicFlaxTrainer(\n            train_conf,\n            model,\n            train_ds,\n            test_ds,\n        )\n        modvar, _ = trainer.train()\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        knlsn = np.array(\n            [\n                estimate_spectral_norm(\n                    lambda x: conv(x, kernel), (1, xshape[1], xshape[2], kernel.shape[2])\n                )\n                for kernel in convtrav.iterate(modvar[\"params\"])\n            ]\n        )\n        np.testing.assert_array_less(knlsn, np.ones(knlsn.shape))\n"
  },
  {
    "path": "scico/test/flax/test_steps.py",
    "content": "import functools\n\nimport jax\n\nimport pytest\nfrom test_trainer import SetupTest\n\nfrom flax import jax_utils\nfrom scico import flax as sflax\nfrom scico.flax.train.diagnostics import compute_metrics\nfrom scico.flax.train.learning_rate import create_cnst_lr_schedule\nfrom scico.flax.train.losses import mse_loss\nfrom scico.flax.train.state import create_basic_train_state\nfrom scico.flax.train.steps import eval_step, train_step, train_step_post\nfrom scico.flax.train.traversals import clip_range, construct_traversal\n\n\n@pytest.fixture(scope=\"module\")\ndef testobj():\n    yield SetupTest()\n\n\ndef test_basic_train_step(testobj):\n    key = jax.random.key(seed=531)\n    key1, key2 = jax.random.split(key)\n\n    model = sflax.ConvBNNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    input_shape = (1, testobj.N, testobj.N, testobj.chn)\n    learning_rate = create_cnst_lr_schedule(testobj.train_conf)\n    state = create_basic_train_state(key1, testobj.train_conf, model, input_shape, learning_rate)\n    criterion = mse_loss\n\n    local_batch_size = testobj.train_conf[\"batch_size\"] // jax.process_count()\n    size_device_prefetch = 2\n    train_dt_iter = sflax.create_input_iter(\n        key2,\n        testobj.train_ds,\n        local_batch_size,\n        size_device_prefetch,\n        model.dtype,\n        train=True,\n    )\n    # Training is configured as parallel operation\n    state = jax_utils.replicate(state)\n    p_train_step = jax.pmap(\n        functools.partial(\n            train_step,\n            learning_rate_fn=learning_rate,\n            criterion=criterion,\n            metrics_fn=compute_metrics,\n        ),\n        axis_name=\"batch\",\n    )\n\n    try:\n        batch = next(train_dt_iter)\n        p_train_step(state, batch)\n    except Exception as e:\n        print(e)\n        assert 0\n\n\ndef test_post_train_step(testobj):\n    key = jax.random.key(seed=531)\n    key1, key2 = jax.random.split(key)\n\n    model = sflax.ConvBNNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    input_shape = (1, testobj.N, testobj.N, testobj.chn)\n    learning_rate = create_cnst_lr_schedule(testobj.train_conf)\n    state = create_basic_train_state(key1, testobj.train_conf, model, input_shape, learning_rate)\n    criterion = mse_loss\n\n    local_batch_size = testobj.train_conf[\"batch_size\"] // jax.process_count()\n    size_device_prefetch = 2\n    train_dt_iter = sflax.create_input_iter(\n        key2,\n        testobj.train_ds,\n        local_batch_size,\n        size_device_prefetch,\n        model.dtype,\n        train=True,\n    )\n    # Dum range requirement over kernel parameters\n    ktrav = construct_traversal(\"kernel\")\n    krange = functools.partial(clip_range, traversal=ktrav, minval=1e-5, maxval=1e1)\n    # Training is configured as parallel operation\n    state = jax_utils.replicate(state)\n    p_train_step = jax.pmap(\n        functools.partial(\n            train_step_post,\n            learning_rate_fn=learning_rate,\n            criterion=criterion,\n            train_step_fn=train_step,\n            metrics_fn=compute_metrics,\n            post_lst=[krange],\n        ),\n        axis_name=\"batch\",\n    )\n\n    try:\n        batch = next(train_dt_iter)\n        p_train_step(state, batch)\n    except Exception as e:\n        print(e)\n        assert 0\n\n\ndef test_basic_eval_step(testobj):\n    key = jax.random.key(seed=531)\n    key1, key2 = jax.random.split(key)\n\n    model = sflax.ConvBNNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    input_shape = (1, testobj.N, testobj.N, testobj.chn)\n    learning_rate = create_cnst_lr_schedule(testobj.train_conf)\n    state = create_basic_train_state(key1, testobj.train_conf, model, input_shape, learning_rate)\n    criterion = mse_loss\n\n    local_batch_size = testobj.train_conf[\"batch_size\"] // jax.process_count()\n    size_device_prefetch = 2\n    eval_dt_iter = sflax.create_input_iter(\n        key2,\n        testobj.test_ds,\n        local_batch_size,\n        size_device_prefetch,\n        model.dtype,\n        train=False,\n    )\n    # Evaluation is configured as parallel operation\n    state = jax_utils.replicate(state)\n    p_eval_step = jax.pmap(\n        functools.partial(eval_step, criterion=criterion, metrics_fn=compute_metrics),\n        axis_name=\"batch\",\n    )\n\n    try:\n        batch = next(eval_dt_iter)\n        p_eval_step(state, batch)\n    except Exception as e:\n        print(e)\n        assert 0\n"
  },
  {
    "path": "scico/test/flax/test_train_aux.py",
    "content": "import numpy as np\n\nimport jax\n\nimport pytest\nfrom test_trainer import SetupTest\n\nfrom scico import flax as sflax\nfrom scico import random\nfrom scico.flax.train.clu_utils import flatten_dict\nfrom scico.flax.train.diagnostics import ArgumentStruct, compute_metrics, stats_obj\nfrom scico.flax.train.input_pipeline import IterateData, prepare_data\nfrom scico.flax.train.learning_rate import (\n    create_cnst_lr_schedule,\n    create_cosine_lr_schedule,\n    create_exp_lr_schedule,\n)\nfrom scico.flax.train.losses import mse_loss\nfrom scico.flax.train.state import create_basic_train_state, initialize\n\n\n@pytest.fixture(scope=\"module\")\ndef testobj():\n    yield SetupTest()\n\n\ndef test_mse_loss():\n    N = 256\n    x, key = random.randn((N, N), seed=4321)\n    y, key = random.randn((N, N), key=key)\n    # Optax uses a 0.5 factor.\n    mse_jnp = 0.5 * jax.numpy.mean((x - y) ** 2)\n    mse_optx = mse_loss(y, x)\n    np.testing.assert_allclose(mse_jnp, mse_optx)\n\n\n@pytest.mark.parametrize(\"batch_size\", [4, 8, 16])\ndef test_data_iterator(testobj, batch_size):\n    ds = IterateData(testobj.test_ds_simple, batch_size, train=False)\n    N = testobj.test_ds_simple[\"image\"].shape[0]\n    assert ds.steps_per_epoch == N // batch_size\n    assert ds.key is not None\n\n\n@pytest.mark.parametrize(\"local_batch\", [8, 16, 24])\ndef test_dstrain(testobj, local_batch):\n\n    key = jax.random.key(seed=1234)\n\n    train_iter = sflax.create_input_iter(\n        key,\n        testobj.train_ds_simple,\n        local_batch,\n    )\n\n    nproc = jax.device_count()\n    ll = []\n    num_steps = 40\n    for step, batch in zip(range(num_steps), train_iter):\n        for j in range(nproc):\n            ll.append(batch[\"image\"][j])\n\n    ll_ = np.array(jax.device_get(ll)).flatten()\n    ll_ar = np.array(list(set(np.sort(ll_))))\n\n    np.testing.assert_allclose(ll_ar, np.arange(80))\n\n\n@pytest.mark.parametrize(\"local_batch\", [8, 16, 32])\ndef test_dstest(testobj, local_batch):\n\n    key = jax.random.key(seed=1234)\n\n    train_iter = sflax.create_input_iter(key, testobj.test_ds_simple, local_batch, train=False)\n\n    nproc = jax.device_count()\n    ll = []\n    num_steps = 20\n    for step, batch in zip(range(num_steps), train_iter):\n        for j in range(nproc):\n            ll.append(batch[\"image\"][j])\n\n    ll_ = np.array(jax.device_get(ll)).flatten()\n    ll_ar = np.array(list(set(np.sort(ll_))))\n\n    np.testing.assert_allclose(ll_ar, np.arange(80, 112))\n\n\ndef test_prepare_data(testobj):\n    xbtch = prepare_data(testobj.x)\n    local_device_count = jax.local_device_count()\n    shrdsz = testobj.x.shape[0] // local_device_count\n    assert xbtch.shape == (local_device_count, shrdsz, testobj.N, testobj.N, testobj.chn)\n\n\ndef test_compute_metrics(testobj):\n    xbtch = prepare_data(testobj.x)\n\n    xbtch = xbtch / jax.numpy.sqrt(jax.numpy.var(xbtch, axis=(1, 2, 3, 4)))\n    ybtch = xbtch + 1\n\n    p_eval = jax.pmap(compute_metrics, axis_name=\"batch\")\n    eval_metrics = p_eval(ybtch, xbtch)\n    mtrcs = jax.tree_util.tree_map(lambda x: x.mean(), eval_metrics)\n    assert np.abs(mtrcs[\"loss\"]) < 0.51\n    assert mtrcs[\"snr\"] < 5e-4\n\n\ndef test_cnst_learning_rate(testobj):\n    step = 1\n    cnst_sch = create_cnst_lr_schedule(testobj.train_conf)\n    lr = cnst_sch(step)\n    assert lr == testobj.train_conf[\"base_learning_rate\"]\n\n\ndef test_cos_learning_rate(testobj):\n    step = 1\n    len_train = testobj.train_ds[\"label\"].shape[0]\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"steps_per_epoch\"] = len_train // testobj.train_conf[\"batch_size\"]\n    sch = create_cosine_lr_schedule(train_conf)\n    lr = sch(step)\n    decay_steps = (train_conf[\"num_epochs\"] - train_conf[\"warmup_epochs\"]) * train_conf[\n        \"steps_per_epoch\"\n    ]\n    cosine_decay = 0.5 * (1 + np.cos(np.pi * step / decay_steps))\n    np.testing.assert_allclose(lr, train_conf[\"base_learning_rate\"] * cosine_decay, rtol=1e-06)\n\n\ndef test_exp_learning_rate(testobj):\n    step = 1\n    len_train = testobj.train_ds[\"label\"].shape[0]\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"steps_per_epoch\"] = len_train // testobj.train_conf[\"batch_size\"]\n    steps = train_conf[\"steps_per_epoch\"] * train_conf[\"num_epochs\"]\n    sch = create_exp_lr_schedule(train_conf)\n    lr = sch(step)\n    exp_decay = train_conf[\"lr_decay_rate\"] ** float(step / steps)\n\n    np.testing.assert_allclose(lr, train_conf[\"base_learning_rate\"] * exp_decay, rtol=1e-06)\n\n\ndef test_train_initialize_function(testobj):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    key = jax.random.key(seed=4444)\n    input_shape = (1, testobj.N, testobj.N, testobj.chn)\n\n    # Via initialize function\n    dparams1, dbstats1 = initialize(key, model, input_shape[1:3])\n    flat_params1 = flatten_dict(dparams1)\n    flat_bstats1 = flatten_dict(dbstats1)\n    params1 = [t[1] for t in sorted(flat_params1.items())]\n    bstats1 = [t[1] for t in sorted(flat_bstats1.items())]\n\n    # Via model initialization\n    variables2 = model.init({\"params\": key}, np.ones(input_shape, model.dtype))\n    flat_params2 = flatten_dict(variables2[\"params\"])\n    flat_bstats2 = flatten_dict(variables2[\"batch_stats\"])\n    params2 = [t[1] for t in sorted(flat_params2.items())]\n    bstats2 = [t[1] for t in sorted(flat_bstats2.items())]\n\n    for i in range(len(params1)):\n        np.testing.assert_allclose(params1[i], params2[i], rtol=1e-5)\n    for i in range(len(bstats1)):\n        np.testing.assert_allclose(bstats1[i], bstats2[i], rtol=1e-5)\n\n\ndef test_create_basic_train_state_default(testobj):\n    model = sflax.ConvBNNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    key = jax.random.key(seed=432)\n    input_shape = (1, testobj.N, testobj.N, testobj.chn)\n\n    # Model initialization\n    variables1 = model.init({\"params\": key}, np.ones(input_shape, model.dtype))\n    flat_params1 = flatten_dict(variables1[\"params\"])\n    flat_bstats1 = flatten_dict(variables1[\"batch_stats\"])\n    params1 = [t[1] for t in sorted(flat_params1.items())]\n    bstats1 = [t[1] for t in sorted(flat_bstats1.items())]\n\n    learning_rate = create_cnst_lr_schedule(testobj.train_conf)\n\n    try:\n        # State initialization\n        state = create_basic_train_state(\n            key, testobj.train_conf, model, input_shape[1:3], learning_rate\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        flat_params2 = flatten_dict(state.params)\n        flat_bstats2 = flatten_dict(state.batch_stats)\n        params2 = [t[1] for t in sorted(flat_params2.items())]\n        bstats2 = [t[1] for t in sorted(flat_bstats2.items())]\n\n        for i in range(len(params1)):\n            np.testing.assert_allclose(params1[i], params2[i], rtol=1e-5)\n        for i in range(len(bstats1)):\n            np.testing.assert_allclose(bstats1[i], bstats2[i], rtol=1e-5)\n\n\ndef test_create_basic_train_state(testobj):\n    model = sflax.ConvBNNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    key = jax.random.key(seed=432)\n    input_shape = (1, testobj.N, testobj.N, testobj.chn)\n\n    # Model initialization\n    variables1 = model.init({\"params\": key}, np.ones(input_shape, model.dtype))\n    flat_params1 = flatten_dict(variables1[\"params\"])\n    flat_bstats1 = flatten_dict(variables1[\"batch_stats\"])\n    params1 = [t[1] for t in sorted(flat_params1.items())]\n    bstats1 = [t[1] for t in sorted(flat_bstats1.items())]\n\n    learning_rate = create_cnst_lr_schedule(testobj.train_conf)\n\n    try:\n        # State initialization\n        state = create_basic_train_state(\n            key, testobj.train_conf, model, input_shape[1:3], learning_rate, variables1\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        flat_params2 = flatten_dict(state.params)\n        flat_bstats2 = flatten_dict(state.batch_stats)\n        params2 = [t[1] for t in sorted(flat_params2.items())]\n        bstats2 = [t[1] for t in sorted(flat_bstats2.items())]\n\n        for i in range(len(params1)):\n            np.testing.assert_allclose(params1[i], params2[i], rtol=1e-5)\n        for i in range(len(bstats1)):\n            np.testing.assert_allclose(bstats1[i], bstats2[i], rtol=1e-5)\n\n\ndef test_sgd_train_state(testobj):\n    model = sflax.ConvBNNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    key = jax.random.key(seed=432)\n    input_shape = (1, testobj.N, testobj.N, testobj.chn)\n\n    # Model initialization\n    variables = model.init({\"params\": key}, np.ones(input_shape, model.dtype))\n    learning_rate = create_cnst_lr_schedule(testobj.train_conf)\n\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"opt_type\"] = \"SGD\"\n\n    try:\n        # State initialization\n        state = create_basic_train_state(\n            key, train_conf, model, input_shape[1:3], learning_rate, variables\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n\n\ndef test_sgd_no_momentum_train_state(testobj):\n    model = sflax.ConvBNNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    key = jax.random.key(seed=432)\n    input_shape = (1, testobj.N, testobj.N, testobj.chn)\n\n    # Model initialization\n    variables = model.init({\"params\": key}, np.ones(input_shape, model.dtype))\n    learning_rate = create_cnst_lr_schedule(testobj.train_conf)\n\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"opt_type\"] = \"SGD\"\n    train_conf.pop(\"momentum\")\n\n    try:\n        # State initialization\n        state = create_basic_train_state(\n            key, train_conf, model, input_shape[1:3], learning_rate, variables\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n\n\ndef test_argument_struct():\n    dictaux = {\"epochs\": 5, \"num_steps\": 10, \"seed\": 0}\n    try:\n        dictstruct = ArgumentStruct(**dictaux)\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        assert hasattr(dictstruct, \"epochs\")\n        assert hasattr(dictstruct, \"num_steps\")\n        assert hasattr(dictstruct, \"seed\")\n\n\ndef test_complete_stats_obj():\n    try:\n        itstat_object, itstat_insert_func = stats_obj()\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        summary = {\n            \"epoch\": 3,\n            \"time\": 231.0,\n            \"train_learning_rate\": 1e-2,\n            \"train_loss\": 1.4e-2,\n            \"train_snr\": 3,\n            \"loss\": 1.6e-2,\n            \"snr\": 2.4,\n        }\n        try:\n            itstat_object.insert(itstat_insert_func(ArgumentStruct(**summary)))\n        except Exception as e:\n            print(e)\n            assert 0\n\n\ndef test_except_incomplete_stats_obj():\n\n    itstat_object, itstat_insert_func = stats_obj()\n    summary = {\n        \"epoch\": 3,\n        \"time\": 231.0,\n        \"train_learning_rate\": 1e-2,\n        \"train_loss\": 1.4e-2,\n        \"train_snr\": 3,\n        \"loss\": 1.6e-2,\n        \"snr\": 2.4,\n    }\n    itstat_object.insert(itstat_insert_func(ArgumentStruct(**summary)))\n    summary2 = {\n        \"epoch\": 3,\n        \"time\": 231.0,\n        \"train_learning_rate\": 1e-2,\n        \"train_loss\": 1.4e-2,\n        \"train_snr\": 3,\n    }\n    with pytest.raises(AttributeError):\n        itstat_object.insert(itstat_insert_func(ArgumentStruct(**summary2)))\n\n\ndef test_patch_incomplete_stats_obj():\n\n    itstat_object, itstat_insert_func = stats_obj()\n    summary = {\n        \"epoch\": 3,\n        \"time\": 231.0,\n        \"train_learning_rate\": 1e-2,\n        \"train_loss\": 1.4e-2,\n        \"train_snr\": 3,\n        \"loss\": 1.6e-2,\n        \"snr\": 2.4,\n    }\n    itstat_object.insert(itstat_insert_func(ArgumentStruct(**summary)))\n    summary2 = {\n        \"epoch\": 3,\n        \"time\": 231.0,\n        \"train_learning_rate\": 1e-2,\n        \"train_loss\": 1.4e-2,\n        \"train_snr\": 3,\n    }\n\n    try:\n        summary2[\"loss\"] = -1\n        summary2[\"snr\"] = -1\n        itstat_object.insert(itstat_insert_func(ArgumentStruct(**summary2)))\n    except Exception as e:\n        print(e)\n        assert 0\n"
  },
  {
    "path": "scico/test/flax/test_trainer.py",
    "content": "import functools\n\nimport numpy as np\n\nimport jax\n\nimport optax\nimport pytest\n\nfrom flax import jax_utils\nfrom scico import flax as sflax\nfrom scico import random\nfrom scico.flax.train.clu_utils import flatten_dict\nfrom scico.flax.train.learning_rate import create_cnst_lr_schedule\nfrom scico.flax.train.state import create_basic_train_state\nfrom scico.flax.train.steps import eval_step, train_step\nfrom scico.flax.train.trainer import sync_batch_stats\nfrom scico.flax.train.traversals import clip_positive, clip_range, construct_traversal\n\n\nclass SetupTest:\n    def __init__(self):\n        datain = np.arange(80)\n        datain_test = np.arange(80, 112)\n        dataout = np.zeros(80)\n        dataout[:40] = 1\n        dataout_test = np.zeros(40)\n        dataout_test[:20] = 1\n\n        self.train_ds_simple = {\"image\": datain, \"label\": dataout}\n        self.test_ds_simple = {\"image\": datain_test, \"label\": dataout_test}\n\n        # More complex data structure\n        self.N = 128  # signal size\n        self.chn = 1  # number of channels\n        self.bsize = 16  # batch size\n        self.x, key = random.randn((4 * self.bsize, self.N, self.N, self.chn), seed=4321)\n\n        xt, key = random.randn((32, self.N, self.N, self.chn), key=key)\n\n        self.train_ds = {\"image\": self.x, \"label\": self.x}\n        self.test_ds = {\"image\": xt, \"label\": xt}\n\n        # Model configuration\n        self.model_conf = {\n            \"depth\": 2,\n            \"num_filters\": 16,\n            \"block_depth\": 2,\n        }\n\n        # Training configuration\n        self.train_conf: sflax.ConfigDict = {\n            \"seed\": 0,\n            \"opt_type\": \"ADAM\",\n            \"momentum\": 0.9,\n            \"batch_size\": self.bsize,\n            \"num_epochs\": 1,\n            \"base_learning_rate\": 1e-3,\n            \"lr_decay_rate\": 0.95,\n            \"warmup_epochs\": 0,\n            \"log_every_steps\": 1000,\n        }\n\n\n@pytest.fixture(scope=\"module\")\ndef testobj():\n    yield SetupTest()\n\n\n@pytest.mark.parametrize(\"opt_type\", [\"SGD\", \"ADAM\", \"ADAMW\"])\ndef test_optimizers(testobj, opt_type):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"opt_type\"] = opt_type\n    try:\n        trainer = sflax.BasicFlaxTrainer(\n            train_conf,\n            model,\n            testobj.train_ds,\n            testobj.test_ds,\n        )\n        modvar, _ = trainer.train()\n    except Exception as e:\n        print(e)\n        assert 0\n\n\ndef test_optimizers_exception(testobj):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"opt_type\"] = \"\"\n    with pytest.raises(NotImplementedError):\n        sflax.BasicFlaxTrainer(\n            train_conf,\n            model,\n            testobj.train_ds,\n            testobj.test_ds,\n        )\n\n\ndef test_sync_batch_stats(testobj):\n    key = jax.random.key(seed=12345)\n    key1, key2 = jax.random.split(key)\n\n    model = sflax.ConvBNNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    input_shape = (1, testobj.N, testobj.N, testobj.chn)\n    learning_rate = create_cnst_lr_schedule(testobj.train_conf)\n    state0 = create_basic_train_state(key1, testobj.train_conf, model, input_shape, learning_rate)\n\n    # For parallel training\n    state = jax_utils.replicate(state0)\n    state = sync_batch_stats(state)\n    state1 = jax_utils.unreplicate(state)\n\n    flat_bstats0 = flatten_dict(state0.batch_stats)\n    bstats0 = [t[1] for t in sorted(flat_bstats0.items())]\n\n    flat_bstats1 = flatten_dict(state1.batch_stats)\n    bstats1 = [t[1] for t in sorted(flat_bstats1.items())]\n\n    for i in range(len(bstats0)):\n        np.testing.assert_allclose(bstats0[i], bstats1[i], rtol=1e-5)\n\n\ndef test_class_train_default_init(testobj):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n    try:\n        trainer = sflax.BasicFlaxTrainer(\n            testobj.train_conf,\n            model,\n            testobj.train_ds,\n            testobj.test_ds,\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        assert trainer.itstat_object is None\n\n\ndef test_class_train_default_noseed(testobj):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n    train_conf = dict(testobj.train_conf)\n    train_conf.pop(\"seed\", None)\n    try:\n        trainer = sflax.BasicFlaxTrainer(\n            testobj.train_conf,\n            model,\n            testobj.train_ds,\n            testobj.test_ds,\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n\n\ndef test_class_train_nolog(testobj):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"log\"] = False\n    try:\n        trainer = sflax.BasicFlaxTrainer(\n            train_conf,\n            model,\n            testobj.train_ds,\n            testobj.test_ds,\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        assert trainer.itstat_object is None\n\n\ndef test_class_train_required_steps(testobj):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n    train_conf = dict(testobj.train_conf)\n    train_conf.pop(\"batch_size\", None)\n    train_conf.pop(\"num_epochs\", None)\n    try:\n        trainer = sflax.BasicFlaxTrainer(\n            train_conf,\n            model,\n            testobj.train_ds,\n            testobj.test_ds,\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        batch_size = 2 * jax.device_count()\n        local_batch_size = batch_size // jax.process_count()\n        num_epochs = 10\n        num_steps = int(trainer.steps_per_epoch * num_epochs)\n        assert trainer.local_batch_size == local_batch_size\n        assert trainer.num_steps == num_steps\n\n\n@pytest.mark.skipif(jax.device_count() == 1, reason=\"single device present\")\ndef test_except_class_train_batch_size(testobj):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"batch_size\"] = jax.device_count() + 1\n    with pytest.raises(ValueError):\n        trainer = sflax.BasicFlaxTrainer(\n            train_conf,\n            model,\n            testobj.train_ds,\n            testobj.test_ds,\n        )\n\n\ndef test_class_train_set_steps(testobj):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"steps_per_eval\"] = 1\n    train_conf[\"steps_per_checkpoint\"] = 1\n    train_conf[\"log_every_steps\"] = 3\n    try:\n        trainer = sflax.BasicFlaxTrainer(\n            train_conf,\n            model,\n            testobj.train_ds,\n            testobj.test_ds,\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        assert trainer.steps_per_eval == train_conf[\"steps_per_eval\"]\n        assert trainer.steps_per_checkpoint == train_conf[\"steps_per_checkpoint\"]\n        assert trainer.log_every_steps == train_conf[\"log_every_steps\"]\n\n\ndef test_class_train_set_reporting(testobj):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"log\"] = True\n    train_conf[\"workdir\"] = \"./out/\"\n    train_conf[\"checkpointing\"] = False\n    train_conf[\"return_state\"] = True\n    try:\n        trainer = sflax.BasicFlaxTrainer(\n            train_conf,\n            model,\n            testobj.train_ds,\n            testobj.test_ds,\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        assert trainer.logflag == train_conf[\"log\"]\n        assert trainer.workdir == train_conf[\"workdir\"]\n        assert trainer.checkpointing == train_conf[\"checkpointing\"]\n        assert trainer.return_state == train_conf[\"return_state\"]\n\n\ndef test_class_train_set_functions(testobj):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    def huber_loss(output, labels):\n        return jax.numpy.mean(optax.huber_loss(output, labels))\n\n    # Dum range requirement over kernel parameters\n    ktrav = construct_traversal(\"kernel\")\n    krange = functools.partial(clip_range, traversal=ktrav, minval=1e-5, maxval=1e1)\n\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"criterion\"] = huber_loss\n    train_conf[\"create_train_state\"] = create_basic_train_state\n    train_conf[\"train_step_fn\"] = train_step\n    train_conf[\"eval_step_fn\"] = eval_step\n    train_conf[\"post_lst\"] = [krange]\n    try:\n        trainer = sflax.BasicFlaxTrainer(\n            train_conf,\n            model,\n            testobj.train_ds,\n            testobj.test_ds,\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        assert trainer.criterion == train_conf[\"criterion\"]\n        assert trainer.create_train_state == train_conf[\"create_train_state\"]\n        assert trainer.train_step_fn == train_conf[\"train_step_fn\"]\n        assert trainer.eval_step_fn == train_conf[\"eval_step_fn\"]\n        assert trainer.post_lst[0] == train_conf[\"post_lst\"][0]\n        assert hasattr(trainer, \"lr_schedule\")\n\n\ndef test_class_train_set_iterators(testobj):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n    try:\n        trainer = sflax.BasicFlaxTrainer(\n            testobj.train_conf,\n            model,\n            testobj.train_ds,\n            testobj.test_ds,\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        assert trainer.ishape == testobj.train_ds[\"image\"].shape[1:3]\n        assert hasattr(trainer, \"train_dt_iter\")\n        assert hasattr(trainer, \"eval_dt_iter\")\n\n\n@pytest.mark.parametrize(\"postl\", [False, True])\ndef test_class_train_set_parallel(testobj, postl):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    train_conf = dict(testobj.train_conf)\n\n    train_conf[\"post_lst\"] = []\n    if postl:\n        # Dum range requirement over kernel parameters\n        ktrav = construct_traversal(\"kernel\")\n        krange = functools.partial(clip_range, traversal=ktrav, minval=1e-5, maxval=1e1)\n        train_conf[\"post_lst\"] = [krange]\n\n    try:\n        trainer = sflax.BasicFlaxTrainer(\n            train_conf,\n            model,\n            testobj.train_ds,\n            testobj.test_ds,\n        )\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        assert hasattr(trainer, \"p_train_step\")\n        assert hasattr(trainer, \"p_eval_step\")\n\n\n@pytest.mark.parametrize(\"chkflag\", [False, True])\ndef test_class_train_external_init(testobj, chkflag):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    key = jax.random.key(seed=1234)\n    input_shape = (1, testobj.N, testobj.N, testobj.chn)\n\n    # Via model initialization\n    variables1 = model.init({\"params\": key}, np.ones(input_shape, model.dtype))\n    flat_params1 = flatten_dict(variables1[\"params\"])\n    flat_bstats1 = flatten_dict(variables1[\"batch_stats\"])\n    params1 = [t[1] for t in sorted(flat_params1.items())]\n    bstats1 = [t[1] for t in sorted(flat_bstats1.items())]\n\n    # Via BasicFlaxTrainer object initialization\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"checkpointing\"] = chkflag\n    trainer = sflax.BasicFlaxTrainer(\n        train_conf,\n        model,\n        testobj.train_ds,\n        testobj.test_ds,\n        variables0=variables1,\n    )\n    flat_params2 = flatten_dict(trainer.state.params)\n    flat_bstats2 = flatten_dict(trainer.state.batch_stats)\n    params2 = [t[1] for t in sorted(flat_params2.items())]\n    bstats2 = [t[1] for t in sorted(flat_bstats2.items())]\n\n    for i in range(len(params1)):\n        np.testing.assert_allclose(params1[i], params2[i], rtol=1e-5)\n    for i in range(len(bstats1)):\n        np.testing.assert_allclose(bstats1[i], bstats2[i], rtol=1e-5)\n\n\n@pytest.mark.parametrize(\"model_cls\", [sflax.DnCNNNet, sflax.ResNet, sflax.ConvBNNet, sflax.UNet])\ndef test_class_train_train_loop(testobj, model_cls):\n    depth = testobj.model_conf[\"depth\"]\n    model = model_cls(depth, testobj.chn, testobj.model_conf[\"num_filters\"])\n    if isinstance(model, sflax.DnCNNNet):\n        depth = 3\n        model = sflax.DnCNNNet(depth, testobj.chn, testobj.model_conf[\"num_filters\"])\n\n    # Create training object\n    trainer = sflax.BasicFlaxTrainer(\n        testobj.train_conf,\n        model,\n        testobj.train_ds,\n        testobj.test_ds,\n    )\n\n    try:\n        modvar, _ = trainer.train()\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        assert \"params\" in modvar\n        assert \"batch_stats\" in modvar\n\n\ndef test_class_train_train_post_loop(testobj):\n    depth = testobj.model_conf[\"depth\"]\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    train_conf = dict(testobj.train_conf)\n\n    # Dum positive requirement over kernel parameters\n    ktrav = construct_traversal(\"kernel\")\n    kpos = functools.partial(clip_positive, traversal=ktrav, minval=1e-5)\n    train_conf[\"post_lst\"] = [kpos]\n\n    # Create training object\n    trainer = sflax.BasicFlaxTrainer(\n        train_conf,\n        model,\n        testobj.train_ds,\n        testobj.test_ds,\n    )\n\n    try:\n        modvar, _ = trainer.train()\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        assert \"params\" in modvar\n        assert \"batch_stats\" in modvar\n\n\ndef test_class_train_return_state(testobj):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"return_state\"] = True\n    trainer = sflax.BasicFlaxTrainer(\n        train_conf,\n        model,\n        testobj.train_ds,\n        testobj.test_ds,\n    )\n    try:\n        state, _ = trainer.train()\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        assert hasattr(state, \"params\")\n        assert hasattr(state, \"batch_stats\")\n\n\ndef test_class_train_update_metrics(testobj):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"log\"] = True\n    train_conf[\"log_every_steps\"] = 1\n    trainer = sflax.BasicFlaxTrainer(\n        train_conf,\n        model,\n        testobj.train_ds,\n        testobj.test_ds,\n    )\n    total_steps = (testobj.train_ds[\"label\"].shape[0] // testobj.bsize) * train_conf[\"num_epochs\"]\n    try:\n        state, stats_object = trainer.train()\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        hist = stats_object.history(transpose=True)\n        assert len(hist.Train_Loss) == total_steps\n\n\ndef test_class_train_update_metrics_nolog(testobj):\n    model = sflax.ResNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    train_conf = dict(testobj.train_conf)\n    train_conf[\"log\"] = False\n    train_conf[\"log_every_steps\"] = 1\n    trainer = sflax.BasicFlaxTrainer(\n        train_conf,\n        model,\n        testobj.train_ds,\n        testobj.test_ds,\n    )\n    try:\n        state, stats_object = trainer.train()\n    except Exception as e:\n        print(e)\n        assert 0\n    else:\n        assert stats_object is None\n"
  },
  {
    "path": "scico/test/flax/test_traversal.py",
    "content": "import numpy as np\n\nimport jax\n\nimport pytest\nfrom test_trainer import SetupTest\n\nfrom scico import flax as sflax\nfrom scico.flax.train.traversals import construct_traversal\n\n\n@pytest.fixture(scope=\"module\")\ndef testobj():\n    yield SetupTest()\n\n\n@pytest.mark.parametrize(\"pname\", [\"kernel\", \"bias\", \"scale\"])\ndef test_construct_traversal(testobj, pname):\n    model = sflax.ConvBNNet(\n        testobj.model_conf[\"depth\"], testobj.chn, testobj.model_conf[\"num_filters\"]\n    )\n\n    ndim = 1\n    if pname == \"kernel\":\n        ndim = 4\n\n    key = jax.random.key(seed=432)\n    input_shape = (1, testobj.N, testobj.N, testobj.chn)\n    variables = model.init({\"params\": key}, np.ones(input_shape, model.dtype))\n\n    ptrav = construct_traversal(pname)\n    for pm in ptrav.iterate(variables[\"params\"]):\n        assert len(pm.shape) == ndim\n"
  },
  {
    "path": "scico/test/functional/prox.py",
    "content": "import numpy as np\n\nimport scico.numpy as snp\nfrom scico.solver import minimize\n\n\ndef prox_func(x, v, f, alpha):\n    \"\"\"Evaluate functional of which the proximal operator is the argmin.\"\"\"\n    return 0.5 * snp.sum(snp.abs(x.reshape(v.shape) - v) ** 2) + alpha * snp.array(\n        f(x.reshape(v.shape)), dtype=snp.float64\n    )\n\n\ndef prox_solve(v, v0, f, alpha):\n    \"\"\"Evaluate the alpha-scaled proximal operator of f at v, using v0 as an\n    initial point for the optimization.\"\"\"\n    fnc = lambda x: prox_func(x, v, f, alpha)\n    fmn = minimize(\n        fnc,\n        v0,\n        method=\"Nelder-Mead\",\n        options={\"maxiter\": 1000, \"xatol\": 1e-9, \"fatol\": 1e-9},\n    )\n\n    return fmn.x.reshape(v.shape), fmn.fun\n\n\ndef prox_test(v, nrm, prx, alpha, x0=None, rtol=1e-6):\n    \"\"\"Test the alpha-scaled proximal operator function prx of norm functional nrm\n    at point v.\"\"\"\n    # Evaluate the proximal operator at v\n    px = snp.array(prx(v, alpha, v0=x0))\n    # Proximal operator functional value (i.e. Moreau envelope) at v\n    pf = prox_func(px, v, nrm, alpha)\n    # Brute-force solve of the proximal operator at v\n    mx, mf = prox_solve(v, px, nrm, alpha)\n\n    # Compare prox functional value with brute-force solution\n    if pf < mf:\n        return  # prox gave a lower cost than brute force, so it passes\n\n    np.testing.assert_allclose(pf, mf, rtol=rtol)\n"
  },
  {
    "path": "scico/test/functional/test_composed.py",
    "content": "import numpy as np\n\nfrom jax import config\n\nfrom prox import prox_test\n\nfrom scico import functional, linop\nfrom scico.random import randn\n\n# enable 64-bit mode for output dtype checks\nconfig.update(\"jax_enable_x64\", True)\n\n\nclass TestComposed:\n    def setup_method(self):\n        key = None\n        self.shape = (2, 3, 4)\n        self.dtype = np.float32\n        self.x, key = randn(self.shape, key=key, dtype=self.dtype)\n        self.composed = functional.ComposedFunctional(\n            functional.L2Norm(),\n            linop.Reshape(self.x.shape, (2, -1), input_dtype=self.dtype),\n        )\n\n    def test_eval(self):\n        np.testing.assert_allclose(self.composed(self.x), self.composed.functional(self.x))\n\n    def test_prox(self):\n        prox_test(self.x, self.composed.__call__, self.composed.prox, 1.0)\n"
  },
  {
    "path": "scico/test/functional/test_denoiser_func.py",
    "content": "import numpy as np\n\nimport pytest\n\nfrom scico import denoiser, functional\nfrom scico.denoiser import BM3DProfile, BM4DProfile, have_bm3d, have_bm4d\nfrom scico.random import randn\nfrom scico.test.osver import osx_ver_geq_than\n\n\n# bm3d is known to be broken on OSX 11.6.5. It may be broken on earlier versions too,\n# but this has not been confirmed\n@pytest.mark.skipif(osx_ver_geq_than(\"11.6.5\"), reason=\"bm3d broken on this platform\")\n@pytest.mark.skipif(not have_bm3d, reason=\"bm3d package not installed\")\nclass TestBM3D:\n    def setup_method(self):\n        key = None\n        self.x_gry, key = randn((32, 33), key=key, dtype=np.float32)\n        self.x_rgb, key = randn((33, 34, 3), key=key, dtype=np.float32)\n        self.profile = BM3DProfile()\n        self.profile.num_threads = 1  # Make processing deterministic\n        self.f_gry = functional.BM3D(profile=self.profile)\n        self.f_rgb = functional.BM3D(is_rgb=True, profile=self.profile)\n\n    def test_gry(self):\n        y0 = self.f_gry.prox(self.x_gry, 1.0)\n        y1 = denoiser.bm3d(self.x_gry, 1.0, profile=self.profile)\n        assert np.linalg.norm(y1 - y0) < 1e-6\n\n    def test_rgb(self):\n        y0 = self.f_rgb.prox(self.x_rgb, 1.0)\n        y1 = denoiser.bm3d(self.x_rgb, 1.0, is_rgb=True, profile=self.profile)\n        assert np.linalg.norm(y1 - y0) < 1e-6\n\n\n# bm4d is known to be broken on OSX 11.6.5. It may be broken on earlier versions too,\n# but this has not been confirmed\n@pytest.mark.skipif(osx_ver_geq_than(\"11.6.5\"), reason=\"bm4d broken on this platform\")\n@pytest.mark.skipif(not have_bm4d, reason=\"bm4d package not installed\")\nclass TestBM4D:\n    def setup_method(self):\n        key = None\n        self.x, key = randn((16, 17, 14), key=key, dtype=np.float32)\n        self.profile = BM4DProfile()\n        self.profile.num_threads = 1  # Make processing deterministic\n        self.f = functional.BM4D(profile=self.profile)\n\n    def test(self):\n        y0 = self.f.prox(self.x, 1.0)\n        y1 = denoiser.bm4d(self.x, 1.0, profile=self.profile)\n        assert np.linalg.norm(y1 - y0) < 1e-6\n\n\nclass TestBlindDnCNN:\n    def setup_method(self):\n        key = None\n        self.x_sngchn, key = randn((32, 33), key=key, dtype=np.float32)\n        self.x_mltchn, key = randn((33, 34, 5), key=key, dtype=np.float32)\n        self.dncnn = denoiser.DnCNN(variant=\"6M\")\n        self.f = functional.DnCNN(variant=\"6M\")\n\n    def test_sngchn(self):\n        y0 = self.f.prox(self.x_sngchn, 1.0)\n        y1 = self.dncnn(self.x_sngchn)\n        np.testing.assert_allclose(y0, y1, rtol=1e-5)\n\n    def test_mltchn(self):\n        y0 = self.f.prox(self.x_mltchn, 1.0)\n        y1 = self.dncnn(self.x_mltchn)\n        np.testing.assert_allclose(y0, y1, rtol=1e-5)\n\n\nclass TestNonBlindDnCNN:\n    def setup_method(self):\n        key = None\n        self.x_sngchn, key = randn((32, 33), key=key, dtype=np.float32)\n        self.x_mltchn, key = randn((33, 34, 5), key=key, dtype=np.float32)\n        self.dncnn = denoiser.DnCNN(variant=\"6N\")\n        self.f = functional.DnCNN(variant=\"6N\")\n\n    def test_sngchn(self):\n        y0 = self.f.prox(self.x_sngchn, 1.5)\n        y1 = self.dncnn(self.x_sngchn, 1.5)\n        np.testing.assert_allclose(y0, y1, rtol=1e-5)\n\n    def test_mltchn(self):\n        y0 = self.f.prox(self.x_mltchn, 0.7)\n        y1 = self.dncnn(self.x_mltchn, 0.7)\n        np.testing.assert_allclose(y0, y1, rtol=1e-5)\n"
  },
  {
    "path": "scico/test/functional/test_funcional_core.py",
    "content": "import numpy as np\n\nimport jax.numpy as jnp\nfrom jax import config\n\n# enable 64-bit mode for output dtype checks\nconfig.update(\"jax_enable_x64\", True)\n\n\nimport pytest\nfrom prox import prox_test\n\nimport scico.numpy as snp\nfrom scico import functional\nfrom scico.random import randn\n\nNO_BLOCK_ARRAY = [\n    functional.L21Norm,\n    functional.L1MinusL2Norm,\n    functional.NuclearNorm,\n    functional.AnisotropicTVNorm,\n    functional.IsotropicTVNorm,\n    functional.TVNorm,\n]\nNO_COMPLEX = [functional.NonNegativeIndicator, functional.BoxIndicator]\n\n\ndef pytest_generate_tests(metafunc):\n    level = int(metafunc.config.getoption(\"--level\"))\n    alpha_range = [1e-2, 1e-1, 1e0, 1e1]\n    dtype_range = [np.float32, np.complex64, np.float64, np.complex128]\n    if level == 2:\n        alpha_range = [1e-2, 1e1]\n        dtype_range = [np.float32, np.complex64, np.float64]\n    elif level < 2:\n        alpha_range = [1e-2]\n        dtype_range = [np.float32, np.complex64]\n    if \"alpha\" in metafunc.fixturenames:\n        metafunc.parametrize(\"alpha\", alpha_range)\n    if \"test_dtype\" in metafunc.fixturenames:\n        metafunc.parametrize(\"test_dtype\", dtype_range)\n\n\nclass ProxTestObj:\n    def __init__(self, dtype):\n        key = None\n        self.v, key = randn(shape=(11, 1), dtype=dtype, key=key, seed=3)\n        self.vb, key = randn(shape=((3, 4), (2,)), dtype=dtype, key=key)\n        self.scalar = np.pi\n        self.vz = snp.zeros((3, 4), dtype=dtype)\n\n\n@pytest.fixture\ndef test_prox_obj(test_dtype):\n    return ProxTestObj(test_dtype)\n\n\nclass SeparableTestObject:\n    def __init__(self, dtype):\n        self.f = functional.L1Norm()\n        self.g = functional.SquaredL2Norm()\n        self.fg = functional.SeparableFunctional([self.f, self.g])\n\n        n = 4\n        m = 6\n        key = None\n\n        self.v1, key = randn((n,), key=key, dtype=dtype)  # point for prox eval\n        self.v2, key = randn((m,), key=key, dtype=dtype)  # point for prox eval\n        self.vb = snp.blockarray([self.v1, self.v2])\n\n\n@pytest.fixture\ndef test_separable_obj(test_dtype):\n    return SeparableTestObject(test_dtype)\n\n\ndef test_separable_eval(test_separable_obj):\n    fv1 = test_separable_obj.f(test_separable_obj.v1)\n    gv2 = test_separable_obj.g(test_separable_obj.v2)\n    fgv = test_separable_obj.fg(test_separable_obj.vb)\n    np.testing.assert_allclose(fv1 + gv2, fgv, rtol=5e-2)\n\n\ndef test_separable_prox(test_separable_obj):\n    alpha = 0.1\n    fv1 = test_separable_obj.f.prox(test_separable_obj.v1, alpha)\n    gv2 = test_separable_obj.g.prox(test_separable_obj.v2, alpha)\n    fgv = test_separable_obj.fg.prox(test_separable_obj.vb, alpha)\n    out = snp.blockarray((fv1, gv2))\n    snp.testing.assert_allclose(out, fgv, rtol=5e-2)\n\n\ndef test_separable_grad(test_separable_obj):\n    # Test the separable grad\n    fv1 = test_separable_obj.f.grad(test_separable_obj.v1)\n    gv2 = test_separable_obj.g.grad(test_separable_obj.v2)\n    fgv = test_separable_obj.fg.grad(test_separable_obj.vb)\n    out = snp.blockarray((fv1, gv2))\n    snp.testing.assert_allclose(out, fgv, rtol=5e-2)\n\n\nclass HuberNormSep(functional.HuberNorm):\n    def __init__(self, delta=1.0):\n        super().__init__(delta=delta, separable=True)\n\n\nclass HuberNormNonSep(functional.HuberNorm):\n    def __init__(self, delta=1.0):\n        super().__init__(delta=delta, separable=False)\n\n\nclass TestNormProx:\n    normlist = [\n        functional.L0Norm,\n        functional.L1Norm,\n        functional.SquaredL2Norm,\n        functional.L2Norm,\n        functional.L21Norm,\n        functional.L1MinusL2Norm,\n        HuberNormSep,\n        HuberNormNonSep,\n        functional.NuclearNorm,\n        functional.ZeroFunctional,\n    ]\n\n    normlist_blockarray_ready = set(normlist.copy()) - set(NO_BLOCK_ARRAY)\n\n    @pytest.mark.parametrize(\"norm\", normlist)\n    def test_prox(self, norm, alpha, test_prox_obj):\n        nrmobj = norm()\n        nrm = nrmobj.__call__\n        prx = nrmobj.prox\n        pf = prox_test(test_prox_obj.v, nrm, prx, alpha)\n\n    @pytest.mark.parametrize(\"norm\", normlist)\n    def test_conj_prox(self, norm, alpha, test_prox_obj):\n        nrmobj = norm()\n        v = test_prox_obj.v\n        # Test checks extended Moreau decomposition at a random vector\n        lhs = nrmobj.prox(v, alpha) + alpha * nrmobj.conj_prox(v / alpha, 1.0 / alpha)\n        rhs = v\n        np.testing.assert_allclose(lhs, rhs, rtol=1e-6, atol=0.0)\n\n    @pytest.mark.parametrize(\"norm\", normlist_blockarray_ready)\n    def test_prox_blockarray(self, norm, alpha, test_prox_obj):\n        nrmobj = norm()\n        nrm = nrmobj.__call__\n        prx = nrmobj.prox\n        pf = nrmobj.prox(snp.ravel(test_prox_obj.vb), alpha)\n        pf_b = nrmobj.prox(test_prox_obj.vb, alpha)\n\n        assert pf.dtype == test_prox_obj.vb.dtype\n        assert pf_b.dtype == test_prox_obj.vb.dtype\n\n        snp.testing.assert_allclose(pf, snp.ravel(pf_b), rtol=1e-6)\n\n    @pytest.mark.parametrize(\"norm\", normlist)\n    def test_prox_zeros(self, norm, test_prox_obj):\n        nrmobj = norm()\n        nrm = nrmobj.__call__\n        prx = nrmobj.prox\n        pf = prox_test(test_prox_obj.vz, nrm, prx, alpha=1.0)\n\n    @pytest.mark.parametrize(\"norm\", normlist)\n    def test_scaled_attrs(self, norm, test_prox_obj):\n        alpha = np.sqrt(2)\n        unscaled = norm()\n        scaled = test_prox_obj.scalar * norm()\n\n        assert scaled.has_eval == unscaled.has_eval\n        assert scaled.has_prox == unscaled.has_prox\n        assert scaled.scale == test_prox_obj.scalar\n\n    @pytest.mark.parametrize(\"norm\", normlist)\n    def test_scaled_eval(self, norm, alpha, test_prox_obj):\n        unscaled = norm()\n        scaled = test_prox_obj.scalar * norm()\n\n        a = test_prox_obj.scalar * unscaled(test_prox_obj.v)\n        b = scaled(test_prox_obj.v)\n        np.testing.assert_allclose(a, b)\n\n    @pytest.mark.parametrize(\"norm\", normlist)\n    def test_scaled_prox(self, norm, alpha, test_prox_obj):\n        # Test prox\n        unscaled = norm()\n        scaled = test_prox_obj.scalar * norm()\n        a = unscaled.prox(test_prox_obj.v, alpha * test_prox_obj.scalar)\n        b = scaled.prox(test_prox_obj.v, alpha)\n        np.testing.assert_allclose(a, b)\n\n\nclass TestBlockArrayEval:\n    # Ensures that functionals evaluate properly on a blockarray\n    # By convention, should be the same as evaluating on the flattened array\n\n    # Generate a list of all functionals in scico.functionals that we will check\n    ignore = [\n        functional.Functional,\n        functional.ScaledFunctional,\n        functional.SetDistance,\n        functional.SquaredSetDistance,\n    ]\n    to_check = []\n    for name, cls in functional.__dict__.items():\n        if isinstance(cls, type):\n            if issubclass(cls, functional.Functional):\n                if cls not in ignore and cls.has_eval is True:\n                    to_check.append(cls)\n\n    to_check = set(to_check) - set(NO_BLOCK_ARRAY)\n\n    @pytest.mark.parametrize(\"cls\", to_check)\n    def test_eval(self, cls, test_prox_obj):\n        func = cls()  # instantiate the functional we are testing\n\n        if cls in NO_COMPLEX and snp.util.is_complex_dtype(test_prox_obj.vb.dtype):\n            with pytest.raises(ValueError):\n                x = func(test_prox_obj.vb)\n            return\n\n        x = func(test_prox_obj.vb)\n        y = func(test_prox_obj.vb.ravel())\n\n        assert jnp.isscalar(x) or x.ndim == 0\n        assert jnp.isscalar(y) or y.ndim == 0\n\n        np.testing.assert_allclose(x, y, rtol=1e-6)\n\n\n# only check double precision on projections\n@pytest.fixture(params=[np.float64, np.complex128])\ndef test_proj_obj(request):\n    return ProxTestObj(request.param)\n\n\nclass TestProj:\n    cnstrlist = [functional.NonNegativeIndicator, functional.L2BallIndicator]\n    sdistlist = [functional.SetDistance, functional.SquaredSetDistance]\n\n    @pytest.mark.parametrize(\"cnstr\", cnstrlist)\n    def test_prox(self, cnstr, test_proj_obj):\n        alpha = 1\n        cnsobj = cnstr()\n        cns = cnsobj.__call__\n        prx = cnsobj.prox\n\n        if cnstr in NO_COMPLEX and snp.util.is_complex_dtype(test_proj_obj.v.dtype):\n            with pytest.raises(ValueError):\n                prox_test(test_proj_obj.v, cns, prx, alpha)\n            return\n\n        prox_test(test_proj_obj.v, cns, prx, alpha)\n\n    @pytest.mark.parametrize(\"cnstr\", cnstrlist)\n    def test_prox_scale_invariance(self, cnstr, test_proj_obj):\n        alpha1 = 1e-2\n        alpha2 = 1e0\n        cnsobj = cnstr()\n        u1 = cnsobj.prox(test_proj_obj.v, alpha1)\n        u2 = cnsobj.prox(test_proj_obj.v, alpha2)\n        assert np.linalg.norm(u1 - u2) / np.linalg.norm(u1) <= 1e-7\n\n    @pytest.mark.parametrize(\"sdist\", sdistlist)\n    @pytest.mark.parametrize(\"cnstr\", cnstrlist)\n    def test_setdistance(self, sdist, cnstr, alpha, test_proj_obj):\n        if cnstr in NO_COMPLEX and snp.util.is_complex_dtype(test_proj_obj.v.dtype):\n            return\n        cnsobj = cnstr()\n        proj = cnsobj.prox\n        sdobj = sdist(proj)\n        call = sdobj.__call__\n        prox = sdobj.prox\n        prox_test(test_proj_obj.v, call, prox, alpha)\n"
  },
  {
    "path": "scico/test/functional/test_indicator.py",
    "content": "import pytest\n\nimport scico.numpy as snp\nfrom scico import functional\nfrom scico.random import randn\n\nINDICATOR = [\n    functional.L2BallIndicator,\n    functional.NonNegativeIndicator,\n    functional.BoxIndicator,\n]\n\n\n@pytest.mark.parametrize(\"indicator\", INDICATOR)\ndef test_indicator(indicator):\n    x, key = randn(shape=(8,), dtype=snp.float32)\n    func = indicator()\n    assert func(func.prox(x)) == 0.0\n"
  },
  {
    "path": "scico/test/functional/test_loss.py",
    "content": "import numpy as np\n\nfrom jax import config\n\nimport pytest\n\n# enable 64-bit mode for output dtype checks\nconfig.update(\"jax_enable_x64\", True)\n\nfrom prox import prox_test\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss\nfrom scico.numpy.util import complex_dtype\nfrom scico.random import randn, uniform\n\n\nclass TestLoss:\n    def setup_method(self):\n        n = 4\n        dtype = np.float64\n        A, key = randn((n, n), key=None, dtype=dtype, seed=1234)\n        D, key = randn((n,), key=key, dtype=dtype)\n        W, key = randn((n,), key=key, dtype=dtype)\n        W = 0.1 * W + 1.0\n        self.Ao = linop.MatrixOperator(A)\n        self.Ao_abs = linop.MatrixOperator(snp.abs(A))\n        self.Do = linop.Diagonal(D)\n        self.W = linop.Diagonal(W)\n        self.y, key = randn((n,), key=key, dtype=dtype)\n        self.v, key = randn((n,), key=key, dtype=dtype)  # point for prox eval\n        scalar, key = randn((1,), key=key, dtype=dtype)\n        self.key = key\n        self.scalar = scalar[0].item()\n\n    def test_generic_squared_l2(self):\n        A = linop.Identity(input_shape=self.y.shape)\n        f = functional.SquaredL2Norm()\n        L0 = loss.Loss(self.y, A=A, f=f, scale=0.5)\n        L1 = loss.SquaredL2Loss(y=self.y, A=A)\n        np.testing.assert_allclose(L0(self.v), L1(self.v))\n        np.testing.assert_allclose(\n            L0.prox(self.v, self.scalar), L1.prox(self.v, self.scalar), rtol=1e-6\n        )\n\n    def test_generic_exception(self):\n        A = linop.Diagonal(self.v)\n        L = loss.Loss(self.y, A=A, scale=0.5)\n        with pytest.raises(NotImplementedError):\n            L(self.v)\n        f = functional.L1Norm()\n        L = loss.Loss(self.y, A=A, f=f, scale=0.5)\n        assert not L.has_prox\n        with pytest.raises(NotImplementedError):\n            L.prox(self.v, self.scalar)\n\n    def test_squared_l2(self):\n        L = loss.SquaredL2Loss(y=self.y, A=self.Ao)\n        assert L.has_eval\n        assert L.has_prox\n\n        # test eval\n        np.testing.assert_allclose(L(self.v), 0.5 * ((self.Ao @ self.v - self.y) ** 2).sum())\n\n        cL = self.scalar * L\n        assert L.scale == 0.5  # hasn't changed\n        assert cL.scale == self.scalar * L.scale\n        assert cL(self.v) == self.scalar * L(self.v)\n\n        # squared l2 loss with diagonal linop has a prox\n        L_d = loss.SquaredL2Loss(y=self.y, A=self.Do)\n\n        # test eval\n        np.testing.assert_allclose(L_d(self.v), 0.5 * ((self.Do @ self.v - self.y) ** 2).sum())\n\n        assert L_d.has_eval\n        assert L_d.has_prox\n\n        cL = self.scalar * L_d\n        assert L_d.scale == 0.5  # hasn't changed\n        assert cL.scale == self.scalar * L_d.scale\n        assert cL(self.v) == self.scalar * L_d(self.v)\n\n        pf = prox_test(self.v, L_d, L_d.prox, 0.75)\n        pf = prox_test(self.v, L, L.prox, 0.75)\n\n    def test_squared_l2_grad(self):\n        La = loss.SquaredL2Loss(y=self.y)\n        Lb = loss.SquaredL2Loss(y=self.y, scale=5e0)\n        Lc = 1e1 * La\n        ga = La.grad(self.v)\n        gb = Lb.grad(self.v)\n        gc = Lc.grad(self.v)\n        np.testing.assert_allclose(1e1 * ga, gb)\n        np.testing.assert_allclose(gb, gc)\n\n    def test_weighted_squared_l2(self):\n        L = loss.SquaredL2Loss(y=self.y, A=self.Ao, W=self.W)\n        assert L.has_eval\n        assert L.has_prox\n        np.testing.assert_allclose(\n            L(self.v), 0.5 * (self.W @ (self.Ao @ self.v - self.y) ** 2).sum()\n        )\n        pf = prox_test(self.v, L, L.prox, 0.75)\n\n        # weighted l2 loss with diagonal linop has a prox\n        L_d = loss.SquaredL2Loss(y=self.y, A=self.Do, W=self.W)\n        assert L_d.has_eval\n        assert L_d.has_prox\n        np.testing.assert_allclose(\n            L_d(self.v), 0.5 * (self.W @ (self.Do @ self.v - self.y) ** 2).sum()\n        )\n        pf = prox_test(self.v, L_d, L_d.prox, 0.75)\n\n    def test_poisson(self):\n        L = loss.PoissonLoss(y=self.y, A=self.Ao_abs)\n        assert L.has_eval\n        assert not L.has_prox\n\n        # test eval\n        v = snp.abs(self.v)\n        Av = self.Ao_abs @ v\n        np.testing.assert_allclose(L(v), 0.5 * snp.sum(Av - self.y * snp.log(Av) + L.const))\n\n        cL = self.scalar * L\n        assert L.scale == 0.5  # hasn't changed\n        assert cL.scale == self.scalar * L.scale\n        assert cL(v) == self.scalar * L(v)\n\n\nclass TestAbsLoss:\n    abs_loss = (\n        (loss.SquaredL2AbsLoss, snp.abs),\n        (loss.SquaredL2SquaredAbsLoss, lambda x: snp.abs(x) ** 2),\n    )\n\n    def setup_method(self):\n        n = 4\n        dtype = np.float64\n        A, key = randn((n, n), key=None, dtype=dtype, seed=1234)\n        W, key = randn((n,), key=key, dtype=dtype)\n        W = 0.1 * W + 1.0\n        self.Ao = linop.MatrixOperator(A)\n        self.Ao_abs = linop.MatrixOperator(snp.abs(A))\n        self.W = linop.Diagonal(W)\n        self.x, key = randn((n,), key=key, dtype=complex_dtype(dtype))\n        self.v, key = randn((n,), key=key, dtype=complex_dtype(dtype))  # point for prox eval\n        scalar, key = randn((1,), key=key, dtype=dtype)\n        self.scalar = scalar[0].item()\n\n    @pytest.mark.parametrize(\"loss_tuple\", abs_loss)\n    def test_properties(self, loss_tuple):\n        loss_class = loss_tuple[0]\n        loss_func = loss_tuple[1]\n\n        y = loss_func(self.Ao(self.x))\n        L = loss_class(y=y, A=self.Ao, W=self.W)\n        assert L.has_eval\n        assert not L.has_prox\n\n        cL = self.scalar * L\n        assert L.scale == 0.5  # hasn't changed\n        assert cL.scale == self.scalar * L.scale\n        assert cL(self.v) == self.scalar * L(self.v)\n\n        with pytest.raises(NotImplementedError):\n            px = L.prox(self.v, 0.75)\n\n        np.testing.assert_allclose(L(self.x), 0)\n\n        y = loss_func(self.x)\n        L = loss_class(y=y, A=None, W=None)\n        assert L.has_eval\n        assert L.has_prox\n\n        cL = self.scalar * L\n        assert L.scale == 0.5  # hasn't changed\n        assert cL.scale == self.scalar * L.scale\n        assert cL(self.v) == self.scalar * L(self.v)\n\n        np.testing.assert_allclose(L(self.x), 0)\n\n        W = -1 * self.W\n        with pytest.raises(ValueError):\n            L = loss_class(y=y, W=W)\n\n        with pytest.raises(TypeError):\n            L = loss_class(y=y, W=linop.Sum(input_shape=W.input_shape))\n\n    @pytest.mark.parametrize(\"loss_tuple\", abs_loss)\n    def test_prox(self, loss_tuple):\n        loss_class = loss_tuple[0]\n        loss_func = loss_tuple[1]\n\n        y = loss_func(self.x)\n        L = loss_class(y=y, A=None, W=self.W)\n\n        pf = prox_test(self.v.real, L, L.prox, 0.5)  # real v\n\n        pf = prox_test(self.v, L, L.prox, 0.0)  # complex v\n        pf = prox_test(self.v, L, L.prox, 0.1)  # complex v\n        pf = prox_test(self.v, L, L.prox, 2.0)  # complex v\n\n        pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 0.0)  # complex zero v\n        pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 1.0)  # complex zero v\n        pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 2.0)  # complex zero v\n\n        # zero y\n        y = snp.zeros(self.x.shape)\n        L = loss_class(y=y, A=None, W=self.W)\n\n        pf = prox_test(self.v.real, L, L.prox, 0.5)  # real v\n\n        pf = prox_test(self.v, L, L.prox, 0.0)  # complex v\n        pf = prox_test(self.v, L, L.prox, 0.1)  # complex v\n\n        pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 0.0)  # complex zero v\n        pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 1.0)  # complex zero v\n\n\ndef test_cubic_root():\n    N = 10000\n    p, key = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0, seed=1234)\n    q, _ = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0, key=key)\n    # Avoid cases of very poor numerical precision\n    p = p.at[snp.logical_and(snp.abs(p) < 2, q > 5e-2 * snp.abs(p))].set(1e1)\n    r = loss._dep_cubic_root(p, q)\n    err = snp.abs(r**3 + p * r + q)\n    assert err.max() < 2e-4\n    # Test loss of precision warning\n    p = snp.array(1e-4, dtype=snp.float32)\n    q = snp.array(1e1, dtype=snp.float32)\n    with pytest.warns(UserWarning):\n        r = loss._dep_cubic_root(p, q)\n"
  },
  {
    "path": "scico/test/functional/test_misc.py",
    "content": "import numpy as np\n\nimport jax\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico import functional, linop\n\n\nclass TestCheckAttrs:\n    # Ensure that the has_eval, has_prox attrs are overridden\n    # and set to True/False in the Functional subclasses.\n\n    # Generate a list of all functionals in scico.functionals that we will check\n    ignore = [\n        functional.Functional,\n        functional.FunctionalSum,\n        functional.ScaledFunctional,\n        functional.SeparableFunctional,\n        functional.ComposedFunctional,\n        functional.ProximalAverage,\n    ]\n    to_check = []\n    for name, cls in functional.__dict__.items():\n        if isinstance(cls, type):\n            if issubclass(cls, functional.Functional):\n                if cls not in ignore:\n                    to_check.append(cls)\n\n    @pytest.mark.parametrize(\"cls\", to_check)\n    def test_has_eval(self, cls):\n        assert isinstance(cls.has_eval, bool)\n\n    @pytest.mark.parametrize(\"cls\", to_check)\n    def test_has_prox(self, cls):\n        assert isinstance(cls.has_prox, bool)\n\n\nclass TestJit:\n    # Test whether functionals can be jitted.\n\n    # Generate a list of all functionals in scico.functionals that we will check\n    ignore = [\n        functional.Functional,\n        functional.ScaledFunctional,\n        functional.SeparableFunctional,\n        functional.AnisotropicTVNorm,  # requires input_shape parameter in order to be jittable\n        functional.IsotropicTVNorm,  # requires input_shape parameter in order to be jittable\n        functional.TVNorm,  # requires input_shape parameter in order to be jittable\n        functional.BM3D,\n        functional.BM4D,\n    ]\n    to_check = []\n    for name, cls in functional.__dict__.items():\n        if isinstance(cls, type):\n            if issubclass(cls, functional.Functional):\n                if cls not in ignore:\n                    to_check.append(cls)\n\n    @pytest.mark.parametrize(\"cls\", to_check)\n    def test_jit(self, cls):\n        # Only test functionals that have no required __init__ parameters.\n        try:\n            f = cls()\n        except TypeError:\n            pass\n        else:\n            v = snp.arange(4.0)\n            # Only test functionals that can take 1D input.\n            try:\n                u0 = f.prox(v)\n            except ValueError:\n                pass\n            else:\n                fprox = jax.jit(f.prox)\n                u1 = fprox(v)\n                assert np.allclose(u0, u1)\n\n\ndef test_functional_sum():\n    x = np.random.randn(4, 4)\n    f0 = functional.L1Norm()\n    f1 = 2.0 * functional.L2Norm()\n    f = f0 + f1\n    assert f(x) == f0(x) + f1(x)\n    with pytest.raises(TypeError):\n        f = f0 + 2.0\n\n\ndef test_scalar_vmap():\n    x = np.random.randn(4, 4)\n    f = functional.L1Norm()\n\n    def foo(c):\n        return (c * f)(x)\n\n    c_list = [1.0, 2.0, 3.0]\n    non_vmap = np.array([foo(c) for c in c_list])\n\n    vmapped = jax.vmap(foo)(snp.array(c_list))\n    np.testing.assert_allclose(non_vmap, vmapped)\n\n\ndef test_scalar_pmap():\n    x = np.random.randn(4, 4)\n    f = functional.L1Norm()\n\n    def foo(c):\n        return (c * f)(x)\n\n    c_list = np.random.randn(jax.device_count())\n    non_pmap = np.array([foo(c) for c in c_list])\n    pmapped = jax.pmap(foo)(c_list)\n    np.testing.assert_allclose(non_pmap, pmapped)\n\n\ndef test_scalar_aggregation():\n    f = functional.L2Norm()\n    g = 2.0 * f\n    h = 5.0 * g\n    assert isinstance(g, functional.ScaledFunctional)\n    assert isinstance(g.functional, functional.L2Norm)\n    assert g.scale == 2.0\n    assert isinstance(h, functional.ScaledFunctional)\n    assert isinstance(h.functional, functional.L2Norm)\n    assert h.scale == 10.0\n\n\n@pytest.mark.parametrize(\n    \"func\",\n    [\n        functional.ZeroFunctional(),\n        functional.SeparableFunctional((functional.ZeroFunctional(), functional.ZeroFunctional())),\n        functional.ComposedFunctional(functional.ZeroFunctional(), linop.Identity((4,))),\n        functional.FunctionalSum(functional.ZeroFunctional(), functional.ZeroFunctional()),\n    ],\n)\ndef test_repr_str(func):\n    fname = str(func)\n    frepr = repr(func)\n    assert fname in frepr\n    assert \"has_eval:\" in frepr\n    assert \"has_prox:\" in frepr\n"
  },
  {
    "path": "scico/test/functional/test_norm.py",
    "content": "import numpy as np\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico import functional\n\n\n@pytest.mark.parametrize(\"axis\", [0, 1, (0, 2)])\ndef test_l21norm(axis):\n    x = np.ones((3, 4, 5))\n    if isinstance(axis, int):\n        l2axis = (axis,)\n    else:\n        l2axis = axis\n    l2shape = [x.shape[k] for k in l2axis]\n    l1axis = tuple(set(range(len(x))) - set(l2axis))\n    l1shape = [x.shape[k] for k in l1axis]\n\n    l21ana = np.sqrt(np.prod(l2shape)) * np.prod(l1shape)\n    F = functional.L21Norm(l2_axis=axis)\n    l21num = F(x)\n    np.testing.assert_allclose(l21ana, l21num, rtol=1e-5)\n\n    l2ana = np.sqrt(np.prod(l2shape))\n    prxana = (l2ana - 1.0) / l2ana * x\n    prxnum = F.prox(x, 1.0)\n    np.testing.assert_allclose(prxana, prxnum, rtol=1e-5)\n\n\ndef test_l2norm_blockarray():\n    xa = np.random.randn(2, 3, 4)\n    xb = snp.blockarray((xa[0], xa[1]))\n\n    fa = functional.L21Norm(l2_axis=(1, 2))\n    fb = functional.L21Norm(l2_axis=None)\n\n    np.testing.assert_allclose(fa(xa), fb(xb), rtol=1e-6)\n\n    ya = fa.prox(xa)\n    yb = fb.prox(xb)\n\n    np.testing.assert_allclose(ya[0], yb[0], rtol=1e-6)\n    np.testing.assert_allclose(ya[1], yb[1], rtol=1e-6)\n"
  },
  {
    "path": "scico/test/functional/test_proxavg.py",
    "content": "import numpy as np\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.optimize.pgm import AcceleratedPGM\n\n\ndef test_proxavg_init():\n    g0 = functional.L1Norm()\n    g1 = functional.L2Norm()\n\n    with pytest.raises(ValueError):\n        h = functional.ProximalAverage(\n            [g0, g1],\n            alpha_list=[\n                0.1,\n            ],\n        )\n\n    h = functional.ProximalAverage([g0, g1], alpha_list=[0.1, 0.1])\n    assert sum(h.alpha_list) == 1.0\n\n    g1.has_prox = False\n    with pytest.raises(ValueError):\n        h = functional.ProximalAverage([g0, g1])\n\n\ndef test_proxavg():\n    N = 128\n    g = np.linspace(0, 2 * np.pi, N, dtype=np.float32)\n    y = np.sin(2 * g)\n    y[y > 0.5] = 0.5\n    y[y < -0.5] = -0.5\n    y *= 2\n    y = snp.array(y)\n\n    λ0 = 6e-1\n    λ1 = 6e-1\n    f = loss.SquaredL2Loss(y=y)\n    g0 = λ0 * functional.L1Norm()\n    g1 = λ1 * functional.L2Norm()\n\n    solver = ADMM(\n        f=f,\n        g_list=[0.5 * g0, 0.5 * g1],\n        C_list=[linop.Identity(y.shape), linop.Identity(y.shape)],\n        rho_list=[1e1, 1e1],\n        x0=y,\n        maxiter=100,\n        subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-5, \"maxiter\": 20}),\n    )\n    x_admm = solver.solve()\n\n    h = functional.ProximalAverage([λ0 * functional.L1Norm(), λ1 * functional.L2Norm()])\n    solver = AcceleratedPGM(f=f, g=h, L0=3.4e2, x0=y, maxiter=250)\n    x_prxavg = solver.solve()\n\n    assert metric.snr(x_admm, x_prxavg) > 50\n"
  },
  {
    "path": "scico/test/functional/test_separable.py",
    "content": "import numpy as np\n\nfrom jax import config\n\n# enable 64-bit mode for output dtype checks\nconfig.update(\"jax_enable_x64\", True)\n\n\nimport pytest\n\nfrom scico import functional\nfrom scico.numpy import blockarray\nfrom scico.numpy.testing import assert_allclose\nfrom scico.random import randn\n\n\nclass SeparableTestObject:\n    def __init__(self, dtype):\n        self.f = functional.L1Norm()\n        self.g = functional.SquaredL2Norm()\n        self.fg = functional.SeparableFunctional([self.f, self.g])\n\n        n = 4\n        m = 6\n        key = None\n\n        self.v1, key = randn((n,), key=key, dtype=dtype)  # point for prox eval\n        self.v2, key = randn((m,), key=key, dtype=dtype)  # point for prox eval\n        self.vb = blockarray([self.v1, self.v2])\n\n\n@pytest.fixture(params=[np.float32, np.complex64, np.float64, np.complex128])\ndef test_separable_obj(request):\n    return SeparableTestObject(request.param)\n\n\ndef test_separable_eval(test_separable_obj):\n    fv1 = test_separable_obj.f(test_separable_obj.v1)\n    gv2 = test_separable_obj.g(test_separable_obj.v2)\n    fgv = test_separable_obj.fg(test_separable_obj.vb)\n    assert_allclose(fv1 + gv2, fgv, rtol=5e-2)\n\n\ndef test_separable_prox(test_separable_obj):\n    alpha = 0.1\n    fv1 = test_separable_obj.f.prox(test_separable_obj.v1, alpha)\n    gv2 = test_separable_obj.g.prox(test_separable_obj.v2, alpha)\n    fgv = test_separable_obj.fg.prox(test_separable_obj.vb, alpha)\n    out = blockarray((fv1, gv2)).ravel()\n    assert_allclose(out, fgv.ravel(), rtol=5e-2)\n\n\ndef test_separable_grad(test_separable_obj):\n    # Tests the separable grad\n    fv1 = test_separable_obj.f.grad(test_separable_obj.v1)\n    gv2 = test_separable_obj.g.grad(test_separable_obj.v2)\n    fgv = test_separable_obj.fg.grad(test_separable_obj.vb)\n    out = blockarray((fv1, gv2)).ravel()\n    assert_allclose(out, fgv.ravel(), rtol=5e-2)\n"
  },
  {
    "path": "scico/test/functional/test_tvnorm.py",
    "content": "import numpy as np\n\nimport pytest\n\nimport scico.random\nfrom scico import functional, linop, loss, metric\nfrom scico.examples import create_circular_phantom\nfrom scico.functional._tvnorm import HaarTransform, SingleAxisHaarTransform\nfrom scico.optimize.admm import ADMM, LinearSubproblemSolver\nfrom scico.optimize.pgm import AcceleratedPGM\n\n\n@pytest.mark.parametrize(\"axis\", [0, 1])\ndef test_single_axis_haar_transform(axis):\n    x, key = scico.random.randn((3, 4), seed=1234)\n    HT = SingleAxisHaarTransform(x.shape, axis=axis)\n    np.testing.assert_allclose(2 * x, HT.T(HT(x)), rtol=1e-6)\n\n\ndef test_haar_transform():\n    x, key = scico.random.randn((3, 4), seed=1234)\n    HT = HaarTransform(x.shape)\n    np.testing.assert_allclose(4 * x, HT.T(HT(x)), rtol=1e-6)\n\n\n@pytest.mark.parametrize(\"circular\", [True, False])\ndef test_aniso_1d(circular):\n    N = 128\n    g = np.linspace(0, 2 * np.pi, N, dtype=np.float32)\n    x_gt = np.sin(2 * g)\n    x_gt[x_gt > 0.5] = 0.5\n    x_gt[x_gt < -0.5] = -0.5\n    σ = 0.02\n    noise, key = scico.random.randn(x_gt.shape, seed=0)\n    y = x_gt + σ * noise\n\n    λ = 5e-2\n    f = loss.SquaredL2Loss(y=y)\n\n    C = linop.FiniteDifference(\n        input_shape=x_gt.shape, circular=circular, append=None if circular else 0\n    )\n    g = λ * functional.L1Norm()\n    solver = ADMM(\n        f=f,\n        g_list=[g],\n        C_list=[C],\n        rho_list=[1e1],\n        x0=y,\n        maxiter=50,\n        subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-4, \"maxiter\": 25}),\n    )\n    x_tvdn = solver.solve()\n\n    h = λ * functional.AnisotropicTVNorm(circular=circular, input_shape=y.shape)\n    solver = AcceleratedPGM(f=f, g=h, L0=5e2, x0=y, maxiter=100)\n    x_approx = solver.solve()\n\n    assert metric.snr(x_tvdn, x_approx) > 50\n    assert metric.rel_res(g(C(x_tvdn)), h(x_tvdn)) < 1e-6\n\n\nclass Test2D:\n    def setup_method(self):\n        N = 32\n        x_gt = create_circular_phantom(\n            (N, N), [0.6 * N, 0.4 * N, 0.2 * N, 0.1 * N], [0.25, 1, 0, 0.5]\n        ).astype(np.float32)\n        gr, gc = np.ogrid[0:N, 0:N]\n        x_gt += ((gr + gc) / (4 * N)).astype(np.float32)\n        σ = 0.02\n        noise, key = scico.random.randn(x_gt.shape, seed=0, dtype=np.float32)\n        y = x_gt + σ * noise\n        self.x_gt = x_gt\n        self.y = y\n\n    @pytest.mark.parametrize(\"circular\", [True, False])\n    @pytest.mark.parametrize(\"tvtype\", [\"aniso\", \"iso\"])\n    def test_2d(self, tvtype, circular):\n        x_gt = self.x_gt\n        y = self.y\n\n        λ = 5e-2\n        f = loss.SquaredL2Loss(y=y)\n        if tvtype == \"aniso\":\n            g = λ * functional.L1Norm()\n        else:\n            g = λ * functional.L21Norm()\n        C = linop.FiniteDifference(\n            input_shape=x_gt.shape, circular=circular, append=None if circular else 0\n        )\n\n        solver = ADMM(\n            f=f,\n            g_list=[g],\n            C_list=[C],\n            rho_list=[1e1],\n            x0=y,\n            maxiter=150,\n            subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-4, \"maxiter\": 25}),\n        )\n        x_tvdn = solver.solve()\n\n        if tvtype == \"aniso\":\n            h = λ * functional.AnisotropicTVNorm(circular=circular, input_shape=y.shape)\n        else:\n            h = λ * functional.IsotropicTVNorm(circular=circular, input_shape=y.shape)\n\n        solver = AcceleratedPGM(\n            f=f,\n            g=h,\n            L0=1e3,\n            x0=y,\n            maxiter=400,\n        )\n        x_aprx = solver.solve()\n\n        assert metric.snr(x_tvdn, x_aprx) > 50\n        assert metric.rel_res(g(C(x_tvdn)), h(x_tvdn)) < 1e-6\n\n\nclass Test3D:\n    def setup_method(self):\n        N = 32\n        x2d = create_circular_phantom(\n            (N, N), [0.6 * N, 0.4 * N, 0.2 * N, 0.1 * N], [0.25, 1, 0, 0.5]\n        ).astype(np.float32)\n        gr, gc = np.ogrid[0:N, 0:N]\n        x2d += ((gr + gc) / (4 * N)).astype(np.float32)\n        x_gt = np.stack((0.9 * x2d, np.zeros(x2d.shape), 1.1 * x2d), dtype=np.float32)\n        σ = 0.02\n        noise, key = scico.random.randn(x_gt.shape, seed=0, dtype=np.float32)\n        y = x_gt + σ * noise\n        self.x_gt = x_gt\n        self.y = y\n\n    @pytest.mark.parametrize(\"circular\", [False])\n    @pytest.mark.parametrize(\"tvtype\", [\"iso\"])\n    def test_3d(self, tvtype, circular):\n        x_gt = self.x_gt\n        y = self.y\n\n        λ = 5e-2\n        f = loss.SquaredL2Loss(y=y)\n        if tvtype == \"aniso\":\n            g = λ * functional.L1Norm()\n        else:\n            g = λ * functional.L21Norm()\n        C = linop.FiniteDifference(\n            input_shape=x_gt.shape, axes=(1, 2), circular=circular, append=None if circular else 0\n        )\n\n        solver = ADMM(\n            f=f,\n            g_list=[g],\n            C_list=[C],\n            rho_list=[5e0],\n            x0=y,\n            maxiter=150,\n            subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-4, \"maxiter\": 25}),\n        )\n        x_tvdn = solver.solve()\n\n        if tvtype == \"aniso\":\n            h = λ * functional.AnisotropicTVNorm(\n                circular=circular, axes=(1, 2), input_shape=y.shape\n            )\n        else:\n            h = λ * functional.IsotropicTVNorm(circular=circular, axes=(1, 2), input_shape=y.shape)\n\n        solver = AcceleratedPGM(\n            f=f,\n            g=h,\n            L0=1e3,\n            x0=y,\n            maxiter=400,\n        )\n        x_aprx = solver.solve()\n\n        assert metric.snr(x_tvdn, x_aprx) > 50\n        assert metric.rel_res(g(C(x_tvdn)), h(x_tvdn)) < 1e-6\n"
  },
  {
    "path": "scico/test/linop/test_binop.py",
    "content": "import operator as op\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico import linop\nfrom scico.operator import Abs, Operator\n\n\nclass TestBinaryOp:\n    def setup_method(self, method):\n        self.input_shape = (5,)\n        self.input_dtype = snp.float32\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub])\n    def test_case1(self, operator):\n        A = linop.Convolve(\n            snp.ones((2,)), input_shape=self.input_shape, input_dtype=self.input_dtype, mode=\"same\"\n        )\n        B = Abs(input_shape=self.input_shape, input_dtype=self.input_dtype)\n\n        assert type(operator(A, B)) == Operator\n        assert type(operator(B, A)) == Operator\n        assert type(operator(2.0 * A, 3.0 * B)) == Operator\n        assert type(operator(2.0 * B, 3.0 * A)) == Operator\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub])\n    def test_case2(self, operator):\n        A = linop.Convolve(\n            snp.ones((2,)), input_shape=self.input_shape, input_dtype=self.input_dtype, mode=\"same\"\n        )\n        B = linop.Identity(input_shape=self.input_shape, input_dtype=self.input_dtype)\n\n        assert type(operator(A, B)) == linop.LinearOperator\n        assert type(operator(B, A)) == linop.LinearOperator\n        assert type(operator(2.0 * A, 3.0 * B)) == linop.LinearOperator\n        assert type(operator(2.0 * B, 3.0 * A)) == linop.LinearOperator\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub])\n    def test_case3(self, operator):\n        A = linop.SingleAxisFiniteDifference(\n            input_shape=self.input_shape, input_dtype=self.input_dtype, circular=True\n        )\n        B = linop.Identity(input_shape=self.input_shape, input_dtype=self.input_dtype)\n\n        assert type(operator(A, B)) == linop.LinearOperator\n        assert type(operator(B, A)) == linop.LinearOperator\n        assert type(operator(2.0 * A, 3.0 * B)) == linop.LinearOperator\n        assert type(operator(2.0 * B, 3.0 * A)) == linop.LinearOperator\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub])\n    def test_case4(self, operator):\n        A = linop.ScaledIdentity(\n            scalar=0.5, input_shape=self.input_shape, input_dtype=self.input_dtype\n        )\n        B = linop.Identity(input_shape=self.input_shape, input_dtype=self.input_dtype)\n\n        assert type(operator(A, B)) == linop.ScaledIdentity\n        assert type(operator(B, A)) == linop.ScaledIdentity\n        assert type(operator(2.0 * A, 3.0 * B)) == linop.ScaledIdentity\n        assert type(operator(2.0 * B, 3.0 * A)) == linop.ScaledIdentity\n"
  },
  {
    "path": "scico/test/linop/test_circconv.py",
    "content": "import operator as op\n\nimport numpy as np\n\nimport jax\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico.linop import CircularConvolve, Convolve, Diagonal\nfrom scico.random import randint, randn, uniform\nfrom scico.test.linop.test_linop import adjoint_test\n\nSHAPE_SPECS = [\n    ((12,), None, (3,)),  # 1D\n    ((12, 8), None, (3, 2)),  # 2D\n    ((6, 8, 12), None, (3, 2, 4)),  # 3D\n    ((2, 12, 8), 2, (3, 2)),  # batching x\n    ((12, 8), None, (2, 3, 2)),  # batching h\n    ((2, 12, 8), 2, (2, 3, 2)),  # batching both\n    # (M, N, b) x (H, W, 1)  # this was the old way\n    # (M, N, b) x (H, W)  # this won't work: Luke, firm-no\n    # (M, b, N) x (H, W)  # do we even want this?\n    # (M, b, N) x (b, H, W) # no, no, no\n]\n\n\nclass TestCircularConvolve:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    @pytest.mark.parametrize(\"jit\", [True, False])\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"axes_shape_spec\", SHAPE_SPECS)\n    def test_eval(self, axes_shape_spec, input_dtype, jit):\n        x_shape, ndims, h_shape = axes_shape_spec\n\n        h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key)\n        x, key = randn(tuple(x_shape), dtype=input_dtype, key=key)\n\n        A = CircularConvolve(h, x_shape, ndims, input_dtype, jit=jit)\n\n        Ax = A @ x\n\n        # check that a specific pixel of Ax computes an inner product between x and\n        # (flipped, padded, shifted) h\n        h_flipped = np.flip(h, range(-A.ndims, 0))  # flip only in the spatial dims (not batches)\n\n        x_inds = (...,) + tuple(\n            slice(-h.shape[a], None) for a in range(-A.ndims, 0)\n        )  # bottom right corner of x\n        Ax_inds = (...,) + tuple(-1 for _ in range(A.ndims))\n        sum_axes = tuple(-(a + 1) for a in range(A.ndims))  # ndims=2 -> -1, -2\n        np.testing.assert_allclose(\n            np.sum(h_flipped * x[x_inds], axis=sum_axes), Ax[Ax_inds], rtol=1e-5\n        )\n\n        # np.testing.assert_allclose(Ax.ravel(), hx.ravel(), rtol=5e-4)\n\n    @pytest.mark.parametrize(\"jit\", [True, False])\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"axes_shape_spec\", SHAPE_SPECS)\n    def test_adjoint(self, axes_shape_spec, input_dtype, jit):\n        x_shape, ndims, h_shape = axes_shape_spec\n\n        h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key)\n\n        A = CircularConvolve(h, x_shape, ndims, input_dtype, jit=jit)\n\n        adjoint_test(A, self.key)\n\n    @pytest.mark.parametrize(\"jit\", [True, False])\n    @pytest.mark.parametrize(\"axes_shape_spec\", SHAPE_SPECS)\n    @pytest.mark.parametrize(\"operator\", [op.mul, op.truediv])\n    def test_scalar_left(self, axes_shape_spec, operator, jit):\n        input_dtype = np.float32\n        scalar = np.float32(3.141)\n\n        x_shape, ndims, h_shape = axes_shape_spec\n\n        h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key)\n\n        A = CircularConvolve(h, x_shape, ndims, input_dtype, jit=jit)\n\n        cA = operator(A, scalar)\n\n        np.testing.assert_allclose(operator(A.h_dft.ravel(), scalar), cA.h_dft.ravel(), rtol=5e-5)\n\n    @pytest.mark.parametrize(\"jit\", [True, False])\n    @pytest.mark.parametrize(\"axes_shape_spec\", SHAPE_SPECS)\n    @pytest.mark.parametrize(\"operator\", [op.mul])\n    def test_scalar_right(self, axes_shape_spec, operator, jit):\n        input_dtype = np.float32\n        scalar = np.float32(3.141)\n\n        x_shape, ndims, h_shape = axes_shape_spec\n\n        h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key)\n\n        A = CircularConvolve(h, x_shape, ndims, input_dtype, jit=jit)\n        cA = operator(scalar, A)\n\n        np.testing.assert_allclose(operator(scalar, A.h_dft.ravel()), cA.h_dft.ravel(), rtol=5e-5)\n\n    @pytest.mark.parametrize(\"jit\", [True, False])\n    @pytest.mark.parametrize(\"axes_shape_spec\", SHAPE_SPECS)\n    def test_add_sub(self, axes_shape_spec, jit):\n        input_dtype = np.float32\n\n        x_shape, ndims, h_shape = axes_shape_spec\n\n        h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key)\n        g, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key)\n\n        A = CircularConvolve(h, x_shape, ndims, input_dtype, jit=jit)\n        B = CircularConvolve(g, x_shape, ndims, input_dtype, jit=jit)\n\n        np.testing.assert_allclose(A.h_dft + B.h_dft, (A + B).h_dft, rtol=5e-5)\n        np.testing.assert_allclose(A.h_dft - B.h_dft, (A - B).h_dft, rtol=5e-5)\n\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"jit\", [True, False])\n    def test_matches_convolve(self, input_dtype, jit):\n        h, key = randint(minval=0, maxval=3, shape=(3, 4), key=self.key)\n        x, key = uniform(minval=0, maxval=1, shape=(5, 4), key=key)\n\n        h = h.astype(input_dtype)\n        x = (x <= 0.1).astype(input_dtype)\n\n        # pad to m + n -1\n        x_pad = snp.pad(x, ((0, h.shape[0] - 1), (0, h.shape[1] - 1)))\n\n        A = Convolve(h=h, input_shape=x.shape, jit=jit, input_dtype=input_dtype)\n        B = CircularConvolve(h, input_shape=x_pad.shape, jit=jit, input_dtype=input_dtype)\n\n        actual = B @ x_pad\n        desired = A @ x\n        np.testing.assert_allclose(actual, desired, atol=1e-6)\n\n    @pytest.mark.parametrize(\n        \"center\",\n        [\n            1,\n            [\n                1,\n            ],\n            snp.array([2]),\n        ],\n    )\n    @pytest.mark.parametrize(\"jit\", [True, False])\n    def test_center(self, center, jit):\n        x, key = uniform(minval=-1, maxval=1, shape=(16,), key=self.key)\n        h = snp.array([0.5, 1.0, 0.25])\n        A = CircularConvolve(h=h, input_shape=x.shape, h_center=center, jit=jit)\n        B = CircularConvolve(h=h, input_shape=x.shape, jit=jit)\n        if isinstance(center, int):\n            shift = -center\n        else:\n            shift = -center[0]\n        np.testing.assert_allclose(A @ x, snp.roll(B @ x, shift), atol=1e-5)\n\n    @pytest.mark.parametrize(\"jit\", [True, False])\n    def test_fractional_center(self, jit):\n        \"\"\"A fractional center should keep outputs real.\"\"\"\n        x, key = uniform(minval=-1, maxval=1, shape=(4, 5), key=self.key)\n        h, _ = uniform(minval=-1, maxval=1, shape=(2, 2), key=key)\n        A = CircularConvolve(h=h, input_shape=x.shape, h_center=[0.1, 2.7], jit=jit)\n\n        # taken from CircularConvolve._eval\n        x_dft = snp.fft.fftn(x, axes=A.x_fft_axes)\n        hx = snp.fft.ifftn(\n            A.h_dft * x_dft,\n            axes=A.ifft_axes,\n        )\n\n        np.testing.assert_allclose(hx, snp.real(hx))\n\n    @pytest.mark.parametrize(\"axes_shape_spec\", SHAPE_SPECS)\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"jit_old_op\", [True, False])\n    @pytest.mark.parametrize(\"jit_new_op\", [True, False])\n    def test_from_operator(self, axes_shape_spec, input_dtype, jit_old_op, jit_new_op):\n        x_shape, ndims, h_shape = axes_shape_spec\n\n        h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key)\n        x, key = randn(tuple(x_shape), dtype=input_dtype, key=key)\n\n        A = CircularConvolve(h, x_shape, ndims, input_dtype, jit=jit_old_op)\n\n        B = CircularConvolve.from_operator(A, ndims, jit=jit_new_op)\n\n        np.testing.assert_allclose(A @ x, B @ x, atol=1e-5)\n\n    def test_from_operator_block_array(self):\n        \"\"\"`from_operator` should throw an exception if asked to work\n        on an operator with blockarray inputs.\"\"\"\n\n        H = Diagonal(snp.zeros(((1, 2), (3,))))\n\n        with pytest.raises(ValueError):\n            CircularConvolve.from_operator(H)\n"
  },
  {
    "path": "scico/test/linop/test_conversions.py",
    "content": "\"\"\"\nTest methods that make one kind of Operator out of another.\n\"\"\"\n\nimport numpy as np\n\nimport pytest\n\nfrom scico.linop import CircularConvolve, FiniteDifference\nfrom scico.random import randn\n\n\n@pytest.mark.parametrize(\n    \"shape_axes\",\n    [\n        ((3, 4), None),  # 2d\n        ((3, 4, 5), None),  # 3d\n        # ((3, 4, 5), [0, 2]),  # 3d specific axes -- not supported\n    ],\n)\n@pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n@pytest.mark.parametrize(\"jit_old\", [False, True])\n@pytest.mark.parametrize(\"jit_new\", [False, True])\ndef testCircularConvolve_from_FiniteDifference(shape_axes, input_dtype, jit_old, jit_new):\n    input_shape, axes = shape_axes\n    x, _ = randn(input_shape, dtype=input_dtype)\n\n    # make a CircularConvolve from a FiniteDifference\n    A = FiniteDifference(\n        input_shape=input_shape, input_dtype=input_dtype, axes=axes, circular=True, jit=jit_old\n    )\n\n    B = CircularConvolve.from_operator(A, ndims=x.ndim, jit=jit_new)\n    np.testing.assert_allclose(A @ x, B @ x, atol=1e-5)\n\n    # try the same on the FiniteDifference Gram\n    ATA = A.gram_op\n\n    B = CircularConvolve.from_operator(ATA, ndims=x.ndim, jit=jit_new)\n    np.testing.assert_allclose(ATA @ x, B @ x, atol=1e-5)\n"
  },
  {
    "path": "scico/test/linop/test_convolve.py",
    "content": "import operator as op\n\nimport numpy as np\n\nimport jax\nimport jax.scipy.signal as signal\n\nimport pytest\n\nfrom scico.linop import Convolve, ConvolveByX, LinearOperator\nfrom scico.random import randn\nfrom scico.test.linop.test_linop import AbsMatOp, adjoint_test\n\n\nclass TestConvolve:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"input_shape\", [(16,), (16, 24)])\n    @pytest.mark.parametrize(\"mode\", [\"full\", \"valid\", \"same\"])\n    @pytest.mark.parametrize(\"jit\", [False, True])\n    def test_eval(self, input_shape, input_dtype, mode, jit):\n        ndim = len(input_shape)\n\n        filter_shape = (3, 4)[:ndim]\n\n        x, key = randn(input_shape, dtype=input_dtype, key=self.key)\n        psf, key = randn(filter_shape, dtype=input_dtype, key=key)\n        A = Convolve(h=psf, input_shape=input_shape, input_dtype=input_dtype, mode=mode, jit=jit)\n        Ax = A @ x\n        y = signal.convolve(x, psf, mode=mode)\n        np.testing.assert_allclose(Ax.ravel(), y.ravel(), rtol=1e-4)\n\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"input_shape\", [(16,), (16, 24)])\n    @pytest.mark.parametrize(\"mode\", [\"full\", \"valid\", \"same\"])\n    @pytest.mark.parametrize(\"jit\", [False, True])\n    def test_adjoint(self, input_shape, mode, jit, input_dtype):\n        ndim = len(input_shape)\n        filter_shape = (3, 4)[:ndim]\n        x, key = randn(input_shape, dtype=input_dtype, key=self.key)\n        psf, key = randn(filter_shape, dtype=input_dtype, key=key)\n\n        A = Convolve(h=psf, input_shape=input_shape, input_dtype=input_dtype, mode=mode, jit=jit)\n\n        adjoint_test(A, self.key)\n\n\nclass ConvolveTestObj:\n    def __init__(self):\n        dtype = np.float32\n        key = jax.random.key(12345)\n\n        self.psf_A, key = randn((3,), dtype=dtype, key=key)\n        self.psf_B, key = randn((3,), dtype=dtype, key=key)\n        self.psf_C, key = randn((5,), dtype=dtype, key=key)\n\n        self.A = Convolve(input_shape=(16,), h=self.psf_A)\n        self.B = Convolve(input_shape=(16,), h=self.psf_B)\n        self.C = Convolve(input_shape=(16,), h=self.psf_C)\n\n        # Matrix for a 'generic linop'\n        m = self.A.output_shape[0]\n        n = self.A.input_shape[0]\n        G_mat, key = randn((m, n), dtype=dtype, key=key)\n        self.G = AbsMatOp(G_mat)\n\n        self.x, key = randn((16,), dtype=dtype, key=key)\n\n        self.scalar = 3.141\n\n\n@pytest.fixture\ndef testobj(request):\n    yield ConvolveTestObj()\n\n\ndef test_init(testobj):\n    with pytest.raises(ValueError):\n        A = Convolve(input_shape=(16, 16), h=testobj.psf_A)\n    with pytest.raises(ValueError):\n        A = Convolve(input_shape=(16,), h=testobj.psf_A, mode=\"invalid\")\n    A = Convolve(input_shape=(16,), input_dtype=None, h=testobj.psf_A)\n    assert A.input_dtype == testobj.psf_A.dtype\n\n\n@pytest.mark.parametrize(\"operator\", [op.mul, op.truediv])\ndef test_scalar_left(testobj, operator):\n    A = operator(testobj.A, testobj.scalar)\n    x = testobj.x\n    B = Convolve(input_shape=(16,), h=operator(testobj.psf_A, testobj.scalar))\n    np.testing.assert_allclose(A @ x, B @ x, rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"operator\", [op.mul, op.truediv])\ndef test_scalar_right(testobj, operator):\n    if operator == op.truediv:\n        pytest.xfail(\"scalar / LinearOperator is not supported\")\n    A = operator(testobj.scalar, testobj.A)\n    x = testobj.x\n    B = Convolve(input_shape=(16,), h=operator(testobj.scalar, testobj.psf_A))\n    np.testing.assert_allclose(A @ x, B @ x, rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"operator\", [op.add, op.sub])\ndef test_convolve_add_sub(testobj, operator):\n    A = testobj.A\n    B = testobj.B\n    C = testobj.C\n    x = testobj.x\n\n    # Two operators of same size\n    AB = operator(A, B)\n    ABx = AB @ x\n    AxBx = operator(A @ x, B @ x)\n    np.testing.assert_allclose(ABx, AxBx, rtol=5e-5)\n\n    # Two operators of different size\n    with pytest.raises(ValueError):\n        operator(A, C)\n\n\n@pytest.mark.parametrize(\"operator\", [op.add, op.sub])\ndef test_add_sub_different_mode(testobj, operator):\n    # These tests get caught inside of the _wrap_add_sub input/output shape checks,\n    # not the explicit mode check inside of the wrapped __add__ method\n    B_same = Convolve(input_shape=(16,), h=testobj.psf_B, mode=\"same\")\n    with pytest.raises(ValueError):\n        operator(testobj.A, B_same)\n\n\n@pytest.mark.parametrize(\"operator\", [op.add, op.sub])\ndef test_add_sum_generic_linop(testobj, operator):\n    # Combine a AbsMatOp and Convolve, get a generic LinearOperator\n    AG = operator(testobj.A, testobj.G)\n    assert isinstance(AG, LinearOperator)\n\n    # Check evaluation\n    a = AG @ testobj.x\n    b = operator(testobj.A @ testobj.x, testobj.G @ testobj.x)\n    np.testing.assert_allclose(a, b, rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"operator\", [op.add, op.sub])\ndef test_add_sum_conv(testobj, operator):\n    # Combine a AbsMatOp and Convolve, get a generic LinearOperator\n    AA = operator(testobj.A, testobj.A)\n    assert isinstance(AA, Convolve)\n\n    # Check evaluation\n    a = AA @ testobj.x\n    b = operator(testobj.A @ testobj.x, testobj.A @ testobj.x)\n    np.testing.assert_allclose(a, b, rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"operator\", [op.mul, op.truediv])\ndef test_mul_div_generic_linop(testobj, operator):\n    # not defined between Convolve and AbsMatOp\n    with pytest.raises(TypeError):\n        operator(testobj.A, testobj.G)\n\n\ndef test_invalid_mode(testobj):\n    # mode that doesn't exist\n    with pytest.raises(ValueError):\n        Convolve(input_shape=(16,), h=testobj.psf_A, mode=\"foo\")\n\n\ndef test_dimension_mismatch(testobj):\n    with pytest.raises(ValueError):\n        # 2-dim input shape, 1-dim filter\n        Convolve(input_shape=(16, 16), h=testobj.psf_A)\n\n\nclass TestConvolveByX:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"input_shape\", [(16,), (16, 24)])\n    @pytest.mark.parametrize(\"mode\", [\"full\", \"valid\", \"same\"])\n    @pytest.mark.parametrize(\"jit\", [False, True])\n    def test_eval(self, input_shape, input_dtype, mode, jit):\n        ndim = len(input_shape)\n\n        x_shape = (3, 4)[:ndim]\n\n        h, key = randn(input_shape, dtype=input_dtype, key=self.key)\n        x, key = randn(x_shape, dtype=input_dtype, key=key)\n\n        A = ConvolveByX(x=x, input_shape=input_shape, input_dtype=input_dtype, mode=mode, jit=jit)\n        Ax = A @ h\n        y = signal.convolve(x, h, mode=mode)\n        np.testing.assert_allclose(Ax.ravel(), y.ravel(), rtol=1e-4)\n\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"input_shape\", [(16,), (16, 24)])\n    @pytest.mark.parametrize(\"mode\", [\"full\", \"valid\", \"same\"])\n    @pytest.mark.parametrize(\"jit\", [False, True])\n    def test_adjoint(self, input_shape, mode, jit, input_dtype):\n        ndim = len(input_shape)\n        x_shape = (3, 4)[:ndim]\n        x, key = randn(input_shape, dtype=input_dtype, key=self.key)\n        x, key = randn(x_shape, dtype=input_dtype, key=key)\n\n        A = ConvolveByX(x=x, input_shape=input_shape, input_dtype=input_dtype, mode=mode, jit=jit)\n\n        adjoint_test(A, self.key)\n\n\nclass ConvolveByXTestObj:\n    def __init__(self):\n        dtype = np.float32\n        key = jax.random.key(12345)\n\n        self.x_A, key = randn((3,), dtype=dtype, key=key)\n        self.x_B, key = randn((3,), dtype=dtype, key=key)\n        self.x_C, key = randn((5,), dtype=dtype, key=key)\n\n        self.A = ConvolveByX(input_shape=(16,), x=self.x_A)\n        self.B = ConvolveByX(input_shape=(16,), x=self.x_B)\n        self.C = ConvolveByX(input_shape=(16,), x=self.x_C)\n\n        # Matrix for a 'generic linop'\n        m = self.A.output_shape[0]\n        n = self.A.input_shape[0]\n        G_mat, key = randn((m, n), dtype=dtype, key=key)\n        self.G = AbsMatOp(G_mat)\n\n        self.h, key = randn((16,), dtype=dtype, key=key)\n\n        self.scalar = 3.141\n\n\n@pytest.fixture\ndef cbx_testobj(request):\n    yield ConvolveByXTestObj()\n\n\n@pytest.mark.parametrize(\"operator\", [op.mul, op.truediv])\ndef test_cbx_scalar_left(cbx_testobj, operator):\n    A = operator(cbx_testobj.A, cbx_testobj.scalar)\n    h = cbx_testobj.h\n    B = ConvolveByX(input_shape=(16,), x=operator(cbx_testobj.x_A, cbx_testobj.scalar))\n    np.testing.assert_allclose(A @ h, B @ h, rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"operator\", [op.mul, op.truediv])\ndef test_cbx_scalar_right(cbx_testobj, operator):\n    if operator == op.truediv:\n        pytest.xfail(\"scalar / LinearOperator is not supported\")\n    A = operator(cbx_testobj.scalar, cbx_testobj.A)\n    h = cbx_testobj.h\n    B = ConvolveByX(input_shape=(16,), x=operator(cbx_testobj.scalar, cbx_testobj.x_A))\n    np.testing.assert_allclose(A @ h, B @ h, rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"operator\", [op.add, op.sub])\ndef test_convolve_add_sub(cbx_testobj, operator):\n    A = cbx_testobj.A\n    B = cbx_testobj.B\n    C = cbx_testobj.C\n    h = cbx_testobj.h\n\n    # Two operators of same size\n    AB = operator(A, B)\n    ABh = AB @ h\n    AfiltBh = operator(A @ h, B @ h)\n    np.testing.assert_allclose(ABh, AfiltBh, rtol=5e-5)\n\n    # Two operators of different size\n    with pytest.raises(ValueError):\n        operator(A, C)\n\n\n@pytest.mark.parametrize(\"operator\", [op.add, op.sub])\ndef test_add_sub_different_mode(cbx_testobj, operator):\n    # These tests get caught inside of the _wrap_add_sub input/output shape checks,\n    # not the explicit mode check inside of the wrapped __add__ method\n    B_same = ConvolveByX(input_shape=(16,), x=cbx_testobj.x_B, mode=\"same\")\n    with pytest.raises(ValueError):\n        operator(cbx_testobj.A, B_same)\n\n\n@pytest.mark.parametrize(\"operator\", [op.add, op.sub])\ndef test_add_sum_generic_linop(cbx_testobj, operator):\n    # Combine a AbsMatOp and ConvolveByX, get a generic LinearOperator\n    AG = operator(cbx_testobj.A, cbx_testobj.G)\n    assert isinstance(AG, LinearOperator)\n\n    # Check evaluation\n    a = AG @ cbx_testobj.h\n    b = operator(cbx_testobj.A @ cbx_testobj.h, cbx_testobj.G @ cbx_testobj.h)\n    np.testing.assert_allclose(a, b, rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"operator\", [op.add, op.sub])\ndef test_add_sum_conv(cbx_testobj, operator):\n    # Combine a AbsMatOp and ConvolveByX, get a generic LinearOperator\n    AA = operator(cbx_testobj.A, cbx_testobj.A)\n    assert isinstance(AA, ConvolveByX)\n\n    # Check evaluation\n    a = AA @ cbx_testobj.h\n    b = operator(cbx_testobj.A @ cbx_testobj.h, cbx_testobj.A @ cbx_testobj.h)\n    np.testing.assert_allclose(a, b, rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"operator\", [op.mul, op.truediv])\ndef test_mul_div_generic_linop(cbx_testobj, operator):\n    # not defined between ConvolveByX and AbsMatOp\n    with pytest.raises(TypeError):\n        operator(cbx_testobj.A, cbx_testobj.G)\n\n\ndef test_invalid_mode(cbx_testobj):\n    # mode that doesn't exist\n    with pytest.raises(ValueError):\n        ConvolveByX(input_shape=(16,), x=cbx_testobj.x_A, mode=\"foo\")\n\n\ndef test_dimension_mismatch(cbx_testobj):\n    with pytest.raises(ValueError):\n        # 2-dim input shape, 1-dim xer\n        ConvolveByX(input_shape=(16, 16), x=cbx_testobj.x_A)\n"
  },
  {
    "path": "scico/test/linop/test_dft.py",
    "content": "import numpy as np\n\nimport jax\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico.linop import DFT\nfrom scico.random import randn\nfrom scico.test.linop.test_linop import adjoint_test\n\n\nclass TestDFT:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    @pytest.mark.parametrize(\"input_shape\", [(16,), (16, 4), (16, 4, 7)])\n    @pytest.mark.parametrize(\n        \"axes_and_shape\",\n        [\n            (None, None),\n            ((0,), None),\n            ((0,), (20,)),\n            ((0, 2), None),\n            ((0, 2), (20, 8)),\n            (None, (6, 8)),\n        ],\n    )\n    @pytest.mark.parametrize(\"norm\", [None, \"backward\", \"ortho\", \"forward\"])\n    @pytest.mark.parametrize(\"jit\", [False, True])\n    def test_dft(self, input_shape, axes_and_shape, norm, jit):\n        axes = axes_and_shape[0]\n        axes_shape = axes_and_shape[1]\n\n        # Skip bad parameter permutations\n        if axes is not None and len(axes) >= len(input_shape):\n            return\n        if axes is not None and max(axes) >= len(input_shape):\n            return\n        if axes_shape is not None and len(axes_shape) > len(input_shape):\n            return\n\n        x, self.key = randn(input_shape, dtype=np.complex64, key=self.key)\n        F = DFT(input_shape=input_shape, axes=axes, axes_shape=axes_shape, norm=norm, jit=jit)\n        Fx = F @ x\n\n        # Test eval\n        snp_result = snp.fft.fftn(x, s=axes_shape, axes=axes, norm=norm).astype(np.complex64)\n        np.testing.assert_allclose(Fx, snp_result, rtol=1e-6)\n\n        # Test adjoint\n        adjoint_test(F, self.key)\n\n        # Test inverse\n        y, self.key = randn(F.output_shape, dtype=np.complex64, key=self.key)\n        Fiy = F.inv(y)\n        snp_result = snp.fft.ifftn(y, s=F.inv_axes_shape, axes=axes, norm=norm).astype(np.complex64)\n        np.testing.assert_allclose(Fiy, snp_result, rtol=1e-6)\n\n    def test_axes_check(self):\n        input_shape = (32, 48)\n        axes = (0,)\n        axes_shape = (40, 50)\n        with pytest.raises(ValueError):\n            F = DFT(input_shape=input_shape, axes=axes, axes_shape=axes_shape)\n"
  },
  {
    "path": "scico/test/linop/test_diag.py",
    "content": "import operator as op\n\nimport numpy as np\n\nfrom jax import config\n\nimport pytest\n\n# enable 64-bit mode for output dtype checks\nconfig.update(\"jax_enable_x64\", True)\n\nimport jax\n\nfrom test_linop import adjoint_test\n\nimport scico.numpy as snp\nfrom scico import linop\nfrom scico.random import randn\n\n\nclass TestDiagonal:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    input_shapes = [(8,), (8, 12), ((3,), (4, 5))]\n\n    @pytest.mark.parametrize(\"diagonal_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"input_shape\", input_shapes)\n    def test_eval(self, input_shape, diagonal_dtype):\n        diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key)\n        x, key = randn(input_shape, dtype=diagonal_dtype, key=key)\n\n        D = linop.Diagonal(diagonal=diagonal)\n        assert (D @ x).shape == D.output_shape\n        snp.testing.assert_allclose((diagonal * x), (D @ x), rtol=1e-5)\n\n    @pytest.mark.parametrize(\"diagonal_dtype\", [np.float32, np.complex64])\n    def test_eval_broadcasting(self, diagonal_dtype):\n        # array broadcast\n        diagonal, key = randn((3, 1, 4), dtype=diagonal_dtype, key=self.key)\n        x, key = randn((5, 1), dtype=diagonal_dtype, key=key)\n        D = linop.Diagonal(diagonal, x.shape)\n        assert (D @ x).shape == (3, 5, 4)\n        np.testing.assert_allclose((diagonal * x).ravel(), (D @ x).ravel(), rtol=1e-5)\n\n        # blockarray broadcast\n        diagonal, key = randn(((3, 1, 4), (5, 5)), dtype=diagonal_dtype, key=self.key)\n        x, key = randn(((5, 1), (1,)), dtype=diagonal_dtype, key=key)\n        D = linop.Diagonal(diagonal, x.shape)\n        assert (D @ x).shape == ((3, 5, 4), (5, 5))\n        snp.testing.assert_allclose((diagonal * x), (D @ x), rtol=1e-5)\n\n        # blockarray x array -> error\n        diagonal, key = randn(((3, 1, 4), (5, 5)), dtype=diagonal_dtype, key=self.key)\n        x, key = randn((5, 1), dtype=diagonal_dtype, key=key)\n        with pytest.raises(ValueError):\n            D = linop.Diagonal(diagonal, x.shape)\n\n        # array x blockarray -> error\n        diagonal, key = randn((3, 1, 4), dtype=diagonal_dtype, key=self.key)\n        x, key = randn(((5, 1), (1,)), dtype=diagonal_dtype, key=key)\n        with pytest.raises(ValueError):\n            D = linop.Diagonal(diagonal, x.shape)\n\n    @pytest.mark.parametrize(\"diagonal_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"input_shape\", input_shapes)\n    def test_adjoint(self, input_shape, diagonal_dtype):\n        diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key)\n        D = linop.Diagonal(diagonal=diagonal)\n\n        adjoint_test(D)\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub])\n    @pytest.mark.parametrize(\"diagonal_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"input_shape1\", input_shapes)\n    @pytest.mark.parametrize(\"input_shape2\", input_shapes)\n    def test_binary_op(self, input_shape1, input_shape2, diagonal_dtype, operator):\n        diagonal1, key = randn(input_shape1, dtype=diagonal_dtype, key=self.key)\n        diagonal2, key = randn(input_shape2, dtype=diagonal_dtype, key=key)\n        x, key = randn(input_shape1, dtype=diagonal_dtype, key=key)\n\n        D1 = linop.Diagonal(diagonal=diagonal1)\n        D2 = linop.Diagonal(diagonal=diagonal2)\n\n        if input_shape1 != input_shape2:\n            with pytest.raises(ValueError):\n                a = operator(D1, D2) @ x\n        else:\n            a = operator(D1, D2) @ x\n            Dnew = linop.Diagonal(operator(diagonal1, diagonal2))\n            b = Dnew @ x\n            snp.testing.assert_allclose(a, b, rtol=1e-5)\n\n    @pytest.mark.parametrize(\"diagonal_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"input_shape1\", input_shapes)\n    @pytest.mark.parametrize(\"input_shape2\", input_shapes)\n    def test_matmul(self, input_shape1, input_shape2, diagonal_dtype):\n        diagonal1, key = randn(input_shape1, dtype=diagonal_dtype, key=self.key)\n        diagonal2, key = randn(input_shape2, dtype=diagonal_dtype, key=key)\n        x, key = randn(input_shape1, dtype=diagonal_dtype, key=key)\n\n        D1 = linop.Diagonal(diagonal=diagonal1)\n        D2 = linop.Diagonal(diagonal=diagonal2)\n\n        if input_shape1 != input_shape2:\n            with pytest.raises(ValueError):\n                D3 = D1 @ D2\n        else:\n            D3 = D1 @ D2\n            assert isinstance(D3, linop.Diagonal)\n            a = D3 @ x\n            D4 = linop.Diagonal(diagonal1 * diagonal2)\n            b = D4 @ x\n            snp.testing.assert_allclose(a, b, rtol=1e-5)\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub])\n    def test_binary_op_mismatch(self, operator):\n        diagonal_dtype = np.float32\n        input_shape1 = (8,)\n        input_shape2 = (12,)\n        diagonal1, key = randn(input_shape1, dtype=diagonal_dtype, key=self.key)\n        diagonal2, key = randn(input_shape2, dtype=diagonal_dtype, key=key)\n\n        D1 = linop.Diagonal(diagonal=diagonal1)\n        D2 = linop.Diagonal(diagonal=diagonal2)\n        with pytest.raises(ValueError):\n            operator(D1, D2)\n\n    @pytest.mark.parametrize(\"operator\", [op.mul, op.truediv])\n    def test_scalar_right(self, operator):\n        if operator == op.truediv:\n            pytest.xfail(\"scalar / LinearOperator is not supported\")\n\n        diagonal_dtype = np.float32\n        input_shape = (8,)\n\n        diagonal1, key = randn(input_shape, dtype=diagonal_dtype, key=self.key)\n        scalar = np.random.randn()\n        x, key = randn(input_shape, dtype=diagonal_dtype, key=key)\n\n        D = linop.Diagonal(diagonal=diagonal1)\n        scaled_D = operator(scalar, D)\n\n        np.testing.assert_allclose(scaled_D @ x, operator(scalar, D @ x), rtol=5e-5)\n\n    @pytest.mark.parametrize(\"operator\", [op.mul, op.truediv])\n    def test_scalar_left(self, operator):\n        diagonal_dtype = np.float32\n        input_shape = (8,)\n\n        diagonal1, key = randn(input_shape, dtype=diagonal_dtype, key=self.key)\n        scalar = np.random.randn()\n        x, key = randn(input_shape, dtype=diagonal_dtype, key=key)\n\n        D = linop.Diagonal(diagonal=diagonal1)\n        scaled_D = operator(D, scalar)\n\n        np.testing.assert_allclose(scaled_D @ x, operator(D @ x, scalar), rtol=5e-5)\n\n    @pytest.mark.parametrize(\"diagonal_dtype\", [np.float32, np.complex64])\n    def test_gram_op(self, diagonal_dtype):\n        input_shape = (7,)\n        diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key)\n\n        D1 = linop.Diagonal(diagonal=diagonal)\n        D2 = D1.gram_op\n        D3 = D1.H @ D1\n        assert isinstance(D3, linop.Diagonal)\n        snp.testing.assert_allclose(D2.diagonal, D3.diagonal, rtol=1e-6)\n\n    @pytest.mark.parametrize(\"diagonal_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"ord\", [None, \"fro\", \"nuc\", -np.inf, np.inf, 1, -1, 2, -2])\n    def test_norm(self, diagonal_dtype, ord):\n        input_shape = (5,)\n        diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key)\n\n        D1 = linop.Diagonal(diagonal=diagonal)\n        D2 = snp.diag(diagonal)\n        n1 = D1.norm(ord=ord)\n        n2 = snp.linalg.norm(D2, ord=ord)\n        snp.testing.assert_allclose(n1, n2, rtol=1e-6)\n\n    def test_norm_except(self):\n        input_shape = (5,)\n        diagonal, key = randn(input_shape, dtype=np.float32, key=self.key)\n\n        D = linop.Diagonal(diagonal=diagonal)\n        with pytest.raises(ValueError):\n            n = D.norm(ord=3)\n\n\nclass TestScaledIdentity:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    input_shapes = [(8,), (8, 12), ((3,), (4, 5))]\n\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"input_shape\", input_shapes)\n    def test_eval(self, input_shape, input_dtype):\n        x, key = randn(input_shape, dtype=input_dtype, key=self.key)\n        scalar, key = randn((), dtype=input_dtype, key=key)\n\n        Id = linop.ScaledIdentity(scalar=scalar, input_shape=input_shape, input_dtype=input_dtype)\n        assert (Id @ x).shape == Id.output_shape\n        snp.testing.assert_allclose(scalar * x, Id @ x, rtol=1e-5)\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub])\n    @pytest.mark.parametrize(\"input_shape\", input_shapes)\n    def test_binary_op(self, input_shape, operator):\n        input_dtype = np.float32\n        diagonal, key = randn(input_shape, dtype=input_dtype, key=self.key)\n        x, key = randn(input_shape, dtype=input_dtype, key=key)\n        scalar, key = randn((), dtype=input_dtype, key=key)\n\n        Id = linop.ScaledIdentity(scalar, input_shape=input_shape)\n        D = linop.Diagonal(diagonal=diagonal)\n\n        IdD = operator(Id, D)\n        assert isinstance(IdD, linop.Diagonal)\n        snp.testing.assert_allclose(IdD @ x, operator(scalar, diagonal) * x, rtol=1e-6)\n\n        DId = operator(D, Id)\n        assert isinstance(DId, linop.Diagonal)\n        snp.testing.assert_allclose(DId @ x, operator(diagonal, scalar) * x, rtol=1e-6)\n\n    def test_scale(self):\n        input_shape = (5,)\n        input_dtype = np.float32\n        scalar1, key = randn((), dtype=input_dtype, key=self.key)\n        scalar2, key = randn((), dtype=input_dtype, key=key)\n\n        x, key = randn(input_shape, dtype=input_dtype, key=self.key)\n        Id = linop.ScaledIdentity(scalar=scalar1, input_shape=input_shape, input_dtype=input_dtype)\n\n        sId = scalar2 * Id\n        assert isinstance(sId, linop.ScaledIdentity)\n        snp.testing.assert_allclose(sId @ x, scalar1 * scalar2 * x, rtol=1e-6)\n\n        Ids = Id * scalar2\n        assert isinstance(Ids, linop.ScaledIdentity)\n        snp.testing.assert_allclose(Ids @ x, scalar1 * scalar2 * x, rtol=1e-6)\n\n        Idds = Id / scalar2\n        assert isinstance(Idds, linop.ScaledIdentity)\n        snp.testing.assert_allclose(Idds @ x, x * scalar1 / scalar2, rtol=1e-6)\n\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"ord\", [None, \"fro\", \"nuc\", -np.inf, np.inf, 1, -1, 2, -2])\n    def test_norm(self, input_dtype, ord):\n        input_shape = (5,)\n        scalar, key = randn((), dtype=input_dtype, key=self.key)\n\n        Id = linop.ScaledIdentity(scalar=scalar, input_shape=input_shape, input_dtype=input_dtype)\n        D = linop.Diagonal(\n            diagonal=scalar * snp.ones(input_shape),\n            input_shape=input_shape,\n            input_dtype=input_dtype,\n        )\n        n1 = Id.norm(ord=ord)\n        n2 = D.norm(ord=ord)\n        snp.testing.assert_allclose(n1, n2, rtol=1e-6)\n\n    def test_norm_except(self):\n        input_shape = (5,)\n\n        Id = linop.Identity(input_shape=input_shape, input_dtype=np.float32)\n        with pytest.raises(ValueError):\n            n = Id.norm(ord=3)\n\n\nclass TestIdentity:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    input_shapes = [(8,), (8, 12), ((3,), (4, 5))]\n\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"input_shape\", input_shapes)\n    def test_eval(self, input_shape, input_dtype):\n        x, key = randn(input_shape, dtype=input_dtype, key=self.key)\n\n        Id = linop.Identity(input_shape=input_shape, input_dtype=input_dtype)\n        assert (Id @ x).shape == Id.output_shape\n        snp.testing.assert_allclose(x, Id @ x, rtol=1e-5)\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub])\n    @pytest.mark.parametrize(\"input_shape\", input_shapes)\n    def test_binary_op(self, input_shape, operator):\n        input_dtype = np.float32\n        diagonal, key = randn(input_shape, dtype=input_dtype, key=self.key)\n        scalar, key = randn((), dtype=input_dtype, key=key)\n        x, key = randn(input_shape, dtype=input_dtype, key=key)\n\n        Id = linop.Identity(input_shape=input_shape)\n        Ids = linop.ScaledIdentity(scalar=scalar, input_shape=input_shape)\n        D = linop.Diagonal(diagonal=diagonal)\n\n        IdD = operator(Id, D)\n        assert isinstance(IdD, linop.Diagonal)\n        snp.testing.assert_allclose(IdD @ x, operator(1.0, diagonal) * x, rtol=1e-6)\n\n        DId = operator(D, Id)\n        assert isinstance(DId, linop.Diagonal)\n        snp.testing.assert_allclose(DId @ x, operator(diagonal, 1.0) * x, rtol=1e-6)\n\n        IdIds = operator(Id, Ids)\n        assert isinstance(IdIds, linop.ScaledIdentity)\n        snp.testing.assert_allclose(IdIds @ x, operator(1.0, scalar) * x, rtol=1e-6)\n\n        IdsId = operator(Ids, Id)\n        assert isinstance(IdsId, linop.ScaledIdentity)\n        snp.testing.assert_allclose(IdsId @ x, operator(scalar, 1.0) * x, rtol=1e-6)\n\n    def test_scale(self):\n        input_shape = (5,)\n        input_dtype = np.float32\n        scalar, key = randn((), dtype=input_dtype, key=self.key)\n        x, key = randn(input_shape, dtype=input_dtype, key=key)\n        Id = linop.Identity(input_shape=input_shape, input_dtype=input_dtype)\n\n        sId = scalar * Id\n        assert isinstance(sId, linop.ScaledIdentity)\n        snp.testing.assert_allclose(sId @ x, scalar * x, rtol=1e-6)\n\n        Ids = Id * scalar\n        assert isinstance(Ids, linop.ScaledIdentity)\n        snp.testing.assert_allclose(Ids @ x, scalar * x, rtol=1e-6)\n\n        Idds = Id / scalar\n        assert isinstance(Idds, linop.ScaledIdentity)\n        snp.testing.assert_allclose(Idds @ x, x / scalar, rtol=1e-6)\n"
  },
  {
    "path": "scico/test/linop/test_diff.py",
    "content": "import numpy as np\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico.linop import FiniteDifference, SingleAxisFiniteDifference\nfrom scico.random import randn\nfrom scico.test.linop.test_linop import adjoint_test\n\n\ndef test_eval():\n    with pytest.raises(ValueError):  # axis 3 does not exist\n        A = FiniteDifference(input_shape=(3, 4, 5), axes=(0, 3))\n\n    A = FiniteDifference(input_shape=(2, 3), append=1)\n\n    x = snp.array([[1, 0, 1], [1, 1, 0]], dtype=snp.float32)\n\n    Ax = A @ x\n\n    snp.testing.assert_allclose(\n        Ax[0],  # down columns x[1] - x[0], ..., append - x[N-1]\n        snp.array([[0, 1, -1], [-1, -1, 0]]),\n    )\n    snp.testing.assert_allclose(Ax[1], snp.array([[-1, 1, -1], [0, -1, 0]]))  # along rows\n\n    # test scale\n    B = 2.0 * A\n    Bx = B @ x\n\n    snp.testing.assert_allclose(\n        Bx[0],  # down columns x[1] - x[0], ..., append - x[N-1]\n        2.0 * snp.array([[0, 1, -1], [-1, -1, 0]]),\n    )\n    snp.testing.assert_allclose(Bx[1], 2.0 * snp.array([[-1, 1, -1], [0, -1, 0]]))  # along rows\n\n\ndef test_except():\n    with pytest.raises(TypeError):  # axis is not an int\n        A = SingleAxisFiniteDifference(input_shape=(3,), axis=2.5)\n\n    with pytest.raises(ValueError):  # invalid parameter combination\n        A = SingleAxisFiniteDifference(input_shape=(3,), prepend=0, circular=True)\n\n    with pytest.raises(ValueError):  # invalid prepend value\n        A = SingleAxisFiniteDifference(input_shape=(3,), prepend=2)\n\n    with pytest.raises(ValueError):  # invalid append value\n        A = SingleAxisFiniteDifference(input_shape=(3,), append=\"a\")\n\n\ndef test_eval_prepend():\n    x = snp.arange(1, 6)\n    A = SingleAxisFiniteDifference(input_shape=(5,), prepend=0)\n    snp.testing.assert_allclose(A @ x, snp.array([0, 1, 1, 1, 1]))\n    A = SingleAxisFiniteDifference(input_shape=(5,), prepend=1)\n    snp.testing.assert_allclose(A @ x, snp.array([1, 1, 1, 1, 1]))\n\n\ndef test_eval_append():\n    x = snp.arange(1, 6)\n    A = SingleAxisFiniteDifference(input_shape=(5,), append=0)\n    snp.testing.assert_allclose(A @ x, snp.array([1, 1, 1, 1, 0]))\n    A = SingleAxisFiniteDifference(input_shape=(5,), append=1)\n    snp.testing.assert_allclose(A @ x, snp.array([1, 1, 1, 1, -5]))\n\n\n@pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n@pytest.mark.parametrize(\"input_shape\", [(16,), (16, 24)])\n@pytest.mark.parametrize(\"axes\", [0, 1, (0,), (1,), None])\n@pytest.mark.parametrize(\"jit\", [False, True])\ndef test_adjoint(input_shape, input_dtype, axes, jit):\n    ndim = len(input_shape)\n    if axes in [1, (1,)] and ndim == 1:\n        return\n\n    A = FiniteDifference(input_shape=input_shape, input_dtype=input_dtype, axes=axes, jit=jit)\n    adjoint_test(A)\n\n\n@pytest.mark.parametrize(\n    \"shape_axes\",\n    [\n        ((3, 4), None),  # 2d\n        ((3, 4), 0),  # 2d specific axis\n        ((3, 4, 5), None),  # 3d\n        ((3, 4, 5), [0, 2]),  # 3d specific axes\n    ],\n)\n@pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n@pytest.mark.parametrize(\"jit\", [False, True])\ndef test_eval_circular(shape_axes, input_dtype, jit):\n    input_shape, axes = shape_axes\n    x, _ = randn(input_shape, dtype=input_dtype)\n    A = FiniteDifference(\n        input_shape=input_shape, input_dtype=input_dtype, axes=axes, circular=True, jit=jit\n    )\n    Ax = A @ x\n\n    # check that correct differences are returned\n    for ax in A.axes:\n        np.testing.assert_allclose(np.roll(x, -1, ax) - x, Ax[ax], atol=1e-5, rtol=0)\n\n    # check that the all results match noncircular results except at the last pixel\n    B = FiniteDifference(input_shape=input_shape, input_dtype=input_dtype, axes=axes, jit=jit)\n    Bx = B @ x\n\n    for ax_ind, ax in enumerate(A.axes):\n        np.testing.assert_allclose(\n            Ax[\n                (ax_ind,)\n                + tuple(slice(0, -1) if i == ax else slice(None) for i in range(len(input_shape)))\n            ],\n            Bx[ax_ind],\n            atol=1e-5,\n            rtol=0,\n        )\n"
  },
  {
    "path": "scico/test/linop/test_func.py",
    "content": "import numpy as np\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico import linop\nfrom scico.random import randn\nfrom scico.test.linop.test_linop import adjoint_test\n\n\ndef test_transpose():\n    shape = (1, 2, 3, 4)\n    perm = (1, 0, 3, 2)\n    x, _ = randn(shape)\n    H = linop.Transpose(shape, perm)\n    np.testing.assert_array_equal(H @ x, x.transpose(perm))\n\n    # transpose transpose is transpose inverse\n    np.testing.assert_array_equal(H.T @ H @ x, x)\n\n\ndef test_transpose_ext_init():\n    shape = (1, 2, 3, 4)\n    perm = (1, 0, 3, 2)\n    x, _ = randn(shape)\n    H = linop.Transpose(\n        shape, perm, input_dtype=snp.float32, output_shape=shape, output_dtype=snp.float32\n    )\n    np.testing.assert_array_equal(H @ x, x.transpose(perm))\n\n\ndef test_reshape():\n    shape = (1, 2, 3, 4)\n    newshape = (2, 12)\n    x, _ = randn(shape)\n    H = linop.Reshape(shape, newshape)\n    np.testing.assert_array_equal(H @ x, x.reshape(newshape))\n\n    # reshape reshape is reshape inverse\n    np.testing.assert_array_equal(H.T @ H @ x, x)\n\n\ndef test_pad():\n    shape = (2, 3, 4)\n    pad = 1\n    x, _ = randn(shape)\n    H = linop.Pad(shape, pad)\n\n    pad_shape = tuple(n + 2 * pad for n in shape)\n    y = snp.zeros(pad_shape)\n    y = y.at[pad:-pad, pad:-pad, pad:-pad].set(x)\n    np.testing.assert_array_equal(H @ x, y)\n\n    # pad transpose is crop\n    y, _ = randn(pad_shape)\n    np.testing.assert_array_equal(H.T @ y, y[pad:-pad, pad:-pad, pad:-pad])\n\n\ndef test_crop():\n    shape = (7, 9)\n    crop = (1, 2)\n    x, _ = randn(shape)\n    H = linop.Crop(crop, shape)\n\n    y = x[crop[0] : -crop[1], crop[0] : -crop[1]]\n    np.testing.assert_array_equal(H @ x, y)\n\n\n@pytest.mark.parametrize(\"pad\", [1, (1, 2), ((1, 0), (0, 1)), ((1, 1), (2, 2))])\ndef test_crop_pad_adjoint(pad):\n    shape = (9, 10)\n    H = linop.Pad(shape, pad)\n    G = linop.Crop(pad, H.output_shape)\n    assert linop.valid_adjoint(H, G, eps=1e-5)\n\n\nclass SliceTestObj:\n    def __init__(self, dtype):\n        self.x = snp.zeros((4, 5, 6, 7), dtype=dtype)\n\n\n@pytest.fixture(scope=\"module\", params=[np.float32, np.complex64])\ndef slicetestobj(request):\n    yield SliceTestObj(request.param)\n\n\nslice_examples = [\n    np.s_[1:],\n    np.s_[:, 2:],\n    np.s_[..., 3:],\n    np.s_[1:, :-3],\n    np.s_[1:, :, :3],\n    np.s_[1:, ..., 2:],\n    np.s_[np.newaxis],\n    np.s_[:, np.newaxis],\n]\n\n\n@pytest.mark.parametrize(\"idx\", slice_examples)\ndef test_slice_eval(slicetestobj, idx):\n    x = slicetestobj.x\n    A = linop.Slice(idx=idx, input_shape=x.shape, input_dtype=x.dtype)\n    assert (A @ x).shape == x[idx].shape\n\n\n@pytest.mark.parametrize(\"idx\", slice_examples)\ndef test_slice_adj(slicetestobj, idx):\n    x = slicetestobj.x\n    A = linop.Slice(idx=idx, input_shape=x.shape, input_dtype=x.dtype)\n    adjoint_test(A)\n\n\nblock_slice_examples = [\n    1,\n    np.s_[0:1],\n    np.s_[:1],\n]\n\n\n@pytest.mark.parametrize(\"idx\", block_slice_examples)\ndef test_slice_blockarray(idx):\n    x = snp.BlockArray((snp.zeros((3, 4)), snp.ones((3, 4, 5, 6))))\n    A = linop.Slice(idx=idx, input_shape=x.shape, input_dtype=x.dtype)\n    assert (A @ x).shape == x[idx].shape\n"
  },
  {
    "path": "scico/test/linop/test_grad.py",
    "content": "from itertools import combinations\n\nimport numpy as np\n\nimport jax\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico.linop import (\n    CylindricalGradient,\n    PolarGradient,\n    ProjectedGradient,\n    SphericalGradient,\n)\nfrom scico.numpy import Array\nfrom scico.random import randn\n\n\ndef test_proj_grad():\n    x = snp.ones((4, 5))\n\n    P = ProjectedGradient(x.shape, axes=(0,))\n    assert P(x).shape == (4, 5)\n\n    P = ProjectedGradient(x.shape)\n    assert P(x).shape == (2, 4, 5)\n\n    P = ProjectedGradient(x.shape, coord=(np.arange(0, 4)[:, np.newaxis],))\n    assert P(x).shape == (4, 5)\n\n    coord = (\n        snp.blockarray([snp.array([0.0]), snp.array([1.0])]),\n        snp.blockarray([snp.array([1.0]), snp.array([0.0])]),\n    )\n    P = ProjectedGradient(x.shape, coord=coord)\n    assert P(x).shape == (2, 4, 5)\n\n\nclass TestPolarGradient:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    @pytest.mark.parametrize(\"jit\", [True, False])\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"outflags\", [(True, True), (True, False), (False, True)])\n    @pytest.mark.parametrize(\"center\", [None, (-2, 3), (1.2, -3.5)])\n    @pytest.mark.parametrize(\n        \"shape_axes\",\n        [\n            ((20, 20), None),\n            ((20, 21), (0, 1)),\n            ((16, 17, 3), (0, 1)),\n            ((2, 17, 16), (1, 2)),\n            ((2, 17, 16, 3), (2, 1)),\n        ],\n    )\n    @pytest.mark.parametrize(\"cdiff\", [True, False])\n    def test_eval(self, cdiff, shape_axes, center, outflags, input_dtype, jit):\n\n        input_shape, axes = shape_axes\n        if axes is None:\n            testaxes = (0, 1)\n        else:\n            testaxes = axes\n        if center is not None:\n            axes_shape = [input_shape[ax] for ax in testaxes]\n            center = (snp.array(axes_shape) - 1) / 2 + snp.array(center)\n        angular, radial = outflags\n        x, key = randn(input_shape, dtype=input_dtype, key=self.key)\n        A = PolarGradient(\n            input_shape,\n            axes=axes,\n            center=center,\n            angular=angular,\n            radial=radial,\n            cdiff=cdiff,\n            input_dtype=input_dtype,\n            jit=jit,\n        )\n        Ax = A @ x\n        assert isinstance(Ax, Array)\n        if angular and radial:\n            assert Ax.shape[0] == 2\n            assert Ax.shape[1:] == input_shape\n        else:\n            assert Ax.shape == input_shape\n        assert Ax.dtype == input_dtype\n\n        # Test orthogonality of coordinate axes\n        coord = A.coord\n        for n0, n1 in combinations(range(len(coord)), 2):\n            c0 = coord[n0]\n            c1 = coord[n1]\n            assert snp.abs(snp.sum(c0 * c1)) < 1e-5\n\n\nclass TestCylindricalGradient:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    @pytest.mark.parametrize(\"jit\", [True, False])\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\n        \"outflags\",\n        [\n            (True, True, True),\n            (True, True, False),\n            (True, False, True),\n            (True, False, False),\n            (False, True, True),\n            (False, True, False),\n            (False, False, True),\n        ],\n    )\n    @pytest.mark.parametrize(\"center\", [None, (-2, 3, 0), (1.2, -3.5, 1.5)])\n    @pytest.mark.parametrize(\n        \"shape_axes\",\n        [\n            ((20, 20, 20), None),\n            ((17, 18, 19), (0, 1, 2)),\n            ((16, 17, 18, 3), (0, 1, 2)),\n            ((2, 17, 16, 15), (1, 2, 3)),\n            ((17, 2, 16, 15), (0, 2, 3)),\n            ((17, 2, 16, 15), (3, 2, 0)),\n        ],\n    )\n    def test_eval(self, shape_axes, center, outflags, input_dtype, jit):\n\n        input_shape, axes = shape_axes\n        if axes is None:\n            testaxes = (0, 1, 2)\n        else:\n            testaxes = axes\n        if center is not None:\n            axes_shape = [input_shape[ax] for ax in testaxes]\n            center = (snp.array(axes_shape) - 1) / 2 + snp.array(center)\n        angular, radial, axial = outflags\n        x, key = randn(input_shape, dtype=input_dtype, key=self.key)\n        A = CylindricalGradient(\n            input_shape,\n            axes=axes,\n            center=center,\n            angular=angular,\n            radial=radial,\n            axial=axial,\n            input_dtype=input_dtype,\n            jit=jit,\n        )\n        Ax = A @ x\n        assert isinstance(Ax, Array)\n        Nc = sum([angular, radial, axial])\n        if Nc > 1:\n            assert Ax.shape[0] == Nc\n            assert Ax.shape[1:] == input_shape\n        else:\n            assert Ax.shape == input_shape\n        assert Ax.dtype == input_dtype\n\n        # Test orthogonality of coordinate axes\n        coord = A.coord\n        for n0, n1 in combinations(range(len(coord)), 2):\n            c0 = coord[n0]\n            c1 = coord[n1]\n            s = sum([c0[m] * c1[m] for m in range(len(c0))]).sum()\n            assert snp.abs(s) < 1e-5\n\n\nclass TestSphericalGradient:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    @pytest.mark.parametrize(\"jit\", [True, False])\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\n        \"outflags\",\n        [\n            (True, True, True),\n            (True, True, False),\n            (True, False, True),\n            (True, False, False),\n            (False, True, True),\n            (False, True, False),\n            (False, False, True),\n        ],\n    )\n    @pytest.mark.parametrize(\"center\", [None, (-2, 3, 0), (1.2, -3.5, 1.5)])\n    @pytest.mark.parametrize(\n        \"shape_axes\",\n        [\n            ((20, 20, 20), None),\n            ((17, 18, 19), (0, 1, 2)),\n            ((16, 17, 18, 3), (0, 1, 2)),\n            ((2, 17, 16, 15), (1, 2, 3)),\n            ((17, 2, 16, 15), (0, 2, 3)),\n            ((17, 2, 16, 15), (3, 2, 0)),\n        ],\n    )\n    def test_eval(self, shape_axes, center, outflags, input_dtype, jit):\n\n        input_shape, axes = shape_axes\n        if axes is None:\n            testaxes = (0, 1, 2)\n        else:\n            testaxes = axes\n        if center is not None:\n            axes_shape = [input_shape[ax] for ax in testaxes]\n            center = (snp.array(axes_shape) - 1) / 2 + snp.array(center)\n        azimuthal, polar, radial = outflags\n        x, key = randn(input_shape, dtype=input_dtype, key=self.key)\n        A = SphericalGradient(\n            input_shape,\n            axes=axes,\n            center=center,\n            azimuthal=azimuthal,\n            polar=polar,\n            radial=radial,\n            input_dtype=input_dtype,\n            jit=jit,\n        )\n        Ax = A @ x\n        assert isinstance(Ax, Array)\n        Nc = sum([azimuthal, polar, radial])\n        if Nc > 1:\n            assert Ax.shape[0] == Nc\n            assert Ax.shape[1:] == input_shape\n        else:\n            assert Ax.shape == input_shape\n        assert Ax.dtype == input_dtype\n\n        # Test orthogonality of coordinate axes\n        coord = A.coord\n        for n0, n1 in combinations(range(len(coord)), 2):\n            c0 = coord[n0]\n            c1 = coord[n1]\n            s = sum([c0[m] * c1[m] for m in range(len(c0))]).sum()\n            assert snp.abs(s) < 1e-5\n"
  },
  {
    "path": "scico/test/linop/test_linop.py",
    "content": "import operator as op\n\nimport numpy as np\n\nfrom jax import config\n\nimport pytest\n\n# enable 64-bit mode for output dtype checks\nconfig.update(\"jax_enable_x64\", True)\nfrom typing import Optional\n\nimport jax\n\nimport scico.numpy as snp\nfrom scico import linop\nfrom scico.random import randn\nfrom scico.typing import PRNGKey\n\nSCALARS = (2, 1e0, snp.array(1.0))\n\n\ndef adjoint_test(\n    A: linop.LinearOperator,\n    key: Optional[PRNGKey] = None,\n    rtol: float = 1e-4,\n    x: Optional[snp.Array] = None,\n    y: Optional[snp.Array] = None,\n):\n    \"\"\"Check the validity of A.conj().T as the adjoint for a LinearOperator A.\n\n    Args:\n        A: LinearOperator to test.\n        key: PRNGKey for generating `x`.\n        rtol: Relative tolerance.\n    \"\"\"\n\n    assert linop.valid_adjoint(A, A.H, key=key, eps=rtol, x=x, y=y)\n\n\nclass AbsMatOp(linop.LinearOperator):\n    \"\"\"Simple LinearOperator subclass for testing purposes.\n\n    Similar to linop.MatrixOperator, but does not use the specialized\n    MatrixOperator methods (.T, adj, etc). Used to verify the\n    LinearOperator interface.\n    \"\"\"\n\n    def __init__(self, A, adj_fn=None):\n        self.A = A\n        super().__init__(\n            input_shape=A.shape[1], output_shape=A.shape[0], input_dtype=A.dtype, adj_fn=adj_fn\n        )\n\n    def _eval(self, x):\n        return self.A @ x\n\n\nclass LinearOperatorTestObj:\n    def __init__(self, dtype):\n        M, N = (8, 16)\n        key = jax.random.key(12345)\n        self.dtype = dtype\n\n        self.A, key = randn((M, N), dtype=dtype, key=key)\n        self.B, key = randn((M, N), dtype=dtype, key=key)\n        self.C, key = randn((N, M), dtype=dtype, key=key)\n        self.D, key = randn((M, N - 1), dtype=dtype, key=key)\n\n        self.x, key = randn((N,), dtype=dtype, key=key)\n        self.y, key = randn((M,), dtype=dtype, key=key)\n\n        self.Ao = AbsMatOp(self.A)\n        self.Bo = AbsMatOp(self.B)\n        self.Co = AbsMatOp(self.C)\n        self.Do = AbsMatOp(self.D)\n\n\n@pytest.fixture(scope=\"module\", params=[np.float32, np.float64, np.complex64, np.complex128])\ndef testobj(request):\n    yield LinearOperatorTestObj(request.param)\n\n\n@pytest.mark.parametrize(\"operator\", [op.add, op.sub])\ndef test_binary_op(testobj, operator):\n    # Our AbsMatOp class does not override the __add__, etc\n    # so AbsMatOp + AbsMatOp -> LinearOperator\n    # So to verify results, we evaluate the new LinearOperator on a random input\n\n    comp_mat = operator(testobj.A, testobj.B)  # composite matrix\n    comp_op = operator(testobj.Ao, testobj.Bo)  # composite linop\n\n    assert isinstance(comp_op, linop.LinearOperator)  # Ensure we don't get a Map\n    assert comp_op.input_dtype == testobj.A.dtype\n    np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=0, atol=1e-5)\n\n    # linops of different sizes\n    with pytest.raises(ValueError):\n        operator(testobj.Ao, testobj.Co)\n    with pytest.raises(ValueError):\n        operator(testobj.Ao, testobj.Do)\n\n\n@pytest.mark.parametrize(\"operator\", [op.mul, op.truediv])\n@pytest.mark.parametrize(\"scalar\", SCALARS)\ndef test_scalar_left(testobj, operator, scalar):\n    comp_mat = operator(testobj.A, scalar)\n    comp_op = operator(testobj.Ao, scalar)\n    assert isinstance(comp_op, linop.LinearOperator)  # Ensure we don't get a Map\n    assert comp_op.input_dtype == testobj.A.dtype\n    np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5)\n    np.testing.assert_allclose(comp_mat.conj().T @ testobj.y, comp_op.adj(testobj.y), rtol=2e-4)\n\n\n@pytest.mark.parametrize(\"operator\", [op.mul, op.truediv])\n@pytest.mark.parametrize(\"scalar\", SCALARS)\ndef test_scalar_right(testobj, operator, scalar):\n    if operator == op.truediv:\n        pytest.xfail(\"scalar / LinearOperator is not supported\")\n    comp_mat = operator(scalar, testobj.A)\n    comp_op = operator(scalar, testobj.Ao)\n    assert comp_op.input_dtype == testobj.A.dtype\n    np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5)\n\n\ndef test_negation(testobj):\n    comp_mat = -testobj.A\n    comp_op = -testobj.Ao\n    assert comp_op.input_dtype == testobj.A.dtype\n    np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"operator\", [op.add, op.sub])\ndef test_invalid_add_sub_array(testobj, operator):\n    # Try to add or subtract an ndarray with AbsMatOp\n    with pytest.raises(TypeError):\n        operator(testobj.A, testobj.Ao)\n\n\n@pytest.mark.parametrize(\"operator\", [op.add, op.sub])\ndef test_invalid_add_sub_scalar(testobj, operator):\n    # Try to add or subtract a scalar with AbsMatOp\n    with pytest.raises(TypeError):\n        operator(1.0, testobj.Ao)\n\n\ndef test_matmul_left(testobj):\n    comp_mat = testobj.A @ testobj.C\n    comp_op = testobj.Ao @ testobj.Co\n    assert comp_op.input_dtype == testobj.A.dtype\n    np.testing.assert_allclose(comp_mat @ testobj.y, comp_op @ testobj.y, rtol=5e-5)\n\n\ndef test_matmul_right(testobj):\n    comp_mat = testobj.C @ testobj.A\n    comp_op = testobj.Co @ testobj.Ao\n    assert comp_op.input_dtype == testobj.A.dtype\n    np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5)\n\n\ndef test_matvec_left(testobj):\n    comp_mat = testobj.A @ testobj.x\n    comp_op = testobj.Ao @ testobj.x\n    assert comp_op.dtype == testobj.A.dtype\n    np.testing.assert_allclose(comp_mat, comp_op, rtol=5e-5)\n\n\ndef test_matvec_right(testobj):\n    comp_mat = testobj.C @ testobj.y\n    comp_op = testobj.Co @ testobj.y\n    assert comp_op.dtype == testobj.A.dtype\n    np.testing.assert_allclose(comp_mat, comp_op, rtol=5e-5)\n\n\ndef test_gram(testobj):\n    Ao = testobj.Ao\n    a = Ao.gram(testobj.x)\n    b = Ao.conj().T @ Ao @ testobj.x\n    c = Ao.gram_op @ testobj.x\n\n    comp_mat = testobj.A.conj().T @ testobj.A @ testobj.x\n\n    np.testing.assert_allclose(a, comp_mat, rtol=5e-5)\n    np.testing.assert_allclose(b, comp_mat, rtol=5e-5)\n    np.testing.assert_allclose(c, comp_mat, rtol=5e-5)\n\n\ndef test_matvec_call(testobj):\n    # A @ x and A(x) should return same\n    np.testing.assert_allclose(testobj.Ao @ testobj.x, testobj.Ao(testobj.x), rtol=5e-5)\n\n\ndef test_adj_composition(testobj):\n    Ao = testobj.Ao\n    Bo = testobj.Bo\n    A = testobj.A\n    B = testobj.B\n    x = testobj.x\n\n    comp_mat = A.conj().T @ B\n    a = Ao.conj().T @ Bo\n    b = Ao.adj(Bo)\n    assert a.input_dtype == testobj.A.dtype\n    assert b.input_dtype == testobj.A.dtype\n    np.testing.assert_allclose(comp_mat @ x, a @ x, rtol=5e-5)\n    np.testing.assert_allclose(comp_mat @ x, b @ x, rtol=5e-5)\n\n\ndef test_transpose_matvec(testobj):\n    Ao = testobj.Ao\n    y = testobj.y\n\n    a = Ao.T @ y\n    b = y.T @ Ao\n\n    comp_mat = testobj.A.T @ y\n\n    assert a.dtype == testobj.A.dtype\n    assert b.dtype == testobj.A.dtype\n    np.testing.assert_allclose(a, comp_mat, rtol=2e-4)\n    np.testing.assert_allclose(a, b, rtol=5e-5)\n\n\ndef test_transpose_matmul(testobj):\n    Ao = testobj.Ao\n    Bo = testobj.Bo\n    x = testobj.x\n    comp_op = Ao.T @ Bo\n    comp_mat = testobj.A.T @ testobj.B\n    assert comp_op.input_dtype == testobj.A.dtype\n    np.testing.assert_allclose(comp_mat @ x, comp_op @ x, rtol=5e-5)\n\n\ndef test_conj_transpose_matmul(testobj):\n    Ao = testobj.Ao\n    Bo = testobj.Bo\n    x = testobj.x\n    comp_op = Ao.conj().T @ Bo\n    comp_mat = testobj.A.conj().T @ testobj.B\n    assert comp_mat.dtype == testobj.A.dtype\n    np.testing.assert_allclose(comp_mat @ x, comp_op @ x, rtol=5e-5)\n\n\ndef test_conj_matvec(testobj):\n    Ao = testobj.Ao\n    x = testobj.x\n    a = Ao.conj() @ x\n    comp_mat = testobj.A.conj() @ x\n    assert a.dtype == testobj.A.dtype\n    np.testing.assert_allclose(a, comp_mat, rtol=5e-5)\n\n\ndef test_adjoint_matvec(testobj):\n    Ao = testobj.Ao\n    y = testobj.y\n\n    a = Ao.adj(y)\n    b = Ao.conj().T @ y\n    c = (y.conj().T @ Ao).conj()\n\n    comp_mat = testobj.A.conj().T @ y\n\n    assert a.dtype == testobj.A.dtype\n    assert b.dtype == testobj.A.dtype\n    assert c.dtype == testobj.A.dtype\n    np.testing.assert_allclose(a, comp_mat, rtol=2e-4)\n    np.testing.assert_allclose(a, b, rtol=5e-5)\n    np.testing.assert_allclose(a, c, rtol=5e-5)\n\n\ndef test_adjoint_matmul(testobj):\n    # shape mismatch\n    Ao = testobj.Ao\n    Co = testobj.Co\n\n    with pytest.raises(ValueError):\n        Ao.adj(Co)\n\n\ndef test_hermitian(testobj):\n    Ao = testobj.Ao\n    y = testobj.y\n\n    np.testing.assert_allclose(Ao.conj().T @ y, Ao.H @ y)\n\n\ndef test_shape(testobj):\n    Ao = testobj.Ao\n    x = testobj.x\n    y = testobj.y\n\n    with pytest.raises(ValueError):\n        _ = Ao @ y\n\n    with pytest.raises(ValueError):\n        _ = Ao(y)\n\n    with pytest.raises(ValueError):\n        _ = Ao.T @ x\n\n    with pytest.raises(ValueError):\n        _ = Ao.adj(x)\n\n\ndef test_adj_lazy():\n    dtype = np.float32\n    M, N = (8, 16)\n    A, key = randn((M, N), dtype=np.float32, key=None)\n    y, key = randn((M,), dtype=np.float32, key=key)\n    Ao = AbsMatOp(A, adj_fn=None)  # defer setting the linop\n\n    assert Ao._adj is None\n    a = Ao.adj(y)  # Adjoint is set when .adj() is called\n    b = A.T @ y\n    np.testing.assert_allclose(a, b, rtol=1e-5)\n\n\ndef test_jit_adj_lazy():\n    dtype = np.float32\n    M, N = (8, 16)\n    A, key = randn((M, N), dtype=np.float32, key=None)\n    y, key = randn((M,), dtype=np.float32, key=key)\n    Ao = AbsMatOp(A, adj_fn=None)  # defer setting the linop\n    assert Ao._adj is None\n    Ao.jit()  # Adjoint set here\n    assert Ao._adj is not None\n    a = Ao.adj(y)\n    b = A.T @ y\n    np.testing.assert_allclose(a, b, rtol=1e-5)\n"
  },
  {
    "path": "scico/test/linop/test_linop_stack.py",
    "content": "import numpy as np\n\nimport jax\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico.linop import (\n    Convolve,\n    DiagonalReplicated,\n    DiagonalStack,\n    Identity,\n    Sum,\n    VerticalStack,\n)\nfrom scico.operator import Abs\nfrom scico.random import randn\nfrom scico.test.linop.test_linop import adjoint_test\n\n\nclass TestVerticalStack:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    @pytest.mark.parametrize(\"jit\", [False, True])\n    def test_construct(self, jit):\n        # requires a list of LinearOperators\n        Id = Identity((42,))\n        with pytest.raises(TypeError):\n            H = VerticalStack(Id, jit=jit)\n\n        # requires all list elements to be LinearOperators\n        A = Abs((42,))\n        with pytest.raises(TypeError):\n            H = VerticalStack((A, Id), jit=jit)\n\n        # checks input sizes\n        A = Identity((3, 2))\n        B = Identity((7, 2))\n        with pytest.raises(ValueError):\n            H = VerticalStack([A, B], jit=jit)\n\n        # in general, returns a BlockArray\n        A = Convolve(snp.ones((3, 3)), (7, 11))\n        B = Convolve(snp.ones((2, 2)), (7, 11))\n        H = VerticalStack([A, B], jit=jit)\n        x = np.ones((7, 11))\n        y = H @ x\n        assert y.shape == ((9, 13), (8, 12))\n\n        # ... result should be [A@x, B@x]\n        assert np.allclose(y[0], A @ x)\n        assert np.allclose(y[1], B @ x)\n\n        # by default, collapse_output to jax array when possible\n        A = Convolve(snp.ones((2, 2)), (7, 11))\n        B = Convolve(snp.ones((2, 2)), (7, 11))\n        H = VerticalStack([A, B], jit=jit)\n        x = np.ones((7, 11))\n        y = H @ x\n        assert y.shape == (2, 8, 12)\n\n        # ... result should be [A@x, B@x]\n        assert np.allclose(y[0], A @ x)\n        assert np.allclose(y[1], B @ x)\n\n        # let user turn off collapsing\n        A = Convolve(snp.ones((2, 2)), (7, 11))\n        B = Convolve(snp.ones((2, 2)), (7, 11))\n        H = VerticalStack([A, B], collapse_output=False, jit=jit)\n        x = np.ones((7, 11))\n        y = H @ x\n        assert y.shape == ((8, 12), (8, 12))\n\n    @pytest.mark.parametrize(\"collapse_output\", [False, True])\n    @pytest.mark.parametrize(\"jit\", [False, True])\n    def test_adjoint(self, collapse_output, jit):\n        # general case\n        A = Convolve(snp.ones((3, 3)), (7, 11))\n        B = Convolve(snp.ones((2, 2)), (7, 11))\n        H = VerticalStack([A, B], collapse_output=collapse_output, jit=jit)\n        adjoint_test(H, self.key)\n\n        # collapsable case\n        A = Convolve(snp.ones((2, 2)), (7, 11))\n        B = Convolve(snp.ones((2, 2)), (7, 11))\n        H = VerticalStack([A, B], collapse_output=collapse_output, jit=jit)\n        adjoint_test(H, self.key)\n\n    @pytest.mark.parametrize(\"collapse_output\", [False, True])\n    @pytest.mark.parametrize(\"jit\", [False, True])\n    def test_algebra(self, collapse_output, jit):\n        # adding\n        A = Convolve(snp.ones((2, 2)), (7, 11))\n        B = Convolve(snp.ones((2, 2)), (7, 11))\n        H = VerticalStack([A, B], collapse_output=collapse_output, jit=jit)\n\n        A = Convolve(snp.array(np.random.rand(2, 2)), (7, 11))\n        B = Convolve(snp.array(np.random.rand(2, 2)), (7, 11))\n        G = VerticalStack([A, B], collapse_output=collapse_output, jit=jit)\n\n        x = np.ones((7, 11))\n        S = H + G\n\n        # test correctness of addition\n        assert S.output_shape == H.output_shape\n        assert S.input_shape == H.input_shape\n        np.testing.assert_allclose((S @ x)[0], (H @ x + G @ x)[0])\n        np.testing.assert_allclose((S @ x)[1], (H @ x + G @ x)[1])\n\n\nclass TestBlockDiagonalLinearOperator:\n    def test_construct(self):\n        Id = Identity((42,))\n        A = Abs((42,))\n        with pytest.raises(TypeError):\n            H = DiagonalStack((A, Id))\n\n    def test_apply(self):\n        S1 = (3, 4)\n        S2 = (3, 5)\n        S3 = (2, 2)\n        A1 = Identity(S1)\n        A2 = 2 * Identity(S2)\n        A3 = Sum(S3)\n        H = DiagonalStack((A1, A2, A3))\n\n        x = snp.ones((S1, S2, S3))\n        y = H @ x\n        y_expected = snp.blockarray((snp.ones(S1), 2 * snp.ones(S2), snp.sum(snp.ones(S3))))\n\n        np.testing.assert_equal(y, y_expected)\n\n    def test_adjoint(self):\n        S1 = (3, 4)\n        S2 = (3, 5)\n        S3 = (2, 2)\n        A1 = Identity(S1)\n        A2 = 2 * Identity(S2)\n        A3 = Sum(S3)\n        H = DiagonalStack((A1, A2, A3))\n\n        y = snp.ones((S1, S2, ()), dtype=snp.float32)\n        x = H.T @ y\n        x_expected = snp.blockarray(\n            (\n                snp.ones(S1),\n                snp.ones(S2),\n                snp.ones(S3),\n            )\n        )\n\n        assert x == x_expected\n\n    def test_input_collapse(self):\n        S = (3, 4)\n        A1 = Identity(S)\n        A2 = Sum(S)\n\n        H = DiagonalStack((A1, A2))\n        assert H.input_shape == (2, *S)\n\n        H = DiagonalStack((A1, A2), collapse_input=False)\n        assert H.input_shape == (S, S)\n\n    def test_output_collapse(self):\n        S1 = (3, 4)\n        S2 = (5, 3, 4)\n        A1 = Identity(S1)\n        A2 = Sum(S2, axis=0)\n\n        H = DiagonalStack((A1, A2))\n        assert H.output_shape == (2, *S1)\n\n        H = DiagonalStack((A1, A2), collapse_output=False)\n        assert H.output_shape == (S1, S1)\n\n\nclass TestDiagonalReplicated:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    def test_adjoint(self):\n        x, key = randn((2, 3, 4), key=self.key)\n        A = Sum(x.shape[1:], axis=-1)\n        D = DiagonalReplicated(A, x.shape[0])\n        y = D.T(D(x))\n        np.testing.assert_allclose(y[0], A.T(A(x[0])))\n        np.testing.assert_allclose(y[1], A.T(A(x[1])))\n"
  },
  {
    "path": "scico/test/linop/test_linop_util.py",
    "content": "import numpy as np\n\nfrom jax import config\n\nimport pytest\n\n# enable 64-bit mode for output dtype checks\nconfig.update(\"jax_enable_x64\", True)\n\nimport jax\n\nimport scico.numpy as snp\nfrom scico import linop\nfrom scico.operator import Operator\nfrom scico.random import randn\nfrom scico.test.linop.test_linop import AbsMatOp\n\n\ndef test_valid_adjoint():\n    diagonal, key = randn((10,), dtype=np.float32)\n    D = linop.Diagonal(diagonal=diagonal)\n    assert linop.valid_adjoint(D, D.T, key=key, eps=None) < 1e-4\n    x, key = randn((5,), dtype=np.float32)\n    y, key = randn((5,), dtype=np.float32)\n    with pytest.raises(ValueError):\n        linop.valid_adjoint(D, D.T, key=key, x=x)\n    with pytest.raises(ValueError):\n        linop.valid_adjoint(D, D.T, key=key, y=y)\n\n\nclass PowerIterTestObj:\n    def __init__(self, dtype):\n        M, N = (4, 4)\n        key = jax.random.key(12345)\n        self.dtype = dtype\n\n        A, key = randn((M, N), dtype=dtype, key=key)\n        self.A = A.conj().T @ A  # ensure symmetric\n\n        self.Ao = linop.MatrixOperator(self.A)\n        self.Bo = AbsMatOp(self.A)\n\n        self.key = key\n        self.ev = snp.linalg.norm(\n            self.A, 2\n        )  # The largest eigenvalue of A is the spectral norm of A\n\n\n@pytest.fixture(scope=\"module\", params=[np.float32, np.complex64])\ndef pitestobj(request):\n    yield PowerIterTestObj(request.param)\n\n\ndef test_power_iteration(pitestobj):\n    \"\"\"Verify that power iteration calculates largest eigenvalue for real and complex\n    symmetric matrices.\n    \"\"\"\n    # Test using the LinearOperator MatrixOperator\n    mu, v = linop.power_iteration(A=pitestobj.Ao, maxiter=100, key=pitestobj.key)\n    assert np.abs(mu - pitestobj.ev) < 1e-4\n\n    # Test using the AbsMatOp for test_linop.py\n    mu, v = linop.power_iteration(A=pitestobj.Bo, maxiter=100, key=pitestobj.key)\n    assert np.abs(mu - pitestobj.ev) < 1e-4\n\n\ndef test_operator_norm():\n    Iop = linop.Identity(8)\n    Inorm = linop.operator_norm(Iop)\n    assert np.abs(Inorm - 1.0) < 1e-5\n    key = jax.random.key(12345)\n    for dtype in [np.float32, np.complex64]:\n        d, key = randn((16,), dtype=dtype, key=key)\n        D = linop.Diagonal(d)\n        Dnorm = linop.operator_norm(D)\n        assert np.abs(Dnorm - snp.abs(d).max()) < 1e-5\n    Zop = linop.MatrixOperator(snp.zeros((3, 3)))\n    Znorm = linop.operator_norm(Zop)\n    assert np.abs(Znorm) < 1e-6\n\n\n@pytest.mark.parametrize(\"dtype\", [snp.float32, snp.complex64])\n@pytest.mark.parametrize(\"inc_eval\", [True, False])\ndef test_jacobian(dtype, inc_eval):\n    N = 7\n    M = 8\n    key = None\n    fmx, key = randn((M, N), key=key, dtype=dtype)\n    F = Operator(\n        (N, 1),\n        output_shape=(M, 1),\n        eval_fn=lambda x: fmx @ x,\n        input_dtype=dtype,\n        output_dtype=dtype,\n    )\n    u, key = randn((N, 1), key=key, dtype=dtype)\n    v, key = randn((N, 1), key=key, dtype=dtype)\n    w, key = randn((M, 1), key=key, dtype=dtype)\n\n    J = linop.jacobian(F, u, include_eval=inc_eval)\n    Jv = J(v)\n    JHw = J.H(w)\n\n    if inc_eval:\n        np.testing.assert_allclose(Jv[0], F(u))\n        np.testing.assert_allclose(Jv[1], F.jvp(u, v)[1])\n        np.testing.assert_allclose(JHw[0], F(u))\n        np.testing.assert_allclose(JHw[1], F.vjp(u)[1](w))\n    else:\n        np.testing.assert_allclose(Jv, F.jvp(u, v)[1])\n        np.testing.assert_allclose(JHw, F.vjp(u)[1](w))\n"
  },
  {
    "path": "scico/test/linop/test_matrix.py",
    "content": "import operator as op\n\nimport numpy as np\n\nimport jax\nimport jax.numpy as jnp\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico import linop\nfrom scico.linop import MatrixOperator\nfrom scico.random import randn\nfrom scico.test.linop.test_linop import AbsMatOp\n\n\nclass TestMatrix:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    @pytest.mark.parametrize(\"input_cols\", [0, 2])\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"matrix_shape\", [(3, 3), (3, 4)])\n    def test_eval(self, matrix_shape, input_dtype, input_cols):\n        A, key = randn(matrix_shape, dtype=input_dtype, key=self.key)\n        Ao = MatrixOperator(A, input_cols=input_cols)\n\n        x, key = randn(Ao.input_shape, dtype=Ao.input_dtype, key=key)\n        np.testing.assert_allclose(A @ x, Ao @ x)\n\n        # Invalid shapes\n        with pytest.raises(TypeError):\n            y, key = randn((64,), dtype=Ao.input_dtype, key=key)\n            _ = Ao @ y\n\n    @pytest.mark.parametrize(\"input_cols\", [0, 2])\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"matrix_shape\", [(3, 3), (3, 4)])\n    def test_adjoint(self, matrix_shape, input_dtype, input_cols):\n        A, key = randn(matrix_shape, dtype=input_dtype, key=self.key)\n        Ao = MatrixOperator(A, input_cols=input_cols)\n\n        x, key = randn(Ao.output_shape, dtype=Ao.input_dtype, key=key)\n        np.testing.assert_allclose(A.conj().T @ x, Ao.conj().T @ x)\n\n    @pytest.mark.parametrize(\"input_cols\", [0, 2])\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"matrix_shape\", [(3, 3), (3, 4)])\n    def test_adjoint_method(self, matrix_shape, input_dtype, input_cols):\n        A, key = randn(matrix_shape, dtype=input_dtype, key=self.key)\n        Ao = MatrixOperator(A, input_cols=input_cols)\n        x, key = randn(Ao.output_shape, dtype=Ao.input_dtype, key=key)\n        np.testing.assert_allclose(Ao.adj(x), Ao.conj().T @ x)\n\n    @pytest.mark.parametrize(\"input_cols\", [0, 2])\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"matrix_shape\", [(3, 3), (3, 4)])\n    def test_hermetian_method(self, matrix_shape, input_dtype, input_cols):\n        A, key = randn(matrix_shape, dtype=input_dtype, key=self.key)\n        Ao = MatrixOperator(A, input_cols=input_cols)\n        x, key = randn(Ao.output_shape, dtype=Ao.input_dtype, key=key)\n        np.testing.assert_allclose(Ao.H @ x, Ao.conj().T @ x)\n\n    @pytest.mark.parametrize(\"input_cols\", [0, 2])\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"matrix_shape\", [(3, 3), (3, 4)])\n    def test_gram_method(self, matrix_shape, input_dtype, input_cols):\n        A, key = randn(matrix_shape, dtype=input_dtype, key=self.key)\n        Ao = MatrixOperator(A, input_cols=input_cols)\n        x, key = randn(Ao.input_shape, dtype=Ao.input_dtype, key=key)\n        np.testing.assert_allclose(Ao.gram(x), A.conj().T @ A @ x, rtol=5e-5)\n\n    @pytest.mark.parametrize(\"input_cols\", [0, 2])\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"matrix_shape\", [(3, 3), (3, 4)])\n    def test_gram_op(self, matrix_shape, input_dtype, input_cols):\n        A, key = randn(matrix_shape, dtype=input_dtype, key=self.key)\n        Ao = MatrixOperator(A, input_cols=input_cols)\n        G = Ao.gram_op\n        x, key = randn(Ao.input_shape, dtype=Ao.input_dtype, key=key)\n        np.testing.assert_allclose(G @ x, A.conj().T @ A @ x, rtol=5e-5)\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub])\n    def test_add_sub(self, operator):\n        A, key = randn((4, 6), key=self.key)\n        B, key = randn((4, 6), key=key)\n        C, key = randn((4, 4), key=key)\n        x, key = randn((6,), key=key)\n        Ao = MatrixOperator(A)\n        Bo = MatrixOperator(B)\n        Co = MatrixOperator(C)\n\n        ABx = operator(Ao, Bo) @ x\n        AxBx = operator(Ao @ x, Bo @ x)\n        np.testing.assert_allclose(ABx, AxBx, rtol=5e-5)\n\n        with pytest.raises(ValueError):\n            operator(Ao, Co)\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub, op.mul, op.truediv])\n    def test_scalar_left(self, operator):\n        scalar = np.float32(np.random.randn())\n\n        A, key = randn((4, 6), key=self.key)\n        x, key = randn((6,), key=key)\n        Ao = MatrixOperator(A)\n\n        np.testing.assert_allclose(operator(scalar, Ao) @ x, operator(scalar, A) @ x, rtol=5e-5)\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub, op.mul, op.truediv])\n    def test_scalar_right(self, operator):\n        scalar = np.float32(np.random.randn())\n\n        A, key = randn((4, 6), key=self.key)\n        x, key = randn((6,), key=key)\n        Ao = MatrixOperator(A)\n\n        np.testing.assert_allclose(operator(Ao, scalar) @ x, operator(A, scalar) @ x, rtol=5e-5)\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub, op.mul, op.truediv])\n    def test_elementwise_matops(self, operator):\n        A, key = randn((4, 6), key=self.key)\n        B, key = randn((4, 6), key=key)\n\n        Ao = MatrixOperator(A)\n        Bo = MatrixOperator(B)\n\n        np.testing.assert_allclose(operator(Ao, Bo).A, operator(A, B), rtol=5e-5)\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub, op.mul, op.truediv])\n    def test_elementwise_array_left(self, operator):\n        A, key = randn((4, 6), key=self.key)\n        B, key = randn((4, 6), key=key)\n        Ao = MatrixOperator(A)\n        Bo = MatrixOperator(B)\n        np.testing.assert_allclose(operator(Ao, B).A, operator(A, B), rtol=5e-5)\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub, op.mul, op.truediv])\n    def test_elementwise_array_right(self, operator):\n        A, key = randn((4, 6), key=self.key)\n        B, key = randn((4, 6), key=key)\n        Ao = MatrixOperator(A)\n        Bo = MatrixOperator(B)\n        np.testing.assert_allclose(operator(A, Bo).A, operator(A, B), rtol=5e-5)\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub, op.mul, op.truediv])\n    def test_elementwise_matop_shape_mismatch(self, operator):\n        A, key = randn((4, 6), key=self.key)\n        B, key = randn((4, 4), key=key)\n        Ao = MatrixOperator(A)\n        Bo = MatrixOperator(B)\n        with pytest.raises(ValueError):\n            operator(Ao, Bo)\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub, op.mul, op.truediv])\n    def test_elementwise_array_shape_mismatch(self, operator):\n        A, key = randn((4, 6), key=self.key)\n        B, key = randn((4, 4), key=key)\n        Ao = MatrixOperator(A)\n        Bo = MatrixOperator(B)\n        with pytest.raises(ValueError):\n            operator(Ao, B)\n\n        with pytest.raises(ValueError):\n            operator(B, Ao)\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub])\n    def test_elementwise_linop(self, operator):\n        A, key = randn((4, 6), key=self.key)\n        B, key = randn((4, 6), key=key)\n        Ao = MatrixOperator(A)\n        Bo = AbsMatOp(B)\n        x, key = randn(Ao.input_shape, dtype=Ao.input_dtype, key=key)\n\n        np.testing.assert_allclose(operator(Ao, Bo) @ x, operator(Ao @ x, Bo @ x), rtol=5e-5)\n\n    @pytest.mark.parametrize(\"operator\", [op.add, op.sub])\n    def test_elementwise_linop_mismatch(self, operator):\n        A, key = randn((4, 6), key=self.key)\n        B, key = randn((4, 4), key=key)\n        Ao = MatrixOperator(A)\n        Bo = AbsMatOp(B)\n        with pytest.raises(ValueError):\n            operator(Ao, Bo)\n\n    @pytest.mark.parametrize(\"operator\", [op.mul, op.truediv])\n    def test_elementwise_linop_invalid(self, operator):\n        A, key = randn((4, 6), key=self.key)\n        B, key = randn((4, 6), key=key)\n        Ao = MatrixOperator(A)\n        Bo = AbsMatOp(B)\n        with pytest.raises(TypeError):\n            operator(Ao, Bo)\n\n        with pytest.raises(TypeError):\n            operator(Bo, Ao)\n\n    def test_matmul(self):\n        A, key = randn((4, 6), key=self.key)\n        B, key = randn((6, 3), key=key)\n        Ao = MatrixOperator(A)\n        Bo = MatrixOperator(B)\n        x, key = randn(Bo.input_shape, dtype=Ao.input_dtype, key=key)\n\n        AB = Ao @ Bo\n        np.testing.assert_allclose((Ao @ Bo) @ x, Ao @ (Bo @ x), rtol=5e-5)\n\n    def test_matmul_cols(self):\n        A, key = randn((4, 6), key=self.key)\n        B, key = randn((6, 3), key=key)\n        Ao = MatrixOperator(A, input_cols=2)\n        Bo = MatrixOperator(B, input_cols=2)\n        x, key = randn(Bo.input_shape, dtype=Ao.input_dtype, key=key)\n\n        AB = Ao @ Bo\n        np.testing.assert_allclose((Ao @ Bo) @ x, Ao @ (Bo @ x), rtol=5e-5)\n\n    def test_matmul_linop(self):\n        A, key = randn((4, 6), key=self.key)\n        B, key = randn((6, 3), key=key)\n        Ao = MatrixOperator(A)\n        Bo = AbsMatOp(B)\n        x, key = randn(Bo.input_shape, dtype=Ao.input_dtype, key=key)\n\n        AB = Ao @ Bo\n        np.testing.assert_allclose((Ao @ Bo) @ x, Ao @ (Bo @ x), rtol=5e-5)\n\n    def test_matmul_linop_shape_mismatch(self):\n        A, key = randn((4, 6), key=self.key)\n        B, key = randn((5, 3), key=key)\n        Ao = MatrixOperator(A)\n        Bo = AbsMatOp(B)\n        with pytest.raises(ValueError):\n            _ = Ao @ Bo\n\n    def test_matmul_identity(self):\n        A, key = randn((4, 6), key=self.key)\n        Ao = MatrixOperator(A)\n        I = linop.Identity(input_shape=(6,))\n        assert Ao == Ao @ I\n\n    def test_init_array(self):\n        Am = np.random.randn(4, 6)\n        A = MatrixOperator(Am)\n        assert isinstance(A.A, np.ndarray)\n\n        A = MatrixOperator(jnp.array(Am))\n        assert isinstance(A.A, jnp.ndarray)\n        np.testing.assert_array_equal(A.A, jnp.array(A))\n\n        with pytest.raises(TypeError):\n            MatrixOperator([1.0, 3.0])\n\n    @pytest.mark.parametrize(\"matrix_shape\", [(3,), (2, 3, 4)])\n    def test_init_wrong_dims(self, matrix_shape):\n        A = np.random.randn(*matrix_shape)\n        with pytest.raises(TypeError):\n            Ao = MatrixOperator(A)\n\n    def test_to_array(self):\n        A = np.random.randn(4, 6)\n        Ao = MatrixOperator(A)\n        A_array = Ao.to_array()\n        assert isinstance(A_array, np.ndarray)\n        np.testing.assert_allclose(A_array, A)\n\n        A_array = jnp.array(Ao)\n        assert isinstance(A_array, jax.Array)\n        np.testing.assert_allclose(A_array, A)\n\n    @pytest.mark.parametrize(\"ord\", [\"fro\", 2])\n    @pytest.mark.parametrize(\"axis\", [None, 0, 1])\n    @pytest.mark.parametrize(\"keepdims\", [True, False])\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    def test_norm(self, ord, axis, keepdims, input_dtype):  # pylint: disable=W0622\n        A, key = randn((4, 6), dtype=input_dtype, key=self.key)\n        Ao = MatrixOperator(A)\n\n        if ord == \"fro\" and axis is not None:\n            # Not defined;\n            pass\n        else:\n            x = Ao.norm(ord=ord, axis=axis, keepdims=keepdims)\n            y = snp.linalg.norm(A, ord=ord, axis=axis, keepdims=keepdims)\n            np.testing.assert_allclose(x, y, rtol=5e-5)\n"
  },
  {
    "path": "scico/test/linop/test_optics.py",
    "content": "import numpy as np\n\nimport jax\n\nimport pytest\n\nfrom scico.linop.optics import (\n    AngularSpectrumPropagator,\n    FraunhoferPropagator,\n    FresnelPropagator,\n    radial_transverse_frequency,\n)\nfrom scico.random import randn\nfrom scico.test.linop.test_linop import adjoint_test\n\nprop_list = [AngularSpectrumPropagator, FresnelPropagator, FraunhoferPropagator]\n\n\nclass TestPropagator:\n    def setup_method(self, method):\n        key = jax.random.key(12345)\n        self.N = 128\n        self.dx = 1\n        self.k0 = 1\n        self.z = 1\n        self.key = key\n\n    @pytest.mark.parametrize(\"ndim\", [1, 2])\n    @pytest.mark.parametrize(\"prop\", prop_list)\n    def test_prop_adjoint(self, prop, ndim):\n        A = prop(input_shape=(self.N,) * ndim, dx=self.dx, k0=self.k0, z=self.z)\n        adjoint_test(A, self.key)\n\n    @pytest.mark.parametrize(\"ndim\", [1, 2])\n    def test_AS_inverse(self, ndim):\n        A = AngularSpectrumPropagator(\n            input_shape=(self.N,) * ndim, dx=self.dx, k0=self.k0, z=self.z\n        )\n        x, key = randn(A.input_shape, dtype=np.complex64, key=self.key)\n        Ax = A @ x\n        AiAx = A.pinv(Ax)\n        np.testing.assert_allclose(x, AiAx, rtol=6e-4)\n\n    @pytest.mark.parametrize(\"prop\", prop_list)\n    def test_3d_invalid(self, prop):\n        with pytest.raises(ValueError):\n            prop(input_shape=(self.N, self.N, self.N), dx=self.dx, k0=self.k0, z=self.z)\n\n    @pytest.mark.parametrize(\"prop\", prop_list)\n    def test_shape_dx_mismatch(self, prop):\n        with pytest.raises(ValueError):\n            prop(input_shape=(self.N,), dx=(self.dx, self.dx), k0=self.k0, z=self.z)\n\n    def test_3d_invalid_radial(self):\n        with pytest.raises(ValueError):\n            radial_transverse_frequency(input_shape=(self.N, self.N, self.N), dx=self.dx)\n\n    def test_shape_dx_mismatch_radial(self):\n        with pytest.raises(ValueError):\n            radial_transverse_frequency(input_shape=(self.N,), dx=(self.dx, self.dx))\n\n\n@pytest.mark.parametrize(\"ndim\", [1, 2])\ndef test_asp_sampling(ndim):\n    N = 128\n    dx = 1\n    z = 1\n    A = AngularSpectrumPropagator(input_shape=(N,) * ndim, dx=dx, k0=1, z=z)\n    assert not A.adequate_sampling()\n    A = AngularSpectrumPropagator(input_shape=(N,) * ndim, dx=dx, k0=100, z=z)\n    assert A.adequate_sampling()\n\n\n@pytest.mark.parametrize(\"ndim\", [1, 2])\ndef test_fresnel_sampling(ndim):\n    N = 128\n    dx = 1\n    k0 = 1\n    A = FresnelPropagator(input_shape=(N,) * ndim, dx=dx, k0=k0, z=N**2)\n    assert not A.adequate_sampling()\n    A = FresnelPropagator(input_shape=(N,) * ndim, dx=dx, k0=k0, z=1)\n    assert A.adequate_sampling()\n\n\n@pytest.mark.parametrize(\"ndim\", [1, 2])\ndef test_fraunhofer_sampling(ndim):\n    N = 128\n    dx = 1\n    k0 = 1\n    A = FraunhoferPropagator(input_shape=(N,) * ndim, dx=dx, k0=k0, z=N**2)\n    assert not A.adequate_sampling()\n    A = FraunhoferPropagator(input_shape=(N,) * ndim, dx=dx, k0=k0, z=1)\n    assert A.adequate_sampling()\n"
  },
  {
    "path": "scico/test/linop/xray/test_abel.py",
    "content": "import numpy as np\n\nimport jax\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico.linop.xray.abel import AbelTransform\nfrom scico.test.linop.test_linop import adjoint_test\n\nBIG_INPUT = (128, 128)\nSMALL_INPUT = (4, 5)\n\n\ndef make_im(Nx, Ny):\n    x, y = snp.meshgrid(snp.linspace(-1, 1, Nx), snp.linspace(-1, 1, Ny))\n\n    im = snp.where(x**2 + y**2 < 0.3, 1.0, 0.0)\n\n    return im\n\n\n@pytest.mark.parametrize(\"Nx, Ny\", (BIG_INPUT, SMALL_INPUT))\ndef test_inverse(Nx, Ny):\n    im = make_im(Nx, Ny)\n    A = AbelTransform(im.shape)\n\n    Ax = A @ im\n    im_hat = A.inverse(Ax)\n\n    np.testing.assert_allclose(im_hat, im, rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"Nx, Ny\", (BIG_INPUT, SMALL_INPUT))\ndef test_adjoint(Nx, Ny):\n    im = make_im(Nx, Ny)\n    A = AbelTransform(im.shape)\n    adjoint_test(A)\n\n\n@pytest.mark.parametrize(\"Nx, Ny\", (BIG_INPUT, SMALL_INPUT))\ndef test_ATA(Nx, Ny):\n    x = make_im(Nx, Ny)\n    A = AbelTransform(x.shape)\n    Ax = A(x)\n    ATAx = A.adj(Ax)\n    np.testing.assert_allclose(np.sum(x * ATAx), np.linalg.norm(Ax) ** 2, rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"Nx, Ny\", (BIG_INPUT, SMALL_INPUT))\ndef test_grad(Nx, Ny):\n    # ensure that we can take grad on a function using our projector\n    # grad || A(x) ||_2^2 == 2 A.T @ A x\n    x = make_im(Nx, Ny)\n    A = AbelTransform(x.shape)\n    g = lambda x: jax.numpy.linalg.norm(A(x)) ** 2\n    np.testing.assert_allclose(jax.grad(g)(x), 2 * A.adj(A(x)), rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"Nx, Ny\", (BIG_INPUT, SMALL_INPUT))\ndef test_adjoint_grad(Nx, Ny):\n    x = make_im(Nx, Ny)\n    A = AbelTransform(x.shape)\n    Ax = A @ x\n    f = lambda y: jax.numpy.linalg.norm(A.T(y)) ** 2\n    np.testing.assert_allclose(jax.grad(f)(Ax), 2 * A(A.adj(Ax)), rtol=5e-5)\n"
  },
  {
    "path": "scico/test/linop/xray/test_astra.py",
    "content": "import numpy as np\n\nimport jax\n\nimport pytest\n\nimport scico\nimport scico.numpy as snp\nfrom scico.linop import DiagonalStack\nfrom scico.test.linop.test_linop import adjoint_test\nfrom scipy.spatial.transform import Rotation\n\ntry:\n    from scico.linop.xray.astra import (\n        XRayTransform2D,\n        XRayTransform3D,\n        _ensure_writeable,\n        angle_to_vector,\n        rotate_vectors,\n    )\nexcept ModuleNotFoundError as e:\n    if e.name == \"astra\":\n        pytest.skip(\"astra not installed\", allow_module_level=True)\n    else:\n        raise e\n\n\nN = 128\nRTOL_CPU = 1e-4\nRTOL_GPU = 1e-1\nRTOL_GPU_RANDOM_INPUT = 2.0\n\n\ndef make_im(Nx, Ny, is_3d=True):\n    x, y = snp.meshgrid(snp.linspace(-1, 1, Nx), snp.linspace(-1, 1, Ny), indexing=\"ij\")\n\n    im = snp.where((x - 0.25) ** 2 / 3 + y**2 < 0.1, 1.0, 0.0)\n    if is_3d:\n        im = im[snp.newaxis, :, :]\n    im = im.astype(snp.float32)\n\n    return im\n\n\ndef get_tol():\n    if jax.devices()[0].device_kind == \"cpu\":\n        rtol = RTOL_CPU\n    else:\n        rtol = RTOL_GPU  # astra inaccurate in GPU\n    return rtol\n\n\ndef get_tol_random_input():\n    if jax.devices()[0].device_kind == \"cpu\":\n        rtol = RTOL_CPU\n    else:\n        rtol = RTOL_GPU_RANDOM_INPUT  # astra more inaccurate in GPU for random inputs\n    return rtol\n\n\nclass XRayTransform2DTest:\n    def __init__(self, volume_geometry):\n        N_proj = 180  # number of projection angles\n        N_det = 384\n        det_spacing = 1\n        angles = np.linspace(0, np.pi, N_proj, False)\n\n        np.random.seed(1234)\n        self.x = np.random.randn(N, N).astype(np.float32)\n        self.y = np.random.randn(N_proj, N_det).astype(np.float32)\n        self.A = XRayTransform2D(\n            input_shape=(N, N),\n            det_count=N_det,\n            det_spacing=det_spacing,\n            angles=angles,\n            volume_geometry=volume_geometry,\n        )\n\n\n@pytest.fixture(params=[None, [-N / 2, N / 2, -N / 2, N / 2]])\ndef testobj(request):\n    yield XRayTransform2DTest(request.param)\n\n\ndef test_init(testobj):\n    with pytest.raises(ValueError):\n        A = XRayTransform2D(\n            input_shape=(16, 16, 16),\n            det_count=16,\n            det_spacing=1.0,\n            angles=np.linspace(0, np.pi, 32, False),\n        )\n    with pytest.raises(ValueError):\n        A = XRayTransform2D(\n            input_shape=(16, 16),\n            det_count=16.3,\n            det_spacing=1.0,\n            angles=np.linspace(0, np.pi, 32, False),\n        )\n    with pytest.raises(ValueError):\n        A = XRayTransform2D(\n            input_shape=(16, 16),\n            det_count=16,\n            det_spacing=1.0,\n            angles=np.linspace(0, np.pi, 32, False),\n            device=\"invalid\",\n        )\n\n\ndef test_ATA_call(testobj):\n    # Test for the call-based interface\n    Ax = testobj.A(testobj.x)\n    ATAx = testobj.A.adj(Ax)\n    np.testing.assert_allclose(np.sum(testobj.x * ATAx), np.linalg.norm(Ax) ** 2, rtol=get_tol())\n\n\ndef test_ATA_matmul(testobj):\n    # Test for the matmul interface\n    Ax = testobj.A @ testobj.x\n    ATAx = testobj.A.T @ Ax\n    np.testing.assert_allclose(np.sum(testobj.x * ATAx), np.linalg.norm(Ax) ** 2, rtol=get_tol())\n\n\ndef test_AAT_call(testobj):\n    # Test for the call-based interface\n    ATy = testobj.A.adj(testobj.y)\n    AATy = testobj.A(ATy)\n    np.testing.assert_allclose(np.sum(testobj.y * AATy), np.linalg.norm(ATy) ** 2, rtol=get_tol())\n\n\ndef test_AAT_matmul(testobj):\n    # Test for the matmul interface\n    ATy = testobj.A.T @ testobj.y\n    AATy = testobj.A @ ATy\n    np.testing.assert_allclose(np.sum(testobj.y * AATy), np.linalg.norm(ATy) ** 2, rtol=get_tol())\n\n\ndef test_grad(testobj):\n    # ensure that we can take grad on a function using our projector\n    # grad || A(x) ||_2^2 == 2 A.T @ A x\n    A = testobj.A\n    x = testobj.x\n    g = lambda x: jax.numpy.linalg.norm(A(x)) ** 2\n    np.testing.assert_allclose(\n        scico.grad(g)(x), 2 * A.adj(A(x)), atol=get_tol() * x.max(), rtol=get_tol()\n    )\n\n\ndef test_adjoint_grad(testobj):\n    A = testobj.A\n    x = testobj.x\n    Ax = A @ x\n    f = lambda y: jax.numpy.linalg.norm(A.T(y)) ** 2\n    np.testing.assert_allclose(scico.grad(f)(Ax), 2 * A(A.adj(Ax)), rtol=get_tol())\n\n\ndef test_adjoint_random(testobj):\n    A = testobj.A\n    adjoint_test(A, rtol=10 * get_tol_random_input())\n\n\ndef test_adjoint_typical_input(testobj):\n    A = testobj.A\n    x = make_im(A.input_shape[0], A.input_shape[1], is_3d=False)\n\n    adjoint_test(A, x=x, rtol=get_tol())\n\n\ndef test_fbp(testobj):\n    x = testobj.A.fbp(testobj.y)\n    # Test for a bug (related to calling the Astra CPU FBP implementation\n    # when using a FPU device) that resulted in a constant zero output.\n    assert np.sum(np.abs(x)) > 0.0\n\n\ndef test_jit_in_DiagonalStack():\n    \"\"\"See https://github.com/lanl/scico/issues/331\"\"\"\n    N = 10\n    H = DiagonalStack([XRayTransform2D((N, N), N, 1.0, snp.linspace(0, snp.pi, N))])\n    H.T @ snp.zeros(H.output_shape, dtype=snp.float32)\n\n\n@pytest.mark.skipif(jax.devices()[0].platform != \"gpu\", reason=\"checking GPU behavior\")\ndef test_3D_on_GPU():\n    x = snp.zeros((4, 5, 6))\n    A = XRayTransform3D(\n        x.shape, det_count=[6, 6], det_spacing=[1.0, 1.0], angles=snp.linspace(0, snp.pi, 10)\n    )\n\n    assert A.num_dims == 3\n    y = A @ x\n    ATy = A.T @ y\n\n\n@pytest.mark.skipif(jax.devices()[0].platform != \"gpu\", reason=\"GPU required for test\")\ndef test_3D_api_equiv():\n    x = np.random.randn(4, 5, 6).astype(np.float32)\n    det_count = [7, 8]\n    det_spacing = [1.0, 1.5]\n    angles = snp.linspace(0, snp.pi, 10)\n    A = XRayTransform3D(x.shape, det_count=det_count, det_spacing=det_spacing, angles=angles)\n    vectors = angle_to_vector(det_spacing, angles)\n    B = XRayTransform3D(x.shape, det_count=det_count, vectors=vectors)\n    ya = A @ x\n    yb = B @ x\n    np.testing.assert_allclose(ya, yb, rtol=get_tol())\n\n\ndef test_angle_to_vector():\n    angles = snp.linspace(0, snp.pi, 5)\n    det_spacing = [0.9, 1.5]\n    vectors = angle_to_vector(det_spacing, angles)\n    assert vectors.shape == (angles.size, 12)\n\n\ndef test_rotate_vectors():\n    v0 = angle_to_vector([1.0, 1.0], np.linspace(0, np.pi / 2, 4, endpoint=False))\n    v1 = angle_to_vector([1.0, 1.0], np.linspace(np.pi / 2, np.pi, 4, endpoint=False))\n    r = Rotation.from_euler(\"z\", np.pi / 2)\n    v0r = rotate_vectors(v0, r)\n    np.testing.assert_allclose(v1, v0r, atol=1e-7)\n\n\n## conversion functions\n@pytest.fixture(scope=\"module\")\ndef test_geometry():\n    \"\"\"\n    In this geometry, if vol[i, j, k]==1, we expect proj[j-2, k-1]==1.\n\n    Because:\n    - We project along z, i.e. `ray=(0,0,1)`, i.e., we remove axis=0.\n    - We set `v=(0, 1, 0)`, so detector rows go with y axis, axis=1.\n    - We set `u=(1, 0, 0)`, so detector columns go with x axis, axis=2.\n    - We shift the detector by (x=1, y=2, z=3) <-> i-3, j-2, k-1\n    \"\"\"\n    in_shape = (30, 31, 32)\n    # in ASTRA terminology:\n    n_rows = in_shape[1]  # y\n    n_cols = in_shape[2]  # x\n    n_slices = in_shape[0]  # z\n    vol_geom = scico.linop.xray.astra.astra.create_vol_geom(n_rows, n_cols, n_slices)\n\n    assert vol_geom[\"option\"][\"WindowMinX\"] == -n_cols / 2\n    assert vol_geom[\"option\"][\"WindowMinY\"] == -n_rows / 2\n    assert vol_geom[\"option\"][\"WindowMinZ\"] == -n_slices / 2\n\n    # project along z, axis=0\n    det_row_count = n_rows\n    det_col_count = n_cols\n    ray = (0, 0, 1)\n    d = (1, 2, 3)  # axis=2 offset by 1, axis=1 offset by 2, axis=0 offset by 3\n    u = (1, 0, 0)  # increments columns, goes with X\n    v = (0, 1, 0)  # increments rows, goes with Y\n    vectors = np.array(ray + d + u + v)[np.newaxis, :]\n    proj_geom = scico.linop.xray.astra.astra.create_proj_geom(\n        \"parallel3d_vec\", det_row_count, det_col_count, vectors\n    )\n\n    return vol_geom, proj_geom\n\n\n@pytest.mark.skipif(jax.devices()[0].platform != \"gpu\", reason=\"GPU required for test\")\ndef test_projection_convention(test_geometry):\n    \"\"\"\n    If vol[i, j, k]==1, test that astra puts proj[j-2, k-1]==1.\n\n    See `test_geometry` for the setup.\n    \"\"\"\n    vol_geom, proj_geom = test_geometry\n    in_shape = scico.linop.xray.astra.astra.functions.geom_size(vol_geom)\n    vol = np.zeros(in_shape)\n\n    i, j, k = [np.random.randint(0, s) for s in in_shape]\n    vol[i, j, k] = 1.0\n\n    proj_id, proj = scico.linop.xray.astra.astra.create_sino3d_gpu(vol, proj_geom, vol_geom)\n    scico.linop.xray.astra.astra.data3d.delete(proj_id)\n    proj = proj[:, 0, :]  # get first view\n    assert len(np.unique(proj) == 2)\n\n    idx_proj_i, idx_proj_j = np.nonzero(proj)\n    np.testing.assert_array_equal(idx_proj_i, j - 2)\n    np.testing.assert_array_equal(idx_proj_j, k - 1)\n\n\ndef test_project_coords(test_geometry):\n    \"\"\"\n    If vol[i, j, k]==1, test that we predict proj[j-2, k-1]==1.\n\n    See `test_geometry` for the setup and `test_projection_convention`\n    for proof ASTRA works this way.\n    \"\"\"\n    vol_geom, proj_geom = test_geometry\n    in_shape = scico.linop.xray.astra.astra.functions.geom_size(vol_geom)\n    x_vol = np.array([np.random.randint(0, s) for s in in_shape])\n    x_proj_gt = np.array(\n        [[x_vol[1] - 2, x_vol[2] - 1]]\n    )  # projection along slices removes first index\n    x_proj = scico.linop.xray.astra._project_coords(x_vol, vol_geom, proj_geom)\n    np.testing.assert_array_equal(x_proj_gt, x_proj)\n\n\ndef test_convert_to_scico_geometry(test_geometry):\n    \"\"\"\n    Basic regression test, `test_project_coords` tests the logic.\n    \"\"\"\n    vol_geom, proj_geom = test_geometry\n    matrices_truth = scico.linop.xray.astra._astra_to_scico_geometry(vol_geom, proj_geom)\n    truth = np.array([[[0.0, 1.0, 0.0, -2.0], [0.0, 0.0, 1.0, -1.0]]])\n    np.testing.assert_allclose(matrices_truth, truth)\n\n\ndef test_convert_from_scico_geometry(test_geometry):\n    \"\"\"\n    Basic regression test, `test_project_coords` tests the logic.\n    \"\"\"\n    in_shape = (30, 31, 32)\n    matrices = np.array([[[0.0, 1.0, 0.0, -2.0], [0.0, 0.0, 1.0, -1.0]]])\n    det_shape = (31, 32)\n    vectors = scico.linop.xray.astra.convert_from_scico_geometry(in_shape, matrices, det_shape)\n\n    _, proj_geom_truth = test_geometry\n    # skip testing element 5, as it is detector center along the ray and doesn't matter\n    np.testing.assert_allclose(vectors[0, :5], proj_geom_truth[\"Vectors\"][0, :5])\n    np.testing.assert_allclose(vectors[0, 6:], proj_geom_truth[\"Vectors\"][0, 6:])\n\n\ndef test_vol_coord_to_world_coord():\n    vol_geom = scico.linop.xray.astra.astra.create_vol_geom(16, 16)\n    vc = np.array([[0.0, 0.0], [1.0, 1.0]])\n    wc = scico.linop.xray.astra.volume_coords_to_world_coords(vc, vol_geom)\n    assert wc.shape == (2, 2)\n\n\ndef test_ensure_writeable():\n    assert isinstance(_ensure_writeable(np.ones((2, 1))), np.ndarray)\n    assert isinstance(_ensure_writeable(snp.ones((2, 1))), np.ndarray)\n"
  },
  {
    "path": "scico/test/linop/xray/test_svmbir.py",
    "content": "import numpy as np\n\nimport jax\n\nimport pytest\n\nimport scico\nimport scico.numpy as snp\nfrom scico.linop import Diagonal\nfrom scico.loss import SquaredL2Loss\nfrom scico.test.functional.prox import prox_test\nfrom scico.test.linop.test_linop import adjoint_test\n\ntry:\n    import svmbir\n\n    from scico.linop.xray.svmbir import (\n        SVMBIRExtendedLoss,\n        SVMBIRSquaredL2Loss,\n        XRayTransform,\n    )\nexcept ImportError as e:\n    pytest.skip(\"svmbir not installed\", allow_module_level=True)\n\n\nBIG_INPUT = (32, 33, 50, 51, 125, 1.2)\nSMALL_INPUT = (4, 5, 7, 8, 16, 1.2)\n\n\ndef pytest_generate_tests(metafunc):\n    param_ranges = {\n        \"is_3d\": (True, False),\n        \"is_masked\": (True, False),\n        \"geometry\": (\"parallel\", \"fan-curved\", \"fan-flat\"),\n        \"center_offset_small\": (0, 0.1),\n        \"center_offset_big\": (0, 3),\n        \"delta_channel\": (None, 0.5),\n        \"delta_pixel\": (None, 0.5),\n        \"positivity\": (True, False),\n        \"weight_type\": (\"transmission\", \"unweighted\"),\n    }\n    level = int(metafunc.config.getoption(\"--level\"))\n    if level < 3:\n        param_ranges.update({\"is_3d\": (False,), \"is_masked\": (False,), \"positivity\": (False,)})\n    if level < 2:\n        param_ranges.update(\n            {\n                \"geometry\": (\"parallel\",),\n                \"center_offset_small\": (0.1,),\n                \"center_offset_big\": (3,),\n                \"delta_channel\": (None,),\n                \"delta_pixel\": (None,),\n                \"weight_type\": (\"transmission\",),\n            }\n        )\n\n    for k, v in param_ranges.items():\n        if k in metafunc.fixturenames:\n            metafunc.parametrize(k, v)\n\n\ndef make_im(Nx, Ny, is_3d=True):\n    x, y = snp.meshgrid(snp.linspace(-1, 1, Nx), snp.linspace(-1, 1, Ny), indexing=\"ij\")\n\n    im = snp.where((x - 0.25) ** 2 / 3 + y**2 < 0.1, 1.0, 0.0)\n    if is_3d:\n        im = im[snp.newaxis, :, :]\n    im = im.astype(snp.float32)\n\n    return im\n\n\ndef make_angles(num_angles):\n    return snp.linspace(0, snp.pi, num_angles, dtype=snp.float32)\n\n\ndef make_A(\n    im,\n    num_angles,\n    num_channels,\n    center_offset,\n    is_masked,\n    geometry=\"parallel\",\n    dist_source_detector=None,\n    magnification=None,\n    delta_channel=None,\n    delta_pixel=None,\n):\n    angles = make_angles(num_angles)\n    A = XRayTransform(\n        im.shape,\n        angles,\n        num_channels,\n        center_offset=center_offset,\n        is_masked=is_masked,\n        geometry=geometry,\n        dist_source_detector=dist_source_detector,\n        magnification=magnification,\n    )\n\n    return A\n\n\ndef test_grad(\n    is_3d,\n    center_offset_big,\n    is_masked,\n    geometry,\n):\n    Nx, Ny, num_angles, num_channels, dist_source_detector, magnification = BIG_INPUT\n    im = make_im(Nx, Ny, is_3d)\n    A = make_A(\n        im,\n        num_angles,\n        num_channels,\n        center_offset_big,\n        is_masked,\n        geometry=geometry,\n        dist_source_detector=dist_source_detector,\n        magnification=magnification,\n    )\n\n    def f(im):\n        return snp.sum(A._eval(im) ** 2)\n\n    val_1 = jax.grad(f)(im)\n    val_2 = 2 * A.adj(A(im))\n\n    np.testing.assert_allclose(val_1, val_2)\n\n\ndef test_adjoint(\n    is_3d,\n    center_offset_big,\n    is_masked,\n    geometry,\n):\n    Nx, Ny, num_angles, num_channels, dist_source_detector, magnification = BIG_INPUT\n    im = make_im(Nx, Ny, is_3d)\n    A = make_A(\n        im,\n        num_angles,\n        num_channels,\n        center_offset_big,\n        is_masked,\n        geometry=geometry,\n        dist_source_detector=dist_source_detector,\n        magnification=magnification,\n    )\n\n    adjoint_test(A)\n\n\n@pytest.mark.slow\ndef test_prox(\n    is_3d,\n    center_offset_small,\n    is_masked,\n    geometry,\n):\n    Nx, Ny, num_angles, num_channels, dist_source_detector, magnification = SMALL_INPUT\n    im = make_im(Nx, Ny, is_3d)\n    A = make_A(\n        im,\n        num_angles,\n        num_channels,\n        center_offset_small,\n        is_masked,\n        geometry=geometry,\n        dist_source_detector=dist_source_detector,\n        magnification=magnification,\n    )\n\n    sino = A @ im\n    v, _ = scico.random.normal(im.shape, dtype=im.dtype)\n\n    if is_masked:\n        f = SVMBIRExtendedLoss(y=sino, A=A, positivity=False, prox_kwargs={\"maxiter\": 5})\n    else:\n        f = SVMBIRSquaredL2Loss(y=sino, A=A, prox_kwargs={\"maxiter\": 5})\n\n    prox_test(v, f, f.prox, alpha=0.25, rtol=5e-4)\n\n\n@pytest.mark.slow\ndef test_prox_weights(\n    is_3d,\n    center_offset_small,\n    is_masked,\n    geometry,\n):\n    Nx, Ny, num_angles, num_channels, dist_source_detector, magnification = SMALL_INPUT\n    im = make_im(Nx, Ny, is_3d)\n    A = make_A(\n        im,\n        num_angles,\n        num_channels,\n        center_offset_small,\n        is_masked,\n        geometry=geometry,\n        dist_source_detector=dist_source_detector,\n        magnification=magnification,\n    )\n\n    sino = A @ im\n    v, _ = scico.random.normal(im.shape, dtype=im.dtype)\n\n    # test with weights\n    weights, _ = scico.random.uniform(sino.shape, dtype=im.dtype)\n    W = scico.linop.Diagonal(weights)\n\n    if is_masked:\n        f = SVMBIRExtendedLoss(y=sino, A=A, W=W, positivity=False, prox_kwargs={\"maxiter\": 5})\n    else:\n        f = SVMBIRSquaredL2Loss(y=sino, A=A, W=W, prox_kwargs={\"maxiter\": 5})\n\n    prox_test(v, f, f.prox, alpha=0.25, rtol=5e-5)\n\n\ndef test_prox_cg(\n    is_3d,\n    weight_type,\n    center_offset_small,\n    is_masked,\n    geometry,\n):\n    Nx, Ny, num_angles, num_channels, dist_source_detector, magnification = SMALL_INPUT\n    im = make_im(Nx, Ny, is_3d=is_3d) / Nx * 10\n    A = make_A(\n        im,\n        num_angles,\n        num_channels,\n        center_offset_small,\n        is_masked=is_masked,\n        geometry=geometry,\n        dist_source_detector=dist_source_detector,\n        magnification=magnification,\n    )\n    y = A @ im\n    A_colsum = A.H @ snp.ones(\n        y.shape, dtype=snp.float32\n    )  # backproject ones to get sum over cols of A\n    if is_masked:\n        mask = np.asarray(A_colsum) > 0  # cols of A which are not all zeros\n    else:\n        mask = np.ones(im.shape) > 0\n\n    W = svmbir.calc_weights(y, weight_type=weight_type).astype(\"float32\")\n    W = snp.array(W)\n    λ = 0.01\n\n    if is_masked:\n        f_sv = SVMBIRExtendedLoss(\n            y=y, A=A, W=Diagonal(W), positivity=False, prox_kwargs={\"maxiter\": 5}\n        )\n    else:\n        f_sv = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(W), prox_kwargs={\"maxiter\": 5})\n\n    f_wg = SquaredL2Loss(y=y, A=A, W=Diagonal(W), prox_kwargs={\"tol\": 5e-4})\n\n    v, _ = scico.random.normal(im.shape, dtype=im.dtype)\n    v *= im.max() * 0.5\n\n    xprox_sv = f_sv.prox(v, λ)\n    xprox_cg = f_wg.prox(v, λ)  # this uses cg\n\n    assert snp.linalg.norm(xprox_sv[mask] - xprox_cg[mask]) / snp.linalg.norm(xprox_sv[mask]) < 5e-4\n\n\ndef test_approx_prox(\n    is_3d,\n    weight_type,\n    center_offset_big,\n    is_masked,\n    positivity,\n    geometry,\n    delta_channel,\n    delta_pixel,\n):\n    Nx, Ny, num_angles, num_channels, dist_source_detector, magnification = SMALL_INPUT\n    im = make_im(Nx, Ny, is_3d)\n    A = make_A(\n        im,\n        num_angles,\n        num_channels,\n        center_offset_big,\n        is_masked,\n        geometry=geometry,\n        dist_source_detector=dist_source_detector,\n        magnification=magnification,\n        delta_channel=delta_channel,\n        delta_pixel=delta_pixel,\n    )\n\n    y = A @ im\n    W = svmbir.calc_weights(y, weight_type=weight_type).astype(\"float32\")\n    W = snp.array(W)\n    λ = 0.01\n\n    v, _ = scico.random.normal(im.shape, dtype=im.dtype)\n    if is_masked or positivity:\n        f = SVMBIRExtendedLoss(\n            y=y, A=A, W=Diagonal(W), positivity=positivity, prox_kwargs={\"maxiter\": 5}\n        )\n    else:\n        f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(W), prox_kwargs={\"maxiter\": 5})\n\n    xprox = snp.array(f.prox(v, lam=λ))\n\n    if is_masked or positivity:\n        f_approx = SVMBIRExtendedLoss(\n            y=y, A=A, W=Diagonal(W), prox_kwargs={\"maxiter\": 2}, positivity=positivity\n        )\n    else:\n        f_approx = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(W), prox_kwargs={\"maxiter\": 2})\n\n    xprox_approx = snp.array(f_approx.prox(v, lam=λ, v0=xprox))\n\n    assert snp.linalg.norm(xprox - xprox_approx) / snp.linalg.norm(xprox) < 5e-5\n"
  },
  {
    "path": "scico/test/linop/xray/test_symcone.py",
    "content": "import numpy as np\n\nimport pytest\n\nfrom scico import metric\nfrom scico.examples import create_circular_phantom\nfrom scico.linop.xray.symcone import (\n    AxiallySymmetricVolume,\n    SymConeXRayTransform,\n    _volume_by_axial_symmetry,\n)\nfrom scipy.ndimage import gaussian_filter\n\n\nclass TestAxialSymm:\n    def setup_method(self, method):\n        N = 64\n        self.N = N\n        self.x2d = create_circular_phantom((N, N), [0.4 * N, 0.2 * N, 0.1 * N], [1, 0, 0.5])\n        self.x3d = create_circular_phantom((N, N, N), [0.4 * N, 0.2 * N, 0.1 * N], [1, 0, 0.5])\n        self.x2d = gaussian_filter(self.x2d, 1.0)\n        self.x3d = gaussian_filter(self.x3d, 1.0)\n\n    @pytest.mark.parametrize(\"axis\", [0, 1])\n    def test_vbas(self, axis):\n        v0 = _volume_by_axial_symmetry(self.x2d, axis=axis)\n        assert metric.rel_res(self.x3d, v0) < 5e-2\n\n        offset = -3\n        x2dr = np.roll(self.x2d, offset, axis=1 - axis)\n        Nh = (self.N + 1) / 2 - 1\n        v1 = _volume_by_axial_symmetry(x2dr, axis=axis, center=Nh + offset)\n        assert metric.rel_res(v0, v1) < 1e-5\n\n        zrange = np.arange(-Nh, 0)\n        v2 = _volume_by_axial_symmetry(self.x2d, axis=axis, zrange=zrange)\n        assert metric.rel_res(self.x3d[..., 0 : self.N // 2], v2) < 5e-2\n\n        A = AxiallySymmetricVolume((self.N, self.N), axis=axis)\n        vl = A(self.x2d)\n        assert metric.rel_res(v0, vl) < 1e-7\n\n\nclass TestAbelCone:\n    def setup_method(self, method):\n        N = 64\n        self.N = N\n        self.x2d = create_circular_phantom((N, N), [0.4 * N, 0.2 * N, 0.1 * N], [1, 0, 0.5])\n        self.x3d = create_circular_phantom((N, N, N), [0.4 * N, 0.2 * N, 0.1 * N], [1, 0, 0.5])\n        self.x2d = gaussian_filter(self.x2d, 1.0)\n        self.x3d = gaussian_filter(self.x3d, 1.0)\n\n    @pytest.mark.parametrize(\"num_slabs\", [1, 2, 3])\n    def test_2d(self, num_slabs):\n        A = SymConeXRayTransform(self.x2d.shape, 1e8, 1e8 + 1, num_slabs=num_slabs)\n        ya = A(self.x2d)\n        x2ds = _volume_by_axial_symmetry(self.x2d, axis=0)\n        ys = np.sum(x2ds, axis=1)\n        assert metric.rel_res(ys, ya) < 1e-6\n\n    @pytest.mark.parametrize(\"num_slabs\", [1, 2, 3])\n    def test_2d_unequal(self, num_slabs):\n        x2dc = self.x2d[1:-1]\n        A = SymConeXRayTransform(x2dc.shape, 1e8, 1e8 + 1, num_slabs=num_slabs)\n        ya = A(x2dc)\n        x2ds = _volume_by_axial_symmetry(x2dc, axis=0)\n        ys = np.sum(x2ds, axis=1)\n        assert metric.rel_res(ys, ya) < 1e-6\n\n    @pytest.mark.parametrize(\"num_slabs\", [1, 2, 3])\n    def test_3d(self, num_slabs):\n        A = SymConeXRayTransform(self.x3d.shape, 1e8, 1e8 + 1, num_slabs=num_slabs)\n        ya = A(self.x3d)\n        ys = np.sum(self.x3d, axis=1)\n        assert metric.rel_res(ys, ya) < 1e-6\n\n    @pytest.mark.parametrize(\"num_slabs\", [1, 2, 3])\n    def test_3d_unequal(self, num_slabs):\n        x3dc = self.x3d[1:-1, 2:-2]\n        A = SymConeXRayTransform(x3dc.shape, 1e8, 1e8 + 1, num_slabs=num_slabs)\n        ya = A(x3dc)\n        ys = np.sum(x3dc, axis=1)\n        assert metric.rel_res(ys, ya) < 1e-6\n\n    @pytest.mark.parametrize(\"num_slabs\", [1, 2, 3])\n    def test_2d3d_unequal(self, num_slabs):\n        A2d = SymConeXRayTransform(self.x2d.shape, 5e1, 6e1, num_slabs=num_slabs)\n        A3d = SymConeXRayTransform(self.x3d.shape, 5e1, 6e1, num_slabs=num_slabs)\n        y2d = A2d(self.x2d)\n        y3d = A3d(self.x3d)\n        assert metric.rel_res(y3d, y2d) < 2e-2\n\n    @pytest.mark.parametrize(\"axis\", [0, 1])\n    def test_proj_axis(self, axis):\n        N = self.N\n        N2 = N // 2\n        N4 = N // 4\n        x = np.zeros((N, N))\n        if axis == 0:\n            x[N2 - 1 : N2 + 1, N4 - 1 : N4 + 1] = 1\n        else:\n            x[N4 - 1 : N4 + 1, N2 - 1 : N2 + 1] = 1\n        A = SymConeXRayTransform(x.shape, 1e2, 2e2, axis=axis, num_slabs=1)\n        y = A(x)\n        if axis == 0:\n            assert np.sum(np.sum(y, axis=1) > 0) <= 4\n            assert np.sum(np.sum(y, axis=0) > 0) >= N2\n        else:\n            assert np.sum(np.sum(y, axis=0) > 0) <= 4\n            assert np.sum(np.sum(y, axis=1) > 0) >= N2\n\n    @pytest.mark.parametrize(\"axis\", [0, 1])\n    def test_fdk(self, axis):\n        A = SymConeXRayTransform(self.x3d.shape, 1e2, 2e2, axis=axis, num_slabs=1)\n        y = A(self.x3d)\n        z = A.fdk(y)\n        assert metric.rel_res(self.x2d, z) < 0.2\n"
  },
  {
    "path": "scico/test/linop/xray/test_xray_2d.py",
    "content": "import numpy as np\n\nimport jax\nimport jax.numpy as jnp\n\nimport pytest\n\nimport scico\nimport scico.linop\nfrom scico.linop.xray import XRayTransform2D\nfrom scico.metric import psnr\n\n\n@pytest.mark.filterwarnings(\"error\")\ndef test_init():\n    input_shape = (3, 3)\n\n    # no warning with default settings, even at 45 degrees\n    H = XRayTransform2D(input_shape, jnp.array([jnp.pi / 4]))\n\n    # no warning if we project orthogonally with oversized pixels\n    H = XRayTransform2D(input_shape, jnp.array([0]), dx=jnp.array([1, 1]))\n\n    # warning if the projection angle changes\n    with pytest.warns(UserWarning):\n        H = XRayTransform2D(input_shape, jnp.array([0.1]), dx=jnp.array([1.1, 1.1]))\n\n    # warning if the pixels get any larger\n    with pytest.warns(UserWarning):\n        H = XRayTransform2D(input_shape, jnp.array([0]), dx=jnp.array([1.1, 1.1]))\n\n\ndef test_apply():\n    im_shape = (12, 13)\n    num_angles = 10\n    x = jnp.ones(im_shape)\n\n    angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)\n\n    # general projection\n    H = XRayTransform2D(x.shape, angles)\n    y = H @ x\n    assert y.shape[0] == (num_angles)\n\n    # fixed det_count\n    det_count = 14\n    H = XRayTransform2D(x.shape, angles, det_count=det_count)\n    y = H @ x\n    assert y.shape[1] == det_count\n\n\ndef test_apply_adjoint():\n    im_shape = (12, 13)\n    num_angles = 10\n    x = jnp.ones(im_shape, dtype=jnp.float32)\n\n    angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)\n\n    # general projection\n    H = XRayTransform2D(x.shape, angles)\n    y = H @ x\n    assert y.shape[0] == (num_angles)\n\n    # adjoint\n    bp = H.T @ y\n    assert scico.linop.valid_adjoint(\n        H, H.T, eps=1e-4\n    )  # associative reductions might cause small errors, hence 1e-5\n\n    # fixed det_length\n    det_count = 14\n    H = XRayTransform2D(x.shape, angles, det_count=det_count)\n    y = H @ x\n    assert y.shape[1] == det_count\n\n\ndef test_matched_adjoint():\n    \"\"\"See https://github.com/lanl/scico/issues/560.\"\"\"\n    N = 16\n    det_count = int(N * 1.05 / np.sqrt(2.0))\n    dx = 1.0 / np.sqrt(2)\n    n_projection = 3\n    angles = np.linspace(0, np.pi, n_projection, endpoint=False)\n    A = XRayTransform2D((N, N), angles, det_count=det_count, dx=dx)\n    assert scico.linop.valid_adjoint(A, A.T, eps=1e-5)\n\n\n@pytest.mark.parametrize(\"dx\", [0.5, 1.0 / np.sqrt(2)])\n@pytest.mark.parametrize(\"det_count_factor\", [1.02 / np.sqrt(2.0), 1.0])\ndef test_fbp(dx, det_count_factor):\n    N = 256\n    x_gt = np.zeros((N, N), dtype=np.float32)\n    N4 = N // 4\n    x_gt[N4:-N4, N4:-N4] = 1.0\n\n    det_count = int(det_count_factor * N)\n    n_proj = 360\n    angles = np.linspace(0, np.pi, n_proj, endpoint=False)\n    A = XRayTransform2D(x_gt.shape, angles, det_count=det_count, dx=dx)\n    y = A(x_gt)\n    x_fbp = A.fbp(y)\n    assert psnr(x_gt, x_fbp) > 28\n\n\ndef test_fbp_jit():\n    N = 64\n    x_gt = np.ones((N, N), dtype=np.float32)\n\n    det_count = N\n    n_proj = 90\n    angles = np.linspace(0, np.pi, n_proj, endpoint=False)\n    A = XRayTransform2D(x_gt.shape, angles, det_count=det_count)\n    y = A(x_gt)\n    fbp = jax.jit(A.fbp)\n    x_fbp = fbp(y)\n"
  },
  {
    "path": "scico/test/linop/xray/test_xray_3d.py",
    "content": "import numpy as np\n\nimport jax.numpy as jnp\n\nimport scico.linop\nfrom scico.linop.xray import XRayTransform3D\n\n\ndef test_matched_adjoint():\n    \"\"\"See https://github.com/lanl/scico/issues/560.\"\"\"\n    N = 16\n    det_count = int(N * 1.05 / np.sqrt(2.0))\n    n_projection = 3\n\n    input_shape = (N, N, N)\n    det_shape = (det_count, det_count)\n\n    M = XRayTransform3D.matrices_from_euler_angles(\n        input_shape,\n        det_shape,\n        \"X\",\n        np.linspace(0, np.pi, n_projection, endpoint=False)[:, None],  # make (n_projection, 1)\n    )\n    H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)\n\n    assert scico.linop.valid_adjoint(H, H.T, eps=1e-5)\n\n\ndef test_scaling():\n    x = jnp.zeros((4, 4, 1))\n    x = x.at[1:3, 1:3, 0].set(1.0)\n\n    input_shape = x.shape\n    det_shape = x.shape[:2]\n\n    # default spacing\n    M = XRayTransform3D.matrices_from_euler_angles(input_shape, det_shape, \"X\", [[0.0]])\n    H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)\n    # fmt: off\n    truth = jnp.array(\n        [[[0.0, 0.0, 0.0, 0.0],\n          [0.0, 1.0, 1.0, 0.0],\n          [0.0, 1.0, 1.0, 0.0],\n          [0.0, 0.0, 0.0, 0.0]]]\n    )  # fmt: on\n    np.testing.assert_allclose(H @ x, truth)\n\n    # bigger voxels in the x (first index) direction\n    M = XRayTransform3D.matrices_from_euler_angles(\n        input_shape, det_shape, \"X\", [[0.0]], voxel_spacing=[2.0, 1.0, 1.0]\n    )\n    H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)\n    # fmt: off\n    truth = jnp.array(\n        [[[0. , 0.5, 0.5, 0. ],\n          [0. , 0.5, 0.5, 0. ],\n          [0. , 0.5, 0.5, 0. ],\n          [0. , 0.5, 0.5, 0. ]]]\n    )  # fmt: on\n    np.testing.assert_allclose(H @ x, truth)\n\n    # bigger detector pixels in the x (first index) direction\n    M = XRayTransform3D.matrices_from_euler_angles(\n        input_shape, det_shape, \"X\", [[0.0]], det_spacing=[2.0, 1.0]\n    )\n    H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)\n    # fmt: off\n    truth = None  # fmt: on  # TODO: Check this case more closely.\n    # np.testing.assert_allclose(H @ x, truth)\n"
  },
  {
    "path": "scico/test/linop/xray/test_xray_util.py",
    "content": "import numpy as np\n\nimport jax\nimport jax.numpy as jnp\nfrom jax.scipy.spatial.transform import Rotation\n\nimport pytest\n\nimport scipy.ndimage\nfrom scico.linop.xray import (\n    center_image,\n    image_alignment_rotation,\n    image_centroid,\n    rotate_volume,\n    volume_alignment_rotation,\n)\n\ntry:\n    import astra  # noqa\n\n    have_astra = True\nexcept ModuleNotFoundError as e:\n    if e.name == \"astra\":\n        have_astra = False\n    else:\n        raise e\n\n\ndef test_image_centroid():\n    v = np.zeros((4, 5))\n    v[1:-1, 1:-1] = 1\n    assert image_centroid(v) == (1.5, 2.0)\n    image_centroid(v, center_offset=True) == (0.0, 0.0)\n\n\ndef test_center_image():\n    u = np.zeros((4, 5))\n    u[0:-2, 0:-2] = 1\n    v = center_image(u)\n    np.testing.assert_allclose(image_centroid(v, center_offset=True), (0.0, 0.0), atol=1e-7)\n    v = center_image(u, axes=(0,))\n    np.testing.assert_allclose(image_centroid(v, center_offset=True), (0.0, -1.0), atol=1e-7)\n\n\ndef test_rotate_volume():\n    vol = np.arange(27).reshape((3, 3, 3))\n    rot = Rotation.from_euler(\"XY\", jnp.array([90.0, 90.0]), degrees=True)\n    vol_rot = rotate_volume(vol, rot)\n    np.testing.assert_allclose(vol.transpose((1, 2, 0)), vol_rot, rtol=1e-7)\n\n\ndef align_test_tol():\n    if jax.devices()[0].device_kind == \"cpu\":\n        tol = 1e-3\n    else:\n        tol = 5e-2  # less accurate on gpu\n    return tol\n\n\n@pytest.mark.skipif(not have_astra, reason=\"astra not installed\")\ndef test_image_alignment():\n    u = np.zeros((256, 256), dtype=np.float32)\n    u[:, 8::16] = 1\n    u[:, 9::16] = 1\n    angle = image_alignment_rotation(u)\n    assert np.abs(angle) < 1e-3\n    ur = scipy.ndimage.rotate(u, 0.75)\n    angle = image_alignment_rotation(ur)\n    assert np.abs(angle - 0.75) < align_test_tol()\n\n\n@pytest.mark.skipif(not have_astra, reason=\"astra not installed\")\ndef test_volume_alignment():\n    u = np.zeros((256, 256, 32), dtype=np.float32)\n    u[8::16, :, 2::6] = 1\n    u[9::16, :, 2::6] = 1\n    u[:, 8::16, 2::6] = 1\n    u[:, 9::16, 2::6] = 1\n    u[8::16, :, 3::6] = 1\n    u[9::16, :, 3::6] = 1\n    u[:, 8::16, 3::6] = 1\n    u[:, 9::16, 3::6] = 1\n    rot = volume_alignment_rotation(u)\n    assert rot.magnitude() < 1e-5\n    ref_rot = Rotation.from_euler(\"XY\", jnp.array([1.6, -0.9]), degrees=True)\n    ur = rotate_volume(u, ref_rot)\n    rot = volume_alignment_rotation(ur)\n    assert (\n        np.abs(ref_rot.as_euler(\"XYZ\", degrees=True) - rot.as_euler(\"XYZ\", degrees=True)).max()\n        < 1e-1\n    )\n"
  },
  {
    "path": "scico/test/numpy/test_blockarray.py",
    "content": "import itertools\nimport operator as op\n\nimport numpy as np\n\nimport jax\nimport jax.numpy as jnp\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico.numpy import BlockArray\nfrom scico.numpy._wrapped_function_lists import testing_functions\nfrom scico.numpy.testing import assert_array_equal\nfrom scico.numpy.util import shape_dtype_rep\nfrom scico.util import rgetattr\n\nmath_ops = [op.add, op.sub, op.mul, op.truediv, op.pow]  # op.floordiv doesn't work on complex\ncomp_ops = [op.le, op.lt, op.ge, op.gt, op.eq]\n\n\ndef make_arbitrary_jax_array(shape, dtype):\n    \"\"\"\n    Make an arbitrary jax array of the given shape and dtype.\n    \"\"\"\n    return jnp.array(np.random.randn(*shape)).astype(dtype)\n\n\ndef sequence_assert_allclose(x, y, *args, **kwargs):\n    \"\"\"Assert sequences x and y have the same length and corresponding\n    elements are allclose.\"\"\"\n    assert len(x) == len(y)\n    for x_i, y_i in zip(x, y):\n        np.testing.assert_allclose(x_i, y_i, *args, **kwargs)\n\n\nclass OperatorsTestObj:\n    operators = math_ops + comp_ops\n\n    def __init__(self, dtype):\n        self.scalar = 1.0\n\n        self.a0 = make_arbitrary_jax_array((2, 3), dtype)\n        self.a1 = make_arbitrary_jax_array((2, 3, 4), dtype)\n        self.a = BlockArray((self.a0, self.a1))\n\n        self.b0 = make_arbitrary_jax_array((2, 3), dtype)\n        self.b1 = make_arbitrary_jax_array((2, 3, 4), dtype)\n        self.b = BlockArray((self.b0, self.b1))\n\n        self.d0 = make_arbitrary_jax_array((3, 2), dtype)\n        self.d1 = make_arbitrary_jax_array((2, 4, 3), dtype)\n        self.d = BlockArray((self.d0, self.d1))\n\n        c0 = make_arbitrary_jax_array((2, 3), dtype)\n        self.c = BlockArray((c0,))\n\n        # A flat device array with same size as self.a & self.b\n        self.flat_da = make_arbitrary_jax_array(self.a.size, dtype)\n        self.flat_nd = np.array(self.flat_da)\n\n        # A device array with length == self.a.num_blocks\n        self.block_da, key = make_arbitrary_jax_array((len(self.a),), dtype)\n\n        # block_da but as a numpy array\n        self.block_nd = np.array(self.block_da)\n\n        self.key = key\n\n\n@pytest.fixture(scope=\"module\", params=[np.float32, np.complex64])\ndef test_operator_obj(request):\n    yield OperatorsTestObj(request.param)\n\n\n# Operations between a blockarray and scalar\n@pytest.mark.parametrize(\"operator\", math_ops + comp_ops)\ndef test_operator_left(test_operator_obj, operator):\n    scalar = test_operator_obj.scalar\n    a = test_operator_obj.a\n    x = operator(scalar, a)\n    y = BlockArray(operator(scalar, a_i) for a_i in a)\n    sequence_assert_allclose(x, y)\n\n\n@pytest.mark.parametrize(\"operator\", math_ops + comp_ops)\ndef test_operator_right(test_operator_obj, operator):\n    scalar = test_operator_obj.scalar\n    a = test_operator_obj.a\n    x = operator(a, scalar)\n    y = BlockArray(operator(a_i, scalar) for a_i in a)\n    sequence_assert_allclose(x, y)\n\n\n# Operations between two blockarrays of same size\n@pytest.mark.parametrize(\"operator\", math_ops + comp_ops)\ndef test_ba_ba_operator(test_operator_obj, operator):\n    a = test_operator_obj.a\n    b = test_operator_obj.b\n    x = operator(a, b)\n    y = BlockArray(operator(a_i, b_i) for a_i, b_i in zip(a, b))\n    sequence_assert_allclose(x, y)\n\n\n# Testing the @ interface for blockarrays of same size, and a blockarray and flattened\n# ndarray/devicearray\ndef test_ba_ba_matmul(test_operator_obj):\n    a = test_operator_obj.a\n    b = test_operator_obj.d\n    c = test_operator_obj.c\n\n    a0 = test_operator_obj.a0\n    a1 = test_operator_obj.a1\n    d0 = test_operator_obj.d0\n    d1 = test_operator_obj.d1\n\n    x = a @ b\n\n    y = BlockArray([a0 @ d0, a1 @ d1])\n    assert x.shape == y.shape\n    sequence_assert_allclose(x, y)\n\n    with pytest.raises(TypeError):\n        z = a @ c\n\n\ndef test_conj(test_operator_obj):\n    a = test_operator_obj.a\n    ac = a.conj()\n\n    assert a.shape == ac.shape\n    sequence_assert_allclose(BlockArray(a_i.conj() for a_i in a), ac)\n\n\ndef test_real(test_operator_obj):\n    a = test_operator_obj.a\n    ac = a.real\n\n    sequence_assert_allclose(BlockArray(a_i.real for a_i in a), ac)\n\n\ndef test_imag(test_operator_obj):\n    a = test_operator_obj.a\n    ac = a.imag\n\n    sequence_assert_allclose(BlockArray(a_i.imag for a_i in a), ac)\n\n\ndef test_ndim(test_operator_obj):\n    assert test_operator_obj.a.ndim == (2, 3)\n    assert test_operator_obj.c.ndim == (2,)\n\n\ndef test_getitem(test_operator_obj):\n    # make a length-4 blockarray\n    a0 = test_operator_obj.a0\n    a1 = test_operator_obj.a1\n    b0 = test_operator_obj.b0\n    b1 = test_operator_obj.b1\n    x = BlockArray([a0, a1, b0, b1])\n\n    # positive indexing\n    np.testing.assert_allclose(x[0], a0)\n    np.testing.assert_allclose(x[1], a1)\n    np.testing.assert_allclose(x[2], b0)\n    np.testing.assert_allclose(x[3], b1)\n\n    # negative indexing\n    np.testing.assert_allclose(x[-4], a0)\n    np.testing.assert_allclose(x[-3], a1)\n    np.testing.assert_allclose(x[-2], b0)\n    np.testing.assert_allclose(x[-1], b1)\n\n\ndef test_split(test_operator_obj):\n    a = test_operator_obj.a\n    np.testing.assert_allclose(a[0], test_operator_obj.a0)\n    np.testing.assert_allclose(a[1], test_operator_obj.a1)\n\n\ndef test_blockarray_from_one_array():\n    # BlockArray(np.jnp.zeros((3,6))) makes a block array\n    # with 3 length-6 blocks\n    x = BlockArray(np.random.randn(3, 6))\n    assert len(x) == 3\n\n\n@pytest.mark.parametrize(\"axis\", [None, 1])\n@pytest.mark.parametrize(\"keepdims\", [True, False])\ndef test_sum_method(test_operator_obj, axis, keepdims):\n    a = test_operator_obj.a\n\n    method_result = a.sum(axis=axis, keepdims=keepdims)\n    snp_result = snp.sum(a, axis=axis, keepdims=keepdims)\n\n    sequence_assert_allclose(method_result, snp_result)\n\n\ndef test_eval_shape_1arg(test_operator_obj):\n    def foo(x):\n        return snp.atleast_3d(x)\n\n    x = test_operator_obj.a\n    es = jax.eval_shape(foo, shape_dtype_rep(x.shape, x.dtype))\n    ba = foo(x)\n    assert es.shape == ba.shape\n    assert es.dtype == ba.dtype\n\n\ndef test_eval_shape_2arg(test_operator_obj):\n    def foo(x, y):\n        return x * y\n\n    x = test_operator_obj.a\n    y = test_operator_obj.b\n\n    args = [\n        BlockArray([jax.ShapeDtypeStruct(b_i.shape, b_i.dtype) for b_i in x]),\n        BlockArray([jax.ShapeDtypeStruct(b_i.shape, b_i.dtype) for b_i in y]),\n    ]\n\n    es = jax.eval_shape(foo, *args)\n    assert es.shape == x.shape\n    assert es.dtype == x.dtype\n\n\ndef test_linear_transpose(test_operator_obj):\n    fun = lambda x: 2 * x\n    x = test_operator_obj.a\n    tfun_ba = jax.linear_transpose(fun, x)\n    tfun_dts = jax.linear_transpose(fun, shape_dtype_rep(x.shape, x.dtype))\n    assert tfun_ba.args == tfun_dts.args\n\n\n@pytest.mark.parametrize(\"operator\", [snp.dot, snp.matmul])\ndef test_ba_ba_dot(test_operator_obj, operator):\n    a = test_operator_obj.a\n    d = test_operator_obj.d\n    a0 = test_operator_obj.a0\n    a1 = test_operator_obj.a1\n    d0 = test_operator_obj.d0\n    d1 = test_operator_obj.d1\n\n    x = operator(a, d)\n    y = BlockArray([operator(a0, d0), operator(a1, d1)])\n    sequence_assert_allclose(x, y)\n\n\n# reduction tests\nreduction_funcs = [\n    snp.sum,\n    snp.linalg.norm,\n]\n\nreal_reduction_funcs = []\n\n\nclass BlockArrayReductionObj:\n    def __init__(self, dtype):\n        key = None\n\n        a0 = make_arbitrary_jax_array(shape=(2, 3), dtype=dtype)\n        a1 = make_arbitrary_jax_array(shape=(2, 3, 4), dtype=dtype)\n        b0 = make_arbitrary_jax_array(shape=(2, 3), dtype=dtype)\n        b1 = make_arbitrary_jax_array(shape=(2, 3), dtype=dtype)\n        c0 = make_arbitrary_jax_array(shape=(2, 3), dtype=dtype)\n        c1 = make_arbitrary_jax_array(shape=(3,), dtype=dtype)\n\n        self.a = BlockArray((a0, a1))\n        self.b = BlockArray((b0, b1))\n        self.c = BlockArray((c0, c1))\n\n\n@pytest.fixture(scope=\"module\")  # so that random objects are cached\ndef reduction_obj(request):\n    yield BlockArrayReductionObj(request.param)\n\n\nREDUCTION_PARAMS = dict(\n    argnames=\"reduction_obj, func\",\n    argvalues=(\n        list(zip(itertools.repeat(np.float32), reduction_funcs))\n        + list(zip(itertools.repeat(np.complex64), reduction_funcs))\n        + list(zip(itertools.repeat(np.float32), real_reduction_funcs))\n    ),\n    indirect=[\"reduction_obj\"],\n)\n\n\n@pytest.mark.parametrize(**REDUCTION_PARAMS)\ndef test_reduce(reduction_obj, func):\n    x = func(reduction_obj.a)\n    x_jit = jax.jit(func)(reduction_obj.a)\n    y = func(snp.ravel(reduction_obj.a))\n    np.testing.assert_allclose(x, x_jit, atol=1e-6)  # test jitted function\n    np.testing.assert_allclose(x, y, atol=1e-6)  # test for correctness\n\n\n@pytest.mark.parametrize(**REDUCTION_PARAMS)\n@pytest.mark.parametrize(\"axis\", (0, 1))\ndef test_reduce_axis(reduction_obj, func, axis):\n    f = lambda x: func(x, axis=axis)\n    x = f(reduction_obj.a)\n    x_jit = jax.jit(f)(reduction_obj.a)\n\n    sequence_assert_allclose(x, x_jit, rtol=1e-4)  # test jitted function\n\n    # test for correctness\n    y0 = func(reduction_obj.a[0], axis=axis)\n    y1 = func(reduction_obj.a[1], axis=axis)\n    y = BlockArray((y0, y1))\n    sequence_assert_allclose(x, y)\n\n\n@pytest.mark.parametrize(**REDUCTION_PARAMS)\ndef test_reduce_singleton(reduction_obj, func):\n    # Case where one block is reduced to a singleton\n    f = lambda x: func(x, axis=0)\n    x = f(reduction_obj.c)\n    x_jit = jax.jit(f)(reduction_obj.c)\n\n    sequence_assert_allclose(x, x_jit, rtol=1e-4)  # test jitted function\n\n    y0 = func(reduction_obj.c[0], axis=0)\n    y1 = func(reduction_obj.c[1], axis=0)[None]  # Ensure size (1,)\n    y = BlockArray((y0, y1))\n    sequence_assert_allclose(x, y)\n\n\nclass TestCreators:\n    def setup_method(self, method):\n        np.random.seed(12345)\n        self.a_shape = (2, 3)\n        self.b_shape = (2, 4, 3)\n        self.c_shape = (1,)\n        self.shape = (self.a_shape, self.b_shape, self.c_shape)\n        self.size = np.prod(self.a_shape) + np.prod(self.b_shape) + np.prod(self.c_shape)\n\n    def test_zeros(self):\n        x = snp.zeros(self.shape, dtype=np.float32)\n        assert x.shape == self.shape\n        assert snp.all(x == 0)\n\n    def test_empty(self):\n        x = snp.empty(self.shape, dtype=np.float32)\n        assert x.shape == self.shape\n        assert snp.all(x == 0)\n\n    def test_ones(self):\n        x = snp.ones(self.shape, dtype=np.float32)\n        assert x.shape == self.shape\n        assert snp.all(x == 1)\n\n    def test_full(self):\n        fill_value = np.float32(np.random.randn())\n        x = snp.full(self.shape, fill_value=fill_value, dtype=np.float32)\n        assert x.shape == self.shape\n        assert x.dtype == np.float32\n        assert snp.all(x == fill_value)\n\n    def test_full_nodtype(self):\n        fill_value = np.float32(np.random.randn())\n        x = snp.full(self.shape, fill_value=fill_value, dtype=None)\n        assert x.shape == self.shape\n        assert x.dtype == fill_value.dtype\n        assert snp.all(x == fill_value)\n\n\ndef test_list_triggering():\n    device_list = 4 * [jax.devices()[0]]\n    ba = snp.ones((3, 3), device=device_list)\n    assert isinstance(ba, BlockArray)\n    assert ba.shape == 4 * ((3, 3),)\n\n\n# testing function tests\n@pytest.mark.parametrize(\"func\", testing_functions)\ndef test_test_func(func):\n    a = snp.array([1.0, 2.0])\n    b = snp.blockarray((a, a))\n    f = rgetattr(snp, func)\n    retval = f(b, b)\n    assert retval is None\n\n\n# tests added for the BlockArray refactor\n@pytest.fixture\ndef x():\n    # any BlockArray, arbitrary shape, content, type\n    return BlockArray([[[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42.0]])\n\n\n@pytest.fixture\ndef y():\n    # another BlockArray, content, type, shape matches x\n    return BlockArray([[[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]], [-2.0]])\n\n\n@pytest.mark.parametrize(\"op\", [op.neg, op.pos, op.abs])\ndef test_unary(op, x):\n    actual = op(x)\n    expected = BlockArray(op(x_i) for x_i in x)\n    assert_array_equal(actual, expected)\n    assert actual.dtype == expected.dtype\n\n\n@pytest.mark.parametrize(\n    \"op\",\n    [\n        op.mul,\n        op.mod,\n        op.lt,\n        op.le,\n        op.gt,\n        op.ge,\n        op.floordiv,\n        op.eq,\n        op.add,\n        op.truediv,\n        op.sub,\n        op.ne,\n    ],\n)\ndef test_elementwise_binary(op, x, y):\n    actual = op(x, y)\n    expected = BlockArray(op(x_i, y_i) for x_i, y_i in zip(x, y))\n    assert_array_equal(actual, expected)\n    assert actual.dtype == expected.dtype\n\n\ndef test_not_implemented_binary(x):\n    with pytest.raises(TypeError, match=r\"unsupported operand type\\(s\\)\"):\n        y = x + \"a string\"\n\n\ndef test_matmul(x):\n    # x is ((2, 3), (1,))\n    # y is ((3, 1), (1, 2))\n    y = BlockArray([[[1.0], [2.0], [3.0]], [[0.0, 1.0]]])\n    actual = x @ y\n    expected = BlockArray([[[14.0], [0.0]], [0.0, 42.0]])\n    assert_array_equal(actual, expected)\n    assert actual.dtype == expected.dtype\n\n\ndef test_property():\n    x = BlockArray(([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [0.0]))\n    actual = x.shape\n    expected = ((2, 3), (1,))\n    assert actual == expected\n\n\ndef test_method():\n    x = BlockArray(([[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42.0]))\n    actual = x.max()\n    expected = BlockArray([[3.0], [42.0]])\n    assert_array_equal(actual, expected)\n    assert actual.dtype == expected.dtype\n\n\ndef test_stack():\n    x = BlockArray(([[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]]))\n    assert x.stack().shape == (2, 3)\n    assert x.stack(axis=1).shape == (3, 2)\n    y = BlockArray(([[1.0, 2.0, 3.0], [0.0, 0.0]]))\n    with pytest.raises(ValueError):\n        z = y.stack()\n\n\ndef test_ravel():\n    # snp.ravel completely flattens a BlockArray\n    ba = snp.ones([[2, 3], [3, 4]])\n    assert snp.ravel(ba).shape == (2 * 3 + 3 * 4,)\n\n    # snp.ravel also flattens an Array\n    arr = snp.ones((2, 3))\n    assert snp.ravel(arr).shape == (2 * 3,)\n\n    # ba.flatten maps over BlockArray blocks\n    assert ba.flatten().shape == ((2 * 3,), (3 * 4,))\n\n    # ba.ravel also maps over BlockArray blocks\n    assert ba.ravel().shape == ((2 * 3,), (3 * 4,))\n\n    # snp.ravel works with scalar blocks\n    # fmt: off\n    scalar_ba = snp.ones(\n        [\n            [],\n            [1,],\n            [1, 1],\n        ]\n    )  # fmt: on\n    assert_array_equal(snp.ravel(scalar_ba), [1, 1, 1])\n"
  },
  {
    "path": "scico/test/numpy/test_numpy.py",
    "content": "import numpy as np\n\nimport jax\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico.numpy import _wrappers\n\n\ndef on_cpu():\n    return jax.devices()[0].device_kind == \"cpu\"\n\n\ndef check_results(jout, sout):\n    if isinstance(jout, (tuple, list)) and isinstance(sout, (tuple, list)):\n        # multiple outputs from the function\n        for x, y in zip(jout, sout):\n            np.testing.assert_allclose(x, y, rtol=1e-4)\n    elif isinstance(jout, jax.Array) and isinstance(sout, jax.Array):\n        # single array output from the function\n        np.testing.assert_allclose(sout, jout, rtol=1e-4)\n    elif jout.shape == () and sout.shape == ():\n        # single scalar output from the function\n        np.testing.assert_allclose(jout, sout, rtol=1e-4)\n    else:\n        # some type of output that isn't being captured?\n        raise TypeError(f\"Unexpected input type {type(jout)} or {type(sout)}.\")\n\n\ndef test_reshape_array():\n    a = np.random.randn(4, 4)\n    np.testing.assert_allclose(snp.reshape(a.ravel(), (4, 4)), a)\n\n\ndef test_ufunc_abs():\n    A = snp.array([-1, 2, 5])\n    res = snp.array([1, 2, 5])\n    np.testing.assert_allclose(snp.abs(A), res)\n\n    A = snp.array([-1, -1, -1])\n    res = snp.array([1, 1, 1])\n    np.testing.assert_allclose(snp.abs(A), res)\n\n    Ba = snp.blockarray((snp.array([-1, 2, 5]),))\n    res = snp.blockarray((snp.array([1, 2, 5]),))\n    np.testing.assert_allclose(snp.abs(Ba).ravel(), res.ravel())\n\n    Ba = snp.blockarray((snp.array([-1, -1, -1]),))\n    res = snp.blockarray((snp.array([1, 1, 1]),))\n    np.testing.assert_allclose(snp.abs(Ba).ravel(), res.ravel())\n\n    Ba = snp.blockarray((snp.array([-1, 2, -3]), snp.array([1, -2, 3])))\n    res = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2, 3])))\n    np.testing.assert_allclose(snp.abs(Ba).ravel(), res.ravel())\n\n\ndef test_ufunc_maximum():\n    A = snp.array([1, 2, 5])\n    B = snp.array([2, 3, 4])\n    res = snp.array([2, 3, 5])\n    np.testing.assert_allclose(snp.maximum(A, B), res)\n    np.testing.assert_allclose(snp.maximum(B, A), res)\n\n    A = snp.array([1, 1, 1])\n    B = snp.array([2, 2, 2])\n    res = snp.array([2, 2, 2])\n    np.testing.assert_allclose(snp.maximum(A, B), res)\n    np.testing.assert_allclose(snp.maximum(B, A), res)\n\n    A = 4\n    B = snp.array([3, 5, 2])\n    res = snp.array([4, 5, 4])\n    np.testing.assert_allclose(snp.maximum(A, B), res)\n    np.testing.assert_allclose(snp.maximum(B, A), res)\n\n    A = 5\n    B = 6\n    res = 6\n    np.testing.assert_allclose(snp.maximum(A, B), res)\n    np.testing.assert_allclose(snp.maximum(B, A), res)\n\n    A = snp.array([1, 2, 3])\n    B = snp.array([2, 3, 4])\n    C = snp.array([5, 6])\n    D = snp.array([2, 7])\n    Ba = snp.blockarray((A, C))\n    Bb = snp.blockarray((B, D))\n    res = snp.blockarray((snp.array([2, 3, 4]), snp.array([5, 7])))\n    Bmax = snp.maximum(Ba, Bb)\n    snp.testing.assert_allclose(Bmax, res)\n\n    A = snp.array([1, 6, 3])\n    B = snp.array([6, 3, 8])\n    C = 5\n    Ba = snp.blockarray((A, B))\n    res = snp.blockarray((snp.array([5, 6, 5]), snp.array([6, 5, 8])))\n    Bmax = snp.maximum(Ba, C)\n    snp.testing.assert_allclose(Bmax, res)\n\n\ndef test_ufunc_sign():\n    A = snp.array([10, -5, 0])\n    res = snp.array([1, -1, 0])\n    np.testing.assert_allclose(snp.sign(A), res)\n\n    Ba = snp.blockarray((snp.array([10, -5, 0]),))\n    res = snp.blockarray((snp.array([1, -1, 0]),))\n    snp.testing.assert_allclose(snp.sign(Ba), res)\n\n    Ba = snp.blockarray((snp.array([10, -5, 0]), snp.array([0, 5, -6])))\n    res = snp.blockarray((snp.array([1, -1, 0]), snp.array([0, 1, -1])))\n    snp.testing.assert_allclose(snp.sign(Ba), res)\n\n\ndef test_ufunc_where():\n    A = snp.array([1, 2, 4, 5])\n    B = snp.array([-1, -1, -1, -1])\n    cond = snp.array([False, False, True, True])\n    res = snp.array([-1, -1, 4, 5])\n    np.testing.assert_allclose(snp.where(cond, A, B), res)\n\n    Ba = snp.blockarray((snp.array([1, 2, 4, 5]),))\n    Bb = snp.blockarray((snp.array([-1, -1, -1, -1]),))\n    Bcond = snp.blockarray((snp.array([False, False, True, True]),))\n    Bres = snp.blockarray((snp.array([-1, -1, 4, 5]),))\n    assert snp.where(Bcond, Ba, Bb).shape == Bres.shape\n    np.testing.assert_allclose(snp.where(Bcond, Ba, Bb).ravel(), Bres.ravel())\n\n    Ba = snp.blockarray((snp.array([1, 2, 4, 5]), snp.array([1, 2, 4, 5])))\n    Bb = snp.blockarray((snp.array([-1, -1, -1, -1]), snp.array([-1, -1, -1, -1])))\n    Bcond = snp.blockarray(\n        (snp.array([False, False, True, True]), snp.array([True, True, False, False]))\n    )\n    Bres = snp.blockarray((snp.array([-1, -1, 4, 5]), snp.array([1, 2, -1, -1])))\n    assert snp.where(Bcond, Ba, Bb).shape == Bres.shape\n    np.testing.assert_allclose(snp.where(Bcond, Ba, Bb).ravel(), Bres.ravel())\n\n\ndef test_ufunc_true_divide():\n    A = snp.array([1, 2, 3])\n    B = snp.array([3, 3, 3])\n    res = snp.array([0.33333333, 0.66666667, 1.0])\n    np.testing.assert_allclose(snp.true_divide(A, B), res)\n\n    A = snp.array([1, 2, 3])\n    B = 3\n    res = snp.array([0.33333333, 0.66666667, 1.0])\n    np.testing.assert_allclose(snp.true_divide(A, B), res)\n\n    Ba = snp.blockarray((snp.array([1, 2, 3]),))\n    Bb = snp.blockarray((snp.array([3, 3, 3]),))\n    res = snp.blockarray((snp.array([0.33333333, 0.66666667, 1.0]),))\n    snp.testing.assert_allclose(snp.true_divide(Ba, Bb), res)\n\n    Ba = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2])))\n    Bb = snp.blockarray((snp.array([3, 3, 3]), snp.array([2, 2])))\n    res = snp.blockarray((snp.array([0.33333333, 0.66666667, 1.0]), snp.array([0.5, 1.0])))\n    snp.testing.assert_allclose(snp.true_divide(Ba, Bb), res)\n\n    Ba = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2])))\n    A = 2\n    res = snp.blockarray((snp.array([0.5, 1.0, 1.5]), snp.array([0.5, 1.0])))\n    snp.testing.assert_allclose(snp.true_divide(Ba, A), res)\n\n\ndef test_ufunc_floor_divide():\n    A = snp.array([1, 2, 3])\n    B = snp.array([3, 3, 3])\n    res = snp.array([0, 0, 1.0])\n    np.testing.assert_allclose(snp.floor_divide(A, B), res)\n\n    A = snp.array([4, 2, 3])\n    B = 3\n    res = snp.array([1.0, 0, 1.0])\n    np.testing.assert_allclose(snp.floor_divide(A, B), res)\n\n    Ba = snp.blockarray((snp.array([1, 2, 3]),))\n    Bb = snp.blockarray((snp.array([3, 3, 3]),))\n    res = snp.blockarray((snp.array([0, 0, 1.0]),))\n    snp.testing.assert_allclose(snp.floor_divide(Ba, Bb), res)\n\n    Ba = snp.blockarray((snp.array([1, 7, 3]), snp.array([1, 2])))\n    Bb = snp.blockarray((snp.array([3, 3, 3]), snp.array([2, 2])))\n    res = snp.blockarray((snp.array([0, 2, 1.0]), snp.array([0, 1.0])))\n    snp.testing.assert_allclose(snp.floor_divide(Ba, Bb), res)\n\n    Ba = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2])))\n    A = 2\n    res = snp.blockarray((snp.array([0, 1.0, 1.0]), snp.array([0, 1.0])))\n    snp.testing.assert_allclose(snp.floor_divide(Ba, A), res)\n\n\ndef test_ufunc_real():\n    A = snp.array([1 + 3j])\n    res = snp.array([1])\n    np.testing.assert_allclose(snp.real(A), res)\n\n    A = snp.array([1 + 3j, 4.0 + 2j])\n    res = snp.array([1, 4.0])\n    np.testing.assert_allclose(snp.real(A), res)\n\n    Ba = snp.blockarray((snp.array([1 + 3j]),))\n    res = snp.blockarray((snp.array([1]),))\n    snp.testing.assert_allclose(snp.real(Ba), res)\n\n    Ba = snp.blockarray((snp.array([1.0 + 3j]), snp.array([1 + 3j, 4.0])))\n    res = snp.blockarray((snp.array([1.0]), snp.array([1, 4.0])))\n    snp.testing.assert_allclose(snp.real(Ba), res)\n\n\ndef test_ufunc_imag():\n    A = snp.array([1 + 3j])\n    res = snp.array([3])\n    np.testing.assert_allclose(snp.imag(A), res)\n\n    A = snp.array([1 + 3j, 4.0 + 2j])\n    res = snp.array([3, 2])\n    np.testing.assert_allclose(snp.imag(A), res)\n\n    Ba = snp.blockarray((snp.array([1 + 3j]),))\n    res = snp.blockarray((snp.array([3]),))\n    snp.testing.assert_allclose(snp.imag(Ba), res)\n\n    Ba = snp.blockarray((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0])))\n    res = snp.blockarray((snp.array([3]), snp.array([3, 0])))\n    snp.testing.assert_allclose(snp.imag(Ba), res)\n\n\ndef test_ufunc_conj():\n    A = snp.array([1 + 3j])\n    res = snp.array([1 - 3j])\n    np.testing.assert_allclose(snp.conj(A), res)\n\n    A = snp.array([1 + 3j, 4.0 + 2j])\n    res = snp.array([1 - 3j, 4.0 - 2j])\n    np.testing.assert_allclose(snp.conj(A), res)\n\n    Ba = snp.blockarray((snp.array([1 + 3j]),))\n    res = snp.blockarray((snp.array([1 - 3j]),))\n    snp.testing.assert_allclose(snp.conj(Ba), res)\n\n    Ba = snp.blockarray((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0])))\n    res = snp.blockarray((snp.array([1 - 3j]), snp.array([1 - 3j, 4.0 - 0j])))\n    snp.testing.assert_allclose(snp.conj(Ba), res)\n\n\ndef test_create_zeros():\n    A = snp.zeros(2)\n    assert np.all(A == 0)\n    assert isinstance(A, jax.Array)\n\n    A = snp.zeros((2,))\n    assert isinstance(A, jax.Array)\n\n    A = snp.zeros(((2,), (2,)))\n    assert snp.all(A == 0)\n    assert isinstance(A, snp.BlockArray)\n\n    A = snp.zeros(())\n    assert isinstance(A, jax.Array)  # from issue 499\n\n\ndef test_create_ones():\n    A = snp.ones(2, dtype=np.float32)\n    assert np.all(A == 1)\n\n    A = snp.ones(((2,), (2,)))\n    assert snp.all(A == 1)\n\n\ndef test_create_empty():\n    A = snp.empty(2)\n    assert np.all(A == 0)\n\n    A = snp.empty(((2,), (2,)))\n    assert snp.all(A == 0)\n\n\ndef test_create_full():\n    A = snp.full((2,), 1)\n    assert np.all(A == 1)\n\n    A = snp.full((2,), 1, dtype=np.float32)\n    assert np.all(A == 1)\n\n    A = snp.full(((2,), (2,)), 1)\n    assert snp.all(A == 1)\n\n\ndef test_create_zeros_like():\n    A = snp.ones(2, dtype=np.float32)\n    B = snp.zeros_like(A)\n    assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype\n\n    A = snp.ones(2, dtype=np.float32)\n    B = snp.zeros_like(A)\n    assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype\n\n    A = snp.ones(((2,), (2,)), dtype=np.float32)\n    B = snp.zeros_like(A)\n    assert snp.all(B == 0)\n    assert A.shape == B.shape\n    assert A.dtype == B.dtype\n\n\ndef test_create_empty_like():\n    A = snp.ones(2, dtype=np.float32)\n    B = snp.empty_like(A)\n    assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype\n\n    A = snp.ones(2, dtype=np.float32)\n    B = snp.empty_like(A)\n    assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype\n\n    A = snp.ones(((2,), (2,)), dtype=np.float32)\n    B = snp.empty_like(A)\n    assert snp.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype\n\n\ndef test_create_ones_like():\n    A = snp.zeros(2, dtype=np.float32)\n    B = snp.ones_like(A)\n    assert np.all(B == 1) and A.shape == B.shape and A.dtype == B.dtype\n\n    A = snp.zeros(2, dtype=np.float32)\n    B = snp.ones_like(A)\n    assert np.all(B == 1) and A.shape == B.shape and A.dtype == B.dtype\n\n    A = snp.zeros(((2,), (2,)), dtype=np.float32)\n    B = snp.ones_like(A)\n    assert snp.all(B == 1) and A.shape == B.shape and A.dtype == B.dtype\n\n\ndef test_create_full_like():\n    A = snp.zeros(2, dtype=np.float32)\n    B = snp.full_like(A, 1.0)\n    assert np.all(B == 1) and (A.shape == B.shape) and (A.dtype == B.dtype)\n\n    A = snp.zeros(2, dtype=np.float32)\n    B = snp.full_like(A, 1)\n    assert np.all(B == 1) and (A.shape == B.shape) and (A.dtype == B.dtype)\n\n    A = snp.zeros(((2,), (2,)), dtype=np.float32)\n    B = snp.full_like(A, 1)\n    assert snp.all(B == 1) and (A.shape == B.shape) and (A.dtype == B.dtype)\n\n\ndef test_wrap_recursively():\n    target_dict = {\"a\": 1, \"b\": 2}\n    names = [\"a\", \"c\"]\n    wrap = lambda x: x + 1\n    with pytest.warns(Warning):\n        _wrappers.wrap_recursively(target_dict, names, wrap)\n\n\ndef test_add_full_reduction():\n    with pytest.raises(ValueError):\n        _wrappers.add_full_reduction(np.sum, axis_arg_name=\"not_axis\")\n"
  },
  {
    "path": "scico/test/numpy/test_numpy_util.py",
    "content": "import collections\n\nimport numpy as np\n\nimport jax.numpy as jnp\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico.numpy.util import (\n    array_info,\n    array_to_namedtuple,\n    complex_dtype,\n    dtype_name,\n    indexed_shape,\n    is_blockable,\n    is_collapsible,\n    is_complex_dtype,\n    is_nested,\n    is_real_dtype,\n    is_scalar_equiv,\n    jax_indexed_shape,\n    namedtuple_to_array,\n    no_nan_divide,\n    normalize_axes,\n    real_dtype,\n    shape_dtype_rep,\n    slice_length,\n    transpose_list_of_ntpl,\n    transpose_ntpl_of_list,\n)\nfrom scico.random import randn\n\n\ndef test_ntpl_list_transpose():\n    nt = collections.namedtuple(\"NT\", (\"a\", \"b\", \"c\"))\n    ntlist0 = [nt(0, 1, 2), nt(3, 4, 5)]\n    listnt = transpose_list_of_ntpl(ntlist0)\n    ntlist1 = transpose_ntpl_of_list(listnt)\n    assert ntlist0[0] == ntlist1[0]\n    assert ntlist0[1] == ntlist1[1]\n\n\ndef test_namedtuple_to_array():\n    nt = collections.namedtuple(\"NT\", (\"A\", \"B\", \"C\"))\n    t0 = nt(0, 1, 2)\n    t0a = namedtuple_to_array(t0)\n    t1 = array_to_namedtuple(t0a)\n    assert t0 == t1\n\n\ndef test_no_nan_divide_array():\n    x, key = randn((4,), dtype=np.float32)\n    y, key = randn(x.shape, dtype=np.float32, key=key)\n    y = y.at[0].set(0)\n\n    res = no_nan_divide(x, y)\n\n    assert res[0] == 0\n    idx = y != 0\n    np.testing.assert_allclose(res[idx], x[idx] / y[idx])\n\n\ndef test_no_nan_divide_blockarray():\n    x, key = randn(((3, 3), (4,)), dtype=np.float32)\n\n    y, key = randn(x.shape, dtype=np.float32, key=key)\n    y[1] = y[1].at[:].set(0 * y[1])\n\n    res = no_nan_divide(x, y)\n\n    assert snp.all(res[1] == 0.0)\n    np.testing.assert_allclose(res[0], x[0] / y[0])\n\n\ndef test_array_info():\n    x = np.array([0.0, 0.1])\n    xinfo = array_info(x)\n    assert \"numpy.ndarray\" in xinfo\n    x = jnp.array([0.0, 0.1])\n    xinfo = array_info(x)\n    assert \"jax.Array\" in xinfo\n    x = snp.ones(((2, 3), (2,)))\n    xinfo = array_info(x)\n    assert \"scico.numpy.BlockArray\" in xinfo\n\n\ndef test_normalize_axes():\n    axes = None\n    np.testing.assert_raises(ValueError, normalize_axes, axes)\n\n    axes = None\n    assert normalize_axes(axes, np.shape([[1, 1], [1, 1]])) == (0, 1)\n\n    axes = None\n    assert normalize_axes(axes, np.shape([[1, 1], [1, 1]]), default=[0]) == [0]\n\n    axes = [1, 2]\n    assert normalize_axes(axes) == axes\n\n    axes = 1\n    assert normalize_axes(axes) == (1,)\n\n    axes = (-1,)\n    assert normalize_axes(axes, shape=(1, 2)) == (1,)\n\n    axes = (0, 2, 1)\n    assert normalize_axes(axes, shape=(2, 3, 4), sort=True) == (0, 1, 2)\n\n    axes = \"axes\"\n    np.testing.assert_raises(ValueError, normalize_axes, axes)\n\n    axes = 2\n    np.testing.assert_raises(ValueError, normalize_axes, axes, np.shape([1]))\n\n    axes = (1, 2, 2)\n    np.testing.assert_raises(ValueError, normalize_axes, axes)\n\n\n@pytest.mark.parametrize(\"length\", (4, 5, 8, 16, 17))\n@pytest.mark.parametrize(\"start\", (None, 0, 1, 2, 3))\n@pytest.mark.parametrize(\"stop\", (None, 0, 1, 2, -2, -1))\n@pytest.mark.parametrize(\"stride\", (None, 1, 2, 3))\ndef test_slice_length(length, start, stop, stride):\n    x = np.zeros(length)\n    slc = slice(start, stop, stride)\n    assert x[slc].size == slice_length(length, slc)\n\n\n@pytest.mark.parametrize(\"length\", (4, 5))\n@pytest.mark.parametrize(\"slc\", (0, 1, -4, Ellipsis))\ndef test_slice_length_other(length, slc):\n    x = np.zeros(length)\n    if isinstance(slc, int):\n        assert slice_length(length, slc) is None\n    else:\n        assert x[slc].size == slice_length(length, slc)\n\n\n@pytest.mark.parametrize(\"shape\", ((8, 8, 1), (7, 1, 6, 5)))\n@pytest.mark.parametrize(\n    \"slc\",\n    (\n        np.s_[0],\n        np.s_[0:5],\n        np.s_[:, 0:4],\n        np.s_[2:, :, :-2],\n        np.s_[..., 2:],\n        np.s_[..., 2:, :],\n        np.s_[1:, ..., 2:],\n        np.s_[np.newaxis],\n        np.s_[:, np.newaxis],\n        np.s_[np.newaxis, :, np.newaxis],\n        np.s_[np.newaxis, ..., 0:2, :],\n    ),\n)\ndef test_indexed_shape(shape, slc):\n    x = np.zeros(shape)\n    assert x[slc].shape == indexed_shape(shape, slc)\n    assert x[slc].shape == jax_indexed_shape(shape, slc)\n\n\ndef test_is_nested():\n    # list\n    assert is_nested([1, 2, 3]) == False\n\n    # tuple\n    assert is_nested((1, 2, 3)) == False\n\n    # list of lists\n    assert is_nested([[1, 2], [4, 5], [3]]) == True\n\n    # list of lists + scalar\n    assert is_nested([[1, 2], 3]) == True\n\n    # list of tuple + scalar\n    assert is_nested([(1, 2), 3]) == True\n\n    # tuple of tuple + scalar\n    assert is_nested(((1, 2), 3)) == True\n\n    # tuple of lists + scalar\n    assert is_nested(([1, 2], 3)) == True\n\n\ndef test_is_collapsible():\n    shape1 = ((1, 2, 3), (1, 2, 3), (1, 3, 3))\n    shape2 = ((1, 2, 3), (1, 2, 3), (1, 2, 3))\n    assert not is_collapsible(shape1)\n    assert is_collapsible(shape2)\n\n\ndef test_is_blockable():\n    shape1 = ((1, 2, 3), (1, 2, 3), (1, 2, 3))\n    shape2 = ((1, 2, 3), ((1, 2, 3), (1, 2, 3)))\n    assert is_blockable(shape1)\n    assert not is_blockable(shape2)\n\n\n@pytest.mark.parametrize(\"shape\", [(3, 4), ((3, 4), (5,))])\ndef test_shape_dtype_rep(shape):\n    dtype = np.float32\n    assert shape_dtype_rep(shape, dtype).shape == shape\n\n\ndef test_is_real_dtype():\n    assert not is_real_dtype(snp.complex64)\n    assert is_real_dtype(snp.float32)\n\n\ndef test_is_complex_dtype():\n    assert is_complex_dtype(snp.complex64)\n    assert not is_complex_dtype(snp.float32)\n\n\ndef test_real_dtype():\n    assert real_dtype(snp.complex64) == snp.float32\n\n\ndef test_complex_dtype():\n    assert complex_dtype(snp.float32) == snp.complex64\n\n\ndef test_dtype_name():\n    assert dtype_name(np.float32) == \"numpy.float32\"\n    assert dtype_name(snp.float32) == \"jax.numpy.float32\"\n\n\ndef test_broadcast_nested_shapes():\n    # unnested should work as usual\n    assert snp.util.broadcast_nested_shapes((1, 3, 4, 7), (3, 1, 7)) == (1, 3, 4, 7)\n\n    # nested + unested\n    assert snp.util.broadcast_nested_shapes(((2, 3), (1, 1, 3)), (2, 3)) == ((2, 3), (1, 2, 3))\n\n    # unested + nested\n    assert snp.util.broadcast_nested_shapes((1, 1, 3), ((2, 3), (7, 3))) == ((1, 2, 3), (1, 7, 3))\n\n    # nested + nested\n    snp.util.broadcast_nested_shapes(((1, 1, 3), (1, 7, 1, 3)), ((2, 3), (7, 4, 3))) == (\n        (1, 2, 3),\n        (1, 7, 4, 3),\n    )\n\n\ndef test_is_scalar_equiv():\n    assert is_scalar_equiv(1e0)\n    assert is_scalar_equiv(snp.array(1e0))\n    assert is_scalar_equiv(snp.sum(snp.zeros(1)))\n    assert not is_scalar_equiv(snp.array([1e0]))\n    assert not is_scalar_equiv(snp.array([1e0, 2e0]))\n"
  },
  {
    "path": "scico/test/operator/test_biconvolve.py",
    "content": "import numpy as np\n\nimport jax\nimport jax.scipy.signal as signal\n\nimport pytest\n\nfrom scico.linop import Convolve, ConvolveByX\nfrom scico.numpy import blockarray\nfrom scico.operator.biconvolve import BiConvolve\nfrom scico.random import randn\n\n\nclass TestBiConvolve:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    @pytest.mark.parametrize(\"input_dtype\", [np.float32, np.complex64])\n    @pytest.mark.parametrize(\"mode\", [\"full\", \"valid\", \"same\"])\n    @pytest.mark.parametrize(\"jit\", [False, True])\n    def test_eval(self, input_dtype, mode, jit):\n        x, key = randn((32, 32), dtype=input_dtype, key=self.key)\n        h, key = randn((4, 4), dtype=input_dtype, key=self.key)\n\n        x_h = blockarray([x, h])\n\n        A = BiConvolve(input_shape=x_h.shape, mode=mode, jit=jit)\n        signal_out = signal.convolve(x, h, mode=mode)\n        np.testing.assert_allclose(A(x_h), signal_out, rtol=1e-4)\n\n        # Test freezing\n        A_x = A.freeze(0, x)\n        assert isinstance(A_x, ConvolveByX)\n        np.testing.assert_allclose(A_x(h), signal_out, rtol=1e-4)\n\n        A_h = A.freeze(1, h)\n        assert isinstance(A_h, Convolve)\n        np.testing.assert_allclose(A_h(x), signal_out, rtol=1e-4)\n\n        with pytest.raises(ValueError):\n            A.freeze(2, x)\n\n    def test_invalid_shapes(self):\n        with pytest.raises(ValueError):\n            A = BiConvolve(input_shape=(2, 2))\n\n        with pytest.raises(ValueError):\n            shape = ((2, 2), (3, 3), (4, 4))  # 3 blocks\n            A = BiConvolve(input_shape=shape)\n\n        with pytest.raises(ValueError):\n            shape = ((2, 2), (3,))  # 3 blocks\n            A = BiConvolve(input_shape=shape)\n"
  },
  {
    "path": "scico/test/operator/test_op_stack.py",
    "content": "import numpy as np\n\nimport jax\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico.operator import (\n    Abs,\n    DiagonalReplicated,\n    DiagonalStack,\n    Operator,\n    VerticalStack,\n)\nfrom scico.random import randn\n\nTestOpA = Operator(input_shape=(3, 4), output_shape=(2, 3, 4), eval_fn=lambda x: snp.stack((x, x)))\nTestOpB = Operator(\n    input_shape=(3, 4), output_shape=(6, 4), eval_fn=lambda x: snp.concatenate((x, x))\n)\nTestOpC = Operator(\n    input_shape=(3, 4), output_shape=(6, 4), eval_fn=lambda x: snp.concatenate((x, 2 * x))\n)\n\n\nclass TestVerticalStack:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    @pytest.mark.parametrize(\"jit\", [False, True])\n    def test_construct(self, jit):\n        # requires a list of Operators\n        A = Abs((42,))\n        with pytest.raises(TypeError):\n            H = VerticalStack(A, jit=jit)\n\n        # checks input sizes\n        A = Abs((3, 2))\n        B = Abs((7, 2))\n        with pytest.raises(ValueError):\n            H = VerticalStack([A, B], jit=jit)\n\n        # in general, returns a BlockArray\n        A = TestOpA\n        B = TestOpB\n        H = VerticalStack([A, B], jit=jit)\n        x = np.ones((3, 4))\n        y = H(x)\n        assert y.shape == ((2, 3, 4), (6, 4))\n\n        # ... result should be [A@x, B@x]\n        assert np.allclose(y[0], A(x))\n        assert np.allclose(y[1], B(x))\n\n        # by default, collapse_output to jax array when possible\n        A = TestOpB\n        B = TestOpB\n        H = VerticalStack([A, B], jit=jit)\n        x = np.ones((3, 4))\n        y = H(x)\n        assert y.shape == (2, 6, 4)\n\n        # ... result should be [A@x, B@x]\n        assert np.allclose(y[0], A(x))\n        assert np.allclose(y[1], B(x))\n\n        # let user turn off collapsing\n        A = TestOpA\n        B = TestOpA\n        H = VerticalStack([A, B], collapse_output=False, jit=jit)\n        x = np.ones((3, 4))\n        y = H(x)\n        assert y.shape == ((2, 3, 4), (2, 3, 4))\n\n    @pytest.mark.parametrize(\"collapse_output\", [False, True])\n    @pytest.mark.parametrize(\"jit\", [False, True])\n    def test_algebra(self, collapse_output, jit):\n        # adding\n        A = TestOpB\n        B = TestOpB\n        H = VerticalStack([A, B], collapse_output=collapse_output, jit=jit)\n\n        A = TestOpC\n        B = TestOpC\n        G = VerticalStack([A, B], collapse_output=collapse_output, jit=jit)\n\n        x = np.ones((3, 4))\n        S = H + G\n\n        # test correctness of addition\n        assert S.output_shape == H.output_shape\n        assert S.input_shape == H.input_shape\n        np.testing.assert_allclose((S(x))[0], (H(x) + G(x))[0])\n        np.testing.assert_allclose((S(x))[1], (H(x) + G(x))[1])\n\n\nclass TestBlockDiagonalOperator:\n    def test_construct(self):\n        # requires a list of Operators\n        A = Abs((8,))\n        with pytest.raises(TypeError):\n            H = VerticalStack(A)\n\n        # no nested output shapes\n        A = Abs(((8,), (10,)))\n        with pytest.raises(ValueError):\n            H = VerticalStack((A, A))\n\n        # output dtypes must be the same\n        A = Abs(input_shape=(8,), input_dtype=snp.float32)\n        B = Abs(input_shape=(8,), input_dtype=snp.int32)\n        with pytest.raises(ValueError):\n            H = VerticalStack((A, B))\n\n    def test_apply(self):\n        S1 = (3, 4)\n        S2 = (3, 5)\n        S3 = (2, 2)\n        A1 = Abs(S1)\n        A2 = 2 * Abs(S2)\n        A3 = Abs(S3)\n        H = DiagonalStack((A1, A2, A3))\n\n        x = snp.ones((S1, S2, S3))\n        y = H(x)\n        y_expected = snp.blockarray((snp.ones(S1), 2 * snp.ones(S2), snp.sum(snp.ones(S3))))\n\n        np.testing.assert_equal(y, y_expected)\n\n    def test_input_collapse(self):\n        S = (3, 4)\n        A1 = TestOpA\n        A2 = TestOpB\n\n        H = DiagonalStack((A1, A2))\n        assert H.input_shape == (2, *S)\n\n        H = DiagonalStack((A1, A2), collapse_input=False)\n        assert H.input_shape == (S, S)\n\n    def test_output_collapse(self):\n        A1 = TestOpB\n        A2 = TestOpC\n\n        H = DiagonalStack((A1, A2))\n        assert H.output_shape == (2, *A1.output_shape)\n\n        H = DiagonalStack((A1, A2), collapse_output=False)\n        assert H.output_shape == (A1.output_shape, A1.output_shape)\n\n\nclass TestDiagonalReplicated:\n    def setup_method(self, method):\n        self.key = jax.random.key(12345)\n\n    @pytest.mark.parametrize(\"map_type\", [\"auto\", \"vmap\"])\n    @pytest.mark.parametrize(\"input_axis\", [0, 1])\n    def test_map_auto_vmap(self, input_axis, map_type):\n        x, key = randn((2, 3, 4), key=self.key)\n        mapshape = (3, 4) if input_axis == 0 else (2, 4)\n        replicates = x.shape[input_axis]\n        A = Abs(mapshape)\n        D = DiagonalReplicated(A, replicates, input_axis=input_axis, map_type=map_type)\n        y = D(x)\n        assert y.shape[input_axis] == replicates\n\n    @pytest.mark.skipif(jax.device_count() < 2, reason=\"multiple devices required for test\")\n    def test_map_auto_pmap(self):\n        x, key = randn((2, 3, 4), key=self.key)\n        A = Abs(x.shape[1:])\n        replicates = x.shape[0]\n        D = DiagonalReplicated(A, replicates, map_type=\"pmap\")\n        y = D(x)\n        assert y.shape[0] == replicates\n\n    def test_input_axis(self):\n        # Ensure that operators can be stacked on final axis\n        x, key = randn((2, 3, 4), key=self.key)\n        A = Abs(x.shape[0:2])\n        replicates = x.shape[2]\n        D = DiagonalReplicated(A, replicates, input_axis=2)\n        y = D(x)\n        assert y.shape == (2, 3, 4)\n        D = DiagonalReplicated(A, replicates, input_axis=-1)\n        y = D(x)\n        assert y.shape == (2, 3, 4)\n\n    def test_output_axis(self):\n        x, key = randn((2, 3, 4), key=self.key)\n        A = Abs(x.shape[1:])\n        replicates = x.shape[0]\n        D = DiagonalReplicated(A, replicates, output_axis=1)\n        y = D(x)\n        assert y.shape == (3, 2, 4)\n"
  },
  {
    "path": "scico/test/operator/test_operator.py",
    "content": "import operator as op\n\nimport numpy as np\n\nfrom jax import config\n\nimport pytest\n\n# enable 64-bit mode for output dtype checks\nconfig.update(\"jax_enable_x64\", True)\n\nimport jax\n\nimport scico.numpy as snp\nfrom scico.operator import Abs, Angle, Exp, Operator, operator_from_function\nfrom scico.random import randn\n\nSCALARS = (2, 1e0, snp.array(1.0))\n\n\nclass AbsOperator(Operator):\n    def _eval(self, x):\n        return snp.sum(snp.abs(x))\n\n\nclass SquareOperator(Operator):\n    def _eval(self, x):\n        return x**2\n\n\nclass SumSquareOperator(Operator):\n    def _eval(self, x):\n        return snp.sum(x**2)\n\n\nclass OperatorTestObj:\n    def __init__(self, dtype):\n        M, N = (32, 64)\n        key = jax.random.key(12345)\n        self.dtype = dtype\n\n        self.A = AbsOperator(input_shape=(N,), input_dtype=dtype)\n        self.B = SquareOperator(input_shape=(N,), input_dtype=dtype)\n        self.S = SumSquareOperator(input_shape=(N,), input_dtype=dtype)\n\n        self.mat = randn(self.A.input_shape, dtype=dtype, key=key)\n        self.x, key = randn((N,), dtype=dtype, key=key)\n\n        self.z, key = randn((2 * N,), dtype=dtype, key=key)\n\n\n@pytest.fixture(scope=\"module\", params=[np.float32, np.float64, np.complex64, np.complex128])\ndef testobj(request):\n    yield OperatorTestObj(request.param)\n\n\n@pytest.mark.parametrize(\"operator\", [op.add, op.sub])\ndef test_binary_op(testobj, operator):\n    # Our AbsOperator class does not override the __add__, etc\n    # so AbsOperator + AbsMatOp -> Operator\n\n    x = testobj.x\n    # Composite operator\n    comp_op = operator(testobj.A, testobj.S)\n\n    # evaluate Operators separately, then add/sub\n    res = operator(testobj.A(x), testobj.S(x))\n\n    assert comp_op.output_dtype == res.dtype\n    np.testing.assert_allclose(comp_op(x), res, rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"operator\", [op.add, op.sub])\ndef test_binary_op_same(testobj, operator):\n    x = testobj.x\n    # Composite operator\n    comp_op = operator(testobj.A, testobj.A)\n\n    # evaluate Operators separately, then add/sub\n    res = operator(testobj.A(x), testobj.A(x))\n\n    assert isinstance(comp_op, Operator)\n    assert comp_op.output_dtype == res.dtype\n    np.testing.assert_allclose(comp_op(x), res, rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"operator\", [op.mul, op.truediv])\n@pytest.mark.parametrize(\"scalar\", SCALARS)\ndef test_scalar_left(testobj, operator, scalar):\n    x = testobj.x\n    comp_op = operator(testobj.A, scalar)\n    res = operator(testobj.A(x), scalar)\n    assert comp_op.output_dtype == res.dtype\n    np.testing.assert_allclose(comp_op(x), res, rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"operator\", [op.mul, op.truediv])\n@pytest.mark.parametrize(\"scalar\", SCALARS)\ndef test_scalar_right(testobj, operator, scalar):\n    if operator == op.truediv:\n        pytest.xfail(\"scalar / Operator is not supported\")\n    x = testobj.x\n    comp_op = operator(scalar, testobj.A)\n    res = operator(scalar, testobj.A(x))\n    assert comp_op.output_dtype == res.dtype\n    np.testing.assert_allclose(comp_op(x), res, rtol=5e-5)\n\n\ndef test_negation(testobj):\n    x = testobj.x\n    comp_op = -testobj.A\n    res = -(testobj.A(x))\n    assert comp_op.input_dtype == testobj.A.input_dtype\n    np.testing.assert_allclose(comp_op(x), res, rtol=5e-5)\n\n\n@pytest.mark.parametrize(\"operator\", [op.add, op.sub])\ndef test_invalid_add_sub_array(testobj, operator):\n    # Try to add or subtract an ndarray with Operator\n    with pytest.raises(TypeError):\n        operator(testobj.A, testobj.mat)\n\n\n@pytest.mark.parametrize(\"operator\", [op.add, op.sub])\ndef test_invalid_add_sub_scalar(testobj, operator):\n    # Try to add or subtract a scalar with AbsMatOp\n    with pytest.raises(TypeError):\n        operator(1.0, testobj.A)\n\n\ndef test_call_operator_operator(testobj):\n    x = testobj.x\n    A = testobj.A\n    B = testobj.B\n    np.testing.assert_allclose(A(B)(x), A(B(x)))\n\n    with pytest.raises(ValueError):\n        # incompatible shapes\n        A(testobj.S)\n\n\ndef test_shape_call_vec(testobj):\n    # evaluate operator on an array of incompatible size\n    with pytest.raises(ValueError):\n        testobj.A(testobj.z)\n\n\ndef test_scale_vmap(testobj):\n    A = testobj.A\n    x = testobj.x\n\n    def foo(c):\n        return (c * A)(x)\n\n    c_list = [1.0, 2.0, 3.0]\n    non_vmap = np.array([foo(c) for c in c_list])\n    vmapped = jax.vmap(foo)(snp.array(c_list))\n    np.testing.assert_allclose(non_vmap, vmapped)\n\n\ndef test_scale_pmap(testobj):\n    A = testobj.A\n    x = testobj.x\n\n    def foo(c):\n        return (c * A)(x)\n\n    c_list = np.random.randn(jax.device_count())\n    non_pmap = np.array([foo(c) for c in c_list])\n    pmapped = jax.pmap(foo)(c_list)\n    np.testing.assert_allclose(non_pmap, pmapped, rtol=1e-6)\n\n\ndef test_freeze_3arg():\n    A = Operator(\n        input_shape=((1, 3, 4), (2, 1, 4), (2, 3, 1)), eval_fn=lambda x: x[0] * x[1] * x[2]\n    )\n\n    a, _ = randn((1, 3, 4))\n    b, _ = randn((2, 1, 4))\n    c, _ = randn((2, 3, 1))\n\n    x = snp.blockarray([a, b, c])\n    Abc = A.freeze(0, a)  # A as a function of b, c\n    Aac = A.freeze(1, b)  # A as a function of a, c\n    Aab = A.freeze(2, c)  # A as a function of a, b\n\n    assert Abc.input_shape == ((2, 1, 4), (2, 3, 1))\n    assert Aac.input_shape == ((1, 3, 4), (2, 3, 1))\n    assert Aab.input_shape == ((1, 3, 4), (2, 1, 4))\n\n    bc = snp.blockarray([b, c])\n    ac = snp.blockarray([a, c])\n    ab = snp.blockarray([a, b])\n    np.testing.assert_allclose(A(x), Abc(bc), rtol=5e-4)\n    np.testing.assert_allclose(A(x), Aac(ac), rtol=5e-4)\n    np.testing.assert_allclose(A(x), Aab(ab), rtol=5e-4)\n\n\ndef test_freeze_2arg():\n    A = Operator(input_shape=((1, 3, 4), (2, 1, 4)), eval_fn=lambda x: x[0] * x[1])\n\n    a, _ = randn((1, 3, 4))\n    b, _ = randn((2, 1, 4))\n\n    x = snp.blockarray([a, b])\n    Ab = A.freeze(0, a)  # A as a function of 'b' only\n    Aa = A.freeze(1, b)  # A as a function of 'a' only\n\n    assert Ab.input_shape == (2, 1, 4)\n    assert Aa.input_shape == (1, 3, 4)\n\n    np.testing.assert_allclose(A(x), Ab(b), rtol=5e-4)\n    np.testing.assert_allclose(A(x), Aa(a), rtol=5e-4)\n\n\n@pytest.mark.parametrize(\"dtype\", [np.float32, np.complex64])\n@pytest.mark.parametrize(\"op_fn\", [(Abs, snp.abs), (Angle, snp.angle), (Exp, snp.exp)])\ndef test_func_op(op_fn, dtype):\n    op = op_fn[0]\n    fn = op_fn[1]\n    shape = (2, 3)\n    x, _ = randn(shape, dtype=dtype)\n    H = op(input_shape=shape, input_dtype=dtype)\n    np.testing.assert_array_equal(H(x), fn(x))\n\n\ndef test_make_func_op():\n    AbsVal = operator_from_function(snp.abs, \"AbsVal\")\n    shape = (2,)\n    x, _ = randn(shape, dtype=np.float32)\n    H = AbsVal(input_shape=shape, input_dtype=np.float32)\n    np.testing.assert_array_equal(H(x), snp.abs(x))\n\n\ndef test_make_func_op_ext_init():\n    AbsVal = operator_from_function(snp.abs, \"AbsVal\")\n    shape = (2,)\n    x, _ = randn(shape, dtype=np.float32)\n    H = AbsVal(\n        input_shape=shape, output_shape=shape, input_dtype=np.float32, output_dtype=np.float32\n    )\n    np.testing.assert_array_equal(H(x), snp.abs(x))\n\n\nclass TestJacobianProdReal:\n    def setup_method(self):\n        N = 7\n        M = 8\n        key = None\n        dtype = snp.float32\n        self.fmx, key = randn((M, N), key=key, dtype=dtype)\n        self.F = Operator(\n            (N, 1),\n            output_shape=(M, 1),\n            eval_fn=lambda x: self.fmx @ x,\n            input_dtype=dtype,\n            output_dtype=dtype,\n        )\n        self.u, key = randn((N, 1), key=key, dtype=dtype)\n        self.v, key = randn((N, 1), key=key, dtype=dtype)\n        self.w, key = randn((M, 1), key=key, dtype=dtype)\n\n    def test_jvp(self):\n        Fu, JFuv = self.F.jvp(self.u, self.v)\n        np.testing.assert_allclose(Fu, self.F(self.u))\n        np.testing.assert_allclose(JFuv, self.fmx @ self.v, atol=1e-6, rtol=0.0)\n\n    def test_vjp_conj(self):\n        Fu, G = self.F.vjp(self.u, conjugate=True)\n        JFTw = G(self.w)\n        np.testing.assert_allclose(Fu, self.F(self.u))\n        np.testing.assert_allclose(JFTw, self.fmx.T @ self.w, atol=1e-6, rtol=0.0)\n\n    def test_vjp_noconj(self):\n        Fu, G = self.F.vjp(self.u, conjugate=False)\n        JFTw = G(self.w)\n        np.testing.assert_allclose(Fu, self.F(self.u))\n        np.testing.assert_allclose(JFTw, self.fmx.T @ self.w, atol=1e-6, rtol=0.0)\n\n\nclass TestJacobianProdComplex:\n    def setup_method(self):\n        N = 7\n        M = 8\n        key = None\n        dtype = snp.complex64\n        self.fmx, key = randn((M, N), key=key, dtype=dtype)\n        self.F = Operator(\n            (N, 1),\n            output_shape=(M, 1),\n            eval_fn=lambda x: self.fmx @ x,\n            input_dtype=dtype,\n            output_dtype=dtype,\n        )\n        self.u, key = randn((N, 1), key=key, dtype=dtype)\n        self.v, key = randn((N, 1), key=key, dtype=dtype)\n        self.w, key = randn((M, 1), key=key, dtype=dtype)\n\n    def test_jvp(self):\n        Fu, JFuv = self.F.jvp(self.u, self.v)\n        np.testing.assert_allclose(Fu, self.F(self.u))\n        np.testing.assert_allclose(JFuv, self.fmx @ self.v, rtol=1e-6)\n\n    def test_vjp_conj(self):\n        Fu, G = self.F.vjp(self.u, conjugate=True)\n        JFTw = G(self.w)\n        np.testing.assert_allclose(Fu, self.F(self.u))\n        np.testing.assert_allclose(JFTw, self.fmx.T.conj() @ self.w, rtol=1e-6)\n\n    def test_vjp_noconj(self):\n        Fu, G = self.F.vjp(self.u, conjugate=False)\n        JFTw = G(self.w)\n        np.testing.assert_allclose(Fu, self.F(self.u))\n        np.testing.assert_allclose(JFTw, self.fmx.T @ self.w, rtol=1e-6)\n"
  },
  {
    "path": "scico/test/optimize/test_admm.py",
    "content": "import os\nimport tempfile\n\nimport numpy as np\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, metric, operator, random\nfrom scico.optimize import ADMM\nfrom scico.optimize.admm import (\n    CircularConvolveSolver,\n    FBlockCircularConvolveSolver,\n    G0BlockCircularConvolveSolver,\n    GenericSubproblemSolver,\n    LinearSubproblemSolver,\n    MatrixSubproblemSolver,\n)\n\n\nclass TestMisc:\n    def setup_method(self, method):\n        np.random.seed(12345)\n        self.y = snp.array(np.random.randn(16, 17).astype(np.float32))\n\n    def test_admm(self):\n        maxiter = 2\n        ρ = 1e-1\n        A = linop.Identity(self.y.shape)\n        f = loss.SquaredL2Loss(y=self.y, A=A)\n        g = functional.DnCNN()\n        C = linop.Identity(self.y.shape)\n\n        itstat_fields = {\"Iter\": \"%d\", \"Time\": \"%8.2e\"}\n\n        def itstat_func(obj):\n            return (obj.itnum, obj.timer.elapsed())\n\n        admm_ = ADMM(\n            f=f,\n            g_list=[g],\n            C_list=[C],\n            rho_list=[ρ],\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n        )\n        assert len(admm_.itstat_object.fieldname) == 6\n        assert snp.sum(admm_.x) == 0.0\n\n        admm_ = ADMM(\n            f=f,\n            g_list=[g],\n            C_list=[C],\n            rho_list=[ρ],\n            maxiter=maxiter,\n            itstat_options={\"fields\": itstat_fields, \"itstat_func\": itstat_func, \"display\": False},\n        )\n        assert len(admm_.itstat_object.fieldname) == 2\n\n        admm_.test_flag = False\n\n        def callback(obj):\n            obj.test_flag = True\n\n        x = admm_.solve(callback=callback)\n        assert admm_.test_flag\n\n        with pytest.raises(TypeError):\n            admm_ = ADMM(f=f, g_list=[g], C_list=[C], rho_list=[ρ], invalid_keyword_arg=None)\n\n        admm_ = ADMM(f=f, g_list=[g], C_list=[C], rho_list=[ρ], maxiter=maxiter, nanstop=True)\n        admm_.step()\n        admm_.x = admm_.x.at[0].set(np.nan)\n        with pytest.raises(ValueError):\n            admm_.solve()\n\n    @pytest.mark.parametrize(\n        \"solver\", [LinearSubproblemSolver, MatrixSubproblemSolver, CircularConvolveSolver]\n    )\n    def test_admm_aux(self, solver):\n        maxiter = 2\n        ρ = 1e-1\n        A = operator.Abs(self.y.shape)\n        f = loss.SquaredL2Loss(y=self.y, A=A)\n        g = functional.DnCNN()\n        C = linop.Identity(self.y.shape)\n\n        with pytest.raises(TypeError):\n            admm_ = ADMM(\n                f=f,\n                g_list=[g],\n                C_list=[C],\n                rho_list=[ρ],\n                maxiter=maxiter,\n                subproblem_solver=solver(),\n            )\n\n        with pytest.raises(TypeError):\n            admm_ = ADMM(\n                f=g,\n                g_list=[g],\n                C_list=[C],\n                rho_list=[ρ],\n                maxiter=maxiter,\n                subproblem_solver=solver(),\n            )\n\n\nclass TestReal:\n    def setup_method(self, method):\n        np.random.seed(12345)\n        MA = 4\n        MB = 5\n        N = 6\n        # Set up arrays for problem argmin (𝛼/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2\n        Amx = np.random.randn(MA, N).astype(np.float32)\n        Bmx = np.random.randn(MB, N).astype(np.float32)\n        y = np.random.randn(MA).astype(np.float32)\n        𝛼 = np.pi  # sort of random number chosen to test non-default scale factor\n        λ = 1e0\n        self.Amx = Amx\n        self.Bmx = Bmx\n        self.y = snp.array(y)\n        self.𝛼 = 𝛼\n        self.λ = λ\n        # Solution of problem is given by linear system (𝛼 A^T A + λ B^T B) x = 𝛼 A^T y\n        self.grdA = lambda x: (𝛼 * Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x\n        self.grdb = 𝛼 * Amx.T @ y\n\n    def test_admm_generic(self):\n        maxiter = 25\n        ρ = 2e-1\n        A = linop.MatrixOperator(self.Amx)\n        f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)\n        g_list = [(self.λ / 2) * functional.SquaredL2Norm()]\n        C_list = [linop.MatrixOperator(self.Bmx)]\n        rho_list = [ρ]\n        admm_ = ADMM(\n            f=f,\n            g_list=g_list,\n            C_list=C_list,\n            rho_list=rho_list,\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            x0=A.adj(self.y),\n            subproblem_solver=GenericSubproblemSolver(minimize_kwargs={\"options\": {\"maxiter\": 50}}),\n        )\n        x = admm_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-3\n\n    def test_admm_saveload(self):\n        maxiter = 5\n        x_ref = np.ones((16, 16), dtype=np.float32)\n        x_ref[4:-4, 4:-4] = 1.0\n        n = 3\n        psf = snp.ones((n, n), dtype=np.float32) / (n * n)\n        A = linop.CircularConvolve(h=psf, input_shape=x_ref.shape)\n        y = A(x_ref)\n        λ = 2e-2\n        ρ = 5e-1\n        f = loss.SquaredL2Loss(y=y, A=A)\n        g = λ * functional.L21Norm()\n        C = linop.FiniteDifference(x_ref.shape, circular=True)\n        admm0 = ADMM(\n            f=f,\n            g_list=[g],\n            C_list=[C],\n            rho_list=[ρ],\n            x0=A.adj(y),\n            maxiter=maxiter,\n            subproblem_solver=CircularConvolveSolver(),\n        )\n        admm0.solve()\n        with tempfile.TemporaryDirectory() as tmpdir:\n            path = os.path.join(tmpdir, \"admm.npz\")\n            admm0.save_state(path)\n            admm0.solve()\n            h0 = admm0.history()\n            admm1 = ADMM(\n                f=f,\n                g_list=[g],\n                C_list=[C],\n                rho_list=[ρ],\n                x0=A.adj(y),\n                maxiter=maxiter,\n                subproblem_solver=CircularConvolveSolver(),\n            )\n            admm1.load_state(path)\n            admm1.solve()\n            h1 = admm1.history()\n            np.testing.assert_allclose(admm0.minimizer(), admm1.minimizer(), atol=1e-7)\n            assert np.abs(h0[-1].Objective - h1[-1].Objective) < 1e-6\n\n    def test_admm_quadratic_scico(self):\n        maxiter = 25\n        ρ = 4e-1\n        A = linop.MatrixOperator(self.Amx)\n        f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)\n        g_list = [(self.λ / 2) * functional.SquaredL2Norm()]\n        C_list = [linop.MatrixOperator(self.Bmx)]\n        rho_list = [ρ]\n        admm_ = ADMM(\n            f=f,\n            g_list=g_list,\n            C_list=C_list,\n            rho_list=rho_list,\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            x0=A.adj(self.y),\n            subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-4}, cg_function=\"scico\"),\n        )\n        x = admm_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4\n\n    def test_admm_quadratic_jax(self):\n        maxiter = 25\n        ρ = 1e0\n        A = linop.MatrixOperator(self.Amx)\n        f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)\n        g_list = [(self.λ / 2) * functional.SquaredL2Norm()]\n        C_list = [linop.MatrixOperator(self.Bmx)]\n        rho_list = [ρ]\n        admm_ = ADMM(\n            f=f,\n            g_list=g_list,\n            C_list=C_list,\n            rho_list=rho_list,\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            x0=A.adj(self.y),\n            subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-4}, cg_function=\"jax\"),\n        )\n        x = admm_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4\n\n    def test_admm_quadratic_relax(self):\n        maxiter = 25\n        ρ = 1e0\n        A = linop.MatrixOperator(self.Amx)\n        f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)\n        g_list = [(self.λ / 2) * functional.SquaredL2Norm()]\n        C_list = [linop.MatrixOperator(self.Bmx)]\n        rho_list = [ρ]\n        admm_ = ADMM(\n            f=f,\n            g_list=g_list,\n            C_list=C_list,\n            rho_list=rho_list,\n            alpha=1.6,\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            x0=A.adj(self.y),\n            subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-4}, cg_function=\"jax\"),\n        )\n        x = admm_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4\n\n\nclass TestRealWeighted:\n    def setup_method(self, method):\n        np.random.seed(12345)\n        MA = 4\n        MB = 5\n        N = 6\n        # Set up arrays for problem argmin (𝛼/2) ||A x - y||_W^2 + (λ/2) ||B x||_2^2\n        Amx = np.random.randn(MA, N).astype(np.float32)\n        W = np.abs(np.random.randn(MA, 1).astype(np.float32))\n        Bmx = np.random.randn(MB, N).astype(np.float32)\n        y = np.random.randn(MA).astype(np.float32)\n        𝛼 = np.pi  # sort of random number chosen to test non-default scale factor\n        λ = np.e\n        self.Amx = Amx\n        self.W = snp.array(W)\n        self.Bmx = Bmx\n        self.y = snp.array(y)\n        self.𝛼 = 𝛼\n        self.λ = λ\n        # Solution of problem is given by linear system\n        #   (𝛼 A^T W A + λ B^T B) x = 𝛼 A^T W y\n        self.grdA = lambda x: (𝛼 * Amx.T @ (W * Amx) + λ * Bmx.T @ Bmx) @ x\n        self.grdb = 𝛼 * Amx.T @ (W[:, 0] * y)\n\n    def test_admm_quadratic_linear(self):\n        maxiter = 100\n        ρ = 1e0\n        A = linop.MatrixOperator(self.Amx)\n        f = loss.SquaredL2Loss(y=self.y, A=A, W=linop.Diagonal(self.W[:, 0]), scale=self.𝛼 / 2.0)\n        g_list = [(self.λ / 2) * functional.SquaredL2Norm()]\n        C_list = [linop.MatrixOperator(self.Bmx)]\n        rho_list = [ρ]\n        admm_ = ADMM(\n            f=f,\n            g_list=g_list,\n            C_list=C_list,\n            rho_list=rho_list,\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            x0=A.adj(self.y),\n            subproblem_solver=LinearSubproblemSolver(cg_kwargs={\"tol\": 1e-4}, cg_function=\"scico\"),\n        )\n        x = admm_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4\n\n    def test_admm_quadratic_matrix(self):\n        maxiter = 50\n        ρ = 1e0\n        A = linop.MatrixOperator(self.Amx)\n        f = loss.SquaredL2Loss(y=self.y, A=A, W=linop.Diagonal(self.W[:, 0]), scale=self.𝛼 / 2.0)\n        g_list = [(self.λ / 2) * functional.SquaredL2Norm()]\n        C_list = [linop.MatrixOperator(self.Bmx)]\n        rho_list = [ρ]\n        admm_ = ADMM(\n            f=f,\n            g_list=g_list,\n            C_list=C_list,\n            rho_list=rho_list,\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            x0=A.adj(self.y),\n            subproblem_solver=MatrixSubproblemSolver(),\n        )\n        x = admm_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5\n\n\nclass TestComplex:\n    def setup_method(self, method):\n        MA = 4\n        MB = 5\n        N = 6\n        # Set up arrays for problem argmin (𝛼/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2\n        Amx, key = random.randn((MA, N), dtype=np.complex64, key=None)\n        Bmx, key = random.randn((MB, N), dtype=np.complex64, key=key)\n        y, key = random.randn((MA,), dtype=np.complex64, key=key)\n        𝛼 = 1.0 / 3.0\n        λ = 1e0\n        self.Amx = Amx\n        self.Bmx = Bmx\n        self.y = y\n        self.𝛼 = 𝛼\n        self.λ = λ\n        # Solution of problem is given by linear system (𝛼 A^T A + λ B^T B) x = A^T y\n        self.grdA = lambda x: (𝛼 * Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x\n        self.grdb = 𝛼 * Amx.conj().T @ y\n\n    def test_admm_generic(self):\n        maxiter = 30\n        ρ = 1e0\n        A = linop.MatrixOperator(self.Amx)\n        f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)\n        g_list = [(self.λ / 2) * functional.SquaredL2Norm()]\n        C_list = [linop.MatrixOperator(self.Bmx)]\n        rho_list = [ρ]\n        admm_ = ADMM(\n            f=f,\n            g_list=g_list,\n            C_list=C_list,\n            rho_list=rho_list,\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            x0=A.adj(self.y),\n            subproblem_solver=GenericSubproblemSolver(minimize_kwargs={\"options\": {\"maxiter\": 50}}),\n        )\n        x = admm_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-3\n\n    def test_admm_quadratic_linear(self):\n        maxiter = 50\n        ρ = 1e0\n        A = linop.MatrixOperator(self.Amx)\n        f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)\n        g_list = [(self.λ / 2) * functional.SquaredL2Norm()]\n        C_list = [linop.MatrixOperator(self.Bmx)]\n        rho_list = [ρ]\n        admm_ = ADMM(\n            f=f,\n            g_list=g_list,\n            C_list=C_list,\n            rho_list=rho_list,\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            x0=A.adj(self.y),\n            subproblem_solver=LinearSubproblemSolver(\n                cg_kwargs={\"tol\": 1e-4},\n            ),\n        )\n        x = admm_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4\n\n    def test_admm_quadratic_matrix(self):\n        maxiter = 50\n        ρ = 1e0\n        A = linop.MatrixOperator(self.Amx)\n        f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0)\n        g_list = [(self.λ / 2) * functional.SquaredL2Norm()]\n        C_list = [linop.MatrixOperator(self.Bmx)]\n        rho_list = [ρ]\n        admm_ = ADMM(\n            f=f,\n            g_list=g_list,\n            C_list=C_list,\n            rho_list=rho_list,\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            x0=A.adj(self.y),\n            subproblem_solver=MatrixSubproblemSolver(),\n        )\n        x = admm_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5\n\n\n@pytest.mark.parametrize(\"extra_axis\", (False, True))\n@pytest.mark.parametrize(\"center\", (None, [-1.0, 2.5]))\nclass TestCircularConvolveSolve:\n\n    @pytest.fixture(scope=\"function\", autouse=True)\n    def setup_and_teardown(self, extra_axis, center):\n        np.random.seed(12345)\n        Nx = 8\n        x = snp.pad(snp.ones((Nx, Nx), dtype=np.float32), Nx)\n        Npsf = 3\n        psf = snp.ones((Npsf, Npsf), dtype=np.float32) / (Npsf**2)\n        if extra_axis:\n            x = x[np.newaxis]\n            psf = psf[np.newaxis]\n        self.A = linop.CircularConvolve(\n            h=psf, input_shape=x.shape, ndims=2, input_dtype=np.float32, h_center=center\n        )\n        self.y = self.A(x)\n        λ = 1e-2\n        self.f = loss.SquaredL2Loss(y=self.y, A=self.A)\n        self.g_list = [λ * functional.L1Norm()]\n        self.C_list = [linop.FiniteDifference(input_shape=x.shape, circular=True)]\n        yield\n\n    def test_admm(self):\n        maxiter = 50\n        ρ = 1e-1\n        rho_list = [ρ]\n        admm_lin = ADMM(\n            f=self.f,\n            g_list=self.g_list,\n            C_list=self.C_list,\n            rho_list=rho_list,\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            x0=self.A.adj(self.y),\n            subproblem_solver=LinearSubproblemSolver(),\n        )\n        x_lin = admm_lin.solve()\n        admm_dft = ADMM(\n            f=self.f,\n            g_list=self.g_list,\n            C_list=self.C_list,\n            rho_list=rho_list,\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            x0=self.A.adj(self.y),\n            subproblem_solver=CircularConvolveSolver(),\n        )\n        assert admm_dft.subproblem_solver.A_lhs.ndims == 2\n        x_dft = admm_dft.solve()\n        np.testing.assert_allclose(x_dft, x_lin, atol=1e-4, rtol=0)\n        assert metric.mse(x_lin, x_dft) < 1e-9\n\n\n@pytest.mark.parametrize(\"with_cconv\", (False, True))\nclass TestSpecialCaseCircularConvolveSolve:\n\n    @pytest.fixture(scope=\"function\", autouse=True)\n    def setup_and_teardown(self, with_cconv):\n        np.random.seed(12345)\n        Nx = 8\n        x = snp.pad(snp.ones((1, Nx, Nx), dtype=np.float32), Nx)\n        if with_cconv:\n            Npsf = 3\n            psf = snp.ones((1, Npsf, Npsf), dtype=np.float32) / (Npsf**2)\n            C0 = linop.CircularConvolve(h=psf, input_shape=x.shape, ndims=2, input_dtype=np.float32)\n        else:\n            C0 = linop.FiniteDifference(input_shape=x.shape, axes=(1, 2), circular=True)\n        C1 = linop.Identity(input_shape=x.shape)\n        self.y = C0(x)\n        self.g_list = [loss.SquaredL2Loss(y=self.y), functional.L2Norm()]\n        self.C_list = [C0, C1]\n        self.with_cconv = with_cconv\n        yield\n\n    def test_admm(self):\n        maxiter = 50\n        ρ = 1e-1\n        rho_list = [ρ, ρ]\n        admm_lin = ADMM(\n            f=None,\n            g_list=self.g_list,\n            C_list=self.C_list,\n            rho_list=rho_list,\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            x0=self.C_list[0].adj(self.y),\n            subproblem_solver=LinearSubproblemSolver(),\n        )\n        x_lin = admm_lin.solve()\n        ndims = None if self.with_cconv else 2\n        admm_dft = ADMM(\n            f=None,\n            g_list=self.g_list,\n            C_list=self.C_list,\n            rho_list=rho_list,\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            x0=self.C_list[0].adj(self.y),\n            subproblem_solver=CircularConvolveSolver(ndims=ndims),\n        )\n        assert admm_dft.subproblem_solver.A_lhs.ndims == 2\n        x_dft = admm_dft.solve()\n        np.testing.assert_allclose(x_dft, x_lin, atol=1e-4, rtol=0)\n        assert metric.mse(x_lin, x_dft) < 1e-9\n\n\nclass TestBlockCircularConvolveSolve:\n    def setup_method(self, method):\n        np.random.seed(12345)\n        Nx = 8\n        x = np.zeros((2, Nx, Nx), dtype=np.float32)\n        x[0, 2, 2] = 1.0\n        x[1, 3, 3] = 1.0\n        Npsf = 3\n        psf = np.zeros((2, Npsf, Npsf), dtype=np.float32)\n        psf[0, 1] = 1.0\n        psf[1, :, 1] = 1.0\n        C = linop.CircularConvolve(h=psf, input_shape=x.shape, input_dtype=np.float32, ndims=2)\n        S = linop.Sum(input_shape=x.shape, axis=0)\n        self.A = S @ C\n        self.y = self.A(x)\n        λ = 1e-1\n        self.f = loss.SquaredL2Loss(y=self.y, A=self.A)\n        self.g_list = [λ * functional.L1Norm()]\n        self.C_list = [linop.Identity(input_shape=x.shape)]\n\n    def test_fblock_init(self):\n        with pytest.raises(ValueError):\n            slvr = ADMM(\n                f=None,\n                g_list=self.g_list,\n                C_list=self.C_list,\n                rho_list=[1.0],\n                itstat_options={\"display\": False},\n                subproblem_solver=FBlockCircularConvolveSolver(),\n            )\n        with pytest.raises(TypeError):\n            slvr = ADMM(\n                f=loss.PoissonLoss(y=self.y),\n                g_list=self.g_list,\n                C_list=self.C_list,\n                rho_list=[1.0],\n                itstat_options={\"display\": False},\n                subproblem_solver=FBlockCircularConvolveSolver(),\n            )\n        with pytest.raises(TypeError):\n            slvr = ADMM(\n                f=loss.SquaredL2Loss(y=self.y, A=self.A.A),\n                g_list=self.g_list,\n                C_list=self.C_list,\n                rho_list=[1.0],\n                itstat_options={\"display\": False},\n                subproblem_solver=FBlockCircularConvolveSolver(),\n            )\n\n    def test_g0block_init(self):\n        with pytest.raises(ValueError):\n            slvr = ADMM(\n                f=self.f,\n                g_list=self.g_list,\n                C_list=self.C_list,\n                rho_list=[1.0],\n                itstat_options={\"display\": False},\n                subproblem_solver=G0BlockCircularConvolveSolver(),\n            )\n        with pytest.raises(TypeError):\n            slvr = ADMM(\n                f=functional.ZeroFunctional(),\n                g_list=[loss.PoissonLoss(y=self.y)],\n                C_list=self.C_list,\n                rho_list=[1.0],\n                itstat_options={\"display\": False},\n                subproblem_solver=G0BlockCircularConvolveSolver(),\n            )\n        with pytest.raises(TypeError):\n            slvr = ADMM(\n                f=functional.ZeroFunctional(),\n                g_list=[loss.SquaredL2Loss(y=self.y)] + self.g_list,\n                C_list=[self.A.A] + self.C_list,\n                rho_list=[1.0, 1.0],\n                itstat_options={\"display\": False},\n                subproblem_solver=G0BlockCircularConvolveSolver(),\n            )\n\n    def test_solve(self):\n        maxiter = 50\n        ρ = 1e1\n        rho_list = [ρ]\n        admm_lin = ADMM(\n            f=self.f,\n            g_list=self.g_list,\n            C_list=self.C_list,\n            rho_list=rho_list,\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            subproblem_solver=LinearSubproblemSolver(),\n        )\n        x_lin = admm_lin.solve()\n\n        admm_dft1 = ADMM(\n            f=self.f,\n            g_list=self.g_list,\n            C_list=self.C_list,\n            rho_list=rho_list,\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            subproblem_solver=FBlockCircularConvolveSolver(check_solve=True),\n        )\n        x_dft1 = admm_dft1.solve()\n        np.testing.assert_allclose(x_dft1, x_lin, atol=1e-4, rtol=0)\n        assert metric.mse(x_lin, x_dft1) < 1e-9\n        assert admm_dft1.subproblem_solver.accuracy <= 1e-6\n\n        admm_dft2 = ADMM(\n            f=functional.ZeroFunctional(),\n            g_list=[loss.SquaredL2Loss(y=self.y)] + self.g_list,\n            C_list=[self.A] + self.C_list,\n            rho_list=[1.0, ρ],\n            maxiter=maxiter,\n            itstat_options={\"display\": False},\n            subproblem_solver=G0BlockCircularConvolveSolver(check_solve=True),\n        )\n        admm_dft2.z_list[0] = self.y  # significantly improves convergence\n        x_dft2 = admm_dft2.solve()\n        np.testing.assert_allclose(x_dft2, x_lin, atol=1e-4, rtol=0)\n        assert metric.mse(x_lin, x_dft2) < 1e-9\n        assert admm_dft2.subproblem_solver.accuracy <= 1e-6\n"
  },
  {
    "path": "scico/test/optimize/test_ladmm.py",
    "content": "import os\nimport tempfile\n\nimport numpy as np\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, random\nfrom scico.numpy import BlockArray\nfrom scico.optimize import LinearizedADMM\n\n\nclass TestMisc:\n    def setup_method(self, method):\n        np.random.seed(12345)\n        self.y = snp.array(np.random.randn(32, 33).astype(np.float32))\n        self.maxiter = 2\n        self.μ = 1e-1\n        self.ν = 1e-1\n        self.A = linop.Identity(self.y.shape)\n        self.f = loss.SquaredL2Loss(y=self.y, A=self.A)\n        self.g = functional.DnCNN()\n        self.C = linop.Identity(self.y.shape)\n\n    def test_itstat(self):\n        itstat_fields = {\"Iter\": \"%d\", \"Time\": \"%8.2e\"}\n\n        def itstat_func(obj):\n            return (obj.itnum, obj.timer.elapsed())\n\n        ladmm_ = LinearizedADMM(\n            f=self.f,\n            g=self.g,\n            C=self.C,\n            mu=self.μ,\n            nu=self.ν,\n            maxiter=self.maxiter,\n        )\n        assert len(ladmm_.itstat_object.fieldname) == 4\n        assert snp.sum(ladmm_.x) == 0.0\n\n        ladmm_ = LinearizedADMM(\n            f=self.f,\n            g=self.g,\n            C=self.C,\n            mu=self.μ,\n            nu=self.ν,\n            maxiter=self.maxiter,\n            itstat_options={\"fields\": itstat_fields, \"itstat_func\": itstat_func, \"display\": False},\n        )\n        assert len(ladmm_.itstat_object.fieldname) == 2\n\n    def test_callback(self):\n        ladmm_ = LinearizedADMM(\n            f=self.f,\n            g=self.g,\n            C=self.C,\n            mu=self.μ,\n            nu=self.ν,\n            maxiter=self.maxiter,\n        )\n        ladmm_.test_flag = False\n\n        def callback(obj):\n            obj.test_flag = True\n\n        x = ladmm_.solve(callback=callback)\n        assert ladmm_.test_flag\n\n    def test_finite_check(self):\n        ladmm_ = LinearizedADMM(\n            f=self.f, g=self.g, C=self.C, mu=self.μ, nu=self.ν, maxiter=self.maxiter, nanstop=True\n        )\n        ladmm_.step()\n        ladmm_.x = ladmm_.x.at[0].set(np.nan)\n        with pytest.raises(ValueError):\n            ladmm_.solve()\n\n\nclass TestBlockArray:\n    def setup_method(self, method):\n        np.random.seed(12345)\n        self.y = snp.blockarray(\n            (\n                np.random.randn(32, 33).astype(np.float32),\n                np.random.randn(\n                    17,\n                ).astype(np.float32),\n            )\n        )\n        self.λ = 1e0\n        self.maxiter = 1\n        self.μ = 1e-1\n        self.ν = 1e-1\n        self.A = linop.Identity(self.y.shape)\n        self.f = loss.SquaredL2Loss(y=self.y, A=self.A)\n        self.g = (self.λ / 2) * functional.L2Norm()\n        self.C = linop.Identity(self.y.shape)\n\n    def test_blockarray(self):\n        ladmm_ = LinearizedADMM(\n            f=self.f,\n            g=self.g,\n            C=self.C,\n            mu=self.μ,\n            nu=self.ν,\n            maxiter=self.maxiter,\n        )\n        x = ladmm_.solve()\n        assert isinstance(x, BlockArray)\n\n\nclass TestReal:\n    def setup_method(self, method):\n        np.random.seed(12345)\n        N = 8\n        MB = 10\n        # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2\n        Amx = np.diag(np.random.randn(N).astype(np.float32))\n        Bmx = np.random.randn(MB, N).astype(np.float32)\n        y = np.random.randn(N).astype(np.float32)\n        λ = 1e0\n        self.Amx = Amx\n        self.Bmx = Bmx\n        self.y = snp.array(y)\n        self.λ = λ\n        # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y\n        self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x\n        self.grdb = Amx.T @ y\n\n    def test_ladmm(self):\n        maxiter = 400\n        μ = 1e-2\n        ν = 2e-1\n        A = linop.Diagonal(snp.diag(self.Amx))\n        f = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2) * functional.SquaredL2Norm()\n        C = linop.MatrixOperator(self.Bmx)\n        ladmm_ = LinearizedADMM(\n            f=f,\n            g=g,\n            C=C,\n            mu=μ,\n            nu=ν,\n            maxiter=maxiter,\n            x0=A.adj(self.y),\n        )\n        x = ladmm_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4\n\n    def test_ladmm_saveload(self):\n        maxiter = 5\n        μ = 1e-2\n        ν = 2e-1\n        A = linop.Diagonal(snp.diag(self.Amx))\n        f = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2) * functional.SquaredL2Norm()\n        C = linop.MatrixOperator(self.Bmx)\n        ladmm0 = LinearizedADMM(\n            f=f,\n            g=g,\n            C=C,\n            mu=μ,\n            nu=ν,\n            maxiter=maxiter,\n            x0=A.adj(self.y),\n        )\n        ladmm0.solve()\n        with tempfile.TemporaryDirectory() as tmpdir:\n            path = os.path.join(tmpdir, \"ladmm.npz\")\n            ladmm0.save_state(path)\n            ladmm0.solve()\n            h0 = ladmm0.history()\n            ladmm1 = LinearizedADMM(\n                f=f,\n                g=g,\n                C=C,\n                mu=μ,\n                nu=ν,\n                maxiter=maxiter,\n                x0=A.adj(self.y),\n            )\n            ladmm1.load_state(path)\n            ladmm1.solve()\n            h1 = ladmm1.history()\n            np.testing.assert_allclose(ladmm0.minimizer(), ladmm1.minimizer(), rtol=1e-6)\n            assert np.abs(h0[-1].Objective - h1[-1].Objective) < 1e-6\n\n\nclass TestComplex:\n    def setup_method(self, method):\n        N = 8\n        MB = 10\n        # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2\n        Amx, key = random.randn((N,), dtype=np.complex64, key=None)\n        Amx = snp.diag(Amx)\n        Bmx, key = random.randn((MB, N), dtype=np.complex64, key=key)\n        y, key = random.randn((N,), dtype=np.complex64, key=key)\n        λ = 1e0\n        self.Amx = Amx\n        self.Bmx = Bmx\n        self.y = snp.array(y)\n        self.λ = λ\n        # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y\n        self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x\n        self.grdb = Amx.conj().T @ y\n\n    def test_ladmm(self):\n        maxiter = 500\n        μ = 1e-2\n        ν = 2e-1\n        A = linop.Diagonal(snp.diag(self.Amx))\n        f = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2) * functional.SquaredL2Norm()\n        C = linop.MatrixOperator(self.Bmx)\n        ladmm_ = LinearizedADMM(\n            f=f,\n            g=g,\n            C=C,\n            mu=μ,\n            nu=ν,\n            maxiter=maxiter,\n            x0=A.adj(self.y),\n        )\n        x = ladmm_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 5e-4\n"
  },
  {
    "path": "scico/test/optimize/test_padmm.py",
    "content": "import os\nimport tempfile\n\nimport numpy as np\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico import function, functional, linop, loss, random\nfrom scico.numpy import BlockArray\nfrom scico.optimize import NonLinearPADMM, ProximalADMM\n\n\nclass TestMisc:\n    def setup_method(self, method):\n        np.random.seed(12345)\n        self.y = snp.array(np.random.randn(32, 33).astype(np.float32))\n        self.maxiter = 2\n        self.ρ = 1e0\n        self.μ = 1e0\n        self.ν = 1e0\n        self.A = linop.Identity(self.y.shape)\n        self.f = loss.SquaredL2Loss(y=self.y, A=self.A)\n        self.g = functional.DnCNN()\n        self.H = function.Function(\n            (self.A.input_shape, self.A.input_shape),\n            output_shape=self.A.input_shape,\n            eval_fn=lambda x, z: x - z,\n            input_dtypes=np.float32,\n            output_dtype=np.float32,\n        )\n        self.x0 = snp.zeros(self.A.input_shape, dtype=snp.float32)\n\n    def test_itstat_padmm(self):\n        itstat_fields = {\"Iter\": \"%d\", \"Time\": \"%8.2e\"}\n\n        def itstat_func(obj):\n            return (obj.itnum, obj.timer.elapsed())\n\n        padmm_ = ProximalADMM(\n            f=self.f,\n            g=self.g,\n            A=self.A,\n            rho=self.ρ,\n            mu=self.μ,\n            nu=self.ν,\n            x0=self.x0,\n            z0=self.x0,\n            u0=self.x0,\n            maxiter=self.maxiter,\n        )\n        assert len(padmm_.itstat_object.fieldname) == 4\n        assert snp.sum(padmm_.x) == 0.0\n\n        padmm_ = ProximalADMM(\n            f=self.f,\n            g=self.g,\n            A=self.A,\n            rho=self.ρ,\n            mu=self.μ,\n            nu=self.ν,\n            B=None,\n            maxiter=self.maxiter,\n            itstat_options={\"fields\": itstat_fields, \"itstat_func\": itstat_func, \"display\": False},\n        )\n        assert len(padmm_.itstat_object.fieldname) == 2\n\n    def test_itstat_nlpadmm(self):\n        itstat_fields = {\"Iter\": \"%d\", \"Time\": \"%8.2e\"}\n\n        def itstat_func(obj):\n            return (obj.itnum, obj.timer.elapsed())\n\n        nlpadmm_ = NonLinearPADMM(\n            f=self.f,\n            g=self.g,\n            H=self.H,\n            rho=self.ρ,\n            mu=self.μ,\n            nu=self.ν,\n            x0=self.x0,\n            z0=self.x0,\n            u0=self.x0,\n            maxiter=self.maxiter,\n        )\n        assert len(nlpadmm_.itstat_object.fieldname) == 4\n        assert snp.sum(nlpadmm_.x) == 0.0\n\n        nlpadmm_ = NonLinearPADMM(\n            f=self.f,\n            g=self.g,\n            H=self.H,\n            rho=self.ρ,\n            mu=self.μ,\n            nu=self.ν,\n            maxiter=self.maxiter,\n            itstat_options={\"fields\": itstat_fields, \"itstat_func\": itstat_func, \"display\": False},\n        )\n        assert len(nlpadmm_.itstat_object.fieldname) == 2\n\n    def test_callback(self):\n        padmm_ = ProximalADMM(\n            f=self.f,\n            g=self.g,\n            A=self.A,\n            rho=self.ρ,\n            mu=self.μ,\n            nu=self.ν,\n            maxiter=self.maxiter,\n        )\n        padmm_.test_flag = False\n\n        def callback(obj):\n            obj.test_flag = True\n\n        x = padmm_.solve(callback=callback)\n        assert padmm_.test_flag\n\n    def test_finite_check(self):\n        padmm_ = ProximalADMM(\n            f=self.f,\n            g=self.g,\n            A=self.A,\n            rho=self.ρ,\n            mu=self.μ,\n            nu=self.ν,\n            maxiter=self.maxiter,\n            nanstop=True,\n        )\n        padmm_.step()\n        padmm_.x = padmm_.x.at[0].set(np.nan)\n        with pytest.raises(ValueError):\n            padmm_.solve()\n\n\nclass TestBlockArray:\n    def setup_method(self, method):\n        np.random.seed(12345)\n        self.y = snp.blockarray(\n            (\n                np.random.randn(32, 33).astype(np.float32),\n                np.random.randn(\n                    17,\n                ).astype(np.float32),\n            )\n        )\n        self.λ = 1e0\n        self.maxiter = 1\n        self.ρ = 1e0\n        self.μ = 1e0\n        self.ν = 1e0\n        self.A = linop.Identity(self.y.shape)\n        self.f = loss.SquaredL2Loss(y=self.y, A=self.A)\n        self.g = (self.λ / 2) * functional.L2Norm()\n        self.H = function.Function(\n            (self.A.input_shape, self.A.input_shape),\n            output_shape=self.A.input_shape,\n            eval_fn=lambda x, z: x - z,\n            input_dtypes=np.float32,\n            output_dtype=np.float32,\n        )\n        self.x0 = snp.zeros(self.A.input_shape, dtype=snp.float32)\n\n    def test_blockarray_padmm(self):\n        padmm_ = ProximalADMM(\n            f=self.f,\n            g=self.g,\n            A=self.A,\n            rho=self.ρ,\n            mu=self.μ,\n            nu=self.ν,\n            maxiter=self.maxiter,\n        )\n        x = padmm_.solve()\n        assert isinstance(x, BlockArray)\n\n    def test_blockarray_nlpadmm(self):\n        nlpadmm_ = NonLinearPADMM(\n            f=self.f,\n            g=self.g,\n            H=self.H,\n            rho=self.ρ,\n            mu=self.μ,\n            nu=self.ν,\n            maxiter=self.maxiter,\n        )\n        x = nlpadmm_.solve()\n        assert isinstance(x, BlockArray)\n\n\nclass TestReal:\n    def setup_method(self, method):\n        np.random.seed(12345)\n        N = 8\n        MB = 10\n        # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2\n        Amx = np.diag(np.random.randn(N).astype(np.float32))\n        Bmx = np.random.randn(MB, N).astype(np.float32)\n        y = np.random.randn(N).astype(np.float32)\n        λ = 1e0\n        self.Amx = Amx\n        self.Bmx = Bmx\n        self.y = snp.array(y)\n        self.λ = λ\n        # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y\n        self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x\n        self.grdb = Amx.T @ y\n\n    def test_padmm(self):\n        maxiter = 200\n        ρ = 1e0\n        μ = 5e1\n        ν = 1e0\n        A = linop.Diagonal(snp.diag(self.Amx))\n        f = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2) * functional.SquaredL2Norm()\n        C = linop.MatrixOperator(self.Bmx)\n        padmm_ = ProximalADMM(\n            f=f,\n            g=g,\n            A=C,\n            rho=ρ,\n            mu=μ,\n            nu=ν,\n            maxiter=maxiter,\n        )\n        x = padmm_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4\n\n    def test_padmm_saveload(self):\n        maxiter = 5\n        ρ = 1e0\n        μ = 5e1\n        ν = 1e0\n        A = linop.Diagonal(snp.diag(self.Amx))\n        f = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2) * functional.SquaredL2Norm()\n        C = linop.MatrixOperator(self.Bmx)\n        padmm0 = ProximalADMM(\n            f=f,\n            g=g,\n            A=C,\n            rho=ρ,\n            mu=μ,\n            nu=ν,\n            maxiter=maxiter,\n        )\n        padmm0.solve()\n        with tempfile.TemporaryDirectory() as tmpdir:\n            path = os.path.join(tmpdir, \"padmm.npz\")\n            padmm0.save_state(path)\n            padmm0.solve()\n            h0 = padmm0.history()\n            padmm1 = ProximalADMM(\n                f=f,\n                g=g,\n                A=C,\n                rho=ρ,\n                mu=μ,\n                nu=ν,\n                maxiter=maxiter,\n            )\n            padmm1.load_state(path)\n            padmm1.solve()\n            h1 = padmm1.history()\n            np.testing.assert_allclose(padmm0.minimizer(), padmm1.minimizer(), rtol=1e-6)\n            assert np.abs(h0[-1].Objective - h1[-1].Objective) < 1e-6\n\n    def test_nlpadmm(self):\n        maxiter = 200\n        ρ = 1e0\n        μ = 5e1\n        ν = 1e0\n        A = linop.Diagonal(snp.diag(self.Amx))\n        f = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2) * functional.SquaredL2Norm()\n        C = linop.MatrixOperator(self.Bmx)\n        H = function.Function(\n            (C.input_shape, C.output_shape),\n            output_shape=C.output_shape,\n            eval_fn=lambda x, z: C(x) - z,\n            input_dtypes=snp.float32,\n            output_dtype=snp.float32,\n        )\n        nlpadmm_ = NonLinearPADMM(\n            f=f,\n            g=g,\n            H=H,\n            rho=ρ,\n            mu=μ,\n            nu=ν,\n            maxiter=maxiter,\n        )\n        x = nlpadmm_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4\n\n\nclass TestComplex:\n    def setup_method(self, method):\n        N = 8\n        MB = 10\n        # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2\n        Amx, key = random.randn((N,), dtype=np.complex64, key=None)\n        Amx = snp.diag(Amx)\n        Bmx, key = random.randn((MB, N), dtype=np.complex64, key=key)\n        y, key = random.randn((N,), dtype=np.complex64, key=key)\n        λ = 1e0\n        self.Amx = Amx\n        self.Bmx = Bmx\n        self.y = snp.array(y)\n        self.λ = λ\n        # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y\n        self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x\n        self.grdb = Amx.conj().T @ y\n\n    def test_nlpadmm(self):\n        maxiter = 300\n        ρ = 1e0\n        μ = 3e1\n        ν = 1e0\n        A = linop.Diagonal(snp.diag(self.Amx))\n        f = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2) * functional.SquaredL2Norm()\n        C = linop.MatrixOperator(self.Bmx)\n        H = function.Function(\n            (C.input_shape, C.output_shape),\n            output_shape=C.output_shape,\n            eval_fn=lambda x, z: C(x) - z,\n            input_dtypes=snp.complex64,\n            output_dtype=snp.complex64,\n        )\n        nlpadmm_ = NonLinearPADMM(\n            f=f,\n            g=g,\n            H=H,\n            rho=ρ,\n            mu=μ,\n            nu=ν,\n            maxiter=maxiter,\n        )\n        x = nlpadmm_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4\n\n\nclass TestEstimateParameters:\n    def setup_method(self):\n        shape = (32, 33)\n        self.A = linop.Identity(shape)\n        self.Hr = function.Function(\n            (shape, shape),\n            output_shape=shape,\n            eval_fn=lambda x, z: x - z,\n            input_dtypes=np.float32,\n            output_dtype=np.float32,\n        )\n        self.Hc = function.Function(\n            (shape, shape),\n            output_shape=shape,\n            eval_fn=lambda x, z: x - z,\n            input_dtypes=np.complex64,\n            output_dtype=np.complex64,\n        )\n\n    def test_padmm_a(self):\n        mu, nu = ProximalADMM.estimate_parameters(self.A, factor=1.0)\n        assert snp.abs(mu - 1.0) < 1e-6\n        assert snp.abs(nu - 1.0) < 1e-6\n\n    def test_padmm_ab(self):\n        mu, nu = ProximalADMM.estimate_parameters(self.A, self.A, factor=1.0)\n        assert snp.abs(mu - 1.0) < 1e-6\n        assert snp.abs(nu - 1.0) < 1e-6\n\n    def test_real(self):\n        mu, nu = NonLinearPADMM.estimate_parameters(self.Hr, factor=1.0)\n        assert snp.abs(mu - 1.0) < 1e-6\n        assert snp.abs(nu - 1.0) < 1e-6\n\n    def test_complex(self):\n        mu, nu = NonLinearPADMM.estimate_parameters(self.Hc, factor=1.0)\n        assert snp.abs(mu - 1.0) < 1e-6\n        assert snp.abs(nu - 1.0) < 1e-6\n"
  },
  {
    "path": "scico/test/optimize/test_pdhg.py",
    "content": "import os\nimport tempfile\n\nimport numpy as np\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, operator, random\nfrom scico.numpy import BlockArray\nfrom scico.optimize import PDHG\n\n\nclass TestMisc:\n    def setup_method(self, method):\n        np.random.seed(12345)\n        self.y = snp.array(np.random.randn(32, 33).astype(np.float32))\n        self.maxiter = 2\n        self.τ = 1e-1\n        self.σ = 1e-1\n        self.A = linop.Identity(self.y.shape)\n        self.f = loss.SquaredL2Loss(y=self.y, A=self.A)\n        self.g = functional.DnCNN()\n        self.C = linop.Identity(self.y.shape)\n\n    def test_itstat(self):\n        itstat_fields = {\"Iter\": \"%d\", \"Time\": \"%8.2e\"}\n\n        def itstat_func(obj):\n            return (obj.itnum, obj.timer.elapsed())\n\n        pdhg_ = PDHG(\n            f=self.f,\n            g=self.g,\n            C=self.C,\n            tau=self.τ,\n            sigma=self.σ,\n            maxiter=self.maxiter,\n        )\n        assert len(pdhg_.itstat_object.fieldname) == 4\n        assert snp.sum(pdhg_.x) == 0.0\n\n        pdhg_ = PDHG(\n            f=self.f,\n            g=self.g,\n            C=self.C,\n            tau=self.τ,\n            sigma=self.σ,\n            maxiter=self.maxiter,\n            itstat_options={\"fields\": itstat_fields, \"itstat_func\": itstat_func, \"display\": False},\n        )\n        assert len(pdhg_.itstat_object.fieldname) == 2\n\n    def test_callback(self):\n        pdhg_ = PDHG(\n            f=self.f,\n            g=self.g,\n            C=self.C,\n            tau=self.τ,\n            sigma=self.σ,\n            maxiter=self.maxiter,\n        )\n        pdhg_.test_flag = False\n\n        def callback(obj):\n            obj.test_flag = True\n\n        x = pdhg_.solve(callback=callback)\n        assert pdhg_.test_flag\n\n    def test_finite_check(self):\n        pdhg_ = PDHG(\n            f=self.f,\n            g=self.g,\n            C=self.C,\n            tau=self.τ,\n            sigma=self.σ,\n            maxiter=self.maxiter,\n            nanstop=True,\n        )\n        pdhg_.step()\n        pdhg_.x = pdhg_.x.at[0].set(np.nan)\n        with pytest.raises(ValueError):\n            pdhg_.solve()\n\n\nclass TestBlockArray:\n    def setup_method(self, method):\n        np.random.seed(12345)\n        self.y = snp.blockarray(\n            (\n                np.random.randn(32, 33).astype(np.float32),\n                np.random.randn(\n                    17,\n                ).astype(np.float32),\n            )\n        )\n        self.λ = 1e0\n        self.maxiter = 1\n        self.τ = 1e-1\n        self.σ = 1e-1\n        self.A = linop.Identity(self.y.shape)\n        self.f = loss.SquaredL2Loss(y=self.y, A=self.A)\n        self.g = (self.λ / 2) * functional.L2Norm()\n        self.C = linop.Identity(self.y.shape)\n\n    def test_blockarray(self):\n        pdhg_ = PDHG(\n            f=self.f,\n            g=self.g,\n            C=self.C,\n            tau=self.τ,\n            sigma=self.σ,\n            maxiter=self.maxiter,\n        )\n        x = pdhg_.solve()\n        assert isinstance(x, BlockArray)\n\n\nclass TestReal:\n    def setup_method(self, method):\n        np.random.seed(12345)\n        N = 8\n        MB = 10\n        # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2\n        Amx = np.diag(np.random.randn(N).astype(np.float32))\n        Bmx = np.random.randn(MB, N).astype(np.float32)\n        y = np.random.randn(N).astype(np.float32)\n        λ = 1e0\n        self.Amx = Amx\n        self.Bmx = Bmx\n        self.y = snp.array(y)\n        self.λ = λ\n        # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y\n        self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x\n        self.grdb = Amx.T @ y\n\n    def test_pdhg(self):\n        maxiter = 300\n        τ = 2e-1\n        σ = 2e-1\n        A = linop.Diagonal(snp.diag(self.Amx))\n        f = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2) * functional.SquaredL2Norm()\n        C = linop.MatrixOperator(self.Bmx)\n        pdhg_ = PDHG(\n            f=f,\n            g=g,\n            C=C,\n            tau=τ,\n            sigma=σ,\n            maxiter=maxiter,\n            x0=A.adj(self.y),\n        )\n        x = pdhg_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4\n\n    def test_pdhg_saveload(self):\n        maxiter = 5\n        τ = 2e-1\n        σ = 2e-1\n        A = linop.Diagonal(snp.diag(self.Amx))\n        f = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2) * functional.SquaredL2Norm()\n        C = linop.MatrixOperator(self.Bmx)\n        pdhg0 = PDHG(\n            f=f,\n            g=g,\n            C=C,\n            tau=τ,\n            sigma=σ,\n            maxiter=maxiter,\n            x0=A.adj(self.y),\n        )\n        pdhg0.solve()\n        with tempfile.TemporaryDirectory() as tmpdir:\n            path = os.path.join(tmpdir, \"pdhg.npz\")\n            pdhg0.save_state(path)\n            pdhg0.solve()\n            h0 = pdhg0.history()\n            pdhg1 = PDHG(\n                f=f,\n                g=g,\n                C=C,\n                tau=τ,\n                sigma=σ,\n                maxiter=maxiter,\n                x0=A.adj(self.y),\n            )\n            pdhg1.load_state(path)\n            pdhg1.solve()\n            h1 = pdhg1.history()\n            np.testing.assert_allclose(pdhg0.minimizer(), pdhg1.minimizer(), atol=1e-7)\n            assert np.abs(h0[-1].Objective - h1[-1].Objective) < 1e-6\n\n    def test_nlpdhg(self):\n        maxiter = 300\n        τ = 2e-1\n        σ = 2e-1\n        A = linop.Diagonal(snp.diag(self.Amx))\n        f = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2) * functional.SquaredL2Norm()\n        cfn = lambda x: self.Bmx @ x\n        Cop = operator.operator_from_function(cfn, \"Cop\")\n        C = Cop(input_shape=self.Bmx.shape[1:])\n        pdhg_ = PDHG(\n            f=f,\n            g=g,\n            C=C,\n            tau=τ,\n            sigma=σ,\n            maxiter=maxiter,\n            x0=A.adj(self.y),\n        )\n        x = pdhg_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4\n\n\nclass TestComplex:\n    def setup_method(self, method):\n        N = 8\n        MB = 10\n        # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2\n        Amx, key = random.randn((N,), dtype=np.complex64, key=None)\n        Amx = snp.diag(Amx)\n        Bmx, key = random.randn((MB, N), dtype=np.complex64, key=key)\n        y, key = random.randn((N,), dtype=np.complex64, key=key)\n        λ = 1e0\n        self.Amx = Amx\n        self.Bmx = Bmx\n        self.y = snp.array(y)\n        self.λ = λ\n        # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y\n        self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x\n        self.grdb = Amx.conj().T @ y\n\n    def test_pdhg(self):\n        maxiter = 300\n        τ = 2e-1\n        σ = 2e-1\n        A = linop.Diagonal(snp.diag(self.Amx))\n        f = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2) * functional.SquaredL2Norm()\n        C = linop.MatrixOperator(self.Bmx)\n        pdhg_ = PDHG(\n            f=f,\n            g=g,\n            C=C,\n            tau=τ,\n            sigma=σ,\n            maxiter=maxiter,\n            x0=A.adj(self.y),\n        )\n        x = pdhg_.solve()\n        assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 5e-4\n\n\nclass TestEstimateParameters:\n    def setup_method(self):\n        shape = (32, 33)\n        A = linop.Identity(shape, input_dtype=np.float32)\n        B = linop.Identity(shape, input_dtype=np.complex64)\n        opcls = operator.operator_from_function(lambda x: snp.abs(x), \"op\")\n        C = opcls(input_shape=shape, input_dtype=np.float32)\n        D = opcls(input_shape=shape, input_dtype=np.complex64)\n        self.operators = [A, B, C, D]\n\n    def test_operators_dlft(self):\n        for op in self.operators[0:2]:\n            tau, sigma = PDHG.estimate_parameters(op, factor=1.0)\n            assert snp.abs(tau - sigma) < 1e-6\n            assert snp.abs(tau - 1.0) < 1e-6\n\n    def test_operators(self):\n        for op in self.operators:\n            x = snp.ones(op.input_shape, op.input_dtype)\n            tau, sigma = PDHG.estimate_parameters(op, x=x, factor=None)\n            assert snp.abs(tau - sigma) < 1e-6\n            assert snp.abs(tau - 1.0) < 1e-6\n\n    def test_ratio(self):\n        op = self.operators[0]\n        tau, sigma = PDHG.estimate_parameters(op, factor=1.0, ratio=10.0)\n        assert snp.abs(tau * sigma - 1.0) < 1e-6\n        assert snp.abs(sigma - 10.0 * tau) < 1e-6\n"
  },
  {
    "path": "scico/test/optimize/test_pgm.py",
    "content": "import os\nimport tempfile\n\nimport numpy as np\n\nimport jax\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico import functional, linop, loss, random\nfrom scico.optimize import PGM, AcceleratedPGM\nfrom scico.optimize.pgm import (\n    AdaptiveBBStepSize,\n    BBStepSize,\n    LineSearchStepSize,\n    RobustLineSearchStepSize,\n)\n\n\nclass TestSet:\n    def setup_method(self, method):\n        np.random.seed(12345)\n        M = 5\n        N = 4\n        # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2\n        Amx = np.random.randn(M, N).astype(np.float32)\n        Bmx = np.identity(N)\n        y = snp.array(np.random.randn(M).astype(np.float32))\n        λ = 1e0\n        self.Amx = Amx\n        self.y = y\n        self.λ = λ\n        # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y\n        self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x\n        self.grdb = Amx.T @ y\n\n    def test_pgm(self):\n        maxiter = 100\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 1.05 * linop.power_iteration(A.T @ A)[0]\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        pgm_ = PGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y))\n        x = pgm_.solve()\n        np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3)\n\n    def test_pgm_saveload(self):\n        maxiter = 5\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 1.05 * linop.power_iteration(A.T @ A)[0]\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        pgm0 = PGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y))\n        pgm0.solve()\n        with tempfile.TemporaryDirectory() as tmpdir:\n            path = os.path.join(tmpdir, \"pgm.npz\")\n            pgm0.save_state(path)\n            pgm0.solve()\n            h0 = pgm0.history()\n            pgm1 = PGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y))\n            pgm1.load_state(path)\n            pgm1.solve()\n            h1 = pgm1.history()\n            np.testing.assert_allclose(pgm0.minimizer(), pgm1.minimizer(), rtol=1e-6)\n            assert np.abs(h0[-1].Objective - h1[-1].Objective) < 1e-6\n\n    def test_pgm_isfinite(self):\n        maxiter = 5\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 1.05 * linop.power_iteration(A.T @ A)[0]\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        pgm_ = PGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y), nanstop=True)\n        pgm_.step()\n        pgm_.x = pgm_.x.at[0].set(np.nan)\n        with pytest.raises(ValueError):\n            pgm_.solve()\n\n    def test_accelerated_pgm(self):\n        maxiter = 100\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 1.05 * linop.power_iteration(A.T @ A)[0]\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        apgm_ = AcceleratedPGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y))\n        x = apgm_.solve()\n        np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3)\n\n    def test_accelerated_pgm_saveload(self):\n        maxiter = 5\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 1.05 * linop.power_iteration(A.T @ A)[0]\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        apgm0 = AcceleratedPGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y))\n        apgm0.solve()\n        with tempfile.TemporaryDirectory() as tmpdir:\n            path = os.path.join(tmpdir, \"pgm.npz\")\n            apgm0.save_state(path)\n            apgm0.solve()\n            h0 = apgm0.history()\n            apgm1 = AcceleratedPGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y))\n            apgm1.load_state(path)\n            apgm1.solve()\n            h1 = apgm1.history()\n            np.testing.assert_allclose(apgm0.minimizer(), apgm1.minimizer(), rtol=1e-6)\n            assert np.abs(h0[-1].Objective - h1[-1].Objective) < 1e-6\n\n    def test_accelerated_pgm_isfinite(self):\n        maxiter = 5\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 1.05 * linop.power_iteration(A.T @ A)[0]\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        apgm_ = AcceleratedPGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y), nanstop=True)\n        apgm_.step()\n        apgm_.v = apgm_.v.at[0].set(np.nan)\n        with pytest.raises(ValueError):\n            apgm_.solve()\n\n    def test_pgm_BB_step_size(self):\n        maxiter = 100\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 1.05 * linop.power_iteration(A.T @ A)[0] / 5.0\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        pgm_ = PGM(\n            f=loss_,\n            g=g,\n            L0=L0,\n            x0=A.adj(self.y),\n            step_size=BBStepSize(),\n            maxiter=maxiter,\n        )\n        x = pgm_.solve()\n        np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3)\n\n    def test_pgm_adaptive_BB_step_size(self):\n        maxiter = 100\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 1.05 * linop.power_iteration(A.T @ A)[0] / 5.0\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        pgm_ = PGM(\n            f=loss_,\n            g=g,\n            L0=L0,\n            x0=A.adj(self.y),\n            step_size=AdaptiveBBStepSize(),\n            maxiter=maxiter,\n        )\n        x = pgm_.solve()\n\n    def test_accelerated_pgm_line_search(self):\n        maxiter = 150\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 1.05 * linop.power_iteration(A.T @ A)[0] / 5.0\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        apgm_ = AcceleratedPGM(\n            f=loss_,\n            g=g,\n            L0=L0,\n            x0=A.adj(self.y),\n            step_size=LineSearchStepSize(gamma_u=1.03, maxiter=55),\n            maxiter=maxiter,\n        )\n        x = apgm_.solve()\n        np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3)\n\n    def test_accelerated_pgm_robust_line_search(self):\n        maxiter = 100\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 1.05 * linop.power_iteration(A.T @ A)[0] / 5.0\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        apgm_ = AcceleratedPGM(\n            f=loss_,\n            g=g,\n            L0=L0,\n            x0=A.adj(self.y),\n            step_size=RobustLineSearchStepSize(gamma_d=0.95, gamma_u=1.05, maxiter=80),\n            maxiter=maxiter,\n        )\n        x = apgm_.solve()\n        np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3)\n\n    def test_pgm_BB_step_size_jit(self):\n        maxiter = 100\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 1.05 * linop.power_iteration(A.T @ A)[0] / 5.0\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        pgm_ = PGM(\n            f=loss_,\n            g=g,\n            L0=L0,\n            x0=A.adj(self.y),\n            step_size=BBStepSize(),\n            maxiter=maxiter,\n        )\n        x = pgm_.x\n        try:\n            update_step = jax.jit(pgm_.step_size.update)\n            L = update_step(x)\n        except Exception as e:\n            print(e)\n            assert 0\n\n    def test_accelerated_pgm_adaptive_BB_step_size_jit(self):\n        maxiter = 100\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 1.05 * linop.power_iteration(A.T @ A)[0] / 5.0\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        apgm_ = AcceleratedPGM(\n            f=loss_,\n            g=g,\n            L0=L0,\n            x0=A.adj(self.y),\n            step_size=AdaptiveBBStepSize(),\n            maxiter=maxiter,\n        )\n        x = apgm_.x\n        try:\n            update_step = jax.jit(apgm_.step_size.update)\n            L = update_step(x)\n        except Exception as e:\n            print(e)\n            assert 0\n\n\nclass TestComplex:\n    def setup_method(self, method):\n        M = 5\n        N = 4\n        # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||x||_2^2\n        Amx, key = random.randn((M, N), dtype=np.complex64, key=None)\n        Bmx = np.identity(N)\n        y = snp.array(np.random.randn(M))\n        λ = 1e0\n        self.Amx = Amx\n        self.Bmx = Bmx\n        self.y = y\n        self.λ = λ\n        # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y\n        self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.T @ Bmx) @ x\n        self.grdb = Amx.conj().T @ y\n\n    def test_pgm(self):\n        maxiter = 150\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 50.0\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        pgm_ = PGM(\n            f=loss_,\n            g=g,\n            L0=L0,\n            x0=A.adj(self.y),\n            maxiter=maxiter,\n        )\n        x = pgm_.solve()\n        np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3)\n\n    def test_accelerated_pgm(self):\n        maxiter = 100\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 50.0\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        apgm_ = AcceleratedPGM(f=loss_, g=g, L0=L0, x0=A.adj(self.y), maxiter=maxiter)\n        x = apgm_.solve()\n        np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3)\n\n    def test_pgm_BB_step_size(self):\n        maxiter = 100\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 10.0\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        pgm_ = PGM(\n            f=loss_,\n            g=g,\n            L0=L0,\n            x0=A.adj(self.y),\n            step_size=BBStepSize(),\n            maxiter=maxiter,\n        )\n        x = pgm_.solve()\n        np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3)\n\n    def test_pgm_adaptive_BB_step_size(self):\n        maxiter = 100\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 10.0\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        pgm_ = PGM(\n            f=loss_,\n            g=g,\n            L0=L0,\n            x0=A.adj(self.y),\n            step_size=AdaptiveBBStepSize(),\n            maxiter=maxiter,\n        )\n        x = pgm_.solve()\n        np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3)\n\n    def test_accelerated_pgm_line_search(self):\n        maxiter = 100\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 10.0\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        apgm_ = AcceleratedPGM(\n            f=loss_,\n            g=g,\n            L0=L0,\n            x0=A.adj(self.y),\n            step_size=LineSearchStepSize(gamma_u=1.03, maxiter=55),\n            maxiter=maxiter,\n        )\n        x = apgm_.solve()\n        np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3)\n\n    def test_accelerated_pgm_robust_line_search(self):\n        maxiter = 100\n        A = linop.MatrixOperator(self.Amx)\n        L0 = 10.0\n        loss_ = loss.SquaredL2Loss(y=self.y, A=A)\n        g = (self.λ / 2.0) * functional.SquaredL2Norm()\n        apgm_ = AcceleratedPGM(\n            f=loss_,\n            g=g,\n            L0=L0,\n            x0=A.adj(self.y),\n            step_size=RobustLineSearchStepSize(gamma_d=0.95, gamma_u=1.05, maxiter=80),\n            maxiter=maxiter,\n        )\n        x = apgm_.solve()\n        np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3)\n"
  },
  {
    "path": "scico/test/osver.py",
    "content": "import platform\n\nfrom packaging.version import parse\n\n\ndef osx_ver_geq_than(verstr):\n    \"\"\"Determine relative platform OSX version.\n\n    Determine whether platform has OSX version that is as recent as or\n    more recent than verstr. Returns ``False`` if the OS is not OSX.\n    \"\"\"\n    if platform.system() != \"Darwin\":\n        return False\n    osxver = platform.mac_ver()[0]\n    return parse(osxver) >= parse(verstr)\n"
  },
  {
    "path": "scico/test/test_core.py",
    "content": "import numpy as np\n\nimport jax\n\nimport pytest\n\nimport scico\nimport scico.numpy as snp\nfrom scico.random import randn\n\n\nclass GradTestObj:\n    def __init__(self, dtype):\n        M, N = (3, 4)\n        key = jax.random.key(12345)\n        self.dtype = dtype\n\n        self.A, key = randn((M, N), dtype=dtype, key=key)\n        self.x, key = randn((N,), dtype=dtype, key=key)\n        self.y, key = randn((M,), dtype=dtype, key=key)\n\n        self.f = lambda x: 0.5 * snp.sum(snp.abs(self.y - self.A @ x) ** 2)\n\n\n@pytest.fixture(scope=\"module\", params=[np.float32, np.complex64])\ndef testobj(request):\n    yield GradTestObj(request.param)\n\n\ndef test_grad(testobj):\n    A = testobj.A\n    x = testobj.x\n    y = testobj.y\n    f = testobj.f\n\n    sgrad = scico.grad(f)(x)\n    an_grad = A.conj().T @ (A @ x - y)\n\n    np.testing.assert_allclose(sgrad, an_grad, rtol=1e-4)\n\n\ndef test_grad_aux(testobj):\n    A = testobj.A\n    x = testobj.x\n    y = testobj.y\n\n    def g(x):\n        return testobj.f(x), True\n\n    sgrad, aux = scico.grad(g, has_aux=True)(x)\n    an_grad = A.conj().T @ (A @ x - y)\n\n    assert aux == True\n    np.testing.assert_allclose(sgrad, an_grad, rtol=1e-4)\n\n\ndef test_value_and_grad(testobj):\n    A = testobj.A\n    x = testobj.x\n    y = testobj.y\n    f = testobj.f\n\n    svalue, sgrad = scico.value_and_grad(f)(x)\n\n    an_val = f(x)\n    an_grad = A.conj().T @ (A @ x - y)\n\n    np.testing.assert_allclose(svalue, an_val, rtol=1e-4)\n    np.testing.assert_allclose(sgrad, an_grad, rtol=1e-4)\n\n\ndef test_value_and_grad_aux(testobj):\n    A = testobj.A\n    x = testobj.x\n    y = testobj.y\n\n    def g(x):\n        return testobj.f(x), True\n\n    (svalue, aux), sgrad = scico.value_and_grad(g, has_aux=True)(x)\n\n    an_val, aux_ = g(x)\n    an_grad = A.conj().T @ (A @ x - y)\n\n    assert aux == aux_\n    np.testing.assert_allclose(svalue, an_val, rtol=1e-4)\n    np.testing.assert_allclose(sgrad, an_grad, rtol=1e-4)\n\n\n@pytest.mark.parametrize(\"shape\", [(2, 3), ((2, 3), (4,))])\ndef test_linear_transpose(shape):\n    fun = lambda x: snp.pad(x, 2)\n    za = snp.zeros(shape, dtype=snp.float32)\n    fza = fun(za)\n    dts = jax.ShapeDtypeStruct(shape, dtype=snp.float32)\n    lt_za = scico.linear_transpose(fun, za)\n    lt_dts = scico.linear_transpose(fun, dts)\n    lt_za_fza = lt_za(fza)[0]\n    lt_dts_fza = lt_dts(fza)[0]\n    assert lt_za_fza.shape == lt_dts_fza.shape\n    assert lt_za_fza.dtype == lt_dts_fza.dtype\n\n\n@pytest.mark.parametrize(\"shape\", [(2, 3), ((2, 3), (4,))])\ndef test_linear_adjoint_shape(shape):\n    fun = lambda x: snp.pad(x, 2)\n    za = snp.zeros(shape, dtype=snp.float32)\n    fza = fun(za)\n    dts = jax.ShapeDtypeStruct(shape, dtype=snp.float32)\n    lt_za = scico.linear_adjoint(fun, za)\n    lt_dts = scico.linear_adjoint(fun, dts)\n    lt_za_fza = lt_za(fza)[0]\n    lt_dts_fza = lt_dts(fza)[0]\n    assert lt_za_fza.shape == lt_dts_fza.shape\n    assert lt_za_fza.dtype == lt_dts_fza.dtype\n\n\ndef test_linear_adjoint(testobj):\n    # Verify that linear_adjoint returns a function that\n    # implements f(y) = A.conj().T @ y\n    A = testobj.A\n    x = testobj.x\n    y = testobj.y\n\n    f = lambda x: A @ x\n\n    A_adj = scico.linear_adjoint(f, x)\n    np.testing.assert_allclose(A.conj().T @ y, A_adj(testobj.y)[0], rtol=1e-4)\n\n    # Test a function with with multiple inputs\n    # Same as np.array([0.5, -0.5j])\n    f = lambda x, y: 0.5 * x - 0.5j * y\n\n    f_transpose = scico.linear_adjoint(f, 1.0j, 1.0j)\n    a, b = f_transpose(1.0 + 0.0j)\n    assert a == 0.5\n    assert b == 0.5j\n\n\ndef test_linear_adjoint_r_to_c():\n    f = snp.fft.rfft\n    x, key = randn((32,))\n    adj = scico.linear_adjoint(f, x)\n\n    a = snp.sum(x * adj(f(x))[0])\n    b = snp.linalg.norm(f(x)) ** 2\n\n    np.testing.assert_allclose(a, b, rtol=1e-4)\n\n\ndef test_linear_adjoint_c_to_r():\n    f = snp.fft.irfft\n    x, key = randn((32,), dtype=np.complex64)\n    adj = scico.linear_adjoint(f, x)\n\n    a = snp.sum(x.conj() * adj(f(x))[0])\n    b = snp.linalg.norm(f(x)) ** 2\n\n    np.testing.assert_allclose(a.real, b.real, rtol=1e-4)\n    np.testing.assert_allclose(a.imag, 0, atol=1e-2)\n\n\n@pytest.mark.parametrize(\"dtype\", [np.float32, np.complex64])\ndef test_cvjp(dtype):\n    A, key = randn((3, 3), dtype=dtype)\n    B, key = randn((3, 4), dtype=dtype, key=key)\n    xp, key = randn((3,), dtype=dtype, key=key)\n    yp, key = randn((4,), dtype=dtype, key=key)\n\n    def fun(x, y):\n        return A @ x + B @ y\n\n    px, jfnx = scico.cvjp(fun, xp, yp, jidx=0)\n    py, jfny = scico.cvjp(fun, xp, yp, jidx=1)\n\n    for k in range(3):\n        v = np.zeros((3,), dtype=dtype)\n        v[k] = 1.0\n        np.testing.assert_allclose(jfnx(v)[0], A[k].conj())\n        np.testing.assert_allclose(jfny(v)[0], B[k].conj())\n\n\n@pytest.mark.parametrize(\n    \"argskwargs\",\n    [\n        [(snp.ones((3,)), snp.ones((3,)), 1.0), {}],\n        [(1.1 * snp.ones((3,)), snp.ones((3,))), {\"z\": snp.zeros((3,))}],\n        [(snp.ones(((2,), (3, 2))), 1.0, 1.0), {}],\n        [\n            (snp.ones(((2,), (3, 2))), snp.blockarray(((2,), (3, 2)))),\n            {\"z\": 2.0 * snp.ones(((2,), (3, 2)))},\n        ],\n    ],\n)\ndef test_eval_shape_1(argskwargs):\n    def _fun(x, y, z):\n        \"\"\"Test function\"\"\"\n        return x + y * z\n\n    def _conv(arg):\n        \"\"\"Convert array to jax.ShapeDtypeStruct.\"\"\"\n        if hasattr(arg, \"shape\"):\n            return jax.ShapeDtypeStruct(arg.shape, dtype=arg.dtype)\n        else:\n            return arg\n\n    args, kwargs = argskwargs\n    # Reference shape computed for array objects\n    ref_shape = jax.eval_shape(_fun, *args, **kwargs)\n    map_args = [_conv(v) for v in args]\n    map_kwargs = {k: _conv(v) for k, v in kwargs.items()}\n    # Test shape computed for jax.ShapeDtypeStruct objects\n    tst_shape = scico.eval_shape(_fun, *map_args, **map_kwargs)\n    assert tst_shape.shape == ref_shape.shape\n\n\n@pytest.mark.parametrize(\n    \"arrdts\",\n    [\n        [snp.ones((3, 2), dtype=snp.float32), jax.ShapeDtypeStruct((3, 2), dtype=snp.float32)],\n        [\n            snp.ones(((3,), (2, 3)), dtype=snp.float32),\n            jax.ShapeDtypeStruct(((3,), (2, 3)), dtype=snp.float32),\n        ],\n    ],\n)\ndef test_eval_shape_2(arrdts):\n    _fun = lambda x: snp.pad(x, 2)\n    arr, dts = arrdts\n    # Reference shape computed for array\n    ref_shape = jax.eval_shape(_fun, arr)\n    # Test shape computed for jax.ShapeDtypeStruct\n    tst_shape = scico.eval_shape(_fun, dts)\n    assert tst_shape.shape == ref_shape.shape\n"
  },
  {
    "path": "scico/test/test_data.py",
    "content": "import os\n\nimport pytest\n\nfrom scico import data\n\nskipif_reason = (\n    \"\\nThe data submodule must be cloned and initialized. If the main repository\"\n    \" is already cloned, use the following in the root directory to get the data\"\n    \" submodule:\\n\\tgit submodule update --init --recursive\\nOtherwise, make sure\"\n    \" to clone using:\\n\\tgit clone --recurse-submodules git@github.com:lanl/scico.git\"\n    \"\\nAnd after cloning run:\\n\\tgit submodule init && git submodule update.\\n\"\n)\n\nexamples = os.path.join(os.path.dirname(data.__file__), \"examples\")\npytestmark = pytest.mark.skipif(not os.path.isdir(examples), reason=skipif_reason)\n\n\nclass TestSet:\n    def test_kodim23_uint(self):\n        x = data.kodim23()\n        assert x.dtype.name == \"uint8\"\n        assert x.shape == (512, 768, 3)\n\n    def test_kodim23_float(self):\n        x = data.kodim23(asfloat=True)\n        assert x.dtype.name == \"float32\"\n        assert x.shape == (512, 768, 3)\n"
  },
  {
    "path": "scico/test/test_denoiser.py",
    "content": "import numpy as np\n\nimport jax\n\nimport pytest\n\nfrom scico.denoiser import DnCNN, bm3d, bm4d, have_bm3d, have_bm4d\nfrom scico.metric import rel_res\nfrom scico.random import randn\nfrom scico.test.osver import osx_ver_geq_than\n\nlevel = 3\n\n\n@pytest.fixture(autouse=True, scope=\"module\")\ndef module_setup_teardown(request):\n    global level\n    level = int(request.config.getoption(\"--level\"))\n\n\n# bm3d is known to be broken on OSX 11.6.5. It may be broken on earlier versions too,\n# but this has not been confirmed\n@pytest.mark.skipif(osx_ver_geq_than(\"11.6.5\"), reason=\"bm3d broken on this platform\")\n@pytest.mark.skipif(not have_bm3d, reason=\"bm3d package not installed\")\nclass TestBM3D:\n    def setup_method(self):\n        key = None\n        self.x_gry, key = randn((32, 33), key=key, dtype=np.float32)\n        self.x_rgb, key = randn((33, 34, 3), key=key, dtype=np.float32)\n\n    def test_shape(self):\n        assert bm3d(self.x_gry, 1.0).shape == self.x_gry.shape\n        assert bm3d(self.x_rgb, 1.0, is_rgb=True).shape == self.x_rgb.shape\n\n    def test_gry(self):\n        no_jit = bm3d(self.x_gry, 1.0)\n        assert no_jit.dtype == np.float32\n        if level > 2:\n            jitted = jax.jit(bm3d)(self.x_gry, 1.0)\n            assert np.linalg.norm(no_jit - jitted) < 1e-3\n            assert jitted.dtype == np.float32\n\n    def test_rgb(self):\n        no_jit = bm3d(self.x_rgb, 1.0)\n        assert no_jit.dtype == np.float32\n        if level > 2:\n            jitted = jax.jit(bm3d)(self.x_rgb, 1.0, is_rgb=True)\n            assert np.linalg.norm(no_jit - jitted) < 1e-3\n            assert jitted.dtype == np.float32\n\n    def test_bad_inputs(self):\n        x, key = randn((32,), key=None, dtype=np.float32)\n        with pytest.raises(ValueError):\n            bm3d(x, 1.0)\n        x, key = randn((12, 12, 4, 3), key=key, dtype=np.float32)\n        with pytest.raises(ValueError):\n            bm3d(x, 1.0)\n        x, key = randn(((2, 3), (3, 4, 5)), key=key, dtype=np.float32)\n        with pytest.raises(ValueError):\n            bm3d(x, 1.0)\n        x, key = randn((5, 9), key=key, dtype=np.float32)\n        with pytest.raises(ValueError):\n            bm3d(x, 1.0)\n        z, key = randn((32, 32), key=key, dtype=np.complex64)\n        with pytest.raises(TypeError):\n            bm3d(z, 1.0)\n\n\n# bm4d is known to be broken on OSX 11.6.5. It may be broken on earlier versions too,\n# but this has not been confirmed\n@pytest.mark.skipif(osx_ver_geq_than(\"11.6.5\"), reason=\"bm4d broken on this platform\")\n@pytest.mark.skipif(not have_bm4d, reason=\"bm4d package not installed\")\nclass TestBM4D:\n    def setup_method(self):\n        key = None\n        self.x1, key = randn((16, 17, 18), key=key, dtype=np.float32)\n        self.x2, key = randn((16, 17, 8), key=key, dtype=np.float32)\n        self.x3, key = randn((16, 17, 9, 1, 1), key=key, dtype=np.float32)\n\n    def test_shape(self):\n        if level > 2:\n            assert bm4d(self.x1, 1.0).shape == self.x1.shape\n        assert bm4d(self.x2, 1.0).shape == self.x2.shape\n        if level > 1:\n            assert bm4d(self.x3, 1.0).shape == self.x3.shape\n\n    def test_jit(self):\n        if level > 2:\n            no_jit = bm4d(self.x1, 1.0)\n            jitted = jax.jit(bm4d)(self.x1, 1.0)\n            assert np.linalg.norm(no_jit - jitted) < 2e-3\n            assert no_jit.dtype == np.float32\n            assert jitted.dtype == np.float32\n\n        no_jit = bm4d(self.x2, 1.0)\n        assert no_jit.dtype == np.float32\n        if level > 1:\n            jitted = jax.jit(bm4d)(self.x2, 1.0)\n            assert np.linalg.norm(no_jit - jitted) < 2e-3\n            assert jitted.dtype == np.float32\n\n    def test_bad_inputs(self):\n        x, key = randn((32,), key=None, dtype=np.float32)\n        with pytest.raises(ValueError):\n            bm4d(x, 1.0)\n        x, key = randn((12, 12, 4, 3), key=key, dtype=np.float32)\n        with pytest.raises(ValueError):\n            bm4d(x, 1.0)\n        x, key = randn(((2, 3), (3, 4, 5)), key=key, dtype=np.float32)\n        with pytest.raises(ValueError):\n            bm4d(x, 1.0)\n        x, key = randn((5, 9), key=key, dtype=np.float32)\n        with pytest.raises(ValueError):\n            bm4d(x, 1.0)\n        z, key = randn((32, 32), key=key, dtype=np.complex64)\n        with pytest.raises(TypeError):\n            bm4d(z, 1.0)\n\n\nclass TestDnCNN:\n    def setup_method(self):\n        key = None\n        self.x_sngchn, key = randn((32, 33), key=key, dtype=np.float32)\n        self.x_mltchn, key = randn((33, 34, 5), key=key, dtype=np.float32)\n        self.dncnn = DnCNN()\n\n    def test_single_channel(self):\n        no_jit = self.dncnn(self.x_sngchn)\n        jitted = jax.jit(self.dncnn)(self.x_sngchn)\n        assert rel_res(no_jit, jitted) < 1e-6\n        assert no_jit.dtype == np.float32\n        assert jitted.dtype == np.float32\n\n    def test_multi_channel(self):\n        no_jit = self.dncnn(self.x_mltchn)\n        jitted = jax.jit(self.dncnn)(self.x_mltchn)\n        assert rel_res(no_jit, jitted) < 1e-6\n        assert no_jit.dtype == np.float32\n        assert jitted.dtype == np.float32\n\n    def test_init(self):\n        dncnn = DnCNN(variant=\"6L\")\n        x = dncnn(self.x_sngchn)\n        dncnn = DnCNN(variant=\"17H\")\n        x = dncnn(self.x_mltchn)\n        with pytest.raises(ValueError):\n            dncnn = DnCNN(variant=\"3A\")\n\n    def test_bad_inputs(self):\n        x, key = randn((32,), key=None, dtype=np.float32)\n        with pytest.raises(ValueError):\n            self.dncnn(x)\n        x, key = randn((12, 12, 4, 3), key=None, dtype=np.float32)\n        with pytest.raises(ValueError):\n            self.dncnn(x)\n        x, key = randn(((2, 3), (3, 4, 5)), key=None, dtype=np.float32)\n        with pytest.raises(ValueError):\n            self.dncnn(x)\n        z, key = randn((32, 32), key=None, dtype=np.complex64)\n        with pytest.raises(TypeError):\n            self.dncnn(z)\n\n\nclass TestNonBLindDnCNN:\n    def setup_method(self):\n        key = None\n        self.x_sngchn, key = randn((32, 33), key=key, dtype=np.float32)\n        self.x_mltchn, key = randn((33, 34, 5), key=key, dtype=np.float32)\n        self.sigma = 0.1\n        self.dncnn = DnCNN(variant=\"6N\")\n\n    def test_single_channel(self):\n        rslt = self.dncnn(self.x_sngchn, sigma=self.sigma)\n        assert rslt.dtype == np.float32\n\n    def test_multi_channel(self):\n        rslt = self.dncnn(self.x_mltchn, sigma=self.sigma)\n        assert rslt.dtype == np.float32\n\n    def test_bad_inputs(self):\n        with pytest.raises(ValueError):\n            rslt = self.dncnn(self.x_sngchn)\n"
  },
  {
    "path": "scico/test/test_diagnostics.py",
    "content": "from collections import OrderedDict\n\nimport pytest\n\nfrom scico import diagnostics\n\n\nclass TestSet:\n    def test_itstat(self):\n        its = diagnostics.IterationStats(OrderedDict({\"Iter\": \"%d\", \"Obj Val\": \"%8.2e\"}))\n        its.insert((0, 1.5))\n        its.insert((1, 1e2))\n        assert its.history()[0].Iter == 0\n        assert its.history()[1].Iter == 1\n        assert its.history()[1].Obj_Val == 1e2\n        assert its.history(transpose=True).Obj_Val == [1.5, 100.0]\n\n    def test_display(self, capsys):\n        its = diagnostics.IterationStats({\"Iter\": \"%d\"}, display=True, period=2, overwrite=False)\n        its.insert((0,))\n        cap = capsys.readouterr()\n        assert cap.out == \"Iter\\n----\\n   0\\n\"\n        its.insert((1,))\n        cap = capsys.readouterr()\n        assert cap.out == \"\"\n        its.insert((2,))\n        cap = capsys.readouterr()\n        assert cap.out == \"   2\\n\"\n\n    def test_exception(self):\n        with pytest.raises(TypeError):\n            its = diagnostics.IterationStats([\"Iter\", \"%z4d\"], display=False)\n        with pytest.raises(ValueError):\n            its = diagnostics.IterationStats({\"Iter\": \"%z4d\"}, display=False)\n\n    def test_warning(self):\n        with pytest.warns(UserWarning):\n            its = diagnostics.IterationStats({\"Iter\": \"%4e\"}, display=False)\n"
  },
  {
    "path": "scico/test/test_examples.py",
    "content": "import os\nimport tempfile\n\nimport numpy as np\n\nimport imageio.v3 as iio\nimport pytest\n\nimport scico.numpy as snp\nfrom scico.examples import (\n    create_3d_foam_phantom,\n    create_circular_phantom,\n    create_cone,\n    create_conv_sparse_phantom,\n    create_tangle_phantom,\n    downsample_volume,\n    epfl_deconv_data,\n    gaussian,\n    phase_diff,\n    rgb2gray,\n    spnoise,\n    tile_volume_slices,\n    ucb_diffusercam_data,\n    volume_read,\n)\n\n# These tests are for the scico.examples module, NOT the example scripts\n\n\ndef test_rgb2gray():\n    rgb = np.ones((31, 32, 3), dtype=np.float32)\n    gry = rgb2gray(rgb)\n    assert np.abs(gry.mean() - 1.0) < 1e-6\n\n\ndef test_volume_read():\n    temp_dir = tempfile.TemporaryDirectory()\n    v0 = np.zeros((32, 32), dtype=np.uint16)\n    v1 = np.ones((32, 32), dtype=np.uint16)\n    iio.imwrite(os.path.join(temp_dir.name, \"v0.tif\"), v0)\n    iio.imwrite(os.path.join(temp_dir.name, \"v1.tif\"), v1)\n    vol = volume_read(temp_dir.name, ext=\"tif\")\n    assert np.allclose(v0, vol[..., 0]) and np.allclose(v1, vol[..., 1])\n\n\ndef test_epfl_deconv_data():\n    temp_dir = tempfile.TemporaryDirectory()\n    y0 = np.zeros((32, 32), dtype=np.uint16)\n    psf0 = np.ones((32, 32), dtype=np.uint16)\n    np.savez(os.path.join(temp_dir.name, \"epfl_big_deconv_0.npz\"), y=y0, psf=psf0)\n    y, psf = epfl_deconv_data(0, cache_path=temp_dir.name)\n    assert np.allclose(y0, y) and np.allclose(psf0, psf)\n\n\ndef test_ucb_diffusercam_data():\n    temp_dir = tempfile.TemporaryDirectory()\n    y0 = np.zeros((32, 32), dtype=np.uint16)\n    psf0 = np.ones((8, 32, 32), dtype=np.uint16)\n    np.savez(os.path.join(temp_dir.name, \"ucb_diffcam_data.npz\"), y=y0, psf=psf0)\n    y, psf = ucb_diffusercam_data(cache_path=temp_dir.name)\n    assert np.allclose(y0, y) and np.allclose(psf0, psf)\n\n\ndef test_downsample_volume():\n    v0 = np.zeros((32, 32, 16))\n    v1 = downsample_volume(v0, rate=1)\n    assert v0.shape == v1.shape\n    v0 = np.zeros((32, 32, 16))\n    v1 = downsample_volume(v0, rate=2)\n    assert tuple([n // 2 for n in v0.shape]) == v1.shape\n    v0 = np.zeros((32, 32, 16))\n    v1 = downsample_volume(v0, rate=3)\n    assert tuple([round(n / 3) for n in v0.shape]) == v1.shape\n\n\ndef test_tile_volume_slices():\n    v = np.ones((16, 16, 16))\n    tvs = tile_volume_slices(v)\n    assert tvs.ndim == 2\n    v = np.ones((16, 16, 16, 3))\n    tvs = tile_volume_slices(v)\n    assert tvs.ndim == 3 and tvs.shape[-1] == 3\n\n\ndef test_gaussian():\n    g0 = gaussian((5, 5))\n    assert g0.shape == (5, 5)\n    g1 = gaussian((5, 5), sigma=np.array([[3, 0], [0, 2]]))\n    assert np.sum(g1 / g1.max()) > np.sum(g0 / g0.max())\n    with pytest.raises(ValueError):\n        g2 = gaussian((5, 5), sigma=np.array([[2, 2], [2, 2]]))\n\n\ndef test_create_circular_phantom():\n    img_shape = (32, 32)\n    radius_list = [2, 4, 8]\n    val_list = [2, 4, 8]\n    x_gt = create_circular_phantom(img_shape, radius_list, val_list)\n\n    assert x_gt.shape == img_shape\n    assert np.max(x_gt) == max(val_list)\n    assert np.min(x_gt) == 0\n\n\n@pytest.mark.parametrize(\n    \"img_shape\",\n    (\n        (3, 3),\n        (50, 51),\n        (3, 3, 3),\n    ),\n)\ndef test_create_cone(img_shape):\n    x_gt = create_cone(img_shape)\n    assert x_gt.shape == img_shape\n    # check symmetry\n    assert np.abs(x_gt[(0,) * len(img_shape)] - x_gt[(-1,) * len(img_shape)]) < 1e-6\n\n\n@pytest.mark.parametrize(\n    \"img_shape\",\n    (\n        (3, 3, 3),\n        (20, 21, 22),\n        (15, 15, 5),\n    ),\n)\n@pytest.mark.parametrize(\"N_sphere\", (3, 10, 20))\ndef test_create_3d_foam_phantom(img_shape, N_sphere):\n    x_gt = create_3d_foam_phantom(img_shape, N_sphere)\n    assert x_gt.shape == img_shape\n\n\ndef test_conv_sparse_phantom():\n    h, x = create_conv_sparse_phantom(64, 32)\n    assert h.shape == (3, 15, 15)\n    assert x.shape == (3, 64, 64)\n    assert np.sum(x > 0) == 32\n\n\ndef test_tangle_phantom():\n    v = create_tangle_phantom(3, 4, 5)\n    assert v.shape == (5, 4, 3)\n\n\ndef test_spnoise():\n    x = 0.5 * np.ones((10, 11))\n    y = spnoise(x, 0.5, nmin=0.01, nmax=0.99)\n    assert np.all(y >= 0.01)\n    assert np.all(y <= 0.99)\n    x = 0.5 * snp.ones((10, 11))\n    y = spnoise(x, 0.5, nmin=0.01, nmax=0.99)\n    assert np.all(y >= 0.01)\n    assert np.all(y <= 0.99)\n\n\ndef test_phase_diff():\n    x = np.pi * np.random.randn(16)\n    y = np.pi * np.random.randn(16)\n    d = phase_diff(x, y)\n    assert np.all(d >= 0)\n    assert np.all(d <= np.pi)\n"
  },
  {
    "path": "scico/test/test_function.py",
    "content": "import numpy as np\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico.function import Function\nfrom scico.linop import jacobian\nfrom scico.random import randn\n\n\nclass TestFunction:\n    def setup_method(self):\n        key = None\n        self.shape = (7, 8)\n        self.dtype = snp.float32\n        self.x, key = randn(self.shape, key=key, dtype=self.dtype)\n        self.y, key = randn(self.shape, key=key, dtype=self.dtype)\n        self.func = lambda x, y: snp.abs(x) + snp.abs(y)\n\n    def test_init(self):\n        F = Function((self.shape, self.shape), input_dtypes=self.dtype, eval_fn=self.func)\n        assert F.output_shape == self.shape\n        assert len(F.input_dtypes) == 2\n        assert F.output_dtype == self.dtype\n\n    def test_eval(self):\n        F = Function(\n            (self.shape, self.shape),\n            output_shape=self.shape,\n            eval_fn=self.func,\n            input_dtypes=(self.dtype, self.dtype),\n            output_dtype=self.dtype,\n        )\n        np.testing.assert_allclose(self.func(self.x, self.y), F(self.x, self.y))\n\n    def test_eval_jit(self):\n        F = Function(\n            (self.shape, self.shape),\n            output_shape=self.shape,\n            eval_fn=self.func,\n            input_dtypes=(self.dtype, self.dtype),\n            output_dtype=self.dtype,\n            jit=True,\n        )\n        np.testing.assert_allclose(self.func(self.x, self.y), F(self.x, self.y))\n\n    def test_slice(self):\n        F = Function((self.shape, self.shape), input_dtypes=self.dtype, eval_fn=self.func)\n        Op = F.slice(0, self.y)\n        np.testing.assert_allclose(Op(self.x), F(self.x, self.y))\n\n    def test_join(self):\n        F = Function((self.shape, self.shape), input_dtypes=self.dtype, eval_fn=self.func)\n        Op = F.join()\n        np.testing.assert_allclose(Op(snp.blockarray((self.x, self.y))), F(self.x, self.y))\n\n    def test_join_raise(self):\n        F = Function(\n            (self.shape, self.shape), input_dtypes=(snp.float32, snp.complex64), eval_fn=self.func\n        )\n        with pytest.raises(ValueError):\n            Op = F.join()\n\n\n@pytest.mark.parametrize(\"dtype\", [snp.float32, snp.complex64])\ndef test_jacobian(dtype):\n    N = 7\n    M = 8\n    key = None\n    fmx, key = randn((M, N), key=key, dtype=dtype)\n    gmx, key = randn((M, N), key=key, dtype=dtype)\n    F = Function(((N, 1), (N, 1)), input_dtypes=dtype, eval_fn=lambda x, y: fmx @ x + gmx @ y)\n    u0, key = randn((N, 1), key=key, dtype=dtype)\n    u1, key = randn((N, 1), key=key, dtype=dtype)\n    v, key = randn((N, 1), key=key, dtype=dtype)\n    w, key = randn((M, 1), key=key, dtype=dtype)\n\n    op = F.slice(0, u1)\n    J0op = jacobian(op, u0)\n    np.testing.assert_allclose(J0op(v), F.jvp(0, v, u0, u1)[1])\n    np.testing.assert_allclose(J0op.H(w), F.vjp(0, u0, u1)[1](w))\n    J0fn = F.jacobian(0, u0, u1)\n    np.testing.assert_allclose(J0op(v), J0fn(v))\n    np.testing.assert_allclose(J0op.H(w), J0fn.H(w))\n\n    op = F.slice(1, u0)\n    J1op = jacobian(op, u1)\n    np.testing.assert_allclose(J1op(v), F.jvp(1, v, u0, u1)[1])\n    np.testing.assert_allclose(J1op.H(w), F.vjp(1, u0, u1)[1](w))\n    J1fn = F.jacobian(1, u0, u1)\n    np.testing.assert_allclose(J1op(v), J1fn(v))\n    np.testing.assert_allclose(J1op.H(w), J1fn.H(w))\n"
  },
  {
    "path": "scico/test/test_metric.py",
    "content": "import numpy as np\n\nimport scico.numpy as snp\nfrom scico import metric\n\n\nclass TestSet:\n    def setup_method(self, method):\n        np.random.seed(12345)\n\n    def test_mae_mse(self):\n        N = 16\n        x = np.random.randn(N)\n        y = x.copy()\n        y[0] = 0\n        xe = np.abs(x[0])\n        e1 = metric.mae(x, y)\n        e2 = metric.mse(x, y)\n        assert np.abs(e1 - xe / N) < 1e-12\n        assert np.abs(e2 - (xe**2) / N) < 1e-12\n\n    def test_snr_nrm(self):\n        N = 16\n        x = np.random.randn(N)\n        x /= np.sqrt(np.var(x))\n        y = x + 1\n        assert np.abs(metric.snr(x, y)) < 1e-6\n\n    def test_snr_signal_range(self):\n        N = 16\n        x = np.random.randn(N)\n        x -= x.min()\n        x /= x.max()\n        y = x + 1\n        assert np.abs(metric.psnr(x, y)) < 1e-6\n\n    def test_psnr(self):\n        N = 16\n        x = np.random.randn(N)\n        y = x + 1\n        assert np.abs(metric.psnr(x, y, signal_range=1.0)) < 1e-6\n\n    def test_isnr(self):\n        N = 16\n        x = np.random.randn(N)\n        y = np.random.randn(N)\n        assert np.abs(metric.isnr(x, y, y)) < 1e-6\n\n    def test_bsnr(self):\n        N = 16\n        x = np.random.randn(N)\n        x /= np.sqrt(np.var(x))\n        n = np.random.randn(N)\n        n /= np.sqrt(np.var(n))\n        y = x + n\n        assert np.abs(metric.bsnr(x, y)) < 1e-6\n\n\ndef test_rel_res():\n    A = snp.array([[2, -1], [1, 0], [-1, 1]], dtype=snp.float32)\n    x = snp.array([[3], [-2]], dtype=snp.float32)\n    Ax = snp.matmul(A, x)\n    b = snp.array([[8], [3], [-5]], dtype=snp.float32)\n    assert 0.0 == metric.rel_res(Ax, b)\n\n    A = snp.array([[2, -1], [1, 0], [-1, 1]], dtype=snp.float32)\n    x = snp.array([[0], [0]], dtype=snp.float32)\n    Ax = snp.matmul(A, x)\n    b = snp.array([[0], [0], [0]], dtype=snp.float32)\n    assert 0.0 == metric.rel_res(Ax, b)\n"
  },
  {
    "path": "scico/test/test_random.py",
    "content": "import numpy as np\n\nimport jax\n\nimport pytest\n\nimport scico.random\n\n\n@pytest.mark.parametrize(\"seed\", [None, 42])\ndef test_wrapped_funcs(seed):\n    fun = jax.random.normal\n    fun_wrapped = scico.random.normal\n\n    # test seed argument\n    if seed is None:\n        key = jax.random.key(0)\n    else:\n        key = jax.random.key(seed)\n\n    np.testing.assert_array_equal(fun(key), fun_wrapped(seed=seed)[0])\n\n    # test blockarray\n    shape = ((7,), (3, 2), (2, 4, 1))\n    seed = 42\n    key = jax.random.key(seed)\n\n    result, _ = fun_wrapped(shape, seed=seed)\n\n\ndef test_add_seed_adapter():\n    fun = jax.random.normal\n\n    fun_alt = scico.random._add_seed(fun)\n\n    # specify a seed instead of a key\n    assert fun(jax.random.key(42)) == fun_alt(seed=42)[0]\n\n    # seed defaults to zero\n    assert fun(jax.random.key(0)) == fun_alt()[0]\n\n    # other parameters still work ...\n    key = jax.random.key(0)\n    sz = (10, 3)\n    dtype = np.float64\n\n    # ... positional\n    np.testing.assert_array_equal(fun(key, sz), fun_alt(sz)[0])\n    np.testing.assert_array_equal(fun(key, sz, dtype), fun_alt(sz, dtype)[0])\n    np.testing.assert_array_equal(fun(key, sz, dtype), fun_alt(sz, dtype, key)[0])\n    np.testing.assert_array_equal(fun(key, sz, dtype), fun_alt(sz, dtype, None, 0)[0])\n\n    # ... keyword\n    np.testing.assert_array_equal(fun(shape=sz, key=key), fun_alt(shape=sz)[0])\n    np.testing.assert_array_equal(\n        fun(shape=sz, key=key, dtype=dtype), fun_alt(dtype=dtype, shape=sz)[0]\n    )\n\n    # ... mixed\n    np.testing.assert_array_equal(\n        fun(key, dtype=dtype, shape=sz), fun_alt(dtype=dtype, shape=sz)[0]\n    )\n\n    # get back the split key\n    _, key_a = fun_alt(seed=42)\n    key_b, _ = jax.random.split(jax.random.key(42), 2)\n    assert key_a == key_b\n\n    # error when key and seed are specified\n    with pytest.raises(ValueError):\n        _ = fun_alt(key=jax.random.key(0), seed=42)[0]\n"
  },
  {
    "path": "scico/test/test_ray_tune.py",
    "content": "import os\nimport tempfile\n\nimport numpy as np\n\nimport pytest\n\ntry:\n    import ray\n    from scico.ray import report, tune\nexcept ImportError as e:\n    pytest.skip(\"ray.tune not installed\", allow_module_level=True)\n\n\ndef test_random_run():\n    def eval_params(config):\n        x, y = config[\"x\"], config[\"y\"]\n        cost = x**2 + (y - 0.5) ** 2\n        report({\"cost\": cost})\n\n    config = {\"x\": tune.uniform(-1, 1), \"y\": tune.uniform(-1, 1)}\n    resources = {\"gpu\": 0, \"cpu\": 1}\n    tune.ray.tune.register_trainable(\"eval_func\", eval_params)\n    analysis = tune.run(\n        \"eval_func\",\n        metric=\"cost\",\n        mode=\"min\",\n        num_samples=100,\n        config=config,\n        resources_per_trial=resources,\n        hyperopt=False,\n        verbose=False,\n        storage_path=os.path.join(tempfile.gettempdir(), \"ray_test\"),\n    )\n    best_config = analysis.get_best_config(metric=\"cost\", mode=\"min\")\n    assert np.abs(best_config[\"x\"]) < 0.25\n    assert np.abs(best_config[\"y\"] - 0.5) < 0.25\n\n\ndef test_random_tune():\n    def eval_params(config):\n        x, y = config[\"x\"], config[\"y\"]\n        cost = x**2 + (y - 0.5) ** 2\n        report({\"cost\": cost})\n\n    config = {\"x\": tune.uniform(-1, 1), \"y\": tune.uniform(-1, 1)}\n    resources = {\"gpu\": 0, \"cpu\": 1}\n    tuner = tune.Tuner(\n        eval_params,\n        param_space=config,\n        resources=resources,\n        metric=\"cost\",\n        mode=\"min\",\n        num_samples=100,\n        hyperopt=False,\n        verbose=False,\n        storage_path=os.path.join(tempfile.gettempdir(), \"ray_test\"),\n    )\n    results = tuner.fit()\n    best_config = results.get_best_result().config\n    assert np.abs(best_config[\"x\"]) < 0.25\n    assert np.abs(best_config[\"y\"] - 0.5) < 0.25\n\n\ndef test_hyperopt_run():\n    def eval_params(config):\n        x, y = config[\"x\"], config[\"y\"]\n        cost = x**2 + (y - 0.5) ** 2\n        report({\"cost\": cost})\n\n    config = {\"x\": tune.uniform(-1, 1), \"y\": tune.uniform(-1, 1)}\n    resources = {\"gpu\": 0, \"cpu\": 1}\n    analysis = tune.run(\n        eval_params,\n        metric=\"cost\",\n        mode=\"min\",\n        num_samples=50,\n        config=config,\n        resources_per_trial=resources,\n        hyperopt=True,\n        verbose=True,\n    )\n    best_config = analysis.get_best_config(metric=\"cost\", mode=\"min\")\n    assert np.abs(best_config[\"x\"]) < 0.25\n    assert np.abs(best_config[\"y\"] - 0.5) < 0.25\n\n\ndef test_hyperopt_tune():\n    def eval_params(config):\n        x, y = config[\"x\"], config[\"y\"]\n        cost = x**2 + (y - 0.5) ** 2\n        report({\"cost\": cost})\n\n    config = {\"x\": tune.uniform(-1, 1), \"y\": tune.uniform(-1, 1)}\n    resources = {\"gpu\": 0, \"cpu\": 1}\n    tuner = tune.Tuner(\n        eval_params,\n        param_space=config,\n        resources=resources,\n        metric=\"cost\",\n        mode=\"min\",\n        num_samples=50,\n        hyperopt=True,\n        verbose=True,\n    )\n    results = tuner.fit()\n    best_config = results.get_best_result().config\n    assert np.abs(best_config[\"x\"]) < 0.25\n    assert np.abs(best_config[\"y\"] - 0.5) < 0.25\n\n\ndef test_hyperopt_tune_alt_init():\n    def eval_params(config):\n        x, y = config[\"x\"], config[\"y\"]\n        cost = x**2 + (y - 0.5) ** 2\n        report({\"cost\": cost})\n\n    config = {\"x\": tune.uniform(-1, 1), \"y\": tune.uniform(-1, 1)}\n    tuner = tune.Tuner(\n        eval_params,\n        param_space=config,\n        max_concurrent_trials=4,\n        metric=\"cost\",\n        mode=\"min\",\n        num_samples=50,\n        time_budget=2,\n        hyperopt=True,\n        verbose=True,\n        tune_config=ray.tune.TuneConfig(),\n        run_config=ray.tune.RunConfig(),\n    )\n    results = tuner.fit()\n    best_config = results.get_best_result().config\n    assert np.abs(best_config[\"x\"]) < 0.25\n    assert np.abs(best_config[\"y\"] - 0.5) < 0.25\n"
  },
  {
    "path": "scico/test/test_scipy_special.py",
    "content": "import numpy as np\n\nimport pytest\n\nimport scico.scipy.special as ss\nfrom scico.random import randn\n\n# these are functions that take only a single ndarray as input\none_arg_funcs = [\n    ss.digamma,\n    ss.entr,\n    ss.erf,\n    ss.erfc,\n    ss.erfinv,\n    ss.expit,\n    ss.gammaln,\n    ss.i0,\n    ss.i0e,\n    ss.i1,\n    ss.i1e,\n    ss.ndtr,\n    ss.log_ndtr,\n    ss.logit,\n    ss.ndtri,\n]\n\n\n@pytest.mark.parametrize(\"func\", one_arg_funcs)\ndef test_one_arg_funcs(func):\n\n    # blockarray array\n    x, key = randn(((8, 8), (4,)), key=None)\n\n    Fx = func(x)\n\n    fx0 = func(x[0])\n    fx1 = func(x[1])\n    np.testing.assert_allclose(Fx[0].ravel(), fx0.ravel(), rtol=1e-4)\n    np.testing.assert_allclose(Fx[1].ravel(), fx1.ravel(), rtol=1e-4)\n\n\ndef test_betainc():\n    a, key = randn(((8, 8), (4,)), key=None)\n    b, key = randn(((8, 8), (4,)), key=key)\n    x, key = randn(((8, 8), (4,)), key=key)\n\n    Fx = ss.betainc(a, b, x)\n    fx0 = ss.betainc(a[0], b[0], x[0])\n    fx1 = ss.betainc(a[1], b[1], x[1])\n    np.testing.assert_allclose(Fx[0].ravel(), fx0.ravel(), rtol=1e-4)\n    np.testing.assert_allclose(Fx[1].ravel(), fx1.ravel(), rtol=1e-4)\n\n\n@pytest.mark.parametrize(\"func\", [ss.gammainc, ss.gammaincc])\ndef test_gammainc(func):\n    a, key = randn(((8, 8), (4,)), key=None)\n    b, key = randn(((8, 8), (4,)), key=key)\n    x, key = randn(((8, 8), (4,)), key=key)\n\n    Fx = ss.betainc(a, b, x)\n    fx0 = ss.betainc(a[0], b[0], x[0])\n    fx1 = ss.betainc(a[1], b[1], x[1])\n    np.testing.assert_allclose(Fx[0].ravel(), fx0.ravel(), rtol=1e-4)\n    np.testing.assert_allclose(Fx[1].ravel(), fx1.ravel(), rtol=1e-4)\n\n\ndef test_multigammaln():\n    x, key = randn(((8, 8), (4,)), key=None)\n    d = 2\n\n    Fx = ss.multigammaln(x, d)\n    fx0 = ss.multigammaln(x[0], d)\n    fx1 = ss.multigammaln(x[1], d)\n    np.testing.assert_allclose(Fx[0].ravel(), fx0.ravel(), rtol=1e-4)\n    np.testing.assert_allclose(Fx[1].ravel(), fx1.ravel(), rtol=1e-4)\n\n\n@pytest.mark.parametrize(\"func\", [ss.xlog1py, ss.xlogy])\ndef test_logs(func):\n    x, key = randn(((8, 8), (4,)), key=None)\n    y, key = randn(((8, 8), (4,)), key=key)\n\n    Fx = func(x, y)\n    fx0 = func(x[0], y[0])\n    fx1 = func(x[1], y[1])\n    np.testing.assert_allclose(Fx[0].ravel(), fx0.ravel(), rtol=1e-4)\n    np.testing.assert_allclose(Fx[1].ravel(), fx1.ravel(), rtol=1e-4)\n\n\ndef test_zeta():\n    x, key = randn(((8, 8), (4,)), key=None)\n    y, key = randn(((8, 8), (4,)), key=None)\n\n    Fx = ss.zeta(x, y)\n    fx0 = ss.zeta(x[0], y[0])\n    fx1 = ss.zeta(x[1], y[1])\n    np.testing.assert_allclose(Fx[0].ravel(), fx0.ravel(), rtol=1e-4)\n    np.testing.assert_allclose(Fx[1].ravel(), fx1.ravel(), rtol=1e-4)\n"
  },
  {
    "path": "scico/test/test_solver.py",
    "content": "import numpy as np\n\nfrom jax.scipy.linalg import block_diag\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico import linop, metric, random, solver\n\n\nclass TestSet:\n    def setup_method(self, method):\n        np.random.seed(12345)\n\n    def test_wrap_func_and_grad(self):\n        N = 8\n        A = snp.array(np.random.randn(N, N))\n        x = snp.array(np.random.randn(N))\n\n        f = lambda x: 0.5 * snp.linalg.norm(A @ x) ** 2\n\n        func_and_grad = solver._wrap_func_and_grad(f, shape=(N,), dtype=x.dtype)\n        fx, grad = func_and_grad(x)\n\n        np.testing.assert_allclose(fx, f(x), rtol=5e-5)\n        np.testing.assert_allclose(grad, A.T @ A @ x, rtol=5e-5)\n\n    def test_cg_std(self):\n        N = 64\n        Ac = np.random.randn(N, N)\n        Am = Ac.dot(Ac.T)\n        A = Am.dot\n        x = np.random.randn(N)\n        b = Am.dot(x)\n        x0 = np.zeros((N,))\n        tol = 1e-12\n        try:\n            xcg, info = solver.cg(A, b, x0, tol=tol)\n        except Exception as e:\n            print(e)\n            assert 0\n        assert info[\"rel_res\"].ndim == 0\n        assert np.linalg.norm(A(xcg) - b) / np.linalg.norm(b) < 1e-6\n\n    def test_cg_op(self):\n        N = 32\n        Ac = np.random.randn(N, N).astype(np.float32)\n        Am = Ac.dot(Ac.T)\n        A = Am.dot\n        x = np.random.randn(N).astype(np.float32)\n        b = Am.dot(x)\n        tol = 1e-12\n        try:\n            xcg, info = solver.cg(linop.MatrixOperator(Am), b, tol=tol)\n        except Exception as e:\n            print(e)\n            assert 0\n        assert info[\"rel_res\"].ndim == 0\n        assert np.linalg.norm(A(xcg) - b) / np.linalg.norm(b) < 1e-6\n\n    def test_cg_no_info(self):\n        N = 64\n        Ac = np.random.randn(N, N)\n        Am = Ac.dot(Ac.T)\n        A = Am.dot\n        x = np.random.randn(N)\n        b = Am.dot(x)\n        x0 = np.zeros((N,))\n        tol = 1e-12\n        try:\n            xcg = solver.cg(A, b, x0, tol=tol, info=False)\n        except Exception as e:\n            print(e)\n            assert 0\n        assert np.linalg.norm(A(xcg) - b) / np.linalg.norm(b) < 1e-6\n\n    def test_cg_complex(self):\n        N = 64\n        Ac = np.random.randn(N, N) + 1j * np.random.randn(N, N)\n        Am = Ac.dot(Ac.conj().T)\n        A = Am.dot\n        x = np.random.randn(N) + 1j * np.random.randn(N)\n        b = Am.dot(x)\n        x0 = np.zeros_like(x)\n        tol = 1e-12\n        try:\n            xcg, info = solver.cg(A, b, x0, tol=tol)\n        except Exception as e:\n            print(e)\n            assert 0\n        assert np.linalg.norm(A(xcg) - b) / np.linalg.norm(b) < 1e-6\n\n    def test_preconditioned_cg(self):\n        N = 64\n        D = np.diag(np.linspace(0.1, 20, N))\n        Ac = D @ np.random.randn(\n            N, N\n        )  # Poorly scaled matrix; good fit for diagonal preconditioning\n        Am = Ac.dot(Ac.conj().T)\n\n        A = Am.dot\n\n        Mm = np.diag(1 / np.diag(Am))  # inverse of diagonal of Am\n        M = Mm.dot\n\n        x = np.random.randn(N) + 1j * np.random.randn(N)\n        b = Am.dot(x)\n        x0 = np.zeros_like(x)\n        tol = 1e-12\n        x_cg, cg_info = solver.cg(A, b, x0, tol=tol, info=True, M=None, maxiter=3)\n        x_pcg, pcg_info = solver.cg(A, b, x0, tol=tol, info=True, M=M, maxiter=3)\n\n        # Assert that PCG converges faster in a few iterations\n        assert cg_info[\"rel_res\"] > 3 * pcg_info[\"rel_res\"]\n\n    def test_lstsq_func(self):\n        N = 24\n        M = 32\n        Ac = snp.array(np.random.randn(N, M).astype(np.float32))\n        Am = Ac.dot(Ac.T)\n        A = Am.dot\n        x = snp.array(np.random.randn(N).astype(np.float32))\n        b = Am.dot(x)\n        x0 = snp.zeros((N,), dtype=np.float32)\n        tol = 1e-6\n        try:\n            xlsq = solver.lstsq(A, b, x0=x0, tol=tol)\n        except Exception as e:\n            print(e)\n            assert 0\n        assert np.linalg.norm(A(xlsq) - b) / np.linalg.norm(b) < 5e-6\n\n    def test_lstsq_op(self):\n        N = 32\n        M = 24\n        Ac = snp.array(np.random.randn(N, M).astype(np.float32))\n        A = linop.MatrixOperator(Ac)\n        x = snp.array(np.random.randn(M).astype(np.float32))\n        b = Ac.dot(x)\n        tol = 1e-7\n        try:\n            xlsq = solver.lstsq(A, b, tol=tol)\n        except Exception as e:\n            print(e)\n            assert 0\n        assert np.linalg.norm(A(xlsq) - b) / np.linalg.norm(b) < 1e-6\n\n\nclass TestOptimizeScalar:\n    # Adopted from SciPy minimize_scalar tests\n    # https://github.com/scipy/scipy/blob/701ffcc8a6f04509d115aac5e5681c538b5265a2/scipy/optimize/tests/test_optimize.py#L1364\n    def setup_method(self):\n        self.solution = 1.5\n        self.rtol = 1e-3\n\n    def fun(self, x, a=1.5):\n        \"\"\"Objective function\"\"\"\n        # Jax version of (x - a)**2 - 0.8; will return a devicearray\n        return snp.square(x - a) - 0.8\n\n    def test_minimize_scalar(self):\n        # combine all tests above for the minimize_scalar wrapper\n        x = solver.minimize_scalar(self.fun).x\n        np.testing.assert_allclose(x, self.solution, rtol=self.rtol)\n\n        x = solver.minimize_scalar(self.fun, method=\"Brent\")\n        np.testing.assert_(x.success)\n\n        x = solver.minimize_scalar(self.fun, method=\"Brent\", options=dict(maxiter=3))\n        np.testing.assert_(not x.success)\n\n        x = solver.minimize_scalar(self.fun, bracket=(-3, -2), args=(1.5,), method=\"Brent\").x\n        np.testing.assert_allclose(x, self.solution, rtol=self.rtol)\n\n        x = solver.minimize_scalar(self.fun, method=\"Brent\", args=(1.5,)).x\n        np.testing.assert_allclose(x, self.solution, rtol=self.rtol)\n\n        x = solver.minimize_scalar(self.fun, bracket=(-15, -1, 15), args=(1.5,), method=\"Brent\").x\n        np.testing.assert_allclose(x, self.solution, rtol=self.rtol)\n\n        x = solver.minimize_scalar(self.fun, bracket=(-3, -2), args=(1.5,), method=\"golden\").x\n        np.testing.assert_allclose(x, self.solution, rtol=self.rtol)\n\n        x = solver.minimize_scalar(self.fun, method=\"golden\", args=(1.5,)).x\n        np.testing.assert_allclose(x, self.solution, rtol=self.rtol)\n\n        x = solver.minimize_scalar(self.fun, bracket=(-15, -1, 15), args=(1.5,), method=\"golden\").x\n        np.testing.assert_allclose(x, self.solution, rtol=self.rtol)\n\n        x = solver.minimize_scalar(self.fun, bounds=(0, 1), args=(1.5,), method=\"Bounded\").x\n        np.testing.assert_allclose(x, 1, rtol=1e-4)\n\n        x = solver.minimize_scalar(self.fun, bounds=(1, 5), args=(1.5,), method=\"bounded\").x\n        np.testing.assert_allclose(x, self.solution, rtol=self.rtol)\n\n        x = solver.minimize_scalar(\n            self.fun,\n            bounds=(np.array([1]), np.array([5])),\n            args=(np.array([1.5]),),\n            method=\"bounded\",\n        ).x\n        np.testing.assert_allclose(x, self.solution, rtol=self.rtol)\n\n\n@pytest.mark.parametrize(\"dtype\", [snp.float32, snp.complex64])\n@pytest.mark.parametrize(\"method\", [\"CG\", \"L-BFGS-B\"])\ndef test_minimize_vector(dtype, method):\n    B, M, N = (4, 3, 2)\n\n    # model a 12x8 block-diagonal matrix with 3x2 blocks\n    A, key = random.randn((B, M, N), dtype=dtype)\n    x, key = random.randn((B, N), dtype=dtype, key=key)\n    y = snp.sum(A * x[:, None], axis=2)  # contract along the N axis\n\n    # result by directly inverting the dense matrix\n    A_mat = block_diag(*A)\n    expected = snp.linalg.pinv(A_mat) @ y.ravel()\n\n    def f(x):\n        return 0.5 * snp.linalg.norm(y - snp.sum(A * x[:, None], axis=2)) ** 2\n\n    out = solver.minimize(f, x0=snp.zeros_like(x), method=method)\n\n    assert out.x.shape == x.shape\n    np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4)\n\n\n@pytest.mark.parametrize(\"dtype\", [snp.float32])\n@pytest.mark.parametrize(\"method\", [\"CG\"])\ndef test_minimize_blockarray(dtype, method):\n    # model a 6x8 block-diagonal matrix with 3x4 blocks\n    A, key = random.randn(((3, 4), (3, 4)), dtype=dtype)\n    x, key = random.randn(((4,), (4,)), dtype=dtype, key=key)\n    y = A @ x\n\n    # result by directly inverting the dense matrix\n    A_mat = block_diag(*A)\n    expected = snp.linalg.pinv(A_mat) @ y.stack(axis=0).ravel()\n\n    def f(x):\n        return 0.5 * snp.linalg.norm(y - A @ x) ** 2\n\n    out = solver.minimize(f, x0=snp.zeros_like(x), method=method)\n\n    assert out.x.shape == x.shape\n    np.testing.assert_allclose(solver._ravel(out.x), expected, rtol=5e-4)\n\n\ndef test_split_join_array():\n    x, key = random.randn((4, 4), dtype=np.complex64)\n    x_s = solver._split_real_imag(x)\n    assert x_s.shape == (2, 4, 4)\n    np.testing.assert_allclose(x_s[0], snp.real(x))\n    np.testing.assert_allclose(x_s[1], snp.imag(x))\n\n    x_j = solver._join_real_imag(x_s)\n    np.testing.assert_allclose(x_j, x, rtol=1e-4)\n\n\ndef test_split_join_blockarray():\n    x, key = random.randn(((4, 4), (3,)), dtype=np.complex64)\n    x_s = solver._split_real_imag(x)\n    assert x_s.shape == ((2, 4, 4), (2, 3))\n\n    real_block = snp.blockarray((x_s[0][0], x_s[1][0]))\n    imag_block = snp.blockarray((x_s[0][1], x_s[1][1]))\n    snp.testing.assert_allclose(real_block, snp.real(x), rtol=1e-4)\n    snp.testing.assert_allclose(imag_block, snp.imag(x), rtol=1e-4)\n\n    x_j = solver._join_real_imag(x_s)\n    snp.testing.assert_allclose(x_j, x, rtol=1e-4)\n\n\ndef test_bisect():\n    f = lambda x: x**3\n    x, info = solver.bisect(f, -snp.ones((5, 1)), snp.ones((5, 1)), full_output=True)\n    assert snp.sum(snp.abs(x)) == 0.0\n    assert info[\"iter\"] == 0\n    x = solver.bisect(f, -2.0 * snp.ones((5, 3)), snp.ones((5, 3)), xtol=1e-5, ftol=1e-5)\n    assert snp.max(snp.abs(x)) <= 1e-5\n    assert snp.max(snp.abs(f(x))) <= 1e-5\n    c, key = random.randn((5, 1), dtype=np.float32)\n    f = lambda x, c: x**3 - c**3\n    x = solver.bisect(f, -snp.abs(c) - 1, snp.abs(c) + 1, args=(c,), xtol=1e-5, ftol=1e-5)\n    assert snp.max(snp.abs(x - c)) <= 1e-5\n    assert snp.max(snp.abs(f(x, c))) <= 1e-5\n\n\ndef test_golden():\n    f = lambda x: x**2\n    x, info = solver.golden(f, -snp.ones((5, 1)), snp.ones((5, 1)), full_output=True)\n    assert snp.max(snp.abs(x)) <= 1e-7\n    x = solver.golden(f, -2.0 * snp.ones((5, 3)), snp.ones((5, 3)), xtol=1e-5)\n    assert snp.max(snp.abs(x)) <= 1e-5\n    c, key = random.randn((5, 1), dtype=np.float32)\n    f = lambda x, c: (x - c) ** 2\n    x = solver.golden(f, -snp.abs(c) - 1, snp.abs(c) + 1, args=(c,), xtol=1e-5)\n    assert snp.max(snp.abs(x - c)) <= 1e-5\n\n\n@pytest.mark.parametrize(\"cho_factor\", [True, False])\n@pytest.mark.parametrize(\"wide\", [True, False])\n@pytest.mark.parametrize(\"weighted\", [True, False])\n@pytest.mark.parametrize(\"alpha\", [1e-1, 1e1])\ndef test_solve_atai(cho_factor, wide, weighted, alpha):\n    A, key = random.randn((5, 8), dtype=snp.float32)\n    if wide:\n        x0, key = random.randn((8,), key=key)\n    else:\n        A = A.T\n        x0, key = random.randn((5,), key=key)\n\n    if weighted:\n        W, key = random.randn((A.shape[0],), key=key)\n        W = snp.abs(W)\n        Wa = W[:, snp.newaxis]\n    else:\n        W = None\n        Wa = snp.array([1.0])[:, snp.newaxis]\n\n    D = alpha * snp.ones((A.shape[1],))\n    ATAD = A.T @ (Wa * A) + alpha * snp.identity(A.shape[1])\n    b = ATAD @ x0\n    slv = solver.MatrixATADSolver(A, D, W=W, cho_factor=cho_factor)\n    x1 = slv.solve(b)\n    assert metric.rel_res(x0, x1) < 5e-5\n\n\n@pytest.mark.parametrize(\"cho_factor\", [True, False])\n@pytest.mark.parametrize(\"wide\", [True, False])\n@pytest.mark.parametrize(\"alpha\", [1e-1, 1e1])\ndef test_solve_aati(cho_factor, wide, alpha):\n    A, key = random.randn((5, 8), dtype=snp.float32)\n    if wide:\n        x0, key = random.randn((5,), key=key)\n    else:\n        A = A.T\n        x0, key = random.randn((8,), key=key)\n\n    D = alpha * snp.ones((A.shape[0],))\n    AATD = A @ A.T + alpha * snp.identity(A.shape[0])\n    b = AATD @ x0\n    slv = solver.MatrixATADSolver(A.T, D)\n    x1 = slv.solve(b)\n    assert metric.rel_res(x0, x1) < 5e-5\n\n\n@pytest.mark.parametrize(\"cho_factor\", [True, False])\n@pytest.mark.parametrize(\"wide\", [True, False])\n@pytest.mark.parametrize(\"vector\", [True, False])\ndef test_solve_atad(cho_factor, wide, vector):\n    A, key = random.randn((5, 8), dtype=snp.float32)\n    if wide:\n        D, key = random.randn((8,), key=key)\n        if vector:\n            x0, key = random.randn((8,), key=key)\n        else:\n            x0, key = random.randn((8, 3), key=key)\n    else:\n        A = A.T\n        D, key = random.randn((5,), key=key)\n        if vector:\n            x0, key = random.randn((5,), key=key)\n        else:\n            x0, key = random.randn((5, 3), key=key)\n\n    D = snp.abs(D)  # only required for Cholesky, but improved accuracy for LU\n    ATAD = A.T @ A + snp.diag(D)\n    b = ATAD @ x0\n    slv = solver.MatrixATADSolver(A, D, cho_factor=cho_factor)\n    x1 = slv.solve(b)\n    assert metric.rel_res(x0, x1) < 5e-5\n    assert slv.accuracy(x1, b) < 5e-5\n"
  },
  {
    "path": "scico/test/test_util.py",
    "content": "import socket\nimport urllib.error as urlerror\n\nimport numpy as np\n\nimport jax\n\nimport pytest\n\nimport scico.numpy as snp\nfrom scico.util import (\n    ContextTimer,\n    Timer,\n    check_for_tracer,\n    partial,\n    rgetattr,\n    rsetattr,\n    url_get,\n)\n\n\ndef test_rattr():\n    class A:\n        class B:\n            c = 0\n\n        b = B()\n\n    a = A()\n    rsetattr(a, \"b.c\", 1)\n    assert rgetattr(a, \"b.c\") == 1\n\n    assert rgetattr(a, \"c.d\", 10) == 10\n\n    with pytest.raises(AttributeError):\n        assert rgetattr(a, \"c.d\")\n\n\ndef test_partial_pos():\n    def func(a, b, c, d):\n        return a + 2 * b + 4 * c + 8 * d\n\n    pfunc = partial(func, (0, 2), 0, 0)\n    assert pfunc(1, 0) == 2 and pfunc(0, 1) == 8\n\n\ndef test_partial_kw():\n    def func(a=1, b=1, c=1, d=1):\n        return a + 2 * b + 4 * c + 8 * d\n\n    pfunc = partial(func, (), a=0, c=0)\n    assert pfunc(b=1, d=0) == 2 and pfunc(b=0, d=1) == 8\n\n\ndef test_partial_pos_and_kw():\n    def func(a, b, c=1, d=1):\n        return a + 2 * b + 4 * c + 8 * d\n\n    pfunc = partial(func, (0,), 0, c=0)\n    assert pfunc(1, d=0) == 2 and pfunc(0, d=1) == 8\n\n\n# See https://stackoverflow.com/a/33117579\ndef _internet_connected(host=\"8.8.8.8\", port=53, timeout=3):\n    \"\"\"Check if internet connection available.\n\n    Host: 8.8.8.8 (google-public-dns-a.google.com)\n    OpenPort: 53/tcp\n    Service: domain (DNS/TCP)\n    \"\"\"\n    try:\n        socket.setdefaulttimeout(timeout)\n        socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect((host, port))\n        return True\n    except socket.error as ex:\n        return False\n\n\n@pytest.mark.skipif(not _internet_connected(), reason=\"No internet connection\")\ndef test_url_get():\n    url = \"https://github.com/lanl/scico/blob/main/README.md\"\n    headers = {\n        \"User-Agent\": \"Mozilla/5.0 (X11; Linux x86_64)\",\n        \"Referer\": \"https://github.com/lanl/scico/blob/main\",\n    }\n    try:\n        uget = url_get(url, headers=headers)\n    except urlerror.HTTPError as e:\n        if e.code != 429:\n            raise\n    else:\n        assert not uget.getvalue().find(b\"SCICO\") == -1\n\n    np.testing.assert_raises(ValueError, url_get, url, maxtry=-1)\n\n    url = \"about:blank\"\n    np.testing.assert_raises(urlerror.URLError, url_get, url)\n\n\ndef test_check_for_tracer():\n    # Using examples from Jax documentation\n\n    A = snp.ones((5, 5))\n    x = snp.ones((10, 5))\n\n    @check_for_tracer\n    def norm(X):\n        X = X - X.mean(0)\n        return X / X.std(0)\n\n    with pytest.raises(TypeError):\n        check_norm = jax.jit(norm)\n        check_norm(x)\n\n    vv = check_for_tracer(lambda x: A @ x)\n    with pytest.raises(TypeError):\n        mv = jax.vmap(vv)\n        mv(x)\n\n\ndef test_timer_basic():\n    t = Timer()\n    t.start()\n    t0 = t.elapsed()\n    t.stop()\n    t1 = t.elapsed()\n    assert t0 >= 0.0\n    assert t1 >= t0\n    assert len(t.__str__()) > 0\n    assert len(t.labels()) > 0\n\n\ndef test_timer_multi():\n    t = Timer(\"a\")\n    t.start([\"a\", \"b\"])\n    t0 = t.elapsed(\"a\")\n    t.stop(\"a\")\n    t.stop(\"b\")\n    t.stop([\"a\", \"b\"])\n    assert t.elapsed(\"a\") >= 0.0\n    assert t.elapsed(\"b\") >= 0.0\n    assert t.elapsed(\"a\", total=False) == 0.0\n\n\ndef test_timer_reset():\n    t = Timer(\"a\")\n    t.start([\"a\", \"b\"])\n    t.reset(\"a\")\n    assert t.elapsed(\"a\") == 0.0\n    t.reset(\"all\")\n    assert t.elapsed(\"b\") == 0.0\n\n\ndef test_ctxttimer_basic():\n    t = Timer()\n    with ContextTimer(t):\n        t0 = t.elapsed()\n    assert t.elapsed() >= 0.0\n\n\ndef test_ctxttimer_stopstart():\n    t = Timer()\n    t.start()\n    with ContextTimer(t, action=\"StopStart\"):\n        t0 = t.elapsed()\n    t.stop()\n    assert t.elapsed() >= 0.0\n"
  },
  {
    "path": "scico/test/test_version.py",
    "content": "from scico._version import variable_assign_value\n\ntest_var_num = 12345\ntest_var_str = \"12345\"\n\n\ndef test_var_val():\n    assert variable_assign_value(__file__, \"test_var_num\") == test_var_num\n    assert variable_assign_value(__file__, \"test_var_str\") == test_var_str\n"
  },
  {
    "path": "scico/trace.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2024-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"Call tracing of scico functions and methods.\n\nJIT must be disabled for tracing to function correctly (set environment\nvariable :code:`JAX_DISABLE_JIT=1`, or call\n:code:`jax.config.update('jax_disable_jit', True)` before importing `jax`\nor `scico`). Call :code:`trace_scico_calls` to initialize tracing, and\ncall :code:`register_variable` to associate a name with a variable so\nthat it can be referenced by name in the call trace.\n\nThe call trace is color-code as follows if\n`colorama <https://github.com/tartley/colorama>`_ is installed:\n\n- `module and class names`: light red\n- `function and method names`: dark red\n- `arguments and return values`: light blue\n- `names of registered variables`: light yellow\n\nWhen a method defined in a class is called for an object of a derived\nclass type, the class of that object is displayed in light magenta, in\nsquare brackets. Function names and return values are distinguished by\ninitial ``>>`` and ``<<`` characters respectively.\n\nA usage example is provided in the script :code:`trace_example.py`.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport inspect\nimport sys\nimport types\nimport warnings\nfrom collections import defaultdict\nfrom functools import wraps\nfrom typing import Any, Callable, Optional, Sequence\n\nimport numpy as np\n\nimport jax\n\ntry:\n    from jaxlib.xla_extension import PjitFunction\nexcept ImportError:\n    from jaxlib._jax import PjitFunction  # jax >= 0.6.1\n\n\ntry:\n    import colorama\n\n    have_colorama = True\nexcept ImportError:\n    have_colorama = False\n\n\nif have_colorama:\n    clr_main = colorama.Fore.LIGHTRED_EX  # main trace information\n    clr_rvar = colorama.Fore.LIGHTYELLOW_EX  # registered variable names\n    clr_self = colorama.Fore.LIGHTMAGENTA_EX  # type of object for which method is called\n    clr_func = colorama.Fore.RED  # function/method name\n    clr_args = colorama.Fore.LIGHTBLUE_EX  # function/method arguments\n    clr_retv = colorama.Fore.LIGHTBLUE_EX  # function/method return values\n    clr_devc = colorama.Fore.CYAN  # JAX array device and sharding\n    clr_reset = colorama.Fore.RESET  # reset color\nelse:\n    clr_main, clr_rvar, clr_self, clr_func = \"\", \"\", \"\", \"\"\n    clr_args, clr_retv, clr_devc, clr_reset = \"\", \"\", \"\", \"\"\n\n\ndef _get_hash(val: Any) -> Optional[int]:\n    \"\"\"Get a hash representing an object.\n\n    Args:\n        val: An object for which the hash is required.\n\n    Returns:\n        A hash value of ``None`` if a hash cannot be computed.\n    \"\"\"\n    if isinstance(val, np.ndarray):\n        hash = val.ctypes.data  # for an ndarray, hash is the memory address\n    elif hasattr(val, \"__hash__\") and callable(val.__hash__):\n        try:\n            hash = val.__hash__()\n        except TypeError:\n            hash = None\n    else:\n        hash = None\n    return hash\n\n\ndef _trace_arg_repr(val: Any) -> str:\n    \"\"\"Compute string representation of function arguments.\n\n    Args:\n        val: Argument value\n\n    Returns:\n        A string representation of the argument.\n    \"\"\"\n    if val is None:\n        return \"None\"\n    elif np.isscalar(val):  # a scalar value\n        return str(val)\n    elif isinstance(val, tuple) and len(val) < 6 and all([np.isscalar(s) for s in val]):\n        return f\"{val}\"  # a short sequence of scalars\n    elif isinstance(val, np.dtype):  # a numpy dtype\n        return f\"numpy.{val}\"\n    elif isinstance(val, type):  # a class name\n        return f\"{val.__module__}.{val.__qualname__}\"\n    elif isinstance(val, np.ndarray) and _get_hash(val) in call_trace.instance_hash:  # type: ignore\n        return f\"{clr_rvar}{call_trace.instance_hash[_get_hash(val)]}{clr_args}\"  # type: ignore\n    elif isinstance(val, (np.ndarray, jax.Array)):  # a jax or numpy array\n        if val.shape == ():\n            return str(val)\n        else:\n            dev_str, shard_str = \"\", \"\"\n            if isinstance(val, jax.Array) and not isinstance(\n                val, jax._src.interpreters.partial_eval.JaxprTracer\n            ):\n                if call_trace.show_jax_device:  # type: ignore\n                    platform = list(val.devices())[0].platform  # assume all of same type\n                    devices = \",\".join(map(str, sorted([d.id for d in val.devices()])))\n                    dev_str = f\"{clr_devc}{{dev={platform}({devices})}}{clr_args}\"\n                if call_trace.show_jax_sharding and isinstance(  # type: ignore\n                    val.sharding, jax._src.sharding_impls.PositionalSharding\n                ):\n                    shard_str = f\"{clr_devc}{{shard={val.sharding.shape}}}{clr_args}\"\n            return f\"Array{val.shape}{dev_str}{shard_str}\"\n    else:\n        if _get_hash(val) in call_trace.instance_hash:  # type: ignore\n            return f\"{clr_rvar}{call_trace.instance_hash[val.__hash__()]}{clr_args}\"  # type: ignore\n        else:\n            return f\"[{type(val).__name__}]\"\n\n\ndef register_variable(var: Any, name: str):\n    \"\"\"Register a variable name for call tracing.\n\n    Any hashable object (or numpy array, with the memory address\n    used as a hash) may be registered. JAX arrays may not be registered\n    since they are not hashable and there is no clear mechanism for\n    associating them with a unique memory address.\n\n    Args:\n        var: The variable to be registered.\n        name: The name to be associated with the variable.\n    \"\"\"\n    hash = _get_hash(var)\n    if hash is None:\n        raise ValueError(f\"Can't get hash for variable '{name}'.\")\n    call_trace.instance_hash[hash] = name  # type: ignore\n\n\ndef _call_wrapped_function(func: Callable, *args, **kwargs) -> Any:\n    \"\"\"Call a wrapped function within the wrapper.\n\n    Handle different call mechanisms required for static and class\n    methods.\n\n    Args:\n        func: Wrapped function\n        *args: Positional arguments\n        **kwargs: Named arguments\n\n    Returns:\n        Return value of wrapped function.\n    \"\"\"\n    if isinstance(func, staticmethod):\n        # If the type of the first argument is the same as the class to\n        # which the static method belongs, assume that it was called as\n        # <object>.<staticmethod>(<args>), which requires that the first\n        # argument be stripped before calling the method. This is\n        # somewhat heuristic, and may fail, but there is no obvious\n        # mechanism for reliably determining how the method was called in\n        # the calling scope.\n        if inspect._findclass(func) == type(args[0]):  # type: ignore\n            call_args = args[1:]\n        else:\n            call_args = args\n        ret = func(*call_args, **kwargs)\n    elif isinstance(func, classmethod):\n        ret = func.__func__(*args, **kwargs)\n    else:\n        ret = func(*args, **kwargs)\n    return ret\n\n\ndef call_trace(func: Callable) -> Callable:\n    \"\"\"Print log of calls to `func`.\n\n    Decorator for printing a log of calls to the wrapped function. A\n    record of call levels is maintained so that call nesting is indicated\n    by call log indentation.\n    \"\"\"\n    try:\n        method_class = inspect._findclass(func)  # type: ignore\n    except AttributeError:\n        method_class = None\n\n    @wraps(func)\n    def wrapper(*args, **kwargs):\n        name = f\"{func.__module__}.{clr_func}{func.__qualname__}\"\n        arg_idx = 0\n        if (\n            args\n            and hasattr(args[0], \"__hash__\")\n            and callable(args[0].__hash__)\n            and method_class\n            and isinstance(args[0], method_class)\n        ):  # first argument is self for a method call\n            arg_idx = 1  # skip self in handling arguments\n            if args[0].__hash__() in call_trace.instance_hash:\n                # self object registered using register_variable\n                name = (\n                    f\"{clr_rvar}{call_trace.instance_hash[args[0].__hash__()]}.\"\n                    f\"{clr_func}{func.__name__}\"\n                )\n            else:\n                # self object not registered\n                func_class = method_class.__name__\n                self_class = args[0].__class__.__name__\n                # If the class in which this method is defined is same as that\n                # of the self object for which it's called, just display the\n                # class name. Otherwise, display the name of the name defining\n                # class followed by the name of the self object class in\n                # square brackets.\n                if func_class == self_class:\n                    class_name = func_class\n                else:\n                    class_name = f\"{func_class}{clr_self}[{self_class}]{clr_main}\"\n                name = f\"{func.__module__}.{class_name}.{clr_func}{func.__name__}\"\n        args_repr = [_trace_arg_repr(val) for val in args[arg_idx:]]\n        kwargs_repr = [f\"{key}={_trace_arg_repr(val)}\" for key, val in kwargs.items()]\n        args_str = clr_args + \", \".join(args_repr + kwargs_repr) + clr_main\n        print(\n            f\"{clr_main}>> {' ' * 2 * call_trace.trace_level}{name}\"\n            f\"({args_str}{clr_func}){clr_reset}\",\n            file=sys.stderr,\n        )\n        # call wrapped function\n        call_trace.trace_level += 1\n        ret = _call_wrapped_function(func, *args, **kwargs)\n        call_trace.trace_level -= 1\n        # print representation of return value\n        if ret is not None and call_trace.show_return_value:\n            print(\n                f\"{clr_main}<< {' ' * 2 * call_trace.trace_level}{clr_retv}\"\n                f\"{_trace_arg_repr(ret)}{clr_reset}\",\n                file=sys.stderr,\n            )\n        return ret\n\n    # Set flag indicating that function is already wrapped\n    wrapper._call_trace_wrap = True  # type: ignore\n    # Avoid multiple wrapper layers\n    if hasattr(func, \"_call_trace_wrap\"):\n        return func\n    else:\n        return wrapper\n\n\n# call level counter for call_trace decorator\ncall_trace.trace_level = 0  # type: ignore\n# hash dict allowing association of objects with variable names\ncall_trace.instance_hash = {}  # type: ignore\n# flag indicating whether to show function return value\ncall_trace.show_return_value = True  # type: ignore\n# flag indicating whether to show JAX array devices\ncall_trace.show_jax_device = False  # type: ignore\n# flag indicating whether to show JAX array sharding shape\ncall_trace.show_jax_sharding = False  # type: ignore\n\n\ndef _submodule_name(module, obj):\n    if (\n        len(obj.__name__) > len(module.__name__)\n        and obj.__name__[0 : len(module.__name__)] == module.__name__\n    ):\n        short_name = obj.__name__[len(module.__name__) + 1 :]\n    else:\n        short_name = \"\"\n    return short_name\n\n\ndef _is_scico_object(obj: Any) -> bool:\n    \"\"\"Determine whether an object is defined in a scico submodule.\n\n    Args:\n        obj: Object to check.\n\n    Returns:\n        A boolean value indicating whether `obj` is defined in a scico\n        submodule.\n    \"\"\"\n    return hasattr(obj, \"__module__\") and obj.__module__[0:5] == \"scico\"\n\n\ndef _is_scico_module(mod: types.ModuleType) -> bool:\n    \"\"\"Determine whether a module is a scico submodule.\n\n    Args:\n        mod: Module to check.\n\n    Returns:\n        A boolean value indicating whether `mod` is a scico submodule.\n    \"\"\"\n    return hasattr(mod, \"__name__\") and mod.__name__[0:5] == \"scico\"\n\n\ndef _in_module(mod: types.ModuleType, obj: Any) -> bool:\n    \"\"\"Determine whether an object is defined in a module.\n\n    Args:\n        mod: Module of interest.\n        obj: Object to check.\n\n    Returns:\n        A boolean value indicating whether `obj` is defined in `mod`.\n    \"\"\"\n    return obj.__module__ == mod.__name__\n\n\ndef _is_submodule(mod: types.ModuleType, submod: types.ModuleType) -> bool:\n    \"\"\"Determine whether a module is a submodule of another module.\n\n    Args:\n        mod: Parent module of interest.\n        submod: Possible submodule to check.\n\n    Returns:\n        A boolean value indicating whether `submod` is defined in `mod`.\n    \"\"\"\n    return submod.__name__[0 : len(mod.__name__)] == mod.__name__\n\n\ndef apply_decorator(\n    module: types.ModuleType,\n    decorator: Callable,\n    recursive: bool = True,\n    skip: Optional[Sequence] = None,\n    seen: Optional[defaultdict[str, int]] = None,\n    verbose: bool = False,\n    level: int = 0,\n) -> defaultdict[str, int]:\n    \"\"\"Apply a decorator function to all functions in a scico module.\n\n    Apply a decorator function to all functions in a scico module,\n    including methods of classes in that module.\n\n    Args:\n        module: The module containing the functions/methods to be\n          decorated.\n        decorator: The decorator function to apply to each module\n          function/method.\n        recursive: Flag indicating whether to recurse into submodules\n          of the specified module. (Hidden modules with a name starting\n          with an underscore are ignored.)\n        skip: A list of class/function/method names to be skipped.\n        seen: A :class:`defaultdict` providing a count of the number of\n          times each function/method was seen.\n        verbose: Flag indicating whether to print a log of functions\n          as they are encountered.\n        level: Counter for recursive call levels.\n\n    Returns:\n        A :class:`defaultdict` providing a count of the number of times\n        each function/method was seen.\n    \"\"\"\n    indent = \" \" * 4 * level\n    if skip is None:\n        skip = []\n    if seen is None:\n        seen = defaultdict(int)\n    if verbose:\n        print(f\"{indent}Module: {module.__name__}\")\n    indent += \" \" * 4\n\n    # Iterate over functions in module\n    for name, func in inspect.getmembers(\n        module,\n        lambda obj: isinstance(obj, (types.FunctionType, PjitFunction)) and _in_module(module, obj),\n    ):\n        if name in skip:\n            continue\n        qualname = func.__module__ + \".\" + func.__qualname__\n        if not seen[qualname]:  # avoid multiple applications of decorator\n            setattr(module, name, decorator(func))\n            seen[qualname] += 1\n            if verbose:\n                print(f\"{indent}Function: {qualname}\")\n\n    # Iterate over classes in module\n    for name, cls in inspect.getmembers(\n        module, lambda obj: inspect.isclass(obj) and _in_module(module, obj)\n    ):\n        qualname = cls.__module__ + \".\" + cls.__qualname__  # type: ignore\n        if verbose:\n            print(f\"{indent}Class: {qualname}\")\n\n        # Iterate over methods in class\n        for name, func in inspect.getmembers(\n            cls, lambda obj: isinstance(obj, (types.FunctionType, PjitFunction))\n        ):\n            if name in skip:\n                continue\n            qualname = func.__module__ + \".\" + func.__qualname__  # type: ignore\n            if not seen[qualname]:  # avoid multiple applications of decorator\n                # Can't use cls returned by inspect.getmembers because it uses plain\n                # getattr internally, which interferes with identification of static\n                # methods. From Python 3.11 onwards one could use\n                # inspect.getmembers_static instead of inspect.getmembers, but that\n                # would imply incompatibility with earlier Python versions.\n                func = inspect.getattr_static(cls, name)\n                setattr(cls, name, decorator(func))\n                seen[qualname] += 1\n                if verbose:\n                    print(f\"{indent + '    '}Method: {qualname}\")\n\n    # Iterate over submodules of module\n    if recursive:\n        for name, mod in inspect.getmembers(\n            module, lambda obj: inspect.ismodule(obj) and _is_submodule(module, obj)\n        ):\n            if name[0:1] == \"_\":\n                continue\n            seen = apply_decorator(\n                mod,\n                decorator,\n                recursive=recursive,\n                skip=skip,\n                seen=seen,\n                verbose=verbose,\n                level=level + 1,\n            )\n\n    return seen\n\n\ndef trace_scico_calls(verbose: bool = False):\n    \"\"\"Enable tracing of calls to all significant scico functions/methods.\n\n    Enable tracing of calls to all significant scico functions and\n    methods. Note that JIT should be disabled to ensure correct\n    functioning of the tracing mechanism.\n    \"\"\"\n    if not jax.config.jax_disable_jit:\n        warnings.warn(\n            \"Call tracing requested but jit is not disabled. Disable jit\"\n            \" by setting the environment variable JAX_DISABLE_JIT=1, or use\"\n            \" jax.config.update('jax_disable_jit', True).\"\n        )\n    from scico import (\n        function,\n        functional,\n        linop,\n        loss,\n        metric,\n        operator,\n        optimize,\n        solver,\n    )\n\n    seen = None\n    for module in (functional, linop, loss, operator, optimize, function, metric, solver):\n        seen = apply_decorator(module, call_trace, skip=[\"__repr__\"], seen=seen, verbose=verbose)\n"
  },
  {
    "path": "scico/typing.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2021-2025 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SPORCO package. Details of the copyright\n# and user license can be found in the 'LICENSE.txt' file distributed\n# with the package.\n\n\"\"\"Type definitions.\"\"\"\n\nfrom typing import Any, List, Tuple, Union\n\ntry:\n    # available in python 3.10\n    from types import EllipsisType  # type: ignore\n    from typing import TypeAlias  # type: ignore\nexcept ImportError:\n    from typing_extensions import TypeAlias  # type: ignore\n\n    EllipsisType: TypeAlias = Any  # type: ignore\n\n\nimport jax.numpy as jnp\nfrom jax import Array\n\nPRNGKey: TypeAlias = Array\n\"\"\"A key for jax random number generators (see :mod:`jax.random`).\"\"\"\n\nDType: TypeAlias = Union[\n    jnp.int8,\n    jnp.int16,\n    jnp.int32,\n    jnp.int64,\n    jnp.uint8,\n    jnp.uint16,\n    jnp.uint32,\n    jnp.uint64,\n    jnp.float16,\n    jnp.float32,\n    jnp.float64,\n    jnp.complex64,\n    jnp.complex128,\n    bool,\n]\n\"\"\"A jax dtype.\"\"\"\n\nShape: TypeAlias = Tuple[int, ...]\n\"\"\"A shape of a numpy or jax array.\"\"\"\n\nBlockShape: TypeAlias = Tuple[Tuple[int, ...], ...]\n\"\"\"A shape of a :class:`.BlockArray`.\"\"\"\n\nAxes: TypeAlias = Union[int, Tuple[int, ...], List[int]]\n\"\"\"Specification of one or more array axes.\"\"\"\n\nAxisIndex: TypeAlias = Union[slice, EllipsisType, int]\n\"\"\"An entity suitable for indexing/slicing of a single array axis; either\na slice object, Ellipsis, or int.\"\"\"\n\nArrayIndex: TypeAlias = Union[AxisIndex, Tuple[AxisIndex, ...]]\n\"\"\"An entity suitable for indexing/slicing of multi-dimentional arrays.\"\"\"\n"
  },
  {
    "path": "scico/util.py",
    "content": "# -*- coding: utf-8 -*-\n# Copyright (C) 2020-2026 by SCICO Developers\n# All rights reserved. BSD 3-clause License.\n# This file is part of the SCICO package. Details of the copyright and\n# user license can be found in the 'LICENSE' file distributed with the\n# package.\n\n\"\"\"General utility functions.\"\"\"\n\nfrom __future__ import annotations\n\nimport io\nimport socket\nimport urllib.error as urlerror\nimport urllib.request as urlrequest\nfrom functools import reduce, wraps\nfrom timeit import default_timer as timer\nfrom typing import Any, Callable, Dict, List, Optional, Sequence, Union\n\nimport jax\n\n\ndef rgetattr(obj: object, name: str, default: Optional[Any] = None) -> Any:\n    \"\"\"Recursive version of :func:`getattr`.\n\n    Args:\n        obj: Object with the attribute to be accessed.\n        name: Path to object in with components delimited by a \".\"\n           character.\n        default: Default value to be returned if the attribute does not\n           exist.\n\n    Returns:\n        Attribute value of default if attribute does not exist.\n    \"\"\"\n\n    try:\n        return reduce(getattr, name.split(\".\"), obj)\n    except AttributeError as e:\n        if default is not None:\n            return default\n        else:\n            raise e\n\n\ndef rsetattr(obj: object, name: str, value: Any):\n    \"\"\"Recursive version of :func:`setattr`.\n\n    Args:\n        obj: Object with the attribute to be set.\n        name: Path to object in with components delimited by a \".\"\n           character.\n        value: Value to which the attribute is to be set.\n    \"\"\"\n\n    # See goo.gl/BVJ7MN\n    path = name.split(\".\")\n    setattr(reduce(getattr, path[:-1], obj), path[-1], value)\n\n\ndef partial(func: Callable, indices: Sequence, *fixargs: Any, **fixkwargs: Any) -> Callable:\n    \"\"\"Flexible partial function creation.\n\n    This function is similar to :func:`functools.partial`, but allows\n    fixing of arbitrary positional arguments rather than just some number\n    of trailing positional arguments.\n\n    Args:\n        func: Function from which partial function is to be derived.\n        indices: Tuple of indices of positional args of `func` that are\n           to be fixed to the values specified in `fixargs`.\n        *fixargs: Fixed values for specified positional arguments.\n        **fixkwargs: Fixed values for keyword arguments.\n\n    Returns:\n       The partial function with fixed arguments.\n    \"\"\"\n\n    def pfunc(*freeargs, **freekwargs):\n        numargs = len(fixargs) + len(freeargs)\n        args = [\n            None,\n        ] * numargs\n        kfix = 0\n        kfree = 0\n        for k in range(numargs):\n            if k in indices:\n                args[k] = fixargs[kfix]\n                kfix += 1\n            else:\n                args[k] = freeargs[kfree]\n                kfree += 1\n        kwargs = freekwargs.copy()\n        kwargs.update(fixkwargs)\n        return func(*args, **kwargs)\n\n    posdoc = \"\"\n    if indices:\n        posdoc = f\"positional arguments {','.join(map(str, indices))}\"\n    kwdoc = \"\"\n    if fixkwargs:\n        kwdoc = f\"keyword arguments {','.join(fixkwargs.keys())}\"\n    pfunc.__doc__ = f\"Partial function derived from function {func.__name__}\"\n    if posdoc or kwdoc:\n        pfunc.__doc__ += \" by fixing \" + (\" and \".join(filter(None, (posdoc, kwdoc))))\n    return pfunc\n\n\ndef device_info(devid: int = 0) -> str:  # pragma: no cover\n    \"\"\"Get a string describing the specified device.\n\n    Args:\n        devid: ID number of device.\n\n    Returns:\n        Device description string.\n    \"\"\"\n    numdev = jax.device_count()\n    if devid >= numdev:\n        raise RuntimeError(f\"Requested information for device {devid} but only {numdev} present.\")\n    dev = jax.devices()[devid]\n    if dev.platform == \"cpu\":\n        info = \"CPU\"\n    else:\n        info = f\"{dev.platform.upper()} ({dev.device_kind})\"\n    return info\n\n\ndef check_for_tracer(func: Callable) -> Callable:\n    \"\"\"Check if positional arguments to `func` are jax tracers.\n\n    This is intended to be used as a decorator for functions that call\n    external code from within SCICO. At present, external functions\n    cannot be jit-ed or vmap/pmaped. This decorator checks for signs of\n    jit/vmap/pmap and raises an appropriate exception.\n    \"\"\"\n\n    @wraps(func)\n    def wrapper(*args, **kwargs):\n        if any([isinstance(x, jax.core.Tracer) for x in args]):\n            raise TypeError(\n                f\"JAX tracer found in {func.__name__}; did you jit/vmap/pmap this function?\"\n            )\n        return func(*args, **kwargs)\n\n    return wrapper\n\n\ndef url_get(\n    url: str, headers: Optional[dict] = None, maxtry: int = 3, timeout: int = 10\n) -> io.BytesIO:  # pragma: no cover\n    \"\"\"Get content of a file via a URL.\n\n    Args:\n        url: URL of the file to be downloaded.\n        headers: Dict of header strings for request.\n        maxtry: Maximum number of download retries.\n        timeout: Timeout in seconds for blocking operations.\n\n    Returns:\n        Buffered I/O stream.\n\n    Raises:\n        ValueError: If the maxtry parameter is not greater than zero.\n        urllib.error.URLError: If the file cannot be downloaded.\n    \"\"\"\n    if maxtry <= 0:\n        raise ValueError(\"Argument 'maxtry' should be greater than zero.\")\n\n    if headers is None:\n        headers = {}\n    req = urlrequest.Request(url, headers=headers)\n    for ntry in range(maxtry):\n        try:\n            rspns = urlrequest.urlopen(req, timeout=timeout)\n            cntnt = rspns.read()\n            break\n        except urlerror.URLError as e:\n            if not isinstance(e.reason, socket.timeout):\n                raise\n\n    return io.BytesIO(cntnt)\n\n\n# Timer classes are copied from https://github.com/bwohlberg/sporco\n\n\nclass Timer:\n    \"\"\"Timer class supporting multiple independent labeled timers.\n\n    The timer is based on the relative time returned by\n    :func:`timeit.default_timer`.\n    \"\"\"\n\n    def __init__(\n        self,\n        labels: Optional[Union[str, List[str]]] = None,\n        default_label: str = \"main\",\n        all_label: str = \"all\",\n    ):\n        \"\"\"\n        Args:\n            labels: Label(s) of the timer(s) to be initialised to zero.\n            default_label: Default timer label to be used when methods\n                are called without specifying a label.\n            all_label: Label string that will be used to denote all\n                timer labels.\n        \"\"\"\n\n        # Initialise current and accumulated time dictionaries\n        self.t0: Dict[str, Optional[float]] = {}\n        self.td: Dict[str, float] = {}\n        # Record default label and string indicating all labels\n        self.default_label = default_label\n        self.all_label = all_label\n        # Initialise dictionary entries for labels to be created\n        # immediately\n        if labels is not None:\n            if not isinstance(labels, (list, tuple)):\n                labels = [\n                    labels,\n                ]\n            for lbl in labels:\n                self.td[lbl] = 0.0\n                self.t0[lbl] = None\n\n    def start(self, labels: Optional[Union[str, List[str]]] = None):\n        \"\"\"Start specified timer(s).\n\n        Args:\n            labels: Label(s) of the timer(s) to be started. If it is\n               ``None``, start the default timer with label specified by\n               the `default_label` parameter of :meth:`__init__`.\n        \"\"\"\n\n        # Default label is self.default_label\n        if labels is None:\n            labels = self.default_label\n        # If label is not a list or tuple, create a singleton list\n        # containing it\n        if not isinstance(labels, (list, tuple)):\n            labels = [\n                labels,\n            ]\n        # Iterate over specified label(s)\n        t = timer()\n        for lbl in labels:\n            # On first call to start for a label, set its accumulator to zero\n            if lbl not in self.td:\n                self.td[lbl] = 0.0\n                self.t0[lbl] = None\n            # Record the time at which start was called for this lbl if\n            # it isn't already running\n            if self.t0[lbl] is None:\n                self.t0[lbl] = t\n\n    def stop(self, labels: Optional[Union[str, List[str]]] = None):\n        \"\"\"Stop specified timer(s).\n\n        Args:\n            labels: Label(s) of the timer(s) to be stopped. If it is\n               ``None``, stop the default timer with label specified by\n               the `default_label` parameter of :meth:`__init__`. If it\n               is equal to the string specified by the `all_label`\n               parameter of :meth:`__init__`, stop all timers.\n        \"\"\"\n\n        # Get current time\n        t = timer()\n        # Default label is self.default_label\n        if labels is None:\n            labels = self.default_label\n        # All timers are affected if label is equal to self.all_label,\n        # otherwise only the timer(s) specified by label\n        if labels == self.all_label:\n            labels = list(self.t0.keys())\n        elif not isinstance(labels, (list, tuple)):\n            labels = [\n                labels,\n            ]\n        # Iterate over specified label(s)\n        for lbl in labels:\n            if lbl not in self.t0:\n                raise KeyError(f\"Unrecognized timer key {lbl}.\")\n            # If self.t0[lbl] is None, the corresponding timer is\n            # already stopped, so no action is required\n            if self.t0[lbl] is not None:\n                # Increment time accumulator from the elapsed time\n                # since most recent start call\n                self.td[lbl] += t - self.t0[lbl]  # type: ignore\n                # Set start time to None to indicate timer is not running\n                self.t0[lbl] = None\n\n    def reset(self, labels: Optional[Union[str, List[str]]] = None):\n        \"\"\"Reset specified timer(s).\n\n        Args:\n            labels: Label(s) of the timer(s) to be stopped. If it is\n                ``None``, stop the default timer with label specified by\n                the `default_label` parameter of :meth:`__init__`. If it\n                is equal to the string specified by the `all_label`\n                parameter of :meth:`__init__`, stop all timers.\n        \"\"\"\n\n        # Default label is self.default_label\n        if labels is None:\n            labels = self.default_label\n        # All timers are affected if label is equal to self.all_label,\n        # otherwise only the timer(s) specified by label\n        if labels == self.all_label:\n            labels = list(self.t0.keys())\n        elif not isinstance(labels, (list, tuple)):\n            labels = [\n                labels,\n            ]\n        # Iterate over specified label(s)\n        for lbl in labels:\n            if lbl not in self.t0:\n                raise KeyError(f\"Unrecognized timer key {lbl}.\")\n            # Set start time to None to indicate timer is not running\n            self.t0[lbl] = None\n            # Set time accumulator to zero\n            self.td[lbl] = 0.0\n\n    def elapsed(self, label: Optional[str] = None, total: bool = True) -> float:\n        \"\"\"Get elapsed time since timer start.\n\n        Args:\n           label: Label of the timer for which the elapsed time is\n               required. If it is ``None``, the default timer with label\n               specified by the `default_label` parameter of\n               :meth:`__init__` is selected.\n           total: If ``True`` return the total elapsed time since the\n               first call of :meth:`start` for the selected timer,\n               otherwise return the elapsed time since the most recent\n               call of :meth:`start` for which there has not been a\n               corresponding call to :meth:`stop`.\n\n        Returns:\n           Elapsed time.\n        \"\"\"\n\n        # Get current time\n        t = timer()\n        # Default label is self.default_label\n        if label is None:\n            label = self.default_label\n            # Return 0.0 if default timer selected and it is not initialised\n            if label not in self.t0:\n                return 0.0\n        # Raise exception if timer with specified label does not exist\n        if label not in self.t0:\n            raise KeyError(f\"Unrecognized timer key {label}.\")\n        # If total flag is True return sum of accumulated time from\n        # previous start/stop calls and current start call, otherwise\n        # return just the time since the current start call\n        te = 0.0\n        if self.t0[label] is not None:\n            te = t - self.t0[label]  # type: ignore\n        if total:\n            te += self.td[label]\n\n        return te\n\n    def labels(self) -> List[str]:\n        \"\"\"Get a list of timer labels.\n\n        Returns:\n          List of timer labels.\n        \"\"\"\n\n        return list(self.t0.keys())\n\n    def __str__(self) -> str:\n        \"\"\"Return string representation of object.\n\n        The representation consists of a table with the following columns:\n\n          * Timer label.\n          * Accumulated time from past start/stop calls.\n          * Time since current start call, or 'Stopped' if timer is not\n            currently running.\n        \"\"\"\n\n        # Get current time\n        t = timer()\n        # Length of label field, calculated from max label length\n        fldlen = [len(lbl) for lbl in self.t0] + [\n            len(self.default_label),\n        ]\n        lfldln = max(fldlen) + 2\n        # Header string for table of timers\n        s = f\"{'Label':{lfldln}s}  Accum.       Current\\n\"\n        s += \"-\" * (lfldln + 25) + \"\\n\"\n        # Construct table of timer details\n        for lbl in sorted(self.t0):\n            td = self.td[lbl]\n            if self.t0[lbl] is None:\n                ts = \" Stopped\"\n            else:\n                ts = f\" {(t - self.t0[lbl]):.2e} s\" % (t - self.t0[lbl])  # type: ignore\n            s += f\"{lbl:{lfldln}s}  {td:.2e} s  {ts}\\n\"\n\n        return s\n\n\nclass ContextTimer:\n    \"\"\"A wrapper class for :class:`Timer` that enables its use as a\n    context manager.\n\n    For example, instead of\n\n    >>> t = Timer()\n    >>> t.start()\n    >>> x = sum(range(1000))\n    >>> t.stop()\n    >>> elapsed = t.elapsed()\n\n    one can use\n\n    >>> t = Timer()\n    >>> with ContextTimer(t):\n    ...   x = sum(range(1000))\n    >>> elapsed = t.elapsed()\n    \"\"\"\n\n    def __init__(\n        self,\n        timer: Optional[Timer] = None,\n        label: Optional[str] = None,\n        action: str = \"StartStop\",\n    ):\n        \"\"\"\n        Args:\n           timer: Timer object to be used as a context manager. If\n              ``None``, a new class:`Timer` object is constructed.\n           label: Label of the timer to be used. If it is ``None``, start\n              the default timer.\n           action: Actions to be taken on context entry and exit. If the\n              value is 'StartStop', start the timer on entry and stop on\n              exit; if it is 'StopStart', stop the timer on entry and\n              start it on exit.\n        \"\"\"\n\n        if action not in [\"StartStop\", \"StopStart\"]:\n            raise ValueError(f\"Unrecognized action {action}.\")\n        if timer is None:\n            self.timer = Timer()\n        else:\n            self.timer = timer\n        self.label = label\n        self.action = action\n\n    def __enter__(self):\n        \"\"\"Start the timer and return this ContextTimer instance.\"\"\"\n\n        if self.action == \"StartStop\":\n            self.timer.start(self.label)\n        else:\n            self.timer.stop(self.label)\n        return self\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        \"\"\"Stop the timer and return ``True`` if no exception was raised\n        within the `with` block, otherwise return ``False``.\n        \"\"\"\n\n        if self.action == \"StartStop\":\n            self.timer.stop(self.label)\n        else:\n            self.timer.start(self.label)\n        return not exc_type\n\n    def elapsed(self, total: bool = True) -> float:\n        \"\"\"Return the elapsed time for the timer.\n\n        Args:\n          total: If ``True`` return the total elapsed time since the\n             first call of :meth:`start` for the selected timer,\n             otherwise return the elapsed time since the most recent call\n             of :meth:`start` for which there has not been a\n             corresponding call to :meth:`stop`.\n\n        Returns:\n          Elapsed time.\n        \"\"\"\n\n        return self.timer.elapsed(self.label, total=total)\n"
  },
  {
    "path": "setup.py",
    "content": "\"\"\"SCICO package configuration.\"\"\"\n\nimport importlib.util\nimport os\nimport os.path\nimport site\nimport sys\n\nfrom setuptools import find_namespace_packages, setup\n\n# Import module scico._version without executing __init__.py\nspec = importlib.util.spec_from_file_location(\"_version\", os.path.join(\"scico\", \"_version.py\"))\nmodule = importlib.util.module_from_spec(spec)\nsys.modules[\"_version\"] = module\nspec.loader.exec_module(module)\nfrom _version import package_version\n\nname = \"scico\"\nversion = package_version()\n# Add argument exclude=[\"test\", \"test.*\"] to exclude test subpackage\npackages = find_namespace_packages(where=\"scico\")\npackages = [\"scico\"] + [f\"scico.{m}\" for m in packages]\n\n\nlongdesc = \"\"\"\nSCICO is a Python package for solving the inverse problems that arise in scientific imaging applications. Its primary focus is providing methods for solving ill-posed inverse problems by using an appropriate prior model of the reconstruction space. SCICO includes a growing suite of operators, cost functionals, regularizers, and optimization routines that may be combined to solve a wide range of problems, and is designed so that it is easy to add new building blocks. SCICO is built on top of JAX, which provides features such as automatic gradient calculation and GPU acceleration.\n\"\"\"\n\n# Set install_requires from requirements.txt file\nwith open(\"requirements.txt\") as f:\n    lines = f.readlines()\ninstall_requires = [line.strip() for line in lines]\n\npython_requires = \">=3.8\"\ntests_require = [\"pytest\", \"pytest-runner\"]\n\nextra_require_files = [\n    \"dev_requirements.txt\",\n    os.path.join(\"docs\", \"docs_requirements.txt\"),\n    os.path.join(\"examples\", \"examples_requirements.txt\"),\n    os.path.join(\"examples\", \"notebooks_requirements.txt\"),\n]\nextras_require = {\"tests\": tests_require}\nfor require_file in extra_require_files:\n    extras_label = os.path.basename(require_file).partition(\"_\")[0]\n    with open(require_file) as f:\n        lines = f.readlines()\n    extras_require[extras_label] = [line.strip() for line in lines if line[0:2] != \"-r\"]\n\n# PEP517 workaround, see https://www.scivision.dev/python-pip-devel-user-install/\nsite.ENABLE_USER_SITE = True\n\nsetup(\n    name=name,\n    version=version,\n    description=\"Scientific Computational Imaging COde: A Python \"\n    \"package for scientific imaging problems\",\n    long_description=longdesc,\n    keywords=[\n        \"Computational Imaging\",\n        \"Scientific Imaging\",\n        \"Inverse Problems\",\n        \"Plug-and-Play Priors\",\n        \"Total Variation\",\n        \"Optimization\",\n        \"ADMM\",\n        \"Linearized ADMM\",\n        \"PDHG\",\n        \"PGM\",\n    ],\n    platforms=\"Any\",\n    license=\"BSD-3-Clause\",\n    url=\"https://github.com/lanl/scico\",\n    author=\"SCICO Developers\",\n    author_email=\"brendt@ieee.org\",  # Temporary\n    packages=packages,\n    package_data={\"scico\": [\"data/*/*.png\", \"data/*/*.npz\"]},\n    include_package_data=True,\n    python_requires=python_requires,\n    install_requires=install_requires,\n    extras_require=extras_require,\n    classifiers=[\n        \"Development Status :: 4 - Beta\",\n        \"Intended Audience :: Education\",\n        \"Intended Audience :: Science/Research\",\n        \"Operating System :: OS Independent\",\n        \"Programming Language :: Python :: 3\",\n        \"Topic :: Scientific/Engineering :: Information Analysis\",\n        \"Topic :: Scientific/Engineering :: Mathematics\",\n        \"Topic :: Software Development :: Libraries :: Python Modules\",\n    ],\n    zip_safe=False,\n)\n"
  }
]