Repository: lanl/scico Branch: main Commit: 010a7714678a Files: 335 Total size: 1.8 MB Directory structure: gitextract_aiioi9sl/ ├── .coveragerc ├── .flake8 ├── .github/ │ ├── codecov.yml │ ├── isbin.sh │ └── workflows/ │ ├── check_files.yml │ ├── lint.yml │ ├── mypy.yml │ ├── pypi_upload.yml │ ├── pytest_latest.yml │ ├── pytest_macos.yml │ ├── pytest_ubuntu.yml │ └── test_examples.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CHANGES.rst ├── LICENSE ├── MANIFEST.in ├── README.md ├── conftest.py ├── dev_requirements.txt ├── docs/ │ ├── Makefile │ ├── docs_requirements.txt │ ├── rtd_requirements.txt │ ├── source/ │ │ ├── _static/ │ │ │ └── scico.css │ │ ├── _templates/ │ │ │ ├── autosummary/ │ │ │ │ └── module.rst │ │ │ ├── package.rst │ │ │ └── sidebar/ │ │ │ └── brand.html │ │ ├── advantages.rst │ │ ├── api.rst │ │ ├── classes.rst │ │ ├── conf/ │ │ │ ├── 10-project.py │ │ │ ├── 15-theme.py │ │ │ ├── 20-extensions.py │ │ │ ├── 25-napoleon.py │ │ │ ├── 30-autodoc.py │ │ │ ├── 40-intersphinx.py │ │ │ ├── 45-mathjax.py │ │ │ ├── 50-graphviz.py │ │ │ ├── 55-nbsphinx.py │ │ │ ├── 60-rtd.py │ │ │ ├── 70-latex.py │ │ │ ├── 71-texinfo.py │ │ │ ├── 72-man_page.py │ │ │ ├── 80-scico_numpy.py │ │ │ ├── 81-scico_scipy.py │ │ │ └── 85-dtype_typehints.py │ │ ├── conf.py │ │ ├── contributing.rst │ │ ├── docsutil.py │ │ ├── examples.rst │ │ ├── include/ │ │ │ ├── blockarray.rst │ │ │ ├── examplenotes.rst │ │ │ ├── functional.rst │ │ │ ├── learning.rst │ │ │ ├── operator.rst │ │ │ └── optimizer.rst │ │ ├── index.rst │ │ ├── install.rst │ │ ├── inverse.rst │ │ ├── notes.rst │ │ ├── overview.rst │ │ ├── pyfigures/ │ │ │ ├── cylindgrad.py │ │ │ ├── polargrad.py │ │ │ ├── spheregrad.py │ │ │ ├── xray_2d_geom.py │ │ │ ├── xray_3d_ang.py │ │ │ ├── xray_3d_vec.py │ │ │ └── xray_3d_vol.py │ │ ├── references.bib │ │ ├── style.rst │ │ ├── team.rst │ │ └── zreferences.rst │ └── tikxfigures/ │ ├── img_align.tex │ ├── makesvg.sh │ ├── vol_align_xyz.tex │ ├── vol_align_xz.tex │ └── vol_align_yz.tex ├── examples/ │ ├── README.rst │ ├── examples_requirements.txt │ ├── jnb.py │ ├── makeindex.py │ ├── makenotebooks.py │ ├── notebooks_requirements.txt │ ├── removejnberr.py │ ├── scriptcheck.sh │ ├── scripts/ │ │ ├── README.rst │ │ ├── ct_abel_tv_admm.py │ │ ├── ct_abel_tv_admm_tune.py │ │ ├── ct_astra_3d_tv_admm.py │ │ ├── ct_astra_3d_tv_padmm.py │ │ ├── ct_astra_noreg_pcg.py │ │ ├── ct_astra_tv_admm.py │ │ ├── ct_astra_weighted_tv_admm.py │ │ ├── ct_datagen_foam2.py │ │ ├── ct_fan_svmbir_ppp_bm3d_admm_prox.py │ │ ├── ct_modl_train_foam2.py │ │ ├── ct_multi_tv_admm.py │ │ ├── ct_odp_train_foam2.py │ │ ├── ct_projector_comparison_2d.py │ │ ├── ct_projector_comparison_3d.py │ │ ├── ct_svmbir_ppp_bm3d_admm_cg.py │ │ ├── ct_svmbir_ppp_bm3d_admm_prox.py │ │ ├── ct_svmbir_tv_multi.py │ │ ├── ct_symcone_tv_padmm.py │ │ ├── ct_tv_admm.py │ │ ├── ct_unet_train_foam2.py │ │ ├── deconv_circ_tv_admm.py │ │ ├── deconv_datagen_bsds.py │ │ ├── deconv_datagen_foam1.py │ │ ├── deconv_microscopy_allchn_tv_admm.py │ │ ├── deconv_microscopy_tv_admm.py │ │ ├── deconv_modl_train_foam1.py │ │ ├── deconv_odp_train_foam1.py │ │ ├── deconv_ppp_bm3d_admm.py │ │ ├── deconv_ppp_bm3d_apgm.py │ │ ├── deconv_ppp_bm4d_admm.py │ │ ├── deconv_ppp_dncnn_admm.py │ │ ├── deconv_ppp_dncnn_padmm.py │ │ ├── deconv_tv_admm.py │ │ ├── deconv_tv_admm_tune.py │ │ ├── deconv_tv_padmm.py │ │ ├── demosaic_ppp_bm3d_admm.py │ │ ├── denoise_approx_tv_multi.py │ │ ├── denoise_cplx_tv_nlpadmm.py │ │ ├── denoise_cplx_tv_pdhg.py │ │ ├── denoise_datagen_bsds.py │ │ ├── denoise_dncnn_train_bsds.py │ │ ├── denoise_dncnn_universal.py │ │ ├── denoise_l1tv_admm.py │ │ ├── denoise_ptv_pdhg.py │ │ ├── denoise_tv_admm.py │ │ ├── denoise_tv_apgm.py │ │ ├── denoise_tv_multi.py │ │ ├── diffusercam_tv_admm.py │ │ ├── index.rst │ │ ├── sparsecode_apgm.py │ │ ├── sparsecode_conv_admm.py │ │ ├── sparsecode_conv_md_admm.py │ │ ├── sparsecode_nn_admm.py │ │ ├── sparsecode_nn_apgm.py │ │ ├── sparsecode_poisson_apgm.py │ │ ├── superres_ppp_dncnn_admm.py │ │ ├── trace_example.py │ │ └── video_rpca_admm.py │ ├── updatejnbcode.py │ └── updatejnbmd.py ├── misc/ │ ├── README.rst │ ├── conda/ │ │ ├── README.rst │ │ ├── install_conda.sh │ │ └── make_conda_env.sh │ ├── gpu/ │ │ ├── README.rst │ │ ├── availgpu.py │ │ └── envinfo.py │ └── pytest/ │ ├── README.rst │ ├── pytest_cov.sh │ ├── pytest_fast.sh │ └── pytest_time.sh ├── pyproject.toml ├── pytest.ini ├── requirements.txt ├── scico/ │ ├── __init__.py │ ├── _core.py │ ├── _version.py │ ├── data/ │ │ └── __init__.py │ ├── denoiser.py │ ├── diagnostics.py │ ├── examples.py │ ├── flax/ │ │ ├── __init__.py │ │ ├── _flax.py │ │ ├── _models.py │ │ ├── blocks.py │ │ ├── examples/ │ │ │ ├── __init__.py │ │ │ ├── data_generation.py │ │ │ ├── data_preprocessing.py │ │ │ ├── examples.py │ │ │ └── typed_dict.py │ │ ├── inverse.py │ │ └── train/ │ │ ├── __init__.py │ │ ├── apply.py │ │ ├── checkpoints.py │ │ ├── clu_utils.py │ │ ├── diagnostics.py │ │ ├── input_pipeline.py │ │ ├── learning_rate.py │ │ ├── losses.py │ │ ├── spectral.py │ │ ├── state.py │ │ ├── steps.py │ │ ├── trainer.py │ │ ├── traversals.py │ │ └── typed_dict.py │ ├── function.py │ ├── functional/ │ │ ├── __init__.py │ │ ├── _denoiser.py │ │ ├── _dist.py │ │ ├── _functional.py │ │ ├── _indicator.py │ │ ├── _norm.py │ │ ├── _proxavg.py │ │ └── _tvnorm.py │ ├── linop/ │ │ ├── __init__.py │ │ ├── _circconv.py │ │ ├── _convolve.py │ │ ├── _dft.py │ │ ├── _diag.py │ │ ├── _diff.py │ │ ├── _func.py │ │ ├── _grad.py │ │ ├── _linop.py │ │ ├── _matrix.py │ │ ├── _stack.py │ │ ├── _util.py │ │ ├── optics.py │ │ └── xray/ │ │ ├── __init__.py │ │ ├── _axitom/ │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── backprojection.py │ │ │ ├── config.py │ │ │ ├── filtering.py │ │ │ ├── projection.py │ │ │ └── utilities.py │ │ ├── _util.py │ │ ├── _xray.py │ │ ├── abel.py │ │ ├── astra.py │ │ ├── svmbir.py │ │ └── symcone.py │ ├── loss.py │ ├── metric.py │ ├── numpy/ │ │ ├── __init__.py │ │ ├── _blockarray.py │ │ ├── _wrapped_function_lists.py │ │ ├── _wrappers.py │ │ ├── fft.py │ │ ├── linalg.py │ │ ├── testing.py │ │ └── util.py │ ├── operator/ │ │ ├── __init__.py │ │ ├── _func.py │ │ ├── _operator.py │ │ ├── _stack.py │ │ └── biconvolve.py │ ├── optimize/ │ │ ├── __init__.py │ │ ├── _admm.py │ │ ├── _admmaux.py │ │ ├── _common.py │ │ ├── _ladmm.py │ │ ├── _padmm.py │ │ ├── _pgm.py │ │ ├── _pgmaux.py │ │ ├── _primaldual.py │ │ ├── admm.py │ │ └── pgm.py │ ├── plot.py │ ├── random.py │ ├── ray/ │ │ ├── __init__.py │ │ └── tune.py │ ├── scipy/ │ │ ├── __init__.py │ │ └── special.py │ ├── solver.py │ ├── test/ │ │ ├── conftest.py │ │ ├── flax/ │ │ │ ├── test_apply.py │ │ │ ├── test_checkpoints.py │ │ │ ├── test_clu.py │ │ │ ├── test_examples_flax.py │ │ │ ├── test_flax.py │ │ │ ├── test_inv.py │ │ │ ├── test_spectral.py │ │ │ ├── test_steps.py │ │ │ ├── test_train_aux.py │ │ │ ├── test_trainer.py │ │ │ └── test_traversal.py │ │ ├── functional/ │ │ │ ├── prox.py │ │ │ ├── test_composed.py │ │ │ ├── test_denoiser_func.py │ │ │ ├── test_funcional_core.py │ │ │ ├── test_indicator.py │ │ │ ├── test_loss.py │ │ │ ├── test_misc.py │ │ │ ├── test_norm.py │ │ │ ├── test_proxavg.py │ │ │ ├── test_separable.py │ │ │ └── test_tvnorm.py │ │ ├── linop/ │ │ │ ├── test_binop.py │ │ │ ├── test_circconv.py │ │ │ ├── test_conversions.py │ │ │ ├── test_convolve.py │ │ │ ├── test_dft.py │ │ │ ├── test_diag.py │ │ │ ├── test_diff.py │ │ │ ├── test_func.py │ │ │ ├── test_grad.py │ │ │ ├── test_linop.py │ │ │ ├── test_linop_stack.py │ │ │ ├── test_linop_util.py │ │ │ ├── test_matrix.py │ │ │ ├── test_optics.py │ │ │ └── xray/ │ │ │ ├── test_abel.py │ │ │ ├── test_astra.py │ │ │ ├── test_svmbir.py │ │ │ ├── test_symcone.py │ │ │ ├── test_xray_2d.py │ │ │ ├── test_xray_3d.py │ │ │ └── test_xray_util.py │ │ ├── numpy/ │ │ │ ├── test_blockarray.py │ │ │ ├── test_numpy.py │ │ │ └── test_numpy_util.py │ │ ├── operator/ │ │ │ ├── test_biconvolve.py │ │ │ ├── test_op_stack.py │ │ │ └── test_operator.py │ │ ├── optimize/ │ │ │ ├── test_admm.py │ │ │ ├── test_ladmm.py │ │ │ ├── test_padmm.py │ │ │ ├── test_pdhg.py │ │ │ └── test_pgm.py │ │ ├── osver.py │ │ ├── test_core.py │ │ ├── test_data.py │ │ ├── test_denoiser.py │ │ ├── test_diagnostics.py │ │ ├── test_examples.py │ │ ├── test_function.py │ │ ├── test_metric.py │ │ ├── test_random.py │ │ ├── test_ray_tune.py │ │ ├── test_scipy_special.py │ │ ├── test_solver.py │ │ ├── test_util.py │ │ └── test_version.py │ ├── trace.py │ ├── typing.py │ └── util.py └── setup.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .coveragerc ================================================ [run] source = scico command_line = -m pytest omit = scico/test/* scico/plot.py scico/trace.py scico/linop/xray/_axitom/*.py [report] # Regexes for lines to exclude from consideration exclude_lines = # Have to re-enable the standard pragma pragma: no cover def __repr__ ================================================ FILE: .flake8 ================================================ [flake8] max-line-length = 100 ignore = #E731: do not assign a lambda expression, use a def E731 ================================================ FILE: .github/codecov.yml ================================================ coverage: precision: 2 round: nearest range: "80...100" status: project: default: target: auto threshold: 0.05% patch: false ================================================ FILE: .github/isbin.sh ================================================ #! /bin/bash # Determine whether files are acceptable for commit into main scico repo size_threshold=65536 SAVEIFS=$IFS IFS=$(echo -en "\n\b") OS=$(uname -a | cut -d ' ' -f 1) for f in $@; do echo $f case "$OS" in Linux) size=$(stat --format "%s" $f);; Darwin) size=$(stat -f "%z" $f);; *) echo "Error: unsupported operating system $OS" >&2; exit 1;; esac # Exception on maximum size for pytest-split .test_durations file if [ $size -gt $size_threshold ] && [ "$(basename $f)" != ".test_durations" ]; then echo "file exceeds maximum allowable size of $size_threshold bytes" echo "raw data and ipynb files should go in scico-data" exit 2 fi charset=$(file -b --mime $f | sed -e 's/.*charset=//') if [ ! -L "$f" ] && [ "$charset" = "binary" ]; then echo "binary files cannot be commited to the repository" echo "raw data and ipynb files should go in scico-data" exit 3 fi basename=$(basename -- "$f") ext="${basename##*.}" if [ "$ext" = "ipynb" ]; then echo "ipynb files cannot be commited to the repository" echo "raw data and ipynb files should go in scico-data" exit 4 fi done IFS=$SAVEIFS exit 0 ================================================ FILE: .github/workflows/check_files.yml ================================================ # Check file types and sizes name: check files on: [push, pull_request] jobs: checkfiles: runs-on: ubuntu-latest steps: - name: checkout uses: actions/checkout@v5 - id: files uses: Ana06/get-changed-files@v2.3.0 continue-on-error: true - run: | for f in ${{ steps.files.outputs.added }}; do ${GITHUB_WORKSPACE}/.github/./isbin.sh $f done ================================================ FILE: .github/workflows/lint.yml ================================================ # Run isort and black on pushes to main and any pull requests name: lint on: push: branches: - main pull_request: jobs: lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v5 - uses: actions/setup-python@v6 with: python-version: "3.12" - name: Black code formatter uses: psf/black@stable with: version: ">=24.3.0" - name: Isort import sorter uses: isort/isort-action@v1 - name: Pylint code analysis run: | pip install pylint pylint --disable=all --enable=missing-docstring,broad-exception-raised scico ================================================ FILE: .github/workflows/mypy.yml ================================================ # Install and run mypy name: mypy on: push: branches: [ main ] pull_request: branches: [ main ] workflow_dispatch: jobs: mypy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v5 with: submodules: recursive - name: Install Python 3 uses: actions/setup-python@v6 with: python-version: "3.12" - name: Install dependencies run: | pip install mypy - name: Run mypy run: | mypy --follow-imports=skip --ignore-missing-imports --exclude "(numpy|test)" scico/ scico/numpy/util.py ================================================ FILE: .github/workflows/pypi_upload.yml ================================================ # When a tag is pushed, build packages and upload to PyPI name: pypi upload # Trigger when tags are pushed on: push: tags: - '*' workflow_dispatch: jobs: build-and-upload: name: Upload package to PyPI runs-on: ubuntu-latest steps: - uses: actions/checkout@v5 with: submodules: recursive - name: Install Python 3 uses: actions/setup-python@v6 with: python-version: "3.12" - name: Install dependencies run: | python -m pip install --upgrade pip sudo apt-get install -y libopenblas-dev pip install -r requirements.txt pip install -r dev_requirements.txt pip install wheel python setup.py sdist bdist_wheel - name: Upload package to PyPI uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} verbose: true ================================================ FILE: .github/workflows/pytest_latest.yml ================================================ # Install scico requirements and run pytest with latest jax version name: unit tests (latest jax) # Controls when the workflow will run on: # Run workflow every Sunday at midnight UTC schedule: - cron: "0 0 * * 0" # Allows you to run this workflow manually from the Actions tab workflow_dispatch: jobs: pytest-latest-jax: runs-on: ubuntu-latest steps: - uses: actions/checkout@v5 with: submodules: recursive - name: Install Python 3 uses: actions/setup-python@v6 with: python-version: "3.12" - name: Install lastversion run: | python -m pip install --upgrade pip pip install lastversion - name: Install dependencies run: | rjaxlib=$(grep jaxlib requirements.txt | sed -e 's/jaxlib.*<=\([0-9\.]*$\)/\1/') rjax=$(grep -E "jax[^lib]" requirements.txt | sed -e 's/jax.*<=\([0-9\.]*$\)/\1/') ljaxlib=$(lastversion --at pip jaxlib) ljax=$(lastversion --at pip jax) echo jaxlib required: $rjaxlib latest: $ljaxlib echo jax required: $rjax latest: $ljax if [ "$rjaxlib" = "$ljaxlib" ] && [ "$rjax" = "$ljax" ]; then echo Test is redundant: required and latest jaxlib/jax versions match echo "TEST=cancel" >> $GITHUB_ENV else echo "TEST=run" >> $GITHUB_ENV sudo apt-get install -y libopenblas-dev pip install -r requirements.txt pip install -r dev_requirements.txt pip install -e . pip install --upgrade "jax[cpu]" fi - name: Run tests with pytest run: | TEST="${{ env.TEST }}" if [ "$TEST" = "run" ]; then pytest else exit 0 fi ================================================ FILE: .github/workflows/pytest_macos.yml ================================================ # Install scico requirements and run pytest name: unit tests (macos) on: push: branches: [ main ] pull_request: branches: [ main ] workflow_dispatch: jobs: test: runs-on: macos-latest strategy: fail-fast: false matrix: group: [1, 2, 3, 4, 5] name: pytest split ${{ matrix.group }} (macos) defaults: run: shell: bash -l {0} steps: # Check-out the repository under $GITHUB_WORKSPACE - uses: actions/checkout@v5 with: submodules: recursive # Set up conda environment - name: Set up miniconda uses: conda-incubator/setup-miniconda@v3 with: miniforge-version: latest activate-environment: test-env python-version: "3.12" # Configure conda environment cache - name: Set up conda environment cache uses: actions/cache@v4 with: path: ${{ env.CONDA }}/envs key: conda-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev_requirements.txt') }}-${{ env.CACHE_NUMBER }} env: CACHE_NUMBER: 0 # Increase this value to force cache reset id: cache # Display environment details - name: Display environment details run: | conda info printenv | sort # Install dependencies in conda environment - name: Install dependencies if: steps.cache.outputs.cache-hit != 'true' run: | conda install -c conda-forge pytest pytest-cov python -m pip install --upgrade pip pip install pytest-split pip install -r requirements.txt pip install -r dev_requirements.txt pip install "bm3d>=4.0.0" pip install "bm4d>=4.0.0" pip install "ray[tune]>=2.44" pip install hyperopt pip install "setuptools<82.0.0" # workaround for hyperopt 0.2.7 pip install pydantic pip install "orbax-checkpoint>=0.5.0" #conda install -c conda-forge "svmbir>=0.4.0" conda install -c astra-toolbox astra-toolbox conda install -c conda-forge pyyaml # Install package to be tested - name: Install package to be tested run: pip install -e . # Run unit tests - name: Run main unit tests run: | DURATIONS_FILE=$(mktemp) bzcat data/pytest/durations_macos.bz2 > $DURATIONS_FILE pytest -x --level=1 --durations-path=$DURATIONS_FILE --splits=5 --group=${{ matrix.group }} --pyargs scico ================================================ FILE: .github/workflows/pytest_ubuntu.yml ================================================ # Install scico requirements and run pytest name: unit tests (ubuntu) on: push: branches: [ main ] pull_request: branches: [ main ] workflow_dispatch: inputs: debug_enabled: type: boolean description: 'Run the build with tmate debugging enabled (https://github.com/marketplace/actions/debugging-with-tmate)' required: false default: false jobs: test: runs-on: ubuntu-latest strategy: fail-fast: false matrix: group: [1, 2, 3, 4, 5] name: pytest split ${{ matrix.group }} (ubuntu) defaults: run: shell: bash -l {0} steps: # Check-out the repository under $GITHUB_WORKSPACE - uses: actions/checkout@v5 with: submodules: recursive # Set up conda environment - name: Set up miniconda uses: conda-incubator/setup-miniconda@v3 with: miniforge-version: latest activate-environment: test-env python-version: "3.12" # Configure conda environment cache - name: Set up conda environment cache uses: actions/cache@v4 with: path: ${{ env.CONDA }}/envs key: conda-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev_requirements.txt') }}-${{ env.CACHE_NUMBER }} env: CACHE_NUMBER: 0 # Increase this value to force cache reset id: cache # Display environment details - name: Display environment details run: | conda info printenv | sort # Install required system package - name: Install required system package run: sudo apt-get install -y libopenblas-dev # Install dependencies in conda environment - name: Install dependencies if: steps.cache.outputs.cache-hit != 'true' run: | conda install -c conda-forge pytest pytest-cov python -m pip install --upgrade pip pip install pytest-split pip install -r requirements.txt pip install -r dev_requirements.txt pip install "bm4d>=4.2.2" pip install "bm3d>=4.0.0" pip install "ray[tune]>=2.44" pip install hyperopt pip install "setuptools<82.0.0" # workaround for hyperopt 0.2.7 pip install pydantic pip install "orbax-checkpoint>=0.5.0" conda install -c conda-forge "svmbir>=0.4.0" conda install -c conda-forge astra-toolbox conda install -c conda-forge pyyaml # Install package to be tested - name: Install package to be tested run: pip install -e . # Enable tmate debugging of manually-triggered workflows if the input option was provided - name: Setup tmate session uses: mxschmitt/action-tmate@v3 if: ${{ github.event_name == 'workflow_dispatch' && inputs.debug_enabled }} # Run unit tests - name: Run main unit tests run: | DURATIONS_FILE=$(mktemp) bzcat data/pytest/durations_ubuntu.bz2 > $DURATIONS_FILE pytest -x --cov --level=2 --durations-path=$DURATIONS_FILE --splits=5 --group=${{ matrix.group }} --pyargs scico # Upload coverage data - name: Upload coverage uses: actions/upload-artifact@v4 with: include-hidden-files: true name: coverage${{ matrix.group }} path: ${{ github.workspace }}/.coverage # Run doc tests - name: Run doc tests if: matrix.group == 1 run: | pytest --ignore-glob="*test_*.py" --ignore=scico/linop/xray --doctest-modules scico pytest --doctest-glob="*.rst" docs coverage: needs: test runs-on: ubuntu-latest steps: - uses: actions/checkout@v5 - name: Set up Python 3.12 uses: actions/setup-python@v6 with: python-version: "3.12" - name: Install deps run: | python -m pip install --upgrade pip pip install coverage - name: Download all artifacts # Downloads coverage1, coverage2, etc. uses: actions/download-artifact@v4 - name: Run coverage run: | coverage combine coverage?/.coverage coverage report coverage xml - uses: codecov/codecov-action@v4 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} with: env_vars: OS,PYTHON fail_ci_if_error: false files: coverage.xml flags: unittests name: codecov-umbrella verbose: true ================================================ FILE: .github/workflows/test_examples.yml ================================================ # Install scico requirements and run short versions of example scripts name: test examples on: push: branches: [ main ] pull_request: branches: [ main ] # Allow this workflow to be run manually from the Actions tab workflow_dispatch: jobs: build: runs-on: ubuntu-latest strategy: fail-fast: false name: test examples (ubuntu) defaults: run: shell: bash -l {0} steps: # Check-out the repository under $GITHUB_WORKSPACE - uses: actions/checkout@v5 with: submodules: recursive # Set up conda environment - name: Set up miniconda uses: conda-incubator/setup-miniconda@v3 with: miniforge-version: latest activate-environment: test-env python-version: "3.12" # Configure conda environment cache - name: Set up conda environment cache uses: actions/cache@v4 with: path: ${{ env.CONDA }}/envs key: conda-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev_requirements.txt') }}-${{ hashFiles('examples/examples_requirements.txt') }}-${{ env.CACHE_NUMBER }} env: CACHE_NUMBER: 0 # Increase this value to force cache reset id: cache # Display environment details - name: Display environment details run: | conda info printenv | sort # Install required system package - name: Install required system package run: sudo apt-get install -y libopenblas-dev # Install dependencies in conda environment - name: Install dependencies if: steps.cache.outputs.cache-hit != 'true' run: | conda install -c conda-forge pytest pytest-cov python -m pip install --upgrade pip pip install -r requirements.txt pip install -r dev_requirements.txt conda install -c conda-forge astra-toolbox conda install -c conda-forge pyyaml pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version pip install -r examples/examples_requirements.txt # Install package to be tested - name: Install package to be tested run: pip install -e . # Run example test - name: Run example test run: | ${GITHUB_WORKSPACE}/examples/scriptcheck.sh -e -d -t -g ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Editor backups .*~ # Docs generation docs/source/_autosummary/ docs/source/examples/ # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # VS Code settings .vscode/ # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # macos files *.DS_Store ================================================ FILE: .gitmodules ================================================ [submodule "data"] path = data url = https://github.com/lanl/scico-data.git ================================================ FILE: .pre-commit-config.yaml ================================================ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.3.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace - repo: local hooks: - id: check-for-binary name: check for binary/ipynb files entry: .github/isbin.sh language: script pass_filenames: true - id: autoflake name: autoflake entry: autoflake language: python language_version: python3 types: [python] args: ['-i', '--remove-all-unused-imports', '--ignore-init-module-imports'] - id: isort name: isort (python) entry: isort language: python language_version: python3 types: [python] - id: isort name: isort (cython) entry: isort language: python language_version: python3 types: [cython] - id: black name: black entry: black description: 'Black: The uncompromising Python code formatter' language: python language_version: python3 types: [python] - id: pylint name: pylint entry: pylint language: python language_version: python3 types: [python] exclude: ^(scico/test/|examples|docs) args: ['--score=n', '--disable=all', '--enable=missing-docstring,broad-exception-raised'] ================================================ FILE: .readthedocs.yaml ================================================ # .readthedocs.yaml # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details # Required version: 2 # Get submodules submodules: include: all recursive: true # Set the version of Python and other tools you might need build: os: ubuntu-24.04 tools: python: "3.12" jobs: pre_build: - mkdir -p docs/source/examples - | for f in data/notebooks/*; do b=$(basename $f) if [ ! -f "docs/source/examples/$b" ]; then ln -s -t docs/source/examples "../../../$f" fi done post_build: # unclear why this is necessary - cp docs/source/_static/scico.css _readthedocs/html/_static apt_packages: - graphviz - libopenblas-dev # Build documentation in the docs/ directory with Sphinx sphinx: builder: html configuration: docs/source/conf.py fail_on_warning: false # Declare the Python requirements required to build your docs python: install: - requirements: docs/docs_requirements.txt - requirements: docs/rtd_requirements.txt ================================================ FILE: CHANGES.rst ================================================ =================== SCICO Release Notes =================== Version 0.0.8 (unreleased) ---------------------------- • Enable certain parameters of array creation functions to trigger ``BlockArray`` creation when they receive lists (currently ``device``). • New functional ``functional.BoxIndicator``. • Support ``jaxlib`` and ``jax`` versions 0.5.0 to 0.10.0. • Support ``flax`` versions 0.8.0 to 0.12.7. • Various bug fixes and minor improvements. Version 0.0.7 (2025-12-09) ---------------------------- • New module ``scico.trace`` for tracing function/method calls. • New generic functional ``functional.ComposedFunctional`` representing a functional composed with an orthogonal linear operator. • New optimizer methods ``save_state`` and ``load_state`` supporting algorithm state checkpointing. • New classes for creating a volume from an image by symmetry, and for cone beam X-ray transform of a cylindrically symmetric object in module ``linop.xray.symcone``. • New utility functions for CT reconstruction preprocessing added in module ``linop.xray``. • Moved ``linop.abel`` module to ``linop.xray.abel``. • Make ``orbax-checkpoint`` dependency optional due to absence of recent conda-forge packages. • Support ``jaxlib`` and ``jax`` versions 0.5.0 to 0.8.1. • Support ``flax`` versions 0.8.0 to 0.12.0. Version 0.0.6 (2024-10-25) ---------------------------- • Significant changes to ``linop.xray.astra`` API. • Rename integrated 2D X-ray transform class to ``linop.xray.XRayTransform2D`` and add filtered back projection method ``fbp``. • New integrated 3D X-ray transform via ``linop.xray.XRayTransform3D``. • New functional ``functional.IsotropicTVNorm`` and faster implementation of ``functional.AnisotropicTVNorm``. • New linear operators ``linop.ProjectedGradient``, ``linop.PolarGradient``, ``linop.CylindricalGradient``, and ``linop.SphericalGradient``. • Rename ``scico.numpy.util.parse_axes`` to ``scico.numpy.util.normalize_axes``. • Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to ``scico.flax.save_variables`` and ``scico.flax.load_variables`` respectively. • Support ``jaxlib`` and ``jax`` versions 0.4.13 to 0.4.35. • Support ``flax`` versions 0.8.0 to 0.10.0. Version 0.0.5 (2023-12-18) ---------------------------- • New functionals ``functional.AnisotropicTVNorm`` and ``functional.ProximalAverage`` with proximal operator approximations. • New integrated Radon/X-ray transform ``linop.XRayTransform``. • New operators ``operator.DiagonalStack`` and ``operator.VerticalStack``. • Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and ``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes to ``XRayTransform``. • Rename ``AbelProjector`` to ``AbelTransform``. • Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``. • Rename some ``__init__`` parameters of ``linop.DiagonalStack`` and ``linop.VerticalStack``. • Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.23. • Support ``flax`` versions up to 0.7.5. • Use ``orbax`` for checkpointing ``flax`` models. Version 0.0.4 (2023-08-03) ---------------------------- • Add new ``Function`` class for representing array-to-array mappings with more than one input. • Add new methods and a function for computing Jacobian-vector products for ``Operator`` objects. • Add new proximal ADMM solvers. • Add new ADMM subproblem solvers for problems involving a sum-of-convolutions operator. • Extend support for other ML models including UNet, ODP and MoDL. • Add functionality for training Flax-based ML models and for data generation. • Enable diagnostics for ML training loops. • Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.14. • Change required packages and version numbers, including more recent version for ``flax``. • Drop support for Python 3.7. • Add support for 3D tomographic projection with the ASTRA Toolbox. Version 0.0.3 (2022-09-21) ---------------------------- • Change required packages and version numbers, including more recent version requirements for ``numpy``, ``scipy``, ``svmbir``, and ``ray``. • Package ``bm4d`` removed from main requirements list due to issue #342. • Support ``jaxlib`` versions 0.3.0 to 0.3.15 and ``jax`` versions 0.3.0 to 0.3.17. • Rename linear operators in ``radon_astra`` and ``radon_svmbir`` modules to ``TomographicProjector``. • Add support for fan beam CT in ``radon_svmbir`` module. • Add function ``linop.linop_from_function`` for constructing linear operators from functions. • Enable addition operator for functionals. • Completely new implementation of ``BlockArray`` class. • Additional solvers in ``scico.solver``. • New Huber norm (``HuberNorm``) and set distance functionals (``SetDistance`` and ``SquaredSetDistance``). • New loss functions ``loss.SquaredL2AbsLoss`` and ``loss.SquaredL2SquaredAbsLoss`` for phase retrieval problems. • Add interface to BM4D denoiser. • Change interfaces of ``linop.FiniteDifference`` and ``linop.DFT``. • Change filenames of some example scripts (and corresponding notebooks). • Add support for Python 3.7. • New ``DiagonalStack`` linear operator. • Add support for non-linear operators to ``optimize.PDHG`` optimizer class. • Various bug fixes. Version 0.0.2 (2022-02-14) ---------------------------- • Additional optimization algorithms: Linearized ADMM and PDHG. • Additional Abel transform and array slicing linear operators. • Additional nuclear norm functional. • New module ``scico.ray.tune`` providing a simplified interface to Ray Tune. • Move optimization algorithms into ``optimize`` subpackage. • Additional iteration stats columns for iterative ADMM subproblem solvers. • Renamed "Primal Rsdl" to "Prml Rsdl" in displayed iteration stats. • Move some functions from ``util`` and ``math`` modules to new ``array`` module. • Bump pinned ``jaxlib`` and ``jax`` versions to 0.3.0. Version 0.0.1 (2021-11-24) ---------------------------- • Initial release. ================================================ FILE: LICENSE ================================================ BSD 3-Clause License Copyright (c) 2021-2025, Los Alamos National Laboratory All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: MANIFEST.in ================================================ include MANIFEST.in include README.md include CHANGES.rst include LICENSE include setup.py include conftest.py include pyproject.toml include pytest.ini include requirements.txt include dev_requirements.txt include docs/docs_requirements.txt recursive-include scico *.py recursive-include scico/data *.png *.mpk *.rst recursive-include docs Makefile *.py *.ipynb *.rst *.bib *.css *.svg *.png *.ico recursive-include examples *_requirements.txt *.txt *.rst *.py *.sh recursive-include misc *.py *.sh *.rst ================================================ FILE: README.md ================================================ [![Python \>= 3.8](https://img.shields.io/badge/python-3.8+-green.svg)](https://www.python.org/) [![Package License](https://img.shields.io/github/license/lanl/scico.svg)](https://github.com/lanl/scico/blob/main/LICENSE) [![Code style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![Documentation Status](https://readthedocs.org/projects/scico/badge/?version=latest)](http://scico.readthedocs.io/en/latest/?badge=latest) [![JOSS paper](https://joss.theoj.org/papers/10.21105/joss.04722/status.svg)](https://doi.org/10.21105/joss.04722)\ [![Lint status](https://github.com/lanl/scico/actions/workflows/lint.yml/badge.svg)](https://github.com/lanl/scico/actions/workflows/lint.yml) [![Test status](https://github.com/lanl/scico/actions/workflows/pytest_ubuntu.yml/badge.svg)](https://github.com/lanl/scico/actions/workflows/pytest_ubuntu.yml) [![Test coverage](https://codecov.io/gh/lanl/scico/branch/main/graph/badge.svg?token=wQimmjnzFf)](https://codecov.io/gh/lanl/scico) [![CodeFactor](https://www.codefactor.io/repository/github/lanl/scico/badge/main)](https://www.codefactor.io/repository/github/lanl/scico/overview/main)\ [![PyPI package version](https://badge.fury.io/py/scico.svg)](https://badge.fury.io/py/scico) [![PyPI download statistics](https://static.pepy.tech/personalized-badge/scico?period=total&left_color=grey&right_color=brightgreen&left_text=downloads)](https://pepy.tech/project/scico) [![Conda Forge Release](https://img.shields.io/conda/vn/conda-forge/scico.svg)](https://anaconda.org/conda-forge/scico) [![Conda Forge Downloads](https://img.shields.io/conda/dn/conda-forge/scico.svg)](https://anaconda.org/conda-forge/scico)\ [![View notebooks at nbviewer](https://raw.githubusercontent.com/jupyter/design/master/logos/Badges/nbviewer_badge.svg)](https://nbviewer.jupyter.org/github/lanl/scico-data/tree/main/notebooks/index.ipynb) [![Run notebooks on binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/lanl/scico-data/binder?labpath=notebooks%2Findex.ipynb) [![Run notebooks on google colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lanl/scico-data/blob/colab/notebooks/index.ipynb) # Scientific Computational Imaging Code (SCICO) SCICO is a Python package for solving the inverse problems that arise in scientific imaging applications. Its primary focus is providing methods for solving ill-posed inverse problems by using an appropriate prior model of the reconstruction space. SCICO includes a growing suite of operators, cost functionals, regularizers, and optimization routines that may be combined to solve a wide range of problems, and is designed so that it is easy to add new building blocks. SCICO is built on top of [JAX](https://github.com/google/jax), which provides features such as automatic gradient calculation and GPU acceleration. [Documentation](https://scico.rtfd.io/) is available online. If you use this software for published work, please cite the corresponding [JOSS Paper](https://doi.org/10.21105/joss.04722) (see bibtex entry `balke-2022-scico` in `docs/source/references.bib`). # Installation The online documentation includes detailed [installation instructions](https://scico.rtfd.io/en/latest/install.html). # Usage Examples Usage examples are available as Python scripts and Jupyter Notebooks. Example scripts are located in `examples/scripts`. The corresponding Jupyter Notebooks are provided in the [scico-data](https://github.com/lanl/scico-data) submodule and symlinked to `examples/notebooks`. They are also viewable on [GitHub](https://github.com/lanl/scico-data/tree/main/notebooks) or [nbviewer](https://nbviewer.jupyter.org/github/lanl/scico-data/tree/main/notebooks/index.ipynb), and can be run online on [binder](https://mybinder.org/v2/gh/lanl/scico-data/binder?labpath=notebooks%2Findex.ipynb) or [google colab](https://colab.research.google.com/github/lanl/scico-data/blob/colab/notebooks/index.ipynb). # License SCICO is distributed as open-source software under a BSD 3-Clause License (see the `LICENSE` file for details). LANL open source approval reference C20091. \(c\) 2020-2026. Triad National Security, LLC. All rights reserved. This program was produced under U.S. Government contract 89233218CNA000001 for Los Alamos National Laboratory (LANL), which is operated by Triad National Security, LLC for the U.S. Department of Energy/National Nuclear Security Administration. All rights in the program are reserved by Triad National Security, LLC, and the U.S. Department of Energy/National Nuclear Security Administration. The Government has granted for itself and others acting on its behalf a nonexclusive, paid-up, irrevocable worldwide license in this material to reproduce, prepare derivative works, distribute copies to the public, perform publicly and display publicly, and to permit others to do so. ================================================ FILE: conftest.py ================================================ """ Configure pytest. """ import os import numpy as np import pytest os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0" # suppress ray warning try: import ray # noqa: F401 except ImportError: have_ray = False else: have_ray = True ray.init(num_cpus=1) # call required to be here: see ray-project/ray#44087 import jax.numpy as jnp import scico.numpy as snp def pytest_sessionstart(session): """Initialize before start of test session.""" # placeholder: currently unused def pytest_sessionfinish(session, exitstatus): """Clean up after end of test session.""" if have_ray: ray.shutdown() @pytest.fixture(autouse=True) def add_modules(doctest_namespace): """Add common modules for use in docstring examples. Necessary because `np` is used in doc strings for jax functions (e.g. `linear_transpose`) that get pulled into `scico/__init__.py`. Also allow `snp` and `jnp` to be used without explicitly importing. """ doctest_namespace["np"] = np doctest_namespace["snp"] = snp doctest_namespace["jnp"] = jnp ================================================ FILE: dev_requirements.txt ================================================ -r requirements.txt pylint pytest>=7.3.0 pytest-split packaging pre-commit black>=24.3.0 isort autoflake ================================================ FILE: docs/Makefile ================================================ # Makefile for Sphinx documentation # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = source BUILDDIR = ../build/sphinx .PHONY: help clean Makefile # Put this first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) clean: rm -rf $(BUILDDIR)/* rm -f $(SOURCEDIR)/_autosummary/* rm -f $(SOURCEDIR)/examples/* # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @mkdir -p source/examples; \ for f in ../data/notebooks/*; do \ b=$$(basename $$f) ; \ if [ ! -f "source/examples/$$b" ]; then \ echo Creating soft link for notebook $$b ; \ ln -s -t source/examples "../../$$f" ; \ fi \ done $(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/docs_requirements.txt ================================================ -r ../requirements.txt sphinx>=5.0.0 sphinxcontrib-napoleon sphinxcontrib-bibtex sphinx-autodoc-typehints furo>=2024.5.6 jinja2<3.1.0 # temporary fix for jinja2/nbconvert bug traitlets!=5.2.2 # temporary fix for ipython/traitlets#741 nbsphinx ipython ipython_genutils py2jn pygraphviz>=1.9 pandoc docutils>=0.18 ================================================ FILE: docs/rtd_requirements.txt ================================================ # nbconvert>=7.5 requires a version of pandoc that is not available # in the readthedocs build environment nbconvert<7.5 ================================================ FILE: docs/source/_static/scico.css ================================================ /* furo theme customization */ body[data-theme="dark"] figure img { filter: invert(100%); } .sidebar-drawer { width: fit-content !important; } .main > .content { min-width: 75%; width: fit-content !important; max-width: 80em; } .highlight { background: #e9efff; } .sidebar-brand-text { font-size: 1.0rem !important; text-align: center; padding-top: 0.5em; } /* Code display section */ div.doctest.highlight-default { background-color: #f9f9f4; } /* Style for autosummary API docs */ [data-theme=light] dl.field-list.simple { background-color: #f5f5f5; border-radius: 4px; } [data-theme=light] dl.field-list.simple > dt.field-odd { background-color: #f2f2f2; border-radius: 4px; } [data-theme=light] dl.field-list.simple > dt.field-even { background-color: #f2f2f2; border-radius: 4px; } [data-theme=light] dl.py.data { background-color: #fdfafa; border-radius: 4px; } [data-theme=light] dl.py.data > dt { border-radius: 4px; } [data-theme=light] dl.py.attribute { background-color: #fdfafa; border-radius: 4px; } [data-theme=light] dl.py.attribute > dt { border-radius: 4px; } [data-theme=light] dl.py.function { background-color: #fdfafa; border-radius: 4px; } [data-theme=light] dl.py.function > dt { border-radius: 4px; } [data-theme=light] dl.py.function blockquote { background-color: #f5f5f5; border-left: 0px; } [data-theme=light] dl.py.class { background-color: #fdfafa; border-radius: 4px; } [data-theme=light] dl.py.class > dt { border-radius: 4px; } [data-theme=light] dl.py.method { background-color: #f6f6f6; border-radius: 4px; } [data-theme=light] dl.py.method > dt { border-radius: 4px; } [data-theme=light] dl.py.property { background-color: #f6f6f6; border-radius: 4px; } [data-theme=light] dl.py.property > dt { border-radius: 4px; } /* Style for figure captions */ div.figure p.caption span.caption-text, figcaption span.caption-text { font-size: var(--font-size--small); margin-left: 5%; margin-right: 5%; display: inline-block; text-align: justify; } ================================================ FILE: docs/source/_templates/autosummary/module.rst ================================================ {{ fullname | escape | underline}} .. automodule:: {{ fullname }} {% block attributes %} {% if attributes %} .. rubric:: {{ _('Module Attributes') }} .. autosummary:: {% for item in attributes %} {{ item }} {%- endfor %} {% endif %} {% endblock %} {% block modules %} {% if modules %} .. rubric:: Modules .. autosummary:: :toctree: :recursive: {% for item in modules %} {{ item }} {%- endfor %} {% endif %} {% endblock %} {% block functions %} {% if functions %} .. rubric:: {{ _('Functions') }} .. autosummary:: {% for item in functions %} {{ item }} {%- endfor %} {% endif %} {% endblock %} {% block classes %} {% if classes %} .. rubric:: {{ _('Classes') }} .. autosummary:: {% for item in classes %} {{ item }} {%- endfor %} {% endif %} {% endblock %} {% block exceptions %} {% if exceptions %} .. rubric:: {{ _('Exceptions') }} .. autosummary:: {% for item in exceptions %} {{ item }} {%- endfor %} {% endif %} {% endblock %} ================================================ FILE: docs/source/_templates/package.rst ================================================ API Reference ============= .. automodule:: {{ fullname }} {% block modules %} {% if modules %} .. autosummary:: :toctree: :recursive: {% for item in modules %} {{ item }} {%- endfor %} {% endif %} {% endblock %} ================================================ FILE: docs/source/_templates/sidebar/brand.html ================================================ {#- Hi there! You might be interested in https://pradyunsg.me/furo/customisation/sidebar/ Although if you're reading this, chances are that you're either familiar enough with Sphinx that you know what you're doing, or landed here from that documentation page. Hope your day's going well. :) -#} ================================================ FILE: docs/source/advantages.rst ================================================ Why SCICO? ========== Advantages of JAX-based Design ------------------------------ The vast majority of scientific computing packages in Python are based on `NumPy `__ and `SciPy `__. SCICO, in contrast, is based on `JAX `__, which provides most of the same features, but with the addition of automatic differentiation, GPU support, and just-in-time (JIT) compilation. (The availability of these features in SCICO is subject to some :ref:`caveats `.) SCICO users and developers are advised to become familiar with the `differences between JAX and NumPy. `_. While recent advances in automatic differentiation have primarily been driven by its important role in deep learning, it is also invaluable in a functional minimization framework such as SCICO. The most obvious advantage is allowing the use of gradient-based minimization methods without the need for tedious mathematical derivation of an expression for the gradient. Equally valuable, though, is the ability to automatically compute the adjoint operator of a linear operator, the manual derivation of which is often time-consuming. GPU support and JIT compilation both offer the potential for significant code acceleration, with the speed gains that can be obtained depending on the algorithm/function to be executed. In many cases, a speed improvement by an order of magnitude or more can be obtained by running the same code on a GPU rather than a CPU, and similar speed gains can sometimes also be obtained via JIT compilation. The figure below shows timing results obtained on a compute server with an Intel Xeon Gold 6230 CPU and NVIDIA GeForce RTX 2080 Ti GPU. It is interesting to note that for :class:`.FiniteDifference` the GPU provides no acceleration, while JIT provides more than an order of magnitude of speed improvement on both CPU and GPU. For :class:`.DFT` and :class:`.Convolve`, significant JIT acceleration is limited to the GPU, which also provides significant acceleration over the CPU. .. image:: /figures/jax-timing.png :align: center :width: 95% :alt: Timing results for SCICO operators on CPU and GPU with and without JIT Related Packages ---------------- Many elements of SCICO are partially available in other packages. We briefly review them here, highlighting some of the main differences with SCICO. `GlobalBioIm `__ is similar in structure to SCICO (and a major inspiration for SCICO), providing linear operators and solvers for inverse problems in imaging. However, it is written in MATLAB and is thus not usable in a completely free environment. It also lacks the automatic adjoint calculation and simple GPU support offered by SCICO. `PyLops `__ provides a linear operator class and many built-in linear operators. These operators are compatible with many `SciPy `__ solvers. GPU support is provided via `CuPy `__, which has the disadvantage that switching for a CPU to GPU requires code changes, unlike SCICO and `JAX `__. SCICO is more focused on computational imaging that PyLops and has several specialized operators that PyLops does not. `Pycsou `__, like SCICO, is a Python project inspired by GlobalBioIm. Since it is based on PyLops, it shares the disadvantages with respect to SCICO of that project. `ODL `__ provides a variety of operators and related infrastructure for prototyping of inverse problems. It is built on top of `NumPy `__/`SciPy `__, and does not support any of the advanced features of `JAX `__. `ProxImaL `__ is a Python package for image optimization problems. Like SCICO and many of the other projects listed here, problems are specified by combining objects representing, operators, functionals, and solvers. It does not support any of the advanced features of `JAX `__. `ProxMin `__ provides a set of proximal optimization algorithms for minimizing non-smooth functionals. It is built on top of `NumPy `__/`SciPy `__, and does not support any of the advanced features of `JAX `__ (however, an open issue suggests that `JAX `__ compatibility is planned). `CVXPY `__ provides a flexible language for defining optimization problems and a wide selection of solvers, but has limited support for matrix-free methods. Other related projects that may be of interest include: - `ToMoBAR `__ - `CCPi-Regularisation Toolkit `__ - `SPORCO `__ - `SigPy `__ - `MIRT `__ - `BART `__ ================================================ FILE: docs/source/api.rst ================================================ :orphan: API Documentation ================= .. autosummary:: :toctree: _autosummary :template: package.rst :caption: API Reference :recursive: scico ================================================ FILE: docs/source/classes.rst ================================================ .. _classes: ****************** Main SCICO Classes ****************** .. include:: include/blockarray.rst .. include:: include/operator.rst .. include:: include/functional.rst .. include:: include/optimizer.rst .. include:: include/learning.rst ================================================ FILE: docs/source/conf/10-project.py ================================================ from scico._version import package_version # General information about the project. project = "SCICO" copyright = "2020-2026, SCICO Developers" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. version = package_version() # The full version, including alpha/beta/rc tags. release = version ================================================ FILE: docs/source/conf/15-theme.py ================================================ # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = "python_docs_theme" html_theme = "furo" html_theme_options = { "top_of_page_buttons": [], # "sidebar_hide_name": True, } if html_theme == "python_docs_theme": html_sidebars = { "**": ["globaltoc.html", "sourcelink.html", "searchbox.html"], } # These folders are copied to the documentation's HTML output html_static_path = ["_static"] # These paths are either relative to html_static_path or fully qualified # paths (eg. https://...) html_css_files = [ "scico.css", "http://netdna.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css", ] # The name of an image file (relative to this directory) to place at the top # of the sidebar. html_logo = "_static/logo.svg" # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. html_favicon = "_static/scico.ico" ================================================ FILE: docs/source/conf/20-extensions.py ================================================ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ "sphinx.ext.napoleon", "sphinx.ext.autodoc", "sphinx_autodoc_typehints", "sphinx.ext.autosummary", "sphinx.ext.doctest", "sphinx.ext.intersphinx", "sphinx.ext.viewcode", "sphinxcontrib.bibtex", "sphinx.ext.inheritance_diagram", "matplotlib.sphinxext.plot_directive", "sphinx.ext.todo", "nbsphinx", ] bibtex_bibfiles = ["references.bib"] ================================================ FILE: docs/source/conf/25-napoleon.py ================================================ from sphinx.ext.napoleon.docstring import GoogleDocstring ## See ## https://github.com/sphinx-doc/sphinx/issues/2115 ## https://michaelgoerz.net/notes/extending-sphinx-napoleon-docstring-sections.html ## # first, we define new methods for any new sections and add them to the class def parse_keys_section(self, section): return self._format_fields("Keys", self._consume_fields()) GoogleDocstring._parse_keys_section = parse_keys_section def parse_attributes_section(self, section): return self._format_fields("Attributes", self._consume_fields()) GoogleDocstring._parse_attributes_section = parse_attributes_section def parse_class_attributes_section(self, section): return self._format_fields("Class Attributes", self._consume_fields()) GoogleDocstring._parse_class_attributes_section = parse_class_attributes_section # we now patch the parse method to guarantee that the the above methods are # assigned to the _section dict def patched_parse(self): self._sections["keys"] = self._parse_keys_section self._sections["class attributes"] = self._parse_class_attributes_section self._unpatched_parse() GoogleDocstring._unpatched_parse = GoogleDocstring._parse GoogleDocstring._parse = patched_parse # napoleon_include_init_with_doc = True napoleon_use_ivar = True napoleon_use_rtype = False # See https://github.com/sphinx-doc/sphinx/issues/9119 # napoleon_custom_sections = [("Returns", "params_style")] ================================================ FILE: docs/source/conf/30-autodoc.py ================================================ autodoc_default_options = { "member-order": "bysource", "inherited-members": False, "ignore-module-all": False, "show-inheritance": True, "members": True, "special-members": "__call__", } autodoc_docstring_signature = True autoclass_content = "both" # See https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_mock_imports autodoc_mock_imports = ["astra", "svmbir", "ray"] # See # https://stackoverflow.com/questions/2701998#62613202 # https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion autosummary_generate = True # See https://stackoverflow.com/questions/5599254 autoclass_content = "both" ================================================ FILE: docs/source/conf/40-intersphinx.py ================================================ # Intersphinx mapping intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), "numpy": ("https://numpy.org/doc/stable/", None), "scipy": ("https://docs.scipy.org/doc/scipy/", None), "matplotlib": ("https://matplotlib.org/stable/", None), "jax": ("https://docs.jax.dev/en/latest/", None), "flax": ("https://flax.readthedocs.io/en/latest/", None), "ray": ("https://docs.ray.io/en/latest/", None), "svmbir": ("https://svmbir.readthedocs.io/en/latest/", None), } # Added timeout due to periodic scipy.org down time # intersphinx_timeout = 30 ================================================ FILE: docs/source/conf/45-mathjax.py ================================================ import os if os.environ.get("NO_MATHJAX"): extensions.append("sphinx.ext.imgmath") imgmath_image_format = "svg" else: extensions.append("sphinx.ext.mathjax") # To use local copy of MathJax for offline use, set MATHJAX_URI to # file:///[path-to-mathjax-repo-root]/es5/tex-mml-chtml.js if os.environ.get("MATHJAX_URI"): mathjax_path = os.environ.get("MATHJAX_URI") mathjax3_config = { "tex": { "macros": { "mb": [r"\mathbf{#1}", 1], "mbs": [r"\boldsymbol{#1}", 1], "mbb": [r"\mathbb{#1}", 1], "norm": [r"\lVert #1 \rVert", 1], "abs": [r"\left| #1 \right|", 1], "argmin": [r"\mathop{\mathrm{argmin}}"], "sign": [r"\mathop{\mathrm{sign}}"], "prox": [r"\mathrm{prox}"], "det": [r"\mathrm{det}"], "exp": [r"\mathrm{exp}"], "loss": [r"\mathop{\mathrm{loss}}"], "kp": [r"k_{\|}"], "rp": [r"r_{\|}"], } } } ================================================ FILE: docs/source/conf/50-graphviz.py ================================================ graphviz_output_format = "svg" inheritance_graph_attrs = dict(rankdir="LR", fontsize=9, ratio="compress", bgcolor="transparent") inheritance_edge_attrs = dict( color='"#2962ffff"', ) inheritance_node_attrs = dict( shape="box", fontsize=9, height=0.4, margin='"0.08, 0.03"', style='"rounded,filled"', color='"#2962ffff"', fontcolor='"#2962ffff"', fillcolor='"#f0f0f8b0"', ) ================================================ FILE: docs/source/conf/55-nbsphinx.py ================================================ nbsphinx_prolog = """ .. raw:: html """ nbsphinx_execute = "never" ================================================ FILE: docs/source/conf/60-rtd.py ================================================ import os on_rtd = os.environ.get("READTHEDOCS") == "True" if on_rtd: print("Building on ReadTheDocs\n") print(" current working directory: {}".format(os.path.abspath(os.curdir))) print(" rootpath: %s" % rootpath) print(" confpath: %s" % confpath) html_static_path = [] # See https://about.readthedocs.com/blog/2024/07/addons-by-default/#how-to-opt-in-to-addons-now html_baseurl = os.environ.get("READTHEDOCS_CANONICAL_URL", "") if "html_context" not in globals(): html_context = {} html_context["READTHEDOCS"] = True import matplotlib matplotlib.use("agg") else: # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] ================================================ FILE: docs/source/conf/70-latex.py ================================================ # -- Options for LaTeX output --------------------------------------------- latex_elements = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt', # Additional stuff for the LaTeX preamble. #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ ("index", "scico.tex", "SCICO Documentation", "The SCICO Developers", "manual"), ] latex_engine = "xelatex" # latex_use_xindy = False # mathjax3_config must already be defined latex_macros = [] for k, v in mathjax3_config["tex"]["macros"].items(): if len(v) == 1: latex_macros.append(r"\newcommand{\%s}{%s}" % (k, v[0])) else: latex_macros.append(r"\newcommand{\%s}[1]{%s}" % (k, v[0])) imgmath_latex_preamble = "\n".join(latex_macros) latex_elements = {"preamble": "\n".join(latex_macros)} ================================================ FILE: docs/source/conf/71-texinfo.py ================================================ # -- Options for Texinfo output ------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ ( "index", "SCICO", "SCICO Documentation", "SCICO Developers", "SCICO", "Scientific Computational Imaging COde (SCICO)", "Miscellaneous", ), ] ================================================ FILE: docs/source/conf/72-man_page.py ================================================ # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [("index", "scico", "SCICO Documentation", ["SCICO Developers"], 1)] # If true, show URL addresses after external links. # man_show_urls = False ================================================ FILE: docs/source/conf/80-scico_numpy.py ================================================ import re from inspect import getmembers, isfunction # Rewrite module names for certain functions imported into scico.numpy so that they are # included in the docs for that module. While a bit messy to do so here rather than in a # function run via app.connect, it is necessary (for some yet to be identified reason) # to do it here to ensure that the relevant API docs include a table of functions. import scico.numpy for module in (scico.numpy, scico.numpy.fft, scico.numpy.linalg, scico.numpy.testing): for _, f in getmembers(module, isfunction): # Rewrite module name so that function is included in docs f.__module__ = module.__name__ f.__doc__ = re.sub( r"^:func:`([\w_]+)` wrapped to operate", r":obj:`jax.numpy.\1` wrapped to operate", str(f.__doc__), flags=re.M, ) modname = ".".join(module.__name__.split(".")[1:]) f.__doc__ = re.sub( r"^LAX-backend implementation of :func:`([\w_]+)`.", r"LAX-backend implementation of :obj:`%s.\1`." % modname, str(f.__doc__), flags=re.M, ) # Improve formatting of jax.numpy warning f.__doc__ = re.sub( r"^\*\*\* This function is not yet implemented by jax.numpy, and will " r"raise NotImplementedError \*\*\*", "**WARNING**: This function is not yet implemented by jax.numpy, " " and will raise :exc:`NotImplementedError`.", f.__doc__, flags=re.M, ) # Remove cross-references to section NEP35 f.__doc__ = re.sub(":ref:`NEP 35 `", "NEP 35", f.__doc__, re.M) # Remove cross-reference to numpydoc style references section f.__doc__ = re.sub(r" \[(\d+)\]_", "", f.__doc__, flags=re.M) # Remove entire numpydoc references section f.__doc__ = re.sub(r"References\n----------\n.*\n", "", f.__doc__, flags=re.DOTALL) # Fix various docstring formatting errors scico.numpy.testing.break_cycles.__doc__ = re.sub( "calling gc.collect$", "calling gc.collect.\n\n", scico.numpy.testing.break_cycles.__doc__, flags=re.M, ) scico.numpy.testing.break_cycles.__doc__ = re.sub( r" __del__\) inside", r"__del__\) inside", scico.numpy.testing.break_cycles.__doc__, flags=re.M ) scico.numpy.testing.assert_raises_regex.__doc__ = re.sub( r"\*args,\n.*\*\*kwargs", "*args, **kwargs", scico.numpy.testing.assert_raises_regex.__doc__, flags=re.M, ) scico.numpy.BlockArray.global_shards.__doc__ = re.sub( r"`Shard`s", r"`Shard`\ s", scico.numpy.BlockArray.global_shards.__doc__, flags=re.M ) ================================================ FILE: docs/source/conf/81-scico_scipy.py ================================================ import re from inspect import getmembers, isfunction # Similar processing for scico.scipy import scico.scipy ssp_func = getmembers(scico.scipy.special, isfunction) for _, f in ssp_func: if f.__module__[0:11] == "scico.scipy" or f.__module__[0:14] == "jax._src.scipy": # Rewrite module name so that function is included in docs f.__module__ = "scico.scipy.special" # Attempt to fix incorrect cross-reference f.__doc__ = re.sub( r"^:func:`([\w_]+)` wrapped to operate", r":obj:`jax.scipy.special.\1` wrapped to operate", str(f.__doc__), flags=re.M, ) modname = "scipy.special" f.__doc__ = re.sub( r"^LAX-backend implementation of :func:`([\w_]+)`.", r"LAX-backend implementation of :obj:`%s.\1`." % modname, str(f.__doc__), flags=re.M, ) # Remove cross-reference to numpydoc style references section f.__doc__ = re.sub(r"(^|\ )\[(\d+)\]_", "", f.__doc__, flags=re.M) # Remove entire numpydoc references section f.__doc__ = re.sub(r"References\n----------\n.*\n", "", f.__doc__, flags=re.DOTALL) # Remove problematic citation f.__doc__ = re.sub(r"See \[dlmf\]_ for details.", "", f.__doc__, re.M) f.__doc__ = re.sub(r"\[dlmf\]_", "NIST DLMF", f.__doc__, re.M) # Fix indentation problems if hasattr(scico.scipy.special, "sph_harm"): scico.scipy.special.sph_harm.__doc__ = re.sub( "^Computes the", " Computes the", scico.scipy.special.sph_harm.__doc__, flags=re.M ) ================================================ FILE: docs/source/conf/85-dtype_typehints.py ================================================ from typing import Optional, Sequence, Union # needed for typehints_formatter hack from scico.typing import ( # needed for typehints_formatter hack ArrayIndex, AxisIndex, DType, ) # An explanation for this nasty hack, the primary purpose of which is to avoid # the very long definition of the scico.typing.DType appearing explicitly in the # docs. This is handled correctly by sphinx.ext.autodoc in some circumstances, # but only when sphinx_autodoc_typehints is not included in the extension list, # and the appearance of the type hints (e.g. whether links to definitions are # included) seems to depend on whether "from __future__ import annotations" was # used in the module being documented, which is not ideal from a consistency # perspective. (It's also worth noting that sphinx.ext.autodoc provides some # configurability for type aliases via the autodoc_type_aliases sphinx # configuration option.) The alternative is to include sphinx_autodoc_typehints, # which gives a consistent appearance to the type hints, but the # autodoc_type_aliases configuration option is ignored, and type aliases are # always expanded. This hack avoids expansion for the type aliases with the # longest definitions by definining a custom function for formatting the # type hints, using an option provided by sphinx_autodoc_typehints. For # more information, see # https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases # https://github.com/tox-dev/sphinx-autodoc-typehints/issues/284 # https://github.com/tox-dev/sphinx-autodoc-typehints/blob/main/README.md def typehints_formatter_function(annotation, config): markup = { DType: ":obj:`~scico.typing.DType`", # Compound types involving DType must be added here to avoid their DType # component being expanded in the docs. Optional[DType]: r":obj:`~typing.Optional`\ [\ :obj:`~scico.typing.DType`\ ]", Union[DType, Sequence[DType]]: ( r":obj:`~typing.Union`\ [\ :obj:`~scico.typing.DType`\ , " r":obj:`~typing.Sequence`\ [\ :obj:`~scico.typing.DType`\ ]]" ), AxisIndex: ":obj:`~scico.typing.AxisIndex`", ArrayIndex: ":obj:`~scico.typing.ArrayIndex`", } if annotation in markup: return markup[annotation] else: return None typehints_formatter = typehints_formatter_function ================================================ FILE: docs/source/conf.py ================================================ # -*- coding: utf-8 -*- import os import sys confpath = os.path.dirname(__file__) sys.path.append(confpath) rootpath = os.path.realpath(os.path.join(confpath, "..", "..")) sys.path.append(rootpath) from docsutil import insert_inheritance_diagram, package_classes, run_conf_files # Process settings in files in conf directory _vardict = run_conf_files(vardict={"confpath": confpath, "rootpath": rootpath}) for _k, _v in _vardict.items(): globals()[_k] = _v del _vardict, _k, _v # If your documentation needs a minimal Sphinx version, state it here. needs_sphinx = "5.0.0" # The suffix of source filenames. source_suffix = ".rst" # The encoding of source files. source_encoding = "utf-8" # The master toctree document. master_doc = "index" # Output file base name for HTML help builder. htmlhelp_basename = "SCICOdoc" # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. exclude_patterns = ["_build", "**tests**", "**README.rst", "include"] # If true, '()' will be appended to :func: etc. cross-reference text. add_function_parentheses = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = "sphinx" # Include TODOs todo_include_todos = True def class_inherit_diagrams(_): # Insert inheritance diagrams for classes that have base classes import scico custom_parts = {"scico.ray.tune.Tuner": 4} clslst = package_classes(scico) for cls in clslst: insert_inheritance_diagram(cls, parts=custom_parts) def process_docstring(app, what, name, obj, options, lines): # Don't show docs for inherited members in classes in scico.flax. # This is primarily useful for silencing warnings due to problems in # the current release of flax, but is arguably also useful in avoiding # extensive documentation of methods that are likely to be of limited # interest to users of the scico.flax classes. # # Note: this event handler currently has no effect since inclusion of # inherited members is currently globally disabled (see # "inherited-members" in autodoc_default_options), but is left in # place in case a decision is ever made to revert the global setting. # # See https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html # for documentation of the autodoc-process-docstring event used here. if what == "class" and "scico.flax." in name: options["inherited-members"] = False def setup(app): app.connect("builder-inited", class_inherit_diagrams) app.connect("autodoc-process-docstring", process_docstring) ================================================ FILE: docs/source/contributing.rst ================================================ .. _scico_dev_contributing: Contributing ============ Contributions to SCICO are welcome. Before starting work, please contact the maintainers, either via email or the GitHub issue system, to discuss the relevance of your contribution and the most appropriate location within the existing package structure. .. _installing_dev: Installing a Development Version -------------------------------- 1. Fork both the ``scico`` and ``scico-data`` repositories, creating copies of these repositories in your own git account. 2. Make sure that you have Python 3.10 or later installed in order to create a conda virtual environment. 3. Clone your fork from the source repo. :: git clone --recurse-submodules git@github.com:/scico.git 4. Create a conda environment using Python 3.10 or later, e.g.: :: conda create -n scico python=3.12 5. Activate the created conda virtual environment: :: conda activate scico 6. Change directory to the root of the cloned repository: :: cd scico 7. Add the ``scico`` repo as an upstream remote to sync your changes: :: git remote add upstream https://www.github.com/lanl/scico 8. After adding the upstream, the recommended way to install SCICO and its dependencies is via pip: :: pip install -r requirements.txt # Installs basic requirements pip install -r dev_requirements.txt # Installs developer requirements pip install -r docs/docs_requirements.txt # Installs documentation requirements pip install -e . # Installs SCICO from the current directory in editable mode For installing dependencies related to the examples please see :ref:`example_notebooks`. Installing these are neccessary for the successfull running of the tests. 9. The SCICO project uses the `black `_, `isort `_ and `pylint `_ code formatting utilities. It is important to set up a `pre-commit hook `_ to ensure that any modified code passes format check before it is committed to the development repo: :: pre-commit install # Sets up git pre-commit hooks It is also recommended to `pin the conda package version `__ of `black `_ to the version number specified in ``dev_requirements.txt``. 10. For testing see `Tests`_. Building Documentation ---------------------- Before building the documentation, one must install the documentation specific dependencies by running :: pip install -r docs/docs_requirements.txt Then, a local copy of the documentation can be built from the respository root directory by running :: python setup.py build_sphinx Alternatively, one can also build the documentation by running the following from the `docs/` directory :: make html Contributing Code ----------------- - New features / bugs / documentation are *always* developed in separate branches. - Branches should be named in the form `/`, where `` provides a highly condensed description of the purpose of the branch (e.g. `address_todo`), and may include an issue number if appropriate (e.g. `fix_223`). A feature development workflow might look like this: 1. Follow the instructions in `Installing a Development Version`_. 2. Sync with the upstream repository: :: git pull --rebase origin main --recurse-submodules 3. Create a branch to develop from: :: git checkout -b / 4. Make your desired changes. 5. Run the test suite: :: pytest You can limit the test suite to a specific file for example: :: pytest scico/test/test_blockarray.py 6. When you are finished making changes, create a new commit: :: git add file1.py git add file2.py git commit -m "A good commit message" If you have added or modified an example script, see `Usage Examples`_. If your contribution involves any significant new features or changes, add a corresponding entry to the change summary for the next release in the ``CHANGES.rst`` file. 7. Sync with the upstream repository: :: git fetch upstream git rebase upstream/main 8. Push your development upstream: :: git push --set-upstream origin / 9. Create a new pull request to the ``main`` branch; see `the GitHub instructions `_. 10. The SCICO maintainers will review and merge your PR. The SCICO project recommends the ``squash and merge`` option for merging PRs. 11. Delete the branch after it has been merged. Adding Data ----------- The following steps show how to add new data, ``new_data.npz``, to the packaged data. We assume the ``scico`` repository has been cloned to ``scico/``. Note that the data is located in the ``scico-data`` submodule, which is attached to the main `scico` repository via the directory ``scico/data`` (i.e. the ``data/`` subdirectory of the repository root directory, *not* the ``scico/data`` subdirectory of the repository root directory). When adding new data, both the ``scico`` and ``scico-data`` repositories must be updated and kept in sync. 1. Create new branches in the main ``scico`` repository as well as in the submodule corresponding to the ``scico-data`` repository (which can be achieved by following the usual branch creation procedure after changing the current directory to ``scico/data``). 2. Add the ``new_data.npz`` file to the appropriate subdirectory (creating a new one if necessary) of the ``scico/data`` directory. 3. Change directory to this directory (taken to be ``scico/data/flax`` for the purposes of this example) and add/commit the new data file: :: cd scico/data/flax git add new_data.npz git commit -m "Add new data file" 4. Return to the ``scico`` repository root directory, add/commit the new data, and update submodule: :: cd ../.. # pwd now `scico` repo root git add data git commit -m "Add data and update data module" 5. Push both repositories: :: git submodule foreach --recursive 'git push' && git push Type Checking ------------- All code is required to pass ``mypy`` type checking. Install ``mypy``: :: conda install mypy To run the type checker, execute the following from the scico repository root: :: mypy --follow-imports=skip --ignore-missing-imports --exclude "(numpy|test)" scico/ Tests ----- All functions and classes should have corresponding ``pytest`` unit tests. Running Tests ^^^^^^^^^^^^^ To be able to run the tests, install ``pytest`` and, optionally, ``pytest-runner``: :: conda install pytest pytest-runner The tests can be run by :: pytest or (if ``pytest-runner`` is installed) :: python setup.py test from the ``scico`` repository root directory. Tests can be run in an installed version of ``scico`` by :: pytest --pyargs scico When any significant changes are made to the test suite, the ``pytest-split`` test time database files in ``data/pytest`` should be updated using :: pytest --store-durations --durations-path data/pytest/durations_ubuntu --level 2 (for Ubuntu CI), and :: pytest --store-durations --durations-path data/pytest/durations_macos --level 1 (for MacOS CI). These updated files should be bzipped and committed into the ``scico-data`` repository, replacing the current versions. Test Coverage ^^^^^^^^^^^^^ Test coverage is a measure of the fraction of the package code that is exercised by the tests. While this should not be the primary criterion in designing tests, it is a useful tool for finding obvious areas of omission. To be able to check test coverage, install ``coverage``: :: conda install coverage A coverage report can be obtained by :: coverage run coverage report Usage Examples -------------- New usage examples should adhere to the same general structure as the existing examples to ensure that the mechanism for automatically generating corresponding Jupyter notebooks functions correctly. In particular: 1. The initial lines of the script should consist of a comment block, followed by a blank line, followed by a multiline string with an RST heading on the first line, e.g., :: #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ Script Title ============ Script description. """ 2. The final line of the script is an ``input`` statement intended to avoid the script terminating immediately, thereby closing all figures: :: input("\nWaiting for input to close figures and exit") 3. Citations are included using the standard `Sphinx `__ ``:cite:`cite-key``` syntax, where ``cite-key`` is the key of an entry in ``docs/source/references.bib``. 4. Cross-references to other components of the documentation are included using the syntax described in the `nbsphinx documentation `__. 5. External links are included using Markdown syntax ``[link text](url)``. 6. When constructing a synthetic image/volume for use in the example, define a global variable `N` that controls the size of the problem, and where relevant, define a global variable `maxiter` that controls the number of iterations of optimization algorithms such as ADMM. Adhering to this convention allows the ``examples/scriptcheck.sh`` utility to automatically construct less computationally expensive versions of the example scripts for testing that they run without any errors. Adding new examples ^^^^^^^^^^^^^^^^^^^ The following steps show how to add a new example, ``new_example.py``, to the packaged usage examples. We assume the ``scico`` repository has been cloned to ``scico/``. Note that the ``.py`` scripts are included in ``scico/examples/scripts``, while the compiled Jupyter Notebooks are located in the scico-data submodule, which is symlinked to ``scico/data``. When adding a new usage example, both the ``scico`` and ``scico-data`` repositories must be updated and kept in sync. .. warning:: Ensure that all binary data (including raw data, images, ``.ipynb`` files) are added to ``scico-data``, not the main ``scico`` repo. 1. Create new branches in the main `scico` repository as well as in the submodule corresponding to the `scico-data` repository (which can be achieved by following the usual branch creation procedure after changing the current directory to ``scico/data``). 2. Add the ``new_example.py`` script to the ``scico/examples/scripts`` directory. 3. Add the basename of the script (i.e., without the pathname; in this case, ``new_example.py``) to the appropriate section of ``examples/scripts/index.rst``. 4. Convert your new example to a Jupyter notebook by changing directory to the ``scico/examples`` directory and following the instructions in ``scico/examples/README.rst``. 5. Change directory to the ``data`` directory and add/commit the new Jupyter Notebook: :: cd scico/data git add notebooks/new_example.ipynb git commit -m "Add new usage example" 6. Return to the main ``scico`` repository root directory, ensure the ``main`` branch is checked out, add/commit the new script and updated submodule: :: cd .. # pwd now `scico` repo root git add data git add examples/scripts/new_filename.py git commit -m "Add usage example and update data module" 7. Push both repositories: :: git submodule foreach --recursive 'git push' && git push ================================================ FILE: docs/source/docsutil.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Utilities for building docs.""" import importlib import inspect import os import pkgutil import sys from glob import glob from runpy import run_path def run_conf_files(vardict=None, path=None): """Execute Python files in conf directory. Args: vardict: Dictionary into which variable names should be inserted. Defaults to empty dict. path: Path to conf directory. Defaults to path to this module. Returns: A dict populated with variables defined during execution of the configuration files. """ if vardict is None: vardict = {} if path is None: path = os.path.dirname(__file__) files = os.path.join(path, "conf", "*.py") for f in sorted(glob(files)): conf = run_path(f, init_globals=vardict) for k, v in conf.items(): if len(k) >= 4 and k[0:2] == "__" and k[-2:] == "__": # ignore ____ variables continue vardict[k] = v return vardict def package_classes(package): """Get a list of classes in a package. Return a list of qualified names of classes in the specified package. Classes in modules with names beginning with an "_" are omitted, as are classes whose internal module name record is not the same as the module in which they are found (i.e. indicating that they have been imported from elsewhere). Args: package: Reference to package for which classes are to be listed (not package name string). Returns: A list of qualified names of classes in the specified package. """ classes = [] # Iterate over modules in package for importer, modname, _ in pkgutil.walk_packages( path=package.__path__, prefix=(package.__name__ + "."), onerror=lambda x: None ): # Skip modules whose names begin with a "_" if modname.split(".")[-1][0] == "_": continue importlib.import_module(modname) # Iterate over module members for name, obj in inspect.getmembers(sys.modules[modname]): if inspect.isclass(obj): # Get internal module name of class for comparison with working module name try: objmodname = getattr(sys.modules[modname], obj.__name__).__module__ except Exception: objmodname = None if objmodname == modname: classes.append(modname + "." + obj.__name__) return classes def get_text_indentation(text, skiplines=0): """Compute the leading whitespace indentation in a block of text. Args: text: A block of text as a string. Returns: Indentation length. """ min_indent = len(text) lines = text.splitlines() if len(lines) > skiplines: lines = lines[skiplines:] else: return None for line in lines: if len(line) > 0: indent = len(line) - len(line.lstrip()) if indent < min_indent: min_indent = indent return min_indent def add_text_indentation(text, indent): """Insert leading whitespace into a block of text. Args: text: A block of text as a string. indent: Number of leading spaces to insert on each line. Returns: Text with additional indentation. """ lines = text.splitlines() for n, line in enumerate(lines): if len(line) > 0: lines[n] = (" " * indent) + line return "\n".join(lines) def insert_inheritance_diagram(clsqname, parts=None, default_nparts=2): """Insert an inheritance diagram into a class docstring. No action is taken for classes without a base clase, and for classes without a docstring. Args: clsqname: Qualified name (i.e. including module name path) of class. parts: A dict mapping qualified class names to custom values for the ":parts:" directive. default_nparts: Default value for the ":parts:" directive. """ # Extract module name and class name from qualified class name clspth = clsqname.split(".") modname = ".".join(clspth[0:-1]) clsname = clspth[-1] # Get reference to class cls = getattr(sys.modules[modname], clsname) # Return immediately if class has no base classes if getattr(cls, "__bases__") == (object,): return # Get current docstring docstr = getattr(cls, "__doc__") # Return immediately if class has no docstring if docstr is None: return # Use class-specific parts or default parts directive value if parts and clsqname in parts: nparts = parts[clsqname] else: nparts = default_nparts # Split docstring into individual lines lines = docstr.splitlines() # Return immediately if there are no lines if not lines: return # Cut leading whitespace lines n = 0 for n, line in enumerate(lines): if line != "": break lines = lines[n:] # Define inheritance diagram insertion text idstr = f""" .. inheritance-diagram:: {clsname} :parts: {nparts} """ docstr_indent = get_text_indentation(docstr, skiplines=1) if docstr_indent is not None and docstr_indent > 4: idstr = add_text_indentation(idstr, docstr_indent - 4) # Insert inheritance diagram after summary line and whitespace line following it lines.insert(2, idstr) # Construct new docstring and attach it to the class extdocstr = "\n".join(lines) setattr(cls, "__doc__", extdocstr) ================================================ FILE: docs/source/examples.rst ================================================ .. _example_notebooks: Usage Examples ============== .. toctree:: :maxdepth: 1 .. include:: include/examplenotes.rst Organized by Application ------------------------ .. toctree:: :maxdepth: 1 Computed Tomography ^^^^^^^^^^^^^^^^^^^ .. toctree:: :maxdepth: 1 examples/ct_abel_tv_admm examples/ct_abel_tv_admm_tune examples/ct_symcone_tv_padmm examples/ct_astra_noreg_pcg examples/ct_astra_3d_tv_admm examples/ct_astra_3d_tv_padmm examples/ct_tv_admm examples/ct_astra_tv_admm examples/ct_multi_tv_admm examples/ct_astra_weighted_tv_admm examples/ct_svmbir_tv_multi examples/ct_svmbir_ppp_bm3d_admm_cg examples/ct_svmbir_ppp_bm3d_admm_prox examples/ct_fan_svmbir_ppp_bm3d_admm_prox examples/ct_modl_train_foam2 examples/ct_odp_train_foam2 examples/ct_unet_train_foam2 examples/ct_projector_comparison_2d examples/ct_projector_comparison_3d Deconvolution ^^^^^^^^^^^^^ .. toctree:: :maxdepth: 1 examples/deconv_circ_tv_admm examples/deconv_tv_admm examples/deconv_tv_padmm examples/deconv_tv_admm_tune examples/deconv_microscopy_tv_admm examples/deconv_microscopy_allchn_tv_admm examples/deconv_ppp_bm3d_admm examples/deconv_ppp_bm3d_apgm examples/deconv_ppp_dncnn_admm examples/deconv_ppp_dncnn_padmm examples/deconv_ppp_bm4d_admm examples/deconv_modl_train_foam1 examples/deconv_odp_train_foam1 Sparse Coding ^^^^^^^^^^^^^ .. toctree:: :maxdepth: 1 examples/sparsecode_nn_admm examples/sparsecode_nn_apgm examples/sparsecode_conv_admm examples/sparsecode_conv_md_admm examples/sparsecode_apgm examples/sparsecode_poisson_apgm Miscellaneous ^^^^^^^^^^^^^ .. toctree:: :maxdepth: 1 examples/demosaic_ppp_bm3d_admm examples/superres_ppp_dncnn_admm examples/denoise_l1tv_admm examples/denoise_ptv_pdhg examples/denoise_tv_admm examples/denoise_tv_apgm examples/denoise_tv_multi examples/denoise_approx_tv_multi examples/denoise_cplx_tv_nlpadmm examples/denoise_cplx_tv_pdhg examples/denoise_dncnn_universal examples/diffusercam_tv_admm examples/video_rpca_admm examples/ct_datagen_foam2 examples/deconv_datagen_bsds examples/deconv_datagen_foam1 examples/denoise_datagen_bsds Organized by Regularization --------------------------- .. toctree:: :maxdepth: 1 Plug and Play Priors ^^^^^^^^^^^^^^^^^^^^ .. toctree:: :maxdepth: 1 examples/ct_svmbir_ppp_bm3d_admm_cg examples/ct_svmbir_ppp_bm3d_admm_prox examples/ct_fan_svmbir_ppp_bm3d_admm_prox examples/deconv_ppp_bm3d_admm examples/deconv_ppp_bm3d_apgm examples/deconv_ppp_dncnn_admm examples/deconv_ppp_dncnn_padmm examples/deconv_ppp_bm4d_admm examples/demosaic_ppp_bm3d_admm examples/superres_ppp_dncnn_admm Total Variation ^^^^^^^^^^^^^^^ .. toctree:: :maxdepth: 1 examples/ct_abel_tv_admm examples/ct_abel_tv_admm_tune examples/ct_symcone_tv_padmm examples/ct_tv_admm examples/ct_multi_tv_admm examples/ct_astra_tv_admm examples/ct_astra_3d_tv_admm examples/ct_astra_3d_tv_padmm examples/ct_astra_weighted_tv_admm examples/ct_svmbir_tv_multi examples/deconv_circ_tv_admm examples/deconv_tv_admm examples/deconv_tv_admm_tune examples/deconv_tv_padmm examples/deconv_microscopy_tv_admm examples/deconv_microscopy_allchn_tv_admm examples/denoise_l1tv_admm examples/denoise_ptv_pdhg examples/denoise_tv_admm examples/denoise_tv_apgm examples/denoise_tv_multi examples/denoise_approx_tv_multi examples/denoise_cplx_tv_nlpadmm examples/denoise_cplx_tv_pdhg examples/diffusercam_tv_admm Sparsity ^^^^^^^^ .. toctree:: :maxdepth: 1 examples/diffusercam_tv_admm examples/sparsecode_nn_admm examples/sparsecode_nn_apgm examples/sparsecode_conv_admm examples/sparsecode_conv_md_admm examples/sparsecode_apgm examples/sparsecode_poisson_apgm examples/video_rpca_admm Machine Learning ^^^^^^^^^^^^^^^^ .. toctree:: :maxdepth: 1 examples/ct_datagen_foam2 examples/ct_modl_train_foam2 examples/ct_odp_train_foam2 examples/ct_unet_train_foam2 examples/deconv_datagen_bsds examples/deconv_datagen_foam1 examples/deconv_modl_train_foam1 examples/deconv_odp_train_foam1 examples/denoise_datagen_bsds examples/denoise_dncnn_train_bsds examples/denoise_dncnn_universal Organized by Optimization Algorithm ----------------------------------- .. toctree:: :maxdepth: 1 ADMM ^^^^ .. toctree:: :maxdepth: 1 examples/ct_abel_tv_admm examples/ct_abel_tv_admm_tune examples/ct_symcone_tv_padmm examples/ct_astra_tv_admm examples/ct_tv_admm examples/ct_astra_3d_tv_admm examples/ct_astra_weighted_tv_admm examples/ct_multi_tv_admm examples/ct_svmbir_tv_multi examples/ct_svmbir_ppp_bm3d_admm_cg examples/ct_svmbir_ppp_bm3d_admm_prox examples/ct_fan_svmbir_ppp_bm3d_admm_prox examples/deconv_circ_tv_admm examples/deconv_tv_admm examples/deconv_tv_admm_tune examples/deconv_microscopy_tv_admm examples/deconv_microscopy_allchn_tv_admm examples/deconv_ppp_bm3d_admm examples/deconv_ppp_dncnn_admm examples/deconv_ppp_bm4d_admm examples/diffusercam_tv_admm examples/sparsecode_nn_admm examples/sparsecode_conv_admm examples/sparsecode_conv_md_admm examples/demosaic_ppp_bm3d_admm examples/superres_ppp_dncnn_admm examples/denoise_l1tv_admm examples/denoise_tv_admm examples/denoise_tv_multi examples/denoise_approx_tv_multi examples/video_rpca_admm Linearized ADMM ^^^^^^^^^^^^^^^ .. toctree:: :maxdepth: 1 examples/ct_svmbir_tv_multi examples/denoise_tv_multi Proximal ADMM ^^^^^^^^^^^^^ .. toctree:: :maxdepth: 1 examples/ct_astra_3d_tv_padmm examples/deconv_tv_padmm examples/denoise_tv_multi examples/deconv_ppp_dncnn_padmm Non-linear Proximal ADMM ^^^^^^^^^^^^^^^^^^^^^^^^ .. toctree:: :maxdepth: 1 examples/denoise_cplx_tv_nlpadmm PDHG ^^^^ .. toctree:: :maxdepth: 1 examples/ct_svmbir_tv_multi examples/denoise_ptv_pdhg examples/denoise_tv_multi examples/denoise_cplx_tv_pdhg PGM ^^^ .. toctree:: :maxdepth: 1 examples/deconv_ppp_bm3d_apgm examples/sparsecode_apgm examples/sparsecode_nn_apgm examples/sparsecode_poisson_apgm examples/denoise_tv_apgm examples/denoise_approx_tv_multi PCG ^^^ .. toctree:: :maxdepth: 1 examples/ct_astra_noreg_pcg ================================================ FILE: docs/source/include/blockarray.rst ================================================ .. _blockarray_class: BlockArray ========== .. testsetup:: >>> import numpy as np >>> import scico >>> import scico.random >>> import scico.linop >>> import scico.numpy as snp >>> from scico.numpy import BlockArray The class :class:`.BlockArray` provides a way to combine arrays of different shapes into a single object for use with other SCICO classes. A :class:`.BlockArray` consists of a list of :class:`jax.Array` objects, which we refer to as blocks. A :class:`.BlockArray` differs from a list in that, whenever possible, :class:`.BlockArray` properties and methods (including unary and binary operators like +, -, \*, ...) automatically map along the blocks, returning another :class:`.BlockArray` or tuple as appropriate. For example, :: >>> x = snp.blockarray(( ... [[1, 3, 7], ... [2, 2, 1]], ... [2, 4, 8] ... )) >>> x.shape # returns tuple ((2, 3), (3,)) >>> x * 2 # returns BlockArray # doctest: +ELLIPSIS BlockArray([...Array([[ 2, 6, 14], [ 4, 4, 2]], dtype=...), ...Array([ 4, 8, 16], dtype=...)]) >>> y = snp.blockarray(( ... [[.2], ... [.3]], ... [.4] ... )) >>> x + y # returns BlockArray # doctest: +ELLIPSIS BlockArray([...Array([[1.2, 3.2, 7.2], [2.3, 2.3, 1.3]], dtype=...), ...Array([2.4, 4.4, 8.4], dtype=...)]) .. _numpy_functions_blockarray: NumPy and SciPy Functions ------------------------- :mod:`scico.numpy`, :mod:`scico.numpy.testing`, and :mod:`scico.scipy.special` provide wrappers around :mod:`jax.numpy`, :mod:`numpy.testing` and :mod:`jax.scipy.special` where many of the functions have been extended to work with instances of :class:`.BlockArray`. In particular: * When a tuple of tuples is passed as the `shape` argument to an array creation routine, a :class:`.BlockArray` is created. * When a :class:`.BlockArray` is passed to a reduction function, the blocks are ravelled (i.e., reshaped to be 1D) and concatenated before the reduction is applied. This behavior may be prevented by passing the `axis` argument, in which case the function is mapped over the blocks. * When one or more :class:`.BlockArray` instances are passed to a mathematical function that is not a reduction, the function is mapped over (corresponding) blocks. For a list of array creation routines, see :: >>> scico.numpy.creation_routines # doctest: +ELLIPSIS ('empty', ...) For a list of reduction functions, see :: >>> scico.numpy.reduction_functions # doctest: +ELLIPSIS ('sum', ...) For lists of the remaining wrapped functions, see :: >>> scico.numpy.mathematical_functions # doctest: +ELLIPSIS ('sin', ...) >>> scico.numpy.testing_functions # doctest: +ELLIPSIS ('testing.assert_allclose', ...) >>> import scico.scipy >>> scico.scipy.special.functions # doctest: +ELLIPSIS ('betainc', ...) Note that: * The functional and method versions of the "same" function differ in their behavior, with the method version only applying the reduction within each block, and the function version applying the reduction across all blocks. For example, :func:`scico.numpy.sum` applied to a :class:`.BlockArray` with two blocks returns a scalar value, while :meth:`.BlockArray.sum` returns a :class:`.BlockArray` two scalar blocks. * For example, :func:`scico.numpy.ravel` returns a fully flattened, single :class:`jax.Array`, while :meth:`.BlockArray.ravel` returns a :class:`.BlockArray` with ravelled blocks. Motivating Example ------------------ The discrete differences of a two-dimensional array, :math:`\mb{x} \in \mbb{R}^{n \times m}`, in the horizontal and vertical directions can be represented by the arrays :math:`\mb{x}_h \in \mbb{R}^{n \times (m-1)}` and :math:`\mb{x}_v \in \mbb{R}^{(n-1) \times m}` respectively. While it is usually useful to consider the output of a difference operator as a single entity, we cannot combine these two arrays into a single array since they have different shapes. We could vectorize each array and concatenate the resulting vectors, leading to :math:`\mb{\bar{x}} \in \mbb{R}^{n(m-1) + m(n-1)}`, which can be stored as a one-dimensional array, but this makes it hard to access the individual components :math:`\mb{x}_h` and :math:`\mb{x}_v`. Instead, we can construct a :class:`.BlockArray`, :math:`\mb{x}_B = [\mb{x}_h, \mb{x}_v]`: :: >>> n = 32 >>> m = 16 >>> x_h, key = scico.random.randn((n, m-1)) >>> x_v, _ = scico.random.randn((n-1, m), key=key) # Form the blockarray >>> x_B = snp.blockarray([x_h, x_v]) # The blockarray shape is a tuple of tuples >>> x_B.shape ((32, 15), (31, 16)) # Each block component can be easily accessed >>> x_B[0].shape (32, 15) >>> x_B[1].shape (31, 16) Constructing a BlockArray ------------------------- The recommended way to construct a :class:`.BlockArray` is by using the :func:`~scico.numpy.blockarray` function. :: >>> import scico.numpy as snp >>> x0, key = scico.random.randn((32, 32)) >>> x1, _ = scico.random.randn((16,), key=key) >>> X = snp.blockarray((x0, x1)) >>> X.shape ((32, 32), (16,)) >>> X.size (1024, 16) >>> len(X) 2 While :func:`~scico.numpy.blockarray` will accept arguments of type :class:`~numpy.ndarray` or :class:`~jax.Array`, arguments of type :class:`~numpy.ndarray` will be converted to :class:`~jax.Array` type. Operating on a BlockArray ------------------------- .. _blockarray_indexing: Indexing ^^^^^^^^ :class:`.BlockArray` indexing works just like indexing a list. Multiplication between BlockArray and LinearOperator ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The :class:`.Operator` and :class:`.LinearOperator` classes are designed to work on instances of :class:`.BlockArray` in addition to instances of :obj:`~jax.Array`. For example :: >>> x, key = scico.random.randn((3, 4)) >>> A_1 = scico.linop.Identity(x.shape) >>> A_1.shape # array -> array ((3, 4), (3, 4)) >>> A_2 = scico.linop.FiniteDifference(x.shape) >>> A_2.shape # array -> BlockArray (((2, 4), (3, 3)), (3, 4)) >>> diag = snp.blockarray([np.array(1.0), np.array(2.0)]) >>> A_3 = scico.linop.Diagonal(diag, input_shape=(A_2.output_shape)) >>> A_3.shape # BlockArray -> BlockArray (((2, 4), (3, 3)), ((2, 4), (3, 3))) ================================================ FILE: docs/source/include/examplenotes.rst ================================================ .. _example_depend: Example Dependencies -------------------- Some examples use additional dependencies, which are listed in `examples_requirements.txt `_. The additional requirements should be installed via pip, with the exception of ``astra-toolbox``, which should be installed via conda: :: conda install astra-toolbox pip install -r examples/examples_requirements.txt # Installs other example requirements The dependencies can also be installed individually as required. Note that ``astra-toolbox`` should be installed on a host with one or more CUDA GPUs to ensure that the version with GPU support is installed. Run Time -------- Most of these examples have been constructed with sufficiently small test problems to allow them to run to completion within 5 minutes or less on a reasonable workstation. Note, however, that it was not feasible to construct meaningful examples of the training of some of the deep learning algorithms that complete within a relatively short time; the examples "CT Training and Reconstructions with MoDL" and "CT Training and Reconstructions with ODP" in particular are much slower, and can require multiple hours to run on a workstation with multiple GPUs. | ================================================ FILE: docs/source/include/functional.rst ================================================ Functionals =========== A functional is a mapping from :math:`\mathbb{R}^n` or :math:`\mathbb{C}^n` to :math:`\mathbb{R}`. In SCICO, functionals are primarily used to represent a cost to be minimized and are represented by instances of the :class:`.Functional` class. An instance of :class:`.Functional`, ``f``, may provide three core operations. * Evaluation - ``f(x)`` returns the value of the functional evaluated at the point ``x``. - A functional that can be evaluated has the attribute ``f.has_eval == True``. - Not all functionals can be evaluated: see `Plug-and-Play`_. * Gradient - ``f.grad(x)`` returns the gradient of the functional evaluated at ``x``. - Gradients are calculated using JAX reverse-mode automatic differentiation, exposed through :func:`scico.grad`. - *Note:* The gradient of a functional ``f`` can be evaluated even if that functional is not smooth. All that is required is that the functional can be evaluated, ``f.has_eval == True``. However, the result may not be a valid gradient (or subgradient) for all inputs. * Proximal operator - ``f.prox(v, lam)`` returns the result of the scaled proximal operator of ``f``, i.e., the proximal operator of ``lambda x: lam * f(x)``, evaluated at the point ``v``. - The proximal operator of a functional :math:`f : \mathbb{R}^n \to \mathbb{R}` is the mapping :math:`\mathrm{prox}_f : \mathbb{R}^n \to \mathbb{R}^n` defined as .. math:: \mathrm{prox}_f (\mb{v}) = \argmin_{\mb{x}} f(\mb{x}) + \frac{1}{2} \norm{\mb{v} - \mb{x}}_2^2\;. Plug-and-Play ------------- For the plug-and-play framework :cite:`sreehari-2016-plug`, we encapsulate generic denoisers including CNNs in :class:`.Functional` objects that **cannot be evaluated**. The denoiser is applied via the the proximal operator. For examples, see :ref:`example_notebooks`. Proximal Calculus ----------------- We support a limited subset of proximal calculus rules: Scaled Functionals ^^^^^^^^^^^^^^^^^^ Given a scalar ``c`` and a functional ``f`` with a defined proximal method, we can determine the proximal method of ``c * f`` as .. math:: \begin{align} \mathrm{prox}_{c f} (v, \lambda) &= \argmin_x \lambda (c f)(x) + \frac{1}{2} \norm{v - x}_2^2 \\ &= \argmin_x (\lambda c) f(x) + \frac{1}{2} \norm{v - x}_2^2 \\ &= \mathrm{prox}_{f} (v, c \lambda) \;. \end{align} Note that we have made no assumptions regarding homogeneity of ``f``; rather, only that the proximal method of ``f`` is given in the parameterized form :math:`\mathrm{prox}_{c f}`. In SCICO, multiplying a :class:`.Functional` by a scalar will return a :class:`.ScaledFunctional`. This :class:`.ScaledFunctional` retains the ``has_eval`` and ``has_prox`` attributes from the original :class:`.Functional`, but the proximal method is modified to accomodate the additional scalar. Separable Functionals ^^^^^^^^^^^^^^^^^^^^^ A separable functional :math:`f : \mathbb{C}^N \to \mathbb{R}` can be written as the sum of functionals :math:`f_i : \mathbb{C}^{N_i} \to \mathbb{R}` with :math:`\sum_i N_i = N`. In particular, .. math:: f(\mb{x}) = f(\mb{x}_1, \dots, \mb{x}_N) = f_1(\mb{x}_1) + \dots + f_N(\mb{x}_N) \;. The proximal operator of a separable :math:`f` can be written in terms of the proximal operators of the :math:`f_i` (see Theorem 6.6 of :cite:`beck-2017-first`): .. math:: \mathrm{prox}_f(\mb{x}, \lambda) = \begin{pmatrix} \mathrm{prox}_{f_1}(\mb{x}_1, \lambda) \\ \vdots \\ \mathrm{prox}_{f_N}(\mb{x}_N, \lambda) \\ \end{pmatrix} \;. Separable Functionals are implemented in the :class:`.SeparableFunctional` class. Separable functionals naturally accept :class:`.BlockArray` inputs and return the prox as a :class:`.BlockArray`. Adding New Functionals ---------------------- To add a new functional, create a class which 1. inherits from base :class:`.Functional`; 2. has ``has_eval`` and ``has_prox`` flags; 3. has ``_eval`` and ``prox`` methods, as necessary. For example, :: class MyFunctional(scico.functional.Functional): has_eval = True has_prox = True def _eval(self, x: JaxArray) -> float: return snp.sum(x) def prox(self, x: JaxArray, lam : float) -> JaxArray: return x - lam Losses ------ In SCICO, a loss is a special type of functional .. math:: f(\mb{x}) = \alpha l( \mb{y}, A(\mb{x}) ) \;, where :math:`\alpha` is a scaling parameter, :math:`l` is a functional, :math:`\mb{y}` is a set of measurements, and :math:`A` is an operator. SCICO uses the class :class:`.Loss` to represent losses. Loss functionals commonly arrise in the context of solving inverse problems in scientific imaging, where they are used to represent the mismatch between predicted measurements :math:`A(\mb{x})` and actual ones :math:`\mb{y}`. ================================================ FILE: docs/source/include/learning.rst ================================================ Learned Models ============== In SCICO, neural network models are used to represent imaging problems and provide different modes of data-driven regularization. The models are implemented in `Flax `_, and constitute a representative sample of frequently used networks. FlaxMap ------- SCICO interfaces with the implemented models via :class:`.FlaxMap`. This provides a standardized access to all trained models via the model definiton and the learned parameters. Further specialized functionality, such as learned denoisers, are built on top of :class:`.FlaxMap`. The specific models that have been implemented are described below. DnCNN ----- The denoiser convolutional neural network model (DnCNN) :cite:`zhang-2017-dncnn`, implemented as :class:`.DnCNNNet`, is used to denoise images that have been corrupted with additive Gaussian noise. ODP --- The unrolled optimization with deep priors (ODP) :cite:`diamond-2018-odp`, implemented as :class:`.ODPNet`, is used to solve inverse problems in imaging by adapting classical iterative methods into an end-to-end framework that incorporates deep networks as well as knowledge of the image formation model. The framework aims to solve the optimization problem .. math:: \argmin_{\mb{x}} \; f(A \mb{x}, \mb{y}) + r(\mb{x}) \;, where :math:`A` represents a linear forward model and :math:`r` a regularization function encoding prior information, by unrolling the iterative solution method into a network where each iteration corresponds to a different stage in the ODP network. Different iterative solutions produce different unrolled optimization algorithms which, in turn, produce different ODP networks. The ones implemented in SCICO are described below. Proximal Map ^^^^^^^^^^^^ This algorithm corresponds to solving .. math:: :label: eq:odp_prox \argmin_{\mb{x}} \; \alpha_k \, f(A \mb{x}, \mb{y}) + \frac{1}{2} \| \mb{x} - \mb{x}^k - \mb{x}^{k+1/2} \|_2^2 \;, with :math:`k` corresponding to the index of the iteration, which translates to an index of the stage of the network, :math:`f(A \mb{x}, \mb{y})` a fidelity term, usually an :math:`\ell_2` norm, and :math:`\mb{x}^{k+1/2}` a regularization representing :math:`\mathrm{prox}_r (\mb{x}^k)` and usually implemented as a convolutional neural network (CNN). This proximal map representation is used when minimization problem :eq:`eq:odp_prox` can be solved in a computationally efficient manner. :class:`.ODPProxDnBlock` uses this formulation to solve a denoising problem, which, according to :cite:`diamond-2018-odp`, can be solved by .. math:: \mb{x}^{k+1} = (\alpha_k \, \mb{y} + \mb{x}^k + \mb{x}^{k+1/2}) \, / \, (\alpha_k + 1) \;, where :math:`A` corresponds to the identity operator and is therefore omitted, :math:`\mb{y}` is the noisy signal, :math:`\alpha_k > 0` is a learned stage-wise parameter weighting the contribution of the fidelity term and :math:`\mb{x}^k + \mb{x}^{k+1/2}` is the regularization, usually represented by a residual CNN. :class:`.ODPProxDblrBlock` uses this formulation to solve a deblurring problem, which, according to :cite:`diamond-2018-odp`, can be solved by .. math:: \mb{x}^{k+1} = \mathcal{F}^{-1} \mathrm{diag} (\alpha_k | \mathcal{F}(K)|^2 + 1 )^{-1} \mathcal{F} \, (\alpha_k K^T * \mb{y} + \mb{x}^k + \mb{x}^{k+1/2}) \;, where :math:`A` is the blurring operator, :math:`K` is the blurring kernel, :math:`\mb{y}` is the blurred signal, :math:`\mathcal{F}` is the DFT, :math:`\alpha_k > 0` is a learned stage-wise parameter weighting the contribution of the fidelity term and :math:`\mb{x}^k + \mb{x}^{k+1/2}` is the regularization represented by a residual CNN. Gradient Descent ^^^^^^^^^^^^^^^^ When the solution of the optimization problem in :eq:`eq:odp_prox` can not be simply represented by an analytical step, a formulation based on a gradient descent iteration is preferred. This yields .. math:: \mb{x}^{k+1} = \mb{x}^k + \mb{x}^{k+1/2} - \alpha_k \, A^T \nabla_x \, f(A \mb{x}^k, \mb{y}) \;, where :math:`\mb{x}^{k+1/2}` represents :math:`\nabla r(\mb{x}^k)`. :class:`.ODPGrDescBlock` uses this formulation to solve a generic problem with :math:`\ell_2` fidelity as .. math:: \mb{x}^{k+1} = \mb{x}^k + \mb{x}^{k+1/2} - \alpha_k \, A^T (A \mb{x} - \mb{y}) \;, with :math:`\mb{y}` the measured signal and :math:`\mb{x} + \mb{x}^{k+1/2}` a residual CNN. MoDL ---- The model-based deep learning (MoDL) :cite:`aggarwal-2019-modl`, implemented as :class:`.MoDLNet`, is used to solve inverse problems in imaging also by adapting classical iterative methods into an end-to-end deep learning framework, but, in contrast to ODP, it solves the optimization problem .. math:: \argmin_{\mb{x}} \; \| A \mb{x} - \mb{y}\|_2^2 + \lambda \, \| \mb{x} - \mathrm{D}_w(\mb{x})\|_2^2 \;, by directly computing the update .. math:: \mb{x}^{k+1} = (A^T A + \lambda \, I)^{-1} (A^T \mb{y} + \lambda \, \mb{z}^k) \;, via conjugate gradient. The regularization :math:`\mb{z}^k = \mathrm{D}_w(\mb{x}^{k})` incorporates prior information, usually in the form of a denoiser model. In this case, the denoiser :math:`\mathrm{D}_w` is shared between all the stages of the network requiring relatively less memory than other unrolling methods. This also allows for deploying a different number of iterations in testing than the ones used in training. ================================================ FILE: docs/source/include/operator.rst ================================================ Operators ========= An operator is a map from :math:`\mathbb{R}^n` or :math:`\mathbb{C}^n` to :math:`\mathbb{R}^m` or :math:`\mathbb{C}^m`. In SCICO, operators are primarily used to represent imaging systems and provide regularization. SCICO operators are represented by instances of the :class:`.Operator` class. SCICO :class:`.Operator` objects extend the notion of "shape" and "size" from the usual NumPy ``ndarray`` class. Each :class:`.Operator` object has an ``input_shape`` and ``output_shape``; these shapes can be either tuples or a tuple of tuples (in the case of a :class:`.BlockArray`). The ``matrix_shape`` attribute describes the shape of the :class:`.LinearOperator` if it were to act on vectorized, or flattened, inputs. For example, consider a two-dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`. We compute the discrete differences of :math:`\mb{x}` in the horizontal and vertical directions, generating two new arrays: :math:`\mb{x}_h \in \mathbb{R}^{n \times (m-1)}` and :math:`\mb{x}_v \in \mathbb{R}^{(n-1) \times m}`. We represent this linear operator by :math:`\mb{A} : \mathbb{R}^{n \times m} \to \mathbb{R}^{n \times (m-1)} \otimes \mathbb{R}^{(n-1) \times m}`. In SCICO, this linear operator will return a :class:`.BlockArray` with the horizontal and vertical differences stored as blocks. Letting :math:`y = \mb{A} x`, we have ``y.shape = ((n, m-1), (n-1, m))`` and :: A.input_shape = (n, m) A.output_shape = ((n, m-1), (n-1, m)], (n, m)) A.shape = ( ((n, m-1), (n-1, m)), (n, m)) # (output_shape, input_shape) A.input_size = n*m A.output_size = n*(n-1)*m*(m-1) A.matrix_shape = (n*(n-1)*m*(m-1), n*m) # (output_size, input_size) Operator Calculus ----------------- SCICO supports a variety of operator calculus rules, allowing new operators to be defined in terms of old ones. The following table summarizes the available operations. +----------------+-----------------+ | Operation | Result | +----------------+-----------------+ | ``(A+B)(x)`` | ``A(x) + B(x)`` | +----------------+-----------------+ | ``(A-B)(x)`` | ``A(x) - B(x)`` | +----------------+-----------------+ | ``(c * A)(x)`` | ``c * A(x)`` | +----------------+-----------------+ | ``(A/c)(x)`` | ``A(x)/c`` | +----------------+-----------------+ | ``(-A)(x)`` | ``-A(x)`` | +----------------+-----------------+ | ``A(B)(x)`` | ``A(B(x))`` | +----------------+-----------------+ | ``A(B)`` | ``Operator`` | +----------------+-----------------+ Defining a New Operator ----------------------- To define a new operator, pass a callable to the :class:`.Operator` constructor: :: A = Operator(input_shape=(32,), eval_fn = lambda x: 2 * x) Or use subclassing: :: >>> from scico.operator import Operator >>> class MyOp(Operator): ... ... def _eval(self, x): ... return 2 * x >>> A = MyOp(input_shape=(32,)) At a minimum, the ``_eval`` function must be overridden. If either ``output_shape`` or ``output_dtype`` are unspecified, they are determined by evaluating the operator on an input of appropriate shape and dtype. Linear Operators ================ Linear operators are those for which .. math:: H(a \mb{x} + b \mb{y}) = a H(\mb{x}) + b H(\mb{y}) \;. SCICO represents linear operators as instances of the class :class:`.LinearOperator`. While finite-dimensional linear operators can always be associated with a matrix, it is often useful to represent them in a matrix-free manner. Most of SCICO's linear operators are implemented matrix-free. Using a LinearOperator ---------------------- We implement two ways to evaluate a :class:`.LinearOperator`. The first is using standard callable syntax: ``A(x)``. The second mimics the NumPy matrix multiplication syntax: ``A @ x``. Both methods perform shape and type checks to validate the input before ultimately either calling `A._eval` or generating a new :class:`.LinearOperator`. For linear operators that map real-valued inputs to real-valued outputs, there are two ways to apply the adjoint: ``A.adj(y)`` and ``A.T @ y``. For complex-valued linear operators, there are three ways to apply the adjoint ``A.adj(y)``, ``A.H @ y``, and ``A.conj().T @ y``. Note that in this case, ``A.T`` returns the non-conjugated transpose of the :class:`.LinearOperator`. While the cost of evaluating the linear operator is virtually identical for ``A(x)`` and ``A @ x``, the ``A.H`` and ``A.conj().T`` methods are somewhat slower; especially the latter. This is because two intermediate linear operators must be created before the function is evaluated. Evaluating ``A.conj().T @ y`` is equivalent to: :: def f(y): B = A.conj() # New LinearOperator #1 C = B.T # New LinearOperator #2 return C @ y **Note**: the speed differences between these methods vanish if applied inside of a jit-ed function. For instance: :: f = jax.jit(lambda x: A.conj().T @ x) +------------------+-----------------+ | Public Method | Private Method | +------------------+-----------------+ | ``__call__`` | ``._eval`` | +------------------+-----------------+ | ``adj`` | ``._adj`` | +------------------+-----------------+ | ``gram`` | ``._gram`` | +------------------+-----------------+ The public methods perform shape and type checking to validate the input before either calling the corresponding private method or returning a composite LinearOperator. Linear Operator Calculus ------------------------ SCICO supports several linear operator calculus rules. Given ``A`` and ``B`` of class :class:`.LinearOperator` and of appropriate shape, ``x`` an array of appropriate shape, ``c`` a scalar, and ``O`` an :class:`.Operator`, we have +----------------+----------------------------+ | Operation | Result | +----------------+----------------------------+ | ``(A+B)(x)`` | ``A(x) + B(x)`` | +----------------+----------------------------+ | ``(A-B)(x)`` | ``A(x) - B(x)`` | +----------------+----------------------------+ | ``(c * A)(x)`` | ``c * A(x)`` | +----------------+----------------------------+ | ``(A/c)(x)`` | ``A(x)/c`` | +----------------+----------------------------+ | ``(-A)(x)`` | ``-A(x)`` | +----------------+----------------------------+ | ``(A@B)(x)`` | ``A@B@x`` | +----------------+----------------------------+ | ``A @ B`` | ``ComposedLinearOperator`` | +----------------+----------------------------+ | ``A @ O`` | ``Operator`` | +----------------+----------------------------+ | ``O(A)`` | ``Operator`` | +----------------+----------------------------+ Defining a New Linear Operator ------------------------------ To define a new linear operator, pass a callable to the :class:`.LinearOperator` constructor :: >>> from scico.linop import LinearOperator >>> A = LinearOperator(input_shape=(32,), ... eval_fn = lambda x: 2 * x) Or, use subclassing: :: >>> class MyLinearOperator(LinearOperator): ... def _eval(self, x): ... return 2 * x >>> A = MyLinearOperator(input_shape=(32,)) At a minimum, the ``_eval`` method must be overridden. If the ``_adj`` method is not overriden, the adjoint is determined using :func:`scico.linear_adjoint`. If either ``output_shape`` or ``output_dtype`` are unspecified, they are determined by evaluating the Operator on an input of appropriate shape and dtype. 🔪 Sharp Edges 🔪 ------------------ Strict Types in Adjoint ^^^^^^^^^^^^^^^^^^^^^^^ SCICO silently promotes real types to complex types in forward application, but enforces strict type checking in the adjoint. This is due to the strict type-safe nature of jax adjoints. LinearOperators from External Code ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ External code may be wrapped as a subclass of :class:`.Operator` or :class:`.LinearOperator` and used in SCICO optimization routines; however this process can be complicated and error-prone. As a starting point, look at the source for :class:`.radon_svmbir.TomographicProjector` or :class:`.radon_astra.TomographicProjector` and the JAX documentation for the `vector-jacobian product `_ and `custom VJP rules `_. ================================================ FILE: docs/source/include/optimizer.rst ================================================ .. _optimizer: Optimization Algorithms ======================= ADMM ---- The Alternating Direction Method of Multipliers (ADMM) :cite:`glowinski-1975-approximation` :cite:`gabay-1976-dual` is an algorithm for minimizing problems of the form .. math:: :label: eq:admm_prob \argmin_{\mb{x}, \mb{z}} \; f(\mb{x}) + g(\mb{z}) \; \text{such that} \; \acute{A} \mb{x} + \acute{B} \mb{z} = \mb{c} \;, where :math:`f` and :math:`g` are convex (but not necessarily smooth) functionals, :math:`\acute{A}` and :math:`\acute{B}` are linear operators, and :math:`\mb{c}` is a constant vector. (For a thorough introduction and overview, see :cite:`boyd-2010-distributed`.) The SCICO ADMM solver, :class:`.ADMM`, solves problems of the form .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + \sum_{i=1}^N g_i(C_i \mb{x}) \;, where :math:`f` and the :math:`g_i` are instances of :class:`.Functional`, and the :math:`C_i` are :class:`.LinearOperator`, by defining .. math:: g(\mb{z}) = \sum_{i=1}^N g_i(\mb{z}_i) \qquad \mb{z}_i = C_i \mb{x} in :eq:`eq:admm_prob`, corresponding to defining .. math:: \acute{A} = \left( \begin{array}{c} C_0 \\ C_1 \\ C_2 \\ \vdots \end{array} \right) \quad \acute{B} = \left( \begin{array}{cccc} -I & 0 & 0 & \ldots \\ 0 & -I & 0 & \ldots \\ 0 & 0 & -I & \ldots \\ \vdots & \vdots & \vdots & \ddots \end{array} \right) \quad \mb{z} = \left( \begin{array}{c} \mb{z}_0 \\ \mb{z}_1 \\ \mb{z}_2 \\ \vdots \end{array} \right) \quad \mb{c} = \left( \begin{array}{c} 0 \\ 0 \\ 0 \\ \vdots \end{array} \right) \;. In :class:`.ADMM`, :math:`f` is a :class:`.Functional`, typically a :class:`.Loss`, corresponding to the forward model of an imaging problem, and the :math:`g_i` are :class:`.Functional`, typically corresponding to a regularization term or constraint. Each of the :math:`g_i` must have a proximal operator defined. It is also possible to set ``f = None``, which corresponds to defining :math:`f = 0`, i.e. the zero function. Subproblem Solvers ^^^^^^^^^^^^^^^^^^ The most computational expensive component of the ADMM iterations is typically the :math:`\mb{x}`-update, .. math:: :label: eq:admm_x_step \argmin_{\mb{x}} \; f(\mb{x}) + \sum_i \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;. The available solvers for this problem are: * :class:`.admm.GenericSubproblemSolver` This is the default subproblem solver as it is applicable in all cases. It it is only suitable for relatively small-scale problems as it makes use of :func:`.solver.minimize`, which wraps :func:`scipy.optimize.minimize`. * :class:`.admm.LinearSubproblemSolver` This subproblem solver can be used when :math:`f` takes the form :math:`\norm{\mb{A} \mb{x} - \mb{y}}^2_W`. It makes use of the conjugate gradient method, and is significantly more efficient than :class:`.admm.GenericSubproblemSolver` when it can be used. * :class:`.admm.MatrixSubproblemSolver` This subproblem solver can be used when :math:`f` takes the form :math:`\norm{\mb{A} \mb{x} - \mb{y}}^2_W`, and :math:`A` and all of the :math:`C_i` are diagonal (:class:`.Diagonal`) or matrix operators (:class:`MatrixOperator`). It exploits a pre-computed matrix factorization for a significantly more efficient solution than conjugate gradient. * :class:`.admm.CircularConvolveSolver` This subproblem solver can be used when :math:`f` takes the form :math:`\norm{\mb{A} \mb{x} - \mb{y}}^2_W` and :math:`\mb{A}` and all the :math:`C_i` s are circulant (i.e., diagonalized by the DFT). * :class:`.admm.FBlockCircularConvolveSolver` and :class:`.admm.G0BlockCircularConvolveSolver` These subproblem solvers can be used when the primary linear operator is block-circulant (i.e. an operator with blocks that are diagonalied by the DFT). For more details of these solvers and how to specify them, see the API reference page for :mod:`scico.optimize.admm`. Proximal ADMM ------------- Proximal ADMM :cite:`deng-2015-global` is an algorithm for solving problems of the form .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + g(\mb{z}) \; \text{such that}\; A \mb{x} + B \mb{z} = \mb{c} \;, where :math:`f` and :math:`g` are are convex (but not necessarily smooth) functionals and :math:`A` and :math:`B` are linear operators. Although convergence per iteration is typically somewhat worse than that of ADMM, the iterations can be much cheaper than that of ADMM, giving Proximal ADMM competitive time convergence performance. The SCICO Proximal ADMM solver, :class:`.ProximalADMM`, requires :math:`f` and :math:`g` to be instances of :class:`.Functional`, and to have a proximal operator defined (:meth:`.Functional.prox`), and :math:`A` and :math:`B` are required to be an instance of :class:`.LinearOperator`. Non-Linear Proximal ADMM ------------------------ Non-Linear Proximal ADMM :cite:`benning-2016-preconditioned` is an algorithm for solving problems of the form .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + g(\mb{z}) \; \text{such that}\; H(\mb{x}, \mb{z}) = 0 \;, where :math:`f` and :math:`g` are are convex (but not necessarily smooth) functionals and :math:`H` is a function of two vector variables. The SCICO Non-Linear Proximal ADMM solver, :class:`.NonLinearPADMM`, requires :math:`f` and :math:`g` to be instances of :class:`.Functional`, and to have a proximal operator defined (:meth:`.Functional.prox`), and :math:`H` is required to be an instance of :class:`.Function`. Linearized ADMM --------------- Linearized ADMM :cite:`yang-2012-linearized` :cite:`parikh-2014-proximal` (Sec. 4.4.2) is an algorithm for solving problems of the form .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + g(C \mb{x}) \;, where :math:`f` and :math:`g` are are convex (but not necessarily smooth) functionals. Although convergence per iteration is typically significantly worse than that of ADMM, the :math:`\mb{x}`-update, can be much cheaper than that of ADMM, giving Linearized ADMM competitive time convergence performance. The SCICO Linearized ADMM solver, :class:`.LinearizedADMM`, requires :math:`f` and :math:`g` to be instances of :class:`.Functional`, and to have a proximal operator defined (:meth:`.Functional.prox`), and :math:`C` is required to be an instance of :class:`.LinearOperator`. PDHG ---- The Primal–Dual Hybrid Gradient (PDHG) algorithm :cite:`esser-2010-general` :cite:`chambolle-2010-firstorder` :cite:`pock-2011-diagonal` solves problems of the form .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + g(C \mb{x}) \;, where :math:`f` and :math:`g` are are convex (but not necessarily smooth) functionals. The algorithm has similar advantages over ADMM to those of Linearized ADMM, but typically exhibits better convergence properties. The SCICO PDHG solver, :class:`.PDHG`, requires :math:`f` and :math:`g` to be instances of :class:`.Functional`, and to have a proximal operator defined (:meth:`.Functional.prox`), and :math:`C` is required to be an instance of :class:`.Operator` or :class:`.LinearOperator`. PGM --- The Proximal Gradient Method (PGM) :cite:`daubechies-2004-iterative` :cite:`beck-2010-gradient` and Accelerated Proximal Gradient Method (AcceleratedPGM) :cite:`beck-2009-fast` are algorithms for minimizing problems of the form .. math:: \argmin_{\mb{x}} f(\mb{x}) + g(\mb{x}) \;, where :math:`g` is convex and :math:`f` is smooth and convex. The corresponding SCICO solvers are :class:`.PGM` and :class:`.AcceleratedPGM` respectively. In most cases :class:`.AcceleratedPGM` is expected to provide faster convergence. In both of these classes, :math:`f` and :math:`g` are both of type :class:`.Functional`, where :math:`f` must be differentiable, and :math:`g` must have a proximal operator defined. While ADMM provides significantly more flexibility than PGM, and often converges faster, the latter is preferred when solving the ADMM :math:`\mb{x}`-step is very computationally expensive, such as in the case of :math:`f(\mb{x}) = \norm{\mb{A} \mb{x} - \mb{y}}^2_W` where :math:`A` is large and does not have any special structure that would allow an efficient solution of :eq:`eq:admm_x_step`. Step Size Options ^^^^^^^^^^^^^^^^^ The step size (usually referred to in terms of its reciprocal, :math:`L`) for the gradient descent in :class:`PGM` can be adapted via Barzilai-Borwein methods (also called spectral methods) and iterative line search methods. The available step size policy classes are: * :class:`.BBStepSize` This implements the step size adaptation based on the Barzilai-Borwein method :cite:`barzilai-1988-stepsize`. The step size :math:`\alpha` is estimated as .. math:: \mb{\Delta x} = \mb{x}_k - \mb{x}_{k-1} \; \\ \mb{\Delta g} = \nabla f(\mb{x}_k) - \nabla f (\mb{x}_{k-1}) \; \\ \alpha = \frac{\mb{\Delta x}^T \mb{\Delta g}}{\mb{\Delta g}^T \mb{\Delta g}} \;. Since the PGM solver uses the reciprocal of the step size, the value :math:`L = 1 / \alpha` is returned. * :class:`.AdaptiveBBStepSize` This implements the adaptive Barzilai-Borwein method as introduced in :cite:`zhou-2006-adaptive`. The adaptive step size rule computes .. math:: \mb{\Delta x} = \mb{x}_k - \mb{x}_{k-1} \; \\ \mb{\Delta g} = \nabla f(\mb{x}_k) - \nabla f (\mb{x}_{k-1}) \; \\ \alpha^{\mathrm{BB1}} = \frac{\mb{\Delta x}^T \mb{\Delta x}} {\mb{\Delta x}^T \mb{\Delta g}} \; \\ \alpha^{\mathrm{BB2}} = \frac{\mb{\Delta x}^T \mb{\Delta g}} {\mb{\Delta g}^T \mb{\Delta g}} \;. The determination of the new step size is made via the rule .. math:: \alpha = \left\{ \begin{array}{ll} \alpha^{\mathrm{BB2}} & \mathrm{~if~} \alpha^{\mathrm{BB2}} / \alpha^{\mathrm{BB1}} < \kappa \; \\ \alpha^{\mathrm{BB1}} & \mathrm{~otherwise} \end{array} \right . \;, with :math:`\kappa \in (0, 1)`. Since the PGM solver uses the reciprocal of the step size, the value :math:`L = 1 / \alpha` is returned. * :class:`.LineSearchStepSize` This implements the line search strategy described in :cite:`beck-2009-fast`. This strategy estimates :math:`L` such that :math:`f(\mb{x}) \leq \hat{f}_{L}(\mb{x})` is satisfied with :math:`\hat{f}_{L}` a quadratic approximation to :math:`f` defined as .. math:: \hat{f}_{L}(\mb{x}, \mb{y}) = f(\mb{y}) + \nabla f(\mb{y})^H (\mb{x} - \mb{y}) + \frac{L}{2} \left\| \mb{x} - \mb{y} \right\|_2^2 \;, with :math:`\mb{x}` the potential new update and :math:`\mb{y}` the current solution or current extrapolation (if using :class:`.AcceleratedPGM`). * :class:`.RobustLineSearchStepSize` This implements the robust line search strategy described in :cite:`florea-2017-robust`. This strategy estimates :math:`L` such that :math:`f(\mb{x}) \leq \hat{f}_{L}(\mb{x})` is satisfied with :math:`\hat{f}_{L}` a quadratic approximation to :math:`f` defined as .. math:: \hat{f}_{L}(\mb{x}, \mb{y}) = f(\mb{y}) + \nabla f(\mb{y})^H (\mb{x} - \mb{y}) + \frac{L}{2} \left\| \mb{x} - \mb{y} \right\|_2^2 \;, with :math:`\mb{x}` the potential new update and :math:`\mb{y}` the auxiliary extrapolation state. Note that this should only be used with :class:`.AcceleratedPGM`. For more details of these step size managers and how to specify them, see the API reference page for :mod:`scico.optimize.pgm`. ================================================ FILE: docs/source/index.rst ================================================ SCICO Documentation =================== .. toctree:: :maxdepth: 2 :caption: User Documentation overview inverse advantages install classes notes examples API Reference <_autosummary/scico.rst> zreferences .. toctree:: :maxdepth: 2 :caption: Developer Documentation team contributing style Indices ======= * :ref:`genindex` * :ref:`modindex` ================================================ FILE: docs/source/install.rst ================================================ .. _installing: Installing SCICO ================ SCICO requires Python version 3.8 or later. (Version 3.12 is recommended as it is the version under which SCICO is tested in GitHub continuous integration, and since the most recent versions of JAX require version 3.10 or later.) SCICO is supported on both Linux and MacOS, but is not currently supported on Windows due to the limited support for ``jaxlib`` on Windows. However, Windows users can use SCICO via the `Windows Subsystem for Linux `_ (WSL). Guides exist for using WSL with `CPU only `_ and with `GPU support `_. While not required, installation of SCICO and its dependencies within a `Conda `_ environment is recommended. `Scripts `_ are provided for creating a `miniconda `_ installation and an environment including all primary SCICO dependencies as well as dependencies for usage example, testing, and building the documentation. From PyPI --------- The simplest way to install the most recent release of SCICO from `PyPI `_ is :: pip install scico which will install SCICO and its primary dependencies. If the additional dependencies for the example scripts are also desired, it can instead be installed using :: pip install scico[examples] Note, however, that since the ``astra-toolbox`` package available from PyPI is not straightforward to install (it has numerous build requirements that are not specified as package dependencies), it is recommended to first install this package via conda :: conda install astra-toolbox From conda-forge ---------------- SCICO can also be installed from `conda-forge `_ :: conda install -c conda-forge "scico>0.0.5" where the version constraint is required to avoid installation of an old package with broken dependencies. Note, however, that installation from conda forge is only possible on a Linux platform since there is no conda package for the secondary dependency ``tensorstore`` under MacOS. There are also complications on Linux platforms with Python versions 3.9 or earlier due to the automatic installation of a version of secondary dependency ``etils`` that does not support Python versions earlier than 3.10. This can be rectified by :: conda install etils=1.5.1 The most recent SCICO conda forge package also includes dependencies for the example scripts, except for ``bm3d``, ``bm4d``, and ``colour_demosaicing``, for which conda packages are not available. These can be installed from PyPI :: pip install bm3d bm4d colour_demosaicing From GitHub ----------- The development version of SCICO can be downloaded from the `GitHub repo `_. Note that, since the SCICO repo has a submodule, it should be cloned via the command :: git clone --recurse-submodules git@github.com:lanl/scico.git Install using the commands :: cd scico pip install -r requirements.txt pip install -e . If a clone of the SCICO repository is not needed, it is simpler to install directly using ``pip`` :: pip install git+https://github.com/lanl/scico GPU Support ----------- The instructions above install a CPU-only version of SCICO. To install a version with GPU support: 1. Follow the CPU-only instructions, above 2. Install the version of jaxlib with GPU support, as described in the `JAX installation instructions `_. In the simplest case, the appropriate command is :: pip install --upgrade "jax[cuda12]" for CUDA 12, but it may be necessary to explicitly specify the ``jaxlib`` version if the most recent release is not yet supported by SCICO (as specified in the ``requirements.txt`` file). The script `misc/gpu/envinfo.py `_ in the source distribution is provided as an aid to debugging GPU support issues. The script `misc/gpu/availgpu.py `_ can be used to automatically recommend a setting of the CUDA_VISIBLE_DEVICES environment variable that excludes GPUs that are already in use. Additional Dependencies ----------------------- See :ref:`example_depend` for instructions on installing dependencies related to the examples. For Developers -------------- See :ref:`scico_dev_contributing` for instructions on installing a version of SCICO suitable for development. ================================================ FILE: docs/source/inverse.rst ================================================ Inverse Problems ================ In traditional imaging, the burden of image formation is placed on physical components, such as a lens, with the resulting image being taken from the sensor with minimal processing. In computational imaging, in contrast, the burden of image formation is shared with or shifted to computation, with the resulting image typically being very different from the measured data. Common examples of computational imaging include demosaicing in consumer cameras, computed tomography and magnetic resonance imaging in medicine, and synthetic aperture radar in remote sensing. This is an active and growing area of research, and many of these problems have common properties that could be supported by shared implementations of solution components. The goal of SCICO is to provide a general research tool for computational imaging, with a particular focus on scientific imaging applications, which are particularly underrepresented in the existing range of open-source packages in this area. While a number of other packages overlap somewhat in functionality with SCICO, only a few support execution of the same code on both CPU and GPU devices, and we are not aware of any that support just-in-time compilation and automatic gradient computation, which is invaluable in computational imaging. SCICO provides all three of these valuable features (subject to some :ref:`caveats `) by being built on top of `JAX `__ rather than `NumPy `__. The remainder of this section outlines the steps involved in solving an inverse problem, and shows how each concept maps to a component of SCICO. More detail on the main classes involved in setting up and solving an inverse problem can be found in :ref:`classes`. Forward Modeling ---------------- In order to solve a computational imaging problem we need to know how the image we wish to reconstruct, :math:`\mathbf{x}`, is related to the data that we can measure, :math:`\mathbf{y}`. This is represented via a model of the measurement process, .. math:: \mathbf{y} = A(\mathbf{x}) \,. SCICO provides the :class:`.Operator` and :class:`.LinearOperator` classes, which may be subclassed by users, in order to implement the forward operator, :math:`A`. It also has several built-in operators, most of which are linear, e.g., finite convolutions, discrete Fourier transforms, optical propagators, Abel transforms, and X-ray transforms (the same as Radon transforms in 2D). For example, .. code:: python input_shape = (512, 512) angles = np.linspace(0, 2 * np.pi, 180, endpoint=False) channels = 512 A = scico.linop.xray.svmbir.XRayTransform(input_shape, angles, channels) defines a tomographic projection operator. A significant advantage of SCICO being built on top of `JAX `__ is that the adjoints of linear operators, which can be quite time consuming to implement even when the operator itself is straightforward, are computed automatically by exploiting the automatic differentation features of `JAX `__. If :code:`A` is a :class:`.LinearOperator`, then its adjoint is simply :code:`A.T` for real transforms and :code:`A.H` for complex transforms. Likewise, Jacobian-vector products can be automatically computed for non-linear operators, allowing for simple linearization and gradient calculations. SCICO operators can be composed to construct new operators. (If both operands are linear, then the result is also linear.) For example, if :code:`A` and :code:`B` have been defined as distinct linear operators, then .. code:: python C = B @ A defines a new linear operator :code:`C` that first applies operator :code:`A` and then applies operator :code:`B` to the result (i.e. :math:`C = B A` in math notation). This operator algebra can be used to build complicated forward operators from simpler building blocks. SCICO also handles cases where either the image we want to reconstruct, :math:`\mb{x}`, or its measurements, :math:`\mb{y}`, do not fit neatly into a multi-dimensional array. This is achieved via :class:`.BlockArray` objects, which consist of a :class:`list` of multi-dimensional array *blocks*. A :class:`.BlockArray` differs from a :class:`list` in that, whenever possible, :class:`.BlockArray` properties and methods (including unary and binary operators like ``+``, ``-``, ``*``, …) automatically map along the blocks, returning another :class:`.BlockArray` or :class:`tuple` as appropriate. For example, consider a system that measures the column sums and row sums of an image. If the input image has shape :math:`M \times N`, the resulting measurement will have shape :math:`M + N`, which is awkward to represent as a multi-dimensional array. In SCICO, we can represent this operator by .. code:: python input_shape = (130, 50) H0 = scico.linop.Sum(input_shape, axis=0) H1 = scico.linop.Sum(input_shape, axis=1) H = scico.linop.VerticalStack((H0, H1)) The result of applying ``H`` to an image with shape ``(130, 50)`` is a :class:`.BlockArray` with shape ``((50,), (130,))``. This result is compatible with the rest of SCICO and may be used, e.g., as the input of other operators. Inverse Problem Formulation --------------------------- In order to estimate the image from the measured data, we need to solve an *inverse problem*. In its simplest form, the solution to such an inverse problem can be expressed as the optimization problem .. math:: \hat{\mb{x}} = \mathop{\mathrm{arg\,min}}_{\mb{x}} f( \mb{x} ) \,, where :math:`\mb{x}` is the unknown image and :math:`\hat{\mb{x}}` is the recovered image. A common choice of :math:`f` is .. math:: f(\mb{x}) = (1/2) \| A(\mb{x}) - \mb{y} \|_2^2 \,, where :math:`\mb{y}` is the measured data and :math:`A` is the forward operator; in this case the minimization problem is a least squares problem. In SCICO, the :mod:`.functional` module provides implementations of common functionals such as :math:`\ell_2` and :math:`\ell_1` norms. The :mod:`.loss` module is used to implement a special type of functional .. math:: f(\mb{x}) = \alpha l(A(\mb{x}),\mb{y}) \,, where :math:`\alpha` is a scaling parameter and :math:`l(\cdot)` is another functional. The SCICO :mod:`.loss` module contains a variety of loss functionals that are commonly used in computational imaging. For example, the squared :math:`\ell_2` loss written above for a forward operator, :math:`A`, can be defined in SCICO using the code: .. code:: python f = scico.loss.SquaredL2Loss(y=y, A=A) The difficulty of the inverse problem depends on the amount of noise in the measured data and the properties of the forward operator. In particular, if :math:`A` is a linear operator, then the difficulty of the inverse problem depends significantly on the condition number of :math:`A`, since a large condition number implies that large changes in :math:`\mb{x}` can correspond to small changes in :math:`\mb{y}`, making it difficult to estimate :math:`\mb{x}` from :math:`\mb{y}`. When there is a significant amount of measurement noise or ill-conditioning of :math:`A`, the standard approach to resolve the limitations in the information available from the measured data is to introduce a *prior model* of the solution space, which is typically achieved by adding a *regularization term* to the data fidelity term, resulting in the optimization problem .. math:: \hat{\mb{x}} = \mathop{\mathrm{arg\,min}}_{\mb{x}} f(\mb{x}) + g(C (\mb{x})) \,, where the functional :math:`g(C(\cdot))` is designed to increase the cost for solutions that are considered less likely or desirable, based on prior knowledge of the properties of the solution space. A common choice of :math:`g(C(\cdot))` is the total variation norm .. math:: g(\mb{x}) = \lambda \| C \mb{x} \|_{2,1} \,, where :math:`\lambda` is a scalar controlling the regularization strength, :math:`C` is a linear operator that computes the spatial gradients of its argument, and :math:`\| \cdot \|_{2,1}` denotes the :math:`\ell_{2,1}` norm, which promotes group sparsity. Use of this functional as a regularization term corresponds to the assumption that the images of interest are piecewise constant. In SCICO, we can represent this regularization functional using a built-in linear operator and a member of the :mod:`.functional` module: .. code:: python C = scico.linop.FiniteDifference(A.input_shape, append=0) λ = 1.0e-1 g = λ * scico.functional.L21Norm() Computing the value of the regularizer then closely matches the math: :code:`g(C(x))`. Finally, the overall objective function needs to be optimized. One of the primary goals of SCICO is to make the solution of such problems accessible to application domain scientists with limited expertise in computational imaging, providing infrastructure for solving this type of problem efficiently, without the need for the user to implement complex algorithms. Solvers ------- Once an inverse problem has been specified using the above components, the resulting functional must be minimized in order to solve the problem. SCICO provides a number of optimization algorithms for addressing a wide range of problems. These optimization algorithms belong to two distinct categories. Basic Solvers ~~~~~~~~~~~~~ The :mod:`scico.solver` module provides a number of functions for solving linear systems and simple optimization problems, some of which are useful as subproblem solvers within the proximal algorithms described in the following section. It also provides an interface to functions in :mod:`scipy.optimize`, supporting their use with multi-dimensional arrays and scico :class:`.Functional` objects. These algorithms are useful both as subproblem solvers within the proximal algorithms described below, as well as for direct solution of higher-level problems. For example, .. code:: python f = scico.loss.PoissonLoss(y=y, A=A) method = 'BFGS' # or any method available for scipy.optimize.minimize x0 = scico.numpy.ones(A.input_shape) res = scico.solver.minimize(f, x0=x0, method=method) x_hat = res.x defines a Poisson objective function and minimizes it using the BFGS :cite:`nocedal-2006-numerical` algorithm. Proximal Algorithms ~~~~~~~~~~~~~~~~~~~ The :mod:`scico.optimize` sub-package provides a set of *proximal algorithms* :cite:`parikh-2014-proximal` that have proven to be useful for solving imaging inverse problems. The common feature of these algorithms is their exploitation of the *proximal operator* :cite:`beck-2017-first` (Ch. 6), of the components of the functions that they minimize. **ADMM** The most flexible of the proximal algorithms supported by SCICO is the alternating direction method of multipliers (ADMM) :cite:`glowinski-1975-approximation` :cite:`gabay-1976-dual` :cite:`boyd-2010-distributed`, which supports solving problems of the form .. math:: \mathop{\mathrm{arg\,min}}_{\mb{x}} \; f(\mb{x}) + \sum_{i=1}^N g_i(C_i \mb{x}) \,. When :math:`f(\cdot)` is an instance of ``scico.loss.SquaredL2Loss``, i.e., .. math:: f(\mb{x}) = (1/2) \| A \mb{x} - \mb{y} \|_2^2 \,, for linear operator :math:`A` and constant vector :math:`\mb{y}`, the primary computational cost of the algorithm is typically in solving a linear system involving a weighted sum of :math:`A^\top A` and the :math:`C_i^\top C_i`, assuming that the proximal operators of the functionals :math:`g_i(\cdot)` can be computed efficiently. This linear system can also be solved efficiently when :math:`A` and all of the :math:`C_i` are either identity operators or circular convolutions. **Proximal ADMM** Proximal ADMM :cite:`deng-2015-global` solves problems of the form .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + g(\mb{z}) \; \text{such that}\; A \mb{x} + B \mb{z} = \mb{c} \;, where :math:`A` and :math:`B` are linear operators. There is also a non-linear PADMM solver :cite:`benning-2016-preconditioned` for problems of the form .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + g(\mb{z}) \; \text{such that}\; H(\mb{x}, \mb{z}) = 0 \;, where :math:`H` is a function. For some problems, proximal ADMM converges substantially faster than ADMM or linearized ADMM. **Linearized ADMM** Linearized ADMM :cite:`yang-2012-linearized` :cite:`parikh-2014-proximal` solves a more restricted problem form, .. math:: \mathop{\mathrm{arg\,min}}_{\mb{x}} \; f(\mb{x}) + g(C \mb{x}) \,. It is an effective algorithm when the proximal operators of both :math:`f(\cdot)` and :math:`g(\cdot)` can be computed efficiently, and has the advantage over "standard" ADMM of avoiding the need for solving a linear system involving :math:`C^\top C`. **PDHG** Primal–dual hybrid gradient (PDHG) :cite:`esser-2010-general` :cite:`chambolle-2010-firstorder` :cite:`pock-2011-diagonal` solves the same form of problem as linearized ADMM .. math:: \mathop{\mathrm{arg\,min}}_{\mb{x}} \; f(\mb{x}) + g(C \mb{x}) \,, but unlike the linearized ADMM implementation, both linear and non-linear operators :math:`C` are supported. For some problems, PDHG converges substantially faster than ADMM or linearized ADMM. **PGM and Accelerated PGM** The proximal gradient method (PGM) :cite:`daubechies-2004-iterative` and accelerated proximal gradient method (APGM), which is also known as FISTA :cite:`beck-2017-first`, solve problems of the form .. math:: \mathop{\mathrm{arg\,min}}_{\mb{x}} \; f(\mb{x}) + g(\mb{x}) \,, where :math:`f(\cdot)` is assumed to be differentiable, and :math:`g(\cdot)` is assumed to have a proximal operator that can be computed efficiently. These algorithms typically require more iterations for convergence than ADMM, but can provide faster convergence with time when the linear solve required by ADMM is slow to compute. Machine Learning ---------------- While relatively simple regularization terms such as the total variation norm can be effective when the underlying assumptions are well matched to the data (e.g., the reconstructed images for certain materials science applications really are approximately piecewise constant), it is difficult to design mathematically simple regularization terms that adequately represent the properties of the complex data that is often encountered in practice. A widely-used alternative framework for regularizing the solution of imaging inverse problems is *plug-and-play priors* (PPP) :cite:`venkatakrishnan-2013-plugandplay2` :cite:`sreehari-2016-plug` :cite:`kamilov-2023-plugandplay`, which provides a mechanism for exploiting image denoisers such as BM3D :cite:`dabov-2008-image` as implicit priors. With the rise of deep learning methods, PPP provided one of the first frameworks for applying machine learning methods to inverse problems via the use of learned denoisers such as DnCNN :cite:`zhang-2017-dncnn`. SCICO supports PPP inverse problems solutions with both BM3D and DnCNN denoisers, and provides usage examples for both choices. BM3D is more flexible, as it includes a tunable noise level parameter, while SCICO only includes DnCNN models trained at three different noise levels (as in the original DnCNN paper), but DnCNN has a significant speed advantage when GPUs are available. As an example, the following code outline demonstrates a PPP solution, with a non-negativity constraint and a 17-layer DnCNN denoiser as a regularizer, of an inverse problem with measurement, :math:`\mb{y}`, and a generic linear forward operator, :math:`A`. .. code:: python ρ = 0.3 # ADMM penalty parameter maxiter = 10 # number of ADMM iterations f = scico.loss.SquaredL2Loss(y=y, A=A) g1 = scico.functional.DnCNN("17M") g2 = scico.functional.NonNegativeIndicator() C = scico.linop.Identity(A.input_shape) solver = scico.optimize.admm.ADMM( f=f, g_list=[g1, g2], C_list=[C, C], rho_list=[ρ, ρ], x0=A.T @ y, maxiter=maxiter, subproblem_solver=scico.optimize.admm.LinearSubproblemSolver(), itstat_options={"display": True, "period": 5}, ) x_hat = solver.solve() Example results for this type of approach applied to image deconvolution (i.e. with forward operator, :math:`A`, as a convolution) are shown in the figure below. .. image:: /figures/deconv_ppp_dncnn.png :align: center :width: 95% :alt: Image deconvolution via PPP with DnCNN denoiser. | More recently, a wider variety of frameworks have been developed for applying deep learning methods to inverse problems, including the application of the adjoint of the forward operator to map the measurement to the solution space followed by an artifact removal CNN :cite:`jin-2017-unet`, and learned networks with structures based on the unrolling of iterative algorithms such as PPP :cite:`monga-2021-algorithm`. A number of these methods are currently being implemented, and will be included in a future SCICO release. It is worth noting, however, that while some of these methods offer superior performance to PPP, it is at the cost of having to train the models with problem-specific data, which may be difficult to obtain, while PPP is often able to function well with a denoiser trained on generic image data. ================================================ FILE: docs/source/notes.rst ================================================ ***** Notes ***** Debugging ========= If difficulties are encountered in debugging jitted functions, jit can be globally disabled by setting the environment variable ``JAX_DISABLE_JIT=1`` before running Python, as in :: JAX_DISABLE_JIT=1 python test_script.py Double Precision ================ By default, JAX enforces single-precision numbers. Double precision can be enabled in one of two ways: 1. Setting the environment variable ``JAX_ENABLE_X64=TRUE`` before launching Python. 2. Manually setting the ``jax_enable_x64`` flag **at program startup**; that is, **before** importing SCICO. :: from jax.config import config config.update("jax_enable_x64", True) import scico # continue as usual For more information, see the `JAX notes on double precision `_. Device Control ============== Use of the CPU device can be forced even when GPUs are present by setting the environment variable ``JAX_PLATFORM_NAME=cpu`` before running Python. This also serves to disable the warning that older versions of JAX issued when running on a platform without a GPU, but this should no longer be necessary for any JAX versions supported by SCICO. By default, JAX views a multi-core CPU as a single device. Primarily for testing purposes, it may be useful to instruct JAX to emulate multiple CPU devices, by setting the environment variable ``XLA_FLAGS='--xla_force_host_platform_device_count='``, where ```` is an integer number of devices. For more detail see the relevant `section of the JAX docs `__. By default, JAX will preallocate a large chunk of GPU memory on startup. This behavior can be controlled using environment variables ``XLA_PYTHON_CLIENT_PREALLOCATE``, ``XLA_PYTHON_CLIENT_MEM_FRACTION``, and ``XLA_PYTHON_CLIENT_ALLOCATOR``, as described in the relevant `section of the JAX docs `__. Random Number Generation ======================== JAX implements an explicit, non-stateful pseudorandom number generator (PRNG). The user is responsible for generating a PRNG key and mutating it each time a new random number is generated. We recommend users read the `JAX documentation `_ for information on the design of JAX random number functionality. In :mod:`scico.random` we provide convenient wrappers around several `jax.random `_ routines to handle the generation and splitting of PRNG keys. :: # Calls to scico.random functions always return a PRNG key # If no key is passed to the function, a new key is generated x, key = scico.random.randn((2,)) print(x) # [ 0.19307713 -0.52678305] # scico.random functions automatically split the PRNGkey and return # an updated key y, key = scico.random.randn((2,), key=key) print(y) # [ 0.00870693 -0.04888531] The user is responsible for passing the PRNG key to :mod:`scico.random` functions. If no key is passed, repeated calls to :mod:`scico.random` functions will return the same random numbers: :: x, key = scico.random.randn((2,)) print(x) # [ 0.19307713 -0.52678305] # No key passed, will return the same random numbers! y, key = scico.random.randn((2,)) print(y) # [ 0.19307713 -0.52678305] .. _non_jax_dep: Compiled Dependency Packages ============================ The code acceleration and automatic differentiation features of JAX are not available for some components of SCICO that are provided via interfaces to compiled C code. When these components are used on a platform with GPUs, the remainder of the code will run on a GPU, but there is potential for a considerable delay due to host-GPU memory transfers. This issue primarily affects: Denoisers --------- The :func:`.bm3d` and :func:`.bm4d` denoisers (and the corresponding :class:`.BM3D` and :class:`.BM4D` pseudo-functionals) are implemented via interfaces to the `bm3d `__ and `bm4d `__ packages respectively. The :class:`~.denoiser.DnCNN` denoiser (and the corresponding :class:`~.functional.DnCNN` pseudo-functional) denoiser should be used when the full benefits of JAX-based code are required. Tomographic Projectors/Radon Transforms --------------------------------------- Note that the tomographic projections that are frequently referred to as Radon transforms are referred to as X-ray transforms in SCICO. While the Radon transform is far more well-known than the X-ray transform, which is the same as the Radon transform for projections in two dimensions, these two transform differ in higher numbers of dimensions, and it is the X-ray transform that is the appropriate mathematical model for beam attenuation based imaging in three or more dimensions. SCICO includes three different implementations of X-ray transforms. Of these, :class:`.linop.XRayTransform` is an integral component of SCICO, while the other two depend on external packages. The :class:`.xray.svmbir.XRayTransform` class is implemented via an interface to the `svmbir `__ package. The :class:`.xray.astra.XRayTransform2D` and :class:`.xray.astra.XRayTransform3D` classes are implemented via an interface to the `ASTRA toolbox `__. This toolbox does provide some GPU acceleration support, but efficiency is expected to be lower than JAX-based code due to host-GPU memory transfers. Automatic Differentiation Caveats ================================= Complex Functions ----------------- The JAX-defined gradient of a complex-valued function is a complex-conjugated version of the usual gradient used in mathematical optimization and computational imaging. Minimizing a function using the JAX convention involves taking steps in the direction of the complex conjugated gradient. The function :func:`scico.grad` returns the expected gradient, that is, the conjugate of the JAX gradient. For further discussion, see this `JAX issue `_. As a concrete example, consider the function :math:`f(x) = \frac{1}{2}\norm{\mb{A} \mb{x}}_2^2` where :math:`\mb{A}` is a complex matrix. The gradient of :math:`f` is usually given :math:`(\nabla f)(\mb{x}) = \mb{A}^H \mb{A} \mb{x}`, where :math:`\mb{A}^H` is the conjugate transpose of :math:`\mb{A}`. Applying :func:`jax.grad` to :math:`f` will yield :math:`(\mb{A}^H \mb{A} \mb{x})^*`, where :math:`\cdot^*` denotes complex conjugation. The following code demonstrates the use of :func:`jax.grad` and :func:`scico.grad`: :: m, n = (4, 3) A, key = randn((m, n), dtype=np.complex64, key=None) x, key = randn((n,), dtype=np.complex64, key=key) def f(x): return 0.5 * snp.linalg.norm(A @ x)**2 an_grad = A.conj().T @ A @ x # The expected gradient np.testing.assert_allclose(jax.grad(f)(x), an_grad.conj(), rtol=1e-4) np.testing.assert_allclose(scico.grad(f)(x), an_grad, rtol=1e-4) Non-differentiable Functionals ------------------------------ :func:`scico.grad` can be applied to any function, but has undefined behavior for non-differentiable functions. For non-differerentiable functions, :func:`scico.grad` may or may not return a valid subgradient. As an example, ``scico.grad(snp.abs)(0.) = 0``, which is a valid subgradient. However, ``scico.grad(snp.linalg.norm)([0., 0.]) = [nan, nan]``. Differentiable functions that are written as the composition of a differentiable and non-differentiable function should be avoided. As an example, :math:`f(x) = \norm{x}_2^2` can be implemented in as ``f = lambda x: snp.linalg.norm(x)**2``. This involves first calculating the non-squared :math:`\ell_2` norm, then squaring it. The un-squared :math:`\ell_2` norm is not differentiable at zero. When evaluating the gradient of ``f`` at 0, :func:`scico.grad` returns :data:`~numpy.NaN`: :: >>> import scico >>> import scico.numpy as snp >>> f = lambda x: snp.linalg.norm(x)**2 >>> scico.grad(f)(snp.zeros(2, dtype=snp.float32)) # doctest: +SKIP Array([nan, nan], dtype=float32) This can be fixed (assuming real-valued arrays only) by defining the squared :math:`\ell_2` norm directly as ``g = lambda x: snp.sum(x**2)``. The gradient will work as expected: :: >>> g = lambda x: snp.sum(x**2) >>> scico.grad(g)(snp.zeros(2, dtype=snp.float32)) #doctest: +SKIP Array([0., 0.], dtype=float32) If complex-valued arrays also need to be supported, a minor modification is necessary: :: >>> g = lambda x: snp.sum(snp.abs(x)**2) >>> scico.grad(g)(snp.zeros(2, dtype=snp.float32)) #doctest: +SKIP Array([0., 0.], dtype=float32) >>> scico.grad(g)(snp.zeros(2, dtype=snp.complex64)) #doctest: +SKIP Array([0.-0.j, 0.-0.j], dtype=complex64) An alternative is to define a `custom derivative rule `_ to enforce a particular derivative convention at a point. JAX Arrays ========== JAX utilizes a new array type :class:`~jax.Array`, which is similar to NumPy :class:`~numpy.ndarray`, but can be backed by CPU, GPU, or TPU memory and is immutable. JAX and NumPy Arrays -------------------- SCICO and JAX functions can be applied directly to NumPy arrays without explicit conversion to JAX arrays, but this is not recommended, as it can result in repeated data transfers from the CPU to GPU. Consider this toy example on a system with a GPU present: :: x = np.random.randn(8) # Array on host A = np.random.randn(8, 8) # Array on host y = snp.dot(A, x) # A, x transfered to GPU # y resides on GPU z = y + x # x must be transfered to GPU again The unnecessary transfer can be avoided by first converting ``A`` and ``x`` to JAX arrays: :: x = np.random.randn(8) # array on host A = np.random.randn(8, 8) # array on host x = jax.device_put(x) # transfer to GPU A = jax.device_put(A) y = snp.dot(A, x) # no transfer needed z = y + x # no transfer needed We recommend that input data be converted to JAX arrays via :func:`jax.device_put` before calling any SCICO optimizers. On a multi-GPU system, :func:`jax.device_put` can place data on a specific GPU. See the `JAX notes on data placement `_. JAX Arrays are Immutable ------------------------ Unlike standard NumPy arrays, JAX arrays are immutable: once they have been created, they cannot be changed. This prohibits in-place updating of JAX arrays. JAX provides special syntax for updating individual array elements through the `indexed update operators `_. ================================================ FILE: docs/source/overview.rst ================================================ Overview ======== `Scientific Computational Imaging Code (SCICO) `__ is a Python package for solving the inverse problems that arise in scientific imaging applications. Its primary focus is providing methods for solving ill-posed inverse problems by using an appropriate prior model of the reconstruction space. SCICO includes a growing suite of operators, cost functionals, regularizers, and optimization algorithms that may be combined to solve a wide range of problems, and is designed so that it is easy to add new building blocks. When solving a problem, these components are combined in a way that makes code for optimization routines look like the pseudocode in scientific papers. SCICO is built on top of `JAX `__ rather than `NumPy `__, enabling GPU/TPU acceleration, just-in-time compilation, and automatic gradient functionality, which is used to automatically compute the adjoints of linear operators. An example of how to solve a multi-channel tomography problem with SCICO is shown in the figure below. .. image:: /figures/scico-tomo-overview.png :align: center :width: 95% :alt: Solving a multi-channel tomography problem with SCICO. | The SCICO source code is available from `GitHub `__, and pre-built packages are available from `PyPI `__. (Detailed instructions for installing SCICO are available in :ref:`installing`.) It has extensive `online documentation `__, including :doc:`API documentation <_autosummary/scico>` and :ref:`usage examples `, which can be run online at `Google Colab `__ and `binder `__. If you use this package for published work, please cite :cite:`balke-2022-scico` (see bibtex entry ``balke-2022-scico`` in `docs/source/references.bib `_ in the source distribution). Contributing ------------ Bug reports, feature requests, and general suggestions are welcome, and should be submitted via the `GitHub issue system `__. More substantial contributions are also :ref:`welcome `. License ------- SCICO is distributed as open-source software under a BSD 3-Clause License (see the `LICENSE `__ file for details). LANL open source approval reference C20091. © 2020-2025. Triad National Security, LLC. All rights reserved. This program was produced under U.S. Government contract 89233218CNA000001 for Los Alamos National Laboratory (LANL), which is operated by Triad National Security, LLC for the U.S. Department of Energy/National Nuclear Security Administration. All rights in the program are reserved by Triad National Security, LLC, and the U.S. Department of Energy/National Nuclear Security Administration. The Government has granted for itself and others acting on its behalf a nonexclusive, paid-up, irrevocable worldwide license in this material to reproduce, prepare derivative works, distribute copies to the public, perform publicly and display publicly, and to permit others to do so. ================================================ FILE: docs/source/pyfigures/cylindgrad.py ================================================ import numpy as np import scico.linop as scl from scico import plot input_shape = (7, 7, 7) centre = (np.array(input_shape) - 1) / 2 end = np.array(input_shape) - centre g0, g1, g2 = np.mgrid[-centre[0] : end[0], -centre[1] : end[1], -centre[2] : end[2]] cg = scl.CylindricalGradient(input_shape=input_shape) ang = cg.coord[0] rad = cg.coord[1] axi = cg.coord[2] theta = np.arctan2(g0, g1) clr = theta # See https://stackoverflow.com/a/49888126 clr = (clr.ravel() - clr.min()) / np.ptp(clr) clr = np.concatenate((clr, np.repeat(clr, 2))) clr = plot.plt.cm.plasma(clr) plot.plt.rcParams["savefig.transparent"] = True fig = plot.plt.figure(figsize=(20, 6)) ax = fig.add_subplot(1, 3, 1, projection="3d") ax.quiver(g0, g1, g2, ang[0], ang[1], ang[2], colors=clr, length=0.9) ax.set_title("Angular local coordinate axis", fontsize=18) ax.set_xlabel("$x$", fontsize=15) ax.set_ylabel("$y$", fontsize=15) ax.set_zlabel("$z$", fontsize=15) ax.tick_params(labelsize=15) ax = fig.add_subplot(1, 3, 2, projection="3d") ax.quiver(g0, g1, g2, rad[0], rad[1], rad[2], colors=clr, length=0.9) ax.set_title("Radial local coordinate axis", fontsize=18) ax.set_xlabel("$x$", fontsize=15) ax.set_ylabel("$y$", fontsize=15) ax.set_zlabel("$z$", fontsize=15) ax.tick_params(labelsize=15) ax = fig.add_subplot(1, 3, 3, projection="3d") ax.quiver(g0, g1, g2, axi[0], axi[1], axi[2], colors=clr[0], length=0.9) ax.set_title("Axial local coordinate axis", fontsize=18) ax.set_xlabel("$x$", fontsize=15) ax.set_ylabel("$y$", fontsize=15) ax.set_zlabel("$z$", fontsize=15) ax.tick_params(labelsize=15) fig.tight_layout() fig.show() ================================================ FILE: docs/source/pyfigures/polargrad.py ================================================ import numpy as np import scico.linop as scl from scico import plot input_shape = (21, 21) centre = (np.array(input_shape) - 1) / 2 end = np.array(input_shape) - centre g0, g1 = np.mgrid[-centre[0] : end[0], -centre[1] : end[1]] pg = scl.PolarGradient(input_shape=input_shape) ang = pg.coord[0] rad = pg.coord[1] clr = (np.arctan2(ang[1], ang[0]) + np.pi) / (2 * np.pi) plot.plt.rcParams["image.cmap"] = "plasma" plot.plt.rcParams["savefig.transparent"] = True fig, ax = plot.plt.subplots(nrows=1, ncols=2, figsize=(13, 6)) ax[0].quiver(g0, g1, ang[0], ang[1], clr) ax[0].set_title("Angular local coordinate axis", fontsize=16) ax[0].set_xlabel("$x$", fontsize=14) ax[0].set_ylabel("$y$", fontsize=14) ax[0].tick_params(labelsize=14) ax[0].xaxis.set_ticks((-10, -5, 0, 5, 10)) ax[0].yaxis.set_ticks((-10, -5, 0, 5, 10)) ax[1].quiver(g0, g1, rad[0], rad[1], clr) ax[1].set_title("Radial local coordinate axis", fontsize=16) ax[1].set_xlabel("$x$", fontsize=14) ax[1].set_ylabel("$y$", fontsize=14) ax[1].tick_params(labelsize=14) ax[1].xaxis.set_ticks((-10, -5, 0, 5, 10)) ax[1].yaxis.set_ticks((-10, -5, 0, 5, 10)) fig.tight_layout() fig.show() ================================================ FILE: docs/source/pyfigures/spheregrad.py ================================================ import numpy as np import scico.linop as scl from scico import plot input_shape = (7, 7, 7) centre = (np.array(input_shape) - 1) / 2 end = np.array(input_shape) - centre g0, g1, g2 = np.mgrid[-centre[0] : end[0], -centre[1] : end[1], -centre[2] : end[2]] sg = scl.SphericalGradient(input_shape=input_shape) azi = sg.coord[0] pol = sg.coord[1] rad = sg.coord[2] theta = np.arctan2(g0, g1) phi = np.arctan2(np.sqrt(g0**2 + g1**2), g2) clr = theta * phi # See https://stackoverflow.com/a/49888126 clr = (clr.ravel() - clr.min()) / np.ptp(clr) clr = np.concatenate((clr, np.repeat(clr, 2))) clr = plot.plt.cm.plasma(clr) plot.plt.rcParams["savefig.transparent"] = True fig = plot.plt.figure(figsize=(20, 6)) ax = fig.add_subplot(1, 3, 1, projection="3d") ax.quiver(g0, g1, g2, azi[0], azi[1], azi[2], colors=clr, length=0.9) ax.set_title("Azimuthal local coordinate axis", fontsize=18) ax.set_xlabel("$x$", fontsize=15) ax.set_ylabel("$y$", fontsize=15) ax.set_zlabel("$z$", fontsize=15) ax.tick_params(labelsize=15) ax = fig.add_subplot(1, 3, 2, projection="3d") ax.quiver(g0, g1, g2, pol[0], pol[1], pol[2], colors=clr, length=0.9) ax.set_title("Polar local coordinate axis", fontsize=18) ax.set_xlabel("$x$", fontsize=15) ax.set_ylabel("$y$", fontsize=15) ax.set_zlabel("$z$", fontsize=15) ax.tick_params(labelsize=15) ax = fig.add_subplot(1, 3, 3, projection="3d") ax.quiver(g0, g1, g2, rad[0], rad[1], rad[2], colors=clr, length=0.9) ax.set_title("Radial local coordinate axis", fontsize=18) ax.set_xlabel("$x$", fontsize=15) ax.set_ylabel("$y$", fontsize=15) ax.set_zlabel("$z$", fontsize=15) ax.tick_params(labelsize=15) fig.tight_layout() fig.show() ================================================ FILE: docs/source/pyfigures/xray_2d_geom.py ================================================ import numpy as np import matplotlib as mpl import matplotlib.patches as patches import matplotlib.pyplot as plt mpl.rcParams["savefig.transparent"] = True c = 1.0 / np.sqrt(2.0) e = 1e-2 style = "Simple, tail_width=0.5, head_width=4, head_length=8" fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(21, 7)) # all plots for n in range(3): ax[n].set_aspect(1.0) ax[n].set_xlim(-1.1, 1.1) ax[n].set_ylim(-1.1, 1.1) ax[n].set_xticks(np.linspace(-1.0, 1.0, 5)) ax[n].set_yticks(np.linspace(-1.0, 1.0, 5)) ax[n].tick_params(axis="x", labelsize=14) ax[n].tick_params(axis="y", labelsize=14) ax[n].set_xlabel("axis 1", fontsize=16) ax[n].set_ylabel("axis 0", fontsize=16) # scico ax[0].set_title("scico", fontsize=18) plist = [ patches.FancyArrowPatch((-1.0, 0.0), (-0.5, 0.0), arrowstyle=style, color="r"), patches.FancyArrowPatch((-c, -c), (-c / 2.0, -c / 2.0), arrowstyle=style, color="r"), patches.FancyArrowPatch( ( 0.0, -1.0, ), (0.0, -0.5), arrowstyle=style, color="r", ), patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=180, theta2=-45.0, color="b", lw=2, ls="dotted"), patches.FancyArrowPatch((c - e, -c - e), (c + e, -c + e), arrowstyle=style, color="b"), ] for p in plist: ax[0].add_patch(p) ax[0].text(-0.88, 0.02, r"$\theta=0$", color="r", fontsize=16) ax[0].text(-3 * c / 4 - 0.01, -3 * c / 4 - 0.1, r"$\theta=\frac{\pi}{4}$", color="r", fontsize=16) ax[0].text(0.03, -0.8, r"$\theta=\frac{\pi}{2}$", color="r", fontsize=16) ax[0].plot((1.0, 1.0), (-0.375, 0.375), color="orange", lw=2) ax[0].arrow( 0.94, 0.375, 0.0, -0.75, color="orange", lw=1.0, ls="--", head_width=0.03, length_includes_head=True, ) ax[0].text(0.7, 0.0, r"$\theta=0$", color="orange", ha="left", fontsize=16) ax[0].plot((-0.375, 0.375), (1.0, 1.0), color="orange", lw=2) ax[0].arrow( -0.375, 0.94, 0.75, 0.0, color="orange", lw=1.0, ls="--", head_width=0.03, length_includes_head=True, ) ax[0].text(0.0, 0.82, r"$\theta=\frac{\pi}{2}$", color="orange", ha="center", fontsize=16) # astra ax[1].set_title("astra", fontsize=18) plist = [ patches.FancyArrowPatch((0.0, -1.0), (0.0, -0.5), arrowstyle=style, color="r"), patches.FancyArrowPatch((c, -c), (c / 2.0, -c / 2.0), arrowstyle=style, color="r"), patches.FancyArrowPatch((1.0, 0.0), (0.5, 0.0), arrowstyle=style, color="r"), patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=-90, theta2=45.0, color="b", lw=2, ls="dotted"), patches.FancyArrowPatch((c + e, c - e), (c - e, c + e), arrowstyle=style, color="b"), ] for p in plist: ax[1].add_patch(p) ax[1].text(0.02, -0.75, r"$\theta=0$", color="r", fontsize=16) ax[1].text(3 * c / 4 + 0.01, -3 * c / 4 + 0.01, r"$\theta=\frac{\pi}{4}$", color="r", fontsize=16) ax[1].text(0.65, 0.05, r"$\theta=\frac{\pi}{2}$", color="r", fontsize=16) ax[1].plot((-0.375, 0.375), (1.0, 1.0), color="orange", lw=2) ax[1].arrow( -0.375, 0.94, 0.75, 0.0, color="orange", lw=1.0, ls="--", head_width=0.03, length_includes_head=True, ) ax[1].text(0.0, 0.82, r"$\theta=0$", color="orange", ha="center", fontsize=16) ax[1].plot((-1.0, -1.0), (-0.375, 0.375), color="orange", lw=2) ax[1].arrow( -0.94, -0.375, 0.0, 0.75, color="orange", lw=1.0, ls="--", head_width=0.03, length_includes_head=True, ) ax[1].text(-0.9, 0.0, r"$\theta=\frac{\pi}{2}$", color="orange", ha="left", fontsize=16) # svmbir ax[2].set_title("svmbir", fontsize=18) plist = [ patches.FancyArrowPatch((-1.0, 0.0), (-0.5, 0.0), arrowstyle=style, color="r"), patches.FancyArrowPatch((-c, c), (-c / 2.0, c / 2.0), arrowstyle=style, color="r"), patches.FancyArrowPatch( ( 0.0, 1.0, ), (0.0, 0.5), arrowstyle=style, color="r", ), patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=45, theta2=180, color="b", lw=2, ls="dotted"), patches.FancyArrowPatch((c - e, c + e), (c + e, c - e), arrowstyle=style, color="b"), ] for p in plist: ax[2].add_patch(p) ax[2].text(-0.88, 0.02, r"$\theta=0$", color="r", fontsize=16) ax[2].text(-3 * c / 4 + 0.01, 3 * c / 4 + 0.01, r"$\theta=\frac{\pi}{4}$", color="r", fontsize=16) ax[2].text(0.03, 0.75, r"$\theta=\frac{\pi}{2}$", color="r", fontsize=16) ax[2].plot((1.0, 1.0), (-0.375, 0.375), color="orange", lw=2) ax[2].arrow( 0.94, 0.375, 0.0, -0.75, color="orange", lw=1.0, ls="--", head_width=0.03, length_includes_head=True, ) ax[2].text(0.7, 0.0, r"$\theta=0$", color="orange", ha="left", fontsize=16) ax[2].plot((-0.375, 0.375), (-1.0, -1.0), color="orange", lw=2) ax[2].arrow( 0.375, -0.94, -0.75, 0.0, color="orange", lw=1.0, ls="--", head_width=0.03, length_includes_head=True, ) ax[2].text(0.0, -0.82, r"$\theta=\frac{\pi}{2}$", color="orange", ha="center", fontsize=16) fig.tight_layout() fig.show() ================================================ FILE: docs/source/pyfigures/xray_3d_ang.py ================================================ import numpy as np import matplotlib as mpl import matplotlib.patches as patches import matplotlib.pyplot as plt mpl.rcParams["savefig.transparent"] = True c = 1.0 / np.sqrt(2.0) e = 1e-2 style = "Simple, tail_width=0.5, head_width=4, head_length=8" fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5, 5)) ax.set_aspect(1.0) ax.set_xlim(-1.1, 1.1) ax.set_ylim(-1.1, 1.1) ax.set_xticks(np.linspace(-1.0, 1.0, 5)) ax.set_yticks(np.linspace(-1.0, 1.0, 5)) ax.tick_params(axis="x", labelsize=12) ax.tick_params(axis="y", labelsize=12) ax.set_xlabel("$x$", fontsize=14) ax.set_ylabel("$y$", fontsize=14) plist = [ patches.FancyArrowPatch((0.0, -1.0), (0.0, -0.5), arrowstyle=style, color="r"), patches.FancyArrowPatch((c, -c), (c / 2.0, -c / 2.0), arrowstyle=style, color="r"), patches.FancyArrowPatch((1.0, 0.0), (0.5, 0.0), arrowstyle=style, color="r"), patches.Arc((0.0, 0.0), 2.0, 2.0, theta1=-90, theta2=45.0, color="b", lw=2, ls="dotted"), patches.FancyArrowPatch((c + e, c - e), (c - e, c + e), arrowstyle=style, color="b"), ] for p in plist: ax.add_patch(p) ax.text(0.02, -0.75, r"$\theta=0$", color="r", fontsize=14) ax.text( 3 * c / 4 + 0.01, -3 * c / 4 + 0.01, r"$\theta=\frac{\pi}{4}$", color="r", fontsize=14, ) ax.text(0.65, 0.05, r"$\theta=\frac{\pi}{2}$", color="r", fontsize=14) ax.plot((-0.375, 0.375), (1.0, 1.0), color="orange", lw=2) ax.arrow( -0.375, 0.94, 0.75, 0.0, color="orange", lw=0.5, ls="--", head_width=0.03, length_includes_head=True, ) ax.text(0.0, 0.82, r"$\theta=0$", color="orange", ha="center", fontsize=14) ax.plot((-1.0, -1.0), (-0.375, 0.375), color="orange", lw=2) ax.arrow( -0.94, -0.375, 0.0, 0.75, color="orange", lw=0.5, ls="--", head_width=0.03, length_includes_head=True, ) ax.text(-0.9, 0.0, r"$\theta=\frac{\pi}{2}$", color="orange", ha="left", fontsize=14) fig.tight_layout() fig.show() ================================================ FILE: docs/source/pyfigures/xray_3d_vec.py ================================================ import numpy as np import matplotlib as mpl from matplotlib import pyplot as plt from matplotlib.patches import FancyArrowPatch from mpl_toolkits.mplot3d import proj3d mpl.rcParams["savefig.transparent"] = True # See https://github.com/matplotlib/matplotlib/issues/21688 class Arrow3D(FancyArrowPatch): def __init__(self, xs, ys, zs, *args, **kwargs): FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) self._verts3d = xs, ys, zs def do_3d_projection(self, renderer=None): xs3d, ys3d, zs3d = self._verts3d xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) return np.min(zs) # Define vector components 𝜃 = 10 * np.pi / 180.0 # angle in x-y plane (azimuth angle) 𝛼 = 70 * np.pi / 180.0 # angle with z axis (zenith angle) 𝛥p, 𝛥d = 0.3, 1.0 d = (-𝛥d * np.sin(𝛼) * np.sin(𝜃), 𝛥d * np.sin(𝛼) * np.cos(𝜃), 𝛥d * np.cos(𝛼)) u = (𝛥p * np.cos(𝜃), 𝛥p * np.sin(𝜃), 0.0) v = (𝛥p * np.cos(𝛼) * np.sin(𝜃), -𝛥p * np.cos(𝛼) * np.cos(𝜃), 𝛥p * np.sin(𝛼)) # Location of text labels d_txtpos = np.array(d) + np.array([0, 0, -0.12]) u_txtpos = np.array(d) + np.array(u) + np.array([0, 0, -0.1]) v_txtpos = np.array(d) + np.array(v) + np.array([0, 0, 0.03]) arrowstyle = "-|>,head_width=2.5,head_length=9" fig, ax = plt.subplots(subplot_kw={"projection": "3d"}) # Set view ax.set_aspect("equal") ax.elev = 15 ax.azim = -50 ax.set_box_aspect(None, zoom=2) ax.set_xlim((-1.1, 1.1)) ax.set_ylim((-1.1, 1.1)) ax.set_zlim((-1.1, 1.1)) # Disable shaded 3d axis grids ax.set_axis_off() # Draw central x,y,z axes and labels axis_crds = np.array([[-1, 1], [0, 0], [0, 0]]) axis_lbls = ("$x$", "$y$", "$z$") for k in range(3): crd = np.roll(axis_crds, k, axis=0) ax.add_artist( Arrow3D( *crd.tolist(), lw=1.5, ls="--", arrowstyle=arrowstyle, color="black", ) ) ax.text(*(1.05 * crd[:, 1]).tolist(), axis_lbls[k], fontsize=12) # Draw d, u, v and labels ax.quiver(0, 0, 0, *d, arrow_length_ratio=0.08, lw=2, color="blue") ax.quiver(*d, *u, arrow_length_ratio=0.08 / 𝛥p, lw=2, color="blue") ax.quiver(*d, *v, arrow_length_ratio=0.08 / 𝛥p, lw=2, color="blue") ax.text(*d_txtpos, r"$\mathbf{d}$", fontsize=12) ax.text(*u_txtpos, r"$\mathbf{u}$", fontsize=12) ax.text(*v_txtpos, r"$\mathbf{v}$", fontsize=12) fig.tight_layout() fig.subplots_adjust(-0.1, -0.06, 1, 1) fig.show() ================================================ FILE: docs/source/pyfigures/xray_3d_vol.py ================================================ import numpy as np import matplotlib as mpl from matplotlib import pyplot as plt from matplotlib.patches import FancyArrowPatch from mpl_toolkits.mplot3d import proj3d mpl.rcParams["savefig.transparent"] = True # See https://github.com/matplotlib/matplotlib/issues/21688 class Arrow3D(FancyArrowPatch): def __init__(self, xs, ys, zs, *args, **kwargs): FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) self._verts3d = xs, ys, zs def do_3d_projection(self, renderer=None): xs3d, ys3d, zs3d = self._verts3d xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) return np.min(zs) # Define vector components 𝜃 = 10 * np.pi / 180.0 # angle in x-y plane (azimuth angle) 𝛼 = 70 * np.pi / 180.0 # angle with z axis (zenith angle) 𝛥p, 𝛥d = 0.3, 1.0 d = (-𝛥d * np.sin(𝛼) * np.sin(𝜃), 𝛥d * np.sin(𝛼) * np.cos(𝜃), 𝛥d * np.cos(𝛼)) u = (𝛥p * np.cos(𝜃), 𝛥p * np.sin(𝜃), 0.0) v = (𝛥p * np.cos(𝛼) * np.sin(𝜃), -𝛥p * np.cos(𝛼) * np.cos(𝜃), 𝛥p * np.sin(𝛼)) # Location of text labels d_txtpos = np.array(d) + np.array([0, 0, -0.12]) u_txtpos = np.array(d) + np.array(u) + np.array([0, 0, -0.1]) v_txtpos = np.array(d) + np.array(v) + np.array([0, 0, 0.03]) arrowstyle = "-|>,head_width=2.5,head_length=9" fig, ax = plt.subplots(subplot_kw={"projection": "3d"}) # Set view ax.set_aspect("equal") ax.elev = 40 ax.azim = -60 ax.set_box_aspect(None, zoom=1.8) ax.set_xlim((-10.5, 10.5)) ax.set_ylim((-10.5, 10.5)) ax.set_zlim((-10.5, 10.5)) # Disable shaded 3d axis grids ax.set_axis_off() # Draw central x,y,z axes and labels axis_crds = np.array([[-10, 10], [0, 0], [0, 0]]) axis_lbls = ("$x$", "$y$", "$z$") for k in range(3): crd = np.roll(axis_crds, k, axis=0) ax.add_artist( Arrow3D( *crd.tolist(), lw=1.5, ls="--", arrowstyle=arrowstyle, color="black", ) ) ax.text(*(1.05 * crd[:, 1]).tolist(), axis_lbls[k], fontsize=12) wx = 4 wy = 3 wz = 2 bx = np.array([-wx, wx, wx, wx, -wx, -wx, -wx]) by = np.array([-wy, -wy, wy, wy, wy, -wy, -wy]) bz = np.array([-wz, -wz, -wz, wz, wz, wz, -wz]) ax.plot(bx, by, bz, lw=2, color="blue") ax.plot(bx[0:3], by[0:3], -bz[0:3], lw=2, color="blue") bx = np.array([wx, wx]) by = np.array([-wy, -wy]) bz = np.array([-wz, wz]) ax.plot(bx, by, bz, lw=2, color="blue") bx = np.array([-wx, -wx, wx]) by = np.array([-wy, wy, wy]) bz = np.array([-wz, -wz, -wz]) ax.plot(bx, by, bz, lw=2, ls="--", color="blue") bx = np.array([-wx, -wx]) by = np.array([wy, wy]) bz = np.array([-wz, wz]) ax.plot(bx, by, bz, lw=2, ls="--", color="blue") fig.tight_layout() fig.subplots_adjust(-0.1, -0.1, 1, 1.07) fig.show() ================================================ FILE: docs/source/references.bib ================================================ @Article {aggarwal-2019-modl, author = {Aggarwal, Hemant K. and Mani, Merry P. and Jacob, Mathews}, journal = {IEEE Transactions on Medical Imaging}, title = {{MoDL}: Model-Based Deep Learning Architecture for Inverse Problems}, year = 2019, volume = 38, number = 2, pages = {394--405}, doi = {10.1109/TMI.2018.2865356} } @Article {alliney-1992-digital, author = {Alliney, Stefano}, journal = {IEEE Transactions on Signal Processing}, title = {Digital filters as absolute norm regularizers}, year = 1992, volume = 40, number = 6, pages = {1548--1562}, doi = {10.1109/78.139258}, month = Jun } @Article {almeida-2013-deconvolving, author = {Almeida, Mariana S. C. and Figueiredo, M\'ario}, journal = {IEEE Transactions on Image Processing}, title = {Deconvolving Images With Unknown Boundaries Using the Alternating Direction Method of Multipliers}, year = 2013, month = Aug, volume = 22, number = 8, pages = {3074--3086}, doi = {10.1109/TIP.2013.2258354} } @Article {antipa-2018-diffusercam, author = {Nick Antipa and Grace Kuo and Reinhard Heckel and Ben Mildenhall and Emrah Bostan and Ren Ng and Laura Waller}, title = {{DiffuserCam}: lensless single-exposure 3{D} imaging}, journal = {Optica}, year = 2018, month = Jan, volume = 5, number = 1, doi = {10.1364/optica.5.000001}, pages = {1--9} } @Article {balke-2022-scico, author = {Thilo Balke and Fernando Davis and Cristina Garcia-Cardona and Soumendu Majee and Michael McCann and Luke Pfister and Brendt Wohlberg}, title = {Scientific Computational Imaging Code ({SCICO})}, journal = {Journal of Open Source Software}, year = 2022, volume = 7, number = 78, pages = 4722, doi = {10.21105/joss.04722} } @Article {barzilai-1988-stepsize, author = {Jonathan Barzilai and Jonathan M. Borwein}, title = {Two-point step size gradient methods}, journal = {{IMA} Journal of Numerical Analysis}, volume = 8, pages = {141--148}, year = 1988, month = Jan, doi = {10.1093/imanum/8.1.141} } @Article {beck-2009-fast, title = {A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse Problems}, author = {Beck, Amir and Teboulle, Marc}, journal = {SIAM Journal on Imaging Sciences}, year = 2009, volume = 2, number = 1, pages = {183--202}, doi = {10.1137/080716542} } @Article {beck-2009-tv, title = {Fast Gradient-Based Algorithms for Constrained Total Variation Image Denoising and Deblurring Problems}, author = {Beck, Amir and Teboulle, Marc}, journal = {IEEE Transactions on Image Processing}, year = 2009, month = Nov, volume = 18, number = 11, pages = {2419--2434}, doi = {10.1109/TIP.2009.2028250} } @InCollection {beck-2010-gradient, author = {Amir Beck and Marc Teboulle}, editor = {Daniel P. Palomar and Yonina C. Eldar}, title = {Gradient-based algorithms with applications to signal-recovery problems}, booktitle = {Convex Optimization in Signal Processing and Communications}, pages = {42--88}, publisher = {Cambridge University Press}, year = 2010, doi = {10.1017/CBO9780511804458.003}, url = {http://www.math.tau.ac.il/~teboulle/papers/gradient_chapter.pdf} } @Software {bradbury-2018-jax, author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, url = {http://github.com/google/jax}, version = {0.2.5}, year = 2018 } @Book {beck-2017-first, title = {First-order methods in optimization}, author = {Beck, Amir}, year = 2017, publisher = {Society for Industrial and Applied Mathematics (SIAM)}, doi = {10.1137/1.9781611974997}, isbn = 1611974984 } @InProceedings {benning-2016-preconditioned, title = {Preconditioned {ADMM} with nonlinear operator constraint}, author = {Benning, Martin and Knoll, Florian and Sch{\"o}nlieb, Carola-Bibiane and Valkonen, Tuomo}, booktitle = {IFIP Conference on System Modeling and Optimization (CSMO) 2015}, pages = {117--126}, year = 2016, doi = {10.1007/978-3-319-55795-3_10} } @Article {boyd-2010-distributed, title = {Distributed optimization and statistical learning via the alternating direction method of multipliers}, author = {Boyd, Stephen and Parikh, Neal and Chu, Eric and Peleato, Borja and Eckstein, Jonathan}, journal = {Foundations and Trends in Machine Learning}, year = 2010, volume = 3, number = 1, pages = {1--122}, doi = {10.1561/2200000016} } @Article {buzzard-2018-plug, title = {Plug-and-play unplugged: Optimization-free reconstruction using consensus equilibrium}, author = {Buzzard, Gregery T. and Chan, Stanley H. and Sreehari, Suhas and Bouman, Charles A.}, journal = {SIAM Journal on Imaging Sciences}, volume = 11, number = 3, pages = {2001--2020}, year = 2018, doi = {10.1137/17M1122451} } @Article {cai-2010-singular, title = {A Singular Value Thresholding Algorithm for Matrix Completion}, author = {Cai, Jian-Feng and Cand{\`e}s, Emmanuel J. and Shen, Zuowei}, journal = {SIAM Journal on Optimization}, year = 2010, volume = 20, number = 4, pages = {1956--1982}, doi = {10.1137/080738970} } @Article {chambolle-2010-firstorder, author = {Antonin Chambolle and Thomas Pock}, title = {A First-Order Primal-Dual Algorithm for Convex Problems with~Applications to Imaging}, journal = {Journal of Mathematical Imaging and Vision}, doi = {10.1007/s10851-010-0251-1}, year = 2010, month = Dec, volume = 40, number = 1, pages = {120--145} } @Misc {chandler-2024-closedform, author = {Edward P. Chandler and Shirin Shoushtari and Brendt Wohlberg and Ulugbek S. Kamilov}, title = {Closed-Form Approximation of the Total Variation Proximal Operator}, year = 2024, eprint = {2412.07718} } @Article {clinthorne-1993-preconditioning, author = {Clinthorne, Neal H. and Pan, Tin-Su and Chiao, Ping-Chun and Rogers, W. Leslie and Stamos, John A.}, title = {Preconditioning methods for improved convergence rates in iterative reconstructions}, journal = {IEEE Transactions on Medical Imaging}, year = 1993, volume = 12, number = 1, pages = {78--83}, month = Mar, doi = {10.1109/42.222670} } @InProceedings {dabov-2008-image, author = {Kostadin Dabov and Alessandro Foi and Vladimir Katkovnik and Karen Egiazarian}, title = {Image restoration by sparse {3D} transform-domain collaborative filtering}, volume = 6812, booktitle = {Image Processing: Algorithms and Systems VI}, editor = {Jaakko T. Astola and Karen O. Egiazarian and Edward R. Dougherty}, organization = {International Society for Optics and Photonics}, publisher = {SPIE}, pages = {62--73}, year = 2008, month = Mar, doi = {10.1117/12.766355} } @Article {daubechies-2004-iterative, title = {An iterative thresholding algorithm for linear inverse problems with a sparsity constraint}, author = {Daubechies, Ingrid and Defrise, Michel and De Mol, Christine}, journal = {Communications on Pure and Applied Mathematics}, volume = 57, number = 11, pages = {1413--1457}, year = 2004, doi = {10.1002/cpa.20042} } @Article {deng-2015-global, author = {Wei Deng and Wotao Yin}, title = {On the Global and Linear Convergence of the Generalized Alternating Direction Method of Multipliers}, journal = {Journal of Scientific Computing}, year = 2015, month = May, volume = 66, number = 3, pages = {889--916}, doi = {10.1007/s10915-015-0048-x}, } @Misc {diamond-2018-odp, author = {Steven Diamond and Vincent Sitzmann and Felix Heide and Gordon Wetzstein}, title = {Unrolled Optimization with Deep Priors}, year = 2018, eprint = {1705.08041v2} } @Article {esser-2010-general, author = {Ernie Esser and Xiaoqun Zhang and Tony F. Chan}, title = {A General Framework for a Class of First Order Primal-Dual Algorithms for Convex Optimization in Imaging Science}, journal = {SIAM Journal on Imaging Sciences}, doi = {10.1137/09076934x}, year = 2010, month = Jan, volume = 3, number = 4, pages = {1015--1046} } @PhDThesis {esser-2010-primal, author = {Ernie Esser}, title = {Primal Dual Algorithms for Convex Models and Applications to Image Restoration, Registration and Nonlocal Inpainting}, school = {University of California Los Angeles}, year = 2010 } @InProceedings {florea-2017-robust, title = {A Robust {FISTA}-Like Algorithm}, author = {Mihai I. Florea and Sergiy A. Vorobyov}, booktitle = {Proceedings of the IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, year = 2017, month = Mar, pages = {4521--4525}, doi = {10.1109/ICASSP.2017.7953012}, location = {New Orleans, LA, USA} } @Article {gabay-1976-dual, title = {A dual algorithm for the solution of nonlinear variational problems via finite element approximation}, author = {Gabay, Daniel and Mercier, Bertrand}, journal = {Computers \& Mathematics with Applications}, volume = 2, number = 1, pages = {17--40}, year = 1976, doi = {10.1016/0898-1221(76)90003-1} } @Article {glowinski-1975-approximation, title = {Sur l'approximation, par {\'e}l{\'e}ments finis d'ordre un, et la r{\'e}solution, par p{\'e}nalisation-dualit{\'e} d'une classe de probl{\`e}mes de Dirichlet non lin{\'e}aires}, author = {Glowinski, Roland and Marroco, Americo}, journal = {ESAIM: Mathematical Modelling and Numerical Analysis - Mod{\'e}lisation Math{\'e}matique et Analyse Num{\'e}rique}, volume = 9, number = {R2}, pages = {41--76}, year = 1975, url = {http://eudml.org/doc/193269} } @Article {goldstein-2009-split, author = {Tom Goldstein and Stanley Osher}, title = {The Split {B}regman Method for L1-Regularized Problems}, journal = {SIAM Journal on Imaging Sciences}, volume = 2, number = 2, pages = {323--343}, year = 2009, doi = {10.1137/080725891} } @Misc {goldstein-2014-fasta, title = {A Field Guide to Forward-Backward Splitting with a {FASTA} Implementation}, author = {Tom Goldstein and Christoph Studer and Richard Baraniuk}, year = 2014, eprint = {1411.3406}, url = {http://arxiv.org/abs/1411.3406}, } @Book {goodman-2005-fourier, author = {Goodman, Joseph W.}, title = {Introduction to {F}ourier Optics}, publisher = {McGraw-Hill}, year = 2005, isbn = 9780974707723, edition = 3 } @Misc {hossein-2024-total, title = {Total Variation Regularization for Tomographic Reconstruction of Cylindrically Symmetric Objects}, author = {Maliha Hossain and Charles A. Bouman and Brendt Wohlberg}, year = 2024, eprint = {2406.17928} } @Article {hoyer-2004-nonnegative, title = {Non-negative matrix factorization with sparseness constraints}, author = {Patrik O. Hoyer}, journal = {Journal of Machine Learning Research}, volume = 5, number = Nov, pages = {1457--1469}, year = 2004, url = {https://www.jmlr.org/papers/volume5/hoyer04a/hoyer04a.pdf} } @Article {huber-1964-robust, doi = {10.1214/aoms/1177703732}, year = 1964, month = Mar, volume = 35, number = 1, pages = {73--101}, author = {Peter J. Huber}, title = {Robust Estimation of a Location Parameter}, journal = {The Annals of Mathematical Statistics} } @Article {jin-2017-unet, title = {Deep Convolutional Neural Network for Inverse Problems in Imaging}, author = {Kyong Hwan Jin and Michael T. McCann and Emmanuel Froustey and Michael Unser}, journal = {IEEE Transactions on Image Processing}, volume = 26, number = 9, pages = {4509--4522}, year = 2017, doi = {10.1109/TIP.2017.2713099} } @Book {kak-1988-principles, author = {Avinash C. Kak and Malcolm Slaney}, title = {Principles of Computerized Tomographic Imaging}, publisher = {IEEE Press}, year = 1988 } @TechReport {kamilov-2016-minimizing, author = {Ulugbek S. Kamilov}, title = {Minimizing Isotropic Total Variation without Subiterations}, institution = {Mitsubishi Electric Research Laboratories (MERL)}, year = 2016, number = {TR2016-109}, month = Aug, note = {Presented at International Traveling Workshop on Interactions Between Sparse Models and Technology (iTWIST) 2016}, url = {https://www.merl.com/publications/docs/TR2016-109.pdf} } @Article {kamilov-2016-parallel, title = {A parallel proximal algorithm for anisotropic total variation minimization}, author = {Ulugbek S. Kamilov}, journal = {IEEE Transactions on Image Processing}, volume = 26, number = 2, pages = {539--548}, year = 2016, doi = {10.1109/tip.2016.2629449 } } @Article {kamilov-2017-plugandplay, author = {Ulugbek S. Kamilov and Hassan Mansour and Brendt Wohlberg}, title = {A Plug-and-Play Priors Approach for Solving Nonlinear Imaging Inverse Problems}, year = 2017, month = Dec, journal = {IEEE Signal Processing Letters}, volume = 24, number = 12, doi = {10.1109/LSP.2017.2763583}, pages = {1872--1876} } @Article {kamilov-2023-plugandplay, author = {Ulugbek S. Kamilov and Charles A. Bouman and Gregery T. Buzzard and Brendt Wohlberg}, title = {Plug-and-Play Methods for Integrating Physical and Learned Models in Computational Imaging}, journal = {IEEE Signal Processing Magazine}, year = 2023, month = Jan, volume = 40, number = 1, pages = {85--97}, doi = {10.1109/MSP.2022.3199595} } @Article {liu-2018-first, author = {Jialin Liu and Cristina Garcia-Cardona and Brendt Wohlberg and Wotao Yin}, title = {First and Second Order Methods for Online Convolutional Dictionary Learning}, journal = {SIAM Journal on Imaging Sciences}, year = 2018, volume = 11, number = 2, pages = {1589--1628}, doi = {10.1137/17M1145689}, eprint = {1709.00106} } @Article {lou-2018-fast, title = {Fast {L1-L2} Minimization via a Proximal Operator}, author = {Yifei Lou and Ming Yan}, journal = {Journal of Scientific Computing}, volume = 74, number = 2, pages = {767--785}, year = 2018, doi = {10.1007/s10915-017-0463-2} } @Article {maggioni-2012-nonlocal, title = {Nonlocal transform-domain filter for volumetric data denoising and reconstruction}, author = {Maggioni, Matteo and Katkovnik, Vladimir and Egiazarian, Karen and Foi, Alessandro}, journal = {IEEE Transactions on Image Processing}, volume = 22, number = 1, pages = {119--133}, year = 2012, doi = {10.1109/TIP.2012.2210725} } @InProceedings {makinen-2019-exact, author = {Ymir M\"akinen and Lucio Azzari and Alessandro Foi}, booktitle = {IEEE International Conference on Image Processing (ICIP)}, title = {Exact Transform-Domain Noise Variance for Collaborative Filtering of Stationary Correlated Noise}, year = 2019, pages = {185--189}, doi = {10.1109/ICIP.2019.8802964}, month = Sep } @Article {menon-2007-demosaicing, title = {Demosaicing With Directional Filtering and a posteriori Decision}, author = {Daniele Menon and Stefano Andriani and Giancarlo Calvagno}, journal = {IEEE Transactions on Image Processing}, year = 2007, month = Jan, volume = 16, number = 1, pages = {132--141}, doi = {10.1109/tip.2006.884928} } @Article {monga-2021-algorithm, author = {Monga, Vishal and Li, Yuelong and Eldar, Yonina C.}, journal = {IEEE Signal Processing Magazine}, title = {Algorithm Unrolling: Interpretable, Efficient Deep Learning for Signal and Image Processing}, year = 2021, volume = 38, number = 2, pages = {18-44}, doi = {10.1109/MSP.2020.3016905} } @Book {nocedal-2006-numerical, title = {Numerical Optimization}, author = {Jorge Nocedal and Stephen J. Wright}, year = 2006, publisher = {Springer}, doi = {10.1007/978-0-387-40065-5}, isbn = 9780387303031 } @Article {olufsen-2019-axitom, title = {{AXITOM}: A {P}ython package for reconstruction of axisymmetric tomograms acquired by a conical beam}, volume = 4, doi = {10.21105/joss.01704}, number = 42, journal = {Journal of Open Source Software}, author = {Olufsen, Sindre}, year = 2019, month = oct, pages = {1704} } @Book {paganin-2006-coherent, doi = {10.1093/acprof:oso/9780198567288.001.0001}, isbn = 9780198567288, year = 2006, month = Jan, publisher = {Oxford University Press}, author = {David Paganin}, title = {Coherent X-Ray Optics} } @Article {parikh-2014-proximal, title = {Proximal algorithms}, author = {Parikh, Neal and Boyd, Stephen}, journal = {Foundations and Trends in optimization}, volume = 1, number = 3, pages = {127--239}, year = 2014, doi = {10.1561/2400000003} } @InProceedings {pock-2011-diagonal, author = {Thomas Pock and Antonin Chambolle}, title = {Diagonal preconditioning for first order primal-dual algorithms in convex optimization}, booktitle = {Proceedings of the International Conference on Computer Vision (ICCV)}, doi = {10.1109/iccv.2011.6126441}, pages = {1762--1769}, year = 2011, month = Nov, address = {Barcelona, Spain} } @Misc {pyabel-2022, author = {Stephen Gibson and Daniel Hickstein and Roman Yurchak, Mikhail Ryazanov and Dhrubajyoti Das and Gilbert Shih}, title = {PyAbel}, howpublished = {PyAbel/PyAbel: v0.8.5}, year = 2022, doi = {10.5281/zenodo.5888391} } @InProceedings {ronneberger-2015-unet, author = {Olaf Ronneberger and Philipp Fischer and Thomas Brox}, title = {{U}-{N}et: Convolutional Networks for Biomedical Image Segmentation}, booktitle = {Proceedings of the 18th International Conference on Medical Image Computing and Computer-Assisted Intervention}, doi = {10.1007/978-3-319-24574-4_28}, volume = 9351, pages = {234--241}, year = 2015, month = Oct, address = {Munich, Germany}, } @Article {rudin-1992-nonlinear, author = {Leonid I. Rudin and Stanley Osher and Emad Fatemi}, title = {Nonlinear total variation based noise removal algorithms}, journal = {Physica D: Nonlinear Phenomena}, volume = 60, number = {1--4}, pages = {259-268}, year = 1992, doi = {10.1016/0167-2789(92)90242-F} } @Article {sauer-1993-local, title = {A local update strategy for iterative reconstruction from projections}, author = {Sauer, Ken and Bouman, Charles}, journal = {IEEE Transactions on Signal Processing}, year = 1993, month = Feb, number = 2, pages = {534--548}, volume = 41, doi = {10.1109/78.193196} } @Article {soulez-2016-proximity, author = {Ferr{\'{e}}ol Soulez and {\'{E}}ric Thi{\'{e}}baut and Antony Schutz and Andr{\'{e}} Ferrari and Fr{\'{e}}d{\'{e}}ric Courbin and Michael Unser}, title = {Proximity operators for phase retrieval}, journal = {Applied Optics}, doi = {10.1364/ao.55.007412}, year = 2016, month = Sep, volume = 55, number = 26, pages = {7412--7421} } @Article {sreehari-2016-plug, author = {Suhas Sreehari and Singanallur V. Venkatakrishnan and Brendt Wohlberg and Gregery T. Buzzard and Lawrence F. Drummy and Jeffrey P. Simmons and Charles A. Bouman}, title = {Plug-and-Play Priors for Bright Field Electron Tomography and Sparse Interpolation}, year = 2016, month = Dec, journal = {IEEE Transactions on Computational Imaging}, volume = 2, number = 4, doi = {10.1109/TCI.2016.2599778}, pages = {408--423} } @Misc {svmbir-2020, author = {SVMBIR Development Team}, title = {{S}uper-{V}oxel {M}odel {B}ased {I}terative {R}econstruction ({SVMBIR})}, howpublished = {Software library available from \url{https://github.com/cabouman/svmbir}}, year = 2020 } @Article {valkonen-2014-primal, title = {A primal--dual hybrid gradient method for nonlinear operators with applications to {MRI}}, author = {Valkonen, Tuomo}, journal = {Inverse Problems}, volume = 30, number = 5, pages = 055012, year = 2014, doi = {10.1088/0266-5611/30/5/055012} } @InProceedings {venkatakrishnan-2013-plugandplay2, author = {Singanallur V. Venkatakrishnan and Charles A. Bouman and Brendt Wohlberg}, title = {Plug-and-Play Priors for Model Based Reconstruction}, year = 2013, month = Dec, booktitle = {Proceedings of IEEE Global Conference on Signal and Information Processing (GlobalSIP)}, address = {Austin, TX, USA}, doi = {10.1109/GlobalSIP.2013.6737048}, pages = {945--948} } @Article {voelz-2009-digital, author = {David G. Voelz and Michael C. Roggemann}, title = {Digital Simulation of Scalar Optical Diffraction: Revisiting Chirp Function Sampling Criteria and Consequences}, journal = {Applied Optics}, volume = 48, number = 32, pages = 6132, year = 2009, doi = {10.1364/ao.48.006132}, } @Book {voelz-2011-computational, author = {Voelz, David}, title = {Computational {F}ourier optics : a {MATLAB} tutorial}, year = 2011, publisher = {SPIE Press}, address = {Bellingham, Wash}, isbn = 9780819482044, } @InProceedings {wohlberg-2014-efficient, author = {Brendt Wohlberg}, title = {Efficient Convolutional Sparse Coding}, booktitle = {Proceedings of IEEE International Conference on Acoustics, Speech, and Signal Processing (ICASSP)}, year = 2014, month = May, doi = {10.1109/ICASSP.2014.6854992}, pages = {7173--7177}, location = {Florence, Italy} } @Article {wohlberg-2021-psf, author = {Brendt Wohlberg and Przemek Wozniak}, title = {PSF Estimation in Crowded Astronomical Imagery as a Convolutional Dictionary Learning Problem}, year = 2021, month = Feb, journal = {IEEE Signal Processing Letters}, volume = 28, doi = {10.1109/LSP.2021.3050706}, pages = {374--378} } @Article {yang-2012-linearized, author = {Junfeng Yang and Xiaoming Yuan}, title = {Linearized augmented {L}agrangian and alternating direction methods for nuclear norm minimization}, journal = {Mathematics of Computation}, doi = {10.1090/s0025-5718-2012-02598-1}, year = 2012, month = Mar, volume = 82, number = 281, pages = {301--329} } @InProceedings {yu-2013-better, author = {Yu, Yao-Liang}, booktitle = {Advances in Neural Information Processing Systems}, editor = {C.J. Burges and L. Bottou and M. Welling and Z. Ghahramani and K.Q. Weinberger}, title = {Better Approximation and Faster Algorithm Using the Proximal Average}, url = {https://proceedings.neurips.cc/paper_files/paper/2013/file/49182f81e6a13cf5eaa496d51fea6406-Paper.pdf}, volume = 26, year = 2013 } @Article {zhang-2017-dncnn, author = {Kai Zhang and Wangmeng Zuo and Yunjin Chen and Deyu Meng and Lei Zhang}, title = {Beyond a {G}aussian Denoiser: Residual Learning of Deep {CNN} for Image Denoising}, year = 2017, month = Jul, journal = {IEEE Transactions on Image Processing}, volume = 26, number = 7, doi = {10.1109/TIP.2017.2662206}, pages = {3142--3155} } @Article {zhang-2021-plug, author = {Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu}, title = {Plug-and-Play Image Restoration With Deep Denoiser Prior}, journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, year = 2022, volume = 44, number = 10, doi = {10.1109/TPAMI.2021.3088914}, pages = {6360--6376} } @Article {zhou-2006-adaptive, author = {Bin Zhou and Li Gao and Yu-Hong Dai}, title = {Gradient Methods with Adaptive Step-Sizes}, year = 2006, month = Mar, journal = {Computational Optimization and Applications}, volume = 35, doi = {10.1007/s10589-006-6446-0}, pages = {69--86} } ================================================ FILE: docs/source/style.rst ================================================ .. _scico_dev_style: Style Guide =========== Overview -------- We adhere to `PEP8 `_ with the exception of allowing a line length limit of 99 characters (as opposed to 79 characters). The standard limit of 72 characters for "flowing long blocks of text" in docstrings or comments is retained. We use `Black `_ as our PEP-8 Formatter and `isort `_ to sort imports. (Please set up a `pre-commit hook `_ to ensure any modified code passes format check before it is committed to the development repo.) We aim to incorporate `PEP 526 `_ type annotations throughout the library. See the `Mypy `_ type annotation `cheat sheet `_ for usage examples. Custom types are defined in :mod:`.typing`. Our coding conventions are based on both the `NumPy conventions `_ and the `Google docstring conventions `_. Unicode variable names are allowed for internal usage (e.g. for Greek characters for mathematical symbols), but not as part of the public interface for functions or methods. Naming ------ We follow the `Google naming conventions `_: .. list-table:: Naming Conventions :widths: 20 20 :header-rows: 1 * - Component - Naming Convention * - Modules - module_name * - Package - package_name * - Class - ClassName * - Method - method_name * - Function - function_name * - Exception - ExceptionName * - Variable - var_name * - Parameter - parameter_name * - Constant - CONSTANT_NAME These names should be descriptive and unambiguous to avoid confusion within the code and other modules in the future. Example: .. code:: Python d = 6 # Day of the week == Saturday if d < 5: print("Weekday") Here the code could be hard to follow since the name ``d`` is not descriptive and requires extra comments to explain the code, which would have been solved otherwise by good naming conventions. Example: .. code:: Python fldln = 5 # field length This could be improved by using the descriptive variable ``field_len``. Things to avoid: - Single character names except for the following special cases: - counters or iterators (``i``, ``j``); - `e` as an exception identifier (``Exception e``); - `f` as a file in ``with`` statements; - mathematical notation in which a reference to the paper or algorithm with said notation is preferred if not clear from the intended purpose. - Trailing underscores unless the component is meant to be protected or private: - protected: Use a single underscore, ``_``, for protected access; and - pseudo-private: Use double underscores, ``__``, for pseudo-private access via name mangling. Displaying and Printing Strings ------------------------------- We follow the `Google string conventions `_. Notably, prefer to use Python f-strings, rather than `.format` or `%` syntax. For example: .. code:: Python state = "active" print("The state is %s" % state) # Not preferred print(f"The state is {state}") # Preferred Imports ------- We follow the `Google import conventions `_. The use of ``import`` statements should be reserved for packages and modules only, i.e. individual classes and functions should not be imported. The only exception to this is the typing module. - Use ``import x`` for importing packages and modules, where x is the package or module name. - Use ``from x import y`` where x is the package name and y is the module name. - Use ``from x import y as z`` if two modules named ``y`` are imported or if ``y`` is too long of a name. - Use ``import y as z`` when ``z`` is a standard abbreviation like ``import numpy as np``. Variables --------- We follow the `Google variable typing conventions `_ which states that there are a few extra documentation and coding practices that can be applied to variables such as: - One may type a variables by using a ``: type`` before the function value is assigned, e.g., .. code-block:: python a: Foo = SomeDecoratedFunction() - Avoid global variables. - A function can refer to variables defined in enclosing functions but cannot assign to them. Parameters ---------- There are three important style components for parameters inspired by the `NumPy parameter conventions `_: 1. Typing We use type annotations meaning we specify the types of the inputs and outputs of any method. From the ``typing`` module we can use more types such as ``Optional``, ``Union``, and ``Any``. For example, .. code-block:: python def foo(a: str) -> str: """Takes an input of type string and returns a value of type string""" ... 2. Default Values Parameters should include ``parameter_name = value`` where value is the default for that particular parameter. If the parameter has a type then the format is ``parameter_name: Type = value``. When documenting parameters, if a parameter can only assume one of a fixed set of values, those values can be listed in braces, with the default appearing first. For example, .. code-block:: python """ letters: {'A', 'B, 'C'} Description of `letters`. """ 3. NoneType In Python, ``NoneType`` is a first-class type, meaning the type itself can be passed into and returned from functions. ``None`` is the most commonly used alias for ``NoneType``. If any of the parameters of a function can be ``None`` then it has to be declared. ``Optional[T]`` is preferred over ``Union[T, None]``. For example, .. code-block:: python def foo(a: Optional[str], b: Optional[Union[str, int]]) -> str: ... For documentation purposes, ``NoneType`` or ``None`` should be written with double backticks. Docstrings ---------- Docstrings are a way to document code within Python and it is the first statement within a package, module, class, or function. To generate a document with all the documentation for the code use `pydoc `_. Typing ~~~~~~ We follow the `NumPy parameter conventions `_. The following are docstring-specific usages: - Always enclose variables in single backticks. - For the parameter types, be as precise as possible, do not use backticks. Modules ~~~~~~~ We follow the `Google module conventions `_. Notably, files must start with a docstring that describes the functionality of the module. For example, .. code-block:: python """A one-line summary of the module must be terminated by a period. Leave a blank line and describe the module or program. Optionally describe exported classes, functions, and/or usage examples. Usage Example: foo = ClassFoo() bar = foo.FunctionBar() """" Functions ~~~~~~~~~ The word *function* encompasses functions, methods, or generators in this section. The docstring should give enough information to make calls to the function without needing to read the functions code. We follow the `Google function conventions `_. Notably, functions should contain docstrings unless: - not externally visible (the function name is prefaced with an underscore) or - very short. The docstring should be imperative-style ``"""Fetch rows from a Table"""`` instead of the descriptive-style ``"""Fetches rows from a Table"""``. If the method overrides a method from a base class then it may use a simple docstring referencing that base class such as ``"""See base class"""``, unless the behavior is different from the overridden method or there are extra details that need to be documented. | There are three sections to function docstrings: - Args: - List each parameter by name, and include a description for each parameter. - Returns: (or Yield in the case of generators) - Describe the type of the return value. If a function only returns ``None`` then this section is not required. - Raises: - List all exceptions followed by a description. The name and description should be separated by a colon followed by a space. Example: .. code-block:: python def fetch_smalltable_rows(table_handle: smalltable.Table, keys: Sequence[Union[bytes, str]], require_all_keys: bool = False, ) -> Mapping[bytes, Tuple[str]]: """Fetch rows from a Smalltable. Retrieve rows pertaining to the given keys from the Table instance represented by table_handle. String keys will be UTF-8 encoded. Args: table_handle: An open smalltable.Table instance. keys: A sequence of strings representing the key of each table row to fetch. String `keys` will be UTF-8 encoded. require_all_keys: Optional If `require_all_keys` is ``True`` only rows with values set for all keys will be returned. Returns: A dict mapping keys to the corresponding table row data fetched. Each row is represented as a tuple of strings. For example: {b'Serak': ('Rigel VII', 'Preparer'), b'Zim': ('Irk', 'Invader'), b'Lrrr': ('Omicron Persei 8', 'Emperor')} Returned keys are always bytes. If a key from the keys argument is missing from the dictionary, then that row was not found in the table (and require_all_keys must have been False). Raises: IOError: An error occurred accessing the smalltable. """ Classes ~~~~~~~ We follow the `Google class conventions `_. Classes, like functions, should have a docstring below the definition describing the class and the class functionality. If the class contains public attributes, the class should have an attributes section where each attribute is listed by name and followed by a description, separated by a colon, like for function parameters. For example, | Example: .. code:: Python class foo: """One-liner describing the class. Additional information or description for the class. Can be multi-line Attributes: attr1: First attribute of the class. attr2: Second attribute of the class. """ def __init__(self): """Should have a docstring of type function.""" pass def method(self): """Should have a docstring of type: function.""" pass Extra Sections ~~~~~~~~~~~~~~ We follow the `NumPy style guide `_. Notably, the following are sections that can be added to functions, modules, classes, or method definitions. - See Also: - Refers to related code. Used to direct users to other modules, functions, or classes that they may not be aware of. - When referring to functions in the same sub-module, no prefix is needed. Example: For ``numpy.mean`` inside the same sub-module: .. code-block:: python """ See Also -------- average: Weighted average. """ - For a reference to ``fft`` in another module: .. code-block:: python """ See Also -------- fft.fft2: 2-D fast discrete Fourier transform. """ - Notes - Provide additional information about the code. May include mathematical equations in LaTeX format. For example, .. code-block:: python """ Notes ----- The FFT is a fast implementation of the discrete Fourier transform: .. math:: X(e^{j\omega } ) = x(n)e^{ - j\omega n} """ Math can also be used inline: .. code-block:: python """ Notes ----- The value of :math:`\omega` is larger than 5. """ For a list of available LaTex macros, search for "macros" in `docs/source/conf.py `_. - Examples: - Uses the doctest format and is meant to showcase usage. - If there are multiple examples include blank lines before and after each example. For example, .. code-block:: python """ Examples -------- Necessary imports >>> import numpy as np Comment explaining example 1. >>> int(np.add(1, 2)) 3 Comment explaining a new example. >>> np.add([1, 2], [3, 4]) array([4, 6]) If the example is too long then each line after the first start it with a ``...`` >>> np.add([[1, 2], [3, 4]], ... [[5, 6], [7, 8]]) array([[ 6, 8], [10, 12]]) """ Comments ~~~~~~~~ There are two types of comments: *block* and *inline*. A good rule of thumb to follow for when to include a comment in your code is *if you have to explain it or is too hard to figure out at first glance, then comment it*. An example of this, taken from the `Google comment conventions `_, is complicated operations which most likely require a block of comments beforehand. .. code-block:: Python # We use a block comment because the following code performs a # difficult operation. Here we can explain the variables or # what the concept of the operation does in an easier # to understand way. i = i & (i-1) == 0: # true if i is 0 or a power of 2 [explains the concept not the code] If a comment consists of one or more full sentences (as is typically the case for *block* comments), it should start with an upper case letter and end with a period. *Inline* comments often consist of a brief phrase which is not a full sentence, in which case they should have a lower case initial letter and not have a terminating period. Markup ~~~~~~ The following components require the recommended markup taken from the `NumPy Conventions `__.: - Paragraphs: Indentation is significant and indicates the indentation of the output. New paragraphs are marked with a blank line. - Variable, parameter, module, function, method, and class names: Should be written between single back-ticks (e.g. \`x\`, rendered as `x`), but note that use of `Sphinx cross-reference syntax `_ is preferred for modules (`:mod:\`module-name\`` ), functions (`:func:\`function-name\`` ), methods (`:meth:\`method-name\`` ) and classes (`:class:\`class-name\`` ). - None, NoneType, True, and False: Should be written between double back-ticks (e.g. \`\`None\`\`, \`\`True\`\`, rendered as ``None``, ``True``). - Types: Should be written between double back-ticks (e.g. \`\`int\`\`, rendered as ``int``). NumPy dtypes, however, should be written using cross-reference syntax, e.g. \:attr\:\`~numpy.float32\` for :attr:`~numpy.float32`. Other components can use \*italics\*, \*\*bold\*\*, and \`\`monospace\`\` (respectively rendered as *italics*, **bold**, and ``monospace``) if needed, but not for variable names, doctest code, or multi-line code. Documentation ------------- Documentation that is separate from code (like this page) should follow the `IEEE Style Manual `_. For additional grammar and usage guidance, refer to `The Chicago Manual of Style `_. A few notable guidelines: * Equations which conclude a sentence should end with a period, e.g., "Poisson's equation is .. math:: \Delta \varphi = f \;." * Do not capitalize acronyms or inititalisms when defining them, e.g., "computer-aided system engineering (CASE)," "fast Fourier transform (FFT)." * Avoid capitalization in text except where absolutely necessary, e.g., "Newton’s first law." * Use a single space after the period at the end of a sentence. The source code (`.rst` files) for these pages does not have a hard line-length guideline, but line breaks at or before 79 characters are encouraged. ================================================ FILE: docs/source/team.rst ================================================ Developers ========== Core Developers --------------- - `Cristina Garcia Cardona `_ - `Michael McCann `_ - `Brendt Wohlberg `_ Emeritus Developers ------------------- - `Thilo Balke `_ - `Fernando Davis `_ - `Soumendu Majee `_ - `Luke Pfister `_ Contributors ------------ - `Weijie Gan `_ (Non-blind variant of DnCNN) - `Oleg Korobkin `_ (BlockArray improvements) - `Andrew Leong `_ (Improvements to optics module documentation) - `Saurav Maheshkar `_ (Improvements to pre-commit configuration) - `Yanpeng Yuan `_ (ASTRA interface improvements) - `Li-Ta (Ollie) Lo `_ (ASTRA interface improvements) - `Renat Sibgatulin `_ (Docs corrections) - `Salman Naqvi `_ (Contributions to approximate TV norm prox and proximal average implementation) - `Eddie Chandler `_ (Contributions to approximate isotropic TV norm prox) ================================================ FILE: docs/source/zreferences.rst ================================================ References ========== .. bibliography:: references.bib :style: plain ================================================ FILE: docs/tikxfigures/img_align.tex ================================================ \documentclass[tikz]{standalone} \usetikzlibrary{calc,angles,quotes} \begin{document} \begin{tikzpicture}[scale=2] \footnotesize % Define rectangle dimensions \def\width{2} % base width \def\aspect{1.25} % aspect ratio (height/width) \pgfmathsetmacro{\height}{\width*\aspect} % Rotate rectangle by 20 degrees \begin{scope}[rotate around={-20:(0,0)}] % Draw rectangle with bottom-left corner at origin \draw[thick] (0,0) -| (\width,\height) node[pos=0.25,below] {$N_1$} node[pos=0.75,right] {$N_0$} -| (0,0); % Save post-rotation rectangle corners \coordinate (BL) at (0,0); \coordinate (BR) at (\width,0); \coordinate (TL) at (0,\height); \coordinate (TR) at (\width,\height); \end{scope} \def\liney{2.5} % top line height \coordinate (PL) at (BL |- 0,\liney); % vertical intersection from bottom-left \coordinate (PR) at (TR |- 0,\liney); % vertical intersection from top-right % Horizontal line representing sensor \draw[blue,thick] (PL) -- (PR); % Draw verticals to meet horizontal line \draw[blue,thick,dashed] (BL) -- (PL); \draw[blue,thick,dashed] (TR) -- (PR); % Double-sided arrow for width label \draw[<->,blue,dashed] (PL) ++(0,0.15) -- ($(PR)+(0,0.15)$) node[midway,above] {$w_0 + w_1$}; % Central vertical line through top-left corner \draw[blue,thick,dashed] (TL |- BL) -- (TL |- 0,\liney); % Horizontal lines with labels \draw[blue,thick,dashed] (TR) -- (TL |- TR) node[midway,below] {$w_1 = N_1 \cos(\theta)$}; \draw[blue,thick,dashed] (BL) -- (TL |- BL) node[right,below] {$\qquad w_0 = N_0 \sin(\theta)$}; % Define intersection point with central vertical line \coordinate (VL) at (TL |- BR); \coordinate (HL) at (TL |- TR); % θ between left rectangle side and central vertical line \pic [draw, ->, "$\theta$", angle radius=30] {angle = BL--TL--VL}; % 90-θ between top rectangle side and horizontal line \pic [draw, ->, "$90\!-\!\theta\quad\;\;$", angle radius=50] {angle = TL--TR--HL}; \end{tikzpicture} \end{document} ================================================ FILE: docs/tikxfigures/makesvg.sh ================================================ #! /bin/bash pdf2svg vol_align_xyz.pdf vol_align_xyz.svg pdf2svg vol_align_xz.pdf vol_align_xz.svg pdf2svg vol_align_yz.pdf vol_align_yz.svg pdf2svg img_align.pdf img_align.svg ================================================ FILE: docs/tikxfigures/vol_align_xyz.tex ================================================ \documentclass{standalone} \usepackage{tikz, tikz-3dplot} \begin{document} \tdplotsetmaincoords{70}{110} \begin{tikzpicture}[scale=5,tdplot_main_coords] \footnotesize \draw[thick,->] (0,0,0) -- (1,0,0) node[anchor=north east]{$x$}; \draw[thick,->] (0,0,0) -- (0,1,0) node[anchor=north west]{$y$}; \draw[thick,->] (0,0,0) -- (0,0,1) node[anchor=south]{$z$}; \coordinate (O) at (0,0,0); \tdplotsetcoord{P}{1}{30}{40} \draw[-stealth,thick,color=red] (O) -- (P); \node[draw=none,color=red] at (0.0,0.11,0.77) {$(x, y, z)$}; \draw[dashed, color=blue] (P) -- (Pxz); \draw[dashed, color=blue] (P) -- (Pyz); \draw[dashed, color=blue] (O) -- (Pxz); \draw[dashed, color=blue] (O) -- (Pyz); \draw[dashed, color=blue] (Pz) -- (Pxz); \draw[dashed, color=blue] (Pz) -- (Pyz); \node[draw=none,color=blue] at (0.38,0.0,0.5) {$r_x$}; \node[draw=none,color=blue] at (0.0,0.24,0.4) {$r_y$}; \tdplotsetthetaplanecoords{0} \tdplotdrawarc[tdplot_rotated_coords,blue,dotted]{(O)}{.25}{22.7}{90}{anchor=mid east}{$\theta_x$} \tdplotsetthetaplanecoords{90} \tdplotdrawarc[tdplot_rotated_coords,blue,dotted]{(O)}{.25}{23}{90}{anchor=mid west}{$\theta_y$} \end{tikzpicture} \end{document} ================================================ FILE: docs/tikxfigures/vol_align_xz.tex ================================================ \documentclass{standalone} \usepackage{tikz, tikz-3dplot} \begin{document} \tdplotsetmaincoords{90}{0} \begin{tikzpicture}[scale=5,tdplot_main_coords] \footnotesize \draw[thick,->] (0,0,0) -- (1,0,0) node[anchor=west]{$x$}; \draw[thick,->] (0,0,0) -- (0,0,1) node[anchor=south]{$z$}; \coordinate (O) at (0,0,0); \tdplotsetcoord{P}{1}{30}{40} \draw[-stealth,thick,color=red] (O) -- (P) node[anchor=west]{$\!(x, z)$}; \draw[dashed, color=blue] (P) -- (Pyz); \draw[dashed, color=blue] (P) -- (Px); \node[draw=none,color=blue] at (0.4,0.0,-0.055) {$r_x \cos (\theta_x)$}; \node[draw=none,rotate=90,color=blue] at (-0.055,0.0,0.84) {$r_x \sin (\theta_x)$}; \tdplotsetthetaplanecoords{0} \tdplotdrawarc[tdplot_rotated_coords,red,dotted]{(O)}{.25}{22.7}{90}{anchor=north east}{$\theta_x$} \node[draw=none,color=red] at (0.14,0.0,0.5) {$r_x$}; \end{tikzpicture} \end{document} ================================================ FILE: docs/tikxfigures/vol_align_yz.tex ================================================ \documentclass{standalone} \usepackage{tikz, tikz-3dplot} \begin{document} \tdplotsetmaincoords{90}{90} \begin{tikzpicture}[scale=5,tdplot_main_coords] \footnotesize \draw[thick,->] (0,0,0) -- (0,1,0) node[anchor=west]{$y$}; \draw[thick,->] (0,0,0) -- (0,0,1) node[anchor=south]{$z$}; \coordinate (O) at (0,0,0); \tdplotsetcoord{P}{1}{30}{40} \draw[-stealth,thick,color=red] (O) -- (P) node[anchor=west]{$\!(y, z)$}; \draw[dashed, color=blue] (P) -- (Pxz); \draw[dashed, color=blue] (P) -- (Py); \node[draw=none,color=blue] at (0.0,0.35,-0.055) {$r_y \cos (\theta_y)$}; \node[draw=none,rotate=90,color=blue] at (0.0,-0.055,0.84) {$r_y \sin (\theta_y)$}; \tdplotsetthetaplanecoords{90} \tdplotdrawarc[tdplot_rotated_coords,red,dotted]{(O)}{.25}{23}{90}{anchor=north east}{$\theta_y$} \node[draw=none,color=red] at (0.05,0.11,0.5) {$r_y$}; \end{tikzpicture} \end{document} ================================================ FILE: examples/README.rst ================================================ SCICO Usage Examples ==================== This directory contains usage examples for the SCICO package. The primary form of these examples is the Python scripts in the directory ``scripts``. A corresponding set of Jupyter notebooks, in the directory ``notebooks``, is auto-generated from these usage example scripts. Building Notebooks ------------------ The scripts for building Jupyter notebooks from the source example scripts are currently only supported under Linux. All scripts described below should be run from this directory, i.e. ``[repo root]/examples``. Running on a GPU ^^^^^^^^^^^^^^^^ Since some of the examples require a considerable amount of memory (``deconv_microscopy_tv_admm.py`` and ``deconv_microscopy_allchn_tv_admm.py`` in particular), it is recommended to set the following environment variables prior to building the notebooks: :: export XLA_PYTHON_CLIENT_ALLOCATOR=platform export XLA_PYTHON_CLIENT_PREALLOCATE=false Running on a CPU ^^^^^^^^^^^^^^^^ If a GPU is not available, or if the available GPU does not have sufficient memory to build the notebooks, set the environment variable :: JAX_PLATFORM_NAME=cpu to run on the CPU instead. Building Specific Examples -------------------------- To build or rebuild notebooks for specific examples, the example script names can be specified on the command line, e.g. :: python makenotebooks.py ct_astra_pcg.py ct_astra_tv_admm.py When rebuilding notebooks for examples that themselves make use of ``ray`` for parallelization (e.g. ``deconv_microscopy_allchn_tv_admm.py``), it is recommended to specify serial notebook execution, as in :: python makenotebooks.py --no-ray deconv_microscopy_allchn_tv_admm.py Building All Examples --------------------- By default, ``makenotebooks.py`` only rebuilds notebooks that are out of date with respect to their corresponding example scripts, as determined by their respective file timestamps. However, timestamps for files retrieved from version control may not be meaningful for this purpose. To rebuild all examples, the following commands (assuming that GPUs are available) are recommended: :: export XLA_PYTHON_CLIENT_ALLOCATOR=platform export XLA_PYTHON_CLIENT_PREALLOCATE=false touch scripts/*.py python makenotebooks.py --no-ray deconv_microscopy_tv_admm.py deconv_microscopy_allchn_tv_admm.py python makenotebooks.py Updating Notebooks in the Repo ------------------------------ The recommended procedure for rebuilding notebooks for inclusion in the ``data`` submodule is: 1. Add and commit the modified script(s). 2. Rebuild the notebooks as described above. 2. Add and commit the updated notebooks following the submodule handling procedure described in the developer docs. Adding a New Notebook --------------------- The procedure for adding a adding a new notebook is: 1. Add an entry for the source file in ``scripts/index.rst``. Note that a script that is not listed in this index will not be converted into a notebook. 2. Run ``makeindex.py`` to update the example scripts README file, the notebook index file, and the examples index in the docs. 3. Build the corresponding notebook following the instructions above. 4. Add and commit the new script, the ``scripts/index.rst`` script index file, the auto-generated ``scripts/README.rst`` file and ``docs/source/examples.rst`` index file, and the new or updated notebooks and the auto-generated ``notebooks/index.ipynb`` file in the notebooks directory, following the submodule handling procedure as described in the developer docs. Management Utilities -------------------- A number of files in this directory assist in the mangement of the usage examples: `examples_requirements.txt `_ Requirements file (as used by ``pip``) listing additional dependencies for running the usage example scripts. `notebooks_requirements.txt `_ Requirements file (as used by ``pip``) listing additional dependencies for building the Jupyter notebooks from the usage example scripts. `makenotebooks.py `_ Auto-generate Jupyter notebooks from the example scripts. `updatejnbmd.py `_ Update markdown cells in notebooks from corresponding example scripts. `makeindex.py `_ Auto-generate the docs example index ``docs/source/examples.rst`` from the example scripts index ``scripts/index.rst``. `scriptcheck.sh `_ Run all example scripts with smaller problems and a reduced number of iterations as a rapid check that they are functioning correctly. ================================================ FILE: examples/examples_requirements.txt ================================================ -r ../requirements.txt colorama colour_demosaicing svmbir>=0.4.0 astra-toolbox xdesign>=0.5.5 ray[tune,train]>=2.44 hyperopt setuptools<82.0.0 # workaround for hyperopt 0.2.7 pydantic orbax-checkpoint>=0.5.0 bm3d>=4.0.0 bm4d>=4.2.2 ================================================ FILE: examples/jnb.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Support functions for manipulating Jupyter notebooks.""" import re from timeit import default_timer as timer import nbformat from nbconvert.preprocessors import CellExecutionError, ExecutePreprocessor from py2jn.tools import py_string_to_notebook, write_notebook def py_file_to_string(src): """Preprocess example script file and return result as a string.""" with open(src, "r") as srcfile: # Drop header comment for line in srcfile: if line[0] != "#": break # assume first non-comment line is a newline that can be dropped # Insert notebook plot config after last import lines = [] import_seen = False for line in srcfile: line = re.sub('^r"""', '"""', line) # remove r from r""" line = re.sub(":cite:`([^`]+)`", r'', line) # fix cite format if import_seen: # Once an import statement has been seen, break on encountering a line that # is neither an import statement nor a newline, nor a component of an import # statement extended over multiple lines, nor an os.environ statement, nor a # ray.init statement, nor components of a try/except construction (note that # handling of these final two cases is probably not very robust). if not re.match( r"(^import|^from|^\n$|^\W+[^\W]|^\)$|^os.environ|^ray.init|^try:$|^except)", line, ): lines.append(line) break else: # Set flag indicating that an import statement has been seen once one has # been encountered if re.match("^import|^from .* import", line): import_seen = True lines.append(line) if "plot" in "".join(lines): # Backtrack through list of lines to find last import statement n = 1 for line in lines[-2::-1]: if re.match("^(import|from)", line): break else: n += 1 # Insert notebook plotting config directly after last import statement lines.insert(-n, "plot.config_notebook_plotting()\n") # Process remainder of source file for line in srcfile: if re.match(r"^input\(", line): # end processing when input statement encountered break line = re.sub('^r"""', '"""', line) # remove r from r""" line = re.sub(r":cite:\`([^`]+)\`", r'', line) # fix cite format lines.append(line) # Backtrack through list of lines to remove trailing newlines n = 0 for line in lines[::-1]: if re.match("^\n$", line): n += 1 else: break if n > 0: lines = lines[0:-n] return "".join(lines) def script_to_notebook(src, dst): """Convert a Python example script into a Jupyter notebook.""" s = py_file_to_string(src) nb = py_string_to_notebook(s) write_notebook(nb, dst) def read_notebook(fname): """Read a notebook from the specified notebook file.""" try: nb = nbformat.read(fname, as_version=4) except (AttributeError, nbformat.reader.NotJSONError): raise RuntimeError("Error reading notebook file %s." % fname) return nb def execute_notebook(fname): """Execute the specified notebook file.""" with open(fname) as f: nb = nbformat.read(f, as_version=4) ep = ExecutePreprocessor(timeout=None) try: t0 = timer() out = ep.preprocess(nb) t1 = timer() with open(fname, "w", encoding="utf-8") as f: nbformat.write(nb, f) except CellExecutionError: print(f"ERROR executing {fname}") return False print(f"{fname} done in {(t1 - t0):.1e} s") return True def notebook_executed(nbfn): """Determine whether the notebook at `nbfn` has been executed.""" try: nb = nbformat.read(nbfn, as_version=4) except (AttributeError, nbformat.reader.NotJSONError): raise RuntimeError("Error reading notebook file %s." % pth) cells = nb["worksheets"][0]["cells"] for n in range(len(nb["cells"])): if cells[n].cell_type == "code" and cells[n].execution_count is None: return False return True def same_notebook_code(nb1, nb2): """Return ``True`` if the code cells of notebook objects `nb1` and `nb2` are all the same. """ if "cells" in nb1: nb1c = nb1["cells"] else: nb1c = nb1["worksheets"][0]["cells"] if "cells" in nb2: nb2c = nb2["cells"] else: nb2c = nb2["worksheets"][0]["cells"] # Notebooks do not match if the number of cells differ if len(nb1c) != len(nb2c): return False # Iterate over cells in nb1 for n in range(len(nb1c)): # Notebooks do not match if corresponding cells have different # types if nb1c[n]["cell_type"] != nb2c[n]["cell_type"]: return False # Notebooks do not match if source of corresponding code cells # differ if nb1c[n]["cell_type"] == "code" and nb1c[n]["source"] != nb2c[n]["source"]: return False return True def same_notebook_markdown(nb1, nb2): """Return ``True`` if the markdown cells of notebook objects `nb1` and `nb2` are all the same. """ if "cells" in nb1: nb1c = nb1["cells"] else: nb1c = nb1["worksheets"][0]["cells"] if "cells" in nb2: nb2c = nb2["cells"] else: nb2c = nb2["worksheets"][0]["cells"] # Notebooks do not match if the number of cells differ if len(nb1c) != len(nb2c): return False # Iterate over cells in nb1 for n in range(len(nb1c)): # Notebooks do not match if corresponding cells have different # types if nb1c[n]["cell_type"] != nb2c[n]["cell_type"]: return False # Notebooks do not match if source of corresponding code cells # differ if nb1c[n]["cell_type"] == "markdown" and nb1c[n]["source"] != nb2c[n]["source"]: return False return True def replace_markdown_cells(src, dst): """Overwrite markdown cells in notebook object `dst` with corresponding cells in notebook object `src`. """ if "cells" in src: srccell = src["cells"] else: srccell = src["worksheets"][0]["cells"] if "cells" in dst: dstcell = dst["cells"] else: dstcell = dst["worksheets"][0]["cells"] # It is an error to attempt markdown replacement if src and dst # have different numbers of cells if len(srccell) != len(dstcell): raise ValueError("Notebooks do not have the same number of cells.") # Iterate over cells in src for n in range(len(srccell)): # It is an error to attempt markdown replacement if any # corresponding pair of cells have different type if srccell[n]["cell_type"] != dstcell[n]["cell_type"]: raise ValueError("Cell number %d of different type in src and dst.") # If current src cell is a markdown cell, copy the src cell to # the dst cell if srccell[n]["cell_type"] == "markdown": dstcell[n]["source"] = srccell[n]["source"] def remove_error_output(src): """Remove output to stderr from all cells in `src`.""" if "cells" in src: cells = src["cells"] else: cells = src["worksheets"][0]["cells"] modified = False for c in cells: if "outputs" in c: dellist = [] for n, out in enumerate(c["outputs"]): if "name" in out and out["name"] == "stderr": dellist.append(n) modified = True for n in dellist[::-1]: del c["outputs"][n] return modified ================================================ FILE: examples/makeindex.py ================================================ #!/usr/bin/env python # Construct an index README file and a docs example index file from # source index file "scripts/index.rst". # Run as # python makeindex.py import re from pathlib import Path import nbformat as nbf import py2jn import pypandoc src = "scripts/index.rst" # Make dict mapping script names to docstring header titles titles = {} scripts = list(Path("scripts").glob("*py")) for s in scripts: prevline = None with open(s, "r") as sfile: for line in sfile: if line[0:3] == "===": titles[s.name] = prevline.rstrip() break else: prevline = line # Build README in scripts directory dst = "scripts/README.rst" with open(dst, "w") as dstfile: with open(src, "r") as srcfile: for line in srcfile: # Detect lines containing script filenames m = re.match(r"(\s+)- ([^\s]+.py)", line) if m: prespace = m.group(1) name = m.group(2) title = titles[name] print( "%s`%s <%s>`_\n%s %s" % (prespace, name, name, prespace, title), file=dstfile ) else: print(line, end="", file=dstfile) # Build notebooks index file in notebooks directory dst = "notebooks/index.ipynb" rst_text = "" with open(src, "r") as srcfile: for line in srcfile: # Detect lines containing script filenames m = re.match(r"(\s+)- ([^\s]+).py", line) if m: prespace = m.group(1) name = m.group(2) title = titles[name + ".py"] rst_text += "%s- `%s <%s.ipynb>`_\n" % (prespace, title, name) else: rst_text += line # Convert text from rst to markdown md_format = "markdown_github+tex_math_dollars+fenced_code_attributes" md_text = pypandoc.convert_text(rst_text, md_format, format="rst", extra_args=["--atx-headers"]) md_text = '"""' + md_text + '"""' # Convert from python to notebook format and write notebook nb = py2jn.py_string_to_notebook(md_text) py2jn.tools.write_notebook(nb, dst, nbver=4) nb = nbf.read(dst, nbf.NO_CONVERT) nb.metadata = {"nbsphinx": {"orphan": True}} nbf.write(nb, dst) # Build examples index for docs dst = "../docs/source/examples.rst" prfx = "examples/" with open(dst, "w") as dstfile: print(".. _example_notebooks:\n", file=dstfile) with open(src, "r") as srcfile: for line in srcfile: # Add toctree and include statements after main heading if line[0:3] == "===": print(line, end="", file=dstfile) print("\n.. toctree::\n :maxdepth: 1", file=dstfile) print("\n.. include:: include/examplenotes.rst", file=dstfile) continue # Detect lines containing script filenames m = re.match(r"(\s+)- ([^\s]+).py", line) if m: print(" " + prfx + m.group(2), file=dstfile) else: print(line, end="", file=dstfile) # Add toctree statement after section headings if line[0:3] == line[0] * 3 and line[0] in ["=", "-", "^"]: print("\n.. toctree::\n :maxdepth: 1", file=dstfile) ================================================ FILE: examples/makenotebooks.py ================================================ #!/usr/bin/env python # Extract a list of Python scripts from "scripts/index.rst" and # create/update and execute any Jupyter notebooks that are out # of date with respect to their source Python scripts. If script # names specified on command line, process them instead. # Run # python makenotebooks.py -h # for usage details. import argparse import os import re import signal import sys from pathlib import Path import psutil from jnb import execute_notebook, script_to_notebook examples_dir = Path(__file__).resolve().parent # absolute path to ../scico/examples/ have_ray = True try: import ray except ImportError: have_ray = False def script_uses_ray(fname): """Determine whether a script uses ray.""" with open(fname, "r") as f: text = f.read() return bool(re.search("^import ray", text, re.MULTILINE)) or bool( re.search("^import scico.ray", text, re.MULTILINE) ) def script_path(sname): """Get script path from script name.""" return examples_dir / "scripts" / Path(sname) def notebook_path(sname): """Get notebook path from script path.""" return examples_dir / "notebooks" / Path(Path(sname).stem + ".ipynb") argparser = argparse.ArgumentParser( description="Convert Python example scripts to Jupyter notebooks." ) argparser.add_argument( "--all", action="store_true", help="Process all notebooks, without checking timestamps. " "Has no effect when files to process are explicitly specified.", ) argparser.add_argument( "--no-exec", action="store_true", help="Create/update notebooks but don't execute them." ) argparser.add_argument( "--no-ray", action="store_true", help="Execute notebooks serially, without the use of ray parallelization.", ) argparser.add_argument( "--verbose", action="store_true", help="Verbose operation.", ) argparser.add_argument( "--test", action="store_true", help="Show actions that would be taken but don't do anything.", ) argparser.add_argument("filename", nargs="*", help="Optional Python example script filenames") args = argparser.parse_args() # Raise error if ray needed but not present if not have_ray and not args.no_ray: raise RuntimeError("The ray package is required to run this script, try --no-ray") if args.filename: # Script names specified on command line scriptnames = [os.path.basename(s) for s in args.filename] else: # Read script names from index file scriptnames = [] srcidx = examples_dir / "scripts" / "index.rst" with open(srcidx, "r") as idxfile: for line in idxfile: m = re.match(r"(\s+)- ([^\s]+.py)", line) if m: scriptnames.append(m.group(2)) # Ensure list entries are unique scriptnames = sorted(list(set(scriptnames))) # Create list of selected scripts. scripts = [] for s in scriptnames: spath = script_path(s) npath = notebook_path(s) # If scripts specified on command line or --all flag specified, convert all scripts. # Otherwise, only convert scripts that have a newer timestamp than their corresponding # notebooks, or that have not previously been converted (i.e. corresponding notebook # file does not exist). if ( args.all or args.filename or not npath.is_file() or spath.stat().st_mtime > npath.stat().st_mtime ): # Add to the list of selected scripts scripts.append(s) if not scripts: if args.verbose: print("No scripts require conversion") sys.exit(0) # Display status information if args.verbose: print(f"Processing scripts {', '.join(scripts)}") # Convert selected scripts to corresponding notebooks and determine which can be run in parallel serial_scripts = [] parallel_scripts = [] for s in scripts: spath = script_path(s) npath = notebook_path(s) # Determine how script should be executed if script_uses_ray(spath): serial_scripts.append(s) else: parallel_scripts.append(s) # Make notebook file if args.verbose or args.test: print(f"Converting script {s} to notebook") if not args.test: script_to_notebook(spath, npath) if args.no_exec: if args.verbose: print("Notebooks will not be executed") sys.exit(0) # If ray disabled or not worth using, run all serially if args.no_ray or len(parallel_scripts) < 2: serial_scripts.extend(parallel_scripts) parallel_scripts = [] # Execute notebooks corresponding to serial_scripts for s in serial_scripts: npath = notebook_path(s) if args.verbose or args.test: print(f"Executing notebook corresponding to script {s}") if not args.test: execute_notebook(npath) # Execute notebooks corresponding to parallel_scripts if parallel_scripts: if args.verbose or args.test: print( f"Notebooks corresponding to scripts {', '.join(parallel_scripts)} will " "be executed in parallel" ) # Execute notebooks in parallel using ray nproc = len(parallel_scripts) ray.init() ngpu = 0 ar = ray.available_resources() ncpu = max(int(ar["CPU"]) // nproc, 1) if "GPU" in ar: ngpu = max(int(ar["GPU"]) // nproc, 1) if args.verbose or args.test: print(f" Running on {ncpu} CPUs and {ngpu} GPUs per process") # Function to execute each notebook with available resources suitably divided @ray.remote(num_cpus=ncpu, num_gpus=ngpu) def ray_run_nb(fname): execute_notebook(fname) if not args.test: # Execute relevant notebooks in parallel try: notebooks = [notebook_path(s) for s in parallel_scripts] objrefs = [ray_run_nb.remote(nbfile) for nbfile in notebooks] ray.wait(objrefs, num_returns=len(objrefs)) except KeyboardInterrupt: print("\nTerminating on keyboard interrupt") for ref in objrefs: ray.cancel(ref, force=True) ray.shutdown() # Clean up sub-processes not ended by ray.cancel process = psutil.Process() children = process.children(recursive=True) for child in children: os.kill(child.pid, signal.SIGTERM) ================================================ FILE: examples/notebooks_requirements.txt ================================================ -r examples-requirements.txt ipykernel ipywidgets nbformat nbconvert nb_conda_kernels<=2.5.1 # 2.5.2 broken: see anaconda/nb_conda_kernels#280 psutil py2jn pypandoc ================================================ FILE: examples/removejnberr.py ================================================ #!/usr/bin/env python # Remove output to stderr in notebooks. NB: use with caution! # Run as # python removejnberr.py import glob import os from jnb import read_notebook, remove_error_output from py2jn.tools import write_notebook for src in glob.glob(os.path.join("notebooks", "*.ipynb")): nb = read_notebook(src) modflg = remove_error_output(nb) if modflg: print(f"Removing output to stderr from {src}") write_notebook(nb, src) ================================================ FILE: examples/scriptcheck.sh ================================================ #!/usr/bin/env bash # Basic test of example script functionality by running them all with # optimization algorithms configured to use only a small number of iterations. # Currently only supported under Linux. SCRIPT=$(basename $0) SCRIPTPATH=$(realpath $(dirname $0)) USAGE=$(cat <<-EOF Usage: $SCRIPT [-h] [-d] [-h] Display usage information [-e] Display excerpt of error message on failure [-d] Skip tests involving additional data downloads [-t] Skip tests related to learned model training [-g] Skip tests that need a GPU EOF ) OPTIND=1 DISPLAY_ERROR=0 SKIP_DOWNLOAD=0 SKIP_TRAINING=0 SKIP_GPU=0 while getopts ":hedtg" opt; do case $opt in h) echo "$USAGE"; exit 0;; e) DISPLAY_ERROR=1;; d) SKIP_DOWNLOAD=1;; t) SKIP_TRAINING=1;; g) SKIP_GPU=1;; \?) echo "Error: invalid option -$OPTARG" >&2 echo "$USAGE" >&2 exit 1 ;; esac done shift $((OPTIND-1)) if [ ! $# -eq 0 ] ; then echo "Error: no positional arguments" >&2 echo "$USAGE" >&2 exit 2 fi # Set environment variables and paths. This script is assumed to be run # from its root directory. export PYTHONPATH=$SCRIPTPATH/.. export PYTHONIOENCODING=utf-8 export MPLBACKEND=agg export PYTHONWARNINGS=ignore:Matplotlib:UserWarning d='/tmp/scriptcheck_'$$ mkdir -p $d retval=0 # On SIGINT clean up temporary script directory and exit. function cleanupexit { rm $d/*.py rmdir $d exit 2 } trap cleanupexit SIGINT # Define regex strings. re1="s/'maxiter' ?: ?[0-9]+/'maxiter': 2/g; " re2="s/^maxiter ?= ?[0-9]+/maxiter = 2/g; " re3="s/^N ?= ?[0-9]+/N = 32/g; " re4="s/num_samples= ?[0-9]+/num_samples = 2/g; " re5='s/\"cpu\": ?[0-9]+/\"cpu\": 1/g; ' re6="s/^downsampling_rate ?= ?[0-9]+/downsampling_rate = 12/g; " re7="s/input\(/#input\(/g; " re8="s/fig.show\(/#fig.show\(/g" # Iterate over all scripts. for f in $SCRIPTPATH/scripts/*.py; do printf "%-50s " $(basename $f) # Skip problem cases. if [ $SKIP_DOWNLOAD -eq 1 ] && grep -q '_microscopy' <<< $f; then printf "%s\n" skipped continue fi if [ $SKIP_TRAINING -eq 1 ]; then if grep -q '_datagen' <<< $f || grep -q '_train' <<< $f; then printf "%s\n" skipped continue fi fi if [ $SKIP_GPU -eq 1 ] && grep -q '_astra_3d' <<< $f; then printf "%s\n" skipped continue fi if [ $SKIP_GPU -eq 1 ] && grep -q 'ct_projector_comparison_3d' <<< $f; then printf "%s\n" skipped continue fi # Create temporary copy of script with all algorithm maxiter values set # to small number and final input statements commented out. g=$d/$(basename $f) sed -E -e "$re1$re2$re3$re4$re5$re6$re7$re8" $f > $g # Run temporary script and print status message. if output=$(timeout 180s python $g 2>&1); then printf "%s\n" succeeded else printf "%s\n" FAILED retval=1 if [ $DISPLAY_ERROR -eq 1 ]; then echo "$output" | tail -8 | sed -e 's/^/ /' fi fi # Remove temporary script. rm -f $g done # Remove temporary script directory. rmdir $d exit $retval ================================================ FILE: examples/scripts/README.rst ================================================ Usage Examples ============== Organized by Application ------------------------ Computed Tomography ^^^^^^^^^^^^^^^^^^^ `ct_abel_tv_admm.py `_ TV-Regularized Abel Inversion `ct_abel_tv_admm_tune.py `_ Parameter Tuning for TV-Regularized Abel Inversion `ct_symcone_tv_padmm.py `_ TV-Regularized Cone Beam CT for Symmetric Objects `ct_astra_noreg_pcg.py `_ CT Reconstruction with CG and PCG `ct_astra_3d_tv_admm.py `_ 3D TV-Regularized Sparse-View CT Reconstruction (ADMM Solver) `ct_astra_3d_tv_padmm.py `_ 3D TV-Regularized Sparse-View CT Reconstruction (Proximal ADMM Solver) `ct_tv_admm.py `_ TV-Regularized Sparse-View CT Reconstruction (Integrated Projector) `ct_astra_tv_admm.py `_ TV-Regularized Sparse-View CT Reconstruction (ASTRA Projector) `ct_multi_tv_admm.py `_ TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors) `ct_astra_weighted_tv_admm.py `_ TV-Regularized Low-Dose CT Reconstruction `ct_svmbir_tv_multi.py `_ TV-Regularized CT Reconstruction (Multiple Algorithms) `ct_svmbir_ppp_bm3d_admm_cg.py `_ PPP (with BM3D) CT Reconstruction (ADMM with CG Subproblem Solver) `ct_svmbir_ppp_bm3d_admm_prox.py `_ PPP (with BM3D) CT Reconstruction (ADMM with Fast SVMBIR Prox) `ct_fan_svmbir_ppp_bm3d_admm_prox.py `_ PPP (with BM3D) Fan-Beam CT Reconstruction `ct_modl_train_foam2.py `_ CT Training and Reconstruction with MoDL `ct_odp_train_foam2.py `_ CT Training and Reconstruction with ODP `ct_unet_train_foam2.py `_ CT Training and Reconstructions with UNet `ct_projector_comparison_2d.py `_ 2D X-ray Transform Comparison `ct_projector_comparison_3d.py `_ 3D X-ray Transform Comparison Deconvolution ^^^^^^^^^^^^^ `deconv_circ_tv_admm.py `_ Circulant Blur Image Deconvolution with TV Regularization `deconv_tv_admm.py `_ Image Deconvolution with TV Regularization (ADMM Solver) `deconv_tv_padmm.py `_ Image Deconvolution with TV Regularization (Proximal ADMM Solver) `deconv_tv_admm_tune.py `_ Parameter Tuning for Image Deconvolution with TV Regularization (ADMM Solver) `deconv_microscopy_tv_admm.py `_ Deconvolution Microscopy (Single Channel) `deconv_microscopy_allchn_tv_admm.py `_ Deconvolution Microscopy (All Channels) `deconv_ppp_bm3d_admm.py `_ PPP (with BM3D) Image Deconvolution (ADMM Solver) `deconv_ppp_bm3d_apgm.py `_ PPP (with BM3D) Image Deconvolution (APGM Solver) `deconv_ppp_dncnn_admm.py `_ PPP (with DnCNN) Image Deconvolution (ADMM Solver) `deconv_ppp_dncnn_padmm.py `_ PPP (with DnCNN) Image Deconvolution (Proximal ADMM Solver) `deconv_ppp_bm4d_admm.py `_ PPP (with BM4D) Volume Deconvolution `deconv_modl_train_foam1.py `_ Deconvolution Training and Reconstructions with MoDL `deconv_odp_train_foam1.py `_ Deconvolution Training and Reconstructions with ODP Sparse Coding ^^^^^^^^^^^^^ `sparsecode_nn_admm.py `_ Non-Negative Basis Pursuit DeNoising (ADMM) `sparsecode_nn_apgm.py `_ Non-Negative Basis Pursuit DeNoising (APGM) `sparsecode_conv_admm.py `_ Convolutional Sparse Coding (ADMM) `sparsecode_conv_md_admm.py `_ Convolutional Sparse Coding with Mask Decoupling (ADMM) `sparsecode_apgm.py `_ Basis Pursuit DeNoising (APGM) `sparsecode_poisson_apgm.py `_ Non-negative Poisson Loss Reconstruction (APGM) Miscellaneous ^^^^^^^^^^^^^ `demosaic_ppp_bm3d_admm.py `_ PPP (with BM3D) Image Demosaicing `superres_ppp_dncnn_admm.py `_ PPP (with DnCNN) Image Superresolution `denoise_l1tv_admm.py `_ ℓ1 Total Variation Denoising `denoise_ptv_pdhg.py `_ Polar Total Variation Denoising (PDHG) `denoise_tv_admm.py `_ Total Variation Denoising (ADMM) `denoise_tv_apgm.py `_ Total Variation Denoising with Constraint (APGM) `denoise_tv_multi.py `_ Comparison of Optimization Algorithms for Total Variation Denoising `denoise_approx_tv_multi.py `_ Denoising with Approximate Total Variation Proximal Operator `denoise_cplx_tv_nlpadmm.py `_ Complex Total Variation Denoising with NLPADMM Solver `denoise_cplx_tv_pdhg.py `_ Complex Total Variation Denoising with PDHG Solver `denoise_dncnn_universal.py `_ Comparison of DnCNN Variants for Image Denoising `diffusercam_tv_admm.py `_ TV-Regularized 3D DiffuserCam Reconstruction `video_rpca_admm.py `_ Video Decomposition via Robust PCA `ct_datagen_foam2.py `_ CT Data Generation for NN Training `deconv_datagen_bsds.py `_ Blurred Data Generation (Natural Images) for NN Training `deconv_datagen_foam1.py `_ Blurred Data Generation (Foams) for NN Training `denoise_datagen_bsds.py `_ Noisy Data Generation for NN Training Organized by Regularization --------------------------- Plug and Play Priors ^^^^^^^^^^^^^^^^^^^^ `ct_svmbir_ppp_bm3d_admm_cg.py `_ PPP (with BM3D) CT Reconstruction (ADMM with CG Subproblem Solver) `ct_svmbir_ppp_bm3d_admm_prox.py `_ PPP (with BM3D) CT Reconstruction (ADMM with Fast SVMBIR Prox) `ct_fan_svmbir_ppp_bm3d_admm_prox.py `_ PPP (with BM3D) Fan-Beam CT Reconstruction `deconv_ppp_bm3d_admm.py `_ PPP (with BM3D) Image Deconvolution (ADMM Solver) `deconv_ppp_bm3d_apgm.py `_ PPP (with BM3D) Image Deconvolution (APGM Solver) `deconv_ppp_dncnn_admm.py `_ PPP (with DnCNN) Image Deconvolution (ADMM Solver) `deconv_ppp_dncnn_padmm.py `_ PPP (with DnCNN) Image Deconvolution (Proximal ADMM Solver) `deconv_ppp_bm4d_admm.py `_ PPP (with BM4D) Volume Deconvolution `demosaic_ppp_bm3d_admm.py `_ PPP (with BM3D) Image Demosaicing `superres_ppp_dncnn_admm.py `_ PPP (with DnCNN) Image Superresolution Total Variation ^^^^^^^^^^^^^^^ `ct_abel_tv_admm.py `_ TV-Regularized Abel Inversion `ct_abel_tv_admm_tune.py `_ Parameter Tuning for TV-Regularized Abel Inversion `ct_symcone_tv_padmm.py `_ TV-Regularized Cone Beam CT for Symmetric Objects `ct_tv_admm.py `_ TV-Regularized Sparse-View CT Reconstruction (Integrated Projector) `ct_multi_tv_admm.py `_ TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors) `ct_astra_tv_admm.py `_ TV-Regularized Sparse-View CT Reconstruction (ASTRA Projector) `ct_astra_3d_tv_admm.py `_ 3D TV-Regularized Sparse-View CT Reconstruction (ADMM Solver) `ct_astra_3d_tv_padmm.py `_ 3D TV-Regularized Sparse-View CT Reconstruction (Proximal ADMM Solver) `ct_astra_weighted_tv_admm.py `_ TV-Regularized Low-Dose CT Reconstruction `ct_svmbir_tv_multi.py `_ TV-Regularized CT Reconstruction (Multiple Algorithms) `deconv_circ_tv_admm.py `_ Circulant Blur Image Deconvolution with TV Regularization `deconv_tv_admm.py `_ Image Deconvolution with TV Regularization (ADMM Solver) `deconv_tv_admm_tune.py `_ Parameter Tuning for Image Deconvolution with TV Regularization (ADMM Solver) `deconv_tv_padmm.py `_ Image Deconvolution with TV Regularization (Proximal ADMM Solver) `deconv_microscopy_tv_admm.py `_ Deconvolution Microscopy (Single Channel) `deconv_microscopy_allchn_tv_admm.py `_ Deconvolution Microscopy (All Channels) `denoise_l1tv_admm.py `_ ℓ1 Total Variation Denoising `denoise_ptv_pdhg.py `_ Polar Total Variation Denoising (PDHG) `denoise_tv_admm.py `_ Total Variation Denoising (ADMM) `denoise_tv_apgm.py `_ Total Variation Denoising with Constraint (APGM) `denoise_tv_multi.py `_ Comparison of Optimization Algorithms for Total Variation Denoising `denoise_approx_tv_multi.py `_ Denoising with Approximate Total Variation Proximal Operator `denoise_cplx_tv_nlpadmm.py `_ Complex Total Variation Denoising with NLPADMM Solver `denoise_cplx_tv_pdhg.py `_ Complex Total Variation Denoising with PDHG Solver `diffusercam_tv_admm.py `_ TV-Regularized 3D DiffuserCam Reconstruction Sparsity ^^^^^^^^ `diffusercam_tv_admm.py `_ TV-Regularized 3D DiffuserCam Reconstruction `sparsecode_nn_admm.py `_ Non-Negative Basis Pursuit DeNoising (ADMM) `sparsecode_nn_apgm.py `_ Non-Negative Basis Pursuit DeNoising (APGM) `sparsecode_conv_admm.py `_ Convolutional Sparse Coding (ADMM) `sparsecode_conv_md_admm.py `_ Convolutional Sparse Coding with Mask Decoupling (ADMM) `sparsecode_apgm.py `_ Basis Pursuit DeNoising (APGM) `sparsecode_poisson_apgm.py `_ Non-negative Poisson Loss Reconstruction (APGM) `video_rpca_admm.py `_ Video Decomposition via Robust PCA Machine Learning ^^^^^^^^^^^^^^^^ `ct_datagen_foam2.py `_ CT Data Generation for NN Training `ct_modl_train_foam2.py `_ CT Training and Reconstruction with MoDL `ct_odp_train_foam2.py `_ CT Training and Reconstruction with ODP `ct_unet_train_foam2.py `_ CT Training and Reconstructions with UNet `deconv_datagen_bsds.py `_ Blurred Data Generation (Natural Images) for NN Training `deconv_datagen_foam1.py `_ Blurred Data Generation (Foams) for NN Training `deconv_modl_train_foam1.py `_ Deconvolution Training and Reconstructions with MoDL `deconv_odp_train_foam1.py `_ Deconvolution Training and Reconstructions with ODP `denoise_datagen_bsds.py `_ Noisy Data Generation for NN Training `denoise_dncnn_train_bsds.py `_ Training of DnCNN for Denoising `denoise_dncnn_universal.py `_ Comparison of DnCNN Variants for Image Denoising Organized by Optimization Algorithm ----------------------------------- ADMM ^^^^ `ct_abel_tv_admm.py `_ TV-Regularized Abel Inversion `ct_abel_tv_admm_tune.py `_ Parameter Tuning for TV-Regularized Abel Inversion `ct_symcone_tv_padmm.py `_ TV-Regularized Cone Beam CT for Symmetric Objects `ct_astra_tv_admm.py `_ TV-Regularized Sparse-View CT Reconstruction (ASTRA Projector) `ct_tv_admm.py `_ TV-Regularized Sparse-View CT Reconstruction (Integrated Projector) `ct_astra_3d_tv_admm.py `_ 3D TV-Regularized Sparse-View CT Reconstruction (ADMM Solver) `ct_astra_weighted_tv_admm.py `_ TV-Regularized Low-Dose CT Reconstruction `ct_multi_tv_admm.py `_ TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors) `ct_svmbir_tv_multi.py `_ TV-Regularized CT Reconstruction (Multiple Algorithms) `ct_svmbir_ppp_bm3d_admm_cg.py `_ PPP (with BM3D) CT Reconstruction (ADMM with CG Subproblem Solver) `ct_svmbir_ppp_bm3d_admm_prox.py `_ PPP (with BM3D) CT Reconstruction (ADMM with Fast SVMBIR Prox) `ct_fan_svmbir_ppp_bm3d_admm_prox.py `_ PPP (with BM3D) Fan-Beam CT Reconstruction `deconv_circ_tv_admm.py `_ Circulant Blur Image Deconvolution with TV Regularization `deconv_tv_admm.py `_ Image Deconvolution with TV Regularization (ADMM Solver) `deconv_tv_admm_tune.py `_ Parameter Tuning for Image Deconvolution with TV Regularization (ADMM Solver) `deconv_microscopy_tv_admm.py `_ Deconvolution Microscopy (Single Channel) `deconv_microscopy_allchn_tv_admm.py `_ Deconvolution Microscopy (All Channels) `deconv_ppp_bm3d_admm.py `_ PPP (with BM3D) Image Deconvolution (ADMM Solver) `deconv_ppp_dncnn_admm.py `_ PPP (with DnCNN) Image Deconvolution (ADMM Solver) `deconv_ppp_bm4d_admm.py `_ PPP (with BM4D) Volume Deconvolution `diffusercam_tv_admm.py `_ TV-Regularized 3D DiffuserCam Reconstruction `sparsecode_nn_admm.py `_ Non-Negative Basis Pursuit DeNoising (ADMM) `sparsecode_conv_admm.py `_ Convolutional Sparse Coding (ADMM) `sparsecode_conv_md_admm.py `_ Convolutional Sparse Coding with Mask Decoupling (ADMM) `demosaic_ppp_bm3d_admm.py `_ PPP (with BM3D) Image Demosaicing `superres_ppp_dncnn_admm.py `_ PPP (with DnCNN) Image Superresolution `denoise_l1tv_admm.py `_ ℓ1 Total Variation Denoising `denoise_tv_admm.py `_ Total Variation Denoising (ADMM) `denoise_tv_multi.py `_ Comparison of Optimization Algorithms for Total Variation Denoising `denoise_approx_tv_multi.py `_ Denoising with Approximate Total Variation Proximal Operator `video_rpca_admm.py `_ Video Decomposition via Robust PCA Linearized ADMM ^^^^^^^^^^^^^^^ `ct_svmbir_tv_multi.py `_ TV-Regularized CT Reconstruction (Multiple Algorithms) `denoise_tv_multi.py `_ Comparison of Optimization Algorithms for Total Variation Denoising Proximal ADMM ^^^^^^^^^^^^^ `ct_astra_3d_tv_padmm.py `_ 3D TV-Regularized Sparse-View CT Reconstruction (Proximal ADMM Solver) `deconv_tv_padmm.py `_ Image Deconvolution with TV Regularization (Proximal ADMM Solver) `denoise_tv_multi.py `_ Comparison of Optimization Algorithms for Total Variation Denoising `deconv_ppp_dncnn_padmm.py `_ PPP (with DnCNN) Image Deconvolution (Proximal ADMM Solver) Non-linear Proximal ADMM ^^^^^^^^^^^^^^^^^^^^^^^^ `denoise_cplx_tv_nlpadmm.py `_ Complex Total Variation Denoising with NLPADMM Solver PDHG ^^^^ `ct_svmbir_tv_multi.py `_ TV-Regularized CT Reconstruction (Multiple Algorithms) `denoise_ptv_pdhg.py `_ Polar Total Variation Denoising (PDHG) `denoise_tv_multi.py `_ Comparison of Optimization Algorithms for Total Variation Denoising `denoise_cplx_tv_pdhg.py `_ Complex Total Variation Denoising with PDHG Solver PGM ^^^ `deconv_ppp_bm3d_apgm.py `_ PPP (with BM3D) Image Deconvolution (APGM Solver) `sparsecode_apgm.py `_ Basis Pursuit DeNoising (APGM) `sparsecode_nn_apgm.py `_ Non-Negative Basis Pursuit DeNoising (APGM) `sparsecode_poisson_apgm.py `_ Non-negative Poisson Loss Reconstruction (APGM) `denoise_tv_apgm.py `_ Total Variation Denoising with Constraint (APGM) `denoise_approx_tv_multi.py `_ Denoising with Approximate Total Variation Proximal Operator PCG ^^^ `ct_astra_noreg_pcg.py `_ CT Reconstruction with CG and PCG ================================================ FILE: examples/scripts/ct_abel_tv_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" TV-Regularized Abel Inversion ============================= This example demonstrates a total variation (TV) regularized Abel inversion by solving the problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + \lambda \| C \mathbf{x} \|_1 \;,$$ where $A$ is the Abel projector (with an implementation based on a projector from PyAbel :cite:`pyabel-2022`), $\mathbf{y}$ is the measured data, $C$ is a 2D finite difference operator, and $\mathbf{x}$ is the solution. """ import numpy as np import scico.numpy as snp from scico import functional, linop, loss, metric, plot from scico.examples import create_circular_phantom from scico.linop.xray.abel import AbelTransform from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Create a ground truth image. """ N = 256 # image size x_gt = create_circular_phantom((N, N), [0.4 * N, 0.2 * N, 0.1 * N], [1, 0, 0.5]) """ Set up the forward operator and create a test measurement. """ A = AbelTransform(x_gt.shape) y = A @ x_gt np.random.seed(12345) y = y + np.random.normal(size=y.shape).astype(np.float32) """ Compute inverse Abel transform solution. """ x_inv = A.inverse(y) """ Set up the problem to be solved. Anisotropic TV, which gives slightly better performance than isotropic TV for this problem, is used here. """ f = loss.SquaredL2Loss(y=y, A=A) λ = 2.35e1 # ℓ1 norm regularization parameter g = λ * functional.L1Norm() # Note the use of anisotropic TV C = linop.FiniteDifference(input_shape=x_gt.shape) """ Set up ADMM solver object. """ ρ = 1.03e2 # ADMM penalty parameter maxiter = 100 # number of ADMM iterations cg_tol = 1e-4 # CG relative tolerance cg_maxiter = 25 # maximum CG iterations per ADMM iteration solver = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], x0=snp.clip(x_inv, 0.0, 1.0), maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), itstat_options={"display": True, "period": 10}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") solver.solve() x_tv = snp.clip(solver.x, 0.0, 1.0) """ Show results. """ norm = plot.matplotlib.colors.Normalize(vmin=-0.1, vmax=1.2) fig, ax = plot.subplots(nrows=2, ncols=2, figsize=(12, 12)) plot.imview(x_gt, title="Ground Truth", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 0], norm=norm) plot.imview(y, title="Measurement", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 1]) plot.imview( x_inv, title="Inverse Abel: %.2f (dB)" % metric.psnr(x_gt, x_inv), cmap=plot.cm.Blues, fig=fig, ax=ax[1, 0], norm=norm, ) plot.imview( x_tv, title="TV-Regularized Inversion: %.2f (dB)" % metric.psnr(x_gt, x_tv), cmap=plot.cm.Blues, fig=fig, ax=ax[1, 1], norm=norm, ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_abel_tv_admm_tune.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Parameter Tuning for TV-Regularized Abel Inversion ================================================== This example demonstrates the use of [scico.ray.tune](../_autosummary/scico.ray.tune.rst) to tune parameters for the companion [example script](ct_abel_tv_admm.rst). The `ray.tune` class API is used in this example. This script is hard-coded to run on CPU only to avoid the large number of warnings that are emitted when GPU resources are requested but not available, and due to the difficulty of suppressing these warnings in a way that does not force use of the CPU only. To enable GPU usage, comment out the `os.environ` statements near the beginning of the script, and change the value of the "gpu" entry in the `resources` dict from 0 to 1. Note that two environment variables are set to suppress the warnings because `JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but this change has yet to be correctly implemented (see [google/jax#6805](https://github.com/google/jax/issues/6805) and [google/jax#10272](https://github.com/google/jax/pull/10272)). """ # isort: off import os os.environ["JAX_PLATFORM_NAME"] = "cpu" os.environ["JAX_PLATFORMS"] = "cpu" import numpy as np import logging import ray ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 import scico.numpy as snp from scico import functional, linop, loss, metric, plot from scico.examples import create_circular_phantom from scico.linop.xray.abel import AbelTransform from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.ray import tune """ Create a ground truth image. """ N = 256 # image size x_gt = create_circular_phantom((N, N), [0.4 * N, 0.2 * N, 0.1 * N], [1, 0, 0.5]) """ Set up the forward operator and create a test measurement. """ A = AbelTransform(x_gt.shape) y = A @ x_gt np.random.seed(12345) y = y + np.random.normal(size=y.shape).astype(np.float32) """ Compute inverse Abel transform solution for use as initial solution. """ x_inv = A.inverse(y) x0 = snp.clip(x_inv, 0.0, 1.0) """ Define performance evaluation class. """ class Trainable(tune.Trainable): """Parameter evaluation class.""" def setup(self, config, x_gt, x0, y): """This method initializes a new parameter evaluation object. It is called once when a new parameter evaluation object is created. The `config` parameter is a dict of specific parameters for evaluation of a single parameter set (a pair of parameters in this case). The remaining parameters are objects that are passed to the evaluation function via the ray object store. """ # Get arrays passed by tune call. self.x_gt, self.x0, self.y = snp.array(x_gt), snp.array(x0), snp.array(y) # Set up problem to be solved. self.A = AbelTransform(self.x_gt.shape) self.f = loss.SquaredL2Loss(y=self.y, A=self.A) self.C = linop.FiniteDifference(input_shape=self.x_gt.shape) self.reset_config(config) def reset_config(self, config): """This method is only required when `scico.ray.tune.Tuner` is initialized with `reuse_actors` set to ``True`` (the default). In this case, a set of parameter evaluation processes and corresponding objects are created once (including initialization via a call to the `setup` method), and this method is called when switching to evaluation of a different parameter configuration. If `reuse_actors` is set to ``False``, then a new process and object are created for each parameter configuration, and this method is not used. """ # Extract solver parameters from config dict. λ, ρ = config["lambda"], config["rho"] # Set up parameter-dependent functional. g = λ * functional.L1Norm() # Define solver. cg_tol = 1e-4 cg_maxiter = 25 self.solver = ADMM( f=self.f, g_list=[g], C_list=[self.C], rho_list=[ρ], x0=self.x0, maxiter=10, subproblem_solver=LinearSubproblemSolver( cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter} ), ) return True def step(self): """This method is called for each step in the evaluation of a single parameter configuration. The maximum number of times it can be called is controlled by the `num_iterations` parameter in the initialization of a `scico.ray.tune.Tuner` object. """ # Perform 10 solver steps for every ray.tune step x_tv = snp.clip(self.solver.solve(), 0.0, 1.0) return {"psnr": float(metric.psnr(self.x_gt, x_tv))} """ Define parameter search space and resources per trial. """ config = {"lambda": tune.loguniform(1e0, 1e2), "rho": tune.loguniform(1e1, 1e3)} resources = {"gpu": 0, "cpu": 1} # gpus per trial, cpus per trial """ Run parameter search. """ tuner = tune.Tuner( tune.with_parameters(Trainable, x_gt=x_gt, x0=x0, y=y), param_space=config, resources=resources, metric="psnr", mode="max", num_samples=100, # perform 100 parameter evaluations num_iterations=10, # perform at most 10 steps for each parameter evaluation ) results = tuner.fit() ray.shutdown() """ Display best parameters and corresponding performance. """ best_result = results.get_best_result() best_config = best_result.config print(f"Best PSNR: {best_result.metrics['psnr']:.2f} dB") print("Best config: " + ", ".join([f"{k}: {v:.2e}" for k, v in best_config.items()])) """ Plot parameter values visited during parameter search. Marker sizes are proportional to number of iterations run at each parameter pair. The best point in the parameter space is indicated in red. """ fig = plot.figure(figsize=(8, 8)) trials = results.get_dataframe() for t in trials.iloc: n = t["training_iteration"] plot.plot( t["config/lambda"], t["config/rho"], ptyp="loglog", lw=0, ms=(0.5 + 1.5 * n), marker="o", mfc="blue", mec="blue", fig=fig, ) plot.plot( best_config["lambda"], best_config["rho"], ptyp="loglog", title="Parameter search sampling locations\n(marker size proportional to number of iterations)", xlbl=r"$\rho$", ylbl=r"$\lambda$", lw=0, ms=5.0, marker="o", mfc="red", mec="red", fig=fig, ) ax = fig.axes[0] ax.set_xlim([config["rho"].lower, config["rho"].upper]) ax.set_ylim([config["lambda"].lower, config["lambda"].upper]) fig.show() """ Plot parameter values visited during parameter search and corresponding reconstruction PSNRs.The best point in the parameter space is indicated in red. """ 𝜌 = [t["config/rho"] for t in trials.iloc] 𝜆 = [t["config/lambda"] for t in trials.iloc] psnr = [t["psnr"] for t in trials.iloc] minpsnr = min(max(psnr), 20.0) 𝜌, 𝜆, psnr = zip(*filter(lambda x: x[2] >= minpsnr, zip(𝜌, 𝜆, psnr))) fig, ax = plot.subplots(figsize=(10, 8)) sc = ax.scatter(𝜌, 𝜆, c=psnr, cmap=plot.cm.plasma_r) fig.colorbar(sc) plot.plot( best_config["lambda"], best_config["rho"], ptyp="loglog", lw=0, ms=12.0, marker="2", mfc="red", mec="red", fig=fig, ax=ax, ) ax.set_xscale("log") ax.set_yscale("log") ax.set_xlabel(r"$\rho$") ax.set_ylabel(r"$\lambda$") ax.set_title("PSNR at each sample location\n(values below 20 dB omitted)") fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_astra_3d_tv_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" 3D TV-Regularized Sparse-View CT Reconstruction (ADMM Solver) ============================================================= This example demonstrates solution of a sparse-view, 3D CT reconstruction problem with isotropic total variation (TV) regularization $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} \|_2^2 + \lambda \| D \mathbf{x} \|_{2,1} \;,$$ where $C$ is the X-ray transform (the CT forward projection operator), $\mathbf{y}$ is the sinogram, $D$ is a 3D finite difference operator, and $\mathbf{x}$ is the reconstructed image. In this example the problem is solved via ADMM, while proximal ADMM is used in a [companion example](ct_astra_3d_tv_padmm.rst). """ import numpy as np from mpl_toolkits.axes_grid1 import make_axes_locatable import scico.numpy as snp from scico import functional, linop, loss, metric, plot from scico.examples import create_tangle_phantom from scico.linop.xray.astra import XRayTransform3D from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Create a ground truth image and projector. """ Nx = 128 Ny = 256 Nz = 64 tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz)) n_projection = 10 # number of projections angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles C = XRayTransform3D( tangle.shape, det_count=[Nz, max(Nx, Ny)], det_spacing=[1.0, 1.0], angles=angles ) # CT projection operator y = C @ tangle # sinogram """ Set up problem and solver. """ λ = 2e0 # ℓ2,1 norm regularization parameter ρ = 5e0 # ADMM penalty parameter maxiter = 25 # number of ADMM iterations cg_tol = 1e-4 # CG relative tolerance cg_maxiter = 25 # maximum CG iterations per ADMM iteration # The append=0 option makes the results of horizontal and vertical # finite differences the same shape, which is required for the L21Norm, # which is used so that g(Ax) corresponds to isotropic TV. D = linop.FiniteDifference(input_shape=tangle.shape, append=0) g = λ * functional.L21Norm() f = loss.SquaredL2Loss(y=y, A=C) solver = ADMM( f=f, g_list=[g], C_list=[D], rho_list=[ρ], x0=C.T(y), maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), itstat_options={"display": True, "period": 5}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") tangle_recon = solver.solve() print( "TV Restruction\nSNR: %.2f (dB), MAE: %.3f" % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)) ) """ Show the recovered image. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 6)) plot.imview( tangle[32], title="Ground truth (central slice)", cmap=plot.cm.Blues, cbar=None, fig=fig, ax=ax[0], ) plot.imview( tangle_recon[32], title="TV Reconstruction (central slice)\nSNR: %.2f (dB), MAE: %.3f" % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)), cmap=plot.cm.Blues, fig=fig, ax=ax[1], ) divider = make_axes_locatable(ax[1]) cax = divider.append_axes("right", size="5%", pad=0.2) fig.colorbar(ax[1].get_images()[0], cax=cax, label="arbitrary units") fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_astra_3d_tv_padmm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" 3D TV-Regularized Sparse-View CT Reconstruction (Proximal ADMM Solver) ====================================================================== This example demonstrates solution of a sparse-view, 3D CT reconstruction problem with isotropic total variation (TV) regularization $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} \|_2^2 + \lambda \| D \mathbf{x} \|_{2,1} \;,$$ where $C$ is the X-ray transform (the CT forward projection operator), $\mathbf{y}$ is the sinogram, $D$ is a 3D finite difference operator, and $\mathbf{x}$ is the reconstructed image. In this example the problem is solved via proximal ADMM, while standard ADMM is used in a [companion example](ct_astra_3d_tv_admm.rst). """ import numpy as np from mpl_toolkits.axes_grid1 import make_axes_locatable import scico.numpy as snp from scico import functional, linop, loss, metric, plot from scico.examples import create_tangle_phantom from scico.linop.xray.astra import XRayTransform3D, angle_to_vector from scico.optimize import ProximalADMM from scico.util import device_info """ Create a ground truth image and projector. """ Nx = 128 Ny = 256 Nz = 64 tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz)) n_projection = 10 # number of projections angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles det_spacing = [1.0, 1.0] det_count = [Nz, max(Nx, Ny)] vectors = angle_to_vector(det_spacing, angles) # It would have been more straightforward to use the det_spacing and angles keywords # in this case (since vectors is just computed directly from these two quantities), but # the more general form is used here as a demonstration. C = XRayTransform3D(tangle.shape, det_count=det_count, vectors=vectors) # CT projection operator y = C @ tangle # sinogram r""" Set up problem and solver. We want to minimize the functional $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} \|_2^2 + \lambda \| D \mathbf{x} \|_{2,1} \;,$$ where $C$ is the X-ray transform and $D$ is a finite difference operator. This problem can be expressed as $$\mathrm{argmin}_{\mathbf{x}, \mathbf{z}} \; (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 + \lambda \| \mathbf{z}_1 \|_{2,1} \;\; \text{such that} \;\; \mathbf{z}_0 = C \mathbf{x} \;\; \text{and} \;\; \mathbf{z}_1 = D \mathbf{x} \;,$$ which can be written in the form of a standard ADMM problem $$\mathrm{argmin}_{\mathbf{x}, \mathbf{z}} \; f(\mathbf{x}) + g(\mathbf{z}) \;\; \text{such that} \;\; A \mathbf{x} + B \mathbf{z} = \mathbf{c}$$ with $$f = 0 \qquad g = g_0 + g_1$$ $$g_0(\mathbf{z}_0) = (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 \qquad g_1(\mathbf{z}_1) = \lambda \| \mathbf{z}_1 \|_{2,1}$$ $$A = \left( \begin{array}{c} C \\ D \end{array} \right) \qquad B = \left( \begin{array}{cc} -I & 0 \\ 0 & -I \end{array} \right) \qquad \mathbf{c} = \left( \begin{array}{c} 0 \\ 0 \end{array} \right) \;.$$ This is a more complex splitting than that used in the [companion example](ct_astra_3d_tv_admm.rst), but it allows the use of a proximal ADMM solver in a way that avoids the need for the conjugate gradient sub-iterations used by the ADMM solver in the [companion example](ct_astra_3d_tv_admm.rst). """ 𝛼 = 1e2 # improve problem conditioning by balancing C and D components of A λ = 2e0 # ℓ2,1 norm regularization parameter ρ = 5e-3 # ADMM penalty parameter maxiter = 1000 # number of ADMM iterations f = functional.ZeroFunctional() g0 = loss.SquaredL2Loss(y=y) g1 = (λ / 𝛼) * functional.L21Norm() g = functional.SeparableFunctional((g0, g1)) D = linop.FiniteDifference(input_shape=tangle.shape, append=0) A = linop.VerticalStack((C, 𝛼 * D)) mu, nu = ProximalADMM.estimate_parameters(A) solver = ProximalADMM( f=f, g=g, A=A, B=None, rho=ρ, mu=mu, nu=nu, maxiter=maxiter, itstat_options={"display": True, "period": 50}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") tangle_recon = solver.solve() print( "TV Restruction\nSNR: %.2f (dB), MAE: %.3f" % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)) ) """ Show the recovered image. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 6)) plot.imview( tangle[32], title="Ground truth (central slice)", cmap=plot.cm.Blues, cbar=None, fig=fig, ax=ax[0], ) plot.imview( tangle_recon[32], title="TV Reconstruction (central slice)\nSNR: %.2f (dB), MAE: %.3f" % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)), cmap=plot.cm.Blues, fig=fig, ax=ax[1], ) divider = make_axes_locatable(ax[1]) cax = divider.append_axes("right", size="5%", pad=0.2) fig.colorbar(ax[1].get_images()[0], cax=cax, label="arbitrary units") fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_astra_noreg_pcg.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" CT Reconstruction with CG and PCG ================================= This example demonstrates a simple iterative CT reconstruction using conjugate gradient (CG) and preconditioned conjugate gradient (PCG) algorithms to solve the problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 \;,$$ where $A$ is the X-ray transform (the CT forward projection operator), $\mathbf{y}$ is the sinogram, and $\mathbf{x}$ is the reconstructed image. """ from time import time import numpy as np import jax.numpy as jnp from xdesign import Foam, discrete_phantom from scico import loss, plot from scico.linop import CircularConvolve from scico.linop.xray.astra import XRayTransform2D from scico.solver import cg """ Create a ground truth image. """ N = 256 # phantom size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) x_gt = jnp.array(x_gt) # convert to jax type """ Configure a CT projection operator and generate synthetic measurements. """ n_projection = N # matches the phantom size so this is not few-view CT angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles A = 1 / N * XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator y = A @ x_gt # sinogram r""" Forward and back project a single pixel (Kronecker delta) to compute an approximate impulse response for $\mathbf{A}^T \mathbf{A}$. """ H = CircularConvolve.from_operator(A.T @ A) r""" Invert in the Fourier domain to form a preconditioner $\mathbf{M} \approx (\mathbf{A}^T \mathbf{A})^{-1}$ (see :cite:`clinthorne-1993-preconditioning` Section V.A. for more details). """ # γ limits the gain of the preconditioner; higher gives a weaker filter. γ = 1e-2 # The imaginary part comes from numerical errors in A.T and needs to be # removed to ensure H is symmetric, positive definite. frequency_response = np.real(H.h_dft) inv_frequency_response = 1 / (frequency_response + γ) # Using circular convolution without padding is sufficient here because # M is approximate anyway. M = CircularConvolve(inv_frequency_response, x_gt.shape, h_is_dft=True) r""" Check that $\mathbf{M}$ does approximately invert $\mathbf{A}^T \mathbf{A}$. """ plot_args = dict( norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.5), cmap=plot.matplotlib.cm.Blues_r ) fig, axes = plot.subplots(nrows=1, ncols=3, figsize=(12, 4.5)) plot.imview(x_gt, title="Ground truth, $x_{gt}$", fig=fig, ax=axes[0], **plot_args) plot.imview( A.T @ A @ x_gt, title=r"$\mathbf{A}^T \mathbf{A} x_{gt}$", fig=fig, ax=axes[1], **plot_args ) plot.imview( M @ A.T @ A @ x_gt, title=r"$\mathbf{M} \mathbf{A}^T \mathbf{A} x_{gt}$", fig=fig, ax=axes[2], **plot_args, ) fig.suptitle(r"$\mathbf{M}$ approximately inverts $\mathbf{A}^T \mathbf{A}$") fig.tight_layout() fig.colorbar( axes[2].get_images()[0], ax=axes, location="right", shrink=0.82, pad=0.02, label="Arbitrary Units", ) fig.show() """ Reconstruct with both standard and preconditioned conjugate gradient. """ start_time = time() x_cg, info_cg = cg( A.T @ A, A.T @ y, jnp.zeros(A.input_shape, dtype=A.input_dtype), tol=1e-5, info=True, ) time_cg = time() - start_time start_time = time() x_pcg, info_pcg = cg( A.T @ A, A.T @ y, jnp.zeros(A.input_shape, dtype=A.input_dtype), tol=2e-5, # preconditioning affects the problem scaling so tol differs between CG and PCG info=True, M=M, ) time_pcg = time() - start_time """ Compare CG and PCG in terms of reconstruction time and data fidelity. """ f_cg = loss.SquaredL2Loss(y=A.T @ y, A=A.T @ A) f_data = loss.SquaredL2Loss(y=y, A=A) print( f"{'Method':8s}{'Iterations':>11s}{'Time (s)':>12s}{'||ATAx - ATy||':>17s}{'||Ax - y||':>15s}" ) print( f"{'CG':8s}{info_cg['num_iter']:>11d}{time_cg:>12.2f}{f_cg(x_cg):>17.2e}{f_data(x_cg):>15.2e}" ) print( f"{'PCG':8s}{info_pcg['num_iter']:>11d}{time_pcg:>12.2f}{f_cg(x_pcg):>17.2e}" f"{f_data(x_pcg):>15.2e}" ) input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_astra_tv_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" TV-Regularized Sparse-View CT Reconstruction (ASTRA Projector) ============================================================== This example demonstrates solution of a sparse-view CT reconstruction problem with isotropic total variation (TV) regularization $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ where $A$ is the X-ray transform (the CT forward projection operator), $\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and $\mathbf{x}$ is the reconstructed image. This example uses the CT projector provided by the astra package, while the companion [example script](ct_tv_admm.rst) uses the projector integrated into scico. """ import numpy as np from mpl_toolkits.axes_grid1 import make_axes_locatable from xdesign import Foam, discrete_phantom import scico.numpy as snp from scico import functional, linop, loss, metric, plot from scico.linop.xray.astra import XRayTransform2D from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Create a ground truth image. """ N = 512 # phantom size np.random.seed(1234) x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)) """ Configure CT projection operator and generate synthetic measurements. """ n_projection = 45 # number of projections angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles det_count = int(N * 1.05 / np.sqrt(2.0)) det_spacing = np.sqrt(2) A = XRayTransform2D(x_gt.shape, det_count, det_spacing, angles) # CT projection operator y = A @ x_gt # sinogram """ Set up problem functional and ADMM solver object. """ λ = 2e0 # ℓ1 norm regularization parameter ρ = 5e0 # ADMM penalty parameter maxiter = 25 # number of ADMM iterations cg_tol = 1e-4 # CG relative tolerance cg_maxiter = 25 # maximum CG iterations per ADMM iteration # The append=0 option makes the results of horizontal and vertical # finite differences the same shape, which is required for the L21Norm, # which is used so that g(Cx) corresponds to isotropic TV. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) g = λ * functional.L21Norm() f = loss.SquaredL2Loss(y=y, A=A) x0 = snp.clip(A.fbp(y), 0, 1.0) solver = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], x0=x0, maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), itstat_options={"display": True, "period": 5}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") solver.solve() hist = solver.itstat_object.history(transpose=True) x_reconstruction = snp.clip(solver.x, 0, 1.0) """ Show the recovered image. """ fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(x_gt, title="Ground truth", cbar=None, fig=fig, ax=ax[0]) plot.imview( x0, title="FBP Reconstruction: \nSNR: %.2f (dB), MAE: %.3f" % (metric.snr(x_gt, x0), metric.mae(x_gt, x0)), cbar=None, fig=fig, ax=ax[1], ) plot.imview( x_reconstruction, title="TV Reconstruction\nSNR: %.2f (dB), MAE: %.3f" % (metric.snr(x_gt, x_reconstruction), metric.mae(x_gt, x_reconstruction)), fig=fig, ax=ax[2], ) divider = make_axes_locatable(ax[2]) cax = divider.append_axes("right", size="5%", pad=0.2) fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units") fig.show() """ Plot convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( hist.Objective, title="Objective function", xlbl="Iteration", ylbl="Functional value", fig=fig, ax=ax[0], ) plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_astra_weighted_tv_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" TV-Regularized Low-Dose CT Reconstruction ========================================= This example demonstrates solution of a low-dose CT reconstruction problem with isotropic total variation (TV) regularization $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_W^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ where $A$ is the X-ray transform (the CT forward projection), $\mathbf{y}$ is the sinogram, the norm weighting $W$ is chosen so that the weighted norm is an approximation to the Poisson negative log likelihood :cite:`sauer-1993-local`, $C$ is a 2D finite difference operator, and $\mathbf{x}$ is the reconstructed image. """ import numpy as np from xdesign import Soil, discrete_phantom import scico.numpy as snp from scico import functional, linop, loss, metric, plot from scico.linop.xray.astra import XRayTransform2D from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Create a ground truth image. """ N = 512 # phantom size np.random.seed(0) x_gt = discrete_phantom(Soil(porosity=0.80), size=384) x_gt = np.ascontiguousarray(np.pad(x_gt, (64, 64))) x_gt = np.clip(x_gt, 0, np.inf) # clip to positive values x_gt = snp.array(x_gt) # convert to jax type """ Configure CT projection operator and generate synthetic measurements. """ n_projection = 360 # number of projections Io = 1e3 # source flux 𝛼 = 1e-2 # attenuation coefficient angles = np.linspace(0, 2 * np.pi, n_projection, endpoint=False) # evenly spaced projection angles A = XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator y_c = A @ x_gt # sinogram r""" Add Poisson noise to projections according to $$\mathrm{counts} \sim \mathrm{Poi}\left(I_0 \exp (- \alpha A \mathbf{x} ) \right)$$ $$\mathbf{y} = - \frac{1}{\alpha} \log\left(\mathrm{counts} / I_0\right) \;.$$ We use the NumPy random functionality so we can generate using 64-bit numbers. """ counts = np.random.poisson(Io * snp.exp(-𝛼 * A @ x_gt)) counts = np.clip(counts, a_min=1, a_max=np.inf) # replace any 0s count with 1 y = -1 / 𝛼 * np.log(counts / Io) y = snp.array(y) # convert back to float32 as a jax array """ Set up post processing. For this example, we clip all reconstructions to the range of the ground truth. """ def postprocess(x): return snp.clip(x, 0, snp.max(x_gt)) """ Compute an FBP reconstruction as an initial guess. """ x0 = postprocess(A.fbp(y)) r""" Set up and solve the un-weighted reconstruction problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;.$$ """ # Note that rho and lambda were selected via a parameter sweep (not # shown here). ρ = 2.5e3 # ADMM penalty parameter lambda_unweighted = 3e2 # regularization strength maxiter = 100 # number of ADMM iterations cg_tol = 1e-5 # CG relative tolerance cg_maxiter = 10 # maximum CG iterations per ADMM iteration f = loss.SquaredL2Loss(y=y, A=A) admm_unweighted = ADMM( f=f, g_list=[lambda_unweighted * functional.L21Norm()], C_list=[linop.FiniteDifference(x_gt.shape, append=0)], rho_list=[ρ], x0=x0, maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), itstat_options={"display": True, "period": 10}, ) print(f"Solving on {device_info()}\n") admm_unweighted.solve() x_unweighted = postprocess(admm_unweighted.x) r""" Set up and solve the weighted reconstruction problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_W^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ where $$W = \mathrm{diag}( \mathrm{counts} / I_0 ) \;.$$ The data fidelity term in this formulation follows :cite:`sauer-1993-local` (9) except for the scaling by $I_0$, which we use to maintain balance between the data and regularization terms if $I_0$ changes. """ lambda_weighted = 5e1 weights = snp.array(counts / Io) f = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights)) admm_weighted = ADMM( f=f, g_list=[lambda_weighted * functional.L21Norm()], C_list=[linop.FiniteDifference(x_gt.shape, append=0)], rho_list=[ρ], maxiter=maxiter, x0=x0, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), itstat_options={"display": True, "period": 10}, ) print() admm_weighted.solve() x_weighted = postprocess(admm_weighted.x) """ Show recovered images. """ def plot_recon(x, title, ax): """Plot an image with title indicating error metrics.""" plot.imview( x, title=f"{title}\nSNR: {metric.snr(x_gt, x):.2f} (dB), MAE: {metric.mae(x_gt, x):.3f}", fig=fig, ax=ax, ) fig, ax = plot.subplots(nrows=2, ncols=2, figsize=(11, 10)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0]) plot_recon(x0, "FBP Reconstruction", ax=ax[0, 1]) plot_recon(x_unweighted, "Unweighted TV Reconstruction", ax=ax[1, 0]) plot_recon(x_weighted, "Weighted TV Reconstruction", ax=ax[1, 1]) for ax_ in ax.ravel(): ax_.set_xlim(64, 448) ax_.set_ylim(64, 448) fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01) fig.colorbar( ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="arbitrary units" ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_datagen_foam2.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ CT Data Generation for NN Training ================================== This example demonstrates how to generate synthetic CT data for training neural network models. If desired, a basic reconstruction can be generated using filtered back projection (FBP). """ # isort: off import os import numpy as np import logging import ray ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 # Set an arbitrary processor count (only applies if GPU is not available). os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" from scico import plot from scico.flax.examples import load_ct_data """ Read data from cache or generate if not available. """ N = 256 # phantom size train_nimg = 536 # number of training images test_nimg = 64 # number of testing images nimg = train_nimg + test_nimg n_projection = 45 # CT views trdt, ttdt = load_ct_data(train_nimg, test_nimg, N, n_projection, verbose=True) """ Plot randomly selected sample. """ indx_tr = np.random.randint(0, train_nimg) indx_te = np.random.randint(0, test_nimg) fig, axes = plot.subplots(nrows=2, ncols=3, figsize=(9, 9)) plot.imview( trdt["img"][indx_tr, ..., 0], title="Ground truth - Training Sample", fig=fig, ax=axes[0, 0] ) plot.imview( trdt["sino"][indx_tr, ..., 0], title="Sinogram - Training Sample", fig=fig, ax=axes[0, 1] ) plot.imview( trdt["fbp"][indx_tr, ..., 0], title="FBP - Training Sample", fig=fig, ax=axes[0, 2], ) plot.imview( ttdt["img"][indx_te, ..., 0], title="Ground truth - Testing Sample", fig=fig, ax=axes[1, 0], ) plot.imview( ttdt["sino"][indx_te, ..., 0], title="Sinogram - Testing Sample", fig=fig, ax=axes[1, 1] ) plot.imview( ttdt["fbp"][indx_te, ..., 0], title="FBP - Testing Sample", fig=fig, ax=axes[1, 2], ) fig.suptitle(r"Training and Testing samples") fig.tight_layout() fig.colorbar( axes[0, 2].get_images()[0], ax=axes, shrink=0.5, pad=0.05, label="Arbitrary Units", ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ PPP (with BM3D) Fan-Beam CT Reconstruction ========================================== This example demonstrates solution of a fan-beam tomographic reconstruction problem using the Plug-and-Play Priors framework :cite:`venkatakrishnan-2013-plugandplay2`, using BM3D :cite:`dabov-2008-image` as a denoiser and SVMBIR :cite:`svmbir-2020` for tomographic projection. This example uses the data fidelity term as one of the ADMM $g$ functionals so that the optimization with respect to the data fidelity is able to exploit the internal prox of the `SVMBIRExtendedLoss` functional. We solve the problem in two different ways: 1. Approximating the fan-beam geometry using parallel-beam and using the parallel beam projector to compute the reconstruction. 2. Using the correct fan-beam geometry to perform a reconstruction. """ import numpy as np import matplotlib.pyplot as plt import svmbir from matplotlib.ticker import MaxNLocator from xdesign import Foam, discrete_phantom import scico.numpy as snp from scico import metric, plot from scico.functional import BM3D from scico.linop import Diagonal, Identity from scico.linop.xray.svmbir import SVMBIRExtendedLoss, XRayTransform from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Generate a ground truth image. """ N = 256 # image size density = 0.025 # attenuation density of the image np.random.seed(1234) pad_len = 5 x_gt = discrete_phantom( Foam(size_range=[0.075, 0.005], gap=2e-3, porosity=1.0), size=N - 2 * pad_len ) x_gt = x_gt / np.max(x_gt) * density x_gt = np.pad(x_gt, pad_len) x_gt[x_gt < 0] = 0 """ Generate tomographic projector and sinogram for fan beam and parallel beam. For fan beam, use view angles spanning 2π since unlike parallel beam, views at 0 and π are not equivalent. """ num_angles = int(N / 2) num_channels = N # Use angles in the range [0, 2*pi] for fan beam angles = snp.linspace(0, 2 * snp.pi, num_angles, endpoint=False, dtype=snp.float32) dist_source_detector = 1500.0 magnification = 1.2 A_fan = XRayTransform( x_gt.shape, angles, num_channels, geometry="fan-curved", dist_source_detector=dist_source_detector, magnification=magnification, ) A_parallel = XRayTransform( x_gt.shape, angles, num_channels, geometry="parallel", ) sino_fan = A_fan @ x_gt """ Impose Poisson noise on sinograms. Higher max_intensity means less noise. """ def add_poisson_noise(sino, max_intensity): expected_counts = max_intensity * np.exp(-sino) noisy_counts = np.random.poisson(expected_counts).astype(np.float32) noisy_counts[noisy_counts == 0] = 1 # deal with 0s y = -np.log(noisy_counts / max_intensity) return y y_fan = add_poisson_noise(sino_fan, max_intensity=500) """ Reconstruct using default prior of SVMBIR :cite:`svmbir-2020`. """ weights_fan = svmbir.calc_weights(y_fan, weight_type="transmission") x_mrf_fan = svmbir.recon( np.array(y_fan[:, np.newaxis]), np.array(angles), weights=weights_fan[:, np.newaxis], num_rows=N, num_cols=N, positivity=True, verbose=0, stop_threshold=0.0, geometry="fan-curved", dist_source_detector=dist_source_detector, magnification=magnification, delta_channel=1.0, delta_pixel=1.0 / magnification, )[0] x_mrf_parallel = svmbir.recon( np.array(y_fan[:, np.newaxis]), np.array(angles), weights=weights_fan[:, np.newaxis], num_rows=N, num_cols=N, positivity=True, verbose=0, stop_threshold=0.0, geometry="parallel", )[0] """ Convert numpy arrays to jax arrays. """ y_fan = snp.array(y_fan) x0_fan = snp.array(x_mrf_fan) weights_fan = snp.array(weights_fan) x0_parallel = snp.array(x_mrf_parallel) """ Set problem parameters and BM3D pseudo-functional. """ ρ = 10 # ADMM penalty parameter σ = density * 0.6 # denoiser sigma g0 = σ * ρ * BM3D() """ Set up problem using `SVMBIRExtendedLoss`. """ f_extloss_fan = SVMBIRExtendedLoss( y=y_fan, A=A_fan, W=Diagonal(weights_fan), scale=0.5, positivity=True, prox_kwargs={"maxiter": 5, "ctol": 0.0}, ) f_extloss_parallel = SVMBIRExtendedLoss( y=y_fan, A=A_parallel, W=Diagonal(weights_fan), scale=0.5, positivity=True, prox_kwargs={"maxiter": 5, "ctol": 0.0}, ) solver_extloss_fan = ADMM( f=None, g_list=[f_extloss_fan, g0], C_list=[Identity(x_mrf_fan.shape), Identity(x_mrf_fan.shape)], rho_list=[ρ, ρ], x0=x0_fan, maxiter=20, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), itstat_options={"display": True, "period": 5}, ) solver_extloss_parallel = ADMM( f=None, g_list=[f_extloss_parallel, g0], C_list=[Identity(x_mrf_parallel.shape), Identity(x_mrf_parallel.shape)], rho_list=[ρ, ρ], x0=x0_parallel, maxiter=20, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), itstat_options={"display": True, "period": 5}, ) """ Run the ADMM solvers. """ print(f"Solving on {device_info()}\n") x_extloss_fan = solver_extloss_fan.solve() hist_extloss_fan = solver_extloss_fan.itstat_object.history(transpose=True) print() x_extloss_parallel = solver_extloss_parallel.solve() hist_extloss_parallel = solver_extloss_parallel.itstat_object.history(transpose=True) """ Show the recovered images. The parallel beam reconstruction is poor because the parallel beam is a poor approximation of the specific fan beam geometry used here. """ norm = plot.matplotlib.colors.Normalize(vmin=-0.1 * density, vmax=1.2 * density) fig, ax = plt.subplots(1, 3, figsize=(20, 7)) plot.imview(img=x_gt, title="Ground Truth Image", cbar=True, fig=fig, ax=ax[0], norm=norm) plot.imview( img=x_mrf_parallel, title=f"Parallel-beam MRF (PSNR: {metric.psnr(x_gt, x_mrf_parallel):.2f} dB)", cbar=True, fig=fig, ax=ax[1], norm=norm, ) plot.imview( img=x_extloss_parallel, title=f"Parallel-beam Extended Loss (PSNR: {metric.psnr(x_gt, x_extloss_parallel):.2f} dB)", cbar=True, fig=fig, ax=ax[2], norm=norm, ) fig.show() fig, ax = plt.subplots(1, 3, figsize=(20, 7)) plot.imview(img=x_gt, title="Ground Truth Image", cbar=True, fig=fig, ax=ax[0], norm=norm) plot.imview( img=x_mrf_fan, title=f"Fan-beam MRF (PSNR: {metric.psnr(x_gt, x_mrf_fan):.2f} dB)", cbar=True, fig=fig, ax=ax[1], norm=norm, ) plot.imview( img=x_extloss_fan, title=f"Fan-beam Extended Loss (PSNR: {metric.psnr(x_gt, x_extloss_fan):.2f} dB)", cbar=True, fig=fig, ax=ax[2], norm=norm, ) fig.show() """ Plot convergence statistics. """ fig, ax = plt.subplots(1, 2, figsize=(15, 6)) plot.plot( snp.array((hist_extloss_parallel.Prml_Rsdl, hist_extloss_parallel.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals for parallel-beam reconstruction", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[0], ) ax[0].set_ylim([1e-1, 1e1]) ax[0].xaxis.set_major_locator(MaxNLocator(integer=True)) plot.plot( snp.array((hist_extloss_fan.Prml_Rsdl, hist_extloss_fan.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals for fan-beam reconstruction", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[1], ) ax[1].set_ylim([1e-1, 1e1]) ax[1].xaxis.set_major_locator(MaxNLocator(integer=True)) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_modl_train_foam2.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" CT Training and Reconstruction with MoDL ======================================== This example demonstrates the training and application of a model-based deep learning (MoDL) architecture described in :cite:`aggarwal-2019-modl` applied to a CT reconstruction problem. The source images are foam phantoms generated with xdesign. A class [scico.flax.MoDLNet](../_autosummary/scico.flax.rst#scico.flax.MoDLNet) implements the MoDL architecture, which solves the optimization problem $$\mathrm{argmin}_{\mathbf{x}} \; \| A \mathbf{x} - \mathbf{y} \|_2^2 + \lambda \, \| \mathbf{x} - \mathrm{D}_w(\mathbf{x})\|_2^2 \;,$$ where $A$ is a tomographic projector, $\mathbf{y}$ is a set of sinograms, $\mathrm{D}_w$ is the regularization (a denoiser), and $\mathbf{x}$ is the set of reconstructed images. The MoDL abstracts the iterative solution by an unrolled network where each iteration corresponds to a different stage in the MoDL network and updates the prediction by solving $$\mathbf{x}^{k+1} = (A^T A + \lambda \, I)^{-1} (A^T \mathbf{y} + \lambda \, \mathbf{z}^k) \;,$$ via conjugate gradient. In the expression, $k$ is the index of the stage (iteration), $\mathbf{z}^k = \mathrm{ResNet}(\mathbf{x}^{k})$ is the regularization (a denoiser implemented as a residual convolutional neural network), $\mathbf{x}^k$ is the output of the previous stage, $\lambda > 0$ is a learned regularization parameter, and $I$ is the identity operator. The output of the final stage is the set of reconstructed images. """ # isort: off import os from functools import partial from time import time import numpy as np import logging import ray ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 import jax try: from jax.extend.backend import get_backend # introduced in jax 0.4.33 except ImportError: from jax.lib.xla_bridge import get_backend from mpl_toolkits.axes_grid1 import make_axes_locatable from scico import flax as sflax from scico import metric, plot from scico.flax.examples import load_ct_data from scico.flax.train.traversals import clip_positive, construct_traversal from scico.linop.xray import XRayTransform2D """ Prepare parallel processing. Set an arbitrary processor count (only applies if GPU is not available). """ os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" platform = get_backend().platform print("Platform: ", platform) """ Read data from cache or generate if not available. """ N = 256 # phantom size train_nimg = 536 # number of training images test_nimg = 64 # number of testing images nimg = train_nimg + test_nimg n_projection = 45 # CT views trdt, ttdt = load_ct_data(train_nimg, test_nimg, N, n_projection, verbose=True) """ Build CT projection operator. Parameters are chosen so that the operator is equivalent to the one used to generate the training data. """ angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles A = XRayTransform2D( input_shape=(N, N), angles=angles, det_count=int(N * 1.05 / np.sqrt(2.0)), dx=1.0 / np.sqrt(2), ) A = (1.0 / N) * A # normalize projection operator """ Build training and testing structures. Inputs are the sinograms and outputs are the original generated foams. Keep training and testing partitions. """ numtr = 100 numtt = 16 train_ds = {"image": trdt["sino"][:numtr], "label": trdt["img"][:numtr]} test_ds = {"image": ttdt["sino"][:numtt], "label": ttdt["img"][:numtt]} """ Define configuration dictionary for model and training loop. Parameters have been selected for demonstration purposes and relatively short training. The model depth is akin to the number of unrolled iterations in the MoDL model. The block depth controls the number of layers at each unrolled iteration. The number of filters is uniform throughout the iterations. The iterations used for the conjugate gradient (CG) solver can also be specified. Better performance may be obtained by increasing depth, block depth, number of filters, CG iterations, or training epochs, but may require longer training times. """ # model configuration model_conf = { "depth": 10, "num_filters": 64, "block_depth": 4, "cg_iter_1": 3, "cg_iter_2": 8, } # training configuration train_conf: sflax.ConfigDict = { "seed": 12345, "opt_type": "SGD", "momentum": 0.9, "batch_size": 16, "num_epochs": 20, "base_learning_rate": 1e-2, "warmup_epochs": 0, "log_every_steps": 40, "log": True, "checkpointing": True, } """ Construct functionality for ensuring that the learned regularization parameter is always positive. """ lmbdatrav = construct_traversal("lmbda") # select lmbda parameters in model lmbdapos = partial( clip_positive, # apply this function traversal=lmbdatrav, # to lmbda parameters in model minval=5e-4, ) """ Print configuration of distributed run. """ print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}") print(f"JAX local devices: {jax.local_devices()}\n") """ Check for iterated trained model. If not found, construct MoDLNet model, using only one iteration (depth) in model and few CG iterations for faster intialization. Run first stage (initialization) training loop followed by a second stage (depth iterations) training loop. """ channels = train_ds["image"].shape[-1] workdir2 = os.path.join( os.path.expanduser("~"), ".cache", "scico", "examples", "modl_ct_out", "iterated" ) stats_object_ini = None stats_object = None checkpoint_files = [] for dirpath, dirnames, filenames in os.walk(workdir2): checkpoint_files = [fn for fn in filenames] if len(checkpoint_files) > 0: model = sflax.MoDLNet( operator=A, depth=model_conf["depth"], channels=channels, num_filters=model_conf["num_filters"], block_depth=model_conf["block_depth"], cg_iter=model_conf["cg_iter_2"], ) train_conf["post_lst"] = [lmbdapos] # Parameters for 2nd stage train_conf["workdir"] = workdir2 train_conf["opt_type"] = "ADAM" train_conf["num_epochs"] = 150 # Construct training object trainer = sflax.BasicFlaxTrainer( train_conf, model, train_ds, test_ds, ) start_time = time() modvar, stats_object = trainer.train() time_train = time() - start_time time_init = 0.0 epochs_init = 0 else: # One iteration (depth) in model and few CG iterations model = sflax.MoDLNet( operator=A, depth=1, channels=channels, num_filters=model_conf["num_filters"], block_depth=model_conf["block_depth"], cg_iter=model_conf["cg_iter_1"], ) # First stage: initialization training loop. workdir1 = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_ct_out") train_conf["workdir"] = workdir1 train_conf["post_lst"] = [lmbdapos] # Construct training object trainer = sflax.BasicFlaxTrainer( train_conf, model, train_ds, test_ds, ) start_time = time() modvar, stats_object_ini = trainer.train() time_init = time() - start_time epochs_init = train_conf["num_epochs"] print( f"{'MoDLNet init':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}{'':3s}" f"{'time[s]:':21s}{time_init:>7.2f}" ) # Second stage: depth iterations training loop. model.depth = model_conf["depth"] model.cg_iter = model_conf["cg_iter_2"] train_conf["opt_type"] = "ADAM" train_conf["num_epochs"] = 150 train_conf["workdir"] = workdir2 # Construct training object, include current model parameters trainer = sflax.BasicFlaxTrainer( train_conf, model, train_ds, test_ds, variables0=modvar, ) start_time = time() modvar, stats_object = trainer.train() time_train = time() - start_time """ Evaluate on testing data. """ del train_ds["image"] del train_ds["label"] fmap = sflax.FlaxMap(model, modvar) del model, modvar maxn = numtt start_time = time() output = fmap(test_ds["image"][:maxn]) time_eval = time() - start_time output = np.clip(output, a_min=0, a_max=1.0) """ Evaluate trained model in terms of reconstruction time and data fidelity. """ total_epochs = epochs_init + train_conf["num_epochs"] total_time_train = time_init + time_train snr_eval = metric.snr(test_ds["label"][:maxn], output) psnr_eval = metric.psnr(test_ds["label"][:maxn], output) print( f"{'MoDLNet training':18s}{'epochs:':2s}{total_epochs:>5d}{'':21s}" f"{'time[s]:':10s}{total_time_train:>7.2f}" ) print( f"{'MoDLNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}" f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}" ) """ Plot comparison. """ np.random.seed(123) indx = np.random.randint(0, high=maxn) fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(test_ds["label"][indx, ..., 0], title="Ground truth", cbar=None, fig=fig, ax=ax[0]) plot.imview( test_ds["image"][indx, ..., 0], title="Sinogram", cbar=None, fig=fig, ax=ax[1], ) plot.imview( output[indx, ..., 0], title="MoDLNet Reconstruction\nSNR: %.2f (dB), PSNR: %.2f" % ( metric.snr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]), metric.psnr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]), ), fig=fig, ax=ax[2], ) divider = make_axes_locatable(ax[2]) cax = divider.append_axes("right", size="5%", pad=0.2) fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units") fig.show() """ Plot convergence statistics. Statistics are generated only if a training cycle was done (i.e. if not reading final epoch results from checkpoint). """ if stats_object is not None and len(stats_object.iterations) > 0: hist = stats_object.history(transpose=True) fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( np.array((hist.Train_Loss, hist.Eval_Loss)).T, x=hist.Epoch, ptyp="semilogy", title="Loss function", xlbl="Epoch", ylbl="Loss value", lgnd=("Train", "Test"), fig=fig, ax=ax[0], ) plot.plot( np.array((hist.Train_SNR, hist.Eval_SNR)).T, x=hist.Epoch, title="Metric", xlbl="Epoch", ylbl="SNR (dB)", lgnd=("Train", "Test"), fig=fig, ax=ax[1], ) fig.show() # Stats for initialization loop if stats_object_ini is not None and len(stats_object_ini.iterations) > 0: hist = stats_object_ini.history(transpose=True) fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( np.array((hist.Train_Loss, hist.Eval_Loss)).T, ptyp="semilogy", title="Loss function - Initialization", xlbl="Epoch", ylbl="Loss value", lgnd=("Train", "Test"), fig=fig, ax=ax[0], ) plot.plot( np.array((hist.Train_SNR, hist.Eval_SNR)).T, title="Metric - Initialization", xlbl="Epoch", ylbl="SNR (dB)", lgnd=("Train", "Test"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_multi_tv_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors) ================================================================== This example demonstrates solution of a sparse-view CT reconstruction problem with isotropic total variation (TV) regularization $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ where $A$ is the X-ray transform (the CT forward projection operator), $\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and $\mathbf{x}$ is the reconstructed image. The solution is computed and compared for all three 2D CT projectors available in scico, using a sinogram computed with the astra projector. """ import numpy as np from xdesign import Foam, discrete_phantom import scico.numpy as snp from scico import functional, linop, loss, metric, plot from scico.linop.xray import XRayTransform2D, astra, svmbir from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Create a ground truth image. """ N = 512 # phantom size np.random.seed(1234) x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)) """ Define CT geometry and construct array of (approximately) equivalent projectors. """ n_projection = 45 # number of projections angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles det_count = int(N * 1.05 / np.sqrt(2.0)) det_spacing = np.sqrt(2) projectors = { "astra": astra.XRayTransform2D( x_gt.shape, det_count, det_spacing, angles - np.pi / 2.0 ), # astra "svmbir": svmbir.XRayTransform( x_gt.shape, 2 * np.pi - angles, det_count, delta_pixel=1.0, delta_channel=det_spacing ), # svmbir "scico": XRayTransform2D((N, N), angles, det_count=det_count, dx=1 / det_spacing), # scico } """ Compute common sinogram using astra projector. """ A = projectors["astra"] noise = np.random.normal(size=(n_projection, det_count)).astype(np.float32) y = A @ x_gt + 2.0 * noise """ Construct initial solution for regularized problem. """ x0 = A.fbp(y) """ Solve the same problem using the different projectors. """ print(f"Solving on {device_info()}") x_rec, hist = {}, {} for p in projectors.keys(): print(f"\nSolving with {p} projector") # Set up ADMM solver object. λ = 2e1 # L1 norm regularization parameter ρ = 1e3 # ADMM penalty parameter maxiter = 100 # number of ADMM iterations cg_tol = 1e-4 # CG relative tolerance cg_maxiter = 50 # maximum CG iterations per ADMM iteration # The append=0 option makes the results of horizontal and vertical # finite differences the same shape, which is required for the L21Norm, # which is used so that g(Cx) corresponds to isotropic TV. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) g = λ * functional.L21Norm() A = projectors[p] f = loss.SquaredL2Loss(y=y, A=A) # Set up the solver. solver = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], x0=x0, maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), itstat_options={"display": True, "period": 5}, ) # Run the solver. solver.solve() hist[p] = solver.itstat_object.history(transpose=True) x_rec[p] = solver.x if p == "scico": x_rec[p] = x_rec[p] * det_spacing # to match ASTRA's scaling """ Compare reconstruction results. """ print("Reconstruction SNR:") for p in projectors.keys(): print(f" {(p + ':'):7s} {metric.snr(x_gt, x_rec[p]):5.2f} dB") """ Display sinogram. """ fig, ax = plot.subplots(nrows=1, ncols=1, figsize=(15, 3)) plot.imview(y, title="sinogram", fig=fig, ax=ax) fig.show() """ Plot convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(12, 5)) plot.plot( np.array([hist[p].Objective for p in projectors.keys()]).T, title="Objective function", xlbl="Iteration", ylbl="Functional value", lgnd=projectors.keys(), fig=fig, ax=ax[0], ) plot.plot( np.array([hist[p].Prml_Rsdl for p in projectors.keys()]).T, ptyp="semilogy", title="Primal Residual", xlbl="Iteration", fig=fig, ax=ax[1], ) plot.plot( np.array([hist[p].Dual_Rsdl for p in projectors.keys()]).T, ptyp="semilogy", title="Dual Residual", xlbl="Iteration", fig=fig, ax=ax[2], ) fig.show() """ Show the recovered images. """ fig, ax = plot.subplots(nrows=1, ncols=4, figsize=(15, 5)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0]) for n, p in enumerate(projectors.keys()): plot.imview( x_rec[p], title="%s SNR: %.2f (dB)" % (p, metric.snr(x_gt, x_rec[p])), fig=fig, ax=ax[n + 1], ) for ax in ax: ax.get_images()[0].set_clim(-0.1, 1.1) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_odp_train_foam2.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" CT Training and Reconstruction with ODP ======================================= This example demonstrates the training of the unrolled optimization with deep priors (ODP) gradient descent architecture described in :cite:`diamond-2018-odp` applied to a CT reconstruction problem. The source images are foam phantoms generated with xdesign. A class [scico.flax.ODPNet](../_autosummary/scico.flax.rst#scico.flax.ODPNet) implements the ODP architecture, which solves the optimization problem $$\mathrm{argmin}_{\mathbf{x}} \; \| A \mathbf{x} - \mathbf{y} \|_2^2 + r(\mathbf{x}) \;,$$ where $A$ is a tomographic projector, $\mathbf{y}$ is a set of sinograms, $r$ is a regularizer and $\mathbf{x}$ is the set of reconstructed images. The ODP, gradient descent architecture, abstracts the iterative solution by an unrolled network where each iteration corresponds to a different stage in the ODP network and updates the prediction by solving $$\mathbf{x}^{k+1} = \mathrm{argmin}_{\mathbf{x}} \; \alpha_k \| A \mathbf{x} - \mathbf{y} \|_2^2 + \frac{1}{2} \| \mathbf{x} - \mathbf{x}^k - \mathbf{x}^{k+1/2} \|_2^2 \;,$$ which for the CT problem, using gradient descent, corresponds to $$\mathbf{x}^{k+1} = \mathbf{x}^k + \mathbf{x}^{k+1/2} - \alpha_k \, A^T \, (A \mathbf{x}^k - \mathbf{y}) \;,$$ where $k$ is the index of the stage (iteration), $\mathbf{x}^k + \mathbf{x}^{k+1/2} = \mathrm{ResNet}(\mathbf{x}^{k})$ is the regularization (implemented as a residual convolutional neural network), $\mathbf{x}^k$ is the output of the previous stage and $\alpha_k > 0$ is a learned stage-wise parameter weighting the contribution of the fidelity term. The output of the final stage is the set of reconstructed images. """ # isort: off import os from functools import partial from time import time import numpy as np import logging import ray ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 # Set an arbitrary processor count (only applies if GPU is not available). os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" import jax try: from jax.extend.backend import get_backend # introduced in jax 0.4.33 except ImportError: from jax.lib.xla_bridge import get_backend from mpl_toolkits.axes_grid1 import make_axes_locatable from scico import flax as sflax from scico import metric, plot from scico.flax.examples import load_ct_data from scico.flax.train.traversals import clip_positive, construct_traversal from scico.linop.xray import XRayTransform2D platform = get_backend().platform print("Platform: ", platform) """ Read data from cache or generate if not available. """ N = 256 # phantom size train_nimg = 536 # number of training images test_nimg = 64 # number of testing images nimg = train_nimg + test_nimg n_projection = 45 # CT views trdt, ttdt = load_ct_data(train_nimg, test_nimg, N, n_projection, verbose=True) """ Build CT projection operator. Parameters are chosen so that the operator is equivalent to the one used to generate the training data. """ angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles A = XRayTransform2D( input_shape=(N, N), angles=angles, det_count=int(N * 1.05 / np.sqrt(2.0)), dx=1.0 / np.sqrt(2), ) A = (1.0 / N) * A # normalize projection operator """ Build training and testing structures. Inputs are the sinograms and outputs are the original generated foams. Keep training and testing partitions. """ numtr = 320 numtt = 32 train_ds = {"image": trdt["sino"][:numtr], "label": trdt["img"][:numtr]} test_ds = {"image": ttdt["sino"][:numtt], "label": ttdt["img"][:numtt]} """ Define configuration dictionary for model and training loop. Parameters have been selected for demonstration purposes and relatively short training. The model depth is akin to the number of unrolled iterations in the MoDL model. The block depth controls the number of layers at each unrolled iteration. The number of filters is uniform throughout the iterations. The iterations used for the conjugate gradient (CG) solver can also be specified. Better performance may be obtained by increasing depth, block depth, number of filters, CG iterations, or training epochs, but may require longer training times. """ # model configuration model_conf = { "depth": 8, "num_filters": 64, "block_depth": 6, } # training configuration train_conf: sflax.ConfigDict = { "seed": 1234, "opt_type": "ADAM", "batch_size": 16, "num_epochs": 200, "base_learning_rate": 1e-3, "warmup_epochs": 0, "log_every_steps": 160, "log": True, "checkpointing": True, } """ Construct functionality for ensuring that the learned fidelity weight parameter is always positive. """ alphatrav = construct_traversal("alpha") # select alpha parameters in model alphapost = partial( clip_positive, # apply this function traversal=alphatrav, # to alpha parameters in model minval=1e-3, ) """ Print configuration of distributed run. """ print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}") print(f"JAX local devices: {jax.local_devices()}\n") """ Construct ODPNet model. """ channels = train_ds["image"].shape[-1] model = sflax.ODPNet( operator=A, depth=model_conf["depth"], channels=channels, num_filters=model_conf["num_filters"], block_depth=model_conf["block_depth"], odp_block=sflax.inverse.ODPGrDescBlock, alpha_ini=1e-2, ) """ Run training loop. """ workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "odp_ct_out") train_conf["workdir"] = workdir train_conf["post_lst"] = [alphapost] # Construct training object trainer = sflax.BasicFlaxTrainer( train_conf, model, train_ds, test_ds, ) modvar, stats_object = trainer.train() """ Evaluate on testing data. """ del train_ds["image"] del train_ds["label"] fmap = sflax.FlaxMap(model, modvar) del model, modvar maxn = numtt start_time = time() output = fmap(test_ds["image"][:maxn]) time_eval = time() - start_time output = np.clip(output, a_min=0, a_max=1.0) epochs = train_conf["num_epochs"] """ Evaluate trained model in terms of reconstruction time and data fidelity. """ snr_eval = metric.snr(test_ds["label"][:maxn], output) psnr_eval = metric.psnr(test_ds["label"][:maxn], output) print( f"{'ODPNet training':18s}{'epochs:':2s}{epochs:>5d}{'':21s}" f"{'time[s]:':10s}{trainer.train_time:>7.2f}" ) print( f"{'ODPNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}" f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}" ) """ Plot comparison. """ np.random.seed(123) indx = np.random.randint(0, high=maxn) fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(test_ds["label"][indx, ..., 0], title="Ground truth", cbar=None, fig=fig, ax=ax[0]) plot.imview( test_ds["image"][indx, ..., 0], title="Sinogram", cbar=None, fig=fig, ax=ax[1], ) plot.imview( output[indx, ..., 0], title="ODPNet Reconstruction\nSNR: %.2f (dB), PSNR: %.2f" % ( metric.snr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]), metric.psnr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]), ), fig=fig, ax=ax[2], ) divider = make_axes_locatable(ax[2]) cax = divider.append_axes("right", size="5%", pad=0.2) fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units") fig.show() """ Plot convergence statistics. Statistics are generated only if a training cycle was done (i.e. if not reading final epoch results from checkpoint). """ if stats_object is not None and len(stats_object.iterations) > 0: hist = stats_object.history(transpose=True) fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( np.array((hist.Train_Loss, hist.Eval_Loss)).T, x=hist.Epoch, ptyp="semilogy", title="Loss function", xlbl="Epoch", ylbl="Loss value", lgnd=("Train", "Test"), fig=fig, ax=ax[0], ) plot.plot( np.array((hist.Train_SNR, hist.Eval_SNR)).T, x=hist.Epoch, title="Metric", xlbl="Epoch", ylbl="SNR (dB)", lgnd=("Train", "Test"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_projector_comparison_2d.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" 2D X-ray Transform Comparison ============================= This example compares SCICO's native 2D X-ray transform algorithm to that of the ASTRA toolbox. """ import numpy as np import jax import jax.numpy as jnp from xdesign import Foam, discrete_phantom import scico.linop.xray.astra as astra from scico import plot from scico.linop.xray import XRayTransform2D from scico.util import Timer """ Create a ground truth image. """ N = 512 x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) x_gt = jnp.array(x_gt) """ Time projector instantiation. """ num_angles = 500 angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) det_count = int(N * 1.02 / jnp.sqrt(2.0)) timer = Timer() projectors = {} timer.start("scico_init") projectors["scico"] = XRayTransform2D((N, N), angles, det_count=det_count) timer.stop("scico_init") timer.start("astra_init") projectors["astra"] = astra.XRayTransform2D( (N, N), det_count=det_count, det_spacing=np.sqrt(2), angles=angles - jnp.pi / 2.0 ) timer.stop("astra_init") """ Time first projector application, which might include JIT overhead. """ ys = {} for name, H in projectors.items(): timer_label = f"{name}_first_fwd" timer.start(timer_label) ys[name] = H @ x_gt jax.block_until_ready(ys[name]) timer.stop(timer_label) """ Compute average time for a projector application. """ num_repeats = 3 for name, H in projectors.items(): timer_label = f"{name}_avg_fwd" timer.start(timer_label) for _ in range(num_repeats): ys[name] = H @ x_gt jax.block_until_ready(ys[name]) timer.stop(timer_label) timer.td[timer_label] /= num_repeats """ Time first back projection, which might include JIT overhead. """ y = np.zeros(H.output_shape, dtype=np.float32) y[num_angles // 3, det_count // 2] = 1.0 y = jnp.array(y) HTys = {} for name, H in projectors.items(): timer_label = f"{name}_first_back" timer.start(timer_label) HTys[name] = H.T @ y jax.block_until_ready(ys[name]) timer.stop(timer_label) """ Compute average time for back projection. """ num_repeats = 3 for name, H in projectors.items(): timer_label = f"{name}_avg_back" timer.start(timer_label) for _ in range(num_repeats): HTys[name] = H.T @ y jax.block_until_ready(ys[name]) timer.stop(timer_label) timer.td[timer_label] /= num_repeats """ Display timing results. On our server, when using the GPU, the SCICO projector (both forward and backward) is faster than ASTRA. When using the CPU, it is slower for forward projection and faster for back projection. The SCICO object initialization and first back projection are slow due to JIT overhead. On our server, using the GPU: ``` init astra 4.81e-02 s init scico 2.53e-01 s first fwd astra 4.44e-02 s first fwd scico 2.82e-02 s first back astra 3.31e-02 s first back scico 2.80e-01 s avg fwd astra 4.76e-02 s avg fwd scico 2.83e-02 s avg back astra 3.96e-02 s avg back scico 6.80e-04 s ``` Using the CPU: ``` init astra 1.72e-02 s init scico 2.88e+00 s first fwd astra 1.02e+00 s first fwd scico 2.40e+00 s first back astra 1.03e+00 s first back scico 3.53e+00 s avg fwd astra 1.03e+00 s avg fwd scico 2.54e+00 s avg back astra 1.01e+00 s avg back scico 5.98e-01 s ``` """ print(f"init astra {timer.td['astra_init']:.2e} s") print(f"init scico {timer.td['scico_init']:.2e} s") print("") for tstr in ("first", "avg"): for dstr in ("fwd", "back"): for pstr in ("astra", "scico"): print( f"{tstr:5s} {dstr:4s} {pstr} {timer.td[pstr + '_' + tstr + '_' + dstr]:.2e} s" ) print() """ Show projections. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 6)) plot.imview(ys["scico"], title="SCICO projection", cbar=None, fig=fig, ax=ax[0]) plot.imview(ys["astra"], title="ASTRA projection", cbar=None, fig=fig, ax=ax[1]) fig.show() """ Show back projections of a single detector element, i.e., a line. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 6)) plot.imview(HTys["scico"], title="SCICO back projection (zoom)", cbar=None, fig=fig, ax=ax[0]) plot.imview(HTys["astra"], title="ASTRA back projection (zoom)", cbar=None, fig=fig, ax=ax[1]) for ax_i in ax: ax_i.set_xlim(2 * N / 5, N - 2 * N / 5) ax_i.set_ylim(2 * N / 5, N - 2 * N / 5) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_projector_comparison_3d.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" 3D X-ray Transform Comparison ============================= This example shows how to define a SCICO native 3D X-ray transform using ASTRA toolbox conventions and vice versa. """ import numpy as np import jax import jax.numpy as jnp import scico.linop.xray.astra as astra from scico import plot from scico.examples import create_block_phantom from scico.linop.xray import XRayTransform3D from scico.util import ContextTimer, Timer """ Create a ground truth image and set detector dimensions. """ N = 64 # use rectangular volume to check whether axes are handled correctly in_shape = (N + 1, N + 2, N + 3) x = create_block_phantom(in_shape) x = jnp.array(x) # use rectangular detector to check whether axes are handled correctly out_shape = (N, N + 1) """ Set up SCICO projection. """ num_angles = 3 rot_X = 90.0 - 16.0 rot_Y = np.linspace(0, 180, num_angles, endpoint=False) angles = np.stack(np.broadcast_arrays(rot_X, rot_Y), axis=-1) matrices = XRayTransform3D.matrices_from_euler_angles( in_shape, out_shape, "XY", angles, degrees=True ) """ Specify geometry using SCICO conventions and project. """ num_repeats = 3 timer_scico = Timer() with ContextTimer(timer_scico, "init"): H_scico = XRayTransform3D(in_shape, matrices, out_shape) with ContextTimer(timer_scico, "first_fwd"): y_scico = H_scico @ x jax.block_until_ready(y_scico) with ContextTimer(timer_scico, "avg_fwd"): for _ in range(num_repeats): y_scico = H_scico @ x jax.block_until_ready(y_scico) timer_scico.td["avg_fwd"] /= num_repeats with ContextTimer(timer_scico, "first_back"): HTy_scico = H_scico.T @ y_scico with ContextTimer(timer_scico, "avg_back"): for _ in range(num_repeats): HTy_scico = H_scico.T @ y_scico jax.block_until_ready(HTy_scico) timer_scico.td["avg_back"] /= num_repeats """ Convert SCICO geometry to ASTRA and project. """ vectors_from_scico = astra.convert_from_scico_geometry(in_shape, matrices, out_shape) timer_astra = Timer() with ContextTimer(timer_astra, "init"): H_astra_from_scico = astra.XRayTransform3D( input_shape=in_shape, det_count=out_shape, vectors=vectors_from_scico ) with ContextTimer(timer_astra, "first_fwd"): y_astra_from_scico = H_astra_from_scico @ x jax.block_until_ready(y_astra_from_scico) with ContextTimer(timer_astra, "avg_fwd"): for _ in range(num_repeats): y_astra_from_scico = H_astra_from_scico @ x jax.block_until_ready(y_astra_from_scico) timer_astra.td["avg_fwd"] /= num_repeats with ContextTimer(timer_astra, "first_back"): HTy_astra_from_scico = H_astra_from_scico.T @ y_astra_from_scico with ContextTimer(timer_astra, "avg_back"): for _ in range(num_repeats): HTy_astra_from_scico = H_astra_from_scico.T @ y_astra_from_scico jax.block_until_ready(HTy_astra_from_scico) timer_astra.td["avg_back"] /= num_repeats """ Specify geometry with ASTRA conventions and project. """ angles = np.random.rand(num_angles) * 180 # random projection angles det_spacing = [1.0, 1.0] vectors = astra.angle_to_vector(det_spacing, angles) H_astra = astra.XRayTransform3D(input_shape=in_shape, det_count=out_shape, vectors=vectors) y_astra = H_astra @ x HTy_astra = H_astra.T @ y_astra """ Convert ASTRA geometry to SCICO and project. """ P_from_astra = astra._astra_to_scico_geometry(H_astra.vol_geom, H_astra.proj_geom) H_scico_from_astra = XRayTransform3D(in_shape, P_from_astra, out_shape) y_scico_from_astra = H_scico_from_astra @ x HTy_scico_from_astra = H_scico_from_astra.T @ y_scico_from_astra """ Print timing results. """ print(f"init astra {timer_astra.td['init']:.2e} s") print(f"init scico {timer_scico.td['init']:.2e} s") print("") for tstr in ("first", "avg"): for dstr in ("fwd", "back"): for timer, pstr in zip((timer_astra, timer_scico), ("astra", "scico")): print(f"{tstr:5s} {dstr:4s} {pstr} {timer.td[tstr + '_' + dstr]:.2e} s") print() """ Show projections. """ fig, ax = plot.subplots(nrows=3, ncols=2, figsize=(8, 10)) plot.imview(y_scico[0], title="SCICO projections", cbar=None, fig=fig, ax=ax[0, 0]) plot.imview(y_scico[1], cbar=None, fig=fig, ax=ax[1, 0]) plot.imview(y_scico[2], cbar=None, fig=fig, ax=ax[2, 0]) plot.imview(y_astra_from_scico[:, 0], title="ASTRA projections", cbar=None, fig=fig, ax=ax[0, 1]) plot.imview(y_astra_from_scico[:, 1], cbar=None, fig=fig, ax=ax[1, 1]) plot.imview(y_astra_from_scico[:, 2], cbar=None, fig=fig, ax=ax[2, 1]) fig.suptitle("Using SCICO conventions") fig.tight_layout() fig.show() fig, ax = plot.subplots(nrows=3, ncols=2, figsize=(8, 10)) plot.imview(y_scico_from_astra[0], title="SCICO projections", cbar=None, fig=fig, ax=ax[0, 0]) plot.imview(y_scico_from_astra[1], cbar=None, fig=fig, ax=ax[1, 0]) plot.imview(y_scico_from_astra[2], cbar=None, fig=fig, ax=ax[2, 0]) plot.imview(y_astra[:, 0], title="ASTRA projections", cbar=None, fig=fig, ax=ax[0, 1]) plot.imview(y_astra[:, 1], cbar=None, fig=fig, ax=ax[1, 1]) plot.imview(y_astra[:, 2], cbar=None, fig=fig, ax=ax[2, 1]) fig.suptitle("Using ASTRA conventions") fig.tight_layout() fig.show() """ Show back projections. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(8, 5)) plot.imview(HTy_scico[N // 2], title="SCICO back projection", cbar=None, fig=fig, ax=ax[0]) plot.imview( HTy_astra_from_scico[N // 2], title="ASTRA back projection", cbar=None, fig=fig, ax=ax[1] ) fig.suptitle("Using SCICO conventions") fig.tight_layout() fig.show() fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(8, 5)) plot.imview( HTy_scico_from_astra[N // 2], title="SCICO back projection", cbar=None, fig=fig, ax=ax[0] ) plot.imview(HTy_astra[N // 2], title="ASTRA back projection", cbar=None, fig=fig, ax=ax[1]) fig.suptitle("Using ASTRA conventions") fig.tight_layout() fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ PPP (with BM3D) CT Reconstruction (ADMM with CG Subproblem Solver) ================================================================== This example demonstrates solution of a tomographic reconstruction problem using the Plug-and-Play Priors framework :cite:`venkatakrishnan-2013-plugandplay2`, using BM3D :cite:`dabov-2008-image` as a denoiser and SVMBIR :cite:`svmbir-2020` for tomographic projection. There are two versions of this example, solving the same problem in two different ways. This version uses the data fidelity term as the ADMM $f$, and thus the optimization with respect to the data fidelity uses CG rather than the prox of the `SVMBIRSquaredL2Loss` functional, as in the [other version](ct_svmbir_ppp_bm3d_admm_prox.rst). """ import numpy as np import matplotlib.pyplot as plt import svmbir from xdesign import Foam, discrete_phantom import scico.numpy as snp from scico import metric, plot from scico.functional import BM3D, NonNegativeIndicator from scico.linop import Diagonal, Identity from scico.linop.xray.svmbir import SVMBIRSquaredL2Loss, XRayTransform from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Generate a ground truth image. """ N = 256 # image size density = 0.025 # attenuation density of the image np.random.seed(1234) x_gt = discrete_phantom(Foam(size_range=[0.075, 0.005], gap=2e-3, porosity=1.0), size=N - 10) x_gt = x_gt / np.max(x_gt) * density x_gt = np.pad(x_gt, 5) x_gt[x_gt < 0] = 0 """ Generate tomographic projector and sinogram. """ num_angles = int(N / 2) num_channels = N angles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32) A = XRayTransform(x_gt.shape, angles, num_channels) sino = A @ x_gt """ Impose Poisson noise on sinogram. Higher max_intensity means less noise. """ max_intensity = 2000 expected_counts = max_intensity * np.exp(-sino) noisy_counts = np.random.poisson(expected_counts).astype(np.float32) noisy_counts[noisy_counts == 0] = 1 # deal with 0s y = -np.log(noisy_counts / max_intensity) """ Reconstruct using default prior of SVMBIR :cite:`svmbir-2020`. """ weights = svmbir.calc_weights(y, weight_type="transmission") x_mrf = svmbir.recon( np.array(y[:, np.newaxis]), np.array(angles), weights=weights[:, np.newaxis], num_rows=N, num_cols=N, positivity=True, verbose=0, )[0] """ Set up an ADMM solver. """ y = snp.array(y) x0 = snp.array(x_mrf) weights = snp.array(weights) ρ = 15 # ADMM penalty parameter σ = density * 0.18 # denoiser sigma f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5) g0 = σ * ρ * BM3D() g1 = NonNegativeIndicator() solver = ADMM( f=f, g_list=[g0, g1], C_list=[Identity(x_mrf.shape), Identity(x_mrf.shape)], rho_list=[ρ, ρ], x0=x0, maxiter=20, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4, "maxiter": 100}), itstat_options={"display": True, "period": 5}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x_bm3d = solver.solve() hist = solver.itstat_object.history(transpose=True) """ Show the recovered image. """ norm = plot.matplotlib.colors.Normalize(vmin=-0.1 * density, vmax=1.2 * density) fig, ax = plt.subplots(1, 3, figsize=[15, 5]) plot.imview(img=x_gt, title="Ground Truth Image", cbar=True, fig=fig, ax=ax[0], norm=norm) plot.imview( img=x_mrf, title=f"MRF (PSNR: {metric.psnr(x_gt, x_mrf):.2f} dB)", cbar=True, fig=fig, ax=ax[1], norm=norm, ) plot.imview( img=x_bm3d, title=f"BM3D (PSNR: {metric.psnr(x_gt, x_bm3d):.2f} dB)", cbar=True, fig=fig, ax=ax[2], norm=norm, ) fig.show() """ Plot convergence statistics. """ plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), ) input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ PPP (with BM3D) CT Reconstruction (ADMM with Fast SVMBIR Prox) ============================================================== This example demonstrates solution of a tomographic reconstruction problem using the Plug-and-Play Priors framework :cite:`venkatakrishnan-2013-plugandplay2`, using BM3D :cite:`dabov-2008-image` as a denoiser and SVMBIR :cite:`svmbir-2020` for tomographic projection. There are two versions of this example, solving the same problem in two different ways. This version uses the data fidelity term as one of the ADMM $g$ functionals so that the optimization with respect to the data fidelity is able to exploit the internal prox of the `SVMBIRExtendedLoss` and `SVMBIRSquaredL2Loss` functionals. The [other version](ct_svmbir_ppp_bm3d_admm_cg.rst) solves the ADMM subproblem corresponding to the data fidelity term via CG. Two ways of exploiting the SVMBIR internal prox are explored in this example: 1. Using the `SVMBIRSquaredL2Loss` together with the BM3D pseudo-functional and a non-negative indicator function, and 2. Using the `SVMBIRExtendedLoss`, which includes a non-negativity constraint, together with the BM3D pseudo-functional. """ import numpy as np import matplotlib.pyplot as plt import svmbir from matplotlib.ticker import MaxNLocator from xdesign import Foam, discrete_phantom import scico.numpy as snp from scico import metric, plot from scico.functional import BM3D, NonNegativeIndicator from scico.linop import Diagonal, Identity from scico.linop.xray.svmbir import ( SVMBIRExtendedLoss, SVMBIRSquaredL2Loss, XRayTransform, ) from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Generate a ground truth image. """ N = 256 # image size density = 0.025 # attenuation density of the image np.random.seed(1234) x_gt = discrete_phantom(Foam(size_range=[0.075, 0.005], gap=2e-3, porosity=1.0), size=N - 10) x_gt = x_gt / np.max(x_gt) * density x_gt = np.pad(x_gt, 5) x_gt[x_gt < 0] = 0 """ Generate tomographic projector and sinogram. """ num_angles = int(N / 2) num_channels = N angles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32) A = XRayTransform(x_gt.shape, angles, num_channels) sino = A @ x_gt """ Impose Poisson noise on sinogram. Higher max_intensity means less noise. """ max_intensity = 2000 expected_counts = max_intensity * np.exp(-sino) noisy_counts = np.random.poisson(expected_counts).astype(np.float32) noisy_counts[noisy_counts == 0] = 1 # deal with 0s y = -np.log(noisy_counts / max_intensity) """ Reconstruct using default prior of SVMBIR :cite:`svmbir-2020`. """ weights = svmbir.calc_weights(y, weight_type="transmission") x_mrf = svmbir.recon( np.array(y[:, np.newaxis]), np.array(angles), weights=weights[:, np.newaxis], num_rows=N, num_cols=N, positivity=True, verbose=0, )[0] """ Convert numpy arrays to jax arrays. """ y = snp.array(y) x0 = snp.array(x_mrf) weights = snp.array(weights) """ Set problem parameters and BM3D pseudo-functional. """ ρ = 10 # ADMM penalty parameter σ = density * 0.26 # denoiser sigma g0 = σ * ρ * BM3D() """ Set up problem using `SVMBIRSquaredL2Loss` and `NonNegativeIndicator`. """ f_l2loss = SVMBIRSquaredL2Loss( y=y, A=A, W=Diagonal(weights), scale=0.5, prox_kwargs={"maxiter": 5, "ctol": 0.0} ) g1 = NonNegativeIndicator() solver_l2loss = ADMM( f=None, g_list=[f_l2loss, g0, g1], C_list=[Identity(x_mrf.shape), Identity(x_mrf.shape), Identity(x_mrf.shape)], rho_list=[ρ, ρ, ρ], x0=x0, maxiter=20, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), itstat_options={"display": True, "period": 5}, ) """ Run the ADMM solver. """ print(f"Solving on {device_info()}\n") x_l2loss = solver_l2loss.solve() hist_l2loss = solver_l2loss.itstat_object.history(transpose=True) """ Set up problem using `SVMBIRExtendedLoss`, without need for `NonNegativeIndicator`. """ f_extloss = SVMBIRExtendedLoss( y=y, A=A, W=Diagonal(weights), scale=0.5, positivity=True, prox_kwargs={"maxiter": 5, "ctol": 0.0}, ) solver_extloss = ADMM( f=None, g_list=[f_extloss, g0], C_list=[Identity(x_mrf.shape), Identity(x_mrf.shape)], rho_list=[ρ, ρ], x0=x0, maxiter=20, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), itstat_options={"display": True, "period": 5}, ) """ Run the ADMM solver. """ print() x_extloss = solver_extloss.solve() hist_extloss = solver_extloss.itstat_object.history(transpose=True) """ Show the recovered images. """ norm = plot.matplotlib.colors.Normalize(vmin=-0.1 * density, vmax=1.2 * density) fig, ax = plt.subplots(2, 2, figsize=(15, 15)) plot.imview(img=x_gt, title="Ground Truth Image", cbar=True, fig=fig, ax=ax[0, 0], norm=norm) plot.imview( img=x_mrf, title=f"MRF (PSNR: {metric.psnr(x_gt, x_mrf):.2f} dB)", cbar=True, fig=fig, ax=ax[0, 1], norm=norm, ) plot.imview( img=x_l2loss, title=f"SquaredL2Loss + non-negativity (PSNR: {metric.psnr(x_gt, x_l2loss):.2f} dB)", cbar=True, fig=fig, ax=ax[1, 0], norm=norm, ) plot.imview( img=x_extloss, title=f"ExtendedLoss (PSNR: {metric.psnr(x_gt, x_extloss):.2f} dB)", cbar=True, fig=fig, ax=ax[1, 1], norm=norm, ) fig.show() """ Plot convergence statistics. """ fig, ax = plt.subplots(1, 2, figsize=(15, 5)) plot.plot( snp.array((hist_l2loss.Prml_Rsdl, hist_l2loss.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals (SquaredL2Loss + non-negativity)", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[0], ) ax[0].set_ylim([1e-1, 5e0]) ax[0].xaxis.set_major_locator(MaxNLocator(integer=True)) plot.plot( snp.array((hist_extloss.Prml_Rsdl, hist_extloss.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals (ExtendedLoss)", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[1], ) ax[1].set_ylim([1e-1, 5e0]) ax[1].xaxis.set_major_locator(MaxNLocator(integer=True)) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_svmbir_tv_multi.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" TV-Regularized CT Reconstruction (Multiple Algorithms) ====================================================== This example demonstrates the use of different optimization algorithms to solve the TV-regularized CT problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ where $A$ is the X-ray transform (implemented using the SVMBIR :cite:`svmbir-2020` tomographic projection), $\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and $\mathbf{x}$ is the reconstructed image. """ import numpy as np import matplotlib.pyplot as plt import svmbir from xdesign import Foam, discrete_phantom import scico.numpy as snp from scico import functional, linop, metric, plot from scico.linop import Diagonal from scico.linop.xray.svmbir import SVMBIRSquaredL2Loss, XRayTransform from scico.optimize import PDHG, LinearizedADMM from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Generate a ground truth image. """ N = 256 # image size density = 0.025 # attenuation density of the image np.random.seed(1234) x_gt = discrete_phantom(Foam(size_range=[0.075, 0.005], gap=2e-3, porosity=1.0), size=N - 10) x_gt = x_gt / np.max(x_gt) * density x_gt = np.pad(x_gt, 5) x_gt[x_gt < 0] = 0 """ Generate tomographic projector and sinogram. """ num_angles = int(N / 2) num_channels = N angles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32) A = XRayTransform(x_gt.shape, angles, num_channels) sino = A @ x_gt """ Impose Poisson noise on sinogram. Higher max_intensity means less noise. """ max_intensity = 2000 expected_counts = max_intensity * np.exp(-sino) noisy_counts = np.random.poisson(expected_counts).astype(np.float32) noisy_counts[noisy_counts == 0] = 1 # deal with 0s y = -snp.log(noisy_counts / max_intensity) """ Reconstruct using default prior of SVMBIR :cite:`svmbir-2020`. """ weights = svmbir.calc_weights(y, weight_type="transmission") x_mrf = svmbir.recon( np.array(y[:, np.newaxis]), np.array(angles), weights=weights[:, np.newaxis], num_rows=N, num_cols=N, positivity=True, verbose=0, )[0] """ Set up problem. """ x0 = snp.array(x_mrf) weights = snp.array(weights) λ = 1e-1 # ℓ1 norm regularization parameter f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5) g = λ * functional.L21Norm() # regularization functional # The append=0 option makes the results of horizontal and vertical finite # differences the same shape, which is required for the L21Norm. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) """ Solve via ADMM. """ solve_admm = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[2e1], x0=x0, maxiter=50, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4, "maxiter": 10}), itstat_options={"display": True, "period": 10}, ) print(f"Solving on {device_info()}\n") print("ADMM:") x_admm = solve_admm.solve() hist_admm = solve_admm.itstat_object.history(transpose=True) print(f"PSNR: {metric.psnr(x_gt, x_admm):.2f} dB\n") """ Solve via Linearized ADMM. """ solver_ladmm = LinearizedADMM( f=f, g=g, C=C, mu=3e-2, nu=2e-1, x0=x0, maxiter=50, itstat_options={"display": True, "period": 10}, ) print("Linearized ADMM:") x_ladmm = solver_ladmm.solve() hist_ladmm = solver_ladmm.itstat_object.history(transpose=True) print(f"PSNR: {metric.psnr(x_gt, x_ladmm):.2f} dB\n") """ Solve via PDHG. """ solver_pdhg = PDHG( f=f, g=g, C=C, tau=2e-2, sigma=8e0, x0=x0, maxiter=50, itstat_options={"display": True, "period": 10}, ) print("PDHG:") x_pdhg = solver_pdhg.solve() hist_pdhg = solver_pdhg.itstat_object.history(transpose=True) print(f"PSNR: {metric.psnr(x_gt, x_pdhg):.2f} dB\n") """ Show the recovered images. """ norm = plot.matplotlib.colors.Normalize(vmin=-0.1 * density, vmax=1.2 * density) fig, ax = plt.subplots(1, 2, figsize=[10, 5]) plot.imview(img=x_gt, title="Ground Truth Image", cbar=True, fig=fig, ax=ax[0], norm=norm) plot.imview( img=x_mrf, title=f"MRF (PSNR: {metric.psnr(x_gt, x_mrf):.2f} dB)", cbar=True, fig=fig, ax=ax[1], norm=norm, ) fig.show() fig, ax = plt.subplots(1, 3, figsize=[15, 5]) plot.imview( img=x_admm, title=f"TV ADMM (PSNR: {metric.psnr(x_gt, x_admm):.2f} dB)", cbar=True, fig=fig, ax=ax[0], norm=norm, ) plot.imview( img=x_ladmm, title=f"TV LinADMM (PSNR: {metric.psnr(x_gt, x_ladmm):.2f} dB)", cbar=True, fig=fig, ax=ax[1], norm=norm, ) plot.imview( img=x_pdhg, title=f"TV PDHG (PSNR: {metric.psnr(x_gt, x_pdhg):.2f} dB)", cbar=True, fig=fig, ax=ax[2], norm=norm, ) fig.show() """ Plot convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6)) plot.plot( snp.array((hist_admm.Objective, hist_ladmm.Objective, hist_pdhg.Objective)).T, ptyp="semilogy", title="Objective function", xlbl="Iteration", lgnd=("ADMM", "LinADMM", "PDHG"), fig=fig, ax=ax[0], ) plot.plot( snp.array((hist_admm.Prml_Rsdl, hist_ladmm.Prml_Rsdl, hist_pdhg.Prml_Rsdl)).T, ptyp="semilogy", title="Primal residual", xlbl="Iteration", lgnd=("ADMM", "LinADMM", "PDHG"), fig=fig, ax=ax[1], ) plot.plot( snp.array((hist_admm.Dual_Rsdl, hist_ladmm.Dual_Rsdl, hist_pdhg.Dual_Rsdl)).T, ptyp="semilogy", title="Dual residual", xlbl="Iteration", lgnd=("ADMM", "LinADMM", "PDHG"), fig=fig, ax=ax[2], ) fig.show() fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6)) plot.plot( snp.array((hist_admm.Objective, hist_ladmm.Objective, hist_pdhg.Objective)).T, snp.array((hist_admm.Time, hist_ladmm.Time, hist_pdhg.Time)).T, ptyp="semilogy", title="Objective function", xlbl="Time (s)", lgnd=("ADMM", "LinADMM", "PDHG"), fig=fig, ax=ax[0], ) plot.plot( snp.array((hist_admm.Prml_Rsdl, hist_ladmm.Prml_Rsdl, hist_pdhg.Prml_Rsdl)).T, snp.array((hist_admm.Time, hist_ladmm.Time, hist_pdhg.Time)).T, ptyp="semilogy", title="Primal residual", xlbl="Time (s)", lgnd=("ADMM", "LinADMM", "PDHG"), fig=fig, ax=ax[1], ) plot.plot( snp.array((hist_admm.Dual_Rsdl, hist_ladmm.Dual_Rsdl, hist_pdhg.Dual_Rsdl)).T, snp.array((hist_admm.Time, hist_ladmm.Time, hist_pdhg.Time)).T, ptyp="semilogy", title="Dual residual", xlbl="Time (s)", lgnd=("ADMM", "LinADMM", "PDHG"), fig=fig, ax=ax[2], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_symcone_tv_padmm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" TV-Regularized Cone Beam CT for Symmetric Objects ================================================= This example demonstrates a total variation (TV) regularized reconstruction for cone beam CT of a cylindrically symmetric object, by solving the problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} \|_2^2 + \lambda \| D \mathbf{x} \|_1 \;,$$ where $C$ is a single-view X-ray transform (with an implementation based on a projector from the AXITOM package :cite:`olufsen-2019-axitom`), $\mathbf{y}$ is the measured data, $D$ is a 2D finite difference operator, and $\mathbf{x}$ is the solution. """ import numpy as np import scico.numpy as snp from scico import functional, linop, loss, metric, plot from scico.examples import create_circular_phantom from scico.linop.xray.symcone import SymConeXRayTransform from scico.optimize import ProximalADMM from scico.util import device_info """ Create a ground truth image. """ N = 256 # image size x_gt = create_circular_phantom((N, N), [0.4 * N, 0.2 * N, 0.1 * N], [1, 0, 0.5]) """ Set up the forward operator and create a test measurement. """ C = SymConeXRayTransform(x_gt.shape, obj_dist=5e2 * N, det_dist=6e2 * N, num_slabs=4) y = C @ x_gt np.random.seed(12345) y = y + np.random.normal(size=y.shape).astype(np.float32) """ Compute FDK reconstruction. """ x_inv = C.fdk(y) r""" Set up problem and solver. We want to minimize the functional $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} \|_2^2 + \lambda \| D \mathbf{x} \|_1 \;,$$ where $C$ is the X-ray transform and $D$ is a finite difference operator. We use anisotropic TV, which gives slightly better performance than isotropic TV in this case. This problem can be expressed as $$\mathrm{argmin}_{\mathbf{x}, \mathbf{z}} \; (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 + \lambda \| \mathbf{z}_1 \|_1 \;\; \text{such that} \;\; \mathbf{z}_0 = C \mathbf{x} \;\; \text{and} \;\; \mathbf{z}_1 = D \mathbf{x} \;,$$ which can be written in the form of a standard ADMM problem $$\mathrm{argmin}_{\mathbf{x}, \mathbf{z}} \; f(\mathbf{x}) + g(\mathbf{z}) \;\; \text{such that} \;\; A \mathbf{x} + B \mathbf{z} = \mathbf{c}$$ with $$f = 0 \qquad g = g_0 + g_1$$ $$g_0(\mathbf{z}_0) = (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 \qquad g_1(\mathbf{z}_1) = \lambda \| \mathbf{z}_1 \|_1$$ $$A = \left( \begin{array}{c} C \\ D \end{array} \right) \qquad B = \left( \begin{array}{cc} -I & 0 \\ 0 & -I \end{array} \right) \qquad \mathbf{c} = \left( \begin{array}{c} 0 \\ 0 \end{array} \right) \;.$$ """ 𝛼 = 7e1 # improve problem conditioning by balancing C and D components of A λ = 8e0 # ℓ1 norm regularization parameter ρ = 1e-2 # ADMM penalty parameter maxiter = 250 # number of ADMM iterations f = functional.ZeroFunctional() g0 = loss.SquaredL2Loss(y=y) g1 = (λ / 𝛼) * functional.L1Norm() g = functional.SeparableFunctional((g0, g1)) D = linop.FiniteDifference(input_shape=x_gt.shape, append=0) A = linop.VerticalStack((C, 𝛼 * D)) mu, nu = ProximalADMM.estimate_parameters(A, maxiter=20) solver = ProximalADMM( f=f, g=g, A=A, B=None, rho=ρ, mu=mu, nu=nu, x0=snp.clip(x_inv, 0.0, 1.0), maxiter=maxiter, itstat_options={"display": True, "period": 20}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x_tv = solver.solve() hist = solver.itstat_object.history(transpose=True) """ Show results. """ norm = plot.matplotlib.colors.Normalize(vmin=-0.1, vmax=1.2) fig, ax = plot.subplots(nrows=2, ncols=2, figsize=(12, 12)) plot.imview(x_gt, title="Ground Truth", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 0], norm=norm) plot.imview(y, title="Measurement", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 1]) plot.imview( x_inv, title="FDK: %.2f (dB)" % metric.psnr(x_gt, x_inv), cmap=plot.cm.Blues, fig=fig, ax=ax[1, 0], norm=norm, ) plot.imview( x_tv, title="TV-Regularized Inversion: %.2f (dB)" % metric.psnr(x_gt, x_tv), cmap=plot.cm.Blues, fig=fig, ax=ax[1, 1], norm=norm, ) fig.show() """ Plot convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( hist.Objective, title="Objective function", xlbl="Iteration", ylbl="Functional value", fig=fig, ax=ax[0], ) plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_tv_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" TV-Regularized Sparse-View CT Reconstruction (Integrated Projector) =================================================================== This example demonstrates solution of a sparse-view CT reconstruction problem with isotropic total variation (TV) regularization $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ where $A$ is the X-ray transform (the CT forward projection operator), $\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and $\mathbf{x}$ is the reconstructed image. This example uses the CT projector integrated into scico, while the companion [example script](ct_astra_tv_admm.rst) uses the projector provided by the astra package. """ import numpy as np from mpl_toolkits.axes_grid1 import make_axes_locatable from xdesign import Foam, discrete_phantom import scico.numpy as snp from scico import functional, linop, loss, metric, plot from scico.linop.xray import XRayTransform2D from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Create a ground truth image. """ N = 512 # phantom size np.random.seed(1234) x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)) """ Configure CT projection operator and generate synthetic measurements. """ n_projection = 45 # number of projections angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles det_count = int(N * 1.05 / np.sqrt(2.0)) dx = 1.0 / np.sqrt(2) A = XRayTransform2D( (N, N), angles + np.pi / 2.0, det_count=det_count, dx=dx ) # CT projection operator y = A @ x_gt # sinogram """ Set up problem functional and ADMM solver object. """ λ = 2e0 # ℓ1 norm regularization parameter ρ = 5e0 # ADMM penalty parameter maxiter = 25 # number of ADMM iterations cg_tol = 1e-4 # CG relative tolerance cg_maxiter = 25 # maximum CG iterations per ADMM iteration # The append=0 option makes the results of horizontal and vertical # finite differences the same shape, which is required for the L21Norm, # which is used so that g(Cx) corresponds to isotropic TV. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) g = λ * functional.L21Norm() f = loss.SquaredL2Loss(y=y, A=A) x0 = snp.clip(A.fbp(y), 0, 1.0) solver = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], x0=x0, maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), itstat_options={"display": True, "period": 5}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") solver.solve() hist = solver.itstat_object.history(transpose=True) x_reconstruction = snp.clip(solver.x, 0, 1.0) """ Show the recovered image. """ fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(x_gt, title="Ground truth", cbar=None, fig=fig, ax=ax[0]) plot.imview( x0, title="FBP Reconstruction: \nSNR: %.2f (dB), MAE: %.3f" % (metric.snr(x_gt, x0), metric.mae(x_gt, x0)), cbar=None, fig=fig, ax=ax[1], ) plot.imview( x_reconstruction, title="TV Reconstruction\nSNR: %.2f (dB), MAE: %.3f" % (metric.snr(x_gt, x_reconstruction), metric.mae(x_gt, x_reconstruction)), fig=fig, ax=ax[2], ) divider = make_axes_locatable(ax[2]) cax = divider.append_axes("right", size="5%", pad=0.2) fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units") fig.show() """ Plot convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( hist.Objective, title="Objective function", xlbl="Iteration", ylbl="Functional value", fig=fig, ax=ax[0], ) plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/ct_unet_train_foam2.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ CT Training and Reconstructions with UNet ========================================= This example demonstrates the training and application of UNet to denoise previously filtered back projections (FBP) for CT reconstruction inspired by :cite:`jin-2017-unet`. """ # isort: off import os from time import time import logging import ray ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 # Set an arbitrary processor count (only applies if GPU is not available). os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" import jax try: from jax.extend.backend import get_backend # introduced in jax 0.4.33 except ImportError: from jax.lib.xla_bridge import get_backend import numpy as np from mpl_toolkits.axes_grid1 import make_axes_locatable from scico import flax as sflax from scico import metric, plot from scico.flax.examples import load_ct_data platform = get_backend().platform print("Platform: ", platform) """ Read data from cache or generate if not available. """ N = 256 # phantom size train_nimg = 498 # number of training images test_nimg = 32 # number of testing images nimg = train_nimg + test_nimg n_projection = 45 # CT views trdt, ttdt = load_ct_data(train_nimg, test_nimg, N, n_projection, verbose=True) """ Build training and testing structures. Inputs are the filter back-projected sinograms and outpus are the original generated foams. Keep training and testing partitions. """ train_ds = {"image": trdt["fbp"], "label": trdt["img"]} test_ds = {"image": ttdt["fbp"], "label": ttdt["img"]} """ Define configuration dictionary for model and training loop. Parameters have been selected for demonstration purposes and relatively short training. The model depth controls the levels of pooling in the U-Net model. The block depth controls the number of layers at each level of depth. The number of filters controls the number of filters at the input and output levels and doubles (halves) at each pooling (unpooling) operation. Better performance may be obtained by increasing depth, block depth, number of filters or training epochs, but may require longer training times. """ # model configuration model_conf = { "depth": 2, "num_filters": 64, "block_depth": 2, } # training configuration train_conf: sflax.ConfigDict = { "seed": 0, "opt_type": "SGD", "momentum": 0.9, "batch_size": 16, "num_epochs": 200, "base_learning_rate": 1e-2, "warmup_epochs": 0, "log_every_steps": 1000, "log": True, "checkpointing": True, } """ Construct UNet model. """ channels = train_ds["image"].shape[-1] model = sflax.UNet( depth=model_conf["depth"], channels=channels, num_filters=model_conf["num_filters"], block_depth=model_conf["block_depth"], ) """ Run training loop. """ workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "unet_ct_out") train_conf["workdir"] = workdir print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}") print(f"JAX local devices: {jax.local_devices()}\n") trainer = sflax.BasicFlaxTrainer( train_conf, model, train_ds, test_ds, ) modvar, stats_object = trainer.train() """ Evaluate on testing data. """ del train_ds["image"] del train_ds["label"] fmap = sflax.FlaxMap(model, modvar) del model, modvar maxn = test_nimg // 2 start_time = time() output = fmap(test_ds["image"][:maxn]) time_eval = time() - start_time output = jax.numpy.clip(output, a_min=0, a_max=1.0) """ Evaluate trained model in terms of reconstruction time and data fidelity. """ snr_eval = metric.snr(test_ds["label"][:maxn], output) psnr_eval = metric.psnr(test_ds["label"][:maxn], output) print( f"{'UNet training':15s}{'epochs:':2s}{train_conf['num_epochs']:>5d}" f"{'':21s}{'time[s]:':10s}{trainer.train_time:>7.2f}" ) print( f"{'UNet testing':15s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}" f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}" ) """ Plot comparison. """ key = jax.random.key(123) indx = jax.random.randint(key, shape=(1,), minval=0, maxval=maxn)[0] fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(test_ds["label"][indx, ..., 0], title="Ground truth", cbar=None, fig=fig, ax=ax[0]) plot.imview( test_ds["image"][indx, ..., 0], title="FBP Reconstruction: \nSNR: %.2f (dB), MAE: %.3f" % ( metric.snr(test_ds["label"][indx, ..., 0], test_ds["image"][indx, ..., 0]), metric.mae(test_ds["label"][indx, ..., 0], test_ds["image"][indx, ..., 0]), ), cbar=None, fig=fig, ax=ax[1], ) plot.imview( output[indx, ..., 0], title="UNet Reconstruction\nSNR: %.2f (dB), MAE: %.3f" % ( metric.snr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]), metric.mae(test_ds["label"][indx, ..., 0], output[indx, ..., 0]), ), fig=fig, ax=ax[2], ) divider = make_axes_locatable(ax[2]) cax = divider.append_axes("right", size="5%", pad=0.2) fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units") fig.show() """ Plot convergence statistics. Statistics are generated only if a training cycle was done (i.e. if not reading final epoch results from checkpoint). """ if stats_object is not None and len(stats_object.iterations) > 0: hist = stats_object.history(transpose=True) fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( np.array((hist.Train_Loss, hist.Eval_Loss)).T, x=hist.Epoch, ptyp="semilogy", title="Loss function", xlbl="Epoch", ylbl="Loss value", lgnd=("Train", "Test"), fig=fig, ax=ax[0], ) plot.plot( np.array((hist.Train_SNR, hist.Eval_SNR)).T, x=hist.Epoch, title="Metric", xlbl="Epoch", ylbl="SNR (dB)", lgnd=("Train", "Test"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/deconv_circ_tv_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Circulant Blur Image Deconvolution with TV Regularization ========================================================= This example demonstrates the solution of an image deconvolution problem with isotropic total variation (TV) regularization $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ where $A$ is a circular convolution operator, $\mathbf{y}$ is the blurred image, $C$ is a 2D finite difference operator, and $\mathbf{x}$ is the deconvolved image. """ from xdesign import SiemensStar, discrete_phantom import scico.numpy as snp import scico.random from scico import functional, linop, loss, metric, plot from scico.optimize.admm import ADMM, CircularConvolveSolver from scico.util import device_info """ Create a ground truth image. """ phantom = SiemensStar(32) N = 256 # image size x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) """ Set up the forward operator and create a test signal consisting of a blurred signal with additive Gaussian noise. """ n = 5 # convolution kernel size σ = 20.0 / 255 # noise level psf = snp.ones((n, n)) / (n * n) A = linop.CircularConvolve(h=psf, input_shape=x_gt.shape) Ax = A(x_gt) # blurred image noise, key = scico.random.randn(Ax.shape, seed=0) y = Ax + σ * noise """ Set up an ADMM solver object. """ λ = 2e-2 # ℓ2,1 norm regularization parameter ρ = 5e-1 # ADMM penalty parameter maxiter = 50 # number of ADMM iterations f = loss.SquaredL2Loss(y=y, A=A) # Penalty parameters must be accounted for in the gi functions, not as # additional inputs. g = λ * functional.L21Norm() # regularization functionals gi C = linop.FiniteDifference(x_gt.shape, circular=True) solver = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], x0=A.adj(y), maxiter=maxiter, subproblem_solver=CircularConvolveSolver(), itstat_options={"display": True, "period": 10}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x = solver.solve() hist = solver.itstat_object.history(transpose=True) """ Show the recovered image. """ fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0]) plot.imview(y, title="Blurred, noisy image: %.2f (dB)" % metric.psnr(x_gt, y), fig=fig, ax=ax[1]) plot.imview(x, title="Deconvolved image: %.2f (dB)" % metric.psnr(x_gt, x), fig=fig, ax=ax[2]) fig.show() """ Plot convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( hist.Objective, title="Objective function", xlbl="Iteration", ylbl="Functional value", fig=fig, ax=ax[0], ) plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/deconv_datagen_bsds.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Blurred Data Generation (Natural Images) for NN Training ======================================================== This example demonstrates how to generate blurred image data for training neural network models for deconvolution (deblurring). The original images are part of the [BSDS500 dataset](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/) provided by the Berkeley Segmentation Dataset and Benchmark project. """ import numpy as np from jax import vmap from scico import plot from scico.flax.examples import PaddedCircularConvolve, load_image_data """ Define blur operator. """ output_size = 256 # patch size channels = 1 # gray scale problem blur_shape = (9, 9) # shape of blur kernel blur_sigma = 5 # Gaussian blur kernel parameter opBlur = PaddedCircularConvolve(output_size, channels, blur_shape, blur_sigma) opBlur_vmap = vmap(opBlur) # for batch processing """ Read data from cache or generate if not available. """ train_nimg = 400 # number of training images test_nimg = 64 # number of testing images nimg = train_nimg + test_nimg gray = True # use gray scale images data_mode = "dcnv" # deconvolution problem noise_level = 0.005 # standard deviation of noise noise_range = False # use fixed noise level stride = 100 # stride to sample multiple patches from each image augment = True # augment data via rotations and flips train_ds, test_ds = load_image_data( train_nimg, test_nimg, output_size, gray, data_mode, verbose=True, noise_level=noise_level, noise_range=noise_range, transf=opBlur_vmap, stride=stride, augment=augment, ) """ Plot randomly selected sample. """ indx_tr = np.random.randint(0, train_nimg) indx_te = np.random.randint(0, test_nimg) fig, axes = plot.subplots(nrows=2, ncols=2, figsize=(7, 7)) plot.imview( train_ds["label"][indx_tr, ..., 0], title="Ground truth - Training Sample", fig=fig, ax=axes[0, 0], ) plot.imview( train_ds["image"][indx_tr, ..., 0], title="Blurred Image - Training Sample", fig=fig, ax=axes[0, 1], ) plot.imview( test_ds["label"][indx_te, ..., 0], title="Ground truth - Testing Sample", fig=fig, ax=axes[1, 0], ) plot.imview( test_ds["image"][indx_te, ..., 0], title="Blurred Image - Testing Sample", fig=fig, ax=axes[1, 1], ) fig.suptitle(r"Training and Testing samples") fig.tight_layout() fig.colorbar( axes[0, 1].get_images()[0], ax=axes, shrink=0.5, pad=0.05, ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/deconv_datagen_foam1.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Blurred Data Generation (Foams) for NN Training =============================================== This example demonstrates how to generate blurred image data for training neural network models for deconvolution (deblurring), using foam phantoms generated by `xdesign`. """ # isort: off import numpy as np import logging import ray ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 from scico import plot from scico.flax.examples import load_blur_data """ Read data from cache or generate if not available. """ n = 3 # convolution kernel size σ = 20.0 / 255 # noise level psf = np.ones((n, n)) / (n * n) # kernel train_nimg = 416 # number of training images test_nimg = 64 # number of testing images nimg = train_nimg + test_nimg output_size = 256 # image size train_ds, test_ds = load_blur_data( train_nimg, test_nimg, output_size, psf, σ, verbose=True, ) """ Plot randomly selected sample. """ indx_tr = np.random.randint(0, train_nimg) indx_te = np.random.randint(0, test_nimg) fig, axes = plot.subplots(nrows=2, ncols=2, figsize=(7, 7)) plot.imview( train_ds["label"][indx_tr, ..., 0], title="Ground truth - Training Sample", fig=fig, ax=axes[0, 0], ) plot.imview( train_ds["image"][indx_tr, ..., 0], title="Blurred Image - Training Sample", fig=fig, ax=axes[0, 1], ) plot.imview( test_ds["label"][indx_te, ..., 0], title="Ground truth - Testing Sample", fig=fig, ax=axes[1, 0], ) plot.imview( test_ds["image"][indx_te, ..., 0], title="Blurred Image - Testing Sample", fig=fig, ax=axes[1, 1], ) fig.suptitle(r"Training and Testing samples") fig.tight_layout() fig.colorbar( axes[0, 1].get_images()[0], ax=axes, shrink=0.5, pad=0.05, label="Arbitrary Units", ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/deconv_microscopy_allchn_tv_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Deconvolution Microscopy (All Channels) ======================================= This example partially replicates a [GlobalBioIm example](https://biomedical-imaging-group.github.io/GlobalBioIm/examples.html) using the [microscopy data](http://bigwww.epfl.ch/deconvolution/bio/) provided by the EPFL Biomedical Imaging Group. The deconvolution problem is solved using class [admm.ADMM](../_autosummary/scico.optimize.rst#scico.optimize.ADMM) to solve an image deconvolution problem with isotropic total variation (TV) regularization $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| M (\mathbf{y} - A \mathbf{x}) \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} + \iota_{\mathrm{NN}}(\mathbf{x}) \;,$$ where $M$ is a mask operator, $A$ is circular convolution, $\mathbf{y}$ is the blurred image, $C$ is a convolutional gradient operator, $\iota_{\mathrm{NN}}$ is the indicator function of the non-negativity constraint, and $\mathbf{x}$ is the deconvolved image. """ # isort: off import numpy as np import logging import ray ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 import scico.numpy as snp from scico import functional, linop, loss, plot from scico.examples import downsample_volume, epfl_deconv_data, tile_volume_slices from scico.optimize.admm import ADMM, CircularConvolveSolver """ Get and preprocess data. The data is downsampled to limit the memory requirements and run time of the example. Reducing the downsampling rate will make the example slower and more memory-intensive. To run this example on a GPU it may be necessary to set environment variables `XLA_PYTHON_CLIENT_ALLOCATOR=platform` and `XLA_PYTHON_CLIENT_PREALLOCATE=false`. If your GPU does not have enough memory, try setting the environment variable `JAX_PLATFORM_NAME=cpu` to run on CPU. """ downsampling_rate = 2 y_list = [] y_pad_list = [] psf_list = [] for channel in range(3): y, psf = epfl_deconv_data(channel, verbose=True) # get data y = downsample_volume(y, downsampling_rate) # downsample psf = downsample_volume(psf, downsampling_rate) y -= y.min() # normalize y y /= y.max() psf /= psf.sum() # normalize psf if channel == 0: padding = [[0, p] for p in snp.array(psf.shape) - 1] mask = snp.pad(snp.ones_like(y), padding) y_pad = snp.pad(y, padding) # zero-padded version of y y_list.append(y) y_pad_list.append(y_pad) psf_list.append(psf) y = snp.stack(y_list, axis=-1) yshape = y.shape del y_list """ Define problem and algorithm parameters. """ λ = 2e-6 # ℓ1 norm regularization parameter ρ0 = 1e-3 # ADMM penalty parameter for first auxiliary variable ρ1 = 1e-3 # ADMM penalty parameter for second auxiliary variable ρ2 = 1e-3 # ADMM penalty parameter for third auxiliary variable maxiter = 100 # number of ADMM iterations """ Determine available computing resources, and put large arrays in ray object store. """ ngpu = 0 ar = ray.available_resources() ncpu = max(int(ar["CPU"]) // 3, 1) if "GPU" in ar: ngpu = int(ar["GPU"]) // 3 print(f"Running on {ncpu} CPUs and {ngpu} GPUs per process") y_pad_list = ray.put(y_pad_list) psf_list = ray.put(psf_list) mask_store = ray.put(mask) """ Define ray remote function for parallel solves. """ @ray.remote(num_cpus=ncpu, num_gpus=ngpu) def deconvolve_channel(channel): """Deconvolve a single channel.""" y_pad = ray.get(y_pad_list)[channel] psf = ray.get(psf_list)[channel] mask = ray.get(mask_store) M = linop.Diagonal(mask) C0 = linop.CircularConvolve( h=psf, input_shape=mask.shape, h_center=snp.array(psf.shape) / 2 - 0.5 # forward operator ) C1 = linop.FiniteDifference(input_shape=mask.shape, circular=True) # gradient operator C2 = linop.Identity(mask.shape) # identity operator g0 = loss.SquaredL2Loss(y=y_pad, A=M) # loss function (forward model) g1 = λ * functional.L21Norm() # TV penalty (when applied to gradient) g2 = functional.NonNegativeIndicator() # non-negativity constraint if channel == 0: print("Displaying solver status for channel 0") display = True else: display = False solver = ADMM( f=None, g_list=[g0, g1, g2], C_list=[C0, C1, C2], rho_list=[ρ0, ρ1, ρ2], maxiter=maxiter, itstat_options={"display": display, "period": 10, "overwrite": False}, x0=y_pad, subproblem_solver=CircularConvolveSolver(), ) x_pad = solver.solve() x = x_pad[: yshape[0], : yshape[1], : yshape[2]] return (x, solver.itstat_object.history(transpose=True)) """ Solve problems for all three channels in parallel and extract results. """ ray_return = ray.get([deconvolve_channel.remote(channel) for channel in range(3)]) x = snp.stack([t[0] for t in ray_return], axis=-1) solve_stats = [t[1] for t in ray_return] """ Show the recovered image. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(14, 7)) plot.imview(tile_volume_slices(y), title="Blurred measurements", fig=fig, ax=ax[0]) plot.imview(tile_volume_slices(x), title="Deconvolved image", fig=fig, ax=ax[1]) fig.show() """ Plot convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(18, 5)) plot.plot( np.stack([s.Objective for s in solve_stats]).T, title="Objective function", xlbl="Iteration", ylbl="Functional value", lgnd=("CY3", "DAPI", "FITC"), fig=fig, ax=ax[0], ) plot.plot( np.stack([s.Prml_Rsdl for s in solve_stats]).T, ptyp="semilogy", title="Primal Residual", xlbl="Iteration", lgnd=("CY3", "DAPI", "FITC"), fig=fig, ax=ax[1], ) plot.plot( np.stack([s.Dual_Rsdl for s in solve_stats]).T, ptyp="semilogy", title="Dual Residual", xlbl="Iteration", lgnd=("CY3", "DAPI", "FITC"), fig=fig, ax=ax[2], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/deconv_microscopy_tv_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Deconvolution Microscopy (Single Channel) ========================================= This example partially replicates a [GlobalBioIm example](https://biomedical-imaging-group.github.io/GlobalBioIm/examples.html) using the [microscopy data](http://bigwww.epfl.ch/deconvolution/bio/) provided by the EPFL Biomedical Imaging Group. The deconvolution problem is solved using class [admm.ADMM](../_autosummary/scico.optimize.rst#scico.optimize.ADMM) to solve an image deconvolution problem with isotropic total variation (TV) regularization $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| M (\mathbf{y} - A \mathbf{x}) \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} + \iota_{\mathrm{NN}}(\mathbf{x}) \;,$$ where $M$ is a mask operator, $A$ is circular convolution, $\mathbf{y}$ is the blurred image, $C$ is a convolutional gradient operator, $\iota_{\mathrm{NN}}$ is the indicator function of the non-negativity constraint, and $\mathbf{x}$ is the deconvolved image. """ import scico.numpy as snp from scico import functional, linop, loss, plot, util from scico.examples import downsample_volume, epfl_deconv_data, tile_volume_slices from scico.optimize.admm import ADMM, CircularConvolveSolver """ Get and preprocess data. The data is downsampled to limit the memory requirements and run time of the example. Reducing the downsampling rate will make the example slower and more memory-intensive. To run this example on a GPU it may be necessary to set environment variables `XLA_PYTHON_CLIENT_ALLOCATOR=platform` and `XLA_PYTHON_CLIENT_PREALLOCATE=false`. If your GPU does not have enough memory, try setting the environment variable `JAX_PLATFORM_NAME=cpu` to run on CPU. """ channel = 0 downsampling_rate = 2 y, psf = epfl_deconv_data(channel, verbose=True) y = downsample_volume(y, downsampling_rate) psf = downsample_volume(psf, downsampling_rate) y -= y.min() y /= y.max() psf /= psf.sum() """ Pad data and create mask. """ padding = [[0, p] for p in snp.array(psf.shape) - 1] y_pad = snp.pad(y, padding) mask = snp.pad(snp.ones_like(y), padding) """ Define problem and algorithm parameters. """ λ = 2e-6 # ℓ1 norm regularization parameter ρ0 = 1e-3 # ADMM penalty parameter for first auxiliary variable ρ1 = 1e-3 # ADMM penalty parameter for second auxiliary variable ρ2 = 1e-3 # ADMM penalty parameter for third auxiliary variable maxiter = 100 # number of ADMM iterations """ Create operators. """ M = linop.Diagonal(mask) C0 = linop.CircularConvolve(h=psf, input_shape=mask.shape, h_center=snp.array(psf.shape) / 2 - 0.5) C1 = linop.FiniteDifference(input_shape=mask.shape, circular=True) C2 = linop.Identity(mask.shape) """ Create functionals. """ g0 = loss.SquaredL2Loss(y=y_pad, A=M) # loss function (forward model) g1 = λ * functional.L21Norm() # TV penalty (when applied to gradient) g2 = functional.NonNegativeIndicator() # non-negativity constraint """ Set up ADMM solver object and solve problem. """ solver = ADMM( f=None, g_list=[g0, g1, g2], C_list=[C0, C1, C2], rho_list=[ρ0, ρ1, ρ2], maxiter=maxiter, itstat_options={"display": True, "period": 10}, x0=y_pad, subproblem_solver=CircularConvolveSolver(), ) print("Solving on %s\n" % util.device_info()) solver.solve() solve_stats = solver.itstat_object.history(transpose=True) x_pad = solver.x x = x_pad[: y.shape[0], : y.shape[1], : y.shape[2]] """ Show the recovered image. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(14, 7)) plot.imview(tile_volume_slices(y), title="Blurred measurements", fig=fig, ax=ax[0]) plot.imview(tile_volume_slices(x), title="Deconvolved image", fig=fig, ax=ax[1]) fig.show() """ Plot convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( solve_stats.Objective, title="Objective function", xlbl="Iteration", ylbl="Functional value", fig=fig, ax=ax[0], ) plot.plot( snp.array((solve_stats.Prml_Rsdl, solve_stats.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/deconv_modl_train_foam1.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Deconvolution Training and Reconstructions with MoDL ==================================================== This example demonstrates the training and application of a model-based deep learning (MoDL) architecture described in :cite:`aggarwal-2019-modl` for a deconvolution (deblurring) problem. The source images are foam phantoms generated with xdesign. A class [scico.flax.MoDLNet](../_autosummary/scico.flax.rst#scico.flax.MoDLNet) implements the MoDL architecture, which solves the optimization problem $$\mathrm{argmin}_{\mathbf{x}} \; \| A \mathbf{x} - \mathbf{y} \|_2^2 + \lambda \, \| \mathbf{x} - \mathrm{D}_w(\mathbf{x})\|_2^2 \;,$$ where $A$ is a circular convolution, $\mathbf{y}$ is a set of blurred images, $\mathrm{D}_w$ is the regularization (a denoiser), and $\mathbf{x}$ is the set of deblurred images. The MoDL abstracts the iterative solution by an unrolled network where each iteration corresponds to a different stage in the MoDL network and updates the prediction by solving $$\mathbf{x}^{k+1} = (A^T A + \lambda \, I)^{-1} (A^T \mathbf{y} + \lambda \, \mathbf{z}^k) \;,$$ via conjugate gradient. In the expression, $k$ is the index of the stage (iteration), $\mathbf{z}^k = \mathrm{ResNet}(\mathbf{x}^{k})$ is the regularization (a denoiser implemented as a residual convolutional neural network), $\mathbf{x}^k$ is the output of the previous stage, $\lambda > 0$ is a learned regularization parameter, and $I$ is the identity operator. The output of the final stage is the set of deblurred images. """ # isort: off import os from functools import partial from time import time import numpy as np import logging import ray ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 # Set an arbitrary processor count (only applies if GPU is not available). os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" import jax try: from jax.extend.backend import get_backend # introduced in jax 0.4.33 except ImportError: from jax.lib.xla_bridge import get_backend from mpl_toolkits.axes_grid1 import make_axes_locatable from scico import flax as sflax from scico import metric, plot from scico.flax.examples import load_blur_data from scico.flax.train.traversals import clip_positive, construct_traversal from scico.linop import CircularConvolve platform = get_backend().platform print("Platform: ", platform) """ Define blur operator. """ output_size = 256 # image size n = 3 # convolution kernel size σ = 20.0 / 255 # noise level psf = np.ones((n, n)) / (n * n) # blur kernel ishape = (output_size, output_size) opBlur = CircularConvolve(h=psf, input_shape=ishape) opBlur_vmap = jax.vmap(opBlur) # for batch processing in data generation """ Read data from cache or generate if not available. """ train_nimg = 416 # number of training images test_nimg = 64 # number of testing images nimg = train_nimg + test_nimg train_ds, test_ds = load_blur_data( train_nimg, test_nimg, output_size, psf, σ, verbose=True, ) """ Define configuration dictionary for model and training loop. Parameters have been selected for demonstration purposes and relatively short training. The model depth is akin to the number of unrolled iterations in the MoDL model. The block depth controls the number of layers at each unrolled iteration. The number of filters is uniform throughout the iterations. The iterations used for the conjugate gradient (CG) solver can also be specified. Better performance may be obtained by increasing depth, block depth, number of filters, CG iterations, or training epochs, but may require longer training times. """ # model configuration model_conf = { "depth": 2, "num_filters": 64, "block_depth": 4, "cg_iter": 4, } # training configuration train_conf: sflax.ConfigDict = { "seed": 0, "opt_type": "SGD", "momentum": 0.9, "batch_size": 16, "num_epochs": 25, "base_learning_rate": 1e-2, "warmup_epochs": 0, "log_every_steps": 100, "log": True, "checkpointing": True, } """ Construct functionality for ensuring that the learned regularization parameter is always positive. """ lmbdatrav = construct_traversal("lmbda") # select lmbda parameters in model lmbdapos = partial( clip_positive, # apply this function traversal=lmbdatrav, # to lmbda parameters in model minval=5e-4, ) """ Print configuration of distributed run. """ print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}") print(f"JAX local devices: {jax.local_devices()}\n") """ Check for iterated trained model. If not found, construct MoDLNet model, using only one iteration (depth) in model and few CG iterations for faster intialization. Run first stage (initialization) training loop followed by a second stage (depth iterations) training loop. """ channels = train_ds["image"].shape[-1] workdir2 = os.path.join( os.path.expanduser("~"), ".cache", "scico", "examples", "modl_dcnv_out", "iterated" ) stats_object_ini = None stats_object = None checkpoint_files = [] for dirpath, dirnames, filenames in os.walk(workdir2): checkpoint_files = [fn for fn in filenames] if len(checkpoint_files) > 0: model = sflax.MoDLNet( operator=opBlur, depth=model_conf["depth"], channels=channels, num_filters=model_conf["num_filters"], block_depth=model_conf["block_depth"], cg_iter=model_conf["cg_iter"], ) train_conf["workdir"] = workdir2 train_conf["post_lst"] = [lmbdapos] # Construct training object trainer = sflax.BasicFlaxTrainer( train_conf, model, train_ds, test_ds, ) start_time = time() modvar, stats_object = trainer.train() time_train = time() - start_time time_init = 0.0 epochs_init = 0 else: # One iteration (depth) in model and few CG iterations model = sflax.MoDLNet( operator=opBlur, depth=1, channels=channels, num_filters=model_conf["num_filters"], block_depth=model_conf["block_depth"], cg_iter=model_conf["cg_iter"], ) # First stage: initialization training loop. workdir1 = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_dcnv_out") train_conf["workdir"] = workdir1 train_conf["post_lst"] = [lmbdapos] # Construct training object trainer = sflax.BasicFlaxTrainer( train_conf, model, train_ds, test_ds, ) start_time = time() modvar, stats_object_ini = trainer.train() time_init = time() - start_time epochs_init = train_conf["num_epochs"] print( f"{'MoDLNet init':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}{'':3s}" f"{'time[s]:':21s}{time_init:>7.2f}" ) # Second stage: depth iterations training loop. model.depth = model_conf["depth"] train_conf["workdir"] = workdir2 # Construct training object, include current model parameters trainer = sflax.BasicFlaxTrainer( train_conf, model, train_ds, test_ds, variables0=modvar, ) start_time = time() modvar, stats_object = trainer.train() time_train = time() - start_time """ Evaluate on testing data. """ del train_ds["image"] del train_ds["label"] fmap = sflax.FlaxMap(model, modvar) del model, modvar maxn = test_nimg // 4 start_time = time() output = fmap(test_ds["image"][:maxn]) time_eval = time() - start_time output = np.clip(output, a_min=0, a_max=1.0) """ Evaluate trained model in terms of reconstruction time and data fidelity. """ total_epochs = epochs_init + train_conf["num_epochs"] total_time_train = time_init + time_train snr_eval = metric.snr(test_ds["label"][:maxn], output) psnr_eval = metric.psnr(test_ds["label"][:maxn], output) print( f"{'MoDLNet training':18s}{'epochs:':2s}{total_epochs:>5d}{'':21s}" f"{'time[s]:':10s}{total_time_train:>7.2f}" ) print( f"{'MoDLNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}" f"{'':3s}{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}" ) """ Plot comparison. """ np.random.seed(123) indx = np.random.randint(0, high=maxn) fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(test_ds["label"][indx, ..., 0], title="Ground truth", cbar=None, fig=fig, ax=ax[0]) plot.imview( test_ds["image"][indx, ..., 0], title="Blurred: \nSNR: %.2f (dB), PSNR: %.2f" % ( metric.snr(test_ds["label"][indx, ..., 0], test_ds["image"][indx, ..., 0]), metric.psnr(test_ds["label"][indx, ..., 0], test_ds["image"][indx, ..., 0]), ), cbar=None, fig=fig, ax=ax[1], ) plot.imview( output[indx, ..., 0], title="MoDLNet Reconstruction\nSNR: %.2f (dB), PSNR: %.2f" % ( metric.snr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]), metric.psnr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]), ), fig=fig, ax=ax[2], ) divider = make_axes_locatable(ax[2]) cax = divider.append_axes("right", size="5%", pad=0.2) fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units") fig.show() """ Plot convergence statistics. Statistics are generated only if a training cycle was done (i.e. if not reading final epoch results from checkpoint). """ if stats_object is not None and len(stats_object.iterations) > 0: hist = stats_object.history(transpose=True) fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( np.array((hist.Train_Loss, hist.Eval_Loss)).T, x=hist.Epoch, ptyp="semilogy", title="Loss function", xlbl="Epoch", ylbl="Loss value", lgnd=("Train", "Test"), fig=fig, ax=ax[0], ) plot.plot( np.array((hist.Train_SNR, hist.Eval_SNR)).T, x=hist.Epoch, title="Metric", xlbl="Epoch", ylbl="SNR (dB)", lgnd=("Train", "Test"), fig=fig, ax=ax[1], ) fig.show() # Stats for initialization loop if stats_object_ini is not None and len(stats_object_ini.iterations) > 0: hist = stats_object_ini.history(transpose=True) fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( np.array((hist.Train_Loss, hist.Eval_Loss)).T, x=hist.Epoch, ptyp="semilogy", title="Loss function - Initialization", xlbl="Epoch", ylbl="Loss value", lgnd=("Train", "Test"), fig=fig, ax=ax[0], ) plot.plot( np.array((hist.Train_SNR, hist.Eval_SNR)).T, x=hist.Epoch, title="Metric - Initialization", xlbl="Epoch", ylbl="SNR (dB)", lgnd=("Train", "Test"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/deconv_odp_train_foam1.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Deconvolution Training and Reconstructions with ODP =================================================== This example demonstrates the training and application of the unrolled optimization with deep priors (ODP) with proximal map architecture described in :cite:`diamond-2018-odp` for a deconvolution (deblurring) problem. The source images are foam phantoms generated with xdesign. A class [scico.flax.ODPNet](../_autosummary/scico.flax.rst#scico.flax.ODPNet) implements the ODP architecture, which solves the optimization problem $$\mathrm{argmin}_{\mathbf{x}} \; \| A \mathbf{x} - \mathbf{y} \|_2^2 + r(\mathbf{x}) \;,$$ where $A$ is a circular convolution, $\mathbf{y}$ is a set of blurred images, $r$ is a regularizer and $\mathbf{x}$ is the set of deblurred images. The ODP, proximal map architecture, abstracts the iterative solution by an unrolled network where each iteration corresponds to a different stage in the ODP network and updates the prediction by solving $$\mathbf{x}^{k+1} = \mathrm{argmin}_{\mathbf{x}} \; \alpha_k \| A \mathbf{x} - \mathbf{y} \|_2^2 + \frac{1}{2} \| \mathbf{x} - \mathbf{x}^k - \mathbf{x}^{k+1/2} \|_2^2 \;,$$ which for the deconvolution problem corresponds to $$\mathbf{x}^{k+1} = \mathcal{F}^{-1} \mathrm{diag} (\alpha_k | \mathcal{K}|^2 + 1 )^{-1} \mathcal{F} \, (\alpha_k K^T * \mathbf{y} + \mathbf{x}^k + \mathbf{x}^{k+1/2}) \;,$$ where $k$ is the index of the stage (iteration), $\mathbf{x}^k + \mathbf{x}^{k+1/2} = \mathrm{ResNet}(\mathbf{x}^{k})$ is the regularization (implemented as a residual convolutional neural network), $\mathbf{x}^k$ is the output of the previous stage, $\alpha_k > 0$ is a learned stage-wise parameter weighting the contribution of the fidelity term, $\mathcal{F}$ is the DFT, $K$ is the blur kernel, and $\mathcal{K}$ is the DFT of $K$. The output of the final stage is the set of deblurred images. """ # isort: off import os from functools import partial from time import time import numpy as np import logging import ray ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 # Set an arbitrary processor count (only applies if GPU is not available). os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" import jax try: from jax.extend.backend import get_backend # introduced in jax 0.4.33 except ImportError: from jax.lib.xla_bridge import get_backend from mpl_toolkits.axes_grid1 import make_axes_locatable from scico import flax as sflax from scico import metric, plot from scico.flax.examples import load_blur_data from scico.flax.train.traversals import clip_positive, construct_traversal from scico.linop import CircularConvolve platform = get_backend().platform print("Platform: ", platform) """ Define blur operator. """ output_size = 256 # patch size n = 3 # convolution kernel size σ = 20.0 / 255 # noise level psf = np.ones((n, n)) / (n * n) # blur kernel ishape = (output_size, output_size) opBlur = CircularConvolve(h=psf, input_shape=ishape) opBlur_vmap = jax.vmap(opBlur) # for batch processing in data generation """ Read data from cache or generate if not available. """ train_nimg = 416 # number of training images test_nimg = 64 # number of testing images nimg = train_nimg + test_nimg train_ds, test_ds = load_blur_data( train_nimg, test_nimg, output_size, psf, σ, verbose=True, ) """ Define configuration dictionary for model and training loop. Parameters have been selected for demonstration purposes and relatively short training. The model depth is akin to the number of unrolled iterations in the ODP model. The block depth controls the number of layers at each unrolled iteration. The number of filters is uniform throughout the iterations. Better performance may be obtained by increasing depth, block depth, number of filters or training epochs, but may require longer training times. """ # model configuration model_conf = { "depth": 2, "num_filters": 64, "block_depth": 3, } # training configuration train_conf: sflax.ConfigDict = { "seed": 0, "opt_type": "SGD", "momentum": 0.9, "batch_size": 16, "num_epochs": 50, "base_learning_rate": 1e-2, "warmup_epochs": 0, "log_every_steps": 100, "log": True, "checkpointing": True, } """ Construct ODPNet model. """ channels = train_ds["image"].shape[-1] model = sflax.ODPNet( operator=opBlur, depth=model_conf["depth"], channels=channels, num_filters=model_conf["num_filters"], block_depth=model_conf["block_depth"], odp_block=sflax.inverse.ODPProxDcnvBlock, ) """ Construct functionality for ensuring that the learned fidelity weight parameter is always positive. """ alphatrav = construct_traversal("alpha") # select alpha parameters in model alphapos = partial( clip_positive, # apply this function traversal=alphatrav, # to alpha parameters in model minval=1e-3, ) """ Run training loop. """ print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}") print(f"JAX local devices: {jax.local_devices()}\n") workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "odp_dcnv_out") train_conf["workdir"] = workdir train_conf["post_lst"] = [alphapos] # Construct training object trainer = sflax.BasicFlaxTrainer( train_conf, model, train_ds, test_ds, ) modvar, stats_object = trainer.train() """ Evaluate on testing data. """ del train_ds["image"] del train_ds["label"] fmap = sflax.FlaxMap(model, modvar) del model, modvar maxn = test_nimg // 4 start_time = time() output = fmap(test_ds["image"][:maxn]) time_eval = time() - start_time output = np.clip(output, a_min=0, a_max=1.0) """ Evaluate trained model in terms of reconstruction time and data fidelity. """ snr_eval = metric.snr(test_ds["label"][:maxn], output) psnr_eval = metric.psnr(test_ds["label"][:maxn], output) print( f"{'ODPNet training':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}" f"{'':21s}{'time[s]:':10s}{trainer.train_time:>7.2f}" ) print( f"{'ODPNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}" f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}" ) """ Plot comparison. """ np.random.seed(123) indx = np.random.randint(0, high=maxn) fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(test_ds["label"][indx, ..., 0], title="Ground truth", cbar=None, fig=fig, ax=ax[0]) plot.imview( test_ds["image"][indx, ..., 0], title="Blurred: \nSNR: %.2f (dB), PSNR: %.2f" % ( metric.snr(test_ds["label"][indx, ..., 0], test_ds["image"][indx, ..., 0]), metric.psnr(test_ds["label"][indx, ..., 0], test_ds["image"][indx, ..., 0]), ), cbar=None, fig=fig, ax=ax[1], ) plot.imview( output[indx, ..., 0], title="ODPNet Reconstruction\nSNR: %.2f (dB), PSNR: %.2f" % ( metric.snr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]), metric.psnr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]), ), fig=fig, ax=ax[2], ) divider = make_axes_locatable(ax[2]) cax = divider.append_axes("right", size="5%", pad=0.2) fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units") fig.show() """ Plot convergence statistics. Statistics are generated only if a training cycle was done (i.e. if not reading final epoch results from checkpoint). """ if stats_object is not None and len(stats_object.iterations) > 0: hist = stats_object.history(transpose=True) fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( np.array((hist.Train_Loss, hist.Eval_Loss)).T, x=hist.Epoch, ptyp="semilogy", title="Loss function", xlbl="Epoch", ylbl="Loss value", lgnd=("Train", "Test"), fig=fig, ax=ax[0], ) plot.plot( np.array((hist.Train_SNR, hist.Eval_SNR)).T, x=hist.Epoch, title="Metric", xlbl="Epoch", ylbl="SNR (dB)", lgnd=("Train", "Test"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/deconv_ppp_bm3d_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ PPP (with BM3D) Image Deconvolution (ADMM Solver) ================================================= This example demonstrates the solution of an image deconvolution problem using the ADMM Plug-and-Play Priors (PPP) algorithm :cite:`venkatakrishnan-2013-plugandplay2`, with the BM3D :cite:`dabov-2008-image` denoiser. """ import numpy as np from xdesign import Foam, discrete_phantom import scico.numpy as snp from scico import functional, linop, loss, metric, plot, random from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Create a ground truth image. """ np.random.seed(1234) N = 512 # image size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) x_gt = snp.array(x_gt) # convert to jax array """ Set up forward operator and test signal consisting of blurred signal with additive Gaussian noise. """ n = 5 # convolution kernel size σ = 20.0 / 255 # noise level psf = snp.ones((n, n)) / (n * n) A = linop.Convolve(h=psf, input_shape=x_gt.shape) Ax = A(x_gt) # blurred image noise, key = random.randn(Ax.shape) y = Ax + σ * noise """ Set up ADMM solver. """ f = loss.SquaredL2Loss(y=y, A=A) C = linop.Identity(x_gt.shape) λ = 20.0 / 255 # BM3D regularization strength g = λ * functional.BM3D() ρ = 1.0 # ADMM penalty parameter maxiter = 10 # number of ADMM iterations solver = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], x0=A.T @ y, maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), itstat_options={"display": True}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x = solver.solve() x = snp.clip(x, 0, 1) hist = solver.itstat_object.history(transpose=True) """ Show the recovered image. """ fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0]) nc = n // 2 yc = snp.clip(y[nc:-nc, nc:-nc], 0, 1) plot.imview(y, title="Blurred, noisy image: %.2f (dB)" % metric.psnr(x_gt, yc), fig=fig, ax=ax[1]) plot.imview(x, title="Deconvolved image: %.2f (dB)" % metric.psnr(x_gt, x), fig=fig, ax=ax[2]) fig.show() """ Plot convergence statistics. """ plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), ) input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/deconv_ppp_bm3d_apgm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ PPP (with BM3D) Image Deconvolution (APGM Solver) ================================================= This example demonstrates the solution of an image deconvolution problem using the APGM Plug-and-Play Priors (PPP) algorithm :cite:`kamilov-2017-plugandplay`, with the BM3D :cite:`dabov-2008-image` denoiser. """ import numpy as np from xdesign import Foam, discrete_phantom import scico.numpy as snp from scico import functional, linop, loss, metric, plot, random from scico.optimize.pgm import AcceleratedPGM from scico.util import device_info """ Create a ground truth image. """ np.random.seed(1234) N = 512 # image size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) x_gt = snp.array(x_gt) # convert to jax array """ Set up forward operator and test signal consisting of blurred signal with additive Gaussian noise. """ n = 5 # convolution kernel size σ = 20.0 / 255 # noise level psf = snp.ones((n, n)) / (n * n) A = linop.Convolve(h=psf, input_shape=x_gt.shape) Ax = A(x_gt) # blurred image noise, key = random.randn(Ax.shape) y = Ax + σ * noise """ Set up PGM solver. """ f = loss.SquaredL2Loss(y=y, A=A) L0 = 15 # APGM inverse step size parameter λ = L0 * 2.0 / 255 # BM3D regularization strength g = λ * functional.BM3D() maxiter = 50 # number of APGM iterations solver = AcceleratedPGM( f=f, g=g, L0=L0, x0=A.T @ y, maxiter=maxiter, itstat_options={"display": True, "period": 10} ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x = solver.solve() x = snp.clip(x, 0, 1) hist = solver.itstat_object.history(transpose=True) """ Show the recovered image. """ fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0]) nc = n // 2 yc = snp.clip(y[nc:-nc, nc:-nc], 0, 1) plot.imview(y, title="Blurred, noisy image: %.2f (dB)" % metric.psnr(x_gt, yc), fig=fig, ax=ax[1]) plot.imview(x, title="Deconvolved image: %.2f (dB)" % metric.psnr(x_gt, x), fig=fig, ax=ax[2]) fig.show() """ Plot convergence statistics. """ plot.plot(hist.Residual, ptyp="semilogy", title="PGM Residual", xlbl="Iteration", ylbl="Residual") input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/deconv_ppp_bm4d_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ PPP (with BM4D) Volume Deconvolution ==================================== This example demonstrates the solution of a 3D image deconvolution problem (involving recovering a 3D volume that has been convolved with a 3D kernel and corrupted by noise) using the ADMM Plug-and-Play Priors (PPP) algorithm :cite:`venkatakrishnan-2013-plugandplay2`, with the BM4D :cite:`maggioni-2012-nonlocal` denoiser. """ import numpy as np import scico.numpy as snp from scico import functional, linop, loss, metric, plot, random from scico.examples import create_3d_foam_phantom, downsample_volume, tile_volume_slices from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Create a ground truth image. """ np.random.seed(1234) N = 128 # phantom size Nx, Ny, Nz = N, N, N // 4 upsamp = 2 x_gt_hires = create_3d_foam_phantom((upsamp * Nz, upsamp * Ny, upsamp * Nx), N_sphere=100) x_gt = downsample_volume(x_gt_hires, upsamp) x_gt = snp.array(x_gt) # convert to jax array """ Set up forward operator and test signal consisting of blurred signal with additive Gaussian noise. """ n = 5 # convolution kernel size σ = 20.0 / 255 # noise level psf = snp.ones((n, n, n)) / (n**3) A = linop.Convolve(h=psf, input_shape=x_gt.shape) Ax = A(x_gt) # blurred image noise, key = random.randn(Ax.shape) y = Ax + σ * noise """ Set up ADMM solver. """ f = loss.SquaredL2Loss(y=y, A=A) C = linop.Identity(x_gt.shape) λ = 40.0 / 255 # BM4D regularization strength g = λ * functional.BM4D() ρ = 1.0 # ADMM penalty parameter maxiter = 10 # number of ADMM iterations solver = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], x0=A.T @ y, maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), itstat_options={"display": True}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x = solver.solve() x = snp.clip(x, 0, 1) hist = solver.itstat_object.history(transpose=True) """ Show slices of the recovered 3D volume. """ show_id = Nz // 2 fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(tile_volume_slices(x_gt), title="Ground truth", fig=fig, ax=ax[0]) nc = n // 2 yc = y[nc:-nc, nc:-nc, nc:-nc] yc = snp.clip(yc, 0, 1) plot.imview( tile_volume_slices(yc), title="Slices of blurred, noisy volume: %.2f (dB)" % metric.psnr(x_gt, yc), fig=fig, ax=ax[1], ) plot.imview( tile_volume_slices(x), title="Slices of deconvolved volume: %.2f (dB)" % metric.psnr(x_gt, x), fig=fig, ax=ax[2], ) fig.show() """ Plot convergence statistics. """ plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), ) input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/deconv_ppp_dncnn_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ PPP (with DnCNN) Image Deconvolution (ADMM Solver) ================================================== This example demonstrates the solution of an image deconvolution problem using the ADMM Plug-and-Play Priors (PPP) algorithm :cite:`venkatakrishnan-2013-plugandplay2` with the DnCNN :cite:`zhang-2017-dncnn` denoiser. """ import numpy as np from xdesign import Foam, discrete_phantom import scico.numpy as snp from scico import functional, linop, loss, metric, plot, random from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Create a ground truth image. """ np.random.seed(1234) N = 512 # image size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) x_gt = snp.array(x_gt) # convert to jax array """ Set up forward operator and test signal consisting of blurred signal with additive Gaussian noise. """ n = 5 # convolution kernel size σ = 20.0 / 255 # noise level psf = snp.ones((n, n)) / (n * n) A = linop.Convolve(h=psf, input_shape=x_gt.shape) Ax = A(x_gt) # blurred image noise, key = random.randn(Ax.shape) y = Ax + σ * noise """ Set up the problem to be solved. We want to minimize the functional $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + R(\mathbf{x}) \;$$ where $R(\cdot)$ is a pseudo-functional having the DnCNN denoiser as its proximal operator. The problem is solved via ADMM, using the standard variable splitting for problems of this form, which requires the use of conjugate gradient sub-iterations in the ADMM step that involves the data fidelity term. """ f = loss.SquaredL2Loss(y=y, A=A) g = functional.DnCNN("17M") C = linop.Identity(x_gt.shape) """ Set up ADMM solver. """ ρ = 0.2 # ADMM penalty parameter maxiter = 10 # number of ADMM iterations solver = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], x0=A.T @ y, maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 30}), itstat_options={"display": True}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x = solver.solve() x = snp.clip(x, 0, 1) hist = solver.itstat_object.history(transpose=True) """ Show the recovered image. """ fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0]) nc = n // 2 yc = snp.clip(y[nc:-nc, nc:-nc], 0, 1) plot.imview(y, title="Blurred, noisy image: %.2f (dB)" % metric.psnr(x_gt, yc), fig=fig, ax=ax[1]) plot.imview(x, title="Deconvolved image: %.2f (dB)" % metric.psnr(x_gt, x), fig=fig, ax=ax[2]) fig.show() """ Plot convergence statistics. """ plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), ) input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/deconv_ppp_dncnn_padmm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ PPP (with DnCNN) Image Deconvolution (Proximal ADMM Solver) =========================================================== This example demonstrates the solution of an image deconvolution problem using a proximal ADMM variant of the Plug-and-Play Priors (PPP) algorithm :cite:`venkatakrishnan-2013-plugandplay2` with the DnCNN :cite:`zhang-2017-dncnn` denoiser. """ import numpy as np from xdesign import Foam, discrete_phantom import scico.numpy as snp from scico import functional, linop, loss, metric, plot, random from scico.optimize import ProximalADMM from scico.util import device_info """ Create a ground truth image. """ np.random.seed(1234) N = 512 # image size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) x_gt = snp.array(x_gt) # convert to jax array """ Set up forward operator $A$ and test signal consisting of blurred signal with additive Gaussian noise. """ n = 5 # convolution kernel size σ = 20.0 / 255 # noise level psf = snp.ones((n, n)) / (n * n) A = linop.Convolve(h=psf, input_shape=x_gt.shape) Ax = A(x_gt) # blurred image noise, key = random.randn(Ax.shape) y = Ax + σ * noise r""" Set up the problem to be solved. We want to minimize the functional $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + R(\mathbf{x}) \;$$ where $R(\cdot)$ is a pseudo-functional having the DnCNN denoiser as its proximal operator. A slightly unusual variable splitting is used,\ including setting the $f$ functional to the $R(\cdot)$ term and the $g$ functional to the data fidelity term to allow the use of proximal ADMM, which avoids the need for conjugate gradient sub-iterations in the solver steps. """ f = functional.DnCNN(variant="17M") g = loss.SquaredL2Loss(y=y) """ Set up proximal ADMM solver. """ ρ = 0.2 # ADMM penalty parameter maxiter = 10 # number of proximal ADMM iterations mu, nu = ProximalADMM.estimate_parameters(A) solver = ProximalADMM( f=f, g=g, A=A, rho=ρ, mu=mu, nu=nu, x0=A.T @ y, maxiter=maxiter, itstat_options={"display": True}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x = solver.solve() x = snp.clip(x, 0, 1) hist = solver.itstat_object.history(transpose=True) """ Show the recovered image. """ fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0]) nc = n // 2 yc = snp.clip(y[nc:-nc, nc:-nc], 0, 1) plot.imview(y, title="Blurred, noisy image: %.2f (dB)" % metric.psnr(x_gt, yc), fig=fig, ax=ax[1]) plot.imview(x, title="Deconvolved image: %.2f (dB)" % metric.psnr(x_gt, x), fig=fig, ax=ax[2]) fig.show() """ Plot convergence statistics. """ plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), ) input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/deconv_tv_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Image Deconvolution with TV Regularization (ADMM Solver) ======================================================== This example demonstrates the solution of an image deconvolution problem with isotropic total variation (TV) regularization $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} \|_2^2 + \lambda \| D \mathbf{x} \|_{2,1} \;,$$ where $C$ is a convolution operator, $\mathbf{y}$ is the blurred image, $D$ is a 2D finite fifference operator, and $\mathbf{x}$ is the deconvolved image. In this example the problem is solved via standard ADMM, while proximal ADMM is used in a [companion example](deconv_tv_padmm.rst). """ from xdesign import SiemensStar, discrete_phantom import scico.numpy as snp import scico.random from scico import functional, linop, loss, metric, plot from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Create a ground truth image. """ phantom = SiemensStar(32) N = 256 # image size x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) """ Set up the forward operator and create a test signal consisting of a blurred signal with additive Gaussian noise. """ n = 5 # convolution kernel size σ = 20.0 / 255 # noise level psf = snp.ones((n, n)) / (n * n) C = linop.Convolve(h=psf, input_shape=x_gt.shape) Cx = C(x_gt) # blurred image noise, key = scico.random.randn(Cx.shape, seed=0) y = Cx + σ * noise r""" Set up the problem to be solved. We want to minimize the functional $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} \|_2^2 + \lambda \| D \mathbf{x} \|_{2,1} \;,$$ where $C$ is the convolution operator and $D$ is a finite difference operator. This problem can be expressed as $$\mathrm{argmin}_{\mathbf{x}, \mathbf{z}} \; (1/2) \| \mathbf{y} - C \mathbf{x} \|_2^2 + \lambda \| \mathbf{z} \|_{2,1} \;\; \text{such that} \;\; \mathbf{z} = D \mathbf{x} \;,$$ which is easily written in the form of a standard ADMM problem. This is simpler splitting than that used in the [companion example](deconv_tv_padmm.rst), but it requires the use conjugate gradient sub-iterations to solve the ADMM step associated with the data fidelity term. """ f = loss.SquaredL2Loss(y=y, A=C) # Penalty parameters must be accounted for in the gi functions, not as # additional inputs. λ = 2.1e-2 # ℓ2,1 norm regularization parameter g = λ * functional.L21Norm() # The append=0 option makes the results of horizontal and vertical # finite differences the same shape, which is required for the L21Norm, # which is used so that g(Cx) corresponds to isotropic TV. D = linop.FiniteDifference(input_shape=x_gt.shape, append=0) """ Set up an ADMM solver object. """ ρ = 1.0e-1 # ADMM penalty parameter maxiter = 50 # number of ADMM iterations solver = ADMM( f=f, g_list=[g], C_list=[D], rho_list=[ρ], x0=C.adj(y), maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(), itstat_options={"display": True, "period": 10}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x = solver.solve() hist = solver.itstat_object.history(transpose=True) """ Show the recovered image. """ fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0]) nc = n // 2 yc = y[nc:-nc, nc:-nc] plot.imview(y, title="Blurred, noisy image: %.2f (dB)" % metric.psnr(x_gt, yc), fig=fig, ax=ax[1]) plot.imview( solver.x, title="Deconvolved image: %.2f (dB)" % metric.psnr(x_gt, solver.x), fig=fig, ax=ax[2] ) fig.show() """ Plot convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( hist.Objective, title="Objective function", xlbl="Iteration", ylbl="Functional value", fig=fig, ax=ax[0], ) plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/deconv_tv_admm_tune.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Parameter Tuning for Image Deconvolution with TV Regularization (ADMM Solver) ============================================================================= This example demonstrates the use of [scico.ray.tune](../_autosummary/scico.ray.tune.rst) to tune parameters for the companion [example script](deconv_tv_admm.rst). The `ray.tune` function API is used in this example. This script is hard-coded to run on CPU only to avoid the large number of warnings that are emitted when GPU resources are requested but not available, and due to the difficulty of suppressing these warnings in a way that does not force use of the CPU only. To enable GPU usage, comment out the `os.environ` statements near the beginning of the script, and change the value of the "gpu" entry in the `resources` dict from 0 to 1. Note that two environment variables are set to suppress the warnings because `JAX_PLATFORMS` was intended to replace `JAX_PLATFORM_NAME` but this change has yet to be correctly implemented (see [google/jax#6805](https://github.com/google/jax/issues/6805) and [google/jax#10272](https://github.com/google/jax/pull/10272)). """ # isort: off import os os.environ["JAX_PLATFORM_NAME"] = "cpu" os.environ["JAX_PLATFORMS"] = "cpu" from xdesign import SiemensStar, discrete_phantom import logging import ray ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 import scico.numpy as snp import scico.random from scico import functional, linop, loss, metric, plot from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.ray import report, tune """ Create a ground truth image. """ phantom = SiemensStar(32) N = 256 # image size x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) """ Set up the forward operator and create a test signal consisting of a blurred signal with additive Gaussian noise. """ n = 5 # convolution kernel size σ = 20.0 / 255 # noise level psf = snp.ones((n, n)) / (n * n) A = linop.Convolve(h=psf, input_shape=x_gt.shape) Ax = A(x_gt) # blurred image noise, key = scico.random.randn(Ax.shape, seed=0) y = Ax + σ * noise """ Define performance evaluation function. """ def eval_params(config, x_gt, psf, y): """Parameter evaluation function. The `config` parameter is a dict of specific parameters for evaluation of a single parameter set (a pair of parameters in this case). The remaining parameters are objects that are passed to the evaluation function via the ray object store. """ # Extract solver parameters from config dict. λ, ρ = config["lambda"], config["rho"] # Set up problem to be solved. A = linop.Convolve(h=psf, input_shape=x_gt.shape) f = loss.SquaredL2Loss(y=y, A=A) g = λ * functional.L21Norm() C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) # Define solver. solver = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], x0=A.adj(y), maxiter=10, subproblem_solver=LinearSubproblemSolver(), ) # Perform 50 iterations, reporting performance to ray.tune every 10 iterations. for step in range(5): x_admm = solver.solve() report({"psnr": float(metric.psnr(x_gt, x_admm))}) """ Define parameter search space and resources per trial. """ config = {"lambda": tune.loguniform(1e-3, 1e-1), "rho": tune.loguniform(1e-2, 1e0)} resources = {"cpu": 4, "gpu": 0} # cpus per trial, gpus per trial """ Run parameter search. """ tuner = tune.Tuner( tune.with_parameters(eval_params, x_gt=x_gt, psf=psf, y=y), param_space=config, resources=resources, metric="psnr", mode="max", num_samples=100, # perform 100 parameter evaluations ) results = tuner.fit() ray.shutdown() """ Display best parameters and corresponding performance. """ best_result = results.get_best_result() best_config = best_result.config print(f"Best PSNR: {best_result.metrics['psnr']:.2f} dB") print("Best config: " + ", ".join([f"{k}: {v:.2e}" for k, v in best_config.items()])) """ Plot parameter values visited during parameter search. Marker sizes are proportional to number of iterations run at each parameter pair. The best point in the parameter space is indicated in red. """ fig = plot.figure(figsize=(8, 8)) trials = results.get_dataframe() for t in trials.iloc: n = t["training_iteration"] plot.plot( t["config/lambda"], t["config/rho"], ptyp="loglog", lw=0, ms=(0.5 + 1.5 * n), marker="o", mfc="blue", mec="blue", fig=fig, ) plot.plot( best_config["lambda"], best_config["rho"], ptyp="loglog", title="Parameter search sampling locations\n(marker size proportional to number of iterations)", xlbl=r"$\rho$", ylbl=r"$\lambda$", lw=0, ms=5.0, marker="o", mfc="red", mec="red", fig=fig, ) ax = fig.axes[0] ax.set_xlim([config["rho"].lower, config["rho"].upper]) ax.set_ylim([config["lambda"].lower, config["lambda"].upper]) fig.show() """ Plot parameter values visited during parameter search and corresponding reconstruction PSNRs.The best point in the parameter space is indicated in red. """ 𝜌 = [t["config/rho"] for t in trials.iloc] 𝜆 = [t["config/lambda"] for t in trials.iloc] psnr = [t["psnr"] for t in trials.iloc] minpsnr = min(max(psnr), 18.0) 𝜌, 𝜆, psnr = zip(*filter(lambda x: x[2] >= minpsnr, zip(𝜌, 𝜆, psnr))) fig, ax = plot.subplots(figsize=(10, 8)) sc = ax.scatter(𝜌, 𝜆, c=psnr, cmap=plot.cm.plasma_r) fig.colorbar(sc) plot.plot( best_config["lambda"], best_config["rho"], ptyp="loglog", lw=0, ms=12.0, marker="2", mfc="red", mec="red", fig=fig, ax=ax, ) ax.set_xscale("log") ax.set_yscale("log") ax.set_xlabel(r"$\rho$") ax.set_ylabel(r"$\lambda$") ax.set_title("PSNR at each sample location\n(values below 18 dB omitted)") fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/deconv_tv_padmm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Image Deconvolution with TV Regularization (Proximal ADMM Solver) ================================================================= This example demonstrates the solution of an image deconvolution problem with isotropic total variation (TV) regularization $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} \|_2^2 + \lambda \| D \mathbf{x} \|_{2,1} \;,$$ where $C$ is a convolution operator, $\mathbf{y}$ is the blurred image, $D$ is a 2D finite difference operator, and $\mathbf{x}$ is the deconvolved image. In this example the problem is solved via proximal ADMM, while standard ADMM is used in a [companion example](deconv_tv_admm.rst). """ from xdesign import SiemensStar, discrete_phantom import scico.numpy as snp import scico.random from scico import functional, linop, loss, metric, plot from scico.optimize import ProximalADMM from scico.util import device_info """ Create a ground truth image. """ phantom = SiemensStar(32) N = 256 # image size x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) """ Set up the forward operator and create a test signal consisting of a blurred signal with additive Gaussian noise. """ n = 5 # convolution kernel size σ = 20.0 / 255 # noise level psf = snp.ones((n, n)) / (n * n) C = linop.Convolve(h=psf, input_shape=x_gt.shape) Cx = C(x_gt) # blurred image noise, key = scico.random.randn(Cx.shape, seed=0) y = Cx + σ * noise r""" Set up the problem to be solved. We want to minimize the functional $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} \|_2^2 + \lambda \| D \mathbf{x} \|_{2,1} \;,$$ where $C$ is the convolution operator and $D$ is a finite difference operator. This problem can be expressed as $$\mathrm{argmin}_{\mathbf{x}, \mathbf{z}} \; (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 + \lambda \| \mathbf{z}_1 \|_{2,1} \;\; \text{such that} \;\; \mathbf{z}_0 = C \mathbf{x} \;\; \text{and} \;\; \mathbf{z}_1 = D \mathbf{x} \;,$$ which can be written in the form of a standard ADMM problem $$\mathrm{argmin}_{\mathbf{x}, \mathbf{z}} \; f(\mathbf{x}) + g(\mathbf{z}) \;\; \text{such that} \;\; A \mathbf{x} + B \mathbf{z} = \mathbf{c}$$ with $$f = 0 \qquad g = g_0 + g_1$$ $$g_0(\mathbf{z}_0) = (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 \qquad g_1(\mathbf{z}_1) = \lambda \| \mathbf{z}_1 \|_{2,1}$$ $$A = \left( \begin{array}{c} C \\ D \end{array} \right) \qquad B = \left( \begin{array}{cc} -I & 0 \\ 0 & -I \end{array} \right) \qquad \mathbf{c} = \left( \begin{array}{c} 0 \\ 0 \end{array} \right) \;.$$ This is a more complex splitting than that used in the [companion example](deconv_tv_admm.rst), but it allows the use of a proximal ADMM solver in a way that avoids the need for the conjugate gradient sub-iterations used by the ADMM solver in the [companion example](deconv_tv_admm.rst). """ f = functional.ZeroFunctional() g0 = loss.SquaredL2Loss(y=y) λ = 2.0e-2 # ℓ2,1 norm regularization parameter g1 = λ * functional.L21Norm() g = functional.SeparableFunctional((g0, g1)) D = linop.FiniteDifference(input_shape=x_gt.shape, append=0) A = linop.VerticalStack((C, D)) """ Set up a proximal ADMM solver object. """ ρ = 5.0e-2 # ADMM penalty parameter maxiter = 50 # number of ADMM iterations mu, nu = ProximalADMM.estimate_parameters(A) solver = ProximalADMM( f=f, g=g, A=A, B=None, rho=ρ, mu=mu, nu=nu, x0=C.adj(y), maxiter=maxiter, itstat_options={"display": True, "period": 10}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x = solver.solve() hist = solver.itstat_object.history(transpose=True) """ Show the recovered image. """ fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0]) nc = n // 2 yc = y[nc:-nc, nc:-nc] plot.imview(y, title="Blurred, noisy image: %.2f (dB)" % metric.psnr(x_gt, yc), fig=fig, ax=ax[1]) plot.imview( solver.x, title="Deconvolved image: %.2f (dB)" % metric.psnr(x_gt, solver.x), fig=fig, ax=ax[2] ) fig.show() """ Plot convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( hist.Objective, title="Objective function", xlbl="Iteration", ylbl="Functional value", fig=fig, ax=ax[0], ) plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/demosaic_ppp_bm3d_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ PPP (with BM3D) Image Demosaicing ================================= This example demonstrates the use of the ADMM Plug and Play Priors (PPP) algorithm :cite:`venkatakrishnan-2013-plugandplay2`, with the BM3D :cite:`dabov-2008-image` denoiser, for solving a raw image demosaicing problem. """ import numpy as np from bm3d import bm3d_rgb # Workarounds for colour_demosaicing incompatibility with NumPy 2.x np.float_ = np.float64 np.float = np.float64 np.complex = np.complex128 np.sctypes = { "float": [np.float16, np.float32, np.float64, np.longdouble], "int": [np.int8, np.int16, np.int32, np.int64], } from colour_demosaicing import demosaicing_CFA_Bayer_Menon2007 import scico import scico.numpy as snp import scico.random from scico import functional, linop, loss, metric, plot from scico.data import kodim23 from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Read a ground truth image. """ img = snp.array(kodim23(asfloat=True)[160:416, 60:316]) """ Define demosaicing forward operator and its transpose. """ def Afn(x): """Map an RGB image to a single channel image with each pixel representing a single colour according to the colour filter array. """ y = snp.zeros(x.shape[0:2]) y = y.at[1::2, 1::2].set(x[1::2, 1::2, 0]) y = y.at[0::2, 1::2].set(x[0::2, 1::2, 1]) y = y.at[1::2, 0::2].set(x[1::2, 0::2, 1]) y = y.at[0::2, 0::2].set(x[0::2, 0::2, 2]) return y def ATfn(x): """Back project a single channel raw image to an RGB image with zeros at the locations of undefined samples. """ y = snp.zeros(x.shape + (3,)) y = y.at[1::2, 1::2, 0].set(x[1::2, 1::2]) y = y.at[0::2, 1::2, 1].set(x[0::2, 1::2]) y = y.at[1::2, 0::2, 1].set(x[1::2, 0::2]) y = y.at[0::2, 0::2, 2].set(x[0::2, 0::2]) return y """ Define a baseline demosaicing function based on the demosaicing algorithm of :cite:`menon-2007-demosaicing` from package [colour_demosaicing](https://github.com/colour-science/colour-demosaicing). """ def demosaic(cfaimg): """Apply baseline demosaicing.""" return demosaicing_CFA_Bayer_Menon2007(cfaimg, pattern="BGGR").astype(np.float32) """ Create a test image by color filter array sampling and adding Gaussian white noise. """ s = Afn(img) rgbshp = s.shape + (3,) # shape of reconstructed RGB image σ = 2e-2 # noise standard deviation noise, key = scico.random.randn(s.shape, seed=0) sn = s + σ * noise """ Compute a baseline demosaicing solution. """ imgb = snp.array(bm3d_rgb(demosaic(sn), 3 * σ).astype(np.float32)) """ Set up an ADMM solver object. Note the use of the baseline solution as an initializer. We use BM3D :cite:`dabov-2008-image` as the denoiser, using the [code](https://pypi.org/project/bm3d) released with :cite:`makinen-2019-exact`. """ A = linop.LinearOperator(input_shape=rgbshp, output_shape=s.shape, eval_fn=Afn, adj_fn=ATfn) f = loss.SquaredL2Loss(y=sn, A=A) C = linop.Identity(input_shape=rgbshp) g = 1.8e-1 * 6.1e-2 * functional.BM3D(is_rgb=True) ρ = 1.8e-1 # ADMM penalty parameter maxiter = 12 # number of ADMM iterations solver = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], x0=imgb, maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), itstat_options={"display": True}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x = solver.solve() hist = solver.itstat_object.history(transpose=True) """ Show reference and demosaiced images. """ fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(21, 7)) plot.imview(img, title="Reference", fig=fig, ax=ax[0]) plot.imview(imgb, title="Baseline demoisac: %.2f (dB)" % metric.psnr(img, imgb), fig=fig, ax=ax[1]) plot.imview(x, title="PPP demoisac: %.2f (dB)" % metric.psnr(img, x), fig=fig, ax=ax[2]) fig.show() """ Plot convergence statistics. """ plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), ) input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/denoise_approx_tv_multi.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Denoising with Approximate Total Variation Proximal Operator ============================================================ This example demonstrates use of approximations to the proximal operators of isotropic :cite:`kamilov-2016-minimizing` and anisotropic :cite:`kamilov-2016-parallel` total variation norms for solving denoising problems using proximal algorithms. """ import matplotlib from xdesign import SiemensStar, discrete_phantom import scico.numpy as snp import scico.random from scico import functional, linop, loss, metric, plot from scico.optimize import AcceleratedPGM from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Create a ground truth image. """ N = 256 # image size phantom = SiemensStar(16) x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) x_gt = x_gt / x_gt.max() """ Add noise to create a noisy test image. """ σ = 0.5 # noise standard deviation noise, key = scico.random.randn(x_gt.shape, seed=0) y = x_gt + σ * noise """ Denoise with isotropic total variation, solved via ADMM. """ λ_iso = 1.0e0 f = loss.SquaredL2Loss(y=y) g_iso = λ_iso * functional.L21Norm() C = linop.FiniteDifference(input_shape=x_gt.shape, circular=True) solver = ADMM( f=f, g_list=[g_iso], C_list=[C], rho_list=[1e1], x0=y, maxiter=200, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4, "maxiter": 25}), itstat_options={"display": True, "period": 25}, ) print(f"Solving on {device_info()}\n") x_iso = solver.solve() print() """ Denoise with anisotropic total variation, solved via ADMM. """ # Tune the weight to give the same data fidelity as the isotropic case. λ_aniso = 8.68e-1 g_aniso = λ_aniso * functional.L1Norm() solver = ADMM( f=f, g_list=[g_aniso], C_list=[C], rho_list=[1e1], x0=y, maxiter=200, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4, "maxiter": 25}), itstat_options={"display": True, "period": 25}, ) x_aniso = solver.solve() print() """ Denoise with isotropic total variation, solved using an approximation of the TV norm proximal operator. """ h = λ_iso * functional.IsotropicTVNorm(circular=True, input_shape=y.shape) solver = AcceleratedPGM( f=f, g=h, L0=1e3, x0=y, maxiter=500, itstat_options={"display": True, "period": 50} ) x_iso_aprx = solver.solve() print() """ Denoise with anisotropic total variation, solved using an approximation of the TV norm proximal operator. """ h = λ_aniso * functional.AnisotropicTVNorm(circular=True, input_shape=y.shape) solver = AcceleratedPGM( f=f, g=h, L0=1e3, x0=y, maxiter=500, itstat_options={"display": True, "period": 50} ) x_aniso_aprx = solver.solve() print() """ Compute and print the data fidelity. """ for x, name in zip( (x_iso, x_aniso, x_iso_aprx, x_aniso_aprx), ("Isotropic", "Anisotropic", "Approx. Isotropic", "Approx. Anisotropic"), ): df = f(x) print(f"Data fidelity for {name} TV: {' ' * (20 - len(name))} {df:.2e}") """ Plot results. """ matplotlib.rc("font", size=9) plt_args = dict(norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.5)) fig, ax = plot.subplots(nrows=2, ncols=3, sharex=True, sharey=True, figsize=(15, 8)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0], **plt_args) plot.imview( y, title=f"Noisy version SNR: {metric.snr(x_gt, y):.2f} dB", fig=fig, ax=ax[1, 0], **plt_args ) plot.imview( x_iso, title=f"Iso. TV denoising SNR: {metric.snr(x_gt, x_iso):.2f} dB", fig=fig, ax=ax[0, 1], **plt_args, ) plot.imview( x_aniso, title=f"Aniso. TV denoising SNR: {metric.snr(x_gt, x_aniso):.2f} dB", fig=fig, ax=ax[1, 1], **plt_args, ) plot.imview( x_iso_aprx, title=f"Approx. Iso. TV denoising SNR: {metric.snr(x_gt, x_iso_aprx):.2f} dB", fig=fig, ax=ax[0, 2], **plt_args, ) plot.imview( x_aniso_aprx, title=f"Approx. Aniso. TV denoising SNR: {metric.snr(x_gt, x_aniso_aprx):.2f} dB", fig=fig, ax=ax[1, 2], **plt_args, ) fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01) fig.colorbar( ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="Arbitrary Units" ) fig.suptitle("Denoising comparison") fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/denoise_cplx_tv_nlpadmm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Complex Total Variation Denoising with NLPADMM Solver ===================================================== This example demonstrates solution of a problem of the form $$\argmin_{\mb{x}} \; f(\mb{x}) + g(\mb{z}) \; \text{such that}\; H(\mb{x}, \mb{z}) = 0 \;,$$ where $H$ is a nonlinear function, via a variant of the proximal ADMM algorithm for problems with a non-linear operator constraint :cite:`benning-2016-preconditioned`. The example problem represents total variation (TV) denoising applied to a complex image with piece-wise smooth magnitude and non-smooth phase. (This example is rather contrived, and was not constructed to represent a specific real imaging problem, but it does have some properties in common with synthetic aperture radar single look complex data in which the magnitude has much more discernible structure than the phase.) The appropriate TV denoising formulation for this problem is $$\argmin_{\mb{x}} \; (1/2) \| \mb{y} - \mb{x} \|_2^2 + \lambda \| C(\mb{x}) \|_{2,1} \;,$$ where $\mb{y}$ is the measurement, $\|\cdot\|_{2,1}$ is the $\ell_{2,1}$ mixed norm, and $C$ is a non-linear operator consisting of a linear difference operator applied to the magnitude of a complex array. This problem is represented in the form above by taking $H(\mb{x}, \mb{z}) = C(\mb{x}) - \mb{z}$. The standard TV solution, which is also computed for comparison purposes, gives very poor results since the difference is applied independently to real and imaginary components of the complex image. """ from mpl_toolkits.axes_grid1 import make_axes_locatable from xdesign import SiemensStar, discrete_phantom import scico.numpy as snp import scico.random from scico import function, functional, linop, loss, metric, operator, plot from scico.examples import phase_diff from scico.optimize import NonLinearPADMM, ProximalADMM from scico.util import device_info """ Create a ground truth image. """ N = 256 # image size phantom = SiemensStar(16) x_mag = snp.pad(discrete_phantom(phantom, N - 16), 8) + 1.0 x_mag /= x_mag.max() # Create reference image with structured magnitude and random phase x_gt = x_mag * snp.exp(-1j * scico.random.randn(x_mag.shape, seed=0)[0]) """ Add noise to create a noisy test image. """ σ = 0.25 # noise standard deviation noise, key = scico.random.randn(x_gt.shape, seed=1, dtype=snp.complex64) y = x_gt + σ * noise """ Denoise with standard total variation. """ λ_tv = 6e-2 f = loss.SquaredL2Loss(y=y) g = λ_tv * functional.L21Norm() # The append=0 option makes the results of horizontal and vertical finite # differences the same shape, which is required for the L21Norm. C = linop.FiniteDifference(input_shape=y.shape, input_dtype=snp.complex64, append=0) solver_tv = ProximalADMM( f=f, g=g, A=C, rho=1.0, mu=8.0, nu=1.0, maxiter=200, itstat_options={"display": True, "period": 20}, ) print(f"Solving on {device_info()}\n") x_tv = solver_tv.solve() print() hist_tv = solver_tv.itstat_object.history(transpose=True) """ Denoise with total variation applied to the magnitude of a complex image. """ λ_nltv = 2e-1 g = λ_nltv * functional.L21Norm() # Redefine C for real input (now applied to magnitude of a complex array) C = linop.FiniteDifference(input_shape=y.shape, input_dtype=snp.float32, append=0) # Operator computing differences of absolute values D = C @ operator.Abs(input_shape=x_gt.shape, input_dtype=snp.complex64) # Constraint function imposing z = D(x) constraint H = function.Function( (C.shape[1], C.shape[0]), output_shape=C.shape[0], eval_fn=lambda x, z: D(x) - z, input_dtypes=(snp.complex64, snp.float32), output_dtype=snp.float32, ) solver_nltv = NonLinearPADMM( f=f, g=g, H=H, rho=5.0, mu=6.0, nu=1.0, maxiter=200, itstat_options={"display": True, "period": 20}, ) x_nltv = solver_nltv.solve() hist_nltv = solver_nltv.itstat_object.history(transpose=True) """ Plot results. """ fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6)) plot.plot( snp.array((hist_tv.Objective, hist_nltv.Objective)).T, ptyp="semilogy", title="Objective function", xlbl="Iteration", lgnd=("Standard TV", "Magnitude TV"), fig=fig, ax=ax[0], ) plot.plot( snp.array((hist_tv.Prml_Rsdl, hist_nltv.Prml_Rsdl)).T, ptyp="semilogy", title="Primal residual", xlbl="Iteration", lgnd=("Standard TV", "Magnitude TV"), fig=fig, ax=ax[1], ) plot.plot( snp.array((hist_tv.Dual_Rsdl, hist_nltv.Dual_Rsdl)).T, ptyp="semilogy", title="Dual residual", xlbl="Iteration", lgnd=("Standard TV", "Magnitude TV"), fig=fig, ax=ax[2], ) fig.show() fig, ax = plot.subplots(nrows=2, ncols=4, figsize=(20, 10)) norm = plot.matplotlib.colors.Normalize( vmin=min(snp.abs(x_gt).min(), snp.abs(y).min(), snp.abs(x_tv).min(), snp.abs(x_nltv).min()), vmax=max(snp.abs(x_gt).max(), snp.abs(y).max(), snp.abs(x_tv).max(), snp.abs(x_nltv).max()), ) plot.imview(snp.abs(x_gt), title="Ground truth", cbar=None, fig=fig, ax=ax[0, 0], norm=norm) plot.imview( snp.abs(y), title="Measured: PSNR %.2f (dB)" % metric.psnr(snp.abs(x_gt), snp.abs(y)), cbar=None, fig=fig, ax=ax[0, 1], norm=norm, ) plot.imview( snp.abs(x_tv), title="Standard TV: PSNR %.2f (dB)" % metric.psnr(snp.abs(x_gt), snp.abs(x_tv)), cbar=None, fig=fig, ax=ax[0, 2], norm=norm, ) plot.imview( snp.abs(x_nltv), title="Magnitude TV: PSNR %.2f (dB)" % metric.psnr(snp.abs(x_gt), snp.abs(x_nltv)), cbar=None, fig=fig, ax=ax[0, 3], norm=norm, ) divider = make_axes_locatable(ax[0, 3]) cax = divider.append_axes("right", size="5%", pad=0.2) fig.colorbar(ax[0, 3].get_images()[0], cax=cax) norm = plot.matplotlib.colors.Normalize( vmin=min(snp.angle(x_gt).min(), snp.angle(x_tv).min(), snp.angle(x_nltv).min()), vmax=max(snp.angle(x_gt).max(), snp.angle(x_tv).max(), snp.angle(x_nltv).max()), ) plot.imview( snp.angle(x_gt), title="Ground truth", cbar=None, fig=fig, ax=ax[1, 0], norm=norm, ) plot.imview( snp.angle(y), title="Measured: Mean phase diff. %.2f" % phase_diff(snp.angle(x_gt), snp.angle(y)).mean(), cbar=None, fig=fig, ax=ax[1, 1], norm=norm, ) plot.imview( snp.angle(x_tv), title="Standard TV: Mean phase diff. %.2f" % phase_diff(snp.angle(x_gt), snp.angle(x_tv)).mean(), cbar=None, fig=fig, ax=ax[1, 2], norm=norm, ) plot.imview( snp.angle(x_nltv), title="Magnitude TV: Mean phase diff. %.2f" % phase_diff(snp.angle(x_gt), snp.angle(x_nltv)).mean(), cbar=None, fig=fig, ax=ax[1, 3], norm=norm, ) divider = make_axes_locatable(ax[1, 3]) cax = divider.append_axes("right", size="5%", pad=0.2) fig.colorbar(ax[1, 3].get_images()[0], cax=cax) ax[0, 0].set_ylabel("Magnitude") ax[1, 0].set_ylabel("Phase") fig.tight_layout() fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/denoise_cplx_tv_pdhg.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Complex Total Variation Denoising with PDHG Solver ================================================== This example demonstrates solution of a problem of the form $$\argmin_{\mathbf{x}} \; f(\mathbf{x}) + g(C(\mathbf{x})) \;,$$ where $C$ is a nonlinear operator, via non-linear PDHG :cite:`valkonen-2014-primal`. The example problem represents total variation (TV) denoising applied to a complex image with piece-wise smooth magnitude and non-smooth phase. The appropriate TV denoising formulation for this problem is $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - \mathbf{x} \|_2^2 + \lambda \| C(\mathbf{x}) \|_{2,1} \;,$$ where $\mathbf{y}$ is the measurement, $\|\cdot\|_{2,1}$ is the $\ell_{2,1}$ mixed norm, and $C$ is a non-linear operator that applies a linear difference operator to the magnitude of a complex array. The standard TV solution, which is also computed for comparison purposes, gives very poor results since the difference is applied independently to real and imaginary components of the complex image. """ from mpl_toolkits.axes_grid1 import make_axes_locatable from xdesign import SiemensStar, discrete_phantom import scico.numpy as snp import scico.random from scico import functional, linop, loss, metric, operator, plot from scico.examples import phase_diff from scico.optimize import PDHG from scico.util import device_info """ Create a ground truth image. """ N = 256 # image size phantom = SiemensStar(16) x_mag = snp.pad(discrete_phantom(phantom, N - 16), 8) + 1.0 x_mag /= x_mag.max() # Create reference image with structured magnitude and random phase x_gt = x_mag * snp.exp(-1j * scico.random.randn(x_mag.shape, seed=0)[0]) """ Add noise to create a noisy test image. """ σ = 0.25 # noise standard deviation noise, key = scico.random.randn(x_gt.shape, seed=1, dtype=snp.complex64) y = x_gt + σ * noise """ Denoise with standard total variation. """ λ_tv = 6e-2 f = loss.SquaredL2Loss(y=y) g = λ_tv * functional.L21Norm() # The append=0 option makes the results of horizontal and vertical finite # differences the same shape, which is required for the L21Norm. C = linop.FiniteDifference(input_shape=x_gt.shape, input_dtype=snp.complex64, append=0) solver_tv = PDHG( f=f, g=g, C=C, tau=4e-1, sigma=4e-1, maxiter=200, itstat_options={"display": True, "period": 10}, ) print(f"Solving on {device_info()}\n") x_tv = solver_tv.solve() hist_tv = solver_tv.itstat_object.history(transpose=True) """ Denoise with total variation applied to the magnitude of a complex image. """ λ_nltv = 2e-1 g = λ_nltv * functional.L21Norm() # Redefine C for real input (now applied to magnitude of a complex array) C = linop.FiniteDifference(input_shape=x_gt.shape, input_dtype=snp.float32, append=0) # Operator computing differences of absolute values D = C @ operator.Abs(input_shape=x_gt.shape, input_dtype=snp.complex64) solver_nltv = PDHG( f=f, g=g, C=D, tau=4e-1, sigma=4e-1, maxiter=200, itstat_options={"display": True, "period": 10}, ) x_nltv = solver_nltv.solve() hist_nltv = solver_nltv.itstat_object.history(transpose=True) """ Plot results. """ fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6)) plot.plot( snp.array((hist_tv.Objective, hist_nltv.Objective)).T, ptyp="semilogy", title="Objective function", xlbl="Iteration", lgnd=("PDHG", "NL-PDHG"), fig=fig, ax=ax[0], ) plot.plot( snp.array((hist_tv.Prml_Rsdl, hist_nltv.Prml_Rsdl)).T, ptyp="semilogy", title="Primal residual", xlbl="Iteration", lgnd=("PDHG", "NL-PDHG"), fig=fig, ax=ax[1], ) plot.plot( snp.array((hist_tv.Dual_Rsdl, hist_nltv.Dual_Rsdl)).T, ptyp="semilogy", title="Dual residual", xlbl="Iteration", lgnd=("PDHG", "NL-PDHG"), fig=fig, ax=ax[2], ) fig.show() fig, ax = plot.subplots(nrows=2, ncols=4, figsize=(20, 10)) norm = plot.matplotlib.colors.Normalize( vmin=min(snp.abs(x_gt).min(), snp.abs(y).min(), snp.abs(x_tv).min(), snp.abs(x_nltv).min()), vmax=max(snp.abs(x_gt).max(), snp.abs(y).max(), snp.abs(x_tv).max(), snp.abs(x_nltv).max()), ) plot.imview(snp.abs(x_gt), title="Ground truth", cbar=None, fig=fig, ax=ax[0, 0], norm=norm) plot.imview( snp.abs(y), title="Measured: PSNR %.2f (dB)" % metric.psnr(snp.abs(x_gt), snp.abs(y)), cbar=None, fig=fig, ax=ax[0, 1], norm=norm, ) plot.imview( snp.abs(x_tv), title="TV: PSNR %.2f (dB)" % metric.psnr(snp.abs(x_gt), snp.abs(x_tv)), cbar=None, fig=fig, ax=ax[0, 2], norm=norm, ) plot.imview( snp.abs(x_nltv), title="NL-TV: PSNR %.2f (dB)" % metric.psnr(snp.abs(x_gt), snp.abs(x_nltv)), cbar=None, fig=fig, ax=ax[0, 3], norm=norm, ) divider = make_axes_locatable(ax[0, 3]) cax = divider.append_axes("right", size="5%", pad=0.2) fig.colorbar(ax[0, 3].get_images()[0], cax=cax) norm = plot.matplotlib.colors.Normalize( vmin=min(snp.angle(x_gt).min(), snp.angle(x_tv).min(), snp.angle(x_nltv).min()), vmax=max(snp.angle(x_gt).max(), snp.angle(x_tv).max(), snp.angle(x_nltv).max()), ) plot.imview( snp.angle(x_gt), title="Ground truth", cbar=None, fig=fig, ax=ax[1, 0], norm=norm, ) plot.imview( snp.angle(y), title="Measured: Mean phase diff. %.2f" % phase_diff(snp.angle(x_gt), snp.angle(y)).mean(), cbar=None, fig=fig, ax=ax[1, 1], norm=norm, ) plot.imview( snp.angle(x_tv), title="TV: Mean phase diff. %.2f" % phase_diff(snp.angle(x_gt), snp.angle(x_tv)).mean(), cbar=None, fig=fig, ax=ax[1, 2], norm=norm, ) plot.imview( snp.angle(x_nltv), title="NL-TV: Mean phase diff. %.2f" % phase_diff(snp.angle(x_gt), snp.angle(x_nltv)).mean(), cbar=None, fig=fig, ax=ax[1, 3], norm=norm, ) divider = make_axes_locatable(ax[1, 3]) cax = divider.append_axes("right", size="5%", pad=0.2) fig.colorbar(ax[1, 3].get_images()[0], cax=cax) ax[0, 0].set_ylabel("Magnitude") ax[1, 0].set_ylabel("Phase") fig.tight_layout() fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/denoise_datagen_bsds.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ Noisy Data Generation for NN Training ===================================== This example demonstrates how to generate noisy image data for training neural network models for denoising. The original images are part of the [BSDS500 dataset](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/) provided by the Berkeley Segmentation Dataset and Benchmark project. """ import numpy as np from scico import plot from scico.flax.examples import load_image_data """ Read data from cache or generate if not available. """ size = 40 # patch size train_nimg = 400 # number of training images test_nimg = 64 # number of testing images nimg = train_nimg + test_nimg gray = True # use gray scale images data_mode = "dn" # Denoising problem noise_level = 0.1 # Standard deviation of noise noise_range = False # Use fixed noise level stride = 23 # Stride to sample multiple patches from each image train_ds, test_ds = load_image_data( train_nimg, test_nimg, size, gray, data_mode, verbose=True, noise_level=noise_level, noise_range=noise_range, stride=stride, ) """ Plot randomly selected sample. Note that patches have small sizes, thus, plots may correspond to unidentifiable fragments. """ indx_tr = np.random.randint(0, train_nimg) indx_te = np.random.randint(0, test_nimg) fig, axes = plot.subplots(nrows=2, ncols=2, figsize=(7, 7)) plot.imview( train_ds["label"][indx_tr, ..., 0], title="Ground truth - Training Sample", fig=fig, ax=axes[0, 0], ) plot.imview( train_ds["image"][indx_tr, ..., 0], title="Noisy Image - Training Sample", fig=fig, ax=axes[0, 1], ) plot.imview( test_ds["label"][indx_te, ..., 0], title="Ground truth - Testing Sample", fig=fig, ax=axes[1, 0], ) plot.imview( test_ds["image"][indx_te, ..., 0], title="Noisy Image - Testing Sample", fig=fig, ax=axes[1, 1] ) fig.suptitle(r"Training and Testing samples") fig.tight_layout() fig.colorbar( axes[0, 1].get_images()[0], ax=axes, shrink=0.5, pad=0.05, ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/denoise_dncnn_train_bsds.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Training of DnCNN for Denoising =============================== This example demonstrates the training and application of the DnCNN model from :cite:`zhang-2017-dncnn` to denoise images that have been corrupted with additive Gaussian noise. """ # isort: off import os from time import time import numpy as np # Set an arbitrary processor count (only applies if GPU is not available). os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" import jax try: from jax.extend.backend import get_backend # introduced in jax 0.4.33 except ImportError: from jax.lib.xla_bridge import get_backend from mpl_toolkits.axes_grid1 import make_axes_locatable from scico import flax as sflax from scico import metric, plot from scico.flax.examples import load_image_data platform = get_backend().platform print("Platform: ", platform) """ Read data from cache or generate if not available. """ size = 40 # patch size train_nimg = 400 # number of training images test_nimg = 16 # number of testing images nimg = train_nimg + test_nimg gray = True # use gray scale images data_mode = "dn" # Denoising problem noise_level = 0.1 # Standard deviation of noise noise_range = False # Use fixed noise level stride = 23 # Stride to sample multiple patches from each image train_ds, test_ds = load_image_data( train_nimg, test_nimg, size, gray, data_mode, verbose=True, noise_level=noise_level, noise_range=noise_range, stride=stride, ) """ Define configuration dictionary for model and training loop. Parameters have been selected for demonstration purposes and relatively short training. The depth of the model has been reduced to 6, instead of the 17 of the original model. The suggested settings can be found in the original paper. """ # model configuration model_conf = { "depth": 6, "num_filters": 64, } # training configuration train_conf: sflax.ConfigDict = { "seed": 0, "opt_type": "ADAM", "batch_size": 128, "num_epochs": 50, "base_learning_rate": 1e-3, "warmup_epochs": 0, "log_every_steps": 5000, "log": True, "checkpointing": True, } """ Construct DnCNN model. """ channels = train_ds["image"].shape[-1] model = sflax.DnCNNNet( depth=model_conf["depth"], channels=channels, num_filters=model_conf["num_filters"], ) """ Run training loop. """ workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "dncnn_out") train_conf["workdir"] = workdir print(f"\nJAX local devices: {jax.local_devices()}\n") trainer = sflax.BasicFlaxTrainer( train_conf, model, train_ds, test_ds, ) modvar, stats_object = trainer.train() """ Evaluate on testing data. """ test_patches = 720 start_time = time() fmap = sflax.FlaxMap(model, modvar) output = fmap(test_ds["image"][:test_patches]) time_eval = time() - start_time output = np.clip(output, a_min=0, a_max=1.0) """ Evaluate trained model in terms of reconstruction time and data fidelity. """ snr_eval = metric.snr(test_ds["label"][:test_patches], output) psnr_eval = metric.psnr(test_ds["label"][:test_patches], output) print( f"{'DnCNNNet training':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}" f"{'':21s}{'time[s]:':10s}{trainer.train_time:>7.2f}" ) print( f"{'DnCNNNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}" f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}" ) """ Plot comparison. Note that plots may display unidentifiable image fragments due to the small patch size. """ np.random.seed(123) indx = np.random.randint(0, high=test_patches) fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(test_ds["label"][indx, ..., 0], title="Ground truth", cbar=None, fig=fig, ax=ax[0]) plot.imview( test_ds["image"][indx, ..., 0], title="Noisy: \nSNR: %.2f (dB), PSNR: %.2f" % ( metric.snr(test_ds["label"][indx, ..., 0], test_ds["image"][indx, ..., 0]), metric.psnr(test_ds["label"][indx, ..., 0], test_ds["image"][indx, ..., 0]), ), cbar=None, fig=fig, ax=ax[1], ) plot.imview( output[indx, ..., 0], title="DnCNNNet Reconstruction\nSNR: %.2f (dB), PSNR: %.2f" % ( metric.snr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]), metric.psnr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]), ), fig=fig, ax=ax[2], ) divider = make_axes_locatable(ax[2]) cax = divider.append_axes("right", size="5%", pad=0.2) fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units") fig.show() """ Plot convergence statistics. Statistics are generated only if a training cycle was done (i.e. if not reading final epoch results from checkpoint). """ if stats_object is not None and len(stats_object.iterations) > 0: hist = stats_object.history(transpose=True) fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( np.array((hist.Train_Loss, hist.Eval_Loss)).T, x=hist.Epoch, ptyp="semilogy", title="Loss function", xlbl="Epoch", ylbl="Loss value", lgnd=("Train", "Test"), fig=fig, ax=ax[0], ) plot.plot( np.array((hist.Train_SNR, hist.Eval_SNR)).T, x=hist.Epoch, title="Metric", xlbl="Epoch", ylbl="SNR (dB)", lgnd=("Train", "Test"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/denoise_dncnn_universal.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ Comparison of DnCNN Variants for Image Denoising ================================================ This example demonstrates the solution of an image denoising problem using DnCNN :cite:`zhang-2017-dncnn` networks trained for different noise levels, as well as custom variants with fewer network layers, and with a noise level input. The networks trained for specific noise levels are labeled 6L, 6M, 6H, 17L, 17M, and 17H, where {6, 17} denote the number of layers, and {L, M, H} represent noise standard deviation of the training images (0.06, 0.10, and 0.20 respectively). The networks with a noise standard deviation input are labeled 6N and 17N, where {6, 17} again denote the number of layers. """ import numpy as np from xdesign import Foam, discrete_phantom import scico.numpy as snp import scico.random from scico import metric, plot from scico.denoiser import DnCNN """ Create a ground truth image. """ np.random.seed(1234) N = 512 # image size x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) x_gt = snp.array(x_gt) # convert to jax array """ Test different DnCNN variants on images with different noise levels. """ print(" σ | variant | noisy image PSNR (dB) | denoised image PSNR (dB)") for σ in [0.06, 0.10, 0.20]: print("------+---------+-------------------------+-------------------------") for variant in ["17L", "17M", "17H", "17N", "6L", "6M", "6H", "6N"]: # Instantiate a DnCNN. denoiser = DnCNN(variant=variant) # Generate a noisy image. noise, key = scico.random.randn(x_gt.shape, seed=0) y = x_gt + σ * noise if variant in ["6N", "17N"]: x_hat = denoiser(y, sigma=σ) else: x_hat = denoiser(y) x_hat = np.clip(x_hat, a_min=0, a_max=1.0) if variant[0] == "6": variant += " " # add spaces to maintain alignment print( " %.2f | %s | %.2f | %.2f " % (σ, variant, metric.psnr(x_gt, y), metric.psnr(x_gt, x_hat)) ) """ Show reference and denoised images for σ=0.2 and variant=6N. """ fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(21, 7)) plot.imview(x_gt, title="Reference", fig=fig, ax=ax[0]) plot.imview(y, title="Noisy image: %.2f (dB)" % metric.psnr(x_gt, y), fig=fig, ax=ax[1]) plot.imview(x_hat, title="Denoised image: %.2f (dB)" % metric.psnr(x_gt, x_hat), fig=fig, ax=ax[2]) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/denoise_l1tv_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" ℓ1 Total Variation Denoising ============================ This example demonstrates impulse noise removal via ℓ1 total variation :cite:`alliney-1992-digital` :cite:`esser-2010-primal` (Sec. 2.4.4) (i.e. total variation regularization with an ℓ1 data fidelity term), minimizing the functional $$\mathrm{argmin}_{\mathbf{x}} \; \| \mathbf{y} - \mathbf{x} \|_1 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ where $\mathbf{y}$ is the noisy image, $C$ is a 2D finite difference operator, and $\mathbf{x}$ is the denoised image. """ from xdesign import SiemensStar, discrete_phantom import scico.numpy as snp from scico import functional, linop, loss, metric, plot from scico.examples import spnoise from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info from scipy.ndimage import median_filter """ Create a ground truth image and impose salt & pepper noise to create a noisy test image. """ N = 256 # image size phantom = SiemensStar(16) x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) x_gt = 0.5 * x_gt / x_gt.max() y = spnoise(x_gt, 0.5) """ Denoise with median filtering. """ x_med = median_filter(y, size=(5, 5)) """ Denoise with ℓ1 total variation. """ λ = 1.5e0 g_loss = loss.Loss(y=y, f=functional.L1Norm()) g_tv = λ * functional.L21Norm() # The append=0 option makes the results of horizontal and vertical finite # differences the same shape, which is required for the L21Norm. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) solver = ADMM( f=None, g_list=[g_loss, g_tv], C_list=[linop.Identity(input_shape=y.shape), C], rho_list=[5e0, 5e0], x0=y, maxiter=100, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}), itstat_options={"display": True, "period": 10}, ) print(f"Solving on {device_info()}\n") x_tv = solver.solve() hist = solver.itstat_object.history(transpose=True) """ Plot results. """ plt_args = dict(norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.0)) fig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(13, 12)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0], **plt_args) plot.imview(y, title="Noisy image", fig=fig, ax=ax[0, 1], **plt_args) plot.imview( x_med, title=f"Median filtering: {metric.psnr(x_gt, x_med):.2f} (dB)", fig=fig, ax=ax[1, 0], **plt_args, ) plot.imview( x_tv, title=f"ℓ1-TV denoising: {metric.psnr(x_gt, x_tv):.2f} (dB)", fig=fig, ax=ax[1, 1], **plt_args, ) fig.show() """ Plot convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( hist.Objective, title="Objective function", xlbl="Iteration", ylbl="Functional value", fig=fig, ax=ax[0], ) plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/denoise_ptv_pdhg.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Polar Total Variation Denoising (PDHG) ====================================== This example compares denoising via standard isotropic total variation (TV) regularization :cite:`rudin-1992-nonlinear` :cite:`goldstein-2009-split` and a variant based on local polar coordinates, as described in :cite:`hossein-2024-total`. It solves the denoising problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - \mathbf{x} \|_2^2 + \lambda R(\mathbf{x}) \;,$$ where $R$ is either the isotropic or polar TV regularizer, via the primal–dual hybrid gradient (PDHG) algorithm. """ from xdesign import SiemensStar, discrete_phantom import scico.numpy as snp import scico.random from scico import functional, linop, loss, metric, plot from scico.optimize import PDHG from scico.util import device_info """ Create a ground truth image. """ N = 256 # image size phantom = SiemensStar(16) x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) x_gt = x_gt / x_gt.max() """ Add noise to create a noisy test image. """ σ = 0.75 # noise standard deviation noise, key = scico.random.randn(x_gt.shape, seed=0) y = x_gt + σ * noise """ Denoise with standard isotropic total variation. """ λ_std = 0.8e0 f = loss.SquaredL2Loss(y=y) g_std = λ_std * functional.L21Norm() # The append=0 option makes the results of horizontal and vertical finite # differences the same shape, which is required for the L21Norm. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) tau, sigma = PDHG.estimate_parameters(C, ratio=20.0) solver = PDHG( f=f, g=g_std, C=C, tau=tau, sigma=sigma, maxiter=200, itstat_options={"display": True, "period": 10}, ) print(f"Solving on {device_info()}\n") solver.solve() hist_std = solver.itstat_object.history(transpose=True) x_std = solver.x print() """ Denoise with polar total variation for comparison. """ # Tune the weight to give the same data fidelty as the isotropic case. λ_plr = 1.2e0 g_plr = λ_plr * functional.L1Norm() G = linop.PolarGradient(input_shape=x_gt.shape) D = linop.Diagonal(snp.array([0.3, 1.0]).reshape((2, 1, 1)), input_shape=G.shape[0]) C = D @ G tau, sigma = PDHG.estimate_parameters(C, ratio=20.0) solver = PDHG( f=f, g=g_plr, C=C, tau=tau, sigma=sigma, maxiter=200, itstat_options={"display": True, "period": 10}, ) solver.solve() hist_plr = solver.itstat_object.history(transpose=True) x_plr = solver.x print() """ Compute and print the data fidelity. """ for x, name in zip((x_std, x_plr), ("Isotropic", "Polar")): df = f(x) print(f"Data fidelity for {(name + ' TV'):12}: {df:.2e} SNR: {metric.snr(x_gt, x):5.2f} dB") """ Plot results. """ plt_args = dict(norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.5)) fig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0], **plt_args) plot.imview(y, title="Noisy version", fig=fig, ax=ax[0, 1], **plt_args) plot.imview(x_std, title="Isotropic TV denoising", fig=fig, ax=ax[1, 0], **plt_args) plot.imview(x_plr, title="Polar TV denoising", fig=fig, ax=ax[1, 1], **plt_args) fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01) fig.colorbar( ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="Arbitrary Units" ) fig.suptitle("Denoising comparison") fig.show() # zoomed version fig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0], **plt_args) plot.imview(y, title="Noisy version", fig=fig, ax=ax[0, 1], **plt_args) plot.imview(x_std, title="Isotropic TV denoising", fig=fig, ax=ax[1, 0], **plt_args) plot.imview(x_plr, title="Polar TV denoising", fig=fig, ax=ax[1, 1], **plt_args) ax[0, 0].set_xlim(N // 4, N // 4 + N // 2) ax[0, 0].set_ylim(N // 4, N // 4 + N // 2) fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01) fig.colorbar( ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="Arbitrary Units" ) fig.suptitle("Denoising comparison (zoomed)") fig.show() fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(20, 5)) plot.plot( snp.array((hist_std.Objective, hist_plr.Objective)).T, ptyp="semilogy", title="Objective function", xlbl="Iteration", lgnd=("Standard", "Polar"), fig=fig, ax=ax[0], ) plot.plot( snp.array((hist_std.Prml_Rsdl, hist_plr.Prml_Rsdl)).T, ptyp="semilogy", title="Primal residual", xlbl="Iteration", lgnd=("Standard", "Polar"), fig=fig, ax=ax[1], ) plot.plot( snp.array((hist_std.Dual_Rsdl, hist_plr.Dual_Rsdl)).T, ptyp="semilogy", title="Dual residual", xlbl="Iteration", lgnd=("Standard", "Polar"), fig=fig, ax=ax[2], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/denoise_tv_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Total Variation Denoising (ADMM) ================================ This example compares denoising via isotropic and anisotropic total variation (TV) regularization :cite:`rudin-1992-nonlinear` :cite:`goldstein-2009-split`. It solves the denoising problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - \mathbf{x} \|_2^2 + \lambda R(\mathbf{x}) \;,$$ where $R$ is either the isotropic or anisotropic TV regularizer. In SCICO, switching between these two regularizers involves a one-line change: replacing an [L1Norm](../_autosummary/scico.functional.rst#scico.functional.L1Norm) with a [L21Norm](../_autosummary/scico.functional.rst#scico.functional.L21Norm). Note that the isotropic version exhibits fewer block-like artifacts on edges that are not vertical or horizontal. """ from xdesign import SiemensStar, discrete_phantom import scico.numpy as snp import scico.random from scico import functional, linop, loss, plot from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Create a ground truth image. """ N = 256 # image size phantom = SiemensStar(16) x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) x_gt = x_gt / x_gt.max() """ Add noise to create a noisy test image. """ σ = 0.75 # noise standard deviation noise, key = scico.random.randn(x_gt.shape, seed=0) y = x_gt + σ * noise """ Denoise with isotropic total variation. """ λ_iso = 1.4e0 f = loss.SquaredL2Loss(y=y) g_iso = λ_iso * functional.L21Norm() # The append=0 option makes the results of horizontal and vertical finite # differences the same shape, which is required for the L21Norm. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) solver = ADMM( f=f, g_list=[g_iso], C_list=[C], rho_list=[1e1], x0=y, maxiter=100, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}), itstat_options={"display": True, "period": 10}, ) print(f"Solving on {device_info()}\n") solver.solve() x_iso = solver.x print() """ Denoise with anisotropic total variation for comparison. """ # Tune the weight to give the same data fidelity as the isotropic case. λ_aniso = 1.2e0 g_aniso = λ_aniso * functional.L1Norm() solver = ADMM( f=f, g_list=[g_aniso], C_list=[C], rho_list=[1e1], x0=y, maxiter=100, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}), itstat_options={"display": True, "period": 10}, ) solver.solve() x_aniso = solver.x print() """ Compute and print the data fidelity. """ for x, name in zip((x_iso, x_aniso), ("Isotropic", "Anisotropic")): df = f(x) print(f"Data fidelity for {name} TV was {df:.2e}") """ Plot results. """ plt_args = dict(norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.5)) fig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0], **plt_args) plot.imview(y, title="Noisy version", fig=fig, ax=ax[0, 1], **plt_args) plot.imview(x_iso, title="Isotropic TV denoising", fig=fig, ax=ax[1, 0], **plt_args) plot.imview(x_aniso, title="Anisotropic TV denoising", fig=fig, ax=ax[1, 1], **plt_args) fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01) fig.colorbar( ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="Arbitrary Units" ) fig.suptitle("Denoising comparison") fig.show() # zoomed version fig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0], **plt_args) plot.imview(y, title="Noisy version", fig=fig, ax=ax[0, 1], **plt_args) plot.imview(x_iso, title="Isotropic TV denoising", fig=fig, ax=ax[1, 0], **plt_args) plot.imview(x_aniso, title="Anisotropic TV denoising", fig=fig, ax=ax[1, 1], **plt_args) ax[0, 0].set_xlim(N // 4, N // 4 + N // 2) ax[0, 0].set_ylim(N // 4, N // 4 + N // 2) fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01) fig.colorbar( ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="Arbitrary Units" ) fig.suptitle("Denoising comparison (zoomed)") fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/denoise_tv_apgm.py ================================================ #!/Usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Total Variation Denoising with Constraint (APGM) ================================================ This example demonstrates the solution of the isotropic total variation (TV) denoising problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - \mathbf{x} \|_2^2 + \lambda R(\mathbf{x}) + \iota_C(\mathbf{x}) \;,$$ where $R$ is a TV regularizer, $\iota_C(\cdot)$ is the indicator function of constraint set $C$, and $C = \{ \mathbf{x} \, | \, x_i \in [0, 1] \}$, i.e. the set of vectors with components constrained to be in the interval $[0, 1]$. The problem is solved seperately with $R$ taken as isotropic and anisotropic TV regularization The solution via APGM is based on the approach in :cite:`beck-2009-tv`, which involves constructing a dual for the constrained denoising problem. The APGM solution minimizes the resulting dual. In this case, switching between the two regularizers corresponds to switching between two different projectors. """ from typing import Callable, Optional, Union import jax.numpy as jnp from xdesign import SiemensStar, discrete_phantom import scico.numpy as snp import scico.random from scico import functional, linop, loss, operator, plot from scico.numpy import Array, BlockArray from scico.optimize.pgm import AcceleratedPGM, RobustLineSearchStepSize from scico.util import device_info """ Create a ground truth image. """ N = 256 # image size phantom = SiemensStar(16) x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) x_gt = x_gt / x_gt.max() """ Add noise to create a noisy test image. """ σ = 0.75 # noise standard deviation noise, key = scico.random.randn(x_gt.shape, seed=0) y = x_gt + σ * noise """ Define finite difference operator and adjoint. """ # The append=0 option appends 0 to the input along the axis # prior to performing the difference to make the results of # horizontal and vertical finite differences the same shape. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) A = C.adj """ Define a zero array as initial estimate. """ x0 = jnp.zeros(C(y).shape) """ Define the dual of the total variation denoising problem. """ class DualTVLoss(loss.Loss): def __init__( self, y: Union[Array, BlockArray], A: Optional[Union[Callable, operator.Operator]] = None, lmbda: float = 0.5, ): self.functional = functional.SquaredL2Norm() super().__init__(y=y, A=A, scale=1.0) self.lmbda = lmbda def __call__(self, x: Union[Array, BlockArray]) -> float: xint = self.y - self.lmbda * self.A(x) return -1.0 * self.functional(xint - jnp.clip(xint, 0.0, 1.0)) + self.functional(xint) """ Denoise with isotropic total variation. Define projector for isotropic total variation. """ # Evaluation of functional set to zero. class IsoProjector(functional.Functional): has_eval = True has_prox = True def __call__(self, x: Union[Array, BlockArray]) -> float: return 0.0 def prox(self, v: Array, lam: float, **kwargs) -> Array: norm_v_ptp = jnp.sqrt(jnp.sum(jnp.abs(v) ** 2, axis=0)) x_out = v / jnp.maximum(jnp.ones(v.shape), norm_v_ptp) out1 = v[0, :, -1] / jnp.maximum(jnp.ones(v[0, :, -1].shape), jnp.abs(v[0, :, -1])) x_out = x_out.at[0, :, -1].set(out1) out2 = v[1, -1, :] / jnp.maximum(jnp.ones(v[1, -1, :].shape), jnp.abs(v[1, -1, :])) x_out = x_out.at[1, -1, :].set(out2) return x_out """ Set up `AcceleratedPGM` solver object using `RobustLineSearchStepSize` step size policy. Run the solver. """ reg_weight_iso = 1.4e0 f_iso = DualTVLoss(y=y, A=A, lmbda=reg_weight_iso) g_iso = IsoProjector() solver_iso = AcceleratedPGM( f=f_iso, g=g_iso, L0=16.0 * f_iso.lmbda**2, x0=x0, maxiter=100, itstat_options={"display": True, "period": 10}, step_size=RobustLineSearchStepSize(), ) # Run the solver. print(f"Solving on {device_info()}\n") x = solver_iso.solve() hist_iso = solver_iso.itstat_object.history(transpose=True) # Project to constraint set. x_iso = jnp.clip(y - f_iso.lmbda * f_iso.A(x), 0.0, 1.0) """ Denoise with anisotropic total variation for comparison. Define projector for anisotropic total variation. """ # Evaluation of functional set to zero. class AnisoProjector(functional.Functional): has_eval = True has_prox = True def __call__(self, x: Union[Array, BlockArray]) -> float: return 0.0 def prox(self, v: Array, lam: float, **kwargs) -> Array: return v / jnp.maximum(jnp.ones(v.shape), jnp.abs(v)) """ Set up `AcceleratedPGM` solver object using `RobustLineSearchStepSize` step size policy. (Weight was tuned to give the same data fidelity as the isotropic case.) Run the solver. """ reg_weight_aniso = 1.2e0 f = DualTVLoss(y=y, A=A, lmbda=reg_weight_aniso) g = AnisoProjector() solver = AcceleratedPGM( f=f, g=g, L0=16.0 * f.lmbda**2, x0=x0, maxiter=100, itstat_options={"display": True, "period": 10}, step_size=RobustLineSearchStepSize(), ) # Run the solver. print() x = solver.solve() # Project to constraint set. x_aniso = jnp.clip(y - f.lmbda * f.A(x), 0.0, 1.0) """ Compute the data fidelity. """ df = hist_iso.Objective[-1] print(f"\nData fidelity for isotropic TV was {df:.2e}") hist = solver.itstat_object.history(transpose=True) df = hist.Objective[-1] print(f"Data fidelity for anisotropic TV was {df:.2e}") """ Plot results. """ plt_args = dict(norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.5)) fig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0], **plt_args) plot.imview(y, title="Noisy version", fig=fig, ax=ax[0, 1], **plt_args) plot.imview(x_iso, title="Isotropic TV denoising", fig=fig, ax=ax[1, 0], **plt_args) plot.imview(x_aniso, title="Anisotropic TV denoising", fig=fig, ax=ax[1, 1], **plt_args) fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01) fig.colorbar( ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="Arbitrary Units" ) fig.suptitle("Denoising comparison") fig.show() # zoomed version fig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10)) plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0], **plt_args) plot.imview(y, title="Noisy version", fig=fig, ax=ax[0, 1], **plt_args) plot.imview(x_iso, title="Isotropic TV denoising", fig=fig, ax=ax[1, 0], **plt_args) plot.imview(x_aniso, title="Anisotropic TV denoising", fig=fig, ax=ax[1, 1], **plt_args) ax[0, 0].set_xlim(N // 4, N // 4 + N // 2) ax[0, 0].set_ylim(N // 4, N // 4 + N // 2) fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01) fig.colorbar( ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="Arbitrary Units" ) fig.suptitle("Denoising comparison (zoomed)") fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/denoise_tv_multi.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Comparison of Optimization Algorithms for Total Variation Denoising =================================================================== This example compares the performance of alternating direction method of multipliers (ADMM), linearized ADMM, proximal ADMM, and primal–dual hybrid gradient (PDHG) in solving the isotropic total variation (TV) denoising problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - \mathbf{x} \|_2^2 + \lambda R(\mathbf{x}) \;,$$ where $R$ is the isotropic TV: the sum of the norms of the gradient vectors at each point in the image $\mathbf{x}$. """ from xdesign import SiemensStar, discrete_phantom import scico.numpy as snp import scico.random from scico import functional, linop, loss, plot from scico.optimize import PDHG, LinearizedADMM, ProximalADMM from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Create a ground truth image. """ phantom = SiemensStar(32) N = 256 # image size x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) """ Add noise to create a noisy test image. """ σ = 1.0 # noise standard deviation noise, key = scico.random.randn(x_gt.shape, seed=0) y = x_gt + σ * noise """ Construct operators and functionals and set regularization parameter. """ # The append=0 option makes the results of horizontal and vertical # finite differences the same shape, which is required for the L21Norm. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) f = loss.SquaredL2Loss(y=y) λ = 1e0 g = λ * functional.L21Norm() """ The first step of the first-run solver is much slower than the following steps, presumably due to just-in-time compilation of relevant operators in first use. The code below performs a preliminary solver step, the result of which is discarded, to reduce this bias in the timing results. The precise cause of the remaining differences in time required to compute the first step of each algorithm is unknown, but it is worth noting that this difference becomes negligible when just-in-time compilation is disabled (e.g. via the `JAX_DISABLE_JIT` environment variable). """ solver_admm = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[1e1], x0=y, maxiter=1, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 1}), ) solver_admm.solve(); # fmt: skip # trailing semi-colon suppresses output in notebook """ Solve via ADMM with a maximum of 2 CG iterations. """ solver_admm = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[1e1], x0=y, maxiter=200, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 2}), itstat_options={"display": True, "period": 10}, ) print(f"Solving on {device_info()}\n") print("ADMM solver") solver_admm.solve() hist_admm = solver_admm.itstat_object.history(transpose=True) """ Solve via Linearized ADMM. """ solver_ladmm = LinearizedADMM( f=f, g=g, C=C, mu=1e-2, nu=1e-1, x0=y, maxiter=200, itstat_options={"display": True, "period": 10}, ) print("\nLinearized ADMM solver") solver_ladmm.solve() hist_ladmm = solver_ladmm.itstat_object.history(transpose=True) """ Solve via Proximal ADMM. """ mu, nu = ProximalADMM.estimate_parameters(C) solver_padmm = ProximalADMM( f=f, g=g, A=C, rho=1e0, mu=mu, nu=nu, x0=y, maxiter=200, itstat_options={"display": True, "period": 10}, ) print("\nProximal ADMM solver") solver_padmm.solve() hist_padmm = solver_padmm.itstat_object.history(transpose=True) """ Solve via PDHG. """ tau, sigma = PDHG.estimate_parameters(C, factor=1.5) solver_pdhg = PDHG( f=f, g=g, C=C, tau=tau, sigma=sigma, maxiter=200, itstat_options={"display": True, "period": 10}, ) print("\nPDHG solver") solver_pdhg.solve() hist_pdhg = solver_pdhg.itstat_object.history(transpose=True) """ Plot results. It is worth noting that: 1. PDHG outperforms ADMM both with respect to iterations and time. 2. Proximal ADMM has similar performance to PDHG with respect to iterations, but is slightly inferior with respect to time. 3. ADMM greatly outperforms Linearized ADMM with respect to iterations. 4. ADMM slightly outperforms Linearized ADMM with respect to time. This is possible because the ADMM $\mathbf{x}$-update can be solved relatively cheaply, with only 2 CG iterations. If more CG iterations were required, the time comparison would be favorable to Linearized ADMM. """ fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6)) plot.plot( snp.array( (hist_admm.Objective, hist_ladmm.Objective, hist_padmm.Objective, hist_pdhg.Objective) ).T, ptyp="semilogy", title="Objective function", xlbl="Iteration", lgnd=("ADMM", "LinADMM", "ProxADMM", "PDHG"), fig=fig, ax=ax[0], ) plot.plot( snp.array( (hist_admm.Prml_Rsdl, hist_ladmm.Prml_Rsdl, hist_padmm.Prml_Rsdl, hist_pdhg.Prml_Rsdl) ).T, ptyp="semilogy", title="Primal residual", xlbl="Iteration", lgnd=("ADMM", "LinADMM", "ProxADMM", "PDHG"), fig=fig, ax=ax[1], ) plot.plot( snp.array( (hist_admm.Dual_Rsdl, hist_ladmm.Dual_Rsdl, hist_padmm.Dual_Rsdl, hist_pdhg.Dual_Rsdl) ).T, ptyp="semilogy", title="Dual residual", xlbl="Iteration", lgnd=("ADMM", "LinADMM", "ProxADMM", "PDHG"), fig=fig, ax=ax[2], ) fig.show() fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6)) plot.plot( snp.array( (hist_admm.Objective, hist_ladmm.Objective, hist_padmm.Objective, hist_pdhg.Objective) ).T, snp.array((hist_admm.Time, hist_ladmm.Time, hist_padmm.Time, hist_pdhg.Time)).T, ptyp="semilogy", title="Objective function", xlbl="Time (s)", lgnd=("ADMM", "LinADMM", "ProxADMM", "PDHG"), fig=fig, ax=ax[0], ) plot.plot( snp.array( (hist_admm.Prml_Rsdl, hist_ladmm.Prml_Rsdl, hist_padmm.Prml_Rsdl, hist_pdhg.Prml_Rsdl) ).T, snp.array((hist_admm.Time, hist_ladmm.Time, hist_padmm.Time, hist_pdhg.Time)).T, ptyp="semilogy", title="Primal residual", xlbl="Time (s)", lgnd=("ADMM", "LinADMM", "ProxADMM", "PDHG"), fig=fig, ax=ax[1], ) plot.plot( snp.array( (hist_admm.Dual_Rsdl, hist_ladmm.Dual_Rsdl, hist_padmm.Dual_Rsdl, hist_pdhg.Dual_Rsdl) ).T, snp.array((hist_admm.Time, hist_ladmm.Time, hist_padmm.Time, hist_pdhg.Time)).T, ptyp="semilogy", title="Dual residual", xlbl="Time (s)", lgnd=("ADMM", "LinADMM", "ProxADMM", "PDHG"), fig=fig, ax=ax[2], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/diffusercam_tv_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" TV-Regularized 3D DiffuserCam Reconstruction ============================================ This example demonstrates reconstruction of a 3D DiffuserCam :cite:`antipa-2018-diffusercam` [dataset](https://github.com/Waller-Lab/DiffuserCam/tree/master/example_data). The inverse problem can be written as $$\mathrm{argmin}_{\mathbf{x}} \; \frac{1}{2} \Big\| \mathbf{y} - M \Big( \sum_k \mathbf{h}_k \ast \mathbf{x}_k \Big) \Big\|_2^2 + \lambda_0 \sum_k \| D \mathbf{x}_k \|_{2,1} + \lambda_1 \sum_k \| \mathbf{x}_k \|_1 \;,$$ where the $\mathbf{h}$_k are the components of the PSF stack, the $\mathbf{x}$_k are the corrresponding components of the reconstructed volume, $\mathbf{y}$ is the measured image, and $M$ is a cropping operator that allows the boundary artifacts resulting from circular convolution to be avoided. Following the mask decoupling approach :cite:`almeida-2013-deconvolving`, the problem is posed in ADMM form as $$\mathrm{argmin}_{\mathbf{x}, \mathbf{z}_0, \mathbf{z}_1, \mathbf{z}_2} \; \frac{1}{2} \| \mathbf{y} - M \mathbf{z}_0 \|_2^2 + \lambda_0 \sum_k \| \mathbf{z}_{1,k} \|_{2,1} + \lambda_1 \sum_k \| \mathbf{z}_{2,k} \|_1 \\ \;\; \text{s.t.} \;\; \mathbf{z}_0 = \sum_k \mathbf{h}_k \ast \mathbf{x}_k \qquad \mathbf{z}_{1,k} = D \mathbf{x}_k \qquad \mathbf{z}_{2,k} = \mathbf{x}_k \;.$$ The most computationally expensive step in the ADMM algorithm is solved using the frequency-domain approach proposed in :cite:`wohlberg-2014-efficient`. """ import numpy as np import scico.numpy as snp from scico import plot from scico.examples import ucb_diffusercam_data from scico.functional import L1Norm, L21Norm, ZeroFunctional from scico.linop import CircularConvolve, Crop, FiniteDifference, Identity, Sum from scico.loss import SquaredL2Loss from scico.optimize.admm import ADMM, G0BlockCircularConvolveSolver from scico.util import device_info """ Load the DiffuserCam PSF stack and measured image. The computational cost of the reconstruction is reduced slightly by removing parts of the PSF stack that don't make a significant contribution to the reconstruction. """ y, psf = ucb_diffusercam_data() psf = psf[..., 1:-7] """ To avoid boundary artifacts, the measured image is padded by half the PSF width/height and then cropped within the data fidelity term. This padding is implicit in that the reconstruction volume is computed at the padded size, but the actual measured image is never explicitly padded since it is used at the original (unpadded) size within the data fidelity term due to the cropping operation. The PSF axis order is modified to put the stack axis at index 0, as required by components of the ADMM solver to be used. Finally, each PSF in the stack is individually normalized. """ half_psf = np.array(psf.shape[0:2]) // 2 pad_spec = ((half_psf[0],) * 2, (half_psf[1],) * 2) y_pad_shape = tuple(np.array(y.shape) + np.array(pad_spec).sum(axis=1)) x_shape = (psf.shape[-1],) + y_pad_shape psf = psf.transpose((2, 0, 1)) psf /= np.sqrt(np.sum(psf**2, axis=(1, 2), keepdims=True)) """ Convert the image and PSF stack to JAX arrays with `float32` dtype since JAX by default does not support double-precision floating point arithmetic. This limited precision leads to relatively poor, but still acceptable accuracy within the ADMM solver x-step. To experiment with the effect of higher numerical precision, set the environment variable `JAX_ENABLE_X64=True` and change `dtype` below to `np.float64`. """ dtype = np.float32 y = snp.array(y.astype(dtype)) psf = snp.array(psf.astype(dtype)) """ Define problem and algorithm parameters. """ λ0 = 3e-3 # TV regularization parameter λ1 = 1e-2 # ℓ1 norm regularization parameter ρ0 = 1e0 # ADMM penalty parameter for first auxiliary variable ρ1 = 5e0 # ADMM penalty parameter for second auxiliary variable ρ2 = 1e1 # ADMM penalty parameter for third auxiliary variable maxiter = 100 # number of ADMM iterations """ Create operators. """ C = CircularConvolve(psf, input_shape=x_shape, input_dtype=dtype, h_center=half_psf, ndims=2) S = Sum(input_shape=x_shape, input_dtype=dtype, axis=0) M = Crop(pad_spec, input_shape=y_pad_shape, input_dtype=dtype) """ Create functionals. """ g0 = SquaredL2Loss(y=y, A=M) g1 = λ0 * L21Norm() g2 = λ1 * L1Norm() C0 = S @ C C1 = FiniteDifference(input_shape=x_shape, input_dtype=dtype, axes=(-2, -1), circular=True) C2 = Identity(input_shape=x_shape, input_dtype=dtype) """ Set up ADMM solver object and solve problem. """ solver = ADMM( f=ZeroFunctional(), g_list=[g0, g1, g2], C_list=[C0, C1, C2], rho_list=[ρ0, ρ1, ρ2], alpha=1.4, maxiter=maxiter, nanstop=True, subproblem_solver=G0BlockCircularConvolveSolver(ndims=2, check_solve=True), itstat_options={"display": True, "period": 10}, ) print(f"Solving on {device_info()}\n") x = solver.solve() hist = solver.itstat_object.history(transpose=True) """ Show the measured image and samples from PDF stack """ plot.imview(y, cmap=plot.plt.cm.Blues, cbar=True, title="Measured Image") fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(14, 7)) plot.imview(psf[0], title="Nearest PSF", cmap=plot.plt.cm.Blues, fig=fig, ax=ax[0]) plot.imview(psf[-1], title="Furthest PSF", cmap=plot.plt.cm.Blues, fig=fig, ax=ax[1]) fig.show() """ Show the recovered volume with depth indicated by color. """ XCrop = Crop(((0, 0),) + pad_spec, input_shape=x_shape, input_dtype=dtype) xm = np.array(XCrop(x[..., ::-1])) xmr = xm.transpose((1, 2, 0))[..., np.newaxis] / xm.max() cmap = plot.plt.cm.viridis_r cmval = cmap(np.arange(0, xm.shape[0]).reshape(1, 1, -1) / (xm.shape[0] - 1)) xms = np.sum(cmval * xmr, axis=2)[..., 0:3] plot.imview(xms, cmap=cmap, cbar=True, title="Recovered Volume") """ Plot convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( hist.Objective, title="Objective function", xlbl="Iteration", ylbl="Functional value", fig=fig, ax=ax[0], ) plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/index.rst ================================================ Usage Examples ============== Organized by Application ------------------------ Computed Tomography ^^^^^^^^^^^^^^^^^^^ - ct_abel_tv_admm.py - ct_abel_tv_admm_tune.py - ct_symcone_tv_padmm.py - ct_astra_noreg_pcg.py - ct_astra_3d_tv_admm.py - ct_astra_3d_tv_padmm.py - ct_tv_admm.py - ct_astra_tv_admm.py - ct_multi_tv_admm.py - ct_astra_weighted_tv_admm.py - ct_svmbir_tv_multi.py - ct_svmbir_ppp_bm3d_admm_cg.py - ct_svmbir_ppp_bm3d_admm_prox.py - ct_fan_svmbir_ppp_bm3d_admm_prox.py - ct_modl_train_foam2.py - ct_odp_train_foam2.py - ct_unet_train_foam2.py - ct_projector_comparison_2d.py - ct_projector_comparison_3d.py Deconvolution ^^^^^^^^^^^^^ - deconv_circ_tv_admm.py - deconv_tv_admm.py - deconv_tv_padmm.py - deconv_tv_admm_tune.py - deconv_microscopy_tv_admm.py - deconv_microscopy_allchn_tv_admm.py - deconv_ppp_bm3d_admm.py - deconv_ppp_bm3d_apgm.py - deconv_ppp_dncnn_admm.py - deconv_ppp_dncnn_padmm.py - deconv_ppp_bm4d_admm.py - deconv_modl_train_foam1.py - deconv_odp_train_foam1.py Sparse Coding ^^^^^^^^^^^^^ - sparsecode_nn_admm.py - sparsecode_nn_apgm.py - sparsecode_conv_admm.py - sparsecode_conv_md_admm.py - sparsecode_apgm.py - sparsecode_poisson_apgm.py Miscellaneous ^^^^^^^^^^^^^ - demosaic_ppp_bm3d_admm.py - superres_ppp_dncnn_admm.py - denoise_l1tv_admm.py - denoise_ptv_pdhg.py - denoise_tv_admm.py - denoise_tv_apgm.py - denoise_tv_multi.py - denoise_approx_tv_multi.py - denoise_cplx_tv_nlpadmm.py - denoise_cplx_tv_pdhg.py - denoise_dncnn_universal.py - diffusercam_tv_admm.py - video_rpca_admm.py - ct_datagen_foam2.py - deconv_datagen_bsds.py - deconv_datagen_foam1.py - denoise_datagen_bsds.py Organized by Regularization --------------------------- Plug and Play Priors ^^^^^^^^^^^^^^^^^^^^ - ct_svmbir_ppp_bm3d_admm_cg.py - ct_svmbir_ppp_bm3d_admm_prox.py - ct_fan_svmbir_ppp_bm3d_admm_prox.py - deconv_ppp_bm3d_admm.py - deconv_ppp_bm3d_apgm.py - deconv_ppp_dncnn_admm.py - deconv_ppp_dncnn_padmm.py - deconv_ppp_bm4d_admm.py - demosaic_ppp_bm3d_admm.py - superres_ppp_dncnn_admm.py Total Variation ^^^^^^^^^^^^^^^ - ct_abel_tv_admm.py - ct_abel_tv_admm_tune.py - ct_symcone_tv_padmm.py - ct_tv_admm.py - ct_multi_tv_admm.py - ct_astra_tv_admm.py - ct_astra_3d_tv_admm.py - ct_astra_3d_tv_padmm.py - ct_astra_weighted_tv_admm.py - ct_svmbir_tv_multi.py - deconv_circ_tv_admm.py - deconv_tv_admm.py - deconv_tv_admm_tune.py - deconv_tv_padmm.py - deconv_microscopy_tv_admm.py - deconv_microscopy_allchn_tv_admm.py - denoise_l1tv_admm.py - denoise_ptv_pdhg.py - denoise_tv_admm.py - denoise_tv_apgm.py - denoise_tv_multi.py - denoise_approx_tv_multi.py - denoise_cplx_tv_nlpadmm.py - denoise_cplx_tv_pdhg.py - diffusercam_tv_admm.py Sparsity ^^^^^^^^ - diffusercam_tv_admm.py - sparsecode_nn_admm.py - sparsecode_nn_apgm.py - sparsecode_conv_admm.py - sparsecode_conv_md_admm.py - sparsecode_apgm.py - sparsecode_poisson_apgm.py - video_rpca_admm.py Machine Learning ^^^^^^^^^^^^^^^^ - ct_datagen_foam2.py - ct_modl_train_foam2.py - ct_odp_train_foam2.py - ct_unet_train_foam2.py - deconv_datagen_bsds.py - deconv_datagen_foam1.py - deconv_modl_train_foam1.py - deconv_odp_train_foam1.py - denoise_datagen_bsds.py - denoise_dncnn_train_bsds.py - denoise_dncnn_universal.py Organized by Optimization Algorithm ----------------------------------- ADMM ^^^^ - ct_abel_tv_admm.py - ct_abel_tv_admm_tune.py - ct_symcone_tv_padmm.py - ct_astra_tv_admm.py - ct_tv_admm.py - ct_astra_3d_tv_admm.py - ct_astra_weighted_tv_admm.py - ct_multi_tv_admm.py - ct_svmbir_tv_multi.py - ct_svmbir_ppp_bm3d_admm_cg.py - ct_svmbir_ppp_bm3d_admm_prox.py - ct_fan_svmbir_ppp_bm3d_admm_prox.py - deconv_circ_tv_admm.py - deconv_tv_admm.py - deconv_tv_admm_tune.py - deconv_microscopy_tv_admm.py - deconv_microscopy_allchn_tv_admm.py - deconv_ppp_bm3d_admm.py - deconv_ppp_dncnn_admm.py - deconv_ppp_bm4d_admm.py - diffusercam_tv_admm.py - sparsecode_nn_admm.py - sparsecode_conv_admm.py - sparsecode_conv_md_admm.py - demosaic_ppp_bm3d_admm.py - superres_ppp_dncnn_admm.py - denoise_l1tv_admm.py - denoise_tv_admm.py - denoise_tv_multi.py - denoise_approx_tv_multi.py - video_rpca_admm.py Linearized ADMM ^^^^^^^^^^^^^^^ - ct_svmbir_tv_multi.py - denoise_tv_multi.py Proximal ADMM ^^^^^^^^^^^^^ - ct_astra_3d_tv_padmm.py - deconv_tv_padmm.py - denoise_tv_multi.py - deconv_ppp_dncnn_padmm.py Non-linear Proximal ADMM ^^^^^^^^^^^^^^^^^^^^^^^^ - denoise_cplx_tv_nlpadmm.py PDHG ^^^^ - ct_svmbir_tv_multi.py - denoise_ptv_pdhg.py - denoise_tv_multi.py - denoise_cplx_tv_pdhg.py PGM ^^^ - deconv_ppp_bm3d_apgm.py - sparsecode_apgm.py - sparsecode_nn_apgm.py - sparsecode_poisson_apgm.py - denoise_tv_apgm.py - denoise_approx_tv_multi.py PCG ^^^ - ct_astra_noreg_pcg.py ================================================ FILE: examples/scripts/sparsecode_apgm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Basis Pursuit DeNoising (APGM) ============================== This example demonstrates the solution of the the sparse coding problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - D \mathbf{x} \|_2^2 + \lambda \| \mathbf{x} \|_1\;,$$ where $D$ the dictionary, $\mathbf{y}$ the signal to be represented, and $\mathbf{x}$ is the sparse representation. """ import numpy as np import scico.numpy as snp from scico import functional, linop, loss, plot from scico.optimize.pgm import AcceleratedPGM from scico.util import device_info """ Construct a random dictionary, a reference random sparse representation, and a test signal consisting of the synthesis of the reference sparse representation. """ m = 512 # Signal size n = 4 * m # Dictionary size s = 32 # Sparsity level (number of non-zeros) σ = 0.5 # Noise level np.random.seed(12345) D = np.random.randn(m, n).astype(np.float32) L0 = np.linalg.norm(D, 2) ** 2 x_gt = np.zeros(n, dtype=np.float32) # true signal idx = np.random.permutation(list(range(0, n - 1))) x_gt[idx[0:s]] = np.random.randn(s) y = D @ x_gt + σ * np.random.randn(m) # synthetic signal x_gt = snp.array(x_gt) # convert to jax array y = snp.array(y) # convert to jax array """ Set up the forward operator and `AcceleratedPGM` solver object. """ maxiter = 100 λ = 2.98e1 A = linop.MatrixOperator(D) f = loss.SquaredL2Loss(y=y, A=A) g = λ * functional.L1Norm() solver = AcceleratedPGM( f=f, g=g, L0=L0, x0=A.adj(y), maxiter=maxiter, itstat_options={"display": True, "period": 10} ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x = solver.solve() hist = solver.itstat_object.history(transpose=True) """ Plot the recovered coefficients and convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( np.vstack((x_gt, x)).T, title="Coefficients", lgnd=("Ground Truth", "Recovered"), fig=fig, ax=ax[0], ) plot.plot( np.array((hist.Objective, hist.Residual)).T, ptyp="semilogy", title="Convergence", xlbl="Iteration", lgnd=("Objective", "Residual"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/sparsecode_conv_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Convolutional Sparse Coding (ADMM) ================================== This example demonstrates the solution of a simple convolutional sparse coding problem $$\mathrm{argmin}_{\mathbf{x}} \; \frac{1}{2} \Big\| \mathbf{y} - \sum_k \mathbf{h}_k \ast \mathbf{x}_k \Big\|_2^2 + \lambda \sum_k ( \| \mathbf{x}_k \|_1 - \| \mathbf{x}_k \|_2 ) \;,$$ where the $\mathbf{h}$_k is a set of filters comprising the dictionary, the $\mathbf{x}$_k is a corrresponding set of coefficient maps, and $\mathbf{y}$ is the signal to be represented. The problem is solved via an ADMM algorithm using the frequency-domain approach proposed in :cite:`wohlberg-2014-efficient`. """ import numpy as np import scico.numpy as snp from scico import plot from scico.examples import create_conv_sparse_phantom from scico.functional import L1MinusL2Norm from scico.linop import CircularConvolve, Identity, Sum from scico.loss import SquaredL2Loss from scico.optimize.admm import ADMM, FBlockCircularConvolveSolver from scico.util import device_info """ Set problem size and create random convolutional dictionary (a set of filters) and a corresponding sparse random set of coefficient maps. """ N = 128 # image size Nnz = 128 # number of non-zeros in coefficient maps h, x0 = create_conv_sparse_phantom(N, Nnz) """ Normalize dictionary filters and scale coefficient maps accordingly. """ hnorm = np.sqrt(np.sum(h**2, axis=(1, 2), keepdims=True)) h /= hnorm x0 *= hnorm """ Convert numpy arrays to jax arrays. """ h = snp.array(h) x0 = snp.array(x0) """ Set up sum-of-convolutions forward operator. """ C = CircularConvolve(h, input_shape=x0.shape, ndims=2) S = Sum(input_shape=C.output_shape, axis=0) A = S @ C """ Construct test image from dictionary $\mathbf{h}$ and coefficient maps $\mathbf{x}_0$. """ y = A(x0) """ Set functional and solver parameters. """ λ = 1e0 # ℓ1-ℓ2 norm regularization parameter ρ = 2e0 # ADMM penalty parameter maxiter = 200 # number of ADMM iterations """ Define loss function and regularization. Note the use of the $\ell_1 - \ell_2$ norm, which has been found to provide slightly better performance than the $\ell_1$ norm in this type of problem :cite:`wohlberg-2021-psf`. """ f = SquaredL2Loss(y=y, A=A) g0 = λ * L1MinusL2Norm() C0 = Identity(input_shape=x0.shape) """ Initialize ADMM solver. """ solver = ADMM( f=f, g_list=[g0], C_list=[C0], rho_list=[ρ], alpha=1.8, maxiter=maxiter, subproblem_solver=FBlockCircularConvolveSolver(check_solve=True), itstat_options={"display": True, "period": 10}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x1 = solver.solve() hist = solver.itstat_object.history(transpose=True) """ Show the recovered coefficient maps. """ fig, ax = plot.subplots(nrows=2, ncols=3, figsize=(12, 8.6)) plot.imview(x0[0], title="Coef. map 0", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 0]) ax[0, 0].set_ylabel("Ground truth") plot.imview(x0[1], title="Coef. map 1", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 1]) plot.imview(x0[2], title="Coef. map 2", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 2]) plot.imview(x1[0], cmap=plot.cm.Blues, fig=fig, ax=ax[1, 0]) ax[1, 0].set_ylabel("Recovered") plot.imview(x1[1], cmap=plot.cm.Blues, fig=fig, ax=ax[1, 1]) plot.imview(x1[2], cmap=plot.cm.Blues, fig=fig, ax=ax[1, 2]) fig.tight_layout() fig.show() """ Show test image and reconstruction from recovered coefficient maps. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 6)) plot.imview(y, title="Test image", cmap=plot.cm.gist_heat_r, fig=fig, ax=ax[0]) plot.imview(A(x1), title="Reconstructed image", cmap=plot.cm.gist_heat_r, fig=fig, ax=ax[1]) fig.show() """ Plot convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( hist.Objective, title="Objective function", xlbl="Iteration", ylbl="Functional value", fig=fig, ax=ax[0], ) plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/sparsecode_conv_md_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Convolutional Sparse Coding with Mask Decoupling (ADMM) ======================================================= This example demonstrates the solution of a convolutional sparse coding problem $$\mathrm{argmin}_{\mathbf{x}} \; \frac{1}{2} \Big\| \mathbf{y} - B \Big( \sum_k \mathbf{h}_k \ast \mathbf{x}_k \Big) \Big\|_2^2 + \lambda \sum_k ( \| \mathbf{x}_k \|_1 - \| \mathbf{x}_k \|_2 ) \;,$$ where the $\mathbf{h}$_k is a set of filters comprising the dictionary, the $\mathbf{x}$_k is a corrresponding set of coefficient maps, $\mathbf{y}$ is the signal to be represented, and $B$ is a cropping operator that allows the boundary artifacts resulting from circular convolution to be avoided. Following the mask decoupling approach :cite:`almeida-2013-deconvolving`, the problem is posed in ADMM form as $$\mathrm{argmin}_{\mathbf{x}, \mathbf{z}_0, \mathbf{z}_1} \; (1/2) \| \mathbf{y} - B \mb{z}_0 \|_2^2 + \lambda \sum_k ( \| \mathbf{z}_{1,k} \|_1 - \| \mathbf{z}_{1,k} \|_2 ) \\ \;\; \text{s.t.} \;\; \mathbf{z}_0 = \sum_k \mathbf{h}_k \ast \mathbf{x}_k \;\; \mathbf{z}_{1,k} = \mathbf{x}_k\;,$$. The most computationally expensive step in the ADMM algorithm is solved using the frequency-domain approach proposed in :cite:`wohlberg-2014-efficient`. """ import numpy as np import scico.numpy as snp from scico import plot from scico.examples import create_conv_sparse_phantom from scico.functional import L1MinusL2Norm, ZeroFunctional from scico.linop import CircularConvolve, Crop, Identity, Sum from scico.loss import SquaredL2Loss from scico.optimize.admm import ADMM, G0BlockCircularConvolveSolver from scico.util import device_info """ Set problem size and create random convolutional dictionary (a set of filters) and a corresponding sparse random set of coefficient maps. """ N = 121 # image size Nnz = 128 # number of non-zeros in coefficient maps h, x0 = create_conv_sparse_phantom(N, Nnz) """ Normalize dictionary filters and scale coefficient maps accordingly. """ hnorm = np.sqrt(np.sum(h**2, axis=(1, 2), keepdims=True)) h /= hnorm x0 *= hnorm """ Convert numpy arrays to jax arrays. """ h = snp.array(h) x0 = snp.array(x0) """ Set up required padding and corresponding crop operator. """ h_center = (h.shape[1] // 2, h.shape[2] // 2) pad_width = ((0, 0), (h_center[0], h_center[0]), (h_center[1], h_center[1])) x0p = snp.pad(x0, pad_width=pad_width) B = Crop(pad_width[1:], input_shape=x0p.shape[1:]) """ Set up sum-of-convolutions forward operator. """ C = CircularConvolve(h, input_shape=x0p.shape, ndims=2, h_center=h_center) S = Sum(input_shape=C.output_shape, axis=0) A = S @ C """ Construct test image from dictionary $\mathbf{h}$ and padded version of coefficient maps $\mathbf{x}_0$. """ y = B(A(x0p)) """ Set functional and solver parameters. """ λ = 1e0 # ℓ1-ℓ2 norm regularization parameter ρ0 = 1e0 # ADMM penalty parameters ρ1 = 3e0 maxiter = 200 # number of ADMM iterations """ Define loss function and regularization. Note the use of the $\ell_1 - \ell_2$ norm, which has been found to provide slightly better performance than the $\ell_1$ norm in this type of problem :cite:`wohlberg-2021-psf`. """ f = ZeroFunctional() g0 = SquaredL2Loss(y=y, A=B) g1 = λ * L1MinusL2Norm() C0 = A C1 = Identity(input_shape=x0p.shape) """ Initialize ADMM solver. """ solver = ADMM( f=f, g_list=[g0, g1], C_list=[C0, C1], rho_list=[ρ0, ρ1], alpha=1.8, maxiter=maxiter, subproblem_solver=G0BlockCircularConvolveSolver(check_solve=True), itstat_options={"display": True, "period": 10}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x1 = solver.solve() hist = solver.itstat_object.history(transpose=True) """ Show the recovered coefficient maps. """ fig, ax = plot.subplots(nrows=2, ncols=3, figsize=(12, 8.6)) plot.imview(x0[0], title="Coef. map 0", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 0]) ax[0, 0].set_ylabel("Ground truth") plot.imview(x0[1], title="Coef. map 1", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 1]) plot.imview(x0[2], title="Coef. map 2", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 2]) plot.imview(x1[0], cmap=plot.cm.Blues, fig=fig, ax=ax[1, 0]) ax[1, 0].set_ylabel("Recovered") plot.imview(x1[1], cmap=plot.cm.Blues, fig=fig, ax=ax[1, 1]) plot.imview(x1[2], cmap=plot.cm.Blues, fig=fig, ax=ax[1, 2]) fig.tight_layout() fig.show() """ Show test image and reconstruction from recovered coefficient maps. Note the absence of the wrap-around effects at the boundary that can be seen in the corresponding images in the [related example](sparsecode_conv_admm.rst). """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 6)) plot.imview(y, title="Test image", cmap=plot.cm.gist_heat_r, fig=fig, ax=ax[0]) plot.imview(B(A(x1)), title="Reconstructed image", cmap=plot.cm.gist_heat_r, fig=fig, ax=ax[1]) fig.show() """ Plot convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( hist.Objective, title="Objective function", xlbl="Iteration", ylbl="Functional value", fig=fig, ax=ax[0], ) plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/sparsecode_nn_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Non-Negative Basis Pursuit DeNoising (ADMM) =========================================== This example demonstrates the solution of a non-negative sparse coding problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - D \mathbf{x} \|_2^2 + \lambda \| \mathbf{x} \|_1 + \iota_{\mathrm{NN}}(\mathbf{x}) \;,$$ where $D$ the dictionary, $\mathbf{y}$ the signal to be represented, $\mathbf{x}$ is the sparse representation, and $\iota_{\mathrm{NN}}$ is the indicator function of the non-negativity constraint. In this example the problem is solved via ADMM, while Accelerated PGM is used in a [companion example](sparsecode_nn_apgm.rst). """ import numpy as np import scico.numpy as snp from scico import functional, linop, loss, plot from scico.optimize.admm import ADMM, MatrixSubproblemSolver from scico.util import device_info """ Create random dictionary, reference random sparse representation, and test signal consisting of the synthesis of the reference sparse representation. """ m = 32 # signal size n = 128 # dictionary size s = 10 # sparsity level np.random.seed(1) D = np.random.randn(m, n).astype(np.float32) D = D / np.linalg.norm(D, axis=0, keepdims=True) # normalize dictionary xt = np.zeros(n, dtype=np.float32) # true signal idx = np.random.randint(low=0, high=n, size=s) # support of xt xt[idx] = np.random.rand(s) y = D @ xt + 5e-2 * np.random.randn(m) # synthetic signal xt = snp.array(xt) # convert to jax array y = snp.array(y) # convert to jax array """ Set up the forward operator and ADMM solver object. """ lmbda = 1e-1 A = linop.MatrixOperator(D) f = loss.SquaredL2Loss(y=y, A=A) g_list = [lmbda * functional.L1Norm(), functional.NonNegativeIndicator()] C_list = [linop.Identity((n)), linop.Identity((n))] rho_list = [1.0, 1.0] maxiter = 100 # number of ADMM iterations solver = ADMM( f=f, g_list=g_list, C_list=C_list, rho_list=rho_list, x0=A.adj(y), maxiter=maxiter, subproblem_solver=MatrixSubproblemSolver(), itstat_options={"display": True, "period": 10}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x = solver.solve() """ Plot the recovered coefficients and signal. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( np.vstack((xt, solver.x)).T, title="Coefficients", lgnd=("Ground Truth", "Recovered"), fig=fig, ax=ax[0], ) plot.plot( np.vstack((D @ xt, y, D @ solver.x)).T, title="Signal", lgnd=("Ground Truth", "Noisy", "Recovered"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/sparsecode_nn_apgm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Non-Negative Basis Pursuit DeNoising (APGM) =========================================== This example demonstrates the solution of a non-negative sparse coding problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - D \mathbf{x} \|_2^2 + \lambda \| \mathbf{x} \|_1 + \iota_{\mathrm{NN}}(\mathbf{x}) \;,$$ where $D$ the dictionary, $\mathbf{y}$ the signal to be represented, $\mathbf{x}$ is the sparse representation, and $\iota_{\mathrm{NN}}$ is the indicator function of the non-negativity constraint. In this example the problem is solved via Accelerated PGM, using the proximal averaging method :cite:`yu-2013-better` to approximate the proximal operator of the sum of the $\ell_1$ norm and an indicator function, while ADMM is used in a [companion example](sparsecode_nn_admm.rst). """ import numpy as np import scico.numpy as snp from scico import functional, linop, loss, plot from scico.optimize.pgm import AcceleratedPGM from scico.util import device_info """ Create random dictionary, reference random sparse representation, and test signal consisting of the synthesis of the reference sparse representation. """ m = 32 # signal size n = 128 # dictionary size s = 10 # sparsity level np.random.seed(1) D = np.random.randn(m, n).astype(np.float32) D = D / np.linalg.norm(D, axis=0, keepdims=True) # normalize dictionary L0 = max(np.linalg.norm(D, 2) ** 2, 5e1) xt = np.zeros(n, dtype=np.float32) # true signal idx = np.random.randint(low=0, high=n, size=s) # support of xt xt[idx] = np.random.rand(s) y = D @ xt + 5e-2 * np.random.randn(m) # synthetic signal xt = snp.array(xt) # convert to jax array y = snp.array(y) # convert to jax array """ Set up the forward operator and APGM solver object. """ lmbda = 2e-1 A = linop.MatrixOperator(D) f = loss.SquaredL2Loss(y=y, A=A) g = functional.ProximalAverage([lmbda * functional.L1Norm(), functional.NonNegativeIndicator()]) maxiter = 250 # number of APGM iterations solver = AcceleratedPGM( f=f, g=g, L0=L0, x0=A.adj(y), maxiter=maxiter, itstat_options={"display": True, "period": 20} ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x = solver.solve() """ Plot the recovered coefficients and signal. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( np.vstack((xt, solver.x)).T, title="Coefficients", lgnd=("Ground Truth", "Recovered"), fig=fig, ax=ax[0], ) plot.plot( np.vstack((D @ xt, y, D @ solver.x)).T, title="Signal", lgnd=("Ground Truth", "Noisy", "Recovered"), fig=fig, ax=ax[1], ) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/sparsecode_poisson_apgm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Non-negative Poisson Loss Reconstruction (APGM) =============================================== This example demonstrates the use of class [pgm.PGMStepSize](../_autosummary/scico.optimize.pgm.rst#scico.optimize.pgm.PGMStepSize) to solve the non-negative reconstruction problem with Poisson negative log likelihood loss $$\mathrm{argmin}_{\mathbf{x}} \; \frac{1}{2} \left( A(\mathbf{x}) - \mathbf{y} \log\left( A(\mathbf{x}) \right) + \log(\mathbf{y}!) \right) + \iota_{\mathrm{NN}}(\mathbf{x}_0) \;,$$ where $A$ is the forward operator, $\mathbf{y}$ is the measurement, $\mathbf{x}$ is the signal reconstruction, and $\iota_{\mathrm{NN}}$ is the indicator function of the non-negativity constraint. This example also demonstrates the application of [numpy.BlockArray](../_autosummary/scico.numpy.rst#scico.numpy.BlockArray), [functional.SeparableFunctional](../_autosummary/scico.functional.rst#scico.functional.SeparableFunctional), and [functional.ZeroFunctional](../_autosummary/scico.functional.rst#scico.functional.ZeroFunctional) to implement the forward operator $A(\mathbf{x}) = A_0(\mathbf{x}_0) + A_1(\mathbf{x}_1)$ and the selective non-negativity constraint that only applies to $\mathbf{x}_0$. """ import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt import scico.numpy as snp import scico.random from scico import functional, loss, plot from scico.numpy import BlockArray from scico.operator import Operator from scico.optimize.pgm import ( AcceleratedPGM, AdaptiveBBStepSize, BBStepSize, LineSearchStepSize, RobustLineSearchStepSize, ) from scico.typing import Shape from scico.util import device_info from scipy.linalg import dft """ Construct a dictionary, a reference random reconstruction, and a test measurement signal consisting of the synthesis of the reference reconstruction. """ m = 1024 # signal size n = 8 # dictionary size n0 = 2 n1 = n - n0 # Create dictionary with bump-like features. D = ((snp.real(dft(m))[1 : n + 1, :m]) ** 12).T D0 = D[:, :n0] D1 = D[:, n0:] # Define composed operator. class ForwardOperator(Operator): """Toy problem non-linear forward operator with different treatment of x[0] and x[1]. Attributes: D0: Matrix multiplying x[0]. D1: Matrix multiplying x[1]. """ def __init__(self, input_shape: Shape, D0, D1, jit: bool = True): self.D0 = D0 self.D1 = D1 output_shape = (D0.shape[0],) super().__init__( input_shape=input_shape, input_dtype=snp.complex64, output_dtype=snp.complex64, output_shape=output_shape, jit=jit, ) def _eval(self, x: BlockArray) -> BlockArray: return 10 * snp.exp(-D0 @ x[0]) + 5 * snp.exp(-D1 @ x[1]) x_gt, key = scico.random.uniform(((n0,), (n1,)), seed=12345) # true coefficients A = ForwardOperator(x_gt.shape, D0, D1) lam = A(x_gt) y, key = scico.random.poisson(lam, shape=lam.shape, key=key) # synthetic signal """ Set up the loss function and the regularization. """ f = loss.PoissonLoss(y=y, A=A) g0 = functional.NonNegativeIndicator() g1 = functional.ZeroFunctional() g = functional.SeparableFunctional([g0, g1]) """ Define common setup: maximum of iterations and initial estimate of solution. """ maxiter = 50 x0, key = scico.random.uniform(((n0,), (n1,)), key=key) """ Define plotting functionality. """ def plot_results(hist, str_ss, L0, xsol, xgt, Aop): # Plot signal, coefficients and convergence statistics. fig = plot.figure( figsize=(12, 6), tight_layout=True, ) gs = gridspec.GridSpec(nrows=2, ncols=3) fig.suptitle( "Results for PGM Solver and " + str_ss + r" ($L_0$: " + "{:4.2f}".format(L0) + ")", fontsize=14, ) ax0 = fig.add_subplot(gs[0, 0]) plot.plot( hist.Objective, ptyp="semilogy", title="Objective", xlbl="Iteration", fig=fig, ax=ax0, ) ax1 = fig.add_subplot(gs[0, 1]) plot.plot( hist.Residual, ptyp="semilogy", title="Residual", xlbl="Iteration", fig=fig, ax=ax1, ) ax2 = fig.add_subplot(gs[0, 2]) plot.plot( hist.L, ptyp="semilogy", title="L", xlbl="Iteration", fig=fig, ax=ax2, ) ax3 = fig.add_subplot(gs[1, 0]) plt.stem(snp.concatenate((xgt[0], xgt[1])), linefmt="C1-", markerfmt="C1o", basefmt="C1-") plt.stem(snp.concatenate((xsol[0], xsol[1])), linefmt="C2-", markerfmt="C2x", basefmt="C1-") plt.legend(["Ground Truth", "Recovered"]) plt.xlabel("Index") plt.title("Coefficients") ax4 = fig.add_subplot(gs[1, 1:]) plot.plot( snp.vstack((y, Aop(xgt), Aop(xsol))).T, title="Fit", xlbl="Index", lgnd=("y", "A(x_gt)", "A(x)"), fig=fig, ax=ax4, ) fig.show() """ Use default `PGMStepSize` object, set L0 based on norm of forward operator and set up `AcceleratedPGM` solver object. Run the solver and plot the recontructed signal and convergence statistics. """ L0 = 1e3 str_L0 = "(Specifically chosen so that convergence occurs)" solver = AcceleratedPGM( f=f, g=g, L0=L0, x0=x0, maxiter=maxiter, itstat_options={"display": True, "period": 10}, ) str_ss = type(solver.step_size).__name__ print(f"Solving on {device_info()}\n") print("============================================================") print("Running solver with step size of class: ", str_ss) print("L0 " + str_L0 + ": ", L0, "\n") x = solver.solve() # run the solver hist = solver.itstat_object.history(transpose=True) plot_results(hist, str_ss, L0, x, x_gt, A) """ Use `BBStepSize` object, set L0 with arbitary initial value and set up `AcceleratedPGM` solver object. Run the solver and plot the recontructed signal and convergence statistics. """ L0 = 90.0 # initial reciprocal of gradient descent step size str_L0 = "(Arbitrary Initialization)" solver = AcceleratedPGM( f=f, g=g, L0=L0, x0=x0, maxiter=maxiter, itstat_options={"display": True, "period": 10}, step_size=BBStepSize(), ) str_ss = type(solver.step_size).__name__ print("===================================================") print("Running solver with step size of class: ", str_ss) print("L0 " + str_L0 + ": ", L0, "\n") x = solver.solve() # run the solver hist = solver.itstat_object.history(transpose=True) plot_results(hist, str_ss, L0, x, x_gt, A) """ Use `AdaptiveBBStepSize` object, set L0 with arbitary initial value and set up `AcceleratedPGM` solver object. Run the solver and plot the recontructed signal and convergence statistics. """ L0 = 90.0 # initial reciprocal of gradient descent step size str_L0 = "(Arbitrary Initialization)" solver = AcceleratedPGM( f=f, g=g, L0=L0, x0=x0, maxiter=maxiter, itstat_options={"display": True, "period": 10}, step_size=AdaptiveBBStepSize(kappa=0.75), ) str_ss = type(solver.step_size).__name__ print("===========================================================") print("Running solver with step size of class: ", str_ss) print("L0 " + str_L0 + ": ", L0, "\n") x = solver.solve() # run the solver hist = solver.itstat_object.history(transpose=True) plot_results(hist, str_ss, L0, x, x_gt, A) """ Use `LineSearchStepSize` object, set L0 with arbitary initial value and set up `AcceleratedPGM` solver object. Run the solver and plot the recontructed signal and convergence statistics. """ L0 = 90.0 # initial reciprocal of gradient descent step size str_L0 = "(Arbitrary Initialization)" solver = AcceleratedPGM( f=f, g=g, L0=L0, x0=x0, maxiter=maxiter, itstat_options={"display": True, "period": 10}, step_size=LineSearchStepSize(), ) str_ss = type(solver.step_size).__name__ print("===========================================================") print("Running solver with step size of class: ", str_ss) print("L0 " + str_L0 + ": ", L0, "\n") x = solver.solve() # run the solver hist = solver.itstat_object.history(transpose=True) plot_results(hist, str_ss, L0, x, x_gt, A) """ Use `RobustLineSearchStepSize` object, set L0 with arbitary initial value and set up `AcceleratedPGM` solver object. Run the solver and plot the recontructed signal and convergence statistics. """ L0 = 90.0 # initial reciprocal of gradient descent step size str_L0 = "(Arbitrary Initialization)" solver = AcceleratedPGM( f=f, g=g, L0=L0, x0=x0, maxiter=maxiter, itstat_options={"display": True, "period": 10}, step_size=RobustLineSearchStepSize(), ) str_ss = type(solver.step_size).__name__ print("=================================================================") print("Running solver with step size of class: ", str_ss) print("L0 " + str_L0 + ": ", L0, "\n") x = solver.solve() # run the solver hist = solver.itstat_object.history(transpose=True) plot_results(hist, str_ss, L0, x, x_gt, A) input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/superres_ppp_dncnn_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ PPP (with DnCNN) Image Superresolution ====================================== This example demonstrates the use of the ADMM Plug and Play Priors (PPP) algorithm :cite:`venkatakrishnan-2013-plugandplay2`, with DnCNN :cite:`zhang-2017-dncnn` denoiser, for solving a simple image superresolution problem. """ import scico import scico.numpy as snp import scico.random from scico import denoiser, functional, linop, loss, metric, plot from scico.data import kodim23 from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.solver import cg from scico.util import device_info """ Define downsampling function. """ def downsample_image(img, rate): img = snp.mean(snp.reshape(img, (-1, rate, img.shape[1], img.shape[2])), axis=1) img = snp.mean(snp.reshape(img, (img.shape[0], -1, rate, img.shape[2])), axis=2) return img """ Read a ground truth image. """ img = snp.array(kodim23(asfloat=True)[160:416, 60:316]) """ Create a test image by downsampling and adding Gaussian white noise. """ rate = 4 # downsampling rate σ = 2e-2 # noise standard deviation Afn = lambda x: downsample_image(x, rate=rate) s = Afn(img) input_shape = img.shape output_shape = s.shape noise, key = scico.random.randn(s.shape, seed=0) sn = s + σ * noise """ Set up the PPP problem pseudo-functional. The DnCNN denoiser :cite:`zhang-2017-dncnn` is used as a regularizer. """ A = linop.LinearOperator(input_shape=input_shape, output_shape=output_shape, eval_fn=Afn) f = loss.SquaredL2Loss(y=sn, A=A) C = linop.Identity(input_shape=input_shape) g = functional.DnCNN("17M") """ Compute a baseline solution via denoising of the pseudo-inverse of the forward operator. This baseline solution is also used to initialize the PPP solver. """ xpinv, info = cg(A.T @ A, A.T @ sn, snp.zeros(input_shape)) dncnn = denoiser.DnCNN("17M") xden = dncnn(xpinv) """ Set up an ADMM solver and solve. """ ρ = 3.4e-2 # ADMM penalty parameter maxiter = 12 # number of ADMM iterations solver = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], x0=xden, maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 10}), itstat_options={"display": True}, ) print(f"Solving on {device_info()}\n") xppp = solver.solve() hist = solver.itstat_object.history(transpose=True) """ Plot convergence statistics. """ plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), ) """ Show reference and test images. """ fig = plot.figure(figsize=(8, 6)) ax0 = plot.plt.subplot2grid((1, rate + 1), (0, 0), colspan=rate) plot.imview(img, title="Reference", fig=fig, ax=ax0) ax1 = plot.plt.subplot2grid((1, rate + 1), (0, rate)) plot.imview(sn, title="Downsampled", fig=fig, ax=ax1) fig.show() """ Show recovered full-resolution images. """ fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(21, 7)) plot.imview(xpinv, title="Pseudo-inverse: %.2f (dB)" % metric.psnr(img, xpinv), fig=fig, ax=ax[0]) plot.imview( xden, title="Denoised pseudo-inverse: %.2f (dB)" % metric.psnr(img, xden), fig=fig, ax=ax[1] ) plot.imview(xppp, title="PPP solution: %.2f (dB)" % metric.psnr(img, xppp), fig=fig, ax=ax[2]) fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/scripts/trace_example.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" SCICO Call Tracing ================== This example demonstrates the call tracing functionality provided by the [trace](../_autosummary/scico.trace.rst) module. It is based on the [non-negative BPDN example](sparsecode_nn_admm.rst). """ import numpy as np import jax import scico.numpy as snp from scico import functional, linop, loss, metric from scico.optimize.admm import ADMM, MatrixSubproblemSolver from scico.trace import register_variable, trace_scico_calls from scico.util import device_info """ Initialize tracing. JIT must be disabled for correct tracing. The call tracing mechanism prints the name, arguments, and return values of functions/methods as they are called. Module and class names are printed in light red, function and method names in dark red, arguments and return values in light blue, and the names of registered variables in light yellow. When a method defined in a class is called for an object of a derived class type, the class of that object is printed in light magenta, in square brackets. Function names and return values are distinguished by initial ">>" and "<<" characters respectively. """ jax.config.update("jax_disable_jit", True) trace_scico_calls() """ Create random dictionary, reference random sparse representation, and test signal consisting of the synthesis of the reference sparse representation. """ m = 32 # signal size n = 128 # dictionary size s = 10 # sparsity level np.random.seed(1) D = np.random.randn(m, n).astype(np.float32) D = D / np.linalg.norm(D, axis=0, keepdims=True) # normalize dictionary xt = np.zeros(n, dtype=np.float32) # true signal idx = np.random.randint(low=0, high=n, size=s) # support of xt xt[idx] = np.random.rand(s) y = D @ xt + 5e-2 * np.random.randn(m) # synthetic signal xt = snp.array(xt) # convert to jax array y = snp.array(y) # convert to jax array """ Register a variable so that it can be referenced by name in the call trace. Any hashable object and numpy arrays may be registered, but JAX arrays cannot. """ register_variable(D, "D") """ Set up the forward operator and ADMM solver object. """ lmbda = 1e-1 A = linop.MatrixOperator(D) register_variable(A, "A") f = loss.SquaredL2Loss(y=y, A=A) g_list = [lmbda * functional.L1Norm(), functional.NonNegativeIndicator()] C_list = [linop.Identity((n)), linop.Identity((n))] rho_list = [1.0, 1.0] maxiter = 1 # number of ADMM iterations (set to small value to simplify trace output) register_variable(f, "f") register_variable(g_list[0], "g_list[0]") register_variable(g_list[1], "g_list[1]") register_variable(C_list[0], "C_list[0]") register_variable(C_list[1], "C_list[1]") solver = ADMM( f=f, g_list=g_list, C_list=C_list, rho_list=rho_list, x0=A.adj(y), maxiter=maxiter, subproblem_solver=MatrixSubproblemSolver(), itstat_options={"display": True, "period": 5}, ) register_variable(solver, "solver") """ Run the solver. """ print(f"Solving on {device_info()}\n") x = solver.solve() mse = metric.mse(xt, x) ================================================ FILE: examples/scripts/video_rpca_admm.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. r""" Video Decomposition via Robust PCA ================================== This example demonstrates video foreground/background separation via a variant of the Robust PCA problem $$\mathrm{argmin}_{\mathbf{x}_0, \mathbf{x}_1} \; (1/2) \| \mathbf{x}_0 + \mathbf{x}_1 - \mathbf{y} \|_2^2 + \lambda_0 \| \mathbf{x}_0 \|_* + \lambda_1 \| \mathbf{x}_1 \|_1 \;,$$ where $\mathbf{x}_0$ and $\mathbf{x}_1$ are respectively low-rank and sparse components, $\| \cdot \|_*$ denotes the nuclear norm, and $\| \cdot \|_1$ denotes the $\ell_1$ norm. Note: while video foreground/background separation is not an example of the scientific and computational imaging problems that are the focus of SCICO, it provides a convenient demonstration of Robust PCA, which does have potential application in scientific imaging problems. """ import imageio.v3 as iio import scico.numpy as snp from scico import functional, linop, loss, plot from scico.examples import rgb2gray from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info """ Load example video. """ vid = rgb2gray( iio.imread("imageio:newtonscradle.gif").transpose((1, 2, 3, 0)).astype(snp.float32) / 255.0 ) """ Construct matrix with each column consisting of a vectorised video frame. """ y = vid.reshape((-1, vid.shape[-1])) """ Define functional for Robust PCA problem. """ A = linop.Sum(axis=0, input_shape=(2,) + y.shape) f = loss.SquaredL2Loss(y=y, A=A) C0 = linop.Slice(idx=0, input_shape=(2,) + y.shape) g0 = functional.NuclearNorm() C1 = linop.Slice(idx=1, input_shape=(2,) + y.shape) g1 = functional.L1Norm() """ Set up an ADMM solver object. """ λ0 = 1e1 # nuclear norm regularization parameter λ1 = 3e1 # ℓ1 norm regularization parameter ρ0 = 2e1 # ADMM penalty parameter ρ1 = 2e1 # ADMM penalty parameter maxiter = 50 # number of ADMM iterations solver = ADMM( f=f, g_list=[λ0 * g0, λ1 * g1], C_list=[C0, C1], rho_list=[ρ0, ρ1], x0=A.adj(y), maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(), itstat_options={"display": True, "period": 10}, ) """ Run the solver. """ print(f"Solving on {device_info()}\n") x = solver.solve() hist = solver.itstat_object.history(transpose=True) """ Plot convergence statistics. """ fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( hist.Objective, title="Objective function", xlbl="Iteration", ylbl="Functional value", fig=fig, ax=ax[0], ) plot.plot( snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, ptyp="semilogy", title="Residuals", xlbl="Iteration", lgnd=("Primal", "Dual"), fig=fig, ax=ax[1], ) fig.show() """ Reshape low-rank component as background video sequence and sparse component as foreground video sequence. """ xlr = C0(x) xsp = C1(x) vbg = xlr.reshape(vid.shape) vfg = xsp.reshape(vid.shape) """ Display original video frames and corresponding background and foreground frames. """ fig, ax = plot.subplots(nrows=4, ncols=3, figsize=(10, 10)) ax[0][0].set_title("Original") ax[0][1].set_title("Background") ax[0][2].set_title("Foreground") for n, fn in enumerate(range(1, 9, 2)): plot.imview(vid[..., fn], fig=fig, ax=ax[n][0]) plot.imview(vbg[..., fn], fig=fig, ax=ax[n][1]) plot.imview(vfg[..., fn], fig=fig, ax=ax[n][2]) ax[n][0].set_ylabel("Frame %d" % fn, labelpad=5, rotation=90, size="large") fig.tight_layout() fig.show() input("\nWaiting for input to close figures and exit") ================================================ FILE: examples/updatejnbcode.py ================================================ #!/usr/bin/env python # Update code cells in notebooks from corresponding scripts without # the need to re-execute the notebook. NB: use with caution! # Run as # python updatejnbcode.py import os import sys from jnb import py_file_to_string, read_notebook from py2jn.tools import py_string_to_notebook, write_notebook def replace_code_cells(src, dst): """Overwrite code cells in notebook object `dst` with corresponding cells in notebook object `src`. """ if "cells" in src: srccell = src["cells"] else: srccell = src["worksheets"][0]["cells"] if "cells" in dst: dstcell = dst["cells"] else: dstcell = dst["worksheets"][0]["cells"] # It is an error to attempt replacement if src and dst have different # numbers of cells if len(srccell) != len(dstcell): raise ValueError("Notebooks do not have the same number of cells.") # Iterate over cells in src for n in range(len(srccell)): # It is an error to attempt replacement if any corresponding pair # of cells have different type if srccell[n]["cell_type"] != dstcell[n]["cell_type"]: raise ValueError("Cell number %d of different type in src and dst.") # If current src cell is a code cell, copy the src cell to the dst cell if srccell[n]["cell_type"] == "code": dstcell[n]["source"] = srccell[n]["source"] src = sys.argv[1] dst = os.path.join("notebooks", os.path.splitext(os.path.basename(src))[0] + ".ipynb") print(f"Updating code cells in {dst} from {src}") if os.path.exists(dst): srcnb = py_string_to_notebook(py_file_to_string(src), nbver=4) dstnb = read_notebook(dst) replace_code_cells(srcnb, dstnb) write_notebook(dstnb, dst) ================================================ FILE: examples/updatejnbmd.py ================================================ #!/usr/bin/env python # Update markdown cells in notebooks from corresponding scripts without # the need to re-execute the notebook. Only applicable if the changes to # the script since generation of the corresponding notebook only affect # markdown cells. # Run as # python updatejnbmd.py import glob import os from jnb import ( py_file_to_string, read_notebook, replace_markdown_cells, same_notebook_code, same_notebook_markdown, ) from py2jn.tools import py_string_to_notebook, write_notebook for src in glob.glob(os.path.join("scripts", "*.py")): dst = os.path.join("notebooks", os.path.splitext(os.path.basename(src))[0] + ".ipynb") if os.path.exists(dst): srcnb = py_string_to_notebook(py_file_to_string(src), nbver=4) dstnb = read_notebook(dst) if not same_notebook_code(srcnb, dstnb): print(f"Non-markup changes in {src}") continue if not same_notebook_markdown(srcnb, dstnb): print(f"Updating markdown in {dst}") replace_markdown_cells(srcnb, dstnb) write_notebook(dstnb, dst) ================================================ FILE: misc/README.rst ================================================ Miscellaneous ============= This directory is a temporary location for content for which there is no obviously more appropriate location: - ``conda``: Scripts intended to faciliate the installation of miniconda and an environment with all SCICO requirements. - ``gpu``: Scripts for debugging and managing JAX use of GPUs. - ``pytest``: Scripts for specialized use of ``pytest``. ================================================ FILE: misc/conda/README.rst ================================================ Conda Installation Scripts ========================== These scripts are intended to faciliate the installation of `miniconda `__ and an environment with all SCICO requirements: - ``install_conda.sh``: Install miniconda - ``make_conda_env.sh``: Create a conda environment with all SCICO requirements For usage details, run the scripts with the ``-h`` flag, e.g. ``./install_conda.sh -h``. Example Usage ------------- To install miniconda in ``/opt/conda`` do :: ./install_conda.sh -y /opt/conda To create a conda environment called ``scico`` with Python version 3.12 and without GPU support :: ./make_conda_env.sh -y -p 3.12 -e scico To include GPU support, follow the `jax installation instructions `__ after running this script and activating the environment created by it. Caveats ------- These scripts should function correctly out-of-the-box on a standard Linux installation. (If you find that this is not the case, please create a GitHub issue, providing details of the Linux variant and version.) While these scripts are supported under OSX (MacOS), there are some caveats: - Required utilities ``realpath`` and ``gsed`` (GNU sed) must be installed via MacPorts or some other 3rd party package management system. - Installation of jaxlib with GPU capabilities is not supported. - While ``make_conda_env.sh`` installs ``matplotlib``, it does not attempt to resolve the `additional complications `_ in using a conda installed matplotlib under OSX. ================================================ FILE: misc/conda/install_conda.sh ================================================ #!/usr/bin/env bash # This script installs miniconda3 in the specified path # # Run with -h flag for usage information URLROOT=https://repo.continuum.io/miniconda/ INSTLINUX=Miniconda3-latest-Linux-x86_64.sh INSTMACOSX=Miniconda3-latest-MacOSX-x86_64.sh SCRIPT=$(basename $0) USAGE=$(cat <<-EOF Usage: $SCRIPT [-h] [-y] install_path [-h] Display usage information [-y] Do not ask for confirmation EOF ) AGREE=no OPTIND=1 while getopts ":hy" opt; do case $opt in h) echo "$USAGE"; exit 0;; y) AGREE=yes;; \?) echo "Error: invalid option -$OPTARG" >&2 echo "$USAGE" >&2 exit 1 ;; esac done shift $((OPTIND-1)) if [ ! $# -eq 1 ] ; then echo "Error: one positional argument required" >&2 echo "$USAGE" >&2 exit 1 fi OS=$(uname -a | cut -d ' ' -f 1) case "$OS" in Linux) SOURCEURL=$URLROOT$INSTLINUX;; Darwin) SOURCEURL=$URLROOT$INSTMACOSX;; *) echo "Error: unsupported operating system $OS" >&2; exit 2;; esac if [ ! "$(which wget 2>/dev/null)" ]; then has_wget=0 else has_wget=1 fi if [ ! "$(which curl 2>/dev/null)" ]; then has_curl=0 else has_curl=1 fi if [ $has_curl -eq 0 ] && [ $has_wget -eq 0 ]; then echo "Error: neither curl nor wget found; at least one required" >&2 exit 3 fi INSTALLROOT=$1 if [ ! -d "$INSTALLROOT" ] || [ ! -w "$INSTALLROOT" ]; then echo "Error: installation root path \"$INSTALLROOT\" is not a directory "\ "or is not writable" >&2 exit 4 fi CONDAHOME=$INSTALLROOT/miniconda3 if [ -d "$CONDAHOME" ]; then echo "Error: miniconda3 installation directory $CONDAHOME already exists"\ >&2 exit 5 fi if [ "$AGREE" == "no" ]; then read -r -p "Confirm conda installation in root path $INSTALLROOT [y/N] "\ CNFRM if [ "$CNFRM" != 'y' ] && [ "$CNFRM" != 'Y' ]; then echo "Cancelling installation" exit 6 fi fi # Get miniconda bash archive and install it if [ $has_wget -eq 1 ]; then wget $SOURCEURL -O /tmp/miniconda.sh elif [ $has_curl -eq 1 ]; then curl -L $SOURCEURL -o /tmp/miniconda.sh fi bash /tmp/miniconda.sh -b -p $CONDAHOME rm -f /tmp/miniconda.sh # Initial conda setup export PATH="$CONDAHOME/bin:$PATH" hash -r conda config --set always_yes yes conda update -q conda conda info -a echo "Add the following to your .bashrc or .bash_aliases file" echo " export CONDAHOME=$CONDAHOME" echo " export PATH=\$PATH:\$CONDAHOME/bin" exit 0 ================================================ FILE: misc/conda/make_conda_env.sh ================================================ #!/usr/bin/env bash # This script installs a conda environment with all required and # optional scico dependencies. The user is assumed to have write # permission for the conda installation. It should function correctly # under both Linux and OSX, but note that there are some additional # complications in using a conda installed matplotlib under OSX # https://matplotlib.org/faq/osx_framework.html # that are not addressed, and that installation of jaxlib with GPU # capabilities is not supported under OSX. Note also that additional # utilities realpath and gsed (gnu sed), available from MacPorts, are # required to run this script under OSX. # # Run with -h flag for usage information set -e # exit when any command fails if [ "$(cut -d '.' -f 1 <<< "$BASH_VERSION")" -lt "4" ]; then echo "Error: this script requires bash version 4 or later" >&2 exit 1 fi SCRIPT=$(basename $0) REPOPATH=$(realpath $(dirname $0)) USAGE=$(cat <<-EOF Usage: $SCRIPT [-h] [-y] [-g] [-p python_version] [-e env_name] [-h] Display usage information [-v] Verbose operation [-t] Display actions that would be taken but do nothing [-y] Do not ask for confirmation [-p python_version] Specify Python version (e.g. 3.12) [-e env_name] Specify conda environment name EOF ) AGREE=no VERBOSE=no TEST=no PYVER="3.12" ENVNM=py$(echo $PYVER | sed -e 's/\.//g') # Project requirements files REQUIRE=$(cat <<-EOF requirements.txt dev_requirements.txt docs/docs_requirements.txt examples/examples_requirements.txt examples/notebooks_requirements.txt EOF ) # Requirements that cannot be installed via conda (i.e. have to use pip) NOCONDA=$(cat <<-EOF flax orbax-checkpoint bm3d bm4d py2jn colour_demosaicing hyperopt ray[tune,train] EOF ) OPTIND=1 while getopts ":hvtyp:e:" opt; do case $opt in p|e) if [ -z "$OPTARG" ] || [ "${OPTARG:0:1}" = "-" ] ; then echo "Error: option -$opt requires an argument" >&2 echo "$USAGE" >&2 exit 2 fi ;;& h) echo "$USAGE"; exit 0;; t) VERBOSE=yes;TEST=yes;; v) VERBOSE=yes;; y) AGREE=yes;; p) PYVER=$OPTARG;; e) ENVNM=$OPTARG;; :) echo "Error: option -$OPTARG requires an argument" >&2 echo "$USAGE" >&2 exit 2 ;; \?) echo "Error: invalid option -$OPTARG" >&2 echo "$USAGE" >&2 exit 2 ;; esac done shift $((OPTIND-1)) if [ ! $# -eq 0 ] ; then echo "Error: no positional arguments" >&2 echo "$USAGE" >&2 exit 2 fi if [ ! "$(which conda 2>/dev/null)" ]; then echo "Error: conda command required but not found" >&2 exit 3 fi # Not available on BSD systems such as OSX: install via MacPorts etc. if [ ! "$(which realpath 2>/dev/null)" ]; then echo "Error: realpath command required but not found" >&2 exit 4 fi # Ensure that a C compiler is available; required for installing svmbir # On debian/ubuntu linux systems, install package build-essential if [ -z "$CC" ] && [ ! "$(which gcc 2>/dev/null)" ]; then echo "Error: gcc command not found and CC environment variable not set" echo " set CC to the path of your C compiler, or install gcc." echo " On debian/ubuntu, you may need to do" echo " sudo apt install build-essential" exit 5 fi OS=$(uname -a | cut -d ' ' -f 1) case "$OS" in Linux) SOURCEURL=$URLROOT$INSTLINUX; SED="sed";; Darwin) SOURCEURL=$URLROOT$INSTMACOSX; SED="gsed";; *) echo "Error: unsupported operating system $OS" >&2; exit 6;; esac if [ "$OS" == "Darwin" ] && [ "$GPU" == yes ]; then echo "Error: GPU-enabled jaxlib installation not supported under OSX" >&2 exit 7 fi if [ "$OS" == "Darwin" ]; then if [ ! "$(which gsed 2>/dev/null)" ]; then echo "Error: gsed command required but not found" >&2 exit 8 fi fi JLVER=$($SED -n 's/^jaxlib>=.*<=\([0-9\.]*\).*/\1/p' \ $REPOPATH/../../requirements.txt) JXVER=$($SED -n 's/^jax>=.*<=\([0-9\.]*\).*/\1/p' \ $REPOPATH/../../requirements.txt) # Construct merged list of all requirements if [ "$OS" == "Darwin" ]; then ALLREQUIRE=$(/usr/bin/mktemp -t condaenv) else ALLREQUIRE=$(mktemp -t condaenv_XXXXXX.txt) fi for req in $REQUIRE; do pthreq="$REPOPATH/../../$req" cat $pthreq >> $ALLREQUIRE done # Construct filtered list of requirements: sort, remove duplicates, and # remove requirements that require special handling if [ "$OS" == "Darwin" ]; then FLTREQUIRE=$(mktemp -t condaenv) else FLTREQUIRE=$(mktemp -t condaenv_XXXXXX.txt) fi # Filter the list of requirements; sed patterns are for # 1st: escape >,<,| characters with a backslash # 2nd: remove comments in requirements file # 3rd: remove recursive include (-r) lines and packages that require # special handling, e.g. jaxlib sort $ALLREQUIRE | uniq | $SED -E 's/(>|<|\|)/\\\1/g' \ | $SED -E 's/\#.*$//g' \ | $SED -E '/^-r.*|^jaxlib.*|^jax.*/d' > $FLTREQUIRE # Remove requirements that cannot be installed via conda PIPREQ="" for nc in $NOCONDA; do # Escape [ and ] for use in regex nc=$(echo $nc | $SED -E 's/(\[|\])/\\\1/g') # Add package to pip package list PIPREQ="$PIPREQ "$(grep "$nc" $FLTREQUIRE | $SED 's/\\//g') # Remove package $nc from conda package list $SED -i "/^$nc.*\$/d" $FLTREQUIRE done # Get list of requirements to be installed via conda CONDAREQ=$(cat $FLTREQUIRE | xargs) if [ "$VERBOSE" == "yes" ]; then echo "Create python $PYVER environment $ENVNM in conda installation" echo " $CONDAHOME" echo "Packages to be installed via conda:" echo " $CONDAREQ" | fmt -w 79 echo "Packages to be installed via pip:" echo " jaxlib==$JLVER jax==$JXVER $PIPREQ" | fmt -w 79 if [ "$TEST" == "yes" ]; then exit 0 fi fi CONDAHOME=$(conda info --base) ENVDIR=$CONDAHOME/envs/$ENVNM if [ -d "$ENVDIR" ]; then echo "Error: environment $ENVNM already exists" exit 9 fi if [ "$AGREE" == "no" ]; then RSTR="Confirm creation of conda environment $ENVNM with Python $PYVER" RSTR="$RSTR [y/N] " read -r -p "$RSTR" CNFRM if [ "$CNFRM" != 'y' ] && [ "$CNFRM" != 'Y' ]; then echo "Cancelling environment creation" exit 10 fi else echo "Creating conda environment $ENVNM with Python $PYVER" fi if [ "$AGREE" == "yes" ]; then CONDA_FLAGS="-y" else CONDA_FLAGS="" fi # Update conda, create new environment, and activate it conda update $CONDA_FLAGS -n base conda conda create $CONDA_FLAGS -n $ENVNM python=$PYVER # See https://stackoverflow.com/a/56155771/1666357 eval "$(conda shell.bash hook)" # required to avoid errors re: `conda init` conda activate $ENVNM # Q: why not `source activate`? A: not always in the path # Add conda-forge channel conda config --append channels conda-forge # Install required conda packages (and extra useful packages) conda install $CONDA_FLAGS $CONDAREQ ipython # Utility ffmpeg is required by imageio for reading mp4 video files # it can also be installed via the system package manager, .e.g. # sudo apt install ffmpeg if [ "$(which ffmpeg)" = '' ]; then conda install $CONDA_FLAGS ffmpeg fi # Install jaxlib and jax pip install --upgrade jaxlib==$JLVER jax==$JXVER # Install other packages that require installation via pip pip install $PIPREQ # Warn if libopenblas-dev not installed on debian/ubuntu if [ "$(which dpkg 2>/dev/null)" ]; then if [ ! "$(dpkg -s libopenblas-dev 2>/dev/null)" ]; then echo "Warning (debian/ubuntu): package libopenblas-dev," echo "which is required by bm3d, does not appear to be" echo "installed; install using the command" echo " sudo apt install libopenblas-dev" fi fi echo echo "Activate the conda environment with the command" echo " conda activate $ENVNM" echo "The environment can be deactivated with the command" echo " conda deactivate" echo echo "JAX installed without GPU support. To enable GPU support, install a" echo "version of jaxlib with CUDA support following the instructions at" echo " https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu" echo "In most cases this just requires the command" echo " pip install -U \"jax[cuda12]\"" echo echo "ASTRA Toolbox installed without GPU support if this script was" echo "run on a host without CUDA drivers installed. To enable GPU support," echo "uninstall and then reinstall the astra-toolbox conda package on a" echo "host with CUDA drivers installed." exit 0 ================================================ FILE: misc/gpu/README.rst ================================================ GPU Utility Scripts =================== These scripts are intended for debugging and managing JAX use of GPUs: - ``availgpu.py``: Automatically recommend a setting of the ``CUDA_VISIBLE_DEVICES`` environment variable that excludes GPUs that are already in use. - ``envinfo.py``: An aid to debugging JAX GPU access. ================================================ FILE: misc/gpu/availgpu.py ================================================ #!/usr/bin/env python # Determine which GPUs available for use and recommend CUDA_VISIBLE_DEVICES # setting if any are already in use. # pylint: disable=missing-module-docstring import GPUtil print("GPU utlizitation") GPUtil.showUtilization() devIDs = GPUtil.getAvailable( order="first", limit=65536, maxLoad=0.1, maxMemory=0.1, includeNan=False ) Ngpu = len(GPUtil.getGPUs()) if len(devIDs) == Ngpu: print(f"All {Ngpu} GPUs available for use") else: print(f"Only {len(devIDs)} of {Ngpu} GPUs available for use") print("To avoid attempting to use GPUs already in use, run the command") print(f" export CUDA_VISIBLE_DEVICES={','.join(map(str, devIDs))}") ================================================ FILE: misc/gpu/envinfo.py ================================================ #!/usr/bin/env python # Print host and environment information. Useful for determining whether # a Python host has available GPUs, and if so, whether the JAX installation # is able to make use of them. # pylint: disable=missing-module-docstring import sys missing = [] try: import psutil have_psutil = True except ImportError: have_psutil = False missing.append("psutil") try: import GPUtil have_gputil = True except ImportError: have_gputil = False missing.append("gputil") import jax import jaxlib try: import scico have_scico = True except ImportError: scico = None have_scico = False missing.append("scico") if missing: print("Some output not available due to missing modules: " + ", ".join(missing)) pyver = ".".join([f"{v}" for v in sys.version_info[0:3]]) print(f"Python version: {pyver}") print("Packages:") packages = [jaxlib, jax, scico] for p in packages: if hasattr(p, "__version__") and hasattr(p, "__name__"): v = getattr(p, "__version__") n = getattr(p, "__name__") print(f" {n:15s} {v}") if have_psutil: print(f"Number of CPU cores: {psutil.cpu_count(logical=False)}") if have_gputil: if GPUtil.getAvailable(): print("GPUs:") for gpu in GPUtil.getGPUs(): print(f" {gpu.id:2d} {gpu.name:10s} {gpu.memoryTotal} kB RAM") else: print("No GPUs available") sys.stderr = open("/dev/null") # suppress annoying jax warning numdev = jax.device_count() if jax.devices()[0].device_kind == "cpu": print("No GPUs available to JAX (JAX device is CPU)") else: print(f"Number of GPUs available to JAX: {jax.device_count()}") ================================================ FILE: misc/pytest/README.rst ================================================ Specialized Pytest Usage ======================== These scripts support specialized ``pytest`` usage: - ``pytest_cov.sh``: This script runs ``scico`` unit tests using the ``pytest-cov`` plugin for test coverage analysis. - ``pytest_fast.sh``: This script runs ``pytest`` tests in parallel using the ``pytest-xdist`` plugin. Some tests (those that do not function correctly when run in parallel) are run separately. - ``pytest_time.sh``: This script runs each ``scico`` unit test module and lists them all in order of decreasing run time. All of these scripts must be run from the repository root directory. ================================================ FILE: misc/pytest/pytest_cov.sh ================================================ #!/usr/bin/env bash # This script runs scico unit tests using the pytest-cov plugin for test # coverage analysis. It must be run from the repository root directory. plugin="pytest-cov" if ! pytest -VV | grep -o $plugin > /dev/null; then echo Required pytest plugin $plugin not installed exit 1 fi pytest --cov=scico --cov-report html echo "To view the report, open htmlcov/index.html in a web browser." exit 0 ================================================ FILE: misc/pytest/pytest_fast.sh ================================================ #!/usr/bin/env bash # This script runs pytest tests in parallel using the pytest-xdist plugin. # Some tests that do not function correctly when run in parallel are run # separately. It must be run from the repository root directory. plugin="pytest-xdist" if ! pytest -VV | grep -o $plugin > /dev/null; then echo Required pytest plugin $plugin not installed exit 1 fi pytest --deselect scico/test/test_ray_tune.py \ --deselect scico/test/functional/test_core.py -x -n 2 pytest -x scico/test/test_ray_tune.py scico/test/functional/test_core.py exit 0 ================================================ FILE: misc/pytest/pytest_time.sh ================================================ #!/usr/bin/env bash # This script runs each scico unit test module and lists them all in order # of decreasing run time. It must be run from the repository root directory. tmp=/tmp/pytest_time.$$ rm -f $tmp for f in $(find scico/test -name "test_*.py"); do tstr=$(/usr/bin/time -p pytest -qqq --disable-warnings $f 2>&1 | tail -4) # Warning does not work in OSX bash if grep -q "Command exited with non-zero status" <<<"$tstr"; then echo "WARNING: test failure in $f" >&2 fi t=$(grep "^real" <<<"$tstr" | grep -o -E "[0-9\.]*$") printf "%6.2f %s\n" $t $f >> $tmp done sort -r -n $tmp rm $tmp exit 0 ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" [tool.black] line-length = 100 target-version = ['py312'] include = '\.pyi?$' exclude = ''' ( /( \.eggs # exclude a few common directories in the | \.git # root of the project | \.hg | \.mypy_cache | \.tox | \.venv | _build | buck-out | build | dist )/ | foo.py # also separately exclude a file named foo.py in # the root of the project ) ''' [tool.isort] profile = "black" multi_line_output = 3 known_jax = ['jax'] known_numpy = ['numpy'] sections = ['FUTURE', 'STDLIB', 'NUMPY', 'JAX', 'THIRDPARTY', 'FIRSTPARTY', 'LOCALFOLDER'] src_paths = ["scico", "examples/scripts"] [mypy] python_version = 3.12 disable_error_code = ['attr-defined'] ================================================ FILE: pytest.ini ================================================ [pytest] testpaths = scico/test docs addopts = --doctest-glob="*rst" doctest_optionflags = NORMALIZE_WHITESPACE NUMBER filterwarnings = ignore::DeprecationWarning:.*pkg_resources.* ignore::DeprecationWarning:.*hyperopt.* ignore::DeprecationWarning:.*flax.* ignore::DeprecationWarning:.*.tensorboardx.* ignore::DeprecationWarning:.*xdesign.* ignore:.*pkg_resources.*:DeprecationWarning ignore:.*imp module.*:DeprecationWarning ================================================ FILE: requirements.txt ================================================ typing_extensions numpy>=2.0 scipy>=1.13 imageio>=2.17 tifffile matplotlib jaxlib>=0.5.0,<=0.10.0 jax>=0.5.0,<=0.10.0 flax>=0.8.0,<=0.12.7 pyabel>=0.9.1 ================================================ FILE: scico/__init__.py ================================================ # Copyright (C) 2021-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Scientific Computational Imaging COde (SCICO) is a Python package for solving the inverse problems that arise in scientific imaging applications. """ __version__ = "0.0.8.dev0" import logging import sys # isort: off # Suppress jax device warning. See https://github.com/google/jax/issues/6805 logging.getLogger("jax._src.xla_bridge").addFilter( # jax 0.4.8 and later logging.Filter("No GPU/TPU found, falling back to CPU.") ) # isort: on import jax from jax import custom_jvp, custom_vjp, hessian, jacfwd, jvp, linearize, vjp import jaxlib from . import numpy from ._core import * from ._core import __all__ as _core_all # See https://github.com/google/jax/issues/19444 jax.config.update("jax_default_matmul_precision", "highest") __all__ = _core_all + [ "custom_jvp", "custom_vjp", "hessian", "jacfwd", "jvp", "linearize", "vjp", ] # Imported items in __all__ appear to originate in top-level functional module for name in __all__: getattr(sys.modules[__name__], name).__module__ = __name__ ================================================ FILE: scico/_core.py ================================================ # Copyright (C) 2020-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Extensions of core jax functions, including tools for automatic differentiation and shape evaluation.""" import sys from typing import Any, Callable, Optional, Sequence, Tuple, Union import jax from jax.tree_util import tree_map import scico.numpy import scico.numpy.util import scico.util __all__ = [ "cvjp", "eval_shape", "grad", "jacrev", "linear_adjoint", "linear_transpose", "value_and_grad", ] def _append_jax_docs(fn, jaxfn=None): """Append the jax function docs. Given wrapper function `fn`, concatenate its docstring with the docstring of the wrapped jax function. """ name = fn.__name__ if jaxfn is None: jaxfn = getattr(jax, name) doc = " " + fn.__doc__.replace("\n ", "\n ") # deal with indentation differences jaxdoc = "\n".join(jaxfn.__doc__.split("\n")[2:]) # strip initial lines return doc + f"\n Docstring for :func:`jax.{name}`:\n\n" + jaxdoc def _convert_ba_dts(arg: Any) -> Any: """Convert a ShapeDtypeStruct with nested shape into a BlockArray of ShapeDtypeStruct. """ if isinstance(arg, jax.ShapeDtypeStruct) and scico.numpy.util.is_nested(arg.shape): return scico.numpy.BlockArray( [jax.ShapeDtypeStruct(blk_shape, dtype=arg.dtype) for blk_shape in arg.shape] ) else: return arg def eval_shape(fun: Callable, *args, **kwargs) -> Any: """Compute the shape and dtype of a function without executing it. Compute the shape and dtype of a function without executing it, via a call to :func:`jax.eval_shape`, with ``args`` and ``kwargs`` mapped to handle :class:`jax.ShapeDtypeStruct` objects with nested shapes corresponding to :class:`.BlockArray` objects. """ mapped_args = jax.tree_util.tree_map(_convert_ba_dts, args) mapped_kwargs = jax.tree_util.tree_map(_convert_ba_dts, kwargs) return jax.eval_shape(fun, *mapped_args, **mapped_kwargs) def grad( fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, ) -> Callable: """Create a function that evaluates the gradient of `fun`. :func:`scico.grad` differs from :func:`jax.grad` in that the output is conjugated. """ jax_grad = jax.grad( fun=fun, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int ) def conjugated_grad_aux(*args, **kwargs): jg, aux = jax_grad(*args, **kwargs) return tree_map(jax.numpy.conj, jg), aux def conjugated_grad(*args, **kwargs): jg = jax_grad(*args, **kwargs) return tree_map(jax.numpy.conj, jg) return conjugated_grad_aux if has_aux else conjugated_grad def value_and_grad( fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, ) -> Callable[..., Tuple[Any, Any]]: """Create a function that evaluates both `fun` and its gradient. :func:`scico.value_and_grad` differs from :func:`jax.value_and_grad` in that the gradient is conjugated. """ jax_val_grad = jax.value_and_grad( fun=fun, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int ) def conjugated_value_and_grad_aux(*args, **kwargs): (value, aux), jg = jax_val_grad(*args, **kwargs) conj_grad = tree_map(jax.numpy.conj, jg) return (value, aux), conj_grad def conjugated_value_and_grad(*args, **kwargs): value, jax_grad = jax_val_grad(*args, **kwargs) conj_grad = tree_map(jax.numpy.conj, jax_grad) return value, conj_grad return conjugated_value_and_grad_aux if has_aux else conjugated_value_and_grad def linear_transpose(fun: Callable, *primals) -> Callable: """Transpose a function that is guaranteed to be linear. :func:`scico.linear_adjoint` differs from :func:`jax.linear_transpose` in that it correctly handles primals consisting of :class:`jax.ShapeDtypeStruct` objects with nested shapes, i.e. corresponding to :class:`.BlockArray` shapes. """ mapped_primals = jax.tree_util.tree_map(_convert_ba_dts, primals) return jax.linear_transpose(fun, *mapped_primals) def linear_adjoint(fun: Callable, *primals) -> Callable: """Conjugate transpose a function that is guaranteed to be linear. :func:`scico.linear_adjoint` differs from :func:`jax.linear_transpose` for complex inputs in that the conjugate transpose (adjoint) of `fun` is returned. :func:`scico.linear_adjoint` is identical to :func:`jax.linear_transpose` for real-valued primals. """ def conj_fun(*primals): conj_primals = tree_map(jax.numpy.conj, primals) return tree_map(jax.numpy.conj, fun(*(conj_primals))) return linear_transpose(conj_fun, *primals) def jacrev( fun: Callable, argnums: Union[int, Sequence[int]] = 0, holomorphic: bool = False, allow_int: bool = False, ) -> Callable: """Jacobian of `fun` evaluated row-by-row using reverse-mode AD. :func:`scico.jacrev` differs from :func:`jax.jacrev` in that the output is conjugated. """ jax_jacrev = jax.jacrev(fun=fun, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int) def conjugated_jacrev(*args, **kwargs): tmp = jax_jacrev(*args, **kwargs) return tree_map(jax.numpy.conj, tmp) return conjugated_jacrev def cvjp(fun: Callable, *primals, jidx: Optional[int] = None) -> Tuple[Tuple[Any, ...], Callable]: r"""Compute a vector-Jacobian product with conjugate transpose. Compute the product :math:`[J(\mb{x})]^H \mb{v}` where :math:`[J(\mb{x})]` is the Jacobian of function `fun` evaluated at :math:`\mb{x}`. Instead of directly evaluating the product, a function is returned that takes :math:`\mb{v}` as an argument. If `fun` has multiple positional parameters, the Jacobian can be taken with respect to only one of them by setting the `jidx` parameter of this function to the positional index of that parameter. Args: fun: Function for which the Jacobian is implicitly computed. primals: Sequence of values at which the Jacobian is evaluated, with length equal to the number of positional arguments of `fun`. jidx: Index of the positional parameter of `fun` with respect to which the Jacobian is taken. Returns: A pair `(primals_out, conj_vjp)` where `primals_out` is the output of `fun` evaluated at `primals`, i.e. `primals_out = fun(*primals)`, and `conj_vjp` is a function that computes the product of the conjugate (Hermitian) transpose of the Jacobian of `fun` and its argument. If the `jidx` parameter is an integer, then the Jacobian is only taken with respect to the coresponding positional parameter of `fun`. """ if jidx is None: primals_out, fun_vjp = jax.vjp(fun, *primals) else: fixidx = tuple(range(0, jidx)) + tuple(range(jidx + 1, len(primals))) fixprm = primals[0:jidx] + primals[jidx + 1 :] pfun = scico.util.partial(fun, fixidx, *fixprm) primals_out, fun_vjp = jax.vjp(pfun, primals[jidx]) def conj_vjp(tangent): return jax.tree_util.tree_map(jax.numpy.conj, fun_vjp(tangent.conj())) return primals_out, conj_vjp # Append docstring from original jax function for name in __all__: if name == "cvjp": continue func = getattr(sys.modules[__name__], name) jaxfn = jax.linear_transpose if name == "linear_adjoint" else None func.__doc__ = _append_jax_docs(func, jaxfn=jaxfn) ================================================ FILE: scico/_version.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Support functions for determining the package version.""" import os import re from ast import parse from subprocess import PIPE, Popen from typing import Any, Optional, Tuple, Union def root_init_path() -> str: # pragma: no cover """Get the path to the package root `__init__.py` file. Returns: Path to the package root `__init__.py` file. """ return os.path.join(os.path.dirname(__file__), "__init__.py") def variable_assign_value(path: str, var: str) -> Any: """Get variable initialization value from a Python file. Args: path: Path of Python file. var: Name of variable. Returns: Value to which variable `var` is initialized. Raises: RuntimeError: If the statement initializing variable `var` is not found. """ with open(path) as f: try: # See https://stackoverflow.com/a/30471662 value_obj = parse(next(filter(lambda line: line.startswith(var), f))).body[0].value # type: ignore value = value_obj.value # type: ignore except StopIteration: raise RuntimeError(f"Could not find initialization of variable {var}") return value def init_variable_assign_value(var: str) -> Any: # pragma: no cover """Get variable initialization value from package `__init__.py` file. Args: var: Name of variable. Returns: Value to which variable `var` is initialized. Raises: RuntimeError: If the statement initializing variable `var` is not found. """ return variable_assign_value(root_init_path(), var) def current_git_hash() -> Optional[str]: # nosec pragma: no cover """Get current short git hash. Returns: Short git hash of current commit, or ``None`` if no git repo found. """ process = Popen(["git", "rev-parse", "--short", "HEAD"], shell=False, stdout=PIPE, stderr=PIPE) git_hash: Optional[str] = process.communicate()[0].strip().decode("utf-8") if git_hash == "": git_hash = None return git_hash def package_version(split: bool = False) -> Union[str, Tuple[str, str]]: # pragma: no cover """Get current package version. Args: split: Flag indicating whether to return the package version as a single string or split into a tuple of components. Returns: Package version string or tuple of strings. """ version = init_variable_assign_value("__version__") # don't extend purely numeric version numbers, possibly ending with post if re.match(r"^[0-9\.]+(post[0-9]+)?$", version): git_hash = None else: git_hash = current_git_hash() if git_hash: git_hash = "+" + git_hash else: git_hash = "" if split: version = (version, git_hash) else: version = version + git_hash return version ================================================ FILE: scico/data/__init__.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Data files for usage examples.""" import os.path from typing import Optional from imageio.v3 import imread import scico.numpy as snp __all__ = ["kodim23"] def _imread(filename: str, path: Optional[str] = None, asfloat: bool = False) -> snp.Array: """Read an image from disk. Args: filename: Base filename (i.e. without path) of image file. path: Path to directory containing the image file. asfloat: Flag indicating whether the returned image should be converted to :attr:`~numpy.float32` dtype with a range [0, 1]. Returns: Image data array. """ if path is None: path = os.path.join(os.path.dirname(__file__), "examples") im = imread(os.path.join(path, filename)) if asfloat: im = im.astype(snp.float32) / 255.0 return im def kodim23(asfloat: bool = False) -> snp.Array: """Return the `kodim23` test image. Args: asfloat: Flag indicating whether the returned image should be converted to :attr:`~numpy.float32` dtype with a range [0, 1]. Returns: Image data array. """ return _imread("kodim23.png", asfloat=asfloat) def _flax_data_path(filename: str) -> str: """Get the full filename of a flax data file. Args: filename: Base filename (i.e. without path) of data file. Returns: Full filename, with path, of data file. """ return os.path.join(os.path.dirname(__file__), "flax", filename) ================================================ FILE: scico/denoiser.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Interfaces to standard denoisers.""" from typing import Any, Optional, Union import numpy as np import jax try: import bm3d as tubm3d except ImportError: have_bm3d = False BM3DProfile = Any else: have_bm3d = True from bm3d.profiles import BM3DProfile # type: ignore try: import bm4d as tubm4d except ImportError: have_bm4d = False BM4DProfile = Any else: have_bm4d = True from bm4d.profiles import BM4DProfile # type: ignore import scico.numpy as snp from scico.data import _flax_data_path from scico.flax import DnCNNNet, FlaxMap, load_variables def bm3d(x: snp.Array, sigma: float, is_rgb: bool = False, profile: Union[BM3DProfile, str] = "np"): r"""An interface to the BM3D denoiser :cite:`dabov-2008-image`. BM3D denoising is performed using the `code `__ released with :cite:`makinen-2019-exact`. Since this package is an interface to compiled C code, JAX features such as automatic differentiation and support for GPU devices are not available. Args: x: Input image. Expected to be a 2D array (gray-scale denoising) or 3D array (color denoising). Higher-dimensional arrays are tolerated only if the additional dimensions are singletons. For color denoising, the color channel is assumed to be in the last non-singleton dimension. sigma: Noise parameter. is_rgb: Flag indicating use of BM3D with a color transform. Default: ``False``. profile: Parameter configuration for BM3D. Returns: Denoised output. """ if not have_bm3d: raise RuntimeError("Package bm3d is required for use of this function.") if is_rgb is True: def bm3d_eval(x: snp.Array, sigma: float): return tubm3d.bm3d_rgb(x, sigma, profile=profile) else: def bm3d_eval(x: snp.Array, sigma: float): return tubm3d.bm3d(x, sigma, profile=profile) if snp.util.is_complex_dtype(x.dtype): raise TypeError(f"BM3D requires real-valued inputs, got {x.dtype}.") # Support arrays with more than three axes when the additional axes are singletons. x_in_shape = x.shape if isinstance(x.ndim, tuple) or x.ndim < 2: raise ValueError( "BM3D requires two-dimensional or three dimensional inputs; got ndim = {x.ndim}." ) # This check is also performed inside the BM3D call, but due to the callback, # no exception is raised and the program will crash with no traceback. # NOTE: if BM3D is extended to allow for different profiles, the block size must be # updated; this presumes 'np' profile (bs=8) if profile == "np" and np.min(x.shape[:2]) < 8: raise ValueError( "Two leading dimensions of input cannot be smaller than block size " f"(8); got image size = {x.shape}." ) if x.ndim > 3: if all(k == 1 for k in x.shape[3:]): x = x.squeeze() else: raise ValueError( "Arrays with more than three axes are only supported when " " the additional axes are singletons." ) y = jax.pure_callback( lambda args: bm3d_eval(*args).astype(x.dtype), jax.ShapeDtypeStruct(x.shape, x.dtype), (x, sigma), ) # undo squeezing, if neccessary y = y.reshape(x_in_shape) return y def bm4d(x: snp.Array, sigma: float, profile: Union[BM4DProfile, str] = "np"): r"""An interface to the BM4D denoiser :cite:`maggioni-2012-nonlocal`. BM4D denoising is performed using the `code `__ released by the authors of :cite:`maggioni-2012-nonlocal`. Since this package is an interface to compiled C code, JAX features such as automatic differentiation and support for GPU devices are not available. Args: x: Input image. Expected to be a 3D array. Higher-dimensional arrays are tolerated only if the additional dimensions are singletons. sigma: Noise parameter. profile: Parameter configuration for BM4D. Returns: Denoised output. """ if not have_bm4d: raise RuntimeError("Package bm4d is required for use of this function.") def bm4d_eval(x: snp.Array, sigma: float): return tubm4d.bm4d(x, sigma, profile=profile) if snp.util.is_complex_dtype(x.dtype): raise TypeError(f"BM4D requires real-valued inputs, got {x.dtype}.") # Support arrays with more than three axes when the additional axes are singletons. x_in_shape = x.shape if isinstance(x.ndim, tuple) or x.ndim < 3: raise ValueError(f"BM4D requires three-dimensional inputs; got ndim = {x.ndim}.") # This check is also performed inside the BM4D call, but due to the callback, # no exception is raised and the program will crash with no traceback. # NOTE: if BM4D is extended to allow for different profiles, the block size must be # updated; this presumes 'np' profile (bs=8) if profile == "np" and np.min(x.shape[:3]) < 8: raise ValueError( "Three leading dimensions of input cannot be smaller than block size " f"(8); got image size = {x.shape}." ) if x.ndim > 3: if all(k == 1 for k in x.shape[3:]): x = x.squeeze() else: raise ValueError( "Arrays with more than three axes are only supported when " " the additional axes are singletons." ) y = jax.pure_callback( lambda args: bm4d_eval(*args).astype(x.dtype), jax.ShapeDtypeStruct(x.shape, x.dtype), (x, sigma), ) # undo squeezing, if neccessary y = y.reshape(x_in_shape) return y class DnCNN(FlaxMap): """Flax implementation of the DnCNN denoiser. A flax implementation of the DnCNN denoiser :cite:`zhang-2017-dncnn`. Note that :class:`.DnCNNNet` represents an untrained form of the generic DnCNN CNN structure, while this class represents a trained form with six or seventeen layers. The standard DnCNN as proposed in :cite:`zhang-2017-dncnn` does not have a noise level input. This implementation of DnCNN also supports a custom variant that includes a noise standard deviation input, `sigma`, which is included in the network as an additional channel consisting of a constant array with value `sigma`. This network was trained with image data on the range [0, 1], and with noise standard deviations ranging from 0.0 to 0.2. It is worth noting that DRUNet :cite:`zhang-2021-plug`, another recent approach to including a noise level input in a CNN denoiser, is based on a substantially different network architecture. """ def __init__(self, variant: str = "6M"): """ Note that all DnCNN models are trained for single-channel image input. Multi-channel input is supported via independent denoising of each channel. Input images are expected to have pixel values in the range [0, 1]. Args: variant: Identify the DnCNN model to be used. Options are '6L', '6M' (default), '6H', '6N', '17L', '17M', '17H', and '17N', where the integer indicates the number of layers in the network, and the postfix indicates the training noise standard deviation (with respect to data in the range [0, 1]): L (low) = 0.06, M (mid) = 0.10, H (high) = 0.20, or N indicating that a noise standard deviation input, `sigma`, is available. """ self.variant = variant if variant not in ["6L", "6M", "6H", "17L", "17M", "17H", "6N", "17N"]: raise ValueError(f"Invalid value {variant} of parameter variant.") if variant[0] == "6": nlayer = 6 else: nlayer = 17 channels = 2 if variant in ["6N", "17N"] else 1 if variant in ["6N", "17N"]: self.is_blind = False else: self.is_blind = True model = DnCNNNet(depth=nlayer, channels=channels, num_filters=64, dtype=np.float32) variables = load_variables(_flax_data_path("dncnn%s.mpk" % variant)) super().__init__(model, variables) def __call__(self, x: snp.Array, sigma: Optional[float] = None) -> snp.Array: r"""Apply DnCNN denoiser. Args: x: Input array. sigma: Noise standard deviation (for variants `6N` and `17N`). Returns: Denoised output. """ if sigma is not None and self.is_blind: raise ValueError( "A non-default value for the sigma parameter may " "only be specified when the variant is 6N or 17N" f"; got variant = {self.variant}." ) if sigma is None and not self.is_blind: raise ValueError( "A float value must be specified for the sigma " "parameter when the variant is 6N or 17N." ) if snp.util.is_complex_dtype(x.dtype): raise TypeError(f"DnCNN requries real-valued inputs, got {x.dtype}.") if isinstance(x.ndim, tuple) or x.ndim < 2: raise ValueError( "DnCNN requires two-dimensional (M, N) or three-dimensional (M, N, C)" f" inputs; got ndim = {x.ndim}." ) x_in_shape = x.shape if x.ndim > 3: if all(k == 1 for k in x.shape[3:]): x = x.squeeze() else: raise ValueError( "Arrays with more than three axes are only supported when" " the additional axes are singletons." ) if x.ndim == 3: y = snp.swapaxes(x, 0, -1) if sigma is not None: y = snp.stack([y, snp.ones_like(y) * sigma], -1) else: y = y[..., np.newaxis] # swap channel axis to batch axis and add singleton axis at end y = super().__call__(y) # drop singleton axis and swap axes back to original positions y = snp.swapaxes(y[..., 0], 0, -1) else: if sigma is not None: x = snp.stack([x, snp.ones_like(x) * sigma], -1) x = x[np.newaxis, ...] y = super().__call__(x) if sigma is not None: y = y[0, ..., 0] y = y.reshape(x_in_shape) return y ================================================ FILE: scico/diagnostics.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Diagnostic information for iterative solvers.""" import re import warnings from collections import OrderedDict, namedtuple from typing import List, NamedTuple, Optional, Tuple, Union from scico.numpy.util import is_array class IterationStats: """Display and record iterative algorithms statistics. Display and record statistics related to convergence of iterative algorithms. """ def __init__( self, fields: OrderedDict, ident: Optional[dict] = None, display: bool = False, period: int = 1, shift_cycles: bool = True, overwrite: bool = True, colsep: int = 2, ): """ The `fields` parameter represents an OrderedDict (to ensure that field order is retained) specifying field names for each value to be inserted and a corresponding format string for when it is displayed. When inserted values are printed in tabular form, the field lengths are taken as the maxima of the header string lengths and the field lengths embedded in the format strings (if specified). For best results, the field lengths should be manually specified based on knowledge of the ranges of values that may be encountered. For example, for a '%e' format string, the specified field length should be at least the precision (e.g. '%.2e' specifies a precision of 2 places) plus 6 when only positive values may encountered, and plus 7 when negative values may be encountered. Args: fields: A dictionary associating field names with format strings for displaying the corresponding values. ident: A dictionary associating field names. with corresponding valid identifiers for use within the namedtuple used to record results. Defaults to ``None``. display: Flag indicating whether results should be printed to stdout. Defaults to ``False``. period: Only display one result in every cycle of length `period`. shift_cycles: If ``True``, apply an offset to the iteration count so that display cycles end at 0, `period` - 1, etc. Otherwise, cycles end at `period`, 2 * `period`, etc. overwrite: If ``True``, display all results, but each one overwrites the next, except for one result per cycle. colsep: Number of spaces seperating fields in displayed tables. Defaults to 2. Raises: TypeError: If the `fields` parameter is not a dict. """ # Parameter fields must be specified as an OrderedDict to ensure # that field order is retained if not isinstance(fields, dict): raise TypeError("Argument 'fields' must be an instance of dict.") # Subsampling rate of results that are to be displayed self.period: int = period # Offset to iteration count for determining start of period self.period_offset = 1 if shift_cycles else 0 # Flag indicating whether to display and overwrite, or not display at all self.overwrite: bool = overwrite # Number of spaces seperating fields in displayed tables self.colsep: int = colsep # Main list of inserted values self.iterations: List = [] # Total length of header string in displayed tables self.headlength: int = 0 # List of field names self.fieldname: List[str] = [] # List of field format strings self.fieldformat: List[str] = [] # List of lengths of each field in displayed tables self.fieldlength: List[int] = [] # Names of fields in namedtuple used to record iteration values self.tuplefields: List[str] = [] # Compile regex for decomposing format strings fmre = re.compile(r"%(\+?-?)((?:\d+)?)(\.?)((?:\d+)?)([a-z])") # Iterate over field names for name in fields: # Get format string and decompose it using compiled regex fmt = fields[name] fmtmatch = fmre.match(fmt) if not fmtmatch: raise ValueError(f"Format string '{fmt}' could not be parsed.") fmflg, fmlen, fmdot, fmprc, fmtyp = fmtmatch.groups() flen = len(fmt % 0) # Warn if actual formatted length longer than specified field # length, e.g. as in "%4e" if fmlen != "" and flen > int(fmlen): warnings.warn( f'Actual length {flen} of format "{fmt}" for field ' f'"{name}" is longer than specified value {fmlen}', stacklevel=2, ) # If the actual formatted length is less than that of the header # string, insert a field length specifier to increase the # length to that of the header string if flen < len(name): fmt = f"%{fmflg}{len(name)}{fmdot}{fmprc}{fmtyp}" flen = len(name) self.fieldname.append(name) self.fieldformat.append(fmt) self.fieldlength.append(flen) self.headlength += flen + colsep # If a distinct identifier is specified for this field, use it # as the namedtuple identifier, otherwise compute it from the # field name if ident is not None and name in ident: self.tuplefields.append(ident[name]) else: # See https://stackoverflow.com/a/3305731 tfnm = re.sub(r"\W+|^(?=\d)", "_", name) if tfnm[0] == "_": tfnm = tfnm[1:] self.tuplefields.append(tfnm) # Decrement head length to account for final colsep added self.headlength -= colsep # Construct namedtuple used to record values self.IterTuple = namedtuple("IterationStatsTuple", self.tuplefields) # type: ignore # Set up table header string display if requested self.display = display self.disphdr = None if display: self.disphdr = ( (" " * colsep).join( ["%-*s" % (fl, fn) for fl, fn in zip(self.fieldlength, self.fieldname)] ) + "\n" + "-" * self.headlength ) def insert(self, values: Union[List, Tuple]): """Insert a list of values for a single iteration. Args: values: Statistics for a single iteration. """ scalar_values = [v.item() if is_array(v) else v for v in values] self.iterations.append(self.IterTuple(*scalar_values)) if self.display: if self.disphdr is not None: print(self.disphdr) self.disphdr = None if self.overwrite: if (len(self.iterations) - self.period_offset) % self.period == 0: end = "\n" else: end = "\r" print((" " * self.colsep).join(self.fieldformat) % values, end=end) else: if (len(self.iterations) - self.period_offset) % self.period == 0: print((" " * self.colsep).join(self.fieldformat) % values) def end(self): """Mark end of iterations. This method should be called at the end of a set of iterations. Its only function is to ensure that the displayed output is left in an appropriate state when overwriting is active with a display period other than unity. """ if ( self.display and self.overwrite and self.period > 1 and (len(self.iterations) - self.period_offset) % self.period ): print() def history(self, transpose: bool = False) -> Union[List[NamedTuple], Tuple[List]]: """Retrieve record of all inserted iterations. Args: transpose: Flag indicating whether results should be returned in "transposed" form, i.e. as a namedtuple of lists rather than a list of namedtuples. Returns: list of namedtuple or namedtuple of lists: Record of all inserted iterations. """ if transpose and self.iterations: return self.IterTuple( *[ [self.iterations[m][n] for m in range(len(self.iterations))] for n in range(len(self.iterations[0])) ] ) return self.iterations ================================================ FILE: scico/examples.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Utility functions used by example scripts.""" import glob import os import tempfile import zipfile from functools import partial from typing import List, Optional, Tuple, Union import numpy as np import jax import imageio.v3 as iio import scico.numpy as snp from scico import random, util from scico.typing import Shape from scipy.io import loadmat from scipy.ndimage import zoom def rgb2gray(rgb: np.ndarray) -> np.ndarray: """Convert an RGB image (or images) to grayscale. Args: rgb: RGB image as Nr x Nc x 3 or Nr x Nc x 3 x K array. Returns: Grayscale image as Nr x Nc or Nr x Nc x K array. """ shape: Union[Tuple[int, int, int], Tuple[int, int, int, int]] if rgb.ndim == 3: shape = (1, 1, 3) else: shape = (1, 1, 3, 1) w = np.array([0.299, 0.587, 0.114], dtype=rgb.dtype).reshape(shape) return np.sum(w * rgb, axis=2) def volume_read(path: str, ext: str = "tif") -> np.ndarray: """Read a 3D volume from a set of files in the specified directory. All files with extension `ext` (i.e. matching glob `*.ext`) in directory `path` are assumed to be image files and are read. The filenames are assumed to be such that their alphanumeric ordering corresponds to their order as volume slices. Args: path: Path to directory containing the image files. ext: Filename extension. Returns: Volume as a 3D array. """ slices = [] for file in sorted(glob.glob(os.path.join(path, "*." + ext))): image = iio.imread(file) slices.append(image) return np.dstack(slices) def get_epfl_deconv_data(channel: int, path: str, verbose: bool = False): # pragma: no cover """Download example data from EPFL Biomedical Imaging Group. Download deconvolution problem data from EPFL Biomedical Imaging Group. The downloaded data is converted to `.npz` format for convenient access via :func:`numpy.load`. The converted data is saved in a file `epfl_big_deconv_.npz` in the directory specified by `path`. Args: channel: Channel number between 0 and 2. path: Directory in which converted data is saved. verbose: Flag indicating whether to print status messages. """ # data source URL and filenames data_base_url = "http://bigwww.epfl.ch/deconvolution/bio/" data_zip_files = ["CElegans-CY3.zip", "CElegans-DAPI.zip", "CElegans-FITC.zip"] psf_zip_files = ["PSF-" + data for data in data_zip_files] # ensure path directory exists if not os.path.isdir(path): raise ValueError(f"Path {path} does not exist or is not a directory.") # create temporary directory temp_dir = tempfile.TemporaryDirectory() # download data and psf files for selected channel into temporary directory for zip_file in (data_zip_files[channel], psf_zip_files[channel]): if verbose: print(f"Downloading {zip_file} from {data_base_url}") data = util.url_get(data_base_url + zip_file) f = open(os.path.join(temp_dir.name, zip_file), "wb") f.write(data.read()) f.close() if verbose: print("Download complete") # unzip downloaded data into temporary directory for zip_file in (data_zip_files[channel], psf_zip_files[channel]): if verbose: print(f"Extracting content from zip file {zip_file}") with zipfile.ZipFile(os.path.join(temp_dir.name, zip_file), "r") as zip_ref: zip_ref.extractall(temp_dir.name) # read unzipped data files into 3D arrays and save as .npz zip_file = data_zip_files[channel] y = volume_read(os.path.join(temp_dir.name, zip_file[:-4])) zip_file = psf_zip_files[channel] psf = volume_read(os.path.join(temp_dir.name, zip_file[:-4])) npz_file = os.path.join(path, f"epfl_big_deconv_{channel}.npz") if verbose: print(f"Saving as {npz_file}") np.savez(npz_file, y=y, psf=psf) def epfl_deconv_data( channel: int, verbose: bool = False, cache_path: Optional[str] = None ) -> Tuple[np.ndarray, np.ndarray]: """Get deconvolution problem data from EPFL Biomedical Imaging Group. If the data has previously been downloaded, it will be retrieved from a local cache. Args: channel: Channel number between 0 and 2. verbose: Flag indicating whether to print status messages. cache_path: Directory in which downloaded data is cached. The default is `~/.cache/scico/examples`, where `~` represents the user home directory. Returns: tuple: A tuple (y, psf) containing: - **y** : (np.ndarray): Blurred channel data. - **psf** : (np.ndarray): Channel psf. """ # set default cache path if not specified if cache_path is None: # pragma: no cover cache_path = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples") # create cache directory and download data if not already present npz_file = os.path.join(cache_path, f"epfl_big_deconv_{channel}.npz") if not os.path.isfile(npz_file): # pragma: no cover if not os.path.isdir(cache_path): os.makedirs(cache_path) get_epfl_deconv_data(channel, path=cache_path, verbose=verbose) # load data and return y and psf arrays converted to float32 npz = np.load(npz_file) y = npz["y"].astype(np.float32) psf = npz["psf"].astype(np.float32) return y, psf def get_ucb_diffusercam_data(path: str, verbose: bool = False): # pragma: no cover """Download data from UC Berkeley Waller Lab diffusercam project. Download deconvolution problem data from UC Berkeley Waller Lab diffusercam project. The downloaded data is converted to `.npz` format for convenient access via :func:`numpy.load`. The converted data is saved in a file `ucb_diffcam_data.npz.npz` in the directory specified by `path`. Args: path: Directory in which converted data is saved. verbose: Flag indicating whether to print status messages. """ # data source URL, filenames, and request header data_base_url = "https://github.com/Waller-Lab/DiffuserCam/blob/master/example_data/" data_files = ["example_psfs.mat", "example_raw.png"] headers = {"User-Agent": "Mozilla/5.0 (X11; Linux x86_64)", "Referer": data_base_url} # ensure path directory exists if not os.path.isdir(path): raise ValueError(f"Path {path} does not exist or is not a directory.") # create temporary directory temp_dir = tempfile.TemporaryDirectory() # download data files into temporary directory for data_file in data_files: if verbose: print(f"Downloading {data_file} from {data_base_url}") data = util.url_get(data_base_url + data_file + "?raw=true", headers=headers) f = open(os.path.join(temp_dir.name, data_file), "wb") f.write(data.read()) f.close() if verbose: print("Download complete") # load data, normalize it, and save as npz y = iio.imread(os.path.join(temp_dir.name, "example_raw.png")) y = y.astype(np.float32) y -= 100.0 y /= y.max() mat = loadmat(os.path.join(temp_dir.name, "example_psfs.mat")) psf = mat["psf"].astype(np.float64) psf -= 102.0 psf /= np.linalg.norm(psf, axis=(0, 1)).min() # save as .npz npz_file = os.path.join(path, "ucb_diffcam_data.npz") if verbose: print(f"Saving as {npz_file}") np.savez(npz_file, y=y, psf=psf) def ucb_diffusercam_data( verbose: bool = False, cache_path: Optional[str] = None ) -> Tuple[np.ndarray, np.ndarray]: """Get example data from UC Berkeley Waller Lab diffusercam project. If the data has previously been downloaded, it will be retrieved from a local cache. Args: verbose: Flag indicating whether to print status messages. cache_path: Directory in which downloaded data is cached. The default is `~/.cache/scico/examples`, where `~` represents the user home directory. Returns: tuple: A tuple (y, psf) containing: - **y** : (np.ndarray): Measured image - **psf** : (np.ndarray): Stack of psfs. """ # set default cache path if not specified if cache_path is None: # pragma: no cover cache_path = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples") # create cache directory and download data if not already present npz_file = os.path.join(cache_path, "ucb_diffcam_data.npz") if not os.path.isfile(npz_file): # pragma: no cover if not os.path.isdir(cache_path): os.makedirs(cache_path) get_ucb_diffusercam_data(path=cache_path, verbose=verbose) # load data and return y and psf arrays converted to float32 npz = np.load(npz_file) y = npz["y"].astype(np.float32) psf = npz["psf"].astype(np.float64) return y, psf def downsample_volume(vol: np.ndarray, rate: int) -> np.ndarray: """Downsample a 3D array. Downsample a 3D array. If the volume dimensions can be divided by `rate`, this is achieved via averaging distinct `rate` x `rate` x `rate` block in `vol`. Otherwise it is achieved via a call to :func:`scipy.ndimage.zoom`. Args: vol: Input volume. rate: Downsampling rate. Returns: Downsampled volume. """ if rate == 1: return vol if np.all([n % rate == 0 for n in vol.shape]): vol = np.mean(np.reshape(vol, (-1, rate, vol.shape[1], vol.shape[2])), axis=1) vol = np.mean(np.reshape(vol, (vol.shape[0], -1, rate, vol.shape[2])), axis=2) vol = np.mean(np.reshape(vol, (vol.shape[0], vol.shape[1], -1, rate)), axis=3) else: vol = zoom(vol, 1.0 / rate) return vol def tile_volume_slices(x: np.ndarray, sep_width: int = 10) -> np.ndarray: """Make an image with tiled slices from an input volume. Make an image with tiled `xy`, `xz`, and `yz` slices from an input volume. Args: x: Input volume consisting of a 3D or 4D array. If the input is 4D, the final axis represents a channel index. sep_width: Number of pixels separating the slices in the output image. Returns: Image containing tiled slices. """ if x.ndim == 3: fshape: Tuple[int, ...] = (x.shape[0], sep_width) else: fshape = (x.shape[0], sep_width, 3) out = np.concatenate( ( x[:, :, x.shape[2] // 2], np.full(fshape, np.nan), x[:, x.shape[1] // 2, :], ), axis=1, ) if x.ndim == 3: fshape0: Tuple[int, ...] = (sep_width, out.shape[1]) fshape1: Tuple[int, ...] = (x.shape[2], x.shape[2] + sep_width) trans: Tuple[int, ...] = (1, 0) else: fshape0 = (sep_width, out.shape[1], 3) fshape1 = (x.shape[2], x.shape[2] + sep_width, 3) trans = (1, 0, 2) out = np.concatenate( ( out, np.full(fshape0, np.nan), np.concatenate( ( x[x.shape[0] // 2, :, :].transpose(trans), np.full(fshape1, np.nan), ), axis=1, ), ), axis=0, ) out = np.where(np.isnan(out), np.nanmax(out), out) return out def gaussian(shape: Shape, sigma: Optional[np.ndarray] = None) -> np.ndarray: r"""Construct a multivariate Gaussian distribution function. Construct a zero-mean multivariate Gaussian distribution function .. math:: f(\mb{x}) = (2 \pi)^{-N/2} \, \det(\Sigma)^{-1/2} \, \exp \left( -\frac{\mb{x}^T \, \Sigma^{-1} \, \mb{x}}{2} \right) \;, where :math:`\Sigma` is the covariance matrix of the distribution. Args: shape: Shape of output array. sigma: Covariance matrix. Returns: Sampled function. Raises: ValueError: If the array `sigma` cannot be inverted. """ if sigma is None: sigma = np.diag(np.array(shape) / 7) ** 2 N = len(shape) try: sigmainv = np.linalg.inv(sigma) sigmadet = np.linalg.det(sigma) except np.linalg.LinAlgError as e: raise ValueError(f"Invalid covariance matrix {sigma}.") from e grd = np.stack(np.mgrid[[slice(-(n - 1) / 2, (n + 1) / 2) for n in shape]], axis=-1) sigmax = np.dot(grd, sigmainv) xtsigmax = np.sum(grd * np.dot(grd, sigmainv), axis=-1) const = ((2.0 * np.pi) ** (-N / 2.0)) * (sigmadet ** (-1.0 / 2.0)) return const * np.exp(-xtsigmax / 2.0) def create_cone(shape: Shape, center: Optional[List[float]] = None) -> np.ndarray: """Compute a map of distances from a center pixel. Args: shape: Shape of the array for which the distance map is to be computed. center: Tuple of center coordinates. If ``None``, it is set to the center of the array. Returns: An array containing a map of the distances. """ if center is None: center = [(dim - 1) / 2 for dim in shape] coords = [np.arange(0, dim) for dim in shape] coord_mesh = np.meshgrid(*coords, sparse=True, indexing="ij") dist_map = sum([(coord_mesh[i] - center[i]) ** 2 for i in range(len(coord_mesh))]) dist_map = np.sqrt(dist_map) return dist_map def create_circular_phantom( shape: Shape, radius_list: list, val_list: list, center: Optional[list] = None ) -> np.ndarray: """Construct a circular phantom with given radii and intensities. This functions supports both circular (``shape`` is 2D) and spherical (``shape`` is 3D) phantoms. Args: shape: Shape of the phantom to be created. radius_list: List of radii of the rings in the phantom. val_list: List of intensity values of the rings in the phantom. center: Tuple of center coordinates. If ``None``, it is set to the center of the array. Returns: The computed phantom. """ dist_map = create_cone(shape, center) img = np.zeros(shape) for r, val in zip(radius_list, val_list): # In numpy: img[dist_map < r] = val # In jax.numpy: img = img.at[dist_map < r].set(val) img[dist_map < r] = val return img def create_3d_foam_phantom( im_shape: Shape, N_sphere: int, r_mean: float = 0.1, r_std: float = 0.001, pad: float = 0.01, is_random: bool = False, ) -> np.ndarray: """Construct a 3D phantom with random radii and centers. Args: im_shape: Shape of input image. N_sphere: Number of spheres added. r_mean: Mean radius of sphere (normalized to 1 along each axis). Default 0.1. r_std: Standard deviation of radius of sphere (normalized to 1 along each axis). Default 0.001. pad: Padding length (normalized to 1 along each axis). Default 0.01. is_random: Flag used to control randomness of phantom generation. If ``False``, random seed is set to 1 in order to make the process deterministic. Default ``False``. Returns: 3D phantom of shape `im_shape`. """ c_lo = 0.0 c_hi = 1.0 if not is_random: np.random.seed(1) coord_list = [np.linspace(0, 1, N) for N in im_shape] x = np.stack(np.meshgrid(*coord_list, indexing="ij"), axis=-1) centers = np.random.uniform(low=r_mean + pad, high=1 - r_mean - pad, size=(N_sphere, 3)) radii = r_std * np.random.randn(N_sphere) + r_mean im = np.zeros(im_shape) + c_lo for c, r in zip(centers, radii): # type: ignore dist = np.sum((x - c) ** 2, axis=-1) select = im[dist < r**2] if select.size > 0 and np.mean(select - c_lo) < 0.01 * c_hi: # In numpy: im[dist < r**2] = c_hi # In jax.numpy: im = im.at[dist < r**2].set(c_hi) im[dist < r**2] = c_hi return im def create_conv_sparse_phantom(Nx: int, Nnz: int) -> Tuple[np.ndarray, np.ndarray]: """Construct a disc dictionary and sparse coefficient maps. Construct a disc dictionary and a corresponding set of sparse coefficient maps for testing convolutional sparse coding algorithms. Args: Nx: Size of coefficient maps (3 x Nx x Nx). Nnz: Number of non-zero coefficients across all coefficient maps. Returns: A tuple consisting of a stack of 2D filters and the coefficient map array. """ # constant parameters M = 3 Nh = 7 e = 1 # create disc filters h = np.zeros((M, 2 * Nh + 1, 2 * Nh + 1)) gr, gc = np.ogrid[-Nh : Nh + 1, -Nh : Nh + 1] for m in range(M): r = 2 * m + 3 d = np.sqrt(gr**2 + gc**2) v = (np.clip(d, r - e, r + e) - (r - e)) / (2 * e) v = 1.0 - v h[m] = v # create sparse random coefficient maps np.random.seed(1234) x = np.zeros((M, Nx, Nx)) idx0 = np.random.randint(0, M, size=(Nnz,)) idx1 = np.random.randint(0, Nx, size=(2, Nnz)) val = np.random.uniform(0, 5, size=(Nnz,)) x[idx0, idx1[0], idx1[1]] = val return h, x def create_tangle_phantom(nx: int, ny: int, nz: int) -> np.ndarray: """Construct a 3D phantom using the tangle function. Args: nx: x-size of output. ny: y-size of output. nz: z-size of output. Returns: An array with shape (nz, ny, nx). """ xs = 1.0 * np.linspace(-1.0, 1.0, nx) ys = 1.0 * np.linspace(-1.0, 1.0, ny) zs = 1.0 * np.linspace(-1.0, 1.0, nz) # default ordering for meshgrid is `xy`, this makes inputs of length # M, N, P will create a mesh of N, M, P. Thus we want ys, zs and xs. xx: np.ndarray yy: np.ndarray zz: np.ndarray xx, yy, zz = np.meshgrid(ys, zs, xs, copy=True) xx = 3.0 * xx yy = 3.0 * yy zz = 3.0 * zz values = ( xx * xx * xx * xx - 5.0 * xx * xx + yy * yy * yy * yy - 5.0 * yy * yy + zz * zz * zz * zz - 5.0 * zz * zz + 11.8 ) * 0.2 + 0.5 return (values < 2.0).astype(float) @partial(jax.jit, static_argnums=0) def create_block_phantom(out_shape: Shape) -> np.ndarray: """Construct a blocky 3D phantom. Args: out_shape: desired phantom shape. Returns: Phantom. """ # make the phantom at a low resolution low_res = np.array( [ [ [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], ], [ [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0], ], [ [0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], ], ] ) positions = np.stack( np.meshgrid(*[np.linspace(-0.5, 2.5, s) for s in out_shape], indexing="ij") ) indices = np.round(positions).astype(int) return low_res[indices[0], indices[1], indices[2]] def spnoise( img: Union[np.ndarray, snp.Array], nfrac: float, nmin: float = 0.0, nmax: float = 1.0 ) -> Union[np.ndarray, snp.Array]: """Return image with salt & pepper noise imposed on it. Args: img: Input image. nfrac: Desired fraction of pixels corrupted by noise. nmin: Lower value for noise (pepper). Default 0.0. nmax: Upper value for noise (salt). Default 1.0. Returns: Noisy image """ if isinstance(img, np.ndarray): spm = np.random.uniform(-1.0, 1.0, img.shape) # type: ignore imgn = img.copy() imgn[spm < nfrac - 1.0] = nmin imgn[spm > 1.0 - nfrac] = nmax else: spm, key = random.uniform(shape=img.shape, minval=-1.0, maxval=1.0, seed=0) # type: ignore imgn = img imgn = imgn.at[spm < nfrac - 1.0].set(nmin) # type: ignore imgn = imgn.at[spm > 1.0 - nfrac].set(nmax) # type: ignore return imgn def phase_diff(x: snp.Array, y: snp.Array) -> snp.Array: """Distance between phase angles. Compute the distance between two arrays of phase angles, with appropriate phase wrapping to minimize the distance. Args: x: Input array. y: Input array. Returns: Array of angular distances. """ mod = snp.mod(snp.abs(x - y), 2 * snp.pi) return snp.minimum(mod, 2 * snp.pi - mod) ================================================ FILE: scico/flax/__init__.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Neural network models implemented in `Flax `_ and utility functions. Many of the function and parameter names used in this sub-package are based on the somewhat non-standard Flax terminology for neural network components: `model` The model is an abstract representation of the network structure that does not include specific weight values. `parameters` The parameters of a model are the weights of the network represented by the model. `variables` The variables encompass both the parameters (i.e. network weights) and secondary values that are set from training data, such as layer-dependent statistics used in batch normalization. `state` The state encompasses both a set of model parameters as well as optimizer parameters involved in training of that model. Storing the state rather than just the variables enables a warm start for additional training. | """ import sys # isort: off from ._flax import FlaxMap, load_variables, save_variables from ._models import ConvBNNet, DnCNNNet, ResNet, UNet from .inverse import MoDLNet, ODPNet from .train.input_pipeline import create_input_iter from .train.typed_dict import ConfigDict from .train.trainer import BasicFlaxTrainer from .train.apply import only_apply from .train.clu_utils import count_parameters __all__ = [ "FlaxMap", "load_variables", "save_variables", "ConvBNNet", "DnCNNNet", "ResNet", "UNet", "MoDLNet", "ODPNet", "create_input_iter", "ConfigDict", "BasicFlaxTrainer", "only_apply", "count_parameters", ] # Imported items in __all__ appear to originate in top-level flax module # except ConfigDict. for name in __all__: if name != "ConfigDict": getattr(sys.modules[__name__], name).__module__ = __name__ ================================================ FILE: scico/flax/_flax.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Convolutional neural network models implemented in Flax.""" import warnings from typing import Any, Optional warnings.simplefilter(action="ignore", category=FutureWarning) from flax import serialization from flax.linen.module import Module from scico.numpy import Array, BlockArray from scico.typing import Shape PyTree = Any def load_variables(filename: str) -> PyTree: """Load trained model variables. Args: filename: Name of file containing trained model variables. Returns: A tree-like structure containing the values of the model variables. """ with open(filename, "rb") as data_file: bytes_input = data_file.read() variables = serialization.msgpack_restore(bytes_input) var_in = {"params": variables["params"], "batch_stats": variables["batch_stats"]} return var_in def save_variables(variables: PyTree, filename: str): """Save trained model weights. Args: filename: Name of file to to which model variables should be saved. variables: Model variables to save. """ bytes_output = serialization.msgpack_serialize(variables) with open(filename, "wb") as data_file: data_file.write(bytes_output) class FlaxMap: r"""A trained flax model.""" def __init__(self, model: Module, variables: PyTree): r"""Initialize a :class:`FlaxMap` object. Args: model: Flax model to apply. variables: Parameters and batch stats of trained model. """ self.model = model self.variables = variables super().__init__() def __call__(self, x: Array) -> Array: r"""Apply trained flax model. Args: x: Input array. Returns: Output of flax model. """ if isinstance(x, BlockArray): raise NotImplementedError # Add singleton to input as necessary: # scico typically works with (H x W) or (H x W x C) arrays # flax expects (K x H x W x C) arrays # H: spatial height W: spatial width # K: batch size C: channel size xndim = x.ndim axsqueeze: Optional[Shape] = None if xndim == 2: x = x.reshape((1,) + x.shape + (1,)) axsqueeze = (0, 3) elif xndim == 3: x = x.reshape((1,) + x.shape) axsqueeze = (0,) y = self.model.apply(self.variables, x, train=False, mutable=False) if y.ndim != xndim: return y.squeeze(axis=axsqueeze) return y ================================================ FILE: scico/flax/_models.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Flax implementation of different convolutional nets.""" import warnings warnings.simplefilter(action="ignore", category=FutureWarning) from functools import partial from typing import Any, Callable, Tuple import jax.numpy as jnp from flax.core import Scope # noqa from flax.linen import BatchNorm, Conv, max_pool, relu from flax.linen.initializers import kaiming_normal, xavier_normal from flax.linen.module import _Sentinel # noqa from flax.linen.module import Module, compact from scico.flax.blocks import ( ConvBNBlock, ConvBNMultiBlock, ConvBNPoolBlock, ConvBNUpsampleBlock, upscale_nn, ) from scico.numpy import Array # The imports of Scope and _Sentinel (above) are required to silence # "cannot resolve forward reference" warnings when building sphinx api # docs. ModuleDef = Any class DnCNNNet(Module): r"""Flax implementation of DnCNN :cite:`zhang-2017-dncnn`. Flax implementation of the convolutional neural network (CNN) architecture for denoising described in :cite:`zhang-2017-dncnn`. Attributes: depth: Number of layers in the neural network. channels: Number of channels of input tensor. num_filters: Number of filters in the convolutional layers. kernel_size: Size of the convolution filters. strides: Convolution strides. dtype: Output dtype. Default: :attr:`~numpy.float32`. act: Class of activation function to apply. Default: :func:`~flax.linen.activation.relu`. """ depth: int channels: int num_filters: int = 64 kernel_size: Tuple[int, int] = (3, 3) strides: Tuple[int, int] = (1, 1) dtype: Any = jnp.float32 act: Callable = relu @compact def __call__( self, inputs: Array, train: bool = True, ) -> Array: """Apply DnCNN denoiser. Args: inputs: The array to be transformed. train: Flag to differentiate between training and testing stages. Returns: The denoised input. """ # Definition using arguments common to all convolutions. conv = partial( Conv, use_bias=False, padding="CIRCULAR", dtype=self.dtype, kernel_init=kaiming_normal() ) # Definition using arguments common to all batch normalizations. norm = partial( BatchNorm, use_running_average=not train, momentum=0.99, epsilon=1e-5, dtype=self.dtype, ) # Definition and application of DnCNN model. base = inputs y = conv( self.num_filters, self.kernel_size, strides=self.strides, name="conv_start", )(inputs) y = self.act(y) for _ in range(self.depth - 2): y = ConvBNBlock( self.num_filters, conv=conv, norm=norm, act=self.act, kernel_size=self.kernel_size, strides=self.strides, )(y) y = conv( self.channels, self.kernel_size, strides=self.strides, name="conv_end", )(y) return base - y # residual-like network class ResNet(Module): """Flax implementation of convolutional network with residual connection. Net constructed from sucessive applications of convolution plus batch normalization blocks and ending with residual connection (i.e. adding the input to the output of the block). Args: depth: Depth of residual net. channels: Number of channels of input tensor. num_filters: Number of filters in the layers of the block. Corresponds to the number of channels in the network processing. kernel_size: Size of the convolution filters. strides: Convolution strides. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ depth: int channels: int num_filters: int = 64 kernel_size: Tuple[int, int] = (3, 3) strides: Tuple[int, int] = (1, 1) dtype: Any = jnp.float32 @compact def __call__(self, x: Array, train: bool = True) -> Array: """Apply ResNet. Args: x: The array to be transformed. train: Flag to differentiate between training and testing stages. Returns: The ResNet result. """ residual = x # Definition using arguments common to all convolutions. conv = partial( Conv, use_bias=False, padding="CIRCULAR", dtype=self.dtype, kernel_init=xavier_normal() ) # Definition using arguments common to all batch normalizations. norm = partial( BatchNorm, use_running_average=not train, momentum=0.99, epsilon=1e-5, dtype=self.dtype, ) act = relu # Definition and application of ResNet. for _ in range(self.depth - 1): x = ConvBNBlock( self.num_filters, conv=conv, norm=norm, act=act, kernel_size=self.kernel_size, strides=self.strides, )(x) x = conv( self.channels, self.kernel_size, strides=self.strides, )(x) x = norm()(x) return x + residual class ConvBNNet(Module): """Convolution and batch normalization net. Net constructed from sucessive applications of convolution plus batch normalization blocks. No residual connection. Args: depth: Depth of net. channels: Number of channels of input tensor. num_filters: Number of filters in the layers of the block. Corresponds to the number of channels in the network processing. kernel_size: Size of the convolution filters. strides: Convolution strides. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ depth: int channels: int num_filters: int = 64 kernel_size: Tuple[int, int] = (3, 3) strides: Tuple[int, int] = (1, 1) dtype: Any = jnp.float32 @compact def __call__(self, x: Array, train: bool = True) -> Array: """Apply ConvBNNet. Args: x: The array to be transformed. train: Flag to differentiate between training and testing stages. Returns: The ConvBNNet result. """ # Definition using arguments common to all convolutions. conv = partial( Conv, use_bias=False, padding="CIRCULAR", dtype=self.dtype, kernel_init=xavier_normal() ) # Definition using arguments common to all batch normalizations. norm = partial( BatchNorm, use_running_average=not train, momentum=0.99, epsilon=1e-5, dtype=self.dtype, ) act = relu # Definition and application of ConvBNNet. for _ in range(self.depth - 1): x = ConvBNBlock( self.num_filters, conv=conv, norm=norm, act=act, kernel_size=self.kernel_size, strides=self.strides, )(x) x = conv( self.channels, self.kernel_size, strides=self.strides, )(x) x = norm()(x) return x class UNet(Module): """Flax implementation of U-Net model :cite:`ronneberger-2015-unet`. Args: depth: Depth of U-Net. channels: Number of channels of input tensor. num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the network processing. kernel_size: Size of the convolution filters. strides: Convolution strides. block_depth: Number of processing layers per block. window_shape: Window for reduction for pooling and downsampling. upsampling: Factor for expanding. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ depth: int channels: int num_filters: int = 64 kernel_size: Tuple[int, int] = (3, 3) strides: Tuple[int, int] = (1, 1) block_depth: int = 2 window_shape: Tuple[int, int] = (2, 2) upsampling: int = 2 dtype: Any = jnp.float32 @compact def __call__(self, x: Array, train: bool = True) -> Array: """Apply U-Net. Args: x: The array to be transformed. train: Flag to differentiate between training and testing stages. Returns: The U-Net result. """ # Definition using arguments common to all convolutions. conv = partial( Conv, use_bias=False, padding="CIRCULAR", dtype=self.dtype, kernel_init=kaiming_normal() ) # Definition using arguments common to all batch normalizations. norm = partial( BatchNorm, use_running_average=not train, momentum=0.99, epsilon=1e-5, dtype=self.dtype, ) act = relu # Definition of upscaling function. upfn = partial(upscale_nn, scale=self.upsampling) # Definition and application of U-Net. x = ConvBNMultiBlock( self.block_depth, self.num_filters, conv=conv, norm=norm, act=act, kernel_size=self.kernel_size, strides=self.strides, )(x) residual = [] # going down j: int = 1 for _ in range(self.depth - 1): residual.append(x) # for skip connections x = ConvBNPoolBlock( 2 * j * self.num_filters, conv=conv, norm=norm, act=act, pool=max_pool, kernel_size=self.kernel_size, strides=self.strides, window_shape=self.window_shape, )(x) x = ConvBNMultiBlock( self.block_depth, 2 * j * self.num_filters, conv=conv, norm=norm, act=act, kernel_size=self.kernel_size, strides=self.strides, )(x) j = 2 * j # going up j = j // 2 # undo last res_ind = -1 for _ in range(self.depth - 1): x = ConvBNUpsampleBlock( j * self.num_filters, conv=conv, norm=norm, act=act, upfn=upfn, kernel_size=self.kernel_size, strides=self.strides, )(x) # skip connection x = jnp.concatenate((residual[res_ind], x), axis=3) x = ConvBNMultiBlock( self.block_depth, j * self.num_filters, conv=conv, norm=norm, act=act, kernel_size=self.kernel_size, strides=self.strides, )(x) res_ind -= 1 j = j // 2 # final conv1x1 ksz_out = (1, 1) x = conv(self.channels, ksz_out, strides=self.strides)(x) return x ================================================ FILE: scico/flax/blocks.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Flax implementation of different convolutional blocks.""" import warnings warnings.simplefilter(action="ignore", category=FutureWarning) from typing import Any, Callable, Tuple import jax.numpy as jnp from flax.core import Scope # noqa from flax.linen.module import _Sentinel # noqa from flax.linen.module import Module, compact from scico.numpy import Array # The imports of Scope and _Sentinel (above) are required to silence # "cannot resolve forward reference" warnings when building sphinx api # docs. ModuleDef = Any class ConvBNBlock(Module): """Define convolution and batch normalization Flax block. Args: num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. conv: Flax module implementing the convolution layer to apply. norm: Flax module implementing the batch normalization layer to apply. act: Flax function defining the activation operation to apply. kernel_size: A shape tuple defining the size of the convolution filters. strides: A shape tuple defining the size of strides in convolution. """ num_filters: int conv: ModuleDef norm: ModuleDef act: Callable[..., Array] kernel_size: Tuple[int, int] = (3, 3) strides: Tuple[int, int] = (1, 1) @compact def __call__( self, inputs: Array, ) -> Array: """Apply convolution followed by normalization and activation. Args: inputs: The array to be transformed. Returns: The transformed input. """ y = self.conv( self.num_filters, self.kernel_size, strides=self.strides, )(inputs) y = self.norm()(y) return self.act(y) class ConvBlock(Module): """Define Flax convolution block. Args: num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. conv: Flax module implementing the convolution layer to apply. act: Flax function defining the activation operation to apply. kernel_size: A shape tuple defining the size of the convolution filters. strides: A shape tuple defining the size of strides in convolution. """ num_filters: int conv: ModuleDef act: Callable[..., Array] kernel_size: Tuple[int, int] = (3, 3) strides: Tuple[int, int] = (1, 1) @compact def __call__( self, inputs: Array, ) -> Array: """Apply convolution followed by activation. Args: inputs: The array to be transformed. Returns: The transformed input. """ y = self.conv( self.num_filters, self.kernel_size, strides=self.strides, )(inputs) return self.act(y) class ConvBNPoolBlock(Module): """Define convolution, batch normalization and pooling Flax block. Args: num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. conv: Flax module implementing the convolution layer to apply. norm: Flax module implementing the batch normalization layer to apply. act: Flax function defining the activation operation to apply. pool: Flax function defining the pooling operation to apply. kernel_size: A shape tuple defining the size of the convolution filters. strides: A shape tuple defining the size of strides in convolution. window_shape: A shape tuple defining the window to reduce over in the pooling operation. """ num_filters: int conv: ModuleDef norm: ModuleDef act: Callable[..., Array] pool: Callable[..., Array] kernel_size: Tuple[int, int] strides: Tuple[int, int] window_shape: Tuple[int, int] @compact def __call__( self, inputs: Array, ) -> Array: """Apply convolution followed by normalization, activation and pooling. Args: inputs: The array to be transformed. Returns: The transformed input. """ y = self.conv( self.num_filters, self.kernel_size, strides=self.strides, )(inputs) y = self.norm()(y) y = self.act(y) # 'SAME': pads so as to have the same output shape as input if the stride is 1. return self.pool(y, self.window_shape, strides=self.window_shape, padding="SAME") class ConvBNUpsampleBlock(Module): """Define convolution, batch normalization and upsample Flax block. Args: num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. conv: Flax module implementing the convolution layer to apply. norm: Flax module implementing the batch normalization layer to apply. act: Flax function defining the activation operation to apply. upfn: Flax function defining the upsampling operation to apply. kernel_size: A shape tuple defining the size of the convolution filters. strides: A shape tuple defining the size of strides in convolution. """ num_filters: int conv: ModuleDef norm: ModuleDef act: Callable[..., Array] upfn: Callable[..., Array] kernel_size: Tuple[int, int] strides: Tuple[int, int] @compact def __call__( self, inputs: Array, ) -> Array: """Apply convolution followed by normalization, activation and upsampling. Args: inputs: The array to be transformed. Returns: The transformed input. """ y = self.conv( self.num_filters, self.kernel_size, strides=self.strides, )(inputs) y = self.norm()(y) y = self.act(y) return self.upfn(y) class ConvBNMultiBlock(Module): """Block constructed from sucessive applications of :class:`ConvBNBlock`. Args: num_blocks: Number of convolutional batch normalization blocks to apply. Each block has its own parameters for convolution and batch normalization. num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. conv: Flax module implementing the convolution layer to apply. norm: Flax module implementing the batch normalization layer to apply. act: Flax function defining the activation operation to apply. kernel_size: A shape tuple defining the size of the convolution filters. strides: A shape tuple defining the size of strides in convolution. """ num_blocks: int num_filters: int conv: ModuleDef norm: ModuleDef act: Callable[..., Array] kernel_size: Tuple[int, int] = (3, 3) strides: Tuple[int, int] = (1, 1) @compact def __call__( self, x: Array, ) -> Array: """Apply sucessive convolution normalization and activation blocks. Apply sucessive blocks, each one composed of convolution normalization and activation. Args: x: The array to be transformed. Returns: The transformed input. """ for _ in range(self.num_blocks): x = ConvBNBlock( self.num_filters, conv=self.conv, norm=self.norm, act=self.act, kernel_size=self.kernel_size, strides=self.strides, )(x) return x def upscale_nn(x: Array, scale: int = 2) -> Array: """Nearest neighbor upscale for image batches of shape (N, H, W, C). Args: x: Input tensor of shape (N, H, W, C). scale: Integer scaling factor. Returns: Output tensor of shape (N, H * scale, W * scale, C). """ s = x.shape x = x.reshape((s[0],) + (s[1], 1, s[2], 1) + (s[3],)) x = jnp.tile(x, (1, 1, scale, 1, scale, 1)) return x.reshape((s[0],) + (scale * s[1], scale * s[2]) + (s[3],)) ================================================ FILE: scico/flax/examples/__init__.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Data utility functions used by Flax example scripts.""" from .data_preprocessing import PaddedCircularConvolve, build_blur_kernel from .examples import load_blur_data, load_ct_data, load_image_data __all__ = [ "load_ct_data", "load_blur_data", "load_image_data", "PaddedCircularConvolve", "build_blur_kernel", ] ================================================ FILE: scico/flax/examples/data_generation.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Functionality to generate training data for Flax example scripts. Computation is distributed via ray to reduce processing time. """ from functools import partial from time import time from typing import Callable, List, Tuple, Union import numpy as np try: import xdesign # noqa: F401 except ImportError: have_xdesign = False # pylint: disable=missing-class-docstring class UnitCircle: pass # pylint: enable=missing-class-docstring else: have_xdesign = True from xdesign import ( # type: ignore Foam, SimpleMaterial, UnitCircle, discrete_phantom, ) try: import os os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0" # suppress ray warning import ray # noqa: F401 except ImportError: have_ray = False else: have_ray = True import jax import jax.numpy as jnp try: from jax.extend.backend import get_backend # introduced in jax 0.4.33 except ImportError: from jax.lib.xla_bridge import get_backend from scico.linop import CircularConvolve from scico.linop.xray import XRayTransform2D from scico.numpy import Array class Foam2(UnitCircle): """Foam-like material with two attenuations. Define functionality to generate phantom with structure similar to foam with two different attenuation properties.""" def __init__( self, size_range: Union[float, List[float]] = [0.05, 0.01], gap: float = 0, porosity: float = 1, attn1: float = 1.0, attn2: float = 10.0, ): """Foam-like structure with two different attenuations. Circles for material 1 are more sparse than for material 2 by design. Args: size_range: The radius, or range of radius, of the circles to be added. Default: [0.05, 0.01]. gap: Minimum distance between circle boundaries. Default: 0. porosity: Target porosity. Must be a value between [0, 1]. Default: 1. attn1: Mass attenuation parameter for material 1. Default: 1. attn2: Mass attenuation parameter for material 2. Default: 10. """ if porosity < 0 or porosity > 1: raise ValueError("Argument 'porosity' must be in the range [0,1).") super().__init__(radius=0.5, material=SimpleMaterial(attn1)) # type: ignore self.sprinkle( # type: ignore 300, size_range, gap, material=SimpleMaterial(attn2), max_density=porosity / 2.0 ) + self.sprinkle( # type: ignore 300, size_range, gap, material=SimpleMaterial(20), max_density=porosity ) def generate_foam1_images(seed: float, size: int, ndata: int) -> np.ndarray: """Generate batch of xdesign foam-like structures. Generate batch of images with `xdesign` foam-like structure, which uses one attenuation. Args: seed: Seed for data generation. size: Size of image to generate. ndata: Number of images to generate. Returns: Array of generated data. """ if not have_xdesign: raise RuntimeError("Package xdesign is required for use of this function.") np.random.seed(seed) saux: np.ndarray = np.zeros((ndata, size, size, 1), dtype=np.float32) for i in range(ndata): foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) saux[i, ..., 0] = discrete_phantom(foam, size=size) return saux def generate_foam2_images(seed: float, size: int, ndata: int) -> np.ndarray: """Generate batch of foam2 structures. Generate batch of images with :class:`Foam2` structure (foam-like material with two different attenuations). Args: seed: Seed for data generation. size: Size of image to generate. ndata: Number of images to generate. Returns: Array of generated data. """ if not have_xdesign: raise RuntimeError("Package xdesign is required for use of this function.") np.random.seed(seed) saux: np.ndarray = np.zeros((ndata, size, size, 1), dtype=np.float32) for i in range(ndata): foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) saux[i, ..., 0] = discrete_phantom(foam, size=size) # normalize saux /= np.max(saux, axis=(1, 2), keepdims=True) return saux def vector_f(f_: Callable, v: Array) -> Array: """Vectorize application of operator. Args: f_: Operator to apply. v: Array to evaluate. Returns: Result of evaluating operator over given arrays. """ lf = lambda x: jnp.atleast_3d(f_(x.squeeze())) auto_batch = jax.vmap(lf) return auto_batch(v) def batched_f(f_: Callable, vr: Array) -> Array: """Distribute application of operator over a batch of vectors among available processes. Args: f_: Operator to apply. vr: Batch of arrays to evaluate. Returns: Result of evaluating operator over given batch of arrays. This evaluation preserves the batch axis. """ nproc = jax.device_count() if vr.shape[0] != nproc: vrr = vr.reshape((nproc, -1, *vr.shape[:1])) else: vrr = vr res = jax.pmap(partial(vector_f, f_))(vrr) return res def generate_ct_data( nimg: int, size: int, nproj: int, imgfunc: Callable = generate_foam2_images, seed: int = 1234, verbose: bool = False, ) -> Tuple[Array, Array, Array]: """Generate batch of computed tomography (CT) data. Generate batch of CT data for training of machine learning network models. Args: nimg: Number of images to generate. size: Size of reconstruction images. nproj: Number of CT views. imgfunc: Function for generating input images (e.g. foams). seed: Seed for data generation. verbose: Flag indicating whether to print status messages. Returns: tuple: A tuple (img, sino, fbp) containing: - **img** : (:class:`jax.Array`): Generated foam images. - **sino** : (:class:`jax.Array`): Corresponding sinograms. - **fbp** : (:class:`jax.Array`) Corresponding filtered back projections. """ if not (have_ray and have_xdesign): raise RuntimeError("Packages ray and xdesign are required for use of this function.") # Generate input data. start_time = time() img = distributed_data_generation(imgfunc, size, nimg, seed) time_dtgen = time() - start_time # clip to [0,1] range img = jnp.clip(img, 0, 1) nproc = jax.device_count() if img.shape[0] % nproc > 0: # Decrease nimg to be a multiple of nproc if it isn't already nimg = (img.shape[0] // nproc) * nproc img = img[:nimg] # Configure a CT projection operator to generate synthetic measurements. angles = np.linspace(0, jnp.pi, nproj) # evenly spaced projection angles gt_shape = (size, size) dx = 1.0 / np.sqrt(2) det_count = int(size * 1.05 / np.sqrt(2.0)) A = XRayTransform2D(gt_shape, angles, dx=dx, det_count=det_count) # Compute sinograms in parallel. start_time = time() if nproc > 1: # shard array imgshd = img.reshape((nproc, -1, size, size, 1)) sinoshd = batched_f(A, imgshd) sino = sinoshd.reshape((-1, nproj, sinoshd.shape[-2], 1)) else: sino = vector_f(A, img) time_sino = time() - start_time # Compute filtered back-projection in parallel. start_time = time() if nproc > 1: fbpshd = batched_f(A.fbp, sinoshd) fbp = fbpshd.reshape((-1, size, size, 1)) else: fbp = vector_f(A.fbp, sino) time_fbp = time() - start_time # Normalize sinogram. sino = sino / size # Clip FBP to [0,1] range. fbp = np.clip(fbp, 0, 1) if verbose: # pragma: no cover platform = get_backend().platform print(f"{'Platform':26s}{':':4s}{platform}") print(f"{'Device count':26s}{':':4s}{jax.device_count()}") print(f"{'Data generation':19s}{'time[s]:':10s}{time_dtgen:>7.2f}") print(f"{'Sinogram':19s}{'time[s]:':10s}{time_sino:>7.2f}") print(f"{'FBP':19s}{'time[s]:':10s}{time_fbp:>7.2f}") return img, sino, fbp def generate_blur_data( nimg: int, size: int, blur_kernel: Array, noise_sigma: float, imgfunc: Callable = generate_foam1_images, seed: int = 4321, verbose: bool = False, ) -> Tuple[Array, Array]: """Generate batch of blurred data. Generate batch of blurred data for training of machine learning network models. Args: nimg: Number of images to generate. size: Size of reconstruction images. blur_kernel: Kernel for blurring the generated images. noise_sigma: Level of additive Gaussian noise to apply. imgfunc: Function to generate foams. seed: Seed for data generation. verbose: Flag indicating whether to print status messages. Returns: tuple: A tuple (img, blurn) containing: - **img** : Generated foam images. - **blurn** : Corresponding blurred and noisy images. """ if not (have_ray and have_xdesign): raise RuntimeError("Packages ray and xdesign are required for use of this function.") start_time = time() img = distributed_data_generation(imgfunc, size, nimg, seed) time_dtgen = time() - start_time # Clip to [0,1] range. img = jnp.clip(img, 0, 1) nproc = jax.device_count() if img.shape[0] % nproc > 0: # Decrease nimg to be a multiple of nproc if it isn't already nimg = (img.shape[0] // nproc) * nproc img = img[:nimg] # Configure blur operator ishape = (size, size) A = CircularConvolve(h=blur_kernel, input_shape=ishape) # Compute blurred images in parallel start_time = time() if nproc > 1: # Shard array imgshd = img.reshape((nproc, -1, size, size, 1)) blurshd = batched_f(A, imgshd) blur = blurshd.reshape((-1, size, size, 1)) else: blur = vector_f(A, img) time_blur = time() - start_time # Normalize blurred images blur = blur / jnp.max(blur, axis=(1, 2), keepdims=True) # Add Gaussian noise key = jax.random.key(seed) noise = jax.random.normal(key, blur.shape) blurn = blur + noise_sigma * noise # Clip to [0,1] range. blurn = jnp.clip(blurn, 0, 1) if verbose: # pragma: no cover platform = get_backend().platform print(f"{'Platform':26s}{':':4s}{platform}") print(f"{'Device count':26s}{':':4s}{jax.device_count()}") print(f"{'Data generation':19s}{'time[s]:':10s}{time_dtgen:>7.2f}") print(f"{'Blur generation':19s}{'time[s]:':10s}{time_blur:>7.2f}") return img, blurn def distributed_data_generation( imgenf: Callable, size: int, nimg: int, seedg: float = 123 ) -> np.ndarray: """Data generation distributed among processes using ray. *Warning:* callable `imgenf` should not make use of any jax functions to avoid the risk of errors when running with GPU devices, in which case jax is initialized to expect the availability of GPUs, which are then not available within the `ray.remote` function due to the absence of any declared GPUs as a `num_gpus` parameter of `@ray.remote`. Args: imagenf: Function for batch-data generation. size: Size of image to generate. ndata: Number of images to generate. seedg: Base seed for data generation. Returns: Array of generated data. """ if not have_ray: raise RuntimeError("Package ray is required for use of this function.") if not ray.is_initialized(): raise RuntimeError("Ray must be initialized via ray.init() before calling this function.") # Use half of available CPU resources ar = ray.available_resources() nproc = max(int(ar.get("CPU", 1)) // 2, 1) # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that # can severely limit parallel execution (since ray will ensure that only # as many actors as available GPUs are created), and is expected to be # rather brittle. if "GPU" in ar: num_gpus = 1 nproc = min(nproc, int(ar.get("GPU"))) else: num_gpus = 0 if nproc > nimg: nproc = nimg if nimg % nproc > 0: # Increase nimg to be a multiple of nproc if it isn't already nimg = (nimg // nproc + 1) * nproc ndata_per_proc = int(nimg // nproc) @ray.remote(num_gpus=num_gpus) def data_gen(seed, size, ndata, imgf): return imgf(seed, size, ndata) ray_return = ray.get( [data_gen.remote(seed + seedg, size, ndata_per_proc, imgenf) for seed in range(nproc)] ) imgs = np.vstack([t for t in ray_return]) return imgs ================================================ FILE: scico/flax/examples/data_preprocessing.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Image manipulation utils.""" import glob import math import os import tarfile import tempfile from typing import Any, Callable, Optional, Tuple, Union import numpy as np import jax.numpy as jnp import imageio.v3 as iio from scico import util from scico.examples import rgb2gray from scico.flax.train.typed_dict import DataSetDict from scico.linop import CircularConvolve, LinearOperator from scico.numpy import Array from scico.typing import Shape from .typed_dict import ConfigImageSetDict def rotation90(img: Array) -> Array: """Rotate an image, or a batch of images, by 90 degrees. Rotate an image or a batch of images by 90 degrees counterclockwise. An image is an array with size H x W x C with H and W spatial dimensions and C number of channels. A batch of images is an array with size N x H x W x C with N number of images. Args: img: The array to be rotated. Returns: An image, or batch of images, rotated by 90 degrees counterclockwise. """ if img.ndim < 4: return np.swapaxes(img, 0, 1) else: return np.swapaxes(img, 1, 2) def flip(img: Array) -> Array: """Horizontal flip of an image or a batch of images. Horizontally flip an image or a batch of images. An image is an array with size H x W x C with H and W spatial dimensions and C number of channels. A batch of images is an array with size N x H x W x C with N number of images. Args: img: The array to be flipped. Returns: An image, or batch of images, flipped horizontally. """ if img.ndim < 4: return img[:, ::-1, ...] else: return img[..., ::-1, :] class CenterCrop: """Crop central part of an image to a specified size. Crop central part of an image. An image is an array with size H x W x C with H and W spatial dimensions and C number of channels. """ def __init__(self, output_size: Union[Shape, int]): """ Args: output_size: Desired output size. If int, square crop is made. """ # assert isinstance(output_size, (int, tuple)) if isinstance(output_size, int): self.output_size: Shape = (output_size, output_size) else: assert len(output_size) == 2 self.output_size = output_size def __call__(self, image: Array) -> Array: """Apply center crop. Args: image: The array to be cropped. Returns: The cropped image. """ h, w = image.shape[:2] new_h, new_w = self.output_size top = (h - new_h) // 2 left = (w - new_w) // 2 image = image[top : top + new_h, left : left + new_w] return image class PositionalCrop: """Crop an image from a given corner to a specified size. Crop an image from a given corner. An image is an array with size H x W x C with H and W spatial dimensions and C number of channels. """ def __init__(self, output_size: Union[Shape, int]): """ Args: output_size: Desired output size. If int, square crop is made. """ # assert isinstance(output_size, (int, tuple)) if isinstance(output_size, int): self.output_size: Shape = (output_size, output_size) else: assert len(output_size) == 2 self.output_size = output_size def __call__(self, image: Array, top: int, left: int) -> Array: """Apply positional crop. Args: image: The array to be cropped. top: Vertical top coordinate of corner to start cropping. left: Horizontal left coordinate of corner to start cropping. Returns: The cropped image. """ h, w = image.shape[:2] new_h, new_w = self.output_size image = image[top : top + new_h, left : left + new_w] return image class RandomNoise: """Add Gaussian noise to an image or a batch of images. Add Gaussian noise to an image or a batch of images. An image is an array with size H x W x C with H and W spatial dimensions and C number of channels. A batch of images is an array with size N x H x W x C with N number of images. The Gaussian noise is a Gaussian random variable with mean zero and given standard deviation. The standard deviation can be a fix value corresponding to the specified noise level or randomly selected on a range between 50% and 100% of the specified noise level. """ def __init__(self, noise_level: float, range_flag: bool = False): """ Args: noise_level: Standard dev of the Gaussian noise. range_flag: If ``True``, the standard dev is randomly selected between 50% and 100% of `noise_level` set. Default: ``False``. """ self.range_flag = range_flag if range_flag: self.noise_level_low = 0.5 * noise_level self.noise_level = noise_level def __call__(self, image: Array) -> Array: """Add Gaussian noise. Args: image: The array to add noise to. Returns: The noisy image. """ noise_level = self.noise_level if self.range_flag: if image.ndim > 3: num_img = image.shape[0] else: num_img = 1 noise_level_range = np.random.uniform(self.noise_level_low, self.noise_level, num_img) noise_level = noise_level_range.reshape( (noise_level_range.shape[0],) + (1,) * (image.ndim - 1) ) imgnoised = image + np.random.normal(0.0, noise_level, image.shape) imgnoised = np.clip(imgnoised, 0.0, 1.0) return imgnoised def preprocess_images( images: Array, output_size: Union[Shape, int], gray_flag: bool = False, num_img: Optional[int] = None, multi_flag: bool = False, stride: Optional[Union[Shape, int]] = None, dtype: Any = np.float32, ) -> Array: """Preprocess (scale, crop, etc.) set of images. Preprocess set of images, converting to gray scale, or cropping or sampling multiple patches from each one, or selecting a subset of them, according to specified setup. Args: images: Array of color images. output_size: Desired output size. If int, square crop is made. gray_flag: If ``True``, converts to gray scale. num_img: If specified, reads that number of images, if not reads all the images in path. multi_flag: If ``True``, samples multiple patches of specified size in each image. stride: Stride between patch origins (indexed from left-top corner). If int, the same stride is used in h and w. dtype: dtype of array. Default: :attr:`~numpy.float32`. Returns: Preprocessed array. """ # Get number of images to use. if num_img is None: num_img = images.shape[0] # Get channels of ouput image. C = 3 if gray_flag: C = 1 # Define functionality to crop and create signal array. if multi_flag: tsfm = PositionalCrop(output_size) assert stride is not None if isinstance(stride, int): stride_multi = (stride, stride) S = np.zeros((num_img, images.shape[1], images.shape[2], C), dtype=dtype) else: tsfm_crop = CenterCrop(output_size) S = np.zeros((num_img, tsfm_crop.output_size[0], tsfm_crop.output_size[1], C), dtype=dtype) # Convert to gray scale and/or crop. for i in range(S.shape[0]): img = images[i] / 255.0 if gray_flag: imgG = rgb2gray(img) # Keep channel singleton. img = imgG.reshape(imgG.shape + (1,)) if not multi_flag: # Crop image img = tsfm_crop(img) S[i] = img if multi_flag: # Sample multiple patches from image h = S.shape[1] w = S.shape[2] nh = int(math.floor((h - tsfm.output_size[0]) / stride_multi[0])) + 1 nw = int(math.floor((w - tsfm.output_size[1]) / stride_multi[1])) + 1 saux = np.zeros( (nh * nw * num_img, tsfm.output_size[0], tsfm.output_size[1], S.shape[-1]), dtype=dtype ) count2 = 0 for i in range(S.shape[0]): for top in range(0, h - tsfm.output_size[0], stride_multi[0]): for left in range(0, w - tsfm.output_size[1], stride_multi[1]): saux[count2, ...] = tsfm(S[i], top, left) count2 += 1 S = saux return S def build_image_dataset( imgs_train, imgs_test, config: ConfigImageSetDict, transf: Optional[Callable] = None ) -> Tuple[DataSetDict, ...]: """Preprocess and assemble dataset for training. Preprocess images according to the specified configuration and assemble a dataset into a structure that can be used for training machine learning models. Keep training and testing partitions. Each dictionary returned has images and labels, which are arrays of dimensions (N, H, W, C) with N: number of images; H, W: spatial dimensions and C: number of channels. Args: imgs_train: 4D array (NHWC) with images for training. imgs_test: 4D array (NHWC) with images for testing. config: Configuration of image data set to read. transf: Operator for blurring or other non-trivial transformations. Default: ``None``. Returns: tuple: A tuple (train_ds, test_ds) containing: - **train_ds** : Dictionary of training data (includes images and labels). - **test_ds** : Dictionary of testing data (includes images and labels). """ # Preprocess images by converting to gray scale or sampling multiple # patches according to specified configuration. S_train = preprocess_images( imgs_train, config["output_size"], gray_flag=config["run_gray"], num_img=config["num_img"], multi_flag=config["multi"], stride=config["stride"], ) S_test = preprocess_images( imgs_test, config["output_size"], gray_flag=config["run_gray"], num_img=config["test_num_img"], multi_flag=config["multi"], stride=config["stride"], ) # Check for transformation tsfm: Optional[Callable] = None # Processing: add noise or blur or etc. if config["data_mode"] == "dn": # Denoise problem tsfm = RandomNoise(config["noise_level"], config["noise_range"]) elif config["data_mode"] == "dcnv": # Deconvolution problem assert transf is not None tsfm = transf if config["augment"]: # Augment training data set by flip and 90 degrees rotation strain1 = rotation90(S_train.copy()) strain2 = flip(S_train.copy()) S_train = np.concatenate((S_train, strain1, strain2), axis=0) # Processing: apply transformation if tsfm is not None: if config["data_mode"] == "dn": Stsfm_train = tsfm(S_train.copy()) Stsfm_test = tsfm(S_test.copy()) elif config["data_mode"] == "dcnv": tsfm2 = RandomNoise(config["noise_level"], config["noise_range"]) Stsfm_train = tsfm2(tsfm(S_train.copy())) Stsfm_test = tsfm2(tsfm(S_test.copy())) # Shuffle data rng = np.random.default_rng(config["seed"]) perm_tr = rng.permutation(Stsfm_train.shape[0]) perm_tt = rng.permutation(Stsfm_test.shape[0]) train_ds: DataSetDict = {"image": Stsfm_train[perm_tr], "label": S_train[perm_tr]} test_ds: DataSetDict = {"image": Stsfm_test[perm_tt], "label": S_test[perm_tt]} return train_ds, test_ds def images_read(path: str, ext: str = "jpg") -> Array: # pragma: no cover """Read a collection of color images from a set of files. Read a collection of color images from a set of files in the specified directory. All files with extension `ext` (i.e. matching glob `*.ext`) in directory `path` are assumed to be image files and are read. Images may have different aspect ratios, therefore, they are transposed to keep the aspect ratio of the first image read. Args: path: Path to directory containing the image files. ext: Filename extension. Returns: Collection of color images as a 4D array. """ slices = [] shape = None for file in sorted(glob.glob(os.path.join(path, "*." + ext))): image = iio.imread(file) if shape is None: shape = image.shape[:2] if shape != image.shape[:2]: image = np.transpose(image, (1, 0, 2)) slices.append(image) return np.stack(slices) def get_bsds_data(path: str, verbose: bool = False): # pragma: no cover """Download BSDS500 data from the BSDB project. Download the BSDS500 dataset, a set of 500 color images of size 481x321 or 321x481, from the Berkeley Segmentation Dataset and Benchmark project. The downloaded data is converted to `.npz` format for convenient access via :func:`numpy.load`. The converted data is saved in a file `bsds500.npz` in the directory specified by `path`. Note that train and test folders are merged to get a set of 400 images for training while the val folder is reserved as a set of 100 images for testing. This is done in multiple works such as :cite:`zhang-2017-dncnn`. Args: path: Directory in which converted data is saved. verbose: Flag indicating whether to print status messages. """ # data source URL and filenames data_base_url = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/" data_tar_file = "BSR_bsds500.tgz" # ensure path directory exists if not os.path.isdir(path): raise ValueError(f"Path {path} does not exist or is not a directory.") # create temporary directory temp_dir = tempfile.TemporaryDirectory() if verbose: print(f"Downloading {data_tar_file} from {data_base_url}") data = util.url_get(data_base_url + data_tar_file) f = open(os.path.join(temp_dir.name, data_tar_file), "wb") f.write(data.read()) f.close() if verbose: print("Download complete") # untar downloaded data into temporary directory if verbose: print(f"Extracting content from tar file {data_tar_file}") with tarfile.open(os.path.join(temp_dir.name, data_tar_file), "r") as tar_ref: tar_ref.extractall(temp_dir.name) # read untared data files into 4D arrays and save as .npz data_path = os.path.join("BSR", "BSDS500", "data", "images") train_path = os.path.join(data_path, "train") imgs_train = images_read(os.path.join(temp_dir.name, train_path)) val_path = os.path.join(data_path, "val") imgs_val = images_read(os.path.join(temp_dir.name, val_path)) test_path = os.path.join(data_path, "test") imgs_test = images_read(os.path.join(temp_dir.name, test_path)) # Train and test data merge into train. # Leave val data for testing. imgs400 = np.vstack([imgs_train, imgs_test]) if verbose: print(f"Read {imgs400.shape[0]} images for training") print(f"Read {imgs_val.shape[0]} images for testing") npz_file = os.path.join(path, "bsds500.npz") if verbose: subpath = str.split(npz_file, ".cache") npz_file_display = "~/.cache" + subpath[-1] print(f"Saving as {npz_file_display}") np.savez(npz_file, imgstr=imgs400, imgstt=imgs_val) def build_blur_kernel( kernel_size: Shape, blur_sigma: float, dtype: Any = np.float32, ): """Construct a blur kernel as specified. Args: kernel_size: Size of the blur kernel. blur_sigma: Standard deviation of the blur kernel. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ kernel = 1.0 meshgrids = np.meshgrid(*[np.arange(size, dtype=dtype) for size in kernel_size]) for size, mgrid in zip(kernel_size, meshgrids): mean = (size - 1) / 2 kernel *= np.exp(-(((mgrid - mean) / blur_sigma) ** 2) / 2) # Make sure norm of values in gaussian kernel equals 1. knorm = np.sqrt(np.sum(kernel * kernel)) kernel = kernel / knorm return kernel class PaddedCircularConvolve(LinearOperator): """Define padded convolutional operator. The operator pads the signal with a reflection of the borders before convolving with the kernel provided at initialization. It crops the result of the convolution to maintain the same signal size. """ def __init__( self, output_size: Union[Shape, int], channels: int, kernel_size: Union[Shape, int], blur_sigma: float, dtype: Any = np.float32, ): """ Args: output_size: Size of the image to blur. channels: Number of channels in image to blur. kernel_size: Size of the blur kernel. blur_sigma: Standard deviation of the blur kernel. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ if isinstance(output_size, int): output_size = (output_size, output_size) else: assert len(output_size) == 2 if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) else: assert len(kernel_size) == 2 # Define padding. self.padsz = ( (kernel_size[0] // 2, kernel_size[0] // 2), (kernel_size[1] // 2, kernel_size[1] // 2), (0, 0), ) shape = (output_size[0], output_size[1], channels) with_pad = ( output_size[0] + self.padsz[0][0] + self.padsz[0][1], output_size[1] + self.padsz[1][0] + self.padsz[1][1], ) shape_padded = (with_pad[0], with_pad[1], channels) # Define data types. input_dtype = dtype output_dtype = dtype # Construct blur kernel as specified. kernel = build_blur_kernel(kernel_size, blur_sigma) # Define convolution part. self.conv = CircularConvolve(kernel, input_shape=shape_padded, ndims=2, input_dtype=dtype) # Initialize Linear Operator. super().__init__( input_shape=shape, output_shape=shape, input_dtype=input_dtype, output_dtype=output_dtype, jit=True, ) def _eval(self, x: Array) -> Array: """Apply operator. Args: x: The array with input signal. The input to the constructed operator should be HWC with H and W spatial dimensions given by `output_size` and C the given `channels`. Returns: The result of padding, convolving and cropping the signal. The output signal has the same HWC dimensions as the input signal. """ xpadd: Array = jnp.pad(x, self.padsz, mode="reflect") rconv: Array = self.conv(xpadd) return rconv[self.padsz[0][0] : -self.padsz[0][1], self.padsz[1][0] : -self.padsz[1][1], :] ================================================ FILE: scico/flax/examples/examples.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Generation and loading of data used in Flax example scripts.""" import os from typing import Callable, Optional, Tuple, Union import numpy as np from scico.flax.train.typed_dict import DataSetDict from scico.numpy import Array from scico.typing import Shape from .data_generation import generate_blur_data, generate_ct_data from .data_preprocessing import ConfigImageSetDict, build_image_dataset, get_bsds_data from .typed_dict import CTDataSetDict def get_cache_path(cache_path: Optional[str] = None) -> Tuple[str, str]: """Get input/output SCICO cache path. Args: cache_path: Given cache path. If ``None`` SCICO default cache path is constructed. Returns: The cache path and a display string with private user path information stripped. """ if cache_path is None: cache_path = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "data") subpath = str.split(cache_path, ".cache") cache_path_display = "~/.cache" + subpath[-1] else: cache_path_display = cache_path return cache_path, cache_path_display def load_ct_data( train_nimg: int, test_nimg: int, size: int, nproj: int, cache_path: Optional[str] = None, verbose: bool = False, ) -> Tuple[CTDataSetDict, ...]: # pragma: no cover """ Load or generate CT data. Load or generate CT data for training of machine learning network models. If cached file exists and enough data of the requested size is available, data is loaded and returned. If either `size` or `nproj` requested does not match the data read from the cached file, a `RunTimeError` is generated. If no cached file is found or not enough data is contained in the file a new data set is generated and stored in `cache_path`. The data is stored in `.npz` format for convenient access via :func:`numpy.load`. The data is saved in two distinct files: `ct_foam2_train.npz` and `ct_foam2_test.npz` to keep separated training and testing partitions. Args: train_nimg: Number of images required for training. test_nimg: Number of images required for testing. size: Size of reconstruction images. nproj: Number of CT views. cache_path: Directory in which generated data is saved. Default: ``None``. verbose: Flag indicating whether to print status messages. Default: ``False``. Returns: tuple: A tuple (trdt, ttdt) containing: - **trdt** : (Dictionary): Collection of images (key `img`), sinograms (key `sino`) and filtered back projections (key `fbp`) for training. - **ttdt** : (Dictionary): Collection of images (key `img`), sinograms (key `sino`) and filtered back projections (key `fbp`) for testing. """ # Set default cache path if not specified cache_path, cache_path_display = get_cache_path(cache_path) # Create cache directory and generate data if not already present. npz_train_file = os.path.join(cache_path, "ct_foam2_train.npz") npz_test_file = os.path.join(cache_path, "ct_foam2_test.npz") if os.path.isfile(npz_train_file) and os.path.isfile(npz_test_file): # Load data trdt_in = np.load(npz_train_file) ttdt_in = np.load(npz_test_file) # Check image size if trdt_in["img"].shape[1] != size: runtime_error_scalar("size", "training", size, trdt_in["img"].shape[1]) if ttdt_in["img"].shape[1] != size: runtime_error_scalar("size", "testing", size, ttdt_in["img"].shape[1]) # Check number of projections if trdt_in["sino"].shape[1] != nproj: runtime_error_scalar("views", "training", nproj, trdt_in["sino"].shape[1]) if ttdt_in["sino"].shape[1] != nproj: runtime_error_scalar("views", "testing", nproj, ttdt_in["sino"].shape[1]) # Check that enough data is available if trdt_in["img"].shape[0] >= train_nimg: if ttdt_in["img"].shape[0] >= test_nimg: trdt: CTDataSetDict = { "img": trdt_in["img"][:train_nimg], "sino": trdt_in["sino"][:train_nimg], "fbp": trdt_in["fbp"][:train_nimg], } ttdt: CTDataSetDict = { "img": ttdt_in["img"][:test_nimg], "sino": ttdt_in["sino"][:test_nimg], "fbp": ttdt_in["fbp"][:test_nimg], } if verbose: print_input_path(cache_path_display) print_data_size("training", trdt["img"].shape[0]) print_data_size("testing ", ttdt["img"].shape[0]) print_data_range("images ", trdt["img"]) print_data_range("sinogram", trdt["sino"]) print_data_range("FBP ", trdt["fbp"]) return trdt, ttdt elif verbose: print_data_warning("testing", test_nimg, ttdt_in["img"].shape[0]) elif verbose: print_data_warning("training", train_nimg, trdt_in["img"].shape[0]) # Generate new data. nimg = train_nimg + test_nimg img, sino, fbp = generate_ct_data( nimg, size, nproj, verbose=verbose, ) # Separate training and testing partitions. trdt = {"img": img[:train_nimg], "sino": sino[:train_nimg], "fbp": fbp[:train_nimg]} ttdt = {"img": img[train_nimg:], "sino": sino[train_nimg:], "fbp": fbp[train_nimg:]} # Store images, sinograms and filtered back-projections. os.makedirs(cache_path, exist_ok=True) np.savez( npz_train_file, img=img[:train_nimg], sino=sino[:train_nimg], fbp=fbp[:train_nimg], ) np.savez( npz_test_file, img=img[train_nimg:], sino=sino[train_nimg:], fbp=fbp[train_nimg:], ) if verbose: print_output_path(cache_path_display) print_data_size("training", train_nimg) print_data_size("testing ", test_nimg) print_data_range("images ", img) print_data_range("sinogram", sino) print_data_range("FBP ", fbp) return trdt, ttdt def load_blur_data( train_nimg: int, test_nimg: int, size: int, blur_kernel: Array, noise_sigma: float, cache_path: Optional[str] = None, verbose: bool = False, ) -> Tuple[DataSetDict, ...]: # pragma: no cover """Load or generate blurred data based on xdesign foam structures. Load or generate blurred data for training of machine learning network models. If cached file exists and enough data of the requested size is available, data is loaded and returned. If `size`, `blur_kernel` or `noise_sigma` requested do not match the data read from the cached file, a `RunTimeError` is generated. If no cached file is found or not enough data is contained in the file a new data set is generated and stored in `cache_path`. The data is stored in `.npz` format for convenient access via :func:`numpy.load`. The data is saved in two distinct files: `dcnv_foam1_train.npz` and `dcnv_foam1_test.npz` to keep separated training and testing partitions. Args: train_nimg: Number of images required for training. test_nimg: Number of images required for testing. size: Size of reconstruction images. blur_kernel: Kernel for blurring the generated images. noise_sigma: Level of additive Gaussian noise to apply. cache_path: Directory in which generated data is saved. Default: ``None``. verbose: Flag indicating whether to print status messages. Default: ``False``. Returns: tuple: A tuple (train_ds, test_ds) containing: - **train_ds** : Dictionary of training data (includes images and labels). - **test_ds** : Dictionary of testing data (includes images and labels). """ # Set default cache path if not specified cache_path, cache_path_display = get_cache_path(cache_path) # Create cache directory and generate data if not already present. npz_train_file = os.path.join(cache_path, "dcnv_foam1_train.npz") npz_test_file = os.path.join(cache_path, "dcnv_foam1_test.npz") if os.path.isfile(npz_train_file) and os.path.isfile(npz_test_file): # Load data and convert arrays to float32. trdt = np.load(npz_train_file) # Training ttdt = np.load(npz_test_file) # Testing train_in = trdt["image"].astype(np.float32) train_out = trdt["label"].astype(np.float32) test_in = ttdt["image"].astype(np.float32) test_out = ttdt["label"].astype(np.float32) # Check image size if train_in.shape[1] != size: runtime_error_scalar("size", "training", size, train_in.shape[1]) if test_in.shape[1] != size: runtime_error_scalar("size", "testing ", size, test_in.shape[1]) # Check noise_sigma if trdt["noise"] != noise_sigma: runtime_error_scalar("noise", "training", noise_sigma, trdt["noise"]) if ttdt["noise"] != noise_sigma: runtime_error_scalar("noise", "testing ", noise_sigma, ttdt["noise"]) # Check blur kernel blur_train = trdt["blur"].astype(np.float32) if not np.allclose(blur_kernel, blur_train): runtime_error_array("blur", "testing ", np.abs(blur_kernel - blur_train).max()) blur_test = ttdt["blur"].astype(np.float32) if not np.allclose(blur_kernel, blur_test): runtime_error_array("blur", "testing ", np.abs(blur_kernel - blur_test).max()) # Check that enough images were restored. if trdt["numimg"] >= train_nimg: if ttdt["numimg"] >= test_nimg: train_ds: DataSetDict = { "image": train_in, "label": train_out, } test_ds: DataSetDict = { "image": test_in, "label": test_out, } if verbose: print_info( "in", cache_path_display, train_ds["image"], train_ds["label"], test_ds["image"].shape[0], ) return train_ds, test_ds elif verbose: print_data_warning("testing ", test_nimg, ttdt["numimg"]) elif verbose: print_data_warning("training", train_nimg, trdt["numimg"]) # Generate new data. nimg = train_nimg + test_nimg img, blrn = generate_blur_data( nimg, size, blur_kernel, noise_sigma, verbose=verbose, ) # Separate training and testing partitions. train_ds = {"image": blrn[:train_nimg], "label": img[:train_nimg]} test_ds = {"image": blrn[train_nimg:], "label": img[train_nimg:]} # Store original and blurred images. os.makedirs(cache_path, exist_ok=True) np.savez( npz_train_file, image=train_ds["image"], label=train_ds["label"], numimg=train_nimg, noise=noise_sigma, blur=blur_kernel.astype(np.float32), ) np.savez( npz_test_file, image=test_ds["image"], label=test_ds["label"], numimg=test_nimg, noise=noise_sigma, blur=blur_kernel.astype(np.float32), ) if verbose: print_info( "out", cache_path_display, train_ds["image"], train_ds["label"], test_ds["image"].shape[0], ) return train_ds, test_ds def load_image_data( train_nimg: int, test_nimg: int, size: int, gray_flag: bool, data_mode: str = "dn", cache_path: Optional[str] = None, verbose: bool = False, noise_level: float = 0.1, noise_range: bool = False, transf: Optional[Callable] = None, stride: Optional[int] = None, augment: bool = False, ) -> Tuple[DataSetDict, ...]: # pragma: no cover """Load or load and preprocess image data. Load or load and preprocess image data for training of neural network models. The original source is the BSDS500 data from the Berkeley Segmentation Dataset and Benchmark project. Depending on the intended applications, different preprocessings can be performed to the source data. If a cached file exists, and enough images were sampled, data is loaded and returned. If either `size` or type of data (gray scale or color) requested does not match the data read from the cached file, a `RunTimeError` is generated. In contrast, there is no checking for the specific contamination (i.e. noise level, blur kernel, etc.). If no cached file is found or not enough images were sampled and stored in the file, a new data set is generated and stored in `cache_path`. The data is stored in `.npz` format for convenient access via :func:`numpy.load`. The data is saved in two distinct files: `*_bsds_train.npz` and `*_bsds_test.npz` to keep separated training and testing partitions. The * stands for `dn` if denoising problem or `dcnv` if deconvolution problem. Other types of pre-processings may be specified via the `transf` operator. Args: train_nimg: Number of images required for sampling training data. test_nimg: Number of images required for sampling testing data. size: Size of reconstruction images. gray_flag: Flag to indicate if gray scale images or color images. When ``True`` gray scale images are used. data_mode: Type of image problem. Options are: `dn` for denosing, `dcnv` for deconvolution. cache_path: Directory in which processed data is saved. Default: ``None``. verbose: Flag indicating whether to print status messages. Default: ``False``. noise_level: Standard deviation of the Gaussian noise. noise_range: Flag to indicate if a fixed or a random standard deviation must be used. Default: ``False`` i.e. fixed standard deviation given by `noise_level`. transf: Operator for blurring or other non-trivial transformations. Should be able to handle batched (NHWC) data. Default: ``None``. stride: Stride between patch origins (indexed from left-top corner). Default: 0 (i.e. no stride, only one patch per image). augment: Augment training data set by flip and 90 degrees rotation. Default: ``False`` (i.e. no augmentation). Returns: tuple: A tuple (train_ds, test_ds) containing: - **train_ds** : (DataSetDict): Dictionary of training data (includes images and labels). - **test_ds** : (DataSetDict): Dictionary of testing data (includes images and labels). """ # Set default cache path if not specified cache_path, cache_path_display = get_cache_path(cache_path) # Create cache directory and generate data if not already present. npz_train_file = os.path.join(cache_path, data_mode + "_bsds_train.npz") npz_test_file = os.path.join(cache_path, data_mode + "_bsds_test.npz") if os.path.isfile(npz_train_file) and os.path.isfile(npz_test_file): # Load data and convert arrays to float32. trdt = np.load(npz_train_file) # Training ttdt = np.load(npz_test_file) # Testing train_in = trdt["image"].astype(np.float32) train_out = trdt["label"].astype(np.float32) test_in = ttdt["image"].astype(np.float32) test_out = ttdt["label"].astype(np.float32) if check_img_data_requirements( train_nimg, test_nimg, size, gray_flag, train_in.shape, test_in.shape, trdt["numimg"], ttdt["numimg"], verbose, ): train_ds: DataSetDict = { "image": train_in, "label": train_out, } test_ds: DataSetDict = { "image": test_in, "label": test_out, } if verbose: print_info( "in", cache_path_display, train_ds["image"], train_ds["label"], test_ds["image"].shape[0], ) print( "NOTE: If blur kernel or noise parameter are changed, the cache " "must be manually\n deleted to ensure that the training data" " is regenerated with the new\n parameters." ) return train_ds, test_ds # Check if BSDS folder exists if not create and download BSDS data. bsds_cache_path = os.path.join(cache_path, "BSDS") if not os.path.isdir(bsds_cache_path): os.makedirs(bsds_cache_path) get_bsds_data(path=bsds_cache_path, verbose=verbose) # Load data, convert arrays to float32 and return # after pre-processing for specified data_mode. npz_file = os.path.join(bsds_cache_path, "bsds500.npz") npz = np.load(npz_file) imgs_train = npz["imgstr"].astype(np.float32) imgs_test = npz["imgstt"].astype(np.float32) # Generate new data. if stride is None: multi = False else: multi = True config: ConfigImageSetDict = { "output_size": size, "stride": stride, "multi": multi, "augment": augment, "run_gray": gray_flag, "num_img": train_nimg, "test_num_img": test_nimg, "data_mode": data_mode, "noise_level": noise_level, "noise_range": noise_range, "test_split": 0.2, "seed": 1234, } train_ds, test_ds = build_image_dataset(imgs_train, imgs_test, config, transf) # Store generated images. os.makedirs(cache_path, exist_ok=True) np.savez( npz_train_file, image=train_ds["image"], label=train_ds["label"], numimg=train_nimg, ) np.savez( npz_test_file, image=test_ds["image"], label=test_ds["label"], numimg=test_nimg, ) if verbose: print_info( "out", cache_path_display, train_ds["image"], train_ds["label"], test_ds["image"].shape[0], ) return train_ds, test_ds def check_img_data_requirements( train_nimg: int, test_nimg: int, size: int, gray_flag: bool, train_in_shp: Shape, test_in_shp: Shape, train_nimg_avail: int, test_nimg_avail: int, verbose: bool, ) -> bool: # pragma: no cover """Check data loaded with respect to data requirements. Args: train_nimg: Number of images required for training data. test_nimg: Number of images required for testing data. size: Size of images requested. gray_flag: Flag to indicate if gray scale images or color images are requested. When ``True`` gray scale images are used, therefore, one channel is expected. train_in_shp: Shape of images/patches loaded as training data. test_in_shp: Shape of images/patches loaded as testing data. train_nimg_avail: Number of images available in loaded training image data. test_nimg_avail: Number of images available in loaded testing image data. verbose: Flag indicating whether to print status messages. Returns: ``True`` if the loaded image data satifies requirements of size, number of samples and number of channels and ``False`` otherwise. """ # Check image size if train_in_shp[1] != size: runtime_error_scalar("size", "training", size, train_in_shp[1]) if test_in_shp[1] != size: runtime_error_scalar("size", "testing ", size, test_in_shp[1]) # Check gray scale or color images. C_train = train_in_shp[-1] C_test = test_in_shp[-1] if gray_flag: C = 1 else: C = 3 if C_train != C: runtime_error_scalar("channels", "training", C, C_train) if C_test != C: runtime_error_scalar("channels", "testing ", C, C_test) # Check that enough images were sampled. if train_nimg_avail >= train_nimg: if test_nimg_avail >= test_nimg: return True elif verbose: print_data_warning("testing ", test_nimg, test_nimg_avail) elif verbose: print_data_warning("training", train_nimg, train_nimg_avail) return False def print_input_path(path_display: str): # pragma: no cover """Display path from where data is being loaded. Args: path_display: Path for loading data. """ print(f"Data read from path: {path_display}") def print_output_path(path_display: str): # pragma: no cover """Display path where data is being stored. Args: path_display: Path for storing data. """ print(f"Storing data in path: {path_display}") def print_data_range(idstring: str, data: Array): # pragma: no cover """Display min and max values of given data array. Args: idstring: Data descriptive string. data: Array to compute min and max. """ print(f"Data range --{idstring}-- Min: {data.min():>5.2f} " f"Max: {data.max():>5.2f}") def print_data_size(idstring: str, size: int): # pragma: no cover """Display integer given. Args: idstring: Data descriptive string. size: Integer representing size of a set. """ print(f"Set --{idstring}-- size: {size}") def print_info( iomode: str, path_display: str, train_in: Array, train_out: Array, test_size: int ): # pragma: no cover """Display information related to data input/output. Args: iomode: Identification of input (load) or ouput (save) operation. path_display: Input or output path. train_in: Input features in training set. train_out: Outputs in training set. test_size: Size of testing set. """ if iomode == "in": print_input_path(path_display) else: print_output_path(path_display) print_data_size("training", train_in.shape[0]) print_data_size("testing ", test_size) print_data_range(" images ", train_in) print_data_range(" labels ", train_out) def print_data_warning(idstring: str, requested: int, available: int): # pragma: no cover """Display warning related to data size demands not satisfied. Args: idstring: Data descriptive string. requested: Size of data set requested. available: Size of data set available. """ print( f"Not enough images sampled in {idstring} file. " f"Requested: {requested} Available: {available}" ) def runtime_error_scalar( type: str, idstring: str, requested: Union[int, float], available: Union[int, float] ): """Raise run time error related to unsatisfied scalar parameter request. Raise run time error related to scalar parameter request not satisfied in available data. Args: type: Type of parameter in the request. idstring: Data descriptive string. requested: Parameter value requested. available: Parameter value available in data. """ raise RuntimeError( f"Requested value of argument '{type}' does not match value " f"read from {idstring} file. Requested: {requested} Available: " f"{available}.\nDelete cache and check data source." ) def runtime_error_array(type: str, idstring: str, maxdiff: float): """Raise run time error related to unsatisfied array parameter request. Raise run time error related to array parameter request not satisfied in available data. Args: type: Type of parameter in the request. idstring: Data descriptive string. maxdiff: Maximum error between requested and available array entries. """ raise RuntimeError( f"Requested value of argument '{type}' does not match value " f"read from {idstring} file. Maximum array difference: " f"{maxdiff:>5.3f}.\nDelete cache and check data source." ) ================================================ FILE: scico/flax/examples/typed_dict.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Definition of typed dictionaries for training data.""" import sys from typing import Optional, Union if sys.version_info >= (3, 8): from typing import TypedDict # pylint: disable=no-name-in-module else: from typing_extensions import TypedDict from scico.numpy import Array from scico.typing import Shape class CTDataSetDict(TypedDict): """Definition of the structure to store generated CT data.""" img: Array # original image sino: Array # sinogram fbp: Array # filtered back projection class ConfigImageSetDict(TypedDict): """Definition of the configuration for image data preprocessing.""" output_size: Union[int, Shape] stride: Optional[Union[Shape, int]] multi: bool augment: bool run_gray: bool num_img: int test_num_img: int data_mode: str noise_level: float noise_range: bool test_split: float seed: float ================================================ FILE: scico/flax/inverse.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Flax implementation of different imaging inversion models.""" import warnings warnings.simplefilter(action="ignore", category=FutureWarning) from functools import partial from typing import Any, Callable, Tuple import jax.numpy as jnp from jax import jit, lax, random from flax.core import Scope # noqa from flax.linen.module import _Sentinel # noqa from flax.linen.module import Module, compact from scico.flax import ResNet from scico.linop import LinearOperator from scico.numpy import Array from scico.typing import DType, PRNGKey, Shape # The imports of Scope and _Sentinel (above) are required to silence # "cannot resolve forward reference" warnings when building sphinx api # docs. ModuleDef = Any class MoDLNet(Module): """Flax implementation of MoDL :cite:`aggarwal-2019-modl`. Flax implementation of the model-based deep learning (MoDL) architecture for inverse problems described in :cite:`aggarwal-2019-modl`. Args: operator: Operator for computing forward and adjoint mappings. depth: Depth of MoDL net. channels: Number of channels of input tensor. num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. block_depth: Number of layers in the computational block. kernel_size: Size of the convolution filters. strides: Convolution strides. lmbda_ini: Initial value of the regularization weight `lambda`. dtype: Output dtype. Default: :attr:`~numpy.float32`. cg_iter: Number of iterations for cg solver. """ operator: ModuleDef depth: int channels: int num_filters: int block_depth: int kernel_size: Tuple[int, int] = (3, 3) strides: Tuple[int, int] = (1, 1) lmbda_ini: float = 0.5 dtype: Any = jnp.float32 cg_iter: int = 10 @compact def __call__(self, y: Array, train: bool = True) -> Array: """Apply MoDL net for inversion. Args: y: The array with signal to invert. train: Flag to differentiate between training and testing stages. Returns: The reconstructed signal. """ def lmbda_init_wrap(rng: PRNGKey, shape: Shape, dtype: DType = self.dtype) -> Array: return jnp.ones(shape, dtype) * self.lmbda_ini lmbda = self.param("lmbda", lmbda_init_wrap, (1,)) resnet = ResNet( self.block_depth, self.channels, self.num_filters, self.kernel_size, self.strides, dtype=self.dtype, ) ah_f = lambda v: jnp.atleast_3d(self.operator.adj(v.reshape(self.operator.output_shape))) Ahb = lax.map(ah_f, y) x = Ahb ahaI_f = lambda v: self.operator.adj(self.operator(v)) + lmbda * v cgsol = lambda b: jnp.atleast_3d( cg_solver(ahaI_f, b.reshape(self.operator.input_shape), maxiter=self.cg_iter) ) for i in range(self.depth): z = resnet(x, train) # Solve: (AH A + lmbda I) x = Ahb + lmbda * z b = Ahb + lmbda * z x = lax.map(cgsol, b) return x def cg_solver(A: Callable, b: Array, x0: Array = None, maxiter: int = 50) -> Array: r"""Conjugate gradient solver. Solve the linear system :math:`A\mb{x} = \mb{b}`, where :math:`A` is positive definite, via the conjugate gradient method. This is a light version constructed to be differentiable with the autograd functionality from jax. Therefore, (i) it uses :meth:`jax.lax.scan` to execute a fixed number of iterations and (ii) it assumes that the linear operator may use :meth:`jax.pure_callback`. Due to the utilization of a while cycle, :meth:`scico.cg` is not differentiable by jax and :meth:`jax.scipy.sparse.linalg.cg` does not support functions using :meth:`jax.pure_callback`, which is why an additional conjugate gradient function has been implemented. Args: A: Function implementing linear operator :math:`A`, should be positive definite. b: Input array :math:`\mb{b}`. x0: Initial solution. maxiter: Maximum iterations. Returns: x: Solution array. """ def fun(carry, _): """Function implementing one iteration of the conjugate gradient solver.""" x, r, p, num = carry Ap = A(p) alpha = num / (p.ravel().conj().T @ Ap.ravel()) x = x + alpha * p r = r - alpha * Ap num_old = num num = r.ravel().conj().T @ r.ravel() beta = num / num_old p = r + beta * p return (x, r, p, num), None if x0 is None: x0 = jnp.zeros_like(b) r0 = b - A(x0) num0 = r0.ravel().conj().T @ r0.ravel() carry = (x0, r0, r0, num0) carry, _ = lax.scan(fun, carry, xs=None, length=maxiter) return carry[0] class ODPProxDnBlock(Module): """Flax implementation of ODP proximal gradient denoise block. Flax implementation of the unrolled optimization with deep priors (ODP) proximal gradient block for denoising :cite:`diamond-2018-odp`. Args: operator: Operator for computing forward and adjoint mappings. In this case it corresponds to the identity operator and is used at the network level. depth: Number of layers in block. channels: Number of channels of input tensor. num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. kernel_size: Size of the convolution filters. strides: Convolution strides. alpha_ini: Initial value of the fidelity weight `alpha`. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ operator: ModuleDef depth: int channels: int num_filters: int kernel_size: Tuple[int, int] = (3, 3) strides: Tuple[int, int] = (1, 1) alpha_ini: float = 0.2 dtype: Any = jnp.float32 def batch_op_adj(self, y: Array) -> Array: """Batch application of adjoint operator.""" return self.operator.adj(y) @compact def __call__(self, x: Array, y: Array, train: bool = True) -> Array: """Apply denoising block. Args: x: The array with current stage of denoised signal. y: The array with noisy signal. train: Flag to differentiate between training and testing stages. Returns: The block output (i.e. next stage of denoised signal). """ def alpha_init_wrap(rng: PRNGKey, shape: Shape, dtype: DType = self.dtype) -> Array: return jnp.ones(shape, dtype) * self.alpha_ini alpha = self.param("alpha", alpha_init_wrap, (1,)) resnet = ResNet( self.depth, self.channels, self.num_filters, self.kernel_size, self.strides, dtype=self.dtype, ) x = (resnet(x, train) + y * alpha) / (1.0 + alpha) return x class ODPProxDcnvBlock(Module): """Flax implementation of ODP proximal gradient deconvolution block. Flax implementation of the unrolled optimization with deep priors (ODP) proximal gradient block for deconvolution under Gaussian noise :cite:`diamond-2018-odp`. Args: operator: Operator for computing forward and adjoint mappings. In this case it correponds to a circular convolution operator. depth: Number of layers in block. channels: Number of channels of input tensor. num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. kernel_size: Size of the convolution filters. strides: Convolution strides. alpha_ini: Initial value of the fidelity weight `alpha`. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ operator: ModuleDef depth: int channels: int num_filters: int kernel_size: Tuple[int, int] = (3, 3) strides: Tuple[int, int] = (1, 1) alpha_ini: float = 0.99 dtype: Any = jnp.float32 def setup(self): """Computing operator norm and setting operator for batch evaluation and defining network layers.""" self.operator_norm = jnp.sqrt(power_iteration(self.operator.H @ self.operator)[0].real) self.ah_f = lambda v: jnp.atleast_3d( self.operator.adj(v.reshape(self.operator.output_shape)) ) self.resnet = ResNet( self.depth, self.channels, self.num_filters, self.kernel_size, self.strides, dtype=self.dtype, ) def alpha_init_wrap(rng: PRNGKey, shape: Shape, dtype: DType = self.dtype) -> Array: return jnp.ones(shape, dtype) * self.alpha_ini self.alpha = self.param("alpha", alpha_init_wrap, (1,)) def batch_op_adj(self, y: Array) -> Array: """Batch application of adjoint operator.""" return lax.map(self.ah_f, y) def __call__(self, x: Array, y: Array, train: bool = True) -> Array: """Apply debluring block. Args: x: The array with current stage of reconstructed signal. y: The array with signal to invert. train: Flag to differentiate between training and testing stages. Returns: The block output (i.e. next stage of reconstructed signal). """ # DFT over spatial dimensions fft_shape: Shape = x.shape[1:-1] fft_axes: Tuple[int, int] = (1, 2) scale = 1.0 / (self.alpha * self.operator_norm**2 + 1) x = jnp.fft.irfftn( jnp.fft.rfftn( self.alpha * self.batch_op_adj(y) + self.resnet(x, train), s=fft_shape, axes=fft_axes, ) / scale, s=fft_shape, axes=fft_axes, ) return x class ODPGrDescBlock(Module): r"""Flax implementation of ODP gradient descent with :math:`\ell_2` loss block. Flax implementation of the unrolled optimization with deep priors (ODP) gradient descent block for inversion using :math:`\ell_2` loss described in :cite:`diamond-2018-odp`. Args: operator: Operator for computing forward and adjoint mappings. In this case it corresponds to the identity operator and is used at the network level. depth: Number of layers in block. channels: Number of channels of input tensor. num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. kernel_size: Size of the convolution filters. strides: Convolution strides. alpha_ini: Initial value of the fidelity weight `alpha`. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ operator: ModuleDef depth: int channels: int num_filters: int kernel_size: Tuple[int, int] = (3, 3) strides: Tuple[int, int] = (1, 1) alpha_ini: float = 0.2 dtype: Any = jnp.float32 def setup(self): """Setting operator for batch evaluation and defining network layers.""" self.ah_f = lambda v: jnp.atleast_3d( self.operator.adj(v.reshape(self.operator.output_shape)) ) self.a_f = lambda v: jnp.atleast_3d(self.operator(v.reshape(self.operator.input_shape))) self.resnet = ResNet( self.depth, self.channels, self.num_filters, self.kernel_size, self.strides, dtype=self.dtype, ) def alpha_init_wrap(rng: PRNGKey, shape: Shape, dtype: DType = self.dtype) -> Array: return jnp.ones(shape, dtype) * self.alpha_ini self.alpha = self.param("alpha", alpha_init_wrap, (1,)) def batch_op_adj(self, y: Array) -> Array: """Batch application of adjoint operator.""" return lax.map(self.ah_f, y) def __call__(self, x: Array, y: Array, train: bool = True) -> Array: """Apply gradient descent block. Args: x: The array with current stage of reconstructed signal. y: The array with signal to invert. train: Flag to differentiate between training and testing stages. Returns: The block output (i.e. next stage of inverted signal). """ x = self.resnet(x, train) - self.alpha * self.batch_op_adj(lax.map(self.a_f, x) - y) return x class ODPNet(Module): """Flax implementation of ODP network :cite:`diamond-2018-odp`. Flax implementation of the unrolled optimization with deep priors (ODP) network for inverse problems described in :cite:`diamond-2018-odp`. It can be constructed with proximal gradient blocks or gradient descent blocks. Args: operator: Operator for computing forward and adjoint mappings. depth: Depth of MoDL net. channels: Number of channels of input tensor. num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. block_depth: Number of layers in the computational block. kernel_size: Size of the convolution filters. strides: Convolution strides. alpha_ini: Initial value of the fidelity weight `alpha`. dtype: Output dtype. Default: :attr:`~numpy.float32`. odp_block: processing block to apply. Default :class:`.ODPProxDnBlock`. """ operator: ModuleDef depth: int channels: int num_filters: int block_depth: int kernel_size: Tuple[int, int] = (3, 3) strides: Tuple[int, int] = (1, 1) alpha_ini: float = 0.5 dtype: Any = jnp.float32 odp_block: Callable = ODPProxDnBlock @compact def __call__(self, y: Array, train: bool = True) -> Array: """Apply ODP net for inversion. Args: y: The array with signal to invert. train: Flag to differentiate between training and testing stages. Returns: The reconstructed signal. """ block = partial( self.odp_block, operator=self.operator, depth=self.block_depth, channels=self.channels, num_filters=self.num_filters, kernel_size=self.kernel_size, strides=self.strides, dtype=self.dtype, ) # Initial block handles initial inversion. # Not all operators are batch-ready. alpha0_i = self.alpha_ini block0 = block(alpha_ini=alpha0_i) x = block0.batch_op_adj(y) x = block0(x, y, train) alpha0_i /= 2.0 for i in range(self.depth - 1): x = block(alpha_ini=alpha0_i)(x, y, train) alpha0_i /= 2.0 return x @partial(jit, static_argnums=0) def power_iteration(A: LinearOperator, maxiter: int = 100): """Compute largest eigenvalue of a diagonalizable :class:`.LinearOperator`. Compute largest eigenvalue of a diagonalizable :class:`LinearOperator` using power iteration. This function has the same functionality as :class:`.linop.power_iteration` but is implemented using lax operations to allow jitting and general jax function composition. Args: A: :class:`LinearOperator` used for computation. Must be diagonalizable. maxiter: Maximum number of power iterations to use. Returns: tuple: A tuple (`mu`, `v`) containing: - **mu**: Estimate of largest eigenvalue of `A`. - **v**: Eigenvector of `A` with eigenvalue `mu`. """ key = random.PRNGKey(0) v = random.normal(key, shape=A.input_shape, dtype=A.input_dtype) v = v / jnp.linalg.norm(v) init_val = (0, v, v, 1.0) def cond_fun(val): return jnp.logical_and(val[0] <= maxiter, val[3] > 0.0) def body_fun(val): i, v, Av, normAv = val v = Av / normAv i = i + 1 Av = A @ v normAv = jnp.linalg.norm(Av) return (i, v, Av, normAv) def true_fun(v, Av, normAv): return jnp.sum(v.conj() * Av) / jnp.linalg.norm(v) ** 2, Av / normAv def false_fun(v, Av, normAv): return 0.0 * normAv, Av # Multiplication by zero used to preserve data type i, v, Av, normAv = lax.while_loop(cond_fun, body_fun, init_val) mu, v = lax.cond(normAv > 0.0, true_fun, false_fun, v, Av, normAv) return mu, v ================================================ FILE: scico/flax/train/__init__.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Utilities for training Flax models.""" ================================================ FILE: scico/flax/train/apply.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Functionality to evaluate Flax trained model. Uses data parallel evaluation. """ from typing import Any, Callable, Optional, Tuple import jax import jax.numpy as jnp from flax import jax_utils from scico.flax import create_input_iter from scico.numpy import Array from .checkpoints import checkpoint_restore from .clu_utils import get_parameter_overview from .learning_rate import create_cnst_lr_schedule from .state import create_basic_train_state from .typed_dict import ConfigDict, DataSetDict, ModelVarDict ModuleDef = Any def apply_fn(model: ModuleDef, variables: ModelVarDict, batch: DataSetDict) -> Array: """Apply current model. Assumes sharded batched data and replicated variables for distributed processing. This function is intended to be used via :meth:`~scico.flax.only_apply`, not directly. Args: model: Flax model to apply. variables: State of model parameters (replicated). batch: Sharded and batched training data. Returns: Output computed by given model. """ output = model.apply(variables, batch["image"], train=False, mutable=False) return output def only_apply( config: ConfigDict, model: ModuleDef, test_ds: DataSetDict, apply_fn: Callable = apply_fn, variables: Optional[ModelVarDict] = None, ) -> Tuple[Array, ModelVarDict]: """Execute model application loop. Args: config: Hyperparameter configuration. model: Flax model to apply. test_ds: Dictionary of testing data (includes images and labels). apply_fn: A hook for a function that applies current model. Default: :meth:`~scico.flax.train.apply.apply_fn`, i.e. use the standard apply function. variables: Model parameters to use for evaluation. Default: ``None`` (i.e. read from checkpoint). Returns: Output of model evaluated at the input provided in `test_ds`. Raises: RuntimeError: If no model variables and no checkpoint are specified. """ if "workdir" in config: workdir: str = config["workdir"] else: workdir = "./" if "checkpointing" in config: checkpointing: bool = config["checkpointing"] else: checkpointing = False # Configure seed. key = jax.random.key(config["seed"]) if variables is None: if checkpointing: # pragma: no cover ishape = test_ds["image"].shape[1:3] lr_ = create_cnst_lr_schedule(config) empty_state = create_basic_train_state(key, config, model, ishape, lr_) state = checkpoint_restore(empty_state, workdir) if hasattr(state, "batch_stats"): variables = { "params": state.params, "batch_stats": state.batch_stats, } # type: ignore print(get_parameter_overview(variables["params"])) print(get_parameter_overview(variables["batch_stats"])) else: variables = {"params": state.params, "batch_stats": {}} print(get_parameter_overview(variables["params"])) else: raise RuntimeError("No variables or checkpoint provided.") # For distributed testing local_batch_size = config["batch_size"] // jax.process_count() size_device_prefetch = 2 # Set for GPU # Set data iterator eval_dt_iter = create_input_iter( key, # eval: no permutation test_ds, local_batch_size, size_device_prefetch, model.dtype, train=False, ) p_apply_step = jax.pmap(apply_fn, axis_name="batch", static_broadcasted_argnums=0) # Evaluate model with provided variables variables = jax_utils.replicate(variables) num_examples = test_ds["image"].shape[0] steps_ = num_examples // config["batch_size"] output_lst = [] for _ in range(steps_): eval_batch = next(eval_dt_iter) output_batch = p_apply_step(model, variables, eval_batch) output_lst.append(output_batch.reshape((-1,) + output_batch.shape[-3:])) # Allow for completing the async run jax.random.normal(jax.random.key(0), ()).block_until_ready() # Extract one copy of variables variables = jax_utils.unreplicate(variables) # Convert to array output = jnp.array(output_lst) # Remove leading dimension output = output.reshape((-1,) + output.shape[-3:]) return output, variables # type: ignore ================================================ FILE: scico/flax/train/checkpoints.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Utilities for checkpointing Flax models.""" import logging from pathlib import Path from typing import Union try: import orbax.checkpoint as ocp have_orbax = True if not hasattr(ocp, "CheckpointManager") or not hasattr(ocp, "checkpoint_managers"): have_orbax = False except ImportError: have_orbax = False if have_orbax: from orbax.checkpoint.checkpoint_managers import LatestN logging.getLogger("absl").addFilter(logging.Filter("could not be identified as a temporary")) # remove the handler that orbax.checkpoint adds to the root logger. # see https://github.com/google/orbax/issues/1951 for h in logging.root.handlers.copy(): h.close() logging.root.removeHandler(h) from .state import TrainState from .typed_dict import ConfigDict def checkpoint_restore( state: TrainState, workdir: Union[str, Path], ok_no_ckpt: bool = False ) -> TrainState: """Load model and optimiser state. Args: state: Flax train state which includes model and optimiser parameters. workdir: Checkpoint file or directory of checkpoints to restore from. ok_no_ckpt: Flag to indicate if a checkpoint is expected. If ``False``, an error is generated if a checkpoint is not found. Returns: A restored Flax train state updated from checkpoint file is returned. If no checkpoint files are present and checkpoints are not strictly expected it returns the passed-in `state` unchanged. Raises: FileNotFoundError: If a checkpoint is expected and is not found. """ if not have_orbax: raise RuntimeError("Package orbax.checkpoint is required for use of this function.") # Check if workdir is Path or convert to Path workdir_ = workdir if isinstance(workdir_, str): workdir_ = Path(workdir_) if workdir_.exists(): mngr = ocp.CheckpointManager( workdir_, ) step = mngr.latest_step() if step is not None: restored = mngr.restore( step, args=ocp.args.Composite(state=ocp.args.StandardRestore(state)) ) mngr.wait_until_finished() mngr.close() state = restored.state elif not ok_no_ckpt: raise FileNotFoundError("Could not read from checkpoint: " + str(workdir) + ".") return state def checkpoint_save(state: TrainState, config: ConfigDict, workdir: Union[str, Path]): """Store model, model configuration, and optimiser state. Note that naming is slightly different to distinguish from Flax functions. Args: state: Flax train state which includes model and optimiser parameters. config: Python dictionary including model train configuration. workdir: Path in which to store checkpoint files. """ if not have_orbax: raise RuntimeError("Package orbax.checkpoint is required for use of this function.") # Check if workdir is Path or convert to Path workdir_ = workdir if isinstance(workdir_, str): workdir_ = Path(workdir_) options = ocp.CheckpointManagerOptions(preservation_policy=LatestN(3), create=True) mngr = ocp.CheckpointManager( workdir_, options=options, ) step = int(state.step) # Remove non-serializable partial functools in post_lst if it exists config_ = config.copy() if "post_lst" in config_: config_.pop("post_lst", None) # type: ignore mngr.save( step, args=ocp.args.Composite( state=ocp.args.StandardSave(state), config=ocp.args.JsonSave(config_), ), ) mngr.wait_until_finished() mngr.close() ================================================ FILE: scico/flax/train/clu_utils.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Utilities for displaying Flax models.""" # These utilities have been copied from the Common Loop Utils (CLU) # https://github.com/google/CommonLoopUtils/tree/main/clu # and have been modified to remove TensorFlow dependencies # CLU is licensed under the Apache License, Version 2.0, which may # be obtained from # http://www.apache.org/licenses/LICENSE-2.0 import warnings warnings.simplefilter(action="ignore", category=FutureWarning) import dataclasses from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import jax import flax PyTree = Any ParamsContainer = Union[Dict[str, np.ndarray], Mapping[str, Mapping[str, Any]]] @dataclasses.dataclass class ParamRow: """Definition of the structure of a row for printing parameters without stats.""" name: str shape: Tuple[int] size: int @dataclasses.dataclass class ParamRowWithStats(ParamRow): """Definition of the structure of a row for printing parameters with stats.""" mean: float std: float def flatten_dict( input_dict: Dict[str, Any], prefix: str = "", delimiter: str = "/" ) -> Dict[str, Any]: """Flatten keys of a nested dictionary. Args: input_dict: Nested dictionary. prefix: Prefix of already flatten. Default: empty string. delimiter: Delimiter for displaying. Default: ``/``. Returns: A dictionary with the keys flattened. """ output_dict = {} for key, value in input_dict.items(): nested_key = f"{prefix}{delimiter}{key}" if prefix else key if isinstance(value, (dict, flax.core.FrozenDict)): output_dict.update(flatten_dict(value, prefix=nested_key, delimiter=delimiter)) else: output_dict[nested_key] = value return output_dict def count_parameters(params: PyTree) -> int: """Return count of variables for the parameter dictionary. Args: params: Flax model parameters. Returns: The number of parameters in the model. """ flat_params = flatten_dict(params) return sum(np.prod(v.shape) for v in flat_params.values()) # type: ignore def get_parameter_rows( params: ParamsContainer, *, include_stats: bool = False, ) -> List[Union[ParamRow, ParamRowWithStats]]: """Return information about parameters as a list of dictionaries. Args: params: Dictionary with parameters as NumPy arrays. The dictionary can be nested. include_stats: If ``True`` add columns with mean and std for each variable. Note that this can be considerably more compute intensive and cause a lot of memory to be transferred to the host. Returns: A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value of `include_stats`. """ assert isinstance(params, (dict, flax.core.FrozenDict)) if params: params = flatten_dict(params) names, values = map(list, tuple(zip(*sorted(params.items())))) else: names, values = [], [] def make_row(name, value): if include_stats: return ParamRowWithStats( name=name, shape=value.shape, size=int(np.prod(value.shape)), mean=float(value.mean()), std=float(value.std()), ) else: return ParamRow(name=name, shape=value.shape, size=int(np.prod(value.shape))) return [make_row(name, value) for name, value in zip(names, values)] def _default_table_value_formatter(value): """Format ints with "," between thousands, and floats to 3 digits.""" if isinstance(value, bool): return str(value) elif isinstance(value, int): return "{:,}".format(value) elif isinstance(value, float): return "{:.3}".format(value) else: return str(value) def make_table( rows: List[Any], *, column_names: Optional[Sequence[str]] = None, value_formatter: Callable[[Any], str] = _default_table_value_formatter, max_lines: Optional[int] = None, ) -> str: """Render list of rows to a table. Args: rows: List of dataclass instances of a single type (e.g. `ParamRow`). column_names: List of columns that that should be included in the output. If not provided, then the columns are taken from keys of the first row. value_formatter: Callable used to format cell values. max_lines: Don't render a table longer than this. Returns: A string representation of a table as in the example below. :: +---------+---------+ | Col1 | Col2 | +---------+---------+ | value11 | value12 | | value21 | value22 | +---------+---------+ """ if any(not dataclasses.is_dataclass(row) for row in rows): raise ValueError("Expected argument 'rows' to be list of dataclasses") if len(set(map(type, rows))) > 1: raise ValueError("Expected elements of argument 'rows' be of same type.") class Column: """Definition of a column for printing parameters.""" def __init__(self, name, values): self.name = name.capitalize() self.values = values self.width = max(len(v) for v in values + [name]) if column_names is None: if not rows: return "(empty table)" column_names = [field.name for field in dataclasses.fields(rows[0])] columns = [ Column(name, [value_formatter(getattr(row, name)) for row in rows]) for name in column_names ] var_line_format = "|" + "".join(f" {{: <{c.width}s}} |" for c in columns) sep_line_format = var_line_format.replace(" ", "-").replace("|", "+") header = var_line_format.replace(">", "<").format(*[c.name for c in columns]) separator = sep_line_format.format(*["" for c in columns]) lines = [separator, header, separator] for i in range(len(rows)): if max_lines and len(lines) >= max_lines - 3: lines.append("[...]") break lines.append(var_line_format.format(*[c.values[i] for c in columns])) lines.append(separator) return "\n".join(lines) def get_parameter_overview( params: ParamsContainer, *, include_stats: bool = True, max_lines: Optional[int] = None ) -> str: """Return string with variables names, their shapes, count. Args: params: Dictionary with parameters as NumPy arrays. The dictionary can be nested. include_stats: If ``True``, add columns with mean and std for each variable. max_lines: If not ``None``, the maximum number of variables to include. Returns: A string with a table as in the example below. :: +----------------+---------------+------------+ | Name | Shape | Size | +----------------+---------------+------------+ | FC_1/weights:0 | (63612, 1024) | 65,138,688 | | FC_1/biases:0 | (1024,) | 1,024 | | FC_2/weights:0 | (1024, 32) | 32,768 | | FC_2/biases:0 | (32,) | 32 | +----------------+---------------+------------+ Total weights: 65,172,512 """ if isinstance(params, (dict, flax.core.FrozenDict)): params = jax.tree_util.tree_map(np.asarray, params) rows = get_parameter_rows(params, include_stats=include_stats) total_weights = count_parameters(params) RowType = ParamRowWithStats if include_stats else ParamRow # Pass in `column_names` to enable rendering empty tables. column_names = [field.name for field in dataclasses.fields(RowType)] table = make_table(rows, max_lines=max_lines, column_names=column_names) return table + f"\nTotal weights: {total_weights:,}" ================================================ FILE: scico/flax/train/diagnostics.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Utilities for computing and displaying performance metrics during training. Assumes sharded batched data. """ from typing import Callable, Dict, Tuple, Union from jax import lax from scico.diagnostics import IterationStats from scico.metric import snr from scico.numpy import Array from .losses import mse_loss from .typed_dict import MetricsDict def compute_metrics(output: Array, labels: Array, criterion: Callable = mse_loss) -> MetricsDict: """Compute diagnostic metrics. Assumes sharded batched data (i.e. it only works inside pmap because it needs an axis name). Args: output: Comparison signal. labels: Reference signal. criterion: Loss function. Default: :meth:`~scico.flax.train.losses.mse_loss`. Returns: Loss and SNR between `output` and `labels`. """ loss = criterion(output, labels) snr_ = snr(labels, output) metrics: MetricsDict = { "loss": loss, "snr": snr_, } metrics = lax.pmean(metrics, axis_name="batch") return metrics class ArgumentStruct: """Class that converts a dictionary into an object with named entries. Class that converts a python dictionary into an object with named entries given by the dictionary keys. After the object instantiation both modes of access (dictionary or object entries) can be used. """ def __init__(self, **entries): self.__dict__.update(entries) def stats_obj() -> Tuple[IterationStats, Callable]: """Functionality to log and store iteration statistics. This function initializes an object :class:`~.diagnostics.IterationStats` to log and store iteration statistics if logging is enabled during training. The statistics collected are: epoch, time, learning rate, loss and snr in training and loss and snr in evaluation. The :class:`~.diagnostics.IterationStats` object takes care of both printing stats to command line and storing them for further analysis. """ # epoch, time learning rate loss and snr (train and # eval) fields itstat_fields = { "Epoch": "%d", "Time": "%8.2e", "Train_LR": "%.6f", "Train_Loss": "%.6f", "Train_SNR": "%.2f", "Eval_Loss": "%.6f", "Eval_SNR": "%.2f", } itstat_attrib = [ "epoch", "time", "train_learning_rate", "train_loss", "train_snr", "loss", "snr", ] # dynamically create itstat_func; see https://stackoverflow.com/questions/24733831 itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")" scope: Dict[str, Callable] = {} exec("def itstat_func(obj): " + itstat_return, scope) default_itstat_options: Dict[str, Union[dict, Callable, bool]] = { "fields": itstat_fields, "itstat_func": scope["itstat_func"], "display": True, } itstat_insert_func: Callable = default_itstat_options.pop("itstat_func") # type: ignore itstat_object = IterationStats(**default_itstat_options) # type: ignore return itstat_object, itstat_insert_func ================================================ FILE: scico/flax/train/input_pipeline.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Generalized data handling for training script. Includes construction of data iterator and instantiation for parallel processing. """ import warnings warnings.simplefilter(action="ignore", category=FutureWarning) from typing import Any, Union import jax import jax.numpy as jnp from flax import jax_utils from scico.numpy import Array from .typed_dict import DataSetDict DType = Any KeyArray = Union[Array, jax.Array] class IterateData: """Class to load data for training and testing. It uses the generator pattern to obtain an iterable object. """ def __init__(self, dt: DataSetDict, batch_size: int, train: bool = True, key: KeyArray = None): r"""Initialize a :class:`IterateData` object. Args: dt: Dictionary of data for supervised training including images and labels. batch_size: Size of batch for iterating through the data. train: Flag indicating use of iterator for training. Iterator for training is infinite, iterator for testing passes once through the data. Default: ``True``. key: A PRNGKey used as the random key. Default: ``None``. """ self.dt = dt self.batch_size = batch_size self.train = train self.n = dt["image"].shape[0] self.key = key if key is None: self.key = jax.random.key(0) self.steps_per_epoch = self.n // batch_size self.reset() def reset(self): """Re-shuffle data in training.""" if self.train: self.key, subkey = jax.random.split(self.key) self.perms = jax.random.permutation(subkey, self.n) else: self.perms = jnp.arange(self.n) self.perms = self.perms[: self.steps_per_epoch * self.batch_size] # skips incomplete batch self.perms = self.perms.reshape((self.steps_per_epoch, self.batch_size)) self.ns = 0 def __iter__(self): return self def __next__(self): """Get next batch. During training it reshuffles the batches when the data is exhausted.""" if self.ns >= self.steps_per_epoch: if self.train: self.reset() else: self.ns = 0 batch = {k: v[self.perms[self.ns], ...] for k, v in self.dt.items()} self.ns += 1 return batch def prepare_data(xs: Array) -> Any: """Reshape input batch for parallel training.""" local_device_count = jax.local_device_count() def _prepare(x: Array) -> Array: # reshape (host_batch_size, height, width, channels) to # (local_devices, device_batch_size, height, width, channels) return x.reshape((local_device_count, -1) + x.shape[1:]) return jax.tree_util.tree_map(_prepare, xs) def create_input_iter( key: KeyArray, dataset: DataSetDict, batch_size: int, size_device_prefetch: int = 2, dtype: DType = jnp.float32, train: bool = True, ) -> Any: """Create data iterator for training. Create data iterator for training by sharding and prefetching batches on device. Args: key: A PRNGKey used for random data permutations. dataset: Dictionary of data for supervised training including images and labels. batch_size: Size of batch for iterating through the data. size_device_prefetch: Size of prefetch buffer. Default: 2. dtype: Type of data to handle. Default: :attr:`~numpy.float32`. train: Flag indicating the type of iterator to construct and use. The iterator for training permutes data on each epoch while the iterator for testing passes through the data without permuting it. Default: ``True``. Returns: Array-like data sharded to specific devices coming from an iterator built from the provided dataset. """ ds = IterateData(dataset, batch_size, train, key) it = map(prepare_data, ds) it = jax_utils.prefetch_to_device(it, size_device_prefetch) return it ================================================ FILE: scico/flax/train/learning_rate.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Learning rate schedulers.""" import optax from .typed_dict import ConfigDict def create_cnst_lr_schedule(config: ConfigDict) -> optax._src.base.Schedule: """Create learning rate to be a constant specified value. Args: config: Dictionary of configuration. The value to use corresponds to the `base_learning_rate` keyword. Returns: schedule: A function that maps step counts to values. """ schedule = optax.constant_schedule(config["base_learning_rate"]) return schedule def create_exp_lr_schedule(config: ConfigDict) -> optax._src.base.Schedule: """Create learning rate schedule to have an exponential decay. Args: config: Dictionary of configuration. The values to use correspond to `base_learning_rate`, `num_epochs`, `steps_per_epochs` and `lr_decay_rate`. Returns: schedule: A function that maps step counts to values. """ decay_steps = config["num_epochs"] * config["steps_per_epoch"] schedule = optax.exponential_decay( config["base_learning_rate"], decay_steps, config["lr_decay_rate"] ) return schedule def create_cosine_lr_schedule(config: ConfigDict) -> optax._src.base.Schedule: """Create learning rate to follow a pre-specified schedule. Create learning rate to follow a pre-specified schedule with warmup and cosine stages. Args: config: Dictionary of configuration. The parameters to use correspond to keywords: `base_learning_rate`, `num_epochs`, `warmup_epochs` and `steps_per_epoch`. Returns: schedule: A function that maps step counts to values. """ # Warmup stage warmup_fn = optax.linear_schedule( init_value=0.0, end_value=config["base_learning_rate"], transition_steps=config["warmup_epochs"] * config["steps_per_epoch"], ) # Cosine stage cosine_epochs = max(config["num_epochs"] - config["warmup_epochs"], 1) cosine_fn = optax.cosine_decay_schedule( init_value=config["base_learning_rate"], decay_steps=cosine_epochs * config["steps_per_epoch"], ) schedule = optax.join_schedules( schedules=[warmup_fn, cosine_fn], boundaries=[config["warmup_epochs"] * config["steps_per_epoch"]], ) return schedule ================================================ FILE: scico/flax/train/losses.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Definition of loss functions for model optimization.""" import jax.numpy as jnp import optax from scico.numpy import Array def mse_loss(output: Array, labels: Array) -> float: """Compute Mean Squared Error (MSE) loss for training via Optax. Args: output: Comparison signal. labels: Reference signal. Returns: MSE between `output` and `labels`. """ mse = optax.l2_loss(output, labels) return jnp.mean(mse) ================================================ FILE: scico/flax/train/spectral.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Utils for spectral normalization of convolutional layers in Flax models.""" import warnings warnings.simplefilter(action="ignore", category=FutureWarning) from typing import Any, Callable, Sequence import numpy as np import jax import jax.numpy as jnp from jax import lax import scipy from flax.core import freeze, unfreeze from flax.linen import Conv from flax.linen.module import Module, compact from scico.numpy import Array from scico.typing import Shape from .traversals import ModelParamTraversal PyTree = Any # From https://github.com/deepmind/dm-haiku/issues/71 def _l2_normalize(x: Array, eps: float = 1e-12) -> Array: r"""Normalize array by its :math:`\el_2` norm. Args: x: Array to be normalized. eps: Small value to prevent divide by zero. Default: 1e-12. Returns: Normalized array. """ return x * lax.rsqrt((x**2).sum() + eps) # From https://nbviewer.org/gist/shoyer/fa9a29fd0880e2e033d7696585978bfc def estimate_spectral_norm( f: Callable, input_shape: Shape, seed: float = 0, n_steps: int = 10, eps: float = 1e-12 ): """Estimate spectral norm of operator. This function estimates the spectral norm of an operator by estimating the singular vectors of the operator via the power iteration method and the transpose operator enabled by nested autodiff in JAX. Args: f: Operator to compute spectral norm. input_shape: Shape of input to operator. seed: Value to seed the random generation. Default: 0. n_steps: Number of power iterations to compute. Default: 10. eps: Small value to prevent divide by zero. Default: 1e-12. Returns: Spectral norm. """ rng = jax.random.key(seed) u0 = jax.random.normal(rng, input_shape) v0 = jnp.zeros_like(f(u0)) def fun(carry, _): u, v = carry v, f_vjp = jax.vjp(f, u) v = _l2_normalize(v, eps) (u,) = f_vjp(v) u = _l2_normalize(u, eps) return (u, v), None (u, v), _ = lax.scan(fun, (u0, v0), xs=None, length=n_steps) return jnp.vdot(v, f(u)) class CNN(Module): """Evaluation of convolution operator via Flax convolutional layer. Evaluation of convolution operator via Flax implementation of a convolutional layer. This is form of convolution is used only for the estimation of the spectral norm of the operator. Therefore, the value of the kernel is provided too. Attributes: kernel_size: Size of the convolution filter. kernel0: Convolution filter. dtype: Output type. """ kernel_size: Sequence[int] kernel0: Array dtype: Any @compact def __call__(self, x): """Apply CNN layer. Args: x: The array to be convolved. Returns: The result of the convolution with `kernel0`. """ def kinit_wrap(rng, shape, dtype=self.dtype): return jnp.array(self.kernel0, dtype) return Conv( features=self.kernel_size[3], kernel_size=self.kernel_size[:2], use_bias=False, padding="CIRCULAR", kernel_init=kinit_wrap, )(x) def conv(inputs: Array, kernel: Array) -> Array: """Compute convolution betwen input and kernel. The convolution is evaluated via a CNN Flax model. Args: inputs: Array to compute convolution. kernel: Filter of the convolutional operator. Returns: Result of convolution of input with kernel. """ dtype = kernel.dtype inputs = jnp.asarray(inputs, dtype) kernel = jnp.asarray(kernel, dtype) rng = jax.random.key(0) # not used model = CNN(kernel_size=kernel.shape, kernel0=kernel, dtype=dtype) variables = model.init(rng, np.zeros(inputs.shape)) y = model.apply(variables, inputs) return y def spectral_normalization_conv( params: PyTree, traversal: ModelParamTraversal, xshape: Shape, n_steps: int = 10 ) -> PyTree: """Normalize parameters of convolutional layer by its spectral norm. Args: params: Current model parameters. traversal: Utility to select model parameters. xshape: Shape of input. n_steps: Number of power iterations to compute. Default: 10. """ params_out = traversal.update( lambda kernel: kernel / ( estimate_spectral_norm( lambda x: conv(x, kernel), (1, xshape[1], xshape[2], kernel.shape[2]), n_steps ) * 1.02 ), unfreeze(params), ) return freeze(params_out) # From https://nbviewer.org/gist/shoyer/fa9a29fd0880e2e033d7696585978bfc def exact_spectral_norm(f, input_shape): """Compute spectral norm of operator. This function computes the spectral norm of an operator via autodiff in JAX. Args: f: Operator to compute spectral norm. input_shape: Shape of input to operator. Returns: Spectral norm. """ dummy_input = jnp.zeros(input_shape) jacobian = jax.jacfwd(f)(dummy_input) shape = (np.prod(jacobian.shape[: -dummy_input.ndim]), np.prod(input_shape)) return scipy.linalg.svdvals(jacobian.reshape(shape)).max() ================================================ FILE: scico/flax/train/state.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Configuration of Flax Train State.""" from typing import Any, Optional, Tuple, Union import jax import jax.numpy as jnp import optax from flax.training import train_state from scico.numpy import Array from scico.typing import Shape from .typed_dict import ConfigDict, ModelVarDict ModuleDef = Any KeyArray = Union[Array, jax.Array] PyTree = Any ArrayTree = optax.Params class TrainState(train_state.TrainState): """Definition of Flax train state. Definition of Flax train state including `batch_stats` for batch normalization. """ batch_stats: Any def initialize(key: KeyArray, model: ModuleDef, ishape: Shape) -> Tuple[PyTree, ...]: """Initialize Flax model. Args: key: A PRNGKey used as the random key. model: Flax model to train. ishape: Shape of signal (image) to process by `model`. Make sure that no batch dimension is included. Returns: Initial model parameters (including `batch_stats`). """ input_shape = (1, ishape[0], ishape[1], model.channels) @jax.jit def init(*args): return model.init(*args) variables = init({"params": key}, jnp.ones(input_shape, model.dtype)) if "batch_stats" in variables: return variables["params"], variables["batch_stats"] return variables["params"] def create_basic_train_state( key: KeyArray, config: ConfigDict, model: ModuleDef, ishape: Shape, learning_rate_fn: optax._src.base.Schedule, variables0: Optional[ModelVarDict] = None, ) -> TrainState: """Create Flax basic train state and initialize. Args: key: A PRNGKey used as the random key. config: Dictionary of configuration. The values to use correspond to keywords: `opt_type` and `momentum`. model: Flax model to train. ishape: Shape of signal (image) to process by `model`. Ensure that no batch dimension is included. variables0: Optional initial state of model parameters. If not provided a random initialization is performed. Default: ``None``. learning_rate_fn: A function that maps step counts to values. Returns: state: Flax train state which includes the model apply function, the model parameters and an Optax optimizer. """ batch_stats = None if variables0 is None: aux = initialize(key, model, ishape) if len(aux) > 1: params, batch_stats = aux else: params = aux else: params = variables0["params"] if "batch_stats" in variables0: batch_stats = variables0["batch_stats"] if config["opt_type"] == "SGD": # Stochastic Gradient Descent optimiser if "momentum" in config: tx = optax.sgd( learning_rate=learning_rate_fn, momentum=config["momentum"], nesterov=True ) else: tx = optax.sgd(learning_rate=learning_rate_fn) elif config["opt_type"] == "ADAM": # Adam optimiser tx = optax.adam( learning_rate=learning_rate_fn, ) elif config["opt_type"] == "ADAMW": # Adam with weight decay regularization tx = optax.adamw( learning_rate=learning_rate_fn, ) else: raise NotImplementedError( f"Optimizer specified {config['opt_type']} has not been included in SCICO." ) if batch_stats is None: state = TrainState.create( apply_fn=model.apply, params=params, tx=tx, ) else: state = TrainState.create( apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats, ) return state ================================================ FILE: scico/flax/train/steps.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Definition of steps to iterate during training or evaluation.""" from typing import Any, Callable, List, Tuple, Union import jax from jax import lax import optax from scico.numpy import Array from .state import TrainState from .typed_dict import DataSetDict, MetricsDict KeyArray = Union[Array, jax.Array] PyTree = Any def train_step( state: TrainState, batch: DataSetDict, learning_rate_fn: optax._src.base.Schedule, criterion: Callable, metrics_fn: Callable, ) -> Tuple[TrainState, MetricsDict]: """Perform a single data parallel training step. Assumes sharded batched data. This function is intended to be used via :class:`~scico.flax.BasicFlaxTrainer`, not directly. Args: state: Flax train state which includes the model apply function, the model parameters and an Optax optimizer. batch: Sharded and batched training data. learning_rate_fn: A function to map step counts to values. This is only used for display purposes (optax optimizers are stateless, so the current learning rate is not stored). The real learning rate schedule applied is the one defined when creating the Flax state. If a different object is passed here, then the displayed value will be inaccurate. criterion: A function that specifies the loss being minimized in training. metrics_fn: A function to evaluate quality of current model. Returns: Updated parameters and diagnostic statistics. """ def loss_fn(params: PyTree): """Loss function used for training.""" output, new_model_state = state.apply_fn( { "params": params, "batch_stats": state.batch_stats, }, batch["image"], mutable=["batch_stats"], ) loss = criterion(output, batch["label"]) return loss, (new_model_state, output) step = state.step # Only to figure out current learning rate, which cannot be stored in stateless optax. # Requires agreement between the function passed here and the one used to create the # train state. lr = learning_rate_fn(step) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) aux, grads = grad_fn(state.params) # Re-use same axis_name as in call to pmap grads = lax.pmean(grads, axis_name="batch") new_model_state, output = aux[1] metrics = metrics_fn(output, batch["label"], criterion) metrics["learning_rate"] = lr # Update params and stats new_state = state.apply_gradients( grads=grads, batch_stats=new_model_state["batch_stats"], ) return new_state, metrics def train_step_post( state: TrainState, batch: DataSetDict, learning_rate_fn: optax._src.base.Schedule, criterion: Callable, train_step_fn: Callable, metrics_fn: Callable, post_lst: List[Callable], ) -> Tuple[TrainState, MetricsDict]: """Perform a single data parallel training step with postprocessing. A list of postprocessing functions (i.e. for spectral normalization or positivity condition, etc.) is applied after the gradient update. Assumes sharded batched data. This function is intended to be used via :class:`~scico.flax.BasicFlaxTrainer`, not directly. Args: state: Flax train state which includes the model apply function, the model parameters and an Optax optimizer. batch: Sharded and batched training data. learning_rate_fn: A function to map step counts to values. criterion: A function that specifies the loss being minimized in training. train_step_fn: A function that executes a training step. metrics_fn: A function to evaluate quality of current model. post_lst: List of postprocessing functions to apply to parameter set after optimizer step (e.g. clip to a specified range, normalize, etc.). Returns: Updated parameters, fulfilling additional constraints, and diagnostic statistics. """ new_state, metrics = train_step_fn(state, batch, learning_rate_fn, criterion, metrics_fn) # Post-process parameters for post_fn in post_lst: new_params = post_fn(new_state.params) new_state = new_state.replace(params=new_params) return new_state, metrics def eval_step( state: TrainState, batch: DataSetDict, criterion: Callable, metrics_fn: Callable ) -> MetricsDict: """Evaluate current model state. Assumes sharded batched data. This function is intended to be used via :class:`~scico.flax.BasicFlaxTrainer` or :meth:`~scico.flax.only_evaluate`, not directly. Args: state: Flax train state which includes the model apply function and the model parameters. batch: Sharded and batched training data. criterion: Loss function. metrics_fn: A function to evaluate quality of current model. Returns: Current diagnostic statistics. """ variables = { "params": state.params, "batch_stats": state.batch_stats, } output = state.apply_fn(variables, batch["image"], train=False, mutable=False) return metrics_fn(output, batch["label"], criterion) ================================================ FILE: scico/flax/train/trainer.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Class providing integrated access to functionality for training Flax models. Assumes sharded batched data and uses data parallel training. """ import warnings warnings.simplefilter(action="ignore", category=FutureWarning) import functools import time from typing import Any, Callable, Dict, List, Optional, Tuple, Union import jax from jax import lax from flax import jax_utils from flax.training import common_utils from scico.diagnostics import IterationStats from scico.numpy import Array from .checkpoints import checkpoint_restore, checkpoint_save from .clu_utils import get_parameter_overview from .diagnostics import ArgumentStruct, compute_metrics, stats_obj from .input_pipeline import create_input_iter from .learning_rate import create_cnst_lr_schedule from .losses import mse_loss from .state import TrainState, create_basic_train_state from .steps import eval_step, train_step, train_step_post from .typed_dict import ConfigDict, DataSetDict, MetricsDict, ModelVarDict ModuleDef = Any KeyArray = Union[Array, jax.Array] PyTree = Any DType = Any # sync across replicas def sync_batch_stats(state: TrainState) -> TrainState: """Sync the batch statistics across replicas.""" # Each device has its own version of the running average batch # statistics and those are synced before evaluation return state.replace(batch_stats=cross_replica_mean(state.batch_stats)) # pmean only works inside pmap because it needs an axis name. #: This function will average the inputs across all devices. cross_replica_mean = jax.pmap(lambda x: lax.pmean(x, "x"), "x") class BasicFlaxTrainer: """Class encapsulating Flax training configuration and execution.""" def __init__( self, config: ConfigDict, model: ModuleDef, train_ds: DataSetDict, test_ds: DataSetDict, variables0: Optional[ModelVarDict] = None, ): """Initializer for :class:`BasicFlaxTrainer`. Initializer for :class:`BasicFlaxTrainer` to configure model training and evaluation loop. Construct a Flax train state (which includes the model apply function, the model parameters and an Optax optimizer). This uses data parallel training assuming sharded batched data. Args: config: Hyperparameter configuration. model: Flax model to train. train_ds: Dictionary of training data (includes images and labels). test_ds: Dictionary of testing data (includes images and labels). variables0: Optional initial state of model parameters. """ # Configure seed if "seed" not in config: key = jax.random.key(0) else: key = jax.random.key(config["seed"]) # Split seed for data iterators and model initialization key1, key2 = jax.random.split(key) # Object for storing iteration stats self.itstat_object: Optional[IterationStats] = None # Configure training loop len_train = train_ds["image"].shape[0] len_test = test_ds["image"].shape[0] self.set_training_parameters(config, len_train, len_test) self.construct_data_iterators(train_ds, test_ds, key1, model.dtype) self.define_parallel_training_functions() self.initialize_training_state(config, key2, model, variables0) # Store configuration self.config = config def set_training_parameters( self, config: ConfigDict, len_train: int, len_test: int, ): """Extract configuration parameters and construct training functions. Parameters and functions are passed in the configuration dictionary. Default values are used when parameters are not included in configuration. Args: config: Hyperparameter configuration. len_train: Number of samples in training set. len_test: Number of samples in testing set. """ self.configure_steps(config, len_train, len_test) self.configure_reporting(config) self.configure_training_functions(config) def configure_steps( self, config: ConfigDict, len_train: int, len_test: int, ): """Configure training, evaluation and monitoring steps. Args: config: Hyperparameter configuration. len_train: Number of samples in training set. len_test: Number of samples in testing set. """ # Set required defaults if not present if "batch_size" not in config: batch_size = 2 * jax.device_count() else: batch_size = config["batch_size"] if "num_epochs" not in config: num_epochs = 10 else: num_epochs = config["num_epochs"] # Determine sharded vs. batch partition if batch_size % jax.device_count() > 0: raise ValueError("Batch size must be divisible by the number of devices.") self.local_batch_size: int = batch_size // jax.process_count() # Training steps self.steps_per_epoch: int = len_train // batch_size config["steps_per_epoch"] = self.steps_per_epoch # needed for creating lr schedule self.num_steps: int = int(self.steps_per_epoch * num_epochs) # Evaluation (over testing set) steps num_validation_examples: int = len_test if "steps_per_eval" not in config: self.steps_per_eval: int = num_validation_examples // batch_size else: self.steps_per_eval = config["steps_per_eval"] # Determine monitoring steps if "steps_per_checkpoint" not in config: self.steps_per_checkpoint: int = self.steps_per_epoch * 10 else: self.steps_per_checkpoint = config["steps_per_checkpoint"] if "log_every_steps" not in config: self.log_every_steps: int = self.steps_per_epoch * 20 else: self.log_every_steps = config["log_every_steps"] def configure_reporting(self, config: ConfigDict): """Configure logging and checkpointing. The parameters configured correspond to - **logflag**: A flag for logging to the output terminal the evolution of results. Default: ``False``. - **workdir**: Directory to write checkpoints. Default: execution directory. - **checkpointing**: A flag for checkpointing model state. Default: ``False``. - **return_state**: A flag for returning the train state instead of the model variables. Default: ``False``, i.e. return model variables. Args: config: Hyperparameter configuration. """ # Determine logging configuration if "log" in config: self.logflag: bool = config["log"] if self.logflag: self.itstat_object, self.itstat_insert_func = stats_obj() else: self.logflag = False # Determine checkpointing configuration if "workdir" in config: self.workdir: str = config["workdir"] else: self.workdir = "./" if "checkpointing" in config: self.checkpointing: bool = config["checkpointing"] else: self.checkpointing = False # Determine variable to return at end of training if "return_state" in config: # Returning Flax train state self.return_state = config["return_state"] else: # Return model variables self.return_state = False def configure_training_functions(self, config: ConfigDict): """Construct training functions. Default functions are used if not specified in configuration. The parameters configured correspond to - **lr_schedule**: A function that creates an Optax learning rate schedule. Default: :meth:`~scico.flax.train.learning_rate.create_cnst_lr_schedule`. - **criterion**: A function that specifies the loss being minimized in training. Default: :meth:`~scico.flax.train.losses.mse_loss`. - **create_train_state**: A function that creates a Flax train state and initializes it. A train state object helps to keep optimizer and module functionality grouped for training. Default: :meth:`~scico.flax.train.state.create_basic_train_state`. - **train_step_fn**: A function that executes a training step. Default: :meth:`~scico.flax.train.steps.train_step`, i.e. use the standard train step. - **eval_step_fn**: A function that executes an eval step. Default: :meth:`~scico.flax.train.steps.eval_step`, i.e. use the standard eval step. - **metrics_fn**: A function that computes metrics. Default: :meth:`~scico.flax.train.diagnostics.compute_metrics`, i.e. use the standard compute metrics function. - **post_lst**: List of postprocessing functions to apply to parameter set after optimizer step (e.g. clip to a specified range, normalize, etc.). Args: config: Hyperparameter configuration. """ if "lr_schedule" in config: create_lr_schedule: Callable = config["lr_schedule"] self.lr_schedule = create_lr_schedule(config) else: self.lr_schedule = create_cnst_lr_schedule(config) if "criterion" in config: self.criterion: Callable = config["criterion"] else: self.criterion = mse_loss if "create_train_state" in config: self.create_train_state: Callable = config["create_train_state"] else: self.create_train_state = create_basic_train_state if "train_step_fn" in config: self.train_step_fn: Callable = config["train_step_fn"] else: self.train_step_fn = train_step if "eval_step_fn" in config: self.eval_step_fn: Callable = config["eval_step_fn"] else: self.eval_step_fn = eval_step if "metrics_fn" in config: self.metrics_fn: Callable = config["metrics_fn"] else: self.metrics_fn = compute_metrics self.post_lst: Optional[List[Callable]] = None if "post_lst" in config: self.post_lst = config["post_lst"] def construct_data_iterators( self, train_ds: DataSetDict, test_ds: DataSetDict, key: KeyArray, mdtype: DType, ): """Construct iterators for training and testing (evaluation) sets. Args: train_ds: Dictionary of training data (includes images and labels). test_ds: Dictionary of testing data (includes images and labels). key: A PRNGKey used as the random key. mdtype: Output type of Flax model to be trained. """ size_device_prefetch = 2 # Set for GPU self.train_dt_iter = create_input_iter( key, train_ds, self.local_batch_size, size_device_prefetch, mdtype, train=True, ) self.eval_dt_iter = create_input_iter( key, # eval: no permutation test_ds, self.local_batch_size, size_device_prefetch, mdtype, train=False, ) self.ishape = train_ds["image"].shape[1:3] self.log( "channels: %d training signals: %d testing" " signals: %d signal size: %d\n" % ( train_ds["label"].shape[-1], train_ds["label"].shape[0], test_ds["label"].shape[0], train_ds["label"].shape[1], ) ) def define_parallel_training_functions(self): """Construct parallel versions of training functions. Construct parallel versions of training functions via :func:`jax.pmap`. """ if self.post_lst is not None: self.p_train_step = jax.pmap( functools.partial( train_step_post, train_step_fn=self.train_step_fn, learning_rate_fn=self.lr_schedule, criterion=self.criterion, metrics_fn=self.metrics_fn, post_lst=self.post_lst, ), axis_name="batch", ) else: self.p_train_step = jax.pmap( functools.partial( self.train_step_fn, learning_rate_fn=self.lr_schedule, criterion=self.criterion, metrics_fn=self.metrics_fn, ), axis_name="batch", ) self.p_eval_step = jax.pmap( functools.partial( self.eval_step_fn, criterion=self.criterion, metrics_fn=self.metrics_fn ), axis_name="batch", ) def initialize_training_state( self, config: ConfigDict, key: KeyArray, model: ModuleDef, variables0: Optional[ModelVarDict] = None, ): """Construct and initialize Flax train state. A train state object helps to keep optimizer and module functionality grouped for training. Args: config: Hyperparameter configuration. key: A PRNGKey used as the random key. model: Flax model to train. variables0: Optional initial state of model parameters. """ # Create Flax training state state = self.create_train_state( key, config, model, self.ishape, self.lr_schedule, variables0 ) # Only restore if no initialization is provided if self.checkpointing and variables0 is None: ok_no_ckpt = True # It is ok if no checkpoint is found state = checkpoint_restore(state, self.workdir, ok_no_ckpt) self.log("Network Structure:") self.log(get_parameter_overview(state.params) + "\n") if hasattr(state, "batch_stats"): self.log("Batch Normalization:") self.log(get_parameter_overview(state.batch_stats) + "\n") self.state = state def train(self) -> Tuple[Dict[str, Any], Optional[IterationStats]]: """Execute training loop. Returns: Model variables extracted from :class:`.TrainState` and iteration stats object obtained after executing the training loop. Alternatively the :class:`.TrainState` can be returned directly instead of the model variables. Note that the iteration stats object is not ``None`` only if log is enabled when configuring the training loop. """ state = self.state step_offset = int(state.step) # > 0 if restarting from checkpoint # For parallel training state = jax_utils.replicate(state) # Execute training loop and register stats t0 = time.time() self.log("Initial compilation, which might take some time ...") train_metrics: List[Any] = [] for step, batch in zip(range(step_offset, self.num_steps), self.train_dt_iter): state, metrics = self.p_train_step(state, batch) # Training metrics computed in step train_metrics.append(metrics) if step == step_offset: self.log("Initial compilation completed.\n") if (step + 1) % self.log_every_steps == 0: # sync batch statistics across replicas state = sync_batch_stats(state) self.update_metrics(state, step, train_metrics, t0) train_metrics = [] if (step + 1) % self.steps_per_checkpoint == 0 or step + 1 == self.num_steps: # sync batch statistics across replicas state = sync_batch_stats(state) self.checkpoint(state) # Wait for finishing asynchronous execution jax.random.normal(jax.random.key(0), ()).block_until_ready() # Close object for iteration stats if logging if self.logflag: assert self.itstat_object is not None self.itstat_object.end() state = sync_batch_stats(state) # Final checkpointing self.checkpoint(state) # Extract one copy of state state = jax_utils.unreplicate(state) if self.return_state: return state, self.itstat_object # type: ignore dvar: ModelVarDict = { "params": state.params, "batch_stats": state.batch_stats, } self.train_time = time.time() - t0 return dvar, self.itstat_object # type: ignore def update_metrics(self, state: TrainState, step: int, train_metrics: List[MetricsDict], t0): """Compute metrics for current model state. Metrics for training and testing (eval) sets are computed and stored in an iteration stats object. This is executed only if logging is enabled. Args: state: Flax train state which includes the model apply function and the model parameters. step: Current step in training. train_metrics: List of diagnostic statistics computed from training set. t0: Time when training loop started. """ if not self.logflag: return eval_metrics: List[Any] = [] # Build summary dictionary for logging # Include training stats train_metrics = common_utils.get_metrics(train_metrics) summary = { f"train_{k}": v for k, v in jax.tree_util.tree_map(lambda x: x.mean(), train_metrics).items() } epoch = step // self.steps_per_epoch summary["epoch"] = epoch summary["time"] = time.time() - t0 # Eval over testing set for _ in range(self.steps_per_eval): eval_batch = next(self.eval_dt_iter) metrics = self.p_eval_step(state, eval_batch) eval_metrics.append(metrics) # Compute testing metrics eval_metrics = common_utils.get_metrics(eval_metrics) # Add testing stats to summary summary_eval = jax.tree_util.tree_map(lambda x: x.mean(), eval_metrics) summary.update(summary_eval) # Update iteration stats object assert isinstance(self.itstat_object, IterationStats) # for mypy self.itstat_object.insert(self.itstat_insert_func(ArgumentStruct(**summary))) def checkpoint(self, state: TrainState): # pragma: no cover """Checkpoint training state if enabled. Args: state: Flax train state. """ if self.checkpointing: checkpoint_save(jax_utils.unreplicate(state), self.config, self.workdir) def log(self, logstr: str): """Print stats to output terminal if logging is enabled. Args: logstr: String to be logged. """ if self.logflag: print(logstr) ================================================ FILE: scico/flax/train/traversals.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Functionality to traverse, select, and update model parameters.""" from typing import Any import jax.numpy as jnp from flax.traverse_util import ModelParamTraversal PyTree = Any def construct_traversal(prmname: str) -> ModelParamTraversal: """Construct utility to select model parameters using a name filter. Args: prmname: Name of parameter to select. Returns: Flax utility to traverse and select model parameters. """ return ModelParamTraversal(lambda path, _: prmname in path) def clip_positive(params: PyTree, traversal: ModelParamTraversal, minval: float = 1e-4) -> PyTree: """Clip parameters to positive range. Args: params: Current model parameters. traversal: Utility to select model parameters. minval: Minimum value to clip selected model parameters and keep them in a positive range. Default: 1e-4. """ params_out = traversal.update(lambda x: jnp.clip(x, minval), params) return params_out def clip_range( params: PyTree, traversal: ModelParamTraversal, minval: float = 1e-4, maxval: float = 1 ) -> PyTree: """Clip parameters to specified range. Args: params: Current model parameters. traversal: Utility to select model parameters. minval: Minimum value to clip selected model parameters. Default: 1e-4. maxval: Maximum value to clip selected model parameters. Default: 1. """ params_out = traversal.update(lambda x: jnp.clip(x, minval, maxval), params) return params_out ================================================ FILE: scico/flax/train/typed_dict.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Definition of typed dictionaries for objects in training functionality.""" import sys from typing import Any, Callable, List if sys.version_info >= (3, 8): from typing import TypedDict # pylint: disable=no-name-in-module else: from typing_extensions import TypedDict from scico.numpy import Array PyTree = Any class DataSetDict(TypedDict): """Dictionary structure for training data sets. Definition of the dictionary structure expected for the training data sets. """ #: Input (Num. samples x Height x Width x Channels). image: Array #: Output (Num. samples x Height x Width x Channels) or (Num. samples x Classes). label: Array class ConfigDict(TypedDict): """Dictionary structure for training parameters. Definition of the dictionary structure expected for specifying training parameters. """ #: Value to initialize seed for random generation. seed: float #: Type of optimizer. Options: SGD, ADAM, ADAMW. opt_type: str #: Momentum for SGD optimizer in case Nesterov is ``True``. momentum: float #: Size of batch for training. batch_size: int #: Number of epochs for training (an epoch is one whole pass through the training dataset). num_epochs: int #: Starting learning rate for scheduling. base_learning_rate: float #: Rate for decaying learning rate when scheduling is used. lr_decay_rate: float #: Number of epochs if warmup scheduling is used. warmup_epochs: int #: Period of training steps to evaluate over test set. steps_per_eval: int #: Period of training steps to print current train and test metrics. log_every_steps: int #: Training steps to be executed per epoch (depends on batch size). steps_per_epoch: int #: Period of training steps to save model (if checkpointing is ``True``). steps_per_checkpoint: int #: Flag to indicate if evolution metrics are to be printed. log: bool #: Path to directory for checkpointing model parameters. workdir: str #: Flag to indicate if model parameters and optimizer state are to #: be stored while training. checkpointing: bool #: Flag to indicate if state (params and batch_stats) are to #: be returned at the end of training. return_state: bool #: Function to modify the learning rate while training (type optax schedule). lr_schedule: Callable #: Criterion to optimize during training. criterion: Callable #: Function to create and initialize trainig state. Should include initialization #: of optimizer and of batch_stats (if applicable). create_train_state: Callable #: Function to execute each training step. train_step_fn: Callable #: Function to execute each evaluation step. eval_step_fn: Callable #: Function to track metrics during training. metrics_fn: Callable #: List of post-processing functions to apply after a train step (if any). post_lst: List[Callable] class ModelVarDict(TypedDict): """Dictionary structure for Flax variables. Definition of the dictionary structure grouping all Flax model variables. """ #: Model weights and biases. params: PyTree #: Batch statistics (e.g. normalization parameters that depend on training data). batch_stats: PyTree class MetricsDict(TypedDict, total=False): """Dictionary structure for training metrics. Definition of the dictionary structure for metrics computed or updates made during training. """ loss: float #: Evaluation of criterion being optimized. snr: float #: Evaluation of signal to noise ratio. learning_rate: float #: Current learning rate. ================================================ FILE: scico/function.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Function class.""" from typing import Any, Callable, Optional, Sequence, Tuple, Union import jax import scico import scico.numpy as snp from scico.linop import LinearOperator, jacobian from scico.numpy import Array, BlockArray from scico.numpy.util import dtype_name from scico.operator import Operator from scico.typing import BlockShape, DType, Shape class Function: r"""Function class. A :class:`Function` maps multiple :code:`array-like` arguments to another :code:`array-like`. It is more general than both :class:`.Functional`, which is a mapping to a scalar, and :class:`.Operator`, which takes a single argument. """ def __init__( self, input_shapes: Sequence[Union[Shape, BlockShape]], output_shape: Optional[Union[Shape, BlockShape]] = None, eval_fn: Optional[Callable] = None, input_dtypes: Union[DType, Sequence[DType]] = snp.float32, output_dtype: Optional[DType] = None, jit: bool = False, ): """ Args: input_shapes: Shapes of input arrays. output_shape: Shape of output array. Defaults to ``None``. If ``None``, `output_shape` is determined by evaluating `self.__call__` on input arrays of zeros. eval_fn: Function used in evaluating this :class:`Function`. Defaults to ``None``. Required unless `__init__` is being called from a derived class with an `_eval` method. input_dtypes: `dtype` for input argument. If a single `dtype` is specified, it implies a common `dtype` for all inputs, otherwise a list or tuple of values should be provided, one per input. Defaults to :attr:`~numpy.float32`. output_dtype: `dtype` for output argument. Defaults to ``None``. If ``None``, `output_dtype` is determined by evaluating `self.__call__` on an input arrays of zeros. jit: If ``True``, jit the evaluation function. """ self.jit = jit self.input_shapes = input_shapes if isinstance(input_dtypes, (list, tuple)): self.input_dtypes = input_dtypes else: self.input_dtypes = (input_dtypes,) * len(input_shapes) if eval_fn is not None: self._eval = jax.jit(eval_fn) if jit else eval_fn elif not hasattr(self, "_eval"): raise NotImplementedError( "Function is an abstract base class when argument 'eval_fn' is not specified." ) # If the output shape/dtype aren't specified, they can be inferred # using scico.eval_shape if output_shape is None or output_dtype is None: dts_in = [ jax.ShapeDtypeStruct(shape, dtype=dtype) for (shape, dtype) in zip(self.input_shapes, self.input_dtypes) ] dts_out = scico.eval_shape(self._eval, *dts_in) if output_shape is None: self.output_shape = dts_out.shape # type: ignore else: self.output_shape = output_shape if output_dtype is None: self.output_dtype = dts_out.dtype else: self.output_dtype = output_dtype def __repr__(self): return f"""{self.__module__}.{self.__class__.__qualname__} input_shapes: {self.input_shapes} output_shape: {self.output_shape} input_dtypes: {", ".join([dtype_name(dt) for dt in self.input_dtypes])} output_dtype: {dtype_name(self.output_dtype)} """ def __call__(self, *args: Union[Array, BlockArray]) -> Union[Array, BlockArray]: """Evaluate this function with the specified parameters. Args: *args: Parameters at which to evaluate the function. Returns: Value of function with specified parameters. """ return self._eval(*args) def slice(self, index: int, *fix_args: Union[Array, BlockArray]) -> Operator: """Fix all but one parameter, returning a :class:`.Operator`. Args: index: Index of parameter that remains free. *fix_args: Fixed values for remaining parameters. Returns: An :class:`.Operator` taking the free parameter of the :class:`Function` as its input. """ def pfunc(var_arg): args = fix_args[0:index] + (var_arg,) + fix_args[index:] return self._eval(*args) return Operator( self.input_shapes[index], output_shape=self.output_shape, eval_fn=pfunc, input_dtype=self.input_dtypes[index], output_dtype=self.output_dtype, jit=self.jit, ) def join(self) -> Operator: """Combine inputs into a :class:`.BlockArray`. Construct an equivalent :class:`.Operator` taking a single :class:`.BlockArray` input combining all inputs of this :class:`Function`. Returns: An :class:`.Operator` taking a :class:`.BlockArray` as its input. """ for dtype in self.input_dtypes[1:]: if dtype != self.input_dtypes[0]: raise ValueError( "The join method may only be applied to Functions that have " "homogeneous input dtypes." ) def jfunc(blkarr): return self._eval(*blkarr.arrays) return Operator( self.input_shapes, # type: ignore output_shape=self.output_shape, eval_fn=jfunc, input_dtype=self.input_dtypes[0], output_dtype=self.output_dtype, jit=self.jit, ) def jvp( self, index: int, v: Union[Array, BlockArray], *args: Union[Array, BlockArray] ) -> Tuple[Union[Array, BlockArray], Union[Array, BlockArray]]: """Jacobian-vector product with respect to a single parameter. Compute a Jacobian-vector product with respect to a single parameter of a :class:`Function`. Note that the order of the parameters specifying where to evaluate the Jacobian and the vector in the product is reverse with respect to :func:`jax.jvp`. Args: index: Index of parameter with respect to which the Jacobian is to be computed. v: Vector against which the Jacobian-vector product is to be computed. *args: Values of function parameters at which Jacobian is to be computed. Returns: A pair consisting of the operator evaluated at the parameters specified by `*args` and the Jacobian-vector product. """ var_arg = args[index] fix_args = args[0:index] + args[(index + 1) :] F = self.slice(index, *fix_args) return F.jvp(var_arg, v) def vjp( self, index: int, *args: Union[Array, BlockArray], conjugate: Optional[bool] = True ) -> Tuple[Tuple[Any, ...], Callable]: """Vector-Jacobian product with respect to a single parameter. Compute a vector-Jacobian product with respect to a single parameter of a :class:`Function`. Args: index: Index of parameter with respect to which the Jacobian is to be computed. *args: Values of function parameters at which Jacobian is to be computed. conjugate: If ``True``, compute the product using the conjugate (Hermitian) transpose. Returns: A pair consisting of the operator evaluated at the parameters specified by `*args` and a function that computes the vector-Jacobian product. """ var_arg = args[index] fix_args = args[0:index] + args[(index + 1) :] F = self.slice(index, *fix_args) return F.vjp(var_arg, conjugate=conjugate) def jacobian( self, index: int, *args: Union[Array, BlockArray], include_eval: Optional[bool] = False ) -> LinearOperator: """Construct Jacobian linear operator for the function. Construct a Jacobian :class:`.LinearOperator` that computes vector products with the Jacobian with respect to a specified variable of the function. Args: index: Index of parameter with respect to which the Jacobian is to be computed. *args: Values of function parameters at which Jacobian is to be computed. include_eval: Flag indicating whether the result of evaluating the :class:`.Operator` should be included (as the first component of a :class:`.BlockArray`) in the output of the Jacobian :class:`.LinearOperator` constructed by this function. Returns: A :class:`.LinearOperator` capable of computing Jacobian-vector products. """ var_arg = args[index] fix_args = args[0:index] + args[(index + 1) :] F = self.slice(index, *fix_args) return jacobian(F, var_arg, include_eval=include_eval) ================================================ FILE: scico/functional/__init__.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Functionals and functionals classes.""" import sys # isort: off from ._functional import ( Functional, FunctionalSum, ComposedFunctional, ScaledFunctional, SeparableFunctional, ZeroFunctional, ) from ._norm import ( HuberNorm, L0Norm, L1Norm, SquaredL2Norm, L2Norm, L21Norm, NuclearNorm, L1MinusL2Norm, ) from ._tvnorm import AnisotropicTVNorm, IsotropicTVNorm, TVNorm from ._proxavg import ProximalAverage from ._indicator import NonNegativeIndicator, L2BallIndicator, BoxIndicator from ._denoiser import BM3D, BM4D, DnCNN from ._dist import SetDistance, SquaredSetDistance __all__ = [ "AnisotropicTVNorm", "IsotropicTVNorm", "TVNorm", "Functional", "FunctionalSum", "ComposedFunctional", "ScaledFunctional", "SeparableFunctional", "ZeroFunctional", "HuberNorm", "L0Norm", "L1Norm", "SquaredL2Norm", "L2Norm", "L21Norm", "L1MinusL2Norm", "NonNegativeIndicator", "BoxIndicator", "NuclearNorm", "L2BallIndicator", "ProximalAverage", "SetDistance", "SquaredSetDistance", "BM3D", "BM4D", "DnCNN", ] # Imported items in __all__ appear to originate in top-level functional module for name in __all__: getattr(sys.modules[__name__], name).__module__ = __name__ ================================================ FILE: scico/functional/_denoiser.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Pseudo-functionals that have denoisers as their proximal operators.""" from typing import Union from scico import denoiser from scico.numpy import Array from ._functional import Functional class BM3D(Functional): r"""Pseudo-functional whose prox applies the BM3D denoising algorithm. A pseudo-functional that has the BM3D algorithm :cite:`dabov-2008-image` as its proximal operator, which calls :func:`.denoiser.bm3d`. Since this function provides an interface to compiled C code, JAX features such as automatic differentiation and support for GPU devices are not available. """ has_eval = False has_prox = True def __init__(self, is_rgb: bool = False, profile: Union[denoiser.BM3DProfile, str] = "np"): r"""Initialize a :class:`BM3D` object. Args: is_rgb: Flag indicating use of BM3D with a color transform. Default: ``False``. profile: Parameter configuration for BM3D. """ self.is_rgb = is_rgb self.profile = profile super().__init__() def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array: # type: ignore r"""Apply BM3D denoiser. Args: x: Input image. lam: Noise parameter. **kwargs: Additional arguments that may be used by derived classes. Returns: Denoised output. """ return denoiser.bm3d(x, lam, self.is_rgb, profile=self.profile) class BM4D(Functional): r"""Pseudo-functional whose prox applies the BM4D denoising algorithm. A pseudo-functional that has the BM4D algorithm :cite:`maggioni-2012-nonlocal` as its proximal operator, which calls :func:`.denoiser.bm4d`. Since this function provides an interface to compiled C code, JAX features such as automatic differentiation and support for GPU devices are not available. """ has_eval = False has_prox = True def __init__(self, profile: Union[denoiser.BM4DProfile, str] = "np"): r"""Initialize a :class:`BM4D` object. Args: profile: Parameter configuration for BM4D. """ self.profile = profile super().__init__() def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array: # type: ignore r"""Apply BM4D denoiser. Args: x: Input image. lam: Noise parameter. **kwargs: Additional arguments that may be used by derived classes. Returns: Denoised output. """ return denoiser.bm4d(x, lam, profile=self.profile) class DnCNN(Functional): """Pseudo-functional whose prox applies the DnCNN denoising algorithm. A pseudo-functional that has the DnCNN algorithm :cite:`zhang-2017-dncnn` as its proximal operator, implemented via :class:`.denoiser.DnCNN`. """ has_eval = False has_prox = True def __init__(self, variant: str = "6M"): """ Args: variant: Identify the DnCNN model to be used. See :class:`.denoiser.DnCNN` for valid values. """ self.dncnn = denoiser.DnCNN(variant) if self.dncnn.is_blind: def denoise(x, sigma): return self.dncnn(x) else: def denoise(x, sigma): return self.dncnn(x, sigma) self._denoise = denoise def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array: # type: ignore r"""Apply DnCNN denoiser. *Warning*: The `lam` parameter is ignored, and has no effect on the output for :class:`.DnCNN` objects initialized with :code:`variant` parameter values other than `6N` and `17N`. Args: x: Input array. lam: Noise parameter (ignored). **kwargs: Additional arguments that may be used by derived classes. Returns: Denoised output. """ return self._denoise(x, lam) ================================================ FILE: scico/functional/_dist.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Distance functions.""" from typing import Callable, Union from scico import numpy as snp from scico.numpy import Array, BlockArray from ._functional import Functional class SetDistance(Functional): r"""Distance to a closed convex set. This functional computes the :math:`\ell_2` distance from a vector to a closed convex set :math:`C` .. math:: d(\mb{x}) = \min_{\mb{y} \in C} \, \| \mb{x} - \mb{y} \|_2 \;. The set is not specified directly, but in terms of a function computing the projection into that set, i.e. .. math:: d(\mb{x}) = \| \mb{x} - P_C(\mb{x}) \|_2 \;, where :math:`P_C(\mb{x})` is the projection of :math:`\mb{x}` into set :math:`C`. """ has_eval = True has_prox = True def __init__(self, proj: Callable, args=()): r""" Args: proj: Function computing the projection into the convex set. args: Additional arguments for function `proj`. """ self.proj = proj self.args = args def __call__(self, x: Union[Array, BlockArray]) -> float: r"""Compute the :math:`\ell_2` distance to the set. Compute the distance :math:`d(\mb{x})` between :math:`\mb{x}` and the set :math:`C`. Args: x: Input array :math:`\mb{x}`. Returns: Euclidean distance from `x` to the projection of `x`. """ y = self.proj(*((x,) + self.args)) return snp.linalg.norm(x - y) def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""Proximal operator of the :math:`\ell_2` distance function. Compute the proximal operator of the :math:`\ell_2` distance function :math:`d(\mb{x})` :cite:`beck-2017-first` (Lemma 6.43). Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda`. **kwargs: Additional arguments that may be used by derived classes. Returns: Scaled proximal operator evaluated at `v`. """ y = self.proj(*((v,) + self.args)) d = snp.linalg.norm(v - y) 𝜃 = lam / d if d >= lam else 1.0 return 𝜃 * y + (1.0 - 𝜃) * v class SquaredSetDistance(Functional): r"""Squared :math:`\ell_2` distance to a closed convex set. This functional computes the :math:`\ell_2` distance from a vector to a closed convex set :math:`C` .. math:: d(\mb{x}) = \min_{\mb{y} \in C} \, (1/2) \| \mb{x} - \mb{y} \|_2^2 \;. The set is not specified directly, but in terms of a function computing the projection into that set, i.e. .. math:: d(\mb{x}) = (1/2) \| \mb{x} - P_C(\mb{x}) \|_2^2 \;, where :math:`P_C(\mb{x})` is the projection of :math:`\mb{x}` into set :math:`C`. """ has_eval = True has_prox = True def __init__(self, proj: Callable, args=()): r""" Args: proj: Function computing the projection into the convex set. args: Additional arguments for function `proj`. """ self.proj = proj self.args = args def __call__(self, x: Union[Array, BlockArray]) -> float: r"""Compute the squared :math:`\ell_2` distance to the set. Compute the distance :math:`d(\mb{x})` between :math:`\mb{x}` and the set :math:`C`. Args: x: Input array :math:`\mb{x}`. Returns: Squared :math:`\ell_2` distance from `x` to the projection of `x`. """ y = self.proj(*((x,) + self.args)) return 0.5 * snp.linalg.norm(x - y) ** 2 def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""Proximal operator of the squared :math:`\ell_2` distance function. Compute the proximal operator of the squared :math:`\ell_2` distance function :math:`d(\mb{x})` :cite:`beck-2017-first` (Example 6.65). Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda`. **kwargs: Additional arguments that may be used by derived classes. Returns: Scaled proximal operator evaluated at `v`. """ y = self.proj(*((v,) + self.args)) 𝛼 = 1.0 / (1.0 + lam) return 𝛼 * v + lam * 𝛼 * y ================================================ FILE: scico/functional/_functional.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Functional base class.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from typing import List, Optional, Union import scico from scico import numpy as snp from scico.linop import LinearOperator from scico.numpy import Array, BlockArray class Functional: r"""Base class for functionals. A functional maps an :code:`array-like` to a scalar; abstractly, a functional is a mapping from :math:`\mathbb{R}^n` or :math:`\mathbb{C}^n` to :math:`\mathbb{R}`. """ #: True if this functional can be evaluated, False otherwise. #: This attribute must be overridden and set to True or False in any derived classes. has_eval: Optional[bool] = None #: True if this functional has the prox method, False otherwise. #: This attribute must be overridden and set to True or False in any derived classes. has_prox: Optional[bool] = None def __init__(self): self._grad = scico.grad(self.__call__) def __str__(self): return f"""{self.__module__}.{self.__class__.__qualname__}""" def __repr__(self): return self.__str__() + f"""\n has_eval: {self.has_eval}\n has_prox: {self.has_prox}\n""" def __mul__(self, other: Union[float, int]) -> ScaledFunctional: if snp.util.is_scalar_equiv(other): return ScaledFunctional(self, other) return NotImplemented def __rmul__(self, other: Union[float, int]) -> ScaledFunctional: return self.__mul__(other) def __add__(self, other: Functional) -> FunctionalSum: if isinstance(other, Functional): return FunctionalSum(self, other) return NotImplemented def __call__(self, x: Union[Array, BlockArray]) -> float: r"""Evaluate this functional at point :math:`\mb{x}`. Args: x: Point at which to evaluate this functional. Returns: Result of evaluating the functional at `x`. """ # Functionals that can be evaluated should override this method. raise NotImplementedError(f"Functional {type(self)} cannot be evaluated.") def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""Scaled proximal operator of functional. Evaluate scaled proximal operator of this functional, with scaling :math:`\lambda` = `lam` and evaluated at point :math:`\mb{v}` = `v`. The scaled proximal operator is defined as .. math:: \prox_{\lambda f}(\mb{v}) = \argmin_{\mb{x}} \lambda f(\mb{x}) + \frac{1}{2} \norm{\mb{v} - \mb{x}}_2^2\;, where :math:`\lambda f(\mb{x})` represents this functional evaluated at :math:`\mb{x}` multiplied by :math:`\lambda`. Args: v: Point at which to evaluate prox function. lam: Proximal parameter :math:`\lambda`. **kwargs: Additional arguments that may be used by derived classes. These include `x0`, an initial guess for the minimizer in the definition of :math:`\prox`. Returns: Result of evaluating the scaled proximal operator at `v`. """ # Functionals that have a prox should override this method. raise NotImplementedError(f"Functional {type(self)} does not have a prox.") def conj_prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""Scaled proximal operator of convex conjugate of functional. Evaluate scaled proximal operator of convex conjugate (Fenchel conjugate) of this functional, with scaling :math:`\lambda` = `lam`, and evaluated at point :math:`\mb{v}` = `v`. Denoting this functional by :math:`f` and its convex conjugate by :math:`f^*`, the proximal operator of :math:`f^*` is computed as follows by exploiting the extended Moreau decomposition (see Sec. 6.6 of :cite:`beck-2017-first`) .. math:: \prox_{\lambda f^*}(\mb{v}) = \mb{v} - \lambda \, \prox_{\lambda^{-1} f}(\mb{v / \lambda}) \;. Args: v: Point at which to evaluate prox function. lam: Proximal parameter :math:`\lambda`. **kwargs: Additional keyword args, passed directly to `self.prox`. Returns: Result of evaluating the scaled proximal operator at `v`. """ return v - lam * self.prox(v / lam, 1.0 / lam, **kwargs) def grad(self, x: Union[Array, BlockArray]): r"""Evaluate the gradient of this functional at :math:`\mb{x}`. Args: x: Point at which to evaluate gradient. Returns: The gradient at `x`. """ return self._grad(x) class ScaledFunctional(Functional): r"""A functional multiplied by a scalar.""" def __init__(self, functional: Functional, scale: float): self.functional = functional self.scale = scale self.has_eval = functional.has_eval self.has_prox = functional.has_prox super().__init__() def __repr__(self): return ( f"""{Functional.__repr__(self)}""" f""" functional: {Functional.__str__(self.functional)}\n""" f""" scale: {self.scale}\n""" ) def __call__(self, x: Union[Array, BlockArray]) -> float: return self.scale * self.functional(x) def __mul__(self, other: Union[float, int]) -> ScaledFunctional: if snp.util.is_scalar_equiv(other): return ScaledFunctional(self.functional, other * self.scale) return NotImplemented def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""Evaluate the scaled proximal operator of the scaled functional. Note that, by definition, the scaled proximal operator of a functional is the proximal operator of the scaled functional. The scaled proximal operator of a scaled functional is the scaled proximal operator of the unscaled functional with the proximal operator scaling consisting of the product of the two scaling factors, i.e., for functional :math:`f` and scaling factors :math:`\alpha` and :math:`\beta`, the proximal operator with scaling parameter :math:`\alpha` of scaled functional :math:`\beta f` is the proximal operator with scaling parameter :math:`\alpha \beta` of functional :math:`f`, .. math:: \prox_{\alpha (\beta f)}(\mb{v}) = \prox_{(\alpha \beta) f}(\mb{v}) \;. Args: v: Point at which to evaluate prox function. lam: Proximal parameter :math:`\lambda`. **kwargs: Additional arguments that may be used by derived classes. These include `x0`, an initial guess for the minimizer in the definition of :math:`\prox`. Returns: Result of evaluating the scaled proximal operator at `v`. """ return self.functional.prox(v, lam * self.scale, **kwargs) class SeparableFunctional(Functional): r"""A functional that is separable in its arguments. A separable functional :math:`f : \mathbb{C}^N \to \mathbb{R}` can be written as the sum of functionals :math:`f_i : \mathbb{C}^{N_i} \to \mathbb{R}` with :math:`\sum_i N_i = N`. In particular, .. math:: f(\mb{x}) = f(\mb{x}_1, \dots, \mb{x}_N) = f_1(\mb{x}_1) + \dots + f_N(\mb{x}_N) \;. A :class:`SeparableFunctional` accepts a :class:`.BlockArray` and is separable in the block components. """ def __init__(self, functional_list: List[Functional]): r""" Args: functional_list: List of component functionals f_i. This functional takes as an input a :class:`.BlockArray` with `num_blocks == len(functional_list)`. """ self.functional_list: List[Functional] = functional_list self.has_eval: bool = all(fi.has_eval for fi in functional_list) self.has_prox: bool = all(fi.has_prox for fi in functional_list) super().__init__() def __repr__(self): return ( Functional.__repr__(self) + " components: " + ", ".join([str(f) for f in self.functional_list]) + "\n" ) def __call__(self, x: BlockArray) -> float: if len(x.shape) == len(self.functional_list): return snp.sum(snp.array([fi(xi) for fi, xi in zip(self.functional_list, x)])) raise ValueError( f"Number of blocks in x, {len(x.shape)}, and length of functional_list, " f"{len(self.functional_list)}, do not match." ) def prox(self, v: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray: r"""Evaluate proximal operator of the separable functional. Evaluate proximal operator of the separable functional (see Theorem 6.6 of :cite:`beck-2017-first`). .. math:: \prox_{\lambda f}(\mb{v}) = \begin{bmatrix} \prox_{\lambda f_1}(\mb{v}_1) \\ \vdots \\ \prox_{\lambda f_N}(\mb{v}_N) \\ \end{bmatrix} \;. Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda`. **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ if len(v.shape) == len(self.functional_list): return snp.blockarray( [fi.prox(vi, lam, **kwargs) for fi, vi in zip(self.functional_list, v)] ) raise ValueError( f"Number of blocks in v, {len(v.shape)}, and length of functional_list, " f"{len(self.functional_list)}, do not match." ) class ComposedFunctional(Functional): r"""A functional constructed by composition. A functional constructed by composition of a functional with an orthogonal linear operator, i.e. .. math:: f(\mb{x}) = g(A \mb{x}) where :math:`f` is the composed functional, :math:`g` is the functional from which it is composed, and :math:`A` is an orthogonal linear operator. Note that the resulting :class:`Functional` can only be applied (either via evaluation or :meth:`prox` calls) to inputs of shape and dtype corresponding to the input specification of the linear operator. """ def __init__(self, functional: Functional, linop: LinearOperator): r""" Args: functional: The functional :math:`g` to be composed. linop: The linear operator :math:`A` to be composed. Note that it is the user's responsibility to confirm that the linear operator is orthogonal. If it is not, the result of :meth:`prox` will be incorrect. """ self.functional = functional self.linop = linop self.has_eval = functional.has_eval self.has_prox = functional.has_prox super().__init__() def __repr__(self): return ( Functional.__repr__(self) + f""" composition of: {self.functional.__str__()} and {self.linop.__str__()}\n""" ) def __call__(self, x: BlockArray) -> float: return self.functional(self.linop(x)) def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""Evaluate proximal operator of a composed functional. Evaluate proximal operator :math:`f(\mb{x}) = g(A \mb{x})`, where :math:`A` is an orthogonal linear operator, via a special case of Theorem 6.15 of :cite:`beck-2017-first` .. math:: \prox_{\lambda f}(\mb{v}) = A^T \prox_{\lambda g}(A \mb{v}) \;. Examples of orthogonal linear operator in SCICO include :class:`.linop.Reshape` and :class:`.linop.Transpose`. Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda`. **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ return self.linop.H(self.functional.prox(self.linop(v), lam=lam, **kwargs)) class FunctionalSum(Functional): r"""A sum of two functionals.""" def __init__(self, functional1: Functional, functional2: Functional): self.functional1 = functional1 self.functional2 = functional2 self.has_eval = functional1.has_eval and functional2.has_eval self.has_prox = False super().__init__() def __repr__(self): return ( Functional.__repr__(self) + f""" sum of functionals: {Functional.__str__(self.functional1)} and """ + f"""{Functional.__str__(self.functional2)}\n""" ) def __call__(self, x: Union[Array, BlockArray]) -> float: return self.functional1(x) + self.functional2(x) class ZeroFunctional(Functional): r"""Zero functional, :math:`f(\mb{x}) = 0 \in \mbb{R}` for any input.""" has_eval = True has_prox = True def __call__(self, x: Union[Array, BlockArray]) -> float: return 0.0 def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: return v ================================================ FILE: scico/functional/_indicator.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Functionals that are indicator functions/constraints.""" from typing import Union import jax from scico import numpy as snp from scico.numpy import Array, BlockArray from scico.numpy.linalg import norm from ._functional import Functional class NonNegativeIndicator(Functional): r"""Indicator function for non-negative orthant. Returns 0 if all elements of input array-like are non-negative, and `inf` otherwise .. math:: I(\mb{x}) = \begin{cases} 0 & \text{ if } x_i \geq 0 \; \forall i \\ \infty & \text{ otherwise} \;. \end{cases} """ has_eval = True has_prox = True def __call__(self, x: Union[Array, BlockArray]) -> float: if snp.util.is_complex_dtype(x.dtype): raise ValueError("Not defined for complex input.") # Equivalent to snp.inf if snp.any(x < 0) else 0.0 return jax.lax.cond(snp.any(x < 0), lambda x: snp.inf, lambda x: 0.0, None) def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""The scaled proximal operator of the non-negative indicator. Evaluate the scaled proximal operator of the indicator over the non-negative orthant, :math:`I`, .. math:: [\mathrm{prox}_{\lambda I}(\mb{v})]_i = \begin{cases} v_i\, & \text{ if } v_i \geq 0 \\ 0\, & \text{ otherwise} \;. \end{cases} Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda` (has no effect). **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ return snp.maximum(v, 0) class L2BallIndicator(Functional): r"""Indicator function for :math:`\ell_2` ball of given radius. Indicator function for :math:`\ell_2` ball of given radius, :math:`r` .. math:: I(\mb{x}) = \begin{cases} 0 & \text{ if } \norm{\mb{x}}_2 \leq r \\ \infty & \text{ otherwise} \;. \end{cases} Attributes: radius: Radius of :math:`\ell_2` ball. """ has_eval = True has_prox = True def __init__(self, radius: float = 1.0): r"""Initialize a :class:`L2BallIndicator` object. Args: radius: Radius of :math:`\ell_2` ball. Default: 1.0. """ self.radius = radius super().__init__() def __call__(self, x: Union[Array, BlockArray]) -> float: # Equivalent to: snp.inf if norm(x) > self.radius else 0.0 return jax.lax.cond(norm(x) > self.radius, lambda x: snp.inf, lambda x: 0.0, None) def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""The scaled proximal operator of the :math:`\ell_2` ball indicator. Evaluate the scaled proximal operator of the indicator, :math:`I`, of the :math:`\ell_2` ball with radius :math:`r` .. math:: \mathrm{prox}_{\lambda I}(\mb{v}) = \begin{cases} \mb{v} & \text{ if } \norm{\mb{v}}_2 \leq r \\ r \frac{\mb{v}}{\norm{\mb{v}}_2} & \text{ otherwise} \;. \end{cases} Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda` (has no effect). **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ return jax.lax.cond( norm(v) > self.radius, lambda v: self.radius * v / norm(v), lambda v: v, v ) class BoxIndicator(Functional): r"""Box indicator function.. Indicator function of the constraint set :math:`a \leq x \leq b` for lower and upper bounds :math:`a` and :math:`b` respectively. """ has_eval = True has_prox = True def __init__(self, lb: float = 0.0, ub: float = 1.0): r"""Initialize a :class:`BoxIndicator` object. Args: lb: Lower bound. ub: Upper bound. """ self.lb = lb self.ub = ub def __call__(self, x: Union[Array, BlockArray]) -> float: if snp.util.is_complex_dtype(x.dtype): raise ValueError("Not defined for complex input.") constr = snp.logical_and(self.lb <= x, x <= self.ub) return jax.lax.cond(snp.all(constr), lambda x: 0.0, lambda x: snp.inf, None) def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""The scaled proximal operator of the box indicator. Evaluate the scaled proximal operator of the constraint set :math:`a \leq x \leq b` for lower and upper bounds :math:`a` and :math:`b` respectively. Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda` (has no effect). **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ return snp.clip(v, self.lb, self.ub) ================================================ FILE: scico/functional/_norm.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Functionals that are norms.""" from functools import partial from typing import Optional, Tuple, Union from jax import jit, lax from scico import numpy as snp from scico.numpy import Array, BlockArray, count_nonzero from scico.numpy.linalg import norm from scico.numpy.util import no_nan_divide from ._functional import Functional class L0Norm(Functional): r"""The :math:`\ell_0` 'norm'. The :math:`\ell_0` 'norm' counts the number of non-zero elements in an array. """ has_eval = True has_prox = True @staticmethod @jit def __call__(x: Union[Array, BlockArray]) -> float: return count_nonzero(x) @staticmethod @jit def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Union[Array, BlockArray]: r"""Evaluate scaled proximal operator of :math:`\ell_0` norm. Evaluate scaled proximal operator of :math:`\ell_0` norm using .. math:: \left[ \prox_{\lambda\| \cdot \|_0}(\mb{v}) \right]_i = \begin{cases} v_i & \text{ if } \abs{v_i} \geq \lambda \\ 0 & \text{ otherwise } \;. \end{cases} Args: v: Input array :math:`\mb{v}`. lam: Thresholding parameter :math:`\lambda`. **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ return snp.where(snp.abs(v) >= lam, v, 0) class L1Norm(Functional): r"""The :math:`\ell_1` norm. Computes .. math:: \norm{\mb{x}}_1 = \sum_i \abs{x_i}^2 \;. """ has_eval = True has_prox = True @staticmethod @jit def __call__(x: Union[Array, BlockArray]) -> float: return snp.sum(snp.abs(x)) @staticmethod @jit def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Array: r"""Evaluate scaled proximal operator of :math:`\ell_1` norm. Evaluate scaled proximal operator of :math:`\ell_1` norm using .. math:: \left[ \prox_{\lambda \|\cdot\|_1}(\mb{v}) \right]_i = \sign(v_i) (\abs{v_i} - \lambda)_+ \;, where .. math:: (x)_+ = \begin{cases} x & \text{ if } x \geq 0 \\ 0 & \text{ otherwise} \;. \end{cases} Args: v: Input array :math:`\mb{v}`. lam: Thresholding parameter :math:`\lambda`. **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ tmp = snp.abs(v) - lam tmp = 0.5 * (tmp + snp.abs(tmp)) if snp.util.is_complex_dtype(v.dtype): out = snp.exp(1j * snp.angle(v)) * tmp else: out = snp.sign(v) * tmp return out class SquaredL2Norm(Functional): r"""The squared :math:`\ell_2` norm. Squared :math:`\ell_2` norm .. math:: \norm{\mb{x}}^2_2 = \sum_i \abs{x_i}^2 \;. """ has_eval = True has_prox = True @staticmethod @jit def __call__(x: Union[Array, BlockArray]) -> float: # Directly implement the squared l2 norm to avoid nondifferentiable # behavior of snp.norm(x) at 0. return snp.sum(snp.abs(x) ** 2) @staticmethod @jit def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Union[Array, BlockArray]: r"""Evaluate proximal operator of squared :math:`\ell_2` norm. Evaluate proximal operator of squared :math:`\ell_2` norm using .. math:: \prox_{\lambda \| \cdot \|_2^2}(\mb{v}) = \frac{\mb{v}}{1 + 2 \lambda} \;. Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda`. **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ return v / (1.0 + 2.0 * lam) class L2Norm(Functional): r"""The :math:`\ell_2` norm. .. math:: \norm{\mb{x}}_2 = \sqrt{\sum_i \abs{x_i}^2} \;. """ has_eval = True has_prox = True @staticmethod @jit def __call__(x: Union[Array, BlockArray]) -> float: return norm(x) @staticmethod @jit def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Union[Array, BlockArray]: r"""Evaluate proximal operator of :math:`\ell_2` norm. Evaluate proximal operator of :math:`\ell_2` norm using .. math:: \prox_{\lambda \| \cdot \|_2}(\mb{v}) = \mb{v} \, \left(1 - \frac{\lambda}{\norm{\mb{v}}_2} \right)_+ \;, where .. math:: (x)_+ = \begin{cases} x & \text{ if } x \geq 0 \\ 0 & \text{ otherwise} \;. \end{cases} Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda`. **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ norm_v = norm(v) return snp.where(norm_v == 0, 0 * v, snp.maximum(1 - lam / norm_v, 0) * v) class L21Norm(Functional): r"""The :math:`\ell_{2,1}` norm. For a :math:`M \times N` matrix, :math:`\mb{A}`, by default, .. math:: \norm{\mb{A}}_{2,1} = \sum_{n=1}^N \sqrt{\sum_{m=1}^M \abs{A_{m,n}}^2} \;. The norm generalizes to more dimensions by first computing the :math:`\ell_2` norm along one or more (user-specified) axes, followed by a sum over all remaining axes. :class:`.BlockArray` inputs require parameter `l2_axis` to be ``None``, in which case the :math:`\ell_2` norm is computed over each block. A typical use case is computing the isotropic total variation norm. """ has_eval = True has_prox = True def __init__(self, l2_axis: Union[None, int, Tuple] = 0): r""" Args: l2_axis: Axis/axes over which to take the l2 norm. Required to be ``None`` for :class:`.BlockArray` inputs to be supported. """ self.l2_axis = l2_axis @staticmethod @partial(jit, static_argnames=("axis", "keepdims")) def _l2norm( x: Union[Array, BlockArray], axis: Union[None, int, Tuple], keepdims: Optional[bool] = False ) -> Union[Array, BlockArray]: r"""Return the :math:`\ell_2` norm of an array.""" return snp.sqrt((snp.abs(x) ** 2).sum(axis=axis, keepdims=keepdims)) def __call__(self, x: Union[Array, BlockArray]) -> float: if isinstance(x, snp.BlockArray) and self.l2_axis is not None: raise ValueError("Initializer argument 'l2_axis' must be None for BlockArray input.") l2 = L21Norm._l2norm(x, axis=self.l2_axis) return snp.sum(snp.abs(l2)) @staticmethod @partial(jit, static_argnames=("axis")) def _prox( v: Union[Array, BlockArray], lam: float, axis: Union[None, int, Tuple] ) -> Union[Array, BlockArray]: r"""Evaluate proximal operator of the :math:`\ell_{2,1}` norm.""" length = L21Norm._l2norm(v, axis=axis, keepdims=True) direction = no_nan_divide(v, length) new_length = length - lam # set negative values to zero without `if` new_length = 0.5 * (new_length + snp.abs(new_length)) return new_length * direction def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""Evaluate proximal operator of the :math:`\ell_{2,1}` norm. In two dimensions, .. math:: \prox_{\lambda \|\cdot\|_{2,1}}(\mb{v}, \lambda)_{:, n} = \frac{\mb{v}_{:, n}}{\|\mb{v}_{:, n}\|_2} (\|\mb{v}_{:, n}\|_2 - \lambda)_+ \;, where .. math:: (x)_+ = \begin{cases} x & \text{ if } x \geq 0 \\ 0 & \text{ otherwise} \;. \end{cases} Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda`. **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ if isinstance(v, snp.BlockArray) and self.l2_axis is not None: raise ValueError("Initializer argument 'l2_axis' must be None for BlockArray input.") return L21Norm._prox(v, lam=lam, axis=self.l2_axis) class L1MinusL2Norm(Functional): r"""Difference of :math:`\ell_1` and :math:`\ell_2` norms. Difference of :math:`\ell_1` and :math:`\ell_2` norms .. math:: \norm{\mb{x}}_1 - \beta * \norm{\mb{x}}_2 """ has_eval = True has_prox = True def __init__(self, beta: float = 1.0): r""" Args: beta: Parameter :math:`\beta` in the norm definition. """ self.beta = beta @staticmethod @jit def _l1minusl2norm(x: Union[Array, BlockArray], beta: float) -> float: r"""Return the :math:`\ell_1 - \ell_2` norm of an array.""" return snp.sum(snp.abs(x)) - beta * norm(x) def __call__(self, x: Union[Array, BlockArray]) -> float: return L1MinusL2Norm._l1minusl2norm(x, self.beta) @staticmethod def _prox_vamx_ge_thresh(v, va, vs, alpha, beta): u = snp.zeros(v.shape, dtype=v.dtype) idx = va.ravel().argmax() u = ( u.ravel().at[idx].set((va.ravel()[idx] + (beta - 1.0) * alpha) * vs.ravel()[idx]) ).reshape(v.shape) return u @staticmethod def _prox_vamx_le_alpha(v, va, vs, vamx, alpha, beta): return snp.where( vamx < (1.0 - beta) * alpha, snp.zeros(v.shape, dtype=v.dtype), L1MinusL2Norm._prox_vamx_ge_thresh(v, va, vs, alpha, beta), ) @staticmethod def _prox_vamx_gt_alpha(v, va, vs, alpha, beta): u = snp.maximum(va - alpha, 0.0) * vs l2u = norm(u) u *= (l2u + alpha * beta) / l2u return u @staticmethod def _prox_vamx_gt_0(v, va, vs, vamx, alpha, beta): return snp.where( vamx > alpha, L1MinusL2Norm._prox_vamx_gt_alpha(v, va, vs, alpha, beta), L1MinusL2Norm._prox_vamx_le_alpha(v, va, vs, vamx, alpha, beta), ) @staticmethod @jit def _prox(v: Union[Array, BlockArray], lam: float, beta: float) -> Union[Array, BlockArray]: r"""Proximal operator of :math:`\ell_1 - \ell_2` norm.""" alpha = lam va = snp.abs(v) vamx = snp.max(va) if snp.util.is_complex_dtype(v.dtype): vs = snp.exp(1j * snp.angle(v)) else: vs = snp.sign(v) return snp.where( vamx > 0.0, L1MinusL2Norm._prox_vamx_gt_0(v, va, vs, vamx, alpha, beta), snp.zeros(v.shape, dtype=v.dtype), ) def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""Proximal operator of difference of :math:`\ell_1` and :math:`\ell_2` norms. Evaluate the proximal operator of the difference of :math:`\ell_1` and :math:`\ell_2` norms, i.e. :math:`\alpha \left( \| \mb{x} \|_1 - \beta \| \mb{x} \|_2 \right)` :cite:`lou-2018-fast`. Note that this is not a proximal operator according to the strict definition since the loss function is non-convex. Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda`. **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ return L1MinusL2Norm._prox(v, lam=lam, beta=self.beta) class HuberNorm(Functional): r"""Huber norm. Compute a norm based on the Huber function :cite:`huber-1964-robust` :cite:`beck-2017-first` (Sec. 6.7.1). In the non-separable case the norm is .. math:: H_{\delta}(\mb{x}) = \begin{cases} (1/2) \norm{ \mb{x} }_2^2 & \text{ when } \norm{ \mb{x} }_2 \leq \delta \\ \delta \left( \norm{ \mb{x} }_2 - (\delta / 2) \right) & \text{ when } \norm{ \mb{x} }_2 > \delta \;, \end{cases} where :math:`\delta` is a parameter controlling the transitions between :math:`\ell_1`-norm like and :math:`\ell_2`-norm like behavior. In the separable case the norm is .. math:: H_{\delta}(\mb{x}) = \sum_i h_{\delta}(x_i) \,, where .. math:: h_{\delta}(x) = \begin{cases} (1/2) \abs{ x }^2 & \text{ when } \abs{ x } \leq \delta \\ \delta \left( \abs{ x } - (\delta / 2) \right) & \text{ when } \abs{ x } > \delta \;. \end{cases} """ has_eval = True has_prox = True def __init__(self, delta: float = 1.0, separable: bool = True): r""" Args: delta: Huber function parameter :math:`\delta`. separable: Flag indicating whether to compute separable or non-separable form. """ self.delta = delta self.separable = separable if separable: self._call = self._call_sep self._prox = self._prox_sep else: self._call = self._call_nonsep self._prox = self._prox_nonsep super().__init__() @staticmethod @jit def _call_sep(x: Union[Array, BlockArray], delta: float) -> float: xabs = snp.abs(x) hx = snp.where(xabs <= delta, 0.5 * xabs**2, delta * (xabs - (delta / 2.0))) return snp.sum(hx) @staticmethod @jit def _call_nonsep(x: Union[Array, BlockArray], delta: float) -> float: xl2 = snp.linalg.norm(x) return lax.cond( xl2 <= delta, lambda xl2: 0.5 * xl2**2, lambda xl2: delta * (xl2 - delta / 2.0), xl2 ) def __call__(self, x: Union[Array, BlockArray]) -> float: return self._call(x, self.delta) @staticmethod @jit def _prox_sep( v: Union[Array, BlockArray], lam: float, delta: float ) -> Union[Array, BlockArray]: den = snp.maximum(snp.abs(v), delta * (1.0 + lam)) return (1.0 - ((delta * lam) / den)) * v @staticmethod @jit def _prox_nonsep( v: Union[Array, BlockArray], lam: float, delta: float ) -> Union[Array, BlockArray]: vl2 = snp.linalg.norm(v) den = snp.maximum(vl2, delta * (1.0 + lam)) return (1.0 - ((delta * lam) / den)) * v def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""Evaluate proximal operator of the Huber function. Evaluate scaled proximal operator of the Huber function :cite:`beck-2017-first` (Sec. 6.7.3). The prox is .. math:: \prox_{\lambda H_{\delta}} (\mb{v}) = \left( 1 - \frac{\lambda \delta} {\max\left\{\norm{\mb{v}}_2, \delta + \lambda \delta\right\} } \right) \mb{v} in the non-separable case, and .. math:: \left[ \prox_{\lambda H_{\delta}} (\mb{v}) \right]_i = \left( 1 - \frac{\lambda \delta} {\max\left\{\abs{v_i}, \delta + \lambda \delta\right\} } \right) v_i in the separable case. Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda`. **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ return self._prox(v, lam=lam, delta=self.delta) class NuclearNorm(Functional): r"""Nuclear norm. Compute the nuclear norm .. math:: \| X \|_* = \sum_i \sigma_i where :math:`\sigma_i` are the singular values of matrix :math:`X`. """ has_eval = True has_prox = True @staticmethod @jit def __call__(x: Union[Array, BlockArray]) -> float: if x.ndim != 2: raise ValueError("Input array must be two dimensional.") return snp.sum(snp.linalg.svd(x, full_matrices=False, compute_uv=False)) @staticmethod @jit def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Union[Array, BlockArray]: r"""Evaluate proximal operator of the nuclear norm. Evaluate proximal operator of the nuclear norm :cite:`cai-2010-singular`. Args: v: Input array :math:`\mb{v}`. Required to be two-dimensional. lam: Proximal parameter :math:`\lambda`. **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ if v.ndim != 2: raise ValueError("Input array must be two dimensional.") svdU, svdS, svdV = snp.linalg.svd(v, full_matrices=False) svdS = snp.maximum(0, svdS - lam) return svdU @ snp.diag(svdS) @ svdV ================================================ FILE: scico/functional/_proxavg.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2023-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Implementation of the proximal average method.""" from typing import List, Optional, Union from scico.numpy import Array, BlockArray, isinf from ._functional import Functional class ProximalAverage(Functional): """Weighted average of functionals. A functional that is composed of a weighted average of functionals. All of the component functionals are required to have proximal operators. The proximal operator of the composite functional is approximated via the proximal average method :cite:`yu-2013-better`, which holds for small scaling parameters. This does not imply that it can only be applied to problems requiring a small regularization parameter since most proximal algorithms include an additional algorithm parameter that also plays a role in the parameter of the proximal operator. For example, in :class:`.PGM` and :class:`.AcceleratedPGM`, the scaled proximal operator parameter is the regularization parameter divided by the `L0` algorithm parameter, and for :class:`.ADMM`, the scaled proximal operator parameters are the regularization parameters divided by the entries in the `rho_list` algorithm parameter. """ def __init__( self, func_list: List[Functional], alpha_list: Optional[List[float]] = None, no_inf_eval=True, ): """ Args: func_list: List of component :class:`.Functional` objects, all of which must have a proximal operator. alpha_list: List of scalar weights for each :class:`.Functional`. If not specified, defaults to equal weights. If specified, the list of weights must have the same length as the :class:`.Functional` list. If the weights do not sum to unity, they are scaled to ensure that they do. no_inf_eval: If ``True``, exclude infinite values (typically associated with a functional that is an indicator function) from the evaluation of the sum of component functionals. """ self.has_prox = all([f.has_prox for f in func_list]) if not self.has_prox: raise ValueError("All functionals in 'func_list' must have has_prox == True.") self.has_eval = all([f.has_eval for f in func_list]) self.no_inf_eval = no_inf_eval self.func_list = func_list N = len(func_list) if alpha_list is None: self.alpha_list = [1.0 / N] * N else: if len(alpha_list) != N: raise ValueError( "If specified, argument 'alpha_list' must have the same length as func_list" ) alpha_sum = sum(alpha_list) if alpha_sum != 1.0: alpha_list = [alpha / alpha_sum for alpha in alpha_list] self.alpha_list = alpha_list def __repr__(self): return ( Functional.__repr__(self) + " components: " + ", ".join([str(f) for f in self.func_list]) + "\n weights: " + ", ".join([str(alpha) for alpha in self.alpha_list]) + "\n" ) def __call__(self, x: Union[Array, BlockArray]) -> float: """Evaluate the weighted average of component functionals.""" if self.has_eval: weight_func_vals = [alpha * f(x) for (alpha, f) in zip(self.alpha_list, self.func_list)] if self.no_inf_eval: weight_func_vals = list(filter(lambda x: not isinf(x), weight_func_vals)) return sum(weight_func_vals) else: raise ValueError( "At least one functional in argument 'func_list' has has_eval == False." ) def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""Approximate proximal operator of the average of functionals. Approximation of the proximal operator of a weighted average of functionals computed via the proximal average method :cite:`yu-2013-better`. Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lam`. **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ return sum( [ alpha * f.prox(v, lam, **kwargs) for (alpha, f) in zip(self.alpha_list, self.func_list) ] ) ================================================ FILE: scico/functional/_tvnorm.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2023-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Total variation norms.""" from functools import partial from typing import Optional, Tuple import jax from scico import numpy as snp from scico.linop import ( Crop, FiniteDifference, LinearOperator, Pad, SingleAxisFiniteDifference, VerticalStack, linop_over_axes, ) from scico.numpy import Array from scico.numpy.util import normalize_axes from scico.typing import Axes, DType, Shape from ._functional import Functional from ._norm import L1Norm, L21Norm class TVNorm(Functional): r"""Generic total variation (TV) norm. Generic total variation (TV) norm with approximation of the scaled proximal operator :cite:`kamilov-2016-parallel` :cite:`kamilov-2016-minimizing` :cite:`chandler-2024-closedform`. """ has_eval = True has_prox = True def __init__( self, norm: Functional, circular: bool = True, axes: Optional[Axes] = None, input_shape: Optional[Shape] = None, input_dtype: DType = snp.float32, ): """ While initializers for :class:`.Functional` objects typically do not take `input_shape` and `input_dtype` parameters, they are included here because methods :meth:`__call__` and :meth:`prox` require instantiation of some :class:`.LinearOperator` objects, which do take these parameters. If these parameters are not provided on intialization of a :class:`TVNorm` object, then creation of the required :class:`.LinearOperator` objects is deferred until these methods are called, which can result in `JAX tracer `__ errors when they are components of a jitted function. Args: norm: Norm functional from which the TV norm is composed. circular: Flag indicating use of circular boundary conditions. axes: Axis or axes over which to apply finite difference operator. If not specified, or ``None``, differences are evaluated along all axes. input_shape: Shape of input arrays of :meth:`__call__` and :meth:`prox`. input_dtype: `dtype` of input arrays of :meth:`__call__` and :meth:`prox`. """ self.norm = norm self.circular = circular self.axes = axes self.G: Optional[LinearOperator] = None self.WP: Optional[LinearOperator] = None self.prox_ndims: Optional[int] = None self.prox_slice: Optional[Tuple] = None if input_shape is not None: self.G = self._call_operator(input_shape, input_dtype) self.WP, self.CWT, self.prox_ndims, self.prox_slice = self._prox_operators( input_shape, input_dtype ) def _call_operator(self, input_shape: Shape, input_dtype: DType) -> LinearOperator: """Construct operator required by __call__ method.""" G = FiniteDifference( input_shape, input_dtype=input_dtype, axes=self.axes, circular=self.circular, # For non-circular boundary conditions, zero-pad to the right # for equivalence with boundary conditions implemented in the # prox calculation. append=None if self.circular else 0, jit=True, ) return G def __call__(self, x: Array) -> float: """Compute the TV norm of an array. Args: x: Array for which the TV norm should be computed. Returns: TV norm of `x`. """ if self.G is None or self.G.shape[1] != x.shape: self.G = self._call_operator(x.shape, x.dtype) return self.norm(self.G @ x) def _prox_operators( self, input_shape: Shape, input_dtype: DType ) -> Tuple[LinearOperator, LinearOperator, int, Tuple]: """Construct operators required by prox method.""" axes = normalize_axes(self.axes, input_shape) ndims = len(axes) w_input_shape = ( # circular boundary: shape of input array input_shape if self.circular # non-circular boundary: shape of input array on non-differenced # axes and one greater for axes that are differenced else tuple([s + 1 if i in axes else s for i, s in enumerate(input_shape)]) # type: ignore ) W = HaarTransform(w_input_shape, input_dtype=input_dtype, axes=axes, jit=True) # type: ignore if self.circular: # slice selecting highpass component of shift-invariant Haar transform slce = snp.s_[:, 1] # No boundary extension, so fused extend and forward transform, and fused # adjoint transform and crop are just forward and adjoint respectively. WP, CWT = W, W.T else: # slice selecting non-boundary region of highpass component of # shift-invariant Haar transform slce = ( snp.s_[:], snp.s_[1], ) + tuple( [snp.s_[:-1] if i in axes else snp.s_[:] for i, s in enumerate(input_shape)] ) # type: ignore # Replicate-pad to the right (resulting in a zero after finite differencing) # on all axes subject to finite differencing. pad_width = [(0, 1) if i in axes else (0, 0) for i, s in enumerate(input_shape)] # type: ignore P = Pad( input_shape, input_dtype=input_dtype, pad_width=pad_width, mode="edge", jit=True ) # fused boundary extend and forward transform linop WP = W @ P # crop operation that is inverse of the padding operation C = Crop( crop_width=pad_width, input_shape=w_input_shape, input_dtype=input_dtype, jit=True ) # fused adjoint transform and crop linop CWT = C @ W.T return WP, CWT, ndims, slce @staticmethod def _slice_tuple_to_tuple(st: Tuple) -> Tuple: """Convert a tuple of slice or int to a tuple of tuple or int. Required here as a workaround for the unhashability of slices in Python < 3.12, since jax.jit requires static arguments to be hashable. """ return tuple([(s.start, s.stop, s.step) if isinstance(s, slice) else s for s in st]) @staticmethod def _slice_tuple_from_tuple(st: Tuple) -> Tuple: """Convert a tuple of tuple or int to a tuple of slice or int. Required here as a workaround for the unhashability of slices in Python < 3.12, since jax.jit requires static arguments to be hashable. """ return tuple([slice(*s) if isinstance(s, tuple) else s for s in st]) @staticmethod @partial(jax.jit, static_argnums=(0, 1, 2, 4)) def _prox_core( WP: LinearOperator, CWT: LinearOperator, norm: Functional, K: int, slce_rep: Tuple, v: Array, lam: float = 1.0, ) -> Array: """Core component of prox calculation.""" # Apply boundary extension (when circular==False) and single-level Haar # transform to input array. WPv: Array = WP(v) # Convert tuple of slices/ints to tuple of tuples/ints to avoid jax.jit # complaints about unhashability of slices. slce = TVNorm._slice_tuple_from_tuple(slce_rep) # Apply shrinkage to highpass component of shift-invariant Haar transform # of padded input (or to non-boundary region thereof when circular==False). WPv = WPv.at[slce].set(norm.prox(WPv[slce], snp.sqrt(2) * K * lam)) # Apply adjoint of single-level Haar transform and crop extended # part of array (when circular==False). return (1.0 / K) * CWT(WPv) def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: r"""Approximate scaled proximal operator of the TV norm. Approximation of the scaled proximal operator of the TV norm, computed via the methods described in :cite:`kamilov-2016-parallel` :cite:`kamilov-2016-minimizing` :cite:`chandler-2024-closedform`. Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lam`. **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ if self.WP is None or self.WP.shape[1] != v.shape: self.WP, self.CWT, self.prox_ndims, self.prox_slice = self._prox_operators( v.shape, v.dtype ) assert self.prox_ndims is not None assert self.prox_slice is not None K = 2 * self.prox_ndims u = TVNorm._prox_core( self.WP, self.CWT, self.norm, K, TVNorm._slice_tuple_to_tuple(self.prox_slice), v, lam ) return u class AnisotropicTVNorm(TVNorm): r"""The anisotropic total variation (TV) norm. The anisotropic total variation (TV) norm computed by .. code-block:: python ATV = scico.functional.AnisotropicTVNorm() x_norm = ATV(x) is equivalent to .. code-block:: python C = linop.FiniteDifference(input_shape=x.shape, circular=True) L1 = functional.L1Norm() x_norm = L1(C @ x) The scaled proximal operator is computed using an approximation that holds for small scaling parameters :cite:`kamilov-2016-parallel`. This does not imply that it can only be applied to problems requiring a small regularization parameter since most proximal algorithms include an additional algorithm parameter that also plays a role in the parameter of the proximal operator. For example, in :class:`.PGM` and :class:`.AcceleratedPGM`, the scaled proximal operator parameter is the regularization parameter divided by the `L0` algorithm parameter, and for :class:`.ADMM`, the scaled proximal operator parameters are the regularization parameters divided by the entries in the `rho_list` algorithm parameter. """ def __init__( self, circular: bool = False, axes: Optional[Axes] = None, input_shape: Optional[Shape] = None, input_dtype: DType = snp.float32, ): """ Args: circular: Flag indicating use of circular boundary conditions. axes: Axis or axes over which to apply finite difference operator. If not specified, or ``None``, differences are evaluated along all axes. input_shape: Shape of input arrays of :meth:`~.TVNorm.__call__` and :meth:`~.TVNorm.prox`. input_dtype: `dtype` of input arrays of :meth:`~.TVNorm.__call__` and :meth:`~.TVNorm.prox`. """ super().__init__( L1Norm(), circular=circular, axes=axes, input_shape=input_shape, input_dtype=input_dtype, ) class IsotropicTVNorm(TVNorm): r"""The isotropic total variation (TV) norm. The isotropic total variation (TV) norm computed by .. code-block:: python ATV = scico.functional.IsotropicTVNorm() x_norm = ATV(x) is equivalent to .. code-block:: python C = linop.FiniteDifference(input_shape=x.shape, circular=True) L21 = functional.L21Norm() x_norm = L21(C @ x) The scaled proximal operator is computed using an approximation that holds for small scaling parameters :cite:`kamilov-2016-minimizing`. This does not imply that it can only be applied to problems requiring a small regularization parameter since most proximal algorithms include an additional algorithm parameter that also plays a role in the parameter of the proximal operator. For example, in :class:`.PGM` and :class:`.AcceleratedPGM`, the scaled proximal operator parameter is the regularization parameter divided by the `L0` algorithm parameter, and for :class:`.ADMM`, the scaled proximal operator parameters are the regularization parameters divided by the entries in the `rho_list` algorithm parameter. """ def __init__( self, circular: bool = False, axes: Optional[Axes] = None, input_shape: Optional[Shape] = None, input_dtype: DType = snp.float32, ): r""" Args: circular: Flag indicating use of circular boundary conditions. axes: Axis or axes over which to apply finite difference operator. If not specified, or ``None``, differences are evaluated along all axes. input_shape: Shape of input arrays of :meth:`~.TVNorm.__call__` and :meth:`~.TVNorm.prox`. input_dtype: `dtype` of input arrays of :meth:`~.TVNorm.__call__` and :meth:`~.TVNorm.prox`. """ super().__init__( L21Norm(), circular=circular, axes=axes, input_shape=input_shape, input_dtype=input_dtype, ) class SingleAxisFiniteSum(LinearOperator): r"""Two-point sum operator acting along a single axis. Boundary handling is circular, so that the sum operator corresponds to the matrix .. math:: \left(\begin{array}{rrrrr} 1 & 0 & 0 & \ldots & 0\\ 1 & 1 & 0 & \ldots & 0\\ 0 & 1 & 1 & \ldots & 0\\ \vdots & \vdots & \ddots & \ddots & \vdots\\ 0 & 0 & \ldots & 1 & 1\\ 1 & 0 & \dots & 0 & 1 \end{array}\right) \;. """ def __init__( self, input_shape: Shape, input_dtype: DType = snp.float32, axis: int = -1, jit: bool = True, **kwargs, ): r""" Args: input_shape: Shape of input array. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. axis: Axis over which to apply sum operator. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. """ if not isinstance(axis, int): raise TypeError( f"Expected argument 'axis' to be of type int, got {type(axis)} instead." ) if axis < 0: axis = len(input_shape) + axis if axis >= len(input_shape): raise ValueError( f"Invalid argument 'axis' specified ({axis}); 'axis' must be less than " f"len(input_shape)={len(input_shape)}." ) self.axis = axis super().__init__( input_shape=input_shape, output_shape=input_shape, input_dtype=input_dtype, output_dtype=input_dtype, jit=jit, **kwargs, ) def _eval(self, x: snp.Array) -> snp.Array: return x + snp.roll(x, -1, self.axis) class FiniteSum(VerticalStack): """Two-point sum operator. Compute two-point sums along the specified axes, returning the results stacked on axis 0 of a :class:`jax.Array`. See :class:`SingleAxisFiniteSum` for boundary handling details. """ def __init__( self, input_shape: Shape, input_dtype: DType = snp.float32, axes: Optional[Axes] = None, jit: bool = True, **kwargs, ): r""" Args: input_shape: Shape of input array. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. axes: Axis or axes over which to apply sum operator. If not specified, or ``None``, sums are evaluated along all axes. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. """ self.axes, ops = linop_over_axes( SingleAxisFiniteSum, input_shape, axes=axes, input_dtype=input_dtype, jit=False, ) super().__init__( ops, # type: ignore jit=jit, **kwargs, ) class SingleAxisHaarTransform(VerticalStack): """Single-level shift-invariant Haar transform along a single axis. Compute one level of a shift-invariant Haar transform along the specified axis, returning the results in a :class:`jax.Array` consisting of sum and difference components (corresponding to lowpass and highpass filtered components respectively) stacked on axis 0. See :class:`SingleAxisFiniteSum` for boundary handling details. """ def __init__( self, input_shape: Shape, input_dtype: DType = snp.float32, axis: int = -1, jit: bool = True, **kwargs, ): r""" Args: input_shape: Shape of input array. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. axis: Axis over which to apply Haar transform. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. """ self.axis = axis self.HaarL = (1.0 / snp.sqrt(2.0)) * SingleAxisFiniteSum( input_shape, input_dtype=input_dtype, axis=axis, jit=jit, **kwargs ) self.HaarH = (1.0 / snp.sqrt(2.0)) * SingleAxisFiniteDifference( input_shape, input_dtype=input_dtype, axis=axis, circular=True, jit=jit, **kwargs ) super().__init__( (self.HaarL, self.HaarH), jit=jit, **kwargs, ) class HaarTransform(VerticalStack): """Single-level shift-invariant Haar transform. Compute one level of a shift-invariant Haar transform along the specified axes, returning the results in a :class:`jax.Array`. See :class:`SingleAxisHaarTransform` for details of the transform along each axis. """ def __init__( self, input_shape: Shape, input_dtype: DType = snp.float32, axes: Optional[Axes] = None, jit: bool = True, **kwargs, ): r""" Args: input_shape: Shape of input array. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. axes: Axis or axes over which to apply Haar transform. If not specified, or ``None``, the transform is evaluated along all axes. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. """ self.axes, ops = linop_over_axes( SingleAxisHaarTransform, input_shape, axes=axes, input_dtype=input_dtype, jit=False, ) super().__init__( ops, # type: ignore jit=jit, **kwargs, ) ================================================ FILE: scico/linop/__init__.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Linear operator functions and classes.""" import sys from ._circconv import CircularConvolve from ._convolve import Convolve, ConvolveByX from ._dft import DFT from ._diag import Diagonal, Identity, ScaledIdentity from ._diff import FiniteDifference, SingleAxisFiniteDifference from ._func import Crop, Pad, Reshape, Slice, Sum, Transpose, linop_from_function from ._grad import ( CylindricalGradient, PolarGradient, ProjectedGradient, SphericalGradient, ) from ._linop import ComposedLinearOperator, LinearOperator from ._matrix import MatrixOperator from ._stack import DiagonalReplicated, DiagonalStack, VerticalStack, linop_over_axes from ._util import jacobian, operator_norm, power_iteration, valid_adjoint __all__ = [ "CircularConvolve", "Convolve", "DFT", "Diagonal", "FiniteDifference", "ProjectedGradient", "PolarGradient", "CylindricalGradient", "SphericalGradient", "SingleAxisFiniteDifference", "Identity", "DiagonalReplicated", "VerticalStack", "DiagonalStack", "MatrixOperator", "Pad", "Crop", "Reshape", "ScaledIdentity", "Slice", "Sum", "Transpose", "LinearOperator", "ComposedLinearOperator", "linop_from_function", "linop_over_axes", "operator_norm", "power_iteration", "valid_adjoint", "jacobian", ] # Imported items in __all__ appear to originate in top-level linop module for name in __all__: getattr(sys.modules[__name__], name).__module__ = __name__ ================================================ FILE: scico/linop/_circconv.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Circular convolution linear operator.""" import math from typing import Optional, Sequence, Tuple, Union import numpy as np from jax.dtypes import result_type import scico.numpy as snp from scico.numpy.util import is_nested from scico.operator import Operator from scico.typing import DType, Shape from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar class CircularConvolve(LinearOperator): r"""A circular convolution linear operator. This linear operator implements circular, multi-dimensional convolution via pointwise multiplication in the DFT domain. In its simplest form, it implements a single convolution and can be represented by linear operator :math:`H` such that .. math:: H \mb{x} = \mb{h} \ast \mb{x} \;, where :math:`\mb{h}` is a user-defined filter. More complex forms, corresponding to the case where either the input (as represented by parameter `input_shape`) or filter (parameter `h`) have additional axes that are not involved in the convolution are also supported. These follow numpy broadcasting rules. For example: Additional axes in the input :math:`\mb{x}` and not in :math:`\mb{h}` corresponds to the operation .. math:: H \mb{x} = \left( \begin{array}{ccc} H' & 0 & \ldots\\ 0 & H' & \ldots\\ \vdots & \vdots & \ddots \end{array} \right) \left( \begin{array}{c} \mb{x}_0\\ \mb{x}_1\\ \vdots \end{array} \right) \;. Additional axes in :math:`\mb{h}` corresponds to multiple filters, which will be denoted by :math:`\{\mb{h}_m\}`, with corresponding individual linear operations being denoted by :math:`h_m \mb{x}_m = \mb{h}_m \ast \mb{x}_m`. The full linear operator can then be represented as .. math:: H \mb{x} = \left( \begin{array}{c} H_0\\ H_1\\ \vdots \end{array} \right) \mb{x} \;. if the input is singleton, and as .. math:: H \mb{x} = \left( \begin{array}{ccc} H_0 & 0 & \ldots\\ 0 & H_1 & \ldots\\ \vdots & \vdots & \ddots \end{array} \right) \left( \begin{array}{c} \mb{x}_0\\ \mb{x}_1\\ \vdots \end{array} \right) otherwise. """ def __init__( self, h: snp.Array, input_shape: Shape, ndims: Optional[int] = None, input_dtype: DType = snp.float32, h_is_dft: bool = False, h_center: Optional[Union[snp.Array, np.ndarray, Sequence, float, int]] = None, jit: bool = True, **kwargs, ): """ Args: h: Array of filters. input_shape: Shape of input array. ndims: Number of (trailing) dimensions of the input and `h` involved in the convolution. Defaults to the number of dimensions in the input. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. h_is_dft: Flag indicating whether `h` is in the DFT domain. h_center: Array of length `ndims` specifying the center of the filter. Defaults to the upper left corner, i.e., `h_center = [0, 0, ..., 0]`, may be noninteger. May be a ``float`` or ``int`` if `h` is one-dimensional. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. """ if ndims is None: self.ndims = len(input_shape) else: self.ndims = ndims if h_is_dft and h_center is not None: raise ValueError("Argument 'h_center' must be None when h_is_dft=True.") self.h_center = h_center if h_is_dft: self.h_dft = h output_dtype = snp.dtype(input_dtype) # cannot infer from h_dft because it is complex else: fft_shape = input_shape[-self.ndims :] fft_axes = list(range(h.ndim - self.ndims, h.ndim)) self.h_dft = snp.fft.fftn(h, s=fft_shape, axes=fft_axes) output_dtype = result_type(h.dtype, input_dtype) if self.h_center is not None: shift = self._dft_center_shift(input_shape) self.h_dft = self.h_dft * shift self.real = output_dtype.kind != "c" try: output_shape = np.broadcast_shapes(self.h_dft.shape, input_shape) except ValueError: raise ValueError( f"Shape of 'h' after padding was {self.h_dft.shape}, needs to be compatible " f"for broadcasting with {input_shape}." ) self.batch_axes = tuple( range(0, len(output_shape) - len(input_shape)) ) # used in adjoint to undo broadcasting self.ifft_axes = list(range(len(output_shape) - self.ndims, len(output_shape))) self.x_fft_axes = list(range(len(input_shape) - self.ndims, len(input_shape))) super().__init__( input_shape=input_shape, output_shape=output_shape, input_dtype=input_dtype, output_dtype=output_dtype, jit=jit, **kwargs, ) def _dft_center_shift(self, input_shape) -> np.ndarray: """Compute DFT domain shift required for centering. See doi:10.1109/78.700979 and doi:10.1109/LSP.2012.2191280 for details of the shift computation. """ if isinstance(self.h_center, (float, int)): # support float/int h_center offset = -np.array( [ self.h_center, ] ) else: # support array/list/tuple h_center offset = -np.array(self.h_center) shifts: Tuple[np.ndarray, ...] = np.ix_( *tuple( np.select( [np.arange(s) < s / 2, np.arange(s) == s / 2, np.arange(s) > s / 2], [ np.exp(-1j * k * 2 * np.pi * np.arange(s) / s), np.cos(k * np.pi), np.exp(1j * k * 2 * np.pi * (s - np.arange(s)) / s), ], # type: ignore ) for k, s in zip(offset, input_shape[-self.ndims :]) ) ) # prevent accidental promotion to double shifts = tuple(s.astype(self.h_dft.dtype) for s in shifts) shift = math.prod(shifts) # np.prod warns assert isinstance(shift, np.ndarray) return shift def _eval(self, x: snp.Array) -> snp.Array: x = x.astype(self.input_dtype) x_dft = snp.fft.fftn(x, axes=self.x_fft_axes) hx = snp.fft.ifftn( self.h_dft * x_dft, axes=self.ifft_axes, ) if self.real: hx = hx.real return hx def _adj(self, x: snp.Array) -> snp.Array: # type: ignore x_dft = snp.fft.fftn(x, axes=self.ifft_axes) H_adj_x = snp.fft.ifftn( snp.conj(self.h_dft) * x_dft, axes=self.ifft_axes, s=self.input_shape[-self.ndims :], ) H_adj_x = snp.sum(H_adj_x, axis=self.batch_axes) # adjoint of the broadcast if self.real: H_adj_x = H_adj_x.real return H_adj_x @_wrap_add_sub def __add__(self, other): if self.ndims != other.ndims: raise ValueError(f"Incompatible ndims: {self.ndims} != {other.ndims}.") return CircularConvolve( h=self.h_dft + other.h_dft, input_shape=self.input_shape, input_dtype=result_type(self.input_dtype, other.input_dtype), ndims=self.ndims, h_is_dft=True, ) @_wrap_add_sub def __sub__(self, other): if self.ndims != other.ndims: raise ValueError(f"Incompatible ndims: {self.ndims} != {other.ndims}.") return CircularConvolve( h=self.h_dft - other.h_dft, input_shape=self.input_shape, input_dtype=result_type(self.input_dtype, other.input_dtype), ndims=self.ndims, h_is_dft=True, ) @_wrap_mul_div_scalar def __mul__(self, scalar): return CircularConvolve( h=self.h_dft * scalar, input_shape=self.input_shape, ndims=self.ndims, input_dtype=self.input_dtype, h_is_dft=True, ) @_wrap_mul_div_scalar def __truediv__(self, scalar): return CircularConvolve( h=self.h_dft / scalar, input_shape=self.input_shape, ndims=self.ndims, input_dtype=self.input_dtype, h_is_dft=True, ) @staticmethod def from_operator( H: Operator, ndims: Optional[int] = None, center: Optional[Shape] = None, jit: bool = True ): r"""Construct a CircularConvolve version of a given operator. Construct a CircularConvolve version of a given operator, which is assumed to be linear and shift invariant (LSI). Args: H: Input operator. ndims: Number of trailing dims over which the H acts. center: Location at which to place the Kronecker delta. For LSI inputs, this will not matter. Defaults to the center of H.input_shape, i.e., (n_1 // 2, n_2 // 2, ...). jit: If ``True``, jit the resulting `CircularConvolve`. """ if is_nested(H.input_shape): raise ValueError( f"H.input_shape ({H.input_shape}) suggests that H " "takes a BlockArray as input, which is not supported " "by this function." ) if ndims is None: ndims = len(H.input_shape) else: ndims = ndims if center is None: center = tuple(d // 2 for d in H.input_shape[-ndims:]) # type: ignore # compute impulse response d = snp.zeros(H.input_shape, H.input_dtype) d = d.at[(Ellipsis,) + center].set(1.0) Hd = H @ d # build CircularConvolve return CircularConvolve( Hd, H.input_shape, # type: ignore ndims=ndims, input_dtype=H.input_dtype, h_center=snp.array(center), jit=jit, ) ================================================ FILE: scico/linop/_convolve.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Convolution linear operator class.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations import numpy as np from jax.dtypes import result_type from jax.scipy.signal import convolve import scico.numpy as snp from scico.typing import DType, Shape from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar class Convolve(LinearOperator): """A convolution linear operator.""" def __init__( self, h: snp.Array, input_shape: Shape, input_dtype: DType = np.float32, mode: str = "full", jit: bool = True, **kwargs, ): r"""Wrap :func:`jax.scipy.signal.convolve` as a :class:`.LinearOperator`. Args: h: Convolutional filter. Must have same number of dimensions as `len(input_shape)`. input_shape: Shape of input array. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. mode: A string indicating the size of the output. One of "full", "valid", "same". Defaults to "full". jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. For more details on `mode`, see :func:`jax.scipy.signal.convolve`. """ self.h: snp.Array # : Convolution kernel self.mode: str # : Convolution mode if h.ndim != len(input_shape): raise ValueError(f"h.ndim = {h.ndim} must equal len(input_shape) = {len(input_shape)}.") self.h = h if mode not in ["full", "valid", "same"]: raise ValueError(f"Invalid mode={mode}; must be one of 'full', 'valid', 'same'.") self.mode = mode if input_dtype is None: input_dtype = self.h.dtype output_dtype = result_type(input_dtype, self.h.dtype) super().__init__( input_shape=input_shape, input_dtype=input_dtype, output_dtype=output_dtype, jit=jit, **kwargs, ) def _eval(self, x: snp.Array) -> snp.Array: return convolve(x, self.h, mode=self.mode) @_wrap_add_sub def __add__(self, other): if self.mode != other.mode: raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.") if self.h.shape == other.h.shape: return Convolve( h=self.h + other.h, input_shape=self.input_shape, input_dtype=result_type(self.input_dtype, other.input_dtype), mode=self.mode, output_shape=self.output_shape, adj_fn=lambda x: self.adj(x) + other.adj(x), ) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") @_wrap_add_sub def __sub__(self, other): if self.mode != other.mode: raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.") if self.h.shape == other.h.shape: return Convolve( h=self.h - other.h, input_shape=self.input_shape, input_dtype=result_type(self.input_dtype, other.input_dtype), mode=self.mode, output_shape=self.output_shape, adj_fn=lambda x: self.adj(x) - other.adj(x), ) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") @_wrap_mul_div_scalar def __mul__(self, scalar): return Convolve( h=self.h * scalar, input_shape=self.input_shape, input_dtype=result_type(self.input_dtype, type(scalar)), mode=self.mode, output_shape=self.output_shape, adj_fn=lambda x: snp.conj(scalar) * self.adj(x), ) @_wrap_mul_div_scalar def __truediv__(self, scalar): return Convolve( h=self.h / scalar, input_shape=self.input_shape, input_dtype=result_type(self.input_dtype, type(scalar)), mode=self.mode, output_shape=self.output_shape, adj_fn=lambda x: self.adj(x) / snp.conj(scalar), ) class ConvolveByX(LinearOperator): """A LinearOperator that performs convolution as a function of the first argument. The :class:`LinearOperator` `ConvolveByX(x=x)(y)` implements `jax.scipy.signal.convolve(x, y)`. """ def __init__( self, x: snp.Array, input_shape: Shape, input_dtype: DType = np.float32, mode: str = "full", jit: bool = True, **kwargs, ): r""" Args: x: Convolutional filter. Must have same number of dimensions as `len(input_shape)`. input_shape: Shape of input array. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. mode: A string indicating the size of the output. One of "full", "valid", "same". Defaults to "full". jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. For more details on `mode`, see :func:`jax.scipy.signal.convolve`. """ self.x: snp.Array # : Fixed signal to convolve with self.mode: str # : Convolution mode if x.ndim != len(input_shape): raise ValueError(f"x.ndim = {x.ndim} must equal len(input_shape) = {len(input_shape)}.") # Ensure that x is a numpy or jax array. if not snp.util.is_arraylike(x): raise TypeError(f"Expected numpy or jax array, got {type(x)}.") self.x = x if mode not in ["full", "valid", "same"]: raise ValueError(f"Invalid mode={mode}; must be one of 'full', 'valid', 'same'.") self.mode = mode if input_dtype is None: input_dtype = x.dtype output_dtype = result_type(input_dtype, x.dtype) super().__init__( input_shape=input_shape, input_dtype=input_dtype, output_dtype=output_dtype, jit=jit, **kwargs, ) def _eval(self, h: snp.Array) -> snp.Array: return convolve(self.x, h, mode=self.mode) @_wrap_add_sub def __add__(self, other): if self.mode != other.mode: raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.") if self.x.shape == other.x.shape: return ConvolveByX( x=self.x + other.x, input_shape=self.input_shape, input_dtype=result_type(self.input_dtype, other.input_dtype), mode=self.mode, output_shape=self.output_shape, adj_fn=lambda x: self.adj(x) + other.adj(x), ) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") @_wrap_add_sub def __sub__(self, other): if self.mode != other.mode: raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.") if self.x.shape == other.x.shape: return ConvolveByX( x=self.x - other.x, input_shape=self.input_shape, input_dtype=result_type(self.input_dtype, other.input_dtype), mode=self.mode, output_shape=self.output_shape, adj_fn=lambda x: self.adj(x) - other.adj(x), ) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") @_wrap_mul_div_scalar def __mul__(self, scalar): return ConvolveByX( x=self.x * scalar, input_shape=self.input_shape, input_dtype=result_type(self.input_dtype, type(scalar)), mode=self.mode, output_shape=self.output_shape, adj_fn=lambda x: snp.conj(scalar) * self.adj(x), ) @_wrap_mul_div_scalar def __truediv__(self, scalar): return ConvolveByX( x=self.x / scalar, input_shape=self.input_shape, input_dtype=result_type(self.input_dtype, type(scalar)), mode=self.mode, output_shape=self.output_shape, adj_fn=lambda x: self.adj(x) / snp.conj(scalar), ) ================================================ FILE: scico/linop/_dft.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Discrete Fourier transform linear operator class.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from typing import Optional, Sequence import numpy as np import scico.numpy as snp from scico.typing import Shape from ._linop import LinearOperator class DFT(LinearOperator): r"""Multi-dimensional discrete Fourier transform.""" def __init__( self, input_shape: Shape, axes: Optional[Sequence] = None, axes_shape: Optional[Shape] = None, norm: Optional[str] = None, jit: bool = True, **kwargs, ): r""" Args: input_shape: Shape of input array. axes: Axes over which to compute the DFT. If ``None``, the DFT is computed over all axes. axes_shape: Output shape on the subset of array axes selected by `axes`. This parameter has the same behavior as the `s` parameter of :func:`numpy.fft.fftn`. norm: DFT normalization mode. See the `norm` parameter of :func:`numpy.fft.fftn`. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. """ if axes is not None and axes_shape is not None and len(axes) != len(axes_shape): raise ValueError( f"len(axes)={len(axes)} does not equal len(axes_shape)={len(axes_shape)}." ) if axes_shape is not None: if axes is None: axes = tuple(range(len(input_shape) - len(axes_shape), len(input_shape))) tmp_output_shape = list(input_shape) for i, s in zip(axes, axes_shape): tmp_output_shape[i] = s output_shape = tuple(tmp_output_shape) else: output_shape = input_shape if axes is None or axes_shape is None: self.inv_axes_shape = None else: self.inv_axes_shape = [input_shape[i] for i in axes] self.axes = axes self.axes_shape = axes_shape self.norm = norm # To satisfy mypy -- DFT shapes must be tuples, not list of tuple # These get set inside of super().__init__ call, but we want to have # more restrictive type than the general LinearOperator self.input_shape: Shape self.output_shape: Shape super().__init__( input_shape=input_shape, output_shape=output_shape, input_dtype=np.complex64, output_dtype=np.complex64, jit=jit, **kwargs, ) def _eval(self, x: snp.Array) -> snp.Array: return snp.fft.fftn(x, s=self.axes_shape, axes=self.axes, norm=self.norm) def inv(self, z: snp.Array) -> snp.Array: """Compute the inverse of this LinearOperator. Compute the inverse of this LinearOperator applied to `z`. Args: z: Input array to inverse DFT. """ return snp.fft.ifftn(z, s=self.inv_axes_shape, axes=self.axes, norm=self.norm) ================================================ FILE: scico/linop/_diag.py ================================================ # Copyright (C) 2020-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Miscellaneous linear operator definitions.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from typing import Optional, Union import scico.numpy as snp from scico.numpy import Array, BlockArray from scico.numpy.util import broadcast_nested_shapes, is_nested from scico.operator._operator import _wrap_mul_div_scalar from scico.typing import BlockShape, DType, Shape from ._linop import LinearOperator, _wrap_add_sub __all__ = ["Diagonal", "Identity", "ScaledIdentity"] class Diagonal(LinearOperator): """Diagonal linear operator.""" def __init__( self, diagonal: Union[Array, BlockArray], input_shape: Optional[Union[Shape, BlockShape]] = None, input_dtype: Optional[DType] = None, **kwargs, ): r""" Args: diagonal: Diagonal elements of this :class:`LinearOperator`. input_shape: Shape of input array. By default, equal to `diagonal.shape`, but may also be set to a shape that is broadcast-compatible with `diagonal.shape`. input_dtype: `dtype` of input argument. The default, ``None``, means `diagonal.dtype`. """ self._diagonal = diagonal if input_shape is None: input_shape = self._diagonal.shape if input_dtype is None: input_dtype = self._diagonal.dtype if isinstance(diagonal, BlockArray) and is_nested(input_shape): output_shape = broadcast_nested_shapes(input_shape, self._diagonal.shape) elif not isinstance(diagonal, BlockArray) and not is_nested(input_shape): output_shape = snp.broadcast_shapes(input_shape, self._diagonal.shape) elif isinstance(diagonal, BlockArray): raise ValueError("Argument 'diagonal' was a BlockArray but input_shape was not nested.") else: raise ValueError("Argument 'diagonal' was not a BlockArray but input_shape was nested.") super().__init__( input_shape=input_shape, input_dtype=input_dtype, output_shape=output_shape, output_dtype=input_dtype, **kwargs, ) def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: return self._diagonal * x @property def diagonal(self) -> Union[Array, BlockArray]: """Return an array representing the diagonal component.""" return self._diagonal @property def T(self) -> Diagonal: """Transpose of this :class:`Diagonal`.""" return self def conj(self) -> Diagonal: """Complex conjugate of this :class:`Diagonal`.""" return Diagonal(diagonal=self.diagonal.conj()) @property def H(self) -> Diagonal: """Hermitian transpose of this :class:`Diagonal`.""" return self.conj() @property def gram_op(self) -> Diagonal: """Gram operator of this :class:`Diagonal`. Return a new :class:`Diagonal` :code:`G` such that :code:`G(x) = A.adj(A(x)))`. """ return Diagonal(diagonal=self.diagonal.conj() * self.diagonal) @_wrap_add_sub def __add__(self, other): if self.shape == other.shape: return Diagonal(diagonal=self.diagonal + other.diagonal) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") @_wrap_add_sub def __sub__(self, other): if self.shape == other.shape: return Diagonal(diagonal=self.diagonal - other.diagonal) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") @_wrap_mul_div_scalar def __mul__(self, scalar): return Diagonal(diagonal=self.diagonal * scalar) @_wrap_mul_div_scalar def __truediv__(self, scalar): return Diagonal(diagonal=self.diagonal / scalar) def __matmul__(self, other): # self @ other if isinstance(other, Diagonal): if self.shape == other.shape: return Diagonal(diagonal=self.diagonal * other.diagonal) raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.") else: return self(other) def norm(self, ord=None): # pylint: disable=W0622 """Compute the matrix norm of the diagonal operator. Valid values of `ord` and the corresponding norm definition are those listed under "norm for matrices" in the :func:`scico.numpy.linalg.norm` documentation. """ ordfunc = { "fro": lambda x: snp.linalg.norm(x), "nuc": lambda x: snp.sum(snp.abs(x)), -snp.inf: lambda x: snp.abs(x).min(), snp.inf: lambda x: snp.abs(x).max(), } mord = ord if mord is None: mord = "fro" elif mord in (-1, -2): mord = -snp.inf elif mord in (1, 2): mord = snp.inf if mord not in ordfunc: raise ValueError(f"Invalid value {ord} for argument 'ord'.") return ordfunc[mord](self._diagonal) class ScaledIdentity(Diagonal): """Scaled identity operator.""" def __init__( self, scalar: float, input_shape: Union[Shape, BlockShape], input_dtype: DType = snp.float32, **kwargs, ): """ Args: scalar: Scaling of the identity. input_shape: Shape of input array. input_dtype: `dtype` of input argument. """ if is_nested(input_shape): diagonal = scalar * snp.ones(((),) * len(input_shape), dtype=input_dtype) else: diagonal = scalar * snp.ones((), dtype=input_dtype) super().__init__( diagonal=diagonal, input_shape=input_shape, input_dtype=input_dtype, **kwargs, ) @property def diagonal(self) -> Union[Array, BlockArray]: return self._diagonal def conj(self) -> ScaledIdentity: """Complex conjugate of this :class:`ScaledIdentity`.""" return ScaledIdentity( scalar=self._diagonal.conj(), input_shape=self.input_shape, input_dtype=self.input_dtype ) @property def gram_op(self) -> ScaledIdentity: """Gram operator of this :class:`ScaledIdentity`.""" return ScaledIdentity( scalar=self._diagonal * self._diagonal.conj(), input_shape=self.input_shape, input_dtype=self.input_dtype, ) @_wrap_add_sub def __add__(self, other): if self.input_shape == other.input_shape: return ScaledIdentity( scalar=self._diagonal + other._diagonal, input_shape=self.input_shape, input_dtype=self.input_dtype, ) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") @_wrap_add_sub def __sub__(self, other): if self.input_shape == other.input_shape: return ScaledIdentity( scalar=self._diagonal - other._diagonal, input_shape=self.input_shape, input_dtype=self.input_dtype, ) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") @_wrap_mul_div_scalar def __mul__(self, scalar): return ScaledIdentity( scalar=self._diagonal * scalar, input_shape=self.input_shape, input_dtype=self.input_dtype, ) @_wrap_mul_div_scalar def __truediv__(self, scalar): return ScaledIdentity( scalar=self._diagonal / scalar, input_shape=self.input_shape, input_dtype=self.input_dtype, ) def __matmul__(self, other): # self @ other if isinstance(other, Diagonal): if self.shape != other.shape: raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.") if isinstance(other, ScaledIdentity): return ScaledIdentity( scalar=self._diagonal * other._diagonal, input_shape=self.input_shape, input_dtype=self.input_dtype, ) else: return Diagonal(diagonal=self._diagonal * other.diagonal) else: return self(other) def norm(self, ord=None): # pylint: disable=W0622 """Compute the matrix norm of the identity operator. Valid values of `ord` and the corresponding norm definition are those listed under "norm for matrices" in the :func:`scico.numpy.linalg.norm` documentation. """ N = self.input_size if ord is None or ord == "fro": return snp.abs(self._diagonal) * snp.sqrt(N) elif ord == "nuc": return snp.abs(self._diagonal) * N elif ord in (-snp.inf, -1, -2, 1, 2, snp.inf): return snp.abs(self._diagonal) else: raise ValueError(f"Invalid value {ord} for argument 'ord'.") class Identity(ScaledIdentity): """Identity operator.""" def __init__( self, input_shape: Union[Shape, BlockShape], input_dtype: DType = snp.float32, **kwargs ): """ Args: input_shape: Shape of input array. input_dtype: `dtype` of input argument. """ super().__init__( scalar=1.0, input_shape=input_shape, input_dtype=input_dtype, **kwargs, ) def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: return x def conj(self) -> Identity: """Complex conjugate of this :class:`Identity`.""" return self @property def gram_op(self) -> Identity: """Gram operator of this :class:`Identity`.""" return self def __matmul__(self, other): return other def __rmatmul__(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: return x ================================================ FILE: scico/linop/_diff.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Finite difference linear operator class.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from typing import Literal, Optional, Union import numpy as np import scico.numpy as snp from scico.typing import Axes, DType, Shape from ._linop import LinearOperator from ._stack import VerticalStack, linop_over_axes class FiniteDifference(VerticalStack): """Finite difference operator. Compute finite differences along the specified axes, returning the results in a :class:`jax.Array` (when possible) or :class:`BlockArray`. See :class:`VerticalStack` for details on how this choice is made. See :class:`SingleAxisFiniteDifference` for the mathematical implications of the different boundary handling options `prepend`, `append`, and `circular`. Example ------- >>> A = FiniteDifference((2, 3)) >>> x = snp.array([[1, 2, 4], ... [0, 4, 1]]) >>> (A @ x)[0] Array([[-1, 2, -3]], dtype=int32) >>> (A @ x)[1] Array([[ 1, 2], [ 4, -3]], dtype=int32) """ def __init__( self, input_shape: Shape, input_dtype: DType = np.float32, axes: Optional[Axes] = None, prepend: Optional[Union[Literal[0], Literal[1]]] = None, append: Optional[Union[Literal[0], Literal[1]]] = None, circular: bool = False, jit: bool = True, **kwargs, ): r""" Args: input_shape: Shape of input array. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. axes: Axis or axes over which to apply finite difference operator. If not specified, or ``None``, differences are evaluated along all axes. prepend: Flag indicating handling of the left/top/etc. boundary. If ``None``, there is no boundary extension. Values of `0` or `1` indicate respectively that zeros or the initial value in the array are prepended to the difference array. append: Flag indicating handling of the right/bottom/etc. boundary. If ``None``, there is no boundary extension. Values of `0` or `1` indicate respectively that zeros or -1 times the final value in the array are appended to the difference array. circular: If ``True``, perform circular differences, i.e., include x[-1] - x[0]. If ``True``, `prepend` and `append` must both be ``None``. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. """ self.axes, ops = linop_over_axes( SingleAxisFiniteDifference, input_shape, axes=axes, input_dtype=input_dtype, prepend=prepend, append=append, circular=circular, jit=False, ) super().__init__( ops, # type: ignore jit=jit, **kwargs, ) class SingleAxisFiniteDifference(LinearOperator): r"""Finite difference operator acting along a single axis. By default (i.e. `prepend` and `append` set to ``None`` and `circular` set to ``False``), the difference operator corresponds to the matrix .. math:: \left(\begin{array}{rrrrr} -1 & 1 & 0 & \ldots & 0\\ 0 & -1 & 1 & \ldots & 0\\ \vdots & \vdots & \ddots & \ddots & \vdots\\ 0 & 0 & \ldots & -1 & 1 \end{array}\right) \;, mapping :math:`\mbb{R}^N \rightarrow \mbb{R}^{N-1}`, while if `circular` is ``True``, it corresponds to the :math:`\mbb{R}^N \rightarrow \mbb{R}^N` mapping .. math:: \left(\begin{array}{rrrrr} -1 & 1 & 0 & \ldots & 0\\ 0 & -1 & 1 & \ldots & 0\\ \vdots & \vdots & \ddots & \ddots & \vdots\\ 0 & 0 & \ldots & -1 & 1\\ 1 & 0 & \dots & 0 & -1 \end{array}\right) \;. Other possible choices include `prepend` set to ``None`` and `append` set to `0`, giving the :math:`\mbb{R}^N \rightarrow \mbb{R}^N` mapping .. math:: \left(\begin{array}{rrrrr} -1 & 1 & 0 & \ldots & 0\\ 0 & -1 & 1 & \ldots & 0\\ \vdots & \vdots & \ddots & \ddots & \vdots\\ 0 & 0 & \ldots & -1 & 1\\ 0 & 0 & \dots & 0 & 0 \end{array}\right) \;, and both `prepend` and `append` set to `1`, giving the :math:`\mbb{R}^N \rightarrow \mbb{R}^{N+1}` mapping .. math:: \left(\begin{array}{rrrrr} 1 & 0 & 0 & \ldots & 0\\ -1 & 1 & 0 & \ldots & 0\\ 0 & -1 & 1 & \ldots & 0\\ \vdots & \vdots & \ddots & \ddots & \vdots\\ 0 & 0 & \ldots & -1 & 1\\ 0 & 0 & \dots & 0 & -1 \end{array}\right) \;. """ def __init__( self, input_shape: Shape, input_dtype: DType = np.float32, axis: int = -1, prepend: Optional[Union[Literal[0], Literal[1]]] = None, append: Optional[Union[Literal[0], Literal[1]]] = None, circular: bool = False, jit: bool = True, **kwargs, ): r""" Args: input_shape: Shape of input array. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. axis: Axis over which to apply finite difference operator. prepend: Flag indicating handling of the left/top/etc. boundary. If ``None``, there is no boundary extension. Values of `0` or `1` indicate respectively that zeros or the initial value in the array are prepended to the difference array. append: Flag indicating handling of the right/bottom/etc. boundary. If ``None``, there is no boundary extension. Values of `0` or `1` indicate respectively that zeros or -1 times the final value in the array are appended to the difference array. circular: If ``True``, perform circular differences, i.e., include x[-1] - x[0]. If ``True``, `prepend` and `append` must both be ``None``. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. """ if not isinstance(axis, int): raise TypeError( f"Expected argument 'axis' to be of type int, got {type(axis)} instead." ) if axis < 0: axis = len(input_shape) + axis if axis >= len(input_shape): raise ValueError( f"Invalid axis {axis} specified; axis must be less than " f"len(input_shape)={len(input_shape)}." ) self.axis = axis if circular and (prepend is not None or append is not None): raise ValueError( "Argument 'circular' must be False if either prepend or append is not None." ) if prepend not in [None, 0, 1]: raise ValueError("Argument 'prepend' may only take values None, 0, or 1.") if append not in [None, 0, 1]: raise ValueError("Argument 'append' may only take values None, 0, or 1.") self.prepend = prepend self.append = append self.circular = circular if self.circular: output_shape = input_shape else: output_shape = tuple( x + ((i == axis) * ((self.prepend is not None) + (self.append is not None) - 1)) for i, x in enumerate(input_shape) ) super().__init__( input_shape=input_shape, output_shape=output_shape, input_dtype=input_dtype, output_dtype=input_dtype, jit=jit, **kwargs, ) def _eval(self, x: snp.Array) -> snp.Array: prepend = None append = None if self.circular: # Append a copy of the initial value at the end of the array so that the difference # array includes the difference across the right/bottom/etc. boundary. ind = tuple( slice(0, 1) if i == self.axis else slice(None) for i in range(len(self.input_shape)) ) append = x[ind] else: if self.prepend == 0: # Prepend a 0 to the difference array by prepending a copy of the initial value # before the difference is computed. ind = tuple( slice(0, 1) if i == self.axis else slice(None) for i in range(len(self.input_shape)) ) prepend = x[ind] elif self.prepend == 1: # Prepend a copy of the initial value to the difference array by prepending a 0 # before the difference is computed. prepend = 0 if self.append == 0: # Append a 0 to the difference array by appending a copy of the initial value # before the difference is computed. ind = tuple( slice(-1, None) if i == self.axis else slice(None) for i in range(len(self.input_shape)) ) append = x[ind] elif self.append == 1: # Append a copy of the initial value to the difference array by appending a 0 # before the difference is computed. append = 0 return snp.diff(x, axis=self.axis, prepend=prepend, append=append) ================================================ FILE: scico/linop/_func.py ================================================ # Copyright (C) 2020-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Linear operators constructed from functions.""" from typing import Any, Callable, Optional, Sequence, Union import jax import scico.numpy as snp from scico._core import linear_transpose from scico.numpy.util import indexed_shape, is_nested from scico.typing import ArrayIndex, BlockShape, DType, Shape from ._linop import LinearOperator __all__ = ["operator_from_function", "Tranpose", "Sum", "Crop", "Pad", "Reshape", "Slice"] def linop_from_function(f: Callable, classname: str, f_name: Optional[str] = None): """Make a :class:`LinearOperator` from a function. Example ------- >>> Sum = linop_from_function(snp.sum, 'Sum') >>> H = Sum((2, 10), axis=1) >>> H @ snp.ones((2, 10)) Array([10., 10.], dtype=float32) Args: f: Function from which to create a :class:`LinearOperator`. classname: Name of the resulting class. f_name: Name of `f` for use in docstrings. Useful for getting the correct version of wrapped functions. Defaults to `f"{f.__module__}.{f.__name__}"`. """ if f_name is None: f_name = f"{f.__module__}.{f.__name__}" f_doc = rf""" Args: input_shape: Shape of input array. args: Positional arguments passed to :func:`{f_name}`. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. If the :class:`LinearOperator` implements complex-valued operations, this must be a complex dtype (typically :attr:`~numpy.complex64`) for correct adjoint and gradient calculation. output_shape: Shape of output array. Defaults to ``None``. If ``None``, `output_shape` is determined by evaluating `self.__call__` on an input array of zeros. output_dtype: `dtype` for output argument. Defaults to ``None``. If ``None``, `output_dtype` is determined by evaluating `self.__call__` on an input array of zeros. jit: If ``True``, call :meth:`~.LinearOperator.jit` on this :class:`LinearOperator` to jit the forward, adjoint, and gram functions. Same as calling :meth:`~.LinearOperator.jit` after the :class:`LinearOperator` is created. kwargs: Keyword arguments passed to :func:`{f_name}`. """ def __init__( self, input_shape: Union[Shape, BlockShape], *args: Any, input_dtype: DType = snp.float32, output_shape: Optional[Union[Shape, BlockShape]] = None, output_dtype: Optional[DType] = None, jit: bool = True, **kwargs: Any, ): self._eval = lambda x: f(x, *args, **kwargs) self.kwargs = kwargs super().__init__(input_shape, input_dtype=input_dtype, output_shape=output_shape, output_dtype=output_dtype, jit=jit) # type: ignore OpClass = type(classname, (LinearOperator,), {"__init__": __init__}) __class__ = OpClass # needed for super() to work OpClass.__doc__ = f"Linear operator version of :func:`{f_name}`." OpClass.__init__.__doc__ = f_doc # type: ignore return OpClass Transpose = linop_from_function(snp.transpose, "Transpose", "scico.numpy.transpose") Reshape = linop_from_function(snp.reshape, "Reshape") Pad = linop_from_function(snp.pad, "Pad", "scico.numpy.pad") Sum = linop_from_function(snp.sum, "Sum") class Crop(LinearOperator): """A linear operator for cropping an array.""" def __init__( self, crop_width: Union[int, Sequence], input_shape: Shape, input_dtype: DType = snp.float32, jit: bool = True, **kwargs, ): r""" Args: crop_width: Specify the crop width using the same format as the `pad_width` parameter of :func:`snp.pad`. input_shape: Shape of input :class:`jax.Array`. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. """ self.crop_width = crop_width # The crop function is defined as the adjoint of snp.pad pad = lambda x: snp.pad(x, pad_width=crop_width) # The output shape of this operator is the input shape of the corresponding # pad operation of which it is the adjoint. Since we don't know this output # shape, we assume that it can be computed by subtracting the difference in # output and input shapes resulting from applying the pad operator to the # input shape of this operator. pad_shape = jax.eval_shape(pad, jax.ShapeDtypeStruct(input_shape, dtype=input_dtype)).shape output_shape = tuple((2 * snp.array(input_shape) - snp.array(pad_shape)).tolist()) pad_adjoint = linear_transpose(pad, jax.ShapeDtypeStruct(output_shape, dtype=input_dtype)) super().__init__( input_shape=input_shape, input_dtype=input_dtype, eval_fn=lambda x: pad_adjoint(x)[0], output_shape=output_shape, output_dtype=input_dtype, jit=jit, **kwargs, ) class Slice(LinearOperator): """A linear operator for slicing an array.""" def __init__( self, idx: ArrayIndex, input_shape: Union[Shape, BlockShape], input_dtype: DType = snp.float32, jit: bool = True, **kwargs, ): r""" This operator may be applied to either a :class:`jax.Array` or a :class:`.BlockArray`. In the latter case, parameter `idx` must conform to the :ref:`BlockArray indexing requirements `. Args: idx: An array indexing expression, as generated by :data:`numpy.s_`, for example. input_shape: Shape of input :class:`jax.Array` or :class:`.BlockArray`. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. """ output_shape: Union[Shape, BlockShape] if is_nested(input_shape): output_shape = input_shape[idx] # type: ignore else: output_shape = indexed_shape(input_shape, idx) # type: ignore self.idx: ArrayIndex = idx super().__init__( input_shape=input_shape, output_shape=output_shape, input_dtype=input_dtype, output_dtype=input_dtype, jit=jit, **kwargs, ) def _eval(self, x: snp.Array) -> snp.Array: return x[self.idx] ================================================ FILE: scico/linop/_grad.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Non-Cartesian gradient linear operators.""" # Needed to annotate a class method that returns the encapsulating class # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from typing import Optional, Sequence, Tuple, Union import numpy as np import scico.numpy as snp from scico.numpy import Array, BlockArray from scico.typing import DType, Shape from ._linop import LinearOperator def diffstack(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: """Compute the discrete difference along multiple axes. Apply :func:`snp.diff` along multiple axes, stacking the results on a newly inserted axis at index 0. The `append` parameter of :func:`snp.diff` is exploited to give output of the same length as the input, which is achieved by zero-padding the output at the end of each axis. """ if axis is None: axis = tuple(range(x.ndim)) elif isinstance(axis, int): axis = (axis,) dstack = [ snp.diff( x, axis=ax, append=x[tuple(slice(-1, None) if i == ax else slice(None) for i in range(x.ndim))], ) for ax in axis ] return snp.stack(dstack) class ProjectedGradient(LinearOperator): """Gradient projected onto local coordinate system. This class represents a linear operator that computes gradients of arrays projected onto a local coordinate system that may differ at every position in the array, as described in :cite:`hossein-2024-total`. In the 2D illustration below :math:`x` and :math:`y` represent the standard coordinate system defined by the array axes, :math:`(g_x, g_y)` is the gradient vector within that coordinate system, :math:`x'` and :math:`y'` are the local coordinate axes, and :math:`(g_x', g_y')` is the gradient vector within the local coordinate system. .. image:: /figures/projgrad.svg :align: center :alt: Figure illustrating projection of gradient onto local coordinate system. Each of the local coordinate axes (e.g. :math:`x'` and :math:`y'` in the illustration above) is represented by a separate array in the `coord` tuple of arrays parameter of the class initializer. .. note:: This operator should not be confused with the Projected Gradient optimization algorithm (a special case of Proximal Gradient), with which it is unrelated. """ def __init__( self, input_shape: Shape, axes: Optional[Tuple[int, ...]] = None, coord: Optional[Sequence[Union[Array, BlockArray]]] = None, cdiff: bool = False, input_dtype: DType = np.float32, jit: bool = True, ): r""" The result of applying the operator is always a :class:`jax.Array`. If `coord` is a singleton tuple, it has the same shape as the input array. Otherwise, the gradients for each of the local coordinate axes are stacked on an additional axis at index 0. If `coord` is ``None``, which is the default, gradients are computed in the standard axis-aligned coordinate system, and the shape of the returned array depends on the number of axes on which the gradient is calculated, as specified explicitly or implicitly via the `axes` parameter. Args: input_shape: Shape of input array. axes: Axes over which to compute the gradient. Defaults to ``None``, in which case the gradient is computed along all axes. coord: A tuple of arrays, each of which specifies a local coordinate axis direction. Each member of the tuple should either be a :class:`jax.Array` or a :class:`.BlockArray`. If it is the former, it should have shape :math:`N \times M_0 \times M_1 \times \ldots`, where :math:`N` is the number of axes specified by parameter `axes`, and :math:`M_i` is the size of the :math:`i^{\mrm{th}}` axis. If it is the latter, it should consist of :math:`N` blocks, each of which has a shape that is suitable for multiplication with an array of shape :math:`M_0 \times M_1 \times \ldots`. cdiff: If ``True``, estimate gradients using the second order central different returned by :func:`snp.gradient`, otherwise use the first order asymmetric difference returned by :func:`snp.diff`. input_dtype: `dtype` for input argument. Default is :attr:`~numpy.float32`. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. """ if axes is None: # If axes is None, set it to all axes in input shape. self.axes = tuple(range(len(input_shape))) else: # Ensure no invalid axis indices specified. if snp.any(np.array(axes) >= len(input_shape)): raise ValueError( "Invalid axes specified; all elements of argument 'axes' must " f"be less than len(input_shape)={len(input_shape)}." ) self.axes = axes output_shape: Shape if coord is None: # If coord is None, output shape is determined by number of axes. if len(self.axes) == 1: output_shape = input_shape else: output_shape = (len(self.axes),) + input_shape else: # If coord is not None, output shape is determined by number of coord arrays. if len(coord) == 1: output_shape = input_shape else: output_shape = (len(coord),) + input_shape self.coord = coord self.cdiff = cdiff super().__init__( input_shape=input_shape, output_shape=output_shape, input_dtype=input_dtype, output_dtype=input_dtype, jit=jit, ) def _eval(self, x: Array) -> Union[Array, BlockArray]: if self.cdiff: grad = snp.stack(snp.gradient(x, axis=self.axes)) else: grad = diffstack(x, axis=self.axes) if self.coord is None: # If coord attribute is None, just return gradients on specified axes. if len(self.axes) == 1: return grad[0] else: return grad else: # If coord attribute is not None, return gradients projected onto specified local # coordinate systems. projgrad = [sum([c[m] * grad[m] for m in range(len(self.axes))]) for c in self.coord] if len(self.coord) == 1: return projgrad[0] else: return snp.stack(projgrad) class PolarGradient(ProjectedGradient): """Gradient projected into polar coordinates. Compute gradients projected onto angular and/or radial axis directions, as described in :cite:`hossein-2024-total`. Local coordinate axes are illustrated in the figure below. .. plot:: pyfigures/polargrad.py :align: center :include-source: False :show-source-link: False | If only one of `angular` and `radial` is ``True``, the operator output has the same shape as the input, otherwise the gradients for the two local coordinate axes are stacked on an additional axis at index 0. """ def __init__( self, input_shape: Shape, axes: Optional[Tuple[int, ...]] = None, center: Optional[Union[Tuple[int, ...], Array]] = None, angular: bool = True, radial: bool = True, cdiff: bool = False, input_dtype: DType = np.float32, jit: bool = True, ): r""" Args: input_shape: Shape of input array. axes: Axes over which to compute the gradient. Should be a tuple :math:`(i_x, i_y)`, where :math:`i_x` and :math:`i_y` are input array axes assigned to :math:`x` and :math:`y` coordinates respectively. Defaults to ``None``, in which case the axes are taken to be `(0, 1)`. center: Center of the polar coordinate system in array indexing coordinates. Default is ``None``, which places the center at the center of the input array. angular: Flag indicating whether to compute gradients in the angular (i.e. tangent to circles) direction. radial: Flag indicating whether to compute gradients in the radial (i.e. directed outwards from the origin) direction. cdiff: If ``True``, estimate gradients using the second order central different returned by :func:`snp.gradient`, otherwise use the first order asymmetric difference returned by :func:`snp.diff`. input_dtype: `dtype` for input argument. Default is :attr:`~numpy.float32`. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. """ if len(input_shape) < 2: raise ValueError("Invalid input shape; input must have at least two axes.") if axes is not None and len(axes) != 2: raise ValueError("Invalid axes specified; exactly two axes must be specified.") if not angular and not radial: raise ValueError("At least one of angular and radial must be True.") real_input_dtype = snp.util.real_dtype(input_dtype) if axes is None: axes = (0, 1) axes_shape = [input_shape[ax] for ax in axes] if center is None: center = (snp.array(axes_shape, dtype=real_input_dtype) - 1) / 2 else: center = snp.array(center, dtype=real_input_dtype) end = snp.array(axes_shape, dtype=real_input_dtype) - center g0, g1 = snp.ogrid[-center[0] : end[0], -center[1] : end[1]] theta = snp.arctan2(g0, g1) # Re-order theta axes in case indices in axes parameter are not in increasing order. axis_order = np.argsort(axes) theta = snp.transpose(theta, axis_order) if len(input_shape) > 2: # Construct list of input axes that are not included in the gradient axes. single = tuple(set(range(len(input_shape))) - set(axes)) # Insert singleton axes to align theta for multiplication with gradients. theta = snp.expand_dims(theta, single) coord = [] if angular: coord.append(snp.blockarray([-snp.cos(theta), snp.sin(theta)])) if radial: coord.append(snp.blockarray([snp.sin(theta), snp.cos(theta)])) super().__init__( input_shape=input_shape, input_dtype=input_dtype, axes=axes, coord=coord, cdiff=cdiff, jit=jit, ) class CylindricalGradient(ProjectedGradient): """Gradient projected into cylindrical coordinates. Compute gradients projected onto cylindrical coordinate axes, as described in :cite:`hossein-2024-total`. The local coordinate axes are illustrated in the figure below. .. plot:: pyfigures/cylindgrad.py :align: center :include-source: False :show-source-link: False | If only one of `angular`, `radial`, and `axial` is ``True``, the operator output has the same shape as the input, otherwise the gradients for the selected local coordinate axes are stacked on an additional axis at index 0. """ def __init__( self, input_shape: Shape, axes: Optional[Tuple[int, ...]] = None, center: Optional[Union[Tuple[int, ...], Array]] = None, angular: bool = True, radial: bool = True, axial: bool = True, cdiff: bool = False, input_dtype: DType = np.float32, jit: bool = True, ): r""" Args: input_shape: Shape of input array. axes: Axes over which to compute the gradient. Should be a tuple :math:`(i_x, i_y, i_z)`, where :math:`i_x`, :math:`i_y` and :math:`i_z` are input array axes assigned to :math:`x`, :math:`y`, and :math:`z` coordinates respectively. Defaults to ``None``, in which case the axes are taken to be `(0, 1, 2)`. If an integer, this operator returns a :class:`jax.Array`. If a tuple or ``None``, the resulting arrays are stacked into a :class:`.BlockArray`. center: Center of the cylindrical coordinate system in array indexing coordinates. Default is ``None``, which places the center at the center of the two polar axes of the input array and at the zero index of the axial axis. angular: Flag indicating whether to compute gradients in the angular (i.e. tangent to circles) direction. radial: Flag indicating whether to compute gradients in the radial (i.e. directed outwards from the origin) direction. axial: Flag indicating whether to compute gradients in the direction of the axis of the cylinder. cdiff: If ``True``, estimate gradients using the second order central different returned by :func:`snp.gradient`, otherwise use the first order asymmetric difference returned by :func:`snp.diff`. input_dtype: `dtype` for input argument. Default is :attr:`~numpy.float32`. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. """ if len(input_shape) < 3: raise ValueError("Invalid input shape; input must have at least three axes.") if axes is not None and len(axes) != 3: raise ValueError("Invalid axes specified; exactly three axes must be specified.") if not angular and not radial and not axial: raise ValueError("At least one of angular, radial, and axial must be True.") real_input_dtype = snp.util.real_dtype(input_dtype) if axes is None: axes = (0, 1, 2) axes_shape = [input_shape[ax] for ax in axes] if center is None: center = (snp.array(axes_shape, dtype=real_input_dtype) - 1) / 2 center = center.at[-1].set(0) # type: ignore else: center = snp.array(center, dtype=real_input_dtype) end = snp.array(axes_shape, dtype=real_input_dtype) - center g0, g1 = snp.ogrid[-center[0] : end[0], -center[1] : end[1]] g0 = g0[..., np.newaxis] g1 = g1[..., np.newaxis] theta = snp.arctan2(g0, g1) # Re-order theta axes in case indices in axes parameter are not in increasing order. axis_order = np.argsort(axes) theta = snp.transpose(theta, axis_order) if len(input_shape) > 3: # Construct list of input axes that are not included in the gradient axes. single = tuple(set(range(len(input_shape))) - set(axes)) # Insert singleton axes to align theta for multiplication with gradients. theta = snp.expand_dims(theta, single) coord = [] if angular: coord.append( snp.blockarray( [-snp.cos(theta), snp.sin(theta), snp.array([0.0], dtype=real_input_dtype)] ) ) if radial: coord.append( snp.blockarray( [snp.sin(theta), snp.cos(theta), snp.array([0.0], dtype=real_input_dtype)] ) ) if axial: coord.append( snp.blockarray( [ snp.array([0.0], dtype=real_input_dtype), snp.array([0.0], dtype=real_input_dtype), snp.array([1.0], dtype=real_input_dtype), ] ) ) super().__init__( input_shape=input_shape, input_dtype=input_dtype, axes=axes, cdiff=cdiff, coord=coord, jit=jit, ) class SphericalGradient(ProjectedGradient): """Gradient projected into spherical coordinates. Compute gradients projected onto spherical coordinate axes, based on the approach described in :cite:`hossein-2024-total`. The local coordinate axes are illustrated in the figure below. .. plot:: pyfigures/spheregrad.py :align: center :include-source: False :show-source-link: False | If only one of `azimuthal`, `polar`, and `radial` is ``True``, the operator output has the same shape as the input, otherwise the gradients for the selected local coordinate axes are stacked on an additional axis at index 0. """ def __init__( self, input_shape: Shape, axes: Optional[Tuple[int, ...]] = None, center: Optional[Union[Tuple[int, ...], Array]] = None, azimuthal: bool = True, polar: bool = True, radial: bool = True, cdiff: bool = False, input_dtype: DType = np.float32, jit: bool = True, ): r""" Args: input_shape: Shape of input array. axes: Axes over which to compute the gradient. Should be a tuple :math:`(i_x, i_y, i_z)`, where :math:`i_x`, :math:`i_y` and :math:`i_z` are input array axes assigned to :math:`x`, :math:`y`, and :math:`z` coordinates respectively. Defaults to ``None``, in which case the axes are taken to be `(0, 1, 2)`. If an integer, this operator returns a :class:`jax.Array`. If a tuple or ``None``, the resulting arrays are stacked into a :class:`.BlockArray`. center: Center of the spherical coordinate system in array indexing coordinates. Default is ``None``, which places the center at the center of the input array. azimuthal: Flag indicating whether to compute gradients in the azimuthal direction. polar: Flag indicating whether to compute gradients in the polar direction. radial: Flag indicating whether to compute gradients in the radial direction. cdiff: If ``True``, estimate gradients using the second order central different returned by :func:`snp.gradient`, otherwise use the first order asymmetric difference returned by :func:`snp.diff`. input_dtype: `dtype` for input argument. Default is :attr:`~numpy.float32`. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. """ if len(input_shape) < 3: raise ValueError("Invalid input shape; input must have at least three axes.") if axes is not None and len(axes) != 3: raise ValueError("Invalid axes specified; exactly three axes must be specified.") if not azimuthal and not polar and not radial: raise ValueError("At least one of azimuthal, polar, and radial must be True.") real_input_dtype = snp.util.real_dtype(input_dtype) if axes is None: axes = (0, 1, 2) axes_shape = [input_shape[ax] for ax in axes] if center is None: center = (snp.array(axes_shape, dtype=real_input_dtype) - 1) / 2 else: center = snp.array(center, dtype=real_input_dtype) end = snp.array(axes_shape, dtype=real_input_dtype) - center g0, g1, g2 = snp.ogrid[-center[0] : end[0], -center[1] : end[1], -center[2] : end[2]] theta = snp.arctan2(g1, g0) phi = snp.arctan2(snp.sqrt(g0**2 + g1**2), g2) # Re-order theta and phi axes in case indices in axes parameter are not in # increasing order. axis_order = np.argsort(axes) theta = snp.transpose(theta, axis_order) phi = snp.transpose(phi, axis_order) if len(input_shape) > 3: # Construct list of input axes that are not included in the gradient axes. single = tuple(set(range(len(input_shape))) - set(axes)) # Insert singleton axes to align theta for multiplication with gradients. theta = snp.expand_dims(theta, single) phi = snp.expand_dims(phi, single) coord = [] if azimuthal: coord.append( snp.blockarray( [snp.sin(theta), -snp.cos(theta), snp.array([0.0], dtype=real_input_dtype)] ) ) if polar: coord.append( snp.blockarray( [snp.cos(phi) * snp.cos(theta), snp.cos(phi) * snp.sin(theta), -snp.sin(phi)] ) ) if radial: coord.append( snp.blockarray( [snp.sin(phi) * snp.cos(theta), snp.sin(phi) * snp.sin(theta), snp.cos(phi)] ) ) super().__init__( input_shape=input_shape, input_dtype=input_dtype, axes=axes, coord=coord, cdiff=cdiff, jit=jit, ) ================================================ FILE: scico/linop/_linop.py ================================================ # Copyright (C) 2020-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Linear operator base class.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from functools import wraps from typing import Callable, Optional, Union import numpy as np import jax import jax.numpy as jnp from jax.dtypes import result_type import scico.numpy as snp from scico._core import linear_adjoint from scico.numpy import Array, BlockArray from scico.numpy.util import is_complex_dtype from scico.operator._operator import Operator, _wrap_mul_div_scalar from scico.typing import BlockShape, DType, Shape def _wrap_add_sub(func: Callable) -> Callable: r"""Wrapper function for defining `__add__` and `__sub__`. Wrapper function for defining `__add__` and ` __sub__` between :class:`LinearOperator` and derived classes. Operations between :class:`LinearOperator` and :class:`.Operator` types are also supported. Handles shape checking and function dispatch based on types of operands `a` and `b` in the call `func(a, b)`. Note that `func` will always be a method of the type of `a`, and since this wrapper should only be applied within :class:`LinearOperator` or derived classes, we can assume that `a` is always an instance of :class:`LinearOperator`. The general rule for dispatch is that the `__add__` or `__sub__` operator of the nearest common base class of `a` and `b` should be called. If `b` is derived from `a`, this entails using the operator defined in the class of `a`, and vice-versa. If one of the operands is not a descendant of the other in the class hierarchy, then it is assumed that their common base class is either :class:`.Operator` or :class:`LinearOperator`, depending on the type of `b`. - If `b` is not an instance of :class:`.Operator`, a :exc:`TypeError` is raised. - If the shapes of `a` and `b` do not match, a :exc:`ValueError` is raised. - If `b` is an instance of the type of `a` then `func(a, b)` is called where `func` is the argument of this wrapper, i.e. the unwrapped function defined in the class of `a`. - If `a` is an instance of the type of `b` then `func(a, b)` is called where `func` is the unwrapped function defined in the class of `b`. - If `b` is a :class:`LinearOperator` then `func(a, b)` is called where `func` is the operator defined in :class:`LinearOperator`. - Othwerwise, `func(a, b)` is called where `func` is the operator defined in :class:`.Operator`. Args: func: should be either `.__add__` or `.__sub__`. Returns: Wrapped version of `func`. Raises: ValueError: If the shapes of two operators do not match. TypeError: If one of the two operands is not an :class:`.Operator` or :class:`LinearOperator`. """ @wraps(func) def wrapper( a: LinearOperator, b: Union[Operator, LinearOperator] ) -> Union[Operator, LinearOperator]: if isinstance(b, Operator): if a.shape == b.shape: if isinstance(b, type(a)): # b is an instance of the class of a: call the unwrapped operator # defined in the class of a, which is the func argument of this # wrapper return func(a, b) if isinstance(a, type(b)): # a is an instance of class b: call the unwrapped operator # defined in the class of b. A test is required because # the operators defined in Operator and non-LinearOperator # derived classes are not wrapped. if hasattr(getattr(type(b), func.__name__), "_unwrapped"): uwfunc = getattr(type(b), func.__name__)._unwrapped else: uwfunc = getattr(type(b), func.__name__) return uwfunc(a, b) # The most general approach here would be to automatically determine # the nearest common ancestor of the classes of a and b (e.g. as # discussed in https://stackoverflow.com/a/58290475 ), but the # simpler approach adopted here is to just assume that the common # base of two classes that do not have an ancestor-descendant # relationship is either Operator or LinearOperator. if isinstance(b, LinearOperator): # LinearOperator + LinearOperator -> LinearOperator uwfunc = getattr(LinearOperator, func.__name__)._unwrapped return uwfunc(a, b) # LinearOperator + Operator -> Operator (access to the function # definition differs from that for LinearOperator because # Operator __add__ and __sub__ are not wrapped) uwfunc = getattr(Operator, func.__name__) return uwfunc(a, b) raise ValueError(f"Shapes {a.shape} and {b.shape} do not match.") raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.") wrapper._unwrapped = func # type: ignore return wrapper class LinearOperator(Operator): """Generic linear operator base class""" def __init__( self, input_shape: Union[Shape, BlockShape], output_shape: Optional[Union[Shape, BlockShape]] = None, eval_fn: Optional[Callable] = None, adj_fn: Optional[Callable] = None, input_dtype: DType = np.float32, output_dtype: Optional[DType] = None, jit: bool = False, ): r""" Args: input_shape: Shape of input array. output_shape: Shape of output array. Defaults to ``None``. If ``None``, `output_shape` is determined by evaluating `self.__call__` on an input array of zeros. eval_fn: Function used in evaluating this :class:`LinearOperator`. Defaults to ``None``. If ``None``, then `self.__call__` must be defined in any derived classes. adj_fn: Function used to evaluate the adjoint of this :class:`LinearOperator`. Defaults to ``None``. If ``None``, the adjoint is not set, and the :meth:`._set_adjoint` will be called silently at the first :meth:`.adj` call or can be called manually. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. If the :class:`.LinearOperator` implements complex-valued operations, this must be a complex dtype (typically :attr:`~numpy.complex64`) for correct adjoint and gradient calculation. output_dtype: `dtype` for output argument. Defaults to ``None``. If ``None``, `output_dtype` is determined by evaluating `self.__call__` on an input array of zeros. jit: If ``True``, call :meth:`.jit()` on this :class:`LinearOperator` to jit the forward, adjoint, and gram functions. Same as calling :meth:`.jit` after the :class:`LinearOperator` is created. """ super().__init__( input_shape=input_shape, output_shape=output_shape, eval_fn=eval_fn, input_dtype=input_dtype, output_dtype=output_dtype, jit=False, ) if not hasattr(self, "_adj"): self._adj: Optional[Callable] = None if not hasattr(self, "_gram"): self._gram: Optional[Callable] = None if callable(adj_fn): self._adj = adj_fn self._gram = lambda x: self.adj(self(x)) elif adj_fn is not None: raise TypeError(f"Argument 'adj_fn' must be either a Callable or None; got {adj_fn}.") if jit: self.jit() def _set_adjoint(self): """Automatically create adjoint method.""" adj_fun = linear_adjoint( self._eval, jax.ShapeDtypeStruct(self.input_shape, dtype=self.input_dtype) ) self._adj = lambda x: adj_fun(x)[0] def _set_gram(self): """Automatically create gram method.""" self._gram = lambda x: self.adj(self(x)) def jit(self): """Replace the private functions :meth:`._eval`, :meth:`_adj`, :meth:`._gram` with jitted versions. """ if self._adj is None: self._set_adjoint() if self._gram is None: self._set_gram() self._eval = jax.jit(self._eval) self._adj = jax.jit(self._adj) self._gram = jax.jit(self._gram) @_wrap_add_sub def __add__(self, other): return LinearOperator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(x) + other(x), adj_fn=lambda x: self.adj(x) + other.adj(x), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other.output_dtype), ) @_wrap_add_sub def __sub__(self, other): return LinearOperator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(x) - other(x), adj_fn=lambda x: self.adj(x) - other.adj(x), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other.output_dtype), ) @_wrap_mul_div_scalar def __mul__(self, other): return LinearOperator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: other * self(x), adj_fn=lambda x: snp.conj(other) * self.adj(x), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other), ) @_wrap_mul_div_scalar def __rmul__(self, other): return self.__mul__(other) # scalar multiplication is commutative @_wrap_mul_div_scalar def __truediv__(self, other): return LinearOperator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(x) / other, adj_fn=lambda x: self.adj(x) / snp.conj(other), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other), ) def __matmul__(self, other): # self @ other return self(other) def __rmatmul__(self, other): # other @ self if isinstance(other, LinearOperator): return other(self) if isinstance(other, (np.ndarray, jnp.ndarray)): # for real valued inputs: y @ self == (self.T @ y.T).T # for complex: y @ self == (self.conj().T @ y.conj().T).conj().T # self.conj().T == self.adj return self.adj(other.conj().T).conj().T raise NotImplementedError( f"Operation __rmatmul__ not defined between {type(self)} and {type(other)}." ) def __call__( self, x: Union[LinearOperator, Array, BlockArray] ) -> Union[LinearOperator, Array, BlockArray]: r"""Evaluate this :class:`LinearOperator` at the point :math:`\mb{x}`. Args: x: Point at which to evaluate this :class:`LinearOperator`. If `x` is a :class:`jax.Array` or :class:`.BlockArray`, must have `shape == self.input_shape`. If `x` is a :class:`LinearOperator`, must have `x.output_shape == self.input_shape`. """ if isinstance(x, LinearOperator): return ComposedLinearOperator(self, x) # Use Operator __call__ for LinearOperator @ array or LinearOperator @ Operator return super().__call__(x) def adj( self, y: Union[LinearOperator, Array, BlockArray] ) -> Union[LinearOperator, Array, BlockArray]: """Adjoint of this :class:`LinearOperator`. Compute the adjoint of this :class:`LinearOperator` applied to input `y`. Args: y: Point at which to compute adjoint. If `y` is :class:`jax.Array` or :class:`.BlockArray`, must have `shape == self.output_shape`. If `y` is a :class:`LinearOperator`, must have `y.output_shape == self.output_shape`. Returns: Adjoint evaluated at `y`. """ if self._adj is None: self._set_adjoint() if isinstance(y, LinearOperator): return ComposedLinearOperator(self.H, y) if self.output_dtype != y.dtype: raise ValueError(f"Dtype error: expected {self.output_dtype}, got {y.dtype}.") if self.output_shape != y.shape: raise ValueError( f"""Shapes do not conform: input array with shape {y.shape} does not match LinearOperator output_shape {self.output_shape}.""" ) assert self._adj is not None return self._adj(y) @property def T(self) -> LinearOperator: """Transpose of this :class:`LinearOperator`. Return a new :class:`LinearOperator` that implements the transpose of this :class:`LinearOperator`. For a real-valued :class:`LinearOperator` `A` (`A.input_dtype` is :attr:`~numpy.float32` or :attr:`~numpy.float64`), the :class:`LinearOperator` `A.T` implements the adjoint: `A.T(y) == A.adj(y)`. For a complex-valued :class:`LinearOperator` `A` (`A.input_dtype` is :attr:`~numpy.complex64` or :attr:`~numpy.complex128`), the :class:`LinearOperator` `A.T` is not the adjoint. For the conjugate transpose, use `.conj().T` or :meth:`.H`. """ if is_complex_dtype(self.input_dtype): return LinearOperator( input_shape=self.output_shape, output_shape=self.input_shape, eval_fn=lambda x: self.adj(x.conj()).conj(), adj_fn=self.__call__, input_dtype=self.input_dtype, output_dtype=self.output_dtype, ) return LinearOperator( input_shape=self.output_shape, output_shape=self.input_shape, eval_fn=self.adj, adj_fn=self.__call__, input_dtype=self.output_dtype, output_dtype=self.input_dtype, ) @property def H(self) -> LinearOperator: """Hermitian transpose of this :class:`LinearOperator`. Return a new :class:`LinearOperator` that is the Hermitian transpose of this :class:`LinearOperator`. For a real-valued :class:`LinearOperator` `A` (`A.input_dtype` is :attr:`~numpy.float32` or :attr:`~numpy.float64`), the :class:`LinearOperator` `A.H` is equivalent to `A.T`. For a complex-valued :class:`LinearOperator` `A` (`A.input_dtype` is :attr:`~numpy.complex64` or :attr:`~numpy.complex128`), the :class:`LinearOperator` `A.H` implements the adjoint of `A : A.H @ y == A.adj(y) == A.conj().T @ y)`. For the non-conjugate transpose, see :meth:`.T`. """ return LinearOperator( input_shape=self.output_shape, output_shape=self.input_shape, eval_fn=self.adj, adj_fn=self.__call__, input_dtype=self.output_dtype, output_dtype=self.input_dtype, ) def conj(self) -> LinearOperator: """Complex conjugate of this :class:`LinearOperator`. Return a new :class:`LinearOperator` `Ac` such that `Ac(x) = conj(A)(x)`. """ # A.conj() x == (A @ x.conj()).conj() return LinearOperator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(x.conj()).conj(), adj_fn=lambda x: self.adj(x.conj()).conj(), input_dtype=self.input_dtype, output_dtype=self.output_dtype, ) @property def gram_op(self) -> LinearOperator: """Gram operator of this :class:`LinearOperator`. Return a new :class:`LinearOperator` `G` such that `G(x) = A.adj(A(x)))`. """ if self._gram is None: self._set_gram() return LinearOperator( input_shape=self.input_shape, output_shape=self.input_shape, eval_fn=self.gram, adj_fn=self.gram, input_dtype=self.input_dtype, output_dtype=self.output_dtype, ) def gram( self, x: Union[LinearOperator, Array, BlockArray] ) -> Union[LinearOperator, Array, BlockArray]: """Compute `A.adj(A(x)).` Args: x: Point at which to evaluate the gram operator. If `x` is a :class:`jax.Array` or :class:`.BlockArray`, must have `shape == self.input_shape`. If `x` is a :class:`LinearOperator`, must have `x.output_shape == self.input_shape`. Returns: Result of `A.adj(A(x))`. """ if self._gram is None: self._set_gram() assert self._gram is not None return self._gram(x) class ComposedLinearOperator(LinearOperator): """A composition of two :class:`LinearOperator` objects. A new :class:`LinearOperator` formed by the composition of two other :class:`LinearOperator` objects. """ def __init__(self, A: LinearOperator, B: LinearOperator, jit: bool = False): r""" A :class:`ComposedLinearOperator` `AB` implements `AB @ x == A @ B @ x`. :class:`LinearOperator` `A` and `B` are stored as attributes of the :class:`ComposedLinearOperator`. :class:`LinearOperator` `A` and `B` must have compatible shapes and dtypes: `A.input_shape == B.output_shape` and `A.input_dtype == B.input_dtype`. Args: A: First (left) :class:`LinearOperator`. B: Second (right) :class:`LinearOperator`. jit: If ``True``, call :meth:`~.LinearOperator.jit()` on this :class:`LinearOperator` to jit the forward, adjoint, and gram functions. Same as calling :meth:`~.LinearOperator.jit` after the :class:`LinearOperator` is created. """ if not isinstance(A, LinearOperator): raise TypeError( "The first argument to ComposedLinearOperator must be a LinearOperator; " f"got {type(A)}." ) if not isinstance(B, LinearOperator): raise TypeError( "The second argument to ComposedLinearOperator must be a LinearOperator; " f"got {type(B)}." ) if A.input_shape != B.output_shape: raise ValueError(f"Incompatable LinearOperator shapes {A.shape}, {B.shape}.") if A.input_dtype != B.output_dtype: raise ValueError( f"Incompatable LinearOperator dtypes {A.input_dtype}, {B.output_dtype}." ) self.A = A self.B = B super().__init__( input_shape=self.B.input_shape, output_shape=self.A.output_shape, input_dtype=self.B.input_dtype, output_dtype=self.A.output_dtype, eval_fn=lambda x: self.A(self.B(x)), adj_fn=lambda z: self.B.adj(self.A.adj(z)), jit=jit, ) ================================================ FILE: scico/linop/_matrix.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Matrix linear operator classes.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations import operator from functools import partial, wraps import numpy as np import jax.numpy as jnp from jax.typing import ArrayLike import scico.numpy as snp from scico.operator._operator import Operator from ._diag import Identity from ._linop import LinearOperator def _wrap_add_sub_matrix(func, op): @wraps(func) def wrapper(a, b): if np.isscalar(b): return MatrixOperator(op(a.A, b)) if isinstance(b, MatrixOperator): if a.shape == b.shape: return MatrixOperator(op(a.A, b.A)) raise ValueError(f"MatrixOperator shapes {a.shape} and {b.shape} do not match.") if isinstance(b, (jnp.ndarray, np.ndarray)): if a.matrix_shape == b.shape: return MatrixOperator(op(a.A, b)) raise ValueError(f"Shapes {a.matrix_shape} and {b.shape} do not match.") if isinstance(b, Operator): if a.shape != b.shape: raise ValueError(f"Shapes {a.shape} and {b.shape} do not match.") if isinstance(b, LinearOperator): uwfunc = getattr(LinearOperator, func.__name__)._unwrapped return uwfunc(a, b) if isinstance(b, Operator): uwfunc = getattr(Operator, func.__name__) return uwfunc(a, b) raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.") return wrapper class MatrixOperator(LinearOperator): """Linear operator implementing matrix multiplication.""" def __init__(self, A: ArrayLike, input_cols: int = 0): """ Args: A: Dense array. The action of the created :class:`.LinearOperator` will implement matrix multiplication with `A`. input_cols: If this parameter is set to the default of 0, the :class:`MatrixOperator` takes a vector (one-dimensional array) input. If the input is intended to be a matrix (two-dimensional array), this parameter should specify number of columns in the matrix. """ self.A: snp.Array #: Dense array implementing this matrix # Ensure that A is a numpy or jax array. if not snp.util.is_arraylike(A): raise TypeError(f"Expected numpy or jax array, got {type(A)}.") self.A = A # Can only do rank-2 arrays if A.ndim != 2: raise TypeError(f"Expected a two-dimensional array, got array of shape {A.shape}.") self.__array__ = A.__array__ # enables jnp.array(H) if input_cols == 0: input_shape = A.shape[1] output_shape = A.shape[0] else: input_shape = (A.shape[1], input_cols) output_shape = (A.shape[0], input_cols) super().__init__( input_shape=input_shape, output_shape=output_shape, input_dtype=self.A.dtype ) def __call__(self, other): if isinstance(other, LinearOperator): if self.input_shape == other.output_shape: if isinstance(other, Identity): return self if isinstance(other, MatrixOperator): return MatrixOperator(A=self.A @ other.A) # must be a generic linop so return composition of the two return LinearOperator( input_shape=other.input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(other(x)), input_dtype=self.input_dtype, ) raise ValueError( "Cannot compute MatrixOperator-LinearOperator product, " f"{other.output_shape} does not match {self.input_shape}." ) return self._eval(other) def _eval(self, other): return self.A @ other def gram(self, other): return self.A.conj().T @ self.A @ other @partial(_wrap_add_sub_matrix, op=operator.add) def __add__(self, other): pass @partial(_wrap_add_sub_matrix, op=operator.sub) def __sub__(self, other): pass def __radd__(self, other): # Addition is commutative return self + other def __rsub__(self, other): return -self + other def __neg__(self): return MatrixOperator(-self.A) # Could write another wrapper for mul, truediv, and rtuediv, but there is # no operator.__rtruediv__; have to write that case out manually anyway. def __mul__(self, other): if np.isscalar(other): return MatrixOperator(other * self.A) if isinstance(other, MatrixOperator): if self.shape == other.shape: return MatrixOperator(self.A * other.A) raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.") if isinstance(other, (jnp.ndarray, np.ndarray)): if self.matrix_shape == other.shape: return MatrixOperator(self.A * other) raise ValueError(f"Shapes {self.matrix_shape} and {other.shape} do not match.") # includes generic LinearOperator raise TypeError(f"Operation __mul__ not defined between {type(self)} and {type(other)}.") def __rmul__(self, other): # multiplication is commutative return self * other def __truediv__(self, other): if np.isscalar(other): return MatrixOperator(self.A / other) if isinstance(other, MatrixOperator): if self.shape == other.shape: return MatrixOperator(self.A / other.A) raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.") if isinstance(other, (jnp.ndarray, np.ndarray)): if self.matrix_shape == other.shape: return MatrixOperator(self.A / other) raise ValueError(f"Shapes {self.matrix_shape} and {other.shape} do not match.") raise TypeError( f"Operation __truediv__ not defined between {type(self)} and {type(other)}." ) def __rtruediv__(self, other): if np.isscalar(other): return MatrixOperator(other / self.A) if isinstance(other, (jnp.ndarray, np.ndarray)): if self.matrix_shape == other.shape: return MatrixOperator(other / self.A) raise ValueError(f"Shapes {other.shape} and {self.matrix_shape} do not match.") raise TypeError( f"Operation __truediv__ not defined between {type(other)} and {type(self)}." ) def __getitem__(self, key): return self.A[key] @property def T(self): """Transpose of this :class:`.MatrixOperator`. Return a :class:`.MatrixOperator` corresponding to the transpose of this matrix. """ return MatrixOperator(self.A.T) @property def H(self): """Hermitian (conjugate) transpose of this :class:`.MatrixOperator`. Return a :class:`.MatrixOperator` corresponding to the Hermitian (conjugate) transpose of this matrix. """ return MatrixOperator(self.A.conj().T) def conj(self): """Complex conjugate of this :class:`.MatrixOperator`. Return a :class:`.MatrixOperator` with complex conjugated elements. """ return MatrixOperator(A=self.A.conj()) def adj(self, y): return self.A.conj().T @ y def to_array(self): """Return a :class:`numpy.ndarray` containing `self.A`.""" return np.array(self.A) @property def gram_op(self): """Gram operator of this :class:`.MatrixOperator`. Return a new :class:`.LinearOperator` `G` such that `G(x) = A.adj(A(x)))`.""" return MatrixOperator(A=self.A.conj().T @ self.A) def norm(self, ord=None, axis=None, keepdims=False): # pylint: disable=W0622 """Compute the norm of the dense matrix `self.A`. Call :func:`scico.numpy.linalg.norm` on the dense matrix `self.A`. """ return snp.linalg.norm(self.A, ord=ord, axis=axis, keepdims=keepdims) ================================================ FILE: scico/linop/_stack.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Stack of linear operators classes.""" from __future__ import annotations from typing import Any, List, Optional, Sequence, Union import scico.numpy as snp from scico.numpy import Array, BlockArray from scico.numpy.util import normalize_axes from scico.operator._stack import DiagonalReplicated as DiagonalReplicatedOperator from scico.operator._stack import DiagonalStack as DiagonalStackOperator from scico.operator._stack import VerticalStack as VerticalStackOperator from scico.typing import Axes, Shape from ._linop import LinearOperator class VerticalStack(VerticalStackOperator, LinearOperator): r"""A vertical stack of linear operators. Given linear operators :math:`A_1, A_2, \dots, A_N`, create the linear operator .. math:: H = \begin{pmatrix} A_1 \\ A_2 \\ \vdots \\ A_N \\ \end{pmatrix} \qquad \text{such that} \qquad H \mb{x} = \begin{pmatrix} A_1(\mb{x}) \\ A_2(\mb{x}) \\ \vdots \\ A_N(\mb{x}) \\ \end{pmatrix} \;. """ def __init__( self, ops: Sequence[LinearOperator], collapse_output: Optional[bool] = True, jit: bool = True, **kwargs, ): r""" Args: ops: Linear operators to stack. collapse_output: If ``True`` and the output would be a :class:`BlockArray` with shape ((m, n, ...), (m, n, ...), ...), the output is instead a :class:`jax.Array` with shape (S, m, n, ...) where S is the length of `ops`. jit: See `jit` in :class:`LinearOperator`. """ if not all(isinstance(op, LinearOperator) for op in ops): raise TypeError("All elements of 'ops' must be of type LinearOperator.") super().__init__(ops=ops, collapse_output=collapse_output, jit=jit, **kwargs) def _adj(self, y: Union[Array, BlockArray]) -> Array: # type: ignore return sum([op.adj(y_block) for y_block, op in zip(y, self.ops)]) # type: ignore class DiagonalStack(DiagonalStackOperator, LinearOperator): r"""A diagonal stack of linear operators. Given linear operators :math:`A_1, A_2, \dots, A_N`, create the linear operator .. math:: H = \begin{pmatrix} A_1 & 0 & \ldots & 0\\ 0 & A_2 & \ldots & 0\\ \vdots & \vdots & \ddots & \vdots\\ 0 & 0 & \ldots & A_N \\ \end{pmatrix} \qquad \text{such that} \qquad H \begin{pmatrix} \mb{x}_1 \\ \mb{x}_2 \\ \vdots \\ \mb{x}_N \\ \end{pmatrix} = \begin{pmatrix} A_1(\mb{x}_1) \\ A_2(\mb{x}_2) \\ \vdots \\ A_N(\mb{x}_N) \\ \end{pmatrix} \;. By default, if the inputs :math:`\mb{x}_1, \mb{x}_2, \dots, \mb{x}_N` all have the same (possibly nested) shape, `S`, this operator will work on the stack, i.e., have an input shape of `(N, *S)`. If the inputs have distinct shapes, `S1`, `S2`, ..., `SN`, this operator will work on the block concatenation, i.e., have an input shape of `(S1, S2, ..., SN)`. The same holds for the output shape. """ def __init__( self, ops: Sequence[LinearOperator], collapse_input: Optional[bool] = True, collapse_output: Optional[bool] = True, jit: bool = True, **kwargs, ): """ Args: ops: Operators to stack. collapse_input: If ``True``, inputs are expected to be stacked along the first dimension when possible. collapse_output: If ``True``, the output will be stacked along the first dimension when possible. jit: See `jit` in :class:`LinearOperator`. """ if not all(isinstance(op, LinearOperator) for op in ops): raise TypeError("All elements of 'ops' must be of type LinearOperator.") super().__init__( ops=ops, collapse_input=collapse_input, collapse_output=collapse_output, jit=jit, **kwargs, ) def _adj(self, y: Union[Array, BlockArray]) -> Union[Array, BlockArray]: # type: ignore result = tuple(op.T @ y_n for op, y_n in zip(self.ops, y)) # type: ignore if self.collapse_input: return snp.stack(result) return snp.blockarray(result) class DiagonalReplicated(DiagonalReplicatedOperator, LinearOperator): r"""A diagonal stack constructed from a single linear operator. Given linear operator :math:`A`, create the linear operator .. math:: H = \begin{pmatrix} A & 0 & \ldots & 0\\ 0 & A & \ldots & 0\\ \vdots & \vdots & \ddots & \vdots\\ 0 & 0 & \ldots & A \\ \end{pmatrix} \qquad \text{such that} \qquad H \begin{pmatrix} \mb{x}_1 \\ \mb{x}_2 \\ \vdots \\ \mb{x}_N \\ \end{pmatrix} = \begin{pmatrix} A(\mb{x}_1) \\ A(\mb{x}_2) \\ \vdots \\ A(\mb{x}_N) \\ \end{pmatrix} \;. The application of :math:`A` to each component :math:`\mb{x}_k` is computed using :func:`jax.pmap` or :func:`jax.vmap`. The input shape for linear operator :math:`A` should exclude the array axis on which :math:`A` is replicated to form :math:`H`. For example, if :math:`A` has input shape `(3, 4)` and :math:`H` is constructed to replicate on axis 0 with 2 replicates, the input shape of :math:`H` will be `(2, 3, 4)`. Linear operators taking :class:`.BlockArray` input are not supported. """ def __init__( self, op: LinearOperator, replicates: int, input_axis: int = 0, output_axis: Optional[int] = None, map_type: str = "auto", **kwargs, ): """ Args: op: Linear operator to replicate. replicates: Number of replicates of `op`. input_axis: Input axis over which `op` should be replicated. output_axis: Index of replication axis in output array. If ``None``, the input replication axis is used. map_type: If "pmap" or "vmap", apply replicated mapping using :func:`jax.pmap` or :func:`jax.vmap` respectively. If "auto", use :func:`jax.pmap` if sufficient devices are available for the number of replicates, otherwise use :func:`jax.vmap`. """ if not isinstance(op, LinearOperator): raise TypeError("Argument 'op' must be of type LinearOperator.") super().__init__( op, replicates, input_axis=input_axis, output_axis=output_axis, map_type=map_type, **kwargs, ) self._adj = self.jaxmap(op.adj, in_axes=self.input_axis, out_axes=self.output_axis) def linop_over_axes( linop: type[LinearOperator], input_shape: Shape, *args: Any, axes: Optional[Axes] = None, **kwargs: Any, ) -> List[LinearOperator]: """Construct a list of :class:`LinearOperator` by iterating over axes. Construct a list of :class:`LinearOperator` by iterating over a specified sequence of axes, passing each value in sequence to the `axis` keyword argument of the :class:`LinearOperator` initializer. Args: linop: Type of :class:`LinearOperator` to construct for each axis. input_shape: Shape of input array. *args: Positional arguments for the :class:`LinearOperator` initializer. axes: Axis or axes over which to construct the list. If not specified, or ``None``, use all axes corresponding to `input_shape`. **kwargs: Keyword arguments for the :class:`LinearOperator` initializer. Returns: A tuple (`axes`, `ops`) where `axes` is a tuple of the axes used to construct the list of :class:`LinearOperator`, and `ops` is the list itself. """ axes = normalize_axes(axes, input_shape) # type: ignore return axes, [linop(input_shape, *args, axis=axis, **kwargs) for axis in axes] # type: ignore ================================================ FILE: scico/linop/_util.py ================================================ # Copyright (C) 2020-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Linear operator utility functions.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from typing import Optional, Union import scico.numpy as snp from scico.operator._operator import Operator from scico.random import randn from scico.typing import PRNGKey from ._linop import LinearOperator def power_iteration(A: LinearOperator, maxiter: int = 100, key: Optional[PRNGKey] = None): """Compute largest eigenvalue of a diagonalizable :class:`.LinearOperator`. Compute largest eigenvalue of a diagonalizable :class:`.LinearOperator` using power iteration. Args: A: :class:`.LinearOperator` used for computation. Must be diagonalizable. maxiter: Maximum number of power iterations to use. key: Jax PRNG key. Defaults to ``None``, in which case a new key is created. Returns: tuple: A tuple (`mu`, `v`) containing: - **mu**: Estimate of largest eigenvalue of `A`. - **v**: Eigenvector of `A` with eigenvalue `mu`. """ v, key = randn(shape=A.input_shape, key=key, dtype=A.input_dtype) v = v / snp.linalg.norm(v) for i in range(maxiter): Av = A @ v normAv = snp.linalg.norm(Av) if normAv == 0.0: # Assume that ||Av|| == 0 implies A is a zero operator mu = 0.0 v = Av break mu = snp.sum(v.conj() * Av) / snp.linalg.norm(v) ** 2 v = Av / normAv return mu, v def operator_norm(A: LinearOperator, maxiter: int = 100, key: Optional[PRNGKey] = None): r"""Estimate the norm of a :class:`.LinearOperator`. Estimate the operator norm `induced `_ by the :math:`\ell_2` vector norm, i.e. for :class:`.LinearOperator` :math:`A`, .. math:: \| A \|_2 &= \max \{ \| A \mb{x} \|_2 \, : \, \| \mb{x} \|_2 \leq 1 \} \\ &= \sqrt{ \lambda_{ \mathrm{max} }( A^H A ) } = \sigma_{\mathrm{max}}(A) \;, where :math:`\lambda_{\mathrm{max}}(B)` and :math:`\sigma_{\mathrm{max}}(B)` respectively denote the largest eigenvalue of :math:`B` and the largest singular value of :math:`B`. The value is estimated via power iteration, using :func:`power_iteration`, to estimate :math:`\lambda_{\mathrm{max}}(A^H A)`. Args: A: :class:`.LinearOperator` for which operator norm is desired. maxiter: Maximum number of power iterations to use. Default: 100 key: Jax PRNG key. Defaults to ``None``, in which case a new key is created. Returns: float: Norm of operator :math:`A`. """ return snp.sqrt(power_iteration(A.H @ A, maxiter, key)[0].real) def valid_adjoint( A: LinearOperator, AT: LinearOperator, eps: Optional[float] = 1e-7, x: Optional[snp.Array] = None, y: Optional[snp.Array] = None, key: Optional[PRNGKey] = None, ) -> Union[bool, float]: r"""Check whether :class:`.LinearOperator` `AT` is the adjoint of `A`. Check whether :class:`.LinearOperator` :math:`\mathsf{AT}` is the adjoint of :math:`\mathsf{A}`. The test exploits the identity .. math:: \mathbf{y}^T (A \mathbf{x}) = (\mathbf{y}^T A) \mathbf{x} = (A^T \mathbf{y})^T \mathbf{x} by computing :math:`\mathbf{u} = \mathsf{A}(\mathbf{x})` and :math:`\mathbf{v} = \mathsf{AT}(\mathbf{y})` for random :math:`\mathbf{x}` and :math:`\mathbf{y}` and confirming that .. math:: \frac{| \mathbf{y}^T \mathbf{u} - \mathbf{v}^T \mathbf{x} |} {\max \left\{ | \mathbf{y}^T \mathbf{u} |, | \mathbf{v}^T \mathbf{x} | \right\}} < \epsilon \;. If :math:`\mathsf{A}` is a complex operator (with a complex `input_dtype`) then the test checks whether :math:`\mathsf{AT}` is the Hermitian conjugate of :math:`\mathsf{A}`, with a test as above, but with all the :math:`(\cdot)^T` replaced with :math:`(\cdot)^H`. Args: A: Primary :class:`.LinearOperator`. AT: Adjoint :class:`.LinearOperator`. eps: Error threshold for validation of :math:`\mathsf{AT}` as adjoint of :math:`\mathsf{AT}`. If ``None``, the relative error is returned instead of a boolean value. x: If not the default ``None``, use the specified array instead of a random array as test vector :math:`\mb{x}`. If specified, the array must have shape `A.input_shape`. y: If not the default ``None``, use the specified array instead of a random array as test vector :math:`\mb{y}`. If specified, the array must have shape `AT.input_shape`. key: Jax PRNG key. Defaults to ``None``, in which case a new key is created. Returns: Boolean value indicating whether validation passed, or relative error of test, depending on type of parameter `eps`. """ if x is None: x, key = randn(shape=A.input_shape, key=key, dtype=A.input_dtype) else: if x.shape != A.input_shape: raise ValueError("Shape of 'x' array not appropriate as an input for operator 'A'.") if y is None: y, key = randn(shape=AT.input_shape, key=key, dtype=AT.input_dtype) else: if y.shape != AT.input_shape: raise ValueError("Shape of 'y' array not appropriate as an input for operator AT.") u = A(x) v = AT(y) yTu = snp.sum(y.conj() * u) # type: ignore vTx = snp.sum(v.conj() * x) # type: ignore err = snp.abs(yTu - vTx) / max(snp.abs(yTu), snp.abs(vTx)) if eps is None: return err return float(err) < eps def jacobian(F: Operator, u: snp.Array, include_eval: Optional[bool] = False) -> LinearOperator: """Construct Jacobian linear operator for a general operator. For a specified :class:`.Operator`, construct a corresponding Jacobian :class:`LinearOperator`, the application of which is equivalent to multiplication by the Jacobian of the :class:`.Operator` at a specified input value. The implementation of this function is based on :meth:`.Operator.jvp` and :meth:`.Operator.vjp`, which are themselves based on :func:`jax.jvp` and :func:`jax.vjp`. For reasons of computational efficiency, these functions return the value of the :class:`.Operator` evaluated at the specified point in addition to the requested Jacobian-vector product. If the `include_eval` parameter of this function is ``True``, the constructed :class:`LinearOperator` returns a :class:`.BlockArray` output, the first component of which is the result of the :class:`.Operator` evaluation, and the second component of which is the requested Jacobian-vector product. If `include_eval` is ``False``, then the :class:`.Operator` evaluation computed by :func:`jax.jvp` and :func:`jax.vjp` are discarded. Args: F: :class:`.Operator` of which the Jacobian is to be computed. u: Input value of the :class:`.Operator` at which the Jacobian is to be computed. include_eval: Flag indicating whether the result of evaluating the :class:`.Operator` should be included (as the first component of a :class:`.BlockArray`) in the output of the Jacobian :class:`LinearOperator` constructed by this function. Returns: A :class:`LinearOperator` capable of computing Jacobian-vector products. """ if include_eval: Fu, G = F.vjp(u, conjugate=True) def adj_fn(v): return snp.blockarray((Fu, G(v))) def eval_fn(v): return snp.blockarray(F.jvp(u, v)) else: adj_fn = F.vjp(u, conjugate=True)[1] def eval_fn(v): return F.jvp(u, v)[1] return LinearOperator( F.input_shape, output_shape=F.output_shape, eval_fn=eval_fn, adj_fn=adj_fn, input_dtype=F.input_dtype, output_dtype=F.output_dtype, ) ================================================ FILE: scico/linop/optics.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. r"""Optical propagator classes. This module provides classes that model the propagation of a monochromatic waveform between two parallel planes in a homogeneous medium. The corresponding linear operators are referred to here as "propagators", which represents a departure from standard terminology, in which "propagator" refers specifically to the Fourier domain component of the linear operator, i.e. if the full linear operator can be written as :math:`F^{-1} D F` where :math:`F` is the Fourier transform, then :math:`D` is usually referred to as the propagator. The following notation is used throughout the module: .. math :: \begin{align} \Delta x, \Delta y & \quad \text{Sampling intervals in } x \text{ and } y \text{ axes}\\ z & \quad \text{Propagation distance} \;\; (z \geq 0) \\ N_x, N_y & \quad \text{Number of samples in } x \text{ and } y \text{ axes}\\ k_0 & \quad \text{Illumination wavenumber corresponding to } 2\pi / \text{wavelength} \;. \end{align} Variables :math:`\Delta x, \Delta y, z,` and :math:`k_0` represent physical quantities. Any units may be chosen, but they must be consistent across all of these variables, e.g. m (metres) for :math:`\Delta x, \Delta y, z,` and :math:`\mathrm{m}^{-1}` for :math:`k_0`, as well as with the units for the physical dimensions of the source wavefield. Subscripts :math:`S` and :math:`D` are used to refer to the source and destination planes respectively when it is necessary to distinguish between them. In the absence of subscripts, the variables refer to the source plane (e.g. both :math:`\Delta x` and :math:`\Delta x_S` refer to the :math:`x`-axis sampling interval in the source plane, while :math:`\Delta x_D` refers to it in the destination plane). Note that :math:`x` corresponds to axis 0 (rows, increasing downwards) and :math:`y` to axis 1 (columns, increasing to the right). """ # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from typing import Any, Tuple, Union import numpy as np from numpy.lib.scimath import sqrt # complex sqrt from typing_extensions import TypeGuard import scico.numpy as snp from scico.linop import Diagonal, Identity, LinearOperator from scico.numpy.util import no_nan_divide from scico.typing import Shape from ._dft import DFT def _isscalar(element: Any) -> TypeGuard[Union[int, float]]: """Type guard interface to `snp.isscalar`.""" return snp.isscalar(element) def radial_transverse_frequency( input_shape: Shape, dx: Union[float, Tuple[float, ...]] ) -> np.ndarray: r"""Construct radial Fourier coordinate system. Args: input_shape: Tuple of length 1 or 2 containing the number of samples per dimension, i.e. :math:`(N_x,)` or :math:`(N_x, N_y)` dx: Sampling interval at source plane. If a float and `len(input_shape)==2` the same sampling interval is applied to both dimensions. If `dx` is a tuple, it must have same length as `input_shape`, and corresponds to either :math:`(\Delta x,)` or :math:`(\Delta x, \Delta y)`. Returns: If `len(input_shape)==1`, returns an ndarray containing corresponding Fourier coordinates. If `len(input_shape) == 2`, returns an ndarray containing the radial Fourier coordinates :math:`\sqrt{k_x^2 + k_y^2}\,`. """ ndim: int = len(input_shape) # 1 or 2 dimensions if ndim not in (1, 2): raise ValueError("Invalid input dimensions; must be 1 or 2.") if _isscalar(dx): dx = (dx,) * ndim else: assert isinstance(dx, tuple) if len(dx) != ndim: raise ValueError( "Argument 'dx' must be a scalar or have len(dx) == len(input_shape); " f"got len(dx)={len(dx)}, len(input_shape)={ndim}." ) assert isinstance(dx, tuple) if ndim == 1: kx = 2 * np.pi * np.fft.fftfreq(input_shape[0], dx[0]) kp = kx elif ndim == 2: kx = 2 * np.pi * np.fft.fftfreq(input_shape[0], dx[0]) ky = 2 * np.pi * np.fft.fftfreq(input_shape[1], dx[1]) kp = np.sqrt(kx[None, :] ** 2 + ky[:, None] ** 2) return kp class Propagator(LinearOperator): """Base class for angular spectrum and Fresnel propagators.""" def __init__( self, input_shape: Shape, dx: Union[float, Tuple[float, ...]], k0: float, z: float, pad_factor: int = 1, **kwargs, ): r""" Args: input_shape: Shape of input array as a tuple of length 1 or 2, corresponding to :math:`(N_x,)` or :math:`(N_x, N_y)`. dx: Sampling interval at source plane. If a float and `len(input_shape)==2` the same sampling interval is applied to both dimensions. If `dx` is a tuple, it must have same length as `input_shape`, and corresponds to either :math:`(\Delta x,)` or :math:`(\Delta x, \Delta y)`. k0: Illumination wavenumber, :math:`k_0`, corresponding to :math:`2 \pi` / wavelength. z: Propagation distance, :math:`z`. pad_factor: The padded input shape is the input shape multiplied by this integer factor. """ ndim = len(input_shape) # 1 or 2 dimensions if ndim not in (1, 2): raise ValueError("Invalid input dimensions; must be 1 or 2.") if _isscalar(dx): dx = (dx,) * ndim else: assert isinstance(dx, tuple) if len(dx) != ndim: raise ValueError( "Argument 'dx' must be a scalar or have len(dx) == len(input_shape); " f"got len(dx)={len(dx)}, len(input_shape)={ndim}." ) assert isinstance(dx, tuple) #: Illumination wavenumber; 2𝜋/wavelength self.k0: float = k0 #: Shape of input after padding self.padded_shape: Shape = tuple(pad_factor * s for s in input_shape) #: Padded source plane side length (dx[i] * padded_shape[i]) self.L: Tuple[float, ...] = tuple( s * d for s, d in zip(self.padded_shape, dx) ) # computational plane size #: Transverse Fourier coordinates (radial) self.kp = radial_transverse_frequency(self.padded_shape, dx) #: Source plane sampling interval self.dx: Union[float, Tuple[float, ...]] = dx #: Propagation distance self.z: float = z # Fourier operator self.F = DFT(input_shape=input_shape, axes_shape=self.padded_shape, jit=False) # Diagonal operator; phase shifting self.D: LinearOperator = Identity(self.kp.shape) super().__init__( input_shape=input_shape, input_dtype=np.complex64, output_shape=input_shape, output_dtype=np.complex64, adj_fn=None, **kwargs, ) def __repr__(self): extra_repr = f""" k0: {self.k0} λ: {2*np.pi/self.k0} z: {self.z} dx: {self.dx} L: {self.L} """ return LinearOperator.__repr__(self) + extra_repr def _eval(self, x): return self.F.inv(self.D @ self.F @ x) class AngularSpectrumPropagator(Propagator): r"""Angular spectrum propagator. Propagates a planar source field with coordinates :math:`(x, y, z_0)` to a destination plane at a distance :math:`z` with coordinates :math:`(x, y, z_0 + z)`. The action of this linear operator is given by (Eq. 3.74, :cite:`goodman-2005-fourier`) .. math :: (A \mb{u})(x, y, z_0 + z) = \frac{1}{2 \pi} \iint_{-\infty}^{\infty} \mb{\hat{u}}(k_x, k_y) e^{j \sqrt{k_0^2 - k_x^2 - k_y^2} \, z} e^{j (x k_x + y k_y) } d k_x \ d k_y \;, where the :math:`\mb{\hat{u}}` is the Fourier transform of the field :math:`\mb{u}(x, y)` in the plane :math:`z=z_0`, given by .. math :: \mb{\hat{u}}(k_x, k_y) = \iint_{-\infty}^{\infty} \mb{u}(x, y) e^{- j (x k_x + y k_y)} d k_x \ d k_y \;, where :math:`(k_x, k_y)` are the :math:`x` and :math:`y` components respectively of the wave-vector of the plane wave, and :math:`j` is the imaginary unit. The angular spectrum propagator can be written .. math :: A\mb{u} = F^{-1} D F \mb{u} \;, where :math:`F` is the Fourier transform with respect to :math:`(x, y)`, :math:`F^{-1}` is the inverse transform with respect to :math:`(k_x, k_y)`, and the propagator term is given by .. math :: D = \exp \left( j \sqrt{k_0^2 - k_x^2 - k_y^2} \, z \right) \;. Aliasing of the wavefield at the destination plane is avoided when the propagator term is adequately sampled according to :cite:`voelz-2009-digital` .. math :: (\Delta x)^2 \geq \frac{\pi}{k_0 N_x} \sqrt{ (\Delta x)^2 N_x^2 + 4 z^2} \quad \text{and} \quad (\Delta y)^2 \geq \frac{\pi}{k_0 N_y} \sqrt{ (\Delta y)^2 N_y^2 + 4 z^2} \;. """ def __init__( self, input_shape: Shape, dx: Union[float, Tuple[float, ...]], k0: float, z: float, pad_factor: int = 1, jit: bool = True, **kwargs, ): r""" Args: input_shape: Shape of input array. Can be a tuple of length 2 or 3. dx: Sampling interval, :math:`\Delta x`, at source plane. If a float and `len(input_shape)==2` the same sampling interval is applied to both dimensions. If `dx` is a tuple, must have same length as `input_shape`. k0: Illumination wavenumber, :math:`k_0`, corresponding to :math:`2 \pi` / wavelength. z: Propagation distance, :math:`z`. pad_factor: The padded input shape is the input shape multiplied by this integer factor. jit: If ``True``, call :meth:`~.Operator.jit` on this :class:`LinearOperator` to jit the forward, adjoint, and gram functions. Same as calling :meth:`~.Operator.jit` after the :class:`LinearOperator` is created. """ # Diagonal operator; phase shifting super().__init__( input_shape=input_shape, dx=dx, k0=k0, z=z, pad_factor=pad_factor, **kwargs ) self.phase = snp.exp(1j * z * sqrt(self.k0**2 - self.kp**2)).astype(np.complex64) self.D = Diagonal(self.phase) self._set_adjoint() if jit: self.jit() def adequate_sampling(self): r"""Verify the angular spectrum kernel is not aliased. Checks the condition for adequate sampling :cite:`voelz-2009-digital`, .. math :: (\Delta x)^2 \geq \frac{\pi}{k_0 N_x} \sqrt{ (\Delta x)^2 N_x^2 + 4 z^2} \quad \text{and} \quad (\Delta y)^2 \geq \frac{\pi}{k_0 N_y} \sqrt{ (\Delta y)^2 N_y^2 + 4 z^2} \;. Returns: ``True`` if the angular spectrum kernel is adequately sampled, ``False`` otherwise. """ tmp = [] for d, N in zip(self.dx, self.padded_shape): tmp.append(d**2 > np.pi / (self.k0 * N) * np.sqrt(d**2 * N**2 + 4 * self.z**2)) return np.all(tmp) def pinv(self, y): """Apply pseudoinverse of Angular Spectrum propagator.""" diag_inv = no_nan_divide(1, self.D.diagonal) return self.F.inv(diag_inv * self.F(y)) class FresnelPropagator(Propagator): r"""Fresnel (small-angle/paraxial) propagator. Propagates a planar source field with coordinates :math:`(x, y, z_0)` to a destination plane at a distance :math:`z` with coordinates :math:`(x, y, z_0 + z)`. The action of this linear operator is given by (Eq. 4.20, :cite:`goodman-2005-fourier`) .. math :: (A \mb{u})(x, y, z + z_0) = e^{j k_0 z} \frac{1}{2 \pi} \iint_{-\infty}^{\infty} \mb{\hat{u}}(k_x, k_y) e^{-j \frac{z}{2 k_0}\left(k_x^2 + k_y^2\right) } e^{j (x k_x + y k_y) } d k_x \ d k_y \;, where the :math:`\mb{\hat{u}}` is the Fourier transform of the field in the source plane, given by .. math :: \mb{\hat{u}}(k_x, k_y) = \iint_{-\infty}^{\infty} \mb{u}(x, y) e^{- j (x k_x + y k_y)} d k_x \ d k_y \;. This linear operator is valid when :math:`k_0^2 << k_x^2 + k_y^2`. The Fresnel propagator can be written .. math :: A\mb{u} = F^{-1} D F \mb{u} \;, where :math:`F` is the Fourier transform with respect to :math:`(x, y)`, :math:`F^{-1}` is the inverse transform with respect to :math:`(k_x, k_y)`, and the propagator term is given by .. math :: D = \exp \left( -j \frac{z}{2 k_0}\left(k_x^2 + k_y^2 \right) \right) \;, where :math:`(k_x, k_y)` are the :math:`x` and :math:`y` components respectively of the wave-vector of the plane wave, and :math:`j` is the imaginary unit. The propagator term is adequately sampled when :cite:`voelz-2011-computational` .. math :: (\Delta x)^2 \geq \frac{2 \pi z }{k_0 N_x} \quad \text{and} \quad (\Delta y)^2 \geq \frac{2 \pi z }{k_0 N_y} \;. """ def __init__( self, input_shape: Shape, dx: float, k0: float, z: float, pad_factor: int = 1, jit: bool = True, **kwargs, ): super().__init__( input_shape=input_shape, dx=dx, k0=k0, z=z, pad_factor=pad_factor, **kwargs ) self.phase = snp.exp(1j * z * (self.k0 - self.kp**2 / (2 * self.k0))).astype(np.complex64) self.D = Diagonal(self.phase) self._set_adjoint() if jit: self.jit() def adequate_sampling(self): r"""Verify the Fresnel propagation kernel is not aliased. Checks the condition for adequate sampling :cite:`voelz-2011-computational`, .. math :: (\Delta x)^2 \geq \frac{2 \pi z }{k_0 N_x} \quad \text{and} \quad (\Delta y)^2 \geq \frac{2 \pi z }{k_0 N_y} \;. Returns: ``True`` if the Fresnel propagation kernel is adequately sampled, ``False`` otherwise. """ tmp = [] for d, N in zip(self.dx, self.padded_shape): tmp.append(d**2 > 2 * np.pi * self.z / (self.k0 * N)) return np.all(tmp) class FraunhoferPropagator(LinearOperator): r"""Fraunhofer (far-field) propagator. Propagates a source field with coordinates :math:`(x_S, y_S)` to a destination plane at a distance :math:`z` with coordinates :math:`(x_D, y_D)`. The action of this linear operator is given by (Eq. 4.25, :cite:`goodman-2005-fourier`) .. math :: (A \mb{u})(x_D, y_D) = \underbrace{\frac{k_0}{2 \pi} \frac{e^{j k_0 z}}{j z} \mathrm{exp} \left( j \frac{k_0}{2 z} (x_D^2 + y_D^2) \right)}_{\triangleq P(x_D, y_D)} \int \mb{u}(x_S, y_S) e^{-j \frac{k_0}{z} (x_D x_S + y_D y_S) } dx_S \ dy_S \;. This is valid when :math:`N_F << 1`, where :math:`N_F` is the Fresnel number (Sec. 1.5, Sec. 4.7.2.1) :cite:`paganin-2006-coherent`. Writing the Fourier transform of the field :math:`\mb{u}` as .. math :: \hat{\mb{u}}(k_x, k_y) = \int e^{-j (k_x x + k_y y)} \mb{u}(x, y) dx \ dy \;, the action of this linear operator can be written .. math :: (A \mb{u})(x_D, y_D) = P(x_D, y_D) \ \hat{\mb{u}} \left({\frac{k_0}{z} x_D, \frac{k_0}{z} y_D}\right) \;. Ignoring multiplicative prefactors, the Fraunhofer propagated field is the Fourier transform of the source field, evaluated at coordinates :math:`(k_x, k_y) = (\frac{k_0}{z} x_D, \frac{k_0}{z} y_D)`. In general, the sampling intervals (and thus plane lengths) differ between source and destination planes. In particular, (Eq. 5.18, :cite:`voelz-2011-computational`) .. math :: \Delta x_D = \frac{2 \pi z}{k_0 L_{Sx} } \quad \text{and} \quad L_{Dx} = \frac{2 \pi z}{k_0 \Delta x_S } \;, and similarly for the :math:`y` axis. The Fraunhofer propagator term :math:`P(x_D, y_D)` is adequately sampled when .. math :: \Delta x_S \geq \sqrt{\frac{2 \pi z}{N_x k_0}} \quad \text{and} \quad \Delta y_S \geq \sqrt{\frac{2 \pi z}{N_y k_0}} \;. """ def __init__( self, input_shape: Shape, dx: Union[float, Tuple[float, ...]], k0: float, z: float, jit: bool = True, **kwargs, ): r""" Args: input_shape: Shape of input array as a tuple of length 1 or 2, corresponding to :math:`(N_x,)` or :math:`(N_x, N_y)`. dx: Sampling interval at source plane. If a float and `len(input_shape)==2` the same sampling interval is applied to both dimensions. If `dx` is a tuple, it must have same length as `input_shape`, and corresponds to either :math:`(\Delta x,)` or :math:`(\Delta x, \Delta y)`. k0: Illumination wavenumber, :math:`k_0`, corresponding to :math:`2 \pi` / wavelength. z: Propagation distance, :math:`z`. jit: If ``True``, jit the evaluation, adjoint, and gram functions of this :class:`LinearOperator`. Default: ``True``. """ ndim = len(input_shape) # 1 or 2 dimensions if ndim not in (1, 2): raise ValueError("Invalid input dimensions; must be 1 or 2.") if _isscalar(dx): dx = (dx,) * ndim else: assert isinstance(dx, tuple) if len(dx) != ndim: raise ValueError( "Argument 'dx' must be a scalar or have len(dx) == len(input_shape); " f"got len(dx)={len(dx)}, len(input_shape)={ndim}." ) assert isinstance(dx, tuple) L: Tuple[float, ...] = tuple(s * d for s, d in zip(input_shape, dx)) #: Illumination wavenumber self.k0: float = k0 #: Propagation distance self.z: float = z #: Source plane side length (dx[i] * input_shape[i]) self.L: Tuple[float, ...] = L #: Source plane sampling interval self.dx: Tuple[float, ...] = dx #: Destination plane sampling interval self.dx_D: Tuple[float, ...] = tuple(np.abs(2 * np.pi * z / (k0 * l)).item() for l in L) #: Destination plane side length self.L_D: Tuple[float, ...] = tuple(np.abs(2 * np.pi * z / (k0 * d)).item() for d in dx) x_D = tuple(np.r_[-l / 2 : l / 2 : d] for l, d in zip(self.L_D, self.dx_D)) # type: ignore # set up radial coordinate system; either x^2 or (x^2 + y^2) if ndim == 1: self.r2 = x_D[0] elif ndim == 2: self.r2 = np.sqrt(x_D[0][:, None] ** 2 + x_D[1][None, :] ** 2) phase = -1j * snp.exp(1j * k0 * z) * snp.exp(1j * 0.5 * k0 / z * self.r2**2) phase *= k0 / (2 * np.pi) * np.abs(1 / z) phase *= np.prod(dx) # from approximating continouous FT with DFT phase = phase.astype(np.complex64) self.F = DFT(input_shape=input_shape, jit=False) self.D = Diagonal(phase) super().__init__( input_shape=input_shape, input_dtype=np.complex64, output_shape=input_shape, output_dtype=np.complex64, **kwargs, ) if jit: self.jit() def __repr__(self): extra_repr = f""" k0: {self.k0} λ: {2*np.pi/self.k0} z: {self.z} dx: {self.dx} L: {self.L} dx_D: {self.dx_D} L_D: {self.L_D} """ return LinearOperator.__repr__(self) + extra_repr def _eval(self, x): x = snp.fft.fftshift(x) y = self.D @ self.F @ x y = snp.fft.ifftshift(y) return y def adequate_sampling(self): r"""Verify the Fraunhofer propagation kernel is not aliased. Checks the condition for adequate sampling :cite:`voelz-2011-computational`, .. math :: \Delta x_S \geq \sqrt{\frac{2 \pi z}{N_x k_0}} \quad \text{and} \quad \Delta y_S \geq \sqrt{\frac{2 \pi z}{N_y k_0}} \;. Returns: ``True`` if the Fraunhofer propagation kernel is adequately sampled, ``False`` otherwise. """ tmp = [] for d, N in zip(self.dx, self.input_shape): tmp.append(d**2 > 2 * np.pi * self.z / (self.k0 * N)) return np.all(tmp) ================================================ FILE: scico/linop/xray/__init__.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2023-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. r"""X-ray transform classes. The tomographic projections that are frequently referred to as Radon transforms are referred to as X-ray transforms in SCICO. While the Radon transform is far more well-known than the X-ray transform, which is the same as the Radon transform for projections in two dimensions, these two transform differ in higher numbers of dimensions, and it is the X-ray transform that is the appropriate mathematical model for beam attenuation based imaging in three or more dimensions. SCICO includes its own integrated 2D and 3D X-ray transforms, and also provides interfaces to those implemented in the `ASTRA toolbox `_ and the `svmbir `_ package. **2D Transforms** The SCICO, ASTRA, and svmbir transforms use different conventions for view angle directions, as illustrated in the figure below. .. plot:: pyfigures/xray_2d_geom.py :align: center :include-source: False :show-source-link: False :caption: Comparison of 2D X-ray projector geometries. The radial arrows are directed towards the locations of the corresponding detectors, with the direction of increasing pixel indices indicated by the arrows on the dotted lines parallel to the detectors. | The conversion from the SCICO projection angle convention to those of the other two transforms is .. math:: \begin{aligned} \theta_{\text{astra}} &= \theta_{\text{scico}} - \frac{\pi}{2} \\ \theta_{\text{svmbir}} &= 2 \pi - \theta_{\text{scico}} \;. \end{aligned} **3D Transforms** There are more significant differences in the interfaces for the 3D SCICO and ASTRA transforms. The SCICO 3D transform :class:`.xray.XRayTransform3D` defines the projection geometry in terms of a set of projection matrices, while the geometry for the ASTRA 3D transform :class:`.astra.XRayTransform3D` may either be specified in terms of a set of view angles, or via a more general set of vectors specifying projection direction and detector orientation. A number of support functions are provided for convering between these conventions. Note that the SCICO transform is implemented in JAX and can be run on both CPU and GPU devices, while the ASTRA transform is implemented in CUDA, and can only be run on GPU devices. """ import sys from ._util import ( center_image, image_alignment_rotation, image_centroid, rotate_volume, volume_alignment_rotation, ) from ._xray import XRayTransform2D, XRayTransform3D __all__ = [ "XRayTransform2D", "XRayTransform3D", "image_centroid", "center_image", "rotate_volume", "image_alignment_rotation", "volume_alignment_rotation", ] # Imported items in __all__ appear to originate in top-level xray module for name in __all__: getattr(sys.modules[__name__], name).__module__ = __name__ ================================================ FILE: scico/linop/xray/_axitom/LICENSE ================================================ MIT License Copyright (c) 2019 PolymerGuy Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: scico/linop/xray/_axitom/README.md ================================================ # AXITOM The modules in this directory are derived from the [AXITOM](https://github.com/PolymerGuy/AXITOM) package. All original components of AXITOM are subject to its license, which is included in the file "LICENSE". ================================================ FILE: scico/linop/xray/_axitom/backprojection.py ================================================ """ This file is a modified version of "backprojection.py" from the [AXITOM](https://github.com/PolymerGuy/AXITOM) package. Filtered back projection functions This module contains the Feldkamp David Kress filtered back projection routines. """ from typing import Optional import jax.numpy as jnp from jax import Array from jax.scipy.ndimage import map_coordinates from .config import Config from .filtering import ramp_filter_and_weight from .utilities import rotate_coordinates def map_object_to_detector_coords(object_xs, object_ys, object_zs, config): """Map the object coordinates to detector pixel coordinates accounting for cone beam divergence. Parameters ---------- object_xs : np.ndarray The x-coordinate array of the object to be reconstructed object_ys : np.ndarray The y-coordinate array of the object to be reconstructed object_zs : np.ndarray The z-coordinate array of the object to be reconstructed config : obj The config object containing all necessary settings for the reconstruction Returns ------- detector_cords_a The detector coordinates along the a-axis corresponding to the given points detector_cords_b The detector coordinates along the b-axis corresponding to the given points """ detector_cords_a = ( ((object_ys * config.source_to_detector_dist) / (object_xs + config.source_to_object_dist)) - config.detector_us[0] ) / config.pixel_size_u if object_xs.ndim == 2: detector_cords_b = ( ( (object_zs[jnp.newaxis, jnp.newaxis, :] * config.source_to_detector_dist) / (object_xs[:, :, jnp.newaxis] + config.source_to_object_dist) ) - config.detector_vs[0] ) / config.pixel_size_v elif object_xs.ndim == 1: detector_cords_b = ( ( (object_zs[jnp.newaxis, :] * config.source_to_detector_dist) / (object_xs[:, jnp.newaxis] + config.source_to_object_dist) ) - config.detector_vs[0] ) / config.pixel_size_v else: raise ValueError("Invalid dimensions on the object coordinates") return detector_cords_a, detector_cords_b def _fdk_axisym(projection_filtered, config, angles): """Filtered back projection algorithm as proposed by Feldkamp David Kress, adapted for axisymmetry. This implementation has been adapted for axis-symmetry by using a single projection only and by only reconstructing a single R-Z slice. This algorithm is based on: https://doi.org/10.1364/JOSAA.1.000612 but follows the notation used by: Henrik Turbell, Cone-Beam Reconstruction Using Filtered Backprojection, PhD Thesis, Linkoping Studies in Science and Technology https://people.csail.mit.edu/bkph/courses/papers/Exact_Conebeam/Turbell_Thesis_FBP_2001.pdf Parameters ---------- projection_filtered : jnp.ndarray The ramp filtered and weighted projection used in the reconstruction config : obj The config object containing all necessary settings for the reconstruction Returns ------- ndarray The reconstructed slice is a R-Z plane of a axis-symmetric tomogram where Z is the symmetry axis. """ proj_width, proj_height = projection_filtered.shape proj_center = int(proj_width / 2) # Allocate an empty array recon_slice = jnp.zeros((proj_width, proj_height), dtype=jnp.float32) for frame_nr, angle in enumerate(angles): x_rotated, y_rotated = rotate_coordinates( jnp.zeros_like(config.object_xs), config.object_ys, jnp.radians(angle), ) detector_cords_a, detector_cords_b = map_object_to_detector_coords( x_rotated, y_rotated, config.object_zs, config ) # a is independent of Z but has to match the shape of b detector_cords_a = detector_cords_a[:, jnp.newaxis] * jnp.ones_like(detector_cords_b) # This term is caused by the divergent cone geometry ratio = (config.source_to_object_dist**2.0) / ( config.source_to_object_dist + x_rotated ) ** 2.0 recon_slice = recon_slice + ratio[:, jnp.newaxis] * map_coordinates( projection_filtered, [detector_cords_a, detector_cords_b], cval=0.0, order=1 ) return recon_slice / angles.size def fdk(projection: Array, config: Config, angles: Optional[Array] = None) -> Array: """Filtered back projection algorithm as proposed by Feldkamp David Kress, adapted for axisymmetry. This implementation has been adapted for axis-symmetry by using a single projection only and by only reconstructing a single R-Z slice. This algorithm is based on: https://doi.org/10.1364/JOSAA.1.000612 but follows the notation used by: Henrik Turbell, Cone-Beam Reconstruction Using Filtered Backprojection, PhD Thesis, Linkoping Studies in Science and Technology https://people.csail.mit.edu/bkph/courses/papers/Exact_Conebeam/Turbell_Thesis_FBP_2001.pdf Args: projection: The projection used in the reconstruction config: The config object containing all necessary settings for the reconstruction. angles: Array of angles at which reconstruction should be computed. Defaults to 0 to 359 degrees with a 1 degree step. Returns: The reconstructed slice is a R-Z plane of a axis-symmetric tomogram where Z is the symmetry axis. """ if angles is None: angles = jnp.arange(0, 360) if not isinstance(config, Config): raise ValueError("Only instances of Config are valid settings") if projection.ndim == 2: projection_filtered = ramp_filter_and_weight(projection, config) else: raise ValueError("The projection has to be a 2D array") tomo = _fdk_axisym(projection_filtered, config, angles) return tomo ================================================ FILE: scico/linop/xray/_axitom/config.py ================================================ """ This file is a modified version of "config.py" from the [AXITOM](https://github.com/PolymerGuy/AXITOM) package. Config object and factory. This module contains the Config class which has all the settings that are used during the reconstruction of the tomogram. """ import numpy as np class Config: """Configuration object for the forward projection.""" def __init__( self, n_pixels_u: int, n_pixels_v: int, pixel_size_u: float, pixel_size_v: float, source_to_detector_dist: float, source_to_object_dist: float, **kwargs, ): """ Note that invalid arguments are neglected without warning. Args: n_pixels_u: Number of pixels in the u direction of the sensor. n_pixels_v: Number of pixels in the u direction of the sensor. pixel_size_u: Pixel size in the u direction [mm]. pixel_size_v: Pixel size in the v direction [mm]. source_to_detector_dist: Distance between source and detector [mm]. source_to_object_dist: Distance between source and object [mm]. """ self.n_pixels_u = n_pixels_u self.n_pixels_v = n_pixels_v self.pixel_size_u = pixel_size_u self.pixel_size_v = pixel_size_v self.detector_size_u = self.pixel_size_u * self.n_pixels_u self.detector_size_v = self.pixel_size_v * self.n_pixels_v self.source_to_detector_dist = source_to_detector_dist self.source_to_object_dist = source_to_object_dist # All values below are calculated self.object_size_x = ( self.detector_size_u * self.source_to_object_dist / self.source_to_detector_dist ) self.object_size_y = ( self.detector_size_u * self.source_to_object_dist / self.source_to_detector_dist ) self.object_size_z = ( self.detector_size_v * self.source_to_object_dist / self.source_to_detector_dist ) self.voxel_size_x = self.object_size_x / self.n_pixels_u self.voxel_size_y = self.object_size_y / self.n_pixels_u self.voxel_size_z = self.object_size_z / self.n_pixels_v self.object_ys = ( np.arange(self.n_pixels_u, dtype=np.float32) - self.n_pixels_u / 2.0 ) * self.voxel_size_y self.object_xs = ( np.arange(self.n_pixels_u, dtype=np.float32) - self.n_pixels_u / 2.0 ) * self.voxel_size_x self.object_zs = ( np.arange(self.n_pixels_v, dtype=np.float32) - self.n_pixels_v / 2.0 ) * self.voxel_size_z self.detector_us = ( np.arange(self.n_pixels_u, dtype=np.float32) - self.n_pixels_u / 2.0 ) * self.pixel_size_u self.detector_vs = ( np.arange(self.n_pixels_v, dtype=np.float32) - self.n_pixels_v / 2.0 ) * self.pixel_size_v def __repr__(self): str = f"Source-object distance: {self.source_to_object_dist} " str += f"Source-detector distance: {self.source_to_detector_dist}\n" str += f"Detector pixels: {self.n_pixels_u}, {self.n_pixels_v} " str += f"Detector size: {self.detector_size_u:.3e}, {self.detector_size_v:.3e}\n" str += f"Pixel size: {self.pixel_size_u:.3e}, {self.pixel_size_v:.3e}\n" str += ( f"Voxel size: {self.voxel_size_x:.3e}, {self.voxel_size_y:.3e}, " f"{self.voxel_size_z:.3e}" ) return str def with_param(self, **kwargs): """Get a clone of the object with changed parameters. Get a clone of the object with changed parameters and all calculations updated. Args: kwargs: The arguments of the config object that should be changed. Returns: obj: Config object with modified settings. """ params = self.__dict__.copy() for arg, value in kwargs.items(): params[arg] = value return Config(**params) ================================================ FILE: scico/linop/xray/_axitom/filtering.py ================================================ """ This file is a modified version of "filtering.py" from the [AXITOM](https://github.com/PolymerGuy/AXITOM) package. Filter tools This module contains the ramp filter and the weighting function. """ import numpy as np import jax.numpy as jnp import jax.scipy.signal as sig def _ramp_kernel_real(cutoff, length): """Ramp filter kernel in real space defined by the cut-off frequency and the spatial dimension. Parameters ---------- cutoff : float The cut-off frequency length : int The kernel filter length Returns ------- ndarray The filter kernel """ pos = jnp.arange(-length, length, 1) return cutoff**2.0 * (2.0 * jnp.sinc(2 * pos * cutoff) - jnp.sinc(pos * cutoff) ** 2.0) def _add_weights(projection, config): """Add weights to the projection according to the ray length traveled through a voxel. Parameters ---------- projection : jnp.ndarray The projection used in the reconstruction config : obj The config object containing all necessary settings for the reconstruction Returns ------- ndarray The projections weighted by the ray length """ uu, vv = jnp.meshgrid(config.detector_vs, config.detector_us) weights = config.source_to_detector_dist / jnp.sqrt( config.source_to_detector_dist**2.0 + uu**2.0 + vv**2.0 ) return projection * weights def ramp_filter_and_weight(projection, config): """Add weights to the projection and apply a ramp-high-pass filter set to 0.5*Nyquist_frequency Parameters ---------- projection : jnp.ndarray The projection used in the reconstruction config : obj The config object containing all necessary settings for the reconstruction Returns ------- ndarray The projections weighted by the ray length and filtered by ramp filter """ projections_weighted = _add_weights(projection, config) n_pixels_u, _ = np.shape(projections_weighted) ramp_kernel = _ramp_kernel_real(0.5, n_pixels_u) projections_filtered = np.zeros_like(projections_weighted) _, n_lines = projections_weighted.shape for j in range(n_lines): projections_filtered[:, j] = sig.fftconvolve( projections_weighted[:, j], ramp_kernel, mode="same" ) scale_factor = ( 1.0 / config.pixel_size_u * np.pi * (config.source_to_detector_dist / config.source_to_object_dist) ) return projections_filtered * scale_factor ================================================ FILE: scico/linop/xray/_axitom/projection.py ================================================ """ This file is a modified version of "projection.py" from the [AXITOM](https://github.com/PolymerGuy/AXITOM) package. Forward projection routines. This module contains the functions used to forward project a volume onto a sensor plane. """ from functools import partial import numpy as np import jax import jax.numpy as jnp from jax import Array, jit from jax.scipy.ndimage import map_coordinates from .config import Config @partial(jit, static_argnames=["config", "input_2d"]) def _partial_forward_project( volume: Array, uu: Array, vv: Array, irslab, config: Config, input_2d: bool = False, ) -> Array: """Partial projection of a volume onto a sensor plane. Partial projection of a cylindrically symmetric volume onto a sensor plane using conical beam geometry: this functional only sums along the section of the imaging direction specified by :code:`ratios`. Args: volume: The volume that will be projected onto the sensor. uu: Detector grid in axis 1 direction. vv: Detector grid in axis 0 direction. irslab: Array of indices and ratios. config: The settings object. input_2d: If ``True``, the input is a 2D image from which a 3D volume is constructed by rotation about the center of axis 1 of the image. Returns: The projection. """ islab = irslab[0] rslab = irslab[1] N = config.object_ys.size pvs = ( vv[:, jnp.newaxis, :] * rslab[jnp.newaxis, :, jnp.newaxis] - config.object_zs[0] ) / config.voxel_size_z pys = islab[jnp.newaxis, :, jnp.newaxis] * jnp.ones_like(pvs) pus = ( uu[:, jnp.newaxis, :] * rslab[jnp.newaxis, :, jnp.newaxis] - config.object_xs[0] ) / config.voxel_size_x if input_2d: ax0c, ax1c, ax2c = ((np.array(pvs.shape) + 1) / 2 - 1).tolist() ax1c = (N + 1) / 2 - 1 r = jnp.hypot(pus - ax2c, pys - ax1c) ax1 = jnp.where(pys >= ax1c, ax1c + r, ax1c - r) proj2d = jnp.sum(map_coordinates(volume, [pvs, ax1], cval=0.0, order=1), axis=1) else: proj2d = jnp.sum(map_coordinates(volume, [pvs, pys, pus], cval=0.0, order=1), axis=1) dist = ( jnp.sqrt(config.source_to_detector_dist**2.0 + uu**2.0 + vv**2.0) / (config.source_to_detector_dist) * config.voxel_size_y ) return proj2d * dist @partial(jit, static_argnames=["config", "num_slabs", "input_2d"]) def forward_project( volume: Array, config: Config, num_slabs: int = 8, input_2d: bool = False ) -> Array: """Projection of a volume onto a sensor plane. Projection of a cylindrically symmetric volume onto a sensor plane using conical beam geometry. Args: volume: The volume that will be projected onto the sensor. config: The settings object. num_slabs: Number of slabs into which the volume should be divided (for serial processing, to limit memory usage) in the imaging direction. input_2d: If ``True``, the input is a 2D image from which a 3D volume is constructed by rotation about the center of axis 1 of the image. Returns: The projection. """ uu, vv = jnp.meshgrid(config.detector_us, config.detector_vs) ratios = (config.object_ys + config.source_to_object_dist) / config.source_to_detector_dist N = ratios.size slab_size = N // num_slabs remainder = N % num_slabs islabs = jnp.stack(jnp.split(jnp.arange(0, slab_size * num_slabs), num_slabs)) rslabs = jnp.stack(jnp.split(ratios[0 : slab_size * num_slabs], num_slabs)) irslabs = jnp.stack((islabs, rslabs), axis=1) func = lambda irslab: _partial_forward_project( volume, uu, vv, irslab, config, input_2d=input_2d ) # jax.checkpoint used to avoid excessive memory requirements proj = jnp.sum(jax.lax.map(jax.checkpoint(func), irslabs), axis=0) if remainder: irslab = jnp.stack((jnp.arange(slab_size * num_slabs, N), ratios[-remainder:])) proj += jax.checkpoint(func)(irslab) return proj ================================================ FILE: scico/linop/xray/_axitom/utilities.py ================================================ """ This file is a modified version of "utilities.py" from the [AXITOM](https://github.com/PolymerGuy/AXITOM) package. Utilites This module contains various utility functions that does not have any other obvious home. """ import numpy as np import jax.numpy as jnp def _find_center_of_gravity_in_projection(projection, background_internsity=0.9): """Find axis of rotation in the projection. This is done by binarization of the image into object and background and determining the center of gravity of the object. Parameters ---------- projection : ndarray The projection, normalized between 0 and 1 background_internsity : float The background intensity threshold Returns ------- float64 The center of gravity in the u-direction float64 The center of gravity in the v-direction """ m, n = np.shape(projection) binary_proj = np.zeros_like(projection, dtype=np.float) binary_proj[projection < background_internsity] = 1.0 area_x = np.sum(binary_proj, axis=1) area_y = np.sum(binary_proj, axis=0) non_zero_rows = np.arange(n)[area_y != 0.0] non_zero_columns = np.arange(m)[area_x != 0.0] # Now removing all columns that does not intersect the object object_pixels = binary_proj[non_zero_columns, :][:, non_zero_rows] area_x = area_x[non_zero_columns] area_y = area_y[non_zero_rows] xs, ys = np.meshgrid(non_zero_rows, non_zero_columns) # Determine center of gravity center_of_grav_x = np.average(np.sum(xs * object_pixels, axis=1) / area_x) - n / 2.0 center_of_grav_y = np.average(np.sum(ys * object_pixels, axis=0) / area_y) - m / 2.0 return center_of_grav_x, center_of_grav_y def find_center_of_rotation(projection, background_internsity=0.9, method="center_of_gravity"): """Find the axis of rotation of the object in the projection Parameters ---------- projection : ndarray The projection, normalized between 0 and 1 background_internsity : float The background intensity threshold method : string The background intensity threshold Returns ------- float64 The center of gravity in the v-direction float64 The center of gravity in the u-direction """ if projection.ndim != 2: raise ValueError("Invalid projection shape. It has to be a 2d numpy array") if method == "center_of_gravity": center_v, center_u = _find_center_of_gravity_in_projection( projection, background_internsity ) else: raise ValueError("Invalid method") return center_v, center_u def rotate_coordinates(xs_array, ys_array, angle_rad): """Rotate coordinate arrays by a given angle Parameters ---------- xs_array : ndarray Two dimensional coordinate array with x-coordinates ys_array : ndarray Two dimensional coordinate array with y-coordinates angle_rad : float Rotation angle in radians Returns ------- ndarray The rotated x-coordinates ndarray The rotated x-coordinates """ rx = xs_array * jnp.cos(angle_rad) + ys_array * jnp.sin(angle_rad) ry = -xs_array * jnp.sin(angle_rad) + ys_array * jnp.cos(angle_rad) return rx, ry ================================================ FILE: scico/linop/xray/_util.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2024-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Utilities for CT data.""" from typing import Optional, Tuple import numpy as np import jax.numpy as jnp import jax.scipy.spatial.transform as jsst from jax import Array from jax.image import ResizeMethod, scale_and_translate from jax.scipy.ndimage import map_coordinates from jax.typing import ArrayLike try: import scico.linop.xray.astra have_astra = True except ModuleNotFoundError as e: if e.name == "astra": have_astra = False else: raise e import scipy.spatial.transform as sst def image_centroid(v: ArrayLike, center_offset: bool = False) -> Tuple[float, ...]: """Compute the centroid of an image. Compute the centroid of an image or higher-dimensional array. Args: v: Array for which centroid is to be computed. center_offset: If ``True``, compute centroid coordinates relative to the spatial center of the image. Returns: Tuple of centroid coordinates. """ if center_offset: offset = (jnp.array(v.shape, dtype=jnp.float32) - 1.0) / 2.0 else: offset = jnp.zeros((v.ndim,), dtype=jnp.float32) g1d = [jnp.arange(size, dtype=jnp.float32) - offset[idx] for idx, size in enumerate(v.shape)] g = jnp.meshgrid(*g1d, sparse=True, indexing="ij") m00 = v.astype(jnp.float32).sum() if m00 == 0.0: c = (0.0,) * v.ndim else: c = tuple([(jnp.sum(v * g[idx]) / m00).item() for idx in range(v.ndim)]) return c def center_image( v: ArrayLike, axes: Optional[Tuple[int, ...]] = None, method: ResizeMethod = ResizeMethod.LANCZOS3, ) -> Array: """Translate an image to center the centroid. Translate an image (or higher-dimensional array) so that the centroid is at the spatial center of the image grid. Args: v: Array to be centered. axes: Array axes on which centering is to be applied. Defaults to all axes. method: Interpolation method for image translation. Returns: Centered array. """ if axes is None: axes = tuple(range(v.ndim)) c = jnp.array(image_centroid(v, center_offset=True), dtype=jnp.float32) scale = jnp.ones((v.ndim,), dtype=jnp.float32)[jnp.array(axes)] trans = -c[jnp.array(axes)] cv = scale_and_translate(v, v.shape, axes, scale, trans, method=method) return cv def rotate_volume( vol: ArrayLike, rot: jsst.Rotation, x: Optional[ArrayLike] = None, y: Optional[ArrayLike] = None, z: Optional[ArrayLike] = None, center: Optional[ArrayLike] = None, ) -> Array: """Rotate a 3D array. Rotate a 3D array as specified by an instance of :class:`~jax.scipy.spatial.transform.Rotation`. Any axis coordinates that are not specified default to a range corresponding to the size of the array on that axis, starting at zero. Args: vol: Array to be rotated. rot: Rotation specification. x: Coordinates for :code:`x` axis (axis 0). y: Coordinates for :code:`y` axis (axis 1). z: Coordinates for :code:`z` axis (axis 2). center: A 3-vector specifying the center of rotation. Defaults to the center of the array. Returns: Rotated array. """ shape = vol.shape if x is None: x = jnp.arange(shape[0]) if y is None: y = jnp.arange(shape[1]) if z is None: z = jnp.arange(shape[2]) if center is None: center = (jnp.array(shape, dtype=jnp.float32) - 1.0) / 2.0 gx, gy, gz = jnp.meshgrid(x - center[0], y - center[1], z - center[2], indexing="ij") crd = jnp.stack((gx.ravel(), gy.ravel(), gz.ravel())) rot_crd = rot.as_matrix() @ crd + center[:, jnp.newaxis] # faster than rot.apply(crd.T) rot_vol = map_coordinates(vol, rot_crd.reshape((3,) + shape), order=1) return rot_vol def image_alignment_rotation( img: ArrayLike, max_angle: float = 2.5, angle_step: float = 0.025, center_factor: float = 5e-3 ) -> float: r"""Estimate an image alignment rotation. Estimate the rotation that best aligns vertical straight lines in the image with the vertical axis. The approach is roughly based on that used in the :code:`find_img_rotation_2D` function in the `cSAXS base package` released by the CXS group at the Paul Scherrer Institute, which finds the rotation angle that results in the sparsest column sum according to the sparsity measure proposed in Sec 3.1 of :cite:`hoyer-2004-nonnegative`. (Note that an :math:`\ell_1` norm sparsity measure is not suitable for this purpose since it is, in typical cases, appropximately invariant to the rotation angle.) The implementation here uses the plain ratio of :math:`\ell_1` and :math:`\ell_2` norms as a sparsity measure, more efficiently computes the column sums at different angles by exploiting the 2D X-ray transform, and includes a small bias for smaller angle rotations that improves performance when a range of rotation angles have the same sparsity measure. Args: img: Array of pixel values. max_angle: Maximum angle (negative and positive) to test, in degrees. angle_step: Increment in angle values for range of angles to test, in degrees. center_factor: The angle multiplied by this scalar is added to the sparsity measure to slightly prefer smaller-angle solutions. Returns: Rotation angle (in degrees) providing best alignment with the vertical (0) axis. Notes: The number number of detector pixels for the 2D X-ray transform is chosen based on the shape :math:`(N_0, N_1)` of :code:`img` and the value :math:`\theta` of parameter :code:`max_angle`, as indicated in Fig. 1. .. figure:: /figures/img_align.svg :align: center :width: 40% Fig 1. Calculation of the number of detector pixels for the 2D X-ray transform. """ if not have_astra: raise RuntimeError("Package astra is required for use of this function.") angles = np.arange(-max_angle, max_angle, angle_step) max_angle_rad = max_angle * np.pi / 180 # choose det_count so that projected image is within the detector bounds det_count = int( 1.05 * (img.shape[0] * np.sin(max_angle_rad) + img.shape[1] * np.cos(max_angle_rad)) ) A = scico.linop.xray.astra.XRayTransform2D( img.shape, det_count=det_count, det_spacing=1.0, angles=angles * np.pi / 180.0, ) y = A @ jnp.abs(img) # compute the ℓ1/ℓ2 norm of the projection for each view angle cost = jnp.sum(jnp.abs(y), axis=1) / jnp.sqrt(jnp.sum(y**2, axis=1)) ext_cost = cost + center_factor * (cost.max() - cost.min()) * jnp.abs(angles) idx = jnp.argmin(ext_cost) return angles[idx] def volume_alignment_rotation( vol: ArrayLike, xslice: Optional[int] = None, yslice: Optional[int] = None, max_angle: float = 2.5, angle_step: float = 0.025, center_factor: float = 5e-3, ) -> jsst.Rotation: r"""Estimate a volume alignment rotation. Estimate the 3D rotation that best aligns planar structures in a volume with the x-y (0-1) plane. The algorithm is based on independent rotation angle estimates, obtained using :func:`image_alignment_rotation`, within 2D slices in the x-z (0-2) and y-z (1-2) planes. These estimates are integrated into a combined 3D rotation specification as explained in the technical note below. Args: vol: Array of voxel values. xslice: Index of slice on axis 0. yslice: Index of slice on axis 1. max_angle: Maximum angle (negative and positive) to test, in degrees. angle_step: Increment in angle values for range of angles to test, in degrees. center_factor: The angle multiplied by this scalar is added to the sparsity measure to slightly prefer smaller-angle solutions. Returns: Rotation object. Notes: The estimation of the 3D rotation required to align planar structure in the volume with the x-y (0-1) plane is approached by estimating the 3D normal vector to this structure, illustrated in Fig. 1. The independent rotation angle estimates with the x-z (0-2) and y-z (1-2) planes are exploited as estimates (after a 90° rotation of each) as estimates of the projections of this normal vector into the x-z (0-2) and y-z (1-2) planes, illustrated in Figs. 2 and 3 respectively. .. figure:: /figures/vol_align_xyz.svg :align: center :width: 60% Fig 1. 3D orientation of the normal to the plane that is desired to be aligned with the x-y plane. .. list-table:: :width: 100 * - .. figure:: /figures/vol_align_xz.svg :align: center :width: 100% Fig 2. Projection of the normal onto the x-z plane. - .. figure:: /figures/vol_align_yz.svg :align: center :width: 100% Fig 3. Projection of the normal onto the y-z plane. It can be observed from these figures that .. math:: x &= r_x \cos (\theta_x) \\ y &= r_y \cos (\theta_y) \\ z &= r_x \sin (\theta_x) = r_y \sin (\theta_y) \;, where :math:`(x, y, z)` are the coordinates of the normal vector. We can write .. math:: r_x = \frac{z}{\sin(\theta_x)} \quad \text{and} \quad r_y = \frac{z}{\sin(\theta_y)} \;, and therefore .. math:: x = z \cot (\theta_x) \quad \text{and} \quad y = z \cot (\theta_y) \;. Since :math:`(x, y, z) = z (\cot (\theta_x), \cot (\theta_y), 1)` it is clear that the choice of :math:`z` only affects the norm of the vector, and can therefore be set to unity. The rotation of this vector is then determined by computing the rotation required to align it (after normalization) with the :math:`z` axis :math:`(0, 0, 1)`. """ # x, y, z volume axes correspond to axes 0, 1, 2 if xslice is None: xslice = vol.shape[0] // 2 # default to central slice if yslice is None: yslice = vol.shape[1] // 2 # default to central slice # projected angles of normal to plane angles identified in yz and xz slices angle_y = ( (90 - image_alignment_rotation(vol[xslice], max_angle=max_angle, angle_step=angle_step)) * np.pi / 180 ) angle_x = ( (90 - image_alignment_rotation(vol[:, yslice], max_angle=max_angle, angle_step=angle_step)) * np.pi / 180 ) # unit vector normal to plane vec = np.array([1.0 / np.tan(angle_x), 1.0 / np.tan(angle_y), 1.0]) vec /= np.linalg.norm(vec) # rotation required to align unit vector with z axis r = sst.Rotation.align_vectors(vec, np.array([0, 0, 1]))[0] # jax.scipy.spatial.transform.Rotation does not have align_vectors method return jsst.Rotation.from_quat(r.as_quat()) ================================================ FILE: scico/linop/xray/_xray.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2023-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """X-ray transform classes.""" from functools import partial from typing import Optional, Tuple from warnings import warn import numpy as np import jax import jax.numpy as jnp from jax.typing import ArrayLike import scico.numpy as snp from scico.numpy.util import is_scalar_equiv from scico.typing import Shape from scipy.spatial.transform import Rotation from .._linop import LinearOperator class XRayTransform2D(LinearOperator): r"""Parallel ray, single axis, 2D X-ray projector. This implementation approximates the projection of each rectangular pixel as a boxcar function (whereas the exact projection is a trapezoid). Detector pixels are modeled as bins (rather than points) and this approximation allows fast calculation of the contribution of each pixel to each bin because the integral of the boxcar is simple. By requiring the width of a projected pixel to be less than or equal to the bin width (which is defined to be 1.0), we ensure that each pixel contributes to at most two bins, which accelerates the accumulation of pixel values into bins (equivalently, makes the linear operator sparse). Warning: The default pixel spacing is :math:`\sqrt{2}/2` (rather than 1) in order to satisfy the aforementioned spacing requirement. `x0`, `dx`, and `y0` should be expressed in units such that the detector spacing `dy` is 1.0. """ def __init__( self, input_shape: Shape, angles: ArrayLike, x0: Optional[ArrayLike] = None, dx: Optional[ArrayLike] = None, y0: Optional[float] = None, det_count: Optional[int] = None, ): r""" Args: input_shape: Shape of input array. angles: (num_angles,) array of angles in radians. Viewing an (M, N) array as a matrix with M rows and N columns, an angle of 0 corresponds to summing rows, an angle of pi/2 corresponds to summing columns, and an angle of pi/4 corresponds to summing along antidiagonals. x0: (x, y) position of the corner of the pixel `im[0,0]`. By default, `(-input_shape * dx[0] / 2, -input_shape * dx[1] / 2)`. dx: Image pixel side length in x- and y-direction (axis 0 and 1 respectively). Must be set so that the width of a projected pixel is never larger than 1.0. By default, [:math:`\sqrt{2}/2`, :math:`\sqrt{2}/2`]. y0: Location of the edge of the first detector bin. By default, `-det_count / 2` det_count: Number of elements in detector. If ``None``, defaults to the size of the diagonal of `input_shape`. """ self.input_shape = input_shape self.angles = angles self.nx = tuple(input_shape) if dx is None: dx = 2 * (np.sqrt(2) / 2,) if is_scalar_equiv(dx): dx = 2 * (dx,) self.dx = dx # check projected pixel width assumption Pdx = np.stack((dx[0] * jnp.cos(angles), dx[1] * jnp.sin(angles))) Pdiag1 = np.abs(Pdx[0] + Pdx[1]) Pdiag2 = np.abs(Pdx[0] - Pdx[1]) max_width: float = np.max(np.maximum(Pdiag1, Pdiag2)) if max_width > 1: warn( f"A projected pixel has width {max_width} > 1.0, " "which will reduce projector accuracy." ) if x0 is None: x0 = -(np.array(self.nx) * self.dx) / 2 self.x0 = x0 if det_count is None: det_count = int(np.ceil(np.linalg.norm(input_shape))) self.det_count = det_count self.ny = det_count self.output_shape = (len(angles), det_count) if y0 is None: y0 = -self.ny / 2 self.y0 = y0 self.dy = 1.0 self.fbp_filter: Optional[snp.Array] = None self.fbp_mask: Optional[snp.Array] = None super().__init__( input_shape=self.input_shape, input_dtype=np.float32, output_shape=self.output_shape, output_dtype=np.float32, eval_fn=self.project, adj_fn=self.back_project, ) def project(self, im: ArrayLike) -> snp.Array: """Compute X-ray projection, equivalent to `H @ im`. Args: im: Input array representing the image to project. """ return XRayTransform2D._project(im, self.x0, self.dx, self.y0, self.ny, self.angles) def back_project(self, y: ArrayLike) -> snp.Array: """Compute X-ray back projection, equivalent to `H.T @ y`. Args: y: Input array representing the sinogram to back project. """ return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles) def fbp(self, y: ArrayLike) -> snp.Array: r"""Compute filtered back projection (FBP) inverse of projection. Compute the filtered back projection inverse by filtering each row of the sinogram with the filter defined in (61) in :cite:`kak-1988-principles` and then back projecting. The projection angles are assumed to be evenly spaced in :math:`[0, \pi)`; reconstruction quality may be poor if this assumption is violated. Poor quality reconstructions should also be expected when `dx[0]` and `dx[1]` are not equal. Args: y: Input projection, (num_angles, N). Returns: FBP inverse of projection. """ N = y.shape[1] if self.fbp_filter is None: nvec = jnp.arange(N) - (N - 1) // 2 self.fbp_filter = XRayTransform2D._ramp_filter(nvec, 1.0).reshape(1, -1) if self.fbp_mask is None: unit_sino = jnp.ones(self.output_shape, dtype=np.float32) # Threshold is multiplied by 0.99... fudge factor to account for numerical errors # in back projection. self.fbp_mask = self.back_project(unit_sino) >= (self.output_shape[0] * (1.0 - 1e-5)) # type: ignore # Apply ramp filter in the frequency domain, padding to avoid # boundary effects h = self.fbp_filter hf = jnp.fft.fft(h, n=2 * N - 1, axis=1) yf = jnp.fft.fft(y, n=2 * N - 1, axis=1) hy = jnp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[ :, (N - 1) // 2 : -(N - 1) // 2 ].real.astype(jnp.float32) x = (jnp.pi * self.dx[0] * self.dx[1] / y.shape[0]) * self.fbp_mask * self.back_project(hy) # type: ignore return x @staticmethod def _ramp_filter(x: ArrayLike, tau: float) -> snp.Array: """Compute coefficients of ramp filter used in FBP. Compute coefficients of ramp filter used in FBP, as defined in (61) in :cite:`kak-1988-principles`. Args: x: Sampling locations at which to compute filter coefficients. tau: Sampling rate. Returns: Spatial-domain coefficients of ramp filter. """ # The (x == 0) term in x**2 * np.pi**2 * tau**2 + (x == 0) # is included to avoid division by zero warnings when x == 1 # since np.where evaluates all values for both True and False # branches. return jnp.where( x == 0, 1.0 / (4.0 * tau**2), jnp.where(x % 2, -1.0 / (x**2 * np.pi**2 * tau**2 + (x == 0)), 0), ) @staticmethod @partial(jax.jit, static_argnames=["ny"]) def _project( im: ArrayLike, x0: ArrayLike, dx: ArrayLike, y0: float, ny: int, angles: ArrayLike ) -> snp.Array: r"""Compute X-ray projection. Args: im: Input array, (M, N). x0: (x, y) position of the corner of the pixel im[0,0]. dx: Pixel side length in x- and y-direction. Units are such that the detector bins have length 1.0. y0: Location of the edge of the first detector bin. ny: Number of detector bins. angles: (num_angles,) array of angles in radians. Pixels are projected onto unit vectors pointing in these directions. """ nx = im.shape inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0) # avoid incompatible types in the .add (scatter operation) weights = weights.astype(im.dtype) # Handle out of bounds indices by setting weight to zero weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0) y = ( jnp.zeros((len(angles), ny), dtype=im.dtype) .at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] .add(im * weights_valid) ) weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0) y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * weights_valid) return y @staticmethod @partial(jax.jit, static_argnames=["nx"]) def _back_project( y: ArrayLike, x0: ArrayLike, dx: ArrayLike, nx: Shape, y0: float, angles: ArrayLike ) -> snp.Array: r"""Compute X-ray back projection. Args: y: Input projection, (num_angles, N). x0: (x, y) position of the corner of the pixel im[0,0]. dx: Pixel side length in x- and y-direction. Units are such that the detector bins have length 1.0. nx: Shape of back projection. y0: Location of the edge of the first detector bin. angles: (num_angles,) array of angles in radians. Pixels are projected onto units vectors pointing in these directions. """ ny = y.shape[1] inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0) # Handle out of bounds indices by setting weight to zero weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0) # the idea: [y[0, inds[0]], y[1, inds[1]], ...] HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights_valid, axis=0) weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0) HTy = HTy + jnp.sum( y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * weights_valid, axis=0 ) return HTy.astype(jnp.float32) @staticmethod @partial(jax.jit, static_argnames=["nx"]) @partial(jax.vmap, in_axes=(None, None, None, 0, None)) def _calc_weights( x0: ArrayLike, dx: ArrayLike, nx: Shape, angles: ArrayLike, y0: float ) -> Tuple[snp.Array, snp.Array]: """ Args: x0: Location of the corner of the pixel im[0,0]. dx: Pixel side length in x- and y-direction. Units are such that the detector bins have length 1.0. nx: Input image shape. angles: (num_angles,) array of angles in radians. Pixels are projected onto units vectors pointing in these directions. (This argument is `vmap`ed.) y0: Location of the edge of the first detector bin. """ u = [jnp.cos(angles), jnp.sin(angles)] Px0 = x0[0] * u[0] + x0[1] * u[1] - y0 Pdx = [dx[0] * u[0], dx[1] * u[1]] Pxmin = jnp.min(jnp.array([Px0, Px0 + Pdx[0], Px0 + Pdx[1], Px0 + Pdx[0] + Pdx[1]])) Px = ( Pxmin + Pdx[0] * jnp.arange(nx[0]).reshape(-1, 1) + Pdx[1] * jnp.arange(nx[1]).reshape(1, -1) ) # detector bin inds inds = jnp.floor(Px).astype(int) # weights Pdx = jnp.array(u) * jnp.array(dx) diag1 = jnp.abs(Pdx[0] + Pdx[1]) diag2 = jnp.abs(Pdx[0] - Pdx[1]) w = jnp.max(jnp.array([diag1, diag2])) f = jnp.min(jnp.array([diag1, diag2])) width = (w + f) / 2 distance_to_next = 1 - (Px - inds) # always in (0, 1] weights = jnp.minimum(distance_to_next, width) / width return inds, weights class XRayTransform3D(LinearOperator): r"""General-purpose, 3D, parallel ray X-ray projector. This projector approximates cubic voxels projecting onto rectangular pixels and provides a back projector that is the exact adjoint of the forward projector. It is written purely in JAX, allowing it to run on either CPU or GPU and minimizing host copies. Warning: This class is experimental and may be up to ten times slower than :class:`scico.linop.xray.astra.XRayTransform3D`. For each view, the projection geometry is specified by an array with shape (2, 4) that specifies a :math:`2 \times 3` projection matrix and a :math:`2 \times 1` offset vector. Denoting the matrix by :math:`\mathbf{M}` and the offset by :math:`\mathbf{t}`, a voxel at array index `(i, j, k)` has its center projected to the detector coordinates .. math:: \mathbf{M} \begin{bmatrix} i + \frac{1}{2} \\ j + \frac{1}{2} \\ k + \frac{1}{2} \end{bmatrix} + \mathbf{t} \,. The detector pixel at index `(i, j)` covers detector coordinates :math:`[i+1) \times [j+1)`. :meth:`XRayTransform3D.matrices_from_euler_angles` can help to make these geometry arrays. """ def __init__( self, input_shape: Shape, matrices: ArrayLike, det_shape: Shape, ): r""" Args: input_shape: Shape of input image. matrices: (num_views, 2, 4) array of homogeneous projection matrices. det_shape: Shape of detector. """ self.input_shape: Shape = input_shape self.matrices = jnp.asarray(matrices, dtype=np.float32) self.det_shape = det_shape self.output_shape = (len(matrices), *det_shape) super().__init__( input_shape=input_shape, output_shape=self.output_shape, eval_fn=self.project, adj_fn=self.back_project, ) def project(self, im: ArrayLike) -> snp.Array: """Compute X-ray projection.""" return XRayTransform3D._project(im, self.matrices, self.det_shape) def back_project(self, proj: ArrayLike) -> snp.Array: """Compute X-ray back projection""" return XRayTransform3D._back_project(proj, self.matrices, self.input_shape) @staticmethod def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> snp.Array: r""" Args: im: Input image. matrix: (num_views, 2, 4) array of homogeneous projection matrices. det_shape: Shape of detector. """ MAX_SLICE_LEN = 10 slice_offsets = list(range(0, im.shape[0], MAX_SLICE_LEN)) num_views = len(matrices) proj = jnp.zeros((num_views,) + det_shape, dtype=im.dtype) for view_ind, matrix in enumerate(matrices): for slice_offset in slice_offsets: proj = proj.at[view_ind].set( XRayTransform3D._project_single( im[slice_offset : slice_offset + MAX_SLICE_LEN], matrix, proj[view_ind], slice_offset=slice_offset, ) ) return proj @staticmethod @partial(jax.jit, donate_argnames="proj") def _project_single( im: ArrayLike, matrix: ArrayLike, proj: ArrayLike, slice_offset: int = 0 ) -> snp.Array: r""" Args: im: Input image. matrix: (2, 4) homogeneous projection matrix. det_shape: Shape of detector. """ ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = XRayTransform3D._calc_weights( im.shape, matrix, proj.shape, slice_offset ) proj = proj.at[ul_ind[0], ul_ind[1]].add(ul_weight * im, mode="drop") proj = proj.at[ul_ind[0] + 1, ul_ind[1]].add(ur_weight * im, mode="drop") proj = proj.at[ul_ind[0], ul_ind[1] + 1].add(ll_weight * im, mode="drop") proj = proj.at[ul_ind[0] + 1, ul_ind[1] + 1].add(lr_weight * im, mode="drop") return proj @staticmethod def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> snp.Array: r""" Args: proj: Input (set of) projection(s). matrix: (num_views, 2, 4) array of homogeneous projection matrices. input_shape: Shape of desired back projection. """ MAX_SLICE_LEN = 10 slice_offsets = list(range(0, input_shape[0], MAX_SLICE_LEN)) HTy = jnp.zeros(input_shape, dtype=proj.dtype) for view_ind, matrix in enumerate(matrices): for slice_offset in slice_offsets: HTy = HTy.at[slice_offset : slice_offset + MAX_SLICE_LEN].set( XRayTransform3D._back_project_single( proj[view_ind], matrix, HTy[slice_offset : slice_offset + MAX_SLICE_LEN], slice_offset=slice_offset, ) ) HTy.block_until_ready() # prevent OOM return HTy @staticmethod @partial(jax.jit, donate_argnames="HTy") def _back_project_single( y: ArrayLike, matrix: ArrayLike, HTy: ArrayLike, slice_offset: int = 0 ) -> snp.Array: ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = XRayTransform3D._calc_weights( HTy.shape, matrix, y.shape, slice_offset ) HTy = HTy + y[ul_ind[0], ul_ind[1]] * ul_weight HTy = HTy + y[ul_ind[0] + 1, ul_ind[1]] * ur_weight HTy = HTy + y[ul_ind[0], ul_ind[1] + 1] * ll_weight HTy = HTy + y[ul_ind[0] + 1, ul_ind[1] + 1] * lr_weight return HTy @staticmethod def _calc_weights( input_shape: Shape, matrix: snp.Array, det_shape: Shape, slice_offset: int = 0 ) -> snp.Array: # pixel (0, 0, 0) has its center at (0.5, 0.5, 0.5) x = jnp.mgrid[: input_shape[0], : input_shape[1], : input_shape[2]] + 0.5 # (3, ...) x = x.at[0].add(slice_offset) Px = jnp.stack( ( matrix[0, 0] * x[0] + matrix[0, 1] * x[1] + matrix[0, 2] * x[2] + matrix[0, 3], matrix[1, 0] * x[0] + matrix[1, 1] * x[1] + matrix[1, 2] * x[2] + matrix[1, 3], ) ) # (2, ...) # calculate weight on 4 intersecting pixels w = 0.5 # assumed <= 1.0 left_edge = Px - w / 2 to_next = jnp.minimum(jnp.ceil(left_edge) - left_edge, w) ul_ind = jnp.floor(left_edge).astype("int32") ul_weight = to_next[0] * to_next[1] * (1 / w**2) ur_weight = (w - to_next[0]) * to_next[1] * (1 / w**2) ll_weight = to_next[0] * (w - to_next[1]) * (1 / w**2) lr_weight = (w - to_next[0]) * (w - to_next[1]) * (1 / w**2) # set weights to zero out of bounds ul_weight = jnp.where( (ul_ind[0] >= 0) * (ul_ind[0] < det_shape[0]) * (ul_ind[1] >= 0) * (ul_ind[1] < det_shape[1]), ul_weight, 0.0, ) ur_weight = jnp.where( (ul_ind[0] + 1 >= 0) * (ul_ind[0] + 1 < det_shape[0]) * (ul_ind[1] >= 0) * (ul_ind[1] < det_shape[1]), ur_weight, 0.0, ) ll_weight = jnp.where( (ul_ind[0] >= 0) * (ul_ind[0] < det_shape[0]) * (ul_ind[1] + 1 >= 0) * (ul_ind[1] + 1 < det_shape[1]), ll_weight, 0.0, ) lr_weight = jnp.where( (ul_ind[0] + 1 >= 0) * (ul_ind[0] + 1 < det_shape[0]) * (ul_ind[1] + 1 >= 0) * (ul_ind[1] + 1 < det_shape[1]), lr_weight, 0.0, ) return ul_ind, ul_weight, ur_weight, ll_weight, lr_weight @staticmethod def matrices_from_euler_angles( input_shape: Shape, output_shape: Shape, seq: str, angles: ArrayLike, degrees: bool = False, voxel_spacing: ArrayLike = None, det_spacing: ArrayLike = None, ) -> snp.Array: """ Create a set of projection matrices from Euler angles. The input voxels will undergo the specified rotation and then be projected onto the global xy-plane. Args: input_shape: Shape of input image. output_shape: Shape of output (detector). str: Sequence of axes for rotation. Up to 3 characters belonging to the set {'X', 'Y', 'Z'} for intrinsic rotations, or {'x', 'y', 'z'} for extrinsic rotations. Extrinsic and intrinsic rotations cannot be mixed in one function call. angles: (num_views, N), N = 1, 2, or 3 Euler angles. degrees: If ``True``, angles are in degrees, otherwise radians. Default: ``True``, radians. voxel_spacing: (3,) array giving the spacing of image voxels. Default: `[1.0, 1.0, 1.0]`. Experimental. det_spacing: (2,) array giving the spacing of detector pixels. Default: `[1.0, 1.0]`. Experimental. Returns: (num_views, 2, 4) array of homogeneous projection matrices. """ if voxel_spacing is None: voxel_spacing = np.ones(3) if det_spacing is None: det_spacing = np.ones(2) # make projection matrix: form a rotation matrix and chop off the last row matrices = Rotation.from_euler(seq, angles, degrees=degrees).as_matrix() matrices = matrices[:, :2, :] # (num_views, 2, 3) # handle scaling M_voxel = np.diag(voxel_spacing) # (3, 3) M_det = np.diag(1 / np.array(det_spacing)) # (2, 2) # idea: M_det * M * M_voxel, but with a leading batch dimension matrices = np.einsum("vmn,nn->vmn", matrices, M_voxel) matrices = np.einsum("mm,vmn->vmn", M_det, matrices) # add translation to line up the centers x0 = np.array(input_shape) / 2 t = -np.einsum("vmn,n->vm", matrices, x0) + np.array(output_shape) / 2 matrices = snp.concatenate((matrices, t[..., np.newaxis]), axis=2) return matrices ================================================ FILE: scico/linop/xray/abel.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Abel transform LinearOperator wrapping the pyabel package. Abel transform LinearOperator wrapping the `pyabel `_ package. """ import math from typing import Optional import numpy as np import jax import jax.numpy as jnp import jax.numpy.fft as jnfft import abel from scico.linop import LinearOperator from scico.typing import Shape from scipy.linalg import solve_triangular class AbelTransform(LinearOperator): r"""Abel transform based on `PyAbel `_. Perform Abel transform (parallel beam projection of cylindrically symmetric objects) for a 2D image. The input 2D image is assumed to be centered and left-right symmetric. """ def __init__(self, img_shape: Shape): """ Args: img_shape: Shape of the input image. """ self.proj_mat_quad = _pyabel_daun_get_proj_matrix(img_shape) super().__init__( input_shape=img_shape, output_shape=img_shape, input_dtype=np.float32, output_dtype=np.float32, adj_fn=self._adj, jit=True, ) def _eval(self, x: jax.Array) -> jax.Array: return _pyabel_transform(x, direction="forward", proj_mat_quad=self.proj_mat_quad).astype( self.output_dtype ) def _adj(self, x: jax.Array) -> jax.Array: # type: ignore return _pyabel_transform(x, direction="transpose", proj_mat_quad=self.proj_mat_quad).astype( self.input_dtype ) def inverse(self, y: jax.Array) -> jax.Array: """Perform inverse Abel transform. Args: y: Input image (assumed to be a result of an Abel transform). Returns: Output of inverse Abel transform. """ return _pyabel_transform(y, direction="inverse", proj_mat_quad=self.proj_mat_quad).astype( self.input_dtype ) def _pyabel_transform( x: jax.Array, direction: str, proj_mat_quad: jax.Array, symmetry_axis: Optional[list] = None ) -> jax.Array: """Apply Abel transforms (forward, inverse and transposed). This function contains code copied from `PyAbel `_. """ if symmetry_axis is None: symmetry_axis = [None] Q0, Q1, Q2, Q3 = get_image_quadrants( x, symmetry_axis=symmetry_axis, use_quadrants=(True, True, True, True) ) def transform_quad(data): if direction == "forward": return data.dot(proj_mat_quad) elif direction == "transpose": return data.dot(proj_mat_quad.T) elif direction == "inverse": return solve_triangular(proj_mat_quad.T, data.T).T else: ValueError("Unsupported direction") AQ0 = AQ1 = AQ2 = AQ3 = None AQ1 = transform_quad(Q1) if 1 not in symmetry_axis: AQ2 = transform_quad(Q2) if 0 not in symmetry_axis: AQ0 = transform_quad(Q0) if None in symmetry_axis: AQ3 = transform_quad(Q3) return put_image_quadrants( (AQ0, AQ1, AQ2, AQ3), original_image_shape=x.shape, symmetry_axis=symmetry_axis ) def _pyabel_daun_get_proj_matrix(img_shape: Shape) -> jax.Array: """Get single-quadrant projection matrix.""" proj_matrix = abel.daun.get_bs_cached( math.ceil(img_shape[1] / 2), degree=0, reg_type=None, strength=0, direction="forward", verbose=False, ) return jnp.array(proj_matrix) # Read abel.tools.symmetry module into a string. mod_file = abel.tools.symmetry.__file__ with open(mod_file, "r") as f: mod_str = f.read() # Replace numpy functions that touch the main arrays with corresponding jax.numpy functions mod_str = mod_str.replace("fftpack.", "jnfft.") mod_str = mod_str.replace("np.atleast_2d", "jnp.atleast_2d") mod_str = mod_str.replace("np.flip", "jnp.flip") mod_str = mod_str.replace("np.concat", "jnp.concat") # Exec the module extract defined functions from the exec scope scope = {"jnp": jnp, "jnfft": jnfft} exec(mod_str, scope) get_image_quadrants = scope["get_image_quadrants"] put_image_quadrants = scope["put_image_quadrants"] ================================================ FILE: scico/linop/xray/astra.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """X-ray transform LinearOperators wrapping the ASTRA toolbox. X-ray transform :class:`.LinearOperator` wrapping the parallel beam projections in the `ASTRA toolbox `_. This package provides both C and CUDA implementations of core functionality, but note that use of the CUDA/GPU implementation is expected to result in GPU-host-GPU memory copies when transferring JAX arrays. Other JAX features such as automatic differentiation are not available. Functions here refer to three coordinate systems: world coordinates, volume coordinates, and detector coordinates. World coordinates are 3D coordinates representing a point in physical space. Volume coordinates refer to a position in the reconstruction volume, where the voxel with its intensity value stored at `vol[i, j, k]` has its center at volume coordinate (i+0.5, j+0.5, k+0.5) and side lengths of 1. Detector coordinates refer to a position on the detector array, and the pixel at `det[i, j]` has its center at detector coordinates (i+0.5, j+0.5) and side lengths of one. """ from typing import List, Optional, Sequence, Tuple, Union import numpy as np import numpy.typing import jax from jax.typing import ArrayLike from scipy.spatial.transform import Rotation try: import astra except ModuleNotFoundError as e: if e.name == "astra": new_e = ModuleNotFoundError("Could not import astra; please install the ASTRA toolbox.") new_e.name = "astra" raise new_e from e else: raise e try: from collections import Iterable # type: ignore except ImportError: import collections # Monkey patching required because latest astra release uses old module path for Iterable collections.Iterable = collections.abc.Iterable # type: ignore from scico.linop import LinearOperator from scico.typing import Shape, TypeAlias VolumeGeometry: TypeAlias = dict ProjectionGeometry: TypeAlias = dict def set_astra_gpu_index(idx: Union[int, Sequence[int]]): """Set the index/indices of GPU(s) to be used by astra. Args: idx: Index or indices of GPU(s). """ astra.set_gpu_index(idx) def _project_coords( x_volume: np.ndarray, vol_geom: VolumeGeometry, proj_geom: ProjectionGeometry ) -> np.ndarray: """ Project volume coordinates into detector coordinates based on ASTRA geometry objects. Args: x_volume: (..., 3) vector(s) of volume coordinates. vol_geom: ASTRA volume geometry object. proj_geom: ASTRA projection geometry object. Returns: (num_angles, ..., 2) array of detector coordinates corresponding to projections of the points in `x_volume`. """ det_shape = (proj_geom["DetectorRowCount"], proj_geom["DetectorColCount"]) x_world = volume_coords_to_world_coords(x_volume, vol_geom=vol_geom) x_dets = [] for vec in proj_geom["Vectors"]: ray, d, u, v = vec[0:3], vec[3:6], vec[6:9], vec[9:12] x_det = project_world_coordinates(x_world, ray, d, u, v, det_shape) x_dets.append(x_det) return np.stack(x_dets) def project_world_coordinates( x: np.ndarray, ray: np.typing.ArrayLike, d: np.typing.ArrayLike, u: np.typing.ArrayLike, v: np.typing.ArrayLike, det_shape: Sequence[int], ) -> np.ndarray: """Project world coordinates along ray into the specified basis. Project world coordinates along `ray` into the basis described by `u` and `v` with center `d`. Args: x: (..., 3) vector(s) of world coordinates. ray: (3,) ray direction d: (3,) center of the detector u: (3,) vector from detector pixel (0,0) to (0,1), columns, x v: (3,) vector from detector pixel (0,0) to (1,0), rows, y Returns: (..., 2) vector(s) in the detector coordinates """ Phi = np.stack((ray, u, v), axis=1) x = x - d # express with respect to detector center alpha = np.linalg.pinv(Phi) @ x[..., :, np.newaxis] # (3,3) times (3,1) alpha = alpha[..., 0] # squash from (..., 3, 1) to (..., 3) Palpha = alpha[..., 1:] # throw away ray coordinate det_center_idx = ( np.array(det_shape)[::-1] / 2 - 0.5 ) # center of length-2 is index 0.5, length-3 -> index 1 ind_xy = Palpha + det_center_idx ind_ij = ind_xy[..., ::-1] return ind_ij def volume_coords_to_world_coords(idx: np.ndarray, vol_geom: VolumeGeometry) -> np.ndarray: """Convert a volume coordinate into a world coordinate. Convert a volume coordinate into a world coordinate using ASTRA conventions. Args: idx: (..., 2) or (..., 3) vector(s) of index coordinates. vol_geom: ASTRA volume geometry object. Returns: (..., 2) or (..., 3) vector(s) of world coordinates. """ if "GridSliceCount" not in vol_geom: return _volume_index_to_astra_world_2d(idx, vol_geom) return _volume_index_to_astra_world_3d(idx, vol_geom) def _volume_index_to_astra_world_2d(idx: np.ndarray, vol_geom: VolumeGeometry) -> np.ndarray: """Convert a 2D volume coordinate into a 2D world coordinate.""" coord = idx[..., [1, 0]] # x:col, y:row, nx = np.array( # (x, y) order ( vol_geom["GridColCount"], vol_geom["GridRowCount"], ) ) opt = vol_geom["option"] dx = np.array( ( (opt["WindowMaxX"] - opt["WindowMinX"]) / nx[0], (opt["WindowMaxY"] - opt["WindowMinY"]) / nx[1], ) ) center_coord = nx / 2 - 0.5 # center of length-2 is index 0.5, center of length-3 is index 1 return (coord - center_coord) * dx def _volume_index_to_astra_world_3d(idx: np.ndarray, vol_geom: VolumeGeometry) -> np.ndarray: """Convert a 3D volume coordinate into a 3D world coordinate.""" coord = idx[..., [2, 1, 0]] # x:col, y:row, z:slice nx = np.array( # (x, y, z) order ( vol_geom["GridColCount"], vol_geom["GridRowCount"], vol_geom["GridSliceCount"], ) ) opt = vol_geom["option"] dx = np.array( ( (opt["WindowMaxX"] - opt["WindowMinX"]) / nx[0], (opt["WindowMaxY"] - opt["WindowMinY"]) / nx[1], (opt["WindowMaxZ"] - opt["WindowMinZ"]) / nx[2], ) ) center_coord = nx / 2 - 0.5 # center of length-2 is index 0.5, center of length-3 is index 1 return (coord - center_coord) * dx class XRayTransform2D(LinearOperator): r"""2D parallel beam X-ray transform based on the ASTRA toolbox. Perform tomographic projection (also called X-ray projection) of an image at specified angles, using the `ASTRA toolbox `_. """ def __init__( self, input_shape: Shape, det_count: int, det_spacing: float, angles: np.ndarray, volume_geometry: Optional[List[float]] = None, device: str = "auto", ): """ Args: input_shape: Shape of the input array. det_count: Number of detector elements. See the `astra documentation `__ for more information. det_spacing: Spacing between detector elements. See the `astra documentation `__ for more information.. angles: Array of projection angles in radians. volume_geometry: Specification of the shape of the discretized reconstruction volume. Must either ``None``, in which case it is inferred from `input_shape`, or follow the syntax described in the `astra documentation `__. device: Specifies device for projection operation. One of ["auto", "gpu", "cpu"]. If "auto", a GPU is used if available, otherwise, the CPU is used. """ self.num_dims = len(input_shape) if self.num_dims != 2: raise ValueError( f"Only 2D projections are supported, but 'input_shape' is {input_shape}." ) if not isinstance(det_count, int): raise ValueError("Expected argument 'det_count' to be an int.") output_shape: Shape = (len(angles), det_count) # Set up all the ASTRA config self.det_spacing = det_spacing self.det_count = det_count self.angles: np.ndarray = np.array(angles) self.proj_geom: dict = astra.create_proj_geom( "parallel", det_spacing, det_count, self.angles ) self.proj_id: Optional[int] self.input_shape: tuple = input_shape if volume_geometry is None: self.vol_geom = astra.create_vol_geom(*input_shape) else: if len(volume_geometry) == 4: self.vol_geom = astra.create_vol_geom(*input_shape, *volume_geometry) else: raise ValueError( "Argument 'volume_geometry' must be a tuple of len 4." "Please see the astra documentation for details." ) if device in ["cpu", "gpu"]: # If cpu or gpu selected, attempt to comply (no checking to # confirm that a gpu is available to astra). self.device = device elif device == "auto": # If auto selected, use cpu or gpu depending on the default # jax device (for simplicity, no checking whether gpu is # available to astra when one is not available to jax). dev0 = jax.devices()[0] self.device = dev0.platform else: raise ValueError(f"Invalid 'device' specified; got {device}.") if self.device == "cpu": self.proj_id = astra.create_projector("line", self.proj_geom, self.vol_geom) elif self.device == "gpu": self.proj_id = astra.create_projector("cuda", self.proj_geom, self.vol_geom) # Wrap our non-jax function to indicate we will supply fwd/rev mode functions self._eval = jax.custom_vjp(self._proj) self._eval.defvjp(lambda x: (self._proj(x), None), lambda _, y: (self._bproj(y),)) # type: ignore self._adj = jax.custom_vjp(self._bproj) self._adj.defvjp(lambda y: (self._bproj(y), None), lambda _, x: (self._proj(x),)) # type: ignore super().__init__( input_shape=self.input_shape, output_shape=output_shape, input_dtype=np.float32, output_dtype=np.float32, adj_fn=self._adj, jit=False, ) def _proj(self, x: jax.Array) -> jax.Array: # apply the forward projector and generate a sinogram def f(x): x = _ensure_writeable(x) proj_id, result = astra.create_sino(x, self.proj_id) astra.data2d.delete(proj_id) return result return jax.pure_callback(f, jax.ShapeDtypeStruct(self.output_shape, self.output_dtype), x) def _bproj(self, y: jax.Array) -> jax.Array: # apply backprojector def f(y): y = _ensure_writeable(y) proj_id, result = astra.create_backprojection(y, self.proj_id) astra.data2d.delete(proj_id) return result return jax.pure_callback(f, jax.ShapeDtypeStruct(self.input_shape, self.input_dtype), y) def fbp(self, sino: jax.Array, filter_type: str = "Ram-Lak") -> jax.Array: """Filtered back projection (FBP) reconstruction. Perform tomographic reconstruction using the filtered back projection (FBP) algorithm. Args: sino: Sinogram to reconstruct. filter_type: Select the filter to use. For a list of options see `cfg.FilterType` in the `ASTRA documentation `__. Returns: Reconstructed volume. """ def f(sino): sino = _ensure_writeable(sino) sino_id = astra.data2d.create("-sino", self.proj_geom, sino) # create memory for result rec_id = astra.data2d.create("-vol", self.vol_geom) # start to populate config cfg = astra.astra_dict("FBP_CUDA" if self.device == "gpu" else "FBP") cfg["ReconstructionDataId"] = rec_id cfg["ProjectorId"] = self.proj_id cfg["ProjectionDataId"] = sino_id cfg["option"] = {"FilterType": filter_type} # initialize algorithm; run alg_id = astra.algorithm.create(cfg) astra.algorithm.run(alg_id) # get the result out = astra.data2d.get(rec_id) # cleanup FBP-specific arra astra.algorithm.delete(alg_id) astra.data2d.delete(rec_id) astra.data2d.delete(sino_id) return out return jax.pure_callback(f, jax.ShapeDtypeStruct(self.input_shape, self.input_dtype), sino) def convert_from_scico_geometry( in_shape: Shape, matrices: ArrayLike, det_shape: Shape ) -> np.ndarray: """Convert SCICO projection matrices into ASTRA "parallel3d_vec" vectors. For 3D arrays, in ASTRA, the dimensions go (slices, rows, columns) and (z, y, x); in SCICO, the dimensions go (x, y, z). In ASTRA, the x-grid (recon) is centered on the origin and the y-grid (projection) can move. In SCICO, the x-grid origin is the center of x[0, 0, 0], the y-grid origin is the center of y[0, 0]. See section "parallel3d_vec" in the `astra documentation `__. Args: in_shape: Shape of input image. matrices: (num_angles, 2, 4) array of homogeneous projection matrices. det_shape: Shape of detector. Returns: (num_angles, 12) vector array in the ASTRA "parallel3d_vec" convention. """ # ray is perpendicular to projection axes ray = np.cross(matrices[:, 0, :3], matrices[:, 1, :3]) # detector center comes from lifting the center index to 3D y_center = (np.array(det_shape) - 1) / 2 x_center = ( np.einsum("...mn,n->...m", matrices[..., :3], (np.array(in_shape) - 1) / 2) + matrices[..., 3] ) d = np.einsum("...mn,...m->...n", matrices[..., :3], y_center - x_center) # (V, 2, 3) x (V, 2) u = matrices[:, 1, :3] v = matrices[:, 0, :3] # handle different axis conventions ray = ray[:, [2, 1, 0]] d = d[:, [2, 1, 0]] u = u[:, [2, 1, 0]] v = v[:, [2, 1, 0]] vectors = np.concatenate((ray, d, u, v), axis=1) # (v, 12) return vectors def _astra_to_scico_geometry(vol_geom: VolumeGeometry, proj_geom: ProjectionGeometry) -> np.ndarray: """Convert ASTRA geometry objects into a SCICO projection matrix. Convert ASTRA volume and projection geometry into a SCICO X-ray projection matrix, assuming "parallel3d_vec" format. The approach is to locate 3 points in the volume domain, deduce the corresponding projection locations, and, then, solve a linear system to determine the affine relationship between them. Args: vol_geom: ASTRA volume geometry object. proj_geom: ASTRA projection geometry object. Returns: (num_angles, 2, 4) array of homogeneous projection matrices. """ x_volume = np.concatenate((np.zeros((1, 3)), np.eye(3)), axis=0) # (4, 3) x_dets = _project_coords(x_volume, vol_geom, proj_geom) # (num_angles, 4, 2) x_volume_aug = np.concatenate((x_volume, np.ones((4, 1))), axis=1) # (4, 4) matrices = [] for x_det in x_dets: M = np.linalg.solve(x_volume_aug, x_det).T np.testing.assert_allclose(M @ x_volume_aug[0], x_det[0]) matrices.append(M) return np.stack(matrices) def convert_to_scico_geometry( input_shape: Shape, det_count: Tuple[int, int], det_spacing: Optional[Tuple[float, float]] = None, angles: Optional[np.ndarray] = None, vectors: Optional[np.ndarray] = None, ) -> np.ndarray: """Convert X-ray geometry specification to a SCICO projection matrix. The approach is to locate 3 points in the volume domain, deduce the corresponding projection locations, and, then, solve a linear system to determine the affine relationship between them. Args: input_shape: Shape of the input array. det_count: Number of detector elements. See the `astra documentation `__ for more information. det_spacing: Spacing between detector elements. See the `astra documentation `__ for more information. angles: Array of projection angles in radians. This parameter is mutually exclusive with `vectors`. vectors: Array of ASTRA geometry specification vectors. This parameter is mutually exclusive with `angles`. Returns: (num_angles, 2, 4) array of homogeneous projection matrices. """ if angles is not None and vectors is not None: raise ValueError("Arguments 'angles' and 'vectors' are mutually exclusive.") if angles is None and vectors is None: raise ValueError("Exactly one of arguments 'angles' and 'vectors' must be provided.") vol_geom, proj_geom = XRayTransform3D.create_astra_geometry( input_shape, det_count, det_spacing=det_spacing, angles=angles, vectors=vectors ) return _astra_to_scico_geometry(vol_geom, proj_geom) class XRayTransform3D(LinearOperator): # pragma: no cover r"""3D parallel beam X-ray transform based on the ASTRA toolbox. Perform tomographic projection (also called X-ray projection) of a volume at specified angles, using the `ASTRA toolbox `_. The `3D geometries `__ "parallel3d" and "parallel3d_vec" are supported by this interface. Note that a CUDA GPU is required for the primary functionality of this class; if no GPU is available, initialization will fail with a :exc:`RuntimeError`. The volume is fixed with respect to the coordinate system, centered at the origin, as illustrated below: .. plot:: pyfigures/xray_3d_vol.py :align: center :include-source: False :show-source-link: False The voxels sides have unit length (in arbitrary units), which defines the scale for all other dimensions in the source-volume-detector configuration. Geometry axes `z`, `y`, and `x` correspond to volume array axes 0, 1, and 2 respectively. The projected array axes 0, 1, and 2 correspond respectively to detector rows, views, and detector columns. In the "parallel3d" case, the source and detector rotate clockwise about the `z` axis in the `x`-`y` plane, as illustrated below: .. plot:: pyfigures/xray_3d_ang.py :align: center :include-source: False :show-source-link: False :caption: Each radial arrow indicates the direction of the beam towards the detector (indicated in orange in the "light" display mode) and the arrow parallel to the detector indicates the direction of increasing pixel indices. In this case the `z` axis is in the same direction as the vertical/row axis of the detector and its projection corresponds to a vertical line in the center of the horizontal/column detector axis. Note that the view images must be displayed with the origin at the bottom left (i.e. vertically inverted from the top left origin image indexing convention) in order for the projections to correspond to the positive up/negative down orientation of the `z` axis in the figures here. In the "parallel3d_vec" case, each view is determined by the following vectors: .. list-table:: View definition vectors :widths: 10 90 * - :math:`\mb{r}` - Direction of the parallel beam * - :math:`\mb{d}` - Center of the detector * - :math:`\mb{u}` - Vector from detector pixel (0,0) to (0,1) (direction of increasing detector column index) * - :math:`\mb{v}` - Vector from detector pixel (0,0) to (1,0) (direction of increasing detector row index) Note that the components of these vectors are in `x`, `y`, `z` order, not the `z`, `y`, `x` order of the volume axes. .. plot:: pyfigures/xray_3d_vec.py :align: center :include-source: False :show-source-link: False Vector :math:`\mb{r}` is not illustrated to avoid cluttering the figure, but will typically be directed toward the center of the detector (i.e. in the direction of :math:`\mb{d}` in the figure.) Since the volume-detector distance does not have a geometric effect for a parallel-beam configuration, :math:`\mb{d}` may be set to the zero vector when the detector and beam centers coincide (e.g., as in the case of the "parallel3d" geometry). Note that the view images must be displayed with the origin at the bottom left (i.e. vertically inverted from the top left origin image indexing convention) in order for the row indexing of the projections to correspond to the direction of :math:`\mb{v}` in the figure. These vectors are concatenated into a single row vector :math:`(\mb{r}, \mb{d}, \mb{u}, \mb{v})` to form the full geometry specification for a single view, and multiple such row vectors are stacked to specify the geometry for a set of views. """ def __init__( self, input_shape: Shape, det_count: Tuple[int, int], det_spacing: Optional[Tuple[float, float]] = None, angles: Optional[np.ndarray] = None, vectors: Optional[np.ndarray] = None, ): """ Keyword arguments `det_spacing` and `angles` should be specified to use the "parallel3d" geometry, and keyword argument `vectors` should be specified to use the "parallel3d_vec" geometry. These parameters are mutually exclusive. Args: input_shape: Shape of the input array. det_count: Number of detector elements. See the `astra documentation `__ for more information. det_spacing: Spacing between detector elements. See the `astra documentation `__ for more information. angles: Array of projection angles in radians. This parameter is mutually exclusive with `vectors`. vectors: Array of ASTRA geometry specification vectors. This parameter is mutually exclusive with `angles`. Raises: RuntimeError: If a CUDA GPU is not available to the ASTRA toolbox. """ if not astra.use_cuda(): raise RuntimeError("CUDA GPU required but not available or not enabled.") if not ( (det_spacing is not None and angles is not None and vectors is None) or (vectors is not None and det_spacing is None and angles is None) ): raise ValueError( "Keyword arguments 'det_spacing' and 'angles', or keyword argument " "'vectors' must be specified, but not both." ) self.num_dims = len(input_shape) if self.num_dims != 3: raise ValueError( f"Only 3D projections are supported, but 'input_shape' is {input_shape}." ) if not isinstance(det_count, (list, tuple)) or len(det_count) != 2: raise ValueError("Expected argument 'det_count' to be a tuple with 2 elements.") if angles is not None and vectors is not None: raise ValueError("Arguments 'angles' and 'vectors' are mutually exclusive.") if angles is None and vectors is None: raise ValueError( "Exactly one of the arguments 'angles' and 'vectors' must be provided." ) if angles is not None: Nview = angles.size self.angles: Optional[np.ndarray] = np.array(angles) self.vectors: Optional[np.ndarray] = None if vectors is not None: Nview = vectors.shape[0] self.vectors = np.array(vectors) self.angles = None output_shape: Shape = (det_count[0], Nview, det_count[1]) self.det_count = det_count assert isinstance(det_count, (list, tuple)) self.input_shape: tuple = input_shape self.vol_geom, self.proj_geom = self.create_astra_geometry( input_shape, det_count, det_spacing=det_spacing, angles=self.angles, vectors=self.vectors, ) # Wrap our non-jax function to indicate we will supply fwd/rev mode functions self._eval = jax.custom_vjp(self._proj) self._eval.defvjp(lambda x: (self._proj(x), None), lambda _, y: (self._bproj(y),)) # type: ignore self._adj = jax.custom_vjp(self._bproj) self._adj.defvjp(lambda y: (self._bproj(y), None), lambda _, x: (self._proj(x),)) # type: ignore super().__init__( input_shape=self.input_shape, output_shape=output_shape, input_dtype=np.float32, output_dtype=np.float32, adj_fn=self._adj, jit=False, ) @staticmethod def create_astra_geometry( input_shape: Shape, det_count: Tuple[int, int], det_spacing: Optional[Tuple[float, float]] = None, angles: Optional[np.ndarray] = None, vectors: Optional[np.ndarray] = None, ) -> Tuple[VolumeGeometry, ProjectionGeometry]: """Create ASTRA 3D geometry objects. Keyword arguments `det_spacing` and `angles` should be specified to use the "parallel3d" geometry, and keyword argument `vectors` should be specified to use the "parallel3d_vec" geometry. These parameters are mutually exclusive. Args: input_shape: Shape of the input array. det_count: Number of detector elements. See the `astra documentation `__ for more information. det_spacing: Spacing between detector elements. See the `astra documentation `__ for more information. angles: Array of projection angles in radians. vectors: Array of geometry specification vectors. Returns: A tuple `(vol_geom, proj_geom)` of ASTRA volume geometry and projection geometry objects. """ vol_geom = astra.create_vol_geom(input_shape[1], input_shape[2], input_shape[0]) if angles is not None: assert det_spacing is not None proj_geom = astra.create_proj_geom( "parallel3d", det_spacing[0], det_spacing[1], det_count[0], det_count[1], angles, ) else: proj_geom = astra.create_proj_geom( "parallel3d_vec", det_count[0], det_count[1], vectors ) return vol_geom, proj_geom def _proj(self, x: jax.Array) -> jax.Array: # apply the forward projector and generate a sinogram def f(x): x = _ensure_writeable(x) proj_id, result = astra.create_sino3d_gpu(x, self.proj_geom, self.vol_geom) astra.data3d.delete(proj_id) return result return jax.pure_callback(f, jax.ShapeDtypeStruct(self.output_shape, self.output_dtype), x) def _bproj(self, y: jax.Array) -> jax.Array: # apply backprojector def f(y): y = _ensure_writeable(y) proj_id, result = astra.create_backprojection3d_gpu(y, self.proj_geom, self.vol_geom) astra.data3d.delete(proj_id) return result return jax.pure_callback(f, jax.ShapeDtypeStruct(self.input_shape, self.input_dtype), y) def angle_to_vector(det_spacing: Tuple[float, float], angles: np.ndarray) -> np.ndarray: """Convert det_spacing and angles to vector geometry specification. Args: det_spacing: Spacing between detector elements. See the `astra documentation `__ for more information. angles: Array of projection angles in radians. Returns: Array of geometry specification vectors. """ vectors = np.zeros((angles.size, 12)) vectors[:, 0] = np.sin(angles) vectors[:, 1] = -np.cos(angles) vectors[:, 6] = np.cos(angles) * det_spacing[0] vectors[:, 7] = np.sin(angles) * det_spacing[0] vectors[:, 11] = det_spacing[1] return vectors def rotate_vectors(vectors: np.ndarray, rot: Rotation) -> np.ndarray: """Rotate geometry specification vectors. Rotate ASTRA "parallel3d_vec" geometry specification vectors. Args: vectors: Array of geometry specification vectors. rot: Rotation. Returns: Rotated geometry specification vectors. """ rot_vecs = vectors.copy() for k in range(0, 12, 3): rot_vecs[:, k : k + 3] = rot.apply(rot_vecs[:, k : k + 3]) return rot_vecs def _ensure_writeable(x): """Ensure that `x.flags.writeable` is ``True``, copying if needed.""" if hasattr(x, "flags"): # x is a numpy array if not x.flags.writeable: try: x.setflags(write=True) except ValueError: x = x.copy() else: # x is a jax array (which is immutable) x = np.array(x) return x ================================================ FILE: scico/linop/xray/svmbir.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """X-ray transform LinearOperator wrapping the svmbir package. X-ray transform :class:`.LinearOperator` wrapping the `svmbir `_ package. Since this package is an interface to compiled C code, JAX features such as automatic differentiation and support for GPU devices are not available. """ from typing import Optional, Tuple, Union import numpy as np import jax import scico.numpy as snp from scico.loss import Loss, SquaredL2Loss from scico.typing import Shape from .._diag import Diagonal, Identity from .._linop import LinearOperator try: import svmbir except ImportError: raise ImportError("Could not import svmbir; please install it.") class XRayTransform(LinearOperator): r"""X-ray transform based on svmbir. Perform tomographic projection of an image at specified angles, using the `svmbir `_ package. The `is_masked` option selects whether a valid region for projections (pixels outside this region are ignored when performing the projection) is active. This region of validity is also respected by :meth:`.SVMBIRSquaredL2Loss.prox` when :class:`.SVMBIRSquaredL2Loss` is initialized with a :class:`XRayTransform` with this option enabled. A brief description of the supported scanner geometries can be found in the `svmbir documentation `_. Parallel beam geometry and two different fan beam geometries are supported. .. list-table:: * - .. figure:: /figures/geom-parallel.png :align: center :width: 75% Fig 1. Parallel beam geometry. - .. figure:: /figures/geom-fan.png :align: center :width: 75% Fig 2. Curved fan beam geometry. """ def __init__( self, input_shape: Shape, angles: snp.Array, num_channels: int, center_offset: float = 0.0, is_masked: bool = False, geometry: str = "parallel", dist_source_detector: Optional[float] = None, magnification: Optional[float] = None, delta_channel: Optional[float] = None, delta_pixel: Optional[float] = None, ): """ The output of this linear operator is an array of shape `(num_angles, num_channels)` when input_shape is 2D, or of shape `(num_angles, num_slices, num_channels)` when input_shape is 3D, where `num_angles` is the length of the `angles` argument, and `num_slices` is inferred from the `input_shape` argument. Most of the the following arguments have the same name as and correspond to arguments of :func:`svmbir.project`. A brief summary of each is provided here, but the documentation for :func:`svmbir.project` should be consulted for further details. Args: input_shape: Shape of the input array. May be of length 2 (a 2D array) or 3 (a 3D array). When specifying a 2D array, the format for the input_shape is `(num_rows, num_cols)`. For a 3D array, the format for the input_shape is `(num_slices, num_rows, num_cols)`, where `num_slices` denotes the number of slices in the input, and `num_rows` and `num_cols` denote the number of rows and columns in a single slice of the input. A slice is a plane perpendicular to the axis of rotation of the tomographic system. At angle zero, each row is oriented along the X-rays (parallel beam) or the X-ray beam directed toward the detector center (fan beam). Note that `input_shape=(num_rows, num_cols)` and `input_shape=(1, num_rows, num_cols)` result in the same underlying projector. angles: Array of projection angles in radians, should be increasing. num_channels: Number of detector channels in the sinogram data. center_offset: Position of the detector center relative to the projection of the center of rotation onto the detector, in units of pixels. is_masked: If ``True``, the valid region of the image is determined by a mask defined as the circle inscribed within the image boundary. Otherwise, the whole image array is taken into account by projections. geometry: Scanner geometry, either "parallel", "fan-curved", or "fan-flat". Note that the `dist_source_detector` and `magnification` arguments must be provided for then fan beam geometries. dist_source_detector: Distance from X-ray focal spot to detectors in units of pixel pitch. Only used when geometry is "fan-flat" or "fan-curved". magnification: Magnification factor of the scanner geometry. Only used when geometry is "fan-flat" or "fan-curved". delta_channel: Detector channel spacing. delta_pixel: Spacing between image pixels in the 2D slice plane. """ self.angles = angles self.num_channels = num_channels self.center_offset = center_offset if len(input_shape) == 2: # 2D input self.svmbir_input_shape = (1,) + input_shape output_shape: Tuple[int, ...] = (len(angles), num_channels) self.svmbir_output_shape = output_shape[0:1] + (1,) + output_shape[1:2] elif len(input_shape) == 3: # 3D input self.svmbir_input_shape = input_shape output_shape = (len(angles), input_shape[0], num_channels) self.svmbir_output_shape = output_shape else: raise ValueError( f"Only 2D and 3D inputs are supported, but input_shape was {input_shape}." ) self.is_masked = is_masked if self.is_masked: self.roi_radius = None else: self.roi_radius = max(self.svmbir_input_shape[1], self.svmbir_input_shape[2]) self.geometry = geometry self.dist_source_detector = dist_source_detector self.magnification = magnification if delta_channel is None: self.delta_channel = 1.0 else: self.delta_channel = delta_channel if self.geometry == "fan-curved" or self.geometry == "fan-flat": if self.dist_source_detector is None: raise ValueError( "Argument 'dist_source_detector' must be specified for fan beam geometry." ) if self.magnification is None: raise ValueError( "Argument 'magnification' must be specified for fan beam geometry." ) if delta_pixel is None: self.delta_pixel = self.delta_channel / self.magnification else: self.delta_pixel = delta_pixel elif self.geometry == "parallel": self.magnification = 1.0 if delta_pixel is None: self.delta_pixel = self.delta_channel else: self.delta_pixel = delta_pixel else: raise ValueError("Unspecified geometry {}.".format(self.geometry)) # Set up custom_vjp for _eval and _adj so jax.grad works on them. self._eval = jax.custom_vjp(self._proj_hcb) self._eval.defvjp(lambda x: (self._proj_hcb(x), None), lambda _, y: (self._bproj_hcb(y),)) # type: ignore self._adj = jax.custom_vjp(self._bproj_hcb) self._adj.defvjp(lambda y: (self._bproj_hcb(y), None), lambda _, x: (self._proj_hcb(x),)) # type: ignore super().__init__( input_shape=input_shape, output_shape=output_shape, input_dtype=np.float32, output_dtype=np.float32, adj_fn=self._adj, jit=False, ) @staticmethod def _proj( x: snp.Array, angles: snp.Array, num_channels: int, center_offset: float = 0.0, roi_radius: Optional[float] = None, geometry: str = "parallel", dist_source_detector: Optional[float] = None, magnification: Optional[float] = None, delta_channel: Optional[float] = None, delta_pixel: Optional[float] = None, ) -> snp.Array: return snp.array( svmbir.project( np.array(x), np.array(angles), num_channels, verbose=0, center_offset=center_offset, roi_radius=roi_radius, geometry=geometry, dist_source_detector=dist_source_detector, magnification=magnification, delta_channel=delta_channel, delta_pixel=delta_pixel, ) ) def _proj_hcb(self, x): x = x.reshape(self.svmbir_input_shape) # callback wrapper for _proj y = jax.pure_callback( lambda x: self._proj( x, self.angles, self.num_channels, center_offset=self.center_offset, roi_radius=self.roi_radius, geometry=self.geometry, dist_source_detector=self.dist_source_detector, magnification=self.magnification, delta_channel=self.delta_channel, delta_pixel=self.delta_pixel, ), jax.ShapeDtypeStruct(self.svmbir_output_shape, self.output_dtype), x, ) return y.reshape(self.output_shape) @staticmethod def _bproj( y: snp.Array, angles: snp.Array, num_rows: int, num_cols: int, center_offset: Optional[float] = 0.0, roi_radius: Optional[float] = None, geometry: str = "parallel", dist_source_detector: Optional[float] = None, magnification: Optional[float] = None, delta_channel: Optional[float] = None, delta_pixel: Optional[float] = None, ) -> snp.Array: return snp.array( svmbir.backproject( np.array(y), np.array(angles), num_rows=num_rows, num_cols=num_cols, verbose=0, center_offset=center_offset, roi_radius=roi_radius, geometry=geometry, dist_source_detector=dist_source_detector, magnification=magnification, delta_channel=delta_channel, delta_pixel=delta_pixel, ) ) def _bproj_hcb(self, y): y = y.reshape(self.svmbir_output_shape) # callback wrapper for _bproj x = jax.pure_callback( lambda y: self._bproj( y, self.angles, self.svmbir_input_shape[1], self.svmbir_input_shape[2], center_offset=self.center_offset, roi_radius=self.roi_radius, geometry=self.geometry, dist_source_detector=self.dist_source_detector, magnification=self.magnification, delta_channel=self.delta_channel, delta_pixel=self.delta_pixel, ), jax.ShapeDtypeStruct(self.svmbir_input_shape, self.input_dtype), y, ) return x.reshape(self.input_shape) class SVMBIRExtendedLoss(Loss): r"""Extended squared :math:`\ell_2` loss with svmbir tomographic projector. Generalization of the weighted squared :math:`\ell_2` loss for a CT reconstruction problem, .. math:: \alpha \norm{\mb{y} - A(\mb{x})}_W^2 = \alpha \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} - A(\mb{x})\right) \;, where :math:`A` is a :class:`.XRayTransform`, :math:`\alpha` is the scaling parameter and :math:`W` is an instance of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, it is set to :class:`scico.linop.Identity`. The extended loss differs from a typical weighted squared :math:`\ell_2` loss as follows. When `positivity=True`, the prox projects onto the non-negative orthant and the loss is infinite if any element of the input is negative. When the `is_masked` option of the associated :class:`.XRayTransform` is ``True``, the reconstruction is computed over a masked region of the image as described in class :class:`.XRayTransform`. """ A: XRayTransform W: Union[Identity, Diagonal] def __init__( self, *args, scale: float = 0.5, prox_kwargs: Optional[dict] = None, positivity: bool = False, W: Optional[Diagonal] = None, **kwargs, ): r"""Initialize a :class:`SVMBIRExtendedLoss` object. Args: y: Sinogram measurement. A: Forward operator. scale: Scaling parameter. prox_kwargs: Dictionary of arguments passed to the :meth:`svmbir.recon` prox routine. Defaults to {"maxiter": 1000, "ctol": 0.001}. positivity: Enforce positivity in the prox operation. The loss is infinite if any element of the input is negative. W: Weighting diagonal operator. Must be non-negative. If ``None``, defaults to :class:`.Identity`. """ super().__init__(*args, scale=scale, **kwargs) # type: ignore if not isinstance(self.A, XRayTransform): raise ValueError("LinearOperator A must be a radon_svmbir.XRayTransform.") self.has_prox = True if prox_kwargs is None: prox_kwargs = {} default_prox_args = {"maxiter": 1000, "ctol": 0.001} default_prox_args.update(prox_kwargs) svmbir_prox_args = {} if "maxiter" in default_prox_args: svmbir_prox_args["max_iterations"] = default_prox_args["maxiter"] if "ctol" in default_prox_args: svmbir_prox_args["stop_threshold"] = default_prox_args["ctol"] self.svmbir_prox_args = svmbir_prox_args self.positivity = positivity if W is None: self.W = Identity(self.y.shape) elif isinstance(W, Diagonal): if snp.all(W.diagonal >= 0): self.W = W else: raise ValueError(f"The weights, W, must be non-negative.") else: raise TypeError(f"Argument 'W' must be None or a linop.Diagonal, got {type(W)}.") def __call__(self, x: snp.Array) -> float: if self.positivity and snp.sum(x < 0) > 0: return snp.inf else: return self.scale * (self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2).sum() def prox(self, v: snp.Array, lam: float = 1, **kwargs) -> snp.Array: v = v.reshape(self.A.svmbir_input_shape) y = self.y.reshape(self.A.svmbir_output_shape) weights = self.W.diagonal.reshape(self.A.svmbir_output_shape) sigma_p = snp.sqrt(lam) if "v0" in kwargs and kwargs["v0"] is not None: v0: Union[float, np.ndarray] = np.reshape( np.array(kwargs["v0"]), self.A.svmbir_input_shape ) else: v0 = 0.0 # change: stop, mask-rad, init result = svmbir.recon( np.array(y), np.array(self.A.angles), weights=np.array(weights), prox_image=np.array(v), num_rows=self.A.svmbir_input_shape[1], num_cols=self.A.svmbir_input_shape[2], center_offset=self.A.center_offset, roi_radius=self.A.roi_radius, geometry=self.A.geometry, dist_source_detector=self.A.dist_source_detector, magnification=self.A.magnification, delta_channel=self.A.delta_channel, delta_pixel=self.A.delta_pixel, sigma_p=float(sigma_p), sigma_y=1.0, positivity=self.positivity, verbose=0, init_image=v0, **self.svmbir_prox_args, ) if np.sum(np.isnan(result)): raise ValueError("Result contains NaNs.") return snp.array(result.reshape(self.A.input_shape)) class SVMBIRSquaredL2Loss(SVMBIRExtendedLoss, SquaredL2Loss): r"""Weighted squared :math:`\ell_2` loss with svmbir tomographic projector. Weighted squared :math:`\ell_2` loss of a CT reconstruction problem, .. math:: \alpha \norm{\mb{y} - A(\mb{x})}_W^2 = \alpha \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} - A(\mb{x})\right) \;, where :math:`A` is a :class:`.XRayTransform`, :math:`\alpha` is the scaling parameter and :math:`W` is an instance of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, it is set to :class:`scico.linop.Identity`. """ def __init__( self, *args, prox_kwargs: Optional[dict] = None, **kwargs, ): r"""Initialize a :class:`SVMBIRSquaredL2Loss` object. Args: y: Sinogram measurement. A: Forward operator. scale: Scaling parameter. W: Weighting diagonal operator. Must be non-negative. If ``None``, defaults to :class:`.Identity`. prox_kwargs: Dictionary of arguments passed to the :meth:`svmbir.recon` prox routine. Defaults to {"maxiter": 1000, "ctol": 0.001}. """ super().__init__(*args, **kwargs, prox_kwargs=prox_kwargs, positivity=False) if self.A.is_masked: raise ValueError( "Argument 'is_masked' must be False for the XRayTransform in SVMBIRSquaredL2Loss." ) ================================================ FILE: scico/linop/xray/symcone.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Cone beam X-ray transform for cylindrically symmetric objects. Cone beam X-ray transform and FDK reconstruction for cylindrically symmetric objects; essentialy a cone-beam variant of the Abel transform. The implementation is based on code modified from the `axitom `_ package :cite:`olufsen-2019-axitom`. """ from functools import partial from typing import Optional, Tuple import numpy as np import jax.numpy as jnp from jax import Array, jit, vjp from jax.scipy.ndimage import map_coordinates from jax.typing import ArrayLike from scico.typing import DType, Shape from .._linop import LinearOperator from ._axitom import backprojection, config, projection @partial(jit, static_argnames=["axis", "center"]) def _volume_by_axial_symmetry( x: Array, axis: int = 0, center: Optional[int] = None, zrange: Optional[Array] = None ) -> Array: """Create a volume by axial rotation of a plane. Args: x: 2D array that is rotated about an axis to generate a volume. axis: Index of axis of symmetry (must be 0 or 1). center: Location of the axis of symmetry on the other axis. If ``None``, defaults to center of that axis. Otherwise identifies the center coordinate on that axis. zrange: 1D array of points at which the extended axis is constructed. Defaults to the same as for axis :code:`1 - axis`. Returns: Volume as a 3D array. """ N0, N1 = x.shape N0h, N1h = (N0 + 1) / 2 - 1, (N1 + 1) / 2 - 1 half_shape = (N0h, N1h) if zrange is None: N2 = x.shape[1 - axis] N2h = (N2 + 1) / 2 - 1 zrange = jnp.arange(-N2h, N2h + 1) if axis == 0: g1d = [np.arange(0, N0), jnp.arange(-N1h, N1h + 1), zrange] else: g1d = [np.arange(-N0h, N0h + 1), jnp.arange(0, N1), zrange] if center is None: offset = 0 else: offset = center - half_shape[1 - axis] g0, g1, g2 = jnp.meshgrid(*g1d, indexing="ij") grids = (g0, g1, g2) r = jnp.hypot(grids[1 - axis], g2) sym_ax_crd = jnp.where( grids[1 - axis] >= 0, half_shape[1 - axis] + offset + r, half_shape[1 - axis] + offset - r ) if axis == 0: coords = [grids[axis], sym_ax_crd] else: coords = [sym_ax_crd, grids[axis]] v = map_coordinates(x, coords, cval=0.0, order=1) return v class AxiallySymmetricVolume(LinearOperator): """Create a volume by axial rotation of a plane.""" def __init__( self, input_shape: Shape, input_dtype: DType = np.float32, axis: int = 0, center: Optional[int] = None, ): """ Args: input_shape: Input image shape. input_dtype: Input image dtype. axis: Index of axis of symmetry (must be 0 or 1). center: If ``None``, defaults to the center of the image on the specified axis. Otherwise identifies the center coordinate on that axis. """ self.axis = axis self.center = center output_shape = input_shape + (input_shape[axis],) super().__init__( input_shape=input_shape, output_shape=output_shape, input_dtype=input_dtype, output_dtype=input_dtype, eval_fn=lambda x: _volume_by_axial_symmetry(x, axis=self.axis, center=self.center), jit=True, ) class SymConeXRayTransform(LinearOperator): """Cone beam X-ray transform for cylindrically symmetric objects. Cone beam X-ray transform of a cylindrically symmetric volume, which may be represented by a 2D central slice, which is rotated about the specified axis to generate a 3D volume for projection. The implementation is based on code modified from the AXITOM package :cite:`olufsen-2019-axitom`.. """ def __init__( self, input_shape: Shape, obj_dist: float, det_dist: float, axis: int = 0, pixel_size: Optional[Tuple[float, float]] = None, num_slabs: int = 1, ): """ Args: input_shape: Shape of the input array. If 2D, the input is extended to 3D (onto a new axis 1) by cylindrical symmetry. obj_dist: Source-object distance in arbitary length units (ALU). det_dist: Source-detector distance in ALU. axis: Index of axis of symmetry (must be 0 or 1). pixel_size: Tuple of pixel size values in ALU. num_slabs: Number of slabs into which the volume should be divided (for serial processing, to limit memory usage) in the imaging direction. """ if len(input_shape) == 2: self.input_2d = True output_shape = input_shape[::-1] else: self.input_2d = False output_shape = (input_shape[2], input_shape[0]) if pixel_size is None: pixel_size = (1.0, 1.0) self.axis = axis self.config = config.Config(*output_shape, *pixel_size, det_dist, obj_dist) self.num_slabs = num_slabs if len(input_shape) == 2 and axis == 1: eval_fn = lambda x: projection.forward_project( x.T, self.config, num_slabs=self.num_slabs, input_2d=self.input_2d ).T else: eval_fn = lambda x: projection.forward_project( x, self.config, num_slabs=self.num_slabs, input_2d=self.input_2d ) # use vjp rather than linear_transpose due to jax-ml/jax#30552 adj_fn = vjp(eval_fn, jnp.zeros(input_shape))[1] super().__init__( input_shape=input_shape, output_shape=output_shape, input_dtype=np.float32, output_dtype=np.float32, eval_fn=eval_fn, adj_fn=lambda x: adj_fn(x)[0], jit=True, ) def fdk(self, y: ArrayLike, num_angles: int = 360): """Reconstruct central slice from projection. Reconstruct the central slice of the cylindrically symmetric volume from a projection. The reconstruction makes use of the Feldkamp David Kress (FDK) algorithm implemented in the `axitom `_ package. Args: y: The projection to be reconstructed. num_angles: Number of angles to be averaged in the reconstruction. Returns: Reconstruction of the central slice of the volume. """ angles = jnp.linspace(0, 360, num_angles, endpoint=False) x = backprojection.fdk(y if self.axis == 1 else y.T, self.config, angles) return x if self.axis == 1 else x.T ================================================ FILE: scico/loss.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Loss function classes.""" import warnings from copy import copy from functools import wraps from typing import Callable, Optional, Union import jax import scico import scico.numpy as snp from scico import functional, linop, operator from scico.numpy import Array, BlockArray from scico.numpy.util import no_nan_divide from scico.scipy.special import gammaln # type: ignore from scico.solver import cg def _loss_mul_div_wrapper(func): @wraps(func) def wrapper(self, other): if snp.isscalar(other) or isinstance(other, jax.core.Tracer): return func(self, other) raise NotImplementedError( f"Operation {func} not defined between {type(self)} and {type(other)}." ) return wrapper class Loss(functional.Functional): r"""Generic loss function. Generic loss function .. math:: \alpha f(\mb{y}, A(\mb{x})) \;, where :math:`\alpha` is the scaling parameter and :math:`f(\cdot)` is the loss functional. """ def __init__( self, y: Union[Array, BlockArray], A: Optional[Union[Callable, operator.Operator]] = None, f: Optional[functional.Functional] = None, scale: float = 1.0, ): r""" Args: y: Measurement. A: Forward operator. Defaults to ``None``, in which case `self.A` is a :class:`.Identity` with input shape and dtype determined by the shape and dtype of `y`. f: Functional :math:`f`. If defined, the loss function is :math:`\alpha f(\mb{y} - A(\mb{x}))`. If ``None``, then :meth:`__call__` and :meth:`prox` (where appropriate) must be defined in a derived class. scale: Scaling parameter. Default: 1.0. """ self.y = y if A is None: # y and x must have same shape A = linop.Identity(input_shape=self.y.shape, input_dtype=self.y.dtype) # type: ignore self.A = A self.f = f self.scale = scale # Set functional-specific flags self.has_eval = True if self.f is not None and isinstance(self.A, linop.Identity): self.has_prox = True else: self.has_prox = False super().__init__() def __call__(self, x: Union[Array, BlockArray]) -> float: r"""Evaluate this loss at point :math:`\mb{x}`. Args: x: Point at which to evaluate loss. Returns: Result of evaluating the loss at `x`. """ if self.f is None: raise NotImplementedError( "Functional f is not defined and __call__ has not been overridden." ) return self.scale * self.f(self.A(x) - self.y) def prox( self, v: Union[Array, BlockArray], lam: float = 1, **kwargs ) -> Union[Array, BlockArray]: r"""Scaled proximal operator of loss function. Evaluate scaled proximal operator of this loss function, with scaling :math:`\lambda` = `lam` and evaluated at point :math:`\mb{v}` = `v`. If :meth:`prox` is not defined in a derived class, and if operator :math:`A` is the identity operator, then the proximal operator is computed using the proximal operator of functional :math:`l`, via Theorem 6.11 in :cite:`beck-2017-first`. Args: v: Point at which to evaluate prox function. lam: Proximal parameter :math:`\lambda`. **kwargs: Additional arguments that may be used by derived classes. These include `x0`, an initial guess for the minimizer in the defintion of :math:`\mathrm{prox}`. Returns: Result of evaluating the scaled proximal operator at `v`. """ if not self.has_prox: raise NotImplementedError( f"Method prox is not implemented for {type(self)} when A is {type(self.A)}; " "A must be an Identity." ) assert self.f is not None return self.f.prox(v - self.y, self.scale * lam, **kwargs) + self.y @_loss_mul_div_wrapper def __mul__(self, other): new_loss = copy(self) new_loss._grad = scico.grad(new_loss.__call__) new_loss.set_scale(self.scale * other) return new_loss def __rmul__(self, other): return self.__mul__(other) @_loss_mul_div_wrapper def __truediv__(self, other): new_loss = copy(self) new_loss._grad = scico.grad(new_loss.__call__) new_loss.set_scale(self.scale / other) return new_loss def set_scale(self, new_scale: float): r"""Update the scale attribute.""" self.scale = new_scale class SquaredL2Loss(Loss): r"""Weighted squared :math:`\ell_2` loss. Weighted squared :math:`\ell_2` loss .. math:: \alpha \norm{\mb{y} - A(\mb{x})}_W^2 = \alpha \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} - A(\mb{x})\right) \;, where :math:`\alpha` is the scaling parameter and :math:`W` is an instance of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, the weighting is an identity operator, giving an unweighted squared :math:`\ell_2` loss. """ def __init__( self, y: Union[Array, BlockArray], A: Optional[Union[Callable, operator.Operator]] = None, scale: float = 0.5, W: Optional[linop.Diagonal] = None, prox_kwargs: Optional[dict] = None, ): r""" Args: y: Measurement. A: Forward operator. If ``None``, defaults to :class:`.Identity`. scale: Scaling parameter. W: Weighting diagonal operator. Must be non-negative. If ``None``, defaults to :class:`.Identity`. """ self.W: linop.Diagonal if W is None: self.W = linop.Identity(y.shape) # type: ignore elif isinstance(W, linop.Diagonal): if snp.all(W.diagonal >= 0): # type: ignore self.W = W else: raise ValueError(f"The weights, W.diagonal, must be non-negative.") else: raise TypeError(f"Parameter W must be None or a linop.Diagonal, got {type(W)}.") super().__init__(y=y, A=A, scale=scale) default_prox_kwargs = {"maxiter": 100, "tol": 1e-5} if prox_kwargs: default_prox_kwargs.update(prox_kwargs) self.prox_kwargs = default_prox_kwargs if isinstance(self.A, linop.LinearOperator): self.has_prox = True def __call__(self, x: Union[Array, BlockArray]) -> float: return self.scale * snp.sum(self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2) def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: if not isinstance(self.A, linop.LinearOperator): raise NotImplementedError( f"Method prox is not implemented for {type(self)} when A is {type(self.A)}; " "A must be a LinearOperator." ) if isinstance(self.A, linop.Diagonal): c = 2.0 * self.scale * lam A = self.A.diagonal W = self.W.diagonal lhs = c * A.conj() * W * self.y + v # type: ignore ATWA = c * A.conj() * W * A # type: ignore return lhs / (ATWA + 1.0) # prox_f(v) = arg min 1/2 || v - x ||_2^2 + λ 𝛼 || A x - y ||^2_W # x # with solution: (I + λ 2𝛼 A^T W A) x = v + λ 2𝛼 A^T W y W = self.W A = self.A 𝛼 = self.scale y = self.y if "x0" in kwargs and kwargs["x0"] is not None: x0 = kwargs["x0"] else: x0 = snp.zeros_like(v) hessian = self.hessian # = (2𝛼 A^T W A) lhs = linop.Identity(v.shape) + lam * hessian rhs = v + 2 * lam * 𝛼 * A.adj(W(y)) x, _ = cg(lhs, rhs, x0, **self.prox_kwargs) # type: ignore return x @property def hessian(self) -> linop.LinearOperator: r"""Compute the Hessian of linear operator `A`. If `self.A` is a :class:`scico.linop.LinearOperator`, returns a :class:`scico.linop.LinearOperator` corresponding to the Hessian :math:`2 \alpha \mathrm{A^H W A}`. Otherwise not implemented. """ A = self.A W = self.W if isinstance(A, linop.LinearOperator): return linop.LinearOperator( input_shape=A.input_shape, output_shape=A.input_shape, eval_fn=lambda x: 2 * self.scale * A.adj(W(A(x))), # type: ignore adj_fn=lambda x: 2 * self.scale * A.adj(W(A(x))), # type: ignore input_dtype=A.input_dtype, ) raise NotImplementedError( f"Hessian is not implemented for {type(self)} when A is {type(A)}; " "A must be LinearOperator." ) class PoissonLoss(Loss): r"""Poisson negative log likelihood loss. Poisson negative log likelihood loss .. math:: \alpha \left( \sum_i [A(x)]_i - y_i \log\left( [A(x)]_i \right) + \log(y_i!) \right) \;, where :math:`\alpha` is the scaling parameter. """ def __init__( self, y: Union[Array, BlockArray], A: Optional[Union[Callable, operator.Operator]] = None, scale: float = 0.5, ): r""" Args: y: Measurement. A: Forward operator. Defaults to ``None``, in which case `self.A` is a :class:`.Identity`. scale: Scaling parameter. Default: 0.5. """ super().__init__(y=y, A=A, scale=scale) #: Constant term, :math:`\ln(y!)`, in Poisson log likehood. self.const = gammaln(self.y + 1.0) def __call__(self, x: Union[Array, BlockArray]) -> float: Ax = self.A(x) return self.scale * snp.sum(Ax - self.y * snp.log(Ax) + self.const) class SquaredL2AbsLoss(Loss): r"""Weighted squared :math:`\ell_2` with absolute value loss. Weighted squared :math:`\ell_2` with absolute value loss .. math:: \alpha \norm{\mb{y} - | A(\mb{x}) |\,}_W^2 = \alpha \left(\mb{y} - | A(\mb{x}) |\right)^T W \left(\mb{y} - | A(\mb{x}) |\right) \;, where :math:`\alpha` is the scaling parameter and :math:`W` is an instance of :class:`scico.linop.Diagonal`. Proximal operator :meth:`prox` is implemented when :math:`A` is an instance of :class:`scico.linop.Identity`. This is not proximal operator according to the strict definition since the loss function is non-convex (Sec. 3) :cite:`soulez-2016-proximity`. """ def __init__( self, y: Union[Array, BlockArray], A: Optional[Union[Callable, operator.Operator]] = None, scale: float = 0.5, W: Optional[linop.Diagonal] = None, ): r""" Args: y: Measurement. A: Forward operator. If ``None``, defaults to :class:`.Identity`. scale: Scaling parameter. W: Weighting diagonal operator. Must be non-negative. If ``None``, defaults to :class:`.Identity`. """ if W is None: self.W: Union[linop.Diagonal, linop.Identity] = linop.Identity(y.shape) elif isinstance(W, linop.Diagonal): if snp.all(W.diagonal >= 0): self.W = W else: raise ValueError("The weights, W.diagonal, must be non-negative.") else: raise TypeError(f"Parameter W must be None or a linop.Diagonal, got {type(W)}.") super().__init__(y=y, A=A, scale=scale) if isinstance(self.A, linop.Identity) and snp.all(y >= 0): self.has_prox = True def __call__(self, x: Union[Array, BlockArray]) -> float: return self.scale * snp.sum(self.W.diagonal * snp.abs(self.y - snp.abs(self.A(x))) ** 2) def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: if not self.has_prox: raise NotImplementedError(f"Method prox is not implemented.") 𝛼 = lam * 2.0 * self.scale * self.W.diagonal y = self.y r = snp.abs(v) 𝛽 = (𝛼 * y + r) / (𝛼 + 1.0) x = snp.where(r > 0, (𝛽 / r) * v, 𝛽) return x def _cbrt(x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: """Compute the cube root of the argument. The two standard options for computing the cube root of an array are :func:`numpy.cbrt`, or raising to the power of (1/3), i.e. `x ** (1/3)`. The former cannot be used for complex values, and the latter returns a complex root of a negative real value. This functions can be used for both real and complex values, and returns the real root of negative real values. Args: x: Input array. Returns: Array of cube roots of input `x`. """ s = snp.where(snp.abs(snp.angle(x)) <= 2 * snp.pi / 3, 1, -1) return s * (s * x) ** (1 / 3) def _check_root( x: Union[Array, BlockArray], p: Union[Array, BlockArray], q: Union[Array, BlockArray], tol: float = 1e-4, ): """Check the precision of a cubic equation solution. Check the precision of an array of depressed cubic equation solutions, issuing a warning if any of the errors exceed a specified tolerance. Args: x: Array of roots of a depressed cubic equation. p: Array of linear parameters of a depressed cubic equation. q: Array of constant parameters of a depressed cubic equation. tol: Expected tolerance for solution precision. """ err = snp.abs(x**3 + p * x + q) if not snp.allclose(err, 0, atol=tol): idx = snp.argmax(err) msg = ( "Low precision in root calculation. Worst error is " f"{err.ravel()[idx]:.3e} for p={p.ravel()[idx]} and q={q.ravel()[idx]}" ) warnings.warn(msg) def _dep_cubic_root( p: Union[Array, BlockArray], q: Union[Array, BlockArray] ) -> Union[Array, BlockArray]: r"""Compute a real root of a depressed cubic equation. A depressed cubic equation is one that can be written in the form .. math:: x^3 + px + q = 0 \;. The determinant is .. math:: \Delta = (q/2)^2 + (p/3)^3 \;. When :math:`\Delta > 0` this equation has one real root and two complex (conjugate) roots, when :math:`\Delta = 0`, all three roots are real, with at least two being equal, and when :math:`\Delta < 0`, all roots are real and unequal. According to Vieta's formulas, the roots :math:`x_0, x_1`, and :math:`x_2` of this equation satisfy .. math:: x_0 + x_1 + x_2 &= 0 \\ x_0 x_1 + x_0 x_2 + x_2 x_3 &= p \\ x_0 x_1 x_2 &= -q \;. Therefore, when :math:`q` is negative, the equation has a single real positive root since at least one root must be negative for their sum to be zero, and their product could not be positive if only one root were zero. This function always returns a real root; when :math:`q` is negative, it returns the single positive root. The solution is computed using `Vieta's substitution `__, .. math:: w = x - \frac{p}{3w} \;, which reduces the depressed cubic equation to .. math:: w^3 - \frac{p^3}{27w^3} + q = 0\;, which can be expressed as a quadratic equation in :math:`w^3` by multiplication by :math:`w^3`, leading to .. math:: w^3 = -\frac{q}{2} \pm \sqrt{\frac{q^2}{4} + \frac{p^3}{27}} \;. Note that the multiplication by :math:`w^3` introduces a spurious solution at zero in the case :math:`p = 0`, which must be handled separately as .. math:: w^3 = -q \;. Despite taking this into account, very poor numerical precision can be obtained when :math:`p` is small but non-zero since, in this case .. math:: \sqrt{\Delta} = \sqrt{(q/2)^2 + (p/3)^3} \approx q/2 \;, so that an incorrect solutions :math:`w^3 = 0` or :math:`w^3 = -q` are obtained, depending on the choice of sign in the equation for :math:`w^3`. An alternative derivation leads to the equation .. math:: x = \sqrt[3]{-q/2 + \sqrt{\Delta}} + \sqrt[3]{-q/2 - \sqrt{\Delta}} for the real root, but this is also prone to severe numerical errors in single precision arithmetic. Args: p: Array of :math:`p` values. q: Array of :math:`q` values. Returns: Array of real roots of the cubic equation. """ Δ = (q**2) / 4.0 + (p**3) / 27.0 w3 = snp.where(snp.abs(p) <= 1e-7, -q, -q / 2.0 + snp.sqrt(Δ + 0j)) w = _cbrt(w3) r = (w - no_nan_divide(p, 3 * w)).real _check_root(r, p, q) return r class SquaredL2SquaredAbsLoss(Loss): r"""Weighted squared :math:`\ell_2` with squared absolute value loss. Weighted squared :math:`\ell_2` with squared absolute value loss .. math:: \alpha \norm{\mb{y} - | A(\mb{x}) |^2 \,}_W^2 = \alpha \left(\mb{y} - | A(\mb{x}) |^2 \right)^T W \left(\mb{y} - | A(\mb{x}) |^2 \right) \;, where :math:`\alpha` is the scaling parameter and :math:`W` is an instance of :class:`scico.linop.Diagonal`. Proximal operator :meth:`prox` is implemented when :math:`A` is an instance of :class:`scico.linop.Identity`. This is not proximal operator according to the strict definition since the loss function is non-convex (Sec. 3) :cite:`soulez-2016-proximity`. """ def __init__( self, y: Union[Array, BlockArray], A: Optional[Union[Callable, operator.Operator]] = None, scale: float = 0.5, W: Optional[linop.Diagonal] = None, ): r""" Args: y: Measurement. A: Forward operator. If ``None``, defaults to :class:`.Identity`. scale: Scaling parameter. W: Weighting diagonal operator. Must be non-negative. If ``None``, defaults to :class:`.Identity`. """ if W is None: self.W: Union[linop.Diagonal, linop.Identity] = linop.Identity(y.shape) elif isinstance(W, linop.Diagonal): if snp.all(W.diagonal >= 0): self.W = W else: raise ValueError("The weights, W.diagonal, must be non-negative.") else: raise TypeError(f"Parameter W must be None or a linop.Diagonal, got {type(W)}.") super().__init__(y=y, A=A, scale=scale) if isinstance(self.A, linop.Identity) and snp.all(y >= 0): self.has_prox = True def __call__(self, x: Union[Array, BlockArray]) -> float: return self.scale * snp.sum( self.W.diagonal * snp.abs(self.y - snp.abs(self.A(x)) ** 2) ** 2 ) def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: if not self.has_prox: raise NotImplementedError(f"Method prox is not implemented.") 𝛼 = lam * 4.0 * self.scale * self.W.diagonal 𝛽 = snp.abs(v) p = no_nan_divide(1.0 - 𝛼 * self.y, 𝛼) q = no_nan_divide(-𝛽, 𝛼) r = _dep_cubic_root(p, q) φ = snp.where(𝛽 > 0, v / snp.abs(v), 1.0) x = snp.where(𝛼 > 0, r * φ, v) return x ================================================ FILE: scico/metric.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Image quality metrics and related functions.""" # This module is copied from https://github.com/bwohlberg/sporco from typing import Optional, Union import numpy as np import scico.numpy as snp from scico.numpy import Array, BlockArray def mae(reference: Union[Array, BlockArray], comparison: Union[Array, BlockArray]) -> float: """Compute Mean Absolute Error (MAE) between two images. Args: reference: Reference image. comparison: Comparison image. Returns: MAE between `reference` and `comparison`. """ return snp.mean(snp.abs(reference - comparison).ravel()) def mse(reference: Union[Array, BlockArray], comparison: Union[Array, BlockArray]) -> float: """Compute Mean Squared Error (MSE) between two images. Args: reference : Reference image. comparison : Comparison image. Returns: MSE between `reference` and `comparison`. """ return snp.mean(snp.abs(reference - comparison).ravel() ** 2) def snr(reference: Union[Array, BlockArray], comparison: Union[Array, BlockArray]) -> float: """Compute Signal to Noise Ratio (SNR) of two images. Args: reference: Reference image. comparison: Comparison image. Returns: SNR of `comparison` with respect to `reference`. """ dv = snp.var(reference) with np.errstate(divide="ignore"): rt = dv / mse(reference, comparison) return 10.0 * snp.log10(rt) def psnr( reference: Union[Array, BlockArray], comparison: Union[Array, BlockArray], signal_range: Optional[Union[int, float]] = None, ) -> float: """Compute Peak Signal to Noise Ratio (PSNR) of two images. The PSNR calculation defaults to using the less common definition in terms of the actual range (i.e. max minus min) of the reference signal instead of the maximum possible range for the data type (i.e. :math:`2^b-1` for a :math:`b` bit representation). Args: reference: Reference image. comparison: Comparison image. signal_range: Signal range, either the value to use (e.g. 255 for 8 bit samples) or ``None``, in which case the actual range of the reference signal is used. Returns: PSNR of `comparison` with respect to `reference`. """ if signal_range is None: signal_range = snp.abs(snp.max(reference) - snp.min(reference)) with np.errstate(divide="ignore"): rt = signal_range**2 / mse(reference, comparison) return 10.0 * snp.log10(rt) def isnr( reference: Union[Array, BlockArray], degraded: Union[Array, BlockArray], restored: Union[Array, BlockArray], ) -> float: """Compute Improvement Signal to Noise Ratio (ISNR). Compute Improvement Signal to Noise Ratio (ISNR) for reference, degraded, and restored images. Args: reference: Reference image. degraded: Degraded/observed image. restored: Restored/estimated image. Returns: ISNR of `restored` with respect to `reference` and `degraded`. """ msedeg = mse(reference, degraded) mserst = mse(reference, restored) with np.errstate(divide="ignore"): rt = msedeg / mserst return 10.0 * snp.log10(rt) def bsnr(blurry: Union[Array, BlockArray], noisy: Union[Array, BlockArray]) -> float: """Compute Blurred Signal to Noise Ratio (BSNR). Compute Blurred Signal to Noise Ratio (BSNR) for a blurred and noisy image. Args: blurry: Blurred noise free image. noisy: Blurred image with additive noise. Returns: BSNR of `noisy` with respect to `blurry`. """ blrvar = snp.var(blurry) nsevar = snp.var(noisy - blurry) with np.errstate(divide="ignore"): rt = blrvar / nsevar return 10.0 * snp.log10(rt) def rel_res(ax: Union[BlockArray, Array], b: Union[BlockArray, Array]) -> float: r"""Relative residual of the solution to a linear equation. The standard relative residual for the linear system :math:`A \mathbf{x} = \mathbf{b}` is :math:`\|\mathbf{b} - A \mathbf{x}\|_2 / \|\mathbf{b}\|_2`. This function computes a variant :math:`\|\mathbf{b} - A \mathbf{x}\|_2 / \max(\|A\mathbf{x}\|_2, \|\mathbf{b}\|_2)` that is robust to the case :math:`\mathbf{b} = 0`. Args: ax: Linear component :math:`A \mathbf{x}` of equation. b: Constant component :math:`\mathbf{b}` of equation. Returns: Relative residual value. """ nrm = max(snp.linalg.norm(ax.ravel()), snp.linalg.norm(b.ravel())) if nrm == 0.0: return 0.0 return snp.linalg.norm((b - ax).ravel()) / nrm ================================================ FILE: scico/numpy/__init__.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. r""":class:`.BlockArray` and compatible functions. This module consists of :class:`.BlockArray` and functions that support both instances of this class and jax arrays. This includes all the functions from :mod:`jax.numpy` and :mod:`numpy.testing`, where many have been extended to automatically map over block array blocks as described in :ref:`numpy_functions_blockarray`. Also included are additional functions unique to SCICO in :mod:`.util`. """ import sys from functools import partial from typing import Union import numpy as np import jax import jax.numpy as jnp from jax import Array from . import _wrappers, fft, linalg, testing, util from ._blockarray import BlockArray from ._wrapped_function_lists import ( creation_routines, mathematical_functions, reduction_functions, testing_functions, ) __all__ = ["fft", "linalg", "testing", "util"] # allow snp.blockarray(...) to create BlockArrays blockarray = BlockArray.blockarray blockarray.__module__ = __name__ # so that blockarray can be referenced in docs # BlockArray appears to originate in this module sys.modules[__name__].BlockArray.__module__ = __name__ # copy most of jnp without wrapping _wrappers.add_attributes(to_dict=vars(), from_dict=jnp.__dict__) # wrap jnp funcs _wrappers.wrap_recursively( vars(), creation_routines, partial( _wrappers.map_func_over_args, map_if_nested_args=["shape"], map_if_list_args=["device"], ), ) _wrappers.wrap_recursively(vars(), mathematical_functions, _wrappers.map_func_over_args) _wrappers.wrap_recursively(vars(), reduction_functions, _wrappers.add_full_reduction) def ravel(ba: Union[Array | BlockArray]) -> Array: """Completely flatten a :class:`BlockArray` into a single ``Array``. When called on an ``Array``, flattens the array. Args: ba: The :class:`BlockArray` to flatten. Returns: `ba` flattened into a single ``Array.`` """ if isinstance(ba, BlockArray): return jax.numpy.concatenate([arr.flatten() for arr in ba]) return ba.ravel() # wrap testing funcs _wrappers.wrap_recursively( vars(), testing_functions, partial(_wrappers.map_func_over_args, is_void=True) ) # clean up del np, jnp, _wrappers ================================================ FILE: scico/numpy/_blockarray.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SPORCO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """Block array class.""" import inspect from functools import WRAPPER_ASSIGNMENTS, wraps from typing import Callable import jax import jax.numpy as jnp from ._wrapped_function_lists import binary_ops, unary_ops from .util import is_collapsible # Determine type of "standard" jax array since jax.Array is an abstract # base class type that is not suitable for use here. JaxArray = type(jnp.array([0])) class BlockArray: """Block array class. A block array provides a way to combine arrays of different shapes into a single object for use with other SCICO classes. For further information, see the :ref:`detailed BlockArray documentation `. Example ------- >>> x = snp.blockarray(( ... [[1, 3, 7], ... [2, 2, 1]], ... [2, 4, 8] ... )) >>> x.shape ((2, 3), (3,)) >>> snp.sum(x) Array(30, dtype=int32) """ # Ensure we use BlockArray.__radd__, __rmul__, etc for binary # operations of the form op(np.ndarray, BlockArray) See # https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__ __array_priority__ = 1 def __init__(self, inputs): # convert inputs to jax arrays self.arrays = [x if isinstance(x, jax.ShapeDtypeStruct) else jnp.array(x) for x in inputs] # check that dtypes match if not all(a.dtype == self.arrays[0].dtype for a in self.arrays): raise ValueError("Heterogeneous dtypes not supported.") @property def dtype(self): """Return the dtype of the blocks, which must currently be homogeneous. This allows `snp.zeros(x.shape, x.dtype)` to work without a mechanism to handle lists of dtypes. """ return self.arrays[0].dtype def __len__(self): return self.arrays.__len__() def __getitem__(self, key): """Indexing method equivalent to x[key]. This is overridden to make, e.g., x[:2] return a BlockArray rather than a list. """ result = self.arrays[key] if isinstance(result, list): return BlockArray(result) # x[k:k+1] returns a BlockArray return result # x[k] returns a jax array def __setitem__(self, key, value): self.arrays[key] = value @staticmethod def blockarray(iterable): """Construct a :class:`.BlockArray` from a list or tuple of existing array-like.""" return BlockArray(iterable) def __repr__(self): return f"BlockArray({repr(self.arrays)})" def stack(self, axis=0): """Collapse a :class:`.BlockArray` to :class:`jax.Array`. Collapse a :class:`.BlockArray` to :class:`jax.Array` by stacking the blocks on axis `axis`. Args: axis: Index of new axis on which blocks are to be stacked. Returns: A :class:`jax.Array` obtained by stacking. Raises: ValueError: When called on a :class:`.BlockArray` that is not stackable. """ if is_collapsible(self.shape): return jnp.stack(self.arrays, axis=axis) else: raise ValueError(f"BlockArray of shape {self.shape} cannot be collapsed to an Array.") # Register BlockArray as a jax pytree; without this, jax autograd won't work. # Taken from what is done with tuples in jax._src.tree_util jax.tree_util.register_pytree_node( BlockArray, lambda xs: (xs, None), # to iter lambda _, xs: BlockArray(xs), # from iter ) # Wrap unary ops like -x. def _unary_op_wrapper(op_name): op = getattr(JaxArray, op_name) @wraps(op) def op_block_array(self): return BlockArray(op(x) for x in self) return op_block_array for op_name in unary_ops: setattr(BlockArray, op_name, _unary_op_wrapper(op_name)) # Wrap binary ops like x + y. """ def _binary_op_wrapper(op_name): op = getattr(JaxArray, op_name) @wraps(op) def op_block_array(self, other): # If other is a block array, we can assume the operation is # implemented (because block arrays must contain jax arrays) if isinstance(other, BlockArray): return BlockArray(op(x, y) for x, y in zip(self, other)) # If not, need to handle possible NotImplemented. Without this, # block_array + 'hi' -> [NotImplemented, NotImplemented, ...] result = list(op(x, other) for x in self) if NotImplemented in result: return NotImplemented return BlockArray(result) return op_block_array for op_name in binary_ops: setattr(BlockArray, op_name, _binary_op_wrapper(op_name)) # Wrap jax array properties. def _jax_array_prop_wrapper(prop_name): prop = getattr(JaxArray, prop_name) @property @wraps(prop) def prop_block_array(self): result = tuple(getattr(x, prop_name) for x in self) # If each jax_array.prop is a jax array, ... if all([isinstance(x, jnp.ndarray) for x in result]): # ...return a block array... return BlockArray(result) # ... otherwise return a tuple. return result return prop_block_array skip_props = ("at",) jax_array_props = [ k for k, v in dict(inspect.getmembers(JaxArray)).items() # (name, method) pairs if isinstance(v, property) and k[0] != "_" and k not in dir(BlockArray) and k not in skip_props ] for prop_name in jax_array_props: setattr(BlockArray, prop_name, _jax_array_prop_wrapper(prop_name)) # Wrap jax array methods. def _jax_array_method_wrapper(method_name): method = getattr(JaxArray, method_name) # Don't try to set attributes that are None. Not clear why some # functions/methods (e.g. block_until_ready) have None values # for these attributes. wrapper_assignments = WRAPPER_ASSIGNMENTS for attr in ("__name__", "__qualname__"): if getattr(method, attr) is None: wrapper_assignments = tuple(x for x in wrapper_assignments if x != attr) @wraps(method, assigned=wrapper_assignments) def method_block_array(self, *args, **kwargs): result = tuple(getattr(x, method_name)(*args, **kwargs) for x in self) # If each jax_array.method(...) call returns a jax array, ... if all([isinstance(x, jnp.ndarray) for x in result]): # ... return a block array... return BlockArray(result) # ... otherwise return a tuple. return result return method_block_array skip_methods = () jax_array_methods = [ k for k, v in dict(inspect.getmembers(JaxArray)).items() # (name, method) pairs if isinstance(v, Callable) and k[0] != "_" and k not in dir(BlockArray) and k not in skip_methods ] for method_name in jax_array_methods: setattr(BlockArray, method_name, _jax_array_method_wrapper(method_name)) ================================================ FILE: scico/numpy/_wrapped_function_lists.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SPORCO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """ Lists of functions to be wrapped in scico.numpy. These are intended to be the functions in :mod:`jax.numpy` that should either #. map over the blocks of a block array (for math functions); #. map over a tuple of tuples to create a block array (for creation functions); or #. reduce a block array to a scalar (for reductions). The links to the numpy docs in the comments are useful for distinguishing between these three cases, but note that these lists of numpy functions include extra functions that are not in :mod:`jax.numpy`, and that are therefore not listed here. """ """ BlockArray """ unary_ops = ( # found from dir() on jax array "__abs__", "__neg__", "__pos__", ) binary_ops = ( # found from dir() on jax array "__add__", "__eq__", "__floordiv__", "__ge__", "__gt__", "__le__", "__lt__", "__matmul__", "__mod__", "__mul__", "__ne__", "__pow__", "__radd__", "__rfloordiv__", "__rmatmul__", "__rmul__", "__rpow__", "__rsub__", "__rtruediv__", "__sub__", "__truediv__", ) """ jax.numpy """ creation_routines = ( "empty", "ones", "zeros", "full", ) mathematical_functions = ( "sin", # https://numpy.org/doc/stable/reference/routines.math.html "cos", "tan", "arcsin", "arccos", "arctan", "hypot", "arctan2", "degrees", "radians", "unwrap", "deg2rad", "rad2deg", "sinh", "cosh", "tanh", "arcsinh", "arccosh", "arctanh", "around", "round", "rint", "floor", "ceil", "trunc", "prod", "sum", "nanprod", "nansum", "cumprod", "cumsum", "nancumprod", "nancumsum", "diff", "ediff1d", "gradient", "cross", "exp", "expm1", "exp2", "log", "log10", "log2", "log1p", "logaddexp", "logaddexp2", "i0", "sinc", "signbit", "copysign", "frexp", "ldexp", "nextafter", "lcm", "gcd", "add", "reciprocal", "positive", "negative", "multiply", "divide", "power", "subtract", "true_divide", "floor_divide", "float_power", "fmod", "mod", "modf", "remainder", "divmod", "angle", "real", "imag", "conj", "conjugate", "maximum", "fmax", "amax", "nanmax", "minimum", "fmin", "amin", "nanmin", "convolve", "clip", "sqrt", "cbrt", "square", "abs", "absolute", "fabs", "sign", "heaviside", "nan_to_num", "interp", "sort", # https://numpy.org/doc/stable/reference/routines.sort.html "lexsort", "argsort", "sort_complex", "partition", "argmax", "nanargmax", "argmin", "nanargmin", "argwhere", "nonzero", "flatnonzero", "where", "searchsorted", "extract", "count_nonzero", "dot", # https://numpy.org/doc/stable/reference/routines.linalg.html "linalg.multi_dot", "vdot", "inner", "outer", "matmul", "tensordot", "einsum", "einsum_path", "linalg.matrix_power", "kron", "linalg.cholesky", "linalg.qr", "linalg.svd", "linalg.eig", "linalg.eigh", "linalg.eigvals", "linalg.eigvalsh", "linalg.norm", "linalg.cond", "linalg.det", "linalg.matrix_rank", "linalg.slogdet", "trace", "linalg.solve", "linalg.tensorsolve", "linalg.lstsq", "linalg.inv", "linalg.pinv", "linalg.tensorinv", "shape", # https://numpy.org/doc/stable/reference/routines.array-manipulation.html "reshape", "moveaxis", "rollaxis", "swapaxes", "transpose", "atleast_1d", "atleast_2d", "atleast_3d", "expand_dims", "squeeze", "asarray", "stack", "block", "vstack", "hstack", "dstack", "column_stack", "split", "array_split", "dsplit", "hsplit", "vsplit", "tile", "repeat", "insert", "append", "resize", "trim_zeros", "unique", "pad", "flip", "fliplr", "flipud", "reshape", "roll", "rot90", "all", "any", "isfinite", "isinf", "isnan", "isneginf", "isposinf", "iscomplex", "iscomplexobj", "isreal", "isrealobj", "isscalar", "logical_and", "logical_or", "logical_not", "logical_xor", "allclose", "isclose", "array_equal", "array_equiv", "greater", "greater_equal", "less", "less_equal", "equal", "not_equal", "empty_like", # https://numpy.org/doc/stable/reference/routines.array-creation.html "ones_like", "zeros_like", "full_like", ) # these may also appear in the mathematical function list reduction_functions = ("sum", "linalg.norm", "count_nonzero", "all", "any") """ testing """ testing_functions = ("testing.assert_allclose", "testing.assert_array_equal") ================================================ FILE: scico/numpy/_wrappers.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SPORCO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """Utilities for wrapping jnp functions to handle BlockArray inputs.""" import sys import warnings from functools import wraps from inspect import Parameter, signature from types import ModuleType from typing import Callable, Iterable, Optional import jax.numpy as jnp import scico.numpy as snp from ._blockarray import BlockArray def add_attributes( to_dict: dict, from_dict: dict, modules_to_recurse: Optional[Iterable[str]] = None, ): """Add attributes in `from_dict` to `to_dict`. Underscore attributes are ignored. Modules are ignored, except those listed in `modules_to_recurse`, which are added recursively. All others are added. """ if modules_to_recurse is None: modules_to_recurse = () for name, obj in from_dict.items(): if name[0] == "_": continue if isinstance(obj, ModuleType): if name in modules_to_recurse: qualname = to_dict["__name__"] + "." + name to_dict[name] = ModuleType(name, doc=obj.__doc__) to_dict[name].__package__ = to_dict["__name__"] # enable `import scico.numpy.linalg` and `from scico.numpy.linalg import norm` sys.modules[qualname] = to_dict[name] sys.modules[qualname].__name__ = qualname add_attributes(to_dict[name].__dict__, obj.__dict__) else: to_dict[name] = obj def wrap_recursively( target_dict: dict, names: Iterable[str], wrap: Callable, ): """Call wrap functions in `target_dict`, correctly handling names like `"linalg.norm"`.""" for name in names: if "." in name: module, rest = name.split(".", maxsplit=1) wrap_recursively(target_dict[module].__dict__, [rest], wrap) else: if name in target_dict: target_dict[name] = wrap(target_dict[name]) else: warnings.warn(f"In call to wrap_recursively, name {name} is not in target_dict") def map_func_over_args( func: Callable, map_if_nested_args: Optional[list[str]] = [], map_if_list_args: Optional[list[str]] = [], is_void: Optional[bool] = False, ): """ Wrap a function so that it automatically maps over its arguments, returning a BlockArray. BlockArray arguments always trigger mapping. Other arguments trigger mapping if they meet specified criteria. """ # check inputs func_signature = signature(func) for arg in map_if_nested_args + map_if_list_args: if arg not in func_signature.parameters: raise ValueError(f"`{arg}` is not an argument of {func.__name__}") # define wrapped function @wraps(func) def wrapped(*args, **kwargs): arg_names = [ k for k, v in func_signature.parameters.items() if v.kind in ( Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD, ) ] # look in args for mapping triggers arg_is_mapping = [] for arg_num, arg_val in enumerate(args): if ( isinstance(arg_val, BlockArray) or ( snp.util.is_nested(arg_val) and arg_num < len(arg_names) and arg_names[arg_num] in map_if_nested_args ) or ( isinstance(arg_val, (list, tuple)) and arg_num < len(arg_names) and arg_names[arg_num] in map_if_list_args ) ): arg_is_mapping.append(True) else: arg_is_mapping.append(False) # look in kwargs for mapping triggers kwarg_is_mapping = {} for arg_name, arg_val in kwargs.items(): if ( isinstance(arg_val, BlockArray) or (arg_name in map_if_nested_args and snp.util.is_nested(arg_val)) or (arg_name in map_if_list_args and isinstance(arg_val, (list, tuple))) ): kwarg_is_mapping[arg_name] = True else: kwarg_is_mapping[arg_name] = False # no arguments that trigger mapping? call as usual if sum(arg_is_mapping) == 0 and sum(kwarg_is_mapping.values()) == 0: return func(*args, **kwargs) # count number of blocks num_blocks = ( len( args[ [index for index, mapping_flag in enumerate(arg_is_mapping) if mapping_flag][0] ] ) # first mapping arg if sum(arg_is_mapping) else len( kwargs[[k for k, mapping_flag in kwarg_is_mapping.items() if mapping_flag][0]] ) # first mapping kwarg ) # map func over the mapping args results = [] for block_ind in range(num_blocks): result = func( *[ arg[block_ind] if is_mapping else arg for arg, is_mapping in zip(args, arg_is_mapping) ], **{ k: kwargs[k][block_ind] if is_mapping else kwargs[k] for k, is_mapping in kwarg_is_mapping.items() }, ) results.append(result) if is_void: return return BlockArray(results) return wrapped def add_full_reduction(func: Callable, axis_arg_name: Optional[str] = "axis"): """Wrap a function so that it can fully reduce a BlockArray. Wrap a function so that it can fully reduce a :class:`.BlockArray`. If nothing is passed for the axis argument and the function is called on a :class:`.BlockArray`, it is fully ravelled before the function is called. Should be outside :func:`map_func_over_args`. """ sig = signature(func) if axis_arg_name not in sig.parameters: raise ValueError( f"Cannot wrap {func} as a reduction because it has no {axis_arg_name} argument." ) @wraps(func) def wrapped(*args, **kwargs): bound_args = sig.bind(*args, **kwargs) ba_args = {} for k, v in list(bound_args.arguments.items()): if isinstance(v, BlockArray): ba_args[k] = bound_args.arguments.pop(k) if "axis" in bound_args.arguments: return func(*bound_args.args, **bound_args.kwargs, **ba_args) # call func as normal if len(ba_args) > 1: raise ValueError("Cannot perform a full reduction with multiple BlockArray arguments.") # fully ravel the ba argument ba_args = {k: jnp.concatenate(v.ravel()) for k, v in ba_args.items()} return func(*bound_args.args, **bound_args.kwargs, **ba_args) return wrapped ================================================ FILE: scico/numpy/fft.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Discrete Fourier Transform functions.""" import numpy as np import jax.numpy as jnp from . import _wrappers _wrappers.add_attributes( to_dict=vars(), from_dict=jnp.fft.__dict__, ) # clean up del np, jnp, _wrappers ================================================ FILE: scico/numpy/linalg.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Linear algebra functions.""" import numpy as np import jax.numpy as jnp from . import _wrappers _wrappers.add_attributes( to_dict=vars(), from_dict=jnp.linalg.__dict__, ) # clean up del np, jnp, _wrappers ================================================ FILE: scico/numpy/testing.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Test support functions.""" import numpy as np from . import _wrappers _wrappers.add_attributes( to_dict=vars(), from_dict=np.testing.__dict__, ) # clean up del np, _wrappers ================================================ FILE: scico/numpy/util.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SPORCO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """Utility functions for working with jax arrays and BlockArrays.""" from __future__ import annotations import collections from math import prod from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union import numpy as np import jax from typing_extensions import TypeGuard import scico.numpy as snp from scico.typing import ArrayIndex, Axes, AxisIndex, BlockShape, DType, Shape def transpose_ntpl_of_list(ntpl: NamedTuple) -> List[NamedTuple]: """Convert a namedtuple of lists/arrays to a list of namedtuples. Args: ntpl: Named tuple object to be transposed. Returns: List of namedtuple objects. """ cls = ntpl.__class__ numentry = len(ntpl[0]) if isinstance(ntpl[0], list) else ntpl[0].shape[0] nfields = len(ntpl._fields) return [cls(*[ntpl[m][n] for m in range(nfields)]) for n in range(numentry)] def transpose_list_of_ntpl(ntlist: List[NamedTuple]) -> NamedTuple: """Convert a list of namedtuples to namedtuple of lists. Args: ntpl: List of namedtuple objects to be transposed. Returns: Named tuple of lists. """ cls = ntlist[0].__class__ numentry = len(ntlist) nfields = len(ntlist[0]) return cls(*[[ntlist[m][n] for m in range(numentry)] for n in range(nfields)]) # type: ignore def namedtuple_to_array(ntpl: NamedTuple) -> snp.Array: """Convert a namedtuple to an array. Convert a :func:`collections.namedtuple` object to a :class:`numpy.ndarray` object that can be saved using :func:`numpy.savez`. Args: ntpl: Named tuple object to be converted to ndarray. Returns: Array representation of input named tuple. """ return np.asarray( { "name": ntpl.__class__.__name__, "fields": ntpl._fields, "data": {fname: fval for fname, fval in zip(ntpl._fields, ntpl)}, } ) def array_to_namedtuple(array: snp.Array) -> NamedTuple: """Convert an array representation of a namedtuple back to a namedtuple. Convert a :class:`numpy.ndarray` object constructed by :func:`namedtuple_to_array` back to the original :func:`collections.namedtuple` representation. Args: Array representation of named tuple constructed by :func:`namedtuple_to_array`. Returns: Named tuple object with the same name and fields as the original named tuple object provided to :func:`namedtuple_to_array`. """ cls = collections.namedtuple(array.item()["name"], array.item()["fields"]) # type: ignore return cls(**array.item()["data"]) def normalize_axes( axes: Optional[Axes], shape: Optional[Shape] = None, default: Optional[List[int]] = None, sort: bool = False, ) -> Sequence[int]: """Normalize `axes` to a sequence and optionally ensure correctness. Normalize `axes` to a tuple or list and (optionally) ensure that entries refer to axes that exist in `shape`. Args: axes: User specification of one or more axes: int, list, tuple, or ``None``. Negative values count from the last to the first axis. shape: The shape of the array of which axes are being specified. If not ``None``, `axes` is checked to make sure its entries refer to axes that exist in `shape`. default: Default value to return if `axes` is ``None``. By default, `tuple(range(len(shape)))`. sort: If ``True``, sort the returned axis indices. Returns: Tuple or list of axes (never an int, never ``None``). The output will only be a list if the input is a list or if the input is ``None`` and `defaults` is a list. """ if axes is None: if default is None: if shape is None: raise ValueError( "Argument 'axes' cannot be None without a default or shape specified." ) axes = tuple(range(len(shape))) else: axes = default elif isinstance(axes, (list, tuple)): axes = axes elif isinstance(axes, int): axes = (axes,) else: raise ValueError(f"Could not understand argument 'axes' {axes} as a list of axes.") if shape is not None: if min(axes) < 0: axes = tuple([len(shape) + a if a < 0 else a for a in axes]) if max(axes) >= len(shape): raise ValueError( f"Invalid axes {axes} specified; each axis must be less than `len(shape)`={len(shape)}." ) if len(set(axes)) != len(axes): raise ValueError(f"Duplicate value in axes {axes}; each axis must be unique.") if sort: axes = tuple(sorted(axes)) return axes def slice_length(length: int, idx: AxisIndex) -> Optional[int]: """Determine the length of an array axis after indexing. Determine the length of an array axis after slicing. An exception is raised if the indexing expression is an integer that is out of bounds for the specified axis length. A value of ``None`` is returned for valid integer indexing expressions as an indication that the corresponding axis shape is an empty tuple; this value should be converted to a unit integer if the axis size is required. Args: length: Length of axis being sliced. idx: Indexing/slice to be applied to axis. Returns: Length of indexed/sliced axis. Raises: ValueError: If `idx` is an integer index that is out bounds for the axis length or if the type of `idx` is not one of `Ellipsis`, `int`, or `slice`. """ if idx is Ellipsis: return length if isinstance(idx, int): if idx < -length or idx > length - 1: raise ValueError(f"Index {idx} out of bounds for axis of length {length}.") return None if not isinstance(idx, slice): raise ValueError(f"Index expression {idx} is of an unrecognized type.") start, stop, stride = idx.indices(length) if start > stop: start = stop return (stop - start + stride - 1) // stride def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]: """Determine the shape of an array after indexing/slicing. The indexed shape is determined by replicating the observed effects of NumPy/JAX array indexing/slicing syntax. It is significantly faster than :func:`.jax_indexed_shape`, and has a minimal memory footprint in all circumstances. Args: shape: Shape of array. idx: Indexing expression (singleton or tuple of `Ellipsis`, `int`, `slice`, or ``None`` (`np.newaxis`)). Returns: Shape of indexed/sliced array. Raises: ValueError: If any element of `idx` is not one of `Ellipsis`, `int`, `slice`, or ``None`` (`np.newaxis`), or if an integer index is out bounds for the corresponding axis length. """ if not isinstance(idx, tuple): idx = (idx,) idx_shape: List[Optional[int]] = list(shape) offset = 0 newaxis = 0 for axis, ax_idx in enumerate(idx): if ax_idx is None: idx_shape.insert(axis, 1) newaxis += 1 continue if ax_idx is Ellipsis: offset = len(shape) - len(idx) continue idx_shape[axis + offset + newaxis] = slice_length(shape[axis + offset], ax_idx) return tuple(filter(lambda x: x is not None, idx_shape)) # type: ignore def jax_indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]: """Determine the shape of an array after indexing/slicing. The indexed shape is determined by constructing and indexing an array of the appropriate shape, relying on :func:`jax.jit` to avoid memory allocation. It is potentially more reliable than :func:`.indexed_shape` because the indexing/slicing calculations are referred to JAX, but is significantly slower, and will involved potentially significant memory allocations if JIT is disabled, e.g. for debugging purposes. Args: shape: Shape of array. idx: Indexing expression (singleton or tuple of `Ellipsis`, `int`, `slice`, or ``None`` (`np.newaxis`)). Returns: Shape of indexed/sliced array. """ if not isinstance(idx, tuple): idx = (idx,) # Convert any slices to its representation (slice, (start, stop, step)) # allowing hashing, needed for jax.jit idx = tuple(exp.__reduce__() if isinstance(exp, slice) else exp for exp in idx) # type: ignore def get_shape(in_shape, ind_expr): # convert slices representations back to slices ind_expr = tuple( (slice(*exp[1]) if isinstance(exp, tuple) and len(exp) > 0 and exp[0] == slice else exp) for exp in ind_expr ) return jax.numpy.empty(in_shape)[ind_expr].shape # This compiles each time it gets new arguments because all arguments are static. f = jax.jit(get_shape, static_argnums=(0, 1)) return tuple(t.item() for t in f(shape, idx)) # type: ignore def no_nan_divide( x: Union[snp.BlockArray, snp.Array], y: Union[snp.BlockArray, snp.Array] ) -> Union[snp.BlockArray, snp.Array]: """Return `x/y`, with 0 instead of :data:`~numpy.NaN` where `y` is 0. Args: x: Numerator. y: Denominator. Returns: `x / y` with 0 wherever `y == 0`. """ return snp.where(y != 0, snp.divide(x, snp.where(y != 0, y, 1)), 0) def _readable_size(size: int) -> str: """Return a human-readable representation of an array size. Args: size: A positive integer array size. Returns: A string representation of the size. """ factor = [1, 1024, 1024**2, 1024**3, 1024**4] units = ["B", "KB", "MB", "GB", "TB"] idx_tuple = np.nonzero([size // f for f in factor[::-1]]) if idx_tuple[0].size == 0: idx = len(factor) - 1 else: idx = int(idx_tuple[0][0]) val = size // factor[::-1][idx] ustr = units[::-1][idx] return f"{val} {ustr}" def array_info(x: Union[snp.BlockArray, snp.Array]) -> str: """Return a string providing information about an array. Args: x: A numpy or jax array or scico :class:`BlockArray`. Returns: A string containing information on the array. Raises: TypeError: If the array is not of a recognized type. """ if isinstance(x, np.ndarray): array_type = "numpy.ndarray" elif isinstance(x, jax.Array): array_type = "jax.Array" elif isinstance(x, snp.BlockArray): array_type = "scico.numpy.BlockArray" else: raise TypeError("Unrecognized array type {type(x)}.") totalbytes = np.sum(x.nbytes).item() # type: ignore return ( f"""{array_type} shape: {x.shape} size: {x.size} bytes: {totalbytes} ({_readable_size(totalbytes)}) """ + (f" device: {x.device}\n" if hasattr(x, "device") else "") + f""" dtype: {dtype_name(x.dtype)} id: {id(x)} min, max: {snp.ravel(x).min()}, {snp.ravel(x).max()} """ ) def shape_to_size(shape: Union[Shape, BlockShape]) -> int: r"""Compute array size corresponding to a specified shape. Compute array size corresponding to a specified shape, which may be nested, i.e. corresponding to a :class:`BlockArray`. Args: shape: A shape tuple. Returns: The number of elements in an array or :class:`BlockArray` with shape `shape`. """ if is_nested(shape): return sum(prod(s) for s in shape) # type: ignore return prod(shape) # type: ignore def is_array(x: Any) -> bool: """Check if input is of type :class:`jax.Array` or :class:`numpy.ndarray`. Check if input is an array, of type :class:`jax.Array` or :class:`numpy.ndarray`. Args: x: Object to be tested. Returns: ``True`` if `x` is an array, ``False`` otherwise. """ return isinstance(x, (np.ndarray, jax.Array)) def is_arraylike(x: Any) -> bool: """Check if input is of type :class:`jax.typing.ArrayLike`. `isinstance(x, jax.typing.ArrayLike)` does not work in Python < 3.10, see https://jax.readthedocs.io/en/latest/jax.typing.html#jax-typing-best-practices. Args: x: Object to be tested. Returns: ``True`` if `x` is an ArrayLike, ``False`` otherwise. """ return isinstance(x, (np.ndarray, jax.Array)) or np.isscalar(x) def is_nested(x: Any) -> bool: """Check if input is a list/tuple containing at least one list/tuple. Args: x: Object to be tested. Returns: ``True`` if `x` is a list/tuple containing at least one list/tuple, ``False`` otherwise. Example: >>> is_nested([1, 2, 3]) False >>> is_nested([(1,2), (3,)]) True >>> is_nested([[1, 2], 3]) True """ return isinstance(x, (list, tuple)) and any([isinstance(_, (list, tuple)) for _ in x]) def is_collapsible(shapes: Sequence[Union[Shape, BlockShape]]) -> bool: """Determine whether a sequence of shapes can be collapsed. Return ``True`` if the a list of shapes represent arrays that can be stacked, i.e., they are all the same. Args: shapes: A sequence of shapes. Returns: A boolean value indicating whether the shapes are all the same. """ return all(s == shapes[0] for s in shapes) def is_blockable(shapes: Sequence[Union[Shape, BlockShape]]) -> TypeGuard[Union[Shape, BlockShape]]: """Determine whether a sequence of shapes could be a :class:`BlockArray` shape. Return ``True`` if the sequence of shapes represent arrays that can be combined into a :class:`BlockArray`, i.e., none are nested. Args: shapes: A sequence of shapes. Returns: A boolean value indicating whether any of the shapes are nested. """ return not any(is_nested(s) for s in shapes) def shape_dtype_rep( shape: Union[Shape, BlockShape], dtype: DType ) -> Union[jax.ShapeDtypeStruct, snp.BlockArray]: """Construct a representation of array or blockarray shape and dtype. Construct a representation of array or block array shape and dtype that is suitable for both jax arrays and scico blockarrays. Args: shape: Array or blockarray shape. dtype: Array or blockarray dtype. Returns: A :class:`jax.ShapeDtypeStruct` or a :class:`.BlockArray` containing :class:`jax.ShapeDtypeStruct`s. """ if is_nested(shape): # block array return snp.BlockArray([jax.ShapeDtypeStruct(blk_shape, dtype=dtype) for blk_shape in shape]) else: # standard array return jax.ShapeDtypeStruct(shape, dtype=dtype) def broadcast_nested_shapes( shape_a: Union[Shape, BlockShape], shape_b: Union[Shape, BlockShape] ) -> Union[Shape, BlockShape]: r"""Compute the result of broadcasting on array shapes. Compute the result of applying a broadcasting binary operator to (block) arrays with (possibly nested) shapes `shape_a` and `shape_b`. Extends :func:`numpy.broadcast_shapes` to also support the nested tuple shapes of :class:`BlockArray`\ s. Args: shape_a: First array shape. shape_b: Second array shape. Returns: A (possibly nested) shape tuple. Example: >>> broadcast_nested_shapes(((1, 1, 3), (2, 3, 1)), ((2, 3,), (2, 1, 4))) ((1, 2, 3), (2, 3, 4)) """ if not is_nested(shape_a) and not is_nested(shape_b): return snp.broadcast_shapes(shape_a, shape_b) if is_nested(shape_a) and not is_nested(shape_b): return tuple(snp.broadcast_shapes(s, shape_b) for s in shape_a) if not is_nested(shape_a) and is_nested(shape_b): return tuple(snp.broadcast_shapes(shape_a, s) for s in shape_b) if is_nested(shape_a) and is_nested(shape_b): return tuple(snp.broadcast_shapes(s_a, s_b) for s_a, s_b in zip(shape_a, shape_b)) raise RuntimeError("Unexpected case encountered in broadcast_nested_shapes.") def is_real_dtype(dtype: DType) -> bool: """Determine whether a dtype is real. Args: dtype: A :mod:`numpy` or :mod:`scico.numpy` dtype (e.g. :attr:`~numpy.float32`, :attr:`~numpy.complex64`). Returns: ``False`` if the dtype is complex, otherwise ``True``. """ return snp.dtype(dtype).kind != "c" def is_complex_dtype(dtype: DType) -> bool: """Determine whether a dtype is complex. Args: dtype: A :mod:`numpy` or :mod:`scico.numpy` dtype (e.g. :attr:`~numpy.float32`, :attr:`~numpy.complex64`). Returns: ``True`` if the dtype is complex, otherwise ``False``. """ return snp.dtype(dtype).kind == "c" def real_dtype(dtype: DType) -> DType: """Construct the corresponding real dtype for a given complex dtype. Construct the corresponding real dtype for a given complex dtype, e.g. the real dtype corresponding to :attr:`~numpy.complex64` is :attr:`~numpy.float32`. Args: dtype: A complex numpy or scico.numpy dtype (e.g. :attr:`~numpy.complex64`, :attr:`~numpy.complex128`). Returns: The real dtype corresponding to the input dtype """ return snp.zeros(1, dtype).real.dtype def complex_dtype(dtype: DType) -> DType: """Construct the corresponding complex dtype for a given real dtype. Construct the corresponding complex dtype for a given real dtype, e.g. the complex dtype corresponding to :attr:`~numpy.float32` is :attr:`~numpy.complex64`. Args: dtype: A real numpy or scico.numpy dtype (e.g. :attr:`~numpy.float32`, :attr:`~numpy.float64`). Returns: The complex dtype corresponding to the input dtype. """ return (snp.zeros(1, dtype) + 1j).dtype def dtype_name(dtype: DType) -> str: """Return the name of a dtype. Construct a string representation of a dtype name. Args: dtype: The dtype for which the name is required. Returns: The name of the dtype. """ if type(dtype).__module__ == "numpy.dtypes": return f"""numpy.{dtype.name}""" # type: ignore return f"""{dtype.__module__}.{dtype.__qualname__}""" # type: ignore def is_scalar_equiv(s: Any) -> bool: """Determine whether an object is a scalar or is scalar-equivalent. Determine whether an object is a scalar or a singleton array. Args: s: Object to be tested. Returns: ``True`` if the object is a scalar or a singleton array, otherwise ``False``. """ return snp.isscalar(s) or (isinstance(s, jax.Array) and s.ndim == 0) ================================================ FILE: scico/operator/__init__.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Operator functions and classes.""" import sys # isort: off from ._operator import Operator from .biconvolve import BiConvolve from ._func import operator_from_function, Abs, Angle, Exp from ._stack import DiagonalStack, VerticalStack, DiagonalReplicated __all__ = [ "Operator", "BiConvolve", "DiagonalReplicated", "DiagonalStack", "VerticalStack", "operator_from_function", "Abs", "Angle", "Exp", ] # Imported items in __all__ appear to originate in top-level linop module for name in __all__: getattr(sys.modules[__name__], name).__module__ = __name__ ================================================ FILE: scico/operator/_func.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Operators constructed from functions.""" from typing import Any, Callable, Optional, Union import scico.numpy as snp from scico.typing import BlockShape, DType, Shape from ._operator import Operator __all__ = [ "operator_from_function", "Abs", "Angle", "Exp", ] def operator_from_function(f: Callable, classname: str, f_name: Optional[str] = None): """Make an :class:`.Operator` from a function. Example ------- >>> AbsVal = operator_from_function(snp.abs, 'AbsVal') >>> H = AbsVal((2,)) >>> H(snp.array([1.0, -1.0])) Array([1., 1.], dtype=float32) Args: f: Function from which to create an :class:`.Operator`. classname: Name of the resulting class. f_name: Name of `f` for use in docstrings. Useful for getting the correct version of wrapped functions. Defaults to `f"{f.__module__}.{f.__name__}"`. """ if f_name is None: f_name = f"{f.__module__}.{f.__name__}" f_doc = rf""" Args: input_shape: Shape of input array. args: Positional arguments passed to :func:`{f_name}`. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. If the :class:`.Operator` implements complex-valued operations, this must be a complex dtype (typically :attr:`~numpy.complex64`) for correct adjoint and gradient calculation. output_shape: Shape of output array. Defaults to ``None``. If ``None``, `output_shape` is determined by evaluating `self.__call__` on an input array of zeros. output_dtype: `dtype` for output argument. Defaults to ``None``. If ``None``, `output_dtype` is determined by evaluating `self.__call__` on an input array of zeros. jit: If ``True``, call :meth:`.Operator.jit` on this `Operator` to jit the forward, adjoint, and gram functions. Same as calling :meth:`.Operator.jit` after the :class:`.Operator` is created. **kwargs: Keyword arguments passed to :func:`{f_name}`. """ def __init__( self, input_shape: Union[Shape, BlockShape], *args: Any, input_dtype: DType = snp.float32, output_shape: Optional[Union[Shape, BlockShape]] = None, output_dtype: Optional[DType] = None, jit: bool = True, **kwargs: Any, ): self._eval = lambda x: f(x, *args, **kwargs) super().__init__(input_shape, input_dtype=input_dtype, output_shape=output_shape, output_dtype=output_dtype, jit=jit) # type: ignore OpClass = type(classname, (Operator,), {"__init__": __init__}) __class__ = OpClass # needed for super() to work OpClass.__doc__ = f"Operator version of :func:`{f_name}`." OpClass.__init__.__doc__ = f_doc # type: ignore return OpClass Abs = operator_from_function(snp.abs, "Abs", "scico.numpy.abs") Angle = operator_from_function(snp.angle, "Angle", "scico.numpy.angle") Exp = operator_from_function(snp.exp, "Exp", "scico.numpy.exp") ================================================ FILE: scico/operator/_operator.py ================================================ # Copyright (C) 2020-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Operator base class.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from functools import wraps from typing import Callable, Optional, Tuple, Union import numpy as np import jax import jax.numpy as jnp from jax.dtypes import result_type import scico import scico.numpy as snp from scico.numpy import Array, BlockArray from scico.numpy.util import dtype_name, is_nested, shape_to_size from scico.typing import BlockShape, DType, Shape def _wrap_mul_div_scalar(func: Callable) -> Callable: r"""Wrapper function for multiplication and division operators. Wrapper function for defining `__mul__`, `__rmul__`, and `__truediv__` between a scalar and an `Operator`. If one of these binary operations are called in the form `binop(Operator, other)` and 'b' is a scalar, specialized :class:`.Operator` constructors can be called. Args: func: should be either `.__mul__()`, `.__rmul__()`, or `.__truediv__()`. Returns: Wrapped version of `func`. Raises: TypeError: If a binop with the form `binop(Operator, other)` is called and `other` is not a scalar. """ @wraps(func) def wrapper(a, b): if snp.util.is_scalar_equiv(b): return func(a, b) raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.") wrapper._unwrapped = func # type: ignore return wrapper class Operator: """Generic operator class.""" # See https://numpy.org/doc/stable/user/c-info.beyond-basics.html#ndarray.__array_priority__ __array_priority__ = 1 def __init__( self, input_shape: Union[Shape, BlockShape], output_shape: Optional[Union[Shape, BlockShape]] = None, eval_fn: Optional[Callable] = None, input_dtype: DType = np.float32, output_dtype: Optional[DType] = None, jit: bool = False, ): r""" Args: input_shape: Shape of input array. output_shape: Shape of output array. Defaults to ``None``. If ``None``, `output_shape` is determined by evaluating `self.__call__` on an input array of zeros. eval_fn: Function used in evaluating this :class:`.Operator`. Defaults to ``None``. Required unless `__init__` is being called from a derived class with an `_eval` method. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. If the :class:`.Operator` implements complex-valued operations, this must be a complex dtype (typically :attr:`~numpy.complex64`) for correct adjoint and gradient calculation. output_dtype: `dtype` for output argument. Defaults to ``None``. If ``None``, `output_dtype` is determined by evaluating `self.__call__` on an input array of zeros. jit: If ``True``, call :meth:`Operator.jit()` on this :class:`.Operator` to jit the forward, adjoint, and gram functions. Same as calling :meth:`Operator.jit` after the :class:`.Operator` is created. Raises: NotImplementedError: If the `eval_fn` parameter is not specified and the `_eval` method is not defined in a derived class. """ #: Shape of input array or :class:`.BlockArray`. self.input_shape: Union[Shape, BlockShape] #: Size of flattened input. Sum of product of `input_shape` tuples. self.input_size: int #: Shape of output array or :class:`.BlockArray` self.output_shape: Union[Shape, BlockShape] #: Size of flattened output. Sum of product of `output_shape` tuples. self.output_size: int #: Shape Operator would take if it operated on flattened arrays. #: Consists of (output_size, input_size) self.matrix_shape: Tuple[int, int] #: Shape of Operator, consisting of (output_shape, input_shape). self.shape: Tuple[Union[Shape, BlockShape], Union[Shape, BlockShape]] #: Dtype of input self.input_dtype: DType #: Dtype of operator self.dtype: DType if isinstance(input_shape, int): self.input_shape = (input_shape,) else: self.input_shape = input_shape self.input_dtype = input_dtype # Allows for dynamic creation of new Operator/LinearOperator, e.g. for adjoints if eval_fn: self._eval = eval_fn # type: ignore elif not hasattr(self, "_eval"): raise NotImplementedError( "Operator is an abstract base class when argument 'eval_fn' is not specified." ) # If the output shape/dtype aren't specified, they can be inferred # using scico.eval_shape if output_shape is None or output_dtype is None: dts = scico.eval_shape( self._eval, jax.ShapeDtypeStruct(self.input_shape, dtype=input_dtype) ) if output_shape is None: self.output_shape = dts.shape # type: ignore else: self.output_shape = (output_shape,) if isinstance(output_shape, int) else output_shape if output_dtype is None: self.output_dtype = dts.dtype else: self.output_dtype = output_dtype # Determine the shape of the "vectorized" operator (as an element of ℝ^{n × m} # If the function returns a BlockArray we need to compute the size of each block, # then sum. self.input_size = shape_to_size(self.input_shape) self.output_size = shape_to_size(self.output_shape) self.shape = (self.output_shape, self.input_shape) self.matrix_shape = (self.output_size, self.input_size) if jit: self.jit() def jit(self): """Activate just-in-time compilation for the `_eval` method.""" self._eval = jax.jit(self._eval) def __str__(self): return f"""{self.__module__}.{self.__class__.__qualname__}""" def __repr__(self): return f"""{str(self)} input_shape: {self.input_shape} output_shape: {self.output_shape} input_dtype: {dtype_name(self.input_dtype)} output_dtype: {dtype_name(self.output_dtype)} """ def __call__(self, x: Union[Operator, Array, BlockArray]) -> Union[Operator, Array, BlockArray]: r"""Evaluate this :class:`Operator` at the point :math:`\mb{x}`. Args: x: Point at which to evaluate this :class:`.Operator`. If `x` is a :class:`jax.Array` or :class:`.BlockArray`, it must have `shape == self.input_shape`. If `x` is a :class:`.Operator` or :class:`.LinearOperator`, it must have `x.output_shape == self.input_shape`. Returns: :class:`.Operator` evaluated at `x`. Raises: ValueError: If the `input_shape` attribute of the :class:`.Operator` is not equal to the input array shape, or to the `output_shape` attribute of another :class:`.Operator` with which it is composed. """ if isinstance(x, Operator): # Compose the two operators if shapes conform if self.input_shape == x.output_shape: return Operator( input_shape=x.input_shape, output_shape=self.output_shape, eval_fn=lambda z: self(x(z)), input_dtype=self.input_dtype, output_dtype=x.output_dtype, ) raise ValueError(f"Incompatible shapes {self.shape}, {x.shape}.") if self.input_shape != x.shape: raise ValueError( f"Cannot evaluate {type(self)} with input_shape={self.input_shape} " f"on array with shape={x.shape}." ) return self._eval(x) def __add__(self, other: Operator) -> Operator: if isinstance(other, Operator): if self.shape == other.shape: return Operator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(x) + other(x), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other.output_dtype), ) raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.") raise TypeError(f"Operation __add__ not defined between {type(self)} and {type(other)}.") def __sub__(self, other: Operator) -> Operator: if isinstance(other, Operator): if self.shape == other.shape: return Operator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(x) - other(x), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other.output_dtype), ) raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.") raise TypeError(f"Operation __sub__ not defined between {type(self)} and {type(other)}.") @_wrap_mul_div_scalar def __mul__(self, other): return Operator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: other * self(x), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other), ) def __neg__(self) -> Operator: return -1.0 * self @_wrap_mul_div_scalar def __rmul__(self, other): return Operator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: other * self(x), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other), ) @_wrap_mul_div_scalar def __truediv__(self, other): return Operator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(x) / other, input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other), ) def jvp(self, u, v): r"""Compute a Jacobian-vector product. Compute the product :math:`J_F(\mb{u}) \mb{v}` where :math:`F` represents this operator and :math:`J_F(\mb{u})` is the Jacobian of :math:`F` evaluated at :math:`\mb{u}`. This method is implemented via a call to :func:`jax.jvp`. Args: u: Value at which the Jacobian is evaluated. v: Vector in the Jacobian-vector product. Returns: A pair :math:`(F(\mb{u}), J_F(\mb{u}) \mb{v})`, i.e. a pair consisting of the operator evaluated at :math:`\mb{u}` and the Jacobian-vector product. """ return jax.jvp(self, (u,), (v,)) def vjp(self, u, conjugate=True): r"""Compute a vector-Jacobian product. Compute the product :math:`[J_F(\mb{u})]^T \mb{v}` where :math:`F` represents this operator and :math:`J_F(\mb{u})` is the Jacobian of :math:`F` evaluated at :math:`\mb{u}`. Instead of directly computing the vector-Jacobian product, this method returns a function, taking :math:`\mb{v}` as an argument, that returns the product. This method is implemented via a call to :func:`jax.vjp`. Args: u: Value at which the Jacobian is evaluated. conjugate: If ``True``, compute the product using the conjugate (Hermitian) transpose. Returns: A pair :math:`(F(\mb{u}), G(\cdot))` where :math:`G(\cdot)` is a function that computes the vector-Jacobian product, i.e. :math:`G(\mb{v}) = [J_F(\mb{u})]^T \mb{v}` when `conjugate` is ``False``, or :math:`G(\mb{v}) = [J_F(\mb{u})]^H \mb{v}` when `conjugate` is ``True``. """ Fu, G = jax.vjp(self, u) if conjugate: def Gmap(v): return G(v.conj())[0].conj() else: def Gmap(v): return G(v)[0] return Fu, Gmap def freeze(self, argnum: int, val: Union[Array, BlockArray]) -> Operator: """Return a new :class:`.Operator` with fixed block argument. Return a new :class:`.Operator` with block argument `argnum` fixed to value `val`. Args: argnum: Index of block to freeze. Must be less than or equal to the number of blocks in an input array. val: Value to fix the `argnum`-th input to. Returns: A new :class:`.Operator` with one of the blocks of the input fixed to the specified value. Raises: ValueError: If the :class:`.Operator` does not take a :class:`.BlockArray` as its input, if the block index equals or exceeds the number of blocks, or if the shape of the fixed value differs from the shape of the specified block. """ if not is_nested(self.input_shape): raise ValueError( "The freeze method can only be applied to Operators that take BlockArray inputs." ) input_ndim = len(self.input_shape) if argnum > input_ndim - 1: raise ValueError( f"Argument 'argnum' must be fewer than the number of input arguments to " f"this operator ({input_ndim}); got {argnum}." ) if val.shape != self.input_shape[argnum]: raise ValueError( f"Value to be frozen at position {argnum} must have shape " f"{self.input_shape[argnum]}, got {val.shape}." ) input_shape: Union[Shape, BlockShape] input_shape = tuple(s for i, s in enumerate(self.input_shape) if i != argnum) # type: ignore if len(input_shape) == 1: input_shape = input_shape[0] # type: ignore def concat_args(args): # Create a blockarray with args and the frozen value in the correct place # E.g. if this operator takes a blockarray with two blocks, then # concat_args(args) = snp.blockarray([val, args]) if argnum = 0 # concat_args(args) = snp.blockarray([args, val]) if argnum = 1 if isinstance(args, (jnp.ndarray, np.ndarray)): # In the case that the original operator takes a blockarray with two # blocks, wrap in a list so we can use the same indexing as >2 block case args = [args] arg_list = [] for i in range(input_ndim): if i < argnum: arg_list.append(args[i]) elif i > argnum: arg_list.append(args[i - 1]) else: arg_list.append(val) return snp.blockarray(arg_list) return Operator( input_shape=input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(concat_args(x)), ) ================================================ FILE: scico/operator/_stack.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2023-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Stack of operators classes.""" from __future__ import annotations from typing import Optional, Sequence, Tuple, Union import numpy as np import jax import scico.numpy as snp from scico.numpy import Array, BlockArray from scico.numpy.util import is_blockable, is_collapsible, is_nested from scico.typing import BlockShape, Shape from ._operator import Operator def collapse_shapes( shapes: Sequence[Union[Shape, BlockShape]], allow_collapse=True ) -> Tuple[Union[Shape, BlockShape], bool]: """Compute the collapsed representation of a sequence of shapes. Decide whether to collapse a sequence of shapes, returning either the sequence of shapes or a collapsed shape, and a boolean indicating whether the shape was collapsed.""" if is_collapsible(shapes) and allow_collapse: return (len(shapes), *shapes[0]), True if is_blockable(shapes): return shapes, False raise ValueError( "Combining these shapes would result in a twice-nested BlockArray, which is not supported." ) class VerticalStack(Operator): r"""A vertical stack of operators. Given operators :math:`A_1, A_2, \dots, A_N`, create the operator :math:`H` such that .. math:: H(\mb{x}) = \begin{pmatrix} A_1(\mb{x}) \\ A_2(\mb{x}) \\ \vdots \\ A_N(\mb{x}) \\ \end{pmatrix} \;. """ def __init__( self, ops: Sequence[Operator], collapse_output: Optional[bool] = True, jit: bool = True, **kwargs, ): r""" Args: ops: Operators to stack. collapse_output: If ``True`` and the output would be a :class:`BlockArray` with shape ((m, n, ...), (m, n, ...), ...), the output is instead a :class:`jax.Array` with shape (S, m, n, ...) where S is the length of `ops`. jit: See `jit` in :class:`Operator`. """ VerticalStack.check_if_stackable(ops) self.ops = ops self.collapse_output = collapse_output output_shapes = tuple(op.output_shape for op in ops) self.output_collapsible = is_collapsible(output_shapes) if self.output_collapsible and self.collapse_output: output_shape = (len(ops),) + output_shapes[0] # collapse to jax array else: output_shape = output_shapes super().__init__( input_shape=ops[0].input_shape, output_shape=output_shape, # type: ignore input_dtype=ops[0].input_dtype, output_dtype=ops[0].output_dtype, jit=jit, **kwargs, ) @staticmethod def check_if_stackable(ops: Sequence[Operator]): """Check that input ops are suitable for stack creation.""" if not isinstance(ops, (list, tuple)): raise TypeError("Expected a list of Operator.") input_shapes = [op.shape[1] for op in ops] if not all(input_shapes[0] == s for s in input_shapes): raise ValueError( "Expected all Operators to have the same input shapes, " f"but got {input_shapes}." ) input_dtypes = [op.input_dtype for op in ops] if not all(input_dtypes[0] == s for s in input_dtypes): raise ValueError( "Expected all Operators to have the same input dtype, " f"but got {input_dtypes}." ) if any([is_nested(op.shape[0]) for op in ops]): raise ValueError("Cannot stack Operators with nested output shapes.") output_dtypes = [op.output_dtype for op in ops] if not np.all(output_dtypes[0] == s for s in output_dtypes): raise ValueError("Expected all Operators to have the same output dtype.") def _eval(self, x: Array) -> Union[Array, BlockArray]: if self.output_collapsible and self.collapse_output: return snp.stack([op(x) for op in self.ops]) return BlockArray([op(x) for op in self.ops]) def __repr__(self): crepr = ", ".join([str(f) for f in self.ops]) return Operator.__repr__(self) + f""" components: {crepr}\n""" class DiagonalStack(Operator): r"""A diagonal stack of operators. Given operators :math:`A_1, A_2, \dots, A_N`, create the operator :math:`H` such that .. math:: H \left( \begin{pmatrix} \mb{x}_1 \\ \mb{x}_2 \\ \vdots \\ \mb{x}_N \\ \end{pmatrix} \right) = \begin{pmatrix} A_1(\mb{x}_1) \\ A_2(\mb{x}_2) \\ \vdots \\ A_N(\mb{x}_N) \\ \end{pmatrix} \;. By default, if the inputs :math:`\mb{x}_1, \mb{x}_2, \dots, \mb{x}_N` all have the same (possibly nested) shape, `S`, this operator will work on the stack, i.e., have an input shape of `(N, *S)`. If the inputs have distinct shapes, `S1`, `S2`, ..., `SN`, this operator will work on the block concatenation, i.e., have an input shape of `(S1, S2, ..., SN)`. The same holds for the output shape. """ def __init__( self, ops: Sequence[Operator], collapse_input: Optional[bool] = True, collapse_output: Optional[bool] = True, jit: bool = True, **kwargs, ): """ Args: ops: Operators to stack. collapse_input: If ``True``, inputs are expected to be stacked along the first dimension when possible. collapse_output: If ``True``, the output will be stacked along the first dimension when possible. jit: See `jit` in :class:`Operator`. """ DiagonalStack.check_if_stackable(ops) self.ops = ops input_shape, self.collapse_input = collapse_shapes( tuple(op.input_shape for op in ops), collapse_input, ) output_shape, self.collapse_output = collapse_shapes( tuple(op.output_shape for op in ops), collapse_output, ) super().__init__( input_shape=input_shape, output_shape=output_shape, input_dtype=ops[0].input_dtype, output_dtype=ops[0].output_dtype, jit=jit, **kwargs, ) @staticmethod def check_if_stackable(ops: Sequence[Operator]): """Check that input ops are suitable for stack creation.""" if not isinstance(ops, (list, tuple)): raise TypeError("Expected a list of Operator.") if any([is_nested(op.shape[0]) for op in ops]): raise ValueError("Cannot stack Operators with nested output shapes.") output_dtypes = [op.output_dtype for op in ops] if not np.all(output_dtypes[0] == s for s in output_dtypes): raise ValueError("Expected all Operators to have the same output dtype.") def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: result = tuple(op(x_n) for op, x_n in zip(self.ops, x)) if self.collapse_output: return snp.stack(result) return snp.blockarray(result) def __repr__(self): crepr = ", ".join([str(f) for f in self.ops]) return Operator.__repr__(self) + f""" components: {crepr}\n""" class DiagonalReplicated(Operator): r"""A diagonal stack constructed from a single operator. Given operator :math:`A`, create the operator :math:`H` such that .. math:: H \left( \begin{pmatrix} \mb{x}_1 \\ \mb{x}_2 \\ \vdots \\ \mb{x}_N \\ \end{pmatrix} \right) = \begin{pmatrix} A(\mb{x}_1) \\ A(\mb{x}_2) \\ \vdots \\ A(\mb{x}_N) \\ \end{pmatrix} \;. The application of :math:`A` to each component :math:`\mb{x}_k` is computed using :func:`jax.pmap` or :func:`jax.vmap`. The input shape for operator :math:`A` should exclude the array axis on which :math:`A` is replicated to form :math:`H`. For example, if :math:`A` has input shape `(3, 4)` and :math:`H` is constructed to replicate on axis 0 with 2 replicates, the input shape of :math:`H` will be `(2, 3, 4)`. Operators taking :class:`.BlockArray` input are not supported. """ def __init__( self, op: Operator, replicates: int, input_axis: int = 0, output_axis: Optional[int] = None, map_type: str = "auto", **kwargs, ): """ Args: op: Operator to replicate. replicates: Number of replicates of `op`. input_axis: Input axis over which `op` should be replicated. output_axis: Index of replication axis in output array. If ``None``, the input replication axis is used. map_type: If "pmap" or "vmap", apply replicated mapping using :func:`jax.pmap` or :func:`jax.vmap` respectively. If "auto", use :func:`jax.pmap` if sufficient devices are available for the number of replicates, otherwise use :func:`jax.vmap`. """ if map_type not in ["auto", "pmap", "vmap"]: raise ValueError("Argument 'map_type' must be one of 'auto', 'pmap, or 'vmap'.") if input_axis < 0: input_axis = len(op.input_shape) + 1 + input_axis if input_axis < 0 or input_axis > len(op.input_shape): raise ValueError( "Argument 'input_axis' must be positive and less than the number of axes " "in the input shape of argument 'op'." ) if is_nested(op.input_shape): raise ValueError("Argument 'op' may not be an Operator taking BlockArray input.") if is_nested(op.output_shape): raise ValueError("Argument 'op' may not be an Operator with BlockArray output.") self.op = op self.replicates = replicates self.input_axis = input_axis self.output_axis = self.input_axis if output_axis is None else output_axis if map_type == "auto": self.jaxmap = jax.pmap if replicates <= jax.device_count() else jax.vmap else: if map_type == "pmap" and replicates > jax.device_count(): raise ValueError( "Requested pmap mapping but number of replicates exceeds device count." ) else: self.jaxmap = jax.pmap if map_type == "pmap" else jax.vmap eval_fn = self.jaxmap(op.__call__, in_axes=self.input_axis, out_axes=self.output_axis) input_shape = ( op.input_shape[0 : self.input_axis] + (replicates,) + op.input_shape[self.input_axis :] ) output_shape = ( op.output_shape[0 : self.output_axis] + (replicates,) + op.output_shape[self.output_axis :] ) super().__init__( input_shape=input_shape, # type: ignore output_shape=output_shape, # type: ignore eval_fn=eval_fn, input_dtype=op.input_dtype, output_dtype=op.output_dtype, jit=False, **kwargs, ) def __repr__(self): return ( Operator.__repr__(self) + f""" component: {str(self.op)}\n replicates: {self.replicates}\n""" ) ================================================ FILE: scico/operator/biconvolve.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Biconvolution operator.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from typing import Tuple, cast import numpy as np from jax.scipy.signal import convolve import scico.linop from scico.numpy import Array, BlockArray from scico.numpy.util import is_nested from scico.typing import DType, Shape from ._operator import Operator class BiConvolve(Operator): """Biconvolution operator. A :class:`.BiConvolve` operator accepts a :class:`.BlockArray` input with two blocks of equal ndims, and convolves the first block with the second. If `A` is a :class:`.BiConvolve` operator, then `A(snp.blockarray([x, h]))` equals `jax.scipy.signal.convolve(x, h)`. """ def __init__( self, input_shape: Tuple[Shape, Shape], input_dtype: DType = np.float32, mode: str = "full", jit: bool = True, ): r""" Args: input_shape: Shape of input :class:`.BlockArray`. Must correspond to a :class:`.`BlockArray` with two blocks of equal ndims. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. mode: A string indicating the size of the output. One of "full", "valid", "same". Defaults to "full". jit: If ``True``, jit the evaluation of this :class:`.Operator`. For more details on `mode`, see :func:`jax.scipy.signal.convolve`. """ if not is_nested(input_shape): raise ValueError("A BlockShape is expected; got {input_shape}.") if len(input_shape) != 2: raise ValueError( f"Argument 'input_shape' must have two blocks; got {len(input_shape)}." ) if len(input_shape[0]) != len(input_shape[1]): raise ValueError( f"Both input blocks must have same number of dimensions; got " f"{len(input_shape[0]), len(input_shape[1])}." ) if mode not in ["full", "valid", "same"]: raise ValueError(f"Invalid mode={mode}; must be one of 'full', 'valid', 'same'.") self.mode = mode super().__init__(input_shape=input_shape, input_dtype=input_dtype, jit=jit) def _eval(self, x: BlockArray) -> Array: return convolve(x[0], x[1], mode=self.mode) def freeze(self, argnum: int, val: Array) -> scico.linop.LinearOperator: """Freeze the `argnum` parameter. Return a new :class:`.LinearOperator` with block argument `argnum` fixed to value `val`. If `argnum == 0`, a :class:`.ConvolveByX` object is returned. If `argnum == 1`, a :class:`.Convolve` object is returned. Args: argnum: Index of block to freeze. Must be 0 or 1. val: Value to fix the `argnum`-th input to. """ if argnum == 0: return scico.linop.ConvolveByX( x=val, input_shape=cast(Shape, self.input_shape[1]), input_dtype=self.input_dtype, output_shape=self.output_shape, mode=self.mode, ) if argnum == 1: return scico.linop.Convolve( h=val, input_shape=cast(Shape, self.input_shape[0]), input_dtype=self.input_dtype, output_shape=self.output_shape, mode=self.mode, ) raise ValueError(f"Argument 'argnum' must be 0 or 1; got {argnum}.") ================================================ FILE: scico/optimize/__init__.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Optimization algorithms.""" import sys # isort: off from .admm import ADMM from ._common import Optimizer from ._ladmm import LinearizedADMM from .pgm import PGM, AcceleratedPGM from ._primaldual import PDHG from ._padmm import ProximalADMM, NonLinearPADMM, ProximalADMMBase __all__ = [ "ADMM", "LinearizedADMM", "ProximalADMM", "ProximalADMMBase", "NonLinearPADMM", "PGM", "AcceleratedPGM", "PDHG", "Optimizer", ] # Imported items in __all__ appear to originate in top-level linop module for name in __all__: getattr(sys.modules[__name__], name).__module__ = __name__ ================================================ FILE: scico/optimize/_admm.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """ADMM solver.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from typing import List, Optional, Tuple, Union import scico.numpy as snp from scico.functional import Functional from scico.linop import LinearOperator from scico.numpy import Array, BlockArray from scico.numpy.linalg import norm from ._admmaux import ( FBlockCircularConvolveSolver, G0BlockCircularConvolveSolver, GenericSubproblemSolver, LinearSubproblemSolver, MatrixSubproblemSolver, SubproblemSolver, ) from ._common import Optimizer class ADMM(Optimizer): r"""Basic Alternating Direction Method of Multipliers (ADMM) algorithm. | Solve an optimization problem of the form .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + \sum_{i=1}^N g_i(C_i \mb{x}) \;, where :math:`f` and the :math:`g_i` are instances of :class:`.Functional`, and the :math:`C_i` are :class:`.LinearOperator`. The optimization problem is solved by introducing the splitting :math:`\mb{z}_i = C_i \mb{x}` and solving .. math:: \argmin_{\mb{x}, \mb{z}_i} \; f(\mb{x}) + \sum_{i=1}^N g_i(\mb{z}_i) \; \text{such that}\; C_i \mb{x} = \mb{z}_i \;, via an ADMM algorithm :cite:`glowinski-1975-approximation` :cite:`gabay-1976-dual` :cite:`boyd-2010-distributed` consisting of the iterations (see :meth:`step`) .. math:: \begin{aligned} \mb{x}^{(k+1)} &= \argmin_{\mb{x}} \; f(\mb{x}) + \sum_i \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \\ \mb{z}_i^{(k+1)} &= \argmin_{\mb{z}_i} \; g_i(\mb{z}_i) + \frac{\rho_i}{2} \norm{\mb{z}_i - \mb{u}^{(k)}_i - C_i \mb{x}^{(k+1)}}_2^2 \\ \mb{u}_i^{(k+1)} &= \mb{u}_i^{(k)} + C_i \mb{x}^{(k+1)} - \mb{z}^{(k+1)}_i \; . \end{aligned} Attributes: f (:class:`.Functional`): Functional :math:`f` (usually a :class:`.Loss`) g_list (list of :class:`.Functional`): List of :math:`g_i` functionals. Must be same length as :code:`C_list` and :code:`rho_list`. C_list (list of :class:`.LinearOperator`): List of :math:`C_i` operators. rho_list (list of scalars): List of :math:`\rho_i` penalty parameters. Must be same length as :code:`C_list` and :code:`g_list`. alpha (float): Relaxation parameter. u_list (list of array-like): List of scaled Lagrange multipliers :math:`\mb{u}_i` at current iteration. x (array-like): Solution. subproblem_solver (:class:`.SubproblemSolver`): Solver for :math:`\mb{x}`-update step. z_list (list of array-like): List of auxiliary variables :math:`\mb{z}_i` at current iteration. z_list_old (list of array-like): List of auxiliary variables :math:`\mb{z}_i` at previous iteration. """ def __init__( self, f: Functional, g_list: List[Functional], C_list: List[LinearOperator], rho_list: List[float], alpha: float = 1.0, x0: Optional[Union[Array, BlockArray]] = None, subproblem_solver: Optional[SubproblemSolver] = None, **kwargs, ): r"""Initialize an :class:`ADMM` object. Args: f: Functional :math:`f` (usually a loss function). g_list: List of :math:`g_i` functionals. Must be same length as :code:`C_list` and :code:`rho_list`. C_list: List of :math:`C_i` operators. rho_list: List of :math:`\rho_i` penalty parameters. Must be same length as :code:`C_list` and :code:`g_list`. alpha: Relaxation parameter. No relaxation for default 1.0. x0: Initial value for :math:`\mb{x}`. If ``None``, defaults to an array of zeros. subproblem_solver: Solver for :math:`\mb{x}`-update step. Defaults to ``None``, which implies use of an instance of :class:`GenericSubproblemSolver`. **kwargs: Additional optional parameters handled by initializer of base class :class:`.Optimizer`. """ N = len(g_list) if len(C_list) != N: raise ValueError(f"len(C_list)={len(C_list)} not equal to len(g_list)={N}.") if len(rho_list) != N: raise ValueError(f"len(rho_list)={len(rho_list)} not equal to len(g_list)={N}.") self.f: Functional = f self.g_list: List[Functional] = g_list self.C_list: List[LinearOperator] = C_list self.rho_list: List[float] = rho_list self.alpha: float = alpha if subproblem_solver is None: subproblem_solver = GenericSubproblemSolver() self.subproblem_solver: SubproblemSolver = subproblem_solver self.subproblem_solver.internal_init(self) if x0 is None: input_shape = C_list[0].input_shape dtype = C_list[0].input_dtype x0 = snp.zeros(input_shape, dtype=dtype) self.x = x0 self.z_list, self.z_list_old = self.z_init(self.x) self.u_list = self.u_init(self.x) super().__init__(**kwargs) def _working_vars_finite(self) -> bool: """Determine where ``NaN`` of ``Inf`` encountered in solve. Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in a solver working variable. """ for v in ( [ self.x, ] + self.z_list + self.u_list ): if not snp.all(snp.isfinite(v)): return False return True def _objective_evaluatable(self): """Determine whether the objective function can be evaluated.""" return (not self.f or self.f.has_eval) and all([_.has_eval for _ in self.g_list]) def _itstat_extra_fields(self): """Define ADMM-specific iteration statistics fields.""" itstat_fields = {"Prml Rsdl": "%9.3e", "Dual Rsdl": "%9.3e"} itstat_attrib = ["norm_primal_residual()", "norm_dual_residual()"] # subproblem solver info when available if isinstance(self.subproblem_solver, GenericSubproblemSolver): itstat_fields.update({"Num FEv": "%6d", "Num It": "%6d"}) itstat_attrib.extend( ["subproblem_solver.info['nfev']", "subproblem_solver.info['nit']"] ) elif ( type(self.subproblem_solver) == LinearSubproblemSolver and self.subproblem_solver.cg_function == "scico" ): itstat_fields.update({"CG It": "%5d", "CG Res": "%9.3e"}) itstat_attrib.extend( ["subproblem_solver.info['num_iter']", "subproblem_solver.info['rel_res']"] ) elif ( type(self.subproblem_solver) in [MatrixSubproblemSolver, FBlockCircularConvolveSolver, G0BlockCircularConvolveSolver] and self.subproblem_solver.check_solve ): itstat_fields.update({"Slv Res": "%9.3e"}) itstat_attrib.extend(["subproblem_solver.accuracy"]) return itstat_fields, itstat_attrib def _state_variable_names(self) -> List[str]: # While x is in the most abstract sense not part of the algorithm # state, it does form part of the state in pratice due to its use # as an initializer for iterative solvers for the x step of the # ADMM algorithm. return ["x", "z_list", "z_list_old", "u_list"] def minimizer(self) -> Union[Array, BlockArray]: return self.x def objective( self, x: Optional[Union[Array, BlockArray]] = None, z_list: Optional[List[Union[Array, BlockArray]]] = None, ) -> float: r"""Evaluate the objective function. Evaluate the objective function .. math:: f(\mb{x}) + \sum_{i=1}^N g_i(\mb{z}_i) \;. Note that this form is cheaper to compute, but may have very poor accuracy compared with the "true" objective function .. math:: f(\mb{x}) + \sum_{i=1}^N g_i(C_i \mb{x}) \;. when the primal residual is large. Args: x: Point at which to evaluate objective function. If ``None``, the objective is evaluated at the current iterate :code:`self.x`. z_list: Point at which to evaluate objective function. If ``None``, the objective is evaluated at the current iterate :code:`self.z_list`. Returns: Value of the objective function. """ if (x is None) != (z_list is None): raise ValueError("Both or neither of arguments 'x' and 'z_list' must be supplied.") if x is None: x = self.x z_list = self.z_list assert z_list is not None out = 0.0 if self.f: out += self.f(x) for g, z in zip(self.g_list, z_list): out += g(z) return out def norm_primal_residual(self, x: Optional[Union[Array, BlockArray]] = None) -> float: r"""Compute the :math:`\ell_2` norm of the primal residual. Compute the :math:`\ell_2` norm of the primal residual .. math:: \left( \sum_{i=1}^N \rho_i \left\| C_i \mb{x} - \mb{z}_i^{(k)} \right\|_2^2\right)^{1/2} \;. Args: x: Point at which to evaluate primal residual. If ``None``, the primal residual is evaluated at the current iterate :code:`self.x`. Returns: Norm of primal residual. """ if x is None: x = self.x sum = 0.0 for rhoi, Ci, zi in zip(self.rho_list, self.C_list, self.z_list): sum += rhoi * norm(Ci(self.x) - zi) ** 2 return snp.sqrt(sum) def norm_dual_residual(self) -> float: r"""Compute the :math:`\ell_2` norm of the dual residual. Compute the :math:`\ell_2` norm of the dual residual .. math:: \left\| \sum_{i=1}^N \rho_i C_i^T \left( \mb{z}^{(k)}_i - \mb{z}^{(k-1)}_i \right) \right\|_2 \;. Returns: Norm of dual residual. """ sum = 0.0 for rhoi, zi, ziold, Ci in zip(self.rho_list, self.z_list, self.z_list_old, self.C_list): sum += rhoi * Ci.adj(zi - ziold) return norm(sum) def z_init( self, x0: Union[Array, BlockArray] ) -> Tuple[List[Union[Array, BlockArray]], List[Union[Array, BlockArray]]]: r"""Initialize auxiliary variables :math:`\mb{z}_i`. Initialized to .. math:: \mb{z}_i = C_i \mb{x}^{(0)} \;. :code:`z_list` and :code:`z_list_old` are initialized to the same value. Args: x0: Initial value of :math:`\mb{x}`. """ z_list: List[Union[Array, BlockArray]] = [Ci(x0) for Ci in self.C_list] z_list_old = z_list.copy() return z_list, z_list_old def u_init(self, x0: Union[Array, BlockArray]) -> List[Union[Array, BlockArray]]: r"""Initialize scaled Lagrange multipliers :math:`\mb{u}_i`. Initialized to .. math:: \mb{u}_i = \mb{0} \;. Note that the parameter `x0` is unused, but is provided for potential use in an overridden method. Args: x0: Initial value of :math:`\mb{x}`. """ u_list = [snp.zeros(Ci.output_shape, dtype=Ci.output_dtype) for Ci in self.C_list] return u_list def step(self): r"""Perform a single ADMM iteration. The primary variable :math:`\mb{x}` is updated by solving the the optimization problem .. math:: \mb{x}^{(k+1)} = \argmin_{\mb{x}} \; f(\mb{x}) + \sum_i \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;. Update auxiliary variables :math:`\mb{z}_i` and scaled Lagrange multipliers :math:`\mb{u}_i`. The auxiliary variables are updated according to .. math:: \begin{aligned} \mb{z}_i^{(k+1)} &= \argmin_{\mb{z}_i} \; g_i(\mb{z}_i) + \frac{\rho_i}{2} \norm{\mb{z}_i - \mb{u}^{(k)}_i - C_i \mb{x}^{(k+1)}}_2^2 \\ &= \mathrm{prox}_{g_i}(C_i \mb{x} + \mb{u}_i, 1 / \rho_i) \;, \end{aligned} and the scaled Lagrange multipliers are updated according to .. math:: \mb{u}_i^{(k+1)} = \mb{u}_i^{(k)} + C_i \mb{x}^{(k+1)} - \mb{z}^{(k+1)}_i \;. """ self.x = self.subproblem_solver.solve(self.x) self.z_list_old = self.z_list.copy() for i, (rhoi, gi, Ci, zi, ui) in enumerate( zip(self.rho_list, self.g_list, self.C_list, self.z_list, self.u_list) ): if self.alpha == 1.0: Cix = Ci(self.x) else: Cix = self.alpha * Ci(self.x) + (1.0 - self.alpha) * zi zi = gi.prox(Cix + ui, 1 / rhoi, v0=zi) ui = ui + Cix - zi self.z_list[i] = zi self.u_list[i] = ui ================================================ FILE: scico/optimize/_admmaux.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """ADMM auxiliary classes.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from functools import reduce from typing import Any, Optional, Union import jax from jax.scipy.sparse.linalg import cg as jax_cg import scico.numpy as snp import scico.optimize.admm as soa from scico.functional import ZeroFunctional from scico.linop import ( CircularConvolve, ComposedLinearOperator, Diagonal, Identity, LinearOperator, MatrixOperator, ) from scico.loss import SquaredL2Loss from scico.numpy import Array, BlockArray from scico.numpy.util import is_real_dtype from scico.solver import ConvATADSolver, MatrixATADSolver from scico.solver import cg as scico_cg from scico.solver import minimize class SubproblemSolver: r"""Base class for solvers for the non-separable ADMM step. The ADMM solver implemented by :class:`.ADMM` addresses a general problem form for which one of the corresponding ADMM algorithm subproblems is separable into distinct subproblems for each of the :math:`g_i`, and another that is non-separable, involving function :math:`f` and a sum over :math:`\ell_2` norm terms involving all operators :math:`C_i`. This class is a base class for solvers of the latter subproblem .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + \sum_i \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;. Attributes: admm (:class:`.ADMM`): ADMM solver object to which the solver is attached. """ def internal_init(self, admm: soa.ADMM): """Second stage initializer to be called by :meth:`.ADMM.__init__`. Args: admm: Reference to :class:`.ADMM` object to which the :class:`.SubproblemSolver` object is to be attached. """ self.admm = admm class GenericSubproblemSolver(SubproblemSolver): """Solver for generic problem without special structure. Note that this solver is only suitable for small-scale problems. Attributes: admm (:class:`.ADMM`): ADMM solver object to which the solver is attached. minimize_kwargs (dict): Dictionary of arguments for :func:`scico.solver.minimize`. """ def __init__(self, minimize_kwargs: dict = {"options": {"maxiter": 100}}): """Initialize a :class:`GenericSubproblemSolver` object. Args: minimize_kwargs: Dictionary of arguments for :func:`scico.solver.minimize`. """ self.minimize_kwargs = minimize_kwargs self.info: dict = {} def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]: """Solve the ADMM step. Args: x0: Initial value. Returns: Computed solution. """ @jax.jit def obj(x): out = 0.0 for rhoi, Ci, zi, ui in zip( self.admm.rho_list, self.admm.C_list, self.admm.z_list, self.admm.u_list ): out += 0.5 * rhoi * snp.sum(snp.abs(zi - ui - Ci(x)) ** 2) if self.admm.f is not None: out += self.admm.f(x) return out res = minimize(obj, x0, **self.minimize_kwargs) for attrib in ("success", "status", "message", "nfev", "njev", "nhev", "nit", "maxcv"): self.info[attrib] = getattr(res, attrib, None) return res.x class LinearSubproblemSolver(SubproblemSolver): r"""Solver for quadratic functionals. Solver for the case in which :code:`f` is a quadratic function of :math:`\mb{x}`. It is a specialization of :class:`.SubproblemSolver` for the case where :code:`f` is an :math:`\ell_2` or weighted :math:`\ell_2` norm, and :code:`f.A` is a linear operator, so that the subproblem involves solving a linear equation. This requires that :code:`f.functional` be an instance of :class:`.SquaredL2Loss` and for the forward operator :code:`f.A` to be an instance of :class:`.LinearOperator`. The :math:`\mb{x}`-update step is .. math:: \mb{x}^{(k+1)} = \argmin_{\mb{x}} \; \frac{1}{2} \norm{\mb{y} - A \mb{x}}_W^2 + \sum_i \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;, where :math:`W` a weighting :class:`.Diagonal` operator or an :class:`.Identity` operator (i.e., no weighting). This update step reduces to the solution of the linear system .. math:: \left(A^H W A + \sum_{i=1}^N \rho_i C_i^H C_i \right) \mb{x}^{(k+1)} = \; A^H W \mb{y} + \sum_{i=1}^N \rho_i C_i^H ( \mb{z}^{(k)}_i - \mb{u}^{(k)}_i) \;. Attributes: admm (:class:`.ADMM`): ADMM solver object to which the solver is attached. cg_kwargs (dict): Dictionary of arguments for CG solver. cg (func): CG solver function (:func:`scico.solver.cg` or :func:`jax.scipy.sparse.linalg.cg`) lhs (type): Function implementing the linear operator needed for the :math:`\mb{x}` update step. """ def __init__(self, cg_kwargs: Optional[dict[str, Any]] = None, cg_function: str = "scico"): """Initialize a :class:`LinearSubproblemSolver` object. Args: cg_kwargs: Dictionary of arguments for CG solver. See documentation for :func:`scico.solver.cg` or :func:`jax.scipy.sparse.linalg.cg`, including how to specify a preconditioner. Default values are the same as those of :func:`scico.solver.cg`, except for `"tol": 1e-4` and `"maxiter": 100`. cg_function: String indicating which CG implementation to use. One of "jax" or "scico"; default "scico". If "scico", uses :func:`scico.solver.cg`. If "jax", uses :func:`jax.scipy.sparse.linalg.cg`. The "jax" option is slower on small-scale problems or problems involving external functions, but can be differentiated through. The "scico" option is faster on small-scale problems, but slower on large-scale problems where the forward operator is written entirely in jax. """ default_cg_kwargs = {"tol": 1e-4, "maxiter": 100} if cg_kwargs: default_cg_kwargs.update(cg_kwargs) self.cg_kwargs = default_cg_kwargs self.cg_function = cg_function if cg_function == "scico": self.cg = scico_cg elif cg_function == "jax": self.cg = jax_cg else: raise ValueError( f"Argument 'cg_function' must be one of 'jax', 'scico'; got {cg_function}." ) self.info = None def internal_init(self, admm: soa.ADMM): if admm.f is not None: if not isinstance(admm.f, SquaredL2Loss): raise TypeError( "LinearSubproblemSolver requires f to be a scico.loss.SquaredL2Loss; " f"got {type(admm.f)}." ) if not isinstance(admm.f.A, LinearOperator): raise TypeError( "LinearSubproblemSolver requires f.A to be a scico.linop.LinearOperator; " f"got {type(admm.f.A)}." ) super().internal_init(admm) # call method of SubproblemSolver via GenericSubproblemSolver # Set lhs_op = \sum_i rho_i * Ci.H @ Ci # Use reduce as the initialization of this sum is messy otherwise lhs_op = reduce( lambda a, b: a + b, [rhoi * Ci.gram_op for rhoi, Ci in zip(admm.rho_list, admm.C_list)] ) if admm.f is not None: # hessian = A.T @ W @ A; W may be identity lhs_op += admm.f.hessian self.lhs_op = lhs_op def compute_rhs(self) -> Union[Array, BlockArray]: r"""Compute the right hand side of the linear equation to be solved. Compute .. math:: A^H W \mb{y} + \sum_{i=1}^N \rho_i C_i^H ( \mb{z}^{(k)}_i - \mb{u}^{(k)}_i) \;. Returns: Computed solution. """ C0 = self.admm.C_list[0] rhs = snp.zeros(C0.input_shape, C0.input_dtype) if self.admm.f is not None: ATWy = self.admm.f.A.adj(self.admm.f.W.diagonal * self.admm.f.y) # type: ignore rhs += 2.0 * self.admm.f.scale * ATWy # type: ignore for rhoi, Ci, zi, ui in zip( self.admm.rho_list, self.admm.C_list, self.admm.z_list, self.admm.u_list ): rhs += rhoi * Ci.adj(zi - ui) return rhs def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]: """Solve the ADMM step. Args: x0: Initial value. Returns: Computed solution. """ rhs = self.compute_rhs() x, self.info = self.cg(self.lhs_op, rhs, x0, **self.cg_kwargs) # type: ignore return x class MatrixSubproblemSolver(LinearSubproblemSolver): r"""Solver for quadratic functionals involving matrix operators. Solver for the case in which :math:`f` is a quadratic function of :math:`\mb{x}`, and :math:`A` and all the :math:`C_i` are diagonal or matrix operators. It is a specialization of :class:`.LinearSubproblemSolver`. As for :class:`.LinearSubproblemSolver`, the :math:`\mb{x}`-update step is .. math:: \mb{x}^{(k+1)} = \argmin_{\mb{x}} \; \frac{1}{2} \norm{\mb{y} - A \mb{x}}_W^2 + \sum_i \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;, where :math:`W` is a weighting :class:`.Diagonal` operator or an :class:`.Identity` operator (i.e., no weighting). This update step reduces to the solution of the linear system .. math:: \left(A^H W A + \sum_{i=1}^N \rho_i C_i^H C_i \right) \mb{x}^{(k+1)} = \; A^H W \mb{y} + \sum_{i=1}^N \rho_i C_i^H ( \mb{z}^{(k)}_i - \mb{u}^{(k)}_i) \;, which is solved by factorization of the left hand side of the equation, using :class:`.MatrixATADSolver`. Attributes: admm (:class:`.ADMM`): ADMM solver object to which the solver is attached. solve_kwargs (dict): Dictionary of arguments for solver :class:`.MatrixATADSolver` initialization. """ def __init__(self, check_solve: bool = False, solve_kwargs: Optional[dict[str, Any]] = None): """Initialize a :class:`MatrixSubproblemSolver` object. Args: check_solve: If ``True``, compute solver accuracy after each solve. solve_kwargs: Dictionary of arguments for solver :class:`.MatrixATADSolver` initialization. """ self.check_solve = check_solve default_solve_kwargs = {"cho_factor": False} if solve_kwargs: default_solve_kwargs.update(solve_kwargs) self.solve_kwargs = default_solve_kwargs def internal_init(self, admm: soa.ADMM): if admm.f is not None: if not isinstance(admm.f, SquaredL2Loss): raise TypeError( "MatrixSubproblemSolver requires f to be a scico.loss.SquaredL2Loss; " f"got {type(admm.f)}." ) if not isinstance(admm.f.A, (Diagonal, MatrixOperator)): raise TypeError( "MatrixSubproblemSolver requires f.A to be a Diagonal or MatrixOperator; " f"got {type(admm.f.A)}." ) for i, Ci in enumerate(admm.C_list): if not isinstance(Ci, (Diagonal, MatrixOperator)): raise TypeError( "MatrixSubproblemSolver requires C[{i}] to be a Diagonal or MatrixOperator; " f"got {type(Ci)}." ) SubproblemSolver.internal_init(self, admm) if admm.f is None: A = snp.zeros(admm.C_list[0].input_shape[0], dtype=admm.C_list[0].input_dtype) W = None else: A = admm.f.A W = 2.0 * self.admm.f.scale * admm.f.W # type: ignore Csum = reduce( lambda a, b: a + b, [rhoi * Ci.gram_op for rhoi, Ci in zip(admm.rho_list, admm.C_list)] ) self.solver = MatrixATADSolver(A, Csum, W, **self.solve_kwargs) def solve(self, x0: Array) -> Array: """Solve the ADMM step. Args: x0: Initial value (ignored). Returns: Computed solution. """ rhs = self.compute_rhs() x = self.solver.solve(rhs) if self.check_solve: self.accuracy = self.solver.accuracy(x, rhs) return x class CircularConvolveSolver(LinearSubproblemSolver): r"""Solver for linear operators diagonalized in the DFT domain. Specialization of :class:`.LinearSubproblemSolver` for the case where :code:`f` is ``None``, or an instance of :class:`.SquaredL2Loss` with a forward operator :code:`f.A` that is either an instance of :class:`.Identity` or :class:`.CircularConvolve`, and the :code:`C_i` are all shift invariant linear operators, examples of which include instances of :class:`.Identity` as well as some instances (depending on initializer parameters) of :class:`.CircularConvolve` and :class:`.FiniteDifference`. None of the instances of :class:`.CircularConvolve` may sum over any of their axes. Attributes: admm (:class:`.ADMM`): ADMM solver object to which the solver is attached. lhs_f (array): Left hand side, in the DFT domain, of the linear equation to be solved. """ def __init__(self, ndims: Optional[int] = None): """Initialize a :class:`CircularConvolveSolver` object. Args: ndims: Number of trailing dimensions of the input and kernel involved in the :class:`.CircularConvolve` convolutions. In most cases this value is automatically determined from the optimization problem specification, but this is not possible when :code:`f` is ``None`` and none of the :code:`C_i` are of type :class:`.CircularConvolve`. When not ``None``, this parameter overrides the automatic mechanism. """ self.ndims = ndims def internal_init(self, admm: soa.ADMM): if admm.f is None: is_cc = [isinstance(C, CircularConvolve) for C in admm.C_list] if any(is_cc): auto_ndims = admm.C_list[is_cc.index(True)].ndims else: auto_ndims = None else: if not isinstance(admm.f, SquaredL2Loss): raise TypeError( "CircularConvolveSolver requires f to be a scico.loss.SquaredL2Loss; " f"got {type(admm.f)}." ) if not isinstance(admm.f.A, (CircularConvolve, Identity)): raise TypeError( "CircularConvolveSolver requires f.A to be a scico.linop.CircularConvolve " f"or scico.linop.Identity; got {type(admm.f.A)}." ) auto_ndims = admm.f.A.ndims if isinstance(admm.f.A, CircularConvolve) else None if self.ndims is None: self.ndims = auto_ndims SubproblemSolver.internal_init(self, admm) self.real_result = is_real_dtype(admm.C_list[0].input_dtype) # All of the C operators are assumed to be linear and shift invariant # but this is not checked. lhs_op_list = [ rho * CircularConvolve.from_operator(C.gram_op, ndims=self.ndims) for rho, C in zip(admm.rho_list, admm.C_list) ] A_lhs = reduce(lambda a, b: a + b, lhs_op_list) if self.admm.f is not None: A_lhs += ( 2.0 * admm.f.scale * CircularConvolve.from_operator(admm.f.A.gram_op, ndims=self.ndims) ) self.A_lhs = A_lhs def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]: """Solve the ADMM step. Args: x0: Initial value (unused, has no effect). Returns: Computed solution. """ rhs = self.compute_rhs() rhs_dft = snp.fft.fftn(rhs, axes=self.A_lhs.x_fft_axes) x_dft = rhs_dft / self.A_lhs.h_dft x = snp.fft.ifftn(x_dft, axes=self.A_lhs.x_fft_axes) if self.real_result: x = x.real return x class FBlockCircularConvolveSolver(LinearSubproblemSolver): r"""Solver for linear operators block-diagonalized in the DFT domain. Specialization of :class:`.LinearSubproblemSolver` for the case where :code:`f` is an instance of :class:`.SquaredL2Loss`, the forward operator :code:`f.A` is a composition of a :class:`.Sum` operator and a :class:`.CircularConvolve` operator. The former must sum over the first axis of its input, and the latter must be initialized so that it convolves a set of filters, indexed by the first axis, with an input array that has the same number of axes as the filter array, and has an initial axis of the same length as that of the filter array. The :math:`C_i` must all be shift invariant linear operators, examples of which include instances of :class:`.Identity` as well as some instances (depending on initializer parameters) of :class:`.CircularConvolve` and :class:`.FiniteDifference`. None of the instances of :class:`.CircularConvolve` may be summed over any of their axes. The solver is based on the frequency-domain approach proposed in :cite:`wohlberg-2014-efficient`. We have :math:`f = \omega \norm{A \mb{x} - \mb{y}}_2^2`, where typically :math:`\omega = 1/2`, and :math:`A` is a block-row operator with circulant blocks, i.e. it can be written as .. math:: A = \left( \begin{array}{cccc} A_1 & A_2 & \ldots & A_{K} \end{array} \right) \;, where all of the :math:`A_k` are circular convolution operators. The complete functional to be minimized is .. math:: \omega \norm{A \mb{x} - \mb{y}}_2^2 + \sum_{i=1}^N g_i(C_i \mb{x}) \;, where the :math:`C_i` are either identity or circular convolutions, and the ADMM x-step is .. math:: \mb{x}^{(j+1)} = \argmin_{\mb{x}} \; \omega \norm{A \mb{x} - \mb{y}}_2^2 + \sum_i \frac{\rho_i}{2} \norm{C_i \mb{x} - (\mb{z}^{(j)}_i - \mb{u}^{(j)}_i)}_2^2 \;. This subproblem is most easily solved in the DFT transform domain, where the circular convolutions become diagonal operators. Denoting the frequency-domain versions of variables with a circumflex (e.g. :math:`\hat{\mb{x}}` is the frequency-domain version of :math:`\mb{x}`), the solution of the subproblem can be written as .. math:: \left( \hat{A}^H \hat{A} + \frac{1}{2 \omega} \sum_i \rho_i \hat{C}_i^H \hat{C}_i \right) \hat{\mathbf{x}} = \hat{A}^H \hat{\mb{y}} + \frac{1}{2 \omega} \sum_i \rho_i \hat{C}_i^H (\hat{\mb{z}}_i - \hat{\mb{u}}_i) \;. This linear equation is computational expensive to solve because the left hand side includes the term :math:`\hat{A}^H \hat{A}`, which corresponds to the outer product of :math:`\hat{A}^H` and :math:`\hat{A}`. A computationally efficient solution is possible, however, by exploiting the Woodbury matrix identity .. math:: (D + U G V)^{-1} = D^{-1} - D^{-1} U (G^{-1} + V D^{-1} U)^{-1} V D^{-1} \;. Setting .. math:: D &= \frac{1}{2 \omega} \sum_i \rho_i \hat{C}_i^H \hat{C}_i \\ U &= \hat{A}^H \\ G &= I \\ V &= \hat{A} we have .. math:: (D + \hat{A}^H \hat{A})^{-1} = D^{-1} - D^{-1} \hat{A}^H (I + \hat{A} D^{-1} \hat{A}^H)^{-1} \hat{A} D^{-1} which can be simplified to .. math:: (D + \hat{A}^H \hat{A})^{-1} = D^{-1} (I - \hat{A}^H E^{-1} \hat{A} D^{-1}) by defining :math:`E = I + \hat{A} D^{-1} \hat{A}^H`. The right hand side is much cheaper to compute because the only matrix inversions involve :math:`D`, which is diagonal, and :math:`E`, which is a weighted inner product of :math:`\hat{A}^H` and :math:`\hat{A}`. """ def __init__(self, ndims: Optional[int] = None, check_solve: bool = False): """Initialize a :class:`FBlockCircularConvolveSolver` object. Args: check_solve: If ``True``, compute solver accuracy after each solve. """ self.ndims = ndims self.check_solve = check_solve self.accuracy: Optional[float] = None def internal_init(self, admm: soa.ADMM): if admm.f is None: raise ValueError("FBlockCircularConvolveSolver does not allow f to be None.") else: if not isinstance(admm.f, SquaredL2Loss): raise TypeError( "FBlockCircularConvolveSolver requires f to be a scico.loss.SquaredL2Loss; " f"got {type(admm.f)}." ) if not isinstance(admm.f.A, ComposedLinearOperator): raise TypeError( "FBlockCircularConvolveSolver requires f.A to be a composition of Sum " f"and CircularConvolve linear operators; got {type(admm.f.A)}." ) SubproblemSolver.internal_init(self, admm) assert isinstance(self.admm.f, SquaredL2Loss) assert isinstance(self.admm.f.A, ComposedLinearOperator) # All of the C operators are assumed to be linear and shift invariant # but this is not checked. c_gram_list = [ rho * CircularConvolve.from_operator(C.gram_op, ndims=self.ndims) for rho, C in zip(admm.rho_list, admm.C_list) ] D = reduce(lambda a, b: a + b, c_gram_list) / (2.0 * self.admm.f.scale) self.solver = ConvATADSolver(self.admm.f.A, D) def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]: """Solve the ADMM step. Args: x0: Initial value (unused, has no effect). Returns: Computed solution. """ assert isinstance(self.admm.f, SquaredL2Loss) rhs = self.compute_rhs() / (2.0 * self.admm.f.scale) x = self.solver.solve(rhs) if self.check_solve: self.accuracy = self.solver.accuracy(x, rhs) return x class G0BlockCircularConvolveSolver(SubproblemSolver): r"""Solver for linear operators block-diagonalized in the DFT domain. Specialization of :class:`.LinearSubproblemSolver` for the case where :math:`f = 0` (i.e, :code:`f` is a :class:`.ZeroFunctional`), :math:`g_1` is an instance of :class:`.SquaredL2Loss`, :math:`C_1` is a composition of a :class:`.Sum` operator an a :class:`.CircularConvolve` operator. The former must sum over the first axis of its input, and the latter must be initialized so that it convolves a set of filters, indexed by the first axis, with an input array that has the same number of axes as the filter array, and has an initial axis of the same length as that of the filter array. The other :math:`C_i` must all be shift invariant linear operators, examples of which include instances of :class:`.Identity` as well as some instances (depending on initializer parameters) of :class:`.CircularConvolve` and :class:`.FiniteDifference`. None of these instances of :class:`.CircularConvolve` may be summed over any of their axes. The solver is based on the frequency-domain approach proposed in :cite:`wohlberg-2014-efficient`. We have :math:`g_1 = \omega \norm{B A \mb{x} - \mb{y}}_2^2`, where typically :math:`\omega = 1/2`, :math:`B` is the identity or a diagonal operator, and :math:`A` is a block-row operator with circulant blocks, i.e. it can be written as .. math:: A = \left( \begin{array}{cccc} A_1 & A_2 & \ldots & A_{K} \end{array} \right) \;, where all of the :math:`A_k` are circular convolution operators. The complete functional to be minimized is .. math:: \sum_{i=1}^N g_i(C_i \mb{x}) \;, where .. math:: g_1(\mb{z}) &= \omega \norm{B \mb{z} - \mb{y}}_2^2\\ C_1 &= A \;, and the other :math:`C_i` are either identity or circular convolutions. The ADMM x-step is .. math:: \mb{x}^{(j+1)} = \argmin_{\mb{x}} \; \rho_1 \omega \norm{ A \mb{x} - (\mb{z}^{(j)}_1 - \mb{u}^{(j)}_1)}_2^2 + \sum_{i=2}^N \frac{\rho_i}{2} \norm{C_i \mb{x} - (\mb{z}^{(j)}_i - \mb{u}^{(j)}_i)}_2^2 \;. This subproblem is most easily solved in the DFT transform domain, where the circular convolutions become diagonal operators. Denoting the frequency-domain versions of variables with a circumflex (e.g. :math:`\hat{\mb{x}}` is the frequency-domain version of :math:`\mb{x}`), the solution of the subproblem can be written as .. math:: \left( \hat{A}^H \hat{A} + \frac{1}{2 \omega \rho_1} \sum_{i=2}^N \rho_i \hat{C}_i^H \hat{C}_i \right) \hat{\mathbf{x}} = \hat{A}^H (\hat{\mb{z}}_1 - \hat{\mb{u}}_1) + \frac{1}{2 \omega \rho_1} \sum_{i=2}^N \rho_i \hat{C}_i^H (\hat{\mb{z}}_i - \hat{\mb{u}}_i) \;. This linear equation is computational expensive to solve because the left hand side includes the term :math:`\hat{A}^H \hat{A}`, which corresponds to the outer product of :math:`\hat{A}^H` and :math:`\hat{A}`. A computationally efficient solution is possible, however, by exploiting the Woodbury matrix identity .. math:: (D + U G V)^{-1} = D^{-1} - D^{-1} U (G^{-1} + V D^{-1} U)^{-1} V D^{-1} \;. Setting .. math:: D &= \frac{1}{2 \omega \rho_1} \sum_{i=2}^N \rho_i \hat{C}_i^H \hat{C}_i \\ U &= \hat{A}^H \\ G &= I \\ V &= \hat{A} we have .. math:: (D + \hat{A}^H \hat{A})^{-1} = D^{-1} - D^{-1} \hat{A}^H (I + \hat{A} D^{-1} \hat{A}^H)^{-1} \hat{A} D^{-1} which can be simplified to .. math:: (D + \hat{A}^H \hat{A})^{-1} = D^{-1} (I - \hat{A}^H E^{-1} \hat{A} D^{-1}) by defining :math:`E = I + \hat{A} D^{-1} \hat{A}^H`. The right hand side is much cheaper to compute because the only matrix inversions involve :math:`D`, which is diagonal, and :math:`E`, which is a weighted inner product of :math:`\hat{A}^H` and :math:`\hat{A}`. """ def __init__(self, ndims: Optional[int] = None, check_solve: bool = False): """Initialize a :class:`G0BlockCircularConvolveSolver` object. Args: check_solve: If ``True``, compute solver accuracy after each solve. """ self.ndims = ndims self.check_solve = check_solve self.accuracy: Optional[float] = None def internal_init(self, admm: soa.ADMM): if admm.f is not None and not isinstance(admm.f, ZeroFunctional): raise ValueError( "G0BlockCircularConvolveSolver requires f to be None or a ZeroFunctional" ) if not isinstance(admm.g_list[0], SquaredL2Loss): raise TypeError( "G0BlockCircularConvolveSolver requires g_1 to be a scico.loss.SquaredL2Loss; " f"got {type(admm.g_list[0])}." ) if not isinstance(admm.C_list[0], ComposedLinearOperator): raise TypeError( "G0BlockCircularConvolveSolver requires C_1 to be a composition of Sum " f"and CircularConvolve linear operators; got {type(admm.C_list[0])}." ) SubproblemSolver.internal_init(self, admm) assert isinstance(self.admm.g_list[0], SquaredL2Loss) assert isinstance(self.admm.C_list[0], ComposedLinearOperator) # All of the C operators are assumed to be linear and shift invariant # but this is not checked. c_gram_list = [ rho * CircularConvolve.from_operator(C.gram_op, ndims=self.ndims) for rho, C in zip(admm.rho_list[1:], admm.C_list[1:]) ] D = reduce(lambda a, b: a + b, c_gram_list) / ( 2.0 * self.admm.g_list[0].scale * admm.rho_list[0] ) self.solver = ConvATADSolver(self.admm.C_list[0], D) def compute_rhs(self) -> Union[Array, BlockArray]: r"""Compute the right hand side of the linear equation to be solved. Compute .. math:: C_1^H ( \mb{z}^{(k)}_1 - \mb{u}^{(k)}_1) + \frac{1}{2 \omega \rho_1}\sum_{i=2}^N \rho_i C_i^H ( \mb{z}^{(k)}_i - \mb{u}^{(k)}_i) \;. Returns: Right hand side of the linear equation. """ assert isinstance(self.admm.g_list[0], SquaredL2Loss) C0 = self.admm.C_list[0] rhs = snp.zeros(C0.input_shape, C0.input_dtype) omega = self.admm.g_list[0].scale omega_list = [ 2.0 * omega, ] + [ 1.0, ] * (len(self.admm.C_list) - 1) for omegai, rhoi, Ci, zi, ui in zip( omega_list, self.admm.rho_list, self.admm.C_list, self.admm.z_list, self.admm.u_list ): rhs += omegai * rhoi * Ci.adj(zi - ui) return rhs def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]: """Solve the ADMM step. Args: x0: Initial value (unused, has no effect). Returns: Computed solution. """ assert isinstance(self.admm.g_list[0], SquaredL2Loss) rhs = self.compute_rhs() / (2.0 * self.admm.g_list[0].scale * self.admm.rho_list[0]) x = self.solver.solve(rhs) if self.check_solve: self.accuracy = self.solver.accuracy(x, rhs) return x ================================================ FILE: scico/optimize/_common.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2023-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Functions common to multiple optimizer modules.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union import numpy as np from scico.diagnostics import IterationStats from scico.numpy import Array, BlockArray from scico.numpy.util import ( array_to_namedtuple, namedtuple_to_array, transpose_ntpl_of_list, ) from scico.util import Timer def itstat_func_and_object( itstat_fields: dict, itstat_attrib: List, itstat_options: Optional[dict] = None ) -> Tuple[Callable, IterationStats]: """Iteration statistics initialization. Iteration statistics initialization steps common to all optimizer classes. Args: itstat_fields: A dictionary associating field names with format strings for displaying the corresponding values. itstat_attrib: A list of expressions corresponding of optimizer class attributes to be evaluated when computing iteration statistics. itstat_options: A dict of named parameters to be passed to the :class:`.diagnostics.IterationStats` initializer. The dict may also include an additional key "itstat_func" with the corresponding value being a function with two parameters, an integer and a :class:`Optimizer` object, responsible for constructing a tuple ready for insertion into the :class:`.diagnostics.IterationStats` object. If ``None``, default values are used for the dict entries, otherwise the default dict is updated with the dict specified by this parameter. Returns: A tuple consisting of the statistics insertion function and the :class:`.diagnostics.IterationStats` object. """ # dynamically create itstat_func; see https://stackoverflow.com/questions/24733831 itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")" scope: Dict[str, Callable] = {} exec("def itstat_func(obj): " + itstat_return, scope) # determine itstat options and initialize IterationStats object default_itstat_options: Dict[str, Union[dict, Callable, bool]] = { "fields": itstat_fields, "itstat_func": scope["itstat_func"], "display": False, } if itstat_options: default_itstat_options.update(itstat_options) itstat_insert_func: Callable = default_itstat_options.pop("itstat_func", None) # type: ignore itstat_object = IterationStats(**default_itstat_options) # type: ignore return itstat_insert_func, itstat_object class Optimizer: """Base class for optimizer classes. Attributes: itnum (int): Optimizer iteration counter. maxiter (int): Maximum number of optimizer outer-loop iterations. timer (:class:`.Timer`): Iteration timer. """ def __init__(self, **kwargs: Any): """Initialize common attributes of :class:`Optimizer` objects. Args: **kwargs: Optional parameter dict. Valid keys are: iter0: Initial value of iteration counter. Default value is 0. maxiter: Maximum iterations on call to :meth:`solve`. Default value is 100. nanstop: If ``True``, stop iterations if a ``NaN`` or ``Inf`` value is encountered in a solver working variable. Default value is ``False``. itstat_options: A dict of named parameters to be passed to the :class:`.diagnostics.IterationStats` initializer. The dict may also include an additional key "itstat_func" with the corresponding value being a function with two parameters, an integer and an :class:`Optimizer` object, responsible for constructing a tuple ready for insertion into the :class:`.diagnostics.IterationStats` object. If ``None``, default values are used for the dict entries, otherwise the default dict is updated with the dict specified by this parameter. """ iter0 = kwargs.pop("iter0", 0) self.maxiter: int = kwargs.pop("maxiter", 100) self.nanstop: bool = kwargs.pop("nanstop", False) itstat_options = kwargs.pop("itstat_options", None) if kwargs: raise TypeError(f"Unrecognized keyword argument(s) {', '.join([k for k in kwargs])}") self.itnum: int = iter0 self.timer: Timer = Timer() itstat_fields, itstat_attrib = self._itstat_default_fields() itstat_extra_fields, itstat_extra_attrib = self._itstat_extra_fields() itstat_fields.update(itstat_extra_fields) itstat_attrib.extend(itstat_extra_attrib) self.itstat_insert_func, self.itstat_object = itstat_func_and_object( itstat_fields, itstat_attrib, itstat_options ) def _working_vars_finite(self) -> bool: """Determine where ``NaN`` of ``Inf`` encountered in solve. Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in a solver working variable. """ raise NotImplementedError( "NaN check requested but _working_vars_finite not implemented." "" ) def _itstat_default_fields(self) -> Tuple[Dict[str, str], List[str]]: """Define iterations stats default fields. Return a dict mapping field names to format strings, and a list of strings containing the names of attributes or methods to call in order to determine the value for each field. """ # iteration number and time fields itstat_fields = { "Iter": "%d", "Time": "%8.2e", } itstat_attrib = ["itnum", "timer.elapsed()"] # objective function can be evaluated if 'g' function can be evaluated if self._objective_evaluatable(): itstat_fields.update({"Objective": "%9.3e"}) itstat_attrib.append("objective()") return itstat_fields, itstat_attrib def _objective_evaluatable(self) -> bool: """Determine whether the objective function can be evaluated. Determine whether the objective function for a :class:`Optimizer` object can be evaluated. """ return False def _itstat_extra_fields(self) -> Tuple[Dict[str, str], List[str]]: """Define additional iterations stats fields. Define iterations stats fields that are not common to all :class:`Optimizer` classes. Return a dict mapping field names to format strings, and a list of strings containing the names of attributes or methods to call in order to determine the value for each field. """ return {}, [] def _state_variable_names(self) -> List[str]: """Get optimizer state variable names. Get optimizer state variable names. Returns: List of names of class attributes that represent algorithm state variables. """ raise NotImplementedError(f"Method _state_variables is not implemented for {type(self)}.") def _get_state_variables(self) -> dict[str, Any]: """Get optimizer state variables. Get optimizer state variables. Returns: Dict of state variable names and corresponding values. """ return {k: getattr(self, k) for k in self._state_variable_names()} def _set_state_variables(self, **kwargs): """Set optimizer state variables. Set optimizer state variables. Args: **kwargs: State variables to be set, with parameter names corresponding to their class attribute names. """ valid_vars = self._state_variable_names() for k, v in kwargs.items(): if k not in valid_vars: raise RuntimeError(f"{k} is not a valid state variable for {type(self)}.") setattr(self, k, v) def save_state(self, path: str): """Save optimizer state to a file. Save optimizer state to a file. Args: path: Filename of file to which state should be saved. """ state_vars = self._get_state_variables() np.savez( path, opt_class=self.__class__, itnum=self.itnum, history=namedtuple_to_array(self.history(transpose=True)), # type: ignore **state_vars, ) def load_state(self, path: str): """Load optimizer state from a file. Restore optimizer state from a file. Args: path: Filename of state file saved using :meth:`save_state`. """ npz = np.load(path, allow_pickle=True) if npz["opt_class"] != self.__class__: raise TypeError( f"Cannot load state for {npz['solver_class']} into optimizer " f"of type {self.__class__}." ) npzd = dict(npz) npzd.pop("opt_class") self.itnum = npzd.pop("itnum") history = transpose_ntpl_of_list(array_to_namedtuple(npzd.pop("history"))) self.itstat_object.iterations = history self._set_state_variables(**npzd) def history(self, transpose: bool = False) -> Union[List[NamedTuple], Tuple[List]]: """Retrieve record of algorithm iterations. Retrieve record of algorithm iterations. Args: transpose: Flag indicating whether results should be returned in "transposed" form, i.e. as a namedtuple of lists rather than a list of namedtuples. Returns: Record of all iterations. """ return self.itstat_object.history(transpose=transpose) def minimizer(self) -> Union[Array, BlockArray]: """Return the current estimate of the functional mimimizer. Returns: Current estimate of the functional mimimizer. """ raise NotImplementedError(f"Method minimizer is not implemented for {type(self)}.") def step(self): """Perform a single optimizer step.""" raise NotImplementedError(f"Method step is not implemented for {type(self)}.") def solve( self, callback: Optional[Callable[[Optimizer], None]] = None, ) -> Union[Array, BlockArray]: r"""Initialize and run the optimization algorithm. Initialize and run the opimization algorithm for a total of `self.maxiter` iterations. Args: callback: An optional callback function, taking an a single argument of type :class:`Optimizer`, that is called at the end of every iteration. Returns: Computed solution. """ self.timer.start() for self.itnum in range(self.itnum, self.itnum + self.maxiter): self.step() if self.nanstop and not self._working_vars_finite(): raise ValueError( f"NaN or Inf value encountered in working variable in iteration {self.itnum}." "" ) self.itstat_object.insert(self.itstat_insert_func(self)) if callback: self.timer.stop() callback(self) self.timer.start() self.timer.stop() self.itnum += 1 self.itstat_object.end() return self.minimizer() ================================================ FILE: scico/optimize/_ladmm.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Linearized ADMM solver.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from typing import List, Optional, Tuple, Union import scico.numpy as snp from scico.functional import Functional from scico.linop import LinearOperator from scico.numpy import Array, BlockArray from scico.numpy.linalg import norm from ._common import Optimizer class LinearizedADMM(Optimizer): r"""Linearized alternating direction method of multipliers algorithm. | Solve an optimization problem of the form .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + g(C \mb{x}) \;, where :math:`f` and :math:`g` are instances of :class:`.Functional`, (in most cases :math:`f` will, more specifically be an an instance of :class:`.Loss`), and :math:`C` is an instance of :class:`.LinearOperator`. The optimization problem is solved by introducing the splitting :math:`\mb{z} = C \mb{x}` and solving .. math:: \argmin_{\mb{x}, \mb{z}} \; f(\mb{x}) + g(\mb{z}) \; \text{such that}\; C \mb{x} = \mb{z} \;, via a linearized ADMM algorithm :cite:`yang-2012-linearized` :cite:`parikh-2014-proximal` (Sec. 4.4.2) consisting of the iterations (see :meth:`step`) .. math:: \begin{aligned} \mb{x}^{(k+1)} &= \mathrm{prox}_{\mu f} \left( \mb{x}^{(k)} - (\mu / \nu) C^T \left(C \mb{x}^{(k)} - \mb{z}^{(k)} + \mb{u}^{(k)} \right) \right) \\ \mb{z}^{(k+1)} &= \mathrm{prox}_{\nu g} \left(C \mb{x}^{(k+1)} + \mb{u}^{(k)} \right) \\ \mb{u}^{(k+1)} &= \mb{u}^{(k)} + C \mb{x}^{(k+1)} - \mb{z}^{(k+1)} \;. \end{aligned} Parameters :math:`\mu` and :math:`\nu` are required to satisfy .. math:: 0 < \mu < \nu \| C \|_2^{-2} \;. Attributes: f (:class:`.Functional`): Functional :math:`f` (usually a :class:`.Loss`). g (:class:`.Functional`): Functional :math:`g`. C (:class:`.LinearOperator`): :math:`C` operator. mu (scalar): First algorithm parameter. nu (scalar): Second algorithm parameter. u (array-like): Scaled Lagrange multipliers :math:`\mb{u}` at current iteration. x (array-like): Solution variable. z (array-like): Auxiliary variables :math:`\mb{z}` at current iteration. z_old (array-like): Auxiliary variables :math:`\mb{z}` at previous iteration. """ def __init__( self, f: Functional, g: Functional, C: LinearOperator, mu: float, nu: float, x0: Optional[Union[Array, BlockArray]] = None, **kwargs, ): r"""Initialize a :class:`LinearizedADMM` object. Args: f: Functional :math:`f` (usually a loss function). g: Functional :math:`g`. C: Operator :math:`C`. mu: First algorithm parameter. nu: Second algorithm parameter. x0: Starting point for :math:`\mb{x}`. If ``None``, defaults to an array of zeros. **kwargs: Additional optional parameters handled by initializer of base class :class:`.Optimizer`. """ self.f: Functional = f self.g: Functional = g self.C: LinearOperator = C self.mu: float = mu self.nu: float = nu if x0 is None: input_shape = C.input_shape dtype = C.input_dtype x0 = snp.zeros(input_shape, dtype=dtype) self.x = x0 self.z, self.z_old = self.z_init(self.x) self.u = self.u_init(self.x) super().__init__(**kwargs) def _working_vars_finite(self) -> bool: """Determine where ``NaN`` of ``Inf`` encountered in solve. Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in a solver working variable. """ return ( snp.all(snp.isfinite(self.x)) and snp.all(snp.isfinite(self.z)) and snp.all(snp.isfinite(self.u)) ) def _objective_evaluatable(self): """Determine whether the objective function can be evaluated.""" return self.f.has_eval and self.g.has_eval def _itstat_extra_fields(self): """Define linearized ADMM-specific iteration statistics fields.""" itstat_fields = {"Prml Rsdl": "%9.3e", "Dual Rsdl": "%9.3e"} itstat_attrib = ["norm_primal_residual()", "norm_dual_residual()"] return itstat_fields, itstat_attrib def _state_variable_names(self) -> List[str]: return ["x", "z", "z_old", "u"] def minimizer(self) -> Union[Array, BlockArray]: return self.x def objective( self, x: Optional[Union[Array, BlockArray]] = None, z: Optional[Union[Array, BlockArray]] = None, ) -> float: r"""Evaluate the objective function. Evaluate the objective function .. math:: f(\mb{x}) + g(\mb{z}) \;. Args: x: Point at which to evaluate objective function. If ``None``, the objective is evaluated at the current iterate :code:`self.x`. z: Point at which to evaluate objective function. If ``None``, the objective is evaluated at the current iterate :code:`self.z`. Returns: scalar: Value of the objective function. """ if (x is None) != (z is None): raise ValueError("Both or neither of arguments 'x' and 'z' must be supplied.") if x is None: x = self.x z = self.z return self.f(x) + self.g(z) def norm_primal_residual(self, x: Optional[Union[Array, BlockArray]] = None) -> float: r"""Compute the :math:`\ell_2` norm of the primal residual. Compute the :math:`\ell_2` norm of the primal residual .. math:: \norm{C \mb{x} - \mb{z}}_2 \;. Args: x: Point at which to evaluate primal residual. If ``None``, the primal residual is evaluated at the current iterate :code:`self.x`. Returns: Norm of primal residual. """ if x is None: x = self.x return norm(self.C(self.x) - self.z) def norm_dual_residual(self) -> float: r"""Compute the :math:`\ell_2` norm of the dual residual. Compute the :math:`\ell_2` norm of the dual residual .. math:: \norm{\mb{z}^{(k)} - \mb{z}^{(k-1)}}_2 \;. Returns: Current norm of dual residual. """ return norm(self.C.adj(self.z - self.z_old)) def z_init( self, x0: Union[Array, BlockArray] ) -> Tuple[Union[Array, BlockArray], Union[Array, BlockArray]]: r"""Initialize auxiliary variable :math:`\mb{z}`. Initialized to .. math:: \mb{z} = C \mb{x}^{(0)} \;. :code:`z` and :code:`z_old` are initialized to the same value. Args: x0: Starting point for :math:`\mb{x}`. """ z = self.C(x0) z_old = z return z, z_old def u_init(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]: r"""Initialize scaled Lagrange multiplier :math:`\mb{u}`. Initialized to .. math:: \mb{u} = \mb{0} \;. Note that the parameter `x0` is unused, but is provided for potential use in an overridden method. Args: x0: Starting point for :math:`\mb{x}`. """ u = snp.zeros(self.C.output_shape, dtype=self.C.output_dtype) return u def step(self): r"""Perform a single linearized ADMM iteration. The primary variable :math:`\mb{x}` is updated by computing .. math:: \mb{x}^{(k+1)} = \mathrm{prox}_{\mu f} \left( \mb{x}^{(k)} - (\mu / \nu) A^T \left(A \mb{x}^{(k)} - \mb{z}^{(k)} + \mb{u}^{(k)} \right) \right) \;. The auxiliary variable is updated according to .. math:: \mb{z}^{(k+1)} = \mathrm{prox}_{\nu g} \left(A \mb{x}^{(k+1)} + \mb{u}^{(k)} \right) \;, and the scaled Lagrange multiplier is updated according to .. math:: \mb{u}^{(k+1)} = \mb{u}^{(k)} + C \mb{x}^{(k+1)} - \mb{z}^{(k+1)} \;. """ proxarg = self.x - (self.mu / self.nu) * self.C.conj().T(self.C(self.x) - self.z + self.u) self.x = self.f.prox(proxarg, self.mu, v0=self.x) self.z_old = self.z Cx = self.C(self.x) self.z = self.g.prox(Cx + self.u, self.nu, v0=self.z) self.u = self.u + Cx - self.z ================================================ FILE: scico/optimize/_padmm.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Proximal ADMM solvers.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from typing import List, Optional, Tuple, Union import scico.numpy as snp from scico import cvjp, jvp from scico.function import Function from scico.functional import Functional from scico.linop import Identity, LinearOperator, operator_norm from scico.numpy import Array, BlockArray from scico.numpy.linalg import norm from scico.typing import BlockShape, DType, PRNGKey, Shape from ._common import Optimizer # mypy: disable-error-code=override class ProximalADMMBase(Optimizer): r"""Base class for proximal ADMM solvers. Attributes: f (:class:`.Functional`): Functional :math:`f` (usually a :class:`.Loss`). g (:class:`.Functional`): Functional :math:`g`. rho (scalar): Penalty parameter. mu (scalar): First algorithm parameter. nu (scalar): Second algorithm parameter. x (array-like): Solution variable. z (array-like): Auxiliary variables :math:`\mb{z}` at current iteration. z_old (array-like): Auxiliary variables :math:`\mb{z}` at previous iteration. u (array-like): Scaled Lagrange multipliers :math:`\mb{u}` at current iteration. u_old (array-like): Scaled Lagrange multipliers :math:`\mb{u}` at previous iteration. """ def __init__( self, f: Functional, g: Functional, rho: float, mu: float, nu: float, xshape: Union[Shape, BlockShape], zshape: Union[Shape, BlockShape], ushape: Union[Shape, BlockShape], xdtype: DType, zdtype: DType, udtype: DType, x0: Optional[Union[Array, BlockArray]] = None, z0: Optional[Union[Array, BlockArray]] = None, u0: Optional[Union[Array, BlockArray]] = None, fast_dual_residual: bool = True, **kwargs, ): r"""Initialize a :class:`ProximalADMMBase` object. Args: f: Functional :math:`f` (usually a loss function). g: Functional :math:`g`. rho: Penalty parameter. mu: First algorithm parameter. nu: Second algorithm parameter. xshape: Shape of variable :math:`\mb{x}`. zshape: Shape of variable :math:`\mb{z}`. ushape: Shape of variable :math:`\mb{u}`. xdtype: Dtype of variable :math:`\mb{x}`. zdtype: Dtype of variable :math:`\mb{z}`. udtype: Dtype of variable :math:`\mb{u}`. x0: Initial value for :math:`\mb{x}`. If ``None``, defaults to an array of zeros. z0: Initial value for :math:`\mb{z}`. If ``None``, defaults to an array of zeros. u0: Initial value for :math:`\mb{u}`. If ``None``, defaults to an array of zeros. fast_dual_residual: Flag indicating whether to use fast approximation to the dual residual, or a slower but more accurate calculation. **kwargs: Additional optional parameters handled by initializer of base class :class:`.Optimizer`. """ self.f: Functional = f self.g: Functional = g self.rho: float = rho self.mu: float = mu self.nu: float = nu self.fast_dual_residual: bool = fast_dual_residual if x0 is None: x0 = snp.zeros(xshape, dtype=xdtype) self.x = x0 if z0 is None: z0 = snp.zeros(zshape, dtype=zdtype) self.z = z0 self.z_old = self.z if u0 is None: u0 = snp.zeros(ushape, dtype=udtype) self.u = u0 self.u_old = self.u super().__init__(**kwargs) def _working_vars_finite(self) -> bool: """Determine where ``NaN`` of ``Inf`` encountered in solve. Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in a solver working variable. """ return ( snp.all(snp.isfinite(self.x)) and snp.all(snp.isfinite(self.z)) and snp.all(snp.isfinite(self.u)) ) def _objective_evaluatable(self): """Determine whether the objective function can be evaluated.""" return self.f.has_eval and self.g.has_eval def _itstat_extra_fields(self): """Define linearized ADMM-specific iteration statistics fields.""" itstat_fields = {"Prml Rsdl": "%9.3e", "Dual Rsdl": "%9.3e"} itstat_attrib = ["norm_primal_residual()", "norm_dual_residual()"] return itstat_fields, itstat_attrib def _state_variable_names(self) -> List[str]: return ["x", "z", "z_old", "u", "u_old"] def minimizer(self) -> Union[Array, BlockArray]: return self.x def objective( self, x: Optional[Union[Array, BlockArray]] = None, z: Optional[Union[Array, BlockArray]] = None, ) -> float: r"""Evaluate the objective function. Evaluate the objective function .. math:: f(\mb{x}) + g(\mb{z}) \;. Args: x: Point at which to evaluate objective function. If ``None``, the objective is evaluated at the current iterate :code:`self.x`. z: Point at which to evaluate objective function. If ``None``, the objective is evaluated at the current iterate :code:`self.z`. Returns: scalar: Current value of the objective function. """ if (x is None) != (z is None): raise ValueError("Both or neither of arguments 'x' and 'z' must be supplied") if x is None: x = self.x z = self.z return self.f(x) + self.g(z) class ProximalADMM(ProximalADMMBase): r"""Proximal alternating direction method of multipliers. | Solve an optimization problem of the form .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + g(\mb{z}) \; \text{such that}\; A \mb{x} + B \mb{z} = \mb{c} \;, where :math:`f` and :math:`g` are instances of :class:`.Functional`, (in most cases :math:`f` will, more specifically be an instance of :class:`.Loss`), and :math:`A` and :math:`B` are instances of :class:`LinearOperator`. The optimization problem is solved via a variant of the proximal ADMM algorithm :cite:`deng-2015-global`, consisting of the iterations (see :meth:`step`) .. math:: \begin{aligned} \mb{x}^{(k+1)} &= \mathrm{prox}_{\rho^{-1} \mu^{-1} f} \left( \mb{x}^{(k)} - \mu^{-1} A^T \left(2 \mb{u}^{(k)} - \mb{u}^{(k-1)} \right) \right) \\ \mb{z}^{(k+1)} &= \mathrm{prox}_{\rho^{-1} \nu^{-1} g} \left( \mb{z}^{(k)} - \nu^{-1} B^T \left( A \mb{x}^{(k+1)} + B \mb{z}^{(k)} - \mb{c} + \mb{u}^{(k)} \right) \right) \\ \mb{u}^{(k+1)} &= \mb{u}^{(k)} + A \mb{x}^{(k+1)} + B \mb{z}^{(k+1)} - \mb{c} \;. \end{aligned} Parameters :math:`\mu` and :math:`\nu` are required to satisfy .. math:: \mu > \norm{ A }_2^2 \quad \text{and} \quad \nu > \norm{ B }_2^2 \;. Attributes: A (:class:`.LinearOperator`): :math:`A` linear operator. B (:class:`.LinearOperator`): :math:`B` linear operator. c (array-like): constant :math:`\mb{c}`. """ def __init__( self, f: Functional, g: Functional, A: LinearOperator, rho: float, mu: float, nu: float, B: Optional[LinearOperator] = None, c: Optional[Union[float, Array, BlockArray]] = None, x0: Optional[Union[Array, BlockArray]] = None, z0: Optional[Union[Array, BlockArray]] = None, u0: Optional[Union[Array, BlockArray]] = None, fast_dual_residual: bool = True, **kwargs, ): r"""Initialize a :class:`ProximalADMM` object. Args: f: Functional :math:`f` (usually a loss function). g: Functional :math:`g`. A: Linear operator :math:`A`. rho: Penalty parameter. mu: First algorithm parameter. nu: Second algorithm parameter. B: Linear operator :math:`B` (if ``None``, :math:`B = -I` where :math:`I` is the identity operator). c: Constant :math:`\mb{c}`. If ``None``, defaults to zero. x0: Starting value for :math:`\mb{x}`. If ``None``, defaults to an array of zeros. z0: Starting value for :math:`\mb{z}`. If ``None``, defaults to an array of zeros. u0: Starting value for :math:`\mb{u}`. If ``None``, defaults to an array of zeros. fast_dual_residual: Flag indicating whether to use fast approximation to the dual residual, or a slower but more accurate calculation. **kwargs: Additional optional parameters handled by initializer of base class :class:`.Optimizer`. """ self.A: LinearOperator = A if B is None: self.B = -Identity(self.A.output_shape, self.A.output_dtype) else: self.B = B if c is None: self.c = 0.0 else: self.c = c super().__init__( f, g, rho, mu, nu, self.A.input_shape, self.B.input_shape, self.A.output_shape, self.A.input_dtype, self.B.input_dtype, self.A.output_dtype, x0=x0, z0=z0, u0=u0, fast_dual_residual=fast_dual_residual, **kwargs, ) def norm_primal_residual( self, x: Optional[Union[Array, BlockArray]] = None, z: Optional[Union[Array, BlockArray]] = None, ) -> float: r"""Compute the :math:`\ell_2` norm of the primal residual. Compute the :math:`\ell_2` norm of the primal residual .. math:: \norm{A \mb{x} + B \mb{z} - \mb{c}}_2 \;. Args: x: Point at which to evaluate primal residual. If ``None``, the primal residual is evaluated at the current iterate :code:`self.x`. z: Point at which to evaluate primal residual. If ``None``, the primal residual is evaluated at the current iterate :code:`self.z`. Returns: Norm of primal residual. """ if (x is None) != (z is None): raise ValueError("Both or neither of arguments 'x' and 'z' must be supplied") if x is None: x = self.x z = self.z return norm(self.A(x) + self.B(z) - self.c) def norm_dual_residual(self) -> float: r"""Compute the :math:`\ell_2` norm of the dual residual. Compute the :math:`\ell_2` norm of the dual residual. If the flag requesting a fast approximate calculation is set, it is computed as .. math:: \norm{\mb{z}^{(k+1)} - \mb{z}^{(k)}}_2 \;, otherwise it is computed as .. math:: \norm{A^T B ( \mb{z}^{(k+1)} - \mb{z}^{(k)} ) }_2 \;. Returns: Current norm of dual residual. """ if self.fast_dual_residual: rsdl = self.z - self.z_old # fast but poor approximation else: rsdl = self.A.H(self.B(self.z - self.z_old)) return norm(rsdl) def step(self): r"""Perform a single algorithm iteration. Perform a single algorithm iteration. """ proxarg = self.x - (1.0 / self.mu) * self.A.H(2.0 * self.u - self.u_old) self.x = self.f.prox(proxarg, (1.0 / (self.rho * self.mu)), v0=self.x) proxarg = self.z - (1.0 / self.nu) * self.B.H( self.A(self.x) + self.B(self.z) - self.c + self.u ) self.z_old = self.z self.z = self.g.prox(proxarg, (1.0 / (self.rho * self.nu)), v0=self.z) self.u_old = self.u self.u = self.u + self.A(self.x) + self.B(self.z) - self.c @staticmethod def estimate_parameters( A: LinearOperator, B: Optional[LinearOperator] = None, factor: Optional[float] = 1.01, maxiter: int = 100, key: Optional[PRNGKey] = None, ) -> Tuple[float, float]: r"""Estimate `mu` and `nu` parameters of :class:`ProximalADMM`. Find values of the `mu` and `nu` parameters of :class:`ProximalADMM` that respect the constraints .. math:: \mu > \norm{ A }_2^2 \quad \text{and} \quad \nu > \norm{ B }_2^2 \;. Args: A: Linear operator :math:`A`. B: Linear operator :math:`B` (if ``None``, :math:`B = -I` where :math:`I` is the identity operator). factor: Safety factor with which to multiply estimated operator norms to ensure strict inequality compliance. If ``None``, return the estimated squared operator norms. maxiter: Maximum number of power iterations to use in operator norm estimation (see :func:`.operator_norm`). Default: 100. key: Jax PRNG key to use in operator norm estimation (see :func:`.operator_norm`). Defaults to ``None``, in which case a new key is created. Returns: A tuple (`mu`, `nu`) representing the estimated parameter values or corresponding squared operator norm values, depending on the value of the `factor` parameter. """ mu = operator_norm(A, maxiter=maxiter, key=key) ** 2 if B is None: nu = 1.0 else: nu = operator_norm(B, maxiter=maxiter, key=key) ** 2 if factor is None: return (mu, nu) else: return (factor * mu, factor * nu) class NonLinearPADMM(ProximalADMMBase): r"""Non-linear proximal alternating direction method of multipliers. | Solve an optimization problem of the form .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + g(\mb{z}) \; \text{such that}\; H(\mb{x}, \mb{z}) = 0 \;, where :math:`f` and :math:`g` are instances of :class:`.Functional`, (in most cases :math:`f` will, more specifically be an instance of :class:`.Loss`), and :math:`H` is a function. The optimization problem is solved via a variant of the proximal ADMM algorithm for problems with a non-linear operator constraint :cite:`benning-2016-preconditioned`, consisting of the iterations (see :meth:`step`) .. math:: \begin{aligned} A^{(k)} &= J_{\mb{x}} H(\mb{x}^{(k)}, \mb{z}^{(k)}) \\ \mb{x}^{(k+1)} &= \mathrm{prox}_{\rho^{-1} \mu^{-1} f} \left( \mb{x}^{(k)} - \mu^{-1} (A^{(k)})^T \left(2 \mb{u}^{(k)} - \mb{u}^{(k-1)} \right) \right) \\ B^{(k)} &= J_{\mb{z}} H(\mb{x}^{(k+1)}, \mb{z}^{(k)}) \\ \mb{z}^{(k+1)} &= \mathrm{prox}_{\rho^{-1} \nu^{-1} g} \left( \mb{z}^{(k)} - \nu^{-1} (B^{(k)})^T \left( H(\mb{x}^{(k+1)}, \mb{z}^{(k)}) + \mb{u}^{(k)} \right) \right) \\ \mb{u}^{(k+1)} &= \mb{u}^{(k)} + H(\mb{x}^{(k+1)}, \mb{z}^{(k+1)}) \;. \end{aligned} Parameters :math:`\mu` and :math:`\nu` are required to satisfy .. math:: \mu > \norm{ A^{(k)} }_2^2 \quad \text{and} \quad \nu > \norm{ B^{(k)} }_2^2 for all :math:`A^{(k)}` and :math:`B^{(k)}`. Attributes: H (:class:`.Function`): :math:`H` function. """ def __init__( self, f: Functional, g: Functional, H: Function, rho: float, mu: float, nu: float, x0: Optional[Union[Array, BlockArray]] = None, z0: Optional[Union[Array, BlockArray]] = None, u0: Optional[Union[Array, BlockArray]] = None, fast_dual_residual: bool = True, **kwargs, ): r"""Initialize a :class:`NonLinearPADMM` object. Args: f: Functional :math:`f` (usually a loss function). g: Functional :math:`g`. H: Function :math:`H`. rho: Penalty parameter. mu: First algorithm parameter. nu: Second algorithm parameter. x0: Starting value for :math:`\mb{x}`. If ``None``, defaults to an array of zeros. z0: Starting value for :math:`\mb{z}`. If ``None``, defaults to an array of zeros. u0: Starting value for :math:`\mb{u}`. If ``None``, defaults to an array of zeros. fast_dual_residual: Flag indicating whether to use fast approximation to the dual residual, or a slower but more accurate calculation. **kwargs: Additional optional parameters handled by initializer of base class :class:`.Optimizer`. """ self.H: Function = H super().__init__( f, g, rho, mu, nu, H.input_shapes[0], H.input_shapes[1], H.output_shape, H.input_dtypes[0], H.input_dtypes[1], H.output_dtype, x0=x0, z0=z0, u0=u0, fast_dual_residual=fast_dual_residual, **kwargs, ) def norm_primal_residual( self, x: Optional[Union[Array, BlockArray]] = None, z: Optional[Union[Array, BlockArray]] = None, ) -> float: r"""Compute the :math:`\ell_2` norm of the primal residual. Compute the :math:`\ell_2` norm of the primal residual .. math:: \norm{H(\mb{x}, \mb{z})}_2 \;. Args: x: Point at which to evaluate primal residual. If ``None``, the primal residual is evaluated at the current iterate :code:`self.x`. z: Point at which to evaluate primal residual. If ``None``, the primal residual is evaluated at the current iterate :code:`self.z`. Returns: Norm of primal residual. """ if (x is None) != (z is None): raise ValueError("Both or neither of arguments 'x' and 'z' must be supplied") if x is None: x = self.x z = self.z return norm(self.H(x, z)) def norm_dual_residual(self) -> float: r"""Compute the :math:`\ell_2` norm of the dual residual. Compute the :math:`\ell_2` norm of the dual residual. If the flag requesting a fast approximate calculation is set, it is computed as .. math:: \norm{\mb{z}^{(k+1)} - \mb{z}^{(k)}}_2 \;, otherwise it is computed as .. math:: \norm{A^T B ( \mb{z}^{(k+1)} - \mb{z}^{(k)} ) }_2 \;, where .. math:: A &= J_{\mb{x}} H(\mb{x}^{(k+1)}, \mb{z}^{(k+1)}) \\ B &= J_{\mb{z}} H(\mb{x}^{(k+1)}, \mb{z}^{(k+1)}) \;. Returns: Current norm of dual residual. """ if self.fast_dual_residual: rsdl = self.z - self.z_old # fast but poor approximation else: Hz = lambda z: self.H(self.x, z) B = lambda u: jvp(Hz, (self.z,), (u,))[1] Hx = lambda x: self.H(x, self.z) AH = cvjp(Hx, self.x)[1] rsdl = AH(B(self.z - self.z_old)) return norm(rsdl) def step(self): r"""Perform a single algorithm iteration. Perform a single algorithm iteration. """ AH = self.H.vjp(0, self.x, self.z, conjugate=True)[1] proxarg = self.x - (1.0 / self.mu) * AH(2.0 * self.u - self.u_old) self.x = self.f.prox(proxarg, (1.0 / (self.rho * self.mu)), v0=self.x) BH = self.H.vjp(1, self.x, self.z, conjugate=True)[1] proxarg = self.z - (1.0 / self.nu) * BH(self.H(self.x, self.z) + self.u) self.z_old = self.z self.z = self.g.prox(proxarg, (1.0 / (self.rho * self.nu)), v0=self.z) self.u_old = self.u self.u = self.u + self.H(self.x, self.z) @staticmethod def estimate_parameters( H: Function, x: Optional[Union[Array, BlockArray]] = None, z: Optional[Union[Array, BlockArray]] = None, factor: Optional[float] = 1.01, maxiter: int = 100, key: Optional[PRNGKey] = None, ) -> Tuple[float, float]: r"""Estimate `mu` and `nu` parameters of :class:`NonLinearPADMM`. Find values of the `mu` and `nu` parameters of :class:`NonLinearPADMM` that respect the constraints .. math:: \mu > \norm{ J_x H(\mb{x}, \mb{z}) }_2^2 \quad \text{and} \quad \nu > \norm{ J_z H(\mb{x}, \mb{z}) }_2^2 \;. Args: H: Constraint function :math:`H`. x: Value of :math:`\mb{x}` at which to evaluate the Jacobian. If ``None``, defaults to an array of zeros. z: Value of :math:`\mb{z}` at which to evaluate the Jacobian. If ``None``, defaults to an array of zeros. factor: Safety factor with which to multiply estimated operator norms to ensure strict inequality compliance. If ``None``, return the estimated squared operator norms. maxiter: Maximum number of power iterations to use in operator norm estimation (see :func:`.operator_norm`). Default: 100. key: Jax PRNG key to use in operator norm estimation (see :func:`.operator_norm`). Defaults to ``None``, in which case a new key is created. Returns: A tuple (`mu`, `nu`) representing the estimated parameter values or corresponding squared operator norm values, depending on the value of the `factor` parameter. """ if x is None: x = snp.zeros(H.input_shapes[0], dtype=H.input_dtypes[0]) if z is None: z = snp.zeros(H.input_shapes[1], dtype=H.input_dtypes[1]) Jx = H.jacobian(0, x, z) Jz = H.jacobian(1, x, z) mu = operator_norm(Jx, maxiter=maxiter, key=key) ** 2 nu = operator_norm(Jz, maxiter=maxiter, key=key) ** 2 if factor is None: return (mu, nu) else: return (factor * mu, factor * nu) ================================================ FILE: scico/optimize/_pgm.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Proximal Gradient Method classes.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from functools import partial from typing import List, Optional, Union import jax import scico.numpy as snp from scico.functional import Functional from scico.loss import Loss from scico.numpy import Array, BlockArray from ._common import Optimizer from ._pgmaux import ( AdaptiveBBStepSize, BBStepSize, PGMStepSize, RobustLineSearchStepSize, ) class PGM(Optimizer): r"""Proximal gradient method (PGM) algorithm. Minimize a functional of the form :math:`f(\mb{x}) + g(\mb{x})`, where :math:`f` and the :math:`g` are instances of :class:`.Functional`. Functional :math:`f` should be differentiable and have a Lipschitz continuous derivative, and functional :math:`g` should have a proximal operator defined. The step size :math:`\alpha` of the algorithm is defined in terms of its reciprocal :math:`L`, i.e. :math:`\alpha = 1 / L`. The initial value for this parameter, `L0`, is required to satisfy .. math:: L_0 \geq K(\nabla f) \;, where :math:`K(\nabla f)` denotes the Lipschitz constant of the gradient of :math:`f`. When `f` is an instance of :class:`.SquaredL2Loss` with a :class:`.LinearOperator` `A`, .. math:: K(\nabla f) = \lambda_{ \mathrm{max} }( A^H A ) = \| A \|_2^2 \;, where :math:`\lambda_{\mathrm{max}}(B)` denotes the largest eigenvalue of :math:`B`. The evolution of the step size is controlled by auxiliary class :class:`.PGMStepSize` and derived classes. The default :class:`.PGMStepSize` simply sets :math:`L = L_0`, while the derived classes implement a variety of adaptive strategies. """ def __init__( self, f: Union[Loss, Functional], g: Functional, L0: float, x0: Union[Array, BlockArray], step_size: Optional[PGMStepSize] = None, **kwargs, ): r""" Args: f: Instance of :class:`.Loss` or :class:`.Functional` with defined `grad` method. g: Instance of :class:`.Functional` with defined prox method. L0: Initial estimate of Lipschitz constant of gradient of `f`. x0: Starting point for :math:`\mb{x}`. step_size: Instance of an auxiliary class of type :class:`.PGMStepSize` determining the evolution of the algorithm step size. **kwargs: Additional optional parameters handled by initializer of base class :class:`.Optimizer`. """ #: Functional or Loss to minimize; must have grad method defined. self.f: Union[Loss, Functional] = f if g.has_prox is not True: raise ValueError(f"Functional 'g' ({type(g)}) must have a prox method.") #: Functional to minimize; must have prox defined self.g: Functional = g if step_size is None: step_size = PGMStepSize() self.step_size: PGMStepSize = step_size self.step_size.internal_init(self) self.L: float = L0 # reciprocal of step size (estimate of Lipschitz constant of ∇f) self.fixed_point_residual = snp.inf self.x: Union[Array, BlockArray] = x0 # current estimate of solution super().__init__(**kwargs) def x_step(self, v: Union[Array, BlockArray], L: float) -> Union[Array, BlockArray]: """Compute update for variable `x`.""" return PGM._x_step(self.f, self.g, v, L) @staticmethod @partial(jax.jit, static_argnums=(0, 1)) def _x_step( f: Functional, g: Functional, v: Union[Array, BlockArray], L: float ) -> Union[Array, BlockArray]: """Jit-able static method for computing update for variable `x`.""" return g.prox(v - 1.0 / L * f.grad(v), 1.0 / L) def _working_vars_finite(self) -> bool: """Determine where ``NaN`` of ``Inf`` encountered in solve. Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in a solver working variable. """ return snp.all(snp.isfinite(self.x)) def _objective_evaluatable(self): """Determine whether the objective function can be evaluated.""" return self.f.has_eval and self.g.has_eval def _itstat_extra_fields(self): """Define linearized ADMM-specific iteration statistics fields.""" itstat_fields = {"L": "%9.3e", "Residual": "%9.3e"} itstat_attrib = ["L", "norm_residual()"] return itstat_fields, itstat_attrib def _state_variable_names(self) -> List[str]: return ["x", "L"] def minimizer(self) -> Union[Array, BlockArray]: return self.x def objective(self, x: Optional[Union[Array, BlockArray]] = None) -> float: r"""Evaluate the objective function :math:`f(\mb{x}) + g(\mb{x})`.""" if x is None: x = self.x return self.f(x) + self.g(x) def f_quad_approx( self, x: Union[Array, BlockArray], y: Union[Array, BlockArray], L: float ) -> float: r"""Evaluate the quadratic approximation to function :math:`f`. Evaluate the quadratic approximation to function :math:`f`, corresponding to :math:`\hat{f}_{L}(\mb{x}, \mb{y}) = f(\mb{y}) + \nabla f(\mb{y})^H (\mb{x} - \mb{y}) + \frac{L}{2} \left\|\mb{x} - \mb{y}\right\|_2^2`. """ diff_xy = x - y return ( self.f(y) + snp.sum(snp.real(snp.conj(self.f.grad(y)) * diff_xy)) + 0.5 * L * snp.linalg.norm(diff_xy) ** 2 ) def norm_residual(self) -> float: r"""Return the fixed point residual. Return the fixed point residual (see Sec. 4.3 of :cite:`liu-2018-first`). """ return self.fixed_point_residual def step(self): """Take a single PGM step.""" # Update reciprocal of step size using current solution. self.L = self.step_size.update(self.x) x = self.x_step(self.x, self.L) self.fixed_point_residual = snp.linalg.norm(self.x - x) self.x = x class AcceleratedPGM(PGM): r"""Accelerated proximal gradient method (APGM) algorithm. Minimize a function of the form :math:`f(\mb{x}) + g(\mb{x})`, where :math:`f` and the :math:`g` are instances of :class:`.Functional`. The accelerated form of PGM is also known as FISTA :cite:`beck-2009-fast`. See :class:`.PGM` for more detailed documentation. """ def __init__( self, f: Union[Loss, Functional], g: Functional, L0: float, x0: Union[Array, BlockArray], step_size: Optional[PGMStepSize] = None, **kwargs, ): r""" Args: f: Instance of :class:`.Loss` or :class:`.Functional` with defined `grad` method. g: Instance of :class:`.Functional` with defined prox method. L0: Initial estimate of Lipschitz constant of gradient of `f`. x0: Starting point for :math:`\mb{x}`. step_size: Instance of an auxiliary class of type :class:`.PGMStepSize` determining the evolution of the algorithm step size. **kwargs: Additional optional parameters handled by initializer of base class :class:`.Optimizer`. """ super().__init__(f=f, g=g, L0=L0, x0=x0, step_size=step_size, **kwargs) self.v = x0 self.t = 1.0 def step(self): """Take a single AcceleratedPGM step.""" x_old = self.x # Update reciprocal of step size using current extrapolation. if isinstance(self.step_size, (AdaptiveBBStepSize, BBStepSize)): self.L = self.step_size.update(self.x) else: self.L = self.step_size.update(self.v) if isinstance(self.step_size, RobustLineSearchStepSize): # Robust line search step size uses a different extrapolation sequence. # Update in solution is computed while updating the reciprocal of step size. self.x = self.step_size.Z self.fixed_point_residual = snp.linalg.norm(self.x - x_old) else: self.x = self.x_step(self.v, self.L) self.fixed_point_residual = snp.linalg.norm(self.x - self.v) t_old = self.t self.t = 0.5 * (1 + snp.sqrt(1 + 4 * t_old**2)) self.v = self.x + ((t_old - 1) / self.t) * (self.x - x_old) def _working_vars_finite(self) -> bool: """Determine where ``NaN`` of ``Inf`` encountered in solve. Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in a solver working variable. """ return snp.all(snp.isfinite(self.x)) and snp.all(snp.isfinite(self.v)) def _state_variable_names(self) -> List[str]: """Get optimizer state variable names.""" return ["x", "v", "t", "L"] ================================================ FILE: scico/optimize/_pgmaux.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Proximal Gradient Method auxiliary classes.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from typing import Optional, Union import jax import scico.numpy as snp import scico.optimize.pgm as sop from scico.numpy import Array, BlockArray class PGMStepSize: r"""Base class for computing the PGM step size. Base class for computing the reciprocal of the step size for PGM solvers. The PGM solver implemented by :class:`.PGM` addresses a general proximal gradient form that requires the specification of a step size for the gradient descent step. This class is a base class for methods that estimate the reciprocal of the step size (:math:`L` in PGM equations). Attributes: pgm (:class:`.PGM`): PGM solver object to which the solver is attached. """ def internal_init(self, pgm: sop.PGM): """Second stage initializer to be called by :meth:`.PGM.__init__`. Args: pgm: Reference to :class:`.PGM` object to which the :class:`.StepSize` object is to be attached. """ self.pgm = pgm def update(self, v: Union[Array, BlockArray]) -> float: """Hook for updating the step size in derived classes. Hook for updating the reciprocal of the step size in derived classes. The base class does not compute any update. Args: v: Current solution or current extrapolation (if accelerated PGM). Returns: Current reciprocal of the step size. """ return self.pgm.L class BBStepSize(PGMStepSize): r"""Scheme for step size estimation based on Barzilai-Borwein method. The Barzilai-Borwein method :cite:`barzilai-1988-stepsize` estimates the step size :math:`\alpha` as .. math:: \mb{\Delta x} = \mb{x}_k - \mb{x}_{k-1} \; \\ \mb{\Delta g} = \nabla f(\mb{x}_k) - \nabla f (\mb{x}_{k-1}) \; \\ \alpha = \frac{\mb{\Delta x}^T \mb{\Delta g}}{\mb{\Delta g}^T \mb{\Delta g}} \;\;. Since the PGM solver uses the reciprocal of the step size, the value :math:`L = 1 / \alpha` is returned. When applied to complex-valued problems, only the real part of the inner product is used. When the inner product is negative, the previous iterate is used instead. Attributes: pgm (:class:`.PGM`): PGM solver object to which the solver is attached. """ def __init__(self): """Initialize a :class:`BBStepSize` object.""" self.xprev = None self.gradprev = None def update(self, v: Union[Array, BlockArray]) -> float: """Update the reciprocal of the step size. Args: v: Current solution or current extrapolation (if accelerated PGM). Returns: Updated reciprocal of the step size. """ if self.xprev is None: # Solution and gradient of previous iterate are required. # For first iteration these variables are stored and current estimate is returned. self.xprev = v self.gradprev = self.pgm.f.grad(self.xprev) L = self.pgm.L else: Δx = v - self.xprev gradv = self.pgm.f.grad(v) Δg = gradv - self.gradprev # Taking real part of inner products in case of complex-value problem. den = snp.real(snp.sum(Δx.conj() * Δg)) num = snp.real(snp.sum(Δg.conj() * Δg)) L = num / den # Revert to previous iterate if update results in nan or negative value. if snp.isnan(L) or L <= 0.0: L = self.pgm.L # Store current state and gradient for next update. self.xprev = v self.gradprev = gradv return L class AdaptiveBBStepSize(PGMStepSize): r"""Adaptive Barzilai-Borwein method to determine step size. Adaptive Barzilai-Borwein method to determine step size in PGM, as introduced in :cite:`zhou-2006-adaptive`. The adaptive step size rule computes .. math:: \mb{\Delta x} = \mb{x}_k - \mb{x}_{k-1} \; \\ \mb{\Delta g} = \nabla f(\mb{x}_k) - \nabla f (\mb{x}_{k-1}) \; \\ \alpha^{\mathrm{BB1}} = \frac{\mb{\Delta x}^T \mb{\Delta x}} {\mb{\Delta x}^T \mb{\Delta g}} \; \\ \alpha^{\mathrm{BB2}} = \frac{\mb{\Delta x}^T \mb{\Delta g}} {\mb{\Delta g}^T \mb{\Delta g}} \;\;. The determination of the new steps size is made via the rule .. math:: \alpha = \left\{ \begin{matrix} \alpha^{\mathrm{BB2}} & \mathrm{~if~} \alpha^{\mathrm{BB2}} / \alpha^{\mathrm{BB1}} < \kappa \; \\ \alpha^{\mathrm{BB1}} & \mathrm{~otherwise} \end{matrix} \right . \;, with :math:`\kappa \in (0, 1)`. Since the PGM solver uses the reciprocal of the step size, the value :math:`L = 1 / \alpha` is returned. When applied to complex-valued problems, only the real part of the inner product is used. When the inner product is negative, the previous iterate is used instead. Attributes: pgm (:class:`.PGM`): PGM solver object to which the solver is attached. """ def __init__(self, kappa: float = 0.5): r"""Initialize a :class:`AdaptiveBBStepSize` object. Args: kappa : Threshold for step size selection :math:`\kappa`. """ self.kappa: float = kappa self.xprev: Union[Array, BlockArray] = None self.gradprev: Union[Array, BlockArray] = None self.Lbb1prev: Optional[float] = None self.Lbb2prev: Optional[float] = None def update(self, v: Union[Array, BlockArray]) -> float: """Update the reciprocal of the step size. Args: v: Current solution or current extrapolation (if accelerated PGM). Returns: Updated reciprocal of the step size. """ if self.xprev is None: # Solution and gradient of previous iterate are required. # For first iteration these variables are stored and current estimate is returned. self.xprev = v self.gradprev = self.pgm.f.grad(self.xprev) L = self.pgm.L else: Δx = v - self.xprev gradv = self.pgm.f.grad(v) Δg = gradv - self.gradprev # Taking real part of inner products in case of complex-value problem. innerxx = snp.real(snp.sum(Δx.conj() * Δx)) innerxg = snp.real(snp.sum(Δx.conj() * Δg)) innergg = snp.real(snp.sum(Δg.conj() * Δg)) Lbb1 = innerxg / innerxx # Revert to previous iterate if computation results in nan or negative value. if snp.isnan(Lbb1) or Lbb1 <= 0.0: Lbb1 = self.Lbb1prev Lbb2 = innergg / innerxg # Revert to previous iterate if computation results in nan or negative value. if snp.isnan(Lbb2) or Lbb2 <= 0.0: Lbb2 = self.Lbb2prev # If possible, apply adaptive selection rule, if not, revert to previous iterate if Lbb1 is not None and Lbb2 is not None: if (Lbb1 / Lbb2) < self.kappa: L = Lbb2 else: L = Lbb1 else: L = self.pgm.L # Store current state and gradient for next update. self.xprev = v self.gradprev = gradv # Store current estimates of Barzilai-Borwein 1 (Lbb1) and Barzilai-Borwein 2 (Lbb2). self.Lbb1prev = Lbb1 self.Lbb2prev = Lbb2 return L class LineSearchStepSize(PGMStepSize): r"""Line search for estimating the step size for PGM solvers. Line search for estimating the reciprocal of step size for PGM solvers. The line search strategy described in :cite:`beck-2009-fast` estimates :math:`L` such that :math:`f(\mb{x}) <= \hat{f}_{L}(\mb{x})` is satisfied with :math:`\hat{f}_{L}` a quadratic approximation to :math:`f` defined as .. math:: \hat{f}_{L}(\mb{x}, \mb{y}) = f(\mb{y}) + \nabla f(\mb{y})^H (\mb{x} - \mb{y}) + \frac{L}{2} \left\| \mb{x} - \mb{y} \right\|_2^2 \;, with :math:`\mb{x}` the potential new update and :math:`\mb{y}` the current solution or current extrapolation (if accelerated PGM). Attributes: pgm (:class:`.PGM`): PGM solver object to which the solver is attached. """ def __init__(self, gamma_u: float = 1.2, maxiter: int = 50): r"""Initialize a :class:`LineSearchStepSize` object. Args: gamma_u: Rate of increment in :math:`L`. maxiter: Maximum iterations in line search. """ self.gamma_u: float = gamma_u self.maxiter: int = maxiter def g_prox(v, gradv, L): return self.pgm.g.prox(v - 1.0 / L * gradv, 1.0 / L) self.g_prox = jax.jit(g_prox) def update(self, v: Union[Array, BlockArray]) -> float: """Update the reciprocal of the step size. Args: v: Current solution or current extrapolation (if accelerated PGM). Returns: Updated reciprocal of the step size. """ gradv = self.pgm.f.grad(v) L = self.pgm.L it = 0 while it < self.maxiter: z = self.g_prox(v, gradv, L) fz = self.pgm.f(z) fquad = self.pgm.f_quad_approx(z, v, L) if fz <= fquad: break else: L *= self.gamma_u it += 1 return L class RobustLineSearchStepSize(LineSearchStepSize): r"""Robust line search for estimating the accelerated PGM step size. A robust line search for estimating the reciprocal of step size for accelerated PGM solvers. The robust line search strategy described in :cite:`florea-2017-robust` estimates :math:`L` such that :math:`f(\mb{x}) <= \hat{f}_{L}(\mb{x})` is satisfied with :math:`\hat{f}_{L}` a quadratic approximation to :math:`f` defined as .. math:: \hat{f}_{L}(\mb{x}, \mb{y}) = f(\mb{y}) + \nabla f(\mb{y})^H (\mb{x} - \mb{y}) + \frac{L}{2} \left\| \mb{x} - \mb{y} \right\|_2^2 \;, with :math:`\mb{x}` the potential new update and :math:`\mb{y}` the auxiliary extrapolation state. Attributes: pgm (:class:`.PGM`): PGM solver object to which the solver is attached. """ def __init__(self, gamma_d: float = 0.9, gamma_u: float = 2.0, maxiter: int = 50): r"""Initialize a :class:`RobustLineSearchStepSize` object. Args: gamma_d: Rate of decrement in :math:`L`. gamma_u: Rate of increment in :math:`L`. maxiter: Maximum iterations in line search. """ super(RobustLineSearchStepSize, self).__init__(gamma_u, maxiter) self.gamma_d: float = gamma_d self.Tk: float = 0.0 # State needed for computing auxiliary extrapolation sequence in robust line search. self.Zrb: Union[Array, BlockArray] = None #: Current estimate of solution in robust line search. self.Z: Union[Array, BlockArray] = None def update(self, v: Union[Array, BlockArray]) -> float: """Update the reciprocal of the step size. Args: v: Current solution or current extrapolation (if accelerated PGM). Returns: Updated reciprocal of the step size. """ if self.Zrb is None: self.Zrb = self.pgm.x L = self.pgm.L * self.gamma_d it = 0 while it < self.maxiter: t = (1.0 + snp.sqrt(1.0 + 4.0 * L * self.Tk)) / (2.0 * L) T = self.Tk + t # Auxiliary extrapolation sequence. y = (self.Tk * self.pgm.x + t * self.Zrb) / T # New update based on auxiliary extrapolation and current L estimate. z = self.pgm.x_step(y, L) fz = self.pgm.f(z) fquad = self.pgm.f_quad_approx(z, y, L) if fz <= fquad: break else: L *= self.gamma_u it += 1 self.Tk = T self.Zrb += t * L * (z - y) self.Z = z return L ================================================ FILE: scico/optimize/_primaldual.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Primal-dual solvers.""" # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations from typing import List, Optional, Union import scico.numpy as snp from scico.functional import Functional from scico.linop import LinearOperator, jacobian, operator_norm from scico.numpy import Array, BlockArray from scico.numpy.linalg import norm from scico.operator import Operator from scico.typing import PRNGKey from ._common import Optimizer class PDHG(Optimizer): r"""Primal–dual hybrid gradient (PDHG) algorithm. | Primal–dual hybrid gradient (PDHG) is a family of algorithms :cite:`esser-2010-general` that includes the Chambolle-Pock primal-dual algorithm :cite:`chambolle-2010-firstorder`. The form implemented here is a minor variant :cite:`pock-2011-diagonal` of the original Chambolle-Pock algorithm. Solve an optimization problem of the form .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + g(C \mb{x}) \;, where :math:`f` and :math:`g` are instances of :class:`.Functional`, (in most cases :math:`f` will, more specifically be an an instance of :class:`.Loss`), and :math:`C` is an instance of :class:`.Operator` or :class:`.LinearOperator`. When `C` is a :class:`.LinearOperator`, the algorithm iterations are .. math:: \begin{aligned} \mb{x}^{(k+1)} &= \mathrm{prox}_{\tau f} \left( \mb{x}^{(k)} - \tau C^T \mb{z}^{(k)} \right) \\ \mb{z}^{(k+1)} &= \mathrm{prox}_{\sigma g^*} \left( \mb{z}^{(k)} + \sigma C((1 + \alpha) \mb{x}^{(k+1)} - \alpha \mb{x}^{(k)} \right) \;, \end{aligned} where :math:`g^*` denotes the convex conjugate of :math:`g`. Parameters :math:`\tau > 0` and :math:`\sigma > 0` are also required to satisfy .. math:: \tau \sigma < \| C \|_2^{-2} \;, and it is required that :math:`\alpha \in [0, 1]`. When `C` is a non-linear :class:`.Operator`, a non-linear PDHG variant :cite:`valkonen-2014-primal` is used, with the same iterations except for :math:`\mb{x}` update .. math:: \mb{x}^{(k+1)} = \mathrm{prox}_{\tau f} \left( \mb{x}^{(k)} - \tau [J_x C(\mb{x}^{(k)})]^T \mb{z}^{(k)} \right) \;. Attributes: f (:class:`.Functional`): Functional :math:`f` (usually a :class:`.Loss`). g (:class:`.Functional`): Functional :math:`g`. C (:class:`.Operator`): :math:`C` operator. tau (scalar): First algorithm parameter. sigma (scalar): Second algorithm parameter. alpha (scalar): Relaxation parameter. x (array-like): Primal variable :math:`\mb{x}` at current iteration. x_old (array-like): Primal variable :math:`\mb{x}` at previous iteration. z (array-like): Dual variable :math:`\mb{z}` at current iteration. z_old (array-like): Dual variable :math:`\mb{z}` at previous iteration. """ def __init__( self, f: Functional, g: Functional, C: Operator, tau: float, sigma: float, alpha: float = 1.0, x0: Optional[Union[Array, BlockArray]] = None, z0: Optional[Union[Array, BlockArray]] = None, **kwargs, ): r"""Initialize a :class:`PDHG` object. Args: f: Functional :math:`f` (usually a loss function). g: Functional :math:`g`. C: Operator :math:`C`. tau: First algorithm parameter. sigma: Second algorithm parameter. alpha: Relaxation parameter. x0: Starting point for :math:`\mb{x}`. If ``None``, defaults to an array of zeros. z0: Starting point for :math:`\mb{z}`. If ``None``, defaults to an array of zeros. **kwargs: Additional optional parameters handled by initializer of base class :class:`.Optimizer`. """ self.f: Functional = f self.g: Functional = g self.C: Operator = C self.tau: float = tau self.sigma: float = sigma self.alpha: float = alpha if x0 is None: input_shape = C.input_shape dtype = C.input_dtype x0 = snp.zeros(input_shape, dtype=dtype) self.x = x0 self.x_old = self.x if z0 is None: input_shape = C.output_shape dtype = C.output_dtype z0 = snp.zeros(input_shape, dtype=dtype) self.z = z0 self.z_old = self.z super().__init__(**kwargs) def _working_vars_finite(self) -> bool: """Determine where ``NaN`` of ``Inf`` encountered in solve. Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in a solver working variable. """ return snp.all(snp.isfinite(self.x)) and snp.all(snp.isfinite(self.z)) def _objective_evaluatable(self): """Determine whether the objective function can be evaluated.""" return self.f.has_eval and self.g.has_eval def _itstat_extra_fields(self): """Define linearized ADMM-specific iteration statistics fields.""" itstat_fields = {"Prml Rsdl": "%9.3e", "Dual Rsdl": "%9.3e"} itstat_attrib = ["norm_primal_residual()", "norm_dual_residual()"] return itstat_fields, itstat_attrib def _state_variable_names(self) -> List[str]: return ["x", "x_old", "z", "z_old"] def minimizer(self) -> Union[Array, BlockArray]: return self.x def objective( self, x: Optional[Union[Array, BlockArray]] = None, ) -> float: r"""Evaluate the objective function. Evaluate the objective function .. math:: f(\mb{x}) + g(C \mb{x}) \;. Args: x: Point at which to evaluate objective function. If ``None``, the objective is evaluated at the current iterate :code:`self.x` Returns: scalar: Value of the objective function. """ if x is None: x = self.x return self.f(x) + self.g(self.C(x)) def norm_primal_residual(self) -> float: r"""Compute the :math:`\ell_2` norm of the primal residual. Compute the :math:`\ell_2` norm of the primal residual .. math:: \tau^{-1} \norm{\mb{x}^{(k)} - \mb{x}^{(k-1)}}_2 \;. Returns: Current norm of primal residual. """ return norm(self.x - self.x_old) / self.tau # type: ignore def norm_dual_residual(self) -> float: r"""Compute the :math:`\ell_2` norm of the dual residual. Compute the :math:`\ell_2` norm of the dual residual .. math:: \sigma^{-1} \norm{\mb{z}^{(k)} - \mb{z}^{(k-1)}}_2 \;. Returns: Current norm of dual residual. """ return norm(self.z - self.z_old) / self.sigma def step(self): """Perform a single iteration.""" self.x_old = self.x self.z_old = self.z if isinstance(self.C, LinearOperator): proxarg = self.x - self.tau * self.C.conj().T(self.z) else: proxarg = self.x - self.tau * self.C.vjp(self.x, conjugate=True)[1](self.z) self.x = self.f.prox(proxarg, self.tau, v0=self.x) proxarg = self.z + self.sigma * self.C( (1.0 + self.alpha) * self.x - self.alpha * self.x_old ) self.z = self.g.conj_prox(proxarg, self.sigma, v0=self.z) @staticmethod def estimate_parameters( C: Operator, x: Optional[Union[Array, BlockArray]] = None, ratio: float = 1.0, factor: Optional[float] = 1.01, maxiter: int = 100, key: Optional[PRNGKey] = None, ): r"""Estimate `tau` and `sigma` parameters of :class:`PDHG`. Find values of the `tau` and `sigma` parameters of :class:`PDHG` that respect the constraint .. math:: \tau \sigma < \| C \|_2^{-2} \quad \text{or} \quad \tau \sigma < \| J_x C(\mb{x}) \|_2^{-2} \;, depending on whether :math:`C` is a :class:`.LinearOperator` or not. Args: C: Operator :math:`C`. x: Value of :math:`\mb{x}` at which to evaluate the Jacobian of :math:`C` (when it is not a :class:`.LinearOperator`). If ``None``, defaults to an array of zeros. ratio: Desired ratio between return :math:`\tau` and :math:`\sigma` values (:math:`\sigma = \mathrm{ratio} \tau`). factor: Safety factor with which to multiply :math:`\| C \|_2^{-2}` to ensure strict inequality compliance. If ``None``, the value is set to 1.0. maxiter: Maximum number of power iterations to use in operator norm estimation (see :func:`.operator_norm`). Default: 100. key: Jax PRNG key to use in operator norm estimation (see :func:`.operator_norm`). Defaults to ``None``, in which case a new key is created. Returns: A tuple (`tau`, `sigma`) representing the estimated parameter values. """ if x is None: x = snp.zeros(C.input_shape, dtype=C.input_dtype) if factor is None: factor = 1.0 if isinstance(C, LinearOperator): J = C else: J = jacobian(C, x) Cnrm = operator_norm(J, maxiter=maxiter, key=key) tau = snp.sqrt(factor / ratio) / Cnrm sigma = ratio * tau return (tau, sigma) ================================================ FILE: scico/optimize/admm.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """ADMM solver and auxiliary classes.""" import sys # isort: off from ._admmaux import ( SubproblemSolver, GenericSubproblemSolver, LinearSubproblemSolver, MatrixSubproblemSolver, CircularConvolveSolver, FBlockCircularConvolveSolver, G0BlockCircularConvolveSolver, ) from ._admm import ADMM __all__ = [ "SubproblemSolver", "GenericSubproblemSolver", "LinearSubproblemSolver", "MatrixSubproblemSolver", "CircularConvolveSolver", "FBlockCircularConvolveSolver", "G0BlockCircularConvolveSolver", "ADMM", ] # Imported items in __all__ appear to originate in top-level linop module for name in __all__: getattr(sys.modules[__name__], name).__module__ = __name__ ================================================ FILE: scico/optimize/pgm.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """PGM solvers and auxiliary classes.""" import sys # isort: off from ._pgmaux import ( PGMStepSize, BBStepSize, AdaptiveBBStepSize, LineSearchStepSize, RobustLineSearchStepSize, ) from ._pgm import PGM, AcceleratedPGM __all__ = [ "PGMStepSize", "BBStepSize", "AdaptiveBBStepSize", "LineSearchStepSize", "RobustLineSearchStepSize", "PGM", "AcceleratedPGM", ] # Imported items in __all__ appear to originate in top-level linop module for name in __all__: getattr(sys.modules[__name__], name).__module__ = __name__ ================================================ FILE: scico/plot.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Plotting/visualization functions. Optional alternative high-level interface to selected :mod:`matplotlib` plotting functions. """ # This module is copied from https://github.com/bwohlberg/sporco import os import sys import numpy as np import matplotlib import matplotlib.cm as cm import matplotlib.pyplot as plt from matplotlib.pyplot import figure, gca, gcf, savefig, subplot, subplots # noqa from mpl_toolkits.axes_grid1 import make_axes_locatable from mpl_toolkits.mplot3d import Axes3D # noqa try: import mpldatacursor as mpldc except ImportError: have_mpldc = False else: have_mpldc = True __all__ = [ "plot", "surf", "contour", "imview", "close", "set_ipython_plot_backend", "set_notebook_plot_backend", "config_notebook_plotting", ] def _attach_keypress(fig, scaling=1.1): """Attach a key press event handler. Attach a key press event handler that configures keys for closing a figure and changing the figure size. Keys 'e' and 'c' respectively expand and contract the figure, and key 'q' closes it. **Note:** Resizing may not function correctly with all matplotlib backends (a `bug `__ has been reported). Args: fig (:class:`matplotlib.figure.Figure` object): Figure to which event handling is to be attached. scaling (float, optional (default 1.1)): Scaling factor for figure size changes. Returns: function: Key press event handler function. """ def press(event): if event.key == "q": plt.close(fig) elif event.key == "e": fig.set_size_inches(scaling * fig.get_size_inches(), forward=True) elif event.key == "c": fig.set_size_inches(fig.get_size_inches() / scaling, forward=True) # Avoid multiple event handlers attached to the same figure if not hasattr(fig, "_scico_keypress_cid"): cid = fig.canvas.mpl_connect("key_press_event", press) fig._scico_keypress_cid = cid return press def _attach_zoom(ax, scaling=2.0): """Attach a scroll wheel event handler. Attach an event handler that supports zooming within a plot using the mouse scroll wheel. Args: ax (:class:`matplotlib.axes.Axes` object): Axes to which event handling is to be attached. scaling (float, optional (default 2.0)): Scaling factor for zooming in and out. Returns: function: Mouse scroll wheel event handler function. """ # See https://stackoverflow.com/questions/11551049 def zoom(event): # Get the current x and y limits cur_xlim = ax.get_xlim() cur_ylim = ax.get_ylim() # Get event location xdata = event.xdata ydata = event.ydata # Return if cursor is not over valid region of plot if xdata is None or ydata is None: return if event.button == "up": # Deal with zoom in scale_factor = 1.0 / scaling elif event.button == "down": # Deal with zoom out scale_factor = scaling # Get distance from the cursor to the edge of the figure frame x_left = xdata - cur_xlim[0] x_right = cur_xlim[1] - xdata y_top = ydata - cur_ylim[0] y_bottom = cur_ylim[1] - ydata # Calculate new x and y limits new_xlim = (xdata - x_left * scale_factor, xdata + x_right * scale_factor) new_ylim = (ydata - y_top * scale_factor, ydata + y_bottom * scale_factor) # Ensure that x limit range is no larger than that of the reference if np.diff(new_xlim) > np.diff(zoom.xlim_ref): new_xlim *= np.diff(zoom.xlim_ref) / np.diff(new_xlim) # Ensure that lower x limit is not less than that of the reference if new_xlim[0] < zoom.xlim_ref[0]: new_xlim += np.array(zoom.xlim_ref[0] - new_xlim[0]) # Ensure that upper x limit is not greater than that of the reference if new_xlim[1] > zoom.xlim_ref[1]: new_xlim -= np.array(new_xlim[1] - zoom.xlim_ref[1]) # Ensure that ylim tuple has the smallest value first if zoom.ylim_ref[1] < zoom.ylim_ref[0]: ylim_ref = zoom.ylim_ref[::-1] new_ylim = new_ylim[::-1] else: ylim_ref = zoom.ylim_ref # Ensure that y limit range is no larger than that of the reference if np.diff(new_ylim) > np.diff(ylim_ref): new_ylim *= np.diff(ylim_ref) / np.diff(new_ylim) # Ensure that lower y limit is not less than that of the reference if new_ylim[0] < ylim_ref[0]: new_ylim += np.array(ylim_ref[0] - new_ylim[0]) # Ensure that upper y limit is not greater than that of the reference if new_ylim[1] > ylim_ref[1]: new_ylim -= np.array(new_ylim[1] - ylim_ref[1]) # Return the ylim tuple to its original order if zoom.ylim_ref[1] < zoom.ylim_ref[0]: new_ylim = new_ylim[::-1] # Set new x and y limits ax.set_xlim(new_xlim) ax.set_ylim(new_ylim) # Force redraw ax.figure.canvas.draw() # Record reference x and y limits prior to any zooming zoom.xlim_ref = ax.get_xlim() zoom.ylim_ref = ax.get_ylim() # Get figure for specified axes and attach the event handler fig = ax.get_figure() fig.canvas.mpl_connect("scroll_event", zoom) return zoom def plot(y, x=None, ptyp="plot", xlbl=None, ylbl=None, title=None, lgnd=None, lglc=None, **kwargs): """Plot points or lines in 2D. Plot points or lines in 2D. If a figure object is specified then the plot is drawn in that figure, and `fig.show()` is not called. The figure is closed on key entry 'q'. Args: y (array_like): 1d or 2d array of data to plot. If a 2d array, each column is plotted as a separate curve. x (array_like, optional (default ``None``)): Values for x-axis of the plot. ptyp (string, optional (default 'plot')): Plot type specification (options are 'plot', 'semilogx', 'semilogy', and 'loglog'). xlbl (string, optional (default ``None``)): Label for x-axis. ylbl (string, optional (default ``None``)): Label for y-axis. title (string, optional (default ``None``)): Figure title. lgnd (list of strings, optional (default ``None``)): List of legend string. lglc (string, optional (default ``None``)): Legend location string. **kwargs: :class:`matplotlib.lines.Line2D` properties or figure properties. Keyword arguments specifying :class:`matplotlib.lines.Line2D` properties, e.g. `lw=2.0` sets a line width of 2, or properties of the figure and axes. If not specified, the defaults for line width (`lw`) and marker size (`ms`) are 1.5 and 6.0 respectively. The valid figure and axes keyword arguments are listed below: .. |mplfg| replace:: :class:`matplotlib.figure.Figure` object .. |mplax| replace:: :class:`matplotlib.axes.Axes` object .. rst-class:: kwargs ===== ==================== =================================== kwarg Accepts Description ===== ==================== =================================== fgsz tuple (width,height) Specify figure dimensions in inches fgnm integer Figure number of figure fig |mplfg| Draw in specified figure instead of creating one ax |mplax| Plot in specified axes instead of current axes of figure ===== ==================== =================================== Returns: - **fig** (:class:`matplotlib.figure.Figure` object): Figure object for this figure. - **ax** (:class:`matplotlib.axes.Axes` object): Axes object for this plot. Raises: ValueError: If an invalid plot type is specified via parameter `ptyp`. """ # Extract kwargs entries that are not related to line properties fgsz = kwargs.pop("fgsz", None) fgnm = kwargs.pop("fgnm", None) fig = kwargs.pop("fig", None) ax = kwargs.pop("ax", None) figp = fig if fig is None: fig = plt.figure(num=fgnm, figsize=fgsz) fig.clf() ax = fig.gca() elif ax is None: ax = fig.gca() # Set defaults for line width and marker size if "lw" not in kwargs and "linewidth" not in kwargs: kwargs["lw"] = 1.5 if "ms" not in kwargs and "markersize" not in kwargs: kwargs["ms"] = 6.0 if ptyp not in ("plot", "semilogx", "semilogy", "loglog"): raise ValueError("Invalid plot type '%s'." % ptyp) pltmth = getattr(ax, ptyp) if x is None: pltln = pltmth(y, **kwargs) else: pltln = pltmth(x, y, **kwargs) ax.fmt_xdata = "{: .2f}".format ax.fmt_ydata = "{: .2f}".format if title is not None: ax.set_title(title) if xlbl is not None: ax.set_xlabel(xlbl) if ylbl is not None: ax.set_ylabel(ylbl) if lgnd is not None: ax.legend(lgnd, loc=lglc) _attach_keypress(fig) _attach_zoom(ax) if have_mpldc: mpldc.datacursor(pltln) if figp is None: fig.show() return fig, ax def surf( z, x=None, y=None, elev=None, azim=None, xlbl=None, ylbl=None, zlbl=None, title=None, lblpad=8.0, alpha=1.0, cntr=None, cmap=None, fgsz=None, fgnm=None, fig=None, ax=None, ): """Plot a 2D surface in 3D. Plot a 2D surface in 3D. If a figure object is specified then the surface is drawn in that figure, and `fig.show()` is not called. The figure is closed on key entry 'q'. Args: z (array_like): 2d array of data to plot. x (array_like, optional (default ``None``)): Values for x-axis of the plot. y (array_like, optional (default ``None``)): Values for y-axis of the plot. elev (float): Elevation angle (in degrees) in the z plane. azim (float): Azimuth angle (in degrees) in the x,y plane. xlbl (string, optional (default ``None``)): Label for x-axis. ylbl (string, optional (default ``None``)): Label for y-axis. zlbl (string, optional (default ``None``)): Label for z-axis. title (string, optional (default ``None``)): Figure title. lblpad (float, optional (default 8.0)): Label padding. alpha (float between 0.0 and 1.0, optional (default 1.0)): Transparency. cntr (int or sequence of ints, optional (default ``None``)): If not ``None``, plot contours of the surface on the lower end of the z-axis. An int specifies the number of contours to plot, and a sequence specifies the specific contour levels to plot. cmap (:class:`matplotlib.colors.Colormap` object, optional (default ``None``)): Color map for surface. If none specifed, defaults to `cm.YlOrRd`. fgsz (tuple (width,height), optional (default ``None``)): Specify figure dimensions in inches. fgnm (integer, optional (default ``None``)): Figure number of figure. fig (:class:`matplotlib.figure.Figure` object, optional (default ``None``)): Draw in specified figure instead of creating one. ax (:class:`matplotlib.axes.Axes` object, optional (default ``None``)): Plot in specified axes instead of creating one. Returns: - **fig** (:class:`matplotlib.figure.Figure` object): Figure object for this figure. - **ax** (:class:`matplotlib.axes.Axes` object): Axes object for this plot. """ figp = fig if fig is None: fig = plt.figure(num=fgnm, figsize=fgsz) fig.clf() ax = plt.axes(projection="3d") else: if ax is None: ax = plt.axes(projection="3d") else: # See https://stackoverflow.com/a/43563804 # https://stackoverflow.com/a/35221116 if ax.name != "3d": ax.remove() ax = fig.add_subplot(ax.get_subplotspec(), projection="3d") if elev is not None or azim is not None: ax.view_init(elev=elev, azim=azim) if cmap is None: cmap = cm.YlOrRd if x is None: x = range(z.shape[1]) if y is None: y = range(z.shape[0]) xg, yg = np.meshgrid(x, y) ax.plot_surface(xg, yg, z, rstride=1, cstride=1, alpha=alpha, cmap=cmap) if cntr is not None: offset = np.around(z.min() - 0.2 * (z.max() - z.min()), 3) ax.contour(xg, yg, z, cntr, cmap=cmap, linewidths=2, linestyles="solid", offset=offset) ax.set_zlim(offset, ax.get_zlim()[1]) ax.fmt_xdata = "{: .2f}".format ax.fmt_ydata = "{: .2f}".format ax.fmt_zdata = "{: .2f}".format if title is not None: ax.set_title(title) if xlbl is not None: ax.set_xlabel(xlbl, labelpad=lblpad) if ylbl is not None: ax.set_ylabel(ylbl, labelpad=lblpad) if zlbl is not None: ax.set_zlabel(zlbl, labelpad=lblpad) _attach_keypress(fig) if figp is None: fig.show() return fig, ax def contour( z, x=None, y=None, v=5, xlog=False, ylog=False, xlbl=None, ylbl=None, title=None, cfmt=None, cfntsz=10, lfntsz=None, alpha=1.0, cmap=None, vmin=None, vmax=None, fgsz=None, fgnm=None, fig=None, ax=None, ): """Contour plot of a 2D surface. Contour plot of a 2D surface. If a figure object is specified then the plot is drawn in that figure, and `fig.show()` is not called. The figure is closed on key entry 'q'. Args: z (array_like): 2d array of data to plot. x (array_like, optional (default ``None``)): Values for x-axis of the plot. y (array_like, optional (default ``None``)): Values for y-axis of the plot. v (int or sequence of floats, optional (default 5)): An int specifies the number of contours to plot, and a sequence specifies the specific contour levels to plot. xlog (boolean, optional (default ``False``)): Set x-axis to log scale. ylog (boolean, optional (default ``False``)): Set y-axis to log scale. xlbl (string, optional (default ``None``)): Label for x-axis. ylbl (string, optional (default ``None``)): Label for y-axis. title (string, optional (default ``None``)): Figure title. cfmt (string, optional (default ``None``)): Format string for contour labels. cfntsz (int or ``None``, optional (default 10)): Contour label font size. No contour labels are displayed if set to 0 or ``None``. lfntsz (int, optional (default ``None``)): Axis label font size. The default font size is used if set to ``None``. alpha (float, optional (default 1.0)): Underlying image display alpha value. cmap (:class:`matplotlib.colors.Colormap`, optional (default ``None``)): Color map for surface. If none specifed, defaults to `cm.YlOrRd`. vmin, vmax (float, optional (default ``None``)): Set upper and lower bounds for the color map (see the corresponding parameters of :meth:`matplotlib.axes.Axes.imshow`). fgsz (tuple (width,height), optional (default ``None``)): Specify figure dimensions in inches. fgnm (integer, optional (default ``None``)): Figure number of figure. fig (:class:`matplotlib.figure.Figure` object, optional (default ``None``)): Draw in specified figure instead of creating one. ax (:class:`matplotlib.axes.Axes` object, optional (default ``None``)): Plot in specified axes instead of current axes of figure. Returns: - **fig** (:class:`matplotlib.figure.Figure` object): Figure object for this figure. - **ax** (:class:`matplotlib.axes.Axes` object): Axes object for this plot. """ figp = fig if fig is None: fig = plt.figure(num=fgnm, figsize=fgsz) fig.clf() ax = fig.gca() elif ax is None: ax = fig.gca() if xlog: ax.set_xscale("log") if ylog: ax.set_yscale("log") if cmap is None: cmap = cm.YlOrRd if x is None: x = np.arange(z.shape[1]) else: x = np.array(x) if y is None: y = np.arange(z.shape[0]) else: y = np.array(y) xg, yg = np.meshgrid(x, y) cntr = ax.contour(xg, yg, z, v, colors="black") kwargs = {} if cfntsz is not None and cfntsz > 0: kwargs["fontsize"] = cfntsz if cfmt is not None: kwargs["fmt"] = cfmt if kwargs: plt.clabel(cntr, inline=True, **kwargs) pc = ax.pcolormesh( xg, yg, z, cmap=cmap, vmin=vmin, vmax=vmax, alpha=alpha, shading="gouraud", clim=(vmin, vmax), ) if xlog: ax.fmt_xdata = "{: .2e}".format else: ax.fmt_xdata = "{: .2f}".format if ylog: ax.fmt_ydata = "{: .2e}".format else: ax.fmt_ydata = "{: .2f}".format if title is not None: ax.set_title(title) if xlbl is not None: ax.set_xlabel(xlbl, fontsize=lfntsz) if ylbl is not None: ax.set_ylabel(ylbl, fontsize=lfntsz) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.2) plt.colorbar(pc, ax=ax, cax=cax) _attach_keypress(fig) _attach_zoom(ax) if have_mpldc: mpldc.datacursor() if figp is None: fig.show() return fig, ax def imview( img, title=None, copy=True, fltscl=False, intrp="nearest", norm=None, cbar=False, cmap=None, fgsz=None, fgnm=None, fig=None, ax=None, ): """Display an image. Display an image. Pixel values are displayed when the pointer is over valid image data. If a figure object is specified then the image is drawn in that figure, and `fig.show()` is not called. The figure is closed on key entry 'q'. Args: img (array_like, shape (Nr, Nc) or (Nr, Nc, 3) or (Nr, Nc, 4)): Image to display. title (string, optional (default ``None``)): Figure title. copy (boolean, optional (default ``True``)): If ``True``, create a copy of input `img` as a reference for displayed pixel values, ensuring that displayed values do not change when the array changes in the calling scope. Set this flag to ``False`` if the overhead of an additional copy of the input image is not acceptable. fltscl (boolean, optional (default ``False``)): If ``True``, rescale and shift floating point arrays to [0,1]. intrp (string, optional (default 'nearest')): Specify type of interpolation used to display image (see `interpolation` parameter of :meth:`matplotlib.axes.Axes.imshow`). norm (:class:`matplotlib.colors.Normalize` object, optional (default ``None``)): Specify the :class:`matplotlib.colors.Normalize` instance used to scale pixel values for input to the color map. cbar (boolean, optional (default ``False``)): Flag indicating whether to display colorbar. cmap (:class:`matplotlib.colors.Colormap`, optional (default ``None``)): Color map for image. If none specifed, defaults to `cm.Greys_r` for monochrome image. fgsz (tuple (width,height), optional (default ``None``)): Specify figure dimensions in inches. fgnm (integer, optional (default ``None``)): Figure number of figure. fig (:class:`matplotlib.figure.Figure` object, optional (default ``None``)): Draw in specified figure instead of creating one. ax (:class:`matplotlib.axes.Axes` object, optional (default ``None``)): Plot in specified axes instead of current axes of figure. Returns: - **fig** (:class:`matplotlib.figure.Figure` object): Figure object for this figure. - **ax** (:class:`matplotlib.axes.Axes` object): Axes object for this plot. Raises: ValueError: If the input array is not of the required shape. """ if img.ndim > 2 and img.shape[2] != 3: raise ValueError("Argument 'img' must be an Nr x Nc array or an Nr x Nc x 3 array.") figp = fig if fig is None: fig = plt.figure(num=fgnm, figsize=fgsz) fig.clf() ax = fig.gca() elif ax is None: ax = fig.gca() # Deal with removal of 'box-forced' adjustable in Matplotlib 2.2.0 mplv = matplotlib.__version__.split(".") if int(mplv[0]) > 2 or (int(mplv[0]) == 2 and int(mplv[1]) >= 2): try: ax.set_adjustable("box") except Exception: ax.set_adjustable("datalim") else: ax.set_adjustable("box-forced") imgd = img.copy() if copy: # Keep a separate copy of the input image so that the original # pixel values can be display rather than the scaled pixel # values that are actually plotted. img = img.copy() if cmap is None and img.ndim == 2: cmap = cm.Greys_r if np.issubdtype(img.dtype, np.floating): if fltscl: imgd -= imgd.min() imgd /= imgd.max() if img.ndim > 2: imgd = np.clip(imgd, 0.0, 1.0) elif img.dtype == np.uint16: imgd = np.float16(imgd) / np.iinfo(np.uint16).max elif img.dtype == np.int16: imgd = np.float16(imgd) - imgd.min() imgd /= imgd.max() if norm is None: im = ax.imshow(imgd, cmap=cmap, interpolation=intrp, vmin=imgd.min(), vmax=imgd.max()) else: im = ax.imshow(imgd, cmap=cmap, interpolation=intrp, norm=norm) ax.set_yticklabels([]) ax.set_xticklabels([]) if title is not None: ax.set_title(title) if cbar or cbar is None: orient = "vertical" if img.shape[0] >= img.shape[1] else "horizontal" pos = "right" if orient == "vertical" else "bottom" divider = make_axes_locatable(ax) cax = divider.append_axes(pos, size="5%", pad=0.2) if cbar is None: # See http://chris35wills.github.io/matplotlib_axis if hasattr(cax, "set_facecolor"): cax.set_facecolor("none") else: cax.set_axis_bgcolor("none") for axis in ["top", "bottom", "left", "right"]: cax.spines[axis].set_linewidth(0) cax.set_xticks([]) cax.set_yticks([]) else: plt.colorbar(im, ax=ax, cax=cax, orientation=orient) def format_coord(x, y): nr, nc = imgd.shape[0:2] col = int(x + 0.5) row = int(y + 0.5) if col >= 0 and col < nc and row >= 0 and row < nr: z = img[row, col] if imgd.ndim == 2: return "x=%6.2f, y=%6.2f, z=%.2f" % (x, y, z) return "x=%6.2f, y=%6.2f, z=(%.2f,%.2f,%.2f)" % sum(((x,), (y,), tuple(z)), ()) return "x=%.2f, y=%.2f" % (x, y) ax.format_coord = format_coord if fig.canvas.toolbar is not None: # See https://stackoverflow.com/a/47086132 def mouse_move(self, event): if event.inaxes and event.inaxes.get_navigate(): s = event.inaxes.format_coord(event.xdata, event.ydata) self.set_message(s) def mouse_move_patch(arg): return mouse_move(fig.canvas.toolbar, arg) fig.canvas.toolbar._idDrag = fig.canvas.mpl_connect("motion_notify_event", mouse_move_patch) _attach_keypress(fig) _attach_zoom(ax) if have_mpldc: mpldc.datacursor(display="single") if figp is None: fig.show() return fig, ax def close(fig=None): """Close figure(s). Close figure(s). If a figure object reference or figure number is provided, close the specified figure, otherwise close all figures. Args: fig (:class:`matplotlib.figure.Figure` object or integer (optional (default ``None``)): Figure object or number of figure to close. """ if fig is None: plt.close("all") else: plt.close(fig) def _in_ipython(): """Determine whether code is running in an ipython shell. Returns: bool: ``True`` if running in an ipython shell, ``False`` otherwise. """ try: # See https://stackoverflow.com/questions/15411967 shell = get_ipython().__class__.__name__ return bool(shell == "TerminalInteractiveShell") except NameError: return False def _in_notebook(): """Determine whether code is running in a Jupyter Notebook shell. Returns: bool: ``True`` if running in a notebook shell, ``False`` otherwise. """ try: # See https://stackoverflow.com/questions/15411967 shell = get_ipython().__class__.__name__ return bool(shell == "ZMQInteractiveShell") except NameError: return False def set_ipython_plot_backend(backend="qt"): """Set matplotlib backend within an ipython shell. Set matplotlib backend within an ipython shell. This function has the same effect as the line magic `%matplotlib [backend]` but is called as a function and includes a check to determine whether the code is running in an ipython shell, so that it can safely be used within a normal python script since it has no effect when not running in an ipython shell. Args: backend (string, optional (default 'qt')): Name of backend to be passed to the `%matplotlib` line magic command. """ if _in_ipython(): # See https://stackoverflow.com/questions/35595766 get_ipython().run_line_magic("matplotlib", backend) def set_notebook_plot_backend(backend="inline"): """Set matplotlib backend within a Jupyter Notebook shell. Set matplotlib backend within a Jupyter Notebook shell. This function has the same effect as the line magic `%matplotlib [backend]` but is called as a function and includes a check to determine whether the code is running in a notebook shell, so that it can safely be used within a normal python script since it has no effect when not running in a notebook shell. Args: backend (string, optional (default 'inline')): Name of backend to be passed to the `%matplotlib` line magic command. """ if _in_notebook(): # See https://stackoverflow.com/questions/35595766 get_ipython().run_line_magic("matplotlib", backend) def config_notebook_plotting(): """Configure plotting functions for inline plotting. Configure plotting functions for inline plotting within a Jupyter Notebook shell. This function has no effect when not within a notebook shell, and may therefore be used within a normal python script. If environment variable ``MATPLOTLIB_IPYNB_BACKEND`` is set, the matplotlib backend is explicitly set to the specified value. """ # Check whether running within a notebook shell and have # not already monkey patched the plot function module = sys.modules[__name__] if _in_notebook() and module.plot.__name__ == "plot": # Set backend if specified by environment variable if "MATPLOTLIB_IPYNB_BACKEND" in os.environ: set_notebook_plot_backend(os.environ["MATPLOTLIB_IPYNB_BACKEND"]) # Replace plot function with a wrapper function that discards # its return value (within a notebook with inline plotting, plots # are duplicated if the return value from the original function is # not assigned to a variable) plot_original = module.plot def plot_wrap(*args, **kwargs): plot_original(*args, **kwargs) module.plot = plot_wrap # Replace surf function with a wrapper function that discards # its return value (see comment for plot function) surf_original = module.surf def surf_wrap(*args, **kwargs): surf_original(*args, **kwargs) module.surf = surf_wrap # Replace contour function with a wrapper function that discards # its return value (see comment for plot function) contour_original = module.contour def contour_wrap(*args, **kwargs): contour_original(*args, **kwargs) module.contour = contour_wrap # Replace imview function with a wrapper function that discards # its return value (see comment for plot function) imview_original = module.imview def imview_wrap(*args, **kwargs): imview_original(*args, **kwargs) module.imview = imview_wrap # Disable figure show method (results in a warning if used within # a notebook with inline plotting) import matplotlib.figure def show_disable(self): pass matplotlib.figure.Figure.show = show_disable ================================================ FILE: scico/random.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Random number generation. This module provides convenient wrappers around several `jax.random `_ routines to handle the generation and splitting of PRNG keys, as well as the generation of random :class:`.BlockArray`: :: # Calls to scico.random functions always return a PRNG key # If no key is passed to the function, a new key is generated x, key = scico.random.randn((2,)) print(x) # [ 0.19307713 -0.52678305] # scico.random functions automatically split the PRNG key and return # an updated key y, key = scico.random.randn((2,), key=key) print(y) # [ 0.00870693 -0.04888531] The user is responsible for passing the PRNG key to :mod:`scico.random` functions. If no key is passed, repeated calls to :mod:`scico.random` functions will return the same random numbers: :: x, key = scico.random.randn((2,)) print(x) # [ 0.19307713 -0.52678305] # No key passed, will return the same random numbers! y, key = scico.random.randn((2,)) print(y) # [ 0.19307713 -0.52678305] If the desired shape is a tuple containing tuples, a :class:`.BlockArray` is returned: :: x, key = scico.random.randn( ((1, 1), (2,)), key=key) print(x) # scico.numpy.BlockArray: # Array([ 1.1378784 , -1.220955 , -0.59153646], dtype=float32) """ import inspect import sys from typing import Optional, Tuple, Union import numpy as np import jax from scico.numpy import Array, BlockArray from scico.numpy._wrappers import map_func_over_args from scico.typing import BlockShape, DType, PRNGKey, Shape def _add_seed(fun): """ Modify a :mod:`jax.random` function to add a `seed` argument. Args: fun: function to be modified, e.g., :func:`jax.random.normal`. Expects `key` to be the first argument. Returns: fun_alt: a version of `fun` supporting an optional `seed` argument that is used to create a :func:`jax.random.key` that is passed along as the `key`. The `key` argument may still be used, but is moved to be second-to-last. By default, `seed=0`. The `seed` argument is added last. Other arguments are unchanged. """ # find number of non-keyword-only parameters of fun num_params = len( [ param for param in inspect.signature(fun).parameters.values() if param.kind != param.KEYWORD_ONLY ] ) def fun_alt(*args, key=None, seed=None, **kwargs): # key and seed may be in *args, look for them if len(args) >= num_params: # they passed all position args including key key = args[num_params - 1] if len(args) > num_params: # they passed all position args including key and seed seed = args[num_params] if key is not None and seed is not None: raise ValueError("Arguments 'key' and 'seed' may not both be specified.") if key is None: if seed is None: seed = 0 key = jax.random.key(seed) result = fun(key, *args[: num_params - 1], **kwargs) key, subkey = jax.random.split(key, 2) return result, key lines = fun.__doc__.split("\n\n") fun_alt.__doc__ = "\n\n".join( lines[0:1] + [ f" Wrapped version of `jax.random.{fun.__name__} " f"`_. " "The SCICO version of this function moves the `key` argument to the end of the " "argument list, adds an additional `seed` argument after that, and allows the " "`shape` argument to accept a nested list, in which case a `BlockArray` is returned. " "Always returns a `(result, key)` tuple. Original docstring below.", ] + lines[1:] ) return fun_alt def _wrap(fun): fun_wrapped = _add_seed(map_func_over_args(fun, map_if_nested_args=["shape"])) fun_wrapped.__module__ = __name__ # so it appears in docs return fun_wrapped def _is_wrappable(fun): params = inspect.signature(getattr(jax.random, fun)).parameters prmkey = list(params.keys()) return prmkey and (prmkey[0] == "key") and ("shape" in params.keys()) wrappable_func_names = [ t[0] for t in inspect.getmembers(jax.random, inspect.isfunction) if _is_wrappable(t[0]) ] for name in wrappable_func_names: setattr(sys.modules[__name__], name, _wrap(getattr(jax.random, name))) def randn( shape: Union[Shape, BlockShape], dtype: DType = np.float32, key: Optional[PRNGKey] = None, seed: Optional[int] = None, ) -> Tuple[Union[Array, BlockArray], PRNGKey]: """Return an array drawn from the standard normal distribution. Alias for :func:`scico.random.normal`. Args: shape: Shape of output array. If shape is a tuple, a :class:`jax.Array` is returned. If shape is a tuple of tuples, a :class:`.BlockArray` is returned. key: JAX PRNGKey. Defaults to ``None``, in which case a new key is created using the seed arg. seed: Seed for new PRNGKey. Default: 0. dtype: dtype for returned value. Defaults to :attr:`~numpy.float32`. If a complex dtype such as :attr:`~numpy.complex64`, generates an array sampled from complex normal distribution. Returns: tuple: A tuple (x, key) containing: - **x** : (:class:`jax.Array`): Generated random array. - **key** : Updated random PRNGKey. """ return normal(shape, dtype, key, seed) # type: ignore ================================================ FILE: scico/ray/__init__.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2022-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Simplified interfaces to :doc:`Ray `.""" import os os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0" # suppress ray warning try: from ray import get, put from ray.tune import report except ImportError: raise ImportError("Could not import ray; please install it.") ================================================ FILE: scico/ray/tune.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Parameter tuning using :doc:`ray.tune `.""" import datetime import getpass import logging import os import tempfile from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union import ray try: import ray.tune os.environ["RAY_AIR_NEW_OUTPUT"] = "0" except ImportError: raise ImportError("Could not import ray.tune; please install it.") import ray.air from ray.tune import ( # noqa CheckpointConfig, RunConfig, Trainable, loguniform, uniform, with_parameters, ) from ray.tune.experiment.trial import Trial from ray.tune.progress_reporter import TuneReporterBase, _get_trials_by_state from ray.tune.result_grid import ResultGrid from ray.tune.schedulers import AsyncHyperBandScheduler from ray.tune.search.hyperopt import HyperOptSearch class _CustomReporter(TuneReporterBase): """Custom status reporter for :mod:`ray.tune`.""" def should_report(self, trials: List[Trial], done: bool = False): """Return boolean indicating whether progress should be reported.""" # Don't report on final call when done to avoid duplicate final output. return not done def report(self, trials: List[Trial], done: bool, *sys_info: Dict): """Report progress across trials.""" # Get dict of trials in each state. trials_by_state = _get_trials_by_state(trials) # Construct list of number of trials in each of three possible states. num_trials = [len(trials_by_state[state]) for state in ["PENDING", "RUNNING", "TERMINATED"]] # Construct string description of number of trials in each state. num_trials_str = f"P: {num_trials[0]:3d} R: {num_trials[1]:3d} T: {num_trials[2]:3d} " # Get current best trial. current_best_trial, metric = self._current_best_trial(trials) if current_best_trial is None: rslt_str = "" else: # If current best trial exists, construct string summary val = current_best_trial.last_result[metric] config = current_best_trial.last_result.get("config", {}) rslt_str = f" {metric}: {val:.2e} at " + ", ".join( [f"{k}: {v:.2e}" for k, v in config.items()] ) # If all trials terminated, print with newline, otherwise carriage return for overwrite if num_trials[0] + num_trials[1] == 0: end = "\n" else: end = "\r" print(num_trials_str + rslt_str, end=end) def run( run_or_experiment: Union[str, Callable, Type], metric: str, mode: str, time_budget_s: Union[None, int, float, datetime.timedelta] = None, num_samples: int = 1, resources_per_trial: Union[None, Mapping[str, Union[float, int, Mapping]]] = None, max_concurrent_trials: Optional[int] = None, config: Optional[Dict[str, Any]] = None, hyperopt: bool = True, verbose: bool = True, storage_path: Optional[str] = None, ) -> ray.tune.ExperimentAnalysis: """Simplified wrapper for `ray.tune.run`_. .. _ray.tune.run: https://github.com/ray-project/ray/blob/master/python/ray/tune/tune.py#L232 The `ray.tune.run`_ interface appears to be scheduled for deprecation. Use of :class:`Tuner`, which is a simplified interface to :class:`ray.tune.Tuner` is recommended instead. Args: run_or_experiment: Function that reports performance values. metric: Name of the metric reported in the performance evaluation function. mode: Either "min" or "max", indicating which represents better performance. time_budget_s: Maximum time allowed in seconds for the parameter search. num_samples: Number of parameter evaluation samples to compute. resources_per_trial: A dict mapping keys "cpu" and "gpu" to integers specifying the corresponding resources to allocate for each performance evaluation trial. max_concurrent_trials: Maximum number of trials to run concurrently. config: Specification of the parameter search space. hyperopt: If ``True``, use :class:`~ray.tune.search.hyperopt.HyperOptSearch` search, otherwise use simple random search (see :class:`~ray.tune.search.basic_variant.BasicVariantGenerator`). verbose: Flag indicating whether verbose operation is desired. When verbose operation is enabled, the number of pending, running, and terminated trials are indicated by "P:", "R:", and "T:" respectively, followed by the current best metric value and the parameters at which it was reported. storage_path: Directory in which to save tuning results. Defaults to a subdirectory "/ray_results" within the path returned by `tempfile.gettempdir()`, corresponding e.g. to "/tmp//ray_results" under Linux. Returns: Result of parameter search. """ kwargs = {} if hyperopt: kwargs.update( { "search_alg": HyperOptSearch(metric=metric, mode=mode), "scheduler": AsyncHyperBandScheduler(), } ) if verbose: kwargs.update({"verbose": 1, "progress_reporter": _CustomReporter()}) else: kwargs.update({"verbose": 0}) if isinstance(run_or_experiment, str): name = run_or_experiment else: name = run_or_experiment.__name__ name += "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") if storage_path is None: try: user = getpass.getuser() except Exception: # pragma: no cover user = "NOUSER" storage_path = os.path.join(tempfile.gettempdir(), user, "ray_results") # Record original logger.info logger_info = ray.tune.tune.logger.info # Replace logger.info with filtered version def logger_info_filter(msg, *args, **kwargs): if msg[0:15] != "Total run time:": logger_info(msg, *args, **kwargs) ray.tune.tune.logger.info = logger_info_filter result = ray.tune.run( run_or_experiment, metric=metric, mode=mode, name=name, time_budget_s=time_budget_s, num_samples=num_samples, storage_path=storage_path, resources_per_trial=resources_per_trial, max_concurrent_trials=max_concurrent_trials, reuse_actors=True, config=config, checkpoint_freq=0, **kwargs, ) # Restore original logger.info ray.tune.tune.logger.info = logger_info return result class Tuner(ray.tune.Tuner): """Simplified interface for :class:`ray.tune.Tuner`.""" def __init__( self, trainable: Union[Type[ray.tune.Trainable], Callable], *, param_space: Optional[Dict[str, Any]] = None, resources: Optional[Dict] = None, max_concurrent_trials: Optional[int] = None, metric: Optional[str] = None, mode: Optional[str] = None, num_samples: Optional[int] = None, num_iterations: Optional[int] = None, time_budget: Optional[int] = None, reuse_actors: bool = True, hyperopt: bool = True, verbose: bool = True, storage_path: Optional[str] = None, **kwargs, ): """ Args: trainable: Function that reports performance values. param_space: Specification of the parameter search space. resources: A dict mapping keys "cpu" and "gpu" to integers specifying the corresponding resources to allocate for each performance evaluation trial. max_concurrent_trials: Maximum number of trials to run concurrently. metric: Name of the metric reported in the performance evaluation function. mode: Either "min" or "max", indicating which represents better performance. num_samples: Number of parameter evaluation samples to compute. num_iterations: Number of training iterations for evaluation of a single configuration. Only required for the Tune Class API. time_budget: Maximum time allowed in seconds for a single parameter evaluation. reuse_actors: If ``True``, reuse the same process/object for multiple hyperparameters. hyperopt: If ``True``, use :class:`~ray.tune.search.hyperopt.HyperOptSearch` search, otherwise use simple random search (see :class:`~ray.tune.search.basic_variant.BasicVariantGenerator`). verbose: Flag indicating whether verbose operation is desired. When verbose operation is enabled, the number of pending, running, and terminated trials are indicated by "P:", "R:", and "T:" respectively, followed by the current best metric value and the parameters at which it was reported. storage_path: Directory in which to save tuning results. Defaults to a subdirectory "/ray_results" within the path returned by `tempfile.gettempdir()`, corresponding e.g. to "/tmp//ray_results" under Linux. """ k: Any # Avoid typing errors v: Any if resources is None: trainable_with_resources = trainable else: trainable_with_resources = ray.tune.with_resources(trainable, resources) tune_config = kwargs.pop("tune_config", None) tune_config_kwargs = { "mode": mode, "metric": metric, "num_samples": num_samples, "reuse_actors": reuse_actors, } if hyperopt: tune_config_kwargs.update( { "search_alg": HyperOptSearch(metric=metric, mode=mode), "scheduler": AsyncHyperBandScheduler(), } ) if max_concurrent_trials is not None: tune_config_kwargs.update({"max_concurrent_trials": max_concurrent_trials}) if tune_config is None: tune_config = ray.tune.TuneConfig(**tune_config_kwargs) else: for k, v in tune_config_kwargs.items(): setattr(tune_config, k, v) name = trainable.__name__ + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") if storage_path is None: try: user = getpass.getuser() except Exception: # pragma: no cover user = "NOUSER" storage_path = os.path.join(tempfile.gettempdir(), user, "ray_results") run_config = kwargs.pop("run_config", None) run_config_kwargs = {"name": name, "storage_path": storage_path, "verbose": 0} if verbose: run_config_kwargs.update({"verbose": 1, "progress_reporter": _CustomReporter()}) if num_iterations is not None or time_budget is not None: stop_criteria = {} if num_iterations is not None: stop_criteria.update({"training_iteration": num_iterations}) if time_budget is not None: stop_criteria.update({"time_total_s": time_budget}) run_config_kwargs.update({"stop": stop_criteria}) if run_config is None: run_config_kwargs.update( {"checkpoint_config": CheckpointConfig(checkpoint_at_end=False)} ) run_config = RunConfig(**run_config_kwargs) else: for k, v in run_config_kwargs.items(): setattr(run_config, k, v) super().__init__( trainable_with_resources, param_space=param_space, tune_config=tune_config, run_config=run_config, **kwargs, ) def fit(self) -> ResultGrid: """Initialize ray and call :meth:`ray.tune.Tuner.fit`. Initialize ray if not already initialized, and call :meth:`ray.tune.Tuner.fit`. If ray was not previously initialized, shut it down after fit process has completed. Returns: Result of parameter search. """ ray_init = ray.is_initialized() if not ray_init: ray.init(logging_level=logging.ERROR) results = super().fit() if not ray_init: ray.shutdown() return results ================================================ FILE: scico/scipy/__init__.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Wrapped versions of `jax.scipy `_ functions. This modules currently serves simply as a namespace for :mod:`scico.scipy.special`. """ from . import special ================================================ FILE: scico/scipy/special.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """:class:`~scico.numpy.BlockArray`-compatible :mod:`jax.scipy.special` functions. This modules is a wrapper for :mod:`jax.scipy.special` where some functions have been extended to automatically map over block array blocks as described in :ref:`numpy_functions_blockarray`. """ from typing import Tuple import jax.scipy.special as js from scico.numpy import _wrappers # add most everything in jax.scipy.special to this module _wrappers.add_attributes( to_dict=vars(), from_dict=js.__dict__, ) # wrap select functions functions: Tuple[str, ...] = ( "betainc", "entr", "erf", "erfc", "erfinv", "expit", "gammainc", "gammaincc", "gammaln", "i0", "i0e", "i1", "i1e", "log_ndtr", "logit", "logsumexp", "multigammaln", "ndtr", "ndtri", "polygamma", "xlog1py", "xlogy", "zeta", "digamma", ) if hasattr(js, "sph_harm_y"): # not available in all supported jax versions functions += ("sph_harm_y",) else: functions += ("sph_harm",) _wrappers.wrap_recursively(vars(), functions, _wrappers.map_func_over_args) # clean up del js, _wrappers ================================================ FILE: scico/solver.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Solver and optimization algorithms. This module provides a number of functions for solving linear systems and optimization problems, some of which are used as subproblem solvers within the iterations of the proximal algorithms in the :mod:`scico.optimize` subpackage. This module also provides scico interface wrappers for functions from :mod:`scipy.optimize` since jax directly implements only a very limited subset of these functions (there is limited, experimental support for `L-BFGS-B `_), but only CG and BFGS are fully supported. These wrappers are required because the functions in :mod:`scipy.optimize` only support on 1D, real valued, numpy arrays. These limitations are addressed by: - Enabling the use of multi-dimensional arrays by flattening and reshaping within the wrapper. - Enabling the use of jax arrays by automatically converting to and from numpy arrays. - Enabling the use of complex arrays by splitting them into real and imaginary parts. The wrapper also JIT compiles the function and gradient evaluations. These wrapper functions have a number of advantages and disadvantages with respect to those in :mod:`jax.scipy.optimize`: - This module provides many more algorithms than :mod:`jax.scipy.optimize`. - The functions in this module tend to be faster for small-scale problems (presumably due to some overhead in the jax functions). - The functions in this module are slower for large problems due to the frequent host-device copies corresponding to conversion between numpy arrays and jax arrays. - The solvers in this module can't be JIT compiled, and gradients cannot be taken through them. In the future, these wrapper functions may be replaced with a dependency on `JAXopt `__. """ from functools import wraps from typing import Any, Callable, Optional, Sequence, Tuple, Union import numpy as np import jax import jax.numpy as jnp import jax.scipy.linalg as jsl import scico.numpy as snp from scico.linop import ( CircularConvolve, ComposedLinearOperator, Diagonal, LinearOperator, MatrixOperator, Sum, ) from scico.metric import rel_res from scico.numpy import Array, BlockArray from scico.numpy.util import is_complex_dtype, is_nested, is_real_dtype from scico.typing import BlockShape, DType, Shape from scipy import optimize as spopt def _wrap_func(func: Callable, shape: Union[Shape, BlockShape], dtype: DType) -> Callable: """Function evaluation for use in :mod:`scipy.optimize`. Compute function evaluation (without gradient) for use in :mod:`scipy.optimize` functions. Reshapes the input to `func` to have `shape`. Evaluates `func`. Args: func: The function to minimize. shape: Shape of input to `func`. dtype: Data type of input to `func`. """ val_func = jax.jit(func) @wraps(func) def wrapper(x, *args): # apply val_grad_func to un-vectorized input val = val_func(_unravel(x, shape).astype(dtype), *args) # Convert val into numpy array, cast to float, convert to scalar val = np.array(val).astype(float) val = val.item() if val.ndim == 0 else val[0].item() return val return wrapper def _wrap_func_and_grad(func: Callable, shape: Union[Shape, BlockShape], dtype: DType) -> Callable: """Function evaluation and gradient for use in :mod:`scipy.optimize`. Compute function evaluation and gradient for use in :mod:`scipy.optimize` functions. Reshapes the input to `func` to have `shape`. Evaluates `func` and computes gradient. Ensures the returned `grad` is an ndarray. Args: func: The function to minimize. shape: Shape of input to `func`. dtype: Data type of input to `func`. """ # argnums=0 ensures only differentiate func wrt first argument, # in case func signature is func(x, *args) val_grad_func = jax.jit(jax.value_and_grad(func, argnums=0)) @wraps(func) def wrapper(x, *args): # apply val_grad_func to un-vectorized input val, grad = val_grad_func(_unravel(x, shape).astype(dtype), *args) # Convert val & grad into numpy arrays, then cast to float # Convert 'val' into a scalar, rather than ndarray of shape (1,) val = np.array(val).astype(float).item() grad = np.array(grad).astype(float).ravel() return val, grad return wrapper def _split_real_imag(x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: """Split an array of shape (N, M, ...) into real and imaginary parts. Args: x: Array to split. Returns: A real ndarray with stacked real/imaginary parts. If `x` has shape (M, N, ...), the returned array will have shape (2, M, N, ...) where the first slice contains the `x.real` and the second contains `x.imag`. If `x` is a BlockArray, this function is called on each block and the output is joined into a BlockArray. """ if isinstance(x, BlockArray): return snp.blockarray([_split_real_imag(_) for _ in x]) return snp.stack((snp.real(x), snp.imag(x))) def _join_real_imag(x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: """Join a real array of shape (2,N,M,...) into a complex array. Join a real array of shape (2,N,M,...) into a complex array of length (N,M, ...). Args: x: Array to join. Returns: A complex array with real and imaginary parts taken from `x[0]` and `x[1]` respectively. """ if isinstance(x, BlockArray): return snp.blockarray([_join_real_imag(_) for _ in x]) return x[0] + 1j * x[1] def _ravel(x: Union[Array, BlockArray]) -> Array: """Vectorize an array or blockarray to a 1d array. Args: x: Array or blockarray to be vectorized. Returns: Vectorized array. """ if isinstance(x, snp.BlockArray): return jnp.hstack(x.ravel().arrays) else: return x.ravel() def _unravel(x: Array, shape: Union[Shape, BlockShape]) -> Union[Array, BlockArray]: """Return a vectorized array or blockarray to its original shape. Args: x: Vectorized array representation. shape: Shape of original array or blockarray. Returns: Array or blockarray with original shape. """ if is_nested(shape): sizes = [np.prod(e).item() for e in shape] indices = np.cumsum(sizes[:-1]) chunks = jnp.split(x, indices) return snp.BlockArray([chunks[k].reshape(cs) for k, cs in enumerate(shape)]) else: return x.reshape(shape) def minimize( func: Callable, x0: Union[Array, BlockArray], args: Union[Tuple, Tuple[Any]] = (), method: str = "L-BFGS-B", hess: Optional[Union[Callable, str]] = None, hessp: Optional[Callable] = None, bounds: Optional[Union[Sequence, spopt.Bounds]] = None, constraints: Union[spopt.LinearConstraint, spopt.NonlinearConstraint, dict] = (), tol: Optional[float] = None, callback: Optional[Callable] = None, options: Optional[dict] = None, ) -> spopt.OptimizeResult: """Minimization of scalar function of one or more variables. Wrapper around :func:`scipy.optimize.minimize`. This function differs from :func:`scipy.optimize.minimize` in three ways: - The `jac` options of :func:`scipy.optimize.minimize` are not supported. The gradient is calculated using :func:`jax.grad`. - Functions mapping from N-dimensional arrays -> float are supported. - Functions mapping from complex arrays -> float are supported. For more detail, including descriptions of the optimization methods and custom minimizers, refer to the original docs for :func:`scipy.optimize.minimize`. """ if is_complex_dtype(x0.dtype): # scipy minimize function requires real-valued arrays, so # we split x0 into a vector with real/imaginary parts stacked # and compose `func` with a `_join_real_imag` iscomplex = True func_real = lambda x: func(_join_real_imag(x)) x0 = _split_real_imag(x0) else: iscomplex = False func_real = func x0_shape = x0.shape x0_dtype = x0.dtype x0 = _ravel(x0) # Run the SciPy minimizer if method in ( "CG, BFGS, Newton-CG, L-BFGS-B, TNC, SLSQP, dogleg, trust-ncg, trust-krylov, " "trust-exact, trust-constr" ).split( ", " ): # uses gradient info min_func = _wrap_func_and_grad(func_real, x0_shape, x0_dtype) jac = True # see scipy.minimize docs else: # does not use gradient info min_func = _wrap_func(func_real, x0_shape, x0_dtype) jac = False res = spopt.OptimizeResult({"x": None}) def fun(x0): nonlocal res # To use the external res res = spopt.minimize( min_func, x0=x0, args=args, jac=jac, method=method, options=options, ) # Return OptimizeResult with x0 as ndarray return res.x.astype(x0_dtype) res.x = jax.pure_callback( fun, jax.ShapeDtypeStruct(x0.shape, x0_dtype), x0, ) res.x = _unravel(res.x, x0_shape) # un-vectorize the output array from spopt.minimize if iscomplex: res.x = _join_real_imag(res.x) return res def minimize_scalar( func: Callable, bracket: Optional[Sequence[float]] = None, bounds: Optional[Sequence[float]] = None, args: Union[Tuple, Tuple[Any]] = (), method: str = "brent", tol: Optional[float] = None, options: Optional[dict] = None, ) -> spopt.OptimizeResult: """Minimization of scalar function of one variable. Wrapper around :func:`scipy.optimize.minimize_scalar`. For more detail, including descriptions of the optimization methods and custom minimizers, refer to the original docstring for :func:`scipy.optimize.minimize_scalar`. """ def f(x, *args): # Wrap jax-based function `func` to return a numpy float rather # than a jax array of size (1,) y = func(x, *args) return y.item() if y.ndim == 0 else y[0].item() res = spopt.minimize_scalar( fun=f, bracket=bracket, bounds=bounds, args=args, method=method, tol=tol, options=options, ) return res def cg( A: Callable, b: Array, x0: Optional[Array] = None, *, tol: float = 1e-5, atol: float = 0.0, maxiter: int = 1000, info: bool = True, M: Optional[Callable] = None, ) -> Tuple[Array, dict]: r"""Conjugate Gradient solver. Solve the linear system :math:`A\mb{x} = \mb{b}`, where :math:`A` is positive definite, via the conjugate gradient method. Args: A: Callable implementing linear operator :math:`A`, which should be positive definite. b: Input array :math:`\mb{b}`. x0: Initial solution. If `A` is a :class:`.LinearOperator`, this parameter need not be specified, and defaults to a zero array. Otherwise, it is required. tol: Relative residual stopping tolerance. Convergence occurs when `norm(residual) <= max(tol * norm(b), atol)`. atol: Absolute residual stopping tolerance. Convergence occurs when `norm(residual) <= max(tol * norm(b), atol)`. maxiter: Maximum iterations. Default: 1000. info: If ``True`` return a tuple consting of the solution array and a dictionary containing diagnostic information, otherwise just return the solution. M: Preconditioner for `A`. The preconditioner should approximate the inverse of `A`. The default, ``None``, uses no preconditioner. Returns: tuple: A tuple (x, info) containing: - **x** : Solution array. - **info**: Dictionary containing diagnostic information. """ if x0 is None: if isinstance(A, LinearOperator): x0 = snp.zeros(A.input_shape, b.dtype) else: raise ValueError( "Argument 'x0' must be specified if argument 'A' is not a LinearOperator." ) if M is None: M = lambda x: x x = x0 Ax = A(x0) bn = snp.linalg.norm(b) r = b - Ax z = M(r) p = z num = snp.sum(r.conj() * z) ii = 0 # termination tolerance (uses the "non-legacy" form of scicpy.sparse.linalg.cg) termination_tol_sq = snp.maximum(tol * bn, atol) ** 2 while (ii < maxiter) and (num > termination_tol_sq): Ap = A(p) alpha = num / snp.sum(p.conj() * Ap) x = x + alpha * p r = r - alpha * Ap z = M(r) num_old = num num = snp.sum(r.conj() * z) beta = num / num_old p = z + beta * p ii += 1 if info: return (x, {"num_iter": ii, "rel_res": snp.sqrt(num).real / bn}) else: return x def lstsq( A: Callable, b: Array, x0: Optional[Array] = None, tol: float = 1e-5, atol: float = 0.0, maxiter: int = 1000, info: bool = False, M: Optional[Callable] = None, ) -> Tuple[Array, dict]: r"""Least squares solver. Solve the least squares problem .. math:: \argmin_{\mb{x}} \; (1/2) \norm{ A \mb{x} - \mb{b} }_2^2 \;, where :math:`A` is a linear operator and :math:`\mb{b}` is a vector. The problem is solved using :func:`cg`. Args: A: Callable implementing linear operator :math:`A`. b: Input array :math:`\mb{b}`. x0: Initial solution. If `A` is a :class:`.LinearOperator`, this parameter need not be specified, and defaults to a zero array. Otherwise, it is required. tol: Relative residual stopping tolerance. Convergence occurs when `norm(residual) <= max(tol * norm(b), atol)`. atol: Absolute residual stopping tolerance. Convergence occurs when `norm(residual) <= max(tol * norm(b), atol)`. maxiter: Maximum iterations. Default: 1000. info: If ``True`` return a tuple consting of the solution array and a dictionary containing diagnostic information, otherwise just return the solution. M: Preconditioner for `A`. The preconditioner should approximate the inverse of `A`. The default, ``None``, uses no preconditioner. Returns: tuple: A tuple (x, info) containing: - **x** : Solution array. - **info**: Dictionary containing diagnostic information. """ if isinstance(A, LinearOperator): Aop = A else: assert x0 is not None Aop = LinearOperator( input_shape=x0.shape, output_shape=b.shape, eval_fn=A, input_dtype=b.dtype, output_dtype=b.dtype, ) ATA = Aop.T @ Aop ATb = Aop.T @ b return cg(ATA, ATb, x0=x0, tol=tol, atol=atol, maxiter=maxiter, info=info, M=M) def bisect( f: Callable, a: Array, b: Array, args: Tuple = (), xtol: float = 1e-7, ftol: float = 1e-7, maxiter: int = 100, full_output: bool = False, range_check: bool = True, ) -> Union[Array, dict]: """Vectorised root finding via bisection method. Vectorised root finding via bisection method, supporting simultaneous finding of multiple roots on a function defined over a multi-dimensional array. When the function is array-valued, each of these values is treated as the independent application of a scalar function. The initial interval `[a, b]` must bracket the root for all scalar functions. The interface is similar to that of :func:`scipy.optimize.bisect`, which is much faster when `f` is a scalar function and `a` and `b` are scalars. Args: f: Function returning a float or an array of floats. a: Lower bound of interval on which to apply bisection. b: Upper bound of interval on which to apply bisection. args: Additional arguments for function `f`. xtol: Stopping tolerance based on maximum bisection interval length over array. ftol: Stopping tolerance based on maximum absolute function value over array. maxiter: Maximum number of algorithm iterations. full_output: If ``False``, return just the root, otherwise return a tuple `(x, info)` where `x` is the root and `info` is a dict containing algorithm status information. range_check: If ``True``, check to ensure that the initial `[a, b]` range brackets the root of `f`. Returns: tuple: A tuple `(x, info)` containing: - **x** : Root array. - **info**: Dictionary containing diagnostic information. """ fa = f(*((a,) + args)) fb = f(*((b,) + args)) if range_check and snp.any(snp.sign(fa) == snp.sign(fb)): raise ValueError("Initial bisection range does not bracket zero.") for numiter in range(maxiter): c = (a + b) / 2.0 fc = f(*((c,) + args)) fcs = snp.sign(fc) a = snp.where(snp.logical_or(snp.sign(fa) * fcs == 1, fc == 0.0), c, a) b = snp.where(snp.logical_or(fcs * snp.sign(fb) == 1, fc == 0.0), c, b) fa = f(*((a,) + args)) fb = f(*((b,) + args)) xerr = snp.max(snp.abs(b - a)) ferr = snp.max(snp.abs(fc)) if xerr <= xtol and ferr <= ftol: break idx = snp.argmin(snp.stack((snp.abs(fa), snp.abs(fb))), axis=0) x = snp.choose(idx, (a, b)) if full_output: r = x, {"iter": numiter, "xerr": xerr, "ferr": ferr, "a": a, "b": b} else: r = x return r def golden( f: Callable, a: Array, b: Array, c: Optional[Array] = None, args: Tuple = (), xtol: float = 1e-7, maxiter: int = 100, full_output: bool = False, ) -> Union[Array, dict]: """Vectorised scalar minimization via golden section method. Vectorised scalar minimization via golden section method, supporting simultaneous minimization of a function defined over a multi-dimensional array. When the function is array-valued, each of these values is treated as the independent application of a scalar function. The minimizer must lie within the interval `(a, b)` for all scalar functions, and, if specified `c` must be within that interval. The interface is more similar to that of :func:`.bisect` than that of :func:`scipy.optimize.golden` which is much faster when `f` is a scalar function and `a`, `b`, and `c` are scalars. Args: f: Function returning a float or an array of floats. a: Lower bound of interval on which to search. b: Upper bound of interval on which to search. c: Initial value for first search point interior to bounding interval `(a, b)` args: Additional arguments for function `f`. xtol: Stopping tolerance based on maximum search interval length over array. maxiter: Maximum number of algorithm iterations. full_output: If ``False``, return just the minizer, otherwise return a tuple `(x, info)` where `x` is the minimizer and `info` is a dict containing algorithm status information. Returns: tuple: A tuple `(x, info)` containing: - **x** : Minimizer array. - **info**: Dictionary containing diagnostic information. """ gr = 2 / (snp.sqrt(5) + 1) if c is None: c = b - gr * (b - a) d = a + gr * (b - a) for numiter in range(maxiter): fc = f(*((c,) + args)) fd = f(*((d,) + args)) b = snp.where(fc < fd, d, b) a = snp.where(fc >= fd, c, a) xerr = snp.amax(snp.abs(b - a)) if xerr <= xtol: break c = b - gr * (b - a) d = a + gr * (b - a) fa = f(*((a,) + args)) fb = f(*((b,) + args)) idx = snp.argmin(snp.stack((fa, fb)), axis=0) x = snp.choose(idx, (a, b)) if full_output: r = (x, {"iter": numiter, "xerr": xerr}) else: r = x return r class MatrixATADSolver: r"""Solver for linear system involving a symmetric product. Solve a linear system of the form .. math:: (A^T W A + D) \mb{x} = \mb{b} or .. math:: (A^T W A + D) X = B \;, where :math:`A \in \mbb{R}^{M \times N}`, :math:`W \in \mbb{R}^{M \times M}` and :math:`D \in \mbb{R}^{N \times N}`. :math:`A` must be an instance of :class:`.MatrixOperator` or an array; :math:`D` must be an instance of :class:`.MatrixOperator`, :class:`.Diagonal`, or an array, and :math:`W`, if specified, must be an instance of :class:`.Diagonal` or an array. The solution is computed by factorization of matrix :math:`A^T W A + D` and solution via Gaussian elimination. If :math:`D` is diagonal and :math:`N < M` (i.e. :math:`A W A^T` is smaller than :math:`A^T W A`), then :math:`A W A^T + D` is factorized and the original problem is solved via the Woodbury matrix identity .. math:: (E + U C V)^{-1} = E^{-1} - E^{-1} U (C^{-1} + V E^{-1} U)^{-1} V E^{-1} \;. Setting .. math:: E &= D \\ U &= A^T \\ C &= W \\ V &= A we have .. math:: (D + A^T W A)^{-1} = D^{-1} - D^{-1} A^T (W^{-1} + A D^{-1} A^T)^{-1} A D^{-1} which can be simplified to .. math:: (D + A^T W A)^{-1} = D^{-1} (I - A^T G^{-1} A D^{-1}) by defining :math:`G = W^{-1} + A D^{-1} A^T`. We therefore have that .. math:: \mb{x} = (D + A^T W A)^{-1} \mb{b} = D^{-1} (I - A^T G^{-1} A D^{-1}) \mb{b} \;. If we have a Cholesky factorization of :math:`G`, e.g. :math:`G = L L^T`, we can define .. math:: \mb{w} = G^{-1} A D^{-1} \mb{b} so that .. math:: G \mb{w} &= A D^{-1} \mb{b} \\ L L^T \mb{w} &= A D^{-1} \mb{b} \;. The Cholesky factorization can be exploited by solving for :math:`\mb{z}` in .. math:: L \mb{z} = A D^{-1} \mb{b} and then for :math:`\mb{w}` in .. math:: L^T \mb{w} = \mb{z} \;, so that .. math:: \mb{x} = D^{-1} \mb{b} - D^{-1} A^T \mb{w} \;. (Functions :func:`~jax.scipy.linalg.cho_solve` and :func:`~jax.scipy.linalg.lu_solve` allow direct solution for :math:`\mb{w}` without the two-step procedure described here.) A Cholesky factorization should only be used when :math:`G` is positive-definite (e.g. :math:`D` is diagonal and positive); if not, an LU factorization should be used. Complex-valued problems are also supported, in which case the transpose :math:`\cdot^T` in the equations above should be taken to represent the conjugate transpose. To solve problems directly involving a matrix of the form :math:`A W A^T + D`, initialize with :code:`A.T` (or :code:`A.T.conj()` for complex problems) instead of :code:`A`. """ def __init__( self, A: Union[MatrixOperator, Array], D: Union[MatrixOperator, Diagonal, Array], W: Optional[Union[Diagonal, Array]] = None, cho_factor: bool = False, lower: bool = False, check_finite: bool = True, ): r""" Args: A: Matrix :math:`A`. D: Matrix :math:`D`. If a 2D array or :class:`MatrixOperator`, specifies the 2D matrix :math:`D`. If 1D array or :class:`Diagonal`, specifies the diagonal elements of :math:`D`. W: Matrix :math:`W`. Specifies the diagonal elements of :math:`W`. Defaults to an array with unit entries. cho_factor: Flag indicating whether to use Cholesky (``True``) or LU (``False``) factorization. lower: Flag indicating whether lower (``True``) or upper (``False``) triangular factorization should be computed. Only relevant to Cholesky factorization. check_finite: Flag indicating whether the input array should be checked for ``Inf`` and ``NaN`` values. """ A = jnp.array(A) if isinstance(D, Diagonal): D = D.diagonal if D.ndim > 1: # Identity operator has 0D diagonal raise ValueError("If Diagonal, 'D' should have a 0D or 1D diagonal.") else: D = jnp.array(D) if not D.ndim in [1, 2]: raise ValueError("If array or MatrixOperator, 'D' should be 1D or 2D.") if W is None: W = snp.ones(A.shape[0], dtype=A.dtype) elif isinstance(W, Diagonal): W = W.diagonal assert hasattr(W, "ndim") if W.ndim > 1: # Identity operator has 0D diagonal raise ValueError("If Diagonal, 'W' should have a 0D or 1D diagonal.") elif not isinstance(W, Array): raise TypeError( f"Operator 'W' is required to be None, a Diagonal, or an array; got a {type(W)}." ) self.A = A self.D = D self.W = W self.cho_factor = cho_factor self.lower = lower self.check_finite = check_finite assert isinstance(W, Array) N, M = A.shape if N < M and D.ndim <= 1: D2 = D if D.ndim == 0 else D[:, snp.newaxis] if W.ndim == 1: G = snp.diag(1.0 / W) + A @ (A.T.conj() / D2) else: # W is 0 dimensional (scalar equivalent) G = A @ (A.T.conj() / D2) G = jnp.fill_diagonal(G, G.diagonal() + (1.0 / W), inplace=False) else: W2 = W if W.ndim == 0 else W[:, snp.newaxis] if D.ndim == 1: G = A.T.conj() @ (W2 * A) + snp.diag(D) else: G = A.T.conj() @ (W2 * A) + D if cho_factor: c, lower = jsl.cho_factor(G, lower=lower, check_finite=check_finite) self.factor = (c, lower) else: lu, piv = jsl.lu_factor(G, check_finite=check_finite) self.factor = (lu, piv) def solve(self, b: Array, check_finite: Optional[bool] = None) -> Array: r"""Solve the linear system. Solve the linear system with right hand side :math:`\mb{b}` (`b` is a vector) or :math:`B` (`b` is a 2d array). Args: b: Vector :math:`\mathbf{b}` or matrix :math:`B`. check_finite: Flag indicating whether the input array should be checked for ``Inf`` and ``NaN`` values. If ``None``, use the value selected on initialization. Returns: Solution to the linear system. """ if check_finite is None: check_finite = self.check_finite if self.cho_factor: fact_solve = lambda x: jsl.cho_solve(self.factor, x, check_finite=check_finite) else: fact_solve = lambda x: jsl.lu_solve(self.factor, x, trans=0, check_finite=check_finite) if b.ndim <= 1: D = self.D else: D = self.D[:, snp.newaxis] N, M = self.A.shape if N < M and self.D.ndim <= 1: w = fact_solve(self.A @ (b / D)) x = (b - (self.A.T.conj() @ w)) / D else: x = fact_solve(b) return x def accuracy(self, x: Array, b: Array) -> float: r"""Compute solution relative residual. Args: x: Array :math:`\mathbf{x}` (solution). b: Array :math:`\mathbf{b}` (right hand side of linear system). Returns: Relative residual of solution. """ if b.ndim == 1: D = self.D else: D = self.D[:, snp.newaxis] assert isinstance(self.W, Array) return rel_res(self.A.T.conj() @ (self.W[:, snp.newaxis] * self.A) @ x + D * x, b) class ConvATADSolver: r"""Solver for a linear system involving a sum of convolutions. Solve a linear system of the form .. math:: (A^H A + D) \mb{x} = \mb{b} where :math:`A` is a block-row operator with circulant blocks, i.e. it can be written as .. math:: A = \left( \begin{array}{cccc} A_1 & A_2 & \ldots & A_{K} \end{array} \right) \;, where all of the :math:`A_k` are circular convolution operators, and :math:`D` is a circular convolution operator. This problem is most easily solved in the DFT transform domain, where the circular convolutions become diagonal operators. Denoting the frequency-domain versions of variables with a circumflex (e.g. :math:`\hat{\mb{x}}` is the frequency-domain version of :math:`\mb{x}`), the the problem can be written as .. math:: (\hat{A}^H \hat{A} + \hat{D}) \hat{\mb{x}} = \hat{\mb{b}} \;, where .. math:: \hat{A} = \left( \begin{array}{cccc} \hat{A}_1 & \hat{A}_2 & \ldots & \hat{A}_{K} \end{array} \right) \;, and :math:`\hat{D}` and all the :math:`\hat{A}_k` are diagonal operators. This linear equation is computational expensive to solve because the left hand side includes the term :math:`\hat{A}^H \hat{A}`, which corresponds to the outer product of :math:`\hat{A}^H` and :math:`\hat{A}`. A computationally efficient solution is possible, however, by exploiting the Woodbury matrix identity :cite:`wohlberg-2014-efficient` .. math:: (B + U C V)^{-1} = B^{-1} - B^{-1} U (C^{-1} + V B^{-1} U)^{-1} V B^{-1} \;. Setting .. math:: B &= \hat{D} \\ U &= \hat{A}^H \\ C &= I \\ V &= \hat{A} we have .. math:: (\hat{D} + \hat{A}^H \hat{A})^{-1} = \hat{D}^{-1} - \hat{D}^{-1} \hat{A}^H (I + \hat{A} \hat{D}^{-1} \hat{A}^H)^{-1} \hat{A} \hat{D}^{-1} which can be simplified to .. math:: (\hat{D} + \hat{A}^H \hat{A})^{-1} = \hat{D}^{-1} (I - \hat{A}^H \hat{E}^{-1} \hat{A} \hat{D}^{-1}) by defining :math:`\hat{E} = I + \hat{A} \hat{D}^{-1} \hat{A}^H`. The right hand side is much cheaper to compute because the only matrix inversions involve :math:`\hat{D}`, which is diagonal, and :math:`\hat{E}`, which is a weighted inner product of :math:`\hat{A}^H` and :math:`\hat{A}`. """ def __init__(self, A: ComposedLinearOperator, D: CircularConvolve): r""" Args: A: Operator :math:`A`. D: Operator :math:`D`. """ if not isinstance(A, ComposedLinearOperator): raise TypeError( f"Operator 'A' is required to be a ComposedLinearOperator; got a {type(A)}." ) if not isinstance(A.A, Sum) or not isinstance(A.B, CircularConvolve): raise TypeError( "Operator 'A' is required to be a composition of Sum and CircularConvolve" f"linear operators; got a composition of {type(A.A)} and {type(A.B)}." ) self.A = A self.D = D self.sum_axis = A.A.kwargs["axis"] if not isinstance(self.sum_axis, int): raise ValueError( "Sum component of operator 'A' must sum over a single axis of its input." ) self.fft_axes = A.B.x_fft_axes self.real_result = is_real_dtype(D.input_dtype) Ahat = A.B.h_dft Dhat = D.h_dft self.AHEinv = Ahat.conj() / ( 1.0 + snp.sum(Ahat * (Ahat.conj() / Dhat), axis=self.sum_axis, keepdims=True) ) def solve(self, b: Array) -> Array: r"""Solve the linear system. Solve the linear system with right hand side :math:`\mb{b}`. Args: b: Array :math:`\mathbf{b}`. Returns: Solution to the linear system. """ assert isinstance(self.A.B, CircularConvolve) Ahat = self.A.B.h_dft Dhat = self.D.h_dft bhat = snp.fft.fftn(b, axes=self.fft_axes) xhat = ( bhat - (self.AHEinv * (snp.sum(Ahat * bhat / Dhat, axis=self.sum_axis, keepdims=True))) ) / Dhat x = snp.fft.ifftn(xhat, axes=self.fft_axes) if self.real_result: x = x.real return x def accuracy(self, x: Array, b: Array) -> float: r"""Compute solution relative residual. Args: x: Array :math:`\mathbf{x}` (solution). b: Array :math:`\mathbf{b}` (right hand side of linear system). Returns: Relative residual of solution. """ return rel_res(self.A.gram_op(x) + self.D(x), b) ================================================ FILE: scico/test/conftest.py ================================================ """ Configure the --level pytest option and its functionality. """ import pytest def pytest_addoption(parser, pluginmanager): """Add --level pytest option. Level definitions: 1 Critical tests only 2 Skip tests that have a significant impact on coverage 3 All standard tests 4 Run all tests, including those marked as slow to run """ parser.addoption( "--level", action="store", default=3, type=int, help="Set test level to be run" ) def pytest_configure(config): """Add marker description.""" config.addinivalue_line("markers", "slow: mark test as slow to run") def pytest_collection_modifyitems(config, items): """Skip slow tests depending on selected testing level.""" if config.getoption("--level") >= 4: # don't skip tests at level 4 or higher return level_skip = pytest.mark.skip(reason="test not appropriate for selected level") for item in items: if "slow" in item.keywords: item.add_marker(level_skip) ================================================ FILE: scico/test/flax/test_apply.py ================================================ import os import tempfile import numpy as np import jax import pytest from test_trainer import SetupTest from flax.traverse_util import flatten_dict from scico import flax as sflax from scico.flax.train.apply import apply_fn from scico.flax.train.checkpoints import checkpoint_save, have_orbax from scico.flax.train.input_pipeline import IterateData from scico.flax.train.learning_rate import create_cnst_lr_schedule from scico.flax.train.state import create_basic_train_state @pytest.fixture(scope="module") def testobj(): yield SetupTest() def test_apply_fn(testobj): key = jax.random.key(seed=531) key1, key2 = jax.random.split(key) model = sflax.ConvBNNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) input_shape = (1, testobj.N, testobj.N, testobj.chn) variables = model.init({"params": key1}, np.ones(input_shape, model.dtype)) ds = IterateData(testobj.test_ds, testobj.bsize, train=False) try: batch = next(ds) output = apply_fn(model, variables, batch) except Exception as e: print(e) assert 0 else: assert output.shape[1:] == testobj.test_ds["label"].shape[1:] def test_except_only_apply(testobj): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) with pytest.raises(RuntimeError): out_ = sflax.only_apply( testobj.train_conf, model, testobj.test_ds, ) @pytest.mark.parametrize("model_cls", [sflax.DnCNNNet, sflax.ResNet, sflax.ConvBNNet, sflax.UNet]) def test_eval(testobj, model_cls): depth = testobj.model_conf["depth"] model = model_cls(depth, testobj.chn, testobj.model_conf["num_filters"]) if isinstance(model, sflax.DnCNNNet): depth = 3 model = sflax.DnCNNNet(depth, testobj.chn, testobj.model_conf["num_filters"]) key = jax.random.key(123) variables = model.init(key, testobj.train_ds["image"]) # from train script out_, _ = sflax.only_apply( testobj.train_conf, model, testobj.test_ds, variables=variables, ) # from scico FlaxMap util fmap = sflax.FlaxMap(model, variables) out_fmap = fmap(testobj.test_ds["image"]) np.testing.assert_allclose(out_, out_fmap, atol=5e-6) @pytest.mark.skipif(not have_orbax, reason="orbax.checkpoint package not installed") def test_apply_from_checkpoint(testobj): depth = 3 model = sflax.DnCNNNet(depth, testobj.chn, testobj.model_conf["num_filters"]) key = jax.random.key(123) variables = model.init(key, testobj.train_ds["image"]) temp_dir = tempfile.TemporaryDirectory() workdir = os.path.join(temp_dir.name, "temp_ckp") # State initialization learning_rate = create_cnst_lr_schedule(testobj.train_conf) state = create_basic_train_state( key, testobj.train_conf, model, (testobj.N, testobj.N), learning_rate ) flat_params1 = flatten_dict(state.params) flat_bstats1 = flatten_dict(state.batch_stats) params1 = [t[1] for t in sorted(flat_params1.items())] bstats1 = [t[1] for t in sorted(flat_bstats1.items())] train_conf = dict(testobj.train_conf) train_conf["checkpointing"] = True train_conf["workdir"] = workdir checkpoint_save(state, train_conf, workdir) try: output, variables = sflax.only_apply( train_conf, model, testobj.test_ds, ) except Exception as e: print(e) assert 0 else: flat_params2 = flatten_dict(variables["params"]) flat_bstats2 = flatten_dict(variables["batch_stats"]) params2 = [t[1] for t in sorted(flat_params2.items())] bstats2 = [t[1] for t in sorted(flat_bstats2.items())] for i in range(len(params1)): np.testing.assert_allclose(params1[i], params2[i], rtol=1e-5) for i in range(len(bstats1)): np.testing.assert_allclose(bstats1[i], bstats2[i], rtol=1e-5) ================================================ FILE: scico/test/flax/test_checkpoints.py ================================================ import os import tempfile import numpy as np import jax import pytest from test_trainer import SetupTest from flax.traverse_util import flatten_dict from scico import flax as sflax from scico.flax.train.checkpoints import checkpoint_restore, checkpoint_save, have_orbax from scico.flax.train.learning_rate import create_cnst_lr_schedule from scico.flax.train.state import create_basic_train_state @pytest.fixture(scope="module") def testobj(): yield SetupTest() @pytest.mark.skipif(not have_orbax, reason="orbax.checkpoint package not installed") def test_checkpoint(testobj): depth = 3 model = sflax.DnCNNNet(depth, testobj.chn, testobj.model_conf["num_filters"]) key = jax.random.key(123) variables = model.init(key, testobj.train_ds["image"]) temp_dir = tempfile.TemporaryDirectory() workdir = os.path.join(temp_dir.name, "temp_ckp") # State initialization learning_rate = create_cnst_lr_schedule(testobj.train_conf) state = create_basic_train_state( key, testobj.train_conf, model, (testobj.N, testobj.N), learning_rate ) flat_params1 = flatten_dict(state.params) flat_bstats1 = flatten_dict(state.batch_stats) params1 = [t[1] for t in sorted(flat_params1.items())] bstats1 = [t[1] for t in sorted(flat_bstats1.items())] try: checkpoint_save(state, testobj.train_conf, workdir) state_in = checkpoint_restore(state, workdir) except Exception as e: print(e) assert 0 else: flat_params2 = flatten_dict(state_in.params) flat_bstats2 = flatten_dict(state_in.batch_stats) params2 = [t[1] for t in sorted(flat_params2.items())] bstats2 = [t[1] for t in sorted(flat_bstats2.items())] for i in range(len(params1)): np.testing.assert_allclose(params1[i], params2[i], rtol=1e-5) for i in range(len(bstats1)): np.testing.assert_allclose(bstats1[i], bstats2[i], rtol=1e-5) @pytest.mark.skipif(not have_orbax, reason="orbax.checkpoint package not installed") @pytest.mark.parametrize("model_cls", [sflax.DnCNNNet, sflax.ResNet]) def test_checkpointing_from_trainer(testobj, model_cls): depth = 3 model = model_cls(depth, testobj.chn, testobj.model_conf["num_filters"]) temp_dir = tempfile.TemporaryDirectory() workdir = os.path.join(temp_dir.name, "temp_ckp") train_conf = dict(testobj.train_conf) train_conf["checkpointing"] = True train_conf["workdir"] = workdir train_conf["return_state"] = True # Create training object trainer = sflax.BasicFlaxTrainer( train_conf, model, testobj.train_ds, testobj.test_ds, ) try: state_out, _ = trainer.train() except Exception as e: print(e) assert 0 else: # Model parameters from training flat_params1 = flatten_dict(state_out.params) params1 = [t[1] for t in sorted(flat_params1.items())] # Model parameteres from checkpoint state_in = checkpoint_restore(state_out, workdir) flat_params2 = flatten_dict(state_in.params) params2 = [t[1] for t in sorted(flat_params2.items())] for i in range(len(params1)): np.testing.assert_allclose(params1[i], params2[i], rtol=1e-5) if hasattr(state_out, "batch_stats"): # Batch stats from training flat_bstats1 = flatten_dict(state_out.batch_stats) bstats1 = [t[1] for t in sorted(flat_bstats1.items())] # Batch stats from checkpoint flat_bstats2 = flatten_dict(state_in.batch_stats) bstats2 = [t[1] for t in sorted(flat_bstats2.items())] for i in range(len(bstats1)): np.testing.assert_allclose(bstats1[i], bstats2[i], rtol=1e-5) @pytest.mark.skipif(not have_orbax, reason="orbax.checkpoint package not installed") def test_checkpoint_exception(testobj): depth = 3 model = sflax.DnCNNNet(depth, testobj.chn, testobj.model_conf["num_filters"]) key = jax.random.key(123) variables = model.init(key, testobj.train_ds["image"]) temp_dir = tempfile.TemporaryDirectory() workdir = os.path.join(temp_dir.name, "temp_ckp") # State initialization learning_rate = create_cnst_lr_schedule(testobj.train_conf) state = create_basic_train_state( key, testobj.train_conf, model, (testobj.N, testobj.N), learning_rate ) with pytest.raises(FileNotFoundError): state_in = checkpoint_restore(state, workdir) ================================================ FILE: scico/test/flax/test_clu.py ================================================ import numpy as np import jax from flax.linen import Conv from flax.linen.module import Module, compact from scico import flax as sflax from scico.flax.train.clu_utils import ( _default_table_value_formatter, get_parameter_overview, ) def test_count_parameters(): N = 128 # signal size chn = 1 # number of channels # Model configuration mconf = { "depth": 2, "num_filters": 16, } model = sflax.ResNet(mconf["depth"], chn, mconf["num_filters"]) key = jax.random.key(seed=1234) input_shape = (1, N, N, chn) variables = model.init({"params": key}, np.ones(input_shape, model.dtype)) filter_sz = model.kernel_size[0] * model.kernel_size[1] # filter parameters output layer sum_manual_params = filter_sz * mconf["num_filters"] * chn # bias and scale of batch normalization output layer sum_manual_params += chn * 2 # mean and bar of batch normalization output layer sum_manual_bst = chn * 2 chn_prev = 1 for i in range(mconf["depth"] - 1): # filter parameters sum_manual_params += filter_sz * mconf["num_filters"] * chn_prev # bias and scale of batch normalization sum_manual_params += mconf["num_filters"] * 2 # mean and bar of batch normalization sum_manual_bst += mconf["num_filters"] * 2 chn_prev = mconf["num_filters"] total_nvar_params = sflax.count_parameters(variables["params"]) total_nvar_bst = sflax.count_parameters(variables["batch_stats"]) assert total_nvar_params == sum_manual_params assert total_nvar_bst == sum_manual_bst def test_count_parameters_empty(): assert sflax.count_parameters({}) == 0 # From https://github.com/google/CommonLoopUtils/blob/main/clu/parameter_overview_test.py EMPTY_PARAMETER_OVERVIEW = """+------+-------+------+------+-----+ | Name | Shape | Size | Mean | Std | +------+-------+------+------+-----+ +------+-------+------+------+-----+ Total weights: 0""" FLAX_CONV2D_PARAMETER_OVERVIEW = """+-------------+--------------+------+ | Name | Shape | Size | +-------------+--------------+------+ | conv/bias | (2,) | 2 | | conv/kernel | (3, 3, 3, 2) | 54 | +-------------+--------------+------+ Total weights: 56""" FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS = """+-------------+--------------+------+------+-----+ | Name | Shape | Size | Mean | Std | +-------------+--------------+------+------+-----+ | conv/bias | (2,) | 2 | 1.0 | 0.0 | | conv/kernel | (3, 3, 3, 2) | 54 | 1.0 | 0.0 | +-------------+--------------+------+------+-----+ Total weights: 56""" FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS = """+--------------------+--------------+------+------+-----+ | Name | Shape | Size | Mean | Std | +--------------------+--------------+------+------+-----+ | params/conv/bias | (2,) | 2 | 1.0 | 0.0 | | params/conv/kernel | (3, 3, 3, 2) | 54 | 1.0 | 0.0 | +--------------------+--------------+------+------+-----+ Total weights: 56""" # From https://github.com/google/CommonLoopUtils/blob/main/clu/parameter_overview_test.py def test_get_parameter_overview_empty(): assert get_parameter_overview({}) == EMPTY_PARAMETER_OVERVIEW class CNN(Module): @compact def __call__(self, x): return Conv(features=2, kernel_size=(3, 3), name="conv")(x) # From https://github.com/google/CommonLoopUtils/blob/main/clu/parameter_overview_test.py def test_get_parameter_overview(): rng = jax.random.key(42) # Weights of a 2D convolution with 2 filters.. variables = CNN().init(rng, np.zeros((2, 5, 5, 3))) variables = jax.tree_util.tree_map(jax.numpy.ones_like, variables) assert ( get_parameter_overview(variables["params"], include_stats=False) == FLAX_CONV2D_PARAMETER_OVERVIEW ) assert get_parameter_overview(variables["params"]) == FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS assert get_parameter_overview(variables) == FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS # From https://github.com/google/CommonLoopUtils/blob/main/clu/parameter_overview_test.py def test_printing_bool(): assert _default_table_value_formatter(True) == "True" assert _default_table_value_formatter(False) == "False" ================================================ FILE: scico/test/flax/test_examples_flax.py ================================================ import os import tempfile import numpy as np import pytest from scico import random from scico.flax.examples.data_generation import ( distributed_data_generation, generate_blur_data, generate_ct_data, generate_foam1_images, generate_foam2_images, have_ray, have_xdesign, ) from scico.flax.examples.data_preprocessing import ( CenterCrop, PaddedCircularConvolve, PositionalCrop, RandomNoise, build_image_dataset, flip, preprocess_images, rotation90, ) from scico.flax.examples.examples import ( get_cache_path, runtime_error_array, runtime_error_scalar, ) from scico.flax.examples.typed_dict import ConfigImageSetDict from scico.typing import Shape os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" # These tests are for the scico.flax.examples module, NOT the example scripts @pytest.mark.skipif(not have_xdesign, reason="xdesign package not installed") def test_foam1_gen(): seed = 4444 N = 32 ndata = 2 dt = generate_foam1_images(seed, N, ndata) assert dt.shape == (ndata, N, N, 1) @pytest.mark.skipif(not have_xdesign, reason="xdesign package not installed") def test_foam2_gen(): seed = 4321 N = 32 ndata = 2 dt = generate_foam2_images(seed, N, ndata) assert dt.shape == (ndata, N, N, 1) @pytest.mark.skipif(not have_ray, reason="ray package not installed") def test_distdatagen(): N = 16 nimg = 8 def random_data_gen(seed, N, ndata): np.random.seed(seed) dt = np.random.randn(ndata, N, N, 1) return dt dt = distributed_data_generation(random_data_gen, N, nimg) assert dt.ndim == 4 assert dt.shape == (nimg, N, N, 1) @pytest.mark.skipif( not have_ray or not have_xdesign, reason="ray or xdesign package not installed", ) def test_ct_data_generation(): N = 32 nimg = 8 nproj = 45 def random_img_gen(seed, size, ndata): np.random.seed(seed) shape = (ndata, size, size, 1) return np.random.randn(*shape) img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen) assert img.shape == (nimg, N, N, 1) assert sino.shape == (nimg, nproj, sino.shape[2], 1) assert fbp.shape == (nimg, N, N, 1) @pytest.mark.skipif(not have_ray or not have_xdesign, reason="ray or xdesign package not installed") def test_blur_data_generation(): N = 32 nimg = 8 n = 3 # convolution kernel size blur_kernel = np.ones((n, n)) / (n * n) def random_img_gen(seed, size, ndata): np.random.seed(seed) shape = (ndata, size, size, 1) return np.random.randn(*shape) img, blurn = generate_blur_data(nimg, N, blur_kernel, noise_sigma=0.01, imgfunc=random_img_gen) assert img.shape == (nimg, N, N, 1) assert blurn.shape == (nimg, N, N, 1) def test_rotation90(): N = 128 x, key = random.randn((N, N), seed=4321) x2, key = random.randn((10, N, N, 1), key=key) x_rot = rotation90(x) x2_rot = rotation90(x2) np.testing.assert_allclose(x_rot, np.swapaxes(x, 0, 1), rtol=1e-5) np.testing.assert_allclose(x2_rot, np.swapaxes(x2, 1, 2), rtol=1e-5) def test_flip(): N = 128 x, key = random.randn((N, N), seed=4321) x2, key = random.randn((10, N, N, 1), key=key) x_flip = flip(x) x2_flip = flip(x2) np.testing.assert_allclose(x_flip, x[:, ::-1, ...], rtol=1e-5) np.testing.assert_allclose(x2_flip, x2[..., ::-1, :], rtol=1e-5) @pytest.mark.parametrize("output_size", [128, (128, 128), (128, 64)]) def test_center_crop(output_size): N = 256 x, key = random.randn((N, N), seed=4321) if isinstance(output_size, int): ccrop = CenterCrop(output_size) else: shp: Shape = output_size ccrop = CenterCrop(shp) x_crop = ccrop(x) if isinstance(output_size, int): assert x_crop.shape[0] == output_size assert x_crop.shape[1] == output_size else: assert x_crop.shape == output_size @pytest.mark.parametrize("output_size", [128, (128, 128), (128, 64)]) def test_positional_crop(output_size): N = 256 x, key = random.randn((N, N), seed=4321) top, key = random.randint(shape=(1,), minval=0, maxval=N - 128, key=key) left, key = random.randint(shape=(1,), minval=0, maxval=N - 128, key=key) pcrop = PositionalCrop(output_size) x_crop = pcrop(x, top[0], left[0]) if isinstance(output_size, int): assert x_crop.shape[0] == output_size assert x_crop.shape[1] == output_size else: assert x_crop.shape == output_size @pytest.mark.parametrize("range_flag", [False, True]) def test_random_noise1(range_flag): N = 128 x, key = random.randn((N, N), seed=4321) noise = RandomNoise(0.1, range_flag) xn = noise(x) x2, key = random.randn((10, N, N, 1), key=key) xn2 = noise(x2) assert x.shape == xn.shape assert x2.shape == xn2.shape @pytest.mark.parametrize("shape", [(128, 128), (128, 128, 3), (5, 128, 128, 1)]) def test_random_noise2(shape): x, key = random.randn(shape, seed=4321) noise = RandomNoise(0.1, True) xn = noise(x) assert x.shape == xn.shape @pytest.mark.parametrize("output_size", [64, (64, 64)]) @pytest.mark.parametrize("gray_flag", [False, True]) @pytest.mark.parametrize("num_img_req", [None, 4]) def test_preprocess_images(output_size, gray_flag, num_img_req): num_img = 10 N = 128 C = 3 shape = (num_img, N, N, C) images, key = random.randn(shape, seed=4444) stride = 1 try: output = preprocess_images( images, output_size, gray_flag, num_img_req, multi_flag=False, stride=stride ) except Exception as e: print(e) assert 0 else: assert output.shape[1] == 64 assert output.shape[2] == 64 if gray_flag: assert output.shape[-1] == 1 else: assert output.shape[-1] == C if num_img_req is None: assert output.shape[0] == num_img else: assert output.shape[0] == num_img_req def test_preprocess_images_multi_flag(): num_img = 10 N = 128 C = 3 shape = (num_img, N, N, C) images, key = random.randn(shape, seed=4444) output_size = (64, 64) gray_flag = True num_img_req = 4 stride = 64 # 2 per side = 4 patches per image try: output = preprocess_images( images, output_size, gray_flag, num_img_req, multi_flag=True, stride=stride ) except Exception as e: print(e) assert 0 else: assert output.shape[0] == (4 * num_img_req) assert output.shape[1] == 64 assert output.shape[2] == 64 assert output.shape[-1] == 1 class SetupTest: def __init__(self): # Data configuration self.dtconf: ConfigImageSetDict = { "seed": 0, "output_size": 64, "stride": 1, "multi": False, "augment": False, "run_gray": True, "num_img": 10, "test_num_img": 4, "data_mode": "dn", "noise_level": 0.01, "noise_range": False, "test_split": 0.1, } @pytest.fixture(scope="module") def testobj(): yield SetupTest() @pytest.mark.parametrize("augment", [False, True]) def test_build_image_dataset(testobj, augment): num_train = testobj.dtconf["num_img"] num_test = testobj.dtconf["test_num_img"] N = 128 C = 3 shape = (num_train, N, N, C) img_train, key = random.randn(shape, seed=4444) img_test, key = random.randn((num_test, N, N, C), key=key) dtconf = dict(testobj.dtconf) dtconf["augment"] = augment train_ds, test_ds = build_image_dataset(img_train, img_test, dtconf) assert train_ds["image"].shape == train_ds["label"].shape assert test_ds["image"].shape == test_ds["label"].shape assert test_ds["label"].shape[0] == num_test if augment: assert train_ds["label"].shape[0] == num_train * 3 else: assert train_ds["label"].shape[0] == num_train def test_padded_circular_convolve(): N = 64 C = 3 kernel_size = 5 blur_sigma = 2.1 x, key = random.randn((N, N, C), seed=2468) pcc_op = PaddedCircularConvolve(N, C, kernel_size, blur_sigma) xblur = pcc_op(x) assert xblur.shape == x.shape def test_runtime_error_scalar(): with pytest.raises(RuntimeError): runtime_error_scalar("channels", "testing ", 3, 1) def test_runtime_error_array(): with pytest.raises(RuntimeError): runtime_error_array("channels", "testing ", 1e-2) def test_default_cache_path(): try: cache_path, cache_path_display = get_cache_path() except Exception as e: print(e) assert 0 else: cache_path_display == "~/.cache/scico/examples/data" def test_cache_path(): try: temp_dir = tempfile.TemporaryDirectory() cache_path = os.path.join(temp_dir.name, ".cache") cache_path_, cache_path_display = get_cache_path(cache_path) except Exception as e: print(e) assert 0 else: cache_path_ == cache_path cache_path_display == cache_path ================================================ FILE: scico/test/flax/test_flax.py ================================================ import os import tempfile from functools import partial import numpy as np import pytest from flax.core import unfreeze from flax.errors import ScopeParamShapeError from flax.linen import BatchNorm, Conv, elu, leaky_relu, max_pool, relu from scico import flax as sflax from scico.data import _flax_data_path from scico.random import randn class TestSet: def test_convnblock_default(self): nflt = 16 # number of filters conv = partial(Conv, dtype=np.float32) norm = partial(BatchNorm, dtype=np.float32) flxm = sflax.blocks.ConvBNBlock( num_filters=nflt, conv=conv, norm=norm, act=relu, ) assert flxm.kernel_size == (3, 3) # size of kernel assert flxm.strides == (1, 1) # stride of convolution def test_convnblock_args(self): nflt = 16 # number of filters ksz = (5, 5) # size of kernel strd = (2, 2) # stride of convolution conv = partial(Conv, dtype=np.float32) norm = partial(BatchNorm, dtype=np.float32) flxm = sflax.blocks.ConvBNBlock( num_filters=nflt, conv=conv, norm=norm, act=leaky_relu, kernel_size=ksz, strides=strd, ) assert flxm.act == leaky_relu assert flxm.kernel_size == ksz # size of kernel assert flxm.strides == strd # stride of convolution def test_convblock_default(self): nflt = 16 # number of filters conv = partial(Conv, dtype=np.float32) flxm = sflax.blocks.ConvBlock( num_filters=nflt, conv=conv, act=relu, ) assert flxm.kernel_size == (3, 3) # size of kernel assert flxm.strides == (1, 1) # stride of convolution def test_convblock_args(self): nflt = 16 # number of filters ksz = (5, 5) # size of kernel strd = (2, 2) # stride of convolution conv = partial(Conv, dtype=np.float32) flxm = sflax.blocks.ConvBlock( num_filters=nflt, conv=conv, act=elu, kernel_size=ksz, strides=strd, ) assert flxm.act == elu assert flxm.kernel_size == ksz # size of kernel assert flxm.strides == strd # stride of convolution def test_convblock_call(self): nflt = 16 # number of filters ksz = (5, 5) # size of kernel strd = (2, 2) # stride of convolution conv = partial(Conv, dtype=np.float32) flxb = sflax.blocks.ConvBlock( num_filters=nflt, conv=conv, act=elu, kernel_size=ksz, strides=strd, ) chn = 1 # number of channels N = 128 # image size x, key = randn((10, N, N, chn), seed=1234) variables = flxb.init(key, x) # Test for the construction / forward pass. cbx = flxb.apply(variables, x) assert x.dtype == cbx.dtype def test_convnpblock_args(self): nflt = 16 # number of filters ksz = (5, 5) # size of kernel strd = (2, 2) # stride of convolution wnd = (2, 2) # window for pooling conv = partial(Conv, dtype=np.float32) norm = partial(BatchNorm, dtype=np.float32) flxm = sflax.blocks.ConvBNPoolBlock( num_filters=nflt, conv=conv, norm=norm, act=relu, pool=max_pool, kernel_size=ksz, strides=strd, window_shape=wnd, ) assert flxm.act == relu assert flxm.kernel_size == ksz # size of kernel assert flxm.strides == strd # stride of convolution def test_convnublock_args(self): nflt = 16 # number of filters ksz = (5, 5) # size of kernel strd = (2, 2) # stride of convolution upsampling = 2 # upsampling factor conv = partial(Conv, dtype=np.float32) norm = partial(BatchNorm, dtype=np.float32) upfn = partial(sflax.blocks.upscale_nn, scale=upsampling) flxm = sflax.blocks.ConvBNUpsampleBlock( num_filters=nflt, conv=conv, norm=norm, act=relu, upfn=upfn, kernel_size=ksz, strides=strd, ) assert flxm.act == relu assert flxm.kernel_size == ksz # size of kernel assert flxm.strides == strd # stride of convolution def test_convmnblock_default(self): nblck = 2 # number of blocks nflt = 16 # number of filters conv = partial(Conv, dtype=np.float32) norm = partial(BatchNorm, dtype=np.float32) flxm = sflax.blocks.ConvBNMultiBlock( num_blocks=nblck, num_filters=nflt, conv=conv, norm=norm, act=relu, ) assert flxm.kernel_size == (3, 3) # size of kernel assert flxm.strides == (1, 1) # stride of convolution def test_upscale(self): N = 128 # image size chn = 3 # channels x, key = randn((10, N, N, chn), seed=1234) xups = sflax.blocks.upscale_nn(x) assert xups.shape == (10, 2 * N, 2 * N, chn) def test_resnet_default(self): depth = 3 # depth of model chn = 1 # number of channels num_filters = 16 # number of filters per layer N = 128 # image size x, key = randn((10, N, N, chn), seed=1234) resnet = sflax.ResNet( depth=depth, channels=chn, num_filters=num_filters, ) variables = resnet.init(key, x) # Test for the construction / forward pass. rnx = resnet.apply(variables, x, train=False, mutable=False) assert x.dtype == rnx.dtype def test_unet_default(self): depth = 2 # depth of model chn = 1 # number of channels num_filters = 16 # number of filters per layer N = 128 # image size x, key = randn((10, N, N, chn), seed=1234) unet = sflax.UNet( depth=depth, channels=chn, num_filters=num_filters, ) variables = unet.init(key, x) # Test for the construction / forward pass. unx = unet.apply(variables, x, train=False, mutable=False) assert x.dtype == unx.dtype class DnCNNNetTest: def __init__(self): depth = 3 # depth of model chn = 1 # number of channels num_filters = 16 # number of filters per layer N = 128 # image size self.x, key = randn((10, N, N, chn), seed=1234) self.dncnn = sflax.DnCNNNet( depth=depth, channels=chn, num_filters=num_filters, ) self.variables = self.dncnn.init(key, self.x) @pytest.fixture(scope="module") def testobj(): yield DnCNNNetTest() def test_DnCNN_call(testobj): # Test for the construction / forward pass. dnx = testobj.dncnn.apply(testobj.variables, testobj.x, train=False, mutable=False) assert testobj.x.dtype == dnx.dtype def test_DnCNN_train(testobj): # Test effect of training flag. bn0bias_before = testobj.variables["params"]["ConvBNBlock_0"]["BatchNorm_0"]["bias"] bn0mean_before = testobj.variables["batch_stats"]["ConvBNBlock_0"]["BatchNorm_0"]["mean"] dnx, new_state = testobj.dncnn.apply( testobj.variables, testobj.x, train=True, mutable=["batch_stats"] ) bn0mean_new = new_state["batch_stats"]["ConvBNBlock_0"]["BatchNorm_0"]["mean"] bn0bias_after = testobj.variables["params"]["ConvBNBlock_0"]["BatchNorm_0"]["bias"] bn0mean_after = testobj.variables["batch_stats"]["ConvBNBlock_0"]["BatchNorm_0"]["mean"] try: np.testing.assert_allclose(bn0bias_before, bn0bias_after, rtol=1e-5) np.testing.assert_allclose( bn0mean_new - bn0mean_before, bn0mean_new + bn0mean_after, rtol=1e-5 ) except Exception as e: print(e) assert 0 def test_DnCNN_test(testobj): # Test effect of training flag. bn0var_before = testobj.variables["batch_stats"]["ConvBNBlock_0"]["BatchNorm_0"]["var"] dnx, new_state = testobj.dncnn.apply( testobj.variables, testobj.x, train=False, mutable=["batch_stats"] ) bn0var_after = new_state["batch_stats"]["ConvBNBlock_0"]["BatchNorm_0"]["var"] np.testing.assert_allclose(bn0var_before, bn0var_after, rtol=1e-5) def test_FlaxMap_call(testobj): # Test for the usage of flax model as a map. # 2D evaluation signal. fmap = sflax.FlaxMap(testobj.dncnn, testobj.variables) N = 128 # image size x, key = randn((N, N)) out = fmap(x) assert x.dtype == out.dtype assert x.ndim == out.ndim def test_FlaxMap_3D_call(testobj): # Test for the usage of flax model as a map. # 3D evaluation signal. fmap = sflax.FlaxMap(testobj.dncnn, testobj.variables) N = 128 # image size chn = 1 # channels x, key = randn((N, N, chn)) out = fmap(x) assert x.dtype == out.dtype assert x.ndim == out.ndim def test_FlaxMap_batch_call(testobj): # Test for the usage of flax model as a map. # 4D evaluation signal. fmap = sflax.FlaxMap(testobj.dncnn, testobj.variables) N = 128 # image size chn = 1 # channels batch = 8 # batch size x, key = randn((batch, N, N, chn)) out = fmap(x) assert x.dtype == out.dtype assert x.ndim == out.ndim def test_FlaxMap_blockarray_exception(testobj): from scico.numpy import BlockArray fmap = sflax.FlaxMap(testobj.dncnn, testobj.variables) x0, key = randn(shape=(3, 4), seed=4321) x1, key = randn(shape=(4, 5, 6), key=key) x = BlockArray((x0, x1)) with pytest.raises(NotImplementedError): fmap(x) @pytest.mark.parametrize("variant", ["6L", "6M", "6H", "17L", "17M", "17H"]) def test_variable_load(variant): N = 128 # image size chn = 1 # channels x, key = randn((10, N, N, chn), seed=1234) if variant[0] == "6": nlayer = 6 else: nlayer = 17 model = sflax.DnCNNNet(depth=nlayer, channels=chn, num_filters=64, dtype=np.float32) # Load weights for DnCNN. variables = sflax.load_variables(_flax_data_path("dncnn%s.mpk" % variant)) try: fmap = sflax.FlaxMap(model, variables) out = fmap(x) except Exception as e: print(e) assert 0 def test_variable_load_mismatch(): N = 128 # image size chn = 1 # channels x, key = randn((10, N, N, chn), seed=1234) nlayer = 6 model = sflax.ResNet(depth=nlayer, channels=chn, num_filters=64, dtype=np.float32) # Load weights for DnCNN. variables = sflax.load_variables(_flax_data_path("dncnn6L.mpk")) # created with mismatched parameters fmap = sflax.FlaxMap(model, variables) with pytest.raises(ScopeParamShapeError): fmap(x) def test_variable_save(): N = 128 # image size chn = 1 # channels x, key = randn((10, N, N, chn), seed=1234) nlayer = 6 model = sflax.ResNet(depth=nlayer, channels=chn, num_filters=64, dtype=np.float32) aux, key = randn((1,), seed=23432) input_shape = (1, N, N, chn) variables = model.init({"params": key}, np.ones(input_shape, model.dtype)) try: temp_dir = tempfile.TemporaryDirectory() sflax.save_variables(unfreeze(variables), os.path.join(temp_dir.name, "vres6.mpk")) except Exception as e: print(e) assert 0 ================================================ FILE: scico/test/flax/test_inv.py ================================================ import os from functools import partial import numpy as np import jax.numpy as jnp from jax import lax from scico import flax as sflax from scico import random from scico.flax.examples import PaddedCircularConvolve, build_blur_kernel from scico.flax.train.traversals import clip_positive, clip_range, construct_traversal from scico.linop import CircularConvolve, Identity from scico.linop.xray import XRayTransform2D os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" class TestSet: def setup_method(self, method): self.depth = 1 # depth (equivalent to number of blocks) of model self.chn = 1 # number of channels self.num_filters = 16 # number of filters per layer self.block_depth = 2 # number of layers in block self.N = 128 # image size def test_odpdn_default(self): y, key = random.randn((10, self.N, self.N, self.chn), seed=1234) opI = Identity(y.shape) odpdn = sflax.ODPNet( operator=opI, depth=self.depth, channels=self.chn, num_filters=self.num_filters, block_depth=self.block_depth, ) variables = odpdn.init(key, y) # Test for the construction / forward pass. mny = odpdn.apply(variables, y, train=False, mutable=False) assert y.dtype == mny.dtype assert y.shape == mny.shape def test_odpdcnv_default(self): y, key = random.randn((10, self.N, self.N, self.chn), seed=1234) blur_shape = (9, 9) blur_sigma = 2.24 kernel = build_blur_kernel(blur_shape, blur_sigma) ishape = (self.N, self.N) opBlur = CircularConvolve(h=kernel, input_shape=ishape) odpdb = sflax.ODPNet( operator=opBlur, depth=self.depth, channels=self.chn, num_filters=self.num_filters, block_depth=self.block_depth, odp_block=sflax.inverse.ODPProxDcnvBlock, ) variables = odpdb.init(key, y) # Test for the construction / forward pass. mny = odpdb.apply(variables, y, train=False, mutable=False) assert y.dtype == mny.dtype assert y.shape == mny.shape def test_odpdcnv_padded(self): y, key = random.randn((10, self.N, self.N, self.chn), seed=1234) blur_shape = (9, 9) blur_sigma = 2.24 opBlur = PaddedCircularConvolve(self.N, self.chn, blur_shape, blur_sigma) odpdb = sflax.ODPNet( operator=opBlur, depth=self.depth, channels=self.chn, num_filters=self.num_filters, block_depth=self.block_depth, odp_block=sflax.inverse.ODPProxDcnvBlock, ) variables = odpdb.init(key, y) # Test for the construction / forward pass. mny = odpdb.apply(variables, y, train=False, mutable=False) assert y.dtype == mny.dtype assert y.shape == mny.shape def test_train_odpdcnv_default(self): xt, key = random.randn((10, self.N, self.N, self.chn), seed=4444) blur_shape = (7, 7) blur_sigma = 3.3 kernel = build_blur_kernel(blur_shape, blur_sigma) ishape = (self.N, self.N) opBlur = CircularConvolve(h=kernel, input_shape=ishape) model = sflax.ODPNet( operator=opBlur, depth=self.depth, channels=self.chn, num_filters=self.num_filters, block_depth=self.block_depth, odp_block=sflax.inverse.ODPProxDcnvBlock, ) train_conf: sflax.ConfigDict = { "seed": 0, "opt_type": "ADAM", "batch_size": 8, "num_epochs": 2, "base_learning_rate": 1e-3, "warmup_epochs": 0, "num_train_steps": -1, "steps_per_eval": -1, "steps_per_epoch": 1, "log_every_steps": 1000, } a_f = lambda v: jnp.atleast_3d(opBlur(v.reshape(opBlur.input_shape))) y = lax.map(a_f, xt) train_ds = {"image": y, "label": xt} test_ds = {"image": y, "label": xt} try: alphatrav = construct_traversal("alpha") alphapos = partial(clip_positive, traversal=alphatrav, minval=1e-3) train_conf["post_lst"] = [alphapos] trainer = sflax.BasicFlaxTrainer( train_conf, model, train_ds, test_ds, ) modvar, _ = trainer.train() except Exception as e: print(e) assert 0 else: alphaval = np.array([alpha for alpha in alphatrav.iterate(modvar["params"])]) np.testing.assert_array_less(1e-2 * np.ones(alphaval.shape), alphaval) class TestCT: def setup_method(self, method): self.N = 32 # signal size self.chn = 1 # number of channels self.bsize = 16 # batch size xt, key = random.randn((2 * self.bsize, self.N, self.N, self.chn), seed=4321) self.nproj = 60 # number of projections angles = np.linspace(0, np.pi, self.nproj, endpoint=False, dtype=np.float32) self.opCT = XRayTransform2D( input_shape=(self.N, self.N), det_count=self.N, angles=angles, dx=0.9999 / np.sqrt(2.0) ) # Radon transform operator a_f = lambda v: jnp.atleast_3d(self.opCT(v.squeeze())) y = lax.map(a_f, xt) self.train_ds = {"image": y, "label": xt} self.test_ds = {"image": y, "label": xt} # Model configuration self.model_conf = { "depth": 1, "num_filters": 16, "block_depth": 2, } # Training configuration self.train_conf: sflax.ConfigDict = { "seed": 0, "opt_type": "ADAM", "batch_size": self.bsize, "num_epochs": 2, "base_learning_rate": 1e-3, "warmup_epochs": 0, "num_train_steps": -1, "steps_per_eval": -1, "steps_per_epoch": 1, "log_every_steps": 1000, } def test_odpct_default(self): y, key = random.randn((10, self.nproj, self.N, self.chn), seed=1234) model = sflax.ODPNet( operator=self.opCT, depth=self.model_conf["depth"], channels=self.chn, num_filters=self.model_conf["num_filters"], block_depth=self.model_conf["block_depth"], odp_block=sflax.inverse.ODPGrDescBlock, ) variables = model.init(key, y) # Test for the construction / forward pass. oy = model.apply(variables, y, train=False, mutable=False) assert y.dtype == oy.dtype def test_modlct_default(self): y, key = random.randn((10, self.nproj, self.N, self.chn), seed=1234) model = sflax.MoDLNet( operator=self.opCT, depth=self.model_conf["depth"], channels=self.chn, num_filters=self.model_conf["num_filters"], block_depth=self.model_conf["block_depth"], ) variables = model.init(key, y) # Test for the construction / forward pass. mny = model.apply(variables, y, train=False, mutable=False) assert y.dtype == mny.dtype def test_train_modl(self): model = sflax.MoDLNet( operator=self.opCT, depth=self.model_conf["depth"], channels=self.chn, num_filters=self.model_conf["num_filters"], block_depth=self.model_conf["block_depth"], ) try: minval = 1.1e-2 lmbdatrav = construct_traversal("lmbda") lmbdapos = partial( clip_positive, traversal=lmbdatrav, minval=minval, ) train_conf = dict(self.train_conf) train_conf["post_lst"] = [lmbdapos] trainer = sflax.BasicFlaxTrainer( train_conf, model, self.train_ds, self.test_ds, ) modvar, _ = trainer.train() except Exception as e: print(e) assert 0 else: lmbdaval = np.array([lmb for lmb in lmbdatrav.iterate(modvar["params"])]) np.testing.assert_array_less(1e-2 * np.ones(lmbdaval.shape), lmbdaval) def test_train_odpct(self): model = sflax.ODPNet( operator=self.opCT, depth=self.model_conf["depth"], channels=self.chn, num_filters=self.model_conf["num_filters"], block_depth=self.model_conf["block_depth"], odp_block=sflax.inverse.ODPGrDescBlock, ) try: minval = 1.1e-2 maxval = 1e2 alphatrav = construct_traversal("alpha") alpharange = partial(clip_range, traversal=alphatrav, minval=minval, maxval=maxval) train_conf = dict(self.train_conf) train_conf["post_lst"] = [alpharange] trainer = sflax.BasicFlaxTrainer( train_conf, model, self.train_ds, self.test_ds, ) modvar, _ = trainer.train() except Exception as e: print(e) assert 0 else: alphaval = np.array([alpha for alpha in alphatrav.iterate(modvar["params"])]) np.testing.assert_array_less(1e-2 * np.ones(alphaval.shape), alphaval) ================================================ FILE: scico/test/flax/test_spectral.py ================================================ from functools import partial from typing import Any, Tuple import numpy as np import jax import pytest from flax.linen import Conv from flax.linen.module import Module, compact from scico import flax as sflax from scico import linop from scico.flax.train.spectral import ( _l2_normalize, conv, estimate_spectral_norm, exact_spectral_norm, spectral_normalization_conv, ) from scico.flax.train.traversals import construct_traversal from scico.random import randn def test_l2_normalize(): N = 256 x, key = randn((N, N), seed=135) eps = 1e-6 l2_jnp = jax.numpy.sqrt((x**2).sum()) l2n_jnp = x / (l2_jnp + eps) l2n_util = _l2_normalize(x, eps) np.testing.assert_allclose(l2n_jnp, l2n_util, rtol=eps) @pytest.mark.parametrize("kernel_size", [(3, 3, 1, 1), (11, 11, 1, 1)]) def test_conv(kernel_size): key = jax.random.key(97531) kernel, key = randn(kernel_size, dtype=np.float32, key=key) input_size = (1, 128, 128, 1) x, key = randn(input_size, dtype=np.float32, key=key) pads = ( [(0, 0)] + [(kernel_size[0] // 2, kernel_size[0] // 2)] + [(kernel_size[1] // 2, kernel_size[1] // 2)] + [(0, 0)] ) xext = np.pad(x, pads, mode="wrap") y = jax.scipy.signal.convolve(xext.squeeze(), jax.numpy.flip(kernel).squeeze(), mode="valid") y_util = conv(x, kernel).squeeze() np.testing.assert_allclose(y, y_util) class CNN(Module): kernel_size: Tuple[int, int] kernel0: Any @compact def __call__(self, x): def kinit_wrap(rng, shape, dtype=np.float32): return np.array(self.kernel0, dtype) return Conv( features=1, kernel_size=self.kernel_size, use_bias=False, padding="CIRCULAR", kernel_init=kinit_wrap, )(x) @pytest.mark.parametrize("kernel_size", [(3, 3, 1, 1), (11, 11, 1, 1)]) def test_conv_layer(kernel_size): key = jax.random.key(12345) kernel, key = randn(kernel_size, dtype=np.float32, key=key) input_size = (1, 128, 128, 1) x, key = randn(input_size, dtype=np.float32, key=key) rng = jax.random.key(42) model = CNN(kernel_size=kernel_size[:2], kernel0=kernel) variables = model.init(rng, np.zeros(x.shape)) prms = variables["params"] np.testing.assert_allclose(prms["Conv_0"]["kernel"], kernel) y_layer = model.apply(variables, x) y_util = conv(x, kernel) np.testing.assert_allclose(y_layer, y_util) @pytest.mark.parametrize("input_shape", [(8,), (128,)]) def test_spectral_norm(input_shape): key = jax.random.key(1357) diagonal, key = randn(input_shape, dtype=np.float32, key=key) mu = np.linalg.norm(np.diag(diagonal), 2) D = linop.Diagonal(diagonal=diagonal) x, key = randn(input_shape, dtype=np.float32, key=key) mu_util = estimate_spectral_norm(lambda x: D @ x, x.shape, n_steps=200) np.testing.assert_allclose(mu, mu_util, rtol=1e-6) @pytest.mark.parametrize("kernel_shape", [(3, 3, 1, 1), (7, 7, 1, 1)]) def test_spectral_norm_conv(kernel_shape): key = jax.random.key(2468) kernel, key = randn(kernel_shape, dtype=np.float32, key=key) input_shape = (1, 32, 32, 1) x, key = randn(input_shape, dtype=np.float32, key=key) sn = exact_spectral_norm(lambda x: conv(x, kernel), x.shape) sn_util = estimate_spectral_norm(lambda x: conv(x, kernel), x.shape, n_steps=100) np.testing.assert_allclose(sn, sn_util, rtol=1e-3, atol=1e-2) def test_train_spectral_norm(): depth = 3 channels = 1 num_filters = 16 model = sflax.DnCNNNet(depth, channels, num_filters) train_conf: sflax.ConfigDict = { "seed": 0, "opt_type": "ADAM", "batch_size": 16, "num_epochs": 1, "base_learning_rate": 1e-3, "lr_decay_rate": 0.95, "warmup_epochs": 0, "num_train_steps": -1, "steps_per_eval": -1, "steps_per_epoch": 1, "log_every_steps": 1000, } N = 64 xtr, key = randn((train_conf["batch_size"], N, N, channels), seed=4321) xtt, key = randn((train_conf["batch_size"], N, N, channels), key=key) train_ds = {"image": xtr, "label": xtr} test_ds = {"image": xtt, "label": xtt} try: xshape = (1,) + train_ds["label"][0].shape convtrav = construct_traversal("kernel") kernelnorm = partial( spectral_normalization_conv, traversal=convtrav, xshape=xshape, ) train_conf["post_lst"] = [kernelnorm] trainer = sflax.BasicFlaxTrainer( train_conf, model, train_ds, test_ds, ) modvar, _ = trainer.train() except Exception as e: print(e) assert 0 else: knlsn = np.array( [ estimate_spectral_norm( lambda x: conv(x, kernel), (1, xshape[1], xshape[2], kernel.shape[2]) ) for kernel in convtrav.iterate(modvar["params"]) ] ) np.testing.assert_array_less(knlsn, np.ones(knlsn.shape)) ================================================ FILE: scico/test/flax/test_steps.py ================================================ import functools import jax import pytest from test_trainer import SetupTest from flax import jax_utils from scico import flax as sflax from scico.flax.train.diagnostics import compute_metrics from scico.flax.train.learning_rate import create_cnst_lr_schedule from scico.flax.train.losses import mse_loss from scico.flax.train.state import create_basic_train_state from scico.flax.train.steps import eval_step, train_step, train_step_post from scico.flax.train.traversals import clip_range, construct_traversal @pytest.fixture(scope="module") def testobj(): yield SetupTest() def test_basic_train_step(testobj): key = jax.random.key(seed=531) key1, key2 = jax.random.split(key) model = sflax.ConvBNNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) input_shape = (1, testobj.N, testobj.N, testobj.chn) learning_rate = create_cnst_lr_schedule(testobj.train_conf) state = create_basic_train_state(key1, testobj.train_conf, model, input_shape, learning_rate) criterion = mse_loss local_batch_size = testobj.train_conf["batch_size"] // jax.process_count() size_device_prefetch = 2 train_dt_iter = sflax.create_input_iter( key2, testobj.train_ds, local_batch_size, size_device_prefetch, model.dtype, train=True, ) # Training is configured as parallel operation state = jax_utils.replicate(state) p_train_step = jax.pmap( functools.partial( train_step, learning_rate_fn=learning_rate, criterion=criterion, metrics_fn=compute_metrics, ), axis_name="batch", ) try: batch = next(train_dt_iter) p_train_step(state, batch) except Exception as e: print(e) assert 0 def test_post_train_step(testobj): key = jax.random.key(seed=531) key1, key2 = jax.random.split(key) model = sflax.ConvBNNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) input_shape = (1, testobj.N, testobj.N, testobj.chn) learning_rate = create_cnst_lr_schedule(testobj.train_conf) state = create_basic_train_state(key1, testobj.train_conf, model, input_shape, learning_rate) criterion = mse_loss local_batch_size = testobj.train_conf["batch_size"] // jax.process_count() size_device_prefetch = 2 train_dt_iter = sflax.create_input_iter( key2, testobj.train_ds, local_batch_size, size_device_prefetch, model.dtype, train=True, ) # Dum range requirement over kernel parameters ktrav = construct_traversal("kernel") krange = functools.partial(clip_range, traversal=ktrav, minval=1e-5, maxval=1e1) # Training is configured as parallel operation state = jax_utils.replicate(state) p_train_step = jax.pmap( functools.partial( train_step_post, learning_rate_fn=learning_rate, criterion=criterion, train_step_fn=train_step, metrics_fn=compute_metrics, post_lst=[krange], ), axis_name="batch", ) try: batch = next(train_dt_iter) p_train_step(state, batch) except Exception as e: print(e) assert 0 def test_basic_eval_step(testobj): key = jax.random.key(seed=531) key1, key2 = jax.random.split(key) model = sflax.ConvBNNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) input_shape = (1, testobj.N, testobj.N, testobj.chn) learning_rate = create_cnst_lr_schedule(testobj.train_conf) state = create_basic_train_state(key1, testobj.train_conf, model, input_shape, learning_rate) criterion = mse_loss local_batch_size = testobj.train_conf["batch_size"] // jax.process_count() size_device_prefetch = 2 eval_dt_iter = sflax.create_input_iter( key2, testobj.test_ds, local_batch_size, size_device_prefetch, model.dtype, train=False, ) # Evaluation is configured as parallel operation state = jax_utils.replicate(state) p_eval_step = jax.pmap( functools.partial(eval_step, criterion=criterion, metrics_fn=compute_metrics), axis_name="batch", ) try: batch = next(eval_dt_iter) p_eval_step(state, batch) except Exception as e: print(e) assert 0 ================================================ FILE: scico/test/flax/test_train_aux.py ================================================ import numpy as np import jax import pytest from test_trainer import SetupTest from scico import flax as sflax from scico import random from scico.flax.train.clu_utils import flatten_dict from scico.flax.train.diagnostics import ArgumentStruct, compute_metrics, stats_obj from scico.flax.train.input_pipeline import IterateData, prepare_data from scico.flax.train.learning_rate import ( create_cnst_lr_schedule, create_cosine_lr_schedule, create_exp_lr_schedule, ) from scico.flax.train.losses import mse_loss from scico.flax.train.state import create_basic_train_state, initialize @pytest.fixture(scope="module") def testobj(): yield SetupTest() def test_mse_loss(): N = 256 x, key = random.randn((N, N), seed=4321) y, key = random.randn((N, N), key=key) # Optax uses a 0.5 factor. mse_jnp = 0.5 * jax.numpy.mean((x - y) ** 2) mse_optx = mse_loss(y, x) np.testing.assert_allclose(mse_jnp, mse_optx) @pytest.mark.parametrize("batch_size", [4, 8, 16]) def test_data_iterator(testobj, batch_size): ds = IterateData(testobj.test_ds_simple, batch_size, train=False) N = testobj.test_ds_simple["image"].shape[0] assert ds.steps_per_epoch == N // batch_size assert ds.key is not None @pytest.mark.parametrize("local_batch", [8, 16, 24]) def test_dstrain(testobj, local_batch): key = jax.random.key(seed=1234) train_iter = sflax.create_input_iter( key, testobj.train_ds_simple, local_batch, ) nproc = jax.device_count() ll = [] num_steps = 40 for step, batch in zip(range(num_steps), train_iter): for j in range(nproc): ll.append(batch["image"][j]) ll_ = np.array(jax.device_get(ll)).flatten() ll_ar = np.array(list(set(np.sort(ll_)))) np.testing.assert_allclose(ll_ar, np.arange(80)) @pytest.mark.parametrize("local_batch", [8, 16, 32]) def test_dstest(testobj, local_batch): key = jax.random.key(seed=1234) train_iter = sflax.create_input_iter(key, testobj.test_ds_simple, local_batch, train=False) nproc = jax.device_count() ll = [] num_steps = 20 for step, batch in zip(range(num_steps), train_iter): for j in range(nproc): ll.append(batch["image"][j]) ll_ = np.array(jax.device_get(ll)).flatten() ll_ar = np.array(list(set(np.sort(ll_)))) np.testing.assert_allclose(ll_ar, np.arange(80, 112)) def test_prepare_data(testobj): xbtch = prepare_data(testobj.x) local_device_count = jax.local_device_count() shrdsz = testobj.x.shape[0] // local_device_count assert xbtch.shape == (local_device_count, shrdsz, testobj.N, testobj.N, testobj.chn) def test_compute_metrics(testobj): xbtch = prepare_data(testobj.x) xbtch = xbtch / jax.numpy.sqrt(jax.numpy.var(xbtch, axis=(1, 2, 3, 4))) ybtch = xbtch + 1 p_eval = jax.pmap(compute_metrics, axis_name="batch") eval_metrics = p_eval(ybtch, xbtch) mtrcs = jax.tree_util.tree_map(lambda x: x.mean(), eval_metrics) assert np.abs(mtrcs["loss"]) < 0.51 assert mtrcs["snr"] < 5e-4 def test_cnst_learning_rate(testobj): step = 1 cnst_sch = create_cnst_lr_schedule(testobj.train_conf) lr = cnst_sch(step) assert lr == testobj.train_conf["base_learning_rate"] def test_cos_learning_rate(testobj): step = 1 len_train = testobj.train_ds["label"].shape[0] train_conf = dict(testobj.train_conf) train_conf["steps_per_epoch"] = len_train // testobj.train_conf["batch_size"] sch = create_cosine_lr_schedule(train_conf) lr = sch(step) decay_steps = (train_conf["num_epochs"] - train_conf["warmup_epochs"]) * train_conf[ "steps_per_epoch" ] cosine_decay = 0.5 * (1 + np.cos(np.pi * step / decay_steps)) np.testing.assert_allclose(lr, train_conf["base_learning_rate"] * cosine_decay, rtol=1e-06) def test_exp_learning_rate(testobj): step = 1 len_train = testobj.train_ds["label"].shape[0] train_conf = dict(testobj.train_conf) train_conf["steps_per_epoch"] = len_train // testobj.train_conf["batch_size"] steps = train_conf["steps_per_epoch"] * train_conf["num_epochs"] sch = create_exp_lr_schedule(train_conf) lr = sch(step) exp_decay = train_conf["lr_decay_rate"] ** float(step / steps) np.testing.assert_allclose(lr, train_conf["base_learning_rate"] * exp_decay, rtol=1e-06) def test_train_initialize_function(testobj): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) key = jax.random.key(seed=4444) input_shape = (1, testobj.N, testobj.N, testobj.chn) # Via initialize function dparams1, dbstats1 = initialize(key, model, input_shape[1:3]) flat_params1 = flatten_dict(dparams1) flat_bstats1 = flatten_dict(dbstats1) params1 = [t[1] for t in sorted(flat_params1.items())] bstats1 = [t[1] for t in sorted(flat_bstats1.items())] # Via model initialization variables2 = model.init({"params": key}, np.ones(input_shape, model.dtype)) flat_params2 = flatten_dict(variables2["params"]) flat_bstats2 = flatten_dict(variables2["batch_stats"]) params2 = [t[1] for t in sorted(flat_params2.items())] bstats2 = [t[1] for t in sorted(flat_bstats2.items())] for i in range(len(params1)): np.testing.assert_allclose(params1[i], params2[i], rtol=1e-5) for i in range(len(bstats1)): np.testing.assert_allclose(bstats1[i], bstats2[i], rtol=1e-5) def test_create_basic_train_state_default(testobj): model = sflax.ConvBNNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) key = jax.random.key(seed=432) input_shape = (1, testobj.N, testobj.N, testobj.chn) # Model initialization variables1 = model.init({"params": key}, np.ones(input_shape, model.dtype)) flat_params1 = flatten_dict(variables1["params"]) flat_bstats1 = flatten_dict(variables1["batch_stats"]) params1 = [t[1] for t in sorted(flat_params1.items())] bstats1 = [t[1] for t in sorted(flat_bstats1.items())] learning_rate = create_cnst_lr_schedule(testobj.train_conf) try: # State initialization state = create_basic_train_state( key, testobj.train_conf, model, input_shape[1:3], learning_rate ) except Exception as e: print(e) assert 0 else: flat_params2 = flatten_dict(state.params) flat_bstats2 = flatten_dict(state.batch_stats) params2 = [t[1] for t in sorted(flat_params2.items())] bstats2 = [t[1] for t in sorted(flat_bstats2.items())] for i in range(len(params1)): np.testing.assert_allclose(params1[i], params2[i], rtol=1e-5) for i in range(len(bstats1)): np.testing.assert_allclose(bstats1[i], bstats2[i], rtol=1e-5) def test_create_basic_train_state(testobj): model = sflax.ConvBNNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) key = jax.random.key(seed=432) input_shape = (1, testobj.N, testobj.N, testobj.chn) # Model initialization variables1 = model.init({"params": key}, np.ones(input_shape, model.dtype)) flat_params1 = flatten_dict(variables1["params"]) flat_bstats1 = flatten_dict(variables1["batch_stats"]) params1 = [t[1] for t in sorted(flat_params1.items())] bstats1 = [t[1] for t in sorted(flat_bstats1.items())] learning_rate = create_cnst_lr_schedule(testobj.train_conf) try: # State initialization state = create_basic_train_state( key, testobj.train_conf, model, input_shape[1:3], learning_rate, variables1 ) except Exception as e: print(e) assert 0 else: flat_params2 = flatten_dict(state.params) flat_bstats2 = flatten_dict(state.batch_stats) params2 = [t[1] for t in sorted(flat_params2.items())] bstats2 = [t[1] for t in sorted(flat_bstats2.items())] for i in range(len(params1)): np.testing.assert_allclose(params1[i], params2[i], rtol=1e-5) for i in range(len(bstats1)): np.testing.assert_allclose(bstats1[i], bstats2[i], rtol=1e-5) def test_sgd_train_state(testobj): model = sflax.ConvBNNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) key = jax.random.key(seed=432) input_shape = (1, testobj.N, testobj.N, testobj.chn) # Model initialization variables = model.init({"params": key}, np.ones(input_shape, model.dtype)) learning_rate = create_cnst_lr_schedule(testobj.train_conf) train_conf = dict(testobj.train_conf) train_conf["opt_type"] = "SGD" try: # State initialization state = create_basic_train_state( key, train_conf, model, input_shape[1:3], learning_rate, variables ) except Exception as e: print(e) assert 0 def test_sgd_no_momentum_train_state(testobj): model = sflax.ConvBNNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) key = jax.random.key(seed=432) input_shape = (1, testobj.N, testobj.N, testobj.chn) # Model initialization variables = model.init({"params": key}, np.ones(input_shape, model.dtype)) learning_rate = create_cnst_lr_schedule(testobj.train_conf) train_conf = dict(testobj.train_conf) train_conf["opt_type"] = "SGD" train_conf.pop("momentum") try: # State initialization state = create_basic_train_state( key, train_conf, model, input_shape[1:3], learning_rate, variables ) except Exception as e: print(e) assert 0 def test_argument_struct(): dictaux = {"epochs": 5, "num_steps": 10, "seed": 0} try: dictstruct = ArgumentStruct(**dictaux) except Exception as e: print(e) assert 0 else: assert hasattr(dictstruct, "epochs") assert hasattr(dictstruct, "num_steps") assert hasattr(dictstruct, "seed") def test_complete_stats_obj(): try: itstat_object, itstat_insert_func = stats_obj() except Exception as e: print(e) assert 0 else: summary = { "epoch": 3, "time": 231.0, "train_learning_rate": 1e-2, "train_loss": 1.4e-2, "train_snr": 3, "loss": 1.6e-2, "snr": 2.4, } try: itstat_object.insert(itstat_insert_func(ArgumentStruct(**summary))) except Exception as e: print(e) assert 0 def test_except_incomplete_stats_obj(): itstat_object, itstat_insert_func = stats_obj() summary = { "epoch": 3, "time": 231.0, "train_learning_rate": 1e-2, "train_loss": 1.4e-2, "train_snr": 3, "loss": 1.6e-2, "snr": 2.4, } itstat_object.insert(itstat_insert_func(ArgumentStruct(**summary))) summary2 = { "epoch": 3, "time": 231.0, "train_learning_rate": 1e-2, "train_loss": 1.4e-2, "train_snr": 3, } with pytest.raises(AttributeError): itstat_object.insert(itstat_insert_func(ArgumentStruct(**summary2))) def test_patch_incomplete_stats_obj(): itstat_object, itstat_insert_func = stats_obj() summary = { "epoch": 3, "time": 231.0, "train_learning_rate": 1e-2, "train_loss": 1.4e-2, "train_snr": 3, "loss": 1.6e-2, "snr": 2.4, } itstat_object.insert(itstat_insert_func(ArgumentStruct(**summary))) summary2 = { "epoch": 3, "time": 231.0, "train_learning_rate": 1e-2, "train_loss": 1.4e-2, "train_snr": 3, } try: summary2["loss"] = -1 summary2["snr"] = -1 itstat_object.insert(itstat_insert_func(ArgumentStruct(**summary2))) except Exception as e: print(e) assert 0 ================================================ FILE: scico/test/flax/test_trainer.py ================================================ import functools import numpy as np import jax import optax import pytest from flax import jax_utils from scico import flax as sflax from scico import random from scico.flax.train.clu_utils import flatten_dict from scico.flax.train.learning_rate import create_cnst_lr_schedule from scico.flax.train.state import create_basic_train_state from scico.flax.train.steps import eval_step, train_step from scico.flax.train.trainer import sync_batch_stats from scico.flax.train.traversals import clip_positive, clip_range, construct_traversal class SetupTest: def __init__(self): datain = np.arange(80) datain_test = np.arange(80, 112) dataout = np.zeros(80) dataout[:40] = 1 dataout_test = np.zeros(40) dataout_test[:20] = 1 self.train_ds_simple = {"image": datain, "label": dataout} self.test_ds_simple = {"image": datain_test, "label": dataout_test} # More complex data structure self.N = 128 # signal size self.chn = 1 # number of channels self.bsize = 16 # batch size self.x, key = random.randn((4 * self.bsize, self.N, self.N, self.chn), seed=4321) xt, key = random.randn((32, self.N, self.N, self.chn), key=key) self.train_ds = {"image": self.x, "label": self.x} self.test_ds = {"image": xt, "label": xt} # Model configuration self.model_conf = { "depth": 2, "num_filters": 16, "block_depth": 2, } # Training configuration self.train_conf: sflax.ConfigDict = { "seed": 0, "opt_type": "ADAM", "momentum": 0.9, "batch_size": self.bsize, "num_epochs": 1, "base_learning_rate": 1e-3, "lr_decay_rate": 0.95, "warmup_epochs": 0, "log_every_steps": 1000, } @pytest.fixture(scope="module") def testobj(): yield SetupTest() @pytest.mark.parametrize("opt_type", ["SGD", "ADAM", "ADAMW"]) def test_optimizers(testobj, opt_type): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) train_conf = dict(testobj.train_conf) train_conf["opt_type"] = opt_type try: trainer = sflax.BasicFlaxTrainer( train_conf, model, testobj.train_ds, testobj.test_ds, ) modvar, _ = trainer.train() except Exception as e: print(e) assert 0 def test_optimizers_exception(testobj): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) train_conf = dict(testobj.train_conf) train_conf["opt_type"] = "" with pytest.raises(NotImplementedError): sflax.BasicFlaxTrainer( train_conf, model, testobj.train_ds, testobj.test_ds, ) def test_sync_batch_stats(testobj): key = jax.random.key(seed=12345) key1, key2 = jax.random.split(key) model = sflax.ConvBNNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) input_shape = (1, testobj.N, testobj.N, testobj.chn) learning_rate = create_cnst_lr_schedule(testobj.train_conf) state0 = create_basic_train_state(key1, testobj.train_conf, model, input_shape, learning_rate) # For parallel training state = jax_utils.replicate(state0) state = sync_batch_stats(state) state1 = jax_utils.unreplicate(state) flat_bstats0 = flatten_dict(state0.batch_stats) bstats0 = [t[1] for t in sorted(flat_bstats0.items())] flat_bstats1 = flatten_dict(state1.batch_stats) bstats1 = [t[1] for t in sorted(flat_bstats1.items())] for i in range(len(bstats0)): np.testing.assert_allclose(bstats0[i], bstats1[i], rtol=1e-5) def test_class_train_default_init(testobj): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) try: trainer = sflax.BasicFlaxTrainer( testobj.train_conf, model, testobj.train_ds, testobj.test_ds, ) except Exception as e: print(e) assert 0 else: assert trainer.itstat_object is None def test_class_train_default_noseed(testobj): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) train_conf = dict(testobj.train_conf) train_conf.pop("seed", None) try: trainer = sflax.BasicFlaxTrainer( testobj.train_conf, model, testobj.train_ds, testobj.test_ds, ) except Exception as e: print(e) assert 0 def test_class_train_nolog(testobj): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) train_conf = dict(testobj.train_conf) train_conf["log"] = False try: trainer = sflax.BasicFlaxTrainer( train_conf, model, testobj.train_ds, testobj.test_ds, ) except Exception as e: print(e) assert 0 else: assert trainer.itstat_object is None def test_class_train_required_steps(testobj): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) train_conf = dict(testobj.train_conf) train_conf.pop("batch_size", None) train_conf.pop("num_epochs", None) try: trainer = sflax.BasicFlaxTrainer( train_conf, model, testobj.train_ds, testobj.test_ds, ) except Exception as e: print(e) assert 0 else: batch_size = 2 * jax.device_count() local_batch_size = batch_size // jax.process_count() num_epochs = 10 num_steps = int(trainer.steps_per_epoch * num_epochs) assert trainer.local_batch_size == local_batch_size assert trainer.num_steps == num_steps @pytest.mark.skipif(jax.device_count() == 1, reason="single device present") def test_except_class_train_batch_size(testobj): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) train_conf = dict(testobj.train_conf) train_conf["batch_size"] = jax.device_count() + 1 with pytest.raises(ValueError): trainer = sflax.BasicFlaxTrainer( train_conf, model, testobj.train_ds, testobj.test_ds, ) def test_class_train_set_steps(testobj): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) train_conf = dict(testobj.train_conf) train_conf["steps_per_eval"] = 1 train_conf["steps_per_checkpoint"] = 1 train_conf["log_every_steps"] = 3 try: trainer = sflax.BasicFlaxTrainer( train_conf, model, testobj.train_ds, testobj.test_ds, ) except Exception as e: print(e) assert 0 else: assert trainer.steps_per_eval == train_conf["steps_per_eval"] assert trainer.steps_per_checkpoint == train_conf["steps_per_checkpoint"] assert trainer.log_every_steps == train_conf["log_every_steps"] def test_class_train_set_reporting(testobj): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) train_conf = dict(testobj.train_conf) train_conf["log"] = True train_conf["workdir"] = "./out/" train_conf["checkpointing"] = False train_conf["return_state"] = True try: trainer = sflax.BasicFlaxTrainer( train_conf, model, testobj.train_ds, testobj.test_ds, ) except Exception as e: print(e) assert 0 else: assert trainer.logflag == train_conf["log"] assert trainer.workdir == train_conf["workdir"] assert trainer.checkpointing == train_conf["checkpointing"] assert trainer.return_state == train_conf["return_state"] def test_class_train_set_functions(testobj): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) def huber_loss(output, labels): return jax.numpy.mean(optax.huber_loss(output, labels)) # Dum range requirement over kernel parameters ktrav = construct_traversal("kernel") krange = functools.partial(clip_range, traversal=ktrav, minval=1e-5, maxval=1e1) train_conf = dict(testobj.train_conf) train_conf["criterion"] = huber_loss train_conf["create_train_state"] = create_basic_train_state train_conf["train_step_fn"] = train_step train_conf["eval_step_fn"] = eval_step train_conf["post_lst"] = [krange] try: trainer = sflax.BasicFlaxTrainer( train_conf, model, testobj.train_ds, testobj.test_ds, ) except Exception as e: print(e) assert 0 else: assert trainer.criterion == train_conf["criterion"] assert trainer.create_train_state == train_conf["create_train_state"] assert trainer.train_step_fn == train_conf["train_step_fn"] assert trainer.eval_step_fn == train_conf["eval_step_fn"] assert trainer.post_lst[0] == train_conf["post_lst"][0] assert hasattr(trainer, "lr_schedule") def test_class_train_set_iterators(testobj): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) try: trainer = sflax.BasicFlaxTrainer( testobj.train_conf, model, testobj.train_ds, testobj.test_ds, ) except Exception as e: print(e) assert 0 else: assert trainer.ishape == testobj.train_ds["image"].shape[1:3] assert hasattr(trainer, "train_dt_iter") assert hasattr(trainer, "eval_dt_iter") @pytest.mark.parametrize("postl", [False, True]) def test_class_train_set_parallel(testobj, postl): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) train_conf = dict(testobj.train_conf) train_conf["post_lst"] = [] if postl: # Dum range requirement over kernel parameters ktrav = construct_traversal("kernel") krange = functools.partial(clip_range, traversal=ktrav, minval=1e-5, maxval=1e1) train_conf["post_lst"] = [krange] try: trainer = sflax.BasicFlaxTrainer( train_conf, model, testobj.train_ds, testobj.test_ds, ) except Exception as e: print(e) assert 0 else: assert hasattr(trainer, "p_train_step") assert hasattr(trainer, "p_eval_step") @pytest.mark.parametrize("chkflag", [False, True]) def test_class_train_external_init(testobj, chkflag): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) key = jax.random.key(seed=1234) input_shape = (1, testobj.N, testobj.N, testobj.chn) # Via model initialization variables1 = model.init({"params": key}, np.ones(input_shape, model.dtype)) flat_params1 = flatten_dict(variables1["params"]) flat_bstats1 = flatten_dict(variables1["batch_stats"]) params1 = [t[1] for t in sorted(flat_params1.items())] bstats1 = [t[1] for t in sorted(flat_bstats1.items())] # Via BasicFlaxTrainer object initialization train_conf = dict(testobj.train_conf) train_conf["checkpointing"] = chkflag trainer = sflax.BasicFlaxTrainer( train_conf, model, testobj.train_ds, testobj.test_ds, variables0=variables1, ) flat_params2 = flatten_dict(trainer.state.params) flat_bstats2 = flatten_dict(trainer.state.batch_stats) params2 = [t[1] for t in sorted(flat_params2.items())] bstats2 = [t[1] for t in sorted(flat_bstats2.items())] for i in range(len(params1)): np.testing.assert_allclose(params1[i], params2[i], rtol=1e-5) for i in range(len(bstats1)): np.testing.assert_allclose(bstats1[i], bstats2[i], rtol=1e-5) @pytest.mark.parametrize("model_cls", [sflax.DnCNNNet, sflax.ResNet, sflax.ConvBNNet, sflax.UNet]) def test_class_train_train_loop(testobj, model_cls): depth = testobj.model_conf["depth"] model = model_cls(depth, testobj.chn, testobj.model_conf["num_filters"]) if isinstance(model, sflax.DnCNNNet): depth = 3 model = sflax.DnCNNNet(depth, testobj.chn, testobj.model_conf["num_filters"]) # Create training object trainer = sflax.BasicFlaxTrainer( testobj.train_conf, model, testobj.train_ds, testobj.test_ds, ) try: modvar, _ = trainer.train() except Exception as e: print(e) assert 0 else: assert "params" in modvar assert "batch_stats" in modvar def test_class_train_train_post_loop(testobj): depth = testobj.model_conf["depth"] model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) train_conf = dict(testobj.train_conf) # Dum positive requirement over kernel parameters ktrav = construct_traversal("kernel") kpos = functools.partial(clip_positive, traversal=ktrav, minval=1e-5) train_conf["post_lst"] = [kpos] # Create training object trainer = sflax.BasicFlaxTrainer( train_conf, model, testobj.train_ds, testobj.test_ds, ) try: modvar, _ = trainer.train() except Exception as e: print(e) assert 0 else: assert "params" in modvar assert "batch_stats" in modvar def test_class_train_return_state(testobj): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) train_conf = dict(testobj.train_conf) train_conf["return_state"] = True trainer = sflax.BasicFlaxTrainer( train_conf, model, testobj.train_ds, testobj.test_ds, ) try: state, _ = trainer.train() except Exception as e: print(e) assert 0 else: assert hasattr(state, "params") assert hasattr(state, "batch_stats") def test_class_train_update_metrics(testobj): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) train_conf = dict(testobj.train_conf) train_conf["log"] = True train_conf["log_every_steps"] = 1 trainer = sflax.BasicFlaxTrainer( train_conf, model, testobj.train_ds, testobj.test_ds, ) total_steps = (testobj.train_ds["label"].shape[0] // testobj.bsize) * train_conf["num_epochs"] try: state, stats_object = trainer.train() except Exception as e: print(e) assert 0 else: hist = stats_object.history(transpose=True) assert len(hist.Train_Loss) == total_steps def test_class_train_update_metrics_nolog(testobj): model = sflax.ResNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) train_conf = dict(testobj.train_conf) train_conf["log"] = False train_conf["log_every_steps"] = 1 trainer = sflax.BasicFlaxTrainer( train_conf, model, testobj.train_ds, testobj.test_ds, ) try: state, stats_object = trainer.train() except Exception as e: print(e) assert 0 else: assert stats_object is None ================================================ FILE: scico/test/flax/test_traversal.py ================================================ import numpy as np import jax import pytest from test_trainer import SetupTest from scico import flax as sflax from scico.flax.train.traversals import construct_traversal @pytest.fixture(scope="module") def testobj(): yield SetupTest() @pytest.mark.parametrize("pname", ["kernel", "bias", "scale"]) def test_construct_traversal(testobj, pname): model = sflax.ConvBNNet( testobj.model_conf["depth"], testobj.chn, testobj.model_conf["num_filters"] ) ndim = 1 if pname == "kernel": ndim = 4 key = jax.random.key(seed=432) input_shape = (1, testobj.N, testobj.N, testobj.chn) variables = model.init({"params": key}, np.ones(input_shape, model.dtype)) ptrav = construct_traversal(pname) for pm in ptrav.iterate(variables["params"]): assert len(pm.shape) == ndim ================================================ FILE: scico/test/functional/prox.py ================================================ import numpy as np import scico.numpy as snp from scico.solver import minimize def prox_func(x, v, f, alpha): """Evaluate functional of which the proximal operator is the argmin.""" return 0.5 * snp.sum(snp.abs(x.reshape(v.shape) - v) ** 2) + alpha * snp.array( f(x.reshape(v.shape)), dtype=snp.float64 ) def prox_solve(v, v0, f, alpha): """Evaluate the alpha-scaled proximal operator of f at v, using v0 as an initial point for the optimization.""" fnc = lambda x: prox_func(x, v, f, alpha) fmn = minimize( fnc, v0, method="Nelder-Mead", options={"maxiter": 1000, "xatol": 1e-9, "fatol": 1e-9}, ) return fmn.x.reshape(v.shape), fmn.fun def prox_test(v, nrm, prx, alpha, x0=None, rtol=1e-6): """Test the alpha-scaled proximal operator function prx of norm functional nrm at point v.""" # Evaluate the proximal operator at v px = snp.array(prx(v, alpha, v0=x0)) # Proximal operator functional value (i.e. Moreau envelope) at v pf = prox_func(px, v, nrm, alpha) # Brute-force solve of the proximal operator at v mx, mf = prox_solve(v, px, nrm, alpha) # Compare prox functional value with brute-force solution if pf < mf: return # prox gave a lower cost than brute force, so it passes np.testing.assert_allclose(pf, mf, rtol=rtol) ================================================ FILE: scico/test/functional/test_composed.py ================================================ import numpy as np from jax import config from prox import prox_test from scico import functional, linop from scico.random import randn # enable 64-bit mode for output dtype checks config.update("jax_enable_x64", True) class TestComposed: def setup_method(self): key = None self.shape = (2, 3, 4) self.dtype = np.float32 self.x, key = randn(self.shape, key=key, dtype=self.dtype) self.composed = functional.ComposedFunctional( functional.L2Norm(), linop.Reshape(self.x.shape, (2, -1), input_dtype=self.dtype), ) def test_eval(self): np.testing.assert_allclose(self.composed(self.x), self.composed.functional(self.x)) def test_prox(self): prox_test(self.x, self.composed.__call__, self.composed.prox, 1.0) ================================================ FILE: scico/test/functional/test_denoiser_func.py ================================================ import numpy as np import pytest from scico import denoiser, functional from scico.denoiser import BM3DProfile, BM4DProfile, have_bm3d, have_bm4d from scico.random import randn from scico.test.osver import osx_ver_geq_than # bm3d is known to be broken on OSX 11.6.5. It may be broken on earlier versions too, # but this has not been confirmed @pytest.mark.skipif(osx_ver_geq_than("11.6.5"), reason="bm3d broken on this platform") @pytest.mark.skipif(not have_bm3d, reason="bm3d package not installed") class TestBM3D: def setup_method(self): key = None self.x_gry, key = randn((32, 33), key=key, dtype=np.float32) self.x_rgb, key = randn((33, 34, 3), key=key, dtype=np.float32) self.profile = BM3DProfile() self.profile.num_threads = 1 # Make processing deterministic self.f_gry = functional.BM3D(profile=self.profile) self.f_rgb = functional.BM3D(is_rgb=True, profile=self.profile) def test_gry(self): y0 = self.f_gry.prox(self.x_gry, 1.0) y1 = denoiser.bm3d(self.x_gry, 1.0, profile=self.profile) assert np.linalg.norm(y1 - y0) < 1e-6 def test_rgb(self): y0 = self.f_rgb.prox(self.x_rgb, 1.0) y1 = denoiser.bm3d(self.x_rgb, 1.0, is_rgb=True, profile=self.profile) assert np.linalg.norm(y1 - y0) < 1e-6 # bm4d is known to be broken on OSX 11.6.5. It may be broken on earlier versions too, # but this has not been confirmed @pytest.mark.skipif(osx_ver_geq_than("11.6.5"), reason="bm4d broken on this platform") @pytest.mark.skipif(not have_bm4d, reason="bm4d package not installed") class TestBM4D: def setup_method(self): key = None self.x, key = randn((16, 17, 14), key=key, dtype=np.float32) self.profile = BM4DProfile() self.profile.num_threads = 1 # Make processing deterministic self.f = functional.BM4D(profile=self.profile) def test(self): y0 = self.f.prox(self.x, 1.0) y1 = denoiser.bm4d(self.x, 1.0, profile=self.profile) assert np.linalg.norm(y1 - y0) < 1e-6 class TestBlindDnCNN: def setup_method(self): key = None self.x_sngchn, key = randn((32, 33), key=key, dtype=np.float32) self.x_mltchn, key = randn((33, 34, 5), key=key, dtype=np.float32) self.dncnn = denoiser.DnCNN(variant="6M") self.f = functional.DnCNN(variant="6M") def test_sngchn(self): y0 = self.f.prox(self.x_sngchn, 1.0) y1 = self.dncnn(self.x_sngchn) np.testing.assert_allclose(y0, y1, rtol=1e-5) def test_mltchn(self): y0 = self.f.prox(self.x_mltchn, 1.0) y1 = self.dncnn(self.x_mltchn) np.testing.assert_allclose(y0, y1, rtol=1e-5) class TestNonBlindDnCNN: def setup_method(self): key = None self.x_sngchn, key = randn((32, 33), key=key, dtype=np.float32) self.x_mltchn, key = randn((33, 34, 5), key=key, dtype=np.float32) self.dncnn = denoiser.DnCNN(variant="6N") self.f = functional.DnCNN(variant="6N") def test_sngchn(self): y0 = self.f.prox(self.x_sngchn, 1.5) y1 = self.dncnn(self.x_sngchn, 1.5) np.testing.assert_allclose(y0, y1, rtol=1e-5) def test_mltchn(self): y0 = self.f.prox(self.x_mltchn, 0.7) y1 = self.dncnn(self.x_mltchn, 0.7) np.testing.assert_allclose(y0, y1, rtol=1e-5) ================================================ FILE: scico/test/functional/test_funcional_core.py ================================================ import numpy as np import jax.numpy as jnp from jax import config # enable 64-bit mode for output dtype checks config.update("jax_enable_x64", True) import pytest from prox import prox_test import scico.numpy as snp from scico import functional from scico.random import randn NO_BLOCK_ARRAY = [ functional.L21Norm, functional.L1MinusL2Norm, functional.NuclearNorm, functional.AnisotropicTVNorm, functional.IsotropicTVNorm, functional.TVNorm, ] NO_COMPLEX = [functional.NonNegativeIndicator, functional.BoxIndicator] def pytest_generate_tests(metafunc): level = int(metafunc.config.getoption("--level")) alpha_range = [1e-2, 1e-1, 1e0, 1e1] dtype_range = [np.float32, np.complex64, np.float64, np.complex128] if level == 2: alpha_range = [1e-2, 1e1] dtype_range = [np.float32, np.complex64, np.float64] elif level < 2: alpha_range = [1e-2] dtype_range = [np.float32, np.complex64] if "alpha" in metafunc.fixturenames: metafunc.parametrize("alpha", alpha_range) if "test_dtype" in metafunc.fixturenames: metafunc.parametrize("test_dtype", dtype_range) class ProxTestObj: def __init__(self, dtype): key = None self.v, key = randn(shape=(11, 1), dtype=dtype, key=key, seed=3) self.vb, key = randn(shape=((3, 4), (2,)), dtype=dtype, key=key) self.scalar = np.pi self.vz = snp.zeros((3, 4), dtype=dtype) @pytest.fixture def test_prox_obj(test_dtype): return ProxTestObj(test_dtype) class SeparableTestObject: def __init__(self, dtype): self.f = functional.L1Norm() self.g = functional.SquaredL2Norm() self.fg = functional.SeparableFunctional([self.f, self.g]) n = 4 m = 6 key = None self.v1, key = randn((n,), key=key, dtype=dtype) # point for prox eval self.v2, key = randn((m,), key=key, dtype=dtype) # point for prox eval self.vb = snp.blockarray([self.v1, self.v2]) @pytest.fixture def test_separable_obj(test_dtype): return SeparableTestObject(test_dtype) def test_separable_eval(test_separable_obj): fv1 = test_separable_obj.f(test_separable_obj.v1) gv2 = test_separable_obj.g(test_separable_obj.v2) fgv = test_separable_obj.fg(test_separable_obj.vb) np.testing.assert_allclose(fv1 + gv2, fgv, rtol=5e-2) def test_separable_prox(test_separable_obj): alpha = 0.1 fv1 = test_separable_obj.f.prox(test_separable_obj.v1, alpha) gv2 = test_separable_obj.g.prox(test_separable_obj.v2, alpha) fgv = test_separable_obj.fg.prox(test_separable_obj.vb, alpha) out = snp.blockarray((fv1, gv2)) snp.testing.assert_allclose(out, fgv, rtol=5e-2) def test_separable_grad(test_separable_obj): # Test the separable grad fv1 = test_separable_obj.f.grad(test_separable_obj.v1) gv2 = test_separable_obj.g.grad(test_separable_obj.v2) fgv = test_separable_obj.fg.grad(test_separable_obj.vb) out = snp.blockarray((fv1, gv2)) snp.testing.assert_allclose(out, fgv, rtol=5e-2) class HuberNormSep(functional.HuberNorm): def __init__(self, delta=1.0): super().__init__(delta=delta, separable=True) class HuberNormNonSep(functional.HuberNorm): def __init__(self, delta=1.0): super().__init__(delta=delta, separable=False) class TestNormProx: normlist = [ functional.L0Norm, functional.L1Norm, functional.SquaredL2Norm, functional.L2Norm, functional.L21Norm, functional.L1MinusL2Norm, HuberNormSep, HuberNormNonSep, functional.NuclearNorm, functional.ZeroFunctional, ] normlist_blockarray_ready = set(normlist.copy()) - set(NO_BLOCK_ARRAY) @pytest.mark.parametrize("norm", normlist) def test_prox(self, norm, alpha, test_prox_obj): nrmobj = norm() nrm = nrmobj.__call__ prx = nrmobj.prox pf = prox_test(test_prox_obj.v, nrm, prx, alpha) @pytest.mark.parametrize("norm", normlist) def test_conj_prox(self, norm, alpha, test_prox_obj): nrmobj = norm() v = test_prox_obj.v # Test checks extended Moreau decomposition at a random vector lhs = nrmobj.prox(v, alpha) + alpha * nrmobj.conj_prox(v / alpha, 1.0 / alpha) rhs = v np.testing.assert_allclose(lhs, rhs, rtol=1e-6, atol=0.0) @pytest.mark.parametrize("norm", normlist_blockarray_ready) def test_prox_blockarray(self, norm, alpha, test_prox_obj): nrmobj = norm() nrm = nrmobj.__call__ prx = nrmobj.prox pf = nrmobj.prox(snp.ravel(test_prox_obj.vb), alpha) pf_b = nrmobj.prox(test_prox_obj.vb, alpha) assert pf.dtype == test_prox_obj.vb.dtype assert pf_b.dtype == test_prox_obj.vb.dtype snp.testing.assert_allclose(pf, snp.ravel(pf_b), rtol=1e-6) @pytest.mark.parametrize("norm", normlist) def test_prox_zeros(self, norm, test_prox_obj): nrmobj = norm() nrm = nrmobj.__call__ prx = nrmobj.prox pf = prox_test(test_prox_obj.vz, nrm, prx, alpha=1.0) @pytest.mark.parametrize("norm", normlist) def test_scaled_attrs(self, norm, test_prox_obj): alpha = np.sqrt(2) unscaled = norm() scaled = test_prox_obj.scalar * norm() assert scaled.has_eval == unscaled.has_eval assert scaled.has_prox == unscaled.has_prox assert scaled.scale == test_prox_obj.scalar @pytest.mark.parametrize("norm", normlist) def test_scaled_eval(self, norm, alpha, test_prox_obj): unscaled = norm() scaled = test_prox_obj.scalar * norm() a = test_prox_obj.scalar * unscaled(test_prox_obj.v) b = scaled(test_prox_obj.v) np.testing.assert_allclose(a, b) @pytest.mark.parametrize("norm", normlist) def test_scaled_prox(self, norm, alpha, test_prox_obj): # Test prox unscaled = norm() scaled = test_prox_obj.scalar * norm() a = unscaled.prox(test_prox_obj.v, alpha * test_prox_obj.scalar) b = scaled.prox(test_prox_obj.v, alpha) np.testing.assert_allclose(a, b) class TestBlockArrayEval: # Ensures that functionals evaluate properly on a blockarray # By convention, should be the same as evaluating on the flattened array # Generate a list of all functionals in scico.functionals that we will check ignore = [ functional.Functional, functional.ScaledFunctional, functional.SetDistance, functional.SquaredSetDistance, ] to_check = [] for name, cls in functional.__dict__.items(): if isinstance(cls, type): if issubclass(cls, functional.Functional): if cls not in ignore and cls.has_eval is True: to_check.append(cls) to_check = set(to_check) - set(NO_BLOCK_ARRAY) @pytest.mark.parametrize("cls", to_check) def test_eval(self, cls, test_prox_obj): func = cls() # instantiate the functional we are testing if cls in NO_COMPLEX and snp.util.is_complex_dtype(test_prox_obj.vb.dtype): with pytest.raises(ValueError): x = func(test_prox_obj.vb) return x = func(test_prox_obj.vb) y = func(test_prox_obj.vb.ravel()) assert jnp.isscalar(x) or x.ndim == 0 assert jnp.isscalar(y) or y.ndim == 0 np.testing.assert_allclose(x, y, rtol=1e-6) # only check double precision on projections @pytest.fixture(params=[np.float64, np.complex128]) def test_proj_obj(request): return ProxTestObj(request.param) class TestProj: cnstrlist = [functional.NonNegativeIndicator, functional.L2BallIndicator] sdistlist = [functional.SetDistance, functional.SquaredSetDistance] @pytest.mark.parametrize("cnstr", cnstrlist) def test_prox(self, cnstr, test_proj_obj): alpha = 1 cnsobj = cnstr() cns = cnsobj.__call__ prx = cnsobj.prox if cnstr in NO_COMPLEX and snp.util.is_complex_dtype(test_proj_obj.v.dtype): with pytest.raises(ValueError): prox_test(test_proj_obj.v, cns, prx, alpha) return prox_test(test_proj_obj.v, cns, prx, alpha) @pytest.mark.parametrize("cnstr", cnstrlist) def test_prox_scale_invariance(self, cnstr, test_proj_obj): alpha1 = 1e-2 alpha2 = 1e0 cnsobj = cnstr() u1 = cnsobj.prox(test_proj_obj.v, alpha1) u2 = cnsobj.prox(test_proj_obj.v, alpha2) assert np.linalg.norm(u1 - u2) / np.linalg.norm(u1) <= 1e-7 @pytest.mark.parametrize("sdist", sdistlist) @pytest.mark.parametrize("cnstr", cnstrlist) def test_setdistance(self, sdist, cnstr, alpha, test_proj_obj): if cnstr in NO_COMPLEX and snp.util.is_complex_dtype(test_proj_obj.v.dtype): return cnsobj = cnstr() proj = cnsobj.prox sdobj = sdist(proj) call = sdobj.__call__ prox = sdobj.prox prox_test(test_proj_obj.v, call, prox, alpha) ================================================ FILE: scico/test/functional/test_indicator.py ================================================ import pytest import scico.numpy as snp from scico import functional from scico.random import randn INDICATOR = [ functional.L2BallIndicator, functional.NonNegativeIndicator, functional.BoxIndicator, ] @pytest.mark.parametrize("indicator", INDICATOR) def test_indicator(indicator): x, key = randn(shape=(8,), dtype=snp.float32) func = indicator() assert func(func.prox(x)) == 0.0 ================================================ FILE: scico/test/functional/test_loss.py ================================================ import numpy as np from jax import config import pytest # enable 64-bit mode for output dtype checks config.update("jax_enable_x64", True) from prox import prox_test import scico.numpy as snp from scico import functional, linop, loss from scico.numpy.util import complex_dtype from scico.random import randn, uniform class TestLoss: def setup_method(self): n = 4 dtype = np.float64 A, key = randn((n, n), key=None, dtype=dtype, seed=1234) D, key = randn((n,), key=key, dtype=dtype) W, key = randn((n,), key=key, dtype=dtype) W = 0.1 * W + 1.0 self.Ao = linop.MatrixOperator(A) self.Ao_abs = linop.MatrixOperator(snp.abs(A)) self.Do = linop.Diagonal(D) self.W = linop.Diagonal(W) self.y, key = randn((n,), key=key, dtype=dtype) self.v, key = randn((n,), key=key, dtype=dtype) # point for prox eval scalar, key = randn((1,), key=key, dtype=dtype) self.key = key self.scalar = scalar[0].item() def test_generic_squared_l2(self): A = linop.Identity(input_shape=self.y.shape) f = functional.SquaredL2Norm() L0 = loss.Loss(self.y, A=A, f=f, scale=0.5) L1 = loss.SquaredL2Loss(y=self.y, A=A) np.testing.assert_allclose(L0(self.v), L1(self.v)) np.testing.assert_allclose( L0.prox(self.v, self.scalar), L1.prox(self.v, self.scalar), rtol=1e-6 ) def test_generic_exception(self): A = linop.Diagonal(self.v) L = loss.Loss(self.y, A=A, scale=0.5) with pytest.raises(NotImplementedError): L(self.v) f = functional.L1Norm() L = loss.Loss(self.y, A=A, f=f, scale=0.5) assert not L.has_prox with pytest.raises(NotImplementedError): L.prox(self.v, self.scalar) def test_squared_l2(self): L = loss.SquaredL2Loss(y=self.y, A=self.Ao) assert L.has_eval assert L.has_prox # test eval np.testing.assert_allclose(L(self.v), 0.5 * ((self.Ao @ self.v - self.y) ** 2).sum()) cL = self.scalar * L assert L.scale == 0.5 # hasn't changed assert cL.scale == self.scalar * L.scale assert cL(self.v) == self.scalar * L(self.v) # squared l2 loss with diagonal linop has a prox L_d = loss.SquaredL2Loss(y=self.y, A=self.Do) # test eval np.testing.assert_allclose(L_d(self.v), 0.5 * ((self.Do @ self.v - self.y) ** 2).sum()) assert L_d.has_eval assert L_d.has_prox cL = self.scalar * L_d assert L_d.scale == 0.5 # hasn't changed assert cL.scale == self.scalar * L_d.scale assert cL(self.v) == self.scalar * L_d(self.v) pf = prox_test(self.v, L_d, L_d.prox, 0.75) pf = prox_test(self.v, L, L.prox, 0.75) def test_squared_l2_grad(self): La = loss.SquaredL2Loss(y=self.y) Lb = loss.SquaredL2Loss(y=self.y, scale=5e0) Lc = 1e1 * La ga = La.grad(self.v) gb = Lb.grad(self.v) gc = Lc.grad(self.v) np.testing.assert_allclose(1e1 * ga, gb) np.testing.assert_allclose(gb, gc) def test_weighted_squared_l2(self): L = loss.SquaredL2Loss(y=self.y, A=self.Ao, W=self.W) assert L.has_eval assert L.has_prox np.testing.assert_allclose( L(self.v), 0.5 * (self.W @ (self.Ao @ self.v - self.y) ** 2).sum() ) pf = prox_test(self.v, L, L.prox, 0.75) # weighted l2 loss with diagonal linop has a prox L_d = loss.SquaredL2Loss(y=self.y, A=self.Do, W=self.W) assert L_d.has_eval assert L_d.has_prox np.testing.assert_allclose( L_d(self.v), 0.5 * (self.W @ (self.Do @ self.v - self.y) ** 2).sum() ) pf = prox_test(self.v, L_d, L_d.prox, 0.75) def test_poisson(self): L = loss.PoissonLoss(y=self.y, A=self.Ao_abs) assert L.has_eval assert not L.has_prox # test eval v = snp.abs(self.v) Av = self.Ao_abs @ v np.testing.assert_allclose(L(v), 0.5 * snp.sum(Av - self.y * snp.log(Av) + L.const)) cL = self.scalar * L assert L.scale == 0.5 # hasn't changed assert cL.scale == self.scalar * L.scale assert cL(v) == self.scalar * L(v) class TestAbsLoss: abs_loss = ( (loss.SquaredL2AbsLoss, snp.abs), (loss.SquaredL2SquaredAbsLoss, lambda x: snp.abs(x) ** 2), ) def setup_method(self): n = 4 dtype = np.float64 A, key = randn((n, n), key=None, dtype=dtype, seed=1234) W, key = randn((n,), key=key, dtype=dtype) W = 0.1 * W + 1.0 self.Ao = linop.MatrixOperator(A) self.Ao_abs = linop.MatrixOperator(snp.abs(A)) self.W = linop.Diagonal(W) self.x, key = randn((n,), key=key, dtype=complex_dtype(dtype)) self.v, key = randn((n,), key=key, dtype=complex_dtype(dtype)) # point for prox eval scalar, key = randn((1,), key=key, dtype=dtype) self.scalar = scalar[0].item() @pytest.mark.parametrize("loss_tuple", abs_loss) def test_properties(self, loss_tuple): loss_class = loss_tuple[0] loss_func = loss_tuple[1] y = loss_func(self.Ao(self.x)) L = loss_class(y=y, A=self.Ao, W=self.W) assert L.has_eval assert not L.has_prox cL = self.scalar * L assert L.scale == 0.5 # hasn't changed assert cL.scale == self.scalar * L.scale assert cL(self.v) == self.scalar * L(self.v) with pytest.raises(NotImplementedError): px = L.prox(self.v, 0.75) np.testing.assert_allclose(L(self.x), 0) y = loss_func(self.x) L = loss_class(y=y, A=None, W=None) assert L.has_eval assert L.has_prox cL = self.scalar * L assert L.scale == 0.5 # hasn't changed assert cL.scale == self.scalar * L.scale assert cL(self.v) == self.scalar * L(self.v) np.testing.assert_allclose(L(self.x), 0) W = -1 * self.W with pytest.raises(ValueError): L = loss_class(y=y, W=W) with pytest.raises(TypeError): L = loss_class(y=y, W=linop.Sum(input_shape=W.input_shape)) @pytest.mark.parametrize("loss_tuple", abs_loss) def test_prox(self, loss_tuple): loss_class = loss_tuple[0] loss_func = loss_tuple[1] y = loss_func(self.x) L = loss_class(y=y, A=None, W=self.W) pf = prox_test(self.v.real, L, L.prox, 0.5) # real v pf = prox_test(self.v, L, L.prox, 0.0) # complex v pf = prox_test(self.v, L, L.prox, 0.1) # complex v pf = prox_test(self.v, L, L.prox, 2.0) # complex v pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 0.0) # complex zero v pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 1.0) # complex zero v pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 2.0) # complex zero v # zero y y = snp.zeros(self.x.shape) L = loss_class(y=y, A=None, W=self.W) pf = prox_test(self.v.real, L, L.prox, 0.5) # real v pf = prox_test(self.v, L, L.prox, 0.0) # complex v pf = prox_test(self.v, L, L.prox, 0.1) # complex v pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 0.0) # complex zero v pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 1.0) # complex zero v def test_cubic_root(): N = 10000 p, key = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0, seed=1234) q, _ = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0, key=key) # Avoid cases of very poor numerical precision p = p.at[snp.logical_and(snp.abs(p) < 2, q > 5e-2 * snp.abs(p))].set(1e1) r = loss._dep_cubic_root(p, q) err = snp.abs(r**3 + p * r + q) assert err.max() < 2e-4 # Test loss of precision warning p = snp.array(1e-4, dtype=snp.float32) q = snp.array(1e1, dtype=snp.float32) with pytest.warns(UserWarning): r = loss._dep_cubic_root(p, q) ================================================ FILE: scico/test/functional/test_misc.py ================================================ import numpy as np import jax import pytest import scico.numpy as snp from scico import functional, linop class TestCheckAttrs: # Ensure that the has_eval, has_prox attrs are overridden # and set to True/False in the Functional subclasses. # Generate a list of all functionals in scico.functionals that we will check ignore = [ functional.Functional, functional.FunctionalSum, functional.ScaledFunctional, functional.SeparableFunctional, functional.ComposedFunctional, functional.ProximalAverage, ] to_check = [] for name, cls in functional.__dict__.items(): if isinstance(cls, type): if issubclass(cls, functional.Functional): if cls not in ignore: to_check.append(cls) @pytest.mark.parametrize("cls", to_check) def test_has_eval(self, cls): assert isinstance(cls.has_eval, bool) @pytest.mark.parametrize("cls", to_check) def test_has_prox(self, cls): assert isinstance(cls.has_prox, bool) class TestJit: # Test whether functionals can be jitted. # Generate a list of all functionals in scico.functionals that we will check ignore = [ functional.Functional, functional.ScaledFunctional, functional.SeparableFunctional, functional.AnisotropicTVNorm, # requires input_shape parameter in order to be jittable functional.IsotropicTVNorm, # requires input_shape parameter in order to be jittable functional.TVNorm, # requires input_shape parameter in order to be jittable functional.BM3D, functional.BM4D, ] to_check = [] for name, cls in functional.__dict__.items(): if isinstance(cls, type): if issubclass(cls, functional.Functional): if cls not in ignore: to_check.append(cls) @pytest.mark.parametrize("cls", to_check) def test_jit(self, cls): # Only test functionals that have no required __init__ parameters. try: f = cls() except TypeError: pass else: v = snp.arange(4.0) # Only test functionals that can take 1D input. try: u0 = f.prox(v) except ValueError: pass else: fprox = jax.jit(f.prox) u1 = fprox(v) assert np.allclose(u0, u1) def test_functional_sum(): x = np.random.randn(4, 4) f0 = functional.L1Norm() f1 = 2.0 * functional.L2Norm() f = f0 + f1 assert f(x) == f0(x) + f1(x) with pytest.raises(TypeError): f = f0 + 2.0 def test_scalar_vmap(): x = np.random.randn(4, 4) f = functional.L1Norm() def foo(c): return (c * f)(x) c_list = [1.0, 2.0, 3.0] non_vmap = np.array([foo(c) for c in c_list]) vmapped = jax.vmap(foo)(snp.array(c_list)) np.testing.assert_allclose(non_vmap, vmapped) def test_scalar_pmap(): x = np.random.randn(4, 4) f = functional.L1Norm() def foo(c): return (c * f)(x) c_list = np.random.randn(jax.device_count()) non_pmap = np.array([foo(c) for c in c_list]) pmapped = jax.pmap(foo)(c_list) np.testing.assert_allclose(non_pmap, pmapped) def test_scalar_aggregation(): f = functional.L2Norm() g = 2.0 * f h = 5.0 * g assert isinstance(g, functional.ScaledFunctional) assert isinstance(g.functional, functional.L2Norm) assert g.scale == 2.0 assert isinstance(h, functional.ScaledFunctional) assert isinstance(h.functional, functional.L2Norm) assert h.scale == 10.0 @pytest.mark.parametrize( "func", [ functional.ZeroFunctional(), functional.SeparableFunctional((functional.ZeroFunctional(), functional.ZeroFunctional())), functional.ComposedFunctional(functional.ZeroFunctional(), linop.Identity((4,))), functional.FunctionalSum(functional.ZeroFunctional(), functional.ZeroFunctional()), ], ) def test_repr_str(func): fname = str(func) frepr = repr(func) assert fname in frepr assert "has_eval:" in frepr assert "has_prox:" in frepr ================================================ FILE: scico/test/functional/test_norm.py ================================================ import numpy as np import pytest import scico.numpy as snp from scico import functional @pytest.mark.parametrize("axis", [0, 1, (0, 2)]) def test_l21norm(axis): x = np.ones((3, 4, 5)) if isinstance(axis, int): l2axis = (axis,) else: l2axis = axis l2shape = [x.shape[k] for k in l2axis] l1axis = tuple(set(range(len(x))) - set(l2axis)) l1shape = [x.shape[k] for k in l1axis] l21ana = np.sqrt(np.prod(l2shape)) * np.prod(l1shape) F = functional.L21Norm(l2_axis=axis) l21num = F(x) np.testing.assert_allclose(l21ana, l21num, rtol=1e-5) l2ana = np.sqrt(np.prod(l2shape)) prxana = (l2ana - 1.0) / l2ana * x prxnum = F.prox(x, 1.0) np.testing.assert_allclose(prxana, prxnum, rtol=1e-5) def test_l2norm_blockarray(): xa = np.random.randn(2, 3, 4) xb = snp.blockarray((xa[0], xa[1])) fa = functional.L21Norm(l2_axis=(1, 2)) fb = functional.L21Norm(l2_axis=None) np.testing.assert_allclose(fa(xa), fb(xb), rtol=1e-6) ya = fa.prox(xa) yb = fb.prox(xb) np.testing.assert_allclose(ya[0], yb[0], rtol=1e-6) np.testing.assert_allclose(ya[1], yb[1], rtol=1e-6) ================================================ FILE: scico/test/functional/test_proxavg.py ================================================ import numpy as np import pytest import scico.numpy as snp from scico import functional, linop, loss, metric from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.optimize.pgm import AcceleratedPGM def test_proxavg_init(): g0 = functional.L1Norm() g1 = functional.L2Norm() with pytest.raises(ValueError): h = functional.ProximalAverage( [g0, g1], alpha_list=[ 0.1, ], ) h = functional.ProximalAverage([g0, g1], alpha_list=[0.1, 0.1]) assert sum(h.alpha_list) == 1.0 g1.has_prox = False with pytest.raises(ValueError): h = functional.ProximalAverage([g0, g1]) def test_proxavg(): N = 128 g = np.linspace(0, 2 * np.pi, N, dtype=np.float32) y = np.sin(2 * g) y[y > 0.5] = 0.5 y[y < -0.5] = -0.5 y *= 2 y = snp.array(y) λ0 = 6e-1 λ1 = 6e-1 f = loss.SquaredL2Loss(y=y) g0 = λ0 * functional.L1Norm() g1 = λ1 * functional.L2Norm() solver = ADMM( f=f, g_list=[0.5 * g0, 0.5 * g1], C_list=[linop.Identity(y.shape), linop.Identity(y.shape)], rho_list=[1e1, 1e1], x0=y, maxiter=100, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-5, "maxiter": 20}), ) x_admm = solver.solve() h = functional.ProximalAverage([λ0 * functional.L1Norm(), λ1 * functional.L2Norm()]) solver = AcceleratedPGM(f=f, g=h, L0=3.4e2, x0=y, maxiter=250) x_prxavg = solver.solve() assert metric.snr(x_admm, x_prxavg) > 50 ================================================ FILE: scico/test/functional/test_separable.py ================================================ import numpy as np from jax import config # enable 64-bit mode for output dtype checks config.update("jax_enable_x64", True) import pytest from scico import functional from scico.numpy import blockarray from scico.numpy.testing import assert_allclose from scico.random import randn class SeparableTestObject: def __init__(self, dtype): self.f = functional.L1Norm() self.g = functional.SquaredL2Norm() self.fg = functional.SeparableFunctional([self.f, self.g]) n = 4 m = 6 key = None self.v1, key = randn((n,), key=key, dtype=dtype) # point for prox eval self.v2, key = randn((m,), key=key, dtype=dtype) # point for prox eval self.vb = blockarray([self.v1, self.v2]) @pytest.fixture(params=[np.float32, np.complex64, np.float64, np.complex128]) def test_separable_obj(request): return SeparableTestObject(request.param) def test_separable_eval(test_separable_obj): fv1 = test_separable_obj.f(test_separable_obj.v1) gv2 = test_separable_obj.g(test_separable_obj.v2) fgv = test_separable_obj.fg(test_separable_obj.vb) assert_allclose(fv1 + gv2, fgv, rtol=5e-2) def test_separable_prox(test_separable_obj): alpha = 0.1 fv1 = test_separable_obj.f.prox(test_separable_obj.v1, alpha) gv2 = test_separable_obj.g.prox(test_separable_obj.v2, alpha) fgv = test_separable_obj.fg.prox(test_separable_obj.vb, alpha) out = blockarray((fv1, gv2)).ravel() assert_allclose(out, fgv.ravel(), rtol=5e-2) def test_separable_grad(test_separable_obj): # Tests the separable grad fv1 = test_separable_obj.f.grad(test_separable_obj.v1) gv2 = test_separable_obj.g.grad(test_separable_obj.v2) fgv = test_separable_obj.fg.grad(test_separable_obj.vb) out = blockarray((fv1, gv2)).ravel() assert_allclose(out, fgv.ravel(), rtol=5e-2) ================================================ FILE: scico/test/functional/test_tvnorm.py ================================================ import numpy as np import pytest import scico.random from scico import functional, linop, loss, metric from scico.examples import create_circular_phantom from scico.functional._tvnorm import HaarTransform, SingleAxisHaarTransform from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.optimize.pgm import AcceleratedPGM @pytest.mark.parametrize("axis", [0, 1]) def test_single_axis_haar_transform(axis): x, key = scico.random.randn((3, 4), seed=1234) HT = SingleAxisHaarTransform(x.shape, axis=axis) np.testing.assert_allclose(2 * x, HT.T(HT(x)), rtol=1e-6) def test_haar_transform(): x, key = scico.random.randn((3, 4), seed=1234) HT = HaarTransform(x.shape) np.testing.assert_allclose(4 * x, HT.T(HT(x)), rtol=1e-6) @pytest.mark.parametrize("circular", [True, False]) def test_aniso_1d(circular): N = 128 g = np.linspace(0, 2 * np.pi, N, dtype=np.float32) x_gt = np.sin(2 * g) x_gt[x_gt > 0.5] = 0.5 x_gt[x_gt < -0.5] = -0.5 σ = 0.02 noise, key = scico.random.randn(x_gt.shape, seed=0) y = x_gt + σ * noise λ = 5e-2 f = loss.SquaredL2Loss(y=y) C = linop.FiniteDifference( input_shape=x_gt.shape, circular=circular, append=None if circular else 0 ) g = λ * functional.L1Norm() solver = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[1e1], x0=y, maxiter=50, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4, "maxiter": 25}), ) x_tvdn = solver.solve() h = λ * functional.AnisotropicTVNorm(circular=circular, input_shape=y.shape) solver = AcceleratedPGM(f=f, g=h, L0=5e2, x0=y, maxiter=100) x_approx = solver.solve() assert metric.snr(x_tvdn, x_approx) > 50 assert metric.rel_res(g(C(x_tvdn)), h(x_tvdn)) < 1e-6 class Test2D: def setup_method(self): N = 32 x_gt = create_circular_phantom( (N, N), [0.6 * N, 0.4 * N, 0.2 * N, 0.1 * N], [0.25, 1, 0, 0.5] ).astype(np.float32) gr, gc = np.ogrid[0:N, 0:N] x_gt += ((gr + gc) / (4 * N)).astype(np.float32) σ = 0.02 noise, key = scico.random.randn(x_gt.shape, seed=0, dtype=np.float32) y = x_gt + σ * noise self.x_gt = x_gt self.y = y @pytest.mark.parametrize("circular", [True, False]) @pytest.mark.parametrize("tvtype", ["aniso", "iso"]) def test_2d(self, tvtype, circular): x_gt = self.x_gt y = self.y λ = 5e-2 f = loss.SquaredL2Loss(y=y) if tvtype == "aniso": g = λ * functional.L1Norm() else: g = λ * functional.L21Norm() C = linop.FiniteDifference( input_shape=x_gt.shape, circular=circular, append=None if circular else 0 ) solver = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[1e1], x0=y, maxiter=150, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4, "maxiter": 25}), ) x_tvdn = solver.solve() if tvtype == "aniso": h = λ * functional.AnisotropicTVNorm(circular=circular, input_shape=y.shape) else: h = λ * functional.IsotropicTVNorm(circular=circular, input_shape=y.shape) solver = AcceleratedPGM( f=f, g=h, L0=1e3, x0=y, maxiter=400, ) x_aprx = solver.solve() assert metric.snr(x_tvdn, x_aprx) > 50 assert metric.rel_res(g(C(x_tvdn)), h(x_tvdn)) < 1e-6 class Test3D: def setup_method(self): N = 32 x2d = create_circular_phantom( (N, N), [0.6 * N, 0.4 * N, 0.2 * N, 0.1 * N], [0.25, 1, 0, 0.5] ).astype(np.float32) gr, gc = np.ogrid[0:N, 0:N] x2d += ((gr + gc) / (4 * N)).astype(np.float32) x_gt = np.stack((0.9 * x2d, np.zeros(x2d.shape), 1.1 * x2d), dtype=np.float32) σ = 0.02 noise, key = scico.random.randn(x_gt.shape, seed=0, dtype=np.float32) y = x_gt + σ * noise self.x_gt = x_gt self.y = y @pytest.mark.parametrize("circular", [False]) @pytest.mark.parametrize("tvtype", ["iso"]) def test_3d(self, tvtype, circular): x_gt = self.x_gt y = self.y λ = 5e-2 f = loss.SquaredL2Loss(y=y) if tvtype == "aniso": g = λ * functional.L1Norm() else: g = λ * functional.L21Norm() C = linop.FiniteDifference( input_shape=x_gt.shape, axes=(1, 2), circular=circular, append=None if circular else 0 ) solver = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[5e0], x0=y, maxiter=150, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4, "maxiter": 25}), ) x_tvdn = solver.solve() if tvtype == "aniso": h = λ * functional.AnisotropicTVNorm( circular=circular, axes=(1, 2), input_shape=y.shape ) else: h = λ * functional.IsotropicTVNorm(circular=circular, axes=(1, 2), input_shape=y.shape) solver = AcceleratedPGM( f=f, g=h, L0=1e3, x0=y, maxiter=400, ) x_aprx = solver.solve() assert metric.snr(x_tvdn, x_aprx) > 50 assert metric.rel_res(g(C(x_tvdn)), h(x_tvdn)) < 1e-6 ================================================ FILE: scico/test/linop/test_binop.py ================================================ import operator as op import pytest import scico.numpy as snp from scico import linop from scico.operator import Abs, Operator class TestBinaryOp: def setup_method(self, method): self.input_shape = (5,) self.input_dtype = snp.float32 @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_case1(self, operator): A = linop.Convolve( snp.ones((2,)), input_shape=self.input_shape, input_dtype=self.input_dtype, mode="same" ) B = Abs(input_shape=self.input_shape, input_dtype=self.input_dtype) assert type(operator(A, B)) == Operator assert type(operator(B, A)) == Operator assert type(operator(2.0 * A, 3.0 * B)) == Operator assert type(operator(2.0 * B, 3.0 * A)) == Operator @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_case2(self, operator): A = linop.Convolve( snp.ones((2,)), input_shape=self.input_shape, input_dtype=self.input_dtype, mode="same" ) B = linop.Identity(input_shape=self.input_shape, input_dtype=self.input_dtype) assert type(operator(A, B)) == linop.LinearOperator assert type(operator(B, A)) == linop.LinearOperator assert type(operator(2.0 * A, 3.0 * B)) == linop.LinearOperator assert type(operator(2.0 * B, 3.0 * A)) == linop.LinearOperator @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_case3(self, operator): A = linop.SingleAxisFiniteDifference( input_shape=self.input_shape, input_dtype=self.input_dtype, circular=True ) B = linop.Identity(input_shape=self.input_shape, input_dtype=self.input_dtype) assert type(operator(A, B)) == linop.LinearOperator assert type(operator(B, A)) == linop.LinearOperator assert type(operator(2.0 * A, 3.0 * B)) == linop.LinearOperator assert type(operator(2.0 * B, 3.0 * A)) == linop.LinearOperator @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_case4(self, operator): A = linop.ScaledIdentity( scalar=0.5, input_shape=self.input_shape, input_dtype=self.input_dtype ) B = linop.Identity(input_shape=self.input_shape, input_dtype=self.input_dtype) assert type(operator(A, B)) == linop.ScaledIdentity assert type(operator(B, A)) == linop.ScaledIdentity assert type(operator(2.0 * A, 3.0 * B)) == linop.ScaledIdentity assert type(operator(2.0 * B, 3.0 * A)) == linop.ScaledIdentity ================================================ FILE: scico/test/linop/test_circconv.py ================================================ import operator as op import numpy as np import jax import pytest import scico.numpy as snp from scico.linop import CircularConvolve, Convolve, Diagonal from scico.random import randint, randn, uniform from scico.test.linop.test_linop import adjoint_test SHAPE_SPECS = [ ((12,), None, (3,)), # 1D ((12, 8), None, (3, 2)), # 2D ((6, 8, 12), None, (3, 2, 4)), # 3D ((2, 12, 8), 2, (3, 2)), # batching x ((12, 8), None, (2, 3, 2)), # batching h ((2, 12, 8), 2, (2, 3, 2)), # batching both # (M, N, b) x (H, W, 1) # this was the old way # (M, N, b) x (H, W) # this won't work: Luke, firm-no # (M, b, N) x (H, W) # do we even want this? # (M, b, N) x (b, H, W) # no, no, no ] class TestCircularConvolve: def setup_method(self, method): self.key = jax.random.key(12345) @pytest.mark.parametrize("jit", [True, False]) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS) def test_eval(self, axes_shape_spec, input_dtype, jit): x_shape, ndims, h_shape = axes_shape_spec h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key) x, key = randn(tuple(x_shape), dtype=input_dtype, key=key) A = CircularConvolve(h, x_shape, ndims, input_dtype, jit=jit) Ax = A @ x # check that a specific pixel of Ax computes an inner product between x and # (flipped, padded, shifted) h h_flipped = np.flip(h, range(-A.ndims, 0)) # flip only in the spatial dims (not batches) x_inds = (...,) + tuple( slice(-h.shape[a], None) for a in range(-A.ndims, 0) ) # bottom right corner of x Ax_inds = (...,) + tuple(-1 for _ in range(A.ndims)) sum_axes = tuple(-(a + 1) for a in range(A.ndims)) # ndims=2 -> -1, -2 np.testing.assert_allclose( np.sum(h_flipped * x[x_inds], axis=sum_axes), Ax[Ax_inds], rtol=1e-5 ) # np.testing.assert_allclose(Ax.ravel(), hx.ravel(), rtol=5e-4) @pytest.mark.parametrize("jit", [True, False]) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS) def test_adjoint(self, axes_shape_spec, input_dtype, jit): x_shape, ndims, h_shape = axes_shape_spec h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key) A = CircularConvolve(h, x_shape, ndims, input_dtype, jit=jit) adjoint_test(A, self.key) @pytest.mark.parametrize("jit", [True, False]) @pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS) @pytest.mark.parametrize("operator", [op.mul, op.truediv]) def test_scalar_left(self, axes_shape_spec, operator, jit): input_dtype = np.float32 scalar = np.float32(3.141) x_shape, ndims, h_shape = axes_shape_spec h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key) A = CircularConvolve(h, x_shape, ndims, input_dtype, jit=jit) cA = operator(A, scalar) np.testing.assert_allclose(operator(A.h_dft.ravel(), scalar), cA.h_dft.ravel(), rtol=5e-5) @pytest.mark.parametrize("jit", [True, False]) @pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS) @pytest.mark.parametrize("operator", [op.mul]) def test_scalar_right(self, axes_shape_spec, operator, jit): input_dtype = np.float32 scalar = np.float32(3.141) x_shape, ndims, h_shape = axes_shape_spec h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key) A = CircularConvolve(h, x_shape, ndims, input_dtype, jit=jit) cA = operator(scalar, A) np.testing.assert_allclose(operator(scalar, A.h_dft.ravel()), cA.h_dft.ravel(), rtol=5e-5) @pytest.mark.parametrize("jit", [True, False]) @pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS) def test_add_sub(self, axes_shape_spec, jit): input_dtype = np.float32 x_shape, ndims, h_shape = axes_shape_spec h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key) g, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key) A = CircularConvolve(h, x_shape, ndims, input_dtype, jit=jit) B = CircularConvolve(g, x_shape, ndims, input_dtype, jit=jit) np.testing.assert_allclose(A.h_dft + B.h_dft, (A + B).h_dft, rtol=5e-5) np.testing.assert_allclose(A.h_dft - B.h_dft, (A - B).h_dft, rtol=5e-5) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("jit", [True, False]) def test_matches_convolve(self, input_dtype, jit): h, key = randint(minval=0, maxval=3, shape=(3, 4), key=self.key) x, key = uniform(minval=0, maxval=1, shape=(5, 4), key=key) h = h.astype(input_dtype) x = (x <= 0.1).astype(input_dtype) # pad to m + n -1 x_pad = snp.pad(x, ((0, h.shape[0] - 1), (0, h.shape[1] - 1))) A = Convolve(h=h, input_shape=x.shape, jit=jit, input_dtype=input_dtype) B = CircularConvolve(h, input_shape=x_pad.shape, jit=jit, input_dtype=input_dtype) actual = B @ x_pad desired = A @ x np.testing.assert_allclose(actual, desired, atol=1e-6) @pytest.mark.parametrize( "center", [ 1, [ 1, ], snp.array([2]), ], ) @pytest.mark.parametrize("jit", [True, False]) def test_center(self, center, jit): x, key = uniform(minval=-1, maxval=1, shape=(16,), key=self.key) h = snp.array([0.5, 1.0, 0.25]) A = CircularConvolve(h=h, input_shape=x.shape, h_center=center, jit=jit) B = CircularConvolve(h=h, input_shape=x.shape, jit=jit) if isinstance(center, int): shift = -center else: shift = -center[0] np.testing.assert_allclose(A @ x, snp.roll(B @ x, shift), atol=1e-5) @pytest.mark.parametrize("jit", [True, False]) def test_fractional_center(self, jit): """A fractional center should keep outputs real.""" x, key = uniform(minval=-1, maxval=1, shape=(4, 5), key=self.key) h, _ = uniform(minval=-1, maxval=1, shape=(2, 2), key=key) A = CircularConvolve(h=h, input_shape=x.shape, h_center=[0.1, 2.7], jit=jit) # taken from CircularConvolve._eval x_dft = snp.fft.fftn(x, axes=A.x_fft_axes) hx = snp.fft.ifftn( A.h_dft * x_dft, axes=A.ifft_axes, ) np.testing.assert_allclose(hx, snp.real(hx)) @pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("jit_old_op", [True, False]) @pytest.mark.parametrize("jit_new_op", [True, False]) def test_from_operator(self, axes_shape_spec, input_dtype, jit_old_op, jit_new_op): x_shape, ndims, h_shape = axes_shape_spec h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key) x, key = randn(tuple(x_shape), dtype=input_dtype, key=key) A = CircularConvolve(h, x_shape, ndims, input_dtype, jit=jit_old_op) B = CircularConvolve.from_operator(A, ndims, jit=jit_new_op) np.testing.assert_allclose(A @ x, B @ x, atol=1e-5) def test_from_operator_block_array(self): """`from_operator` should throw an exception if asked to work on an operator with blockarray inputs.""" H = Diagonal(snp.zeros(((1, 2), (3,)))) with pytest.raises(ValueError): CircularConvolve.from_operator(H) ================================================ FILE: scico/test/linop/test_conversions.py ================================================ """ Test methods that make one kind of Operator out of another. """ import numpy as np import pytest from scico.linop import CircularConvolve, FiniteDifference from scico.random import randn @pytest.mark.parametrize( "shape_axes", [ ((3, 4), None), # 2d ((3, 4, 5), None), # 3d # ((3, 4, 5), [0, 2]), # 3d specific axes -- not supported ], ) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("jit_old", [False, True]) @pytest.mark.parametrize("jit_new", [False, True]) def testCircularConvolve_from_FiniteDifference(shape_axes, input_dtype, jit_old, jit_new): input_shape, axes = shape_axes x, _ = randn(input_shape, dtype=input_dtype) # make a CircularConvolve from a FiniteDifference A = FiniteDifference( input_shape=input_shape, input_dtype=input_dtype, axes=axes, circular=True, jit=jit_old ) B = CircularConvolve.from_operator(A, ndims=x.ndim, jit=jit_new) np.testing.assert_allclose(A @ x, B @ x, atol=1e-5) # try the same on the FiniteDifference Gram ATA = A.gram_op B = CircularConvolve.from_operator(ATA, ndims=x.ndim, jit=jit_new) np.testing.assert_allclose(ATA @ x, B @ x, atol=1e-5) ================================================ FILE: scico/test/linop/test_convolve.py ================================================ import operator as op import numpy as np import jax import jax.scipy.signal as signal import pytest from scico.linop import Convolve, ConvolveByX, LinearOperator from scico.random import randn from scico.test.linop.test_linop import AbsMatOp, adjoint_test class TestConvolve: def setup_method(self, method): self.key = jax.random.key(12345) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("input_shape", [(16,), (16, 24)]) @pytest.mark.parametrize("mode", ["full", "valid", "same"]) @pytest.mark.parametrize("jit", [False, True]) def test_eval(self, input_shape, input_dtype, mode, jit): ndim = len(input_shape) filter_shape = (3, 4)[:ndim] x, key = randn(input_shape, dtype=input_dtype, key=self.key) psf, key = randn(filter_shape, dtype=input_dtype, key=key) A = Convolve(h=psf, input_shape=input_shape, input_dtype=input_dtype, mode=mode, jit=jit) Ax = A @ x y = signal.convolve(x, psf, mode=mode) np.testing.assert_allclose(Ax.ravel(), y.ravel(), rtol=1e-4) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("input_shape", [(16,), (16, 24)]) @pytest.mark.parametrize("mode", ["full", "valid", "same"]) @pytest.mark.parametrize("jit", [False, True]) def test_adjoint(self, input_shape, mode, jit, input_dtype): ndim = len(input_shape) filter_shape = (3, 4)[:ndim] x, key = randn(input_shape, dtype=input_dtype, key=self.key) psf, key = randn(filter_shape, dtype=input_dtype, key=key) A = Convolve(h=psf, input_shape=input_shape, input_dtype=input_dtype, mode=mode, jit=jit) adjoint_test(A, self.key) class ConvolveTestObj: def __init__(self): dtype = np.float32 key = jax.random.key(12345) self.psf_A, key = randn((3,), dtype=dtype, key=key) self.psf_B, key = randn((3,), dtype=dtype, key=key) self.psf_C, key = randn((5,), dtype=dtype, key=key) self.A = Convolve(input_shape=(16,), h=self.psf_A) self.B = Convolve(input_shape=(16,), h=self.psf_B) self.C = Convolve(input_shape=(16,), h=self.psf_C) # Matrix for a 'generic linop' m = self.A.output_shape[0] n = self.A.input_shape[0] G_mat, key = randn((m, n), dtype=dtype, key=key) self.G = AbsMatOp(G_mat) self.x, key = randn((16,), dtype=dtype, key=key) self.scalar = 3.141 @pytest.fixture def testobj(request): yield ConvolveTestObj() def test_init(testobj): with pytest.raises(ValueError): A = Convolve(input_shape=(16, 16), h=testobj.psf_A) with pytest.raises(ValueError): A = Convolve(input_shape=(16,), h=testobj.psf_A, mode="invalid") A = Convolve(input_shape=(16,), input_dtype=None, h=testobj.psf_A) assert A.input_dtype == testobj.psf_A.dtype @pytest.mark.parametrize("operator", [op.mul, op.truediv]) def test_scalar_left(testobj, operator): A = operator(testobj.A, testobj.scalar) x = testobj.x B = Convolve(input_shape=(16,), h=operator(testobj.psf_A, testobj.scalar)) np.testing.assert_allclose(A @ x, B @ x, rtol=5e-5) @pytest.mark.parametrize("operator", [op.mul, op.truediv]) def test_scalar_right(testobj, operator): if operator == op.truediv: pytest.xfail("scalar / LinearOperator is not supported") A = operator(testobj.scalar, testobj.A) x = testobj.x B = Convolve(input_shape=(16,), h=operator(testobj.scalar, testobj.psf_A)) np.testing.assert_allclose(A @ x, B @ x, rtol=5e-5) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_convolve_add_sub(testobj, operator): A = testobj.A B = testobj.B C = testobj.C x = testobj.x # Two operators of same size AB = operator(A, B) ABx = AB @ x AxBx = operator(A @ x, B @ x) np.testing.assert_allclose(ABx, AxBx, rtol=5e-5) # Two operators of different size with pytest.raises(ValueError): operator(A, C) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_add_sub_different_mode(testobj, operator): # These tests get caught inside of the _wrap_add_sub input/output shape checks, # not the explicit mode check inside of the wrapped __add__ method B_same = Convolve(input_shape=(16,), h=testobj.psf_B, mode="same") with pytest.raises(ValueError): operator(testobj.A, B_same) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_add_sum_generic_linop(testobj, operator): # Combine a AbsMatOp and Convolve, get a generic LinearOperator AG = operator(testobj.A, testobj.G) assert isinstance(AG, LinearOperator) # Check evaluation a = AG @ testobj.x b = operator(testobj.A @ testobj.x, testobj.G @ testobj.x) np.testing.assert_allclose(a, b, rtol=5e-5) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_add_sum_conv(testobj, operator): # Combine a AbsMatOp and Convolve, get a generic LinearOperator AA = operator(testobj.A, testobj.A) assert isinstance(AA, Convolve) # Check evaluation a = AA @ testobj.x b = operator(testobj.A @ testobj.x, testobj.A @ testobj.x) np.testing.assert_allclose(a, b, rtol=5e-5) @pytest.mark.parametrize("operator", [op.mul, op.truediv]) def test_mul_div_generic_linop(testobj, operator): # not defined between Convolve and AbsMatOp with pytest.raises(TypeError): operator(testobj.A, testobj.G) def test_invalid_mode(testobj): # mode that doesn't exist with pytest.raises(ValueError): Convolve(input_shape=(16,), h=testobj.psf_A, mode="foo") def test_dimension_mismatch(testobj): with pytest.raises(ValueError): # 2-dim input shape, 1-dim filter Convolve(input_shape=(16, 16), h=testobj.psf_A) class TestConvolveByX: def setup_method(self, method): self.key = jax.random.key(12345) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("input_shape", [(16,), (16, 24)]) @pytest.mark.parametrize("mode", ["full", "valid", "same"]) @pytest.mark.parametrize("jit", [False, True]) def test_eval(self, input_shape, input_dtype, mode, jit): ndim = len(input_shape) x_shape = (3, 4)[:ndim] h, key = randn(input_shape, dtype=input_dtype, key=self.key) x, key = randn(x_shape, dtype=input_dtype, key=key) A = ConvolveByX(x=x, input_shape=input_shape, input_dtype=input_dtype, mode=mode, jit=jit) Ax = A @ h y = signal.convolve(x, h, mode=mode) np.testing.assert_allclose(Ax.ravel(), y.ravel(), rtol=1e-4) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("input_shape", [(16,), (16, 24)]) @pytest.mark.parametrize("mode", ["full", "valid", "same"]) @pytest.mark.parametrize("jit", [False, True]) def test_adjoint(self, input_shape, mode, jit, input_dtype): ndim = len(input_shape) x_shape = (3, 4)[:ndim] x, key = randn(input_shape, dtype=input_dtype, key=self.key) x, key = randn(x_shape, dtype=input_dtype, key=key) A = ConvolveByX(x=x, input_shape=input_shape, input_dtype=input_dtype, mode=mode, jit=jit) adjoint_test(A, self.key) class ConvolveByXTestObj: def __init__(self): dtype = np.float32 key = jax.random.key(12345) self.x_A, key = randn((3,), dtype=dtype, key=key) self.x_B, key = randn((3,), dtype=dtype, key=key) self.x_C, key = randn((5,), dtype=dtype, key=key) self.A = ConvolveByX(input_shape=(16,), x=self.x_A) self.B = ConvolveByX(input_shape=(16,), x=self.x_B) self.C = ConvolveByX(input_shape=(16,), x=self.x_C) # Matrix for a 'generic linop' m = self.A.output_shape[0] n = self.A.input_shape[0] G_mat, key = randn((m, n), dtype=dtype, key=key) self.G = AbsMatOp(G_mat) self.h, key = randn((16,), dtype=dtype, key=key) self.scalar = 3.141 @pytest.fixture def cbx_testobj(request): yield ConvolveByXTestObj() @pytest.mark.parametrize("operator", [op.mul, op.truediv]) def test_cbx_scalar_left(cbx_testobj, operator): A = operator(cbx_testobj.A, cbx_testobj.scalar) h = cbx_testobj.h B = ConvolveByX(input_shape=(16,), x=operator(cbx_testobj.x_A, cbx_testobj.scalar)) np.testing.assert_allclose(A @ h, B @ h, rtol=5e-5) @pytest.mark.parametrize("operator", [op.mul, op.truediv]) def test_cbx_scalar_right(cbx_testobj, operator): if operator == op.truediv: pytest.xfail("scalar / LinearOperator is not supported") A = operator(cbx_testobj.scalar, cbx_testobj.A) h = cbx_testobj.h B = ConvolveByX(input_shape=(16,), x=operator(cbx_testobj.scalar, cbx_testobj.x_A)) np.testing.assert_allclose(A @ h, B @ h, rtol=5e-5) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_convolve_add_sub(cbx_testobj, operator): A = cbx_testobj.A B = cbx_testobj.B C = cbx_testobj.C h = cbx_testobj.h # Two operators of same size AB = operator(A, B) ABh = AB @ h AfiltBh = operator(A @ h, B @ h) np.testing.assert_allclose(ABh, AfiltBh, rtol=5e-5) # Two operators of different size with pytest.raises(ValueError): operator(A, C) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_add_sub_different_mode(cbx_testobj, operator): # These tests get caught inside of the _wrap_add_sub input/output shape checks, # not the explicit mode check inside of the wrapped __add__ method B_same = ConvolveByX(input_shape=(16,), x=cbx_testobj.x_B, mode="same") with pytest.raises(ValueError): operator(cbx_testobj.A, B_same) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_add_sum_generic_linop(cbx_testobj, operator): # Combine a AbsMatOp and ConvolveByX, get a generic LinearOperator AG = operator(cbx_testobj.A, cbx_testobj.G) assert isinstance(AG, LinearOperator) # Check evaluation a = AG @ cbx_testobj.h b = operator(cbx_testobj.A @ cbx_testobj.h, cbx_testobj.G @ cbx_testobj.h) np.testing.assert_allclose(a, b, rtol=5e-5) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_add_sum_conv(cbx_testobj, operator): # Combine a AbsMatOp and ConvolveByX, get a generic LinearOperator AA = operator(cbx_testobj.A, cbx_testobj.A) assert isinstance(AA, ConvolveByX) # Check evaluation a = AA @ cbx_testobj.h b = operator(cbx_testobj.A @ cbx_testobj.h, cbx_testobj.A @ cbx_testobj.h) np.testing.assert_allclose(a, b, rtol=5e-5) @pytest.mark.parametrize("operator", [op.mul, op.truediv]) def test_mul_div_generic_linop(cbx_testobj, operator): # not defined between ConvolveByX and AbsMatOp with pytest.raises(TypeError): operator(cbx_testobj.A, cbx_testobj.G) def test_invalid_mode(cbx_testobj): # mode that doesn't exist with pytest.raises(ValueError): ConvolveByX(input_shape=(16,), x=cbx_testobj.x_A, mode="foo") def test_dimension_mismatch(cbx_testobj): with pytest.raises(ValueError): # 2-dim input shape, 1-dim xer ConvolveByX(input_shape=(16, 16), x=cbx_testobj.x_A) ================================================ FILE: scico/test/linop/test_dft.py ================================================ import numpy as np import jax import pytest import scico.numpy as snp from scico.linop import DFT from scico.random import randn from scico.test.linop.test_linop import adjoint_test class TestDFT: def setup_method(self, method): self.key = jax.random.key(12345) @pytest.mark.parametrize("input_shape", [(16,), (16, 4), (16, 4, 7)]) @pytest.mark.parametrize( "axes_and_shape", [ (None, None), ((0,), None), ((0,), (20,)), ((0, 2), None), ((0, 2), (20, 8)), (None, (6, 8)), ], ) @pytest.mark.parametrize("norm", [None, "backward", "ortho", "forward"]) @pytest.mark.parametrize("jit", [False, True]) def test_dft(self, input_shape, axes_and_shape, norm, jit): axes = axes_and_shape[0] axes_shape = axes_and_shape[1] # Skip bad parameter permutations if axes is not None and len(axes) >= len(input_shape): return if axes is not None and max(axes) >= len(input_shape): return if axes_shape is not None and len(axes_shape) > len(input_shape): return x, self.key = randn(input_shape, dtype=np.complex64, key=self.key) F = DFT(input_shape=input_shape, axes=axes, axes_shape=axes_shape, norm=norm, jit=jit) Fx = F @ x # Test eval snp_result = snp.fft.fftn(x, s=axes_shape, axes=axes, norm=norm).astype(np.complex64) np.testing.assert_allclose(Fx, snp_result, rtol=1e-6) # Test adjoint adjoint_test(F, self.key) # Test inverse y, self.key = randn(F.output_shape, dtype=np.complex64, key=self.key) Fiy = F.inv(y) snp_result = snp.fft.ifftn(y, s=F.inv_axes_shape, axes=axes, norm=norm).astype(np.complex64) np.testing.assert_allclose(Fiy, snp_result, rtol=1e-6) def test_axes_check(self): input_shape = (32, 48) axes = (0,) axes_shape = (40, 50) with pytest.raises(ValueError): F = DFT(input_shape=input_shape, axes=axes, axes_shape=axes_shape) ================================================ FILE: scico/test/linop/test_diag.py ================================================ import operator as op import numpy as np from jax import config import pytest # enable 64-bit mode for output dtype checks config.update("jax_enable_x64", True) import jax from test_linop import adjoint_test import scico.numpy as snp from scico import linop from scico.random import randn class TestDiagonal: def setup_method(self, method): self.key = jax.random.key(12345) input_shapes = [(8,), (8, 12), ((3,), (4, 5))] @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("input_shape", input_shapes) def test_eval(self, input_shape, diagonal_dtype): diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key) x, key = randn(input_shape, dtype=diagonal_dtype, key=key) D = linop.Diagonal(diagonal=diagonal) assert (D @ x).shape == D.output_shape snp.testing.assert_allclose((diagonal * x), (D @ x), rtol=1e-5) @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64]) def test_eval_broadcasting(self, diagonal_dtype): # array broadcast diagonal, key = randn((3, 1, 4), dtype=diagonal_dtype, key=self.key) x, key = randn((5, 1), dtype=diagonal_dtype, key=key) D = linop.Diagonal(diagonal, x.shape) assert (D @ x).shape == (3, 5, 4) np.testing.assert_allclose((diagonal * x).ravel(), (D @ x).ravel(), rtol=1e-5) # blockarray broadcast diagonal, key = randn(((3, 1, 4), (5, 5)), dtype=diagonal_dtype, key=self.key) x, key = randn(((5, 1), (1,)), dtype=diagonal_dtype, key=key) D = linop.Diagonal(diagonal, x.shape) assert (D @ x).shape == ((3, 5, 4), (5, 5)) snp.testing.assert_allclose((diagonal * x), (D @ x), rtol=1e-5) # blockarray x array -> error diagonal, key = randn(((3, 1, 4), (5, 5)), dtype=diagonal_dtype, key=self.key) x, key = randn((5, 1), dtype=diagonal_dtype, key=key) with pytest.raises(ValueError): D = linop.Diagonal(diagonal, x.shape) # array x blockarray -> error diagonal, key = randn((3, 1, 4), dtype=diagonal_dtype, key=self.key) x, key = randn(((5, 1), (1,)), dtype=diagonal_dtype, key=key) with pytest.raises(ValueError): D = linop.Diagonal(diagonal, x.shape) @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("input_shape", input_shapes) def test_adjoint(self, input_shape, diagonal_dtype): diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key) D = linop.Diagonal(diagonal=diagonal) adjoint_test(D) @pytest.mark.parametrize("operator", [op.add, op.sub]) @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("input_shape1", input_shapes) @pytest.mark.parametrize("input_shape2", input_shapes) def test_binary_op(self, input_shape1, input_shape2, diagonal_dtype, operator): diagonal1, key = randn(input_shape1, dtype=diagonal_dtype, key=self.key) diagonal2, key = randn(input_shape2, dtype=diagonal_dtype, key=key) x, key = randn(input_shape1, dtype=diagonal_dtype, key=key) D1 = linop.Diagonal(diagonal=diagonal1) D2 = linop.Diagonal(diagonal=diagonal2) if input_shape1 != input_shape2: with pytest.raises(ValueError): a = operator(D1, D2) @ x else: a = operator(D1, D2) @ x Dnew = linop.Diagonal(operator(diagonal1, diagonal2)) b = Dnew @ x snp.testing.assert_allclose(a, b, rtol=1e-5) @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("input_shape1", input_shapes) @pytest.mark.parametrize("input_shape2", input_shapes) def test_matmul(self, input_shape1, input_shape2, diagonal_dtype): diagonal1, key = randn(input_shape1, dtype=diagonal_dtype, key=self.key) diagonal2, key = randn(input_shape2, dtype=diagonal_dtype, key=key) x, key = randn(input_shape1, dtype=diagonal_dtype, key=key) D1 = linop.Diagonal(diagonal=diagonal1) D2 = linop.Diagonal(diagonal=diagonal2) if input_shape1 != input_shape2: with pytest.raises(ValueError): D3 = D1 @ D2 else: D3 = D1 @ D2 assert isinstance(D3, linop.Diagonal) a = D3 @ x D4 = linop.Diagonal(diagonal1 * diagonal2) b = D4 @ x snp.testing.assert_allclose(a, b, rtol=1e-5) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_binary_op_mismatch(self, operator): diagonal_dtype = np.float32 input_shape1 = (8,) input_shape2 = (12,) diagonal1, key = randn(input_shape1, dtype=diagonal_dtype, key=self.key) diagonal2, key = randn(input_shape2, dtype=diagonal_dtype, key=key) D1 = linop.Diagonal(diagonal=diagonal1) D2 = linop.Diagonal(diagonal=diagonal2) with pytest.raises(ValueError): operator(D1, D2) @pytest.mark.parametrize("operator", [op.mul, op.truediv]) def test_scalar_right(self, operator): if operator == op.truediv: pytest.xfail("scalar / LinearOperator is not supported") diagonal_dtype = np.float32 input_shape = (8,) diagonal1, key = randn(input_shape, dtype=diagonal_dtype, key=self.key) scalar = np.random.randn() x, key = randn(input_shape, dtype=diagonal_dtype, key=key) D = linop.Diagonal(diagonal=diagonal1) scaled_D = operator(scalar, D) np.testing.assert_allclose(scaled_D @ x, operator(scalar, D @ x), rtol=5e-5) @pytest.mark.parametrize("operator", [op.mul, op.truediv]) def test_scalar_left(self, operator): diagonal_dtype = np.float32 input_shape = (8,) diagonal1, key = randn(input_shape, dtype=diagonal_dtype, key=self.key) scalar = np.random.randn() x, key = randn(input_shape, dtype=diagonal_dtype, key=key) D = linop.Diagonal(diagonal=diagonal1) scaled_D = operator(D, scalar) np.testing.assert_allclose(scaled_D @ x, operator(D @ x, scalar), rtol=5e-5) @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64]) def test_gram_op(self, diagonal_dtype): input_shape = (7,) diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key) D1 = linop.Diagonal(diagonal=diagonal) D2 = D1.gram_op D3 = D1.H @ D1 assert isinstance(D3, linop.Diagonal) snp.testing.assert_allclose(D2.diagonal, D3.diagonal, rtol=1e-6) @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("ord", [None, "fro", "nuc", -np.inf, np.inf, 1, -1, 2, -2]) def test_norm(self, diagonal_dtype, ord): input_shape = (5,) diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key) D1 = linop.Diagonal(diagonal=diagonal) D2 = snp.diag(diagonal) n1 = D1.norm(ord=ord) n2 = snp.linalg.norm(D2, ord=ord) snp.testing.assert_allclose(n1, n2, rtol=1e-6) def test_norm_except(self): input_shape = (5,) diagonal, key = randn(input_shape, dtype=np.float32, key=self.key) D = linop.Diagonal(diagonal=diagonal) with pytest.raises(ValueError): n = D.norm(ord=3) class TestScaledIdentity: def setup_method(self, method): self.key = jax.random.key(12345) input_shapes = [(8,), (8, 12), ((3,), (4, 5))] @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("input_shape", input_shapes) def test_eval(self, input_shape, input_dtype): x, key = randn(input_shape, dtype=input_dtype, key=self.key) scalar, key = randn((), dtype=input_dtype, key=key) Id = linop.ScaledIdentity(scalar=scalar, input_shape=input_shape, input_dtype=input_dtype) assert (Id @ x).shape == Id.output_shape snp.testing.assert_allclose(scalar * x, Id @ x, rtol=1e-5) @pytest.mark.parametrize("operator", [op.add, op.sub]) @pytest.mark.parametrize("input_shape", input_shapes) def test_binary_op(self, input_shape, operator): input_dtype = np.float32 diagonal, key = randn(input_shape, dtype=input_dtype, key=self.key) x, key = randn(input_shape, dtype=input_dtype, key=key) scalar, key = randn((), dtype=input_dtype, key=key) Id = linop.ScaledIdentity(scalar, input_shape=input_shape) D = linop.Diagonal(diagonal=diagonal) IdD = operator(Id, D) assert isinstance(IdD, linop.Diagonal) snp.testing.assert_allclose(IdD @ x, operator(scalar, diagonal) * x, rtol=1e-6) DId = operator(D, Id) assert isinstance(DId, linop.Diagonal) snp.testing.assert_allclose(DId @ x, operator(diagonal, scalar) * x, rtol=1e-6) def test_scale(self): input_shape = (5,) input_dtype = np.float32 scalar1, key = randn((), dtype=input_dtype, key=self.key) scalar2, key = randn((), dtype=input_dtype, key=key) x, key = randn(input_shape, dtype=input_dtype, key=self.key) Id = linop.ScaledIdentity(scalar=scalar1, input_shape=input_shape, input_dtype=input_dtype) sId = scalar2 * Id assert isinstance(sId, linop.ScaledIdentity) snp.testing.assert_allclose(sId @ x, scalar1 * scalar2 * x, rtol=1e-6) Ids = Id * scalar2 assert isinstance(Ids, linop.ScaledIdentity) snp.testing.assert_allclose(Ids @ x, scalar1 * scalar2 * x, rtol=1e-6) Idds = Id / scalar2 assert isinstance(Idds, linop.ScaledIdentity) snp.testing.assert_allclose(Idds @ x, x * scalar1 / scalar2, rtol=1e-6) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("ord", [None, "fro", "nuc", -np.inf, np.inf, 1, -1, 2, -2]) def test_norm(self, input_dtype, ord): input_shape = (5,) scalar, key = randn((), dtype=input_dtype, key=self.key) Id = linop.ScaledIdentity(scalar=scalar, input_shape=input_shape, input_dtype=input_dtype) D = linop.Diagonal( diagonal=scalar * snp.ones(input_shape), input_shape=input_shape, input_dtype=input_dtype, ) n1 = Id.norm(ord=ord) n2 = D.norm(ord=ord) snp.testing.assert_allclose(n1, n2, rtol=1e-6) def test_norm_except(self): input_shape = (5,) Id = linop.Identity(input_shape=input_shape, input_dtype=np.float32) with pytest.raises(ValueError): n = Id.norm(ord=3) class TestIdentity: def setup_method(self, method): self.key = jax.random.key(12345) input_shapes = [(8,), (8, 12), ((3,), (4, 5))] @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("input_shape", input_shapes) def test_eval(self, input_shape, input_dtype): x, key = randn(input_shape, dtype=input_dtype, key=self.key) Id = linop.Identity(input_shape=input_shape, input_dtype=input_dtype) assert (Id @ x).shape == Id.output_shape snp.testing.assert_allclose(x, Id @ x, rtol=1e-5) @pytest.mark.parametrize("operator", [op.add, op.sub]) @pytest.mark.parametrize("input_shape", input_shapes) def test_binary_op(self, input_shape, operator): input_dtype = np.float32 diagonal, key = randn(input_shape, dtype=input_dtype, key=self.key) scalar, key = randn((), dtype=input_dtype, key=key) x, key = randn(input_shape, dtype=input_dtype, key=key) Id = linop.Identity(input_shape=input_shape) Ids = linop.ScaledIdentity(scalar=scalar, input_shape=input_shape) D = linop.Diagonal(diagonal=diagonal) IdD = operator(Id, D) assert isinstance(IdD, linop.Diagonal) snp.testing.assert_allclose(IdD @ x, operator(1.0, diagonal) * x, rtol=1e-6) DId = operator(D, Id) assert isinstance(DId, linop.Diagonal) snp.testing.assert_allclose(DId @ x, operator(diagonal, 1.0) * x, rtol=1e-6) IdIds = operator(Id, Ids) assert isinstance(IdIds, linop.ScaledIdentity) snp.testing.assert_allclose(IdIds @ x, operator(1.0, scalar) * x, rtol=1e-6) IdsId = operator(Ids, Id) assert isinstance(IdsId, linop.ScaledIdentity) snp.testing.assert_allclose(IdsId @ x, operator(scalar, 1.0) * x, rtol=1e-6) def test_scale(self): input_shape = (5,) input_dtype = np.float32 scalar, key = randn((), dtype=input_dtype, key=self.key) x, key = randn(input_shape, dtype=input_dtype, key=key) Id = linop.Identity(input_shape=input_shape, input_dtype=input_dtype) sId = scalar * Id assert isinstance(sId, linop.ScaledIdentity) snp.testing.assert_allclose(sId @ x, scalar * x, rtol=1e-6) Ids = Id * scalar assert isinstance(Ids, linop.ScaledIdentity) snp.testing.assert_allclose(Ids @ x, scalar * x, rtol=1e-6) Idds = Id / scalar assert isinstance(Idds, linop.ScaledIdentity) snp.testing.assert_allclose(Idds @ x, x / scalar, rtol=1e-6) ================================================ FILE: scico/test/linop/test_diff.py ================================================ import numpy as np import pytest import scico.numpy as snp from scico.linop import FiniteDifference, SingleAxisFiniteDifference from scico.random import randn from scico.test.linop.test_linop import adjoint_test def test_eval(): with pytest.raises(ValueError): # axis 3 does not exist A = FiniteDifference(input_shape=(3, 4, 5), axes=(0, 3)) A = FiniteDifference(input_shape=(2, 3), append=1) x = snp.array([[1, 0, 1], [1, 1, 0]], dtype=snp.float32) Ax = A @ x snp.testing.assert_allclose( Ax[0], # down columns x[1] - x[0], ..., append - x[N-1] snp.array([[0, 1, -1], [-1, -1, 0]]), ) snp.testing.assert_allclose(Ax[1], snp.array([[-1, 1, -1], [0, -1, 0]])) # along rows # test scale B = 2.0 * A Bx = B @ x snp.testing.assert_allclose( Bx[0], # down columns x[1] - x[0], ..., append - x[N-1] 2.0 * snp.array([[0, 1, -1], [-1, -1, 0]]), ) snp.testing.assert_allclose(Bx[1], 2.0 * snp.array([[-1, 1, -1], [0, -1, 0]])) # along rows def test_except(): with pytest.raises(TypeError): # axis is not an int A = SingleAxisFiniteDifference(input_shape=(3,), axis=2.5) with pytest.raises(ValueError): # invalid parameter combination A = SingleAxisFiniteDifference(input_shape=(3,), prepend=0, circular=True) with pytest.raises(ValueError): # invalid prepend value A = SingleAxisFiniteDifference(input_shape=(3,), prepend=2) with pytest.raises(ValueError): # invalid append value A = SingleAxisFiniteDifference(input_shape=(3,), append="a") def test_eval_prepend(): x = snp.arange(1, 6) A = SingleAxisFiniteDifference(input_shape=(5,), prepend=0) snp.testing.assert_allclose(A @ x, snp.array([0, 1, 1, 1, 1])) A = SingleAxisFiniteDifference(input_shape=(5,), prepend=1) snp.testing.assert_allclose(A @ x, snp.array([1, 1, 1, 1, 1])) def test_eval_append(): x = snp.arange(1, 6) A = SingleAxisFiniteDifference(input_shape=(5,), append=0) snp.testing.assert_allclose(A @ x, snp.array([1, 1, 1, 1, 0])) A = SingleAxisFiniteDifference(input_shape=(5,), append=1) snp.testing.assert_allclose(A @ x, snp.array([1, 1, 1, 1, -5])) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("input_shape", [(16,), (16, 24)]) @pytest.mark.parametrize("axes", [0, 1, (0,), (1,), None]) @pytest.mark.parametrize("jit", [False, True]) def test_adjoint(input_shape, input_dtype, axes, jit): ndim = len(input_shape) if axes in [1, (1,)] and ndim == 1: return A = FiniteDifference(input_shape=input_shape, input_dtype=input_dtype, axes=axes, jit=jit) adjoint_test(A) @pytest.mark.parametrize( "shape_axes", [ ((3, 4), None), # 2d ((3, 4), 0), # 2d specific axis ((3, 4, 5), None), # 3d ((3, 4, 5), [0, 2]), # 3d specific axes ], ) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("jit", [False, True]) def test_eval_circular(shape_axes, input_dtype, jit): input_shape, axes = shape_axes x, _ = randn(input_shape, dtype=input_dtype) A = FiniteDifference( input_shape=input_shape, input_dtype=input_dtype, axes=axes, circular=True, jit=jit ) Ax = A @ x # check that correct differences are returned for ax in A.axes: np.testing.assert_allclose(np.roll(x, -1, ax) - x, Ax[ax], atol=1e-5, rtol=0) # check that the all results match noncircular results except at the last pixel B = FiniteDifference(input_shape=input_shape, input_dtype=input_dtype, axes=axes, jit=jit) Bx = B @ x for ax_ind, ax in enumerate(A.axes): np.testing.assert_allclose( Ax[ (ax_ind,) + tuple(slice(0, -1) if i == ax else slice(None) for i in range(len(input_shape))) ], Bx[ax_ind], atol=1e-5, rtol=0, ) ================================================ FILE: scico/test/linop/test_func.py ================================================ import numpy as np import pytest import scico.numpy as snp from scico import linop from scico.random import randn from scico.test.linop.test_linop import adjoint_test def test_transpose(): shape = (1, 2, 3, 4) perm = (1, 0, 3, 2) x, _ = randn(shape) H = linop.Transpose(shape, perm) np.testing.assert_array_equal(H @ x, x.transpose(perm)) # transpose transpose is transpose inverse np.testing.assert_array_equal(H.T @ H @ x, x) def test_transpose_ext_init(): shape = (1, 2, 3, 4) perm = (1, 0, 3, 2) x, _ = randn(shape) H = linop.Transpose( shape, perm, input_dtype=snp.float32, output_shape=shape, output_dtype=snp.float32 ) np.testing.assert_array_equal(H @ x, x.transpose(perm)) def test_reshape(): shape = (1, 2, 3, 4) newshape = (2, 12) x, _ = randn(shape) H = linop.Reshape(shape, newshape) np.testing.assert_array_equal(H @ x, x.reshape(newshape)) # reshape reshape is reshape inverse np.testing.assert_array_equal(H.T @ H @ x, x) def test_pad(): shape = (2, 3, 4) pad = 1 x, _ = randn(shape) H = linop.Pad(shape, pad) pad_shape = tuple(n + 2 * pad for n in shape) y = snp.zeros(pad_shape) y = y.at[pad:-pad, pad:-pad, pad:-pad].set(x) np.testing.assert_array_equal(H @ x, y) # pad transpose is crop y, _ = randn(pad_shape) np.testing.assert_array_equal(H.T @ y, y[pad:-pad, pad:-pad, pad:-pad]) def test_crop(): shape = (7, 9) crop = (1, 2) x, _ = randn(shape) H = linop.Crop(crop, shape) y = x[crop[0] : -crop[1], crop[0] : -crop[1]] np.testing.assert_array_equal(H @ x, y) @pytest.mark.parametrize("pad", [1, (1, 2), ((1, 0), (0, 1)), ((1, 1), (2, 2))]) def test_crop_pad_adjoint(pad): shape = (9, 10) H = linop.Pad(shape, pad) G = linop.Crop(pad, H.output_shape) assert linop.valid_adjoint(H, G, eps=1e-5) class SliceTestObj: def __init__(self, dtype): self.x = snp.zeros((4, 5, 6, 7), dtype=dtype) @pytest.fixture(scope="module", params=[np.float32, np.complex64]) def slicetestobj(request): yield SliceTestObj(request.param) slice_examples = [ np.s_[1:], np.s_[:, 2:], np.s_[..., 3:], np.s_[1:, :-3], np.s_[1:, :, :3], np.s_[1:, ..., 2:], np.s_[np.newaxis], np.s_[:, np.newaxis], ] @pytest.mark.parametrize("idx", slice_examples) def test_slice_eval(slicetestobj, idx): x = slicetestobj.x A = linop.Slice(idx=idx, input_shape=x.shape, input_dtype=x.dtype) assert (A @ x).shape == x[idx].shape @pytest.mark.parametrize("idx", slice_examples) def test_slice_adj(slicetestobj, idx): x = slicetestobj.x A = linop.Slice(idx=idx, input_shape=x.shape, input_dtype=x.dtype) adjoint_test(A) block_slice_examples = [ 1, np.s_[0:1], np.s_[:1], ] @pytest.mark.parametrize("idx", block_slice_examples) def test_slice_blockarray(idx): x = snp.BlockArray((snp.zeros((3, 4)), snp.ones((3, 4, 5, 6)))) A = linop.Slice(idx=idx, input_shape=x.shape, input_dtype=x.dtype) assert (A @ x).shape == x[idx].shape ================================================ FILE: scico/test/linop/test_grad.py ================================================ from itertools import combinations import numpy as np import jax import pytest import scico.numpy as snp from scico.linop import ( CylindricalGradient, PolarGradient, ProjectedGradient, SphericalGradient, ) from scico.numpy import Array from scico.random import randn def test_proj_grad(): x = snp.ones((4, 5)) P = ProjectedGradient(x.shape, axes=(0,)) assert P(x).shape == (4, 5) P = ProjectedGradient(x.shape) assert P(x).shape == (2, 4, 5) P = ProjectedGradient(x.shape, coord=(np.arange(0, 4)[:, np.newaxis],)) assert P(x).shape == (4, 5) coord = ( snp.blockarray([snp.array([0.0]), snp.array([1.0])]), snp.blockarray([snp.array([1.0]), snp.array([0.0])]), ) P = ProjectedGradient(x.shape, coord=coord) assert P(x).shape == (2, 4, 5) class TestPolarGradient: def setup_method(self, method): self.key = jax.random.key(12345) @pytest.mark.parametrize("jit", [True, False]) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("outflags", [(True, True), (True, False), (False, True)]) @pytest.mark.parametrize("center", [None, (-2, 3), (1.2, -3.5)]) @pytest.mark.parametrize( "shape_axes", [ ((20, 20), None), ((20, 21), (0, 1)), ((16, 17, 3), (0, 1)), ((2, 17, 16), (1, 2)), ((2, 17, 16, 3), (2, 1)), ], ) @pytest.mark.parametrize("cdiff", [True, False]) def test_eval(self, cdiff, shape_axes, center, outflags, input_dtype, jit): input_shape, axes = shape_axes if axes is None: testaxes = (0, 1) else: testaxes = axes if center is not None: axes_shape = [input_shape[ax] for ax in testaxes] center = (snp.array(axes_shape) - 1) / 2 + snp.array(center) angular, radial = outflags x, key = randn(input_shape, dtype=input_dtype, key=self.key) A = PolarGradient( input_shape, axes=axes, center=center, angular=angular, radial=radial, cdiff=cdiff, input_dtype=input_dtype, jit=jit, ) Ax = A @ x assert isinstance(Ax, Array) if angular and radial: assert Ax.shape[0] == 2 assert Ax.shape[1:] == input_shape else: assert Ax.shape == input_shape assert Ax.dtype == input_dtype # Test orthogonality of coordinate axes coord = A.coord for n0, n1 in combinations(range(len(coord)), 2): c0 = coord[n0] c1 = coord[n1] assert snp.abs(snp.sum(c0 * c1)) < 1e-5 class TestCylindricalGradient: def setup_method(self, method): self.key = jax.random.key(12345) @pytest.mark.parametrize("jit", [True, False]) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize( "outflags", [ (True, True, True), (True, True, False), (True, False, True), (True, False, False), (False, True, True), (False, True, False), (False, False, True), ], ) @pytest.mark.parametrize("center", [None, (-2, 3, 0), (1.2, -3.5, 1.5)]) @pytest.mark.parametrize( "shape_axes", [ ((20, 20, 20), None), ((17, 18, 19), (0, 1, 2)), ((16, 17, 18, 3), (0, 1, 2)), ((2, 17, 16, 15), (1, 2, 3)), ((17, 2, 16, 15), (0, 2, 3)), ((17, 2, 16, 15), (3, 2, 0)), ], ) def test_eval(self, shape_axes, center, outflags, input_dtype, jit): input_shape, axes = shape_axes if axes is None: testaxes = (0, 1, 2) else: testaxes = axes if center is not None: axes_shape = [input_shape[ax] for ax in testaxes] center = (snp.array(axes_shape) - 1) / 2 + snp.array(center) angular, radial, axial = outflags x, key = randn(input_shape, dtype=input_dtype, key=self.key) A = CylindricalGradient( input_shape, axes=axes, center=center, angular=angular, radial=radial, axial=axial, input_dtype=input_dtype, jit=jit, ) Ax = A @ x assert isinstance(Ax, Array) Nc = sum([angular, radial, axial]) if Nc > 1: assert Ax.shape[0] == Nc assert Ax.shape[1:] == input_shape else: assert Ax.shape == input_shape assert Ax.dtype == input_dtype # Test orthogonality of coordinate axes coord = A.coord for n0, n1 in combinations(range(len(coord)), 2): c0 = coord[n0] c1 = coord[n1] s = sum([c0[m] * c1[m] for m in range(len(c0))]).sum() assert snp.abs(s) < 1e-5 class TestSphericalGradient: def setup_method(self, method): self.key = jax.random.key(12345) @pytest.mark.parametrize("jit", [True, False]) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize( "outflags", [ (True, True, True), (True, True, False), (True, False, True), (True, False, False), (False, True, True), (False, True, False), (False, False, True), ], ) @pytest.mark.parametrize("center", [None, (-2, 3, 0), (1.2, -3.5, 1.5)]) @pytest.mark.parametrize( "shape_axes", [ ((20, 20, 20), None), ((17, 18, 19), (0, 1, 2)), ((16, 17, 18, 3), (0, 1, 2)), ((2, 17, 16, 15), (1, 2, 3)), ((17, 2, 16, 15), (0, 2, 3)), ((17, 2, 16, 15), (3, 2, 0)), ], ) def test_eval(self, shape_axes, center, outflags, input_dtype, jit): input_shape, axes = shape_axes if axes is None: testaxes = (0, 1, 2) else: testaxes = axes if center is not None: axes_shape = [input_shape[ax] for ax in testaxes] center = (snp.array(axes_shape) - 1) / 2 + snp.array(center) azimuthal, polar, radial = outflags x, key = randn(input_shape, dtype=input_dtype, key=self.key) A = SphericalGradient( input_shape, axes=axes, center=center, azimuthal=azimuthal, polar=polar, radial=radial, input_dtype=input_dtype, jit=jit, ) Ax = A @ x assert isinstance(Ax, Array) Nc = sum([azimuthal, polar, radial]) if Nc > 1: assert Ax.shape[0] == Nc assert Ax.shape[1:] == input_shape else: assert Ax.shape == input_shape assert Ax.dtype == input_dtype # Test orthogonality of coordinate axes coord = A.coord for n0, n1 in combinations(range(len(coord)), 2): c0 = coord[n0] c1 = coord[n1] s = sum([c0[m] * c1[m] for m in range(len(c0))]).sum() assert snp.abs(s) < 1e-5 ================================================ FILE: scico/test/linop/test_linop.py ================================================ import operator as op import numpy as np from jax import config import pytest # enable 64-bit mode for output dtype checks config.update("jax_enable_x64", True) from typing import Optional import jax import scico.numpy as snp from scico import linop from scico.random import randn from scico.typing import PRNGKey SCALARS = (2, 1e0, snp.array(1.0)) def adjoint_test( A: linop.LinearOperator, key: Optional[PRNGKey] = None, rtol: float = 1e-4, x: Optional[snp.Array] = None, y: Optional[snp.Array] = None, ): """Check the validity of A.conj().T as the adjoint for a LinearOperator A. Args: A: LinearOperator to test. key: PRNGKey for generating `x`. rtol: Relative tolerance. """ assert linop.valid_adjoint(A, A.H, key=key, eps=rtol, x=x, y=y) class AbsMatOp(linop.LinearOperator): """Simple LinearOperator subclass for testing purposes. Similar to linop.MatrixOperator, but does not use the specialized MatrixOperator methods (.T, adj, etc). Used to verify the LinearOperator interface. """ def __init__(self, A, adj_fn=None): self.A = A super().__init__( input_shape=A.shape[1], output_shape=A.shape[0], input_dtype=A.dtype, adj_fn=adj_fn ) def _eval(self, x): return self.A @ x class LinearOperatorTestObj: def __init__(self, dtype): M, N = (8, 16) key = jax.random.key(12345) self.dtype = dtype self.A, key = randn((M, N), dtype=dtype, key=key) self.B, key = randn((M, N), dtype=dtype, key=key) self.C, key = randn((N, M), dtype=dtype, key=key) self.D, key = randn((M, N - 1), dtype=dtype, key=key) self.x, key = randn((N,), dtype=dtype, key=key) self.y, key = randn((M,), dtype=dtype, key=key) self.Ao = AbsMatOp(self.A) self.Bo = AbsMatOp(self.B) self.Co = AbsMatOp(self.C) self.Do = AbsMatOp(self.D) @pytest.fixture(scope="module", params=[np.float32, np.float64, np.complex64, np.complex128]) def testobj(request): yield LinearOperatorTestObj(request.param) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_binary_op(testobj, operator): # Our AbsMatOp class does not override the __add__, etc # so AbsMatOp + AbsMatOp -> LinearOperator # So to verify results, we evaluate the new LinearOperator on a random input comp_mat = operator(testobj.A, testobj.B) # composite matrix comp_op = operator(testobj.Ao, testobj.Bo) # composite linop assert isinstance(comp_op, linop.LinearOperator) # Ensure we don't get a Map assert comp_op.input_dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=0, atol=1e-5) # linops of different sizes with pytest.raises(ValueError): operator(testobj.Ao, testobj.Co) with pytest.raises(ValueError): operator(testobj.Ao, testobj.Do) @pytest.mark.parametrize("operator", [op.mul, op.truediv]) @pytest.mark.parametrize("scalar", SCALARS) def test_scalar_left(testobj, operator, scalar): comp_mat = operator(testobj.A, scalar) comp_op = operator(testobj.Ao, scalar) assert isinstance(comp_op, linop.LinearOperator) # Ensure we don't get a Map assert comp_op.input_dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5) np.testing.assert_allclose(comp_mat.conj().T @ testobj.y, comp_op.adj(testobj.y), rtol=2e-4) @pytest.mark.parametrize("operator", [op.mul, op.truediv]) @pytest.mark.parametrize("scalar", SCALARS) def test_scalar_right(testobj, operator, scalar): if operator == op.truediv: pytest.xfail("scalar / LinearOperator is not supported") comp_mat = operator(scalar, testobj.A) comp_op = operator(scalar, testobj.Ao) assert comp_op.input_dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5) def test_negation(testobj): comp_mat = -testobj.A comp_op = -testobj.Ao assert comp_op.input_dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_invalid_add_sub_array(testobj, operator): # Try to add or subtract an ndarray with AbsMatOp with pytest.raises(TypeError): operator(testobj.A, testobj.Ao) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_invalid_add_sub_scalar(testobj, operator): # Try to add or subtract a scalar with AbsMatOp with pytest.raises(TypeError): operator(1.0, testobj.Ao) def test_matmul_left(testobj): comp_mat = testobj.A @ testobj.C comp_op = testobj.Ao @ testobj.Co assert comp_op.input_dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat @ testobj.y, comp_op @ testobj.y, rtol=5e-5) def test_matmul_right(testobj): comp_mat = testobj.C @ testobj.A comp_op = testobj.Co @ testobj.Ao assert comp_op.input_dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5) def test_matvec_left(testobj): comp_mat = testobj.A @ testobj.x comp_op = testobj.Ao @ testobj.x assert comp_op.dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat, comp_op, rtol=5e-5) def test_matvec_right(testobj): comp_mat = testobj.C @ testobj.y comp_op = testobj.Co @ testobj.y assert comp_op.dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat, comp_op, rtol=5e-5) def test_gram(testobj): Ao = testobj.Ao a = Ao.gram(testobj.x) b = Ao.conj().T @ Ao @ testobj.x c = Ao.gram_op @ testobj.x comp_mat = testobj.A.conj().T @ testobj.A @ testobj.x np.testing.assert_allclose(a, comp_mat, rtol=5e-5) np.testing.assert_allclose(b, comp_mat, rtol=5e-5) np.testing.assert_allclose(c, comp_mat, rtol=5e-5) def test_matvec_call(testobj): # A @ x and A(x) should return same np.testing.assert_allclose(testobj.Ao @ testobj.x, testobj.Ao(testobj.x), rtol=5e-5) def test_adj_composition(testobj): Ao = testobj.Ao Bo = testobj.Bo A = testobj.A B = testobj.B x = testobj.x comp_mat = A.conj().T @ B a = Ao.conj().T @ Bo b = Ao.adj(Bo) assert a.input_dtype == testobj.A.dtype assert b.input_dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat @ x, a @ x, rtol=5e-5) np.testing.assert_allclose(comp_mat @ x, b @ x, rtol=5e-5) def test_transpose_matvec(testobj): Ao = testobj.Ao y = testobj.y a = Ao.T @ y b = y.T @ Ao comp_mat = testobj.A.T @ y assert a.dtype == testobj.A.dtype assert b.dtype == testobj.A.dtype np.testing.assert_allclose(a, comp_mat, rtol=2e-4) np.testing.assert_allclose(a, b, rtol=5e-5) def test_transpose_matmul(testobj): Ao = testobj.Ao Bo = testobj.Bo x = testobj.x comp_op = Ao.T @ Bo comp_mat = testobj.A.T @ testobj.B assert comp_op.input_dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat @ x, comp_op @ x, rtol=5e-5) def test_conj_transpose_matmul(testobj): Ao = testobj.Ao Bo = testobj.Bo x = testobj.x comp_op = Ao.conj().T @ Bo comp_mat = testobj.A.conj().T @ testobj.B assert comp_mat.dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat @ x, comp_op @ x, rtol=5e-5) def test_conj_matvec(testobj): Ao = testobj.Ao x = testobj.x a = Ao.conj() @ x comp_mat = testobj.A.conj() @ x assert a.dtype == testobj.A.dtype np.testing.assert_allclose(a, comp_mat, rtol=5e-5) def test_adjoint_matvec(testobj): Ao = testobj.Ao y = testobj.y a = Ao.adj(y) b = Ao.conj().T @ y c = (y.conj().T @ Ao).conj() comp_mat = testobj.A.conj().T @ y assert a.dtype == testobj.A.dtype assert b.dtype == testobj.A.dtype assert c.dtype == testobj.A.dtype np.testing.assert_allclose(a, comp_mat, rtol=2e-4) np.testing.assert_allclose(a, b, rtol=5e-5) np.testing.assert_allclose(a, c, rtol=5e-5) def test_adjoint_matmul(testobj): # shape mismatch Ao = testobj.Ao Co = testobj.Co with pytest.raises(ValueError): Ao.adj(Co) def test_hermitian(testobj): Ao = testobj.Ao y = testobj.y np.testing.assert_allclose(Ao.conj().T @ y, Ao.H @ y) def test_shape(testobj): Ao = testobj.Ao x = testobj.x y = testobj.y with pytest.raises(ValueError): _ = Ao @ y with pytest.raises(ValueError): _ = Ao(y) with pytest.raises(ValueError): _ = Ao.T @ x with pytest.raises(ValueError): _ = Ao.adj(x) def test_adj_lazy(): dtype = np.float32 M, N = (8, 16) A, key = randn((M, N), dtype=np.float32, key=None) y, key = randn((M,), dtype=np.float32, key=key) Ao = AbsMatOp(A, adj_fn=None) # defer setting the linop assert Ao._adj is None a = Ao.adj(y) # Adjoint is set when .adj() is called b = A.T @ y np.testing.assert_allclose(a, b, rtol=1e-5) def test_jit_adj_lazy(): dtype = np.float32 M, N = (8, 16) A, key = randn((M, N), dtype=np.float32, key=None) y, key = randn((M,), dtype=np.float32, key=key) Ao = AbsMatOp(A, adj_fn=None) # defer setting the linop assert Ao._adj is None Ao.jit() # Adjoint set here assert Ao._adj is not None a = Ao.adj(y) b = A.T @ y np.testing.assert_allclose(a, b, rtol=1e-5) ================================================ FILE: scico/test/linop/test_linop_stack.py ================================================ import numpy as np import jax import pytest import scico.numpy as snp from scico.linop import ( Convolve, DiagonalReplicated, DiagonalStack, Identity, Sum, VerticalStack, ) from scico.operator import Abs from scico.random import randn from scico.test.linop.test_linop import adjoint_test class TestVerticalStack: def setup_method(self, method): self.key = jax.random.key(12345) @pytest.mark.parametrize("jit", [False, True]) def test_construct(self, jit): # requires a list of LinearOperators Id = Identity((42,)) with pytest.raises(TypeError): H = VerticalStack(Id, jit=jit) # requires all list elements to be LinearOperators A = Abs((42,)) with pytest.raises(TypeError): H = VerticalStack((A, Id), jit=jit) # checks input sizes A = Identity((3, 2)) B = Identity((7, 2)) with pytest.raises(ValueError): H = VerticalStack([A, B], jit=jit) # in general, returns a BlockArray A = Convolve(snp.ones((3, 3)), (7, 11)) B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], jit=jit) x = np.ones((7, 11)) y = H @ x assert y.shape == ((9, 13), (8, 12)) # ... result should be [A@x, B@x] assert np.allclose(y[0], A @ x) assert np.allclose(y[1], B @ x) # by default, collapse_output to jax array when possible A = Convolve(snp.ones((2, 2)), (7, 11)) B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], jit=jit) x = np.ones((7, 11)) y = H @ x assert y.shape == (2, 8, 12) # ... result should be [A@x, B@x] assert np.allclose(y[0], A @ x) assert np.allclose(y[1], B @ x) # let user turn off collapsing A = Convolve(snp.ones((2, 2)), (7, 11)) B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], collapse_output=False, jit=jit) x = np.ones((7, 11)) y = H @ x assert y.shape == ((8, 12), (8, 12)) @pytest.mark.parametrize("collapse_output", [False, True]) @pytest.mark.parametrize("jit", [False, True]) def test_adjoint(self, collapse_output, jit): # general case A = Convolve(snp.ones((3, 3)), (7, 11)) B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], collapse_output=collapse_output, jit=jit) adjoint_test(H, self.key) # collapsable case A = Convolve(snp.ones((2, 2)), (7, 11)) B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], collapse_output=collapse_output, jit=jit) adjoint_test(H, self.key) @pytest.mark.parametrize("collapse_output", [False, True]) @pytest.mark.parametrize("jit", [False, True]) def test_algebra(self, collapse_output, jit): # adding A = Convolve(snp.ones((2, 2)), (7, 11)) B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], collapse_output=collapse_output, jit=jit) A = Convolve(snp.array(np.random.rand(2, 2)), (7, 11)) B = Convolve(snp.array(np.random.rand(2, 2)), (7, 11)) G = VerticalStack([A, B], collapse_output=collapse_output, jit=jit) x = np.ones((7, 11)) S = H + G # test correctness of addition assert S.output_shape == H.output_shape assert S.input_shape == H.input_shape np.testing.assert_allclose((S @ x)[0], (H @ x + G @ x)[0]) np.testing.assert_allclose((S @ x)[1], (H @ x + G @ x)[1]) class TestBlockDiagonalLinearOperator: def test_construct(self): Id = Identity((42,)) A = Abs((42,)) with pytest.raises(TypeError): H = DiagonalStack((A, Id)) def test_apply(self): S1 = (3, 4) S2 = (3, 5) S3 = (2, 2) A1 = Identity(S1) A2 = 2 * Identity(S2) A3 = Sum(S3) H = DiagonalStack((A1, A2, A3)) x = snp.ones((S1, S2, S3)) y = H @ x y_expected = snp.blockarray((snp.ones(S1), 2 * snp.ones(S2), snp.sum(snp.ones(S3)))) np.testing.assert_equal(y, y_expected) def test_adjoint(self): S1 = (3, 4) S2 = (3, 5) S3 = (2, 2) A1 = Identity(S1) A2 = 2 * Identity(S2) A3 = Sum(S3) H = DiagonalStack((A1, A2, A3)) y = snp.ones((S1, S2, ()), dtype=snp.float32) x = H.T @ y x_expected = snp.blockarray( ( snp.ones(S1), snp.ones(S2), snp.ones(S3), ) ) assert x == x_expected def test_input_collapse(self): S = (3, 4) A1 = Identity(S) A2 = Sum(S) H = DiagonalStack((A1, A2)) assert H.input_shape == (2, *S) H = DiagonalStack((A1, A2), collapse_input=False) assert H.input_shape == (S, S) def test_output_collapse(self): S1 = (3, 4) S2 = (5, 3, 4) A1 = Identity(S1) A2 = Sum(S2, axis=0) H = DiagonalStack((A1, A2)) assert H.output_shape == (2, *S1) H = DiagonalStack((A1, A2), collapse_output=False) assert H.output_shape == (S1, S1) class TestDiagonalReplicated: def setup_method(self, method): self.key = jax.random.key(12345) def test_adjoint(self): x, key = randn((2, 3, 4), key=self.key) A = Sum(x.shape[1:], axis=-1) D = DiagonalReplicated(A, x.shape[0]) y = D.T(D(x)) np.testing.assert_allclose(y[0], A.T(A(x[0]))) np.testing.assert_allclose(y[1], A.T(A(x[1]))) ================================================ FILE: scico/test/linop/test_linop_util.py ================================================ import numpy as np from jax import config import pytest # enable 64-bit mode for output dtype checks config.update("jax_enable_x64", True) import jax import scico.numpy as snp from scico import linop from scico.operator import Operator from scico.random import randn from scico.test.linop.test_linop import AbsMatOp def test_valid_adjoint(): diagonal, key = randn((10,), dtype=np.float32) D = linop.Diagonal(diagonal=diagonal) assert linop.valid_adjoint(D, D.T, key=key, eps=None) < 1e-4 x, key = randn((5,), dtype=np.float32) y, key = randn((5,), dtype=np.float32) with pytest.raises(ValueError): linop.valid_adjoint(D, D.T, key=key, x=x) with pytest.raises(ValueError): linop.valid_adjoint(D, D.T, key=key, y=y) class PowerIterTestObj: def __init__(self, dtype): M, N = (4, 4) key = jax.random.key(12345) self.dtype = dtype A, key = randn((M, N), dtype=dtype, key=key) self.A = A.conj().T @ A # ensure symmetric self.Ao = linop.MatrixOperator(self.A) self.Bo = AbsMatOp(self.A) self.key = key self.ev = snp.linalg.norm( self.A, 2 ) # The largest eigenvalue of A is the spectral norm of A @pytest.fixture(scope="module", params=[np.float32, np.complex64]) def pitestobj(request): yield PowerIterTestObj(request.param) def test_power_iteration(pitestobj): """Verify that power iteration calculates largest eigenvalue for real and complex symmetric matrices. """ # Test using the LinearOperator MatrixOperator mu, v = linop.power_iteration(A=pitestobj.Ao, maxiter=100, key=pitestobj.key) assert np.abs(mu - pitestobj.ev) < 1e-4 # Test using the AbsMatOp for test_linop.py mu, v = linop.power_iteration(A=pitestobj.Bo, maxiter=100, key=pitestobj.key) assert np.abs(mu - pitestobj.ev) < 1e-4 def test_operator_norm(): Iop = linop.Identity(8) Inorm = linop.operator_norm(Iop) assert np.abs(Inorm - 1.0) < 1e-5 key = jax.random.key(12345) for dtype in [np.float32, np.complex64]: d, key = randn((16,), dtype=dtype, key=key) D = linop.Diagonal(d) Dnorm = linop.operator_norm(D) assert np.abs(Dnorm - snp.abs(d).max()) < 1e-5 Zop = linop.MatrixOperator(snp.zeros((3, 3))) Znorm = linop.operator_norm(Zop) assert np.abs(Znorm) < 1e-6 @pytest.mark.parametrize("dtype", [snp.float32, snp.complex64]) @pytest.mark.parametrize("inc_eval", [True, False]) def test_jacobian(dtype, inc_eval): N = 7 M = 8 key = None fmx, key = randn((M, N), key=key, dtype=dtype) F = Operator( (N, 1), output_shape=(M, 1), eval_fn=lambda x: fmx @ x, input_dtype=dtype, output_dtype=dtype, ) u, key = randn((N, 1), key=key, dtype=dtype) v, key = randn((N, 1), key=key, dtype=dtype) w, key = randn((M, 1), key=key, dtype=dtype) J = linop.jacobian(F, u, include_eval=inc_eval) Jv = J(v) JHw = J.H(w) if inc_eval: np.testing.assert_allclose(Jv[0], F(u)) np.testing.assert_allclose(Jv[1], F.jvp(u, v)[1]) np.testing.assert_allclose(JHw[0], F(u)) np.testing.assert_allclose(JHw[1], F.vjp(u)[1](w)) else: np.testing.assert_allclose(Jv, F.jvp(u, v)[1]) np.testing.assert_allclose(JHw, F.vjp(u)[1](w)) ================================================ FILE: scico/test/linop/test_matrix.py ================================================ import operator as op import numpy as np import jax import jax.numpy as jnp import pytest import scico.numpy as snp from scico import linop from scico.linop import MatrixOperator from scico.random import randn from scico.test.linop.test_linop import AbsMatOp class TestMatrix: def setup_method(self, method): self.key = jax.random.key(12345) @pytest.mark.parametrize("input_cols", [0, 2]) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)]) def test_eval(self, matrix_shape, input_dtype, input_cols): A, key = randn(matrix_shape, dtype=input_dtype, key=self.key) Ao = MatrixOperator(A, input_cols=input_cols) x, key = randn(Ao.input_shape, dtype=Ao.input_dtype, key=key) np.testing.assert_allclose(A @ x, Ao @ x) # Invalid shapes with pytest.raises(TypeError): y, key = randn((64,), dtype=Ao.input_dtype, key=key) _ = Ao @ y @pytest.mark.parametrize("input_cols", [0, 2]) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)]) def test_adjoint(self, matrix_shape, input_dtype, input_cols): A, key = randn(matrix_shape, dtype=input_dtype, key=self.key) Ao = MatrixOperator(A, input_cols=input_cols) x, key = randn(Ao.output_shape, dtype=Ao.input_dtype, key=key) np.testing.assert_allclose(A.conj().T @ x, Ao.conj().T @ x) @pytest.mark.parametrize("input_cols", [0, 2]) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)]) def test_adjoint_method(self, matrix_shape, input_dtype, input_cols): A, key = randn(matrix_shape, dtype=input_dtype, key=self.key) Ao = MatrixOperator(A, input_cols=input_cols) x, key = randn(Ao.output_shape, dtype=Ao.input_dtype, key=key) np.testing.assert_allclose(Ao.adj(x), Ao.conj().T @ x) @pytest.mark.parametrize("input_cols", [0, 2]) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)]) def test_hermetian_method(self, matrix_shape, input_dtype, input_cols): A, key = randn(matrix_shape, dtype=input_dtype, key=self.key) Ao = MatrixOperator(A, input_cols=input_cols) x, key = randn(Ao.output_shape, dtype=Ao.input_dtype, key=key) np.testing.assert_allclose(Ao.H @ x, Ao.conj().T @ x) @pytest.mark.parametrize("input_cols", [0, 2]) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)]) def test_gram_method(self, matrix_shape, input_dtype, input_cols): A, key = randn(matrix_shape, dtype=input_dtype, key=self.key) Ao = MatrixOperator(A, input_cols=input_cols) x, key = randn(Ao.input_shape, dtype=Ao.input_dtype, key=key) np.testing.assert_allclose(Ao.gram(x), A.conj().T @ A @ x, rtol=5e-5) @pytest.mark.parametrize("input_cols", [0, 2]) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)]) def test_gram_op(self, matrix_shape, input_dtype, input_cols): A, key = randn(matrix_shape, dtype=input_dtype, key=self.key) Ao = MatrixOperator(A, input_cols=input_cols) G = Ao.gram_op x, key = randn(Ao.input_shape, dtype=Ao.input_dtype, key=key) np.testing.assert_allclose(G @ x, A.conj().T @ A @ x, rtol=5e-5) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_add_sub(self, operator): A, key = randn((4, 6), key=self.key) B, key = randn((4, 6), key=key) C, key = randn((4, 4), key=key) x, key = randn((6,), key=key) Ao = MatrixOperator(A) Bo = MatrixOperator(B) Co = MatrixOperator(C) ABx = operator(Ao, Bo) @ x AxBx = operator(Ao @ x, Bo @ x) np.testing.assert_allclose(ABx, AxBx, rtol=5e-5) with pytest.raises(ValueError): operator(Ao, Co) @pytest.mark.parametrize("operator", [op.add, op.sub, op.mul, op.truediv]) def test_scalar_left(self, operator): scalar = np.float32(np.random.randn()) A, key = randn((4, 6), key=self.key) x, key = randn((6,), key=key) Ao = MatrixOperator(A) np.testing.assert_allclose(operator(scalar, Ao) @ x, operator(scalar, A) @ x, rtol=5e-5) @pytest.mark.parametrize("operator", [op.add, op.sub, op.mul, op.truediv]) def test_scalar_right(self, operator): scalar = np.float32(np.random.randn()) A, key = randn((4, 6), key=self.key) x, key = randn((6,), key=key) Ao = MatrixOperator(A) np.testing.assert_allclose(operator(Ao, scalar) @ x, operator(A, scalar) @ x, rtol=5e-5) @pytest.mark.parametrize("operator", [op.add, op.sub, op.mul, op.truediv]) def test_elementwise_matops(self, operator): A, key = randn((4, 6), key=self.key) B, key = randn((4, 6), key=key) Ao = MatrixOperator(A) Bo = MatrixOperator(B) np.testing.assert_allclose(operator(Ao, Bo).A, operator(A, B), rtol=5e-5) @pytest.mark.parametrize("operator", [op.add, op.sub, op.mul, op.truediv]) def test_elementwise_array_left(self, operator): A, key = randn((4, 6), key=self.key) B, key = randn((4, 6), key=key) Ao = MatrixOperator(A) Bo = MatrixOperator(B) np.testing.assert_allclose(operator(Ao, B).A, operator(A, B), rtol=5e-5) @pytest.mark.parametrize("operator", [op.add, op.sub, op.mul, op.truediv]) def test_elementwise_array_right(self, operator): A, key = randn((4, 6), key=self.key) B, key = randn((4, 6), key=key) Ao = MatrixOperator(A) Bo = MatrixOperator(B) np.testing.assert_allclose(operator(A, Bo).A, operator(A, B), rtol=5e-5) @pytest.mark.parametrize("operator", [op.add, op.sub, op.mul, op.truediv]) def test_elementwise_matop_shape_mismatch(self, operator): A, key = randn((4, 6), key=self.key) B, key = randn((4, 4), key=key) Ao = MatrixOperator(A) Bo = MatrixOperator(B) with pytest.raises(ValueError): operator(Ao, Bo) @pytest.mark.parametrize("operator", [op.add, op.sub, op.mul, op.truediv]) def test_elementwise_array_shape_mismatch(self, operator): A, key = randn((4, 6), key=self.key) B, key = randn((4, 4), key=key) Ao = MatrixOperator(A) Bo = MatrixOperator(B) with pytest.raises(ValueError): operator(Ao, B) with pytest.raises(ValueError): operator(B, Ao) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_elementwise_linop(self, operator): A, key = randn((4, 6), key=self.key) B, key = randn((4, 6), key=key) Ao = MatrixOperator(A) Bo = AbsMatOp(B) x, key = randn(Ao.input_shape, dtype=Ao.input_dtype, key=key) np.testing.assert_allclose(operator(Ao, Bo) @ x, operator(Ao @ x, Bo @ x), rtol=5e-5) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_elementwise_linop_mismatch(self, operator): A, key = randn((4, 6), key=self.key) B, key = randn((4, 4), key=key) Ao = MatrixOperator(A) Bo = AbsMatOp(B) with pytest.raises(ValueError): operator(Ao, Bo) @pytest.mark.parametrize("operator", [op.mul, op.truediv]) def test_elementwise_linop_invalid(self, operator): A, key = randn((4, 6), key=self.key) B, key = randn((4, 6), key=key) Ao = MatrixOperator(A) Bo = AbsMatOp(B) with pytest.raises(TypeError): operator(Ao, Bo) with pytest.raises(TypeError): operator(Bo, Ao) def test_matmul(self): A, key = randn((4, 6), key=self.key) B, key = randn((6, 3), key=key) Ao = MatrixOperator(A) Bo = MatrixOperator(B) x, key = randn(Bo.input_shape, dtype=Ao.input_dtype, key=key) AB = Ao @ Bo np.testing.assert_allclose((Ao @ Bo) @ x, Ao @ (Bo @ x), rtol=5e-5) def test_matmul_cols(self): A, key = randn((4, 6), key=self.key) B, key = randn((6, 3), key=key) Ao = MatrixOperator(A, input_cols=2) Bo = MatrixOperator(B, input_cols=2) x, key = randn(Bo.input_shape, dtype=Ao.input_dtype, key=key) AB = Ao @ Bo np.testing.assert_allclose((Ao @ Bo) @ x, Ao @ (Bo @ x), rtol=5e-5) def test_matmul_linop(self): A, key = randn((4, 6), key=self.key) B, key = randn((6, 3), key=key) Ao = MatrixOperator(A) Bo = AbsMatOp(B) x, key = randn(Bo.input_shape, dtype=Ao.input_dtype, key=key) AB = Ao @ Bo np.testing.assert_allclose((Ao @ Bo) @ x, Ao @ (Bo @ x), rtol=5e-5) def test_matmul_linop_shape_mismatch(self): A, key = randn((4, 6), key=self.key) B, key = randn((5, 3), key=key) Ao = MatrixOperator(A) Bo = AbsMatOp(B) with pytest.raises(ValueError): _ = Ao @ Bo def test_matmul_identity(self): A, key = randn((4, 6), key=self.key) Ao = MatrixOperator(A) I = linop.Identity(input_shape=(6,)) assert Ao == Ao @ I def test_init_array(self): Am = np.random.randn(4, 6) A = MatrixOperator(Am) assert isinstance(A.A, np.ndarray) A = MatrixOperator(jnp.array(Am)) assert isinstance(A.A, jnp.ndarray) np.testing.assert_array_equal(A.A, jnp.array(A)) with pytest.raises(TypeError): MatrixOperator([1.0, 3.0]) @pytest.mark.parametrize("matrix_shape", [(3,), (2, 3, 4)]) def test_init_wrong_dims(self, matrix_shape): A = np.random.randn(*matrix_shape) with pytest.raises(TypeError): Ao = MatrixOperator(A) def test_to_array(self): A = np.random.randn(4, 6) Ao = MatrixOperator(A) A_array = Ao.to_array() assert isinstance(A_array, np.ndarray) np.testing.assert_allclose(A_array, A) A_array = jnp.array(Ao) assert isinstance(A_array, jax.Array) np.testing.assert_allclose(A_array, A) @pytest.mark.parametrize("ord", ["fro", 2]) @pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("keepdims", [True, False]) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) def test_norm(self, ord, axis, keepdims, input_dtype): # pylint: disable=W0622 A, key = randn((4, 6), dtype=input_dtype, key=self.key) Ao = MatrixOperator(A) if ord == "fro" and axis is not None: # Not defined; pass else: x = Ao.norm(ord=ord, axis=axis, keepdims=keepdims) y = snp.linalg.norm(A, ord=ord, axis=axis, keepdims=keepdims) np.testing.assert_allclose(x, y, rtol=5e-5) ================================================ FILE: scico/test/linop/test_optics.py ================================================ import numpy as np import jax import pytest from scico.linop.optics import ( AngularSpectrumPropagator, FraunhoferPropagator, FresnelPropagator, radial_transverse_frequency, ) from scico.random import randn from scico.test.linop.test_linop import adjoint_test prop_list = [AngularSpectrumPropagator, FresnelPropagator, FraunhoferPropagator] class TestPropagator: def setup_method(self, method): key = jax.random.key(12345) self.N = 128 self.dx = 1 self.k0 = 1 self.z = 1 self.key = key @pytest.mark.parametrize("ndim", [1, 2]) @pytest.mark.parametrize("prop", prop_list) def test_prop_adjoint(self, prop, ndim): A = prop(input_shape=(self.N,) * ndim, dx=self.dx, k0=self.k0, z=self.z) adjoint_test(A, self.key) @pytest.mark.parametrize("ndim", [1, 2]) def test_AS_inverse(self, ndim): A = AngularSpectrumPropagator( input_shape=(self.N,) * ndim, dx=self.dx, k0=self.k0, z=self.z ) x, key = randn(A.input_shape, dtype=np.complex64, key=self.key) Ax = A @ x AiAx = A.pinv(Ax) np.testing.assert_allclose(x, AiAx, rtol=6e-4) @pytest.mark.parametrize("prop", prop_list) def test_3d_invalid(self, prop): with pytest.raises(ValueError): prop(input_shape=(self.N, self.N, self.N), dx=self.dx, k0=self.k0, z=self.z) @pytest.mark.parametrize("prop", prop_list) def test_shape_dx_mismatch(self, prop): with pytest.raises(ValueError): prop(input_shape=(self.N,), dx=(self.dx, self.dx), k0=self.k0, z=self.z) def test_3d_invalid_radial(self): with pytest.raises(ValueError): radial_transverse_frequency(input_shape=(self.N, self.N, self.N), dx=self.dx) def test_shape_dx_mismatch_radial(self): with pytest.raises(ValueError): radial_transverse_frequency(input_shape=(self.N,), dx=(self.dx, self.dx)) @pytest.mark.parametrize("ndim", [1, 2]) def test_asp_sampling(ndim): N = 128 dx = 1 z = 1 A = AngularSpectrumPropagator(input_shape=(N,) * ndim, dx=dx, k0=1, z=z) assert not A.adequate_sampling() A = AngularSpectrumPropagator(input_shape=(N,) * ndim, dx=dx, k0=100, z=z) assert A.adequate_sampling() @pytest.mark.parametrize("ndim", [1, 2]) def test_fresnel_sampling(ndim): N = 128 dx = 1 k0 = 1 A = FresnelPropagator(input_shape=(N,) * ndim, dx=dx, k0=k0, z=N**2) assert not A.adequate_sampling() A = FresnelPropagator(input_shape=(N,) * ndim, dx=dx, k0=k0, z=1) assert A.adequate_sampling() @pytest.mark.parametrize("ndim", [1, 2]) def test_fraunhofer_sampling(ndim): N = 128 dx = 1 k0 = 1 A = FraunhoferPropagator(input_shape=(N,) * ndim, dx=dx, k0=k0, z=N**2) assert not A.adequate_sampling() A = FraunhoferPropagator(input_shape=(N,) * ndim, dx=dx, k0=k0, z=1) assert A.adequate_sampling() ================================================ FILE: scico/test/linop/xray/test_abel.py ================================================ import numpy as np import jax import pytest import scico.numpy as snp from scico.linop.xray.abel import AbelTransform from scico.test.linop.test_linop import adjoint_test BIG_INPUT = (128, 128) SMALL_INPUT = (4, 5) def make_im(Nx, Ny): x, y = snp.meshgrid(snp.linspace(-1, 1, Nx), snp.linspace(-1, 1, Ny)) im = snp.where(x**2 + y**2 < 0.3, 1.0, 0.0) return im @pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT)) def test_inverse(Nx, Ny): im = make_im(Nx, Ny) A = AbelTransform(im.shape) Ax = A @ im im_hat = A.inverse(Ax) np.testing.assert_allclose(im_hat, im, rtol=5e-5) @pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT)) def test_adjoint(Nx, Ny): im = make_im(Nx, Ny) A = AbelTransform(im.shape) adjoint_test(A) @pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT)) def test_ATA(Nx, Ny): x = make_im(Nx, Ny) A = AbelTransform(x.shape) Ax = A(x) ATAx = A.adj(Ax) np.testing.assert_allclose(np.sum(x * ATAx), np.linalg.norm(Ax) ** 2, rtol=5e-5) @pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT)) def test_grad(Nx, Ny): # ensure that we can take grad on a function using our projector # grad || A(x) ||_2^2 == 2 A.T @ A x x = make_im(Nx, Ny) A = AbelTransform(x.shape) g = lambda x: jax.numpy.linalg.norm(A(x)) ** 2 np.testing.assert_allclose(jax.grad(g)(x), 2 * A.adj(A(x)), rtol=5e-5) @pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT)) def test_adjoint_grad(Nx, Ny): x = make_im(Nx, Ny) A = AbelTransform(x.shape) Ax = A @ x f = lambda y: jax.numpy.linalg.norm(A.T(y)) ** 2 np.testing.assert_allclose(jax.grad(f)(Ax), 2 * A(A.adj(Ax)), rtol=5e-5) ================================================ FILE: scico/test/linop/xray/test_astra.py ================================================ import numpy as np import jax import pytest import scico import scico.numpy as snp from scico.linop import DiagonalStack from scico.test.linop.test_linop import adjoint_test from scipy.spatial.transform import Rotation try: from scico.linop.xray.astra import ( XRayTransform2D, XRayTransform3D, _ensure_writeable, angle_to_vector, rotate_vectors, ) except ModuleNotFoundError as e: if e.name == "astra": pytest.skip("astra not installed", allow_module_level=True) else: raise e N = 128 RTOL_CPU = 1e-4 RTOL_GPU = 1e-1 RTOL_GPU_RANDOM_INPUT = 2.0 def make_im(Nx, Ny, is_3d=True): x, y = snp.meshgrid(snp.linspace(-1, 1, Nx), snp.linspace(-1, 1, Ny), indexing="ij") im = snp.where((x - 0.25) ** 2 / 3 + y**2 < 0.1, 1.0, 0.0) if is_3d: im = im[snp.newaxis, :, :] im = im.astype(snp.float32) return im def get_tol(): if jax.devices()[0].device_kind == "cpu": rtol = RTOL_CPU else: rtol = RTOL_GPU # astra inaccurate in GPU return rtol def get_tol_random_input(): if jax.devices()[0].device_kind == "cpu": rtol = RTOL_CPU else: rtol = RTOL_GPU_RANDOM_INPUT # astra more inaccurate in GPU for random inputs return rtol class XRayTransform2DTest: def __init__(self, volume_geometry): N_proj = 180 # number of projection angles N_det = 384 det_spacing = 1 angles = np.linspace(0, np.pi, N_proj, False) np.random.seed(1234) self.x = np.random.randn(N, N).astype(np.float32) self.y = np.random.randn(N_proj, N_det).astype(np.float32) self.A = XRayTransform2D( input_shape=(N, N), det_count=N_det, det_spacing=det_spacing, angles=angles, volume_geometry=volume_geometry, ) @pytest.fixture(params=[None, [-N / 2, N / 2, -N / 2, N / 2]]) def testobj(request): yield XRayTransform2DTest(request.param) def test_init(testobj): with pytest.raises(ValueError): A = XRayTransform2D( input_shape=(16, 16, 16), det_count=16, det_spacing=1.0, angles=np.linspace(0, np.pi, 32, False), ) with pytest.raises(ValueError): A = XRayTransform2D( input_shape=(16, 16), det_count=16.3, det_spacing=1.0, angles=np.linspace(0, np.pi, 32, False), ) with pytest.raises(ValueError): A = XRayTransform2D( input_shape=(16, 16), det_count=16, det_spacing=1.0, angles=np.linspace(0, np.pi, 32, False), device="invalid", ) def test_ATA_call(testobj): # Test for the call-based interface Ax = testobj.A(testobj.x) ATAx = testobj.A.adj(Ax) np.testing.assert_allclose(np.sum(testobj.x * ATAx), np.linalg.norm(Ax) ** 2, rtol=get_tol()) def test_ATA_matmul(testobj): # Test for the matmul interface Ax = testobj.A @ testobj.x ATAx = testobj.A.T @ Ax np.testing.assert_allclose(np.sum(testobj.x * ATAx), np.linalg.norm(Ax) ** 2, rtol=get_tol()) def test_AAT_call(testobj): # Test for the call-based interface ATy = testobj.A.adj(testobj.y) AATy = testobj.A(ATy) np.testing.assert_allclose(np.sum(testobj.y * AATy), np.linalg.norm(ATy) ** 2, rtol=get_tol()) def test_AAT_matmul(testobj): # Test for the matmul interface ATy = testobj.A.T @ testobj.y AATy = testobj.A @ ATy np.testing.assert_allclose(np.sum(testobj.y * AATy), np.linalg.norm(ATy) ** 2, rtol=get_tol()) def test_grad(testobj): # ensure that we can take grad on a function using our projector # grad || A(x) ||_2^2 == 2 A.T @ A x A = testobj.A x = testobj.x g = lambda x: jax.numpy.linalg.norm(A(x)) ** 2 np.testing.assert_allclose( scico.grad(g)(x), 2 * A.adj(A(x)), atol=get_tol() * x.max(), rtol=get_tol() ) def test_adjoint_grad(testobj): A = testobj.A x = testobj.x Ax = A @ x f = lambda y: jax.numpy.linalg.norm(A.T(y)) ** 2 np.testing.assert_allclose(scico.grad(f)(Ax), 2 * A(A.adj(Ax)), rtol=get_tol()) def test_adjoint_random(testobj): A = testobj.A adjoint_test(A, rtol=10 * get_tol_random_input()) def test_adjoint_typical_input(testobj): A = testobj.A x = make_im(A.input_shape[0], A.input_shape[1], is_3d=False) adjoint_test(A, x=x, rtol=get_tol()) def test_fbp(testobj): x = testobj.A.fbp(testobj.y) # Test for a bug (related to calling the Astra CPU FBP implementation # when using a FPU device) that resulted in a constant zero output. assert np.sum(np.abs(x)) > 0.0 def test_jit_in_DiagonalStack(): """See https://github.com/lanl/scico/issues/331""" N = 10 H = DiagonalStack([XRayTransform2D((N, N), N, 1.0, snp.linspace(0, snp.pi, N))]) H.T @ snp.zeros(H.output_shape, dtype=snp.float32) @pytest.mark.skipif(jax.devices()[0].platform != "gpu", reason="checking GPU behavior") def test_3D_on_GPU(): x = snp.zeros((4, 5, 6)) A = XRayTransform3D( x.shape, det_count=[6, 6], det_spacing=[1.0, 1.0], angles=snp.linspace(0, snp.pi, 10) ) assert A.num_dims == 3 y = A @ x ATy = A.T @ y @pytest.mark.skipif(jax.devices()[0].platform != "gpu", reason="GPU required for test") def test_3D_api_equiv(): x = np.random.randn(4, 5, 6).astype(np.float32) det_count = [7, 8] det_spacing = [1.0, 1.5] angles = snp.linspace(0, snp.pi, 10) A = XRayTransform3D(x.shape, det_count=det_count, det_spacing=det_spacing, angles=angles) vectors = angle_to_vector(det_spacing, angles) B = XRayTransform3D(x.shape, det_count=det_count, vectors=vectors) ya = A @ x yb = B @ x np.testing.assert_allclose(ya, yb, rtol=get_tol()) def test_angle_to_vector(): angles = snp.linspace(0, snp.pi, 5) det_spacing = [0.9, 1.5] vectors = angle_to_vector(det_spacing, angles) assert vectors.shape == (angles.size, 12) def test_rotate_vectors(): v0 = angle_to_vector([1.0, 1.0], np.linspace(0, np.pi / 2, 4, endpoint=False)) v1 = angle_to_vector([1.0, 1.0], np.linspace(np.pi / 2, np.pi, 4, endpoint=False)) r = Rotation.from_euler("z", np.pi / 2) v0r = rotate_vectors(v0, r) np.testing.assert_allclose(v1, v0r, atol=1e-7) ## conversion functions @pytest.fixture(scope="module") def test_geometry(): """ In this geometry, if vol[i, j, k]==1, we expect proj[j-2, k-1]==1. Because: - We project along z, i.e. `ray=(0,0,1)`, i.e., we remove axis=0. - We set `v=(0, 1, 0)`, so detector rows go with y axis, axis=1. - We set `u=(1, 0, 0)`, so detector columns go with x axis, axis=2. - We shift the detector by (x=1, y=2, z=3) <-> i-3, j-2, k-1 """ in_shape = (30, 31, 32) # in ASTRA terminology: n_rows = in_shape[1] # y n_cols = in_shape[2] # x n_slices = in_shape[0] # z vol_geom = scico.linop.xray.astra.astra.create_vol_geom(n_rows, n_cols, n_slices) assert vol_geom["option"]["WindowMinX"] == -n_cols / 2 assert vol_geom["option"]["WindowMinY"] == -n_rows / 2 assert vol_geom["option"]["WindowMinZ"] == -n_slices / 2 # project along z, axis=0 det_row_count = n_rows det_col_count = n_cols ray = (0, 0, 1) d = (1, 2, 3) # axis=2 offset by 1, axis=1 offset by 2, axis=0 offset by 3 u = (1, 0, 0) # increments columns, goes with X v = (0, 1, 0) # increments rows, goes with Y vectors = np.array(ray + d + u + v)[np.newaxis, :] proj_geom = scico.linop.xray.astra.astra.create_proj_geom( "parallel3d_vec", det_row_count, det_col_count, vectors ) return vol_geom, proj_geom @pytest.mark.skipif(jax.devices()[0].platform != "gpu", reason="GPU required for test") def test_projection_convention(test_geometry): """ If vol[i, j, k]==1, test that astra puts proj[j-2, k-1]==1. See `test_geometry` for the setup. """ vol_geom, proj_geom = test_geometry in_shape = scico.linop.xray.astra.astra.functions.geom_size(vol_geom) vol = np.zeros(in_shape) i, j, k = [np.random.randint(0, s) for s in in_shape] vol[i, j, k] = 1.0 proj_id, proj = scico.linop.xray.astra.astra.create_sino3d_gpu(vol, proj_geom, vol_geom) scico.linop.xray.astra.astra.data3d.delete(proj_id) proj = proj[:, 0, :] # get first view assert len(np.unique(proj) == 2) idx_proj_i, idx_proj_j = np.nonzero(proj) np.testing.assert_array_equal(idx_proj_i, j - 2) np.testing.assert_array_equal(idx_proj_j, k - 1) def test_project_coords(test_geometry): """ If vol[i, j, k]==1, test that we predict proj[j-2, k-1]==1. See `test_geometry` for the setup and `test_projection_convention` for proof ASTRA works this way. """ vol_geom, proj_geom = test_geometry in_shape = scico.linop.xray.astra.astra.functions.geom_size(vol_geom) x_vol = np.array([np.random.randint(0, s) for s in in_shape]) x_proj_gt = np.array( [[x_vol[1] - 2, x_vol[2] - 1]] ) # projection along slices removes first index x_proj = scico.linop.xray.astra._project_coords(x_vol, vol_geom, proj_geom) np.testing.assert_array_equal(x_proj_gt, x_proj) def test_convert_to_scico_geometry(test_geometry): """ Basic regression test, `test_project_coords` tests the logic. """ vol_geom, proj_geom = test_geometry matrices_truth = scico.linop.xray.astra._astra_to_scico_geometry(vol_geom, proj_geom) truth = np.array([[[0.0, 1.0, 0.0, -2.0], [0.0, 0.0, 1.0, -1.0]]]) np.testing.assert_allclose(matrices_truth, truth) def test_convert_from_scico_geometry(test_geometry): """ Basic regression test, `test_project_coords` tests the logic. """ in_shape = (30, 31, 32) matrices = np.array([[[0.0, 1.0, 0.0, -2.0], [0.0, 0.0, 1.0, -1.0]]]) det_shape = (31, 32) vectors = scico.linop.xray.astra.convert_from_scico_geometry(in_shape, matrices, det_shape) _, proj_geom_truth = test_geometry # skip testing element 5, as it is detector center along the ray and doesn't matter np.testing.assert_allclose(vectors[0, :5], proj_geom_truth["Vectors"][0, :5]) np.testing.assert_allclose(vectors[0, 6:], proj_geom_truth["Vectors"][0, 6:]) def test_vol_coord_to_world_coord(): vol_geom = scico.linop.xray.astra.astra.create_vol_geom(16, 16) vc = np.array([[0.0, 0.0], [1.0, 1.0]]) wc = scico.linop.xray.astra.volume_coords_to_world_coords(vc, vol_geom) assert wc.shape == (2, 2) def test_ensure_writeable(): assert isinstance(_ensure_writeable(np.ones((2, 1))), np.ndarray) assert isinstance(_ensure_writeable(snp.ones((2, 1))), np.ndarray) ================================================ FILE: scico/test/linop/xray/test_svmbir.py ================================================ import numpy as np import jax import pytest import scico import scico.numpy as snp from scico.linop import Diagonal from scico.loss import SquaredL2Loss from scico.test.functional.prox import prox_test from scico.test.linop.test_linop import adjoint_test try: import svmbir from scico.linop.xray.svmbir import ( SVMBIRExtendedLoss, SVMBIRSquaredL2Loss, XRayTransform, ) except ImportError as e: pytest.skip("svmbir not installed", allow_module_level=True) BIG_INPUT = (32, 33, 50, 51, 125, 1.2) SMALL_INPUT = (4, 5, 7, 8, 16, 1.2) def pytest_generate_tests(metafunc): param_ranges = { "is_3d": (True, False), "is_masked": (True, False), "geometry": ("parallel", "fan-curved", "fan-flat"), "center_offset_small": (0, 0.1), "center_offset_big": (0, 3), "delta_channel": (None, 0.5), "delta_pixel": (None, 0.5), "positivity": (True, False), "weight_type": ("transmission", "unweighted"), } level = int(metafunc.config.getoption("--level")) if level < 3: param_ranges.update({"is_3d": (False,), "is_masked": (False,), "positivity": (False,)}) if level < 2: param_ranges.update( { "geometry": ("parallel",), "center_offset_small": (0.1,), "center_offset_big": (3,), "delta_channel": (None,), "delta_pixel": (None,), "weight_type": ("transmission",), } ) for k, v in param_ranges.items(): if k in metafunc.fixturenames: metafunc.parametrize(k, v) def make_im(Nx, Ny, is_3d=True): x, y = snp.meshgrid(snp.linspace(-1, 1, Nx), snp.linspace(-1, 1, Ny), indexing="ij") im = snp.where((x - 0.25) ** 2 / 3 + y**2 < 0.1, 1.0, 0.0) if is_3d: im = im[snp.newaxis, :, :] im = im.astype(snp.float32) return im def make_angles(num_angles): return snp.linspace(0, snp.pi, num_angles, dtype=snp.float32) def make_A( im, num_angles, num_channels, center_offset, is_masked, geometry="parallel", dist_source_detector=None, magnification=None, delta_channel=None, delta_pixel=None, ): angles = make_angles(num_angles) A = XRayTransform( im.shape, angles, num_channels, center_offset=center_offset, is_masked=is_masked, geometry=geometry, dist_source_detector=dist_source_detector, magnification=magnification, ) return A def test_grad( is_3d, center_offset_big, is_masked, geometry, ): Nx, Ny, num_angles, num_channels, dist_source_detector, magnification = BIG_INPUT im = make_im(Nx, Ny, is_3d) A = make_A( im, num_angles, num_channels, center_offset_big, is_masked, geometry=geometry, dist_source_detector=dist_source_detector, magnification=magnification, ) def f(im): return snp.sum(A._eval(im) ** 2) val_1 = jax.grad(f)(im) val_2 = 2 * A.adj(A(im)) np.testing.assert_allclose(val_1, val_2) def test_adjoint( is_3d, center_offset_big, is_masked, geometry, ): Nx, Ny, num_angles, num_channels, dist_source_detector, magnification = BIG_INPUT im = make_im(Nx, Ny, is_3d) A = make_A( im, num_angles, num_channels, center_offset_big, is_masked, geometry=geometry, dist_source_detector=dist_source_detector, magnification=magnification, ) adjoint_test(A) @pytest.mark.slow def test_prox( is_3d, center_offset_small, is_masked, geometry, ): Nx, Ny, num_angles, num_channels, dist_source_detector, magnification = SMALL_INPUT im = make_im(Nx, Ny, is_3d) A = make_A( im, num_angles, num_channels, center_offset_small, is_masked, geometry=geometry, dist_source_detector=dist_source_detector, magnification=magnification, ) sino = A @ im v, _ = scico.random.normal(im.shape, dtype=im.dtype) if is_masked: f = SVMBIRExtendedLoss(y=sino, A=A, positivity=False, prox_kwargs={"maxiter": 5}) else: f = SVMBIRSquaredL2Loss(y=sino, A=A, prox_kwargs={"maxiter": 5}) prox_test(v, f, f.prox, alpha=0.25, rtol=5e-4) @pytest.mark.slow def test_prox_weights( is_3d, center_offset_small, is_masked, geometry, ): Nx, Ny, num_angles, num_channels, dist_source_detector, magnification = SMALL_INPUT im = make_im(Nx, Ny, is_3d) A = make_A( im, num_angles, num_channels, center_offset_small, is_masked, geometry=geometry, dist_source_detector=dist_source_detector, magnification=magnification, ) sino = A @ im v, _ = scico.random.normal(im.shape, dtype=im.dtype) # test with weights weights, _ = scico.random.uniform(sino.shape, dtype=im.dtype) W = scico.linop.Diagonal(weights) if is_masked: f = SVMBIRExtendedLoss(y=sino, A=A, W=W, positivity=False, prox_kwargs={"maxiter": 5}) else: f = SVMBIRSquaredL2Loss(y=sino, A=A, W=W, prox_kwargs={"maxiter": 5}) prox_test(v, f, f.prox, alpha=0.25, rtol=5e-5) def test_prox_cg( is_3d, weight_type, center_offset_small, is_masked, geometry, ): Nx, Ny, num_angles, num_channels, dist_source_detector, magnification = SMALL_INPUT im = make_im(Nx, Ny, is_3d=is_3d) / Nx * 10 A = make_A( im, num_angles, num_channels, center_offset_small, is_masked=is_masked, geometry=geometry, dist_source_detector=dist_source_detector, magnification=magnification, ) y = A @ im A_colsum = A.H @ snp.ones( y.shape, dtype=snp.float32 ) # backproject ones to get sum over cols of A if is_masked: mask = np.asarray(A_colsum) > 0 # cols of A which are not all zeros else: mask = np.ones(im.shape) > 0 W = svmbir.calc_weights(y, weight_type=weight_type).astype("float32") W = snp.array(W) λ = 0.01 if is_masked: f_sv = SVMBIRExtendedLoss( y=y, A=A, W=Diagonal(W), positivity=False, prox_kwargs={"maxiter": 5} ) else: f_sv = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(W), prox_kwargs={"maxiter": 5}) f_wg = SquaredL2Loss(y=y, A=A, W=Diagonal(W), prox_kwargs={"tol": 5e-4}) v, _ = scico.random.normal(im.shape, dtype=im.dtype) v *= im.max() * 0.5 xprox_sv = f_sv.prox(v, λ) xprox_cg = f_wg.prox(v, λ) # this uses cg assert snp.linalg.norm(xprox_sv[mask] - xprox_cg[mask]) / snp.linalg.norm(xprox_sv[mask]) < 5e-4 def test_approx_prox( is_3d, weight_type, center_offset_big, is_masked, positivity, geometry, delta_channel, delta_pixel, ): Nx, Ny, num_angles, num_channels, dist_source_detector, magnification = SMALL_INPUT im = make_im(Nx, Ny, is_3d) A = make_A( im, num_angles, num_channels, center_offset_big, is_masked, geometry=geometry, dist_source_detector=dist_source_detector, magnification=magnification, delta_channel=delta_channel, delta_pixel=delta_pixel, ) y = A @ im W = svmbir.calc_weights(y, weight_type=weight_type).astype("float32") W = snp.array(W) λ = 0.01 v, _ = scico.random.normal(im.shape, dtype=im.dtype) if is_masked or positivity: f = SVMBIRExtendedLoss( y=y, A=A, W=Diagonal(W), positivity=positivity, prox_kwargs={"maxiter": 5} ) else: f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(W), prox_kwargs={"maxiter": 5}) xprox = snp.array(f.prox(v, lam=λ)) if is_masked or positivity: f_approx = SVMBIRExtendedLoss( y=y, A=A, W=Diagonal(W), prox_kwargs={"maxiter": 2}, positivity=positivity ) else: f_approx = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(W), prox_kwargs={"maxiter": 2}) xprox_approx = snp.array(f_approx.prox(v, lam=λ, v0=xprox)) assert snp.linalg.norm(xprox - xprox_approx) / snp.linalg.norm(xprox) < 5e-5 ================================================ FILE: scico/test/linop/xray/test_symcone.py ================================================ import numpy as np import pytest from scico import metric from scico.examples import create_circular_phantom from scico.linop.xray.symcone import ( AxiallySymmetricVolume, SymConeXRayTransform, _volume_by_axial_symmetry, ) from scipy.ndimage import gaussian_filter class TestAxialSymm: def setup_method(self, method): N = 64 self.N = N self.x2d = create_circular_phantom((N, N), [0.4 * N, 0.2 * N, 0.1 * N], [1, 0, 0.5]) self.x3d = create_circular_phantom((N, N, N), [0.4 * N, 0.2 * N, 0.1 * N], [1, 0, 0.5]) self.x2d = gaussian_filter(self.x2d, 1.0) self.x3d = gaussian_filter(self.x3d, 1.0) @pytest.mark.parametrize("axis", [0, 1]) def test_vbas(self, axis): v0 = _volume_by_axial_symmetry(self.x2d, axis=axis) assert metric.rel_res(self.x3d, v0) < 5e-2 offset = -3 x2dr = np.roll(self.x2d, offset, axis=1 - axis) Nh = (self.N + 1) / 2 - 1 v1 = _volume_by_axial_symmetry(x2dr, axis=axis, center=Nh + offset) assert metric.rel_res(v0, v1) < 1e-5 zrange = np.arange(-Nh, 0) v2 = _volume_by_axial_symmetry(self.x2d, axis=axis, zrange=zrange) assert metric.rel_res(self.x3d[..., 0 : self.N // 2], v2) < 5e-2 A = AxiallySymmetricVolume((self.N, self.N), axis=axis) vl = A(self.x2d) assert metric.rel_res(v0, vl) < 1e-7 class TestAbelCone: def setup_method(self, method): N = 64 self.N = N self.x2d = create_circular_phantom((N, N), [0.4 * N, 0.2 * N, 0.1 * N], [1, 0, 0.5]) self.x3d = create_circular_phantom((N, N, N), [0.4 * N, 0.2 * N, 0.1 * N], [1, 0, 0.5]) self.x2d = gaussian_filter(self.x2d, 1.0) self.x3d = gaussian_filter(self.x3d, 1.0) @pytest.mark.parametrize("num_slabs", [1, 2, 3]) def test_2d(self, num_slabs): A = SymConeXRayTransform(self.x2d.shape, 1e8, 1e8 + 1, num_slabs=num_slabs) ya = A(self.x2d) x2ds = _volume_by_axial_symmetry(self.x2d, axis=0) ys = np.sum(x2ds, axis=1) assert metric.rel_res(ys, ya) < 1e-6 @pytest.mark.parametrize("num_slabs", [1, 2, 3]) def test_2d_unequal(self, num_slabs): x2dc = self.x2d[1:-1] A = SymConeXRayTransform(x2dc.shape, 1e8, 1e8 + 1, num_slabs=num_slabs) ya = A(x2dc) x2ds = _volume_by_axial_symmetry(x2dc, axis=0) ys = np.sum(x2ds, axis=1) assert metric.rel_res(ys, ya) < 1e-6 @pytest.mark.parametrize("num_slabs", [1, 2, 3]) def test_3d(self, num_slabs): A = SymConeXRayTransform(self.x3d.shape, 1e8, 1e8 + 1, num_slabs=num_slabs) ya = A(self.x3d) ys = np.sum(self.x3d, axis=1) assert metric.rel_res(ys, ya) < 1e-6 @pytest.mark.parametrize("num_slabs", [1, 2, 3]) def test_3d_unequal(self, num_slabs): x3dc = self.x3d[1:-1, 2:-2] A = SymConeXRayTransform(x3dc.shape, 1e8, 1e8 + 1, num_slabs=num_slabs) ya = A(x3dc) ys = np.sum(x3dc, axis=1) assert metric.rel_res(ys, ya) < 1e-6 @pytest.mark.parametrize("num_slabs", [1, 2, 3]) def test_2d3d_unequal(self, num_slabs): A2d = SymConeXRayTransform(self.x2d.shape, 5e1, 6e1, num_slabs=num_slabs) A3d = SymConeXRayTransform(self.x3d.shape, 5e1, 6e1, num_slabs=num_slabs) y2d = A2d(self.x2d) y3d = A3d(self.x3d) assert metric.rel_res(y3d, y2d) < 2e-2 @pytest.mark.parametrize("axis", [0, 1]) def test_proj_axis(self, axis): N = self.N N2 = N // 2 N4 = N // 4 x = np.zeros((N, N)) if axis == 0: x[N2 - 1 : N2 + 1, N4 - 1 : N4 + 1] = 1 else: x[N4 - 1 : N4 + 1, N2 - 1 : N2 + 1] = 1 A = SymConeXRayTransform(x.shape, 1e2, 2e2, axis=axis, num_slabs=1) y = A(x) if axis == 0: assert np.sum(np.sum(y, axis=1) > 0) <= 4 assert np.sum(np.sum(y, axis=0) > 0) >= N2 else: assert np.sum(np.sum(y, axis=0) > 0) <= 4 assert np.sum(np.sum(y, axis=1) > 0) >= N2 @pytest.mark.parametrize("axis", [0, 1]) def test_fdk(self, axis): A = SymConeXRayTransform(self.x3d.shape, 1e2, 2e2, axis=axis, num_slabs=1) y = A(self.x3d) z = A.fdk(y) assert metric.rel_res(self.x2d, z) < 0.2 ================================================ FILE: scico/test/linop/xray/test_xray_2d.py ================================================ import numpy as np import jax import jax.numpy as jnp import pytest import scico import scico.linop from scico.linop.xray import XRayTransform2D from scico.metric import psnr @pytest.mark.filterwarnings("error") def test_init(): input_shape = (3, 3) # no warning with default settings, even at 45 degrees H = XRayTransform2D(input_shape, jnp.array([jnp.pi / 4])) # no warning if we project orthogonally with oversized pixels H = XRayTransform2D(input_shape, jnp.array([0]), dx=jnp.array([1, 1])) # warning if the projection angle changes with pytest.warns(UserWarning): H = XRayTransform2D(input_shape, jnp.array([0.1]), dx=jnp.array([1.1, 1.1])) # warning if the pixels get any larger with pytest.warns(UserWarning): H = XRayTransform2D(input_shape, jnp.array([0]), dx=jnp.array([1.1, 1.1])) def test_apply(): im_shape = (12, 13) num_angles = 10 x = jnp.ones(im_shape) angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) # general projection H = XRayTransform2D(x.shape, angles) y = H @ x assert y.shape[0] == (num_angles) # fixed det_count det_count = 14 H = XRayTransform2D(x.shape, angles, det_count=det_count) y = H @ x assert y.shape[1] == det_count def test_apply_adjoint(): im_shape = (12, 13) num_angles = 10 x = jnp.ones(im_shape, dtype=jnp.float32) angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) # general projection H = XRayTransform2D(x.shape, angles) y = H @ x assert y.shape[0] == (num_angles) # adjoint bp = H.T @ y assert scico.linop.valid_adjoint( H, H.T, eps=1e-4 ) # associative reductions might cause small errors, hence 1e-5 # fixed det_length det_count = 14 H = XRayTransform2D(x.shape, angles, det_count=det_count) y = H @ x assert y.shape[1] == det_count def test_matched_adjoint(): """See https://github.com/lanl/scico/issues/560.""" N = 16 det_count = int(N * 1.05 / np.sqrt(2.0)) dx = 1.0 / np.sqrt(2) n_projection = 3 angles = np.linspace(0, np.pi, n_projection, endpoint=False) A = XRayTransform2D((N, N), angles, det_count=det_count, dx=dx) assert scico.linop.valid_adjoint(A, A.T, eps=1e-5) @pytest.mark.parametrize("dx", [0.5, 1.0 / np.sqrt(2)]) @pytest.mark.parametrize("det_count_factor", [1.02 / np.sqrt(2.0), 1.0]) def test_fbp(dx, det_count_factor): N = 256 x_gt = np.zeros((N, N), dtype=np.float32) N4 = N // 4 x_gt[N4:-N4, N4:-N4] = 1.0 det_count = int(det_count_factor * N) n_proj = 360 angles = np.linspace(0, np.pi, n_proj, endpoint=False) A = XRayTransform2D(x_gt.shape, angles, det_count=det_count, dx=dx) y = A(x_gt) x_fbp = A.fbp(y) assert psnr(x_gt, x_fbp) > 28 def test_fbp_jit(): N = 64 x_gt = np.ones((N, N), dtype=np.float32) det_count = N n_proj = 90 angles = np.linspace(0, np.pi, n_proj, endpoint=False) A = XRayTransform2D(x_gt.shape, angles, det_count=det_count) y = A(x_gt) fbp = jax.jit(A.fbp) x_fbp = fbp(y) ================================================ FILE: scico/test/linop/xray/test_xray_3d.py ================================================ import numpy as np import jax.numpy as jnp import scico.linop from scico.linop.xray import XRayTransform3D def test_matched_adjoint(): """See https://github.com/lanl/scico/issues/560.""" N = 16 det_count = int(N * 1.05 / np.sqrt(2.0)) n_projection = 3 input_shape = (N, N, N) det_shape = (det_count, det_count) M = XRayTransform3D.matrices_from_euler_angles( input_shape, det_shape, "X", np.linspace(0, np.pi, n_projection, endpoint=False)[:, None], # make (n_projection, 1) ) H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape) assert scico.linop.valid_adjoint(H, H.T, eps=1e-5) def test_scaling(): x = jnp.zeros((4, 4, 1)) x = x.at[1:3, 1:3, 0].set(1.0) input_shape = x.shape det_shape = x.shape[:2] # default spacing M = XRayTransform3D.matrices_from_euler_angles(input_shape, det_shape, "X", [[0.0]]) H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape) # fmt: off truth = jnp.array( [[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]] ) # fmt: on np.testing.assert_allclose(H @ x, truth) # bigger voxels in the x (first index) direction M = XRayTransform3D.matrices_from_euler_angles( input_shape, det_shape, "X", [[0.0]], voxel_spacing=[2.0, 1.0, 1.0] ) H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape) # fmt: off truth = jnp.array( [[[0. , 0.5, 0.5, 0. ], [0. , 0.5, 0.5, 0. ], [0. , 0.5, 0.5, 0. ], [0. , 0.5, 0.5, 0. ]]] ) # fmt: on np.testing.assert_allclose(H @ x, truth) # bigger detector pixels in the x (first index) direction M = XRayTransform3D.matrices_from_euler_angles( input_shape, det_shape, "X", [[0.0]], det_spacing=[2.0, 1.0] ) H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape) # fmt: off truth = None # fmt: on # TODO: Check this case more closely. # np.testing.assert_allclose(H @ x, truth) ================================================ FILE: scico/test/linop/xray/test_xray_util.py ================================================ import numpy as np import jax import jax.numpy as jnp from jax.scipy.spatial.transform import Rotation import pytest import scipy.ndimage from scico.linop.xray import ( center_image, image_alignment_rotation, image_centroid, rotate_volume, volume_alignment_rotation, ) try: import astra # noqa have_astra = True except ModuleNotFoundError as e: if e.name == "astra": have_astra = False else: raise e def test_image_centroid(): v = np.zeros((4, 5)) v[1:-1, 1:-1] = 1 assert image_centroid(v) == (1.5, 2.0) image_centroid(v, center_offset=True) == (0.0, 0.0) def test_center_image(): u = np.zeros((4, 5)) u[0:-2, 0:-2] = 1 v = center_image(u) np.testing.assert_allclose(image_centroid(v, center_offset=True), (0.0, 0.0), atol=1e-7) v = center_image(u, axes=(0,)) np.testing.assert_allclose(image_centroid(v, center_offset=True), (0.0, -1.0), atol=1e-7) def test_rotate_volume(): vol = np.arange(27).reshape((3, 3, 3)) rot = Rotation.from_euler("XY", jnp.array([90.0, 90.0]), degrees=True) vol_rot = rotate_volume(vol, rot) np.testing.assert_allclose(vol.transpose((1, 2, 0)), vol_rot, rtol=1e-7) def align_test_tol(): if jax.devices()[0].device_kind == "cpu": tol = 1e-3 else: tol = 5e-2 # less accurate on gpu return tol @pytest.mark.skipif(not have_astra, reason="astra not installed") def test_image_alignment(): u = np.zeros((256, 256), dtype=np.float32) u[:, 8::16] = 1 u[:, 9::16] = 1 angle = image_alignment_rotation(u) assert np.abs(angle) < 1e-3 ur = scipy.ndimage.rotate(u, 0.75) angle = image_alignment_rotation(ur) assert np.abs(angle - 0.75) < align_test_tol() @pytest.mark.skipif(not have_astra, reason="astra not installed") def test_volume_alignment(): u = np.zeros((256, 256, 32), dtype=np.float32) u[8::16, :, 2::6] = 1 u[9::16, :, 2::6] = 1 u[:, 8::16, 2::6] = 1 u[:, 9::16, 2::6] = 1 u[8::16, :, 3::6] = 1 u[9::16, :, 3::6] = 1 u[:, 8::16, 3::6] = 1 u[:, 9::16, 3::6] = 1 rot = volume_alignment_rotation(u) assert rot.magnitude() < 1e-5 ref_rot = Rotation.from_euler("XY", jnp.array([1.6, -0.9]), degrees=True) ur = rotate_volume(u, ref_rot) rot = volume_alignment_rotation(ur) assert ( np.abs(ref_rot.as_euler("XYZ", degrees=True) - rot.as_euler("XYZ", degrees=True)).max() < 1e-1 ) ================================================ FILE: scico/test/numpy/test_blockarray.py ================================================ import itertools import operator as op import numpy as np import jax import jax.numpy as jnp import pytest import scico.numpy as snp from scico.numpy import BlockArray from scico.numpy._wrapped_function_lists import testing_functions from scico.numpy.testing import assert_array_equal from scico.numpy.util import shape_dtype_rep from scico.util import rgetattr math_ops = [op.add, op.sub, op.mul, op.truediv, op.pow] # op.floordiv doesn't work on complex comp_ops = [op.le, op.lt, op.ge, op.gt, op.eq] def make_arbitrary_jax_array(shape, dtype): """ Make an arbitrary jax array of the given shape and dtype. """ return jnp.array(np.random.randn(*shape)).astype(dtype) def sequence_assert_allclose(x, y, *args, **kwargs): """Assert sequences x and y have the same length and corresponding elements are allclose.""" assert len(x) == len(y) for x_i, y_i in zip(x, y): np.testing.assert_allclose(x_i, y_i, *args, **kwargs) class OperatorsTestObj: operators = math_ops + comp_ops def __init__(self, dtype): self.scalar = 1.0 self.a0 = make_arbitrary_jax_array((2, 3), dtype) self.a1 = make_arbitrary_jax_array((2, 3, 4), dtype) self.a = BlockArray((self.a0, self.a1)) self.b0 = make_arbitrary_jax_array((2, 3), dtype) self.b1 = make_arbitrary_jax_array((2, 3, 4), dtype) self.b = BlockArray((self.b0, self.b1)) self.d0 = make_arbitrary_jax_array((3, 2), dtype) self.d1 = make_arbitrary_jax_array((2, 4, 3), dtype) self.d = BlockArray((self.d0, self.d1)) c0 = make_arbitrary_jax_array((2, 3), dtype) self.c = BlockArray((c0,)) # A flat device array with same size as self.a & self.b self.flat_da = make_arbitrary_jax_array(self.a.size, dtype) self.flat_nd = np.array(self.flat_da) # A device array with length == self.a.num_blocks self.block_da, key = make_arbitrary_jax_array((len(self.a),), dtype) # block_da but as a numpy array self.block_nd = np.array(self.block_da) self.key = key @pytest.fixture(scope="module", params=[np.float32, np.complex64]) def test_operator_obj(request): yield OperatorsTestObj(request.param) # Operations between a blockarray and scalar @pytest.mark.parametrize("operator", math_ops + comp_ops) def test_operator_left(test_operator_obj, operator): scalar = test_operator_obj.scalar a = test_operator_obj.a x = operator(scalar, a) y = BlockArray(operator(scalar, a_i) for a_i in a) sequence_assert_allclose(x, y) @pytest.mark.parametrize("operator", math_ops + comp_ops) def test_operator_right(test_operator_obj, operator): scalar = test_operator_obj.scalar a = test_operator_obj.a x = operator(a, scalar) y = BlockArray(operator(a_i, scalar) for a_i in a) sequence_assert_allclose(x, y) # Operations between two blockarrays of same size @pytest.mark.parametrize("operator", math_ops + comp_ops) def test_ba_ba_operator(test_operator_obj, operator): a = test_operator_obj.a b = test_operator_obj.b x = operator(a, b) y = BlockArray(operator(a_i, b_i) for a_i, b_i in zip(a, b)) sequence_assert_allclose(x, y) # Testing the @ interface for blockarrays of same size, and a blockarray and flattened # ndarray/devicearray def test_ba_ba_matmul(test_operator_obj): a = test_operator_obj.a b = test_operator_obj.d c = test_operator_obj.c a0 = test_operator_obj.a0 a1 = test_operator_obj.a1 d0 = test_operator_obj.d0 d1 = test_operator_obj.d1 x = a @ b y = BlockArray([a0 @ d0, a1 @ d1]) assert x.shape == y.shape sequence_assert_allclose(x, y) with pytest.raises(TypeError): z = a @ c def test_conj(test_operator_obj): a = test_operator_obj.a ac = a.conj() assert a.shape == ac.shape sequence_assert_allclose(BlockArray(a_i.conj() for a_i in a), ac) def test_real(test_operator_obj): a = test_operator_obj.a ac = a.real sequence_assert_allclose(BlockArray(a_i.real for a_i in a), ac) def test_imag(test_operator_obj): a = test_operator_obj.a ac = a.imag sequence_assert_allclose(BlockArray(a_i.imag for a_i in a), ac) def test_ndim(test_operator_obj): assert test_operator_obj.a.ndim == (2, 3) assert test_operator_obj.c.ndim == (2,) def test_getitem(test_operator_obj): # make a length-4 blockarray a0 = test_operator_obj.a0 a1 = test_operator_obj.a1 b0 = test_operator_obj.b0 b1 = test_operator_obj.b1 x = BlockArray([a0, a1, b0, b1]) # positive indexing np.testing.assert_allclose(x[0], a0) np.testing.assert_allclose(x[1], a1) np.testing.assert_allclose(x[2], b0) np.testing.assert_allclose(x[3], b1) # negative indexing np.testing.assert_allclose(x[-4], a0) np.testing.assert_allclose(x[-3], a1) np.testing.assert_allclose(x[-2], b0) np.testing.assert_allclose(x[-1], b1) def test_split(test_operator_obj): a = test_operator_obj.a np.testing.assert_allclose(a[0], test_operator_obj.a0) np.testing.assert_allclose(a[1], test_operator_obj.a1) def test_blockarray_from_one_array(): # BlockArray(np.jnp.zeros((3,6))) makes a block array # with 3 length-6 blocks x = BlockArray(np.random.randn(3, 6)) assert len(x) == 3 @pytest.mark.parametrize("axis", [None, 1]) @pytest.mark.parametrize("keepdims", [True, False]) def test_sum_method(test_operator_obj, axis, keepdims): a = test_operator_obj.a method_result = a.sum(axis=axis, keepdims=keepdims) snp_result = snp.sum(a, axis=axis, keepdims=keepdims) sequence_assert_allclose(method_result, snp_result) def test_eval_shape_1arg(test_operator_obj): def foo(x): return snp.atleast_3d(x) x = test_operator_obj.a es = jax.eval_shape(foo, shape_dtype_rep(x.shape, x.dtype)) ba = foo(x) assert es.shape == ba.shape assert es.dtype == ba.dtype def test_eval_shape_2arg(test_operator_obj): def foo(x, y): return x * y x = test_operator_obj.a y = test_operator_obj.b args = [ BlockArray([jax.ShapeDtypeStruct(b_i.shape, b_i.dtype) for b_i in x]), BlockArray([jax.ShapeDtypeStruct(b_i.shape, b_i.dtype) for b_i in y]), ] es = jax.eval_shape(foo, *args) assert es.shape == x.shape assert es.dtype == x.dtype def test_linear_transpose(test_operator_obj): fun = lambda x: 2 * x x = test_operator_obj.a tfun_ba = jax.linear_transpose(fun, x) tfun_dts = jax.linear_transpose(fun, shape_dtype_rep(x.shape, x.dtype)) assert tfun_ba.args == tfun_dts.args @pytest.mark.parametrize("operator", [snp.dot, snp.matmul]) def test_ba_ba_dot(test_operator_obj, operator): a = test_operator_obj.a d = test_operator_obj.d a0 = test_operator_obj.a0 a1 = test_operator_obj.a1 d0 = test_operator_obj.d0 d1 = test_operator_obj.d1 x = operator(a, d) y = BlockArray([operator(a0, d0), operator(a1, d1)]) sequence_assert_allclose(x, y) # reduction tests reduction_funcs = [ snp.sum, snp.linalg.norm, ] real_reduction_funcs = [] class BlockArrayReductionObj: def __init__(self, dtype): key = None a0 = make_arbitrary_jax_array(shape=(2, 3), dtype=dtype) a1 = make_arbitrary_jax_array(shape=(2, 3, 4), dtype=dtype) b0 = make_arbitrary_jax_array(shape=(2, 3), dtype=dtype) b1 = make_arbitrary_jax_array(shape=(2, 3), dtype=dtype) c0 = make_arbitrary_jax_array(shape=(2, 3), dtype=dtype) c1 = make_arbitrary_jax_array(shape=(3,), dtype=dtype) self.a = BlockArray((a0, a1)) self.b = BlockArray((b0, b1)) self.c = BlockArray((c0, c1)) @pytest.fixture(scope="module") # so that random objects are cached def reduction_obj(request): yield BlockArrayReductionObj(request.param) REDUCTION_PARAMS = dict( argnames="reduction_obj, func", argvalues=( list(zip(itertools.repeat(np.float32), reduction_funcs)) + list(zip(itertools.repeat(np.complex64), reduction_funcs)) + list(zip(itertools.repeat(np.float32), real_reduction_funcs)) ), indirect=["reduction_obj"], ) @pytest.mark.parametrize(**REDUCTION_PARAMS) def test_reduce(reduction_obj, func): x = func(reduction_obj.a) x_jit = jax.jit(func)(reduction_obj.a) y = func(snp.ravel(reduction_obj.a)) np.testing.assert_allclose(x, x_jit, atol=1e-6) # test jitted function np.testing.assert_allclose(x, y, atol=1e-6) # test for correctness @pytest.mark.parametrize(**REDUCTION_PARAMS) @pytest.mark.parametrize("axis", (0, 1)) def test_reduce_axis(reduction_obj, func, axis): f = lambda x: func(x, axis=axis) x = f(reduction_obj.a) x_jit = jax.jit(f)(reduction_obj.a) sequence_assert_allclose(x, x_jit, rtol=1e-4) # test jitted function # test for correctness y0 = func(reduction_obj.a[0], axis=axis) y1 = func(reduction_obj.a[1], axis=axis) y = BlockArray((y0, y1)) sequence_assert_allclose(x, y) @pytest.mark.parametrize(**REDUCTION_PARAMS) def test_reduce_singleton(reduction_obj, func): # Case where one block is reduced to a singleton f = lambda x: func(x, axis=0) x = f(reduction_obj.c) x_jit = jax.jit(f)(reduction_obj.c) sequence_assert_allclose(x, x_jit, rtol=1e-4) # test jitted function y0 = func(reduction_obj.c[0], axis=0) y1 = func(reduction_obj.c[1], axis=0)[None] # Ensure size (1,) y = BlockArray((y0, y1)) sequence_assert_allclose(x, y) class TestCreators: def setup_method(self, method): np.random.seed(12345) self.a_shape = (2, 3) self.b_shape = (2, 4, 3) self.c_shape = (1,) self.shape = (self.a_shape, self.b_shape, self.c_shape) self.size = np.prod(self.a_shape) + np.prod(self.b_shape) + np.prod(self.c_shape) def test_zeros(self): x = snp.zeros(self.shape, dtype=np.float32) assert x.shape == self.shape assert snp.all(x == 0) def test_empty(self): x = snp.empty(self.shape, dtype=np.float32) assert x.shape == self.shape assert snp.all(x == 0) def test_ones(self): x = snp.ones(self.shape, dtype=np.float32) assert x.shape == self.shape assert snp.all(x == 1) def test_full(self): fill_value = np.float32(np.random.randn()) x = snp.full(self.shape, fill_value=fill_value, dtype=np.float32) assert x.shape == self.shape assert x.dtype == np.float32 assert snp.all(x == fill_value) def test_full_nodtype(self): fill_value = np.float32(np.random.randn()) x = snp.full(self.shape, fill_value=fill_value, dtype=None) assert x.shape == self.shape assert x.dtype == fill_value.dtype assert snp.all(x == fill_value) def test_list_triggering(): device_list = 4 * [jax.devices()[0]] ba = snp.ones((3, 3), device=device_list) assert isinstance(ba, BlockArray) assert ba.shape == 4 * ((3, 3),) # testing function tests @pytest.mark.parametrize("func", testing_functions) def test_test_func(func): a = snp.array([1.0, 2.0]) b = snp.blockarray((a, a)) f = rgetattr(snp, func) retval = f(b, b) assert retval is None # tests added for the BlockArray refactor @pytest.fixture def x(): # any BlockArray, arbitrary shape, content, type return BlockArray([[[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42.0]]) @pytest.fixture def y(): # another BlockArray, content, type, shape matches x return BlockArray([[[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]], [-2.0]]) @pytest.mark.parametrize("op", [op.neg, op.pos, op.abs]) def test_unary(op, x): actual = op(x) expected = BlockArray(op(x_i) for x_i in x) assert_array_equal(actual, expected) assert actual.dtype == expected.dtype @pytest.mark.parametrize( "op", [ op.mul, op.mod, op.lt, op.le, op.gt, op.ge, op.floordiv, op.eq, op.add, op.truediv, op.sub, op.ne, ], ) def test_elementwise_binary(op, x, y): actual = op(x, y) expected = BlockArray(op(x_i, y_i) for x_i, y_i in zip(x, y)) assert_array_equal(actual, expected) assert actual.dtype == expected.dtype def test_not_implemented_binary(x): with pytest.raises(TypeError, match=r"unsupported operand type\(s\)"): y = x + "a string" def test_matmul(x): # x is ((2, 3), (1,)) # y is ((3, 1), (1, 2)) y = BlockArray([[[1.0], [2.0], [3.0]], [[0.0, 1.0]]]) actual = x @ y expected = BlockArray([[[14.0], [0.0]], [0.0, 42.0]]) assert_array_equal(actual, expected) assert actual.dtype == expected.dtype def test_property(): x = BlockArray(([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [0.0])) actual = x.shape expected = ((2, 3), (1,)) assert actual == expected def test_method(): x = BlockArray(([[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42.0])) actual = x.max() expected = BlockArray([[3.0], [42.0]]) assert_array_equal(actual, expected) assert actual.dtype == expected.dtype def test_stack(): x = BlockArray(([[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]])) assert x.stack().shape == (2, 3) assert x.stack(axis=1).shape == (3, 2) y = BlockArray(([[1.0, 2.0, 3.0], [0.0, 0.0]])) with pytest.raises(ValueError): z = y.stack() def test_ravel(): # snp.ravel completely flattens a BlockArray ba = snp.ones([[2, 3], [3, 4]]) assert snp.ravel(ba).shape == (2 * 3 + 3 * 4,) # snp.ravel also flattens an Array arr = snp.ones((2, 3)) assert snp.ravel(arr).shape == (2 * 3,) # ba.flatten maps over BlockArray blocks assert ba.flatten().shape == ((2 * 3,), (3 * 4,)) # ba.ravel also maps over BlockArray blocks assert ba.ravel().shape == ((2 * 3,), (3 * 4,)) # snp.ravel works with scalar blocks # fmt: off scalar_ba = snp.ones( [ [], [1,], [1, 1], ] ) # fmt: on assert_array_equal(snp.ravel(scalar_ba), [1, 1, 1]) ================================================ FILE: scico/test/numpy/test_numpy.py ================================================ import numpy as np import jax import pytest import scico.numpy as snp from scico.numpy import _wrappers def on_cpu(): return jax.devices()[0].device_kind == "cpu" def check_results(jout, sout): if isinstance(jout, (tuple, list)) and isinstance(sout, (tuple, list)): # multiple outputs from the function for x, y in zip(jout, sout): np.testing.assert_allclose(x, y, rtol=1e-4) elif isinstance(jout, jax.Array) and isinstance(sout, jax.Array): # single array output from the function np.testing.assert_allclose(sout, jout, rtol=1e-4) elif jout.shape == () and sout.shape == (): # single scalar output from the function np.testing.assert_allclose(jout, sout, rtol=1e-4) else: # some type of output that isn't being captured? raise TypeError(f"Unexpected input type {type(jout)} or {type(sout)}.") def test_reshape_array(): a = np.random.randn(4, 4) np.testing.assert_allclose(snp.reshape(a.ravel(), (4, 4)), a) def test_ufunc_abs(): A = snp.array([-1, 2, 5]) res = snp.array([1, 2, 5]) np.testing.assert_allclose(snp.abs(A), res) A = snp.array([-1, -1, -1]) res = snp.array([1, 1, 1]) np.testing.assert_allclose(snp.abs(A), res) Ba = snp.blockarray((snp.array([-1, 2, 5]),)) res = snp.blockarray((snp.array([1, 2, 5]),)) np.testing.assert_allclose(snp.abs(Ba).ravel(), res.ravel()) Ba = snp.blockarray((snp.array([-1, -1, -1]),)) res = snp.blockarray((snp.array([1, 1, 1]),)) np.testing.assert_allclose(snp.abs(Ba).ravel(), res.ravel()) Ba = snp.blockarray((snp.array([-1, 2, -3]), snp.array([1, -2, 3]))) res = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2, 3]))) np.testing.assert_allclose(snp.abs(Ba).ravel(), res.ravel()) def test_ufunc_maximum(): A = snp.array([1, 2, 5]) B = snp.array([2, 3, 4]) res = snp.array([2, 3, 5]) np.testing.assert_allclose(snp.maximum(A, B), res) np.testing.assert_allclose(snp.maximum(B, A), res) A = snp.array([1, 1, 1]) B = snp.array([2, 2, 2]) res = snp.array([2, 2, 2]) np.testing.assert_allclose(snp.maximum(A, B), res) np.testing.assert_allclose(snp.maximum(B, A), res) A = 4 B = snp.array([3, 5, 2]) res = snp.array([4, 5, 4]) np.testing.assert_allclose(snp.maximum(A, B), res) np.testing.assert_allclose(snp.maximum(B, A), res) A = 5 B = 6 res = 6 np.testing.assert_allclose(snp.maximum(A, B), res) np.testing.assert_allclose(snp.maximum(B, A), res) A = snp.array([1, 2, 3]) B = snp.array([2, 3, 4]) C = snp.array([5, 6]) D = snp.array([2, 7]) Ba = snp.blockarray((A, C)) Bb = snp.blockarray((B, D)) res = snp.blockarray((snp.array([2, 3, 4]), snp.array([5, 7]))) Bmax = snp.maximum(Ba, Bb) snp.testing.assert_allclose(Bmax, res) A = snp.array([1, 6, 3]) B = snp.array([6, 3, 8]) C = 5 Ba = snp.blockarray((A, B)) res = snp.blockarray((snp.array([5, 6, 5]), snp.array([6, 5, 8]))) Bmax = snp.maximum(Ba, C) snp.testing.assert_allclose(Bmax, res) def test_ufunc_sign(): A = snp.array([10, -5, 0]) res = snp.array([1, -1, 0]) np.testing.assert_allclose(snp.sign(A), res) Ba = snp.blockarray((snp.array([10, -5, 0]),)) res = snp.blockarray((snp.array([1, -1, 0]),)) snp.testing.assert_allclose(snp.sign(Ba), res) Ba = snp.blockarray((snp.array([10, -5, 0]), snp.array([0, 5, -6]))) res = snp.blockarray((snp.array([1, -1, 0]), snp.array([0, 1, -1]))) snp.testing.assert_allclose(snp.sign(Ba), res) def test_ufunc_where(): A = snp.array([1, 2, 4, 5]) B = snp.array([-1, -1, -1, -1]) cond = snp.array([False, False, True, True]) res = snp.array([-1, -1, 4, 5]) np.testing.assert_allclose(snp.where(cond, A, B), res) Ba = snp.blockarray((snp.array([1, 2, 4, 5]),)) Bb = snp.blockarray((snp.array([-1, -1, -1, -1]),)) Bcond = snp.blockarray((snp.array([False, False, True, True]),)) Bres = snp.blockarray((snp.array([-1, -1, 4, 5]),)) assert snp.where(Bcond, Ba, Bb).shape == Bres.shape np.testing.assert_allclose(snp.where(Bcond, Ba, Bb).ravel(), Bres.ravel()) Ba = snp.blockarray((snp.array([1, 2, 4, 5]), snp.array([1, 2, 4, 5]))) Bb = snp.blockarray((snp.array([-1, -1, -1, -1]), snp.array([-1, -1, -1, -1]))) Bcond = snp.blockarray( (snp.array([False, False, True, True]), snp.array([True, True, False, False])) ) Bres = snp.blockarray((snp.array([-1, -1, 4, 5]), snp.array([1, 2, -1, -1]))) assert snp.where(Bcond, Ba, Bb).shape == Bres.shape np.testing.assert_allclose(snp.where(Bcond, Ba, Bb).ravel(), Bres.ravel()) def test_ufunc_true_divide(): A = snp.array([1, 2, 3]) B = snp.array([3, 3, 3]) res = snp.array([0.33333333, 0.66666667, 1.0]) np.testing.assert_allclose(snp.true_divide(A, B), res) A = snp.array([1, 2, 3]) B = 3 res = snp.array([0.33333333, 0.66666667, 1.0]) np.testing.assert_allclose(snp.true_divide(A, B), res) Ba = snp.blockarray((snp.array([1, 2, 3]),)) Bb = snp.blockarray((snp.array([3, 3, 3]),)) res = snp.blockarray((snp.array([0.33333333, 0.66666667, 1.0]),)) snp.testing.assert_allclose(snp.true_divide(Ba, Bb), res) Ba = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2]))) Bb = snp.blockarray((snp.array([3, 3, 3]), snp.array([2, 2]))) res = snp.blockarray((snp.array([0.33333333, 0.66666667, 1.0]), snp.array([0.5, 1.0]))) snp.testing.assert_allclose(snp.true_divide(Ba, Bb), res) Ba = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2]))) A = 2 res = snp.blockarray((snp.array([0.5, 1.0, 1.5]), snp.array([0.5, 1.0]))) snp.testing.assert_allclose(snp.true_divide(Ba, A), res) def test_ufunc_floor_divide(): A = snp.array([1, 2, 3]) B = snp.array([3, 3, 3]) res = snp.array([0, 0, 1.0]) np.testing.assert_allclose(snp.floor_divide(A, B), res) A = snp.array([4, 2, 3]) B = 3 res = snp.array([1.0, 0, 1.0]) np.testing.assert_allclose(snp.floor_divide(A, B), res) Ba = snp.blockarray((snp.array([1, 2, 3]),)) Bb = snp.blockarray((snp.array([3, 3, 3]),)) res = snp.blockarray((snp.array([0, 0, 1.0]),)) snp.testing.assert_allclose(snp.floor_divide(Ba, Bb), res) Ba = snp.blockarray((snp.array([1, 7, 3]), snp.array([1, 2]))) Bb = snp.blockarray((snp.array([3, 3, 3]), snp.array([2, 2]))) res = snp.blockarray((snp.array([0, 2, 1.0]), snp.array([0, 1.0]))) snp.testing.assert_allclose(snp.floor_divide(Ba, Bb), res) Ba = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2]))) A = 2 res = snp.blockarray((snp.array([0, 1.0, 1.0]), snp.array([0, 1.0]))) snp.testing.assert_allclose(snp.floor_divide(Ba, A), res) def test_ufunc_real(): A = snp.array([1 + 3j]) res = snp.array([1]) np.testing.assert_allclose(snp.real(A), res) A = snp.array([1 + 3j, 4.0 + 2j]) res = snp.array([1, 4.0]) np.testing.assert_allclose(snp.real(A), res) Ba = snp.blockarray((snp.array([1 + 3j]),)) res = snp.blockarray((snp.array([1]),)) snp.testing.assert_allclose(snp.real(Ba), res) Ba = snp.blockarray((snp.array([1.0 + 3j]), snp.array([1 + 3j, 4.0]))) res = snp.blockarray((snp.array([1.0]), snp.array([1, 4.0]))) snp.testing.assert_allclose(snp.real(Ba), res) def test_ufunc_imag(): A = snp.array([1 + 3j]) res = snp.array([3]) np.testing.assert_allclose(snp.imag(A), res) A = snp.array([1 + 3j, 4.0 + 2j]) res = snp.array([3, 2]) np.testing.assert_allclose(snp.imag(A), res) Ba = snp.blockarray((snp.array([1 + 3j]),)) res = snp.blockarray((snp.array([3]),)) snp.testing.assert_allclose(snp.imag(Ba), res) Ba = snp.blockarray((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0]))) res = snp.blockarray((snp.array([3]), snp.array([3, 0]))) snp.testing.assert_allclose(snp.imag(Ba), res) def test_ufunc_conj(): A = snp.array([1 + 3j]) res = snp.array([1 - 3j]) np.testing.assert_allclose(snp.conj(A), res) A = snp.array([1 + 3j, 4.0 + 2j]) res = snp.array([1 - 3j, 4.0 - 2j]) np.testing.assert_allclose(snp.conj(A), res) Ba = snp.blockarray((snp.array([1 + 3j]),)) res = snp.blockarray((snp.array([1 - 3j]),)) snp.testing.assert_allclose(snp.conj(Ba), res) Ba = snp.blockarray((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0]))) res = snp.blockarray((snp.array([1 - 3j]), snp.array([1 - 3j, 4.0 - 0j]))) snp.testing.assert_allclose(snp.conj(Ba), res) def test_create_zeros(): A = snp.zeros(2) assert np.all(A == 0) assert isinstance(A, jax.Array) A = snp.zeros((2,)) assert isinstance(A, jax.Array) A = snp.zeros(((2,), (2,))) assert snp.all(A == 0) assert isinstance(A, snp.BlockArray) A = snp.zeros(()) assert isinstance(A, jax.Array) # from issue 499 def test_create_ones(): A = snp.ones(2, dtype=np.float32) assert np.all(A == 1) A = snp.ones(((2,), (2,))) assert snp.all(A == 1) def test_create_empty(): A = snp.empty(2) assert np.all(A == 0) A = snp.empty(((2,), (2,))) assert snp.all(A == 0) def test_create_full(): A = snp.full((2,), 1) assert np.all(A == 1) A = snp.full((2,), 1, dtype=np.float32) assert np.all(A == 1) A = snp.full(((2,), (2,)), 1) assert snp.all(A == 1) def test_create_zeros_like(): A = snp.ones(2, dtype=np.float32) B = snp.zeros_like(A) assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype A = snp.ones(2, dtype=np.float32) B = snp.zeros_like(A) assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype A = snp.ones(((2,), (2,)), dtype=np.float32) B = snp.zeros_like(A) assert snp.all(B == 0) assert A.shape == B.shape assert A.dtype == B.dtype def test_create_empty_like(): A = snp.ones(2, dtype=np.float32) B = snp.empty_like(A) assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype A = snp.ones(2, dtype=np.float32) B = snp.empty_like(A) assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype A = snp.ones(((2,), (2,)), dtype=np.float32) B = snp.empty_like(A) assert snp.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype def test_create_ones_like(): A = snp.zeros(2, dtype=np.float32) B = snp.ones_like(A) assert np.all(B == 1) and A.shape == B.shape and A.dtype == B.dtype A = snp.zeros(2, dtype=np.float32) B = snp.ones_like(A) assert np.all(B == 1) and A.shape == B.shape and A.dtype == B.dtype A = snp.zeros(((2,), (2,)), dtype=np.float32) B = snp.ones_like(A) assert snp.all(B == 1) and A.shape == B.shape and A.dtype == B.dtype def test_create_full_like(): A = snp.zeros(2, dtype=np.float32) B = snp.full_like(A, 1.0) assert np.all(B == 1) and (A.shape == B.shape) and (A.dtype == B.dtype) A = snp.zeros(2, dtype=np.float32) B = snp.full_like(A, 1) assert np.all(B == 1) and (A.shape == B.shape) and (A.dtype == B.dtype) A = snp.zeros(((2,), (2,)), dtype=np.float32) B = snp.full_like(A, 1) assert snp.all(B == 1) and (A.shape == B.shape) and (A.dtype == B.dtype) def test_wrap_recursively(): target_dict = {"a": 1, "b": 2} names = ["a", "c"] wrap = lambda x: x + 1 with pytest.warns(Warning): _wrappers.wrap_recursively(target_dict, names, wrap) def test_add_full_reduction(): with pytest.raises(ValueError): _wrappers.add_full_reduction(np.sum, axis_arg_name="not_axis") ================================================ FILE: scico/test/numpy/test_numpy_util.py ================================================ import collections import numpy as np import jax.numpy as jnp import pytest import scico.numpy as snp from scico.numpy.util import ( array_info, array_to_namedtuple, complex_dtype, dtype_name, indexed_shape, is_blockable, is_collapsible, is_complex_dtype, is_nested, is_real_dtype, is_scalar_equiv, jax_indexed_shape, namedtuple_to_array, no_nan_divide, normalize_axes, real_dtype, shape_dtype_rep, slice_length, transpose_list_of_ntpl, transpose_ntpl_of_list, ) from scico.random import randn def test_ntpl_list_transpose(): nt = collections.namedtuple("NT", ("a", "b", "c")) ntlist0 = [nt(0, 1, 2), nt(3, 4, 5)] listnt = transpose_list_of_ntpl(ntlist0) ntlist1 = transpose_ntpl_of_list(listnt) assert ntlist0[0] == ntlist1[0] assert ntlist0[1] == ntlist1[1] def test_namedtuple_to_array(): nt = collections.namedtuple("NT", ("A", "B", "C")) t0 = nt(0, 1, 2) t0a = namedtuple_to_array(t0) t1 = array_to_namedtuple(t0a) assert t0 == t1 def test_no_nan_divide_array(): x, key = randn((4,), dtype=np.float32) y, key = randn(x.shape, dtype=np.float32, key=key) y = y.at[0].set(0) res = no_nan_divide(x, y) assert res[0] == 0 idx = y != 0 np.testing.assert_allclose(res[idx], x[idx] / y[idx]) def test_no_nan_divide_blockarray(): x, key = randn(((3, 3), (4,)), dtype=np.float32) y, key = randn(x.shape, dtype=np.float32, key=key) y[1] = y[1].at[:].set(0 * y[1]) res = no_nan_divide(x, y) assert snp.all(res[1] == 0.0) np.testing.assert_allclose(res[0], x[0] / y[0]) def test_array_info(): x = np.array([0.0, 0.1]) xinfo = array_info(x) assert "numpy.ndarray" in xinfo x = jnp.array([0.0, 0.1]) xinfo = array_info(x) assert "jax.Array" in xinfo x = snp.ones(((2, 3), (2,))) xinfo = array_info(x) assert "scico.numpy.BlockArray" in xinfo def test_normalize_axes(): axes = None np.testing.assert_raises(ValueError, normalize_axes, axes) axes = None assert normalize_axes(axes, np.shape([[1, 1], [1, 1]])) == (0, 1) axes = None assert normalize_axes(axes, np.shape([[1, 1], [1, 1]]), default=[0]) == [0] axes = [1, 2] assert normalize_axes(axes) == axes axes = 1 assert normalize_axes(axes) == (1,) axes = (-1,) assert normalize_axes(axes, shape=(1, 2)) == (1,) axes = (0, 2, 1) assert normalize_axes(axes, shape=(2, 3, 4), sort=True) == (0, 1, 2) axes = "axes" np.testing.assert_raises(ValueError, normalize_axes, axes) axes = 2 np.testing.assert_raises(ValueError, normalize_axes, axes, np.shape([1])) axes = (1, 2, 2) np.testing.assert_raises(ValueError, normalize_axes, axes) @pytest.mark.parametrize("length", (4, 5, 8, 16, 17)) @pytest.mark.parametrize("start", (None, 0, 1, 2, 3)) @pytest.mark.parametrize("stop", (None, 0, 1, 2, -2, -1)) @pytest.mark.parametrize("stride", (None, 1, 2, 3)) def test_slice_length(length, start, stop, stride): x = np.zeros(length) slc = slice(start, stop, stride) assert x[slc].size == slice_length(length, slc) @pytest.mark.parametrize("length", (4, 5)) @pytest.mark.parametrize("slc", (0, 1, -4, Ellipsis)) def test_slice_length_other(length, slc): x = np.zeros(length) if isinstance(slc, int): assert slice_length(length, slc) is None else: assert x[slc].size == slice_length(length, slc) @pytest.mark.parametrize("shape", ((8, 8, 1), (7, 1, 6, 5))) @pytest.mark.parametrize( "slc", ( np.s_[0], np.s_[0:5], np.s_[:, 0:4], np.s_[2:, :, :-2], np.s_[..., 2:], np.s_[..., 2:, :], np.s_[1:, ..., 2:], np.s_[np.newaxis], np.s_[:, np.newaxis], np.s_[np.newaxis, :, np.newaxis], np.s_[np.newaxis, ..., 0:2, :], ), ) def test_indexed_shape(shape, slc): x = np.zeros(shape) assert x[slc].shape == indexed_shape(shape, slc) assert x[slc].shape == jax_indexed_shape(shape, slc) def test_is_nested(): # list assert is_nested([1, 2, 3]) == False # tuple assert is_nested((1, 2, 3)) == False # list of lists assert is_nested([[1, 2], [4, 5], [3]]) == True # list of lists + scalar assert is_nested([[1, 2], 3]) == True # list of tuple + scalar assert is_nested([(1, 2), 3]) == True # tuple of tuple + scalar assert is_nested(((1, 2), 3)) == True # tuple of lists + scalar assert is_nested(([1, 2], 3)) == True def test_is_collapsible(): shape1 = ((1, 2, 3), (1, 2, 3), (1, 3, 3)) shape2 = ((1, 2, 3), (1, 2, 3), (1, 2, 3)) assert not is_collapsible(shape1) assert is_collapsible(shape2) def test_is_blockable(): shape1 = ((1, 2, 3), (1, 2, 3), (1, 2, 3)) shape2 = ((1, 2, 3), ((1, 2, 3), (1, 2, 3))) assert is_blockable(shape1) assert not is_blockable(shape2) @pytest.mark.parametrize("shape", [(3, 4), ((3, 4), (5,))]) def test_shape_dtype_rep(shape): dtype = np.float32 assert shape_dtype_rep(shape, dtype).shape == shape def test_is_real_dtype(): assert not is_real_dtype(snp.complex64) assert is_real_dtype(snp.float32) def test_is_complex_dtype(): assert is_complex_dtype(snp.complex64) assert not is_complex_dtype(snp.float32) def test_real_dtype(): assert real_dtype(snp.complex64) == snp.float32 def test_complex_dtype(): assert complex_dtype(snp.float32) == snp.complex64 def test_dtype_name(): assert dtype_name(np.float32) == "numpy.float32" assert dtype_name(snp.float32) == "jax.numpy.float32" def test_broadcast_nested_shapes(): # unnested should work as usual assert snp.util.broadcast_nested_shapes((1, 3, 4, 7), (3, 1, 7)) == (1, 3, 4, 7) # nested + unested assert snp.util.broadcast_nested_shapes(((2, 3), (1, 1, 3)), (2, 3)) == ((2, 3), (1, 2, 3)) # unested + nested assert snp.util.broadcast_nested_shapes((1, 1, 3), ((2, 3), (7, 3))) == ((1, 2, 3), (1, 7, 3)) # nested + nested snp.util.broadcast_nested_shapes(((1, 1, 3), (1, 7, 1, 3)), ((2, 3), (7, 4, 3))) == ( (1, 2, 3), (1, 7, 4, 3), ) def test_is_scalar_equiv(): assert is_scalar_equiv(1e0) assert is_scalar_equiv(snp.array(1e0)) assert is_scalar_equiv(snp.sum(snp.zeros(1))) assert not is_scalar_equiv(snp.array([1e0])) assert not is_scalar_equiv(snp.array([1e0, 2e0])) ================================================ FILE: scico/test/operator/test_biconvolve.py ================================================ import numpy as np import jax import jax.scipy.signal as signal import pytest from scico.linop import Convolve, ConvolveByX from scico.numpy import blockarray from scico.operator.biconvolve import BiConvolve from scico.random import randn class TestBiConvolve: def setup_method(self, method): self.key = jax.random.key(12345) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("mode", ["full", "valid", "same"]) @pytest.mark.parametrize("jit", [False, True]) def test_eval(self, input_dtype, mode, jit): x, key = randn((32, 32), dtype=input_dtype, key=self.key) h, key = randn((4, 4), dtype=input_dtype, key=self.key) x_h = blockarray([x, h]) A = BiConvolve(input_shape=x_h.shape, mode=mode, jit=jit) signal_out = signal.convolve(x, h, mode=mode) np.testing.assert_allclose(A(x_h), signal_out, rtol=1e-4) # Test freezing A_x = A.freeze(0, x) assert isinstance(A_x, ConvolveByX) np.testing.assert_allclose(A_x(h), signal_out, rtol=1e-4) A_h = A.freeze(1, h) assert isinstance(A_h, Convolve) np.testing.assert_allclose(A_h(x), signal_out, rtol=1e-4) with pytest.raises(ValueError): A.freeze(2, x) def test_invalid_shapes(self): with pytest.raises(ValueError): A = BiConvolve(input_shape=(2, 2)) with pytest.raises(ValueError): shape = ((2, 2), (3, 3), (4, 4)) # 3 blocks A = BiConvolve(input_shape=shape) with pytest.raises(ValueError): shape = ((2, 2), (3,)) # 3 blocks A = BiConvolve(input_shape=shape) ================================================ FILE: scico/test/operator/test_op_stack.py ================================================ import numpy as np import jax import pytest import scico.numpy as snp from scico.operator import ( Abs, DiagonalReplicated, DiagonalStack, Operator, VerticalStack, ) from scico.random import randn TestOpA = Operator(input_shape=(3, 4), output_shape=(2, 3, 4), eval_fn=lambda x: snp.stack((x, x))) TestOpB = Operator( input_shape=(3, 4), output_shape=(6, 4), eval_fn=lambda x: snp.concatenate((x, x)) ) TestOpC = Operator( input_shape=(3, 4), output_shape=(6, 4), eval_fn=lambda x: snp.concatenate((x, 2 * x)) ) class TestVerticalStack: def setup_method(self, method): self.key = jax.random.key(12345) @pytest.mark.parametrize("jit", [False, True]) def test_construct(self, jit): # requires a list of Operators A = Abs((42,)) with pytest.raises(TypeError): H = VerticalStack(A, jit=jit) # checks input sizes A = Abs((3, 2)) B = Abs((7, 2)) with pytest.raises(ValueError): H = VerticalStack([A, B], jit=jit) # in general, returns a BlockArray A = TestOpA B = TestOpB H = VerticalStack([A, B], jit=jit) x = np.ones((3, 4)) y = H(x) assert y.shape == ((2, 3, 4), (6, 4)) # ... result should be [A@x, B@x] assert np.allclose(y[0], A(x)) assert np.allclose(y[1], B(x)) # by default, collapse_output to jax array when possible A = TestOpB B = TestOpB H = VerticalStack([A, B], jit=jit) x = np.ones((3, 4)) y = H(x) assert y.shape == (2, 6, 4) # ... result should be [A@x, B@x] assert np.allclose(y[0], A(x)) assert np.allclose(y[1], B(x)) # let user turn off collapsing A = TestOpA B = TestOpA H = VerticalStack([A, B], collapse_output=False, jit=jit) x = np.ones((3, 4)) y = H(x) assert y.shape == ((2, 3, 4), (2, 3, 4)) @pytest.mark.parametrize("collapse_output", [False, True]) @pytest.mark.parametrize("jit", [False, True]) def test_algebra(self, collapse_output, jit): # adding A = TestOpB B = TestOpB H = VerticalStack([A, B], collapse_output=collapse_output, jit=jit) A = TestOpC B = TestOpC G = VerticalStack([A, B], collapse_output=collapse_output, jit=jit) x = np.ones((3, 4)) S = H + G # test correctness of addition assert S.output_shape == H.output_shape assert S.input_shape == H.input_shape np.testing.assert_allclose((S(x))[0], (H(x) + G(x))[0]) np.testing.assert_allclose((S(x))[1], (H(x) + G(x))[1]) class TestBlockDiagonalOperator: def test_construct(self): # requires a list of Operators A = Abs((8,)) with pytest.raises(TypeError): H = VerticalStack(A) # no nested output shapes A = Abs(((8,), (10,))) with pytest.raises(ValueError): H = VerticalStack((A, A)) # output dtypes must be the same A = Abs(input_shape=(8,), input_dtype=snp.float32) B = Abs(input_shape=(8,), input_dtype=snp.int32) with pytest.raises(ValueError): H = VerticalStack((A, B)) def test_apply(self): S1 = (3, 4) S2 = (3, 5) S3 = (2, 2) A1 = Abs(S1) A2 = 2 * Abs(S2) A3 = Abs(S3) H = DiagonalStack((A1, A2, A3)) x = snp.ones((S1, S2, S3)) y = H(x) y_expected = snp.blockarray((snp.ones(S1), 2 * snp.ones(S2), snp.sum(snp.ones(S3)))) np.testing.assert_equal(y, y_expected) def test_input_collapse(self): S = (3, 4) A1 = TestOpA A2 = TestOpB H = DiagonalStack((A1, A2)) assert H.input_shape == (2, *S) H = DiagonalStack((A1, A2), collapse_input=False) assert H.input_shape == (S, S) def test_output_collapse(self): A1 = TestOpB A2 = TestOpC H = DiagonalStack((A1, A2)) assert H.output_shape == (2, *A1.output_shape) H = DiagonalStack((A1, A2), collapse_output=False) assert H.output_shape == (A1.output_shape, A1.output_shape) class TestDiagonalReplicated: def setup_method(self, method): self.key = jax.random.key(12345) @pytest.mark.parametrize("map_type", ["auto", "vmap"]) @pytest.mark.parametrize("input_axis", [0, 1]) def test_map_auto_vmap(self, input_axis, map_type): x, key = randn((2, 3, 4), key=self.key) mapshape = (3, 4) if input_axis == 0 else (2, 4) replicates = x.shape[input_axis] A = Abs(mapshape) D = DiagonalReplicated(A, replicates, input_axis=input_axis, map_type=map_type) y = D(x) assert y.shape[input_axis] == replicates @pytest.mark.skipif(jax.device_count() < 2, reason="multiple devices required for test") def test_map_auto_pmap(self): x, key = randn((2, 3, 4), key=self.key) A = Abs(x.shape[1:]) replicates = x.shape[0] D = DiagonalReplicated(A, replicates, map_type="pmap") y = D(x) assert y.shape[0] == replicates def test_input_axis(self): # Ensure that operators can be stacked on final axis x, key = randn((2, 3, 4), key=self.key) A = Abs(x.shape[0:2]) replicates = x.shape[2] D = DiagonalReplicated(A, replicates, input_axis=2) y = D(x) assert y.shape == (2, 3, 4) D = DiagonalReplicated(A, replicates, input_axis=-1) y = D(x) assert y.shape == (2, 3, 4) def test_output_axis(self): x, key = randn((2, 3, 4), key=self.key) A = Abs(x.shape[1:]) replicates = x.shape[0] D = DiagonalReplicated(A, replicates, output_axis=1) y = D(x) assert y.shape == (3, 2, 4) ================================================ FILE: scico/test/operator/test_operator.py ================================================ import operator as op import numpy as np from jax import config import pytest # enable 64-bit mode for output dtype checks config.update("jax_enable_x64", True) import jax import scico.numpy as snp from scico.operator import Abs, Angle, Exp, Operator, operator_from_function from scico.random import randn SCALARS = (2, 1e0, snp.array(1.0)) class AbsOperator(Operator): def _eval(self, x): return snp.sum(snp.abs(x)) class SquareOperator(Operator): def _eval(self, x): return x**2 class SumSquareOperator(Operator): def _eval(self, x): return snp.sum(x**2) class OperatorTestObj: def __init__(self, dtype): M, N = (32, 64) key = jax.random.key(12345) self.dtype = dtype self.A = AbsOperator(input_shape=(N,), input_dtype=dtype) self.B = SquareOperator(input_shape=(N,), input_dtype=dtype) self.S = SumSquareOperator(input_shape=(N,), input_dtype=dtype) self.mat = randn(self.A.input_shape, dtype=dtype, key=key) self.x, key = randn((N,), dtype=dtype, key=key) self.z, key = randn((2 * N,), dtype=dtype, key=key) @pytest.fixture(scope="module", params=[np.float32, np.float64, np.complex64, np.complex128]) def testobj(request): yield OperatorTestObj(request.param) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_binary_op(testobj, operator): # Our AbsOperator class does not override the __add__, etc # so AbsOperator + AbsMatOp -> Operator x = testobj.x # Composite operator comp_op = operator(testobj.A, testobj.S) # evaluate Operators separately, then add/sub res = operator(testobj.A(x), testobj.S(x)) assert comp_op.output_dtype == res.dtype np.testing.assert_allclose(comp_op(x), res, rtol=5e-5) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_binary_op_same(testobj, operator): x = testobj.x # Composite operator comp_op = operator(testobj.A, testobj.A) # evaluate Operators separately, then add/sub res = operator(testobj.A(x), testobj.A(x)) assert isinstance(comp_op, Operator) assert comp_op.output_dtype == res.dtype np.testing.assert_allclose(comp_op(x), res, rtol=5e-5) @pytest.mark.parametrize("operator", [op.mul, op.truediv]) @pytest.mark.parametrize("scalar", SCALARS) def test_scalar_left(testobj, operator, scalar): x = testobj.x comp_op = operator(testobj.A, scalar) res = operator(testobj.A(x), scalar) assert comp_op.output_dtype == res.dtype np.testing.assert_allclose(comp_op(x), res, rtol=5e-5) @pytest.mark.parametrize("operator", [op.mul, op.truediv]) @pytest.mark.parametrize("scalar", SCALARS) def test_scalar_right(testobj, operator, scalar): if operator == op.truediv: pytest.xfail("scalar / Operator is not supported") x = testobj.x comp_op = operator(scalar, testobj.A) res = operator(scalar, testobj.A(x)) assert comp_op.output_dtype == res.dtype np.testing.assert_allclose(comp_op(x), res, rtol=5e-5) def test_negation(testobj): x = testobj.x comp_op = -testobj.A res = -(testobj.A(x)) assert comp_op.input_dtype == testobj.A.input_dtype np.testing.assert_allclose(comp_op(x), res, rtol=5e-5) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_invalid_add_sub_array(testobj, operator): # Try to add or subtract an ndarray with Operator with pytest.raises(TypeError): operator(testobj.A, testobj.mat) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_invalid_add_sub_scalar(testobj, operator): # Try to add or subtract a scalar with AbsMatOp with pytest.raises(TypeError): operator(1.0, testobj.A) def test_call_operator_operator(testobj): x = testobj.x A = testobj.A B = testobj.B np.testing.assert_allclose(A(B)(x), A(B(x))) with pytest.raises(ValueError): # incompatible shapes A(testobj.S) def test_shape_call_vec(testobj): # evaluate operator on an array of incompatible size with pytest.raises(ValueError): testobj.A(testobj.z) def test_scale_vmap(testobj): A = testobj.A x = testobj.x def foo(c): return (c * A)(x) c_list = [1.0, 2.0, 3.0] non_vmap = np.array([foo(c) for c in c_list]) vmapped = jax.vmap(foo)(snp.array(c_list)) np.testing.assert_allclose(non_vmap, vmapped) def test_scale_pmap(testobj): A = testobj.A x = testobj.x def foo(c): return (c * A)(x) c_list = np.random.randn(jax.device_count()) non_pmap = np.array([foo(c) for c in c_list]) pmapped = jax.pmap(foo)(c_list) np.testing.assert_allclose(non_pmap, pmapped, rtol=1e-6) def test_freeze_3arg(): A = Operator( input_shape=((1, 3, 4), (2, 1, 4), (2, 3, 1)), eval_fn=lambda x: x[0] * x[1] * x[2] ) a, _ = randn((1, 3, 4)) b, _ = randn((2, 1, 4)) c, _ = randn((2, 3, 1)) x = snp.blockarray([a, b, c]) Abc = A.freeze(0, a) # A as a function of b, c Aac = A.freeze(1, b) # A as a function of a, c Aab = A.freeze(2, c) # A as a function of a, b assert Abc.input_shape == ((2, 1, 4), (2, 3, 1)) assert Aac.input_shape == ((1, 3, 4), (2, 3, 1)) assert Aab.input_shape == ((1, 3, 4), (2, 1, 4)) bc = snp.blockarray([b, c]) ac = snp.blockarray([a, c]) ab = snp.blockarray([a, b]) np.testing.assert_allclose(A(x), Abc(bc), rtol=5e-4) np.testing.assert_allclose(A(x), Aac(ac), rtol=5e-4) np.testing.assert_allclose(A(x), Aab(ab), rtol=5e-4) def test_freeze_2arg(): A = Operator(input_shape=((1, 3, 4), (2, 1, 4)), eval_fn=lambda x: x[0] * x[1]) a, _ = randn((1, 3, 4)) b, _ = randn((2, 1, 4)) x = snp.blockarray([a, b]) Ab = A.freeze(0, a) # A as a function of 'b' only Aa = A.freeze(1, b) # A as a function of 'a' only assert Ab.input_shape == (2, 1, 4) assert Aa.input_shape == (1, 3, 4) np.testing.assert_allclose(A(x), Ab(b), rtol=5e-4) np.testing.assert_allclose(A(x), Aa(a), rtol=5e-4) @pytest.mark.parametrize("dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("op_fn", [(Abs, snp.abs), (Angle, snp.angle), (Exp, snp.exp)]) def test_func_op(op_fn, dtype): op = op_fn[0] fn = op_fn[1] shape = (2, 3) x, _ = randn(shape, dtype=dtype) H = op(input_shape=shape, input_dtype=dtype) np.testing.assert_array_equal(H(x), fn(x)) def test_make_func_op(): AbsVal = operator_from_function(snp.abs, "AbsVal") shape = (2,) x, _ = randn(shape, dtype=np.float32) H = AbsVal(input_shape=shape, input_dtype=np.float32) np.testing.assert_array_equal(H(x), snp.abs(x)) def test_make_func_op_ext_init(): AbsVal = operator_from_function(snp.abs, "AbsVal") shape = (2,) x, _ = randn(shape, dtype=np.float32) H = AbsVal( input_shape=shape, output_shape=shape, input_dtype=np.float32, output_dtype=np.float32 ) np.testing.assert_array_equal(H(x), snp.abs(x)) class TestJacobianProdReal: def setup_method(self): N = 7 M = 8 key = None dtype = snp.float32 self.fmx, key = randn((M, N), key=key, dtype=dtype) self.F = Operator( (N, 1), output_shape=(M, 1), eval_fn=lambda x: self.fmx @ x, input_dtype=dtype, output_dtype=dtype, ) self.u, key = randn((N, 1), key=key, dtype=dtype) self.v, key = randn((N, 1), key=key, dtype=dtype) self.w, key = randn((M, 1), key=key, dtype=dtype) def test_jvp(self): Fu, JFuv = self.F.jvp(self.u, self.v) np.testing.assert_allclose(Fu, self.F(self.u)) np.testing.assert_allclose(JFuv, self.fmx @ self.v, atol=1e-6, rtol=0.0) def test_vjp_conj(self): Fu, G = self.F.vjp(self.u, conjugate=True) JFTw = G(self.w) np.testing.assert_allclose(Fu, self.F(self.u)) np.testing.assert_allclose(JFTw, self.fmx.T @ self.w, atol=1e-6, rtol=0.0) def test_vjp_noconj(self): Fu, G = self.F.vjp(self.u, conjugate=False) JFTw = G(self.w) np.testing.assert_allclose(Fu, self.F(self.u)) np.testing.assert_allclose(JFTw, self.fmx.T @ self.w, atol=1e-6, rtol=0.0) class TestJacobianProdComplex: def setup_method(self): N = 7 M = 8 key = None dtype = snp.complex64 self.fmx, key = randn((M, N), key=key, dtype=dtype) self.F = Operator( (N, 1), output_shape=(M, 1), eval_fn=lambda x: self.fmx @ x, input_dtype=dtype, output_dtype=dtype, ) self.u, key = randn((N, 1), key=key, dtype=dtype) self.v, key = randn((N, 1), key=key, dtype=dtype) self.w, key = randn((M, 1), key=key, dtype=dtype) def test_jvp(self): Fu, JFuv = self.F.jvp(self.u, self.v) np.testing.assert_allclose(Fu, self.F(self.u)) np.testing.assert_allclose(JFuv, self.fmx @ self.v, rtol=1e-6) def test_vjp_conj(self): Fu, G = self.F.vjp(self.u, conjugate=True) JFTw = G(self.w) np.testing.assert_allclose(Fu, self.F(self.u)) np.testing.assert_allclose(JFTw, self.fmx.T.conj() @ self.w, rtol=1e-6) def test_vjp_noconj(self): Fu, G = self.F.vjp(self.u, conjugate=False) JFTw = G(self.w) np.testing.assert_allclose(Fu, self.F(self.u)) np.testing.assert_allclose(JFTw, self.fmx.T @ self.w, rtol=1e-6) ================================================ FILE: scico/test/optimize/test_admm.py ================================================ import os import tempfile import numpy as np import pytest import scico.numpy as snp from scico import functional, linop, loss, metric, operator, random from scico.optimize import ADMM from scico.optimize.admm import ( CircularConvolveSolver, FBlockCircularConvolveSolver, G0BlockCircularConvolveSolver, GenericSubproblemSolver, LinearSubproblemSolver, MatrixSubproblemSolver, ) class TestMisc: def setup_method(self, method): np.random.seed(12345) self.y = snp.array(np.random.randn(16, 17).astype(np.float32)) def test_admm(self): maxiter = 2 ρ = 1e-1 A = linop.Identity(self.y.shape) f = loss.SquaredL2Loss(y=self.y, A=A) g = functional.DnCNN() C = linop.Identity(self.y.shape) itstat_fields = {"Iter": "%d", "Time": "%8.2e"} def itstat_func(obj): return (obj.itnum, obj.timer.elapsed()) admm_ = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], maxiter=maxiter, itstat_options={"display": False}, ) assert len(admm_.itstat_object.fieldname) == 6 assert snp.sum(admm_.x) == 0.0 admm_ = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], maxiter=maxiter, itstat_options={"fields": itstat_fields, "itstat_func": itstat_func, "display": False}, ) assert len(admm_.itstat_object.fieldname) == 2 admm_.test_flag = False def callback(obj): obj.test_flag = True x = admm_.solve(callback=callback) assert admm_.test_flag with pytest.raises(TypeError): admm_ = ADMM(f=f, g_list=[g], C_list=[C], rho_list=[ρ], invalid_keyword_arg=None) admm_ = ADMM(f=f, g_list=[g], C_list=[C], rho_list=[ρ], maxiter=maxiter, nanstop=True) admm_.step() admm_.x = admm_.x.at[0].set(np.nan) with pytest.raises(ValueError): admm_.solve() @pytest.mark.parametrize( "solver", [LinearSubproblemSolver, MatrixSubproblemSolver, CircularConvolveSolver] ) def test_admm_aux(self, solver): maxiter = 2 ρ = 1e-1 A = operator.Abs(self.y.shape) f = loss.SquaredL2Loss(y=self.y, A=A) g = functional.DnCNN() C = linop.Identity(self.y.shape) with pytest.raises(TypeError): admm_ = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], maxiter=maxiter, subproblem_solver=solver(), ) with pytest.raises(TypeError): admm_ = ADMM( f=g, g_list=[g], C_list=[C], rho_list=[ρ], maxiter=maxiter, subproblem_solver=solver(), ) class TestReal: def setup_method(self, method): np.random.seed(12345) MA = 4 MB = 5 N = 6 # Set up arrays for problem argmin (𝛼/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2 Amx = np.random.randn(MA, N).astype(np.float32) Bmx = np.random.randn(MB, N).astype(np.float32) y = np.random.randn(MA).astype(np.float32) 𝛼 = np.pi # sort of random number chosen to test non-default scale factor λ = 1e0 self.Amx = Amx self.Bmx = Bmx self.y = snp.array(y) self.𝛼 = 𝛼 self.λ = λ # Solution of problem is given by linear system (𝛼 A^T A + λ B^T B) x = 𝛼 A^T y self.grdA = lambda x: (𝛼 * Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x self.grdb = 𝛼 * Amx.T @ y def test_admm_generic(self): maxiter = 25 ρ = 2e-1 A = linop.MatrixOperator(self.Amx) f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0) g_list = [(self.λ / 2) * functional.SquaredL2Norm()] C_list = [linop.MatrixOperator(self.Bmx)] rho_list = [ρ] admm_ = ADMM( f=f, g_list=g_list, C_list=C_list, rho_list=rho_list, maxiter=maxiter, itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=GenericSubproblemSolver(minimize_kwargs={"options": {"maxiter": 50}}), ) x = admm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-3 def test_admm_saveload(self): maxiter = 5 x_ref = np.ones((16, 16), dtype=np.float32) x_ref[4:-4, 4:-4] = 1.0 n = 3 psf = snp.ones((n, n), dtype=np.float32) / (n * n) A = linop.CircularConvolve(h=psf, input_shape=x_ref.shape) y = A(x_ref) λ = 2e-2 ρ = 5e-1 f = loss.SquaredL2Loss(y=y, A=A) g = λ * functional.L21Norm() C = linop.FiniteDifference(x_ref.shape, circular=True) admm0 = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], x0=A.adj(y), maxiter=maxiter, subproblem_solver=CircularConvolveSolver(), ) admm0.solve() with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "admm.npz") admm0.save_state(path) admm0.solve() h0 = admm0.history() admm1 = ADMM( f=f, g_list=[g], C_list=[C], rho_list=[ρ], x0=A.adj(y), maxiter=maxiter, subproblem_solver=CircularConvolveSolver(), ) admm1.load_state(path) admm1.solve() h1 = admm1.history() np.testing.assert_allclose(admm0.minimizer(), admm1.minimizer(), atol=1e-7) assert np.abs(h0[-1].Objective - h1[-1].Objective) < 1e-6 def test_admm_quadratic_scico(self): maxiter = 25 ρ = 4e-1 A = linop.MatrixOperator(self.Amx) f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0) g_list = [(self.λ / 2) * functional.SquaredL2Norm()] C_list = [linop.MatrixOperator(self.Bmx)] rho_list = [ρ] admm_ = ADMM( f=f, g_list=g_list, C_list=C_list, rho_list=rho_list, maxiter=maxiter, itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4}, cg_function="scico"), ) x = admm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 def test_admm_quadratic_jax(self): maxiter = 25 ρ = 1e0 A = linop.MatrixOperator(self.Amx) f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0) g_list = [(self.λ / 2) * functional.SquaredL2Norm()] C_list = [linop.MatrixOperator(self.Bmx)] rho_list = [ρ] admm_ = ADMM( f=f, g_list=g_list, C_list=C_list, rho_list=rho_list, maxiter=maxiter, itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4}, cg_function="jax"), ) x = admm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 def test_admm_quadratic_relax(self): maxiter = 25 ρ = 1e0 A = linop.MatrixOperator(self.Amx) f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0) g_list = [(self.λ / 2) * functional.SquaredL2Norm()] C_list = [linop.MatrixOperator(self.Bmx)] rho_list = [ρ] admm_ = ADMM( f=f, g_list=g_list, C_list=C_list, rho_list=rho_list, alpha=1.6, maxiter=maxiter, itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4}, cg_function="jax"), ) x = admm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 class TestRealWeighted: def setup_method(self, method): np.random.seed(12345) MA = 4 MB = 5 N = 6 # Set up arrays for problem argmin (𝛼/2) ||A x - y||_W^2 + (λ/2) ||B x||_2^2 Amx = np.random.randn(MA, N).astype(np.float32) W = np.abs(np.random.randn(MA, 1).astype(np.float32)) Bmx = np.random.randn(MB, N).astype(np.float32) y = np.random.randn(MA).astype(np.float32) 𝛼 = np.pi # sort of random number chosen to test non-default scale factor λ = np.e self.Amx = Amx self.W = snp.array(W) self.Bmx = Bmx self.y = snp.array(y) self.𝛼 = 𝛼 self.λ = λ # Solution of problem is given by linear system # (𝛼 A^T W A + λ B^T B) x = 𝛼 A^T W y self.grdA = lambda x: (𝛼 * Amx.T @ (W * Amx) + λ * Bmx.T @ Bmx) @ x self.grdb = 𝛼 * Amx.T @ (W[:, 0] * y) def test_admm_quadratic_linear(self): maxiter = 100 ρ = 1e0 A = linop.MatrixOperator(self.Amx) f = loss.SquaredL2Loss(y=self.y, A=A, W=linop.Diagonal(self.W[:, 0]), scale=self.𝛼 / 2.0) g_list = [(self.λ / 2) * functional.SquaredL2Norm()] C_list = [linop.MatrixOperator(self.Bmx)] rho_list = [ρ] admm_ = ADMM( f=f, g_list=g_list, C_list=C_list, rho_list=rho_list, maxiter=maxiter, itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4}, cg_function="scico"), ) x = admm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 def test_admm_quadratic_matrix(self): maxiter = 50 ρ = 1e0 A = linop.MatrixOperator(self.Amx) f = loss.SquaredL2Loss(y=self.y, A=A, W=linop.Diagonal(self.W[:, 0]), scale=self.𝛼 / 2.0) g_list = [(self.λ / 2) * functional.SquaredL2Norm()] C_list = [linop.MatrixOperator(self.Bmx)] rho_list = [ρ] admm_ = ADMM( f=f, g_list=g_list, C_list=C_list, rho_list=rho_list, maxiter=maxiter, itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=MatrixSubproblemSolver(), ) x = admm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5 class TestComplex: def setup_method(self, method): MA = 4 MB = 5 N = 6 # Set up arrays for problem argmin (𝛼/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2 Amx, key = random.randn((MA, N), dtype=np.complex64, key=None) Bmx, key = random.randn((MB, N), dtype=np.complex64, key=key) y, key = random.randn((MA,), dtype=np.complex64, key=key) 𝛼 = 1.0 / 3.0 λ = 1e0 self.Amx = Amx self.Bmx = Bmx self.y = y self.𝛼 = 𝛼 self.λ = λ # Solution of problem is given by linear system (𝛼 A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (𝛼 * Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x self.grdb = 𝛼 * Amx.conj().T @ y def test_admm_generic(self): maxiter = 30 ρ = 1e0 A = linop.MatrixOperator(self.Amx) f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0) g_list = [(self.λ / 2) * functional.SquaredL2Norm()] C_list = [linop.MatrixOperator(self.Bmx)] rho_list = [ρ] admm_ = ADMM( f=f, g_list=g_list, C_list=C_list, rho_list=rho_list, maxiter=maxiter, itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=GenericSubproblemSolver(minimize_kwargs={"options": {"maxiter": 50}}), ) x = admm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-3 def test_admm_quadratic_linear(self): maxiter = 50 ρ = 1e0 A = linop.MatrixOperator(self.Amx) f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0) g_list = [(self.λ / 2) * functional.SquaredL2Norm()] C_list = [linop.MatrixOperator(self.Bmx)] rho_list = [ρ] admm_ = ADMM( f=f, g_list=g_list, C_list=C_list, rho_list=rho_list, maxiter=maxiter, itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=LinearSubproblemSolver( cg_kwargs={"tol": 1e-4}, ), ) x = admm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 def test_admm_quadratic_matrix(self): maxiter = 50 ρ = 1e0 A = linop.MatrixOperator(self.Amx) f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.𝛼 / 2.0) g_list = [(self.λ / 2) * functional.SquaredL2Norm()] C_list = [linop.MatrixOperator(self.Bmx)] rho_list = [ρ] admm_ = ADMM( f=f, g_list=g_list, C_list=C_list, rho_list=rho_list, maxiter=maxiter, itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=MatrixSubproblemSolver(), ) x = admm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5 @pytest.mark.parametrize("extra_axis", (False, True)) @pytest.mark.parametrize("center", (None, [-1.0, 2.5])) class TestCircularConvolveSolve: @pytest.fixture(scope="function", autouse=True) def setup_and_teardown(self, extra_axis, center): np.random.seed(12345) Nx = 8 x = snp.pad(snp.ones((Nx, Nx), dtype=np.float32), Nx) Npsf = 3 psf = snp.ones((Npsf, Npsf), dtype=np.float32) / (Npsf**2) if extra_axis: x = x[np.newaxis] psf = psf[np.newaxis] self.A = linop.CircularConvolve( h=psf, input_shape=x.shape, ndims=2, input_dtype=np.float32, h_center=center ) self.y = self.A(x) λ = 1e-2 self.f = loss.SquaredL2Loss(y=self.y, A=self.A) self.g_list = [λ * functional.L1Norm()] self.C_list = [linop.FiniteDifference(input_shape=x.shape, circular=True)] yield def test_admm(self): maxiter = 50 ρ = 1e-1 rho_list = [ρ] admm_lin = ADMM( f=self.f, g_list=self.g_list, C_list=self.C_list, rho_list=rho_list, maxiter=maxiter, itstat_options={"display": False}, x0=self.A.adj(self.y), subproblem_solver=LinearSubproblemSolver(), ) x_lin = admm_lin.solve() admm_dft = ADMM( f=self.f, g_list=self.g_list, C_list=self.C_list, rho_list=rho_list, maxiter=maxiter, itstat_options={"display": False}, x0=self.A.adj(self.y), subproblem_solver=CircularConvolveSolver(), ) assert admm_dft.subproblem_solver.A_lhs.ndims == 2 x_dft = admm_dft.solve() np.testing.assert_allclose(x_dft, x_lin, atol=1e-4, rtol=0) assert metric.mse(x_lin, x_dft) < 1e-9 @pytest.mark.parametrize("with_cconv", (False, True)) class TestSpecialCaseCircularConvolveSolve: @pytest.fixture(scope="function", autouse=True) def setup_and_teardown(self, with_cconv): np.random.seed(12345) Nx = 8 x = snp.pad(snp.ones((1, Nx, Nx), dtype=np.float32), Nx) if with_cconv: Npsf = 3 psf = snp.ones((1, Npsf, Npsf), dtype=np.float32) / (Npsf**2) C0 = linop.CircularConvolve(h=psf, input_shape=x.shape, ndims=2, input_dtype=np.float32) else: C0 = linop.FiniteDifference(input_shape=x.shape, axes=(1, 2), circular=True) C1 = linop.Identity(input_shape=x.shape) self.y = C0(x) self.g_list = [loss.SquaredL2Loss(y=self.y), functional.L2Norm()] self.C_list = [C0, C1] self.with_cconv = with_cconv yield def test_admm(self): maxiter = 50 ρ = 1e-1 rho_list = [ρ, ρ] admm_lin = ADMM( f=None, g_list=self.g_list, C_list=self.C_list, rho_list=rho_list, maxiter=maxiter, itstat_options={"display": False}, x0=self.C_list[0].adj(self.y), subproblem_solver=LinearSubproblemSolver(), ) x_lin = admm_lin.solve() ndims = None if self.with_cconv else 2 admm_dft = ADMM( f=None, g_list=self.g_list, C_list=self.C_list, rho_list=rho_list, maxiter=maxiter, itstat_options={"display": False}, x0=self.C_list[0].adj(self.y), subproblem_solver=CircularConvolveSolver(ndims=ndims), ) assert admm_dft.subproblem_solver.A_lhs.ndims == 2 x_dft = admm_dft.solve() np.testing.assert_allclose(x_dft, x_lin, atol=1e-4, rtol=0) assert metric.mse(x_lin, x_dft) < 1e-9 class TestBlockCircularConvolveSolve: def setup_method(self, method): np.random.seed(12345) Nx = 8 x = np.zeros((2, Nx, Nx), dtype=np.float32) x[0, 2, 2] = 1.0 x[1, 3, 3] = 1.0 Npsf = 3 psf = np.zeros((2, Npsf, Npsf), dtype=np.float32) psf[0, 1] = 1.0 psf[1, :, 1] = 1.0 C = linop.CircularConvolve(h=psf, input_shape=x.shape, input_dtype=np.float32, ndims=2) S = linop.Sum(input_shape=x.shape, axis=0) self.A = S @ C self.y = self.A(x) λ = 1e-1 self.f = loss.SquaredL2Loss(y=self.y, A=self.A) self.g_list = [λ * functional.L1Norm()] self.C_list = [linop.Identity(input_shape=x.shape)] def test_fblock_init(self): with pytest.raises(ValueError): slvr = ADMM( f=None, g_list=self.g_list, C_list=self.C_list, rho_list=[1.0], itstat_options={"display": False}, subproblem_solver=FBlockCircularConvolveSolver(), ) with pytest.raises(TypeError): slvr = ADMM( f=loss.PoissonLoss(y=self.y), g_list=self.g_list, C_list=self.C_list, rho_list=[1.0], itstat_options={"display": False}, subproblem_solver=FBlockCircularConvolveSolver(), ) with pytest.raises(TypeError): slvr = ADMM( f=loss.SquaredL2Loss(y=self.y, A=self.A.A), g_list=self.g_list, C_list=self.C_list, rho_list=[1.0], itstat_options={"display": False}, subproblem_solver=FBlockCircularConvolveSolver(), ) def test_g0block_init(self): with pytest.raises(ValueError): slvr = ADMM( f=self.f, g_list=self.g_list, C_list=self.C_list, rho_list=[1.0], itstat_options={"display": False}, subproblem_solver=G0BlockCircularConvolveSolver(), ) with pytest.raises(TypeError): slvr = ADMM( f=functional.ZeroFunctional(), g_list=[loss.PoissonLoss(y=self.y)], C_list=self.C_list, rho_list=[1.0], itstat_options={"display": False}, subproblem_solver=G0BlockCircularConvolveSolver(), ) with pytest.raises(TypeError): slvr = ADMM( f=functional.ZeroFunctional(), g_list=[loss.SquaredL2Loss(y=self.y)] + self.g_list, C_list=[self.A.A] + self.C_list, rho_list=[1.0, 1.0], itstat_options={"display": False}, subproblem_solver=G0BlockCircularConvolveSolver(), ) def test_solve(self): maxiter = 50 ρ = 1e1 rho_list = [ρ] admm_lin = ADMM( f=self.f, g_list=self.g_list, C_list=self.C_list, rho_list=rho_list, maxiter=maxiter, itstat_options={"display": False}, subproblem_solver=LinearSubproblemSolver(), ) x_lin = admm_lin.solve() admm_dft1 = ADMM( f=self.f, g_list=self.g_list, C_list=self.C_list, rho_list=rho_list, maxiter=maxiter, itstat_options={"display": False}, subproblem_solver=FBlockCircularConvolveSolver(check_solve=True), ) x_dft1 = admm_dft1.solve() np.testing.assert_allclose(x_dft1, x_lin, atol=1e-4, rtol=0) assert metric.mse(x_lin, x_dft1) < 1e-9 assert admm_dft1.subproblem_solver.accuracy <= 1e-6 admm_dft2 = ADMM( f=functional.ZeroFunctional(), g_list=[loss.SquaredL2Loss(y=self.y)] + self.g_list, C_list=[self.A] + self.C_list, rho_list=[1.0, ρ], maxiter=maxiter, itstat_options={"display": False}, subproblem_solver=G0BlockCircularConvolveSolver(check_solve=True), ) admm_dft2.z_list[0] = self.y # significantly improves convergence x_dft2 = admm_dft2.solve() np.testing.assert_allclose(x_dft2, x_lin, atol=1e-4, rtol=0) assert metric.mse(x_lin, x_dft2) < 1e-9 assert admm_dft2.subproblem_solver.accuracy <= 1e-6 ================================================ FILE: scico/test/optimize/test_ladmm.py ================================================ import os import tempfile import numpy as np import pytest import scico.numpy as snp from scico import functional, linop, loss, random from scico.numpy import BlockArray from scico.optimize import LinearizedADMM class TestMisc: def setup_method(self, method): np.random.seed(12345) self.y = snp.array(np.random.randn(32, 33).astype(np.float32)) self.maxiter = 2 self.μ = 1e-1 self.ν = 1e-1 self.A = linop.Identity(self.y.shape) self.f = loss.SquaredL2Loss(y=self.y, A=self.A) self.g = functional.DnCNN() self.C = linop.Identity(self.y.shape) def test_itstat(self): itstat_fields = {"Iter": "%d", "Time": "%8.2e"} def itstat_func(obj): return (obj.itnum, obj.timer.elapsed()) ladmm_ = LinearizedADMM( f=self.f, g=self.g, C=self.C, mu=self.μ, nu=self.ν, maxiter=self.maxiter, ) assert len(ladmm_.itstat_object.fieldname) == 4 assert snp.sum(ladmm_.x) == 0.0 ladmm_ = LinearizedADMM( f=self.f, g=self.g, C=self.C, mu=self.μ, nu=self.ν, maxiter=self.maxiter, itstat_options={"fields": itstat_fields, "itstat_func": itstat_func, "display": False}, ) assert len(ladmm_.itstat_object.fieldname) == 2 def test_callback(self): ladmm_ = LinearizedADMM( f=self.f, g=self.g, C=self.C, mu=self.μ, nu=self.ν, maxiter=self.maxiter, ) ladmm_.test_flag = False def callback(obj): obj.test_flag = True x = ladmm_.solve(callback=callback) assert ladmm_.test_flag def test_finite_check(self): ladmm_ = LinearizedADMM( f=self.f, g=self.g, C=self.C, mu=self.μ, nu=self.ν, maxiter=self.maxiter, nanstop=True ) ladmm_.step() ladmm_.x = ladmm_.x.at[0].set(np.nan) with pytest.raises(ValueError): ladmm_.solve() class TestBlockArray: def setup_method(self, method): np.random.seed(12345) self.y = snp.blockarray( ( np.random.randn(32, 33).astype(np.float32), np.random.randn( 17, ).astype(np.float32), ) ) self.λ = 1e0 self.maxiter = 1 self.μ = 1e-1 self.ν = 1e-1 self.A = linop.Identity(self.y.shape) self.f = loss.SquaredL2Loss(y=self.y, A=self.A) self.g = (self.λ / 2) * functional.L2Norm() self.C = linop.Identity(self.y.shape) def test_blockarray(self): ladmm_ = LinearizedADMM( f=self.f, g=self.g, C=self.C, mu=self.μ, nu=self.ν, maxiter=self.maxiter, ) x = ladmm_.solve() assert isinstance(x, BlockArray) class TestReal: def setup_method(self, method): np.random.seed(12345) N = 8 MB = 10 # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2 Amx = np.diag(np.random.randn(N).astype(np.float32)) Bmx = np.random.randn(MB, N).astype(np.float32) y = np.random.randn(N).astype(np.float32) λ = 1e0 self.Amx = Amx self.Bmx = Bmx self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x self.grdb = Amx.T @ y def test_ladmm(self): maxiter = 400 μ = 1e-2 ν = 2e-1 A = linop.Diagonal(snp.diag(self.Amx)) f = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2) * functional.SquaredL2Norm() C = linop.MatrixOperator(self.Bmx) ladmm_ = LinearizedADMM( f=f, g=g, C=C, mu=μ, nu=ν, maxiter=maxiter, x0=A.adj(self.y), ) x = ladmm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 def test_ladmm_saveload(self): maxiter = 5 μ = 1e-2 ν = 2e-1 A = linop.Diagonal(snp.diag(self.Amx)) f = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2) * functional.SquaredL2Norm() C = linop.MatrixOperator(self.Bmx) ladmm0 = LinearizedADMM( f=f, g=g, C=C, mu=μ, nu=ν, maxiter=maxiter, x0=A.adj(self.y), ) ladmm0.solve() with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "ladmm.npz") ladmm0.save_state(path) ladmm0.solve() h0 = ladmm0.history() ladmm1 = LinearizedADMM( f=f, g=g, C=C, mu=μ, nu=ν, maxiter=maxiter, x0=A.adj(self.y), ) ladmm1.load_state(path) ladmm1.solve() h1 = ladmm1.history() np.testing.assert_allclose(ladmm0.minimizer(), ladmm1.minimizer(), rtol=1e-6) assert np.abs(h0[-1].Objective - h1[-1].Objective) < 1e-6 class TestComplex: def setup_method(self, method): N = 8 MB = 10 # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2 Amx, key = random.randn((N,), dtype=np.complex64, key=None) Amx = snp.diag(Amx) Bmx, key = random.randn((MB, N), dtype=np.complex64, key=key) y, key = random.randn((N,), dtype=np.complex64, key=key) λ = 1e0 self.Amx = Amx self.Bmx = Bmx self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x self.grdb = Amx.conj().T @ y def test_ladmm(self): maxiter = 500 μ = 1e-2 ν = 2e-1 A = linop.Diagonal(snp.diag(self.Amx)) f = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2) * functional.SquaredL2Norm() C = linop.MatrixOperator(self.Bmx) ladmm_ = LinearizedADMM( f=f, g=g, C=C, mu=μ, nu=ν, maxiter=maxiter, x0=A.adj(self.y), ) x = ladmm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 5e-4 ================================================ FILE: scico/test/optimize/test_padmm.py ================================================ import os import tempfile import numpy as np import pytest import scico.numpy as snp from scico import function, functional, linop, loss, random from scico.numpy import BlockArray from scico.optimize import NonLinearPADMM, ProximalADMM class TestMisc: def setup_method(self, method): np.random.seed(12345) self.y = snp.array(np.random.randn(32, 33).astype(np.float32)) self.maxiter = 2 self.ρ = 1e0 self.μ = 1e0 self.ν = 1e0 self.A = linop.Identity(self.y.shape) self.f = loss.SquaredL2Loss(y=self.y, A=self.A) self.g = functional.DnCNN() self.H = function.Function( (self.A.input_shape, self.A.input_shape), output_shape=self.A.input_shape, eval_fn=lambda x, z: x - z, input_dtypes=np.float32, output_dtype=np.float32, ) self.x0 = snp.zeros(self.A.input_shape, dtype=snp.float32) def test_itstat_padmm(self): itstat_fields = {"Iter": "%d", "Time": "%8.2e"} def itstat_func(obj): return (obj.itnum, obj.timer.elapsed()) padmm_ = ProximalADMM( f=self.f, g=self.g, A=self.A, rho=self.ρ, mu=self.μ, nu=self.ν, x0=self.x0, z0=self.x0, u0=self.x0, maxiter=self.maxiter, ) assert len(padmm_.itstat_object.fieldname) == 4 assert snp.sum(padmm_.x) == 0.0 padmm_ = ProximalADMM( f=self.f, g=self.g, A=self.A, rho=self.ρ, mu=self.μ, nu=self.ν, B=None, maxiter=self.maxiter, itstat_options={"fields": itstat_fields, "itstat_func": itstat_func, "display": False}, ) assert len(padmm_.itstat_object.fieldname) == 2 def test_itstat_nlpadmm(self): itstat_fields = {"Iter": "%d", "Time": "%8.2e"} def itstat_func(obj): return (obj.itnum, obj.timer.elapsed()) nlpadmm_ = NonLinearPADMM( f=self.f, g=self.g, H=self.H, rho=self.ρ, mu=self.μ, nu=self.ν, x0=self.x0, z0=self.x0, u0=self.x0, maxiter=self.maxiter, ) assert len(nlpadmm_.itstat_object.fieldname) == 4 assert snp.sum(nlpadmm_.x) == 0.0 nlpadmm_ = NonLinearPADMM( f=self.f, g=self.g, H=self.H, rho=self.ρ, mu=self.μ, nu=self.ν, maxiter=self.maxiter, itstat_options={"fields": itstat_fields, "itstat_func": itstat_func, "display": False}, ) assert len(nlpadmm_.itstat_object.fieldname) == 2 def test_callback(self): padmm_ = ProximalADMM( f=self.f, g=self.g, A=self.A, rho=self.ρ, mu=self.μ, nu=self.ν, maxiter=self.maxiter, ) padmm_.test_flag = False def callback(obj): obj.test_flag = True x = padmm_.solve(callback=callback) assert padmm_.test_flag def test_finite_check(self): padmm_ = ProximalADMM( f=self.f, g=self.g, A=self.A, rho=self.ρ, mu=self.μ, nu=self.ν, maxiter=self.maxiter, nanstop=True, ) padmm_.step() padmm_.x = padmm_.x.at[0].set(np.nan) with pytest.raises(ValueError): padmm_.solve() class TestBlockArray: def setup_method(self, method): np.random.seed(12345) self.y = snp.blockarray( ( np.random.randn(32, 33).astype(np.float32), np.random.randn( 17, ).astype(np.float32), ) ) self.λ = 1e0 self.maxiter = 1 self.ρ = 1e0 self.μ = 1e0 self.ν = 1e0 self.A = linop.Identity(self.y.shape) self.f = loss.SquaredL2Loss(y=self.y, A=self.A) self.g = (self.λ / 2) * functional.L2Norm() self.H = function.Function( (self.A.input_shape, self.A.input_shape), output_shape=self.A.input_shape, eval_fn=lambda x, z: x - z, input_dtypes=np.float32, output_dtype=np.float32, ) self.x0 = snp.zeros(self.A.input_shape, dtype=snp.float32) def test_blockarray_padmm(self): padmm_ = ProximalADMM( f=self.f, g=self.g, A=self.A, rho=self.ρ, mu=self.μ, nu=self.ν, maxiter=self.maxiter, ) x = padmm_.solve() assert isinstance(x, BlockArray) def test_blockarray_nlpadmm(self): nlpadmm_ = NonLinearPADMM( f=self.f, g=self.g, H=self.H, rho=self.ρ, mu=self.μ, nu=self.ν, maxiter=self.maxiter, ) x = nlpadmm_.solve() assert isinstance(x, BlockArray) class TestReal: def setup_method(self, method): np.random.seed(12345) N = 8 MB = 10 # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2 Amx = np.diag(np.random.randn(N).astype(np.float32)) Bmx = np.random.randn(MB, N).astype(np.float32) y = np.random.randn(N).astype(np.float32) λ = 1e0 self.Amx = Amx self.Bmx = Bmx self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x self.grdb = Amx.T @ y def test_padmm(self): maxiter = 200 ρ = 1e0 μ = 5e1 ν = 1e0 A = linop.Diagonal(snp.diag(self.Amx)) f = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2) * functional.SquaredL2Norm() C = linop.MatrixOperator(self.Bmx) padmm_ = ProximalADMM( f=f, g=g, A=C, rho=ρ, mu=μ, nu=ν, maxiter=maxiter, ) x = padmm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 def test_padmm_saveload(self): maxiter = 5 ρ = 1e0 μ = 5e1 ν = 1e0 A = linop.Diagonal(snp.diag(self.Amx)) f = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2) * functional.SquaredL2Norm() C = linop.MatrixOperator(self.Bmx) padmm0 = ProximalADMM( f=f, g=g, A=C, rho=ρ, mu=μ, nu=ν, maxiter=maxiter, ) padmm0.solve() with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "padmm.npz") padmm0.save_state(path) padmm0.solve() h0 = padmm0.history() padmm1 = ProximalADMM( f=f, g=g, A=C, rho=ρ, mu=μ, nu=ν, maxiter=maxiter, ) padmm1.load_state(path) padmm1.solve() h1 = padmm1.history() np.testing.assert_allclose(padmm0.minimizer(), padmm1.minimizer(), rtol=1e-6) assert np.abs(h0[-1].Objective - h1[-1].Objective) < 1e-6 def test_nlpadmm(self): maxiter = 200 ρ = 1e0 μ = 5e1 ν = 1e0 A = linop.Diagonal(snp.diag(self.Amx)) f = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2) * functional.SquaredL2Norm() C = linop.MatrixOperator(self.Bmx) H = function.Function( (C.input_shape, C.output_shape), output_shape=C.output_shape, eval_fn=lambda x, z: C(x) - z, input_dtypes=snp.float32, output_dtype=snp.float32, ) nlpadmm_ = NonLinearPADMM( f=f, g=g, H=H, rho=ρ, mu=μ, nu=ν, maxiter=maxiter, ) x = nlpadmm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 class TestComplex: def setup_method(self, method): N = 8 MB = 10 # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2 Amx, key = random.randn((N,), dtype=np.complex64, key=None) Amx = snp.diag(Amx) Bmx, key = random.randn((MB, N), dtype=np.complex64, key=key) y, key = random.randn((N,), dtype=np.complex64, key=key) λ = 1e0 self.Amx = Amx self.Bmx = Bmx self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x self.grdb = Amx.conj().T @ y def test_nlpadmm(self): maxiter = 300 ρ = 1e0 μ = 3e1 ν = 1e0 A = linop.Diagonal(snp.diag(self.Amx)) f = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2) * functional.SquaredL2Norm() C = linop.MatrixOperator(self.Bmx) H = function.Function( (C.input_shape, C.output_shape), output_shape=C.output_shape, eval_fn=lambda x, z: C(x) - z, input_dtypes=snp.complex64, output_dtype=snp.complex64, ) nlpadmm_ = NonLinearPADMM( f=f, g=g, H=H, rho=ρ, mu=μ, nu=ν, maxiter=maxiter, ) x = nlpadmm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 class TestEstimateParameters: def setup_method(self): shape = (32, 33) self.A = linop.Identity(shape) self.Hr = function.Function( (shape, shape), output_shape=shape, eval_fn=lambda x, z: x - z, input_dtypes=np.float32, output_dtype=np.float32, ) self.Hc = function.Function( (shape, shape), output_shape=shape, eval_fn=lambda x, z: x - z, input_dtypes=np.complex64, output_dtype=np.complex64, ) def test_padmm_a(self): mu, nu = ProximalADMM.estimate_parameters(self.A, factor=1.0) assert snp.abs(mu - 1.0) < 1e-6 assert snp.abs(nu - 1.0) < 1e-6 def test_padmm_ab(self): mu, nu = ProximalADMM.estimate_parameters(self.A, self.A, factor=1.0) assert snp.abs(mu - 1.0) < 1e-6 assert snp.abs(nu - 1.0) < 1e-6 def test_real(self): mu, nu = NonLinearPADMM.estimate_parameters(self.Hr, factor=1.0) assert snp.abs(mu - 1.0) < 1e-6 assert snp.abs(nu - 1.0) < 1e-6 def test_complex(self): mu, nu = NonLinearPADMM.estimate_parameters(self.Hc, factor=1.0) assert snp.abs(mu - 1.0) < 1e-6 assert snp.abs(nu - 1.0) < 1e-6 ================================================ FILE: scico/test/optimize/test_pdhg.py ================================================ import os import tempfile import numpy as np import pytest import scico.numpy as snp from scico import functional, linop, loss, operator, random from scico.numpy import BlockArray from scico.optimize import PDHG class TestMisc: def setup_method(self, method): np.random.seed(12345) self.y = snp.array(np.random.randn(32, 33).astype(np.float32)) self.maxiter = 2 self.τ = 1e-1 self.σ = 1e-1 self.A = linop.Identity(self.y.shape) self.f = loss.SquaredL2Loss(y=self.y, A=self.A) self.g = functional.DnCNN() self.C = linop.Identity(self.y.shape) def test_itstat(self): itstat_fields = {"Iter": "%d", "Time": "%8.2e"} def itstat_func(obj): return (obj.itnum, obj.timer.elapsed()) pdhg_ = PDHG( f=self.f, g=self.g, C=self.C, tau=self.τ, sigma=self.σ, maxiter=self.maxiter, ) assert len(pdhg_.itstat_object.fieldname) == 4 assert snp.sum(pdhg_.x) == 0.0 pdhg_ = PDHG( f=self.f, g=self.g, C=self.C, tau=self.τ, sigma=self.σ, maxiter=self.maxiter, itstat_options={"fields": itstat_fields, "itstat_func": itstat_func, "display": False}, ) assert len(pdhg_.itstat_object.fieldname) == 2 def test_callback(self): pdhg_ = PDHG( f=self.f, g=self.g, C=self.C, tau=self.τ, sigma=self.σ, maxiter=self.maxiter, ) pdhg_.test_flag = False def callback(obj): obj.test_flag = True x = pdhg_.solve(callback=callback) assert pdhg_.test_flag def test_finite_check(self): pdhg_ = PDHG( f=self.f, g=self.g, C=self.C, tau=self.τ, sigma=self.σ, maxiter=self.maxiter, nanstop=True, ) pdhg_.step() pdhg_.x = pdhg_.x.at[0].set(np.nan) with pytest.raises(ValueError): pdhg_.solve() class TestBlockArray: def setup_method(self, method): np.random.seed(12345) self.y = snp.blockarray( ( np.random.randn(32, 33).astype(np.float32), np.random.randn( 17, ).astype(np.float32), ) ) self.λ = 1e0 self.maxiter = 1 self.τ = 1e-1 self.σ = 1e-1 self.A = linop.Identity(self.y.shape) self.f = loss.SquaredL2Loss(y=self.y, A=self.A) self.g = (self.λ / 2) * functional.L2Norm() self.C = linop.Identity(self.y.shape) def test_blockarray(self): pdhg_ = PDHG( f=self.f, g=self.g, C=self.C, tau=self.τ, sigma=self.σ, maxiter=self.maxiter, ) x = pdhg_.solve() assert isinstance(x, BlockArray) class TestReal: def setup_method(self, method): np.random.seed(12345) N = 8 MB = 10 # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2 Amx = np.diag(np.random.randn(N).astype(np.float32)) Bmx = np.random.randn(MB, N).astype(np.float32) y = np.random.randn(N).astype(np.float32) λ = 1e0 self.Amx = Amx self.Bmx = Bmx self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x self.grdb = Amx.T @ y def test_pdhg(self): maxiter = 300 τ = 2e-1 σ = 2e-1 A = linop.Diagonal(snp.diag(self.Amx)) f = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2) * functional.SquaredL2Norm() C = linop.MatrixOperator(self.Bmx) pdhg_ = PDHG( f=f, g=g, C=C, tau=τ, sigma=σ, maxiter=maxiter, x0=A.adj(self.y), ) x = pdhg_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 def test_pdhg_saveload(self): maxiter = 5 τ = 2e-1 σ = 2e-1 A = linop.Diagonal(snp.diag(self.Amx)) f = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2) * functional.SquaredL2Norm() C = linop.MatrixOperator(self.Bmx) pdhg0 = PDHG( f=f, g=g, C=C, tau=τ, sigma=σ, maxiter=maxiter, x0=A.adj(self.y), ) pdhg0.solve() with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "pdhg.npz") pdhg0.save_state(path) pdhg0.solve() h0 = pdhg0.history() pdhg1 = PDHG( f=f, g=g, C=C, tau=τ, sigma=σ, maxiter=maxiter, x0=A.adj(self.y), ) pdhg1.load_state(path) pdhg1.solve() h1 = pdhg1.history() np.testing.assert_allclose(pdhg0.minimizer(), pdhg1.minimizer(), atol=1e-7) assert np.abs(h0[-1].Objective - h1[-1].Objective) < 1e-6 def test_nlpdhg(self): maxiter = 300 τ = 2e-1 σ = 2e-1 A = linop.Diagonal(snp.diag(self.Amx)) f = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2) * functional.SquaredL2Norm() cfn = lambda x: self.Bmx @ x Cop = operator.operator_from_function(cfn, "Cop") C = Cop(input_shape=self.Bmx.shape[1:]) pdhg_ = PDHG( f=f, g=g, C=C, tau=τ, sigma=σ, maxiter=maxiter, x0=A.adj(self.y), ) x = pdhg_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 class TestComplex: def setup_method(self, method): N = 8 MB = 10 # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2 Amx, key = random.randn((N,), dtype=np.complex64, key=None) Amx = snp.diag(Amx) Bmx, key = random.randn((MB, N), dtype=np.complex64, key=key) y, key = random.randn((N,), dtype=np.complex64, key=key) λ = 1e0 self.Amx = Amx self.Bmx = Bmx self.y = snp.array(y) self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x self.grdb = Amx.conj().T @ y def test_pdhg(self): maxiter = 300 τ = 2e-1 σ = 2e-1 A = linop.Diagonal(snp.diag(self.Amx)) f = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2) * functional.SquaredL2Norm() C = linop.MatrixOperator(self.Bmx) pdhg_ = PDHG( f=f, g=g, C=C, tau=τ, sigma=σ, maxiter=maxiter, x0=A.adj(self.y), ) x = pdhg_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 5e-4 class TestEstimateParameters: def setup_method(self): shape = (32, 33) A = linop.Identity(shape, input_dtype=np.float32) B = linop.Identity(shape, input_dtype=np.complex64) opcls = operator.operator_from_function(lambda x: snp.abs(x), "op") C = opcls(input_shape=shape, input_dtype=np.float32) D = opcls(input_shape=shape, input_dtype=np.complex64) self.operators = [A, B, C, D] def test_operators_dlft(self): for op in self.operators[0:2]: tau, sigma = PDHG.estimate_parameters(op, factor=1.0) assert snp.abs(tau - sigma) < 1e-6 assert snp.abs(tau - 1.0) < 1e-6 def test_operators(self): for op in self.operators: x = snp.ones(op.input_shape, op.input_dtype) tau, sigma = PDHG.estimate_parameters(op, x=x, factor=None) assert snp.abs(tau - sigma) < 1e-6 assert snp.abs(tau - 1.0) < 1e-6 def test_ratio(self): op = self.operators[0] tau, sigma = PDHG.estimate_parameters(op, factor=1.0, ratio=10.0) assert snp.abs(tau * sigma - 1.0) < 1e-6 assert snp.abs(sigma - 10.0 * tau) < 1e-6 ================================================ FILE: scico/test/optimize/test_pgm.py ================================================ import os import tempfile import numpy as np import jax import pytest import scico.numpy as snp from scico import functional, linop, loss, random from scico.optimize import PGM, AcceleratedPGM from scico.optimize.pgm import ( AdaptiveBBStepSize, BBStepSize, LineSearchStepSize, RobustLineSearchStepSize, ) class TestSet: def setup_method(self, method): np.random.seed(12345) M = 5 N = 4 # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2 Amx = np.random.randn(M, N).astype(np.float32) Bmx = np.identity(N) y = snp.array(np.random.randn(M).astype(np.float32)) λ = 1e0 self.Amx = Amx self.y = y self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x self.grdb = Amx.T @ y def test_pgm(self): maxiter = 100 A = linop.MatrixOperator(self.Amx) L0 = 1.05 * linop.power_iteration(A.T @ A)[0] loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() pgm_ = PGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y)) x = pgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) def test_pgm_saveload(self): maxiter = 5 A = linop.MatrixOperator(self.Amx) L0 = 1.05 * linop.power_iteration(A.T @ A)[0] loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() pgm0 = PGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y)) pgm0.solve() with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "pgm.npz") pgm0.save_state(path) pgm0.solve() h0 = pgm0.history() pgm1 = PGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y)) pgm1.load_state(path) pgm1.solve() h1 = pgm1.history() np.testing.assert_allclose(pgm0.minimizer(), pgm1.minimizer(), rtol=1e-6) assert np.abs(h0[-1].Objective - h1[-1].Objective) < 1e-6 def test_pgm_isfinite(self): maxiter = 5 A = linop.MatrixOperator(self.Amx) L0 = 1.05 * linop.power_iteration(A.T @ A)[0] loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() pgm_ = PGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y), nanstop=True) pgm_.step() pgm_.x = pgm_.x.at[0].set(np.nan) with pytest.raises(ValueError): pgm_.solve() def test_accelerated_pgm(self): maxiter = 100 A = linop.MatrixOperator(self.Amx) L0 = 1.05 * linop.power_iteration(A.T @ A)[0] loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() apgm_ = AcceleratedPGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y)) x = apgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) def test_accelerated_pgm_saveload(self): maxiter = 5 A = linop.MatrixOperator(self.Amx) L0 = 1.05 * linop.power_iteration(A.T @ A)[0] loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() apgm0 = AcceleratedPGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y)) apgm0.solve() with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "pgm.npz") apgm0.save_state(path) apgm0.solve() h0 = apgm0.history() apgm1 = AcceleratedPGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y)) apgm1.load_state(path) apgm1.solve() h1 = apgm1.history() np.testing.assert_allclose(apgm0.minimizer(), apgm1.minimizer(), rtol=1e-6) assert np.abs(h0[-1].Objective - h1[-1].Objective) < 1e-6 def test_accelerated_pgm_isfinite(self): maxiter = 5 A = linop.MatrixOperator(self.Amx) L0 = 1.05 * linop.power_iteration(A.T @ A)[0] loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() apgm_ = AcceleratedPGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y), nanstop=True) apgm_.step() apgm_.v = apgm_.v.at[0].set(np.nan) with pytest.raises(ValueError): apgm_.solve() def test_pgm_BB_step_size(self): maxiter = 100 A = linop.MatrixOperator(self.Amx) L0 = 1.05 * linop.power_iteration(A.T @ A)[0] / 5.0 loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() pgm_ = PGM( f=loss_, g=g, L0=L0, x0=A.adj(self.y), step_size=BBStepSize(), maxiter=maxiter, ) x = pgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) def test_pgm_adaptive_BB_step_size(self): maxiter = 100 A = linop.MatrixOperator(self.Amx) L0 = 1.05 * linop.power_iteration(A.T @ A)[0] / 5.0 loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() pgm_ = PGM( f=loss_, g=g, L0=L0, x0=A.adj(self.y), step_size=AdaptiveBBStepSize(), maxiter=maxiter, ) x = pgm_.solve() def test_accelerated_pgm_line_search(self): maxiter = 150 A = linop.MatrixOperator(self.Amx) L0 = 1.05 * linop.power_iteration(A.T @ A)[0] / 5.0 loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() apgm_ = AcceleratedPGM( f=loss_, g=g, L0=L0, x0=A.adj(self.y), step_size=LineSearchStepSize(gamma_u=1.03, maxiter=55), maxiter=maxiter, ) x = apgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) def test_accelerated_pgm_robust_line_search(self): maxiter = 100 A = linop.MatrixOperator(self.Amx) L0 = 1.05 * linop.power_iteration(A.T @ A)[0] / 5.0 loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() apgm_ = AcceleratedPGM( f=loss_, g=g, L0=L0, x0=A.adj(self.y), step_size=RobustLineSearchStepSize(gamma_d=0.95, gamma_u=1.05, maxiter=80), maxiter=maxiter, ) x = apgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) def test_pgm_BB_step_size_jit(self): maxiter = 100 A = linop.MatrixOperator(self.Amx) L0 = 1.05 * linop.power_iteration(A.T @ A)[0] / 5.0 loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() pgm_ = PGM( f=loss_, g=g, L0=L0, x0=A.adj(self.y), step_size=BBStepSize(), maxiter=maxiter, ) x = pgm_.x try: update_step = jax.jit(pgm_.step_size.update) L = update_step(x) except Exception as e: print(e) assert 0 def test_accelerated_pgm_adaptive_BB_step_size_jit(self): maxiter = 100 A = linop.MatrixOperator(self.Amx) L0 = 1.05 * linop.power_iteration(A.T @ A)[0] / 5.0 loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() apgm_ = AcceleratedPGM( f=loss_, g=g, L0=L0, x0=A.adj(self.y), step_size=AdaptiveBBStepSize(), maxiter=maxiter, ) x = apgm_.x try: update_step = jax.jit(apgm_.step_size.update) L = update_step(x) except Exception as e: print(e) assert 0 class TestComplex: def setup_method(self, method): M = 5 N = 4 # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||x||_2^2 Amx, key = random.randn((M, N), dtype=np.complex64, key=None) Bmx = np.identity(N) y = snp.array(np.random.randn(M)) λ = 1e0 self.Amx = Amx self.Bmx = Bmx self.y = y self.λ = λ # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.T @ Bmx) @ x self.grdb = Amx.conj().T @ y def test_pgm(self): maxiter = 150 A = linop.MatrixOperator(self.Amx) L0 = 50.0 loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() pgm_ = PGM( f=loss_, g=g, L0=L0, x0=A.adj(self.y), maxiter=maxiter, ) x = pgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) def test_accelerated_pgm(self): maxiter = 100 A = linop.MatrixOperator(self.Amx) L0 = 50.0 loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() apgm_ = AcceleratedPGM(f=loss_, g=g, L0=L0, x0=A.adj(self.y), maxiter=maxiter) x = apgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) def test_pgm_BB_step_size(self): maxiter = 100 A = linop.MatrixOperator(self.Amx) L0 = 10.0 loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() pgm_ = PGM( f=loss_, g=g, L0=L0, x0=A.adj(self.y), step_size=BBStepSize(), maxiter=maxiter, ) x = pgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) def test_pgm_adaptive_BB_step_size(self): maxiter = 100 A = linop.MatrixOperator(self.Amx) L0 = 10.0 loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() pgm_ = PGM( f=loss_, g=g, L0=L0, x0=A.adj(self.y), step_size=AdaptiveBBStepSize(), maxiter=maxiter, ) x = pgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) def test_accelerated_pgm_line_search(self): maxiter = 100 A = linop.MatrixOperator(self.Amx) L0 = 10.0 loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() apgm_ = AcceleratedPGM( f=loss_, g=g, L0=L0, x0=A.adj(self.y), step_size=LineSearchStepSize(gamma_u=1.03, maxiter=55), maxiter=maxiter, ) x = apgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) def test_accelerated_pgm_robust_line_search(self): maxiter = 100 A = linop.MatrixOperator(self.Amx) L0 = 10.0 loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() apgm_ = AcceleratedPGM( f=loss_, g=g, L0=L0, x0=A.adj(self.y), step_size=RobustLineSearchStepSize(gamma_d=0.95, gamma_u=1.05, maxiter=80), maxiter=maxiter, ) x = apgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) ================================================ FILE: scico/test/osver.py ================================================ import platform from packaging.version import parse def osx_ver_geq_than(verstr): """Determine relative platform OSX version. Determine whether platform has OSX version that is as recent as or more recent than verstr. Returns ``False`` if the OS is not OSX. """ if platform.system() != "Darwin": return False osxver = platform.mac_ver()[0] return parse(osxver) >= parse(verstr) ================================================ FILE: scico/test/test_core.py ================================================ import numpy as np import jax import pytest import scico import scico.numpy as snp from scico.random import randn class GradTestObj: def __init__(self, dtype): M, N = (3, 4) key = jax.random.key(12345) self.dtype = dtype self.A, key = randn((M, N), dtype=dtype, key=key) self.x, key = randn((N,), dtype=dtype, key=key) self.y, key = randn((M,), dtype=dtype, key=key) self.f = lambda x: 0.5 * snp.sum(snp.abs(self.y - self.A @ x) ** 2) @pytest.fixture(scope="module", params=[np.float32, np.complex64]) def testobj(request): yield GradTestObj(request.param) def test_grad(testobj): A = testobj.A x = testobj.x y = testobj.y f = testobj.f sgrad = scico.grad(f)(x) an_grad = A.conj().T @ (A @ x - y) np.testing.assert_allclose(sgrad, an_grad, rtol=1e-4) def test_grad_aux(testobj): A = testobj.A x = testobj.x y = testobj.y def g(x): return testobj.f(x), True sgrad, aux = scico.grad(g, has_aux=True)(x) an_grad = A.conj().T @ (A @ x - y) assert aux == True np.testing.assert_allclose(sgrad, an_grad, rtol=1e-4) def test_value_and_grad(testobj): A = testobj.A x = testobj.x y = testobj.y f = testobj.f svalue, sgrad = scico.value_and_grad(f)(x) an_val = f(x) an_grad = A.conj().T @ (A @ x - y) np.testing.assert_allclose(svalue, an_val, rtol=1e-4) np.testing.assert_allclose(sgrad, an_grad, rtol=1e-4) def test_value_and_grad_aux(testobj): A = testobj.A x = testobj.x y = testobj.y def g(x): return testobj.f(x), True (svalue, aux), sgrad = scico.value_and_grad(g, has_aux=True)(x) an_val, aux_ = g(x) an_grad = A.conj().T @ (A @ x - y) assert aux == aux_ np.testing.assert_allclose(svalue, an_val, rtol=1e-4) np.testing.assert_allclose(sgrad, an_grad, rtol=1e-4) @pytest.mark.parametrize("shape", [(2, 3), ((2, 3), (4,))]) def test_linear_transpose(shape): fun = lambda x: snp.pad(x, 2) za = snp.zeros(shape, dtype=snp.float32) fza = fun(za) dts = jax.ShapeDtypeStruct(shape, dtype=snp.float32) lt_za = scico.linear_transpose(fun, za) lt_dts = scico.linear_transpose(fun, dts) lt_za_fza = lt_za(fza)[0] lt_dts_fza = lt_dts(fza)[0] assert lt_za_fza.shape == lt_dts_fza.shape assert lt_za_fza.dtype == lt_dts_fza.dtype @pytest.mark.parametrize("shape", [(2, 3), ((2, 3), (4,))]) def test_linear_adjoint_shape(shape): fun = lambda x: snp.pad(x, 2) za = snp.zeros(shape, dtype=snp.float32) fza = fun(za) dts = jax.ShapeDtypeStruct(shape, dtype=snp.float32) lt_za = scico.linear_adjoint(fun, za) lt_dts = scico.linear_adjoint(fun, dts) lt_za_fza = lt_za(fza)[0] lt_dts_fza = lt_dts(fza)[0] assert lt_za_fza.shape == lt_dts_fza.shape assert lt_za_fza.dtype == lt_dts_fza.dtype def test_linear_adjoint(testobj): # Verify that linear_adjoint returns a function that # implements f(y) = A.conj().T @ y A = testobj.A x = testobj.x y = testobj.y f = lambda x: A @ x A_adj = scico.linear_adjoint(f, x) np.testing.assert_allclose(A.conj().T @ y, A_adj(testobj.y)[0], rtol=1e-4) # Test a function with with multiple inputs # Same as np.array([0.5, -0.5j]) f = lambda x, y: 0.5 * x - 0.5j * y f_transpose = scico.linear_adjoint(f, 1.0j, 1.0j) a, b = f_transpose(1.0 + 0.0j) assert a == 0.5 assert b == 0.5j def test_linear_adjoint_r_to_c(): f = snp.fft.rfft x, key = randn((32,)) adj = scico.linear_adjoint(f, x) a = snp.sum(x * adj(f(x))[0]) b = snp.linalg.norm(f(x)) ** 2 np.testing.assert_allclose(a, b, rtol=1e-4) def test_linear_adjoint_c_to_r(): f = snp.fft.irfft x, key = randn((32,), dtype=np.complex64) adj = scico.linear_adjoint(f, x) a = snp.sum(x.conj() * adj(f(x))[0]) b = snp.linalg.norm(f(x)) ** 2 np.testing.assert_allclose(a.real, b.real, rtol=1e-4) np.testing.assert_allclose(a.imag, 0, atol=1e-2) @pytest.mark.parametrize("dtype", [np.float32, np.complex64]) def test_cvjp(dtype): A, key = randn((3, 3), dtype=dtype) B, key = randn((3, 4), dtype=dtype, key=key) xp, key = randn((3,), dtype=dtype, key=key) yp, key = randn((4,), dtype=dtype, key=key) def fun(x, y): return A @ x + B @ y px, jfnx = scico.cvjp(fun, xp, yp, jidx=0) py, jfny = scico.cvjp(fun, xp, yp, jidx=1) for k in range(3): v = np.zeros((3,), dtype=dtype) v[k] = 1.0 np.testing.assert_allclose(jfnx(v)[0], A[k].conj()) np.testing.assert_allclose(jfny(v)[0], B[k].conj()) @pytest.mark.parametrize( "argskwargs", [ [(snp.ones((3,)), snp.ones((3,)), 1.0), {}], [(1.1 * snp.ones((3,)), snp.ones((3,))), {"z": snp.zeros((3,))}], [(snp.ones(((2,), (3, 2))), 1.0, 1.0), {}], [ (snp.ones(((2,), (3, 2))), snp.blockarray(((2,), (3, 2)))), {"z": 2.0 * snp.ones(((2,), (3, 2)))}, ], ], ) def test_eval_shape_1(argskwargs): def _fun(x, y, z): """Test function""" return x + y * z def _conv(arg): """Convert array to jax.ShapeDtypeStruct.""" if hasattr(arg, "shape"): return jax.ShapeDtypeStruct(arg.shape, dtype=arg.dtype) else: return arg args, kwargs = argskwargs # Reference shape computed for array objects ref_shape = jax.eval_shape(_fun, *args, **kwargs) map_args = [_conv(v) for v in args] map_kwargs = {k: _conv(v) for k, v in kwargs.items()} # Test shape computed for jax.ShapeDtypeStruct objects tst_shape = scico.eval_shape(_fun, *map_args, **map_kwargs) assert tst_shape.shape == ref_shape.shape @pytest.mark.parametrize( "arrdts", [ [snp.ones((3, 2), dtype=snp.float32), jax.ShapeDtypeStruct((3, 2), dtype=snp.float32)], [ snp.ones(((3,), (2, 3)), dtype=snp.float32), jax.ShapeDtypeStruct(((3,), (2, 3)), dtype=snp.float32), ], ], ) def test_eval_shape_2(arrdts): _fun = lambda x: snp.pad(x, 2) arr, dts = arrdts # Reference shape computed for array ref_shape = jax.eval_shape(_fun, arr) # Test shape computed for jax.ShapeDtypeStruct tst_shape = scico.eval_shape(_fun, dts) assert tst_shape.shape == ref_shape.shape ================================================ FILE: scico/test/test_data.py ================================================ import os import pytest from scico import data skipif_reason = ( "\nThe data submodule must be cloned and initialized. If the main repository" " is already cloned, use the following in the root directory to get the data" " submodule:\n\tgit submodule update --init --recursive\nOtherwise, make sure" " to clone using:\n\tgit clone --recurse-submodules git@github.com:lanl/scico.git" "\nAnd after cloning run:\n\tgit submodule init && git submodule update.\n" ) examples = os.path.join(os.path.dirname(data.__file__), "examples") pytestmark = pytest.mark.skipif(not os.path.isdir(examples), reason=skipif_reason) class TestSet: def test_kodim23_uint(self): x = data.kodim23() assert x.dtype.name == "uint8" assert x.shape == (512, 768, 3) def test_kodim23_float(self): x = data.kodim23(asfloat=True) assert x.dtype.name == "float32" assert x.shape == (512, 768, 3) ================================================ FILE: scico/test/test_denoiser.py ================================================ import numpy as np import jax import pytest from scico.denoiser import DnCNN, bm3d, bm4d, have_bm3d, have_bm4d from scico.metric import rel_res from scico.random import randn from scico.test.osver import osx_ver_geq_than level = 3 @pytest.fixture(autouse=True, scope="module") def module_setup_teardown(request): global level level = int(request.config.getoption("--level")) # bm3d is known to be broken on OSX 11.6.5. It may be broken on earlier versions too, # but this has not been confirmed @pytest.mark.skipif(osx_ver_geq_than("11.6.5"), reason="bm3d broken on this platform") @pytest.mark.skipif(not have_bm3d, reason="bm3d package not installed") class TestBM3D: def setup_method(self): key = None self.x_gry, key = randn((32, 33), key=key, dtype=np.float32) self.x_rgb, key = randn((33, 34, 3), key=key, dtype=np.float32) def test_shape(self): assert bm3d(self.x_gry, 1.0).shape == self.x_gry.shape assert bm3d(self.x_rgb, 1.0, is_rgb=True).shape == self.x_rgb.shape def test_gry(self): no_jit = bm3d(self.x_gry, 1.0) assert no_jit.dtype == np.float32 if level > 2: jitted = jax.jit(bm3d)(self.x_gry, 1.0) assert np.linalg.norm(no_jit - jitted) < 1e-3 assert jitted.dtype == np.float32 def test_rgb(self): no_jit = bm3d(self.x_rgb, 1.0) assert no_jit.dtype == np.float32 if level > 2: jitted = jax.jit(bm3d)(self.x_rgb, 1.0, is_rgb=True) assert np.linalg.norm(no_jit - jitted) < 1e-3 assert jitted.dtype == np.float32 def test_bad_inputs(self): x, key = randn((32,), key=None, dtype=np.float32) with pytest.raises(ValueError): bm3d(x, 1.0) x, key = randn((12, 12, 4, 3), key=key, dtype=np.float32) with pytest.raises(ValueError): bm3d(x, 1.0) x, key = randn(((2, 3), (3, 4, 5)), key=key, dtype=np.float32) with pytest.raises(ValueError): bm3d(x, 1.0) x, key = randn((5, 9), key=key, dtype=np.float32) with pytest.raises(ValueError): bm3d(x, 1.0) z, key = randn((32, 32), key=key, dtype=np.complex64) with pytest.raises(TypeError): bm3d(z, 1.0) # bm4d is known to be broken on OSX 11.6.5. It may be broken on earlier versions too, # but this has not been confirmed @pytest.mark.skipif(osx_ver_geq_than("11.6.5"), reason="bm4d broken on this platform") @pytest.mark.skipif(not have_bm4d, reason="bm4d package not installed") class TestBM4D: def setup_method(self): key = None self.x1, key = randn((16, 17, 18), key=key, dtype=np.float32) self.x2, key = randn((16, 17, 8), key=key, dtype=np.float32) self.x3, key = randn((16, 17, 9, 1, 1), key=key, dtype=np.float32) def test_shape(self): if level > 2: assert bm4d(self.x1, 1.0).shape == self.x1.shape assert bm4d(self.x2, 1.0).shape == self.x2.shape if level > 1: assert bm4d(self.x3, 1.0).shape == self.x3.shape def test_jit(self): if level > 2: no_jit = bm4d(self.x1, 1.0) jitted = jax.jit(bm4d)(self.x1, 1.0) assert np.linalg.norm(no_jit - jitted) < 2e-3 assert no_jit.dtype == np.float32 assert jitted.dtype == np.float32 no_jit = bm4d(self.x2, 1.0) assert no_jit.dtype == np.float32 if level > 1: jitted = jax.jit(bm4d)(self.x2, 1.0) assert np.linalg.norm(no_jit - jitted) < 2e-3 assert jitted.dtype == np.float32 def test_bad_inputs(self): x, key = randn((32,), key=None, dtype=np.float32) with pytest.raises(ValueError): bm4d(x, 1.0) x, key = randn((12, 12, 4, 3), key=key, dtype=np.float32) with pytest.raises(ValueError): bm4d(x, 1.0) x, key = randn(((2, 3), (3, 4, 5)), key=key, dtype=np.float32) with pytest.raises(ValueError): bm4d(x, 1.0) x, key = randn((5, 9), key=key, dtype=np.float32) with pytest.raises(ValueError): bm4d(x, 1.0) z, key = randn((32, 32), key=key, dtype=np.complex64) with pytest.raises(TypeError): bm4d(z, 1.0) class TestDnCNN: def setup_method(self): key = None self.x_sngchn, key = randn((32, 33), key=key, dtype=np.float32) self.x_mltchn, key = randn((33, 34, 5), key=key, dtype=np.float32) self.dncnn = DnCNN() def test_single_channel(self): no_jit = self.dncnn(self.x_sngchn) jitted = jax.jit(self.dncnn)(self.x_sngchn) assert rel_res(no_jit, jitted) < 1e-6 assert no_jit.dtype == np.float32 assert jitted.dtype == np.float32 def test_multi_channel(self): no_jit = self.dncnn(self.x_mltchn) jitted = jax.jit(self.dncnn)(self.x_mltchn) assert rel_res(no_jit, jitted) < 1e-6 assert no_jit.dtype == np.float32 assert jitted.dtype == np.float32 def test_init(self): dncnn = DnCNN(variant="6L") x = dncnn(self.x_sngchn) dncnn = DnCNN(variant="17H") x = dncnn(self.x_mltchn) with pytest.raises(ValueError): dncnn = DnCNN(variant="3A") def test_bad_inputs(self): x, key = randn((32,), key=None, dtype=np.float32) with pytest.raises(ValueError): self.dncnn(x) x, key = randn((12, 12, 4, 3), key=None, dtype=np.float32) with pytest.raises(ValueError): self.dncnn(x) x, key = randn(((2, 3), (3, 4, 5)), key=None, dtype=np.float32) with pytest.raises(ValueError): self.dncnn(x) z, key = randn((32, 32), key=None, dtype=np.complex64) with pytest.raises(TypeError): self.dncnn(z) class TestNonBLindDnCNN: def setup_method(self): key = None self.x_sngchn, key = randn((32, 33), key=key, dtype=np.float32) self.x_mltchn, key = randn((33, 34, 5), key=key, dtype=np.float32) self.sigma = 0.1 self.dncnn = DnCNN(variant="6N") def test_single_channel(self): rslt = self.dncnn(self.x_sngchn, sigma=self.sigma) assert rslt.dtype == np.float32 def test_multi_channel(self): rslt = self.dncnn(self.x_mltchn, sigma=self.sigma) assert rslt.dtype == np.float32 def test_bad_inputs(self): with pytest.raises(ValueError): rslt = self.dncnn(self.x_sngchn) ================================================ FILE: scico/test/test_diagnostics.py ================================================ from collections import OrderedDict import pytest from scico import diagnostics class TestSet: def test_itstat(self): its = diagnostics.IterationStats(OrderedDict({"Iter": "%d", "Obj Val": "%8.2e"})) its.insert((0, 1.5)) its.insert((1, 1e2)) assert its.history()[0].Iter == 0 assert its.history()[1].Iter == 1 assert its.history()[1].Obj_Val == 1e2 assert its.history(transpose=True).Obj_Val == [1.5, 100.0] def test_display(self, capsys): its = diagnostics.IterationStats({"Iter": "%d"}, display=True, period=2, overwrite=False) its.insert((0,)) cap = capsys.readouterr() assert cap.out == "Iter\n----\n 0\n" its.insert((1,)) cap = capsys.readouterr() assert cap.out == "" its.insert((2,)) cap = capsys.readouterr() assert cap.out == " 2\n" def test_exception(self): with pytest.raises(TypeError): its = diagnostics.IterationStats(["Iter", "%z4d"], display=False) with pytest.raises(ValueError): its = diagnostics.IterationStats({"Iter": "%z4d"}, display=False) def test_warning(self): with pytest.warns(UserWarning): its = diagnostics.IterationStats({"Iter": "%4e"}, display=False) ================================================ FILE: scico/test/test_examples.py ================================================ import os import tempfile import numpy as np import imageio.v3 as iio import pytest import scico.numpy as snp from scico.examples import ( create_3d_foam_phantom, create_circular_phantom, create_cone, create_conv_sparse_phantom, create_tangle_phantom, downsample_volume, epfl_deconv_data, gaussian, phase_diff, rgb2gray, spnoise, tile_volume_slices, ucb_diffusercam_data, volume_read, ) # These tests are for the scico.examples module, NOT the example scripts def test_rgb2gray(): rgb = np.ones((31, 32, 3), dtype=np.float32) gry = rgb2gray(rgb) assert np.abs(gry.mean() - 1.0) < 1e-6 def test_volume_read(): temp_dir = tempfile.TemporaryDirectory() v0 = np.zeros((32, 32), dtype=np.uint16) v1 = np.ones((32, 32), dtype=np.uint16) iio.imwrite(os.path.join(temp_dir.name, "v0.tif"), v0) iio.imwrite(os.path.join(temp_dir.name, "v1.tif"), v1) vol = volume_read(temp_dir.name, ext="tif") assert np.allclose(v0, vol[..., 0]) and np.allclose(v1, vol[..., 1]) def test_epfl_deconv_data(): temp_dir = tempfile.TemporaryDirectory() y0 = np.zeros((32, 32), dtype=np.uint16) psf0 = np.ones((32, 32), dtype=np.uint16) np.savez(os.path.join(temp_dir.name, "epfl_big_deconv_0.npz"), y=y0, psf=psf0) y, psf = epfl_deconv_data(0, cache_path=temp_dir.name) assert np.allclose(y0, y) and np.allclose(psf0, psf) def test_ucb_diffusercam_data(): temp_dir = tempfile.TemporaryDirectory() y0 = np.zeros((32, 32), dtype=np.uint16) psf0 = np.ones((8, 32, 32), dtype=np.uint16) np.savez(os.path.join(temp_dir.name, "ucb_diffcam_data.npz"), y=y0, psf=psf0) y, psf = ucb_diffusercam_data(cache_path=temp_dir.name) assert np.allclose(y0, y) and np.allclose(psf0, psf) def test_downsample_volume(): v0 = np.zeros((32, 32, 16)) v1 = downsample_volume(v0, rate=1) assert v0.shape == v1.shape v0 = np.zeros((32, 32, 16)) v1 = downsample_volume(v0, rate=2) assert tuple([n // 2 for n in v0.shape]) == v1.shape v0 = np.zeros((32, 32, 16)) v1 = downsample_volume(v0, rate=3) assert tuple([round(n / 3) for n in v0.shape]) == v1.shape def test_tile_volume_slices(): v = np.ones((16, 16, 16)) tvs = tile_volume_slices(v) assert tvs.ndim == 2 v = np.ones((16, 16, 16, 3)) tvs = tile_volume_slices(v) assert tvs.ndim == 3 and tvs.shape[-1] == 3 def test_gaussian(): g0 = gaussian((5, 5)) assert g0.shape == (5, 5) g1 = gaussian((5, 5), sigma=np.array([[3, 0], [0, 2]])) assert np.sum(g1 / g1.max()) > np.sum(g0 / g0.max()) with pytest.raises(ValueError): g2 = gaussian((5, 5), sigma=np.array([[2, 2], [2, 2]])) def test_create_circular_phantom(): img_shape = (32, 32) radius_list = [2, 4, 8] val_list = [2, 4, 8] x_gt = create_circular_phantom(img_shape, radius_list, val_list) assert x_gt.shape == img_shape assert np.max(x_gt) == max(val_list) assert np.min(x_gt) == 0 @pytest.mark.parametrize( "img_shape", ( (3, 3), (50, 51), (3, 3, 3), ), ) def test_create_cone(img_shape): x_gt = create_cone(img_shape) assert x_gt.shape == img_shape # check symmetry assert np.abs(x_gt[(0,) * len(img_shape)] - x_gt[(-1,) * len(img_shape)]) < 1e-6 @pytest.mark.parametrize( "img_shape", ( (3, 3, 3), (20, 21, 22), (15, 15, 5), ), ) @pytest.mark.parametrize("N_sphere", (3, 10, 20)) def test_create_3d_foam_phantom(img_shape, N_sphere): x_gt = create_3d_foam_phantom(img_shape, N_sphere) assert x_gt.shape == img_shape def test_conv_sparse_phantom(): h, x = create_conv_sparse_phantom(64, 32) assert h.shape == (3, 15, 15) assert x.shape == (3, 64, 64) assert np.sum(x > 0) == 32 def test_tangle_phantom(): v = create_tangle_phantom(3, 4, 5) assert v.shape == (5, 4, 3) def test_spnoise(): x = 0.5 * np.ones((10, 11)) y = spnoise(x, 0.5, nmin=0.01, nmax=0.99) assert np.all(y >= 0.01) assert np.all(y <= 0.99) x = 0.5 * snp.ones((10, 11)) y = spnoise(x, 0.5, nmin=0.01, nmax=0.99) assert np.all(y >= 0.01) assert np.all(y <= 0.99) def test_phase_diff(): x = np.pi * np.random.randn(16) y = np.pi * np.random.randn(16) d = phase_diff(x, y) assert np.all(d >= 0) assert np.all(d <= np.pi) ================================================ FILE: scico/test/test_function.py ================================================ import numpy as np import pytest import scico.numpy as snp from scico.function import Function from scico.linop import jacobian from scico.random import randn class TestFunction: def setup_method(self): key = None self.shape = (7, 8) self.dtype = snp.float32 self.x, key = randn(self.shape, key=key, dtype=self.dtype) self.y, key = randn(self.shape, key=key, dtype=self.dtype) self.func = lambda x, y: snp.abs(x) + snp.abs(y) def test_init(self): F = Function((self.shape, self.shape), input_dtypes=self.dtype, eval_fn=self.func) assert F.output_shape == self.shape assert len(F.input_dtypes) == 2 assert F.output_dtype == self.dtype def test_eval(self): F = Function( (self.shape, self.shape), output_shape=self.shape, eval_fn=self.func, input_dtypes=(self.dtype, self.dtype), output_dtype=self.dtype, ) np.testing.assert_allclose(self.func(self.x, self.y), F(self.x, self.y)) def test_eval_jit(self): F = Function( (self.shape, self.shape), output_shape=self.shape, eval_fn=self.func, input_dtypes=(self.dtype, self.dtype), output_dtype=self.dtype, jit=True, ) np.testing.assert_allclose(self.func(self.x, self.y), F(self.x, self.y)) def test_slice(self): F = Function((self.shape, self.shape), input_dtypes=self.dtype, eval_fn=self.func) Op = F.slice(0, self.y) np.testing.assert_allclose(Op(self.x), F(self.x, self.y)) def test_join(self): F = Function((self.shape, self.shape), input_dtypes=self.dtype, eval_fn=self.func) Op = F.join() np.testing.assert_allclose(Op(snp.blockarray((self.x, self.y))), F(self.x, self.y)) def test_join_raise(self): F = Function( (self.shape, self.shape), input_dtypes=(snp.float32, snp.complex64), eval_fn=self.func ) with pytest.raises(ValueError): Op = F.join() @pytest.mark.parametrize("dtype", [snp.float32, snp.complex64]) def test_jacobian(dtype): N = 7 M = 8 key = None fmx, key = randn((M, N), key=key, dtype=dtype) gmx, key = randn((M, N), key=key, dtype=dtype) F = Function(((N, 1), (N, 1)), input_dtypes=dtype, eval_fn=lambda x, y: fmx @ x + gmx @ y) u0, key = randn((N, 1), key=key, dtype=dtype) u1, key = randn((N, 1), key=key, dtype=dtype) v, key = randn((N, 1), key=key, dtype=dtype) w, key = randn((M, 1), key=key, dtype=dtype) op = F.slice(0, u1) J0op = jacobian(op, u0) np.testing.assert_allclose(J0op(v), F.jvp(0, v, u0, u1)[1]) np.testing.assert_allclose(J0op.H(w), F.vjp(0, u0, u1)[1](w)) J0fn = F.jacobian(0, u0, u1) np.testing.assert_allclose(J0op(v), J0fn(v)) np.testing.assert_allclose(J0op.H(w), J0fn.H(w)) op = F.slice(1, u0) J1op = jacobian(op, u1) np.testing.assert_allclose(J1op(v), F.jvp(1, v, u0, u1)[1]) np.testing.assert_allclose(J1op.H(w), F.vjp(1, u0, u1)[1](w)) J1fn = F.jacobian(1, u0, u1) np.testing.assert_allclose(J1op(v), J1fn(v)) np.testing.assert_allclose(J1op.H(w), J1fn.H(w)) ================================================ FILE: scico/test/test_metric.py ================================================ import numpy as np import scico.numpy as snp from scico import metric class TestSet: def setup_method(self, method): np.random.seed(12345) def test_mae_mse(self): N = 16 x = np.random.randn(N) y = x.copy() y[0] = 0 xe = np.abs(x[0]) e1 = metric.mae(x, y) e2 = metric.mse(x, y) assert np.abs(e1 - xe / N) < 1e-12 assert np.abs(e2 - (xe**2) / N) < 1e-12 def test_snr_nrm(self): N = 16 x = np.random.randn(N) x /= np.sqrt(np.var(x)) y = x + 1 assert np.abs(metric.snr(x, y)) < 1e-6 def test_snr_signal_range(self): N = 16 x = np.random.randn(N) x -= x.min() x /= x.max() y = x + 1 assert np.abs(metric.psnr(x, y)) < 1e-6 def test_psnr(self): N = 16 x = np.random.randn(N) y = x + 1 assert np.abs(metric.psnr(x, y, signal_range=1.0)) < 1e-6 def test_isnr(self): N = 16 x = np.random.randn(N) y = np.random.randn(N) assert np.abs(metric.isnr(x, y, y)) < 1e-6 def test_bsnr(self): N = 16 x = np.random.randn(N) x /= np.sqrt(np.var(x)) n = np.random.randn(N) n /= np.sqrt(np.var(n)) y = x + n assert np.abs(metric.bsnr(x, y)) < 1e-6 def test_rel_res(): A = snp.array([[2, -1], [1, 0], [-1, 1]], dtype=snp.float32) x = snp.array([[3], [-2]], dtype=snp.float32) Ax = snp.matmul(A, x) b = snp.array([[8], [3], [-5]], dtype=snp.float32) assert 0.0 == metric.rel_res(Ax, b) A = snp.array([[2, -1], [1, 0], [-1, 1]], dtype=snp.float32) x = snp.array([[0], [0]], dtype=snp.float32) Ax = snp.matmul(A, x) b = snp.array([[0], [0], [0]], dtype=snp.float32) assert 0.0 == metric.rel_res(Ax, b) ================================================ FILE: scico/test/test_random.py ================================================ import numpy as np import jax import pytest import scico.random @pytest.mark.parametrize("seed", [None, 42]) def test_wrapped_funcs(seed): fun = jax.random.normal fun_wrapped = scico.random.normal # test seed argument if seed is None: key = jax.random.key(0) else: key = jax.random.key(seed) np.testing.assert_array_equal(fun(key), fun_wrapped(seed=seed)[0]) # test blockarray shape = ((7,), (3, 2), (2, 4, 1)) seed = 42 key = jax.random.key(seed) result, _ = fun_wrapped(shape, seed=seed) def test_add_seed_adapter(): fun = jax.random.normal fun_alt = scico.random._add_seed(fun) # specify a seed instead of a key assert fun(jax.random.key(42)) == fun_alt(seed=42)[0] # seed defaults to zero assert fun(jax.random.key(0)) == fun_alt()[0] # other parameters still work ... key = jax.random.key(0) sz = (10, 3) dtype = np.float64 # ... positional np.testing.assert_array_equal(fun(key, sz), fun_alt(sz)[0]) np.testing.assert_array_equal(fun(key, sz, dtype), fun_alt(sz, dtype)[0]) np.testing.assert_array_equal(fun(key, sz, dtype), fun_alt(sz, dtype, key)[0]) np.testing.assert_array_equal(fun(key, sz, dtype), fun_alt(sz, dtype, None, 0)[0]) # ... keyword np.testing.assert_array_equal(fun(shape=sz, key=key), fun_alt(shape=sz)[0]) np.testing.assert_array_equal( fun(shape=sz, key=key, dtype=dtype), fun_alt(dtype=dtype, shape=sz)[0] ) # ... mixed np.testing.assert_array_equal( fun(key, dtype=dtype, shape=sz), fun_alt(dtype=dtype, shape=sz)[0] ) # get back the split key _, key_a = fun_alt(seed=42) key_b, _ = jax.random.split(jax.random.key(42), 2) assert key_a == key_b # error when key and seed are specified with pytest.raises(ValueError): _ = fun_alt(key=jax.random.key(0), seed=42)[0] ================================================ FILE: scico/test/test_ray_tune.py ================================================ import os import tempfile import numpy as np import pytest try: import ray from scico.ray import report, tune except ImportError as e: pytest.skip("ray.tune not installed", allow_module_level=True) def test_random_run(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} tune.ray.tune.register_trainable("eval_func", eval_params) analysis = tune.run( "eval_func", metric="cost", mode="min", num_samples=100, config=config, resources_per_trial=resources, hyperopt=False, verbose=False, storage_path=os.path.join(tempfile.gettempdir(), "ray_test"), ) best_config = analysis.get_best_config(metric="cost", mode="min") assert np.abs(best_config["x"]) < 0.25 assert np.abs(best_config["y"] - 0.5) < 0.25 def test_random_tune(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} tuner = tune.Tuner( eval_params, param_space=config, resources=resources, metric="cost", mode="min", num_samples=100, hyperopt=False, verbose=False, storage_path=os.path.join(tempfile.gettempdir(), "ray_test"), ) results = tuner.fit() best_config = results.get_best_result().config assert np.abs(best_config["x"]) < 0.25 assert np.abs(best_config["y"] - 0.5) < 0.25 def test_hyperopt_run(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} analysis = tune.run( eval_params, metric="cost", mode="min", num_samples=50, config=config, resources_per_trial=resources, hyperopt=True, verbose=True, ) best_config = analysis.get_best_config(metric="cost", mode="min") assert np.abs(best_config["x"]) < 0.25 assert np.abs(best_config["y"] - 0.5) < 0.25 def test_hyperopt_tune(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} tuner = tune.Tuner( eval_params, param_space=config, resources=resources, metric="cost", mode="min", num_samples=50, hyperopt=True, verbose=True, ) results = tuner.fit() best_config = results.get_best_result().config assert np.abs(best_config["x"]) < 0.25 assert np.abs(best_config["y"] - 0.5) < 0.25 def test_hyperopt_tune_alt_init(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} tuner = tune.Tuner( eval_params, param_space=config, max_concurrent_trials=4, metric="cost", mode="min", num_samples=50, time_budget=2, hyperopt=True, verbose=True, tune_config=ray.tune.TuneConfig(), run_config=ray.tune.RunConfig(), ) results = tuner.fit() best_config = results.get_best_result().config assert np.abs(best_config["x"]) < 0.25 assert np.abs(best_config["y"] - 0.5) < 0.25 ================================================ FILE: scico/test/test_scipy_special.py ================================================ import numpy as np import pytest import scico.scipy.special as ss from scico.random import randn # these are functions that take only a single ndarray as input one_arg_funcs = [ ss.digamma, ss.entr, ss.erf, ss.erfc, ss.erfinv, ss.expit, ss.gammaln, ss.i0, ss.i0e, ss.i1, ss.i1e, ss.ndtr, ss.log_ndtr, ss.logit, ss.ndtri, ] @pytest.mark.parametrize("func", one_arg_funcs) def test_one_arg_funcs(func): # blockarray array x, key = randn(((8, 8), (4,)), key=None) Fx = func(x) fx0 = func(x[0]) fx1 = func(x[1]) np.testing.assert_allclose(Fx[0].ravel(), fx0.ravel(), rtol=1e-4) np.testing.assert_allclose(Fx[1].ravel(), fx1.ravel(), rtol=1e-4) def test_betainc(): a, key = randn(((8, 8), (4,)), key=None) b, key = randn(((8, 8), (4,)), key=key) x, key = randn(((8, 8), (4,)), key=key) Fx = ss.betainc(a, b, x) fx0 = ss.betainc(a[0], b[0], x[0]) fx1 = ss.betainc(a[1], b[1], x[1]) np.testing.assert_allclose(Fx[0].ravel(), fx0.ravel(), rtol=1e-4) np.testing.assert_allclose(Fx[1].ravel(), fx1.ravel(), rtol=1e-4) @pytest.mark.parametrize("func", [ss.gammainc, ss.gammaincc]) def test_gammainc(func): a, key = randn(((8, 8), (4,)), key=None) b, key = randn(((8, 8), (4,)), key=key) x, key = randn(((8, 8), (4,)), key=key) Fx = ss.betainc(a, b, x) fx0 = ss.betainc(a[0], b[0], x[0]) fx1 = ss.betainc(a[1], b[1], x[1]) np.testing.assert_allclose(Fx[0].ravel(), fx0.ravel(), rtol=1e-4) np.testing.assert_allclose(Fx[1].ravel(), fx1.ravel(), rtol=1e-4) def test_multigammaln(): x, key = randn(((8, 8), (4,)), key=None) d = 2 Fx = ss.multigammaln(x, d) fx0 = ss.multigammaln(x[0], d) fx1 = ss.multigammaln(x[1], d) np.testing.assert_allclose(Fx[0].ravel(), fx0.ravel(), rtol=1e-4) np.testing.assert_allclose(Fx[1].ravel(), fx1.ravel(), rtol=1e-4) @pytest.mark.parametrize("func", [ss.xlog1py, ss.xlogy]) def test_logs(func): x, key = randn(((8, 8), (4,)), key=None) y, key = randn(((8, 8), (4,)), key=key) Fx = func(x, y) fx0 = func(x[0], y[0]) fx1 = func(x[1], y[1]) np.testing.assert_allclose(Fx[0].ravel(), fx0.ravel(), rtol=1e-4) np.testing.assert_allclose(Fx[1].ravel(), fx1.ravel(), rtol=1e-4) def test_zeta(): x, key = randn(((8, 8), (4,)), key=None) y, key = randn(((8, 8), (4,)), key=None) Fx = ss.zeta(x, y) fx0 = ss.zeta(x[0], y[0]) fx1 = ss.zeta(x[1], y[1]) np.testing.assert_allclose(Fx[0].ravel(), fx0.ravel(), rtol=1e-4) np.testing.assert_allclose(Fx[1].ravel(), fx1.ravel(), rtol=1e-4) ================================================ FILE: scico/test/test_solver.py ================================================ import numpy as np from jax.scipy.linalg import block_diag import pytest import scico.numpy as snp from scico import linop, metric, random, solver class TestSet: def setup_method(self, method): np.random.seed(12345) def test_wrap_func_and_grad(self): N = 8 A = snp.array(np.random.randn(N, N)) x = snp.array(np.random.randn(N)) f = lambda x: 0.5 * snp.linalg.norm(A @ x) ** 2 func_and_grad = solver._wrap_func_and_grad(f, shape=(N,), dtype=x.dtype) fx, grad = func_and_grad(x) np.testing.assert_allclose(fx, f(x), rtol=5e-5) np.testing.assert_allclose(grad, A.T @ A @ x, rtol=5e-5) def test_cg_std(self): N = 64 Ac = np.random.randn(N, N) Am = Ac.dot(Ac.T) A = Am.dot x = np.random.randn(N) b = Am.dot(x) x0 = np.zeros((N,)) tol = 1e-12 try: xcg, info = solver.cg(A, b, x0, tol=tol) except Exception as e: print(e) assert 0 assert info["rel_res"].ndim == 0 assert np.linalg.norm(A(xcg) - b) / np.linalg.norm(b) < 1e-6 def test_cg_op(self): N = 32 Ac = np.random.randn(N, N).astype(np.float32) Am = Ac.dot(Ac.T) A = Am.dot x = np.random.randn(N).astype(np.float32) b = Am.dot(x) tol = 1e-12 try: xcg, info = solver.cg(linop.MatrixOperator(Am), b, tol=tol) except Exception as e: print(e) assert 0 assert info["rel_res"].ndim == 0 assert np.linalg.norm(A(xcg) - b) / np.linalg.norm(b) < 1e-6 def test_cg_no_info(self): N = 64 Ac = np.random.randn(N, N) Am = Ac.dot(Ac.T) A = Am.dot x = np.random.randn(N) b = Am.dot(x) x0 = np.zeros((N,)) tol = 1e-12 try: xcg = solver.cg(A, b, x0, tol=tol, info=False) except Exception as e: print(e) assert 0 assert np.linalg.norm(A(xcg) - b) / np.linalg.norm(b) < 1e-6 def test_cg_complex(self): N = 64 Ac = np.random.randn(N, N) + 1j * np.random.randn(N, N) Am = Ac.dot(Ac.conj().T) A = Am.dot x = np.random.randn(N) + 1j * np.random.randn(N) b = Am.dot(x) x0 = np.zeros_like(x) tol = 1e-12 try: xcg, info = solver.cg(A, b, x0, tol=tol) except Exception as e: print(e) assert 0 assert np.linalg.norm(A(xcg) - b) / np.linalg.norm(b) < 1e-6 def test_preconditioned_cg(self): N = 64 D = np.diag(np.linspace(0.1, 20, N)) Ac = D @ np.random.randn( N, N ) # Poorly scaled matrix; good fit for diagonal preconditioning Am = Ac.dot(Ac.conj().T) A = Am.dot Mm = np.diag(1 / np.diag(Am)) # inverse of diagonal of Am M = Mm.dot x = np.random.randn(N) + 1j * np.random.randn(N) b = Am.dot(x) x0 = np.zeros_like(x) tol = 1e-12 x_cg, cg_info = solver.cg(A, b, x0, tol=tol, info=True, M=None, maxiter=3) x_pcg, pcg_info = solver.cg(A, b, x0, tol=tol, info=True, M=M, maxiter=3) # Assert that PCG converges faster in a few iterations assert cg_info["rel_res"] > 3 * pcg_info["rel_res"] def test_lstsq_func(self): N = 24 M = 32 Ac = snp.array(np.random.randn(N, M).astype(np.float32)) Am = Ac.dot(Ac.T) A = Am.dot x = snp.array(np.random.randn(N).astype(np.float32)) b = Am.dot(x) x0 = snp.zeros((N,), dtype=np.float32) tol = 1e-6 try: xlsq = solver.lstsq(A, b, x0=x0, tol=tol) except Exception as e: print(e) assert 0 assert np.linalg.norm(A(xlsq) - b) / np.linalg.norm(b) < 5e-6 def test_lstsq_op(self): N = 32 M = 24 Ac = snp.array(np.random.randn(N, M).astype(np.float32)) A = linop.MatrixOperator(Ac) x = snp.array(np.random.randn(M).astype(np.float32)) b = Ac.dot(x) tol = 1e-7 try: xlsq = solver.lstsq(A, b, tol=tol) except Exception as e: print(e) assert 0 assert np.linalg.norm(A(xlsq) - b) / np.linalg.norm(b) < 1e-6 class TestOptimizeScalar: # Adopted from SciPy minimize_scalar tests # https://github.com/scipy/scipy/blob/701ffcc8a6f04509d115aac5e5681c538b5265a2/scipy/optimize/tests/test_optimize.py#L1364 def setup_method(self): self.solution = 1.5 self.rtol = 1e-3 def fun(self, x, a=1.5): """Objective function""" # Jax version of (x - a)**2 - 0.8; will return a devicearray return snp.square(x - a) - 0.8 def test_minimize_scalar(self): # combine all tests above for the minimize_scalar wrapper x = solver.minimize_scalar(self.fun).x np.testing.assert_allclose(x, self.solution, rtol=self.rtol) x = solver.minimize_scalar(self.fun, method="Brent") np.testing.assert_(x.success) x = solver.minimize_scalar(self.fun, method="Brent", options=dict(maxiter=3)) np.testing.assert_(not x.success) x = solver.minimize_scalar(self.fun, bracket=(-3, -2), args=(1.5,), method="Brent").x np.testing.assert_allclose(x, self.solution, rtol=self.rtol) x = solver.minimize_scalar(self.fun, method="Brent", args=(1.5,)).x np.testing.assert_allclose(x, self.solution, rtol=self.rtol) x = solver.minimize_scalar(self.fun, bracket=(-15, -1, 15), args=(1.5,), method="Brent").x np.testing.assert_allclose(x, self.solution, rtol=self.rtol) x = solver.minimize_scalar(self.fun, bracket=(-3, -2), args=(1.5,), method="golden").x np.testing.assert_allclose(x, self.solution, rtol=self.rtol) x = solver.minimize_scalar(self.fun, method="golden", args=(1.5,)).x np.testing.assert_allclose(x, self.solution, rtol=self.rtol) x = solver.minimize_scalar(self.fun, bracket=(-15, -1, 15), args=(1.5,), method="golden").x np.testing.assert_allclose(x, self.solution, rtol=self.rtol) x = solver.minimize_scalar(self.fun, bounds=(0, 1), args=(1.5,), method="Bounded").x np.testing.assert_allclose(x, 1, rtol=1e-4) x = solver.minimize_scalar(self.fun, bounds=(1, 5), args=(1.5,), method="bounded").x np.testing.assert_allclose(x, self.solution, rtol=self.rtol) x = solver.minimize_scalar( self.fun, bounds=(np.array([1]), np.array([5])), args=(np.array([1.5]),), method="bounded", ).x np.testing.assert_allclose(x, self.solution, rtol=self.rtol) @pytest.mark.parametrize("dtype", [snp.float32, snp.complex64]) @pytest.mark.parametrize("method", ["CG", "L-BFGS-B"]) def test_minimize_vector(dtype, method): B, M, N = (4, 3, 2) # model a 12x8 block-diagonal matrix with 3x2 blocks A, key = random.randn((B, M, N), dtype=dtype) x, key = random.randn((B, N), dtype=dtype, key=key) y = snp.sum(A * x[:, None], axis=2) # contract along the N axis # result by directly inverting the dense matrix A_mat = block_diag(*A) expected = snp.linalg.pinv(A_mat) @ y.ravel() def f(x): return 0.5 * snp.linalg.norm(y - snp.sum(A * x[:, None], axis=2)) ** 2 out = solver.minimize(f, x0=snp.zeros_like(x), method=method) assert out.x.shape == x.shape np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4) @pytest.mark.parametrize("dtype", [snp.float32]) @pytest.mark.parametrize("method", ["CG"]) def test_minimize_blockarray(dtype, method): # model a 6x8 block-diagonal matrix with 3x4 blocks A, key = random.randn(((3, 4), (3, 4)), dtype=dtype) x, key = random.randn(((4,), (4,)), dtype=dtype, key=key) y = A @ x # result by directly inverting the dense matrix A_mat = block_diag(*A) expected = snp.linalg.pinv(A_mat) @ y.stack(axis=0).ravel() def f(x): return 0.5 * snp.linalg.norm(y - A @ x) ** 2 out = solver.minimize(f, x0=snp.zeros_like(x), method=method) assert out.x.shape == x.shape np.testing.assert_allclose(solver._ravel(out.x), expected, rtol=5e-4) def test_split_join_array(): x, key = random.randn((4, 4), dtype=np.complex64) x_s = solver._split_real_imag(x) assert x_s.shape == (2, 4, 4) np.testing.assert_allclose(x_s[0], snp.real(x)) np.testing.assert_allclose(x_s[1], snp.imag(x)) x_j = solver._join_real_imag(x_s) np.testing.assert_allclose(x_j, x, rtol=1e-4) def test_split_join_blockarray(): x, key = random.randn(((4, 4), (3,)), dtype=np.complex64) x_s = solver._split_real_imag(x) assert x_s.shape == ((2, 4, 4), (2, 3)) real_block = snp.blockarray((x_s[0][0], x_s[1][0])) imag_block = snp.blockarray((x_s[0][1], x_s[1][1])) snp.testing.assert_allclose(real_block, snp.real(x), rtol=1e-4) snp.testing.assert_allclose(imag_block, snp.imag(x), rtol=1e-4) x_j = solver._join_real_imag(x_s) snp.testing.assert_allclose(x_j, x, rtol=1e-4) def test_bisect(): f = lambda x: x**3 x, info = solver.bisect(f, -snp.ones((5, 1)), snp.ones((5, 1)), full_output=True) assert snp.sum(snp.abs(x)) == 0.0 assert info["iter"] == 0 x = solver.bisect(f, -2.0 * snp.ones((5, 3)), snp.ones((5, 3)), xtol=1e-5, ftol=1e-5) assert snp.max(snp.abs(x)) <= 1e-5 assert snp.max(snp.abs(f(x))) <= 1e-5 c, key = random.randn((5, 1), dtype=np.float32) f = lambda x, c: x**3 - c**3 x = solver.bisect(f, -snp.abs(c) - 1, snp.abs(c) + 1, args=(c,), xtol=1e-5, ftol=1e-5) assert snp.max(snp.abs(x - c)) <= 1e-5 assert snp.max(snp.abs(f(x, c))) <= 1e-5 def test_golden(): f = lambda x: x**2 x, info = solver.golden(f, -snp.ones((5, 1)), snp.ones((5, 1)), full_output=True) assert snp.max(snp.abs(x)) <= 1e-7 x = solver.golden(f, -2.0 * snp.ones((5, 3)), snp.ones((5, 3)), xtol=1e-5) assert snp.max(snp.abs(x)) <= 1e-5 c, key = random.randn((5, 1), dtype=np.float32) f = lambda x, c: (x - c) ** 2 x = solver.golden(f, -snp.abs(c) - 1, snp.abs(c) + 1, args=(c,), xtol=1e-5) assert snp.max(snp.abs(x - c)) <= 1e-5 @pytest.mark.parametrize("cho_factor", [True, False]) @pytest.mark.parametrize("wide", [True, False]) @pytest.mark.parametrize("weighted", [True, False]) @pytest.mark.parametrize("alpha", [1e-1, 1e1]) def test_solve_atai(cho_factor, wide, weighted, alpha): A, key = random.randn((5, 8), dtype=snp.float32) if wide: x0, key = random.randn((8,), key=key) else: A = A.T x0, key = random.randn((5,), key=key) if weighted: W, key = random.randn((A.shape[0],), key=key) W = snp.abs(W) Wa = W[:, snp.newaxis] else: W = None Wa = snp.array([1.0])[:, snp.newaxis] D = alpha * snp.ones((A.shape[1],)) ATAD = A.T @ (Wa * A) + alpha * snp.identity(A.shape[1]) b = ATAD @ x0 slv = solver.MatrixATADSolver(A, D, W=W, cho_factor=cho_factor) x1 = slv.solve(b) assert metric.rel_res(x0, x1) < 5e-5 @pytest.mark.parametrize("cho_factor", [True, False]) @pytest.mark.parametrize("wide", [True, False]) @pytest.mark.parametrize("alpha", [1e-1, 1e1]) def test_solve_aati(cho_factor, wide, alpha): A, key = random.randn((5, 8), dtype=snp.float32) if wide: x0, key = random.randn((5,), key=key) else: A = A.T x0, key = random.randn((8,), key=key) D = alpha * snp.ones((A.shape[0],)) AATD = A @ A.T + alpha * snp.identity(A.shape[0]) b = AATD @ x0 slv = solver.MatrixATADSolver(A.T, D) x1 = slv.solve(b) assert metric.rel_res(x0, x1) < 5e-5 @pytest.mark.parametrize("cho_factor", [True, False]) @pytest.mark.parametrize("wide", [True, False]) @pytest.mark.parametrize("vector", [True, False]) def test_solve_atad(cho_factor, wide, vector): A, key = random.randn((5, 8), dtype=snp.float32) if wide: D, key = random.randn((8,), key=key) if vector: x0, key = random.randn((8,), key=key) else: x0, key = random.randn((8, 3), key=key) else: A = A.T D, key = random.randn((5,), key=key) if vector: x0, key = random.randn((5,), key=key) else: x0, key = random.randn((5, 3), key=key) D = snp.abs(D) # only required for Cholesky, but improved accuracy for LU ATAD = A.T @ A + snp.diag(D) b = ATAD @ x0 slv = solver.MatrixATADSolver(A, D, cho_factor=cho_factor) x1 = slv.solve(b) assert metric.rel_res(x0, x1) < 5e-5 assert slv.accuracy(x1, b) < 5e-5 ================================================ FILE: scico/test/test_util.py ================================================ import socket import urllib.error as urlerror import numpy as np import jax import pytest import scico.numpy as snp from scico.util import ( ContextTimer, Timer, check_for_tracer, partial, rgetattr, rsetattr, url_get, ) def test_rattr(): class A: class B: c = 0 b = B() a = A() rsetattr(a, "b.c", 1) assert rgetattr(a, "b.c") == 1 assert rgetattr(a, "c.d", 10) == 10 with pytest.raises(AttributeError): assert rgetattr(a, "c.d") def test_partial_pos(): def func(a, b, c, d): return a + 2 * b + 4 * c + 8 * d pfunc = partial(func, (0, 2), 0, 0) assert pfunc(1, 0) == 2 and pfunc(0, 1) == 8 def test_partial_kw(): def func(a=1, b=1, c=1, d=1): return a + 2 * b + 4 * c + 8 * d pfunc = partial(func, (), a=0, c=0) assert pfunc(b=1, d=0) == 2 and pfunc(b=0, d=1) == 8 def test_partial_pos_and_kw(): def func(a, b, c=1, d=1): return a + 2 * b + 4 * c + 8 * d pfunc = partial(func, (0,), 0, c=0) assert pfunc(1, d=0) == 2 and pfunc(0, d=1) == 8 # See https://stackoverflow.com/a/33117579 def _internet_connected(host="8.8.8.8", port=53, timeout=3): """Check if internet connection available. Host: 8.8.8.8 (google-public-dns-a.google.com) OpenPort: 53/tcp Service: domain (DNS/TCP) """ try: socket.setdefaulttimeout(timeout) socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect((host, port)) return True except socket.error as ex: return False @pytest.mark.skipif(not _internet_connected(), reason="No internet connection") def test_url_get(): url = "https://github.com/lanl/scico/blob/main/README.md" headers = { "User-Agent": "Mozilla/5.0 (X11; Linux x86_64)", "Referer": "https://github.com/lanl/scico/blob/main", } try: uget = url_get(url, headers=headers) except urlerror.HTTPError as e: if e.code != 429: raise else: assert not uget.getvalue().find(b"SCICO") == -1 np.testing.assert_raises(ValueError, url_get, url, maxtry=-1) url = "about:blank" np.testing.assert_raises(urlerror.URLError, url_get, url) def test_check_for_tracer(): # Using examples from Jax documentation A = snp.ones((5, 5)) x = snp.ones((10, 5)) @check_for_tracer def norm(X): X = X - X.mean(0) return X / X.std(0) with pytest.raises(TypeError): check_norm = jax.jit(norm) check_norm(x) vv = check_for_tracer(lambda x: A @ x) with pytest.raises(TypeError): mv = jax.vmap(vv) mv(x) def test_timer_basic(): t = Timer() t.start() t0 = t.elapsed() t.stop() t1 = t.elapsed() assert t0 >= 0.0 assert t1 >= t0 assert len(t.__str__()) > 0 assert len(t.labels()) > 0 def test_timer_multi(): t = Timer("a") t.start(["a", "b"]) t0 = t.elapsed("a") t.stop("a") t.stop("b") t.stop(["a", "b"]) assert t.elapsed("a") >= 0.0 assert t.elapsed("b") >= 0.0 assert t.elapsed("a", total=False) == 0.0 def test_timer_reset(): t = Timer("a") t.start(["a", "b"]) t.reset("a") assert t.elapsed("a") == 0.0 t.reset("all") assert t.elapsed("b") == 0.0 def test_ctxttimer_basic(): t = Timer() with ContextTimer(t): t0 = t.elapsed() assert t.elapsed() >= 0.0 def test_ctxttimer_stopstart(): t = Timer() t.start() with ContextTimer(t, action="StopStart"): t0 = t.elapsed() t.stop() assert t.elapsed() >= 0.0 ================================================ FILE: scico/test/test_version.py ================================================ from scico._version import variable_assign_value test_var_num = 12345 test_var_str = "12345" def test_var_val(): assert variable_assign_value(__file__, "test_var_num") == test_var_num assert variable_assign_value(__file__, "test_var_str") == test_var_str ================================================ FILE: scico/trace.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2024-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Call tracing of scico functions and methods. JIT must be disabled for tracing to function correctly (set environment variable :code:`JAX_DISABLE_JIT=1`, or call :code:`jax.config.update('jax_disable_jit', True)` before importing `jax` or `scico`). Call :code:`trace_scico_calls` to initialize tracing, and call :code:`register_variable` to associate a name with a variable so that it can be referenced by name in the call trace. The call trace is color-code as follows if `colorama `_ is installed: - `module and class names`: light red - `function and method names`: dark red - `arguments and return values`: light blue - `names of registered variables`: light yellow When a method defined in a class is called for an object of a derived class type, the class of that object is displayed in light magenta, in square brackets. Function names and return values are distinguished by initial ``>>`` and ``<<`` characters respectively. A usage example is provided in the script :code:`trace_example.py`. """ from __future__ import annotations import inspect import sys import types import warnings from collections import defaultdict from functools import wraps from typing import Any, Callable, Optional, Sequence import numpy as np import jax try: from jaxlib.xla_extension import PjitFunction except ImportError: from jaxlib._jax import PjitFunction # jax >= 0.6.1 try: import colorama have_colorama = True except ImportError: have_colorama = False if have_colorama: clr_main = colorama.Fore.LIGHTRED_EX # main trace information clr_rvar = colorama.Fore.LIGHTYELLOW_EX # registered variable names clr_self = colorama.Fore.LIGHTMAGENTA_EX # type of object for which method is called clr_func = colorama.Fore.RED # function/method name clr_args = colorama.Fore.LIGHTBLUE_EX # function/method arguments clr_retv = colorama.Fore.LIGHTBLUE_EX # function/method return values clr_devc = colorama.Fore.CYAN # JAX array device and sharding clr_reset = colorama.Fore.RESET # reset color else: clr_main, clr_rvar, clr_self, clr_func = "", "", "", "" clr_args, clr_retv, clr_devc, clr_reset = "", "", "", "" def _get_hash(val: Any) -> Optional[int]: """Get a hash representing an object. Args: val: An object for which the hash is required. Returns: A hash value of ``None`` if a hash cannot be computed. """ if isinstance(val, np.ndarray): hash = val.ctypes.data # for an ndarray, hash is the memory address elif hasattr(val, "__hash__") and callable(val.__hash__): try: hash = val.__hash__() except TypeError: hash = None else: hash = None return hash def _trace_arg_repr(val: Any) -> str: """Compute string representation of function arguments. Args: val: Argument value Returns: A string representation of the argument. """ if val is None: return "None" elif np.isscalar(val): # a scalar value return str(val) elif isinstance(val, tuple) and len(val) < 6 and all([np.isscalar(s) for s in val]): return f"{val}" # a short sequence of scalars elif isinstance(val, np.dtype): # a numpy dtype return f"numpy.{val}" elif isinstance(val, type): # a class name return f"{val.__module__}.{val.__qualname__}" elif isinstance(val, np.ndarray) and _get_hash(val) in call_trace.instance_hash: # type: ignore return f"{clr_rvar}{call_trace.instance_hash[_get_hash(val)]}{clr_args}" # type: ignore elif isinstance(val, (np.ndarray, jax.Array)): # a jax or numpy array if val.shape == (): return str(val) else: dev_str, shard_str = "", "" if isinstance(val, jax.Array) and not isinstance( val, jax._src.interpreters.partial_eval.JaxprTracer ): if call_trace.show_jax_device: # type: ignore platform = list(val.devices())[0].platform # assume all of same type devices = ",".join(map(str, sorted([d.id for d in val.devices()]))) dev_str = f"{clr_devc}{{dev={platform}({devices})}}{clr_args}" if call_trace.show_jax_sharding and isinstance( # type: ignore val.sharding, jax._src.sharding_impls.PositionalSharding ): shard_str = f"{clr_devc}{{shard={val.sharding.shape}}}{clr_args}" return f"Array{val.shape}{dev_str}{shard_str}" else: if _get_hash(val) in call_trace.instance_hash: # type: ignore return f"{clr_rvar}{call_trace.instance_hash[val.__hash__()]}{clr_args}" # type: ignore else: return f"[{type(val).__name__}]" def register_variable(var: Any, name: str): """Register a variable name for call tracing. Any hashable object (or numpy array, with the memory address used as a hash) may be registered. JAX arrays may not be registered since they are not hashable and there is no clear mechanism for associating them with a unique memory address. Args: var: The variable to be registered. name: The name to be associated with the variable. """ hash = _get_hash(var) if hash is None: raise ValueError(f"Can't get hash for variable '{name}'.") call_trace.instance_hash[hash] = name # type: ignore def _call_wrapped_function(func: Callable, *args, **kwargs) -> Any: """Call a wrapped function within the wrapper. Handle different call mechanisms required for static and class methods. Args: func: Wrapped function *args: Positional arguments **kwargs: Named arguments Returns: Return value of wrapped function. """ if isinstance(func, staticmethod): # If the type of the first argument is the same as the class to # which the static method belongs, assume that it was called as # .(), which requires that the first # argument be stripped before calling the method. This is # somewhat heuristic, and may fail, but there is no obvious # mechanism for reliably determining how the method was called in # the calling scope. if inspect._findclass(func) == type(args[0]): # type: ignore call_args = args[1:] else: call_args = args ret = func(*call_args, **kwargs) elif isinstance(func, classmethod): ret = func.__func__(*args, **kwargs) else: ret = func(*args, **kwargs) return ret def call_trace(func: Callable) -> Callable: """Print log of calls to `func`. Decorator for printing a log of calls to the wrapped function. A record of call levels is maintained so that call nesting is indicated by call log indentation. """ try: method_class = inspect._findclass(func) # type: ignore except AttributeError: method_class = None @wraps(func) def wrapper(*args, **kwargs): name = f"{func.__module__}.{clr_func}{func.__qualname__}" arg_idx = 0 if ( args and hasattr(args[0], "__hash__") and callable(args[0].__hash__) and method_class and isinstance(args[0], method_class) ): # first argument is self for a method call arg_idx = 1 # skip self in handling arguments if args[0].__hash__() in call_trace.instance_hash: # self object registered using register_variable name = ( f"{clr_rvar}{call_trace.instance_hash[args[0].__hash__()]}." f"{clr_func}{func.__name__}" ) else: # self object not registered func_class = method_class.__name__ self_class = args[0].__class__.__name__ # If the class in which this method is defined is same as that # of the self object for which it's called, just display the # class name. Otherwise, display the name of the name defining # class followed by the name of the self object class in # square brackets. if func_class == self_class: class_name = func_class else: class_name = f"{func_class}{clr_self}[{self_class}]{clr_main}" name = f"{func.__module__}.{class_name}.{clr_func}{func.__name__}" args_repr = [_trace_arg_repr(val) for val in args[arg_idx:]] kwargs_repr = [f"{key}={_trace_arg_repr(val)}" for key, val in kwargs.items()] args_str = clr_args + ", ".join(args_repr + kwargs_repr) + clr_main print( f"{clr_main}>> {' ' * 2 * call_trace.trace_level}{name}" f"({args_str}{clr_func}){clr_reset}", file=sys.stderr, ) # call wrapped function call_trace.trace_level += 1 ret = _call_wrapped_function(func, *args, **kwargs) call_trace.trace_level -= 1 # print representation of return value if ret is not None and call_trace.show_return_value: print( f"{clr_main}<< {' ' * 2 * call_trace.trace_level}{clr_retv}" f"{_trace_arg_repr(ret)}{clr_reset}", file=sys.stderr, ) return ret # Set flag indicating that function is already wrapped wrapper._call_trace_wrap = True # type: ignore # Avoid multiple wrapper layers if hasattr(func, "_call_trace_wrap"): return func else: return wrapper # call level counter for call_trace decorator call_trace.trace_level = 0 # type: ignore # hash dict allowing association of objects with variable names call_trace.instance_hash = {} # type: ignore # flag indicating whether to show function return value call_trace.show_return_value = True # type: ignore # flag indicating whether to show JAX array devices call_trace.show_jax_device = False # type: ignore # flag indicating whether to show JAX array sharding shape call_trace.show_jax_sharding = False # type: ignore def _submodule_name(module, obj): if ( len(obj.__name__) > len(module.__name__) and obj.__name__[0 : len(module.__name__)] == module.__name__ ): short_name = obj.__name__[len(module.__name__) + 1 :] else: short_name = "" return short_name def _is_scico_object(obj: Any) -> bool: """Determine whether an object is defined in a scico submodule. Args: obj: Object to check. Returns: A boolean value indicating whether `obj` is defined in a scico submodule. """ return hasattr(obj, "__module__") and obj.__module__[0:5] == "scico" def _is_scico_module(mod: types.ModuleType) -> bool: """Determine whether a module is a scico submodule. Args: mod: Module to check. Returns: A boolean value indicating whether `mod` is a scico submodule. """ return hasattr(mod, "__name__") and mod.__name__[0:5] == "scico" def _in_module(mod: types.ModuleType, obj: Any) -> bool: """Determine whether an object is defined in a module. Args: mod: Module of interest. obj: Object to check. Returns: A boolean value indicating whether `obj` is defined in `mod`. """ return obj.__module__ == mod.__name__ def _is_submodule(mod: types.ModuleType, submod: types.ModuleType) -> bool: """Determine whether a module is a submodule of another module. Args: mod: Parent module of interest. submod: Possible submodule to check. Returns: A boolean value indicating whether `submod` is defined in `mod`. """ return submod.__name__[0 : len(mod.__name__)] == mod.__name__ def apply_decorator( module: types.ModuleType, decorator: Callable, recursive: bool = True, skip: Optional[Sequence] = None, seen: Optional[defaultdict[str, int]] = None, verbose: bool = False, level: int = 0, ) -> defaultdict[str, int]: """Apply a decorator function to all functions in a scico module. Apply a decorator function to all functions in a scico module, including methods of classes in that module. Args: module: The module containing the functions/methods to be decorated. decorator: The decorator function to apply to each module function/method. recursive: Flag indicating whether to recurse into submodules of the specified module. (Hidden modules with a name starting with an underscore are ignored.) skip: A list of class/function/method names to be skipped. seen: A :class:`defaultdict` providing a count of the number of times each function/method was seen. verbose: Flag indicating whether to print a log of functions as they are encountered. level: Counter for recursive call levels. Returns: A :class:`defaultdict` providing a count of the number of times each function/method was seen. """ indent = " " * 4 * level if skip is None: skip = [] if seen is None: seen = defaultdict(int) if verbose: print(f"{indent}Module: {module.__name__}") indent += " " * 4 # Iterate over functions in module for name, func in inspect.getmembers( module, lambda obj: isinstance(obj, (types.FunctionType, PjitFunction)) and _in_module(module, obj), ): if name in skip: continue qualname = func.__module__ + "." + func.__qualname__ if not seen[qualname]: # avoid multiple applications of decorator setattr(module, name, decorator(func)) seen[qualname] += 1 if verbose: print(f"{indent}Function: {qualname}") # Iterate over classes in module for name, cls in inspect.getmembers( module, lambda obj: inspect.isclass(obj) and _in_module(module, obj) ): qualname = cls.__module__ + "." + cls.__qualname__ # type: ignore if verbose: print(f"{indent}Class: {qualname}") # Iterate over methods in class for name, func in inspect.getmembers( cls, lambda obj: isinstance(obj, (types.FunctionType, PjitFunction)) ): if name in skip: continue qualname = func.__module__ + "." + func.__qualname__ # type: ignore if not seen[qualname]: # avoid multiple applications of decorator # Can't use cls returned by inspect.getmembers because it uses plain # getattr internally, which interferes with identification of static # methods. From Python 3.11 onwards one could use # inspect.getmembers_static instead of inspect.getmembers, but that # would imply incompatibility with earlier Python versions. func = inspect.getattr_static(cls, name) setattr(cls, name, decorator(func)) seen[qualname] += 1 if verbose: print(f"{indent + ' '}Method: {qualname}") # Iterate over submodules of module if recursive: for name, mod in inspect.getmembers( module, lambda obj: inspect.ismodule(obj) and _is_submodule(module, obj) ): if name[0:1] == "_": continue seen = apply_decorator( mod, decorator, recursive=recursive, skip=skip, seen=seen, verbose=verbose, level=level + 1, ) return seen def trace_scico_calls(verbose: bool = False): """Enable tracing of calls to all significant scico functions/methods. Enable tracing of calls to all significant scico functions and methods. Note that JIT should be disabled to ensure correct functioning of the tracing mechanism. """ if not jax.config.jax_disable_jit: warnings.warn( "Call tracing requested but jit is not disabled. Disable jit" " by setting the environment variable JAX_DISABLE_JIT=1, or use" " jax.config.update('jax_disable_jit', True)." ) from scico import ( function, functional, linop, loss, metric, operator, optimize, solver, ) seen = None for module in (functional, linop, loss, operator, optimize, function, metric, solver): seen = apply_decorator(module, call_trace, skip=["__repr__"], seen=seen, verbose=verbose) ================================================ FILE: scico/typing.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2021-2025 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SPORCO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed # with the package. """Type definitions.""" from typing import Any, List, Tuple, Union try: # available in python 3.10 from types import EllipsisType # type: ignore from typing import TypeAlias # type: ignore except ImportError: from typing_extensions import TypeAlias # type: ignore EllipsisType: TypeAlias = Any # type: ignore import jax.numpy as jnp from jax import Array PRNGKey: TypeAlias = Array """A key for jax random number generators (see :mod:`jax.random`).""" DType: TypeAlias = Union[ jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64, jnp.float16, jnp.float32, jnp.float64, jnp.complex64, jnp.complex128, bool, ] """A jax dtype.""" Shape: TypeAlias = Tuple[int, ...] """A shape of a numpy or jax array.""" BlockShape: TypeAlias = Tuple[Tuple[int, ...], ...] """A shape of a :class:`.BlockArray`.""" Axes: TypeAlias = Union[int, Tuple[int, ...], List[int]] """Specification of one or more array axes.""" AxisIndex: TypeAlias = Union[slice, EllipsisType, int] """An entity suitable for indexing/slicing of a single array axis; either a slice object, Ellipsis, or int.""" ArrayIndex: TypeAlias = Union[AxisIndex, Tuple[AxisIndex, ...]] """An entity suitable for indexing/slicing of multi-dimentional arrays.""" ================================================ FILE: scico/util.py ================================================ # -*- coding: utf-8 -*- # Copyright (C) 2020-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """General utility functions.""" from __future__ import annotations import io import socket import urllib.error as urlerror import urllib.request as urlrequest from functools import reduce, wraps from timeit import default_timer as timer from typing import Any, Callable, Dict, List, Optional, Sequence, Union import jax def rgetattr(obj: object, name: str, default: Optional[Any] = None) -> Any: """Recursive version of :func:`getattr`. Args: obj: Object with the attribute to be accessed. name: Path to object in with components delimited by a "." character. default: Default value to be returned if the attribute does not exist. Returns: Attribute value of default if attribute does not exist. """ try: return reduce(getattr, name.split("."), obj) except AttributeError as e: if default is not None: return default else: raise e def rsetattr(obj: object, name: str, value: Any): """Recursive version of :func:`setattr`. Args: obj: Object with the attribute to be set. name: Path to object in with components delimited by a "." character. value: Value to which the attribute is to be set. """ # See goo.gl/BVJ7MN path = name.split(".") setattr(reduce(getattr, path[:-1], obj), path[-1], value) def partial(func: Callable, indices: Sequence, *fixargs: Any, **fixkwargs: Any) -> Callable: """Flexible partial function creation. This function is similar to :func:`functools.partial`, but allows fixing of arbitrary positional arguments rather than just some number of trailing positional arguments. Args: func: Function from which partial function is to be derived. indices: Tuple of indices of positional args of `func` that are to be fixed to the values specified in `fixargs`. *fixargs: Fixed values for specified positional arguments. **fixkwargs: Fixed values for keyword arguments. Returns: The partial function with fixed arguments. """ def pfunc(*freeargs, **freekwargs): numargs = len(fixargs) + len(freeargs) args = [ None, ] * numargs kfix = 0 kfree = 0 for k in range(numargs): if k in indices: args[k] = fixargs[kfix] kfix += 1 else: args[k] = freeargs[kfree] kfree += 1 kwargs = freekwargs.copy() kwargs.update(fixkwargs) return func(*args, **kwargs) posdoc = "" if indices: posdoc = f"positional arguments {','.join(map(str, indices))}" kwdoc = "" if fixkwargs: kwdoc = f"keyword arguments {','.join(fixkwargs.keys())}" pfunc.__doc__ = f"Partial function derived from function {func.__name__}" if posdoc or kwdoc: pfunc.__doc__ += " by fixing " + (" and ".join(filter(None, (posdoc, kwdoc)))) return pfunc def device_info(devid: int = 0) -> str: # pragma: no cover """Get a string describing the specified device. Args: devid: ID number of device. Returns: Device description string. """ numdev = jax.device_count() if devid >= numdev: raise RuntimeError(f"Requested information for device {devid} but only {numdev} present.") dev = jax.devices()[devid] if dev.platform == "cpu": info = "CPU" else: info = f"{dev.platform.upper()} ({dev.device_kind})" return info def check_for_tracer(func: Callable) -> Callable: """Check if positional arguments to `func` are jax tracers. This is intended to be used as a decorator for functions that call external code from within SCICO. At present, external functions cannot be jit-ed or vmap/pmaped. This decorator checks for signs of jit/vmap/pmap and raises an appropriate exception. """ @wraps(func) def wrapper(*args, **kwargs): if any([isinstance(x, jax.core.Tracer) for x in args]): raise TypeError( f"JAX tracer found in {func.__name__}; did you jit/vmap/pmap this function?" ) return func(*args, **kwargs) return wrapper def url_get( url: str, headers: Optional[dict] = None, maxtry: int = 3, timeout: int = 10 ) -> io.BytesIO: # pragma: no cover """Get content of a file via a URL. Args: url: URL of the file to be downloaded. headers: Dict of header strings for request. maxtry: Maximum number of download retries. timeout: Timeout in seconds for blocking operations. Returns: Buffered I/O stream. Raises: ValueError: If the maxtry parameter is not greater than zero. urllib.error.URLError: If the file cannot be downloaded. """ if maxtry <= 0: raise ValueError("Argument 'maxtry' should be greater than zero.") if headers is None: headers = {} req = urlrequest.Request(url, headers=headers) for ntry in range(maxtry): try: rspns = urlrequest.urlopen(req, timeout=timeout) cntnt = rspns.read() break except urlerror.URLError as e: if not isinstance(e.reason, socket.timeout): raise return io.BytesIO(cntnt) # Timer classes are copied from https://github.com/bwohlberg/sporco class Timer: """Timer class supporting multiple independent labeled timers. The timer is based on the relative time returned by :func:`timeit.default_timer`. """ def __init__( self, labels: Optional[Union[str, List[str]]] = None, default_label: str = "main", all_label: str = "all", ): """ Args: labels: Label(s) of the timer(s) to be initialised to zero. default_label: Default timer label to be used when methods are called without specifying a label. all_label: Label string that will be used to denote all timer labels. """ # Initialise current and accumulated time dictionaries self.t0: Dict[str, Optional[float]] = {} self.td: Dict[str, float] = {} # Record default label and string indicating all labels self.default_label = default_label self.all_label = all_label # Initialise dictionary entries for labels to be created # immediately if labels is not None: if not isinstance(labels, (list, tuple)): labels = [ labels, ] for lbl in labels: self.td[lbl] = 0.0 self.t0[lbl] = None def start(self, labels: Optional[Union[str, List[str]]] = None): """Start specified timer(s). Args: labels: Label(s) of the timer(s) to be started. If it is ``None``, start the default timer with label specified by the `default_label` parameter of :meth:`__init__`. """ # Default label is self.default_label if labels is None: labels = self.default_label # If label is not a list or tuple, create a singleton list # containing it if not isinstance(labels, (list, tuple)): labels = [ labels, ] # Iterate over specified label(s) t = timer() for lbl in labels: # On first call to start for a label, set its accumulator to zero if lbl not in self.td: self.td[lbl] = 0.0 self.t0[lbl] = None # Record the time at which start was called for this lbl if # it isn't already running if self.t0[lbl] is None: self.t0[lbl] = t def stop(self, labels: Optional[Union[str, List[str]]] = None): """Stop specified timer(s). Args: labels: Label(s) of the timer(s) to be stopped. If it is ``None``, stop the default timer with label specified by the `default_label` parameter of :meth:`__init__`. If it is equal to the string specified by the `all_label` parameter of :meth:`__init__`, stop all timers. """ # Get current time t = timer() # Default label is self.default_label if labels is None: labels = self.default_label # All timers are affected if label is equal to self.all_label, # otherwise only the timer(s) specified by label if labels == self.all_label: labels = list(self.t0.keys()) elif not isinstance(labels, (list, tuple)): labels = [ labels, ] # Iterate over specified label(s) for lbl in labels: if lbl not in self.t0: raise KeyError(f"Unrecognized timer key {lbl}.") # If self.t0[lbl] is None, the corresponding timer is # already stopped, so no action is required if self.t0[lbl] is not None: # Increment time accumulator from the elapsed time # since most recent start call self.td[lbl] += t - self.t0[lbl] # type: ignore # Set start time to None to indicate timer is not running self.t0[lbl] = None def reset(self, labels: Optional[Union[str, List[str]]] = None): """Reset specified timer(s). Args: labels: Label(s) of the timer(s) to be stopped. If it is ``None``, stop the default timer with label specified by the `default_label` parameter of :meth:`__init__`. If it is equal to the string specified by the `all_label` parameter of :meth:`__init__`, stop all timers. """ # Default label is self.default_label if labels is None: labels = self.default_label # All timers are affected if label is equal to self.all_label, # otherwise only the timer(s) specified by label if labels == self.all_label: labels = list(self.t0.keys()) elif not isinstance(labels, (list, tuple)): labels = [ labels, ] # Iterate over specified label(s) for lbl in labels: if lbl not in self.t0: raise KeyError(f"Unrecognized timer key {lbl}.") # Set start time to None to indicate timer is not running self.t0[lbl] = None # Set time accumulator to zero self.td[lbl] = 0.0 def elapsed(self, label: Optional[str] = None, total: bool = True) -> float: """Get elapsed time since timer start. Args: label: Label of the timer for which the elapsed time is required. If it is ``None``, the default timer with label specified by the `default_label` parameter of :meth:`__init__` is selected. total: If ``True`` return the total elapsed time since the first call of :meth:`start` for the selected timer, otherwise return the elapsed time since the most recent call of :meth:`start` for which there has not been a corresponding call to :meth:`stop`. Returns: Elapsed time. """ # Get current time t = timer() # Default label is self.default_label if label is None: label = self.default_label # Return 0.0 if default timer selected and it is not initialised if label not in self.t0: return 0.0 # Raise exception if timer with specified label does not exist if label not in self.t0: raise KeyError(f"Unrecognized timer key {label}.") # If total flag is True return sum of accumulated time from # previous start/stop calls and current start call, otherwise # return just the time since the current start call te = 0.0 if self.t0[label] is not None: te = t - self.t0[label] # type: ignore if total: te += self.td[label] return te def labels(self) -> List[str]: """Get a list of timer labels. Returns: List of timer labels. """ return list(self.t0.keys()) def __str__(self) -> str: """Return string representation of object. The representation consists of a table with the following columns: * Timer label. * Accumulated time from past start/stop calls. * Time since current start call, or 'Stopped' if timer is not currently running. """ # Get current time t = timer() # Length of label field, calculated from max label length fldlen = [len(lbl) for lbl in self.t0] + [ len(self.default_label), ] lfldln = max(fldlen) + 2 # Header string for table of timers s = f"{'Label':{lfldln}s} Accum. Current\n" s += "-" * (lfldln + 25) + "\n" # Construct table of timer details for lbl in sorted(self.t0): td = self.td[lbl] if self.t0[lbl] is None: ts = " Stopped" else: ts = f" {(t - self.t0[lbl]):.2e} s" % (t - self.t0[lbl]) # type: ignore s += f"{lbl:{lfldln}s} {td:.2e} s {ts}\n" return s class ContextTimer: """A wrapper class for :class:`Timer` that enables its use as a context manager. For example, instead of >>> t = Timer() >>> t.start() >>> x = sum(range(1000)) >>> t.stop() >>> elapsed = t.elapsed() one can use >>> t = Timer() >>> with ContextTimer(t): ... x = sum(range(1000)) >>> elapsed = t.elapsed() """ def __init__( self, timer: Optional[Timer] = None, label: Optional[str] = None, action: str = "StartStop", ): """ Args: timer: Timer object to be used as a context manager. If ``None``, a new class:`Timer` object is constructed. label: Label of the timer to be used. If it is ``None``, start the default timer. action: Actions to be taken on context entry and exit. If the value is 'StartStop', start the timer on entry and stop on exit; if it is 'StopStart', stop the timer on entry and start it on exit. """ if action not in ["StartStop", "StopStart"]: raise ValueError(f"Unrecognized action {action}.") if timer is None: self.timer = Timer() else: self.timer = timer self.label = label self.action = action def __enter__(self): """Start the timer and return this ContextTimer instance.""" if self.action == "StartStop": self.timer.start(self.label) else: self.timer.stop(self.label) return self def __exit__(self, exc_type, exc_value, traceback): """Stop the timer and return ``True`` if no exception was raised within the `with` block, otherwise return ``False``. """ if self.action == "StartStop": self.timer.stop(self.label) else: self.timer.start(self.label) return not exc_type def elapsed(self, total: bool = True) -> float: """Return the elapsed time for the timer. Args: total: If ``True`` return the total elapsed time since the first call of :meth:`start` for the selected timer, otherwise return the elapsed time since the most recent call of :meth:`start` for which there has not been a corresponding call to :meth:`stop`. Returns: Elapsed time. """ return self.timer.elapsed(self.label, total=total) ================================================ FILE: setup.py ================================================ """SCICO package configuration.""" import importlib.util import os import os.path import site import sys from setuptools import find_namespace_packages, setup # Import module scico._version without executing __init__.py spec = importlib.util.spec_from_file_location("_version", os.path.join("scico", "_version.py")) module = importlib.util.module_from_spec(spec) sys.modules["_version"] = module spec.loader.exec_module(module) from _version import package_version name = "scico" version = package_version() # Add argument exclude=["test", "test.*"] to exclude test subpackage packages = find_namespace_packages(where="scico") packages = ["scico"] + [f"scico.{m}" for m in packages] longdesc = """ SCICO is a Python package for solving the inverse problems that arise in scientific imaging applications. Its primary focus is providing methods for solving ill-posed inverse problems by using an appropriate prior model of the reconstruction space. SCICO includes a growing suite of operators, cost functionals, regularizers, and optimization routines that may be combined to solve a wide range of problems, and is designed so that it is easy to add new building blocks. SCICO is built on top of JAX, which provides features such as automatic gradient calculation and GPU acceleration. """ # Set install_requires from requirements.txt file with open("requirements.txt") as f: lines = f.readlines() install_requires = [line.strip() for line in lines] python_requires = ">=3.8" tests_require = ["pytest", "pytest-runner"] extra_require_files = [ "dev_requirements.txt", os.path.join("docs", "docs_requirements.txt"), os.path.join("examples", "examples_requirements.txt"), os.path.join("examples", "notebooks_requirements.txt"), ] extras_require = {"tests": tests_require} for require_file in extra_require_files: extras_label = os.path.basename(require_file).partition("_")[0] with open(require_file) as f: lines = f.readlines() extras_require[extras_label] = [line.strip() for line in lines if line[0:2] != "-r"] # PEP517 workaround, see https://www.scivision.dev/python-pip-devel-user-install/ site.ENABLE_USER_SITE = True setup( name=name, version=version, description="Scientific Computational Imaging COde: A Python " "package for scientific imaging problems", long_description=longdesc, keywords=[ "Computational Imaging", "Scientific Imaging", "Inverse Problems", "Plug-and-Play Priors", "Total Variation", "Optimization", "ADMM", "Linearized ADMM", "PDHG", "PGM", ], platforms="Any", license="BSD-3-Clause", url="https://github.com/lanl/scico", author="SCICO Developers", author_email="brendt@ieee.org", # Temporary packages=packages, package_data={"scico": ["data/*/*.png", "data/*/*.npz"]}, include_package_data=True, python_requires=python_requires, install_requires=install_requires, extras_require=extras_require, classifiers=[ "Development Status :: 4 - Beta", "Intended Audience :: Education", "Intended Audience :: Science/Research", "Operating System :: OS Independent", "Programming Language :: Python :: 3", "Topic :: Scientific/Engineering :: Information Analysis", "Topic :: Scientific/Engineering :: Mathematics", "Topic :: Software Development :: Libraries :: Python Modules", ], zip_safe=False, )