Repository: jcmgray/cotengra
Branch: main
Commit: e0f164f452e6
Files: 118
Total size: 8.3 MB
Directory structure:
gitextract_lc2zkv6z/
├── .codecov.yml
├── .gitattributes
├── .github/
│ ├── dependabot.yml
│ └── workflows/
│ ├── pypi-release.yml
│ └── test.yml
├── .gitignore
├── .readthedocs.yml
├── LICENSE.md
├── MANIFEST.in
├── README.md
├── cotengra/
│ ├── __init__.py
│ ├── contract.py
│ ├── core.py
│ ├── core_multi.py
│ ├── experimental/
│ │ ├── __init__.py
│ │ ├── hyper_de.py
│ │ ├── hyper_pe.py
│ │ ├── hyper_pymoo.py
│ │ ├── hyper_scipy.py
│ │ ├── hyper_smac.py
│ │ ├── multi.ipynb
│ │ ├── path_compressed_branchbound.py
│ │ ├── path_compressed_mcts.py
│ │ └── scoring.py
│ ├── hypergraph.py
│ ├── hyperoptimizers/
│ │ ├── __init__.py
│ │ ├── _param_mapping.py
│ │ ├── hyper.py
│ │ ├── hyper_cmaes.py
│ │ ├── hyper_es.py
│ │ ├── hyper_neldermead.py
│ │ ├── hyper_nevergrad.py
│ │ ├── hyper_optuna.py
│ │ ├── hyper_random.py
│ │ ├── hyper_sbplx.py
│ │ └── hyper_skopt.py
│ ├── interface.py
│ ├── nodeops.py
│ ├── oe.py
│ ├── parallel.py
│ ├── pathfinders/
│ │ ├── __init__.py
│ │ ├── kahypar_profiles/
│ │ │ ├── cut_kKaHyPar_sea20.ini
│ │ │ ├── cut_rKaHyPar_sea20.ini
│ │ │ ├── km1_kKaHyPar_sea20.ini
│ │ │ ├── km1_rKaHyPar_sea20.ini
│ │ │ └── old/
│ │ │ ├── cut_kKaHyPar_sea20.ini
│ │ │ ├── cut_rKaHyPar_sea20.ini
│ │ │ ├── km1_kKaHyPar_sea20.ini
│ │ │ └── km1_rKaHyPar_sea20.ini
│ │ ├── path_basic.py
│ │ ├── path_compressed.py
│ │ ├── path_compressed_greedy.py
│ │ ├── path_edgesort.py
│ │ ├── path_flowcutter.py
│ │ ├── path_greedy.py
│ │ ├── path_igraph.py
│ │ ├── path_kahypar.py
│ │ ├── path_labels.py
│ │ ├── path_quickbb.py
│ │ ├── path_random.py
│ │ ├── path_simulated_annealing.py
│ │ └── treedecomp.py
│ ├── plot.py
│ ├── presets.py
│ ├── reusable.py
│ ├── schematic.py
│ ├── scoring.py
│ ├── slicer.py
│ └── utils.py
├── docs/
│ ├── Makefile
│ ├── _pygments/
│ │ ├── _pygments_dark.py
│ │ └── _pygments_light.py
│ ├── _static/
│ │ └── my-styles.css
│ ├── advanced.ipynb
│ ├── basics.ipynb
│ ├── changelog.md
│ ├── conf.py
│ ├── contraction.ipynb
│ ├── examples/
│ │ ├── ex_compressed_contraction.ipynb
│ │ ├── ex_large_output_lazy.ipynb
│ │ └── ex_trace_contraction_to_matmuls.ipynb
│ ├── high-level-interface.ipynb
│ ├── index.md
│ ├── index_examples.md
│ ├── installation.md
│ ├── make.bat
│ ├── trees.ipynb
│ └── visualization.ipynb
├── examples/
│ ├── Example - Reproducing 2005.06787.ipynb
│ ├── Example - Reproducing 2103-03074.ipynb
│ ├── Quantum Circuit Example Old.ipynb
│ ├── Quantum Circuit Example.ipynb
│ ├── benchmarks/
│ │ ├── cubic_6x6x10.json
│ │ ├── mps_mpo_L100_chi64_D5.json
│ │ ├── peps_cluster_r2_D10_a.json
│ │ ├── qucirc_rrzz_n56_d13.json
│ │ ├── rand_50_5_a.json
│ │ ├── randreg_200_3_a.json
│ │ ├── rtree_100_a.json
│ │ └── sycamore_n53_m20_s0_e0_pABCDCDAB.json
│ ├── circuit_n53_m10_s0_e0_pABCDCDAB.qsim
│ ├── circuit_n53_m12_s0_e0_pABCDCDAB.qsim
│ ├── circuit_n53_m20_s0_e0_pABCDCDAB.qsim
│ ├── ex_jax.py
│ ├── ex_mpi_executor.py
│ └── ex_mpi_spmd.py
├── pyproject.toml
└── tests/
├── __init__.py
├── test_backends.py
├── test_compressed.py
├── test_compute.py
├── test_hypergraph.py
├── test_interface.py
├── test_optimizers.py
├── test_parallel.py
├── test_paths_basic.py
├── test_slicer.py
└── test_tree.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .codecov.yml
================================================
codecov:
require_ci_to_pass: yes
coverage:
range: 50..100
status:
project:
default:
informational: true
patch:
default:
informational: true
changes: false
comment: off
================================================
FILE: .gitattributes
================================================
# Auto detect text files and perform LF normalization
* text=auto
# Standard to msysgit
*.doc diff=astextplain
*.DOC diff=astextplain
*.docx diff=astextplain
*.DOCX diff=astextplain
*.dot diff=astextplain
*.DOT diff=astextplain
*.pdf diff=astextplain
*.PDF diff=astextplain
*.rtf diff=astextplain
*.RTF diff=astextplain
# include the version number in git archive
cotengra/_version.py export-subst
# make cotengra appear as a python project on github
*.ipynb linguist-language=Python
# SCM syntax highlighting & preventing 3-way merges
pixi.lock merge=binary linguist-language=YAML linguist-generated=true
================================================
FILE: .github/dependabot.yml
================================================
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
version: 2
updates:
# Enable Dependabot for GitHub Actions
- package-ecosystem: "github-actions"
directory: "/" # Location of your workflows, typically the root folder
schedule:
interval: "daily" # Frequency of update checks (daily, weekly, or monthly)
================================================
FILE: .github/workflows/pypi-release.yml
================================================
name: Build and Upload cotengra to PyPI
on:
release:
types:
- published
push:
tags:
- 'v*'
jobs:
build-artifacts:
runs-on: ubuntu-latest
if: github.repository == 'jcmgray/cotengra'
steps:
- uses: actions/checkout@v6
with:
fetch-depth: 0
- uses: actions/setup-python@v6
name: Install Python
with:
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install build twine
- name: Build tarball and wheels
run: |
git clean -xdf
git restore -SW .
python -m build
- name: Check built artifacts
run: |
python -m twine check --strict dist/*
pwd
if [ -f dist/cotengra-0.0.0.tar.gz ]; then
echo "❌ INVALID VERSION NUMBER"
exit 1
else
echo "✅ Looks good"
fi
- uses: actions/upload-artifact@v7
with:
name: releases
path: dist
test-built-dist:
needs: build-artifacts
runs-on: ubuntu-latest
steps:
- uses: actions/setup-python@v6
name: Install Python
with:
python-version: "3.12"
- uses: actions/download-artifact@v8
with:
name: releases
path: dist
- name: List contents of built dist
run: |
ls -ltrh
ls -ltrh dist
- name: Verify the built dist/wheel is valid
if: github.event_name == 'push'
run: |
python -m pip install --upgrade pip
python -m pip install dist/cotengra*.whl
upload-to-test-pypi:
needs: test-built-dist
if: github.event_name == 'push'
runs-on: ubuntu-latest
environment:
name: pypi
url: https://test.pypi.org/p/cotengra
permissions:
id-token: write
steps:
- uses: actions/download-artifact@v8
with:
name: releases
path: dist
- name: Publish package to TestPyPI
if: github.event_name == 'push'
uses: pypa/gh-action-pypi-publish@v1.14.0
with:
repository-url: https://test.pypi.org/legacy/
verbose: true
upload-to-pypi:
needs: test-built-dist
if: github.event_name == 'release'
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/cotengra
permissions:
id-token: write
steps:
- uses: actions/download-artifact@v8
with:
name: releases
path: dist
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@v1.14.0
with:
verbose: true
================================================
FILE: .github/workflows/test.yml
================================================
name: Tests
on:
workflow_dispatch:
push:
pull_request:
defaults:
run:
shell: bash -l {0}
jobs:
run-tests:
runs-on: ${{ matrix.os }}
strategy:
matrix:
include:
- os: ubuntu-latest
pixi_environment: testminimal
pixi_task: test
- os: ubuntu-latest
pixi_environment: testpyold
pixi_task: test
- os: ubuntu-latest
pixi_environment: testpynew
pixi_task: test
- os: macos-latest
pixi_environment: testpymid
pixi_task: test
- os: windows-latest
pixi_environment: testpymid
pixi_task: test
- os: ubuntu-latest
pixi_environment: testtorch
pixi_task: test-backends
- os: ubuntu-latest
pixi_environment: testjax
pixi_task: test-backends
- os: ubuntu-latest
pixi_environment: testtensorflow
pixi_task: test-backends
steps:
- uses: actions/checkout@v6
- uses: prefix-dev/setup-pixi@v0.9.5
with:
environments: ${{ matrix.pixi_environment }}
- name: Test with pytest
run: pixi run -e ${{ matrix.pixi_environment }} ${{ matrix.pixi_task }}
- name: Report to codecov
uses: codecov/codecov-action@v6
with:
token: ${{ secrets.CODECOV_TOKEN }}
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
# Translations
*.mo
*.pot
# Django stuff:
*.log
# Sphinx documentation
docs/_build/
# PyBuilder
target/
.pytest_cache
.vscode
**/.ipynb_checkpoints/
# Added by cargo
/target
Cargo.lock
cotengra/_version.py
experiments
# pixi environments
.pixi/*
!.pixi/config.toml
================================================
FILE: .readthedocs.yml
================================================
version: 2
sphinx:
configuration: docs/conf.py
build:
os: "ubuntu-24.04"
tools:
python: "latest"
jobs:
create_environment:
- asdf plugin add pixi
- asdf install pixi latest
- asdf global pixi latest
install:
- pixi install -e docs
build:
html:
- pixi run readthedocs
formats: []
================================================
FILE: LICENSE.md
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: MANIFEST.in
================================================
include cotengra/kahypar_profiles/*.ini
include cotengra/kahypar_profiles/old/*.ini
include LICENSE.md
include README.md
graft tests
================================================
FILE: README.md
================================================
[](https://github.com/jcmgray/cotengra/actions/workflows/test.yml)
[](https://codecov.io/gh/jcmgray/cotengra)
[](https://cotengra.readthedocs.io)
[](https://pypi.org/project/cotengra/)
[](https://anaconda.org/conda-forge/cotengra)
[](https://pixi.sh)
`cotengra` is a python library for contracting tensor networks or einsum
expressions involving large numbers of tensors - the main docs can be found
at [cotengra.readthedocs.io](https://cotengra.readthedocs.io/).
Some of the key feautures of `cotengra` include:
* drop-in ``einsum`` and ``ncon`` replacement
* an explicit **contraction tree** object that can be flexibly built, modified and visualized
* a **'hyper optimizer'** that samples trees while tuning the generating meta-paremeters
* **dynamic slicing** for massive memory savings and parallelism
* **simulated annealing** as an alternative optimizing and slicing strategy
* support for **hyper** edge tensor networks and thus arbitrary einsum equations
* **paths** that can be supplied to [`numpy.einsum`](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html), [`opt_einsum`](https://dgasmith.github.io/opt_einsum/), [`quimb`](https://quimb.readthedocs.io/en/latest/) among others
* **performing contractions** with tensors from many libraries via [`autoray`](https://github.com/jcmgray/autoray),
even if they don't provide `einsum` or `tensordot` but do have (batch) matrix
multiplication
================================================
FILE: cotengra/__init__.py
================================================
"""Hyper optimized contraction trees for large tensor networks and einsums."""
from importlib.metadata import PackageNotFoundError as _PackageNotFoundError
from importlib.metadata import version as _version
try:
__version__ = _version("cotengra")
except _PackageNotFoundError:
try:
# fallback for source trees where hatch-vcs has generated _version.py.
from ._version import version as __version__
except ImportError:
__version__ = "0.0.0+unknown"
import functools
import warnings
from . import utils
from .core import (
ContractionTree,
ContractionTreeCompressed,
)
from .core_multi import (
ContractionTreeMulti,
)
from .hypergraph import (
HyperGraph,
get_hypergraph,
)
from .hyperoptimizers import (
hyper_cmaes,
hyper_es,
hyper_neldermead,
hyper_nevergrad,
hyper_optuna,
hyper_random,
hyper_sbplx,
hyper_skopt,
)
from .hyperoptimizers.hyper import (
HyperCompressedOptimizer,
HyperMultiOptimizer,
HyperOptimizer,
ReusableHyperCompressedOptimizer,
ReusableHyperOptimizer,
get_hyper_space,
list_hyper_functions,
)
from .interface import (
array_contract,
array_contract_expression,
array_contract_path,
array_contract_tree,
einsum,
einsum_expression,
einsum_tree,
ncon,
register_preset,
)
from .oe import PathOptimizer
from .pathfinders import (
path_basic,
path_compressed_greedy,
path_greedy,
path_igraph,
path_kahypar,
path_labels,
)
from .pathfinders.path_basic import (
GreedyOptimizer,
OptimalOptimizer,
RandomGreedyOptimizer,
ReusableRandomGreedyOptimizer,
edge_path_to_linear,
edge_path_to_ssa,
linear_to_ssa,
ssa_to_linear,
)
from .pathfinders.path_flowcutter import (
FlowCutterOptimizer,
optimize_flowcutter,
)
from .pathfinders.path_quickbb import QuickBBOptimizer, optimize_quickbb
from .pathfinders.path_random import RandomOptimizer
from .plot import (
plot_contractions,
plot_contractions_alt,
plot_scatter,
plot_scatter_alt,
plot_slicings,
plot_slicings_alt,
plot_tree,
plot_tree_ring,
plot_tree_span,
plot_tree_tent,
plot_trials,
plot_trials_alt,
)
from .presets import (
AutoHQOptimizer,
AutoOptimizer,
auto_hq_optimize,
auto_optimize,
greedy_optimize,
optimal_optimize,
optimal_outer_optimize,
)
from .reusable import (
hash_contraction,
)
from .slicer import SliceFinder
from .utils import (
get_symbol,
get_symbol_map,
)
UniformOptimizer = functools.partial(HyperOptimizer, optlib="random")
"""Does no gaussian process tuning by default, just randomly samples - requires
no optimization library.
"""
contract_expression = einsum_expression
"""Alias for :func:`cotengra.einsum_expression`."""
contract = einsum
"""Alias for :func:`cotengra.einsum`."""
__all__ = (
"array_contract_expression",
"array_contract_path",
"array_contract_tree",
"array_contract",
"auto_hq_optimize",
"auto_optimize",
"AutoHQOptimizer",
"AutoOptimizer",
"contract_expression",
"contract",
"ContractionTree",
"ContractionTreeCompressed",
"ContractionTreeMulti",
"edge_path_to_linear",
"edge_path_to_ssa",
"einsum_expression",
"einsum_tree",
"einsum",
"FlowCutterOptimizer",
"get_hyper_space",
"get_hypergraph",
"get_symbol_map",
"get_symbol",
"greedy_optimize",
"GreedyOptimizer",
"hash_contraction",
"hyper_cmaes",
"hyper_nevergrad",
"hyper_neldermead",
"hyper_optimize",
"hyper_optuna",
"hyper_random",
"hyper_es",
"hyper_skopt",
"hyper_sbplx",
"HyperCompressedOptimizer",
"HyperGraph",
"HyperMultiOptimizer",
"HyperOptimizer",
"linear_to_ssa",
"list_hyper_functions",
"ncon",
"optimal_optimize",
"optimal_outer_optimize",
"OptimalOptimizer",
"optimize_flowcutter",
"optimize_quickbb",
"path_basic",
"path_compressed_greedy",
"path_greedy",
"path_igraph",
"path_kahypar",
"path_labels",
"PathOptimizer",
"plot_contractions_alt",
"plot_contractions",
"plot_scatter_alt",
"plot_scatter",
"plot_slicings_alt",
"plot_slicings",
"plot_tree_ring",
"plot_tree_span",
"plot_tree_tent",
"plot_tree",
"plot_trials_alt",
"plot_trials",
"QuasiRandOptimizer",
"QuickBBOptimizer",
"RandomGreedyOptimizer",
"RandomOptimizer",
"register_preset",
"ReusableHyperCompressedOptimizer",
"ReusableHyperOptimizer",
"ReusableRandomGreedyOptimizer",
"SliceFinder",
"ssa_to_linear",
"UniformOptimizer",
"utils",
)
# add some presets
def hyper_optimize(
inputs,
output,
size_dict,
memory_limit=None,
get="path",
**opts,
):
if memory_limit is not None:
warnings.warn(
"`memory_limit` is not supported in hyper_optimize, ignoring."
)
optimizer = HyperOptimizer(**opts)
if get == "path":
return optimizer(inputs, output, size_dict)
elif get == "tree":
return optimizer.search(inputs, output, size_dict)
else:
raise ValueError(f"Unknown get option {get}")
def hyper_compressed_optimize(
inputs,
output,
size_dict,
get="path",
**opts,
):
optimizer = HyperCompressedOptimizer(**opts)
if get == "path":
return optimizer(inputs, output, size_dict)
elif get == "tree":
return optimizer.search(inputs, output, size_dict)
else:
raise ValueError(f"Unknown get option {get}")
def random_greedy_optimize(
inputs, output, size_dict, memory_limit=None, **opts
):
if memory_limit is not None:
warnings.warn(
"`memory_limit` is not supported in "
"random_greedy_optimize, ignoring."
)
optimizer = RandomGreedyOptimizer(**opts)
return optimizer(inputs, output, size_dict)
try:
register_preset(
"hyper",
hyper_optimize,
optimizer_tree=functools.partial(hyper_optimize, get="tree"),
)
register_preset(
"hyper-256",
functools.partial(hyper_optimize, max_repeats=256),
optimizer_tree=functools.partial(
hyper_optimize, max_repeats=256, get="tree"
),
)
register_preset(
"hyper-greedy",
functools.partial(hyper_optimize, methods=["greedy"]),
optimizer_tree=functools.partial(
hyper_optimize, methods=["greedy"], get="tree"
),
)
register_preset(
"hyper-labels",
functools.partial(hyper_optimize, methods=["labels"]),
optimizer_tree=functools.partial(
hyper_optimize, methods=["labels"], get="tree"
),
)
register_preset(
"hyper-kahypar",
functools.partial(hyper_optimize, methods=["kahypar"]),
optimizer_tree=functools.partial(
hyper_optimize, methods=["kahypar"], get="tree"
),
)
register_preset(
"hyper-balanced",
functools.partial(
hyper_optimize, methods=["kahypar-balanced"], max_repeats=16
),
optimizer_tree=functools.partial(
hyper_optimize,
methods=["kahypar-balanced"],
max_repeats=16,
get="tree",
),
)
register_preset(
"hyper-compressed",
hyper_compressed_optimize,
optimizer_tree=functools.partial(
hyper_compressed_optimize,
get="tree",
),
compressed=True,
)
register_preset(
"hyper-spinglass",
functools.partial(hyper_optimize, methods=["spinglass"]),
)
register_preset(
"hyper-betweenness",
functools.partial(hyper_optimize, methods=["betweenness"]),
)
register_preset(
"random-greedy",
random_greedy_optimize,
)
register_preset(
"random-greedy-128",
functools.partial(random_greedy_optimize, max_repeats=128),
)
register_preset(
"flowcutter-2",
functools.partial(optimize_flowcutter, max_time=2),
)
register_preset(
"flowcutter-10",
functools.partial(optimize_flowcutter, max_time=10),
)
register_preset(
"flowcutter-60",
functools.partial(optimize_flowcutter, max_time=60),
)
register_preset(
"quickbb-2",
functools.partial(optimize_quickbb, max_time=2),
)
register_preset(
"quickbb-10",
functools.partial(optimize_quickbb, max_time=10),
)
register_preset(
"quickbb-60",
functools.partial(optimize_quickbb, max_time=60),
)
register_preset(
"greedy-compressed",
path_compressed_greedy.greedy_compressed,
path_compressed_greedy.trial_greedy_compressed,
compressed=True,
)
register_preset(
"greedy-span",
path_compressed_greedy.greedy_span,
path_compressed_greedy.trial_greedy_span,
compressed=True,
)
except KeyError:
# KeyError: if reloading cotengra e.g. library entries already registered
pass
================================================
FILE: cotengra/contract.py
================================================
"""Functionality relating to actually contracting."""
import contextlib
import functools
import itertools
import operator
from autoray import do, get_namespace, infer_backend_multi, shape
DEFAULT_IMPLEMENTATION = "auto"
def set_default_implementation(impl):
global DEFAULT_IMPLEMENTATION
DEFAULT_IMPLEMENTATION = impl
def get_default_implementation():
return DEFAULT_IMPLEMENTATION
@contextlib.contextmanager
def default_implementation(impl):
"""Context manager for temporarily setting the default implementation."""
global DEFAULT_IMPLEMENTATION
old_impl = DEFAULT_IMPLEMENTATION
DEFAULT_IMPLEMENTATION = impl
try:
yield
finally:
DEFAULT_IMPLEMENTATION = old_impl
@functools.lru_cache(2**12)
def _sanitize_equation(eq):
"""Get the input and output indices of an equation, computing the output
implicitly as the sorted sequence of every index that appears exactly once
if it is not provided.
"""
# remove spaces
eq = eq.replace(" ", "")
if "..." in eq:
raise NotImplementedError("Ellipsis not supported.")
if "->" not in eq:
lhs = eq
tmp_subscripts = lhs.replace(",", "")
out = "".join(
# sorted sequence of indices
s
for s in sorted(set(tmp_subscripts))
# that appear exactly once
if tmp_subscripts.count(s) == 1
)
else:
lhs, out = eq.split("->")
return lhs, out
@functools.lru_cache(2**12)
def _parse_einsum_single(eq, shape):
"""Cached parsing of a single term einsum equation into the necessary
sequence of arguments for axes diagonals, sums, and transposes.
"""
lhs, out = _sanitize_equation(eq)
# parse each index
need_to_diag = []
need_to_sum = []
seen = set()
for ix in lhs:
if ix in need_to_diag:
continue
if ix in seen:
need_to_diag.append(ix)
continue
seen.add(ix)
if ix not in out:
need_to_sum.append(ix)
# first handle diagonal reductions
if need_to_diag:
diag_sels = []
sizes = dict(zip(lhs, shape))
while need_to_diag:
ixd = need_to_diag.pop()
dinds = tuple(range(sizes[ixd]))
# construct advanced indexing object
selector = tuple(dinds if ix == ixd else slice(None) for ix in lhs)
diag_sels.append(selector)
# after taking the diagonal what are new indices?
ixd_contig = ixd * lhs.count(ixd)
if ixd_contig in lhs:
# contig axes, new axis is at same position
lhs = lhs.replace(ixd_contig, ixd)
else:
# non-contig, new axis is at beginning
lhs = ixd + lhs.replace(ixd, "")
else:
diag_sels = None
# then sum reductions
if need_to_sum:
sum_axes = tuple(map(lhs.index, need_to_sum))
for ix in need_to_sum:
lhs = lhs.replace(ix, "")
else:
sum_axes = None
# then transposition
if lhs == out:
perm = None
else:
perm = tuple(lhs.index(ix) for ix in out)
return diag_sels, sum_axes, perm
def _parse_eq_to_pure_multiplication(a_term, shape_a, b_term, shape_b, out):
"""If there are no contracted indices, then we can directly transpose and
insert singleton dimensions into ``a`` and ``b`` such that (broadcast)
elementwise multiplication performs the einsum.
No need to cache this as it is within the cached
``_parse_eq_to_batch_matmul``.
"""
desired_a = ""
desired_b = ""
new_shape_a = []
new_shape_b = []
for ix in out:
if ix in a_term:
desired_a += ix
new_shape_a.append(shape_a[a_term.index(ix)])
else:
new_shape_a.append(1)
if ix in b_term:
desired_b += ix
new_shape_b.append(shape_b[b_term.index(ix)])
else:
new_shape_b.append(1)
if desired_a != a_term:
eq_a = f"{a_term}->{desired_a}"
else:
eq_a = None
if desired_b != b_term:
eq_b = f"{b_term}->{desired_b}"
else:
eq_b = None
return (
eq_a,
eq_b,
new_shape_a,
new_shape_b,
None, # new_shape_ab, not needed since not fusing
None, # perm_ab, not needed as we transpose a and b first
True, # pure_multiplication=True
)
@functools.lru_cache(2**12)
def _parse_eq_to_batch_matmul(eq, shape_a, shape_b):
"""Cached parsing of a two term einsum equation into the necessary
sequence of arguments for contracttion via batched matrix multiplication.
The steps we need to specify are:
1. Remove repeated and trivial indices from the left and right terms,
and transpose them, done as a single einsum.
2. Fuse the remaining indices so we have two 3D tensors.
3. Perform the batched matrix multiplication.
4. Unfuse the output to get the desired final index order.
"""
lhs, out = eq.split("->")
a_term, b_term = lhs.split(",")
if len(a_term) != len(shape_a):
raise ValueError(f"Term '{a_term}' does not match shape {shape_a}.")
if len(b_term) != len(shape_b):
raise ValueError(f"Term '{b_term}' does not match shape {shape_b}.")
sizes = {}
singletons = set()
# parse left term to unique indices with size > 1
left = {}
for ix, d in zip(a_term, shape_a):
if d == 1:
# everything (including broadcasting) works nicely if simply ignore
# such dimensions, but we do need to track if they appear in output
# and thus should be reintroduced later
singletons.add(ix)
continue
if sizes.setdefault(ix, d) != d:
# set and check size
raise ValueError(
f"Index {ix} has mismatched sizes {sizes[ix]} and {d}."
)
left[ix] = True
# parse right term to unique indices with size > 1
right = {}
for ix, d in zip(b_term, shape_b):
# broadcast indices (size 1 on one input and size != 1
# on the other) should not be treated as singletons
if d == 1:
if ix not in left:
singletons.add(ix)
continue
singletons.discard(ix)
if sizes.setdefault(ix, d) != d:
# set and check size
raise ValueError(
f"Index {ix} has mismatched sizes {sizes[ix]} and {d}."
)
right[ix] = True
# now we classify the unique size > 1 indices only
bat_inds = [] # appears on A, B, O
con_inds = [] # appears on A, B, .
a_keep = [] # appears on A, ., O
b_keep = [] # appears on ., B, O
# other indices (appearing on A or B only) will
# be summed or traced out prior to the matmul
for ix in left:
if right.pop(ix, False):
if ix in out:
bat_inds.append(ix)
else:
con_inds.append(ix)
elif ix in out:
a_keep.append(ix)
# now only indices unique to right remain
for ix in right:
if ix in out:
b_keep.append(ix)
if not con_inds:
# contraction is pure multiplication, prepare inputs differently
return _parse_eq_to_pure_multiplication(
a_term, shape_a, b_term, shape_b, out
)
# only need the size one indices that appear in the output
singletons = [ix for ix in out if ix in singletons]
# take diagonal, remove any trivial axes and transpose left
desired_a = "".join((*bat_inds, *a_keep, *con_inds))
if a_term != desired_a:
if set(a_term) == set(desired_a):
# only need to transpose, don't invoke einsum
eq_a = tuple(a_term.index(ix) for ix in desired_a)
else:
eq_a = f"{a_term}->{desired_a}"
else:
eq_a = None
# take diagonal, remove any trivial axes and transpose right
desired_b = "".join((*bat_inds, *con_inds, *b_keep))
if b_term != desired_b:
if set(b_term) == set(desired_b):
# only need to transpose, don't invoke einsum
eq_b = tuple(b_term.index(ix) for ix in desired_b)
else:
eq_b = f"{b_term}->{desired_b}"
else:
eq_b = None
# then we want to reshape
if bat_inds:
lgroups = (bat_inds, a_keep, con_inds)
rgroups = (bat_inds, con_inds, b_keep)
ogroups = (bat_inds, a_keep, b_keep)
else:
# avoid size 1 batch dimension if no batch indices
lgroups = (a_keep, con_inds)
rgroups = (con_inds, b_keep)
ogroups = (a_keep, b_keep)
if any(len(group) != 1 for group in lgroups):
# need to fuse 'kept' and contracted indices
# (though could allow batch indices to be broadcast)
new_shape_a = tuple(
functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1)
for ix_group in lgroups
)
else:
new_shape_a = None
if any(len(group) != 1 for group in rgroups):
# need to fuse 'kept' and contracted indices
# (though could allow batch indices to be broadcast)
new_shape_b = tuple(
functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1)
for ix_group in rgroups
)
else:
new_shape_b = None
if any(len(group) != 1 for group in ogroups) or singletons:
new_shape_ab = (1,) * len(singletons) + tuple(
sizes[ix] for ix_group in ogroups for ix in ix_group
)
else:
new_shape_ab = None
# then we might need to permute the matmul produced output:
out_produced = "".join((*singletons, *bat_inds, *a_keep, *b_keep))
if out_produced != out:
perm_ab = tuple(out_produced.index(ix) for ix in out)
else:
perm_ab = None
return (
eq_a,
eq_b,
new_shape_a,
new_shape_b,
new_shape_ab,
perm_ab,
False, # pure_multiplication=False
)
def _einsum_single(eq, x, backend=None):
"""Einsum on a single tensor, via three steps: diagonal selection
(via advanced indexing), axes summations, transposition. The logic for each
is cached based on the equation and array shape, and each step is only
performed if necessary.
"""
try:
return do("einsum", eq, x, like=backend)
except ImportError:
pass
diag_sels, sum_axes, perm = _parse_einsum_single(eq, shape(x))
if diag_sels is not None:
# diagonal reduction via advanced indexing
# e.g ababbac->abc
for selector in diag_sels:
x = x[selector]
if sum_axes is not None:
# trivial removal of axes via summation
# e.g. abc->c
x = do("sum", x, sum_axes, like=backend)
if perm is not None:
# transpose to desired output
# e.g. abc->cba
x = do("transpose", x, perm, like=backend)
return x
def _do_contraction_via_bmm(
a,
b,
eq_a,
eq_b,
new_shape_a,
new_shape_b,
new_shape_ab,
perm_ab,
pure_multiplication,
backend,
):
# prepare left
if eq_a is not None:
if isinstance(eq_a, tuple):
# only transpose
a = do("transpose", a, eq_a, like=backend)
else:
# diagonals, sums, and tranpose
a = _einsum_single(eq_a, a)
if new_shape_a is not None:
a = do("reshape", a, new_shape_a, like=backend)
# prepare right
if eq_b is not None:
if isinstance(eq_b, tuple):
# only transpose
b = do("transpose", b, eq_b, like=backend)
else:
# diagonals, sums, and tranpose
b = _einsum_single(eq_b, b)
if new_shape_b is not None:
b = do("reshape", b, new_shape_b, like=backend)
if pure_multiplication:
# no contracted indices
return do("multiply", a, b)
# do the contraction!
ab = do("matmul", a, b, like=backend)
# prepare the output
if new_shape_ab is not None:
ab = do("reshape", ab, new_shape_ab, like=backend)
if perm_ab is not None:
ab = do("transpose", ab, perm_ab, like=backend)
return ab
def einsum(eq, a, b=None, *, backend=None):
"""Perform arbitrary single and pairwise einsums using only `matmul`,
`transpose`, `reshape` and `sum`. The logic for each is cached based on
the equation and array shape, and each step is only performed if necessary.
Parameters
----------
eq : str
The einsum equation.
a : array_like
The first array to contract.
b : array_like, optional
The second array to contract.
backend : str, optional
The backend to use for array operations. If ``None``, dispatch
automatically based on ``a`` and ``b``.
Returns
-------
array_like
"""
if b is None:
return _einsum_single(eq, a, backend=backend)
(
eq_a,
eq_b,
new_shape_a,
new_shape_b,
new_shape_ab,
perm_ab,
pure_multiplication,
) = _parse_eq_to_batch_matmul(eq, shape(a), shape(b))
return _do_contraction_via_bmm(
a,
b,
eq_a,
eq_b,
new_shape_a,
new_shape_b,
new_shape_ab,
perm_ab,
pure_multiplication,
backend,
)
def gen_nice_inds():
"""Generate the indices from [a-z, A-Z, reasonable unicode...]."""
for i in range(26):
yield chr(ord("a") + i)
for i in range(26):
yield chr(ord("A") + i)
for i in itertools.count(192):
yield chr(i)
@functools.lru_cache(2**12)
def _parse_tensordot_axes_to_matmul(axes, shape_a, shape_b):
"""Parse a tensordot specification into the necessary sequence of arguments
for contracttion via matrix multiplication. This just converts ``axes``
into an ``einsum`` eq string then calls ``_parse_eq_to_batch_matmul``.
"""
ndim_a = len(shape_a)
ndim_b = len(shape_b)
if isinstance(axes, int):
axes_a = tuple(range(ndim_a - axes, ndim_a))
axes_b = tuple(range(axes))
else:
axes_a, axes_b = axes
num_con = len(axes_a)
if num_con != len(axes_b):
raise ValueError(
f"Axes should have the same length, got {axes_a} and {axes_b}."
)
possible_inds = gen_nice_inds()
inds_a = [next(possible_inds) for _ in range(ndim_a)]
inds_b = []
inds_out = inds_a.copy()
for axb in range(ndim_b):
if axb not in axes_b:
# right uncontracted index
ind = next(possible_inds)
inds_out.append(ind)
else:
# contracted index
axa = axes_a[axes_b.index(axb)]
# check that the shapes match
if shape_a[axa] != shape_b[axb]:
raise ValueError(
f"Dimension mismatch between axes {axa} of {shape_a} and "
f"{axb} of {shape_b}: {shape_a[axa]} != {shape_b[axb]}."
)
ind = inds_a[axa]
inds_out.remove(ind)
inds_b.append(ind)
eq = f"{''.join(inds_a)},{''.join(inds_b)}->{''.join(inds_out)}"
return _parse_eq_to_batch_matmul(eq, shape_a, shape_b)
def tensordot(a, b, axes=2, *, backend=None):
"""Perform a tensordot using only `matmul`, `transpose`, `reshape`. The
logic for each is cached based on the equation and array shape, and each
step is only performed if necessary.
Parameters
----------
a, b : array_like
The arrays to contract.
axes : int or tuple of (sequence[int], sequence[int])
The number of axes to contract, or the axes to contract. If an int,
the last ``axes`` axes of ``a`` and the first ``axes`` axes of ``b``
are contracted. If a tuple, the axes to contract for ``a`` and ``b``
respectively.
backend : str or None, optional
The backend to use for array operations. If ``None``, dispatch
automatically based on ``a`` and ``b``.
Returns
-------
array_like
"""
try:
# ensure hashable
axes = tuple(map(int, axes[0])), tuple(map(int, axes[1]))
except IndexError:
axes = int(axes)
(
eq_a,
eq_b,
new_shape_a,
new_shape_b,
new_shape_ab,
perm_ab,
pure_multiplication,
) = _parse_tensordot_axes_to_matmul(axes, shape(a), shape(b))
return _do_contraction_via_bmm(
a,
b,
eq_a,
eq_b,
new_shape_a,
new_shape_b,
new_shape_ab,
perm_ab,
pure_multiplication,
backend,
)
def extract_contractions(
tree,
order=None,
prefer_einsum=False,
):
"""Extract just the information needed to perform the contraction.
Parameters
----------
order : str or callable, optional
Supplied to :meth:`ContractionTree.traverse`.
prefer_einsum : bool, optional
Prefer to use ``einsum`` for pairwise contractions, even if
``tensordot`` can perform the contraction.
Returns
-------
contractions : tuple
A tuple of tuples, each containing the information needed to
perform a pairwise contraction. Each tuple contains:
- ``p``: the parent node,
- ``l``: the left child node,
- ``r``: the right child node,
- ``tdot``: whether to use ``tensordot`` or ``einsum``,
- ``arg``: the argument to pass to ``tensordot`` or ``einsum``
i.e. ``axes`` or ``eq``,
- ``perm``: the permutation required after the contraction, if
any (only applies to tensordot).
If both ``l`` and ``r`` are ``None``, the the operation is a single
term simplification performed with ``einsum``.
"""
if tree.N == 1:
# trivial 'contraction', single input maps directly to output
# possibly with reductions/transpose, setting l but not r flags this
pi = 1
li = 0
ri = None
tdot = False
arg = tree.get_eq_sliced()
perm = None
return [(pi, li, ri, tdot, arg, perm)]
contractions = []
# for compactness we convert nodes to ssa indices
ssas = {leaf: i for i, leaf in enumerate(tree.gen_leaves())}
ssa = len(ssas)
# pairwise contractions
for p, l, r in tree.traverse(order=order):
li = ssas.pop(l)
ri = ssas.pop(r)
pi = ssas[p] = ssa
ssa += 1
if prefer_einsum or not tree.get_can_dot(p):
tdot = False
arg = tree.get_einsum_eq(p)
perm = None
else:
tdot = True
arg = tree.get_tensordot_axes(p)
perm = tree.get_tensordot_perm(p)
contractions.append((pi, li, ri, tdot, arg, perm))
if tree.preprocessing:
# inplace single term simplifications
# n.b. these are populated lazily when the other information is
# computed above, so we do it after
pre_contractions = (
(i, None, None, False, eq, None)
for i, eq in tree.preprocessing.items()
)
return (*pre_contractions, *contractions)
return tuple(contractions)
class Contractor:
"""Default cotengra network contractor.
Parameters
----------
contractions : tuple[tuple]
The sequence of contractions to perform. Each contraction should be a
tuple containing:
- ``p``: the parent node,
- ``l``: the left child node,
- ``r``: the right child node,
- ``tdot``: whether to use ``tensordot`` or ``einsum``,
- ``arg``: the argument to pass to ``tensordot`` or ``einsum``
i.e. ``axes`` or ``eq``,
- ``perm``: the permutation required after the contraction, if
any (only applies to tensordot).
e.g. built by calling ``extract_contractions(tree)``.
strip_exponent : bool, optional
If ``True``, eagerly strip the exponent (in log10) from
intermediate tensors to control numerical problems from leaving the
range of the datatype. This method then returns the scaled
'mantissa' output array and the exponent separately.
check_zero : bool, optional
If ``True``, when ``strip_exponent=True``, explicitly check for
zero-valued intermediates that would otherwise produce ``nan``,
instead terminating early if encounteredand returning
``(0.0, 0.0)``.
backend : str, optional
What library to use for ``tensordot``, ``einsum`` and
``transpose``, it will be automatically inferred from the input
arrays if not given.
progbar : bool, optional
Whether to show a progress bar.
"""
__slots__ = (
"contractions",
"strip_exponent",
"check_zero",
"implementation",
"backend",
"progbar",
"__weakref__",
)
def __init__(
self,
contractions,
strip_exponent=False,
check_zero=False,
implementation="auto",
backend=None,
progbar=False,
):
self.contractions = contractions
self.strip_exponent = strip_exponent
self.check_zero = check_zero
self.implementation = implementation
self.backend = backend
self.progbar = progbar
def __call__(self, *arrays, **kwargs):
"""Contract ``arrays`` using operations listed in ``contractions``.
Parameters
----------
arrays : sequence of array-like
The arrays to contract.
kwargs : dict
Override the default settings for this contraction only.
Returns
-------
output : array
The contracted output, it will be scaled if ``strip_exponent==True``.
exponent : float
The exponent of the output in base 10, returned only if
``strip_exponent==True``.
"""
backend = kwargs.pop("backend", self.backend)
progbar = kwargs.pop("progbar", self.progbar)
check_zero = kwargs.pop("check_zero", self.check_zero)
strip_exponent = kwargs.pop("strip_exponent", self.strip_exponent)
implementation = kwargs.pop("implementation", self.implementation)
if kwargs:
raise TypeError(f"Unknown keyword arguments: {kwargs}.")
if backend is None:
backend = infer_backend_multi(*arrays)
xp = get_namespace(backend)
if implementation == "auto":
if backend == "numpy":
# by default only replace numpy's einsum/tensordot
implementation = "cotengra"
else:
implementation = "autoray"
if implementation == "cotengra":
_einsum, _tensordot = einsum, tensordot
elif implementation == "pytblis":
import pytblis
_einsum = pytblis.einsum
_tensordot = pytblis.tensordot
elif implementation == "autoray":
try:
_einsum = xp.einsum
except ImportError:
# fallback to cotengra (matmul) implementation
_einsum = einsum
try:
_tensordot = xp.tensordot
except ImportError:
# fallback to cotengra (matmul) implementation
_tensordot = tensordot
else:
# manually supplied
_einsum, _tensordot = implementation
# temporary storage for intermediates
N = len(arrays)
temps = dict(enumerate(arrays))
exponent = 0.0 if (strip_exponent is not False) else None
if progbar:
import tqdm
contractions = tqdm.tqdm(self.contractions, total=N - 1)
else:
contractions = self.contractions
for pi, li, ri, tdot, arg, perm in contractions:
if ri is None:
if li is None:
# single term simplification, perform inplace with einsum
temps[pi] = _einsum(arg, temps[pi])
continue
else:
# trivial 'contraction', single input maps directly to
# output, possibly with reductions/transpose via einsum
p_array = _einsum(arg, temps[li])
if strip_exponent:
return p_array, 0.0
return p_array
# get input arrays for this contraction
l_array = temps.pop(li)
r_array = temps.pop(ri)
if tdot:
p_array = _tensordot(l_array, r_array, arg)
if perm:
p_array = xp.transpose(p_array, perm)
else:
p_array = _einsum(arg, l_array, r_array)
if exponent is not None:
factor = xp.max(xp.abs(p_array))
if check_zero and float(factor) == 0.0:
return 0.0, float("-inf")
exponent = exponent + xp.log10(factor)
if backend == "tensorflow":
factor = xp.astype(factor, p_array.dtype)
# TODO:
# currently special case tensorflow
# autoray needs fix for autojit and astype to use generally
p_array = p_array / factor
# insert the new intermediate array
temps[pi] = p_array
if exponent is not None:
return p_array, exponent
return p_array
class CuQuantumContractor:
def __init__(
self,
tree,
handle_slicing=False,
autotune=False,
**kwargs,
):
if kwargs.pop("strip_exponent", None):
raise ValueError(
"strip_exponent=True not supported with cuQuantum"
)
if tree.has_preprocessing():
raise ValueError("Preprocessing not supported with cuQuantum yet.")
if kwargs.pop("progbar", None):
import warnings
warnings.warn("Progress bar not supported with cuQuantum yet.")
if handle_slicing:
self.eq = tree.get_eq()
self.shapes = tree.get_shapes()
else:
self.eq = tree.get_eq_sliced()
self.shapes = tree.get_shapes_sliced()
if tree.is_complete():
kwargs.setdefault("optimize", {})
kwargs["optimize"].setdefault("path", tree.get_path())
if handle_slicing and tree.sliced_inds:
kwargs["optimize"].setdefault(
"slicing",
[(ix, tree.size_dict[ix] - 1) for ix in tree.sliced_inds],
)
self.kwargs = kwargs
self.autotune = 3 if autotune is True else autotune
self.handle = None
self.network = None
def setup(self, *arrays):
import cuquantum
if hasattr(cuquantum, "bindings"):
# cuquantum-python >= 25.03
from cuquantum.tensornet import Network
else:
# for cuquantum < 25.03
from cuquantum import Network
self.network = Network(
self.eq,
*arrays,
)
self.network.contract_path(**self.kwargs)
if self.autotune:
self.network.autotune(iterations=self.autotune)
def __call__(
self,
*arrays,
check_zero=False,
backend=None,
progbar=False,
):
# can't handle these yet
assert not check_zero
assert not progbar
assert backend is None
if self.network is None:
self.setup(*arrays)
else:
self.network.reset_operands(*arrays)
return self.network.contract()
def __del__(self):
if self.network is not None:
self.network.free()
def make_contractor(
tree,
order=None,
prefer_einsum=False,
strip_exponent=False,
check_zero=False,
implementation=None,
autojit=False,
progbar=False,
):
"""Get a reusable function which performs the contraction corresponding
to ``tree``. The various options provide defaults that can also be overrode
when calling the standard contractor.
Parameters
----------
tree : ContractionTree
The contraction tree.
order : str or callable, optional
Supplied to :meth:`ContractionTree.traverse`, the order in which
to perform the pairwise contractions given by the tree.
prefer_einsum : bool, optional
Prefer to use ``einsum`` for pairwise contractions, even if
``tensordot`` can perform the contraction.
strip_exponent : bool, optional
If ``True``, the function will strip the exponent from the output
array and return it separately.
check_zero : bool, optional
If ``True``, when ``strip_exponent=True``, explicitly check for
zero-valued intermediates that would otherwise produce ``nan``,
instead terminating early if encountered and returning
``(0.0, 0.0)``.
implementation : str or tuple[callable, callable], optional
What library to use to actually perform the contractions. Options are
- "auto": let cotengra choose
- "autoray": dispatch with autoray, using the ``tensordot`` and
``einsum`` implementation of the backend
- "cotengra": use the ``tensordot`` and ``einsum`` implementation of
cotengra, which is based on batch matrix multiplication. This is
faster for some backends like numpy, and also enables libraries
which don't yet provide ``tensordot`` and ``einsum`` to be used.
- "cuquantum": use the cuquantum library to perform the whole
contraction (not just individual contractions).
- tuple[callable, callable]: manually supply the ``tensordot`` and
``einsum`` implementations to use.
autojit : bool, optional
If ``True``, use :func:`autoray.autojit` to compile the contraction
function.
progbar : bool, optional
Whether to show progress through the contraction by default.
Returns
-------
fn : callable
The contraction function, with signature ``fn(*arrays)``.
"""
if implementation is None:
implementation = get_default_implementation()
if implementation == "cuquantum":
fn = CuQuantumContractor(
tree,
strip_exponent=strip_exponent,
check_zero=check_zero,
progbar=progbar,
)
else:
fn = Contractor(
contractions=extract_contractions(tree, order, prefer_einsum),
strip_exponent=strip_exponent,
check_zero=check_zero,
implementation=implementation,
progbar=progbar,
)
if autojit:
from autoray import autojit as _autojit
fn = _autojit(fn)
return fn
================================================
FILE: cotengra/core.py
================================================
"""Core contraction tree data structure and methods."""
import collections
import functools
import itertools
import math
import warnings
from dataclasses import dataclass
from typing import Optional
from autoray import do, get_namespace, infer_backend
from .contract import make_contractor
from .hypergraph import get_hypergraph
from .nodeops import get_nodeops
from .parallel import (
can_scatter,
maybe_leave_pool,
maybe_rejoin_pool,
parse_parallel_arg,
scatter,
submit,
)
from .pathfinders.path_simulated_annealing import (
parallel_temper_tree,
simulated_anneal_tree,
)
from .plot import (
plot_contractions,
plot_contractions_alt,
plot_hypergraph,
plot_tree_circuit,
plot_tree_flat,
plot_tree_ring,
plot_tree_rubberband,
plot_tree_span,
plot_tree_tent,
)
from .scoring import (
DEFAULT_COMBO_FACTOR,
CompressedStatsTracker,
get_score_fn,
)
from .utils import (
MaxCounter,
compute_size_by_dict,
deprecated,
get_rng,
get_symbol,
groupby,
inputs_output_to_eq,
interleave,
oset,
prod,
unique,
)
def cached_node_property(name):
"""Decorator for caching information about nodes."""
def wrapper(meth):
@functools.wraps(meth)
def getter(self, node):
try:
return self.info[node][name]
except KeyError:
self.info[node][name] = value = meth(self, node)
return value
return getter
return wrapper
def legs_union(legs_seq):
"""Combine a sequence of legs into a single set of legs, summing their
appearances.
"""
new_legs, *rem_legs = legs_seq
new_legs = new_legs.copy()
for legs in rem_legs:
for ix, ix_count in legs.items():
new_legs[ix] = new_legs.get(ix, 0) + ix_count
return new_legs
def legs_without(legs, ind):
"""Discard ``ind`` from legs to create a new set of legs."""
new_legs = legs.copy()
new_legs.pop(ind, None)
return new_legs
def get_with_default(k, obj, default):
return obj.get(k, default)
@dataclass(order=True, frozen=True)
class SliceInfo:
inner: bool
ind: str
size: int
project: Optional[int]
@property
def sliced_range(self):
if self.project is None:
return range(self.size)
else:
return [self.project]
def get_slice_strides(sliced_inds):
"""Compute the 'strides' given the (ordered) dictionary of sliced indices."""
slice_infos = list(sliced_inds.values())
nsliced = len(slice_infos)
strides = [1] * nsliced
# backwards cumulative product
for i in range(nsliced - 2, -1, -1):
strides[i] = strides[i + 1] * slice_infos[i + 1].size
return strides
class AdderWithMaybeExponentStripped:
"""Object that ddds two arrays, or tuples of (array, exponent) together in
a stable and branchless way. It also internally caches the backend on the
first call.
"""
__slots__ = ("backend", "namespace", "need_to_cast")
def __init__(self):
self.backend = None
self.namespace = None
self.need_to_cast = False
def __call__(self, x, y):
xistup = isinstance(x, tuple)
yistup = isinstance(y, tuple)
if not (xistup or yistup):
# simple sum without exponent
return x + y
if xistup:
xm, xe = x
else:
xm = x
xe = 0.0
if yistup:
ym, ye = y
else:
ym = y
ye = 0.0
if self.backend is None:
self.backend = infer_backend(xm)
self.namespace = get_namespace(self.backend)
self.need_to_cast = self.backend == "tensorflow"
# perform branchless for jit etc.
e = max(xe, ye)
if self.need_to_cast:
xcoeff = self.namespace.astype(10.0 ** (xe - e), xm.dtype)
ycoeff = self.namespace.astype(10.0 ** (ye - e), ym.dtype)
m = xm * xcoeff + ym * ycoeff
else:
m = xm * 10 ** (xe - e) + ym * 10 ** (ye - e)
return (m, e)
class ContractionTree:
"""Binary tree representing a tensor network contraction.
Parameters
----------
inputs : sequence of str
The list of input tensor's indices.
output : str
The output indices.
size_dict : dict[str, int]
The size of each index.
track_childless : bool, optional
Whether to dynamically keep track of which nodes are childless. Useful
if you are 'divisively' building the tree.
track_flops : bool, optional
Whether to dynamically keep track of the total number of flops. If
``False`` You can still compute this once the tree is complete.
track_write : bool, optional
Whether to dynamically keep track of the total number of elements
written. If ``False`` You can still compute this once the tree is
complete.
track_size : bool, optional
Whether to dynamically keep track of the largest tensor so far. If
``False`` You can still compute this once the tree is complete.
objective : str or Objective, optional
An default objective function to use for further optimization and
scoring, for example reconfiguring or computing the combo cost. If not
supplied the default is to create a flops objective when needed.
Attributes
----------
children : dict[node, tuple[node]]
Mapping of each node to two children.
info : dict[node, dict]
Information about the tree nodes. The key is the set of inputs (a
set of inputs indices) the node contains. Or in other words, the
subgraph of the node. The value is a dictionary to cache information
about effective 'leg' indices, size, flops of formation etc.
"""
def __init__(
self,
inputs,
output,
size_dict,
track_childless=False,
track_flops=False,
track_write=False,
track_size=False,
objective=None,
nodeops="auto",
):
self.inputs = inputs
self.output = output
if isinstance(self.inputs[0], set) or isinstance(self.output, set):
warnings.warn(
"The inputs or output of this tree are not ordered."
"Costs will be accurate but actually contracting requires "
"ordered indices corresponding to array axes."
)
if not isinstance(next(iter(size_dict.values()), 1), int):
# make sure we are working with python integers to avoid overflow
# comparison errors with inf etc.
self.size_dict = {k: int(v) for k, v in size_dict.items()}
else:
self.size_dict = size_dict
self.N = len(self.inputs)
# the index representation for each input is an ordered mapping of
# each index to the number of times it has appeared on children. By
# also tracking the total number of appearances one can efficiently
# and locally compute which indices should be kept or contracted
self.appearances = {}
for term in self.inputs:
for ix in term:
self.appearances[ix] = self.appearances.get(ix, 0) + 1
# adding output appearances ensures these are never contracted away,
# N.B. if after this step every appearance count is exactly 2,
# then there are no 'hyper' indices in the contraction
for ix in self.output:
self.appearances[ix] = self.appearances.get(ix, 0) + 1
# this stores potentialy preprocessing steps that are not part of the
# main contraction tree, but assumed to have been applied, for example
# tracing or summing over indices that appear only once
self.preprocessing = {}
# mapping of parents to children - the core binary tree object
self.children = {}
# information about all the nodes
self.nodeops = get_nodeops(nodeops, self.N)
self.info = {}
# add constant nodes: the leaves
for leaf in self.gen_leaves():
self._add_node(leaf, extent=1) # leaf extent is always 1
# and the root or top node
self.root = self.nodeops.node_supremum(self.N)
self._add_node(self.root, extent=self.N) # root extent is always N
if self.N == 1:
# trivial 'contraction', single input maps directly to output,
self.children[self.root] = (leaf,)
# whether to keep track of dangling nodes/subgraphs
self.track_childless = track_childless
if self.track_childless:
# the set of dangling nodes
self.childless = oset([self.root])
# running largest_intermediate and total flops
self._track_flops = track_flops
if track_flops:
self._flops = 0
self._track_write = track_write
if track_write:
self._write = 0
self._track_size = track_size
if track_size:
self._sizes = MaxCounter()
# container for caching subtree reconfiguration condidates
self.already_optimized = dict()
# info relating to slicing (base constructor is always unsliced)
self.multiplicity = 1
self.sliced_inds = {}
self.sliced_inputs = frozenset()
# cache for compiled contraction cores
self.contraction_cores = {}
# a default objective function useful for
# further optimization and scoring
self._default_objective = objective
def set_state_from(self, other):
"""Set the internal state of this tree to that of ``other``."""
# immutable or never mutated properties
for attr in (
"appearances",
"inputs",
"multiplicity",
"N",
"output",
"root",
"size_dict",
"sliced_inputs",
"_default_objective",
):
setattr(self, attr, getattr(other, attr))
# mutable properties
for attr in (
"children",
"contraction_cores",
"nodeops",
"sliced_inds",
"preprocessing",
):
setattr(self, attr, getattr(other, attr).copy())
# dicts of mutable
for attr in ("info", "already_optimized"):
setattr(
self,
attr,
{k: v.copy() for k, v in getattr(other, attr).items()},
)
self.track_childless = other.track_childless
if other.track_childless:
self.childless = other.childless.copy()
self._track_flops = other._track_flops
if other._track_flops:
self._flops = other._flops
self._track_write = other._track_write
if other._track_write:
self._write = other._write
self._track_size = other._track_size
if other._track_size:
self._sizes = other._sizes.copy()
def copy(self):
"""Create a copy of this ``ContractionTree``."""
tree = object.__new__(self.__class__)
tree.set_state_from(self)
return tree
def set_default_objective(self, objective):
"""Set the objective function for this tree."""
self._default_objective = get_score_fn(objective)
def get_default_objective(self):
"""Get the objective function for this tree."""
if self._default_objective is None:
self._default_objective = get_score_fn("flops")
return self._default_objective
def get_default_combo_factor(self):
"""Get the default combo factor for this tree."""
objective = self.get_default_objective()
try:
return objective.factor
except AttributeError:
return DEFAULT_COMBO_FACTOR
def get_score(self, objective=None):
"""Score this tree using the default objective function."""
from .scoring import get_score_fn
if objective is None:
objective = self.get_default_objective()
objective = get_score_fn(objective)
return objective({"tree": self})
@property
def nslices(self):
"""Simple alias for how many independent contractions this tree
represents overall.
"""
return self.multiplicity
@property
def nchunks(self):
"""The number of 'chunks' - determined by the number of sliced output
indices.
"""
return prod(
si.size for si in self.sliced_inds.values() if not si.inner
)
def input_to_node(self, i):
"""Create a node from a single input index, i.e. the subgraph that
only contains the input tensor ``i``.
Parameters
----------
i : int
The input index.
Returns
-------
node : node_type
"""
return self.nodeops.node_from_single(i)
def node_to_input(self, node):
"""Assuming ``node`` has one element, i.e. is a leaf, return the
corresponding input index.
Parameters
----------
node : node_type
The node to convert.
Returns
-------
i : int
"""
return self.nodeops.node_get_single_el(node)
def node_to_terms(self, node):
"""Turn a node into the corresponding terms a sequence of leaf legs,
corresponding to input indices.
"""
return (
self.get_legs(self.input_to_node(i))
for i in self.get_subgraph(node)
)
def gen_leaves(self):
"""Generate the nodes representing leaves of the contraction tree, i.e.
of size 1 each corresponding to a single input tensor.
"""
return map(self.input_to_node, range(self.N))
def get_incomplete_nodes(self):
"""Get the set of current nodes that have no children and the set of
nodes that have no parents. These are the 'childless' and 'parentless'
nodes respectively, that need to be contracted to complete the tree.
The parentless nodes are grouped into the childless nodes that contain
them as subgraphs.
Returns
-------
groups : dict[node_type, list[node_type]]
A mapping of childless nodes to the list of parentless nodes are
beneath them.
See Also
--------
autocomplete
"""
childless = dict.fromkeys(
node
for node in self.info
# start wth all but leaves
if not self.is_leaf(node)
)
parentless = dict.fromkeys(
node
for node in self.info
# start with all but root
if not self.is_root(node)
)
for p, (l, r) in self.children.items():
parentless.pop(l)
parentless.pop(r)
childless.pop(p)
groups = {node: [] for node in childless}
for node in parentless:
# get the smallest node that contains this node
ancestor = min(
(
possible_parent
for possible_parent in childless
if set(self.get_subgraph(node)).issubset(
self.get_subgraph(possible_parent)
)
# XXX: for non-ssa node types could do:
# if self.is_descendant(node, possible_parent)
),
key=self.get_extent,
)
groups[ancestor].append(node)
return groups
def autocomplete(self, **contract_opts):
"""Contract all remaining node groups (as computed by
``tree.get_incomplete_nodes``) in the tree to complete it.
Parameters
----------
contract_opts
Options to pass to ``tree.contract_nodes``.
See Also
--------
get_incomplete_nodes, contract_nodes
"""
groups = self.get_incomplete_nodes()
for grandparent, parentless_subnodes in groups.items():
self.contract_nodes(
parentless_subnodes, grandparent=grandparent, **contract_opts
)
@classmethod
def from_path(
cls,
inputs,
output,
size_dict,
*,
path=None,
ssa_path=None,
edge_path=None,
optimize="auto",
autocomplete="auto",
check=False,
**kwargs,
):
"""Create a (completed) ``ContractionTree`` from the usual inputs plus
a standard contraction path or 'ssa_path' - you need to supply one.
Parameters
----------
inputs : Sequence[Sequence[str]]
The input indices of each tensor, as single unicode characters.
output : Sequence[str]
The output indices.
size_dict : dict[str, int]
The size of each index.
path : Sequence[Sequence[int]], optional
The contraction path, a sequence of pairs of tensor ids to
contract. The ids are linear indices into the list of temporary
tensors, which are recycled as each contraction pops a pair and
appends the result. One of ``path``, ``ssa_path`` or ``edge_path``
must be supplied.
ssa_path : Sequence[Sequence[int]], optional
The contraction path, a sequence of pairs of indices to contract.
The indices are single use, as if the result of each contraction is
appended to the end of the list of temporary tensors without
popping. One of ``path``, ``ssa_path`` or ``edge_path`` must be
supplied.
edge_path : Sequence[str], optional
The contraction path, a sequence of indices to contract in order.
One of ``path``, ``ssa_path`` or ``edge_path`` must be supplied.
optimize : str, optional
If a contraction within the path contains 3 or more tensors, how to
optimize this subcontraction into a binary tree.
autocomplete : "auto" or bool, optional
Whether to automatically complete the path, i.e. contract all
remaining nodes. If "auto" then a warning is issued if the path is
not complete.
check : bool, optional
Whether to perform some basic checks while creating the contraction
nodes.
Returns
-------
ContractionTree
"""
if (path is None) + (ssa_path is None) + (edge_path is None) != 2:
raise ValueError(
"Exactly one of ``path`` or ``ssa_path`` must be supplied."
)
contract_opts = {"optimize": optimize, "check": check}
if edge_path is not None:
from .pathfinders.path_basic import edge_path_to_ssa
ssa_path = edge_path_to_ssa(edge_path, inputs)
if ssa_path is not None:
path = ssa_path
tree = cls(inputs, output, size_dict, **kwargs)
if ssa_path is not None:
# ssa path ('single static assignment' ids)
nodes = dict(enumerate(tree.gen_leaves()))
ssa = len(nodes)
for p in path:
merge = [nodes.pop(i) for i in p]
nodes[ssa] = tree.contract_nodes(merge, **contract_opts)
ssa += 1
nodes = nodes.values()
else:
# regular path ('recycled' ids)
nodes = list(tree.gen_leaves())
for p in path:
merge = [nodes.pop(i) for i in sorted(p, reverse=True)]
nodes.append(tree.contract_nodes(merge, **contract_opts))
if len(nodes) > 1 and autocomplete:
if autocomplete == "auto":
# warn that we are completing
warnings.warn(
"Path was not complete - contracting all remaining. "
"You can silence this warning with `autocomplete=True`."
"Or produce an incomplete tree with `autocomplete=False`."
)
tree.contract_nodes(nodes, grandparent=tree.root, **contract_opts)
return tree
@classmethod
def from_info(cls, info, **kwargs):
"""Create a ``ContractionTree`` from an ``opt_einsum.PathInfo`` object."""
return cls.from_path(
inputs=info.input_subscripts.split(","),
output=info.output_subscript,
size_dict=info.size_dict,
path=info.path,
**kwargs,
)
@classmethod
def from_eq(cls, eq, size_dict, **kwargs):
"""Create a empty ``ContractionTree`` directly from an equation and set
of shapes.
Parameters
----------
eq : str
The einsum string equation.
size_dict : dict[str, int]
The size of each index.
"""
lhs, output = eq.split("->")
inputs = lhs.split(",")
return cls(inputs, output, size_dict, **kwargs)
def get_eq(self):
"""Get the einsum equation corresponding to this tree. Note that this
is the total (or original) equation, so includes indices which have
been sliced.
Returns
-------
eq : str
"""
return inputs_output_to_eq(self.inputs, self.output)
def get_shapes(self):
"""Get the shapes of the input tensors corresponding to this tree.
Returns
-------
shapes : tuple[tuple[int]]
"""
return tuple(
tuple(self.size_dict[ix] for ix in term) for term in self.inputs
)
def get_inputs_sliced(self):
"""Get the input indices corresponding to a single slice of this tree,
i.e. with sliced indices removed.
Returns
-------
inputs : tuple[tuple[str]]
"""
return tuple(
tuple(ix for ix in term if ix not in self.sliced_inds)
for term in self.inputs
)
def get_output_sliced(self):
"""Get the output indices corresponding to a single slice of this tree,
i.e. with sliced indices removed.
Returns
-------
output : tuple[str]
"""
return tuple(ix for ix in self.output if ix not in self.sliced_inds)
def get_eq_sliced(self):
"""Get the einsum equation corresponding to a single slice of this
tree, i.e. with sliced indices removed.
Returns
-------
eq : str
"""
return inputs_output_to_eq(
self.get_inputs_sliced(), self.get_output_sliced()
)
def get_shapes_sliced(self):
"""Get the shapes of the input tensors corresponding to a single slice
of this tree, i.e. with sliced indices removed.
Returns
-------
shapes : tuple[tuple[int]]
"""
return tuple(
tuple(
self.size_dict[ix] for ix in term if ix not in self.sliced_inds
)
for term in self.inputs
)
@classmethod
def from_edge_path(
cls,
edge_path,
inputs,
output,
size_dict,
optimize="auto",
autocomplete="auto",
check=False,
**kwargs,
):
"""Create a ``ContractionTree`` from an edge elimination ordering."""
warnings.warn(
"ContractionTree.from_edge_path(edge_path, ...) is deprecated. Use"
" ContractionTree.from_path(edge_path=edge_path, ...) instead."
)
return cls.from_path(
inputs,
output,
size_dict,
edge_path=edge_path,
optimize=optimize,
autocomplete=autocomplete,
check=check,
**kwargs,
)
def _add_node(self, node, check=False, **kwargs):
"""Add a node to this tree, specified either directly as a existing
node type, or as a subgraph (i.e. a sequence of input positions) which
is then converted to a node with the corresponding extent and subgraph
information.
Note if "ssa" nodes are used, then adding two equivalent subgraphs
will result in *two* new nodes, since the node labels do not
themselves encode the subgraph information.
Parameters
----------
node : node_type or Sequence[int]
The node to add, either directly as a node type, or as a subgraph
specified by the sequence of input positions it contains.
check : bool, optional
Whether to perform some basic checks on the node and tree state
before adding the node.
kwargs : dict, optional
Additional information to cache about this node, for example its
'extent' or 'subgraph'. If it is being specified as a sequence of
input positions, these two will be injected automatically.
Returns
-------
node : node_type
The node that was added, which may be different from the input if
the input was specified as a sequence of input positions.
"""
# first we possibly convert from subgraph spec to node
if not isinstance(node, self.nodeops.node_type):
# assume node *has* been specified as sequence of input positions
subgraph = tuple(node)
if len(subgraph) == 1:
# leaf node, for ssa we don't want to generate a new node
(i,) = subgraph
node = self.nodeops.node_from_single(i)
elif len(subgraph) == self.N:
# root node, for ssa we don't want to generate a new node
node = self.root
else:
# intermediate, assume we can generate new node identifier
node = self.nodeops.new_node_for_seq(subgraph)
kwargs.setdefault("extent", len(subgraph))
kwargs.setdefault("subgraph", subgraph)
if check:
if len(self.info) > 2 * self.N - 1:
raise ValueError("There are too many children already.")
if len(self.children) > self.N - 1:
raise ValueError("There are too many branches already.")
if not self.nodeops.is_valid_node(node):
raise ValueError("{} is not a valid node.".format(node))
try:
d = self.info[node]
except KeyError:
d = self.info[node] = {}
if kwargs:
d.update(kwargs)
return node
def _remove_node(self, node):
"""Remove ``node`` from this tree and update the flops and maximum size
if tracking them respectively, as well as input pre-processing.
"""
node_extent = self.get_extent(node)
if node_extent == 1:
# leaf nodes should always exist
self.info[node].clear()
self.info[node]["extent"] = 1 # leaf extent is always 1
# input: remove any associated preprocessing
self.preprocessing.pop(self.nodeops.node_get_single_el(node), None)
else:
# only non-leaf nodes contribute to size, flops and write
if self._track_size:
self._sizes.discard(self.get_size(node))
if self._track_flops:
self._flops -= self.get_flops(node)
if self._track_write:
self._write -= self.get_size(node)
del self.children[node]
if node_extent == self.N:
# root node should always exist
self.info[node].clear()
self.info[node]["extent"] = self.N # root extent is always N
else:
del self.info[node]
def compute_leaf_legs(self, i):
"""Compute the effective 'outer' indices for the ith input tensor. This
is not always simply the ith input indices, due to A) potential slicing
and B) potential preprocessing.
"""
# indices of input tensor (after slicing which is done immediately)
if self.sliced_inds:
term = tuple(
ix for ix in self.inputs[i] if ix not in self.sliced_inds
)
else:
term = self.inputs[i]
legs = {}
for ix in term:
legs[ix] = legs.get(ix, 0) + 1
# check for single term simplifications, these are treated as a simple
# preprocessing step that only is taken into account during actual
# contraction, and are not represented in the binary tree
# N.B. need to compute simplifiability *after* slicing
is_simplifiable = (
# repeated indices (diag or traces)
(len(term) != len(legs))
or
# reduced indices (are summed immediately)
any(
ix_count == self.appearances[ix]
for ix, ix_count in legs.items()
)
)
if is_simplifiable:
# compute the simplified legs -> the new effective input legs
legs = {
ix: ix_count
for ix, ix_count in legs.items()
if ix_count != self.appearances[ix]
}
# add a preprocessing step to the list of contractions
eq = inputs_output_to_eq((term,), legs, canonicalize=True)
self.preprocessing[i] = eq
return legs
def has_preprocessing(self):
# touch all inputs legs, since preprocessing is lazily computed
for node in self.gen_leaves():
self.get_legs(node)
return bool(self.preprocessing)
def has_hyper_indices(self):
"""Check if there are any 'hyper' indices in the contraction, i.e.
indices that don't appear exactly twice, when considering the inputs
and output.
"""
return any(ix_count != 2 for ix_count in self.appearances.values())
@cached_node_property("extent")
def get_extent(self, node):
"""Get the number of input tensors contained in the subgraph
represented by ``node``.
Parameters
----------
node : node_type
The node to compute the extent of.
Returns
-------
extent : int
"""
if node in self.children:
l, r = self.children[node]
return self.get_extent(l) + self.get_extent(r)
else:
return self.nodeops.node_size(node)
@cached_node_property("subgraph")
def get_subgraph(self, node) -> tuple[int, ...]:
"""Get the sequence of input tensors contained in subgraph represented
by ``node``.
Parameters
----------
node : node_type
The node to compute the subgraph of.
Returns
-------
subgraph : tuple[int]
The input tensor indices contained in this subgraph.
"""
node_extent = self.get_extent(node)
if node_extent == 1:
return (self.nodeops.node_get_single_el(node),)
elif node_extent == self.N:
return tuple(range(self.N))
else:
try:
left, right = self.children[node]
return self.get_subgraph(left) + self.get_subgraph(right)
except KeyError:
# this should only happen if directly creating
# incomplete nodes e.g. not in a bottom up fashion
# ssa nodes e.g. will not support this operation
return tuple(node)
@cached_node_property("legs")
def get_legs(self, node):
"""Get the effective 'outer' indices for the collection of tensors
in ``node``.
"""
# should this comparison be with self.N for efficiency?
if node == self.root:
# root legs are output, after slicing
# n.b. the index counts are irrelevant for the output
return {ix: 0 for ix in self.output if ix not in self.sliced_inds}
node_extent = self.get_extent(node)
if node_extent == 1:
# leaf legs are inputs
return self.compute_leaf_legs(
self.nodeops.node_get_single_el(node)
)
try:
involved = self.get_involved(node)
except KeyError:
# this should only happen if directly creating
# incomplete nodes e.g. not in a bottom up fashion
involved = legs_union(self.node_to_terms(node))
return {
ix: ix_count
for ix, ix_count in involved.items()
if ix_count < self.appearances[ix]
}
@cached_node_property("involved")
def get_involved(self, node):
"""Get all the indices involved in the formation of subgraph ``node``."""
if self.is_leaf(node):
return {}
sub_legs = map(self.get_legs, self.children[node])
return legs_union(sub_legs)
@cached_node_property("size")
def get_size(self, node):
"""Get the tensor size of ``node``."""
return compute_size_by_dict(self.get_legs(node), self.size_dict)
@cached_node_property("flops")
def get_flops(self, node):
"""Get the FLOPs for the pairwise contraction that will create
``node``.
"""
if self.is_leaf(node):
return 0
involved = self.get_involved(node)
return compute_size_by_dict(involved, self.size_dict)
@cached_node_property("can_dot")
def get_can_dot(self, node):
"""Get whether this contraction can be performed as a dot product (i.e.
with ``tensordot``), or else requires ``einsum``, as it has indices
that don't appear exactly twice in either the inputs or the output.
"""
l, r = self.children[node]
sp, sl, sr = map(self.get_legs, (node, l, r))
return set(sp) == set(sl).symmetric_difference(sr)
@cached_node_property("inds")
def get_inds(self, node):
"""Get the indices of this node - an ordered string version of
``get_legs`` that starts with ``tree.inputs`` and maintains the order
they appear in each contraction 'ABC,abc->ABCabc', to match tensordot.
"""
# NB: self.inputs and self.output contain the full (unsliced) indices
# thus we filter even the input legs and output legs
if self.get_extent(node) in (1, self.N):
return "".join(self.get_legs(node))
legs = self.get_legs(node)
l_inds, r_inds = map(self.get_inds, self.children[node])
# the filter here takes care of contracted indices
return "".join(
unique(filter(legs.__contains__, itertools.chain(l_inds, r_inds)))
)
@cached_node_property("tensordot_axes")
def get_tensordot_axes(self, node):
"""Get the ``axes`` arg for a tensordot ocontraction that produces
``node``. The pairs are sorted in order of appearance on the left
input.
"""
l_inds, r_inds = map(self.get_inds, self.children[node])
l_axes, r_axes = [], []
for i, ind in enumerate(l_inds):
j = r_inds.find(ind)
if j != -1:
l_axes.append(i)
r_axes.append(j)
return tuple(l_axes), tuple(r_axes)
@cached_node_property("tensordot_perm")
def get_tensordot_perm(self, node):
"""Get the permutation required, if any, to bring the tensordot output
of this nodes contraction into line with ``self.get_inds(node)``.
"""
l_inds, r_inds = map(self.get_inds, self.children[node])
# the target output inds
p_inds = self.get_inds(node)
# the tensordot output inds
td_inds = "".join(sorted(p_inds, key=f"{l_inds}{r_inds}".find))
if td_inds == p_inds:
return None
return tuple(map(td_inds.find, p_inds))
@cached_node_property("einsum_eq")
def get_einsum_eq(self, node):
"""Get the einsum string describing the contraction that produces
``node``, unlike ``get_inds`` the characters are mapped into [a-zA-Z],
for compatibility with ``numpy.einsum`` for example.
"""
l, r = self.children[node]
l_inds, r_inds, p_inds = map(self.get_inds, (l, r, node))
# we need to map any extended unicode characters into ascii
char_mapping = {
ord(ix): get_symbol(i)
for i, ix in enumerate(unique(itertools.chain(l_inds, r_inds)))
}
return f"{l_inds},{r_inds}->{p_inds}".translate(char_mapping)
def is_leaf(self, node):
"""Check if ``node`` is a leaf node in this tree.
Parameters
----------
node : node_type
The node to check.
Returns
-------
is_leaf : bool
"""
return self.nodeops.is_leaf(node)
def is_root(self, node):
"""Check if ``node`` is the root node in this tree.
Parameters
----------
node : node_type
The node to check.
Returns
-------
is_root : bool
"""
return self.nodeops.is_supremum(node, self.N)
def is_descendant(self, node, ancestor):
"""Check if ``node`` is a descendant of ``ancestor`` in this tree.
Parameters
----------
node : node_type
The node to check.
ancestor : node_type
The potential ancestor node.
Returns
-------
is_descendant : bool
"""
return self.nodeops.node_issubset(node, ancestor)
def get_peak_size(self, node):
"""Get the peak size for all but only the contractions required to
produce ``node``. The value for the root note will be the peak size of
the entire contraction.
Parameters
----------
node : node_type
The node to compute the peak size of.
Returns
-------
peak_size : int
"""
if self.is_leaf(node):
# leaf node is input
return self.get_size(node)
l, r = self.children[node]
# peak either occured while:
# 1. we were forming left intermediate
peakleft = self.get_peak_size(l)
# 2. we were forming right intermediate (whilst holding left)
peakright = self.get_size(l) + self.get_peak_size(r)
# 3. or we were performing this contraction, including output
peakthis = self.get_size(l) + self.get_size(r) + self.get_size(node)
return max(peakleft, peakright, peakthis)
def reorder_for_peak_size(self):
"""This reorders the depth first traversal of the tree to minimize
the peak size of the contraction.
"""
changed = False
for p, l, r in self.traverse():
sl = self.get_size(l)
pl = self.get_peak_size(l)
sr = self.get_size(r)
pr = self.get_peak_size(r)
# peak if hold left while we form right
plr = max(pl, sl + pr)
# peak if hold right while we form left
prl = max(pr, sr + pl)
if prl < plr:
self.children[p] = r, l
changed = True
return changed
def get_centrality(self, node):
try:
return self.info[node]["centrality"]
except KeyError:
self.compute_centralities()
return self.info[node]["centrality"]
def total_flops(self, dtype=None, log=None):
"""Sum the flops contribution from every node in the tree.
Parameters
----------
dtype : {'float', 'complex', None}, optional
Scale the answer depending on the assumed data type.
"""
if self._track_flops:
C = self.multiplicity * self._flops
else:
self._flops = 0
for node, _, _ in self.traverse():
self._flops += self.get_flops(node)
self._track_flops = True
C = self.multiplicity * self._flops
if dtype is None:
pass
elif "float" in dtype:
C *= 2
elif "complex" in dtype:
C *= 4
else:
raise ValueError(f"Unknown dtype {dtype}")
if log is not None:
C = math.log(max(C, 1), log)
return C
def total_write(self):
"""Sum the total amount of memory that will be created and operated on."""
if not self._track_write:
self._write = 0
for node, _, _ in self.traverse():
self._write += self.get_size(node)
self._track_write = True
return self.multiplicity * self._write
def combo_cost(self, factor=DEFAULT_COMBO_FACTOR, combine=sum, log=None):
t = 0
for p in self.children:
f = self.get_flops(p)
w = self.get_size(p)
t += combine((f, factor * w))
t *= self.multiplicity
if log is not None:
t = math.log(t, log)
return t
total_cost = combo_cost
def max_size(self, log=None):
"""The size of the largest intermediate tensor."""
if self.N == 1:
return self.get_size(self.root)
if not self._track_size:
self._sizes = MaxCounter()
for node, _, _ in self.traverse():
self._sizes.add(self.get_size(node))
self._track_size = True
size = self._sizes.max()
if log is not None:
size = math.log(size, log)
return size
def max_contraction_size(self, log=None):
"""The maximum size of a single contraction in the tree. This includes
the size of the two input tensors and the output tensor, and can be a
more practical measure of the peak memory required.
Parameters
----------
log : float, optional
If provided, return the log of the size to this base.
Returns
-------
size : int or float
The maximum size of a single contraction in the tree, or its log.
"""
Y = max(
self.get_size(p) + self.get_size(l) + self.get_size(r)
for p, (l, r) in self.children.items()
)
if log is not None:
Y = math.log(Y, log)
return Y
def peak_size(self, order=None, log=None):
"""Get the peak concurrent size of tensors needed - this depends on the
traversal order, i.e. the exact contraction path, not just the
contraction tree.
"""
tot_size = sum(self.get_size(node) for node in self.gen_leaves())
peak = tot_size
for p, l, r in self.traverse(order=order):
tot_size += self.get_size(p)
# measure peak assuming we need both inputs and output
peak = max(peak, tot_size)
tot_size -= self.get_size(l)
tot_size -= self.get_size(r)
if log is not None:
peak = math.log(peak, log)
return peak
def contract_stats(self, force=False):
"""Simulteneously compute the total flops, write and size of the
contraction tree. This is more efficient than calling each of the
individual methods separately. Once computed, each quantity is then
automatically tracked.
Returns
-------
stats : dict[str, int]
The total flops, write and size.
"""
if force or not (
self._track_flops and self._track_write and self._track_size
):
self._flops = self._write = 0
self._sizes = MaxCounter()
for node, _, _ in self.traverse():
self._flops += self.get_flops(node)
node_size = self.get_size(node)
self._write += node_size
self._sizes.add(node_size)
self._track_flops = self._track_write = self._track_size = True
return {
"flops": max(self.multiplicity * self._flops, 1),
"write": max(self.multiplicity * self._write, 1),
"size": max(self._sizes.max(), 1),
}
def arithmetic_intensity(self):
"""The ratio of total flops to total write - the higher the better for
extracting good computational performance.
"""
return self.total_flops(dtype=None) / self.total_write()
def contraction_scaling(self):
"""This is computed simply as the maximum number of indices involved
in any single contraction, which will match the scaling assuming that
all dimensions are equal.
"""
return max(len(self.get_involved(node)) for node in self.info)
def contraction_cost(self, log=None):
"""Get the total number of scalar operations ~ time complexity."""
return self.total_flops(dtype=None, log=log)
def naive_cost(self, log=None):
"""Get the naive cost of performing this contraction as a single
einsum summation, without any intermediate contractions. This is given
the as product of the size of all indices.
Parameters
----------
log : float, optional
If provided, return log of the cost to this base.
"""
if log is None:
return prod(self.size_dict[ix] for ix in self.appearances)
else:
return sum(
math.log(self.size_dict[ix], log) for ix in self.appearances
)
def speedup(self, log=None):
"""Speedup compared to naive summation.
Parameters
----------
log : float, optional
If provided, return log of the speedup to this base.
"""
if log is None:
return self.naive_cost() / self.contraction_cost()
else:
logc = self.contraction_cost(log=log)
logn = self.naive_cost(log=log)
return logn - logc
def contraction_width(self, log=2):
"""Get log2 of the size of the largest tensor."""
return self.max_size(log=log)
def compressed_contract_stats(
self,
chi=None,
order="surface_order",
compress_late=None,
):
if chi is None:
chi = self.get_default_chi()
if compress_late is None:
compress_late = self.get_default_compress_late()
hg = self.get_hypergraph(accel="auto")
# conversion between tree nodes <-> hypergraph nodes during contraction
tree_map = dict(zip(self.gen_leaves(), range(hg.get_num_nodes())))
tracker = CompressedStatsTracker(hg, chi)
for p, l, r in self.traverse(order):
li = tree_map[l]
ri = tree_map[r]
tracker.update_pre_step()
if compress_late:
tracker.update_pre_compress(hg, li, ri)
# compress just before we contract tensors
hg.compress(chi=chi, edges=hg.get_node(li))
hg.compress(chi=chi, edges=hg.get_node(ri))
tracker.update_post_compress(hg, li, ri)
tracker.update_pre_contract(hg, li, ri)
pi = tree_map[p] = hg.contract(li, ri)
tracker.update_post_contract(hg, pi)
if not compress_late:
# compress as soon as we can after contracting tensors
tracker.update_pre_compress(hg, pi)
hg.compress(chi=chi, edges=hg.get_node(pi))
tracker.update_post_compress(hg, pi)
tracker.update_post_step()
return tracker
def total_flops_compressed(
self,
chi=None,
order="surface_order",
compress_late=None,
dtype=None,
log=None,
):
"""Estimate the total flops for a compressed contraction of this tree
with maximum bond size ``chi``. This includes basic estimates of the
ops to perform contractions, QRs and SVDs.
"""
if dtype is not None:
raise ValueError(
"Can only estimate cost in terms of "
"number of abstract scalar ops."
)
F = self.compressed_contract_stats(
chi=chi,
order=order,
compress_late=compress_late,
).flops
if log is not None:
F = math.log(F, log)
return F
contraction_cost_compressed = total_flops_compressed
def total_write_compressed(
self,
chi=None,
order="surface_order",
compress_late=None,
log=None,
):
"""Compute the total size of all intermediate tensors when a
compressed contraction is performed with maximum bond size ``chi``,
ordered by ``order``. This is relevant maybe for time complexity and
e.g. autodiff space complexity (since every intermediate is kept).
"""
W = self.compressed_contract_stats(
chi=chi,
order=order,
compress_late=compress_late,
).write
if log is not None:
W = math.log(W, log)
return W
def combo_cost_compressed(
self,
chi=None,
order="surface_order",
compress_late=None,
factor=None,
log=None,
):
if factor is None:
factor = self.get_default_combo_factor()
C = self.total_flops_compressed(
chi=chi, order=order, compress_late=compress_late
) + factor * self.total_write_compressed(
chi=chi, order=order, compress_late=compress_late
)
if log is not None:
C = math.log(C, log)
return C
total_cost_compressed = combo_cost_compressed
def max_size_compressed(
self, chi=None, order="surface_order", compress_late=None, log=None
):
"""Compute the maximum sized tensor produced when a compressed
contraction is performed with maximum bond size ``chi``, ordered by
``order``. This is close to the ideal space complexity if only
tensors that are being directly operated on are kept in memory.
"""
S = self.compressed_contract_stats(
chi=chi,
order=order,
compress_late=compress_late,
).max_size
if log is not None:
S = math.log(S, log)
return S
def peak_size_compressed(
self,
chi=None,
order="surface_order",
compress_late=None,
accel="auto",
log=None,
):
"""Compute the peak size of combined intermediate tensors when a
compressed contraction is performed with maximum bond size ``chi``,
ordered by ``order``. This is the practical space complexity if one is
not swapping intermediates in and out of memory.
"""
P = self.compressed_contract_stats(
chi=chi,
order=order,
compress_late=compress_late,
).peak_size
if log is not None:
P = math.log(P, log)
return P
def contraction_width_compressed(
self, chi=None, order="surface_order", compress_late=None, log=2
):
"""Compute log2 of the maximum sized tensor produced when a compressed
contraction is performed with maximum bond size ``chi``, ordered by
``order``.
"""
return self.max_size_compressed(chi, order, compress_late, log=log)
def _update_tracked(self, node):
if self._track_flops:
self._flops += self.get_flops(node)
if self._track_write:
self._write += self.get_size(node)
if self._track_size:
self._sizes.add(self.get_size(node))
def contract_nodes_pair(
self,
x,
y,
legs=None,
cost=None,
size=None,
parent=None,
check=False,
):
"""Contract node ``x`` with node ``y`` in the tree to create a new
parent node, which is returned.
Parameters
----------
x : node_type
The first node to contract.
y : node_type
The second node to contract.
legs : dict[str, int], optional
The effective 'legs' of the new node if already known. If not
given, this is computed from the inputs of ``x`` and ``y``.
cost : int, optional
The cost of the contraction if already known. If not given, this is
computed from the inputs of ``x`` and ``y``.
size : int, optional
The size of the new node if already known. If not given, this is
computed from the inputs of ``x`` and ``y``.
check : bool, optional
Whether to check the inputs are valid.
Returns
-------
parent : node_type
The new parent node of ``x`` and ``y``.
"""
self._add_node(x, check=check)
self._add_node(y, check=check)
nx, ny = self.get_extent(x), self.get_extent(y)
if parent is None:
if nx + ny == self.N:
parent = self.root
else:
parent = self.nodeops.new_node_for_union([x, y])
self._add_node(parent, check=check)
# enforce left ordering of 'heaviest' subtrees
if nx == ny:
# deterministically break ties
sortx = self.nodeops.node_tie_breaker(x)
sorty = self.nodeops.node_tie_breaker(y)
else:
sortx = nx
sorty = ny
if sortx > sorty:
lr = (x, y)
else:
lr = (y, x)
self.children[parent] = lr
if self.track_childless:
self.childless.discard(parent)
if x not in self.children and nx > 1:
self.childless.add(x)
if y not in self.children and ny > 1:
self.childless.add(y)
# pre-computed information
if legs is not None:
self.info[parent]["legs"] = legs
if cost is not None:
self.info[parent]["flops"] = cost
if size is not None:
self.info[parent]["size"] = size
self._update_tracked(parent)
return parent
def contract_nodes(
self,
nodes,
optimize="auto",
grandparent=None,
check=False,
extra_opts=None,
):
"""Contract an arbitrary number of ``nodes`` in the tree to build up a
subtree. The root of this subtree (a new intermediate) is returned.
"""
# possibly convert from subgraph spec to node types
nodes = tuple(self._add_node(node, check=check) for node in nodes)
if len(nodes) == 1:
return nodes[0]
if len(nodes) == 2:
return self.contract_nodes_pair(
*nodes, parent=grandparent, check=check
)
from .interface import find_path
# create the bottom and top nodes
if grandparent is None:
if sum(map(self.get_extent, nodes)) == self.N:
# don't generate new node if root
grandparent = self.root
else:
# assume we can generate new node
grandparent = self.nodeops.new_node_for_union(nodes)
self._add_node(grandparent, check=check)
# if more than two nodes need to find the path to fill in between
# \
# GN <- 'grandparent'
# / \
# ?????????
# ????????????? <- to be filled with 'temp nodes'
# / \ / / \
# N0 N1 N2 N3 N4 <- ``nodes``, or, subgraphs
# / \ / / \
legs_inputs = tuple(map(self.get_legs, nodes))
path_inputs = tuple(map(tuple, legs_inputs))
try:
# output legs of the grandparent (after slicing)
# we dont' use get_legs since we can do the shortcut below
grand_legs = self.info[grandparent]["legs"]
except KeyError:
# compute legs directly from children
if grandparent == self.root:
# special case, need output ordering and sliced indices
grand_legs = self.get_legs(grandparent)
else:
involved = legs_union(legs_inputs)
grand_legs = {
ix: ix_count
for ix, ix_count in involved.items()
if ix_count < self.appearances[ix]
}
self.info[grandparent]["legs"] = grand_legs
path_output = tuple(grand_legs)
path = find_path(
path_inputs,
path_output,
self.size_dict,
optimize=optimize,
**(extra_opts or {}),
)
# now we have path create the nodes in between
temp_nodes = list(nodes)
for p in path[:-1]:
to_contract = [temp_nodes.pop(i) for i in sorted(p, reverse=True)]
temp_nodes.append(self.contract_nodes(to_contract, check=check))
# want to explicitly specify the grandparent node:
# so do the final pairwise contraction separately
self.contract_nodes(temp_nodes, grandparent=grandparent, check=check)
return grandparent
def is_complete(self):
"""Check every node has two children, unless it is a leaf."""
if self.N == 1:
return True
too_many_nodes = len(self.info) > 2 * self.N - 1
too_many_branches = len(self.children) > self.N - 1
if too_many_nodes or too_many_branches:
raise ValueError("Contraction tree seems to be over complete!")
queue = [self.root]
while queue:
node = queue.pop()
if self.is_leaf(node):
continue
try:
queue.extend(self.children[node])
except KeyError:
return False
return True
def get_default_order(self):
return "dfs"
def _traverse_dfs(self):
"""Traverse the tree in a depth first, non-recursive, order."""
ready = set(self.gen_leaves())
queue = [self.root]
while queue:
node = queue[-1]
l, r = self.children[node]
# both node's children are ready -> we can yield this contraction
if (l in ready) and (r in ready):
ready.add(queue.pop())
yield node, l, r
continue
if r not in ready:
queue.append(r)
if l not in ready:
queue.append(l)
def _traverse_ordered(self, order):
"""Traverse the tree in the order that minimizes ``order(node)``, but
still constrained to produce children before parents.
"""
from bisect import bisect
if order == "surface_order":
order = self.surface_order
seen = set()
queue = [self.root]
scores = [order(self.root)]
while len(seen) != len(self.children):
i = 0
while i < len(queue):
node = queue[i]
if node not in seen:
for child in self.children[node]:
if self.get_extent(child) > 1:
# insert child into queue by score + before parent
score = order(child)
ci = bisect(scores[:i], score)
scores.insert(ci, score)
queue.insert(ci, child)
# parent moves extra place to right
i += 1
seen.add(node)
i += 1
for node in queue:
yield (node, *self.children[node])
def traverse(self, order=None):
"""Generate, in order, all the node merges in this tree. Non-recursive!
This ensures children are always visited before their parent.
Parameters
----------
order : None, "dfs", or callable, optional
How to order the contractions within the tree. If a callable is
given (which should take a node as its argument), try to contract
nodes that minimize this function first.
Returns
-------
generator[tuple[node]]
The bottom up ordered sequence of tree merges, each a
tuple of ``(parent, left_child, right_child)``.
See Also
--------
descend
"""
if self.N == 1:
return
if order is None:
order = self.get_default_order()
if order == "dfs":
yield from self._traverse_dfs()
else:
yield from self._traverse_ordered(order=order)
def descend(self, mode="dfs"):
"""Generate, from root to leaves, all the node merges in this tree.
Non-recursive! This ensures parents are visited before their children.
Parameters
----------
mode : {'dfs', bfs}, optional
How expand from a parent.
Returns
-------
generator[tuple[node]
The top down ordered sequence of tree merges, each a
tuple of ``(parent, left_child, right_child)``.
See Also
--------
traverse
"""
queue = [self.root]
while queue:
if mode == "dfs":
parent = queue.pop(-1)
elif mode == "bfs":
parent = queue.pop(0)
l, r = self.children[parent]
yield parent, l, r
if self.get_extent(l) > 1:
queue.append(l)
if self.get_extent(r) > 1:
queue.append(r)
def get_subtree(self, node, size, search="bfs", seed=None):
"""Get a subtree spanning down from ``node`` which will have ``size``
leaves (themselves not necessarily leaves of the actual tree).
Parameters
----------
node : node_type
The node of the tree to start with.
size : int
How many subtree leaves to aim for.
search : {'bfs', 'dfs', 'random'}, optional
How to build the tree:
- 'bfs': breadth first expansion
- 'dfs': depth first expansion (largest nodes first)
- 'random': random expansion
seed : None, int or random.Random, optional
Random number generator seed, if ``search`` is 'random'.
Returns
-------
sub_leaves : tuple[node_type]
Nodes which are subtree leaves.
branches : tuple[node_type]
Nodes which are between the subtree leaves and root.
"""
# nodes which are subtree leaves
branches = []
# actual tree leaves - can't expand
real_leaves = []
# nodes to expand
queue = [node]
if search == "random":
rng = get_rng(seed)
else:
rng = None
if search == "bfs":
i = 0
elif search == "dfs":
i = -1
while (len(queue) + len(real_leaves) < size) and queue:
if rng is not None:
i = rng.randint(0, len(queue) - 1)
p = queue.pop(i)
if self.is_leaf(p):
real_leaves.append(p)
continue
# the left child is always >= in weight that right child
# if we append it last then ``.pop(-1)`` above perform the
# depth first search sorting by node subgraph size
l, r = self.children[p]
queue.append(r)
queue.append(l)
branches.append(p)
# nodes at the bottom of the subtree
sub_leaves = queue + real_leaves
return tuple(sub_leaves), tuple(branches)
def remove_ind(self, ind, project=None, inplace=False):
"""Remove (i.e. by default slice) index ``ind`` from this contraction
tree, taking care to update all relevant information about each node.
"""
tree = self if inplace else self.copy()
if ind in tree.sliced_inds:
raise ValueError(f"Index {ind} already sliced.")
# make sure all flops and size information has been populated
tree.contract_stats()
d = tree.size_dict[ind]
if project is None:
# we are slicing the index
si = SliceInfo(ind not in tree.output, ind, d, None)
tree.multiplicity = tree.multiplicity * d
else:
si = SliceInfo(ind not in tree.output, ind, 1, project)
# update the ordered slice information dictionary, but maintain the
# order such that output sliced indices always appear first ->
# enforced by the dataclass SliceInfo ordering
tree.sliced_inds = {
si.ind: si for si in sorted((*tree.sliced_inds.values(), si))
}
for node, node_info in tree.info.items():
if self.is_leaf(node):
# handle leaves separately
i = self.nodeops.node_get_single_el(node)
term = tree.inputs[i]
if ind in term:
# n.b. leaves don't contribute to size, flops or write
# simply recalculate all information, incl. preprocessing
tree._remove_node(node)
tree.sliced_inputs = tree.sliced_inputs | frozenset([i])
else:
involved = tree.get_involved(node)
if ind not in involved:
# if ind doesn't feature in this node (contraction)
# -> nothing to do
continue
# else update all the relevant information about this node
# -> flops changes for all involved indices
node_info["involved"] = legs_without(involved, ind)
old_flops = tree.get_flops(node)
new_flops = old_flops // d
node_info["flops"] = new_flops
tree._flops += new_flops - old_flops
# -> size and write only changes for node legs (output) indices
legs = tree.get_legs(node)
if ind in legs:
node_info["legs"] = legs_without(legs, ind)
old_size = tree.get_size(node)
tree._sizes.discard(old_size)
new_size = old_size // d
tree._sizes.add(new_size)
node_info["size"] = new_size
tree._write += new_size - old_size
# delete info we can't change
for k in (
"inds",
"einsum_eq",
"can_dot",
"tensordot_axes",
"tensordot_perm",
):
tree.info[node].pop(k, None)
tree.already_optimized.clear()
tree.contraction_cores.clear()
return tree
remove_ind_ = functools.partialmethod(remove_ind, inplace=True)
def restore_ind(self, ind, inplace=False):
"""Restore (unslice or un-project) index ``ind`` to this contraction
tree, taking care to update all relevant information about each node.
Parameters
----------
ind : str
The index to restore.
inplace : bool, optional
Whether to perform the restoration inplace or not.
Returns
-------
ContractionTree
"""
tree = self if inplace else self.copy()
# pop sliced index info
si = tree.sliced_inds.pop(ind)
# make sure all flops and size information has been populated
tree.contract_stats()
tree.multiplicity //= si.size
# handle inputs
for i, term in enumerate(tree.inputs):
# this is the original term with all indices
if ind in term:
tree._remove_node(self.input_to_node(i))
if all(ix not in tree.sliced_inds for ix in term):
# mark this input as not sliced
tree.sliced_inputs = tree.sliced_inputs - frozenset([i])
# delete and re-add dependent intermediates
for p, l, r in tree.traverse():
if ind in tree.get_legs(l) or ind in tree.get_legs(r):
tree._remove_node(p)
tree.contract_nodes_pair(l, r, parent=p)
# reset caches
tree.already_optimized.clear()
tree.contraction_cores.clear()
return tree
restore_ind_ = functools.partialmethod(restore_ind, inplace=True)
def unslice_rand(self, seed=None, inplace=False):
"""Unslice (restore) a random index from this contraction tree.
Parameters
----------
seed : None, int or random.Random, optional
Random number generator seed.
inplace : bool, optional
Whether to perform the unslicing inplace or not.
Returns
-------
ContractionTree
"""
rng = get_rng(seed)
ix = rng.choice(tuple(self.sliced_inds))
return self.restore_ind(ix, inplace=inplace)
unslice_rand_ = functools.partialmethod(unslice_rand, inplace=True)
def unslice_all(self, inplace=False):
"""Unslice (restore) all sliced indices from this contraction tree.
Parameters
----------
inplace : bool, optional
Whether to perform the unslicing inplace or not.
Returns
-------
ContractionTree
"""
tree = self if inplace else self.copy()
for ind in tuple(tree.sliced_inds):
tree.restore_ind_(ind)
return tree
unslice_all_ = functools.partialmethod(unslice_all, inplace=True)
def calc_subtree_candidates(self, pwr=2, what="flops"):
# get all intermediate nodes
candidates = list(self.children)
if what == "size":
weights = [self.get_size(x) for x in candidates]
elif what == "flops":
weights = [self.get_flops(x) for x in candidates]
if pwr == "log":
weights = [math.log2(max(2, w)) for w in weights]
else:
max_weight = max(weights)
# can be bigger than numpy int/float allows
weights = [float(w / max_weight) ** (1 / pwr) for w in weights]
# sort by descending score
candidates, weights = zip(
*sorted(zip(candidates, weights), key=lambda x: -x[1])
)
return list(candidates), list(weights)
def _subtree_remove_and_optimize(
self,
sub_root,
sub_leaves,
sub_branches,
already_optimized,
node_cost,
minimize,
opt,
pbar,
):
current_cost = node_cost(self, sub_root)
for node in sub_branches:
# these are the intermediates *between* leaves and sub-root
if minimize == "size":
current_cost = max(current_cost, node_cost(self, node))
else:
current_cost += node_cost(self, node)
self._remove_node(node)
# make the optimizer more efficient by supplying accurate cap
opt.cost_cap = max(2, current_cost)
# and reoptimize the leaves
self.contract_nodes(sub_leaves, optimize=opt, grandparent=sub_root)
already_optimized.add(sub_leaves)
if pbar is not None:
pbar.update()
pbar.set_description(_describe_tree(self), refresh=False)
def _subtree_reconfigure_descend(
self,
subtree_size,
subtree_search,
maxiter,
seed,
minimize,
opt,
already_optimized,
node_cost,
pbar,
):
candidates = [self.root]
any_modified = False
def _possibly_add_children(sub_root, any_modified):
if self.get_extent(sub_root) > subtree_size:
# possibly extend with node children, if not close to bottom
lnode, rnode = self.children[sub_root]
if self.get_extent(lnode) >= 2:
candidates.append(lnode)
if self.get_extent(rnode) >= 2:
candidates.append(rnode)
if len(candidates) == 0:
# exhausted queue
if any_modified:
# but have made *any* changes -> go again from top
candidates.append(self.root)
any_modified = False
return any_modified
r = 0
while candidates and r < maxiter:
sub_root = candidates.pop(0)
# get a subtree to possibly reconfigure
sub_leaves, sub_branches = self.get_subtree(
sub_root, size=subtree_size, search=subtree_search, seed=seed
)
# check if its already been optimized
sub_leaves = frozenset(sub_leaves)
if sub_leaves in already_optimized:
any_modified = _possibly_add_children(sub_root, any_modified)
continue
# else remove the branches, keeping track of current cost
self._subtree_remove_and_optimize(
sub_root,
sub_leaves,
sub_branches,
already_optimized,
node_cost,
minimize,
opt,
pbar,
)
any_modified = _possibly_add_children(sub_root, True)
r += 1
def _subtree_reconfigure_rand_select(
self,
subtree_size,
subtree_search,
weight_what,
weight_pwr,
select,
maxiter,
seed,
minimize,
opt,
already_optimized,
node_cost,
pbar,
):
if select == "random":
rng = get_rng(seed)
else:
rng = None
if select == "max":
i = 0
elif select == "min":
i = -1
candidates, weights = self.calc_subtree_candidates(
pwr=weight_pwr, what=weight_what
)
r = 0
while candidates and r < maxiter:
if rng is not None:
(i,) = rng.choices(range(len(candidates)), weights=weights)
weights.pop(i)
sub_root = candidates.pop(i)
# get a subtree to possibly reconfigure
sub_leaves, sub_branches = self.get_subtree(
sub_root, size=subtree_size, search=subtree_search, seed=seed
)
# check if its already been optimized
sub_leaves = frozenset(sub_leaves)
if sub_leaves in already_optimized:
continue
# else remove the branches, keeping track of current cost
self._subtree_remove_and_optimize(
sub_root,
sub_leaves,
sub_branches,
already_optimized,
node_cost,
minimize,
opt,
pbar,
)
# if we have reconfigured simply re-add all candidates
candidates, weights = self.calc_subtree_candidates(
pwr=weight_pwr, what=weight_what
)
r += 1
def subtree_reconfigure(
self,
subtree_size=8,
subtree_search="bfs",
weight_what="flops",
weight_pwr=2,
select="max",
maxiter=500,
seed=None,
minimize=None,
optimize=None,
inplace=False,
progbar=False,
):
"""Reconfigure subtrees of this tree with locally optimal paths.
Parameters
----------
subtree_size : int, optional
The size of subtree to consider. Cost is exponential in this.
subtree_search : {'bfs', 'dfs', 'random'}, optional
How to build the subtrees:
- 'bfs': breadth-first-search creating balanced subtrees
- 'dfs': depth-first-search creating imbalanced subtrees
- 'random': random subtree building
weight_what : {'flops', 'size'}, optional
When assessing nodes to build and optimize subtrees from whether to
score them by the (local) contraction cost, or tensor size.
weight_pwr : int, optional
When assessing nodes to build and optimize subtrees from, how to
scale their score into a probability: ``score**(1 / weight_pwr)``.
The larger this is the more explorative the algorithm is when
``select='random'``.
select : {'descend', 'max', 'min', 'random'}, optional
What order to select node subtrees to optimize:
- 'descend': start from the root and then descend into children. In
this case the weights and weight_pwr are ignored since this is a
deterministic order.
- 'max': choose the highest score first
- 'min': choose the lowest score first
- 'random': choose randomly weighted on score - see ``weight_pwr``.
maxiter : int, optional
How many subtree optimizations to perform, the algorithm can
terminate before this if all subtrees have been optimized.
seed : int, optional
A random seed (seeds python system random module).
minimize : {'flops', 'size'}, optional
Whether to minimize with respect to contraction flops or size.
inplace : bool, optional
Whether to perform the reconfiguration inplace or not.
progbar : bool, optional
Whether to show live progress of the reconfiguration.
Returns
-------
ContractionTree
"""
tree = self if inplace else self.copy()
tree.reset_contraction_indices()
# ensure these have been computed and thus are being tracked
tree.contract_stats()
if minimize is None:
minimize = self.get_default_objective()
scorer = get_score_fn(minimize)
node_cost = getattr(scorer, "cost_local_tree_node", lambda _: 2)
if optimize is None:
from .pathfinders.path_basic import OptimalOptimizer
minimize = scorer.get_dynamic_programming_minimize()
opt = OptimalOptimizer(minimize=minimize)
else:
opt = optimize
# different caches as we might want to reconfigure one before other
tree.already_optimized.setdefault(minimize, set())
already_optimized = tree.already_optimized[minimize]
if progbar:
import tqdm
pbar = tqdm.tqdm()
pbar.set_description(_describe_tree(tree), refresh=False)
else:
pbar = None
try:
reconf_kwargs = {
"subtree_size": subtree_size,
"subtree_search": subtree_search,
"maxiter": maxiter,
"seed": seed,
"minimize": minimize,
"opt": opt,
"already_optimized": already_optimized,
"node_cost": node_cost,
"pbar": pbar,
}
if select == "descend":
tree._subtree_reconfigure_descend(**reconf_kwargs)
else:
reconf_kwargs["weight_what"] = weight_what
reconf_kwargs["weight_pwr"] = weight_pwr
reconf_kwargs["select"] = select
tree._subtree_reconfigure_rand_select(**reconf_kwargs)
finally:
if progbar:
pbar.close()
return tree
subtree_reconfigure_ = functools.partialmethod(
subtree_reconfigure, inplace=True
)
def subtree_reconfigure_forest(
self,
num_trees=8,
num_restarts=10,
restart_fraction=0.5,
subtree_maxiter=100,
subtree_size=10,
subtree_search=("random", "bfs"),
subtree_select=("random",),
subtree_weight_what=("flops", "size"),
subtree_weight_pwr=(2,),
parallel="auto",
parallel_maxiter_steps=4,
minimize=None,
seed=None,
progbar=False,
inplace=False,
):
"""'Forested' version of ``subtree_reconfigure`` which is more
explorative and can be parallelized. It stochastically generates
a 'forest' reconfigured trees, then only keeps some fraction of these
to generate the next forest.
Parameters
----------
num_trees : int, optional
The number of trees to reconfigure at each stage.
num_restarts : int, optional
The number of times to halt, prune and then restart the
tree reconfigurations.
restart_fraction : float, optional
The fraction of trees to keep at each stage and generate the next
forest from.
subtree_maxiter : int, optional
Number of subtree reconfigurations per step.
``num_restarts * subtree_maxiter`` is the max number of total
subtree reconfigurations for the final tree produced.
subtree_size : int, optional
The size of subtrees to search for and reconfigure.
subtree_search : tuple[{'random', 'bfs', 'dfs'}], optional
Tuple of options for the ``search`` kwarg of
:meth:`ContractionTree.subtree_reconfigure` to randomly sample.
subtree_select : tuple[{'random', 'max', 'min'}], optional
Tuple of options for the ``select`` kwarg of
:meth:`ContractionTree.subtree_reconfigure` to randomly sample.
subtree_weight_what : tuple[{'flops', 'size'}], optional
Tuple of options for the ``weight_what`` kwarg of
:meth:`ContractionTree.subtree_reconfigure` to randomly sample.
subtree_weight_pwr : tuple[int], optional
Tuple of options for the ``weight_pwr`` kwarg of
:meth:`ContractionTree.subtree_reconfigure` to randomly sample.
parallel : 'auto', False, True, int, or distributed.Client
Whether to parallelize the search.
parallel_maxiter_steps : int, optional
If parallelizing, how many steps to break each reconfiguration into
in order to evenly saturate many processes.
minimize : {'flops', 'size', ..., Objective}, optional
Whether to minimize the total flops or maximum size of the
contraction tree.
seed : None, int or random.Random, optional
A random seed to use.
progbar : bool, optional
Whether to show live progress.
inplace : bool, optional
Whether to perform the subtree reconfiguration inplace.
Returns
-------
ContractionTree
"""
tree = self if inplace else self.copy()
tree.reset_contraction_indices()
# candidate trees
num_keep = max(1, int(num_trees * restart_fraction))
# how to rank the trees
if minimize is None:
minimize = self.get_default_objective()
score = get_score_fn(minimize)
rng = get_rng(seed)
# set up the initial 'forest' and parallel machinery
pool = parse_parallel_arg(parallel)
is_scatter_pool = can_scatter(pool)
if is_scatter_pool:
is_worker = maybe_leave_pool(pool)
# store the trees as futures for the entire process
forest = [scatter(pool, tree)]
maxiter = subtree_maxiter // parallel_maxiter_steps
else:
forest = [tree]
maxiter = subtree_maxiter
if progbar:
import tqdm
pbar = tqdm.tqdm(total=num_restarts)
pbar.set_description(_describe_tree(tree), refresh=False)
try:
for _ in range(num_restarts):
# on the next round take only the best trees
forest = itertools.cycle(forest[:num_keep])
# select some random configurations
saplings = [
{
"tree": next(forest),
"maxiter": maxiter,
"minimize": minimize,
"subtree_size": subtree_size,
"subtree_search": rng.choice(subtree_search),
"select": rng.choice(subtree_select),
"weight_pwr": rng.choice(subtree_weight_pwr),
"weight_what": rng.choice(subtree_weight_what),
}
for _ in range(num_trees)
]
if pool is None:
forest = [_reconfigure_tree(**s) for s in saplings]
res = [{"tree": t, **_get_tree_info(t)} for t in forest]
elif not is_scatter_pool:
forest_futures = [
submit(pool, _reconfigure_tree, **s) for s in saplings
]
forest = [f.result() for f in forest_futures]
res = [{"tree": t, **_get_tree_info(t)} for t in forest]
else:
# submit in smaller steps to saturate processes
for _ in range(parallel_maxiter_steps):
for s in saplings:
s["tree"] = submit(pool, _reconfigure_tree, **s)
# compute scores remotely then gather
forest_futures = [s["tree"] for s in saplings]
res_futures = [
submit(pool, _get_tree_info, t) for t in forest_futures
]
res = [
{"tree": tree_future, **res_future.result()}
for tree_future, res_future in zip(
forest_futures, res_futures
)
]
# update the order of the new forest
res.sort(key=score)
forest = [r["tree"] for r in res]
if progbar:
pbar.update()
if pool is None:
d = _describe_tree(forest[0])
else:
d = submit(pool, _describe_tree, forest[0]).result()
pbar.set_description(d, refresh=False)
finally:
if progbar:
pbar.close()
if is_scatter_pool:
tree.set_state_from(forest[0].result())
maybe_rejoin_pool(is_worker, pool)
else:
tree.set_state_from(forest[0])
return tree
subtree_reconfigure_forest_ = functools.partialmethod(
subtree_reconfigure_forest, inplace=True
)
simulated_anneal = simulated_anneal_tree
simulated_anneal_ = functools.partialmethod(simulated_anneal, inplace=True)
parallel_temper = parallel_temper_tree
parallel_temper_ = functools.partialmethod(parallel_temper, inplace=True)
def slice(
self,
target_size=None,
target_overhead=None,
target_slices=None,
temperature=0.01,
minimize=None,
allow_outer=True,
max_repeats=16,
reslice=False,
seed=None,
inplace=False,
):
"""Slice this tree (turn some indices into indices which are explicitly
summed over rather than being part of contractions). The indices are
stored in ``tree.sliced_inds``, and the contraction width updated to
take account of the slicing. Calling ``tree.contract(arrays)`` moreover
which automatically perform the slicing and summation.
Parameters
----------
target_size : int, optional
The target number of entries in the largest tensor of the sliced
contraction. The search algorithm will terminate after this is
reached.
target_slices : int, optional
The target or minimum number of 'slices' to consider - individual
contractions after slicing indices. The search algorithm will
terminate after this is breached. This is on top of the current
number of slices.
target_overhead : float, optional
The target increase in total number of floating point operations.
For example, a value of ``2.0`` will terminate the search just
before the cost of computing all the slices individually breaches
twice that of computing the original contraction all at once.
temperature : float, optional
How much to randomize the repeated search.
minimize : {'flops', 'size', ..., Objective}, optional
Which metric to score the overhead increase against.
allow_outer : bool, optional
Whether to allow slicing of outer indices.
max_repeats : int, optional
How many times to repeat the search with a slight randomization.
reslice : bool, optional
Whether to reslice the tree, i.e. first remove all currently
sliced indices and start the search again. Generally any 'good'
sliced indices will be easily found again.
seed : None, int or random.Random, optional
A random seed or generator to use for the search.
inplace : bool, optional
Whether the remove the indices from this tree inplace or not.
Returns
-------
ContractionTree
See Also
--------
SliceFinder, ContractionTree.slice_and_reconfigure
"""
from .slicer import SliceFinder
if minimize is None:
minimize = self.get_default_objective()
tree = self if inplace else self.copy()
if reslice:
if target_slices is not None:
target_slices *= tree.nslices
tree.unslice_all_()
sf = SliceFinder(
tree,
target_size=target_size,
target_overhead=target_overhead,
target_slices=target_slices,
temperature=temperature,
minimize=minimize,
allow_outer=allow_outer,
seed=seed,
)
ix_sl, _ = sf.search(max_repeats)
for ix in ix_sl:
tree.remove_ind_(ix)
return tree
slice_ = functools.partialmethod(slice, inplace=True)
def slice_and_reconfigure(
self,
target_size,
step_size=2,
temperature=0.01,
minimize=None,
allow_outer=True,
max_repeats=16,
reslice=False,
reconf_opts=None,
progbar=False,
inplace=False,
):
"""Interleave slicing (removing indices into an exterior sum) with
subtree reconfiguration to minimize the overhead induced by this
slicing.
Parameters
----------
target_size : int
Slice the tree until the maximum intermediate size is this or
smaller.
step_size : int, optional
The minimum size reduction to try and achieve before switching to a
round of subtree reconfiguration.
temperature : float, optional
The temperature to supply to ``SliceFinder`` for searching for
indices.
minimize : {'flops', 'size', ..., Objective}, optional
The metric to minimize when slicing and reconfiguring subtrees.
max_repeats : int, optional
The number of slicing attempts to perform per search.
progbar : bool, optional
Whether to show live progress.
inplace : bool, optional
Whether to perform the slicing and reconfiguration inplace.
reconf_opts : None or dict, optional
Supplied to
:meth:`ContractionTree.subtree_reconfigure` or
:meth:`ContractionTree.subtree_reconfigure_forest`, depending on
`'forested'` key value.
"""
tree = self if inplace else self.copy()
reconf_opts = {} if reconf_opts is None else dict(reconf_opts)
if minimize is None:
minimize = self.get_default_objective()
minimize = get_score_fn(minimize)
reconf_opts.setdefault("minimize", minimize)
forested_reconf = reconf_opts.pop("forested", False)
if progbar:
import tqdm
pbar = tqdm.tqdm()
pbar.set_description(_describe_tree(tree), refresh=False)
try:
while tree.max_size() > target_size:
tree.slice_(
temperature=temperature,
target_slices=step_size,
minimize=minimize,
allow_outer=allow_outer,
max_repeats=max_repeats,
reslice=reslice,
)
if forested_reconf:
tree.subtree_reconfigure_forest_(**reconf_opts)
else:
tree.subtree_reconfigure_(**reconf_opts)
if progbar:
pbar.update()
pbar.set_description(_describe_tree(tree), refresh=False)
finally:
if progbar:
pbar.close()
return tree
slice_and_reconfigure_ = functools.partialmethod(
slice_and_reconfigure, inplace=True
)
def slice_and_reconfigure_forest(
self,
target_size,
step_size=2,
num_trees=8,
restart_fraction=0.5,
temperature=0.02,
max_repeats=32,
reslice=False,
minimize=None,
allow_outer=True,
parallel="auto",
progbar=False,
inplace=False,
reconf_opts=None,
):
"""'Forested' version of :meth:`ContractionTree.slice_and_reconfigure`.
This maintains a 'forest' of trees with different slicing and subtree
reconfiguration attempts, pruning the worst at each step and generating
a new forest from the best.
Parameters
----------
target_size : int
Slice the tree until the maximum intermediate size is this or
smaller.
step_size : int, optional
The minimum size reduction to try and achieve before switching to a
round of subtree reconfiguration.
num_restarts : int, optional
The number of times to halt, prune and then restart the
tree reconfigurations.
restart_fraction : float, optional
The fraction of trees to keep at each stage and generate the next
forest from.
temperature : float, optional
The temperature at which to randomize the sliced index search.
max_repeats : int, optional
The number of slicing attempts to perform per search.
parallel : 'auto', False, True, int, or distributed.Client
Whether to parallelize the search.
progbar : bool, optional
Whether to show live progress.
inplace : bool, optional
Whether to perform the slicing and reconfiguration inplace.
reconf_opts : None or dict, optional
Supplied to
:meth:`ContractionTree.slice_and_reconfigure`.
Returns
-------
ContractionTree
"""
tree = self if inplace else self.copy()
tree.reset_contraction_indices()
# candidate trees
num_keep = max(1, int(num_trees * restart_fraction))
# how to rank the trees
if minimize is None:
minimize = self.get_default_objective()
score = get_score_fn(minimize)
# set up the initial 'forest' and parallel machinery
pool = parse_parallel_arg(parallel)
is_scatter_pool = can_scatter(pool)
if is_scatter_pool:
is_worker = maybe_leave_pool(pool)
# store the trees as futures for the entire process
forest = [scatter(pool, tree)]
else:
forest = [tree]
if progbar:
import tqdm
pbar = tqdm.tqdm()
pbar.set_description(_describe_tree(tree), refresh=False)
next_size = tree.max_size()
try:
while True:
next_size //= step_size
# on the next round take only the best trees
forest = itertools.cycle(forest[:num_keep])
saplings = [
{
"tree": next(forest),
"target_size": next_size,
"step_size": step_size,
"temperature": temperature,
"max_repeats": max_repeats,
"reconf_opts": reconf_opts,
"allow_outer": allow_outer,
"reslice": reslice,
}
for _ in range(num_trees)
]
if pool is None:
forest = [
_slice_and_reconfigure_tree(**s) for s in saplings
]
res = [{"tree": t, **_get_tree_info(t)} for t in forest]
elif not is_scatter_pool:
# simple pool with no pass by reference
forest_futures = [
submit(pool, _slice_and_reconfigure_tree, **s)
for s in saplings
]
forest = [f.result() for f in forest_futures]
res = [{"tree": t, **_get_tree_info(t)} for t in forest]
else:
forest_futures = [
submit(pool, _slice_and_reconfigure_tree, **s)
for s in saplings
]
# compute scores remotely then gather
res_futures = [
submit(pool, _get_tree_info, t) for t in forest_futures
]
res = [
{"tree": tree_future, **res_future.result()}
for tree_future, res_future in zip(
forest_futures, res_futures
)
]
# we want to sort by flops, but also favour sampling as
# many different sliced index combos as possible
# ~ [1, 1, 1, 2, 2, 3] -> [1, 2, 3, 1, 2, 1]
res.sort(key=score)
res = list(
interleave(
groupby(lambda r: r["sliced_ind_set"], res).values()
)
)
# update the order of the new forest
forest = [r["tree"] for r in res]
if progbar:
pbar.update()
if pool is None:
d = _describe_tree(forest[0])
else:
d = submit(pool, _describe_tree, forest[0]).result()
pbar.set_description(d, refresh=False)
if res[0]["size"] <= target_size:
break
finally:
if progbar:
pbar.close()
if is_scatter_pool:
tree.set_state_from(forest[0].result())
maybe_rejoin_pool(is_worker, pool)
else:
tree.set_state_from(forest[0])
return tree
slice_and_reconfigure_forest_ = functools.partialmethod(
slice_and_reconfigure_forest, inplace=True
)
def compressed_reconfigure(
self,
minimize=None,
order_only=False,
max_nodes="auto",
max_time=None,
local_score=None,
exploration_power=0,
best_score=None,
progbar=False,
inplace=False,
):
"""Reconfigure this tree according to ``peak_size_compressed``.
Parameters
----------
chi : int
The maximum bond dimension to consider.
order_only : bool, optional
Whether to only consider the ordering of the current tree
contractions, or all possible contractions, starting with the
current.
max_nodes : int, optional
Set the maximum number of contraction steps to consider.
max_time : float, optional
Set the maximum time to spend on the search.
local_score : callable, optional
A function that assigns a score to a potential contraction, with a
lower score giving more priority to explore that contraction
earlier. It should have signature::
local_score(step, new_score, dsize, new_size)
where ``step`` is the number of steps so far, ``new_score`` is the
score of the contraction so far, ``dsize`` is the change in memory
by the current step, and ``new_size`` is the new memory size after
contraction.
exploration_power : float, optional
If not ``0.0``, the inverse power to which the step is raised in
the default local score function. Higher values favor exploring
more promising branches early on - at the cost of increased memory.
Ignored if ``local_score`` is supplied.
best_score : float, optional
Manually specify an upper bound for best score found so far.
progbar : bool, optional
If ``True``, display a progress bar.
inplace : bool, optional
Whether to perform the reconfiguration inplace on this tree.
Returns
-------
ContractionTree
"""
from .experimental.path_compressed_branchbound import (
CompressedExhaustive,
)
if minimize is None:
minimize = self.get_default_objective()
if max_nodes == "auto":
if max_time is None:
max_nodes = max(10_000, self.N**2)
else:
max_nodes = float("inf")
opt = CompressedExhaustive(
minimize=minimize,
local_score=local_score,
max_nodes=max_nodes,
max_time=max_time,
exploration_power=exploration_power,
best_score=best_score,
progbar=progbar,
)
opt.setup(self.inputs, self.output, self.size_dict)
opt.explore_path(self.get_path_surface(), restrict=order_only)
# rtree = opt.search(self.inputs, self.output, self.size_dict)
opt.run(self.inputs, self.output, self.size_dict)
ssa_path = opt.ssa_path
# ssa_path = opt(self.inputs, self.output, self.size_dict)
rtree = self.__class__.from_path(
self.inputs,
self.output,
self.size_dict,
ssa_path=ssa_path,
objective=minimize,
)
if inplace:
self.set_state_from(rtree)
rtree = self
rtree.reset_contraction_indices()
return rtree
compressed_reconfigure_ = functools.partialmethod(
compressed_reconfigure, inplace=True
)
def windowed_reconfigure(
self,
minimize=None,
order_only=False,
window_size=20,
max_iterations=100,
max_window_tries=1000,
score_temperature=0.0,
queue_temperature=1.0,
scorer=None,
queue_scorer=None,
seed=None,
inplace=False,
progbar=False,
**kwargs,
):
from .pathfinders.path_compressed import WindowedOptimizer
if minimize is None:
minimize = self.get_default_objective()
wo = WindowedOptimizer(
self.inputs,
self.output,
self.size_dict,
minimize=minimize,
ssa_path=self.get_ssa_path(),
seed=seed,
)
wo.refine(
window_size=window_size,
max_iterations=max_iterations,
order_only=order_only,
max_window_tries=max_window_tries,
score_temperature=score_temperature,
queue_temperature=queue_temperature,
scorer=scorer,
queue_scorer=queue_scorer,
progbar=progbar,
**kwargs,
)
ssa_path = wo.get_ssa_path()
rtree = self.__class__.from_path(
self.inputs,
self.output,
self.size_dict,
ssa_path=ssa_path,
objective=minimize,
)
if inplace:
self.set_state_from(rtree)
rtree = self
rtree.reset_contraction_indices()
return rtree
windowed_reconfigure_ = functools.partialmethod(
windowed_reconfigure, inplace=True
)
def flat_tree(self, order=None):
"""Create a nested tuple representation of the contraction tree like::
((0, (1, 2)), ((3, 4), ((5, (6, 7)), (8, 9))))
Such that the contraction will progress like::
((0, (1, 2)), ((3, 4), ((5, (6, 7)), (8, 9))))
((0, 12), (34, ((5, 67), 89)))
(012, (34, (567, 89)))
(012, (34, 56789))
(012, 3456789)
0123456789
Where each integer represents a leaf (i.e. single element node).
"""
tups = dict(zip(self.gen_leaves(), range(self.N)))
for parent, l, r in self.traverse(order=order):
tups[parent] = tups[l], tups[r]
return tups[self.root]
def get_leaves_ordered(self):
"""Return the list of leaves as ordered by the contraction tree.
Returns
-------
tuple[node_type]
"""
if not self.is_complete():
raise ValueError("Can't order the leaves until tree is complete.")
return tuple(
node
for node in itertools.chain.from_iterable(self.traverse())
if self.is_leaf(node)
)
def get_path(self, order=None):
"""Generate a standard path (with linear recycled ids) from the
contraction tree.
Parameters
----------
order : None, "dfs", or callable, optional
How to order the contractions within the tree. If a callable is
given (which should take a node as its argument), try to contract
nodes that minimize this function first.
Returns
-------
path: tuple[tuple[int, int]]
"""
from bisect import bisect_left
ssa = self.N
ssas = list(range(ssa))
node_to_ssa = dict(zip(self.gen_leaves(), ssas))
path = []
for parent, left, right in self.traverse(order=order):
# map nodes to ssas
lssa = node_to_ssa[left]
rssa = node_to_ssa[right]
# map ssas to linear indices, using bisection
i, j = sorted((bisect_left(ssas, lssa), bisect_left(ssas, rssa)))
# 'contract' nodes
ssas.pop(j)
ssas.pop(i)
path.append((i, j))
ssas.append(ssa)
# update mapping
node_to_ssa[parent] = ssa
ssa += 1
return tuple(path)
path = deprecated(get_path, "path", "get_path")
def get_numpy_path(self, order=None):
"""Generate a path compatible with the `optimize` kwarg of
`numpy.einsum`.
"""
return ["einsum_path", *self.get_path(order=order)]
def get_ssa_path(self, order=None):
"""Generate a single static assignment path from the contraction tree.
Parameters
----------
order : None, "dfs", or callable, optional
How to order the contractions within the tree. If a callable is
given (which should take a node as its argument), try to contract
nodes that minimize this function first.
Returns
-------
ssa_path: tuple[tuple[int, int]]
"""
ssa_path = []
pos = dict(zip(self.gen_leaves(), range(self.N)))
for parent, l, r in self.traverse(order=order):
i, j = sorted((pos[l], pos[r]))
ssa_path.append((i, j))
pos[parent] = len(ssa_path) + self.N - 1
return tuple(ssa_path)
ssa_path = deprecated(get_ssa_path, "ssa_path", "get_ssa_path")
def surface_order(self, node):
return (self.get_extent(node), self.get_centrality(node))
def set_surface_order_from_path(self, ssa_path):
# first get dict from contractions to parents (don't usually store)
parent_map = {}
for p, l, r in self.traverse():
parent_map[frozenset([l, r])] = p
# then traverse up in given ssa_path order,
# assigning parent node ordering 'score' incrementally
parent_scores = {}
node_map = {i: n for i, n in enumerate(self.gen_leaves())}
for j, p in enumerate(ssa_path):
lr = frozenset(node_map[i] for i in p)
p = parent_map[lr]
parent_scores[p] = j
node_map[self.N + j] = p
self.surface_order = functools.partial(
get_with_default, obj=parent_scores, default=float("inf")
)
def get_path_surface(self):
return self.get_path(order=self.surface_order)
path_surface = deprecated(
get_path_surface, "path_surface", "get_path_surface"
)
def get_ssa_path_surface(self):
return self.get_ssa_path(order=self.surface_order)
ssa_path_surface = deprecated(
get_ssa_path_surface, "ssa_path_surface", "get_ssa_path_surface"
)
def get_spans(self):
"""Get all (which could mean none) potential embeddings of this
contraction tree into a spanning tree of the original graph.
Returns
-------
tuple[dict[frozenset[int], frozenset[int]]]
"""
ind_to_term = collections.defaultdict(set)
for i, term in enumerate(self.inputs):
for ix in term:
ind_to_term[ix].add(i)
def boundary_pairs(node):
"""Get nodes along the boundary of the bipartition represented by
``node``.
"""
pairs = set()
involved = self.get_involved(node)
legs = self.get_legs(node)
removed = [ix for ix in involved if ix not in legs]
for ix in removed:
# for every index across the contraction
l1, l2 = ind_to_term[ix]
# can either span from left to right or right to left
pairs.add((l1, l2))
pairs.add((l2, l1))
return pairs
# first span choice is any nodes across the top level bipart
candidates = [
{
# which intermedate nodes map to which leaf nodes
"map": {self.root: self.input_to_node(l2)},
# the leaf nodes in the spanning tree
"spine": {l1, l2},
}
for l1, l2 in boundary_pairs(self.root)
]
for _, l, r in self.descend():
for child in (r, l):
# for each current candidate check all the possible extensions
for _ in range(len(candidates)):
cand = candidates.pop(0)
# don't need to do anything for
if self.is_leaf(child):
candidates.append(
{
"map": {child: child, **cand["map"]},
"spine": cand["spine"].copy(),
}
)
for l1, l2 in boundary_pairs(child):
if (l1 in cand["spine"]) or (l2 not in cand["spine"]):
# pair does not merge inwards into spine
continue
# valid extension of spanning tree
candidates.append(
{
"map": {
child: self.input_to_node(l2),
**cand["map"],
},
"spine": cand["spine"] | {l1, l2},
}
)
return tuple(c["map"] for c in candidates)
def compute_centralities(self, combine="mean"):
"""Compute a centrality for every node in this contraction tree."""
hg = self.get_hypergraph(accel="auto")
cents = hg.simple_centrality()
for i, leaf in enumerate(self.gen_leaves()):
self.info[leaf]["centrality"] = cents[i]
combine = {
"mean": lambda x, y: (x + y) / 2,
"sum": lambda x, y: x + y,
"max": max,
"min": min,
}.get(combine, combine)
for p, l, r in self.traverse("dfs"):
self.info[p]["centrality"] = combine(
self.info[l]["centrality"], self.info[r]["centrality"]
)
def get_hypergraph(self, accel=False):
"""Get a hypergraph representing the uncontracted network (i.e. the
leaves).
"""
return get_hypergraph(self.inputs, self.output, self.size_dict, accel)
def reset_contraction_indices(self):
"""Reset all information regarding a) the explicit contraction indices
ordering and b) cached contraction expressions. This should probably be
called any time structural changes are made to the tree, e.g.
reconfiguration.
"""
# delete all derived information
# (note legs, involved, etc. are order invariant so we can keep those)
for node in self.children:
for k in (
"inds",
"einsum_eq",
"can_dot",
"tensordot_axes",
"tensordot_perm",
):
self.info[node].pop(k, None)
# invalidate any compiled contractions
self.contraction_cores.clear()
def sort_contraction_indices(
self,
priority="flops",
make_output_contig=True,
make_contracted_contig=True,
reset=True,
):
"""Set explicit orders for the contraction indices of this self to
optimize for one of two things: contiguity in contracted ('k') indices,
or contiguity of left and right output ('m' and 'n') indices.
Parameters
----------
priority : {'flops', 'size', 'root', 'leaves'}, optional
Which order to process the intermediate nodes in. Later nodes
re-sort previous nodes so are more likely to keep their ordering.
E.g. for 'flops' the mostly costly contracton will be process last
and thus will be guaranteed to have its indices exactly sorted.
make_output_contig : bool, optional
When processing a pairwise contraction, sort the parent contraction
indices so that the order of indices is the order they appear
from left to right in the two child (input) tensors.
make_contracted_contig : bool, optional
When processing a pairwise contraction, sort the child (input)
tensor indices so that all contracted indices appear contiguously.
reset : bool, optional
Reset all indices to the default order before sorting.
"""
if reset:
self.reset_contraction_indices()
if priority == "flops":
nodes = sorted(
self.children.items(), key=lambda x: self.get_flops(x[0])
)
elif priority == "size":
nodes = sorted(
self.children.items(), key=lambda x: self.get_size(x[0])
)
elif priority == "root":
nodes = ((p, (l, r)) for p, l, r in self.traverse())
elif priority == "leaves":
nodes = ((p, (l, r)) for p, l, r in self.descend())
else:
raise ValueError(priority)
for p, (l, r) in nodes:
p_inds, l_inds, r_inds = map(self.get_inds, (p, l, r))
if make_output_contig and not self.is_root(p):
# sort indices by whether they appear in the left or right
# whether this happens before or after the sort below depends
# on the order we are processing the nodes
# (avoid root as don't want to modify output)
def psort(ix):
# group by whether in left or right input
return (r_inds.find(ix), l_inds.find(ix))
p_inds = "".join(sorted(p_inds, key=psort))
self.info[p]["inds"] = p_inds
if make_contracted_contig:
# sort indices by:
# 1. if they are going to be contracted
# 2. what order they appear in the parent indices
# (but ignore leaf indices)
if not self.is_leaf(l):
def lsort(ix):
return (r_inds.find(ix), p_inds.find(ix))
l_inds = "".join(sorted(self.get_legs(l), key=lsort))
self.info[l]["inds"] = l_inds
if not self.is_leaf(r):
def rsort(ix):
return (p_inds.find(ix), l_inds.find(ix))
r_inds = "".join(sorted(self.get_legs(r), key=rsort))
self.info[r]["inds"] = r_inds
if not reset:
# still need to invalidate any cached contraction expressions
self.contraction_cores.clear()
def print_contractions(self, sort=None, show_brackets=True):
"""Print each pairwise contraction, with colorized indices (if
`colorama` is installed), and other information. The color codes are:
- blue: index appears on left and is kept
- green: index appears on right and is kept
- red: contracted index: appears on both sides and is removed
- pink: batch index: appears on both sides and is kept
Any trivial indices that appear only on one term and not in the output
are removed and shown by the preprocessing steps.
Parameters
----------
sort : {'flops', 'size'}, optional
Sort the contractions by either the number of floating point
operations or the size of the intermediate tensor. By default the
contraction are show in the order they are performed.
show_brackets : bool, optional
Whether to show the brackets around contiguous sections of the same
type of indices.
"""
try:
from colorama import Fore
RESET = Fore.RESET
GREY = Fore.WHITE
PINK = Fore.MAGENTA
RED = Fore.RED
BLUE = Fore.BLUE
GREEN = Fore.GREEN
except ImportError:
RESET = GREY = PINK = RED = BLUE = GREEN = ""
entries = []
if self.has_preprocessing():
for pi, eq in self.preprocessing.items():
# eq is with canonical indices, reinsert original inputs
replacer = dict(zip(eq.split("->")[0], self.inputs[pi]))
eq = "".join(replacer.get(c, c) for c in eq)
print(f"{GREY}preprocess input {pi}: {RESET}{eq}")
print()
for i, (p, l, r) in enumerate(self.traverse()):
p_legs, l_legs, r_legs = map(self.get_legs, [p, l, r])
p_inds, l_inds, r_inds = map(self.get_inds, [p, l, r])
# print sizes and flops
p_flops = self.get_flops(p)
p_sz, l_sz, r_sz = (
math.log2(self.get_size(node)) for node in [p, l, r]
)
# print whether tensordottable
if self.get_can_dot(p):
type_msg = "tensordot"
perm = self.get_tensordot_perm(p)
if perm is not None:
# and whether indices match tensordot
type_msg += "+perm"
else:
type_msg = "einsum"
kpt_brck_l = "(" if show_brackets else ""
kpt_brck_r = ")" if show_brackets else ""
con_brck_l = "[" if show_brackets else ""
con_brck_r = "]" if show_brackets else ""
hyp_brck_l = "{" if show_brackets else ""
hyp_brck_r = "}" if show_brackets else ""
pa = (
"".join(
PINK + f"{hyp_brck_l}{ix}{hyp_brck_r}"
if (ix in l_legs) and (ix in r_legs)
else GREEN + f"{kpt_brck_l}{ix}{kpt_brck_r}"
if ix in r_legs
else BLUE + ix
for ix in p_inds
)
.replace(f"){GREEN}(", "")
.replace(f"}}{PINK}{{", "")
)
la = (
"".join(
PINK + f"{hyp_brck_l}{ix}{hyp_brck_r}"
if (ix in p_legs) and (ix in r_legs)
else RED + f"{con_brck_l}{ix}{con_brck_r}"
if ix in r_legs
else BLUE + ix
for ix in l_inds
)
.replace(f"]{RED}[", "")
.replace(f"}}{PINK}{{", "")
)
ra = (
"".join(
PINK + f"{hyp_brck_l}{ix}{hyp_brck_r}"
if (ix in p_legs) and (ix in l_legs)
else RED + f"{con_brck_l}{ix}{con_brck_r}"
if ix in l_legs
else GREEN + ix
for ix in r_inds
)
.replace(f"]{RED}[", "")
.replace(f"}}{PINK}{{", "")
)
entries.append(
(
p,
f"{GREY}({i}) cost: {RESET}{p_flops:.1e} "
f"{GREY}widths: {RESET}{l_sz:.1f},{r_sz:.1f}->{p_sz:.1f} "
f"{GREY}type: {RESET}{type_msg}\n"
f"{GREY}inputs: {la},{ra}{RESET}->\n"
f"{GREY}output: {pa}\n",
)
)
if sort == "flops":
entries.sort(key=lambda x: self.get_flops(x[0]), reverse=True)
if sort == "size":
entries.sort(key=lambda x: self.get_size(x[0]), reverse=True)
entries.append((None, f"{RESET}"))
o = "\n".join(entry for _, entry in entries)
print(o)
# --------------------- Performing the Contraction ---------------------- #
def get_contractor(
self,
order=None,
prefer_einsum=False,
strip_exponent=False,
check_zero=False,
implementation=None,
autojit=False,
progbar=False,
):
"""Get a reusable function which performs the contraction corresponding
to this tree, cached.
Parameters
----------
tree : ContractionTree
The contraction tree.
order : str or callable, optional
Supplied to :meth:`ContractionTree.traverse`, the order in which
to perform the pairwise contractions given by the tree.
prefer_einsum : bool, optional
Prefer to use ``einsum`` for pairwise contractions, even if
``tensordot`` can perform the contraction.
strip_exponent : bool, optional
If ``True``, the function will eagerly strip the exponent (in
log10) from intermediate tensors to control numerical problems from
leaving the range of the datatype. This method then returns the
scaled 'mantissa' output array and the exponent separately.
check_zero : bool, optional
If ``True``, when ``strip_exponent=True``, explicitly check for
zero-valued intermediates that would otherwise produce ``nan``,
instead terminating early if encountered and returning
``(0.0, 0.0)``.
implementation : str or tuple[callable, callable], optional
What library to use to actually perform the contractions. Options
are:
- None: let cotengra choose.
- "autoray": dispatch with autoray, using the ``tensordot`` and
``einsum`` implementation of the backend.
- "cotengra": use the ``tensordot`` and ``einsum`` implementation
of cotengra, which is based on batch matrix multiplication. This
is faster for some backends like numpy, and also enables
libraries which don't yet provide ``tensordot`` and ``einsum`` to
be used.
- "cuquantum": use the cuquantum library to perform the whole
contraction (not just individual contractions).
- tuple[callable, callable]: manually supply the ``tensordot`` and
``einsum`` implementations to use.
autojit : bool, optional
If ``True``, use :func:`autoray.autojit` to compile the contraction
function.
progbar : bool, optional
Whether to show progress through the contraction by default.
Returns
-------
fn : callable
The contraction function, with signature ``fn(*arrays)``.
"""
key = (
autojit,
order,
prefer_einsum,
strip_exponent,
check_zero,
implementation,
progbar,
)
try:
fn = self.contraction_cores[key]
except KeyError:
fn = self.contraction_cores[key] = make_contractor(
tree=self,
order=order,
prefer_einsum=prefer_einsum,
strip_exponent=strip_exponent,
check_zero=check_zero,
implementation=implementation,
autojit=autojit,
progbar=progbar,
)
return fn
def contract_core(
self,
arrays,
order=None,
prefer_einsum=False,
strip_exponent=False,
check_zero=False,
backend=None,
implementation=None,
autojit="auto",
progbar=False,
):
"""Contract ``arrays`` with this tree. The order of the axes and
output is assumed to be that of ``tree.inputs`` and ``tree.output``,
but with sliced indices removed. This functon contracts the core tree
and thus if indices have been sliced the arrays supplied need to be
sliced as well.
Parameters
----------
arrays : sequence of array
The arrays to contract.
order : str or callable, optional
Supplied to :meth:`ContractionTree.traverse`.
prefer_einsum : bool, optional
Prefer to use ``einsum`` for pairwise contractions, even if
``tensordot`` can perform the contraction.
backend : str, optional
What library to use for ``einsum`` and ``transpose``, will be
automatically inferred from the arrays if not given.
autojit : "auto" or bool, optional
Whether to use ``autoray.autojit`` to jit compile the expression.
If "auto", then let ``cotengra`` choose.
progbar : bool, optional
Show progress through the contraction.
"""
if autojit == "auto":
# choose for the user
autojit = backend == "jax"
fn = self.get_contractor(
order=order,
prefer_einsum=prefer_einsum,
strip_exponent=strip_exponent is not False,
implementation=implementation,
autojit=autojit,
check_zero=check_zero,
progbar=progbar,
)
return fn(*arrays, backend=backend)
def slice_key(self, i, strides=None):
"""Get the combination of sliced index values for overall slice ``i``.
Parameters
----------
i : int
The overall slice index.
Returns
-------
key : dict[str, int]
The value each sliced index takes for slice ``i``.
"""
if strides is None:
strides = get_slice_strides(self.sliced_inds)
key = {}
for (ind, info), stride in zip(self.sliced_inds.items(), strides):
if info.project is None:
key[ind] = i // stride
i %= stride
else:
# size is 1 and i doesn't change
key[ind] = info.project
return key
def slice_arrays(self, arrays, i):
"""Take ``arrays`` and slice the relevant inputs according to
``tree.sliced_inds`` and the dynary representation of ``i``.
"""
temp_arrays = list(arrays)
# e.g. {'a': 2, 'd': 7, 'z': 0}
locations = self.slice_key(i)
for c in self.sliced_inputs:
# the indexing object, e.g. [:, :, 7, :, 2, :, :, 0]
selector = tuple(
locations.get(ix, slice(None)) for ix in self.inputs[c]
)
# re-insert the sliced array
temp_arrays[c] = temp_arrays[c][selector]
return temp_arrays
def contract_slice(self, arrays, i, **kwargs):
"""Get slices ``i`` of ``arrays`` and then contract them."""
return self.contract_core(self.slice_arrays(arrays, i), **kwargs)
def gather_slices(self, slices, backend=None, progbar=False):
"""Gather all the output contracted slices into a single full result.
If none of the sliced indices appear in the output, then this is a
simple sum - otherwise the slices need to be partially summed and
partially stacked.
"""
if progbar:
import tqdm
slices = tqdm.tqdm(slices, total=self.multiplicity)
output_pos = {
ix: i for i, ix in enumerate(self.output) if ix in self.sliced_inds
}
add_maybe_exponent_stripped = AdderWithMaybeExponentStripped()
if not output_pos:
# we can just sum everything
return functools.reduce(add_maybe_exponent_stripped, slices)
# first we sum over non-output sliced indices
chunks = {}
for i, s in enumerate(slices):
key_slice = self.slice_key(i)
key = tuple(key_slice[ix] for ix in output_pos)
try:
chunks[key] = add_maybe_exponent_stripped(chunks[key], s)
except KeyError:
chunks[key] = s
if isinstance(next(iter(chunks.values())), tuple):
# have stripped exponents, need to scale to largest
emax = max(v[1] for v in chunks.values())
chunks = {
k: mi * 10 ** (ei - emax) for k, (mi, ei) in chunks.items()
}
else:
emax = None
# then we stack these summed chunks over output sliced indices
def recursively_stack_chunks(loc, remaining):
if not remaining:
return chunks[loc]
arrays = [
recursively_stack_chunks(loc + (d,), remaining[1:])
for d in self.sliced_inds[remaining[0]].sliced_range
]
axes = output_pos[remaining[0]] - len(loc)
return do("stack", arrays, axes, like=backend)
result = recursively_stack_chunks((), tuple(output_pos))
if emax is not None:
# strip_exponent was True, return the exponent separately
return result, emax
return result
def gen_output_chunks(
self, arrays, with_key=False, progbar=False, **contract_opts
):
"""Generate each output chunk of the contraction - i.e. take care of
summing internally sliced indices only first. This assumes that the
``sliced_inds`` are sorted by whether they appear in the output or not
(the default order). Useful for performing some kind of reduction over
the final tensor object like ``fn(x).sum()`` without constructing the
entire thing.
Parameters
----------
arrays : sequence of array
The arrays to contract.
with_key : bool, optional
Whether to yield the output index configuration key along with the
chunk.
progbar : bool, optional
Show progress through the contraction chunks.
Yields
------
chunk : array
A chunk of the contracted result.
key : dict[str, int]
The value each sliced output index takes for this chunk.
"""
# consecutive slices of size ``stepsize`` all belong to the same output
# block because the sliced indices are sorted output first
stepsize = prod(
si.size for si in self.sliced_inds.values() if si.inner
)
if progbar:
import tqdm
it = tqdm.trange(self.nslices // stepsize)
else:
it = range(self.nslices // stepsize)
for o in it:
chunk = self.contract_slice(arrays, o * stepsize, **contract_opts)
if with_key:
output_key = {
ix: x
for ix, x in self.slice_key(o * stepsize).items()
if ix in self.output
}
for j in range(1, stepsize):
i = o * stepsize + j
chunk = chunk + self.contract_slice(arrays, i, **contract_opts)
if with_key:
yield chunk, output_key
else:
yield chunk
def contract(
self,
arrays,
order=None,
prefer_einsum=False,
strip_exponent=False,
check_zero=False,
backend=None,
implementation=None,
autojit="auto",
progbar=False,
):
"""Contract ``arrays`` with this tree. This function takes *unsliced*
arrays and handles the slicing, contractions and gathering. The order
of the axes and output is assumed to match that of ``tree.inputs`` and
``tree.output``.
Parameters
----------
arrays : sequence of array
The arrays to contract.
order : str or callable, optional
Supplied to :meth:`ContractionTree.traverse`.
prefer_einsum : bool, optional
Prefer to use ``einsum`` for pairwise contractions, even if
``tensordot`` can perform the contraction.
strip_exponent : bool, optional
If ``True``, eagerly strip the exponent (in log10) from
intermediate tensors to control numerical problems from leaving the
range of the datatype. This method then returns the scaled
'mantissa' output array and the exponent separately.
check_zero : bool, optional
If ``True``, when ``strip_exponent=True``, explicitly check for
zero-valued intermediates that would otherwise produce ``nan``,
instead terminating early if encountered and returning
``(0.0, 0.0)``.
backend : str, optional
What library to use for ``tensordot``, ``einsum`` and
``transpose``, it will be automatically inferred from the input
arrays if not given.
autojit : bool, optional
Whether to use the 'autojit' feature of `autoray` to compile the
contraction expression.
progbar : bool, optional
Whether to show a progress bar.
Returns
-------
output : array
The contracted output, it will be scaled if
``strip_exponent==True``.
exponent : float
The exponent of the output in base 10, returned only if
``strip_exponent==True``.
See Also
--------
contract_core, contract_slice, slice_arrays, gather_slices
"""
if not self.sliced_inds:
return self.contract_core(
arrays,
order=order,
prefer_einsum=prefer_einsum,
strip_exponent=strip_exponent,
check_zero=check_zero,
backend=backend,
implementation=implementation,
autojit=autojit,
progbar=progbar,
)
slices = (
self.contract_slice(
arrays,
i,
order=order,
prefer_einsum=prefer_einsum,
strip_exponent=strip_exponent,
check_zero=check_zero,
backend=backend,
implementation=implementation,
autojit=autojit,
)
for i in range(self.multiplicity)
)
return self.gather_slices(slices, backend=backend, progbar=progbar)
def contract_mpi(self, arrays, comm=None, root=None, **kwargs):
"""Contract the slices of this tree and sum them in parallel -
*assuming* we are already running under MPI.
Parameters
----------
arrays : sequence of array
The input (unsliced arrays)
comm : None or mpi4py communicator
Defaults to ``mpi4py.MPI.COMM_WORLD`` if not given.
root : None or int, optional
If ``root=None``, an ``Allreduce`` will be performed such that
every process has the resulting tensor, else if an integer e.g.
``root=0``, the result will be exclusively gathered to that
process using ``Reduce``, with every other process returning
``None``.
kwargs
Supplied to :meth:`~cotengra.ContractionTree.contract_slice`.
"""
if not set(self.sliced_inds).isdisjoint(set(self.output)):
raise NotImplementedError(
"Sliced and output indices overlap - currently only a simple "
"sum of result slices is supported currently."
)
if comm is None:
from mpi4py import MPI
comm = MPI.COMM_WORLD
if self.multiplicity < comm.size:
raise ValueError(
f"Need to have more slices than MPI processes, but have "
f"{self.multiplicity} and {comm.size} respectively."
)
# round robin compute each slice, eagerly summing
result_i = None
for i in range(comm.rank, self.multiplicity, comm.size):
# note: fortran ordering is needed for the MPI reduce
x = do("asfortranarray", self.contract_slice(arrays, i, **kwargs))
if result_i is None:
result_i = x
else:
result_i += x
if root is None:
# everyone gets the summed result
result = do("empty_like", result_i)
comm.Allreduce(result_i, result)
return result
# else we only sum reduce the result to process ``root``
if comm.rank == root:
result = do("empty_like", result_i)
else:
result = None
comm.Reduce(result_i, result, root=root)
return result
def benchmark(
self,
dtype="float64",
max_time=60,
min_reps=3,
max_reps=100,
warmup=True,
**contract_opts,
):
"""Benchmark the contraction of this tree.
Parameters
----------
dtype : {"float32", "float64", "complex64", "complex128"}
The datatype to use.
max_time : float, optional
The maximum time to spend benchmarking in seconds.
min_reps : int, optional
The minimum number of repetitions to perform, regardless of time.
max_reps : int, optional
The maximum number of repetitions to perform, regardless of time.
warmup : bool or int, optional
Whether to perform a warmup run before the benchmark. If an int,
the number of warmup runs to perform.
contract_opts
Supplied to :meth:`~cotengra.ContractionTree.contract_slice`.
Returns
-------
dict
A dictionary of benchmarking results. The keys are:
- "time_per_slice" : float
The average time to contract a single slice.
- "est_time_total" : float
The estimated total time to contract all slices.
- "est_gigaflops" : float
The estimated gigaflops of the contraction.
See Also
--------
contract_slice
"""
import time
from .utils import make_arrays_from_inputs
arrays = make_arrays_from_inputs(
self.inputs, self.size_dict, dtype=dtype
)
for i in range(int(warmup)):
self.contract_slice(arrays, i % self.nslices, **contract_opts)
t0 = time.time()
ti = t0
i = 0
while (ti - t0 < max_time) or (i < min_reps):
self.contract_slice(arrays, i % self.nslices, **contract_opts)
ti = time.time()
i += 1
if i >= max_reps:
break
time_per_slice = (ti - t0) / i
est_time_total = time_per_slice * self.nslices
est_gigaflops = self.total_flops(dtype=dtype) / (1e9 * est_time_total)
return {
"time_per_slice": time_per_slice,
"est_time_total": est_time_total,
"est_gigaflops": est_gigaflops,
}
plot_ring = plot_tree_ring
plot_tent = plot_tree_tent
plot_span = plot_tree_span
plot_flat = plot_tree_flat
plot_circuit = plot_tree_circuit
plot_rubberband = plot_tree_rubberband
plot_contractions = plot_contractions
plot_contractions_alt = plot_contractions_alt
@functools.wraps(plot_hypergraph)
def plot_hypergraph(self, **kwargs):
hg = self.get_hypergraph(accel=False)
hg.plot(**kwargs)
def describe(self, info="normal", join=" "):
"""Return a string describing the contraction tree."""
self.contract_stats()
if info == "normal":
return join.join(
(
f"log10[FLOPs]={self.total_flops(log=10):.2f}",
f"log2[SIZE]={self.max_size(log=2):.2f}",
)
)
elif info == "full":
s = [
f"log10[FLOPS]={self.total_flops(log=10):.2f}",
f"log10[COMBO]={self.combo_cost(log=10):.2f}",
f"log2[SIZE]={self.max_size(log=2):.2f}",
f"log2[PEAK]={self.peak_size(log=2):.2f}",
]
if self.sliced_inds:
s.append(f"NSLICES={self.multiplicity:.2f}")
return join.join(s)
elif info == "concise":
s = [
f"F={self.total_flops(log=10):.2f}",
f"C={self.combo_cost(log=10):.2f}",
f"S={self.max_size(log=2):.2f}",
f"P={self.peak_size(log=2):.2f}",
]
if self.sliced_inds:
s.append(f"$={self.multiplicity:.2f}")
return join.join(s)
def __repr__(self):
if self.is_complete():
return f"<{self.__class__.__name__}(N={self.N})>"
else:
s = "<{}(N={}, branches={}, complete={})>"
return s.format(
self.__class__.__name__,
self.N,
len(self.children),
self.is_complete(),
)
def __str__(self):
if not self.is_complete():
return self.__repr__()
else:
d = self.describe("concise", join=", ")
return f"<{self.__class__.__name__}(N={self.N}, {d})>"
def _reconfigure_tree(tree, *args, **kwargs):
return tree.subtree_reconfigure(*args, **kwargs)
def _slice_and_reconfigure_tree(tree, *args, **kwargs):
return tree.slice_and_reconfigure(*args, **kwargs)
def _get_tree_info(tree):
stats = tree.contract_stats()
stats["sliced_ind_set"] = frozenset(tree.sliced_inds)
return stats
def _describe_tree(tree, info="normal"):
return tree.describe(info=info)
class ContractionTreeCompressed(ContractionTree):
"""A contraction tree for compressed contractions. Currently the only
difference is that this defaults to the 'surface' traversal ordering.
"""
def set_state_from(self, other):
super().set_state_from(other)
self.set_surface_order_from_path(other.get_ssa_path())
@classmethod
def from_path(
cls,
inputs,
output,
size_dict,
*,
path=None,
ssa_path=None,
autocomplete="auto",
check=False,
**kwargs,
):
"""Create a (completed) ``ContractionTreeCompressed`` from the usual
inputs plus a standard contraction path or 'ssa_path' - you need to
supply one. This also set the default 'surface' traversal ordering to
be the initial path.
"""
if int(path is None) + int(ssa_path is None) != 1:
raise ValueError(
"Exactly one of ``path`` or ``ssa_path`` must be supplied."
)
if path is not None:
from .pathfinders.path_basic import linear_to_ssa
ssa_path = linear_to_ssa(path)
tree = super().from_path(
inputs,
output,
size_dict,
ssa_path=ssa_path,
autocomplete=False,
check=check,
**kwargs,
)
tree.set_surface_order_from_path(ssa_path)
if (len(tree.children) < tree.N - 1) and autocomplete:
if autocomplete == "auto":
# warn that we are completing
warnings.warn(
"Path was not complete - contracting all remaining. "
"You can silence this warning with `autocomplete=True`."
"Or produce an incomplete tree with `autocomplete=False`."
)
tree.autocomplete(optimize="greedy-compressed")
return tree
def get_default_order(self):
return "surface_order"
def get_default_objective(self):
if self._default_objective is None:
self._default_objective = get_score_fn("peak-compressed")
return self._default_objective
def get_default_chi(self):
objective = self.get_default_objective()
try:
chi = objective.chi
except AttributeError:
chi = "auto"
if chi == "auto":
chi = max(self.size_dict.values()) ** 2
return chi
def get_default_compress_late(self):
objective = self.get_default_objective()
try:
return objective.compress_late
except AttributeError:
return False
total_flops = ContractionTree.total_flops_compressed
total_write = ContractionTree.total_write_compressed
combo_cost = ContractionTree.combo_cost_compressed
total_cost = ContractionTree.total_cost_compressed
max_size = ContractionTree.max_size_compressed
peak_size = ContractionTree.peak_size_compressed
contraction_cost = ContractionTree.contraction_cost_compressed
contraction_width = ContractionTree.contraction_width_compressed
total_flops_exact = ContractionTree.total_flops
total_write_exact = ContractionTree.total_write
combo_cost_exact = ContractionTree.combo_cost
total_cost_exact = ContractionTree.total_cost
max_size_exact = ContractionTree.max_size
peak_size_exact = ContractionTree.peak_size
def get_contractor(self, *_, **__):
raise NotImplementedError(
"`cotengra` doesn't implement compressed contraction itself. "
"If you want to use compressed contractions, you need to use "
"`quimb` and the `TensorNetwork.contract_compressed` method, "
"with e.g. `optimize=tree.get_path()`."
)
def simulated_anneal(
self,
minimize=None,
tfinal=0.0001,
tstart=0.01,
tsteps=50,
numiter=50,
seed=None,
inplace=False,
progbar=False,
**kwargs,
):
"""Perform simulated annealing refinement of this *compressed*
contraction tree.
"""
from .pathfinders.path_compressed import WindowedOptimizer
if minimize is None:
minimize = self.get_default_objective()
wo = WindowedOptimizer(
self.inputs,
self.output,
self.size_dict,
minimize=minimize,
ssa_path=self.get_ssa_path(),
seed=seed,
)
wo.simulated_anneal(
tfinal=tfinal,
tstart=tstart,
tsteps=tsteps,
numiter=numiter,
progbar=progbar,
**kwargs,
)
ssa_path = wo.get_ssa_path()
rtree = self.__class__.from_path(
self.inputs,
self.output,
self.size_dict,
ssa_path=ssa_path,
objective=minimize,
)
if inplace:
self.set_state_from(rtree)
rtree = self
rtree.reset_contraction_indices()
return rtree
simulated_anneal_ = functools.partialmethod(simulated_anneal, inplace=True)
class PartitionTreeBuilder:
"""Function wrapper that takes a function that partitions graphs and
uses it to build a contraction tree. ``partition_fn`` should have
signature:
def partition_fn(inputs, output, size_dict,
weight_nodes, weight_edges, **kwargs):
...
return membership
Where ``weight_nodes`` and ``weight_edges`` decsribe how to weight the
nodes and edges of the graph respectively and ``membership`` should be a
list of integers of length ``len(inputs)`` labelling which partition
each input node should be put it.
"""
def __init__(self, partition_fn):
self.partition_fn = partition_fn
def build_divide(
self,
inputs,
output,
size_dict,
random_strength=0.01,
cutoff=10,
parts=2,
parts_decay=0.5,
sub_optimize="greedy",
super_optimize="random-greedy-128",
check=False,
seed=None,
**partition_opts,
):
tree = ContractionTree(
inputs,
output,
size_dict,
track_childless=True,
)
rng = get_rng(seed)
rand_size_dict = jitter_dict(size_dict, random_strength, rng)
dynamic_imbalance = ("imbalance" in partition_opts) and (
"imbalance_decay" in partition_opts
)
if dynamic_imbalance:
imbalance = partition_opts.pop("imbalance")
imbalance_decay = partition_opts.pop("imbalance_decay")
else:
imbalance = imbalance_decay = None
dynamic_fix = partition_opts.get("fix_output_nodes", None) == "auto"
while tree.childless:
top_node = next(iter(tree.childless))
subgraph = tree.get_subgraph(top_node)
subsize = len(subgraph)
# skip straight to better method
if subsize <= cutoff:
tree.contract_nodes(
[tree.input_to_node(x) for x in subgraph],
grandparent=top_node,
optimize=sub_optimize,
check=check,
)
continue
# relative subgraph size
s = subsize / tree.N
# let the target number of communities depend on subgraph size
parts_s = max(int(s**parts_decay * parts), 2)
# let the imbalance either rise or fall
if dynamic_imbalance:
if imbalance_decay >= 0:
imbalance_s = s**imbalance_decay * imbalance
else:
imbalance_s = 1 - s**-imbalance_decay * (1 - imbalance)
partition_opts["imbalance"] = imbalance_s
if dynamic_fix:
# for the top level subtree (s==1.0) we partition the outputs
# nodes first into their own bi-partition
parts_s = 2
partition_opts["fix_output_nodes"] = s == 1.0
# partition! get community membership list e.g.
# [0, 0, 1, 0, 1, 0, 0, 2, 2, ...]
inputs = tuple(map(tuple, tree.node_to_terms(top_node)))
output = tuple(tree.get_legs(top_node))
membership = self.partition_fn(
inputs,
output,
rand_size_dict,
parts=parts_s,
seed=rng,
**partition_opts,
)
# divide subgraph up e.g. if we enumerate the subgraph index sets
# (0, 1, 2, 3, 4, 5, 6, 7, 8, ...) ->
# ({0, 1, 3, 5, 6}, {2, 4}, {7, 8})
partitions = separate(subgraph, membership)
if len(partitions) == 1:
# no communities found - contract all remaining leaves
tree.contract_nodes(
tuple(map(tree.input_to_node, subgraph)),
grandparent=top_node,
optimize=sub_optimize,
check=check,
)
continue
tree.contract_nodes(
partitions,
grandparent=top_node,
optimize=super_optimize,
check=check,
)
if check:
assert tree.is_complete()
return tree
def build_agglom(
self,
inputs,
output,
size_dict,
random_strength=0.01,
groupsize=4,
sub_optimize="greedy",
check=False,
seed=None,
**partition_opts,
):
tree = ContractionTree(
inputs,
output,
size_dict,
track_childless=True,
)
rand_size_dict = jitter_dict(size_dict, random_strength, seed)
leaves = tuple(tree.gen_leaves())
output = tuple(tree.output)
while len(leaves) > groupsize:
# choose number of partitions so that each
# has approximately ``groupsize`` nodes
parts = max(2, len(leaves) // groupsize)
# partition! get community membership list
inputs = [tuple(tree.get_legs(node)) for node in leaves]
membership = self.partition_fn(
inputs,
output,
rand_size_dict,
parts=parts,
**partition_opts,
)
# group leaves according to partition label
partitions = separate(leaves, membership)
if len(partitions) == 1:
# only found one group, move to final contraction
break
# contract each group into a new leaf
leaves = [
tree.contract_nodes(
partition,
check=check,
optimize=sub_optimize,
)
for partition in partitions
]
if len(leaves) > 1:
# contract any remaining leaves together
tree.contract_nodes(
leaves,
check=check,
optimize=sub_optimize,
grandparent=tree.root,
)
if check:
assert tree.is_complete()
return tree
def trial_fn(self, inputs, output, size_dict, **partition_opts):
return self.build_divide(inputs, output, size_dict, **partition_opts)
def trial_fn_agglom(self, inputs, output, size_dict, **partition_opts):
return self.build_agglom(inputs, output, size_dict, **partition_opts)
def jitter(x, strength, rng):
return x * (1 + strength * rng.expovariate(1.0))
def jitter_dict(d, strength, seed=None):
rng = get_rng(seed)
return {k: jitter(v, strength, rng) for k, v in d.items()}
def separate(xs, blocks):
"""Partition ``xs`` into ``n`` different list based on the corresponding
labels in ``blocks``.
"""
sorter = collections.defaultdict(list)
for x, b in zip(xs, blocks):
sorter[b].append(x)
x_b = list(sorter.items())
x_b.sort()
return [x[1] for x in x_b]
================================================
FILE: cotengra/core_multi.py
================================================
import math
from .core import ContractionTree, cached_node_property
class ContractionTreeMulti(ContractionTree):
def __init__(
self,
inputs,
output,
size_dict,
sliced_inds,
objective,
track_cache=False,
):
super().__init__(inputs, output, size_dict, objective=objective)
self.sliced_inds = {ix: None for ix in sliced_inds}
self._track_cache = track_cache
if track_cache:
self._cache_est = 0
def set_state_from(self, other):
super().set_state_from(other)
self._track_cache = other._track_cache
if other._track_cache:
self._cache_est = other._cache_est
def _remove_node(self, node):
if self._track_cache:
self._cache_est -= self.get_cache_contrib(node)
super()._remove_node(node)
def _update_tracked(self, node):
if self._track_cache:
self._cache_est += self.get_cache_contrib(node)
super()._update_tracked(node)
@cached_node_property("node_var_inds")
def get_node_var_inds(self, node):
"""Get the set of variable indices that a node depends on."""
if self.is_leaf(node):
i = self.node_to_input(node)
term = self.inputs[i]
return {ix: None for ix in term if ix in self.sliced_inds}
try:
l, r = self.children[node]
return self.get_node_var_inds(l) | self.get_node_var_inds(r)
except KeyError:
return {
ix: None
for term in self.node_to_terms(node)
for ix in term
if ix in self.sliced_inds
}
@cached_node_property("node_is_bright")
def get_node_is_bright(self, node):
"""Get whether a node is 'bright', i.e. contains a different set of
variable indices to either of its children, if a node is not bright
then its children never have to be stored in the cache.
"""
if self.is_leaf(node):
i = self.node_to_input(node)
term = self.inputs[i]
return any(ix in self.sliced_inds for ix in term)
l, r = self.children[node]
return (self.get_node_var_inds(node) != self.get_node_var_inds(l)) or (
self.get_node_var_inds(node) != self.get_node_var_inds(r)
)
@cached_node_property("node_mult")
def get_node_mult(self, node):
"""Get the estimated 'multiplicity' of a node, i.e. the number of times
it will have to be recomputed for different index configurations.
"""
return self.get_default_objective().estimate_node_mult(self, node)
def get_node_cache_mult(self, node, sliced_ind_ordering):
"""Get the estimated 'cache multiplicity' of a node, i.e. the total
number of versions with different index configurations that must be
stored simultaneously in the cache.
"""
return self.get_default_objective().estimate_node_cache_mult(
self, node, sliced_ind_ordering
)
# @cached_node_property("multi_flops")
def get_flops(self, node):
"""The the estimated total cost of computing a node for all index
configurations.
"""
return super().get_flops(node) * self.get_node_mult(node)
@cached_node_property("cache_contrib")
def get_cache_contrib(self, node):
l, r = self.children[node]
lr_peak = 0
if self.get_node_is_bright(l):
lr_peak += self.get_size(l)
if self.get_node_is_bright(r):
lr_peak += self.get_size(r) * self.get_node_mult(r)
rl_peak = 0
if self.get_node_is_bright(r):
rl_peak += self.get_size(r)
if self.get_node_is_bright(l):
rl_peak += self.get_size(l) * self.get_node_mult(l)
if lr_peak < rl_peak:
return lr_peak
else:
self.children[node] = (r, l)
return rl_peak
def peak_size(self, log=None):
if not self._track_cache:
self._cache_est = 0
for (
node,
_,
_,
) in self.traverse():
self._cache_est += self.get_cache_contrib(node)
self._track_cache = True
peak = self._cache_est
if log is not None:
peak = math.log(peak, log)
return peak
def reorder_contractions_for_peak_est(self):
"""Reorder the contractions to try and reduce the peak memory usage."""
swapped = False
for p, l, r in self.descend():
lr_peak = 0
if self.get_node_is_bright(l):
lr_peak += self.get_size(l)
if self.get_node_is_bright(r):
lr_peak += self.get_size(r) * self.get_node_mult(r)
rl_peak = 0
if self.get_node_is_bright(r):
rl_peak += self.get_size(r)
if self.get_node_is_bright(l):
rl_peak += self.get_size(l) * self.get_node_mult(l)
if rl_peak < lr_peak:
self.children[p] = (r, l)
swapped = True
return swapped
def reorder_sliced_inds(self):
""" """
sliced_ind_ordering = dict()
for node, _, _ in self.traverse():
sliced_ind_ordering.update(self.get_node_var_inds(node))
self.sliced_inds = {ix: None for ix in sliced_ind_ordering}
def exact_multi_stats(self, configs):
# ragged list of lists (configs and contractions)
cons = []
# build this for efficiency
plr = tuple(self.traverse())
def to_key(node, config):
subconfig = tuple(
map(config.__getitem__, self.get_node_var_inds(node))
)
return hash((node, subconfig))
# iterate forward, recording only when we first need to produce a 'parent'
seen = set()
for config in configs:
cons_i = []
for p, l, r in plr:
pkey = to_key(p, config)
first = pkey not in seen
if first:
seen.add(pkey)
cons_i.append(
{
"p": p,
"l": l,
"r": r,
"pkey": pkey,
"lkey": to_key(l, config),
"rkey": to_key(r, config),
}
)
cons.append(cons_i)
del seen
# iterate backward, checking the last
# time a 'child' is seen -> can delete
deleted = set()
for cons_i in reversed(cons):
for con in cons_i:
rkey = con["rkey"]
rdel = rkey not in deleted
if rdel:
deleted.add(rkey)
con["rdel"] = rdel
lkey = con["lkey"]
ldel = lkey not in deleted
if ldel:
deleted.add(lkey)
con["ldel"] = ldel
del deleted
# iterate forward again if we want to compute flops and memory usage:
# not needed if we already know these & just want to contract
flops = 0
mems = []
mem_current = 0
mem_peak = 0
mem_write = 0
for cons_i in cons:
for con in cons_i:
p = con["p"]
flops += super().get_flops(p)
psize = self.get_size(p)
mem_current += psize
mem_write += psize
mems.append(mem_current)
mem_peak = max(mem_peak, mem_current)
l, r = con["l"], con["r"]
if con["ldel"] and not self.is_leaf(l):
mem_current -= self.get_size(l)
if con["rdel"] and not self.is_leaf(r):
mem_current -= self.get_size(r)
# final output of each config is always deletable
mem_current -= self.get_size(p)
return {
"flops": flops,
"write": mem_write,
"size": self.max_size(),
"peak": mem_peak,
}
================================================
FILE: cotengra/experimental/__init__.py
================================================
"""Potentially useful but experimental (untested) features,"""
================================================
FILE: cotengra/experimental/hyper_de.py
================================================
"""Hyper optimization using a pure Python differential evolution strategy."""
from ..utils import get_rng
from ._param_mapping import (
LCBOptimizer,
build_params,
convert_raw,
num_params,
)
from .hyper import HyperOptLib, register_hyper_optlib
class HyperDESampler:
"""A lightweight differential evolution optimizer operating in raw
``[-1, 1]`` parameter space.
Each generation maintains a population of candidate vectors. New trial
vectors are created using ``DE/rand/1/bin`` mutation and binomial
crossover, then kept only if they improve on their parent.
Parameters
----------
space : dict[str, dict]
The search space for a single contraction method.
seed : None or int, optional
Random seed.
population_size : int or "auto", optional
The population size. When ``"auto"`` it is chosen based on the mapped
parameter dimension.
mutation : float, optional
The differential weight (F) applied to the difference vector.
crossover : float, optional
The crossover probability (CR) for binomial crossover.
mutation_decay : float, optional
Multiplicative decay applied to ``mutation`` after each completed
generation.
mutation_min : float, optional
Lower bound for ``mutation``.
mutation_max : float, optional
Upper bound for ``mutation``.
exponential_param_power : float, optional
Passed through to the shared parameter mapping for ``FLOAT_EXP``
parameters.
"""
def __init__(
self,
space,
seed=None,
population_size="auto",
mutation=0.8,
crossover=0.7,
mutation_decay=1.0,
mutation_min=0.1,
mutation_max=1.5,
exponential_param_power=None,
):
self.rng = get_rng(seed)
self.params = build_params(
space, exponential_param_power=exponential_param_power
)
self.ndim = num_params(self.params)
if population_size == "auto":
population_size = max(8, 5 * self.ndim)
self.population_size = population_size
self.mutation = mutation
self.crossover = crossover
self.mutation_decay = mutation_decay
self.mutation_min = mutation_min
self.mutation_max = mutation_max
# initialize population uniformly in [-1, 1]
self._population = [
tuple(self.rng.uniform(-1.0, 1.0) for _ in range(self.ndim))
for _ in range(self.population_size)
]
self._scores = [float("inf")] * self.population_size
self._trial_counter = 0
self._target_index = 0
self._generation = None
self._trial_map = {}
def _mutate(self, target_idx):
"""Create a trial vector via DE/rand/1/bin."""
# pick three distinct indices, all different from target
indices = list(range(self.population_size))
indices.remove(target_idx)
r0, r1, r2 = self.rng.sample(indices, 3)
x_r0 = self._population[r0]
x_r1 = self._population[r1]
x_r2 = self._population[r2]
# mutation: v = x_r0 + F * (x_r1 - x_r2)
v = []
for d in range(self.ndim):
vi = x_r0[d] + self.mutation * (x_r1[d] - x_r2[d])
v.append(min(max(vi, -1.0), 1.0))
# binomial crossover
x_target = self._population[target_idx]
j_rand = self.rng.randrange(self.ndim)
trial = []
for d in range(self.ndim):
if self.rng.random() < self.crossover or d == j_rand:
trial.append(v[d])
else:
trial.append(x_target[d])
return tuple(trial)
def _sample_generation(self):
"""Prepare trial vectors for all population members."""
self._generation = {
"trials": [],
"trial_numbers": [],
"target_indices": [],
"scores": {},
"next_index": 0,
}
for i in range(self.population_size):
self._extend_generation(i)
def _extend_generation(self, target_idx=None):
"""Append one more trial to the current generation."""
if target_idx is None:
# wrap around if we need more trials than population
target_idx = len(self._generation["trials"]) % self.population_size
trial_number = self._trial_counter
self._trial_counter += 1
trial_vec = self._mutate(target_idx)
slot = len(self._generation["trials"])
self._generation["trials"].append(trial_vec)
self._generation["trial_numbers"].append(trial_number)
self._generation["target_indices"].append(target_idx)
self._trial_map[trial_number] = slot
def ask(self):
"""Return the next candidate from the current generation.
If all prepared candidates have been issued, grow the generation
by one more sample.
"""
if self._generation is None:
self._sample_generation()
if self._generation["next_index"] >= len(self._generation["trials"]):
self._extend_generation()
i = self._generation["next_index"]
self._generation["next_index"] += 1
trial_number = self._generation["trial_numbers"][i]
x = self._generation["trials"][i]
return trial_number, convert_raw(self.params, x)
def tell(self, trial_number, score):
"""Record a completed trial and perform selection if the
generation is complete.
For each trial vector, if it scores better than (or equal to) its
target parent, it replaces the parent in the population.
"""
slot = self._trial_map.pop(trial_number)
self._generation["scores"][slot] = score
if len(self._generation["scores"]) != self._generation["next_index"]:
return
# selection: compare each trial against its target
# (only iterate over issued slots, not all pre-sampled ones)
for slot_i in range(self._generation["next_index"]):
target_idx = self._generation["target_indices"][slot_i]
trial_score = self._generation["scores"][slot_i]
if trial_score <= self._scores[target_idx]:
self._population[target_idx] = self._generation["trials"][
slot_i
]
self._scores[target_idx] = trial_score
# decay mutation factor
self.mutation *= self.mutation_decay
self.mutation = min(
max(self.mutation, self.mutation_min), self.mutation_max
)
self._generation = None
class DEOptLib(HyperOptLib):
"""Hyper-optimization using differential evolution."""
def setup(
self,
methods,
space,
optimizer=None,
population_size="auto",
mutation=0.8,
crossover=0.7,
mutation_decay=1.0,
mutation_min=0.1,
mutation_max=1.5,
method_exploration=1.0,
method_temperature=1.0,
exponential_param_power=None,
seed=None,
**kwargs,
):
"""Initialize DE optimizers for each contraction method.
Parameters
----------
methods : list[str]
The contraction methods to optimize over.
space : dict[str, dict[str, dict]]
The per-method hyperparameter search space.
optimizer : HyperOptimizer, optional
The parent optimizer. Used to size the initial population
large enough for parallel pre-dispatch.
population_size : int or "auto", optional
The population size for each method-specific DE sampler.
mutation : float, optional
Differential weight (F).
crossover : float, optional
Crossover probability (CR).
mutation_decay, mutation_min, mutation_max : float, optional
Parameters controlling mutation scale over generations.
method_exploration : float, optional
Exploration strength for the LCB-based method chooser.
method_temperature : float, optional
Noise temperature for the LCB-based method chooser.
exponential_param_power : float, optional
Passed to the shared parameter mapping for ``FLOAT_EXP``.
seed : None or int, optional
Random seed.
"""
if population_size == "auto":
max_ndim = max(
num_params(
build_params(
space[m],
exponential_param_power=exponential_param_power,
)
)
for m in methods
)
population_size = max(
8,
max(1, getattr(optimizer, "pre_dispatch", 1)),
5 * max_ndim,
)
self._method_chooser = LCBOptimizer(
options=methods,
exploration=method_exploration,
temperature=method_temperature,
seed=seed,
)
self._optimizers = {
method: HyperDESampler(
space[method],
seed=seed,
population_size=population_size,
mutation=mutation,
crossover=crossover,
mutation_decay=mutation_decay,
mutation_min=mutation_min,
mutation_max=mutation_max,
exponential_param_power=exponential_param_power,
)
for method in methods
}
def get_setting(self):
"""Choose a contraction method, then request its next setting."""
method = self._method_chooser.ask()
params_token, params = self._optimizers[method].ask()
return {
"method": method,
"params_token": params_token,
"params": params,
}
def report_result(self, setting, trial, score):
"""Report a completed trial back to the method chooser and DE."""
self._method_chooser.tell(setting["method"], score)
self._optimizers[setting["method"]].tell(
setting["params_token"], score
)
register_hyper_optlib("de", DEOptLib)
register_hyper_optlib("diffev", DEOptLib)
================================================
FILE: cotengra/experimental/hyper_pe.py
================================================
"""Hyper optimization using parallel evolution with ranked sigma assignment."""
import math
from ..utils import get_rng
from ._param_mapping import (
LCBOptimizer,
build_params,
convert_raw,
num_params,
)
from .hyper import HyperOptLib, register_hyper_optlib
class HyperPESampler:
"""A parallel evolution optimizer operating in raw ``[-1, 1]`` space.
Multiple workers each maintain their own solution. Perturbation scales
(sigmas) are distributed across an evenly spaced range and reassigned
by rank after each generation: the best-scoring worker gets the lowest
sigma (exploit) and the worst gets the highest (explore).
Parameters
----------
space : dict[str, dict]
The search space for a single contraction method.
seed : None or int, optional
Random seed.
population_size : int or "auto", optional
The number of parallel workers. When ``"auto"`` it is chosen
based on the mapped parameter dimension.
sigma_min : float, optional
The smallest perturbation scale (assigned to the best worker).
sigma_max : float, optional
The largest perturbation scale (assigned to the worst worker).
elite_migrate_prob : float, optional
Probability each generation of copying the best worker's
solution to the worst worker's slot.
differential_prob : float, optional
Per-sample probability of using a differential perturbation
(``x_best - x_rand``) instead of Gaussian noise.
patience : int or None, optional
If a worker has not improved for this many generations,
re-randomize its solution. ``None`` or ``0`` disables.
exponential_param_power : float, optional
Passed through to the shared parameter mapping for ``FLOAT_EXP``
parameters.
"""
def __init__(
self,
space,
seed=None,
population_size=8,
sigma_min=0.01,
sigma_max=0.5,
elite_migrate_prob=0.0,
differential_prob=0.0,
patience=None,
exponential_param_power=None,
):
self.rng = get_rng(seed)
self.params = build_params(
space, exponential_param_power=exponential_param_power
)
self.ndim = num_params(self.params)
if population_size == "auto":
population_size = max(8, 4 * self.ndim)
self.population_size = population_size
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.elite_migrate_prob = elite_migrate_prob
self.differential_prob = differential_prob
self.patience = patience
# initialize each worker's solution uniformly in [-1, 1]
self._solutions = [
tuple(self.rng.uniform(-1.0, 1.0) for _ in range(self.ndim))
for _ in range(self.population_size)
]
self._scores = [float("inf")] * self.population_size
self._stagnation = [0] * self.population_size
# evenly spaced sigmas, assigned by rank (best -> lowest)
self._sigmas = self._make_sigmas()
self._trial_counter = 0
self._generation = None
self._trial_map = {}
def _make_sigmas(self):
"""Create geometrically spaced sigmas from sigma_min to sigma_max."""
n = self.population_size
if n == 1:
return [math.sqrt(self.sigma_min * self.sigma_max)]
log_min = math.log(self.sigma_min)
log_max = math.log(self.sigma_max)
return [
math.exp(log_min + i * (log_max - log_min) / (n - 1))
for i in range(n)
]
def _sample_candidate(self, worker_idx, noise=None):
"""Perturb a worker's current solution with its assigned sigma."""
sigma = self._sigmas[worker_idx]
sol = self._solutions[worker_idx]
if noise is None:
# possibly use differential perturbation
if (
self.differential_prob > 0.0
and self.population_size >= 3
and self.rng.random() < self.differential_prob
):
best_idx = min(
range(self.population_size),
key=lambda i: self._scores[i],
)
others = [
j
for j in range(self.population_size)
if j != worker_idx and j != best_idx
]
rand_idx = self.rng.choice(others)
noise = [
self._solutions[best_idx][d] - self._solutions[rand_idx][d]
for d in range(self.ndim)
]
else:
noise = [self.rng.gauss(0.0, 1.0) for _ in range(self.ndim)]
x = []
for si, ni in zip(sol, noise):
xi = si + sigma * ni
x.append(min(max(xi, -1.0), 1.0))
return tuple(x)
def _sample_generation(self):
"""Start a new generation with one trial per worker."""
self._generation = {
"worker_indices": [],
"xs": [],
"trial_numbers": [],
"scores": {},
"next_index": 0,
}
for i in range(self.population_size):
self._extend_generation(i)
def _extend_generation(self, worker_idx=None, noise=None):
"""Append one more trial to the current generation."""
if worker_idx is None:
worker_idx = len(self._generation["xs"]) % self.population_size
trial_number = self._trial_counter
self._trial_counter += 1
slot = len(self._generation["xs"])
self._generation["worker_indices"].append(worker_idx)
self._generation["xs"].append(
self._sample_candidate(worker_idx, noise=noise)
)
self._generation["trial_numbers"].append(trial_number)
self._trial_map[trial_number] = slot
def ask(self):
"""Return the next candidate from the current generation.
If all prepared candidates have been issued, grow the generation
by one more sample.
"""
if self._generation is None:
self._sample_generation()
if self._generation["next_index"] >= len(self._generation["xs"]):
self._extend_generation()
i = self._generation["next_index"]
self._generation["next_index"] += 1
trial_number = self._generation["trial_numbers"][i]
x = self._generation["xs"][i]
return trial_number, convert_raw(self.params, x)
def tell(self, trial_number, score):
"""Record a completed trial and update workers if the generation
is complete.
For each trial, if it scores better than (or equal to) its
worker's current best, the worker adopts the new solution.
Then sigmas are reassigned by rank: best worker gets lowest
sigma, worst gets highest.
"""
slot = self._trial_map.pop(trial_number)
self._generation["scores"][slot] = score
if len(self._generation["scores"]) != self._generation["next_index"]:
return
# greedy update: adopt better solutions and track stagnation
improved = set()
for slot_i in range(self._generation["next_index"]):
worker_idx = self._generation["worker_indices"][slot_i]
trial_score = self._generation["scores"][slot_i]
if trial_score <= self._scores[worker_idx]:
self._solutions[worker_idx] = self._generation["xs"][slot_i]
self._scores[worker_idx] = trial_score
improved.add(worker_idx)
for i in range(self.population_size):
if i in improved:
self._stagnation[i] = 0
else:
self._stagnation[i] += 1
# stagnation restart
if self.patience:
for i in range(self.population_size):
if self._stagnation[i] >= self.patience:
self._solutions[i] = tuple(
self.rng.uniform(-1.0, 1.0) for _ in range(self.ndim)
)
self._scores[i] = float("inf")
self._stagnation[i] = 0
# rank workers by score (best to worst) and reassign sigmas
ranking = sorted(
range(self.population_size),
key=lambda i: self._scores[i],
)
# elite migration
if self.elite_migrate_prob > 0.0:
if self.rng.random() < self.elite_migrate_prob:
worst = ranking[-1]
best = ranking[0]
self._solutions[worst] = self._solutions[best]
new_sigmas = [0.0] * self.population_size
base_sigmas = self._make_sigmas()
for rank, worker_idx in enumerate(ranking):
new_sigmas[worker_idx] = base_sigmas[rank]
self._sigmas = new_sigmas
self._generation = None
class PEOptLib(HyperOptLib):
"""Hyper-optimization using parallel evolution with ranked sigmas."""
def setup(
self,
methods,
space,
optimizer=None,
population_size="auto",
sigma_min=0.01,
sigma_max=0.5,
elite_migrate_prob=0.5,
differential_prob=0.5,
patience=8,
method_exploration=1.0,
method_temperature=1.0,
exponential_param_power=None,
seed=None,
**kwargs,
):
"""Initialize PE optimizers for each contraction method.
Parameters
----------
methods : list[str]
The contraction methods to optimize over.
space : dict[str, dict[str, dict]]
The per-method hyperparameter search space.
optimizer : HyperOptimizer, optional
The parent optimizer. Used to size the initial population
large enough for parallel pre-dispatch.
population_size : int or "auto", optional
The number of parallel workers for each method.
sigma_min : float, optional
Smallest perturbation scale (for the best-ranked worker).
sigma_max : float, optional
Largest perturbation scale (for the worst-ranked worker).
elite_migrate_prob : float, optional
Probability of copying best solution to worst worker.
differential_prob : float, optional
Per-sample probability of differential perturbation.
patience : int or None, optional
Generations without improvement before restart.
method_exploration : float, optional
Exploration strength for the LCB-based method chooser.
method_temperature : float, optional
Noise temperature for the LCB-based method chooser.
exponential_param_power : float, optional
Passed to the shared parameter mapping for ``FLOAT_EXP``.
seed : None or int, optional
Random seed.
"""
if population_size == "auto":
max_ndim = max(
num_params(
build_params(
space[m],
exponential_param_power=exponential_param_power,
)
)
for m in methods
)
population_size = max(
8,
max(1, getattr(optimizer, "pre_dispatch", 1)),
4 * max_ndim,
)
self._method_chooser = LCBOptimizer(
options=methods,
exploration=method_exploration,
temperature=method_temperature,
seed=seed,
)
self._optimizers = {
method: HyperPESampler(
space[method],
seed=seed,
population_size=population_size,
sigma_min=sigma_min,
sigma_max=sigma_max,
elite_migrate_prob=elite_migrate_prob,
differential_prob=differential_prob,
patience=patience,
exponential_param_power=exponential_param_power,
)
for method in methods
}
def get_setting(self):
"""Choose a contraction method, then request its next setting."""
method = self._method_chooser.ask()
params_token, params = self._optimizers[method].ask()
return {
"method": method,
"params_token": params_token,
"params": params,
}
def report_result(self, setting, trial, score):
"""Report a completed trial back to the method chooser and PE."""
self._method_chooser.tell(setting["method"], score)
self._optimizers[setting["method"]].tell(
setting["params_token"], score
)
register_hyper_optlib("pe", PEOptLib)
register_hyper_optlib("parallelev", PEOptLib)
================================================
FILE: cotengra/experimental/hyper_pymoo.py
================================================
"""Hyper optimization using pymoo single-objective algorithms.
This backend currently supports serial optimization only. Pymoo ask/tell
algorithms operate on generations/batches rather than individual trials, so
the integration buffers one full batch at a time and feeds it back when all
batch members have been evaluated.
"""
from ._param_mapping import (
LCBOptimizer,
build_params,
convert_raw,
num_params,
)
from .hyper import HyperOptLib, register_hyper_optlib
def _get_pymoo_algorithm(name):
if name == "de":
from pymoo.algorithms.soo.nonconvex.de import DE
return DE
if name == "ga":
from pymoo.algorithms.soo.nonconvex.ga import GA
return GA
if name == "pso":
from pymoo.algorithms.soo.nonconvex.pso import PSO
return PSO
if name == "brkga":
from pymoo.algorithms.soo.nonconvex.brkga import BRKGA
return BRKGA
if name == "es":
from pymoo.algorithms.soo.nonconvex.es import ES
return ES
if name == "sres":
from pymoo.algorithms.soo.nonconvex.sres import SRES
return SRES
if name == "isres":
from pymoo.algorithms.soo.nonconvex.isres import ISRES
return ISRES
if name == "nrbo":
from pymoo.algorithms.soo.nonconvex.nrbo import NRBO
return NRBO
raise ValueError(f"Unknown pymoo sampler {name}.")
class HyperPymooSampler:
"""Per-method ask/tell wrapper around a pymoo algorithm."""
def __init__(
self,
space,
sampler="de",
sampler_opts=None,
exponential_param_power=None,
seed=None,
):
import numpy as np
from pymoo.core.evaluator import Evaluator
from pymoo.core.problem import Problem
from pymoo.core.termination import NoTermination
from pymoo.problems.static import StaticProblem
self._np = np
self._Evaluator = Evaluator
self._StaticProblem = StaticProblem
self.params = build_params(
space, exponential_param_power=exponential_param_power
)
self._ndim = num_params(self.params)
self._problem = Problem(
n_var=self._ndim,
n_obj=1,
n_constr=0,
xl=np.full(self._ndim, -1.0),
xu=np.full(self._ndim, 1.0),
)
sampler_opts = {} if sampler_opts is None else dict(sampler_opts)
Algorithm = _get_pymoo_algorithm(sampler)
self.algorithm = Algorithm(**sampler_opts)
self.algorithm.setup(
self._problem,
termination=NoTermination(),
seed=seed,
verbose=False,
)
self._trial_counter = 0
self._active_batch = None
def ask(self):
if self._active_batch is None:
pop = self.algorithm.ask()
xs = pop.get("X")
trial_numbers = tuple(
range(
self._trial_counter,
self._trial_counter + len(xs),
)
)
self._trial_counter += len(xs)
settings = tuple(convert_raw(self.params, x.copy()) for x in xs)
self._active_batch = {
"pop": pop,
"trial_numbers": trial_numbers,
"settings": settings,
"scores": {},
"next_index": 0,
}
i = self._active_batch["next_index"]
self._active_batch["next_index"] += 1
return (
self._active_batch["trial_numbers"][i],
self._active_batch["settings"][i],
)
def tell(self, trial_number, score):
batch = self._active_batch
batch["scores"][trial_number] = score
if len(batch["scores"]) != len(batch["trial_numbers"]):
return
f = self._np.asarray(
[batch["scores"][t] for t in batch["trial_numbers"]],
dtype=float,
).reshape(-1, 1)
static = self._StaticProblem(self._problem, F=f)
self._Evaluator().eval(static, batch["pop"])
self.algorithm.tell(infills=batch["pop"])
self._active_batch = None
class PymooOptLib(HyperOptLib):
"""Hyper-optimization using pymoo algorithms with LCB method choice."""
def setup(
self,
methods,
space,
optimizer=None,
sampler="de",
sampler_opts=None,
method_exploration=1.0,
method_temperature=1.0,
exponential_param_power=None,
seed=None,
**kwargs,
):
if getattr(optimizer, "_pool", None) is not None:
raise ValueError(
"The 'pymoo' optlib currently only supports serial "
"hyper-optimization (`parallel=False`)."
)
self._method_chooser = LCBOptimizer(
options=methods,
exploration=method_exploration,
temperature=method_temperature,
seed=seed,
)
self._optimizers = {
method: HyperPymooSampler(
space[method],
sampler=sampler,
sampler_opts=sampler_opts,
exponential_param_power=exponential_param_power,
seed=seed,
)
for method in methods
}
def get_setting(self):
method = self._method_chooser.ask()
params_token, params = self._optimizers[method].ask()
return {
"method": method,
"params_token": params_token,
"params": params,
}
def report_result(self, setting, trial, score):
method = setting["method"]
self._method_chooser.tell(method, score)
self._optimizers[method].tell(setting["params_token"], score)
register_hyper_optlib("pymoo", PymooOptLib)
================================================
FILE: cotengra/experimental/hyper_scipy.py
================================================
"""Hyper optimization using scipy gradient-free optimizers.
Supported methods: ``differential_evolution``, ``dual_annealing``,
``direct``, ``shgo``. Since these optimizers use a callback-style
objective, a background thread is used to invert the control flow into
an ask/tell interface. Multiple workers are maintained per method to
support parallel pre-dispatch (ask-ask-...-tell-tell-...).
"""
import queue
import threading
from ._param_mapping import (
LCBOptimizer,
build_params,
convert_raw,
num_params,
)
from .hyper import HyperOptLib, register_hyper_optlib
_OPTIMIZER_NAMES = {
"differential_evolution",
"dual_annealing",
"direct",
"shgo",
}
class _StopOptimization(Exception):
"""Raised inside the objective to abort the scipy optimizer."""
class ScipyAskTell:
"""Ask/tell wrapper around a scipy global optimizer.
The optimizer runs in a background thread. Each time it needs an objective
evaluation it posts the candidate vector to ``_ask_q`` and blocks on
``_tell_q``. The caller drives progress by alternating
``ask()`` / ``tell()`` calls from the main thread.
Parameters
----------
method : str
One of the supported scipy optimizer names.
bounds : list[tuple[float, float]]
Bounds for every raw dimension.
kwargs
Forwarded to the underlying scipy optimizer.
"""
def __init__(self, method, bounds, **kwargs):
self.method = method
self.bounds = bounds
self.kwargs = kwargs
self._ask_q = queue.Queue()
self._tell_q = queue.Queue()
self._thread = None
self._stop = threading.Event()
self.done = False
# ---- internal --------------------------------------------------------
def _get_optimizer_fn(self):
from scipy import optimize
return getattr(optimize, self.method)
def _objective(self, x):
if self._stop.is_set():
raise _StopOptimization
self._ask_q.put(x)
val = self._tell_q.get()
if self._stop.is_set():
raise _StopOptimization
return float(val)
def _run(self):
try:
fn = self._get_optimizer_fn()
fn(self._objective, self.bounds, **self.kwargs)
except _StopOptimization:
pass
except Exception:
pass
finally:
self.done = True
# unblock any pending ask()
self._ask_q.put(None)
# ---- public ----------------------------------------------------------
def start(self):
"""Launch the optimizer in a background thread."""
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
def ask(self):
"""Block until the optimizer requests an evaluation.
Returns
-------
x : ndarray or None
The candidate vector, or ``None`` if the optimizer finished.
"""
x = self._ask_q.get()
if self.done and x is None:
return None
return x
def tell(self, score):
"""Provide the objective value back to the optimizer."""
self._tell_q.put(score)
def stop(self):
"""Signal the background thread to stop and wait for it."""
self._stop.set()
# unblock if the thread is waiting in _objective for a tell
try:
self._tell_q.put_nowait(float("inf"))
except queue.Full:
pass
if self._thread is not None:
self._thread.join(timeout=5.0)
class HyperScipySampler:
"""Per-method optimizer that wraps a pool of ``ScipyAskTell``
workers with the ``Param``-based space mapping.
Each ``ScipyAskTell`` worker is strictly serial
(ask-tell-ask-tell), so to support parallel pre-dispatch
(multiple ``ask`` calls before any ``tell``), we maintain a pool
of workers. Each ``ask`` grabs a candidate from the next idle
worker; each ``tell`` feeds the score back to the specific worker
that produced that candidate.
Parameters
----------
space : dict[str, dict]
The search space for one method.
method : str
Which scipy optimizer to use.
n_workers : int, optional
Number of concurrent ``ScipyAskTell`` threads to run.
exponential_param_power : float, optional
Power for ``ParamFloatExp``.
kwargs
Extra keyword arguments forwarded to the scipy optimizer.
"""
def __init__(
self,
space,
method="differential_evolution",
n_workers=1,
exponential_param_power=None,
**kwargs,
):
self.params = build_params(
space, exponential_param_power=exponential_param_power
)
self._ndim = num_params(self.params)
self._method = method
self._scipy_opts = kwargs
self._n_workers = n_workers
# trial bookkeeping
self._trial_counter = 0
self._trial_to_worker = {}
# worker pool: round-robin index for asks
self._workers = []
self._worker_idx = 0
for _ in range(self._n_workers):
self._workers.append(self._make_worker())
def _make_worker(self):
"""Create and start a fresh ``ScipyAskTell`` instance."""
w = ScipyAskTell(
method=self._method,
bounds=[(-1.0, 1.0)] * self._ndim,
**self._scipy_opts,
)
w.start()
return w
def _next_worker(self):
"""Pick the next worker in round-robin order, restarting any
that have finished.
"""
for _ in range(self._n_workers):
w = self._workers[self._worker_idx]
self._worker_idx = (self._worker_idx + 1) % self._n_workers
if w.done:
# optimizer converged / exhausted – replace it
w.stop()
w = self._make_worker()
# the index was already advanced, so store at previous
prev = (self._worker_idx - 1) % self._n_workers
self._workers[prev] = w
return w
# should never get here
return self._workers[0]
def ask(self):
"""Return ``(trial_number, params_dict)``.
Picks the next idle worker (round-robin), blocking until it
has a candidate ready. If a worker has converged, it is
automatically restarted.
"""
w = self._next_worker()
x = w.ask()
if x is None:
# worker finished between _next_worker check and ask
prev = (self._worker_idx - 1) % self._n_workers
w = self._make_worker()
self._workers[prev] = w
x = w.ask()
trial_number = self._trial_counter
self._trial_to_worker[trial_number] = w
self._trial_counter += 1
return trial_number, convert_raw(self.params, x)
def tell(self, trial_number, score):
"""Report a score back to the specific worker that produced
this trial.
"""
w = self._trial_to_worker.pop(trial_number, None)
if w is not None and not w.done:
w.tell(score)
def stop(self):
"""Stop all background threads."""
for w in self._workers:
w.stop()
class ScipyOptLib(HyperOptLib):
"""Hyper-optimization using scipy gradient-free optimizers with
an LCB method selector.
"""
def setup(
self,
methods,
space,
optimizer=None,
method="differential_evolution",
method_exploration=1.0,
method_temperature=1.0,
exponential_param_power=None,
**scipy_opts,
):
"""Initialize per-method scipy optimizers.
Parameters
----------
methods : list[str]
The contraction methods to optimize over.
space : dict[str, dict[str, dict]]
The search space.
optimizer : HyperOptimizer, optional
The parent optimizer instance.
method : str, optional
Which scipy global optimizer to use. One of
``'differential_evolution'``, ``'dual_annealing'``,
``'direct'``, ``'shgo'``.
method_exploration : float, optional
Exploration parameter for the LCB method selector.
method_temperature : float, optional
Temperature parameter for the LCB method selector.
exponential_param_power : float, optional
Power for ``ParamFloatExp``.
scipy_opts
Extra keyword arguments forwarded to the scipy optimizer.
"""
if method not in _OPTIMIZER_NAMES:
raise ValueError(
f"method must be one of {sorted(_OPTIMIZER_NAMES)}, "
f"got {method!r}"
)
n_workers = getattr(optimizer, "_num_workers", 1)
# need at least as many workers per method as there are
# parallel pre-dispatch slots, so that asks never block on a
# worker waiting for its tell
pre_dispatch = getattr(optimizer, "pre_dispatch", 1)
n_workers = max(n_workers, pre_dispatch)
self._method_chooser = LCBOptimizer(
options=methods,
exploration=method_exploration,
temperature=method_temperature,
)
self._optimizers = {
m: HyperScipySampler(
space[m],
method=method,
n_workers=n_workers,
exponential_param_power=exponential_param_power,
**scipy_opts,
)
for m in methods
}
def get_setting(self):
method = self._method_chooser.ask()
trial_number, params = self._optimizers[method].ask()
return {
"method": method,
"params_token": trial_number,
"params": params,
}
def report_result(self, setting, trial, score):
self._method_chooser.tell(setting["method"], score)
self._optimizers[setting["method"]].tell(
setting["params_token"], score
)
def cleanup(self):
"""Stop all background optimizer threads."""
for sampler in self._optimizers.values():
sampler.stop()
register_hyper_optlib("scipy", ScipyOptLib)
================================================
FILE: cotengra/experimental/hyper_smac.py
================================================
"""Hyper parameter optimization using SMAC3.
https://automl.github.io/SMAC3/latest/
"""
from .hyper import HyperOptLib, register_hyper_optlib
def build_config_space(method, space):
"""Build a SMAC ``ConfigurationSpace`` from a cotengra space dict.
Parameters
----------
method : str
The method name (used as a prefix to avoid name collisions when
multiple methods share the same parameter name).
space : dict[str, dict]
The search space for a single method.
Returns
-------
cs : ConfigurationSpace
"""
from ConfigSpace import (
CategoricalHyperparameter,
UniformFloatHyperparameter,
UniformIntegerHyperparameter,
)
from smac.configspace import ConfigurationSpace
cs = ConfigurationSpace()
for name, param in space.items():
ptype = param["type"]
if ptype == "FLOAT":
hp = UniformFloatHyperparameter(
name, lower=param["min"], upper=param["max"]
)
elif ptype == "FLOAT_EXP":
hp = UniformFloatHyperparameter(
name, lower=param["min"], upper=param["max"], log=True
)
elif ptype == "INT":
hp = UniformIntegerHyperparameter(
name, lower=param["min"], upper=param["max"]
)
elif ptype == "STRING":
hp = CategoricalHyperparameter(name, choices=param["options"])
elif ptype == "BOOL":
hp = CategoricalHyperparameter(name, choices=[False, True])
else:
raise ValueError(f"Unknown parameter type: {ptype!r}")
cs.add_hyperparameter(hp)
return cs
def config_to_params(config):
"""Convert a SMAC ``Configuration`` to a plain dict of parameters."""
return dict(config)
class SMACOptLib(HyperOptLib):
"""Hyper-optimization using SMAC3 with per-method facades and
a Lower Confidence Bound method selector.
"""
def setup(
self,
methods,
space,
optimizer=None,
facade="BlackBoxFacade",
n_trials=10000,
seed=0,
method_exploration=1.0,
method_temperature=1.0,
**facade_opts,
):
from smac import BlackBoxFacade, HyperparameterOptimizationFacade
from smac.scenario import Scenario
from ._param_mapping import LCBOptimizer
self._method_chooser = LCBOptimizer(
options=methods,
exploration=method_exploration,
temperature=method_temperature,
)
if isinstance(facade, str):
facade_cls = {
"BlackBoxFacade": BlackBoxFacade,
"HyperparameterOptimizationFacade": HyperparameterOptimizationFacade,
}[facade]
else:
facade_cls = facade
self._facades = {}
self._trial_infos = {}
for method in methods:
cs = build_config_space(method, space[method])
scenario = Scenario(
cs,
n_trials=n_trials,
seed=seed,
deterministic=True,
)
self._facades[method] = facade_cls(
scenario,
target_function=lambda cfg, seed: 0.0,
overwrite=True,
logging_level=False,
**facade_opts,
)
def get_setting(self):
method = self._method_chooser.ask()
smac = self._facades[method]
info = smac.ask()
trial_key = (method, id(info))
self._trial_infos[trial_key] = info
return {
"method": method,
"trial_key": trial_key,
"params": config_to_params(info.config),
}
def report_result(self, setting, trial, score):
from smac.runhistory import TrialValue
method = setting["method"]
trial_key = setting["trial_key"]
info = self._trial_infos.pop(trial_key)
self._method_chooser.tell(method, score)
value = TrialValue(cost=score)
self._facades[method].tell(info, value)
register_hyper_optlib("smac", SMACOptLib)
================================================
FILE: cotengra/experimental/multi.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "9576af0f-adc3-44e4-88d0-7a57dc97a1d5",
"metadata": {},
"outputs": [],
"source": [
"%config InlineBackend.figure_formats = ['svg']\n",
"import math # noqa\n",
"import random # noqa\n",
"import itertools # noqa\n",
"import functools # noqa\n",
"import collections # noqa\n",
"import tqdm # noqa\n",
"import numpy as np # noqa\n",
"import matplotlib as mpl # noqa\n",
"import matplotlib.pyplot as plt # noqa\n",
"import quimb as qu # noqa\n",
"import xyzpy as xyz # noqa\n",
"import autoray as ar # noqa\n",
"import cotengra as ctg # noqa\n",
"import quimb.tensor as qtn # noqa"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "881a2afe-11e9-43f4-a62f-0bda257b7750",
"metadata": {},
"outputs": [],
"source": [
"tn = qtn.PEPS.rand(4, 5, 4)\n",
"# tn = qtn.MPS_rand_state(20, 7)\n",
"\n",
"# tn = qtn.TN2D_rand(32, 32, D=2)\n",
"\n",
"\n",
"inputs, output, size_dict = tn.get_inputs_output_size_dict()\n",
"symbol_map = tn.get_symbol_map()\n",
"\n",
"rng = ctg.utils.get_rng(42)\n",
"\n",
"M = 200\n",
"\n",
"sliced_inds = output\n",
"\n",
"configs = [{ix: rng.choice([0, 1]) for ix in output} for _ in range(M)]\n",
"\n",
"tree = ctg.array_contract_tree(\n",
" inputs,\n",
" output,\n",
" size_dict,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "46dd3357-18b0-4a80-afd5-935f6a4029f9",
"metadata": {},
"outputs": [],
"source": [
"# inputs, output, shapes, size_dict = ctg.utils.randreg_equation(100, 4, seed=42)\n",
"inputs, output, shapes, size_dict = ctg.utils.lattice_equation(\n",
" [28, 28], d_max=2\n",
")\n",
"\n",
"opt = ctg.HyperOptimizer(\n",
" simulated_annealing_opts=dict(target_size=2**24),\n",
" max_repeats=8,\n",
" progbar=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "92b7e150-cffd-40d7-ab26-7cf34af79476",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"F=14.05 C=14.72 S=24 P=25.02 $=8192: 100%|██████████| 16/16 [01:56<00:00, 7.31s/it] \n"
]
}
],
"source": [
"tree = opt.search(inputs, output, size_dict)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d47c757e-ecb8-43c8-91a7-d464cb146353",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"8192"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sliced_inds = tree.sliced_inds\n",
"configs = [\n",
" dict(zip(sliced_inds, x))\n",
" for x in itertools.product(*[range(size_dict[i]) for i in sliced_inds])\n",
"]\n",
"len(configs)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "771de335-5945-4909-902f-0f77b13d6c7e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.573196922320367"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree.contraction_cost() / tree.unslice_all().contraction_cost()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "b020426e-19a8-4700-ab4a-80dc8d714680",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"24.0"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree.max_size(log=2)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c3802312-f19a-481b-9b17-371c28b49d37",
"metadata": {},
"outputs": [],
"source": [
"from cotengra.core import (\n",
" ContractionTree,\n",
" get_score_fn,\n",
")\n",
"from cotengra.scoring import Objective\n",
"\n",
"\n",
"class MultiObjective(Objective):\n",
" __slots__ = (\"num_configs\",)\n",
"\n",
" def __init__(self, num_configs):\n",
" self.num_configs = num_configs\n",
"\n",
" def compute_mult(self, dims):\n",
" raise NotImplementedError\n",
"\n",
" def estimate_node_mult(self, tree, node):\n",
" return self.compute_mult(\n",
" [tree.size_dict[ix] for ix in tree.get_node_var_inds(node)]\n",
" )\n",
"\n",
" def estimate_node_cache_mult(self, tree, node, sliced_ind_ordering):\n",
" node_var_inds = tree.get_node_var_inds(node)\n",
"\n",
" # indices which are the first 'k' in the sliced ordering\n",
" non_heavy_inds = [\n",
" ix\n",
" for ix in tree.get_node_var_inds(node)\n",
" if ix not in sliced_ind_ordering[: len(node_var_inds)]\n",
" ]\n",
"\n",
" # each of these cycles 'out of sync' and thus must be kept\n",
" return self.compute_mult([tree.size_dict[ix] for ix in non_heavy_inds])\n",
"\n",
"\n",
"class MultiObjectiveDense(MultiObjective):\n",
" \"\"\"Number of intermediate configurations is expected to scale as if all\n",
" configurations are present.\n",
" \"\"\"\n",
"\n",
" __slots__ = (\"num_configs\",)\n",
"\n",
" def compute_mult(self, dims):\n",
" return math.prod(dims)\n",
"\n",
"\n",
"def expected_coupons(num_sub, num_total):\n",
" \"\"\"If we draw a random 'coupon` which can take `num_sub` different values\n",
" `num_total` times, how many unique coupons will we expect?\n",
" \"\"\"\n",
" # return min(num_sub, num_total)\n",
" return num_sub * (1 - (1 - 1 / num_sub) ** num_total)\n",
"\n",
"\n",
"class MultiObjectiveUniform(MultiObjective):\n",
" \"\"\"Number of intermediate configurations is expected to scale as if all\n",
" configurations are randomly draw from a uniform distribution.\n",
" \"\"\"\n",
"\n",
" __slots__ = (\"num_configs\",)\n",
"\n",
" def compute_mult(self, dims):\n",
" return expected_coupons(math.prod(dims), self.num_configs)\n",
"\n",
"\n",
"class MultiObjectiveLinear(MultiObjective):\n",
" \"\"\"Number of intermediate configurations is expected to scale linearly with\n",
" respect to number of variable indices (e.g. VMC like 'locally connected'\n",
" configurations).\n",
" \"\"\"\n",
"\n",
" __slots__ = (\"num_configs\", \"coeff\")\n",
"\n",
" def __init__(self, num_configs, coeff=1):\n",
" self.coeff = coeff\n",
" super().__init__(num_configs=num_configs)\n",
"\n",
" def compute_mult(self, dims):\n",
" return min(self.coeff * len(dims), self.num_configs)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "2a6d0a84-e0a0-4e4a-a92b-fd70e89ab75b",
"metadata": {},
"outputs": [],
"source": [
"M = len(configs)\n",
"# objective = MultiObjectiveUniform(num_configs=M)\n",
"# objective = MultiObjectiveDense(num_configs=M)\n",
"objective = MultiObjectiveLinear(num_configs=M, coeff=M / len(sliced_inds))\n",
"\n",
"mtree = ctg.ContractionTreeMulti.from_path(\n",
" inputs,\n",
" output,\n",
" size_dict,\n",
" sliced_inds=sliced_inds,\n",
" objective=objective,\n",
" path=tree.get_path(),\n",
")\n",
"mtree.reorder_contractions_for_peak_est()\n",
"mtree.reorder_sliced_inds()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "0a89762d-f34b-4c42-b187-c291be640341",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mtree.multiplicity"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "8482cdf3-7f5f-423d-8bff-51188fd4272e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'flops': 5273600.0, 'write': 6913, 'size': 1024}"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mtree.contract_stats()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "2b380e14-f8ab-4e18-b83e-3c6177790856",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"8.5302276561314"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"math.log10(339021824)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "d7d99d38-bf07-40f4-9230-9893ad0fd939",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4.920123326290724"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"math.log10(\n",
" sum(\n",
" ContractionTree.get_flops(mtree, node)\n",
" for node, _, _ in mtree.descend()\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "f342e882-ab08-4707-86a7-ad93bfa8ab2a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6.722107185681003"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"math.log10(sum(mtree.get_flops(node) for node, _, _ in mtree.descend()))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "38609b23-ef7c-44eb-b382-977f79d98d8c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6.722107185681002"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mtree.contraction_cost(log=10)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "bb422dce-0da9-4cc0-80c0-02657691d71c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'flops': 6203392, 'write': 248008, 'size': 1024, 'peak': 111360}"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stats = mtree.exact_multi_stats(configs)\n",
"stats"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "84a0c184-6e73-4a39-88c3-5b0c8c7edafa",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6.792629225636649"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"math.log10(stats[\"flops\"])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "8449c4a8-4e1c-43dd-b614-7dba4b28e2af",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"16.98299357469431"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mtree.peak_size(log=2)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "d32c40a1-19c3-4f62-9b38-aa3d067eb838",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"16.764871590736092"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"math.log2(stats[\"peak\"])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "7fd0ac7d-ad07-4654-8fbc-278e9f3d07d3",
"metadata": {},
"outputs": [],
"source": [
"from cotengra.pathfinders.path_simulated_annealing import (\n",
" _describe_tree,\n",
" _slice_tree_basic,\n",
" _slice_tree_drift,\n",
" _slice_tree_reslice,\n",
" get_rng,\n",
" linspace_generator,\n",
")\n",
"\n",
"\n",
"def compute_contracted_info(\n",
" legsa,\n",
" legsb,\n",
" appearances,\n",
" size_dict,\n",
" var_inds_a,\n",
" var_inds_b,\n",
"):\n",
" \"\"\"Compute the contracted legs, cost and size of a pair of legs.\n",
"\n",
" Parameters\n",
" ----------\n",
" legsa : dict[str, int]\n",
" The legs of the first tensor.\n",
" legsb : dict[str, int]\n",
" The legs of the second tensor.\n",
" appearances : dict[str, int]\n",
" The total number of appearances of each index in the contraction.\n",
" size_dict : dict[str, int]\n",
" The size of each index.\n",
"\n",
" Returns\n",
" -------\n",
" legsab : dict[str, int]\n",
" The contracted legs.\n",
" cost : int\n",
" The cost of the contraction.\n",
" size : int\n",
" The size of the resulting tensor.\n",
" \"\"\"\n",
" legsab = {}\n",
" cost = 1\n",
" size = 1\n",
"\n",
" # handle all left indices\n",
" for ix, ix_count in legsa.items():\n",
" d = size_dict[ix]\n",
" # all involved indices contribute to cost\n",
" cost *= d\n",
" if ix in legsb:\n",
" ix_count += legsb[ix]\n",
" if ix_count < appearances[ix]:\n",
" # index appears on output\n",
" legsab[ix] = ix_count\n",
" # and so contributes to size\n",
" size *= d\n",
"\n",
" # now handle right indices that we haven't seen yet\n",
" for ix, ix_count in legsb.items():\n",
" if ix not in legsa:\n",
" d = size_dict[ix]\n",
" cost *= d\n",
" if ix_count < appearances[ix]:\n",
" legsab[ix] = ix_count\n",
" size *= d\n",
"\n",
" return legsab, cost, size, var_inds_a | var_inds_b\n",
"\n",
"\n",
"def simulated_anneal_tree(\n",
" tree,\n",
" tfinal=0.05,\n",
" tstart=2,\n",
" tsteps=50,\n",
" numiter=50,\n",
" minimize=None,\n",
" target_size=None,\n",
" target_size_initial=None,\n",
" slice_mode=\"basic\",\n",
" seed=None,\n",
" progbar=False,\n",
" inplace=False,\n",
"):\n",
" \"\"\"Perform a simulated annealing optimization of this contraction\n",
" tree, based on \"Multi-Tensor Contraction for XEB Verification of\n",
" Quantum Circuits\" by Gleb Kalachev, Pavel Panteleev, Man-Hong Yung\n",
" (arXiv:2108.05665), and the \"treesa\" implementation in\n",
" OMEinsumContractionOrders.jl by Jin-Guo Liu and Pan Zhang.\n",
"\n",
" Parameters\n",
" ----------\n",
" tfinal : float, optional\n",
" The final temperature.\n",
" tstart : float, optional\n",
" The starting temperature.\n",
" tsteps : int, optional\n",
" The number of temperature steps.\n",
" numiter : int, optional\n",
" The number of sweeps at each temperature step.\n",
" minimize : {'flops', 'combo', 'write', 'size', ...}, optional\n",
" The objective function to minimize.\n",
" target_size : int, optional\n",
" The target size to slice the contraction to. A schedule is used to\n",
" reach this only at the final temperature step.\n",
" target_size_initial : int, optional\n",
" The initial target size to use in the slicing schedule. If None, then\n",
" the current size is used.\n",
" slice_mode : {'basic', 'reslice', 'drift'}, optional\n",
" The mode for slicing the contraction tree within each annealing\n",
" iteration. 'basic' always unslices a random index and then slices to\n",
" the target size. 'reslice' unslices all indices and then slices to the\n",
" target size. 'drift' unslices a random index with probability 1/4 and\n",
" slices to the target size with probability 3/4. It is therefore not\n",
" guaranteed to reach the target size, but may be more explorative for\n",
" long annealing schedules.\n",
" seed : int, optional\n",
" A random seed.\n",
" progbar : bool, optional\n",
" Whether to show live progress.\n",
" inplace : bool, optional\n",
" Whether to perform the optimization inplace.\n",
"\n",
" Returns\n",
" -------\n",
" ContractionTree\n",
" \"\"\"\n",
"\n",
" tree = tree if inplace else tree.copy()\n",
" # ensure stats tracking is on\n",
" tree.contract_stats()\n",
"\n",
" if minimize is None:\n",
" minimize = tree.get_default_objective()\n",
" scorer = get_score_fn(minimize)\n",
" rng = get_rng(seed)\n",
"\n",
" # create a schedule for annealing temperatures\n",
" temps = linspace_generator(tstart, tfinal, tsteps, log=True)\n",
"\n",
" if target_size is not None:\n",
" # create a schedule for slicing target sizes\n",
" if target_size_initial is None:\n",
" # start with the current size\n",
" current_size = max(tree.contraction_width(log=None), target_size)\n",
" else:\n",
" current_size = max(target_size_initial, target_size)\n",
"\n",
" target_sizes = linspace_generator(\n",
" current_size,\n",
" target_size,\n",
" tsteps,\n",
" log=True,\n",
" )\n",
" _slice_tree = {\n",
" \"basic\": _slice_tree_basic,\n",
" \"reslice\": _slice_tree_reslice,\n",
" \"drift\": _slice_tree_drift,\n",
" }[slice_mode]\n",
" else:\n",
" target_sizes = itertools.repeat(None)\n",
"\n",
" def _slice_tree(tree, current_target_size, rng):\n",
" pass\n",
"\n",
" if progbar:\n",
" import tqdm\n",
"\n",
" pbar = tqdm.tqdm(total=tsteps)\n",
" pbar.set_description(_describe_tree(tree))\n",
"\n",
" for temp in temps:\n",
" # handle slicing\n",
" _slice_tree(tree, next(target_sizes), rng)\n",
"\n",
" for _ in range(numiter):\n",
" candidates = [tree.root]\n",
"\n",
" while candidates:\n",
" p = candidates.pop(0)\n",
" l, r = tree.children[p]\n",
"\n",
" # check which local moves are possible\n",
" if len(l) == 1:\n",
" if len(r) == 1:\n",
" # both are leaves\n",
" continue\n",
" else:\n",
" # left is leaf\n",
" rule = rng.randint(2, 3)\n",
" elif len(r) == 1:\n",
" # right is leaf\n",
" rule = rng.randint(0, 1)\n",
" else:\n",
" # neither are leaves\n",
" rule = rng.randint(0, 3)\n",
"\n",
" if rule < 2:\n",
" # ((AB)C)\n",
" x, c = l, r\n",
" a, b = tree.children[x]\n",
" if rule == 0:\n",
" # -> ((AC)B)\n",
" new_order = [a, c, b]\n",
" else:\n",
" # -> (A(BC))\n",
" new_order = [b, c, a]\n",
" else:\n",
" # (A(BC))\n",
" a, x = l, r\n",
" b, c = tree.children[x]\n",
" if rule == 2:\n",
" # -> (B(AC))\n",
" new_order = [a, c, b]\n",
" else:\n",
" # -> (C(AB))\n",
" new_order = [a, b, c]\n",
"\n",
" current_score = math.log2(\n",
" tree.get_flops(p) + tree.get_flops(x)\n",
" )\n",
"\n",
" # current_score = scorer.score_local(\n",
" # flops=[tree.get_flops(p), tree.get_flops(x)],\n",
" # size=[tree.get_size(p), tree.get_size(x)],\n",
" # )\n",
"\n",
" # legs0 = tree.get_legs(new_order[0])\n",
" # legs1 = tree.get_legs(new_order[1])\n",
" # if any(ix0 in legs1 for ix0 in legs0):\n",
"\n",
" # compute new intermediate\n",
" new_legs0, new_cost0, new_size0, new_var_inds0 = (\n",
" compute_contracted_info(\n",
" tree.get_legs(new_order[0]),\n",
" tree.get_legs(new_order[1]),\n",
" tree.appearances,\n",
" tree.size_dict,\n",
" tree.get_node_var_inds(new_order[0]),\n",
" tree.get_node_var_inds(new_order[1]),\n",
" )\n",
" )\n",
"\n",
" new_mult_cost0 = new_cost0 * scorer.estimate_mult(\n",
" [tree.size_dict[ix] for ix in new_var_inds0]\n",
" )\n",
"\n",
" # compute new parent costs\n",
" new_legs1, new_cost1, new_size1, new_var_inds1 = (\n",
" compute_contracted_info(\n",
" new_legs0,\n",
" tree.get_legs(new_order[2]),\n",
" tree.appearances,\n",
" tree.size_dict,\n",
" new_var_inds0,\n",
" tree.get_node_var_inds(new_order[2]),\n",
" )\n",
" )\n",
"\n",
" new_mult_cost1 = new_cost1 * scorer.estimate_mult(\n",
" [tree.size_dict[ix] for ix in new_var_inds1]\n",
" )\n",
"\n",
" proposed_score = math.log2(new_mult_cost0 + new_mult_cost1)\n",
"\n",
" # proposed_score = scorer.score_local(\n",
" # flops=[new_cost0, new_cost1],\n",
" # size=[new_size0, new_size1],\n",
" # )\n",
"\n",
" dE = proposed_score - current_score\n",
" accept = (dE <= 0) or (math.log(rng.random()) < -dE / temp)\n",
"\n",
" if accept:\n",
" tree._remove_node(p)\n",
" tree._remove_node(x)\n",
"\n",
" tree.contract_nodes_pair(\n",
" tree.contract_nodes_pair(\n",
" new_order[0],\n",
" new_order[1],\n",
" legs=new_legs0,\n",
" cost=new_cost0,\n",
" size=new_size0,\n",
" ),\n",
" new_order[2],\n",
" legs=new_legs1,\n",
" cost=new_cost1,\n",
" size=new_size1,\n",
" )\n",
"\n",
" if progbar:\n",
" pbar.set_description(\n",
" f\"T: {temp:.2e} \" + _describe_tree(tree),\n",
" refresh=False,\n",
" )\n",
"\n",
" # check which children to recurse into\n",
" l, r = tree.children[p]\n",
" if len(l) > 2:\n",
" candidates.append(l)\n",
" if len(r) > 2:\n",
" candidates.append(r)\n",
"\n",
" if progbar:\n",
" pbar.update()\n",
"\n",
" return tree"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "d61c20c0-e3a4-489a-ad78-4b9e2097bc69",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"T: 5.00e-02 F=6.708 C=6.744 S=10 P=11.74 $=1: 100%|██████████| 50/50 [00:00<00:00, 55.26it/s]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"simulated_anneal_tree(mtree, inplace=True, progbar=True)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "0edc1201-8e64-4023-982f-6b5df8a28a4e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'flops': 7598080, 'write': 475336, 'size': 1024, 'peak': 61953}"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stats = mtree.exact_multi_stats(configs)\n",
"stats"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "8b3ce36c-f79d-4304-b585-40e049ab3915",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6.880703861918839"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"math.log10(stats[\"flops\"])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "dc812dca-b634-490b-9db9-9fb1e63f46b1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MultiObjectiveLinear(num_configs=200, coeff=2.0)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mtree.get_default_objective()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "61d8fa34-d7da-44ab-bb61-e613c4b62b98",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.3460989920591954"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mtree.estimate_multi_total_flops() / (ContractionTree.total_flops(mtree) * M)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "9f2d113c-44fd-43a1-bc3c-b26739687042",
"metadata": {},
"outputs": [],
"source": [
"configs.sort(key=lambda c: tuple(c[ix] for ix in sliced_ind_ordering))"
]
},
{
"cell_type": "code",
"execution_count": 126,
"id": "cd241688-eda3-4d53-9b4b-cd2712ecf6e3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3.920360304532837e-21"
]
},
"execution_count": 126,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stats[\"flops\"] / tree.total_flops()"
]
},
{
"cell_type": "code",
"execution_count": 127,
"id": "56abeb3e-a518-4faa-a7f2-087e5cb111c7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3.920360304532837e-21"
]
},
"execution_count": 127,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stats[\"flops\"] / tree.unslice_all().total_flops()"
]
},
{
"cell_type": "code",
"execution_count": 128,
"id": "d5f5ce47-0408-463b-b4a6-d7980efde2d7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"104.0"
]
},
"execution_count": 128,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree.unslice_all().peak_size(log=2)"
]
},
{
"cell_type": "code",
"execution_count": 129,
"id": "611f63a0-3f6b-448e-98ba-b0177ace9a27",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"28.0"
]
},
"execution_count": 129,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ContractionTree.max_size(mtree, log=2)"
]
},
{
"cell_type": "code",
"execution_count": 130,
"id": "67ddde83-1531-4a60-af64-2ff03add40a0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"29.067312900309663"
]
},
"execution_count": 130,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"math.log2(stats[\"peak\"])"
]
},
{
"cell_type": "code",
"execution_count": 120,
"id": "1d32e5ae-9583-4ffe-a2ed-a9116e188920",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"