[
  {
    "path": ".codecov.yml",
    "content": "# Codecov configuration to make it a bit less noisy\ncoverage:\n  status:\n    patch: false\n    project:\n      default:\n        threshold: 50%\ncomment:\n  layout: \"header\"\n  require_changes: false\n  branches: null\n  behavior: default\n  flags: null\n  paths: null\n"
  },
  {
    "path": ".gitattributes",
    "content": "*.ipynb linguist-documentation\n*.html linguist-documentation\nespaloma/_version.py export-subst\n"
  },
  {
    "path": ".github/workflows/CI.yaml",
    "content": "name: CI\n\non:\n  pull_request:\n    branches:\n      - main\n  push:\n    branches:\n      - main \n\n  schedule:\n    # Nightly tests run on master by default:\n    #   Scheduled workflows run on the latest commit on the default or base branch.\n    #   (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule)\n    - cron: \"0 0 * * *\"\n\nconcurrency:\n  group: \"${{ github.workflow }}-${{ github.ref }}\"\n  cancel-in-progress: true\n\ndefaults:\n  run:\n    shell: bash -leo pipefail {0}\n\njobs:\n  test:\n    name: ${{ matrix.os }}, Python ${{ matrix.python-version }}\n    runs-on: ${{ matrix.os }}-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        os: ['ubuntu','macos']\n        python-version:\n          - \"3.12\"\n          - \"3.11\"\n          - \"3.10\"\n\n    env:\n      OPENMM: ${{ matrix.cfg.openmm }}\n      OE_LICENSE: ${{ github.workspace }}/oe_license.txt\n\n    steps:\n      - uses: actions/checkout@v3\n      - name: Get current date\n        id: date\n        run: echo \"date=$(date +%Y-%m-%d)\" >> \"${GITHUB_OUTPUT}\"  \n      - uses: mamba-org/setup-micromamba@v1\n        with:\n          environment-file: devtools/conda-envs/espaloma.yaml\n          cache-environment: true\n          cache-downloads: true\n          cache-environment-key: environment-${{ steps.date.outputs.date }}\n          cache-downloads-key: downloads-${{ steps.date.outputs.date }}\n          create-args: >-\n            python=${{ matrix.python-version }}\n\n      - name: Additional info about the build\n        shell: bash\n        run: |\n          uname -a\n          df -h\n          ulimit -a\n\n      - name: Environment Information\n        run: |\n          micromamba info\n          micromamba list\n          micromamba --version\n\n      - name: Install package\n        run: |\n          python -m pip install --no-deps -e .\n\n      - name: Run tests\n        run: |\n          pytest -v --cov=espaloma --cov-report=xml --color=yes espaloma/\n\n      - name: CodeCov\n        uses: codecov/codecov-action@v3\n        if: ${{ github.repository == 'choderalab/espaloma'\n                && github.event_name == 'pull_request' }} \n        with:\n          token: ${{ secrets.CODECOV_TOKEN }}\n          file: ./coverage.xml\n          flags: unittests\n          yml: ./.codecov.yml\n          fail_ci_if_error: False\n          verbose: True\n"
  },
  {
    "path": ".github/workflows/clean_cache.yaml",
    "content": "# from https://docs.github.com/en/actions/using-workflows/caching-dependencies-to-speed-up-workflows#force-deleting-cache-entries\nname: cleanup caches by a branch\non:\n  pull_request:\n    types:\n      - closed\n\njobs:\n  cleanup:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Check out code\n        uses: actions/checkout@v3\n        \n      - name: Cleanup\n        run: |\n          gh extension install actions/gh-actions-cache\n          \n          REPO=${{ github.repository }}\n          BRANCH=\"refs/pull/${{ github.event.pull_request.number }}/merge\"\n\n          echo \"Fetching list of cache key\"\n          cacheKeysForPR=$(gh actions-cache list -R $REPO -B $BRANCH | cut -f 1 )\n\n          ## Setting this to not fail the workflow while deleting cache keys. \n          set +e\n          echo \"Deleting caches...\"\n          for cacheKey in $cacheKeysForPR\n          do\n              gh actions-cache delete $cacheKey -R $REPO -B $BRANCH --confirm\n          done\n          echo \"Done\"\n        env:\n          GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/docker.yaml",
    "content": "# This workflow uses actions that are not certified by GitHub.\n# They are provided by a third-party and are governed by\n# separate terms of service, privacy policy, and support\n# documentation.\n\n# GitHub recommends pinning actions to a commit SHA.\n# To get a newer version, you will need to update the SHA.\n# You can also reference a tag or branch, but the action may change without warning.\n\nname: Create and publish a Docker image\n\non:\n  workflow_dispatch:\n\ndefaults:\n  run:\n    shell: bash -l {0}\n\nenv:\n  REGISTRY: ghcr.io\n  IMAGE_NAME: choderalab/espaloma\n\njobs:\n  build-and-push-image:\n    runs-on: ubuntu-latest\n    permissions:\n      contents: read\n      packages: write\n\n    steps:\n      - name: Free disk space\n        run: |\n          sudo docker rmi $(docker image ls -aq) >/dev/null 2>&1 || true\n          sudo rm -rf \\\n            /usr/share/dotnet /usr/local/lib/android /opt/ghc \\\n            /usr/local/share/powershell /usr/share/swift /usr/local/.ghcup \\\n            /usr/lib/jvm || true\n          echo \"some directories deleted\"\n          sudo apt install aptitude -y >/dev/null 2>&1\n          sudo aptitude purge aria2 ansible azure-cli shellcheck rpm xorriso zsync \\\n            esl-erlang firefox gfortran-8 gfortran-9 google-chrome-stable \\\n            google-cloud-sdk imagemagick \\\n            libmagickcore-dev libmagickwand-dev libmagic-dev ant ant-optional kubectl \\\n            mercurial apt-transport-https mono-complete libmysqlclient \\\n            unixodbc-dev yarn chrpath libssl-dev libxft-dev \\\n            libfreetype6 libfreetype6-dev libfontconfig1 libfontconfig1-dev \\\n            snmp pollinate libpq-dev postgresql-client powershell ruby-full \\\n            sphinxsearch subversion mongodb-org azure-cli microsoft-edge-stable \\\n            -y -f >/dev/null 2>&1\n          sudo aptitude purge google-cloud-sdk -f -y >/dev/null 2>&1\n          sudo aptitude purge microsoft-edge-stable -f -y >/dev/null 2>&1 || true\n          sudo apt purge microsoft-edge-stable -f -y >/dev/null 2>&1 || true\n          sudo aptitude purge '~n ^mysql' -f -y >/dev/null 2>&1\n          sudo aptitude purge '~n ^php' -f -y >/dev/null 2>&1\n          sudo aptitude purge '~n ^dotnet' -f -y >/dev/null 2>&1\n          sudo apt-get autoremove -y >/dev/null 2>&1\n          sudo apt-get autoclean -y >/dev/null 2>&1\n          echo \"some packages purged\"\n\n      - name: Checkout repository\n        uses: actions/checkout@v3\n        with:\n          fetch-depth: 0\n\n      - name: Get Latest Version\n        id: latest-version\n        run: |\n          LATEST_TAG=$(git describe --tags $(git rev-list --tags --max-count=1))\n          echo $LATEST_TAG\n          echo \"LATEST_TAG=$LATEST_TAG\" >> $GITHUB_OUTPUT\n          VERSION=$LATEST_TAG\n          echo $VERSION\n          echo \"VERSION=$VERSION\" >> $GITHUB_OUTPUT\n\n      - name: Print Latest Version\n        run: echo ${{ steps.latest-version.outputs.VERSION }}\n\n      # Now that we got the version, we don't need the .git folder\n      - name: Get more space\n        run: |\n          df . -h\n          sudo rm -rf ${GITHUB_WORKSPACE}/.git\n          df . -h\n\n      - name: Create fully qualified image registry path\n        id: fqirp\n        run: |\n          FQIRP=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.latest-version.outputs.VERSION }}\n          echo \"FQIRP=$FQIRP\" >> $GITHUB_OUTPUT\n\n      - name: Print FQIRP\n        run: echo ${{ steps.fqirp.outputs.FQIRP  }}\n\n      - name: Log in to the Container registry\n        uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9\n        with:\n          registry: ${{ env.REGISTRY }}\n          username: ${{ github.actor }}\n          password: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Extract metadata (tags, labels) for Docker\n        id: meta\n        uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38\n        with:\n          images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}\n          tags: |\n            type=schedule,pattern=nightly,enable=true,priority=1000\n            type=ref,event=branch,enable=true,priority=600\n            type=ref,event=tag,enable=true,priority=600\n            type=ref,event=pr,prefix=pr-,enable=true,priority=600\n            type=semver,pattern={{major}}.{{minor}}\n            type=semver,pattern={{version}}\n            type=sha\n            ${{ steps.latest-version.outputs.VERSION }}\n\n      - name: Build and export to Docker\n        uses: docker/build-push-action@v4\n        with:\n          context: .\n          file: docker/Dockerfile\n          load: true\n          push: false\n          tags: ${{ steps.meta.outputs.tags }}\n          labels: ${{ steps.meta.outputs.labels }}\n          build-args: |\n            VERSION=${{ steps.latest-version.outputs.VERSION }}\n\n      - name: Test image\n        run: |\n          docker run --rm ${{ steps.fqirp.outputs.FQIRP }} python -c \"import espaloma; print(espaloma.__version__)\"\n          docker run --rm ${{ steps.fqirp.outputs.FQIRP }} pytest --pyargs espaloma -v\n\n      - name: Push Docker image\n        uses: docker/build-push-action@ad44023a93711e3deb337508980b4b5e9bcdc5dc\n        with:\n          context: .\n          file: docker/Dockerfile\n          push: true\n          tags: ${{ steps.meta.outputs.tags }}\n          labels: ${{ steps.meta.outputs.labels }}\n          build-args: |\n            VERSION=${{ steps.latest-version.outputs.VERSION }}\n\n      - name: Setup Apptainer\n        uses: eWaterCycle/setup-apptainer@v2\n        with:\n          apptainer-version: 1.1.2\n\n      - name: Build Apptainer Image\n        run: singularity build espaloma_${{ steps.latest-version.outputs.VERSION }}.sif docker-daemon:${{ steps.fqirp.outputs.FQIRP }}\n\n      - name: Test & Push Apptainer Image\n        run: |\n          mkdir test_apptainer\n          cd test_apptainer\n          singularity run ../espaloma_${{ steps.latest-version.outputs.VERSION }}.sif pytest --pyargs espaloma -v\n          echo ${{ secrets.GITHUB_TOKEN }} | singularity remote login -u ${{ secrets.GHCR_USERNAME }} --password-stdin oras://ghcr.io\n          singularity push ../espaloma_${{ steps.latest-version.outputs.VERSION }}.sif oras://${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.latest-version.outputs.VERSION }}-apptainer\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# data\n*.sdf\n*.csv\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# Parm@Frosst download\nparm_at_Frosst.tgz\n\n# misc\n.DS_Store\n"
  },
  {
    "path": ".readthedocs.yaml",
    "content": "version: 2\n\nbuild:\n  os: \"ubuntu-20.04\"\n  tools:\n    python: \"mambaforge-4.10\"\n\nsphinx:\n   configuration: docs/conf.py\n   fail_on_warning: false\n\nconda:\n  environment: devtools/conda-envs/espaloma.yaml\n\npython:\n  # Install our python package before building the docs\n  install:\n    - method: pip\n      path: .\n"
  },
  {
    "path": "LICENSE",
    "content": "\nMIT License\n\nCopyright (c) 2020 Yuanqing Wang @ choderalab // MSKCC\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "include LICENSE\ninclude MANIFEST.in\ninclude versioneer.py\n\ngraft espaloma\nglobal-exclude *.py[cod] __pycache__ *.so"
  },
  {
    "path": "README.md",
    "content": "espaloma: **E**xtensible **S**urrogate **P**otenti**al** **O**ptimized by **M**essage-passing **A**lgorithms 🍹\n==============================\n[//]: # (Badges)\n[![CI](https://github.com/choderalab/espaloma/actions/workflows/CI.yaml/badge.svg?branch=main)](https://github.com/choderalab/espaloma/actions/workflows/CI.yaml)\n[![Documentation Status](https://readthedocs.org/projects/espaloma/badge/?version=latest)](https://espaloma.readthedocs.io/en/latest/?badge=latest)\n\nSource code for [Wang Y, Fass J, and Chodera JD \"End-to-End Differentiable Construction of Molecular Mechanics Force Fields.\"](https://arxiv.org/abs/2010.01196)\n\n![abstract](docs/_static/espaloma_abstract_v2-2.png)\n\n#\nDocumentation: https://docs.espaloma.org\n\n# Paper Abstract\nMolecular mechanics (MM) potentials have long been a workhorse of computational chemistry.\nLeveraging accuracy and speed, these functional forms find use in a wide variety of applications in biomolecular modeling and drug discovery, from rapid virtual screening to detailed free energy calculations.\nTraditionally, MM potentials have relied on human-curated, inflexible, and poorly extensible discrete chemical perception rules _atom types_ for applying parameters to small molecules or biopolymers, making it difficult to optimize both types and parameters to fit quantum chemical or physical property data.\nHere, we propose an alternative approach that uses _graph neural networks_ to perceive chemical environments, producing continuous atom embeddings from which valence and nonbonded parameters can be predicted using invariance-preserving layers.\nSince all stages are built from smooth neural functions, the entire process---spanning chemical perception to parameter assignment---is modular and end-to-end differentiable with respect to model parameters, allowing new force fields to be easily constructed, extended, and applied to arbitrary molecules.\nWe show that this approach is not only sufficiently expressive to reproduce legacy atom types, but that it can learn and extend existing molecular mechanics force fields, construct entirely new force fields applicable to both biopolymers and small molecules from quantum chemical calculations, and even learn to accurately predict free energies from experimental observables.\n\n\n# Installation\n\nWe recommend using [`mamba`](https://mamba.readthedocs.io/en/latest/mamba-installation.html#mamba-installation) which is a drop-in replacement for `conda` and is much faster.   \n\n```bash\n$ mamba create --name espaloma -c conda-forge \"espaloma=0.3.2\"\n```\n\n# Example: Deploy espaloma 0.3.2 pretrained force field to arbitrary MM system\n\n```python  \n# imports\nimport os\nimport torch\nimport espaloma as esp\n\n# define or load a molecule of interest via the Open Force Field toolkit\nfrom openff.toolkit.topology import Molecule\nmolecule = Molecule.from_smiles(\"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\")\n\n# create an Espaloma Graph object to represent the molecule of interest\nmolecule_graph = esp.Graph(molecule)\n\n# load pretrained model\nespaloma_model = esp.get_model(\"latest\")\n\n# apply a trained espaloma model to assign parameters\nespaloma_model(molecule_graph.heterograph)\n\n# create an OpenMM System for the specified molecule\nopenmm_system = esp.graphs.deploy.openmm_system_from_graph(molecule_graph)\n```\n\nIf using espaloma from a local `.pt` file, say for example `espaloma-0.3.2.pt`,\nthen you would need to run the `eval` method of the model to get the correct\ninference/predictions, as follows:\n\n```python\nimport torch\n...\n# load local pretrained model\nespaloma_model = torch.load(\"espaloma-0.3.2.pt\")\nespaloma_model.eval()\n...\n```\n\nThe rest of the code should be the same as in the previous code block example.\n\n# Compatible models\n\nBelow is a compatibility matrix for different versions of `espaloma` code and `espaloma` models (the `.pt` file).\n\n| Model 🧪             | DOI 📝 | Supported Espaloma version 💻 | Release Date 🗓️ | Espaloma architecture change 📐? |\n|---------------------|-------|------------------------------|----------------|----------------------------------|\n| `espaloma-0.3.2.pt` |       | 0.3.1, 0.3.2, 0.4.0          | Sep 22, 2023   | ✅ No                            |\n| `espaloma-0.3.1.pt` |       | 0.3.1, 0.3.2, 0.4.0          | Jul 17, 2023   | ⚠️ Yes                           |\n| `espaloma-0.3.0.pt` |       | 0.3.0                        | Apr 26, 2023   | ⚠️Yes                            |\n\n> [!NOTE]  \n> `espaloma-0.3.1.pt` and `espaloma-0.3.2.pt` are the same model.\n\n# Using espaloma to parameterize small molecules in relative free energy calculations\n\nAn example of using espaloma to parameterize small molecules in relative alchemical free energy calculations is provided in the `scripts/perses-benchmark/` directory.\n\n# Manifest\n\n* `espaloma/` core code for graph-parametrized potential energy functions.\n    * `graphs/` data objects that contain various level of information we need.\n        * `graph.py` base modules for graphs.\n        * `molecule_graph.py` provide APIs to various molecular modelling toolkits.\n        * `homogeneous_graph.py` simplest graph representation of a molecule.\n        * `heterogeneous_graph.py` graph representation of a molecule that contains information regarding membership of lower-level nodes to higher-level nodes.\n        * `parametrized_graph.py` graph representation of a molecule with all parameters needed for energy evaluation.\n    * `nn/` neural network models that facilitates translation between graphs.\n        * `dgl_legacy.py` API to dgl models for atom-level message passing.\n    * `mm/` molecular mechanics functionalities for energy evaluation.\n        * `i/` energy terms used in Class-I force field.\n            * `bond.py` bond energy\n            * `angle.py` angle energy\n            * `torsion.py` torsion energy\n            * `nonbonded.py` nonbonded energy\n        * `ii/` energy terms used in Class-II force field.\n            * `coupling.py` coupling terms\n            * `polynomial.py` higher order polynomials.\n\n# License\n\nThis software is licensed under [MIT license](https://opensource.org/licenses/MIT).\n\n# Copyright\n\nCopyright (c) 2020, Chodera Lab at Memorial Sloan Kettering Cancer Center and Authors:\nAuthors:\n- [Yuanqing Wang](http://www.wangyq.net)\n- Josh Fass\n- John D. Chodera\n"
  },
  {
    "path": "devtools/README.md",
    "content": "# Development, testing, and deployment tools\n\nThis directory contains a collection of tools for running Continuous Integration (CI) tests, \nconda installation, and other development tools not directly related to the coding process.\n\n\n## Manifest\n\n### Continuous Integration\n\nYou should test your code, but do not feel compelled to use these specific programs. You also may not need Unix and \nWindows testing if you only plan to deploy on specific platforms. These are just to help you get started\n\n* `travis-ci`: Linux and OSX based testing through [Travis-CI](https://about.travis-ci.com/) \n  * `before_install.sh`: Pip/Miniconda pre-package installation script for Travis \n* `appveyor`: Windows based testing through [AppVeyor](https://www.appveyor.com/) (there are no files directly related to this)\n\n### Conda Environment:\n\nThis directory contains the files to setup the Conda environment for testing purposes\n\n* `conda-envs`: directory containing the YAML file(s) which fully describe Conda Environments, their dependencies, and those dependency provenance's\n  * `test_env.yaml`: Simple test environment file with base dependencies. Channels are not specified here and therefore respect global Conda configuration\n  \n### Additional Scripts:\n\nThis directory contains OS agnostic helper scripts which don't fall in any of the previous categories\n* `scripts`\n  * `create_conda_env.py`: Helper program for spinning up new conda environments based on a starter file with Python Version and Env. Name command-line options\n\n\n## How to contribute changes\n- Clone the repository if you have write access to the main repo, fork the repository if you are a collaborator.\n- Make a new branch with `git checkout -b {your branch name}`\n- Make changes and test your code\n- Ensure that the test environment dependencies (`conda-envs`) line up with the build and deploy dependencies (`conda-recipe/meta.yaml`)\n- Push the branch to the repo (either the main or your fork) with `git push -u origin {your branch name}`\n  * Note that `origin` is the default name assigned to the remote, yours may be different\n- Make a PR on GitHub with your changes\n- We'll review the changes and get your code into the repo after lively discussion!\n\n\n## Checklist for updates\n- [ ] Make sure there is an/are issue(s) opened for your specific update\n- [ ] Create the PR, referencing the issue\n- [ ] Debug the PR as needed until tests pass\n- [ ] Tag the final, debugged version \n   *  `git tag -a X.Y.Z [latest pushed commit] && git push --follow-tags`\n- [ ] Get the PR merged in\n\n## Versioneer Auto-version\n[Versioneer](https://github.com/warner/python-versioneer) will automatically infer what version \nis installed by looking at the `git` tags and how many commits ahead this version is. The format follows \n[PEP 440](https://www.python.org/dev/peps/pep-0440/) and has the regular expression of:\n```regexp\n\\d+.\\d+.\\d+(?\\+\\d+-[a-z0-9]+)\n```\nIf the version of this commit is the same as a `git` tag, the installed version is the same as the tag, \ne.g. `espaloma-0.1.2`, otherwise it will be appended with `+X` where `X` is the number of commits \nahead from the last tag, and then `-YYYYYY` where the `Y`'s are replaced with the `git` commit hash.\n"
  },
  {
    "path": "devtools/conda-envs/espaloma.yaml",
    "content": "name: espaloma-test\nchannels:\n  - conda-forge\n  - openeye\ndependencies:\n  # Base dependencies\n  - python\n  - pip\n  # 3rd party\n  - openeye-toolkits\n  - numpy\n  - matplotlib\n  - scipy\n  - openff-toolkit >=0.12\n  - openff-forcefields\n  - openff-units\n  - smirnoff99frosst>=1.1.0.1  #https://github.com/openforcefield/smirnoff99Frosst/issues/109\n  - openmm\n  - openmmforcefields >=0.11.2\n  - tqdm\n  - pydantic <2  # We need our deps to fix this\n  - qcportal >=0.50\n  - dgl =2.3.0\n  - torchdata <= 0.10.0\n  # Testing\n  - pytest\n  - pytest-cov\n  - pytest-xdist\n  - pytest-randomly\n  - codecov\n  - nose\n  - nose-timer\n  - coverage\n  - sphinx\n  - sphinx_rtd_theme\n"
  },
  {
    "path": "devtools/conda-recipe/build.sh",
    "content": "pip install .\n"
  },
  {
    "path": "devtools/conda-recipe/meta.yml",
    "content": "package:\n  name: espaloma\n  version: !!str 0.0.0\n\nsource:\n  path: ../../\n\nbuild:\n  preserve_egg_dir: True\n  number: 0\n\nrequirements:\n  build:\n    - python\n    - setuptools\n    - numpy >=1.14\n\n  run:\n    - python\n    - pip\n    - openeye-toolkits\n    - numpy\n    - matplotlib\n    - scipy\n    - openff-toolkit\n    - openff-forcefields\n    - smirnoff99Frosst\n    - openmm\n    - openmmforcefields\n    - pytorch\n    - dgl\n    - pytest\n    - pytest-cov\n    - codecov\n    - nose\n    - nose-timer\n    - coverage\n    - qcportal\n    - torchdata <= 0.10.0\n\nabout:\n  home: https://github.com/choderalab/perses\n  license: MIT\n  license_file: LICENSE\n"
  },
  {
    "path": "devtools/gh-actions/initialize_conda.sh",
    "content": "case $CI_OS in\n    windows*)\n        eval \"$(${CONDA}/condabin/conda.bat shell.bash hook)\";;\n    *)\n        eval \"$(${CONDA}/condabin/conda shell.bash hook)\";;\nesac"
  },
  {
    "path": "devtools/scripts/create_conda_env.py",
    "content": "import argparse\nimport glob\nimport os\nimport re\nimport shutil\nimport subprocess as sp\nfrom contextlib import contextmanager\nfrom tempfile import TemporaryDirectory\n\n# YAML imports\ntry:\n    import yaml  # PyYAML\n    loader = yaml.load\nexcept ImportError:\n    try:\n        import ruamel_yaml as yaml  # Ruamel YAML\n    except ImportError:\n        try:\n            # Load Ruamel YAML from the base conda environment\n            from importlib import util as import_util\n            CONDA_BIN = os.path.dirname(os.environ['CONDA_EXE'])\n            ruamel_yaml_path = glob.glob(os.path.join(CONDA_BIN, '..',\n                                                      'lib', 'python*.*', 'site-packages',\n                                                      'ruamel_yaml', '__init__.py'))[0]\n            # Based on importlib example, but only needs to load_module since its the whole package, not just\n            # a module\n            spec = import_util.spec_from_file_location('ruamel_yaml', ruamel_yaml_path)\n            yaml = spec.loader.load_module()\n        except (KeyError, ImportError, IndexError):\n            raise ImportError(\"No YAML parser could be found in this or the conda environment. \"\n                              \"Could not find PyYAML or Ruamel YAML in the current environment, \"\n                              \"AND could not find Ruamel YAML in the base conda environment through CONDA_EXE path. \" \n                              \"Environment not created!\")\n    loader = yaml.YAML(typ=\"safe\").load  # typ=\"safe\" avoids odd typing on output\n\n\n@contextmanager\ndef temp_cd():\n    \"\"\"Temporary CD Helper\"\"\"\n    cwd = os.getcwd()\n    with TemporaryDirectory() as td:\n        try:\n            os.chdir(td)\n            yield\n        finally:\n            os.chdir(cwd)\n\n\n# Args\nparser = argparse.ArgumentParser(description='Creates a conda environment from file for a given Python version.')\nparser.add_argument('-n', '--name', type=str,\n                    help='The name of the created Python environment')\nparser.add_argument('-p', '--python', type=str,\n                    help='The version of the created Python environment')\nparser.add_argument('conda_file',\n                    help='The file for the created Python environment')\n\nargs = parser.parse_args()\n\n# Open the base file\nwith open(args.conda_file, \"r\") as handle:\n    yaml_script = loader(handle.read())\n\npython_replacement_string = \"python {}*\".format(args.python)\n\ntry:\n    for dep_index, dep_value in enumerate(yaml_script['dependencies']):\n        if re.match('python([ ><=*]+[0-9.*]*)?$', dep_value):  # Match explicitly 'python' and its formats\n            yaml_script['dependencies'].pop(dep_index)\n            break  # Making the assumption there is only one Python entry, also avoids need to enumerate in reverse\nexcept (KeyError, TypeError):\n    # Case of no dependencies key, or dependencies: None\n    yaml_script['dependencies'] = []\nfinally:\n    # Ensure the python version is added in. Even if the code does not need it, we assume the env does\n    yaml_script['dependencies'].insert(0, python_replacement_string)\n\n# Figure out conda path\nif \"CONDA_EXE\" in os.environ:\n    conda_path = os.environ[\"CONDA_EXE\"]\nelse:\n    conda_path = shutil.which(\"conda\")\nif conda_path is None:\n    raise RuntimeError(\"Could not find a conda binary in CONDA_EXE variable or in executable search path\")\n\nprint(\"CONDA ENV NAME  {}\".format(args.name))\nprint(\"PYTHON VERSION  {}\".format(args.python))\nprint(\"CONDA FILE NAME {}\".format(args.conda_file))\nprint(\"CONDA PATH      {}\".format(conda_path))\n\n# Write to a temp directory which will always be cleaned up\nwith temp_cd():\n    temp_file_name = \"temp_script.yaml\"\n    with open(temp_file_name, 'w') as f:\n        f.write(yaml.dump(yaml_script))\n    sp.call(\"{} env create -n {} -f {}\".format(conda_path, args.name, temp_file_name), shell=True)\n"
  },
  {
    "path": "docker/Dockerfile",
    "content": "FROM mambaorg/micromamba:1.4.9\n\nLABEL org.opencontainers.image.source=https://github.com/choderalab/espaloma\nLABEL org.opencontainers.image.description=\"Extensible Surrogate Potential of Ab initio Learned and Optimized by Message-passing Algorithm\"\nLABEL org.opencontainers.image.licenses=MIT\n# OpenFE Version we want to build\nARG VERSION\n\n# Don't buffer stdout & stderr streams, so if there is a crash no partial buffer output is lost\n# https://docs.python.org/3/using/cmdline.html#cmdoption-u\nENV PYTHONUNBUFFERED=1\n\nRUN micromamba install -y -n base -c conda-forge -c dglteam pytest \"dgl<1\" git \"espaloma==$VERSION\" && \\\n    micromamba clean --all --yes\n\n# Ensure that conda environment is automatically activated\n# https://github.com/mamba-org/micromamba-docker#running-commands-in-dockerfile-within-the-conda-environment\nARG MAMBA_DOCKERFILE_ACTIVATE=1\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = sphinx-build\nSPHINXPROJ    = espaloma\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)"
  },
  {
    "path": "docs/README.md",
    "content": "# Compiling espaloma's Documentation\n\nThe docs for this project are built with [Sphinx](http://www.sphinx-doc.org/en/master/).\nTo compile the docs, first ensure that Sphinx and the ReadTheDocs theme are installed.\n\n\n```bash\nconda install sphinx sphinx_rtd_theme \n```\n\n\nOnce installed, you can use the `Makefile` in this directory to compile static HTML pages by\n```bash\nmake html\n```\n\nThe compiled docs will be in the `_build` directory and can be viewed by opening `index.html` (which may itself \nbe inside a directory called `html/` depending on what version of Sphinx is installed)."
  },
  {
    "path": "docs/_static/README.md",
    "content": "# Static Doc Directory\n\nAdd any paths that contain custom static files (such as style sheets) here,\nrelative to the `conf.py` file's directory. \nThey are copied after the builtin static files,\nso a file named \"default.css\" will overwrite the builtin \"default.css\".\n\nThe path to this folder is set in the Sphinx `conf.py` file in the line: \n```python\ntemplates_path = ['_static']\n```\n\n## Examples of file to add to this directory\n* Custom Cascading Style Sheets\n* Custom JavaScript code\n* Static logo images\n"
  },
  {
    "path": "docs/_templates/README.md",
    "content": "# Templates Doc Directory\n\nAdd any paths that contain templates here, relative to  \nthe `conf.py` file's directory.\nThey are copied after the builtin template files,\nso a file named \"page.html\" will overwrite the builtin \"page.html\".\n\nThe path to this folder is set in the Sphinx `conf.py` file in the line: \n```python\nhtml_static_path = ['_templates']\n```\n\n## Examples of file to add to this directory\n* HTML extensions of stock pages like `page.html` or `layout.html`\n"
  },
  {
    "path": "docs/_templates/custom-class-template.rst",
    "content": "{{ fullname | escape | underline}}\n\n.. currentmodule:: {{ module }}\n\n.. autoclass:: {{ objname }}\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   {% block methods %}\n   .. automethod:: __init__\n\n   {% if methods %}\n   .. rubric:: {{ _('Methods') }}\n\n   .. autosummary::\n   {% for item in methods %}\n      ~{{ name }}.{{ item }}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n\n   {% block attributes %}\n   {% if attributes %}\n   .. rubric:: {{ _('Attributes') }}\n\n   .. autosummary::\n   {% for item in attributes %}\n      ~{{ name }}.{{ item }}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n"
  },
  {
    "path": "docs/_templates/custom-module-template.rst",
    "content": "{{ fullname | escape | underline}}\n\n.. automodule:: {{ fullname }}\n  \n   {% block attributes %}\n   {% if attributes %}\n   .. rubric:: Module Attributes\n\n   .. autosummary::\n      :toctree:\n   {% for item in attributes %}\n      {{ item }}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n\n   {% block functions %}\n   {% if functions %}\n   .. rubric:: {{ _('Functions') }}\n\n   .. autosummary::\n      :toctree:\n   {% for item in functions %}\n      {{ item }}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n\n   {% block classes %}\n   {% if classes %}\n   .. rubric:: {{ _('Classes') }}\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   {% for item in classes %}\n      {{ item }}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n\n   {% block exceptions %}\n   {% if exceptions %}\n   .. rubric:: {{ _('Exceptions') }}\n\n   .. autosummary::\n      :toctree:\n   {% for item in exceptions %}\n      {{ item }}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n\n{% block modules %}\n{% if modules %}\n.. rubric:: Modules\n\n.. autosummary::\n   :toctree:\n   :template: custom-module-template.rst\n   :recursive:\n\n{% for item in modules %}\n   {{ item }}\n{%- endfor %}\n{% endif %}\n{% endblock %}\n"
  },
  {
    "path": "docs/api.rst",
    "content": "API Documentation\n=================\n\n.. autosummary::\n   :toctree: autosummary\n   :template: custom-module-template.rst\n   :recursive:\n\n   espaloma.mm\n   espaloma.nn\n   espaloma.graphs\n   espaloma.data \n"
  },
  {
    "path": "docs/autosummary/espaloma.data.collection.alkethoh.rst",
    "content": "espaloma.data.collection.alkethoh\n=================================\n\n.. currentmodule:: espaloma.data.collection\n\n.. autofunction:: alkethoh"
  },
  {
    "path": "docs/autosummary/espaloma.data.collection.esol.rst",
    "content": "espaloma.data.collection.esol\n=============================\n\n.. currentmodule:: espaloma.data.collection\n\n.. autofunction:: esol"
  },
  {
    "path": "docs/autosummary/espaloma.data.collection.md17_new.rst",
    "content": "espaloma.data.collection.md17\\_new\n==================================\n\n.. currentmodule:: espaloma.data.collection\n\n.. autofunction:: md17_new"
  },
  {
    "path": "docs/autosummary/espaloma.data.collection.md17_old.rst",
    "content": "espaloma.data.collection.md17\\_old\n==================================\n\n.. currentmodule:: espaloma.data.collection\n\n.. autofunction:: md17_old"
  },
  {
    "path": "docs/autosummary/espaloma.data.collection.qca.rst",
    "content": "espaloma.data.collection.qca\n============================\n\n.. currentmodule:: espaloma.data.collection\n\n.. autoclass:: qca\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~qca.__init__\n      ~qca.bayer\n      ~qca.benchmark\n      ~qca.coverage\n      ~qca.emolecules\n      ~qca.fda\n      ~qca.pfizer\n      ~qca.roche\n   \n   \n\n   \n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.data.collection.rst",
    "content": "espaloma.data.collection\n========================\n\n.. automodule:: espaloma.data.collection\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      alkethoh\n      esol\n      md17_new\n      md17_old\n      zinc\n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      qca\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.data.collection.zinc.rst",
    "content": "espaloma.data.collection.zinc\n=============================\n\n.. currentmodule:: espaloma.data.collection\n\n.. autofunction:: zinc"
  },
  {
    "path": "docs/autosummary/espaloma.data.dataset.Dataset.rst",
    "content": "espaloma.data.dataset.Dataset\n=============================\n\n.. currentmodule:: espaloma.data.dataset\n\n.. autoclass:: Dataset\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~Dataset.__init__\n      ~Dataset.apply\n      ~Dataset.load\n      ~Dataset.save\n      ~Dataset.shuffle\n      ~Dataset.split\n      ~Dataset.subsample\n   \n   \n\n   \n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.data.dataset.GraphDataset.rst",
    "content": "espaloma.data.dataset.GraphDataset\n==================================\n\n.. currentmodule:: espaloma.data.dataset\n\n.. autoclass:: GraphDataset\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~GraphDataset.__init__\n      ~GraphDataset.apply\n      ~GraphDataset.batch\n      ~GraphDataset.load\n      ~GraphDataset.save\n      ~GraphDataset.shuffle\n      ~GraphDataset.split\n      ~GraphDataset.subsample\n      ~GraphDataset.view\n   \n   \n\n   \n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.data.dataset.rst",
    "content": "espaloma.data.dataset\n=====================\n\n.. automodule:: espaloma.data.dataset\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      Dataset\n      GraphDataset\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.data.md.MoleculeVacuumSimulation.rst",
    "content": "espaloma.data.md.MoleculeVacuumSimulation\n=========================================\n\n.. currentmodule:: espaloma.data.md\n\n.. autoclass:: MoleculeVacuumSimulation\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~MoleculeVacuumSimulation.__init__\n      ~MoleculeVacuumSimulation.run\n      ~MoleculeVacuumSimulation.simulation_from_graph\n   \n   \n\n   \n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.data.md.rst",
    "content": "espaloma.data.md\n================\n\n.. automodule:: espaloma.data.md\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      subtract_nonbonded_force\n      subtract_nonbonded_force_except_14\n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      MoleculeVacuumSimulation\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.data.md.subtract_nonbonded_force.rst",
    "content": "espaloma.data.md.subtract\\_nonbonded\\_force\n===========================================\n\n.. currentmodule:: espaloma.data.md\n\n.. autofunction:: subtract_nonbonded_force"
  },
  {
    "path": "docs/autosummary/espaloma.data.md.subtract_nonbonded_force_except_14.rst",
    "content": "espaloma.data.md.subtract\\_nonbonded\\_force\\_except\\_14\n=======================================================\n\n.. currentmodule:: espaloma.data.md\n\n.. autofunction:: subtract_nonbonded_force_except_14"
  },
  {
    "path": "docs/autosummary/espaloma.data.md17_utils.get_molecule.rst",
    "content": "espaloma.data.md17\\_utils.get\\_molecule\n=======================================\n\n.. currentmodule:: espaloma.data.md17_utils\n\n.. autofunction:: get_molecule"
  },
  {
    "path": "docs/autosummary/espaloma.data.md17_utils.realize_molecule.rst",
    "content": "espaloma.data.md17\\_utils.realize\\_molecule\n===========================================\n\n.. currentmodule:: espaloma.data.md17_utils\n\n.. autofunction:: realize_molecule"
  },
  {
    "path": "docs/autosummary/espaloma.data.md17_utils.rst",
    "content": "espaloma.data.md17\\_utils\n=========================\n\n.. automodule:: espaloma.data.md17_utils\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      get_molecule\n      realize_molecule\n      sum_offsets\n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.data.md17_utils.sum_offsets.rst",
    "content": "espaloma.data.md17\\_utils.sum\\_offsets\n======================================\n\n.. currentmodule:: espaloma.data.md17_utils\n\n.. autofunction:: sum_offsets"
  },
  {
    "path": "docs/autosummary/espaloma.data.normalize.BaseNormalize.rst",
    "content": "espaloma.data.normalize.BaseNormalize\n=====================================\n\n.. currentmodule:: espaloma.data.normalize\n\n.. autoclass:: BaseNormalize\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~BaseNormalize.__init__\n   \n   \n\n   \n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.data.normalize.DatasetLogNormalNormalize.rst",
    "content": "espaloma.data.normalize.DatasetLogNormalNormalize\n=================================================\n\n.. currentmodule:: espaloma.data.normalize\n\n.. autoclass:: DatasetLogNormalNormalize\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~DatasetLogNormalNormalize.__init__\n   \n   \n\n   \n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.data.normalize.DatasetNormalNormalize.rst",
    "content": "espaloma.data.normalize.DatasetNormalNormalize\n==============================================\n\n.. currentmodule:: espaloma.data.normalize\n\n.. autoclass:: DatasetNormalNormalize\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~DatasetNormalNormalize.__init__\n   \n   \n\n   \n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.data.normalize.ESOL100LogNormalNormalize.rst",
    "content": "espaloma.data.normalize.ESOL100LogNormalNormalize\n=================================================\n\n.. currentmodule:: espaloma.data.normalize\n\n.. autoclass:: ESOL100LogNormalNormalize\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~ESOL100LogNormalNormalize.__init__\n   \n   \n\n   \n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.data.normalize.ESOL100NormalNormalize.rst",
    "content": "espaloma.data.normalize.ESOL100NormalNormalize\n==============================================\n\n.. currentmodule:: espaloma.data.normalize\n\n.. autoclass:: ESOL100NormalNormalize\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~ESOL100NormalNormalize.__init__\n   \n   \n\n   \n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.data.normalize.NotNormalize.rst",
    "content": "espaloma.data.normalize.NotNormalize\n====================================\n\n.. currentmodule:: espaloma.data.normalize\n\n.. autoclass:: NotNormalize\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~NotNormalize.__init__\n   \n   \n\n   \n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.data.normalize.PositiveNotNormalize.rst",
    "content": "espaloma.data.normalize.PositiveNotNormalize\n============================================\n\n.. currentmodule:: espaloma.data.normalize\n\n.. autoclass:: PositiveNotNormalize\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~PositiveNotNormalize.__init__\n   \n   \n\n   \n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.data.normalize.rst",
    "content": "espaloma.data.normalize\n=======================\n\n.. automodule:: espaloma.data.normalize\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      BaseNormalize\n      DatasetLogNormalNormalize\n      DatasetNormalNormalize\n      ESOL100LogNormalNormalize\n      ESOL100NormalNormalize\n      NotNormalize\n      PositiveNotNormalize\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.data.qcarchive_utils.MolWithTargets.rst",
    "content": "espaloma.data.qcarchive\\_utils.MolWithTargets\n=============================================\n\n.. currentmodule:: espaloma.data.qcarchive_utils\n\n.. autoclass:: MolWithTargets\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~MolWithTargets.__init__\n      ~MolWithTargets.count\n      ~MolWithTargets.index\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~MolWithTargets.energies\n      ~MolWithTargets.gradients\n      ~MolWithTargets.offmol\n      ~MolWithTargets.xyz\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.data.qcarchive_utils.breakdown_along_time_axis.rst",
    "content": "espaloma.data.qcarchive\\_utils.breakdown\\_along\\_time\\_axis\n===========================================================\n\n.. currentmodule:: espaloma.data.qcarchive_utils\n\n.. autofunction:: breakdown_along_time_axis"
  },
  {
    "path": "docs/autosummary/espaloma.data.qcarchive_utils.fetch_td_record.rst",
    "content": "espaloma.data.qcarchive\\_utils.fetch\\_td\\_record\n================================================\n\n.. currentmodule:: espaloma.data.qcarchive_utils\n\n.. autofunction:: fetch_td_record"
  },
  {
    "path": "docs/autosummary/espaloma.data.qcarchive_utils.get_client.rst",
    "content": "espaloma.data.qcarchive\\_utils.get\\_client\n==========================================\n\n.. currentmodule:: espaloma.data.qcarchive_utils\n\n.. autofunction:: get_client"
  },
  {
    "path": "docs/autosummary/espaloma.data.qcarchive_utils.get_collection.rst",
    "content": "espaloma.data.qcarchive\\_utils.get\\_collection\n==============================================\n\n.. currentmodule:: espaloma.data.qcarchive_utils\n\n.. autofunction:: get_collection"
  },
  {
    "path": "docs/autosummary/espaloma.data.qcarchive_utils.get_energy_and_gradient.rst",
    "content": "espaloma.data.qcarchive\\_utils.get\\_energy\\_and\\_gradient\n=========================================================\n\n.. currentmodule:: espaloma.data.qcarchive_utils\n\n.. autofunction:: get_energy_and_gradient"
  },
  {
    "path": "docs/autosummary/espaloma.data.qcarchive_utils.get_graph.rst",
    "content": "espaloma.data.qcarchive\\_utils.get\\_graph\n=========================================\n\n.. currentmodule:: espaloma.data.qcarchive_utils\n\n.. autofunction:: get_graph"
  },
  {
    "path": "docs/autosummary/espaloma.data.qcarchive_utils.h5_to_dataset.rst",
    "content": "espaloma.data.qcarchive\\_utils.h5\\_to\\_dataset\n==============================================\n\n.. currentmodule:: espaloma.data.qcarchive_utils\n\n.. autofunction:: h5_to_dataset"
  },
  {
    "path": "docs/autosummary/espaloma.data.qcarchive_utils.make_batch_size_consistent.rst",
    "content": "espaloma.data.qcarchive\\_utils.make\\_batch\\_size\\_consistent\n============================================================\n\n.. currentmodule:: espaloma.data.qcarchive_utils\n\n.. autofunction:: make_batch_size_consistent"
  },
  {
    "path": "docs/autosummary/espaloma.data.qcarchive_utils.rst",
    "content": "espaloma.data.qcarchive\\_utils\n==============================\n\n.. automodule:: espaloma.data.qcarchive_utils\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      breakdown_along_time_axis\n      fetch_td_record\n      get_client\n      get_collection\n      get_energy_and_gradient\n      get_graph\n      h5_to_dataset\n      make_batch_size_consistent\n      weight_by_snapshots\n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      MolWithTargets\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.data.qcarchive_utils.weight_by_snapshots.rst",
    "content": "espaloma.data.qcarchive\\_utils.weight\\_by\\_snapshots\n====================================================\n\n.. currentmodule:: espaloma.data.qcarchive_utils\n\n.. autofunction:: weight_by_snapshots"
  },
  {
    "path": "docs/autosummary/espaloma.data.rst",
    "content": "﻿espaloma.data\n=============\n\n.. automodule:: espaloma.data\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n.. rubric:: Modules\n\n.. autosummary::\n   :toctree:\n   :template: custom-module-template.rst\n   :recursive:\n\n\n   espaloma.data.collection\n   espaloma.data.dataset\n   espaloma.data.md\n   espaloma.data.md17_utils\n   espaloma.data.normalize\n   espaloma.data.qcarchive_utils\n   espaloma.data.utils\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.data.utils.batch.rst",
    "content": "espaloma.data.utils.batch\n=========================\n\n.. currentmodule:: espaloma.data.utils\n\n.. autofunction:: batch"
  },
  {
    "path": "docs/autosummary/espaloma.data.utils.collate_fn.rst",
    "content": "espaloma.data.utils.collate\\_fn\n===============================\n\n.. currentmodule:: espaloma.data.utils\n\n.. autofunction:: collate_fn"
  },
  {
    "path": "docs/autosummary/espaloma.data.utils.from_csv.rst",
    "content": "espaloma.data.utils.from\\_csv\n=============================\n\n.. currentmodule:: espaloma.data.utils\n\n.. autofunction:: from_csv"
  },
  {
    "path": "docs/autosummary/espaloma.data.utils.infer_mol_from_coordinates.rst",
    "content": "espaloma.data.utils.infer\\_mol\\_from\\_coordinates\n=================================================\n\n.. currentmodule:: espaloma.data.utils\n\n.. autofunction:: infer_mol_from_coordinates"
  },
  {
    "path": "docs/autosummary/espaloma.data.utils.make_temp_directory.rst",
    "content": "espaloma.data.utils.make\\_temp\\_directory\n=========================================\n\n.. currentmodule:: espaloma.data.utils\n\n.. autofunction:: make_temp_directory"
  },
  {
    "path": "docs/autosummary/espaloma.data.utils.normalize.rst",
    "content": "espaloma.data.utils.normalize\n=============================\n\n.. currentmodule:: espaloma.data.utils\n\n.. autofunction:: normalize"
  },
  {
    "path": "docs/autosummary/espaloma.data.utils.rst",
    "content": "espaloma.data.utils\n===================\n\n.. automodule:: espaloma.data.utils\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      batch\n      collate_fn\n      from_csv\n      infer_mol_from_coordinates\n      make_temp_directory\n      normalize\n      split\n      sum_offsets\n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.data.utils.split.rst",
    "content": "espaloma.data.utils.split\n=========================\n\n.. currentmodule:: espaloma.data.utils\n\n.. autofunction:: split"
  },
  {
    "path": "docs/autosummary/espaloma.data.utils.sum_offsets.rst",
    "content": "espaloma.data.utils.sum\\_offsets\n================================\n\n.. currentmodule:: espaloma.data.utils\n\n.. autofunction:: sum_offsets"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.deploy.load_forcefield.rst",
    "content": "espaloma.graphs.deploy.load\\_forcefield\n=======================================\n\n.. currentmodule:: espaloma.graphs.deploy\n\n.. autofunction:: load_forcefield"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.deploy.openmm_system_from_graph.rst",
    "content": "espaloma.graphs.deploy.openmm\\_system\\_from\\_graph\n==================================================\n\n.. currentmodule:: espaloma.graphs.deploy\n\n.. autofunction:: openmm_system_from_graph"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.deploy.rst",
    "content": "espaloma.graphs.deploy\n======================\n\n.. automodule:: espaloma.graphs.deploy\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      load_forcefield\n      openmm_system_from_graph\n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.graph.BaseGraph.rst",
    "content": "espaloma.graphs.graph.BaseGraph\n===============================\n\n.. currentmodule:: espaloma.graphs.graph\n\n.. autoclass:: BaseGraph\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~BaseGraph.__init__\n   \n   \n\n   \n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.graphs.graph.Graph.rst",
    "content": "espaloma.graphs.graph.Graph\n===========================\n\n.. currentmodule:: espaloma.graphs.graph\n\n.. autoclass:: Graph\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~Graph.__init__\n      ~Graph.get_heterograph_from_graph_and_mol\n      ~Graph.get_homograph_from_mol\n      ~Graph.load\n      ~Graph.save\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~Graph.edata\n      ~Graph.ndata\n      ~Graph.nodes\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.graphs.graph.rst",
    "content": "espaloma.graphs.graph\n=====================\n\n.. automodule:: espaloma.graphs.graph\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      BaseGraph\n      Graph\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.legacy_force_field.LegacyForceField.rst",
    "content": "espaloma.graphs.legacy\\_force\\_field.LegacyForceField\n=====================================================\n\n.. currentmodule:: espaloma.graphs.legacy_force_field\n\n.. autoclass:: LegacyForceField\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~LegacyForceField.__init__\n      ~LegacyForceField.baseline_energy\n      ~LegacyForceField.multi_typing\n      ~LegacyForceField.parametrize\n      ~LegacyForceField.typing\n   \n   \n\n   \n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.graphs.legacy_force_field.rst",
    "content": "espaloma.graphs.legacy\\_force\\_field\n====================================\n\n.. automodule:: espaloma.graphs.legacy_force_field\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      LegacyForceField\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.rst",
    "content": "﻿espaloma.graphs\n===============\n\n.. automodule:: espaloma.graphs\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n.. rubric:: Modules\n\n.. autosummary::\n   :toctree:\n   :template: custom-module-template.rst\n   :recursive:\n\n\n   espaloma.graphs.deploy\n   espaloma.graphs.graph\n   espaloma.graphs.legacy_force_field\n   espaloma.graphs.utils\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.offmol_indices.angle_indices.rst",
    "content": "espaloma.graphs.utils.offmol\\_indices.angle\\_indices\n====================================================\n\n.. currentmodule:: espaloma.graphs.utils.offmol_indices\n\n.. autofunction:: angle_indices"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.offmol_indices.atom_indices.rst",
    "content": "espaloma.graphs.utils.offmol\\_indices.atom\\_indices\n===================================================\n\n.. currentmodule:: espaloma.graphs.utils.offmol_indices\n\n.. autofunction:: atom_indices"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.offmol_indices.bond_indices.rst",
    "content": "espaloma.graphs.utils.offmol\\_indices.bond\\_indices\n===================================================\n\n.. currentmodule:: espaloma.graphs.utils.offmol_indices\n\n.. autofunction:: bond_indices"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.offmol_indices.improper_torsion_indices.rst",
    "content": "espaloma.graphs.utils.offmol\\_indices.improper\\_torsion\\_indices\n================================================================\n\n.. currentmodule:: espaloma.graphs.utils.offmol_indices\n\n.. autofunction:: improper_torsion_indices"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.offmol_indices.proper_torsion_indices.rst",
    "content": "espaloma.graphs.utils.offmol\\_indices.proper\\_torsion\\_indices\n==============================================================\n\n.. currentmodule:: espaloma.graphs.utils.offmol_indices\n\n.. autofunction:: proper_torsion_indices"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.offmol_indices.rst",
    "content": "espaloma.graphs.utils.offmol\\_indices\n=====================================\n\n.. automodule:: espaloma.graphs.utils.offmol_indices\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      angle_indices\n      atom_indices\n      bond_indices\n      improper_torsion_indices\n      proper_torsion_indices\n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.read_heterogeneous_graph.duplicate_index_ordering.rst",
    "content": "espaloma.graphs.utils.read\\_heterogeneous\\_graph.duplicate\\_index\\_ordering\n===========================================================================\n\n.. currentmodule:: espaloma.graphs.utils.read_heterogeneous_graph\n\n.. autofunction:: duplicate_index_ordering"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.read_heterogeneous_graph.from_homogeneous_and_mol.rst",
    "content": "espaloma.graphs.utils.read\\_heterogeneous\\_graph.from\\_homogeneous\\_and\\_mol\n============================================================================\n\n.. currentmodule:: espaloma.graphs.utils.read_heterogeneous_graph\n\n.. autofunction:: from_homogeneous_and_mol"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.read_heterogeneous_graph.relationship_indices_from_offmol.rst",
    "content": "espaloma.graphs.utils.read\\_heterogeneous\\_graph.relationship\\_indices\\_from\\_offmol\n====================================================================================\n\n.. currentmodule:: espaloma.graphs.utils.read_heterogeneous_graph\n\n.. autofunction:: relationship_indices_from_offmol"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.read_heterogeneous_graph.rst",
    "content": "espaloma.graphs.utils.read\\_heterogeneous\\_graph\n================================================\n\n.. automodule:: espaloma.graphs.utils.read_heterogeneous_graph\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      duplicate_index_ordering\n      from_homogeneous_and_mol\n      relationship_indices_from_offmol\n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.read_homogeneous_graph.fp_oe.rst",
    "content": "espaloma.graphs.utils.read\\_homogeneous\\_graph.fp\\_oe\n=====================================================\n\n.. currentmodule:: espaloma.graphs.utils.read_homogeneous_graph\n\n.. autofunction:: fp_oe"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.read_homogeneous_graph.fp_rdkit.rst",
    "content": "espaloma.graphs.utils.read\\_homogeneous\\_graph.fp\\_rdkit\n========================================================\n\n.. currentmodule:: espaloma.graphs.utils.read_homogeneous_graph\n\n.. autofunction:: fp_rdkit"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.read_homogeneous_graph.from_oemol.rst",
    "content": "espaloma.graphs.utils.read\\_homogeneous\\_graph.from\\_oemol\n==========================================================\n\n.. currentmodule:: espaloma.graphs.utils.read_homogeneous_graph\n\n.. autofunction:: from_oemol"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.read_homogeneous_graph.from_openff_toolkit_mol.rst",
    "content": "espaloma.graphs.utils.read\\_homogeneous\\_graph.from\\_openff\\_toolkit\\_mol\n=========================================================================\n\n.. currentmodule:: espaloma.graphs.utils.read_homogeneous_graph\n\n.. autofunction:: from_openff_toolkit_mol"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.read_homogeneous_graph.from_rdkit_mol.rst",
    "content": "espaloma.graphs.utils.read\\_homogeneous\\_graph.from\\_rdkit\\_mol\n===============================================================\n\n.. currentmodule:: espaloma.graphs.utils.read_homogeneous_graph\n\n.. autofunction:: from_rdkit_mol"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.read_homogeneous_graph.rst",
    "content": "espaloma.graphs.utils.read\\_homogeneous\\_graph\n==============================================\n\n.. automodule:: espaloma.graphs.utils.read_homogeneous_graph\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      fp_oe\n      fp_rdkit\n      from_oemol\n      from_openff_toolkit_mol\n      from_rdkit_mol\n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.graphs.utils.rst",
    "content": "espaloma.graphs.utils\n=====================\n\n.. automodule:: espaloma.graphs.utils\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n.. rubric:: Modules\n\n.. autosummary::\n   :toctree:\n   :template: custom-module-template.rst\n   :recursive:\n\n\n   espaloma.graphs.utils.offmol_indices\n   espaloma.graphs.utils.read_heterogeneous_graph\n   espaloma.graphs.utils.read_homogeneous_graph\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.mm.angle.angle_high.rst",
    "content": "espaloma.mm.angle.angle\\_high\n=============================\n\n.. currentmodule:: espaloma.mm.angle\n\n.. autofunction:: angle_high"
  },
  {
    "path": "docs/autosummary/espaloma.mm.angle.bond_angle.rst",
    "content": "espaloma.mm.angle.bond\\_angle\n=============================\n\n.. currentmodule:: espaloma.mm.angle\n\n.. autofunction:: bond_angle"
  },
  {
    "path": "docs/autosummary/espaloma.mm.angle.bond_bond.rst",
    "content": "espaloma.mm.angle.bond\\_bond\n============================\n\n.. currentmodule:: espaloma.mm.angle\n\n.. autofunction:: bond_bond"
  },
  {
    "path": "docs/autosummary/espaloma.mm.angle.harmonic_angle.rst",
    "content": "espaloma.mm.angle.harmonic\\_angle\n=================================\n\n.. currentmodule:: espaloma.mm.angle\n\n.. autofunction:: harmonic_angle"
  },
  {
    "path": "docs/autosummary/espaloma.mm.angle.linear_mixture_angle.rst",
    "content": "espaloma.mm.angle.linear\\_mixture\\_angle\n========================================\n\n.. currentmodule:: espaloma.mm.angle\n\n.. autofunction:: linear_mixture_angle"
  },
  {
    "path": "docs/autosummary/espaloma.mm.angle.rst",
    "content": "espaloma.mm.angle\n=================\n\n.. automodule:: espaloma.mm.angle\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      angle_high\n      bond_angle\n      bond_bond\n      harmonic_angle\n      linear_mixture_angle\n      urey_bradley\n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.mm.angle.urey_bradley.rst",
    "content": "espaloma.mm.angle.urey\\_bradley\n===============================\n\n.. currentmodule:: espaloma.mm.angle\n\n.. autofunction:: urey_bradley"
  },
  {
    "path": "docs/autosummary/espaloma.mm.bond.bond_high.rst",
    "content": "espaloma.mm.bond.bond\\_high\n===========================\n\n.. currentmodule:: espaloma.mm.bond\n\n.. autofunction:: bond_high"
  },
  {
    "path": "docs/autosummary/espaloma.mm.bond.gaussian_bond.rst",
    "content": "espaloma.mm.bond.gaussian\\_bond\n===============================\n\n.. currentmodule:: espaloma.mm.bond\n\n.. autofunction:: gaussian_bond"
  },
  {
    "path": "docs/autosummary/espaloma.mm.bond.harmonic_bond.rst",
    "content": "espaloma.mm.bond.harmonic\\_bond\n===============================\n\n.. currentmodule:: espaloma.mm.bond\n\n.. autofunction:: harmonic_bond"
  },
  {
    "path": "docs/autosummary/espaloma.mm.bond.linear_mixture_bond.rst",
    "content": "espaloma.mm.bond.linear\\_mixture\\_bond\n======================================\n\n.. currentmodule:: espaloma.mm.bond\n\n.. autofunction:: linear_mixture_bond"
  },
  {
    "path": "docs/autosummary/espaloma.mm.bond.rst",
    "content": "espaloma.mm.bond\n================\n\n.. automodule:: espaloma.mm.bond\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      bond_high\n      gaussian_bond\n      harmonic_bond\n      linear_mixture_bond\n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.CarryII.rst",
    "content": "espaloma.mm.energy.CarryII\n==========================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autoclass:: CarryII\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~CarryII.__init__\n      ~CarryII.add_module\n      ~CarryII.apply\n      ~CarryII.bfloat16\n      ~CarryII.buffers\n      ~CarryII.children\n      ~CarryII.cpu\n      ~CarryII.cuda\n      ~CarryII.double\n      ~CarryII.eval\n      ~CarryII.extra_repr\n      ~CarryII.float\n      ~CarryII.forward\n      ~CarryII.half\n      ~CarryII.load_state_dict\n      ~CarryII.modules\n      ~CarryII.named_buffers\n      ~CarryII.named_children\n      ~CarryII.named_modules\n      ~CarryII.named_parameters\n      ~CarryII.parameters\n      ~CarryII.register_backward_hook\n      ~CarryII.register_buffer\n      ~CarryII.register_forward_hook\n      ~CarryII.register_forward_pre_hook\n      ~CarryII.register_parameter\n      ~CarryII.requires_grad_\n      ~CarryII.share_memory\n      ~CarryII.state_dict\n      ~CarryII.to\n      ~CarryII.train\n      ~CarryII.type\n      ~CarryII.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~CarryII.T_destination\n      ~CarryII.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.EnergyInGraph.rst",
    "content": "espaloma.mm.energy.EnergyInGraph\n================================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autoclass:: EnergyInGraph\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~EnergyInGraph.__init__\n      ~EnergyInGraph.add_module\n      ~EnergyInGraph.apply\n      ~EnergyInGraph.bfloat16\n      ~EnergyInGraph.buffers\n      ~EnergyInGraph.children\n      ~EnergyInGraph.cpu\n      ~EnergyInGraph.cuda\n      ~EnergyInGraph.double\n      ~EnergyInGraph.eval\n      ~EnergyInGraph.extra_repr\n      ~EnergyInGraph.float\n      ~EnergyInGraph.forward\n      ~EnergyInGraph.half\n      ~EnergyInGraph.load_state_dict\n      ~EnergyInGraph.modules\n      ~EnergyInGraph.named_buffers\n      ~EnergyInGraph.named_children\n      ~EnergyInGraph.named_modules\n      ~EnergyInGraph.named_parameters\n      ~EnergyInGraph.parameters\n      ~EnergyInGraph.register_backward_hook\n      ~EnergyInGraph.register_buffer\n      ~EnergyInGraph.register_forward_hook\n      ~EnergyInGraph.register_forward_pre_hook\n      ~EnergyInGraph.register_parameter\n      ~EnergyInGraph.requires_grad_\n      ~EnergyInGraph.share_memory\n      ~EnergyInGraph.state_dict\n      ~EnergyInGraph.to\n      ~EnergyInGraph.train\n      ~EnergyInGraph.type\n      ~EnergyInGraph.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~EnergyInGraph.T_destination\n      ~EnergyInGraph.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.EnergyInGraphII.rst",
    "content": "espaloma.mm.energy.EnergyInGraphII\n==================================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autoclass:: EnergyInGraphII\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~EnergyInGraphII.__init__\n      ~EnergyInGraphII.add_module\n      ~EnergyInGraphII.apply\n      ~EnergyInGraphII.bfloat16\n      ~EnergyInGraphII.buffers\n      ~EnergyInGraphII.children\n      ~EnergyInGraphII.cpu\n      ~EnergyInGraphII.cuda\n      ~EnergyInGraphII.double\n      ~EnergyInGraphII.eval\n      ~EnergyInGraphII.extra_repr\n      ~EnergyInGraphII.float\n      ~EnergyInGraphII.forward\n      ~EnergyInGraphII.half\n      ~EnergyInGraphII.load_state_dict\n      ~EnergyInGraphII.modules\n      ~EnergyInGraphII.named_buffers\n      ~EnergyInGraphII.named_children\n      ~EnergyInGraphII.named_modules\n      ~EnergyInGraphII.named_parameters\n      ~EnergyInGraphII.parameters\n      ~EnergyInGraphII.register_backward_hook\n      ~EnergyInGraphII.register_buffer\n      ~EnergyInGraphII.register_forward_hook\n      ~EnergyInGraphII.register_forward_pre_hook\n      ~EnergyInGraphII.register_parameter\n      ~EnergyInGraphII.requires_grad_\n      ~EnergyInGraphII.share_memory\n      ~EnergyInGraphII.state_dict\n      ~EnergyInGraphII.to\n      ~EnergyInGraphII.train\n      ~EnergyInGraphII.type\n      ~EnergyInGraphII.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~EnergyInGraphII.T_destination\n      ~EnergyInGraphII.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.apply_angle.rst",
    "content": "espaloma.mm.energy.apply\\_angle\n===============================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autofunction:: apply_angle"
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.apply_angle_ii.rst",
    "content": "espaloma.mm.energy.apply\\_angle\\_ii\n===================================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autofunction:: apply_angle_ii"
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.apply_angle_linear_mixture.rst",
    "content": "espaloma.mm.energy.apply\\_angle\\_linear\\_mixture\n================================================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autofunction:: apply_angle_linear_mixture"
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.apply_bond.rst",
    "content": "espaloma.mm.energy.apply\\_bond\n==============================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autofunction:: apply_bond"
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.apply_bond_gaussian.rst",
    "content": "espaloma.mm.energy.apply\\_bond\\_gaussian\n========================================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autofunction:: apply_bond_gaussian"
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.apply_bond_ii.rst",
    "content": "espaloma.mm.energy.apply\\_bond\\_ii\n==================================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autofunction:: apply_bond_ii"
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.apply_bond_linear_mixture.rst",
    "content": "espaloma.mm.energy.apply\\_bond\\_linear\\_mixture\n===============================================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autofunction:: apply_bond_linear_mixture"
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.apply_improper_torsion.rst",
    "content": "espaloma.mm.energy.apply\\_improper\\_torsion\n===========================================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autofunction:: apply_improper_torsion"
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.apply_nonbonded.rst",
    "content": "espaloma.mm.energy.apply\\_nonbonded\n===================================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autofunction:: apply_nonbonded"
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.apply_torsion.rst",
    "content": "espaloma.mm.energy.apply\\_torsion\n=================================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autofunction:: apply_torsion"
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.apply_torsion_ii.rst",
    "content": "espaloma.mm.energy.apply\\_torsion\\_ii\n=====================================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autofunction:: apply_torsion_ii"
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.energy_in_graph.rst",
    "content": "espaloma.mm.energy.energy\\_in\\_graph\n====================================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autofunction:: energy_in_graph"
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.energy_in_graph_ii.rst",
    "content": "espaloma.mm.energy.energy\\_in\\_graph\\_ii\n========================================\n\n.. currentmodule:: espaloma.mm.energy\n\n.. autofunction:: energy_in_graph_ii"
  },
  {
    "path": "docs/autosummary/espaloma.mm.energy.rst",
    "content": "espaloma.mm.energy\n==================\n\n.. automodule:: espaloma.mm.energy\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      apply_angle\n      apply_angle_ii\n      apply_angle_linear_mixture\n      apply_bond\n      apply_bond_gaussian\n      apply_bond_ii\n      apply_bond_linear_mixture\n      apply_improper_torsion\n      apply_nonbonded\n      apply_torsion\n      apply_torsion_ii\n      energy_in_graph\n      energy_in_graph_ii\n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      CarryII\n      EnergyInGraph\n      EnergyInGraphII\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.mm.functional.gaussian.rst",
    "content": "espaloma.mm.functional.gaussian\n===============================\n\n.. currentmodule:: espaloma.mm.functional\n\n.. autofunction:: gaussian"
  },
  {
    "path": "docs/autosummary/espaloma.mm.functional.harmonic.rst",
    "content": "espaloma.mm.functional.harmonic\n===============================\n\n.. currentmodule:: espaloma.mm.functional\n\n.. autofunction:: harmonic"
  },
  {
    "path": "docs/autosummary/espaloma.mm.functional.harmonic_harmonic_coupled.rst",
    "content": "espaloma.mm.functional.harmonic\\_harmonic\\_coupled\n==================================================\n\n.. currentmodule:: espaloma.mm.functional\n\n.. autofunction:: harmonic_harmonic_coupled"
  },
  {
    "path": "docs/autosummary/espaloma.mm.functional.harmonic_harmonic_periodic_coupled.rst",
    "content": "espaloma.mm.functional.harmonic\\_harmonic\\_periodic\\_coupled\n============================================================\n\n.. currentmodule:: espaloma.mm.functional\n\n.. autofunction:: harmonic_harmonic_periodic_coupled"
  },
  {
    "path": "docs/autosummary/espaloma.mm.functional.harmonic_periodic_coupled.rst",
    "content": "espaloma.mm.functional.harmonic\\_periodic\\_coupled\n==================================================\n\n.. currentmodule:: espaloma.mm.functional\n\n.. autofunction:: harmonic_periodic_coupled"
  },
  {
    "path": "docs/autosummary/espaloma.mm.functional.linear_mixture.rst",
    "content": "espaloma.mm.functional.linear\\_mixture\n======================================\n\n.. currentmodule:: espaloma.mm.functional\n\n.. autofunction:: linear_mixture"
  },
  {
    "path": "docs/autosummary/espaloma.mm.functional.linear_mixture_to_original.rst",
    "content": "espaloma.mm.functional.linear\\_mixture\\_to\\_original\n====================================================\n\n.. currentmodule:: espaloma.mm.functional\n\n.. autofunction:: linear_mixture_to_original"
  },
  {
    "path": "docs/autosummary/espaloma.mm.functional.lj.rst",
    "content": "espaloma.mm.functional.lj\n=========================\n\n.. currentmodule:: espaloma.mm.functional\n\n.. autofunction:: lj"
  },
  {
    "path": "docs/autosummary/espaloma.mm.functional.periodic.rst",
    "content": "espaloma.mm.functional.periodic\n===============================\n\n.. currentmodule:: espaloma.mm.functional\n\n.. autofunction:: periodic"
  },
  {
    "path": "docs/autosummary/espaloma.mm.functional.periodic_fixed_phases.rst",
    "content": "espaloma.mm.functional.periodic\\_fixed\\_phases\n==============================================\n\n.. currentmodule:: espaloma.mm.functional\n\n.. autofunction:: periodic_fixed_phases"
  },
  {
    "path": "docs/autosummary/espaloma.mm.functional.rst",
    "content": "espaloma.mm.functional\n======================\n\n.. automodule:: espaloma.mm.functional\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      gaussian\n      harmonic\n      harmonic_harmonic_coupled\n      harmonic_harmonic_periodic_coupled\n      harmonic_periodic_coupled\n      linear_mixture\n      linear_mixture_to_original\n      lj\n      periodic\n      periodic_fixed_phases\n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.mm.geometry.GeometryInGraph.rst",
    "content": "espaloma.mm.geometry.GeometryInGraph\n====================================\n\n.. currentmodule:: espaloma.mm.geometry\n\n.. autoclass:: GeometryInGraph\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~GeometryInGraph.__init__\n      ~GeometryInGraph.add_module\n      ~GeometryInGraph.apply\n      ~GeometryInGraph.bfloat16\n      ~GeometryInGraph.buffers\n      ~GeometryInGraph.children\n      ~GeometryInGraph.cpu\n      ~GeometryInGraph.cuda\n      ~GeometryInGraph.double\n      ~GeometryInGraph.eval\n      ~GeometryInGraph.extra_repr\n      ~GeometryInGraph.float\n      ~GeometryInGraph.forward\n      ~GeometryInGraph.half\n      ~GeometryInGraph.load_state_dict\n      ~GeometryInGraph.modules\n      ~GeometryInGraph.named_buffers\n      ~GeometryInGraph.named_children\n      ~GeometryInGraph.named_modules\n      ~GeometryInGraph.named_parameters\n      ~GeometryInGraph.parameters\n      ~GeometryInGraph.register_backward_hook\n      ~GeometryInGraph.register_buffer\n      ~GeometryInGraph.register_forward_hook\n      ~GeometryInGraph.register_forward_pre_hook\n      ~GeometryInGraph.register_parameter\n      ~GeometryInGraph.requires_grad_\n      ~GeometryInGraph.share_memory\n      ~GeometryInGraph.state_dict\n      ~GeometryInGraph.to\n      ~GeometryInGraph.train\n      ~GeometryInGraph.type\n      ~GeometryInGraph.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~GeometryInGraph.T_destination\n      ~GeometryInGraph.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.mm.geometry.angle.rst",
    "content": "espaloma.mm.geometry.angle\n==========================\n\n.. currentmodule:: espaloma.mm.geometry\n\n.. autofunction:: angle"
  },
  {
    "path": "docs/autosummary/espaloma.mm.geometry.apply_angle.rst",
    "content": "espaloma.mm.geometry.apply\\_angle\n=================================\n\n.. currentmodule:: espaloma.mm.geometry\n\n.. autofunction:: apply_angle"
  },
  {
    "path": "docs/autosummary/espaloma.mm.geometry.apply_bond.rst",
    "content": "espaloma.mm.geometry.apply\\_bond\n================================\n\n.. currentmodule:: espaloma.mm.geometry\n\n.. autofunction:: apply_bond"
  },
  {
    "path": "docs/autosummary/espaloma.mm.geometry.apply_torsion.rst",
    "content": "espaloma.mm.geometry.apply\\_torsion\n===================================\n\n.. currentmodule:: espaloma.mm.geometry\n\n.. autofunction:: apply_torsion"
  },
  {
    "path": "docs/autosummary/espaloma.mm.geometry.copy_src.rst",
    "content": "espaloma.mm.geometry.copy\\_src\n==============================\n\n.. currentmodule:: espaloma.mm.geometry\n\n.. autofunction:: copy_src"
  },
  {
    "path": "docs/autosummary/espaloma.mm.geometry.dihedral.rst",
    "content": "espaloma.mm.geometry.dihedral\n=============================\n\n.. currentmodule:: espaloma.mm.geometry\n\n.. autofunction:: dihedral"
  },
  {
    "path": "docs/autosummary/espaloma.mm.geometry.distance.rst",
    "content": "espaloma.mm.geometry.distance\n=============================\n\n.. currentmodule:: espaloma.mm.geometry\n\n.. autofunction:: distance"
  },
  {
    "path": "docs/autosummary/espaloma.mm.geometry.geometry_in_graph.rst",
    "content": "espaloma.mm.geometry.geometry\\_in\\_graph\n========================================\n\n.. currentmodule:: espaloma.mm.geometry\n\n.. autofunction:: geometry_in_graph"
  },
  {
    "path": "docs/autosummary/espaloma.mm.geometry.reduce_stack.rst",
    "content": "espaloma.mm.geometry.reduce\\_stack\n==================================\n\n.. currentmodule:: espaloma.mm.geometry\n\n.. autofunction:: reduce_stack"
  },
  {
    "path": "docs/autosummary/espaloma.mm.geometry.rst",
    "content": "espaloma.mm.geometry\n====================\n\n.. automodule:: espaloma.mm.geometry\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      angle\n      apply_angle\n      apply_bond\n      apply_torsion\n      copy_src\n      dihedral\n      distance\n      geometry_in_graph\n      reduce_stack\n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      GeometryInGraph\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.mm.nonbonded.arithmetic_mean.rst",
    "content": "espaloma.mm.nonbonded.arithmetic\\_mean\n======================================\n\n.. currentmodule:: espaloma.mm.nonbonded\n\n.. autofunction:: arithmetic_mean"
  },
  {
    "path": "docs/autosummary/espaloma.mm.nonbonded.geometric_mean.rst",
    "content": "espaloma.mm.nonbonded.geometric\\_mean\n=====================================\n\n.. currentmodule:: espaloma.mm.nonbonded\n\n.. autofunction:: geometric_mean"
  },
  {
    "path": "docs/autosummary/espaloma.mm.nonbonded.lj_12_6.rst",
    "content": "espaloma.mm.nonbonded.lj\\_12\\_6\n===============================\n\n.. currentmodule:: espaloma.mm.nonbonded\n\n.. autofunction:: lj_12_6"
  },
  {
    "path": "docs/autosummary/espaloma.mm.nonbonded.lj_9_6.rst",
    "content": "espaloma.mm.nonbonded.lj\\_9\\_6\n==============================\n\n.. currentmodule:: espaloma.mm.nonbonded\n\n.. autofunction:: lj_9_6"
  },
  {
    "path": "docs/autosummary/espaloma.mm.nonbonded.lorentz_berthelot.rst",
    "content": "espaloma.mm.nonbonded.lorentz\\_berthelot\n========================================\n\n.. currentmodule:: espaloma.mm.nonbonded\n\n.. autofunction:: lorentz_berthelot"
  },
  {
    "path": "docs/autosummary/espaloma.mm.nonbonded.rst",
    "content": "espaloma.mm.nonbonded\n=====================\n\n.. automodule:: espaloma.mm.nonbonded\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      arithmetic_mean\n      geometric_mean\n      lj_12_6\n      lj_9_6\n      lorentz_berthelot\n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.mm.rst",
    "content": "﻿espaloma.mm\n===========\n\n.. automodule:: espaloma.mm\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n.. rubric:: Modules\n\n.. autosummary::\n   :toctree:\n   :template: custom-module-template.rst\n   :recursive:\n\n\n   espaloma.mm.angle\n   espaloma.mm.bond\n   espaloma.mm.energy\n   espaloma.mm.functional\n   espaloma.mm.geometry\n   espaloma.mm.nonbonded\n   espaloma.mm.torsion\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.mm.torsion.angle_angle.rst",
    "content": "espaloma.mm.torsion.angle\\_angle\n================================\n\n.. currentmodule:: espaloma.mm.torsion\n\n.. autofunction:: angle_angle"
  },
  {
    "path": "docs/autosummary/espaloma.mm.torsion.angle_angle_torsion.rst",
    "content": "espaloma.mm.torsion.angle\\_angle\\_torsion\n=========================================\n\n.. currentmodule:: espaloma.mm.torsion\n\n.. autofunction:: angle_angle_torsion"
  },
  {
    "path": "docs/autosummary/espaloma.mm.torsion.angle_torsion.rst",
    "content": "espaloma.mm.torsion.angle\\_torsion\n==================================\n\n.. currentmodule:: espaloma.mm.torsion\n\n.. autofunction:: angle_torsion"
  },
  {
    "path": "docs/autosummary/espaloma.mm.torsion.bond_torsion.rst",
    "content": "espaloma.mm.torsion.bond\\_torsion\n=================================\n\n.. currentmodule:: espaloma.mm.torsion\n\n.. autofunction:: bond_torsion"
  },
  {
    "path": "docs/autosummary/espaloma.mm.torsion.periodic_torsion.rst",
    "content": "espaloma.mm.torsion.periodic\\_torsion\n=====================================\n\n.. currentmodule:: espaloma.mm.torsion\n\n.. autofunction:: periodic_torsion"
  },
  {
    "path": "docs/autosummary/espaloma.mm.torsion.rst",
    "content": "espaloma.mm.torsion\n===================\n\n.. automodule:: espaloma.mm.torsion\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      angle_angle\n      angle_angle_torsion\n      angle_torsion\n      bond_torsion\n      periodic_torsion\n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.nn.baselines.FreeParameterBaseline.rst",
    "content": "espaloma.nn.baselines.FreeParameterBaseline\n===========================================\n\n.. currentmodule:: espaloma.nn.baselines\n\n.. autoclass:: FreeParameterBaseline\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~FreeParameterBaseline.__init__\n      ~FreeParameterBaseline.add_module\n      ~FreeParameterBaseline.apply\n      ~FreeParameterBaseline.bfloat16\n      ~FreeParameterBaseline.buffers\n      ~FreeParameterBaseline.children\n      ~FreeParameterBaseline.cpu\n      ~FreeParameterBaseline.cuda\n      ~FreeParameterBaseline.double\n      ~FreeParameterBaseline.eval\n      ~FreeParameterBaseline.extra_repr\n      ~FreeParameterBaseline.float\n      ~FreeParameterBaseline.forward\n      ~FreeParameterBaseline.half\n      ~FreeParameterBaseline.load_state_dict\n      ~FreeParameterBaseline.modules\n      ~FreeParameterBaseline.named_buffers\n      ~FreeParameterBaseline.named_children\n      ~FreeParameterBaseline.named_modules\n      ~FreeParameterBaseline.named_parameters\n      ~FreeParameterBaseline.parameters\n      ~FreeParameterBaseline.register_backward_hook\n      ~FreeParameterBaseline.register_buffer\n      ~FreeParameterBaseline.register_forward_hook\n      ~FreeParameterBaseline.register_forward_pre_hook\n      ~FreeParameterBaseline.register_parameter\n      ~FreeParameterBaseline.requires_grad_\n      ~FreeParameterBaseline.share_memory\n      ~FreeParameterBaseline.state_dict\n      ~FreeParameterBaseline.to\n      ~FreeParameterBaseline.train\n      ~FreeParameterBaseline.type\n      ~FreeParameterBaseline.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~FreeParameterBaseline.T_destination\n      ~FreeParameterBaseline.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.nn.baselines.FreeParameterBaselineInitMean.rst",
    "content": "espaloma.nn.baselines.FreeParameterBaselineInitMean\n===================================================\n\n.. currentmodule:: espaloma.nn.baselines\n\n.. autoclass:: FreeParameterBaselineInitMean\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~FreeParameterBaselineInitMean.__init__\n      ~FreeParameterBaselineInitMean.add_module\n      ~FreeParameterBaselineInitMean.apply\n      ~FreeParameterBaselineInitMean.bfloat16\n      ~FreeParameterBaselineInitMean.buffers\n      ~FreeParameterBaselineInitMean.children\n      ~FreeParameterBaselineInitMean.cpu\n      ~FreeParameterBaselineInitMean.cuda\n      ~FreeParameterBaselineInitMean.double\n      ~FreeParameterBaselineInitMean.eval\n      ~FreeParameterBaselineInitMean.extra_repr\n      ~FreeParameterBaselineInitMean.float\n      ~FreeParameterBaselineInitMean.forward\n      ~FreeParameterBaselineInitMean.half\n      ~FreeParameterBaselineInitMean.load_state_dict\n      ~FreeParameterBaselineInitMean.modules\n      ~FreeParameterBaselineInitMean.named_buffers\n      ~FreeParameterBaselineInitMean.named_children\n      ~FreeParameterBaselineInitMean.named_modules\n      ~FreeParameterBaselineInitMean.named_parameters\n      ~FreeParameterBaselineInitMean.parameters\n      ~FreeParameterBaselineInitMean.register_backward_hook\n      ~FreeParameterBaselineInitMean.register_buffer\n      ~FreeParameterBaselineInitMean.register_forward_hook\n      ~FreeParameterBaselineInitMean.register_forward_pre_hook\n      ~FreeParameterBaselineInitMean.register_parameter\n      ~FreeParameterBaselineInitMean.requires_grad_\n      ~FreeParameterBaselineInitMean.share_memory\n      ~FreeParameterBaselineInitMean.state_dict\n      ~FreeParameterBaselineInitMean.to\n      ~FreeParameterBaselineInitMean.train\n      ~FreeParameterBaselineInitMean.type\n      ~FreeParameterBaselineInitMean.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~FreeParameterBaselineInitMean.T_destination\n      ~FreeParameterBaselineInitMean.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.nn.baselines.rst",
    "content": "espaloma.nn.baselines\n=====================\n\n.. automodule:: espaloma.nn.baselines\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      FreeParameterBaseline\n      FreeParameterBaselineInitMean\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.nn.layers.dgl_legacy.GN.rst",
    "content": "﻿espaloma.nn.layers.dgl\\_legacy.gn\n=================================\n\n.. currentmodule:: espaloma.nn.layers.dgl_legacy\n\n.. autofunction:: gn"
  },
  {
    "path": "docs/autosummary/espaloma.nn.layers.dgl_legacy.rst",
    "content": "espaloma.nn.layers.dgl\\_legacy\n==============================\n\n.. automodule:: espaloma.nn.layers.dgl_legacy\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      gn\n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      GN\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.nn.layers.rst",
    "content": "espaloma.nn.layers\n==================\n\n.. automodule:: espaloma.nn.layers\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n.. rubric:: Modules\n\n.. autosummary::\n   :toctree:\n   :template: custom-module-template.rst\n   :recursive:\n\n\n   espaloma.nn.layers.dgl_legacy\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.base_readout.BaseReadout.rst",
    "content": "espaloma.nn.readout.base\\_readout.BaseReadout\n=============================================\n\n.. currentmodule:: espaloma.nn.readout.base_readout\n\n.. autoclass:: BaseReadout\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~BaseReadout.__init__\n      ~BaseReadout.add_module\n      ~BaseReadout.apply\n      ~BaseReadout.bfloat16\n      ~BaseReadout.buffers\n      ~BaseReadout.children\n      ~BaseReadout.cpu\n      ~BaseReadout.cuda\n      ~BaseReadout.double\n      ~BaseReadout.eval\n      ~BaseReadout.extra_repr\n      ~BaseReadout.float\n      ~BaseReadout.forward\n      ~BaseReadout.half\n      ~BaseReadout.load_state_dict\n      ~BaseReadout.modules\n      ~BaseReadout.named_buffers\n      ~BaseReadout.named_children\n      ~BaseReadout.named_modules\n      ~BaseReadout.named_parameters\n      ~BaseReadout.parameters\n      ~BaseReadout.register_backward_hook\n      ~BaseReadout.register_buffer\n      ~BaseReadout.register_forward_hook\n      ~BaseReadout.register_forward_pre_hook\n      ~BaseReadout.register_parameter\n      ~BaseReadout.requires_grad_\n      ~BaseReadout.share_memory\n      ~BaseReadout.state_dict\n      ~BaseReadout.to\n      ~BaseReadout.train\n      ~BaseReadout.type\n      ~BaseReadout.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~BaseReadout.T_destination\n      ~BaseReadout.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.base_readout.rst",
    "content": "espaloma.nn.readout.base\\_readout\n=================================\n\n.. automodule:: espaloma.nn.readout.base_readout\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      BaseReadout\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.charge_equilibrium.ChargeEquilibrium.rst",
    "content": "espaloma.nn.readout.charge\\_equilibrium.ChargeEquilibrium\n=========================================================\n\n.. currentmodule:: espaloma.nn.readout.charge_equilibrium\n\n.. autoclass:: ChargeEquilibrium\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~ChargeEquilibrium.__init__\n      ~ChargeEquilibrium.add_module\n      ~ChargeEquilibrium.apply\n      ~ChargeEquilibrium.bfloat16\n      ~ChargeEquilibrium.buffers\n      ~ChargeEquilibrium.children\n      ~ChargeEquilibrium.cpu\n      ~ChargeEquilibrium.cuda\n      ~ChargeEquilibrium.double\n      ~ChargeEquilibrium.eval\n      ~ChargeEquilibrium.extra_repr\n      ~ChargeEquilibrium.float\n      ~ChargeEquilibrium.forward\n      ~ChargeEquilibrium.half\n      ~ChargeEquilibrium.load_state_dict\n      ~ChargeEquilibrium.modules\n      ~ChargeEquilibrium.named_buffers\n      ~ChargeEquilibrium.named_children\n      ~ChargeEquilibrium.named_modules\n      ~ChargeEquilibrium.named_parameters\n      ~ChargeEquilibrium.parameters\n      ~ChargeEquilibrium.register_backward_hook\n      ~ChargeEquilibrium.register_buffer\n      ~ChargeEquilibrium.register_forward_hook\n      ~ChargeEquilibrium.register_forward_pre_hook\n      ~ChargeEquilibrium.register_parameter\n      ~ChargeEquilibrium.requires_grad_\n      ~ChargeEquilibrium.share_memory\n      ~ChargeEquilibrium.state_dict\n      ~ChargeEquilibrium.to\n      ~ChargeEquilibrium.train\n      ~ChargeEquilibrium.type\n      ~ChargeEquilibrium.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~ChargeEquilibrium.T_destination\n      ~ChargeEquilibrium.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.charge_equilibrium.get_charges.rst",
    "content": "espaloma.nn.readout.charge\\_equilibrium.get\\_charges\n====================================================\n\n.. currentmodule:: espaloma.nn.readout.charge_equilibrium\n\n.. autofunction:: get_charges"
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.charge_equilibrium.rst",
    "content": "espaloma.nn.readout.charge\\_equilibrium\n=======================================\n\n.. automodule:: espaloma.nn.readout.charge_equilibrium\n  \n   \n   \n   \n\n   \n   \n   .. rubric:: Functions\n\n   .. autosummary::\n      :toctree:\n   \n      get_charges\n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      ChargeEquilibrium\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.graph_level_readout.GraphLevelReadout.rst",
    "content": "espaloma.nn.readout.graph\\_level\\_readout.GraphLevelReadout\n===========================================================\n\n.. currentmodule:: espaloma.nn.readout.graph_level_readout\n\n.. autoclass:: GraphLevelReadout\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~GraphLevelReadout.__init__\n      ~GraphLevelReadout.add_module\n      ~GraphLevelReadout.apply\n      ~GraphLevelReadout.bfloat16\n      ~GraphLevelReadout.buffers\n      ~GraphLevelReadout.children\n      ~GraphLevelReadout.cpu\n      ~GraphLevelReadout.cuda\n      ~GraphLevelReadout.double\n      ~GraphLevelReadout.eval\n      ~GraphLevelReadout.extra_repr\n      ~GraphLevelReadout.float\n      ~GraphLevelReadout.forward\n      ~GraphLevelReadout.half\n      ~GraphLevelReadout.load_state_dict\n      ~GraphLevelReadout.modules\n      ~GraphLevelReadout.named_buffers\n      ~GraphLevelReadout.named_children\n      ~GraphLevelReadout.named_modules\n      ~GraphLevelReadout.named_parameters\n      ~GraphLevelReadout.parameters\n      ~GraphLevelReadout.register_backward_hook\n      ~GraphLevelReadout.register_buffer\n      ~GraphLevelReadout.register_forward_hook\n      ~GraphLevelReadout.register_forward_pre_hook\n      ~GraphLevelReadout.register_parameter\n      ~GraphLevelReadout.requires_grad_\n      ~GraphLevelReadout.share_memory\n      ~GraphLevelReadout.state_dict\n      ~GraphLevelReadout.to\n      ~GraphLevelReadout.train\n      ~GraphLevelReadout.type\n      ~GraphLevelReadout.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~GraphLevelReadout.T_destination\n      ~GraphLevelReadout.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.graph_level_readout.rst",
    "content": "espaloma.nn.readout.graph\\_level\\_readout\n=========================================\n\n.. automodule:: espaloma.nn.readout.graph_level_readout\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      GraphLevelReadout\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.janossy.ExpCoefficients.rst",
    "content": "espaloma.nn.readout.janossy.ExpCoefficients\n===========================================\n\n.. currentmodule:: espaloma.nn.readout.janossy\n\n.. autoclass:: ExpCoefficients\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~ExpCoefficients.__init__\n      ~ExpCoefficients.add_module\n      ~ExpCoefficients.apply\n      ~ExpCoefficients.bfloat16\n      ~ExpCoefficients.buffers\n      ~ExpCoefficients.children\n      ~ExpCoefficients.cpu\n      ~ExpCoefficients.cuda\n      ~ExpCoefficients.double\n      ~ExpCoefficients.eval\n      ~ExpCoefficients.extra_repr\n      ~ExpCoefficients.float\n      ~ExpCoefficients.forward\n      ~ExpCoefficients.half\n      ~ExpCoefficients.load_state_dict\n      ~ExpCoefficients.modules\n      ~ExpCoefficients.named_buffers\n      ~ExpCoefficients.named_children\n      ~ExpCoefficients.named_modules\n      ~ExpCoefficients.named_parameters\n      ~ExpCoefficients.parameters\n      ~ExpCoefficients.register_backward_hook\n      ~ExpCoefficients.register_buffer\n      ~ExpCoefficients.register_forward_hook\n      ~ExpCoefficients.register_forward_pre_hook\n      ~ExpCoefficients.register_parameter\n      ~ExpCoefficients.requires_grad_\n      ~ExpCoefficients.share_memory\n      ~ExpCoefficients.state_dict\n      ~ExpCoefficients.to\n      ~ExpCoefficients.train\n      ~ExpCoefficients.type\n      ~ExpCoefficients.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~ExpCoefficients.T_destination\n      ~ExpCoefficients.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.janossy.JanossyPooling.rst",
    "content": "espaloma.nn.readout.janossy.JanossyPooling\n==========================================\n\n.. currentmodule:: espaloma.nn.readout.janossy\n\n.. autoclass:: JanossyPooling\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~JanossyPooling.__init__\n      ~JanossyPooling.add_module\n      ~JanossyPooling.apply\n      ~JanossyPooling.bfloat16\n      ~JanossyPooling.buffers\n      ~JanossyPooling.children\n      ~JanossyPooling.cpu\n      ~JanossyPooling.cuda\n      ~JanossyPooling.double\n      ~JanossyPooling.eval\n      ~JanossyPooling.extra_repr\n      ~JanossyPooling.float\n      ~JanossyPooling.forward\n      ~JanossyPooling.half\n      ~JanossyPooling.load_state_dict\n      ~JanossyPooling.modules\n      ~JanossyPooling.named_buffers\n      ~JanossyPooling.named_children\n      ~JanossyPooling.named_modules\n      ~JanossyPooling.named_parameters\n      ~JanossyPooling.parameters\n      ~JanossyPooling.register_backward_hook\n      ~JanossyPooling.register_buffer\n      ~JanossyPooling.register_forward_hook\n      ~JanossyPooling.register_forward_pre_hook\n      ~JanossyPooling.register_parameter\n      ~JanossyPooling.requires_grad_\n      ~JanossyPooling.share_memory\n      ~JanossyPooling.state_dict\n      ~JanossyPooling.to\n      ~JanossyPooling.train\n      ~JanossyPooling.type\n      ~JanossyPooling.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~JanossyPooling.T_destination\n      ~JanossyPooling.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.janossy.JanossyPoolingImproper.rst",
    "content": "espaloma.nn.readout.janossy.JanossyPoolingImproper\n==================================================\n\n.. currentmodule:: espaloma.nn.readout.janossy\n\n.. autoclass:: JanossyPoolingImproper\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~JanossyPoolingImproper.__init__\n      ~JanossyPoolingImproper.add_module\n      ~JanossyPoolingImproper.apply\n      ~JanossyPoolingImproper.bfloat16\n      ~JanossyPoolingImproper.buffers\n      ~JanossyPoolingImproper.children\n      ~JanossyPoolingImproper.cpu\n      ~JanossyPoolingImproper.cuda\n      ~JanossyPoolingImproper.double\n      ~JanossyPoolingImproper.eval\n      ~JanossyPoolingImproper.extra_repr\n      ~JanossyPoolingImproper.float\n      ~JanossyPoolingImproper.forward\n      ~JanossyPoolingImproper.half\n      ~JanossyPoolingImproper.load_state_dict\n      ~JanossyPoolingImproper.modules\n      ~JanossyPoolingImproper.named_buffers\n      ~JanossyPoolingImproper.named_children\n      ~JanossyPoolingImproper.named_modules\n      ~JanossyPoolingImproper.named_parameters\n      ~JanossyPoolingImproper.parameters\n      ~JanossyPoolingImproper.register_backward_hook\n      ~JanossyPoolingImproper.register_buffer\n      ~JanossyPoolingImproper.register_forward_hook\n      ~JanossyPoolingImproper.register_forward_pre_hook\n      ~JanossyPoolingImproper.register_parameter\n      ~JanossyPoolingImproper.requires_grad_\n      ~JanossyPoolingImproper.share_memory\n      ~JanossyPoolingImproper.state_dict\n      ~JanossyPoolingImproper.to\n      ~JanossyPoolingImproper.train\n      ~JanossyPoolingImproper.type\n      ~JanossyPoolingImproper.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~JanossyPoolingImproper.T_destination\n      ~JanossyPoolingImproper.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.janossy.JanossyPoolingNonbonded.rst",
    "content": "espaloma.nn.readout.janossy.JanossyPoolingNonbonded\n===================================================\n\n.. currentmodule:: espaloma.nn.readout.janossy\n\n.. autoclass:: JanossyPoolingNonbonded\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~JanossyPoolingNonbonded.__init__\n      ~JanossyPoolingNonbonded.add_module\n      ~JanossyPoolingNonbonded.apply\n      ~JanossyPoolingNonbonded.bfloat16\n      ~JanossyPoolingNonbonded.buffers\n      ~JanossyPoolingNonbonded.children\n      ~JanossyPoolingNonbonded.cpu\n      ~JanossyPoolingNonbonded.cuda\n      ~JanossyPoolingNonbonded.double\n      ~JanossyPoolingNonbonded.eval\n      ~JanossyPoolingNonbonded.extra_repr\n      ~JanossyPoolingNonbonded.float\n      ~JanossyPoolingNonbonded.forward\n      ~JanossyPoolingNonbonded.half\n      ~JanossyPoolingNonbonded.load_state_dict\n      ~JanossyPoolingNonbonded.modules\n      ~JanossyPoolingNonbonded.named_buffers\n      ~JanossyPoolingNonbonded.named_children\n      ~JanossyPoolingNonbonded.named_modules\n      ~JanossyPoolingNonbonded.named_parameters\n      ~JanossyPoolingNonbonded.parameters\n      ~JanossyPoolingNonbonded.register_backward_hook\n      ~JanossyPoolingNonbonded.register_buffer\n      ~JanossyPoolingNonbonded.register_forward_hook\n      ~JanossyPoolingNonbonded.register_forward_pre_hook\n      ~JanossyPoolingNonbonded.register_parameter\n      ~JanossyPoolingNonbonded.requires_grad_\n      ~JanossyPoolingNonbonded.share_memory\n      ~JanossyPoolingNonbonded.state_dict\n      ~JanossyPoolingNonbonded.to\n      ~JanossyPoolingNonbonded.train\n      ~JanossyPoolingNonbonded.type\n      ~JanossyPoolingNonbonded.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~JanossyPoolingNonbonded.T_destination\n      ~JanossyPoolingNonbonded.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.janossy.LinearMixtureToOriginal.rst",
    "content": "espaloma.nn.readout.janossy.LinearMixtureToOriginal\n===================================================\n\n.. currentmodule:: espaloma.nn.readout.janossy\n\n.. autoclass:: LinearMixtureToOriginal\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~LinearMixtureToOriginal.__init__\n      ~LinearMixtureToOriginal.add_module\n      ~LinearMixtureToOriginal.apply\n      ~LinearMixtureToOriginal.bfloat16\n      ~LinearMixtureToOriginal.buffers\n      ~LinearMixtureToOriginal.children\n      ~LinearMixtureToOriginal.cpu\n      ~LinearMixtureToOriginal.cuda\n      ~LinearMixtureToOriginal.double\n      ~LinearMixtureToOriginal.eval\n      ~LinearMixtureToOriginal.extra_repr\n      ~LinearMixtureToOriginal.float\n      ~LinearMixtureToOriginal.forward\n      ~LinearMixtureToOriginal.half\n      ~LinearMixtureToOriginal.load_state_dict\n      ~LinearMixtureToOriginal.modules\n      ~LinearMixtureToOriginal.named_buffers\n      ~LinearMixtureToOriginal.named_children\n      ~LinearMixtureToOriginal.named_modules\n      ~LinearMixtureToOriginal.named_parameters\n      ~LinearMixtureToOriginal.parameters\n      ~LinearMixtureToOriginal.register_backward_hook\n      ~LinearMixtureToOriginal.register_buffer\n      ~LinearMixtureToOriginal.register_forward_hook\n      ~LinearMixtureToOriginal.register_forward_pre_hook\n      ~LinearMixtureToOriginal.register_parameter\n      ~LinearMixtureToOriginal.requires_grad_\n      ~LinearMixtureToOriginal.share_memory\n      ~LinearMixtureToOriginal.state_dict\n      ~LinearMixtureToOriginal.to\n      ~LinearMixtureToOriginal.train\n      ~LinearMixtureToOriginal.type\n      ~LinearMixtureToOriginal.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~LinearMixtureToOriginal.T_destination\n      ~LinearMixtureToOriginal.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.janossy.rst",
    "content": "espaloma.nn.readout.janossy\n===========================\n\n.. automodule:: espaloma.nn.readout.janossy\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      ExpCoefficients\n      JanossyPooling\n      JanossyPoolingImproper\n      JanossyPoolingNonbonded\n      LinearMixtureToOriginal\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.node_typing.NodeTyping.rst",
    "content": "espaloma.nn.readout.node\\_typing.NodeTyping\n===========================================\n\n.. currentmodule:: espaloma.nn.readout.node_typing\n\n.. autoclass:: NodeTyping\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~NodeTyping.__init__\n      ~NodeTyping.add_module\n      ~NodeTyping.apply\n      ~NodeTyping.bfloat16\n      ~NodeTyping.buffers\n      ~NodeTyping.children\n      ~NodeTyping.cpu\n      ~NodeTyping.cuda\n      ~NodeTyping.double\n      ~NodeTyping.eval\n      ~NodeTyping.extra_repr\n      ~NodeTyping.float\n      ~NodeTyping.forward\n      ~NodeTyping.half\n      ~NodeTyping.load_state_dict\n      ~NodeTyping.modules\n      ~NodeTyping.named_buffers\n      ~NodeTyping.named_children\n      ~NodeTyping.named_modules\n      ~NodeTyping.named_parameters\n      ~NodeTyping.parameters\n      ~NodeTyping.register_backward_hook\n      ~NodeTyping.register_buffer\n      ~NodeTyping.register_forward_hook\n      ~NodeTyping.register_forward_pre_hook\n      ~NodeTyping.register_parameter\n      ~NodeTyping.requires_grad_\n      ~NodeTyping.share_memory\n      ~NodeTyping.state_dict\n      ~NodeTyping.to\n      ~NodeTyping.train\n      ~NodeTyping.type\n      ~NodeTyping.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~NodeTyping.T_destination\n      ~NodeTyping.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.node_typing.rst",
    "content": "espaloma.nn.readout.node\\_typing\n================================\n\n.. automodule:: espaloma.nn.readout.node_typing\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      NodeTyping\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.nn.readout.rst",
    "content": "espaloma.nn.readout\n===================\n\n.. automodule:: espaloma.nn.readout\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n.. rubric:: Modules\n\n.. autosummary::\n   :toctree:\n   :template: custom-module-template.rst\n   :recursive:\n\n\n   espaloma.nn.readout.base_readout\n   espaloma.nn.readout.charge_equilibrium\n   espaloma.nn.readout.graph_level_readout\n   espaloma.nn.readout.janossy\n   espaloma.nn.readout.node_typing\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.nn.rst",
    "content": "﻿espaloma.nn\n===========\n\n.. automodule:: espaloma.nn\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   \n\n\n\n.. rubric:: Modules\n\n.. autosummary::\n   :toctree:\n   :template: custom-module-template.rst\n   :recursive:\n\n\n   espaloma.nn.baselines\n   espaloma.nn.layers\n   espaloma.nn.readout\n   espaloma.nn.sequential\n\n"
  },
  {
    "path": "docs/autosummary/espaloma.nn.sequential.Sequential.rst",
    "content": "espaloma.nn.sequential.Sequential\n=================================\n\n.. currentmodule:: espaloma.nn.sequential\n\n.. autoclass:: Sequential\n   :members:\n   :show-inheritance:\n   :inherited-members:\n\n   \n   .. automethod:: __init__\n\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n   \n      ~Sequential.__init__\n      ~Sequential.add_module\n      ~Sequential.apply\n      ~Sequential.bfloat16\n      ~Sequential.buffers\n      ~Sequential.children\n      ~Sequential.cpu\n      ~Sequential.cuda\n      ~Sequential.double\n      ~Sequential.eval\n      ~Sequential.extra_repr\n      ~Sequential.float\n      ~Sequential.forward\n      ~Sequential.half\n      ~Sequential.load_state_dict\n      ~Sequential.modules\n      ~Sequential.named_buffers\n      ~Sequential.named_children\n      ~Sequential.named_modules\n      ~Sequential.named_parameters\n      ~Sequential.parameters\n      ~Sequential.register_backward_hook\n      ~Sequential.register_buffer\n      ~Sequential.register_forward_hook\n      ~Sequential.register_forward_pre_hook\n      ~Sequential.register_parameter\n      ~Sequential.requires_grad_\n      ~Sequential.share_memory\n      ~Sequential.state_dict\n      ~Sequential.to\n      ~Sequential.train\n      ~Sequential.type\n      ~Sequential.zero_grad\n   \n   \n\n   \n   \n   .. rubric:: Attributes\n\n   .. autosummary::\n   \n      ~Sequential.T_destination\n      ~Sequential.dump_patches\n   \n   "
  },
  {
    "path": "docs/autosummary/espaloma.nn.sequential.rst",
    "content": "espaloma.nn.sequential\n======================\n\n.. automodule:: espaloma.nn.sequential\n  \n   \n   \n   \n\n   \n   \n   \n\n   \n   \n   .. rubric:: Classes\n\n   .. autosummary::\n      :toctree:\n      :template: custom-class-template.rst\n   \n      Sequential\n   \n   \n\n   \n   \n   \n\n\n\n"
  },
  {
    "path": "docs/conf.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Configuration file for the Sphinx documentation builder.\n#\n# This file does only contain a selection of the most common options. For a\n# full list see the documentation:\n# http://www.sphinx-doc.org/en/stable/config\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n\n# Incase the project was not installed\nimport os\nimport sys\nimport subprocess\n\nsys.path.insert(0, os.path.abspath('..'))\n\n# -- Project information -----------------------------------------------------\n\nproject = 'espaloma'\ncopyright = (\"2020, Yuanqing Wang @ choderalab // MSKCC.\")\nauthor = 'Yuanqing Wang'\ngithub_url = \"https://github.com/choderalab/espaloma\"\n\n# The short X.Y version\nversion = ''\n# The full version, including alpha/beta/rc tags\nrelease = ''\n\n# -- General configuration ---------------------------------------------------\n\n# If your documentation needs a minimal Sphinx version, state it here.\n#\n# needs_sphinx = '1.0'\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    'sphinx.ext.autosummary',\n    'sphinx.ext.autodoc',\n    'sphinx.ext.mathjax',\n    'sphinx.ext.viewcode',\n    'sphinx.ext.napoleon',\n    'sphinx.ext.intersphinx',\n    'sphinx.ext.extlinks',\n    'sphinx.ext.coverage',\n    # 'numpydoc',\n]\n\nautosummary_generate = True\nnapoleon_google_docstring = False\nnapoleon_use_param = False\nnapoleon_use_ivar = True\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# The suffix(es) of source filenames.\n# You can specify multiple suffix as a list of string:\n#\n# source_suffix = ['.rst', '.md']\nsource_suffix = '.rst'\n\n# The master toctree document.\nmaster_doc = 'index'\n\n# The language for content autogenerated by Sphinx. Refer to documentation\n# for a list of supported languages.\n#\n# This is also used if you do content translation via gettext catalogs.\n# Usually you set \"language\" from the command line for these cases.\nlanguage = \"en\"\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path .\nexclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = 'default'\n\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = 'sphinx_rtd_theme'\n\n# Theme options are theme-specific and customize the look and feel of a theme\n# further.  For a list of options available for each theme, see the\n# documentation.\n#\n# html_theme_options = {}\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = ['_static']\n\n# Custom sidebar templates, must be a dictionary that maps document names\n# to template names.\n#\n# The default sidebars (for documents that don't match any pattern) are\n# defined by theme itself.  Builtin themes are using these templates by\n# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',\n# 'searchbox.html']``.\n#\n# html_sidebars = {}\n\n\n# -- Options for HTMLHelp output ---------------------------------------------\n\n# Output file base name for HTML help builder.\nhtmlhelp_basename = 'espalomadoc'\n\n\n# -- Options for LaTeX output ------------------------------------------------\n\nlatex_elements = {\n    # The paper size ('letterpaper' or 'a4paper').\n    #\n    # 'papersize': 'letterpaper',\n\n    # The font size ('10pt', '11pt' or '12pt').\n    #\n    # 'pointsize': '10pt',\n\n    # Additional stuff for the LaTeX preamble.\n    #\n    # 'preamble': '',\n\n    # Latex figure (float) alignment\n    #\n    # 'figure_align': 'htbp',\n}\n\n# Grouping the document tree into LaTeX files. List of tuples\n# (source start file, target name, title,\n#  author, documentclass [howto, manual, or own class]).\nlatex_documents = [\n    (master_doc, 'espaloma.tex', 'espaloma Documentation',\n     'espaloma', 'manual'),\n]\n\n\n# -- Options for manual page output ------------------------------------------\n\n# One entry per manual page. List of tuples\n# (source start file, name, description, authors, manual section).\nman_pages = [\n    (master_doc, 'espaloma', 'espaloma Documentation',\n     [author], 1)\n]\n\n\n# -- Options for Texinfo output ----------------------------------------------\n\n# Grouping the document tree into Texinfo files. List of tuples\n# (source start file, target name, title, author,\n#  dir menu entry, description, category)\ntexinfo_documents = [\n    (master_doc, 'espaloma', 'espaloma Documentation',\n     author, 'espaloma', 'Extensible Surrogate Potential of Ab initio Learned and Optimized by Message-passing Algorithm',\n     'Miscellaneous'),\n]\n\n\n# -- Extension configuration -------------------------------------------------\n"
  },
  {
    "path": "docs/deploy.rst",
    "content": "Deploy espaloma 0.3.2 force field to parametrize your MM system\n===============================================================\nPretrained espaloma force field could be deployed on arbitrary small molecule\nsystems in a few lines::\n\n    # imports\n    import os\n    import torch\n    import espaloma as esp\n    \n    # define or load a molecule of interest via the Open Force Field toolkit\n    from openff.toolkit.topology import Molecule\n    molecule = Molecule.from_smiles(\"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\")\n    \n    # create an Espaloma Graph object to represent the molecule of interest\n    molecule_graph = esp.Graph(molecule)\n    \n    # load pretrained model\n    espaloma_model = esp.get_model(\"latest\")\n    \n    # apply a trained espaloma model to assign parameters\n    espaloma_model(molecule_graph.heterograph)\n    \n    # create an OpenMM System for the specified molecule\n    openmm_system = esp.graphs.deploy.openmm_system_from_graph(molecule_graph)\n\nIf using espaloma from a local ``.pt`` file, say for example ``espaloma-0.3.2.pt``,\nthen you would need to run the ``eval`` method of the model to get the correct\ninference/predictions, as follows::\n\n    # load local pretrained model\n    espaloma_model = torch.load(\"espaloma-0.3.2.pt\")\n    espaloma_model.eval()\n\nThe rest of the code should be the same as in the previous example.\n"
  },
  {
    "path": "docs/download_experiments.sh",
    "content": "export fileid=1qdHEypk3uMhZEYCStWTU8u1uIDHzH3Qy\nwget -O typing.ipynb 'https://docs.google.com/uc?export=download&id='$fileid\nipython nbconvert typing.ipynb --to rst --TagRemovePreprocessor.remove_all_outputs_tags='{\"remove_output\"}'\nmv typing.rst experiments/typing.rst\n\nexport fileid=1krhwGHKoqL5-_P0G89fDB7Iw3ENHW2G_\nwget -O mm_fitting_small.ipynb 'https://docs.google.com/uc?export=download&id='$fileid\nipython nbconvert mm_fitting_small.ipynb --to rst --TagRemovePreprocessor.remove_all_outputs_tags='{\"remove_output\"}'\nmv mm_fitting_small.rst experiments/mm_fitting_small.rst\nmv mm_fitting_small_files experiments/mm_fitting_small_files\n\nexport fileid=1i_z0b0-m_91bMww1hY5Kdc76VHmtHsWD\nwget -O qm_fitting.ipynb 'https://docs.google.com/uc?export=download&id='$fileid\nipython nbconvert qm_fitting.ipynb --to rst --TagRemovePreprocessor.remove_all_outputs_tags='{\"remove_output\"}'\ncp qm_fitting.rst experiments/qm_fitting.rst\n\nrm *.ipynb\n"
  },
  {
    "path": "docs/experiments/index.rst",
    "content": "To reproduce experiments in paper https://arxiv.org/abs/2010.01196\n\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Contents:\n\n   typing\n   mm_fitting_small\n   qm_fitting\n"
  },
  {
    "path": "docs/experiments/mm_fitting_small.rst",
    "content": "Toy experiment: Molecular mechanics (MM) fitting on subsampled PhAlkEthOH dataset.\n==================================================================================\n\n**Open in Google Colab**:\nhttp://data.wangyq.net/esp_notebooks/phalkethoh_mm_small.ipynb\n\nThis notebook is intended to recover the MM fitting behavior in\nhttps://arxiv.org/abs/2010.01196\n\nTo assess how well Espaloma can learn to reproduce an MM force field\nfrom a limited amount of data, we selected a chemical dataset of limited\ncomplexity—PhAlkEthOH—which consists of linear and cyclic molecules\ncontaining phenyl rings, small alkanes, ethers, and alcohols composed of\nonly the elements carbon, oxygen, and hydrogen. We generated a set of\nconformational snapshots for each molecule using short high-temperature\nmolecular dynamics simulations at 300~K initiated from multiple\nconformations to ensure adequate sampling of conformers. The AlkEthOH\ndataset was randomly partitioned (by molecules) into 80% training, 10%\nvalidation, and 10% test molecules, with 100 snapshots/molecule, and an\nEspaloma model was trained with early stopping via monitoring for a\ndecrease in accuracy in the validation set.\n\n.. image:: https://pbs.twimg.com/media/FBL0qACXIBkJLQZ?format=png&name=4096x4096\n\nInstallation and imports\n------------------------\n\n.. code:: python\n\n    # install conda\n    ! pip install -q condacolab\n    import condacolab\n    condacolab.install()\n\n\n.. parsed-literal::\n\n    ⏬ Downloading https://github.com/jaimergp/miniforge/releases/latest/download/Mambaforge-colab-Linux-x86_64.sh...\n    📦 Installing...\n    📌 Adjusting configuration...\n    🩹 Patching environment...\n    ⏲ Done in 0:00:34\n    🔁 Restarting kernel...\n\n\n.. code:: python\n\n    %%capture\n    ! mamba install --yes --strict-channel-priority --channel jaimergp/label/unsupported-cudatoolkit-shim --channel omnia --channel omnia/label/cuda100 --channel dglteam --channel numpy openmm openmmtools openmmforcefields rdkit openff-toolkit dgl-cuda10.0 qcportal\n\n.. code:: python\n\n    ! git clone https://github.com/choderalab/espaloma.git\n\n\n.. parsed-literal::\n\n    Cloning into 'espaloma'...\n    remote: Enumerating objects: 7812, done.\u001b[K\n    remote: Counting objects: 100% (3634/3634), done.\u001b[K\n    remote: Compressing objects: 100% (1649/1649), done.\u001b[K\n    remote: Total 7812 (delta 2714), reused 2639 (delta 1900), pack-reused 4178\u001b[K\n    Receiving objects: 100% (7812/7812), 13.50 MiB | 11.77 MiB/s, done.\n    Resolving deltas: 100% (5538/5538), done.\n\n\n.. code:: python\n\n    import torch\n    import sys\n    sys.path.append(\"/content/espaloma\")\n    import espaloma as esp\n\n\n.. parsed-literal::\n\n    Warning: Unable to load toolkit 'OpenEye Toolkit'. The Open Force Field Toolkit does not require the OpenEye Toolkits, and can use RDKit/AmberTools instead. However, if you have a valid license for the OpenEye Toolkits, consider installing them for faster performance and additional file format support: https://docs.eyesopen.com/toolkits/python/quickstart-python/linuxosx.html OpenEye offers free Toolkit licenses for academics: https://www.eyesopen.com/academic-licensing\n\n\nLoad dataset\n------------\n\nHere we load the PhAlKeThoh dataset and shuffle before splitting into\ntraining, validation, and test (80%:10%:10%)\n\n.. code:: python\n\n    %%capture\n    ! wget http://data.wangyq.net/esp_dataset/phalkethoh_mm_small.zip\n    ! unzip phalkethoh_mm_small.zip\n\n.. code:: python\n\n    ds = esp.data.dataset.GraphDataset.load(\"phalkethoh\")\n    ds.shuffle(seed=2666)\n    ds_tr, ds_vl, ds_te = ds.split([8, 1, 1])\n\n\n.. parsed-literal::\n\n    DGL backend not selected or invalid.  Assuming PyTorch for now.\n    Using backend: pytorch\n\n\n.. parsed-literal::\n\n    Setting the default backend to \"pytorch\". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)\n\n\nA training dataloader is constructed with ``batch_size=100``\n\n.. code:: python\n\n    ds_tr_loader = ds_tr.view(batch_size=100, shuffle=True)\n\n.. code:: python\n\n    g_tr = next(iter(ds_tr.view(batch_size=len(ds_tr))))\n    g_vl = next(iter(ds_vl.view(batch_size=len(ds_vl))))\n\n\n.. parsed-literal::\n\n    /usr/local/lib/python3.7/site-packages/dgl/base.py:45: DGLWarning: From v0.5, DGLHeteroGraph is merged into DGLGraph. You can safely replace dgl.batch_hetero with dgl.batch\n      return warnings.warn(message, category=category, stacklevel=1)\n\n\nDefine model\n------------\n\nDefine Espaloma stage I: graph -> atom latent representation\n\n.. code:: python\n\n    representation = esp.nn.Sequential(\n        layer=esp.nn.layers.dgl_legacy.gn(\"SAGEConv\"), # use SAGEConv implementation in DGL\n        config=[128, \"relu\", 128, \"relu\", 128, \"relu\"], # 3 layers, 128 units, ReLU activation\n    )\n\nDefine Espaloma stage II and III: atom latent representation -> bond,\nangle, and torsion representation and parameters. And compose all three\nEspaloma stages into an end-to-end model.\n\n.. code:: python\n\n    readout = esp.nn.readout.janossy.JanossyPooling(\n        in_features=128, config=[128, \"relu\", 128, \"relu\", 128, \"relu\"],\n        out_features={              # define modular MM parameters Espaloma will assign\n            1: {\"e\": 1, \"s\": 1}, # atom hardness and electronegativity\n            2: {\"log_coefficients\": 2}, # bond linear combination, enforce positive\n            3: {\"log_coefficients\": 2}, # angle linear combination, enforce positive\n            4: {\"k\": 6}, # torsion barrier heights (can be positive or negative)\n        },\n    )\n    \n    espaloma_model = torch.nn.Sequential(\n                     representation, readout, esp.nn.readout.janossy.ExpCoefficients(),\n                     esp.mm.geometry.GeometryInGraph(), \n                     esp.mm.energy.EnergyInGraph(),\n                     esp.mm.energy.EnergyInGraph(suffix=\"_ref\"),\n                     esp.nn.readout.charge_equilibrium.ChargeEquilibrium(),\n    )\n\n\n.. code:: python\n\n    if torch.cuda.is_available():\n        espaloma_model = espaloma_model.cuda()\n\nLoss function is specified as the MSE between predicted and reference\nenergy.\n\n.. code:: python\n\n    loss_fn = esp.metrics.GraphMetric(\n            base_metric=torch.nn.MSELoss(), # use mean-squared error loss\n            between=['u', \"u_ref\"],         # between predicted and QM energies\n            level=\"g\", # compare on graph level\n    )\n\nDefine optimizer\n----------------\n\n.. code:: python\n\n    optimizer = torch.optim.Adam(espaloma_model.parameters(), 1e-4)\n\nTrain it!\n---------\n\n.. code:: python\n\n    for idx_epoch in range(10000):\n        for g in ds_tr_loader:\n            optimizer.zero_grad()\n            if torch.cuda.is_available():\n                g = g.to(\"cuda:0\")\n            g = espaloma_model(g)\n            loss = loss_fn(g)\n            loss.backward()\n            optimizer.step()\n            torch.save(espaloma_model.state_dict(), \"%s.th\" % idx_epoch)\n\n\n.. parsed-literal::\n\n    /usr/local/lib/python3.7/site-packages/dgl/base.py:45: DGLWarning: From v0.5, DGLHeteroGraph is merged into DGLGraph. You can safely replace dgl.batch_hetero with dgl.batch\n      return warnings.warn(message, category=category, stacklevel=1)\n    /usr/local/lib/python3.7/site-packages/dgl/base.py:45: DGLWarning: dgl.to_homo is deprecated. Please use dgl.to_homogeneous\n      return warnings.warn(message, category=category, stacklevel=1)\n\n\nInspect\n-------\n\n.. code:: python\n\n    inspect_metric = esp.metrics.GraphMetric(\n            base_metric=torch.nn.L1Loss(), # use mean-squared error loss\n            between=['u', \"u_ref\"],         # between predicted and QM energies\n            level=\"g\", # compare on graph level\n    )\n\n.. code:: python\n\n    if torch.cuda.is_available():\n        g_vl = g_vl.to(\"cuda:0\")\n        g_tr = g_tr.to(\"cuda:0\")\n\n.. code:: python\n\n    loss_tr = []\n    loss_vl = []\n\n.. code:: python\n\n    for idx_epoch in range(10000):\n        espaloma_model.load_state_dict(\n            torch.load(\"%s.th\" % idx_epoch)\n        )\n    \n        espaloma_model(g_tr)\n        loss_tr.append(inspect_metric(g_tr).item())\n    \n        espaloma_model(g_vl)\n        loss_vl.append(inspect_metric(g_vl).item())\n\n\n\n.. parsed-literal::\n\n    /usr/local/lib/python3.7/site-packages/dgl/base.py:45: DGLWarning: dgl.to_homo is deprecated. Please use dgl.to_homogeneous\n      return warnings.warn(message, category=category, stacklevel=1)\n\n\n.. code:: python\n\n    import numpy as np\n    loss_tr = np.array(loss_tr) * 627.5\n    loss_vl = np.array(loss_vl) * 627.5\n\n.. code:: python\n\n    from matplotlib import pyplot as plt \n    plt.plot(loss_tr, label=\"train\")\n    plt.plot(loss_vl, label=\"valid\")\n    plt.yscale(\"log\")\n    plt.legend()\n\n\n\n\n.. parsed-literal::\n\n    <matplotlib.legend.Legend at 0x7fd8f0eebd90>\n\n\n\n\n.. image:: mm_fitting_small_files/mm_fitting_small_31_1.png\n\n\n"
  },
  {
    "path": "docs/experiments/qm_fitting.rst",
    "content": "Quantum mechanics (QM) fitting experiment.\n==========================================\n\n**Open in Google Colab:**\nhttp://data.wangyq.net/esp_notesbooks/qm_fitting.ipynb\n\nThis notebook recovers the QM fitting experiment in\nhttps://arxiv.org/abs/2010.01196\n\n|image1| **Table 2\u0010\u0010\u0010\u0010\u0010:** Espaloma can directly fit quantum chemical\nenergies to produce a new molecular mechanics force fields with better\naccuracy than traditional force fields based on atom typing or direct\nchemical perception. Espaloma was fit to quantum chemical potential\nenergies for conformations generated by optimization trajectories from\nmultiple conformers in various datasets from QCArchive.All datasets were\npartitioned by molecules 80:10:10 into train:validate:test sets. We\nreport the RMSE on training and test sets, as well as the performance of\nlegacy force fields on the test set. All statistics are computed with\npredicted and reference energies centered to have zero mean for each\nmolecule to focus on errors in relative conformational energetics,\nrather than on errors in predicting the heats of formation of chemical\nspecies (which the MM functional form used here is incapable of). The\n95% confidence intervals annotated are calculated by via bootstrapping\nmolecules with replacement using 1000 replicates. \\*: Six cyclic\npeptides that cannot be parametrized using OpenForceField toolkit\nengine~:raw-latex:`\\cite{openff-toolkit-0.10.0}` and is not included.\n\nSince Espaloma can derive a force field solely by fitting to energies\n(and optionally gradients), we repeat the end-to-end fitting experiment\n(See notebook\nhttp://data.wangyq.net/esp_notebooks/phalkethoh_mm_small.ipynb) directly\nusing a quantum chemical (QM) datasets used to build and evaluate MM\nforce fields. We assessed the ability of Espaloma to learn several\ndistinct quantum chemical datasets generated by the Open Force Field\nInitiativeand deposited in the MolSSI QCArchive: - **PhAlkEthOH** is a\ncollection of compounds containing only the elements carbon, hydrogen,\nand oxygen in compounds containing phenyl rings, alkanes, ketones, and\nalcohols. Limited in elemental and chemical diversity, this dataset is\nchosen as a proof-of-concept to demonstrate the capability of Espaloma\nto fit and generalize quantum chemical energies when training data is\nsufficient to exhaustively cover the breadth of chemical environments. -\n**OpenFF Gen2 Optimization** consists of druglike molecules used in the\nparametrization of the Open Force Field 1.2.0 (“Parsley”) small molecule\nforce field. This set was constructed by the Open Force Field Consortium\nfrom challenging molecule structures provided by Pfizer, Bayer, and\nRoche, along with diverse molecules selected from eMolecules to achieve\nuseful coverage of chemical space. - **VEHICLe**, or *virtual\nexploratory heterocyclic library*, is a set of heteroaromatic ring\nsystems of interest to drug discovery. The atoms in the molecules in\nthis dataset have interesting chemical environments in heteroarmatic\nrings that present a challenge to traditional atom typing schemes, which\ncannot easily accomodate the nuanced distinctions in chemical\nenvironments that lead to perturbations in heterocycle structure.We use\nthis dataset to illustrate that Espaloma performs in situations\nchallenging to traditional force fields. - **PepConf** contains a\nvariety of short peptides, including capped, cyclic, and\ndisulfide-bonded peptides.This dataset—regenerated using the Open Force\nField QCSubmit tool—explores the applicability of Espaloma to\nbiopolymers, such as proteins.\n\nSince nonbonded terms are generally optimized to fit other\ncondensed-phase properties, we focused here on optimizing only the\nvalence parameters (bond, angle, and proper and improper torsion) to fit\nthese gas-phase quantum chemical datasets, fixing the non-bonded\nenergies using a legacy force field. Because we are learning an MM force\nfield that is incapable of reproducing quantum chemical heats of\nformation reflected as an additive offset in the quantum chemical energy\ntargets, in both training and test sets, snapshot energies for each\nmolecule are shifted to have zero mean. All datasets are randomly\nshuffled and split (by molecules) into training (80%), validation (10%),\nand test (10%) sets.\n\n.. |image1| image:: https://pbs.twimg.com/media/FBL1Gb0WEAYkUhM?format=png&name=4096x4096\n\nInstallation and imports\n------------------------\n\n.. code:: python\n\n    # install conda\n    ! pip install -q condacolab\n    import condacolab\n    condacolab.install()\n\n.. code:: python\n\n    %%capture\n    ! mamba install --yes --strict-channel-priority --channel jaimergp/label/unsupported-cudatoolkit-shim --channel omnia --channel omnia/label/cuda100 --channel dglteam --channel numpy openmm openmmtools openmmforcefields rdkit openff-toolkit dgl-cuda10.0 qcportal\n\n.. code:: python\n\n    ! git clone https://github.com/choderalab/espaloma.git\n\n.. code:: python\n\n    import torch\n    import sys\n    sys.path.append(\"/content/espaloma\")\n    import espaloma as esp\n\nLoad dataset\n------------\n\nChoose a dataset from ``[\"gen2\", \"pepconf\", \"vehicle\", \"phalkethoh\"]``.\n\n.. code:: python\n\n    dataset_name = \"gen2\"\n    # dataset_name = \"pepconf\"\n    # dataset_name = \"vehicle\"\n    # dataset_name = \"phalkethoh\"\n\n.. code:: python\n\n    %%capture\n    ! wget \"data.wangyq.net/esp_dataset/\"$dataset_name\".zip\"\n    ! unzip $dataset_name\".zip\"\n\n.. code:: python\n\n    ds = esp.data.dataset.GraphDataset.load(dataset_name)\n    ds.shuffle(seed=2666)\n    ds_tr, ds_vl, ds_te = ds.split([8, 1, 1])\n\nDefine model\n------------\n\nDefine Espaloma stage I: graph -> atom latent representation\n\n.. code:: python\n\n    representation = esp.nn.Sequential(\n        layer=esp.nn.layers.dgl_legacy.gn(\"SAGEConv\"), # use SAGEConv implementation in DGL\n        config=[128, \"relu\", 128, \"relu\", 128, \"relu\"], # 3 layers, 128 units, ReLU activation\n    )\n\nDefine Espaloma stage II and III: atom latent representation -> bond,\nangle, and torsion representation and parameters. And compose all three\nEspaloma stages into an end-to-end model.\n\n.. code:: python\n\n    readout = esp.nn.readout.janossy.JanossyPooling(\n        in_features=128, config=[128, \"relu\", 128, \"relu\", 128, \"relu\"],\n        out_features={              # define modular MM parameters Espaloma will assign\n            1: {\"e\": 1, \"s\": 1}, # atom hardness and electronegativity\n            2: {\"log_coefficients\": 2}, # bond linear combination, enforce positive\n            3: {\"log_coefficients\": 2}, # angle linear combination, enforce positive\n            4: {\"k\": 6}, # torsion barrier heights (can be positive or negative)\n        },\n    )\n    \n    espaloma_model = torch.nn.Sequential(\n                     representation, readout, esp.nn.readout.janossy.ExpCoefficients(),\n                     esp.mm.geometry.GeometryInGraph(), \n                     esp.mm.energy.EnergyInGraph(),\n    )\n\n\n.. code:: python\n\n    if torch.cuda.is_available():\n        espaloma_model = espaloma_model.cuda()\n\nLoss function is specified as the MSE between predicted and reference\nenergy.\n\n.. code:: python\n\n    loss_fn = esp.metrics.GraphMetric(\n            base_metric=torch.nn.MSELoss(), # use mean-squared error loss\n            between=['u', \"u_ref\"],         # between predicted and QM energies\n            level=\"g\", # compare on graph level\n    )\n\nDefine optimizer\n----------------\n\n.. code:: python\n\n    optimizer = torch.optim.Adam(espaloma_model.parameters(), 1e-4)\n\nTrain it!\n---------\n\n.. code:: python\n\n    for idx_epoch in range(10000):\n        for g in ds_tr:\n            optimizer.zero_grad()\n            if torch.cuda.is_available():\n                g.heterograph = g.heterograph.to(\"cuda:0\")\n            g = espaloma_model(g.heterograph)\n            loss = loss_fn(g)\n            loss.backward()\n            optimizer.step()\n        torch.save(espaloma_model.state_dict(), \"%s.th\" % idx_epoch)\n\nInspect\n-------\n\n.. code:: python\n\n    inspect_metric = esp.metrics.center(torch.nn.L1Loss()) # use mean-squared error loss\n\n.. code:: python\n\n    loss_tr = []\n    loss_vl = []\n\n.. code:: python\n\n    with torch.no_grad():\n        for idx_epoch in range(10000):\n            espaloma_model.load_state_dict(\n                torch.load(\"%s.th\" % idx_epoch)\n            )\n    \n            # training set performance\n            u = []\n            u_ref = []\n            for g in ds_tr:\n                if torch.cuda.is_available():\n                    g.heterograph = g.heterograph.to(\"cuda:0\")\n                espaloma_model(g.heterograph)\n                u.append(g.nodes['g'].data['u'])\n                u_ref.append(g.nodes['g'])\n            u = torch.cat(u, dim=0)\n            u_ref = torch.cat(u_ref, dim=0)\n            loss_tr.append(inspect_metric(u, u_ref))\n    \n    \n            # validation set performance\n            u = []\n            u_ref = []\n            for g in ds_vl:\n                if torch.cuda.is_available():\n                    g.heterograph = g.heterograph.to(\"cuda:0\")\n                espaloma_model(g.heterograph)\n                u.append(g.nodes['g'].data['u'])\n                u_ref.append(g.nodes['g'])\n            u = torch.cat(u, dim=0)\n            u_ref = torch.cat(u_ref, dim=0)\n            loss_vl.append(inspect_metric(u, u_ref))\n\n\n.. code:: python\n\n    import numpy as np\n    loss_tr = np.array(loss_tr) * 627.5\n    loss_vl = np.array(loss_vl) * 627.5\n\n.. code:: python\n\n    from matplotlib import pyplot as plt \n    plt.plot(loss_tr, label=\"train\")\n    plt.plot(loss_vl, label=\"valid\")\n    plt.yscale(\"log\")\n    plt.legend()\n"
  },
  {
    "path": "docs/experiments/typing.rst",
    "content": "Atom typing recovery experiment.\n================================\n\n**Open in Google Colab**:\nhttp://data.wangyq.net/esp_notebooks/typing.ipynb\n\n(GPU preferred)\n\nIn this notebook, we reproduce the atom typing recovery experiment in\n`Wang Y, Fass J, and Chodera JD “End-to-End Differentiable Construction\nof Molecular Mechanics Force\nFields <https://arxiv.org/abs/2010.01196>`__\n\n(Section 3: Graph neural networks can learn to reproduce human-defined\nlegacy atom types with high accuracy; Figure 3. Graph neural networks\ncan reproduce legacy atom types with high accuracy.)\n\n.. image:: https://pbs.twimg.com/media/FBLz_6sWUAM2iHa?format=jpg&name=4096x4096\n\nGraph neural networks can reproduce legacy atom types with high\naccuracy.\n\nThe Stage 1 graph neural network of Espaloma chained to a discrete atom\ntype readout was fit to GAFF 1.81 atom types on a subset of ZINC\ndistributed with parm Frosst as a validation set .\n\nThe 7529 molecules in this set were partitioned 80:10:10 into\ntraining:test:validation sets for this experiment. The overall test set\naccuracy was :math:`99.07\\%_{98.93\\%}^{99.22\\%}`, with 1000 bootstrap\nreplicates used to estimate the confidence intervals arising from finite\ntest set size effects. (a) The distribution of the number of atom type\ndiscrepancies on the test set demonstrates that only a minority of atoms\nare incorrectly typed. (b) The error rate per element is primarily\nconcentrated within carbon, nitrogen, and sulfur types. (c) Examining\natom type failures in detail on molecules with the largest numbers of\ndiscrepancies shows that the atom types are easily confused by a human,\nsince they represent qualities that are difficult to precisely define.\n(d) The distribution of predicted atom types for each reference atom\ntype for carbon types are shown; on-diagonal values indicate agreement.\nThe percentages annotated under x-axis denote the relative abundance\nwithin the test set.\n\nInstallation and Imports\n------------------------\n\nFirst, we install espaloma after all of its dependencies. Note that this\nis going to be significantly simplified.\n\n.. code:: python\n\n    %%capture\n    ! wget -c https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n    ! bash Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local\n    ! conda config --add channels conda-forge --add channels omnia --add channels omnia/label/cuda100 --add channels dglteam\n    ! conda update --yes --all\n    ! conda create --yes -n openmm python=3.6 numpy openmm openmmtools rdkit openforcefield==0.7.0 dgl-cuda10.0 qcportal\n    ! git clone https://github.com/choderalab/espaloma.git\n\n.. code:: python\n\n    import torch\n    import dgl\n    import numpy as np\n\nGet dataset\n-----------\n\n.. code:: python\n\n    import os\n    if not os.path.exists(\"zinc\"):\n        os.system(\"wget data.wangyq.net/esp_datasets/zinc\")\n    ds = esp.data.dataset.GraphDataset.load(\"zinc\")\n\nAssign legacy typing\n--------------------\n\nNext, we assign legacy typings using `GAFF-1.81 force\nfield. <https://github.com/openmm/openmmforcefields/blob/master/amber/gaff/dat/gaff-1.81.dat#L20-L32>`__\n\n.. code:: python\n\n    typing = esp.graphs.legacy_force_field.LegacyForceField('gaff-1.81')\n    ds.apply(typing, in_place=True) # this modify the original data\n\nData massaging\n--------------\n\nWe then split the data into training, test, and validatoin (80:10:10)\nand batch the the datasets.\n\n.. code:: python\n\n    ds_tr, ds_te, ds_vl = ds.split([8, 1, 1])\n\n.. code:: python\n\n    ds_tr = ds_tr.view('graph', batch_size=100, shuffle=True)\n    ds_te = ds_te.view('graph', batch_size=100)\n    ds_vl = ds_vl.view('graph', batch_size=100)\n\nDefining model\n--------------\n\nWe define a graph neural network (GNN) model with\n`SAGEConv <https://arxiv.org/abs/1706.02216>`__ with 128 units, three\nlayers, and ReLU activation functions.\n\n.. code:: python\n\n    # define a layer\n    layer = esp.nn.layers.dgl_legacy.gn(\"SAGEConv\")\n    \n    # define a representation\n    representation = esp.nn.Sequential(\n            layer,\n            [128, \"relu\", 128, \"relu\", 128, \"relu\"],\n    )\n    \n    # define a readout\n    readout = esp.nn.readout.node_typing.NodeTyping(\n            in_features=128,\n            n_classes=100\n    )\n    \n    net = torch.nn.Sequential(\n        representation,\n        readout\n    )\n\nDefine graph-level loss function\n--------------------------------\n\n.. code:: python\n\n    loss_fn = esp.metrics.TypingAccuracy()\n\nTrain the model\n---------------\n\n.. code:: python\n\n    # define optimizer\n    optimizer = torch.optim.Adam(net.parameters(), 1e-5)\n    \n    # train the model\n    for _ in range(3000):\n        for g in ds_tr:\n            optimizer.zero_grad()\n            net(g.heterograph)\n            loss = loss_fn(g.heterograph)\n            loss.backward()\n            optimizer.step()\n\n"
  },
  {
    "path": "docs/index.rst",
    "content": ".. espaloma documentation master file, created by\n   sphinx-quickstart on Thu Mar 15 13:55:56 2018.\n   You can adapt this file completely to your liking, but it should at least\n   contain the root `toctree` directive.\n\nespaloma: Extensible Surrogate Potential Optimized by Message-passing Algorithms\n=========================================================\n\nSource code for Wang Y, Fass J, and Chodera JD \"End-to-End Differentiable Construction of Molecular Mechanics Force Fields. https://arxiv.org/abs/2010.01196\n\n.. image:: _static/espaloma_abstract_v2-2.png\n\nPaper Abstract\n--------------\nMolecular mechanics (MM) potentials have long been a workhorse of computational chemistry.\nLeveraging accuracy and speed, these functional forms find use in a wide variety of applications in biomolecular modeling and drug discovery, from rapid virtual screening to detailed free energy calculations.\nTraditionally, MM potentials have relied on human-curated, inflexible, and poorly extensible discrete chemical perception rules *atom types* for applying parameters to small molecules or biopolymers, making it difficult to optimize both types and parameters to fit quantum chemical or physical property data.\nHere, we propose an alternative approach that uses *graph neural networks* to perceive chemical environments, producing continuous atom embeddings from which valence and nonbonded parameters can be predicted using invariance-preserving layers.\nSince all stages are built from smooth neural functions, the entire process---spanning chemical perception to parameter assignment---is modular and end-to-end differentiable with respect to model parameters, allowing new force fields to be easily constructed, extended, and applied to arbitrary molecules.\nWe show that this approach is not only sufficiently expressive to reproduce legacy atom types, but that it can learn and extend existing molecular mechanics force fields, construct entirely new force fields applicable to both biopolymers and small molecules from quantum chemical calculations, and even learn to accurately predict free energies from experimental observables.\n\n\nLab Meeting\n-----------\n.. raw:: html\n\n    <iframe width=\"600\" height=\"450\" src=\"https://www.youtube.com/embed/OC210nUuXHk\"></iframe>\n\nFull video: https://youtu.be/OC210nUuXHk\n\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Contents:\n\n   install\n   deploy\n   experiments/index\n   api\n\n\nIndices and tables\n------------------\n\n* :ref:`genindex`\n* :ref:`modindex`\n* :ref:`search`\n"
  },
  {
    "path": "docs/install.rst",
    "content": "Installation\n============\n\nmamba\n-----\n\nWe recommend using `mamba <https://mamba.readthedocs.io/en/latest/mamba-installation.html#mamba-installation>`_ which is a drop-in replacement for ``conda`` and is much faster.\n\n.. code-block:: bash\n\n   $ mamba create --name espaloma -c conda-forge \"espaloma=0.3.2\""
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\n\npushd %~dp0\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n\tset SPHINXBUILD=sphinx-build\n)\nset SOURCEDIR=.\nset BUILDDIR=_build\nset SPHINXPROJ=malt\n\nif \"%1\" == \"\" goto help\n\n%SPHINXBUILD% >NUL 2>NUL\nif errorlevel 9009 (\n\techo.\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\n\techo.installed, then set the SPHINXBUILD environment variable to point\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\n\techo.may add the Sphinx directory to PATH.\n\techo.\n\techo.If you don't have Sphinx installed, grab it from\n\techo.http://sphinx-doc.org/\n\texit /b 1\n)\n\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%\ngoto end\n\n:help\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%\n\n:end\npopd\n"
  },
  {
    "path": "docs/qm_fitting.rst",
    "content": "Quantum mechanics (QM) fitting experiment.\n==========================================\n\n**Open in Google Colab:**\nhttp://data.wangyq.net/esp_notesbooks/qm_fitting.ipynb\n\nThis notebook recovers the QM fitting experiment in\nhttps://arxiv.org/abs/2010.01196\n\n|image1| **Table 2\u0010\u0010\u0010\u0010\u0010:** Espaloma can directly fit quantum chemical\nenergies to produce a new molecular mechanics force fields with better\naccuracy than traditional force fields based on atom typing or direct\nchemical perception. Espaloma was fit to quantum chemical potential\nenergies for conformations generated by optimization trajectories from\nmultiple conformers in various datasets from QCArchive.All datasets were\npartitioned by molecules 80:10:10 into train:validate:test sets. We\nreport the RMSE on training and test sets, as well as the performance of\nlegacy force fields on the test set. All statistics are computed with\npredicted and reference energies centered to have zero mean for each\nmolecule to focus on errors in relative conformational energetics,\nrather than on errors in predicting the heats of formation of chemical\nspecies (which the MM functional form used here is incapable of). The\n95% confidence intervals annotated are calculated by via bootstrapping\nmolecules with replacement using 1000 replicates. \\*: Six cyclic\npeptides that cannot be parametrized using OpenForceField toolkit\nengine~:raw-latex:`\\cite{openff-toolkit-0.10.0}` and is not included.\n\nSince Espaloma can derive a force field solely by fitting to energies\n(and optionally gradients), we repeat the end-to-end fitting experiment\n(See notebook\nhttp://data.wangyq.net/esp_notebooks/phalkethoh_mm_small.ipynb) directly\nusing a quantum chemical (QM) datasets used to build and evaluate MM\nforce fields. We assessed the ability of Espaloma to learn several\ndistinct quantum chemical datasets generated by the Open Force Field\nInitiativeand deposited in the MolSSI QCArchive: - **PhAlkEthOH** is a\ncollection of compounds containing only the elements carbon, hydrogen,\nand oxygen in compounds containing phenyl rings, alkanes, ketones, and\nalcohols. Limited in elemental and chemical diversity, this dataset is\nchosen as a proof-of-concept to demonstrate the capability of Espaloma\nto fit and generalize quantum chemical energies when training data is\nsufficient to exhaustively cover the breadth of chemical environments. -\n**OpenFF Gen2 Optimization** consists of druglike molecules used in the\nparametrization of the Open Force Field 1.2.0 (“Parsley”) small molecule\nforce field. This set was constructed by the Open Force Field Consortium\nfrom challenging molecule structures provided by Pfizer, Bayer, and\nRoche, along with diverse molecules selected from eMolecules to achieve\nuseful coverage of chemical space. - **VEHICLe**, or *virtual\nexploratory heterocyclic library*, is a set of heteroaromatic ring\nsystems of interest to drug discovery. The atoms in the molecules in\nthis dataset have interesting chemical environments in heteroarmatic\nrings that present a challenge to traditional atom typing schemes, which\ncannot easily accomodate the nuanced distinctions in chemical\nenvironments that lead to perturbations in heterocycle structure.We use\nthis dataset to illustrate that Espaloma performs in situations\nchallenging to traditional force fields. - **PepConf** contains a\nvariety of short peptides, including capped, cyclic, and\ndisulfide-bonded peptides.This dataset—regenerated using the Open Force\nField QCSubmit tool—explores the applicability of Espaloma to\nbiopolymers, such as proteins.\n\nSince nonbonded terms are generally optimized to fit other\ncondensed-phase properties, we focused here on optimizing only the\nvalence parameters (bond, angle, and proper and improper torsion) to fit\nthese gas-phase quantum chemical datasets, fixing the non-bonded\nenergies using a legacy force field. Because we are learning an MM force\nfield that is incapable of reproducing quantum chemical heats of\nformation reflected as an additive offset in the quantum chemical energy\ntargets, in both training and test sets, snapshot energies for each\nmolecule are shifted to have zero mean. All datasets are randomly\nshuffled and split (by molecules) into training (80%), validation (10%),\nand test (10%) sets.\n\n.. |image1| image:: https://pbs.twimg.com/media/FBL1Gb0WEAYkUhM?format=png&name=4096x4096\n\nInstallation and imports\n------------------------\n\n.. code:: python\n\n    # install conda\n    ! pip install -q condacolab\n    import condacolab\n    condacolab.install()\n\n.. code:: python\n\n    %%capture\n    ! mamba install --yes --strict-channel-priority --channel jaimergp/label/unsupported-cudatoolkit-shim --channel omnia --channel omnia/label/cuda100 --channel dglteam --channel numpy openmm openmmtools openmmforcefields rdkit openff-toolkit dgl-cuda10.0 qcportal\n\n.. code:: python\n\n    ! git clone https://github.com/choderalab/espaloma.git\n\n.. code:: python\n\n    import torch\n    import sys\n    sys.path.append(\"/content/espaloma\")\n    import espaloma as esp\n\nLoad dataset\n------------\n\nChoose a dataset from ``[\"gen2\", \"pepconf\", \"vehicle\", \"phalkethoh\"]``.\n\n.. code:: python\n\n    dataset_name = \"gen2\"\n    # dataset_name = \"pepconf\"\n    # dataset_name = \"vehicle\"\n    # dataset_name = \"phalkethoh\"\n\n.. code:: python\n\n    %%capture\n    ! wget \"data.wangyq.net/esp_dataset/\"$dataset_name\".zip\"\n    ! unzip $dataset_name\".zip\"\n\n.. code:: python\n\n    ds = esp.data.dataset.GraphDataset.load(dataset_name)\n    ds.shuffle(seed=2666)\n    ds_tr, ds_vl, ds_te = ds.split([8, 1, 1])\n\nDefine model\n------------\n\nDefine Espaloma stage I: graph -> atom latent representation\n\n.. code:: python\n\n    representation = esp.nn.Sequential(\n        layer=esp.nn.layers.dgl_legacy.gn(\"SAGEConv\"), # use SAGEConv implementation in DGL\n        config=[128, \"relu\", 128, \"relu\", 128, \"relu\"], # 3 layers, 128 units, ReLU activation\n    )\n\nDefine Espaloma stage II and III: atom latent representation -> bond,\nangle, and torsion representation and parameters. And compose all three\nEspaloma stages into an end-to-end model.\n\n.. code:: python\n\n    readout = esp.nn.readout.janossy.JanossyPooling(\n        in_features=128, config=[128, \"relu\", 128, \"relu\", 128, \"relu\"],\n        out_features={              # define modular MM parameters Espaloma will assign\n            1: {\"e\": 1, \"s\": 1}, # atom hardness and electronegativity\n            2: {\"log_coefficients\": 2}, # bond linear combination, enforce positive\n            3: {\"log_coefficients\": 2}, # angle linear combination, enforce positive\n            4: {\"k\": 6}, # torsion barrier heights (can be positive or negative)\n        },\n    )\n    \n    espaloma_model = torch.nn.Sequential(\n                     representation, readout, esp.nn.readout.janossy.ExpCoefficients(),\n                     esp.mm.geometry.GeometryInGraph(), \n                     esp.mm.energy.EnergyInGraph(),\n    )\n\n\n.. code:: python\n\n    if torch.cuda.is_available():\n        espaloma_model = espaloma_model.cuda()\n\nLoss function is specified as the MSE between predicted and reference\nenergy.\n\n.. code:: python\n\n    loss_fn = esp.metrics.GraphMetric(\n            base_metric=torch.nn.MSELoss(), # use mean-squared error loss\n            between=['u', \"u_ref\"],         # between predicted and QM energies\n            level=\"g\", # compare on graph level\n    )\n\nDefine optimizer\n----------------\n\n.. code:: python\n\n    optimizer = torch.optim.Adam(espaloma_model.parameters(), 1e-4)\n\nTrain it!\n---------\n\n.. code:: python\n\n    for idx_epoch in range(10000):\n        for g in ds_tr:\n            optimizer.zero_grad()\n            if torch.cuda.is_available():\n                g.heterograph = g.heterograph.to(\"cuda:0\")\n            g = espaloma_model(g.heterograph)\n            loss = loss_fn(g)\n            loss.backward()\n            optimizer.step()\n        torch.save(espaloma_model.state_dict(), \"%s.th\" % idx_epoch)\n\nInspect\n-------\n\n.. code:: python\n\n    inspect_metric = esp.metrics.center(torch.nn.L1Loss()) # use mean-squared error loss\n\n.. code:: python\n\n    loss_tr = []\n    loss_vl = []\n\n.. code:: python\n\n    with torch.no_grad():\n        for idx_epoch in range(10000):\n            espaloma_model.load_state_dict(\n                torch.load(\"%s.th\" % idx_epoch)\n            )\n    \n            # training set performance\n            u = []\n            u_ref = []\n            for g in ds_tr:\n                if torch.cuda.is_available():\n                    g.heterograph = g.heterograph.to(\"cuda:0\")\n                espaloma_model(g.heterograph)\n                u.append(g.nodes['g'].data['u'])\n                u_ref.append(g.nodes['g'])\n            u = torch.cat(u, dim=0)\n            u_ref = torch.cat(u_ref, dim=0)\n            loss_tr.append(inspect_metric(u, u_ref))\n    \n    \n            # validation set performance\n            u = []\n            u_ref = []\n            for g in ds_vl:\n                if torch.cuda.is_available():\n                    g.heterograph = g.heterograph.to(\"cuda:0\")\n                espaloma_model(g.heterograph)\n                u.append(g.nodes['g'].data['u'])\n                u_ref.append(g.nodes['g'])\n            u = torch.cat(u, dim=0)\n            u_ref = torch.cat(u_ref, dim=0)\n            loss_vl.append(inspect_metric(u, u_ref))\n\n\n.. code:: python\n\n    import numpy as np\n    loss_tr = np.array(loss_tr) * 627.5\n    loss_vl = np.array(loss_vl) * 627.5\n\n.. code:: python\n\n    from matplotlib import pyplot as plt \n    plt.plot(loss_tr, label=\"train\")\n    plt.plot(loss_vl, label=\"valid\")\n    plt.yscale(\"log\")\n    plt.legend()\n"
  },
  {
    "path": "espaloma/.py",
    "content": ""
  },
  {
    "path": "espaloma/__init__.py",
    "content": "\"\"\"\nespaloma\nExtensible Surrogate Potential of Ab initio Learned and Optimized by Message-passing Algorithm\n\"\"\"\n\nfrom . import metrics, units, data, app, graphs, mm, nn\nfrom .app.experiment import *\nfrom .graphs.graph import Graph\nfrom .metrics import GraphMetric\nfrom .mm.geometry import *\nfrom .utils.model_fetch import get_model, get_model_path\n\n# Add imports here\n# import espaloma\n\n\n# Handle versioneer\nfrom ._version import get_versions\n\n#\n# from openff.toolkit.utils.toolkits import ToolkitRegistry, OpenEyeToolkitWrapper, RDKitToolkitWrapper, AmberToolsToolkitWrapper\n# toolkit_registry = ToolkitRegistry()\n# toolkit_precedence = [ RDKitToolkitWrapper ] # , OpenEyeToolkitWrapper, AmberToolsToolkitWrapper]\n# [ toolkit_registry.register_toolkit(toolkit) for toolkit in toolkit_precedence if toolkit.is_available() ]\n#\n\n\nversions = get_versions()\n__version__ = versions[\"version\"]\n__git_revision__ = versions[\"full-revisionid\"]\ndel get_versions, versions\n\nfrom . import _version\n__version__ = _version.get_versions()['version']\n"
  },
  {
    "path": "espaloma/_version.py",
    "content": "\n# This file helps to compute a version number in source trees obtained from\n# git-archive tarball (such as those provided by githubs download-from-tag\n# feature). Distribution tarballs (built by setup.py sdist) and build\n# directories (produced by setup.py build) will contain a much shorter file\n# that just contains the computed version number.\n\n# This file is released into the public domain.\n# Generated by versioneer-0.29\n# https://github.com/python-versioneer/python-versioneer\n\n\"\"\"Git implementation of _version.py.\"\"\"\n\nimport errno\nimport os\nimport re\nimport subprocess\nimport sys\nfrom typing import Any, Callable, Dict, List, Optional, Tuple\nimport functools\n\n\ndef get_keywords() -> Dict[str, str]:\n    \"\"\"Get the keywords needed to look up the version information.\"\"\"\n    # these strings will be replaced by git during git-archive.\n    # setup.py/versioneer.py will grep for the variable names, so they must\n    # each be defined on a line of their own. _version.py will just call\n    # get_keywords().\n    git_refnames = \"$Format:%d$\"\n    git_full = \"$Format:%H$\"\n    git_date = \"$Format:%ci$\"\n    keywords = {\"refnames\": git_refnames, \"full\": git_full, \"date\": git_date}\n    return keywords\n\n\nclass VersioneerConfig:\n    \"\"\"Container for Versioneer configuration parameters.\"\"\"\n\n    VCS: str\n    style: str\n    tag_prefix: str\n    parentdir_prefix: str\n    versionfile_source: str\n    verbose: bool\n\n\ndef get_config() -> VersioneerConfig:\n    \"\"\"Create, populate and return the VersioneerConfig() object.\"\"\"\n    # these strings are filled in when 'setup.py versioneer' creates\n    # _version.py\n    cfg = VersioneerConfig()\n    cfg.VCS = \"git\"\n    cfg.style = \"pep440\"\n    cfg.tag_prefix = \"\"\n    cfg.parentdir_prefix = \"None\"\n    cfg.versionfile_source = \"espaloma/_version.py\"\n    cfg.verbose = False\n    return cfg\n\n\nclass NotThisMethod(Exception):\n    \"\"\"Exception raised if a method is not valid for the current scenario.\"\"\"\n\n\nLONG_VERSION_PY: Dict[str, str] = {}\nHANDLERS: Dict[str, Dict[str, Callable]] = {}\n\n\ndef register_vcs_handler(vcs: str, method: str) -> Callable:  # decorator\n    \"\"\"Create decorator to mark a method as the handler of a VCS.\"\"\"\n    def decorate(f: Callable) -> Callable:\n        \"\"\"Store f in HANDLERS[vcs][method].\"\"\"\n        if vcs not in HANDLERS:\n            HANDLERS[vcs] = {}\n        HANDLERS[vcs][method] = f\n        return f\n    return decorate\n\n\ndef run_command(\n    commands: List[str],\n    args: List[str],\n    cwd: Optional[str] = None,\n    verbose: bool = False,\n    hide_stderr: bool = False,\n    env: Optional[Dict[str, str]] = None,\n) -> Tuple[Optional[str], Optional[int]]:\n    \"\"\"Call the given command(s).\"\"\"\n    assert isinstance(commands, list)\n    process = None\n\n    popen_kwargs: Dict[str, Any] = {}\n    if sys.platform == \"win32\":\n        # This hides the console window if pythonw.exe is used\n        startupinfo = subprocess.STARTUPINFO()\n        startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW\n        popen_kwargs[\"startupinfo\"] = startupinfo\n\n    for command in commands:\n        try:\n            dispcmd = str([command] + args)\n            # remember shell=False, so use git.cmd on windows, not just git\n            process = subprocess.Popen([command] + args, cwd=cwd, env=env,\n                                       stdout=subprocess.PIPE,\n                                       stderr=(subprocess.PIPE if hide_stderr\n                                               else None), **popen_kwargs)\n            break\n        except OSError as e:\n            if e.errno == errno.ENOENT:\n                continue\n            if verbose:\n                print(\"unable to run %s\" % dispcmd)\n                print(e)\n            return None, None\n    else:\n        if verbose:\n            print(\"unable to find command, tried %s\" % (commands,))\n        return None, None\n    stdout = process.communicate()[0].strip().decode()\n    if process.returncode != 0:\n        if verbose:\n            print(\"unable to run %s (error)\" % dispcmd)\n            print(\"stdout was %s\" % stdout)\n        return None, process.returncode\n    return stdout, process.returncode\n\n\ndef versions_from_parentdir(\n    parentdir_prefix: str,\n    root: str,\n    verbose: bool,\n) -> Dict[str, Any]:\n    \"\"\"Try to determine the version from the parent directory name.\n\n    Source tarballs conventionally unpack into a directory that includes both\n    the project name and a version string. We will also support searching up\n    two directory levels for an appropriately named parent directory\n    \"\"\"\n    rootdirs = []\n\n    for _ in range(3):\n        dirname = os.path.basename(root)\n        if dirname.startswith(parentdir_prefix):\n            return {\"version\": dirname[len(parentdir_prefix):],\n                    \"full-revisionid\": None,\n                    \"dirty\": False, \"error\": None, \"date\": None}\n        rootdirs.append(root)\n        root = os.path.dirname(root)  # up a level\n\n    if verbose:\n        print(\"Tried directories %s but none started with prefix %s\" %\n              (str(rootdirs), parentdir_prefix))\n    raise NotThisMethod(\"rootdir doesn't start with parentdir_prefix\")\n\n\n@register_vcs_handler(\"git\", \"get_keywords\")\ndef git_get_keywords(versionfile_abs: str) -> Dict[str, str]:\n    \"\"\"Extract version information from the given file.\"\"\"\n    # the code embedded in _version.py can just fetch the value of these\n    # keywords. When used from setup.py, we don't want to import _version.py,\n    # so we do it with a regexp instead. This function is not used from\n    # _version.py.\n    keywords: Dict[str, str] = {}\n    try:\n        with open(versionfile_abs, \"r\") as fobj:\n            for line in fobj:\n                if line.strip().startswith(\"git_refnames =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"refnames\"] = mo.group(1)\n                if line.strip().startswith(\"git_full =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"full\"] = mo.group(1)\n                if line.strip().startswith(\"git_date =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"date\"] = mo.group(1)\n    except OSError:\n        pass\n    return keywords\n\n\n@register_vcs_handler(\"git\", \"keywords\")\ndef git_versions_from_keywords(\n    keywords: Dict[str, str],\n    tag_prefix: str,\n    verbose: bool,\n) -> Dict[str, Any]:\n    \"\"\"Get version information from git keywords.\"\"\"\n    if \"refnames\" not in keywords:\n        raise NotThisMethod(\"Short version file found\")\n    date = keywords.get(\"date\")\n    if date is not None:\n        # Use only the last line.  Previous lines may contain GPG signature\n        # information.\n        date = date.splitlines()[-1]\n\n        # git-2.2.0 added \"%cI\", which expands to an ISO-8601 -compliant\n        # datestamp. However we prefer \"%ci\" (which expands to an \"ISO-8601\n        # -like\" string, which we must then edit to make compliant), because\n        # it's been around since git-1.5.3, and it's too difficult to\n        # discover which version we're using, or to work around using an\n        # older one.\n        date = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n    refnames = keywords[\"refnames\"].strip()\n    if refnames.startswith(\"$Format\"):\n        if verbose:\n            print(\"keywords are unexpanded, not using\")\n        raise NotThisMethod(\"unexpanded keywords, not a git-archive tarball\")\n    refs = {r.strip() for r in refnames.strip(\"()\").split(\",\")}\n    # starting in git-1.8.3, tags are listed as \"tag: foo-1.0\" instead of\n    # just \"foo-1.0\". If we see a \"tag: \" prefix, prefer those.\n    TAG = \"tag: \"\n    tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}\n    if not tags:\n        # Either we're using git < 1.8.3, or there really are no tags. We use\n        # a heuristic: assume all version tags have a digit. The old git %d\n        # expansion behaves like git log --decorate=short and strips out the\n        # refs/heads/ and refs/tags/ prefixes that would let us distinguish\n        # between branches and tags. By ignoring refnames without digits, we\n        # filter out many common branch names like \"release\" and\n        # \"stabilization\", as well as \"HEAD\" and \"master\".\n        tags = {r for r in refs if re.search(r'\\d', r)}\n        if verbose:\n            print(\"discarding '%s', no digits\" % \",\".join(refs - tags))\n    if verbose:\n        print(\"likely tags: %s\" % \",\".join(sorted(tags)))\n    for ref in sorted(tags):\n        # sorting will prefer e.g. \"2.0\" over \"2.0rc1\"\n        if ref.startswith(tag_prefix):\n            r = ref[len(tag_prefix):]\n            # Filter out refs that exactly match prefix or that don't start\n            # with a number once the prefix is stripped (mostly a concern\n            # when prefix is '')\n            if not re.match(r'\\d', r):\n                continue\n            if verbose:\n                print(\"picking %s\" % r)\n            return {\"version\": r,\n                    \"full-revisionid\": keywords[\"full\"].strip(),\n                    \"dirty\": False, \"error\": None,\n                    \"date\": date}\n    # no suitable tags, so version is \"0+unknown\", but full hex is still there\n    if verbose:\n        print(\"no suitable tags, using unknown + full revision id\")\n    return {\"version\": \"0+unknown\",\n            \"full-revisionid\": keywords[\"full\"].strip(),\n            \"dirty\": False, \"error\": \"no suitable tags\", \"date\": None}\n\n\n@register_vcs_handler(\"git\", \"pieces_from_vcs\")\ndef git_pieces_from_vcs(\n    tag_prefix: str,\n    root: str,\n    verbose: bool,\n    runner: Callable = run_command\n) -> Dict[str, Any]:\n    \"\"\"Get version from 'git describe' in the root of the source tree.\n\n    This only gets called if the git-archive 'subst' keywords were *not*\n    expanded, and _version.py hasn't already been rewritten with a short\n    version string, meaning we're inside a checked out source tree.\n    \"\"\"\n    GITS = [\"git\"]\n    if sys.platform == \"win32\":\n        GITS = [\"git.cmd\", \"git.exe\"]\n\n    # GIT_DIR can interfere with correct operation of Versioneer.\n    # It may be intended to be passed to the Versioneer-versioned project,\n    # but that should not change where we get our version from.\n    env = os.environ.copy()\n    env.pop(\"GIT_DIR\", None)\n    runner = functools.partial(runner, env=env)\n\n    _, rc = runner(GITS, [\"rev-parse\", \"--git-dir\"], cwd=root,\n                   hide_stderr=not verbose)\n    if rc != 0:\n        if verbose:\n            print(\"Directory %s not under git control\" % root)\n        raise NotThisMethod(\"'git rev-parse --git-dir' returned error\")\n\n    # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]\n    # if there isn't one, this yields HEX[-dirty] (no NUM)\n    describe_out, rc = runner(GITS, [\n        \"describe\", \"--tags\", \"--dirty\", \"--always\", \"--long\",\n        \"--match\", f\"{tag_prefix}[[:digit:]]*\"\n    ], cwd=root)\n    # --long was added in git-1.5.5\n    if describe_out is None:\n        raise NotThisMethod(\"'git describe' failed\")\n    describe_out = describe_out.strip()\n    full_out, rc = runner(GITS, [\"rev-parse\", \"HEAD\"], cwd=root)\n    if full_out is None:\n        raise NotThisMethod(\"'git rev-parse' failed\")\n    full_out = full_out.strip()\n\n    pieces: Dict[str, Any] = {}\n    pieces[\"long\"] = full_out\n    pieces[\"short\"] = full_out[:7]  # maybe improved later\n    pieces[\"error\"] = None\n\n    branch_name, rc = runner(GITS, [\"rev-parse\", \"--abbrev-ref\", \"HEAD\"],\n                             cwd=root)\n    # --abbrev-ref was added in git-1.6.3\n    if rc != 0 or branch_name is None:\n        raise NotThisMethod(\"'git rev-parse --abbrev-ref' returned error\")\n    branch_name = branch_name.strip()\n\n    if branch_name == \"HEAD\":\n        # If we aren't exactly on a branch, pick a branch which represents\n        # the current commit. If all else fails, we are on a branchless\n        # commit.\n        branches, rc = runner(GITS, [\"branch\", \"--contains\"], cwd=root)\n        # --contains was added in git-1.5.4\n        if rc != 0 or branches is None:\n            raise NotThisMethod(\"'git branch --contains' returned error\")\n        branches = branches.split(\"\\n\")\n\n        # Remove the first line if we're running detached\n        if \"(\" in branches[0]:\n            branches.pop(0)\n\n        # Strip off the leading \"* \" from the list of branches.\n        branches = [branch[2:] for branch in branches]\n        if \"master\" in branches:\n            branch_name = \"master\"\n        elif not branches:\n            branch_name = None\n        else:\n            # Pick the first branch that is returned. Good or bad.\n            branch_name = branches[0]\n\n    pieces[\"branch\"] = branch_name\n\n    # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]\n    # TAG might have hyphens.\n    git_describe = describe_out\n\n    # look for -dirty suffix\n    dirty = git_describe.endswith(\"-dirty\")\n    pieces[\"dirty\"] = dirty\n    if dirty:\n        git_describe = git_describe[:git_describe.rindex(\"-dirty\")]\n\n    # now we have TAG-NUM-gHEX or HEX\n\n    if \"-\" in git_describe:\n        # TAG-NUM-gHEX\n        mo = re.search(r'^(.+)-(\\d+)-g([0-9a-f]+)$', git_describe)\n        if not mo:\n            # unparsable. Maybe git-describe is misbehaving?\n            pieces[\"error\"] = (\"unable to parse git-describe output: '%s'\"\n                               % describe_out)\n            return pieces\n\n        # tag\n        full_tag = mo.group(1)\n        if not full_tag.startswith(tag_prefix):\n            if verbose:\n                fmt = \"tag '%s' doesn't start with prefix '%s'\"\n                print(fmt % (full_tag, tag_prefix))\n            pieces[\"error\"] = (\"tag '%s' doesn't start with prefix '%s'\"\n                               % (full_tag, tag_prefix))\n            return pieces\n        pieces[\"closest-tag\"] = full_tag[len(tag_prefix):]\n\n        # distance: number of commits since tag\n        pieces[\"distance\"] = int(mo.group(2))\n\n        # commit: short hex revision ID\n        pieces[\"short\"] = mo.group(3)\n\n    else:\n        # HEX: no tags\n        pieces[\"closest-tag\"] = None\n        out, rc = runner(GITS, [\"rev-list\", \"HEAD\", \"--left-right\"], cwd=root)\n        pieces[\"distance\"] = len(out.split())  # total number of commits\n\n    # commit date: see ISO-8601 comment in git_versions_from_keywords()\n    date = runner(GITS, [\"show\", \"-s\", \"--format=%ci\", \"HEAD\"], cwd=root)[0].strip()\n    # Use only the last line.  Previous lines may contain GPG signature\n    # information.\n    date = date.splitlines()[-1]\n    pieces[\"date\"] = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n\n    return pieces\n\n\ndef plus_or_dot(pieces: Dict[str, Any]) -> str:\n    \"\"\"Return a + if we don't already have one, else return a .\"\"\"\n    if \"+\" in pieces.get(\"closest-tag\", \"\"):\n        return \".\"\n    return \"+\"\n\n\ndef render_pep440(pieces: Dict[str, Any]) -> str:\n    \"\"\"Build up version string, with post-release \"local version identifier\".\n\n    Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you\n    get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty\n\n    Exceptions:\n    1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += plus_or_dot(pieces)\n            rendered += \"%d.g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0+untagged.%d.g%s\" % (pieces[\"distance\"],\n                                          pieces[\"short\"])\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef render_pep440_branch(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .\n\n    The \".dev0\" means not master branch. Note that .dev0 sorts backwards\n    (a feature branch will appear \"older\" than the master branch).\n\n    Exceptions:\n    1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            if pieces[\"branch\"] != \"master\":\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"%d.g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0\"\n        if pieces[\"branch\"] != \"master\":\n            rendered += \".dev0\"\n        rendered += \"+untagged.%d.g%s\" % (pieces[\"distance\"],\n                                          pieces[\"short\"])\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef pep440_split_post(ver: str) -> Tuple[str, Optional[int]]:\n    \"\"\"Split pep440 version string at the post-release segment.\n\n    Returns the release segments before the post-release and the\n    post-release version number (or -1 if no post-release segment is present).\n    \"\"\"\n    vc = str.split(ver, \".post\")\n    return vc[0], int(vc[1] or 0) if len(vc) == 2 else None\n\n\ndef render_pep440_pre(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[.postN.devDISTANCE] -- No -dirty.\n\n    Exceptions:\n    1: no tags. 0.post0.devDISTANCE\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        if pieces[\"distance\"]:\n            # update the post release segment\n            tag_version, post_version = pep440_split_post(pieces[\"closest-tag\"])\n            rendered = tag_version\n            if post_version is not None:\n                rendered += \".post%d.dev%d\" % (post_version + 1, pieces[\"distance\"])\n            else:\n                rendered += \".post0.dev%d\" % (pieces[\"distance\"])\n        else:\n            # no commits, use the tag as the version\n            rendered = pieces[\"closest-tag\"]\n    else:\n        # exception #1\n        rendered = \"0.post0.dev%d\" % pieces[\"distance\"]\n    return rendered\n\n\ndef render_pep440_post(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[.postDISTANCE[.dev0]+gHEX] .\n\n    The \".dev0\" means dirty. Note that .dev0 sorts backwards\n    (a dirty tree will appear \"older\" than the corresponding clean one),\n    but you shouldn't be releasing software with -dirty anyways.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"g%s\" % pieces[\"short\"]\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n        rendered += \"+g%s\" % pieces[\"short\"]\n    return rendered\n\n\ndef render_pep440_post_branch(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .\n\n    The \".dev0\" means not master branch.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"branch\"] != \"master\":\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"g%s\" % pieces[\"short\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"branch\"] != \"master\":\n            rendered += \".dev0\"\n        rendered += \"+g%s\" % pieces[\"short\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef render_pep440_old(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[.postDISTANCE[.dev0]] .\n\n    The \".dev0\" means dirty.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n    return rendered\n\n\ndef render_git_describe(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[-DISTANCE-gHEX][-dirty].\n\n    Like 'git describe --tags --dirty --always'.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"]:\n            rendered += \"-%d-g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render_git_describe_long(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG-DISTANCE-gHEX[-dirty].\n\n    Like 'git describe --tags --dirty --always -long'.\n    The distance/hash is unconditional.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        rendered += \"-%d-g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]:\n    \"\"\"Render the given version pieces into the requested style.\"\"\"\n    if pieces[\"error\"]:\n        return {\"version\": \"unknown\",\n                \"full-revisionid\": pieces.get(\"long\"),\n                \"dirty\": None,\n                \"error\": pieces[\"error\"],\n                \"date\": None}\n\n    if not style or style == \"default\":\n        style = \"pep440\"  # the default\n\n    if style == \"pep440\":\n        rendered = render_pep440(pieces)\n    elif style == \"pep440-branch\":\n        rendered = render_pep440_branch(pieces)\n    elif style == \"pep440-pre\":\n        rendered = render_pep440_pre(pieces)\n    elif style == \"pep440-post\":\n        rendered = render_pep440_post(pieces)\n    elif style == \"pep440-post-branch\":\n        rendered = render_pep440_post_branch(pieces)\n    elif style == \"pep440-old\":\n        rendered = render_pep440_old(pieces)\n    elif style == \"git-describe\":\n        rendered = render_git_describe(pieces)\n    elif style == \"git-describe-long\":\n        rendered = render_git_describe_long(pieces)\n    else:\n        raise ValueError(\"unknown style '%s'\" % style)\n\n    return {\"version\": rendered, \"full-revisionid\": pieces[\"long\"],\n            \"dirty\": pieces[\"dirty\"], \"error\": None,\n            \"date\": pieces.get(\"date\")}\n\n\ndef get_versions() -> Dict[str, Any]:\n    \"\"\"Get version information or return default if unable to do so.\"\"\"\n    # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have\n    # __file__, we can work backwards from there to the root. Some\n    # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which\n    # case we can only use expanded keywords.\n\n    cfg = get_config()\n    verbose = cfg.verbose\n\n    try:\n        return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,\n                                          verbose)\n    except NotThisMethod:\n        pass\n\n    try:\n        root = os.path.realpath(__file__)\n        # versionfile_source is the relative path from the top of the source\n        # tree (where the .git directory might live) to this file. Invert\n        # this to find the root from __file__.\n        for _ in cfg.versionfile_source.split('/'):\n            root = os.path.dirname(root)\n    except NameError:\n        return {\"version\": \"0+unknown\", \"full-revisionid\": None,\n                \"dirty\": None,\n                \"error\": \"unable to find root of source tree\",\n                \"date\": None}\n\n    try:\n        pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)\n        return render(pieces, cfg.style)\n    except NotThisMethod:\n        pass\n\n    try:\n        if cfg.parentdir_prefix:\n            return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)\n    except NotThisMethod:\n        pass\n\n    return {\"version\": \"0+unknown\", \"full-revisionid\": None,\n            \"dirty\": None,\n            \"error\": \"unable to compute version\", \"date\": None}\n"
  },
  {
    "path": "espaloma/app/__init__.py",
    "content": "from . import experiment, report\n"
  },
  {
    "path": "espaloma/app/experiment.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport abc\nimport copy\nimport torch\n\nimport espaloma as esp\n\n\n# =============================================================================\n# MODULE CLASSES\n# =============================================================================\nclass Experiment(abc.ABC):\n    \"\"\"Base class for espaloma experiment.\"\"\"\n\n    def __init__(self):\n        super(Experiment, self).__init__()\n\n\nclass Train(Experiment):\n    \"\"\"Training experiment.\n\n    Parameters\n    ----------\n    net : `torch.nn.Module`\n        Neural networks that inputs graph representation and outputs\n        parameterized or typed graph for molecular mechanics.\n\n    data : `espaloma.data.dataset.Dataset`\n        or `torch.utils.data.DataLoader`\n        Dataset.\n\n    metrics : `List` of `callable`\n        List of loss functions to be used (summed) in training.\n\n    optimizer : `torch.optim.Optimizer`\n        Optimizer used for training.\n\n    n_epochs : `int`\n        Number of epochs.\n\n    record_interval : `int`\n        Interval at which states are recorded.\n\n    Methods\n    -------\n    train_once : Train the network for exactly once.\n\n    train : Execute `train_once` for `n_epochs` times and record states\n        every `record_interval`.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        net,\n        data,\n        metrics=[esp.metrics.TypingCrossEntropy()],\n        optimizer=lambda net: torch.optim.Adam(net.parameters(), 1e-3),\n        n_epochs=100,\n        record_interval=1,\n        normalize=esp.data.normalize.ESOL100LogNormalNormalize,\n        scheduler=None,\n        device=torch.device(\"cpu\"),\n    ):\n        super(Train, self).__init__()\n\n        # bookkeeping\n        self.device = device\n        if isinstance(net, torch.nn.DataParallel):\n            self.net = net\n        elif isinstance(net, torch.nn.parallel.DistributedDataParallel):\n            self.net = net\n        else:\n            self.net = net.to(self.device)\n        self.data = data\n        self.metrics = metrics\n        self.n_epochs = n_epochs\n        self.record_interval = record_interval\n        self.normalize = normalize()\n        self.states = {}\n        self.scheduler = scheduler\n\n        # make optimizer\n        if callable(optimizer):\n            self.optimizer = optimizer(net)\n        else:\n            self.optimizer = optimizer\n\n        # compose loss function\n        def loss(g):\n            _loss = 0.0\n            for metric in self.metrics:\n                _loss += metric(g)\n\n            return _loss\n\n        self.loss = loss\n\n    def train_once(self):\n        \"\"\"Train the model for one batch.\"\"\"\n        for idx, g in enumerate(\n            self.data\n        ):  # TODO: does this have to be a single g?\n\n            if isinstance(self.optimizer, torch.optim.LBFGS):\n                retain_graph = True\n            else:\n                retain_graph = False\n\n            g = g.to(self.device)\n            self.net.train()\n\n            def closure(g=g):\n                self.optimizer.zero_grad()\n                g = self.net(g)\n                g = self.normalize.unnorm(g)\n\n                loss = self.loss(g)\n                loss.backward(retain_graph=retain_graph)\n                if idx == 0:\n                    if torch.isnan(loss).cpu().numpy().item() is True:\n                        raise RuntimeError(\"Loss is Nan.\")\n                return loss\n\n            loss = closure()\n            self.optimizer.step()\n\n            if self.scheduler is not None:\n                self.scheduler.step(loss)\n\n    def train(self):\n        \"\"\"Train the model for multiple steps and\n        record the weights once every `record_interval`\n\n        \"\"\"\n\n        for epoch_idx in range(int(self.n_epochs)):\n\n            self.train_once()\n\n            # record when `record_interval` is hit\n            if epoch_idx % self.record_interval == 0:\n                self.states[epoch_idx] = copy.deepcopy(self.net.state_dict())\n\n        # record final state\n        self.states[\"final\"] = copy.deepcopy(self.net.state_dict())\n\n        return self.net\n\n\nclass Test(Experiment):\n    \"\"\"Test experiment.\n\n    Parameters\n    ----------\n    net : `torch.nn.Module`\n        Neural networks that inputs graph representation and outputs\n        parameterized or typed graph for molecular mechanics.\n\n    data : `espaloma.data.dataset.Dataset`\n        or `torch.utils.data.DataLoader`\n        Dataset.\n\n    metrics : `List` of `callable`\n        List of loss functions to be used (summed) in training.\n\n\n    \"\"\"\n\n    def __init__(\n        self,\n        net,\n        data,\n        states,\n        metrics=[esp.metrics.TypingCrossEntropy()],\n        normalize=esp.data.normalize.NotNormalize,\n        sampler=None,\n        device=torch.device(\"cpu\"),  # it should cpu\n    ):\n        # bookkeeping\n        self.device = device\n        self.net = net.to(self.device)\n        self.data = data\n        self.states = states\n        self.metrics = metrics\n        self.sampler = sampler\n        self.normalize = normalize()\n\n    def test(self):\n        \"\"\"Run tests.\"\"\"\n\n        results = {}\n\n        # loop through the metrics\n        for metric in self.metrics:\n            results[metric.__name__] = {}\n\n        # NOTE: we are not doing this here since this will lead to OOM\n        # from time to time\n        # make it just one giant graph\n        # g = list(self.data)\n        # g = dgl.batch(g)\n        # g = g.to(self.device)\n\n        if self.states is None:\n            self.states = {\"final\": None}\n\n        for state_name, state in self.states.items():  # loop through states\n            if state is not None:\n                # load the state dict\n                self.net.load_state_dict(state)\n\n            self.net.eval()\n\n            for metric in self.metrics:\n                assert isinstance(metric, esp.metrics.Metric)\n                input_fn, target_fn = metric.between\n\n                inputs = []\n                targets = []\n\n                for g in self.data:\n                    with g.local_scope():\n                        g = g.to(self.device)\n                        g_input = self.normalize.unnorm(self.net(g))\n                        inputs.append(input_fn(g_input).detach())\n                        targets.append(target_fn(g_input).detach())\n\n                inputs = torch.cat(inputs, dim=0)\n                targets = torch.cat(targets, dim=0)\n\n                # loop through the metrics\n                results[metric.__name__][state_name] = (\n                    metric.base_metric(inputs, targets).detach().cpu().numpy()\n                )\n\n        self.ref_g = self.normalize.unnorm(self.net(g)).to(\n            torch.device(\"cpu\")\n        )\n\n        for term in self.ref_g.ntypes:\n            for param in self.ref_g.nodes[term].data.keys():\n                g.nodes[term].data[param] = g.nodes[term].data[param].detach()\n\n        # point this to self\n        self.results = results\n\n\nclass TrainAndTest(Experiment):\n    \"\"\"Train a model and then test it.\"\"\"\n\n    def __init__(\n        self,\n        net,\n        ds_tr,\n        ds_te,\n        ds_vl=None,\n        metrics_tr=[esp.metrics.TypingCrossEntropy()],\n        metrics_te=[esp.metrics.TypingCrossEntropy()],\n        optimizer=lambda net: torch.optim.Adam(net.parameters(), 1e-2),\n        normalize=esp.data.normalize.NotNormalize,\n        n_epochs=100,\n        record_interval=1,\n        device=torch.device(\"cpu\"),\n        scheduler=None,\n    ):\n\n        # bookkeeping\n        self.device = device\n        self.net = net\n        self.ds_tr = ds_tr\n        self.ds_te = ds_te\n        self.ds_vl = ds_vl\n        self.optimizer = optimizer\n        self.n_epochs = n_epochs\n        self.metrics_tr = metrics_tr\n        self.metrics_te = metrics_te\n        self.normalize = normalize\n        self.record_interval = record_interval\n        self.scheduler = scheduler\n\n    def __str__(self):\n        _str = \"\"\n        _str += \"# model\"\n        _str += \"\\n\"\n        _str += str(self.net)\n        _str += \"\\n\"\n        if hasattr(self.net, \"noise_model\"):\n            _str += \"# noise model\"\n            _str += \"\\n\"\n            _str += str(self.net.noise_model)\n            _str += \"\\n\"\n        _str += \"# optimizer\"\n        _str += \"\\n\"\n        _str += str(self.optimizer)\n        _str += \"\\n\"\n        _str += \"# n_epochs\"\n        _str += \"\\n\"\n        _str += str(self.n_epochs)\n        _str += \"\\n\"\n        return _str\n\n    def run(self):\n        \"\"\"Run train and test.\"\"\"\n        train = Train(\n            net=self.net,\n            data=self.ds_tr,\n            optimizer=self.optimizer,\n            n_epochs=self.n_epochs,\n            metrics=self.metrics_tr,\n            normalize=self.normalize,\n            device=self.device,\n            record_interval=self.record_interval,\n            scheduler=self.scheduler,\n        )\n\n        train.train()\n\n        self.states = train.states\n\n        test = Test(\n            net=self.net,\n            data=self.ds_te,\n            metrics=self.metrics_te,\n            states=self.states,\n            normalize=self.normalize,\n            device=self.device,\n        )\n\n        test.test()\n\n        self.ref_g_test = test.ref_g\n\n        self.results_te = test.results\n\n        test = Test(\n            net=self.net,\n            data=self.ds_tr,\n            metrics=self.metrics_te,\n            states=self.states,\n            normalize=self.normalize,\n            device=self.device,\n        )\n\n        test.test()\n        self.ref_g_training = test.ref_g\n\n        self.results_tr = test.results\n\n        if self.ds_vl is not None:\n\n            test = Test(\n                net=self.net,\n                data=self.ds_vl,\n                metrics=self.metrics_te,\n                states=self.states,\n                normalize=self.normalize,\n                device=self.device,\n            )\n\n            test.test()\n            self.ref_g_validation = test.ref_g\n\n            self.results_vl = test.results\n\n            return {\n                \"test\": self.results_te,\n                \"train\": self.results_tr,\n                \"validate\": self.results_vl,\n            }\n\n        return {\"test\": self.results_te, \"train\": self.results_tr}\n"
  },
  {
    "path": "espaloma/app/report.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport numpy as np\nimport pandas as pd\n\n\n# =============================================================================\n# MODULE FUNCTIONS\n# =============================================================================\ndef dataframe(results_dict):\n    # get all the results\n    metrics = list(list(results_dict.values())[0].keys())\n    ds_names = list(results_dict.keys())\n    df = pd.DataFrame(\n        [\n            [value[\"final\"].round(4) for metric, value in results.items()]\n            for ds_name, results in results_dict.items()\n        ],\n        columns=metrics,\n        index=ds_names,\n    )\n    return df\n\n\ndef curve(results_dict):\n    curve_dict = {}\n\n    # get all the results\n    metrics = list(list(results_dict.values())[0].keys())\n\n    # loop through metrics\n    for idx_metric, metric in enumerate(metrics):\n\n        # loop through the results\n        for ds_name, results in results_dict.items():\n\n            # get all the recorded indices\n            idxs = list(\n                [\n                    key\n                    for key in results[metric].keys()\n                    if isinstance(key, int)\n                ]\n            )\n\n            curve_dict[(metric, ds_name)] = np.array(\n                [results[metric][idx] for idx in idxs]\n            )\n\n    return curve_dict\n\n\ndef markdown(results_dict):\n    df = dataframe(results_dict)\n    return df.transpose().to_markdown()\n\n\ndef visual(results_dict):\n    # make plots less ugly\n    from matplotlib import pyplot as plt\n\n    plt.rc(\"font\", size=14)\n    plt.rc(\"lines\", linewidth=6)\n\n    # initialize the figure\n    fig = plt.figure(figsize=(8, 3))\n\n    # get all the results\n    metrics = list(list(results_dict.values())[0].keys())\n    n_metrics = len(metrics)\n\n    # loop through metrics\n    for idx_metric, metric in enumerate(metrics):\n        ax = plt.subplot(1, n_metrics, idx_metric + 1)\n\n        # loop through the results\n        for ds_name, results in results_dict.items():\n\n            # get all the recorded indices\n            idxs = list(\n                [\n                    key\n                    for key in results[metric].keys()\n                    if isinstance(key, int)\n                ]\n            )\n\n            # sort it ascending\n            idxs.sort()\n\n            ax.plot(\n                idxs, [results[metric][idx] for idx in idxs], label=ds_name\n            )\n\n        ax.set_xlabel(\"epochs\")\n        ax.set_ylabel(metric)\n\n    plt.tight_layout()\n    plt.legend()\n\n    return fig\n\n\ndef visual_multiple(results_dicts):\n    from matplotlib import cm as cm\n    from matplotlib import pyplot as plt\n\n    plt.rc(\"font\", size=14)\n    plt.rc(\"lines\", linewidth=4)\n\n    # initialize the figure\n    fig = plt.figure()\n\n    # get all the results\n    metrics = list(list(results_dicts[0][1].values())[0].keys())\n    n_metrics = len(metrics)\n\n    # loop through metrics\n    for idx_metric, metric in enumerate(metrics):\n        ax = plt.subplot(n_metrics, 1, idx_metric + 1)\n\n        # loop through results\n        for idx_result, config_and_results_dict in enumerate(results_dicts):\n\n            config, results_dict = config_and_results_dict\n\n            for ds_name, results in results_dict.items():\n\n                # get all the recorded indices\n                idxs = list(\n                    [\n                        key\n                        for key in results[metric].keys()\n                        if isinstance(key, int)\n                    ]\n                )\n\n                # sort it ascending\n                idxs.sort()\n\n                label = None\n                linestyle = \"dotted\"\n\n                if ds_name == \"training\":\n                    label = config[\"#\"]\n                    linestyle = \"solid\"\n\n                ax.plot(\n                    idxs,\n                    [results[metric][idx] for idx in idxs],\n                    label=label,\n                    c=cm.gist_rainbow(\n                        (float(idx_result) / len(results_dicts))\n                    ),\n                    linestyle=linestyle,\n                    alpha=0.8,\n                )\n\n        ax.set_xlabel(\"epochs\")\n        ax.set_ylabel(metric)\n\n    plt.legend(bbox_to_anchor=(1.04, 0), loc=\"lower left\")\n    plt.tight_layout()\n\n    return fig\n\n\ndef visual_base64(results_dict):\n    fig = visual(results_dict)\n    import base64\n    import io\n\n    img = io.BytesIO()\n    fig.savefig(img, format=\"png\", dpi=50)\n    img.seek(0)\n    img = base64.b64encode(img.read()).decode(\"utf-8\")\n    # img = \"![img](data:image/png;base64%s)\" % img\n    return img\n\n\ndef html(results_dict):\n    html_string = \"\"\n\n    if isinstance(results_dict, dict):\n        results_dict = [results_dict]\n\n    for _results_dict in results_dict:\n\n        html_string += \"\"\"\n        <p>\n        <div style='height:15%%;width:100%%;'>\n            <div style='float:left'>\n                <img src='data:image/png;base64, %s'/>\n            </div>\n            <div style='float:left'>\n                %s\n            </div>\n        </div>\n        <br><br><br>\n        <p/>\n        \"\"\" % (\n            visual_base64(_results_dict)[:-1],\n            dataframe(_results_dict).to_html(),\n        )\n\n    return html_string\n\n\ndef html_multiple_train_and_test(results):\n    html_string = \"\"\n    for param, result in results:\n        html_string += \"<p><br><br><br>\" + str(param) + \"<p/>\"\n        html_string += html(result)\n        html_string += \"<br><br><br>\"\n\n    return html_string\n\n\ndef html_multiple_train_and_test_2d_grid(results):\n    # make sure there are only two paramter types\n    import copy\n\n    results = copy.deepcopy(results)\n\n    for result in results:\n        result[0].pop(\"#\")\n\n    param_names = list(results[0][0].keys())\n    assert len(param_names) == 2\n    param_col_name, param_row_name = param_names\n\n    param_col_values = list(\n        set([result[0][param_col_name] for result in results])\n    )\n    param_row_values = list(\n        set([result[0][param_row_name] for result in results])\n    )\n\n    param_col_values.sort()\n    param_row_values.sort()\n\n    # initialize giant table in nested lists\n    table = [[\"NA\" for _ in param_col_values] for _ in param_row_values]\n\n    # populate this table\n    for idx_col, param_col in enumerate(param_col_values):\n        for idx_row, param_row in enumerate(param_row_values):\n            param_dict = {\n                param_col_name: param_col,\n                param_row_name: param_row,\n            }\n\n            # TODO:\n            # make this less ugly\n\n            for result in results:\n\n                if result[0] == param_dict:\n                    table[idx_row][idx_col] = html(result[1])\n\n    html_string = \"\"\n    html_string += \"<table style='border: 1px solid black'>\"\n\n    # first row\n    html_string += \"<thread><tr style='border: 1px solid black'>\"\n    html_string += (\n        \"<th style='border: 1px solid black'>\"\n        + param_row_name\n        + \"/\"\n        + param_col_name\n        + \"</th>\"\n    )\n\n    for param_col in param_col_values:\n        html_string += (\n            \"<th style='border: 1px solid black'>\" + str(param_col) + \"</th>\"\n        )\n\n    html_string += \"</tr></thread>\"\n\n    # the rest of the rows\n    for idx_row, param_row in enumerate(param_row_values):\n        html_string += \"<tr style='border: 1px solid black'>\"\n\n        # html_string += \"<td></td>\"\n\n        html_string += (\n            \"<th style='border: 1px solid black'>\" + param_row + \" </th>\"\n        )\n\n        for idx_col, param_col in enumerate(param_col_values):\n            html_string += (\n                \"<td style='border: 1px solid black'>\"\n                + table[idx_row][idx_col]\n                + \"</td>\"\n            )\n\n        html_string += \"</tr>\"\n\n    html_string += \"</table>\"\n    return html_string\n"
  },
  {
    "path": "espaloma/app/tests/test_experiment.py",
    "content": "import pytest\nimport torch\n\n\ndef test_import():\n    import espaloma as esp\n\n    esp.app.experiment\n\n\n@pytest.fixture\ndef data():\n    import espaloma as esp\n\n    esol = esp.data.esol(first=20)\n\n    # do some typing\n    typing = esp.graphs.legacy_force_field.LegacyForceField(\"gaff-1.81\")\n    esol.apply(typing, in_place=True)  # this modify the original data\n\n    return esol.view(\"graph\", batch_size=10)\n\n\n@pytest.fixture\ndef net():\n    import espaloma as esp\n\n    # define a layer\n    layer = esp.nn.layers.dgl_legacy.gn(\"GraphConv\")\n\n    # define a representation\n    representation = esp.nn.Sequential(\n        layer, [32, \"tanh\", 32, \"tanh\", 32, \"tanh\"]\n    )\n\n    # define a readout\n    readout = esp.nn.readout.node_typing.NodeTyping(\n        in_features=32, n_classes=100\n    )  # not too many elements here I think?\n\n    net = torch.nn.Sequential(representation, readout)\n\n    return net\n\n\ndef test_data_and_net(data, net):\n    data\n    net\n\n\n@pytest.fixture\ndef train(data, net):\n    import espaloma as esp\n\n    train = esp.app.experiment.Train(\n        net=net,\n        data=data,\n        n_epochs=1,\n        metrics=[\n            esp.metrics.GraphMetric(\n                base_metric=torch.nn.CrossEntropyLoss(),\n                between=[\"nn_typing\", \"legacy_typing\"],\n            )\n        ],\n    )\n\n    return train\n\n\ndef test_train(train):\n    train.train()\n\n\ndef test_test(train, net, data):\n    import espaloma as esp\n\n    train.train()\n    test = esp.app.experiment.Test(net=net, data=data, states=train.states)\n\n\ndef test_train_and_test(net, data):\n    import espaloma as esp\n\n    train_and_test = esp.app.experiment.TrainAndTest(\n        net=net, n_epochs=1, ds_tr=data, ds_te=data\n    )\n"
  },
  {
    "path": "espaloma/app/train.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport argparse\nimport os\n\nimport numpy as np\nimport torch\n\nimport espaloma as esp\n\n\ndef run(args):\n    # define data\n    data = getattr(esp.data, args.data)(first=args.first)\n\n    # get force field\n    forcefield = esp.graphs.legacy_force_field.LegacyForceField(\n        args.forcefield\n    )\n\n    # param / typing\n    operation = getattr(forcefield, args.operation)\n\n    # apply to dataset\n    data = data.apply(operation, in_place=True)\n\n    # split\n    partition = [int(x) for x in args.partition.split(\":\")]\n    ds_tr, ds_te = data.split(partition)\n\n    # batch\n    ds_tr = ds_tr.view(\"graph\", batch_size=args.batch_size)\n    ds_te = ds_te.view(\"graph\", batch_size=args.batch_size)\n\n    # layer\n    layer = esp.nn.layers.dgl_legacy.gn(args.layer)\n\n    # representation\n    representation = esp.nn.Sequential(layer, config=args.config)\n\n    # get the last bit of units\n    units = [x for x in args.config if isinstance(x, int)][-1]\n\n    # readout\n    if args.readout == \"node_typing\":\n        readout = esp.nn.readout.node_typing.NodeTyping(\n            in_features=units, n_classes=args.n_classes\n        )\n\n    if args.readout == \"janossy\":\n        readout = esp.nn.readout.janossy.JanossyPooling(\n            in_features=units, config=args.janossy_config\n        )\n\n    net = torch.nn.Sequential(representation, readout)\n\n    training_metrics = [\n        getattr(esp.metrics, metric)() for metric in args.training_metrics\n    ]\n\n    test_metrics = [\n        getattr(esp.metrics, metric)() for metric in args.test_metrics\n    ]\n\n    exp = esp.TrainAndTest(\n        ds_tr=ds_tr,\n        ds_te=ds_te,\n        net=net,\n        metrics_tr=[\n            getattr(esp.metrics, metric)() for metric in args.training_metrics\n        ],\n        metrics_te=[\n            getattr(esp.metrics, metric)() for metric in args.test_metrics\n        ],\n        n_epochs=args.n_epochs,\n    )\n\n    results = exp.run()\n\n    print(esp.app.report.markdown(results))\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data\", default=\"esol\", type=str)\n    parser.add_argument(\"--first\", default=-1, type=int)\n    parser.add_argument(\"--readout\", default=\"node_typing\", type=str)\n    parser.add_argument(\"--partition\", default=\"4:1\", type=str)\n    parser.add_argument(\"--batch_size\", default=8, type=int)\n    parser.add_argument(\"--forcefield\", default=\"gaff-1.81\", type=str)\n    parser.add_argument(\"--operation\", default=\"typing\", type=str)\n    parser.add_argument(\"--layer\", default=\"GraphConv\", type=str)\n    parser.add_argument(\"--n_classes\", default=100, type=int)\n    parser.add_argument(\n        \"--config\", nargs=\"*\", default=[32, \"tanh\", 32, \"tanh\", 32, \"tanh\"]\n    )\n\n    parser.add_argument(\n        \"--training_metrics\", nargs=\"*\", default=[\"TypingCrossEntropy\"]\n    )\n    parser.add_argument(\n        \"--test_metrics\", nargs=\"*\", default=[\"TypingAccuracy\"]\n    )\n\n    parser.add_argument(\"--janossy_config\", nargs=\"*\", default=[32, \"tanh\"])\n\n    parser.add_argument(\"--n_epochs\", default=10, type=int)\n\n    args = parser.parse_args()\n\n    run(args)\n"
  },
  {
    "path": "espaloma/app/train_all_params.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport argparse\nimport numpy as np\nimport torch\n\nimport espaloma as esp\n\n\ndef run(args):\n    # define data\n    data = getattr(esp.data, args.data)(first=args.first)\n\n    # get force field\n    forcefield = esp.graphs.legacy_force_field.LegacyForceField(\n        args.forcefield\n    )\n\n    # param / typing\n    operation = forcefield.parametrize\n\n    # apply to dataset\n    data = data.apply(operation, in_place=True)\n\n    # split\n    partition = [int(x) for x in args.partition.split(\":\")]\n    ds_tr, ds_te = data.split(partition)\n\n    # batch\n    ds_tr = ds_tr.view(\"graph\", batch_size=args.batch_size)\n    ds_te = ds_te.view(\"graph\", batch_size=args.batch_size)\n\n    # layer\n    layer = esp.nn.layers.dgl_legacy.gn(args.layer)\n\n    # representation\n    representation = esp.nn.Sequential(layer, config=args.config)\n\n    # get the last bit of units\n    units = [x for x in args.config if isinstance(x, int)][-1]\n\n    readout = esp.nn.readout.janossy.JanossyPooling(\n        in_features=units,\n        config=args.janossy_config,\n        out_features={\n            2: [\"k\", \"eq\"],\n            3: [\"k\", \"eq\"],\n        },\n    )\n\n    net = torch.nn.Sequential(representation, readout)\n\n    metrics_tr = [\n        esp.metrics.GraphMetric(\n            base_metric=torch.nn.L1Loss(),\n            between=[param, param + \"_ref\"],\n            level=term,\n        )\n        for param in [\"k\", \"eq\"]\n        for term in [\"n2\", \"n3\"]\n    ]\n\n    metrics_te = [\n        esp.metrics.GraphMetric(\n            base_metric=base_metric,\n            between=[param, param + \"_ref\"],\n            level=term,\n        )\n        for param in [\"k\", \"eq\"]\n        for term in [\"n2\", \"n3\"]\n        for base_metric in [esp.metrics.rmse, esp.metrics.r2]\n    ]\n\n    exp = esp.TrainAndTest(\n        ds_tr=ds_tr,\n        ds_te=ds_te,\n        net=net,\n        metrics_tr=metrics_tr,\n        metrics_te=metrics_te,\n        n_epochs=args.n_epochs,\n    )\n\n    results = exp.run()\n\n    print(esp.app.report.markdown(results))\n\n    import os\n\n    os.mkdir(args.out)\n\n    with open(args.out + \"/architecture.txt\", \"w\") as f_handle:\n        f_handle.write(str(exp))\n\n    with open(args.out + \"/result_table.md\", \"w\") as f_handle:\n        f_handle.write(esp.app.report.markdown(results))\n\n    curves = esp.app.report.curve(results)\n\n    for spec, curve in curves.items():\n        np.save(args.out + \"/\" + \"_\".join(spec) + \".npy\", curve)\n\n    import pickle\n\n    with open(args.out + \"/ref_g_test.th\", \"wb\") as f_handle:\n        pickle.dump(exp.ref_g_test, f_handle)\n\n    with open(args.out + \"/ref_g_training.th\", \"wb\") as f_handle:\n        pickle.dump(exp.ref_g_training, f_handle)\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data\", default=\"alkethoh\", type=str)\n    parser.add_argument(\"--out\", default=\"results\", type=str)\n    parser.add_argument(\"--first\", default=-1, type=int)\n    parser.add_argument(\"--partition\", default=\"4:1\", type=str)\n    parser.add_argument(\"--batch_size\", default=8, type=int)\n    parser.add_argument(\n        \"--forcefield\", default=\"smirnoff99Frosst-1.1.0\", type=str\n    )\n    parser.add_argument(\"--layer\", default=\"GraphConv\", type=str)\n    parser.add_argument(\"--n_classes\", default=100, type=int)\n    parser.add_argument(\n        \"--config\", nargs=\"*\", default=[32, \"tanh\", 32, \"tanh\", 32, \"tanh\"]\n    )\n\n    parser.add_argument(\n        \"--training_metrics\", nargs=\"*\", default=[\"TypingCrossEntropy\"]\n    )\n    parser.add_argument(\n        \"--test_metrics\", nargs=\"*\", default=[\"TypingAccuracy\"]\n    )\n\n    parser.add_argument(\"--janossy_config\", nargs=\"*\", default=[32, \"tanh\"])\n\n    parser.add_argument(\"--n_epochs\", default=10, type=int)\n\n    args = parser.parse_args()\n\n    run(args)\n"
  },
  {
    "path": "espaloma/app/train_bonded_energy.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport argparse\nimport os\n\nimport numpy as np\nimport torch\n\nimport espaloma as esp\n\n\ndef run(args):\n    # define data\n    data = getattr(esp.data, args.data)(first=args.first)\n\n    # get force field\n    forcefield = esp.graphs.legacy_force_field.LegacyForceField(\n        args.forcefield\n    )\n\n    # param / typing\n    operation = forcefield.parametrize\n\n    # apply to dataset\n    data = data.apply(operation, in_place=True)\n\n    # apply simulation\n    # make simulation\n    from espaloma.data.md import MoleculeVacuumSimulation\n\n    simulation = MoleculeVacuumSimulation(\n        n_samples=1000, n_steps_per_sample=10\n    )\n\n    data = data.apply(simulation.run, in_place=True)\n\n    # split\n    partition = [int(x) for x in args.partition.split(\":\")]\n    ds_tr, ds_te = data.split(partition)\n\n    # batch\n    ds_tr = ds_tr.view(\"graph\", batch_size=args.batch_size)\n    ds_te = ds_te.view(\"graph\", batch_size=args.batch_size)\n\n    # layer\n    layer = esp.nn.layers.dgl_legacy.gn(args.layer)\n\n    # representation\n    representation = esp.nn.Sequential(layer, config=args.config)\n\n    # get the last bit of units\n    units = [x for x in args.config if isinstance(x, int)][-1]\n\n    readout = esp.nn.readout.janossy.JanossyPooling(\n        in_features=units,\n        config=args.janossy_config,\n    )\n\n    net = torch.nn.Sequential(\n        representation,\n        readout,\n        esp.mm.geometry.GeometryInGraph(),\n        esp.mm.energy.EnergyInGraph(),\n        esp.mm.energy.EnergyInGraph(suffix=\"_ref\"),\n    )\n\n    metrics_tr = [\n        esp.metrics.GraphMetric(\n            base_metric=torch.nn.L1Loss(), between=[\"u\", \"u_ref\"], level=\"g\"\n        )\n    ]\n\n    metrics_te = [\n        esp.metrics.GraphMetric(\n            base_metric=base_metric,\n            between=[param, param + \"_ref\"],\n            level=term,\n        )\n        for param in [\"u\"]\n        for term in [\"g\"]\n        for base_metric in [esp.metrics.rmse, esp.metrics.r2]\n    ]\n\n    optimizer = getattr(torch.optim, args.optimizer)(\n        net.parameters(), lr=args.lr\n    )\n\n    exp = esp.TrainAndTest(\n        ds_tr=ds_tr,\n        ds_te=ds_te,\n        net=net,\n        metrics_tr=metrics_tr,\n        metrics_te=metrics_te,\n        n_epochs=args.n_epochs,\n        normalize=esp.data.normalize.PositiveNotNormalize,\n    )\n\n    results = exp.run()\n\n    print(esp.app.report.markdown(results))\n\n    import os\n\n    os.mkdir(args.out)\n\n    with open(args.out + \"/architecture.txt\", \"w\") as f_handle:\n        f_handle.write(str(exp))\n\n    with open(args.out + \"/result_table.md\", \"w\") as f_handle:\n        f_handle.write(esp.app.report.markdown(results))\n\n    curves = esp.app.report.curve(results)\n\n    for spec, curve in curves.items():\n        np.save(args.out + \"/\" + \"_\".join(spec) + \".npy\", curve)\n\n    import pickle\n\n    with open(args.out + \"/ref_g_test.th\", \"wb\") as f_handle:\n        pickle.dump(exp.ref_g_test, f_handle)\n\n    with open(args.out + \"/ref_g_training.th\", \"wb\") as f_handle:\n        pickle.dump(exp.ref_g_training, f_handle)\n\n    print(esp.app.report.markdown(results))\n\n    import pickle\n\n    with open(args.out + \"/ref_g_test.th\", \"wb\") as f_handle:\n        pickle.dump(exp.ref_g_test, f_handle)\n\n    with open(args.out + \"/ref_g_training.th\", \"wb\") as f_handle:\n        pickle.dump(exp.ref_g_training, f_handle)\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data\", default=\"alkethoh\", type=str)\n    parser.add_argument(\"--first\", default=-1, type=int)\n    parser.add_argument(\"--partition\", default=\"4:1\", type=str)\n    parser.add_argument(\"--batch_size\", default=8, type=int)\n    parser.add_argument(\n        \"--forcefield\", default=\"smirnoff99Frosst-1.1.0\", type=str\n    )\n    parser.add_argument(\"--layer\", default=\"GraphConv\", type=str)\n    parser.add_argument(\"--n_classes\", default=100, type=int)\n    parser.add_argument(\n        \"--config\", nargs=\"*\", default=[32, \"tanh\", 32, \"tanh\", 32, \"tanh\"]\n    )\n\n    parser.add_argument(\n        \"--training_metrics\", nargs=\"*\", default=[\"TypingCrossEntropy\"]\n    )\n    parser.add_argument(\n        \"--test_metrics\", nargs=\"*\", default=[\"TypingAccuracy\"]\n    )\n    parser.add_argument(\"--out\", default=\"results\", type=str)\n    parser.add_argument(\"--janossy_config\", nargs=\"*\", default=[32, \"tanh\"])\n\n    parser.add_argument(\"--n_epochs\", default=10, type=int)\n\n    parser.add_argument(\"--optimizer\", default=\"Adam\", type=str)\n    parser.add_argument(\"--lr\", default=1e-3, type=float)\n\n    args = parser.parse_args()\n\n    run(args)\n"
  },
  {
    "path": "espaloma/app/train_multi_typing.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport argparse\nimport torch\n\nimport espaloma as esp\n\n\ndef run(args):\n    # define data\n    data = getattr(esp.data, args.data)(first=args.first)\n\n    # get force field\n    forcefield = esp.graphs.legacy_force_field.LegacyForceField(\n        args.forcefield\n    )\n\n    # param / typing\n    operation = forcefield.multi_typing\n\n    # apply to dataset\n    data = data.apply(operation, in_place=True)\n\n    # split\n    partition = [int(x) for x in args.partition.split(\":\")]\n    ds_tr, ds_te = data.split(partition)\n\n    # batch\n    ds_tr = ds_tr.view(\"graph\", batch_size=args.batch_size)\n    ds_te = ds_te.view(\"graph\", batch_size=args.batch_size)\n\n    # layer\n    layer = esp.nn.layers.dgl_legacy.gn(args.layer)\n\n    # representation\n    representation = esp.nn.Sequential(layer, config=args.config)\n\n    # get the last bit of units\n    units = [x for x in args.config if isinstance(x, int)][-1]\n\n    readout = esp.nn.readout.janossy.JanossyPooling(\n        in_features=units,\n        config=args.janossy_config,\n        out_features={\n            1: {\"nn_typing\": 100},\n            2: {\"nn_typing\": 100},\n            3: {\"nn_typing\": 100},\n        },\n    )\n\n    net = torch.nn.Sequential(representation, readout)\n\n    metrics_tr = [\n        esp.metrics.GraphMetric(\n            base_metric=torch.nn.CrossEntropyLoss(),\n            between=[\"nn_typing\", \"legacy_typing\"],\n            level=term,\n        )\n        for term in [\"n1\", \"n2\", \"n3\"]\n    ]\n\n    metrics_te = [\n        esp.metrics.GraphMetric(\n            base_metric=esp.metrics.accuracy,\n            between=[\"nn_typing\", \"legacy_typing\"],\n            level=term,\n        )\n        for term in [\"n1\", \"n2\", \"n3\"]\n    ]\n\n    exp = esp.TrainAndTest(\n        ds_tr=ds_tr,\n        ds_te=ds_te,\n        net=net,\n        metrics_tr=metrics_tr,\n        metrics_te=metrics_te,\n        n_epochs=args.n_epochs,\n    )\n\n    results = exp.run()\n\n    print(esp.app.report.markdown(results))\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data\", default=\"alkethoh\", type=str)\n    parser.add_argument(\"--first\", default=-1, type=int)\n    parser.add_argument(\"--partition\", default=\"4:1\", type=str)\n    parser.add_argument(\"--batch_size\", default=8, type=int)\n    parser.add_argument(\n        \"--forcefield\", default=\"smirnoff99Frosst-1.1.0\", type=str\n    )\n    parser.add_argument(\"--layer\", default=\"GraphConv\", type=str)\n    parser.add_argument(\"--n_classes\", default=100, type=int)\n    parser.add_argument(\n        \"--config\", nargs=\"*\", default=[32, \"tanh\", 32, \"tanh\", 32, \"tanh\"]\n    )\n\n    parser.add_argument(\n        \"--training_metrics\", nargs=\"*\", default=[\"TypingCrossEntropy\"]\n    )\n    parser.add_argument(\n        \"--test_metrics\", nargs=\"*\", default=[\"TypingAccuracy\"]\n    )\n\n    parser.add_argument(\"--janossy_config\", nargs=\"*\", default=[32, \"tanh\"])\n\n    parser.add_argument(\"--n_epochs\", default=10, type=int)\n\n    args = parser.parse_args()\n\n    run(args)\n"
  },
  {
    "path": "espaloma/data/__init__.py",
    "content": "\"\"\" Handles the dataset and collections of espaloma. \"\"\"\nfrom . import dataset, md, normalize, utils, qcarchive_utils, md17_utils\nfrom .collection import *\n"
  },
  {
    "path": "espaloma/data/collection.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport espaloma as esp\n\n\n# =============================================================================\n# MODULE CLASSES\n# =============================================================================\ndef esol(*args, **kwargs):\n    \"\"\"ESOL collection.\n\n    ..[1] ESOL:  Estimating Aqueous Solubility Directly from Molecular Structure\n        John S. Delaney\n        Journal of Chemical Information and Computer Sciences\n        2004 44 (3), 1000-1005\n        DOI: 10.1021/ci034243x\n    \"\"\"\n    import os\n\n    import pandas as pd\n\n    path = os.path.dirname(esp.__file__) + \"/data/esol.csv\"\n    df = pd.read_csv(path)\n    smiles = df.iloc[:, -1]\n    return esp.data.dataset.GraphDataset(smiles, *args, **kwargs)\n\n\ndef alkethoh(*args, **kwargs):\n    \"\"\"AlkEthOH collection.\n\n    ..[1] Open Force Field Consortium: Escaping atom types using direct chemical\n    perception with SMIRNOFF v0.1\n    David L. Mobley, Caitlin C. Bannan, Andrea Rizzi, Christopher I. Bayly,\n    John D. Chodera, Victoria T. Lim, Nathan M. Lim, Kyle A. Beauchamp,\n    Michael R. Shirts, Michael K. Gilson, Peter K. Eastman\n    bioRxiv 286542; doi: https://doi.org/10.1101/286542\n\n    \"\"\"\n    import os\n\n    import pandas as pd\n\n    df = pd.concat(\n        [\n            pd.read_csv(\n                \"https://raw.githubusercontent.com/openff.toolkit/\"\n                \"open-forcefield-data/master/Model-Systems/AlkEthOH_distrib/\"\n                \"AlkEthOH_rings.smi\",\n                header=None,\n            ),\n            pd.read_csv(\n                \"https://raw.githubusercontent.com/openff.toolkit/\"\n                \"open-forcefield-data/master/Model-Systems/AlkEthOH_distrib/\"\n                \"AlkEthOH_chain.smi\",\n                header=None,\n            ),\n        ],\n        axis=0,\n    )\n\n    smiles = df.iloc[:, 0].values\n    return esp.data.dataset.GraphDataset(smiles, *args, **kwargs)\n\n\ndef zinc(first=-1, *args, **kwargs):\n    \"\"\"ZINC collection.\n\n    ..[1] Irwin, John J, and Brian K Shoichet.\n    “ZINC\n    --a free database of commercially available compounds for virtual screening.”\n    Journal of chemical information and modeling\n    vol. 45,1 (2005): 177-82. doi:10.1021/ci049714+\n    \"\"\"\n    import tarfile\n    from os.path import exists\n    from openff.toolkit.topology import Molecule\n    from rdkit import Chem\n\n    fname = \"parm_at_Frosst.tgz\"\n    url = \"http://www.ccl.net/cca/data/parm_at_Frosst/parm_at_Frosst.tgz\"\n\n    if not exists(fname):\n        import urllib.request\n\n        urllib.request.urlretrieve(url, fname)\n\n    archive = tarfile.open(fname)\n    zinc_file = archive.extractfile(\"parm_at_Frosst/zinc.sdf\")\n    _mols = Chem.ForwardSDMolSupplier(zinc_file, removeHs=False)\n\n    count = 0\n    gs = []\n\n    for mol in _mols:\n        try:\n            gs.append(\n                esp.Graph(\n                    Molecule.from_rdkit(mol, allow_undefined_stereo=True)\n                )\n            )\n\n            count += 1\n\n        except:\n            pass\n\n        if first != -1 and count >= first:\n            break\n\n    return esp.data.dataset.GraphDataset(gs, *args, **kwargs)\n\n\ndef md17_old(*args, **kwargs):\n    return [\n        esp.data.md17_utils.get_molecule(name, *args, **kwargs)\n        for name in [\n            \"benzene\",\n            \"uracil\",\n            \"naphthalene\",\n            \"aspirin\",\n            \"salicylic\",\n            \"malonaldehyde\",\n            \"ethanol\",\n            \"toluene\",\n            \"paracetamol\",\n            \"azobenzene\",\n        ]\n    ]\n\n\ndef md17_new(*args, **kwargs):\n    return [\n        esp.data.md17_utils.get_molecule(name, *args, **kwargs).heterograph\n        for name in [\n            \"paracetamol\",\n            \"azobenzene\",\n            \"benzene\",\n            \"ethanol\",\n        ]\n    ]\n\n\nclass qca(object):\n    pass\n\n\ndf_names = [\n    \"Bayer\",\n    \"Coverage\",\n    \"eMolecules\",\n    \"Pfizer\",\n    \"Roche\",\n    \"Benchmark\",\n    \"fda\",\n]\n\n\ndef _get_ds(cls, df_name):\n    import os\n    import pandas as pd\n\n    path = os.path.dirname(esp.__file__) + \"/../data/qca/%s.h5\" % df_name\n    df = pd.read_hdf(path)\n    ds = esp.data.qcarchive_utils.h5_to_dataset(df)\n    return ds\n\n\nfrom functools import partial\n\nfor df_name in df_names:\n    setattr(\n        qca,\n        df_name.lower(),\n        classmethod(partial(_get_ds, df_name=df_name)),\n    )\n"
  },
  {
    "path": "espaloma/data/dataset.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport abc\n\nimport torch\n\nimport espaloma as esp\n\n\n# =============================================================================\n# MODULE CLASSES\n# =============================================================================\nclass Dataset(abc.ABC, torch.utils.data.Dataset):\n    \"\"\"The base class of map-style dataset.\n\n    Parameters\n    ----------\n    graphs : List\n        objects in the dataset\n\n    Methods\n    -------\n    shuffle\n        Randomly shuffle the graphs in the dataset.\n\n    apply(fn, in_place=True)\n        Apply a function to every graph in the dataset.\n        If `in_place=True`, modify the graph in-place.\n\n    split(partitions)\n        Split the dataset into partitions\n\n    subsample(ratio, seed=None)\n        Subsample the dataset.\n\n    save(path)\n        Save the dataset to a local path.\n\n    load(path)\n        Load a dataset from local path.\n\n    Note\n    ----\n    This also supports iterative-style dataset by deleting `__getitem__`\n    and `__len__` function.\n\n    Attributes\n    ----------\n    transforms : an iterable of callables that transforms the input.\n        the `__getiem__` method applies these transforms later.\n\n    Examples\n    --------\n    >>> data = Dataset([esp.Graph(\"C\")])\n\n    \"\"\"\n\n    def __init__(self, graphs=None):\n        super(Dataset, self).__init__()\n        self.graphs = graphs\n        self.transforms = None\n\n    def __len__(self):\n        # 0 len if no graphs\n        if self.graphs is None:\n            return 0\n\n        else:\n            return len(self.graphs)\n\n    def __getitem__(self, idx):\n        if self.graphs is None:\n            raise RuntimeError(\"Empty molecule dataset.\")\n\n        if isinstance(idx, int):  # sinlge element\n            if self.transforms is None:  # when no transform act like list\n                return self.graphs[idx]\n\n            else:\n                graph = self.graphs[idx]\n\n                # nested transforms\n                for transform in self.transforms:\n                    graph = transform(graph)\n\n                return graph\n\n        elif isinstance(idx, slice):\n            # implement slicing\n            if self.transforms is None:\n                # return a Dataset object rather than list\n                return self.__class__(graphs=self.graphs[idx])\n            else:\n                graphs = []\n                for graph in self.graphs[idx]:\n\n                    # nested transforms\n                    for transform in self.transforms:\n                        graph = transform(graph)\n                    graphs.append(graph)\n\n                return self.__class__(graphs=graphs)\n\n        elif isinstance(idx, list):\n            # implement slicing\n            if self.transforms is None:\n                # return a Dataset object rather than list\n                return self.__class__(\n                    graphs=[self.graphs[_idx] for _idx in idx]\n                )\n            else:\n                graphs = []\n                for _idx in idx:\n                    graph = self[_idx]\n                    # nested transforms\n                    for transform in self.transforms:\n                        graph = transform(graph)\n                    graphs.append(graph)\n\n                return self.__class__(graphs=graphs)\n\n    def __iter__(self):\n        if self.transforms is None:\n            return iter(self.graphs)\n\n        else:\n            # TODO:\n            # is this efficient?\n            graphs = iter(self.graphs)\n            for transform in self.transforms:\n                graphs = map(transform, graphs)\n\n            return graphs\n\n    def shuffle(self, seed=None):\n        import random\n        from random import shuffle\n\n        if seed is not None:\n            random.seed(seed)\n\n        shuffle(self.graphs)\n        return self\n\n    def apply(self, fn, in_place=False):\n        r\"\"\"Apply functions to the elements of the dataset.\n\n        Parameters\n        ----------\n        fn : callable\n\n        Note\n        ----\n        If in_place is False, `fn` is added to the `transforms` else it is applied\n        to elements and modifies them.\n\n        \"\"\"\n        assert callable(fn)\n        assert isinstance(in_place, bool)\n\n        if in_place is False:  # add to list of transforms\n            if self.transforms is None:\n                self.transforms = []\n\n            self.transforms.append(fn)\n\n        else:  # modify in-place\n            # self.graphs = list(map(fn, self.graphs))\n            _graphs = []\n            for graph in self.graphs:\n                try:\n                    _graphs.append(fn(graph))\n                except:\n                    pass\n            self.graphs = _graphs\n\n        return self  # to allow grammar: ds = ds.apply(...)\n\n    def split(self, partition):\n        \"\"\"Split the dataset according to some partition.\n\n        Parameters\n        ----------\n        partition : sequence of integers or floats\n\n        \"\"\"\n        n_data = len(self)\n        p_sizes = []\n        for i, _partition in enumerate(partition):\n            p_size = int((n_data - sum(p_sizes)) * _partition / sum(partition[i:]))\n            p_sizes.append(p_size)\n        assert sum(p_sizes) == n_data, f\"{p_sizes}, {sum(p_sizes)}\"\n        ds = []\n        idx = 0\n        for p_size in p_sizes:\n            ds.append(self[idx : idx + p_size])\n            idx += p_size\n\n        return ds\n\n    def subsample(self, ratio, seed=None):\n        \"\"\"Subsample the dataset according to some ratio.\n\n        Parameters\n        ----------\n        ratio : float\n            Ratio between the size of the subsampled dataset and the\n            original dataset.\n\n        \"\"\"\n        n_data = len(self)\n        idxs = list(range(n_data))\n        import random\n\n        random.seed(seed)\n        _idxs = random.choices(idxs, k=int(n_data * ratio))\n        return self[_idxs]\n\n    def save(self, path):\n        \"\"\"Save dataset to path.\n\n        Parameters\n        ----------\n        path : path-like object\n        \"\"\"\n        import pickle\n\n        with open(path, \"wb\") as f_handle:\n            pickle.dump(self.graphs, f_handle)\n\n    def regenerate_impropers(self, improper_def=\"smirnoff\"):\n        \"\"\"\n        Regenerate the improper nodes for all graphs.\n\n        Parameters\n        ----------\n        improper_def : str\n            Which convention to use for permuting impropers.\n        \"\"\"\n        from espaloma.graphs.utils.regenerate_impropers import (\n            regenerate_impropers,\n        )\n\n        for g in self.graphs:\n            regenerate_impropers(g, improper_def)\n\n    @classmethod\n    def load(cls, path):\n        \"\"\"Load path to dataset.\n\n        Parameters\n        ----------\n        \"\"\"\n        import pickle\n\n        with open(path, \"rb\") as f_handle:\n            graphs = pickle.load(f_handle)\n\n        return cls(graphs)\n\n    def __add__(self, x):\n        return self.__class__(self.graphs + x.graphs)\n\n\nclass GraphDataset(Dataset):\n    \"\"\"Dataset with additional support for only viewing\n    certain attributes as `torch.utils.data.DataLoader`\n\n    Methods\n    -------\n    view(collate_fn, *args, **kwargs)\n        Provide a `torch.utils.data.DataLoader` view of the dataset.\n\n    Note\n    \"\"\"\n\n    def __init__(self, graphs=[], first=None):\n        super(GraphDataset, self).__init__()\n        from openff.toolkit.topology import Molecule\n\n        if all(\n            isinstance(graph, Molecule) or isinstance(graph, str)\n            for graph in graphs\n        ):\n\n            if first is None or first == -1:\n                graphs = [esp.Graph(graph) for graph in graphs]\n\n            else:\n                graphs = [esp.Graph(graph) for graph in graphs[:first]]\n\n        self.graphs = graphs\n\n    @staticmethod\n    def batch(graphs):\n        import dgl\n\n        if all(isinstance(graph, esp.graphs.graph.Graph) for graph in graphs):\n            return dgl.batch([graph.heterograph for graph in graphs])\n\n        elif all(isinstance(graph, dgl.DGLGraph) for graph in graphs):\n            return dgl.batch(graphs)\n\n        elif all(isinstance(graph, dgl.DGLHeteroGraph) for graph in graphs):\n            return dgl.batch(graphs)\n\n        else:\n            raise RuntimeError(\n                \"Can only batch DGLGraph or DGLHeterograph,\"\n                \"now have %s\" % type(graphs[0])\n            )\n\n    def view(self, collate_fn=\"graph\", *args, **kwargs):\n        \"\"\"Provide a data loader.\n\n        Parameters\n        ----------\n        collate_fn : callable or string\n            see `collate_fn` argument for `torch.utils.data.DataLoader`\n\n\n        \"\"\"\n        if collate_fn == \"graph\":\n            collate_fn = self.batch\n\n        elif collate_fn == \"homograph\":\n\n            def collate_fn(graphs):\n                graph = self.batch([g.homograph for g in graphs])\n\n                return graph\n\n        elif collate_fn == \"graph-typing\":\n\n            def collate_fn(graphs):\n                graph = self.batch(graphs)\n                y = graph.ndata[\"legacy_typing\"]\n                return graph, y\n\n        elif collate_fn == \"graph-typing-loss\":\n            loss_fn = torch.nn.CrossEntropyLoss()\n\n            def collate_fn(graphs):\n                graph = self.batch(graphs)\n                loss = lambda _graph: loss_fn(\n                    _graph.ndata[\"nn_typing\"], graph.ndata[\"legacy_typing\"]\n                )\n                return graph, loss\n\n        return torch.utils.data.DataLoader(\n            dataset=self, collate_fn=collate_fn, *args, **kwargs\n        )\n\n    def save(self, path):\n        import os\n\n        os.mkdir(path)\n        for idx, graph in enumerate(self.graphs):\n            graph.save(path + \"/\" + str(idx))\n\n    @classmethod\n    def load(cls, path):\n        import os\n\n        paths = os.listdir(path)\n        paths = [_path for _path in paths]\n\n        graphs = []\n        for _path in paths:\n            graphs.append(esp.Graph.load(path + \"/\" + _path))\n\n        return cls(graphs)\n"
  },
  {
    "path": "espaloma/data/md.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport numpy as np\nimport torch\n\nfrom openmmforcefields.generators import SystemGenerator\nimport openmm\nfrom openmm import unit\nfrom openmm.app import Simulation\nfrom openmm.unit import Quantity\n\nfrom espaloma.units import *\nimport espaloma as esp\n\n# =============================================================================\n# CONSTANTS\n# =============================================================================\n# simulation specs\nTEMPERATURE = 350 * unit.kelvin\nSTEP_SIZE = 1.0 * unit.femtosecond\nCOLLISION_RATE = 1.0 / unit.picosecond\nEPSILON_MIN = 0.05 * unit.kilojoules_per_mole\n\n# =============================================================================\n# MODULE FUNCTIONS\n# =============================================================================\ndef add_nonbonded_force(\n    g,\n    forcefield=\"gaff-1.81\",\n    add_charges=True,\n):\n\n    # parameterize topology\n    topology = g.mol.to_topology().to_openmm()\n\n    generator = SystemGenerator(\n        small_molecule_forcefield=forcefield,\n        molecules=[g.mol],\n        forcefield_kwargs={\"constraints\": None, \"removeCMMotion\": False},\n    )\n\n    # create openmm system\n    system = generator.create_system(\n        topology,\n    )\n\n    # use langevin integrator, although it's not super useful here\n    integrator = openmm.LangevinIntegrator(\n        TEMPERATURE, COLLISION_RATE, STEP_SIZE\n    )\n\n    # create simulation\n    simulation = Simulation(\n        topology=topology, system=system, integrator=integrator\n    )\n\n    # get forces\n    forces = list(system.getForces())\n\n    # loop through forces\n    for force in forces:\n        name = force.__class__.__name__\n\n        # turn off angle\n        if \"Angle\" in name:\n            for idx in range(force.getNumAngles()):\n                id1, id2, id3, angle, k = force.getAngleParameters(idx)\n                force.setAngleParameters(idx, id1, id2, id3, angle, 0.0)\n\n            force.updateParametersInContext(simulation.context)\n\n        elif \"Bond\" in name:\n            for idx in range(force.getNumBonds()):\n                id1, id2, length, k = force.getBondParameters(idx)\n                force.setBondParameters(\n                    idx,\n                    id1,\n                    id2,\n                    length,\n                    0.0,\n                )\n\n            force.updateParametersInContext(simulation.context)\n\n        elif \"Torsion\" in name:\n            for idx in range(force.getNumTorsions()):\n                (\n                    id1,\n                    id2,\n                    id3,\n                    id4,\n                    periodicity,\n                    phase,\n                    k,\n                ) = force.getTorsionParameters(idx)\n                force.setTorsionParameters(\n                    idx,\n                    id1,\n                    id2,\n                    id3,\n                    id4,\n                    periodicity,\n                    phase,\n                    0.0,\n                )\n\n            force.updateParametersInContext(simulation.context)\n\n        elif \"Nonbonded\" in name:\n            if add_charges == False:\n                for idx in range(force.getNumParticles()):\n                    q, sigma, epsilon = force.getParticleParameters(idx)\n                    force.setParticleParameters(idx, q * 1e-8, sigma, epsilon)\n                for idx in range(force.getNumExceptions()):\n                    (\n                        idx0,\n                        idx1,\n                        q,\n                        sigma,\n                        epsilon,\n                    ) = force.getExceptionParameters(idx)\n                    force.setExceptionParameters(\n                        idx, idx0, idx1, q * 1e-8, sigma, epsilon\n                    )\n\n                force.updateParametersInContext(simulation.context)\n\n    # the snapshots\n    xs = (\n        Quantity(\n            g.nodes[\"n1\"].data[\"xyz\"].detach().numpy(),\n            esp.units.DISTANCE_UNIT,\n        )\n        .value_in_unit(unit.nanometer)\n        .transpose((1, 0, 2))\n    )\n\n    # loop through the snapshots\n    energies = []\n    derivatives = []\n\n    for x in xs:\n        simulation.context.setPositions(x)\n\n        state = simulation.context.getState(\n            getEnergy=True,\n            getParameters=True,\n            getForces=True,\n        )\n\n        energy = state.getPotentialEnergy().value_in_unit(\n            esp.units.ENERGY_UNIT,\n        )\n\n        derivative = state.getForces(asNumpy=True).value_in_unit(\n            esp.units.FORCE_UNIT,\n        ) * -1\n\n        energies.append(energy)\n        derivatives.append(derivative)\n\n    # put energies to a tensor\n    energies = torch.tensor(\n        energies,\n        dtype=torch.get_default_dtype(),\n    ).flatten()[None, :]\n    derivatives = torch.tensor(\n        np.stack(derivatives, axis=1),\n        dtype=torch.get_default_dtype(),\n    )\n\n    # add the energies\n    g.heterograph.apply_nodes(\n        lambda node: {\"u\": node.data[\"u\"] + energies},\n        ntype=\"g\",\n    )\n    return g\n\n\ndef get_coulomb_force(\n    g,\n    forcefield=\"gaff-1.81\",\n):\n    # parameterize topology\n    topology = g.mol.to_topology().to_openmm()\n\n    generator = SystemGenerator(\n        small_molecule_forcefield=forcefield,\n        molecules=[g.mol],\n        forcefield_kwargs={\"constraints\": None, \"removeCMMotion\": False},\n    )\n\n    # create openmm system\n    system = generator.create_system(\n        topology,\n    )\n\n    # use langevin integrator, although it's not super useful here\n    integrator = openmm.LangevinIntegrator(\n        TEMPERATURE, COLLISION_RATE, STEP_SIZE\n    )\n\n    # create simulation\n    simulation = Simulation(\n        topology=topology, system=system, integrator=integrator\n    )\n\n    # the snapshots\n    xs = (\n        Quantity(\n            g.nodes[\"n1\"].data[\"xyz\"].detach().numpy(),\n            esp.units.DISTANCE_UNIT,\n        )\n        .value_in_unit(unit.nanometer)\n        .transpose((1, 0, 2))\n    )\n\n    # loop through the snapshots\n    energies = []\n    derivatives = []\n\n    for x in xs:\n        simulation.context.setPositions(x)\n\n        state = simulation.context.getState(\n            getEnergy=True,\n            getParameters=True,\n            getForces=True,\n        )\n\n        energy = state.getPotentialEnergy().value_in_unit(\n            esp.units.ENERGY_UNIT,\n        )\n\n        derivative = state.getForces(asNumpy=True).value_in_unit(\n            esp.units.FORCE_UNIT,\n        ) * -1\n\n        energies.append(energy)\n        derivatives.append(derivative)\n\n    # put energies to a tensor\n    energies = torch.tensor(\n        energies,\n        dtype=torch.get_default_dtype(),\n    ).flatten()[None, :]\n    derivatives = torch.tensor(\n        np.stack(derivatives, axis=1),\n        dtype=torch.get_default_dtype(),\n    )\n\n    # loop through forces\n    forces = list(system.getForces())\n    for force in forces:\n        name = force.__class__.__name__\n        if \"Nonbonded\" in name:\n            force.setNonbondedMethod(openmm.NonbondedForce.NoCutoff)\n\n            for idx in range(force.getNumParticles()):\n                q, sigma, epsilon = force.getParticleParameters(idx)\n                force.setParticleParameters(idx, q * 1e-8, sigma, epsilon)\n            for idx in range(force.getNumExceptions()):\n                idx0, idx1, q, sigma, epsilon = force.getExceptionParameters(\n                    idx\n                )\n                force.setExceptionParameters(\n                    idx, idx0, idx1, q * 1e-8, sigma, epsilon\n                )\n            force.updateParametersInContext(simulation.context)\n\n    # the snapshots\n    xs = (\n        Quantity(\n            g.nodes[\"n1\"].data[\"xyz\"].detach().numpy(),\n            esp.units.DISTANCE_UNIT,\n        )\n        .value_in_unit(unit.nanometer)\n        .transpose((1, 0, 2))\n    )\n\n    # loop through the snapshots\n    new_energies = []\n    new_derivatives = []\n\n    for x in xs:\n        simulation.context.setPositions(x)\n\n        state = simulation.context.getState(\n            getEnergy=True,\n            getParameters=True,\n            getForces=True,\n        )\n\n        energy = state.getPotentialEnergy().value_in_unit(\n            esp.units.ENERGY_UNIT,\n        )\n\n        derivative = state.getForces(asNumpy=True).value_in_unit(\n            esp.units.FORCE_UNIT,\n        ) * -1\n\n        new_energies.append(energy)\n        new_derivatives.append(derivative)\n\n    # put energies to a tensor\n    new_energies = torch.tensor(\n        new_energies,\n        dtype=torch.get_default_dtype(),\n    ).flatten()[None, :]\n\n    new_derivatives = torch.tensor(\n        np.stack(new_derivatives, axis=1),\n        dtype=torch.get_default_dtype(),\n    )\n\n    return energies - new_energies, derivatives - new_derivatives\n\n\ndef subtract_coulomb_force(\n    g,\n    forcefield=\"gaff-1.81\",\n):\n\n    delta_energies, delta_derivatives = get_coulomb_force(\n        g, forcefield=forcefield\n    )\n\n    # subtract the energies\n    g.heterograph.apply_nodes(\n        lambda node: {\"u_ref\": node.data[\"u_ref\"] - delta_energies},\n        ntype=\"g\",\n    )\n\n    if \"u_ref_prime\" in g.nodes[\"n1\"].data:\n        g.heterograph.apply_nodes(\n            lambda node: {\n                \"u_ref_prime\": node.data[\"u_ref_prime\"] - delta_derivatives\n            },\n            ntype=\"n1\",\n        )\n\n    return g\n\n\ndef subtract_nonbonded_force(\n    g,\n    forcefield=\"gaff-1.81\",\n    subtract_charges=True,\n):\n\n    # parameterize topology\n    topology = g.mol.to_topology().to_openmm()\n\n    generator = SystemGenerator(\n        small_molecule_forcefield=forcefield,\n        molecules=[g.mol],\n        forcefield_kwargs={\"constraints\": None, \"removeCMMotion\": False},\n    )\n\n    # create openmm system\n    system = generator.create_system(\n        topology,\n    )\n\n    # use langevin integrator, although it's not super useful here\n    integrator = openmm.LangevinIntegrator(\n        TEMPERATURE, COLLISION_RATE, STEP_SIZE\n    )\n\n    # create simulation\n    simulation = Simulation(\n        topology=topology, system=system, integrator=integrator\n    )\n\n    # get forces\n    forces = list(system.getForces())\n\n    # loop through forces\n    for force in forces:\n        name = force.__class__.__name__\n\n        # turn off angle\n        if \"Angle\" in name:\n            for idx in range(force.getNumAngles()):\n                id1, id2, id3, angle, k = force.getAngleParameters(idx)\n                force.setAngleParameters(idx, id1, id2, id3, angle, 0.0)\n\n            force.updateParametersInContext(simulation.context)\n\n        elif \"Bond\" in name:\n            for idx in range(force.getNumBonds()):\n                id1, id2, length, k = force.getBondParameters(idx)\n                force.setBondParameters(\n                    idx,\n                    id1,\n                    id2,\n                    length,\n                    0.0,\n                )\n\n            force.updateParametersInContext(simulation.context)\n\n        elif \"Torsion\" in name:\n            for idx in range(force.getNumTorsions()):\n                (\n                    id1,\n                    id2,\n                    id3,\n                    id4,\n                    periodicity,\n                    phase,\n                    k,\n                ) = force.getTorsionParameters(idx)\n                force.setTorsionParameters(\n                    idx,\n                    id1,\n                    id2,\n                    id3,\n                    id4,\n                    periodicity,\n                    phase,\n                    0.0,\n                )\n\n            force.updateParametersInContext(simulation.context)\n\n        elif \"Nonbonded\" in name:\n            # only handle LJ potentials\n            # subtract Coulomb interaction seperately with nocutoff method if substract_charges==True\n            for idx in range(force.getNumParticles()):\n                q, sigma, epsilon = force.getParticleParameters(idx)\n                force.setParticleParameters(idx, q * 1e-8, sigma, epsilon)\n            for idx in range(force.getNumExceptions()):\n                idx0, idx1, q, sigma, epsilon = force.getExceptionParameters(\n                    idx\n                )\n                force.setExceptionParameters(\n                    idx, idx0, idx1, q * 1e-8, sigma, epsilon\n                )\n\n            force.updateParametersInContext(simulation.context)\n\n    # the snapshots\n    xs = (\n        Quantity(\n            g.nodes[\"n1\"].data[\"xyz\"].detach().numpy(),\n            esp.units.DISTANCE_UNIT,\n        )\n        .value_in_unit(unit.nanometer)\n        .transpose((1, 0, 2))\n    )\n\n    # loop through the snapshots\n    energies = []\n    derivatives = []\n\n    for x in xs:\n        simulation.context.setPositions(x)\n\n        state = simulation.context.getState(\n            getEnergy=True,\n            getParameters=True,\n            getForces=True,\n        )\n\n        energy = state.getPotentialEnergy().value_in_unit(\n            esp.units.ENERGY_UNIT,\n        )\n\n        derivative = state.getForces(asNumpy=True).value_in_unit(\n            esp.units.FORCE_UNIT,\n        ) * -1\n\n        energies.append(energy)\n        derivatives.append(derivative)\n\n    # put energies to a tensor\n    energies = torch.tensor(\n        energies,\n        dtype=torch.get_default_dtype(),\n    ).flatten()[None, :]\n    derivatives = torch.tensor(\n        np.stack(derivatives, axis=1),\n        dtype=torch.get_default_dtype(),\n    )\n\n    # subtract the energies\n    g.heterograph.apply_nodes(\n        lambda node: {\"u_ref\": node.data[\"u_ref\"] - energies},\n        ntype=\"g\",\n    )\n\n    if \"u_ref_prime\" in g.nodes[\"n1\"].data:\n        g.heterograph.apply_nodes(\n            lambda node: {\n                \"u_ref_prime\": node.data[\"u_ref_prime\"] - derivatives\n            },\n            ntype=\"n1\",\n        )\n\n    if subtract_charges:\n        g = subtract_coulomb_force(g)\n\n    return g\n\n\ndef subtract_nonbonded_force_except_14(\n    g,\n    forcefield=\"gaff-1.81\",\n):\n\n    # parameterize topology\n    topology = g.mol.to_topology().to_openmm()\n\n    generator = SystemGenerator(\n        small_molecule_forcefield=forcefield,\n        molecules=[g.mol],\n    )\n\n    # create openmm system\n    system = generator.create_system(\n        topology,\n    )\n\n    # use langevin integrator, although it's not super useful here\n    integrator = openmm.LangevinIntegrator(\n        TEMPERATURE, COLLISION_RATE, STEP_SIZE\n    )\n\n    # create simulation\n    simulation = Simulation(\n        topology=topology, system=system, integrator=integrator\n    )\n\n    # get forces\n    forces = list(system.getForces())\n\n    # loop through forces\n    for force in forces:\n        name = force.__class__.__name__\n\n        # turn off angle\n        if \"Angle\" in name:\n            for idx in range(force.getNumAngles()):\n                id1, id2, id3, angle, k = force.getAngleParameters(idx)\n                force.setAngleParameters(idx, id1, id2, id3, angle, 0.0)\n\n            force.updateParametersInContext(simulation.context)\n\n        elif \"Bond\" in name:\n            for idx in range(force.getNumBonds()):\n                id1, id2, length, k = force.getBondParameters(idx)\n                force.setBondParameters(\n                    idx,\n                    id1,\n                    id2,\n                    length,\n                    0.0,\n                )\n\n            force.updateParametersInContext(simulation.context)\n\n        elif \"Torsion\" in name:\n            for idx in range(force.getNumTorsions()):\n                (\n                    id1,\n                    id2,\n                    id3,\n                    id4,\n                    periodicity,\n                    phase,\n                    k,\n                ) = force.getTorsionParameters(idx)\n                force.setTorsionParameters(\n                    idx,\n                    id1,\n                    id2,\n                    id3,\n                    id4,\n                    periodicity,\n                    phase,\n                    0.0,\n                )\n\n            force.updateParametersInContext(simulation.context)\n\n        elif \"Nonbonded\" in name:\n            for idx in range(force.getNumExceptions()):\n                idx0, idx1, q, sigma, epsilon = force.getExceptionParameters(\n                    idx\n                )\n                force.setExceptionParameters(\n                    idx, idx0, idx1, q, sigma, epsilon * 1e-8\n                )\n            force.updateParametersInContext(simulation.context)\n\n    # the snapshots\n    xs = (\n        Quantity(\n            g.nodes[\"n1\"].data[\"xyz\"].detach().numpy(),\n            esp.units.DISTANCE_UNIT,\n        )\n        .value_in_unit(unit.nanometer)\n        .transpose((1, 0, 2))\n    )\n\n    # loop through the snapshots\n    energies = []\n    derivatives = []\n\n    for x in xs:\n        simulation.context.setPositions(x)\n\n        state = simulation.context.getState(\n            getEnergy=True,\n            getParameters=True,\n            getForces=True,\n        )\n\n        energy = state.getPotentialEnergy().value_in_unit(\n            esp.units.ENERGY_UNIT,\n        )\n\n        derivative = state.getForces(asNumpy=True).value_in_unit(\n            esp.units.FORCE_UNIT,\n        ) * -1\n\n        energies.append(energy)\n        derivatives.append(derivative)\n\n    # put energies to a tensor\n    energies = torch.tensor(\n        energies,\n        dtype=torch.get_default_dtype(),\n    ).flatten()[None, :]\n    derivatives = torch.tensor(\n        np.stack(derivatives, axis=1),\n        dtype=torch.get_default_dtype(),\n    )\n\n    # subtract the energies\n    g.heterograph.apply_nodes(\n        lambda node: {\"u_ref\": node.data[\"u_ref\"] - energies},\n        ntype=\"g\",\n    )\n\n    if \"u_ref_prime\" in g.nodes[\"n1\"].data:\n\n        g.heterograph.apply_nodes(\n            lambda node: {\n                \"u_ref_prime\": node.data[\"u_ref_prime\"] - derivatives\n            },\n            ntype=\"n1\",\n        )\n\n    return g\n\n\n# =============================================================================\n# MODULE CLASSES\n# =============================================================================\nclass MoleculeVacuumSimulation(object):\n    \"\"\"Simluate a single molecule system in vaccum.\n\n    Parameters\n    ----------\n    g : `espaloma.Graph`\n        Input molecular graph.\n\n    n_samples : `int`\n        Number of samples to collect.\n\n    n_steps_per_sample : `int`\n        Number of steps between each sample.\n\n    temperature : `float * unit.kelvin`\n        Temperature for the simluation.\n\n    collision_rate : `float / unit.picosecond`\n        Collision rate.\n\n    timestep : `float * unit.femtosecond`\n        Time step.\n\n    Methods\n    -------\n    simulation_from_graph : Create simluation from molecule.\n\n    run : Run the simluation.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        forcefield=\"gaff-1.81\",\n        n_samples=100,\n        n_conformers=10,\n        n_steps_per_sample=1000,\n        temperature=TEMPERATURE,\n        collision_rate=COLLISION_RATE,\n        step_size=STEP_SIZE,\n        charge_method=None,\n    ):\n\n        self.n_samples = n_samples\n        self.n_steps_per_sample = n_steps_per_sample\n        self.temperature = temperature\n        self.collision_rate = collision_rate\n        self.step_size = step_size\n        self.forcefield = forcefield\n        self.n_conformers = n_conformers\n        self.charge_method = charge_method\n\n    def simulation_from_graph(self, g):\n        \"\"\"Create simulation from moleucle\"\"\"\n        # assign partial charge\n        if self.charge_method is not None:\n            g.mol.assign_partial_charges(self.charge_method)\n\n        # parameterize topology\n        topology = g.mol.to_topology().to_openmm()\n\n        generator = SystemGenerator(\n            small_molecule_forcefield=self.forcefield,\n            molecules=[g.mol],\n        )\n\n        # create openmm system\n        system = generator.create_system(\n            topology,\n        )\n\n        # set epsilon minimum to 0.05 kJ/mol\n        for force in system.getForces():\n            if \"Nonbonded\" in force.__class__.__name__:\n                force.setNonbondedMethod(openmm.NonbondedForce.NoCutoff)\n                for particle_index in range(force.getNumParticles()):\n                    charge, sigma, epsilon = force.getParticleParameters(\n                        particle_index\n                    )\n                    if epsilon < EPSILON_MIN:\n                        force.setParticleParameters(\n                            particle_index, charge, sigma, EPSILON_MIN\n                        )\n\n        # use langevin integrator\n        integrator = openmm.LangevinIntegrator(\n            self.temperature, self.collision_rate, self.step_size\n        )\n\n        # initialize simulation\n        simulation = Simulation(\n            topology=topology,\n            system=system,\n            integrator=integrator,\n            platform=openmm.Platform.getPlatformByName(\"Reference\"),\n        )\n\n        return simulation\n\n    def run(self, g, in_place=True):\n        \"\"\"Collect samples from simulation.\n\n        Parameters\n        ----------\n        g : `esp.Graph`\n            Input graph.\n\n        in_place : `bool`\n            If ture,\n\n        Returns\n        -------\n        samples : `torch.Tensor`, `shape=(n_samples, n_nodes, 3)`\n            `in_place=True`\n            Sample.\n\n        graph : `esp.Graph`\n            Modified graph.\n\n        \"\"\"\n        # build simulation\n        simulation = self.simulation_from_graph(g)\n\n        import openff.toolkit\n\n        # get conformer\n        g.mol.generate_conformers(\n            toolkit_registry=openff.toolkit.utils.RDKitToolkitWrapper(),\n            n_conformers=self.n_conformers,\n        )\n\n        # get number of actual conformers\n        true_n_conformers = len(g.mol.conformers)\n\n        samples = []\n        for idx in range(true_n_conformers):\n            # put conformer in simulation\n            simulation.context.setPositions(g.mol.conformers[idx].to_openmm())\n\n            # set velocities\n            simulation.context.setVelocitiesToTemperature(self.temperature)\n\n            # minimize\n            simulation.minimizeEnergy()\n\n            # loop through number of samples\n            for _ in range(self.n_samples // self.n_conformers):\n\n                # run MD for `self.n_steps_per_sample` steps\n                simulation.step(self.n_steps_per_sample)\n\n                # append samples to `samples`\n                samples.append(\n                    simulation.context.getState(getPositions=True)\n                    .getPositions(asNumpy=True)\n                    .value_in_unit(DISTANCE_UNIT)\n                )\n\n        # if the `samples` array is not filled,\n        # pick a random conformer to do it again\n        if len(samples) < self.n_samples:\n            len_samples = len(samples)\n            import random\n\n            idx = random.choice(list(range(true_n_conformers)))\n            simulation.context.setPositions(g.mol.conformers[idx].to_openmm())\n\n            # set velocities\n            simulation.context.setVelocitiesToTemperature(self.temperature)\n\n            # minimize\n            simulation.minimizeEnergy()\n\n            # loop through number of samples\n            for _ in range(self.n_samples - len_samples):\n\n                # run MD for `self.n_steps_per_sample` steps\n                simulation.step(self.n_steps_per_sample)\n\n                # append samples to `samples`\n                samples.append(\n                    simulation.context.getState(getPositions=True)\n                    .getPositions(asNumpy=True)\n                    .value_in_unit(DISTANCE_UNIT)\n                )\n\n        assert len(samples) == self.n_samples\n\n        # put samples into an array\n        samples = np.array(samples)\n\n        # put samples into tensor\n        samples = torch.tensor(samples, dtype=torch.float32)\n\n        if in_place is True:\n            g.heterograph.nodes[\"n1\"].data[\"xyz\"] = samples.permute(1, 0, 2)\n\n            # require gradient for force matching\n            g.heterograph.nodes[\"n1\"].data[\"xyz\"].requires_grad = True\n\n            return g\n\n        return samples\n"
  },
  {
    "path": "espaloma/data/md17_utils.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport numpy as np\nimport torch\nimport espaloma as esp\nfrom openmm import unit\nfrom openmm.unit import Quantity\n\n# =============================================================================\n# CONSTANTS\n# =============================================================================\nMOLECULES = {\n    \"benzene\": \"C1=CC=CC=C1\",\n    \"uracil\": \"O=C1NC=CC(=O)N1\",\n    \"naphthalene\": \"C1=CC=C2C=CC=CC2=C1\",\n    \"aspirin\": \"CC(=O)OC1=CC=CC=C1C(=O)O\",\n    \"salicylic\": \"C1=CC=C(C(=C1)C(=O)O)O\",\n    \"malonaldehyde\": \"C(C=O)C=O\",\n    \"ethanol\": \"CCO\",\n    \"toluene\": \"CC1=CC=CC=C1\",\n    \"paracetamol\": \"CC(=O)NC1=CC=C(C=C1)O\",\n    \"azobenzene\": \"C1=CC=C(C=C1)N=NC2=CC=CC=C2\",\n}\n\nOFFSETS = {\n    1: -0.500607632585,\n    6: -37.8302333826,\n    7: -54.5680045287,\n    8: -75.0362229210,\n}\n\n# ==============================================================================\n# UTILITY FUNCTIONS\n# ==============================================================================\ndef sum_offsets(elements):\n    return sum([OFFSETS[element] for element in elements])\n\n\ndef realize_molecule(\n    data, name, smiles=None, first=-1, subtract_nonbonded=True\n):\n    elements = data[\"z\"].tolist()\n\n    offset = sum_offsets(elements)\n\n    g = esp.data.utils.infer_mol_from_coordinates(\n        data[\"R\"][0],\n        elements,\n        smiles,\n    )\n\n    g.nodes[\"n1\"].data[\"xyz\"] = torch.tensor(\n        Quantity(\n            data[\"R\"].transpose(1, 0, 2),\n            unit.angstrom,\n        ).value_in_unit(esp.units.DISTANCE_UNIT),\n        requires_grad=True,\n    )[:, :first, :]\n\n    g.nodes[\"g\"].data[\"u_ref\"] = (\n        torch.tensor(\n            Quantity(\n                data[\"E\"],\n                unit.kilocalorie_per_mole,\n            ).value_in_unit(esp.units.ENERGY_UNIT)\n        ).transpose(1, 0)[:, :first]\n        - offset\n    )\n\n    g.nodes[\"n1\"].data[\"u_ref_prime\"] = torch.tensor(\n        Quantity(\n            data[\"F\"],\n            unit.kilocalorie_per_mole / unit.angstrom,\n        ).value_in_unit(esp.units.FORCE_UNIT)\n    ).transpose(1, 0)[:, :first, :]\n\n    if subtract_nonbonded is True:\n        g = esp.data.md.subtract_nonbonded_force(g)\n\n    return g\n\n\ndef get_molecule(name, *args, **kwargs):\n    if name == \"benzene\":\n        file_name = \"benzene_old_dft.npz\"\n    else:\n        file_name = \"%s_dft.npz\" % name\n\n    from os.path import exists\n\n    if not exists(file_name):\n        url = \"http://www.quantum-machine.org/gdml/data/npz/%s\" % file_name\n        print(url)\n        import urllib.request\n\n        urllib.request.urlretrieve(url, file_name)\n\n    data = np.load(file_name)\n\n    smiles = MOLECULES[name]\n\n    g = realize_molecule(data, name, smiles, *args, **kwargs)\n\n    return g\n"
  },
  {
    "path": "espaloma/data/normalize.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport abc\n\nimport torch\n\nimport espaloma as esp\n\n\n# =============================================================================\n# BASE CLASSES\n# =============================================================================\nclass BaseNormalize(abc.ABC):\n    \"\"\"Base class for normalizing operation.\"\"\"\n\n    def __init__(self):\n        super(BaseNormalize, self).__init__()\n\n    @abc.abstractmethod\n    def _prepare(self):\n        # NOTE:\n        # `_norm` and `_unnorm` are assigned here\n        raise NotImplementedError\n\n\n# =============================================================================\n# MODULE CLASSES\n# =============================================================================\nclass DatasetNormalNormalize(BaseNormalize):\n    \"\"\"Normalizing operation based on a dataset of molecules,\n    assuming parameters having normal distribution.\n\n    Parameters\n    ----------\n    dataset : `espaloma.data.dataset.Dataset`\n        The dataset we base on to calculate the statistics of parameter\n        distributions.\n\n    Attributes\n    ----------\n    norm : normalize function\n\n    unnorm : unnormalize function\n\n    \"\"\"\n\n    def __init__(self, dataset):\n        super(DatasetNormalNormalize, self).__init__()\n        self.dataset = dataset\n        self._prepare()\n\n    def _prepare(self):\n        \"\"\"Calculate the statistics from dataset\"\"\"\n        # grab the collection of graphs in the dataset, batched\n        g = self.dataset.batch(self.dataset.graphs)\n\n        self.statistics = {term: {} for term in [\"n1\", \"n2\", \"n3\", \"n4\"]}\n\n        # calculate statistics\n        for term in [\"n1\", \"n2\", \"n3\", \"n4\"]:  # loop through terms\n            for key in g.nodes[term].data.keys():  # loop through parameters\n                if not key.endswith(\"ref\"):  # pass non-parameters\n                    continue\n\n                self.statistics[term][\n                    key.replace(\"_ref\", \"_mean\")\n                ] = torch.mean(g.nodes[term].data[key], axis=0)\n\n                self.statistics[term][\n                    key.replace(\"_ref\", \"_std\")\n                ] = torch.std(g.nodes[term].data[key], axis=0)\n\n        # get normalize and unnormalize functions\n        def norm(g):\n            for term in [\"n1\", \"n2\", \"n3\", \"n4\"]:  # loop through terms\n                for key in g.nodes[\n                    term\n                ].data.keys():  # loop through parameters\n                    if not key.endswith(\"ref\"):  # pass non-parameters\n                        continue\n\n                    g.nodes[term].data[key] = (\n                        g.nodes[term].data[key]\n                        - self.statistics[term][key.replace(\"_ref\", \"_mean\")]\n                    ) / self.statistics[term][key.replace(\"_ref\", \"_std\")]\n\n            return g\n\n        def unnorm(g):\n            for term in [\"n1\", \"n2\", \"n3\", \"n4\"]:  # loop through terms\n                for key in g.nodes[\n                    term\n                ].data.keys():  # loop through parameters\n\n                    if key + \"_mean\" in self.statistics[term]:\n\n                        g.nodes[term].data[key] = (\n                            g.nodes[term].data[key]\n                            * self.statistics[term][key + \"_std\"]\n                            + self.statistics[term][key + \"_mean\"]\n                        )\n                    #\n                    # elif '_ref' in key \\\n                    #     and key.replace('_ref', '_mean')\\\n                    #     in self.statistics[term]:\n                    #\n                    #     g.nodes[term].data[key]\\\n                    #         = g.nodes[term].data[key]\\\n                    #             * self.statistics[term][\n                    #                 key.replace('_ref', '_std')]\\\n                    #             + self.statistics[term][\n                    #                 key.replace('_ref', '_mean')]\n\n            return g\n\n        # point normalize and unnormalize functions to `self`\n        self.norm = norm\n        self.unnorm = unnorm\n\n\nclass DatasetLogNormalNormalize(BaseNormalize):\n    \"\"\"Normalizing operation based on a dataset of molecules,\n    assuming parameters having log normal distribution.\n\n    Parameters\n    ----------\n    dataset : `espaloma.data.dataset.Dataset`\n        The dataset we base on to calculate the statistics of parameter\n        distributions.\n\n    Attributes\n    ----------\n    norm : normalize function\n\n    unnorm : unnormalize function\n\n    \"\"\"\n\n    def __init__(self, dataset):\n        super(DatasetLogNormalNormalize, self).__init__()\n        self.dataset = dataset\n        self._prepare()\n\n    def _prepare(self):\n        \"\"\"Calculate the statistics from dataset\"\"\"\n        # grab the collection of graphs in the dataset, batched\n        g = self.dataset.batch(self.dataset.graphs)\n\n        self.statistics = {term: {} for term in [\"n1\", \"n2\", \"n3\", \"n4\"]}\n\n        # calculate statistics\n        for term in [\"n1\", \"n2\", \"n3\", \"n4\"]:  # loop through terms\n            for key in g.nodes[term].data.keys():  # loop through parameters\n                if not key.endswith(\"ref\"):  # pass non-parameters\n                    continue\n\n                self.statistics[term][\n                    key.replace(\"_ref\", \"_mean\")\n                ] = torch.mean(g.nodes[term].data[key].log(), axis=0)\n\n                self.statistics[term][\n                    key.replace(\"_ref\", \"_std\")\n                ] = torch.std(g.nodes[term].data[key].log(), axis=0)\n\n        # get normalize and unnormalize functions\n        def norm(g):\n            for term in [\"n1\", \"n2\", \"n3\", \"n4\"]:  # loop through terms\n                for key in g.nodes[\n                    term\n                ].data.keys():  # loop through parameters\n                    if not key.endswith(\"ref\"):  # pass non-parameters\n                        continue\n\n                    g.nodes[term].data[key] = (\n                        g.nodes[term].data[key].log()\n                        - self.statistics[term][key.replace(\"_ref\", \"_mean\")]\n                    ) / self.statistics[term][key.replace(\"_ref\", \"_std\")]\n\n            return g\n\n        def unnorm(g):\n            for term in [\"n1\", \"n2\", \"n3\", \"n4\"]:  # loop through terms\n                for key in g.nodes[\n                    term\n                ].data.keys():  # loop through parameters\n\n                    if key + \"_mean\" in self.statistics[term]:\n\n                        g.nodes[term].data[key] = torch.exp(\n                            g.nodes[term].data[key]\n                            * self.statistics[term][key + \"_std\"].to(\n                                g.nodes[term].data[key].device\n                            )\n                            + self.statistics[term][key + \"_mean\"].to(\n                                g.nodes[term].data[key].device\n                            )\n                        )\n                    #\n                    # elif '_ref' in key \\\n                    #     and key.replace('_ref', '_mean')\\\n                    #     in self.statistics[term]:\n                    #\n                    #     g.nodes[term].data[key]\\\n                    #         = torch.exp(\n                    #             g.nodes[term].data[key]\\\n                    #                 * self.statistics[term][\n                    #                     key.replace('_ref', '_std')]\\\n                    #                 + self.statistics[term][\n                    #                     key.replace('_ref', '_mean')])\n\n            return g\n\n        # point normalize and unnormalize functions to `self`\n        self.norm = norm\n        self.unnorm = unnorm\n\n\n# =============================================================================\n# PRESETS\n# =============================================================================\nclass ESOL100NormalNormalize(DatasetNormalNormalize):\n    def __init__(self):\n        super(ESOL100NormalNormalize, self).__init__(\n            dataset=esp.data.esol(first=100).apply(\n                esp.graphs.legacy_force_field.LegacyForceField(\n                    \"smirnoff99Frosst-1.1.0\"\n                ).parametrize,\n                in_place=True,\n            )\n        )\n\n\nclass ESOL100LogNormalNormalize(DatasetLogNormalNormalize):\n    def __init__(self):\n        super(ESOL100LogNormalNormalize, self).__init__(\n            dataset=esp.data.esol(first=100).apply(\n                esp.graphs.legacy_force_field.LegacyForceField(\n                    \"smirnoff99Frosst-1.1.0\"\n                ).parametrize,\n                in_place=True,\n            )\n        )\n\n\nclass NotNormalize(BaseNormalize):\n    def __init__(self):\n        super(NotNormalize).__init__()\n        self._prepare()\n\n    def _prepare(self):\n        self.norm = lambda x: x\n        self.unnorm = lambda x: x\n\n\nclass PositiveNotNormalize(BaseNormalize):\n    def __init__(self):\n        super(PositiveNotNormalize, self).__init__()\n        self._prepare()\n\n    def _prepare(self):\n\n        # get normalize and unnormalize functions\n        def norm(g):\n            for term in [\"n1\", \"n2\", \"n3\", \"n4\"]:  # loop through terms\n                for key in g.nodes[\n                    term\n                ].data.keys():  # loop through parameters\n                    if not key.endswith(\"ref\"):  # pass non-parameters\n                        continue\n\n                    g.nodes[term].data[key] = g.nodes[term].data[key].log()\n\n            return g\n\n        def unnorm(g):\n            for term in [\n                \"n2\",\n                \"n3\",\n            ]:  # loop through terms\n                for key in g.nodes[\n                    term\n                ].data.keys():  # loop through parameters\n                    if key == \"k\" or key == \"eq\":\n\n                        g.nodes[term].data[key] = torch.exp(\n                            g.nodes[term].data[key]\n                        )\n\n            return g\n\n        # point normalize and unnormalize functions to `self`\n        self.norm = norm\n        self.unnorm = unnorm\n"
  },
  {
    "path": "espaloma/data/off-mol_0_10_6.json",
    "content": "\"{\\\"name\\\": \\\"\\\", \\\"atoms\\\": [{\\\"atomic_number\\\": 8, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 7, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 7, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 7, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 7, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 16, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 17, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 6, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 1, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 1, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 1, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 1, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 1, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 1, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 1, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 1, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 1, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 1, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 1, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 1, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 1, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 1, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}, {\\\"atomic_number\\\": 1, \\\"formal_charge\\\": 0, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"name\\\": \\\"\\\"}], \\\"virtual_sites\\\": [], \\\"bonds\\\": [{\\\"atom1\\\": 0, \\\"atom2\\\": 1, \\\"bond_order\\\": 2, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 1, \\\"atom2\\\": 2, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 2, \\\"atom2\\\": 3, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 3, \\\"atom2\\\": 4, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 4, \\\"atom2\\\": 5, \\\"bond_order\\\": 2, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 5, \\\"atom2\\\": 6, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 6, \\\"atom2\\\": 7, \\\"bond_order\\\": 2, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 7, \\\"atom2\\\": 8, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 8, \\\"atom2\\\": 9, \\\"bond_order\\\": 2, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 9, \\\"atom2\\\": 10, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 10, \\\"atom2\\\": 11, \\\"bond_order\\\": 2, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 1, \\\"atom2\\\": 12, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 12, \\\"atom2\\\": 13, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 13, \\\"atom2\\\": 14, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 14, \\\"atom2\\\": 15, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 15, \\\"atom2\\\": 16, \\\"bond_order\\\": 2, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 16, \\\"atom2\\\": 17, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 17, \\\"atom2\\\": 18, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 12, \\\"atom2\\\": 19, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 19, \\\"atom2\\\": 20, \\\"bond_order\\\": 2, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 20, \\\"atom2\\\": 21, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 21, \\\"atom2\\\": 22, \\\"bond_order\\\": 2, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 22, \\\"atom2\\\": 23, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 22, \\\"atom2\\\": 24, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 24, \\\"atom2\\\": 25, \\\"bond_order\\\": 2, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 11, \\\"atom2\\\": 3, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 18, \\\"atom2\\\": 14, \\\"bond_order\\\": 2, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 25, \\\"atom2\\\": 19, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 11, \\\"atom2\\\": 6, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": true, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 2, \\\"atom2\\\": 26, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 2, \\\"atom2\\\": 27, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 7, \\\"atom2\\\": 28, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 8, \\\"atom2\\\": 29, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 9, \\\"atom2\\\": 30, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 10, \\\"atom2\\\": 31, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 13, \\\"atom2\\\": 32, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 13, \\\"atom2\\\": 33, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 15, \\\"atom2\\\": 34, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 16, \\\"atom2\\\": 35, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 18, \\\"atom2\\\": 36, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 20, \\\"atom2\\\": 37, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 21, \\\"atom2\\\": 38, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 24, \\\"atom2\\\": 39, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}, {\\\"atom1\\\": 25, \\\"atom2\\\": 40, \\\"bond_order\\\": 1, \\\"is_aromatic\\\": false, \\\"stereochemistry\\\": null, \\\"fractional_bond_order\\\": null}], \\\"properties\\\": {}, \\\"conformers\\\": null, \\\"partial_charges\\\": null, \\\"partial_charges_unit\\\": null}\""
  },
  {
    "path": "espaloma/data/qcarchive_utils.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nfrom collections import namedtuple\nfrom typing import Tuple\n\nimport numpy as np\nimport qcportal\nimport torch\nfrom openmm import unit\nfrom openmm.unit import Quantity\n\nimport espaloma as esp\n\n\n# =============================================================================\n# CONSTANTS\n# =============================================================================\n\n\n# =============================================================================\n# UTILITY FUNCTIONS\n# =============================================================================\ndef get_client(url: str = \"api.qcarchive.molssi.org\") -> qcportal.client.PortalClient:\n    \"\"\"\n    Returns a instance of the qcportal client.\n\n    Parameters\n    ----------\n    url: str, default=\"api.qcarchive.molssi.org\"\n        qcportal instance to connect\n\n    Returns\n    -------\n    qcportal.client.PortalClient\n        qcportal client instance.\n    \"\"\"\n    # Note, this may need to be modified to include username/password for non-public servers\n    return qcportal.PortalClient(url)\n\n\ndef get_collection(\n        client,\n        collection_type=\"optimization\",\n        name=\"OpenFF Full Optimization Benchmark 1\",\n):\n    \"\"\"\n    Connects to a specific dataset on qcportal\n\n    Parameters\n    ----------\n    client: qcportal.client, required\n        The qcportal client instance\n    collection_type: str, default=\"optimization\"\n        The type of qcarchive collection, options are\n        \"torsiondrive\", \"optimization\", \"gridoptimization\", \"reaction\", \"singlepoint\" \"manybody\"\n    name: str, default=\"OpenFF Full Optimization Benchmark 1\"\n        Name of the dataset\n\n    Returns\n    -------\n    (qcportal dataset, list(str))\n        Tuple with an instance of qcportal dataset and list of record names\n\n    \"\"\"\n    collection = client.get_dataset(\n        dataset_type=collection_type,\n        dataset_name=name,\n    )\n\n    record_names = collection.entry_names\n\n    return collection, record_names\n\n\ndef process_record(record, entry):\n    \"\"\"\n    Processes a given record/entry pair from a dataset and returns the graph\n\n    Parameters\n    ----------\n    record: qcportal.optimization.record_models.OptimizationRecord\n        qcportal record\n    entry: cportal.optimization.dataset_models.OptimizationDatasetEntry\n        qcportal entry\n\n    Returns\n    -------\n        esp.Graph\n    \"\"\"\n\n    from openff.toolkit.topology import Molecule\n\n    if record.record_type == \"optimization\":\n        trajectory = record.trajectory\n        if trajectory is None:\n            return None\n    else:\n        raise Exception(\n            f\"{record.record_type} is not supported: only optimization datasets can be processed.\"\n        )\n    mol = Molecule.from_qcschema(entry.dict())\n    g = esp.Graph(mol)\n\n    # energy is already hartree\n    g.nodes[\"g\"].data[\"u_ref\"] = torch.tensor(\n        [\n            Quantity(\n                snapshot.properties[\"scf_total_energy\"],\n                esp.units.HARTREE_PER_PARTICLE,\n            ).value_in_unit(esp.units.ENERGY_UNIT)\n            for snapshot in trajectory\n        ],\n        dtype=torch.get_default_dtype(),\n    )[None, :]\n\n    g.nodes[\"n1\"].data[\"xyz\"] = torch.tensor(\n        np.stack(\n            [\n                Quantity(\n                    snapshot.molecule.geometry,\n                    unit.bohr,\n                ).value_in_unit(esp.units.DISTANCE_UNIT)\n                for snapshot in trajectory\n            ],\n            axis=1,\n        ),\n        requires_grad=True,\n        dtype=torch.get_default_dtype(),\n    )\n\n    g.nodes[\"n1\"].data[\"u_ref_prime\"] = torch.stack(\n        [\n            torch.tensor(\n                Quantity(\n                    np.array(snapshot.properties[\"return_result\"]).reshape((-1, 3)),\n                    esp.units.HARTREE_PER_PARTICLE / unit.bohr,\n                ).value_in_unit(esp.units.FORCE_UNIT),\n                dtype=torch.get_default_dtype(),\n            )\n            for snapshot in trajectory\n        ],\n        dim=1,\n    )\n\n    return g\n\n\ndef get_graph(collection, record_name, spec_name=\"default\"):\n    \"\"\"\n    Processes the qcportal data for a given record name.\n\n    This supports optimization and singlepoint datasets.\n\n    Parameters\n    ----------\n    collection, qcportal dataset, required\n        The instance of the qcportal dataset\n    record_name, str, required\n        The name of a give record\n    spec_name, str, default=\"default\"\n        Retrieve data for a given qcportal specification.\n    Returns\n    -------\n        Graph\n    \"\"\"\n    # get record and trajectory\n    record = collection.get_record(record_name, specification_name=spec_name)\n    entry = collection.get_entry(record_name)\n\n    g = process_record(record, entry)\n\n    return g\n\n\ndef get_graphs(collection, record_names, spec_name=\"default\"):\n    \"\"\"\n    Processes the qcportal data for a given set of record names.\n    This uses the qcportal iteration functions which are faster than processing\n    records one at a time.\n\n    This supports optimization and singlepoint datasets.\n\n\n    Parameters\n    ----------\n    collection, qcportal dataset, required\n        The instance of the qcportal dataset\n    record_name, List[str], required\n        A list of the record_names of a give record\n    spec_name, str, default=\"default\"\n        Retrieve data for a given qcportal specification.\n    Returns\n    -------\n    list(graph)\n        Returns a list of the corresponding graph for each record name\n    \"\"\"\n    g_list = []\n    for record, entry in zip(\n            collection.iterate_records(record_names, specification_names=[spec_name]),\n            collection.iterate_entries(record_names),\n    ):\n        # note iterate records returns a tuple of length 3 (name, spec_name, actual record information)\n\n        g = process_record(record[2], entry)\n        g_list.append(g)\n\n    return g_list\n\n\ndef fetch_td_record(record: qcportal.torsiondrive.record_models.TorsiondriveRecord):\n    \"\"\"\n    Fetches configuration, energy, and gradients for a given torsiondrive record as a function of different angles.\n\n    Parameters\n    ----------\n    record: qcportal.torsiondrive.record_models.TorsiondriveRecord, required\n        Torsiondrive record of interest\n    Returns\n    -------\n    tuple, ( numpy.array, numpy.array, numpy.array,numpy.array)\n        Returned data is a tuple of numpy arrays.\n        The first index contains angles and subsequent arrays represent\n        molecule coordinate, energy and gradients associated with each angle.\n\n    \"\"\"\n    molecule_optimization = record.optimizations\n\n    angle_keys = list(molecule_optimization.keys())\n\n    xyzs = []\n    energies = []\n    gradients = []\n\n    for angle in angle_keys:\n        # NOTE: this is calling the first index of the optimization array\n        # this gives the same value as the prior implementation.\n        # however it seems to be that this contains multiple different initial configurations\n        # that have been optimized.  Should all conformers and energies/gradients be considered?\n        mol = molecule_optimization[angle][0].final_molecule\n        result = molecule_optimization[angle][0].trajectory[-1].properties\n\n        \"\"\"Note: force = - gradient\"\"\"\n\n        # TODO: attach units here? or later?\n\n        e = result[\"current energy\"]\n        g = np.array(result[\"current gradient\"]).reshape(-1, 3)\n\n        xyzs.append(mol.geometry)\n        energies.append(e)\n        gradients.append(g)\n\n    # to arrays\n    xyz = np.array(xyzs)\n    energies = np.array(energies)\n    gradients = np.array(gradients)\n\n    # assume each angle key is a tuple -- sort by first angle in tuple\n\n    # NOTE: (for now making the assumption that these torsion drives are 1D)\n    for k in angle_keys:\n        assert len(k) == 1\n\n    to_ordered = np.argsort([k[0] for k in angle_keys])\n    angles_in_order = [angle_keys[i_] for i_ in to_ordered]\n    flat_angles = np.array(angles_in_order).flatten()\n\n    # put the xyz's, energies, and gradients in the same order as the angles\n    xyz_in_order = xyz[to_ordered]\n    energies_in_order = energies[to_ordered]\n    gradients_in_order = gradients[to_ordered]\n\n    # TODO: put this return blob into a better struct\n    return flat_angles, xyz_in_order, energies_in_order, gradients_in_order\n\n\nMolWithTargets = namedtuple(\n    \"MolWithTargets\", [\"offmol\", \"xyz\", \"energies\", \"gradients\"]\n)\n\n\ndef h5_to_dataset(df):\n    def get_smiles(x):\n        try:\n            return x[\"offmol\"].to_smiles()\n        except:\n            return np.nan\n\n    df[\"smiles\"] = df.apply(get_smiles, axis=1)\n    df = df.dropna()\n    groups = df.groupby(\"smiles\")\n    gs = []\n    for name, group in groups:\n        mol_ref = group[\"offmol\"][0]\n        assert all(mol_ref == entry for entry in group[\"offmol\"])\n        g = esp.Graph(mol_ref)\n\n        u_ref = np.concatenate(group[\"energies\"].values)\n        u_ref_prime = np.concatenate(group[\"gradients\"].values, axis=0).transpose(\n            1, 0, 2\n        )\n        xyz = np.concatenate(group[\"xyz\"].values, axis=0).transpose(1, 0, 2)\n\n        assert u_ref_prime.shape[0] == xyz.shape[0] == mol_ref.n_atoms\n        assert u_ref.shape[0] == u_ref_prime.shape[1] == xyz.shape[1]\n\n        # energy is already hartree\n        g.nodes[\"g\"].data[\"u_ref\"] = torch.tensor(\n            Quantity(u_ref, esp.units.HARTREE_PER_PARTICLE).value_in_unit(\n                esp.units.ENERGY_UNIT\n            ),\n            dtype=torch.get_default_dtype(),\n        )[None, :]\n\n        g.nodes[\"n1\"].data[\"xyz\"] = torch.tensor(\n            Quantity(\n                xyz,\n                unit.bohr,\n            ).value_in_unit(esp.units.DISTANCE_UNIT),\n            requires_grad=True,\n            dtype=torch.get_default_dtype(),\n        )\n\n        g.nodes[\"n1\"].data[\"u_ref_prime\"] = torch.tensor(\n            Quantity(\n                u_ref_prime,\n                esp.units.HARTREE_PER_PARTICLE / unit.bohr,\n            ).value_in_unit(esp.units.FORCE_UNIT),\n            dtype=torch.get_default_dtype(),\n        )\n\n        gs.append(g)\n\n    return esp.data.dataset.GraphDataset(gs)\n\n\ndef breakdown_along_time_axis(g, batch_size=32):\n    n_snapshots = g.nodes[\"g\"].data[\"u_ref\"].flatten().shape[0]\n    idxs = list(range(n_snapshots))\n    from random import shuffle\n\n    shuffle(idxs)\n    chunks = [\n        idxs[_idx * batch_size: (_idx + 1) * batch_size]\n        for _idx in range(n_snapshots // batch_size)\n    ]\n\n    _gs = []\n    for chunk in chunks:\n        _g = esp.Graph(g.mol)\n        _g.nodes[\"g\"].data[\"u_ref\"] = (\n            g.nodes[\"g\"].data[\"u_ref\"][:, chunk].detach().clone()\n        )\n        _g.nodes[\"n1\"].data[\"xyz\"] = (\n            g.nodes[\"n1\"].data[\"xyz\"][:, chunk, :].detach().clone()\n        )\n        _g.nodes[\"n1\"].data[\"u_ref_prime\"] = (\n            g.nodes[\"n1\"].data[\"u_ref_prime\"][:, chunk, :].detach().clone()\n        )\n\n        _g.nodes[\"n1\"].data[\"xyz\"].requires_grad = True\n\n        _gs.append(_g)\n\n    return _gs\n\n\ndef make_batch_size_consistent(ds, batch_size=32):\n    import itertools\n\n    return esp.data.dataset.GraphDataset(\n        list(\n            itertools.chain.from_iterable(\n                [breakdown_along_time_axis(g, batch_size=batch_size) for g in ds]\n            )\n        )\n    )\n\n\ndef weight_by_snapshots(g, key=\"weight\"):\n    n_snapshots = g.nodes[\"n1\"].data[\"xyz\"].shape[1]\n    g.nodes[\"g\"].data[key] = torch.tensor(float(1.0 / n_snapshots))[None, :]\n"
  },
  {
    "path": "espaloma/data/tests/test_collection.py",
    "content": "import pytest\n\n\n@pytest.fixture\ndef esol():\n    import espaloma as esp\n\n    return esp.data.esol(first=16)\n\n\ndef test_view(esol):\n    view = esol.view(batch_size=4)\n    import dgl\n\n    graphs = list(view)\n    assert len(graphs) == 4\n    assert all(isinstance(graph, dgl.DGLHeteroGraph) for graph in graphs)\n\n\ndef test_typing(esol):\n    import espaloma as esp\n\n    typing = esp.graphs.legacy_force_field.LegacyForceField(\"gaff-1.81\")\n    esol = esol.apply(typing, in_place=True)\n    view = esol.view(batch_size=4)\n    for g in view:\n        assert g.nodes[\"n1\"].data[\"legacy_typing\"].shape[\n            0\n        ] == g.number_of_nodes(ntype=\"n1\")\n"
  },
  {
    "path": "espaloma/data/tests/test_dataset.py",
    "content": "import pytest\n\n\ndef test_tiny_dataset():\n    import espaloma as esp\n\n    xs = list(range(5))\n    ds = esp.data.dataset.Dataset(xs)\n\n\n@pytest.fixture\ndef ds():\n    xs = list(range(5))\n    import espaloma as esp\n\n    return esp.data.dataset.Dataset(xs)\n\n\ndef test_get(ds):\n    assert ds[0] == 0\n\n\ndef test_len(ds):\n    assert len(ds) == 5\n\n\ndef test_iter(ds):\n    assert all(x == x_ for (x, x_) in zip(ds, range(5)))\n\n\ndef test_slice(ds):\n    import espaloma as esp\n\n    sub_ds = ds[:2]\n    assert isinstance(ds, esp.data.dataset.Dataset)\n    assert len(sub_ds) == 2\n\n\ndef test_split(ds):\n    a, b = ds.split([1, 4])\n    assert len(a) == 1\n    assert len(b) == 4\n\n\n@pytest.fixture\ndef ds_new(ds):\n    fn = lambda x: x + 1\n    return ds.apply(fn)\n\n\ndef test_no_change(ds_new):\n    assert all(x == x_ for (x, x_) in zip(ds_new.graphs, range(5)))\n\n\ndef test_get_new(ds_new):\n    assert ds_new[0] == 1\n\n\ndef test_len_new(ds_new):\n    assert len(ds_new) == 5\n\n\ndef test_iter_new(ds_new):\n    assert all(x == x_ + 1 for (x, x_) in zip(ds_new, range(5)))\n\n\n@pytest.fixture\ndef ds_newer(ds):\n    fn = lambda x: x + 1\n    return ds.apply(fn).apply(fn)\n\n\ndef test_iter_newer(ds_newer):\n    assert all(x == x_ + 2 for (x, x_) in zip(ds_newer, range(5)))\n\n\ndef test_no_return(ds):\n    fn = lambda x: x + 1\n    ds.apply(fn).apply(fn)\n    assert all(x == x_ + 2 for (x, x_) in zip(ds, range(5)))\n\n\ndef test_subsample(ds):\n    _ds = ds.subsample(0.2)\n    assert len(_ds) == 1\n"
  },
  {
    "path": "espaloma/data/tests/test_md.py",
    "content": "import pytest\nimport torch\n\n\ndef test_init():\n    import espaloma.data.md\n\n\n@pytest.fixture\ndef graph():\n    import espaloma as esp\n\n    graph = esp.Graph(\"c1ccccc1\")\n    return graph\n\n\n@pytest.fixture\ndef ds():\n    import espaloma as esp\n\n    ds = esp.data.esol(first=10)\n    return ds\n\n\ndef test_system(graph):\n    from espaloma.data.md import MoleculeVacuumSimulation\n\n    simulation = MoleculeVacuumSimulation()\n\n\ndef test_run(graph):\n    from espaloma.data.md import MoleculeVacuumSimulation\n\n    simulation = MoleculeVacuumSimulation(n_samples=10, n_steps_per_sample=10)\n\n    samples = simulation.run(graph, in_place=False)\n\n    assert samples.shape == torch.Size([10, 12, 3])\n\n\ndef test_run_in_place(graph):\n    from espaloma.data.md import MoleculeVacuumSimulation\n\n    simulation = MoleculeVacuumSimulation(n_samples=10, n_steps_per_sample=10)\n\n    graph = simulation.run(graph, in_place=True)\n\n    assert graph.nodes[\"n1\"].data[\"xyz\"].shape == torch.Size([12, 10, 3])\n\n\ndef test_apply(ds):\n    from espaloma.data.md import MoleculeVacuumSimulation\n\n    simulation = MoleculeVacuumSimulation(\n        n_samples=1, n_steps_per_sample=1\n    ).run\n\n    ds.apply(simulation, in_place=True)\n\n    assert ds.graphs[0].nodes[\"n1\"].data[\"xyz\"].shape[-1] == 3\n    assert ds.graphs[0].nodes[\"n1\"].data[\"xyz\"].shape[-2] == 1\n"
  },
  {
    "path": "espaloma/data/tests/test_normalize.py",
    "content": "import numpy.testing as npt\nimport pytest\n\n\ndef test_import():\n    from espaloma.data.normalize import BaseNormalize\n\n\ndef test_normalize_esol():\n    import espaloma as esp\n\n    normalize = esp.data.normalize.DatasetNormalNormalize(\n        dataset=esp.data.esol(first=10).apply(\n            esp.graphs.legacy_force_field.LegacyForceField(\n                \"smirnoff99Frosst-1.1.0\"\n            ).parametrize,\n            in_place=True,\n        )\n    )\n\n\ndef test_log_normalize_esol():\n    import espaloma as esp\n\n    normalize = esp.data.normalize.DatasetLogNormalNormalize(\n        dataset=esp.data.esol(first=10).apply(\n            esp.graphs.legacy_force_field.LegacyForceField(\n                \"smirnoff99Frosst-1.1.0\"\n            ).parametrize,\n            in_place=True,\n        )\n    )\n\n\ndef test_normal_normalize_reproduce():\n    import espaloma as esp\n\n    normalize = esp.data.normalize.DatasetNormalNormalize(\n        dataset=esp.data.esol(first=10).apply(\n            esp.graphs.legacy_force_field.LegacyForceField(\n                \"smirnoff99Frosst-1.1.0\"\n            ).parametrize,\n            in_place=True,\n        )\n    )\n\n    esol = esp.data.esol(first=1)\n\n    # do some typing\n    param = esp.graphs.legacy_force_field.LegacyForceField(\n        \"smirnoff99Frosst-1.1.0\"\n    ).parametrize\n    esol.apply(param, in_place=True)  # this modify the original data\n\n    g = esol[0]\n\n    import copy\n\n    g_ = copy.deepcopy(g)\n\n    g = normalize.norm(g)\n\n    g.nodes[\"n2\"].data[\"k\"] = g.nodes[\"n2\"].data[\"k_ref\"]\n    g.nodes[\"n2\"].data[\"eq\"] = g.nodes[\"n2\"].data[\"eq_ref\"]\n\n    g = normalize.unnorm(g)\n\n    npt.assert_almost_equal(\n        g.nodes[\"n2\"].data[\"k\"].detach().numpy(),\n        g_.nodes[\"n2\"].data[\"k_ref\"].detach().numpy(),\n    )\n\n    npt.assert_almost_equal(\n        g.nodes[\"n2\"].data[\"eq\"].detach().numpy(),\n        g_.nodes[\"n2\"].data[\"eq_ref\"].detach().numpy(),\n    )\n\n\ndef test_log_normal_normalize_reproduce():\n    import espaloma as esp\n\n    normalize = esp.data.normalize.DatasetLogNormalNormalize(\n        dataset=esp.data.esol(first=10).apply(\n            esp.graphs.legacy_force_field.LegacyForceField(\n                \"smirnoff99Frosst-1.1.0\"\n            ).parametrize,\n            in_place=True,\n        )\n    )\n\n    esol = esp.data.esol(first=1)\n\n    # do some typing\n    param = esp.graphs.legacy_force_field.LegacyForceField(\n        \"smirnoff99Frosst-1.1.0\"\n    ).parametrize\n    esol.apply(param, in_place=True)  # this modify the original data\n\n    g = esol[0]\n\n    import copy\n\n    g_ = copy.deepcopy(g)\n\n    g = normalize.norm(g)\n\n    g.nodes[\"n2\"].data[\"k\"] = g.nodes[\"n2\"].data[\"k_ref\"]\n    g.nodes[\"n2\"].data[\"eq\"] = g.nodes[\"n2\"].data[\"eq_ref\"]\n\n    g = normalize.unnorm(g)\n\n    npt.assert_almost_equal(\n        g.nodes[\"n2\"].data[\"k\"].detach().numpy(),\n        g_.nodes[\"n2\"].data[\"k_ref\"].detach().numpy(),\n        decimal=1,\n    )\n\n    npt.assert_almost_equal(\n        g.nodes[\"n2\"].data[\"eq\"].detach().numpy(),\n        g_.nodes[\"n2\"].data[\"eq_ref\"].detach().numpy(),\n        decimal=1,\n    )\n"
  },
  {
    "path": "espaloma/data/tests/test_qcarchive.py",
    "content": "import pytest\n\n\ndef test_import():\n    import espaloma.data.qcarchive_utils\n\n\ndef test_get_graph():\n    from espaloma.data import qcarchive_utils\n\n    client = qcarchive_utils.get_client()\n    collection, record_names = qcarchive_utils.get_collection(client)\n    # The order records are received is not guaranteed, and can change if,\n    # e.g., the underlying database ends up being replaced by a copy during a database migration.\n    # as such we need to use a specific record name.\n    records_names_for_testing = ['c1c2c(c(c(c1f)n3cc(c3)o)cl)n(cc(c2=o)c(=o)[o-])c4c(cc(c(n4)n)f)f-3', 'c1c2c(cc(c1f)n3ccncc3)n(cc(c2=o)c(=o)[o-])c4cc4-0']\n\n    record_name = records_names_for_testing[0]\n    assert record_name in record_names\n\n    graph = qcarchive_utils.get_graph(collection, record_name)\n    assert graph is not None\n\n\n    graphs = qcarchive_utils.get_graphs(collection, records_names_for_testing)\n    assert len(graphs) == 2\n    assert graphs[0] is not None\n\n\ndef test_notsupported_dataset():\n    from espaloma.data import qcarchive_utils\n\n    name = \"DBH24\"\n    collection_type = \"reaction\"\n    collection, record_names = qcarchive_utils.get_collection(\n        qcarchive_utils.get_client(\"ml.qcarchive.molssi.org\"), collection_type, name\n    )\n    record_name = record_names[0]\n\n    with pytest.raises(Exception):\n        graph = qcarchive_utils.get_graph(collection, record_name, spec_name=\"spec_2\")\n\n\ndef test_get_torsiondrive():\n    from espaloma.data import qcarchive_utils\n    import numpy as np\n\n    record_name = \"[h]c1c(c(c(c([c:1]1[n:2]([c:3](=[o:4])c(=c([h])[h])[h])c([h])([h])[h])[h])[h])n(=o)=o)[h]\"\n\n    # example dataset \n    name = \"OpenFF Amide Torsion Set v1.0\"\n    collection_type = \"torsiondrive\"\n\n    collection, record_names = qcarchive_utils.get_collection(\n        qcarchive_utils.get_client(), collection_type, name\n    )\n    record_info = collection.get_record(record_name, specification_name=\"default\")\n\n    (\n        flat_angles,\n        xyz_in_order,\n        energies_in_order,\n        gradients_in_order,\n    ) = qcarchive_utils.fetch_td_record(record_info)\n\n    assert flat_angles.shape == (24,)\n    assert energies_in_order.shape == (24,)\n    assert gradients_in_order.shape == (24, 25, 3)\n    assert xyz_in_order.shape == (24, 25, 3)\n\n    assert np.isclose(energies_in_order[0], -722.2850260791969)\n    assert np.all(\n        flat_angles\n        == np.array(\n            [\n                -165,\n                -150,\n                -135,\n                -120,\n                -105,\n                -90,\n                -75,\n                -60,\n                -45,\n                -30,\n                -15,\n                0,\n                15,\n                30,\n                45,\n                60,\n                75,\n                90,\n                105,\n                120,\n                135,\n                150,\n                165,\n                180,\n            ]\n        )\n    )\n    assert np.allclose(\n        xyz_in_order[0][0], np.array([-0.66407807, -8.59922225, -0.02685972])\n    )\n"
  },
  {
    "path": "espaloma/data/tests/test_save_and_load.py",
    "content": "import pytest\n\n\ndef test_save_and_load():\n    import espaloma as esp\n\n    g = esp.Graph(\"C\")\n    ds = esp.data.dataset.GraphDataset([g])\n\n    # Temporary directory will be automatically cleaned up\n    from espaloma.data.utils import make_temp_directory\n\n    with make_temp_directory() as tmpdir:\n        import os\n\n        filename = os.path.join(tmpdir, \"ds\")\n\n        ds.save(filename)\n        new_ds = esp.data.dataset.GraphDataset.load(filename)\n"
  },
  {
    "path": "espaloma/data/utils.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport random\n\nimport numpy as np\nimport pandas as pd\nimport torch\nimport contextlib\n\nimport espaloma as esp\n\nOFFSETS = {\n    1: -0.500607632585,\n    6: -37.8302333826,\n    7: -54.5680045287,\n    8: -75.0362229210,\n}\n\n# ==============================================================================\n# UTILITY FUNCTIONS\n# ==============================================================================\n\n\n@contextlib.contextmanager\ndef make_temp_directory():\n    import tempfile, shutil\n\n    temp_dir = tempfile.mkdtemp()\n    try:\n        yield temp_dir\n    finally:\n        shutil.rmtree(temp_dir)\n\n\ndef sum_offsets(elements):\n    return sum([OFFSETS[element] for element in elements])\n\n\ndef from_csv(path, toolkit=\"rdkit\", smiles_col=-1, y_cols=[-2], seed=2666):\n    \"\"\"Read csv from file.\"\"\"\n\n    def _from_csv():\n        df = pd.read_csv(path)\n        df_smiles = df.iloc[:, smiles_col]\n        df_y = df.iloc[:, y_cols]\n\n        if toolkit == \"rdkit\":\n            from rdkit import Chem\n\n            mols = [Chem.MolFromSmiles(smiles) for smiles in df_smiles]\n            gs = [esp.HomogeneousGraph(mol) for mol in mols]\n\n        elif toolkit == \"openeye\":\n            from openeye import oechem\n\n            mols = [\n                oechem.OESmilesToMol(oechem.OEGraphMol(), smiles)\n                for smiles in df_smiles\n            ]\n            gs = [esp.HomogeneousGraph(mol) for mol in mols]\n\n        ds = list(zip(gs, list(torch.tensor(df_y.values))))\n\n        random.seed(seed)\n        random.shuffle(ds)\n\n        return ds\n\n    return _from_csv\n\n\ndef normalize(ds):\n    \"\"\"Get mean and std.\"\"\"\n\n    gs, ys = tuple(zip(*ds))\n    y_mean = np.mean(ys)\n    y_std = np.std(ys)\n\n    def norm(y):\n        return (y - y_mean) / y_std\n\n    def unnorm(y):\n        return y * y_std + y_mean\n\n    return y_mean, y_std, norm, unnorm\n\n\ndef split(ds, partition):\n    \"\"\"Split the dataset according to some partition.\"\"\"\n    n_data = len(ds)\n\n    # get the actual size of partition\n    partition = [int(n_data * x / sum(partition)) for x in partition]\n\n    ds_batched = []\n    idx = 0\n    for p_size in partition:\n        ds_batched.append(ds[idx : idx + p_size])\n        idx += p_size\n\n    return ds_batched\n\n\ndef batch(ds, batch_size, seed=2666):\n    \"\"\"Batch graphs and values after shuffling.\"\"\"\n    import dgl\n\n    # get the numebr of data\n    n_data_points = len(ds)\n    n_batches = n_data_points // batch_size  # drop the rest\n\n    random.seed(seed)\n    random.shuffle(ds)\n    gs, ys = tuple(zip(*ds))\n\n    gs_batched = [\n        dgl.batch(gs[idx * batch_size : (idx + 1) * batch_size])\n        for idx in range(n_batches)\n    ]\n\n    ys_batched = [\n        torch.stack(ys[idx * batch_size : (idx + 1) * batch_size], dim=0)\n        for idx in range(n_batches)\n    ]\n\n    return list(zip(gs_batched, ys_batched))\n\n\ndef collate_fn(graphs):\n    import dgl\n\n    return esp.HomogeneousGraph(dgl.batch(graphs))\n\n\ndef infer_mol_from_coordinates(\n    coordinates,\n    species,\n    smiles_ref=None,\n    coordinates_unit=\"angstrom\",\n):\n\n    # local import\n    from openeye import oechem\n    from openmm import unit\n    from openmm.unit import Quantity\n\n    if isinstance(coordinates_unit, str):\n        coordinates_unit = getattr(unit, coordinates_unit)\n\n    # make sure we have the coordinates\n    # in the unit system\n    coordinates = Quantity(coordinates, coordinates_unit).value_in_unit(\n        unit.angstrom  # to make openeye happy\n    )\n\n    # initialize molecule\n    mol = oechem.OEGraphMol()\n\n    if all(isinstance(symbol, str) for symbol in species):\n        [\n            mol.NewAtom(getattr(oechem, \"OEElemNo_\" + symbol))\n            for symbol in species\n        ]\n\n    elif all(isinstance(symbol, int) for symbol in species):\n        [\n            mol.NewAtom(\n                getattr(\n                    oechem, \"OEElemNo_\" + oechem.OEGetAtomicSymbol(symbol)\n                )\n            )\n            for symbol in species\n        ]\n\n    else:\n        raise RuntimeError(\n            \"The species can only be all strings or all integers.\"\n        )\n\n    mol.SetCoords(coordinates.reshape([-1]))\n    mol.SetDimension(3)\n    oechem.OEDetermineConnectivity(mol)\n    oechem.OEFindRingAtomsAndBonds(mol)\n    oechem.OEPerceiveBondOrders(mol)\n\n    if smiles_ref is not None:\n        smiles_can = oechem.OECreateCanSmiString(mol)\n        ims = oechem.oemolistream()\n        ims.SetFormat(oechem.OEFormat_SMI)\n        ims.openstring(smiles_ref)\n        mol_ref = next(ims.GetOEMols())\n        smiles_ref = oechem.OECreateCanSmiString(mol_ref)\n        assert (\n            smiles_ref == smiles_can\n        ), \"SMILES different. Input is %s, ref is %s\" % (\n            smiles_can,\n            smiles_ref,\n        )\n\n    from openff.toolkit.topology import Molecule\n\n    _mol = Molecule.from_openeye(mol, allow_undefined_stereo=True)\n    g = esp.Graph(_mol)\n\n    return g\n"
  },
  {
    "path": "espaloma/graphs/__init__.py",
    "content": "\"\"\"The basic data structure of espaloma---graph is represent a molecular system\nand provide access to `dgl.DGLHeteroGraph` and `openff.toolkit.topology.Molecule.\n\n\"\"\"\nfrom . import deploy, utils\nfrom .legacy_force_field import *\n"
  },
  {
    "path": "espaloma/graphs/deploy.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport numpy as np\nimport rdkit\nimport torch\nfrom openff.toolkit.typing.engines.smirnoff import ForceField\nimport espaloma as esp\nfrom openmm import unit\nfrom openmm.unit import Quantity\nimport math\n\n# =============================================================================\n# CONSTANTS\n# =============================================================================\nOPENMM_LENGTH_UNIT = unit.nanometer\nOPENMM_ANGLE_UNIT = unit.radian\nOPENMM_ENERGY_UNIT = unit.kilojoule_per_mole\n\nOPENMM_BOND_EQ_UNIT = OPENMM_LENGTH_UNIT\nOPENMM_ANGLE_EQ_UNIT = OPENMM_ANGLE_UNIT\nOPENMM_TORSION_K_UNIT = OPENMM_ENERGY_UNIT\nOPENMM_TORSION_PHASE_UNIT = OPENMM_ANGLE_UNIT\nOPENMM_BOND_K_UNIT = OPENMM_ENERGY_UNIT / (OPENMM_LENGTH_UNIT**2)\nOPENMM_ANGLE_K_UNIT = OPENMM_ENERGY_UNIT / (OPENMM_ANGLE_UNIT**2)\n\n# =============================================================================\n# MODULE FUNCTIONS\n# =============================================================================\n\n\ndef load_forcefield(forcefield=\"openff_unconstrained-2.2.1\"):\n    # get a forcefield\n    try:\n        ff = ForceField(\"%s.offxml\" % forcefield)\n    except Exception as e:\n        print(e)\n        raise NotImplementedError\n    return ff\n\n\ndef openmm_system_from_graph(\n    g,\n    forcefield=\"openff_unconstrained-2.1.1\",\n    suffix=\"\",\n    charge_method=\"nn\",\n    create_system_kwargs={},\n):\n    \"\"\"Construct an openmm system from `espaloma.Graph`.\n\n    Parameters\n    ----------\n    g : `espaloma.Graph`\n        Input graph.\n\n    forcefield : `str`, optional, default='openff_unconstrained-2.1.1'\n        Name of the force field. Have to be Open Force Field.\n        (this forcefield will be used to assign nonbonded parameters, but all of its valence parameters will be overwritten)\n\n    suffix : `str`\n        Suffix for the force terms.\n\n    charge_method : str, optional, default='nn'\n        Method to use for assigning partial charges:\n        'nn' : Assign partial charges from the espaloma graph net model\n        'am1-bcc' : Allow the OpenFF toolkit to assign AM1-BCC charges using default backend\n        'gasteiger' : Assign Gasteiger partial charges (not recommended)\n        'from-molecule' : Use partial charges provided in the original `Molecule` object\n\n    Returns\n    -------\n    sys : `openmm.System`\n        Constructed single-molecule OpenMM system.\n\n\n    \"\"\"\n    ff = load_forcefield(forcefield)\n\n    # get the mapping between position and indices\n    bond_lookup = {\n        tuple(idxs.detach().numpy()): position\n        for position, idxs in enumerate(g.nodes[\"n2\"].data[\"idxs\"])\n    }\n\n    angle_lookup = {\n        tuple(idxs.detach().numpy()): position\n        for position, idxs in enumerate(g.nodes[\"n3\"].data[\"idxs\"])\n    }\n\n    if charge_method == \"gasteiger\":\n        # from rdkit.Chem.AllChem import ComputeGasteigerCharges\n        # rdkit_mol = g.mol.to_rdkit()\n        # ComputeGasteigerCharges(rdkit_mol)\n        # charges = [atom.GetDoubleProp(\"_GasteigerCharge\") for atom in rdkit_mol.GetAtoms()]\n        g.mol.assign_partial_charges(\"gasteiger\")\n        sys = ff.create_openmm_system(\n            g.mol.to_topology(), charge_from_molecules=[g.mol]\n        )\n\n    elif charge_method == \"am1-bcc\":\n        g.mol.assign_partial_charges(\"am1bcc\")\n        sys = ff.create_openmm_system(\n            g.mol.to_topology(), charge_from_molecules=[g.mol]\n        )\n\n    elif charge_method == \"from-molecule\":\n        sys = ff.create_openmm_system(\n            g.mol.to_topology(), charge_from_molecules=[g.mol]\n        )\n\n    elif charge_method == \"nn\":\n        g.mol.partial_charges = unit.elementary_charge * g.nodes[\"n1\"].data[\n            \"q\"\n        ].flatten().detach().cpu().numpy().astype(\n            np.float64,\n        )\n        sys = ff.create_openmm_system(\n            g.mol.to_topology(),\n            charge_from_molecules=[g.mol],\n            allow_nonintegral_charges=True,\n        )\n\n    else:\n        # create openmm system\n        raise RuntimeError(\n            \"Charge method %s is not supported. \" % charge_method\n        )\n\n    for force in sys.getForces():\n        name = force.__class__.__name__\n        if \"HarmonicBondForce\" in name:\n            assert force.getNumBonds() * 2 == g.heterograph.number_of_nodes(\n                \"n2\"\n            )\n\n            for idx in range(force.getNumBonds()):\n                idx0, idx1, eq, k = force.getBondParameters(idx)\n                position = bond_lookup[(idx0, idx1)]\n                _eq = (\n                    g.nodes[\"n2\"]\n                    .data[\"eq%s\" % suffix][position]\n                    .detach()\n                    .numpy()\n                    .item()\n                )\n                _k = (\n                    g.nodes[\"n2\"]\n                    .data[\"k%s\" % suffix][position]\n                    .detach()\n                    .numpy()\n                    .item()\n                )\n\n                _eq = Quantity(  # bond length\n                    _eq,\n                    esp.units.DISTANCE_UNIT,\n                ).value_in_unit(OPENMM_BOND_EQ_UNIT)\n\n                _k = Quantity(  # bond force constant:\n                    # since everything is enumerated twice in espaloma\n                    # and once in OpenMM,\n                    # we insert a coefficient of 2.0\n                    _k,\n                    esp.units.FORCE_CONSTANT_UNIT,\n                ).value_in_unit(OPENMM_BOND_K_UNIT)\n\n                force.setBondParameters(idx, idx0, idx1, _eq, _k)\n\n        if \"HarmonicAngleForce\" in name:\n            assert force.getNumAngles() * 2 == g.heterograph.number_of_nodes(\n                \"n3\"\n            )\n\n            for idx in range(force.getNumAngles()):\n                idx0, idx1, idx2, eq, k = force.getAngleParameters(idx)\n                position = angle_lookup[(idx0, idx1, idx2)]\n                _eq = (\n                    g.nodes[\"n3\"]\n                    .data[\"eq%s\" % suffix][position]\n                    .detach()\n                    .numpy()\n                    .item()\n                )\n                _k = (\n                    g.nodes[\"n3\"]\n                    .data[\"k%s\" % suffix][position]\n                    .detach()\n                    .numpy()\n                    .item()\n                )\n\n                _eq = Quantity(\n                    _eq,\n                    esp.units.ANGLE_UNIT,\n                ).value_in_unit(OPENMM_ANGLE_EQ_UNIT)\n\n                _k = Quantity(  # force constant\n                    # since everything is enumerated twice in espaloma\n                    # and once in OpenMM,\n                    # we insert a coefficient of 2.0\n                    _k,\n                    esp.units.ANGLE_FORCE_CONSTANT_UNIT,\n                ).value_in_unit(OPENMM_ANGLE_K_UNIT)\n\n                force.setAngleParameters(idx, idx0, idx1, idx2, _eq, _k)\n\n        if \"PeriodicTorsionForce\" in name:\n            number_of_torsions = force.getNumTorsions()\n            if (\n                \"periodicity%s\" % suffix not in g.nodes[\"n4\"].data\n                or \"phase%s\" % suffix not in g.nodes[\"n4\"].data\n            ):\n\n                g.nodes[\"n4\"].data[\"periodicity%s\" % suffix] = torch.arange(\n                    1, 7\n                )[None, :].repeat(g.heterograph.number_of_nodes(\"n4\"), 1)\n\n                g.nodes[\"n4\"].data[\"phases%s\" % suffix] = torch.zeros(\n                    g.heterograph.number_of_nodes(\"n4\"), 6\n                )\n\n                g.nodes[\"n4_improper\"].data[\n                    \"periodicity%s\" % suffix\n                ] = torch.arange(1, 7)[None, :].repeat(\n                    g.heterograph.number_of_nodes(\"n4_improper\"), 1\n                )\n\n                g.nodes[\"n4_improper\"].data[\n                    \"phases%s\" % suffix\n                ] = torch.zeros(\n                    g.heterograph.number_of_nodes(\"n4_improper\"), 6\n                )\n\n            count_idx = 0\n            for idx in range(g.heterograph.number_of_nodes(\"n4\")):\n                idx0 = g.nodes[\"n4\"].data[\"idxs\"][idx, 0].item()\n                idx1 = g.nodes[\"n4\"].data[\"idxs\"][idx, 1].item()\n                idx2 = g.nodes[\"n4\"].data[\"idxs\"][idx, 2].item()\n                idx3 = g.nodes[\"n4\"].data[\"idxs\"][idx, 3].item()\n\n                # assuming both (a,b,c,d) and (d,c,b,a) are listed for every torsion, only pick one of the orderings\n                if idx0 < idx3:\n                    periodicities = g.nodes[\"n4\"].data[\n                        \"periodicity%s\" % suffix\n                    ][idx]\n                    phases = g.nodes[\"n4\"].data[\"phases%s\" % suffix][idx]\n                    ks = g.nodes[\"n4\"].data[\"k%s\" % suffix][idx]\n                    for sub_idx in range(ks.flatten().shape[0]):\n                        k = ks[sub_idx].item()\n                        if k != 0.0:\n                            _periodicity = periodicities[sub_idx].item()\n                            _phase = phases[sub_idx].item()\n\n                            if k < 0:\n                                k = -k\n                                _phase = math.pi - _phase\n\n                            k = Quantity(\n                                k,\n                                esp.units.ENERGY_UNIT,\n                            ).value_in_unit(\n                                OPENMM_ENERGY_UNIT,\n                            )\n\n                            if count_idx < number_of_torsions:\n                                force.setTorsionParameters(\n                                    # since everything is enumerated\n                                    # twice in espaloma\n                                    # and once in OpenMM,\n                                    # we insert a coefficient of 2.0\n                                    count_idx,\n                                    idx0,\n                                    idx1,\n                                    idx2,\n                                    idx3,\n                                    _periodicity,\n                                    _phase,\n                                    k,\n                                )\n\n                            else:\n                                force.addTorsion(\n                                    # since everything is enumerated\n                                    # twice in espaloma\n                                    # and once in OpenMM,\n                                    # we insert a coefficient of 2.0\n                                    idx0,\n                                    idx1,\n                                    idx2,\n                                    idx3,\n                                    _periodicity,\n                                    _phase,\n                                    k,\n                                )\n\n                            count_idx += 1\n\n            if \"k%s\" % suffix in g.nodes[\"n4_improper\"].data:\n                for idx in range(\n                    g.heterograph.number_of_nodes(\"n4_improper\")\n                ):\n                    idx0 = g.nodes[\"n4_improper\"].data[\"idxs\"][idx, 0].item()\n                    idx1 = g.nodes[\"n4_improper\"].data[\"idxs\"][idx, 1].item()\n                    idx2 = g.nodes[\"n4_improper\"].data[\"idxs\"][idx, 2].item()\n                    idx3 = g.nodes[\"n4_improper\"].data[\"idxs\"][idx, 3].item()\n\n                    periodicities = g.nodes[\"n4_improper\"].data[\n                        \"periodicity%s\" % suffix\n                    ][idx]\n                    phases = g.nodes[\"n4_improper\"].data[\"phases%s\" % suffix][\n                        idx\n                    ]\n                    ks = g.nodes[\"n4_improper\"].data[\"k%s\" % suffix][idx]\n                    for sub_idx in range(ks.flatten().shape[0]):\n                        k = ks[sub_idx].item()\n                        if k != 0.0:\n                            _periodicity = periodicities[sub_idx].item()\n                            _phase = phases[sub_idx].item()\n\n                            if k < 0:\n                                k = -k\n                                _phase = math.pi - _phase\n\n                            k = Quantity(\n                                k,\n                                esp.units.ENERGY_UNIT,\n                            ).value_in_unit(\n                                OPENMM_ENERGY_UNIT,\n                            )\n\n                            if count_idx < number_of_torsions:\n                                force.setTorsionParameters(\n                                    # since everything is enumerated\n                                    # twice in espaloma\n                                    # and once in OpenMM,\n                                    # we insert a coefficient of 2.0\n                                    count_idx,\n                                    idx0,\n                                    idx1,\n                                    idx2,\n                                    idx3,\n                                    _periodicity,\n                                    _phase,\n                                    0.5 * k,\n                                )\n\n                            else:\n                                force.addTorsion(\n                                    # since everything is enumerated\n                                    # twice in espaloma\n                                    # and once in OpenMM,\n                                    # we insert a coefficient of 2.0\n                                    idx0,\n                                    idx1,\n                                    idx2,\n                                    idx3,\n                                    _periodicity,\n                                    _phase,\n                                    0.5 * k,\n                                )\n\n                            count_idx += 1\n\n    return sys\n"
  },
  {
    "path": "espaloma/graphs/graph.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport abc\nimport io\nimport openff.toolkit\n\nimport espaloma as esp\n\n\n# =============================================================================\n# MODULE CLASSES\n# =============================================================================\nclass BaseGraph(abc.ABC):\n    \"\"\"Base class of graph.\"\"\"\n\n    def __init__(self):\n        super(BaseGraph, self).__init__()\n\n\nclass Graph(BaseGraph):\n    \"\"\"A unified graph object that support translation to and from\n    message-passing graphs and MM factor graph.\n\n    Methods\n    -------\n    save(path)\n        Save graph to file.\n\n    load(path)\n        Load a graph from path.\n\n    Note\n    ----\n    This object provides access to popular attributes of homograph and\n    heterograph.\n\n    This object also provides access to `ndata` and `edata` from the heterograph.\n\n    Examples\n    --------\n    >>> g0 = esp.Graph(\"C\")\n    >>> g1 = esp.Graph(Molecule.from_smiles(\"C\"))\n    >>> assert g0 == g1\n\n    \"\"\"\n\n    def __init__(self, mol=None, homograph=None, heterograph=None):\n        # TODO : more pythonic way allow multiple constructors:\n        #   Graph.from_smiles(...), Graph.from_mol(...), Graph.from_homograph(...), ...\n        #   rather than Graph(mol=None, homograph=None, ...)\n\n        # input molecule\n        if isinstance(mol, str):\n            from openff.toolkit.topology import Molecule\n\n            mol = Molecule.from_smiles(mol, allow_undefined_stereo=True)\n        if mol is not None and homograph is None and heterograph is None:\n            homograph = self.get_homograph_from_mol(mol)\n\n        if homograph is not None and heterograph is None:\n            heterograph = self.get_heterograph_from_graph_and_mol(\n                homograph, mol\n            )\n\n        self.mol = mol\n        self.homograph = homograph\n        self.heterograph = heterograph\n\n    def save(self, path):\n        import os\n        import json\n        import dgl\n\n        os.mkdir(path)\n        dgl.save_graphs(path + \"/homograph.bin\", [self.homograph])\n        dgl.save_graphs(path + \"/heterograph.bin\", [self.heterograph])\n        with open(path + \"/mol.json\", \"w\") as f_handle:\n            json.dump(self.mol.to_json(), f_handle)\n\n    @classmethod\n    def load(cls, path):\n        import json\n        import dgl\n\n        homograph = dgl.load_graphs(path + \"/homograph.bin\")[0][0]\n        heterograph = dgl.load_graphs(path + \"/heterograph.bin\")[0][0]\n\n        with open(path + \"/mol.json\", \"r\") as f_handle:\n            mol = json.load(f_handle)\n        from openff.toolkit.topology import Molecule\n\n        # With OFF toolkit >=0.11, from_json requires the \"hierarchy_schemes\" key\n        # which is not created with previous toolkit versions. That means, from_json\n        # errors out when loading molecules that were json serialized with older\n        # toolkit versions.\n        try:\n            mol = Molecule.from_json(mol)\n        except KeyError:\n            # this probably means hierarchy_schemes key wasn't found\n            mol_dict = json.load(io.StringIO(mol))\n            if \"hierarchy_schemes\" not in mol_dict.keys():\n                mol_dict[\"hierarchy_schemes\"] = dict()  # Default to empty dict if not present\n            mol = Molecule.from_dict(mol_dict)\n\n        g = cls(mol=mol, homograph=homograph, heterograph=heterograph)\n        return g\n\n    @staticmethod\n    def get_homograph_from_mol(mol):\n        assert isinstance(\n            mol, openff.toolkit.topology.Molecule\n        ), \"mol can only be OFF Molecule object.\"\n\n        # TODO:\n        # rewrite this using OFF-generic grammar\n        # graph = esp.graphs.utils.read_homogeneous_graph.from_rdkit_mol(\n        #     mol.to_rdkit()\n        # )\n\n        graph = (\n            esp.graphs.utils.read_homogeneous_graph.from_openff_toolkit_mol(\n                mol\n            )\n        )\n        return graph\n\n    @staticmethod\n    def get_heterograph_from_graph_and_mol(graph, mol):\n        import dgl\n\n        assert isinstance(\n            graph, dgl.DGLGraph\n        ), \"graph can only be dgl Graph object.\"\n\n        heterograph = esp.graphs.utils.read_heterogeneous_graph.from_homogeneous_and_mol(\n            graph, mol\n        )\n\n        return heterograph\n\n    #\n    # @property\n    # def mol(self):\n    #     return self._mol\n    #\n    # @property\n    # def homograph(self):\n    #     return self._homograph\n    #\n    # @property\n    # def heterograph(self):\n    #     return self._heterograph\n\n    @property\n    def ndata(self):\n        return self.homograph.ndata\n\n    @property\n    def edata(self):\n        return self.homograph.edata\n\n    @property\n    def nodes(self):\n        return self.heterograph.nodes\n"
  },
  {
    "path": "espaloma/graphs/legacy_force_field.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport rdkit\nimport torch\nfrom openff.toolkit import Molecule\nimport espaloma as esp\n\nfrom openmmforcefields.generators import SystemGenerator\nimport openmm\nfrom openmm import unit\nfrom openmm.app import Simulation\nfrom openmm.unit import Quantity\n\n# =============================================================================\n# CONSTANTS\n# =============================================================================\nREDUNDANT_TYPES = {\n    \"cd\": \"cc\",\n    \"cf\": \"ce\",\n    \"cq\": \"cp\",\n    \"pd\": \"pc\",\n    \"pf\": \"pe\",\n    \"nd\": \"nc\",\n}\n\n# simulation specs\nTEMPERATURE = 350 * unit.kelvin\nSTEP_SIZE = 1.0 * unit.femtosecond\nCOLLISION_RATE = 1.0 / unit.picosecond\nEPSILON_MIN = 0.05 * unit.kilojoules_per_mole\n\n# =============================================================================\n# MODULE CLASSES\n# =============================================================================\nclass LegacyForceField:\n    \"\"\"Class to hold legacy forcefield for typing and parameter assignment.\n\n    Parameters\n    ----------\n    forcefield : string\n        name and version of the forcefield.\n\n    Methods\n    -------\n    parametrize()\n        Parametrize a molecular system.\n\n    typing()\n        Provide legacy typing for a molecular system.\n\n    \"\"\"\n\n    def __init__(self, forcefield=\"gaff-1.81\"):\n        self.forcefield = forcefield\n        self._prepare_forcefield()\n\n    @staticmethod\n    def _convert_to_off(mol):\n        if isinstance(mol, esp.Graph):\n            return mol.mol\n\n        elif isinstance(mol, Molecule):\n            return mol\n        elif isinstance(mol, rdkit.Chem.rdchem.Mol):\n            return Molecule.from_rdkit(mol)\n        elif \"openeye\" in str(\n            type(mol)\n        ):  # because we don't want to depend on OE\n            return Molecule.from_openeye(mol)\n\n    def _prepare_forcefield(self):\n\n        if \"gaff\" in self.forcefield:\n            self._prepare_gaff()\n\n        elif \"smirnoff\" in self.forcefield:\n            # do nothing for now\n            self._prepare_smirnoff()\n\n        elif \"openff\" in self.forcefield:\n            self._prepare_openff()\n\n        else:\n            raise NotImplementedError\n\n    def _prepare_openff(self):\n\n        from openff.toolkit import ForceField\n\n        self.FF = ForceField(\"%s.offxml\" % self.forcefield)\n\n    def _prepare_smirnoff(self):\n\n        from openff.toolkit import ForceField\n\n        self.FF = ForceField(\"%s.offxml\" % self.forcefield)\n\n    def _prepare_gaff(self):\n        import os\n        import xml.etree.ElementTree as ET\n\n        import openmmforcefields\n\n        # get the openff.toolkits path\n        openmmforcefields_path = os.path.dirname(openmmforcefields.__file__)\n\n        # get the xml path\n        ffxml_path = (\n            openmmforcefields_path\n            + \"/ffxml/amber/gaff/ffxml/\"\n            + self.forcefield\n            + \".xml\"\n        )\n\n        # parse xml\n        tree = ET.parse(ffxml_path)\n        root = tree.getroot()\n        nonbonded = root.find(\"NonbondedForce\")\n        atom_types = [atom.get(\"class\") for atom in nonbonded.findall(\"Atom\")]\n\n        # remove redundant types\n        [atom_types.remove(bad_type) for bad_type in REDUNDANT_TYPES.keys()]\n\n        # compose the translation dictionaries\n        str_2_idx = dict(zip(atom_types, range(len(atom_types))))\n        idx_2_str = dict(zip(range(len(atom_types)), atom_types))\n\n        # provide mapping for redundant types\n        for bad_type, good_type in REDUNDANT_TYPES.items():\n            str_2_idx[bad_type] = str_2_idx[good_type]\n\n        # make translation dictionaries attributes of self\n        self._str_2_idx = str_2_idx\n        self._idx_2_str = idx_2_str\n\n    def _type_gaff(self, g):\n        \"\"\"Type a molecular graph using gaff force fields.\"\"\"\n        # assert the forcefield is indeed of gaff family\n        assert \"gaff\" in self.forcefield\n\n        # make sure mol is in openff.toolkit format `\n        mol = g.mol\n\n        # import template generator\n        from openmmforcefields.generators import GAFFTemplateGenerator\n\n        gaff = GAFFTemplateGenerator(\n            molecules=mol, forcefield=self.forcefield\n        )\n\n        # create temporary directory for running antechamber\n        import os\n        import shutil\n        import tempfile\n\n        tempdir = tempfile.mkdtemp()\n        prefix = \"molecule\"\n        input_sdf_filename = os.path.join(tempdir, prefix + \".sdf\")\n        gaff_mol2_filename = os.path.join(tempdir, prefix + \".gaff.mol2\")\n        frcmod_filename = os.path.join(tempdir, prefix + \".frcmod\")\n\n        # write sdf for input\n        mol.to_file(input_sdf_filename, file_format=\"sdf\")\n\n        # run antechamber\n        gaff._run_antechamber(\n            molecule_filename=input_sdf_filename,\n            input_format=\"mdl\",\n            gaff_mol2_filename=gaff_mol2_filename,\n            frcmod_filename=frcmod_filename,\n        )\n\n        gaff._read_gaff_atom_types_from_mol2(gaff_mol2_filename, mol)\n        gaff_types = [atom.gaff_type for atom in mol.atoms]\n        shutil.rmtree(tempdir)\n\n        # put types into graph object\n        if g is None:\n            g = esp.Graph(mol)\n\n        g.nodes[\"n1\"].data[\"legacy_typing\"] = torch.tensor(\n            [self._str_2_idx[atom] for atom in gaff_types]\n        )\n\n        return g\n\n    def _parametrize_gaff(self, g, n_max_phases=6):\n        from openmmforcefields.generators import SystemGenerator\n\n        # define a system generator\n        system_generator = SystemGenerator(\n            small_molecule_forcefield=self.forcefield,\n        )\n\n        mol = g.mol\n        # mol.assign_partial_charges(\"formal_charge\")\n        # create system\n        sys = system_generator.create_system(\n            topology=mol.to_topology().to_openmm(),\n            molecules=mol,\n        )\n\n        bond_lookup = {\n            tuple(idxs.detach().numpy()): position\n            for position, idxs in enumerate(g.nodes[\"n2\"].data[\"idxs\"])\n        }\n\n        angle_lookup = {\n            tuple(idxs.detach().numpy()): position\n            for position, idxs in enumerate(g.nodes[\"n3\"].data[\"idxs\"])\n        }\n\n        torsion_lookup = {\n            tuple(idxs.detach().numpy()): position\n            for position, idxs in enumerate(g.nodes[\"n4\"].data[\"idxs\"])\n        }\n\n        improper_lookup = {\n            tuple(idxs.detach().numpy()): position\n            for position, idxs in enumerate(\n                g.nodes[\"n4_improper\"].data[\"idxs\"]\n            )\n        }\n\n        torsion_phases = torch.zeros(\n            g.heterograph.number_of_nodes(\"n4\"),\n            n_max_phases,\n        )\n\n        torsion_periodicities = torch.zeros(\n            g.heterograph.number_of_nodes(\"n4\"),\n            n_max_phases,\n        )\n\n        torsion_ks = torch.zeros(\n            g.heterograph.number_of_nodes(\"n4\"),\n            n_max_phases,\n        )\n\n        improper_phases = torch.zeros(\n            g.heterograph.number_of_nodes(\"n4\"),\n            n_max_phases,\n        )\n\n        improper_periodicities = torch.zeros(\n            g.heterograph.number_of_nodes(\"n4\"),\n            n_max_phases,\n        )\n\n        improper_ks = torch.zeros(\n            g.heterograph.number_of_nodes(\"n4\"),\n            n_max_phases,\n        )\n\n        for force in sys.getForces():\n            name = force.__class__.__name__\n            if \"HarmonicBondForce\" in name:\n                assert (\n                    force.getNumBonds() * 2\n                    == g.heterograph.number_of_nodes(\"n2\")\n                )\n\n                g.nodes[\"n2\"].data[\"eq_ref\"] = torch.zeros(\n                    force.getNumBonds() * 2, 1\n                )\n\n                g.nodes[\"n2\"].data[\"k_ref\"] = torch.zeros(\n                    force.getNumBonds() * 2, 1\n                )\n\n                for idx in range(force.getNumBonds()):\n                    idx0, idx1, eq, k = force.getBondParameters(idx)\n\n                    position = bond_lookup[(idx0, idx1)]\n                    g.nodes[\"n2\"].data[\"eq_ref\"][position] = eq.value_in_unit(\n                        esp.units.DISTANCE_UNIT,\n                    )\n                    g.nodes[\"n2\"].data[\"k_ref\"][position] = k.value_in_unit(\n                        esp.units.FORCE_CONSTANT_UNIT,\n                    )\n\n                    position = bond_lookup[(idx1, idx0)]\n                    g.nodes[\"n2\"].data[\"eq_ref\"][position] = eq.value_in_unit(\n                        esp.units.DISTANCE_UNIT,\n                    )\n                    g.nodes[\"n2\"].data[\"k_ref\"][position] = k.value_in_unit(\n                        esp.units.FORCE_CONSTANT_UNIT,\n                    )\n\n            if \"HarmonicAngleForce\" in name:\n                assert (\n                    force.getNumAngles() * 2\n                    == g.heterograph.number_of_nodes(\"n3\")\n                )\n\n                g.nodes[\"n3\"].data[\"eq_ref\"] = torch.zeros(\n                    force.getNumAngles() * 2, 1\n                )\n\n                g.nodes[\"n3\"].data[\"k_ref\"] = torch.zeros(\n                    force.getNumAngles() * 2, 1\n                )\n\n                for idx in range(force.getNumAngles()):\n                    idx0, idx1, idx2, eq, k = force.getAngleParameters(idx)\n\n                    position = angle_lookup[(idx0, idx1, idx2)]\n                    g.nodes[\"n3\"].data[\"eq_ref\"][position] = eq.value_in_unit(\n                        esp.units.ANGLE_UNIT,\n                    )\n                    g.nodes[\"n3\"].data[\"k_ref\"][position] = k.value_in_unit(\n                        esp.units.ANGLE_FORCE_CONSTANT_UNIT,\n                    )\n\n                    position = angle_lookup[(idx2, idx1, idx0)]\n                    g.nodes[\"n3\"].data[\"eq_ref\"][position] = eq.value_in_unit(\n                        esp.units.ANGLE_UNIT,\n                    )\n                    g.nodes[\"n3\"].data[\"k_ref\"][position] = k.value_in_unit(\n                        esp.units.ANGLE_FORCE_CONSTANT_UNIT,\n                    )\n\n            if \"PeriodicTorsionForce\" in name:\n                for idx in range(force.getNumTorsions()):\n                    (\n                        idx0,\n                        idx1,\n                        idx2,\n                        idx3,\n                        periodicity,\n                        phase,\n                        k,\n                    ) = force.getTorsionParameters(idx)\n\n                    if (idx0, idx1, idx2, idx3) in torsion_lookup:\n                        position = torsion_lookup[(idx0, idx1, idx2, idx3)]\n                        for sub_idx in range(n_max_phases):\n                            if torsion_ks[position, sub_idx] == 0:\n                                torsion_ks[\n                                    position, sub_idx\n                                ] = 0.5 * k.value_in_unit(\n                                    esp.units.ENERGY_UNIT\n                                )\n                                torsion_phases[\n                                    position, sub_idx\n                                ] = phase.value_in_unit(esp.units.ANGLE_UNIT)\n                                torsion_periodicities[\n                                    position, sub_idx\n                                ] = periodicity\n\n                                position = torsion_lookup[\n                                    (idx3, idx2, idx1, idx0)\n                                ]\n                                torsion_ks[\n                                    position, sub_idx\n                                ] = 0.5 * k.value_in_unit(\n                                    esp.units.ENERGY_UNIT\n                                )\n                                torsion_phases[\n                                    position, sub_idx\n                                ] = phase.value_in_unit(esp.units.ANGLE_UNIT)\n                                torsion_periodicities[\n                                    position, sub_idx\n                                ] = periodicity\n                                break\n\n            g.heterograph.apply_nodes(\n                lambda nodes: {\n                    \"k_ref\": torsion_ks,\n                    \"periodicity_ref\": torsion_periodicities,\n                    \"phases_ref\": torsion_phases,\n                },\n                ntype=\"n4\",\n            )\n\n            \"\"\"\n            g.heterograph.apply_nodes(\n                    lambda nodes: {\n                        \"k_ref\": improper_ks,\n                        \"periodicity_ref\": improper_periodicities,\n                        \"phases_ref\": improper_phases,\n                    },\n                    ntype=\"n4_improper\"\n            )\n\n            \"\"\"\n\n        \"\"\"\n        def apply_torsion(node, n_max_phases=6):\n            phases = torch.zeros(\n                g.heterograph.number_of_nodes(\"n4\"), n_max_phases,\n            )\n\n            periodicity = torch.zeros(\n                g.heterograph.number_of_nodes(\"n4\"), n_max_phases,\n            )\n\n            k = torch.zeros(g.heterograph.number_of_nodes(\"n4\"), n_max_phases,)\n\n            for idx in range(g.heterograph.number_of_nodes(\"n4\")):\n                idxs = tuple(node.data[\"idxs\"][idx].numpy())\n                if idxs in force:\n                    _force = force[idxs]\n                    for sub_idx in range(len(_force.periodicity)):\n                        if hasattr(_force, \"k%s\" % sub_idx):\n                            k[idx, sub_idx] = getattr(\n                                _force, \"k%s\" % sub_idx\n                            ).value_in_unit(esp.units.ENERGY_UNIT)\n\n                            phases[idx, sub_idx] = getattr(\n                                _force, \"phase%s\" % sub_idx\n                            ).value_in_unit(esp.units.ANGLE_UNIT)\n\n                            periodicity[idx, sub_idx] = getattr(\n                                _force, \"periodicity%s\" % sub_idx\n                            )\n\n            return {\n                \"k_ref\": k,\n                \"periodicity_ref\": periodicity,\n                \"phases_ref\": phases,\n            }\n\n        g.heterograph.apply_nodes(apply_torsion, ntype=\"n4\")\n        \"\"\"\n\n        return g\n\n    def _parametrize_smirnoff(self, g):\n        from openff.units import unit as openff_unit\n\n        OPENFF_FORCE_CONSTANT_UNIT = openff_unit\n\n        forces = self.FF.label_molecules(g.mol.to_topology())[0]\n\n        g.heterograph.apply_nodes(\n            lambda node: {\n                \"k_ref\": 2.0\n                * torch.Tensor(\n                    [\n                        forces[\"Bonds\"][\n                            tuple(node.data[\"idxs\"][idx].numpy())\n                        ].k.to_openmm().value_in_unit(esp.units.FORCE_CONSTANT_UNIT)\n                        for idx in range(node.data[\"idxs\"].shape[0])\n                    ]\n                )[:, None]\n            },\n            ntype=\"n2\",\n        )\n\n        g.heterograph.apply_nodes(\n            lambda node: {\n                \"eq_ref\": torch.Tensor(\n                    [\n                        forces[\"Bonds\"][\n                            tuple(node.data[\"idxs\"][idx].numpy())\n                        ].length.to_openmm().value_in_unit(esp.units.DISTANCE_UNIT)\n                        for idx in range(node.data[\"idxs\"].shape[0])\n                    ]\n                )[:, None]\n            },\n            ntype=\"n2\",\n        )\n\n        g.heterograph.apply_nodes(\n            lambda node: {\n                \"k_ref\": 2.0\n                * torch.Tensor(  # OpenFF records 1/2k as param\n                    [\n                        forces[\"Angles\"][\n                            tuple(node.data[\"idxs\"][idx].numpy())\n                        ].k.to_openmm().value_in_unit(esp.units.ANGLE_FORCE_CONSTANT_UNIT)\n                        for idx in range(node.data[\"idxs\"].shape[0])\n                    ]\n                )[:, None]\n            },\n            ntype=\"n3\",\n        )\n\n        g.heterograph.apply_nodes(\n            lambda node: {\n                \"eq_ref\": torch.Tensor(\n                    [\n                        forces[\"Angles\"][\n                            tuple(node.data[\"idxs\"][idx].numpy())\n                        ].angle.to_openmm().value_in_unit(esp.units.ANGLE_UNIT)\n                        for idx in range(node.data[\"idxs\"].shape[0])\n                    ]\n                )[:, None]\n            },\n            ntype=\"n3\",\n        )\n\n        g.heterograph.apply_nodes(\n            lambda node: {\n                \"epsilon_ref\": torch.Tensor(\n                    [\n                        forces[\"vdW\"][(idx,)].epsilon.to_openmm().value_in_unit(\n                            esp.units.ENERGY_UNIT\n                        )\n                        for idx in range(g.heterograph.number_of_nodes(\"n1\"))\n                    ]\n                )[:, None]\n            },\n            ntype=\"n1\",\n        )\n\n        g.heterograph.apply_nodes(\n            lambda node: {\n                \"sigma_ref\": torch.Tensor(\n                    [\n                        forces[\"vdW\"][(idx,)].rmin_half.to_openmm().value_in_unit(\n                            esp.units.DISTANCE_UNIT\n                        )\n                        for idx in range(g.heterograph.number_of_nodes(\"n1\"))\n                    ]\n                )[:, None]\n            },\n            ntype=\"n1\",\n        )\n\n        def apply_torsion(node, n_max_phases=6):\n            phases = torch.zeros(\n                g.heterograph.number_of_nodes(\"n4\"),\n                n_max_phases,\n            )\n\n            periodicity = torch.zeros(\n                g.heterograph.number_of_nodes(\"n4\"),\n                n_max_phases,\n            )\n\n            k = torch.zeros(\n                g.heterograph.number_of_nodes(\"n4\"),\n                n_max_phases,\n            )\n\n            force = forces[\"ProperTorsions\"]\n\n            for idx in range(g.heterograph.number_of_nodes(\"n4\")):\n                idxs = tuple(node.data[\"idxs\"][idx].numpy())\n                if idxs in force:\n                    _force = force[idxs]\n                    for sub_idx in range(len(_force.periodicity)):\n                        if hasattr(_force, \"k%s\" % sub_idx):\n                            k[idx, sub_idx] = getattr(\n                                _force, \"k%s\" % sub_idx\n                            ).to_openmm().value_in_unit(esp.units.ENERGY_UNIT)\n\n                            phases[idx, sub_idx] = getattr(\n                                _force, \"phase%s\" % sub_idx\n                            ).to_openmm().value_in_unit(esp.units.ANGLE_UNIT)\n\n                            periodicity[idx, sub_idx] = getattr(\n                                _force, \"periodicity%s\" % sub_idx\n                            )\n\n            return {\n                \"k_ref\": k,\n                \"periodicity_ref\": periodicity,\n                \"phases_ref\": phases,\n            }\n\n        def apply_improper_torsion(node, n_max_phases=6):\n            phases = torch.zeros(\n                g.heterograph.number_of_nodes(\"n4_improper\"),\n                n_max_phases,\n            )\n\n            periodicity = torch.zeros(\n                g.heterograph.number_of_nodes(\"n4_improper\"),\n                n_max_phases,\n            )\n\n            k = torch.zeros(\n                g.heterograph.number_of_nodes(\"n4_improper\"),\n                n_max_phases,\n            )\n\n            force = forces[\"ImproperTorsions\"]\n\n            for idx in range(g.heterograph.number_of_nodes(\"n4_improper\")):\n                idxs = tuple(node.data[\"idxs\"][idx].numpy())\n                if idxs in force:\n                    _force = force[idxs]\n                    for sub_idx in range(len(_force.periodicity)):\n\n                        if hasattr(_force, \"k%s\" % sub_idx):\n                            k[idx, sub_idx] = getattr(\n                                _force, \"k%s\" % sub_idx\n                            ).to_openmm().value_in_unit(esp.units.ENERGY_UNIT)\n\n                            phases[idx, sub_idx] = getattr(\n                                _force, \"phase%s\" % sub_idx\n                            ).to_openmm().value_in_unit(esp.units.ANGLE_UNIT)\n\n                            periodicity[idx, sub_idx] = getattr(\n                                _force, \"periodicity%s\" % sub_idx\n                            )\n\n            return {\n                \"k_ref\": k,\n                \"periodicity_ref\": periodicity,\n                \"phases_ref\": phases,\n            }\n\n        g.heterograph.apply_nodes(apply_torsion, ntype=\"n4\")\n        g.heterograph.apply_nodes(apply_improper_torsion, ntype=\"n4_improper\")\n\n        return g\n\n    def baseline_energy(self, g, suffix=None):\n        if suffix is None:\n            suffix = \"_\" + self.forcefield\n\n        from openmmforcefields.generators import SystemGenerator\n\n        # define a system generator\n        system_generator = SystemGenerator(\n            small_molecule_forcefield=self.forcefield,\n        )\n\n        mol = g.mol\n        # mol.assign_partial_charges(\"formal_charge\")\n        # create system\n        system = system_generator.create_system(\n            topology=mol.to_topology().to_openmm(),\n            molecules=mol,\n        )\n\n        # parameterize topology\n        topology = g.mol.to_topology().to_openmm()\n\n        integrator = openmm.LangevinIntegrator(\n            TEMPERATURE, COLLISION_RATE, STEP_SIZE\n        )\n\n        # create simulation\n        simulation = Simulation(\n            topology=topology, system=system, integrator=integrator\n        )\n\n        us = []\n\n        xs = (\n            Quantity(\n                g.nodes[\"n1\"].data[\"xyz\"].detach().numpy(),\n                esp.units.DISTANCE_UNIT,\n            )\n            .value_in_unit(unit.nanometer)\n            .transpose((1, 0, 2))\n        )\n\n        for x in xs:\n            simulation.context.setPositions(x)\n            us.append(\n                simulation.context.getState(getEnergy=True)\n                .getPotentialEnergy()\n                .value_in_unit(esp.units.ENERGY_UNIT)\n            )\n\n        g.nodes[\"g\"].data[\"u%s\" % suffix] = torch.tensor(us)[None, :]\n\n        return g\n\n    def _multi_typing_smirnoff(self, g):\n        # mol = self._convert_to_off(mol)\n\n        forces = self.FF.label_molecules(g.mol.to_topology())[0]\n\n        g.heterograph.apply_nodes(\n            lambda node: {\n                \"legacy_typing\": torch.Tensor(\n                    [\n                        int(\n                            forces[\"Bonds\"][\n                                tuple(node.data[\"idxs\"][idx].numpy())\n                            ].id[1:]\n                        )\n                        for idx in range(node.data[\"idxs\"].shape[0])\n                    ]\n                ).long()\n            },\n            ntype=\"n2\",\n        )\n\n        g.heterograph.apply_nodes(\n            lambda node: {\n                \"legacy_typing\": torch.Tensor(\n                    [\n                        int(\n                            forces[\"Angles\"][\n                                tuple(node.data[\"idxs\"][idx].numpy())\n                            ].id[1:]\n                        )\n                        for idx in range(node.data[\"idxs\"].shape[0])\n                    ]\n                ).long()\n            },\n            ntype=\"n3\",\n        )\n\n        g.heterograph.apply_nodes(\n            lambda node: {\n                \"legacy_typing\": torch.Tensor(\n                    [\n                        int(forces[\"vdW\"][(idx,)].id[1:])\n                        for idx in range(g.heterograph.number_of_nodes(\"n1\"))\n                    ]\n                ).long()\n            },\n            ntype=\"n1\",\n        )\n\n        return g\n\n    def parametrize(self, g):\n        \"\"\"Parametrize a molecular graph.\"\"\"\n        if \"smirnoff\" in self.forcefield or \"openff\" in self.forcefield:\n            return self._parametrize_smirnoff(g)\n\n        elif \"gaff\" in self.forcefield:\n            return self._parametrize_gaff(g)\n\n        else:\n            raise NotImplementedError\n\n    def typing(self, g):\n        \"\"\"Type a molecular graph.\"\"\"\n        if \"gaff\" in self.forcefield:\n            return self._type_gaff(g)\n\n        else:\n            raise NotImplementedError\n\n    def multi_typing(self, g):\n        \"\"\"Type a molecular graph for hetero nodes.\"\"\"\n        if \"smirnoff\" in self.forcefield:\n            return self._multi_typing_smirnoff(g)\n\n        else:\n            raise NotImplementedError\n\n    def __call__(self, *args, **kwargs):\n        return self.typing(*args, **kwargs)\n"
  },
  {
    "path": "espaloma/graphs/tests/test_deploy.py",
    "content": "import openmm\nimport urllib.request\nimport numpy.testing as npt\nimport espaloma as esp\nfrom openmm import unit\n\nomm_angle_unit = unit.radian\nomm_energy_unit = unit.kilojoule_per_mole\nfrom openmm.unit import Quantity\n\n\ndef test_butane_charge_am1bcc():\n    \"\"\"check that esp.graphs.deploy.openmm_system_from_graph runs without error on butane using\n    am1-bcc charge method\"\"\"\n    ff = esp.graphs.legacy_force_field.LegacyForceField(\"openff-1.2.0\")\n    g = esp.Graph(\"CCCC\")\n    g = ff.parametrize(g)\n    esp.graphs.deploy.openmm_system_from_graph(g, suffix=\"_ref\", charge_method=\"am1-bcc\")\n\ndef test_butane_charge_nn():\n    \"\"\"check that esp.graphs.deploy.openmm_system_from_graph runs without error on butane using\n    the nn charge method\"\"\"\n    import torch\n    # Download serialized espaloma model\n    url = f'https://github.com/choderalab/espaloma/releases/download/0.3.0/espaloma-0.3.0rc1.pt'\n    espaloma_model_filepath = f'espaloma-0.3.0rc1.pt'\n    urllib.request.urlretrieve(url, filename=espaloma_model_filepath)\n    # Test deployment\n    ff = esp.graphs.legacy_force_field.LegacyForceField(\"openff-1.2.0\")\n    g = esp.Graph(\"CCCC\")\n    g = ff.parametrize(g)\n    # apply a trained espaloma model to assign parameters\n    net = torch.load(espaloma_model_filepath, map_location=torch.device('cpu'))\n    net.eval()\n    net(g.heterograph)\n    esp.graphs.deploy.openmm_system_from_graph(g, suffix=\"_ref\", charge_method=\"nn\")\n\ndef test_caffeine():\n    \"\"\"Test Openmm system deployment of caffeine method using the charges from the molecule runs\n    without error.\"\"\"\n    ff = esp.graphs.legacy_force_field.LegacyForceField(\"openff-1.2.0\")\n    g = esp.Graph(\"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\")\n    g = ff.parametrize(g)\n    g.mol.assign_partial_charges(\"am1bcc\")  # Assign charges after parametrizing\n    esp.graphs.deploy.openmm_system_from_graph(g, suffix=\"_ref\", charge_method=\"from-molecule\")\n\n\ndef test_parameter_consistent_caffeine():\n    ff = esp.graphs.legacy_force_field.LegacyForceField(\"openff-1.2.0\")\n    g = esp.Graph(\"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\")\n    g = ff.parametrize(g)\n    system = esp.graphs.deploy.openmm_system_from_graph(g, suffix=\"_ref\", charge_method=\"am1-bcc\")\n    forces = list(system.getForces())\n    openff_forces = ff.FF.label_molecules(g.mol.to_topology())[0]\n    for idx, force in enumerate(forces):\n        force.setForceGroup(idx)\n        name = force.__class__.__name__\n        if \"HarmonicBondForce\" in name:\n            for _idx in range(force.getNumBonds()):\n                start, end, eq, k_openmm = force.getBondParameters(_idx)\n\n                k_openff = openff_forces[\"Bonds\"][(start, end)].k.to_openmm()\n\n                npt.assert_almost_equal(\n                    k_openmm / k_openff,\n                    2.0,\n                    decimal=3,\n                )\n\n\ndef test_energy_consistent_caffeine():\n    \"\"\"Deploy a caffeine molecule parametrized by a traditional force field\n    and deployed by espaloma, make sure the energies computed using espaloma\n    and OpenMM are same or close.\n\n    \"\"\"\n    # grab a force field\n    ff = esp.graphs.legacy_force_field.LegacyForceField(\"openff-1.2.0\")\n\n    # parametrize caffeine molecule using the parametrization\n    ## Should there be a second test for SMIRNOFF impropers?\n    g = esp.Graph(\"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\")\n    g = ff.parametrize(g)\n    system = esp.graphs.deploy.openmm_system_from_graph(g, suffix=\"_ref\", charge_method=\"am1-bcc\")\n\n    # compute energies using espaloma\n    import torch\n\n    g.nodes[\"n1\"].data[\"xyz\"] = torch.randn(\n        g.heterograph.number_of_nodes(\"n1\"), 1, 3\n    )\n    esp.mm.geometry.geometry_in_graph(g.heterograph)\n    esp.mm.energy.energy_in_graph(\n        g.heterograph, terms=[\"n2\", \"n3\", \"n4\", \"n4_improper\"], suffix=\"_ref\"\n    )\n\n    # compute energies using OpenMM with bond, angle, and torsion breakdown\n    forces = list(system.getForces())\n\n    energies = {}\n\n    for idx, force in enumerate(forces):\n        force.setForceGroup(idx)\n\n        name = force.__class__.__name__\n\n        if \"Nonbonded\" in name:\n            force.setNonbondedMethod(openmm.NonbondedForce.NoCutoff)\n\n            # epsilons = {}\n            # sigmas = {}\n\n            # for _idx in range(force.getNumParticles()):\n            #     q, sigma, epsilon = force.getParticleParameters(_idx)\n\n            #     # record parameters\n            #     epsilons[_idx] = epsilon\n            #     sigmas[_idx] = sigma\n\n            #     force.setParticleParameters(_idx, 0., sigma, epsilon)\n\n            # def sigma_combining_rule(sig1, sig2):\n            #     return (sig1 + sig2) / 2\n\n            # def eps_combining_rule(eps1, eps2):\n            #     return np.sqrt(np.abs(eps1 * eps2))\n\n            # for _idx in range(force.getNumExceptions()):\n            #     idx0, idx1, q, sigma, epsilon = force.getExceptionParameters(\n            #         _idx)\n            #     force.setExceptionParameters(\n            #         _idx,\n            #         idx0,\n            #         idx1,\n            #         0.0,\n            #         sigma_combining_rule(sigmas[idx0], sigmas[idx1]),\n            #         eps_combining_rule(epsilons[idx0], epsilons[idx1])\n            #     )\n\n            # force.updateParametersInContext(_simulation.context)\n\n    # create new simulation\n    _simulation = openmm.app.Simulation(\n        g.mol.to_topology().to_openmm(),\n        system,\n        openmm.VerletIntegrator(0.0),\n    )\n\n    _simulation.context.setPositions(\n        Quantity(\n            g.nodes[\"n1\"].data[\"xyz\"][:, 0, :].numpy(),\n            unit=esp.units.DISTANCE_UNIT,\n        ).value_in_unit(unit.nanometer)\n    )\n\n    for idx, force in enumerate(forces):\n        name = force.__class__.__name__\n\n        state = _simulation.context.getState(\n            getEnergy=True,\n            getParameters=True,\n            groups=2**idx,\n        )\n\n        energy = state.getPotentialEnergy().value_in_unit(\n            esp.units.ENERGY_UNIT\n        )\n\n        energies[name] = energy\n\n    # test if bond energies are equal\n    npt.assert_almost_equal(\n        g.nodes[\"g\"].data[\"u_n2_ref\"].numpy(),\n        energies[\"HarmonicBondForce\"],\n        decimal=3,\n    )\n\n    # test if angle energies are equal\n    npt.assert_almost_equal(\n        g.nodes[\"g\"].data[\"u_n3_ref\"].numpy(),\n        energies[\"HarmonicAngleForce\"],\n        decimal=3,\n    )\n\n    # test if torsion energies are equal\n    npt.assert_almost_equal(\n        g.nodes[\"g\"].data[\"u_n4_ref\"].numpy()\n        + g.nodes[\"g\"].data[\"u_n4_improper_ref\"].numpy(),\n        energies[\"PeriodicTorsionForce\"],\n        decimal=3,\n    )\n\n\n# TODO: test that desired parameters are assigned\n"
  },
  {
    "path": "espaloma/graphs/tests/test_gaff_parametrize.py",
    "content": "import pytest\nimport espaloma as esp\n\n\ndef test_gaff_parametrize():\n    ff = esp.graphs.legacy_force_field.LegacyForceField(\"gaff-1.81\")\n    g = esp.Graph(\n        \"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\",\n    )\n    ff.parametrize(g)\n\n    print(g.nodes[\"n2\"].data)\n    print(g.nodes[\"n3\"].data)\n    print(g.nodes[\"n4\"].data)\n    print(g.nodes[\"n4_improper\"].data)\n"
  },
  {
    "path": "espaloma/graphs/tests/test_graph.py",
    "content": "import io\nimport json\nimport pytest\nimport shutil\nimport importlib_resources\nimport espaloma as esp\n\n\ndef test_graph():\n    import espaloma as esp\n\n    g = esp.Graph(\"c1ccccc1\")\n\n    print(g.heterograph)\n\n\n@pytest.fixture\ndef graph():\n    import espaloma as esp\n\n    return esp.Graph(\"c1ccccc1\")\n\n\ndef test_ndata_consistency(graph):\n    import torch\n\n    assert torch.equal(graph.ndata[\"h0\"], graph.nodes[\"n1\"].data[\"h0\"])\n\n\n@pytest.mark.parametrize(\n    \"molecule, charge\",\n    [\n        pytest.param(\"C\", 0, id=\"methane\"),\n        pytest.param(\"[NH4+]\", 1, id=\"Ammonium\"),\n        pytest.param(\"CC(=O)[O-]\", -1, id=\"Acetate\"),\n    ],\n)\ndef test_formal_charge(molecule, charge):\n    import espaloma as esp\n\n    graph = esp.Graph(molecule)\n    assert graph.nodes[\"g\"].data[\"sum_q\"].numpy()[0] == charge\n\n\ndef test_save_and_load(graph):\n    import tempfile\n\n    with tempfile.TemporaryDirectory() as tempdir:\n        graph.save(tempdir + \"/g.esp\")\n        new_graph = esp.Graph.load(tempdir + \"/g.esp\")\n\n    assert graph.homograph.number_of_nodes() == new_graph.homograph.number_of_nodes()\n\n    assert graph.homograph.number_of_edges() == new_graph.homograph.number_of_edges()\n\ndef test_load_from_older_openff(tmp_path_factory):\n    \"\"\"Tests creating a graph from a json-serialized mol with older openff-toolkit\n    version (0.10.x)\n\n    This checks that the serialized molecule doesn't have the expected hierarchy_schemes\n    key, which will be created on the fly when loaded as a graph.\n\n    This tests creates a graph with\n    \"\"\"\n    # Load json serialized off 0.10.6 molecule and save it in path\n    from openff.toolkit import Molecule\n    mol_json_path = importlib_resources.files('espaloma.data') / 'off-mol_0_10_6.json'\n    with open(str(mol_json_path), \"r\") as json_file:\n        # This loads it as a string -- seems like an off toolkit limitation\n        mol_json_str = json.load(json_file)\n    mol_dict = json.load(io.StringIO(mol_json_str))\n    assert \"hierarchy_schemes\" not in mol_dict, \"Serialized json mol contains unexpected key.\"\n    # Save json molecule in path\n    out_esp_dir_1 = tmp_path_factory.mktemp(\"esp1\")\n    shutil.copy(mol_json_path, out_esp_dir_1 / \"mol.json\")\n\n    # update dicitonary and create espaloma graph with the same molecule\n    mol_dict[\"hierarchy_schemes\"] = dict()\n    off_molecule = Molecule.from_dict(mol_dict)\n    smiles = off_molecule.to_smiles()\n    g = esp.Graph(smiles)\n    # Save the graph\n    out_esp_dir_2 = tmp_path_factory.mktemp(\"esp2\") / \"esp-test\"\n    g.save(str(out_esp_dir_2))\n    # copy homo/hetero-graphs to original dir\n    shutil.copy(out_esp_dir_2 / \"homograph.bin\", out_esp_dir_1)\n    shutil.copy(out_esp_dir_2 / \"heterograph.bin\", out_esp_dir_1)\n\n    # Load espaloma from original directory -- with mol serialized from off 0.10.6\n    esp_graph = esp.Graph.load(str(out_esp_dir_1))\n\n    assert esp_graph.mol == g.mol, f\"Read molecule from esp graph, {esp_graph.mol} is not \" \\\n                                   f\"the same as the expected molecule {off_molecule}.\"\n\n\n# TODO: test offmol_indices\n# TODO: test relationship_indices_from_offmol\n"
  },
  {
    "path": "espaloma/graphs/tests/test_smirnoff.py",
    "content": "import pytest\n\nimport espaloma as esp\n\n\ndef test_smirnoff_esol_first():\n    ff = esp.graphs.legacy_force_field.LegacyForceField(\n        \"smirnoff99Frosst-1.1.0\"\n    )\n    g = esp.data.esol(first=1)[0]\n    g = ff.parametrize(g)\n\n\n# def test_smirnoff_strange_mol():\n#     ff = esp.graphs.legacy_force_field.LegacyForceField(\"smirnoff99Frosst-1.1.0\")\n#     g = esp.Graph(\n#         \"[H]c1c(nc(n(=O)c1N([H])[H])N([H])[H])N2C(C(C(C(C2([H])[H])([H])[H])([H])[H])([H])[H])([H])[H]\"\n#     )\n#     g = ff.parametrize(g)\n#\n#\n# def test_multi_typing():\n#     ff = esp.graphs.legacy_force_field.LegacyForceField(\"smirnoff99Frosst-1.1.0\")\n#     g = esp.data.esol(first=1)[0]\n#     g = ff.multi_typing(g)\n"
  },
  {
    "path": "espaloma/graphs/utils/__init__.py",
    "content": "import espaloma.graphs.utils.read_heterogeneous_graph\nimport espaloma.graphs.utils.read_homogeneous_graph\n"
  },
  {
    "path": "espaloma/graphs/utils/offmol_indices.py",
    "content": "import numpy as np\nfrom openff.toolkit.topology import Molecule\n\n\ndef atom_indices(offmol: Molecule) -> np.ndarray:\n    return np.array([a.molecule_atom_index for a in offmol.atoms])\n\n\ndef bond_indices(offmol: Molecule) -> np.ndarray:\n    return np.array([(b.atom1_index, b.atom2_index) for b in offmol.bonds])\n\n\ndef angle_indices(offmol: Molecule) -> np.ndarray:\n    return np.array(\n        sorted(\n            [\n                tuple([atom.molecule_atom_index for atom in angle])\n                for angle in offmol.angles\n            ]\n        )\n    )\n\n\ndef proper_torsion_indices(offmol: Molecule) -> np.ndarray:\n    return np.array(\n        sorted(\n            [\n                tuple([atom.molecule_atom_index for atom in proper])\n                for proper in offmol.propers\n            ]\n        )\n    )\n\n\ndef _all_improper_torsion_indices(offmol: Molecule) -> np.ndarray:\n    \"\"\" \"[*:1]~[*:2](~[*:3])~[*:4]\" matches\"\"\"\n\n    return np.array(\n        sorted(\n            [\n                tuple([atom.molecule_atom_index for atom in improper])\n                for improper in offmol.impropers\n            ]\n        )\n    )\n\n\ndef improper_torsion_indices(\n    offmol: Molecule, improper_def=\"espaloma\"\n) -> np.ndarray:\n    \"\"\" \"[*:1]~[X3:2](~[*:3])~[*:4]\" matches (_all_improper_torsion_indices returns \"[*:1]~[*:2](~[*:3])~[*:4]\" matches)\n\n    improper_def allows for choosing which atom will be the central atom in the\n    permutations:\n    smirnoff: central atom is listed first\n    espaloma: central atom is listed second\n\n    Addtionally, for smirnoff, only take the subset of atoms that corresponds\n    to the ccw traversal of connected atoms.\n\n    Notes\n    -----\n    Motivation: offmol.impropers returns a large number of impropers, and we may wish to restrict this number.\n    May update this filter definition based on discussion in https://github.com/openff.toolkit/openff.toolkit/issues/746\n    \"\"\"\n\n    ## Find all atoms bound to exactly 3 other atoms\n    if improper_def == \"espaloma\":\n        ## This finds all orderings, which is what we want for the espaloma case\n        ##  but not for smirnoff\n        improper_smarts = \"[*:1]~[X3:2](~[*:3])~[*:4]\"\n        mol_idxs = offmol.chemical_environment_matches(improper_smarts)\n        return np.array(mol_idxs)\n    elif improper_def == \"smirnoff\":\n        improper_smarts = \"[*:2]~[X3:1](~[*:3])~[*:4]\"\n        ## For smirnoff ordering, we only want to find the unique combinations\n        ##  of atoms forming impropers so we can permute them the way we want\n        mol_idxs = offmol.chemical_environment_matches(\n            improper_smarts, unique=True\n        )\n\n        ## Get all ccw orderings\n        # feels like there should be some good way to do this with itertools...\n        idx_permuts = []\n        for c, *other_atoms in mol_idxs:\n            for i in range(3):\n                idx = [c]\n                for j in range(3):\n                    idx.append(other_atoms[(i + j) % 3])\n                idx_permuts.append(tuple(idx))\n\n        return np.array(idx_permuts)\n    else:\n        raise ValueError(f\"Unknown value for improper_def: {improper_def}\")\n"
  },
  {
    "path": "espaloma/graphs/utils/read_heterogeneous_graph.py",
    "content": "\"\"\" Build heterogeneous graph from homogeneous ones.\n\n\"\"\"\n# =============================================================================\n# IMPORTS\n# =============================================================================\nimport numpy as np\nimport torch\nfrom espaloma.graphs.utils import offmol_indices\nfrom openff.toolkit.topology import Molecule\nfrom typing import Dict\n\n# =============================================================================\n# UTILITY FUNCTIONS\n# =============================================================================\n\n\ndef duplicate_index_ordering(indices: np.ndarray) -> np.ndarray:\n    \"\"\"For every (a,b,c,d) add a (d,c,b,a)\n\n    TODO: is there a way to avoid this duplication?\n\n    >>> indices = np.array([[0, 1, 2, 3], [1, 2, 3, 4]])\n    >>> duplicate_index_ordering(indices)\n    array([[0, 1, 2, 3],\n           [1, 2, 3, 4],\n           [3, 2, 1, 0],\n           [4, 3, 2, 1]])\n    \"\"\"\n    return np.concatenate([indices, np.flip(indices, axis=-1)], axis=0)\n\n\ndef relationship_indices_from_offmol(\n    offmol: Molecule,\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Construct a dictionary that maps node names (like \"n2\") to torch tensors of indices\n\n    Notes\n    -----\n    * introduces 2x redundant indices (including (d,c,b,a) for every (a,b,c,d)) for compatibility with later processing\n    \"\"\"\n    idxs = dict()\n    idxs[\"n1\"] = offmol_indices.atom_indices(offmol)\n    idxs[\"n2\"] = offmol_indices.bond_indices(offmol)\n    idxs[\"n3\"] = offmol_indices.angle_indices(offmol)\n    idxs[\"n4\"] = offmol_indices.proper_torsion_indices(offmol)\n    idxs[\"n4_improper\"] = offmol_indices.improper_torsion_indices(offmol)\n\n    if len(idxs[\"n4\"]) == 0:\n        idxs[\"n4\"] = np.empty((0, 4))\n\n    if len(idxs[\"n4_improper\"]) == 0:\n        idxs[\"n4_improper\"] = np.empty((0, 4))\n\n    # TODO: enumerate indices for coupling-term nodes also\n    # TODO: big refactor of term names from \"n4\" to \"proper_torsion\", \"improper_torsion\", \"angle_angle_coupling\", etc.\n\n    # TODO (discuss with YW) : I think \"n1\" and \"n4_improper\" shouldn't be 2x redundant in current scheme\n    #   (also, unclear why we need \"n2\", \"n3\", \"n4\" to be 2x redundant, but that's something to consider changing later)\n    for key in [\"n2\", \"n3\", \"n4\"]:\n        idxs[key] = duplicate_index_ordering(idxs[key])\n\n    # make them all torch.Tensors\n    for key in idxs:\n        idxs[key] = torch.from_numpy(idxs[key])\n\n    return idxs\n\n\ndef from_homogeneous_and_mol(g, offmol):\n    r\"\"\"Build heterogeneous graph from homogeneous ones.\n\n\n    Note\n    ----\n    For now we name single node, two-, three, and four-,\n    hypernodes as `n1`, `n2`, `n3`, and `n4`. These correspond\n    to atom, bond, angle, and torsion nodes in chemical graphs.\n\n\n    Parameters\n    ----------\n    g : `espaloma.HomogeneousGraph` object\n        the homogeneous graph to be translated.\n\n    Returns\n    -------\n    hg : `espaloma.HeterogeneousGraph` object\n        the resulting heterogeneous graph.\n\n    \"\"\"\n\n    # initialize empty dictionary\n    hg = {}\n\n    # get adjacency matrix\n    a = g.adjacency_matrix()\n\n    # get all the indices\n    idxs = relationship_indices_from_offmol(offmol)\n\n    # make them all numpy\n    idxs = {key: value.numpy() for key, value in idxs.items()}\n\n    # also include n1\n    idxs[\"n1\"] = np.arange(g.number_of_nodes())[:, None]\n\n    # =========================\n    # neighboring relationships\n    # =========================\n    # NOTE:\n    # here we only define the neighboring relationship\n    # on atom level\n    hg[(\"n1\", \"n1_neighbors_n1\", \"n1\")] = idxs[\"n2\"]\n\n    # build a mapping between indices and the ordering\n    idxs_to_ordering = {}\n\n    for term in [\"n1\", \"n2\", \"n3\", \"n4\", \"n4_improper\"]:\n        idxs_to_ordering[term] = {\n            tuple(subgraph_idxs): ordering\n            for (ordering, subgraph_idxs) in enumerate(list(idxs[term]))\n        }\n\n    # ===============================================\n    # relationships between nodes of different levels\n    # ===============================================\n    # NOTE:\n    # here we define all the possible\n    # 'has' and 'in' relationships.\n    # TODO:\n    # we'll test later to see if this adds too much overhead\n    #\n\n    for small_idx in range(1, 5):\n        for big_idx in range(small_idx + 1, 5):\n            for pos_idx in range(big_idx - small_idx + 1):\n\n                hg[\n                    (\n                        \"n%s\" % small_idx,\n                        \"n%s_as_%s_in_n%s\" % (small_idx, pos_idx, big_idx),\n                        \"n%s\" % big_idx,\n                    )\n                ] = np.stack(\n                    [\n                        np.array(\n                            [\n                                idxs_to_ordering[\"n%s\" % small_idx][tuple(x)]\n                                for x in idxs[\"n%s\" % big_idx][\n                                    :, pos_idx : pos_idx + small_idx\n                                ]\n                            ]\n                        ),\n                        np.arange(idxs[\"n%s\" % big_idx].shape[0]),\n                    ],\n                    axis=1,\n                )\n\n                hg[\n                    (\n                        \"n%s\" % big_idx,\n                        \"n%s_has_%s_n%s\" % (big_idx, pos_idx, small_idx),\n                        \"n%s\" % small_idx,\n                    )\n                ] = np.stack(\n                    [\n                        np.arange(idxs[\"n%s\" % big_idx].shape[0]),\n                        np.array(\n                            [\n                                idxs_to_ordering[\"n%s\" % small_idx][tuple(x)]\n                                for x in idxs[\"n%s\" % big_idx][\n                                    :, pos_idx : pos_idx + small_idx\n                                ]\n                            ]\n                        ),\n                    ],\n                    axis=1,\n                )\n\n    # ======================================\n    # nonbonded terms\n    # ======================================\n    # NOTE: everything is counted twice here\n    # nonbonded is where\n    # $A = AA = AAA = AAAA = 0$\n\n    # make dense\n    a_ = a.to_dense().detach().numpy()\n\n    idxs[\"nonbonded\"] = np.stack(\n        np.where(np.equal(a_ + a_ @ a_ + a_ @ a_ @ a_, 0.0)),\n        axis=-1,\n    )\n\n    # onefour is the two ends of torsion\n    # idxs[\"onefour\"] = np.stack(\n    #     [\n    #         idxs[\"n4\"][:, 0],\n    #         idxs[\"n4\"][:, 3],\n    #     ],\n    #     axis=1,\n    # )\n\n    idxs[\"onefour\"] = np.stack(\n        np.where(\n            np.equal(a_ + a_ @ a_, 0.0) * np.greater(a_ @ a_ @ a_, 0.0),\n        ),\n        axis=-1,\n    )\n\n    # membership\n    for term in [\"nonbonded\", \"onefour\"]:\n        for pos_idx in [0, 1]:\n            hg[(term, \"%s_has_%s_n1\" % (term, pos_idx), \"n1\")] = np.stack(\n                [np.arange(idxs[term].shape[0]), idxs[term][:, pos_idx]],\n                axis=-1,\n            )\n\n            hg[(\"n1\", \"n1_as_%s_in_%s\" % (pos_idx, term), term)] = np.stack(\n                [\n                    idxs[term][:, pos_idx],\n                    np.arange(idxs[term].shape[0]),\n                ],\n                axis=-1,\n            )\n\n    # membership of n1 in n4_improper\n    for term in [\"n4_improper\"]:\n        for pos_idx in [0, 1, 2, 3]:\n            hg[(term, \"%s_has_%s_n1\" % (term, pos_idx), \"n1\")] = np.stack(\n                [np.arange(idxs[term].shape[0]), idxs[term][:, pos_idx]],\n                axis=-1,\n            )\n\n            hg[(\"n1\", \"n1_as_%s_in_%s\" % (pos_idx, term), term)] = np.stack(\n                [\n                    idxs[term][:, pos_idx],\n                    np.arange(idxs[term].shape[0]),\n                ],\n                axis=-1,\n            )\n\n    # ======================================\n    # relationships between nodes and graphs\n    # ======================================\n    for term in [\n        \"n1\",\n        \"n2\",\n        \"n3\",\n        \"n4\",\n        \"n4_improper\",\n        \"nonbonded\",\n        \"onefour\",\n    ]:\n        hg[(term, \"%s_in_g\" % term, \"g\",)] = np.stack(\n            [np.arange(len(idxs[term])), np.zeros(len(idxs[term]))],\n            axis=1,\n        )\n\n        hg[(\"g\", \"g_has_%s\" % term, term)] = np.stack(\n            [\n                np.zeros(len(idxs[term])),\n                np.arange(len(idxs[term])),\n            ],\n            axis=1,\n        )\n\n    import dgl\n\n    hg = dgl.heterograph(\n        {key: value.astype(np.int32).tolist() for key, value in hg.items()}\n    )\n\n    hg.nodes[\"n1\"].data[\"h0\"] = g.ndata[\"h0\"]\n    hg.nodes[\"g\"].data[\"sum_q\"] = g.ndata[\"sum_q\"][0].reshape(1, 1)\n    # include indices in the nodes themselves\n    for term in [\n        \"n1\",\n        \"n2\",\n        \"n3\",\n        \"n4\",\n        \"n4_improper\",\n        \"onefour\",\n        \"nonbonded\",\n    ]:\n        hg.nodes[term].data[\"idxs\"] = torch.tensor(idxs[term])\n\n    return hg\n"
  },
  {
    "path": "espaloma/graphs/utils/read_homogeneous_graph.py",
    "content": "\"\"\" Build simple graph from OpenEye or RDKit molecule object.\n\n\"\"\"\n# =============================================================================\n# IMPORTS\n# =============================================================================\nimport torch\n\n# =============================================================================\n# UTILITY FUNCTIONS\n# =============================================================================\n\n\ndef fp_oe(atom):\n    from openeye import oechem\n\n    HYBRIDIZATION_OE = {\n        oechem.OEHybridization_sp: torch.tensor(\n            [1, 0, 0, 0, 0], dtype=torch.get_default_dtype()\n        ),\n        oechem.OEHybridization_sp2: torch.tensor(\n            [0, 1, 0, 0, 0], dtype=torch.get_default_dtype()\n        ),\n        oechem.OEHybridization_sp3: torch.tensor(\n            [0, 0, 1, 0, 0], dtype=torch.get_default_dtype()\n        ),\n        oechem.OEHybridization_sp3d: torch.tensor(\n            [0, 0, 0, 1, 0], dtype=torch.get_default_dtype()\n        ),\n        oechem.OEHybridization_sp3d2: torch.tensor(\n            [0, 0, 0, 0, 1], dtype=torch.get_default_dtype()\n        ),\n        oechem.OEHybridization_Unknown: torch.tensor(\n            [0, 0, 0, 0, 0], dtype=torch.get_default_dtype()\n        ),\n    }\n    return torch.cat(\n        [\n            torch.tensor(\n                [\n                    atom.GetDegree(),\n                    # Note: Discard resonance-variant features\n                    # Issue related to https://github.com/choderalab/espaloma_charge/issues/18\n                    # atom.GetValence(),\n                    # atom.GetExplicitValence(),\n                    # atom.GetFormalCharge(),\n                    atom.IsAromatic() * 1.0,\n                    atom.GetIsotope(),  # TODO: is this a good idea?\n                    oechem.OEAtomIsInRingSize(atom, 3) * 1.0,\n                    oechem.OEAtomIsInRingSize(atom, 4) * 1.0,\n                    oechem.OEAtomIsInRingSize(atom, 5) * 1.0,\n                    oechem.OEAtomIsInRingSize(atom, 6) * 1.0,\n                    oechem.OEAtomIsInRingSize(atom, 7) * 1.0,\n                    oechem.OEAtomIsInRingSize(atom, 8) * 1.0,\n                ],\n                dtype=torch.float32,\n            ),\n            HYBRIDIZATION_OE[atom.GetHyb()],\n        ],\n        dim=0,\n    )\n\n\ndef fp_rdkit(atom):\n    from rdkit import Chem\n\n    HYBRIDIZATION_RDKIT = {\n        Chem.rdchem.HybridizationType.SP: torch.tensor(\n            [1, 0, 0, 0, 0],\n            dtype=torch.get_default_dtype(),\n        ),\n        Chem.rdchem.HybridizationType.SP2: torch.tensor(\n            [0, 1, 0, 0, 0],\n            dtype=torch.get_default_dtype(),\n        ),\n        Chem.rdchem.HybridizationType.SP3: torch.tensor(\n            [0, 0, 1, 0, 0],\n            dtype=torch.get_default_dtype(),\n        ),\n        Chem.rdchem.HybridizationType.SP3D: torch.tensor(\n            [0, 0, 0, 1, 0],\n            dtype=torch.get_default_dtype(),\n        ),\n        Chem.rdchem.HybridizationType.SP3D2: torch.tensor(\n            [0, 0, 0, 0, 1],\n            dtype=torch.get_default_dtype(),\n        ),\n        Chem.rdchem.HybridizationType.S: torch.tensor(\n            [0, 0, 0, 0, 0],\n            dtype=torch.get_default_dtype(),\n        ),\n    }\n    return torch.cat(\n        [\n            torch.tensor(\n                [\n                    atom.GetTotalDegree(),\n                    # Note: Discard resonance-variant features\n                    # Issue related to https://github.com/choderalab/espaloma_charge/issues/18\n                    # atom.GetTotalValence(),\n                    # atom.GetExplicitValence(),\n                    # atom.GetFormalCharge(),\n                    atom.GetIsAromatic() * 1.0,\n                    atom.GetMass(),\n                    atom.IsInRingSize(3) * 1.0,\n                    atom.IsInRingSize(4) * 1.0,\n                    atom.IsInRingSize(5) * 1.0,\n                    atom.IsInRingSize(6) * 1.0,\n                    atom.IsInRingSize(7) * 1.0,\n                    atom.IsInRingSize(8) * 1.0,\n                ],\n                dtype=torch.get_default_dtype(),\n            ),\n            HYBRIDIZATION_RDKIT[atom.GetHybridization()],\n        ],\n        dim=0,\n    )\n\n\n# =============================================================================\n# MODULE FUNCTIONS\n# =============================================================================\ndef from_openff_toolkit_mol(mol, use_fp=True):\n    import dgl\n\n    # initialize graph\n    g = dgl.DGLGraph()\n\n    # enter nodes\n    n_atoms = mol.n_atoms\n    g.add_nodes(n_atoms)\n    g.ndata[\"type\"] = torch.Tensor(\n        [[atom.atomic_number] for atom in mol.atoms]\n    )\n    total_charge = mol.total_charge.magnitude\n    g.ndata[\"sum_q\"] = torch.Tensor(\n        [[total_charge] for _ in range(mol.n_atoms)]\n    )\n    h_v = torch.zeros(\n        g.ndata[\"type\"].shape[0], 100, dtype=torch.get_default_dtype()\n    )\n\n    h_v[\n        torch.arange(g.ndata[\"type\"].shape[0]),\n        torch.squeeze(g.ndata[\"type\"]).long(),\n    ] = 1.0\n\n    h_v_fp = torch.stack(\n        [fp_rdkit(atom) for atom in mol.to_rdkit().GetAtoms()], axis=0\n    )\n\n    if use_fp == True:\n        h_v = torch.cat([h_v, h_v_fp], dim=-1)  # (n_atoms, 117)\n\n    g.ndata[\"h0\"] = h_v\n\n    # enter bonds\n    bonds = list(mol.bonds)\n    bonds_begin_idxs = [bond.atom1_index for bond in bonds]\n    bonds_end_idxs = [bond.atom2_index for bond in bonds]\n    bonds_types = [bond.bond_order for bond in bonds]\n\n    # NOTE: dgl edges are directional\n    g.add_edges(bonds_begin_idxs, bonds_end_idxs)\n    g.add_edges(bonds_end_idxs, bonds_begin_idxs)\n\n    # g.edata[\"type\"] = torch.Tensor(bonds_types)[:, None].repeat(2, 1)\n\n    return g\n\n\ndef from_oemol(mol, use_fp=True):\n    import dgl\n\n    # initialize graph\n    g = dgl.DGLGraph()\n\n    # enter nodes\n    n_atoms = mol.NumAtoms()\n    g.add_nodes(n_atoms)\n    g.ndata[\"type\"] = torch.Tensor(\n        [[atom.GetAtomicNum()] for atom in mol.GetAtoms()]\n    )\n    g.ndata[\"formal_charge\"] = torch.Tensor(\n        [[atom.GetFormalCharge()] for atom in mol.GetAtoms()]\n    )\n    h_v = torch.zeros(g.ndata[\"type\"].shape[0], 100, dtype=torch.float32)\n\n    h_v[\n        torch.arange(g.ndata[\"type\"].shape[0]),\n        torch.squeeze(g.ndata[\"type\"]).long(),\n    ] = 1.0\n\n    h_v_fp = torch.stack([fp_oe(atom) for atom in mol.GetAtoms()], axis=0)\n\n    if use_fp == True:\n        h_v = torch.cat([h_v, h_v_fp], dim=-1)  # (n_atoms, 117)\n\n    g.ndata[\"h0\"] = h_v\n\n    # enter bonds\n    bonds = list(mol.GetBonds())\n    bonds_begin_idxs = [bond.GetBgnIdx() for bond in bonds]\n    bonds_end_idxs = [bond.GetEndIdx() for bond in bonds]\n    bonds_types = [bond.GetOrder() for bond in bonds]\n\n    # NOTE: dgl edges are directional\n    g.add_edges(bonds_begin_idxs, bonds_end_idxs)\n    g.add_edges(bonds_end_idxs, bonds_begin_idxs)\n\n    # g.edata[\"type\"] = torch.Tensor(bonds_types)[:, None].repeat(2, 1)\n\n    return g\n\n\ndef from_rdkit_mol(mol, use_fp=True):\n    import dgl\n\n    # initialize graph\n    g = dgl.DGLGraph()\n\n    # enter nodes\n    n_atoms = mol.GetNumAtoms()\n    g.add_nodes(n_atoms)\n    g.ndata[\"type\"] = torch.Tensor(\n        [[atom.GetAtomicNum()] for atom in mol.GetAtoms()]\n    )\n    g.ndata[\"formal_charge\"] = torch.Tensor(\n        [[atom.GetFormalCharge()] for atom in mol.GetAtoms()]\n    )\n    h_v = torch.zeros(g.ndata[\"type\"].shape[0], 100, dtype=torch.float32)\n\n    h_v[\n        torch.arange(g.ndata[\"type\"].shape[0]),\n        torch.squeeze(g.ndata[\"type\"]).long(),\n    ] = 1.0\n\n    h_v_fp = torch.stack([fp_rdkit(atom) for atom in mol.GetAtoms()], axis=0)\n\n    if use_fp == True:\n        h_v = torch.cat([h_v, h_v_fp], dim=-1)  # (n_atoms, 117)\n\n    g.ndata[\"h0\"] = h_v\n\n    # enter bonds\n    bonds = list(mol.GetBonds())\n    bonds_begin_idxs = [bond.GetBeginAtomIdx() for bond in bonds]\n    bonds_end_idxs = [bond.GetEndAtomIdx() for bond in bonds]\n    bonds_types = [bond.GetBondType().real for bond in bonds]\n\n    # NOTE: dgl edges are directional\n    g.add_edges(bonds_begin_idxs, bonds_end_idxs)\n    g.add_edges(bonds_end_idxs, bonds_begin_idxs)\n\n    # g.edata[\"type\"] = torch.Tensor(bonds_types)[:, None].repeat(2, 1)\n\n    return g\n"
  },
  {
    "path": "espaloma/graphs/utils/regenerate_impropers.py",
    "content": "import dgl\nimport numpy as np\nimport torch\n\nfrom .offmol_indices import improper_torsion_indices\nfrom ..graph import Graph\n\n\ndef regenerate_impropers(g: Graph, improper_def=\"smirnoff\"):\n    \"\"\"\n    Method to regenerate the improper nodes according to the specified\n    method of permuting the impropers. Modifies the esp.Graph's heterograph\n    in place and returns the new heterograph.\n    NOTE: This will clear the data on all n4_improper nodes, including\n    previously generated improper from JanossyPoolingImproper.\n    \"\"\"\n\n    ## First get rid of the old nodes/edges\n    hg = g.heterograph\n    hg = dgl.remove_nodes(hg, hg.nodes(\"n4_improper\"), \"n4_improper\")\n\n    ## Generate new improper torsion permutations\n    idxs = improper_torsion_indices(g.mol, improper_def)\n    if len(idxs) == 0:\n        return g\n\n    ## Add new nodes of type n4_improper (one for each permut)\n    hg = dgl.add_nodes(hg, idxs.shape[0], ntype=\"n4_improper\")\n\n    ## New edges b/n improper permuts and n1 nodes\n    permut_ids = np.arange(idxs.shape[0])\n    for i in range(4):\n        n1_ids = idxs[:, i]\n\n        # edge from improper node to n1 node\n        outgoing_etype = (\"n4_improper\", f\"n4_improper_has_{i}_n1\", \"n1\")\n        hg = dgl.add_edges(hg, permut_ids, n1_ids, etype=outgoing_etype)\n\n        # edge from n1 to improper\n        incoming_etype = (\"n1\", f\"n1_as_{i}_in_n4_improper\", \"n4_improper\")\n        hg = dgl.add_edges(hg, n1_ids, permut_ids, etype=incoming_etype)\n\n    ## New edges b/n improper permuts and the graph (for global pooling)\n    # edge from improper node to graph\n    outgoing_etype = (\"n4_improper\", f\"n4_improper_in_g\", \"g\")\n    hg = dgl.add_edges(\n        hg, permut_ids, np.zeros_like(permut_ids), etype=outgoing_etype\n    )\n\n    # edge from graph to improper nodes\n    incoming_etype = (\"g\", \"g_has_n4_improper\", \"n4_improper\")\n    hg = dgl.add_edges(\n        hg, np.zeros_like(permut_ids), permut_ids, etype=incoming_etype\n    )\n\n    hg.nodes[\"n4_improper\"].data[\"idxs\"] = torch.tensor(idxs)\n\n    g.heterograph = hg\n\n    return g  # hg\n"
  },
  {
    "path": "espaloma/metrics.py",
    "content": "\"\"\" Metrics to evaluate and train models.\n\n\"\"\"\nimport abc\n\n# =============================================================================\n# IMPORTS\n# =============================================================================\nimport torch\nimport numpy as np\nfrom .units import GAS_CONSTANT\n\n# =============================================================================\n# HELPER FUNCTIONS\n# =============================================================================\ndef center(metric, dim=1, reduction=\"none\"):\n    def _centered(input, target, metric=metric, dim=dim):\n        # center input\n        input = input - input.mean(dim=dim, keepdim=True)\n\n        # center target\n        target = target - target.mean(dim=dim, keepdim=True)\n\n        if reduction == \"none\":\n            return metric(input, target)\n        else:\n            return getattr(torch, reduction)(metric(input, target))\n\n    return _centered\n\n\ndef boltzmann_weighted(metric, reduction=\"mean\", temperature=300.0):\n    def _weighted(input, target, metric=metric, reduction=reduction):\n        _loss = metric(input, target)\n\n        min_target, _ = torch.min(target, dim=-1, keepdims=True)\n        delta_target = target - min_target\n\n        weight_target = torch.softmax(\n            -delta_target / (GAS_CONSTANT * temperature),\n            dim=-1,\n        )\n\n        _loss = _loss * weight_target\n\n        return getattr(torch, reduction)(_loss)\n\n    return _weighted\n\n\ndef std(metric, weight=1.0, dim=1):\n    def _std(input, target, metric=metric, weight=weight, dim=dim):\n        return weight * metric(input, target).std(dim=dim).sum()\n\n    return _std\n\n\ndef weighted(metric, weight, reduction=\"mean\"):\n    def _weighted(\n        input, target, metric=metric, weight=weight, reduction=reduction\n    ):\n        _loss = metric(input, target)\n        for _ in range(_loss.dims() - 1):\n            weight = weight.unsqueeze(-1)\n        return getattr(torch, reduction)(weight)\n\n    return _weighted\n\n\ndef weighted_with_key(metric, key=\"weight\", reduction=\"mean\"):\n    def _weighted(input, target, metric=metric, key=key, reduction=reduction):\n        weight = target.nodes[\"g\"].data[key].flatten()\n        _loss = metric(input, target)\n        for _ in range(_loss.dims() - 1):\n            weight = weight.unsqueeze(-1)\n        return getattr(torch, reduction)(weight)\n\n    return _weighted\n\n\ndef bootstrap(metric, n_samples=100, ci=0.95):\n    def _bootstrap(input, target, metric=metric, n_samples=n_samples, ci=ci):\n        original = metric(input=input, target=target)\n\n        idxs_all = np.arange(input.shape[0])\n        results = []\n        for _ in range(n_samples):\n            idxs = np.random.choice(\n                idxs_all,\n                len(idxs_all),\n                replace=True,\n            )\n\n            _metric = (\n                metric(input=input[idxs], target=target[idxs])\n                .detach()\n                .cpu()\n                .numpy()\n                .item()\n            )\n\n            results.append(\n                _metric,\n            )\n\n        results = np.array(results)\n\n        low = np.percentile(results, 100.0 * 0.5 * (1 - ci))\n        high = np.percentile(results, (1 - ((1 - ci) * 0.5)) * 100.0)\n\n        return original.detach().cpu().numpy().item(), low, high\n\n    return _bootstrap\n\n\ndef latex_format_ci(original, low, high):\n    return \"$%.4f_{%.4f}^{%.4f}$\" % (original, low, high)\n\n\n# =============================================================================\n# MODULE FUNCTIONS\n# =============================================================================\ndef mse(input, target):\n    return torch.nn.functional.mse_loss(target, input)\n\n\ndef mape(input, target):\n    return ((input - target).abs() / target.abs()).mean()\n\n\ndef rmse(input, target):\n    return torch.sqrt(torch.nn.functional.mse_loss(target, input))\n\n\ndef mae_of_log(input, target):\n    return torch.nn.L1Loss()(torch.log(input), torch.log(target))\n\n\ndef cross_entropy(input, target, reduction=\"mean\"):\n    loss_fn = torch.nn.CrossEntropyLoss(reduction=reduction)\n    return loss_fn(input=input, target=target)  # prediction first, logit\n\n\ndef r2(input, target):\n    target = target.flatten()\n    input = input.flatten()\n    ss_tot = (target - target.mean()).pow(2).sum()\n    ss_res = (input - target).pow(2).sum()\n    return 1 - torch.div(ss_res, ss_tot)\n\n\ndef accuracy(input, target):\n    # check if this is logit\n    if input.dim() == 2 and input.shape[-1] > 1:\n        input = input.argmax(dim=-1)\n\n    return torch.div((input == target).sum().double(), target.shape[0])\n\n\n# =============================================================================\n# MODULE CLASSES\n# =============================================================================\nclass Metric(torch.nn.modules.loss._Loss):\n    \"\"\"Base function for loss.\"\"\"\n\n    def __init__(self, size_average=None, reduce=None, reduction=\"mean\"):\n        super(Metric, self).__init__(size_average, reduce, reduction)\n\n    @abc.abstractmethod\n    def forward(self, *args, **kwargs):\n        raise NotImplementedError\n\n\nclass GraphMetric(Metric):\n    \"\"\"Loss between nodes attributes of graph or graphs.\n\n    Parameters\n    ----------\n    base_metric : callable\n        Metric on fixed dimensional space.\n\n    between : List[str]\n        Names of quantities to compare.\n\n    level : str\n        Level of nodes to compare with.\n\n    Returns\n    -------\n    torch.Tensor\n    \"\"\"\n\n    def __init__(self, base_metric, between, level=\"n1\", *args, **kwargs):\n        super(GraphMetric, self).__init__(*args, **kwargs)\n\n        # between could be tuple of two strings or two functions\n        assert len(between) == 2\n\n        self.between = (\n            self._translation(between[0], level),\n            self._translation(between[1], level),\n        )\n\n        self.base_metric = base_metric\n\n        # get name\n        if hasattr(base_metric, \"__name__\"):\n            base_name = base_metric.__name__\n        else:\n            base_name = base_metric.__class__.__name__\n\n        self.__name__ = \"%s_between_%s_and_%s_on_%s\" % (\n            base_name,\n            between[0],\n            between[1],\n            level,\n        )\n\n    @staticmethod\n    def _translation(string, level):\n        return lambda g: g.nodes[level].data[string]\n\n    def forward(self, g_input, g_target=None):\n        \"\"\"Forward function of loss.\"\"\"\n        # allow loss within self\n        if g_target is None:\n            g_target = g_input\n\n        # get input and output transform function\n        input_fn, target_fn = self.between\n\n        # compute loss using base loss\n        # NOTE:\n        # use keyward argument here since torch is bad with the order with args\n        return self.base_metric(\n            input=input_fn(g_input), target=target_fn(g_target)\n        )\n\n\nclass GraphDerivativeMetric(Metric):\n    \"\"\"Loss between nodes attributes of graph or graphs.\"\"\"\n\n    def __init__(\n        self,\n        base_metric,\n        between,\n        weight=1.0,\n        level=\"n1\",\n        d=\"xyz\",\n        d_level=\"n1\",\n        *args,\n        **kwargs\n    ):\n        super(GraphDerivativeMetric, self).__init__(*args, **kwargs)\n\n        # between could be tuple of two strings or two functions\n        assert len(between) == 2\n\n        self.between = (\n            self._translation(between[0], level),\n            self._translation(between[1], level),\n        )\n\n        self.d = self._translation(d, d_level)\n\n        self.base_metric = base_metric\n        self.weight = weight\n        # get name\n        if hasattr(base_metric, \"__name__\"):\n            base_name = base_metric.__name__\n        else:\n            base_name = base_metric.__class__.__name__\n\n        self.__name__ = \"%s_between_d_%s_d_%s_and_d_%s_d_%s_on_%s\" % (\n            base_name,\n            between[0],\n            d,\n            between[1],\n            d,\n            level,\n        )\n\n    @staticmethod\n    def _translation(string, level):\n        return lambda g: g.nodes[level].data[string]\n\n    def forward(self, g_input, g_target=None):\n        \"\"\"Forward function of loss.\"\"\"\n        # allow loss within self\n        if g_target is None:\n            g_target = g_input\n\n        # get input and output transform function\n        input_fn, target_fn = self.between\n\n        # calculate the derivatives of input\n        input_prime = torch.autograd.grad(\n            input_fn(g_input).sum(),\n            self.d(g_input),\n            create_graph=True,\n            retain_graph=True,\n            allow_unused=True,\n        )[0]\n\n        target_prime = torch.autograd.grad(\n            target_fn(g_target).sum(),\n            self.d(g_target),\n            create_graph=True,\n            retain_graph=True,\n            allow_unused=True,\n        )[0]\n\n        # compute loss using base loss\n        # NOTE:\n        # use keyward argument here since torch is bad with the order with args\n        return self.weight * self.base_metric(\n            input=input_prime,\n            target=target_prime,\n        )\n\n\nclass GraphHalfDerivativeMetric(Metric):\n    \"\"\"Loss between nodes attributes of graph or graphs.\"\"\"\n\n    def __init__(\n        self,\n        base_metric,\n        input_level=\"g\",\n        input_name=\"u\",\n        target_prime_level=\"n1\",\n        target_prime_name=\"u_ref_prime\",\n        d=\"xyz\",\n        d_level=\"n1\",\n        weight=1.0,\n        *args,\n        **kwargs\n    ):\n        super(GraphHalfDerivativeMetric, self).__init__(*args, **kwargs)\n\n        # define query functions\n        self.d = self._translation(d, d_level)\n        self.input_fn = self._translation(input_name, input_level)\n        self.target_prime_fn = self._translation(\n            target_prime_name, target_prime_level\n        )\n\n        self.base_metric = base_metric\n        self.weight = weight\n        # get name\n        if hasattr(base_metric, \"__name__\"):\n            base_name = base_metric.__name__\n        else:\n            base_name = base_metric.__class__.__name__\n\n        self.__name__ = \"%s_between_%s_d_%s_on_%s_and_%s_on_%s\" % (\n            base_name,\n            input_name,\n            d,\n            input_level,\n            target_prime_name,\n            target_prime_level,\n        )\n\n    @staticmethod\n    def _translation(string, level):\n        return lambda g: g.nodes[level].data[string]\n\n    def forward(self, g_input, g_target=None):\n        \"\"\"Forward function of loss.\"\"\"\n        # allow loss within self\n        if g_target is None:\n            g_target = g_input\n\n        # calculate the derivatives of input\n        input_prime = torch.autograd.grad(\n            self.input_fn(g_input).sum(),\n            self.d(g_input),\n            create_graph=True,\n            retain_graph=True,\n        )[0]\n\n        target_prime = self.target_prime_fn(g_target)\n\n        # compute loss using base loss\n        # NOTE:\n        # use keyward argument here since torch is bad with the order with args\n        return self.weight * self.base_metric(\n            input=input_prime,\n            target=target_prime,\n        )\n\n\n# =============================================================================\n# PRESETS\n# =============================================================================\n\n\nclass TypingCrossEntropy(GraphMetric):\n    def __init__(self):\n        super(TypingCrossEntropy, self).__init__(\n            base_metric=torch.nn.CrossEntropyLoss(),\n            between=[\"nn_typing\", \"legacy_typing\"],\n        )\n\n        self.__name__ = \"TypingCrossEntropy\"\n\n\nclass TypingAccuracy(GraphMetric):\n    def __init__(self):\n        super(TypingAccuracy, self).__init__(\n            base_metric=accuracy, between=[\"nn_typing\", \"legacy_typing\"]\n        )\n\n        self.__name__ = \"TypingAccuracy\"\n\n\nclass BondKMSE(GraphMetric):\n    def __init__(self):\n        super(BondKMSE, self).__init__(\n            between=[\"k_ref\", \"k\"], level=\"n2\", base_metric=torch.nn.MSELoss()\n        )\n\n        self.__name__ = \"BondKMSE\"\n\n\nclass BondKRMSE(GraphMetric):\n    def __init__(self):\n        super(BondKRMSE, self).__init__(\n            between=[\"k_ref\", \"k\"], level=\"n2\", base_metric=rmse\n        )\n\n        self.__name__ = \"BondKRMSE\"\n"
  },
  {
    "path": "espaloma/mm/__init__.py",
    "content": "from . import angle, bond, energy, functional, geometry, nonbonded, torsion\n"
  },
  {
    "path": "espaloma/mm/angle.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport espaloma as esp\n\n\n# =============================================================================\n# MODULE FUNCTIONS\n# =============================================================================\ndef harmonic_angle(x, k, eq):\n    \"\"\"Harmonic angle energy.\n\n    Parameters\n    ----------\n    x : `torch.Tensor`, `shape = (batch_size, 1)`\n        angle value\n    k : `torch.Tensor`, `shape = (batch_size, 1)`\n        force constant\n    eq : `torch.Tensor`, `shape = (batch_size, 1)`\n        equilibrium angle\n\n    Returns\n    -------\n    u : `torch.Tensor`, `shape = (batch_size, 1)`\n        energy\n\n    \"\"\"\n    # NOTE:\n    # the constant 0.5 is included here but not in the functional forms\n\n    # NOTE:\n    # 0.25 because all angles are calculated twice\n    return 0.5 * esp.mm.functional.harmonic(x=x, k=k, eq=eq)\n\n\ndef linear_mixture_angle(x, coefficients, phases):\n    \"\"\"Angle energy with Linear basis function.\n\n    Parameters\n    ----------\n    coefficients : torch.Tensor\n        Coefficients of the linear mixuture.\n\n    phases : torch.Tensor\n        Phases of the linear mixture.\n\n    \"\"\"\n\n    return 0.5 * esp.mm.functional.linear_mixture(\n        x=x, coefficients=coefficients, phases=phases\n    )\n\n\ndef urey_bradley(x_between, coefficients, phases):\n    return esp.mm.functional.linear_mixture(\n        x=x_between,\n        coefficients=coefficients,\n        phases=phases,\n    )\n\n\ndef bond_bond(u_left, u_right, k_bond_bond):\n    u_left = u_left - u_left.min(dim=-1, keepdims=True)[0]\n    u_right = u_right - u_right.min(dim=-1, keepdims=True)[0]\n    return k_bond_bond * (u_left**0.5) * (u_right**0.5)\n\n\ndef bond_angle(\n    u_left,\n    u_right,\n    u_angle,\n    k_bond_angle,\n):\n\n    u_left = u_left - u_left.min(dim=-1, keepdims=True)[0]\n    u_right = u_right - u_right.min(dim=-1, keepdims=True)[0]\n    u_angle = u_angle - u_angle.min(dim=-1, keepdims=True)[0]\n    return k_bond_angle * (u_left**0.5) * (\n        u_angle**0.5\n    ) + k_bond_angle * (u_right**0.5) * (u_angle**0.5)\n\n\ndef angle_high(\n    u_angle,\n    k3,\n    k4,\n):\n    u_angle = u_angle - u_angle.min(dim=-1, keepdims=True)[0]\n    return k3 * u_angle**1.5 + k4 * u_angle**2\n"
  },
  {
    "path": "espaloma/mm/bond.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport espaloma as esp\n\n\n# =============================================================================\n# MODULE FUNCTIONS\n# =============================================================================\ndef harmonic_bond(x, k, eq):\n    \"\"\"Harmonic bond energy.\n\n    Parameters\n    ----------\n    x : `torch.Tensor`, `shape = (batch_size, 1)`\n        bond length\n    k : `torch.Tensor`, `shape = (batch_size, 1)`\n        force constant\n    eq : `torch.Tensor`, `shape = (batch_size, 1)`\n        equilibrium bond length\n\n    Returns\n    -------\n    u : `torch.Tensor`, `shape = (batch_size, 1)`\n        energy\n\n    \"\"\"\n    # NOTE:\n    # the constant is included here but not in the functional forms\n\n    # NOTE:\n    # 0.25 because all bonds are calculated twice\n    return 0.5 * esp.mm.functional.harmonic(x=x, k=k, eq=eq)\n\n\ndef gaussian_bond(x, coefficients):\n    \"\"\"Bond energy with Gaussian basis function.\"\"\"\n    return esp.mm.functional.gaussian(\n        x=x,\n        coefficients=coefficients,\n    )\n\n\ndef linear_mixture_bond(x, coefficients, phases):\n    \"\"\"Bond energy with Linear basis function.\n\n    Parameters\n    ----------\n    coefficients : torch.Tensor\n        Coefficients of the linear mixuture.\n\n    phases : torch.Tensor\n        Phases of the linear mixture.\n\n    \"\"\"\n    return 0.5 * esp.mm.functional.linear_mixture(\n        x=x, coefficients=coefficients, phases=phases\n    )\n\n\ndef bond_high(u_bond, k3, k4):\n    u_bond = u_bond - u_bond.min(dim=-1, keepdims=True)[0]\n    return k3 * u_bond**1.5 + k4 * u_bond**2\n"
  },
  {
    "path": "espaloma/mm/energy.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport torch\n\nimport espaloma as esp\n\n\n# =============================================================================\n# ENERGY IN HYPERNODES---BONDED\n# =============================================================================\ndef apply_bond(nodes, suffix=\"\"):\n    \"\"\"Bond energy in nodes.\"\"\"\n    # if suffix == '_ref':\n    return {\n        \"u%s\"\n        % suffix: esp.mm.bond.harmonic_bond(\n            x=nodes.data[\"x\"],\n            k=nodes.data[\"k%s\" % suffix],\n            eq=nodes.data[\"eq%s\" % suffix],\n        )\n    }\n\n    # else:\n    #     return {\n    #         'u%s' % suffix: esp.mm.bond.harmonic_bond_re(\n    #             x=nodes.data['x'],\n    #             k=nodes.data['k%s' % suffix],\n    #             eq=nodes.data['eq%s' % suffix],\n    #         )\n    #     }\n\n\ndef apply_angle(nodes, suffix=\"\"):\n    \"\"\"Angle energy in nodes.\"\"\"\n    return {\n        \"u%s\"\n        % suffix: esp.mm.angle.harmonic_angle(\n            x=nodes.data[\"x\"],\n            k=nodes.data[\"k%s\" % suffix],\n            eq=nodes.data[\"eq%s\" % suffix],\n        )\n    }\n\n\ndef apply_angle_ii(nodes, suffix=\"\"):\n    return {\n        # \"u_angle_high%s\"\n        # % suffix: esp.mm.angle.angle_high(\n        #     u_angle=nodes.data[\"u\"],\n        #     k3=nodes.data[\"k3\"],\n        #     k4=nodes.data[\"k4\"],\n        # ),\n        \"u_urey_bradley%s\"\n        % suffix: esp.mm.angle.urey_bradley(\n            x_between=nodes.data[\"x_between\"],\n            coefficients=nodes.data[\"coefficients_urey_bradley\"],\n            phases=[0.0, 12.0],\n        ),\n        \"u_bond_bond%s\"\n        % suffix: esp.mm.angle.bond_bond(\n            u_left=nodes.data[\"u_left\"],\n            u_right=nodes.data[\"u_right\"],\n            k_bond_bond=nodes.data[\"k_bond_bond\"],\n        ),\n        \"u_bond_angle%s\"\n        % suffix: esp.mm.angle.bond_angle(\n            u_left=nodes.data[\"u_left\"],\n            u_right=nodes.data[\"u_right\"],\n            u_angle=nodes.data[\"u\"],\n            k_bond_angle=nodes.data[\"k_bond_angle\"],\n        ),\n    }\n\n\ndef apply_bond_ii(nodes, suffix=\"\"):\n    return {\n        \"u_bond_high%s\"\n        % suffix: esp.mm.bond.bond_high(\n            u_bond=nodes.data[\"u\"],\n            k3=nodes.data[\"k3\"],\n            k4=nodes.data[\"k4\"],\n        )\n    }\n\n\ndef apply_torsion_ii(nodes, suffix=\"\"):\n    \"\"\"Torsion energy in nodes.\"\"\"\n    return {\n        \"u_angle_angle%s\"\n        % suffix: esp.mm.torsion.angle_angle(\n            u_angle_left=nodes.data[\"u_angle_left\"],\n            u_angle_right=nodes.data[\"u_angle_right\"],\n            k_angle_angle=nodes.data[\"k_angle_angle\"],\n        ),\n        \"u_angle_torsion%s\"\n        % suffix: esp.mm.torsion.angle_torsion(\n            u_angle_left=nodes.data[\"u_angle_left\"],\n            u_angle_right=nodes.data[\"u_angle_right\"],\n            u_torsion=nodes.data[\"u\"],\n            k_angle_torsion=nodes.data[\"k_angle_torsion\"],\n        ),\n        \"u_angle_angle_torsion%s\"\n        % suffix: esp.mm.torsion.angle_angle_torsion(\n            u_angle_left=nodes.data[\"u_angle_left\"],\n            u_angle_right=nodes.data[\"u_angle_right\"],\n            u_torsion=nodes.data[\"u\"],\n            k_angle_angle_torsion=nodes.data[\"k_angle_angle_torsion\"],\n        ),\n        \"u_bond_torsion%s\"\n        % suffix: esp.mm.torsion.bond_torsion(\n            u_bond_left=nodes.data[\"u_bond_left\"],\n            u_bond_right=nodes.data[\"u_bond_right\"],\n            u_bond_center=nodes.data[\"u_bond_center\"],\n            u_torsion=nodes.data[\"u\"],\n            k_side_torsion=nodes.data[\"k_side_torsion\"],\n            k_center_torsion=nodes.data[\"k_center_torsion\"],\n        ),\n    }\n\n\ndef apply_torsion(nodes, suffix=\"\"):\n    \"\"\"Torsion energy in nodes.\"\"\"\n    if (\n        \"phases%s\" % suffix in nodes.data\n        and \"periodicity%s\" % suffix in nodes.data\n    ):\n        return {\n            \"u%s\"\n            % suffix: esp.mm.torsion.periodic_torsion(\n                x=nodes.data[\"x\"],\n                k=nodes.data[\"k%s\" % suffix],\n                phases=nodes.data[\"phases%s\" % suffix],\n                periodicity=nodes.data[\"periodicity%s\" % suffix],\n            )\n        }\n\n    else:\n        return {\n            \"u%s\"\n            % suffix: esp.mm.torsion.periodic_torsion(\n                x=nodes.data[\"x\"],\n                k=nodes.data[\"k%s\" % suffix],\n            )\n        }\n\n\ndef apply_improper_torsion(nodes, suffix=\"\"):\n    \"\"\"Improper torsion energy in nodes.\"\"\"\n    if (\n        \"phases%s\" % suffix in nodes.data\n        and \"periodicity%s\" % suffix in nodes.data\n    ):\n        return {\n            \"u%s\"\n            % suffix: esp.mm.torsion.periodic_torsion(\n                x=nodes.data[\"x\"],\n                k=nodes.data[\"k%s\" % suffix],\n                phases=nodes.data[\"phases%s\" % suffix],\n                periodicity=nodes.data[\"periodicity%s\" % suffix],\n            )\n        }\n\n    else:\n        n_multi = nodes.data[\"k%s\" % suffix].shape[-1]\n        periodicity=list(range(1, n_multi+1))\n        phases=[0.0 for _ in range(n_multi)]\n        return {\n            \"u%s\"\n            % suffix: esp.mm.torsion.periodic_torsion(\n                x=nodes.data[\"x\"],\n                k=nodes.data[\"k%s\" % suffix],\n                phases=phases,\n                periodicity=periodicity,\n            )\n        }\n\n\ndef apply_bond_gaussian(nodes, suffix=\"\"):\n    \"\"\"Bond energy in nodes.\"\"\"\n    # if suffix == '_ref':\n    return {\n        \"u%s\"\n        % suffix: esp.mm.bond.gaussian_bond(\n            x=nodes.data[\"x\"],\n            coefficients=nodes.data[\"coefficients%s\" % suffix],\n        )\n    }\n\n\ndef apply_bond_linear_mixture(nodes, suffix=\"\", phases=[0.0, 1.0]):\n    \"\"\"Bond energy in nodes.\"\"\"\n    # if suffix == '_ref':\n    return {\n        \"u%s\"\n        % suffix: esp.mm.bond.linear_mixture_bond(\n            x=nodes.data[\"x\"],\n            coefficients=nodes.data[\"coefficients%s\" % suffix],\n            phases=phases,\n        )\n    }\n\n\ndef apply_angle_linear_mixture(nodes, suffix=\"\", phases=[0.0, 1.0]):\n    \"\"\"Bond energy in nodes.\"\"\"\n    # if suffix == '_ref':\n    return {\n        \"u%s\"\n        % suffix: esp.mm.angle.linear_mixture_angle(\n            x=nodes.data[\"x\"],\n            coefficients=nodes.data[\"coefficients%s\" % suffix],\n            phases=phases,\n        )\n    }\n\n\n# =============================================================================\n# ENERGY IN HYPERNODES---NONBONDED\n# =============================================================================\ndef apply_nonbonded(nodes, scaling=1.0, suffix=\"\"):\n    \"\"\"Nonbonded in nodes.\"\"\"\n    # TODO: should this be 9-6 or 12-6?\n    return {\n        \"u%s\"\n        % suffix: scaling\n        * esp.mm.nonbonded.lj_12_6(\n            x=nodes.data[\"x\"],\n            sigma=nodes.data[\"sigma%s\" % suffix],\n            epsilon=nodes.data[\"epsilon%s\" % suffix],\n        )\n    }\n\n\ndef apply_coulomb(nodes, scaling=1.0, suffix=\"\"):\n    return {\n        \"u%s\"\n        % suffix: scaling\n        * esp.mm.nonbonded.coulomb(\n            x=nodes.data[\"x\"],\n            q=nodes.data[\"q\"],\n        )\n    }\n\n\n# =============================================================================\n# ENERGY IN GRAPH\n# =============================================================================\ndef energy_in_graph(\n    g,\n    suffix=\"\",\n    terms=[\"n2\", \"n3\", \"n4\"],\n):  # \"onefour\", \"nonbonded\"]):\n    \"\"\"Calculate the energy of a small molecule given parameters and geometry.\n\n    Parameters\n    ----------\n    g : `dgl.DGLHeteroGraph`\n        Input graph.\n\n    Returns\n    -------\n    g : `dgl.DGLHeteroGraph`\n        Output graph.\n\n    Notes\n    -----\n    This function modifies graphs in-place.\n\n    \"\"\"\n    # TODO: this is all very restricted for now\n    # we need to make this better\n    import dgl\n\n    if \"n2\" in terms:\n        # apply energy function\n\n        if \"coefficients%s\" % suffix in g.nodes[\"n2\"].data:\n            g.apply_nodes(\n                lambda node: apply_bond_linear_mixture(\n                    node, suffix=suffix, phases=[1.5, 6.0]\n                ),\n                ntype=\"n2\",\n            )\n        else:\n            g.apply_nodes(\n                lambda node: apply_bond(node, suffix=suffix),\n                ntype=\"n2\",\n            )\n\n    if \"n3\" in terms:\n        if \"coefficients%s\" % suffix in g.nodes[\"n3\"].data:\n            import math\n\n            g.apply_nodes(\n                lambda node: apply_angle_linear_mixture(\n                    node, suffix=suffix, phases=[0.0, math.pi]\n                ),\n                ntype=\"n3\",\n            )\n        else:\n            g.apply_nodes(\n                lambda node: apply_angle(node, suffix=suffix),\n                ntype=\"n3\",\n            )\n\n    if g.number_of_nodes(\"n4\") > 0 and \"n4\" in terms:\n        g.apply_nodes(\n            lambda node: apply_torsion(node, suffix=suffix),\n            ntype=\"n4\",\n        )\n\n    if g.number_of_nodes(\"n4_improper\") > 0 and \"n4_improper\" in terms:\n        g.apply_nodes(\n            lambda node: apply_improper_torsion(node, suffix=suffix),\n            ntype=\"n4_improper\",\n        )\n\n    # if g.number_of_nodes(\"nonbonded\") > 0 and \"nonbonded\" in terms:\n    #     g.apply_nodes(\n    #         lambda node: apply_nonbonded(node, suffix=suffix),\n    #         ntype=\"nonbonded\",\n    #     )\n\n    # if g.number_of_nodes(\"onefour\") > 0 and \"onefour\" in terms:\n    #     g.apply_nodes(\n    #         lambda node: apply_nonbonded(\n    #             node,\n    #             suffix=suffix,\n    #             scaling=0.5,\n    #         ),\n    #         ntype=\"onefour\",\n    #     )\n\n    if \"nonbonded\" in terms or \"onefour\" in terms:\n        esp.mm.nonbonded.multiply_charges(g)\n\n    if \"nonbonded\" in terms and g.number_of_nodes(\"nonbonded\") > 0:\n        g.apply_nodes(\n            lambda node: apply_coulomb(\n                node,\n                suffix=suffix,\n                scaling=1.0,\n            ),\n            ntype=\"nonbonded\",\n        )\n\n    if \"onefour\" in terms and g.number_of_nodes(\"onefour\") > 0:\n        g.apply_nodes(\n            lambda node: apply_coulomb(\n                node,\n                suffix=suffix,\n                # scaling=0.5,\n                scaling=0.8333333333333334,\n            ),\n            ntype=\"onefour\",\n        )\n\n    # sum up energy\n    # bonded\n    g.multi_update_all(\n        {\n            \"%s_in_g\"\n            % term: (\n                dgl.function.copy_u(u=\"u%s\" % suffix, out=\"m_%s\" % term),\n                dgl.function.sum(\n                    msg=\"m_%s\" % term, out=\"u_%s%s\" % (term, suffix)\n                ),\n            )\n            for term in terms\n            if \"u%s\" % suffix in g.nodes[term].data\n        },\n        cross_reducer=\"sum\",\n    )\n\n    g.apply_nodes(\n        lambda node: {\n            \"u%s\"\n            % suffix: sum(\n                node.data[\"u_%s%s\" % (term, suffix)]\n                for term in terms\n                if \"u_%s%s\" % (term, suffix) in node.data\n            )\n        },\n        ntype=\"g\",\n    )\n\n    if \"u0\" in g.nodes[\"g\"].data:\n        g.apply_nodes(\n            lambda node: {\"u\": node.data[\"u\"] + node.data[\"u0\"]},\n            ntype=\"g\",\n        )\n\n    return g\n\n\ndef energy_in_graph_ii(\n    g,\n    suffix=\"\",\n):\n    if g.number_of_nodes(\"n3\") > 0:\n\n        g.apply_nodes(\n            lambda node: apply_angle_ii(node, suffix=suffix),\n            ntype=\"n3\",\n        )\n\n        g.apply_nodes(\n            lambda node: {\n                \"u%s\" % suffix: node.data[\"u%s\" % suffix]\n                + node.data[\"u_urey_bradley%s\" % suffix]\n                + node.data[\"u_bond_bond%s\" % suffix]\n                + node.data[\"u_bond_angle%s\" % suffix]\n            },\n            ntype=\"n3\",\n        )\n\n    if g.number_of_nodes(\"n4\") > 0:\n        g.apply_nodes(\n            lambda node: apply_torsion_ii(node, suffix=suffix),\n            ntype=\"n4\",\n        )\n\n        g.apply_nodes(\n            lambda node: {\n                \"u%s\" % suffix: node.data[\"u%s\" % suffix]\n                + node.data[\"u_angle_angle%s\" % suffix]\n                + node.data[\"u_angle_torsion%s\" % suffix]\n                + node.data[\"u_angle_angle_torsion%s\" % suffix]\n                + node.data[\"u_bond_torsion%s\" % suffix]\n            },\n            ntype=\"n4\",\n        )\n\n    return g\n\n\nclass EnergyInGraph(torch.nn.Module):\n    def __init__(self, *args, **kwargs):\n        super(EnergyInGraph, self).__init__()\n        self.args = args\n        self.kwargs = kwargs\n\n    def forward(self, g):\n        return energy_in_graph(g, *self.args, **self.kwargs)\n\n\nclass EnergyInGraphII(torch.nn.Module):\n    def __init__(self, *args, **kwargs):\n        super(EnergyInGraphII, self).__init__()\n        self.args = args\n        self.kwargs = kwargs\n\n    def forward(self, g):\n        return energy_in_graph_ii(g, *self.args, **self.kwargs)\n\n\nclass CarryII(torch.nn.Module):\n    def forward(self, g):\n        import math\n        import dgl\n\n        g.multi_update_all(\n            {\n                \"n2_as_0_in_n3\": (\n                    dgl.function.copy_u(\"u\", \"m_u_0\"),\n                    dgl.function.sum(\"m_u_0\", \"u_left\"),\n                ),\n                \"n2_as_1_in_n3\": (\n                    dgl.function.copy_u(\"u\", \"m_u_1\"),\n                    dgl.function.sum(\"m_u_1\", \"u_right\"),\n                ),\n                \"n2_as_0_in_n4\": (\n                    dgl.function.copy_u(\"u\", \"m_u_0\"),\n                    dgl.function.sum(\"m_u_0\", \"u_bond_left\"),\n                ),\n                \"n2_as_1_in_n4\": (\n                    dgl.function.copy_u(\"u\", \"m_u_1\"),\n                    dgl.function.sum(\"m_u_1\", \"u_bond_center\"),\n                ),\n                \"n2_as_2_in_n4\": (\n                    dgl.function.copy_u(\"u\", \"m_u_2\"),\n                    dgl.function.sum(\"m_u_2\", \"u_bond_right\"),\n                ),\n                \"n3_as_0_in_n4\": (\n                    dgl.function.copy_u(\"u\", \"m3_u_0\"),\n                    dgl.function.sum(\"m3_u_0\", \"u_angle_left\"),\n                ),\n                \"n3_as_1_in_n4\": (\n                    dgl.function.copy_u(\"u\", \"m3_u_1\"),\n                    dgl.function.sum(\"m3_u_1\", \"u_angle_right\"),\n                ),\n            },\n            cross_reducer=\"sum\",\n        )\n\n        return g\n"
  },
  {
    "path": "espaloma/mm/functional.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport math\nimport torch\nimport espaloma as esp\n\n# =============================================================================\n# CONSTANTS\n# =============================================================================\nfrom openmm import unit\nfrom openmm.unit import Quantity\n\nLJ_SWITCH = Quantity(1.0, unit.angstrom).value_in_unit(\n    esp.units.DISTANCE_UNIT\n)\n\n# =============================================================================\n# UTILITY FUNCTIONS\n# =============================================================================\ndef linear_mixture_to_original(k1, k2, b1, b2):\n    \"\"\"Translating linear mixture coefficients back to original\n    parameterization.\n    \"\"\"\n    # (batch_size, )\n    k = k1 + k2\n\n    # (batch_size, )\n    b = (k1 * b1 + k2 * b2) / (k + 1e-7)\n\n    return k, b\n\n\n# =============================================================================\n# MODULE FUNCTIONS\n# =============================================================================\ndef harmonic(x, k, eq, order=[2]):\n    \"\"\"Harmonic term.\n\n    Parameters\n    ----------\n    x : `torch.Tensor`, `shape=(batch_size, 1)`\n    k : `torch.Tensor`, `shape=(batch_size, len(order))`\n    eq : `torch.Tensor`, `shape=(batch_size, len(order))`\n    order : `int` or `List` of `int`\n\n    Returns\n    -------\n    u : `torch.Tensor`, `shape=(batch_size, 1)`\n    \"\"\"\n\n    if isinstance(order, list):\n        order = torch.tensor(order, device=x.device)\n\n    return (\n        0.5\n        * k\n        * ((x - eq)).pow(order[:, None, None]).permute(1, 2, 0).sum(dim=-1)\n    )\n\n\ndef periodic_fixed_phases(\n    dihedrals: torch.Tensor, ks: torch.Tensor\n) -> torch.Tensor:\n    \"\"\"Periodic torsion term with n_phases = 6, periodicities = 1..n_phases, phases = zeros\n\n    Parameters\n    ----------\n    dihedrals : torch.Tensor, shape=(n_snapshots, n_dihedrals)\n        dihedral angles -- TODO: confirm in radians?\n    ks : torch.Tensor, shape=(n_dihedrals, n_phases)\n        force constants -- TODO: confirm in esp.unit.ENERGY_UNIT ?\n\n    Returns\n    -------\n    u : torch.Tensor, shape=(n_snapshots, 1)\n        potential energy of each snapshot\n\n    Notes\n    -----\n    TODO: is there a way to annotate / type-hint tensor shapes? (currently adding many assert statements)\n    TODO: merge with esp.mm.functional.periodic -- adding this because I was having difficulty debugging runtime tensor\n      shape errors in esp.mm.functional.periodic, which allows for a more flexible mix of input shapes and types\n    \"\"\"\n\n    # periodicity = 1..n_phases\n    n_phases = 6\n    periodicity = torch.arange(n_phases) + 1\n\n    # assert input shape consistency\n    n_snapshots, n_dihedrals = dihedrals.shape\n    n_dihedrals_, n_phases_ = ks.shape\n    assert n_dihedrals == n_dihedrals_\n    assert n_phases == n_phases_\n\n    # promote everything to this shape\n    stacked_shape = (n_snapshots, n_dihedrals, n_phases)\n\n    # duplicate ks n_snapshots times\n    ks_stacked = torch.stack([ks] * n_snapshots, dim=0)\n    assert ks_stacked.shape == stacked_shape\n\n    # duplicate dihedral angles n_phases times\n    dihedrals_stacked = torch.stack([dihedrals] * n_phases, dim=2)\n    assert dihedrals_stacked.shape == stacked_shape\n\n    # duplicate periodicity n_snapshots * n_dihedrals times\n    ns = torch.stack(\n        [torch.stack([periodicity] * n_snapshots)] * n_dihedrals, dim=1\n    )\n    assert ns.shape == stacked_shape\n\n    # compute k_n * cos(n * theta) for n in 1..n_phases, for each dihedral in each snapshot\n    energy_terms = ks_stacked * torch.cos(ns * dihedrals_stacked)\n    assert energy_terms.shape == stacked_shape\n\n    # sum over n_dihedrals and n_phases\n    energy_sums = energy_terms.sum(dim=(1, 2))\n    assert energy_sums.shape == (n_snapshots,)\n\n    return energy_sums.reshape((n_snapshots, 1))\n\n\ndef periodic(\n    x, k, periodicity=list(range(1, 7)), phases=[0.0 for _ in range(6)]\n):\n    \"\"\"Periodic term.\n\n    Parameters\n    ----------\n    x : `torch.Tensor`, `shape=(batch_size, 1)`\n    k : `torch.Tensor`, `shape=(batch_size, number_of_phases)`\n    periodicity: either list of length number_of_phases, or\n        `torch.Tensor`, `shape=(batch_size, number_of_phases)`\n    phases : either list of length number_of_phases, or\n        `torch.Tensor`, `shape=(batch_size, number_of_phases)`\n    \"\"\"\n\n    if isinstance(phases, list):\n        phases = torch.tensor(phases, device=x.device)\n\n    if isinstance(periodicity, list):\n        periodicity = torch.tensor(\n            periodicity,\n            device=x.device,\n            dtype=torch.get_default_dtype(),\n        )\n\n    if periodicity.ndim == 1:\n        periodicity = periodicity[None, None, :].repeat(\n            x.shape[0], x.shape[1], 1\n        )\n\n    elif periodicity.ndim == 2:\n        periodicity = periodicity[:, None, :].repeat(1, x.shape[1], 1)\n\n    if phases.ndim == 1:\n        phases = phases[None, None, :].repeat(\n            x.shape[0],\n            x.shape[1],\n            1,\n        )\n\n    elif phases.ndim == 2:\n        phases = phases[:, None, :].repeat(\n            1,\n            x.shape[1],\n            1,\n        )\n\n    n_theta = periodicity * x[:, :, None]\n\n    n_theta_minus_phases = n_theta - phases\n\n    cos_n_theta_minus_phases = n_theta_minus_phases.cos()\n\n    k = k[:, None, :].repeat(1, x.shape[1], 1)\n\n    # energy = (k * (1.0 + cos_n_theta_minus_phases)).sum(dim=-1)\n\n    energy = (\n        torch.nn.functional.relu(k) * (cos_n_theta_minus_phases + 1.0)\n        - torch.nn.functional.relu(0.0 - k) * (cos_n_theta_minus_phases - 1.0)\n    ).sum(dim=-1)\n\n    return energy\n\n\n# simple implementation\n# def harmonic(x, k, eq):\n#     return k * (x - eq) ** 2\n#\n# def harmonic_re(x, k, eq, a=0.0, b=0.3):\n#     # temporary\n#     ka = k\n#     kb = eq\n#\n#     c = ((ka * a + kb * b) / (ka + kb)) ** 2 - a ** 2 - b ** 2\n#\n#     return ka * (x - a) ** 2 + kb * (x - b) ** 2\n\n\ndef lj(\n    x,\n    epsilon,\n    sigma,\n    order=[12, 6],\n    coefficients=[1.0, 1.0],\n    switch=LJ_SWITCH,\n):\n    r\"\"\"Lennard-Jones term.\n\n    Notes\n    -----\n    ..math::\n    E  = \\epsilon  ((\\sigma / r) ^ {12} - (\\sigma / r) ^ 6)\n\n    Parameters\n    ----------\n    x : `torch.Tensor`, `shape=(batch_size, 1)`\n    epsilon : `torch.Tensor`, `shape=(batch_size, len(order))`\n    sigma : `torch.Tensor`, `shape=(batch_size, len(order))`\n    order : `int` or `List` of `int`\n    coefficients : torch.tensor or list\n    switch : unitless switch width (distance)\n\n    Returns\n    -------\n    u : `torch.Tensor`, `shape=(batch_size, 1)`\n    \"\"\"\n    if isinstance(order, list):\n        order = torch.tensor(order, device=x.device)\n\n    if isinstance(coefficients, list):\n        coefficients = torch.tensor(coefficients, device=x.device)\n\n    assert order.shape[0] == 2\n    assert order.dim() == 1\n\n    # TODO:\n    # for experiments only\n    # erase later\n\n    # compute sigma over x\n    sigma_over_x = sigma / x\n\n    # erase values under switch\n    sigma_over_x = torch.where(\n        torch.lt(x, switch),\n        torch.zeros_like(sigma_over_x),\n        sigma_over_x,\n    )\n\n    return epsilon * (\n        coefficients[0] * sigma_over_x ** order[0]\n        - coefficients[1] * sigma_over_x ** order[1]\n    )\n\n\ndef gaussian(x, coefficients, phases=[idx * 0.001 for idx in range(200)]):\n    r\"\"\"Gaussian basis function.\n\n    Parameters\n    ----------\n    x : torch.Tensor\n    coefficients : list or torch.Tensor of length n_phases\n    phases : list or torch.Tensor of length n_phases\n    \"\"\"\n\n    if isinstance(phases, list):\n        # (number_of_phases, )\n        phases = torch.tensor(phases, device=x.device)\n\n    # broadcasting\n    # (number_of_hypernodes, number_of_snapshots, number_of_phases)\n    phases = phases[None, None, :].repeat(x.shape[0], x.shape[1], 1)\n    x = x[:, :, None].repeat(1, 1, phases.shape[-1])\n    coefficients = coefficients[:, None, :].repeat(1, x.shape[1], 1)\n\n    return (coefficients * torch.exp(-0.5 * (x - phases) ** 2)).sum(-1)\n\n\ndef linear_mixture(x, coefficients, phases=[0.0, 1.0]):\n    r\"\"\"Linear mixture basis function.\n\n    x : torch.Tensor\n    coefficients : list or torch.Tensor of length 2\n    phases : list of length 2\n    \"\"\"\n\n    assert len(phases) == 2, \"Only two phases now.\"\n    assert coefficients.shape[-1] == 2\n\n    # partition the dimensions\n    # (, )\n    b1 = phases[0]\n    b2 = phases[1]\n\n    # (batch_size, 1)\n    k1 = coefficients[:, 0][:, None]\n    k2 = coefficients[:, 1][:, None]\n\n    # get the original parameters\n    # (batch_size, )\n    # k, b = linear_mixture_to_original(k1, k2, b1, b2)\n\n    # (batch_size, 1)\n    u1 = k1 * (x - b1) ** 2\n    u2 = k2 * (x - b2) ** 2\n\n    u = 0.5 * (u1 + u2)  # - k1 * b1 ** 2 - k2 ** b2 ** 2 + b ** 2\n\n    return u\n\n\ndef harmonic_periodic_coupled(\n    x_harmonic,\n    x_periodic,\n    k,\n    eq,\n    periodicity=list(range(1, 3)),\n):\n\n    if isinstance(periodicity, list):\n        periodicity = torch.tensor(\n            periodicity,\n            device=x_harmonic.device,\n            dtype=torch.get_default_dtype(),\n        )\n\n    n_theta = (\n        periodicity[None, None, :].repeat(\n            x_periodic.shape[0], x_periodic.shape[1], 1\n        )\n        * x_periodic[:, :, None]\n    )\n\n    cos_n_theta = n_theta.cos()\n\n    k = k[:, None, :].repeat(1, x_periodic.shape[1], 1)\n\n    sum_k_cos_n_theta = (k * cos_n_theta).sum(dim=-1)\n\n    x_minus_eq = x_harmonic - eq\n\n    energy = x_minus_eq * sum_k_cos_n_theta\n\n    return energy\n\n\ndef harmonic_harmonic_coupled(\n    x0,\n    x1,\n    eq0,\n    eq1,\n    k,\n):\n    energy = k * (x0 - eq0) * (x1 - eq1)\n    return energy\n\n\ndef harmonic_harmonic_periodic_coupled(\n    theta0,\n    theta1,\n    eq0,\n    eq1,\n    phi,\n    k,\n):\n    energy = k * (theta0 - eq0) * (theta1 - eq1) * phi.cos()\n    return energy\n"
  },
  {
    "path": "espaloma/mm/geometry.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport torch\n\n# =============================================================================\n# UTILITY FUNCTIONS\n# =============================================================================\ndef reduce_stack(msg, out):\n    \"\"\"Copy massage and stack.\"\"\"\n\n    def _reduce_stack(nodes, msg=msg, out=out):\n        return {out: nodes.mailbox[msg]}\n\n    return _reduce_stack\n\n\ndef copy_src(src, out):\n    \"\"\"Copy source of an edge.\"\"\"\n\n    def _copy_src(edges, src=src, out=out):\n        return {out: edges.src[src].clone()}\n\n    return _copy_src\n\n\n# =============================================================================\n# SINGLE GEOMETRY ENTITY\n# =============================================================================\ndef distance(x0, x1):\n    \"\"\"Distance.\"\"\"\n    return torch.norm(x0 - x1, p=2, dim=-1)\n\n\ndef _angle(r0, r1):\n    \"\"\"Angle between vectors.\"\"\"\n\n    angle = torch.atan2(\n        torch.norm(torch.cross(r0, r1), p=2, dim=-1),\n        torch.sum(torch.mul(r0, r1), dim=-1),\n    )\n\n    return angle\n\n\ndef angle(x0, x1, x2):\n    \"\"\"Angle between three points.\"\"\"\n    left = x1 - x0\n    right = x1 - x2\n    return _angle(left, right)\n\n\ndef _dihedral(r0, r1):\n    \"\"\"Dihedral between normal vectors.\"\"\"\n    return _angle(r0, r1)\n\n\ndef dihedral(\n    x0: torch.Tensor, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor\n) -> torch.Tensor:\n    \"\"\"Dihedral between four points.\n\n    Reference\n    ---------\n    Closely follows implementation in Yutong Zhao's timemachine:\n        https://github.com/proteneer/timemachine/blob/1a0ab45e605dc1e28c44ea90f38cb0dedce5c4db/timemachine/potentials/bonded.py#L152-L199\n    \"\"\"\n    # check input shapes\n\n    assert x0.shape == x1.shape == x2.shape == x3.shape\n\n    # compute displacements 0->1, 2->1, 2->3\n    r01 = x1 - x0 + torch.randn_like(x0) * 1e-5\n    r21 = x1 - x2 + torch.randn_like(x0) * 1e-5\n    r23 = x3 - x2 + torch.randn_like(x0) * 1e-5\n\n    # compute normal planes\n    n1 = torch.cross(r01, r21)\n    n2 = torch.cross(r21, r23)\n\n    rkj_normed = r21 / torch.norm(r21, dim=-1, keepdim=True)\n\n    y = torch.sum(torch.mul(torch.cross(n1, n2), rkj_normed), dim=-1)\n    x = torch.sum(torch.mul(n1, n2), dim=-1)\n\n    # choose quadrant correctly\n    theta = torch.atan2(y, x)\n\n    return theta\n\n\n# =============================================================================\n# GEOMETRY IN HYPERNODES\n# =============================================================================\ndef apply_bond(nodes):\n    \"\"\"Bond length in nodes.\"\"\"\n\n    return {\"x\": distance(x0=nodes.data[\"xyz0\"], x1=nodes.data[\"xyz1\"])}\n\n\ndef apply_angle(nodes):\n    \"\"\"Angle values in nodes.\"\"\"\n    return {\n        \"x\": angle(\n            x0=nodes.data[\"xyz0\"],\n            x1=nodes.data[\"xyz1\"],\n            x2=nodes.data[\"xyz2\"],\n        ),\n        \"x_left\": distance(\n            x0=nodes.data[\"xyz1\"],\n            x1=nodes.data[\"xyz0\"],\n        ),\n        \"x_right\": distance(\n            x0=nodes.data[\"xyz1\"],\n            x1=nodes.data[\"xyz2\"],\n        ),\n        \"x_between\": distance(\n            x0=nodes.data[\"xyz0\"],\n            x1=nodes.data[\"xyz2\"],\n        ),\n    }\n\n\ndef apply_torsion(nodes):\n    \"\"\"Torsion dihedrals in nodes.\"\"\"\n    return {\n        \"x\": dihedral(\n            x0=nodes.data[\"xyz0\"],\n            x1=nodes.data[\"xyz1\"],\n            x2=nodes.data[\"xyz2\"],\n            x3=nodes.data[\"xyz3\"],\n        ),\n        \"x_bond_left\": distance(\n            x0=nodes.data[\"xyz0\"],\n            x1=nodes.data[\"xyz1\"],\n        ),\n        \"x_bond_center\": distance(\n            x0=nodes.data[\"xyz1\"],\n            x1=nodes.data[\"xyz2\"],\n        ),\n        \"x_bond_right\": distance(\n            x0=nodes.data[\"xyz2\"],\n            x1=nodes.data[\"xyz3\"],\n        ),\n        \"x_angle_left\": angle(\n            x0=nodes.data[\"xyz0\"],\n            x1=nodes.data[\"xyz1\"],\n            x2=nodes.data[\"xyz2\"],\n        ),\n        \"x_angle_right\": angle(\n            x0=nodes.data[\"xyz1\"],\n            x1=nodes.data[\"xyz2\"],\n            x2=nodes.data[\"xyz3\"],\n        ),\n    }\n\n\n# =============================================================================\n# GEOMETRY IN GRAPH\n# =============================================================================\n# NOTE:\n# The following functions modify graphs in-place.\n\n\ndef geometry_in_graph(g):\n    \"\"\"Assign values to geometric entities in graphs.\n\n    Parameters\n    ----------\n    g : `dgl.DGLHeteroGraph`\n        Input graph.\n\n    Returns\n    -------\n    g : `dgl.DGLHeteroGraph`\n        Output graph.\n\n    Notes\n    -----\n    This function modifies graphs in-place.\n\n    \"\"\"\n    import dgl\n\n    # Copy coordinates to higher-order nodes.\n    g.multi_update_all(\n        {\n            **{\n                \"n1_as_%s_in_n%s\"\n                % (pos_idx, big_idx): (\n                    dgl.function.copy_u(u=\"xyz\", out=\"m_xyz%s\" % pos_idx),\n                    dgl.function.sum(\n                        msg=\"m_xyz%s\" % pos_idx, out=\"xyz%s\" % pos_idx\n                    ),\n                )\n                for big_idx in range(2, 5)\n                for pos_idx in range(big_idx)\n            },\n            **{\n                \"n1_as_%s_in_%s\"\n                % (pos_idx, term): (\n                    dgl.function.copy_u(u=\"xyz\", out=\"m_xyz%s\" % pos_idx),\n                    dgl.function.sum(\n                        msg=\"m_xyz%s\" % pos_idx, out=\"xyz%s\" % pos_idx\n                    ),\n                )\n                for term in [\"nonbonded\", \"onefour\"]\n                for pos_idx in [0, 1]\n            },\n            **{\n                \"n1_as_%s_in_%s\"\n                % (pos_idx, term): (\n                    dgl.function.copy_u(u=\"xyz\", out=\"m_xyz%s\" % pos_idx),\n                    dgl.function.sum(\n                        msg=\"m_xyz%s\" % pos_idx, out=\"xyz%s\" % pos_idx\n                    ),\n                )\n                for term in [\"n4_improper\"]\n                for pos_idx in [0, 1, 2, 3]\n            },\n        },\n        cross_reducer=\"sum\",\n    )\n\n    # apply geometry functions\n    g.apply_nodes(apply_bond, ntype=\"n2\")\n    g.apply_nodes(apply_angle, ntype=\"n3\")\n\n    if g.number_of_nodes(\"n4\") > 0:\n        g.apply_nodes(apply_torsion, ntype=\"n4\")\n\n    # copy coordinates to nonbonded\n    if g.number_of_nodes(\"nonbonded\") > 0:\n        g.apply_nodes(apply_bond, ntype=\"nonbonded\")\n\n    if g.number_of_nodes(\"onefour\") > 0:\n        g.apply_nodes(apply_bond, ntype=\"onefour\")\n\n    if g.number_of_nodes(\"n4_improper\") > 0:\n        g.apply_nodes(apply_torsion, ntype=\"n4_improper\")\n\n    return g\n\n\nclass GeometryInGraph(torch.nn.Module):\n    def __init__(self, *args, **kwargs):\n        super(GeometryInGraph, self).__init__()\n        self.args = args\n        self.kwargs = kwargs\n\n    def forward(self, g):\n        return geometry_in_graph(g, *self.args, **self.kwargs)\n"
  },
  {
    "path": "espaloma/mm/nonbonded.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport torch\n\n# =============================================================================\n# CONSTANTS\n# =============================================================================\nimport espaloma as esp\nfrom openmm import unit\n\n# CODATA 2018\n# ref https://en.wikipedia.org/wiki/Coulomb_constant\n# Coulomb constant\nK_E = (\n    8.9875517923\n    * 1e9\n    * unit.newton\n    * unit.meter**2\n    * unit.coulomb ** (-2)\n    * esp.units.PARTICLE ** (-1)\n).value_in_unit(esp.units.COULOMB_CONSTANT_UNIT)\n\n# =============================================================================\n# UTILITY FUNCTIONS FOR COMBINATION RULES FOR NONBONDED\n# =============================================================================\ndef geometric_mean(msg=\"m\", out=\"epsilon\"):\n    def _geometric_mean(nodes):\n        return {out: torch.prod(nodes.mailbox[msg], dim=1).pow(0.5)}\n\n    return _geometric_mean\n\n\ndef arithmetic_mean(msg=\"m\", out=\"sigma\"):\n    def _arithmetic_mean(nodes):\n        return {out: torch.sum(nodes.mailbox[msg], dim=1).mul(0.5)}\n\n    return _arithmetic_mean\n\n\n# =============================================================================\n# COMBINATION RULES FOR NONBONDED\n# =============================================================================\ndef lorentz_berthelot(g, suffix=\"\"):\n    import dgl\n\n    g.multi_update_all(\n        {\n            \"n1_as_%s_in_%s\"\n            % (pos_idx, term): (\n                dgl.function.copy_u(\n                    u=\"epsilon%s\" % suffix, out=\"m_epsilon\"\n                ),\n                geometric_mean(msg=\"m_epsilon\", out=\"epsilon%s\" % suffix),\n            )\n            for pos_idx in [0, 1]\n            for term in [\"nonbonded\", \"onefour\"]\n        },\n        cross_reducer=\"sum\",\n    )\n\n    g.multi_update_all(\n        {\n            \"n1_as_%s_in_%s\"\n            % (pos_idx, term): (\n                dgl.function.copy_u(u=\"sigma%s\" % suffix, out=\"m_sigma\"),\n                arithmetic_mean(msg=\"m_sigma\", out=\"sigma%s\" % suffix),\n            )\n            for pos_idx in [0, 1]\n            for term in [\"nonbonded\", \"onefour\"]\n        },\n        cross_reducer=\"sum\",\n    )\n\n    return g\n\n\ndef multiply_charges(g, suffix=\"\"):\n    \"\"\"Multiply the charges of atoms into nonbonded and onefour terms.\n\n    Parameters\n    ----------\n    g : dgl.HeteroGraph\n        Input graph.\n\n    Returns\n    -------\n    dgl.HeteroGraph : The modified graph with charges.\n\n    \"\"\"\n    import dgl\n\n    g.multi_update_all(\n        {\n            \"n1_as_%s_in_%s\"\n            % (pos_idx, term): (\n                dgl.function.copy_u(u=\"q%s\" % suffix, out=\"m_q\"),\n                dgl.function.sum(msg=\"m_q\", out=\"_q\")\n                # lambda node: {\"q%s\" % suffix: node.mailbox[\"m_q\"].prod(dim=1)}\n            )\n            for pos_idx in [0, 1]\n            for term in [\"nonbonded\", \"onefour\"]\n        },\n        cross_reducer=\"stack\",\n        apply_node_func=lambda node: {\"q\": node.data[\"_q\"].prod(dim=1)},\n    )\n\n    return g\n\n\n# =============================================================================\n# ENERGY FUNCTIONS\n# =============================================================================\ndef lj_12_6(x, sigma, epsilon):\n    \"\"\"Lennard-Jones 12-6.\n\n    Parameters\n    ----------\n    x : `torch.Tensor`, `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)`\n\n    sigma : `torch.Tensor`, `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)`\n\n    epsilon : `torch.Tensor`,\n        `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)`\n\n    Returns\n    -------\n    u : `torch.Tensor`, `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)`\n\n    \"\"\"\n\n    return esp.mm.functional.lj(x=x, sigma=sigma, epsilon=epsilon)\n\n\ndef lj_9_6(x, sigma, epsilon):\n    \"\"\"Lennard-Jones 9-6.\n\n    Parameters\n    ----------\n    x : `torch.Tensor`, `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)`\n\n    sigma : `torch.Tensor`, `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)`\n\n    epsilon : `torch.Tensor`,\n        `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)`\n\n    Returns\n    -------\n    u : `torch.Tensor`, `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)`\n\n    \"\"\"\n\n    return esp.mm.functional.lj(\n        x=x, sigma=sigma, epsilon=epsilon, order=[9, 6], coefficients=[2, 3]\n    )\n\n\ndef coulomb(x, q, k_e=K_E):\n    \"\"\"Columb interaction without cutoff.\n\n    Parameters\n    ----------\n    x : `torch.Tensor`, shape=`(batch_size, 1)` or `(batch_size, batch_size, 1)`\n        Distance between atoms.\n\n    q : `torch.Tensor`,\n        `shape=(batch_size, 1) or `(batch_size, batch_size, 1)`\n        Product of charge.\n\n    Returns\n    -------\n    torch.Tensor : `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)`\n        Coulomb energy.\n\n    Notes\n    -----\n    This computes half Coulomb energy to count for the duplication in onefour\n        and nonbonded enumerations.\n\n    \"\"\"\n    return 0.5 * k_e * q / x\n"
  },
  {
    "path": "espaloma/mm/tests/system.xml",
    "content": "<?xml version=\"1.0\" ?>\n<System openmmVersion=\"7.7\" type=\"System\" version=\"1\">\n\t<PeriodicBoxVectors>\n\t\t<A x=\"2\" y=\"0\" z=\"0\"/>\n\t\t<B x=\"0\" y=\"2\" z=\"0\"/>\n\t\t<C x=\"0\" y=\"0\" z=\"2\"/>\n\t</PeriodicBoxVectors>\n\t<Particles>\n\t\t<Particle mass=\"12.01\"/>\n\t\t<Particle mass=\"12.01\"/>\n\t\t<Particle mass=\"12.01\"/>\n\t\t<Particle mass=\"12.01\"/>\n\t\t<Particle mass=\"12.01\"/>\n\t\t<Particle mass=\"12.01\"/>\n\t\t<Particle mass=\"12.01\"/>\n\t\t<Particle mass=\"12.01\"/>\n\t\t<Particle mass=\"12.01\"/>\n\t\t<Particle mass=\"12.01\"/>\n\t\t<Particle mass=\"16\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t\t<Particle mass=\"1.008\"/>\n\t</Particles>\n\t<Constraints/>\n\t<Forces>\n\t\t<Force forceGroup=\"0\" name=\"HarmonicBondForce\" type=\"HarmonicBondForce\" usesPeriodic=\"0\" version=\"2\">\n\t\t\t<Bonds>\n\t\t\t\t<Bond d=\".15095000000000003\" k=\"273466.23999999993\" p1=\"0\" p2=\"1\"/>\n\t\t\t\t<Bond d=\".15095000000000003\" k=\"273466.23999999993\" p1=\"1\" p2=\"2\"/>\n\t\t\t\t<Bond d=\".13343000000000002\" k=\"476473.91999999987\" p1=\"1\" p2=\"3\"/>\n\t\t\t\t<Bond d=\".15095000000000003\" k=\"273466.23999999993\" p1=\"3\" p2=\"4\"/>\n\t\t\t\t<Bond d=\".15375000000000003\" k=\"251793.11999999994\" p1=\"4\" p2=\"5\"/>\n\t\t\t\t<Bond d=\".15095000000000003\" k=\"273466.23999999993\" p1=\"5\" p2=\"6\"/>\n\t\t\t\t<Bond d=\".15095000000000003\" k=\"273466.23999999993\" p1=\"6\" p2=\"7\"/>\n\t\t\t\t<Bond d=\".13461\" k=\"457980.6399999999\" p1=\"6\" p2=\"8\"/>\n\t\t\t\t<Bond d=\".14825\" k=\"296645.5999999999\" p1=\"8\" p2=\"9\"/>\n\t\t\t\t<Bond d=\".12183\" k=\"533627.36\" p1=\"9\" p2=\"10\"/>\n\t\t\t\t<Bond d=\".10969000000000001\" k=\"276646.07999999996\" p1=\"0\" p2=\"11\"/>\n\t\t\t\t<Bond d=\".10969000000000001\" k=\"276646.07999999996\" p1=\"0\" p2=\"12\"/>\n\t\t\t\t<Bond d=\".10969000000000001\" k=\"276646.07999999996\" p1=\"0\" p2=\"13\"/>\n\t\t\t\t<Bond d=\".10969000000000001\" k=\"276646.07999999996\" p1=\"2\" p2=\"14\"/>\n\t\t\t\t<Bond d=\".10969000000000001\" k=\"276646.07999999996\" p1=\"2\" p2=\"15\"/>\n\t\t\t\t<Bond d=\".10969000000000001\" k=\"276646.07999999996\" p1=\"2\" p2=\"16\"/>\n\t\t\t\t<Bond d=\".10879000000000001\" k=\"287106.07999999996\" p1=\"3\" p2=\"17\"/>\n\t\t\t\t<Bond d=\".10969000000000001\" k=\"276646.07999999996\" p1=\"4\" p2=\"18\"/>\n\t\t\t\t<Bond d=\".10969000000000001\" k=\"276646.07999999996\" p1=\"4\" p2=\"19\"/>\n\t\t\t\t<Bond d=\".10969000000000001\" k=\"276646.07999999996\" p1=\"5\" p2=\"20\"/>\n\t\t\t\t<Bond d=\".10969000000000001\" k=\"276646.07999999996\" p1=\"5\" p2=\"21\"/>\n\t\t\t\t<Bond d=\".10969000000000001\" k=\"276646.07999999996\" p1=\"7\" p2=\"22\"/>\n\t\t\t\t<Bond d=\".10969000000000001\" k=\"276646.07999999996\" p1=\"7\" p2=\"23\"/>\n\t\t\t\t<Bond d=\".10969000000000001\" k=\"276646.07999999996\" p1=\"7\" p2=\"24\"/>\n\t\t\t\t<Bond d=\".10883000000000001\" k=\"286603.99999999994\" p1=\"8\" p2=\"25\"/>\n\t\t\t\t<Bond d=\".11121000000000002\" k=\"259993.75999999995\" p1=\"9\" p2=\"26\"/>\n\t\t\t</Bonds>\n\t\t</Force>\n\t\t<Force forceGroup=\"0\" name=\"HarmonicAngleForce\" type=\"HarmonicAngleForce\" usesPeriodic=\"0\" version=\"2\">\n\t\t\t<Angles>\n\t\t\t\t<Angle a=\"2.018473279931442\" k=\"526.3472\" p1=\"0\" p2=\"1\" p3=\"2\"/>\n\t\t\t\t<Angle a=\"2.1577505542405895\" k=\"536.3888\" p1=\"0\" p2=\"1\" p3=\"3\"/>\n\t\t\t\t<Angle a=\"1.9261453625009421\" k=\"393.296\" p1=\"1\" p2=\"0\" p3=\"11\"/>\n\t\t\t\t<Angle a=\"1.9261453625009421\" k=\"393.296\" p1=\"1\" p2=\"0\" p3=\"12\"/>\n\t\t\t\t<Angle a=\"1.9261453625009421\" k=\"393.296\" p1=\"1\" p2=\"0\" p3=\"13\"/>\n\t\t\t\t<Angle a=\"1.9261453625009421\" k=\"393.296\" p1=\"1\" p2=\"2\" p3=\"14\"/>\n\t\t\t\t<Angle a=\"1.9261453625009421\" k=\"393.296\" p1=\"1\" p2=\"2\" p3=\"15\"/>\n\t\t\t\t<Angle a=\"1.9261453625009421\" k=\"393.296\" p1=\"1\" p2=\"2\" p3=\"16\"/>\n\t\t\t\t<Angle a=\"2.1577505542405895\" k=\"536.3888\" p1=\"1\" p2=\"3\" p3=\"4\"/>\n\t\t\t\t<Angle a=\"2.1019000181767713\" k=\"417.5632\" p1=\"1\" p2=\"3\" p3=\"17\"/>\n\t\t\t\t<Angle a=\"2.1577505542405895\" k=\"536.3888\" p1=\"2\" p2=\"1\" p3=\"3\"/>\n\t\t\t\t<Angle a=\"1.9470893135248741\" k=\"530.5312\" p1=\"3\" p2=\"4\" p3=\"5\"/>\n\t\t\t\t<Angle a=\"1.9261453625009421\" k=\"393.296\" p1=\"3\" p2=\"4\" p3=\"18\"/>\n\t\t\t\t<Angle a=\"1.9261453625009421\" k=\"393.296\" p1=\"3\" p2=\"4\" p3=\"19\"/>\n\t\t\t\t<Angle a=\"2.0189968787070405\" k=\"384.0912\" p1=\"4\" p2=\"3\" p3=\"17\"/>\n\t\t\t\t<Angle a=\"1.9470893135248741\" k=\"530.5312\" p1=\"4\" p2=\"5\" p3=\"6\"/>\n\t\t\t\t<Angle a=\"1.9163715186897738\" k=\"387.4384\" p1=\"4\" p2=\"5\" p3=\"20\"/>\n\t\t\t\t<Angle a=\"1.9163715186897738\" k=\"387.4384\" p1=\"4\" p2=\"5\" p3=\"21\"/>\n\t\t\t\t<Angle a=\"1.9163715186897738\" k=\"387.4384\" p1=\"5\" p2=\"4\" p3=\"18\"/>\n\t\t\t\t<Angle a=\"1.9163715186897738\" k=\"387.4384\" p1=\"5\" p2=\"4\" p3=\"19\"/>\n\t\t\t\t<Angle a=\"2.018473279931442\" k=\"526.3472\" p1=\"5\" p2=\"6\" p3=\"7\"/>\n\t\t\t\t<Angle a=\"2.1493729738310168\" k=\"535.552\" p1=\"5\" p2=\"6\" p3=\"8\"/>\n\t\t\t\t<Angle a=\"1.9261453625009421\" k=\"393.296\" p1=\"6\" p2=\"5\" p3=\"20\"/>\n\t\t\t\t<Angle a=\"1.9261453625009421\" k=\"393.296\" p1=\"6\" p2=\"5\" p3=\"21\"/>\n\t\t\t\t<Angle a=\"1.9261453625009421\" k=\"393.296\" p1=\"6\" p2=\"7\" p3=\"22\"/>\n\t\t\t\t<Angle a=\"1.9261453625009421\" k=\"393.296\" p1=\"6\" p2=\"7\" p3=\"23\"/>\n\t\t\t\t<Angle a=\"1.9261453625009421\" k=\"393.296\" p1=\"6\" p2=\"7\" p3=\"24\"/>\n\t\t\t\t<Angle a=\"2.101725485251572\" k=\"548.104\" p1=\"6\" p2=\"8\" p3=\"9\"/>\n\t\t\t\t<Angle a=\"2.0933479048419987\" k=\"415.05280000000005\" p1=\"6\" p2=\"8\" p3=\"25\"/>\n\t\t\t\t<Angle a=\"2.1493729738310168\" k=\"535.552\" p1=\"7\" p2=\"6\" p3=\"8\"/>\n\t\t\t\t<Angle a=\"2.150245638457014\" k=\"575.7184\" p1=\"8\" p2=\"9\" p3=\"10\"/>\n\t\t\t\t<Angle a=\"2.0052087776162852\" k=\"390.78560000000004\" p1=\"8\" p2=\"9\" p3=\"26\"/>\n\t\t\t\t<Angle a=\"2.0326104468725963\" k=\"389.112\" p1=\"9\" p2=\"8\" p3=\"25\"/>\n\t\t\t\t<Angle a=\"2.1066124071571557\" k=\"453.54560000000004\" p1=\"10\" p2=\"9\" p3=\"26\"/>\n\t\t\t\t<Angle a=\"1.8776252092954997\" k=\"329.6992\" p1=\"11\" p2=\"0\" p3=\"12\"/>\n\t\t\t\t<Angle a=\"1.8776252092954997\" k=\"329.6992\" p1=\"11\" p2=\"0\" p3=\"13\"/>\n\t\t\t\t<Angle a=\"1.8776252092954997\" k=\"329.6992\" p1=\"12\" p2=\"0\" p3=\"13\"/>\n\t\t\t\t<Angle a=\"1.8776252092954997\" k=\"329.6992\" p1=\"14\" p2=\"2\" p3=\"15\"/>\n\t\t\t\t<Angle a=\"1.8776252092954997\" k=\"329.6992\" p1=\"14\" p2=\"2\" p3=\"16\"/>\n\t\t\t\t<Angle a=\"1.8776252092954997\" k=\"329.6992\" p1=\"15\" p2=\"2\" p3=\"16\"/>\n\t\t\t\t<Angle a=\"1.8776252092954997\" k=\"329.6992\" p1=\"18\" p2=\"4\" p3=\"19\"/>\n\t\t\t\t<Angle a=\"1.8776252092954997\" k=\"329.6992\" p1=\"20\" p2=\"5\" p3=\"21\"/>\n\t\t\t\t<Angle a=\"1.8776252092954997\" k=\"329.6992\" p1=\"22\" p2=\"7\" p3=\"23\"/>\n\t\t\t\t<Angle a=\"1.8776252092954997\" k=\"329.6992\" p1=\"22\" p2=\"7\" p3=\"24\"/>\n\t\t\t\t<Angle a=\"1.8776252092954997\" k=\"329.6992\" p1=\"23\" p2=\"7\" p3=\"24\"/>\n\t\t\t</Angles>\n\t\t</Force>\n\t\t<Force forceGroup=\"0\" name=\"PeriodicTorsionForce\" type=\"PeriodicTorsionForce\" usesPeriodic=\"0\" version=\"2\">\n\t\t\t<Torsions>\n\t\t\t\t<Torsion k=\"27.823600000000003\" p1=\"0\" p2=\"1\" p3=\"3\" p4=\"4\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"7.9496\" p1=\"0\" p2=\"1\" p3=\"3\" p4=\"4\" periodicity=\"1\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"27.823600000000003\" p1=\"0\" p2=\"1\" p3=\"3\" p4=\"17\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"1.58992\" p1=\"1\" p2=\"3\" p3=\"4\" p4=\"18\" periodicity=\"3\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"4.811599999999999\" p1=\"1\" p2=\"3\" p3=\"4\" p4=\"18\" periodicity=\"1\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\"1.58992\" p1=\"1\" p2=\"3\" p3=\"4\" p4=\"19\" periodicity=\"3\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"4.811599999999999\" p1=\"1\" p2=\"3\" p3=\"4\" p4=\"19\" periodicity=\"1\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\"27.823600000000003\" p1=\"2\" p2=\"1\" p3=\"3\" p4=\"4\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"7.9496\" p1=\"2\" p2=\"1\" p3=\"3\" p4=\"4\" periodicity=\"1\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"27.823600000000003\" p1=\"2\" p2=\"1\" p3=\"3\" p4=\"17\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"1.58992\" p1=\"3\" p2=\"1\" p3=\"0\" p4=\"11\" periodicity=\"3\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"4.811599999999999\" p1=\"3\" p2=\"1\" p3=\"0\" p4=\"11\" periodicity=\"1\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\"1.58992\" p1=\"3\" p2=\"1\" p3=\"0\" p4=\"12\" periodicity=\"3\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"4.811599999999999\" p1=\"3\" p2=\"1\" p3=\"0\" p4=\"12\" periodicity=\"1\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\"1.58992\" p1=\"3\" p2=\"1\" p3=\"0\" p4=\"13\" periodicity=\"3\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"4.811599999999999\" p1=\"3\" p2=\"1\" p3=\"0\" p4=\"13\" periodicity=\"1\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\"1.58992\" p1=\"3\" p2=\"1\" p3=\"2\" p4=\"14\" periodicity=\"3\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"4.811599999999999\" p1=\"3\" p2=\"1\" p3=\"2\" p4=\"14\" periodicity=\"1\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\"1.58992\" p1=\"3\" p2=\"1\" p3=\"2\" p4=\"15\" periodicity=\"3\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"4.811599999999999\" p1=\"3\" p2=\"1\" p3=\"2\" p4=\"15\" periodicity=\"1\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\"1.58992\" p1=\"3\" p2=\"1\" p3=\"2\" p4=\"16\" periodicity=\"3\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"4.811599999999999\" p1=\"3\" p2=\"1\" p3=\"2\" p4=\"16\" periodicity=\"1\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\".6508444444444444\" p1=\"3\" p2=\"4\" p3=\"5\" p4=\"6\" periodicity=\"3\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\".6508444444444444\" p1=\"3\" p2=\"4\" p3=\"5\" p4=\"20\" periodicity=\"3\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\".6508444444444444\" p1=\"3\" p2=\"4\" p3=\"5\" p4=\"21\" periodicity=\"3\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\"27.823600000000003\" p1=\"5\" p2=\"6\" p3=\"8\" p4=\"9\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"27.823600000000003\" p1=\"5\" p2=\"6\" p3=\"8\" p4=\"25\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\".6508444444444444\" p1=\"6\" p2=\"5\" p3=\"4\" p4=\"18\" periodicity=\"3\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\".6508444444444444\" p1=\"6\" p2=\"5\" p3=\"4\" p4=\"19\" periodicity=\"3\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\"9.1002\" p1=\"6\" p2=\"8\" p3=\"9\" p4=\"10\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"9.1002\" p1=\"6\" p2=\"8\" p3=\"9\" p4=\"26\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"27.823600000000003\" p1=\"7\" p2=\"6\" p3=\"8\" p4=\"9\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"27.823600000000003\" p1=\"7\" p2=\"6\" p3=\"8\" p4=\"25\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"9.1002\" p1=\"10\" p2=\"9\" p3=\"8\" p4=\"25\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\".6276\" p1=\"18\" p2=\"4\" p3=\"5\" p4=\"20\" periodicity=\"3\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\".6276\" p1=\"18\" p2=\"4\" p3=\"5\" p4=\"21\" periodicity=\"3\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\".6276\" p1=\"19\" p2=\"4\" p3=\"5\" p4=\"20\" periodicity=\"3\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\".6276\" p1=\"19\" p2=\"4\" p3=\"5\" p4=\"21\" periodicity=\"3\" phase=\"0\"/>\n\t\t\t\t<Torsion k=\"9.1002\" p1=\"25\" p2=\"8\" p3=\"9\" p4=\"26\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"4.6024\" p1=\"0\" p2=\"3\" p3=\"1\" p4=\"2\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"4.6024\" p1=\"1\" p2=\"4\" p3=\"3\" p4=\"17\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"4.6024\" p1=\"5\" p2=\"7\" p3=\"6\" p4=\"8\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"4.6024\" p1=\"6\" p2=\"9\" p3=\"8\" p4=\"25\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t\t<Torsion k=\"43.932\" p1=\"8\" p2=\"26\" p3=\"9\" p4=\"10\" periodicity=\"2\" phase=\"3.141592653589793\"/>\n\t\t\t</Torsions>\n\t\t</Force>\n\t\t<Force alpha=\"0\" cutoff=\"1\" dispersionCorrection=\"1\" ewaldTolerance=\".0005\" exceptionsUsePeriodic=\"0\" forceGroup=\"0\" includeDirectSpace=\"1\" ljAlpha=\"0\" ljnx=\"0\" ljny=\"0\" ljnz=\"0\" method=\"0\" name=\"NonbondedForce\" nx=\"0\" ny=\"0\" nz=\"0\" recipForceGroup=\"-1\" rfDielectric=\"78.3\" switchingDistance=\"-1\" type=\"NonbondedForce\" useSwitchingFunction=\"0\" version=\"4\">\n\t\t\t<GlobalParameters/>\n\t\t\t<ParticleOffsets/>\n\t\t\t<ExceptionOffsets/>\n\t\t\t<Particles>\n\t\t\t\t<Particle eps=\".4577296\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Particle eps=\".359824\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Particle eps=\".4577296\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Particle eps=\".359824\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Particle eps=\".4577296\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Particle eps=\".4577296\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Particle eps=\".359824\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Particle eps=\".4577296\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Particle eps=\".359824\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Particle eps=\".359824\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Particle eps=\".87864\" q=\"0\" sig=\".2959921901149463\"/>\n\t\t\t\t<Particle eps=\".06568879999999999\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Particle eps=\".06568879999999999\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Particle eps=\".06568879999999999\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Particle eps=\".06568879999999999\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Particle eps=\".06568879999999999\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Particle eps=\".06568879999999999\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Particle eps=\".06276\" q=\"0\" sig=\".25996424595335105\"/>\n\t\t\t\t<Particle eps=\".06568879999999999\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Particle eps=\".06568879999999999\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Particle eps=\".06568879999999999\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Particle eps=\".06568879999999999\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Particle eps=\".06568879999999999\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Particle eps=\".06568879999999999\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Particle eps=\".06568879999999999\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Particle eps=\".06276\" q=\"0\" sig=\".25996424595335105\"/>\n\t\t\t\t<Particle eps=\".06276\" q=\"0\" sig=\".2510552587719476\"/>\n\t\t\t</Particles>\n\t\t\t<Exceptions>\n\t\t\t\t<Exception eps=\"0\" p1=\"0\" p2=\"1\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"0\" p2=\"2\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"1\" p2=\"2\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"0\" p2=\"3\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"1\" p2=\"3\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"2\" p2=\"3\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".2288648\" p1=\"0\" p2=\"4\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"1\" p2=\"4\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".2288648\" p1=\"2\" p2=\"4\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"3\" p2=\"4\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".20291752979375635\" p1=\"1\" p2=\"5\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"3\" p2=\"5\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"4\" p2=\"5\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".179912\" p1=\"3\" p2=\"6\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"4\" p2=\"6\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"5\" p2=\"6\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".2288648\" p1=\"4\" p2=\"7\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"5\" p2=\"7\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"6\" p2=\"7\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".20291752979375635\" p1=\"4\" p2=\"8\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"5\" p2=\"8\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"6\" p2=\"8\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"7\" p2=\"8\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".20291752979375635\" p1=\"5\" p2=\"9\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"6\" p2=\"9\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".20291752979375635\" p1=\"7\" p2=\"9\" q=\"0\" sig=\".3399669508423535\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"8\" p2=\"9\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".28113864878383404\" p1=\"6\" p2=\"10\" q=\"0\" sig=\".3179795704786499\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"8\" p2=\"10\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"9\" p2=\"10\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"0\" p2=\"11\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"1\" p2=\"11\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".08670021359327784\" p1=\"2\" p2=\"11\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"3\" p2=\"11\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"0\" p2=\"12\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"1\" p2=\"12\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".08670021359327784\" p1=\"2\" p2=\"12\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"3\" p2=\"12\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"11\" p2=\"12\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"0\" p2=\"13\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"1\" p2=\"13\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".08670021359327784\" p1=\"2\" p2=\"13\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"3\" p2=\"13\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"11\" p2=\"13\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"12\" p2=\"13\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".08670021359327784\" p1=\"0\" p2=\"14\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"1\" p2=\"14\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"2\" p2=\"14\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"3\" p2=\"14\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\".08670021359327784\" p1=\"0\" p2=\"15\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"1\" p2=\"15\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"2\" p2=\"15\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"3\" p2=\"15\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"14\" p2=\"15\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".08670021359327784\" p1=\"0\" p2=\"16\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"1\" p2=\"16\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"2\" p2=\"16\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"3\" p2=\"16\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"14\" p2=\"16\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"15\" p2=\"16\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".08474536815661372\" p1=\"0\" p2=\"17\" q=\"0\" sig=\".2999655983978523\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"1\" p2=\"17\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".08474536815661372\" p1=\"2\" p2=\"17\" q=\"0\" sig=\".2999655983978523\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"3\" p2=\"17\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"4\" p2=\"17\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".08474536815661372\" p1=\"5\" p2=\"17\" q=\"0\" sig=\".2999655983978523\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"1\" p2=\"18\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"3\" p2=\"18\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"4\" p2=\"18\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"5\" p2=\"18\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"6\" p2=\"18\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\".03210385135774211\" p1=\"17\" p2=\"18\" q=\"0\" sig=\".262458762364144\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"1\" p2=\"19\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"3\" p2=\"19\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"4\" p2=\"19\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"5\" p2=\"19\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"6\" p2=\"19\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\".03210385135774211\" p1=\"17\" p2=\"19\" q=\"0\" sig=\".262458762364144\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"18\" p2=\"19\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"3\" p2=\"20\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"4\" p2=\"20\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"5\" p2=\"20\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"6\" p2=\"20\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".08670021359327784\" p1=\"7\" p2=\"20\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"8\" p2=\"20\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\".032844399999999996\" p1=\"18\" p2=\"20\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Exception eps=\".032844399999999996\" p1=\"19\" p2=\"20\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"3\" p2=\"21\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"4\" p2=\"21\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"5\" p2=\"21\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"6\" p2=\"21\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".08670021359327784\" p1=\"7\" p2=\"21\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"8\" p2=\"21\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\".032844399999999996\" p1=\"18\" p2=\"21\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Exception eps=\".032844399999999996\" p1=\"19\" p2=\"21\" q=\"0\" sig=\".2649532787749369\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"20\" p2=\"21\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".08670021359327784\" p1=\"5\" p2=\"22\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"6\" p2=\"22\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"7\" p2=\"22\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"8\" p2=\"22\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\".08670021359327784\" p1=\"5\" p2=\"23\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"6\" p2=\"23\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"7\" p2=\"23\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"8\" p2=\"23\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"22\" p2=\"23\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".08670021359327784\" p1=\"5\" p2=\"24\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"6\" p2=\"24\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"7\" p2=\"24\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".07687068162049819\" p1=\"8\" p2=\"24\" q=\"0\" sig=\".3024601148086452\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"22\" p2=\"24\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"23\" p2=\"24\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".08474536815661372\" p1=\"5\" p2=\"25\" q=\"0\" sig=\".2999655983978523\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"6\" p2=\"25\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".08474536815661372\" p1=\"7\" p2=\"25\" q=\"0\" sig=\".2999655983978523\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"8\" p2=\"25\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"9\" p2=\"25\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".11741320879696628\" p1=\"10\" p2=\"25\" q=\"0\" sig=\".2779782180341487\"/>\n\t\t\t\t<Exception eps=\".07513746442354838\" p1=\"6\" p2=\"26\" q=\"0\" sig=\".2955111048071506\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"8\" p2=\"26\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"9\" p2=\"26\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\"0\" p1=\"10\" p2=\"26\" q=\"0\" sig=\"1\"/>\n\t\t\t\t<Exception eps=\".03138\" p1=\"25\" p2=\"26\" q=\"0\" sig=\".25550975236264933\"/>\n\t\t\t</Exceptions>\n\t\t</Force>\n\t\t<Force forceGroup=\"0\" frequency=\"1\" name=\"CMMotionRemover\" type=\"CMMotionRemover\" version=\"1\"/>\n\t</Forces>\n</System>\n"
  },
  {
    "path": "espaloma/mm/tests/test_angle.py",
    "content": "import numpy as np\nimport numpy.testing as npt\nimport pytest\nimport torch\n\n\ndef test_angle_random_vectors():\n    import espaloma as esp\n\n    distribution = torch.distributions.normal.Normal(\n        loc=torch.zeros(\n            3,\n        ),\n        scale=torch.ones(\n            3,\n        ),\n    )\n\n    left = distribution.sample()\n    right = distribution.sample()\n\n    cos_ref = (left * right).sum(dim=-1) / (\n        torch.norm(left) * torch.norm(right)\n    )\n\n    cos_hat = torch.cos(esp.mm.geometry._angle(left, right))\n\n    npt.assert_almost_equal(cos_ref.numpy(), cos_hat.numpy(), decimal=3)\n\n\ndef test_angle_random_points():\n    import espaloma as esp\n\n    distribution = torch.distributions.normal.Normal(\n        loc=torch.zeros(5, 3), scale=torch.ones(5, 3)\n    )\n\n    x0 = distribution.sample()\n    x1 = distribution.sample()\n    x2 = distribution.sample()\n\n    left = x1 - x0\n    right = x1 - x2\n\n    cos_ref = (left * right).sum(dim=-1) / (\n        torch.norm(left, dim=-1) * torch.norm(right, dim=-1)\n    )\n\n    cos_hat = torch.cos(esp.angle(x0, x1, x2))\n\n    npt.assert_almost_equal(cos_ref.numpy(), cos_hat.numpy(), decimal=3)\n\n\ndef test_zero():\n    import espaloma as esp\n\n    x0 = torch.zeros(5, 3)\n\n    npt.assert_almost_equal(esp.angle(x0, x0, x0).numpy(), 0.0)\n"
  },
  {
    "path": "espaloma/mm/tests/test_angle_energy.py",
    "content": "import numpy as np\nimport numpy.testing as npt\nimport pytest\nimport torch\nimport openmm\nfrom openmm import unit\n\nfrom espaloma.utils.geometry import _sample_four_particle_torsion_scan\n\nomm_angle_unit = unit.radian\nomm_energy_unit = unit.kilojoule_per_mole\n\nfrom openmm import app\n\nimport espaloma as esp\n\n\ndef test_energy_angle_and_bond():\n    g = esp.Graph(\"C\")\n    # make simulation\n    from espaloma.data.md import MoleculeVacuumSimulation\n\n    # get simulation\n    esp_simulation = MoleculeVacuumSimulation(\n        n_samples=1, n_steps_per_sample=10, forcefield=\"gaff-1.81\"\n    )\n\n    simulation = esp_simulation.simulation_from_graph(g)\n    system = simulation.system\n    esp_simulation.run(g)\n\n    forces = list(system.getForces())\n\n    energies = {}\n\n    for idx, force in enumerate(forces):\n        force.setForceGroup(idx)\n\n        name = force.__class__.__name__\n\n        if \"Nonbonded\" in name:\n            force.setNonbondedMethod(openmm.NonbondedForce.NoCutoff)\n\n    # create new simulation\n    _simulation = openmm.app.Simulation(\n        simulation.topology,\n        system,\n        openmm.VerletIntegrator(0.0),\n    )\n\n    _simulation.context.setPositions(\n        g.nodes[\"n1\"].data[\"xyz\"][:, 0, :].detach().numpy() * unit.nanometer\n    )\n\n    for idx, force in enumerate(forces):\n        name = force.__class__.__name__\n\n        state = _simulation.context.getState(\n            getEnergy=True,\n            getParameters=True,\n            groups=2**idx,\n        )\n\n        energy = state.getPotentialEnergy().value_in_unit(\n            esp.units.ENERGY_UNIT\n        )\n\n        energies[name] = energy\n\n    for idx, force in enumerate(forces):\n        name = force.__class__.__name__\n        if \"HarmonicAngleForce\" in name:\n            print(\"openmm thinks there are %s angles\" % force.getNumAngles())\n\n            for _idx in range(force.getNumAngles()):\n                _, __, ___, eq, k = force.getAngleParameters(_idx)\n                eq = eq.value_in_unit(esp.units.ANGLE_UNIT)\n                k = k.value_in_unit(esp.units.ANGLE_FORCE_CONSTANT_UNIT)\n                print(eq, k)\n\n    # parametrize\n    ff = esp.graphs.legacy_force_field.LegacyForceField(\"gaff-1.81\")\n    g = ff.parametrize(g)\n\n    # n2 : bond, n3: angle, n1: nonbonded?\n    # n1 : sigma (k), epsilon (eq), and charge (not included yet)\n    for term in [\"n2\", \"n3\"]:\n        g.nodes[term].data[\"k\"] = g.nodes[term].data[\"k_ref\"]\n        g.nodes[term].data[\"eq\"] = g.nodes[term].data[\"eq_ref\"]\n\n    print(\n        \"espaloma thinks there are %s angles\"\n        % g.heterograph.number_of_nodes(\"n3\")\n    )\n    print(g.nodes[\"n3\"].data[\"k\"])\n    print(g.nodes[\"n3\"].data[\"eq\"])\n\n    # for each atom, store n_snapshots x 3\n    # g.nodes[\"n1\"].data[\"xyz\"] = torch.tensor(\n    #     simulation.context.getState(getPositions=True)\n    #     .getPositions(asNumpy=True)\n    #     .value_in_unit(esp.units.DISTANCE_UNIT),\n    #     dtype=torch.float32,\n    # )[None, :, :].permute(1, 0, 2)\n\n    # print(g.nodes['n2'].data)\n    esp.mm.geometry.geometry_in_graph(g.heterograph)\n    esp.mm.energy.energy_in_graph(g.heterograph, terms=[\"n2\", \"n3\", \"n4\"])\n\n    n_decimals = 3\n\n    # test angles\n    npt.assert_almost_equal(\n        g.nodes[\"g\"].data[\"u_n3\"].detach().numpy(),\n        energies[\"HarmonicAngleForce\"],\n        decimal=n_decimals,\n    )\n\n\nif __name__ == \"__main__\":\n    test_energy_angle_and_bond()\n"
  },
  {
    "path": "espaloma/mm/tests/test_bond_energy.py",
    "content": "import pytest\n\n\ndef test_multiple_conformation():\n    import espaloma as esp\n\n    g = esp.Graph(\"c1ccccc1\")\n\n    # make simulation\n    from espaloma.data.md import MoleculeVacuumSimulation\n\n    simulation = MoleculeVacuumSimulation(n_samples=10, n_steps_per_sample=10)\n    g = simulation.run(g, in_place=True)\n\n    param = esp.graphs.legacy_force_field.LegacyForceField(\n        \"smirnoff99Frosst-1.1.0\"\n    ).parametrize\n\n    g = param(g)\n\n    esp.mm.geometry.geometry_in_graph(g.heterograph)\n\n    esp.mm.energy.energy_in_graph(g.heterograph, suffix=\"_ref\")\n"
  },
  {
    "path": "espaloma/mm/tests/test_charge_energy_consistency.py",
    "content": "import pytest\nimport espaloma as esp\nimport numpy as np\nimport numpy.testing as npt\nimport pytest\nimport torch\n\n\n@pytest.mark.parametrize(\n    \"g\",\n    esp.data.esol(first=10),  # use a subset of ESOL dataset to test\n    # [esp.Graph(\"c1ccccc1\")],\n)\ndef test_coulomb_energy_consistency(g):\n    \"\"\"We use both `esp.mm` and OpenMM to compute the Coulomb energy of\n    some molecules with generated geometries and see if the resulting Columb\n    energy matches.\n\n\n    \"\"\"\n    from openff.units import unit as openff_unit\n\n    from espaloma.data.md import MoleculeVacuumSimulation\n\n    print(g.mol)\n\n    # get simulation\n    esp_simulation = MoleculeVacuumSimulation(\n        n_samples=10,\n        n_steps_per_sample=10,\n        forcefield=\"gaff-1.81\",\n        charge_method=\"gasteiger\",\n    )\n\n    simulation = esp_simulation.simulation_from_graph(g)\n    charges = g.mol.partial_charges.m_as(openff_unit.elementary_charge).flatten()\n    system = simulation.system\n\n    esp_simulation.run(g, in_place=True)\n\n    # if MD blows up, forget about it\n    if g.nodes[\"n1\"].data[\"xyz\"].abs().max() > 100:\n        pytest.skip(\n            \"MD simulation blew up, skipping test. \"\n        )\n\n    g.nodes[\"n1\"].data[\"q\"] = torch.tensor(charges).unsqueeze(-1)\n    esp.mm.nonbonded.multiply_charges(g.heterograph)\n    esp.mm.geometry.geometry_in_graph(g.heterograph)\n    esp.mm.energy.energy_in_graph(\n        g.heterograph, terms=[\"nonbonded\", \"onefour\"]\n    )\n\n    print(g.nodes[\"g\"].data[\"u\"].detach())\n    print(esp.data.md.get_coulomb_force(g)[0])\n\n    npt.assert_almost_equal(\n        g.nodes[\"g\"].data[\"u\"].detach().numpy(),\n        esp.data.md.get_coulomb_force(g)[0].numpy(),\n        decimal=3,\n    )\n"
  },
  {
    "path": "espaloma/mm/tests/test_charge_energy_consistency_hardcode.py",
    "content": "import pytest\nimport espaloma as esp\nimport numpy as np\nimport numpy.testing as npt\nimport pytest\nimport torch\nimport openmm\nfrom openmm import unit\n\n\n@pytest.mark.parametrize(\n    \"g\",\n    esp.data.esol(first=1),  # use a subset of ESOL dataset to test\n)\ndef test_coulomb_energy_consistency(g):\n    \"\"\"We use both `esp.mm` and OpenMM to compute the Coulomb energy of\n    some molecules with generated geometries and see if the resulting Columb\n    energy matches.\n\n\n    \"\"\"\n    # make simulation\n    from espaloma.data.md import MoleculeVacuumSimulation\n\n    # get simulation\n    esp_simulation = MoleculeVacuumSimulation(\n        n_samples=1,\n        n_steps_per_sample=10,\n        forcefield=\"gaff-1.81\",\n        charge_method=\"gasteiger\",\n    )\n\n    simulation = esp_simulation.simulation_from_graph(g)\n    charges = g.mol.partial_charges.flatten()\n    system = simulation.system\n\n    esp_simulation.run(g, in_place=True)\n\n    # if MD blows up, forget about it\n    if g.nodes[\"n1\"].data[\"xyz\"].abs().max() > 100:\n        pytest.skip(\n            \"MD simulation blew up, skipping test. \"\n        )\n\n    _simulation = openmm.app.Simulation(\n        simulation.topology,\n        system,\n        openmm.VerletIntegrator(0.0),\n    )\n\n    forces = list(system.getForces())\n    for force in forces:\n        name = force.__class__.__name__\n        if \"Nonbonded\" in name:\n            force.setNonbondedMethod(openmm.NonbondedForce.NoCutoff)\n            force.updateParametersInContext(_simulation.context)\n\n    _simulation.context.setPositions(\n        g.nodes[\"n1\"].data[\"xyz\"][:, 0, :].detach().numpy() * unit.bohr\n    )\n\n    state = _simulation.context.getState(\n        getEnergy=True,\n        getParameters=True,\n    )\n\n    energy_old = state.getPotentialEnergy().value_in_unit(\n        esp.units.ENERGY_UNIT\n    )\n\n    forces = list(system.getForces())\n\n    print(forces)\n    for force in forces:\n        name = force.__class__.__name__\n        print(name)\n        if name == \"NonbondedForce\":\n            force.setNonbondedMethod(openmm.NonbondedForce.NoCutoff)\n            print(force.getNumExceptions())\n            for idx in range(force.getNumParticles()):\n                q, sigma, epsilon = force.getParticleParameters(idx)\n                force.setParticleParameters(idx, 0.0, sigma, epsilon)\n\n            for idx in range(force.getNumExceptions()):\n                idx0, idx1, q, sigma, epsilon = force.getExceptionParameters(\n                    idx\n                )\n                force.setExceptionParameters(\n                    idx, idx0, idx1, 0.0, sigma, epsilon\n                )\n\n            force.updateParametersInContext(_simulation.context)\n\n    state = _simulation.context.getState(\n        getEnergy=True,\n        getParameters=True,\n    )\n\n    energy_new = state.getPotentialEnergy().value_in_unit(\n        esp.units.ENERGY_UNIT\n    )\n\n    g.nodes[\"n1\"].data[\"q\"] = torch.tensor(charges).unsqueeze(-1)\n    esp.mm.nonbonded.multiply_charges(g.heterograph)\n    esp.mm.geometry.geometry_in_graph(g.heterograph)\n    esp.mm.energy.energy_in_graph(\n        g.heterograph, terms=[\"nonbonded\", \"onefour\"]\n    )\n\n    npt.assert_almost_equal(\n        g.nodes[\"g\"].data[\"u\"].item(),\n        energy_old - energy_new,\n        decimal=3,\n    )\n"
  },
  {
    "path": "espaloma/mm/tests/test_dihedral.py",
    "content": "import numpy.testing as npt\nimport torch\n\nimport espaloma as esp\nfrom espaloma.utils.geometry import (\n    _sample_four_particle_torsion_scan,\n    _timemachine_signed_torsion_angle,\n)\n\n\ndef test_dihedral_vectors():\n    import espaloma as esp\n\n    distribution = torch.distributions.normal.Normal(\n        loc=torch.zeros(5, 3), scale=torch.ones(5, 3)\n    )\n\n    left = distribution.sample()\n    right = distribution.sample()\n\n    npt.assert_almost_equal(\n        esp.mm.geometry._angle(left, right).numpy(),\n        esp.mm.geometry._dihedral(left, right).numpy(),\n        decimal=3,\n    )\n\n\ndef test_dihedral_points():\n    n_samples = 1000\n\n    # get geometries\n    xyz_np = _sample_four_particle_torsion_scan(n_samples)\n\n    # compute dihedrals using timemachine (numpy / JAX)\n    ci, cj, ck, cl = (\n        xyz_np[:, 0, :],\n        xyz_np[:, 1, :],\n        xyz_np[:, 2, :],\n        xyz_np[:, 3, :],\n    )\n    theta_timemachine = _timemachine_signed_torsion_angle(ci, cj, ck, cl)\n\n    # compute dihedrals using espaloma (PyTorch)\n    xyz = torch.tensor(xyz_np)\n    x0, x1, x2, x3 = xyz[:, 0, :], xyz[:, 1, :], xyz[:, 2, :], xyz[:, 3, :]\n    theta_espaloma = esp.dihedral(x0, x1, x2, x3).numpy()\n\n    npt.assert_almost_equal(\n        theta_timemachine,\n        theta_espaloma,\n        decimal=3,\n    )\n"
  },
  {
    "path": "espaloma/mm/tests/test_distance.py",
    "content": "import numpy as np\nimport numpy.testing as npt\nimport pytest\nimport torch\n\n\ndef test_distance():\n    import espaloma as esp\n\n    distribution = torch.distributions.normal.Normal(\n        loc=torch.zeros(5, 3), scale=torch.ones(5, 3)\n    )\n\n    x0 = distribution.sample()\n    x1 = distribution.sample()\n\n    npt.assert_almost_equal(\n        esp.distance(x0, x1).numpy(),\n        torch.sqrt((x0 - x1).pow(2).sum(dim=-1)).numpy(),\n        decimal=3,\n    )\n\n    npt.assert_almost_equal(esp.distance(x0, x0).numpy(), 0.0)\n"
  },
  {
    "path": "espaloma/mm/tests/test_energy.py",
    "content": "import pytest\nimport torch\n\nimport espaloma as esp\n\n\ndef test_import():\n    esp.mm.energy\n\n\ndef test_energy():\n    g = esp.Graph(\"c1ccccc1\")\n\n    # make simulation\n    from espaloma.data.md import MoleculeVacuumSimulation\n\n    simulation = MoleculeVacuumSimulation(n_samples=10, n_steps_per_sample=10)\n    g = simulation.run(g, in_place=True)\n\n    param = esp.graphs.legacy_force_field.LegacyForceField(\n        \"smirnoff99Frosst-1.1.0\"\n    ).parametrize\n\n    g = param(g)\n\n    # parametrize\n    layer = esp.nn.dgl_legacy.gn()\n    net = torch.nn.Sequential(\n        esp.nn.Sequential(layer, [32, \"tanh\", 32, \"tanh\", 32, \"tanh\"]),\n        esp.nn.readout.janossy.JanossyPooling(\n            in_features=32,\n            config=[32, \"tanh\"],\n            out_features={\n                1: [\"epsilon\", \"sigma\"],\n                2: [\"k\", \"eq\"],\n                3: [\"k\", \"eq\"],\n                4: [\"k\"],\n            },\n        ),\n        esp.nn.readout.janossy.JanossyPoolingImproper(\n            in_features=32,\n            config=[32, \"tanh\"],\n            out_features={\n                \"k\": 6,\n            },\n        ),\n    )\n\n    g = net(g.heterograph)\n\n    # print(g.nodes['n2'].data)\n    esp.mm.geometry.geometry_in_graph(g)\n    # esp.mm.energy.energy_in_graph(g)\n\n    esp.mm.energy.energy_in_graph(g, terms=[\"n2\", \"n3\", \"n4\", \"n4_improper\"])\n\n\n# def test_energy_consistent():\n#     g = esp.Graph(\"c1ccccc1\")\n#\n#     # make simulation\n#     from espaloma.data.md import MoleculeVacuumSimulation\n#\n#     simulation = MoleculeVacuumSimulation(n_samples=10, n_steps_per_sample=10)\n#     g = simulation.run(g, in_place=True)\n#\n#     param = esp.graphs.legacy_force_field.LegacyForceField(\n#         \"smirnoff99Frosst-1.1.0\"\n#     ).parametrize\n#\n#     g = param(g)\n#\n#     for node in [\"n1\", \"n2\", \"n3\"]:\n#         _dict = {}\n#         for data in g.nodes[node].data.keys():\n#             if data.endswith(\"_ref\"):\n#                 _dict[data.replace(\"_ref\", \"\")] = g.nodes[node].data[data]\n#         for key, value in _dict.items():\n#             g.nodes[node].data[key] = value\n#\n#     # print(g.nodes['n2'].data)\n#     esp.mm.geometry.geometry_in_graph(g.heterograph)\n#     esp.mm.energy.energy_in_graph(g.heterograph)\n#\n#     esp.mm.energy.energy_in_graph(g.heterograph, suffix=\"_ref\")\n"
  },
  {
    "path": "espaloma/mm/tests/test_energy_gaussian.py",
    "content": "import pytest\n\n\n\"\"\"\ndef test_energy():\n    import espaloma as esp\n    import torch\n\n    g = esp.Graph(\"c1ccccc1\")\n\n    # make simulation\n    from espaloma.data.md import MoleculeVacuumSimulation\n\n    simulation = MoleculeVacuumSimulation(n_samples=10, n_steps_per_sample=10)\n    g = simulation.run(g, in_place=True)\n\n    param = esp.graphs.legacy_force_field.LegacyForceField(\n        \"smirnoff99Frosst-1.1.0\"\n    ).parametrize\n\n    g = param(g)\n\n    # parametrize\n    layer = esp.nn.dgl_legacy.gn()\n    net = torch.nn.Sequential(\n        esp.nn.Sequential(layer, [32, \"tanh\", 32, \"tanh\", 32, \"tanh\"]),\n        esp.nn.readout.janossy.JanossyPooling(\n            in_features=32, config=[32, \"tanh\"],\n            out_features={\n                1: {'sigma': 1, 'epsilon': 1},\n                2: {'coefficients': 200},\n                3: {'k':1, 'eq': 1},\n            },\n        ),\n    )\n\n    g = net(g.heterograph)\n\n    # print(g.nodes['n2'].data)\n    esp.mm.geometry.geometry_in_graph(g)\n    esp.mm.energy.energy_in_graph(g)\n\n    esp.mm.energy.energy_in_graph(g, suffix=\"_ref\")\n\"\"\"\n"
  },
  {
    "path": "espaloma/mm/tests/test_energy_ii.py",
    "content": "import pytest\nimport espaloma as esp\nimport torch\nimport dgl\n\n\ndef test_energy():\n    g = esp.Graph(\"c1ccccc1\")\n\n    # make simulation\n    from espaloma.data.md import MoleculeVacuumSimulation\n\n    simulation = MoleculeVacuumSimulation(n_samples=10, n_steps_per_sample=10)\n    g = simulation.run(g, in_place=True)\n\n    param = esp.graphs.legacy_force_field.LegacyForceField(\n        \"gaff-1.81\"\n    ).parametrize\n\n    g = param(g)\n\n    # parametrize\n\n    # layer\n    layer = esp.nn.layers.dgl_legacy.gn()\n\n    # representation\n    representation = esp.nn.Sequential(\n        layer, config=[32, \"relu\", 32, \"relu\", 32, \"relu\"]\n    )\n\n    # get the last bit of units\n    units = 32\n\n    janossy_config = [32, \"relu\"]\n\n    readout = esp.nn.readout.janossy.JanossyPooling(\n        in_features=units,\n        config=janossy_config,\n        out_features={\n            2: {\"log_coefficients\": 2},\n            3: {\n                \"log_coefficients\": 2,\n                \"coefficients_urey_bradley\": 2,\n                \"k_bond_bond\": 1,\n                \"k_bond_angle\": 1,\n                \"k_bond_angle\": 1,\n            },\n            4: {\n                \"k\": 6,\n                \"k_angle_angle\": 1,\n                \"k_angle_angle_torsion\": 1,\n                \"k_angle_torsion\": 1,\n                \"k_side_torsion\": 1,\n                \"k_center_torsion\": 1,\n            },\n        },\n    )\n\n    readout_improper = esp.nn.readout.janossy.JanossyPoolingImproper(\n        in_features=units, config=janossy_config\n    )\n\n    class ExpCoeff(torch.nn.Module):\n        def forward(self, g):\n            g.nodes[\"n2\"].data[\"coefficients\"] = (\n                g.nodes[\"n2\"].data[\"log_coefficients\"].exp()\n            )\n            g.nodes[\"n3\"].data[\"coefficients\"] = (\n                g.nodes[\"n3\"].data[\"log_coefficients\"].exp()\n            )\n            return g\n\n    class CarryII(torch.nn.Module):\n        def forward(self, g):\n            import math\n\n            g.multi_update_all(\n                {\n                    \"n2_as_0_in_n3\": (\n                        dgl.function.copy_u(\"u\", \"m_u_0\"),\n                        dgl.function.sum(\"m_u_0\", \"u_left\"),\n                    ),\n                    \"n2_as_1_in_n3\": (\n                        dgl.function.copy_u(\"u\", \"m_u_1\"),\n                        dgl.function.sum(\"m_u_1\", \"u_right\"),\n                    ),\n                    \"n2_as_0_in_n4\": (\n                        dgl.function.copy_u(\"u\", \"m_u_0\"),\n                        dgl.function.sum(\"m_u_0\", \"u_bond_left\"),\n                    ),\n                    \"n2_as_1_in_n4\": (\n                        dgl.function.copy_u(\"u\", \"m_u_1\"),\n                        dgl.function.sum(\"m_u_1\", \"u_bond_center\"),\n                    ),\n                    \"n2_as_2_in_n4\": (\n                        dgl.function.copy_u(\"u\", \"m_u_2\"),\n                        dgl.function.sum(\"m_u_2\", \"u_bond_right\"),\n                    ),\n                    \"n3_as_0_in_n4\": (\n                        dgl.function.copy_u(\"u\", \"m3_u_0\"),\n                        dgl.function.sum(\"m3_u_0\", \"u_angle_left\"),\n                    ),\n                    \"n3_as_1_in_n4\": (\n                        dgl.function.copy_u(\"u\", \"m3_u_1\"),\n                        dgl.function.sum(\"m3_u_1\", \"u_angle_right\"),\n                    ),\n                },\n                cross_reducer=\"sum\",\n            )\n\n            return g\n\n    net = torch.nn.Sequential(\n        representation,\n        readout,\n        readout_improper,\n        ExpCoeff(),\n        esp.mm.geometry.GeometryInGraph(),\n        esp.mm.energy.EnergyInGraph(terms=[\"n2\", \"n3\", \"n4\", \"n4_improper\"]),\n        CarryII(),\n        esp.mm.energy.EnergyInGraphII(),\n    )\n\n    torch.nn.init.normal_(\n        net[1].f_out_2_to_log_coefficients.bias,\n        mean=-5,\n    )\n    torch.nn.init.normal_(\n        net[1].f_out_3_to_log_coefficients.bias,\n        mean=-5,\n    )\n\n    for name, module in net[1].named_modules():\n        if \"k\" in name:\n            torch.nn.init.normal(module.bias, mean=0.0, std=1e-4)\n            torch.nn.init.normal(module.weight, mean=0.0, std=1e-4)\n\n    g = net(g.heterograph)\n\n    print(g.nodes[\"n3\"].data)\n    print(g.nodes[\"n4\"].data)\n\n    # print(g.nodes['n2'].data)\n    esp.mm.geometry.geometry_in_graph(g)\n    esp.mm.energy.energy_in_graph(g)\n"
  },
  {
    "path": "espaloma/mm/tests/test_geometry.py",
    "content": "import pytest\nimport torch\n\nimport espaloma as esp\nfrom espaloma.graphs.utils.regenerate_impropers import regenerate_impropers\n\n\ndef test_import():\n    esp.mm.geometry\n\n\n# later, if we want to do multiple molecules, group these into a struct\nsmiles = \"c1ccccc1\"\nn_samples = 2\n## Different number of expected terms for different improper permutations\nexpected_n_terms = {\n    \"none\": dict(n2=24, n3=36, n4=48, n4_improper=36),\n    \"espaloma\": dict(n2=24, n3=36, n4=48, n4_improper=36),\n    \"smirnoff\": dict(n2=24, n3=36, n4=48, n4_improper=18),\n}\n\n\n@pytest.fixture\ndef all_g():\n    from espaloma.data.md import MoleculeVacuumSimulation\n\n    all_g = {}\n    for improper_def in expected_n_terms.keys():\n        g = esp.Graph(smiles)\n        if improper_def != \"none\":\n            regenerate_impropers(g, improper_def)\n\n        simulation = MoleculeVacuumSimulation(\n            n_samples=n_samples, n_steps_per_sample=1\n        )\n        g = simulation.run(g, in_place=True)\n        all_g[improper_def] = g\n    return all_g\n\n\ndef test_geometry_can_be_computed_without_exceptions(all_g):\n    for g in all_g.values():\n        g = esp.mm.geometry.geometry_in_graph(g.heterograph)\n\n\ndef test_geometry_n_terms(all_g):\n    for improper_def, g in all_g.items():\n        g = esp.mm.geometry.geometry_in_graph(g.heterograph)\n\n        for term, n_terms in expected_n_terms[improper_def].items():\n            assert g.nodes[term].data[\"x\"].shape == torch.Size(\n                [n_terms, n_samples]\n            )\n"
  },
  {
    "path": "espaloma/mm/tests/test_linear_combination.py",
    "content": "import pytest\n\n\ndef test_linear_combination():\n    import torch\n    import espaloma as esp\n\n    assert (\n        esp.mm.functional.linear_mixture(\n            0.0,\n            torch.tensor([[0.0, 0.0]]),\n        )\n        == 0.0\n    )\n    assert (\n        esp.mm.functional.linear_mixture(\n            1.0,\n            torch.tensor([[1.0, 1.0]]),\n            [0.0, 2.0],\n        )\n        == 1.0\n    )\n\n\ndef test_consistency():\n    import torch\n    import espaloma as esp\n\n    g = esp.Graph(\"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\")\n\n    from espaloma.data.md import MoleculeVacuumSimulation\n\n    simulation = MoleculeVacuumSimulation(n_samples=10, n_steps_per_sample=10)\n    g = simulation.run(g, in_place=True)\n\n    g.nodes[\"n2\"].data[\"coefficients\"] = torch.randn(\n        g.heterograph.number_of_nodes(\"n2\"), 2\n    ).exp()\n\n    g.nodes[\"n3\"].data[\"coefficients\"] = torch.randn(\n        g.heterograph.number_of_nodes(\"n3\"), 2\n    ).exp()\n\n    esp.mm.geometry.geometry_in_graph(g.heterograph)\n\n    esp.mm.energy.energy_in_graph(g.heterograph, terms=[\"n2\", \"n3\"])\n\n    u0_2 = g.nodes[\"n2\"].data[\"u\"] - g.nodes[\"n2\"].data[\"u\"].mean(\n        dim=1, keepdims=True\n    )\n    u0_3 = g.nodes[\"n3\"].data[\"u\"] - g.nodes[\"n3\"].data[\"u\"].mean(\n        dim=1, keepdims=True\n    )\n    u0 = g.nodes[\"g\"].data[\"u\"] - g.nodes[\"g\"].data[\"u\"].mean(\n        dim=1, keepdims=True\n    )\n\n    (\n        g.nodes[\"n2\"].data[\"k\"],\n        g.nodes[\"n2\"].data[\"eq\"],\n    ) = esp.mm.functional.linear_mixture_to_original(\n        g.nodes[\"n2\"].data[\"coefficients\"][:, 0][:, None],\n        g.nodes[\"n2\"].data[\"coefficients\"][:, 1][:, None],\n        1.5,\n        6.0,\n    )\n\n    import math\n\n    (\n        g.nodes[\"n3\"].data[\"k\"],\n        g.nodes[\"n3\"].data[\"eq\"],\n    ) = esp.mm.functional.linear_mixture_to_original(\n        g.nodes[\"n3\"].data[\"coefficients\"][:, 0][:, None],\n        g.nodes[\"n3\"].data[\"coefficients\"][:, 1][:, None],\n        0.0,\n        math.pi,\n    )\n\n    g.nodes[\"n2\"].data.pop(\"coefficients\")\n    g.nodes[\"n3\"].data.pop(\"coefficients\")\n\n    esp.mm.energy.energy_in_graph(g.heterograph, terms=[\"n2\", \"n3\"])\n\n    u1_2 = g.nodes[\"n2\"].data[\"u\"] - g.nodes[\"n2\"].data[\"u\"].mean(\n        dim=1, keepdims=True\n    )\n    u1_3 = g.nodes[\"n3\"].data[\"u\"] - g.nodes[\"n3\"].data[\"u\"].mean(\n        dim=1, keepdims=True\n    )\n    u1 = g.nodes[\"g\"].data[\"u\"] - g.nodes[\"g\"].data[\"u\"].mean(\n        dim=1, keepdims=True\n    )\n\n    import numpy.testing as npt\n\n    npt.assert_almost_equal(\n        u0_2.detach().numpy(),\n        u1_2.detach().numpy(),\n        decimal=3,\n    )\n\n    npt.assert_almost_equal(\n        u0_3.detach().numpy(),\n        u1_3.detach().numpy(),\n        decimal=3,\n    )\n\n    npt.assert_almost_equal(\n        u0.detach().numpy(),\n        u1.detach().numpy(),\n        decimal=3,\n    )\n"
  },
  {
    "path": "espaloma/mm/tests/test_openmm_consistency.py",
    "content": "import numpy as np\nimport numpy.testing as npt\nimport pytest\nimport torch\nimport openmm\nfrom openmm import unit\n\nfrom espaloma.utils.geometry import _sample_four_particle_torsion_scan\n\nomm_angle_unit = unit.radian\nomm_energy_unit = unit.kilojoule_per_mole\n\nfrom openmm import app\n\nimport espaloma as esp\n\ndecimal_threshold = 2\n\n\ndef _create_torsion_sim(\n    periodicity: int = 2, phase=0 * omm_angle_unit, k=10.0 * omm_energy_unit\n) -> app.Simulation:\n    \"\"\"Create a 4-particle OpenMM Simulation containing only a PeriodicTorsionForce\"\"\"\n    system = openmm.System()\n\n    # add 4 particles of unit mass\n    for _ in range(4):\n        system.addParticle(1)\n\n    # add torsion force to system\n    force = openmm.PeriodicTorsionForce()\n    force.addTorsion(0, 1, 2, 3, periodicity, phase, k)\n    system.addForce(force)\n\n    # create openmm Simulation, which requires a Topology and Integrator\n    topology = app.Topology()\n    chain = topology.addChain()\n    residue = topology.addResidue(\"torsion\", chain)\n    for name in [\"a\", \"b\", \"c\", \"d\"]:\n        topology.addAtom(name, \"C\", residue)\n    integrator = openmm.VerletIntegrator(1.0)\n    sim = app.Simulation(topology, system, integrator)\n\n    return sim\n\n\n# TODO: mark this properly: want to test periodicities 1..6, +ve, -ve k\n# @pytest.mark.parametrize(periodicity=[1,2,3,4,5,6], k=[-10 * omm_energy_unit, +10 * omm_energy_unit])\ndef test_periodic_torsion(\n    periodicity=4, k=10 * omm_energy_unit, n_samples=100\n):\n    \"\"\"Using simulated torsion scan, test if espaloma torsion energies and\n    OpenMM torsion energies agree.\n\n    \"\"\"\n    phase = 0 * omm_angle_unit  # all zero phases\n\n    # create torsion simulation\n    sim = _create_torsion_sim(periodicity=periodicity, phase=phase, k=k)\n\n    # grab snapshots from torsion scan\n    xyz_np = _sample_four_particle_torsion_scan(n_samples)\n\n    # compute energies using OpenMM\n    openmm_energies = np.zeros(n_samples)\n    for i, pos in enumerate(xyz_np):\n        sim.context.setPositions(pos)\n        openmm_energies[i] = (\n            sim.context.getState(getEnergy=True).getPotentialEnergy()\n            / omm_energy_unit\n        )\n\n    # compute energies using espaloma\n    xyz = torch.tensor(xyz_np)\n    x0, x1, x2, x3 = xyz[:, 0, :], xyz[:, 1, :], xyz[:, 2, :], xyz[:, 3, :]\n    theta = esp.mm.geometry.dihedral(x0, x1, x2, x3).reshape((n_samples, 1))\n    ks = torch.zeros(n_samples, 6)\n    ks[:, periodicity - 1] = k.value_in_unit(esp.units.ENERGY_UNIT)\n\n    espaloma_energies = (\n        esp.mm.functional.periodic(theta, ks).numpy().flatten()\n        * esp.units.ENERGY_UNIT\n    )\n    espaloma_energies_in_omm_units = espaloma_energies.value_in_unit(\n        omm_energy_unit\n    )\n\n    np.testing.assert_almost_equal(\n        actual=espaloma_energies_in_omm_units,\n        desired=openmm_energies,\n        decimal=decimal_threshold,\n    )\n\n\n# TODO: parameterize on the individual energy terms also\n@pytest.mark.parametrize(\n    \"g\",\n    esp.data.esol(first=10),\n)\ndef test_energy_angle_and_bond(g):\n    # make simulation\n    from espaloma.data.md import MoleculeVacuumSimulation\n\n    # get simulation\n    esp_simulation = MoleculeVacuumSimulation(\n        n_samples=1,\n        n_steps_per_sample=1000,\n        forcefield=\"gaff-1.81\",\n        charge_method=\"gasteiger\",\n    )\n\n    simulation = esp_simulation.simulation_from_graph(g)\n    system = simulation.system\n    esp_simulation.run(g, in_place=True)\n\n    # if MD blows up, forget about it\n    if g.nodes[\"n1\"].data[\"xyz\"].abs().max() > 100:\n        pytest.skip(\"MD simulation blew up, skipping test.\")\n\n    forces = list(system.getForces())\n\n    energies = {}\n\n    for idx, force in enumerate(forces):\n        force.setForceGroup(idx)\n\n        name = force.__class__.__name__\n\n        if \"Nonbonded\" in name:\n            force.setNonbondedMethod(openmm.NonbondedForce.NoCutoff)\n\n            # epsilons = {}\n            # sigmas = {}\n\n            # for _idx in range(force.getNumParticles()):\n            #     q, sigma, epsilon = force.getParticleParameters(_idx)\n\n            #     # record parameters\n            #     epsilons[_idx] = epsilon\n            #     sigmas[_idx] = sigma\n\n            #     force.setParticleParameters(_idx, 0., sigma, epsilon)\n\n            # def sigma_combining_rule(sig1, sig2):\n            #     return (sig1 + sig2) / 2\n\n            # def eps_combining_rule(eps1, eps2):\n            #     return np.sqrt(np.abs(eps1 * eps2))\n\n            # for _idx in range(force.getNumExceptions()):\n            #     idx0, idx1, q, sigma, epsilon = force.getExceptionParameters(\n            #         _idx)\n            #     force.setExceptionParameters(\n            #         _idx,\n            #         idx0,\n            #         idx1,\n            #         0.0,\n            #         sigma_combining_rule(sigmas[idx0], sigmas[idx1]),\n            #         eps_combining_rule(epsilons[idx0], epsilons[idx1])\n            #     )\n\n            # force.updateParametersInContext(_simulation.context)\n\n    # create new simulation\n    _simulation = openmm.app.Simulation(\n        simulation.topology,\n        system,\n        openmm.VerletIntegrator(0.0),\n    )\n\n    _simulation.context.setPositions(\n        g.nodes[\"n1\"].data[\"xyz\"][:, 0, :].detach().numpy() * unit.bohr\n    )\n\n    for idx, force in enumerate(forces):\n        name = force.__class__.__name__\n\n        state = _simulation.context.getState(\n            getEnergy=True,\n            getParameters=True,\n            groups=2**idx,\n        )\n\n        energy = state.getPotentialEnergy().value_in_unit(\n            esp.units.ENERGY_UNIT\n        )\n\n        energies[name] = energy\n\n    # parametrize\n    ff = esp.graphs.legacy_force_field.LegacyForceField(\"gaff-1.81\")\n    g = ff.parametrize(g)\n\n    # n2 : bond, n3: angle, n1: nonbonded?\n    # n1 : sigma (k), epsilon (eq), and charge (not included yet)\n    for term in [\"n2\", \"n3\"]:\n        g.nodes[term].data[\"k\"] = g.nodes[term].data[\"k_ref\"]\n        g.nodes[term].data[\"eq\"] = g.nodes[term].data[\"eq_ref\"]\n\n    \"\"\"\n    for term in [\"n1\"]:\n        g.nodes[term].data[\"sigma\"] = g.nodes[term].data[\"sigma_ref\"]\n        g.nodes[term].data[\"epsilon\"] = g.nodes[term].data[\"epsilon_ref\"]\n        # g.nodes[term].data['q'] = g.nodes[term].data['q_ref']\n    \"\"\"\n\n    for term in [\"n4\"]:\n        g.nodes[term].data[\"phases\"] = g.nodes[term].data[\"phases_ref\"]\n        g.nodes[term].data[\"periodicity\"] = g.nodes[term].data[\n            \"periodicity_ref\"\n        ]\n        g.nodes[term].data[\"k\"] = g.nodes[term].data[\"k_ref\"]\n\n    # for each atom, store n_snapshots x 3\n    # g.nodes[\"n1\"].data[\"xyz\"] = torch.tensor(\n    #     simulation.context.getState(getPositions=True)\n    #     .getPositions(asNumpy=True)\n    #     .value_in_unit(esp.units.DISTANCE_UNIT),\n    #     dtype=torch.float32,\n    # )[None, :, :].permute(1, 0, 2)\n\n    # print(g.nodes['n2'].data)\n    esp.mm.geometry.geometry_in_graph(g.heterograph)\n    esp.mm.energy.energy_in_graph(g.heterograph, terms=[\"n2\", \"n3\", \"n4\"])\n    # writes into nodes\n    # .data['u_nonbonded'], .data['u_onefour'], .data['u2'], .data['u3'],\n\n    # TODO: consider more carefully how many decimals of precision are needed\n    n_decimals = 3\n\n    # test bonds\n    npt.assert_almost_equal(\n        g.nodes[\"g\"].data[\"u_n2\"].detach().numpy(),\n        energies[\"HarmonicBondForce\"],\n        decimal=n_decimals,\n    )\n\n    # test angles\n    npt.assert_almost_equal(\n        g.nodes[\"g\"].data[\"u_n3\"].detach().numpy(),\n        energies[\"HarmonicAngleForce\"],\n        decimal=n_decimals,\n    )\n\n    # propers = g.nodes[\"g\"].data[\"u_n4\"].detach().numpy()\n    # impropers =  g.nodes[\"g\"].data[\"u_n4_improper\"].detach().numpy()\n    # all_torsions = propers + impropers\n    # npt.assert_almost_equal(\n    #     all_torsions,\n    #     energies[\"PeriodicTorsionForce\"],\n    #     decimal=n_decimals,\n    # )\n\n    # print(all_torsions)\n    # print(energies[\"PeriodicTorsionForce\"])\n\n    # TODO:\n    # This is not working now, matching OpenMM nonbonded.\n    # test nonbonded\n    # TODO: must set all charges to zero in _simulation for this to pass currently, since g doesn't have any charges\n    # npt.assert_almost_equal(\n    #     g.nodes['g'].data['u_nonbonded'].numpy()\\\n    #     + g.nodes['g'].data['u_onefour'].numpy(),\n    #     energies['NonbondedForce'],\n    #     decimal=3,\n    # )\n"
  },
  {
    "path": "espaloma/mm/tests/test_recoverability.py",
    "content": "# Check whether we can recover a molecular mechanics model containing just one kind of term\n# Initially, interested in recovering a molecular mechanics model containing only improper torsion terms\n\nimport numpy as np\nfrom openff.toolkit.topology import Molecule, Topology\nfrom openff.toolkit.typing.engines.smirnoff import ForceField\nimport openmm as mm\nimport pytest\nimport espaloma as esp\n\nimport torch\n\n\ndef _create_impropers_only_system(\n    smiles: str = \"CC1=C(C(=O)C2=C(C1=O)N3CC4C(C3(C2COC(=O)N)OC)N4)N\",\n) -> mm.System:\n    \"\"\"Create a simulation that contains only improper torsion terms,\n    by parameterizing with openff-1.2.0 and deleting  all terms but impropers\n    \"\"\"\n\n    molecule = Molecule.from_smiles(smiles, allow_undefined_stereo=True)\n    g = esp.Graph(molecule)\n\n    topology = Topology.from_molecules(molecule)\n    forcefield = ForceField(\"openff-1.2.0.offxml\")\n    openmm_system = forcefield.create_openmm_system(topology)\n\n    # delete all forces except PeriodicTorsionForce\n    is_torsion = (\n        lambda force: \"PeriodicTorsionForce\" in force.__class__.__name__\n    )\n    for i in range(openmm_system.getNumForces())[::-1]:\n        if not is_torsion(openmm_system.getForce(i)):\n            openmm_system.removeForce(i)\n    assert openmm_system.getNumForces() == 1\n    torsion_force = openmm_system.getForce(0)\n    assert is_torsion(torsion_force)\n\n    # set k = 0 for any torsion that's not an improper\n    indices = set(\n        map(\n            tuple,\n            esp.graphs.utils.offmol_indices.improper_torsion_indices(\n                molecule\n            ),\n        )\n    )\n    num_impropers_retained = 0\n    for i in range(torsion_force.getNumTorsions()):\n        (\n            p1,\n            p2,\n            p3,\n            p4,\n            periodicity,\n            phase,\n            k,\n        ) = torsion_force.getTorsionParameters(i)\n\n        if (p1, p2, p3, p4) in indices:\n            num_impropers_retained += 1\n        else:\n            torsion_force.setTorsionParameters(\n                i, p1, p2, p3, p4, periodicity, phase, 0.0\n            )\n\n    assert (\n        num_impropers_retained > 0\n    )  # otherwise this molecule is not a useful test case!\n\n    return openmm_system, topology, g\n\n\n@pytest.mark.skip(reason=\"too slow\")\ndef test_improper_recover():\n    import openmm\n    from openmm import unit\n    from openmm.app import Simulation\n    from openmm.unit import Quantity\n\n    TEMPERATURE = 500 * unit.kelvin\n    STEP_SIZE = 1 * unit.femtosecond\n    COLLISION_RATE = 1 / unit.picosecond\n\n    system, topology, g = _create_impropers_only_system()\n\n    # use langevin integrator, although it's not super useful here\n    integrator = openmm.LangevinIntegrator(\n        TEMPERATURE, COLLISION_RATE, STEP_SIZE\n    )\n\n    # initialize simulation\n    simulation = Simulation(\n        topology=topology, system=system, integrator=integrator\n    )\n\n    import openff.toolkit\n\n    # get conformer\n    g.mol.generate_conformers(\n        toolkit_registry=openff.toolkit.utils.RDKitToolkitWrapper(),\n    )\n\n    # put conformer in simulation\n    simulation.context.setPositions(g.mol.conformers[0])\n\n    # minimize energy\n    simulation.minimizeEnergy()\n\n    # set velocities\n    simulation.context.setVelocitiesToTemperature(TEMPERATURE)\n\n    samples = []\n    us = []\n\n    # loop through number of samples\n    for _ in range(10):\n\n        # run MD for `self.n_steps_per_sample` steps\n        simulation.step(10)\n\n        # append samples to `samples`\n        samples.append(\n            simulation.context.getState(getPositions=True)\n            .getPositions(asNumpy=True)\n            .value_in_unit(esp.units.DISTANCE_UNIT)\n        )\n\n        us.append(\n            simulation.context.getState(getEnergy=True)\n            .getPotentialEnergy()\n            .value_in_unit(esp.units.ENERGY_UNIT)\n        )\n\n    # put samples into an array\n    samples = np.array(samples)\n    us = np.array(us)\n\n    # put samples into tensor\n    samples = torch.tensor(samples, dtype=torch.float32)\n    us = torch.tensor(us, dtype=torch.float32)[None, :, None]\n\n    g.heterograph.nodes[\"n1\"].data[\"xyz\"] = samples.permute(1, 0, 2)\n\n    # require gradient for force matching\n    g.heterograph.nodes[\"n1\"].data[\"xyz\"].requires_grad = True\n\n    g.heterograph.nodes[\"g\"].data[\"u_ref\"] = us\n\n    # parametrize\n    layer = esp.nn.dgl_legacy.gn()\n    net = torch.nn.Sequential(\n        esp.nn.Sequential(layer, [32, \"tanh\", 32, \"tanh\", 32, \"tanh\"]),\n        esp.nn.readout.janossy.JanossyPoolingImproper(\n            in_features=32,\n            config=[32, \"tanh\"],\n            out_features={\n                \"k\": 6,\n            },\n        ),\n        esp.mm.geometry.GeometryInGraph(),\n        esp.mm.energy.EnergyInGraph(terms=[\"n4_improper\"]),\n    )\n\n    optimizer = torch.optim.Adam(net.parameters(), 1e-3)\n\n    for _ in range(1500):\n        optimizer.zero_grad()\n\n        net(g.heterograph)\n        u_ref = g.nodes[\"g\"].data[\"u\"]\n        u = g.nodes[\"g\"].data[\"u_ref\"]\n        loss = torch.nn.MSELoss()(u_ref, u)\n        loss.backward()\n        print(loss)\n        optimizer.step()\n\n    assert loss.detach().numpy().item() < 0.1\n\n\n# caffeine_smiles = 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C'\n#\n#\n# def _create_random_impropers_only_system(smiles: str = caffeine_smiles, k_stddev: float = 10.0) -> mm.System:\n#     \"\"\"Create an OpenMM system that contains only a large number of improper torsion terms,\n#     assigning random coefficients ~ N(0, k_stddev) kJ/mol\"\"\"\n#\n#     molecule = Molecule.from_smiles(smiles, allow_undefined_stereo=True)\n#\n#     topology = Topology.from_molecules(molecule)\n#     forcefield = ForceField('openff-1.2.0.offxml')\n#     openmm_system = forcefield.create_openmm_system(topology)\n#\n#     # delete all forces\n#     while openmm_system.getNumForces() > 0:\n#         openmm_system.removeForce(0)\n#\n#     # add a torsion force\n#     torsion_force = mm.PeriodicTorsionForce()\n#\n#     # for each improper torsion abcd, sample a periodicity, phase, and k, then add 3 terms to torsion_force\n#     # with different indices abcd, acdb, adbc but identical periodicity, phase, and k\n#     indices = esp.graphs.utils.offmol_indices.improper_torsion_indices(molecule)\n#     improper_perms = [(0, 1, 2, 3), (0, 2, 3, 1), (0, 3, 1, 2)]\n#\n#     for inds in indices:\n#         periodicity = np.random.randint(1, 7)\n#         phase = 0\n#         k = np.random.randn() * k_stddev\n#         for perm in improper_perms:\n#             p1, p2, p3, p4 = [int(inds[p]) for p in perm]  # careful to pass python ints rather than np ints to openmm\n#             torsion_force.addTorsion(p1, p2, p3, p4, periodicity, phase, k)\n#\n#     openmm_system.addForce(torsion_force)\n#\n#     return openmm_system\n\n# TODO: integration test where we recover this molecular mechanics system from energies/forces\n"
  },
  {
    "path": "espaloma/mm/torsion.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport espaloma as esp\n\n\n# =============================================================================\n# MODULE FUNCTIONS\n# =============================================================================\ndef periodic_torsion(\n    x, k, periodicity=list(range(1, 7)), phases=[0.0 for _ in range(6)]\n):\n    \"\"\"Periodic torsion potential\n\n    Parameters\n    ----------\n    x : `torch.Tensor`, `shape = (batch_size, 1)`\n        Dihedral value.\n    k : `torch.Tensor`, `shape = (batch_size, n_phases)`\n        Force constants.\n    periodicity : `torch.Tensor`, `shape = (batch_size, n_phases)`\n        Periodicities\n    phases : `torch.Tensor`, `shape = (batch_size, n_phases)`\n        Phase offsets\n\n    Returns\n    -------\n    u : `torch.Tensor`, `shape = (batch_size, 1)`\n        Energy.\n\n    \"\"\"\n\n    # NOTE:\n    # 0.5 because all torsions are calculated twice\n    out = 0.5 * esp.mm.functional.periodic(\n        x=x,\n        k=k,\n        periodicity=periodicity,\n        phases=phases,\n    )\n    # assert(out.shape == (len(x), 1))\n    return out\n\n\ndef angle_angle(\n    u_angle_left,\n    u_angle_right,\n    k_angle_angle,\n):\n\n    u_angle_left = u_angle_left - u_angle_left.min(dim=-1, keepdims=True)[0]\n    u_angle_right = (\n        u_angle_right - u_angle_right.min(dim=-1, keepdims=True)[0]\n    )\n    return k_angle_angle * (u_angle_left**0.5) * (u_angle_right**0.5)\n\n\ndef angle_torsion(\n    u_angle_left,\n    u_angle_right,\n    u_torsion,\n    k_angle_torsion,\n):\n    u_angle_left = u_angle_left - u_angle_left.min(dim=-1, keepdims=True)[0]\n    u_angle_right = (\n        u_angle_right - u_angle_right.min(dim=-1, keepdims=True)[0]\n    )\n    return (\n        k_angle_torsion * (u_angle_left**0.5) * u_torsion\n        + k_angle_torsion * (u_angle_right**0.5) * u_torsion\n    )\n\n\ndef angle_angle_torsion(\n    u_angle_left,\n    u_angle_right,\n    u_torsion,\n    k_angle_angle_torsion,\n):\n    u_angle_left = u_angle_left - u_angle_left.min(dim=-1, keepdims=True)[0]\n    u_angle_right = (\n        u_angle_right - u_angle_right.min(dim=-1, keepdims=True)[0]\n    )\n    return (\n        k_angle_angle_torsion\n        * (u_angle_left**0.5)\n        * (u_angle_right**0.5)\n        * u_torsion\n    )\n\n\ndef bond_torsion(\n    u_bond_left,\n    u_bond_right,\n    u_bond_center,\n    u_torsion,\n    k_side_torsion,\n    k_center_torsion,\n):\n\n    u_bond_left = u_bond_left - u_bond_left.min(dim=-1, keepdims=True)[0]\n    u_bond_right = u_bond_right - u_bond_right.min(dim=-1, keepdims=True)[0]\n    u_bond_center = (\n        u_bond_center - u_bond_center.min(dim=-1, keepdims=True)[0]\n    )\n    return (\n        k_side_torsion * u_torsion * (u_bond_left**0.5)\n        + k_side_torsion * u_torsion * (u_bond_right**0.5)\n        + k_center_torsion * u_torsion * (u_bond_center**0.5)\n    )\n"
  },
  {
    "path": "espaloma/nn/__init__.py",
    "content": "from . import baselines, layers, readout, sequential\nfrom .layers import dgl_legacy\nfrom .sequential import Sequential\n"
  },
  {
    "path": "espaloma/nn/baselines.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport torch\n\n\n# =============================================================================\n# MODULE CLASSES\n# =============================================================================\nclass FreeParameterBaseline(torch.nn.Module):\n    \"\"\"Parametrize a graph by populating the parameters with free\n    `torch.nn.Parameter`.\n\n\n    \"\"\"\n\n    def __init__(self, g_ref):\n        super(FreeParameterBaseline, self).__init__()\n        self.g_ref = g_ref\n\n        # whenever there is a reference parameter,\n        # assign a `torch.nn.Parameter`\n        for term in self.g_ref.ntypes:\n            for param, param_value in self.g_ref.nodes[term].data.items():\n                if param.endswith(\"_ref\") and \"u\" not in param:\n                    setattr(\n                        self,\n                        \"%s_%s\" % (term, param.replace(\"_ref\", \"\")),\n                        torch.nn.Parameter(\n                            torch.zeros_like(\n                                param_value.clone().detach(),\n                            )\n                        ),\n                    )\n\n    def forward(self, g):\n        update_dicts = {node: {} for node in self.g_ref.ntypes}\n\n        for term in self.g_ref.ntypes:\n            for param, param_value in self.g_ref.nodes[term].data.items():\n                if param.endswith(\"_ref\"):\n                    if hasattr(\n                        self, \"%s_%s\" % (term, param.replace(\"_ref\", \"\"))\n                    ):\n\n                        update_dicts[term][\n                            param.replace(\"_ref\", \"\")\n                        ] = getattr(\n                            self,\n                            \"%s_%s\" % (term, param.replace(\"_ref\", \"\")),\n                        )\n\n        for node, update_dict in update_dicts.items():\n            for param, param_value in update_dict.items():\n                g.nodes[node].data[param] = param_value\n\n        return g\n\n\nclass FreeParameterBaselineInitMean(torch.nn.Module):\n    \"\"\"Parametrize a graph by populating the parameters with free\n    `torch.nn.Parameter`.\n\n    \"\"\"\n\n    def __init__(self, g_ref):\n        super(FreeParameterBaselineInitMean, self).__init__()\n        self.g_ref = g_ref\n\n        # whenever there is a reference parameter,\n        # assign a `torch.nn.Parameter`\n        for term in self.g_ref.ntypes:\n            for param, param_value in self.g_ref.nodes[term].data.items():\n                if param.endswith(\"_ref\") and \"u\" not in param:\n                    setattr(\n                        self,\n                        \"%s_%s\" % (term, param.replace(\"_ref\", \"\")),\n                        torch.nn.Parameter(\n                            torch.ones_like(\n                                param_value.clone().detach(),\n                            )\n                            * param_value.clone().detach().mean()\n                        ),\n                    )\n\n    def forward(self, g):\n        update_dicts = {node: {} for node in self.g_ref.ntypes}\n\n        for term in self.g_ref.ntypes:\n            for param, param_value in self.g_ref.nodes[term].data.items():\n                if param.endswith(\"_ref\"):\n                    if hasattr(\n                        self, \"%s_%s\" % (term, param.replace(\"_ref\", \"\"))\n                    ):\n\n                        update_dicts[term][\n                            param.replace(\"_ref\", \"\")\n                        ] = getattr(\n                            self,\n                            \"%s_%s\" % (term, param.replace(\"_ref\", \"\")),\n                        )\n\n        for node, update_dict in update_dicts.items():\n            for param, param_value in update_dict.items():\n                g.nodes[node].data[param] = param_value\n\n        return g\n"
  },
  {
    "path": "espaloma/nn/layers/__init__.py",
    "content": "import espaloma.nn.layers.dgl_legacy\n"
  },
  {
    "path": "espaloma/nn/layers/dgl_legacy.py",
    "content": "\"\"\" Legacy models from DGL.\n\n\"\"\"\n\n# =============================================================================\n# IMPORTS\n# =============================================================================\nimport torch\n\n# =============================================================================\n# CONSTANT\n# =============================================================================\nDEFAULT_MODEL_KWARGS = {\n    \"SAGEConv\": {\"aggregator_type\": \"mean\"},\n    \"GATConv\": {\"num_heads\": 4},\n    \"TAGConv\": {\"k\": 2},\n}\n\n\n# =============================================================================\n# MODULE CLASSES\n# =============================================================================\nclass GN(torch.nn.Module):\n    def __init__(\n        self,\n        in_features,\n        out_features,\n        model_name=\"GraphConv\",\n        kwargs={},\n    ):\n        super(GN, self).__init__()\n        from dgl.nn import pytorch as dgl_pytorch\n\n        if kwargs == {}:\n            if model_name in DEFAULT_MODEL_KWARGS:\n                kwargs = DEFAULT_MODEL_KWARGS[model_name]\n\n        self.gn = getattr(dgl_pytorch.conv, model_name)(\n            in_features, out_features, **kwargs\n        )\n\n        # register these properties here for downstream handling\n        self.in_features = in_features\n        self.out_features = out_features\n\n    def forward(self, g, x):\n        return self.gn(g, x)\n\n\n# =============================================================================\n# MODULE FUNCTIONS\n# =============================================================================\n\n\ndef gn(model_name=\"GraphConv\", kwargs={}):\n    from dgl.nn import pytorch as dgl_pytorch\n\n    if model_name == \"GINConv\":\n        return lambda in_features, out_features: dgl_pytorch.conv.GINConv(\n            apply_func=torch.nn.Linear(in_features, out_features),\n            aggregator_type=\"sum\",\n        )\n\n    else:\n        return lambda in_features, out_features: GN(\n            in_features=in_features,\n            out_features=out_features,\n            model_name=model_name,\n            kwargs=kwargs,\n        )\n"
  },
  {
    "path": "espaloma/nn/readout/__init__.py",
    "content": "from . import janossy, graph_level_readout, node_typing, charge_equilibrium\n"
  },
  {
    "path": "espaloma/nn/readout/base_readout.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport abc\n\nimport torch\n\n\n# =============================================================================\n# BASE CLASSES\n# =============================================================================\nclass BaseReadout(abc.ABC, torch.nn.Module):\n    \"\"\"Base class for readout function.\"\"\"\n\n    def __init__(self):\n        super(BaseReadout, self).__init__()\n\n    @abc.abstractmethod\n    def forward(self, g, x=None, *args, **kwargs):\n        raise NotImplementedError\n\n    def _forward(self, g, x, *args, **kwargs):\n        raise NotImplementedError\n"
  },
  {
    "path": "espaloma/nn/readout/charge_equilibrium.py",
    "content": "\"\"\" Charge equilibrium.ß\n\n\"\"\"\n# =============================================================================\n# IMPORTS\n# =============================================================================\nimport torch\n\n# =============================================================================\n# UTILITY FUNCTIONS\n# =============================================================================\ndef get_charges(node):\n    \"\"\" Solve the function to get the absolute charges of atoms in a\n    molecule from parameters.\n    Parameters\n    ----------\n    e : tf.Tensor, dtype = tf.float32,\n        electronegativity.\n    s : tf.Tensor, dtype = tf.float32,\n        hardness.\n    Q : tf.Tensor, dtype = tf.float32, shape=(),\n        total charge of a molecule.\n    We use Lagrange multipliers to analytically give the solution.\n    $$\n    U({\\bf q})\n    &= \\sum_{i=1}^N \\left[ e_i q_i +  \\frac{1}{2}  s_i q_i^2\\right]\n        - \\lambda \\, \\left( \\sum_{j=1}^N q_j - Q \\right) \\\\\n    &= \\sum_{i=1}^N \\left[\n        (e_i - \\lambda) q_i +  \\frac{1}{2}  s_i q_i^2 \\right\n        ] + Q\n    $$\n    This gives us:\n    $$\n    q_i^*\n    &= - e_i s_i^{-1}\n    + \\lambda s_i^{-1} \\\\\n    &= - e_i s_i^{-1}\n    + s_i^{-1} \\frac{\n        Q +\n         \\sum\\limits_{i=1}^N e_i \\, s_i^{-1}\n        }{\\sum\\limits_{j=1}^N s_j^{-1}}\n    $$\n    \"\"\"\n    e = node.data[\"e\"]\n    s = node.data[\"s\"]\n    sum_e_s_inv = node.data[\"sum_e_s_inv\"]\n    sum_s_inv = node.data[\"sum_s_inv\"]\n    sum_q = node.data[\"sum_q\"]\n\n    return {\n        \"q\": -e * s**-1\n        + (s**-1) * torch.div(sum_q + sum_e_s_inv, sum_s_inv)\n    }\n\n\n# =============================================================================\n# MODULE CLASS\n# =============================================================================\nclass ChargeEquilibrium(torch.nn.Module):\n    \"\"\"Charge equilibrium within batches of molecules.\"\"\"\n\n    def __init__(self):\n        super(ChargeEquilibrium, self).__init__()\n\n    def forward(self, g, total_charge=0.0):\n        \"\"\"apply charge equilibrium to all molecules in batch\"\"\"\n        # calculate $s ^ {-1}$ and $ es ^ {-1}$\n        import dgl\n\n        g.apply_nodes(\n            lambda node: {\"s_inv\": node.data[\"s\"] ** -1}, ntype=\"n1\"\n        )\n\n        g.apply_nodes(\n            lambda node: {\"e_s_inv\": node.data[\"e\"] * node.data[\"s\"] ** -1},\n            ntype=\"n1\",\n        )\n\n        if \"sum_q\" not in g.nodes[\"g\"].data:\n            if \"q_ref\" in g.nodes[\"n1\"].data:\n                # get total charge\n                g.update_all(\n                    dgl.function.copy_u(u=\"q_ref\", out=\"m_q\"),\n                    dgl.function.sum(msg=\"m_q\", out=\"sum_q\"),\n                    etype=\"n1_in_g\",\n                )\n            else:\n                g.nodes[\"g\"].data[\"sum_q\"] = (\n                    torch.ones(\n                        g.batch_size,\n                        1,\n                        device=g.nodes[\"n1\"].data[\"s\"].device,\n                    )\n                    * total_charge\n                )\n\n        g.update_all(\n            dgl.function.copy_u(u=\"sum_q\", out=\"m_sum_q\"),\n            dgl.function.sum(msg=\"m_sum_q\", out=\"sum_q\"),\n            etype=\"g_has_n1\",\n        )\n\n        # get the sum of $s^{-1}$ and $m_s^{-1}$\n        g.update_all(\n            dgl.function.copy_u(u=\"s_inv\", out=\"m_s_inv\"),\n            dgl.function.sum(msg=\"m_s_inv\", out=\"sum_s_inv\"),\n            etype=\"n1_in_g\",\n        )\n\n        g.update_all(\n            dgl.function.copy_u(u=\"e_s_inv\", out=\"m_e_s_inv\"),\n            dgl.function.sum(msg=\"m_e_s_inv\", out=\"sum_e_s_inv\"),\n            etype=\"n1_in_g\",\n        )\n\n        g.update_all(\n            dgl.function.copy_u(u=\"sum_s_inv\", out=\"m_sum_s_inv\"),\n            dgl.function.sum(msg=\"m_sum_s_inv\", out=\"sum_s_inv\"),\n            etype=\"g_has_n1\",\n        )\n\n        g.update_all(\n            dgl.function.copy_u(u=\"sum_e_s_inv\", out=\"m_sum_e_s_inv\"),\n            dgl.function.sum(msg=\"m_sum_e_s_inv\", out=\"sum_e_s_inv\"),\n            etype=\"g_has_n1\",\n        )\n\n        g.apply_nodes(get_charges, ntype=\"n1\")\n\n        return g\n"
  },
  {
    "path": "espaloma/nn/readout/graph_level_readout.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport torch\nimport espaloma as esp\n\n# =============================================================================\n# MODULE CLASSES\n# =============================================================================\nclass GraphLevelReadout(torch.nn.Module):\n    \"\"\"Readout from graph level.\"\"\"\n\n    def __init__(\n        self,\n        in_features,\n        config_local,\n        config_global,\n        out_name,\n        pool=None,\n    ):\n\n        super(GraphLevelReadout, self).__init__()\n        import dgl\n\n        if pool is None:\n            pool = dgl.function.sum\n        self.in_features = in_features\n        self.config_local = config_local\n        self.config_global = config_global\n        self.d_local = esp.nn.sequential._Sequential(\n            in_features=in_features,\n            config=config_local,\n            layer=torch.nn.Linear,\n        )\n\n        mid_features = [x for x in config_local if isinstance(x, int)][-1]\n\n        self.d_global = esp.nn.sequential._Sequential(\n            in_features=mid_features,\n            config=config_global,\n            layer=torch.nn.Linear,\n        )\n\n        self.pool = pool\n        self.out_name = out_name\n\n    def forward(self, g):\n        import dgl\n\n        g.apply_nodes(\n            lambda node: {\"h_global\": self.d_local(None, node.data[\"h\"])},\n            ntype=\"n1\",\n        )\n\n        g.update_all(\n            dgl.function.copy_u(\"h_global\", \"m\"),\n            self.pool(\"m\", \"h_global\"),\n            etype=\"n1_in_g\",\n        )\n\n        g.apply_nodes(\n            lambda node: {\n                self.out_name: self.d_global(None, node.data[\"h_global\"])\n            },\n            ntype=\"g\",\n        )\n\n        return g\n"
  },
  {
    "path": "espaloma/nn/readout/janossy.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport torch\n\nimport espaloma as esp\n\n\n# =============================================================================\n# MODULE CLASSES\n# =============================================================================\nclass JanossyPooling(torch.nn.Module):\n    \"\"\"Janossy pooling (arXiv:1811.01900) to average node representation\n    for higher-order nodes.\n\n\n    \"\"\"\n\n    def __init__(\n        self,\n        config,\n        in_features,\n        out_features={\n            1: [\"sigma\", \"epsilon\", \"q\"],\n            2: [\"k\", \"eq\"],\n            3: [\"k\", \"eq\"],\n            4: [\"k\", \"eq\"],\n        },\n        out_features_dimensions=-1,\n        pool=torch.add,\n    ):\n        super(JanossyPooling, self).__init__()\n\n        # if users specify out features as lists,\n        # assume dimensions to be all zero\n        for level in out_features.keys():\n            if isinstance(out_features[level], list):\n                out_features[level] = dict(\n                    zip(out_features[level], [1 for _ in out_features[level]])\n                )\n\n        # bookkeeping\n        self.out_features = out_features\n        self.levels = [key for key in out_features.keys() if key != 1]\n        self.pool = pool\n\n        # get output features\n        mid_features = [x for x in config if isinstance(x, int)][-1]\n\n        # set up networks\n        for level in self.levels:\n\n            # set up individual sequential networks\n            setattr(\n                self,\n                \"sequential_%s\" % level,\n                esp.nn.sequential._Sequential(\n                    in_features=in_features * level,\n                    config=config,\n                    layer=torch.nn.Linear,\n                ),\n            )\n\n            for feature, dimension in self.out_features[level].items():\n                setattr(\n                    self,\n                    \"f_out_%s_to_%s\" % (level, feature),\n                    torch.nn.Linear(\n                        mid_features,\n                        dimension,\n                    ),\n                )\n\n        if 1 not in self.out_features:\n            return\n\n        # atom level\n        self.sequential_1 = esp.nn.sequential._Sequential(\n            in_features=in_features, config=config, layer=torch.nn.Linear\n        )\n\n        for feature, dimension in self.out_features[1].items():\n            setattr(\n                self,\n                \"f_out_1_to_%s\" % feature,\n                torch.nn.Linear(\n                    mid_features,\n                    dimension,\n                ),\n            )\n\n    def forward(self, g):\n        \"\"\"Forward pass.\n\n        Parameters\n        ----------\n        g : dgl.DGLHeteroGraph,\n            input graph.\n        \"\"\"\n        import dgl\n\n        # copy\n        g.multi_update_all(\n            {\n                \"n1_as_%s_in_n%s\"\n                % (relationship_idx, big_idx): (\n                    dgl.function.copy_u(\"h\", \"m%s\" % relationship_idx),\n                    dgl.function.mean(\n                        \"m%s\" % relationship_idx, \"h%s\" % relationship_idx\n                    ),\n                )\n                for big_idx in self.levels\n                for relationship_idx in range(big_idx)\n            },\n            cross_reducer=\"sum\",\n        )\n\n        # pool\n        for big_idx in self.levels:\n\n            if g.number_of_nodes(\"n%s\" % big_idx) == 0:\n                continue\n\n            g.apply_nodes(\n                func=lambda nodes: {\n                    feature: getattr(\n                        self, \"f_out_%s_to_%s\" % (big_idx, feature)\n                    )(\n                        self.pool(\n                            getattr(self, \"sequential_%s\" % big_idx)(\n                                None,\n                                torch.cat(\n                                    [\n                                        nodes.data[\"h%s\" % relationship_idx]\n                                        for relationship_idx in range(big_idx)\n                                    ],\n                                    dim=1,\n                                ),\n                            ),\n                            getattr(self, \"sequential_%s\" % big_idx)(\n                                None,\n                                torch.cat(\n                                    [\n                                        nodes.data[\"h%s\" % relationship_idx]\n                                        for relationship_idx in range(\n                                            big_idx - 1, -1, -1\n                                        )\n                                    ],\n                                    dim=1,\n                                ),\n                            ),\n                        ),\n                    )\n                    for feature in self.out_features[big_idx].keys()\n                },\n                ntype=\"n%s\" % big_idx,\n            )\n\n        if 1 not in self.out_features:\n            return g\n\n        # atom level\n        g.apply_nodes(\n            func=lambda nodes: {\n                feature: getattr(self, \"f_out_1_to_%s\" % feature)(\n                    self.sequential_1(g=None, x=nodes.data[\"h\"])\n                )\n                for feature in self.out_features[1].keys()\n            },\n            ntype=\"n1\",\n        )\n\n        return g\n\n\nclass JanossyPoolingImproper(torch.nn.Module):\n    \"\"\"Janossy pooling (arXiv:1811.01900) to average node representation\n    for improper torsions.\n\n\n    \"\"\"\n\n    def __init__(\n        self,\n        config,\n        in_features,\n        out_features={\n            \"k\": 2,\n        },\n        out_features_dimensions=-1,\n    ):\n        super(JanossyPoolingImproper, self).__init__()\n\n        # if users specify out features as lists,\n        # assume dimensions to be all zero\n\n        # bookkeeping\n        self.out_features = out_features\n        self.levels = [\"n4_improper\"]\n\n        # get output features\n        mid_features = [x for x in config if isinstance(x, int)][-1]\n\n        # set up networks\n        for level in self.levels:\n\n            # set up individual sequential networks\n            setattr(\n                self,\n                \"sequential_%s\" % level,\n                esp.nn.sequential._Sequential(\n                    in_features=4 * in_features,\n                    config=config,\n                    layer=torch.nn.Linear,\n                ),\n            )\n\n            for feature, dimension in self.out_features.items():\n                setattr(\n                    self,\n                    \"f_out_%s_to_%s\" % (level, feature),\n                    torch.nn.Linear(\n                        mid_features,\n                        dimension,\n                    ),\n                )\n\n    def forward(self, g):\n        \"\"\"Forward pass.\n\n        Parameters\n        ----------\n        g : dgl.DGLHeteroGraph,\n            input graph.\n        \"\"\"\n        import dgl\n\n        # copy\n        g.multi_update_all(\n            {\n                \"n1_as_%s_in_%s\"\n                % (relationship_idx, big_idx): (\n                    dgl.function.copy_u(\"h\", \"m%s\" % relationship_idx),\n                    dgl.function.mean(\n                        \"m%s\" % relationship_idx, \"h%s\" % relationship_idx\n                    ),\n                )\n                for big_idx in self.levels\n                for relationship_idx in range(4)\n            },\n            cross_reducer=\"sum\",\n        )\n\n        if g.number_of_nodes(\"n4_improper\") == 0:\n            return g\n\n        # pool\n        #   sum over three cyclic permutations of \"h0\", \"h2\", \"h3\", assuming \"h1\" is the central atom in the improper\n        #   following the smirnoff trefoil convention [(0, 1, 2, 3), (2, 1, 3, 0), (3, 1, 0, 2)]\n        #   https://github.com/openff.toolkit/openff.toolkit/blob/166c9864de3455244bd80b2c24656bd7dda3ae2d/openff.toolkit/typing/engines/smirnoff/parameters.py#L3326-L3360\n\n        ## Set different permutations based on which definition of impropers\n        ##  are being used\n        permuts = [(0, 1, 2, 3), (2, 1, 3, 0), (3, 1, 0, 2)]\n        stack_permuts = lambda nodes, p: torch.cat(\n            [nodes.data[f\"h{i}\"] for i in p], dim=1\n        )\n\n        for big_idx in self.levels:\n            inner_net = getattr(self, f\"sequential_{big_idx}\")\n\n            g.apply_nodes(\n                func=lambda nodes: {\n                    feature: getattr(self, f\"f_out_{big_idx}_to_{feature}\")(\n                        torch.sum(\n                            torch.stack(\n                                [\n                                    inner_net(\n                                        g=None, x=stack_permuts(nodes, p)\n                                    )\n                                    for p in permuts\n                                ],\n                                dim=0,\n                            ),\n                            dim=0,\n                        )\n                    )\n                    for feature in self.out_features.keys()\n                },\n                ntype=big_idx,\n            )\n\n        return g\n\n\nclass JanossyPoolingWithSmirnoffImproper(torch.nn.Module):\n    \"\"\"Janossy pooling (arXiv:1811.01900) to average node representation\n    for improper torsions.\n    \"\"\"\n\n    def __init__(\n        self,\n        config,\n        in_features,\n        out_features={\n            \"k\": 2,\n        },\n        out_features_dimensions=-1,\n    ):\n        super(JanossyPoolingWithSmirnoffImproper, self).__init__()\n\n        # if users specify out features as lists,\n        # assume dimensions to be all zero\n\n        # bookkeeping\n        self.out_features = out_features\n        self.levels = [\"n4_improper\"]\n\n        # get output features\n        mid_features = [x for x in config if isinstance(x, int)][-1]\n\n        # set up networks\n        for level in self.levels:\n\n            # set up individual sequential networks\n            setattr(\n                self,\n                \"sequential_%s\" % level,\n                esp.nn.sequential._Sequential(\n                    in_features=4 * in_features,\n                    config=config,\n                    layer=torch.nn.Linear,\n                ),\n            )\n\n            for feature, dimension in self.out_features.items():\n                setattr(\n                    self,\n                    \"f_out_%s_to_%s\" % (level, feature),\n                    torch.nn.Linear(\n                        mid_features,\n                        dimension,\n                    ),\n                )\n\n    def forward(self, g):\n        \"\"\"Forward pass.\n\n        Parameters\n        ----------\n        g : dgl.DGLHeteroGraph,\n            input graph.\n        \"\"\"\n        import dgl\n\n        # copy\n        g.multi_update_all(\n            {\n                \"n1_as_%s_in_%s\"\n                % (relationship_idx, big_idx): (\n                    dgl.function.copy_u(\"h\", \"m%s\" % relationship_idx),\n                    dgl.function.mean(\n                        \"m%s\" % relationship_idx, \"h%s\" % relationship_idx\n                    ),\n                )\n                for big_idx in self.levels\n                for relationship_idx in range(4)\n            },\n            cross_reducer=\"sum\",\n        )\n\n        if g.number_of_nodes(\"n4_improper\") == 0:\n            return g\n\n        # pool\n        #   sum over three cyclic permutations of \"h0\", \"h2\", \"h3\", assuming \"h1\" is the central atom in the improper\n        #   following the smirnoff trefoil convention [(0, 1, 2, 3), (2, 1, 3, 0), (3, 1, 0, 2)]\n        #   https://github.com/openff.toolkit/openff.toolkit/blob/166c9864de3455244bd80b2c24656bd7dda3ae2d/openff.toolkit/typing/engines/smirnoff/parameters.py#L3326-L3360\n\n        ## Set different permutations based on which definition of impropers\n        ##  are being used\n        permuts = [(0, 1, 2, 3), (0, 2, 3, 1), (0, 3, 1, 2)]\n        stack_permuts = lambda nodes, p: torch.cat(\n            [nodes.data[f\"h{i}\"] for i in p], dim=1\n        )\n\n        for big_idx in self.levels:\n            inner_net = getattr(self, f\"sequential_{big_idx}\")\n\n            g.apply_nodes(\n                func=lambda nodes: {\n                    feature: getattr(self, f\"f_out_{big_idx}_to_{feature}\")(\n                        torch.sum(\n                            torch.stack(\n                                [\n                                    inner_net(\n                                        g=None, x=stack_permuts(nodes, p)\n                                    )\n                                    for p in permuts\n                                ],\n                                dim=0,\n                            ),\n                            dim=0,\n                        )\n                    )\n                    for feature in self.out_features.keys()\n                },\n                ntype=big_idx,\n            )\n\n        return g\n\n\nclass JanossyPoolingNonbonded(torch.nn.Module):\n    \"\"\"Janossy pooling (arXiv:1811.01900) to average node representation\n    for nonbonded interactions.\n\n\n    \"\"\"\n\n    def __init__(\n        self,\n        config,\n        in_features,\n        out_features={\"sigma\": 1, \"epsilon\": 1},\n        out_features_dimensions=-1,\n    ):\n        super(JanossyPoolingNonbonded, self).__init__()\n\n        # if users specify out features as lists,\n        # assume dimensions to be all zero\n\n        # bookkeeping\n        self.out_features = out_features\n        self.levels = [\"onefour\", \"nonbonded\"]\n\n        # get output features\n        mid_features = [x for x in config if isinstance(x, int)][-1]\n\n        # set up networks\n        for level in self.levels:\n\n            # set up individual sequential networks\n            setattr(\n                self,\n                \"sequential_%s\" % level,\n                esp.nn.sequential._Sequential(\n                    in_features=2 * in_features,\n                    config=config,\n                    layer=torch.nn.Linear,\n                ),\n            )\n\n            for feature, dimension in self.out_features.items():\n                setattr(\n                    self,\n                    \"f_out_%s_to_%s\" % (level, feature),\n                    torch.nn.Linear(\n                        mid_features,\n                        dimension,\n                    ),\n                )\n\n    def forward(self, g):\n        \"\"\"Forward pass.\n\n        Parameters\n        ----------\n        g : dgl.DGLHeteroGraph,\n            input graph.\n        \"\"\"\n\n        # copy\n        g.multi_update_all(\n            {\n                \"n1_as_%s_in_%s\"\n                % (relationship_idx, big_idx): (\n                    dgl.function.copy_u(\"h\", \"m%s\" % relationship_idx),\n                    dgl.function.mean(\n                        \"m%s\" % relationship_idx, \"h%s\" % relationship_idx\n                    ),\n                )\n                for big_idx in self.levels\n                for relationship_idx in range(2)\n            },\n            cross_reducer=\"sum\",\n        )\n\n        for big_idx in self.levels:\n\n            g.apply_nodes(\n                func=lambda nodes: {\n                    feature: getattr(\n                        self, \"f_out_%s_to_%s\" % (big_idx, feature)\n                    )(\n                        torch.sum(\n                            torch.stack(\n                                [\n                                    getattr(self, \"sequential_%s\" % big_idx)(\n                                        g=None,\n                                        x=torch.cat(\n                                            [\n                                                nodes.data[\"h0\"],\n                                                nodes.data[\"h1\"],\n                                            ],\n                                            dim=1,\n                                        ),\n                                    ),\n                                    getattr(self, \"sequential_%s\" % big_idx)(\n                                        g=None,\n                                        x=torch.cat(\n                                            [\n                                                nodes.data[\"h1\"],\n                                                nodes.data[\"h0\"],\n                                            ],\n                                            dim=1,\n                                        ),\n                                    ),\n                                ],\n                                dim=0,\n                            ),\n                            dim=0,\n                        )\n                    )\n                    for feature in self.out_features.keys()\n                },\n                ntype=big_idx,\n            )\n\n        return g\n\n\nclass ExpCoefficients(torch.nn.Module):\n    def forward(self, g):\n        import math\n\n        g.nodes[\"n2\"].data[\"coefficients\"] = (\n            g.nodes[\"n2\"].data[\"log_coefficients\"].exp()\n        )\n        g.nodes[\"n3\"].data[\"coefficients\"] = (\n            g.nodes[\"n3\"].data[\"log_coefficients\"].exp()\n        )\n        return g\n\n\nclass LinearMixtureToOriginal(torch.nn.Module):\n    def forward(self, g):\n        import math\n\n        (\n            g.nodes[\"n2\"].data[\"k\"],\n            g.nodes[\"n2\"].data[\"eq\"],\n        ) = esp.mm.functional.linear_mixture_to_original(\n            g.nodes[\"n2\"].data[\"coefficients\"][:, 0][:, None],\n            g.nodes[\"n2\"].data[\"coefficients\"][:, 1][:, None],\n            1.5,\n            6.0,\n        )\n\n        (\n            g.nodes[\"n3\"].data[\"k\"],\n            g.nodes[\"n3\"].data[\"eq\"],\n        ) = esp.mm.functional.linear_mixture_to_original(\n            g.nodes[\"n3\"].data[\"coefficients\"][:, 0][:, None],\n            g.nodes[\"n3\"].data[\"coefficients\"][:, 1][:, None],\n            0.0,\n            math.pi,\n        )\n\n        g.nodes[\"n3\"].data.pop(\"coefficients\")\n        g.nodes[\"n2\"].data.pop(\"coefficients\")\n        return g\n"
  },
  {
    "path": "espaloma/nn/readout/node_typing.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nimport torch\n\nfrom espaloma.nn.readout.base_readout import BaseReadout\n\n\n# =============================================================================\n# MODULE CLASSES\n# =============================================================================\nclass NodeTyping(BaseReadout):\n    \"\"\"Simple typing on homograph.\"\"\"\n\n    def __init__(self, in_features, n_classes=100):\n        super(NodeTyping, self).__init__()\n        self.c = torch.nn.Linear(in_features, n_classes)\n\n    def forward(self, g):\n        g.apply_nodes(\n            ntype=\"n1\",\n            func=lambda node: {\"nn_typing\": self.c(node.data[\"h\"])},\n        )\n        return g\n"
  },
  {
    "path": "espaloma/nn/sequential.py",
    "content": "\"\"\" Chain mutiple layers of GN together.\n\"\"\"\nimport torch\n\n\nclass _Sequential(torch.nn.Module):\n    \"\"\"Sequentially staggered neural networks.\"\"\"\n\n    def __init__(\n        self,\n        layer,\n        config,\n        in_features,\n        model_kwargs={},\n    ):\n        super(_Sequential, self).__init__()\n\n        self.exes = []\n\n        # init dim\n        dim = in_features\n\n        # parse the config\n        for idx, exe in enumerate(config):\n\n            try:\n                exe = float(exe)\n\n                if exe >= 1:\n                    exe = int(exe)\n            except BaseException:\n                pass\n\n            # int -> feedfoward\n            if isinstance(exe, int):\n                setattr(self, \"d\" + str(idx), layer(dim, exe, **model_kwargs))\n\n                dim = exe\n                self.exes.append(\"d\" + str(idx))\n\n            # str -> activation\n            elif isinstance(exe, str):\n                if exe == \"bn\":\n                    setattr(self, \"a\" + str(idx), torch.nn.BatchNorm1d(dim))\n\n                else:\n                    activation = getattr(torch.nn.functional, exe)\n                    setattr(self, \"a\" + str(idx), activation)\n\n                self.exes.append(\"a\" + str(idx))\n\n            # float -> dropout\n            elif isinstance(exe, float):\n                dropout = torch.nn.Dropout(exe)\n                setattr(self, \"o\" + str(idx), dropout)\n\n                self.exes.append(\"o\" + str(idx))\n\n    def forward(self, g, x):\n        for exe in self.exes:\n            if exe.startswith(\"d\"):\n                if g is not None:\n                    x = getattr(self, exe)(g, x)\n                else:\n                    x = getattr(self, exe)(x)\n            else:\n                x = getattr(self, exe)(x)\n\n        return x\n\n\nclass Sequential(torch.nn.Module):\n    \"\"\"Sequential neural network with input layers.\n\n    Parameters\n    ----------\n    layer : torch.nn.Module\n        DGL graph convolution layers.\n\n    config : List\n        A sequence of numbers (for units) and strings (for activation functions)\n        denoting the configuration of the sequential model.\n\n    feature_units : int(default=114)\n        The number of input channels.\n\n    Methods\n    -------\n    forward(g, x)\n        Forward pass.\n    \"\"\"\n\n    def __init__(\n        self,\n        layer,\n        config,\n        feature_units=114,\n        input_units=128,\n        model_kwargs={},\n    ):\n        super(Sequential, self).__init__()\n\n        # initial featurization\n        self.f_in = torch.nn.Sequential(\n            torch.nn.Linear(feature_units, input_units), torch.nn.Tanh()\n        )\n\n        self._sequential = _Sequential(\n            layer, config, in_features=input_units, model_kwargs=model_kwargs\n        )\n\n    def _forward(self, g, x):\n        \"\"\"Forward pass with graph and features.\"\"\"\n        for exe in self.exes:\n            if exe.startswith(\"d\"):\n                x = getattr(self, exe)(g, x)\n            else:\n                x = getattr(self, exe)(x)\n\n        return x\n\n    def forward(self, g, x=None):\n        \"\"\"Forward pass.\n\n        Parameters\n        ----------\n        g : `dgl.DGLHeteroGraph`,\n            input graph\n\n        Returns\n        -------\n        g : `dgl.DGLHeteroGraph`\n            output graph\n        \"\"\"\n        import dgl\n\n        # get homogeneous subgraph\n        g_ = dgl.to_homogeneous(g.edge_type_subgraph([\"n1_neighbors_n1\"]))\n\n        if x is None:\n            # get node attributes\n            x = g.nodes[\"n1\"].data[\"h0\"]\n            x = self.f_in(x)\n\n        # message passing on homo graph\n        x = self._sequential(g_, x)\n\n        # put attribute back in the graph\n        g.nodes[\"n1\"].data[\"h\"] = x\n\n        return g\n"
  },
  {
    "path": "espaloma/nn/tests/test_baseline.py",
    "content": "import pytest\n\n\n@pytest.fixture\ndef baseline():\n    import espaloma as esp\n\n    g = esp.Graph(\"c1ccccc1\")\n\n    # get force field\n    forcefield = esp.graphs.legacy_force_field.LegacyForceField(\n        \"smirnoff99Frosst-1.1.0\"\n    )\n\n    # param / typing\n    operation = forcefield.parametrize\n\n    operation(g)\n\n    baseline = esp.nn.baselines.FreeParameterBaseline(g_ref=g.heterograph)\n\n    return baseline\n\n\ndef test_init(baseline):\n    baseline\n\n\ndef test_parameter(baseline):\n    print(list(baseline.parameters()))\n\n    assert len(list(baseline.parameters())) > 0\n"
  },
  {
    "path": "espaloma/nn/tests/test_janossy.py",
    "content": "import pytest\n\n\ndef test_small_net():\n    import torch\n\n    import espaloma as esp\n\n    # define a layer\n    layer = esp.nn.layers.dgl_legacy.gn(\"GraphConv\")\n\n    # define a representation\n    representation = esp.nn.Sequential(\n        layer, [32, \"tanh\", 32, \"tanh\", 32, \"tanh\"]\n    )\n\n    # define a readout\n    readout = esp.nn.readout.janossy.JanossyPooling(\n        config=[32, \"tanh\"], in_features=32\n    )\n\n    net = torch.nn.Sequential(representation, readout)\n\n    g = esp.Graph(\"c1ccccc1\")\n"
  },
  {
    "path": "espaloma/nn/tests/test_simple_net.py",
    "content": "import pytest\n\n\ndef test_small_net():\n    import torch\n\n    import espaloma as esp\n\n    layer = esp.nn.dgl_legacy.gn()\n    net = esp.nn.Sequential(layer, [32, \"tanh\", 32, \"tanh\", 32, \"tanh\"])\n"
  },
  {
    "path": "espaloma/units.py",
    "content": "# =============================================================================\n# IMPORTS\n# =============================================================================\nfrom openmm import unit\n\n# =============================================================================\n# CONSTANTS\n# =============================================================================\n\n# scaled units\nPARTICLE = unit.mole.create_unit(\n    6.02214076e23**-1,\n    \"particle\",\n    \"particle\",\n)\n\nHARTREE_PER_PARTICLE = unit.hartree / PARTICLE\n\n# basic units\nDISTANCE_UNIT = unit.bohr\nENERGY_UNIT = HARTREE_PER_PARTICLE\nFORCE_UNIT = ENERGY_UNIT / DISTANCE_UNIT\nANGLE_UNIT = unit.radian\nCHARGE_UNIT = unit.elementary_charge\n\n# compose units\nFORCE_CONSTANT_UNIT = ENERGY_UNIT / (DISTANCE_UNIT**2)\nANGLE_FORCE_CONSTANT_UNIT = ENERGY_UNIT / (ANGLE_UNIT**2)\nCOULOMB_CONSTANT_UNIT = (\n    ENERGY_UNIT * DISTANCE_UNIT / ((unit.elementary_charge**2))\n)\n\nGAS_CONSTANT = (\n    8.31446261815324 * unit.joule * (unit.kelvin**-1) * (unit.mole**-1)\n).value_in_unit(HARTREE_PER_PARTICLE / unit.kelvin)\n"
  },
  {
    "path": "espaloma/utils/geometry.py",
    "content": "import numpy as np\n\n\ndef _sample_unit_circle(n_samples: int = 1) -> np.ndarray:\n    \"\"\"\n    >>> np.isclose(np.linalg.norm(_sample_unit_circle(1)), 1)\n    True\n\n    \"\"\"\n    theta = np.random.rand(n_samples) * 2 * np.pi\n    x = np.cos(theta)\n    y = np.sin(theta)\n    xy = np.array([x, y]).T\n    assert xy.shape == (n_samples, 2)\n    return xy\n\n\ndef _sample_four_particle_torsion_scan(n_samples: int = 1) -> np.ndarray:\n    \"\"\"Generate n_samples random configurations of a 4-particle system abcd where\n    * distances ab, bc, cd are constant,\n    * angles abc, bcd are constant\n    * dihedral angle abcd is uniformly distributed in [0, 2pi]\n\n    Returns\n    -------\n    xyz : np.ndarray, shape = (n_samples, 4, 3)\n\n    Notes\n    -----\n    * Positions of a,b,c are constant, and x-coordinate of d is constant.\n        To be more exacting, could add random displacements and rotations.\n    \"\"\"\n    a = (-3, -1, 0)\n    b = (-2, 0, 0)\n    c = (-1, 0, 0)\n    d = (0, 1, 0)\n\n    # form one 3D configuration\n    conf = np.array([a, b, c, d])\n    assert conf.shape == (4, 3)\n\n    # make n_samples copies\n    xyz = np.array([conf] * n_samples, dtype=float)\n    assert xyz.shape == (n_samples, 4, 3)\n\n    # assign y and z coordinates of particle d to unit-circle samples\n    xyz[:, 3, 1:] = _sample_unit_circle(n_samples)\n\n    return xyz\n\n\ndef _timemachine_signed_torsion_angle(ci, cj, ck, cl):\n    \"\"\"Reference implementation from Yutong Zhao's timemachine\n\n    Copied directly from\n    https://github.com/proteneer/timemachine/blob/1a0ab45e605dc1e28c44ea90f38cb0dedce5c4db/timemachine/potentials/bonded.py#L152-L199\n    (but with 3 lines of dead code removed, and delta_r inlined)\n    \"\"\"\n\n    rij = cj - ci\n    rkj = cj - ck\n    rkl = cl - ck\n\n    n1 = np.cross(rij, rkj)\n    n2 = np.cross(rkj, rkl)\n\n    y = np.sum(\n        np.multiply(\n            np.cross(n1, n2),\n            rkj / np.linalg.norm(rkj, axis=-1, keepdims=True),\n        ),\n        axis=-1,\n    )\n    x = np.sum(np.multiply(n1, n2), -1)\n\n    return np.arctan2(y, x)\n"
  },
  {
    "path": "espaloma/utils/model_fetch.py",
    "content": "from pathlib import Path\nfrom typing import Any, Union\n\nimport requests\nimport torch.utils.model_zoo\nfrom tqdm import tqdm\n\n\ndef _get_model_url(version: str) -> str:\n    \"\"\"\n    Get the URL of the espaloma model from GitHub releases.\n\n    Parameters:\n        version (str): Version of the model. If set to \"latest\", the URL for the latest version will be returned.\n\n    Returns:\n        str: The URL of the espaloma model.\n\n    Note:\n        - If version is set to \"latest\", the URL for the latest version of the model will be returned.\n        - The URL is obtained from the GitHub releases of the espaloma repository.\n\n    Example:\n        >>> url = _get_model_url(version=\"0.3.0\")\n    \"\"\"\n\n    if version == \"latest\":\n        url = \"https://github.com/choderalab/espaloma/releases/latest/download/espaloma-latest.pt\"\n    else:\n        # TODO: This scheme requires the version string of the model to match the\n        # release version\n        url = f\"https://github.com/choderalab/espaloma/releases/download/{version}/espaloma-{version}.pt\"\n\n    return url\n\n\ndef get_model_path(\n    model_dir: Union[str, Path] = \".espaloma/\",\n    version: str = \"latest\",\n    disable_progress_bar: bool = False,\n    overwrite: bool = False,\n) -> Path:\n    \"\"\"\n    Download a model for espaloma.\n\n    Parameters:\n        model_dir (str or Path): Directory path where the model will be saved. Default is ``.espaloma/``.\n        version (str): Version of the model to download. Default is \"latest\".\n        disable_progress_bar (bool): Whether to disable the progress bar during the download. Default is False.\n        overwrite (bool): Whether to overwrite the existing model file if it exists. Default is False.\n\n    Returns:\n        Path: The path to the downloaded model file.\n\n    Raises:\n        FileExistsError: If the model file already exists and overwrite is set to False.\n\n    Note:\n        - If version is set to \"latest\", the latest version of the model will be downloaded.\n        - The model will be downloaded from GitHub releases.\n        - The model file will be saved in the specified model directory.\n\n    Example:\n        >>> model_path = get_model(model_dir=\".espaloma/\", version=\"0.3.0\", disable_progress_bar=True)\n    \"\"\"\n\n    url = _get_model_url(version)\n\n    # This will work as long as we never have a \"/\" in the version string\n    file_name = Path(url.split(\"/\")[-1])\n    model_dir = Path(model_dir)\n    model_path = Path(model_dir / file_name)\n\n    if not overwrite and model_path.exists():\n        raise FileExistsError(\n            f\"File '{model_path}' exiits, use overwrite=True to overwrite file\"\n        )\n    model_dir.mkdir(parents=True, exist_ok=True)\n\n    request = requests.get(url, stream=True)\n    request_lenght = int(request.headers.get(\"content-length\", 0))\n    with open(model_path, \"wb\") as file, tqdm(\n        total=request_lenght,\n        unit=\"iB\",\n        unit_scale=True,\n        unit_divisor=1024,\n        disable=disable_progress_bar,\n    ) as progress:\n        for data in request.iter_content(chunk_size=1024):\n            size = file.write(data)\n            progress.update(size)\n\n    return model_path\n\n\ndef get_model(version: str = \"latest\") -> dict[str, Any]:\n    \"\"\"\n        Load an espaloma model from GitHub releases.\n\n    Parameters:\n        version (str): Version of the model to load. Default is \"latest\".\n\n    Returns:\n        dict[str, Any]: The loaded espaloma model.\n\n    Note:\n        - If version is set to \"latest\", the latest version of the model will be loaded.\n        - The model will be loaded from GitHub releases.\n        - The model will be loaded onto the CPU.\n\n    Example:\n        >>> model = get_model(version=\"0.3.0\")\n    \"\"\"\n\n    url = _get_model_url(version)\n    model = torch.utils.model_zoo.load_url(url, map_location=\"cpu\")\n    model.eval()  # type: ignore\n\n    return model\n"
  },
  {
    "path": "espaloma/utils/tests/test_model_fetch.py",
    "content": "import espaloma as esp\nimport torch\nfrom openff.toolkit.topology import Molecule\n\n\ndef test_get_model_path(tmp_path):\n    model_dir = tmp_path / \"latest\"\n    model_path = esp.get_model_path(model_dir=model_dir, disable_progress_bar=True)\n\n    molecule = Molecule.from_smiles(\"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\")\n    molecule_graph = esp.Graph(molecule)\n\n    espaloma_model = torch.load(model_path)\n    espaloma_model.eval()\n    espaloma_model(molecule_graph.heterograph)\n\n\ndef test_get_model(tmp_path):\n    espaloma_model = esp.get_model()\n\n    molecule = Molecule.from_smiles(\"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\")\n    molecule_graph = esp.Graph(molecule)\n    espaloma_model(molecule_graph.heterograph)\n"
  },
  {
    "path": "requirements.txt",
    "content": "dgl\ntorch\nmatplotlib\npandas\nnumpy\nqcportal\n"
  },
  {
    "path": "scripts/README.md",
    "content": "# Miscellaneous auxiliary scripts for demonstrating espaloma\n\n* `perses-benchmark/` - relative alchemical free energy calculations with [perses](http://github.com/choderalab/perses) using espaloma to parameterize small molecules via []`openmmforcefields`](https://github.com/openmm/openmmforcefields)\n"
  },
  {
    "path": "scripts/perses-benchmark/README.md",
    "content": "# Relative alchemical free energy calculations\n\nThis is an example of using [perses](http://github.com/choderalab/perses) using espaloma to parameterize small molecules via []`openmmforcefields`](https://github.com/openmm/openmmforcefields)\n\n* `tyk2/` - JACS tyk2 system\n\n## Installing perses and espaloma\n\nTo install perses and espaloma together:\n```bash\nconda env create -n espaloma-perses -f espaloma-perses.yaml\n```\nTo reproduce environment used in paper (on linux-64)\n```bash\nconda env create -n espaloma-perses -f espaloma-perses.export.yaml\n```\n"
  },
  {
    "path": "scripts/perses-benchmark/espaloma-perses.export.yaml",
    "content": "name: espaloma-perses\nchannels:\n  - dglteam\n  - psi4\n  - conda-forge\n  - openeye\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=conda_forge\n  - _openmp_mutex=4.5=1_gnu\n  - alabaster=0.7.12=py_0\n  - ambertools=21.9=py39h69e27f8_0\n  - argon2-cffi=21.3.0=pyhd8ed1ab_0\n  - argon2-cffi-bindings=21.2.0=py39h3811e60_1\n  - arpack=3.7.0=hdefa2d7_2\n  - arrow-cpp=2.0.0=py39h5894ca3_15_cpu\n  - arsenic=0.2.1=py39hf3d152e_0\n  - asttokens=2.0.5=pyhd8ed1ab_0\n  - astunparse=1.6.3=pyhd8ed1ab_0\n  - attrs=21.4.0=pyhd8ed1ab_0\n  - aws-c-common=0.4.59=h36c2ea0_1\n  - aws-c-event-stream=0.1.6=had2084c_6\n  - aws-checksums=0.1.10=h4e93380_0\n  - aws-sdk-cpp=1.8.70=h57dc084_1\n  - babel=2.9.1=pyh44b312d_0\n  - backcall=0.2.0=pyh9f0ad1d_0\n  - backports=1.1=pyhd3eb1b0_0\n  - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0\n  - beautifulsoup4=4.10.0=pyha770c72_0\n  - blas=1.0=mkl\n  - bleach=5.0.0=pyhd8ed1ab_0\n  - blosc=1.21.0=h9c3ff4c_0\n  - bokeh=2.4.2=py39hf3d152e_0\n  - boost=1.74.0=py39h5472131_5\n  - boost-cpp=1.74.0=hc6e9bd1_3\n  - brotli=1.0.9=h166bdaf_7\n  - brotli-bin=1.0.9=h166bdaf_7\n  - brotlipy=0.7.0=py39hb9d737c_1004\n  - bzip2=1.0.8=h7f98852_4\n  - c-ares=1.18.1=h7f98852_0\n  - ca-certificates=2022.3.29=h06a4308_0\n  - cached-property=1.5.2=hd8ed1ab_1\n  - cached_property=1.5.2=pyha770c72_1\n  - cairo=1.16.0=h6cf1ce9_1008\n  - certifi=2021.10.8=py39hf3d152e_2\n  - cffi=1.15.0=py39h4bc2ebd_0\n  - cftime=1.6.0=py39hd257fcd_0\n  - charset-normalizer=2.0.12=pyhd8ed1ab_0\n  - click=8.1.2=py39hf3d152e_0\n  - cloudpickle=2.0.0=pyhd8ed1ab_0\n  - codecov=2.1.11=pyhd3deb0d_0\n  - colorama=0.4.4=pyh9f0ad1d_0\n  - coverage=6.3.2=py39hb9d737c_2\n  - cryptography=36.0.0=py39h9ce1e76_0\n  - cudatoolkit=10.2.89=h8f6ccaa_10\n  - curl=7.82.0=h2283fc2_0\n  - cycler=0.11.0=pyhd8ed1ab_0\n  - cython=0.29.28=py39h5a03fae_2\n  - cytoolz=0.11.2=py39hb9d737c_2\n  - dask=2022.4.0=pyhd8ed1ab_0\n  - dask-core=2022.4.0=pyhd8ed1ab_0\n  - dask-jobqueue=0.7.3=pyhd8ed1ab_0\n  - debugpy=1.5.1=py39he80948d_0\n  - decorator=5.1.1=pyhd8ed1ab_0\n  - defusedxml=0.7.1=pyhd8ed1ab_0\n  - dgl=0.8.0post2=py39_0\n  - dicttoxml=1.7.4=pyhd8ed1ab_2\n  - distributed=2022.4.0=pyhd8ed1ab_0\n  - docutils=0.17.1=py39hf3d152e_1\n  - entrypoints=0.4=pyhd8ed1ab_0\n  - executing=0.8.3=pyhd8ed1ab_0\n  - expat=2.4.8=h27087fc_0\n  - fftw=3.3.10=nompi_h77c792f_102\n  - fire=0.4.0=pyh44b312d_0\n  - flit-core=3.7.1=pyhd8ed1ab_0\n  - fontconfig=2.14.0=h8e229c2_0\n  - freetype=2.11.0=h70c0345_0\n  - fsspec=2022.3.0=pyhd8ed1ab_0\n  - future=0.18.2=py39hf3d152e_5\n  - gettext=0.21.0=hf68c758_0\n  - gflags=2.2.2=he1b5a44_1004\n  - giflib=5.2.1=h516909a_2\n  - glog=0.4.0=h49b9bf7_3\n  - greenlet=1.1.2=py39h5a03fae_2\n  - grpc-cpp=1.34.1=h2157cd5_4\n  - h5py=3.6.0=nompi_py39h7e08c79_100\n  - hdf4=4.2.15=h10796ff_3\n  - hdf5=1.12.1=nompi_h4df4325_104\n  - heapdict=1.0.1=py_0\n  - icu=68.2=h9c3ff4c_0\n  - idna=3.3=pyhd8ed1ab_0\n  - imagesize=1.3.0=pyhd8ed1ab_0\n  - importlib-metadata=4.11.3=py39hf3d152e_1\n  - importlib_resources=5.6.0=pyhd8ed1ab_0\n  - iniconfig=1.1.1=pyh9f0ad1d_0\n  - intel-openmp=2021.4.0=h06a4308_3561\n  - ipykernel=6.12.0=py39hef51801_0\n  - ipython=8.2.0=py39hf3d152e_0\n  - ipython_genutils=0.2.0=py_1\n  - ipywidgets=7.7.0=pyhd8ed1ab_0\n  - jedi=0.18.1=py39hf3d152e_1\n  - jinja2=3.1.1=pyhd8ed1ab_0\n  - joblib=1.1.0=pyhd8ed1ab_0\n  - jpeg=9e=h7f98852_0\n  - jsonschema=4.4.0=pyhd8ed1ab_0\n  - jupyter_client=7.2.2=pyhd8ed1ab_1\n  - jupyter_core=4.9.2=py39hf3d152e_0\n  - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0\n  - jupyterlab_widgets=1.1.0=pyhd8ed1ab_0\n  - keyutils=1.6.1=h166bdaf_0\n  - kiwisolver=1.4.2=py39hf939315_1\n  - krb5=1.19.3=h08a2579_0\n  - lcms2=2.12=hddcbb42_0\n  - ld_impl_linux-64=2.36.1=hea4e1c9_2\n  - libblas=3.9.0=12_linux64_mkl\n  - libbrotlicommon=1.0.9=h166bdaf_7\n  - libbrotlidec=1.0.9=h166bdaf_7\n  - libbrotlienc=1.0.9=h166bdaf_7\n  - libcblas=3.9.0=12_linux64_mkl\n  - libcurl=7.82.0=h2283fc2_0\n  - libedit=3.1.20210910=h7f8727e_0\n  - libev=4.33=h516909a_1\n  - libevent=2.1.10=h28343ad_4\n  - libffi=3.4.2=h7f98852_5\n  - libgcc-ng=11.2.0=h1d223b6_15\n  - libgfortran-ng=11.2.0=h69a702a_15\n  - libgfortran5=11.2.0=h5c6108e_15\n  - libglib=2.70.2=h174f98d_4\n  - libgomp=11.2.0=h1d223b6_15\n  - libiconv=1.16=h516909a_0\n  - liblapack=3.9.0=12_linux64_mkl\n  - libllvm10=10.0.1=he513fc3_3\n  - libnetcdf=4.8.1=nompi_hb3fd0d9_101\n  - libnghttp2=1.47.0=he49606f_0\n  - libnsl=2.0.0=h7f98852_0\n  - libpng=1.6.37=hed695b0_2\n  - libprotobuf=3.14.0=h780b84a_0\n  - libsodium=1.0.18=h516909a_1\n  - libssh2=1.10.0=ha35d2d1_2\n  - libstdcxx-ng=11.2.0=he4da1e4_15\n  - libthrift=0.13.0=hfb8234f_6\n  - libtiff=4.2.0=hbd63e13_2\n  - libutf8proc=2.7.0=h7f98852_0\n  - libuuid=2.32.1=h14c3975_1000\n  - libwebp=1.2.2=h55f646e_0\n  - libwebp-base=1.2.2=h7f98852_1\n  - libxcb=1.14=h7b6447c_0\n  - libxml2=2.9.12=h72842e0_0\n  - libxslt=1.1.33=h15afd5d_2\n  - libzip=1.8.0=h1c5bbd1_1\n  - libzlib=1.2.11=h166bdaf_1014\n  - llvmlite=0.36.0=py39h1bbdace_0\n  - locket=0.2.1=py39h06a4308_2\n  - lxml=4.8.0=py39hb9d737c_1\n  - lz4=4.0.0=py39h029007f_1\n  - lz4-c=1.9.3=h9c3ff4c_1\n  - lzo=2.10=h516909a_1000\n  - markupsafe=2.1.1=py39hb9d737c_1\n  - matplotlib=3.3.2=0\n  - matplotlib-base=3.3.2=py39h98787fa_1\n  - matplotlib-inline=0.1.3=pyhd8ed1ab_0\n  - mdtraj=1.9.7=py39h138c130_1\n  - mistune=0.8.4=py39h3811e60_1005\n  - mkl=2021.4.0=h06a4308_640\n  - mkl-service=2.4.0=py39h3811e60_0\n  - mpiplus=v0.0.1=py39hde42818_1002\n  - msgpack-python=1.0.3=py39hf939315_1\n  - nbclient=0.5.13=pyhd8ed1ab_0\n  - nbconvert=6.4.5=py39hf3d152e_0\n  - nbformat=5.3.0=pyhd8ed1ab_0\n  - ncurses=6.3=h27087fc_1\n  - nest-asyncio=1.5.5=pyhd8ed1ab_0\n  - netcdf-fortran=4.5.4=nompi_h2b6e579_100\n  - netcdf4=1.5.8=nompi_py39h64b754b_101\n  - networkx=2.7.1=pyhd8ed1ab_0\n  - nglview=3.0.3=pyh8a188c0_0\n  - ninja=1.10.2=h4bd325d_1\n  - nose=1.3.7=py_1006\n  - nose-timer=1.0.1=pyhd8ed1ab_0\n  - notebook=6.4.10=pyha770c72_0\n  - numba=0.53.1=py39h56b8d98_1\n  - numexpr=2.8.1=py39h6abb31d_0\n  - numpy=1.22.3=py39h18676bf_1\n  - numpydoc=1.2.1=pyhd8ed1ab_0\n  - ocl-icd=2.3.1=h7f98852_0\n  - ocl-icd-system=1.0.0=1\n  - openeye-toolkits=2021.2.0=py39_0\n  - openff-forcefields=2.0.0=pyh6c4a22f_0\n  - openff-toolkit=0.10.3=pyhd8ed1ab_0\n  - openff-toolkit-base=0.10.3=pyhd8ed1ab_0\n  - openmm=7.7.0=py39h9717219_1\n  - openmmtools=0.21.2=pyhd8ed1ab_0\n  - openmoltools=0.8.8=pyhd8ed1ab_1\n  - openssl=3.0.2=h166bdaf_1\n  - orc=1.6.6=h7950760_1\n  - packaging=21.3=pyhd8ed1ab_0\n  - packmol=20.010=h86c2bf4_0\n  - pandas=1.4.2=py39h1832856_0\n  - pandoc=2.17.1.1=ha770c72_0\n  - pandocfilters=1.5.0=pyhd8ed1ab_0\n  - parmed=3.4.3=py39he80948d_1\n  - parquet-cpp=1.5.1=1\n  - parso=0.8.3=pyhd8ed1ab_0\n  - partd=1.2.0=pyhd8ed1ab_0\n  - patsy=0.5.2=pyhd8ed1ab_0\n  - pcre=8.45=h9c3ff4c_0\n  - pdbfixer=1.8.1=pyh6c4a22f_0\n  - perl=5.32.1=2_h7f98852_perl5\n  - perses=0.9.5=pyh8a188c0_0\n  - pexpect=4.8.0=pyh9f0ad1d_2\n  - pickleshare=0.7.5=py39hde42818_1002\n  - pillow=9.0.1=py39h22f2fdc_0\n  - pint=0.19.1=pyhd8ed1ab_0\n  - pip=22.0.4=pyhd8ed1ab_0\n  - pixman=0.40.0=h36c2ea0_0\n  - plotly=5.7.0=pyhd8ed1ab_0\n  - pluggy=1.0.0=py39hf3d152e_3\n  - prometheus_client=0.14.0=pyhd8ed1ab_0\n  - prompt-toolkit=3.0.29=pyha770c72_0\n  - psutil=5.9.0=py39hb9d737c_1\n  - ptyprocess=0.7.0=pyhd3deb0d_0\n  - pure_eval=0.2.2=pyhd8ed1ab_0\n  - py=1.11.0=pyh6c4a22f_0\n  - pyarrow=2.0.0=py39h3ebc44c_15_cpu\n  - pycairo=1.21.0=py39h0934665_1\n  - pycparser=2.21=pyhd8ed1ab_0\n  - pydantic=1.9.0=py39hb9d737c_1\n  - pygments=2.11.2=pyhd8ed1ab_0\n  - pymbar=3.0.6=py39hd257fcd_0\n  - pyopenssl=22.0.0=pyhd8ed1ab_0\n  - pyparsing=3.0.7=pyhd8ed1ab_0\n  - pyrsistent=0.18.1=py39hb9d737c_1\n  - pysocks=1.7.1=py39hf3d152e_5\n  - pytables=3.7.0=py39h2669a42_0\n  - pytest=7.1.1=py39hf3d152e_1\n  - pytest-cov=3.0.0=pyhd8ed1ab_0\n  - python=3.9.12=h2660328_1_cpython\n  - python-dateutil=2.8.2=pyhd8ed1ab_0\n  - python-fastjsonschema=2.15.3=pyhd8ed1ab_0\n  - python_abi=3.9=2_cp39\n  - pytorch=1.10.2=cpu_py39hfa7516b_0\n  - pytz=2022.1=pyhd8ed1ab_0\n  - pyyaml=6.0=py39hb9d737c_4\n  - pyzmq=22.3.0=py39headdf64_2\n  - qcelemental=0.24.0=pyhd8ed1ab_0\n  - qcportal=0.15.8=pyhd8ed1ab_0\n  - rdkit=2022.03.1=py39h89e00b9_0\n  - re2=2020.11.01=h58526e2_0\n  - readline=8.1.2=h7f8727e_1\n  - reportlab=3.5.68=py39he59360d_1\n  - requests=2.27.1=pyhd8ed1ab_0\n  - scikit-learn=1.0.2=py39h4dfa638_0\n  - scipy=1.8.0=py39hee8e79c_1\n  - seaborn=0.11.2=hd8ed1ab_0\n  - seaborn-base=0.11.2=pyhd8ed1ab_0\n  - send2trash=1.8.0=pyhd8ed1ab_0\n  - setuptools=62.0.0=py39hf3d152e_0\n  - six=1.16.0=pyh6c4a22f_0\n  - smirnoff99frosst=1.1.0=pyh44b312d_0\n  - snappy=1.1.8=he1b5a44_3\n  - snowballstemmer=2.2.0=pyhd8ed1ab_0\n  - sortedcontainers=2.4.0=pyhd8ed1ab_0\n  - soupsieve=2.3.1=pyhd8ed1ab_0\n  - sphinx=4.5.0=pyh6c4a22f_0\n  - sphinx_rtd_theme=1.0.0=pyhd8ed1ab_0\n  - sphinxcontrib-applehelp=1.0.2=py_0\n  - sphinxcontrib-devhelp=1.0.2=py_0\n  - sphinxcontrib-htmlhelp=2.0.0=pyhd8ed1ab_0\n  - sphinxcontrib-jsmath=1.0.1=py_0\n  - sphinxcontrib-qthelp=1.0.3=py_0\n  - sphinxcontrib-serializinghtml=1.1.5=pyhd8ed1ab_1\n  - sqlalchemy=1.4.35=py39hb9d737c_0\n  - sqlite=3.38.2=hc218d9a_0\n  - stack_data=0.2.0=pyhd8ed1ab_0\n  - statsmodels=0.13.2=py39hce5d2b2_0\n  - tblib=1.7.0=pyhd8ed1ab_0\n  - tenacity=8.0.1=pyhd8ed1ab_0\n  - termcolor=1.1.0=py_2\n  - terminado=0.13.3=py39hf3d152e_1\n  - testpath=0.6.0=pyhd8ed1ab_0\n  - threadpoolctl=3.1.0=pyh8a188c0_0\n  - tinydb=4.7.0=pyhd8ed1ab_0\n  - tk=8.6.12=h27826a3_0\n  - toml=0.10.2=pyhd8ed1ab_0\n  - tomli=2.0.1=pyhd8ed1ab_0\n  - toolz=0.11.2=pyhd8ed1ab_0\n  - tornado=6.1=py39hb9d737c_3\n  - tqdm=4.64.0=pyhd8ed1ab_0\n  - traitlets=5.1.1=pyhd8ed1ab_0\n  - typing-extensions=4.1.1=hd8ed1ab_0\n  - typing_extensions=4.1.1=pyha770c72_0\n  - tzdata=2022a=h191b570_0\n  - urllib3=1.26.9=pyhd8ed1ab_0\n  - validators=0.18.2=pyhd3deb0d_0\n  - wcwidth=0.2.5=pyh9f0ad1d_2\n  - webencodings=0.5.1=py_1\n  - wheel=0.37.1=pyhd8ed1ab_0\n  - widgetsnbextension=3.6.0=py39hf3d152e_0\n  - xmltodict=0.12.0=py_0\n  - xorg-kbproto=1.0.7=h14c3975_1002\n  - xorg-libice=1.0.10=h516909a_0\n  - xorg-libsm=1.2.3=hd9c2040_1000\n  - xorg-libx11=1.7.2=h7f98852_0\n  - xorg-libxext=1.3.4=h7f98852_1\n  - xorg-libxrender=0.9.10=h7f98852_1003\n  - xorg-libxt=1.2.1=h7f98852_2\n  - xorg-renderproto=0.11.1=h14c3975_1002\n  - xorg-xextproto=7.3.0=h14c3975_1002\n  - xorg-xproto=7.0.31=h14c3975_1007\n  - xz=5.2.5=h516909a_1\n  - yaml=0.2.5=h7f98852_2\n  - zeromq=4.3.4=h9c3ff4c_1\n  - zict=2.1.0=pyhd8ed1ab_0\n  - zipp=3.8.0=pyhd8ed1ab_0\n  - zlib=1.2.11=h166bdaf_1014\n  - zstd=1.4.9=ha95c52a_0\n  - pip:\n    - amberlite==16.0\n    - amberutils==21.0\n    - espaloma==0.2.2\n    - mmpbsa-py==16.0\n    - openmmforcefields==0.10.0+27.g1fabf43\n    - packmol-memgen==1.2.1rc0\n    - pdb4amber==20.1\n    - pytraj==2.0.6\n    - sander==16.0\nprefix: /lila/home/chodera/miniconda/envs/espaloma-perses\n"
  },
  {
    "path": "scripts/perses-benchmark/espaloma-perses.yaml",
    "content": "name: espaloma-perses\nchannels:\n  - conda-forge\n  - dglteam\n  - openeye\n  - defaults\n  - anaconda\ndependencies:\n  # Base dependencies\n  - python\n  - pip\n  # 3rd party\n  - openeye-toolkits\n  - numpy\n  - matplotlib\n  - scipy\n  - openff-toolkit\n  - openff-forcefields\n  - smirnoff99Frosst\n  - openmm\n  - openmmforcefields\n  - tqdm\n  # Pytorch\n  - pytorch>=1.8.0\n  - dgl\n  # Testing\n  - pytest\n  - pytest-cov\n  - codecov\n  - nose\n  - nose-timer\n  - coverage\n  - qcportal>=0.15.0\n  - sphinx\n  - sphinx_rtd_theme\n  # perses\n  - perses\n  # will be added to openmmforcefields conda-forge recipe\n  - validators\n  - pip:\n    # espaloma\n    - git+https://github.com/choderalab/espaloma.git@0.2.2\n    # openmmforcefield\n    - git+https://github.com/openmm/openmmforcefields.git\n"
  },
  {
    "path": "scripts/perses-benchmark/tyk2/README.md",
    "content": "# tyk2 benchmarks with perses and espaloma\n\n* `openff-1.2.0/` - scripts to use Open Force Field (\"Parsley\") `openff-1.2.0` small molecule force field\n* `espaloma-0.2.2/` - scripts to use Espaloma `espaloma-0.2.2` small molecule force field\n"
  },
  {
    "path": "scripts/perses-benchmark/tyk2/espaloma-0.2.2/LSF-job-template.sh",
    "content": "#!/bin/bash\n#BSUB -P \"tyk2-benchmark\"\n#BSUB -J \"perses-benchmark-[1-24]\"\n#BSUB -n 1\n#BSUB -R rusage[mem=8]\n#BSUB -R span[hosts=1]\n#BSUB -q gpuqueue\n#BSUB -sp 1 # low priority. default is 12, max is 25\n#BSUB -gpu num=1:j_exclusive=yes:mode=shared\n#BSUB -W  24:00\n#BSUB -o out_%J_%I.stdout\n#BSUB -eo out_%J_%I.stderr\n#BSUB -L /bin/bash\n\nsource ~/.bashrc\nOPENMM_CPU_THREADS=1\n\necho \"changing directory to ${LS_SUBCWD}\"\ncd $LS_SUBCWD\nconda activate espaloma-perses\n\n# Report node in use\nhostname\n\n# Report CUDA info\nenv | sort | grep 'CUDA'\n\n# launching a benchmark pair (target, edge) per job (0-based thus substract 1)\npython run_benchmarks.py --target tyk2 --edge $(( $LSB_JOBINDEX - 1 ))\n"
  },
  {
    "path": "scripts/perses-benchmark/tyk2/espaloma-0.2.2/README.md",
    "content": "# Perses benchmarks\n\nThis subdirectory exposes a CLI tool for running automated benchmarks from\n[OpenFF's protein ligand benchmark dataset](https://github.com/openforcefield/protein-ligand-benchmark) using perses.\n\n## Running all edges\n\nA script to run all transformations in an LSF batch scheduler is provided, but will likely need to be modified for your batch queue system:\n```bash\nbsub < LSF-job-template.sh\n```\n\n## Running single edges\n\nAssuming you have a clone of the perses code repository and you are standing in the `benchmarks` subdirectory\n(where this file lives). Then the benchmarks can be run using the following command syntax:\n```bash\npython run_benchmarks.py --target [protein-name] --edge [edge-index]\n```\n\nFor example, for running the seventh edge (zero-based, according to [plbenchmark data](https://github.com/openforcefield/protein-ligand-benchmark) )\nfor `tyk2` protein, you would run:\n```bash\n# Set up and run edge 6\npython run_benchmarks.py --target tyk2 --edge 6\n```\nShould the calculation for an edge fail, you can simply re-run the same command-line and the calculation will resume:\n```bash\n# Resume failed edge 6\npython run_benchmarks.py --target tyk2 --edge 6\n```\nFor more information on how to use the tool, you can run `python run_benchmarks.py -h`.\n\n## Analyzing benchmarks\n\nTo analyze the simulations a script called `benchmark_analysis.py` is used as follows:\n```bash\npython benchmark_analysis.py --target [protein-name]\n```\n\nFor example, for tyk2 results:\n```bash\npython benchmark_analysis.py --target tyk2\n```\nThis will generate an output CSV file for [`arsenic`](https://github.com/openforcefield/arsenic) and corresponding absolute and relative free energy plots as PNG files produced according to best practices.)\n\nFor more information on how to use the cli analysis tool use `python benchmark_analysis.py -h`.\n"
  },
  {
    "path": "scripts/perses-benchmark/tyk2/espaloma-0.2.2/benchmark_analysis.py",
    "content": "\"\"\"\nScript to perform analysis of perses simulations executed using run_benchmarks.py script.\n\nIntended to be used on systems from https://github.com/openforcefield/protein-ligand-benchmark\n\"\"\"\n\nimport argparse\nimport glob\nimport itertools\nimport re\nimport warnings\n\nimport numpy as np\nimport urllib.request\nimport yaml\n\nfrom openmmtools.constants import kB\nfrom perses.analysis.load_simulations import Simulation\n\nfrom openmm import unit\n\nfrom openff.arsenic import plotting, wrangle\n\n# global variables\nbase_repo_url = \"https://github.com/openforcefield/protein-ligand-benchmark\"\n\n\n# Helper functions\n\ndef get_simdir_list(base_dir='.', is_reversed=False):\n    \"\"\"\n    Get list of directories to extract simulation data.\n\n    Attributes\n    ----------\n    base_dir: str, optional, default='.'\n        Base directory where to search for simulations results. Defaults to current directory.\n    is_reversed: bool, optional, default=False\n        Whether to consider the reversed simulations or not. Meant for testing purposes.\n\n    Returns\n    -------\n    dir_list: list\n        List of directories paths for simulation results.\n    \"\"\"\n    # Load all expected simulation from directories\n    out_dirs = ['/'.join(filepath.split('/')[:-1]) for filepath in glob.glob(f'{base_dir}/out*/*complex.nc')]\n    reg = re.compile(r'out_[0-9]+_[0-9]+_reversed')  # regular expression to deal with reversed directories\n    if is_reversed:\n        # Choose only reversed directories\n        out_dirs = list(filter(reg.search, out_dirs))\n    else:\n        # Filter out reversed directories\n        out_dirs = list(itertools.filterfalse(reg.search, out_dirs))\n    return out_dirs\n\n\ndef get_simulations_data(simulation_dirs):\n    \"\"\"Generates a list of simulation data objects given the simulation directories paths.\"\"\"\n    simulations = []\n    for out_dir in simulation_dirs:\n        # Load complete or fully working simulations\n        # TODO: Try getting better exceptions from openmmtools -- use non-generic exceptions\n        try:\n            simulation = Simulation(out_dir)\n            simulations.append(simulation)\n        except Exception:\n            warnings.warn(f\"Edge in {out_dir} could not be loaded. Check simulation output is complete.\")\n    return simulations\n\n\ndef to_arsenic_csv(experimental_data: dict, simulation_data: list, out_csv: str = 'out_benchmark.csv'):\n    \"\"\"\n    Generates a csv file to be used with openff-arsenic. Energy units in kcal/mol.\n\n    .. warning:: To be deprecated once arsenic object model is improved.\n\n    Parameters\n    ----------\n        experimental_data: dict\n            Python nested dictionary with experimental data in micromolar or nanomolar units.\n            Example of entry:\n\n                {'lig_ejm_31': {'measurement': {'comment': 'Table 4, entry 31',\n                  'doi': '10.1016/j.ejmech.2013.03.070',\n                  'error': -1,\n                  'type': 'ki',\n                  'unit': 'uM',\n                  'value': 0.096},\n                  'name': 'lig_ejm_31',\n                  'smiles': '[H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)C([H])([H])[H])[H])[H])Cl)[H]'}\n\n        simulation_data: list or iterable\n            Python iterable object with perses Simulation objects as entries.\n        out_csv: str\n            Path to output csv file to be generated.\n    \"\"\"\n    # Ligand information\n    ligands_names = list(ligands_dict.keys())\n    lig_id_to_name = dict(enumerate(ligands_names))\n    kBT = kB * 300 * unit.kelvin  # useful when converting to kcal/mol\n    # Write csv file\n    with open(out_csv, 'w') as csv_file:\n        # Experimental block\n        # print header for block\n        csv_file.write(\"# Experimental block\\n\")\n        csv_file.write(\"# Ligand, expt_DG, expt_dDG\\n\")\n        # Extract ligand name, expt_DG and expt_dDG from ligands dictionary\n        for ligand_name, ligand_data in experimental_data.items():\n            # TODO: Handle multiple measurement types\n            unit_symbol = ligand_data['measurement']['unit']\n            measurement_value = ligand_data['measurement']['value']\n            measurement_error = ligand_data['measurement']['error']\n            # Unit conversion\n            # TODO: Let's persuade PLBenchmarks to use pint units\n            unit_conversions = { 'M' : 1.0, 'mM' : 1e-3, 'uM' : 1e-6, 'nM' : 1e-9, 'pM' : 1e-12, 'fM' : 1e-15 }\n            if unit_symbol not in unit_conversions:\n                raise ValueError(f'Unknown units \"{unit_symbol}\"')\n            value_to_molar= unit_conversions[unit_symbol]\n            # Handle unknown errors\n            # TODO: We should be able to ensure that all entries have more reasonable errors.\n            if measurement_error == -1:\n                # TODO: For now, we use a relative_error from the Tyk2 system 10.1016/j.ejmech.2013.03.070\n                relative_error = 0.3\n            else:\n                relative_error = measurement_error / measurement_value\n            # Convert to free eneriges\n            expt_DG = kBT.value_in_unit(unit.kilocalorie_per_mole) * np.log(measurement_value * value_to_molar)\n            expt_dDG = kBT.value_in_unit(unit.kilocalorie_per_mole) * relative_error\n            csv_file.write(f\"{ligand_name}, {expt_DG}, {expt_dDG}\\n\")\n\n        # Calculated block\n        # print header for block\n        csv_file.write(\"# Calculated block\\n\")\n        csv_file.write(\"# Ligand1,Ligand2, calc_DDG, calc_dDDG(MBAR), calc_dDDG(additional)\\n\")\n        # Loop through simulation, extract ligand1 and ligand2 indices, convert to names, create string with\n        # ligand1, ligand2, calc_DDG, calc_dDDG(MBAR), calc_dDDG(additional)\n        # write string in csv file\n        for simulation in simulation_data:\n            out_dir = simulation.directory.split('/')[-1]\n            # getting integer indices\n            ligand1_id, ligand2_id = int(out_dir.split('_')[-1]), int(out_dir.split('_')[-2])  # CHECK ORDER!\n            # getting names of ligands\n            ligand1, ligand2 = lig_id_to_name[ligand1_id], lig_id_to_name[ligand2_id]\n            # getting calc_DDG in kcal/mol\n            calc_DDG = simulation.bindingdg.value_in_unit(unit.kilocalorie_per_mole)\n            # getting calc_dDDG in kcal/mol\n            calc_dDDG = simulation.bindingddg.value_in_unit(unit.kilocalorie_per_mole)\n            csv_file.write(\n                f\"{ligand1}, {ligand2}, {calc_DDG}, {calc_dDDG}, 0.0\\n\")  # hardcoding additional error as 0.0\n\n\n# Defining command line arguments\n# fetching targets from github repo\n# TODO: This part should be done using plbenchmarks API - once there is a conda pkg\ntargets_url = f\"{base_repo_url}/raw/master/data/targets.yml\"\nwith urllib.request.urlopen(targets_url) as response:\n    targets_dict = yaml.safe_load(response.read())\n# get the possible choices from targets yaml file\ntarget_choices = targets_dict.keys()\n\narg_parser = argparse.ArgumentParser(description='CLI tool for running perses protein-ligand benchmarks analysis.')\narg_parser.add_argument(\n    \"--target\",\n    type=str,\n    help=\"Target biomolecule, use openff's plbenchmark names.\",\n    choices=target_choices,\n    required=True\n)\narg_parser.add_argument(\n    \"--reversed\",\n    action='store_true',\n    help=\"Analyze reversed edge simulations. Helpful for testing/consistency checks.\"\n)\nargs = arg_parser.parse_args()\ntarget = args.target\n\n# Download experimental data\n# TODO: This part should be done using plbenchmarks API - once there is a conda pkg\n# TODO: Let's cache this data when we set up the initial simulations in case it changes in between setting up and running the calculations and analysis.\n# TODO: Let's also be sure to use a specific release tag rather than 'master'\ntarget_dir = targets_dict[target]['dir']\nligands_url = f\"{base_repo_url}/raw/master/data/{target_dir}/00_data/ligands.yml\"\nwith urllib.request.urlopen(ligands_url) as response:\n    yaml_contents = response.read()\n    print(yaml_contents)\n    ligands_dict = yaml.safe_load(yaml_contents)\n\n# DEBUG\nprint('')\nprint(yaml.dump(ligands_dict))\n\n# Get paths for simulation output directories\nout_dirs = get_simdir_list(is_reversed=args.reversed)\n\n# Generate list with simulation objects\nsimulations = get_simulations_data(out_dirs)\n\n# Generate csv file\ncsv_path = f'./{target}_arsenic.csv'\nto_arsenic_csv(ligands_dict, simulations, out_csv=csv_path)\n\n\n# TODO: Separate plotting in a different file\n# Make plots and store\nfe = wrangle.FEMap(csv_path)\n# Relative plot\nplotting.plot_DDGs(fe.graph,\n                   target_name=f'{target}',\n                   title=f'Relative binding energies - {target}',\n                   figsize=5,\n                   filename='./plot_relative.pdf'\n                   )\n# Absolute plot, with experimental data shifted to correct mean\nexperimental_mean_dg = np.asarray([node[1][\"exp_DG\"] for node in fe.graph.nodes(data=True)]).mean()\nplotting.plot_DGs(fe.graph,\n                  target_name=f'{target}',\n                  title=f'Absolute binding energies - {target}',\n                  figsize=5,\n                  filename='./plot_absolute.pdf',\n                  shift=experimental_mean_dg,\n                  )\n\n\n"
  },
  {
    "path": "scripts/perses-benchmark/tyk2/espaloma-0.2.2/run_benchmarks.py",
    "content": "#!/usr/bin/env python\n\n\"\"\"\nCLI utility to automatically run benchmarks using data from the open force field protein-ligand benchmark at\nhttps://github.com/openforcefield/protein-ligand-benchmark\n\nIt requires internet connection to function properly, by connecting to the mentioned repository.\n\"\"\"\n# TODO: Use plbenchmarks when conda package is available.\n\nimport argparse\nimport logging\nimport os\nimport yaml\n\nfrom perses.app.setup_relative_calculation import run\nfrom perses.utils.url_utils import retrieve_file_url\nfrom perses.utils.url_utils import fetch_url_contents\n\n# Setting logging level config\nLOGLEVEL = os.environ.get(\"LOGLEVEL\", \"DEBUG\").upper()\nlogging.basicConfig(\n    format='%(asctime)s %(levelname)-8s %(message)s',\n    level=LOGLEVEL,\n    datefmt='%Y-%m-%d %H:%M:%S')\n_logger = logging.getLogger()\n_logger.setLevel(LOGLEVEL)\n\n# global variables\nbase_repo_url = \"https://github.com/openforcefield/protein-ligand-benchmark\"\n\n\ndef concatenate_files(input_files, output_file):\n    \"\"\"\n    Concatenate files given in input_files iterator into output_file.\n    \"\"\"\n    with open(output_file, 'w') as outfile:\n        for filename in input_files:\n            with open(filename) as infile:\n                for line in infile:\n                    outfile.write(line)\n\n\ndef run_relative_perturbation(lig_a_idx, lig_b_idx, reverse=False, tidy=True):\n    \"\"\"\n    Perform relative free energy simulation using perses CLI.\n\n    Parameters\n    ----------\n        lig_a_idx : int\n            Index for first ligand (ligand A)\n        lig_b_idx : int\n            Index for second ligand (ligand B)\n        reverse: bool\n            Run the edge in reverse direction. Swaps the ligands.\n        tidy : bool, optional\n            remove auto-generated yaml files.\n\n    Expects the target/protein pdb file in the same directory to be called 'target.pdb', and ligands file\n    to be called 'ligands.sdf'.\n    \"\"\"\n    _logger.info(f'Starting relative calculation of ligand {lig_a_idx} to {lig_b_idx}')\n    trajectory_directory = f'out_{lig_a_idx}_{lig_b_idx}'\n    new_yaml = f'relative_{lig_a_idx}_{lig_b_idx}.yaml'\n\n    # read base template yaml file\n    # TODO: template.yaml file is configured for Tyk2, check if the same options work for others.\n    with open(f'template.yaml', \"r\") as yaml_file:\n        options = yaml.load(yaml_file, Loader=yaml.FullLoader)\n\n    # TODO: add a step to perform some minimization - should help with NaNs\n    # generate yaml file from template\n    options['protein_pdb'] = 'target.pdb'\n    options['ligand_file'] = 'ligands.sdf'\n    if reverse:\n        # Do the other direction of ligands\n        options['old_ligand_index'] = lig_b_idx\n        options['new_ligand_index'] = lig_a_idx\n        # mark the output directory with reversed\n        trajectory_directory = f'{trajectory_directory}_reversed'\n        # mark new yaml file with reversed\n        temp_path = new_yaml.split('.')\n        new_yaml = f'{temp_path[0]}_reversed.{temp_path[1]}'\n    else:\n        options['old_ligand_index'] = lig_a_idx\n        options['new_ligand_index'] = lig_b_idx\n    options['trajectory_directory'] = f'{trajectory_directory}'\n    with open(new_yaml, 'w') as outfile:\n        yaml.dump(options, outfile)\n\n    # run the simulation - using API point to respect logging level\n    run(new_yaml)\n\n    _logger.info(f'Relative calculation of ligand {lig_a_idx} to {lig_b_idx} complete')\n\n    if tidy:\n        os.remove(new_yaml)\n\n\n# Defining command line arguments\n# fetching targets from github repo\n# TODO: This part should be done using plbenchmarks API - once there is a conda pkg\ntargets_url = f\"{base_repo_url}/raw/master/data/targets.yml\"\nwith fetch_url_contents(targets_url) as response:\n    targets_dict = yaml.safe_load(response.read())\n# get the possible choices from targets yaml file\ntarget_choices = targets_dict.keys()\n\narg_parser = argparse.ArgumentParser(description='CLI tool for running perses protein-ligand benchmarks.')\narg_parser.add_argument(\n    \"--target\",\n    type=str,\n    help=\"Target biomolecule, use openff's plbenchmark names.\",\n    choices=target_choices,\n    required=True\n)\narg_parser.add_argument(\n    \"--edge\",\n    type=int,\n    help=\"Edge index (0-based) according to edges yaml file in dataset. Ex. --edge 5 (for sixth edge)\",\n    required=True\n)\narg_parser.add_argument(\n    \"--reversed\",\n    action='store_true',\n    help=\"Whether to run the edge in reverse direction. Helpful for consistency checks.\"\n)\nargs = arg_parser.parse_args()\ntarget = args.target\nis_reversed = args.reversed\n\n# Fetch protein pdb file\n# TODO: This part should be done using plbenchmarks API - once there is a conda pkg\ntarget_dir = targets_dict[target]['dir']\npdb_url = f\"{base_repo_url}/raw/master/data/{target_dir}/01_protein/crd/protein.pdb\"\npdb_file = retrieve_file_url(pdb_url)\n\n# Fetch cofactors crystalwater pdb file\n# TODO: This part should be done using plbenchmarks API - once there is a conda pkg\ncofactors_url = f\"{base_repo_url}/raw/master/data/{target_dir}/01_protein/crd/cofactors_crystalwater.pdb\"\ncofactors_file = retrieve_file_url(cofactors_url)\n\n# Concatenate protein with cofactors pdbs\nconcatenate_files((pdb_file, cofactors_file), 'target.pdb')\n\n# Fetch ligands sdf files and concatenate them in one\n# TODO: This part should be done using plbenchmarks API - once there is a conda pkg\nligands_url = f\"{base_repo_url}/raw/master/data/{target_dir}/00_data/ligands.yml\"\nwith fetch_url_contents(ligands_url) as response:\n    ligands_dict = yaml.safe_load(response.read())\nligand_files = []\nfor ligand in ligands_dict.keys():\n    ligand_url = f\"{base_repo_url}/raw/master/data/{target_dir}/02_ligands/{ligand}/crd/{ligand}.sdf\"\n    ligand_file = retrieve_file_url(ligand_url)\n    ligand_files.append(ligand_file)\n# concatenate sdfs\nconcatenate_files(ligand_files, 'ligands.sdf')\n\n# run simulation\n# fetch edges information\n# TODO: This part should be done using plbenchmarks API - once there is a conda pkg\nedges_url = f\"{base_repo_url}/raw/master/data/{target_dir}/00_data/edges.yml\"\nwith fetch_url_contents(edges_url) as response:\n    edges_dict = yaml.safe_load(response.read())\nedges_list = list(edges_dict.values())  # suscriptable edges object - note dicts are ordered for py>=3.7\n# edge list to access by index\nedge_index = args.edge  # read from cli arguments\nedge = edges_list[edge_index]\nligand_a_name = edge['ligand_a']\nligand_b_name = edge['ligand_b']\n# ligands list to get indices -- preserving same order as upstream yaml file\nligands_list = list(ligands_dict.keys())\nlig_a_index = ligands_list.index(ligand_a_name)\nlig_b_index = ligands_list.index(ligand_b_name)\n# Perform the simulation\nrun_relative_perturbation(lig_a_index, lig_b_index, reverse=is_reversed)\n"
  },
  {
    "path": "scripts/perses-benchmark/tyk2/espaloma-0.2.2/template.yaml",
    "content": "# Path to protein file\nprotein_pdb: null\n# Path to ligand SDF file\nligand_file: null\n# Indices of old and new ligands within SDF file\nold_ligand_index: null\nnew_ligand_index: null\n\n#\n# Force fields\n#\n\n# OpenMM ffxml force field files installed via the openmm-forcefields package\n# for biopolymers and solvents.\n# Note that small molecule force field files should NOT be included here.\nforcefield_files:\n    - amber/ff14SB.xml # ff14SB protein force field\n    - amber/tip3p_standard.xml # TIP3P and recommended monovalent ion parameters\n    - amber/tip3p_HFE_multivalent.xml # for divalent ions\n    - amber/phosaa10.xml # HANDLES THE TPO\n\n# Small molecule force field\n# Options include anything allowed by the openmmforcefields SystemGenerator\n# e.g. one of ['openff-2.0.0', 'gaff-2.11']\nsmall_molecule_forcefield: espaloma-0.2.2\n\n#\n# Simulation conditions\n#\n\n# Simulation setup options\nsolvent_padding: 9.0 # angstroms\n\n# Use geometry-derived mapping\nuse_given_geometries: true\ngiven_geometries_tolerance: 0.2 # angstroms\n\n# Atom mapping specification\natom_expression:\n    - IntType\nbond_expession:\n    - DefaultBonds\n\n# Multi-state sampling scheme\n# One of ['repex', 'nonequilibrium', 'sams']\nfe_type: repex\n\n# Checkpoint interval\ncheckpoint_interval: 50 # number of iterations\n\n# Number of equilibration iterations\nn_equilibration_iterations: 0\n\n# Number of iterations to run\nn_cycles: 5000\n\n# Number of alchemical intermediate states to use\nn_states: 12\n\npressure: 1.0 # atmsopheres\ntemperature: 300.0 # kelvin\ntimestep: 4.0 # femtoseconds\n# remove_constraints: false\n\n# Number of integration stpes per iteration\nn_steps_per_move_application: 250\n\n# Location for storing trajectories\ntrajectory_directory: null\n\n# Prefix for trajectory files (project-specific name)\ntrajectory_prefix: out\n\n# Atoms to store in NetCDF files (MDTraj selection syntax)\natom_selection: not water\n\n# Calculation phases to run\n# Permitted phases: ['complex', 'solvent', 'vacuum']\nphases:\n    - complex\n    - solvent\n\n\n"
  },
  {
    "path": "scripts/perses-benchmark/tyk2/openff-1.2.0/LSF-job-template.sh",
    "content": "#!/bin/bash\n#BSUB -P \"tyk2-benchmark\"\n#BSUB -J \"perses-benchmark-[1-24]\"\n#BSUB -n 1\n#BSUB -R rusage[mem=8]\n#BSUB -R span[hosts=1]\n#BSUB -q gpuqueue\n#BSUB -sp 1 # low priority. default is 12, max is 25\n#BSUB -gpu num=1:j_exclusive=yes:mode=shared\n#BSUB -W  24:00\n#BSUB -o out_%J_%I.stdout\n#BSUB -eo out_%J_%I.stderr\n#BSUB -L /bin/bash\n\nsource ~/.bashrc\nOPENMM_CPU_THREADS=1\n\necho \"changing directory to ${LS_SUBCWD}\"\ncd $LS_SUBCWD\nconda activate espaloma-perses\n\n# Report node in use\nhostname\n\n# Report CUDA info\nenv | sort | grep 'CUDA'\n\n# launching a benchmark pair (target, edge) per job (0-based thus substract 1)\npython run_benchmarks.py --target tyk2 --edge $(( $LSB_JOBINDEX - 1 ))\n"
  },
  {
    "path": "scripts/perses-benchmark/tyk2/openff-1.2.0/README.md",
    "content": "# Perses benchmarks\n\nThis subdirectory exposes a CLI tool for running automated benchmarks from\n[OpenFF's protein ligand benchmark dataset](https://github.com/openforcefield/protein-ligand-benchmark) using perses.\n\n## Running all edges\n\nA script to run all transformations in an LSF batch scheduler is provided, but will likely need to be modified for your batch queue system:\n```bash\nbsub < LSF-job-template.sh\n```\n\n## Running single edges\n\nAssuming you have a clone of the perses code repository and you are standing in the `benchmarks` subdirectory\n(where this file lives). Then the benchmarks can be run using the following command syntax:\n```bash\npython run_benchmarks.py --target [protein-name] --edge [edge-index]\n```\n\nFor example, for running the seventh edge (zero-based, according to [plbenchmark data](https://github.com/openforcefield/protein-ligand-benchmark) )\nfor `tyk2` protein, you would run:\n```bash\n# Set up and run edge 6\npython run_benchmarks.py --target tyk2 --edge 6\n```\nShould the calculation for an edge fail, you can simply re-run the same command-line and the calculation will resume:\n```bash\n# Resume failed edge 6\npython run_benchmarks.py --target tyk2 --edge 6\n```\nFor more information on how to use the tool, you can run `python run_benchmarks.py -h`.\n\n## Analyzing benchmarks\n\nTo analyze the simulations a script called `benchmark_analysis.py` is used as follows:\n```bash\npython benchmark_analysis.py --target [protein-name]\n```\n\nFor example, for tyk2 results:\n```bash\npython benchmark_analysis.py --target tyk2\n```\nThis will generate an output CSV file for [`arsenic`](https://github.com/openforcefield/arsenic) and corresponding absolute and relative free energy plots as PNG files produced according to best practices.)\n\nFor more information on how to use the cli analysis tool use `python benchmark_analysis.py -h`.\n"
  },
  {
    "path": "scripts/perses-benchmark/tyk2/openff-1.2.0/benchmark_analysis.py",
    "content": "\"\"\"\nScript to perform analysis of perses simulations executed using run_benchmarks.py script.\n\nIntended to be used on systems from https://github.com/openforcefield/protein-ligand-benchmark\n\"\"\"\n\nimport argparse\nimport glob\nimport itertools\nimport re\nimport warnings\n\nimport numpy as np\nimport urllib.request\nimport yaml\n\nfrom openmmtools.constants import kB\nfrom perses.analysis.load_simulations import Simulation\n\nfrom openmm import unit\n\nfrom openff.arsenic import plotting, wrangle\n\n# global variables\nbase_repo_url = \"https://github.com/openforcefield/protein-ligand-benchmark\"\n\n\n# Helper functions\n\ndef get_simdir_list(base_dir='.', is_reversed=False):\n    \"\"\"\n    Get list of directories to extract simulation data.\n\n    Attributes\n    ----------\n    base_dir: str, optional, default='.'\n        Base directory where to search for simulations results. Defaults to current directory.\n    is_reversed: bool, optional, default=False\n        Whether to consider the reversed simulations or not. Meant for testing purposes.\n\n    Returns\n    -------\n    dir_list: list\n        List of directories paths for simulation results.\n    \"\"\"\n    # Load all expected simulation from directories\n    out_dirs = ['/'.join(filepath.split('/')[:-1]) for filepath in glob.glob(f'{base_dir}/out*/*complex.nc')]\n    reg = re.compile(r'out_[0-9]+_[0-9]+_reversed')  # regular expression to deal with reversed directories\n    if is_reversed:\n        # Choose only reversed directories\n        out_dirs = list(filter(reg.search, out_dirs))\n    else:\n        # Filter out reversed directories\n        out_dirs = list(itertools.filterfalse(reg.search, out_dirs))\n    return out_dirs\n\n\ndef get_simulations_data(simulation_dirs):\n    \"\"\"Generates a list of simulation data objects given the simulation directories paths.\"\"\"\n    simulations = []\n    for out_dir in simulation_dirs:\n        # Load complete or fully working simulations\n        # TODO: Try getting better exceptions from openmmtools -- use non-generic exceptions\n        try:\n            simulation = Simulation(out_dir)\n            simulations.append(simulation)\n        except Exception:\n            warnings.warn(f\"Edge in {out_dir} could not be loaded. Check simulation output is complete.\")\n    return simulations\n\n\ndef to_arsenic_csv(experimental_data: dict, simulation_data: list, out_csv: str = 'out_benchmark.csv'):\n    \"\"\"\n    Generates a csv file to be used with openff-arsenic. Energy units in kcal/mol.\n\n    .. warning:: To be deprecated once arsenic object model is improved.\n\n    Parameters\n    ----------\n        experimental_data: dict\n            Python nested dictionary with experimental data in micromolar or nanomolar units.\n            Example of entry:\n\n                {'lig_ejm_31': {'measurement': {'comment': 'Table 4, entry 31',\n                  'doi': '10.1016/j.ejmech.2013.03.070',\n                  'error': -1,\n                  'type': 'ki',\n                  'unit': 'uM',\n                  'value': 0.096},\n                  'name': 'lig_ejm_31',\n                  'smiles': '[H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)C([H])([H])[H])[H])[H])Cl)[H]'}\n\n        simulation_data: list or iterable\n            Python iterable object with perses Simulation objects as entries.\n        out_csv: str\n            Path to output csv file to be generated.\n    \"\"\"\n    # Ligand information\n    ligands_names = list(ligands_dict.keys())\n    lig_id_to_name = dict(enumerate(ligands_names))\n    kBT = kB * 300 * unit.kelvin  # useful when converting to kcal/mol\n    # Write csv file\n    with open(out_csv, 'w') as csv_file:\n        # Experimental block\n        # print header for block\n        csv_file.write(\"# Experimental block\\n\")\n        csv_file.write(\"# Ligand, expt_DG, expt_dDG\\n\")\n        # Extract ligand name, expt_DG and expt_dDG from ligands dictionary\n        for ligand_name, ligand_data in experimental_data.items():\n            # TODO: Handle multiple measurement types\n            unit_symbol = ligand_data['measurement']['unit']\n            measurement_value = ligand_data['measurement']['value']\n            measurement_error = ligand_data['measurement']['error']\n            # Unit conversion\n            # TODO: Let's persuade PLBenchmarks to use pint units\n            unit_conversions = { 'M' : 1.0, 'mM' : 1e-3, 'uM' : 1e-6, 'nM' : 1e-9, 'pM' : 1e-12, 'fM' : 1e-15 }\n            if unit_symbol not in unit_conversions:\n                raise ValueError(f'Unknown units \"{unit_symbol}\"')\n            value_to_molar= unit_conversions[unit_symbol]\n            # Handle unknown errors\n            # TODO: We should be able to ensure that all entries have more reasonable errors.\n            if measurement_error == -1:\n                # TODO: For now, we use a relative_error from the Tyk2 system 10.1016/j.ejmech.2013.03.070\n                relative_error = 0.3\n            else:\n                relative_error = measurement_error / measurement_value\n            # Convert to free eneriges\n            expt_DG = kBT.value_in_unit(unit.kilocalorie_per_mole) * np.log(measurement_value * value_to_molar)\n            expt_dDG = kBT.value_in_unit(unit.kilocalorie_per_mole) * relative_error\n            csv_file.write(f\"{ligand_name}, {expt_DG}, {expt_dDG}\\n\")\n\n        # Calculated block\n        # print header for block\n        csv_file.write(\"# Calculated block\\n\")\n        csv_file.write(\"# Ligand1,Ligand2, calc_DDG, calc_dDDG(MBAR), calc_dDDG(additional)\\n\")\n        # Loop through simulation, extract ligand1 and ligand2 indices, convert to names, create string with\n        # ligand1, ligand2, calc_DDG, calc_dDDG(MBAR), calc_dDDG(additional)\n        # write string in csv file\n        for simulation in simulation_data:\n            out_dir = simulation.directory.split('/')[-1]\n            # getting integer indices\n            ligand1_id, ligand2_id = int(out_dir.split('_')[-1]), int(out_dir.split('_')[-2])  # CHECK ORDER!\n            # getting names of ligands\n            ligand1, ligand2 = lig_id_to_name[ligand1_id], lig_id_to_name[ligand2_id]\n            # getting calc_DDG in kcal/mol\n            calc_DDG = simulation.bindingdg.value_in_unit(unit.kilocalorie_per_mole)\n            # getting calc_dDDG in kcal/mol\n            calc_dDDG = simulation.bindingddg.value_in_unit(unit.kilocalorie_per_mole)\n            csv_file.write(\n                f\"{ligand1}, {ligand2}, {calc_DDG}, {calc_dDDG}, 0.0\\n\")  # hardcoding additional error as 0.0\n\n\n# Defining command line arguments\n# fetching targets from github repo\n# TODO: This part should be done using plbenchmarks API - once there is a conda pkg\ntargets_url = f\"{base_repo_url}/raw/master/data/targets.yml\"\nwith urllib.request.urlopen(targets_url) as response:\n    targets_dict = yaml.safe_load(response.read())\n# get the possible choices from targets yaml file\ntarget_choices = targets_dict.keys()\n\narg_parser = argparse.ArgumentParser(description='CLI tool for running perses protein-ligand benchmarks analysis.')\narg_parser.add_argument(\n    \"--target\",\n    type=str,\n    help=\"Target biomolecule, use openff's plbenchmark names.\",\n    choices=target_choices,\n    required=True\n)\narg_parser.add_argument(\n    \"--reversed\",\n    action='store_true',\n    help=\"Analyze reversed edge simulations. Helpful for testing/consistency checks.\"\n)\nargs = arg_parser.parse_args()\ntarget = args.target\n\n# Download experimental data\n# TODO: This part should be done using plbenchmarks API - once there is a conda pkg\n# TODO: Let's cache this data when we set up the initial simulations in case it changes in between setting up and running the calculations and analysis.\n# TODO: Let's also be sure to use a specific release tag rather than 'master'\ntarget_dir = targets_dict[target]['dir']\nligands_url = f\"{base_repo_url}/raw/master/data/{target_dir}/00_data/ligands.yml\"\nwith urllib.request.urlopen(ligands_url) as response:\n    yaml_contents = response.read()\n    print(yaml_contents)\n    ligands_dict = yaml.safe_load(yaml_contents)\n\n# DEBUG\nprint('')\nprint(yaml.dump(ligands_dict))\n\n# Get paths for simulation output directories\nout_dirs = get_simdir_list(is_reversed=args.reversed)\n\n# Generate list with simulation objects\nsimulations = get_simulations_data(out_dirs)\n\n# Generate csv file\ncsv_path = f'./{target}_arsenic.csv'\nto_arsenic_csv(ligands_dict, simulations, out_csv=csv_path)\n\n\n# TODO: Separate plotting in a different file\n# Make plots and store\nfe = wrangle.FEMap(csv_path)\n# Relative plot\nplotting.plot_DDGs(fe.graph,\n                   target_name=f'{target}',\n                   title=f'Relative binding energies - {target}',\n                   figsize=5,\n                   filename='./plot_relative.pdf'\n                   )\n# Absolute plot, with experimental data shifted to correct mean\nexperimental_mean_dg = np.asarray([node[1][\"exp_DG\"] for node in fe.graph.nodes(data=True)]).mean()\nplotting.plot_DGs(fe.graph,\n                  target_name=f'{target}',\n                  title=f'Absolute binding energies - {target}',\n                  figsize=5,\n                  filename='./plot_absolute.pdf',\n                  shift=experimental_mean_dg,\n                  )\n\n\n"
  },
  {
    "path": "scripts/perses-benchmark/tyk2/openff-1.2.0/run_benchmarks.py",
    "content": "#!/usr/bin/env python\n\n\"\"\"\nCLI utility to automatically run benchmarks using data from the open force field protein-ligand benchmark at\nhttps://github.com/openforcefield/protein-ligand-benchmark\n\nIt requires internet connection to function properly, by connecting to the mentioned repository.\n\"\"\"\n# TODO: Use plbenchmarks when conda package is available.\n\nimport argparse\nimport logging\nimport os\nimport yaml\n\nfrom perses.app.setup_relative_calculation import run\nfrom perses.utils.url_utils import retrieve_file_url\nfrom perses.utils.url_utils import fetch_url_contents\n\n# Setting logging level config\nLOGLEVEL = os.environ.get(\"LOGLEVEL\", \"DEBUG\").upper()\nlogging.basicConfig(\n    format='%(asctime)s %(levelname)-8s %(message)s',\n    level=LOGLEVEL,\n    datefmt='%Y-%m-%d %H:%M:%S')\n_logger = logging.getLogger()\n_logger.setLevel(LOGLEVEL)\n\n# global variables\nbase_repo_url = \"https://github.com/openforcefield/protein-ligand-benchmark\"\n\n\ndef concatenate_files(input_files, output_file):\n    \"\"\"\n    Concatenate files given in input_files iterator into output_file.\n    \"\"\"\n    with open(output_file, 'w') as outfile:\n        for filename in input_files:\n            with open(filename) as infile:\n                for line in infile:\n                    outfile.write(line)\n\n\ndef run_relative_perturbation(lig_a_idx, lig_b_idx, reverse=False, tidy=True):\n    \"\"\"\n    Perform relative free energy simulation using perses CLI.\n\n    Parameters\n    ----------\n        lig_a_idx : int\n            Index for first ligand (ligand A)\n        lig_b_idx : int\n            Index for second ligand (ligand B)\n        reverse: bool\n            Run the edge in reverse direction. Swaps the ligands.\n        tidy : bool, optional\n            remove auto-generated yaml files.\n\n    Expects the target/protein pdb file in the same directory to be called 'target.pdb', and ligands file\n    to be called 'ligands.sdf'.\n    \"\"\"\n    _logger.info(f'Starting relative calculation of ligand {lig_a_idx} to {lig_b_idx}')\n    trajectory_directory = f'out_{lig_a_idx}_{lig_b_idx}'\n    new_yaml = f'relative_{lig_a_idx}_{lig_b_idx}.yaml'\n\n    # read base template yaml file\n    # TODO: template.yaml file is configured for Tyk2, check if the same options work for others.\n    with open(f'template.yaml', \"r\") as yaml_file:\n        options = yaml.load(yaml_file, Loader=yaml.FullLoader)\n\n    # TODO: add a step to perform some minimization - should help with NaNs\n    # generate yaml file from template\n    options['protein_pdb'] = 'target.pdb'\n    options['ligand_file'] = 'ligands.sdf'\n    if reverse:\n        # Do the other direction of ligands\n        options['old_ligand_index'] = lig_b_idx\n        options['new_ligand_index'] = lig_a_idx\n        # mark the output directory with reversed\n        trajectory_directory = f'{trajectory_directory}_reversed'\n        # mark new yaml file with reversed\n        temp_path = new_yaml.split('.')\n        new_yaml = f'{temp_path[0]}_reversed.{temp_path[1]}'\n    else:\n        options['old_ligand_index'] = lig_a_idx\n        options['new_ligand_index'] = lig_b_idx\n    options['trajectory_directory'] = f'{trajectory_directory}'\n    with open(new_yaml, 'w') as outfile:\n        yaml.dump(options, outfile)\n\n    # run the simulation - using API point to respect logging level\n    run(new_yaml)\n\n    _logger.info(f'Relative calculation of ligand {lig_a_idx} to {lig_b_idx} complete')\n\n    if tidy:\n        os.remove(new_yaml)\n\n\n# Defining command line arguments\n# fetching targets from github repo\n# TODO: This part should be done using plbenchmarks API - once there is a conda pkg\ntargets_url = f\"{base_repo_url}/raw/master/data/targets.yml\"\nwith fetch_url_contents(targets_url) as response:\n    targets_dict = yaml.safe_load(response.read())\n# get the possible choices from targets yaml file\ntarget_choices = targets_dict.keys()\n\narg_parser = argparse.ArgumentParser(description='CLI tool for running perses protein-ligand benchmarks.')\narg_parser.add_argument(\n    \"--target\",\n    type=str,\n    help=\"Target biomolecule, use openff's plbenchmark names.\",\n    choices=target_choices,\n    required=True\n)\narg_parser.add_argument(\n    \"--edge\",\n    type=int,\n    help=\"Edge index (0-based) according to edges yaml file in dataset. Ex. --edge 5 (for sixth edge)\",\n    required=True\n)\narg_parser.add_argument(\n    \"--reversed\",\n    action='store_true',\n    help=\"Whether to run the edge in reverse direction. Helpful for consistency checks.\"\n)\nargs = arg_parser.parse_args()\ntarget = args.target\nis_reversed = args.reversed\n\n# Fetch protein pdb file\n# TODO: This part should be done using plbenchmarks API - once there is a conda pkg\ntarget_dir = targets_dict[target]['dir']\npdb_url = f\"{base_repo_url}/raw/master/data/{target_dir}/01_protein/crd/protein.pdb\"\npdb_file = retrieve_file_url(pdb_url)\n\n# Fetch cofactors crystalwater pdb file\n# TODO: This part should be done using plbenchmarks API - once there is a conda pkg\ncofactors_url = f\"{base_repo_url}/raw/master/data/{target_dir}/01_protein/crd/cofactors_crystalwater.pdb\"\ncofactors_file = retrieve_file_url(cofactors_url)\n\n# Concatenate protein with cofactors pdbs\nconcatenate_files((pdb_file, cofactors_file), 'target.pdb')\n\n# Fetch ligands sdf files and concatenate them in one\n# TODO: This part should be done using plbenchmarks API - once there is a conda pkg\nligands_url = f\"{base_repo_url}/raw/master/data/{target_dir}/00_data/ligands.yml\"\nwith fetch_url_contents(ligands_url) as response:\n    ligands_dict = yaml.safe_load(response.read())\nligand_files = []\nfor ligand in ligands_dict.keys():\n    ligand_url = f\"{base_repo_url}/raw/master/data/{target_dir}/02_ligands/{ligand}/crd/{ligand}.sdf\"\n    ligand_file = retrieve_file_url(ligand_url)\n    ligand_files.append(ligand_file)\n# concatenate sdfs\nconcatenate_files(ligand_files, 'ligands.sdf')\n\n# run simulation\n# fetch edges information\n# TODO: This part should be done using plbenchmarks API - once there is a conda pkg\nedges_url = f\"{base_repo_url}/raw/master/data/{target_dir}/00_data/edges.yml\"\nwith fetch_url_contents(edges_url) as response:\n    edges_dict = yaml.safe_load(response.read())\nedges_list = list(edges_dict.values())  # suscriptable edges object - note dicts are ordered for py>=3.7\n# edge list to access by index\nedge_index = args.edge  # read from cli arguments\nedge = edges_list[edge_index]\nligand_a_name = edge['ligand_a']\nligand_b_name = edge['ligand_b']\n# ligands list to get indices -- preserving same order as upstream yaml file\nligands_list = list(ligands_dict.keys())\nlig_a_index = ligands_list.index(ligand_a_name)\nlig_b_index = ligands_list.index(ligand_b_name)\n# Perform the simulation\nrun_relative_perturbation(lig_a_index, lig_b_index, reverse=is_reversed)\n"
  },
  {
    "path": "scripts/perses-benchmark/tyk2/openff-1.2.0/template.yaml",
    "content": "# Path to protein file\nprotein_pdb: null\n# Path to ligand SDF file\nligand_file: null\n# Indices of old and new ligands within SDF file\nold_ligand_index: null\nnew_ligand_index: null\n\n#\n# Force fields\n#\n\n# OpenMM ffxml force field files installed via the openmm-forcefields package\n# for biopolymers and solvents.\n# Note that small molecule force field files should NOT be included here.\nforcefield_files:\n    - amber/ff14SB.xml # ff14SB protein force field\n    - amber/tip3p_standard.xml # TIP3P and recommended monovalent ion parameters\n    - amber/tip3p_HFE_multivalent.xml # for divalent ions\n    - amber/phosaa10.xml # HANDLES THE TPO\n\n# Small molecule force field\n# Options include anything allowed by the openmmforcefields SystemGenerator\n# e.g. one of ['openff-2.0.0', 'gaff-2.11']\nsmall_molecule_forcefield: openff-1.2.0\n\n#\n# Simulation conditions\n#\n\n# Simulation setup options\nsolvent_padding: 9.0 # angstroms\n\n# Use geometry-derived mapping\nuse_given_geometries: true\ngiven_geometries_tolerance: 0.2 # angstroms\n\n# Atom mapping specification\natom_expression:\n    - IntType\nbond_expession:\n    - DefaultBonds\n\n# Multi-state sampling scheme\n# One of ['repex', 'nonequilibrium', 'sams']\nfe_type: repex\n\n# Checkpoint interval\ncheckpoint_interval: 50 # number of iterations\n\n# Number of equilibration iterations\nn_equilibration_iterations: 0\n\n# Number of iterations to run\nn_cycles: 5000\n\n# Number of alchemical intermediate states to use\nn_states: 12\n\npressure: 1.0 # atmsopheres\ntemperature: 300.0 # kelvin\ntimestep: 4.0 # femtoseconds\n# remove_constraints: false\n\n# Number of integration stpes per iteration\nn_steps_per_move_application: 250\n\n# Location for storing trajectories\ntrajectory_directory: null\n\n# Prefix for trajectory files (project-specific name)\ntrajectory_prefix: out\n\n# Atoms to store in NetCDF files (MDTraj selection syntax)\natom_selection: not water\n\n# Calculation phases to run\n# Permitted phases: ['complex', 'solvent', 'vacuum']\nphases:\n    - complex\n    - solvent\n\n\n"
  },
  {
    "path": "setup.cfg",
    "content": "# Helper file to handle all configs\n\n[coverage:run]\n# .coveragerc to control coverage.py and pytest-cov\nomit =\n    # Omit the tests\n    */tests/*\n    # Omit generated versioneer\n    espaloma/_version.py\n\n[yapf]\n# YAPF, in .style.yapf files this shows up as \"[style]\" header\nCOLUMN_LIMIT = 119\nINDENT_WIDTH = 4\nUSE_TABS = False\n\n[flake8]\n# Flake8, PyFlakes, etc\nmax-line-length = 119\n\n[versioneer]\n# Automatic version numbering scheme\nVCS = git\nstyle = pep440\nversionfile_source = espaloma/_version.py\nversionfile_build = espaloma/_version.py\ntag_prefix = ''\n\n[aliases]\ntest = pytest\n"
  },
  {
    "path": "setup.py",
    "content": "\"\"\"\nespaloma\nExtensible Surrogate Potential of Ab initio Learned and Optimized by Message-passing Algorithm\n\"\"\"\nimport sys\n\nfrom setuptools import find_packages, setup\n\nimport versioneer\n\nshort_description = __doc__.split(\"\\n\")\n\n# from https://github.com/pytest-dev/pytest-runner#conditional-requirement\nneeds_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv)\npytest_runner = ['pytest-runner'] if needs_pytest else []\n\ntry:\n    with open(\"README.md\", \"r\") as handle:\n        long_description = handle.read()\nexcept:\n    long_description = \"\\n\".join(short_description[2:])\n\n\nsetup(\n    # Self-descriptive entries which should always be present\n    name='espaloma',\n    author='Yuanqing Wang @ choderalab // MSKCC',\n    author_email='wangyq@wangyq.net',\n    description=short_description[0],\n    long_description=long_description,\n    long_description_content_type=\"text/markdown\",\n    version=versioneer.get_version(),\n    cmdclass=versioneer.get_cmdclass(),\n    license='MIT',\n\n    # Which Python importable modules should be included when your package is installed\n    # Handled automatically by setuptools. Use 'exclude' to prevent some specific\n    # subpackage(s) from being added, if needed\n    packages=find_packages(),\n\n    # Optional include package data to ship with your package\n    # Customize MANIFEST.in if the general case does not suit your needs\n    # Comment out this line to prevent the files from being packaged with your software\n    include_package_data=True,\n\n    # Allows `setup.py test` to work correctly with pytest\n    setup_requires=[] + pytest_runner,\n\n    # Additional entries you may want simply uncomment the lines you want and fill in the data\n    # url='http://www.my_package.com',  # Website\n    # install_requires=[],              # Required packages, pulls from pip if needed; do not use for Conda deployment\n    # platforms=['Linux',\n    #            'Mac OS-X',\n    #            'Unix',\n    #            'Windows'],            # Valid platforms your code works on, adjust to your flavor\n    # python_requires=\">=3.5\",          # Python version restrictions\n\n    # Manual control if final package is compressible or not, set False to prevent the .egg from being made\n    # zip_safe=False,\n\n)\n"
  },
  {
    "path": "versioneer.py",
    "content": "\n# Version: 0.29\n\n\"\"\"The Versioneer - like a rocketeer, but for versions.\n\nThe Versioneer\n==============\n\n* like a rocketeer, but for versions!\n* https://github.com/python-versioneer/python-versioneer\n* Brian Warner\n* License: Public Domain (Unlicense)\n* Compatible with: Python 3.7, 3.8, 3.9, 3.10, 3.11 and pypy3\n* [![Latest Version][pypi-image]][pypi-url]\n* [![Build Status][travis-image]][travis-url]\n\nThis is a tool for managing a recorded version number in setuptools-based\npython projects. The goal is to remove the tedious and error-prone \"update\nthe embedded version string\" step from your release process. Making a new\nrelease should be as easy as recording a new tag in your version-control\nsystem, and maybe making new tarballs.\n\n\n## Quick Install\n\nVersioneer provides two installation modes. The \"classic\" vendored mode installs\na copy of versioneer into your repository. The experimental build-time dependency mode\nis intended to allow you to skip this step and simplify the process of upgrading.\n\n### Vendored mode\n\n* `pip install versioneer` to somewhere in your $PATH\n   * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is\n     available, so you can also use `conda install -c conda-forge versioneer`\n* add a `[tool.versioneer]` section to your `pyproject.toml` or a\n  `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md))\n   * Note that you will need to add `tomli; python_version < \"3.11\"` to your\n     build-time dependencies if you use `pyproject.toml`\n* run `versioneer install --vendor` in your source tree, commit the results\n* verify version information with `python setup.py version`\n\n### Build-time dependency mode\n\n* `pip install versioneer` to somewhere in your $PATH\n   * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is\n     available, so you can also use `conda install -c conda-forge versioneer`\n* add a `[tool.versioneer]` section to your `pyproject.toml` or a\n  `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md))\n* add `versioneer` (with `[toml]` extra, if configuring in `pyproject.toml`)\n  to the `requires` key of the `build-system` table in `pyproject.toml`:\n  ```toml\n  [build-system]\n  requires = [\"setuptools\", \"versioneer[toml]\"]\n  build-backend = \"setuptools.build_meta\"\n  ```\n* run `versioneer install --no-vendor` in your source tree, commit the results\n* verify version information with `python setup.py version`\n\n## Version Identifiers\n\nSource trees come from a variety of places:\n\n* a version-control system checkout (mostly used by developers)\n* a nightly tarball, produced by build automation\n* a snapshot tarball, produced by a web-based VCS browser, like github's\n  \"tarball from tag\" feature\n* a release tarball, produced by \"setup.py sdist\", distributed through PyPI\n\nWithin each source tree, the version identifier (either a string or a number,\nthis tool is format-agnostic) can come from a variety of places:\n\n* ask the VCS tool itself, e.g. \"git describe\" (for checkouts), which knows\n  about recent \"tags\" and an absolute revision-id\n* the name of the directory into which the tarball was unpacked\n* an expanded VCS keyword ($Id$, etc)\n* a `_version.py` created by some earlier build step\n\nFor released software, the version identifier is closely related to a VCS\ntag. Some projects use tag names that include more than just the version\nstring (e.g. \"myproject-1.2\" instead of just \"1.2\"), in which case the tool\nneeds to strip the tag prefix to extract the version identifier. For\nunreleased software (between tags), the version identifier should provide\nenough information to help developers recreate the same tree, while also\ngiving them an idea of roughly how old the tree is (after version 1.2, before\nversion 1.3). Many VCS systems can report a description that captures this,\nfor example `git describe --tags --dirty --always` reports things like\n\"0.7-1-g574ab98-dirty\" to indicate that the checkout is one revision past the\n0.7 tag, has a unique revision id of \"574ab98\", and is \"dirty\" (it has\nuncommitted changes).\n\nThe version identifier is used for multiple purposes:\n\n* to allow the module to self-identify its version: `myproject.__version__`\n* to choose a name and prefix for a 'setup.py sdist' tarball\n\n## Theory of Operation\n\nVersioneer works by adding a special `_version.py` file into your source\ntree, where your `__init__.py` can import it. This `_version.py` knows how to\ndynamically ask the VCS tool for version information at import time.\n\n`_version.py` also contains `$Revision$` markers, and the installation\nprocess marks `_version.py` to have this marker rewritten with a tag name\nduring the `git archive` command. As a result, generated tarballs will\ncontain enough information to get the proper version.\n\nTo allow `setup.py` to compute a version too, a `versioneer.py` is added to\nthe top level of your source tree, next to `setup.py` and the `setup.cfg`\nthat configures it. This overrides several distutils/setuptools commands to\ncompute the version when invoked, and changes `setup.py build` and `setup.py\nsdist` to replace `_version.py` with a small static file that contains just\nthe generated version data.\n\n## Installation\n\nSee [INSTALL.md](./INSTALL.md) for detailed installation instructions.\n\n## Version-String Flavors\n\nCode which uses Versioneer can learn about its version string at runtime by\nimporting `_version` from your main `__init__.py` file and running the\n`get_versions()` function. From the \"outside\" (e.g. in `setup.py`), you can\nimport the top-level `versioneer.py` and run `get_versions()`.\n\nBoth functions return a dictionary with different flavors of version\ninformation:\n\n* `['version']`: A condensed version string, rendered using the selected\n  style. This is the most commonly used value for the project's version\n  string. The default \"pep440\" style yields strings like `0.11`,\n  `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the \"Styles\" section\n  below for alternative styles.\n\n* `['full-revisionid']`: detailed revision identifier. For Git, this is the\n  full SHA1 commit id, e.g. \"1076c978a8d3cfc70f408fe5974aa6c092c949ac\".\n\n* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the\n  commit date in ISO 8601 format. This will be None if the date is not\n  available.\n\n* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that\n  this is only accurate if run in a VCS checkout, otherwise it is likely to\n  be False or None\n\n* `['error']`: if the version string could not be computed, this will be set\n  to a string describing the problem, otherwise it will be None. It may be\n  useful to throw an exception in setup.py if this is set, to avoid e.g.\n  creating tarballs with a version string of \"unknown\".\n\nSome variants are more useful than others. Including `full-revisionid` in a\nbug report should allow developers to reconstruct the exact code being tested\n(or indicate the presence of local changes that should be shared with the\ndevelopers). `version` is suitable for display in an \"about\" box or a CLI\n`--version` output: it can be easily compared against release notes and lists\nof bugs fixed in various releases.\n\nThe installer adds the following text to your `__init__.py` to place a basic\nversion in `YOURPROJECT.__version__`:\n\n    from ._version import get_versions\n    __version__ = get_versions()['version']\n    del get_versions\n\n## Styles\n\nThe setup.cfg `style=` configuration controls how the VCS information is\nrendered into a version string.\n\nThe default style, \"pep440\", produces a PEP440-compliant string, equal to the\nun-prefixed tag name for actual releases, and containing an additional \"local\nversion\" section with more detail for in-between builds. For Git, this is\nTAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags\n--dirty --always`. For example \"0.11+2.g1076c97.dirty\" indicates that the\ntree is like the \"1076c97\" commit but has uncommitted changes (\".dirty\"), and\nthat this commit is two revisions (\"+2\") beyond the \"0.11\" tag. For released\nsoftware (exactly equal to a known tag), the identifier will only contain the\nstripped tag, e.g. \"0.11\".\n\nOther styles are available. See [details.md](details.md) in the Versioneer\nsource tree for descriptions.\n\n## Debugging\n\nVersioneer tries to avoid fatal errors: if something goes wrong, it will tend\nto return a version of \"0+unknown\". To investigate the problem, run `setup.py\nversion`, which will run the version-lookup code in a verbose mode, and will\ndisplay the full contents of `get_versions()` (including the `error` string,\nwhich may help identify what went wrong).\n\n## Known Limitations\n\nSome situations are known to cause problems for Versioneer. This details the\nmost significant ones. More can be found on Github\n[issues page](https://github.com/python-versioneer/python-versioneer/issues).\n\n### Subprojects\n\nVersioneer has limited support for source trees in which `setup.py` is not in\nthe root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are\ntwo common reasons why `setup.py` might not be in the root:\n\n* Source trees which contain multiple subprojects, such as\n  [Buildbot](https://github.com/buildbot/buildbot), which contains both\n  \"master\" and \"slave\" subprojects, each with their own `setup.py`,\n  `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI\n  distributions (and upload multiple independently-installable tarballs).\n* Source trees whose main purpose is to contain a C library, but which also\n  provide bindings to Python (and perhaps other languages) in subdirectories.\n\nVersioneer will look for `.git` in parent directories, and most operations\nshould get the right version string. However `pip` and `setuptools` have bugs\nand implementation details which frequently cause `pip install .` from a\nsubproject directory to fail to find a correct version string (so it usually\ndefaults to `0+unknown`).\n\n`pip install --editable .` should work correctly. `setup.py install` might\nwork too.\n\nPip-8.1.1 is known to have this problem, but hopefully it will get fixed in\nsome later version.\n\n[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking\nthis issue. The discussion in\n[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the\nissue from the Versioneer side in more detail.\n[pip PR#3176](https://github.com/pypa/pip/pull/3176) and\n[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve\npip to let Versioneer work correctly.\n\nVersioneer-0.16 and earlier only looked for a `.git` directory next to the\n`setup.cfg`, so subprojects were completely unsupported with those releases.\n\n### Editable installs with setuptools <= 18.5\n\n`setup.py develop` and `pip install --editable .` allow you to install a\nproject into a virtualenv once, then continue editing the source code (and\ntest) without re-installing after every change.\n\n\"Entry-point scripts\" (`setup(entry_points={\"console_scripts\": ..})`) are a\nconvenient way to specify executable scripts that should be installed along\nwith the python package.\n\nThese both work as expected when using modern setuptools. When using\nsetuptools-18.5 or earlier, however, certain operations will cause\n`pkg_resources.DistributionNotFound` errors when running the entrypoint\nscript, which must be resolved by re-installing the package. This happens\nwhen the install happens with one version, then the egg_info data is\nregenerated while a different version is checked out. Many setup.py commands\ncause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into\na different virtualenv), so this can be surprising.\n\n[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes\nthis one, but upgrading to a newer version of setuptools should probably\nresolve it.\n\n\n## Updating Versioneer\n\nTo upgrade your project to a new release of Versioneer, do the following:\n\n* install the new Versioneer (`pip install -U versioneer` or equivalent)\n* edit `setup.cfg` and `pyproject.toml`, if necessary,\n  to include any new configuration settings indicated by the release notes.\n  See [UPGRADING](./UPGRADING.md) for details.\n* re-run `versioneer install --[no-]vendor` in your source tree, to replace\n  `SRC/_version.py`\n* commit any changed files\n\n## Future Directions\n\nThis tool is designed to make it easily extended to other version-control\nsystems: all VCS-specific components are in separate directories like\nsrc/git/ . The top-level `versioneer.py` script is assembled from these\ncomponents by running make-versioneer.py . In the future, make-versioneer.py\nwill take a VCS name as an argument, and will construct a version of\n`versioneer.py` that is specific to the given VCS. It might also take the\nconfiguration arguments that are currently provided manually during\ninstallation by editing setup.py . Alternatively, it might go the other\ndirection and include code from all supported VCS systems, reducing the\nnumber of intermediate scripts.\n\n## Similar projects\n\n* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time\n  dependency\n* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of\n  versioneer\n* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools\n  plugin\n\n## License\n\nTo make Versioneer easier to embed, all its code is dedicated to the public\ndomain. The `_version.py` that it creates is also in the public domain.\nSpecifically, both are released under the \"Unlicense\", as described in\nhttps://unlicense.org/.\n\n[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg\n[pypi-url]: https://pypi.python.org/pypi/versioneer/\n[travis-image]:\nhttps://img.shields.io/travis/com/python-versioneer/python-versioneer.svg\n[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer\n\n\"\"\"\n# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring\n# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements\n# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error\n# pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with\n# pylint:disable=attribute-defined-outside-init,too-many-arguments\n\nimport configparser\nimport errno\nimport json\nimport os\nimport re\nimport subprocess\nimport sys\nfrom pathlib import Path\nfrom typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union\nfrom typing import NoReturn\nimport functools\n\nhave_tomllib = True\nif sys.version_info >= (3, 11):\n    import tomllib\nelse:\n    try:\n        import tomli as tomllib\n    except ImportError:\n        have_tomllib = False\n\n\nclass VersioneerConfig:\n    \"\"\"Container for Versioneer configuration parameters.\"\"\"\n\n    VCS: str\n    style: str\n    tag_prefix: str\n    versionfile_source: str\n    versionfile_build: Optional[str]\n    parentdir_prefix: Optional[str]\n    verbose: Optional[bool]\n\n\ndef get_root() -> str:\n    \"\"\"Get the project root directory.\n\n    We require that all commands are run from the project root, i.e. the\n    directory that contains setup.py, setup.cfg, and versioneer.py .\n    \"\"\"\n    root = os.path.realpath(os.path.abspath(os.getcwd()))\n    setup_py = os.path.join(root, \"setup.py\")\n    pyproject_toml = os.path.join(root, \"pyproject.toml\")\n    versioneer_py = os.path.join(root, \"versioneer.py\")\n    if not (\n        os.path.exists(setup_py)\n        or os.path.exists(pyproject_toml)\n        or os.path.exists(versioneer_py)\n    ):\n        # allow 'python path/to/setup.py COMMAND'\n        root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0])))\n        setup_py = os.path.join(root, \"setup.py\")\n        pyproject_toml = os.path.join(root, \"pyproject.toml\")\n        versioneer_py = os.path.join(root, \"versioneer.py\")\n    if not (\n        os.path.exists(setup_py)\n        or os.path.exists(pyproject_toml)\n        or os.path.exists(versioneer_py)\n    ):\n        err = (\"Versioneer was unable to run the project root directory. \"\n               \"Versioneer requires setup.py to be executed from \"\n               \"its immediate directory (like 'python setup.py COMMAND'), \"\n               \"or in a way that lets it use sys.argv[0] to find the root \"\n               \"(like 'python path/to/setup.py COMMAND').\")\n        raise VersioneerBadRootError(err)\n    try:\n        # Certain runtime workflows (setup.py install/develop in a setuptools\n        # tree) execute all dependencies in a single python process, so\n        # \"versioneer\" may be imported multiple times, and python's shared\n        # module-import table will cache the first one. So we can't use\n        # os.path.dirname(__file__), as that will find whichever\n        # versioneer.py was first imported, even in later projects.\n        my_path = os.path.realpath(os.path.abspath(__file__))\n        me_dir = os.path.normcase(os.path.splitext(my_path)[0])\n        vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0])\n        if me_dir != vsr_dir and \"VERSIONEER_PEP518\" not in globals():\n            print(\"Warning: build in %s is using versioneer.py from %s\"\n                  % (os.path.dirname(my_path), versioneer_py))\n    except NameError:\n        pass\n    return root\n\n\ndef get_config_from_root(root: str) -> VersioneerConfig:\n    \"\"\"Read the project setup.cfg file to determine Versioneer config.\"\"\"\n    # This might raise OSError (if setup.cfg is missing), or\n    # configparser.NoSectionError (if it lacks a [versioneer] section), or\n    # configparser.NoOptionError (if it lacks \"VCS=\"). See the docstring at\n    # the top of versioneer.py for instructions on writing your setup.cfg .\n    root_pth = Path(root)\n    pyproject_toml = root_pth / \"pyproject.toml\"\n    setup_cfg = root_pth / \"setup.cfg\"\n    section: Union[Dict[str, Any], configparser.SectionProxy, None] = None\n    if pyproject_toml.exists() and have_tomllib:\n        try:\n            with open(pyproject_toml, 'rb') as fobj:\n                pp = tomllib.load(fobj)\n            section = pp['tool']['versioneer']\n        except (tomllib.TOMLDecodeError, KeyError) as e:\n            print(f\"Failed to load config from {pyproject_toml}: {e}\")\n            print(\"Try to load it from setup.cfg\")\n    if not section:\n        parser = configparser.ConfigParser()\n        with open(setup_cfg) as cfg_file:\n            parser.read_file(cfg_file)\n        parser.get(\"versioneer\", \"VCS\")  # raise error if missing\n\n        section = parser[\"versioneer\"]\n\n    # `cast`` really shouldn't be used, but its simplest for the\n    # common VersioneerConfig users at the moment. We verify against\n    # `None` values elsewhere where it matters\n\n    cfg = VersioneerConfig()\n    cfg.VCS = section['VCS']\n    cfg.style = section.get(\"style\", \"\")\n    cfg.versionfile_source = cast(str, section.get(\"versionfile_source\"))\n    cfg.versionfile_build = section.get(\"versionfile_build\")\n    cfg.tag_prefix = cast(str, section.get(\"tag_prefix\"))\n    if cfg.tag_prefix in (\"''\", '\"\"', None):\n        cfg.tag_prefix = \"\"\n    cfg.parentdir_prefix = section.get(\"parentdir_prefix\")\n    if isinstance(section, configparser.SectionProxy):\n        # Make sure configparser translates to bool\n        cfg.verbose = section.getboolean(\"verbose\")\n    else:\n        cfg.verbose = section.get(\"verbose\")\n\n    return cfg\n\n\nclass NotThisMethod(Exception):\n    \"\"\"Exception raised if a method is not valid for the current scenario.\"\"\"\n\n\n# these dictionaries contain VCS-specific tools\nLONG_VERSION_PY: Dict[str, str] = {}\nHANDLERS: Dict[str, Dict[str, Callable]] = {}\n\n\ndef register_vcs_handler(vcs: str, method: str) -> Callable:  # decorator\n    \"\"\"Create decorator to mark a method as the handler of a VCS.\"\"\"\n    def decorate(f: Callable) -> Callable:\n        \"\"\"Store f in HANDLERS[vcs][method].\"\"\"\n        HANDLERS.setdefault(vcs, {})[method] = f\n        return f\n    return decorate\n\n\ndef run_command(\n    commands: List[str],\n    args: List[str],\n    cwd: Optional[str] = None,\n    verbose: bool = False,\n    hide_stderr: bool = False,\n    env: Optional[Dict[str, str]] = None,\n) -> Tuple[Optional[str], Optional[int]]:\n    \"\"\"Call the given command(s).\"\"\"\n    assert isinstance(commands, list)\n    process = None\n\n    popen_kwargs: Dict[str, Any] = {}\n    if sys.platform == \"win32\":\n        # This hides the console window if pythonw.exe is used\n        startupinfo = subprocess.STARTUPINFO()\n        startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW\n        popen_kwargs[\"startupinfo\"] = startupinfo\n\n    for command in commands:\n        try:\n            dispcmd = str([command] + args)\n            # remember shell=False, so use git.cmd on windows, not just git\n            process = subprocess.Popen([command] + args, cwd=cwd, env=env,\n                                       stdout=subprocess.PIPE,\n                                       stderr=(subprocess.PIPE if hide_stderr\n                                               else None), **popen_kwargs)\n            break\n        except OSError as e:\n            if e.errno == errno.ENOENT:\n                continue\n            if verbose:\n                print(\"unable to run %s\" % dispcmd)\n                print(e)\n            return None, None\n    else:\n        if verbose:\n            print(\"unable to find command, tried %s\" % (commands,))\n        return None, None\n    stdout = process.communicate()[0].strip().decode()\n    if process.returncode != 0:\n        if verbose:\n            print(\"unable to run %s (error)\" % dispcmd)\n            print(\"stdout was %s\" % stdout)\n        return None, process.returncode\n    return stdout, process.returncode\n\n\nLONG_VERSION_PY['git'] = r'''\n# This file helps to compute a version number in source trees obtained from\n# git-archive tarball (such as those provided by githubs download-from-tag\n# feature). Distribution tarballs (built by setup.py sdist) and build\n# directories (produced by setup.py build) will contain a much shorter file\n# that just contains the computed version number.\n\n# This file is released into the public domain.\n# Generated by versioneer-0.29\n# https://github.com/python-versioneer/python-versioneer\n\n\"\"\"Git implementation of _version.py.\"\"\"\n\nimport errno\nimport os\nimport re\nimport subprocess\nimport sys\nfrom typing import Any, Callable, Dict, List, Optional, Tuple\nimport functools\n\n\ndef get_keywords() -> Dict[str, str]:\n    \"\"\"Get the keywords needed to look up the version information.\"\"\"\n    # these strings will be replaced by git during git-archive.\n    # setup.py/versioneer.py will grep for the variable names, so they must\n    # each be defined on a line of their own. _version.py will just call\n    # get_keywords().\n    git_refnames = \"%(DOLLAR)sFormat:%%d%(DOLLAR)s\"\n    git_full = \"%(DOLLAR)sFormat:%%H%(DOLLAR)s\"\n    git_date = \"%(DOLLAR)sFormat:%%ci%(DOLLAR)s\"\n    keywords = {\"refnames\": git_refnames, \"full\": git_full, \"date\": git_date}\n    return keywords\n\n\nclass VersioneerConfig:\n    \"\"\"Container for Versioneer configuration parameters.\"\"\"\n\n    VCS: str\n    style: str\n    tag_prefix: str\n    parentdir_prefix: str\n    versionfile_source: str\n    verbose: bool\n\n\ndef get_config() -> VersioneerConfig:\n    \"\"\"Create, populate and return the VersioneerConfig() object.\"\"\"\n    # these strings are filled in when 'setup.py versioneer' creates\n    # _version.py\n    cfg = VersioneerConfig()\n    cfg.VCS = \"git\"\n    cfg.style = \"%(STYLE)s\"\n    cfg.tag_prefix = \"%(TAG_PREFIX)s\"\n    cfg.parentdir_prefix = \"%(PARENTDIR_PREFIX)s\"\n    cfg.versionfile_source = \"%(VERSIONFILE_SOURCE)s\"\n    cfg.verbose = False\n    return cfg\n\n\nclass NotThisMethod(Exception):\n    \"\"\"Exception raised if a method is not valid for the current scenario.\"\"\"\n\n\nLONG_VERSION_PY: Dict[str, str] = {}\nHANDLERS: Dict[str, Dict[str, Callable]] = {}\n\n\ndef register_vcs_handler(vcs: str, method: str) -> Callable:  # decorator\n    \"\"\"Create decorator to mark a method as the handler of a VCS.\"\"\"\n    def decorate(f: Callable) -> Callable:\n        \"\"\"Store f in HANDLERS[vcs][method].\"\"\"\n        if vcs not in HANDLERS:\n            HANDLERS[vcs] = {}\n        HANDLERS[vcs][method] = f\n        return f\n    return decorate\n\n\ndef run_command(\n    commands: List[str],\n    args: List[str],\n    cwd: Optional[str] = None,\n    verbose: bool = False,\n    hide_stderr: bool = False,\n    env: Optional[Dict[str, str]] = None,\n) -> Tuple[Optional[str], Optional[int]]:\n    \"\"\"Call the given command(s).\"\"\"\n    assert isinstance(commands, list)\n    process = None\n\n    popen_kwargs: Dict[str, Any] = {}\n    if sys.platform == \"win32\":\n        # This hides the console window if pythonw.exe is used\n        startupinfo = subprocess.STARTUPINFO()\n        startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW\n        popen_kwargs[\"startupinfo\"] = startupinfo\n\n    for command in commands:\n        try:\n            dispcmd = str([command] + args)\n            # remember shell=False, so use git.cmd on windows, not just git\n            process = subprocess.Popen([command] + args, cwd=cwd, env=env,\n                                       stdout=subprocess.PIPE,\n                                       stderr=(subprocess.PIPE if hide_stderr\n                                               else None), **popen_kwargs)\n            break\n        except OSError as e:\n            if e.errno == errno.ENOENT:\n                continue\n            if verbose:\n                print(\"unable to run %%s\" %% dispcmd)\n                print(e)\n            return None, None\n    else:\n        if verbose:\n            print(\"unable to find command, tried %%s\" %% (commands,))\n        return None, None\n    stdout = process.communicate()[0].strip().decode()\n    if process.returncode != 0:\n        if verbose:\n            print(\"unable to run %%s (error)\" %% dispcmd)\n            print(\"stdout was %%s\" %% stdout)\n        return None, process.returncode\n    return stdout, process.returncode\n\n\ndef versions_from_parentdir(\n    parentdir_prefix: str,\n    root: str,\n    verbose: bool,\n) -> Dict[str, Any]:\n    \"\"\"Try to determine the version from the parent directory name.\n\n    Source tarballs conventionally unpack into a directory that includes both\n    the project name and a version string. We will also support searching up\n    two directory levels for an appropriately named parent directory\n    \"\"\"\n    rootdirs = []\n\n    for _ in range(3):\n        dirname = os.path.basename(root)\n        if dirname.startswith(parentdir_prefix):\n            return {\"version\": dirname[len(parentdir_prefix):],\n                    \"full-revisionid\": None,\n                    \"dirty\": False, \"error\": None, \"date\": None}\n        rootdirs.append(root)\n        root = os.path.dirname(root)  # up a level\n\n    if verbose:\n        print(\"Tried directories %%s but none started with prefix %%s\" %%\n              (str(rootdirs), parentdir_prefix))\n    raise NotThisMethod(\"rootdir doesn't start with parentdir_prefix\")\n\n\n@register_vcs_handler(\"git\", \"get_keywords\")\ndef git_get_keywords(versionfile_abs: str) -> Dict[str, str]:\n    \"\"\"Extract version information from the given file.\"\"\"\n    # the code embedded in _version.py can just fetch the value of these\n    # keywords. When used from setup.py, we don't want to import _version.py,\n    # so we do it with a regexp instead. This function is not used from\n    # _version.py.\n    keywords: Dict[str, str] = {}\n    try:\n        with open(versionfile_abs, \"r\") as fobj:\n            for line in fobj:\n                if line.strip().startswith(\"git_refnames =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"refnames\"] = mo.group(1)\n                if line.strip().startswith(\"git_full =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"full\"] = mo.group(1)\n                if line.strip().startswith(\"git_date =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"date\"] = mo.group(1)\n    except OSError:\n        pass\n    return keywords\n\n\n@register_vcs_handler(\"git\", \"keywords\")\ndef git_versions_from_keywords(\n    keywords: Dict[str, str],\n    tag_prefix: str,\n    verbose: bool,\n) -> Dict[str, Any]:\n    \"\"\"Get version information from git keywords.\"\"\"\n    if \"refnames\" not in keywords:\n        raise NotThisMethod(\"Short version file found\")\n    date = keywords.get(\"date\")\n    if date is not None:\n        # Use only the last line.  Previous lines may contain GPG signature\n        # information.\n        date = date.splitlines()[-1]\n\n        # git-2.2.0 added \"%%cI\", which expands to an ISO-8601 -compliant\n        # datestamp. However we prefer \"%%ci\" (which expands to an \"ISO-8601\n        # -like\" string, which we must then edit to make compliant), because\n        # it's been around since git-1.5.3, and it's too difficult to\n        # discover which version we're using, or to work around using an\n        # older one.\n        date = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n    refnames = keywords[\"refnames\"].strip()\n    if refnames.startswith(\"$Format\"):\n        if verbose:\n            print(\"keywords are unexpanded, not using\")\n        raise NotThisMethod(\"unexpanded keywords, not a git-archive tarball\")\n    refs = {r.strip() for r in refnames.strip(\"()\").split(\",\")}\n    # starting in git-1.8.3, tags are listed as \"tag: foo-1.0\" instead of\n    # just \"foo-1.0\". If we see a \"tag: \" prefix, prefer those.\n    TAG = \"tag: \"\n    tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}\n    if not tags:\n        # Either we're using git < 1.8.3, or there really are no tags. We use\n        # a heuristic: assume all version tags have a digit. The old git %%d\n        # expansion behaves like git log --decorate=short and strips out the\n        # refs/heads/ and refs/tags/ prefixes that would let us distinguish\n        # between branches and tags. By ignoring refnames without digits, we\n        # filter out many common branch names like \"release\" and\n        # \"stabilization\", as well as \"HEAD\" and \"master\".\n        tags = {r for r in refs if re.search(r'\\d', r)}\n        if verbose:\n            print(\"discarding '%%s', no digits\" %% \",\".join(refs - tags))\n    if verbose:\n        print(\"likely tags: %%s\" %% \",\".join(sorted(tags)))\n    for ref in sorted(tags):\n        # sorting will prefer e.g. \"2.0\" over \"2.0rc1\"\n        if ref.startswith(tag_prefix):\n            r = ref[len(tag_prefix):]\n            # Filter out refs that exactly match prefix or that don't start\n            # with a number once the prefix is stripped (mostly a concern\n            # when prefix is '')\n            if not re.match(r'\\d', r):\n                continue\n            if verbose:\n                print(\"picking %%s\" %% r)\n            return {\"version\": r,\n                    \"full-revisionid\": keywords[\"full\"].strip(),\n                    \"dirty\": False, \"error\": None,\n                    \"date\": date}\n    # no suitable tags, so version is \"0+unknown\", but full hex is still there\n    if verbose:\n        print(\"no suitable tags, using unknown + full revision id\")\n    return {\"version\": \"0+unknown\",\n            \"full-revisionid\": keywords[\"full\"].strip(),\n            \"dirty\": False, \"error\": \"no suitable tags\", \"date\": None}\n\n\n@register_vcs_handler(\"git\", \"pieces_from_vcs\")\ndef git_pieces_from_vcs(\n    tag_prefix: str,\n    root: str,\n    verbose: bool,\n    runner: Callable = run_command\n) -> Dict[str, Any]:\n    \"\"\"Get version from 'git describe' in the root of the source tree.\n\n    This only gets called if the git-archive 'subst' keywords were *not*\n    expanded, and _version.py hasn't already been rewritten with a short\n    version string, meaning we're inside a checked out source tree.\n    \"\"\"\n    GITS = [\"git\"]\n    if sys.platform == \"win32\":\n        GITS = [\"git.cmd\", \"git.exe\"]\n\n    # GIT_DIR can interfere with correct operation of Versioneer.\n    # It may be intended to be passed to the Versioneer-versioned project,\n    # but that should not change where we get our version from.\n    env = os.environ.copy()\n    env.pop(\"GIT_DIR\", None)\n    runner = functools.partial(runner, env=env)\n\n    _, rc = runner(GITS, [\"rev-parse\", \"--git-dir\"], cwd=root,\n                   hide_stderr=not verbose)\n    if rc != 0:\n        if verbose:\n            print(\"Directory %%s not under git control\" %% root)\n        raise NotThisMethod(\"'git rev-parse --git-dir' returned error\")\n\n    # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]\n    # if there isn't one, this yields HEX[-dirty] (no NUM)\n    describe_out, rc = runner(GITS, [\n        \"describe\", \"--tags\", \"--dirty\", \"--always\", \"--long\",\n        \"--match\", f\"{tag_prefix}[[:digit:]]*\"\n    ], cwd=root)\n    # --long was added in git-1.5.5\n    if describe_out is None:\n        raise NotThisMethod(\"'git describe' failed\")\n    describe_out = describe_out.strip()\n    full_out, rc = runner(GITS, [\"rev-parse\", \"HEAD\"], cwd=root)\n    if full_out is None:\n        raise NotThisMethod(\"'git rev-parse' failed\")\n    full_out = full_out.strip()\n\n    pieces: Dict[str, Any] = {}\n    pieces[\"long\"] = full_out\n    pieces[\"short\"] = full_out[:7]  # maybe improved later\n    pieces[\"error\"] = None\n\n    branch_name, rc = runner(GITS, [\"rev-parse\", \"--abbrev-ref\", \"HEAD\"],\n                             cwd=root)\n    # --abbrev-ref was added in git-1.6.3\n    if rc != 0 or branch_name is None:\n        raise NotThisMethod(\"'git rev-parse --abbrev-ref' returned error\")\n    branch_name = branch_name.strip()\n\n    if branch_name == \"HEAD\":\n        # If we aren't exactly on a branch, pick a branch which represents\n        # the current commit. If all else fails, we are on a branchless\n        # commit.\n        branches, rc = runner(GITS, [\"branch\", \"--contains\"], cwd=root)\n        # --contains was added in git-1.5.4\n        if rc != 0 or branches is None:\n            raise NotThisMethod(\"'git branch --contains' returned error\")\n        branches = branches.split(\"\\n\")\n\n        # Remove the first line if we're running detached\n        if \"(\" in branches[0]:\n            branches.pop(0)\n\n        # Strip off the leading \"* \" from the list of branches.\n        branches = [branch[2:] for branch in branches]\n        if \"master\" in branches:\n            branch_name = \"master\"\n        elif not branches:\n            branch_name = None\n        else:\n            # Pick the first branch that is returned. Good or bad.\n            branch_name = branches[0]\n\n    pieces[\"branch\"] = branch_name\n\n    # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]\n    # TAG might have hyphens.\n    git_describe = describe_out\n\n    # look for -dirty suffix\n    dirty = git_describe.endswith(\"-dirty\")\n    pieces[\"dirty\"] = dirty\n    if dirty:\n        git_describe = git_describe[:git_describe.rindex(\"-dirty\")]\n\n    # now we have TAG-NUM-gHEX or HEX\n\n    if \"-\" in git_describe:\n        # TAG-NUM-gHEX\n        mo = re.search(r'^(.+)-(\\d+)-g([0-9a-f]+)$', git_describe)\n        if not mo:\n            # unparsable. Maybe git-describe is misbehaving?\n            pieces[\"error\"] = (\"unable to parse git-describe output: '%%s'\"\n                               %% describe_out)\n            return pieces\n\n        # tag\n        full_tag = mo.group(1)\n        if not full_tag.startswith(tag_prefix):\n            if verbose:\n                fmt = \"tag '%%s' doesn't start with prefix '%%s'\"\n                print(fmt %% (full_tag, tag_prefix))\n            pieces[\"error\"] = (\"tag '%%s' doesn't start with prefix '%%s'\"\n                               %% (full_tag, tag_prefix))\n            return pieces\n        pieces[\"closest-tag\"] = full_tag[len(tag_prefix):]\n\n        # distance: number of commits since tag\n        pieces[\"distance\"] = int(mo.group(2))\n\n        # commit: short hex revision ID\n        pieces[\"short\"] = mo.group(3)\n\n    else:\n        # HEX: no tags\n        pieces[\"closest-tag\"] = None\n        out, rc = runner(GITS, [\"rev-list\", \"HEAD\", \"--left-right\"], cwd=root)\n        pieces[\"distance\"] = len(out.split())  # total number of commits\n\n    # commit date: see ISO-8601 comment in git_versions_from_keywords()\n    date = runner(GITS, [\"show\", \"-s\", \"--format=%%ci\", \"HEAD\"], cwd=root)[0].strip()\n    # Use only the last line.  Previous lines may contain GPG signature\n    # information.\n    date = date.splitlines()[-1]\n    pieces[\"date\"] = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n\n    return pieces\n\n\ndef plus_or_dot(pieces: Dict[str, Any]) -> str:\n    \"\"\"Return a + if we don't already have one, else return a .\"\"\"\n    if \"+\" in pieces.get(\"closest-tag\", \"\"):\n        return \".\"\n    return \"+\"\n\n\ndef render_pep440(pieces: Dict[str, Any]) -> str:\n    \"\"\"Build up version string, with post-release \"local version identifier\".\n\n    Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you\n    get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty\n\n    Exceptions:\n    1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += plus_or_dot(pieces)\n            rendered += \"%%d.g%%s\" %% (pieces[\"distance\"], pieces[\"short\"])\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0+untagged.%%d.g%%s\" %% (pieces[\"distance\"],\n                                          pieces[\"short\"])\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef render_pep440_branch(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .\n\n    The \".dev0\" means not master branch. Note that .dev0 sorts backwards\n    (a feature branch will appear \"older\" than the master branch).\n\n    Exceptions:\n    1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            if pieces[\"branch\"] != \"master\":\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"%%d.g%%s\" %% (pieces[\"distance\"], pieces[\"short\"])\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0\"\n        if pieces[\"branch\"] != \"master\":\n            rendered += \".dev0\"\n        rendered += \"+untagged.%%d.g%%s\" %% (pieces[\"distance\"],\n                                          pieces[\"short\"])\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef pep440_split_post(ver: str) -> Tuple[str, Optional[int]]:\n    \"\"\"Split pep440 version string at the post-release segment.\n\n    Returns the release segments before the post-release and the\n    post-release version number (or -1 if no post-release segment is present).\n    \"\"\"\n    vc = str.split(ver, \".post\")\n    return vc[0], int(vc[1] or 0) if len(vc) == 2 else None\n\n\ndef render_pep440_pre(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[.postN.devDISTANCE] -- No -dirty.\n\n    Exceptions:\n    1: no tags. 0.post0.devDISTANCE\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        if pieces[\"distance\"]:\n            # update the post release segment\n            tag_version, post_version = pep440_split_post(pieces[\"closest-tag\"])\n            rendered = tag_version\n            if post_version is not None:\n                rendered += \".post%%d.dev%%d\" %% (post_version + 1, pieces[\"distance\"])\n            else:\n                rendered += \".post0.dev%%d\" %% (pieces[\"distance\"])\n        else:\n            # no commits, use the tag as the version\n            rendered = pieces[\"closest-tag\"]\n    else:\n        # exception #1\n        rendered = \"0.post0.dev%%d\" %% pieces[\"distance\"]\n    return rendered\n\n\ndef render_pep440_post(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[.postDISTANCE[.dev0]+gHEX] .\n\n    The \".dev0\" means dirty. Note that .dev0 sorts backwards\n    (a dirty tree will appear \"older\" than the corresponding clean one),\n    but you shouldn't be releasing software with -dirty anyways.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%%d\" %% pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"g%%s\" %% pieces[\"short\"]\n    else:\n        # exception #1\n        rendered = \"0.post%%d\" %% pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n        rendered += \"+g%%s\" %% pieces[\"short\"]\n    return rendered\n\n\ndef render_pep440_post_branch(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .\n\n    The \".dev0\" means not master branch.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%%d\" %% pieces[\"distance\"]\n            if pieces[\"branch\"] != \"master\":\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"g%%s\" %% pieces[\"short\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0.post%%d\" %% pieces[\"distance\"]\n        if pieces[\"branch\"] != \"master\":\n            rendered += \".dev0\"\n        rendered += \"+g%%s\" %% pieces[\"short\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef render_pep440_old(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[.postDISTANCE[.dev0]] .\n\n    The \".dev0\" means dirty.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%%d\" %% pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n    else:\n        # exception #1\n        rendered = \"0.post%%d\" %% pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n    return rendered\n\n\ndef render_git_describe(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[-DISTANCE-gHEX][-dirty].\n\n    Like 'git describe --tags --dirty --always'.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"]:\n            rendered += \"-%%d-g%%s\" %% (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render_git_describe_long(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG-DISTANCE-gHEX[-dirty].\n\n    Like 'git describe --tags --dirty --always -long'.\n    The distance/hash is unconditional.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        rendered += \"-%%d-g%%s\" %% (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]:\n    \"\"\"Render the given version pieces into the requested style.\"\"\"\n    if pieces[\"error\"]:\n        return {\"version\": \"unknown\",\n                \"full-revisionid\": pieces.get(\"long\"),\n                \"dirty\": None,\n                \"error\": pieces[\"error\"],\n                \"date\": None}\n\n    if not style or style == \"default\":\n        style = \"pep440\"  # the default\n\n    if style == \"pep440\":\n        rendered = render_pep440(pieces)\n    elif style == \"pep440-branch\":\n        rendered = render_pep440_branch(pieces)\n    elif style == \"pep440-pre\":\n        rendered = render_pep440_pre(pieces)\n    elif style == \"pep440-post\":\n        rendered = render_pep440_post(pieces)\n    elif style == \"pep440-post-branch\":\n        rendered = render_pep440_post_branch(pieces)\n    elif style == \"pep440-old\":\n        rendered = render_pep440_old(pieces)\n    elif style == \"git-describe\":\n        rendered = render_git_describe(pieces)\n    elif style == \"git-describe-long\":\n        rendered = render_git_describe_long(pieces)\n    else:\n        raise ValueError(\"unknown style '%%s'\" %% style)\n\n    return {\"version\": rendered, \"full-revisionid\": pieces[\"long\"],\n            \"dirty\": pieces[\"dirty\"], \"error\": None,\n            \"date\": pieces.get(\"date\")}\n\n\ndef get_versions() -> Dict[str, Any]:\n    \"\"\"Get version information or return default if unable to do so.\"\"\"\n    # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have\n    # __file__, we can work backwards from there to the root. Some\n    # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which\n    # case we can only use expanded keywords.\n\n    cfg = get_config()\n    verbose = cfg.verbose\n\n    try:\n        return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,\n                                          verbose)\n    except NotThisMethod:\n        pass\n\n    try:\n        root = os.path.realpath(__file__)\n        # versionfile_source is the relative path from the top of the source\n        # tree (where the .git directory might live) to this file. Invert\n        # this to find the root from __file__.\n        for _ in cfg.versionfile_source.split('/'):\n            root = os.path.dirname(root)\n    except NameError:\n        return {\"version\": \"0+unknown\", \"full-revisionid\": None,\n                \"dirty\": None,\n                \"error\": \"unable to find root of source tree\",\n                \"date\": None}\n\n    try:\n        pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)\n        return render(pieces, cfg.style)\n    except NotThisMethod:\n        pass\n\n    try:\n        if cfg.parentdir_prefix:\n            return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)\n    except NotThisMethod:\n        pass\n\n    return {\"version\": \"0+unknown\", \"full-revisionid\": None,\n            \"dirty\": None,\n            \"error\": \"unable to compute version\", \"date\": None}\n'''\n\n\n@register_vcs_handler(\"git\", \"get_keywords\")\ndef git_get_keywords(versionfile_abs: str) -> Dict[str, str]:\n    \"\"\"Extract version information from the given file.\"\"\"\n    # the code embedded in _version.py can just fetch the value of these\n    # keywords. When used from setup.py, we don't want to import _version.py,\n    # so we do it with a regexp instead. This function is not used from\n    # _version.py.\n    keywords: Dict[str, str] = {}\n    try:\n        with open(versionfile_abs, \"r\") as fobj:\n            for line in fobj:\n                if line.strip().startswith(\"git_refnames =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"refnames\"] = mo.group(1)\n                if line.strip().startswith(\"git_full =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"full\"] = mo.group(1)\n                if line.strip().startswith(\"git_date =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"date\"] = mo.group(1)\n    except OSError:\n        pass\n    return keywords\n\n\n@register_vcs_handler(\"git\", \"keywords\")\ndef git_versions_from_keywords(\n    keywords: Dict[str, str],\n    tag_prefix: str,\n    verbose: bool,\n) -> Dict[str, Any]:\n    \"\"\"Get version information from git keywords.\"\"\"\n    if \"refnames\" not in keywords:\n        raise NotThisMethod(\"Short version file found\")\n    date = keywords.get(\"date\")\n    if date is not None:\n        # Use only the last line.  Previous lines may contain GPG signature\n        # information.\n        date = date.splitlines()[-1]\n\n        # git-2.2.0 added \"%cI\", which expands to an ISO-8601 -compliant\n        # datestamp. However we prefer \"%ci\" (which expands to an \"ISO-8601\n        # -like\" string, which we must then edit to make compliant), because\n        # it's been around since git-1.5.3, and it's too difficult to\n        # discover which version we're using, or to work around using an\n        # older one.\n        date = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n    refnames = keywords[\"refnames\"].strip()\n    if refnames.startswith(\"$Format\"):\n        if verbose:\n            print(\"keywords are unexpanded, not using\")\n        raise NotThisMethod(\"unexpanded keywords, not a git-archive tarball\")\n    refs = {r.strip() for r in refnames.strip(\"()\").split(\",\")}\n    # starting in git-1.8.3, tags are listed as \"tag: foo-1.0\" instead of\n    # just \"foo-1.0\". If we see a \"tag: \" prefix, prefer those.\n    TAG = \"tag: \"\n    tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}\n    if not tags:\n        # Either we're using git < 1.8.3, or there really are no tags. We use\n        # a heuristic: assume all version tags have a digit. The old git %d\n        # expansion behaves like git log --decorate=short and strips out the\n        # refs/heads/ and refs/tags/ prefixes that would let us distinguish\n        # between branches and tags. By ignoring refnames without digits, we\n        # filter out many common branch names like \"release\" and\n        # \"stabilization\", as well as \"HEAD\" and \"master\".\n        tags = {r for r in refs if re.search(r'\\d', r)}\n        if verbose:\n            print(\"discarding '%s', no digits\" % \",\".join(refs - tags))\n    if verbose:\n        print(\"likely tags: %s\" % \",\".join(sorted(tags)))\n    for ref in sorted(tags):\n        # sorting will prefer e.g. \"2.0\" over \"2.0rc1\"\n        if ref.startswith(tag_prefix):\n            r = ref[len(tag_prefix):]\n            # Filter out refs that exactly match prefix or that don't start\n            # with a number once the prefix is stripped (mostly a concern\n            # when prefix is '')\n            if not re.match(r'\\d', r):\n                continue\n            if verbose:\n                print(\"picking %s\" % r)\n            return {\"version\": r,\n                    \"full-revisionid\": keywords[\"full\"].strip(),\n                    \"dirty\": False, \"error\": None,\n                    \"date\": date}\n    # no suitable tags, so version is \"0+unknown\", but full hex is still there\n    if verbose:\n        print(\"no suitable tags, using unknown + full revision id\")\n    return {\"version\": \"0+unknown\",\n            \"full-revisionid\": keywords[\"full\"].strip(),\n            \"dirty\": False, \"error\": \"no suitable tags\", \"date\": None}\n\n\n@register_vcs_handler(\"git\", \"pieces_from_vcs\")\ndef git_pieces_from_vcs(\n    tag_prefix: str,\n    root: str,\n    verbose: bool,\n    runner: Callable = run_command\n) -> Dict[str, Any]:\n    \"\"\"Get version from 'git describe' in the root of the source tree.\n\n    This only gets called if the git-archive 'subst' keywords were *not*\n    expanded, and _version.py hasn't already been rewritten with a short\n    version string, meaning we're inside a checked out source tree.\n    \"\"\"\n    GITS = [\"git\"]\n    if sys.platform == \"win32\":\n        GITS = [\"git.cmd\", \"git.exe\"]\n\n    # GIT_DIR can interfere with correct operation of Versioneer.\n    # It may be intended to be passed to the Versioneer-versioned project,\n    # but that should not change where we get our version from.\n    env = os.environ.copy()\n    env.pop(\"GIT_DIR\", None)\n    runner = functools.partial(runner, env=env)\n\n    _, rc = runner(GITS, [\"rev-parse\", \"--git-dir\"], cwd=root,\n                   hide_stderr=not verbose)\n    if rc != 0:\n        if verbose:\n            print(\"Directory %s not under git control\" % root)\n        raise NotThisMethod(\"'git rev-parse --git-dir' returned error\")\n\n    # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]\n    # if there isn't one, this yields HEX[-dirty] (no NUM)\n    describe_out, rc = runner(GITS, [\n        \"describe\", \"--tags\", \"--dirty\", \"--always\", \"--long\",\n        \"--match\", f\"{tag_prefix}[[:digit:]]*\"\n    ], cwd=root)\n    # --long was added in git-1.5.5\n    if describe_out is None:\n        raise NotThisMethod(\"'git describe' failed\")\n    describe_out = describe_out.strip()\n    full_out, rc = runner(GITS, [\"rev-parse\", \"HEAD\"], cwd=root)\n    if full_out is None:\n        raise NotThisMethod(\"'git rev-parse' failed\")\n    full_out = full_out.strip()\n\n    pieces: Dict[str, Any] = {}\n    pieces[\"long\"] = full_out\n    pieces[\"short\"] = full_out[:7]  # maybe improved later\n    pieces[\"error\"] = None\n\n    branch_name, rc = runner(GITS, [\"rev-parse\", \"--abbrev-ref\", \"HEAD\"],\n                             cwd=root)\n    # --abbrev-ref was added in git-1.6.3\n    if rc != 0 or branch_name is None:\n        raise NotThisMethod(\"'git rev-parse --abbrev-ref' returned error\")\n    branch_name = branch_name.strip()\n\n    if branch_name == \"HEAD\":\n        # If we aren't exactly on a branch, pick a branch which represents\n        # the current commit. If all else fails, we are on a branchless\n        # commit.\n        branches, rc = runner(GITS, [\"branch\", \"--contains\"], cwd=root)\n        # --contains was added in git-1.5.4\n        if rc != 0 or branches is None:\n            raise NotThisMethod(\"'git branch --contains' returned error\")\n        branches = branches.split(\"\\n\")\n\n        # Remove the first line if we're running detached\n        if \"(\" in branches[0]:\n            branches.pop(0)\n\n        # Strip off the leading \"* \" from the list of branches.\n        branches = [branch[2:] for branch in branches]\n        if \"master\" in branches:\n            branch_name = \"master\"\n        elif not branches:\n            branch_name = None\n        else:\n            # Pick the first branch that is returned. Good or bad.\n            branch_name = branches[0]\n\n    pieces[\"branch\"] = branch_name\n\n    # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]\n    # TAG might have hyphens.\n    git_describe = describe_out\n\n    # look for -dirty suffix\n    dirty = git_describe.endswith(\"-dirty\")\n    pieces[\"dirty\"] = dirty\n    if dirty:\n        git_describe = git_describe[:git_describe.rindex(\"-dirty\")]\n\n    # now we have TAG-NUM-gHEX or HEX\n\n    if \"-\" in git_describe:\n        # TAG-NUM-gHEX\n        mo = re.search(r'^(.+)-(\\d+)-g([0-9a-f]+)$', git_describe)\n        if not mo:\n            # unparsable. Maybe git-describe is misbehaving?\n            pieces[\"error\"] = (\"unable to parse git-describe output: '%s'\"\n                               % describe_out)\n            return pieces\n\n        # tag\n        full_tag = mo.group(1)\n        if not full_tag.startswith(tag_prefix):\n            if verbose:\n                fmt = \"tag '%s' doesn't start with prefix '%s'\"\n                print(fmt % (full_tag, tag_prefix))\n            pieces[\"error\"] = (\"tag '%s' doesn't start with prefix '%s'\"\n                               % (full_tag, tag_prefix))\n            return pieces\n        pieces[\"closest-tag\"] = full_tag[len(tag_prefix):]\n\n        # distance: number of commits since tag\n        pieces[\"distance\"] = int(mo.group(2))\n\n        # commit: short hex revision ID\n        pieces[\"short\"] = mo.group(3)\n\n    else:\n        # HEX: no tags\n        pieces[\"closest-tag\"] = None\n        out, rc = runner(GITS, [\"rev-list\", \"HEAD\", \"--left-right\"], cwd=root)\n        pieces[\"distance\"] = len(out.split())  # total number of commits\n\n    # commit date: see ISO-8601 comment in git_versions_from_keywords()\n    date = runner(GITS, [\"show\", \"-s\", \"--format=%ci\", \"HEAD\"], cwd=root)[0].strip()\n    # Use only the last line.  Previous lines may contain GPG signature\n    # information.\n    date = date.splitlines()[-1]\n    pieces[\"date\"] = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n\n    return pieces\n\n\ndef do_vcs_install(versionfile_source: str, ipy: Optional[str]) -> None:\n    \"\"\"Git-specific installation logic for Versioneer.\n\n    For Git, this means creating/changing .gitattributes to mark _version.py\n    for export-subst keyword substitution.\n    \"\"\"\n    GITS = [\"git\"]\n    if sys.platform == \"win32\":\n        GITS = [\"git.cmd\", \"git.exe\"]\n    files = [versionfile_source]\n    if ipy:\n        files.append(ipy)\n    if \"VERSIONEER_PEP518\" not in globals():\n        try:\n            my_path = __file__\n            if my_path.endswith((\".pyc\", \".pyo\")):\n                my_path = os.path.splitext(my_path)[0] + \".py\"\n            versioneer_file = os.path.relpath(my_path)\n        except NameError:\n            versioneer_file = \"versioneer.py\"\n        files.append(versioneer_file)\n    present = False\n    try:\n        with open(\".gitattributes\", \"r\") as fobj:\n            for line in fobj:\n                if line.strip().startswith(versionfile_source):\n                    if \"export-subst\" in line.strip().split()[1:]:\n                        present = True\n                        break\n    except OSError:\n        pass\n    if not present:\n        with open(\".gitattributes\", \"a+\") as fobj:\n            fobj.write(f\"{versionfile_source} export-subst\\n\")\n        files.append(\".gitattributes\")\n    run_command(GITS, [\"add\", \"--\"] + files)\n\n\ndef versions_from_parentdir(\n    parentdir_prefix: str,\n    root: str,\n    verbose: bool,\n) -> Dict[str, Any]:\n    \"\"\"Try to determine the version from the parent directory name.\n\n    Source tarballs conventionally unpack into a directory that includes both\n    the project name and a version string. We will also support searching up\n    two directory levels for an appropriately named parent directory\n    \"\"\"\n    rootdirs = []\n\n    for _ in range(3):\n        dirname = os.path.basename(root)\n        if dirname.startswith(parentdir_prefix):\n            return {\"version\": dirname[len(parentdir_prefix):],\n                    \"full-revisionid\": None,\n                    \"dirty\": False, \"error\": None, \"date\": None}\n        rootdirs.append(root)\n        root = os.path.dirname(root)  # up a level\n\n    if verbose:\n        print(\"Tried directories %s but none started with prefix %s\" %\n              (str(rootdirs), parentdir_prefix))\n    raise NotThisMethod(\"rootdir doesn't start with parentdir_prefix\")\n\n\nSHORT_VERSION_PY = \"\"\"\n# This file was generated by 'versioneer.py' (0.29) from\n# revision-control system data, or from the parent directory name of an\n# unpacked source archive. Distribution tarballs contain a pre-generated copy\n# of this file.\n\nimport json\n\nversion_json = '''\n%s\n'''  # END VERSION_JSON\n\n\ndef get_versions():\n    return json.loads(version_json)\n\"\"\"\n\n\ndef versions_from_file(filename: str) -> Dict[str, Any]:\n    \"\"\"Try to determine the version from _version.py if present.\"\"\"\n    try:\n        with open(filename) as f:\n            contents = f.read()\n    except OSError:\n        raise NotThisMethod(\"unable to read _version.py\")\n    mo = re.search(r\"version_json = '''\\n(.*)'''  # END VERSION_JSON\",\n                   contents, re.M | re.S)\n    if not mo:\n        mo = re.search(r\"version_json = '''\\r\\n(.*)'''  # END VERSION_JSON\",\n                       contents, re.M | re.S)\n    if not mo:\n        raise NotThisMethod(\"no version_json in _version.py\")\n    return json.loads(mo.group(1))\n\n\ndef write_to_version_file(filename: str, versions: Dict[str, Any]) -> None:\n    \"\"\"Write the given version number to the given _version.py file.\"\"\"\n    contents = json.dumps(versions, sort_keys=True,\n                          indent=1, separators=(\",\", \": \"))\n    with open(filename, \"w\") as f:\n        f.write(SHORT_VERSION_PY % contents)\n\n    print(\"set %s to '%s'\" % (filename, versions[\"version\"]))\n\n\ndef plus_or_dot(pieces: Dict[str, Any]) -> str:\n    \"\"\"Return a + if we don't already have one, else return a .\"\"\"\n    if \"+\" in pieces.get(\"closest-tag\", \"\"):\n        return \".\"\n    return \"+\"\n\n\ndef render_pep440(pieces: Dict[str, Any]) -> str:\n    \"\"\"Build up version string, with post-release \"local version identifier\".\n\n    Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you\n    get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty\n\n    Exceptions:\n    1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += plus_or_dot(pieces)\n            rendered += \"%d.g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0+untagged.%d.g%s\" % (pieces[\"distance\"],\n                                          pieces[\"short\"])\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef render_pep440_branch(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .\n\n    The \".dev0\" means not master branch. Note that .dev0 sorts backwards\n    (a feature branch will appear \"older\" than the master branch).\n\n    Exceptions:\n    1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            if pieces[\"branch\"] != \"master\":\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"%d.g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0\"\n        if pieces[\"branch\"] != \"master\":\n            rendered += \".dev0\"\n        rendered += \"+untagged.%d.g%s\" % (pieces[\"distance\"],\n                                          pieces[\"short\"])\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef pep440_split_post(ver: str) -> Tuple[str, Optional[int]]:\n    \"\"\"Split pep440 version string at the post-release segment.\n\n    Returns the release segments before the post-release and the\n    post-release version number (or -1 if no post-release segment is present).\n    \"\"\"\n    vc = str.split(ver, \".post\")\n    return vc[0], int(vc[1] or 0) if len(vc) == 2 else None\n\n\ndef render_pep440_pre(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[.postN.devDISTANCE] -- No -dirty.\n\n    Exceptions:\n    1: no tags. 0.post0.devDISTANCE\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        if pieces[\"distance\"]:\n            # update the post release segment\n            tag_version, post_version = pep440_split_post(pieces[\"closest-tag\"])\n            rendered = tag_version\n            if post_version is not None:\n                rendered += \".post%d.dev%d\" % (post_version + 1, pieces[\"distance\"])\n            else:\n                rendered += \".post0.dev%d\" % (pieces[\"distance\"])\n        else:\n            # no commits, use the tag as the version\n            rendered = pieces[\"closest-tag\"]\n    else:\n        # exception #1\n        rendered = \"0.post0.dev%d\" % pieces[\"distance\"]\n    return rendered\n\n\ndef render_pep440_post(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[.postDISTANCE[.dev0]+gHEX] .\n\n    The \".dev0\" means dirty. Note that .dev0 sorts backwards\n    (a dirty tree will appear \"older\" than the corresponding clean one),\n    but you shouldn't be releasing software with -dirty anyways.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"g%s\" % pieces[\"short\"]\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n        rendered += \"+g%s\" % pieces[\"short\"]\n    return rendered\n\n\ndef render_pep440_post_branch(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .\n\n    The \".dev0\" means not master branch.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"branch\"] != \"master\":\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"g%s\" % pieces[\"short\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"branch\"] != \"master\":\n            rendered += \".dev0\"\n        rendered += \"+g%s\" % pieces[\"short\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef render_pep440_old(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[.postDISTANCE[.dev0]] .\n\n    The \".dev0\" means dirty.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n    return rendered\n\n\ndef render_git_describe(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG[-DISTANCE-gHEX][-dirty].\n\n    Like 'git describe --tags --dirty --always'.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"]:\n            rendered += \"-%d-g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render_git_describe_long(pieces: Dict[str, Any]) -> str:\n    \"\"\"TAG-DISTANCE-gHEX[-dirty].\n\n    Like 'git describe --tags --dirty --always -long'.\n    The distance/hash is unconditional.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        rendered += \"-%d-g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]:\n    \"\"\"Render the given version pieces into the requested style.\"\"\"\n    if pieces[\"error\"]:\n        return {\"version\": \"unknown\",\n                \"full-revisionid\": pieces.get(\"long\"),\n                \"dirty\": None,\n                \"error\": pieces[\"error\"],\n                \"date\": None}\n\n    if not style or style == \"default\":\n        style = \"pep440\"  # the default\n\n    if style == \"pep440\":\n        rendered = render_pep440(pieces)\n    elif style == \"pep440-branch\":\n        rendered = render_pep440_branch(pieces)\n    elif style == \"pep440-pre\":\n        rendered = render_pep440_pre(pieces)\n    elif style == \"pep440-post\":\n        rendered = render_pep440_post(pieces)\n    elif style == \"pep440-post-branch\":\n        rendered = render_pep440_post_branch(pieces)\n    elif style == \"pep440-old\":\n        rendered = render_pep440_old(pieces)\n    elif style == \"git-describe\":\n        rendered = render_git_describe(pieces)\n    elif style == \"git-describe-long\":\n        rendered = render_git_describe_long(pieces)\n    else:\n        raise ValueError(\"unknown style '%s'\" % style)\n\n    return {\"version\": rendered, \"full-revisionid\": pieces[\"long\"],\n            \"dirty\": pieces[\"dirty\"], \"error\": None,\n            \"date\": pieces.get(\"date\")}\n\n\nclass VersioneerBadRootError(Exception):\n    \"\"\"The project root directory is unknown or missing key files.\"\"\"\n\n\ndef get_versions(verbose: bool = False) -> Dict[str, Any]:\n    \"\"\"Get the project version from whatever source is available.\n\n    Returns dict with two keys: 'version' and 'full'.\n    \"\"\"\n    if \"versioneer\" in sys.modules:\n        # see the discussion in cmdclass.py:get_cmdclass()\n        del sys.modules[\"versioneer\"]\n\n    root = get_root()\n    cfg = get_config_from_root(root)\n\n    assert cfg.VCS is not None, \"please set [versioneer]VCS= in setup.cfg\"\n    handlers = HANDLERS.get(cfg.VCS)\n    assert handlers, \"unrecognized VCS '%s'\" % cfg.VCS\n    verbose = verbose or bool(cfg.verbose)  # `bool()` used to avoid `None`\n    assert cfg.versionfile_source is not None, \\\n        \"please set versioneer.versionfile_source\"\n    assert cfg.tag_prefix is not None, \"please set versioneer.tag_prefix\"\n\n    versionfile_abs = os.path.join(root, cfg.versionfile_source)\n\n    # extract version from first of: _version.py, VCS command (e.g. 'git\n    # describe'), parentdir. This is meant to work for developers using a\n    # source checkout, for users of a tarball created by 'setup.py sdist',\n    # and for users of a tarball/zipball created by 'git archive' or github's\n    # download-from-tag feature or the equivalent in other VCSes.\n\n    get_keywords_f = handlers.get(\"get_keywords\")\n    from_keywords_f = handlers.get(\"keywords\")\n    if get_keywords_f and from_keywords_f:\n        try:\n            keywords = get_keywords_f(versionfile_abs)\n            ver = from_keywords_f(keywords, cfg.tag_prefix, verbose)\n            if verbose:\n                print(\"got version from expanded keyword %s\" % ver)\n            return ver\n        except NotThisMethod:\n            pass\n\n    try:\n        ver = versions_from_file(versionfile_abs)\n        if verbose:\n            print(\"got version from file %s %s\" % (versionfile_abs, ver))\n        return ver\n    except NotThisMethod:\n        pass\n\n    from_vcs_f = handlers.get(\"pieces_from_vcs\")\n    if from_vcs_f:\n        try:\n            pieces = from_vcs_f(cfg.tag_prefix, root, verbose)\n            ver = render(pieces, cfg.style)\n            if verbose:\n                print(\"got version from VCS %s\" % ver)\n            return ver\n        except NotThisMethod:\n            pass\n\n    try:\n        if cfg.parentdir_prefix:\n            ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose)\n            if verbose:\n                print(\"got version from parentdir %s\" % ver)\n            return ver\n    except NotThisMethod:\n        pass\n\n    if verbose:\n        print(\"unable to compute version\")\n\n    return {\"version\": \"0+unknown\", \"full-revisionid\": None,\n            \"dirty\": None, \"error\": \"unable to compute version\",\n            \"date\": None}\n\n\ndef get_version() -> str:\n    \"\"\"Get the short version string for this project.\"\"\"\n    return get_versions()[\"version\"]\n\n\ndef get_cmdclass(cmdclass: Optional[Dict[str, Any]] = None):\n    \"\"\"Get the custom setuptools subclasses used by Versioneer.\n\n    If the package uses a different cmdclass (e.g. one from numpy), it\n    should be provide as an argument.\n    \"\"\"\n    if \"versioneer\" in sys.modules:\n        del sys.modules[\"versioneer\"]\n        # this fixes the \"python setup.py develop\" case (also 'install' and\n        # 'easy_install .'), in which subdependencies of the main project are\n        # built (using setup.py bdist_egg) in the same python process. Assume\n        # a main project A and a dependency B, which use different versions\n        # of Versioneer. A's setup.py imports A's Versioneer, leaving it in\n        # sys.modules by the time B's setup.py is executed, causing B to run\n        # with the wrong versioneer. Setuptools wraps the sub-dep builds in a\n        # sandbox that restores sys.modules to it's pre-build state, so the\n        # parent is protected against the child's \"import versioneer\". By\n        # removing ourselves from sys.modules here, before the child build\n        # happens, we protect the child from the parent's versioneer too.\n        # Also see https://github.com/python-versioneer/python-versioneer/issues/52\n\n    cmds = {} if cmdclass is None else cmdclass.copy()\n\n    # we add \"version\" to setuptools\n    from setuptools import Command\n\n    class cmd_version(Command):\n        description = \"report generated version string\"\n        user_options: List[Tuple[str, str, str]] = []\n        boolean_options: List[str] = []\n\n        def initialize_options(self) -> None:\n            pass\n\n        def finalize_options(self) -> None:\n            pass\n\n        def run(self) -> None:\n            vers = get_versions(verbose=True)\n            print(\"Version: %s\" % vers[\"version\"])\n            print(\" full-revisionid: %s\" % vers.get(\"full-revisionid\"))\n            print(\" dirty: %s\" % vers.get(\"dirty\"))\n            print(\" date: %s\" % vers.get(\"date\"))\n            if vers[\"error\"]:\n                print(\" error: %s\" % vers[\"error\"])\n    cmds[\"version\"] = cmd_version\n\n    # we override \"build_py\" in setuptools\n    #\n    # most invocation pathways end up running build_py:\n    #  distutils/build -> build_py\n    #  distutils/install -> distutils/build ->..\n    #  setuptools/bdist_wheel -> distutils/install ->..\n    #  setuptools/bdist_egg -> distutils/install_lib -> build_py\n    #  setuptools/install -> bdist_egg ->..\n    #  setuptools/develop -> ?\n    #  pip install:\n    #   copies source tree to a tempdir before running egg_info/etc\n    #   if .git isn't copied too, 'git describe' will fail\n    #   then does setup.py bdist_wheel, or sometimes setup.py install\n    #  setup.py egg_info -> ?\n\n    # pip install -e . and setuptool/editable_wheel will invoke build_py\n    # but the build_py command is not expected to copy any files.\n\n    # we override different \"build_py\" commands for both environments\n    if 'build_py' in cmds:\n        _build_py: Any = cmds['build_py']\n    else:\n        from setuptools.command.build_py import build_py as _build_py\n\n    class cmd_build_py(_build_py):\n        def run(self) -> None:\n            root = get_root()\n            cfg = get_config_from_root(root)\n            versions = get_versions()\n            _build_py.run(self)\n            if getattr(self, \"editable_mode\", False):\n                # During editable installs `.py` and data files are\n                # not copied to build_lib\n                return\n            # now locate _version.py in the new build/ directory and replace\n            # it with an updated value\n            if cfg.versionfile_build:\n                target_versionfile = os.path.join(self.build_lib,\n                                                  cfg.versionfile_build)\n                print(\"UPDATING %s\" % target_versionfile)\n                write_to_version_file(target_versionfile, versions)\n    cmds[\"build_py\"] = cmd_build_py\n\n    if 'build_ext' in cmds:\n        _build_ext: Any = cmds['build_ext']\n    else:\n        from setuptools.command.build_ext import build_ext as _build_ext\n\n    class cmd_build_ext(_build_ext):\n        def run(self) -> None:\n            root = get_root()\n            cfg = get_config_from_root(root)\n            versions = get_versions()\n            _build_ext.run(self)\n            if self.inplace:\n                # build_ext --inplace will only build extensions in\n                # build/lib<..> dir with no _version.py to write to.\n                # As in place builds will already have a _version.py\n                # in the module dir, we do not need to write one.\n                return\n            # now locate _version.py in the new build/ directory and replace\n            # it with an updated value\n            if not cfg.versionfile_build:\n                return\n            target_versionfile = os.path.join(self.build_lib,\n                                              cfg.versionfile_build)\n            if not os.path.exists(target_versionfile):\n                print(f\"Warning: {target_versionfile} does not exist, skipping \"\n                      \"version update. This can happen if you are running build_ext \"\n                      \"without first running build_py.\")\n                return\n            print(\"UPDATING %s\" % target_versionfile)\n            write_to_version_file(target_versionfile, versions)\n    cmds[\"build_ext\"] = cmd_build_ext\n\n    if \"cx_Freeze\" in sys.modules:  # cx_freeze enabled?\n        from cx_Freeze.dist import build_exe as _build_exe  # type: ignore\n        # nczeczulin reports that py2exe won't like the pep440-style string\n        # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g.\n        # setup(console=[{\n        #   \"version\": versioneer.get_version().split(\"+\", 1)[0], # FILEVERSION\n        #   \"product_version\": versioneer.get_version(),\n        #   ...\n\n        class cmd_build_exe(_build_exe):\n            def run(self) -> None:\n                root = get_root()\n                cfg = get_config_from_root(root)\n                versions = get_versions()\n                target_versionfile = cfg.versionfile_source\n                print(\"UPDATING %s\" % target_versionfile)\n                write_to_version_file(target_versionfile, versions)\n\n                _build_exe.run(self)\n                os.unlink(target_versionfile)\n                with open(cfg.versionfile_source, \"w\") as f:\n                    LONG = LONG_VERSION_PY[cfg.VCS]\n                    f.write(LONG %\n                            {\"DOLLAR\": \"$\",\n                             \"STYLE\": cfg.style,\n                             \"TAG_PREFIX\": cfg.tag_prefix,\n                             \"PARENTDIR_PREFIX\": cfg.parentdir_prefix,\n                             \"VERSIONFILE_SOURCE\": cfg.versionfile_source,\n                             })\n        cmds[\"build_exe\"] = cmd_build_exe\n        del cmds[\"build_py\"]\n\n    if 'py2exe' in sys.modules:  # py2exe enabled?\n        try:\n            from py2exe.setuptools_buildexe import py2exe as _py2exe  # type: ignore\n        except ImportError:\n            from py2exe.distutils_buildexe import py2exe as _py2exe  # type: ignore\n\n        class cmd_py2exe(_py2exe):\n            def run(self) -> None:\n                root = get_root()\n                cfg = get_config_from_root(root)\n                versions = get_versions()\n                target_versionfile = cfg.versionfile_source\n                print(\"UPDATING %s\" % target_versionfile)\n                write_to_version_file(target_versionfile, versions)\n\n                _py2exe.run(self)\n                os.unlink(target_versionfile)\n                with open(cfg.versionfile_source, \"w\") as f:\n                    LONG = LONG_VERSION_PY[cfg.VCS]\n                    f.write(LONG %\n                            {\"DOLLAR\": \"$\",\n                             \"STYLE\": cfg.style,\n                             \"TAG_PREFIX\": cfg.tag_prefix,\n                             \"PARENTDIR_PREFIX\": cfg.parentdir_prefix,\n                             \"VERSIONFILE_SOURCE\": cfg.versionfile_source,\n                             })\n        cmds[\"py2exe\"] = cmd_py2exe\n\n    # sdist farms its file list building out to egg_info\n    if 'egg_info' in cmds:\n        _egg_info: Any = cmds['egg_info']\n    else:\n        from setuptools.command.egg_info import egg_info as _egg_info\n\n    class cmd_egg_info(_egg_info):\n        def find_sources(self) -> None:\n            # egg_info.find_sources builds the manifest list and writes it\n            # in one shot\n            super().find_sources()\n\n            # Modify the filelist and normalize it\n            root = get_root()\n            cfg = get_config_from_root(root)\n            self.filelist.append('versioneer.py')\n            if cfg.versionfile_source:\n                # There are rare cases where versionfile_source might not be\n                # included by default, so we must be explicit\n                self.filelist.append(cfg.versionfile_source)\n            self.filelist.sort()\n            self.filelist.remove_duplicates()\n\n            # The write method is hidden in the manifest_maker instance that\n            # generated the filelist and was thrown away\n            # We will instead replicate their final normalization (to unicode,\n            # and POSIX-style paths)\n            from setuptools import unicode_utils\n            normalized = [unicode_utils.filesys_decode(f).replace(os.sep, '/')\n                          for f in self.filelist.files]\n\n            manifest_filename = os.path.join(self.egg_info, 'SOURCES.txt')\n            with open(manifest_filename, 'w') as fobj:\n                fobj.write('\\n'.join(normalized))\n\n    cmds['egg_info'] = cmd_egg_info\n\n    # we override different \"sdist\" commands for both environments\n    if 'sdist' in cmds:\n        _sdist: Any = cmds['sdist']\n    else:\n        from setuptools.command.sdist import sdist as _sdist\n\n    class cmd_sdist(_sdist):\n        def run(self) -> None:\n            versions = get_versions()\n            self._versioneer_generated_versions = versions\n            # unless we update this, the command will keep using the old\n            # version\n            self.distribution.metadata.version = versions[\"version\"]\n            return _sdist.run(self)\n\n        def make_release_tree(self, base_dir: str, files: List[str]) -> None:\n            root = get_root()\n            cfg = get_config_from_root(root)\n            _sdist.make_release_tree(self, base_dir, files)\n            # now locate _version.py in the new base_dir directory\n            # (remembering that it may be a hardlink) and replace it with an\n            # updated value\n            target_versionfile = os.path.join(base_dir, cfg.versionfile_source)\n            print(\"UPDATING %s\" % target_versionfile)\n            write_to_version_file(target_versionfile,\n                                  self._versioneer_generated_versions)\n    cmds[\"sdist\"] = cmd_sdist\n\n    return cmds\n\n\nCONFIG_ERROR = \"\"\"\nsetup.cfg is missing the necessary Versioneer configuration. You need\na section like:\n\n [versioneer]\n VCS = git\n style = pep440\n versionfile_source = src/myproject/_version.py\n versionfile_build = myproject/_version.py\n tag_prefix =\n parentdir_prefix = myproject-\n\nYou will also need to edit your setup.py to use the results:\n\n import versioneer\n setup(version=versioneer.get_version(),\n       cmdclass=versioneer.get_cmdclass(), ...)\n\nPlease read the docstring in ./versioneer.py for configuration instructions,\nedit setup.cfg, and re-run the installer or 'python versioneer.py setup'.\n\"\"\"\n\nSAMPLE_CONFIG = \"\"\"\n# See the docstring in versioneer.py for instructions. Note that you must\n# re-run 'versioneer.py setup' after changing this section, and commit the\n# resulting files.\n\n[versioneer]\n#VCS = git\n#style = pep440\n#versionfile_source =\n#versionfile_build =\n#tag_prefix =\n#parentdir_prefix =\n\n\"\"\"\n\nOLD_SNIPPET = \"\"\"\nfrom ._version import get_versions\n__version__ = get_versions()['version']\ndel get_versions\n\"\"\"\n\nINIT_PY_SNIPPET = \"\"\"\nfrom . import {0}\n__version__ = {0}.get_versions()['version']\n\"\"\"\n\n\ndef do_setup() -> int:\n    \"\"\"Do main VCS-independent setup function for installing Versioneer.\"\"\"\n    root = get_root()\n    try:\n        cfg = get_config_from_root(root)\n    except (OSError, configparser.NoSectionError,\n            configparser.NoOptionError) as e:\n        if isinstance(e, (OSError, configparser.NoSectionError)):\n            print(\"Adding sample versioneer config to setup.cfg\",\n                  file=sys.stderr)\n            with open(os.path.join(root, \"setup.cfg\"), \"a\") as f:\n                f.write(SAMPLE_CONFIG)\n        print(CONFIG_ERROR, file=sys.stderr)\n        return 1\n\n    print(\" creating %s\" % cfg.versionfile_source)\n    with open(cfg.versionfile_source, \"w\") as f:\n        LONG = LONG_VERSION_PY[cfg.VCS]\n        f.write(LONG % {\"DOLLAR\": \"$\",\n                        \"STYLE\": cfg.style,\n                        \"TAG_PREFIX\": cfg.tag_prefix,\n                        \"PARENTDIR_PREFIX\": cfg.parentdir_prefix,\n                        \"VERSIONFILE_SOURCE\": cfg.versionfile_source,\n                        })\n\n    ipy = os.path.join(os.path.dirname(cfg.versionfile_source),\n                       \"__init__.py\")\n    maybe_ipy: Optional[str] = ipy\n    if os.path.exists(ipy):\n        try:\n            with open(ipy, \"r\") as f:\n                old = f.read()\n        except OSError:\n            old = \"\"\n        module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0]\n        snippet = INIT_PY_SNIPPET.format(module)\n        if OLD_SNIPPET in old:\n            print(\" replacing boilerplate in %s\" % ipy)\n            with open(ipy, \"w\") as f:\n                f.write(old.replace(OLD_SNIPPET, snippet))\n        elif snippet not in old:\n            print(\" appending to %s\" % ipy)\n            with open(ipy, \"a\") as f:\n                f.write(snippet)\n        else:\n            print(\" %s unmodified\" % ipy)\n    else:\n        print(\" %s doesn't exist, ok\" % ipy)\n        maybe_ipy = None\n\n    # Make VCS-specific changes. For git, this means creating/changing\n    # .gitattributes to mark _version.py for export-subst keyword\n    # substitution.\n    do_vcs_install(cfg.versionfile_source, maybe_ipy)\n    return 0\n\n\ndef scan_setup_py() -> int:\n    \"\"\"Validate the contents of setup.py against Versioneer's expectations.\"\"\"\n    found = set()\n    setters = False\n    errors = 0\n    with open(\"setup.py\", \"r\") as f:\n        for line in f.readlines():\n            if \"import versioneer\" in line:\n                found.add(\"import\")\n            if \"versioneer.get_cmdclass()\" in line:\n                found.add(\"cmdclass\")\n            if \"versioneer.get_version()\" in line:\n                found.add(\"get_version\")\n            if \"versioneer.VCS\" in line:\n                setters = True\n            if \"versioneer.versionfile_source\" in line:\n                setters = True\n    if len(found) != 3:\n        print(\"\")\n        print(\"Your setup.py appears to be missing some important items\")\n        print(\"(but I might be wrong). Please make sure it has something\")\n        print(\"roughly like the following:\")\n        print(\"\")\n        print(\" import versioneer\")\n        print(\" setup( version=versioneer.get_version(),\")\n        print(\"        cmdclass=versioneer.get_cmdclass(),  ...)\")\n        print(\"\")\n        errors += 1\n    if setters:\n        print(\"You should remove lines like 'versioneer.VCS = ' and\")\n        print(\"'versioneer.versionfile_source = ' . This configuration\")\n        print(\"now lives in setup.cfg, and should be removed from setup.py\")\n        print(\"\")\n        errors += 1\n    return errors\n\n\ndef setup_command() -> NoReturn:\n    \"\"\"Set up Versioneer and exit with appropriate error code.\"\"\"\n    errors = do_setup()\n    errors += scan_setup_py()\n    sys.exit(1 if errors else 0)\n\n\nif __name__ == \"__main__\":\n    cmd = sys.argv[1]\n    if cmd == \"setup\":\n        setup_command()\n"
  }
]