Repository: google-deepmind/alphafold3
Branch: main
Commit: 608edb684db9
Files: 186
Total size: 7.4 MB
Directory structure:
gitextract_op0f7t9l/
├── .github/
│ └── workflows/
│ └── ci.yaml
├── CMakeLists.txt
├── CONTRIBUTING.md
├── LICENSE
├── OUTPUT_TERMS_OF_USE.md
├── README.md
├── WEIGHTS_PROHIBITED_USE_POLICY.md
├── WEIGHTS_TERMS_OF_USE.md
├── docker/
│ ├── Dockerfile
│ ├── dockerignore
│ └── jackhmmer_seq_limit.patch
├── docs/
│ ├── community_tools.md
│ ├── contributing.md
│ ├── input.md
│ ├── installation.md
│ ├── known_issues.md
│ ├── metadata_antibody_antigen.csv
│ ├── metadata_antibody_antigen.md
│ ├── model_parameters.md
│ ├── output.md
│ └── performance.md
├── fetch_databases.sh
├── legal/
│ ├── WEIGHTS_PROHIBITED_USE_POLICY-Bahasa-Indonesia.md
│ ├── WEIGHTS_PROHIBITED_USE_POLICY-Espanol-Latinoamerica.md
│ ├── WEIGHTS_PROHIBITED_USE_POLICY-Francais-Canada.md
│ ├── WEIGHTS_PROHIBITED_USE_POLICY-Portugues-Brazil.md
│ ├── WEIGHTS_TERMS_OF_USE-Bahasa-Indonesia.md
│ ├── WEIGHTS_TERMS_OF_USE-Espanol-Latinoamerica.md
│ ├── WEIGHTS_TERMS_OF_USE-Francais-Canada.md
│ └── WEIGHTS_TERMS_OF_USE-Portugues-Brazil.md
├── pyproject.toml
├── run_alphafold.py
├── run_alphafold_data_test.py
├── run_alphafold_test.py
└── src/
└── alphafold3/
├── __init__.py
├── build_data.py
├── common/
│ ├── base_config.py
│ ├── folding_input.py
│ ├── resources.py
│ ├── safe_pickle.py
│ └── testing/
│ └── data.py
├── constants/
│ ├── atom_types.py
│ ├── chemical_component_sets.py
│ ├── chemical_components.py
│ ├── converters/
│ │ ├── ccd_pickle_gen.py
│ │ └── chemical_component_sets_gen.py
│ ├── mmcif_names.py
│ ├── periodic_table.py
│ ├── residue_names.py
│ └── side_chains.py
├── cpp.cc
├── data/
│ ├── cpp/
│ │ ├── msa_profile_pybind.cc
│ │ └── msa_profile_pybind.h
│ ├── featurisation.py
│ ├── msa.py
│ ├── msa_config.py
│ ├── msa_features.py
│ ├── msa_identifiers.py
│ ├── parsers.py
│ ├── pipeline.py
│ ├── structure_stores.py
│ ├── template_realign.py
│ ├── templates.py
│ └── tools/
│ ├── hmmalign.py
│ ├── hmmbuild.py
│ ├── hmmsearch.py
│ ├── jackhmmer.py
│ ├── msa_tool.py
│ ├── nhmmer.py
│ ├── rdkit_utils.py
│ ├── shards.py
│ └── subprocess_utils.py
├── jax/
│ └── geometry/
│ ├── __init__.py
│ ├── rigid_matrix_vector.py
│ ├── rotation_matrix.py
│ ├── struct_of_array.py
│ ├── utils.py
│ └── vector.py
├── model/
│ ├── atom_layout/
│ │ └── atom_layout.py
│ ├── components/
│ │ ├── haiku_modules.py
│ │ ├── mapping.py
│ │ └── utils.py
│ ├── confidence_types.py
│ ├── confidences.py
│ ├── data3.py
│ ├── data_constants.py
│ ├── feat_batch.py
│ ├── features.py
│ ├── merging_features.py
│ ├── mkdssp_pybind.cc
│ ├── mkdssp_pybind.h
│ ├── mmcif_metadata.py
│ ├── model.py
│ ├── model_config.py
│ ├── msa_pairing.py
│ ├── network/
│ │ ├── atom_cross_attention.py
│ │ ├── confidence_head.py
│ │ ├── diffusion_head.py
│ │ ├── diffusion_transformer.py
│ │ ├── distogram_head.py
│ │ ├── evoformer.py
│ │ ├── featurization.py
│ │ ├── modules.py
│ │ ├── noise_level_embeddings.py
│ │ └── template_modules.py
│ ├── params.py
│ ├── pipeline/
│ │ ├── inter_chain_bonds.py
│ │ ├── pipeline.py
│ │ └── structure_cleaning.py
│ ├── post_processing.py
│ ├── protein_data_processing.py
│ └── scoring/
│ ├── alignment.py
│ ├── chirality.py
│ ├── covalent_bond_cleaning.py
│ └── scoring.py
├── parsers/
│ └── cpp/
│ ├── cif_dict.pyi
│ ├── cif_dict_lib.cc
│ ├── cif_dict_lib.h
│ ├── cif_dict_pybind.cc
│ ├── cif_dict_pybind.h
│ ├── fasta_iterator.pyi
│ ├── fasta_iterator_lib.cc
│ ├── fasta_iterator_lib.h
│ ├── fasta_iterator_pybind.cc
│ ├── fasta_iterator_pybind.h
│ ├── msa_conversion.pyi
│ ├── msa_conversion_pybind.cc
│ └── msa_conversion_pybind.h
├── scripts/
│ ├── copy_to_ssd.sh
│ └── gcp_mount_ssd.sh
├── structure/
│ ├── __init__.py
│ ├── bioassemblies.py
│ ├── bonds.py
│ ├── chemical_components.py
│ ├── cpp/
│ │ ├── aggregation.pyi
│ │ ├── aggregation_pybind.cc
│ │ ├── aggregation_pybind.h
│ │ ├── membership.pyi
│ │ ├── membership_pybind.cc
│ │ ├── membership_pybind.h
│ │ ├── mmcif_altlocs.cc
│ │ ├── mmcif_altlocs.h
│ │ ├── mmcif_atom_site.pyi
│ │ ├── mmcif_atom_site_pybind.cc
│ │ ├── mmcif_atom_site_pybind.h
│ │ ├── mmcif_layout.h
│ │ ├── mmcif_layout.pyi
│ │ ├── mmcif_layout_lib.cc
│ │ ├── mmcif_layout_pybind.cc
│ │ ├── mmcif_layout_pybind.h
│ │ ├── mmcif_struct_conn.h
│ │ ├── mmcif_struct_conn.pyi
│ │ ├── mmcif_struct_conn_lib.cc
│ │ ├── mmcif_struct_conn_pybind.cc
│ │ ├── mmcif_struct_conn_pybind.h
│ │ ├── mmcif_utils.pyi
│ │ ├── mmcif_utils_pybind.cc
│ │ ├── mmcif_utils_pybind.h
│ │ ├── string_array.pyi
│ │ ├── string_array_pybind.cc
│ │ └── string_array_pybind.h
│ ├── mmcif.py
│ ├── parsing.py
│ ├── sterics.py
│ ├── structure.py
│ ├── structure_tables.py
│ ├── table.py
│ └── test_utils.py
├── test_data/
│ ├── alphafold_run_outputs/
│ │ ├── run_alphafold_test_output_bucket_1024.pkl
│ │ └── run_alphafold_test_output_bucket_default.pkl
│ ├── featurised_example.json
│ ├── featurised_example.pkl
│ ├── miniature_databases/
│ │ ├── bfd-first_non_consensus_sequences__subsampled_1000.fasta
│ │ ├── mgy_clusters__subsampled_1000.fa
│ │ ├── nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq__subsampled_1000.fasta
│ │ ├── pdb_mmcif/
│ │ │ ├── 5y2e.cif
│ │ │ ├── 6s61.cif
│ │ │ ├── 6ydw.cif
│ │ │ └── 7rye.cif
│ │ ├── pdb_seqres_2022_09_28__subsampled_1000.fasta
│ │ ├── rfam_14_4_clustered_rep_seq__subsampled_1000.fasta
│ │ ├── rnacentral_active_seq_id_90_cov_80_linclust__subsampled_1000.fasta
│ │ ├── uniprot_all__subsampled_1000.fasta
│ │ └── uniref90__subsampled_1000.fasta
│ └── model_config.json
└── version.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/ci.yaml
================================================
name: Continuous Integration
on:
push:
branches:
- main
pull_request:
branches:
- main
workflow_dispatch:
jobs:
build:
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ${{ matrix.os }})"
runs-on: ${{ matrix.os }}
strategy:
matrix:
include:
- name-prefix: "all tests"
python-version: '3.12'
os: ubuntu-latest
steps:
- uses: actions/checkout@v6
- name: Set up uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
cache-dependency-glob: "uv.lock"
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
- name: Install dependencies
run: sudo apt-get install -y hmmer
- name: Install Python dependencies
run: uv sync --frozen --all-groups
- name: Build data
run: uv run build_data
- name: Run CPU-only tests
run: uv run python run_alphafold_data_test.py
================================================
FILE: CMakeLists.txt
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
cmake_minimum_required(VERSION 3.28)
# This forces Git to use the 'files' backend for all FetchContent operations.
# This fixes libcifpp and dssp incompatibility with newer git versions.
set(ENV{GIT_CONFIG_PARAMETERS} "'init.defaultRefFormat=files'")
project(
"${SKBUILD_PROJECT_NAME}"
LANGUAGES CXX
VERSION "${SKBUILD_PROJECT_VERSION}")
include(FetchContent)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE TRUE)
set(ABSL_PROPAGATE_CXX_STD ON)
# Remove support for scan deps, which is only useful when using C++ modules.
unset(CMAKE_CXX_SCANDEP_SOURCE)
FetchContent_Declare(
abseil-cpp
GIT_REPOSITORY https://github.com/abseil/abseil-cpp
GIT_TAG d7aaad83b488fd62bd51c81ecf16cd938532cc0a # 20240116.2
EXCLUDE_FROM_ALL)
FetchContent_Declare(
pybind11
GIT_REPOSITORY https://github.com/pybind/pybind11
GIT_TAG 2e0815278cb899b20870a67ca8205996ef47e70f # v2.12.0
EXCLUDE_FROM_ALL)
FetchContent_Declare(
pybind11_abseil
GIT_REPOSITORY https://github.com/pybind/pybind11_abseil
GIT_TAG bddf30141f9fec8e577f515313caec45f559d319 # HEAD @ 2024-08-07
EXCLUDE_FROM_ALL)
FetchContent_Declare(
cifpp
GIT_REPOSITORY https://github.com/pdb-redo/libcifpp
GIT_TAG ac98531a2fc8daf21131faa0c3d73766efa46180 # v7.0.3
# Don't `EXCLUDE_FROM_ALL` as necessary for build_data.
)
FetchContent_Declare(
dssp
GIT_REPOSITORY https://github.com/PDB-REDO/dssp
GIT_TAG 57560472b4260dc41f457706bc45fc6ef0bc0f10 # v4.4.7
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(pybind11 abseil-cpp pybind11_abseil cifpp dssp)
find_package(
Python3
COMPONENTS Interpreter Development NumPy
REQUIRED)
include_directories(${PYTHON_INCLUDE_DIRS})
include_directories(src/)
file(GLOB_RECURSE cpp_srcs src/alphafold3/*.cc)
list(FILTER cpp_srcs EXCLUDE REGEX ".*\(_test\|_main\|_benchmark\).cc$")
add_compile_definitions(NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION)
pybind11_add_module(cpp ${cpp_srcs})
target_link_libraries(
cpp
PRIVATE absl::check
absl::flat_hash_map
absl::node_hash_map
absl::strings
absl::status
absl::statusor
absl::log
pybind11_abseil::absl_casters
Python3::NumPy
dssp::dssp
cifpp::cifpp)
target_compile_definitions(cpp PRIVATE VERSION_INFO=${PROJECT_VERSION})
install(TARGETS cpp LIBRARY DESTINATION alphafold3)
install(
FILES LICENSE
OUTPUT_TERMS_OF_USE.md
WEIGHTS_PROHIBITED_USE_POLICY.md
WEIGHTS_TERMS_OF_USE.md
DESTINATION alphafold3)
================================================
FILE: CONTRIBUTING.md
================================================
# How to Contribute
We welcome small patches related to bug fixes and documentation, but we do not
plan to make any major changes to this repository.
## AI Generated Code
We welcome the use of AI tools for the generation of code, documentation and/or
Pull Request (PR) description as long as:
1. It has been transparently labelled as such. Make sure to declare it in the
PR message.
2. You have manually reviewed the code before sending the PR.
3. The change has been manually tested. We might ask you to fold a certain
input to check correctness of the PR.
Please do not submit AI generated PRs where test results have been hallucinated.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution,
this simply gives us permission to use and redistribute your contributions as
part of the project. Head over to to see
your current agreements on file or to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
================================================
FILE: LICENSE
================================================
Attribution-NonCommercial-ShareAlike 4.0 International
=======================================================================
Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.
Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.
Considerations for licensors: Our public licenses are
intended for use by those authorized to give the public
permission to use material in ways otherwise restricted by
copyright and certain other rights. Our licenses are
irrevocable. Licensors should read and understand the terms
and conditions of the license they choose before applying it.
Licensors should also secure all rights necessary before
applying our licenses so that the public can reuse the
material as expected. Licensors should clearly mark any
material not subject to the license. This includes other CC-
licensed material, or material used under an exception or
limitation to copyright. More considerations for licensors:
wiki.creativecommons.org/Considerations_for_licensors
Considerations for the public: By using one of our public
licenses, a licensor grants the public permission to use the
licensed material under specified terms and conditions. If
the licensor's permission is not necessary for any reason--for
example, because of any applicable exception or limitation to
copyright--then that use is not regulated by the license. Our
licenses grant only permissions under copyright and certain
other rights that a licensor has authority to grant. Use of
the licensed material may still be restricted for other
reasons, including because others have copyright or other
rights in the material. A licensor may make special requests,
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More considerations
for the public:
wiki.creativecommons.org/Considerations_for_licensees
=======================================================================
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
Public License
By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International Public License
("Public License"). To the extent this Public License may be
interpreted as a contract, You are granted the Licensed Rights in
consideration of Your acceptance of these terms and conditions, and the
Licensor grants You such rights in consideration of benefits the
Licensor receives from making the Licensed Material available under
these terms and conditions.
Section 1 -- Definitions.
a. Adapted Material means material subject to Copyright and Similar
Rights that is derived from or based upon the Licensed Material
and in which the Licensed Material is translated, altered,
arranged, transformed, or otherwise modified in a manner requiring
permission under the Copyright and Similar Rights held by the
Licensor. For purposes of this Public License, where the Licensed
Material is a musical work, performance, or sound recording,
Adapted Material is always produced where the Licensed Material is
synched in timed relation with a moving image.
b. Adapter's License means the license You apply to Your Copyright
and Similar Rights in Your contributions to Adapted Material in
accordance with the terms and conditions of this Public License.
c. BY-NC-SA Compatible License means a license listed at
creativecommons.org/compatiblelicenses, approved by Creative
Commons as essentially the equivalent of this Public License.
d. Copyright and Similar Rights means copyright and/or similar rights
closely related to copyright including, without limitation,
performance, broadcast, sound recording, and Sui Generis Database
Rights, without regard to how the rights are labeled or
categorized. For purposes of this Public License, the rights
specified in Section 2(b)(1)-(2) are not Copyright and Similar
Rights.
e. Effective Technological Measures means those measures that, in the
absence of proper authority, may not be circumvented under laws
fulfilling obligations under Article 11 of the WIPO Copyright
Treaty adopted on December 20, 1996, and/or similar international
agreements.
f. Exceptions and Limitations means fair use, fair dealing, and/or
any other exception or limitation to Copyright and Similar Rights
that applies to Your use of the Licensed Material.
g. License Elements means the license attributes listed in the name
of a Creative Commons Public License. The License Elements of this
Public License are Attribution, NonCommercial, and ShareAlike.
h. Licensed Material means the artistic or literary work, database,
or other material to which the Licensor applied this Public
License.
i. Licensed Rights means the rights granted to You subject to the
terms and conditions of this Public License, which are limited to
all Copyright and Similar Rights that apply to Your use of the
Licensed Material and that the Licensor has authority to license.
j. Licensor means the individual(s) or entity(ies) granting rights
under this Public License.
k. NonCommercial means not primarily intended for or directed towards
commercial advantage or monetary compensation. For purposes of
this Public License, the exchange of the Licensed Material for
other material subject to Copyright and Similar Rights by digital
file-sharing or similar means is NonCommercial provided there is
no payment of monetary compensation in connection with the
exchange.
l. Share means to provide material to the public by any means or
process that requires permission under the Licensed Rights, such
as reproduction, public display, public performance, distribution,
dissemination, communication, or importation, and to make material
available to the public including in ways that members of the
public may access the material from a place and at a time
individually chosen by them.
m. Sui Generis Database Rights means rights other than copyright
resulting from Directive 96/9/EC of the European Parliament and of
the Council of 11 March 1996 on the legal protection of databases,
as amended and/or succeeded, as well as other essentially
equivalent rights anywhere in the world.
n. You means the individual or entity exercising the Licensed Rights
under this Public License. Your has a corresponding meaning.
Section 2 -- Scope.
a. License grant.
1. Subject to the terms and conditions of this Public License,
the Licensor hereby grants You a worldwide, royalty-free,
non-sublicensable, non-exclusive, irrevocable license to
exercise the Licensed Rights in the Licensed Material to:
a. reproduce and Share the Licensed Material, in whole or
in part, for NonCommercial purposes only; and
b. produce, reproduce, and Share Adapted Material for
NonCommercial purposes only.
2. Exceptions and Limitations. For the avoidance of doubt, where
Exceptions and Limitations apply to Your use, this Public
License does not apply, and You do not need to comply with
its terms and conditions.
3. Term. The term of this Public License is specified in Section
6(a).
4. Media and formats; technical modifications allowed. The
Licensor authorizes You to exercise the Licensed Rights in
all media and formats whether now known or hereafter created,
and to make technical modifications necessary to do so. The
Licensor waives and/or agrees not to assert any right or
authority to forbid You from making technical modifications
necessary to exercise the Licensed Rights, including
technical modifications necessary to circumvent Effective
Technological Measures. For purposes of this Public License,
simply making modifications authorized by this Section 2(a)
(4) never produces Adapted Material.
5. Downstream recipients.
a. Offer from the Licensor -- Licensed Material. Every
recipient of the Licensed Material automatically
receives an offer from the Licensor to exercise the
Licensed Rights under the terms and conditions of this
Public License.
b. Additional offer from the Licensor -- Adapted Material.
Every recipient of Adapted Material from You
automatically receives an offer from the Licensor to
exercise the Licensed Rights in the Adapted Material
under the conditions of the Adapter's License You apply.
c. No downstream restrictions. You may not offer or impose
any additional or different terms or conditions on, or
apply any Effective Technological Measures to, the
Licensed Material if doing so restricts exercise of the
Licensed Rights by any recipient of the Licensed
Material.
6. No endorsement. Nothing in this Public License constitutes or
may be construed as permission to assert or imply that You
are, or that Your use of the Licensed Material is, connected
with, or sponsored, endorsed, or granted official status by,
the Licensor or others designated to receive attribution as
provided in Section 3(a)(1)(A)(i).
b. Other rights.
1. Moral rights, such as the right of integrity, are not
licensed under this Public License, nor are publicity,
privacy, and/or other similar personality rights; however, to
the extent possible, the Licensor waives and/or agrees not to
assert any such rights held by the Licensor to the limited
extent necessary to allow You to exercise the Licensed
Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this
Public License.
3. To the extent possible, the Licensor waives any right to
collect royalties from You for the exercise of the Licensed
Rights, whether directly or through a collecting society
under any voluntary or waivable statutory or compulsory
licensing scheme. In all other cases the Licensor expressly
reserves any right to collect such royalties, including when
the Licensed Material is used other than for NonCommercial
purposes.
Section 3 -- License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the
following conditions.
a. Attribution.
1. If You Share the Licensed Material (including in modified
form), You must:
a. retain the following if it is supplied by the Licensor
with the Licensed Material:
i. identification of the creator(s) of the Licensed
Material and any others designated to receive
attribution, in any reasonable manner requested by
the Licensor (including by pseudonym if
designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of
warranties;
v. a URI or hyperlink to the Licensed Material to the
extent reasonably practicable;
b. indicate if You modified the Licensed Material and
retain an indication of any previous modifications; and
c. indicate the Licensed Material is licensed under this
Public License, and include the text of, or the URI or
hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any
reasonable manner based on the medium, means, and context in
which You Share the Licensed Material. For example, it may be
reasonable to satisfy the conditions by providing a URI or
hyperlink to a resource that includes the required
information.
3. If requested by the Licensor, You must remove any of the
information required by Section 3(a)(1)(A) to the extent
reasonably practicable.
b. ShareAlike.
In addition to the conditions in Section 3(a), if You Share
Adapted Material You produce, the following conditions also apply.
1. The Adapter's License You apply must be a Creative Commons
license with the same License Elements, this version or
later, or a BY-NC-SA Compatible License.
2. You must include the text of, or the URI or hyperlink to, the
Adapter's License You apply. You may satisfy this condition
in any reasonable manner based on the medium, means, and
context in which You Share Adapted Material.
3. You may not offer or impose any additional or different terms
or conditions on, or apply any Effective Technological
Measures to, Adapted Material that restrict exercise of the
rights granted under the Adapter's License You apply.
Section 4 -- Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
to extract, reuse, reproduce, and Share all or a substantial
portion of the contents of the database for NonCommercial purposes
only;
b. if You include all or a substantial portion of the database
contents in a database in which You have Sui Generis Database
Rights, then the database in which You have Sui Generis Database
Rights (but not its individual contents) is Adapted Material,
including for purposes of Section 3(b); and
c. You must comply with the conditions in Section 3(a) if You Share
all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
c. The disclaimer of warranties and limitation of liability provided
above shall be interpreted in a manner that, to the extent
possible, most closely approximates an absolute disclaimer and
waiver of all liability.
Section 6 -- Term and Termination.
a. This Public License applies for the term of the Copyright and
Similar Rights licensed here. However, if You fail to comply with
this Public License, then Your rights under this Public License
terminate automatically.
b. Where Your right to use the Licensed Material has terminated under
Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided
it is cured within 30 days of Your discovery of the
violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any
right the Licensor may have to seek remedies for Your violations
of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the
Licensed Material under separate terms or conditions or stop
distributing the Licensed Material at any time; however, doing so
will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
License.
Section 7 -- Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different
terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the
Licensed Material not stated herein are separate from and
independent of the terms and conditions of this Public License.
Section 8 -- Interpretation.
a. For the avoidance of doubt, this Public License does not, and
shall not be interpreted to, reduce, limit, restrict, or impose
conditions on any use of the Licensed Material that could lawfully
be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is
deemed unenforceable, it shall be automatically reformed to the
minimum extent necessary to make it enforceable. If the provision
cannot be reformed, it shall be severed from this Public License
without affecting the enforceability of the remaining terms and
conditions.
c. No term or condition of this Public License will be waived and no
failure to comply consented to unless expressly agreed to by the
Licensor.
d. Nothing in this Public License constitutes or may be interpreted
as a limitation upon, or waiver of, any privileges and immunities
that apply to the Licensor or You, including from the legal
processes of any jurisdiction or authority.
=======================================================================
Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.
Creative Commons may be contacted at creativecommons.org.
================================================
FILE: OUTPUT_TERMS_OF_USE.md
================================================
# ALPHAFOLD 3 OUTPUT TERMS OF USE
Last Modified: 2024-11-09
By using AlphaFold 3 Output (as defined below), without having agreed to
[AlphaFold 3 Model Parameters Terms of Use](https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md),
you agree to be bound by these AlphaFold 3 Output Terms of Use between you (or
your organization, as applicable) and Google LLC (these "**Terms**").
If you are using Output on behalf of an organization, you confirm you are
authorized either explicitly or implicitly to agree to, and are agreeing to,
these Terms as an employee on behalf of, or otherwise on behalf of, your
organization.
If you have agreed to
[AlphaFold 3 Model Parameters Terms of Use](https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md),
your use of Output are governed by those terms. **If you have not agreed to
[AlphaFold 3 Model Parameters Terms of Use](https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md)
and do not agree to these Terms, do not use Output or permit any third party to
do so on your behalf.**
When we say "**you**", we mean the individual or organization using Output. When
we say "**we**", "**us**" or "**Google**", we mean the entities that belong to
the Google group of companies, which means Google LLC and its affiliates.
## Key Definitions
As used in these Terms:
"**AlphaFold 3**" means the AlphaFold 3 Code and Model Parameters.
"**AlphaFold 3 Code**" means the AlphaFold 3 source code: (a) identified at
[public GitHub repo](https://github.com/google-deepmind/alphafold3/), or such
other location in which we may make it available from time to time, regardless
of the source that it was obtained from; and (b) made available by Google to
organizations for their use in accordance with the
[AlphaFold 3 Model Parameters Terms of Use](https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md)
(not these Terms) together with (i) modifications to that code, (ii) works based
on that code, or (iii) other code or machine learning model which incorporates,
in full or in part, that code.
"**Model Parameters**" means the trained model weights and parameters made
available by Google to organizations (at its sole discretion) for their use in
accordance with the
[AlphaFold 3 Model Parameters Terms of Use](https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md)
(not these Terms), together with (a) modifications to those weights and
parameters, (b) works based on those weights and parameters, or (c) other code
or machine learning model which incorporates, in full or in part, those weights
and parameters.
"**Output**" means the structure predictions and all related information
provided by AlphaFold 3, together with any visual representations, computational
predictions, descriptions, modifications, copies, or adaptations that are
substantially derived from Output.
## Use restrictions
[AlphaFold 3](https://blog.google/technology/ai/google-deepmind-isomorphic-alphafold-3-ai-model/)
belongs to us. Output are made available free of charge, for non-commercial use
only, in accordance with the following use restrictions. You must not use nor
allow others to use Output:
1. **On behalf of a commercial organization or in connection with any
commercial activities, including research on behalf of commercial
organizations.**
1. This means that only non-commercial organizations (*i.e.*, universities,
non-profit organizations and research institutes, educational,
journalism and government bodies) may use Output for their
non-commercial activities. Output are not available for use by any other
types of organization, even if conducting non-commercial work.
2. If you are a researcher affiliated with a non-commercial organization,
provided **you are not a commercial organisation or acting on behalf of
a commercial organisation**, you can use Output for your non-commercial
affiliated research.
3. You must not share Output with any commercial organization. The only
exception is making Output publicly available (including, indirectly, to
commercial organizations) via a scientific publication or open source
release or using these to support journalism, each of which are
permitted.
2. **To misinform, misrepresent or mislead**, including:
1. providing false or inaccurate information in relation to your access to
or use of Output;
2. misrepresenting your relationship with Google - including by using
Google’s trademarks, trade names, logos or suggesting endorsement by
Google without Google’s permission to do so - nothing in these Terms
grants such permission;
3. misrepresenting the origin of Output;
4. distributing misleading claims of expertise or capability, or engaging
in the unauthorized or unlicensed practice of any profession,
particularly in sensitive areas (*e.g.*, health); or
5. making decisions in domains that affect material or individual rights or
well-being (*e.g.*, healthcare).
3. **To perform, promote or facilitate dangerous, illegal or malicious
activities**, including:
1. promoting or facilitating the sale of, or providing instructions for
synthesizing or accessing, illegal substances, goods or services;
2. abusing, harming, interfering, or disrupting any services, including
generating or distributing content for deceptive or fraudulent
activities or malware;
3. generating or distributing any content that infringes, misappropriates,
or otherwise violates any individual’s or entity’s rights (including,
but not limited to rights in copyrighted content); or
4. attempting to circumvent these Terms.
4. **To train or create machine learning models or related technology for
biomolecular structure prediction similar to AlphaFold 3 as made available
by Google ("Derived Models"),** including via distillation or other
methods**.** For the avoidance of doubt, the use restrictions set out in
these Terms would apply in full to any Derived Models created in breach of
these Terms.
5. **Without providing conspicuous notice that published or distributed Output
is provided under and subject to these Terms and of any modifications you
make to Output.**
1. This means if you remove, or cause to be removed (for example by using
third-party software), these Terms, or any notice of these Terms, from
Output, you must ensure further distribution or publication is
accompanied by a copy of the
[AlphaFold 3 Output Terms of Use](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md)
and a "*Legally Binding Terms of Use*" text file that contains the
following notice:
"*By using this information, you agree to AlphaFold 3 Output Terms of
Use found at
https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md.*
*To request access to the AlphaFold 3 model parameters, follow the
process set out at https://github.com/google-deepmind/alphafold3. You
may only use these if received directly from Google. Use is subject to
terms of use available at
https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.*"
2. You must not include any additional or different terms that conflict
with the
[AlphaFold 3 Output Terms of Use](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md).
6. **Distribute Output, or disclose findings arising from using AlphaFold 3
without citing our paper:** [Abramson, J et al. Accurate structure
prediction of biomolecular interactions with AlphaFold 3. *Nature*
(2024)](https://www.nature.com/articles/s41586-024-07487-w). For the
avoidance of doubt, this is an additional requirement to the notice
requirements set out above.
We grant you a non-exclusive, royalty-free, revocable, non-transferable and
non-sublicensable (except as expressly permitted in these Terms) license to any
intellectual property rights we have in Output to the extent necessary for these
purposes. You agree that your right to use and share Output is subject to your
compliance with these Terms. If you breach these Terms, Google reserves the
right to request that you delete and cease use or sharing of Output in your
possession or control and prohibit you from using the AlphaFold 3 Assets
(including as made available via
[AlphaFold Server](https://alphafoldserver.com/about)). You agree to immediately
comply with any such request.
## Disclaimers
Nothing in these Terms restricts any rights that cannot be restricted under
applicable law or limits Google’s responsibilities except as allowed by
applicable law.
**Output are provided on an "as is" basis, without warranties or conditions of
any kind, either express or implied, including any warranties or conditions of
title, non-infringement, merchantability, or fitness for a particular purpose.
You are solely responsible for determining the appropriateness of using or
distributing any of the Output and assume any and all risks associated with your
use or distribution of any Output and your exercise of rights and obligations
under these Terms. You and anyone you share Output with are solely responsible
for these and their subsequent uses.**
**Output are predictions with varying levels of confidence and should be
interpreted carefully. Use discretion before relying on, publishing, downloading
or otherwise using Output.**
**Output are for theoretical modeling only. These are not intended, validated,
or approved for clinical use. You should not use these for clinical purposes or
rely on them for medical or other professional advice. Any content regarding
those topics is provided for informational purposes only and is not a substitute
for advice from a qualified professional.**
## Liabilities
To the extent allowed by applicable law, you will indemnify Google and its
directors, officers, employees, and contractors for any third-party legal
proceedings (including actions by government authorities) arising out of or
relating to your unlawful use of Output or violation of these Terms. This
indemnity covers any liability or expense arising from claims, losses, damages,
judgments, fines, litigation costs, and legal fees, except to the extent a
liability or expense is caused by Google's breach, negligence, or willful
misconduct. If you are legally exempt from certain responsibilities, including
indemnification, then those responsibilities don’t apply to you under these
terms.
In no circumstances will Google be responsible for any indirect, special,
incidental, exemplary, consequential, or punitive damages, or lost profits of
any kind, even if Google has been advised of the possibility of such damages.
Google’s total, aggregate liability for all claims arising out of or in
connection with these Terms or Output, including for its own negligence, is
limited to $500.
## Governing law and disputes
These Terms will be governed by the laws of the State of California. The state
or federal courts of Santa Clara County, California shall have exclusive
jurisdiction of any dispute arising out of these Terms.
Given the nature of scientific research, it may take some time for any breach of
these Terms to become apparent. To the extent allowed by applicable law, any
legal claims relating to these Terms or Output can be initiated until the later
of (a) the cut-off date under applicable law for bringing the legal claim; or
(b) two years from the date you or Google (as applicable) became aware, or
should reasonably have become aware, of the facts giving rise to that claim. You
will not argue limitation, time bar, delay, waiver or the like in an attempt to
bar an action filed within that time period, and neither will we.
All rights not specifically and expressly granted to you by these Terms are
reserved to Google. No delay, act or omission by Google in exercising any right
or remedy will be deemed a waiver of any breach of these Terms and Google
expressly reserves any and all rights and remedies available under these Terms
or at law or in equity or otherwise, including the remedy of injunctive relief
against any threatened or actual breach of these Terms without the necessity of
proving actual damages.
## Miscellaneous
Google may update these Terms (1) to reflect changes in how it does business,
(2) for legal, regulatory or security reasons, or (3) to prevent abuse or harm.
The version of these Terms that were effective on the date the relevant Output
was generated will apply to your use of that Output.
If it turns out that a particular provision of these Terms is not valid or
enforceable, this will not affect any other provisions.
================================================
FILE: README.md
================================================

# AlphaFold 3
This package provides an implementation of the inference pipeline of AlphaFold
3. See below for how to access the model parameters. You may only use AlphaFold
3 model parameters if received directly from Google. Use is subject to these
[terms of use](https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md).
Any publication that discloses findings arising from using this source code, the
model parameters or outputs produced by those should [cite](#citing-this-work)
the
[Accurate structure prediction of biomolecular interactions with AlphaFold 3](https://doi.org/10.1038/s41586-024-07487-w)
paper.
Please also refer to the Supplementary Information for a detailed description of
the method.
AlphaFold 3 is also available at
[alphafoldserver.com](https://alphafoldserver.com) for non-commercial use,
though with a more limited set of ligands and covalent modifications.
If you have any questions, please contact the AlphaFold team at
[alphafold@google.com](mailto:alphafold@google.com).
## Obtaining Model Parameters
This repository contains all necessary code for AlphaFold 3 inference. To
request access to the AlphaFold 3 model parameters, please complete
[this form](https://forms.gle/svvpY4u2jsHEwWYS6). Access will be granted at
Google DeepMind’s sole discretion. We will aim to respond to requests within 2–3
business days. You may only use AlphaFold 3 model parameters if received
directly from Google. Use is subject to these
[terms of use](https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md).
## Installation and Running Your First Prediction
See the [installation documentation](docs/installation.md).
Once you have installed AlphaFold 3, you can test your setup using e.g. the
following input JSON file named `fold_input.json`:
```json
{
"name": "2PV7",
"sequences": [
{
"protein": {
"id": ["A", "B"],
"sequence": "GMRESYANENQFGFKTINSDIHKIVIVGGYGKLGGLFARYLRASGYPISILDREDWAVAESILANADVVIVSVPINLTLETIERLKPYLTENMLLADLTSVKREPLAKMLEVHTGAVLGLHPMFGADIASMAKQVVVRCDGRFPERYEWLLEQIQIWGAKIYQTNATEHDHNMTYIQALRHFSTFANGLHLSKQPINLANLLALSSPIYRLELAMIGRLFAQDAELYADIIMDKSENLAVIETLKQTYDEALTFFENNDRQGFIDAFHKVRDWFGDYSEQFLKESRQLLQQANDLKQG"
}
}
],
"modelSeeds": [1],
"dialect": "alphafold3",
"version": 1
}
```
You can then run AlphaFold 3 using the following command:
```
docker run -it \
--volume $HOME/af_input:/root/af_input \
--volume $HOME/af_output:/root/af_output \
--volume :/root/models \
--volume :/root/public_databases \
--gpus all \
alphafold3 \
python run_alphafold.py \
--json_path=/root/af_input/fold_input.json \
--model_dir=/root/models \
--output_dir=/root/af_output
```
There are various flags that you can pass to the `run_alphafold.py` command, to
list them all run `python run_alphafold.py --help`. Two fundamental flags that
control which parts AlphaFold 3 will run are:
* `--run_data_pipeline` (defaults to `true`): whether to run the data
pipeline, i.e. genetic and template search. This part is CPU-only, time
consuming and could be run on a machine without a GPU.
* `--run_inference` (defaults to `true`): whether to run the inference. This
part requires a GPU.
## AlphaFold 3 Input
See the [input documentation](docs/input.md).
## AlphaFold 3 Output
See the [output documentation](docs/output.md).
## Performance
See the [performance documentation](docs/performance.md).
## Known Issues
Known issues are documented in the
[known issues documentation](docs/known_issues.md).
Please
[create an issue](https://github.com/google-deepmind/alphafold3/issues/new/choose)
if it is not already listed in [Known Issues](docs/known_issues.md) or in the
[issues tracker](https://github.com/google-deepmind/alphafold3/issues).
## Citing This Work
Any publication that discloses findings arising from using this source code, the
model parameters or outputs produced by those should cite:
```bibtex
@article{Abramson2024,
author = {Abramson, Josh and Adler, Jonas and Dunger, Jack and Evans, Richard and Green, Tim and Pritzel, Alexander and Ronneberger, Olaf and Willmore, Lindsay and Ballard, Andrew J. and Bambrick, Joshua and Bodenstein, Sebastian W. and Evans, David A. and Hung, Chia-Chun and O’Neill, Michael and Reiman, David and Tunyasuvunakool, Kathryn and Wu, Zachary and Žemgulytė, Akvilė and Arvaniti, Eirini and Beattie, Charles and Bertolli, Ottavia and Bridgland, Alex and Cherepanov, Alexey and Congreve, Miles and Cowen-Rivers, Alexander I. and Cowie, Andrew and Figurnov, Michael and Fuchs, Fabian B. and Gladman, Hannah and Jain, Rishub and Khan, Yousuf A. and Low, Caroline M. R. and Perlin, Kuba and Potapenko, Anna and Savy, Pascal and Singh, Sukhdeep and Stecula, Adrian and Thillaisundaram, Ashok and Tong, Catherine and Yakneen, Sergei and Zhong, Ellen D. and Zielinski, Michal and Žídek, Augustin and Bapst, Victor and Kohli, Pushmeet and Jaderberg, Max and Hassabis, Demis and Jumper, John M.},
journal = {Nature},
title = {Accurate structure prediction of biomolecular interactions with AlphaFold 3},
year = {2024},
volume = {630},
number = {8016},
pages = {493–-500},
doi = {10.1038/s41586-024-07487-w}
}
```
## Acknowledgements
AlphaFold 3's release was made possible by the invaluable contributions of the
following people:
Andrew Cowie, Bella Hansen, Charlie Beattie, Chris Jones, Grace Margand,
Jacob Kelly, James Spencer, Josh Abramson, Kathryn Tunyasuvunakool, Kuba Perlin,
Lindsay Willmore, Max Bileschi, Molly Beck, Oleg Kovalevskiy,
Sebastian Bodenstein, Sukhdeep Singh, Tim Green, Toby Sargeant, Uchechi Okereke,
Yotam Doron, and Augustin Žídek (engineering lead).
We also extend our gratitude to our collaborators at Google and Isomorphic Labs.
AlphaFold 3 uses the following separate libraries and packages:
* [abseil-cpp](https://github.com/abseil/abseil-cpp) and
[abseil-py](https://github.com/abseil/abseil-py)
* [Docker](https://www.docker.com)
* [DSSP](https://github.com/PDB-REDO/dssp)
* [HMMER Suite](https://github.com/EddyRivasLab/hmmer)
* [Haiku](https://github.com/deepmind/dm-haiku)
* [JAX](https://github.com/jax-ml/jax/)
* [libcifpp](https://github.com/pdb-redo/libcifpp)
* [NumPy](https://github.com/numpy/numpy)
* [pybind11](https://github.com/pybind/pybind11) and
[pybind11_abseil](https://github.com/pybind/pybind11_abseil)
* [RDKit](https://github.com/rdkit/rdkit)
* [Tokamax](https://github.com/openxla/tokamax)
* [tqdm](https://github.com/tqdm/tqdm)
We thank all their contributors and maintainers!
## Get in Touch
If you have any questions not covered in this overview, please contact the
AlphaFold team at alphafold@google.com.
We would love to hear your feedback and understand how AlphaFold 3 has been
useful in your research. Share your stories with us at
[alphafold@google.com](mailto:alphafold@google.com).
## Licence and Disclaimer
This is not an officially supported Google product.
Copyright 2024 DeepMind Technologies Limited.
### AlphaFold 3 Source Code and Model Parameters
The AlphaFold 3 source code is licensed under the Creative Commons
Attribution-Non-Commercial ShareAlike International License, Version 4.0
(CC-BY-NC-SA 4.0) (the "License"); you may not use this file except in
compliance with the License. You may obtain a copy of the License at
[https://github.com/google-deepmind/alphafold3/blob/main/LICENSE](https://github.com/google-deepmind/alphafold3/blob/main/LICENSE).
The AlphaFold 3 model parameters are made available under the
[AlphaFold 3 Model Parameters Terms of Use](https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md)
(the "Terms"); you may not use these except in compliance with the Terms. You
may obtain a copy of the Terms at
[https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md](https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md).
Unless required by applicable law, AlphaFold 3 and its output are distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
or implied. You are solely responsible for determining the appropriateness of
using AlphaFold 3, or using or distributing its source code or output, and
assume any and all risks associated with such use or distribution and your
exercise of rights and obligations under the relevant terms. Output are
predictions with varying levels of confidence and should be interpreted
carefully. Use discretion before relying on, publishing, downloading or
otherwise using the AlphaFold 3 Assets.
AlphaFold 3 and its output are for theoretical modeling only. They are not
intended, validated, or approved for clinical use. You should not use the
AlphaFold 3 or its output for clinical purposes or rely on them for medical or
other professional advice. Any content regarding those topics is provided for
informational purposes only and is not a substitute for advice from a qualified
professional. See the relevant terms for the specific language governing
permissions and limitations under the terms.
### Third-party Software
Use of the third-party software, libraries or code referred to in the
[Acknowledgements](#acknowledgements) section above may be governed by separate
terms and conditions or license provisions. Your use of the third-party
software, libraries or code is subject to any such terms and you should check
that you can comply with any applicable restrictions or terms and conditions
before use.
### Mirrored and Reference Databases
The following databases have been: (1) mirrored by Google DeepMind; and (2) in
part, included with the inference code package for testing purposes, and are
available with reference to the following:
* [BFD](https://bfd.mmseqs.com/) (modified), by Steinegger M. and Söding J.,
modified by Google DeepMind, available under a
[Creative Commons Attribution 4.0 International License](https://creativecommons.org/licenses/by/4.0/deed.en).
See the Methods section of the
[AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1)
for details.
* [PDB](https://wwpdb.org) (unmodified), by H.M. Berman et al., available free
of all copyright restrictions and made fully and freely available for both
non-commercial and commercial use under
[CC0 1.0 Universal (CC0 1.0) Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/).
* [MGnify: v2022\_05](https://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2022_05/README.txt)
(unmodified), by Mitchell AL et al., available free of all copyright
restrictions and made fully and freely available for both non-commercial and
commercial use under
[CC0 1.0 Universal (CC0 1.0) Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/).
* [UniProt: 2021\_04](https://www.uniprot.org/) (unmodified), by The UniProt
Consortium, available under a
[Creative Commons Attribution 4.0 International License](https://creativecommons.org/licenses/by/4.0/deed.en).
* [UniRef90: 2022\_05](https://www.uniprot.org/) (unmodified) by The UniProt
Consortium, available under a
[Creative Commons Attribution 4.0 International License](https://creativecommons.org/licenses/by/4.0/deed.en).
* [NT: 2023\_02\_23](https://www.ncbi.nlm.nih.gov/nucleotide/) (modified) See
the Supplementary Information of the
[AlphaFold 3 paper](https://nature.com/articles/s41586-024-07487-w) for
details.
* [RFam: 14\_4](https://rfam.org/) (modified), by I. Kalvari et al., available
free of all copyright restrictions and made fully and freely available for
both non-commercial and commercial use under
[CC0 1.0 Universal (CC0 1.0) Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/).
See the Supplementary Information of the
[AlphaFold 3 paper](https://nature.com/articles/s41586-024-07487-w) for
details.
* [RNACentral: 21\_0](https://rnacentral.org/) (modified), by The RNAcentral
Consortium available free of all copyright restrictions and made fully and
freely available for both non-commercial and commercial use under
[CC0 1.0 Universal (CC0 1.0) Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/).
See the Supplementary Information of the
[AlphaFold 3 paper](https://nature.com/articles/s41586-024-07487-w) for
details.
================================================
FILE: WEIGHTS_PROHIBITED_USE_POLICY.md
================================================
# ALPHAFOLD 3 MODEL PARAMETERS PROHIBITED USE POLICY
Last Modified: 2024-11-09
AlphaFold 3 can help you accelerate scientific research by predicting the 3D
structure of biological molecules. Google makes the AlphaFold Assets available
free of charge for certain non-commercial uses in accordance with the
restrictions set out below. This policy uses the same defined terms as the
[AlphaFold 3 Model Parameters Terms of Use](https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md).
**You must not access or use nor allow others to access or use the AlphaFold 3
Assets:**
1. **On behalf of a commercial organization or in connection with any
commercial activities, including research on behalf of commercial
organizations.**
1. This means that only non-commercial organizations (*i.e.*, universities,
non-profit organizations and research institutes, educational,
journalism and government bodies) may use the AlphaFold 3 Assets for
their non-commercial activities. The AlphaFold 3 Assets are not
available for any other types of organization, even if conducting
non-commercial work.
2. If you are a researcher affiliated with a non-commercial organization,
provided **you are not a commercial organisation or acting on behalf of
a commercial organisation,** you can use the AlphaFold 3 Assets for your
non-commercial affiliated research.
3. You must not share the AlphaFold 3 Assets with any commercial
organization or use the AlphaFold 3 Assets in a manner that will grant a
commercial organization any rights in these. The only exception is
making Output publicly available (including indirectly to commercial
organizations) via a scientific publication or open source release or
using it to support journalism, each of which is permitted.
2. **To misinform, misrepresent or mislead**, including:
1. providing false or inaccurate information in relation to your access to
or use of AlphaFold 3 or Output, including accessing or using the Model
Parameters on behalf of an organization without telling us or submitting
a request to access the Model Parameters where Google has prohibited
your use of AlphaFold 3 in full or in part (including as made available
via [AlphaFold Server](https://alphafoldserver.com/about));
2. misrepresenting your relationship with us, including by using Google’s
trademarks, trade names, logos or suggesting endorsement by Google
without Google’s permission to do so - nothing in the Terms grants such
permission;
3. misrepresenting the origin of AlphaFold 3 in full or in part;
4. distributing misleading claims of expertise or capability, or engaging
in the unauthorized or unlicensed practice of any profession,
particularly in sensitive areas (*e.g.*, health); or
5. to make decisions in domains that affect material or individual rights
or well-being (*e.g.*, healthcare).
3. **To perform, promote or facilitate dangerous, illegal or malicious
activities**, including:
1. promoting or facilitating the sale of, or providing instructions for
synthesizing or accessing, illegal substances, goods or services;
2. abusing, harming, interfering, or disrupting any services, including
generating or distributing content for deceptive or fraudulent
activities or malware;
3. generating or distributing any content, including Output, that
infringes, misappropriates, or otherwise violates any individual's or
entity's rights (including, but not limited to rights in copyrighted
content); or
4. attempting to circumvent, or intentionally causing (directly or
indirectly) AlphaFold 3 to act in a manner that contravenes the Terms.
**You must not nor allow others to:**
1. **Use Output to train or create machine learning models or related
technology for biomolecular structure prediction similar to AlphaFold 3
("Derived Models"),** including via distillation or other methods. For the
avoidance of doubt, the use restrictions set out in the Terms would apply in
full to any Derived Models created in breach of the Terms.
2. **Distribute Output without providing conspicuous notice that what you
Distribute is provided under and subject to the
[AlphaFold 3 Output Terms of Use](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md)
and of any modifications you make.**
1. This means if you remove, or cause to be removed (for example by using
third-party software), the notices and terms we provide when you
generate Output using AlphaFold 3, you must ensure any further
Distribution of Output is accompanied by a copy of the
[AlphaFold 3 Output Terms of Use](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md)
and a "Legally Binding Terms of Use" text file that contains the
following notice:
"*By using this information, you agree to AlphaFold 3 Output Terms of
Use found at
https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md.*
*To request access to the AlphaFold 3 model parameters, follow the
process set out at https://github.com/google-deepmind/alphafold3. You
may only use these if received directly from Google. Use is subject to
terms of use available at
https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.*"
2. You must not include any additional or different terms that conflict
with the
[AlphaFold 3 Output Terms of Use](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md).
3. **Distribute Output, or disclose findings arising from using AlphaFold 3
without citing our paper:** [Abramson, J et al. Accurate structure
prediction of biomolecular interactions with AlphaFold 3. *Nature*
(2024)](https://www.nature.com/articles/s41586-024-07487-w). For the
avoidance of doubt, this is an additional requirement to the notice
requirements set out above.
4. **Circumvent access restrictions relating to the Model Parameters, including
utilising, sharing or making available the Model Parameters when you have
not been expressly authorized to do so by Google.** Google will grant access
to the Model Parameters to either:
1. you for your individual use on behalf of your organization, in which
case you cannot share your copy of Model Parameters with anyone else; or
2. an authorized representative of your organization, with full legal
authority to bind that organization to these Terms in which case you may
share that organization’s copy of the Model Parameters with employees,
consultants, contractors and agents of the organization as authorized by
that representative.
================================================
FILE: WEIGHTS_TERMS_OF_USE.md
================================================
# ALPHAFOLD 3 MODEL PARAMETERS TERMS OF USE
Last Modified: 2024-11-09
[AlphaFold 3](https://blog.google/technology/ai/google-deepmind-isomorphic-alphafold-3-ai-model/)
is an AI model developed by [Google DeepMind](https://deepmind.google/) and
[Isomorphic Labs](https://www.isomorphiclabs.com/). It generates 3D structure
predictions of biological molecules, providing model confidence for the
structure predictions. We make the trained model parameters and output generated
using those available free of charge for certain non-commercial uses, in
accordance with these terms of use and the
[AlphaFold 3 Model Parameters Prohibited Use Policy](https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_PROHIBITED_USE_POLICY.md).
**Key things to know when using the AlphaFold 3 model parameters and output**
1. The AlphaFold 3 model parameters and output are **only** available for
non-commercial use by, or on behalf of, non-commercial organizations
(*i.e.*, universities, non-profit organizations and research institutes,
educational, journalism and government bodies). If you are a researcher
affiliated with a non-commercial organization, provided **you are not a
commercial organisation or acting on behalf of a commercial organisation,**
this means you can use these for your non-commercial affiliated research.
2. You **must not** use nor allow others to use:
1. AlphaFold 3 model parameters or output in connection with **any
commercial activities, including research** **on behalf of commercial
organizations;** or
2. AlphaFold 3 output to **train machine learning models** or related
technology for **biomolecular structure prediction** similar to
AlphaFold 3.
3. You ***must not* publish or share AlphaFold 3 model parameters**, except
sharing these within your organization in accordance with these Terms.
4. You ***can* publish, share and adapt AlphaFold 3 *output*** in accordance
with these Terms, including the requirements to provide clear notice of any
modifications you make and that ongoing use of AlphaFold 3 output and
derivatives are subject to the
[AlphaFold 3 Output Terms of Use](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md).
By using, reproducing, modifying, performing, distributing or displaying any
portion or element of the Model Parameters (as defined below) or otherwise
accepting the terms of this agreement, you agree to be bound by (1) these terms
of use, and (2) the
[AlphaFold 3 Model Parameters Prohibited Use Policy](https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_PROHIBITED_USE_POLICY.md)
which is incorporated herein by reference (together, the "**Terms**"), in each
case (a) as modified from time to time in accordance with the Terms, and (b)
between you and (i) if you are from a country in the European Economic Area or
Switzerland, Google Ireland Limited, or (ii) otherwise, Google LLC.
You confirm you are authorized either explicitly or implicitly to enter, and are
entering, into the Terms as an employee on behalf of, or otherwise on behalf of,
your organization.
Please read these Terms carefully. They establish what you can expect from us as
you access and use the AlphaFold 3 Assets (as defined below), and what Google
expects from you. When we say "**you**", we mean the individual or organization
using the AlphaFold 3 Assets. When we say "**we**", "**us**" or "**Google**", we
mean the entities that belong to the Google group of companies, which means
Google LLC and its affiliates.
## 1. Key Definitions
As used in these Terms:
"**AlphaFold 3**" means: (a) the AlphaFold 3 source code made available
[here](https://github.com/google-deepmind/alphafold3/) and licensed under the
terms of the Creative Commons Attribution-NonCommercial-Sharealike 4.0
International (CC-BY-NC-SA 4.0) license and any derivative source code, and (b)
Model Parameters.
"**AlphaFold 3 Assets**" means the Model Parameters and Output.
"**Distribution**" or "**Distribute**" means any transmission, publication, or
other sharing of Output publicly or to any other person.
"**Model Parameters**" means the trained model weights and parameters made
available by Google to organizations (at its sole discretion) for their use in
accordance with these Terms, together with (a) modifications to those weights
and parameters, (b) works based on those weights and parameters, or (c) other
code or machine learning models which incorporate, in full or in part, those
weights and parameters.
"**Output**" means the structure predictions and all ancillary and related
information provided by AlphaFold 3 or using the Model Parameters, together with
any visual representations, computational predictions, descriptions,
modifications, copies, or adaptations that are substantially derived from
Output.
"**Including"** means "**including without limitation**".
## 2. Accessing and using the AlphaFold 3 Assets
Subject to your compliance with the Terms, including the
[AlphaFold 3 Model Parameters Prohibited Use Policy](https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_PROHIBITED_USE_POLICY.md),
you may access, use and modify the AlphaFold 3 Assets and Distribute the Output
as set out in these Terms. We grant you a non-exclusive, royalty-free,
revocable, non-transferable and non-sublicensable (except as expressly permitted
in these Terms) license to any intellectual property rights we have in the
AlphaFold Assets to the extent necessary for these purposes. In order to verify
your access and use of AlphaFold 3, we may from time-to-time ask for additional
information from you, including verification of your name, organization, and
other identifying information.
By accessing, using, or modifying the AlphaFold 3 Assets, Distributing Output,
or requesting to access the Model Parameters, you represent and warrant that (a)
you have full power and authority to enter into these Terms (including being of
sufficient age of consent), (b) Google has never previously terminated your
access and right to use AlphaFold 3 (including as made available via
[AlphaFold Server](https://alphafoldserver.com/about)) due to your breach of
applicable terms of use, (c) entering into or performing your rights and
obligations under these Terms will not violate any agreement you have with a
third party or any third-party rights, (d) any information provided by you to
Google in relation to AlphaFold 3, including (where applicable) in order to
request access to the Model Parameters, is correct and current, and (e) you are
not (i) resident of a embargoed country, (ii) ordinarily resident in a US
embargoed country, or (iii) otherwise prohibited by applicable export controls
and sanctions programs from accessing, using, or modifying the AlphaFold 3
Assets.
If you choose to give Google feedback, such as suggestions to improve AlphaFold
3, you undertake any such information is non-confidential and non-proprietary,
and Google may act on your feedback without obligation to you.
## 3. Use Restrictions
You must not use any of the AlphaFold 3 Assets:
1. for the restricted uses set forth in the
[AlphaFold 3 Model Parameters Prohibited Use Policy](https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_PROHIBITED_USE_POLICY.md);
or
2. in violation of applicable laws and regulations.
To the maximum extent permitted by law and without limiting any of our other
rights, Google reserves the right to revoke your right to use, and (to the
extent feasible) restrict usage of any of the AlphaFold 3 Assets that Google
reasonably believes is in violation of these Terms.
## 4. Generated Output
Although you must comply with these Terms when using the AlphaFold 3 Assets, we
will not claim ownership in original Output you generate using AlphaFold 3.
However, you acknowledge that AlphaFold 3 may generate the same or similar
Output for multiple users, including Google, and we reserve all our rights in
this respect.
## 5. Changes to the AlphaFold 3 Assets or these Terms
Google may add or remove functionalities or features of the AlphaFold 3 Assets
at any time and may stop offering access to the AlphaFold 3 Assets altogether.
Google may update these Terms and the access mechanism for the Model Parameters
at any time. We'll post any modifications to the Terms
[in the AlphaFold 3 GitHub repository](https://github.com/google-deepmind/alphafold3).
Changes will generally become effective 14 days after they are posted. However,
changes addressing functionality or made for legal reasons will be effective
immediately.
You should review the Terms whenever we update them or you use the AlphaFold 3
Assets. If you do not agree to any modifications to the Terms, you must stop
using the AlphaFold 3 Assets immediately.
## 6. Suspending or terminating your right to use the AlphaFold 3 Assets
Google may at any time suspend or terminate your right to use and, as applicable
access to, the AlphaFold 3 Assets because of, among other reasons, your failure
to fully comply with the Terms. If Google suspends or terminates your right to
access or use the AlphaFold 3 Assets, you must immediately delete and cease use
and Distribution of all copies of the AlphaFold 3 Assets in your possession or
control and are prohibited from using the AlphaFold 3 Assets, including by
submitting an application to use the Model Parameters. Google will endeavour to
give you reasonable notice prior to any such suspension or termination, but no
notice or prior warning will be given if the suspension or termination is for
your failure to fully comply with the Terms or other serious grounds.
Of course, you are always free to stop using the AlphaFold 3 Assets. If you do
stop using these, we would appreciate knowing why (via
[alphafold@google.com](mailto:alphafold@google.com)) so that we can continue to
improve our technologies.
## 7. Confidentiality
You agree not to disclose or make available Google Confidential Information to
anyone without our prior written consent. "**Google Confidential Information**"
means (a) the AlphaFold 3 Model Parameters and all software, technology and
documentation relating to AlphaFold 3, except for the AlphaFold 3 source code,
and (b) any other information made available by Google that is marked
confidential or would normally be considered confidential under the
circumstances in which it is presented. Google Confidential Information does not
include (a) information that you already knew prior to your access to, or use
of, the AlphaFold 3 Assets (including via
[AlphaFold Server](https://alphafoldserver.com/about)), (b) that becomes public
through no fault of yours (for example, your breach of the Terms), (c) that was
independently developed by you without reference to Google Confidential
Information, or (d) that was lawfully given to you by a third party (without
your or their breach of the Terms).
## 8. Disclaimers
Nothing in the Terms restricts any rights that cannot be restricted under
applicable law or limits Google's responsibilities except as allowed by
applicable law.
**AlphaFold 3 and Output are provided on an "as is" basis, without warranties or
conditions of any kind, either express or implied, including any warranties or
conditions of title, non-infringement, merchantability, or fitness for a
particular purpose. You are solely responsible for determining the
appropriateness of using AlphaFold 3, or using or distributing Output, and
assume any and all risks associated with such use or distribution and your
exercise of rights and obligations under these Terms. You and anyone you share
Output with are solely responsible for these and their subsequent uses.**
**Output are predictions with varying levels of confidence and should be
interpreted carefully. Use discretion before relying on, publishing, downloading
or otherwise using AlphaFold 3.**
**AlphaFold 3 and Outputs are for theoretical modeling only. They are not
intended, validated, or approved for clinical use. You should not use AlphaFold
3 or Output for clinical purposes or rely on them for medical or other
professional advice. Any content regarding those topics is provided for
informational purposes only and is not a substitute for advice from a qualified
professional.**
## 9. Liabilities
To the extent allowed by applicable law, you will indemnify Google and its
directors, officers, employees, and contractors for any third-party legal
proceedings (including actions by government authorities) arising out of or
relating to your unlawful use of the AlphaFold 3 Assets or violation of the
Terms. This indemnity covers any liability or expense arising from claims,
losses, damages, judgments, fines, litigation costs, and legal fees, except to
the extent a liability or expense is caused by Google's breach, negligence, or
willful misconduct. If you are legally exempt from certain responsibilities,
including indemnification, then those responsibilities do not apply to you under
the Terms.
In no circumstances will Google be responsible for any indirect, special,
incidental, exemplary, consequential, or punitive damages, or lost profits of
any kind in connection with the Terms or the AlphaFold 3 Assets, even if Google
has been advised of the possibility of such damages. Google's total aggregate
liability for all claims arising out of or in connection with the Terms or the
AlphaFold 3 Assets, including for its own negligence, is limited to $500.
## 10. Miscellaneous
By law, you have certain rights that cannot be limited by a contract like the
Terms. The Terms are in no way intended to restrict those rights.
The Terms are our entire agreement relating to your use of the AlphaFold 3
Assets and supersede any prior or contemporaneous agreements on that subject.
If it turns out that a particular provision of the Terms is not enforceable, the
balance of the Terms will remain in full force and effect.
## 11. Disputes
California law will govern all disputes arising out of or relating to the Terms
or in connection to the AlphaFold 3 Assets. These disputes will be resolved
exclusively in the federal or state courts of Santa Clara County, California,
USA and you and Google consent to personal jurisdiction in those courts. To the
extent that applicable local law prevents certain disputes from being resolved
in a California court, you and Google can file those disputes in your local
courts. If applicable local law prevents your local court from applying
California law to resolve these disputes, then these disputes will be governed
by the applicable local laws of your country, state, or other place of
residence. If you are using the AlphaFold 3 Assets on behalf of a government
organization other than US federal government organizations (where the foregoing
provisions shall apply to the extent permitted by federal law), these Terms will
be silent regarding governing law and courts.
Given the nature of scientific research, it may take some time for any breach of
the Terms to become apparent. To protect you, Google and the AlphaFold 3 Assets,
to the extent allowed by applicable law you agree that:
1. any legal claims relating to the Terms or the AlphaFold 3 Assets can be
initiated until the later of:
1. the cut-off date under applicable law for bringing the legal claim; or
2. two years from the date you or Google (as applicable) became aware, or
should reasonably have become aware, of the facts giving rise to that
claim; and
2. you will not argue limitation, time bar, delay, waiver, or the like in an
attempt to bar an action filed within that time period, and neither will
Google.
All rights not specifically and expressly granted to you by the Terms are
reserved to Google. No delay, act or omission by Google in exercising any right
or remedy will be deemed a waiver of any breach of the Terms and Google
expressly reserves any and all rights and remedies available under the Terms or
at law or in equity or otherwise, including the remedy of injunctive relief
against any threatened or actual breach of the Terms without the necessity of
proving actual damages.
================================================
FILE: docker/Dockerfile
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
FROM nvidia/cuda:12.6.3-base-ubuntu24.04
# Some RUN statements are combined together to make Docker build run faster.
# Get latest package listing, install python, git, wget, compilers and libs.
# * git is required for pyproject.toml toolchain's use of CMakeLists.txt.
# * gcc, g++, make are required for compiling HMMER and AlphaFold 3 libaries.
# * zlib is a required dependency of AlphaFold 3.
RUN DEBIAN_FRONTEND=noninteractive \
apt-get update --quiet \
&& apt-get install --yes --quiet python3.12 python3.12-dev \
&& apt-get install --yes --quiet git wget gcc g++ make zlib1g-dev zstd
# Install uv from the official repository. The version is pinned for
# reproducibility.
COPY --from=ghcr.io/astral-sh/uv:0.9.24 /uv /uvx /bin/
# UV_COMPILE_BYTECODE=1 speeds up future container starts.
# UV_PROJECT_ENVIRONMENT explicitly sets the virtual environment location.
ENV UV_COMPILE_BYTECODE=1
ENV UV_PROJECT_ENVIRONMENT=/alphafold3_venv
RUN uv venv $UV_PROJECT_ENVIRONMENT
ENV PATH="/hmmer/bin:/alphafold3_venv/bin:$PATH"
# Install HMMER. Do so before copying the source code, so that docker can cache
# the image layer containing HMMER. Alternatively, you could also install it
# using `apt-get install hmmer` instead of bulding it from source, but we want
# to have control over the exact version of HMMER and also apply the sequence
# limit patch. Also note that eddylab.org unfortunately doesn't support HTTPS
# and the tar file published on GitHub is explicitly not recommended to be used
# for building from source.
# Download, check hash, and extract the HMMER source code.
RUN mkdir /hmmer_build /hmmer ; \
wget http://eddylab.org/software/hmmer/hmmer-3.4.tar.gz --directory-prefix /hmmer_build ; \
(cd /hmmer_build && echo "ca70d94fd0cf271bd7063423aabb116d42de533117343a9b27a65c17ff06fbf3 hmmer-3.4.tar.gz" | sha256sum --check) && \
(cd /hmmer_build && tar zxf hmmer-3.4.tar.gz && rm hmmer-3.4.tar.gz)
# Apply the --seq_limit patch to HMMER.
COPY docker/jackhmmer_seq_limit.patch /hmmer_build/
RUN (cd /hmmer_build && patch -p0 < jackhmmer_seq_limit.patch)
# Build HMMER.
RUN (cd /hmmer_build/hmmer-3.4 && ./configure --prefix /hmmer) ; \
(cd /hmmer_build/hmmer-3.4 && make -j) ; \
(cd /hmmer_build/hmmer-3.4 && make install) ; \
(cd /hmmer_build/hmmer-3.4/easel && make install) ; \
rm -R /hmmer_build
# Copy the AlphaFold 3 source code from the local machine to the container and
# set the working directory to there.
COPY . /app/alphafold
WORKDIR /app/alphafold
# Install the exact dependency tree using uv and cache the build artifacts.
# --frozen: do not update the lockfile during build.
# --all-groups: install development/test dependencies defined in pyproject.toml.
# --no-editable: install as a static package.
# If using this as a recipe for local installation, we recommend removing the
# --frozen and --no-editable flags.
RUN --mount=type=cache,target=/root/.cache/uv \
UV_LINK_MODE=copy uv sync --frozen --all-groups --no-editable
# Build chemical components database (this binary was installed by uv sync).
RUN uv run build_data
# To work around a known XLA issue causing the compilation time to greatly
# increase, the following environment variable setting XLA flags must be enabled
# when running AlphaFold 3. Note that if using CUDA capability 7 GPUs, it is
# necessary to set the following XLA_FLAGS value instead:
# ENV XLA_FLAGS="--xla_disable_hlo_passes=custom-kernel-fusion-rewriter"
# (no need to disable gemm in that case as it is not supported for such GPU).
ENV XLA_FLAGS="--xla_gpu_enable_triton_gemm=false"
# Memory settings used for folding up to 5,120 tokens on A100 80 GB.
ENV XLA_PYTHON_CLIENT_PREALLOCATE=true
ENV XLA_CLIENT_MEM_FRACTION=0.95
CMD ["uv", "run", "python3", "run_alphafold.py"]
================================================
FILE: docker/dockerignore
================================================
dockerignore
Dockerfile
================================================
FILE: docker/jackhmmer_seq_limit.patch
================================================
--- hmmer-3.4/src/jackhmmer.c
+++ hmmer-3.4/src/jackhmmer.c
@@ -73,6 +73,7 @@ static ESL_OPTIONS options[] = {
{ "--noali", eslARG_NONE, FALSE, NULL, NULL, NULL, NULL, NULL, "don't output alignments, so output is smaller", 2 },
{ "--notextw", eslARG_NONE, NULL, NULL, NULL, NULL, NULL, "--textw", "unlimit ASCII text output line width", 2 },
{ "--textw", eslARG_INT, "120", NULL, "n>=120", NULL, NULL, "--notextw", "set max width of ASCII text output lines", 2 },
+ { "--seq_limit", eslARG_INT, NULL, NULL, NULL, NULL, NULL, "--seq_limit", "if set, truncate all hits after this value is reached", 2 },
/* Control of scoring system */
{ "--popen", eslARG_REAL, "0.02", NULL, "0<=x<0.5",NULL, NULL, NULL, "gap open probability", 3 },
{ "--pextend", eslARG_REAL, "0.4", NULL, "0<=x<1", NULL, NULL, NULL, "gap extend probability", 3 },
@@ -298,6 +299,7 @@ output_header(FILE *ofp, ESL_GETOPTS *go
if (esl_opt_IsUsed(go, "--noali") && fprintf(ofp, "# show alignments in output: no\n") < 0) ESL_EXCEPTION_SYS(eslEWRITE, "write failed");
if (esl_opt_IsUsed(go, "--notextw") && fprintf(ofp, "# max ASCII text line length: unlimited\n") < 0) ESL_EXCEPTION_SYS(eslEWRITE, "write failed");
if (esl_opt_IsUsed(go, "--textw") && fprintf(ofp, "# max ASCII text line length: %d\n", esl_opt_GetInteger(go, "--textw")) < 0) ESL_EXCEPTION_SYS(eslEWRITE, "write failed");
+ if (esl_opt_IsUsed(go, "--seq_limit") && fprintf(ofp, "# set max sequence hits to return: %d\n", esl_opt_GetInteger(go, "--seq_limit")) < 0) ESL_EXCEPTION_SYS(eslEWRITE, "write failed");
if (esl_opt_IsUsed(go, "--popen") && fprintf(ofp, "# gap open probability: %f\n", esl_opt_GetReal (go, "--popen")) < 0) ESL_EXCEPTION_SYS(eslEWRITE, "write failed");
if (esl_opt_IsUsed(go, "--pextend") && fprintf(ofp, "# gap extend probability: %f\n", esl_opt_GetReal (go, "--pextend")) < 0) ESL_EXCEPTION_SYS(eslEWRITE, "write failed");
if (esl_opt_IsUsed(go, "--mx") && fprintf(ofp, "# subst score matrix (built-in): %s\n", esl_opt_GetString (go, "--mx")) < 0) ESL_EXCEPTION_SYS(eslEWRITE, "write failed");
@@ -674,6 +676,13 @@ serial_master(ESL_GETOPTS *go, struct cf
/* Print the results. */
p7_tophits_SortBySortkey(info->th);
p7_tophits_Threshold(info->th, info->pli);
+ /* Limit the number of hits if specified. */
+ if (esl_opt_IsOn(go, "--seq_limit"))
+ {
+ int seq_limit = esl_opt_GetInteger(go, "--seq_limit");
+ info->th->N = ESL_MIN(info->th->N, seq_limit);
+ }
+
p7_tophits_CompareRanking(info->th, kh, &nnew_targets);
p7_tophits_Targets(ofp, info->th, info->pli, textw); if (fprintf(ofp, "\n\n") < 0) ESL_EXCEPTION_SYS(eslEWRITE, "write failed");
p7_tophits_Domains(ofp, info->th, info->pli, textw); if (fprintf(ofp, "\n\n") < 0) ESL_EXCEPTION_SYS(eslEWRITE, "write failed");
================================================
FILE: docs/community_tools.md
================================================
# Community Tools
## JAAG: a JSON input file Assembler for AlphaFold 3 (with Glycan Integration)
JAAG is a lightweight, web-based GUI tool that helps generate AlphaFold 3 input
JSON files with integrated glycan support. It automates the creation of correct
glycan syntax (including `bondedAtomPairs` + CCD), reducing manual errors when
preparing glycoprotein or glycan–protein complexes.
* Web app: https://biofgreat.org/JAAG
* Source code: https://github.com/chinchc/JAAG
* Paper: https://doi.org/10.1093/glycob/cwaf083
Note: JAAG is compatible with standalone AlphaFold 3, but not with the AlphaFold
3 server.
## Modeling glycans with AlphaFold 3: capabilities, caveats, and limitations
Paper on modeling glycans (and other ligands) with AF3 that modeled and assessed
major glycan classes and provides:
* Step-by-step tutorial for building ligand inputs (applicable beyond glycans)
* Ready-to-run scripts for each glycan class
* Comprehensive CCD table for all SNFG monosaccharides
* Discussion of caveats and limitations of AF3
* Full AF3 inputs/outputs archived on ModelArchive for reproducibility
Useful resource if your AF3 ligand models appear stereochemically off.
* Paper: https://doi.org/10.1093/glycob/cwaf048
* ModelArchive: https://doi.org/10.5452/ma-af3glycan
================================================
FILE: docs/contributing.md
================================================
# How to Contribute
We welcome small patches related to bug fixes and documentation, but we do not
plan to make any major changes to this repository.
## Before You Begin
### Sign Our Contributor License Agreement
Contributions to this project must be accompanied by a
[Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
You (or your employer) retain the copyright to your contribution; this simply
gives us permission to use and redistribute your contributions as part of the
project.
If you or your current employer have already signed the Google CLA (even if it
was for a different project), you probably don't need to do it again.
Visit to see your current agreements or to
sign a new one.
### Review Our Community Guidelines
This project follows
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
## Contribution Process
We won't accept pull requests directly, but if you send one, we will review it.
If we send a fix based on your pull request, we will make sure to credit you in
the release notes.
================================================
FILE: docs/input.md
================================================
# AlphaFold 3 Input
## Specifying Input Files
You can provide inputs to `run_alphafold.py` in one of two ways:
- Single input file: Use the `--json_path` flag followed by the path to a
single JSON file.
- Multiple input files: Use the `--input_dir` flag followed by the path to a
directory of JSON files.
## Input Format
AlphaFold 3 uses a custom JSON input format differing from the
[AlphaFold Server JSON input format](https://github.com/google-deepmind/alphafold/tree/main/server).
See [below](#alphafold-server-json-compatibility) for more information.
The custom AlphaFold 3 format allows:
* Specifying protein, RNA, and DNA chains, including modified residues.
* Specifying custom multiple sequence alignment (MSA) for protein and RNA
chains.
* Specifying custom structural templates for protein chains.
* Specifying ligands using
[Chemical Component Dictionary (CCD)](https://www.wwpdb.org/data/ccd) codes.
* Specifying ligands using SMILES.
* Specifying ligands by defining them using the CCD mmCIF format and supplying
them via the [user-provided CCD](#user-provided-ccd).
* Specifying covalent bonds between entities.
* Specifying multiple random seeds.
## AlphaFold Server JSON Compatibility
The [AlphaFold Server](https://alphafoldserver.com/) uses a separate
[JSON format](https://github.com/google-deepmind/alphafold/tree/main/server)
from the one used here in the AlphaFold 3 codebase. In particular, the JSON
format used in the AlphaFold 3 codebase offers more flexibility and control in
defining custom ligands, branched glycans, and covalent bonds between entities.
We provide a converter in `run_alphafold.py` which automatically detects the
input JSON format, denoted `dialect` in the converter code. The converter
denotes the AlphaFoldServer JSON as `alphafoldserver`, and the JSON format
defined here in the AlphaFold 3 codebase as `alphafold3`. If the detected input
JSON format is `alphafoldserver`, then the converter will translate that into
the JSON format `alphafold3`.
### Multiple Inputs
The top-level of the `alphafoldserver` JSON format is a list, allowing
specification of multiple inputs in a single JSON. In contrast, the `alphafold3`
JSON format requires exactly one input per JSON file. Specifying multiple inputs
in a single `alphafoldserver` JSON is fully supported.
Note that the converter distinguishes between `alphafoldserver` and `alphafold3`
JSON formats by checking if the top-level of the JSON is a list or not. In
particular, if you pass in a `alphafoldserver`-style JSON without a top-level
list, then this is considered incorrect and `run_alphafold.py` will raise an
error.
### Glycans
If the JSON in `alphafoldserver` format specifies glycans, the converter will
raise an error. This is because translating glycans specified in the
`alphafoldserver` format to the `alphafold3` format is not currently supported.
### Random Seeds
The `alphafoldserver` JSON format allows users to specify `"modelSeeds": []`, in
which case a seed is chosen randomly for the user. On the other hand, the
`alphafold3` format requires users to specify a seed.
The converter will choose a seed randomly if `"modelSeeds": []` is set when
translating from `alphafoldserver` JSON format to `alphafold3` JSON format. If
seeds are specified in the `alphafoldserver` JSON format, then those will be
preserved in the translation to the `alphafold3` JSON format.
### Ions
While AlphaFold Server treats ions and ligands as different entity types in the
JSON format, AlphaFold 3 treats ions as ligands. Therefore, to specify e.g. a
magnesium ion, one would specify it as an entity of type `ligand` with
`ccdCodes: ["MG"]`.
### Sequence IDs
The `alphafold3` JSON format requires the user to specify a unique identifier
(`id`) for each entity. On the other hand, the `alphafoldserver` does not allow
specification of an `id` for each entity. Thus, the converter automatically
assigns one.
The converter iterates through the list provided in the `sequences` field of the
`alphafoldserver` JSON format, assigning an `id` to each entity using the
following order ("reverse spreadsheet style"):
```
A, B, ..., Z, AA, BA, CA, ..., ZA, AB, BB, CB, ..., ZB, ...
```
For any entity with `count > 1`, an `id` is assigned arbitrarily to each "copy"
of the entity.
## Top-level Structure
The top-level structure of the input JSON is:
```json
{
"name": "Job name goes here",
"modelSeeds": [1, 2], # At least one seed required.
"sequences": [
{"protein": {...}},
{"rna": {...}},
{"dna": {...}},
{"ligand": {...}}
],
"bondedAtomPairs": [...], # Optional.
"userCCD": "...", # Optional, mutually exclusive with userCCDPath.
"userCCDPath": "...", # Optional, mutually exclusive with userCCD.
"dialect": "alphafold3", # Required.
"version": 4 # Required.
}
```
The fields specify the following:
* `name: str`: The name of the job. A sanitised version of this name is used
for naming the output files.
* `modelSeeds: list[int]`: A list of integer random seeds. The pipeline and
the model will be invoked with each of the seeds in the list. I.e. if you
provide *n* random seeds, you will get *n* predicted structures, each with
the respective random seed. You must provide at least one random seed.
* `sequences: list[Protein | RNA | DNA | Ligand]`: A list of sequence
dictionaries, each defining a molecular entity, see below.
* `bondedAtomPairs: list[Bond]`: An optional list of covalently bonded atoms.
These can link atoms within an entity, or across two entities. See more
below.
* `userCCD: str`: An optional string with user-provided chemical components
dictionary. This is an expert mode for providing custom molecules when
SMILES is not sufficient. This should also be used when you have a custom
molecule that needs to be bonded with other entities - SMILES can't be used
in such cases since it doesn't give the possibility of uniquely naming all
atoms. It can also be used to provide a reference conformer for cases where
RDKit fails to generate a conformer. See more below.
* `userCCDPath: str`: An optional path to a file that contains the
user-provided chemical components dictionary instead of providing it inline
using the `userCCD` field. The path can be either absolute, or relative to
the input JSON path. The file must be in the
[CCD mmCIF format](https://www.wwpdb.org/data/ccd#mmcifFormat), and could be
either plain text, or compressed using gzip, xz, or zstd.
* `dialect: str`: The dialect of the input JSON. This must be set to
`alphafold3`. See
[AlphaFold Server JSON Compatibility](#alphafold-server-json-compatibility)
for more information.
* `version: int`: The version of the input JSON. This must be set to 1 or 2.
See
[AlphaFold Server JSON Compatibility](#alphafold-server-json-compatibility)
and [versions](#versions) below for more information.
## Versions
The top-level `version` field (for the `alphafold3` dialect) can be either `1`,
`2`, or `3`. The following features have been added in respective versions:
* `1`: the initial AlphaFold 3 input format.
* `2`: added the option of specifying external MSA and templates using newly
added fields `unpairedMsaPath`, `pairedMsaPath`, and `mmcifPath`.
* `3`: added the option of specifying external user-provided CCD using newly
added field `userCCDPath`.
* `4`: added the option of specifying textual `description` of protein chains,
RNA chains, DNA chains, or ligands.
## Sequences
The `sequences` section specifies the protein chains, RNA chains, DNA chains,
and ligands. Every entity in `sequences` must have a unique ID. IDs don't have
to be sorted alphabetically.
### Protein
Specifies a single protein chain.
```json
{
"protein": {
"id": "A",
"sequence": "PVLSCGEWQL",
"modifications": [
{"ptmType": "HY3", "ptmPosition": 1},
{"ptmType": "P1L", "ptmPosition": 5}
],
"description": ..., # Optional.
"unpairedMsa": ..., # Mutually exclusive with unpairedMsaPath.
"unpairedMsaPath": ..., # Mutually exclusive with unpairedMsa.
"pairedMsa": ..., # Mutually exclusive with pairedMsaPath.
"pairedMsaPath": ..., # Mutually exclusive with pairedMsa.
"templates": [...]
}
}
```
The fields specify the following:
* `id: str | list[str]`: An uppercase letter or multiple letters specifying
the unique IDs for each copy of this protein chain. The IDs are then also
used in the output mmCIF file. Specifying a list of IDs (e.g. `["A", "B",
"C"]`) implies a homomeric chain with multiple copies.
* `sequence: str`: The amino-acid sequence, specified as a string that uses
the 1-letter standard amino acid codes.
* `modifications: list[ProteinModification]`: An optional list of
post-translational modifications. Each modification is specified using its
CCD code and 1-based residue position. In the example above, we see that the
first residue won't be a proline (`P`) but instead `HY3`.
* `description: str`: An optional textual description of this chain. This
field will is only used in the JSON format and serves as a comment
describing this chain.
* `unpairedMsa: str`: An optional multiple sequence alignment for this chain.
This is specified using the A3M format (equivalent to the FASTA format, but
also allows gaps denoted by the hyphen `-` character). See more details
below.
* `unpairedMsaPath: str`: An optional path to a file that contains the
multiple sequence alignment for this chain instead of providing it inline
using the `unpairedMsa` field. The path can be either absolute, or relative
to the input JSON path. The file must be in the A3M format, and could be
either plain text, or compressed using gzip, xz, or zstd.
* `pairedMsa: str`: We recommend *not* using this optional field and using the
`unpairedMsa` for the purposes of pairing. See more details below.
* `pairedMsaPath: str`: An optional path to a file that contains the multiple
sequence alignment for this chain instead of providing it inline using the
`pairedMsa` field. The path can be either absolute, or relative to the input
JSON path. The file must be in the A3M format, and could be either plain
text, or compressed using gzip, xz, or zstd.
* `templates: list[Template]`: An optional list of structural templates. See
more details below.
### RNA
Specifies a single RNA chain.
```json
{
"rna": {
"id": "A",
"sequence": "AGCU",
"modifications": [
{"modificationType": "2MG", "basePosition": 1},
{"modificationType": "5MC", "basePosition": 4}
],
"description": ..., # Optional.
"unpairedMsa": ..., # Mutually exclusive with unpairedMsaPath.
"unpairedMsaPath": ... # Mutually exclusive with unpairedMsa.
}
}
```
The fields specify the following:
* `id: str | list[str]`: An uppercase letter or multiple letters specifying
the unique IDs for each copy of this RNA chain. The IDs are then also used
in the output mmCIF file. Specifying a list of IDs (e.g. `["A", "B", "C"]`)
implies a homomeric chain with multiple copies.
* `sequence: str`: The RNA sequence, specified as a string using only the
letters `A`, `C`, `G`, `U`.
* `modifications: list[RnaModification]`: An optional list of modifications.
Each modification is specified using its CCD code and 1-based base position.
* `description: str`: An optional textual description of this chain. This
field will is only used in the JSON format and serves as a comment
describing this chain.
* `unpairedMsa: str`: An optional multiple sequence alignment for this chain.
This is specified using the A3M format. See more details below.
* `unpairedMsaPath: str`: An optional path to a file that contains the
multiple sequence alignment for this chain instead of providing it inline
using the `unpairedMsa` field. The path can be either absolute, or relative
to the input JSON path. The file must be in the A3M format, and could be
either plain text, or compressed using gzip, xz, or zstd.
### DNA
Specifies a single DNA chain.
```json
{
"dna": {
"id": "A",
"sequence": "GACCTCT",
"modifications": [
{"modificationType": "6OG", "basePosition": 1},
{"modificationType": "6MA", "basePosition": 2}
],
"description": ... # Optional.
}
}
```
The fields specify the following:
* `id: str | list[str]`: An uppercase letter or multiple letters specifying
the unique IDs for each copy of this DNA chain. The IDs are then also used
in the output mmCIF file. Specifying a list of IDs (e.g. `["A", "B", "C"]`)
implies a homomeric chain with multiple copies.
* `sequence: str`: The DNA sequence, specified as a string using only the
letters `A`, `C`, `G`, `T`.
* `modifications: list[DnaModification]`: An optional list of modifications.
Each modification is specified using its CCD code and 1-based base position.
* `description: str`: An optional textual description of this chain. This
field will is only used in the JSON format and serves as a comment
describing this chain.
### Ligands
Specifies a single ligand. Ligands can be specified using 3 different formats:
1. [CCD code(s)](https://www.wwpdb.org/data/ccd). This is the easiest way to
specify ligands. Supports specifying covalent bonds to other entities. CCD
from 2022-09-28 is used. If multiple CCD codes are specified, you may want
to specify a bond between these and/or a bond to some other entity. See the
[bonds](#bonds) section below.
2. [SMILES string](https://en.wikipedia.org/wiki/Simplified_Molecular_Input_Line_Entry_System).
This enables specifying ligands that are not in CCD. If using SMILES, you
cannot specify covalent bonds to other entities as these rely on specific
atom names - see the next option for what to use for this case.
3. User-provided CCD + custom ligand codes. This enables specifying ligands not
in CCD, while also supporting specification of covalent bonds to other
entities and backup reference coordinates for when RDKit fails to generate a
conformer. This offers the most flexibility, but also requires careful
attention to get all of the details right.
```json
{
"ligand": {
"id": ["G", "H", "I"],
"ccdCodes": ["ATP"],
"description": ... # Optional.
}
},
{
"ligand": {
"id": "J",
"ccdCodes": ["LIG-1337"],
"description": ... # Optional.
}
},
{
"ligand": {
"id": "K",
"smiles": "CC(=O)OC1C[NH+]2CCC1CC2",
"description": ... # Optional.
}
}
```
The fields specify the following:
* `id: str | list[str]`: An uppercase letter (or multiple letters) specifying
the unique ID of this ligand. This ID is then also used in the output mmCIF
file. Specifying a list of IDs (e.g. `["A", "B", "C"]`) implies a ligand
that has multiple copies.
* `ccdCodes: list[str]`: An optional list of CCD codes. These could be either
standard CCD codes, or custom codes pointing to the
[user-provided CCD](#user-provided-ccd).
* `smiles: str`: An optional string defining the ligand using a SMILES string.
The SMILES string must be correctly JSON-escaped.
* `description: str`: An optional textual description of this chain. This
field will is only used in the JSON format and serves as a comment
describing this ligand.
Each ligand may be specified using CCD codes or SMILES but not both, i.e. for a
given ligand, the `ccdCodes` and `smiles` fields are mutually exclusive.
#### SMILES string JSON escaping
The SMILES string must be correctly JSON-escaped, in particular the backslash
character must be escaped as two backslashes, otherwise the JSON parser will
fail with a `JSONDecodeError`. For instance, the following SMILES string
`CCC[C@@H](O)CC\C=C\C=C\C#CC#C\C=C\CO` has to be specified as:
```json
{
"ligand": {
"id": "A",
"smiles": "CCC[C@@H](O)CC\\C=C\\C=C\\C#CC#C\\C=C\\CO"
}
}
```
You can JSON-escape the SMILES string using the
[`jq`](https://github.com/jqlang/jq) command-line tool which should be easily
installable on most Linux systems:
```bash
jq -R . <<< 'CCC[C@@H](O)CC\C=C\C=C\C#CC#C\C=C\CO' # Replace with your SMILES.
```
Alternatively, you can use this Python code:
```python
import json
smiles = r'CCC[C@@H](O)CC\C=C\C=C\C#CC#C\C=C\CO' # Replace with your SMILES.
print(json.dumps(smiles))
```
#### Reference structure construction with SMILES
For some ligands and some random seeds, RDKit might fail to generate a
conformer, indicated by the `Failed to construct RDKit reference structure`
error message. In this case, you can either provide a reference structure for
the ligand using the [user-provided CCD Format](#user-provided-ccd-format), or
try increasing the number of RDKit conformer iterations using the
`--conformer_max_iterations=...` flag.
### Ions
Ions are treated as ligands, e.g. a magnesium ion would simply be a ligand with
`ccdCodes: ["MG"]`.
## Multiple Sequence Alignment
Protein and RNA chains allow setting a custom Multiple Sequence Alignment (MSA).
If not set, the data pipeline will automatically build MSAs for protein and RNA
entities using Jackhmmer/Nhmmer search over genetic databases as described in
the paper.
### RNA Multiple Sequence Alignment
RNA `unpairedMsa` can be either:
1. Unset (or set explicitly to `null`). AlphaFold 3 will build MSA for this RNA
chain automatically. This is the recommended option.
2. Set to an empty string (`""`). AlphaFold 3 won't build the MSA for this RNA
chain and the MSA input to the model will be just the RNA chain (equivalent
to running MSA-free for this RNA chain).
3. Set to a non-empty A3M string. AlphaFold 3 will use the provided MSA for
this RNA chain.
### Protein Multiple Sequence Alignment
For protein chains, the situation is slightly more complicated due to paired and
unpaired MSA (see [MSA Pairing](#msa-pairing) below for more details).
The following combinations are valid for a given protein chain:
1. Both `unpairedMsa` and `pairedMsa` fields are unset (or set explicitly to
`null`), AlphaFold 3 will build both MSAs automatically. This is the
recommended option.
2. The `unpairedMsa` is set to to a non-empty A3M string, `pairedMsa` set to an
empty string (`""`). AlphaFold 3 won't build MSA, will use the `unpairedMsa`
as is and run `pairedMSA`-free.
3. The `pairedMsa` is set to to a non-empty A3M string, `unpairedMsa` set to an
empty string (`""`). AlphaFold 3 won't build MSA, will use the `pairedMsa`
and run `unpairedMSA`-free. **This option is not recommended**, see
[MSA Pairing](#msa-pairing) below.
4. Both `unpairedMsa` and `pairedMsa` fields are set to an empty string (`""`).
AlphaFold 3 will not build the MSA and the MSA input to the model will be
just the query sequence (equivalent to running completely MSA-free).
5. Both `unpairedMsa` and `pairedMsa` fields are set to a custom non-empty A3M
string, AlphaFold 3 will use the provided MSA instead of building one as
part of the data pipeline. This is considered an expert option.
Note that both `unpairedMsa` and `pairedMsa` have to either be *both* set (i.e.
non-`null`), or both unset (i.e. both `null`, explicitly or implicitly).
Typically, when setting `unpairedMsa`, you will set the `pairedMsa` to an empty
string (`""`). For example this will run the protein chain A with the given MSA,
but without any templates (template-free):
```json
{
"protein": {
"id": "A",
"sequence": ...,
"unpairedMsa": "The A3M you want to run with",
"pairedMsa": "",
"templates": []
}
}
```
When setting your own MSA, you have to make sure that:
1. The MSA is in the A3M format. This means adhering to the FASTA format while
also allowing lowercase characters denoting inserted residues and hyphens
(`-`) denoting gaps in sequences.
2. The first sequence is exactly equal to the query sequence.
3. If all insertions are removed from MSA hits (i.e. all lowercase letters are
removed), all sequences have exactly the same length as the query (they form
an exact rectangular matrix).
### MSA Pairing
MSA pairing matters only when folding multiple chains (multimers), since we need
to find a way to concatenate MSAs for the individual chains along the sequence
dimension. If done naively, by simply concatenating the individual MSA matrices
along the sequence dimension and padding so that all MSAs have the same depth,
one can end up with rows in the concatenated MSA that are formed by sequences
from different organisms.
It may be desirable to ensure that across multiple chains, sequences in the MSA
that are from the same organism end up in the same MSA row. AlphaFold 3
internally achieves this by looking for the UniProt organism ID in the
`pairedMsa` and pairing sequences based on this information.
We recommend users do the pairing manually or use the output of an appropriate
software and then provide the MSA using only the `unpairedMsa` field. This
method gives exact control over the placement of each sequence in the MSA, as
opposed to relying on name-matching post-processing heuristics used for
`pairedMsa`.
When setting `unpairedMsa` manually, the `pairedMsa` must be explicitly set to
an empty string (`""`).
Make sure to run with `--resolve_msa_overlaps=false`. This prevents
deduplication of the unpaired MSA within each chain against the paired MSA
sequences. Even if you set `pairedMsa` to an empty string, the query sequence(s)
will still be added in there and the deduplication procedure could destroy the
carefully crafted sequence positioning in the unpaired MSA.
For instance, if there are two chains `DEEP` and `MIND` which we want to be
paired on organism A and C, we can achieve it as follows:
```txt
> query
DEEP
> match 1 (organism A)
D--P
> match 2 (organism B)
DD-P
> match 3 (organism C)
DD-P
```
```txt
> query
MIND
> match 1 (organism A)
M--D
> Empty hit to make sure pairing is achieved
----
> match 2 (organism C)
MIN-
```
The resulting MSA when chains are concatenated will then be:
```txt
> query
DEEPMIND
> match 1 + match 1
D--PM--D
> match 2 + padding
DD-P----
> match 3 + match 2
DD-PMIN-
```
## Structural Templates
Structural templates can be specified only for protein chains:
```json
"templates": [
{
"mmcif": ..., # Mutually exclusive with mmcifPath.
"mmcifPath": ..., # Mutually exclusive with mmcif.
"queryIndices": [0, 1, 2, 4, 5, 6],
"templateIndices": [0, 1, 2, 3, 4, 8]
}
]
```
The fields specify the following:
* `mmcif: str`: A string containing the single chain protein structural
template in the mmCIF format.
* `mmcifPath: str`: An optional path to a file that contains the mmCIF with
the structural template instead of providing it inline using the `mmcifPath`
field. The path can be either absolute, or relative to the input JSON path.
The file must be in the mmCIF format, and could be either plain text, or
compressed using gzip, xz, or zstd.
* `queryIndices: list[int]`: O-based indices in the query sequence, defining
the mapping from query residues to template residues.
* `templateIndices: list[int]`: O-based indices in the template sequence,
specifying the mapping from query residues to template residues defined in
the mmCIF file. Note that unresolved mmCIF residues must be taken into
account when specifying template indices.
A template is specified as an mmCIF string containing a single chain with the
structural template together with a 0-based mapping that maps query residue
indices to the template residue indices. The mapping is specified using two
lists of the same length. E.g. to express a mapping `{0: 0, 1: 2, 2: 5, 3: 6}`,
you would specify the two indices lists as:
```json
"queryIndices": [0, 1, 2, 3],
"templateIndices": [0, 2, 5, 6]
```
Note that mmCIFs can have residues with missing atom coordinates (present in
residue tables but missing in the `_atom_site` table) – these must be taken into
account when specifying template indices. E.g. to align residues 4–7 in a
template with unresolved residues 1, 2, 3 and resolved residues 4, 5, 6, 7, you
need to set the template indices to 3, 4, 5, 6 (since 0-based indexing is used).
An example of a protein with unresolved residues 1–20 can be found here:
https://www.rcsb.org/structure/8UXY.
You can provide multiple structural templates. Note that if an mmCIF containing
more than one chain is provided, you will get an error since it is not possible
to determine which of the chains should be used as the template.
You can run template-free (but still run genetic search and build MSA) by
setting templates to `[]` and either explicitly setting both `unpairedMsa` and
`pairedMsa` to `null`:
```json
"protein": {
"id": "A",
"sequence": ...,
"pairedMsa": null,
"unpairedMsa": null,
"templates": []
}
```
Or you can simply fully omit them:
```json
"protein": {
"id": "A",
"sequence": ...,
"templates": []
}
```
You can also run with pre-computed MSA, but let AlphaFold 3 search for
templates. This can be achieved by setting `unpairedMsa` and `pairedMsa`, but
keeping templates unset (or set to `null`). The profile given as an input to
Hmmsearch when searching for templates will be built from the provided
`unpairedMsa`:
```json
"protein": {
"id": "A",
"sequence": ...,
"unpairedMsa": ...,
"pairedMsa": ...,
"templates": null
}
```
Or you can simply fully omit the `templates` field thus setting it implicitly to
`null`:
```json
"protein": {
"id": "A",
"sequence": ...,
"unpairedMsa": ...,
"pairedMsa": ...,
}
```
## Bonds
To manually specify covalent bonds, use the `bondedAtomPairs` field. This is
intended for modelling covalent ligands, and for defining multi-CCD ligands
(e.g. glycans). Defining covalent bonds between or within polymer entities is
not currently supported.
Bonds are specified as pairs of (source atom, destination atom), with each atom
being uniquely addressed using 3 fields:
* **Entity ID** (`str`): this corresponds to the `id` field for that entity.
* **Residue ID** (`int`): this is 1-based residue index *within* the chain.
For single-residue ligands, this is simply set to 1.
* **Atom name** (`str`): this is the unique atom name *within* the given
residue. The atom name for protein/RNA/DNA residues or CCD ligands can be
looked up in the CCD for the given chemical component. This also explains
why SMILES ligands don't support bonds: there is no atom name that could be
used to define the bond. This shortcoming can be addressed by using the
user-provided CCD format (see below).
The example below shows two bonds:
```json
"bondedAtomPairs": [
[["A", 145, "SG"], ["L", 1, "C04"]],
[["J", 1, "O6"], ["J", 2, "C1"]]
]
```
The first bond is between chain A, residue 145, atom SG and chain L, residue 1,
atom C04. This is a typical example for a covalent ligand. The second bond is
between chain J, residue 1, atom O6 and chain J, residue 2, atom C1. This bond
is within the same entity and is a typical example when defining a glycan.
All bonds are implicitly assumed to be covalent bonds. Other bond types are not
supported.
### Defining Glycans
Glycans are bound to a protein residue, and they are typically formed of
multiple chemical components. To define a glycan, define a new ligand with all
of the chemical components of the glycan. Then define a bond that links the
glycan to the protein residue, and all bonds that are within the glycan between
its individual chemical components.
For example, to define the following glycan composed of 4 components (CMP1,
CMP2, CMP3, CMP4) bound to an asparagine in a protein chain A:
```
⋮
ALA CMP4
| |
ASN ―― CMP1 ―― CMP2
| |
ALA CMP3
⋮
```
You will need to specify:
1. Protein chain A.
2. Ligand chain B with the 4 components.
3. Bonds ASN-CMP1, CMP1-CMP2, CMP2-CMP3, CMP2-CMP4.
## User-provided CCD
There are two approaches to model a custom ligand not defined in the CCD:
1. If the ligand is not bonded to other entities, it can be defined using a
[SMILES string](https://en.wikipedia.org/wiki/Simplified_Molecular_Input_Line_Entry_System).
2. If it is bonded to other entities, or to be able to customise relevant
features (such as bond orders, atom names and ideal coordinates used when
conformer generation fails), it is necessary to define that particular
ligand using the
[CCD mmCIF format](https://www.wwpdb.org/data/ccd#mmcifFormat).
Note that if a full CCD mmCIF is provided, any SMILES string input as part of
that mmCIF is ignored.
Once defined, this ligand needs to be assigned a name that doesn't clash with
existing CCD ligand names (e.g. `LIG-1`). Avoid underscores (`_`) in the name,
as it could cause issues in the mmCIF format.
The newly defined ligand can then be used as a standard CCD ligand using its
custom name, and bonds can be linked to it using its named atom scheme.
### Conformer Generation
The data pipeline attempts to generate a conformer for ligands using RDKit. The
`Mol` used to generate the conformer is constructed either from the information
provided in the CCD mmCIF, or from the SMILES string if that is the only
information provided.
If conformer generation fails, the model will fall back to using the ideal
coordinates in the CCD mmCIF if these are provided. If they are not provided,
the model will use the reference coordinates if the last modification date given
in the CCD mmCIF is prior to the training cutoff date. If no coordinates can be
found in this way, all conformer coordinates are set to zero and the model will
output `NaN` (`null` in the output JSON) confidences for the ligand.
Note that sometimes conformer generation failures can be resolved by
increasinging the number of RDKit conformer iterations using the
`--conformer_max_iterations=...` flag.
### User-provided CCD Format
The user-provided CCD must be passed either:
* In the `userCCD` field (in the root of the input JSON) as a string. Note
that JSON doesn't allow newlines within strings, so newline characters
(`\n`) must be used to delimit lines. Single rather than double quotes
should also be used around strings like the chemical formula.
* In the `userCCDPath` field, as a path to a file that contains the
user-provided chemical components dictionary. The path can be either
absolute, or relative to the input JSON path. The file must be in the
[CCD mmCIF format](https://www.wwpdb.org/data/ccd#mmcifFormat), and could be
either plain text, or compressed using gzip, xz, or zstd.
The main pieces of information used are the atom names and elements, bonds, and
also the ideal coordinates (`pdbx_model_Cartn_{x,y,z}_ideal`) which essentially
serve as a structural template for the ligand if RDKit fails to generate
conformers for that ligand.
The user-provided CCD can also be used to redefine standard chemical components
in the CCD. This can be useful if you need to redefine the ideal coordinates.
Below is an example user-provided CCD redefining component X7F, which serves to
illustrate the required sections. For readability purposes, newlines have not
been replaced by `\n`.
```
data_MY-X7F
#
_chem_comp.id MY-X7F
_chem_comp.name '5,8-bis(oxidanyl)naphthalene-1,4-dione'
_chem_comp.type non-polymer
_chem_comp.formula 'C10 H6 O4'
_chem_comp.mon_nstd_parent_comp_id ?
_chem_comp.pdbx_synonyms ?
_chem_comp.formula_weight 190.152
#
loop_
_chem_comp_atom.comp_id
_chem_comp_atom.atom_id
_chem_comp_atom.type_symbol
_chem_comp_atom.charge
_chem_comp_atom.pdbx_leaving_atom_flag
_chem_comp_atom.pdbx_model_Cartn_x_ideal
_chem_comp_atom.pdbx_model_Cartn_y_ideal
_chem_comp_atom.pdbx_model_Cartn_z_ideal
MY-X7F C02 C 0 N -1.418 -1.260 0.018
MY-X7F C03 C 0 N -0.665 -2.503 -0.247
MY-X7F C04 C 0 N 0.677 -2.501 -0.235
MY-X7F C05 C 0 N 1.421 -1.257 0.043
MY-X7F C06 C 0 N 0.706 0.032 0.008
MY-X7F C07 C 0 N -0.706 0.030 -0.004
MY-X7F C08 C 0 N -1.397 1.240 -0.037
MY-X7F C10 C 0 N -0.685 2.443 -0.057
MY-X7F C11 C 0 N 0.679 2.445 -0.045
MY-X7F C12 C 0 N 1.394 1.243 -0.013
MY-X7F O01 O 0 N -2.611 -1.301 0.247
MY-X7F O09 O 0 N -2.752 1.249 -0.049
MY-X7F O13 O 0 N 2.750 1.257 -0.001
MY-X7F O14 O 0 N 2.609 -1.294 0.298
MY-X7F H1 H 0 N -1.199 -3.419 -0.452
MY-X7F H2 H 0 N 1.216 -3.416 -0.429
MY-X7F H3 H 0 N -1.221 3.381 -0.082
MY-X7F H4 H 0 N 1.212 3.384 -0.062
MY-X7F H5 H 0 N -3.154 1.271 0.830
MY-X7F H6 H 0 N 3.151 1.241 -0.880
#
loop_
_chem_comp_bond.atom_id_1
_chem_comp_bond.atom_id_2
_chem_comp_bond.value_order
_chem_comp_bond.pdbx_aromatic_flag
O01 C02 DOUB N
O09 C08 SING N
C02 C03 SING N
C02 C07 SING N
C03 C04 DOUB N
C08 C07 DOUB Y
C08 C10 SING Y
C07 C06 SING Y
C10 C11 DOUB Y
C04 C05 SING N
C06 C05 SING N
C06 C12 DOUB Y
C11 C12 SING Y
C05 O14 DOUB N
C12 O13 SING N
C03 H1 SING N
C04 H2 SING N
C10 H3 SING N
C11 H4 SING N
O09 H5 SING N
O13 H6 SING N
#
```
### Mandatory fields
Parsing the user-provided CCD needs only a subset of the fields that CCD uses.
The mandatory fields are described below. Refer to
[CCD documentation](https://www.wwpdb.org/data/ccd#mmcifFormat) for more
detailed explanation of each field. Note that not all of these fields are input
to the model, but they are necessary for the data pipeline to run – see the
[Model input fields](#model-input-fields) section below.
**Singular fields (containing just a single value)**
* `_chem_comp.id`: The ID of the component. Must match the `_data` record and
must not contain special CIF characters (like `_` or `#`).
* `_chem_comp.name`: Optional full name of the component. If unknown, set to
`?`.
* `_chem_comp.type`: Type of the component, typically `non-polymer`.
* `_chem_comp.formula`: Optional component formula. If unknown, set to `?`.
* `_chem_comp.mon_nstd_parent_comp_id`: Optional parent component ID. If
unknown, set to `?`.
* `_chem_comp.pdbx_synonyms`: Optional synonym IDs. If unknown, set to `?`.
* `_chem_comp.formula_weight`: Optional weight of the component. If unknown,
set to `?`.
**Per-atom fields (containing one record per atom)**
* `_chem_comp_atom.comp_id`: Component ID.
* `_chem_comp_atom.atom_id`: Atom ID.
* `_chem_comp_atom.type_symbol`: Atom element type.
* `_chem_comp_atom.charge`: Atom charge.
* `_chem_comp_atom.pdbx_leaving_atom_flag`: Optional flag determining whether
this is a leaving atom. If unset, assumed to be no (`N`) for all atoms.
* `_chem_comp_atom.pdbx_model_Cartn_x_ideal`: Ideal x coordinate.
* `_chem_comp_atom.pdbx_model_Cartn_y_ideal`: Ideal y coordinate.
* `_chem_comp_atom.pdbx_model_Cartn_z_ideal`: Ideal z coordinate.
**Per-bond fields (containing one record per bond)**
* `_chem_comp_bond.atom_id_1`: The ID of the first of the two atoms that
define the bond.
* `_chem_comp_bond.atom_id_2`: The ID of the second of the two atoms that
define the bond.
* `_chem_comp_bond.value_order`: The bond order of the chemical bond
associated with the specified atoms.
* `_chem_comp_bond.pdbx_aromatic_flag`: Whether the bond is aromatic.
### Model input fields
The following fields are used to generate input for the model:
* `_chem_comp_atom.atom_id`: Atom ID.
* `_chem_comp_atom.type_symbol`: Atom element type.
* `_chem_comp_atom.charge`: Atom charge.
* `_chem_comp_atom.pdbx_model_Cartn_x_ideal`: Ideal x coordinate. Only used if
conformer generation fails.
* `_chem_comp_atom.pdbx_model_Cartn_y_ideal`: Ideal y coordinate. Only used if
conformer generation fails.
* `_chem_comp_atom.pdbx_model_Cartn_z_ideal`: Ideal z coordinate. Only used if
conformer generation fails.
* `_chem_comp_bond.atom_id_1`: The ID of the first of the two atoms that
define the bond.
* `_chem_comp_bond.atom_id_2`: The ID of the second of the two atoms that
define the bond.
## Full Example
An example illustrating all the aspects of the input format is provided below.
Note that AlphaFold 3 won't run this input out of the box as it abbreviates
certain fields and the sequences are not biologically meaningful.
```json
{
"name": "Hello fold",
"modelSeeds": [10, 42],
"sequences": [
{
"protein": {
"id": "A",
"sequence": "PVLSCGEWQL",
"modifications": [
{"ptmType": "HY3", "ptmPosition": 1},
{"ptmType": "P1L", "ptmPosition": 5}
],
"description": "10-residue protein with 2 modifications",
"unpairedMsa": ...,
"pairedMsa": ""
}
},
{
"protein": {
"id": "B",
"sequence": "RPACQLW",
"templates": [
{
"mmcif": ...,
"queryIndices": [0, 1, 2, 4, 5, 6],
"templateIndices": [0, 1, 2, 3, 4, 8]
}
]
}
},
{
"dna": {
"id": "C",
"sequence": "GACCTCT",
"modifications": [
{"modificationType": "6OG", "basePosition": 1},
{"modificationType": "6MA", "basePosition": 2}
]
}
},
{
"rna": {
"id": "E",
"sequence": "AGCU",
"modifications": [
{"modificationType": "2MG", "basePosition": 1},
{"modificationType": "5MC", "basePosition": 4}
],
"unpairedMsa": ...
}
},
{
"ligand": {
"id": ["F", "G", "H"],
"ccdCodes": ["ATP"]
}
},
{
"ligand": {
"id": "I",
"ccdCodes": ["NAG", "FUC"]
}
},
{
"ligand": {
"id": "Z",
"smiles": "CC(=O)OC1C[NH+]2CCC1CC2"
}
}
],
"bondedAtomPairs": [
[["A", 1, "CA"], ["G", 1, "CHA"]],
[["I", 1, "O6"], ["I", 2, "C1"]]
],
"userCCD": ...,
"dialect": "alphafold3",
"version": 4
}
```
================================================
FILE: docs/installation.md
================================================
# Installation and Running Your First Prediction
You will need a machine running Linux; AlphaFold 3 does not support other
operating systems. Full installation requires up to 1 TB of disk space to keep
genetic databases (SSD storage is recommended) and an NVIDIA GPU with Compute
Capability 8.0 or greater (GPUs with more memory can predict larger protein
structures). We have verified that inputs with up to 5,120 tokens can fit on a
single NVIDIA A100 80 GB, or a single NVIDIA H100 80 GB. We have verified
numerical accuracy on both NVIDIA A100 and H100 GPUs.
Especially for long targets, the genetic search stage can consume a lot of RAM –
we recommend running with at least 64 GB of RAM.
We provide installation instructions for a machine with an NVIDIA A100 80 GB GPU
and a clean Ubuntu 22.04 LTS installation, and expect that these instructions
should aid others with different setups. If you are installing locally outside
of a Docker container, please ensure CUDA, cuDNN, and JAX are correctly
installed; the
[JAX installation documentation](https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu)
is a useful reference for this case. Please note that the Docker container
requires that the host machine has CUDA 12.6 installed.
The instructions provided below describe how to:
1. Provision a machine on GCP.
1. Install Docker.
1. Install NVIDIA drivers for an A100.
1. Obtain genetic databases.
1. Obtain model parameters.
1. Build the AlphaFold 3 Docker container or Singularity image.
## Provisioning a Machine
Clean Ubuntu images are available on Google Cloud, AWS, Azure, and other major
platforms.
Using an existing Google Cloud project, we provisioned a new machine:
* We recommend using `--machine-type a2-ultragpu-1g` but feel free to use
`--machine-type a2-highgpu-1g` for smaller predictions.
* If desired, replace `--zone us-central1-a` with a zone that has quota for
the machine you have selected. See
[gpu-regions-zones](https://cloud.google.com/compute/docs/gpus/gpu-regions-zones).
```sh
gcloud compute instances create alphafold3 \
--machine-type a2-ultragpu-1g \
--zone us-central1-a \
--image-family ubuntu-2204-lts \
--image-project ubuntu-os-cloud \
--maintenance-policy TERMINATE \
--boot-disk-size 1000 \
--boot-disk-type pd-balanced
```
This provisions a bare Ubuntu 22.04 LTS image on an
[A2 Ultra](https://cloud.google.com/compute/docs/accelerator-optimized-machines#a2-vms)
machine with 12 CPUs, 170 GB RAM, 1 TB disk and NVIDIA A100 80 GB GPU attached.
We verified the following installation steps from this point.
## Installing Docker
These instructions are for rootless Docker.
### Installing Docker on Host
Note these instructions only apply to Ubuntu 22.04 LTS images, see above.
Add Docker's official GPG key. Official Docker instructions are
[here](https://docs.docker.com/engine/install/ubuntu/#install-using-the-repository).
The commands we ran are:
```sh
sudo apt-get update
sudo apt-get install ca-certificates curl
sudo install -m 0755 -d /etc/apt/keyrings
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
sudo chmod a+r /etc/apt/keyrings/docker.asc
```
Add the repository to apt sources:
```sh
echo \
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" | \
sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
sudo apt-get update
sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
sudo docker run hello-world
```
### Enabling Rootless Docker
Official Docker instructions are
[here](https://docs.docker.com/engine/security/rootless/#distribution-specific-hint).
The commands we ran are:
```sh
sudo apt-get install -y uidmap systemd-container
sudo machinectl shell $(whoami)@ /bin/bash -c 'dockerd-rootless-setuptool.sh install && sudo loginctl enable-linger $(whoami) && DOCKER_HOST=unix:///run/user/1001/docker.sock docker context use rootless'
```
## Installing GPU Support
### Installing NVIDIA Drivers
Official Ubuntu instructions are
[here](https://documentation.ubuntu.com/server/how-to/graphics/install-nvidia-drivers/).
The commands we ran are:
```sh
sudo apt-get -y install alsa-utils ubuntu-drivers-common
sudo ubuntu-drivers install
sudo nvidia-smi --gpu-reset
nvidia-smi # Check that the drivers are installed.
```
Accept the "Pending kernel upgrade" dialog if it appears.
You will need to reboot the instance with `sudo reboot now` to reset the GPU if
you see the following warning:
```text
NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver.
Make sure that the latest NVIDIA driver is installed and running.
```
Proceed only if `nvidia-smi` has a sensible output.
### Installing NVIDIA Support for Docker
Official NVIDIA instructions are
[here](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
The commands we ran are:
```sh
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
&& curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
sudo apt-get update
sudo apt-get install -y nvidia-container-toolkit
nvidia-ctk runtime configure --runtime=docker --config=$HOME/.config/docker/daemon.json
systemctl --user restart docker
sudo nvidia-ctk config --set nvidia-container-cli.no-cgroups --in-place
```
Check that your container can see the GPU:
```sh
docker run --rm --gpus all nvidia/cuda:12.6.0-base-ubuntu22.04 nvidia-smi
```
Example output:
```text
Mon Nov 11 12:00:00 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A100-SXM4-80GB Off | 00000000:00:05.0 Off | 0 |
| N/A 34C P0 51W / 400W | 1MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
```
## Obtaining AlphaFold 3 Source Code
Install `git` and download the AlphaFold 3 repository:
```sh
git clone https://github.com/google-deepmind/alphafold3.git
```
## Obtaining Genetic Databases
This step requires `wget` and `zstd` to be installed on your machine. On
Debian-based systems install them by running `sudo apt install wget zstd`.
AlphaFold 3 needs multiple genetic (sequence) protein and RNA databases to run:
* [BFD small](https://bfd.mmseqs.com/)
* [MGnify](https://www.ebi.ac.uk/metagenomics/)
* [PDB](https://www.rcsb.org/) (structures in the mmCIF format)
* [PDB seqres](https://www.rcsb.org/)
* [UniProt](https://www.uniprot.org/uniprot/)
* [UniRef90](https://www.uniprot.org/help/uniref)
* [NT](https://www.ncbi.nlm.nih.gov/nucleotide/)
* [RFam](https://rfam.org/)
* [RNACentral](https://rnacentral.org/)
We provide a bash script `fetch_databases.sh` that can be used to download and
set up all of these databases. This process takes around 45 minutes when not
installing on local SSD. We recommend running the following in a `screen` or
`tmux` session as downloading and decompressing the databases takes some time.
```sh
cd alphafold3 # Navigate to the directory with cloned AlphaFold 3 repository.
./fetch_databases.sh []
```
This script downloads the databases from a mirror hosted on GCS, with all
versions being the same as used in the AlphaFold 3 paper, to the directory
``. If not specified, the default `` is
`$HOME/public_databases`.
:ledger: **Note: The download directory `` should *not* be a
subdirectory in the AlphaFold 3 repository directory.** If it is, the Docker
build will be slow as the large databases will be copied during the image
creation.
:ledger: **Note: The total download size for the full databases is around 252 GB
and the total size when unzipped is 630 GB. Please make sure you have sufficient
hard drive space, bandwidth, and time to download. We recommend using an SSD for
better genetic search performance.**
:ledger: **Note: If the download directory and datasets don't have full read and
write permissions, it can cause errors with the MSA tools, with opaque
(external) error messages. Please ensure the required permissions are applied,
e.g. with the `sudo chmod 755 --recursive ` command.**
Once the script has finished, you should have the following directory structure:
```sh
mmcif_files/ # Directory containing ~200k PDB mmCIF files.
bfd-first_non_consensus_sequences.fasta
mgy_clusters_2022_05.fa
nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta
pdb_seqres_2022_09_28.fasta
rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta
rnacentral_active_seq_id_90_cov_80_linclust.fasta
uniprot_all_2021_04.fa
uniref90_2022_05.fa
```
Optionally, after the script finishes, you may want copy databases to an SSD.
You can use theses two scripts:
* `src/scripts/gcp_mount_ssd.sh []` Mounts and formats an
unmounted GCP SSD drive to the specified path. It will skip the either step
if the disk is either already formatted or already mounted. The default
`` is `/mnt/disks/ssd`.
* `src/scripts/copy_to_ssd.sh [] []` this will copy as
many files that it can fit on to the SSD. The default `` is
`$HOME/public_databases`, and must match the path used in the
`fetch_databases.sh` command above, and the default `` is
`/mnt/disks/ssd/public_databases`.
## Obtaining Model Parameters
To request access to the AlphaFold 3 model parameters, please complete
[this form](https://forms.gle/svvpY4u2jsHEwWYS6). Access will be granted at
Google DeepMind’s sole discretion. We will aim to respond to requests within 2–3
business days. You may only use AlphaFold 3 model parameters if received
directly from Google. Use is subject to these
[terms of use](https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md).
Once access has been granted, download the model parameters to a directory of
your choosing, referred to as `` in the following
instructions. As with the databases, this should *not* be a subdirectory in the
AlphaFold 3 repository directory.
## Building the Docker Container That Will Run AlphaFold 3
Then, build the Docker container. This builds a container with all the right
python dependencies:
```sh
docker build -t alphafold3 -f docker/Dockerfile .
```
If you hit `No file descriptors available (os error 24)` on systems like
AlmaLinux/Rocky/RHEL, you need to manually expand the file descriptor limits
during the build by appending `--ulimit nofile=65535:65535`:
```sh
docker build --ulimit nofile=65535:65535 -t alphafold3 -f docker/Dockerfile .
```
Create an input JSON file, using either the example in the
[README](https://github.com/google-deepmind/alphafold3?tab=readme-ov-file#installation-and-running-your-first-prediction)
or a
[custom input](https://github.com/google-deepmind/alphafold3/blob/main/docs/input.md),
and place it in a directory, e.g. `$HOME/af_input`. You can now run AlphaFold 3!
```sh
docker run -it \
--volume $HOME/af_input:/root/af_input \
--volume $HOME/af_output:/root/af_output \
--volume :/root/models \
--volume :/root/public_databases \
--gpus all \
alphafold3 \
python run_alphafold.py \
--json_path=/root/af_input/fold_input.json \
--model_dir=/root/models \
--output_dir=/root/af_output
```
where `$HOME/af_input` is the directory containing the input JSON file;
`$HOME/af_output` is the directory where the output will be written to; and
`` and `` are the directories containing the
databases and model parameters. The values of these directories must match the
directories used in previous steps for downloading databases and model weights,
and for the input file.
:ledger: Note: You may also need to create the output directory,
`$HOME/af_output` directory before running the `docker` command and make it and
the input directory writable from the docker container, e.g. by running `chmod
755 $HOME/af_input $HOME/af_output`. In most cases `docker` and
`run_alphafold.py` will create the output directory if it does not exist.
:ledger: **Note: In the example above the databases have been placed on the
persistent disk, which is slow.** If you want better genetic and template search
performance, make sure all databases are placed on a local SSD.
If you have some databases on an SSD in the `` directory and some
databases on a slower disk in the `` directory, you can mount both
directories and specify `db_dir` multiple times. This will enable the fast
access to databases with a fallback to the larger, slower disk:
```sh
docker run -it \
--volume $HOME/af_input:/root/af_input \
--volume $HOME/af_output:/root/af_output \
--volume :/root/models \
--volume :/root/public_databases \
--volume :/root/public_databases_fallback \
--gpus all \
alphafold3 \
python run_alphafold.py \
--json_path=/root/af_input/fold_input.json \
--model_dir=/root/models \
--db_dir=/root/public_databases \
--db_dir=/root/public_databases_fallback \
--output_dir=/root/af_output
```
If you get an error like the following, make sure the models and data are in the
paths (flags named `--volume` above) in the correct locations.
```
docker: Error response from daemon: error while creating mount source path '/srv/alphafold3_data/models': mkdir /srv/alphafold3_data/models: permission denied.
```
`run_alphafold.py` supports many flags for controlling performance, running on
multiple input files, specifying external binary paths, and more. See
```sh
docker run alphafold3 python run_alphafold.py --help
```
for more information.
## Running Using Singularity Instead of Docker
You may prefer to run AlphaFold 3 within Singularity. You'll still need to
*build* the Singularity image from the Docker container. Afterwards, you will
not have to depend on Docker (at structure prediction time).
### Install Singularity
Official Singularity instructions are
[here](https://docs.sylabs.io/guides/3.3/user-guide/installation.html). The
commands we ran are:
```sh
wget https://github.com/sylabs/singularity/releases/download/v4.2.1/singularity-ce_4.2.1-jammy_amd64.deb
sudo dpkg --install singularity-ce_4.2.1-jammy_amd64.deb
sudo apt-get install -f
```
### Build the Singularity Container From the Docker Image
After building the *Docker* container above with `docker build -t`, start a
local Docker registry and upload your image `alphafold3` to it. Singularity's
instructions are [here](https://github.com/apptainer/singularity/issues/1537).
The commands we ran are:
```sh
docker run -d -p 5000:5000 --restart=always --name registry registry:2
docker tag alphafold3 localhost:5000/alphafold3
docker push localhost:5000/alphafold3
```
Then build the Singularity container:
```sh
SINGULARITY_NOHTTPS=1 singularity build alphafold3.sif docker://localhost:5000/alphafold3:latest
```
You can confirm your build by starting a shell and inspecting the environment.
For example, you may want to ensure the Singularity image can access your GPU.
You may want to restart your computer if you have issues with this.
```sh
singularity exec --nv alphafold3.sif sh -c 'nvidia-smi'
```
You can now run AlphaFold 3!
```sh
singularity exec --nv alphafold3.sif <>
```
For example:
```sh
singularity exec \
--nv \
--bind $HOME/af_input:/root/af_input \
--bind $HOME/af_output:/root/af_output \
--bind :/root/models \
--bind :/root/public_databases \
alphafold3.sif \
python run_alphafold.py \
--json_path=/root/af_input/fold_input.json \
--model_dir=/root/models \
--db_dir=/root/public_databases \
--output_dir=/root/af_output
```
Or with some databases on SSD in location ``:
```sh
singularity exec \
--nv \
--bind $HOME/af_input:/root/af_input \
--bind $HOME/af_output:/root/af_output \
--bind :/root/models \
--bind :/root/public_databases \
--bind :/root/public_databases_fallback \
alphafold3.sif \
python run_alphafold.py \
--json_path=/root/af_input/fold_input.json \
--model_dir=/root/models \
--db_dir=/root/public_databases \
--db_dir=/root/public_databases_fallback \
--output_dir=/root/af_output
```
================================================
FILE: docs/known_issues.md
================================================
# Known Issues
## Numerical performance for CUDA Capability 7.x GPUs
All CUDA Capability 7.x GPUs (e.g. V100) produce obviously bad output, with lots
of clashing residues (the clashes cause a ranking score of -99 or lower), unless
the environment variable `XLA_FLAGS` is set to include
`--xla_disable_hlo_passes=custom-kernel-fusion-rewriter`.
## Incorrect handling of two-letter atoms in SMILES ligands
Between commits https://github.com/google-deepmind/alphafold3/commit/f8df1c7 and
https://github.com/google-deepmind/alphafold3/commit/4e4023c, AlphaFold 3
handled incorrectly any two-letter atoms (e.g. Cl, Br) in ligands defined using
SMILES strings.
## MSA discrepancy between AlphaFold 3 and AlphaFold Server
### The root cause of the problem
The released AlphaFold 3 and AlphaFold Server use the same model weights and
equivalent featurisation and model code. However, the way they run genetic
search is slightly different. The released AlphaFold 3 searches each database in
one go, while AlphaFold Server has a sharded version of each database (split
into multiple smaller FASTA files) and searches all of the shards in parallel.
The results of these parallel searches are then merged together at the end.
The discrepancy is caused by a different (deeper) MSA on AlphaFold Server in
some cases. We discovered that the issue is caused by running sharded Jackhmmer
in AlphaFold Server without the `--domZ` flag (has to be set together with the
`--Z` flag and set to the same value) which means that effectively the AlphaFold
Server is running with roughly 100× more permissive `--domE` filter. This means
more sequences are sometimes included in the MSA.
We are keeping behaviour unchanged in both the released AlphaFold 3 and in the
AlphaFold Server, however, we are giving users with local installs an option to
replicate AlphaFold Server behaviour locally. In our large scale tests the
difference did not matter, it is only very specific inputs that get better
accuracy with the deeper MSA.
See https://github.com/google-deepmind/alphafold3/issues/492 for an example
input where a protein-DNA complex gets significantly higher ipTM and pTM with
AlphaFold Server compared to a local run.
### Replicating AlphaFold Server behaviour locally
If you want to replicate AlphaFold Server behaviour (i.e. better folding
accuracy in some cases), you can increase the value of the Jackhmmer/Nhmmer
`--domE` flag by 100× compared to its default value.
Alternatively, you can run the sharded MSA search while not setting the `--domZ`
value – you would have to modify the code to do it. We added support for
searching against sharded databases in AlphaFold 3 in
https://github.com/google-deepmind/alphafold3/commit/805adc3863841d83d631ccd18136ad58ce3ecb34
and the way to run AlphaFold 3 with sharded databases is documented in
https://github.com/google-deepmind/alphafold3/blob/main/docs/performance.md#sharded-genetic-databases.
It can provide 10–30× speedup (potentially even more, depending on hardware) of
the genetic search.
In general, we recommend experimenting with MSA if you are seeing a prediction
with low predicted confidence. Typically adding more *relevant* sequences in the
MSA will increase AlphaFold prediction accuracy and model confidence scores.
================================================
FILE: docs/metadata_antibody_antigen.csv
================================================
pdb_id,chain_id_1,chain_id_2,cluster_key_chain_1,cluster_key_chain_2,interface_cluster_key
7fci,A,B,5581,5964,5581|5964
7fci,A,C,5581,17640,17640|5581
7mnl,A,C,8677,17640,17640|8677
7n0a,A,B,33602,5964,33602|5964
7n0a,A,C,33602,17640,17640|33602
7ox1,A,G,17640,41184,17640|41184
7ox1,B,G,5964,41184,41184|5964
7ox2,A,C,17640,41184,17640|41184
7ox2,B,C,5964,41184,41184|5964
7ox3,A,C,5964,41184,41184|5964
7ox3,B,C,17640,41184,17640|41184
7ox4,A,C,17640,41184,17640|41184
7ox4,B,C,5964,41184,41184|5964
7q6c,A,B,15496,17640,15496|17640
7q6c,A,D,15496,5964,15496|5964
7r58,A,B,30790,17640,17640|30790
7r58,A,C,30790,5964,30790|5964
7ru6,A,B,7068,17640,17640|7068
7sbd,A,C,17640,20692,17640|20692
7sbd,B,C,5964,20692,20692|5964
7sbg,A,C,17640,20692,17640|20692
7sbg,B,C,5964,20692,20692|5964
7sjo,A,F,7390,17640,17640|7390
7sjo,A,I,7390,5964,5964|7390
7sjo,B,G,7390,17640,17640|7390
7sjo,B,H,7390,5964,5964|7390
7sjo,C,D,7390,17640,17640|7390
7sjo,C,E,7390,5964,5964|7390
7sk3,A,C,45640,5964,45640|5964
7sk3,A,D,45640,17640,17640|45640
7sk3,A,E,45640,5964,45640|5964
7sk3,A,F,45640,17640,17640|45640
7sk4,A,C,45640,5964,45640|5964
7sk4,A,D,45640,17640,17640|45640
7sk4,A,E,45640,5964,45640|5964
7sk4,A,F,45640,17640,17640|45640
7sk5,A,B,45640,17640,17640|45640
7sk5,A,D,45640,5964,45640|5964
7sk6,A,C,45640,5964,45640|5964
7sk6,A,D,45640,17640,17640|45640
7sk7,A,C,45640,5964,45640|5964
7sk7,A,D,45640,17640,17640|45640
7sk8,A,C,45640,5964,45640|5964
7sk8,A,D,45640,17640,17640|45640
7sk8,A,E,45640,5964,45640|5964
7sk8,A,F,45640,17640,17640|45640
7sk9,A,B,45640,17640,17640|45640
7sk9,A,C,45640,5964,45640|5964
7st8,A,C,17640,41188,17640|41188
7st8,B,C,5964,41188,41188|5964
7t6x,E,H,5964,24273,24273|5964
7t82,A,C,9703,5964,5964|9703
7t82,A,D,9703,17640,17640|9703
7t9m,A,C,17640,13210,13210|17640
7t9m,B,C,5964,13210,13210|5964
7t9n,A,D,5964,13210,13210|5964
7t9n,B,D,17640,13210,13210|17640
7tuf,A,B,22549,17640,17640|22549
7tuf,A,C,22549,5964,22549|5964
7tuf,A,E,22549,17640,17640|22549
7tuf,B,D,17640,22549,17640|22549
7tuf,D,E,22549,17640,17640|22549
7tuf,D,F,22549,5964,22549|5964
7tug,A,B,22549,17640,17640|22549
7tug,A,C,22549,5964,22549|5964
7u8c,A,B,20081,17640,17640|20081
7u8c,A,C,20081,5964,20081|5964
7u8g,A,C,29632,17640,17640|29632
7u8g,A,D,29632,5964,29632|5964
7uih,A,B,11223,5964,11223|5964
7uih,A,C,11223,17640,11223|17640
7uih,A,D,11223,5964,11223|5964
7uih,A,E,11223,17640,11223|17640
7um3,A,E,17640,33649,17640|33649
7um3,B,E,5964,33649,33649|5964
7ura,A,B,44530,5964,44530|5964
7ura,A,C,44530,17640,17640|44530
7urc,A,B,44530,5964,44530|5964
7urc,A,C,44530,17640,17640|44530
7urd,A,C,44530,5964,44530|5964
7urd,A,D,44530,17640,17640|44530
7ure,A,B,44530,5964,44530|5964
7ure,A,C,44530,17640,17640|44530
7uvf,A,C,23558,17640,17640|23558
7uvf,A,D,23558,5964,23558|5964
7uvf,B,E,23558,5964,23558|5964
7uvf,B,F,23558,17640,17640|23558
7vad,A,B,5581,17640,17640|5581
7vad,A,C,5581,5964,5581|5964
7vae,A,B,5581,17640,17640|5581
7vae,A,C,5581,5964,5581|5964
7vaf,A,C,17640,5581,17640|5581
7vaf,B,C,5964,5581,5581|5964
7vag,A,B,5581,17640,17640|5581
7vag,A,C,5581,5964,5581|5964
7vgr,A,E,5964,33673,33673|5964
7vgr,A,F,5964,33673,33673|5964
7vgr,B,E,17640,33673,17640|33673
7vgr,B,F,17640,33673,17640|33673
7vgr,C,E,5964,33673,33673|5964
7vgr,C,F,5964,33673,33673|5964
7vgr,D,E,17640,33673,17640|33673
7vgr,D,F,17640,33673,17640|33673
7vgs,A,B,33673,5964,33673|5964
7vgs,A,C,33673,17640,17640|33673
7vgs,A,E,33673,5964,33673|5964
7vgs,B,D,5964,33673,33673|5964
7vgs,D,E,33673,5964,33673|5964
7vgs,D,F,33673,17640,17640|33673
7vn9,A,C,17640,20046,17640|20046
7vn9,B,C,5964,20046,20046|5964
7vng,A,B,20046,17640,17640|20046
7vng,A,C,20046,5964,20046|5964
7w71,A,E,24335,17640,17640|24335
7w71,A,F,24335,5964,24335|5964
7wsi,A,B,5581,17640,17640|5581
7wsi,A,C,5581,5964,5581|5964
7xq8,A,B,26372,5964,26372|5964
7xq8,C,D,26372,5964,26372|5964
7zlg,A,D,17640,29547,17640|29547
7zlg,C,D,5964,29547,29547|5964
7zlh,A,D,17640,29547,17640|29547
7zlh,C,D,5964,29547,29547|5964
7zli,A,D,17640,29547,17640|29547
7zli,C,D,5964,29547,29547|5964
7zlj,A,D,17640,29547,17640|29547
7zlj,C,D,5964,29547,29547|5964
7zwi,A,C,7003,5964,5964|7003
7zxf,A,C,7003,5964,5964|7003
7zxf,A,E,7003,5964,5964|7003
7zxg,A,C,7003,5964,5964|7003
7zxk,A,F,26707,5964,26707|5964
7zxk,A,G,26707,17640,17640|26707
7zyi,A,B,5581,17640,17640|5581
7zyi,A,C,5581,5964,5581|5964
8cz5,A,B,24059,17640,17640|24059
8cz5,A,C,24059,5964,24059|5964
8dcy,A,C,17640,23342,17640|23342
8dcy,B,C,5964,23342,23342|5964
8ddk,A,C,17640,23342,17640|23342
8ddk,B,C,5964,23342,23342|5964
8djk,A,E,15456,17640,15456|17640
8djk,B,D,15455,5964,15455|5964
8djk,B,E,15455,17640,15455|17640
8djm,A,E,15456,17640,15456|17640
8djm,B,D,15455,5964,15455|5964
8djm,B,E,15455,17640,15455|17640
8dke,A,B,13979,17640,13979|17640
8dke,A,C,13979,5964,13979|5964
8dki,A,B,13979,17640,13979|17640
8dki,A,C,13979,5964,13979|5964
8dkm,A,C,17640,13979,13979|17640
8dkm,B,C,5964,13979,13979|5964
8dkw,A,C,17640,13979,13979|17640
8dkw,B,C,5964,13979,13979|5964
8dkx,A,C,17640,13979,13979|17640
8dkx,B,C,5964,13979,13979|5964
8hii,A,B,21158,17640,17640|21158
8hii,A,D,21158,5964,21158|5964
8hij,A,B,21158,17640,17640|21158
8hij,A,D,21158,5964,21158|5964
8hik,A,B,21158,17640,17640|21158
8hik,A,D,21158,5964,21158|5964
7so7,A,F,3006,5964,3006|5964
7xy8,A,C,2517,17640,17640|2517
7xy8,A,E,2517,5964,2517|5964
================================================
FILE: docs/metadata_antibody_antigen.md
================================================
# Metadata for Antibody-Antigen pairs used to create figure 5a
Figure 5a in the AlphaFold 3 paper was created using 71 antibody–antigen
complexes, containing 166 antibody–antigen interfaces spanning 65 interface
clusters. Scores were averaged within each interface cluster then across
clusters. Note that the first bioassembly is used in all cases.
We provide metadata for these complexes and the associated clusters in this CSV
file:
https://github.com/google-deepmind/alphafold3/blob/main/docs/metadata_antibody_antigen.csv
================================================
FILE: docs/model_parameters.md
================================================
# Model Parameters
AlphaFold 3 layer names, shapes, and dtypes are documented in the table below.
This can be used for example to generate random parameters for AlphaFold 3
performance optimisation on new accelerators without having to obtain the
official parameters. It is important to not generate zero-only parameters for
performance optimisations as accelerators often have shortcuts for zero-only
arguments (e.g. `0 * tensor` can be optimised to a no-op).
Producing random parameters could be done similarly to the following snippet:
```py
from alphafold3.model import params
import numpy as np
import zstandard
parameters = ... # Data from the parameters schema.
with zstandard.open('random_weights.bin.zst', 'wb') as compressed:
for scope_name, shape, dtype in parameters:
if scope_name == '__meta__:__identifier__':
# The identifier can be all zeros.
arr = np.zeros(shape=shape, dtype=dtype)
else:
# Do not use all-zero params, instead sample uniformly between -1 and 1.
arr = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
scope_name = scope_name.split(':')
compressed.write(params.encode_record(*scope_name, arr))
```
## Parameters Schema
```
name=__meta__:__identifier__ dtype=uint8 shape=(64,)
name=diffuser/~/diffusion_head/diffusion_atom_features_layer_norm:scale dtype=float32 shape=(128,)
name=diffuser/~/diffusion_head/diffusion_atom_features_to_position_update:weights dtype=float32 shape=(128, 3)
name=diffuser/~/diffusion_head/diffusion_atom_positions_to_features:weights dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderadaptive_zero_cond:bias dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderadaptive_zero_cond:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderffw_adaptive_zero_cond:bias dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderffw_adaptive_zero_cond:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderffw_single_cond_bias:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderffw_single_cond_layer_norm:scale dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderffw_single_cond_scale:bias dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderffw_single_cond_scale:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderffw_transition1:weights dtype=float32 shape=(3, 128, 512)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderffw_transition2:weights dtype=float32 shape=(3, 256, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decodergating_query:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderk_projection:weights dtype=float32 shape=(3, 128, 4, 32)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderksingle_cond_bias:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderksingle_cond_layer_norm:scale dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderksingle_cond_scale:bias dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderksingle_cond_scale:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderq_projection:bias dtype=float32 shape=(3, 4, 32)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderq_projection:weights dtype=float32 shape=(3, 128, 4, 32)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderqsingle_cond_bias:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderqsingle_cond_layer_norm:scale dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderqsingle_cond_scale:bias dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderqsingle_cond_scale:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decodertransition2:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/__layer_stack_with_per_layer/diffusion_atom_transformer_decoderv_projection:weights dtype=float32 shape=(3, 128, 4, 32)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/pair_input_layer_norm:scale dtype=float32 shape=(16,)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_decoder/pair_logits_projection:weights dtype=float32 shape=(16, 3, 4)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderadaptive_zero_cond:bias dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderadaptive_zero_cond:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderffw_adaptive_zero_cond:bias dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderffw_adaptive_zero_cond:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderffw_single_cond_bias:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderffw_single_cond_layer_norm:scale dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderffw_single_cond_scale:bias dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderffw_single_cond_scale:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderffw_transition1:weights dtype=float32 shape=(3, 128, 512)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderffw_transition2:weights dtype=float32 shape=(3, 256, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encodergating_query:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderk_projection:weights dtype=float32 shape=(3, 128, 4, 32)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderksingle_cond_bias:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderksingle_cond_layer_norm:scale dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderksingle_cond_scale:bias dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderksingle_cond_scale:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderq_projection:bias dtype=float32 shape=(3, 4, 32)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderq_projection:weights dtype=float32 shape=(3, 128, 4, 32)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderqsingle_cond_bias:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderqsingle_cond_layer_norm:scale dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderqsingle_cond_scale:bias dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderqsingle_cond_scale:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encodertransition2:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/__layer_stack_with_per_layer/diffusion_atom_transformer_encoderv_projection:weights dtype=float32 shape=(3, 128, 4, 32)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/pair_input_layer_norm:scale dtype=float32 shape=(16,)
name=diffuser/~/diffusion_head/diffusion_atom_transformer_encoder/pair_logits_projection:weights dtype=float32 shape=(16, 3, 4)
name=diffuser/~/diffusion_head/diffusion_embed_pair_distances_1:weights dtype=float32 shape=(1, 16)
name=diffuser/~/diffusion_head/diffusion_embed_pair_distances:weights dtype=float32 shape=(1, 16)
name=diffuser/~/diffusion_head/diffusion_embed_pair_offsets_1:weights dtype=float32 shape=(3, 16)
name=diffuser/~/diffusion_head/diffusion_embed_pair_offsets_valid:weights dtype=float32 shape=(1, 16)
name=diffuser/~/diffusion_head/diffusion_embed_pair_offsets:weights dtype=float32 shape=(3, 16)
name=diffuser/~/diffusion_head/diffusion_embed_ref_atom_name:weights dtype=float32 shape=(256, 128)
name=diffuser/~/diffusion_head/diffusion_embed_ref_charge:weights dtype=float32 shape=(1, 128)
name=diffuser/~/diffusion_head/diffusion_embed_ref_element:weights dtype=float32 shape=(128, 128)
name=diffuser/~/diffusion_head/diffusion_embed_ref_mask:weights dtype=float32 shape=(1, 128)
name=diffuser/~/diffusion_head/diffusion_embed_ref_pos:weights dtype=float32 shape=(3, 128)
name=diffuser/~/diffusion_head/diffusion_embed_trunk_pair_cond:weights dtype=float32 shape=(128, 16)
name=diffuser/~/diffusion_head/diffusion_embed_trunk_single_cond:weights dtype=float32 shape=(384, 128)
name=diffuser/~/diffusion_head/diffusion_lnorm_trunk_pair_cond:scale dtype=float32 shape=(128,)
name=diffuser/~/diffusion_head/diffusion_lnorm_trunk_single_cond:scale dtype=float32 shape=(384,)
name=diffuser/~/diffusion_head/diffusion_pair_mlp_1:weights dtype=float32 shape=(16, 16)
name=diffuser/~/diffusion_head/diffusion_pair_mlp_2:weights dtype=float32 shape=(16, 16)
name=diffuser/~/diffusion_head/diffusion_pair_mlp_3:weights dtype=float32 shape=(16, 16)
name=diffuser/~/diffusion_head/diffusion_project_atom_features_for_aggr:weights dtype=float32 shape=(128, 768)
name=diffuser/~/diffusion_head/diffusion_project_token_features_for_broadcast:weights dtype=float32 shape=(768, 128)
name=diffuser/~/diffusion_head/diffusion_single_to_pair_cond_col_1:weights dtype=float32 shape=(128, 16)
name=diffuser/~/diffusion_head/diffusion_single_to_pair_cond_col:weights dtype=float32 shape=(128, 16)
name=diffuser/~/diffusion_head/diffusion_single_to_pair_cond_row_1:weights dtype=float32 shape=(128, 16)
name=diffuser/~/diffusion_head/diffusion_single_to_pair_cond_row:weights dtype=float32 shape=(128, 16)
name=diffuser/~/diffusion_head/noise_embedding_initial_norm:scale dtype=float32 shape=(256,)
name=diffuser/~/diffusion_head/noise_embedding_initial_projection:weights dtype=float32 shape=(256, 384)
name=diffuser/~/diffusion_head/output_norm:scale dtype=float32 shape=(768,)
name=diffuser/~/diffusion_head/pair_cond_initial_norm:scale dtype=float32 shape=(267,)
name=diffuser/~/diffusion_head/pair_cond_initial_projection:weights dtype=float32 shape=(267, 128)
name=diffuser/~/diffusion_head/pair_transition_0ffw_layer_norm:offset dtype=float32 shape=(128,)
name=diffuser/~/diffusion_head/pair_transition_0ffw_layer_norm:scale dtype=float32 shape=(128,)
name=diffuser/~/diffusion_head/pair_transition_0ffw_transition1:weights dtype=float32 shape=(128, 512)
name=diffuser/~/diffusion_head/pair_transition_0ffw_transition2:weights dtype=float32 shape=(256, 128)
name=diffuser/~/diffusion_head/pair_transition_1ffw_layer_norm:offset dtype=float32 shape=(128,)
name=diffuser/~/diffusion_head/pair_transition_1ffw_layer_norm:scale dtype=float32 shape=(128,)
name=diffuser/~/diffusion_head/pair_transition_1ffw_transition1:weights dtype=float32 shape=(128, 512)
name=diffuser/~/diffusion_head/pair_transition_1ffw_transition2:weights dtype=float32 shape=(256, 128)
name=diffuser/~/diffusion_head/single_cond_embedding_norm:scale dtype=float32 shape=(384,)
name=diffuser/~/diffusion_head/single_cond_embedding_projection:weights dtype=float32 shape=(384, 768)
name=diffuser/~/diffusion_head/single_cond_initial_norm:scale dtype=float32 shape=(831,)
name=diffuser/~/diffusion_head/single_cond_initial_projection:weights dtype=float32 shape=(831, 384)
name=diffuser/~/diffusion_head/single_transition_0ffw_layer_norm:offset dtype=float32 shape=(384,)
name=diffuser/~/diffusion_head/single_transition_0ffw_layer_norm:scale dtype=float32 shape=(384,)
name=diffuser/~/diffusion_head/single_transition_0ffw_transition1:weights dtype=float32 shape=(384, 1536)
name=diffuser/~/diffusion_head/single_transition_0ffw_transition2:weights dtype=float32 shape=(768, 384)
name=diffuser/~/diffusion_head/single_transition_1ffw_layer_norm:offset dtype=float32 shape=(384,)
name=diffuser/~/diffusion_head/single_transition_1ffw_layer_norm:scale dtype=float32 shape=(384,)
name=diffuser/~/diffusion_head/single_transition_1ffw_transition1:weights dtype=float32 shape=(384, 1536)
name=diffuser/~/diffusion_head/single_transition_1ffw_transition2:weights dtype=float32 shape=(768, 384)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformeradaptive_zero_cond:bias dtype=float32 shape=(6, 4, 768)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformeradaptive_zero_cond:weights dtype=float32 shape=(6, 4, 384, 768)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformerffw_adaptive_zero_cond:bias dtype=float32 shape=(6, 4, 768)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformerffw_adaptive_zero_cond:weights dtype=float32 shape=(6, 4, 384, 768)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformerffw_single_cond_bias:weights dtype=float32 shape=(6, 4, 384, 768)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformerffw_single_cond_layer_norm:scale dtype=float32 shape=(6, 4, 384)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformerffw_single_cond_scale:bias dtype=float32 shape=(6, 4, 768)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformerffw_single_cond_scale:weights dtype=float32 shape=(6, 4, 384, 768)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformerffw_transition1:weights dtype=float32 shape=(6, 4, 768, 3072)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformerffw_transition2:weights dtype=float32 shape=(6, 4, 1536, 768)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformergating_query:weights dtype=float32 shape=(6, 4, 768, 768)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformerk_projection:weights dtype=float32 shape=(6, 4, 768, 16, 48)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformerq_projection:bias dtype=float32 shape=(6, 4, 16, 48)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformerq_projection:weights dtype=float32 shape=(6, 4, 768, 16, 48)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformersingle_cond_bias:weights dtype=float32 shape=(6, 4, 384, 768)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformersingle_cond_layer_norm:scale dtype=float32 shape=(6, 4, 384)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformersingle_cond_scale:bias dtype=float32 shape=(6, 4, 768)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformersingle_cond_scale:weights dtype=float32 shape=(6, 4, 384, 768)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformertransition2:weights dtype=float32 shape=(6, 4, 768, 768)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformerv_projection:weights dtype=float32 shape=(6, 4, 768, 16, 48)
name=diffuser/~/diffusion_head/transformer/__layer_stack_with_per_layer/pair_logits_projection:weights dtype=float32 shape=(6, 128, 4, 16)
name=diffuser/~/diffusion_head/transformer/pair_input_layer_norm:scale dtype=float32 shape=(128,)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention1/act_norm:offset dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention1/act_norm:scale dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention1/gating_query:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention1/k_projection:weights dtype=bfloat16 shape=(4, 4, 32, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention1/output_projection:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention1/pair_bias_projection:weights dtype=bfloat16 shape=(4, 128, 4)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention1/q_projection:weights dtype=bfloat16 shape=(4, 4, 32, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention1/v_projection:weights dtype=bfloat16 shape=(4, 128, 4, 32)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention2/act_norm:offset dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention2/act_norm:scale dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention2/gating_query:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention2/k_projection:weights dtype=bfloat16 shape=(4, 4, 32, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention2/output_projection:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention2/pair_bias_projection:weights dtype=bfloat16 shape=(4, 128, 4)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention2/q_projection:weights dtype=bfloat16 shape=(4, 4, 32, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_attention2/v_projection:weights dtype=bfloat16 shape=(4, 128, 4, 32)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_transition/input_layer_norm:offset dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_transition/input_layer_norm:scale dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_transition/transition1:weights dtype=bfloat16 shape=(4, 128, 1024)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/pair_transition/transition2:weights dtype=bfloat16 shape=(4, 512, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/single_attention_gating_query:weights dtype=bfloat16 shape=(4, 384, 384)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/single_attention_k_projection:weights dtype=bfloat16 shape=(4, 384, 16, 24)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/single_attention_layer_norm:offset dtype=float32 shape=(4, 384)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/single_attention_layer_norm:scale dtype=float32 shape=(4, 384)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/single_attention_q_projection:bias dtype=bfloat16 shape=(4, 16, 24)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/single_attention_q_projection:weights dtype=bfloat16 shape=(4, 384, 16, 24)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/single_attention_transition2:weights dtype=bfloat16 shape=(4, 384, 384)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/single_attention_v_projection:weights dtype=bfloat16 shape=(4, 384, 16, 24)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/single_pair_logits_norm:offset dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/single_pair_logits_norm:scale dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/single_pair_logits_projection:weights dtype=bfloat16 shape=(4, 128, 16)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/single_transition/input_layer_norm:offset dtype=float32 shape=(4, 384)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/single_transition/input_layer_norm:scale dtype=float32 shape=(4, 384)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/single_transition/transition1:weights dtype=bfloat16 shape=(4, 384, 3072)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/single_transition/transition2:weights dtype=bfloat16 shape=(4, 1536, 384)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_incoming/center_norm:offset dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_incoming/center_norm:scale dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_incoming/gate:weights dtype=bfloat16 shape=(4, 128, 256)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_incoming/gating_linear:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_incoming/left_norm_input:offset dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_incoming/left_norm_input:scale dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_incoming/output_projection:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_incoming/projection:weights dtype=bfloat16 shape=(4, 128, 256)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_outgoing/center_norm:offset dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_outgoing/center_norm:scale dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_outgoing/gate:weights dtype=bfloat16 shape=(4, 128, 256)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_outgoing/gating_linear:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_outgoing/left_norm_input:offset dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_outgoing/left_norm_input:scale dtype=float32 shape=(4, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_outgoing/output_projection:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/confidence_head/__layer_stack_no_per_layer/confidence_pairformer/triangle_multiplication_outgoing/projection:weights dtype=bfloat16 shape=(4, 128, 256)
name=diffuser/confidence_head/~_embed_features/distogram_feat_project:weights dtype=bfloat16 shape=(39, 128)
name=diffuser/confidence_head/~_embed_features/left_target_feat_project:weights dtype=bfloat16 shape=(447, 128)
name=diffuser/confidence_head/~_embed_features/right_target_feat_project:weights dtype=bfloat16 shape=(447, 128)
name=diffuser/confidence_head/experimentally_resolved_ln:offset dtype=float32 shape=(384,)
name=diffuser/confidence_head/experimentally_resolved_ln:scale dtype=float32 shape=(384,)
name=diffuser/confidence_head/experimentally_resolved_logits:weights dtype=float32 shape=(384, 24, 2)
name=diffuser/confidence_head/left_half_distance_logits:weights dtype=float32 shape=(128, 64)
name=diffuser/confidence_head/logits_ln:offset dtype=float32 shape=(128,)
name=diffuser/confidence_head/logits_ln:scale dtype=float32 shape=(128,)
name=diffuser/confidence_head/pae_logits_ln:offset dtype=float32 shape=(128,)
name=diffuser/confidence_head/pae_logits_ln:scale dtype=float32 shape=(128,)
name=diffuser/confidence_head/pae_logits:weights dtype=float32 shape=(128, 64)
name=diffuser/confidence_head/plddt_logits_ln:offset dtype=float32 shape=(384,)
name=diffuser/confidence_head/plddt_logits_ln:scale dtype=float32 shape=(384,)
name=diffuser/confidence_head/plddt_logits:weights dtype=float32 shape=(384, 24, 50)
name=diffuser/distogram_head/half_logits:weights dtype=float32 shape=(128, 64)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderadaptive_zero_cond:bias dtype=float32 shape=(3, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderadaptive_zero_cond:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderffw_adaptive_zero_cond:bias dtype=float32 shape=(3, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderffw_adaptive_zero_cond:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderffw_single_cond_bias:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderffw_single_cond_layer_norm:scale dtype=float32 shape=(3, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderffw_single_cond_scale:bias dtype=float32 shape=(3, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderffw_single_cond_scale:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderffw_transition1:weights dtype=float32 shape=(3, 128, 512)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderffw_transition2:weights dtype=float32 shape=(3, 256, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encodergating_query:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderk_projection:weights dtype=float32 shape=(3, 128, 4, 32)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderksingle_cond_bias:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderksingle_cond_layer_norm:scale dtype=float32 shape=(3, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderksingle_cond_scale:bias dtype=float32 shape=(3, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderksingle_cond_scale:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderq_projection:bias dtype=float32 shape=(3, 4, 32)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderq_projection:weights dtype=float32 shape=(3, 128, 4, 32)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderqsingle_cond_bias:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderqsingle_cond_layer_norm:scale dtype=float32 shape=(3, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderqsingle_cond_scale:bias dtype=float32 shape=(3, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderqsingle_cond_scale:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encodertransition2:weights dtype=float32 shape=(3, 128, 128)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/__layer_stack_with_per_layer/evoformer_conditioning_atom_transformer_encoderv_projection:weights dtype=float32 shape=(3, 128, 4, 32)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/pair_input_layer_norm:scale dtype=float32 shape=(16,)
name=diffuser/evoformer_conditioning_atom_transformer_encoder/pair_logits_projection:weights dtype=float32 shape=(16, 3, 4)
name=diffuser/evoformer_conditioning_embed_pair_distances_1:weights dtype=float32 shape=(1, 16)
name=diffuser/evoformer_conditioning_embed_pair_distances:weights dtype=float32 shape=(1, 16)
name=diffuser/evoformer_conditioning_embed_pair_offsets_1:weights dtype=float32 shape=(3, 16)
name=diffuser/evoformer_conditioning_embed_pair_offsets_valid:weights dtype=float32 shape=(1, 16)
name=diffuser/evoformer_conditioning_embed_pair_offsets:weights dtype=float32 shape=(3, 16)
name=diffuser/evoformer_conditioning_embed_ref_atom_name:weights dtype=float32 shape=(256, 128)
name=diffuser/evoformer_conditioning_embed_ref_charge:weights dtype=float32 shape=(1, 128)
name=diffuser/evoformer_conditioning_embed_ref_element:weights dtype=float32 shape=(128, 128)
name=diffuser/evoformer_conditioning_embed_ref_mask:weights dtype=float32 shape=(1, 128)
name=diffuser/evoformer_conditioning_embed_ref_pos:weights dtype=float32 shape=(3, 128)
name=diffuser/evoformer_conditioning_pair_mlp_1:weights dtype=float32 shape=(16, 16)
name=diffuser/evoformer_conditioning_pair_mlp_2:weights dtype=float32 shape=(16, 16)
name=diffuser/evoformer_conditioning_pair_mlp_3:weights dtype=float32 shape=(16, 16)
name=diffuser/evoformer_conditioning_project_atom_features_for_aggr:weights dtype=float32 shape=(128, 384)
name=diffuser/evoformer_conditioning_single_to_pair_cond_col_1:weights dtype=float32 shape=(128, 16)
name=diffuser/evoformer_conditioning_single_to_pair_cond_col:weights dtype=float32 shape=(128, 16)
name=diffuser/evoformer_conditioning_single_to_pair_cond_row_1:weights dtype=float32 shape=(128, 16)
name=diffuser/evoformer_conditioning_single_to_pair_cond_row:weights dtype=float32 shape=(128, 16)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention1/act_norm:offset dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention1/act_norm:scale dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention1/gating_query:weights dtype=bfloat16 shape=(48, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention1/k_projection:weights dtype=bfloat16 shape=(48, 4, 32, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention1/output_projection:weights dtype=bfloat16 shape=(48, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention1/pair_bias_projection:weights dtype=bfloat16 shape=(48, 128, 4)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention1/q_projection:weights dtype=bfloat16 shape=(48, 4, 32, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention1/v_projection:weights dtype=bfloat16 shape=(48, 128, 4, 32)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention2/act_norm:offset dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention2/act_norm:scale dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention2/gating_query:weights dtype=bfloat16 shape=(48, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention2/k_projection:weights dtype=bfloat16 shape=(48, 4, 32, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention2/output_projection:weights dtype=bfloat16 shape=(48, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention2/pair_bias_projection:weights dtype=bfloat16 shape=(48, 128, 4)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention2/q_projection:weights dtype=bfloat16 shape=(48, 4, 32, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_attention2/v_projection:weights dtype=bfloat16 shape=(48, 128, 4, 32)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_transition/input_layer_norm:offset dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_transition/input_layer_norm:scale dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_transition/transition1:weights dtype=bfloat16 shape=(48, 128, 1024)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/pair_transition/transition2:weights dtype=bfloat16 shape=(48, 512, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/single_attention_gating_query:weights dtype=bfloat16 shape=(48, 384, 384)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/single_attention_k_projection:weights dtype=bfloat16 shape=(48, 384, 16, 24)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/single_attention_layer_norm:offset dtype=float32 shape=(48, 384)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/single_attention_layer_norm:scale dtype=float32 shape=(48, 384)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/single_attention_q_projection:bias dtype=bfloat16 shape=(48, 16, 24)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/single_attention_q_projection:weights dtype=bfloat16 shape=(48, 384, 16, 24)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/single_attention_transition2:weights dtype=bfloat16 shape=(48, 384, 384)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/single_attention_v_projection:weights dtype=bfloat16 shape=(48, 384, 16, 24)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/single_pair_logits_norm:offset dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/single_pair_logits_norm:scale dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/single_pair_logits_projection:weights dtype=bfloat16 shape=(48, 128, 16)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/single_transition/input_layer_norm:offset dtype=float32 shape=(48, 384)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/single_transition/input_layer_norm:scale dtype=float32 shape=(48, 384)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/single_transition/transition1:weights dtype=bfloat16 shape=(48, 384, 3072)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/single_transition/transition2:weights dtype=bfloat16 shape=(48, 1536, 384)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_incoming/center_norm:offset dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_incoming/center_norm:scale dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_incoming/gate:weights dtype=bfloat16 shape=(48, 128, 256)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_incoming/gating_linear:weights dtype=bfloat16 shape=(48, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_incoming/left_norm_input:offset dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_incoming/left_norm_input:scale dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_incoming/output_projection:weights dtype=bfloat16 shape=(48, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_incoming/projection:weights dtype=bfloat16 shape=(48, 128, 256)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_outgoing/center_norm:offset dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_outgoing/center_norm:scale dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_outgoing/gate:weights dtype=bfloat16 shape=(48, 128, 256)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_outgoing/gating_linear:weights dtype=bfloat16 shape=(48, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_outgoing/left_norm_input:offset dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_outgoing/left_norm_input:scale dtype=float32 shape=(48, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_outgoing/output_projection:weights dtype=bfloat16 shape=(48, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer_1/trunk_pairformer/triangle_multiplication_outgoing/projection:weights dtype=bfloat16 shape=(48, 128, 256)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/msa_attention1/act_norm:offset dtype=float32 shape=(4, 64)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/msa_attention1/act_norm:scale dtype=float32 shape=(4, 64)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/msa_attention1/gating_query:weights dtype=bfloat16 shape=(4, 64, 64)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/msa_attention1/output_projection:weights dtype=bfloat16 shape=(4, 64, 64)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/msa_attention1/pair_logits:weights dtype=bfloat16 shape=(4, 128, 8)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/msa_attention1/pair_norm:offset dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/msa_attention1/pair_norm:scale dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/msa_attention1/v_projection:weights dtype=bfloat16 shape=(4, 64, 8, 8)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/msa_transition/input_layer_norm:offset dtype=float32 shape=(4, 64)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/msa_transition/input_layer_norm:scale dtype=float32 shape=(4, 64)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/msa_transition/transition1:weights dtype=bfloat16 shape=(4, 64, 512)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/msa_transition/transition2:weights dtype=bfloat16 shape=(4, 256, 64)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/outer_product_mean:output_b dtype=bfloat16 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/outer_product_mean:output_w dtype=bfloat16 shape=(4, 32, 32, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/outer_product_mean/layer_norm_input:offset dtype=float32 shape=(4, 64)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/outer_product_mean/layer_norm_input:scale dtype=float32 shape=(4, 64)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/outer_product_mean/left_projection:weights dtype=bfloat16 shape=(4, 64, 32)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/outer_product_mean/right_projection:weights dtype=bfloat16 shape=(4, 64, 32)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention1/act_norm:offset dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention1/act_norm:scale dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention1/gating_query:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention1/k_projection:weights dtype=bfloat16 shape=(4, 4, 32, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention1/output_projection:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention1/pair_bias_projection:weights dtype=bfloat16 shape=(4, 128, 4)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention1/q_projection:weights dtype=bfloat16 shape=(4, 4, 32, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention1/v_projection:weights dtype=bfloat16 shape=(4, 128, 4, 32)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention2/act_norm:offset dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention2/act_norm:scale dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention2/gating_query:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention2/k_projection:weights dtype=bfloat16 shape=(4, 4, 32, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention2/output_projection:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention2/pair_bias_projection:weights dtype=bfloat16 shape=(4, 128, 4)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention2/q_projection:weights dtype=bfloat16 shape=(4, 4, 32, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_attention2/v_projection:weights dtype=bfloat16 shape=(4, 128, 4, 32)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_transition/input_layer_norm:offset dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_transition/input_layer_norm:scale dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_transition/transition1:weights dtype=bfloat16 shape=(4, 128, 1024)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/pair_transition/transition2:weights dtype=bfloat16 shape=(4, 512, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_incoming/center_norm:offset dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_incoming/center_norm:scale dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_incoming/gate:weights dtype=bfloat16 shape=(4, 128, 256)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_incoming/gating_linear:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_incoming/left_norm_input:offset dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_incoming/left_norm_input:scale dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_incoming/output_projection:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_incoming/projection:weights dtype=bfloat16 shape=(4, 128, 256)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_outgoing/center_norm:offset dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_outgoing/center_norm:scale dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_outgoing/gate:weights dtype=bfloat16 shape=(4, 128, 256)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_outgoing/gating_linear:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_outgoing/left_norm_input:offset dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_outgoing/left_norm_input:scale dtype=float32 shape=(4, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_outgoing/output_projection:weights dtype=bfloat16 shape=(4, 128, 128)
name=diffuser/evoformer/__layer_stack_no_per_layer/msa_stack/triangle_multiplication_outgoing/projection:weights dtype=bfloat16 shape=(4, 128, 256)
name=diffuser/evoformer/~_relative_encoding/position_activations:weights dtype=bfloat16 shape=(139, 128)
name=diffuser/evoformer/bond_embedding:weights dtype=bfloat16 shape=(1, 128)
name=diffuser/evoformer/extra_msa_target_feat:weights dtype=bfloat16 shape=(447, 64)
name=diffuser/evoformer/left_single:weights dtype=bfloat16 shape=(447, 128)
name=diffuser/evoformer/msa_activations:weights dtype=bfloat16 shape=(34, 64)
name=diffuser/evoformer/prev_embedding_layer_norm:offset dtype=float32 shape=(128,)
name=diffuser/evoformer/prev_embedding_layer_norm:scale dtype=float32 shape=(128,)
name=diffuser/evoformer/prev_embedding:weights dtype=bfloat16 shape=(128, 128)
name=diffuser/evoformer/prev_single_embedding_layer_norm:offset dtype=float32 shape=(384,)
name=diffuser/evoformer/prev_single_embedding_layer_norm:scale dtype=float32 shape=(384,)
name=diffuser/evoformer/prev_single_embedding:weights dtype=bfloat16 shape=(384, 384)
name=diffuser/evoformer/right_single:weights dtype=bfloat16 shape=(447, 128)
name=diffuser/evoformer/single_activations:weights dtype=bfloat16 shape=(447, 384)
name=diffuser/evoformer/template_embedding/output_linear:weights dtype=bfloat16 shape=(64, 128)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention1/act_norm:offset dtype=float32 shape=(2, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention1/act_norm:scale dtype=float32 shape=(2, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention1/gating_query:weights dtype=bfloat16 shape=(2, 64, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention1/k_projection:weights dtype=bfloat16 shape=(2, 4, 16, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention1/output_projection:weights dtype=bfloat16 shape=(2, 64, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention1/pair_bias_projection:weights dtype=bfloat16 shape=(2, 64, 4)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention1/q_projection:weights dtype=bfloat16 shape=(2, 4, 16, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention1/v_projection:weights dtype=bfloat16 shape=(2, 64, 4, 16)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention2/act_norm:offset dtype=float32 shape=(2, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention2/act_norm:scale dtype=float32 shape=(2, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention2/gating_query:weights dtype=bfloat16 shape=(2, 64, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention2/k_projection:weights dtype=bfloat16 shape=(2, 4, 16, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention2/output_projection:weights dtype=bfloat16 shape=(2, 64, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention2/pair_bias_projection:weights dtype=bfloat16 shape=(2, 64, 4)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention2/q_projection:weights dtype=bfloat16 shape=(2, 4, 16, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_attention2/v_projection:weights dtype=bfloat16 shape=(2, 64, 4, 16)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_transition/input_layer_norm:offset dtype=float32 shape=(2, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_transition/input_layer_norm:scale dtype=float32 shape=(2, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_transition/transition1:weights dtype=bfloat16 shape=(2, 64, 256)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/pair_transition/transition2:weights dtype=bfloat16 shape=(2, 128, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_incoming/center_norm:offset dtype=float32 shape=(2, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_incoming/center_norm:scale dtype=float32 shape=(2, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_incoming/gate:weights dtype=bfloat16 shape=(2, 64, 128)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_incoming/gating_linear:weights dtype=bfloat16 shape=(2, 64, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_incoming/left_norm_input:offset dtype=float32 shape=(2, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_incoming/left_norm_input:scale dtype=float32 shape=(2, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_incoming/output_projection:weights dtype=bfloat16 shape=(2, 64, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_incoming/projection:weights dtype=bfloat16 shape=(2, 64, 128)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_outgoing/center_norm:offset dtype=float32 shape=(2, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_outgoing/center_norm:scale dtype=float32 shape=(2, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_outgoing/gate:weights dtype=bfloat16 shape=(2, 64, 128)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_outgoing/gating_linear:weights dtype=bfloat16 shape=(2, 64, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_outgoing/left_norm_input:offset dtype=float32 shape=(2, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_outgoing/left_norm_input:scale dtype=float32 shape=(2, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_outgoing/output_projection:weights dtype=bfloat16 shape=(2, 64, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/__layer_stack_no_per_layer/template_embedding_iteration/triangle_multiplication_outgoing/projection:weights dtype=bfloat16 shape=(2, 64, 128)
name=diffuser/evoformer/template_embedding/single_template_embedding/output_layer_norm:offset dtype=float32 shape=(64,)
name=diffuser/evoformer/template_embedding/single_template_embedding/output_layer_norm:scale dtype=float32 shape=(64,)
name=diffuser/evoformer/template_embedding/single_template_embedding/query_embedding_norm:offset dtype=float32 shape=(128,)
name=diffuser/evoformer/template_embedding/single_template_embedding/query_embedding_norm:scale dtype=float32 shape=(128,)
name=diffuser/evoformer/template_embedding/single_template_embedding/template_pair_embedding_0:weights dtype=bfloat16 shape=(39, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/template_pair_embedding_1:weights dtype=bfloat16 shape=(64,)
name=diffuser/evoformer/template_embedding/single_template_embedding/template_pair_embedding_2:weights dtype=bfloat16 shape=(31, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/template_pair_embedding_3:weights dtype=bfloat16 shape=(31, 64)
name=diffuser/evoformer/template_embedding/single_template_embedding/template_pair_embedding_4:weights dtype=bfloat16 shape=(64,)
name=diffuser/evoformer/template_embedding/single_template_embedding/template_pair_embedding_5:weights dtype=bfloat16 shape=(64,)
name=diffuser/evoformer/template_embedding/single_template_embedding/template_pair_embedding_6:weights dtype=bfloat16 shape=(64,)
name=diffuser/evoformer/template_embedding/single_template_embedding/template_pair_embedding_7:weights dtype=bfloat16 shape=(64,)
name=diffuser/evoformer/template_embedding/single_template_embedding/template_pair_embedding_8:weights dtype=bfloat16 shape=(128, 64)
```
================================================
FILE: docs/output.md
================================================
# AlphaFold 3 Output
## Output Directory Structure
For every input job, AlphaFold 3 writes all its outputs in a directory called by
the sanitized version of the job name. E.g. for job name "My first fold (TEST)",
AlphaFold 3 will write its outputs in a directory called `My_first_fold_TEST`
(the case is respected). If such directory already exists, AlphaFold 3 will
append a timestamp to the directory name to avoid overwriting existing data
unless `--force_output_dir` is passed.
The following structure is used within the output directory:
* Sub-directories with results for each sample and seed. There will be
*num\_seeds* \* *num\_samples* such sub-directories. The naming pattern is
`seed-_sample-`. Each of these directories
contains a confidence JSON, summary confidence JSON, and the mmCIF with the
predicted structure.
* Distogram for each seed: `seed-_distogram/distogram.npz`. The
Numpy zip file contains a single key: `distogram`. The distogram can be
large, its shape is `(num_tokens, num_tokens, 64)` and dtype `np.float16`
(almost 3 GiB for a 5,000-token input). Only saved if AlphaFold 3 is run
with `--save_distogram=true`.
* Embeddings for each seed: `seed-_embeddings/embeddings.npz`. The
Numpy zip file contains 2 keys: `single_embeddings` and `pair_embeddings`.
The embeddings can be large, their shapes are `(num_tokens, 384)` for
`single_embeddings`, and `(num_tokens, num_tokens, 128)` for
`pair_embeddings`. Their dtype is `np.float16` (almost 6 GiB for a
5,000-token input). Only saved if AlphaFold 3 is run with
`--save_embeddings=true`.
* Top-ranking prediction mmCIF: `_model.cif`. This file contains the
predicted coordinates and should be compatible with most structural biology
tools. We do not provide the output in the PDB format, the CIF file can be
easily converted into one if needed.
* Top-ranking prediction confidence JSON: `_confidences.json`.
* Top-ranking prediction summary confidence JSON:
`_summary_confidences.json`.
* Job input JSON file with the MSA and template data added by the data
pipeline: `_data.json`.
* Ranking scores for all predictions: `ranking_scores.csv`. The prediction
with highest ranking is the one included in the root directory.
* Output terms of use: `TERMS_OF_USE.md`.
Below is an example AlphaFold 3 output directory listing for a job called "Hello
Fold", that has been ran with 1 seed and 5 samples:
```txt
hello_fold/
├── seed-1234_distogram # Only if --save_distogram=true.
│ └── hello_fold_seed-1234_distogram.npz # Only if --save_distogram=true.
├── seed-1234_embeddings # Only if --save_embeddings=true.
│ └── hello_fold_seed-1234_embeddings.npz # Only if --save_embeddings=true.
├── seed-1234_sample-0/
│ ├── hello_fold_seed-1234_sample-0_confidences.json
│ ├── hello_fold_seed-1234_sample-0_model.cif
│ └── hello_fold_seed-1234_sample-0_summary_confidences.json
├── seed-1234_sample-1/
│ ├── hello_fold_seed-1234_sample-1_confidences.json
│ ├── hello_fold_seed-1234_sample-1_model.cif
│ └── hello_fold_seed-1234_sample-1_summary_confidences.json
├── seed-1234_sample-2/
│ ├── hello_fold_seed-1234_sample-2_confidences.json
│ ├── hello_fold_seed-1234_sample-2_model.cif
│ └── hello_fold_seed-1234_sample-2_summary_confidences.json
├── seed-1234_sample-3/
│ ├── hello_fold_seed-1234_sample-3_confidences.json
│ ├── hello_fold_seed-1234_sample-3_model.cif
│ └── hello_fold_seed-1234_sample-3_summary_confidences.json
├── seed-1234_sample-4/
│ ├── hello_fold_seed-1234_sample-4_confidences.json
│ ├── hello_fold_seed-1234_sample-4_model.cif
│ └── hello_fold_seed-1234_sample-4_summary_confidences.json
├── TERMS_OF_USE.md
├── hello_fold_confidences.json
├── hello_fold_data.json
├── hello_fold_model.cif
├── hello_fold_ranking_scores.csv
└── hello_fold_summary_confidences.json
```
## Confidence Metrics
Similar to AlphaFold 2 and AlphaFold-Multimer, AlphaFold 3 outputs include
confidence metrics. The main metrics are:
* **pLDDT:** a per-atom confidence estimate on a 0-100 scale where a higher
value indicates higher confidence. pLDDT aims to predict a modified LDDT
score that only considers distances to polymers. For proteins this is
similar to the
[lDDT-Cα metric](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3799472/) but
with more granularity as it can vary per atom not just per residue. For
ligand atoms, the modified LDDT considers the errors only between the ligand
atom and polymers, not other ligand atoms. For DNA/RNA a wider radius of 30
Å is used for the modified LDDT instead of 15 Å.
* **PAE (predicted aligned error)**: an estimate of the error in the relative
position and orientation between two tokens in the predicted structure.
Higher values indicate higher predicted error and therefore lower
confidence. For proteins and nucleic acids, PAE score is essentially the
same as AlphaFold 2, where the error is measured relative to frames
constructed from the protein backbone. For small molecules and
post-translational modifications, a frame is constructed for each atom from
its closest neighbors from a reference conformer.
* **pTM and ipTM scores**: the predicted template modeling (pTM) score and the
interface predicted template modeling (ipTM) score are both derived from a
measure called the template modeling (TM) score. This measures the accuracy
of the entire structure
([Zhang and Skolnick, 2004](https://doi.org/10.1002/prot.20264);
[Xu and Zhang, 2010](https://doi.org/10.1093/bioinformatics/btq066)). A pTM
score above 0.5 means the overall predicted fold for the complex might be
similar to the true structure. ipTM measures the accuracy of the predicted
relative positions of the subunits within the complex. Values higher than
0.8 represent confident high-quality predictions, while values below 0.6
suggest a failed prediction. ipTM values between 0.6 and 0.8 are a gray zone
where predictions could be correct or incorrect. The TM score is very strict
for small structures or short chains, so pTM assigns values less than 0.05
when fewer than 20 tokens are involved; for these cases PAE or pLDDT may be
more indicative of prediction quality.
For detailed description of these confidence metrics see the
[AlphaFold 3 paper](https://www.nature.com/articles/s41586-024-07487-w). For
protein components, the
[AlphaFold: A Practical guide](https://www.ebi.ac.uk/training/online/courses/alphafold/inputs-and-outputs/evaluating-alphafolds-predicted-structures-using-confidence-scores/)
course for structures provides additional tutorials on the confidence metrics.
If you are interested in a specific entity or interaction, then there are
confidences available in the outputs which are specific to each chain or
chain-pair, as opposed to the full complex. See below for more details on all
the confidence metrics that are returned.
## Multi-Seed and Multi-Sample Results
By default, the model samples five predictions per seed. The top-ranked
prediction across all samples and seeds is available at the top-level of the
output directory. All samples along with their associated confidences are
available in subdirectories of the output directory.
For ranking of the full complex use the `ranking_score` (higher is better). This
score uses overall structure confidences (pTM and ipTM), but also includes terms
that penalize clashes and encourage disordered regions not to have spurious
helices – these extra terms mean the score should only be used to rank
structures.
If you are interested in a specific entity or interaction, you may want to rank
by a metric specific to that chain or chain-pair, as opposed to the full
complex. In that case, use the per chain or per chain-pair confidence metrics
described below for ranking.
## Metrics in Confidences JSON
For each predicted sample we provide two JSON files. One contains summary
metrics – summaries for either the whole structure, per chain or per chain-pair
– and the other contains full 1D or 2D arrays.
Summary outputs:
* `ptm`: A scalar in the range 0-1 indicating the predicted TM-score for the
full structure.
* `iptm`: A scalar in the range 0-1 indicating predicted interface TM-score
(confidence in the predicted interfaces) for all interfaces in the
structure.
* `fraction_disordered`: A scalar in the range 0-1 that indicates what
fraction of the prediction structure is disordered, as measured by
accessible surface area, see our
[paper](https://www.nature.com/articles/s41586-024-07487-w) for details.
* `has_clash`: A boolean indicating if the structure has a significant number
of clashing atoms (more than 50% of a chain, or a chain with more than 100
clashing atoms).
* `ranking_score`: A scalar in the range \[-100, 1.5\] that can be used for
ranking predictions, it incorporates `ptm`, `iptm`, `fraction_disordered`
and `has_clash` into a single number with the following equation: 0.8 × ipTM
\+ 0.2 × pTM \+ 0.5 × disorder − 100 × has_clash.
* `chain_pair_pae_min`: A \[num_chains, num_chains\] array. Element (i, j) of
the array contains the lowest PAE value across rows restricted to chain i
and columns restricted to chain j. This has been found to correlate with
whether two chains interact or not, and in some cases can be used to
distinguish binders from non-binders.
* `chain_pair_iptm`: A \[num_chains, num_chains\] array. Off-diagonal element
(i, j) of the array contains the ipTM restricted to tokens from chains i and
j. Diagonal element (i, i) contains the pTM restricted to chain i. Can be
used for ranking a specific interface between two chains, when you know that
they interact, e.g. for antibody-antigen interactions
* `chain_ptm`: A \[num_chains\] array. Element i contains the pTM restricted
to chain i. Can be used for ranking individual chains when the structure of
that chain is most of interest, rather than the cross-chain interactions it
is involved with.
* `chain_iptm:` A \[num_chains\] array that gives the average confidence
(interface pTM) in the interface between each chain and all other chains.
Can be used for ranking a specific chain, when you care about where the
chain binds to the rest of the complex and you do not know which other
chains you expect it to interact with. This is often the case with ligands.
Full array outputs:
* `pae`: A \[num\_tokens, num\_tokens\] array. Element (i, j) indicates the
predicted error in the position of token j, when the prediction is aligned
to the ground truth using the frame of token i.
* `atom_plddts`: A \[num_atoms\] array, element i indicates the predicted
local distance difference test (pLDDT) for atom i in the prediction.
* `contact_probs`: A \[num_tokens, num_tokens\] array. Element (i, j)
indicates the predicted probability that token i and token j are in contact
(8 Å between the representative atom for each token), see
[paper](https://www.nature.com/articles/s41586-024-07487-w) for details.
* `token_chain_ids`: A \[num_tokens\] array indicating the chain ids
corresponding to each token in the prediction.
* `atom_chain_ids`: A \[num_atoms\] array indicating the chain ids
corresponding to each atom in the prediction.
## Embeddings
AlphaFold 3 can be run with `--save_embeddings=true` to save the embeddings for
each seed. The file is in the
[compressed Numpy `.npz` format](https://numpy.org/doc/stable/reference/generated/numpy.savez_compressed.html)
and can be loaded using `numpy.load` as a dictionary-like object with two
arrays:
* `single_embeddings`: A \`[num\_tokens, 384\] array containing the embeddings
for each token.
* `pair_embeddings`: A \[num\_tokens, num\_tokens, 128\] array containing the
pairwise embeddings between all tokens.
You can use for instance the following Python code to load the embeddings:
```py
import numpy as np
with open('embeddings.npz', 'rb') as f:
embeddings = np.load(f)
single_embeddings = embeddings['single_embeddings']
pair_embeddings = embeddings['pair_embeddings']
```
## Chirality checks
In the AlphaFold 3 paper Posebusters results, a penalty was applied to the
ranking score if the ligand of interest contained chiral errors. By running
multiple seeds and using this chiral aware ranking, chiral error rates were
greatly reduced.
We provide the method `compare_chirality` in
[`model/scoring/chirality.py`](https://github.com/google-deepmind/alphafold3/blob/main/src/alphafold3/model/scoring/chirality.py)
to replicate these chiral checks. Chirality is checked against CCD structures if
available, otherwise users can supply custom RDKit Mol objects for comparison.
================================================
FILE: docs/performance.md
================================================
# Performance
## Running the Pipeline in Stages
The `run_alphafold.py` script can be executed in stages to optimise resource
utilisation. This can be useful for:
1. Splitting the CPU-only data pipeline from model inference (which requires a
GPU), to optimise cost and resource usage.
1. Generating the JSON output file from the data pipeline only run and then
using it for multiple different inference only runs across seeds or across
variations of other features (e.g. a ligand or a partner chain).
1. Generating the JSON output for multiple individual monomer chains (e.g. for
chains A, B, C, D), then running the inference on all possible chain pairs
(AB, AC, AD, BC, BD, CD) by creating dimer JSONs by merging the monomer
JSONs. By doing this, the MSA and template search need to be run just 4
times (once for each chain), instead of 12 times.
### Data Pipeline Only
Launch `run_alphafold.py` with `--norun_inference` to generate Multiple Sequence
Alignments (MSAs) and templates, without running featurisation and model
inference. This stage can be quite costly in terms of runtime, CPU, and RAM use.
The output will be JSON files augmented with MSAs and templates that can then be
directly used as input for running inference.
### Pre-computing and reusing MSA and templates
When folding multiple candidate chains with a set of fixed chains (i.e. chains
that are the same for all the runs), you can optimize the process by computing
the MSA and templates for the fixed chains only once. The computations for the
changing candidate chains will still be performed for each run:
1. Run the AlphaFold 3 data pipeline for the fixed chains using the
`--run_inference=false` flag. This step generates a JSON file containing the
MSA and template data for these chains.
2. When constructing your multimer input JSONs, populate the entries for the
fixed chains using the data generated in the previous step.
* For the fixed chains: Specifically, copy the `unpairedMsa`, `pairedMsa`,
and `templates` fields from the pre-computed JSON into the multimer
input JSON. This prevents these fields from being recomputed.
* For the candidate chains: Leave these fields unset (or `null`) in the
multimer input JSON. This will signal the pipeline to compute them
dynamically for each run.
This technique can also be extended to efficiently process all combinations of
*n* first chains and *m* second chains. Instead of performing *n* × *m* full
computations, you can reduce this to *n* + *m* data pipeline runs.
In this scenario:
1. Run the data pipeline (step 1 above, with `--run_inference=false`) for all
*n* individual first chains and all *m* individual second chains.
2. Assemble the dimer input JSONs for each desired pair by combining their
respective pre-computed monomer JSONs.
3. Run only the inference step on these assembled JSONs using the
`--run_data_pipeline=false` flag.
This approach has been discussed in multiple GitHub issues, such as:
https://github.com/google-deepmind/alphafold3/issues/171 (which links to other
similar issues).
### Featurisation and Model Inference Only
Launch `run_alphafold.py` with `--norun_data_pipeline` to skip the data pipeline
and run only featurisation and model inference. This stage requires the input
JSON file to contain pre-computed MSAs and templates (or they must be explicitly
set to empty if you want to run MSA and template free).
## Data Pipeline
The runtime of the data pipeline (i.e. genetic sequence search and template
search) can vary significantly depending on the size of the input and the number
of homologous sequences found, as well as the available hardware – the disk
speed can influence genetic search speed in particular.
If you would like to improve performance, it's recommended to increase the disk
speed (e.g. by leveraging a RAM-backed filesystem), or increase the available
CPU cores and add more parallelisation. This can help because AlphaFold 3 runs
genetic search against 4 databases in parallel, so the optimal number of cores
is the number of cores used for each Jackhmmer process times 4. Also note that
for sequences with deep MSAs, Jackhmmer or Nhmmer may need a substantial amount
of RAM beyond the recommended 64 GB of RAM.
### Sharded genetic databases
The run time of the genetic database search can be *significantly* sped up by
splitting the genetic databases if a machine with many CPU cores is used and the
databases are on very fast SSD or in a RAM-backed filesystem. With this
technique you can make Jackhmmer/Nhmmer genetic search fully utilize your
hardware and take advantage of multi-core systems.
Each genetic database with *n* sequences is split into *s* shards, each
containing roughly *n* / *s* sequences. We recommend splitting the sequences
between shards randomly to make sure each shard has similar sequence length
distribution. This could be achieved using standard tools:
1. Shuffle the sequences in the fasta. This can be done for example by running:
`seqkit shuffle --two-pass `
2. Split the shuffled fasta in *s* shards. This can be done for example by
running: `seqkit split2 --by-part `
Make sure the shards names follow this pattern:
`prefix--of-`, both `shard_index` and `total_shards`
having always 5 digits, with leading zeros as needed. The `shard_index` goes
from 0 to `total_shards - 1`. A file "path" (spec) for a sharded file is
`prefix@`.
E.g. for a file named `uniprot.fasta` split into 3 shards, the names of the
shards should be:
* `uniprot.fasta-00000-of-00003`
* `uniprot.fasta-00001-of-00003`
* `uniprot.fasta-00002-of-00003`
The file spec for these files is `uniprot.fasta@3`.
Save the total number of sequences in the protein databases, and the total
number of nucleic bases in the RNA databases – these will be needed later as a
flag to Jackhmmer/Nhmmer to correctly scale e-values across all shards.
Save the sharded databases on a fast SSD or in a RAM-backed filesystem, then
launch AlphaFold with the sharded paths instead of normal paths and set the
Z-values.
For instance with each database sharded into 16 shards:
```bash
python run_alphafold.py \
--small_bfd_database_path="bfd-first_non_consensus_sequences.fasta@64" \
--small_bfd_z_value=65984053 \
--mgnify_database_path="mgy_clusters_2022_05.fa@512" \
--mgnify_z_value=623796864 \
--uniprot_cluster_annot_database_path="uniprot_cluster_annot_2021_04.fasta@256" \
--uniprot_cluster_annot_z_value=225619586 \
--uniref90_database_path="uniref90_2022_05.fasta@128" \
--uniref90_z_value=153742194 \
--ntrna_database_path="nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta@256" \
--ntrna_z_value=76752.808514 \
--rfam_database_path="rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta@16" \
--rfam_z_value=138.115553 \
--rna_central_database_path="rnacentral_active_seq_id_90_cov_80_linclust.fasta@64" \
--rna_central_z_value=13271.415730
--jackhmmer_n_cpu=2 \
--jackhmmer_max_parallel_shards=16 \
--nhmmer_n_cpu=2 \
--nhmmer_max_parallel_shards=16
```
This run will utilize (2 CPUs) × (16 max parallel shards) × (4 protein dbs
searched in parallel) = 128 cores for each protein chain, and (2 CPUs) × (16 max
parallel shards) × (3 RNA dbs searched in parallel) = 96 cores for each RNA
chain. Make sure to tune:
* the Jackhmmer/Nhmmer number of CPUs,
* the maximum number of shards searched in parallel,
* and the number of shards for each database
so that the memory bandwidth and CPUs on your machine are optimally utilized.
You should aim for consistent shard sizes across all databases (so e.g. if
database A is split into 16 shards and is 3× smaller than database B, database B
should be split into 3 × 16 = 48 shards).
## Model Inference
Table 8 in the Supplementary Information of the
[AlphaFold 3 paper](https://nature.com/articles/s41586-024-07487-w) provides
compile-free inference timings for AlphaFold 3 when configured to run on 16
NVIDIA A100s, with 40 GB of memory per device. In contrast, this repository
supports running AlphaFold 3 on a single NVIDIA A100 with 80 GB of memory in a
configuration optimised to maximise throughput.
We compare compile-free inference timings of these two setups in the table below
using GPU seconds (i.e. multiplying by 16 when using 16 A100s). The setup in
this repository is more efficient (by at least 2×) across all token sizes,
indicating its suitability for high-throughput applications.
Num Tokens | 1 A100 80 GB (GPU secs) | 16 A100 40 GB (GPU secs) | Improvement
:--------- | ----------------------: | -----------------------: | ----------:
1024 | 62 | 352 | 5.7×
2048 | 275 | 1136 | 4.1×
3072 | 703 | 2016 | 2.9×
4096 | 1434 | 3648 | 2.5×
5120 | 2547 | 5552 | 2.2×
## Accelerator Hardware Requirements
We officially support the following configurations, and have extensively tested
them for numerical accuracy and throughput efficiency:
- 1 NVIDIA A100 (80 GB)
- 1 NVIDIA H100 (80 GB)
We compare compile-free inference timings of both configurations in the
following table:
Num Tokens | 1 A100 80 GB (seconds) | 1 H100 80 GB (seconds)
:--------- | ---------------------: | ---------------------:
1024 | 62 | 34
2048 | 275 | 144
3072 | 703 | 367
4096 | 1434 | 774
5120 | 2547 | 1416
### Other Hardware Configurations
#### NVIDIA A100 (40 GB)
AlphaFold 3 can run on inputs of size up to 4,352 tokens on a single NVIDIA A100
(40 GB) with the following configuration changes:
1. Enabling [unified memory](#unified-memory).
1. Adjusting `pair_transition_shard_spec` in `model_config.py`:
```py
pair_transition_shard_spec: Sequence[_Shape2DType] = (
(2048, None),
(3072, 1024),
(None, 512),
)
```
The format of entries in `pair_transition_shard_spec` is
`(num_tokens_upper_bound, shard_size)`. Setting `shard_size=None` means there is
no upper bound.
For the example above:
* `(2048, None)`: for sequences up to 2,048 tokens, do not shard
* `(3072, 1024)`: for sequences up to 3,072 tokens, shard in chunks of 1,024
* `(None, 512)`: for all other sequences, shard in chunks of 512
While numerically accurate, this configuration will have lower throughput
compared to the set up on the NVIDIA A100 (80 GB), due to less available memory.
#### NVIDIA V100
There are known numerical issues with CUDA Capability 7.x devices. To work
around the issue, set the ENV XLA_FLAGS to include
`--xla_disable_hlo_passes=custom-kernel-fusion-rewriter`.
With the above flag set, AlphaFold 3 can run on inputs of size up to 1,280
tokens on a single NVIDIA V100 using [unified memory](#unified-memory).
#### NVIDIA P100
AlphaFold 3 can run on inputs of size up to 1,024 tokens on a single NVIDIA P100
with no configuration changes needed.
#### Other devices
Large-scale numerical tests have not been performed on any other devices but
they are believed to be numerically accurate.
There are known numerical issues with CUDA Capability 7.x devices. To work
around the issue, set the environment variable `XLA_FLAGS` to include
`--xla_disable_hlo_passes=custom-kernel-fusion-rewriter`.
## Compilation Buckets
To avoid excessive re-compilation of the model, AlphaFold 3 implements
compilation buckets: ranges of input sizes using a single compilation of the
model.
When featurising an input, AlphaFold 3 determines the smallest bucket the input
fits into, then adds any necessary padding. This may avoid re-compiling the
model when running inference on the input if it belongs to the same bucket as a
previously processed input.
The configuration of bucket sizes involves a trade-off: more buckets leads to
more re-compilations of the model, but less padding.
By default, the largest bucket size is 5,120 tokens. Processing inputs larger
than this maximum bucket size triggers the creation of a new bucket for exactly
that input size, and a re-compilation of the model. In this case, you may wish
to redefine the compilation bucket sizes via the `--buckets` flag in
`run_alphafold.py` to add additional larger bucket sizes. For example, suppose
you are running inference on inputs with token sizes: `5132, 5280, 5342`. Using
the default bucket sizes configured in `run_alphafold.py` will trigger three
separate model compilations, one for each unique token size. If instead you pass
in the following flag to `run_alphafold.py`
```
--buckets 256,512,768,1024,1280,1536,2048,2560,3072,3584,4096,4608,5120,5376
```
when running inference on the above three input sizes, the model will be
compiled only once for the bucket size `5376`. **Note:** for this specific
example with input sizes `5132, 5280, 5342`, passing in `--buckets 5376` is
sufficient to achieve the desired compilation behaviour. The provided example
with multiple buckets illustrates a more general solution suitable for diverse
input sizes.
## Additional Flags
### Compilation Time Workaround with XLA Flags
To work around a known XLA issue causing the compilation time to greatly
increase, the following environment variable must be set (it is set by default
in the provided `Dockerfile`).
```sh
ENV XLA_FLAGS="--xla_gpu_enable_triton_gemm=false"
```
### CUDA Capability 7.x GPUs
For all CUDA Capability 7.x GPUs (e.g. V100) the environment variable
`XLA_FLAGS` must be changed to include
`--xla_disable_hlo_passes=custom-kernel-fusion-rewriter`. Disabling the Tritron
GEMM kernels is not necessary as they are not supported for such GPUs.
```sh
ENV XLA_FLAGS="--xla_disable_hlo_passes=custom-kernel-fusion-rewriter"
```
### GPU Memory
The following environment variables (set by default in the `Dockerfile`) enable
folding a single input of size up to 5,120 tokens on a single A100 (80 GB) or a
single H100 (80 GB):
```sh
ENV XLA_PYTHON_CLIENT_PREALLOCATE=true
ENV XLA_CLIENT_MEM_FRACTION=0.95
```
#### Unified Memory
If you would like to run AlphaFold 3 on inputs larger than 5,120 tokens, or on a
GPU with less memory (an A100 with 40 GB of memory, for instance), we recommend
enabling unified memory. Enabling unified memory allows the program to spill GPU
memory to host memory if there isn't enough space. This prevents an OOM, at the
cost of making the program slower by accessing host memory instead of device
memory. To learn more, check out the
[NVIDIA blog post](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/).
You can enable unified memory by setting the following environment variables in
your `Dockerfile`:
```sh
ENV XLA_PYTHON_CLIENT_PREALLOCATE=false
ENV TF_FORCE_UNIFIED_MEMORY=true
ENV XLA_CLIENT_MEM_FRACTION=3.2
```
### JAX Persistent Compilation Cache
You may also want to make use of the JAX persistent compilation cache, to avoid
unnecessary recompilation of the model between runs. You can enable the
compilation cache with the `--jax_compilation_cache_dir ` flag
in `run_alphafold.py`.
More detailed instructions are available in the
[JAX documentation](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html#persistent-compilation-cache),
and more specifically the instructions for use on
[Google Cloud](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html#persistent-compilation-cache).
In particular, note that if you would like to make use of a non-local
filesystem, such as Google Cloud Storage, you will need to install
[`etils`](https://github.com/google/etils) (this is not included by default in
the AlphaFold 3 Docker container).
================================================
FILE: fetch_databases.sh
================================================
#!/bin/bash
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
set -euo pipefail
readonly db_dir=${1:-$HOME/public_databases}
for cmd in wget tar zstd ; do
if ! command -v "${cmd}" > /dev/null 2>&1; then
echo "${cmd} is not installed. Please install it."
exit 1
fi
done
echo "Fetching databases to ${db_dir}"
mkdir -p "${db_dir}"
readonly SOURCE=https://storage.googleapis.com/alphafold-databases/v3.0
echo "Start Fetching and Untarring 'pdb_2022_09_28_mmcif_files.tar'"
wget --quiet --output-document=- \
"${SOURCE}/pdb_2022_09_28_mmcif_files.tar.zst" | \
tar --no-same-owner --no-same-permissions \
--use-compress-program=zstd -xf - --directory="${db_dir}" &
for NAME in mgy_clusters_2022_05.fa \
bfd-first_non_consensus_sequences.fasta \
uniref90_2022_05.fa uniprot_all_2021_04.fa \
pdb_seqres_2022_09_28.fasta \
rnacentral_active_seq_id_90_cov_80_linclust.fasta \
nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta \
rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta ; do
echo "Start Fetching '${NAME}'"
wget --quiet --output-document=- "${SOURCE}/${NAME}.zst" | \
zstd --decompress > "${db_dir}/${NAME}" &
done
wait
echo "Complete"
================================================
FILE: legal/WEIGHTS_PROHIBITED_USE_POLICY-Bahasa-Indonesia.md
================================================
# KEBIJAKAN PENGGUNAAN TERLARANG UNTUK PARAMETER MODEL ALPHAFOLD 3
Terakhir diubah: 2024-11-09
AlphaFold 3 dapat membantu Anda mempercepat riset ilmiah dengan memprediksi
struktur 3D molekul biologis. Google menyediakan Aset AlphaFold tanpa biaya
untuk penggunaan non-komersial tertentu, sesuai dengan pembatasan yang
ditetapkan di bawah. Kebijakan ini menggunakan persyaratan yang sama dengan
[Persyaratan Penggunaan Parameter Model AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_TERMS_OF_USE-Bahasa-Indonesia.md).
**Anda tidak boleh mengakses atau menggunakan, atau mengizinkan orang lain
mengakses atau menggunakan Aset AlphaFold 3:**
1. **Atas nama organisasi komersial atau sehubungan dengan aktivitas komersial
apa pun, termasuk riset atas nama organisasi komersial.**
1. Artinya, hanya organisasi non-komersial (*yaitu*, universitas,
organisasi non-profit dan institusi riset, serta lembaga pendidikan,
jurnalistik, dan pemerintah) yang dapat menggunakan Aset AlphaFold 3
untuk aktivitas non-komersial mereka. Aset AlphaFold 3 tidak tersedia
untuk jenis organisasi lainnya, meskipun organisasi tersebut melakukan
pekerjaan non-komersial.
2. Jika Anda adalah peneliti yang berafiliasi dengan organisasi
non-komersial, Anda dapat menggunakan Aset AlphaFold 3 untuk riset
terafiliasi non-komersial Anda, dengan syarat Anda bukan organisasi
komersial atau bertindak atas nama organisasi komersial.
3. Anda tidak boleh membagikan Aset AlphaFold 3 kepada organisasi komersial
mana pun atau menggunakan Aset AlphaFold 3 dengan cara yang akan memberi
organisasi komersial hak apa pun atas Aset ini. Satu-satunya
pengecualian adalah menyediakan Output secara publik (termasuk secara
tidak langsung kepada organisasi komersial) melalui publikasi ilmiah
atau rilis open source atau menggunakannya untuk mendukung jurnalisme,
yang masing-masing diizinkan.
2. **Untuk menyebarkan misinformasi, memberikan pernyataan tidak benar, atau
menyesatkan pengguna**, termasuk:
1. menyediakan informasi palsu atau tidak akurat sehubungan dengan akses ke
atau penggunaan AlphaFold 3 atau Output oleh Anda, termasuk mengakses
atau menggunakan Parameter Model atas nama organisasi tanpa memberi tahu
kami atau mengirimkan permintaan untuk mengakses Parameter Model di mana
Google telah melarang penggunaan AlphaFold 3 oleh Anda secara
keseluruhan atau sebagian (termasuk yang disediakan melalui
[Server AlphaFold](https://alphafoldserver.com/about));
2. memberikan pernyataan tidak benar tentang hubungan Anda dengan kami;
termasuk dengan menggunakan merek dagang, nama dagang, atau logo Google,
atau menyiratkan dukungan oleh Google tanpa seizin Google - Tidak ada di
dalam Persyaratan memberikan izin semacam itu;
3. memberikan pernyataan tidak benar tentang asal AlphaFold 3 secara
keseluruhan atau sebagian;
4. menyebarkan klaim menyesatkan tentang keahlian atau kemampuan, atau
terlibat dalam praktik profesional yang tidak sah atau tanpa lisensi,
khususnya di bidang yang sensitif (misalnya, kesehatan); atau
5. membuat keputusan dalam ranah yang memengaruhi hak atau kesejahteraan
individu atau material (misalnya, layanan kesehatan).
3. **Untuk melakukan, mempromosikan, atau memfasilitasi aktivitas berbahaya,
ilegal, atau jahat, termasuk:**
1. mempromosikan atau memfasilitasi penjualan, ataupun memberikan petunjuk
untuk membuat atau mengakses, zat, barang, atau layanan ilegal;
2. menyalahgunakan, merugikan, mengganggu, atau mengacaukan layanan apa
pun, termasuk membuat atau mendistribusikan konten untuk aktivitas
penipuan atau penyebaran malware;
3. membuat atau mendistribusikan konten, termasuk Output, yang menyalahi,
menyalahgunakan, atau melanggar hak individu atau entitas apa pun
(termasuk, tetapi tidak terbatas pada hak atas konten yang dilindungi
hak cipta); atau
4. mencoba mengakali, atau dengan sengaja menyebabkan (secara langsung atau
tidak langsung) AlphaFold 3 untuk bertindak dengan cara yang melanggar
Persyaratan.
**Anda tidak boleh atau mengizinkan orang lain:**
1. **Menggunakan Output guna melatih atau membuat model machine learning atau
teknologi terkait untuk prediksi struktur biomolekuler yang mirip dengan
AlphaFold 3 ("Model Turunan"),** termasuk melalui distilasi atau metode
lainnya. Untuk menegaskan, pembatasan penggunaan yang ditetapkan dalam
Persyaratan akan berlaku sepenuhnya untuk semua Model Turunan yang dibuat
dengan melanggar Persyaratan.
2. **Mendistribusikan Output tanpa memberikan pemberitahuan yang jelas bahwa
apa yang Anda Distribusikan disediakan berdasarkan dan tunduk pada
[Persyaratan Penggunaan Output AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md)
serta tentang modifikasi apa pun yang Anda buat.**
1. Artinya, jika Anda menghapus, atau menyebabkan penghapusan (misalnya
dengan menggunakan perangkat lunak pihak ketiga), pemberitahuan dan
syarat yang kami berikan saat Anda menghasilkan Output menggunakan
AlphaFold 3, Anda harus memastikan Distribusi Output berikutnya
menyertakan salinan
[Persyaratan Penggunaan Output AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md)
dan file teks "Persyaratan Penggunaan yang Mengikat secara Hukum" yang
berisi pemberitahuan berikut:
"*Dengan menggunakan informasi ini, Anda menyetujui Persayatan
Penggunaan Output AlphaFold 3 yang terdapat di
https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md.*
*Untuk meminta akses ke parameter model AlphaFold 3, ikuti proses yang
ditetapkan di https://github.com/google-deepmind/alphafold3. Anda hanya
dapat menggunakan parameter model ini jika menerimanya langsung dari
Google. Penggunaannya tunduk pada persyaratan penggunaan yang tersedia
di
https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.*"
1. Anda tidak boleh menyertakan persyaratan tambahan atau berbeda yang
bertentangan dengan
[Persyaratan Penggunaan Output AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md).
3. **Mendistribusikan Output, atau mengungkapkan temuan yang didapatkan dari
penggunaan AlphaFold 3 tanpa mengutip makalah kami**: [Abramson, J et al.
Accurate structure prediction of biomolecular interactions with AlphaFold
3.](https://www.nature.com/articles/s41586-024-07487-w). Untuk menegaskan,
hal ini merupakan persyaratan tambahan selain persyaratan pemberitahuan yang
ditetapkan di atas.
4. **Mengakali pembatasan akses terkait Parameter Model, termasuk menggunakan,
membagikan, atau menyediakan Parameter Model ketika Google belum mengizinkan
Anda secara tegas untuk melakukan hal tersebut.** Google akan memberikan
akses ke Parameter Model kepada:
1. Anda untuk penggunaan pribadi Anda atas nama organisasi Anda, dimana
Anda tidak dapat membagikan salinan Parameter Model Anda kepada siapa
pun; atau
2. perwakilan resmi organisasi Anda, dengan kewenangan hukum penuh untuk
mengikat organisasi tersebut pada Persyaratan ini. Dalam hal ini, Anda
dapat membagikan salinan Parameter Model milik organisasi tersebut
kepada karyawan, konsultan, kontraktor, serta agen organisasi
sebagaimana diizinkan oleh perwakilan tersebut.
================================================
FILE: legal/WEIGHTS_PROHIBITED_USE_POLICY-Espanol-Latinoamerica.md
================================================
# POLÍTICA DE USO PROHIBIDO DE PARÁMETROS DEL MODELO ALPHAFOLD 3
Última modificación: 9 de noviembre de 2024
AlphaFold 3 puede ayudar a acelerar la investigación científica, ya que predice
la estructura 3D de moléculas biológicas. Google pone a disposición los Recursos
de AlphaFold sin costo para determinados usos no comerciales de conformidad con
las restricciones que se establecen a continuación. Esta política usa los mismos
términos definidos que en las
[Condiciones de Uso de los Parámetros del Modelo AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Espanol-Latinoamerica.md).
**No debe acceder o utilizar ni permitir que otros accedan o utilicen los
Recursos de AlphaFold 3 en los siguientes casos:**
1. **En nombre de una organización comercial o en conexión con cualquier
actividad comercial, incluida la investigación en nombre de organizaciones
comerciales.**
1. Esto significa que solo las organizaciones no comerciales (*por
ejemplo*, universidades, institutos de investigación y organizaciones
sin fines de lucro, y organismos educativos, gubernamentales y
periodísticos) pueden usar los Recursos de AlphaFold 3 para sus
actividades no comerciales. Los Recursos de AlphaFold 3 no están
disponibles para ningún otro tipo de organización, aunque realicen
trabajos no comerciales.
2. Si usted es un investigador afiliado de una organización no comercial,
dado que no pertenece a una organización comercial, puede usar los
Recursos de AlphaFold 3 para su investigación de afiliación no
comercial.
3. No debe compartir los Recursos de AlphaFold 3 con ninguna organización
comercial ni usarlos de manera que otorgue a una organización comercial
algún derecho sobre estos. La única excepción es poner los Resultados a
disposición del público (lo que incluye indirectamente a las
organizaciones comerciales) a través de una publicación científica o una
publicación de código abierto, o utilizarlos para apoyar la actividad
periodística, opciones que están todas permitidas.
2. **Para desinformar, tergiversar o engañar, entre lo que se incluye lo
siguiente:**
1. proporcionar información falsa o errónea en relación con su acceso o uso
de AlphaFold 3 o los Resultados, incluido el uso o acceso a los
Parámetros del Modelo en nombre de una organización sin informarnos o
enviarnos una solicitud para acceder a los Parámetros del Modelo cuando
Google le ha prohibido el uso de AlphaFold 3 de forma parcial o total
(incluido como se pone a disposición a través de
[AlphaFold Server](https://alphafoldserver.com/about)),
2. tergiversar su relación con nosotros, incluido el uso de marcas,
comerciales, nombres comerciales o logotipos de Google, o sugerir
recomendación por parte de Google sin el permiso de Google para hacerlo
(ningún punto de las Condiciones otorga ese permiso)
3. tergiversar el origen de AlphaFold 3 de forma parcial o total,
4. distribuir declaraciones engañosas sobre experiencia o capacidad, o
participar en la práctica de cualquier profesión sin autorización o
licencia, en particular si se trata de áreas sensibles (*p. ej.*, la de
la salud), o
5. tomar decisiones en ámbitos que afectan el bienestar o los derechos
materiales o individuales (*p. ej.*, atención médica).
3. **Para realizar, promover o facilitar actividades peligrosas, ilegales o
maliciosas, entre lo que se incluye lo siguiente:**
1. promover o facilitar la venta de sustancias, bienes o servicios
ilegales, o bien proporcionar instrucciones para sintetizarlos o acceder
a ellos,
2. abusar, interferir, dañar o interrumpir servicios, lo que incluye
generar o distribuir contenido para actividades engañosas o fraudulentas
o software malicioso,
3. generar o distribuir contenido, incluidos los Resultados, que incumpla,
se apropie indebidamente o infrinja los derechos de un individuo o una
entidad (incluidos, sin limitaciones, los derechos de contenido
protegido por derechos de autor), o
4. intentar eludir o causar de forma intencional (directa o indirectamente)
que AlphaFold 3 actúe de manera que incumpla las Condiciones.
**No debe utilizar ni permitir que otros:**
1. **Utilicen los Resultados para entrenar o crear modelos de aprendizaje
automático o tecnología relacionada para la predicción de estructura
biomolecular similar a la de AlphaFold 3 ("Modelos Derivados"),** lo que
incluye métodos a través de destilación o de otro tipo. En aras de evitar
dudas, las restricciones de uso establecidas en las Condiciones se aplicarán
en su totalidad a cualquier Modelo Derivado que se cree incumpliendo las
Condiciones.
2. **Distribuir los Resultados sin brindar un aviso claro de que lo que usted
Distribuye se proporciona de acuerdo con las
[Condiciones de Uso de los Resultados de AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md)
y cualquier modificación que usted haga.**
1. Esto quiere decir que si usted quita o hace que se quiten (por ejemplo,
con software de terceros) los avisos y las condiciones que
proporcionamos cuando genera Resultados usando AlphaFold 3, debe
asegurarse de que cualquier Distribución adicional de los Resultados
esté acompañada por una copia de las
[Condiciones de Uso de Resultados de AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md)
y un archivo de texto llamado "Condiciones de Uso Legalmente
Vinculantes" que contenga el siguiente aviso:
"*Si utiliza esta información, usted acepta las Condiciones de Uso de
Resultados de AlphaFold 3, que se encuentran en
https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md.*
*Para solicitar acceso a los parámetros del modelo AlphaFold 3, siga el
proceso que se establece en
https://github.com/google-deepmind/alphafold3. Solo puede usarlos si los
recibe directamente de Google. El uso está sujeto a las Condiciones de
Uso disponibles en
https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.*"
2. No debe incluir ninguna condición adicional o diferente que entre en
conflicto con las
[Condiciones de Uso de Resultados de AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md).
3. **Distribuir Resultados, o divulgar descubrimientos que surjan del uso de
AlphaFold 3 sin citar nuestro artículo** [Abramson, J et al. Accurate
structure prediction of biomolecular interactions with AlphaFold 3
(Predicción precisa de la estructura de las interacciones biomoleculares con
AlphaFold 3). *Nature*
(2024)](https://www.nature.com/articles/s41586-024-07487-w). En aras de
evitar dudas, este es un requisito adicional a los requisitos de aviso que
se establecen más arriba.
4. **Eludir las restricciones de acceso relacionadas con los Parámetros del
Modelo, lo que incluye utilizar, compartir o poner a disponibilidad los
Parámetros del Modelo cuando no recibió autorización expresa por parte de
Google para hacerlo.** Google otorgará acceso a los Parámetros del Modelo a:
1. Usted para su uso individual o para usarlos en nombre de su
organización, en cuyo caso no puede compartir su copia de los Parámetros
del Modelo con nadie más, o
2. Un representante autorizado de su organización con autoridad legal total
para obligar a esa organización con estas Condiciones (en cuyo caso
usted podrá compartir la copia de los Parámetros del Modelo de esa
organización con empleados, consultores, contratistas y agentes de la
organización, según lo autorizado por ese representante)
================================================
FILE: legal/WEIGHTS_PROHIBITED_USE_POLICY-Francais-Canada.md
================================================
# POLITIQUE D'UTILISATION INTERDITE DES PARAMÈTRES DU MODÈLE ALPHAFOLD 3
Dernière modification: 2024-11-09
AlphaFold 3 peut vous aider à accélérer la recherche scientifique en prévoyant
la structure 3D des molécules biologiques. Pour certaines utilisations non
commerciales, Google met gratuitement à disposition les Éléments d'AlphaFold
dans le respect des restrictions énoncées ci-dessous. Cette politique utilise
les mêmes conditions d'utilisation telles que définies dans les
[Conditions d'utilisation des paramètres du modèle AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Francais-Canada.md).
**Vous ne devez pas accéder aux Éléments d'AlphaFold 3 ni les utiliser ou
permettre à d'autres personnes de le faire:**
1. **Au nom d'une organisation commerciale ou en connexion avec des activités
commerciales, y compris la recherche au nom d'organisations commerciales.**
1. Cela signifie que seules les organisations non commerciales (*c.-à-d*.
universités, organismes sans but lucratif, instituts de recherche et
organismes éducatifs, journalistiques et gouvernementaux) peuvent
utiliser les Éléments d'AlphaFold 3 dans le cadre de leurs activités non
commerciales. Les Éléments AlphaFold 3 ne sont pas offerts à d'autres
types d'organisations, même si elles effectuent des travaux non
commerciaux.
2. Si vous êtes un chercheur affilié à une organisation non commerciale, à
la condition que **vous ne soyez pas une organisation commerciale ou que
vous n'agissez pas au nom d'une organisation commerciale**, vous pouvez
utiliser les Éléments AlphaFold 3 pour vos recherches affiliées non
commerciales.
3. Vous ne devez pas partager les Éléments d'AlphaFold 3 avec une
organisation commerciale ni les utiliser d'une manière qui confère à une
organisation commerciale des droits sur ces éléments. La seule exception
est la mise à disposition publique des Résultats (y compris
indirectement à des organisations commerciales) par le biais d'une
publication scientifique ou d'une version open source, ou l'utilisation
de ces résultats pour soutenir le journalisme, qui sont toutes deux
autorisées.
2. **Pour désinformer ou déformer ou induire en erreur**, y compris:
1. fournir des informations fausses ou inexactes concernant votre accès à
AlphaFold 3 ou à ses Résultats, ou à l'utilisation de ceux-ci, y compris
l'accès aux Paramètres du modèle ou l'utilisation de ceux-ci au nom
d'une organisation sans nous en informer ou sans soumettre une demande
d'accès aux Paramètres du modèle lorsque Google a interdit l'utilisation
d'AlphaFold 3 en totalité ou en partie (y compris tel que mis à
disposition par le biais du
[Serveur d'AlphaFold](https://alphafoldserver.com/about));
2. présenter de manière inexacte votre relation avec nous, y compris en
utilisant les marques de commerce, les noms commerciaux et les logos de
Google ou en suggérant l'approbation de Google sans son autorisation.
Rien dans les présentes Conditions ne permet d'accorder une telle
autorisation;
3. présenter de manière inexacte l'origine d'AlphaFold 3, en tout ou en
partie;
4. distribuer des déclarations trompeuses quant au savoir-faire ou aux
capacités, ou exercer une activité professionnelle sans autorisation ou
sans licence, en particulier dans des domaines sensibles (*p. ex.* les
soins de santé); ou
5. prendre des décisions dans des domaines qui touchent les droits
matériels ou individuels ou le bien-être (*p. ex.* les soins de santé).
3. **Pour effectuer ou faciliter des activités dangereuses, illégales ou
malveillantes**, y compris:
1. la promotion ou l'aide à la vente, ou la fourniture d'instructions pour
synthétiser ou accéder à des substances, des biens ou des services
illégaux, ou l'accès à ces derniers;
2. abuser, nuire, interférer ou perturber tout service, y compris en
générant ou en distribuant du contenu pour des activités trompeuses ou
frauduleuses ou pour des logiciels malveillants;
3. générer ou distribuer tout contenu, y compris des Résultats, qui
enfreigne, détourne ou viole de toute autre manière les droits d'un
individu ou d'une entité (y compris, mais sans s'y limiter, les droits
sur les contenus protégés par des droits d'auteur); ou
4. tenter de contourner, ou causer intentionnellement (directement ou
indirectement) AlphaFold 3 à agir d'une manière qui contrevient aux
Conditions.
**Vous ne devez pas, et vous ne devez pas permettre aux autres:**
1. **D'utiliser les Résultats pour entraîner ou créer des modèles
d'apprentissage automatique ou une technologie connexe pour la prédiction de
la structure biomoléculaire semblable à AlphaFold 3 (« Modèles dérivés »)**,
y compris par distillation ou d'autres méthodes. Pour éviter le doute, les
restrictions d'utilisation énoncées dans les présentes Conditions
s'appliquent intégralement à tout Modèle dérivé créé en violation des
présentes Conditions.
2. **De Distribuer les Résultats sans indiquer clairement que ce que vous
Distribuez est fourni dans le cadre et sous réserve des
[Conditions d'utilisation des résultats d'AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md)
et de toutes les modifications que vous y apportez.**
1. Cela signifie que si vous retirez, ou faites retirer (par exemple en
utilisant un logiciel tiers), les avis et les conditions d'utilisation
que nous fournissons lorsque vous générez des Résultats à l'aide
d'AlphaFold 3, vous devez vous assurer que toute Distribution ultérieure
de Résultats est accompagnée d'une copie des
[Conditions d'utilisation des résultats d'AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md)
et d'un fichier texte des « Conditions d'utilisation légalement
contraignantes » qui contient l'avis suivant:
« *En utilisant cette information,vous acceptez les Conditions
d'utilisation des résultats d'AlphaFold 3 qui se trouve à l'adresse
https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md.*
*Pour demander l'accès aux paramètres du modèle AlphaFold 3, suivez le
processus décrit à l'adresse
https://github.com/google-deepmind/alphafold3. Vous ne pouvez les
utiliser que si vous les recevez directement de Google. L'utilisation
est soumise aux conditions d'utilisation disponibles à l'adresse
https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.*
»
2. Vous ne devez pas inclure de conditions d'utilisation supplémentaires ou
différentes qui seraient en contradiction avec les
[Conditions d'utilisation des résultats d'AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md).
3. **De Distribuer les Résultats ou de divulguer les résultats découlant de
l'utilisation d'AlphaFold 3 sans citer notre article:** « [Abramson, J et
al. Accurate structure prediction of biomolecular interactions with
AlphaFold 3. *Nature*
(2024)](https://www.nature.com/articles/s41586-024-07487-w) ». Pour éviter
toute ambiguïté, il s'agit d'une exigence supplémentaire par rapport aux
exigences de notification énoncées ci-dessus.
4. **De contourner les restrictions d'accès relatives aux Paramètres du modèle,
y compris l'utilisation, le partage ou la mise à disposition des Paramètres
du modèle alors que vous n'y avez pas été expressément autorisé par
Google.** Google accordera l'accès aux Paramètres du modèle à soit:
1. vous, pour votre utilisation individuelle au nom de votre organisation,
auquel cas vous ne pouvez pas partager votre copie des Paramètres du
modèle avec quelqu'un d'autre; ou
2. un représentant autorisé de votre organisation, disposant de la pleine
autorité légale pour lier cette organisation aux présentes Conditions,
auquel cas vous pouvez partager la copie des Paramètres du modèle de
cette organisation avec les employés, les consultants, les entrepreneurs
et les agents de l'organisation, tel qu'autorisé par ce représentant.
================================================
FILE: legal/WEIGHTS_PROHIBITED_USE_POLICY-Portugues-Brazil.md
================================================
# POLÍTICA DE USO PROIBIDO DOS PARÂMETROS DO MODELO ALPHAFOLD 3
Última modificação: 2024-11-09
O AlphaFold 3 ajuda você a acelerar pesquisas científicas ao prever a estrutura
3D de moléculas biológicas. O Google disponibiliza os Recursos do AlphaFold sem
custo financeiro para certos usos não comerciais, de acordo com as restrições
abaixo. Esta política usa os mesmos termos definidos nos
[Termos de Uso dos Parâmetros do Modelo AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Portugues-Brazil.md).
**Você não deve acessar, usar nem permitir que outras pessoas acessem ou usem os
Recursos do AlphaFold 3 nos seguintes casos:**
1. **Em nome de uma organização comercial ou em associação a atividades
comerciais, incluindo pesquisas em nome de organizações comerciais.**
1. Isso significa que apenas organizações não comerciais (*ou seja*,
universidades, organizações sem fins lucrativos, institutos de pesquisa
e órgãos governamentais, educacionais e de notícias) podem usar os
Recursos do AlphaFold 3 para suas atividades não comerciais. Os Recursos
do AlphaFold 3 não estão disponíveis para qualquer outro tipo de
organização, mesmo as que conduzem trabalhos não comerciais.
2. Se você for um pesquisador afiliado a uma organização não comercial,
você tem permissão para usar esses recursos em sua pesquisa afiliada a
organizações sem fins lucrativos, desde que você não seja uma
organização comercial nem esteja agindo em nome de uma.
3. É proibido compartilhar os Recursos do AlphaFold 3 com qualquer
organização comercial ou usar os Recursos do AlphaFold 3 de modo a
conceder a uma organização comercial qualquer direito em relação a eles.
A única exceção é a disponibilização da Saída para o público (incluindo
indiretamente para organizações comerciais) mediante uma publicação
científica, versão de código aberto ou em apoio ao jornalismo, o que é
permitido.
2. **Para gerar desinformação, deturpar ou enganar**, incluindo:
1. fornecer informações falsas ou imprecisas em relação ao seu acesso ou
uso do AlphaFold 3 ou da Saída gerada, incluindo acessar ou usar os
Parâmetros do Modelo em nome de uma organização sem nos informar ou
solicitar o acesso aos Parâmetros do Modelo caso o Google tenha proibido
totalmente ou parcialmente seu uso do AlphaFold 3 (incluindo conforme
disponibilizado pelo
[Servidor do AlphaFold](https://alphafoldserver.com/about));
2. deturpar sua relação conosco, incluindo ao usar marcas registradas,
nomes comerciais e logotipos do Google, ou sugerir o endosso do Google
sem a nossa permissão – nada nestes Termos concede tal permissão;
3. deturpar a origem do AlphaFold 3 total ou parcialmente;
4. distribuir declarações enganosas sobre conhecimento ou capacidade, ou
participar do exercício não autorizado ou não licenciado de qualquer
profissão, especialmente em áreas sensíveis (*por exemplo*, saúde); ou
5. tomar decisões em áreas que afetam o bem-estar ou direitos materiais ou
individuais (*por exemplo*, saúde).
3. **Para realizar, promover ou facilitar atividades perigosas, ilegais ou
maliciosas**, incluindo:
1. promover ou facilitar a venda ou fornecer instruções para sintetizar ou
ter acesso a substâncias, produtos ou serviços ilegais;
2. abusar, prejudicar, interferir ou interromper quaisquer serviços,
incluindo gerar ou distribuir conteúdo para atividades enganosas ou
fraudulentas ou malware;
3. gerar ou distribuir qualquer conteúdo, incluindo a Saída, que infrinja,
se aproprie indevidamente ou viole de outra forma os direitos de
qualquer indivíduo ou entidade (incluindo, mas não se limitando a
direitos autorais do conteúdo); ou
4. tentar burlar ou levar intencionalmente (direta ou indiretamente) o
AlphaFold 3 a agir de maneira que viole os Termos.
**Não é permitido que você nem outras pessoas:**
1. **Usem os Resultados para treinar ou criar modelos de aprendizado de máquina
ou tecnologias relacionadas para previsão de estrutura biomolecular
semelhante ao AlphaFold 3 ("Modelos Derivados"),** incluindo pela destilação
ou outros métodos. Para evitar dúvidas, as restrições de uso definidas nos
Termos são totalmente válidas para quaisquer Modelos Derivados criados em
violação dos Termos.
2. **Distribuam a Saída sem apresentar aviso evidente de que o que você
Distribui é oferecido de acordo com e sujeito aos
[Termos de Uso dos Resultados do AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md)
e quaisquer modificações realizadas.**
1. Isso significa que, se você remover ou causar a remoção (por exemplo,
usando um software de terceiros) dos avisos e termos que fornecemos
quando você gera Resultados usando o AlphaFold 3, você precisa garantir
que a Distribuição da Saída posterior esteja acompanhada de uma cópia
dos
[Termos de Uso dos Resultados do AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md)
e de um arquivo de texto "Termos de Uso Juridicamente Vinculativos" com
o seguinte aviso:
"*Ao usar estas informações, você concorda com os Termos de Uso da Saída
do AlphaFold 3 disponíveis em
https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md.*
*Para solicitar acesso aos parâmetros do modelo AlphaFold 3, siga o
processo descrito em https://github.com/google-deepmind/alphafold3. Você
só pode usar os parâmetros se os receber diretamente do Google. O uso
está sujeito aos Termos de Uso disponíveis em
https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.*"
2. É proibido incluir quaisquer termos adicionais ou diferentes que entrem
em conflito com os
[Termos de Uso da Saída do AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md).
3. **Distribuam a Saída ou divulguem descobertas provenientes do uso do
AlphaFold 3 sem citar nosso artigo:** [Abramson, J et al. Accurate structure
prediction of biomolecular interactions with AlphaFold 3. *Nature*
(2024)](https://www.nature.com/articles/s41586-024-07487-w). Para evitar
dúvidas, esse é um requisito adicional às exigências de aviso definidas
acima.
4. **Burlem as restrições de acesso relacionadas aos Parâmetros do Modelo,
incluindo usar, compartilhar ou disponibilizar os Parâmetros do Modelo sem
autorização explícita do Google.** O Google concederá acesso aos Parâmetros
do Modelo a:
1. você, para uso individual em nome da sua organização, sendo proibido
compartilhar sua cópia dos Parâmetros do Modelo com qualquer indivíduo;
ou
2. um representante autorizado da sua organização, com autoridade legal
total para vincular tal organização a estes Termos, sendo permitido
compartilhar a cópia dos Parâmetros do Modelo pertencente a essa
organização com funcionários, consultores, prestadores de serviço e
agentes da organização, conforme autorizado por esse representante.
================================================
FILE: legal/WEIGHTS_TERMS_OF_USE-Bahasa-Indonesia.md
================================================
# PERSYARATAN PENGGUNAAN PARAMETER MODEL ALPHAFOLD 3
Terakhir diubah: 09-11-2024
[AlphaFold 3](https://blog.google/technology/ai/google-deepmind-isomorphic-alphafold-3-ai-model/)
adalah model AI yang dikembangkan oleh
[Google DeepMind](https://deepmind.google/) dan
[Isomorphic Labs](https://www.isomorphiclabs.com/). Program ini membuat prediksi
struktur 3D molekul biologis, serta memberikan keyakinan model untuk prediksi
struktur tersebut. Kami membuat parameter model terlatih dan output yang
dihasilkan menggunakan aset yang tersedia tanpa biaya untuk penggunaan
non-komersial tertentu, sehubungan dengan persyaratan penggunaan ini dan
[Kebijakan Penggunaan Terlarang untuk Parameter Model AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Bahasa-Indonesia.md).
**Hal penting yang perlu diketahui saat menggunakan parameter dan output model
AlphaFold 3**
1. Parameter dan output model AlphaFold 3 hanya tersedia untuk penggunaan
non-komersial oleh, atau atas nama, organisasi non-komersial (*yaitu*
universitas, organisasi nonprofit dan institusi riset, serta lembaga
pendidikan, jurnalistik, dan pemerintah). Jika Anda adalah peneliti yang
berafiliasi dengan organisasi non-komersial, dengan syarat Anda bukan
organisasi komersial atau bertindak atas nama organisasi komersial, artinya
Anda dapat menggunakannya untuk riset terafiliasi non-komersial Anda.
2. Anda tidak boleh menggunakan atau mengizinkan orang lain menggunakan:
1. Parameter atau Output model AlphaFold 3 sehubungan dengan aktivitas
komersial apa pun, termasuk riset atas nama organisasi komersial; atau
1. Output AlphaFold 3 untuk melatih model machine learning atau teknologi
terkait untuk prediksi struktur biomolekuler yang mirip dengan AlphaFold
3
3. Anda *tidak boleh* mempublikasikan atau membagikan parameter model AlphaFold
3, kecuali membagikannya dalam organisasi Anda sesuai dengan Persyaratan
ini.
4. Anda *dapat* mempublikasikan, membagikan, dan mengadaptasi *output*
AlphaFold 3 sesuai dengan Persyaratan ini, termasuk persyaratan untuk
memberikan pemberitahuan yang jelas atas setiap modifikasi yang Anda buat
dan bahwa penggunaan yang sedang berlangsung atas output AlphaFold 3 dan
turunannya tunduk pada
[Persyaratan Penggunaan Output AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md).
Dengan menggunakan, mereproduksi, memodifikasi, menjalankan, mendistribusikan,
atau menampilkan bagian atau elemen apa pun dari Parameter Model (sebagaimana
didefinisikan di bawah) atau menyetujui persyaratan perjanjian ini, Anda setuju
untuk terikat oleh (1) persyaratan penggunaan ini, dan (2)
[Kebijakan Penggunaan Terlarang untuk Parameter Model AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Bahasa-Indonesia.md)
yang disertakan di sini sebagai referensi (secara kolektif disebut
"**Persyaratan**"), dalam setiap kasus (a) sebagaimana diubah dari waktu ke
waktu sesuai dengan Persyaratan, serta (b) antara Anda dan (i) Google Ireland
Limited, jika Anda berasal dari negara di Wilayah Ekonomi Eropa atau Swiss, atau
(ii) Google LLC, jika Anda berasal dari wilayah lain.
Anda mengonfirmasi bahwa Anda berwenang baik secara eksplisit maupun implisit
untuk masuk, dan sedang memasuki, ke dalam Persyaratan ini sebagai karyawan yang
mewakili, atau atas nama, organisasi Anda.
Harap baca Persyaratan ini dengan cermat. Persyaratan ini menetapkan apa yang
dapat Anda harapkan dari kami saat Anda mengakses dan menggunakan Aset AlphaFold
3 (sebagaimana di definisikan di bawah), dan apa yang Google harapkan dari Anda.
Penyebutan "**Anda**" di sini mengacu pada individu atau organisasi yang
menggunakan Aset AlphaFold 3. Penyebutan "**kami**", "**kita**", atau
"**Google**" di sini mengacu pada entitas milik grup perusahaan Google, yaitu
Google LLC beserta afiliasinya.
## 1. Definisi Penting
Sebagaimana digunakan dalam Persyaratan ini:
"**AlphaFold 3**" adalah: (a) kode sumber AlphaFold 3 yang disediakan
[di sini](https://github.com/google-deepmind/alphafold3/) dan yang dilisensikan
berdasarkan persyaratan lisensi Creative Commons
Attribution-NonCommercial-Sharealike 4.0 International (CC-BY-NC-SA 4.0) dan
kode sumber turunan apa pun, serta (b) Parameter Model.
"**Aset AlphaFold 3**" adalah Parameter dan output Model.
"**Distribusi**" atau "**Mendistribusikan**" adalah mengirimkan,
mempublikasikan, atau membagikan Output secara publik atau kepada orang lain.
"**Parameter Model**" adalah bobot dan parameter model terlatih yang disediakan
oleh Google bagi organisasi (atas pertimbangannya sendiri) untuk digunakan
sesuai dengan Persyaratan ini, bersama dengan (a) modifikasi pada bobot dan
parameter tersebut, (b) pekerjaan yang didasarkan pada bobot dan parameter
tersebut, atau (c) kode atau model machine learning lainnya yang menggabungkan,
seluruh atau sebagian, bobot dan parameter tersebut.
"**Output**" adalah prediksi struktur serta semua informasi tambahan dan
informasi terkait yang disediakan oleh AlphaFold 3 atau penggunaan Parameter
Model, bersama dengan representasi visual, prediksi komputasional, deskripsi,
modifikasi, salinan, atau adaptasi apa pun yang secara substansial berasal dari
Output.
"**Termasuk**" adalah "**termasuk, tetapi tidak terbatas pada**".
## 2. Mengakses dan menggunakan Aset AlphaFold 3
Dengan tunduk pada kepatuhan Anda terhadap Persyaratan, termasuk
[Kebijakan Penggunaan Terlarang untuk Parameter Model AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Bahasa-Indonesia.md),
Anda dapat mengakses, menggunakan, dan memodifikasi Aset AlphaFold 3 serta
Mendistribusikan Output sebagaimana ditetapkan dalam Persyaratan ini. Kami
memberi Anda lisensi non-eksklusif, bebas royalti, dapat dibatalkan, tidak dapat
dipindahtangankan, dan tidak dapat disublisensikan (kecuali secara tegas
diizinkan dalam Persyaratan ini) untuk hak atas kekayaan intelektual apa pun
yang kami miliki dalam Aset AlphaFold sejauh diperlukan untuk tujuan ini. Untuk
memverifikasi akses dan penggunaan AlphaFold 3 oleh Anda, kami dapat meminta
Anda memberikan informasi tambahan dari waktu ke waktu, termasuk verifikasi
nama, organisasi, serta informasi identitas Anda lainnya.
Dengan mengakses, menggunakan, atau memodifikasi Aset AlphaFold 3,
Mendistribusikan Output, atau meminta akses ke Parameter Model, Anda menyatakan
dan menjamin bahwa (a) Anda memiliki kuasa dan wewenang penuh untuk menyetujui
Persyaratan ini (termasuk telah berusia dewasa), (b) Google sebelumnya tidak
pernah menghentikan akses dan hak Anda untuk menggunakan AlphaFold 3 (termasuk
yang disediakan melalui [Server AlphaFold](https://alphafoldserver.com/about))
karena pelanggaran Anda terhadap persyaratan penggunaan yang berlaku, (c)
menyetujui atau menjalankan hak dan kewajiban Anda berdasarkan Persyaratan ini
tidak akan melanggar hak pihak ketiga mana pun atau perjanjian yang Anda
sepakati dengan pihak ketiga, (d) informasi apa pun yang Anda berikan ke Google
sehubungan dengan AlphaFold 3, termasuk (jika berlaku) untuk meminta akses ke
Parameter Model, sudah benar dan aktual, serta (e) Anda bukan (i) berstatus
warga dari negara yang diembargo, (ii) berstatus menetap di negara yang
diembargo Amerika Serikat, atau (iii) dinyatakan dilarang oleh program sanksi
dan kontrol ekspor yang berlaku untuk mengakses, menggunakan, atau memodifikasi
Aset AlphaFold 3.
Jika Anda memilih untuk memberikan masukan ke Google, seperti saran untuk
meningkatkan kualitas AlphaFold 3, Anda setuju bahwa informasi tersebut tidak
bersifat rahasia dan eksklusif, serta Google dapat menindaklanjuti masukan Anda
tanpa kewajiban kepada Anda.
## 3. Pembatasan Penggunaan
Anda tidak boleh menggunakan Aset AlphaFold 3 apa pun:
1. untuk penggunaan terbatas yang ditetapkan dalam
[Kebijakan Penggunaan Terlarang untuk Parameter Model AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Bahasa-Indonesia.md);
atau
2. dengan cara yang melanggar hukum dan peraturan yang berlaku.
Selama diizinkan oleh hukum dan tanpa membatasi hak kami lainnya, Google berhak
mencabut hak penggunaan Anda, dan (selama memungkinkan) membatasi penggunaan
Aset AlphaFold 3 apa pun yang menurut Google secara wajar melanggar Persyaratan
ini.
## 4. Output yang Dihasilkan
Meskipun Anda harus mematuhi Persyaratan ini saat menggunakan Aset AlphaFold 3,
kami tidak akan mengklaim kepemilikan atas Output orisinal yang Anda hasilkan
menggunakan AlphaFold 3. Namun, Anda memahami bahwa AlphaFold 3 dapat
menghasilkan Output yang sama atau mirip untuk beberapa pengguna, termasuk
Google, dan kami berhak mengklaim Output tersebut.
## 5. Perubahan pada Aset AlphaFold 3 atau Persyaratan ini
Google dapat menambahkan atau menghapus fungsi atau fitur Aset AlphaFold 3 kapan
saja dan dapat berhenti menawarkan akses ke Aset AlphaFold 3 sepenuhnya.
Google dapat memperbarui Persyaratan ini dan mekanisme akses untuk Parameter
Model kapan saja. Kami akan memposting setiap perubahan pada Persyaratan
[di repositori GitHub AlphaFold 3](https://github.com/google-deepmind/alphafold3).
Perubahan umumnya akan berlaku 14 hari setelah diposting. Namun, perubahan yang
berkaitan dengan fungsi atau yang dibuat karena alasan hukum akan langsung
berlaku.
Anda harus meninjau Persyaratan ini setiap kali kami memperbaruinya atau saat
Anda menggunakan Aset AlphaFold 3. Jika Anda tidak menyetujui perubahan pada
Persyaratan, Anda harus segera menghentikan penggunaan Aset AlphaFold 3.
## 6. Menangguhkan atau menghentikan hak Anda untuk menggunakan Aset AlphaFold 3
Google dapat sewaktu-waktu menangguhkan atau menghentikan hak Anda untuk
menggunakan dan mengakses Aset AlphaFold 3 sebagaimana berlaku karena, antara
lain, kegagalan Anda untuk sepenuhnya mematuhi Persyaratan. Jika Google
menangguhkan atau menghentikan hak Anda untuk mengakses atau menggunakan Aset
AlphaFold 3, Anda harus segera menghapus dan menghentikan penggunaan serta
Distribusi semua salinan Aset AlphaFold 3 yang Anda miliki atau kontrol, dan
Anda dilarang menggunakan Aset AlphaFold 3, termasuk mengajukan permohonan untuk
menggunakan Parameter Model. Google akan berupaya memberikan pemberitahuan
sewajarnya kepada Anda sebelum penangguhan atau penghentian tersebut. Namun,
Anda tidak akan menerima pemberitahuan atau peringatan sebelumnya jika
penangguhan atau penghentian tersebut terjadi karena Anda tidak sepenuhnya
mematuhi Persyaratan atau karena alasan serius lainnya.
Anda tentunya dapat menghentikan penggunaan Aset AlphaFold 3 kapan saja. Jika
Anda berhenti menggunakannya, harap beri tahu kami alasannya (melalui
alphafold@google.com) sehingga kami dapat terus meningkatkan kualitas teknologi
kami.
## 7. Kerahasiaan
Anda setuju untuk tidak mengungkapkan atau menyediakan Informasi Rahasia Google
kepada siapa pun tanpa izin tertulis dari kami sebelumnya. "**Informasi Rahasia
Google**" berarti (a) Parameter Model AlphaFold 3 dan semua software, teknologi,
serta dokumentasi yang terkait dengan AlphaFold 3, kecuali kode sumber AlphaFold
3, dan (b) informasi lain apa pun yang disediakan oleh Google yang ditandai
sebagai rahasia atau umumnya dianggap rahasia berdasarkan penyajian informasi
tersebut. Informasi Rahasia Google tidak mencakup (a) informasi yang sudah Anda
ketahui sebelum Anda mengakses atau menggunakan Aset AlphaFold 3 (termasuk
melalui [Server AlphaFold](https://alphafoldserver.com/about)), (b) yang
terungkap ke publik bukan karena kesalahan Anda (misalnya, pelanggaran Anda
terhadap Persyaratan ini), (c) yang Anda kembangkan sendiri tanpa mengacu pada
Informasi Rahasia Google, atau (d) yang diberikan kepada Anda oleh pihak ketiga
sesuai hukum yang berlaku (Tanpa anda atau pihak ketiga tersebut melanggar
Persyaratan).
## 8. Pernyataan penyangkalan
Tidak ada di dalam Persyaratan membatasi hak apa pun yang tidak dapat dibatasi
berdasarkan hukum yang berlaku atau membatasi tanggung jawab Google kecuali
sebagaimana diizinkan oleh hukum yang berlaku.
**AlphaFold 3 dan Output disediakan "apa adanya", tanpa jaminan atau ketentuan
apa pun, baik tersurat maupun tersirat, termasuk jaminan atau ketentuan tentang
kepemilikan, ketiadaan pelanggaran, kelayakan untuk diperdagangkan, atau
kesesuaian untuk tujuan tertentu. Anda bertanggung jawab sepenuhnya untuk
menentukan kesesuaian penggunaan AlphaFold 3, atau penggunaan atau
pendistribusian Output, dan menanggung semua risiko yang terkait dengan
penggunaan atau pendistribusian tersebut serta pelaksanaan hak dan kewajiban
oleh Anda berdasarkan Persyaratan ini. Anda dan siapa pun yang Anda beri Output
bertanggung jawab sepenuhnya atas Output tersebut serta penggunaannya
selanjutnya.**
**Output merupakan prediksi dengan tingkat keyakinan yang berbeda-beda dan harus
ditafsirkan dengan cermat. Gunakan pertimbangan sebelum mengandalkan,
memublikasikan, mendownload, atau menggunakan AlphaFold 3.**
**AlphaFold 3 dan Output hanya ditujukan untuk pemodelan teoretis. Aset tersebut
tidak dimaksudkan, divalidasi, atau disetujui untuk penggunaan klinis. Anda
tidak boleh menggunakan AlphaFold 3 atau Output untuk tujuan klinis atau
mengandalkannya untuk saran medis atau profesional lainnya. Konten apa pun
terkait topik tersebut hanya diberikan untuk tujuan informasi dan bukan
merupakan pengganti saran dari profesional yang berkualifikasi.**
## 9. Kewajiban
Selama diizinkan hukum yang berlaku, Anda akan melindungi Google serta direktur,
petugas, karyawan, dan kontraktornya terhadap kerugian dari proses hukum pihak
ketiga (termasuk tindakan oleh otoritas pemerintah) yang timbul dari atau
berkaitan dengan penggunaan Aset AlphaFold 3 oleh Anda yang melanggar hukum atau
pelanggaran Anda terhadap Persyaratan. Perlindungan terhadap kerugian ini
mencakup kewajiban atau pengeluaran yang timbul dari klaim, kerugian, kerusakan,
putusan pengadilan, denda, biaya proses pengadilan, dan biaya hukum, kecuali
jika kewajiban atau pengeluaran disebabkan oleh pelanggaran, kelalaian, atau
perbuatan tidak pantas yang disengaja oleh Google. Jika Anda dikecualikan secara
hukum dari tanggung jawab tertentu, termasuk perlindungan terhadap kerugian,
tanggung jawab tersebut tidak berlaku bagi Anda berdasarkan Persyaratan.
Dalam keadaan apa pun, Google tidak akan bertanggung jawab atas ganti rugi tidak
langsung, ganti rugi khusus, ganti rugi insidental, ganti rugi sebagai
peringatan, ganti rugi sebagai akibat, atau ganti rugi penghukuman, atau
hilangnya keuntungan dalam bentuk apa pun sehubungan dengan Persyaratan atau
Aset AlphaFold 3, meskipun Google telah diberi tahu tentang kemungkinan adanya
ganti rugi tersebut. Total kewajiban kumulatif Google untuk semua klaim yang
timbul dari atau sehubungan dengan Persyaratan atau Aset AlphaFold 3, termasuk
karena kelalaiannya sendiri, dibatasi hingga $500.
## 10. Ketentuan lainnya
Secara hukum, Anda memiliki hak tertentu yang tidak dapat dibatasi oleh kontrak
seperti Persyaratan. Persyaratan sama sekali tidak dimaksudkan untuk membatasi
hak tersebut.
Persyaratan merupakan keseluruhan perjanjian kami terkait penggunaan Aset
AlphaFold 3 oleh Anda dan menggantikan perjanjian sebelumnya atau pada saat yang
sama yang menyangkut penggunaan tersebut.
Jika ternyata ada ketentuan dalam Persyaratan yang tidak memiliki kekuatan
hukum, ketentuan lainnya dalam Persyaratan akan tetap berlaku dan memiliki
kekuatan hukum penuh.
## 11. Sengketa
Hukum California akan mengatur semua sengketa yang timbul dari atau berkaitan
dengan Persyaratan atau sehubungan dengan Aset AlphaFold 3. Sengketa ini akan
diselesaikan secara eksklusif di pengadilan federal atau negara bagian Santa
Clara County, California, Amerika Serikat dan Anda serta Google menyetujui
wilayah hukum pribadi di pengadilan tersebut. Jika hukum setempat yang berlaku
mencegah sengketa tertentu diselesaikan di pengadilan California, Anda dan
Google dapat mengajukan sengketa tersebut di pengadilan setempat Anda. Jika
hukum setempat yang berlaku mencegah pengadilan setempat Anda menerapkan hukum
California untuk menyelesaikan sengketa ini, sengketa ini akan diatur oleh hukum
setempat yang berlaku dari negara, negara bagian, atau tempat tinggal Anda yang
lain. Jika Anda menggunakan Aset AlphaFold 3 atas nama organisasi pemerintah
selain organisasi pemerintah federal Amerika Serikat (dengan ketentuan yang
disebutkan sebelumnya akan berlaku selama diizinkan oleh hukum federal),
Persyaratan ini tidak akan berlaku untuk pengadilan dan hukum yang mengatur.
Mengingat sifat riset ilmiah, mungkin perlu waktu beberapa saat hingga
pelanggaran terhadap Persyaratan terlihat jelas. Untuk melindungi Anda, Google,
dan Aset AlphaFold 3, selama diizinkan hukum yang berlaku, Anda setuju bahwa:
1. klaim hukum apa pun terkait Persyaratan atau Aset AlphaFold 3 dapat diajukan
hingga:
1. tanggal batas waktu berdasarkan hukum yang berlaku untuk mengajukan
klaim hukum; atau
2. dua tahun sejak tanggal Anda atau Google (sebagaimana berlaku)
mengetahui, atau seharusnya secara wajar mengetahui, fakta yang
menimbulkan klaim tersebut; dan
2. Anda dan Google tidak akan memperdebatkan pembatasan, batas waktu,
penundaan, pelepasan hak, atau sejenisnya dalam upaya untuk menghalangi
gugatan yang diajukan dalam jangka waktu tersebut.
Semua hak yang tidak secara khusus dan tegas diberikan kepada Anda oleh
Persyaratan menjadi hak milik Google. Penundaan, tindakan, atau kelalaian oleh
Google dalam melaksanakan hak atau upaya hukum apa pun tidak akan dianggap
sebagai pelepasan hak atas pelanggaran terhadap Persyaratan, dan Google secara
tegas memiliki semua hak dan upaya hukum yang tersedia berdasarkan Persyaratan,
hukum, ekuitas, atau lainnya, termasuk upaya hukum yang menyangkut penyelesaian
dengan perintah pengadilan atas setiap ancaman atau pelanggaran nyata terhadap
Persyaratan tanpa perlu membuktikan kerugian yang sebenarnya.
================================================
FILE: legal/WEIGHTS_TERMS_OF_USE-Espanol-Latinoamerica.md
================================================
# CONDICIONES DE USO DE LOS PARÁMETROS DEL MODELO ALPHAFOLD 3
Última modificación: 9 de noviembre de 2024
[AlphaFold 3](https://blog.google/technology/ai/google-deepmind-isomorphic-alphafold-3-ai-model/)
es un modelo de IA desarrollado por [Google DeepMind](https://deepmind.google/)
y por [Isomorphic Labs](https://www.isomorphiclabs.com/). Genera predicciones de
estructuras 3D de moléculas biológicas, lo que proporciona confianza del modelo
para las predicciones de estructuras. Creamos los parámetros del modelo
entrenado y los resultados generados y los ponemos a disposición sin costo para
determinados usos no comerciales de conformidad con las condiciones de uso y la
[Política de Uso Prohibido de los Parámetros del Modelo AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Espanol-Latinoamerica.md).
**Puntos clave para tener en cuenta al usar los parámetros y los resultados del
modelo AlphaFold 3**
1. Los parámetros y los resultados del modelo AlphaFold 3 solo están
disponibles para usos no comerciales de organizaciones no comerciales (*es
decir*, universidades, organizaciones sin fines de lucro, instituciones de
investigación y organismos educativos, periodísticos y gubernamentales), o
bien en su nombre. Si usted es un investigador afiliado de una organización
no comercial, en la medida en que no pertenezca a una organización comercial
ni actúe en nombre de una, puede usar estos recursos para su investigación
de afiliación no comercial.
1. No debe utilizar ni permitir que otros utilicen AlphaFold 3 ni sus
parámetros o resultados en los siguientes casos:
2. En conexión con cualquier actividad comercial, incluidas investigaciones
en nombre de organizaciones comerciales
2. Para entrenar modelos de aprendizaje automático, o bien tecnologías
relacionadas para la predicción de estructuras biomoleculares, similares a
AlphaFold 3
3. No *debe* publicar ni compartir los parámetros del modelo AlphaFold 3,
excepto dentro de su organización, de acuerdo con estas Condiciones.
4. Puede publicar, compartir y adaptar los *resultados* de AlphaFold 3 de
conformidad con estas Condiciones, que incluyen el requisito de brindar un
aviso claro de que cualquier modificación que haga y el uso continuo de los
resultados de AlphaFold 3 y sus derivaciones están sujetas a las
[Condiciones de Uso de los Resultados de AlphaFold](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md).
Al usar, reproducir, modificar, realizar, distribuir o mostrar cualquier porción
o elemento de los Parámetros del Modelo (como se definen a continuación), o bien
al aceptar las condiciones de este acuerdo, usted se compromete a cumplir con lo
siguiente: (1) estas Condiciones de Uso y (2) la
[Política de Uso Prohibido de los Parámetros del Modelo AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Espanol-Latinoamerica.md),
que se incorpora por referencia en este documento (en conjunto, las
"**Condiciones**"), en cada caso, (a) según las modificaciones ocasionales que
se hagan de acuerdo con las Condiciones y (b) entre usted y (i) si es de un país
del Espacio Económico Europeo o Suiza, Google Ireland Limited o (ii) en
cualquier otro caso, Google LLC.
Confirma que tiene la autorización explícita o implícita para celebrar, y está
celebrando, las Condiciones como empleado o de otra manera en nombre de su
organización.
Lea cuidadosamente estas Condiciones. En ellas, se establece lo que usted puede
esperar de nosotros cuando usa los Recursos de AlphaFold 3, como se describen a
continuación, y lo que Google espera de usted. Cuando decimos "**usted**",
hacemos referencia al individuo o la organización que usa los Recursos de
AlphaFold 3. Cuando decimos "**nosotros**" o "**Google**", hacemos referencia a
las entidades que pertenecen al grupo de empresas de Google, que comprende a
Google LLC y sus afiliadas.
## 1. Definiciones clave
Según su uso en estas Condiciones:
"**AlphaFold 3**" significa: (a) el código fuente de AlphaFold 3 disponible
[aquí](https://github.com/google-deepmind/alphafold3/) y con licencia en virtud
de las condiciones de la Atribución/Reconocimiento-NoComercial-CompartirIgual
4.0 Internacional (CC-BY-NC-SA 4.0) de Creative Commons, y cualquier código
fuente derivado, y (b) los Parámetros del Modelo.
"**Recursos de AlphaFold 3**" hace referencia a los Resultados y los Parámetros
del Modelo.
"**Distribución" o "Distribuir**" incluye cualquier transmisión, publicación y
otras instancias en las que se comparten los Resultados de manera pública o a
otra persona.
"**Parámetros del Modelo**" hace referencia a las ponderaciones y los parámetros
del modelo entrenado, que Google pone a disposición para las organizaciones (a
su entera discreción) para su uso de acuerdo con estas Condiciones, junto con
(a) las modificaciones a esas ponderaciones y parámetros, (b) los trabajos
basados en esas ponderaciones y parámetros, o bien (c) otros modelos de
aprendizaje automático y código que incorporan, en su totalidad o en parte,
estos parámetros y ponderaciones.
"**Resultados**" hace referencia a las predicciones de estructura y toda la
información adicional y relacionada que brinda AlphaFold 3 o el uso de los
Parámetros del Modelo, además de toda representación visual, predicción
computacional, descripción, modificación, copia o adaptación que esté
sustancialmente derivada de los Resultados.
"**Lo que incluye**" significa "**incluido, sin limitarse a ello**".
## 2. Acceso y uso de los Recursos de AlphaFold 3
Sujeto al cumplimiento de estas Condiciones, lo que incluye la
[Política de Uso Prohibido de los Parámetros del Modelo AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Espanol-Latinoamerica.md),
puede acceder a los Recursos de AlphaFold 3, usarlos y modificarlos, y
Distribuir los Resultados como se define en estas Condiciones. Le otorgamos una
licencia no exclusiva, libre de regalías, revocable, no transferible y no
susceptible de someterse a otras licencias (excepto como se indica expresamente
en estas Condiciones) respecto de cualquier derecho de propiedad intelectual que
tengamos sobre los Recursos de AlphaFold en la medida necesaria para estos
propósitos. Para verificar su acceso a AlphaFold 3 y el uso correspondiente,
podríamos solicitarle ocasionalmente información adicional sobre usted, ya sea
que verifique su nombre, su organización o cualquier otra información
identificatoria.
Al acceder a los Recursos de AlphaFold 3 y usarlos o modificarlos, así como al
Distribuir Resultados o solicitar acceso a los Parámetros del Modelo, manifiesta
y garantiza que (a) tiene plenas facultades y atribuciones para celebrar estas
Condiciones (lo que incluye tener la edad de consentimiento), (b) Google nunca
rescindió en el pasado su acceso a AlphaFold 3 ni su derecho de uso (lo que
incluye su disponibilidad a través de
[AlphaFold Server](https://alphafoldserver.com/about)) debido a su
incumplimiento de las Condiciones de Uso correspondientes, (c) el cumplimiento
de estas Condiciones o el ejercicio de sus derechos y obligaciones no infringirá
ningún acuerdo que tenga con un tercero ni ningún derecho de terceros, (d)
cualquier información que usted proporcione a Google en relación con AlphaFold
3, incluida la necesaria (cuando corresponda) para solicitar acceso a los
Parámetros del Modelo, es correcta y actual, y (e) usted no (i) es residente de
un país bajo embargo, (ii) es residente de un país bajo el embargo de EE.UU. ni
(iii) tiene prohibiciones a través de controles de exportación aplicables y
programas de sanción el acceso a los Recursos de AlphaFold 3, así como su uso y
modificación.
Si decide enviarle comentarios a Google, como sugerencias para mejorar AlphaFold
3, asegura que esa información no es confidencial ni de su propiedad, y que
Google puede actuar respecto de sus comentarios sin tener ninguna obligación con
usted.
## 3. Restricciones de uso
No debe usar ninguno de los Recursos de AlphaFold 3 en los siguientes casos:
1. Los usos restringidos establecidos en la
[Política de Uso Prohibido de los Parámetros del Modelo AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Espanol-Latinoamerica.md)
2. En incumplimiento de las leyes y reglamentaciones aplicables
En el sentido más amplio permitido por la ley y sin limitar ninguno de nuestros
otros derechos, Google se reserva el derecho de revocar su derecho de usar y (en
la medida que sea viable) restringir el uso de cualquiera de los Recursos de
AlphaFold 3 que Google razonablemente cree que infringe estas Condiciones.
## 4. Resultados generados
Aunque debe cumplir con estas Condiciones cuando usa los Recursos de AlphaFold
3, no reclamaremos la propiedad de los Resultados originales que genere usando
AlphaFold 3. Sin embargo, usted reconoce que AlphaFold 3 puede generar los
mismos Resultados, o bien otros similares, para varios usuarios, incluido
Google, y nos reservamos todos nuestros derechos al respecto.
## 5. Cambios en los Recursos de AlphaFold 3 o estas Condiciones
Google podría agregar o quitar funciones de los Recursos de AlphaFold 3 en
cualquier momento, y también podría quitar por completo el acceso a los Recursos
de AlphaFold 3.
Google podría actualizar estas Condiciones y el mecanismo de acceso a los
Parámetros del Modelo en cualquier momento. Publicaremos cualquier modificación
a las Condiciones
[en el repositorio de GitHub de AlphaFold 3](https://github.com/google-deepmind/alphafold3).
En general, los cambios entrarán en vigencia 14 días después de su publicación.
Sin embargo, los cambios relacionados con funciones o realizados por motivos
legales entrarán en vigencia de inmediato.
Debería revisar las Condiciones siempre que realicemos actualizaciones o que use
los Recursos de AlphaFold 3. Si no está de acuerdo con las modificaciones de las
Condiciones, debe dejar de usar los Recursos de AlphaFold 3 de inmediato.
## 6. Suspensión o rescisión de su derecho de uso de los Recursos de AlphaFold 3
Google puede, en cualquier momento, suspender o rescindir su derecho de uso y,
según corresponda, acceso a los Recursos de AlphaFold 3 debido a, entre otros
motivos, su incumplimiento de estas Condiciones. Si Google suspende o rescinde
su derecho de acceso o uso de los Recursos de AlphaFold 3, debe borrarlos y
dejar de usar y Distribuir todas las copias correspondientes que tenga en su
posesión o control, y se le prohibirá usar los Recursos de AlphaFold 3, lo que
incluye el envío de solicitudes para usar los Parámetros del Modelo. Google
tratará de darle un aviso con una antelación razonable antes de cualquier
suspensión o rescisión, pero no se le dará ningún aviso ni advertencia previos
si la suspensión o rescisión se deben a su incumplimiento de las Condiciones o
alguna otra razón grave.
Tenga en cuenta que puede dejar de usar los Recursos de AlphaFold 3 cuando lo
desee. Si los deja de usar, le agradeceríamos saber el motivo (a través de
alphafold@google.com) para que podamos continuar mejorando nuestras tecnologías.
## 7. Confidencialidad
Usted acepta no divulgar ni poner a disposición Información Confidencial de
Google sin obtener nuestro previo consentimiento por escrito. "Información
Confidencial de Google" hace referencia a (a) los Parámetros del Modelo
AlphaFold 3 y todo el software, la tecnología y la documentación relacionada con
AlphaFold 3, excepto el código fuente de AlphaFold 3, y (b) cualquier otra
información que Google ponga a disposición y se marque como confidencial o que
normalmente se consideraría confidencial en las circunstancias en las que se
presenta. La Información Confidencial de Google no incluye (a) información que
usted ya conocía antes de acceder a los Recursos de AlphaFold 3 o de usarlos
(que se incluye a través de
[AlphaFold Server](https://alphafoldserver.com/about)), (b) información que se
comparte de manera pública por una razón que no lo responsabiliza (por ejemplo,
su incumplimiento de las Condiciones), (c) información que usted desarrolló de
manera independiente sin hacer referencia a la Información Confidencial de
Google, o (d) información que recibió de manera legal de parte de un tercero
(sin que usted o ese tercero hayan incumplido las Condiciones).
## 8. Renuncias de responsabilidad
Ninguna disposición de las Condiciones restringe ningún derecho que no pueda
restringirse en función de la ley aplicable ni limita las responsabilidades de
Google, excepto según lo que permite la ley aplicable.
**AlphaFold 3 y los Resultados se brindan "tal cual son", sin garantías de
ningún tipo, ya sean explícitas o implícitas, lo que incluye garantías o
condiciones de titularidad, no incumplimiento, comerciabilidad o adecuación para
un propósito particular. Usted es el único responsable de determinar la
idoneidad del uso de AlphaFold 3, o bien del uso o la distribución de los
Resultados, y asume todos los riesgos asociados con ese uso o distribución y su
ejercicio de los derechos y las obligaciones según estas Condiciones. Usted y
todas las personas con quienes comparta los Resultados serán los únicos
responsables de estos usos y sus usos posteriores.**
**Los Resultados son predicciones con diversos niveles de confianza y deberían
interpretarse con cuidado. Sea prudente antes de basarse en el contenido de
AlphaFold 3, o bien publicarlo, descargarlo o usarlo de cualquier otro modo.**
**AlphaFold 3 y los Resultados deben usarse únicamente para el modelado teórico.
No están pensados, validados ni aprobados para uso clínico. No debe usar
AlphaFold 3 ni los Resultados con fines clínicos, ni basarse en ellos para dar
consejos médicos ni de índole profesional. Cualquier contenido relacionado con
esos temas se proporciona solo con fines informativos y no sustituye el
asesoramiento de un profesional calificado.**
## 9. Responsabilidades
En la medida en que lo permita la legislación aplicable, usted indemnizará a
Google y sus directores, funcionarios, empleados y contratistas por cualquier
procedimiento legal de terceros (incluidas las acciones de las autoridades
gubernamentales) que surja de su uso ilegal de los Recursos de AlphaFold 3 o del
incumplimiento de estas Condiciones. Esta indemnización cubrirá cualquier
responsabilidad o gasto que surja a partir de reclamos, pérdidas, daños,
juicios, multas, costos de litigios y honorarios legales, excepto en la medida
en que una responsabilidad o un gasto sean causados por un incumplimiento,
negligencia o conducta inapropiada voluntaria por parte de Google. Si en su caso
se aplica una exención legal de ciertas responsabilidades, lo que incluye la
indemnización, no deberá hacerse cargo de estas responsabilidades según estas
Condiciones.
En ningún caso Google será responsable de daños indirectos, especiales,
incidentales, ejemplares, resultantes ni punitivos, ni de la pérdida de
ganancias de ningún tipo en conexión con estas Condiciones o los Recursos de
AlphaFold 3, incluso si se le advirtió sobre la posibilidad de dichos daños. La
responsabilidad conjunta de Google por todos los reclamos que surjan en conexión
con estas Condiciones o los Recursos de AlphaFold 3, lo que incluye los que
surjan de su propia negligencia, se limita a USD 500.
## 10. Varios
Por ley, tiene ciertos derechos que no pueden estar limitados por un contrato,
como estas Condiciones. Las Condiciones no tienen la intención de restringir
esos derechos.
Las Condiciones son nuestro acuerdo completo relacionado con su uso de los
AlphaFold 3 y sustituyen cualquier acuerdo anterior o contemporáneo sobre la
materia.
Si cualquier disposición de estas Condiciones resultase inejecutable, el resto
seguirá plenamente en vigencia.
## 11. Disputas
La ley de California regirá todas las disputas que surjan de las Condiciones o
en conexión con los Recursos de AlphaFold 3. Estas disputas se resolverán
exclusivamente en los tribunales federales o estatales del Condado de Santa
Clara, California, EE.UU., y usted y Google aceptan someterse a la jurisdicción
personal de dichos tribunales. En la medida en que la ley local aplicable impida
que ciertas disputas se resuelvan en un tribunal de California, usted y Google
pueden presentarlas en los tribunales locales de su jurisdicción. Si la ley
local aplicable impide que su tribunal local aplique la ley de California para
resolver las disputas, estas se regirán por las leyes locales aplicables de su
país, estado o lugar de residencia. Si usará los Recursos de AlphaFold 3 en
nombre de una organización gubernamental que no sea del gobierno federal de
Estados Unidos (donde se aplican las disposiciones mencionadas anteriormente en
la medida en que la ley federal lo permita), estas Condiciones no se aplicarán
en relación con la ley aplicable y los tribunales.
Dada la naturaleza de la investigación científica, el incumplimiento de las
Condiciones puede tardar algún tiempo en hacerse evidente. Para protegerlo a
usted, y proteger a Google y a los Recursos de AlphaFold 3, en la medida en que
lo permita la ley aplicable, usted acepta lo siguiente:
1. Cualquier demanda legal relacionada con las Condiciones o los Recursos de
AlphaFold 3 podrá iniciarse hasta la fecha posterior de lo siguiente:
1. la fecha límite que establece la ley aplicable para interponer una
demanda legal; o
2. dos años a partir de la fecha en que usted o Google (según corresponda)
tomaron conocimiento, o debieron haber tomado conocimiento
razonablemente, de los hechos que dieron lugar a dicha demanda, y
2. Ni usted ni Google alegarán prescripción, caducidad, demora, renuncia o
similares para intentar impedir una acción presentada dentro de ese período.
Todos los derechos que no se le otorguen específica y expresamente en las
Condiciones quedan reservados a Google. Ninguna demora, omisión o acto de Google
en el ejercicio de cualquier derecho o recurso se considerará una renuncia de
cualquier incumplimiento de las Condiciones y Google se reserva expresamente
todos los derechos y recursos disponibles según las Condiciones, la ley, por
acuerdo implícito o de cualquier otro modo, lo que incluye el recurso de medida
cautelar contra cualquier amenaza o hecho de infracción de las Condiciones sin
la necesidad de mostrar daños reales.
================================================
FILE: legal/WEIGHTS_TERMS_OF_USE-Francais-Canada.md
================================================
# CONDITIONS D'UTILISATION DES PARAMÈTRES DU MODÈLE ALPHAFOLD 3
Dernière modification: 2024-11-09
[AlphaFold 3](https://blog.google/technology/ai/google-deepmind-isomorphic-alphafold-3-ai-model/)
est un modèle d'IA développé par [Google DeepMind](https://deepmind.google/) et
[Isomorphic Labs](https://www.isomorphiclabs.com/). Il génère des prédictions de
structures 3D de molécules biologiques en fournissant la confiance du modèle
pour les prédictions de structures. Pour certaines utilisations non
commerciales, nous mettons gratuitement à disposition les paramètres du modèle
entraîné et les résultats générés à l'aide de ces paramètres, conformément aux
présentes conditions d'utilisation et à la
[Politique d'utilisation interdite des paramètres du modèle AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Francais-Canada.md).
**Éléments clés à connaître lors de l'utilisation des paramètres du modèle
AlphaFold 3 et les résultats**
1. Les paramètres du modèle AlphaFold 3 et les résultats sont **uniquement**
disponibles pour un usage non commercial par des organisations non
commerciales ou au nom de celles-ci (*c.-à-d.* universités, organismes sans
but lucratif, instituts de recherche et organismes éducatifs,
journalistiques et gouvernementaux). Si vous êtes un chercheur affilié à une
organisation non commerciale, **à la condition que vous ne soyez pas une
organisation commerciale ou que vous n'agissiez pas au nom d'une
organisation commerciale**, cela signifie que vous pouvez les utiliser pour
votre recherche affiliée non commerciale.
2. Vous **ne devez pas** utiliser ni permettre à d'autres personnes d'utiliser:
1. les paramètres du modèle AlphaFold 3 ou les résultats dans le cadre de
**toute activité commerciale, y compris la recherche au nom
d'organisations commerciales**; ou
2. les résultats d'AlphaFold 3 pour **entraîner des modèles d'apprentissage
automatique** ou une technologie connexe de **prédiction de structures
biomoléculaires** semblable à AlphaFold 3.
3. Vous ***ne devez pas* publier ni partager les paramètres du modèle AlphaFold
3**, sauf si vous les partagez au sein de votre organisation conformément
aux présentes Conditions.
4. Vous ***pouvez* publier, partager ou adapter les *résultats* d'AlphaFold 3**
conformément aux présentes Conditions, y compris à l'exigence de fournir un
préavis clair de toute modification que vous apportez et à celle stipulant
que l'utilisation continue des résultats et des œuvres dérivées d'AlphaFold
3 est soumise aux
C[onditions d'utilisation des résultats d'AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md).
En utilisant, reproduisant, modifiant, exécutant, distribuant ou affichant toute
portion ou tout élément des Paramètres du modèle (comme défini ci-dessous) ou en
acceptant autrement les conditions de ce contrat, vous acceptez d'être lié par
(1) ces conditions d'utilisation et (2) la
[Politique d'utilisation interdite des paramètres du modèle AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Francais-Canada.md)
qui est incorporée aux présentes par référence (collectivement, les « Conditions
»), dans chaque cas (a) tel que modifié de temps à autre conformément aux
Conditions et (b) entre vous et (i) si vous êtes d'un pays de l'Espace
économique européen ou de la Suisse, Google Ireland Limited, ou (ii) autrement,
Google LLC.
Vous confirmez que vous êtes autorisé, soit explicitement ou implicitement, à
accepter les Conditions et que vous les acceptez, en tant qu'employé ou
autrement, au nom de votre organisation.
Veuillez lire ces Conditions attentivement. Elles établissent ce à quoi vous
pouvez vous attendre de nous lorsque vous accédez aux Éléments d'AlphaFold 3 et
que vous les utilisez (comme défini ci-dessous), et ce à quoi Google s'attend de
vous. Par « **vous** », nous entendons l’individu ou l'organisation qui utilise
les Éléments d'AlphaFold 3. Par « **nous** », « **notre** » ou « **Google** »,
nous entendons les entités qui appartiennent au groupe d'entreprises Google,
c'est-à-dire Google LLC et ses filiales.
## 1. Définitions clés
Telle qu’utilisées dans ces Conditions: \
« **AlphaFold 3** » désigne: (a) le code source d'AlphaFold 3 rendu accessible
[ici](https://github.com/google-deepmind/alphafold3/) et sous les conditions de
la licence « Creative Commons Attribution-NonCommercial-Sharealike 4.0
International (CC-BY-NC-SA 4.0) » ainsi que tout code source d'œuvres dérivées
et (b) les Paramètres du modèle.
« **Éléments d'AlphaFold 3** » signifie les Paramètres du modèle et les
Résultats.
« **Distribution** » ou « **Distribuer** » signifient toute transmission,
publication ou tout autre partage de Résultats effectués publiquement ou avec
une autre personne.
« **Paramètres du modèle** » désigne les poids du modèle entrainé et paramètres
mis à disposition par Google pour les organisations (à sa seule discrétion) pour
leur utilisation conformément à ces Conditions, ainsi que (a) les modifications
apportées à ces poids et paramètres (b) les travaux basés sur ces poids et
paramètres ou (c) tout autre code ou tout autre modèle d'apprentissage
automatique qui intègre, en totalité ou en partie, ces poids et paramètres.
« **Résultats** » désigne les prédictions de structures et toutes les
informations auxiliaires et connexes fournies par AlphaFold 3 ou utilisant les
Paramètres du modèle ainsi que toutes les représentations visuelles, les
prédictions informatiques, les descriptions, les modifications, les copies ou
les adaptations qui sont substantiellement dérivées des Résultats.
« **Y compris** » signifie « **y compris, sans s'y limiter** ».
## 2. Accéder aux Éléments d'AlphaFold 3 et les utiliser
Sous réserve de votre conformité aux Conditions, y compris la
P[olitique d'utilisation interdite des paramètres du modèle AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Francais-Canada.md),
vous pouvez accéder aux Éléments d'AlphaFold 3, les utiliser et les modifier, et
Distribuer les Résultats comme indiqué dans ces Conditions. Nous vous accordons
une licence non exclusive, libre de redevances, révocable, non transférable et
non susceptible de faire l'objet d'une sous-licence (sauf si expressément permis
dans ces Conditions) sur tout droit de propriété intellectuelle que nous
détenons sur les Éléments d'AlphaFold, dans la mesure nécessaire à ces fins.
Afin de vérifier votre accès à AlphaFold 3 et votre utilisation de celui-ci,
nous pouvons de temps à autre vous demander des informations supplémentaires, y
compris la validation de votre nom, de votre organisation et d'autres
informations d'identification.
En accédant aux Éléments d'AlphaFold 3, en les utilisant ou en les modifiant, en
Distribuant des Résultats ou en demandant l'accès aux Paramètres du modèle, vous
déclarez et garantissez que (a) vous avez les pleins pouvoirs et l'autorité
nécessaire pour accepter ces Conditions (y compris avoir l'âge de consentement
requis) (b) Google n'a jamais précédemment résilié votre accès à AlphaFold 3 ni
votre droit de l'utiliser (y compris au moyen du
[Serveur AlphaFold](https://alphafoldserver.com/about)) en raison de votre
violation des conditions d'utilisation applicables (c) l'acceptation de ces
Conditions ou l'exécution de vos droits et obligations en vertu de ces
Conditions ne violera aucun contrat que vous avez avec un tiers ni aucun droit
d'un tiers (d) toute information que vous fournissez à Google en relation avec
AlphaFold 3, y compris (le cas échéant) pour demander l'accès aux Paramètres du
modèle, est correcte et à jour, et (e) vous n'êtes pas (i) résident d'un pays
soumis à un embargo (ii) habituellement résident d'un pays sous embargo
américain ni (iii) autrement soumis à une interdiction, par les contrôles
d'exportation et les programmes de sanctions applicables, d'accéder aux Éléments
d'AlphaFold 3, de les utiliser ou de les modifier.
Si vous choisissez de donner des commentaires à Google, comme des suggestions
pour améliorer AlphaFold 3, vous vous engagez à ce que ces informations soient
non confidentielles et non propriétaire, et Google pourra agir en fonction de
vos commentaires sans aucune obligation envers vous.
## 3. Restrictions d'utilisation
Vous ne devez pas utiliser les Éléments d'AlphaFold 3:
1. pour les utilisations restreintes énoncées dans la
P[olitique d'utilisation interdite des paramètres du modèle AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Francais-Canada.md);
2. en violation des lois et règlements applicables.
Dans toute la mesure permise par la loi et sans limiter aucun de nos autres
droits, Google se réserve le droit de révoquer votre droit d'utilisation et
(dans la mesure du possible) de restreindre l'utilisation de tout Élément
d'AlphaFold 3 que Google estime raisonnablement être en violation de ces
Conditions.
## 4. Résultats générés
Bien que vous deviez respecter ces Conditions lors de l'utilisation des Éléments
d'AlphaFold 3, nous ne revendiquerons pas la propriété des Résultats d'origine
que vous générez en utilisant AlphaFold 3. Cependant, vous reconnaissez
qu'AlphaFold 3 peut générer les mêmes Résultats ou des Résultats semblables pour
plusieurs utilisateurs, y compris Google, et nous nous réservons tous nos droits
à cet égard.
## 5. Modifications aux Éléments d'AlphaFold 3 ou aux présentes Conditions
Google peut ajouter ou retirer des fonctions ou fonctionnalités des Éléments
d'AlphaFold 3 à tout moment et peut cesser d'offrir l'accès aux Éléments
d'AlphaFold 3.
Google peut mettre à jour ces Conditions et le mécanisme d'accès aux Paramètres
du modèle à tout moment. Nous allons publier toute modification apportée aux
Conditions
[dans le référentiel GitHub d'AlphaFold 3](https://github.com/google-deepmind/alphafold3).
Les modifications entreront généralement en vigueur 14 jours après leur
publication. Cependant, les modifications concernant la fonctionnalité ou celles
apportées pour des raisons juridiques entreront en vigueur immédiatement.
Vous devriez revoir les Conditions chaque fois que nous les mettons à jour ou
que vous utilisez les Éléments d'AlphaFold 3. Si vous n'acceptez pas les
modifications apportées aux Conditions, vous devez cesser d'utiliser les
Éléments d'AlphaFold 3 immédiatement.
## 6. Suspendre ou résilier votre droit d'utiliser les Éléments d'AlphaFold 3
Google peut à tout moment suspendre ou résilier votre droit d'utiliser les
Éléments d'AlphaFold 3 et, le cas échéant, d'y accéder pour différentes raisons,
notamment votre manquement à respecter entièrement les présentes Conditions. Si
Google suspend ou résilie votre droit d'accéder aux Éléments d'AlphaFold 3 ou de
les utiliser, vous devez immédiatement supprimer toutes les copies des Éléments
d'AlphaFold 3 en votre possession ou sous votre contrôle et cesser de les
utiliser et de les Distribuer, et il vous est interdit d'utiliser les Éléments
d'AlphaFold 3, y compris en soumettant une demande pour utiliser les Paramètres
du modèle. Google s'efforcera de vous donner un préavis raisonnable avant toute
suspension ou résiliation, mais aucun avis ni avertissement préalable ne sera
donné si la suspension ou la résiliation est due à votre manquement à respecter
entièrement les présentes Conditions ou à d'autres motifs sérieux.
Bien entendu, vous êtes toujours libre de cesser d'utiliser les Éléments
d'AlphaFold 3. Si vous cessez de les utiliser, nous aimerions savoir pourquoi (à
l'adresse alphafold@google.com) afin de pouvoir continuer à améliorer nos
technologies.
## 7. Confidentialité
Vous acceptez de ne pas divulguer ni rendre disponibles les renseignements
confidentiels de Google à quiconque sans notre consentement écrit préalable. «
**Renseignements confidentiels de Google** » désigne (a) les Paramètres du
modèle AlphaFold 3 et tous les logiciels, la technologie et la documentation en
lien avec AlphaFold 3, excepté le code source d'AlphaFold 3, et (b) toute autre
information mise à disposition par Google qui est marquée comme confidentielle
ou qui serait normalement considérée comme confidentielle dans les circonstances
dans lesquelles elle est présentée. Les Renseignements confidentiels de Google
n'incluent pas (a) les informations que vous connaissiez déjà avant d'accéder
aux Éléments d'AlphaFold 3 ou de les utiliser (y compris au moyen du
[Serveur AlphaFold](https://alphafoldserver.com/about)) (b) qui deviennent
publiques sans que vous en soyez responsable (par exemple, par votre violation
des Conditions) (c) qui ont été développées indépendamment par vous sans
référence aux Renseignements confidentiels de Google ou (d) qui vous ont été
légalement fournies par un tiers (sans violation des Conditions par vous-même ou
par le tiers).
## 8. Clauses de non-responsabilité
Rien dans les Conditions ne restreint les droits qui ne peuvent pas être
restreints en vertu de la loi applicable ni ne limite les responsabilités de
Google, sauf si cela est permis par la loi applicable.
AlphaFold 3 et les Résultats sont fournis « tels quels », sans garantie ni
condition de quelque nature que ce soit, explicite ou implicite, y compris toute
garantie ou condition de titre, d'absence de violation, de qualité marchande ou
d'adéquation avec un usage particulier. Vous êtes seul responsable de déterminer
la légitimité de l'utilisation d'AlphaFold 3 ou celle de l'utilisation et de la
distribution des Résultats, et vous assumez tous les risques liés à une telle
utilisation ou distribution ainsi qu'à l'exercice de vos droits et obligations
en vertu de ces Conditions. Vous et toute personne avec qui vous partagez des
Résultats êtes les seuls responsables de ces utilisations et de celles qui
s’ensuivent.
Les Résultats sont des prédictions avec des niveaux de confiance variables et
doivent être interprétés avec prudence. Faites preuve de discernement avant de
vous fier à AlphaFold 3, de le publier, de le télécharger ou de l’utiliser d'une
autre manière.
AlphaFold 3 et les Résultats sont uniquement destinés à la modélisation
théorique. Ils ne sont pas prévus, validés, ni approuvés pour une utilisation
clinique. Vous ne devez pas utiliser AlphaFold 3 ni les Résultats à des fins
cliniques ni les considérer comme des conseils médicaux ou professionnels. Tout
contenu concernant ces sujets est fourni à titre informatif uniquement et ne
remplace pas les conseils d'un professionnel qualifié.
## 9. Responsabilités
Dans la mesure permise par la loi applicable, vous indemniserez Google et ses
administrateurs, dirigeants, employés et sous-traitants pour toutes poursuites
judiciaires intentées par des tiers (y compris des actions menées par des
autorités gouvernementales) découlant de ou en rapport avec votre utilisation
illégale des Éléments d'AlphaFold 3 ou à votre violation des présentes
Conditions. Cette indemnité couvre toute responsabilité ou charge financière
résultant de réclamations, de pertes, de dommages, de jugements, d'amendes, de
débours et de frais juridiques, sauf dans la mesure où une responsabilité ou une
charge financière est causée par une violation, une négligence ou une inconduite
intentionnelle de Google. Si vous êtes légalement exempté de certaines
responsabilités, y compris l'indemnisation, alors ces responsabilités ne
s'appliquent pas à vous en vertu des présentes Conditions.
Google n'est pas responsable, en aucun cas, des dommages-intérêts indirects,
spéciaux, accessoires, exemplaires, consécutifs ou punitifs ni des pertes de
profits de quelque nature que ce soit en rapport avec les Conditions ou les
Éléments d'AlphaFold 3, même si Google a été informée de la possibilité de tels
dommages. L'obligation globale et totale de Google pour toutes les réclamations
découlant des Conditions ou des Éléments d'AlphaFold 3 ou en lien avec ceux-ci,
y compris pour sa propre négligence, est limitée à 500,00USD.
## 10. Divers
Selon la loi, vous avez certains droits qui ne peuvent pas être limités par un
contrat tel que les Conditions. Les présentes Conditions ne visent aucunement à
restreindre ces droits.
Les Conditions constituent l'intégralité de notre contrat concernant votre
utilisation des Éléments d'AlphaFold 3 et remplacent tous les contrats
antérieurs ou contemporains sur ce sujet.
Si une disposition particulière des présentes Conditions s'avère inapplicable,
le reste des conditions restera en vigueur.
## 11. Contestations
Les lois de la Californie régiront toutes les contestations découlant de ou en
rapport avec ces Conditions ou en lien avec les Éléments d'AlphaFold 3. Ces
contestations seront résolues exclusivement par les tribunaux fédéraux ou
étatiques du comté de Santa Clara, en Californie, aux États-Unis, et vous et
Google consentez à la compétence territoriale de ces tribunaux. Dans la mesure
où la loi locale applicable s'oppose à ce que certaines contestations soient
résolues devant un tribunal de la Californie, vous et Google pouvez les
soumettre à vos tribunaux locaux. Si la loi locale applicable s'oppose à ce que
votre tribunal local applique la loi californienne pour résoudre ces
contestations, elles seront régies par les lois locales applicables de votre
pays, de votre État ou de votre autre lieu de résidence. Si vous utilisez les
Éléments d'AlphaFold 3 au nom d'une organisation gouvernementale autre que les
organisations gouvernementales fédérales américaines (où les dispositions
précédentes s'appliquent dans la mesure permise par la loi fédérale), ces
Conditions seront silencieuses en ce qui concerne la loi applicable et les
tribunaux.
Considérant la nature de la recherche scientifique, il peut s'écouler un certain
temps avant que toute violation des présentes Conditions devienne évidente. \
Dans la mesure permise par la loi applicable, pour vous protéger, Google et les
Éléments d'AlphaFold 3, vous acceptez que:
1. toute réclamation légale liée aux présentes Conditions ou aux Éléments
d'AlphaFold 3 peut être intentée jusqu'à la date la plus tardive entre:
1. la date limite prévue par la loi applicable pour intenter la réclamation
légale; ou
2. deux années à partir de la date à laquelle vous ou Google (selon le cas)
avez pris connaissance ou auriez dû raisonnablement prendre connaissance
des faits à l'origine de cette réclamation; et
2. vous n'invoquerez pas la limitation, la prescription, le retard, la
renonciation ou des arguments semblables pour tenter de faire obstacle à une
action intentée dans ce délai et Google non plus.
Tous les droits qui ne vous sont pas précisément et expressément accordés par
les présentes Conditions sont réservés à Google. Aucun retard, acte ni aucune
omission de la part de Google dans l'exercice d'un droit ou d'un recours ne sera
considéré comme une renonciation à une violation des Conditions, et Google se
réserve expressément tous les droits et recours disponibles en vertu des
Conditions ou de la loi, en équité ou autrement, y compris le recours à une
injonction contre toute menace de violation ou violation réelle des Conditions
sans qu'il soit nécessaire de prouver des dommages réels.
================================================
FILE: legal/WEIGHTS_TERMS_OF_USE-Portugues-Brazil.md
================================================
# TERMOS DE USO DOS PARÂMETROS DO MODELO ALPHAFOLD 3
Última modificação: 2024-11-09
O
[AlphaFold 3](https://blog.google/technology/ai/google-deepmind-isomorphic-alphafold-3-ai-model/)
é um modelo de IA desenvolvido pelo [Google DeepMind](https://deepmind.google/)
e pela [Isomorphic Labs](https://www.isomorphiclabs.com/). Ele gera previsões
sobre a estrutura 3D de moléculas biológicas, apresentando a confiança do
modelo. Disponibilizamos os parâmetros do modelo treinado e as saídas geradas
por ele sem custo financeiro para determinados usos não comerciais, de acordo
com estes Termos de Uso e com a
[Política de uso proibido dos parâmetros do modelo AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Portugues-Brazil.md).
**Observações importantes sobre o uso dos parâmetros e das saídas do modelo
AlphaFold 3**
1. Os parâmetros e as saídas do modelo AlphaFold 3 só estão disponíveis para
uso não comercial por organizações não comerciais ou em nome delas (*por
exemplo*, universidades, organizações sem fins lucrativos, institutos de
pesquisa e órgãos governamentais, educacionais e de notícias). Se você for
um pesquisador afiliado a uma organização não comercial, você tem permissão
para usar esses recursos em sua pesquisa afiliada a organizações sem fins
lucrativos, desde que você não seja uma organização comercial nem esteja
agindo em nome de uma.
2. Não use nem permita que outras pessoas usem:
1. os parâmetros ou as saídas do modelo AlphaFold 3 em relação a qualquer
atividade comercial, incluindo a pesquisa em nome de organizações
comerciais; ou
2. a saída do AlphaFold 3 para treinar modelos de aprendizado de máquina ou
tecnologia relacionada na previsão de estrutura biomolecular semelhante
ao AlphaFold 3.
3. Você *não tem permissão* para publicar ou compartilhar os parâmetros do
modelo AlphaFold 3, exceto compartilhar dentro da sua organização de acordo
com estes Termos.
4. Você *tem permissão* para publicar, compartilhar e adaptar as *saídas* do
AlphaFold 3 de acordo com estes Termos, incluindo os requisitos de oferecer
aviso claro de quaisquer modificações e de que o uso contínuo das saídas e
derivados do modelo estão sujeitos aos
[Termos de Uso das saídas do AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md).
Ao usar, reproduzir, modificar, realizar, distribuir ou exibir qualquer parte ou
elemento dos Parâmetros do modelo (conforme definido abaixo) ou aceitar de outra
forma os termos deste contrato, você concorda em se vincular (1) a estes Termos
de Uso e (2) à
[Política de uso proibido dos parâmetros do modelo AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Portugues-Brazil.md),
incorporada aqui como referência (em conjunto, os "**Termos**"), em cada caso
(a) conforme modificado periodicamente de acordo com os Termos e (b) entre você
e (i), se você for de um país no Espaço Econômico Europeu ou da Suíça, a Google
Ireland Limited ou (ii), caso contrário, a Google LLC.
Você confirma que tem autorização explícita ou implícita para celebrar, e está
celebrando, estes Termos como funcionário ou de outro modo em nome da sua
organização.
Leia estes Termos com atenção. Eles definem o que você pode esperar de nós ao
acessar e usar os Recursos do AlphaFold 3 (conforme definido abaixo) e o que o
Google espera de você. "**Você**" significa o indivíduo ou a organização que
está usando os Recursos do AlphaFold 3. "**Nós**", "**nos**" ou "**Google**"
significam as entidades que pertencem ao grupo de empresas do Google, ou seja, a
Google LLC e suas afiliadas.
## 1. Principais definições
Conforme usado nestes Termos:
"**AlphaFold 3**" significa: (a) o código-fonte do AlphaFold 3 disponível
[neste link](https://github.com/google-deepmind/alphafold3/) e licenciado nos
termos da licença Creative Commons Attribution-NonCommercial-Sharealike 4.0
International (CC-BY-NC-SA 4.0), bem como qualquer código-fonte derivado, e (b)
Parâmetros do modelo.
"**Recursos do AlphaFold 3**" significam as Saídas e os Parâmetros do modelo.
"**Distribuição**" ou "**Distribuir**" significam qualquer transmissão,
publicação ou outra forma de compartilhamento das Saídas publicamente ou com
qualquer outra pessoa.
"**Parâmetros do modelo**" significam os pesos e os parâmetros do modelo
treinado disponibilizados pelo Google às organizações (a critério próprio) para
uso de acordo com estes Termos, com (a) modificações nesses pesos e parâmetros,
com (b) trabalhos baseados nesses pesos e parâmetros ou (c) com outros códigos
ou modelos de aprendizado de máquina que incorporam esses pesos e parâmetros na
íntegra ou em partes.
"**Saída**" significa as previsões de estrutura e todas as informações
adicionais e relacionadas que são fornecidas pelo AlphaFold 3 ou usam os
Parâmetros do modelo, com quaisquer representações visuais, previsões
computacionais, descrições, modificações, cópias ou adaptações derivadas
consideravelmente da Saída.
"**Incluindo**" significa "**incluindo, sem limitação**".
## 2. Acesso e uso dos Recursos do AlphaFold 3
Sujeito à sua compliance com os Termos, incluindo a
[Política de uso proibido do AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Portugues-Brazil.md),
você pode acessar, usar e modificar os Recursos do AlphaFold 3 e Distribuir as
Saídas conforme definido nestes Termos. Concedemos a você uma licença não
exclusiva, livre de royalties, revogável, intransferível e não sublicenciável
(exceto conforme expressamente permitido nestes Termos) para todos os nossos
direitos de propriedade intelectual dos Recursos do AlphaFold na medida
necessária para esses fins. Para verificar seu acesso e uso do AlphaFold 3,
podemos solicitar informações adicionais periodicamente, incluindo a verificação
do seu nome, organização e outras informações de identificação.
Ao acessar, usar ou modificar os Recursos do AlphaFold 3, Distribuir a Saída ou
solicitar acesso aos Parâmetros do modelo, você declara e garante que (a) tem
total capacidade legal para celebrar estes Termos (incluindo a idade mínima de
consentimento), (b) o Google nunca rescindiu seu acesso e direito de usar o
AlphaFold 3 (incluindo conforme disponibilizado pelo
[Servidor da AlphaFold](https://alphafoldserver.com/about)) devido à sua
violação dos Termos de Uso relevantes, (c) celebrar ou exercer seus direitos e
obrigações de acordo com estes Termos não violará nenhum contrato firmado entre
você e um terceiro ou quaisquer direitos de terceiros, (d) quaisquer informações
que você fornecer ao Google em relação ao AlphaFold 3, incluindo solicitar
acesso aos Parâmetros do modelo (quando aplicável), são verdadeiras e atuais, e
(e) você (i) não é residente de um país embargado, (ii) não é residente
ordinário de um país embargado pelos EUA ou (iii) não tem nenhuma outra
proibição de acessar, usar ou modificar os Recursos do AlphaFold 3 pelos
programas de sanções e controles de exportação aplicáveis.
Ao optar por dar feedback ao Google, como sugestões para melhorar o AlphaFold 3,
você assume que essas informações não são confidenciais nem reservadas, e
poderemos agir de acordo com seu feedback sem qualquer compromisso com você.
## 3. Restrições de uso
Você não tem permissão para usar qualquer Recurso do AlphaFold 3:
1. para os usos restritos estabelecidos na
[Política de uso proibido dos Parâmetros do modelo AlphaFold 3](https://github.com/google-deepmind/alphafold3/blob/main/legal/WEIGHTS_PROHIBITED_USE_POLICY-Portugues-Brazil.md);
ou
2. em violação das leis e regulamentações aplicáveis.
Até o limite permitido pela legislação e sem limitação de quaisquer outros
direitos, o Google reserva o direito de revogar e (até onde possível) restringir
seu uso de qualquer Recurso do AlphaFold 3 que acreditamos razoavelmente violar
estes Termos.
## 4. Saída gerada
Embora você precise cumprir com estes Termos ao usar os Recursos do AlphaFold 3,
não reivindicaremos propriedade da Saída original que você gerar usando o
AlphaFold 3. No entanto, você reconhece que o AlphaFold 3 pode gerar uma Saída
igual ou semelhante para vários usuários, incluindo o Google, e reservamos todos
os direitos nesse sentido.
## 5. Mudanças nos Recursos do AlphaFold 3 ou nestes Termos
O Google pode adicionar ou remover funcionalidades ou funções dos Recursos do
AlphaFold 3 a qualquer momento e parar de oferecer acesso a elas completamente.
O Google pode atualizar estes Termos e o mecanismo de acesso aos Parâmetros do
modelo a qualquer momento. Quaisquer modificações nestes Termos serão postadas
[no repositório GitHub do AlphaFold 3](https://github.com/google-deepmind/alphafold3).
Geralmente, as alterações entrarão em vigor 14 dias após a postagem. No entanto,
as alterações relacionadas à funcionalidade ou feitas por motivos jurídicos
serão aplicadas imediatamente.
Consulte os Termos sempre que forem atualizados ou você usar os Recursos do
AlphaFold 3. Se você não concordar com quaisquer modificações nos Termos, pare
de usar os Recursos do AlphaFold 3 imediatamente.
## 6. Suspensão ou encerramento do seu direito de usar os Recursos do AlphaFold 3
O Google pode suspender ou encerrar a qualquer momento seu direito de usar e,
conforme aplicável, acessar os Recursos do AlphaFold 3 devido ao não cumprimento
dos Termos, entre outros motivos. Se o Google suspender ou encerrar seu direito
de acessar ou usar os Recursos do AlphaFold 3, você precisará excluir e parar de
usar e Distribuir imediatamente todas as cópias dos Recursos do AlphaFold 3 em
sua posse ou controle. Você não poderá usar os Recursos do AlphaFold 3,
incluindo o envio de um aplicativo para usar os Parâmetros do modelo. O Google
fará o possível para fornecer aviso prévio razoável antes de qualquer suspensão
ou encerramento, mas não daremos nenhum aviso ou alerta com antecedência se a
suspensão ou o encerramento for por não obedecer totalmente aos Termos ou outras
justificativas graves.
Você pode parar de usar os Recursos do AlphaFold 3 a qualquer momento. Nesse
caso, queremos saber o motivo (via alphafold@google.com) para continuarmos
melhorando nossas tecnologias.
## 7. Confidencialidade
Você concorda em não divulgar nem disponibilizar Informações confidenciais do
Google a qualquer pessoa sem nosso consentimento prévio por escrito.
"**Informações confidenciais do Google**" significam (a) os Parâmetros do modelo
e todo software, tecnologia e documentação associados ao AlphaFold 3, exceto
para o código-fonte do AlphaFold 3, e (b) quaisquer outras informações
disponibilizadas pelo Google que foram marcadas como confidenciais ou que seriam
normalmente consideradas assim nas circunstâncias em que são apresentadas. As
Informações confidenciais do Google não incluem (a) informações que você já
sabia antes do seu acesso ou uso dos Recursos do AlphaFold 3 (incluindo pelo
[Servidor do AlphaFold](https://alphafoldserver.com/about)), (b) que se tornaram
públicas sem sua culpa (por exemplo, sua violação dos Termos), (c) que foram
desenvolvidas de maneira independente por você sem referência às Informações
confidenciais do Google ou (d) foram fornecidas legalmente a você por um
terceiro (sem que você nem o terceiro violassem os Termos).
## 8. Exoneração de responsabilidade
Os Termos não restringem quaisquer direitos que não possam ser restritos de
acordo com a legislação aplicável nem limitam as responsabilidades do Google,
exceto conforme permitido pela legislação aplicável.
**O AlphaFold 3 e as Saídas são fornecidos no estado em que se encontram, sem
garantias ou condições de qualquer tipo, sejam explícitas ou implícitas,
incluindo quaisquer garantias ou condições de título, comercialidade, adequação
para uma finalidade específica e não violação. Você é a única pessoa responsável
por determinar se o uso do AlphaFold 3, ou uso/distribuição das Saídas, é
adequado e assume qualquer e todo risco associado a esse uso ou distribuição e
ao exercício dos seus direitos e obrigações de acordo com estes Termos. Você e
qualquer pessoa com quem compartilhar as Saídas são exclusivamente responsáveis
por elas e pelos usos subsequentes delas.**
**As Saídas são previsões com níveis variados de confiança e devem ser
interpretadas com cuidado. Tenha cautela antes de confiar, publicar, baixar ou
usar de outra forma o AlphaFold 3.**
**O AlphaFold 3 e as Saídas servem apenas para modelagem teórica. Eles não são
destinados, validados nem aprovados para uso clínico. Não os use para
finalidades clínicas nem conte com eles para aconselhamento médico ou de outra
natureza. Todo conteúdo sobre esses assuntos é fornecido somente para fins
informativos e não substitui a orientação de um profissional qualificado.**
## 9. Responsabilidades
Na medida permitida pela lei, você indenizará o Google e os diretores,
executivos, funcionários e prestadores de serviço dele por qualquer processo
judicial de terceiros (incluindo ações de órgãos do governo) decorrente ou
relacionado ao uso ilegal dos Recursos do AlphaFold ou a violações dos Termos.
Essa indenização cobre qualquer responsabilidade ou despesa decorrente de ações
judiciais, perdas, danos, julgamentos, multas, custos de litígios e honorários
jurídicos, exceto se a responsabilidade ou despesa for causada por violação,
negligência ou má conduta intencional do Google. Se você for passível de isenção
legal de certas responsabilidades, incluindo indenização, essas
responsabilidades não se aplicarão a você de acordo com os Termos.
Em hipótese alguma o Google será responsável por quaisquer danos indiretos,
especiais, incidentais, exemplares, emergentes ou punitivos ou por perdas de
lucros de qualquer tipo em relação aos Termos ou aos Recursos do AlphaFold 3,
mesmo se o Google tiver sido advertido da possibilidade de tais danos. A
responsabilidade agregada total do Google para todas as ações judiciais
decorrentes de ou relacionadas aos Termos ou aos Recursos do AlphaFold 3,
incluindo pela nossa negligência, é limitada a US$ 500.
## 10. Disposições gerais
Por lei, você tem certos direitos que não podem ser limitados por um contrato
como os Termos. Os Termos não têm, de forma alguma, o objetivo de restringir
esses direitos.
Os Termos constituem a integralidade do nosso contrato relacionado ao seu uso
dos Recursos do AlphaFold 3 e substituem quaisquer contratos anteriores ou
contemporâneos sobre esse assunto.
Se uma disposição específica dos Termos não for aplicável, o saldo dos Termos
permanecerá vigente.
## 11. Disputas
As leis da Califórnia vão reger todas as disputas que surgirem com relação aos
Termos ou em relação aos Recursos do AlphaFold 3. Essas disputas serão
resolvidas exclusivamente nos tribunais federais ou estaduais do condado de
Santa Clara, Califórnia, EUA, e você e o Google concordam com a jurisdição
pessoal nesses tribunais. Se a legislação local aplicável impedir que alguma
disputa seja tratada em um tribunal na Califórnia, você e o Google podem entrar
com a petição no seu foro local. Da mesma forma, se a legislação local aplicável
impedir que o tribunal local aplique a lei da Califórnia para resolver essas
disputas, elas serão regidas pelas leis do seu país, estado ou outro local de
residência. Se você usar os Recursos do AlphaFold 3 em nome de uma organização
governamental que não seja do governo federal dos EUA (onde as disposições acima
se aplicam até onde permitido pela legislação federal), estes Termos não se
aplicarão quanto à legislação aplicável e aos tribunais.
Considerando a natureza das pesquisas científicas, pode levar algum tempo para
qualquer violação dos Termos se tornar aparente. Para proteger você, o Google e
os Recursos do AlphaFold 3, até onde permitido pela legislação aplicável, você
concorda que:
1. qualquer ação judicial relacionada aos Termos ou Recursos do AlphaFold 3
pode ser iniciada até o que ocorrer por último:
1. a data-limite de acordo com a legislação aplicável para iniciar a ação
judicial; ou
2. dois anos após a data em que você ou o Google (conforme aplicável) tomou
conhecimento, ou deve ter tomado conhecimento de forma razoável, dos
fatos que deram origem a essa ação; e
2. você não alegará limitação, prazo de prescrição, atraso, renúncia ou
semelhantes para tentar impedir uma ação registrada nesse período, e o
Google também não.
Todos os direitos que não forem concedidos a você de maneira específica e
explícita pelos Termos são reservados ao Google. Nenhum atraso, ação ou omissão
do Google em exercer qualquer direito ou correção será considerado uma renúncia
a qualquer violação dos Termos, e o Google reserva expressamente todos e
quaisquer direitos e correções disponíveis de acordo com os Termos ou com base
na lei, na equidade ou de outra forma, incluindo a correção da tutela de
urgência contra qualquer violação real dos Termos ou ameaça disso sem precisar
comprovar danos reais.
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = [
"scikit_build_core",
"pybind11",
"cmake>=3.28",
"ninja",
"numpy",
]
build-backend = "scikit_build_core.build"
[project]
name = "alphafold3"
version = "3.0.1"
requires-python = ">=3.12"
readme = "README.md"
license = {file = "LICENSE"}
dependencies = [
"absl-py>=2.3.1",
"dm-haiku==0.0.16",
"jax==0.9.1",
"jax[cuda12]==0.9.1",
"numpy",
"rdkit==2025.9.4",
"tokamax==0.0.11",
"tqdm",
"zstandard",
]
[dependency-groups]
dev = [
"pytest>=6.0",
]
[tool.uv]
package = true
environments = [
"sys_platform == 'linux' and platform_machine == 'x86_64'",
"sys_platform == 'linux' and platform_machine == 'aarch64'",
]
[tool.scikit-build]
wheel.exclude = [
"**.pyx",
"**/CMakeLists.txt",
"**.cc",
"**.h"
]
sdist.include = [
"LICENSE",
"OUTPUT_TERMS_OF_USE.md",
"WEIGHTS_PROHIBITED_USE_POLICY.md",
"WEIGHTS_TERMS_OF_USE.md",
]
[tool.cibuildwheel]
build = "cp3*-manylinux_x86_64"
manylinux-x86_64-image = "manylinux_2_28"
[project.scripts]
build_data = "alphafold3.build_data:build_data"
================================================
FILE: run_alphafold.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""AlphaFold 3 structure prediction script.
AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
To request access to the AlphaFold 3 model parameters, follow the process set
out at https://github.com/google-deepmind/alphafold3. You may only use these
if received directly from Google. Use is subject to terms of use available at
https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""
from collections.abc import Callable, Sequence
import csv
import dataclasses
import datetime
import functools
import os
import pathlib
import shutil
import string
import textwrap
import time
import typing
from typing import overload
from absl import app
from absl import flags
from alphafold3.common import folding_input
from alphafold3.common import resources
from alphafold3.constants import chemical_components
import alphafold3.cpp
from alphafold3.data import featurisation
from alphafold3.data import pipeline
from alphafold3.data.tools import shards
from alphafold3.model import features
from alphafold3.model import model
from alphafold3.model import params
from alphafold3.model import post_processing
from alphafold3.model.components import utils
import haiku as hk
import jax
from jax import numpy as jnp
import numpy as np
import tokamax
_HOME_DIR = pathlib.Path(os.environ.get('HOME'))
_DEFAULT_MODEL_DIR = _HOME_DIR / 'models'
_DEFAULT_DB_DIR = _HOME_DIR / 'public_databases'
# Input and output paths.
_JSON_PATH = flags.DEFINE_string(
'json_path',
None,
'Path to the input JSON file.',
)
_INPUT_DIR = flags.DEFINE_string(
'input_dir',
None,
'Path to the directory containing input JSON files.',
)
_OUTPUT_DIR = flags.DEFINE_string(
'output_dir',
None,
'Path to a directory where the results will be saved.',
)
MODEL_DIR = flags.DEFINE_string(
'model_dir',
_DEFAULT_MODEL_DIR.as_posix(),
'Path to the model to use for inference.',
)
# Control which stages to run.
_RUN_DATA_PIPELINE = flags.DEFINE_bool(
'run_data_pipeline',
True,
'Whether to run the data pipeline on the fold inputs.',
)
_RUN_INFERENCE = flags.DEFINE_bool(
'run_inference',
True,
'Whether to run inference on the fold inputs.',
)
# Binary paths.
_JACKHMMER_BINARY_PATH = flags.DEFINE_string(
'jackhmmer_binary_path',
shutil.which('jackhmmer'),
'Path to the Jackhmmer binary.',
)
_NHMMER_BINARY_PATH = flags.DEFINE_string(
'nhmmer_binary_path',
shutil.which('nhmmer'),
'Path to the Nhmmer binary.',
)
_HMMALIGN_BINARY_PATH = flags.DEFINE_string(
'hmmalign_binary_path',
shutil.which('hmmalign'),
'Path to the Hmmalign binary.',
)
_HMMSEARCH_BINARY_PATH = flags.DEFINE_string(
'hmmsearch_binary_path',
shutil.which('hmmsearch'),
'Path to the Hmmsearch binary.',
)
_HMMBUILD_BINARY_PATH = flags.DEFINE_string(
'hmmbuild_binary_path',
shutil.which('hmmbuild'),
'Path to the Hmmbuild binary.',
)
# Database paths.
DB_DIR = flags.DEFINE_multi_string(
'db_dir',
(_DEFAULT_DB_DIR.as_posix(),),
'Path to the directory containing the databases. Can be specified multiple'
' times to search multiple directories in order.',
)
_SMALL_BFD_DATABASE_PATH = flags.DEFINE_string(
'small_bfd_database_path',
'${DB_DIR}/bfd-first_non_consensus_sequences.fasta',
'Small BFD database path, used for protein MSA search.',
)
_SMALL_BFD_Z_VALUE = flags.DEFINE_integer(
'small_bfd_z_value',
None,
'The Z-value representing the database size in number of sequences for'
' E-value calculation. Must be set for sharded databases.',
lower_bound=0,
)
_MGNIFY_DATABASE_PATH = flags.DEFINE_string(
'mgnify_database_path',
'${DB_DIR}/mgy_clusters_2022_05.fa',
'Mgnify database path, used for protein MSA search.',
)
_MGNIFY_Z_VALUE = flags.DEFINE_integer(
'mgnify_z_value',
None,
'The Z-value representing the database size in number of sequences for'
' E-value calculation. Must be set for sharded databases.',
lower_bound=0,
)
_UNIPROT_CLUSTER_ANNOT_DATABASE_PATH = flags.DEFINE_string(
'uniprot_cluster_annot_database_path',
'${DB_DIR}/uniprot_all_2021_04.fa',
'UniProt database path, used for protein paired MSA search.',
)
_UNIPROT_CLUSTER_ANNOT_Z_VALUE = flags.DEFINE_integer(
'uniprot_cluster_annot_z_value',
None,
'The Z-value representing the database size in number of sequences for'
' E-value calculation. Must be set for sharded databases.',
lower_bound=0,
)
_UNIREF90_DATABASE_PATH = flags.DEFINE_string(
'uniref90_database_path',
'${DB_DIR}/uniref90_2022_05.fa',
'UniRef90 database path, used for MSA search. The MSA obtained by '
'searching it is used to construct the profile for template search.',
)
_UNIREF90_Z_VALUE = flags.DEFINE_integer(
'uniref90_z_value',
None,
'The Z-value representing the database size in number of sequences for'
' E-value calculation. Must be set for sharded databases.',
lower_bound=0,
)
_NTRNA_DATABASE_PATH = flags.DEFINE_string(
'ntrna_database_path',
'${DB_DIR}/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta',
'NT-RNA database path, used for RNA MSA search.',
)
_NTRNA_Z_VALUE = flags.DEFINE_float(
'ntrna_z_value',
None,
'The Z-value representing the database size in megabases for E-value'
' calculation. Must be set for sharded databases.',
lower_bound=0.0,
)
_RFAM_DATABASE_PATH = flags.DEFINE_string(
'rfam_database_path',
'${DB_DIR}/rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta',
'Rfam database path, used for RNA MSA search.',
)
_RFAM_Z_VALUE = flags.DEFINE_float(
'rfam_z_value',
None,
'The Z-value representing the database size in megabases for E-value'
' calculation. Must be set for sharded databases.',
lower_bound=0.0,
)
_RNA_CENTRAL_DATABASE_PATH = flags.DEFINE_string(
'rna_central_database_path',
'${DB_DIR}/rnacentral_active_seq_id_90_cov_80_linclust.fasta',
'RNAcentral database path, used for RNA MSA search.',
)
_RNA_CENTRAL_Z_VALUE = flags.DEFINE_float(
'rna_central_z_value',
None,
'The Z-value representing the database size in megabases for E-value'
' calculation. Must be set for sharded databases.',
lower_bound=0.0,
)
_PDB_DATABASE_PATH = flags.DEFINE_string(
'pdb_database_path',
'${DB_DIR}/mmcif_files',
'PDB database directory with mmCIF files path, used for template search.',
)
_SEQRES_DATABASE_PATH = flags.DEFINE_string(
'seqres_database_path',
'${DB_DIR}/pdb_seqres_2022_09_28.fasta',
'PDB sequence database path, used for template search.',
)
# Number of CPUs to use for MSA tools.
_JACKHMMER_N_CPU = flags.DEFINE_integer(
'jackhmmer_n_cpu',
# Unfortunately, os.process_cpu_count() is only available in Python 3.13+.
min(len(os.sched_getaffinity(0)), 8),
'Number of CPUs to use for Jackhmmer. Defaults to min(cpu_count, 8). Going'
' above 8 CPUs provides very little additional speedup.',
lower_bound=0,
)
_JACKHMMER_MAX_PARALLEL_SHARDS = flags.DEFINE_integer(
'jackhmmer_max_parallel_shards',
None,
'Maximum number of shards to search against in parallel. If unset, one'
' Jackhmmer instance will be run per shard. Only applicable if the'
' database is sharded.',
lower_bound=1,
)
_NHMMER_N_CPU = flags.DEFINE_integer(
'nhmmer_n_cpu',
# Unfortunately, os.process_cpu_count() is only available in Python 3.13+.
min(len(os.sched_getaffinity(0)), 8),
'Number of CPUs to use for Nhmmer. Defaults to min(cpu_count, 8). Going'
' above 8 CPUs provides very little additional speedup.',
lower_bound=0,
)
_NHMMER_MAX_PARALLEL_SHARDS = flags.DEFINE_integer(
'nhmmer_max_parallel_shards',
None,
'Maximum number of shards to search against in parallel. If unset, one'
' Nhmmer instance will be run per shard. Only applicable if the'
' database is sharded.',
lower_bound=1,
)
# Data pipeline configuration.
_RESOLVE_MSA_OVERLAPS = flags.DEFINE_bool(
'resolve_msa_overlaps',
True,
'Whether to deduplicate unpaired MSA against paired MSA. The default'
' behaviour matches the method described in the AlphaFold 3 paper. Set this'
' to false if providing custom paired MSA using the unpaired MSA field to'
' keep it exactly as is as deduplication against the paired MSA could break'
' the manually crafted pairing between MSA sequences.',
)
_MAX_TEMPLATE_DATE = flags.DEFINE_string(
'max_template_date',
'2021-09-30', # By default, use the date from the AlphaFold 3 paper.
'Maximum template release date to consider. Format: YYYY-MM-DD. All'
' templates released after this date will be ignored. Controls also whether'
' to allow use of model coordinates for a chemical component from the CCD'
' if RDKit conformer generation fails and the component does not have ideal'
' coordinates set. Only for components that have been released before this'
' date the model coordinates can be used as a fallback.',
)
_CONFORMER_MAX_ITERATIONS = flags.DEFINE_integer(
'conformer_max_iterations',
None, # Default to RDKit default parameters value.
'Optional override for maximum number of iterations to run for RDKit '
'conformer search.',
lower_bound=0,
)
# JAX inference performance tuning.
_JAX_COMPILATION_CACHE_DIR = flags.DEFINE_string(
'jax_compilation_cache_dir',
None,
'Path to a directory for the JAX compilation cache.',
)
_GPU_DEVICE = flags.DEFINE_integer(
'gpu_device',
0,
'Optional override for the GPU device to use for inference, uses zero-based'
' indexing. Defaults to the 0th GPU on the system. Useful on multi-GPU'
' systems to pin each run to a specific GPU. Note that if GPUs are already'
' pre-filtered by the environment (e.g. by using CUDA_VISIBLE_DEVICES),'
' this flag refers to the GPU index after the filtering has been done.',
)
_BUCKETS = flags.DEFINE_list(
'buckets',
# pyformat: disable
['256', '512', '768', '1024', '1280', '1536', '2048', '2560', '3072',
'3584', '4096', '4608', '5120'],
# pyformat: enable
'Strictly increasing order of token sizes for which to cache compilations.'
' For any input with more tokens than the largest bucket size, a new bucket'
' is created for exactly that number of tokens.',
)
_FLASH_ATTENTION_IMPLEMENTATION = flags.DEFINE_enum(
'flash_attention_implementation',
default='triton',
enum_values=['triton', 'cudnn', 'xla'],
help=(
"Flash attention implementation to use. 'triton' and 'cudnn' uses a"
' Triton and cuDNN flash attention implementation, respectively. The'
' Triton kernel is fastest and has been tested more thoroughly. The'
" Triton and cuDNN kernels require Ampere GPUs or later. 'xla' uses an"
' XLA attention implementation (no flash attention) and is portable'
' across GPU devices.'
),
)
_NUM_RECYCLES = flags.DEFINE_integer(
'num_recycles',
10,
'Number of recycles to use during inference.',
lower_bound=1,
)
_NUM_DIFFUSION_SAMPLES = flags.DEFINE_integer(
'num_diffusion_samples',
5,
'Number of diffusion samples to generate.',
lower_bound=1,
)
_NUM_SEEDS = flags.DEFINE_integer(
'num_seeds',
None,
'Number of seeds to use for inference. If set, only a single seed must be'
' provided in the input JSON. AlphaFold 3 will then generate random seeds'
' in sequence, starting from the single seed specified in the input JSON.'
' The full input JSON produced by AlphaFold 3 will include the generated'
' random seeds. If not set, AlphaFold 3 will use the seeds as provided in'
' the input JSON.',
lower_bound=1,
)
# Output controls.
_SAVE_EMBEDDINGS = flags.DEFINE_bool(
'save_embeddings',
False,
'Whether to save the final trunk single and pair embeddings in the output.'
' Note that the embeddings are large float16 arrays: num_tokens * 384'
' + num_tokens * num_tokens * 128.',
)
_SAVE_DISTOGRAM = flags.DEFINE_bool(
'save_distogram',
False,
'Whether to save the final distogram in the output. Note that the distogram'
' is a large float16 array: num_tokens * num_tokens * 64.',
)
_FORCE_OUTPUT_DIR = flags.DEFINE_bool(
'force_output_dir',
False,
'Whether to force the output directory to be used even if it already exists'
' and is non-empty. Useful to set this to True to run the data pipeline and'
' the inference separately, but use the same output directory.',
)
_COMPRESS_LARGE_OUTPUT_FILES = flags.DEFINE_bool(
'compress_large_output_files',
False,
'If True, compresses the output mmCIF and confidences JSON files (the two'
' largest files) using zstandard. Note that embeddings and distogram, if'
' saved, are already stored in a compressed format.',
)
def make_model_config(
*,
flash_attention_implementation: tokamax.DotProductAttentionImplementation = 'triton',
num_diffusion_samples: int = 5,
num_recycles: int = 10,
return_embeddings: bool = False,
return_distogram: bool = False,
) -> model.Model.Config:
"""Returns a model config with some defaults overridden."""
config = model.Model.Config()
config.global_config.flash_attention_implementation = (
flash_attention_implementation
)
config.heads.diffusion.eval.num_samples = num_diffusion_samples
config.num_recycles = num_recycles
config.return_embeddings = return_embeddings
config.return_distogram = return_distogram
return config
class ModelRunner:
"""Helper class to run structure prediction stages."""
def __init__(
self,
config: model.Model.Config,
device: jax.Device,
model_dir: pathlib.Path,
):
self._model_config = config
self._device = device
self._model_dir = model_dir
@functools.cached_property
def model_params(self) -> hk.Params:
"""Loads model parameters from the model directory."""
return params.get_model_haiku_params(model_dir=self._model_dir)
@functools.cached_property
def _model(
self,
) -> Callable[[jnp.ndarray, features.BatchDict], model.ModelResult]:
"""Loads model parameters and returns a jitted model forward pass."""
@hk.transform
def forward_fn(batch):
return model.Model(self._model_config)(batch)
return functools.partial(
jax.jit(forward_fn.apply, device=self._device), self.model_params
)
def run_inference(
self, featurised_example: features.BatchDict, rng_key: jnp.ndarray
) -> model.ModelResult:
"""Computes a forward pass of the model on a featurised example."""
featurised_example = jax.device_put(
jax.tree_util.tree_map(
jnp.asarray, utils.remove_invalidly_typed_feats(featurised_example)
),
self._device,
)
result = self._model(rng_key, featurised_example)
result = jax.tree.map(np.asarray, result)
result = jax.tree.map(
lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x,
result,
)
result = dict(result)
identifier = self.model_params['__meta__']['__identifier__'].tobytes()
result['__identifier__'] = identifier
return result
def extract_inference_results(
self,
batch: features.BatchDict,
result: model.ModelResult,
target_name: str,
) -> list[model.InferenceResult]:
"""Extracts inference results from model outputs."""
return list(
model.Model.get_inference_result(
batch=batch, result=result, target_name=target_name
)
)
def extract_embeddings(
self, result: model.ModelResult, num_tokens: int
) -> dict[str, np.ndarray] | None:
"""Extracts embeddings from model outputs."""
embeddings = {}
if 'single_embeddings' in result:
embeddings['single_embeddings'] = result['single_embeddings'][
:num_tokens
].astype(np.float16)
if 'pair_embeddings' in result:
embeddings['pair_embeddings'] = result['pair_embeddings'][
:num_tokens, :num_tokens
].astype(np.float16)
return embeddings or None
def extract_distogram(
self, result: model.ModelResult, num_tokens: int
) -> np.ndarray | None:
"""Extracts distogram from model outputs."""
if 'distogram' not in result['distogram']:
return None
distogram = result['distogram']['distogram'][:num_tokens, :num_tokens, :]
return distogram
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class ResultsForSeed:
"""Stores the inference results (diffusion samples) for a single seed.
Attributes:
seed: The seed used to generate the samples.
inference_results: The inference results, one per sample.
full_fold_input: The fold input that must also include the results of
running the data pipeline - MSA and templates.
embeddings: The final trunk single and pair embeddings, if requested.
distogram: The token distance histogram, if requested.
"""
seed: int
inference_results: Sequence[model.InferenceResult]
full_fold_input: folding_input.Input
embeddings: dict[str, np.ndarray] | None = None
distogram: np.ndarray | None = None
def predict_structure(
fold_input: folding_input.Input,
model_runner: ModelRunner,
buckets: Sequence[int] | None = None,
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
) -> Sequence[ResultsForSeed]:
"""Runs the full inference pipeline to predict structures for each seed."""
print(f'Featurising data with {len(fold_input.rng_seeds)} seed(s)...')
featurisation_start_time = time.time()
ccd = chemical_components.Ccd(user_ccd=fold_input.user_ccd)
featurised_examples = featurisation.featurise_input(
fold_input=fold_input,
buckets=buckets,
ccd=ccd,
verbose=True,
ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=conformer_max_iterations,
resolve_msa_overlaps=resolve_msa_overlaps,
)
print(
f'Featurising data with {len(fold_input.rng_seeds)} seed(s) took'
f' {time.time() - featurisation_start_time:.2f} seconds.'
)
print(
'Running model inference and extracting output structure samples with'
f' {len(fold_input.rng_seeds)} seed(s)...'
)
all_inference_start_time = time.time()
all_inference_results = []
for seed, example in zip(fold_input.rng_seeds, featurised_examples):
print(f'Running model inference with seed {seed}...')
inference_start_time = time.time()
rng_key = jax.random.PRNGKey(seed)
result = model_runner.run_inference(example, rng_key)
print(
f'Running model inference with seed {seed} took'
f' {time.time() - inference_start_time:.2f} seconds.'
)
print(f'Extracting inference results with seed {seed}...')
extract_structures = time.time()
inference_results = model_runner.extract_inference_results(
batch=example, result=result, target_name=fold_input.name
)
num_tokens = len(inference_results[0].metadata['token_chain_ids'])
embeddings = model_runner.extract_embeddings(
result=result, num_tokens=num_tokens
)
distogram = model_runner.extract_distogram(
result=result, num_tokens=num_tokens
)
print(
f'Extracting {len(inference_results)} inference samples with'
f' seed {seed} took {time.time() - extract_structures:.2f} seconds.'
)
all_inference_results.append(
ResultsForSeed(
seed=seed,
inference_results=inference_results,
full_fold_input=fold_input,
embeddings=embeddings,
distogram=distogram,
)
)
print(
'Running model inference and extracting output structures with'
f' {len(fold_input.rng_seeds)} seed(s) took'
f' {time.time() - all_inference_start_time:.2f} seconds.'
)
return all_inference_results
def write_fold_input_json(
fold_input: folding_input.Input,
output_dir: os.PathLike[str] | str,
) -> None:
"""Writes the input JSON to the output directory."""
os.makedirs(output_dir, exist_ok=True)
path = os.path.join(output_dir, f'{fold_input.sanitised_name()}_data.json')
print(f'Writing model input JSON to {path}')
with open(path, 'wt') as f:
f.write(fold_input.to_json())
def write_outputs(
all_inference_results: Sequence[ResultsForSeed],
output_dir: os.PathLike[str] | str,
job_name: str,
compress_large_output_files: bool = False,
) -> None:
"""Writes outputs to the specified output directory."""
ranking_scores = []
max_ranking_score = None
max_ranking_result = None
output_terms = (
pathlib.Path(alphafold3.cpp.__file__).parent / 'OUTPUT_TERMS_OF_USE.md'
).read_text()
os.makedirs(output_dir, exist_ok=True)
for results_for_seed in all_inference_results:
seed = results_for_seed.seed
for sample_idx, result in enumerate(results_for_seed.inference_results):
sample_dir = os.path.join(output_dir, f'seed-{seed}_sample-{sample_idx}')
os.makedirs(sample_dir, exist_ok=True)
post_processing.write_output(
inference_result=result,
output_dir=sample_dir,
name=f'{job_name}_seed-{seed}_sample-{sample_idx}',
compress=compress_large_output_files,
)
ranking_score = float(result.metadata['ranking_score'])
ranking_scores.append((seed, sample_idx, ranking_score))
if max_ranking_score is None or ranking_score > max_ranking_score:
max_ranking_score = ranking_score
max_ranking_result = result
if embeddings := results_for_seed.embeddings:
embeddings_dir = os.path.join(output_dir, f'seed-{seed}_embeddings')
os.makedirs(embeddings_dir, exist_ok=True)
post_processing.write_embeddings(
embeddings=embeddings,
output_dir=embeddings_dir,
name=f'{job_name}_seed-{seed}',
)
if (distogram := results_for_seed.distogram) is not None:
distogram_dir = os.path.join(output_dir, f'seed-{seed}_distogram')
os.makedirs(distogram_dir, exist_ok=True)
distogram_path = os.path.join(
distogram_dir, f'{job_name}_seed-{seed}_distogram.npz'
)
with open(distogram_path, 'wb') as f:
np.savez_compressed(f, distogram=distogram.astype(np.float16))
if max_ranking_result is not None: # True iff ranking_scores non-empty.
post_processing.write_output(
inference_result=max_ranking_result,
output_dir=output_dir,
# The output terms of use are the same for all seeds/samples.
terms_of_use=output_terms,
name=job_name,
compress=compress_large_output_files,
)
# Save csv of ranking scores with seeds and sample indices, to allow easier
# comparison of ranking scores across different runs.
with open(
os.path.join(output_dir, f'{job_name}_ranking_scores.csv'), 'wt'
) as f:
writer = csv.writer(f)
writer.writerow(['seed', 'sample', 'ranking_score'])
writer.writerows(ranking_scores)
def replace_db_dir(path_with_db_dir: str, db_dirs: Sequence[str]) -> str:
"""Replaces the DB_DIR placeholder in a path with the given DB_DIR."""
template = string.Template(path_with_db_dir)
if 'DB_DIR' in template.get_identifiers():
for db_dir in db_dirs:
path = template.substitute(DB_DIR=db_dir)
if os.path.exists(path):
return path
raise FileNotFoundError(
f'{path_with_db_dir} with ${{DB_DIR}} not found in any of {db_dirs}.'
)
if (sharded_paths := shards.get_sharded_paths(path_with_db_dir)) is not None:
db_exists = all(os.path.exists(p) for p in sharded_paths)
else:
db_exists = os.path.exists(path_with_db_dir)
if not db_exists:
raise FileNotFoundError(f'{path_with_db_dir} does not exist.')
return path_with_db_dir
@overload
def process_fold_input(
fold_input: folding_input.Input,
data_pipeline_config: pipeline.DataPipelineConfig | None,
*,
model_runner: None,
output_dir: os.PathLike[str] | str,
buckets: Sequence[int] | None = None,
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
force_output_dir: bool = False,
compress_large_output_files: bool = False,
) -> folding_input.Input:
...
@overload
def process_fold_input(
fold_input: folding_input.Input,
data_pipeline_config: pipeline.DataPipelineConfig | None,
*,
model_runner: ModelRunner,
output_dir: os.PathLike[str] | str,
buckets: Sequence[int] | None = None,
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
force_output_dir: bool = False,
compress_large_output_files: bool = False,
) -> Sequence[ResultsForSeed]:
...
def process_fold_input(
fold_input: folding_input.Input,
data_pipeline_config: pipeline.DataPipelineConfig | None,
*,
model_runner: ModelRunner | None,
output_dir: os.PathLike[str] | str,
buckets: Sequence[int] | None = None,
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
force_output_dir: bool = False,
compress_large_output_files: bool = False,
) -> folding_input.Input | Sequence[ResultsForSeed]:
"""Runs data pipeline and/or inference on a single fold input.
Args:
fold_input: Fold input to process.
data_pipeline_config: Data pipeline config to use. If None, skip the data
pipeline.
model_runner: Model runner to use. If None, skip inference.
output_dir: Output directory to write to.
buckets: Bucket sizes to pad the data to, to avoid excessive re-compilation
of the model. If None, calculate the appropriate bucket size from the
number of tokens. If not None, must be a sequence of at least one integer,
in strictly increasing order. Will raise an error if the number of tokens
is more than the largest bucket size.
ref_max_modified_date: Optional maximum date that controls whether to allow
use of model coordinates for a chemical component from the CCD if RDKit
conformer generation fails and the component does not have ideal
coordinates set. Only for components that have been released before this
date the model coordinates can be used as a fallback.
conformer_max_iterations: Optional override for maximum number of iterations
to run for RDKit conformer search.
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
MSA. The default behaviour matches the method described in the AlphaFold 3
paper. Set this to false if providing custom paired MSA using the unpaired
MSA field to keep it exactly as is as deduplication against the paired MSA
could break the manually crafted pairing between MSA sequences.
force_output_dir: If True, do not create a new output directory even if the
existing one is non-empty. Instead use the existing output directory and
potentially overwrite existing files. If False, create a new timestamped
output directory instead if the existing one is non-empty.
compress_large_output_files: If True, compress large output files (mmCIF and
confidences JSON) using zstandard.
Returns:
The processed fold input, or the inference results for each seed.
Raises:
ValueError: If the fold input has no chains.
"""
print(f'\nRunning fold job {fold_input.name}...')
if not fold_input.chains:
raise ValueError('Fold input has no chains.')
if (
not force_output_dir
and os.path.exists(output_dir)
and os.listdir(output_dir)
):
new_output_dir = (
f'{output_dir}_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}'
)
print(
f'Output will be written in {new_output_dir} since {output_dir} is'
' non-empty.'
)
output_dir = new_output_dir
else:
print(f'Output will be written in {output_dir}')
if data_pipeline_config is None:
print('Skipping data pipeline...')
else:
print('Running data pipeline...')
fold_input = pipeline.DataPipeline(data_pipeline_config).process(fold_input)
write_fold_input_json(fold_input, output_dir)
if model_runner is None:
print('Skipping model inference...')
output = fold_input
else:
print(
f'Predicting 3D structure for {fold_input.name} with'
f' {len(fold_input.rng_seeds)} seed(s)...'
)
all_inference_results = predict_structure(
fold_input=fold_input,
model_runner=model_runner,
buckets=buckets,
ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=conformer_max_iterations,
resolve_msa_overlaps=resolve_msa_overlaps,
)
print(f'Writing outputs with {len(fold_input.rng_seeds)} seed(s)...')
write_outputs(
all_inference_results=all_inference_results,
output_dir=output_dir,
job_name=fold_input.sanitised_name(),
compress_large_output_files=compress_large_output_files,
)
output = all_inference_results
print(f'Fold job {fold_input.name} done, output written to {output_dir}\n')
return output
def main(_):
if _JAX_COMPILATION_CACHE_DIR.value is not None:
jax.config.update(
'jax_compilation_cache_dir', _JAX_COMPILATION_CACHE_DIR.value
)
if _JSON_PATH.value is None == _INPUT_DIR.value is None:
raise ValueError(
'Exactly one of --json_path or --input_dir must be specified.'
)
if not _RUN_INFERENCE.value and not _RUN_DATA_PIPELINE.value:
raise ValueError(
'At least one of --run_inference or --run_data_pipeline must be'
' set to true.'
)
if _INPUT_DIR.value is not None:
fold_inputs = folding_input.load_fold_inputs_from_dir(
pathlib.Path(_INPUT_DIR.value)
)
elif _JSON_PATH.value is not None:
fold_inputs = folding_input.load_fold_inputs_from_path(
pathlib.Path(_JSON_PATH.value)
)
else:
raise AssertionError(
'Exactly one of --json_path or --input_dir must be specified.'
)
# Make sure we can create the output directory before running anything.
try:
os.makedirs(_OUTPUT_DIR.value, exist_ok=True)
except OSError as e:
print(f'Failed to create output directory {_OUTPUT_DIR.value}: {e}')
raise
if _RUN_INFERENCE.value:
# Fail early on incompatible devices, but only if we're running inference.
gpu_devices = jax.local_devices(backend='gpu')
if gpu_devices:
compute_capability = float(
gpu_devices[_GPU_DEVICE.value].compute_capability
)
if compute_capability < 6.0:
raise ValueError(
'AlphaFold 3 requires at least GPU compute capability 6.0 (see'
' https://developer.nvidia.com/cuda-gpus).'
)
elif 7.0 <= compute_capability < 8.0:
xla_flags = os.environ.get('XLA_FLAGS')
required_flag = '--xla_disable_hlo_passes=custom-kernel-fusion-rewriter'
if not xla_flags or required_flag not in xla_flags:
raise ValueError(
'For devices with GPU compute capability 7.x (see'
' https://developer.nvidia.com/cuda-gpus) the ENV XLA_FLAGS must'
f' include "{required_flag}".'
)
if _FLASH_ATTENTION_IMPLEMENTATION.value != 'xla':
raise ValueError(
'For devices with GPU compute capability 7.x (see'
' https://developer.nvidia.com/cuda-gpus) the'
' --flash_attention_implementation must be set to "xla".'
)
notice = textwrap.wrap(
'Running AlphaFold 3. Please note that standard AlphaFold 3 model'
' parameters are only available under terms of use provided at'
' https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.'
' If you do not agree to these terms and are using AlphaFold 3 derived'
' model parameters, cancel execution of AlphaFold 3 inference with'
' CTRL-C, and do not use the model parameters.',
break_long_words=False,
break_on_hyphens=False,
width=80,
)
print('\n' + '\n'.join(notice) + '\n')
max_template_date = datetime.date.fromisoformat(_MAX_TEMPLATE_DATE.value)
if _RUN_DATA_PIPELINE.value:
expand_path = lambda x: replace_db_dir(x, DB_DIR.value)
data_pipeline_config = pipeline.DataPipelineConfig(
jackhmmer_binary_path=_JACKHMMER_BINARY_PATH.value,
nhmmer_binary_path=_NHMMER_BINARY_PATH.value,
hmmalign_binary_path=_HMMALIGN_BINARY_PATH.value,
hmmsearch_binary_path=_HMMSEARCH_BINARY_PATH.value,
hmmbuild_binary_path=_HMMBUILD_BINARY_PATH.value,
small_bfd_database_path=expand_path(_SMALL_BFD_DATABASE_PATH.value),
small_bfd_z_value=_SMALL_BFD_Z_VALUE.value,
mgnify_database_path=expand_path(_MGNIFY_DATABASE_PATH.value),
mgnify_z_value=_MGNIFY_Z_VALUE.value,
uniprot_cluster_annot_database_path=expand_path(
_UNIPROT_CLUSTER_ANNOT_DATABASE_PATH.value
),
uniprot_cluster_annot_z_value=_UNIPROT_CLUSTER_ANNOT_Z_VALUE.value,
uniref90_database_path=expand_path(_UNIREF90_DATABASE_PATH.value),
uniref90_z_value=_UNIREF90_Z_VALUE.value,
ntrna_database_path=expand_path(_NTRNA_DATABASE_PATH.value),
ntrna_z_value=_NTRNA_Z_VALUE.value,
rfam_database_path=expand_path(_RFAM_DATABASE_PATH.value),
rfam_z_value=_RFAM_Z_VALUE.value,
rna_central_database_path=expand_path(_RNA_CENTRAL_DATABASE_PATH.value),
rna_central_z_value=_RNA_CENTRAL_Z_VALUE.value,
pdb_database_path=expand_path(_PDB_DATABASE_PATH.value),
seqres_database_path=expand_path(_SEQRES_DATABASE_PATH.value),
jackhmmer_n_cpu=_JACKHMMER_N_CPU.value,
jackhmmer_max_parallel_shards=_JACKHMMER_MAX_PARALLEL_SHARDS.value,
nhmmer_n_cpu=_NHMMER_N_CPU.value,
nhmmer_max_parallel_shards=_NHMMER_MAX_PARALLEL_SHARDS.value,
max_template_date=max_template_date,
)
else:
data_pipeline_config = None
if _RUN_INFERENCE.value:
devices = jax.local_devices(backend='gpu')
print(
f'Found local devices: {devices}, using device {_GPU_DEVICE.value}:'
f' {devices[_GPU_DEVICE.value]}'
)
print('Building model from scratch...')
model_runner = ModelRunner(
config=make_model_config(
flash_attention_implementation=typing.cast(
tokamax.DotProductAttentionImplementation,
_FLASH_ATTENTION_IMPLEMENTATION.value,
),
num_diffusion_samples=_NUM_DIFFUSION_SAMPLES.value,
num_recycles=_NUM_RECYCLES.value,
return_embeddings=_SAVE_EMBEDDINGS.value,
return_distogram=_SAVE_DISTOGRAM.value,
),
device=devices[_GPU_DEVICE.value],
model_dir=pathlib.Path(MODEL_DIR.value),
)
# Check we can load the model parameters before launching anything.
print('Checking that model parameters can be loaded...')
_ = model_runner.model_params
else:
model_runner = None
num_fold_inputs = 0
for fold_input in fold_inputs:
if _NUM_SEEDS.value is not None:
print(f'Expanding fold job {fold_input.name} to {_NUM_SEEDS.value} seeds')
fold_input = fold_input.with_multiple_seeds(_NUM_SEEDS.value)
process_fold_input(
fold_input=fold_input,
data_pipeline_config=data_pipeline_config,
model_runner=model_runner,
output_dir=os.path.join(_OUTPUT_DIR.value, fold_input.sanitised_name()),
buckets=tuple(int(bucket) for bucket in _BUCKETS.value),
ref_max_modified_date=max_template_date,
conformer_max_iterations=_CONFORMER_MAX_ITERATIONS.value,
resolve_msa_overlaps=_RESOLVE_MSA_OVERLAPS.value,
force_output_dir=_FORCE_OUTPUT_DIR.value,
compress_large_output_files=_COMPRESS_LARGE_OUTPUT_FILES.value,
)
num_fold_inputs += 1
print(f'Done running {num_fold_inputs} fold jobs.')
if __name__ == '__main__':
flags.mark_flags_as_required(['output_dir'])
app.run(main)
================================================
FILE: run_alphafold_data_test.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Tests the AlphaFold 3 data pipeline."""
import contextlib
import datetime
import difflib
import functools
import hashlib
import json
import os
import pathlib
import pickle
from typing import Any
from absl.testing import absltest
from absl.testing import parameterized
from alphafold3 import structure
from alphafold3.common import folding_input
from alphafold3.common import resources
from alphafold3.common.testing import data as testing_data
from alphafold3.constants import chemical_components
from alphafold3.data import featurisation
from alphafold3.data import pipeline
from alphafold3.model.atom_layout import atom_layout
import jax
import numpy as np
import run_alphafold
import shutil
_JACKHMMER_BINARY_PATH = shutil.which('jackhmmer')
_NHMMER_BINARY_PATH = shutil.which('nhmmer')
_HMMALIGN_BINARY_PATH = shutil.which('hmmalign')
_HMMSEARCH_BINARY_PATH = shutil.which('hmmsearch')
_HMMBUILD_BINARY_PATH = shutil.which('hmmbuild')
@contextlib.contextmanager
def _output(name: str):
with open(result_path := f'{absltest.TEST_TMPDIR.value}/{name}', "wb") as f:
yield result_path, f
@functools.singledispatch
def _hash_data(x: Any, /) -> str:
if x is None:
return '<>'
return _hash_data(json.dumps(x).encode('utf-8'))
@_hash_data.register
def _(x: bytes, /) -> str:
return hashlib.sha256(x).hexdigest()
@_hash_data.register
def _(x: jax.Array) -> str:
return _hash_data(jax.device_get(x))
@_hash_data.register
def _(x: np.ndarray) -> str:
if x.dtype == object:
return ';'.join(map(_hash_data, x.ravel().tolist()))
return _hash_data(x.tobytes())
@_hash_data.register
def _(_: structure.Structure) -> str:
return '<>'
@_hash_data.register
def _(_: atom_layout.AtomLayout) -> str:
return '<>'
def _generate_diff(actual: str, expected: str) -> str:
return '\n'.join(
difflib.unified_diff(
expected.split('\n'),
actual.split('\n'),
fromfile='expected',
tofile='actual',
lineterm='',
)
)
class DataPipelineTest(parameterized.TestCase):
"""Test AlphaFold 3 inference."""
def setUp(self):
super().setUp()
small_bfd_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/bfd-first_non_consensus_sequences__subsampled_1000.fasta'
).path()
mgnify_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/mgy_clusters__subsampled_1000.fa'
).path()
uniprot_cluster_annot_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/uniprot_all__subsampled_1000.fasta'
).path()
uniref90_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/uniref90__subsampled_1000.fasta'
).path()
ntrna_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq__subsampled_1000.fasta'
).path()
rfam_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/rfam_14_4_clustered_rep_seq__subsampled_1000.fasta'
).path()
rna_central_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/rnacentral_active_seq_id_90_cov_80_linclust__subsampled_1000.fasta'
).path()
pdb_database_path = testing_data.Data(
resources.ROOT / 'test_data/miniature_databases/pdb_mmcif'
).path()
seqres_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/pdb_seqres_2022_09_28__subsampled_1000.fasta'
).path()
self._data_pipeline_config = pipeline.DataPipelineConfig(
jackhmmer_binary_path=_JACKHMMER_BINARY_PATH,
nhmmer_binary_path=_NHMMER_BINARY_PATH,
hmmalign_binary_path=_HMMALIGN_BINARY_PATH,
hmmsearch_binary_path=_HMMSEARCH_BINARY_PATH,
hmmbuild_binary_path=_HMMBUILD_BINARY_PATH,
small_bfd_database_path=small_bfd_database_path,
mgnify_database_path=mgnify_database_path,
uniprot_cluster_annot_database_path=uniprot_cluster_annot_database_path,
uniref90_database_path=uniref90_database_path,
ntrna_database_path=ntrna_database_path,
rfam_database_path=rfam_database_path,
rna_central_database_path=rna_central_database_path,
pdb_database_path=pdb_database_path,
seqres_database_path=seqres_database_path,
max_template_date=datetime.date(2021, 9, 30),
)
test_input = {
'name': '5tgy',
'modelSeeds': [1234],
'sequences': [
{
'protein': {
'id': 'P',
'sequence': (
'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN'
),
'modifications': [],
'unpairedMsa': None,
'pairedMsa': None,
}
},
{'ligand': {'id': 'LL', 'ccdCodes': ['7BU']}},
],
'dialect': folding_input.JSON_DIALECT,
'version': folding_input.JSON_VERSION,
}
self._test_input_json = json.dumps(test_input)
def compare_golden(self, result_path: str) -> None:
filename = os.path.split(result_path)[1]
golden_path = testing_data.Data(
resources.ROOT / f'test_data/{filename}'
).path()
with open(golden_path, 'r') as golden_file:
golden_text = golden_file.read()
with open(result_path, 'r') as result_file:
result_text = result_file.read()
diff = _generate_diff(result_text, golden_text)
self.assertEqual(diff, "", f"Result differs from golden:\n{diff}")
def test_config(self):
model_config = run_alphafold.make_model_config()
model_config_as_str = json.dumps(
model_config.as_dict(), sort_keys=True, indent=2
)
with _output('model_config.json') as (result_path, output):
output.write(model_config_as_str.encode('utf-8'))
self.compare_golden(result_path)
def test_featurisation(self):
"""Run featurisation and assert that the output is as expected."""
fold_input = folding_input.Input.from_json(self._test_input_json)
data_pipeline = pipeline.DataPipeline(self._data_pipeline_config)
full_fold_input = data_pipeline.process(fold_input)
featurised_example = featurisation.featurise_input(
full_fold_input,
ccd=chemical_components.Ccd(),
buckets=None,
)
del featurised_example[0]['ref_pos'] # Depends on specific RDKit version.
with _output('featurised_example.pkl') as (_, output):
output.write(pickle.dumps(featurised_example))
featurised_example = jax.tree_util.tree_map(_hash_data, featurised_example)
with _output('featurised_example.json') as (result_path, output):
output.write(
json.dumps(featurised_example, sort_keys=True, indent=2).encode(
'utf-8'
)
)
self.compare_golden(result_path)
def test_write_input_json(self):
fold_input = folding_input.Input.from_json(self._test_input_json)
output_dir = self.create_tempdir().full_path
run_alphafold.write_fold_input_json(fold_input, output_dir)
with open(
os.path.join(output_dir, f'{fold_input.sanitised_name()}_data.json'),
'rt',
) as f:
actual_fold_input = folding_input.Input.from_json(f.read())
self.assertEqual(actual_fold_input, fold_input)
def test_process_fold_input_runs_only_data_pipeline(self):
fold_input = folding_input.Input.from_json(self._test_input_json)
output_dir = self.create_tempdir().full_path
run_alphafold.process_fold_input(
fold_input=fold_input,
data_pipeline_config=self._data_pipeline_config,
model_runner=None,
output_dir=output_dir,
)
with open(
os.path.join(output_dir, f'{fold_input.sanitised_name()}_data.json'),
'rt',
) as f:
actual_fold_input = folding_input.Input.from_json(f.read())
featurisation.validate_fold_input(actual_fold_input)
@parameterized.product(num_db_dirs=tuple(range(1, 3)))
def test_replace_db_dir(self, num_db_dirs: int) -> None:
"""Test that the db_dir is replaced correctly."""
db_dirs = [pathlib.Path(self.create_tempdir()) for _ in range(num_db_dirs)]
db_dirs_posix = [db_dir.as_posix() for db_dir in db_dirs]
for i, db_dir in enumerate(db_dirs):
for j in range(i + 1):
(db_dir / f'filename{j}.txt').write_text(f'hello world {i}')
for i in range(num_db_dirs):
self.assertEqual(
pathlib.Path(
run_alphafold.replace_db_dir(
f'${{DB_DIR}}/filename{i}.txt', db_dirs_posix
)
).read_text(),
f'hello world {i}',
)
with self.assertRaises(FileNotFoundError):
run_alphafold.replace_db_dir(
f'${{DB_DIR}}/filename{num_db_dirs}.txt', db_dirs_posix
)
if __name__ == '__main__':
absltest.main()
================================================
FILE: run_alphafold_test.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Tests end-to-end running of AlphaFold 3."""
import contextlib
import csv
import dataclasses
import datetime
import difflib
import json
import os
import pathlib
import pickle
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
from alphafold3.common import folding_input
from alphafold3.common import resources
from alphafold3.common.testing import data as testing_data
from alphafold3.data import pipeline
from alphafold3.model.scoring import alignment
import jax
import numpy as np
import run_alphafold
import shutil
_JACKHMMER_BINARY_PATH = shutil.which('jackhmmer')
_NHMMER_BINARY_PATH = shutil.which('nhmmer')
_HMMALIGN_BINARY_PATH = shutil.which('hmmalign')
_HMMSEARCH_BINARY_PATH = shutil.which('hmmsearch')
_HMMBUILD_BINARY_PATH = shutil.which('hmmbuild')
@contextlib.contextmanager
def _output(name: str):
with open(result_path := f'{absltest.TEST_TMPDIR.value}/{name}', "wb") as f:
yield result_path, f
jax.config.update('jax_enable_compilation_cache', False)
def _generate_diff(actual: str, expected: str) -> str:
return '\n'.join(
difflib.unified_diff(
expected.split('\n'),
actual.split('\n'),
fromfile='expected',
tofile='actual',
lineterm='',
)
)
class InferenceTest(parameterized.TestCase):
"""Test AlphaFold 3 inference."""
def setUp(self):
super().setUp()
small_bfd_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/bfd-first_non_consensus_sequences__subsampled_1000.fasta'
).path()
mgnify_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/mgy_clusters__subsampled_1000.fa'
).path()
uniprot_cluster_annot_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/uniprot_all__subsampled_1000.fasta'
).path()
uniref90_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/uniref90__subsampled_1000.fasta'
).path()
ntrna_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq__subsampled_1000.fasta'
).path()
rfam_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/rfam_14_4_clustered_rep_seq__subsampled_1000.fasta'
).path()
rna_central_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/rnacentral_active_seq_id_90_cov_80_linclust__subsampled_1000.fasta'
).path()
pdb_database_path = testing_data.Data(
resources.ROOT / 'test_data/miniature_databases/pdb_mmcif'
).path()
seqres_database_path = testing_data.Data(
resources.ROOT
/ 'test_data/miniature_databases/pdb_seqres_2022_09_28__subsampled_1000.fasta'
).path()
self._data_pipeline_config = pipeline.DataPipelineConfig(
jackhmmer_binary_path=_JACKHMMER_BINARY_PATH,
nhmmer_binary_path=_NHMMER_BINARY_PATH,
hmmalign_binary_path=_HMMALIGN_BINARY_PATH,
hmmsearch_binary_path=_HMMSEARCH_BINARY_PATH,
hmmbuild_binary_path=_HMMBUILD_BINARY_PATH,
small_bfd_database_path=small_bfd_database_path,
mgnify_database_path=mgnify_database_path,
uniprot_cluster_annot_database_path=uniprot_cluster_annot_database_path,
uniref90_database_path=uniref90_database_path,
ntrna_database_path=ntrna_database_path,
rfam_database_path=rfam_database_path,
rna_central_database_path=rna_central_database_path,
pdb_database_path=pdb_database_path,
seqres_database_path=seqres_database_path,
max_template_date=datetime.date(2021, 9, 30),
)
test_input = {
'name': '5tgy',
'modelSeeds': [1234],
'sequences': [
{
'protein': {
'id': 'P',
'sequence': (
'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN'
),
'modifications': [],
'unpairedMsa': None,
'pairedMsa': None,
}
},
{'ligand': {'id': 'LL', 'ccdCodes': ['7BU']}},
],
'dialect': folding_input.JSON_DIALECT,
'version': folding_input.JSON_VERSION,
}
self._test_input_json = json.dumps(test_input)
self._model_config = run_alphafold.make_model_config(
flash_attention_implementation='triton',
return_embeddings=True,
return_distogram=True,
)
self._runner = run_alphafold.ModelRunner(
config=self._model_config,
device=jax.local_devices()[0],
model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value),
)
def test_model_inference(self):
"""Run model inference and assert that output exists."""
featurised_examples = pickle.loads(
(resources.ROOT / 'test_data' / 'featurised_example.pkl').read_bytes()
)
self.assertLen(featurised_examples, 1)
featurised_example = featurised_examples[0]
result = self._runner.run_inference(
featurised_example, jax.random.PRNGKey(0)
)
self.assertIsNotNone(result)
inference_results = self._runner.extract_inference_results(
batch=featurised_example, result=result, target_name='target'
)
embeddings = self._runner.extract_embeddings(
result=result,
num_tokens=len(inference_results[0].metadata['token_chain_ids']),
)
self.assertLen(embeddings, 2)
def test_process_fold_input_runs_only_inference(self):
with self.assertRaisesRegex(ValueError, 'missing unpaired MSA.'):
run_alphafold.process_fold_input(
fold_input=folding_input.Input.from_json(self._test_input_json),
# No data pipeline config, so featurisation will run first, and fail
# since the input is missing MSAs.
data_pipeline_config=None,
model_runner=self._runner,
output_dir=self.create_tempdir().full_path,
)
@parameterized.named_parameters(
{
'testcase_name': 'default_bucket',
'bucket': None,
'seed': 1,
},
{
'testcase_name': 'bucket_1024',
'bucket': 1024,
'seed': 42,
},
)
def test_inference(self, bucket, seed):
"""Run AlphaFold 3 inference."""
### Prepare inputs with modified seed.
fold_input = folding_input.Input.from_json(self._test_input_json)
fold_input = dataclasses.replace(fold_input, rng_seeds=[seed])
output_dir = self.create_tempdir().full_path
actual = run_alphafold.process_fold_input(
fold_input,
self._data_pipeline_config,
model_runner=run_alphafold.ModelRunner(
config=self._model_config,
device=jax.local_devices(backend='gpu')[0],
model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value),
),
output_dir=output_dir,
buckets=None if bucket is None else [bucket],
)
logging.info('finished get_inference_result')
expected_model_cif_filename = f'{fold_input.sanitised_name()}_model.cif'
expected_summary_confidences_filename = (
f'{fold_input.sanitised_name()}_summary_confidences.json'
)
expected_confidences_filename = (
f'{fold_input.sanitised_name()}_confidences.json'
)
expected_data_json_filename = f'{fold_input.sanitised_name()}_data.json'
prefix = f'seed-{seed}'
self.assertSameElements(
os.listdir(output_dir),
[
# Subdirectories, one for each sample and one for embeddings.
f'{prefix}_sample-0',
f'{prefix}_sample-1',
f'{prefix}_sample-2',
f'{prefix}_sample-3',
f'{prefix}_sample-4',
f'{prefix}_embeddings',
f'{prefix}_distogram',
# Top ranking result.
expected_confidences_filename,
expected_model_cif_filename,
expected_summary_confidences_filename,
# Ranking scores for all samples.
f'{fold_input.sanitised_name()}_ranking_scores.csv',
# The input JSON defining the job.
expected_data_json_filename,
# The output terms of use.
'TERMS_OF_USE.md',
],
)
for sample_index in range(5):
sample_dir = os.path.join(output_dir, f'{prefix}_sample-{sample_index}')
sample_prefix = (
f'{fold_input.sanitised_name()}_seed-{seed}_sample-{sample_index}'
)
self.assertSameElements(
os.listdir(sample_dir),
[
f'{sample_prefix}_confidences.json',
f'{sample_prefix}_model.cif',
f'{sample_prefix}_summary_confidences.json',
],
)
embeddings_dir = os.path.join(output_dir, f'{prefix}_embeddings')
embeddings_filename = (
f'{fold_input.sanitised_name()}_{prefix}_embeddings.npz'
)
self.assertSameElements(os.listdir(embeddings_dir), [embeddings_filename])
with open(os.path.join(embeddings_dir, embeddings_filename), 'rb') as f:
embeddings = np.load(f)
self.assertSameElements(
embeddings.keys(), ['single_embeddings', 'pair_embeddings']
)
# Ligand 7BU has 41 tokens.
num_tokens = len(fold_input.protein_chains[0].sequence) + 41
self.assertEqual(embeddings['single_embeddings'].shape, (num_tokens, 384))
self.assertEqual(embeddings['single_embeddings'].dtype, np.float16)
self.assertEqual(
embeddings['pair_embeddings'].shape, (num_tokens, num_tokens, 128)
)
self.assertEqual(embeddings['pair_embeddings'].dtype, np.float16)
distogram_dir = os.path.join(output_dir, f'{prefix}_distogram')
distogram_filename = f'{fold_input.sanitised_name()}_{prefix}_distogram.npz'
self.assertSameElements(os.listdir(distogram_dir), [distogram_filename])
with open(os.path.join(distogram_dir, distogram_filename), 'rb') as f:
distogram = np.load(f)['distogram']
self.assertEqual(distogram.shape, (num_tokens, num_tokens, 64))
self.assertEqual(distogram.dtype, np.float16)
with open(os.path.join(output_dir, expected_data_json_filename), 'rt') as f:
actual_input_json = json.load(f)
self.assertEqual(
actual_input_json['sequences'][0]['protein']['sequence'],
fold_input.protein_chains[0].sequence,
)
self.assertSequenceEqual(
actual_input_json['sequences'][1]['ligand']['ccdCodes'],
fold_input.ligands[0].ccd_ids,
)
self.assertNotEmpty(
actual_input_json['sequences'][0]['protein']['unpairedMsa']
)
self.assertNotEmpty(
actual_input_json['sequences'][0]['protein']['pairedMsa']
)
self.assertIsNotNone(
actual_input_json['sequences'][0]['protein']['templates']
)
ranking_scores_filename = (
f'{fold_input.sanitised_name()}_ranking_scores.csv'
)
with open(os.path.join(output_dir, ranking_scores_filename), 'rt') as f:
ranking_scores = list(csv.DictReader(f))
self.assertLen(ranking_scores, 5)
self.assertEqual([int(s['seed']) for s in ranking_scores], [seed] * 5)
self.assertEqual(
[int(s['sample']) for s in ranking_scores], [0, 1, 2, 3, 4]
)
# Ranking score should be in the expected range for all samples.
ranking_scores = [float(s['ranking_score']) for s in ranking_scores]
lower = 0.66
upper = 0.78
scores_ok = [lower <= score <= upper for score in ranking_scores]
if not all(scores_ok):
printable_scores = [f'{score:.2f}' for score in ranking_scores]
self.fail(
f'Ranking scores {printable_scores} not in expected range '
f'[{lower:.2f}, {upper:.2f}]'
)
with open(os.path.join(output_dir, 'TERMS_OF_USE.md'), 'rt') as f:
actual_terms_of_use = f.read()
self.assertStartsWith(
actual_terms_of_use, '# ALPHAFOLD 3 OUTPUT TERMS OF USE'
)
bucket_label = 'default' if bucket is None else bucket
output_filename = f'run_alphafold_test_output_bucket_{bucket_label}.pkl'
# Convert to dict to enable simple serialization.
actual_dict = [
dict(
seed=actual_inf.seed,
inference_results=actual_inf.inference_results,
full_fold_input=actual_inf.full_fold_input,
)
for actual_inf in actual
]
with _output(output_filename) as (_, output):
output.write(pickle.dumps(actual_dict))
logging.info('Comparing inference results with expected values.')
### Assert that output is as expected.
expected_dict = pickle.loads(
(
resources.ROOT
/ 'test_data'
/ 'alphafold_run_outputs'
/ output_filename
).read_bytes()
)
expected = [
run_alphafold.ResultsForSeed(**expected_inf)
for expected_inf in expected_dict
]
actual_rmsds = []
mask_proportions = []
actual_masked_rmsds = []
for actual_inf, expected_inf in zip(actual, expected, strict=True):
for actual_inf, expected_inf in zip(
actual_inf.inference_results,
expected_inf.inference_results,
strict=True,
):
# Make sure the token chain IDs are the same as the input chain IDs.
self.assertEqual(
actual_inf.metadata['token_chain_ids'],
['P'] * len(fold_input.protein_chains[0].sequence) + ['LL'] * 41,
)
# All atom occupancies should be 1.0.
np.testing.assert_array_equal(
actual_inf.predicted_structure.atom_occupancy,
[1.0] * actual_inf.predicted_structure.num_atoms,
)
actual_rmsds.append(
alignment.rmsd_from_coords(
decoy_coords=actual_inf.predicted_structure.coords,
gt_coords=expected_inf.predicted_structure.coords,
)
)
# Mask out atoms with b_factor < 80.0 (i.e. lower confidence regions).
mask = actual_inf.predicted_structure.atom_b_factor > 80.0
mask_proportions.append(
np.sum(mask) / actual_inf.predicted_structure.num_atoms
)
actual_masked_rmsds.append(
alignment.rmsd_from_coords(
decoy_coords=actual_inf.predicted_structure.coords,
gt_coords=expected_inf.predicted_structure.coords,
include_idxs=mask,
)
)
# 5tgy is stably predicted, samples should be all within 3.0 RMSD
# regardless of seed, bucket, device type, etc.
if any(rmsd > 3.0 for rmsd in actual_rmsds):
self.fail(f'Full RMSD too high: {actual_rmsds=}')
# Check proportion of atoms with b_factor > 80 is at least 70%.
if any(prop < 0.7 for prop in mask_proportions):
self.fail(f'Too many residues with low pLDDT: {mask_proportions=}')
# Check masked RMSD is within tolerance (lower than full RMSD due to masking
# of lower confidence regions).
if any(rmsd > 1.4 for rmsd in actual_masked_rmsds):
self.fail(f'Masked RMSD too high: {actual_masked_rmsds=}')
if __name__ == '__main__':
absltest.main()
================================================
FILE: src/alphafold3/__init__.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""An implementation of the inference pipeline of AlphaFold 3."""
================================================
FILE: src/alphafold3/build_data.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Script for building intermediate data."""
from importlib import resources
import os
import pathlib
import site
import alphafold3.constants.converters
from alphafold3.constants.converters import ccd_pickle_gen
from alphafold3.constants.converters import chemical_component_sets_gen
def build_data():
"""Builds intermediate data."""
libcifpp_data_dir = os.environ.get('LIBCIFPP_DATA_DIR')
if libcifpp_data_dir:
cif_path = pathlib.Path(libcifpp_data_dir) / 'components.cif'
else:
for site_path in site.getsitepackages():
path = pathlib.Path(site_path) / 'share/libcifpp/components.cif'
if path.exists():
cif_path = path
break
else:
raise ValueError(
'Could not find components.cif. If libcifpp is installed in a'
' non-standard location, please set the LIBCIFPP_DATA_DIR environment'
' variable to the directory where libcifpp is installed.'
)
out_root = resources.files(alphafold3.constants.converters)
ccd_pickle_path = out_root.joinpath('ccd.pickle')
chemical_component_sets_pickle_path = out_root.joinpath(
'chemical_component_sets.pickle'
)
ccd_pickle_gen.main(['', str(cif_path), str(ccd_pickle_path)])
chemical_component_sets_gen.main(
['', str(chemical_component_sets_pickle_path)]
)
================================================
FILE: src/alphafold3/common/base_config.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Config for the protein folding model and experiment."""
from collections.abc import Mapping
import copy
import dataclasses
import types
import typing
from typing import Any, ClassVar, TypeVar
_T = TypeVar('_T')
_ConfigT = TypeVar('_ConfigT', bound='BaseConfig')
def _strip_optional(t: type[Any]) -> type[Any]:
"""Transforms type annotations of the form `T | None` to `T`."""
if typing.get_origin(t) in (typing.Union, types.UnionType):
args = set(typing.get_args(t)) - {types.NoneType}
if len(args) == 1:
return args.pop()
return t
_NO_UPDATE = object()
class _Autocreate:
def __init__(self, **defaults: Any):
self.defaults = defaults
def autocreate(**defaults: Any) -> Any:
"""Marks a field as having a default factory derived from its type."""
return _Autocreate(**defaults)
def _clone_field(
field: dataclasses.Field[_T], new_default: _T
) -> dataclasses.Field[_T]:
if new_default is _NO_UPDATE:
return copy.copy(field)
return dataclasses.field(
default=new_default,
init=True,
kw_only=True,
repr=field.repr,
hash=field.hash,
compare=field.compare,
metadata=field.metadata,
)
@typing.dataclass_transform()
class ConfigMeta(type):
"""Metaclass that synthesizes a __post_init__ that coerces dicts to Config subclass instances."""
def __new__(mcs, name, bases, classdict):
cls = super().__new__(mcs, name, bases, classdict)
def _coercable_fields(self) -> Mapping[str, tuple[ConfigMeta, Any]]:
type_hints = typing.get_type_hints(self.__class__)
fields = dataclasses.fields(self.__class__)
field_to_type_and_default = {
field.name: (_strip_optional(type_hints[field.name]), field.default)
for field in fields
}
coercable_fields = {
f: t
for f, t in field_to_type_and_default.items()
if issubclass(type(t[0]), ConfigMeta)
}
return coercable_fields
cls._coercable_fields = property(_coercable_fields)
old_post_init = getattr(cls, '__post_init__', None)
def _post_init(self) -> None:
# Use get_type_hints instead of Field.type to ensure that forward
# references are resolved.
for field_name, (
field_type,
field_default,
) in self._coercable_fields.items(): # pylint: disable=protected-access
field_value = getattr(self, field_name)
if field_value is None:
continue
try:
match field_value:
case _Autocreate():
# Construct from field defaults.
setattr(self, field_name, field_type(**field_value.defaults))
case Mapping():
# Field value is not yet a `Config` instance; Assume we can create
# one by splatting keys and values.
args = {}
# Apply default args first, if present.
if isinstance(field_default, _Autocreate):
args.update(field_default.defaults)
args.update(field_value)
setattr(self, field_name, field_type(**args))
case _:
pass
except TypeError as e:
raise TypeError(
f'Failure while coercing field {field_name!r} of'
f' {self.__class__.__qualname__}'
) from e
if old_post_init:
old_post_init(self)
cls.__post_init__ = _post_init
return dataclasses.dataclass(kw_only=True)(cls)
class BaseConfig(metaclass=ConfigMeta):
"""Config base class.
Subclassing Config automatically makes the subclass a kw_only dataclass with
a `__post_init__` that coerces Config-subclass field values from mappings to
instances of the right type.
"""
# Provided by dataclasses.make_dataclass
__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]
# Overridden by metaclass
@property
def _coercable_fields(self) -> Mapping[str, tuple[type['BaseConfig'], Any]]:
return {}
def as_dict(self) -> Mapping[str, Any]:
result = dataclasses.asdict(self)
for field_name in self._coercable_fields:
field_value = getattr(self, field_name, None)
if isinstance(field_value, BaseConfig):
result[field_name] = field_value.as_dict()
return result
================================================
FILE: src/alphafold3/common/folding_input.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Model input dataclass."""
from collections.abc import Collection, Iterator, Mapping, Sequence
import dataclasses
import gzip
import json
import logging
import lzma
import os
import pathlib
import random
import re
import string
from typing import Any, Final, Self, TypeAlias, cast
from alphafold3 import structure
from alphafold3.constants import chemical_components
from alphafold3.constants import mmcif_names
from alphafold3.constants import residue_names
from alphafold3.cpp import cif_dict
from alphafold3.structure import mmcif as mmcif_lib
import rdkit.Chem as rd_chem
import zstandard as zstd
BondAtomId: TypeAlias = tuple[str, int, str]
JSON_DIALECT: Final[str] = 'alphafold3'
JSON_VERSIONS: Final[tuple[int, ...]] = (1, 2, 3, 4)
JSON_VERSION: Final[int] = JSON_VERSIONS[-1]
ALPHAFOLDSERVER_JSON_DIALECT: Final[str] = 'alphafoldserver'
ALPHAFOLDSERVER_JSON_VERSION: Final[int] = 1
def _validate_keys(actual: Collection[str], expected: Collection[str]):
"""Validates that the JSON doesn't contain any extra unwanted keys."""
if bad_keys := set(actual) - set(expected):
raise ValueError(f'Unexpected JSON keys in: {", ".join(sorted(bad_keys))}')
def _read_file(path: pathlib.Path, json_path: pathlib.Path | None) -> str:
"""Reads a maybe compressed (gzip, xz, zstd) file from the given path.
Args:
path: The path to the file to read. This can be either absolute path, or a
path relative to the JSON file path.
json_path: The path to the JSON file. If None, the path must be absolute.
Returns:
The contents of the file.
"""
if not path.is_absolute():
if json_path is None:
raise ValueError('json_path must be specified if path is not absolute.')
path = (json_path.parent / path).resolve()
with open(path, 'rb') as f:
first_six_bytes = f.read(6)
f.seek(0)
# Detect the compression type using the magic number in the header.
if first_six_bytes[:2] == b'\x1f\x8b':
with gzip.open(f, 'rt') as gzip_f:
return cast(str, gzip_f.read())
elif first_six_bytes == b'\xfd\x37\x7a\x58\x5a\x00':
with lzma.open(f, 'rt') as xz_f:
return cast(str, xz_f.read())
elif first_six_bytes[:4] == b'\x28\xb5\x2f\xfd':
with zstd.open(f, 'rt') as zstd_f:
return cast(str, zstd_f.read())
else:
return f.read().decode('utf-8')
class Template:
"""Structural template input."""
__slots__ = ('_mmcif', '_query_to_template')
def __init__(self, *, mmcif: str, query_to_template_map: Mapping[int, int]):
"""Initializes the template.
Args:
mmcif: The structural template in mmCIF format. The mmCIF should have only
one protein chain.
query_to_template_map: A mapping from query residue index to template
residue index.
"""
self._mmcif = mmcif
# Needed to make the Template class hashable.
self._query_to_template = tuple(query_to_template_map.items())
@property
def query_to_template_map(self) -> Mapping[int, int]:
return dict(self._query_to_template)
@property
def mmcif(self) -> str:
return self._mmcif
def __hash__(self) -> int:
return hash((self._mmcif, tuple(sorted(self._query_to_template))))
def __eq__(self, other: Self) -> bool:
mmcifs_equal = self._mmcif == other._mmcif
maps_equal = sorted(self._query_to_template) == sorted(
other._query_to_template
)
return mmcifs_equal and maps_equal
class ProteinChain:
"""Protein chain input."""
__slots__ = (
'_id',
'_sequence',
'_ptms',
'_description',
'_paired_msa',
'_unpaired_msa',
'_templates',
)
def __init__(
self,
*,
id: str, # pylint: disable=redefined-builtin
sequence: str,
ptms: Sequence[tuple[str, int]],
description: str | None = None,
paired_msa: str | None = None,
unpaired_msa: str | None = None,
templates: Sequence[Template] | None = None,
):
"""Initializes a single protein chain input.
Args:
id: Unique protein chain identifier.
sequence: The amino acid sequence of the chain.
ptms: A list of tuples containing the post-translational modification type
and the (1-based) residue index where the modification is applied.
description: An optional textual description of the protein chain.
paired_msa: Paired A3M-formatted MSA for this chain. This MSA is not
deduplicated and will be used to compute paired features. If None, this
field is unset and must be filled in by the data pipeline before
featurisation. If set to an empty string, it will be treated as a custom
MSA with no sequences.
unpaired_msa: Unpaired A3M-formatted MSA for this chain. This will be
deduplicated and used to compute unpaired features. If None, this field
is unset and must be filled in by the data pipeline before
featurisation. If set to an empty string, it will be treated as a custom
MSA with no sequences.
templates: A list of structural templates for this chain. If None, this
field is unset and must be filled in by the data pipeline before
featurisation. The list can be empty or contain up to 20 templates.
"""
if not all(res.isalpha() for res in sequence):
raise ValueError(f'Protein must contain only letters, got "{sequence}"')
if any(not 0 < mod[1] <= len(sequence) for mod in ptms):
raise ValueError(f'Invalid protein modification index: {ptms}')
if any(mod[0].startswith('CCD_') for mod in ptms):
raise ValueError(
f'Protein ptms must not contain the "CCD_" prefix, got {ptms}'
)
# Use hashable containers for ptms and templates.
self._id = id
self._sequence = sequence
self._ptms = tuple(ptms)
self._description = description
self._paired_msa = paired_msa
self._unpaired_msa = unpaired_msa
self._templates = tuple(templates) if templates is not None else None
@property
def id(self) -> str:
return self._id
@property
def sequence(self) -> str:
"""Returns a single-letter sequence, taking modifications into account.
Uses 'X' for all unknown residues.
"""
return ''.join([
residue_names.letters_three_to_one(r, default='X')
for r in self.to_ccd_sequence()
])
@property
def ptms(self) -> Sequence[tuple[str, int]]:
return self._ptms
@property
def description(self) -> str | None:
return self._description
@property
def paired_msa(self) -> str | None:
return self._paired_msa
@property
def unpaired_msa(self) -> str | None:
return self._unpaired_msa
@property
def templates(self) -> Sequence[Template] | None:
return self._templates
def __len__(self) -> int:
return len(self._sequence)
def __eq__(self, other: Self) -> bool:
return (
self._id == other._id
and self._sequence == other._sequence
and self._ptms == other._ptms
and self._description == other._description
and self._paired_msa == other._paired_msa
and self._unpaired_msa == other._unpaired_msa
and self._templates == other._templates
)
def __hash__(self) -> int:
return hash((
self._id,
self._sequence,
self._ptms,
self._description,
self._paired_msa,
self._unpaired_msa,
self._templates,
))
def hash_without_id(self) -> int:
"""Returns a hash ignoring the ID - useful for deduplication."""
return hash((
self._sequence,
self._ptms,
self._description,
self._paired_msa,
self._unpaired_msa,
self._templates,
))
@classmethod
def from_alphafoldserver_dict(
cls, json_dict: Mapping[str, Any], seq_id: str
) -> Self:
"""Constructs ProteinChain from the AlphaFoldServer JSON dict."""
_validate_keys(
json_dict.keys(),
{
'sequence',
'glycans',
'modifications',
'count',
'maxTemplateDate',
'useStructureTemplate',
},
)
sequence = json_dict['sequence']
if 'glycans' in json_dict:
raise ValueError(
f'Specifying glycans in the `{ALPHAFOLDSERVER_JSON_DIALECT}` format'
' is not supported.'
)
if 'maxTemplateDate' in json_dict:
raise ValueError(
f'Specifying maxTemplateDate in the `{ALPHAFOLDSERVER_JSON_DIALECT}`'
' format is not supported, use the --max_template_date flag instead.'
)
templates = None # Search for templates unless explicitly disabled.
if not json_dict.get('useStructureTemplate', True):
templates = [] # Do not use any templates.
ptms = [
(mod['ptmType'].removeprefix('CCD_'), mod['ptmPosition'])
for mod in json_dict.get('modifications', [])
]
return cls(id=seq_id, sequence=sequence, ptms=ptms, templates=templates)
@classmethod
def from_dict(
cls,
json_dict: Mapping[str, Any],
json_path: pathlib.Path | None = None,
seq_id: str | None = None,
) -> Self:
"""Constructs ProteinChain from the AlphaFold JSON dict."""
json_dict = json_dict['protein']
_validate_keys(
json_dict.keys(),
{
'id',
'sequence',
'modifications',
'description',
'unpairedMsa',
'unpairedMsaPath',
'pairedMsa',
'pairedMsaPath',
'templates',
},
)
sequence = json_dict['sequence']
ptms = [
(mod['ptmType'], mod['ptmPosition'])
for mod in json_dict.get('modifications', [])
]
unpaired_msa = json_dict.get('unpairedMsa', None)
unpaired_msa_path = json_dict.get('unpairedMsaPath', None)
if unpaired_msa and unpaired_msa_path:
raise ValueError('Only one of unpairedMsa/unpairedMsaPath can be set.')
if (
unpaired_msa
and len(unpaired_msa) < 256
and os.path.exists(unpaired_msa)
):
raise ValueError(
'Set the unpaired MSA path using the "unpairedMsaPath" field.'
)
elif unpaired_msa_path:
unpaired_msa = _read_file(pathlib.Path(unpaired_msa_path), json_path)
paired_msa = json_dict.get('pairedMsa', None)
paired_msa_path = json_dict.get('pairedMsaPath', None)
if paired_msa and paired_msa_path:
raise ValueError('Only one of pairedMsa/pairedMsaPath can be set.')
if paired_msa and len(paired_msa) < 256 and os.path.exists(paired_msa):
raise ValueError(
'Set the paired MSA path using the "pairedMsaPath" field.'
)
elif paired_msa_path:
paired_msa = _read_file(pathlib.Path(paired_msa_path), json_path)
raw_templates = json_dict.get('templates', None)
if raw_templates is None:
templates = None
else:
templates = []
for raw_template in raw_templates:
_validate_keys(
raw_template.keys(),
{'mmcif', 'mmcifPath', 'queryIndices', 'templateIndices'},
)
mmcif = raw_template.get('mmcif', None)
mmcif_path = raw_template.get('mmcifPath', None)
if mmcif and mmcif_path:
raise ValueError('Only one of mmcif/mmcifPath can be set.')
if mmcif and len(mmcif) < 256 and os.path.exists(mmcif):
raise ValueError('Set the template path using the "mmcifPath" field.')
if mmcif_path:
mmcif = _read_file(pathlib.Path(mmcif_path), json_path)
query_to_template_map = dict(
zip(raw_template['queryIndices'], raw_template['templateIndices'])
)
templates.append(
Template(mmcif=mmcif, query_to_template_map=query_to_template_map)
)
return cls(
id=seq_id or json_dict['id'],
sequence=sequence,
ptms=ptms,
description=json_dict.get('description', None),
paired_msa=paired_msa,
unpaired_msa=unpaired_msa,
templates=templates,
)
def to_dict(
self, seq_id: str | Sequence[str] | None = None
) -> Mapping[str, Mapping[str, Any]]:
"""Converts ProteinChain to an AlphaFold JSON dict."""
if self._templates is None:
templates = None
else:
templates = [
{
'mmcif': template.mmcif,
'queryIndices': list(template.query_to_template_map.keys()),
'templateIndices': (
list(template.query_to_template_map.values()) or None
),
}
for template in self._templates
]
contents = {
'id': seq_id or self._id,
'sequence': self._sequence,
'modifications': [
{'ptmType': ptm[0], 'ptmPosition': ptm[1]} for ptm in self._ptms
],
'unpairedMsa': self._unpaired_msa,
'pairedMsa': self._paired_msa,
'templates': templates,
}
if self._description is not None:
contents['description'] = self._description
return {'protein': contents}
def to_ccd_sequence(self) -> Sequence[str]:
"""Converts to a sequence of CCD codes."""
ccd_coded_seq = [
residue_names.PROTEIN_COMMON_ONE_TO_THREE.get(res, residue_names.UNK)
for res in self._sequence
]
for ptm_code, ptm_index in self._ptms:
ccd_coded_seq[ptm_index - 1] = ptm_code
return ccd_coded_seq
def fill_missing_fields(self) -> Self:
"""Fill missing MSA and template fields with default values."""
return ProteinChain(
id=self.id,
sequence=self._sequence,
ptms=self._ptms,
description=self._description,
unpaired_msa=self._unpaired_msa or '',
paired_msa=self._paired_msa or '',
templates=self._templates or [],
)
class RnaChain:
"""RNA chain input."""
__slots__ = (
'_id',
'_sequence',
'_modifications',
'_description',
'_unpaired_msa',
)
def __init__(
self,
*,
id: str, # pylint: disable=redefined-builtin
sequence: str,
modifications: Sequence[tuple[str, int]],
description: str | None = None,
unpaired_msa: str | None = None,
):
"""Initializes a single strand RNA chain input.
Args:
id: Unique RNA chain identifier.
sequence: The RNA sequence of the chain.
modifications: A list of tuples containing the modification type and the
(1-based) residue index where the modification is applied.
description: An optional textual description of the RNA chain.
unpaired_msa: Unpaired A3M-formatted MSA for this chain. This will be
deduplicated and used to compute unpaired features. If None, this field
is unset and must be filled in by the data pipeline before
featurisation. If set to an empty string, it will be treated as a custom
MSA with no sequences.
"""
if not all(res.isalpha() for res in sequence):
raise ValueError(f'RNA must contain only letters, got "{sequence}"')
if any(not 0 < mod[1] <= len(sequence) for mod in modifications):
raise ValueError(f'Invalid RNA modification index: {modifications}')
if any(mod[0].startswith('CCD_') for mod in modifications):
raise ValueError(
'RNA modifications must not contain the "CCD_" prefix, got'
f' {modifications}'
)
self._id = id
self._sequence = sequence
# Use hashable container for modifications.
self._modifications = tuple(modifications)
self._description = description
self._unpaired_msa = unpaired_msa
@property
def id(self) -> str:
return self._id
@property
def sequence(self) -> str:
"""Returns a single-letter sequence, taking modifications into account.
Uses 'N' for all unknown residues.
"""
return ''.join([
residue_names.letters_three_to_one(r, default='N')
for r in self.to_ccd_sequence()
])
@property
def modifications(self) -> Sequence[tuple[str, int]]:
return self._modifications
@property
def description(self) -> str | None:
return self._description
@property
def unpaired_msa(self) -> str | None:
return self._unpaired_msa
def __len__(self) -> int:
return len(self._sequence)
def __eq__(self, other: Self) -> bool:
return (
self._id == other._id
and self._sequence == other._sequence
and self._modifications == other._modifications
and self._description == other._description
and self._unpaired_msa == other._unpaired_msa
)
def __hash__(self) -> int:
return hash((
self._id,
self._sequence,
self._modifications,
self._description,
self._unpaired_msa,
))
def hash_without_id(self) -> int:
"""Returns a hash ignoring the ID - useful for deduplication."""
return hash((
self._sequence,
self._modifications,
self._description,
self._unpaired_msa,
))
@classmethod
def from_alphafoldserver_dict(
cls, json_dict: Mapping[str, Any], seq_id: str
) -> Self:
"""Constructs RnaChain from the AlphaFoldServer JSON dict."""
_validate_keys(json_dict.keys(), {'sequence', 'modifications', 'count'})
sequence = json_dict['sequence']
modifications = [
(mod['modificationType'].removeprefix('CCD_'), mod['basePosition'])
for mod in json_dict.get('modifications', [])
]
return cls(id=seq_id, sequence=sequence, modifications=modifications)
@classmethod
def from_dict(
cls,
json_dict: Mapping[str, Any],
json_path: pathlib.Path | None = None,
seq_id: str | None = None,
) -> Self:
"""Constructs RnaChain from the AlphaFold JSON dict."""
json_dict = json_dict['rna']
_validate_keys(
json_dict.keys(),
{
'id',
'sequence',
'modifications',
'description',
'unpairedMsa',
'unpairedMsaPath',
},
)
sequence = json_dict['sequence']
modifications = [
(mod['modificationType'], mod['basePosition'])
for mod in json_dict.get('modifications', [])
]
unpaired_msa = json_dict.get('unpairedMsa', None)
unpaired_msa_path = json_dict.get('unpairedMsaPath', None)
if unpaired_msa and unpaired_msa_path:
raise ValueError('Only one of unpairedMsa/unpairedMsaPath can be set.')
if (
unpaired_msa
and len(unpaired_msa) < 256
and os.path.exists(unpaired_msa)
):
raise ValueError(
'Set the unpaired MSA path using the "unpairedMsaPath" field.'
)
elif unpaired_msa_path:
unpaired_msa = _read_file(pathlib.Path(unpaired_msa_path), json_path)
return cls(
id=seq_id or json_dict['id'],
sequence=sequence,
modifications=modifications,
description=json_dict.get('description', None),
unpaired_msa=unpaired_msa,
)
def to_dict(
self, seq_id: str | Sequence[str] | None = None
) -> Mapping[str, Mapping[str, Any]]:
"""Converts RnaChain to an AlphaFold JSON dict."""
contents = {
'id': seq_id or self._id,
'sequence': self._sequence,
'modifications': [
{'modificationType': mod[0], 'basePosition': mod[1]}
for mod in self._modifications
],
'unpairedMsa': self._unpaired_msa,
}
if self._description is not None:
contents['description'] = self._description
return {'rna': contents}
def to_ccd_sequence(self) -> Sequence[str]:
"""Converts to a sequence of CCD codes."""
mapping = {r: r for r in residue_names.RNA_TYPES} # Same 1-letter and CCD.
ccd_coded_seq = [
mapping.get(res, residue_names.UNK_RNA) for res in self._sequence
]
for ccd_code, modification_index in self._modifications:
ccd_coded_seq[modification_index - 1] = ccd_code
return ccd_coded_seq
def fill_missing_fields(self) -> Self:
"""Fill missing MSA fields with default values."""
return RnaChain(
id=self.id,
sequence=self.sequence,
modifications=self.modifications,
unpaired_msa=self._unpaired_msa or '',
)
class DnaChain:
"""Single strand DNA chain input."""
__slots__ = ('_id', '_sequence', '_modifications', '_description')
def __init__(
self,
*,
id: str, # pylint: disable=redefined-builtin
sequence: str,
modifications: Sequence[tuple[str, int]],
description: str | None = None,
):
"""Initializes a single strand DNA chain input.
Args:
id: Unique DNA chain identifier.
sequence: The DNA sequence of the chain.
modifications: A list of tuples containing the modification type and the
(1-based) residue index where the modification is applied.
description: An optional textual description of the DNA chain.
"""
if not all(res.isalpha() for res in sequence):
raise ValueError(f'DNA must contain only letters, got "{sequence}"')
if any(not 0 < mod[1] <= len(sequence) for mod in modifications):
raise ValueError(f'Invalid DNA modification index: {modifications}')
if any(mod[0].startswith('CCD_') for mod in modifications):
raise ValueError(
'DNA modifications must not contain the "CCD_" prefix, got'
f' {modifications}'
)
self._id = id
self._sequence = sequence
# Use hashable container for modifications.
self._modifications = tuple(modifications)
self._description = description
@property
def id(self) -> str:
return self._id
@property
def sequence(self) -> str:
"""Returns a single-letter sequence, taking modifications into account.
Uses 'N' for all unknown residues.
"""
return ''.join([
residue_names.letters_three_to_one(r, default='N')
for r in self.to_ccd_sequence()
])
@property
def description(self) -> str | None:
return self._description
def __len__(self) -> int:
return len(self._sequence)
def __eq__(self, other: Self) -> bool:
return (
self._id == other._id
and self._sequence == other._sequence
and self._modifications == other._modifications
and self._description == other._description
)
def __hash__(self) -> int:
return hash(
(self._id, self._sequence, self._modifications, self._description)
)
def modifications(self) -> Sequence[tuple[str, int]]:
return self._modifications
def hash_without_id(self) -> int:
"""Returns a hash ignoring the ID - useful for deduplication."""
return hash((self._sequence, self._modifications, self._description))
@classmethod
def from_alphafoldserver_dict(
cls, json_dict: Mapping[str, Any], seq_id: str
) -> Self:
"""Constructs DnaChain from the AlphaFoldServer JSON dict."""
_validate_keys(json_dict.keys(), {'sequence', 'modifications', 'count'})
sequence = json_dict['sequence']
modifications = [
(mod['modificationType'].removeprefix('CCD_'), mod['basePosition'])
for mod in json_dict.get('modifications', [])
]
return cls(id=seq_id, sequence=sequence, modifications=modifications)
@classmethod
def from_dict(
cls, json_dict: Mapping[str, Any], seq_id: str | None = None
) -> Self:
"""Constructs DnaChain from the AlphaFold JSON dict."""
json_dict = json_dict['dna']
_validate_keys(
json_dict.keys(), {'id', 'sequence', 'modifications', 'description'}
)
sequence = json_dict['sequence']
modifications = [
(mod['modificationType'], mod['basePosition'])
for mod in json_dict.get('modifications', [])
]
return cls(
id=seq_id or json_dict['id'],
sequence=sequence,
modifications=modifications,
description=json_dict.get('description', None),
)
def to_dict(
self, seq_id: str | Sequence[str] | None = None
) -> Mapping[str, Mapping[str, Any]]:
"""Converts DnaChain to an AlphaFold JSON dict."""
contents = {
'id': seq_id or self._id,
'sequence': self._sequence,
'modifications': [
{'modificationType': mod[0], 'basePosition': mod[1]}
for mod in self._modifications
],
}
if self._description is not None:
contents['description'] = self._description
return {'dna': contents}
def to_ccd_sequence(self) -> Sequence[str]:
"""Converts to a sequence of CCD codes."""
ccd_coded_seq = [
residue_names.DNA_COMMON_ONE_TO_TWO.get(res, residue_names.UNK_DNA)
for res in self._sequence
]
for ccd_code, modification_index in self._modifications:
ccd_coded_seq[modification_index - 1] = ccd_code
return ccd_coded_seq
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class Ligand:
"""Ligand input.
Attributes:
id: Unique ligand "chain" identifier.
ccd_ids: The Chemical Component Dictionary or user-defined CCD IDs of the
chemical components of the ligand. Typically, this is just a single ID,
but some ligands are composed of multiple components. If that is the case,
a bond linking these components should be added to the bonded_atom_pairs
Input field.
smiles: The SMILES representation of the ligand.
description: An optional textual description of the ligand.
"""
id: str
ccd_ids: Sequence[str] | None = None
smiles: str | None = None
description: str | None = None
def __post_init__(self):
if (self.ccd_ids is None) == (self.smiles is None):
raise ValueError('Ligand must have one of CCD ID or SMILES set.')
if self.smiles is not None:
mol = rd_chem.MolFromSmiles(self.smiles)
if not mol:
raise ValueError(f'Unable to make RDKit Mol from SMILES: {self.smiles}')
# Use hashable types for ccd_ids.
if self.ccd_ids is not None:
object.__setattr__(self, 'ccd_ids', tuple(self.ccd_ids))
def __len__(self) -> int:
if self.ccd_ids is not None:
return len(self.ccd_ids)
else:
return 1
def hash_without_id(self) -> int:
"""Returns a hash ignoring the ID - useful for deduplication."""
return hash((self.ccd_ids, self.smiles, self.description))
@classmethod
def from_alphafoldserver_dict(
cls, json_dict: Mapping[str, Any], seq_id: str
) -> Self:
"""Constructs Ligand from the AlphaFoldServer JSON dict."""
# Ligand can be specified either as a ligand, or ion (special-case).
_validate_keys(json_dict.keys(), {'ligand', 'ion', 'count'})
if 'ligand' in json_dict:
return cls(id=seq_id, ccd_ids=[json_dict['ligand'].removeprefix('CCD_')])
elif 'ion' in json_dict:
return cls(id=seq_id, ccd_ids=[json_dict['ion']])
else:
raise ValueError(f'Unknown ligand type: {json_dict}')
@classmethod
def from_dict(
cls, json_dict: Mapping[str, Any], seq_id: str | None = None
) -> Self:
"""Constructs Ligand from the AlphaFold JSON dict."""
json_dict = json_dict['ligand']
_validate_keys(
json_dict.keys(), {'id', 'ccdCodes', 'smiles', 'description'}
)
if json_dict.get('ccdCodes') and json_dict.get('smiles'):
raise ValueError(
'Ligand cannot have both CCD code and SMILES set at the same time, '
f'got CCD: {json_dict["ccdCodes"]} and SMILES: {json_dict["smiles"]}'
)
if 'ccdCodes' in json_dict:
ccd_codes = json_dict['ccdCodes']
if not isinstance(ccd_codes, (list, tuple)):
raise ValueError(
'CCD codes must be a list of strings, got '
f'{type(ccd_codes).__name__} instead: {ccd_codes}'
)
return cls(
id=seq_id or json_dict['id'],
ccd_ids=ccd_codes,
description=json_dict.get('description', None),
)
elif 'smiles' in json_dict:
return cls(
id=seq_id or json_dict['id'],
smiles=json_dict['smiles'],
description=json_dict.get('description', None),
)
else:
raise ValueError(f'Unknown ligand type: {json_dict}')
def to_dict(
self, seq_id: str | Sequence[str] | None = None
) -> Mapping[str, Mapping[str, Any]]:
"""Converts Ligand to an AlphaFold JSON dict."""
contents = {'id': seq_id or self.id}
if self.ccd_ids is not None:
contents['ccdCodes'] = self.ccd_ids
if self.smiles is not None:
contents['smiles'] = self.smiles
if self.description is not None:
contents['description'] = self.description
return {'ligand': contents}
def _sample_rng_seed() -> int:
"""Sample a random seed for AlphaFoldServer job."""
# See https://alphafoldserver.com/faq#what-are-seeds-and-how-are-they-set.
return random.randint(0, 2**32 - 1)
def _validate_user_ccd_keys(keys: Sequence[str], component_name: str) -> None:
"""Validates the keys of the user-defined CCD dictionary."""
mandatory_keys = (
'_chem_comp.id',
'_chem_comp.name',
'_chem_comp.type',
'_chem_comp.formula',
'_chem_comp.mon_nstd_parent_comp_id',
'_chem_comp.pdbx_synonyms',
'_chem_comp.formula_weight',
'_chem_comp_atom.comp_id',
'_chem_comp_atom.atom_id',
'_chem_comp_atom.type_symbol',
'_chem_comp_atom.charge',
'_chem_comp_atom.pdbx_model_Cartn_x_ideal',
'_chem_comp_atom.pdbx_model_Cartn_y_ideal',
'_chem_comp_atom.pdbx_model_Cartn_z_ideal',
'_chem_comp_bond.atom_id_1',
'_chem_comp_bond.atom_id_2',
'_chem_comp_bond.value_order',
'_chem_comp_bond.pdbx_aromatic_flag',
)
if missing_keys := set(mandatory_keys) - set(keys):
raise ValueError(
f'Component {component_name} in the user-defined CCD is missing these'
f' keys: {missing_keys}'
)
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class Input:
"""AlphaFold input.
Attributes:
name: The name of the target.
chains: Protein chains, RNA chains, DNA chains, or ligands.
protein_chains: Protein chains.
rna_chains: RNA chains.
dna_chains: Single strand DNA chains.
ligands: Ligand (including ion) inputs.
rng_seeds: Random number generator seeds, one for each model execution.
bonded_atom_pairs: A list of tuples of atoms that are bonded to each other.
Each atom is defined by a tuple of (chain_id, res_id, atom_name). Chain
IDs must be set if there are any bonded atoms. Residue IDs are 1-indexed.
Atoms in ligands defined by SMILES can't be bonded since SMILES doesn't
define unique atom names.
user_ccd: Optional user-defined chemical component dictionary in the CIF
format. This can be used to provide additional CCD entries that are not
present in the default CCD and thus define arbitrary new ligands. This is
more expressive than SMILES since it allows to name all atoms within the
ligand which in turn makes it possible to define bonds using those atoms.
"""
name: str
chains: Sequence[ProteinChain | RnaChain | DnaChain | Ligand]
rng_seeds: Sequence[int]
bonded_atom_pairs: Sequence[tuple[BondAtomId, BondAtomId]] | None = None
user_ccd: str | None = None
def __post_init__(self):
if not self.rng_seeds:
raise ValueError('Input must have at least one RNG seed.')
if not self.name.strip() or not self.sanitised_name():
raise ValueError(
'Input name must be non-empty and contain at least one valid'
' character (letters, numbers, dots, dashes, underscores).'
)
chain_ids = [c.id for c in self.chains]
if any(not c.id.isalpha() or c.id.islower() for c in self.chains):
raise ValueError(f'IDs must be upper case letters, got: {chain_ids}')
if len(set(chain_ids)) != len(chain_ids):
raise ValueError('Input JSON contains sequences with duplicate IDs.')
# Use hashable types for chains, rng_seeds, and bonded_atom_pairs.
object.__setattr__(self, 'chains', tuple(self.chains))
object.__setattr__(self, 'rng_seeds', tuple(self.rng_seeds))
if self.bonded_atom_pairs is not None:
object.__setattr__(
self, 'bonded_atom_pairs', tuple(self.bonded_atom_pairs)
)
if self.user_ccd is not None:
for component_name, component_cif in cif_dict.parse_multi_data_cif(
self.user_ccd
).items():
_validate_user_ccd_keys(component_cif.keys(), component_name)
@property
def protein_chains(self) -> Sequence[ProteinChain]:
return [chain for chain in self.chains if isinstance(chain, ProteinChain)]
@property
def rna_chains(self) -> Sequence[RnaChain]:
return [chain for chain in self.chains if isinstance(chain, RnaChain)]
@property
def dna_chains(self) -> Sequence[DnaChain]:
return [chain for chain in self.chains if isinstance(chain, DnaChain)]
@property
def ligands(self) -> Sequence[Ligand]:
return [chain for chain in self.chains if isinstance(chain, Ligand)]
def sanitised_name(self) -> str:
"""Returns sanitised version of the name that can be used as a filename."""
spaceless_name = self.name.replace(' ', '_')
allowed_chars = set(string.ascii_letters + string.digits + '_-.')
return ''.join(l for l in spaceless_name if l in allowed_chars)
@classmethod
def from_alphafoldserver_fold_job(cls, fold_job: Mapping[str, Any]) -> Self:
"""Constructs Input from an AlphaFoldServer fold job."""
# Validate the fold job has the correct format.
_validate_keys(
fold_job.keys(),
{'name', 'modelSeeds', 'sequences', 'dialect', 'version'},
)
if 'dialect' not in fold_job and 'version' not in fold_job:
dialect = ALPHAFOLDSERVER_JSON_DIALECT
version = ALPHAFOLDSERVER_JSON_VERSION
elif 'dialect' in fold_job and 'version' in fold_job:
dialect = fold_job['dialect']
version = fold_job['version']
else:
raise ValueError(
'AlphaFold Server input JSON must either contain both `dialect` and'
' `version` fields, or neither. If neither is specified, it is'
f' assumed that `dialect="{ALPHAFOLDSERVER_JSON_DIALECT}"` and'
f' `version="{ALPHAFOLDSERVER_JSON_VERSION}"`.'
)
if dialect != ALPHAFOLDSERVER_JSON_DIALECT:
raise ValueError(
f'AlphaFold Server input JSON has unsupported dialect: {dialect}, '
f'expected {ALPHAFOLDSERVER_JSON_DIALECT}.'
)
# For now, there is only one AlphaFold Server JSON version.
if version != ALPHAFOLDSERVER_JSON_VERSION:
raise ValueError(
f'AlphaFold Server input JSON has unsupported version: {version}, '
f'expected {ALPHAFOLDSERVER_JSON_VERSION}.'
)
# Parse the chains.
chains = []
for sequence in fold_job['sequences']:
if 'proteinChain' in sequence:
for _ in range(sequence['proteinChain'].get('count', 1)):
chains.append(
ProteinChain.from_alphafoldserver_dict(
sequence['proteinChain'],
seq_id=mmcif_lib.int_id_to_str_id(len(chains) + 1),
)
)
elif 'rnaSequence' in sequence:
for _ in range(sequence['rnaSequence'].get('count', 1)):
chains.append(
RnaChain.from_alphafoldserver_dict(
sequence['rnaSequence'],
seq_id=mmcif_lib.int_id_to_str_id(len(chains) + 1),
)
)
elif 'dnaSequence' in sequence:
for _ in range(sequence['dnaSequence'].get('count', 1)):
chains.append(
DnaChain.from_alphafoldserver_dict(
sequence['dnaSequence'],
seq_id=mmcif_lib.int_id_to_str_id(len(chains) + 1),
)
)
elif 'ion' in sequence:
for _ in range(sequence['ion'].get('count', 1)):
chains.append(
Ligand.from_alphafoldserver_dict(
sequence['ion'],
seq_id=mmcif_lib.int_id_to_str_id(len(chains) + 1),
)
)
elif 'ligand' in sequence:
for _ in range(sequence['ligand'].get('count', 1)):
chains.append(
Ligand.from_alphafoldserver_dict(
sequence['ligand'],
seq_id=mmcif_lib.int_id_to_str_id(len(chains) + 1),
)
)
else:
raise ValueError(f'Unknown sequence type: {sequence}')
if 'modelSeeds' in fold_job and fold_job['modelSeeds']:
rng_seeds = [int(seed) for seed in fold_job['modelSeeds']]
else:
rng_seeds = [_sample_rng_seed()]
return cls(name=fold_job['name'], chains=chains, rng_seeds=rng_seeds)
@classmethod
def from_json(
cls, json_str: str, json_path: pathlib.Path | None = None
) -> Self:
"""Loads the input from the AlphaFold JSON string."""
raw_json = json.loads(json_str)
_validate_keys(
raw_json.keys(),
{
'dialect',
'version',
'name',
'modelSeeds',
'sequences',
'bondedAtomPairs',
'userCCD',
'userCCDPath',
},
)
if 'dialect' not in raw_json or 'version' not in raw_json:
raise ValueError(
'AlphaFold 3 input JSON must contain `dialect` and `version` fields.'
)
if raw_json['dialect'] != JSON_DIALECT:
raise ValueError(
'AlphaFold 3 input JSON has unsupported dialect:'
f' {raw_json["dialect"]}, expected {JSON_DIALECT}.'
)
if raw_json['version'] not in JSON_VERSIONS:
raise ValueError(
'AlphaFold 3 input JSON has unsupported version:'
f' {raw_json["version"]}, expected one of {JSON_VERSIONS}.'
)
if 'sequences' not in raw_json:
raise ValueError('AlphaFold 3 input JSON does not contain any sequences.')
if 'modelSeeds' not in raw_json or not raw_json['modelSeeds']:
raise ValueError(
'AlphaFold 3 input JSON must specify at least one rng seed in'
' `modelSeeds`.'
)
sequences = raw_json['sequences']
# Make sure sequence IDs are all set.
raw_sequence_ids = [next(iter(s.values())).get('id') for s in sequences]
if all(raw_sequence_ids):
sequence_ids = []
for sequence_id in raw_sequence_ids:
if isinstance(sequence_id, list):
sequence_ids.append(sequence_id)
else:
sequence_ids.append([sequence_id])
else:
raise ValueError(
'AlphaFold 3 input JSON contains sequences with unset IDs.'
)
flat_seq_ids = []
for seq_ids in sequence_ids:
flat_seq_ids.extend(seq_ids)
chains = []
for seq_ids, sequence in zip(sequence_ids, sequences, strict=True):
if len(sequence) != 1:
raise ValueError(f'Chain {seq_ids} has more than 1 sequence.')
for seq_id in seq_ids:
if 'protein' in sequence:
chains.append(ProteinChain.from_dict(sequence, json_path, seq_id))
elif 'rna' in sequence:
chains.append(RnaChain.from_dict(sequence, json_path, seq_id))
elif 'dna' in sequence:
chains.append(DnaChain.from_dict(sequence, seq_id=seq_id))
elif 'ligand' in sequence:
chains.append(Ligand.from_dict(sequence, seq_id=seq_id))
else:
raise ValueError(f'Unknown sequence type: {sequence}')
smiles_ligand_ids = set(
c.id for c in chains if isinstance(c, Ligand) and c.smiles is not None
)
chain_lengths = {chain.id: len(chain) for chain in chains}
bonded_atom_pairs = None
if bonds := raw_json.get('bondedAtomPairs'):
bonded_atom_pairs = []
for bond in bonds:
if len(bond) != 2:
raise ValueError(f'Bond {bond} must have 2 atoms, got {len(bond)}.')
bond_beg, bond_end = bond
if (
len(bond_beg) != 3
or not isinstance(bond_beg[0], str)
or not isinstance(bond_beg[1], int)
or not isinstance(bond_beg[2], str)
):
raise ValueError(
f'Atom {bond_beg} in bond {bond} must have 3 components: '
'(chain_id: str, res_id: int, atom_name: str).'
)
if (
len(bond_end) != 3
or not isinstance(bond_end[0], str)
or not isinstance(bond_end[1], int)
or not isinstance(bond_end[2], str)
):
raise ValueError(
f'Atom {bond_end} in bond {bond} must have 3 components: '
'(chain_id: str, res_id: int, atom_name: str).'
)
if bond_beg[0] not in flat_seq_ids or bond_end[0] not in flat_seq_ids:
raise ValueError(f'Invalid chain ID(s) in bond {bond}')
if (
not 0 < bond_beg[1] <= chain_lengths[bond_beg[0]]
or not 0 < bond_end[1] <= chain_lengths[bond_end[0]]
):
raise ValueError(f'Invalid residue ID(s) in bond {bond}')
if bond_beg[0] in smiles_ligand_ids:
raise ValueError(
f'Bond {bond} involves an unsupported SMILES ligand {bond_beg[0]}'
)
if bond_end[0] in smiles_ligand_ids:
raise ValueError(
f'Bond {bond} involves an unsupported SMILES ligand {bond_end[0]}'
)
bonded_atom_pairs.append((tuple(bond_beg), tuple(bond_end)))
if len(bonded_atom_pairs) != len(set(bonded_atom_pairs)):
raise ValueError(f'Bonds are not unique: {bonded_atom_pairs}')
user_ccd = raw_json.get('userCCD')
user_ccd_path = raw_json.get('userCCDPath')
if user_ccd and user_ccd_path:
raise ValueError('Only one of userCCD/userCCDPath can be set.')
if user_ccd and len(user_ccd) < 256 and os.path.exists(user_ccd):
raise ValueError('Set the user CCD path using the "userCCDPath" field.')
elif user_ccd_path:
user_ccd = _read_file(pathlib.Path(user_ccd_path), json_path)
return cls(
name=raw_json['name'],
chains=chains,
rng_seeds=[int(seed) for seed in raw_json['modelSeeds']],
bonded_atom_pairs=bonded_atom_pairs,
user_ccd=user_ccd,
)
@classmethod
def from_mmcif(cls, mmcif_str: str, ccd: chemical_components.Ccd) -> Self:
"""Loads the input from an mmCIF string.
WARNING: Since rng seeds are not stored in mmCIFs, an rng seed is sampled
in the returned `Input`.
Args:
mmcif_str: The mmCIF string.
ccd: The chemical components dictionary.
Returns:
The input in an Input format.
"""
struc = structure.from_mmcif(
mmcif_str,
# Change MSE residues to MET residues.
fix_mse_residues=True,
# Fix arginine atom names. This is not needed since the input discards
# any atom-level data, but kept for consistency with the paper.
fix_arginines=True,
# Fix unknown DNA residues to the correct unknown DNA residue type.
fix_unknown_dna=True,
# Do not include water molecules.
include_water=False,
# Do not include things like DNA/RNA hybrids. This will be changed once
# we have a way of handling these in the AlphaFold 3 input format.
include_other=False,
# Include the specific bonds defined in the mmCIF bond table, e.g.
# covalent bonds for PTMs.
include_bonds=True,
)
# Create default bioassembly, expanding structures implied by stoichiometry.
struc = struc.generate_bioassembly(None)
sequences = struc.chain_single_letter_sequence(
include_missing_residues=True
)
chains = []
for chain_id, chain_type in zip(
struc.group_by_chain.chain_id, struc.group_by_chain.chain_type
):
sequence = sequences[chain_id]
if chain_type in mmcif_names.NON_POLYMER_CHAIN_TYPES:
residues = list(struc.chain_res_name_sequence()[chain_id])
if all(ccd.get(res) is not None for res in residues):
chains.append(Ligand(id=chain_id, ccd_ids=residues))
elif len(residues) == 1:
comp_name = residues[0]
comps = struc.chemical_components_data
if comps is None:
raise ValueError(
'Missing mmCIF chemical components data - this is required for '
f'a non-CCD ligand {comp_name} defined using SMILES string.'
)
chains.append(
Ligand(id=chain_id, smiles=comps.chem_comp[comp_name].pdbx_smiles)
)
else:
raise ValueError(
'Multi-component ligand must be defined using CCD IDs, defining'
' using SMILES is supported only for single-component ligands. '
f'Got {residues}'
)
else:
residues = struc.chain_res_name_sequence()[chain_id]
fixed = struc.chain_res_name_sequence(
fix_non_standard_polymer_res=True
)[chain_id]
modifications = [
(orig, i + 1)
for i, (orig, fixed) in enumerate(zip(residues, fixed, strict=True))
if orig != fixed
]
if chain_type == mmcif_names.PROTEIN_CHAIN:
chains.append(
ProteinChain(id=chain_id, sequence=sequence, ptms=modifications)
)
elif chain_type == mmcif_names.RNA_CHAIN:
chains.append(
RnaChain(
id=chain_id, sequence=sequence, modifications=modifications
)
)
elif chain_type == mmcif_names.DNA_CHAIN:
chains.append(
DnaChain(
id=chain_id, sequence=sequence, modifications=modifications
)
)
bonded_atom_pairs = []
chain_ids = set(c.id for c in chains)
for atom_a, atom_b, _ in struc.iter_bonds():
if atom_a['chain_id'] in chain_ids and atom_b['chain_id'] in chain_ids:
beg = (atom_a['chain_id'], int(atom_a['res_id']), atom_a['atom_name'])
end = (atom_b['chain_id'], int(atom_b['res_id']), atom_b['atom_name'])
bonded_atom_pairs.append((beg, end))
return cls(
name=struc.name,
chains=chains,
# mmCIFs don't store rng seeds, so we need to sample one here.
rng_seeds=[_sample_rng_seed()],
bonded_atom_pairs=bonded_atom_pairs or None,
)
def to_structure(self, ccd: chemical_components.Ccd) -> structure.Structure:
"""Converts Input to a Structure.
WARNING: This method does not preserve the rng seeds.
Args:
ccd: The chemical components dictionary.
Returns:
The input in a structure.Structure format.
"""
ids: list[str] = []
sequences: list[str] = []
poly_types: list[str] = []
formats: list[structure.SequenceFormat] = []
for chain in self.chains:
ids.append(chain.id)
match chain:
case ProteinChain():
sequences.append('(' + ')('.join(chain.to_ccd_sequence()) + ')')
poly_types.append(mmcif_names.PROTEIN_CHAIN)
formats.append(structure.SequenceFormat.CCD_CODES)
case RnaChain():
sequences.append('(' + ')('.join(chain.to_ccd_sequence()) + ')')
poly_types.append(mmcif_names.RNA_CHAIN)
formats.append(structure.SequenceFormat.CCD_CODES)
case DnaChain():
sequences.append('(' + ')('.join(chain.to_ccd_sequence()) + ')')
poly_types.append(mmcif_names.DNA_CHAIN)
formats.append(structure.SequenceFormat.CCD_CODES)
case Ligand():
if chain.ccd_ids is not None:
sequences.append('(' + ')('.join(chain.ccd_ids) + ')')
if len(chain.ccd_ids) == 1:
poly_types.append(mmcif_names.NON_POLYMER_CHAIN)
else:
poly_types.append(mmcif_names.BRANCHED_CHAIN)
formats.append(structure.SequenceFormat.CCD_CODES)
elif chain.smiles is not None:
# Convert to `:` format that is expected
# by structure.from_sequences_and_bonds.
sequences.append(f'LIG_{chain.id}:{chain.smiles}')
poly_types.append(mmcif_names.NON_POLYMER_CHAIN)
formats.append(structure.SequenceFormat.LIGAND_SMILES)
else:
raise ValueError('Ligand must have one of CCD ID or SMILES set.')
# Remap bond chain IDs from chain IDs to chain indices and convert to
# 0-based residue indexing.
bonded_atom_pairs = []
chain_indices = {cid: i for i, cid in enumerate(ids)}
if self.bonded_atom_pairs is not None:
for bond_beg, bond_end in self.bonded_atom_pairs:
bonded_atom_pairs.append((
(chain_indices[bond_beg[0]], bond_beg[1] - 1, bond_beg[2]),
(chain_indices[bond_end[0]], bond_end[1] - 1, bond_end[2]),
))
return structure.from_sequences_and_bonds(
sequences=sequences,
chain_types=poly_types,
sequence_formats=formats,
chain_ids=ids,
bonded_atom_pairs=bonded_atom_pairs,
ccd=ccd,
name=self.sanitised_name(),
bond_type=mmcif_names.COVALENT_BOND,
release_date=None,
)
def to_json(self) -> str:
"""Converts Input to an AlphaFold JSON."""
deduped_chains = {}
deduped_chain_ids = {}
for chain in self.chains:
deduped_chains[chain.hash_without_id()] = chain
deduped_chain_ids.setdefault(chain.hash_without_id(), []).append(chain.id)
sequences = []
for chain_content_hash, ids in deduped_chain_ids.items():
chain = deduped_chains[chain_content_hash]
sequences.append(chain.to_dict(seq_id=ids if len(ids) > 1 else ids[0]))
alphafold_json = json.dumps(
{
'dialect': JSON_DIALECT,
'version': JSON_VERSION,
'name': self.name,
'sequences': sequences,
'modelSeeds': self.rng_seeds,
'bondedAtomPairs': self.bonded_atom_pairs,
'userCCD': self.user_ccd,
},
indent=2,
)
# Remove newlines from the query/template indices arrays. We match the
# queryIndices/templatesIndices with a non-capturing group. We then match
# the entire region between the square brackets by looking for lines
# containing only whitespace, number, or a comma.
return re.sub(
r'("(?:queryIndices|templateIndices)": \[)([\s\n\d,]+)(\],?)',
lambda mtch: mtch[1] + re.sub(r'\n\s+', ' ', mtch[2].strip()) + mtch[3],
alphafold_json,
)
def fill_missing_fields(self) -> Self:
"""Fill missing MSA and template fields with default values."""
with_missing_fields = [
c.fill_missing_fields()
if isinstance(c, (ProteinChain, RnaChain))
else c
for c in self.chains
]
return dataclasses.replace(self, chains=with_missing_fields)
def with_multiple_seeds(self, num_seeds: int) -> Self:
"""Returns a copy of the input with num_seeds rng seeds."""
if num_seeds <= 1:
raise ValueError('Number of seeds must be greater than 1.')
if len(self.rng_seeds) != 1:
raise ValueError('Input must have one rng seed to set multiple seeds.')
return dataclasses.replace(
self,
rng_seeds=list(range(self.rng_seeds[0], self.rng_seeds[0] + num_seeds)),
)
def load_fold_inputs_from_path(json_path: pathlib.Path) -> Iterator[Input]:
"""Loads multiple fold inputs from a JSON string."""
with open(json_path, 'r') as f:
json_str = f.read()
# Parse the JSON string, so we can detect its format.
raw_json = json.loads(json_str)
if isinstance(raw_json, list):
# AlphaFold Server JSON.
logging.info('Loading %d fold jobs from %s', len(raw_json), json_path)
for fold_job_idx, fold_job in enumerate(raw_json):
try:
yield Input.from_alphafoldserver_fold_job(fold_job)
except ValueError as e:
raise ValueError(
f'Failed to load fold job {fold_job_idx} from {json_path}'
f' (AlphaFold Server dialect): {e}'
) from e
else:
# AlphaFold 3 JSON.
try:
yield Input.from_json(json_str, json_path)
except ValueError as e:
raise ValueError(
f'Failed to load input from {json_path} (AlphaFold 3 dialect): {e}'
) from e
def load_fold_inputs_from_dir(input_dir: pathlib.Path) -> Iterator[Input]:
"""Loads multiple fold inputs from all JSON files in a given input_dir.
Args:
input_dir: The directory containing the JSON files.
Yields:
The fold inputs from all JSON files in the input directory.
"""
for file_path in sorted(input_dir.glob('*.json')):
if not file_path.is_file():
continue
yield from load_fold_inputs_from_path(file_path)
================================================
FILE: src/alphafold3/common/resources.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Load external resources, such as external tools or data resources."""
from collections.abc import Iterator
import os
import pathlib
import typing
from typing import BinaryIO, Final, Literal, TextIO
from importlib import resources
import alphafold3.common
_DATA_ROOT: Final[pathlib.Path] = (
resources.files(alphafold3.common).joinpath('..').resolve()
)
ROOT = _DATA_ROOT
def filename(name: str | os.PathLike[str]) -> str:
"""Returns the absolute path to an external resource.
Note that this calls resources.GetResourceFilename under the hood and hence
causes par file unpacking, which might be unfriendly on diskless machines.
Args:
name: the name of the resource corresponding to its path relative to the
root of the repository.
"""
return (_DATA_ROOT / name).as_posix()
@typing.overload
def open_resource(
name: str | os.PathLike[str], mode: Literal['r', 'rt'] = 'rt'
) -> TextIO:
...
@typing.overload
def open_resource(
name: str | os.PathLike[str], mode: Literal['rb']
) -> BinaryIO:
...
def open_resource(
name: str | os.PathLike[str], mode: str = 'rb'
) -> TextIO | BinaryIO:
"""Returns an open file object for the named resource.
Args:
name: the name of the resource corresponding to its path relative to the
root of the repository.
mode: the mode to use when opening the file.
"""
return (_DATA_ROOT / name).open(mode)
def get_resource_dir(path: str | os.PathLike[str]) -> os.PathLike[str]:
return _DATA_ROOT / path
def walk(path: str) -> Iterator[tuple[str, list[str], list[str]]]:
"""Walks the directory tree of resources similar to os.walk."""
return os.walk((_DATA_ROOT / path).as_posix())
================================================
FILE: src/alphafold3/common/safe_pickle.py
================================================
# Copyright 2025 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Restricted-safe wrapper around pickle for loading trusted data.
This prevents arbitrary object instantiation during unpickling by only
allowing a small allowlist of built-in, innocuous types.
Intended for loading pickled constant data that ships with the repository.
If the pickle is tampered with, an UnpicklingError will be raised instead
of silently executing attacker-controlled bytecode.
"""
from collections.abc import Collection
import pickle
from typing import Any, BinaryIO, Final
# Builtin types expected from AlphaFold 3 generated data.
_ALLOWED_BUILTINS: Final[Collection[str]] = frozenset({
"NoneType",
"bool",
"bytes",
"dict",
"float",
"frozenset",
"int",
"list",
"set",
"str",
"tuple",
})
class _RestrictedUnpickler(pickle.Unpickler):
"""A pickle `Unpickler` that forbids loading arbitrary global classes."""
def find_class(self, module: str, name: str) -> Any:
"""Returns the class for `module` and `name` if allowed."""
if module == "builtins" and name in _ALLOWED_BUILTINS:
return super().find_class(module, name)
raise pickle.UnpicklingError(f"Can't unpickle disallowed '{module}.{name}'")
def load(file_obj: BinaryIO) -> Any:
"""Safely loads pickle data from an already-opened binary file handle.
Only built-in container/primitive types listed in `_ALLOWED_BUILTINS` are
permitted. Any attempt to load other types raises `pickle.UnpicklingError`.
Args:
file_obj: A binary file-like object open for reading.
Returns:
The unpickled data.
"""
return _RestrictedUnpickler(file_obj).load()
================================================
FILE: src/alphafold3/common/testing/data.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Module that provides an abstraction for accessing test data."""
import os
import pathlib
from typing import Literal, overload
from absl.testing import absltest
class Data:
"""Provides an abstraction for accessing test data."""
def __init__(self, data_dir: os.PathLike[str] | str):
"""Initiailizes data wrapper, providing users with high level data access.
Args:
data_dir: Directory containing test data.
"""
self._data_dir = pathlib.Path(data_dir)
def path(self, data_name: str | os.PathLike[str] | None = None) -> str:
"""Returns the path to a given test data.
Args:
data_name: the name of the test data file relative to data_dir. If not
set, this will return the absolute path to the data directory.
"""
data_dir_path = (
pathlib.Path(absltest.get_default_test_srcdir()) / self._data_dir
)
if data_name:
return str(data_dir_path / data_name)
return str(data_dir_path)
@overload
def load(
self, data_name: str | os.PathLike[str], mode: Literal['rt'] = 'rt'
) -> str:
...
@overload
def load(
self, data_name: str | os.PathLike[str], mode: Literal['rb'] = 'rb'
) -> bytes:
...
def load(
self, data_name: str | os.PathLike[str], mode: str = 'rt'
) -> str | bytes:
"""Returns the contents of a given test data.
Args:
data_name: the name of the test data file relative to data_dir.
mode: the mode in which to read the data file. Defaults to text ('rt').
"""
with open(self.path(data_name), mode=mode) as f:
return f.read()
================================================
FILE: src/alphafold3/constants/atom_types.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""List of atom types with reverse look-up."""
from collections.abc import Mapping, Sequence, Set
import itertools
import sys
from typing import Final
from alphafold3.constants import residue_names
# Note:
# `sys.intern` places the values in the Python internal db for fast lookup.
# 37 common residue atoms.
N = sys.intern('N')
CA = sys.intern('CA')
C = sys.intern('C')
CB = sys.intern('CB')
O = sys.intern('O')
CG = sys.intern('CG')
CG1 = sys.intern('CG1')
CG2 = sys.intern('CG2')
OG = sys.intern('OG')
OG1 = sys.intern('OG1')
SG = sys.intern('SG')
CD = sys.intern('CD')
CD1 = sys.intern('CD1')
CD2 = sys.intern('CD2')
ND1 = sys.intern('ND1')
ND2 = sys.intern('ND2')
OD1 = sys.intern('OD1')
OD2 = sys.intern('OD2')
SD = sys.intern('SD')
CE = sys.intern('CE')
CE1 = sys.intern('CE1')
CE2 = sys.intern('CE2')
CE3 = sys.intern('CE3')
NE = sys.intern('NE')
NE1 = sys.intern('NE1')
NE2 = sys.intern('NE2')
OE1 = sys.intern('OE1')
OE2 = sys.intern('OE2')
CH2 = sys.intern('CH2')
NH1 = sys.intern('NH1')
NH2 = sys.intern('NH2')
OH = sys.intern('OH')
CZ = sys.intern('CZ')
CZ2 = sys.intern('CZ2')
CZ3 = sys.intern('CZ3')
NZ = sys.intern('NZ')
OXT = sys.intern('OXT')
# 29 common nucleic acid atoms.
C1PRIME = sys.intern("C1'")
C2 = sys.intern('C2')
C2PRIME = sys.intern("C2'")
C3PRIME = sys.intern("C3'")
C4 = sys.intern('C4')
C4PRIME = sys.intern("C4'")
C5 = sys.intern('C5')
C5PRIME = sys.intern("C5'")
C6 = sys.intern('C6')
C7 = sys.intern('C7')
C8 = sys.intern('C8')
N1 = sys.intern('N1')
N2 = sys.intern('N2')
N3 = sys.intern('N3')
N4 = sys.intern('N4')
N6 = sys.intern('N6')
N7 = sys.intern('N7')
N9 = sys.intern('N9')
O2 = sys.intern('O2')
O2PRIME = sys.intern("O2'")
O3PRIME = sys.intern("O3'")
O4 = sys.intern('O4')
O4PRIME = sys.intern("O4'")
O5PRIME = sys.intern("O5'")
O6 = sys.intern('O6')
OP1 = sys.intern('OP1')
OP2 = sys.intern('OP2')
OP3 = sys.intern('OP3')
P = sys.intern('P')
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
RESIDUE_ATOMS: Mapping[str, tuple[str, ...]] = {
residue_names.ALA: (C, CA, CB, N, O),
residue_names.ARG: (C, CA, CB, CG, CD, CZ, N, NE, O, NH1, NH2),
residue_names.ASN: (C, CA, CB, CG, N, ND2, O, OD1),
residue_names.ASP: (C, CA, CB, CG, N, O, OD1, OD2),
residue_names.CYS: (C, CA, CB, N, O, SG),
residue_names.GLN: (C, CA, CB, CG, CD, N, NE2, O, OE1),
residue_names.GLU: (C, CA, CB, CG, CD, N, O, OE1, OE2),
residue_names.GLY: (C, CA, N, O),
residue_names.HIS: (C, CA, CB, CG, CD2, CE1, N, ND1, NE2, O),
residue_names.ILE: (C, CA, CB, CG1, CG2, CD1, N, O),
residue_names.LEU: (C, CA, CB, CG, CD1, CD2, N, O),
residue_names.LYS: (C, CA, CB, CG, CD, CE, N, NZ, O),
residue_names.MET: (C, CA, CB, CG, CE, N, O, SD),
residue_names.PHE: (C, CA, CB, CG, CD1, CD2, CE1, CE2, CZ, N, O),
residue_names.PRO: (C, CA, CB, CG, CD, N, O),
residue_names.SER: (C, CA, CB, N, O, OG),
residue_names.THR: (C, CA, CB, CG2, N, O, OG1),
residue_names.TRP:
(C, CA, CB, CG, CD1, CD2, CE2, CE3, CZ2, CZ3, CH2, N, NE1, O),
residue_names.TYR: (C, CA, CB, CG, CD1, CD2, CE1, CE2, CZ, N, O, OH),
residue_names.VAL: (C, CA, CB, CG1, CG2, N, O),
} # pyformat: disable
# Used to identify backbone for alignment and distance calculation for sterics.
PROTEIN_BACKBONE_ATOMS: tuple[str, ...] = (N, CA, C)
# Naming swaps for ambiguous atom names. Due to symmetries in the amino acids
# the naming of atoms is ambiguous in 4 of the 20 amino acids. (The LDDT paper
# lists 7 amino acids as ambiguous, but the naming ambiguities in LEU, VAL and
# ARG can be resolved by using the 3D constellations of the 'ambiguous' atoms
# and their neighbours)
AMBIGUOUS_ATOM_NAMES: Mapping[str, Mapping[str, str]] = {
residue_names.ASP: {OD1: OD2},
residue_names.GLU: {OE1: OE2},
residue_names.PHE: {CD1: CD2, CE1: CE2},
residue_names.TYR: {CD1: CD2, CE1: CE2},
}
# Used when we need to store atom data in a format that requires fixed atom data
# size for every protein residue (e.g. a numpy array).
ATOM37: tuple[str, ...] = (
N, CA, C, CB, O, CG, CG1, CG2, OG, OG1, SG, CD, CD1, CD2, ND1, ND2, OD1,
OD2, SD, CE, CE1, CE2, CE3, NE, NE1, NE2, OE1, OE2, CH2, NH1, NH2, OH, CZ,
CZ2, CZ3, NZ, OXT) # pyformat: disable
ATOM37_ORDER: Mapping[str, int] = {name: i for i, name in enumerate(ATOM37)}
ATOM37_NUM: Final[int] = len(ATOM37) # := 37.
# Used when we need to store protein atom data in a format that requires fixed
# atom data size for any residue but takes less space than ATOM37 by having 14
# fields, which is sufficient for storing atoms of all protein residues (e.g. a
# numpy array).
ATOM14: Mapping[str, tuple[str, ...]] = {
residue_names.ALA: (N, CA, C, O, CB),
residue_names.ARG: (N, CA, C, O, CB, CG, CD, NE, CZ, NH1, NH2),
residue_names.ASN: (N, CA, C, O, CB, CG, OD1, ND2),
residue_names.ASP: (N, CA, C, O, CB, CG, OD1, OD2),
residue_names.CYS: (N, CA, C, O, CB, SG),
residue_names.GLN: (N, CA, C, O, CB, CG, CD, OE1, NE2),
residue_names.GLU: (N, CA, C, O, CB, CG, CD, OE1, OE2),
residue_names.GLY: (N, CA, C, O),
residue_names.HIS: (N, CA, C, O, CB, CG, ND1, CD2, CE1, NE2),
residue_names.ILE: (N, CA, C, O, CB, CG1, CG2, CD1),
residue_names.LEU: (N, CA, C, O, CB, CG, CD1, CD2),
residue_names.LYS: (N, CA, C, O, CB, CG, CD, CE, NZ),
residue_names.MET: (N, CA, C, O, CB, CG, SD, CE),
residue_names.PHE: (N, CA, C, O, CB, CG, CD1, CD2, CE1, CE2, CZ),
residue_names.PRO: (N, CA, C, O, CB, CG, CD),
residue_names.SER: (N, CA, C, O, CB, OG),
residue_names.THR: (N, CA, C, O, CB, OG1, CG2),
residue_names.TRP:
(N, CA, C, O, CB, CG, CD1, CD2, NE1, CE2, CE3, CZ2, CZ3, CH2),
residue_names.TYR: (N, CA, C, O, CB, CG, CD1, CD2, CE1, CE2, CZ, OH),
residue_names.VAL: (N, CA, C, O, CB, CG1, CG2),
residue_names.UNK: (),
} # pyformat: disable
# A compact atom encoding with 14 columns, padded with '' in empty slots.
ATOM14_PADDED: Mapping[str, Sequence[str]] = {
k: [v for _, v in itertools.zip_longest(range(14), values, fillvalue='')]
for k, values in ATOM14.items()
}
ATOM14_ORDER: Mapping[str, Mapping[str, int]] = {
k: {name: i for i, name in enumerate(v)} for k, v in ATOM14.items()
}
ATOM14_NUM: Final[int] = max(len(v) for v in ATOM14.values())
# Used when we need to store protein and nucleic atom library.
DENSE_ATOM: Mapping[str, tuple[str, ...]] = {
# Protein.
residue_names.ALA: (N, CA, C, O, CB),
residue_names.ARG: (N, CA, C, O, CB, CG, CD, NE, CZ, NH1, NH2),
residue_names.ASN: (N, CA, C, O, CB, CG, OD1, ND2),
residue_names.ASP: (N, CA, C, O, CB, CG, OD1, OD2),
residue_names.CYS: (N, CA, C, O, CB, SG),
residue_names.GLN: (N, CA, C, O, CB, CG, CD, OE1, NE2),
residue_names.GLU: (N, CA, C, O, CB, CG, CD, OE1, OE2),
residue_names.GLY: (N, CA, C, O),
residue_names.HIS: (N, CA, C, O, CB, CG, ND1, CD2, CE1, NE2),
residue_names.ILE: (N, CA, C, O, CB, CG1, CG2, CD1),
residue_names.LEU: (N, CA, C, O, CB, CG, CD1, CD2),
residue_names.LYS: (N, CA, C, O, CB, CG, CD, CE, NZ),
residue_names.MET: (N, CA, C, O, CB, CG, SD, CE),
residue_names.PHE: (N, CA, C, O, CB, CG, CD1, CD2, CE1, CE2, CZ),
residue_names.PRO: (N, CA, C, O, CB, CG, CD),
residue_names.SER: (N, CA, C, O, CB, OG),
residue_names.THR: (N, CA, C, O, CB, OG1, CG2),
residue_names.TRP:
(N, CA, C, O, CB, CG, CD1, CD2, NE1, CE2, CE3, CZ2, CZ3, CH2),
residue_names.TYR: (N, CA, C, O, CB, CG, CD1, CD2, CE1, CE2, CZ, OH),
residue_names.VAL: (N, CA, C, O, CB, CG1, CG2),
residue_names.UNK: (),
# RNA.
residue_names.A:
(OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME,
C2PRIME, O2PRIME, C1PRIME, N9, C8, N7, C5, C6, N6, N1, C2, N3, C4),
residue_names.C:
(OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME,
C2PRIME, O2PRIME, C1PRIME, N1, C2, O2, N3, C4, N4, C5, C6),
residue_names.G:
(OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME,
C2PRIME, O2PRIME, C1PRIME, N9, C8, N7, C5, C6, O6, N1, C2, N2, N3, C4),
residue_names.U:
(OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME,
C2PRIME, O2PRIME, C1PRIME, N1, C2, O2, N3, C4, O4, C5, C6),
residue_names.UNK_RNA: (),
# DNA.
residue_names.DA:
(OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME,
C2PRIME, C1PRIME, N9, C8, N7, C5, C6, N6, N1, C2, N3, C4),
residue_names.DC:
(OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME,
C2PRIME, C1PRIME, N1, C2, O2, N3, C4, N4, C5, C6),
residue_names.DG:
(OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME,
C2PRIME, C1PRIME, N9, C8, N7, C5, C6, O6, N1, C2, N2, N3, C4),
residue_names.DT:
(OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME,
C2PRIME, C1PRIME, N1, C2, O2, N3, C4, O4, C5, C7, C6),
# Unknown nucleic.
residue_names.UNK_DNA: (),
} # pyformat: disable
DENSE_ATOM_ORDER: Mapping[str, Mapping[str, int]] = {
k: {name: i for i, name in enumerate(v)} for k, v in DENSE_ATOM.items()
}
DENSE_ATOM_NUM: Final[int] = max(len(v) for v in DENSE_ATOM.values())
# Used when we need to store atom data in a format that requires fixed atom data
# size for every nucleic molecule (e.g. a numpy array).
ATOM29: tuple[str, ...] = (
"C1'", 'C2', "C2'", "C3'", 'C4', "C4'", 'C5', "C5'", 'C6', 'C7', 'C8', 'N1',
'N2', 'N3', 'N4', 'N6', 'N7', 'N9', 'OP3', 'O2', "O2'", "O3'", 'O4', "O4'",
"O5'", 'O6', 'OP1', 'OP2', 'P') # pyformat: disable
ATOM29_ORDER: Mapping[str, int] = {
atom_type: i for i, atom_type in enumerate(ATOM29)
}
ATOM29_NUM: Final[int] = len(ATOM29) # := 29
# Hydrogens that exist depending on the protonation state of the residue.
# Extracted from third_party/py/openmm/app/data/hydrogens.xml
PROTONATION_HYDROGENS: Mapping[str, Set[str]] = {
'ASP': {'HD2'},
'CYS': {'HG'},
'GLU': {'HE2'},
'HIS': {'HD1', 'HE2'},
'LYS': {'HZ3'},
}
================================================
FILE: src/alphafold3/constants/chemical_component_sets.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Sets of chemical components."""
from typing import Final
from alphafold3.common import resources
from alphafold3.common import safe_pickle
_CCD_SETS_CCD_PICKLE_FILE = resources.filename(
resources.ROOT / 'constants/converters/chemical_component_sets.pickle'
)
with open(_CCD_SETS_CCD_PICKLE_FILE, 'rb') as f:
_CCD_SET = safe_pickle.load(f)
# Glycan (or 'Saccharide') ligands.
# _chem_comp.type containing 'saccharide' and 'linking' (when lower-case).
GLYCAN_LINKING_LIGANDS: Final[frozenset[str]] = _CCD_SET['glycans_linking']
# _chem_comp.type containing 'saccharide' and not 'linking' (when lower-case).
GLYCAN_OTHER_LIGANDS: Final[frozenset[str]] = _CCD_SET['glycans_other']
# Each of these molecules appears in over 1k PDB structures, are used to
# facilitate crystallization conditions, but do not have biological relevance.
COMMON_CRYSTALLIZATION_AIDS: Final[frozenset[str]] = frozenset({
'SO4', 'GOL', 'EDO', 'PO4', 'ACT', 'PEG', 'DMS', 'TRS', 'PGE', 'PG4', 'FMT',
'EPE', 'MPD', 'MES', 'CD', 'IOD',
}) # pyformat: disable
================================================
FILE: src/alphafold3/constants/chemical_components.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Chemical Components found in PDB (CCD) constants."""
from collections.abc import ItemsView, Iterator, KeysView, Mapping, Sequence, ValuesView
import dataclasses
import functools
import os
from alphafold3.common import resources
from alphafold3.common import safe_pickle
from alphafold3.cpp import cif_dict
_CCD_PICKLE_FILE = resources.filename(
resources.ROOT / 'constants/converters/ccd.pickle'
)
@functools.cache
def _load_ccd_pickle_cached(
path: os.PathLike[str],
) -> dict[str, Mapping[str, Sequence[str]]]:
"""Loads the CCD pickle file and caches it so that it is only loaded once."""
with open(path, 'rb') as f:
return safe_pickle.load(f)
class Ccd(Mapping[str, Mapping[str, Sequence[str]]]):
"""Chemical Components found in PDB (CCD) constants.
See https://academic.oup.com/bioinformatics/article/31/8/1274/212200 for CCD
CIF format documentation.
Wraps the dict to prevent accidental mutation.
"""
__slots__ = ('_dict', '_ccd_pickle_path')
def __init__(
self,
ccd_pickle_path: os.PathLike[str] | None = None,
user_ccd: str | None = None,
):
"""Initialises the chemical components dictionary.
Args:
ccd_pickle_path: Path to the CCD pickle file. If None, uses the default
CCD pickle file included in the source code.
user_ccd: A string containing the user-provided CCD. This has to conform
to the same format as the CCD, see https://www.wwpdb.org/data/ccd. If
provided, takes precedence over the CCD for the the same key. This can
be used to override specific entries in the CCD if desired.
"""
self._ccd_pickle_path = ccd_pickle_path or _CCD_PICKLE_FILE
self._dict = _load_ccd_pickle_cached(self._ccd_pickle_path)
if user_ccd is not None:
if not user_ccd:
raise ValueError('User CCD cannot be an empty string.')
user_ccd_cifs = {
key: value.to_dict()
for key, value in cif_dict.parse_multi_data_cif(user_ccd).items()
}
self._dict.update(user_ccd_cifs)
def __getitem__(self, key: str) -> Mapping[str, Sequence[str]]:
return self._dict[key]
def __contains__(self, key: str) -> bool:
return key in self._dict
def __iter__(self) -> Iterator[str]:
return self._dict.__iter__()
def __len__(self) -> int:
return len(self._dict)
def __hash__(self) -> int:
return id(self) # Ok since this is immutable.
def get(
self, key: str, default: None | Mapping[str, Sequence[str]] = None
) -> Mapping[str, Sequence[str]] | None:
return self._dict.get(key, default)
def items(self) -> ItemsView[str, Mapping[str, Sequence[str]]]:
return self._dict.items()
def values(self) -> ValuesView[Mapping[str, Sequence[str]]]:
return self._dict.values()
def keys(self) -> KeysView[str]:
return self._dict.keys()
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class ComponentInfo:
name: str
type: str
pdbx_synonyms: str
formula: str
formula_weight: str
mon_nstd_parent_comp_id: str
mon_nstd_flag: str
pdbx_smiles: str
def mmcif_to_info(mmcif: Mapping[str, Sequence[str]]) -> ComponentInfo:
"""Converts CCD mmCIFs to component info. Missing fields are left empty."""
names = mmcif['_chem_comp.name']
types = mmcif['_chem_comp.type']
mon_nstd_parent_comp_ids = mmcif['_chem_comp.mon_nstd_parent_comp_id']
pdbx_synonyms = mmcif['_chem_comp.pdbx_synonyms']
formulas = mmcif['_chem_comp.formula']
formula_weights = mmcif['_chem_comp.formula_weight']
def front_or_empty(values: Sequence[str]) -> str:
return values[0] if values else ''
type_ = front_or_empty(types)
mon_nstd_parent_comp_id = front_or_empty(mon_nstd_parent_comp_ids)
if type_.lower() == 'non-polymer':
# Unset for non-polymers, e.g. water or ions.
mon_nstd_flag = '.'
elif mon_nstd_parent_comp_id == '?':
# A standard component - it doesn't have a standard parent, e.g. MET.
mon_nstd_flag = 'y'
else:
# A non-standard component, e.g. MSE.
mon_nstd_flag = 'n'
# Default SMILES is the canonical SMILES, but we fall back to the SMILES if a
# canonical SMILES is not available. Of canonical SMILES, we prefer ones from
# the OpenEye OEToolkits program.
canonical_pdbx_smiles = ''
fallback_pdbx_smiles = ''
descriptor_types = mmcif.get('_pdbx_chem_comp_descriptor.type', [])
descriptors = mmcif.get('_pdbx_chem_comp_descriptor.descriptor', [])
programs = mmcif.get('_pdbx_chem_comp_descriptor.program', [])
for descriptor_type, descriptor, program in zip(
descriptor_types, descriptors, programs
):
if descriptor_type == 'SMILES_CANONICAL':
if (not canonical_pdbx_smiles) or program == 'OpenEye OEToolkits':
canonical_pdbx_smiles = descriptor
if not fallback_pdbx_smiles and descriptor_type == 'SMILES':
fallback_pdbx_smiles = descriptor
pdbx_smiles = canonical_pdbx_smiles or fallback_pdbx_smiles
return ComponentInfo(
name=front_or_empty(names),
type=type_,
pdbx_synonyms=front_or_empty(pdbx_synonyms),
formula=front_or_empty(formulas),
formula_weight=front_or_empty(formula_weights),
mon_nstd_parent_comp_id=mon_nstd_parent_comp_id,
mon_nstd_flag=mon_nstd_flag,
pdbx_smiles=pdbx_smiles,
)
@functools.lru_cache(maxsize=128)
def component_name_to_info(ccd: Ccd, res_name: str) -> ComponentInfo | None:
component = ccd.get(res_name)
if component is None:
return None
return mmcif_to_info(component)
def type_symbol(ccd: Ccd, res_name: str, atom_name: str) -> str:
"""Returns the element type for the given component name and atom name.
Args:
ccd: The chemical components dictionary.
res_name: The component name, e.g. ARG.
atom_name: The atom name, e.g. CB, OXT, or NH1.
Returns:
Element type, e.g. C for (ARG, CB), O for (ARG, OXT), N for (ARG, NH1).
"""
res = ccd.get(res_name)
if res is None:
return '?'
try:
return res['_chem_comp_atom.type_symbol'][
res['_chem_comp_atom.atom_id'].index(atom_name)
]
except (ValueError, IndexError, KeyError):
return '?'
================================================
FILE: src/alphafold3/constants/converters/ccd_pickle_gen.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Reads Chemical Components gz file and generates a CCD pickle file."""
from collections.abc import Sequence
import gzip
import pickle
import sys
from alphafold3.cpp import cif_dict
import tqdm
def main(argv: Sequence[str]) -> None:
if len(argv) != 3:
raise ValueError('Must specify input_file components.cif and output_file')
_, input_file, output_file = argv
print(f'Parsing {input_file}', flush=True)
if input_file.endswith('.gz'):
opener = gzip.open
else:
opener = open
with opener(input_file, 'rb') as f:
whole_file = f.read()
result = {
key: value.to_dict()
for key, value in tqdm.tqdm(
cif_dict.parse_multi_data_cif(whole_file).items(), disable=None
)
}
assert len(result) == whole_file.count(b'data_')
print(f'Writing {output_file}', flush=True)
with open(output_file, 'wb') as f:
pickle.dump(result, f, protocol=pickle.HIGHEST_PROTOCOL)
print('Done', flush=True)
if __name__ == '__main__':
main(sys.argv)
================================================
FILE: src/alphafold3/constants/converters/chemical_component_sets_gen.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Script for updating chemical_component_sets.py."""
from collections.abc import Mapping, Sequence
import pathlib
import pickle
import re
import sys
from alphafold3.common import resources
from alphafold3.common import safe_pickle
import tqdm
_CCD_PICKLE_FILE = resources.filename(
'constants/converters/ccd.pickle'
)
def find_ions_and_glycans_in_ccd(
ccd: Mapping[str, Mapping[str, Sequence[str]]],
) -> dict[str, frozenset[str]]:
"""Finds glycans and ions in all version of CCD."""
glycans_linking = []
glycans_other = []
ions = []
for name, comp in tqdm.tqdm(ccd.items(), disable=None):
if name == 'UNX':
continue # Skip "unknown atom or ion".
comp_type = comp['_chem_comp.type'][0].lower()
# Glycans have the type 'saccharide'.
if re.findall(r'\bsaccharide\b', comp_type):
# Separate out linking glycans from others.
if 'linking' in comp_type:
glycans_linking.append(name)
else:
glycans_other.append(name)
# Ions have the word 'ion' in their name.
comp_name = comp['_chem_comp.name'][0].lower()
if re.findall(r'\bion\b', comp_name):
ions.append(name)
result = dict(
glycans_linking=frozenset(glycans_linking),
glycans_other=frozenset(glycans_other),
ions=frozenset(ions),
)
return result
def main(argv: Sequence[str]) -> None:
if len(argv) != 2:
raise ValueError(
'Directory to write to must be specified as a command-line arguments.'
)
print(f'Loading {_CCD_PICKLE_FILE}', flush=True)
with open(_CCD_PICKLE_FILE, 'rb') as f:
ccd: Mapping[str, Mapping[str, Sequence[str]]] = safe_pickle.load(f)
output_path = pathlib.Path(argv[1])
output_path.parent.mkdir(exist_ok=True)
print('Finding ions and glycans', flush=True)
result = find_ions_and_glycans_in_ccd(ccd)
print(f'writing to {output_path}', flush=True)
with output_path.open('wb') as f:
pickle.dump(result, f)
print('Done', flush=True)
if __name__ == '__main__':
main(sys.argv)
================================================
FILE: src/alphafold3/constants/mmcif_names.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Names of things in mmCIF format.
See https://www.iucr.org/__data/iucr/cifdic_html/2/cif_mm.dic/index.html
"""
from collections.abc import Mapping, Sequence, Set
from typing import Final
from alphafold3.constants import atom_types
from alphafold3.constants import residue_names
# The following are all possible values for the "_entity.type".
# https://mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_entity.type.html
BRANCHED_CHAIN: Final[str] = 'branched'
MACROLIDE_CHAIN: Final[str] = 'macrolide'
NON_POLYMER_CHAIN: Final[str] = 'non-polymer'
POLYMER_CHAIN: Final[str] = 'polymer'
WATER: Final[str] = 'water'
CYCLIC_PSEUDO_PEPTIDE_CHAIN: Final[str] = 'cyclic-pseudo-peptide'
DNA_CHAIN: Final[str] = 'polydeoxyribonucleotide'
DNA_RNA_HYBRID_CHAIN: Final[str] = (
'polydeoxyribonucleotide/polyribonucleotide hybrid'
)
OTHER_CHAIN: Final[str] = 'other'
PEPTIDE_NUCLEIC_ACID_CHAIN: Final[str] = 'peptide nucleic acid'
POLYPEPTIDE_D_CHAIN: Final[str] = 'polypeptide(D)'
PROTEIN_CHAIN: Final[str] = 'polypeptide(L)'
RNA_CHAIN: Final[str] = 'polyribonucleotide'
# Most common _entity_poly.types.
STANDARD_POLYMER_CHAIN_TYPES: Final[Set[str]] = {
PROTEIN_CHAIN,
DNA_CHAIN,
RNA_CHAIN,
}
# Possible values for _entity.type other than polymer and water.
LIGAND_CHAIN_TYPES: Final[Set[str]] = {
BRANCHED_CHAIN,
MACROLIDE_CHAIN,
NON_POLYMER_CHAIN,
}
# Possible values for _entity.type other than polymer.
NON_POLYMER_CHAIN_TYPES: Final[Set[str]] = {
*LIGAND_CHAIN_TYPES,
WATER,
}
# Peptide possible values for _entity_poly.type.
PEPTIDE_CHAIN_TYPES: Final[Set[str]] = {
CYCLIC_PSEUDO_PEPTIDE_CHAIN,
POLYPEPTIDE_D_CHAIN,
PROTEIN_CHAIN,
PEPTIDE_NUCLEIC_ACID_CHAIN,
}
# Nucleic-acid possible values for _entity_poly.type.
NUCLEIC_ACID_CHAIN_TYPES: Final[Set[str]] = {
RNA_CHAIN,
DNA_CHAIN,
DNA_RNA_HYBRID_CHAIN,
}
# All possible values for _entity_poly.type.
POLYMER_CHAIN_TYPES: Final[Set[str]] = {
*NUCLEIC_ACID_CHAIN_TYPES,
*PEPTIDE_CHAIN_TYPES,
OTHER_CHAIN,
}
TERMINAL_OXYGENS: Final[Mapping[str, str]] = {
PROTEIN_CHAIN: 'OXT',
DNA_CHAIN: 'OP3',
RNA_CHAIN: 'OP3',
}
# For each chain type, which atom should be used to represent each residue.
RESIDUE_REPRESENTATIVE_ATOMS: Final[Mapping[str, str]] = {
PROTEIN_CHAIN: atom_types.CA,
DNA_CHAIN: atom_types.C1PRIME,
RNA_CHAIN: atom_types.C1PRIME,
}
# Methods involving crystallization. See the documentation at
# mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_exptl.method.html
# for the full list of experimental methods.
CRYSTALLIZATION_METHODS: Final[Set[str]] = {
'X-RAY DIFFRACTION',
'NEUTRON DIFFRACTION',
'ELECTRON CRYSTALLOGRAPHY',
'POWDER CRYSTALLOGRAPHY',
'FIBER DIFFRACTION',
}
# Possible bond types.
COVALENT_BOND: Final[str] = 'covale'
HYDROGEN_BOND: Final[str] = 'hydrog'
METAL_COORDINATION: Final[str] = 'metalc'
DISULFIDE_BRIDGE: Final[str] = 'disulf'
def is_standard_polymer_type(chain_type: str) -> bool:
"""Returns if chain type is a protein, DNA or RNA chain type.
Args:
chain_type: The type of the chain.
Returns:
A bool for if the chain_type matches protein, DNA, or RNA.
"""
return chain_type in STANDARD_POLYMER_CHAIN_TYPES
def guess_polymer_type(chain_residues: Sequence[str]) -> str:
"""Guess the polymer type (protein/rna/dna/other) based on the residues.
The polymer type is guessed by first checking for any of the standard
protein residues. If one is present then the chain is considered to be a
polypeptide. Otherwise we decide by counting residue types and deciding by
majority voting (e.g. mostly DNA residues -> DNA). If there is a tie between
the counts, the ordering is rna > dna > other.
Note that we count MSE and UNK as protein residues.
Args:
chain_residues: A sequence of full residue name (1-letter for DNA, 2-letters
for RNA, 3 for protein). The _atom_site.label_comp_id column in mmCIF.
Returns:
The most probable chain type as set in the _entity_poly mmCIF table:
protein - polypeptide(L), rna - polyribonucleotide,
dna - polydeoxyribonucleotide or other.
"""
residue_types = {
**{r: RNA_CHAIN for r in residue_names.RNA_TYPES},
**{r: DNA_CHAIN for r in residue_names.DNA_TYPES},
**{r: PROTEIN_CHAIN for r in residue_names.PROTEIN_TYPES_WITH_UNKNOWN},
residue_names.MSE: PROTEIN_CHAIN,
}
counts = {PROTEIN_CHAIN: 0, RNA_CHAIN: 0, DNA_CHAIN: 0, OTHER_CHAIN: 0}
for residue in chain_residues:
residue_type = residue_types.get(residue, OTHER_CHAIN)
# If we ever see a protein residue we'll consider this a polypeptide(L).
if residue_type == PROTEIN_CHAIN:
return residue_type
counts[residue_type] += 1
# Make sure protein > rna > dna > other if there is a tie.
tie_braker = {PROTEIN_CHAIN: 3, RNA_CHAIN: 2, DNA_CHAIN: 1, OTHER_CHAIN: 0}
def order_fn(item):
name, count = item
return count, tie_braker[name]
most_probable_type = max(counts.items(), key=order_fn)[0]
return most_probable_type
def fix_non_standard_polymer_res(*, res_name: str, chain_type: str) -> str:
"""Returns the res_name of the closest standard protein/RNA/DNA residue.
Optimized for the case where a single residue needs to be converted.
If res_name is already a standard type, it is returned unaltered.
If a match cannot be found, returns 'UNK' for protein chains and 'N' for
RNA/DNA chains.
Args:
res_name: A residue_name (monomer code from the CCD).
chain_type: The type of the chain, must be PROTEIN_CHAIN, RNA_CHAIN or
DNA_CHAIN.
Returns:
An element from PROTEIN_TYPES_WITH_UNKNOWN | RNA_TYPES | DNA_TYPES | {'N'}.
Raises:
ValueError: If chain_type not in PEPTIDE_CHAIN_TYPES or
{OTHER_CHAIN, RNA_CHAIN, DNA_CHAIN, DNA_RNA_HYBRID_CHAIN}.
"""
# Map to one letter code, then back to common res_names.
one_letter_code = residue_names.letters_three_to_one(res_name, default='X')
if chain_type in PEPTIDE_CHAIN_TYPES or chain_type == OTHER_CHAIN:
return residue_names.PROTEIN_COMMON_ONE_TO_THREE.get(one_letter_code, 'UNK')
elif chain_type == RNA_CHAIN:
# RNA's CCD monomer code is single-letter.
return (
one_letter_code if one_letter_code in residue_names.RNA_TYPES else 'N'
)
elif chain_type == DNA_CHAIN:
return residue_names.DNA_COMMON_ONE_TO_TWO.get(one_letter_code, 'N')
elif chain_type == DNA_RNA_HYBRID_CHAIN:
return (
res_name
if res_name in residue_names.NUCLEIC_TYPES_WITH_UNKNOWN
else 'N'
)
else:
raise ValueError(f'Expected a protein/DNA/RNA chain but got {chain_type}')
================================================
FILE: src/alphafold3/constants/periodic_table.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Periodic table of elements."""
from collections.abc import Mapping, Sequence
import dataclasses
from typing import Final
import numpy as np
@dataclasses.dataclass(frozen=True, kw_only=True)
class Element:
name: str
number: int
symbol: str
weight: float
# Weights taken from rdkit/Code/GraphMol/atomic_data.cpp for compatibility.
# pylint: disable=invalid-name
# X is an unknown element that can be present in the CCD,
# https://www.rcsb.org/ligand/UNX.
X: Final[Element] = Element(name='Unknown', number=0, symbol='X', weight=0.0)
H: Final[Element] = Element(name='Hydrogen', number=1, symbol='H', weight=1.008)
He: Final[Element] = Element(name='Helium', number=2, symbol='He', weight=4.003)
Li: Final[Element] = Element(
name='Lithium', number=3, symbol='Li', weight=6.941
)
Be: Final[Element] = Element(
name='Beryllium', number=4, symbol='Be', weight=9.012
)
B: Final[Element] = Element(name='Boron', number=5, symbol='B', weight=10.812)
C: Final[Element] = Element(name='Carbon', number=6, symbol='C', weight=12.011)
N: Final[Element] = Element(
name='Nitrogen', number=7, symbol='N', weight=14.007
)
O: Final[Element] = Element(name='Oxygen', number=8, symbol='O', weight=15.999)
F: Final[Element] = Element(
name='Fluorine', number=9, symbol='F', weight=18.998
)
Ne: Final[Element] = Element(name='Neon', number=10, symbol='Ne', weight=20.18)
Na: Final[Element] = Element(
name='Sodium', number=11, symbol='Na', weight=22.99
)
Mg: Final[Element] = Element(
name='Magnesium', number=12, symbol='Mg', weight=24.305
)
Al: Final[Element] = Element(
name='Aluminium', number=13, symbol='Al', weight=26.982
)
Si: Final[Element] = Element(
name='Silicon', number=14, symbol='Si', weight=28.086
)
P: Final[Element] = Element(
name='Phosphorus', number=15, symbol='P', weight=30.974
)
S: Final[Element] = Element(name='Sulfur', number=16, symbol='S', weight=32.067)
Cl: Final[Element] = Element(
name='Chlorine', number=17, symbol='Cl', weight=35.453
)
Ar: Final[Element] = Element(
name='Argon', number=18, symbol='Ar', weight=39.948
)
K: Final[Element] = Element(
name='Potassium', number=19, symbol='K', weight=39.098
)
Ca: Final[Element] = Element(
name='Calcium', number=20, symbol='Ca', weight=40.078
)
Sc: Final[Element] = Element(
name='Scandium', number=21, symbol='Sc', weight=44.956
)
Ti: Final[Element] = Element(
name='Titanium', number=22, symbol='Ti', weight=47.867
)
V: Final[Element] = Element(
name='Vanadium', number=23, symbol='V', weight=50.942
)
Cr: Final[Element] = Element(
name='Chromium', number=24, symbol='Cr', weight=51.996
)
Mn: Final[Element] = Element(
name='Manganese', number=25, symbol='Mn', weight=54.938
)
Fe: Final[Element] = Element(name='Iron', number=26, symbol='Fe', weight=55.845)
Co: Final[Element] = Element(
name='Cobalt', number=27, symbol='Co', weight=58.933
)
Ni: Final[Element] = Element(
name='Nickel', number=28, symbol='Ni', weight=58.693
)
Cu: Final[Element] = Element(
name='Copper', number=29, symbol='Cu', weight=63.546
)
Zn: Final[Element] = Element(name='Zinc', number=30, symbol='Zn', weight=65.39)
Ga: Final[Element] = Element(
name='Gallium', number=31, symbol='Ga', weight=69.723
)
Ge: Final[Element] = Element(
name='Germanium', number=32, symbol='Ge', weight=72.61
)
As: Final[Element] = Element(
name='Arsenic', number=33, symbol='As', weight=74.922
)
Se: Final[Element] = Element(
name='Selenium', number=34, symbol='Se', weight=78.96
)
Br: Final[Element] = Element(
name='Bromine', number=35, symbol='Br', weight=79.904
)
Kr: Final[Element] = Element(
name='Krypton', number=36, symbol='Kr', weight=83.8
)
Rb: Final[Element] = Element(
name='Rubidium', number=37, symbol='Rb', weight=85.468
)
Sr: Final[Element] = Element(
name='Strontium', number=38, symbol='Sr', weight=87.62
)
Y: Final[Element] = Element(
name='Yttrium', number=39, symbol='Y', weight=88.906
)
Zr: Final[Element] = Element(
name='Zirconium', number=40, symbol='Zr', weight=91.224
)
Nb: Final[Element] = Element(
name='Niobiu', number=41, symbol='Nb', weight=92.906
)
Mo: Final[Element] = Element(
name='Molybdenum', number=42, symbol='Mo', weight=95.94
)
Tc: Final[Element] = Element(
name='Technetium', number=43, symbol='Tc', weight=98
)
Ru: Final[Element] = Element(
name='Ruthenium', number=44, symbol='Ru', weight=101.07
)
Rh: Final[Element] = Element(
name='Rhodium', number=45, symbol='Rh', weight=102.906
)
Pd: Final[Element] = Element(
name='Palladium', number=46, symbol='Pd', weight=106.42
)
Ag: Final[Element] = Element(
name='Silver', number=47, symbol='Ag', weight=107.868
)
Cd: Final[Element] = Element(
name='Cadmium', number=48, symbol='Cd', weight=112.412
)
In: Final[Element] = Element(
name='Indium', number=49, symbol='In', weight=114.818
)
Sn: Final[Element] = Element(name='Tin', number=50, symbol='Sn', weight=118.711)
Sb: Final[Element] = Element(
name='Antimony', number=51, symbol='Sb', weight=121.76
)
Te: Final[Element] = Element(
name='Tellurium', number=52, symbol='Te', weight=127.6
)
I: Final[Element] = Element(
name='Iodine', number=53, symbol='I', weight=126.904
)
Xe: Final[Element] = Element(
name='Xenon', number=54, symbol='Xe', weight=131.29
)
Cs: Final[Element] = Element(
name='Caesium', number=55, symbol='Cs', weight=132.905
)
Ba: Final[Element] = Element(
name='Barium', number=56, symbol='Ba', weight=137.328
)
La: Final[Element] = Element(
name='Lanthanum', number=57, symbol='La', weight=138.906
)
Ce: Final[Element] = Element(
name='Cerium', number=58, symbol='Ce', weight=140.116
)
Pr: Final[Element] = Element(
name='Praseodymium', number=59, symbol='Pr', weight=140.908
)
Nd: Final[Element] = Element(
name='Neodymium', number=60, symbol='Nd', weight=144.24
)
Pm: Final[Element] = Element(
name='Promethium', number=61, symbol='Pm', weight=145
)
Sm: Final[Element] = Element(
name='Samarium', number=62, symbol='Sm', weight=150.36
)
Eu: Final[Element] = Element(
name='Europium', number=63, symbol='Eu', weight=151.964
)
Gd: Final[Element] = Element(
name='Gadolinium', number=64, symbol='Gd', weight=157.25
)
Tb: Final[Element] = Element(
name='Terbium', number=65, symbol='Tb', weight=158.925
)
Dy: Final[Element] = Element(
name='Dysprosium', number=66, symbol='Dy', weight=162.5
)
Ho: Final[Element] = Element(
name='Holmium', number=67, symbol='Ho', weight=164.93
)
Er: Final[Element] = Element(
name='Erbium', number=68, symbol='Er', weight=167.26
)
Tm: Final[Element] = Element(
name='Thulium', number=69, symbol='Tm', weight=168.934
)
Yb: Final[Element] = Element(
name='Ytterbium', number=70, symbol='Yb', weight=173.04
)
Lu: Final[Element] = Element(
name='Lutetium', number=71, symbol='Lu', weight=174.967
)
Hf: Final[Element] = Element(
name='Hafnium', number=72, symbol='Hf', weight=178.49
)
Ta: Final[Element] = Element(
name='Tantalum', number=73, symbol='Ta', weight=180.948
)
W: Final[Element] = Element(
name='Tungsten', number=74, symbol='W', weight=183.84
)
Re: Final[Element] = Element(
name='Rhenium', number=75, symbol='Re', weight=186.207
)
Os: Final[Element] = Element(
name='Osmium', number=76, symbol='Os', weight=190.23
)
Ir: Final[Element] = Element(
name='Iridium', number=77, symbol='Ir', weight=192.217
)
Pt: Final[Element] = Element(
name='Platinum', number=78, symbol='Pt', weight=195.078
)
Au: Final[Element] = Element(
name='Gold', number=79, symbol='Au', weight=196.967
)
Hg: Final[Element] = Element(
name='Mercury', number=80, symbol='Hg', weight=200.59
)
Tl: Final[Element] = Element(
name='Thallium', number=81, symbol='Tl', weight=204.383
)
Pb: Final[Element] = Element(name='Lead', number=82, symbol='Pb', weight=207.2)
Bi: Final[Element] = Element(
name='Bismuth', number=83, symbol='Bi', weight=208.98
)
Po: Final[Element] = Element(
name='Polonium', number=84, symbol='Po', weight=209
)
At: Final[Element] = Element(
name='Astatine', number=85, symbol='At', weight=210
)
Rn: Final[Element] = Element(name='Radon', number=86, symbol='Rn', weight=222)
Fr: Final[Element] = Element(
name='Francium', number=87, symbol='Fr', weight=223
)
Ra: Final[Element] = Element(name='Radium', number=88, symbol='Ra', weight=226)
Ac: Final[Element] = Element(
name='Actinium', number=89, symbol='Ac', weight=227
)
Th: Final[Element] = Element(
name='Thorium', number=90, symbol='Th', weight=232.038
)
Pa: Final[Element] = Element(
name='Protactinium', number=91, symbol='Pa', weight=231.036
)
U: Final[Element] = Element(
name='Uranium', number=92, symbol='U', weight=238.029
)
Np: Final[Element] = Element(
name='Neptunium', number=93, symbol='Np', weight=237
)
Pu: Final[Element] = Element(
name='Plutonium', number=94, symbol='Pu', weight=244
)
Am: Final[Element] = Element(
name='Americium', number=95, symbol='Am', weight=243
)
Cm: Final[Element] = Element(name='Curium', number=96, symbol='Cm', weight=247)
Bk: Final[Element] = Element(
name='Berkelium', number=97, symbol='Bk', weight=247
)
Cf: Final[Element] = Element(
name='Californium', number=98, symbol='Cf', weight=251
)
Es: Final[Element] = Element(
name='Einsteinium', number=99, symbol='Es', weight=252
)
Fm: Final[Element] = Element(
name='Fermium', number=100, symbol='Fm', weight=257
)
Md: Final[Element] = Element(
name='Mendelevium', number=101, symbol='Md', weight=258
)
No: Final[Element] = Element(
name='Nobelium', number=102, symbol='No', weight=259
)
Lr: Final[Element] = Element(
name='Lawrencium', number=103, symbol='Lr', weight=262
)
Rf: Final[Element] = Element(
name='Rutherfordium', number=104, symbol='Rf', weight=267
)
Db: Final[Element] = Element(
name='Dubnium', number=105, symbol='Db', weight=268
)
Sg: Final[Element] = Element(
name='Seaborgium', number=106, symbol='Sg', weight=269
)
Bh: Final[Element] = Element(
name='Bohrium', number=107, symbol='Bh', weight=270
)
Hs: Final[Element] = Element(
name='Hassium', number=108, symbol='Hs', weight=269
)
Mt: Final[Element] = Element(
name='Meitnerium', number=109, symbol='Mt', weight=278
)
Ds: Final[Element] = Element(
name='Darmstadtium', number=110, symbol='Ds', weight=281
)
Rg: Final[Element] = Element(
name='Roentgenium', number=111, symbol='Rg', weight=281
)
Cn: Final[Element] = Element(
name='Copernicium', number=112, symbol='Cn', weight=285
)
Nh: Final[Element] = Element(
name='Nihonium', number=113, symbol='Nh', weight=284
)
Fl: Final[Element] = Element(
name='Flerovium', number=114, symbol='Fl', weight=289
)
Mc: Final[Element] = Element(
name='Moscovium', number=115, symbol='Mc', weight=288
)
Lv: Final[Element] = Element(
name='Livermorium', number=116, symbol='Lv', weight=293
)
Ts: Final[Element] = Element(
name='Tennessine', number=117, symbol='Ts', weight=292
)
Og: Final[Element] = Element(
name='Oganesson', number=118, symbol='Og', weight=294
)
# pylint: enable=invalid-name
# fmt: off
# Lanthanides
_L: Final[Sequence[Element]] = (
La, Ce, Pr, Nd, Pm, Sm, Eu, Gd, Tb, Dy, Ho, Er, Tm, Yb, Lu)
# Actinides
_A: Final[Sequence[Element]] = (
Ac, Th, Pa, U, Np, Pu, Am, Cm, Bk, Cf, Es, Fm, Md, No, Lr)
# pylint: disable=bad-whitespace
PERIODIC_TABLE: Final[Sequence[Element]] = (
X, # Unknown
H, He,
Li, Be, B, C, N, O, F, Ne,
Na, Mg, Al, Si, P, S, Cl, Ar,
K, Ca, Sc, Ti, V, Cr, Mn, Fe, Co, Ni, Cu, Zn, Ga, Ge, As, Se, Br, Kr,
Rb, Sr, Y, Zr, Nb, Mo, Tc, Ru, Rh, Pd, Ag, Cd, In, Sn, Sb, Te, I, Xe,
Cs, Ba, *_L, Hf, Ta, W, Re, Os, Ir, Pt, Au, Hg, Tl, Pb, Bi, Po, At, Rn,
Fr, Ra, *_A, Rf, Db, Sg, Bh, Hs, Mt, Ds, Rg, Cn, Nh, Fl, Mc, Lv, Ts, Og
)
# pylint: enable=bad-whitespace
# fmt: on
ATOMIC_SYMBOL: Mapping[int, str] = {e.number: e.symbol for e in PERIODIC_TABLE}
ATOMIC_NUMBER = {e.symbol: e.number for e in PERIODIC_TABLE}
# Add Deuterium as previous table contained it.
ATOMIC_NUMBER['D'] = 1
ATOMIC_NUMBER: Mapping[str, int] = ATOMIC_NUMBER
ATOMIC_WEIGHT: np.ndarray = np.zeros(len(PERIODIC_TABLE), dtype=np.float64)
for e in PERIODIC_TABLE:
ATOMIC_WEIGHT[e.number] = e.weight
ATOMIC_WEIGHT.setflags(write=False)
================================================
FILE: src/alphafold3/constants/residue_names.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Constants associated with residue names."""
from collections.abc import Mapping
import functools
import sys
# pyformat: disable
# common_typos_disable
CCD_NAME_TO_ONE_LETTER: Mapping[str, str] = {
'00C': 'C', '01W': 'X', '02K': 'A', '03Y': 'C', '07O': 'C', '08P': 'C',
'0A0': 'D', '0A1': 'Y', '0A2': 'K', '0A8': 'C', '0AA': 'V', '0AB': 'V',
'0AC': 'G', '0AD': 'G', '0AF': 'W', '0AG': 'L', '0AH': 'S', '0AK': 'D',
'0AM': 'A', '0AP': 'C', '0AU': 'U', '0AV': 'A', '0AZ': 'P', '0BN': 'F',
'0C': 'C', '0CS': 'A', '0DC': 'C', '0DG': 'G', '0DT': 'T', '0FL': 'A',
'0G': 'G', '0NC': 'A', '0SP': 'A', '0U': 'U', '10C': 'C', '125': 'U',
'126': 'U', '127': 'U', '128': 'N', '12A': 'A', '143': 'C', '193': 'X',
'1AP': 'A', '1MA': 'A', '1MG': 'G', '1PA': 'F', '1PI': 'A', '1PR': 'N',
'1SC': 'C', '1TQ': 'W', '1TY': 'Y', '1X6': 'S', '200': 'F', '23F': 'F',
'23S': 'X', '26B': 'T', '2AD': 'X', '2AG': 'A', '2AO': 'X', '2AR': 'A',
'2AS': 'X', '2AT': 'T', '2AU': 'U', '2BD': 'I', '2BT': 'T', '2BU': 'A',
'2CO': 'C', '2DA': 'A', '2DF': 'N', '2DM': 'N', '2DO': 'X', '2DT': 'T',
'2EG': 'G', '2FE': 'N', '2FI': 'N', '2FM': 'M', '2GT': 'T', '2HF': 'H',
'2LU': 'L', '2MA': 'A', '2MG': 'G', '2ML': 'L', '2MR': 'R', '2MT': 'P',
'2MU': 'U', '2NT': 'T', '2OM': 'U', '2OT': 'T', '2PI': 'X', '2PR': 'G',
'2SA': 'N', '2SI': 'X', '2ST': 'T', '2TL': 'T', '2TY': 'Y', '2VA': 'V',
'2XA': 'C', '32S': 'X', '32T': 'X', '3AH': 'H', '3AR': 'X', '3CF': 'F',
'3DA': 'A', '3DR': 'N', '3GA': 'A', '3MD': 'D', '3ME': 'U', '3NF': 'Y',
'3QN': 'K', '3TY': 'X', '3XH': 'G', '4AC': 'N', '4BF': 'Y', '4CF': 'F',
'4CY': 'M', '4DP': 'W', '4FB': 'P', '4FW': 'W', '4HT': 'W', '4IN': 'W',
'4MF': 'N', '4MM': 'X', '4OC': 'C', '4PC': 'C', '4PD': 'C', '4PE': 'C',
'4PH': 'F', '4SC': 'C', '4SU': 'U', '4TA': 'N', '4U7': 'A', '56A': 'H',
'5AA': 'A', '5AB': 'A', '5AT': 'T', '5BU': 'U', '5CG': 'G', '5CM': 'C',
'5CS': 'C', '5FA': 'A', '5FC': 'C', '5FU': 'U', '5HP': 'E', '5HT': 'T',
'5HU': 'U', '5IC': 'C', '5IT': 'T', '5IU': 'U', '5MC': 'C', '5MD': 'N',
'5MU': 'U', '5NC': 'C', '5PC': 'C', '5PY': 'T', '5SE': 'U', '64T': 'T',
'6CL': 'K', '6CT': 'T', '6CW': 'W', '6HA': 'A', '6HC': 'C', '6HG': 'G',
'6HN': 'K', '6HT': 'T', '6IA': 'A', '6MA': 'A', '6MC': 'A', '6MI': 'N',
'6MT': 'A', '6MZ': 'N', '6OG': 'G', '70U': 'U', '7DA': 'A', '7GU': 'G',
'7JA': 'I', '7MG': 'G', '8AN': 'A', '8FG': 'G', '8MG': 'G', '8OG': 'G',
'9NE': 'E', '9NF': 'F', '9NR': 'R', '9NV': 'V', 'A': 'A', 'A1P': 'N',
'A23': 'A', 'A2L': 'A', 'A2M': 'A', 'A34': 'A', 'A35': 'A', 'A38': 'A',
'A39': 'A', 'A3A': 'A', 'A3P': 'A', 'A40': 'A', 'A43': 'A', 'A44': 'A',
'A47': 'A', 'A5L': 'A', 'A5M': 'C', 'A5N': 'N', 'A5O': 'A', 'A66': 'X',
'AA3': 'A', 'AA4': 'A', 'AAR': 'R', 'AB7': 'X', 'ABA': 'A', 'ABR': 'A',
'ABS': 'A', 'ABT': 'N', 'ACB': 'D', 'ACL': 'R', 'AD2': 'A', 'ADD': 'X',
'ADX': 'N', 'AEA': 'X', 'AEI': 'D', 'AET': 'A', 'AFA': 'N', 'AFF': 'N',
'AFG': 'G', 'AGM': 'R', 'AGT': 'C', 'AHB': 'N', 'AHH': 'X', 'AHO': 'A',
'AHP': 'A', 'AHS': 'X', 'AHT': 'X', 'AIB': 'A', 'AKL': 'D', 'AKZ': 'D',
'ALA': 'A', 'ALC': 'A', 'ALM': 'A', 'ALN': 'A', 'ALO': 'T', 'ALQ': 'X',
'ALS': 'A', 'ALT': 'A', 'ALV': 'A', 'ALY': 'K', 'AN8': 'A', 'AP7': 'A',
'APE': 'X', 'APH': 'A', 'API': 'K', 'APK': 'K', 'APM': 'X', 'APP': 'X',
'AR2': 'R', 'AR4': 'E', 'AR7': 'R', 'ARG': 'R', 'ARM': 'R', 'ARO': 'R',
'ARV': 'X', 'AS': 'A', 'AS2': 'D', 'AS9': 'X', 'ASA': 'D', 'ASB': 'D',
'ASI': 'D', 'ASK': 'D', 'ASL': 'D', 'ASM': 'X', 'ASN': 'N', 'ASP': 'D',
'ASQ': 'D', 'ASU': 'N', 'ASX': 'B', 'ATD': 'T', 'ATL': 'T', 'ATM': 'T',
'AVC': 'A', 'AVN': 'X', 'AYA': 'A', 'AZK': 'K', 'AZS': 'S', 'AZY': 'Y',
'B1F': 'F', 'B1P': 'N', 'B2A': 'A', 'B2F': 'F', 'B2I': 'I', 'B2V': 'V',
'B3A': 'A', 'B3D': 'D', 'B3E': 'E', 'B3K': 'K', 'B3L': 'X', 'B3M': 'X',
'B3Q': 'X', 'B3S': 'S', 'B3T': 'X', 'B3U': 'H', 'B3X': 'N', 'B3Y': 'Y',
'BB6': 'C', 'BB7': 'C', 'BB8': 'F', 'BB9': 'C', 'BBC': 'C', 'BCS': 'C',
'BE2': 'X', 'BFD': 'D', 'BG1': 'S', 'BGM': 'G', 'BH2': 'D', 'BHD': 'D',
'BIF': 'F', 'BIL': 'X', 'BIU': 'I', 'BJH': 'X', 'BLE': 'L', 'BLY': 'K',
'BMP': 'N', 'BMT': 'T', 'BNN': 'F', 'BNO': 'X', 'BOE': 'T', 'BOR': 'R',
'BPE': 'C', 'BRU': 'U', 'BSE': 'S', 'BT5': 'N', 'BTA': 'L', 'BTC': 'C',
'BTR': 'W', 'BUC': 'C', 'BUG': 'V', 'BVP': 'U', 'BZG': 'N', 'C': 'C',
'C1X': 'K', 'C25': 'C', 'C2L': 'C', 'C2S': 'C', 'C31': 'C', 'C32': 'C',
'C34': 'C', 'C36': 'C', 'C37': 'C', 'C38': 'C', 'C3Y': 'C', 'C42': 'C',
'C43': 'C', 'C45': 'C', 'C46': 'C', 'C49': 'C', 'C4R': 'C', 'C4S': 'C',
'C5C': 'C', 'C66': 'X', 'C6C': 'C', 'CAF': 'C', 'CAL': 'X', 'CAR': 'C',
'CAS': 'C', 'CAV': 'X', 'CAY': 'C', 'CB2': 'C', 'CBR': 'C', 'CBV': 'C',
'CCC': 'C', 'CCL': 'K', 'CCS': 'C', 'CDE': 'X', 'CDV': 'X', 'CDW': 'C',
'CEA': 'C', 'CFL': 'C', 'CG1': 'G', 'CGA': 'E', 'CGU': 'E', 'CH': 'C',
'CHF': 'X', 'CHG': 'X', 'CHP': 'G', 'CHS': 'X', 'CIR': 'R', 'CLE': 'L',
'CLG': 'K', 'CLH': 'K', 'CM0': 'N', 'CME': 'C', 'CMH': 'C', 'CML': 'C',
'CMR': 'C', 'CMT': 'C', 'CNU': 'U', 'CP1': 'C', 'CPC': 'X', 'CPI': 'X',
'CR5': 'G', 'CS0': 'C', 'CS1': 'C', 'CS3': 'C', 'CS4': 'C', 'CS8': 'N',
'CSA': 'C', 'CSB': 'C', 'CSD': 'C', 'CSE': 'C', 'CSF': 'C', 'CSI': 'G',
'CSJ': 'C', 'CSL': 'C', 'CSO': 'C', 'CSP': 'C', 'CSR': 'C', 'CSS': 'C',
'CSU': 'C', 'CSW': 'C', 'CSX': 'C', 'CSZ': 'C', 'CTE': 'W', 'CTG': 'T',
'CTH': 'T', 'CUC': 'X', 'CWR': 'S', 'CXM': 'M', 'CY0': 'C', 'CY1': 'C',
'CY3': 'C', 'CY4': 'C', 'CYA': 'C', 'CYD': 'C', 'CYF': 'C', 'CYG': 'C',
'CYJ': 'X', 'CYM': 'C', 'CYQ': 'C', 'CYR': 'C', 'CYS': 'C', 'CZ2': 'C',
'CZZ': 'C', 'D11': 'T', 'D1P': 'N', 'D3': 'N', 'D33': 'N', 'D3P': 'G',
'D3T': 'T', 'D4M': 'T', 'D4P': 'X', 'DA': 'A', 'DA2': 'X', 'DAB': 'A',
'DAH': 'F', 'DAL': 'A', 'DAR': 'R', 'DAS': 'D', 'DBB': 'T', 'DBM': 'N',
'DBS': 'S', 'DBU': 'T', 'DBY': 'Y', 'DBZ': 'A', 'DC': 'C', 'DC2': 'C',
'DCG': 'G', 'DCI': 'X', 'DCL': 'X', 'DCT': 'C', 'DCY': 'C', 'DDE': 'H',
'DDG': 'G', 'DDN': 'U', 'DDX': 'N', 'DFC': 'C', 'DFG': 'G', 'DFI': 'X',
'DFO': 'X', 'DFT': 'N', 'DG': 'G', 'DGH': 'G', 'DGI': 'G', 'DGL': 'E',
'DGN': 'Q', 'DHA': 'S', 'DHI': 'H', 'DHL': 'X', 'DHN': 'V', 'DHP': 'X',
'DHU': 'U', 'DHV': 'V', 'DI': 'I', 'DIL': 'I', 'DIR': 'R', 'DIV': 'V',
'DLE': 'L', 'DLS': 'K', 'DLY': 'K', 'DM0': 'K', 'DMH': 'N', 'DMK': 'D',
'DMT': 'X', 'DN': 'N', 'DNE': 'L', 'DNG': 'L', 'DNL': 'K', 'DNM': 'L',
'DNP': 'A', 'DNR': 'C', 'DNS': 'K', 'DOA': 'X', 'DOC': 'C', 'DOH': 'D',
'DON': 'L', 'DPB': 'T', 'DPH': 'F', 'DPL': 'P', 'DPP': 'A', 'DPQ': 'Y',
'DPR': 'P', 'DPY': 'N', 'DRM': 'U', 'DRP': 'N', 'DRT': 'T', 'DRZ': 'N',
'DSE': 'S', 'DSG': 'N', 'DSN': 'S', 'DSP': 'D', 'DT': 'T', 'DTH': 'T',
'DTR': 'W', 'DTY': 'Y', 'DU': 'U', 'DVA': 'V', 'DXD': 'N', 'DXN': 'N',
'DYS': 'C', 'DZM': 'A', 'E': 'A', 'E1X': 'A', 'ECC': 'Q', 'EDA': 'A',
'EFC': 'C', 'EHP': 'F', 'EIT': 'T', 'ENP': 'N', 'ESB': 'Y', 'ESC': 'M',
'EXB': 'X', 'EXY': 'L', 'EY5': 'N', 'EYS': 'X', 'F2F': 'F', 'FA2': 'A',
'FA5': 'N', 'FAG': 'N', 'FAI': 'N', 'FB5': 'A', 'FB6': 'A', 'FCL': 'F',
'FFD': 'N', 'FGA': 'E', 'FGL': 'G', 'FGP': 'S', 'FHL': 'X', 'FHO': 'K',
'FHU': 'U', 'FLA': 'A', 'FLE': 'L', 'FLT': 'Y', 'FME': 'M', 'FMG': 'G',
'FMU': 'N', 'FOE': 'C', 'FOX': 'G', 'FP9': 'P', 'FPA': 'F', 'FRD': 'X',
'FT6': 'W', 'FTR': 'W', 'FTY': 'Y', 'FVA': 'V', 'FZN': 'K', 'G': 'G',
'G25': 'G', 'G2L': 'G', 'G2S': 'G', 'G31': 'G', 'G32': 'G', 'G33': 'G',
'G36': 'G', 'G38': 'G', 'G42': 'G', 'G46': 'G', 'G47': 'G', 'G48': 'G',
'G49': 'G', 'G4P': 'N', 'G7M': 'G', 'GAO': 'G', 'GAU': 'E', 'GCK': 'C',
'GCM': 'X', 'GDP': 'G', 'GDR': 'G', 'GFL': 'G', 'GGL': 'E', 'GH3': 'G',
'GHG': 'Q', 'GHP': 'G', 'GL3': 'G', 'GLH': 'Q', 'GLJ': 'E', 'GLK': 'E',
'GLM': 'X', 'GLN': 'Q', 'GLQ': 'E', 'GLU': 'E', 'GLX': 'Z', 'GLY': 'G',
'GLZ': 'G', 'GMA': 'E', 'GMS': 'G', 'GMU': 'U', 'GN7': 'G', 'GND': 'X',
'GNE': 'N', 'GOM': 'G', 'GPL': 'K', 'GS': 'G', 'GSC': 'G', 'GSR': 'G',
'GSS': 'G', 'GSU': 'E', 'GT9': 'C', 'GTP': 'G', 'GVL': 'X', 'H2U': 'U',
'H5M': 'P', 'HAC': 'A', 'HAR': 'R', 'HBN': 'H', 'HCS': 'X', 'HDP': 'U',
'HEU': 'U', 'HFA': 'X', 'HGL': 'X', 'HHI': 'H', 'HIA': 'H', 'HIC': 'H',
'HIP': 'H', 'HIQ': 'H', 'HIS': 'H', 'HL2': 'L', 'HLU': 'L', 'HMR': 'R',
'HOL': 'N', 'HPC': 'F', 'HPE': 'F', 'HPH': 'F', 'HPQ': 'F', 'HQA': 'A',
'HRG': 'R', 'HRP': 'W', 'HS8': 'H', 'HS9': 'H', 'HSE': 'S', 'HSL': 'S',
'HSO': 'H', 'HTI': 'C', 'HTN': 'N', 'HTR': 'W', 'HV5': 'A', 'HVA': 'V',
'HY3': 'P', 'HYP': 'P', 'HZP': 'P', 'I': 'I', 'I2M': 'I', 'I58': 'K',
'I5C': 'C', 'IAM': 'A', 'IAR': 'R', 'IAS': 'D', 'IC': 'C', 'IEL': 'K',
'IG': 'G', 'IGL': 'G', 'IGU': 'G', 'IIL': 'I', 'ILE': 'I', 'ILG': 'E',
'ILX': 'I', 'IMC': 'C', 'IML': 'I', 'IOY': 'F', 'IPG': 'G', 'IPN': 'N',
'IRN': 'N', 'IT1': 'K', 'IU': 'U', 'IYR': 'Y', 'IYT': 'T', 'IZO': 'M',
'JJJ': 'C', 'JJK': 'C', 'JJL': 'C', 'JW5': 'N', 'K1R': 'C', 'KAG': 'G',
'KCX': 'K', 'KGC': 'K', 'KNB': 'A', 'KOR': 'M', 'KPI': 'K', 'KST': 'K',
'KYQ': 'K', 'L2A': 'X', 'LA2': 'K', 'LAA': 'D', 'LAL': 'A', 'LBY': 'K',
'LC': 'C', 'LCA': 'A', 'LCC': 'N', 'LCG': 'G', 'LCH': 'N', 'LCK': 'K',
'LCX': 'K', 'LDH': 'K', 'LED': 'L', 'LEF': 'L', 'LEH': 'L', 'LEI': 'V',
'LEM': 'L', 'LEN': 'L', 'LET': 'X', 'LEU': 'L', 'LEX': 'L', 'LG': 'G',
'LGP': 'G', 'LHC': 'X', 'LHU': 'U', 'LKC': 'N', 'LLP': 'K', 'LLY': 'K',
'LME': 'E', 'LMF': 'K', 'LMQ': 'Q', 'LMS': 'N', 'LP6': 'K', 'LPD': 'P',
'LPG': 'G', 'LPL': 'X', 'LPS': 'S', 'LSO': 'X', 'LTA': 'X', 'LTR': 'W',
'LVG': 'G', 'LVN': 'V', 'LYF': 'K', 'LYK': 'K', 'LYM': 'K', 'LYN': 'K',
'LYR': 'K', 'LYS': 'K', 'LYX': 'K', 'LYZ': 'K', 'M0H': 'C', 'M1G': 'G',
'M2G': 'G', 'M2L': 'K', 'M2S': 'M', 'M30': 'G', 'M3L': 'K', 'M5M': 'C',
'MA': 'A', 'MA6': 'A', 'MA7': 'A', 'MAA': 'A', 'MAD': 'A', 'MAI': 'R',
'MBQ': 'Y', 'MBZ': 'N', 'MC1': 'S', 'MCG': 'X', 'MCL': 'K', 'MCS': 'C',
'MCY': 'C', 'MD3': 'C', 'MD6': 'G', 'MDH': 'X', 'MDR': 'N', 'MEA': 'F',
'MED': 'M', 'MEG': 'E', 'MEN': 'N', 'MEP': 'U', 'MEQ': 'Q', 'MET': 'M',
'MEU': 'G', 'MF3': 'X', 'MG1': 'G', 'MGG': 'R', 'MGN': 'Q', 'MGQ': 'A',
'MGV': 'G', 'MGY': 'G', 'MHL': 'L', 'MHO': 'M', 'MHS': 'H', 'MIA': 'A',
'MIS': 'S', 'MK8': 'L', 'ML3': 'K', 'MLE': 'L', 'MLL': 'L', 'MLY': 'K',
'MLZ': 'K', 'MME': 'M', 'MMO': 'R', 'MMT': 'T', 'MND': 'N', 'MNL': 'L',
'MNU': 'U', 'MNV': 'V', 'MOD': 'X', 'MP8': 'P', 'MPH': 'X', 'MPJ': 'X',
'MPQ': 'G', 'MRG': 'G', 'MSA': 'G', 'MSE': 'M', 'MSL': 'M', 'MSO': 'M',
'MSP': 'X', 'MT2': 'M', 'MTR': 'T', 'MTU': 'A', 'MTY': 'Y', 'MVA': 'V',
'N': 'N', 'N10': 'S', 'N2C': 'X', 'N5I': 'N', 'N5M': 'C', 'N6G': 'G',
'N7P': 'P', 'NA8': 'A', 'NAL': 'A', 'NAM': 'A', 'NB8': 'N', 'NBQ': 'Y',
'NC1': 'S', 'NCB': 'A', 'NCX': 'N', 'NCY': 'X', 'NDF': 'F', 'NDN': 'U',
'NEM': 'H', 'NEP': 'H', 'NF2': 'N', 'NFA': 'F', 'NHL': 'E', 'NIT': 'X',
'NIY': 'Y', 'NLE': 'L', 'NLN': 'L', 'NLO': 'L', 'NLP': 'L', 'NLQ': 'Q',
'NMC': 'G', 'NMM': 'R', 'NMS': 'T', 'NMT': 'T', 'NNH': 'R', 'NP3': 'N',
'NPH': 'C', 'NPI': 'A', 'NSK': 'X', 'NTY': 'Y', 'NVA': 'V', 'NYM': 'N',
'NYS': 'C', 'NZH': 'H', 'O12': 'X', 'O2C': 'N', 'O2G': 'G', 'OAD': 'N',
'OAS': 'S', 'OBF': 'X', 'OBS': 'X', 'OCS': 'C', 'OCY': 'C', 'ODP': 'N',
'OHI': 'H', 'OHS': 'D', 'OIC': 'X', 'OIP': 'I', 'OLE': 'X', 'OLT': 'T',
'OLZ': 'S', 'OMC': 'C', 'OMG': 'G', 'OMT': 'M', 'OMU': 'U', 'ONE': 'U',
'ONH': 'A', 'ONL': 'X', 'OPR': 'R', 'ORN': 'A', 'ORQ': 'R', 'OSE': 'S',
'OTB': 'X', 'OTH': 'T', 'OTY': 'Y', 'OXX': 'D', 'P': 'G', 'P1L': 'C',
'P1P': 'N', 'P2T': 'T', 'P2U': 'U', 'P2Y': 'P', 'P5P': 'A', 'PAQ': 'Y',
'PAS': 'D', 'PAT': 'W', 'PAU': 'A', 'PBB': 'C', 'PBF': 'F', 'PBT': 'N',
'PCA': 'E', 'PCC': 'P', 'PCE': 'X', 'PCS': 'F', 'PDL': 'X', 'PDU': 'U',
'PEC': 'C', 'PF5': 'F', 'PFF': 'F', 'PFX': 'X', 'PG1': 'S', 'PG7': 'G',
'PG9': 'G', 'PGL': 'X', 'PGN': 'G', 'PGP': 'G', 'PGY': 'G', 'PHA': 'F',
'PHD': 'D', 'PHE': 'F', 'PHI': 'F', 'PHL': 'F', 'PHM': 'F', 'PIV': 'X',
'PLE': 'L', 'PM3': 'F', 'PMT': 'C', 'POM': 'P', 'PPN': 'F', 'PPU': 'A',
'PPW': 'G', 'PQ1': 'N', 'PR3': 'C', 'PR5': 'A', 'PR9': 'P', 'PRN': 'A',
'PRO': 'P', 'PRS': 'P', 'PSA': 'F', 'PSH': 'H', 'PST': 'T', 'PSU': 'U',
'PSW': 'C', 'PTA': 'X', 'PTH': 'Y', 'PTM': 'Y', 'PTR': 'Y', 'PU': 'A',
'PUY': 'N', 'PVH': 'H', 'PVL': 'X', 'PYA': 'A', 'PYO': 'U', 'PYX': 'C',
'PYY': 'N', 'QMM': 'Q', 'QPA': 'C', 'QPH': 'F', 'QUO': 'G', 'R': 'A',
'R1A': 'C', 'R4K': 'W', 'RE0': 'W', 'RE3': 'W', 'RIA': 'A', 'RMP': 'A',
'RON': 'X', 'RT': 'T', 'RTP': 'N', 'S1H': 'S', 'S2C': 'C', 'S2D': 'A',
'S2M': 'T', 'S2P': 'A', 'S4A': 'A', 'S4C': 'C', 'S4G': 'G', 'S4U': 'U',
'S6G': 'G', 'SAC': 'S', 'SAH': 'C', 'SAR': 'G', 'SBL': 'S', 'SC': 'C',
'SCH': 'C', 'SCS': 'C', 'SCY': 'C', 'SD2': 'X', 'SDG': 'G', 'SDP': 'S',
'SEB': 'S', 'SEC': 'A', 'SEG': 'A', 'SEL': 'S', 'SEM': 'S', 'SEN': 'S',
'SEP': 'S', 'SER': 'S', 'SET': 'S', 'SGB': 'S', 'SHC': 'C', 'SHP': 'G',
'SHR': 'K', 'SIB': 'C', 'SLA': 'P', 'SLR': 'P', 'SLZ': 'K', 'SMC': 'C',
'SME': 'M', 'SMF': 'F', 'SMP': 'A', 'SMT': 'T', 'SNC': 'C', 'SNN': 'N',
'SOC': 'C', 'SOS': 'N', 'SOY': 'S', 'SPT': 'T', 'SRA': 'A', 'SSU': 'U',
'STY': 'Y', 'SUB': 'X', 'SUN': 'S', 'SUR': 'U', 'SVA': 'S', 'SVV': 'S',
'SVW': 'S', 'SVX': 'S', 'SVY': 'S', 'SVZ': 'X', 'SYS': 'C', 'T': 'T',
'T11': 'F', 'T23': 'T', 'T2S': 'T', 'T2T': 'N', 'T31': 'U', 'T32': 'T',
'T36': 'T', 'T37': 'T', 'T38': 'T', 'T39': 'T', 'T3P': 'T', 'T41': 'T',
'T48': 'T', 'T49': 'T', 'T4S': 'T', 'T5O': 'U', 'T5S': 'T', 'T66': 'X',
'T6A': 'A', 'TA3': 'T', 'TA4': 'X', 'TAF': 'T', 'TAL': 'N', 'TAV': 'D',
'TBG': 'V', 'TBM': 'T', 'TC1': 'C', 'TCP': 'T', 'TCQ': 'Y', 'TCR': 'W',
'TCY': 'A', 'TDD': 'L', 'TDY': 'T', 'TFE': 'T', 'TFO': 'A', 'TFQ': 'F',
'TFT': 'T', 'TGP': 'G', 'TH6': 'T', 'THC': 'T', 'THO': 'X', 'THR': 'T',
'THX': 'N', 'THZ': 'R', 'TIH': 'A', 'TLB': 'N', 'TLC': 'T', 'TLN': 'U',
'TMB': 'T', 'TMD': 'T', 'TNB': 'C', 'TNR': 'S', 'TOX': 'W', 'TP1': 'T',
'TPC': 'C', 'TPG': 'G', 'TPH': 'X', 'TPL': 'W', 'TPO': 'T', 'TPQ': 'Y',
'TQI': 'W', 'TQQ': 'W', 'TRF': 'W', 'TRG': 'K', 'TRN': 'W', 'TRO': 'W',
'TRP': 'W', 'TRQ': 'W', 'TRW': 'W', 'TRX': 'W', 'TS': 'N', 'TST': 'X',
'TT': 'N', 'TTD': 'T', 'TTI': 'U', 'TTM': 'T', 'TTQ': 'W', 'TTS': 'Y',
'TY1': 'Y', 'TY2': 'Y', 'TY3': 'Y', 'TY5': 'Y', 'TYB': 'Y', 'TYI': 'Y',
'TYJ': 'Y', 'TYN': 'Y', 'TYO': 'Y', 'TYQ': 'Y', 'TYR': 'Y', 'TYS': 'Y',
'TYT': 'Y', 'TYU': 'N', 'TYW': 'Y', 'TYX': 'X', 'TYY': 'Y', 'TZB': 'X',
'TZO': 'X', 'U': 'U', 'U25': 'U', 'U2L': 'U', 'U2N': 'U', 'U2P': 'U',
'U31': 'U', 'U33': 'U', 'U34': 'U', 'U36': 'U', 'U37': 'U', 'U8U': 'U',
'UAR': 'U', 'UCL': 'U', 'UD5': 'U', 'UDP': 'N', 'UFP': 'N', 'UFR': 'U',
'UFT': 'U', 'UMA': 'A', 'UMP': 'U', 'UMS': 'U', 'UN1': 'X', 'UN2': 'X',
'UNK': 'X', 'UR3': 'U', 'URD': 'U', 'US1': 'U', 'US2': 'U', 'US3': 'T',
'US5': 'U', 'USM': 'U', 'VAD': 'V', 'VAF': 'V', 'VAL': 'V', 'VB1': 'K',
'VDL': 'X', 'VLL': 'X', 'VLM': 'X', 'VMS': 'X', 'VOL': 'X', 'X': 'G',
'X2W': 'E', 'X4A': 'N', 'XAD': 'A', 'XAE': 'N', 'XAL': 'A', 'XAR': 'N',
'XCL': 'C', 'XCN': 'C', 'XCP': 'X', 'XCR': 'C', 'XCS': 'N', 'XCT': 'C',
'XCY': 'C', 'XGA': 'N', 'XGL': 'G', 'XGR': 'G', 'XGU': 'G', 'XPR': 'P',
'XSN': 'N', 'XTH': 'T', 'XTL': 'T', 'XTR': 'T', 'XTS': 'G', 'XTY': 'N',
'XUA': 'A', 'XUG': 'G', 'XX1': 'K', 'Y': 'A', 'YCM': 'C', 'YG': 'G',
'YOF': 'Y', 'YRR': 'N', 'YYG': 'G', 'Z': 'C', 'Z01': 'A', 'ZAD': 'A',
'ZAL': 'A', 'ZBC': 'C', 'ZBU': 'U', 'ZCL': 'F', 'ZCY': 'C', 'ZDU': 'U',
'ZFB': 'X', 'ZGU': 'G', 'ZHP': 'N', 'ZTH': 'T', 'ZU0': 'T', 'ZZJ': 'A',
}
# common_typos_enable
# pyformat: enable
@functools.lru_cache(maxsize=64)
def letters_three_to_one(restype: str, *, default: str) -> str:
"""Returns single letter name if one exists otherwise returns default."""
return CCD_NAME_TO_ONE_LETTER.get(restype, default)
ALA = sys.intern('ALA')
ARG = sys.intern('ARG')
ASN = sys.intern('ASN')
ASP = sys.intern('ASP')
CYS = sys.intern('CYS')
GLN = sys.intern('GLN')
GLU = sys.intern('GLU')
GLY = sys.intern('GLY')
HIS = sys.intern('HIS')
ILE = sys.intern('ILE')
LEU = sys.intern('LEU')
LYS = sys.intern('LYS')
MET = sys.intern('MET')
PHE = sys.intern('PHE')
PRO = sys.intern('PRO')
SER = sys.intern('SER')
THR = sys.intern('THR')
TRP = sys.intern('TRP')
TYR = sys.intern('TYR')
VAL = sys.intern('VAL')
UNK = sys.intern('UNK')
GAP = sys.intern('-')
# Unknown ligand.
UNL = sys.intern('UNL')
# Non-standard version of MET (with Se instead of S), but often appears in PDB.
MSE = sys.intern('MSE')
# 20 standard protein amino acids (no unknown).
PROTEIN_TYPES: tuple[str, ...] = (
ALA, ARG, ASN, ASP, CYS, GLN, GLU, GLY, HIS, ILE, LEU, LYS, MET, PHE, PRO,
SER, THR, TRP, TYR, VAL,
) # pyformat: disable
# 20 standard protein amino acids plus the unknown (UNK) amino acid.
PROTEIN_TYPES_WITH_UNKNOWN: tuple[str, ...] = PROTEIN_TYPES + (UNK,)
# This is the standard residue order when coding AA type as a number.
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
# For legacy reasons this only refers to protein residues.
PROTEIN_TYPES_ONE_LETTER: tuple[str, ...] = (
'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
'S', 'T', 'W', 'Y', 'V',
) # pyformat: disable
PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN: tuple[str, ...] = (
PROTEIN_TYPES_ONE_LETTER + ('X',)
)
PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP: tuple[str, ...] = (
PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN + (GAP,)
)
PROTEIN_TYPES_ONE_LETTER_TO_INT: Mapping[str, int] = {
r: i for i, r in enumerate(PROTEIN_TYPES_ONE_LETTER)
}
PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_TO_INT: Mapping[str, int] = {
r: i for i, r in enumerate(PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN)
}
PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP_TO_INT: Mapping[str, int] = {
r: i for i, r in enumerate(PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP)
}
PROTEIN_COMMON_ONE_TO_THREE: Mapping[str, str] = {
'A': ALA,
'R': ARG,
'N': ASN,
'D': ASP,
'C': CYS,
'Q': GLN,
'E': GLU,
'G': GLY,
'H': HIS,
'I': ILE,
'L': LEU,
'K': LYS,
'M': MET,
'F': PHE,
'P': PRO,
'S': SER,
'T': THR,
'W': TRP,
'Y': TYR,
'V': VAL,
}
PROTEIN_COMMON_THREE_TO_ONE: Mapping[str, str] = {
v: k for k, v in PROTEIN_COMMON_ONE_TO_THREE.items()
}
A = sys.intern('A')
G = sys.intern('G')
C = sys.intern('C')
U = sys.intern('U')
T = sys.intern('T')
DA = sys.intern('DA')
DG = sys.intern('DG')
DC = sys.intern('DC')
DT = sys.intern('DT')
UNK_NUCLEIC_ONE_LETTER = sys.intern('N') # Unknown nucleic acid single letter.
UNK_RNA = sys.intern('N') # Unknown RNA.
UNK_DNA = sys.intern('DN') # Unknown DNA residue (differs from N).
RNA_TYPES: tuple[str, ...] = (A, G, C, U)
DNA_TYPES: tuple[str, ...] = (DA, DG, DC, DT)
NUCLEIC_TYPES: tuple[str, ...] = RNA_TYPES + DNA_TYPES
# Without UNK DNA.
NUCLEIC_TYPES_WITH_UNKNOWN: tuple[str, ...] = NUCLEIC_TYPES + (
UNK_NUCLEIC_ONE_LETTER,
)
NUCLEIC_TYPES_WITH_2_UNKS: tuple[str, ...] = NUCLEIC_TYPES + (
UNK_RNA,
UNK_DNA,
)
RNA_TYPES_ONE_LETTER_WITH_UNKNOWN: tuple[str, ...] = RNA_TYPES + (UNK_RNA,)
RNA_TYPES_ONE_LETTER_WITH_UNKNOWN_TO_INT: Mapping[str, int] = {
r: i for i, r in enumerate(RNA_TYPES_ONE_LETTER_WITH_UNKNOWN)
}
DNA_TYPES_WITH_UNKNOWN: tuple[str, ...] = DNA_TYPES + (UNK_DNA,)
DNA_TYPES_ONE_LETTER: tuple[str, ...] = (A, G, C, T)
DNA_TYPES_ONE_LETTER_WITH_UNKNOWN: tuple[str, ...] = DNA_TYPES_ONE_LETTER + (
UNK_NUCLEIC_ONE_LETTER,
)
DNA_TYPES_ONE_LETTER_WITH_UNKNOWN_TO_INT: Mapping[str, int] = {
r: i for i, r in enumerate(DNA_TYPES_ONE_LETTER_WITH_UNKNOWN)
}
DNA_COMMON_ONE_TO_TWO: Mapping[str, str] = {
'A': 'DA',
'G': 'DG',
'C': 'DC',
'T': 'DT',
}
STANDARD_POLYMER_TYPES: tuple[str, ...] = PROTEIN_TYPES + NUCLEIC_TYPES
POLYMER_TYPES: tuple[str, ...] = PROTEIN_TYPES_WITH_UNKNOWN + NUCLEIC_TYPES
POLYMER_TYPES_WITH_UNKNOWN: tuple[str, ...] = (
PROTEIN_TYPES_WITH_UNKNOWN + NUCLEIC_TYPES_WITH_UNKNOWN
)
POLYMER_TYPES_WITH_GAP: tuple[str, ...] = PROTEIN_TYPES + (GAP,) + NUCLEIC_TYPES
POLYMER_TYPES_WITH_UNKNOWN_AND_GAP: tuple[str, ...] = (
PROTEIN_TYPES_WITH_UNKNOWN + (GAP,) + NUCLEIC_TYPES_WITH_UNKNOWN
)
POLYMER_TYPES_WITH_ALL_UNKS_AND_GAP: tuple[str, ...] = (
PROTEIN_TYPES_WITH_UNKNOWN + (GAP,) + NUCLEIC_TYPES_WITH_2_UNKS
)
POLYMER_TYPES_ORDER = {restype: i for i, restype in enumerate(POLYMER_TYPES)}
POLYMER_TYPES_ORDER_WITH_UNKNOWN = {
restype: i for i, restype in enumerate(POLYMER_TYPES_WITH_UNKNOWN)
}
POLYMER_TYPES_ORDER_WITH_UNKNOWN_AND_GAP = {
restype: i for i, restype in enumerate(POLYMER_TYPES_WITH_UNKNOWN_AND_GAP)
}
POLYMER_TYPES_ORDER_WITH_ALL_UNKS_AND_GAP = {
restype: i for i, restype in enumerate(POLYMER_TYPES_WITH_ALL_UNKS_AND_GAP)
}
POLYMER_TYPES_NUM = len(POLYMER_TYPES) # := 29.
POLYMER_TYPES_NUM_WITH_UNKNOWN = len(POLYMER_TYPES_WITH_UNKNOWN) # := 30.
POLYMER_TYPES_NUM_WITH_GAP = len(POLYMER_TYPES_WITH_GAP) # := 29.
POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP = len(
POLYMER_TYPES_WITH_UNKNOWN_AND_GAP
) # := 31.
POLYMER_TYPES_NUM_ORDER_WITH_ALL_UNKS_AND_GAP = len(
POLYMER_TYPES_WITH_ALL_UNKS_AND_GAP
) # := 32.
WATER_TYPES: tuple[str, ...] = ('HOH', 'DOD')
UNKNOWN_TYPES: tuple[str, ...] = (UNK, UNK_RNA, UNK_DNA, UNL)
================================================
FILE: src/alphafold3/constants/side_chains.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Constants associated with side chains."""
from collections.abc import Mapping, Sequence
import itertools
# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
# chi angles so their chi angle lists are empty.
CHI_ANGLES_ATOMS: Mapping[str, Sequence[tuple[str, ...]]] = {
'ALA': [],
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
'ARG': [
('N', 'CA', 'CB', 'CG'),
('CA', 'CB', 'CG', 'CD'),
('CB', 'CG', 'CD', 'NE'),
('CG', 'CD', 'NE', 'CZ'),
],
'ASN': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'OD1')],
'ASP': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'OD1')],
'CYS': [('N', 'CA', 'CB', 'SG')],
'GLN': [
('N', 'CA', 'CB', 'CG'),
('CA', 'CB', 'CG', 'CD'),
('CB', 'CG', 'CD', 'OE1'),
],
'GLU': [
('N', 'CA', 'CB', 'CG'),
('CA', 'CB', 'CG', 'CD'),
('CB', 'CG', 'CD', 'OE1'),
],
'GLY': [],
'HIS': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'ND1')],
'ILE': [('N', 'CA', 'CB', 'CG1'), ('CA', 'CB', 'CG1', 'CD1')],
'LEU': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'CD1')],
'LYS': [
('N', 'CA', 'CB', 'CG'),
('CA', 'CB', 'CG', 'CD'),
('CB', 'CG', 'CD', 'CE'),
('CG', 'CD', 'CE', 'NZ'),
],
'MET': [
('N', 'CA', 'CB', 'CG'),
('CA', 'CB', 'CG', 'SD'),
('CB', 'CG', 'SD', 'CE'),
],
'PHE': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'CD1')],
'PRO': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'CD')],
'SER': [('N', 'CA', 'CB', 'OG')],
'THR': [('N', 'CA', 'CB', 'OG1')],
'TRP': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'CD1')],
'TYR': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'CD1')],
'VAL': [('N', 'CA', 'CB', 'CG1')],
}
CHI_GROUPS_FOR_ATOM = {}
for res_name, chi_angle_atoms_for_res in CHI_ANGLES_ATOMS.items():
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
for atom_i, atom in enumerate(chi_group):
CHI_GROUPS_FOR_ATOM.setdefault((res_name, atom), []).append(
(chi_group_i, atom_i)
)
# Mapping from (residue_name, atom_name) pairs to the atom's chi group index
# and atom index within that group.
CHI_GROUPS_FOR_ATOM: Mapping[tuple[str, str], Sequence[tuple[int, int]]] = (
CHI_GROUPS_FOR_ATOM
)
MAX_NUM_CHI_ANGLES: int = 4
ATOMS_PER_CHI_ANGLE: int = 4
# A list of atoms for each AA type that are involved in chi angle calculations.
CHI_ATOM_SETS: Mapping[str, set[str]] = {
residue_name: set(itertools.chain(*atoms))
for residue_name, atoms in CHI_ANGLES_ATOMS.items()
}
# If chi angles given in fixed-length array, this matrix determines how to mask
# them for each AA type. The order is as per restype_order (see below).
CHI_ANGLES_MASK: Sequence[Sequence[float]] = (
(0.0, 0.0, 0.0, 0.0), # ALA
(1.0, 1.0, 1.0, 1.0), # ARG
(1.0, 1.0, 0.0, 0.0), # ASN
(1.0, 1.0, 0.0, 0.0), # ASP
(1.0, 0.0, 0.0, 0.0), # CYS
(1.0, 1.0, 1.0, 0.0), # GLN
(1.0, 1.0, 1.0, 0.0), # GLU
(0.0, 0.0, 0.0, 0.0), # GLY
(1.0, 1.0, 0.0, 0.0), # HIS
(1.0, 1.0, 0.0, 0.0), # ILE
(1.0, 1.0, 0.0, 0.0), # LEU
(1.0, 1.0, 1.0, 1.0), # LYS
(1.0, 1.0, 1.0, 0.0), # MET
(1.0, 1.0, 0.0, 0.0), # PHE
(1.0, 1.0, 0.0, 0.0), # PRO
(1.0, 0.0, 0.0, 0.0), # SER
(1.0, 0.0, 0.0, 0.0), # THR
(1.0, 1.0, 0.0, 0.0), # TRP
(1.0, 1.0, 0.0, 0.0), # TYR
(1.0, 0.0, 0.0, 0.0), # VAL
)
================================================
FILE: src/alphafold3/cpp.cc
================================================
// Copyright 2024 DeepMind Technologies Limited
//
// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
//
// To request access to the AlphaFold 3 model parameters, follow the process set
// out at https://github.com/google-deepmind/alphafold3. You may only use these
// if received directly from Google. Use is subject to terms of use available at
// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
#include "alphafold3/data/cpp/msa_profile_pybind.h"
#include "alphafold3/model/mkdssp_pybind.h"
#include "alphafold3/parsers/cpp/cif_dict_pybind.h"
#include "alphafold3/parsers/cpp/fasta_iterator_pybind.h"
#include "alphafold3/parsers/cpp/msa_conversion_pybind.h"
#include "alphafold3/structure/cpp/aggregation_pybind.h"
#include "alphafold3/structure/cpp/membership_pybind.h"
#include "alphafold3/structure/cpp/mmcif_atom_site_pybind.h"
#include "alphafold3/structure/cpp/mmcif_layout_pybind.h"
#include "alphafold3/structure/cpp/mmcif_struct_conn_pybind.h"
#include "alphafold3/structure/cpp/mmcif_utils_pybind.h"
#include "alphafold3/structure/cpp/string_array_pybind.h"
#include "pybind11/pybind11.h"
namespace alphafold3 {
namespace {
// Include all modules as submodules to simplify building.
PYBIND11_MODULE(cpp, m) {
RegisterModuleCifDict(m.def_submodule("cif_dict"));
RegisterModuleFastaIterator(m.def_submodule("fasta_iterator"));
RegisterModuleMsaConversion(m.def_submodule("msa_conversion"));
RegisterModuleMmcifLayout(m.def_submodule("mmcif_layout"));
RegisterModuleMmcifStructConn(m.def_submodule("mmcif_struct_conn"));
RegisterModuleMembership(m.def_submodule("membership"));
RegisterModuleMmcifUtils(m.def_submodule("mmcif_utils"));
RegisterModuleAggregation(m.def_submodule("aggregation"));
RegisterModuleStringArray(m.def_submodule("string_array"));
RegisterModuleMmcifAtomSite(m.def_submodule("mmcif_atom_site"));
RegisterModuleMkdssp(m.def_submodule("mkdssp"));
RegisterModuleMsaProfile(m.def_submodule("msa_profile"));
}
} // namespace
} // namespace alphafold3
================================================
FILE: src/alphafold3/data/cpp/msa_profile_pybind.cc
================================================
// Copyright 2024 DeepMind Technologies Limited
//
// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
//
// To request access to the AlphaFold 3 model parameters, follow the process set
// out at https://github.com/google-deepmind/alphafold3. You may only use these
// if received directly from Google. Use is subject to terms of use available at
// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
#include
#include "absl/strings/str_cat.h"
#include "pybind11/cast.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
namespace {
namespace py = pybind11;
py::array_t ComputeMsaProfile(
const py::array_t& msa, int num_residue_types) {
if (msa.size() == 0) {
throw py::value_error("The MSA must be non-empty.");
}
if (msa.ndim() != 2) {
throw py::value_error(absl::StrCat("The MSA must be rectangular, got ",
msa.ndim(), "-dimensional MSA array."));
}
const int msa_depth = msa.shape()[0];
const int sequence_length = msa.shape()[1];
py::array_t profile({sequence_length, num_residue_types});
std::fill(profile.mutable_data(), profile.mutable_data() + profile.size(),
0.0f);
auto profile_unchecked = profile.mutable_unchecked<2>();
const double normalized_count = 1.0 / msa_depth;
const int* msa_it = msa.data();
for (int row_index = 0; row_index < msa_depth; ++row_index) {
for (int column_index = 0; column_index < sequence_length; ++column_index) {
const int residue_code = *(msa_it++);
if (residue_code < 0 || residue_code >= num_residue_types) {
throw py::value_error(
absl::StrCat("All residue codes must be positive and smaller than "
"num_residue_types ",
num_residue_types, ", got ", residue_code));
}
profile_unchecked(column_index, residue_code) += normalized_count;
}
}
return profile;
}
constexpr char kComputeMsaProfileDoc[] = R"(
Computes MSA profile for the given encoded MSA.
Args:
msa: A Numpy array of shape (num_msa, num_res) with the integer coded MSA.
num_residue_types: Integer that determines the number of unique residue types.
This will determine the shape of the output profile.
Returns:
A float Numpy array of shape (num_res, num_residue_types) with residue
frequency (residue type count normalized by MSA depth) for every column of the
MSA.
)";
} // namespace
namespace alphafold3 {
void RegisterModuleMsaProfile(pybind11::module m) {
m.def("compute_msa_profile", &ComputeMsaProfile, py::arg("msa"),
py::arg("num_residue_types"), py::doc(kComputeMsaProfileDoc + 1));
}
} // namespace alphafold3
================================================
FILE: src/alphafold3/data/cpp/msa_profile_pybind.h
================================================
// Copyright 2024 DeepMind Technologies Limited
//
// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
//
// To request access to the AlphaFold 3 model parameters, follow the process set
// out at https://github.com/google-deepmind/alphafold3. You may only use these
// if received directly from Google. Use is subject to terms of use available at
// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_DATA_PYTHON_MSA_PROFILE_PYBIND_H_
#define ALPHAFOLD3_SRC_ALPHAFOLD3_DATA_PYTHON_MSA_PROFILE_PYBIND_H_
#include "pybind11/pybind11.h"
namespace alphafold3 {
void RegisterModuleMsaProfile(pybind11::module m);
}
#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_DATA_PYTHON_MSA_PROFILE_PYBIND_H_
================================================
FILE: src/alphafold3/data/featurisation.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""AlphaFold 3 featurisation pipeline."""
from collections.abc import Sequence
import datetime
import time
from alphafold3.common import folding_input
from alphafold3.constants import chemical_components
from alphafold3.model import features
from alphafold3.model.pipeline import pipeline
import numpy as np
def validate_fold_input(fold_input: folding_input.Input):
"""Validates the fold input contains MSA and templates for featurisation."""
for i, chain in enumerate(fold_input.protein_chains):
if chain.unpaired_msa is None:
raise ValueError(f'Protein chain {i + 1} is missing unpaired MSA.')
if chain.paired_msa is None:
raise ValueError(f'Protein chain {i + 1} is missing paired MSA.')
if chain.templates is None:
raise ValueError(f'Protein chain {i + 1} is missing Templates.')
for i, chain in enumerate(fold_input.rna_chains):
if chain.unpaired_msa is None:
raise ValueError(f'RNA chain {i + 1} is missing unpaired MSA.')
def featurise_input(
fold_input: folding_input.Input,
ccd: chemical_components.Ccd,
buckets: Sequence[int] | None,
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
verbose: bool = False,
) -> Sequence[features.BatchDict]:
"""Featurise the folding input.
Args:
fold_input: The input to featurise.
ccd: The chemical components dictionary.
buckets: Bucket sizes to pad the data to, to avoid excessive re-compilation
of the model. If None, calculate the appropriate bucket size from the
number of tokens. If not None, must be a sequence of at least one integer,
in strictly increasing order. Will raise an error if the number of tokens
is more than the largest bucket size.
ref_max_modified_date: Optional maximum date that controls whether to allow
use of model coordinates for a chemical component from the CCD if RDKit
conformer generation fails and the component does not have ideal
coordinates set. Only for components that have been released before this
date the model coordinates can be used as a fallback.
conformer_max_iterations: Optional override for maximum number of iterations
to run for RDKit conformer search.
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
MSA. The default behaviour matches the method described in the AlphaFold 3
paper. Set this to false if providing custom paired MSA using the unpaired
MSA field to keep it exactly as is as deduplication against the paired MSA
could break the manually crafted pairing between MSA sequences.
verbose: Whether to print progress messages.
Returns:
A featurised batch for each rng_seed in the input.
"""
validate_fold_input(fold_input)
# Set up data pipeline for single use.
data_pipeline = pipeline.WholePdbPipeline(
config=pipeline.WholePdbPipeline.Config(
buckets=buckets,
ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=conformer_max_iterations,
resolve_msa_overlaps=resolve_msa_overlaps,
),
)
batches = []
for rng_seed in fold_input.rng_seeds:
featurisation_start_time = time.time()
if verbose:
print(f'Featurising data with seed {rng_seed}.')
batch = data_pipeline.process_item(
fold_input=fold_input,
ccd=ccd,
random_state=np.random.RandomState(rng_seed),
random_seed=rng_seed,
)
if verbose:
print(
f'Featurising data with seed {rng_seed} took'
f' {time.time() - featurisation_start_time:.2f} seconds.'
)
batches.append(batch)
return batches
================================================
FILE: src/alphafold3/data/msa.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Functions for getting MSA and calculating alignment features."""
from collections.abc import MutableMapping, Sequence
import string
from typing import Self
from absl import logging
from alphafold3.constants import mmcif_names
from alphafold3.data import msa_config
from alphafold3.data import msa_features
from alphafold3.data import parsers
from alphafold3.data.tools import jackhmmer
from alphafold3.data.tools import msa_tool
from alphafold3.data.tools import nhmmer
import numpy as np
class Error(Exception):
"""Error indicatating a problem with MSA Search."""
def _featurize(seq: str, chain_poly_type: str) -> str | list[int]:
if mmcif_names.is_standard_polymer_type(chain_poly_type):
featurized_seqs, _ = msa_features.extract_msa_features(
msa_sequences=[seq], chain_poly_type=chain_poly_type
)
return featurized_seqs[0].tolist()
# For anything else simply require an identical match.
return seq
def sequences_are_feature_equivalent(
sequence1: str,
sequence2: str,
chain_poly_type: str,
) -> bool:
feat1 = _featurize(sequence1, chain_poly_type)
feat2 = _featurize(sequence2, chain_poly_type)
return feat1 == feat2
class Msa:
"""Multiple Sequence Alignment container with methods for manipulating it."""
def __init__(
self,
query_sequence: str,
chain_poly_type: str,
sequences: Sequence[str],
descriptions: Sequence[str],
deduplicate: bool = True,
):
"""Raw constructor, prefer using the from_{a3m,multiple_msas} class methods.
The first sequence must be equal (in featurised form) to the query sequence.
If sequences/descriptions are empty, they will be initialised to the query.
Args:
query_sequence: The sequence that was used to search for MSA.
chain_poly_type: Polymer type of the query sequence, see mmcif_names.
sequences: The sequences returned by the MSA search tool.
descriptions: Metadata for the sequences returned by the MSA search tool.
deduplicate: If True, the MSA sequences will be deduplicated in the input
order. Lowercase letters (insertions) are ignored when deduplicating.
"""
if len(sequences) != len(descriptions):
raise ValueError('The number of sequences and descriptions must match.')
self.query_sequence = query_sequence
self.chain_poly_type = chain_poly_type
if not deduplicate:
self.sequences = sequences
self.descriptions = descriptions
else:
self.sequences = []
self.descriptions = []
# A replacement table that removes all lowercase characters.
deletion_table = str.maketrans('', '', string.ascii_lowercase)
unique_sequences = set()
for seq, desc in zip(sequences, descriptions, strict=True):
# Using string.translate is faster than re.sub('[a-z]+', '').
sequence_no_deletions = seq.translate(deletion_table)
if sequence_no_deletions not in unique_sequences:
unique_sequences.add(sequence_no_deletions)
self.sequences.append(seq)
self.descriptions.append(desc)
# Make sure the MSA always has at least the query.
self.sequences = self.sequences or [query_sequence]
self.descriptions = self.descriptions or ['Original query']
# Check if the 1st MSA sequence matches the query sequence. Since it may be
# mutated by the search tool (jackhmmer) check using the featurized version.
if not sequences_are_feature_equivalent(
self.sequences[0], query_sequence, chain_poly_type
):
raise ValueError(
f'First MSA sequence {self.sequences[0]} is not the {query_sequence=}'
)
@classmethod
def from_multiple_msas(
cls, msas: Sequence[Self], deduplicate: bool = True
) -> Self:
"""Initializes the MSA from multiple MSAs.
Args:
msas: A sequence of Msa objects representing individual MSAs produced by
different tools/dbs.
deduplicate: If True, the MSA sequences will be deduplicated in the input
order. Lowercase letters (insertions) are ignored when deduplicating.
Returns:
An Msa object created by merging multiple MSAs.
"""
if not msas:
raise ValueError('At least one MSA must be provided.')
query_sequence = msas[0].query_sequence
chain_poly_type = msas[0].chain_poly_type
sequences = []
descriptions = []
for msa in msas:
if msa.query_sequence != query_sequence:
raise ValueError(
f'Query sequences must match: {[m.query_sequence for m in msas]}'
)
if msa.chain_poly_type != chain_poly_type:
raise ValueError(
f'Chain poly types must match: {[m.chain_poly_type for m in msas]}'
)
sequences.extend(msa.sequences)
descriptions.extend(msa.descriptions)
return cls(
query_sequence=query_sequence,
chain_poly_type=chain_poly_type,
sequences=sequences,
descriptions=descriptions,
deduplicate=deduplicate,
)
@classmethod
def from_multiple_a3ms(
cls, a3ms: Sequence[str], chain_poly_type: str, deduplicate: bool = True
) -> Self:
"""Initializes the MSA from multiple A3M strings.
Args:
a3ms: A sequence of A3M strings representing individual MSAs produced by
different tools/dbs.
chain_poly_type: Polymer type of the query sequence, see mmcif_names.
deduplicate: If True, the MSA sequences will be deduplicated in the input
order. Lowercase letters (insertions) are ignored when deduplicating.
Returns:
An Msa object created by merging multiple A3Ms.
"""
if not a3ms:
raise ValueError('At least one A3M must be provided.')
query_sequence = None
all_sequences = []
all_descriptions = []
for a3m in a3ms:
sequences, descriptions = parsers.parse_fasta(a3m)
if query_sequence is None:
query_sequence = sequences[0]
if sequences[0] != query_sequence:
raise ValueError(
f'Query sequences must match: {sequences[0]=} != {query_sequence=}'
)
all_sequences.extend(sequences)
all_descriptions.extend(descriptions)
return cls(
query_sequence=query_sequence,
chain_poly_type=chain_poly_type,
sequences=all_sequences,
descriptions=all_descriptions,
deduplicate=deduplicate,
)
@classmethod
def from_a3m(
cls,
query_sequence: str,
chain_poly_type: str,
a3m: str,
max_depth: int | None = None,
deduplicate: bool = True,
) -> Self:
"""Parses the single A3M and builds the Msa object."""
sequences, descriptions = parsers.parse_fasta(a3m)
if max_depth is not None and 0 < max_depth < len(sequences):
logging.info(
'MSA cropped from depth of %d to %d for %s.',
len(sequences),
max_depth,
query_sequence,
)
sequences = sequences[:max_depth]
descriptions = descriptions[:max_depth]
return cls(
query_sequence=query_sequence,
chain_poly_type=chain_poly_type,
sequences=sequences,
descriptions=descriptions,
deduplicate=deduplicate,
)
@classmethod
def from_empty(cls, query_sequence: str, chain_poly_type: str) -> Self:
"""Creates an empty Msa containing just the query sequence."""
return cls(
query_sequence=query_sequence,
chain_poly_type=chain_poly_type,
sequences=[],
descriptions=[],
deduplicate=False,
)
@property
def depth(self) -> int:
return len(self.sequences)
def __repr__(self) -> str:
return f'Msa({self.depth} sequences, {self.chain_poly_type})'
def to_a3m(self) -> str:
"""Returns the MSA in the A3M format."""
a3m_lines = []
for desc, seq in zip(self.descriptions, self.sequences, strict=True):
a3m_lines.append(f'>{desc}')
a3m_lines.append(seq)
return '\n'.join(a3m_lines) + '\n'
def featurize(self) -> MutableMapping[str, np.ndarray]:
"""Featurises the MSA and returns a map of feature names to features.
Returns:
A dictionary mapping feature names to values.
Raises:
msa.Error:
* If the sequences in the MSA don't have the same length after deletions
(lower case letters) are removed.
* If the MSA contains an unknown amino acid code.
* If there are no sequences after aligning.
"""
try:
msa, deletion_matrix = msa_features.extract_msa_features(
msa_sequences=self.sequences, chain_poly_type=self.chain_poly_type
)
except ValueError as e:
raise Error(f'Error extracting MSA or deletion features: {e}') from e
if msa.shape == (0, 0):
raise Error(f'Empty MSA feature for {self}')
species_ids = msa_features.extract_species_ids(self.descriptions)
return {
'msa_species_identifiers': np.array(species_ids, dtype=object),
'num_alignments': np.array(self.depth, dtype=np.int32),
'msa': msa,
'deletion_matrix': deletion_matrix,
}
def get_msa_tool(
msa_tool_config: msa_config.JackhmmerConfig | msa_config.NhmmerConfig,
) -> msa_tool.MsaTool:
"""Returns the requested MSA tool."""
match msa_tool_config:
case msa_config.JackhmmerConfig():
return jackhmmer.Jackhmmer(
binary_path=msa_tool_config.binary_path,
database_path=msa_tool_config.database_config.path,
n_cpu=msa_tool_config.n_cpu,
n_iter=msa_tool_config.n_iter,
e_value=msa_tool_config.e_value,
z_value=msa_tool_config.z_value,
max_sequences=msa_tool_config.max_sequences,
)
case msa_config.NhmmerConfig():
return nhmmer.Nhmmer(
binary_path=msa_tool_config.binary_path,
hmmalign_binary_path=msa_tool_config.hmmalign_binary_path,
hmmbuild_binary_path=msa_tool_config.hmmbuild_binary_path,
database_path=msa_tool_config.database_config.path,
n_cpu=msa_tool_config.n_cpu,
e_value=msa_tool_config.e_value,
max_sequences=msa_tool_config.max_sequences,
alphabet=msa_tool_config.alphabet,
)
case _:
raise ValueError(f'Unknown MSA tool: {msa_tool_config}.')
def get_msa(
target_sequence: str,
run_config: msa_config.RunConfig,
chain_poly_type: str,
deduplicate: bool = False,
) -> Msa:
"""Computes the MSA for a given query sequence.
Args:
target_sequence: The target amino-acid sequence.
run_config: MSA run configuration.
chain_poly_type: The type of chain for which to get an MSA.
deduplicate: If True, the MSA sequences will be deduplicated in the input
order. Lowercase letters (insertions) are ignored when deduplicating.
Returns:
Aligned MSA sequences.
"""
return Msa.from_a3m(
query_sequence=target_sequence,
chain_poly_type=chain_poly_type,
a3m=get_msa_tool(run_config.config).query(target_sequence).a3m,
max_depth=run_config.crop_size,
deduplicate=deduplicate,
)
================================================
FILE: src/alphafold3/data/msa_config.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Genetic search config settings for data pipelines."""
import dataclasses
import datetime
from typing import Self
from alphafold3.constants import mmcif_names
def _validate_chain_poly_type(chain_poly_type: str) -> None:
if chain_poly_type not in mmcif_names.STANDARD_POLYMER_CHAIN_TYPES:
raise ValueError(
'chain_poly_type must be one of'
f' {mmcif_names.STANDARD_POLYMER_CHAIN_TYPES}: {chain_poly_type}'
)
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class DatabaseConfig:
"""Configuration for a database."""
name: str
path: str
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class JackhmmerConfig:
"""Configuration for a jackhmmer run.
Attributes:
binary_path: Path to the binary of the msa tool.
database_config: Database configuration.
n_cpu: An integer with the number of CPUs to use.
n_iter: An integer with the number of database search iterations.
e_value: e-value for the database lookup.
z_value: The Z-value representing the database size in number of sequences
for E-value and domain E-value calculation. Must be set for sharded
databases.
dom_z_value: The Z-value representing the database size in number of
sequences for domain E-value calculation. Must be set for sharded
databases.
max_sequences: Max sequences to return in MSA.
max_parallel_shards: If given, the maximum number of shards to search
against in parallel. If None, one Jackhmmer instance will be run per
shard. Only applicable if the database is sharded.
"""
binary_path: str
database_config: DatabaseConfig
n_cpu: int
n_iter: int
e_value: float
z_value: int | None
dom_z_value: int | None
max_sequences: int
max_parallel_shards: int | None = None
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class NhmmerConfig:
"""Configuration for a nhmmer run.
Attributes:
binary_path: Path to the binary of the msa tool.
hmmalign_binary_path: Path to the hmmalign binary.
hmmbuild_binary_path: Path to the hmmbuild binary.
database_config: Database configuration.
n_cpu: An integer with the number of CPUs to use.
e_value: e-value for the database lookup.
z_value: The Z-value representing the database size in megabases for
E-value calculation. Allows fractional values. Must be set for sharded
databases.
max_sequences: Max sequences to return in MSA.
alphabet: The alphabet when building a profile with hmmbuild.
max_parallel_shards: If given, the maximum number of shards to search
against in parallel. If None, one Nhmmer instance will be run per shard.
Only applicable if the database is sharded.
"""
binary_path: str
hmmalign_binary_path: str
hmmbuild_binary_path: str
database_config: DatabaseConfig
n_cpu: int
e_value: float
z_value: float | None
max_sequences: int
alphabet: str | None
max_parallel_shards: int | None = None
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class RunConfig:
"""Configuration for an MSA run.
Attributes:
config: MSA tool config.
chain_poly_type: The chain type for which the tools will be run.
crop_size: The maximum number of sequences to keep in the MSA. If None, all
sequences are kept. Note that the query is included in the MSA, so it
doesn't make sense to set this to less than 2.
"""
config: JackhmmerConfig | NhmmerConfig
chain_poly_type: str
crop_size: int | None
def __post_init__(self):
if self.crop_size is not None and self.crop_size < 2:
raise ValueError(f'crop_size must be None or >= 2: {self.crop_size}')
_validate_chain_poly_type(self.chain_poly_type)
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class HmmsearchConfig:
"""Configuration for a hmmsearch."""
hmmsearch_binary_path: str
hmmbuild_binary_path: str
e_value: float
inc_e: float
dom_e: float
incdom_e: float
alphabet: str = 'amino'
filter_f1: float | None = None
filter_f2: float | None = None
filter_f3: float | None = None
filter_max: bool = False
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class TemplateToolConfig:
"""Configuration for a template tool."""
database_path: str
chain_poly_type: str
hmmsearch_config: HmmsearchConfig
max_a3m_query_sequences: int | None = 300
def __post_init__(self):
_validate_chain_poly_type(self.chain_poly_type)
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class TemplateFilterConfig:
"""Configuration for a template filter."""
max_subsequence_ratio: float | None
min_align_ratio: float | None
min_hit_length: int | None
deduplicate_sequences: bool
max_hits: int | None
max_template_date: datetime.date
@classmethod
def no_op_filter(cls) -> Self:
"""Returns a config for filter that keeps everything."""
return cls(
max_subsequence_ratio=None,
min_align_ratio=None,
min_hit_length=None,
deduplicate_sequences=False,
max_hits=None,
max_template_date=datetime.date(3000, 1, 1), # Very far in the future.
)
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class TemplatesConfig:
"""Configuration for the template search pipeline."""
template_tool_config: TemplateToolConfig
filter_config: TemplateFilterConfig
================================================
FILE: src/alphafold3/data/msa_features.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Utilities for computing MSA features."""
from collections.abc import Sequence
import re
from alphafold3.constants import mmcif_names
import numpy as np
_PROTEIN_TO_ID = {
'A': 0,
'B': 3, # Same as D.
'C': 4,
'D': 3,
'E': 6,
'F': 13,
'G': 7,
'H': 8,
'I': 9,
'J': 20, # Same as unknown (X).
'K': 11,
'L': 10,
'M': 12,
'N': 2,
'O': 20, # Same as unknown (X).
'P': 14,
'Q': 5,
'R': 1,
'S': 15,
'T': 16,
'U': 4, # Same as C.
'V': 19,
'W': 17,
'X': 20,
'Y': 18,
'Z': 6, # Same as E.
'-': 21,
}
_RNA_TO_ID = {
# Map non-standard residues to UNK_NUCLEIC (N) -> 30
**{chr(i): 30 for i in range(ord('A'), ord('Z') + 1)},
# Continue the RNA indices from where Protein indices left off.
'-': 21,
'A': 22,
'G': 23,
'C': 24,
'U': 25,
}
_DNA_TO_ID = {
# Map non-standard residues to UNK_NUCLEIC (N) -> 30
**{chr(i): 30 for i in range(ord('A'), ord('Z') + 1)},
# Continue the DNA indices from where DNA indices left off.
'-': 21,
'A': 26,
'G': 27,
'C': 28,
'T': 29,
}
def extract_msa_features(
msa_sequences: Sequence[str], chain_poly_type: str
) -> tuple[np.ndarray, np.ndarray]:
"""Extracts MSA features.
Example:
The input raw MSA is: `[["AAAAAA"], ["Ai-CiDiiiEFa"]]`
The output MSA will be: `[["AAAAAA"], ["A-CDEF"]]`
The deletions will be: `[[0, 0, 0, 0, 0, 0], [0, 1, 0, 1, 3, 0]]`
Args:
msa_sequences: A list of strings, each string with one MSA sequence. Each
string must have the same, constant number of non-lowercase (matching)
residues.
chain_poly_type: Either 'polypeptide(L)' (protein), 'polyribonucleotide'
(RNA), or 'polydeoxyribonucleotide' (DNA). Use the appropriate string
constant from mmcif_names.py.
Returns:
A tuple with:
* MSA array of shape (num_seq, num_res) that contains only the uppercase
characters or gaps (-) from the original MSA.
* Deletions array of shape (num_seq, num_res) that contains the number
of deletions (lowercase letters in the MSA) to the left from each
non-deleted residue (uppercase letters in the MSA).
Raises:
ValueError if any of the preconditions are not met.
"""
# Select the appropriate character map based on the chain type.
if chain_poly_type == mmcif_names.RNA_CHAIN:
char_map = _RNA_TO_ID
elif chain_poly_type == mmcif_names.DNA_CHAIN:
char_map = _DNA_TO_ID
elif chain_poly_type == mmcif_names.PROTEIN_CHAIN:
char_map = _PROTEIN_TO_ID
else:
raise ValueError(f'{chain_poly_type=} invalid.')
# Handle empty MSA.
if not msa_sequences:
empty_msa = np.array([], dtype=np.int32).reshape((0, 0))
empty_deletions = np.array([], dtype=np.int32).reshape((0, 0))
return empty_msa, empty_deletions
# Get the number of rows and columns in the MSA.
num_rows = len(msa_sequences)
num_cols = sum(1 for c in msa_sequences[0] if c in char_map)
# Initialize the output arrays.
msa_arr = np.zeros((num_rows, num_cols), dtype=np.int32)
deletions_arr = np.zeros((num_rows, num_cols), dtype=np.int32)
# Populate the output arrays.
for problem_row, msa_sequence in enumerate(msa_sequences):
deletion_count = 0
upper_count = 0
problem_col = 0
problems = []
for current in msa_sequence:
msa_id = char_map.get(current, -1)
if msa_id == -1:
if not current.islower():
problems.append(f'({problem_row}, {problem_col}):{current}')
deletion_count += 1
else:
# Check the access is safe before writing to the array.
# We don't need to check problem_row since it's guaranteed to be within
# the array bounds, while upper_count is incremented in the loop.
if upper_count < deletions_arr.shape[1]:
deletions_arr[problem_row, upper_count] = deletion_count
msa_arr[problem_row, upper_count] = msa_id
deletion_count = 0
upper_count += 1
problem_col += 1
if problems:
raise ValueError(
f"Unknown residues in MSA: {', '.join(problems)}. "
f'target_sequence: {msa_sequences[0]}'
)
if upper_count != num_cols:
raise ValueError(
'Invalid shape all strings must have the same number '
'of non-lowercase characters; First string has '
f"{num_cols} non-lowercase characters but '{msa_sequence}' has "
f'{upper_count}. target_sequence: {msa_sequences[0]}'
)
return msa_arr, deletions_arr
# UniProtKB SwissProt/TrEMBL dbs have the following description format:
# `db|UniqueIdentifier|EntryName`, e.g. `sp|P0C2L1|A3X1_LOXLA` or
# `tr|A0A146SKV9|A0A146SKV9_FUNHE`.
_UNIPROT_ENTRY_NAME_REGEX = re.compile(
# UniProtKB TrEMBL or SwissProt database.
r'(?:tr|sp)\|'
# A primary accession number of the UniProtKB entry.
r'(?:[A-Z0-9]{6,10})'
# Occasionally there is an isoform suffix (e.g. _1 or _10) which we ignore.
r'(?:_\d+)?\|'
# TrEMBL: Same as AccessionId (6-10 characters).
# SwissProt: A mnemonic protein identification code (1-5 characters).
r'(?:[A-Z0-9]{1,10}_)'
# A mnemonic species identification code.
r'(?P[A-Z0-9]{1,5})'
)
def extract_species_ids(msa_descriptions: Sequence[str]) -> Sequence[str]:
"""Extracts species ID from MSA UniProtKB sequence identifiers.
Args:
msa_descriptions: The descriptions (the FASTA/A3M comment line) for each of
the sequences.
Returns:
Extracted UniProtKB species IDs if there is a regex match for each
description line, blank if the regex doesn't match.
"""
species_ids = []
for msa_description in msa_descriptions:
msa_description = msa_description.strip()
match = _UNIPROT_ENTRY_NAME_REGEX.match(msa_description)
if match:
species_ids.append(match.group('SpeciesId'))
else:
# Handle cases where the regex doesn't match
# (e.g., append None or raise an error depending on your needs)
species_ids.append('')
return species_ids
================================================
FILE: src/alphafold3/data/msa_identifiers.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Utilities for extracting identifiers from MSA sequence descriptions."""
import dataclasses
import re
# Sequences coming from UniProtKB database come in the
# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
_UNIPROT_PATTERN = re.compile(
r"""
^
# UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
(?:tr|sp)
\|
# A primary accession number of the UniProtKB entry.
(?P[A-Za-z0-9]{6,10})
# Occasionally there is a _0 or _1 isoform suffix, which we ignore.
(?:_\d)?
\|
# TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
# protein ID code.
(?:[A-Za-z0-9]+)
_
# A mnemonic species identification code.
(?P([A-Za-z0-9]){1,5})
# Small BFD uses a final value after an underscore, which we ignore.
(?:_\d+)?
$
""",
re.VERBOSE,
)
@dataclasses.dataclass(frozen=True)
class Identifiers:
species_id: str = ''
def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
"""Gets species from an msa sequence identifier.
The sequence identifier has the format specified by
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
Args:
msa_sequence_identifier: a sequence identifier.
Returns:
An `Identifiers` instance with species_id. These
can be empty in the case where no identifier was found.
"""
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
if matches:
return Identifiers(species_id=matches.group('SpeciesIdentifier'))
return Identifiers()
def _extract_sequence_identifier(description: str) -> str | None:
"""Extracts sequence identifier from description. Returns None if no match."""
split_description = description.split()
if split_description:
return split_description[0].partition('/')[0]
else:
return None
def get_identifiers(description: str) -> Identifiers:
"""Computes extra MSA features from the description."""
sequence_identifier = _extract_sequence_identifier(description)
if sequence_identifier is None:
return Identifiers()
else:
return _parse_sequence_identifier(sequence_identifier)
================================================
FILE: src/alphafold3/data/parsers.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Functions for parsing various file formats."""
from collections.abc import Iterable, Sequence
from typing import IO, TypeAlias
from alphafold3.cpp import fasta_iterator
from alphafold3.cpp import msa_conversion
DeletionMatrix: TypeAlias = Sequence[Sequence[int]]
def lazy_parse_fasta_string(fasta_string: str) -> Iterable[tuple[str, str]]:
"""Lazily parses a FASTA/A3M string and yields (sequence, description) tuples.
This implementation is more memory friendly than `fasta_sequence` while
offering comparable performance. The underlying implementation is in C++ and
is therefore faster than a pure Python implementation.
Use this method when parsing FASTA files where you already have the FASTA
string, but need to control how far you iterate through its sequences.
Arguments:
fasta_string: A string with the contents of FASTA/A3M file.
Returns:
Iterator of (sequence, description). In the description, the leading ">" is
stripped.
Raises:
ValueError if the FASTA/A3M file is invalid, e.g. empty.
"""
# The lifetime of the FastaStringIterator is tied to the lifetime of
# fasta_string - fasta_string must be kept while the iterator is in use.
return fasta_iterator.FastaStringIterator(fasta_string)
def parse_fasta(fasta_string: str) -> tuple[Sequence[str], Sequence[str]]:
"""Parses FASTA string and returns list of strings with amino-acid sequences.
Arguments:
fasta_string: The string contents of a FASTA file.
Returns:
A tuple of two lists:
* A list of sequences.
* A list of sequence descriptions taken from the comment lines. In the
same order as the sequences.
"""
return fasta_iterator.parse_fasta_include_descriptions(fasta_string)
def convert_a3m_to_stockholm(a3m: str, max_seqs: int | None = None) -> str:
"""Converts MSA in the A3M format to the Stockholm format."""
sequences, descriptions = parse_fasta(a3m)
if max_seqs is not None:
sequences = sequences[:max_seqs]
descriptions = descriptions[:max_seqs]
stockholm = ['# STOCKHOLM 1.0', '']
# Add the Stockholm header with the sequence metadata.
names = []
for i, description in enumerate(descriptions):
name, _, rest = description.replace('\t', ' ').partition(' ')
# Ensure that the names are unique - stockholm format requires that
# the sequence names are unique.
name = f'{name}_{i}'
names.append(name)
# Avoid zero-length description due to historic hmmbuild parsing bug.
desc = rest.strip() or ''
stockholm.append(f'#=GS {name.strip()} DE {desc}')
stockholm.append('')
# Convert insertions in a sequence into gaps in all other sequences that don't
# have an insertion in that column as well.
sequences = msa_conversion.convert_a3m_to_stockholm(sequences)
# Add the MSA data.
max_name_width = max(len(name) for name in names)
for name, sequence in zip(names, sequences, strict=True):
# Align the names to the left and pad with spaces to the maximum length.
stockholm.append(f'{name:<{max_name_width}s} {sequence}')
# Add the reference annotation for the query (the first sequence).
ref_annotation = ''.join('.' if c == '-' else 'x' for c in sequences[0])
stockholm.append(f'{"#=GC RF":<{max_name_width}s} {ref_annotation}')
stockholm.append('//')
return '\n'.join(stockholm)
def convert_stockholm_to_a3m(
stockholm: IO[str],
max_sequences: int | None = None,
remove_first_row_gaps: bool = True,
linewidth: int | None = None,
) -> str:
"""Converts MSA in Stockholm format to the A3M format."""
descriptions = {}
sequences = {}
reached_max_sequences = False
if linewidth is not None and linewidth <= 0:
raise ValueError('linewidth must be > 0 or None')
for line in stockholm:
reached_max_sequences = max_sequences and len(sequences) >= max_sequences
line = line.strip()
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
if not line or line.startswith(('#', '//')):
continue
seqname, aligned_seq = line.split(maxsplit=1)
if seqname not in sequences:
if reached_max_sequences:
continue
sequences[seqname] = ''
sequences[seqname] += aligned_seq
if not sequences:
return ''
stockholm.seek(0)
for line in stockholm:
line = line.strip()
if line[:4] == '#=GS':
# Description row - example format is:
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
columns = line.split(maxsplit=3)
seqname, feature = columns[1:3]
value = columns[3] if len(columns) == 4 else ''
if feature != 'DE':
continue
if reached_max_sequences and seqname not in sequences:
continue
descriptions[seqname] = value
if len(descriptions) == len(sequences):
break
assert len(descriptions) <= len(sequences)
# Convert sto format to a3m line by line
a3m_sequences = {}
# query_sequence is assumed to be the first sequence
query_sequence = next(iter(sequences.values()))
for seqname, sto_sequence in sequences.items():
if remove_first_row_gaps:
a3m_sequences[seqname] = msa_conversion.align_sequence_to_gapless_query(
sequence=sto_sequence, query_sequence=query_sequence
).replace('.', '')
else:
a3m_sequences[seqname] = sto_sequence.replace('.', '')
fasta_chunks = []
for seqname, seq in a3m_sequences.items():
fasta_chunks.append(f'>{seqname} {descriptions.get(seqname, "")}')
if linewidth:
fasta_chunks.extend(
seq[i : linewidth + i] for i in range(0, len(seq), linewidth)
)
else:
fasta_chunks.append(seq)
return '\n'.join(fasta_chunks) + '\n' # Include terminating newline.
================================================
FILE: src/alphafold3/data/pipeline.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Functions for running the MSA and template tools for the AlphaFold model."""
from concurrent import futures
import dataclasses
import datetime
import functools
import logging
import time
from alphafold3.common import folding_input
from alphafold3.constants import mmcif_names
from alphafold3.data import msa
from alphafold3.data import msa_config
from alphafold3.data import structure_stores
from alphafold3.data import templates as templates_lib
# Cache to avoid re-running template search for the same sequence in homomers.
@functools.cache
def _get_protein_templates(
sequence: str,
input_msa_a3m: str,
run_template_search: bool,
templates_config: msa_config.TemplatesConfig,
pdb_database_path: str,
) -> templates_lib.Templates:
"""Searches for templates for a single protein chain."""
if run_template_search:
templates_start_time = time.time()
logging.info('Getting protein templates for sequence %s', sequence)
protein_templates = templates_lib.Templates.from_seq_and_a3m(
query_sequence=sequence,
msa_a3m=input_msa_a3m,
max_template_date=templates_config.filter_config.max_template_date,
database_path=templates_config.template_tool_config.database_path,
hmmsearch_config=templates_config.template_tool_config.hmmsearch_config,
max_a3m_query_sequences=None,
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
structure_store=structure_stores.StructureStore(pdb_database_path),
filter_config=templates_config.filter_config,
)
logging.info(
'Getting %d protein templates took %.2f seconds for sequence %s',
protein_templates.num_hits,
time.time() - templates_start_time,
sequence,
)
else:
logging.info('Skipping template search for sequence %s', sequence)
protein_templates = templates_lib.Templates(
query_sequence=sequence,
hits=[],
max_template_date=templates_config.filter_config.max_template_date,
structure_store=structure_stores.StructureStore(pdb_database_path),
)
return protein_templates
# Cache to avoid re-running the MSA tools for the same sequence in homomers.
@functools.cache
def _get_protein_msa_and_templates(
sequence: str,
run_template_search: bool,
uniref90_msa_config: msa_config.RunConfig,
mgnify_msa_config: msa_config.RunConfig,
small_bfd_msa_config: msa_config.RunConfig,
uniprot_msa_config: msa_config.RunConfig,
templates_config: msa_config.TemplatesConfig,
pdb_database_path: str,
) -> tuple[msa.Msa, msa.Msa, templates_lib.Templates]:
"""Processes a single protein chain."""
logging.info('Getting protein MSAs for sequence %s', sequence)
msa_start_time = time.time()
# Run various MSA tools in parallel. Use a ThreadPoolExecutor because
# they're not blocked by the GIL, as they're sub-shelled out.
with futures.ThreadPoolExecutor(max_workers=4) as executor:
uniref90_msa_future = executor.submit(
msa.get_msa,
target_sequence=sequence,
run_config=uniref90_msa_config,
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
)
mgnify_msa_future = executor.submit(
msa.get_msa,
target_sequence=sequence,
run_config=mgnify_msa_config,
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
)
small_bfd_msa_future = executor.submit(
msa.get_msa,
target_sequence=sequence,
run_config=small_bfd_msa_config,
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
)
uniprot_msa_future = executor.submit(
msa.get_msa,
target_sequence=sequence,
run_config=uniprot_msa_config,
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
)
uniref90_msa = uniref90_msa_future.result()
mgnify_msa = mgnify_msa_future.result()
small_bfd_msa = small_bfd_msa_future.result()
uniprot_msa = uniprot_msa_future.result()
logging.info(
'Getting protein MSAs took %.2f seconds for sequence %s',
time.time() - msa_start_time,
sequence,
)
logging.info('Deduplicating MSAs for sequence %s', sequence)
msa_dedupe_start_time = time.time()
with futures.ThreadPoolExecutor() as executor:
unpaired_protein_msa_future = executor.submit(
msa.Msa.from_multiple_msas,
msas=[uniref90_msa, small_bfd_msa, mgnify_msa],
deduplicate=True,
)
paired_protein_msa_future = executor.submit(
msa.Msa.from_multiple_msas, msas=[uniprot_msa], deduplicate=False
)
unpaired_protein_msa = unpaired_protein_msa_future.result()
paired_protein_msa = paired_protein_msa_future.result()
logging.info(
'Deduplicating MSAs took %.2f seconds for sequence %s, found %d unpaired'
' sequences, %d paired sequences',
time.time() - msa_dedupe_start_time,
sequence,
unpaired_protein_msa.depth,
paired_protein_msa.depth,
)
protein_templates = _get_protein_templates(
sequence=sequence,
input_msa_a3m=unpaired_protein_msa.to_a3m(),
run_template_search=run_template_search,
templates_config=templates_config,
pdb_database_path=pdb_database_path,
)
return unpaired_protein_msa, paired_protein_msa, protein_templates
# Cache to avoid re-running the Nhmmer for the same sequence in homomers.
@functools.cache
def _get_rna_msa(
sequence: str,
nt_rna_msa_config: msa_config.NhmmerConfig,
rfam_msa_config: msa_config.NhmmerConfig,
rnacentral_msa_config: msa_config.NhmmerConfig,
) -> msa.Msa:
"""Processes a single RNA chain."""
logging.info('Getting RNA MSAs for sequence %s', sequence)
rna_msa_start_time = time.time()
# Run various MSA tools in parallel. Use a ThreadPoolExecutor because
# they're not blocked by the GIL, as they're sub-shelled out.
with futures.ThreadPoolExecutor() as executor:
nt_rna_msa_future = executor.submit(
msa.get_msa,
target_sequence=sequence,
run_config=nt_rna_msa_config,
chain_poly_type=mmcif_names.RNA_CHAIN,
)
rfam_msa_future = executor.submit(
msa.get_msa,
target_sequence=sequence,
run_config=rfam_msa_config,
chain_poly_type=mmcif_names.RNA_CHAIN,
)
rnacentral_msa_future = executor.submit(
msa.get_msa,
target_sequence=sequence,
run_config=rnacentral_msa_config,
chain_poly_type=mmcif_names.RNA_CHAIN,
)
nt_rna_msa = nt_rna_msa_future.result()
rfam_msa = rfam_msa_future.result()
rnacentral_msa = rnacentral_msa_future.result()
rna_msa = msa.Msa.from_multiple_msas(
msas=[rfam_msa, rnacentral_msa, nt_rna_msa],
deduplicate=True,
)
logging.info(
'Getting RNA MSAs took %.2f seconds for sequence %s, found %d unpaired'
' sequences',
time.time() - rna_msa_start_time,
sequence,
rna_msa.depth,
)
return rna_msa
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class DataPipelineConfig:
"""The configuration for the data pipeline.
Attributes:
jackhmmer_binary_path: Jackhmmer binary path, used for protein MSA search.
nhmmer_binary_path: Nhmmer binary path, used for RNA MSA search.
hmmalign_binary_path: Hmmalign binary path, used to align hits to the query
profile.
hmmsearch_binary_path: Hmmsearch binary path, used for template search.
hmmbuild_binary_path: Hmmbuild binary path, used to build HMM profile from
raw MSA in template search.
small_bfd_database_path: Small BFD database path, used for protein MSA
search.
small_bfd_z_value: The Z-value representing the database size in number of
sequences for E-value calculation. Must be set for sharded databases.
mgnify_database_path: Mgnify database path, used for protein MSA search.
mgnify_z_value: The Z-value representing the database size in number of
sequences for E-value calculation. Must be set for sharded databases.
uniprot_cluster_annot_database_path: Uniprot database path, used for protein
paired MSA search.
uniprot_cluster_annot_z_value: The Z-value representing the database size in
number of sequences for E-value calculation. Must be set for sharded
databases.
uniref90_database_path: UniRef90 database path, used for MSA search, and the
MSA obtained by searching it is used to construct the profile for template
search.
uniref90_z_value: The Z-value representing the database size in number of
sequences for E-value calculation. Must be set for sharded databases.
ntrna_database_path: NT-RNA database path, used for RNA MSA search.
ntrna_z_value: The Z-value representing the database size in megabases for
E-value calculation. Must be set for sharded databases.
rfam_database_path: Rfam database path, used for RNA MSA search.
rfam_z_value: The Z-value representing the database size in megabases for
E-value calculation. Must be set for sharded databases.
rna_central_database_path: RNAcentral database path, used for RNA MSA
search.
rna_central_z_value: The Z-value representing the database size in megabases
for E-value calculation. Must be set for sharded databases.
seqres_database_path: PDB sequence database path, used for template search.
pdb_database_path: PDB database directory with mmCIF files path, used for
template search.
jackhmmer_n_cpu: Number of CPUs to use for Jackhmmer.
jackhmmer_max_parallel_shards: Maximum number of shards to search against in
parallel. If None, one Jackhmmer instance will be run per shard. Only
applicable if the database is sharded.
nhmmer_n_cpu: Number of CPUs to use for Nhmmer.
nhmmer_max_parallel_shards: Maximum number of shards to search against in
parallel. If None, one Nhmmer instance will be run per shard. Only
applicable if the database is sharded.
max_template_date: The latest date of templates to use.
"""
# Binary paths.
jackhmmer_binary_path: str
nhmmer_binary_path: str
hmmalign_binary_path: str
hmmsearch_binary_path: str
hmmbuild_binary_path: str
# Jackhmmer databases.
small_bfd_database_path: str
small_bfd_z_value: int | None = None
mgnify_database_path: str
mgnify_z_value: int | None = None
uniprot_cluster_annot_database_path: str
uniprot_cluster_annot_z_value: int | None = None
uniref90_database_path: str
uniref90_z_value: int | None = None
# Nhmmer databases.
ntrna_database_path: str
ntrna_z_value: int | None = None
rfam_database_path: str
rfam_z_value: int | None = None
rna_central_database_path: str
rna_central_z_value: int | None = None
# Template search databases.
seqres_database_path: str
pdb_database_path: str
# Optional configuration for MSA tools.
jackhmmer_n_cpu: int = 8
jackhmmer_max_parallel_shards: int | None = None
nhmmer_n_cpu: int = 8
nhmmer_max_parallel_shards: int | None = None
max_template_date: datetime.date
class DataPipeline:
"""Runs the alignment tools and assembles the input features."""
def __init__(self, data_pipeline_config: DataPipelineConfig):
"""Initializes the data pipeline with default configurations."""
self._uniref90_msa_config = msa_config.RunConfig(
config=msa_config.JackhmmerConfig(
binary_path=data_pipeline_config.jackhmmer_binary_path,
database_config=msa_config.DatabaseConfig(
name='uniref90',
path=data_pipeline_config.uniref90_database_path,
),
n_cpu=data_pipeline_config.jackhmmer_n_cpu,
n_iter=1,
e_value=1e-4,
z_value=data_pipeline_config.uniref90_z_value,
dom_z_value=data_pipeline_config.uniref90_z_value,
max_sequences=10_000,
max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards,
),
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
crop_size=None,
)
self._mgnify_msa_config = msa_config.RunConfig(
config=msa_config.JackhmmerConfig(
binary_path=data_pipeline_config.jackhmmer_binary_path,
database_config=msa_config.DatabaseConfig(
name='mgnify',
path=data_pipeline_config.mgnify_database_path,
),
n_cpu=data_pipeline_config.jackhmmer_n_cpu,
n_iter=1,
e_value=1e-4,
z_value=data_pipeline_config.mgnify_z_value,
dom_z_value=data_pipeline_config.mgnify_z_value,
max_sequences=5_000,
max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards,
),
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
crop_size=None,
)
self._small_bfd_msa_config = msa_config.RunConfig(
config=msa_config.JackhmmerConfig(
binary_path=data_pipeline_config.jackhmmer_binary_path,
database_config=msa_config.DatabaseConfig(
name='small_bfd',
path=data_pipeline_config.small_bfd_database_path,
),
n_cpu=data_pipeline_config.jackhmmer_n_cpu,
n_iter=1,
e_value=1e-4,
# Set z_value=138_515_945 to match the z_value used in the paper.
# In practice, this has minimal impact on predicted structures.
z_value=data_pipeline_config.small_bfd_z_value,
dom_z_value=data_pipeline_config.small_bfd_z_value,
max_sequences=5_000,
max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards,
),
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
crop_size=None,
)
self._uniprot_msa_config = msa_config.RunConfig(
config=msa_config.JackhmmerConfig(
binary_path=data_pipeline_config.jackhmmer_binary_path,
database_config=msa_config.DatabaseConfig(
name='uniprot_cluster_annot',
path=data_pipeline_config.uniprot_cluster_annot_database_path,
),
n_cpu=data_pipeline_config.jackhmmer_n_cpu,
n_iter=1,
e_value=1e-4,
z_value=data_pipeline_config.uniprot_cluster_annot_z_value,
dom_z_value=data_pipeline_config.uniprot_cluster_annot_z_value,
max_sequences=50_000,
max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards,
),
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
crop_size=None,
)
self._nt_rna_msa_config = msa_config.RunConfig(
config=msa_config.NhmmerConfig(
binary_path=data_pipeline_config.nhmmer_binary_path,
hmmalign_binary_path=data_pipeline_config.hmmalign_binary_path,
hmmbuild_binary_path=data_pipeline_config.hmmbuild_binary_path,
database_config=msa_config.DatabaseConfig(
name='nt_rna',
path=data_pipeline_config.ntrna_database_path,
),
n_cpu=data_pipeline_config.nhmmer_n_cpu,
e_value=1e-3,
alphabet='rna',
z_value=data_pipeline_config.ntrna_z_value,
max_sequences=10_000,
max_parallel_shards=data_pipeline_config.nhmmer_max_parallel_shards,
),
chain_poly_type=mmcif_names.RNA_CHAIN,
crop_size=None,
)
self._rfam_msa_config = msa_config.RunConfig(
config=msa_config.NhmmerConfig(
binary_path=data_pipeline_config.nhmmer_binary_path,
hmmalign_binary_path=data_pipeline_config.hmmalign_binary_path,
hmmbuild_binary_path=data_pipeline_config.hmmbuild_binary_path,
database_config=msa_config.DatabaseConfig(
name='rfam_rna',
path=data_pipeline_config.rfam_database_path,
),
n_cpu=data_pipeline_config.nhmmer_n_cpu,
e_value=1e-3,
alphabet='rna',
z_value=data_pipeline_config.rfam_z_value,
max_sequences=10_000,
max_parallel_shards=data_pipeline_config.nhmmer_max_parallel_shards,
),
chain_poly_type=mmcif_names.RNA_CHAIN,
crop_size=None,
)
self._rnacentral_msa_config = msa_config.RunConfig(
config=msa_config.NhmmerConfig(
binary_path=data_pipeline_config.nhmmer_binary_path,
hmmalign_binary_path=data_pipeline_config.hmmalign_binary_path,
hmmbuild_binary_path=data_pipeline_config.hmmbuild_binary_path,
database_config=msa_config.DatabaseConfig(
name='rna_central_rna',
path=data_pipeline_config.rna_central_database_path,
),
n_cpu=data_pipeline_config.nhmmer_n_cpu,
e_value=1e-3,
alphabet='rna',
z_value=data_pipeline_config.rna_central_z_value,
max_sequences=10_000,
max_parallel_shards=data_pipeline_config.nhmmer_max_parallel_shards,
),
chain_poly_type=mmcif_names.RNA_CHAIN,
crop_size=None,
)
self._templates_config = msa_config.TemplatesConfig(
template_tool_config=msa_config.TemplateToolConfig(
database_path=data_pipeline_config.seqres_database_path,
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
hmmsearch_config=msa_config.HmmsearchConfig(
hmmsearch_binary_path=data_pipeline_config.hmmsearch_binary_path,
hmmbuild_binary_path=data_pipeline_config.hmmbuild_binary_path,
filter_f1=0.1,
filter_f2=0.1,
filter_f3=0.1,
e_value=100,
inc_e=100,
dom_e=100,
incdom_e=100,
alphabet='amino',
),
),
filter_config=msa_config.TemplateFilterConfig(
max_subsequence_ratio=0.95,
min_align_ratio=0.1,
min_hit_length=10,
deduplicate_sequences=True,
max_hits=4,
max_template_date=data_pipeline_config.max_template_date,
),
)
self._pdb_database_path = data_pipeline_config.pdb_database_path
def process_protein_chain(
self, chain: folding_input.ProteinChain
) -> folding_input.ProteinChain:
"""Processes a single protein chain."""
has_unpaired_msa = chain.unpaired_msa is not None
has_paired_msa = chain.paired_msa is not None
has_templates = chain.templates is not None
if not has_unpaired_msa and not has_paired_msa and not chain.templates:
# MSA None - search. Templates either [] - don't search, or None - search.
unpaired_msa, paired_msa, template_hits = _get_protein_msa_and_templates(
sequence=chain.sequence,
run_template_search=not has_templates, # Skip template search if [].
uniref90_msa_config=self._uniref90_msa_config,
mgnify_msa_config=self._mgnify_msa_config,
small_bfd_msa_config=self._small_bfd_msa_config,
uniprot_msa_config=self._uniprot_msa_config,
templates_config=self._templates_config,
pdb_database_path=self._pdb_database_path,
)
unpaired_msa = unpaired_msa.to_a3m()
paired_msa = paired_msa.to_a3m()
templates = [
folding_input.Template(
mmcif=struc.to_mmcif(),
query_to_template_map=hit.query_to_hit_mapping,
)
for hit, struc in template_hits.get_hits_with_structures()
]
elif has_unpaired_msa and has_paired_msa and not has_templates:
# Has MSA, but doesn't have templates. Search for templates only.
empty_msa = msa.Msa.from_empty(
query_sequence=chain.sequence,
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
).to_a3m()
unpaired_msa = chain.unpaired_msa or empty_msa
paired_msa = chain.paired_msa or empty_msa
template_hits = _get_protein_templates(
sequence=chain.sequence,
input_msa_a3m=unpaired_msa,
run_template_search=True,
templates_config=self._templates_config,
pdb_database_path=self._pdb_database_path,
)
templates = [
folding_input.Template(
mmcif=struc.to_mmcif(),
query_to_template_map=hit.query_to_hit_mapping,
)
for hit, struc in template_hits.get_hits_with_structures()
]
else:
# Has MSA and templates, don't search for anything.
if not has_unpaired_msa or not has_paired_msa or not has_templates:
raise ValueError(
f'Protein chain {chain.id} has unpaired MSA, paired MSA, or'
' templates set only partially. If you want to run the pipeline'
' with custom MSA/templates, you need to set all of them. You can'
' set MSA to empty string and templates to empty list to signify'
' that they should not be used and searched for.'
)
logging.info(
'Skipping MSA and template search for protein chain %s because it '
'already has MSAs and templates.',
chain.id,
)
if not chain.unpaired_msa:
logging.info('Using empty unpaired MSA for protein chain %s', chain.id)
if not chain.paired_msa:
logging.info('Using empty paired MSA for protein chain %s', chain.id)
if not chain.templates:
logging.info('Using no templates for protein chain %s', chain.id)
empty_msa = msa.Msa.from_empty(
query_sequence=chain.sequence,
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
).to_a3m()
unpaired_msa = chain.unpaired_msa or empty_msa
paired_msa = chain.paired_msa or empty_msa
templates = chain.templates
return folding_input.ProteinChain(
id=chain.id,
sequence=chain.sequence,
ptms=chain.ptms,
unpaired_msa=unpaired_msa,
paired_msa=paired_msa,
templates=templates,
)
def process_rna_chain(
self, chain: folding_input.RnaChain
) -> folding_input.RnaChain:
"""Processes a single RNA chain."""
if chain.unpaired_msa is not None:
# Don't run MSA tools if the chain already has an MSA.
logging.info(
'Skipping MSA search for RNA chain %s because it already has MSA.',
chain.id,
)
if not chain.unpaired_msa:
logging.info('Using empty unpaired MSA for RNA chain %s', chain.id)
empty_msa = msa.Msa.from_empty(
query_sequence=chain.sequence, chain_poly_type=mmcif_names.RNA_CHAIN
).to_a3m()
unpaired_msa = chain.unpaired_msa or empty_msa
else:
unpaired_msa = _get_rna_msa(
sequence=chain.sequence,
nt_rna_msa_config=self._nt_rna_msa_config,
rfam_msa_config=self._rfam_msa_config,
rnacentral_msa_config=self._rnacentral_msa_config,
).to_a3m()
return folding_input.RnaChain(
id=chain.id,
sequence=chain.sequence,
modifications=chain.modifications,
unpaired_msa=unpaired_msa,
)
def process(self, fold_input: folding_input.Input) -> folding_input.Input:
"""Runs MSA and template tools and returns a new Input with the results."""
processed_chains = []
for chain in fold_input.chains:
print(f'Running data pipeline for chain {chain.id}...')
process_chain_start_time = time.time()
match chain:
case folding_input.ProteinChain():
processed_chains.append(self.process_protein_chain(chain))
case folding_input.RnaChain():
processed_chains.append(self.process_rna_chain(chain))
case _:
processed_chains.append(chain)
print(
f'Running data pipeline for chain {chain.id} took'
f' {time.time() - process_chain_start_time:.2f} seconds',
)
return dataclasses.replace(fold_input, chains=processed_chains)
================================================
FILE: src/alphafold3/data/structure_stores.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Library for loading structure data from various sources."""
from collections.abc import Mapping, Sequence
import functools
import os
import pathlib
import tarfile
class NotFoundError(KeyError):
"""Raised when the structure store doesn't contain the requested target."""
class StructureStore:
"""Handles the retrieval of mmCIF files from a filesystem."""
def __init__(
self,
structures: str | os.PathLike[str] | Mapping[str, str],
):
"""Initialises the instance.
Args:
structures: Path of the directory where the mmCIF files are or a Mapping
from target name to mmCIF string.
"""
if isinstance(structures, Mapping):
self._structure_mapping = structures
self._structure_path = None
self._structure_tar = None
else:
self._structure_mapping = None
path_str = os.fspath(structures)
if path_str.endswith('.tar'):
self._structure_tar = tarfile.open(path_str, 'r')
self._structure_path = None
else:
self._structure_path = pathlib.Path(structures)
self._structure_tar = None
@functools.cached_property
def _tar_members(self) -> Mapping[str, tarfile.TarInfo]:
assert self._structure_tar is not None
return {
path.stem: tarinfo
for tarinfo in self._structure_tar.getmembers()
if tarinfo.isfile()
and (path := pathlib.Path(tarinfo.path.lower())).suffix == '.cif'
}
def get_mmcif_str(self, target_name: str) -> str:
"""Returns an mmCIF for a given `target_name`.
Args:
target_name: Name specifying the target mmCIF.
Raises:
NotFoundError: If the target is not found.
"""
if self._structure_mapping is not None:
try:
return self._structure_mapping[target_name]
except KeyError as e:
raise NotFoundError(f'{target_name=} not found') from e
if self._structure_tar is not None:
try:
member = self._tar_members[target_name]
if struct_file := self._structure_tar.extractfile(member):
return struct_file.read().decode()
else:
raise NotFoundError(f'{target_name=} not found')
except KeyError:
raise NotFoundError(f'{target_name=} not found') from None
filepath = self._structure_path / f'{target_name}.cif'
try:
return filepath.read_text()
except FileNotFoundError as e:
raise NotFoundError(f'{target_name=} not found at {filepath=}') from e
def target_names(self) -> Sequence[str]:
"""Returns all targets in the store."""
if self._structure_mapping is not None:
return [*self._structure_mapping.keys()]
elif self._structure_tar is not None:
return sorted(self._tar_members.keys())
elif self._structure_path is not None:
return sorted([path.stem for path in self._structure_path.glob('*.cif')])
return ()
================================================
FILE: src/alphafold3/data/template_realign.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Realign sequences found in PDB seqres to the actual CIF sequences."""
from collections.abc import Mapping
class AlignmentError(Exception):
"""Failed alignment between the hit sequence and the actual mmCIF sequence."""
def realign_hit_to_structure(
*,
hit_sequence: str,
hit_start_index: int,
hit_end_index: int,
full_length: int,
structure_sequence: str,
query_to_hit_mapping: Mapping[int, int],
) -> Mapping[int, int]:
"""Realigns the hit sequence to the Structure sequence.
For example, for the given input:
query_sequence : ABCDEFGHIJKL
hit_sequence : ---DEFGHIJK-
struc_sequence : XDEFGHKL
the mapping is {3: 0, 4: 1, 5: 2, 6: 3, 7: 4, 8: 5, 9: 6, 10: 7}. However, the
actual Structure sequence has an extra X at the start as well as no IJ. So the
alignment from the query to the Structure sequence will be:
hit_sequence : ---DEFGHIJK-
struc_aligned : --XDEFGH--KL
and the new mapping will therefore be: {3: 1, 4: 2, 5: 3, 6: 4, 7: 5, 10: 6}.
Args:
hit_sequence: The PDB seqres hit sequence obtained from Hmmsearch, but
without any gaps. This is not the full PDB seqres template sequence but
rather just its subsequence from hit_start_index to hit_end_index.
hit_start_index: The start index of the hit sequence in the full PDB seqres
template sequence (inclusive).
hit_end_index: The end index of the hit sequence in the full PDB seqres
template sequence (exclusive).
full_length: The length of the full PDB seqres template sequence.
structure_sequence: The actual sequence extracted from the Structure
corresponding to this template. In vast majority of cases this is the same
as the PDB seqres sequence, but this function handles the cases when not.
query_to_hit_mapping: The mapping from the query sequence to the
hit_sequence.
Raises:
AlignmentError: if the alignment between the sequence returned by Hmmsearch
differs from the actual sequence found in the mmCIF and can't be aligned
using the simple alignment algorithm.
Returns:
A mapping from the query sequence to the actual Structure sequence.
"""
max_num_gaps = full_length - len(structure_sequence)
if max_num_gaps < 0:
raise AlignmentError(
f'The Structure sequence ({len(structure_sequence)}) '
f'must be shorter than the PDB seqres sequence ({full_length}):\n'
f'Structure sequence : {structure_sequence}\n'
f'PDB seqres sequence: {hit_sequence}'
)
if len(hit_sequence) != hit_end_index - hit_start_index:
raise AlignmentError(
f'The difference of {hit_end_index=} and {hit_start_index=} does not '
f'equal to the length of the {hit_sequence}: {len(hit_sequence)}'
)
best_score = -1
best_start = 0
best_query_to_hit_mapping = query_to_hit_mapping
max_num_gaps_before_subseq = min(hit_start_index, max_num_gaps)
# It is possible the gaps needed to align the PDB seqres subsequence and
# the Structure subsequence need to be inserted before the match region.
# Try and pick the alignment with the best number of aligned residues.
for num_gaps_before_subseq in range(0, max_num_gaps_before_subseq + 1):
start = hit_start_index - num_gaps_before_subseq
end = hit_end_index - num_gaps_before_subseq
structure_subseq = structure_sequence[start:end]
new_query_to_hit_mapping, score = _remap_to_struc_seq(
hit_seq=hit_sequence,
struc_seq=structure_subseq,
max_num_gaps=max_num_gaps - num_gaps_before_subseq,
mapping=query_to_hit_mapping,
)
if score >= best_score:
# Use >= to prefer matches with larger number of gaps before.
best_score = score
best_start = start
best_query_to_hit_mapping = new_query_to_hit_mapping
return {q: h + best_start for q, h in best_query_to_hit_mapping.items()}
def _remap_to_struc_seq(
*,
hit_seq: str,
struc_seq: str,
max_num_gaps: int,
mapping: Mapping[int, int],
) -> tuple[Mapping[int, int], int]:
"""Remaps the query -> hit mapping to match the actual Structure sequence.
Args:
hit_seq: The hit sequence - a subsequence of the PDB seqres sequence without
any Hmmsearch modifications like inserted gaps or lowercased residues.
struc_seq: The actual sequence obtained from the corresponding Structure.
max_num_gaps: The maximum number of gaps that can be inserted in the
Structure sequence. In practice, this is the length difference between the
PDB seqres sequence and the actual Structure sequence.
mapping: The mapping from the query residues to the hit residues. This will
be remapped to point to the actual Structure sequence using a simple
realignment algorithm.
Returns:
A tuple of (mapping, score):
* Mapping from the query to the actual Structure sequence.
* Score which is the number of matching aligned residues.
Raises:
ValueError if the structure sequence isn't shorter than the seqres sequence.
ValueError if the alignment fails.
"""
hit_seq_idx = 0
struc_seq_idx = 0
hit_to_struc_seq_mapping = {}
score = 0
# This while loop is guaranteed to terminate since we increase both
# struc_seq_idx and hit_seq_idx by at least 1 in each iteration.
remaining_num_gaps = max_num_gaps
while hit_seq_idx < len(hit_seq) and struc_seq_idx < len(struc_seq):
if hit_seq[hit_seq_idx] != struc_seq[struc_seq_idx]:
# Explore which alignment aligns the next residue (if present).
best_shift = 0
for shift in range(0, remaining_num_gaps + 1):
next_hit_res = hit_seq[hit_seq_idx + shift : hit_seq_idx + shift + 1]
next_struc_res = struc_seq[struc_seq_idx : struc_seq_idx + 1]
if next_hit_res == next_struc_res:
best_shift = shift
break
hit_seq_idx += best_shift
remaining_num_gaps -= best_shift
hit_to_struc_seq_mapping[hit_seq_idx] = struc_seq_idx
score += hit_seq[hit_seq_idx] == struc_seq[struc_seq_idx]
hit_seq_idx += 1
struc_seq_idx += 1
fixed_mapping = {}
for query_idx, original_hit_idx in mapping.items():
fixed_hit_idx = hit_to_struc_seq_mapping.get(original_hit_idx)
if fixed_hit_idx is not None:
fixed_mapping[query_idx] = fixed_hit_idx
return fixed_mapping, score
================================================
FILE: src/alphafold3/data/templates.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""API for retrieving and manipulating template search results."""
from collections.abc import Iterable, Iterator, Mapping, Sequence
import dataclasses
import datetime
import functools
import os
import re
from typing import Any, Final, Self, TypeAlias
from absl import logging
from alphafold3 import structure
from alphafold3.common import resources
from alphafold3.constants import atom_types
from alphafold3.constants import mmcif_names
from alphafold3.constants import residue_names
from alphafold3.data import msa_config
from alphafold3.data import parsers
from alphafold3.data import structure_stores
from alphafold3.data import template_realign
from alphafold3.data.tools import hmmsearch
from alphafold3.structure import mmcif
import numpy as np
_POLYMER_FEATURES: Final[Mapping[str, np.float64 | np.int32 | object]] = {
'template_aatype': np.int32,
'template_all_atom_masks': np.float64,
'template_all_atom_positions': np.float64,
'template_domain_names': object,
'template_release_date': object,
'template_sequence': object,
}
_LIGAND_FEATURES: Final[Mapping[str, Any]] = {
'ligand_features': Mapping[str, Any]
}
TemplateFeatures: TypeAlias = Mapping[
str, np.ndarray | bytes | Mapping[str, np.ndarray | bytes]
]
_REQUIRED_METADATA_COLUMNS: Final[Sequence[str]] = (
'seq_release_date',
'seq_unresolved_res_num',
'seq_author_chain_id',
'seq_sequence',
)
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class _Polymer:
"""Container for alphabet specific (dna, rna, protein) atom information."""
min_atoms: int
num_atom_types: int
atom_order: Mapping[str, int]
_POLYMERS = {
mmcif_names.PROTEIN_CHAIN: _Polymer(
min_atoms=5,
num_atom_types=atom_types.ATOM37_NUM,
atom_order=atom_types.ATOM37_ORDER,
),
mmcif_names.DNA_CHAIN: _Polymer(
min_atoms=21,
num_atom_types=atom_types.ATOM29_NUM,
atom_order=atom_types.ATOM29_ORDER,
),
mmcif_names.RNA_CHAIN: _Polymer(
min_atoms=20,
num_atom_types=atom_types.ATOM29_NUM,
atom_order=atom_types.ATOM29_ORDER,
),
}
def _encode_restype(
chain_poly_type: str,
sequence: str,
) -> Sequence[int]:
"""Encodes a sequence of residue names as a sequence of ints.
Args:
chain_poly_type: Polymer chain type to determine sequence encoding.
sequence: Polymer residues. Protein encoded by single letters. RNA and DNA
encoded by multi-letter CCD codes.
Returns:
A sequence of integers encoding amino acid types for the given chain type.
"""
if chain_poly_type == mmcif_names.PROTEIN_CHAIN:
return [
residue_names.PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP_TO_INT[
_STANDARDIZED_AA.get(res, res)
]
for res in sequence
]
unk_nucleic = residue_names.UNK_NUCLEIC_ONE_LETTER
unk_nucleic_idx = residue_names.POLYMER_TYPES_ORDER_WITH_UNKNOWN_AND_GAP[
unk_nucleic
]
if chain_poly_type == mmcif_names.RNA_CHAIN:
return [
residue_names.POLYMER_TYPES_ORDER_WITH_UNKNOWN_AND_GAP.get(
res, unk_nucleic_idx
)
for res in sequence
]
elif chain_poly_type == mmcif_names.DNA_CHAIN:
# Map UNK DNA to the generic nucleic UNK (N), which happens to also be the
# same as the RNA UNK.
return [
residue_names.POLYMER_TYPES_ORDER_WITH_UNKNOWN_AND_GAP.get(
residue_names.DNA_COMMON_ONE_TO_TWO.get(res, unk_nucleic),
unk_nucleic_idx,
)
for res in sequence
]
raise NotImplementedError(f'"{chain_poly_type}" unsupported.')
_DAYS_BEFORE_QUERY_DATE: Final[int] = 60
_HIT_DESCRIPTION_REGEX = re.compile(
r'(?P[a-z0-9]{4,})_(?P\w+)/(?P\d+)-(?P\d+) '
r'.* length:(?P\d+)\b.*'
)
_STANDARDIZED_AA = {'B': 'D', 'J': 'X', 'O': 'X', 'U': 'C', 'Z': 'E'}
class Error(Exception):
"""Base class for exceptions."""
class HitDateError(Error):
"""An error indicating that invalid release date was detected."""
class InvalidTemplateError(Error):
"""An error indicating that template is invalid."""
@dataclasses.dataclass(frozen=True, kw_only=True)
class Hit:
"""Template hit metrics derived from the MSA for filtering and featurising.
Attributes:
pdb_id: The PDB ID of the hit.
auth_chain_id: The author chain ID of the hit.
hmmsearch_sequence: Hit sequence as given in hmmsearch a3m output.
structure_sequence: Hit sequence as given in PDB structure.
unresolved_res_indices: Indices of unresolved residues in the structure
sequence. 0-based.
query_sequence: The query nucleotide/amino acid sequence.
start_index: The start index of the sequence relative to the full PDB seqres
sequence. Inclusive and uses 0-based indexing.
end_index: The end index of the sequence relative to the full PDB seqres
sequence. Exclusive and uses 0-based indexing.
full_length: Length of the full PDB seqres sequence. This can be different
from the length from the actual sequence we get from the mmCIF and we use
this to detect whether we need to realign or not.
release_date: The release date of the PDB corresponding to this hit.
chain_poly_type: The polymer type of the selected hit structure.
"""
pdb_id: str
auth_chain_id: str
hmmsearch_sequence: str
structure_sequence: str
unresolved_res_indices: Sequence[int] | None
query_sequence: str
start_index: int
end_index: int
full_length: int
release_date: datetime.date
chain_poly_type: str
@functools.cached_property
def query_to_hit_mapping(self) -> Mapping[int, int]:
"""0-based query index to hit index mapping."""
query_to_hit_mapping = {}
hit_index = 0
query_index = 0
for residue in self.hmmsearch_sequence:
# Gap inserted in the template
if residue == '-':
query_index += 1
# Deleted residue in the template (would be a gap in the query).
elif residue.islower():
hit_index += 1
# Normal aligned residue, in both query and template. Add to mapping.
elif residue.isupper():
query_to_hit_mapping[query_index] = hit_index
query_index += 1
hit_index += 1
structure_subseq = self.structure_sequence[
self.start_index : self.end_index
]
if self.matching_sequence != structure_subseq:
# The seqres sequence doesn't match the structure sequence. Two cases:
# 1. The sequences have the same length. The sequences are different
# because our 3->1 residue code mapping is different from the one PDB
# uses. We don't do anything in this case as both sequences have the
# same length, so the original query to hit mapping stays valid.
# 2. The sequences don't have the same length, the one in structure is
# shorter. In this case we change the mapping to match the actual
# structure sequence using a simple realignment algorithm.
# This procedure was validated on all PDB seqres (2023_01_12) sequences
# and handles all cases that can happen.
if self.full_length != len(self.structure_sequence):
return template_realign.realign_hit_to_structure(
hit_sequence=self.matching_sequence,
hit_start_index=self.start_index,
hit_end_index=self.end_index,
full_length=self.full_length,
structure_sequence=self.structure_sequence,
query_to_hit_mapping=query_to_hit_mapping,
)
# Hmmsearch returns a subsequence and so far indices have been relative to
# the subsequence. Add an offset to index relative to the full structure
# sequence.
return {q: h + self.start_index for q, h in query_to_hit_mapping.items()}
@property
def matching_sequence(self) -> str:
"""Returns the matching hit sequence including insertions.
Make deleted residues uppercase and remove gaps ("-").
"""
return self.hmmsearch_sequence.upper().replace('-', '')
@functools.cached_property
def output_templates_sequence(self) -> str:
"""Returns the final template sequence."""
result_seq = ['-'] * len(self.query_sequence)
for query_index, template_index in self.query_to_hit_mapping.items():
result_seq[query_index] = self.structure_sequence[template_index]
return ''.join(result_seq)
@property
def length_ratio(self) -> float:
"""Ratio of the length of the hit sequence to the query."""
return len(self.matching_sequence) / len(self.query_sequence)
@property
def align_ratio(self) -> float:
"""Ratio of the number of aligned residues to the query length."""
return len(self.query_to_hit_mapping) / len(self.query_sequence)
@functools.cached_property
def is_valid(self) -> bool:
"""Whether hit can be used as a template."""
if self.unresolved_res_indices is None:
return False
return bool(
set(self.query_to_hit_mapping.values())
- set(self.unresolved_res_indices)
)
@property
def full_name(self) -> str:
"""A full name of the hit."""
return f'{self.pdb_id}_{self.auth_chain_id}'
def __post_init__(self):
if not self.pdb_id.islower() and not self.pdb_id.isdigit():
raise ValueError(f'pdb_id must be lowercase {self.pdb_id}')
if not (0 <= self.start_index <= self.end_index):
raise ValueError(
'Start must be non-negative and less than or equal to end index. '
f'Range: {self.start_index}-{self.end_index}'
)
if len(self.matching_sequence) != (self.end_index - self.start_index):
raise ValueError(
'Sequence length must be equal to end_index - start_index. '
f'{len(self.matching_sequence)} != {self.end_index} - '
f'{self.start_index}'
)
if self.full_length < 0:
raise ValueError(f'Full length must be non-negative: {self.full_length}')
def keep(
self,
*,
release_date_cutoff: datetime.date | None,
max_subsequence_ratio: float | None,
min_hit_length: int | None,
min_align_ratio: float | None,
) -> bool:
"""Returns whether the hit should be kept.
In addition to filtering on all of the provided parameters, this method also
excludes hits with unresolved residues.
Args:
release_date_cutoff: Maximum release date of the template.
max_subsequence_ratio: If set, excludes hits which are an exact
subsequence of the query sequence, and longer than this ratio. Useful to
avoid ground truth leakage.
min_hit_length: If set, excludes hits which have fewer residues than this.
min_align_ratio: If set, excludes hits where the number of residues
aligned to the query is less than this proportion of the template
length.
"""
# Exclude hits which are too recent.
if (
release_date_cutoff is not None
and self.release_date > release_date_cutoff
):
return False
# Exclude hits which are large duplicates of the query_sequence.
if (
max_subsequence_ratio is not None
and self.length_ratio > max_subsequence_ratio
):
if self.matching_sequence in self.query_sequence:
return False
# Exclude hits which are too short.
if (
min_hit_length is not None
and len(self.matching_sequence) < min_hit_length
):
return False
# Exclude hits with unresolved residues.
if not self.is_valid:
return False
# Exclude hits with too few alignments.
try:
if min_align_ratio is not None and self.align_ratio <= min_align_ratio:
return False
except template_realign.AlignmentError as e:
logging.warning('Failed to align %s: %s', self, str(e))
return False
return True
def _filter_hits(
hits: Iterable[Hit],
release_date_cutoff: datetime.date,
max_subsequence_ratio: float | None,
min_align_ratio: float | None,
min_hit_length: int | None,
deduplicate_sequences: bool,
max_hits: int | None,
) -> Sequence[Hit]:
"""Filters hits based on the filter config."""
filtered_hits = []
seen_before = set()
for hit in hits:
if not hit.keep(
max_subsequence_ratio=max_subsequence_ratio,
min_align_ratio=min_align_ratio,
min_hit_length=min_hit_length,
release_date_cutoff=release_date_cutoff,
):
continue
# Remove duplicate templates, keeping the first.
if deduplicate_sequences:
if hit.output_templates_sequence in seen_before:
continue
seen_before.add(hit.output_templates_sequence)
filtered_hits.append(hit)
if max_hits and len(filtered_hits) == max_hits:
break
return filtered_hits
@dataclasses.dataclass(init=False)
class Templates:
"""A container for templates that were found for the given query sequence.
The structure_store is constructed from the config by default. Callers can
optionally supply a structure_store to the constructor to avoid the cost of
construction and metadata loading.
"""
def __init__(
self,
*,
query_sequence: str,
hits: Sequence[Hit],
max_template_date: datetime.date,
structure_store: structure_stores.StructureStore,
query_release_date: datetime.date | None = None,
):
self._query_sequence = query_sequence
self._hits = tuple(hits)
self._max_template_date = max_template_date
self._query_release_date = query_release_date
self._hit_structures = {}
self._structure_store = structure_store
if any(h.query_sequence != self._query_sequence for h in self.hits):
raise ValueError('All hits must match the query sequence.')
if self._hits:
chain_poly_type = self._hits[0].chain_poly_type
if any(h.chain_poly_type != chain_poly_type for h in self.hits):
raise ValueError('All hits must have the same chain_poly_type.')
@classmethod
def from_seq_and_a3m(
cls,
*,
query_sequence: str,
msa_a3m: str,
max_template_date: datetime.date,
database_path: os.PathLike[str] | str,
hmmsearch_config: msa_config.HmmsearchConfig,
max_a3m_query_sequences: int | None,
structure_store: structure_stores.StructureStore,
filter_config: msa_config.TemplateFilterConfig | None = None,
query_release_date: datetime.date | None = None,
chain_poly_type: str = mmcif_names.PROTEIN_CHAIN,
) -> Self:
"""Creates templates from a run of hmmsearch tool against a custom a3m.
Args:
query_sequence: The polymer sequence of the target query.
msa_a3m: An a3m of related polymers aligned to the query sequence, this is
used to create an HMM for the hmmsearch run.
max_template_date: This is used to filter templates for training, ensuring
that they do not leak ground truth information used in testing sets.
database_path: A path to the sequence database to search for templates.
hmmsearch_config: Config with Hmmsearch settings.
max_a3m_query_sequences: The maximum number of input MSA sequences to use
to construct the profile which is then used to search for templates.
structure_store: Structure store to fetch template structures from.
filter_config: Optional config that controls which and how many hits to
keep. More performant than constructing and then filtering. If not
provided, no filtering is done.
query_release_date: The release_date of the template query, this is used
to filter templates for training, ensuring that they do not leak
structure information from the future.
chain_poly_type: The polymer type of the templates.
Returns:
Templates object containing a list of Hits initialised from the
structure_store metadata and a3m alignments.
"""
hmmsearch_a3m = run_hmmsearch_with_a3m(
database_path=database_path,
hmmsearch_config=hmmsearch_config,
max_a3m_query_sequences=max_a3m_query_sequences,
a3m=msa_a3m,
)
return cls.from_hmmsearch_a3m(
query_sequence=query_sequence,
a3m=hmmsearch_a3m,
max_template_date=max_template_date,
query_release_date=query_release_date,
chain_poly_type=chain_poly_type,
structure_store=structure_store,
filter_config=filter_config,
)
@classmethod
def from_hmmsearch_a3m(
cls,
*,
query_sequence: str,
a3m: str,
max_template_date: datetime.date,
structure_store: structure_stores.StructureStore,
filter_config: msa_config.TemplateFilterConfig | None = None,
query_release_date: datetime.date | None = None,
chain_poly_type: str = mmcif_names.PROTEIN_CHAIN,
) -> Self:
"""Creates Templates from a Hmmsearch A3M.
Args:
query_sequence: The polymer sequence of the target query.
a3m: Results of Hmmsearch in A3M format. This provides a list of potential
template alignments and pdb codes.
max_template_date: This is used to filter templates for training, ensuring
that they do not leak ground truth information used in testing sets.
structure_store: Structure store to fetch template structures from.
filter_config: Optional config that controls which and how many hits to
keep. More performant than constructing and then filtering. If not
provided, no filtering is done.
query_release_date: The release_date of the template query, this is used
to filter templates for training, ensuring that they do not leak
structure information from the future.
chain_poly_type: The polymer type of the templates.
Returns:
Templates object containing a list of Hits initialised from the
structure_store metadata and a3m alignments.
"""
def hit_generator(a3m: str):
if not a3m:
return # Hmmsearch could return an empty string if there are no hits.
for hit_seq, hit_desc in parsers.lazy_parse_fasta_string(a3m):
pdb_id, auth_chain_id, start, end, full_length = _parse_hit_description(
hit_desc
)
release_date, sequence, unresolved_res_ids = _parse_hit_metadata(
structure_store, pdb_id, auth_chain_id
)
if unresolved_res_ids is None:
continue
# seq_unresolved_res_num are 1-based, setting to 0-based indices.
unresolved_indices = [i - 1 for i in unresolved_res_ids]
yield Hit(
pdb_id=pdb_id,
auth_chain_id=auth_chain_id,
hmmsearch_sequence=hit_seq,
structure_sequence=sequence,
query_sequence=query_sequence,
unresolved_res_indices=unresolved_indices,
start_index=start - 1, # Raw value is residue number, not index.
end_index=end,
full_length=full_length,
release_date=datetime.date.fromisoformat(release_date),
chain_poly_type=chain_poly_type,
)
if filter_config is None:
hits = tuple(hit_generator(a3m))
else:
hits = _filter_hits(
hit_generator(a3m),
release_date_cutoff=filter_config.max_template_date,
max_subsequence_ratio=filter_config.max_subsequence_ratio,
min_align_ratio=filter_config.min_align_ratio,
min_hit_length=filter_config.min_hit_length,
deduplicate_sequences=filter_config.deduplicate_sequences,
max_hits=filter_config.max_hits,
)
return Templates(
query_sequence=query_sequence,
query_release_date=query_release_date,
hits=hits,
max_template_date=max_template_date,
structure_store=structure_store,
)
@property
def query_sequence(self) -> str:
return self._query_sequence
@property
def hits(self) -> tuple[Hit, ...]:
return self._hits
@property
def query_release_date(self) -> datetime.date | None:
return self._query_release_date
@property
def num_hits(self) -> int:
return len(self._hits)
@functools.cached_property
def release_date_cutoff(self) -> datetime.date:
if self.query_release_date is None:
return self._max_template_date
return min(
self._max_template_date,
self.query_release_date
- datetime.timedelta(days=_DAYS_BEFORE_QUERY_DATE),
)
def __repr__(self) -> str:
return f'Templates({self.num_hits} hits)'
def filter(
self,
*,
max_subsequence_ratio: float | None,
min_align_ratio: float | None,
min_hit_length: int | None,
deduplicate_sequences: bool,
max_hits: int | None,
) -> Self:
"""Returns a new Templates object with only the hits that pass all filters.
This also filters on query_release_date and max_template_date.
Args:
max_subsequence_ratio: If set, excludes hits which are an exact
subsequence of the query sequence, and longer than this ratio. Useful to
avoid ground truth leakage.
min_align_ratio: If set, excludes hits where the number of residues
aligned to the query is less than this proportion of the template
length.
min_hit_length: If set, excludes hits which have fewer residues than this.
deduplicate_sequences: Whether to exclude duplicate template sequences,
keeping only the first. This can be useful in increasing the diversity
of hits especially in the case of homomer hits.
max_hits: If set, excludes any hits which exceed this count.
"""
filtered_hits = _filter_hits(
hits=self._hits,
release_date_cutoff=self.release_date_cutoff,
max_subsequence_ratio=max_subsequence_ratio,
min_align_ratio=min_align_ratio,
min_hit_length=min_hit_length,
deduplicate_sequences=deduplicate_sequences,
max_hits=max_hits,
)
return Templates(
query_sequence=self.query_sequence,
query_release_date=self.query_release_date,
hits=filtered_hits,
max_template_date=self._max_template_date,
structure_store=self._structure_store,
)
def get_hits_with_structures(
self,
) -> Sequence[tuple[Hit, structure.Structure]]:
"""Returns hits + Structures, Structures filtered to the hit's chain."""
results = []
structures = {struc.name.lower(): struc for struc in self.structures}
for hit in self.hits:
if not hit.is_valid:
raise InvalidTemplateError(
'Hits must be filtered before calling get_hits_with_structures.'
)
struc = structures[hit.pdb_id]
label_chain_id = struc.polymer_auth_asym_id_to_label_asym_id().get(
hit.auth_chain_id
)
results.append((hit, struc.filter(chain_id=label_chain_id)))
return results
def featurize(
self,
include_ligand_features: bool = True,
) -> TemplateFeatures:
"""Featurises the templates and returns a map of feature names to features.
NB: If you don't do any prefiltering, this method might be slow to run
as it has to fetch many CIFs and featurize them all.
Args:
include_ligand_features: Whether to compute ligand features.
Returns:
Template features: A mapping of template feature labels to features, which
may be numpy arrays, bytes objects, or for the special case of label
`ligand_features` (if `include_ligand_features` is True), a nested
feature map of labels to numpy arrays.
Raises:
InvalidTemplateError: If hits haven't been filtered before featurization.
"""
hits_by_pdb_id = {}
for idx, hit in enumerate(self.hits):
if not hit.is_valid:
raise InvalidTemplateError(
f'Hits must be filtered before featurizing, got unprocessed {hit=}'
)
hits_by_pdb_id.setdefault(hit.pdb_id, []).append((idx, hit))
unsorted_features = []
for struc in self.structures:
pdb_id = str(struc.name).lower()
for idx, hit in hits_by_pdb_id[pdb_id]:
try:
label_chain_id = struc.polymer_auth_asym_id_to_label_asym_id()[
hit.auth_chain_id
]
hit_features = {
**get_polymer_features(
chain=struc.filter(chain_id=label_chain_id),
chain_poly_type=hit.chain_poly_type,
query_sequence_length=len(hit.query_sequence),
query_to_hit_mapping=hit.query_to_hit_mapping,
),
}
if include_ligand_features:
hit_features['ligand_features'] = _get_ligand_features(struc)
unsorted_features.append((idx, hit_features))
except Error as e:
raise type(e)(f'Failed to featurise {hit=}') from e
sorted_features = sorted(unsorted_features, key=lambda x: x[0])
sorted_features = [feat for _, feat in sorted_features]
return package_template_features(
hit_features=sorted_features,
include_ligand_features=include_ligand_features,
)
@property
def structures(self) -> Iterator[structure.Structure]:
"""Yields template structures for each unique PDB ID among hits.
If there are multiple hits in the same Structure, the Structure will be
included only once by this method.
Yields:
A Structure object for each unique PDB ID among hits.
Raises:
HitDateError: If template's release date exceeds max cutoff date.
"""
for hit in self.hits:
if hit.release_date > self.release_date_cutoff: # pylint: disable=comparison-with-callable
raise HitDateError(
f'Invalid release date for hit {hit.pdb_id=}, when release date '
f'cutoff is {self.release_date_cutoff}.'
)
# Get the set of pdbs to load. In particular, remove duplicate PDB IDs.
targets_to_load = tuple({hit.pdb_id for hit in self.hits})
for target_name in targets_to_load:
yield structure.from_mmcif(
mmcif_string=self._structure_store.get_mmcif_str(target_name),
fix_mse_residues=True,
fix_arginines=True,
include_water=False,
include_bonds=False,
include_other=True, # For non-standard polymer chains.
)
def _parse_hit_description(description: str) -> tuple[str, str, int, int, int]:
"""Parses the hmmsearch A3M sequence description line."""
# Example lines (protein, nucleic, no description):
# >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text
# >4pqx_A/2-217 [subseq from] mol:na length:217 Free text
# >5g3r_A/1-55 [subseq from] mol:protein length:352
if match := re.fullmatch(_HIT_DESCRIPTION_REGEX, description):
return (
match['pdb_id'],
match['chain_id'],
int(match['start']),
int(match['end']),
int(match['length']),
)
else:
raise ValueError(f'Could not parse description "{description}"')
def _parse_hit_metadata(
structure_store: structure_stores.StructureStore,
pdb_id: str,
auth_chain_id: str,
) -> tuple[Any, str | None, Sequence[int] | None]:
"""Parse hit metadata by parsing mmCIF from structure store."""
try:
cif = mmcif.from_string(structure_store.get_mmcif_str(pdb_id))
except structure_stores.NotFoundError:
logging.warning(
'Failed to get mmCIF for %s (author chain %s).', pdb_id, auth_chain_id
)
return None, None, None
release_date = mmcif.get_release_date(cif)
try:
struc = structure.from_parsed_mmcif(
cif,
model_id=structure.ModelID.ALL,
include_water=True,
include_other=True,
include_bonds=False,
)
except ValueError:
struc = structure.from_parsed_mmcif(
cif,
model_id=structure.ModelID.FIRST,
include_water=True,
include_other=True,
include_bonds=False,
)
sequence = struc.polymer_author_chain_single_letter_sequence(
include_missing_residues=True,
protein=True,
dna=True,
rna=True,
other=True,
)[auth_chain_id]
unresolved_res_ids = struc.filter(
chain_auth_asym_id=auth_chain_id
).unresolved_residues.id
return release_date, sequence, unresolved_res_ids
def get_polymer_features(
*,
chain: structure.Structure,
chain_poly_type: str,
query_sequence_length: int,
query_to_hit_mapping: Mapping[int, int],
) -> Mapping[str, Any]:
"""Returns features for this polymer chain.
Args:
chain: Structure object representing the template. Must be already filtered
to a single chain.
chain_poly_type: The chain polymer type (protein, DNA, RNA).
query_sequence_length: The length of the query sequence.
query_to_hit_mapping: 0-based query index to hit index mapping.
Returns:
A dictionary with polymer features for template_chain_id in the struc.
Raises:
ValueError: If the input structure contains more than just a single chain.
"""
if chain.name is None:
raise ValueError('Template structure must have a name.')
if chain.release_date is None:
raise ValueError(
f'Template structure {chain.name} must have a release date. You can do'
' this by setting "_pdbx_audit_revision_history.revision_date" in the'
' template mmCIF to a date in the ISO-8601 format (e.g. 1989-11-17).'
)
num_polymer_chains = len(chain.polymer_auth_asym_id_to_label_asym_id())
if num_polymer_chains != 1:
raise ValueError(
f'Template structure {chain.name} must be filtered to a single polymer'
f' chain but got a structure with {num_polymer_chains} polymer chains.'
)
auth_chain_id, label_chain_id = next(
iter(chain.polymer_auth_asym_id_to_label_asym_id().items())
)
chain_sequence = chain.chain_single_letter_sequence()[label_chain_id]
polymer = _POLYMERS[chain_poly_type]
res_arrays = chain.to_res_arrays(
include_missing_residues=True, atom_order=polymer.atom_order
)
positions = res_arrays.atom_positions
positions_mask = res_arrays.atom_mask
template_all_atom_positions = np.zeros(
(query_sequence_length, polymer.num_atom_types, 3), dtype=np.float64
)
template_all_atom_masks = np.zeros(
(query_sequence_length, polymer.num_atom_types), dtype=np.int64
)
template_sequence = ['-'] * query_sequence_length
for query_index, template_index in query_to_hit_mapping.items():
template_all_atom_positions[query_index] = positions[template_index]
template_all_atom_masks[query_index] = positions_mask[template_index]
template_sequence[query_index] = chain_sequence[template_index]
template_sequence = ''.join(template_sequence)
template_aatype = _encode_restype(chain_poly_type, template_sequence)
template_name = f'{chain.name.lower()}_{auth_chain_id}'
release_date = chain.release_date.strftime('%Y-%m-%d')
return {
'template_all_atom_positions': template_all_atom_positions,
'template_all_atom_masks': template_all_atom_masks,
'template_sequence': template_sequence.encode(),
'template_aatype': np.array(template_aatype, dtype=np.int32),
'template_domain_names': np.array(template_name.encode(), dtype=object),
'template_release_date': np.array(release_date.encode(), dtype=object),
}
def _get_ligand_features(
struc: structure.Structure,
) -> Mapping[str, Mapping[str, np.ndarray | bytes]]:
"""Returns features for the ligands in this structure."""
ligand_struc = struc.filter_to_entity_type(ligand=True)
assert ligand_struc.coords is not None
assert ligand_struc.atom_name is not None
assert ligand_struc.atom_occupancy is not None
ligand_features = {}
for ligand_chain_id in ligand_struc.chains:
idxs = np.where(ligand_struc.chain_id == ligand_chain_id)[0]
if idxs.shape[0]:
ligand_features[ligand_chain_id] = {
'ligand_atom_positions': (
ligand_struc.coords[idxs, :].astype(np.float32)
),
'ligand_atom_names': ligand_struc.atom_name[idxs].astype(object),
'ligand_atom_occupancies': (
ligand_struc.atom_occupancy[idxs].astype(np.float32)
),
'ccd_id': ligand_struc.res_name[idxs][0].encode(),
}
return ligand_features
def package_template_features(
*,
hit_features: Sequence[Mapping[str, Any]],
include_ligand_features: bool,
) -> Mapping[str, Any]:
"""Stacks polymer features, adds empty and keeps ligand features unstacked."""
features_to_include = set(_POLYMER_FEATURES)
if include_ligand_features:
features_to_include.update(_LIGAND_FEATURES)
features = {
feat: [single_hit_features[feat] for single_hit_features in hit_features]
for feat in features_to_include
}
stacked_features = {}
for k, v in features.items():
if k in _POLYMER_FEATURES:
v = np.stack(v, axis=0) if v else np.array([], dtype=_POLYMER_FEATURES[k])
stacked_features[k] = v
return stacked_features
def _resolve_path(path: os.PathLike[str] | str) -> str:
"""Resolves path for data dep paths, stringifies otherwise."""
# Data dependency paths: db baked into the binary.
resolved_path = resources.filename(path)
if os.path.exists(resolved_path):
return resolved_path
else:
# Other paths, e.g. local.
return str(path)
def run_hmmsearch_with_a3m(
*,
database_path: os.PathLike[str] | str,
hmmsearch_config: msa_config.HmmsearchConfig,
max_a3m_query_sequences: int | None,
a3m: str | None,
) -> str:
"""Runs Hmmsearch to get a3m string of hits."""
searcher = hmmsearch.Hmmsearch(
binary_path=hmmsearch_config.hmmsearch_binary_path,
hmmbuild_binary_path=hmmsearch_config.hmmbuild_binary_path,
database_path=_resolve_path(database_path),
e_value=hmmsearch_config.e_value,
inc_e=hmmsearch_config.inc_e,
dom_e=hmmsearch_config.dom_e,
incdom_e=hmmsearch_config.incdom_e,
alphabet=hmmsearch_config.alphabet,
filter_f1=hmmsearch_config.filter_f1,
filter_f2=hmmsearch_config.filter_f2,
filter_f3=hmmsearch_config.filter_f3,
filter_max=hmmsearch_config.filter_max,
)
# STO enables us to annotate query non-gap columns as reference columns.
sto = parsers.convert_a3m_to_stockholm(a3m, max_a3m_query_sequences)
return searcher.query_with_sto(sto, model_construction='hand')
================================================
FILE: src/alphafold3/data/tools/hmmalign.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""A Python wrapper for hmmalign from the HMMER Suite."""
from collections.abc import Mapping, Sequence
import os
import tempfile
from alphafold3.data import parsers
from alphafold3.data.tools import subprocess_utils
def _to_a3m(sequences: Sequence[str], name_prefix: str = 'sequence') -> str:
a3m = ''
for i, sequence in enumerate(sequences, 1):
a3m += f'> {name_prefix} {i}\n{sequence}\n'
return a3m
class Hmmalign:
"""Python wrapper of the hmmalign binary."""
def __init__(self, binary_path: str):
"""Initializes the Python hmmalign wrapper.
Args:
binary_path: Path to the hmmalign binary.
Raises:
RuntimeError: If hmmalign binary not found within the path.
"""
self._binary_path = binary_path
subprocess_utils.check_binary_exists(
path=self._binary_path, name='hmmalign'
)
def align_sequences(
self,
sequences: Sequence[str],
profile: str,
extra_flags: Mapping[str, str] | None = None,
) -> str:
"""Aligns sequence list to the profile and returns the alignment in A3M."""
return self.align(
a3m_str=_to_a3m(sequences, name_prefix='query'),
profile=profile,
extra_flags=extra_flags,
)
def align(
self,
a3m_str: str,
profile: str,
extra_flags: Mapping[str, str] | None = None,
) -> str:
"""Aligns sequences in A3M to the profile and returns the alignment in A3M.
Args:
a3m_str: A list of sequence strings.
profile: A hmm file with the hmm profile to align the sequences to.
extra_flags: Dictionary with extra flags, flag_name: flag_value, that are
added to hmmalign.
Returns:
An A3M string with the aligned sequences.
Raises:
RuntimeError: If hmmalign fails.
"""
with tempfile.TemporaryDirectory() as query_tmp_dir:
input_profile = os.path.join(query_tmp_dir, 'profile.hmm')
input_sequences = os.path.join(query_tmp_dir, 'sequences.a3m')
output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
with open(input_profile, 'w') as f:
f.write(profile)
with open(input_sequences, 'w') as f:
f.write(a3m_str)
cmd = [
self._binary_path,
*('-o', output_a3m_path),
*('--outformat', 'A2M'), # A2M is A3M in the HMMER suite.
]
if extra_flags:
for flag_name, flag_value in extra_flags.items():
cmd.extend([flag_name, flag_value])
cmd.extend([input_profile, input_sequences])
subprocess_utils.run(
cmd=cmd,
cmd_name='hmmalign',
log_stdout=False,
log_stderr=True,
log_on_process_error=True,
)
with open(output_a3m_path, encoding='utf-8') as f:
a3m = f.read()
return a3m
def align_sequences_to_profile(self, profile: str, sequences_a3m: str) -> str:
"""Aligns the sequences to profile and returns the alignment in A3M string.
Uses hmmalign to align the sequences to the profile, then ouputs the
sequence contatenated at the beginning of the sequences in the A3M format.
As the sequences are represented by an alignment with possible gaps ('-')
and insertions (lowercase characters), the method first removes the gaps,
then uppercases the insertions to prepare the sequences for realignment.
Sequences with gaps cannot be aligned, as '-'s are not a valid symbol to
align; lowercase characters must be uppercased to preserve the original
sequences before realignment.
Args:
profile: The Hmmbuild profile to align the sequences to.
sequences_a3m: Sequences in A3M format to align to the profile.
Returns:
An A3M string with the aligned sequences.
Raises:
RuntimeError: If hmmalign fails.
"""
deletion_table = str.maketrans('', '', '-')
sequences_no_gaps_a3m = []
for seq, desc in parsers.lazy_parse_fasta_string(sequences_a3m):
sequences_no_gaps_a3m.append(f'>{desc}')
sequences_no_gaps_a3m.append(seq.translate(deletion_table))
sequences_no_gaps_a3m = '\n'.join(sequences_no_gaps_a3m)
aligned_sequences = self.align(sequences_no_gaps_a3m, profile)
return aligned_sequences
================================================
FILE: src/alphafold3/data/tools/hmmbuild.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
import os
import re
import tempfile
from typing import Literal
from alphafold3.data import parsers
from alphafold3.data.tools import subprocess_utils
class Hmmbuild(object):
"""Python wrapper of the hmmbuild binary."""
def __init__(
self,
*,
binary_path: str,
singlemx: bool = False,
alphabet: str | None = None,
):
"""Initializes the Python hmmbuild wrapper.
Args:
binary_path: The path to the hmmbuild executable.
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
just use a common substitution score matrix.
alphabet: The alphabet to assert when building a profile. Useful when
hmmbuild cannot guess the alphabet. If None, no alphabet is asserted.
Raises:
RuntimeError: If hmmbuild binary not found within the path.
"""
self._binary_path = binary_path
self._singlemx = singlemx
self._alphabet = alphabet
subprocess_utils.check_binary_exists(
path=self._binary_path, name='hmmbuild'
)
def build_profile_from_sto(self, sto: str, model_construction='fast') -> str:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
sto: A string with the aligned sequences in the Stockholm format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
return self._build_profile(
sto, informat='stockholm', model_construction=model_construction
)
def build_profile_from_a3m(self, a3m: str) -> str:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
a3m: A string with the aligned sequences in the A3M format.
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
lines = []
for sequence, description in parsers.lazy_parse_fasta_string(a3m):
sequence = re.sub('[a-z]+', '', sequence) # Remove inserted residues.
lines.append(f'>{description}\n{sequence}\n')
msa = ''.join(lines)
return self._build_profile(msa, informat='afa')
def _build_profile(
self,
msa: str,
informat: Literal['afa', 'stockholm'],
model_construction: str = 'fast',
) -> str:
"""Builds a HMM for the aligned sequences given as an MSA string.
Args:
msa: A string with the aligned sequences, in A3M or STO format.
informat: One of 'afa' (aligned FASTA) or 'sto' (Stockholm).
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
ValueError: If unspecified arguments are provided.
"""
if model_construction not in {'hand', 'fast'}:
raise ValueError(f'Bad {model_construction=}. Only hand or fast allowed.')
with tempfile.TemporaryDirectory() as query_tmp_dir:
input_msa_path = os.path.join(query_tmp_dir, 'query.msa')
output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm')
with open(input_msa_path, 'w') as f:
f.write(msa)
# Specify the format as we don't specify the input file extension. See
# https://github.com/EddyRivasLab/hmmer/issues/321 for more details.
cmd_flags = ['--informat', informat]
# If adding flags, we have to do so before the output and input:
if model_construction == 'hand':
cmd_flags.append(f'--{model_construction}')
if self._singlemx:
cmd_flags.append('--singlemx')
if self._alphabet:
cmd_flags.append(f'--{self._alphabet}')
cmd_flags.extend([output_hmm_path, input_msa_path])
cmd = [self._binary_path, *cmd_flags]
subprocess_utils.run(
cmd=cmd,
cmd_name='Hmmbuild',
log_stdout=False,
log_stderr=True,
log_on_process_error=True,
)
with open(output_hmm_path) as f:
hmm = f.read()
return hmm
================================================
FILE: src/alphafold3/data/tools/hmmsearch.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import os
import tempfile
from absl import logging
from alphafold3.data import parsers
from alphafold3.data.tools import hmmbuild
from alphafold3.data.tools import subprocess_utils
class Hmmsearch(object):
"""Python wrapper of the hmmsearch binary."""
def __init__(
self,
*,
binary_path: str,
hmmbuild_binary_path: str,
database_path: str,
alphabet: str = 'amino',
filter_f1: float | None = None,
filter_f2: float | None = None,
filter_f3: float | None = None,
e_value: float | None = None,
inc_e: float | None = None,
dom_e: float | None = None,
incdom_e: float | None = None,
filter_max: bool = False,
):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
alphabet: Chain type e.g. amino, rna, dna.
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
filter_f3: Forward pre-filter, set to >1.0 to turn off.
e_value: E-value criteria for inclusion in tblout.
inc_e: E-value criteria for inclusion in MSA/next round.
dom_e: Domain e-value criteria for inclusion in tblout.
incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
round.
filter_max: Remove all filters, will ignore all filter_f* settings.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self._binary_path = binary_path
self._hmmbuild_runner = hmmbuild.Hmmbuild(
alphabet=alphabet, binary_path=hmmbuild_binary_path
)
self._database_path = database_path
flags = []
if filter_max:
flags.append('--max')
else:
if filter_f1 is not None:
flags.extend(('--F1', filter_f1))
if filter_f2 is not None:
flags.extend(('--F2', filter_f2))
if filter_f3 is not None:
flags.extend(('--F3', filter_f3))
if e_value is not None:
flags.extend(('-E', e_value))
if inc_e is not None:
flags.extend(('--incE', inc_e))
if dom_e is not None:
flags.extend(('--domE', dom_e))
if incdom_e is not None:
flags.extend(('--incdomE', incdom_e))
self._flags = tuple(map(str, flags))
subprocess_utils.check_binary_exists(
path=self._binary_path, name='hmmsearch'
)
if not os.path.exists(self._database_path):
logging.error('Could not find hmmsearch database %s', database_path)
raise ValueError(f'Could not find hmmsearch database {database_path}')
def query_with_hmm(self, hmm: str) -> str:
"""Queries the database using hmmsearch using a given hmm."""
with tempfile.TemporaryDirectory() as query_tmp_dir:
hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
sto_out_path = os.path.join(query_tmp_dir, 'output.sto')
with open(hmm_input_path, 'w') as f:
f.write(hmm)
cmd = [
self._binary_path,
'--noali', # Don't include the alignment in stdout.
*('--cpu', '8'),
]
# If adding flags, we have to do so before the output and input:
if self._flags:
cmd.extend(self._flags)
cmd.extend([
*('-A', sto_out_path),
hmm_input_path,
self._database_path,
])
subprocess_utils.run(
cmd=cmd,
cmd_name=f'Hmmsearch ({os.path.basename(self._database_path)})',
log_stdout=False,
log_stderr=True,
log_on_process_error=True,
)
with open(sto_out_path) as f:
a3m_out = parsers.convert_stockholm_to_a3m(
f, remove_first_row_gaps=False, linewidth=60
)
return a3m_out
def query_with_a3m(self, a3m_in: str) -> str:
"""Query the database using hmmsearch using a given a3m."""
# Only the "fast" model construction makes sense with A3M, as it doesn't
# have any way to annotate reference columns.
hmm = self._hmmbuild_runner.build_profile_from_a3m(a3m_in)
return self.query_with_hmm(hmm)
def query_with_sto(
self, msa_sto: str, model_construction: str = 'fast'
) -> str:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm = self._hmmbuild_runner.build_profile_from_sto(
msa_sto, model_construction=model_construction
)
return self.query_with_hmm(hmm)
================================================
FILE: src/alphafold3/data/tools/jackhmmer.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Library to run Jackhmmer from Python."""
from collections.abc import Iterable, Sequence
from concurrent import futures
import heapq
import os
import pathlib
import shutil
import tempfile
import time
from absl import logging
from alphafold3.data import parsers
from alphafold3.data.tools import msa_tool
from alphafold3.data.tools import shards
from alphafold3.data.tools import subprocess_utils
class Jackhmmer(msa_tool.MsaTool):
"""Python wrapper of the Jackhmmer binary."""
def __init__(
self,
*,
binary_path: str,
database_path: str,
n_cpu: int = 8,
n_iter: int = 3,
e_value: float | None = 1e-3,
z_value: float | int | None = None,
dom_e: float | None = None,
dom_z_value: float | int | None = None,
max_sequences: int = 5000,
filter_f1: float = 5e-4,
filter_f2: float = 5e-5,
filter_f3: float = 5e-7,
max_threads: int | None = None,
**unused_kwargs,
):
"""Initializes the Python Jackhmmer wrapper.
NOTE: The MSA obtained by running against sharded dbs won't be always
exactly the same as the MSA obtained by running against an unsharded db.
This is because of Jackhmmer deduplication logic, which won't spot duplicate
hits across multiple shards. Usually this means that the sharded search
finds more hits (likely bounded by the number of shards), but this should
not pose an issue given how the results are used downstream. The problem is
more pronounced with deep MSAs and lower in the hit list (higher e-values).
Make sure to set the Z and domZ values when searching against a sharded
database, otherwise the results won't match the normal unsharded search.
Args:
binary_path: The path to the jackhmmer executable.
database_path: The path to the jackhmmer database (FASTA format). Sharded
file specs, e.g. `@`, are supported.
n_cpu: The number of CPUs to give Jackhmmer.
n_iter: The number of Jackhmmer iterations.
e_value: The E-value, see Jackhmmer docs for more details.
z_value: The Z-value representing the number of comparisons done (i.e
correct database size) for E-value calculation. Make sure to set this
when searching against a sharded database, otherwise the e-values will
be incorrectly scaled.
dom_e: Domain e-value criteria for inclusion in tblout.
dom_z_value: Domain z-value representing the number of comparisons done
(i.e correct database size) for domain E-value calculation. Make sure to
set this when searching against a sharded database, otherwise the domain
e-values will be incorrectly scaled.
max_sequences: Maximum number of sequences to return in the MSA.
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
filter_f3: Forward pre-filter, set to >1.0 to turn off.
max_threads: If given, the maximum number of threads used when running
sharded databases.
Raises:
RuntimeError: If Jackhmmer binary not found within the path.
ValueError: If an invalid configuration is provided in the args.
"""
self._database_path = database_path
if shard_paths := shards.get_sharded_paths(self._database_path):
if n_iter != 1:
raise ValueError('For a sharded db, only n_iter=1 is supported.')
if z_value is None:
raise ValueError(
'The Z-value must be set when searching against a sharded database '
'to correctly scale e-values.'
)
if max_sequences <= 1:
raise ValueError(
'max_sequences must be greater than 1 when running in sharded '
'mode, because each shard would return only the query sequence.'
)
self._shard_paths = shard_paths
self._max_threads = len(self._shard_paths)
if max_threads is not None:
self._max_threads = min(max_threads, self._max_threads)
logging.info('Jackhmmer running with max_threads = %d', self._max_threads)
else:
self._shard_paths = None
self._max_threads = None
self._binary_path = binary_path
subprocess_utils.check_binary_exists(
path=self._binary_path, name='Jackhmmer'
)
self._n_cpu = n_cpu
self._n_iter = n_iter
self._e_value = e_value
self._z_value = z_value
self._dom_e = dom_e
self._dom_z_value = dom_z_value
self._max_sequences = max_sequences
self._filter_f1 = filter_f1
self._filter_f2 = filter_f2
self._filter_f3 = filter_f3
# If Jackhmmer supports the --seq_limit flag (via our patch), use it to
# prevent writing out redundant sequences and increasing peak memory usage.
# If not, the Jackhmmer will be run without the --seq_limit flag.
self._supports_seq_limit = subprocess_utils.jackhmmer_seq_limit_supported(
self._binary_path
)
def query(self, target_sequence: str) -> msa_tool.MsaToolResult:
"""Query the database (sharded or unsharded) using Jackhmmer."""
if self._shard_paths:
# Sharded case, run the query against each database shard in parallel.
logging.info(
'Query sequence (sharded db): %s',
target_sequence
if len(target_sequence) <= 16
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
)
global_temp_dir = tempfile.mkdtemp()
def _query_shard_fn(
shard_path: str,
) -> tuple[msa_tool.MsaToolResult, float]:
t_start = time.time()
result = self._query_db_shard(
target_sequence=target_sequence,
db_shard_path=shard_path,
get_tblout=True, # Tblout contains e-values needed for merging.
global_temp_dir=global_temp_dir,
)
return result, time.time() - t_start
with futures.ThreadPoolExecutor(max_workers=self._max_threads) as ex:
tool_outputs, timings = zip(*ex.map(_query_shard_fn, self._shard_paths))
logging.info(
'Finished query for %d shards, shard timings (seconds): %s',
len(tool_outputs),
', '.join(f'{t:.1f}' for t in timings),
)
shutil.rmtree(global_temp_dir, ignore_errors=True)
return _merge_jackhmmer_results(tool_outputs, self._max_sequences)
else:
# Non-sharded case, run the query against the whole database.
logging.info(
'Query sequence (non-sharded db): %s',
target_sequence
if len(target_sequence) <= 16
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
)
return self._query_db_shard(
target_sequence=target_sequence,
db_shard_path=self._database_path,
get_tblout=False,
)
def _query_db_shard(
self,
*,
target_sequence: str,
db_shard_path: str,
get_tblout: bool,
global_temp_dir: str | None = None,
) -> msa_tool.MsaToolResult:
"""Query the database shard using Jackhmmer."""
with tempfile.TemporaryDirectory(dir=global_temp_dir) as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, 'query.fasta')
subprocess_utils.create_query_fasta_file(
sequence=target_sequence, path=input_fasta_path
)
output_sto_path = os.path.join(query_tmp_dir, 'output.sto')
pathlib.Path(output_sto_path).touch()
# The F1/F2/F3 are the expected proportion to pass each of the filtering
# stages (which get progressively more expensive), reducing these
# speeds up the pipeline at the expensive of sensitivity. They are
# currently set very low to make querying Mgnify run in a reasonable
# amount of time.
cmd_flags = [
*('-o', '/dev/null'), # Don't pollute stdout with Jackhmmer output.
*('-A', output_sto_path),
'--noali',
*('--F1', str(self._filter_f1)),
*('--F2', str(self._filter_f2)),
*('--F3', str(self._filter_f3)),
*('--cpu', str(self._n_cpu)),
*('-N', str(self._n_iter)),
]
if get_tblout:
output_tblout_path = pathlib.Path(query_tmp_dir, 'tblout.txt')
output_tblout_path.touch()
cmd_flags.extend(['--tblout', str(output_tblout_path)])
else:
output_tblout_path = None
# Report only sequences with E-values <= x in per-sequence output.
if self._e_value is not None:
cmd_flags.extend(['-E', str(self._e_value)])
# Use the same value as the reporting e-value (`-E` flag).
cmd_flags.extend(['--incE', str(self._e_value)])
if self._z_value is not None:
cmd_flags.extend(['-Z', str(self._z_value)])
if self._dom_z_value is not None:
cmd_flags.extend(['--domZ', str(self._dom_z_value)])
if self._dom_e is not None:
cmd_flags.extend(['--domE', str(self._dom_e)])
if self._max_sequences is not None and self._supports_seq_limit:
cmd_flags.extend(['--seq_limit', str(self._max_sequences)])
# The input FASTA and the input db are the last two arguments.
cmd = [self._binary_path] + cmd_flags + [input_fasta_path, db_shard_path]
subprocess_utils.run(
cmd=cmd,
cmd_name=f'Jackhmmer ({os.path.basename(db_shard_path)})',
log_stdout=False,
log_stderr=True,
log_on_process_error=True,
)
with open(output_sto_path) as f:
a3m = parsers.convert_stockholm_to_a3m(
f, max_sequences=self._max_sequences
)
# Get the tabular output which has e.g. e-value for each target.
tbl = '' if output_tblout_path is None else output_tblout_path.read_text()
return msa_tool.MsaToolResult(
target_sequence=target_sequence,
a3m=a3m,
e_value=self._e_value,
tblout=tbl,
)
def _merge_jackhmmer_results(
jh_results: Sequence[msa_tool.MsaToolResult], max_sequences: int
) -> msa_tool.MsaToolResult:
"""Merges Jackhmmer result protos into a single one."""
assert len(set(jh_res.target_sequence for jh_res in jh_results)) == 1
assert len(set(jh_res.e_value for jh_res in jh_results)) == 1
# Parse the TBL output, create a mapping from hit name to TBL line.
parsed_tbl = {}
for jh_result in jh_results:
assert jh_result.tblout is not None
for line in jh_result.tblout.splitlines():
if not line.startswith('#'):
parsed_tbl[line.partition(' ')[0]] = line
# Create an iterator and merge a3m info with tbl info.
def _merged_a3m_tbl_iter(a3m: str) -> Iterable[tuple[str, str, str, str]]:
# Don't parse the entire a3m, lazily parse only as many sequences as needed.
iterator = iter(parsers.lazy_parse_fasta_string(a3m))
next(iterator) # Skip the query which isn't present in tblout.
for sequence, description in iterator:
name = description.partition(' ')[0].partition('/')[0]
if tbl_info := parsed_tbl.get(name):
# Skip sequences for which we don't have tbl information.
yield sequence, description, tbl_info, name
def sort_key(seq_data: tuple[str, str, str, str]) -> tuple[float, float, str]:
unused_seq, unused_description, tbl_info, name = seq_data
# Tblout lines have 19 whitespace delimited columns. "-" used if no value
# present. We want e-value (column 5) and bit score (column 6), so do only 6
# splits. E-value and bit score are equivalent, but bit score might have
# higher resolution. Use the name in case of a tie.
e_value, bit_score = tbl_info.split(maxsplit=6)[4:6]
return float(e_value), -float(bit_score), name
# A3M/TBL is sorted by e-value and name, hence we can merge them efficiently.
merged_a3m_and_tblout = heapq.merge(
*[_merged_a3m_tbl_iter(res.a3m) for res in jh_results],
key=sort_key,
)
# Truncate the a3m to max_sequences. Do not truncate the tblout.
merged_tblout = []
merged_a3m = [f'>query\n{jh_results[0].target_sequence}']
for seq, description, tbl_info, _ in merged_a3m_and_tblout:
merged_tblout.append(tbl_info)
if len(merged_a3m) < max_sequences:
merged_a3m.append(f'>{description}\n{seq}')
logging.info(
'Limiting merged MSA depth from %d to %d',
len(merged_tblout),
max_sequences,
)
return msa_tool.MsaToolResult(
target_sequence=jh_results[0].target_sequence,
a3m='\n'.join(merged_a3m),
e_value=jh_results[0].e_value,
tblout=None, # We no longer need the tblout.
)
================================================
FILE: src/alphafold3/data/tools/msa_tool.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Defines protocol for MSA tools."""
import dataclasses
from typing import Protocol
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class MsaToolResult:
"""The result of a MSA tool query.
Attributes:
target_sequence: The sequence that was used to query the MSA tool.
e_value: The e-value that was used to filter the MSA tool results.
a3m: The MSA output of the tool in the A3M format.
tblout: The optional tblout output of the MSA tool (needed for merging
results of queries against a sharded database).
"""
target_sequence: str
e_value: float
a3m: str
tblout: str | None = None
class MsaTool(Protocol):
"""Interface for MSA tools."""
def query(self, target_sequence: str) -> MsaToolResult:
"""Runs the MSA tool on the target sequence."""
================================================
FILE: src/alphafold3/data/tools/nhmmer.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Library to run Nhmmer from Python."""
from collections.abc import Iterable, Sequence
from concurrent import futures
import heapq
import os
import pathlib
import shutil
import tempfile
import time
from typing import Final
from absl import logging
from alphafold3.data import parsers
from alphafold3.data.tools import hmmalign
from alphafold3.data.tools import hmmbuild
from alphafold3.data.tools import msa_tool
from alphafold3.data.tools import shards
from alphafold3.data.tools import subprocess_utils
_SHORT_SEQUENCE_CUTOFF: Final[int] = 50
class Nhmmer(msa_tool.MsaTool):
"""Python wrapper of the Nhmmer binary."""
def __init__(
self,
binary_path: str,
hmmalign_binary_path: str,
hmmbuild_binary_path: str,
database_path: str,
n_cpu: int = 8,
e_value: float = 1e-3,
z_value: float | int | None = None,
max_sequences: int = 5000,
filter_f3: float = 1e-5,
alphabet: str | None = None,
strand: str | None = None,
max_threads: int | None = None,
):
"""Initializes the Python Nhmmer wrapper.
NOTE: The MSA obtained by running against sharded dbs won't be always
exactly the same as the MSA obtained by running against an unsharded db.
This is because of Jackhmmer deduplication logic, which won't spot duplicate
hits across multiple shards. Usually this means that the sharded search
finds more hits (likely bounded by the number of shards), but this should
not pose an issue given how the results are used downstream. The problem is
more pronounced with deep MSAs and lower in the hit list (higher e-values).
Make sure to set the Z value when searching against a sharded database,
otherwise the results won't match the normal unsharded search.
Args:
binary_path: Path to the Nhmmer binary.
hmmalign_binary_path: Path to the Hmmalign binary.
hmmbuild_binary_path: Path to the Hmmbuild binary.
database_path: MSA database path to search against. This can be either a
FASTA (slow) or HMMERDB produced from the FASTA using the makehmmerdb
binary. The HMMERDB is ~10x faster but experimental. Sharded file
specs, e.g. @, are supported.
n_cpu: The number of CPUs to give Nhmmer.
e_value: The E-value, see Nhmmer docs for more details. Will be
overwritten if bit_score is set.
z_value: The Z-value representing the number of comparisons done (i.e
correct database size) for E-value calculation. Make sure to set this
when searching against a sharded database, otherwise the e-values will
be incorrectly scaled.
max_sequences: Maximum number of sequences to return in the MSA.
filter_f3: Forward pre-filter, set to >1.0 to turn off.
alphabet: The alphabet to assert when building a profile with hmmbuild.
This must be 'rna', 'dna', or None.
strand: "watson" searches query sequence, "crick" searches
reverse-compliment and default is None which means searching for both.
max_threads: If given, the maximum number of threads used when running
sharded databases.
Raises:
RuntimeError: If Nhmmer binary not found within the path.
ValueError: If an invalid configuration is provided in the args.
"""
self._database_path = database_path
if shard_paths := shards.get_sharded_paths(self._database_path):
if z_value is None:
raise ValueError(
'The Z-value must be set when searching against a sharded database '
'to correctly scale e-values.'
)
if 'hmmerdb' in self._database_path:
raise ValueError('HMMERDB is not supported in sharded mode.')
if max_sequences <= 1:
raise ValueError(
'max_sequences must be greater than 1 when running in sharded '
'mode, because each shard would return only the query sequence.'
)
self._shard_paths = shard_paths
self._max_threads = len(self._shard_paths)
if max_threads is not None:
self._max_threads = min(max_threads, self._max_threads)
logging.info('Nhmmer running with max_threads = %d', self._max_threads)
else:
self._shard_paths = None
self._max_threads = None
self._binary_path = binary_path
self._hmmalign_binary_path = hmmalign_binary_path
self._hmmbuild_binary_path = hmmbuild_binary_path
subprocess_utils.check_binary_exists(path=self._binary_path, name='Nhmmer')
if strand and strand not in {'watson', 'crick'}:
raise ValueError(f'Invalid {strand=}. only "watson" or "crick" supported')
if alphabet and alphabet not in {'rna', 'dna'}:
raise ValueError(f'Invalid {alphabet=}, only "rna" or "dna" supported')
self._e_value = e_value
self._n_cpu = n_cpu
self._z_value = z_value
self._max_sequences = max_sequences
self._filter_f3 = filter_f3
self._alphabet = alphabet
self._strand = strand
def query(self, target_sequence: str) -> msa_tool.MsaToolResult:
"""Query the database (sharded or unsharded) using Nhmmer."""
if self._shard_paths:
# Sharded case, run the query against each database shard in parallel.
logging.info(
'Query sequence (sharded db): %s',
target_sequence
if len(target_sequence) <= 16
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
)
global_temp_dir = tempfile.mkdtemp()
def _query_shard_fn(
shard_path: str,
) -> tuple[msa_tool.MsaToolResult, float]:
t_start = time.time()
# Get tblout as it contains e-values we need for merging sequences.
result = self._query_db_shard(
target_sequence=target_sequence,
db_shard_path=shard_path,
get_tblout=True, # Tblout contains e-values needed for merging.
global_temp_dir=global_temp_dir,
)
return result, time.time() - t_start
with futures.ThreadPoolExecutor(max_workers=self._max_threads) as ex:
tool_outputs, timings = zip(*ex.map(_query_shard_fn, self._shard_paths))
logging.info(
'Finished query for %d shards, shard timings (seconds): %s',
len(tool_outputs),
', '.join(f'{t:.1f}' for t in timings),
)
shutil.rmtree(global_temp_dir, ignore_errors=True)
return _merge_nhmmer_results(tool_outputs, self._max_sequences)
else:
# Non-sharded case, run the query against the whole database.
logging.info(
'Query sequence (non-sharded db): %s',
target_sequence
if len(target_sequence) <= 16
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
)
return self._query_db_shard(
target_sequence=target_sequence,
db_shard_path=self._database_path,
get_tblout=False,
)
def _query_db_shard(
self,
*,
target_sequence: str,
db_shard_path: str,
get_tblout: bool,
global_temp_dir: str | None = None,
) -> msa_tool.MsaToolResult:
"""Query the database shard using Nhmmer."""
with tempfile.TemporaryDirectory(dir=global_temp_dir) as query_tmp_dir:
input_a3m_path = os.path.join(query_tmp_dir, 'query.a3m')
output_sto_path = os.path.join(query_tmp_dir, 'output.sto')
pathlib.Path(output_sto_path).touch()
subprocess_utils.create_query_fasta_file(
sequence=target_sequence, path=input_a3m_path
)
cmd_flags = [
*('-o', '/dev/null'), # Don't pollute stdout with nhmmer output.
'--noali', # Don't include the alignment in stdout.
*('--cpu', str(self._n_cpu)),
]
if get_tblout:
output_tblout_path = pathlib.Path(query_tmp_dir, 'tblout.txt')
output_tblout_path.touch()
cmd_flags.extend(['--tblout', str(output_tblout_path)])
else:
output_tblout_path = None
cmd_flags.extend(['-E', str(self._e_value)])
if self._z_value is not None:
cmd_flags.extend(['-Z', str(self._z_value)])
if self._alphabet:
cmd_flags.extend([f'--{self._alphabet}'])
if self._strand is not None:
cmd_flags.extend([f'--{self._strand}'])
cmd_flags.extend(['-A', output_sto_path])
# As recommend by RNAcentral for short sequences.
if (
self._alphabet == 'rna'
and len(target_sequence) < _SHORT_SEQUENCE_CUTOFF
):
cmd_flags.extend(['--F3', str(0.02)])
else:
cmd_flags.extend(['--F3', str(self._filter_f3)])
# The input A3M and the db are the last two arguments.
cmd_flags.extend((input_a3m_path, db_shard_path))
cmd = [self._binary_path, *cmd_flags]
subprocess_utils.run(
cmd=cmd,
cmd_name=f'Nhmmer ({os.path.basename(db_shard_path)})',
log_stdout=False,
log_stderr=True,
log_on_process_error=True,
)
if os.path.getsize(output_sto_path) > 0:
with open(output_sto_path) as f:
a3m_out = parsers.convert_stockholm_to_a3m(
f, max_sequences=self._max_sequences - 1 # Query not included.
)
# Nhmmer hits are generally shorter than the query sequence. To get MSA
# of width equal to the query sequence, align hits to the query profile.
logging.info('Aligning output a3m of size %d bytes', len(a3m_out))
aligner = hmmalign.Hmmalign(self._hmmalign_binary_path)
target_sequence_fasta = f'>query\n{target_sequence}\n'
profile_builder = hmmbuild.Hmmbuild(
binary_path=self._hmmbuild_binary_path, alphabet=self._alphabet
)
profile = profile_builder.build_profile_from_a3m(target_sequence_fasta)
a3m_out = aligner.align_sequences_to_profile(
profile=profile, sequences_a3m=a3m_out
)
a3m_out = ''.join([target_sequence_fasta, a3m_out])
# Parse the output a3m to remove line breaks.
a3m = '\n'.join(
[f'>{n}\n{s}' for s, n in parsers.lazy_parse_fasta_string(a3m_out)]
)
else:
# Nhmmer returns an empty file if there are no hits.
# In this case return only the query sequence.
a3m = f'>query\n{target_sequence}'
# Get the tabular output which has e.g. e-value for each target.
tbl = '' if output_tblout_path is None else output_tblout_path.read_text()
return msa_tool.MsaToolResult(
target_sequence=target_sequence,
e_value=self._e_value,
a3m=a3m,
tblout=tbl,
)
def _merge_nhmmer_results(
nhmmer_results: Sequence[msa_tool.MsaToolResult],
max_sequences: int,
) -> msa_tool.MsaToolResult:
"""Merges nhmmer result protos into a single one."""
assert len(set(nh_res.target_sequence for nh_res in nhmmer_results)) == 1
assert len(set(nh_res.e_value for nh_res in nhmmer_results)) == 1
# Parse the TBL output, create a mapping from unique hit ID to TBL line.
parsed_tbl = {}
for nhmmer_result in nhmmer_results:
assert nhmmer_result.tblout is not None
for line in nhmmer_result.tblout.splitlines():
if not line.startswith('#'):
line_fields = line.split(maxsplit=15)
accession = line_fields[0]
alignment_from = line_fields[6]
alignment_to = line_fields[7]
# This is the unique ID that is used in the output A3M.
unique_id = f'{accession}/{alignment_from}-{alignment_to}'
parsed_tbl[unique_id] = line
# Create an iterator and merge a3m info with tbl info.
def _merged_a3m_tbl_iter(a3m: str) -> Iterable[tuple[str, str, str, str]]:
# Don't parse the entire a3m, lazily parse only as many sequences as needed.
iterator = iter(parsers.lazy_parse_fasta_string(a3m))
next(iterator) # Skip the query which isn't present in tblout.
for sequence, description in iterator:
name = description.partition(' ')[0]
if tbl_info := parsed_tbl.get(name):
# Skip sequences for which we don't have tbl information.
yield sequence, description, tbl_info, name
def sort_key(seq_data: tuple[str, str, str, str]) -> tuple[float, str]:
unused_seq, unused_description, tbl_info, name = seq_data
# Nucleic tblout has 16 space delimited columns. "-" used if no value
# present. We want e-value in column 12, so do only 13 splits. Use the name
# in case of an e-value tie.
return float(tbl_info.split(maxsplit=13)[12]), name
# A3M/TBL is sorted by e-value and name, hence we can merge them efficiently.
merged_a3m_and_tblout = heapq.merge(
*[_merged_a3m_tbl_iter(res.a3m) for res in nhmmer_results],
key=sort_key,
)
# Truncate the a3m to max_sequences. Do not truncate the tblout.
merged_tblout = []
merged_a3m = [f'>query\n{nhmmer_results[0].target_sequence}']
for seq, description, tbl_info, _ in merged_a3m_and_tblout:
merged_tblout.append(tbl_info)
if len(merged_a3m) < max_sequences:
merged_a3m.append(f'>{description}\n{seq}')
logging.info(
'Limiting merged MSA depth from %d to %d',
len(merged_tblout),
max_sequences,
)
return msa_tool.MsaToolResult(
target_sequence=nhmmer_results[0].target_sequence,
a3m='\n'.join(merged_a3m),
e_value=nhmmer_results[0].e_value,
tblout=None, # We no longer need the tblout.
)
================================================
FILE: src/alphafold3/data/tools/rdkit_utils.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Tools for calculating features for ligands."""
import collections
from collections.abc import Mapping, Sequence
from absl import logging
from alphafold3.cpp import cif_dict
import numpy as np
import rdkit.Chem as rd_chem
from rdkit.Chem import AllChem as rd_all_chem
_RDKIT_MMCIF_TO_BOND_TYPE: Mapping[str, rd_chem.BondType] = {
'SING': rd_chem.BondType.SINGLE,
'DOUB': rd_chem.BondType.DOUBLE,
'TRIP': rd_chem.BondType.TRIPLE,
}
_RDKIT_BOND_TYPE_TO_MMCIF: Mapping[rd_chem.BondType, str] = {
v: k for k, v in _RDKIT_MMCIF_TO_BOND_TYPE.items()
}
_RDKIT_BOND_STEREO_TO_MMCIF: Mapping[rd_chem.BondStereo, str] = {
rd_chem.BondStereo.STEREONONE: 'N',
rd_chem.BondStereo.STEREOE: 'E',
rd_chem.BondStereo.STEREOZ: 'Z',
rd_chem.BondStereo.STEREOCIS: 'Z',
rd_chem.BondStereo.STEREOTRANS: 'E',
}
class MolFromMmcifError(Exception):
"""Raised when conversion from mmCIF to RDKit Mol fails."""
class UnsupportedMolBondError(Exception):
"""Raised when we try to handle unsupported RDKit bonds."""
def _populate_atoms_in_mol(
mol: rd_chem.Mol,
atom_names: Sequence[str],
atom_types: Sequence[str],
atom_charges: Sequence[int],
implicit_hydrogens: bool,
ligand_name: str,
atom_leaving_flags: Sequence[str],
):
"""Populate the atoms of a Mol given atom features.
Args:
mol: Mol object.
atom_names: Names of the atoms.
atom_types: Types of the atoms.
atom_charges: Charges of the atoms.
implicit_hydrogens: Whether to mark the atoms to allow implicit Hs.
ligand_name: Name of the ligand which the atoms are in.
atom_leaving_flags: Whether the atom is possibly a leaving atom. Values from
the CCD column `_chem_comp_atom.pdbx_leaving_atom_flag`. The expected
values are 'Y' (yes), 'N' (no), '?' (unknown/unset, interpreted as no).
Raises:
ValueError: If atom type is invalid.
"""
# Map atom names to the position they will take in the rdkit molecule.
atom_name_to_idx = {name: i for i, name in enumerate(atom_names)}
for atom_name, atom_type, atom_charge, atom_leaving_flag in zip(
atom_names, atom_types, atom_charges, atom_leaving_flags, strict=True
):
try:
if atom_type == 'X':
atom_type = '*'
atom = rd_chem.Atom(atom_type)
except RuntimeError as e:
raise ValueError(f'Failed to use atom type: {str(e)}') from e
if not implicit_hydrogens:
atom.SetNoImplicit(True)
atom.SetProp('atom_name', atom_name)
atom.SetProp('atom_leaving_flag', atom_leaving_flag)
atom.SetFormalCharge(atom_charge)
residue_info = rd_chem.AtomPDBResidueInfo()
residue_info.SetName(_format_atom_name(atom_name, atom_type))
residue_info.SetIsHeteroAtom(True)
residue_info.SetResidueName(ligand_name)
residue_info.SetResidueNumber(1)
atom.SetPDBResidueInfo(residue_info)
atom_index = mol.AddAtom(atom)
assert atom_index == atom_name_to_idx[atom_name]
def _populate_bonds_in_mol(
mol: rd_chem.Mol,
atom_names: Sequence[str],
bond_begins: Sequence[str],
bond_ends: Sequence[str],
bond_orders: Sequence[str],
bond_is_aromatics: Sequence[bool],
):
"""Populate the bonds of a Mol given bond features.
Args:
mol: Mol object.
atom_names: Names of atoms in the molecule.
bond_begins: Names of atoms at the beginning of the bond.
bond_ends: Names of atoms at the end of the bond.
bond_orders: What order the bonds are.
bond_is_aromatics: Whether the bonds are aromatic.
"""
atom_name_to_idx = {name: i for i, name in enumerate(atom_names)}
for begin, end, bond_type, is_aromatic in zip(
bond_begins, bond_ends, bond_orders, bond_is_aromatics, strict=True
):
begin_name, end_name = atom_name_to_idx[begin], atom_name_to_idx[end]
bond_idx = mol.AddBond(begin_name, end_name, bond_type)
mol.GetBondWithIdx(bond_idx - 1).SetIsAromatic(is_aromatic)
def sanitize_mol(mol, sort_alphabetically, remove_hydrogens) -> rd_chem.Mol:
# https://www.rdkit.org/docs/source/rdkit.Chem.rdmolops.html#rdkit.Chem.rdmolops.SanitizeMol
# Kekulize, check valencies, set aromaticity, conjugation and hybridization.
# This can repair e.g. incorrect aromatic flags.
rd_chem.SanitizeMol(mol)
if sort_alphabetically:
mol = sort_atoms_by_name(mol)
if remove_hydrogens:
mol = rd_chem.RemoveHs(mol)
return mol
def _add_conformer_to_mol(mol, conformer, force_parse) -> rd_chem.Mol:
# Create conformer and use it to assign stereochemistry.
if conformer is not None:
try:
mol.AddConformer(conformer)
rd_chem.AssignStereochemistryFrom3D(mol)
except ValueError as e:
logging.warning('Failed to parse conformer: %s', e)
if not force_parse:
raise
def mol_from_ccd_cif(
mol_cif: cif_dict.CifDict,
*,
force_parse: bool = False,
sort_alphabetically: bool = True,
remove_hydrogens: bool = True,
implicit_hydrogens: bool = False,
) -> rd_chem.Mol:
"""Creates an rdkit Mol object from a CCD mmcif data block.
The atoms are renumbered so that their names are in alphabetical order and
these names are placed on the atoms under property 'atom_name'.
Only hydrogens which are not required to define the molecule are removed.
For example, hydrogens that define stereochemistry around a double bond are
retained.
See this link for more details.
https://www.rdkit.org/docs/source/rdkit.Chem.rdmolops.html#rdkit.Chem.rdmolops.RemoveHs
Args:
mol_cif: An mmcif object representing a molecule.
force_parse: If True, assumes missing aromatic flags are false, substitutes
deuterium for hydrogen, assumes missing charges are 0 and ignores missing
conformer / stereochemistry information.
sort_alphabetically: True: sort atom alphabetically; False: keep CCD order
remove_hydrogens: if True, remove non-important hydrogens
implicit_hydrogens: Sets a marker on the atom that allows implicit Hs.
Returns:
An rdkit molecule, with the atoms sorted by name.
Raises:
MolToMmcifError: If conversion from mmcif to rdkit Mol fails. More detailed
error is available as this error's cause.
"""
# Read data fields.
try:
atom_names, atom_types, atom_charges, atom_leaving_flags = parse_atom_data(
mol_cif, force_parse
)
bond_begins, bond_ends, bond_orders, bond_is_aromatics = parse_bond_data(
mol_cif, force_parse
)
lig_name = mol_cif['_chem_comp.id'][0].rjust(3)
except (KeyError, ValueError) as e:
raise MolFromMmcifError from e
# Build Rdkit molecule.
mol = rd_chem.RWMol()
# Per atom features.
try:
_populate_atoms_in_mol(
mol=mol,
atom_names=atom_names,
atom_types=atom_types,
atom_charges=atom_charges,
implicit_hydrogens=implicit_hydrogens,
ligand_name=lig_name,
atom_leaving_flags=atom_leaving_flags,
)
except (ValueError, RuntimeError) as e:
raise MolFromMmcifError from e
_populate_bonds_in_mol(
mol, atom_names, bond_begins, bond_ends, bond_orders, bond_is_aromatics
)
try:
conformer = _parse_ideal_conformer(mol_cif)
except (KeyError, ValueError) as e:
logging.warning('Failed to parse ideal conformer: %s', e)
if not force_parse:
raise MolFromMmcifError from e
conformer = None
mol.UpdatePropertyCache(strict=False)
try:
_add_conformer_to_mol(mol, conformer, force_parse)
mol = sanitize_mol(mol, sort_alphabetically, remove_hydrogens)
except (
ValueError,
rd_chem.KekulizeException,
rd_chem.AtomValenceException,
) as e:
raise MolFromMmcifError from e
return mol
def mol_to_ccd_cif(
mol: rd_chem.Mol,
component_id: str,
pdbx_smiles: str | None = None,
include_hydrogens: bool = True,
) -> cif_dict.CifDict:
"""Creates a CCD-like mmcif data block from an rdkit Mol object.
Only a subset of associated mmcif fields is populated, but that is
sufficient for further usage, e.g. in featurization code.
Atom names can be specified via `atom_name` property. For atoms with
unspecified value of that property, the name is assigned based on element type
and the order in the Mol object.
If the Mol object has associated conformers, atom positions from the first of
them will be populated in the resulting mmcif file.
Args:
mol: An rdkit molecule.
component_id: Name of the molecule to use in the resulting mmcif. That is
equivalent to CCD code.
pdbx_smiles: If specified, the value will be used to populate
`_chem_comp.pdbx_smiles`.
include_hydrogens: Whether to include atom and bond data involving
hydrogens.
Returns:
An mmcif data block corresponding for the given rdkit molecule.
Raises:
UnsupportedMolBond: When a molecule contains a bond that can't be
represented with mmcif.
"""
mol = rd_chem.Mol(mol)
if include_hydrogens:
mol = rd_chem.AddHs(mol)
rd_chem.Kekulize(mol)
if mol.GetNumConformers() > 0:
ideal_conformer = mol.GetConformer(0).GetPositions()
ideal_conformer = np.vectorize(lambda x: f'{x:.3f}')(ideal_conformer)
else:
# No data will be populated in the resulting mmcif if the molecule doesn't
# have any conformers attached to it.
ideal_conformer = None
mol_cif = collections.defaultdict(list)
mol_cif['data_'] = [component_id]
mol_cif['_chem_comp.id'] = [component_id]
if pdbx_smiles:
mol_cif['_chem_comp.pdbx_smiles'] = [pdbx_smiles]
mol = assign_atom_names_from_graph(mol, keep_existing_names=True)
for atom_idx, atom in enumerate(mol.GetAtoms()):
element = atom.GetSymbol()
if not include_hydrogens and element in ('H', 'D'):
continue
mol_cif['_chem_comp_atom.comp_id'].append(component_id)
mol_cif['_chem_comp_atom.atom_id'].append(atom.GetProp('atom_name'))
mol_cif['_chem_comp_atom.type_symbol'].append(atom.GetSymbol().upper())
mol_cif['_chem_comp_atom.charge'].append(str(atom.GetFormalCharge()))
if ideal_conformer is not None:
coords = ideal_conformer[atom_idx]
mol_cif['_chem_comp_atom.pdbx_model_Cartn_x_ideal'].append(coords[0])
mol_cif['_chem_comp_atom.pdbx_model_Cartn_y_ideal'].append(coords[1])
mol_cif['_chem_comp_atom.pdbx_model_Cartn_z_ideal'].append(coords[2])
for bond in mol.GetBonds():
atom1 = bond.GetBeginAtom()
atom2 = bond.GetEndAtom()
if not include_hydrogens and (
atom1.GetSymbol() in ('H', 'D') or atom2.GetSymbol() in ('H', 'D')
):
continue
mol_cif['_chem_comp_bond.comp_id'].append(component_id)
mol_cif['_chem_comp_bond.atom_id_1'].append(
bond.GetBeginAtom().GetProp('atom_name')
)
mol_cif['_chem_comp_bond.atom_id_2'].append(
bond.GetEndAtom().GetProp('atom_name')
)
try:
bond_type = bond.GetBondType()
# Older versions of RDKit did not have a DATIVE bond type. Convert it to
# SINGLE to match the AF3 training setup.
if bond_type == rd_chem.BondType.DATIVE:
bond_type = rd_chem.BondType.SINGLE
mol_cif['_chem_comp_bond.value_order'].append(
_RDKIT_BOND_TYPE_TO_MMCIF[bond_type]
)
mol_cif['_chem_comp_bond.pdbx_stereo_config'].append(
_RDKIT_BOND_STEREO_TO_MMCIF[bond.GetStereo()]
)
except KeyError as e:
raise UnsupportedMolBondError from e
mol_cif['_chem_comp_bond.pdbx_aromatic_flag'].append(
'Y' if bond.GetIsAromatic() else 'N'
)
return cif_dict.CifDict(mol_cif)
def _format_atom_name(atom_name: str, atom_type: str) -> str:
"""Formats an atom name to fit in the four characters specified in PDB.
See for example the following note on atom name formatting in PDB files:
https://www.cgl.ucsf.edu/chimera/docs/UsersGuide/tutorials/pdbintro.html#note1
Args:
atom_name: The unformatted atom name.
atom_type: The atom element symbol.
Returns:
formatted_atom_name: The formatted 4-character atom name.
"""
atom_name = atom_name.strip()
atom_type = atom_type.strip().upper()
if len(atom_name) == 1:
return atom_name.rjust(2).ljust(4)
elif len(atom_name) == 2:
if atom_name == atom_type:
return atom_name.ljust(4)
return atom_name.center(4)
elif len(atom_name) == 3:
if atom_name[:2] == atom_type:
return atom_name.ljust(4)
return atom_name.rjust(4)
elif len(atom_name) == 4:
return atom_name
else:
raise ValueError(
f'Atom name `{atom_name}` has more than four characters '
'or is an empty string.'
)
def parse_atom_data(
mol_cif: cif_dict.CifDict | Mapping[str, Sequence[str]], force_parse: bool
) -> tuple[Sequence[str], Sequence[str], Sequence[int], Sequence[str]]:
"""Parses atoms. If force_parse is True, fix deuterium and missing charge."""
atom_types = [t.capitalize() for t in mol_cif['_chem_comp_atom.type_symbol']]
atom_names = mol_cif['_chem_comp_atom.atom_id']
atom_charges = mol_cif['_chem_comp_atom.charge']
atom_leaving_flags = ['?'] * len(atom_names)
if '_chem_comp_atom.pdbx_leaving_atom_flag' in mol_cif:
atom_leaving_flags = mol_cif['_chem_comp_atom.pdbx_leaving_atom_flag']
if force_parse:
# Replace missing charges with 0.
atom_charges = [charge if charge != '?' else '0' for charge in atom_charges]
# Deuterium for hydrogen.
atom_types = [type_ if type_ != 'D' else 'H' for type_ in atom_types]
atom_charges = [int(atom_charge) for atom_charge in atom_charges]
return atom_names, atom_types, atom_charges, atom_leaving_flags
def parse_bond_data(
mol_cif: cif_dict.CifDict | Mapping[str, Sequence[str]], force_parse: bool
) -> tuple[
Sequence[str], Sequence[str], Sequence[rd_chem.BondType], Sequence[bool]
]:
"""Parses bond data. If force_parse is True, ignore missing aromatic flags."""
# The bond table isn't present if there are no bonds. Use [] in that case.
begin_atoms = mol_cif.get('_chem_comp_bond.atom_id_1', [])
end_atoms = mol_cif.get('_chem_comp_bond.atom_id_2', [])
orders = mol_cif.get('_chem_comp_bond.value_order', [])
bond_types = [_RDKIT_MMCIF_TO_BOND_TYPE[order] for order in orders]
try:
aromatic_flags = mol_cif.get('_chem_comp_bond.pdbx_aromatic_flag', [])
is_aromatic = [{'Y': True, 'N': False}[flag] for flag in aromatic_flags]
except KeyError:
if force_parse:
# Set them all to not aromatic.
is_aromatic = [False for _ in begin_atoms]
else:
raise
return begin_atoms, end_atoms, bond_types, is_aromatic
def _parse_ideal_conformer(mol_cif: cif_dict.CifDict) -> rd_chem.Conformer:
"""Builds a conformer containing the ideal coordinates from the CCD.
Args:
mol_cif: An mmcif object representing a molecule.
Returns:
An rdkit conformer filled with the ideal positions from the mmcif.
Raises:
ValueError: if the positions can't be interpreted.
"""
atom_x = [
float(x) for x in mol_cif['_chem_comp_atom.pdbx_model_Cartn_x_ideal']
]
atom_y = [
float(y) for y in mol_cif['_chem_comp_atom.pdbx_model_Cartn_y_ideal']
]
atom_z = [
float(z) for z in mol_cif['_chem_comp_atom.pdbx_model_Cartn_z_ideal']
]
atom_positions = zip(atom_x, atom_y, atom_z, strict=True)
conformer = rd_chem.Conformer(len(atom_x))
for atom_index, atom_position in enumerate(atom_positions):
conformer.SetAtomPosition(atom_index, atom_position)
return conformer
def sort_atoms_by_name(mol: rd_chem.Mol) -> rd_chem.Mol:
"""Sorts the atoms in the molecule by their names."""
atom_names = {
atom.GetProp('atom_name'): atom.GetIdx() for atom in mol.GetAtoms()
}
# Sort the name, int tuples by the names.
sorted_atom_names = sorted(atom_names.items())
# Zip these tuples back together to the sorted indices.
_, new_order = zip(*sorted_atom_names, strict=True)
# Reorder the molecule.
# new_order is effectively an argsort of the names.
return rd_chem.RenumberAtoms(mol, new_order)
def assign_atom_names_from_graph(
mol: rd_chem.Mol,
keep_existing_names: bool = False,
) -> rd_chem.Mol:
"""Assigns atom names from the molecular graph.
The atom name is stored as an atom property 'atom_name', accessible
with atom.GetProp('atom_name'). If the property is already specified, and
keep_existing_names is True we keep the original name.
We traverse the graph in the order of the rdkit atom index and give each atom
a name equal to '{ELEMENT_TYPE}{INDEX}'. E.g. C5 is the name for the fifth
unnamed carbon encountered.
NOTE: A new mol is returned, the original is not changed in place.
Args:
mol: Mol object.
keep_existing_names: If True, atoms that already have the atom_name property
will keep their assigned names.
Returns:
A new mol, with potentially new 'atom_name' properties.
"""
mol = rd_chem.Mol(mol)
specified_atom_names = {
atom.GetProp('atom_name')
for atom in mol.GetAtoms()
if atom.HasProp('atom_name') and keep_existing_names
}
element_counts = collections.Counter()
for atom in mol.GetAtoms():
if not atom.HasProp('atom_name') or not keep_existing_names:
element = atom.GetSymbol()
while True:
element_counts[element] += 1
# Standardize names by using uppercase element type, as in CCD. Only
# effects elements with more than one letter, e.g. 'Cl' becomes 'CL'.
new_name = f'{element.upper()}{element_counts[element]}'
if new_name not in specified_atom_names:
break
atom.SetProp('atom_name', new_name)
return mol
def get_random_conformer(
mol: rd_chem.Mol,
random_seed: int,
max_iterations: int | None,
logging_name: str,
) -> rd_chem.Conformer | None:
"""Stochastic conformer search method using V3 ETK."""
params = rd_all_chem.ETKDGv3()
params.randomSeed = random_seed
if max_iterations is not None: # Override default value.
params.maxIterations = max_iterations
mol_copy = rd_chem.Mol(mol)
try:
conformer_id = rd_all_chem.EmbedMolecule(mol_copy, params)
conformer = mol_copy.GetConformer(conformer_id)
except ValueError:
logging.warning('Failed to generate conformer for: %s', logging_name)
conformer = None
return conformer
================================================
FILE: src/alphafold3/data/tools/shards.py
================================================
# Copyright 2025 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""A library to handle shards of the format file_path@NUM_SHARDS.
For instance, /path/to/file@20 will generate the following shards:
- /path/to/file-00000-of-00020
- /path/to/file-00001-of-00020
- ...
- /path/to/file-00019-of-00020
This also supports @* pattern, which will determine the number of shards based
on the filesystem content.
"""
from collections.abc import Sequence
import dataclasses
import pathlib
import re
_MAX_NUM_SHARDS = 99_999
_SHARD_RE = re.compile(
r"""
^(?P[^\?\],\*]+)@
(?P(\d{1,5})|\*)
(?P[\._][^\?\]@\*\/]*)?
$""",
re.X,
)
@dataclasses.dataclass(frozen=True)
class ShardSpec:
prefix: str
num_shards: int
suffix: str
def parse_shard_spec(path: str) -> ShardSpec | None:
"""Returns the shard spec or None if the path is not a shard spec.
For instance, if the shard spec is '/path/to/file@20', the output will be
('/path/to/file', 20).
Args:
path: the path to parse, e.g. /path/to/file@20 or /path/to/file@*.
"""
parsed = re.fullmatch(_SHARD_RE, path)
if not parsed:
return None
prefix = parsed.group('prefix')
shards = parsed.group('shards')
suffix = parsed.group('suffix') or ''
if shards != '*':
return ShardSpec(prefix=prefix, num_shards=int(shards), suffix=suffix)
shard_slice = slice(len(prefix) + 10, len(prefix) + 15)
shard_path = pathlib.Path(f'{prefix}-00000-of-?????{suffix}')
for shard in sorted(shard_path.parent.glob(shard_path.name), reverse=True):
try:
num_shards = int(str(shard)[shard_slice])
return ShardSpec(prefix=prefix, num_shards=num_shards, suffix=suffix)
except ValueError:
continue
return None
def get_sharded_paths(shard_spec: str) -> Sequence[str] | None:
"""Returns a list of file path or None if the input is not a shard spec.
Args:
shard_spec: the specifications of the shard, e.g. /path/to/file@20.
"""
parsed_spec = parse_shard_spec(shard_spec)
if not parsed_spec:
return None
prefix = parsed_spec.prefix
num_shards = parsed_spec.num_shards
suffix = parsed_spec.suffix
if num_shards > _MAX_NUM_SHARDS:
raise ValueError(f'Shard count for {shard_spec} exceeds {_MAX_NUM_SHARDS}')
return [
f'{prefix}-{i:05d}-of-{num_shards:05d}{suffix}' for i in range(num_shards)
]
================================================
FILE: src/alphafold3/data/tools/subprocess_utils.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Helper functions for launching external tools."""
from collections.abc import Sequence
import os
import subprocess
import time
from typing import Any
from absl import logging
def create_query_fasta_file(sequence: str, path: str, linewidth: int = 80):
"""Creates a fasta file with the sequence with line width limit."""
with open(path, 'w') as f:
f.write('>query\n')
i = 0
while i < len(sequence):
f.write(f'{sequence[i:(i + linewidth)]}\n')
i += linewidth
def check_binary_exists(path: str, name: str) -> None:
"""Checks if a binary exists on the given path and raises otherwise."""
if not os.path.exists(path):
raise RuntimeError(f'{name} binary not found at {path}')
def jackhmmer_seq_limit_supported(jackhmmer_path: str) -> bool:
"""Checks if Jackhmmer supports the --seq-limit flag."""
try:
subprocess.run(
[jackhmmer_path, '-h', '--seq_limit', '1'],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=True,
)
except subprocess.CalledProcessError:
return False
return True
def run(
cmd: Sequence[str],
cmd_name: str,
log_on_process_error: bool = False,
log_stderr: bool = False,
log_stdout: bool = False,
max_out_streams_len: int | None = 500_000,
**run_kwargs,
) -> subprocess.CompletedProcess[Any]:
"""Launches a subprocess, times it, and checks for errors.
Args:
cmd: Command to launch.
cmd_name: Human-readable command name to be used in logs.
log_on_process_error: Whether to use `logging.error` to log the process'
stderr on failure.
log_stderr: Whether to log the stderr of the command.
log_stdout: Whether to log the stdout of the command.
max_out_streams_len: Max length of prefix of stdout and stderr included in
the exception message. Set to `None` to disable truncation.
**run_kwargs: Any other kwargs for `subprocess.run`.
Returns:
The completed process object.
Raises:
RuntimeError: if the process completes with a non-zero return code.
"""
logging.info('Launching subprocess "%s"', ' '.join(cmd))
start_time = time.time()
try:
completed_process = subprocess.run(
cmd,
check=True,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
text=True,
**run_kwargs,
)
except subprocess.CalledProcessError as e:
if log_on_process_error:
# Logs have a 15k character limit, so log the error line by line.
logging.error('%s failed. %s stderr begin:', cmd_name, cmd_name)
for error_line in e.stderr.splitlines():
if stripped_error_line := error_line.strip():
logging.error(stripped_error_line)
logging.error('%s stderr end.', cmd_name)
error_msg = (
f'{cmd_name} failed'
f'\nstdout:\n{e.stdout[:max_out_streams_len]}\n'
f'\nstderr:\n{e.stderr[:max_out_streams_len]}'
)
raise RuntimeError(error_msg) from e
end_time = time.time()
logging.info('Finished %s in %.3f seconds', cmd_name, end_time - start_time)
stdout, stderr = completed_process.stdout, completed_process.stderr
if log_stdout and stdout:
logging.info('%s stdout:\n%s', cmd_name, stdout)
if log_stderr and stderr:
logging.info('%s stderr:\n%s', cmd_name, stderr)
return completed_process
================================================
FILE: src/alphafold3/jax/geometry/__init__.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Geometry Module."""
from alphafold3.jax.geometry import rigid_matrix_vector
from alphafold3.jax.geometry import rotation_matrix
from alphafold3.jax.geometry import struct_of_array
from alphafold3.jax.geometry import vector
Rot3Array = rotation_matrix.Rot3Array
Rigid3Array = rigid_matrix_vector.Rigid3Array
StructOfArray = struct_of_array.StructOfArray
Vec3Array = vector.Vec3Array
square_euclidean_distance = vector.square_euclidean_distance
euclidean_distance = vector.euclidean_distance
dihedral_angle = vector.dihedral_angle
dot = vector.dot
cross = vector.cross
================================================
FILE: src/alphafold3/jax/geometry/rigid_matrix_vector.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Rigid3Array Transformations represented by a Matrix and a Vector."""
from typing import Any, Final, Self, TypeAlias
from alphafold3.jax.geometry import rotation_matrix
from alphafold3.jax.geometry import struct_of_array
from alphafold3.jax.geometry import utils
from alphafold3.jax.geometry import vector
import jax
import jax.numpy as jnp
Float: TypeAlias = float | jnp.ndarray
VERSION: Final[str] = '0.1'
# Disabling name in pylint, since the relevant variable in math are typically
# referred to as X, Y in mathematical literature.
def _compute_covariance_matrix(
row_values: vector.Vec3Array,
col_values: vector.Vec3Array,
weights: jnp.ndarray,
epsilon=1e-6,
) -> jnp.ndarray:
"""Compute covariance matrix.
The quantity computes is
cov_xy = weighted_avg_i(row_values[i, x] col_values[j, y]).
Here x and y run over the xyz coordinates.
This is used to construct frames when aligning points.
Args:
row_values: Values used for rows of covariance matrix, shape [..., n_point]
col_values: Values used for columns of covariance matrix, shape [...,
n_point]
weights: weights to weight points by, shape broacastable to [...]
epsilon: small value to add to denominator to avoid Nan's when all weights
are 0.
Returns:
Covariance Matrix as [..., 3, 3] array.
"""
weights = jnp.asarray(weights)
weights = jnp.broadcast_to(weights, row_values.shape)
out = []
normalized_weights = weights / (weights.sum(axis=-1, keepdims=True) + epsilon)
weighted_average = lambda x: jnp.sum(normalized_weights * x, axis=-1)
out.append(
jnp.stack(
(
weighted_average(row_values.x * col_values.x),
weighted_average(row_values.x * col_values.y),
weighted_average(row_values.x * col_values.z),
),
axis=-1,
)
)
out.append(
jnp.stack(
(
weighted_average(row_values.y * col_values.x),
weighted_average(row_values.y * col_values.y),
weighted_average(row_values.y * col_values.z),
),
axis=-1,
)
)
out.append(
jnp.stack(
(
weighted_average(row_values.z * col_values.x),
weighted_average(row_values.z * col_values.y),
weighted_average(row_values.z * col_values.z),
),
axis=-1,
)
)
return jnp.stack(out, axis=-2)
@struct_of_array.StructOfArray(same_dtype=True)
class Rigid3Array:
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation: rotation_matrix.Rot3Array
translation: vector.Vec3Array
def __matmul__(self, other: Self) -> Self:
new_rotation = self.rotation @ other.rotation
new_translation = self.apply_to_point(other.translation)
return Rigid3Array(new_rotation, new_translation)
def inverse(self) -> Self:
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation = self.rotation.inverse()
inv_translation = inv_rotation.apply_to_point(-self.translation)
return Rigid3Array(inv_rotation, inv_translation)
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply Rigid3Array transform to point."""
return self.rotation.apply_to_point(point) + self.translation
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply inverse Rigid3Array transform to point."""
new_point = point - self.translation
return self.rotation.apply_inverse_to_point(new_point)
def compose_rotation(self, other_rotation: rotation_matrix.Rot3Array) -> Self:
rot = self.rotation @ other_rotation
trans = jax.tree.map(
lambda x: jnp.broadcast_to(x, rot.shape), self.translation
)
return Rigid3Array(rot, trans)
@classmethod
def identity(cls, shape: Any, dtype: jnp.dtype = jnp.float32) -> Self:
"""Return identity Rigid3Array of given shape."""
return cls(
rotation_matrix.Rot3Array.identity(shape, dtype=dtype),
vector.Vec3Array.zeros(shape, dtype=dtype),
) # pytype: disable=wrong-arg-count # trace-all-classes
def scale_translation(self, factor: Float) -> Self:
"""Scale translation in Rigid3Array by 'factor'."""
return Rigid3Array(self.rotation, self.translation * factor)
def to_array(self):
rot_array = self.rotation.to_array()
vec_array = self.translation.to_array()
return jnp.concatenate([rot_array, vec_array[..., None]], axis=-1)
@classmethod
def from_array(cls, array):
rot = rotation_matrix.Rot3Array.from_array(array[..., :3])
vec = vector.Vec3Array.from_array(array[..., -1])
return cls(rot, vec) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod
def from_array4x4(cls, array: jnp.ndarray) -> Self:
"""Construct Rigid3Array from homogeneous 4x4 array."""
if array.shape[-2:] != (4, 4):
raise ValueError(f'array.shape({array.shape}) must be [..., 4, 4]')
rotation = rotation_matrix.Rot3Array(
*(array[..., 0, 0], array[..., 0, 1], array[..., 0, 2]),
*(array[..., 1, 0], array[..., 1, 1], array[..., 1, 2]),
*(array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]),
)
translation = vector.Vec3Array(
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]
)
return cls(rotation, translation) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod
def from_point_alignment(
cls,
points_to: vector.Vec3Array,
points_from: vector.Vec3Array,
weights: Float | None = None,
epsilon: float = 1e-6,
) -> Self:
"""Constructs Rigid3Array by finding transform aligning points.
This constructs the optimal Rigid Transform taking points_from to the
arrangement closest to points_to.
Args:
points_to: Points to align to.
points_from: Points to align from.
weights: weights for points.
epsilon: epsilon used to regularize covariance matrix.
Returns:
Rigid Transform.
"""
if weights is None:
weights = 1.0
def compute_center(value):
return utils.weighted_mean(value=value, weights=weights, axis=-1)
points_to_center = jax.tree.map(compute_center, points_to)
points_from_center = jax.tree.map(compute_center, points_from)
centered_points_to = points_to - points_to_center[..., None]
centered_points_from = points_from - points_from_center[..., None]
cov_mat = _compute_covariance_matrix(
centered_points_to,
centered_points_from,
weights=weights,
epsilon=epsilon,
)
rots = rotation_matrix.Rot3Array.from_svd(
jnp.reshape(cov_mat, cov_mat.shape[:-2] + (9,))
)
translations = points_to_center - rots.apply_to_point(points_from_center)
return cls(rots, translations) # pytype: disable=wrong-arg-count # trace-all-classes
def __getstate__(self):
return (VERSION, (self.rotation, self.translation))
def __setstate__(self, state):
version, (rot, trans) = state
del version
object.__setattr__(self, 'rotation', rot)
object.__setattr__(self, 'translation', trans)
================================================
FILE: src/alphafold3/jax/geometry/rotation_matrix.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Rot3Array Matrix Class."""
import dataclasses
from typing import Any, Final, Self
from alphafold3.jax.geometry import struct_of_array
from alphafold3.jax.geometry import utils
from alphafold3.jax.geometry import vector
import jax
import jax.numpy as jnp
import numpy as np
COMPONENTS: Final[tuple[str, ...]] = (
*('xx', 'xy', 'xz'),
*('yx', 'yy', 'yz'),
*('zx', 'zy', 'zz'),
)
VERSION: Final[str] = '0.1'
def make_matrix_svd_factors() -> np.ndarray:
"""Generates factors for converting 3x3 matrix to symmetric 4x4 matrix."""
factors = np.zeros((16, 9), dtype=np.float32)
factors[0, [0, 4, 8]] = 1.0
factors[[1, 4], 5] = 1.0
factors[[1, 4], 7] = -1.0
factors[[2, 8], 6] = 1.0
factors[[2, 8], 2] = -1.0
factors[[3, 12], 1] = 1.0
factors[[3, 12], 3] = -1.0
factors[5, 0] = 1.0
factors[5, [4, 8]] = -1.0
factors[[6, 9], 1] = 1.0
factors[[6, 9], 3] = 1.0
factors[[7, 13], 2] = 1.0
factors[[7, 13], 6] = 1.0
factors[10, 4] = 1.0
factors[10, [0, 8]] = -1.0
factors[[11, 14], 5] = 1.0
factors[[11, 14], 7] = 1.0
factors[15, 8] = 1.0
factors[15, [0, 4]] = -1.0
return factors
@jax.custom_jvp
def largest_evec(m):
_, eigvecs = jnp.linalg.eigh(m)
return eigvecs[..., -1]
def largest_evec_jvp(primals, tangents):
"""jvp for largest eigenvector."""
(m,) = primals
(t,) = tangents
eigvals, eigvecs = jnp.linalg.eigh(m)
large_eigvec = eigvecs[..., -1]
large_eigval = eigvals[..., -1]
other_eigvals = eigvals[..., :-1]
other_eigvecs = eigvecs[..., :-1]
other_ev_times_tangent = jnp.einsum(
'...aj,...ab -> ...bj',
other_eigvecs,
t,
precision=jax.lax.Precision.HIGHEST,
)
nominator = jnp.einsum(
'...bj,...b -> ...j',
other_ev_times_tangent,
large_eigvec,
precision=jax.lax.Precision.HIGHEST,
)
prefactor = nominator / jnp.maximum(
large_eigval[..., None] - other_eigvals, 1e-6
)
grad = jnp.sum(prefactor[..., None, :] * other_eigvecs, axis=-1)
return large_eigvec, grad
largest_evec.defjvp(largest_evec_jvp)
MATRIX_SVD_QUAT_FACTORS = make_matrix_svd_factors()
@struct_of_array.StructOfArray(same_dtype=True)
class Rot3Array:
"""Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
xx: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32})
xy: jnp.ndarray
xz: jnp.ndarray
yx: jnp.ndarray
yy: jnp.ndarray
yz: jnp.ndarray
zx: jnp.ndarray
zy: jnp.ndarray
zz: jnp.ndarray
__array_ufunc__ = None
def inverse(self) -> Self:
"""Returns inverse of Rot3Array."""
return Rot3Array(
*(self.xx, self.yx, self.zx),
*(self.xy, self.yy, self.zy),
*(self.xz, self.yz, self.zz),
)
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Applies Rot3Array to point."""
return vector.Vec3Array(
self.xx * point.x + self.xy * point.y + self.xz * point.z,
self.yx * point.x + self.yy * point.y + self.yz * point.z,
self.zx * point.x + self.zy * point.y + self.zz * point.z,
)
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Applies inverse Rot3Array to point."""
return self.inverse().apply_to_point(point)
def __matmul__(self, other: Self) -> Self:
"""Composes two Rot3Arrays."""
c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx))
c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy))
c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz))
return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
@classmethod
def identity(cls, shape: Any, dtype: jnp.dtype = jnp.float32) -> Self:
"""Returns identity of given shape."""
ones = jnp.ones(shape, dtype=dtype)
zeros = jnp.zeros(shape, dtype=dtype)
return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod
def from_two_vectors(cls, e0: vector.Vec3Array, e1: vector.Vec3Array) -> Self:
"""Construct Rot3Array from two Vectors.
Rot3Array is constructed such that in the corresponding frame 'e0' lies on
the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
Args:
e0: Vector
e1: Vector
Returns:
Rot3Array
"""
# Normalize the unit vector for the x-axis, e0.
e0 = e0.normalized()
# make e1 perpendicular to e0.
c = e1.dot(e0)
e1 = (e1 - c * e0).normalized()
# Compute e2 as cross product of e0 and e1.
e2 = e0.cross(e1)
return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod
def from_array(cls, array: jnp.ndarray) -> Self:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
unstacked = utils.unstack(array, axis=-2)
unstacked = sum([utils.unstack(x, axis=-1) for x in unstacked], [])
return cls(*unstacked)
def to_array(self) -> jnp.ndarray:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return jnp.stack(
[
jnp.stack([self.xx, self.xy, self.xz], axis=-1),
jnp.stack([self.yx, self.yy, self.yz], axis=-1),
jnp.stack([self.zx, self.zy, self.zz], axis=-1),
],
axis=-2,
)
@classmethod
def from_quaternion(
cls,
w: jnp.ndarray,
x: jnp.ndarray,
y: jnp.ndarray,
z: jnp.ndarray,
normalize: bool = True,
epsilon: float = 1e-6,
) -> Self:
"""Construct Rot3Array from components of quaternion."""
if normalize:
inv_norm = jax.lax.rsqrt(jnp.maximum(epsilon, w**2 + x**2 + y**2 + z**2))
w *= inv_norm
x *= inv_norm
y *= inv_norm
z *= inv_norm
xx = 1 - 2 * (jnp.square(y) + jnp.square(z))
xy = 2 * (x * y - w * z)
xz = 2 * (x * z + w * y)
yx = 2 * (x * y + w * z)
yy = 1 - 2 * (jnp.square(x) + jnp.square(z))
yz = 2 * (y * z - w * x)
zx = 2 * (x * z - w * y)
zy = 2 * (y * z + w * x)
zz = 1 - 2 * (jnp.square(x) + jnp.square(y))
return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod
def from_svd(cls, mat: jnp.ndarray, use_quat_formula: bool = True) -> Self:
"""Constructs Rot3Array from arbitrary array of shape [3 * 3] using SVD.
The case when 'use_quat_formula' is False rephrases the problem of
projecting the matrix to a rotation matrix as a problem of finding the
largest eigenvector of a certain 4x4 matrix. This has the advantage of
having fewer numerical issues.
This approach follows:
https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.65.971&rep=rep1&type=pdf
In the other case we construct it via svd following
https://arxiv.org/pdf/2006.14616.pdf
In that case [∂L/∂M] is large if the two smallest singular values are close
to each other, or if they are close to 0.
Args:
mat: Array of shape [..., 3 * 3]
use_quat_formula: Whether to construct matrix via 4x4 eigenvalue problem.
Returns:
Rot3Array of shape [...]
"""
assert mat.shape[-1] == 9
if use_quat_formula:
symmetric_4by4 = jnp.einsum(
'ji, ...i -> ...j',
MATRIX_SVD_QUAT_FACTORS,
mat,
precision=jax.lax.Precision.HIGHEST,
)
symmetric_4by4 = jnp.reshape(symmetric_4by4, mat.shape[:-1] + (4, 4))
largest_eigvec = largest_evec(symmetric_4by4)
return cls.from_quaternion(
*utils.unstack(largest_eigvec, axis=-1)
).inverse()
else:
mat = jnp.reshape(mat, mat.shape[:-1] + (3, 3))
u, _, v_t = jnp.linalg.svd(mat, full_matrices=False)
det_uv_t = jnp.linalg.det(
jnp.matmul(u, v_t, precision=jax.lax.Precision.HIGHEST)
)
ones = jnp.ones_like(det_uv_t)
diag_array = jnp.stack([ones, ones, det_uv_t], axis=-1)
# This is equivalent to making diag_array into a diagonal array and matrix
# multiplying
diag_times_v_t = diag_array[..., None] * v_t
out = jnp.matmul(u, diag_times_v_t, precision=jax.lax.Precision.HIGHEST)
return cls.from_array(out)
@classmethod
def random_uniform(cls, key, shape, dtype=jnp.float32) -> Self:
"""Samples uniform random Rot3Array according to Haar Measure."""
quat_array = jax.random.normal(key, tuple(shape) + (4,), dtype=dtype)
quats = utils.unstack(quat_array)
return cls.from_quaternion(*quats)
def __getstate__(self):
return (VERSION, [np.asarray(getattr(self, field)) for field in COMPONENTS])
def __setstate__(self, state):
version, state = state
del version
for i, field in enumerate(COMPONENTS):
object.__setattr__(self, field, state[i])
================================================
FILE: src/alphafold3/jax/geometry/struct_of_array.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Class decorator to represent (nested) struct of arrays."""
import dataclasses
import jax
def get_item(instance, key):
sliced = {}
for field in get_array_fields(instance):
num_trailing_dims = field.metadata.get('num_trailing_dims', 0)
this_key = key
if isinstance(key, tuple) and Ellipsis in this_key:
this_key += (slice(None),) * num_trailing_dims
sliced[field.name] = jax.tree.map(
lambda x: x[this_key], # pylint: disable=cell-var-from-loop
getattr(instance, field.name),
)
return dataclasses.replace(instance, **sliced)
@property
def get_shape(instance):
"""Returns Shape for given instance of dataclass."""
first_field = dataclasses.fields(instance)[0]
num_trailing_dims = first_field.metadata.get('num_trailing_dims', None)
value = getattr(instance, first_field.name)
if num_trailing_dims:
return value.shape[:-num_trailing_dims]
else:
return value.shape
def get_len(instance):
"""Returns length for given instance of dataclass."""
shape = instance.shape
if shape:
return shape[0]
else:
raise TypeError('len() of unsized object') # Match jax.numpy behavior.
@property
def get_dtype(instance):
"""Returns Dtype for given instance of dataclass."""
fields = dataclasses.fields(instance)
sets_dtype = [
field.name for field in fields if field.metadata.get('sets_dtype', False)
]
if sets_dtype:
assert len(sets_dtype) == 1, 'at most field can set dtype'
field_value = getattr(instance, sets_dtype[0])
elif instance.same_dtype:
field_value = getattr(instance, fields[0].name)
else:
# Should this be Value Error?
raise AttributeError(
'Trying to access Dtype on Struct of Array without'
'either "same_dtype" or field setting dtype'
)
if hasattr(field_value, 'dtype'):
return field_value.dtype
else:
# Should this be Value Error?
raise AttributeError(f'field_value {field_value} does not have dtype')
def replace(instance, **kwargs):
return dataclasses.replace(instance, **kwargs)
def post_init(instance):
"""Validate instance has same shapes & dtypes."""
array_fields = get_array_fields(instance)
arrays = list(get_array_fields(instance, return_values=True).values())
first_field = array_fields[0]
# These slightly weird constructions about checking whether the leaves are
# actual arrays is since e.g. vmap internally relies on being able to
# construct pytree's with object() as leaves, this would break the checking
# as such we are only validating the object when the entries in the dataclass
# Are arrays or other dataclasses of arrays.
try:
dtype = instance.dtype
except AttributeError:
dtype = None
if dtype is not None:
first_shape = instance.shape
for array, field in zip(arrays, array_fields, strict=True):
num_trailing_dims = field.metadata.get('num_trailing_dims', None)
if num_trailing_dims:
array_shape = array.shape
field_shape = array_shape[:-num_trailing_dims]
msg = (
f'field {field} should have number of trailing dims'
' {num_trailing_dims}'
)
assert len(array_shape) == len(first_shape) + num_trailing_dims, msg
else:
field_shape = array.shape
shape_msg = (
f"Stripped Shape {field_shape} of field {field} doesn't "
f'match shape {first_shape} of field {first_field}'
)
assert field_shape == first_shape, shape_msg
field_dtype = array.dtype
allowed_metadata_dtypes = field.metadata.get('allowed_dtypes', [])
if allowed_metadata_dtypes:
msg = f'Dtype is {field_dtype} but must be in {allowed_metadata_dtypes}'
assert field_dtype in allowed_metadata_dtypes, msg
if 'dtype' in field.metadata:
target_dtype = field.metadata['dtype']
else:
target_dtype = dtype
msg = f'Dtype is {field_dtype} but must be {target_dtype}'
assert field_dtype == target_dtype, msg
def flatten(instance):
"""Flatten Struct of Array instance."""
array_likes = get_array_fields(instance, return_values=True).values()
flat_array_likes = []
inner_treedefs = []
num_arrays = []
for array_like in array_likes:
flat_array_like, inner_treedef = jax.tree_util.tree_flatten(array_like)
inner_treedefs.append(inner_treedef)
flat_array_likes += flat_array_like
num_arrays.append(len(flat_array_like))
metadata = get_metadata_fields(instance, return_values=True)
metadata = type(instance).metadata_cls(**metadata)
return flat_array_likes, (inner_treedefs, metadata, num_arrays)
def make_metadata_class(cls):
metadata_fields = get_fields(
cls, lambda x: x.metadata.get('is_metadata', False)
)
metadata_cls = dataclasses.make_dataclass(
cls_name='Meta' + cls.__name__,
fields=[(field.name, field.type, field) for field in metadata_fields],
frozen=True,
eq=True,
)
return metadata_cls
def get_fields(cls_or_instance, filterfn, return_values=False):
fields = dataclasses.fields(cls_or_instance)
fields = [field for field in fields if filterfn(field)]
if return_values:
return {
field.name: getattr(cls_or_instance, field.name) for field in fields
}
else:
return fields
def get_array_fields(cls, return_values=False):
return get_fields(
cls,
lambda x: not x.metadata.get('is_metadata', False),
return_values=return_values,
)
def get_metadata_fields(cls, return_values=False):
return get_fields(
cls,
lambda x: x.metadata.get('is_metadata', False),
return_values=return_values,
)
class StructOfArray:
"""Class Decorator for Struct Of Arrays."""
def __init__(self, same_dtype=True):
self.same_dtype = same_dtype
def __call__(self, cls):
cls.__array_ufunc__ = None
cls.replace = replace
cls.same_dtype = self.same_dtype
cls.dtype = get_dtype
cls.shape = get_shape
cls.__len__ = get_len
cls.__getitem__ = get_item
cls.__post_init__ = post_init
new_cls = dataclasses.dataclass(cls, frozen=True, eq=False) # pytype: disable=wrong-keyword-args
# pytree claims to require metadata to be hashable, not sure why,
# But making derived dataclass that can just hold metadata
new_cls.metadata_cls = make_metadata_class(new_cls)
def unflatten(aux, data):
inner_treedefs, metadata, num_arrays = aux
array_fields = [field.name for field in get_array_fields(new_cls)]
value_dict = {}
array_start = 0
for num_array, inner_treedef, array_field in zip(
num_arrays, inner_treedefs, array_fields, strict=True
):
value_dict[array_field] = jax.tree_util.tree_unflatten(
inner_treedef, data[array_start : array_start + num_array]
)
array_start += num_array
metadata_fields = get_metadata_fields(new_cls)
for field in metadata_fields:
value_dict[field.name] = getattr(metadata, field.name)
return new_cls(**value_dict)
jax.tree_util.register_pytree_node(
nodetype=new_cls, flatten_func=flatten, unflatten_func=unflatten
)
return new_cls
================================================
FILE: src/alphafold3/jax/geometry/utils.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Utils for geometry library."""
from collections.abc import Iterable
import numbers
import jax.numpy as jnp
def unstack(value: jnp.ndarray, axis: int = -1) -> list[jnp.ndarray]:
return [
jnp.squeeze(v, axis=axis)
for v in jnp.split(value, value.shape[axis], axis=axis)
]
def angdiff(alpha: jnp.ndarray, beta: jnp.ndarray) -> jnp.ndarray:
"""Compute absolute difference between two angles."""
d = alpha - beta
d = (d + jnp.pi) % (2 * jnp.pi) - jnp.pi
return d
def weighted_mean(
*,
weights: jnp.ndarray,
value: jnp.ndarray,
axis: int | Iterable[int] | None = None,
eps: float = 1e-10,
) -> jnp.ndarray:
"""Computes weighted mean in a safe way that avoids NaNs.
This is equivalent to jnp.average for the case eps=0.0, but adds a small
constant to the denominator of the weighted average to avoid NaNs.
'weights' should be broadcastable to the shape of value.
Args:
weights: Weights to weight value by.
value: Values to average
axis: Axes to average over.
eps: Epsilon to add to the denominator.
Returns:
Weighted average.
"""
weights = jnp.asarray(weights, dtype=value.dtype)
weights = jnp.broadcast_to(weights, value.shape)
weights_shape = weights.shape
if isinstance(axis, numbers.Integral):
axis = [axis]
elif axis is None:
axis = list(range(len(weights_shape)))
return jnp.sum(weights * value, axis=axis) / (
jnp.sum(weights, axis=axis) + eps
)
================================================
FILE: src/alphafold3/jax/geometry/vector.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Vec3Array Class."""
import dataclasses
from typing import Final, Self, TypeAlias
from alphafold3.jax.geometry import struct_of_array
from alphafold3.jax.geometry import utils
import jax
import jax.numpy as jnp
import numpy as np
Float: TypeAlias = float | jnp.ndarray
VERSION: Final[str] = '0.1'
@struct_of_array.StructOfArray(same_dtype=True)
class Vec3Array:
"""Vec3Array in 3 dimensional Space implemented as struct of arrays.
This is done in order to improve performance and precision.
On TPU small matrix multiplications are very suboptimal and will waste large
compute ressources, furthermore any matrix multiplication on TPU happens in
mixed bfloat16/float32 precision, which is often undesirable when handling
physical coordinates.
In most cases this will also be faster on CPUs/GPUs since it allows for easier
use of vector instructions.
"""
x: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32})
y: jnp.ndarray
z: jnp.ndarray
def __post_init__(self):
if hasattr(self.x, 'dtype'):
if not self.x.dtype == self.y.dtype == self.z.dtype:
raise ValueError(
f'Type mismatch: {self.x.dtype}, {self.y.dtype}, {self.z.dtype}'
)
if not self.x.shape == self.y.shape == self.z.shape:
raise ValueError(
f'Shape mismatch: {self.x.shape}, {self.y.shape}, {self.z.shape}'
)
def __add__(self, other: Self) -> Self:
return jax.tree.map(lambda x, y: x + y, self, other)
def __sub__(self, other: Self) -> Self:
return jax.tree.map(lambda x, y: x - y, self, other)
def __mul__(self, other: Float) -> Self:
return jax.tree.map(lambda x: x * other, self)
def __rmul__(self, other: Float) -> Self:
return self * other
def __truediv__(self, other: Float) -> Self:
return jax.tree.map(lambda x: x / other, self)
def __neg__(self) -> Self:
return jax.tree.map(lambda x: -x, self)
def __pos__(self) -> Self:
return jax.tree.map(lambda x: x, self)
def cross(self, other: Self) -> Self:
"""Compute cross product between 'self' and 'other'."""
new_x = self.y * other.z - self.z * other.y
new_y = self.z * other.x - self.x * other.z
new_z = self.x * other.y - self.y * other.x
return Vec3Array(new_x, new_y, new_z)
def dot(self, other: Self) -> Float:
"""Compute dot product between 'self' and 'other'."""
return self.x * other.x + self.y * other.y + self.z * other.z
def norm(self, epsilon: float = 1e-6) -> Float:
"""Compute Norm of Vec3Array, clipped to epsilon."""
# To avoid NaN on the backward pass, we must use maximum before the sqrt
norm2 = self.dot(self)
if epsilon:
norm2 = jnp.maximum(norm2, epsilon**2)
return jnp.sqrt(norm2)
def norm2(self):
return self.dot(self)
def normalized(self, epsilon: float = 1e-6) -> Self:
"""Return unit vector with optional clipping."""
return self / self.norm(epsilon)
@classmethod
def zeros(cls, shape, dtype=jnp.float32):
"""Return Vec3Array corresponding to zeros of given shape."""
return cls(
jnp.zeros(shape, dtype),
jnp.zeros(shape, dtype),
jnp.zeros(shape, dtype),
) # pytype: disable=wrong-arg-count # trace-all-classes
def to_array(self) -> jnp.ndarray:
return jnp.stack([self.x, self.y, self.z], axis=-1)
@classmethod
def from_array(cls, array):
return cls(*utils.unstack(array))
def __getstate__(self):
return (
VERSION,
[np.asarray(self.x), np.asarray(self.y), np.asarray(self.z)],
)
def __setstate__(self, state):
version, state = state
del version
for i, letter in enumerate('xyz'):
object.__setattr__(self, letter, state[i])
def square_euclidean_distance(
vec1: Vec3Array, vec2: Vec3Array, epsilon: float = 1e-6
) -> Float:
"""Computes square of euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute distance to
vec2: Vec3Array to compute distance from, should be broadcast compatible
with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of square euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
difference = vec1 - vec2
distance = difference.dot(difference)
if epsilon:
distance = jnp.maximum(distance, epsilon)
return distance
def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float:
return vector1.dot(vector2)
def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float:
return vector1.cross(vector2)
def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float:
return vector.norm(epsilon)
def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array:
return vector.normalized(epsilon)
def euclidean_distance(
vec1: Vec3Array, vec2: Vec3Array, epsilon: float = 1e-6
) -> Float:
"""Computes euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute euclidean distance to
vec2: Vec3Array to compute euclidean distance from, should be broadcast
compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2)
distance = jnp.sqrt(distance_sq)
return distance
def dihedral_angle(
a: Vec3Array, b: Vec3Array, c: Vec3Array, d: Vec3Array
) -> Float:
"""Computes torsion angle for a quadruple of points.
For points (a, b, c, d), this is the angle between the planes defined by
points (a, b, c) and (b, c, d). It is also known as the dihedral angle.
Arguments:
a: A Vec3Array of coordinates.
b: A Vec3Array of coordinates.
c: A Vec3Array of coordinates.
d: A Vec3Array of coordinates.
Returns:
A tensor of angles in radians: [-pi, pi].
"""
v1 = a - b
v2 = b - c
v3 = d - c
c1 = v1.cross(v2)
c2 = v3.cross(v2)
c3 = c2.cross(c1)
v2_mag = v2.norm()
return jnp.arctan2(c3.dot(v2), v2_mag * c1.dot(c2))
def random_gaussian_vector(shape, key, dtype=jnp.float32) -> Vec3Array:
vec_array = jax.random.normal(key, shape + (3,), dtype)
return Vec3Array.from_array(vec_array)
================================================
FILE: src/alphafold3/model/atom_layout/atom_layout.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Helper functions for different atom layouts and conversion between them."""
import collections
from collections.abc import Mapping, Sequence
import dataclasses
import types
from typing import Any, TypeAlias
from alphafold3 import structure
from alphafold3.constants import atom_types
from alphafold3.constants import chemical_component_sets
from alphafold3.constants import chemical_components
from alphafold3.constants import mmcif_names
from alphafold3.constants import residue_names
from alphafold3.data.tools import rdkit_utils
from alphafold3.structure import chemical_components as struc_chem_comps
import jax.numpy as jnp
import numpy as np
from rdkit import Chem
xnp_ndarray: TypeAlias = np.ndarray | jnp.ndarray # pylint: disable=invalid-name
NumpyIndex: TypeAlias = Any
@dataclasses.dataclass(frozen=True)
class AtomLayout:
"""Atom layout in a fixed shape (usually 1-dim or 2-dim).
Examples for atom layouts are atom37, atom14, and similar.
All members are np.ndarrays with the same shape, e.g.
- [num_atoms]
- [num_residues, max_atoms_per_residue]
- [num_fragments, max_fragments_per_residue]
All string arrays should have dtype=object to avoid pitfalls with Numpy's
fixed-size strings
Attributes:
atom_name: np.ndarray of str: atom names (e.g. 'CA', 'NE2'), padding
elements have an empty string (''), None or any other value, that maps to
False for .astype(bool). mmCIF field: _atom_site.label_atom_id.
res_id: np.ndarray of int: residue index (usually starting from 1) padding
elements can have an arbitrary value. mmCIF field:
_atom_site.label_seq_id.
chain_id: np.ndarray of str: chain names (e.g. 'A', 'B') padding elements
can have an arbitrary value. mmCIF field: _atom_site.label_seq_id.
atom_element: np.ndarray of str: atom elements (e.g. 'C', 'N', 'O'), padding
elements have an empty string (''), None or any other value, that maps to
False for .astype(bool). mmCIF field: _atom_site.type_symbol.
res_name: np.ndarray of str: residue names (e.g. 'ARG', 'TRP') padding
elements can have an arbitrary value. mmCIF field:
_atom_site.label_comp_id.
chain_type: np.ndarray of str: chain types (e.g. 'polypeptide(L)'). padding
elements can have an arbitrary value. mmCIF field: _entity_poly.type OR
_entity.type (for non-polymers).
shape: shape of the layout (just returns atom_name.shape)
"""
atom_name: np.ndarray
res_id: np.ndarray
chain_id: np.ndarray
atom_element: np.ndarray | None = None
res_name: np.ndarray | None = None
chain_type: np.ndarray | None = None
def __post_init__(self):
"""Assert all arrays have the same shape."""
attribute_names = (
'atom_name',
'atom_element',
'res_name',
'res_id',
'chain_id',
'chain_type',
)
_assert_all_arrays_have_same_shape(
obj=self,
expected_shape=self.atom_name.shape,
attribute_names=attribute_names,
)
# atom_name must have dtype object, such that we can convert it to bool to
# obtain the mask
if self.atom_name.dtype != object:
raise ValueError(
'atom_name must have dtype object, such that it can '
'be converted converted to bool to obtain the mask'
)
def __getitem__(self, key: NumpyIndex) -> 'AtomLayout':
return AtomLayout(
atom_name=self.atom_name[key],
res_id=self.res_id[key],
chain_id=self.chain_id[key],
atom_element=(
self.atom_element[key] if self.atom_element is not None else None
),
res_name=(self.res_name[key] if self.res_name is not None else None),
chain_type=(
self.chain_type[key] if self.chain_type is not None else None
),
)
def __eq__(self, other: 'AtomLayout') -> bool:
if not np.array_equal(self.atom_name, other.atom_name):
return False
mask = self.atom_name.astype(bool)
# Check essential fields.
for field in ('res_id', 'chain_id'):
my_arr = getattr(self, field)
other_arr = getattr(other, field)
if not np.array_equal(my_arr[mask], other_arr[mask]):
return False
# Check optional fields.
for field in ('atom_element', 'res_name', 'chain_type'):
my_arr = getattr(self, field)
other_arr = getattr(other, field)
if (
my_arr is not None
and other_arr is not None
and not np.array_equal(my_arr[mask], other_arr[mask])
):
return False
return True
def copy_and_pad_to(self, shape: tuple[int, ...]) -> 'AtomLayout':
"""Copies and pads the layout to the requested shape.
Args:
shape: new shape for the atom layout
Returns:
a copy of the atom layout padded to the requested shape
Raises:
ValueError: incompatible shapes.
"""
if len(shape) != len(self.atom_name.shape):
raise ValueError(
f'Incompatible shape {shape}. Current layout has shape {self.shape}.'
)
if any(new < old for old, new in zip(self.atom_name.shape, shape)):
raise ValueError(
"Can't pad to a smaller shape. Current layout has shape "
f'{self.shape} and you requested shape {shape}.'
)
pad_width = [
(0, new - old) for old, new in zip(self.atom_name.shape, shape)
]
pad_val = np.array('', dtype=object)
return AtomLayout(
atom_name=np.pad(self.atom_name, pad_width, constant_values=pad_val),
res_id=np.pad(self.res_id, pad_width, constant_values=0),
chain_id=np.pad(self.chain_id, pad_width, constant_values=pad_val),
atom_element=(
np.pad(self.atom_element, pad_width, constant_values=pad_val)
if self.atom_element is not None
else None
),
res_name=(
np.pad(self.res_name, pad_width, constant_values=pad_val)
if self.res_name is not None
else None
),
chain_type=(
np.pad(self.chain_type, pad_width, constant_values=pad_val)
if self.chain_type is not None
else None
),
)
def to_array(self) -> np.ndarray:
"""Stacks the fields to a numpy array with shape (6, ).
Creates a pure numpy array of type `object` by stacking the 6 fields of the
AtomLayout, i.e. (atom_name, atom_element, res_name, res_id, chain_id,
chain_type). This method together with from_array() provides an easy way to
apply pure numpy methods like np.concatenate() to `AtomLayout`s.
Returns:
np.ndarray of object with shape (6, ), e.g.
array([['N', 'CA', 'C', ..., 'CB', 'CG', 'CD'],
['N', 'C', 'C', ..., 'C', 'C', 'C'],
['LEU', 'LEU', 'LEU', ..., 'PRO', 'PRO', 'PRO'],
[1, 1, 1, ..., 403, 403, 403],
['A', 'A', 'A', ..., 'D', 'D', 'D'],
['polypeptide(L)', 'polypeptide(L)', ..., 'polypeptide(L)']],
dtype=object)
"""
if (
self.atom_element is None
or self.res_name is None
or self.chain_type is None
):
raise ValueError('All optional fields need to be present.')
return np.stack(dataclasses.astuple(self), axis=0)
@classmethod
def from_array(cls, arr: np.ndarray) -> 'AtomLayout':
"""Creates an AtomLayout object from a numpy array with shape (6, ...).
see also to_array()
Args:
arr: np.ndarray of object with shape (6, )
Returns:
AtomLayout object with shape ()
"""
if arr.shape[0] != 6:
raise ValueError(
'Given array must have shape (6, ...) to match the 6 fields of '
'AtomLayout (atom_name, atom_element, res_name, res_id, chain_id, '
f'chain_type). Your array has {arr.shape=}'
)
return cls(*arr)
@property
def shape(self) -> tuple[int, ...]:
return self.atom_name.shape
@dataclasses.dataclass(frozen=True)
class Residues:
"""List of residues with meta data.
Attributes:
res_name: np.ndarray of str [num_res], e.g. 'ARG', 'TRP'
res_id: np.ndarray of int [num_res]
chain_id: np.ndarray of str [num_res], e.g. 'A', 'B'
chain_type: np.ndarray of str [num_res], e.g. 'polypeptide(L)'
is_start_terminus: np.ndarray of bool [num_res]
is_end_terminus: np.ndarray of bool [num_res]
deprotonation: (optional) np.ndarray of set() [num_res], e.g. {'HD1', 'HE2'}
smiles_string: (optional) np.ndarray of str [num_res], e.g. 'Cc1ccccc1'
shape: shape of the layout (just returns res_name.shape)
"""
res_name: np.ndarray
res_id: np.ndarray
chain_id: np.ndarray
chain_type: np.ndarray
is_start_terminus: np.ndarray
is_end_terminus: np.ndarray
deprotonation: np.ndarray | None = None
smiles_string: np.ndarray | None = None
def __post_init__(self):
"""Assert all arrays are 1D have the same shape."""
attribute_names = (
'res_name',
'res_id',
'chain_id',
'chain_type',
'is_start_terminus',
'is_end_terminus',
'deprotonation',
'smiles_string',
)
_assert_all_arrays_have_same_shape(
obj=self,
expected_shape=(self.res_name.shape[0],),
attribute_names=attribute_names,
)
def __getitem__(self, key: NumpyIndex) -> 'Residues':
return Residues(
res_name=self.res_name[key],
res_id=self.res_id[key],
chain_id=self.chain_id[key],
chain_type=self.chain_type[key],
is_start_terminus=self.is_start_terminus[key],
is_end_terminus=self.is_end_terminus[key],
deprotonation=(
self.deprotonation[key] if self.deprotonation is not None else None
),
smiles_string=(
self.smiles_string[key] if self.smiles_string is not None else None
),
)
def __eq__(self, other: 'Residues') -> bool:
return all(
np.array_equal(getattr(self, field.name), getattr(other, field.name))
for field in dataclasses.fields(self)
)
@property
def shape(self) -> tuple[int, ...]:
return self.res_name.shape
@dataclasses.dataclass(frozen=True)
class GatherInfo:
"""Gather indices to translate from one atom layout to another.
All members are np or jnp ndarray (usually 1-dim or 2-dim) with the same
shape, e.g.
- [num_atoms]
- [num_residues, max_atoms_per_residue]
- [num_fragments, max_fragments_per_residue]
Attributes:
gather_idxs: np or jnp ndarray of int: gather indices into a flattened array
gather_mask: np or jnp ndarray of bool: mask for resulting array
input_shape: np or jnp ndarray of int: the shape of the unflattened input
array
shape: output shape. Just returns gather_idxs.shape
"""
gather_idxs: xnp_ndarray
gather_mask: xnp_ndarray
input_shape: xnp_ndarray
def __post_init__(self):
if self.gather_mask.shape != self.gather_idxs.shape:
raise ValueError(
'All arrays must have the same shape. Got\n'
f'gather_idxs.shape = {self.gather_idxs.shape}\n'
f'gather_mask.shape = {self.gather_mask.shape}\n'
)
def __getitem__(self, key: NumpyIndex) -> 'GatherInfo':
return GatherInfo(
gather_idxs=self.gather_idxs[key],
gather_mask=self.gather_mask[key],
input_shape=self.input_shape,
)
@property
def shape(self) -> tuple[int, ...]:
return self.gather_idxs.shape
def as_np_or_jnp(self, xnp: types.ModuleType) -> 'GatherInfo':
return GatherInfo(
gather_idxs=xnp.array(self.gather_idxs),
gather_mask=xnp.array(self.gather_mask),
input_shape=xnp.array(self.input_shape),
)
def as_dict(
self,
key_prefix: str | None = None,
) -> dict[str, xnp_ndarray]:
prefix = f'{key_prefix}:' if key_prefix else ''
return {
prefix + 'gather_idxs': self.gather_idxs,
prefix + 'gather_mask': self.gather_mask,
prefix + 'input_shape': self.input_shape,
}
@classmethod
def from_dict(
cls,
d: Mapping[str, xnp_ndarray],
key_prefix: str | None = None,
) -> 'GatherInfo':
"""Creates GatherInfo from a given dictionary."""
prefix = f'{key_prefix}:' if key_prefix else ''
return cls(
gather_idxs=d[prefix + 'gather_idxs'],
gather_mask=d[prefix + 'gather_mask'],
input_shape=d[prefix + 'input_shape'],
)
def fill_in_optional_fields(
minimal_atom_layout: AtomLayout,
reference_atoms: AtomLayout,
) -> AtomLayout:
"""Fill in the optional fields (atom_element, res_name, chain_type).
Extracts the optional fields (atom_element, res_name, chain_type) from a
flat reference layout and fills them into the fields from this layout.
Args:
minimal_atom_layout: An AtomLayout that only contains the essential fields
(atom_name, res_id, chain_id).
reference_atoms: A flat layout that contains all fields for all atoms.
Returns:
An AtomLayout that contains all fields.
Raises:
ValueError: Reference atoms layout is not flat.
ValueError: Missing atoms in reference.
"""
if len(reference_atoms.shape) > 1:
raise ValueError('Only flat layouts are supported as reference.')
ref_to_self = compute_gather_idxs(
source_layout=reference_atoms, target_layout=minimal_atom_layout
)
atom_mask = minimal_atom_layout.atom_name.astype(bool)
missing_atoms_mask = atom_mask & ~ref_to_self.gather_mask
if np.any(missing_atoms_mask):
raise ValueError(
f'{np.sum(missing_atoms_mask)} missing atoms in reference: '
f'{minimal_atom_layout[missing_atoms_mask]}'
)
def _convert_str_array(gather: GatherInfo, arr: np.ndarray):
output = arr[gather.gather_idxs]
output[~gather.gather_mask] = ''
return output
return dataclasses.replace(
minimal_atom_layout,
atom_element=_convert_str_array(
ref_to_self, reference_atoms.atom_element
),
res_name=_convert_str_array(ref_to_self, reference_atoms.res_name),
chain_type=_convert_str_array(ref_to_self, reference_atoms.chain_type),
)
def guess_deprotonation(residues: Residues) -> Residues:
"""Convenience function to create a plausible deprotonation field.
Assumes a pH of 7 and always prefers HE2 over HD1 for HIS.
Args:
residues: a Residues object without a depronotation field
Returns:
a Residues object with a depronotation field
"""
num_residues = residues.res_name.shape[0]
deprotonation = np.empty(num_residues, dtype=object)
deprotonation_at_ph7 = {
'ASP': 'HD2',
'GLU': 'HE2',
'HIS': 'HD1',
}
for idx, res_name in enumerate(residues.res_name):
deprotonation[idx] = set()
if res_name in deprotonation_at_ph7:
deprotonation[idx].add(deprotonation_at_ph7[res_name])
if residues.is_end_terminus[idx]:
deprotonation[idx].add('HXT')
return dataclasses.replace(residues, deprotonation=deprotonation)
def atom_layout_from_structure(
struct: structure.Structure,
*,
fix_non_standard_polymer_res: bool = False,
) -> AtomLayout:
"""Extract AtomLayout from a Structure."""
if not fix_non_standard_polymer_res:
return AtomLayout(
atom_name=np.array(struct.atom_name, dtype=object),
atom_element=np.array(struct.atom_element, dtype=object),
res_name=np.array(struct.res_name, dtype=object),
res_id=np.array(struct.res_id, dtype=int),
chain_id=np.array(struct.chain_id, dtype=object),
chain_type=np.array(struct.chain_type, dtype=object),
)
# Target lists.
target_atom_names = []
target_atom_elements = []
target_res_ids = []
target_res_names = []
target_chain_ids = []
target_chain_types = []
for atom in struct.iter_atoms():
target_atom_names.append(atom['atom_name'])
target_atom_elements.append(atom['atom_element'])
target_res_ids.append(atom['res_id'])
target_chain_ids.append(atom['chain_id'])
target_chain_types.append(atom['chain_type'])
if mmcif_names.is_standard_polymer_type(atom['chain_type']):
fixed_res_name = mmcif_names.fix_non_standard_polymer_res(
res_name=atom['res_name'], chain_type=atom['chain_type']
)
target_res_names.append(fixed_res_name)
else:
target_res_names.append(atom['res_name'])
return AtomLayout(
atom_name=np.array(target_atom_names, dtype=object),
atom_element=np.array(target_atom_elements, dtype=object),
res_name=np.array(target_res_names, dtype=object),
res_id=np.array(target_res_ids, dtype=int),
chain_id=np.array(target_chain_ids, dtype=object),
chain_type=np.array(target_chain_types, dtype=object),
)
def residues_from_structure(
struct: structure.Structure,
*,
include_missing_residues: bool = True,
fix_non_standard_polymer_res: bool = False,
) -> Residues:
"""Create a Residues object from a Structure object."""
def _get_smiles(res_name):
"""Get SMILES string from chemical components."""
smiles = None
if (
struct.chemical_components_data is not None
and struct.chemical_components_data.chem_comp is not None
and struct.chemical_components_data.chem_comp.get(res_name)
):
smiles = struct.chemical_components_data.chem_comp[res_name].pdbx_smiles
return smiles
res_names_per_chain = struct.chain_res_name_sequence(
include_missing_residues=include_missing_residues,
fix_non_standard_polymer_res=fix_non_standard_polymer_res,
)
res_name = []
res_id = []
chain_id = []
chain_type = []
smiles = []
is_start_terminus = []
for c in struct.iter_chains():
if include_missing_residues:
this_res_ids = [id for (_, id) in struct.all_residues[c['chain_id']]]
else:
this_res_ids = [
r['res_id']
for r in struct.iter_residues()
if r['chain_id'] == c['chain_id']
]
fixed_res_names = res_names_per_chain[c['chain_id']]
assert len(this_res_ids) == len(
fixed_res_names
), f'{len(this_res_ids)} != {len(fixed_res_names)}'
this_start_res_id = min(min(this_res_ids), 1)
this_is_start_terminus = [r == this_start_res_id for r in this_res_ids]
smiles.extend([_get_smiles(res_name) for res_name in fixed_res_names])
num_res = len(fixed_res_names)
res_name.extend(fixed_res_names)
res_id.extend(this_res_ids)
chain_id.extend([c['chain_id']] * num_res)
chain_type.extend([c['chain_type']] * num_res)
is_start_terminus.extend(this_is_start_terminus)
res_name = np.array(res_name, dtype=object)
res_id = np.array(res_id, dtype=int)
chain_id = np.array(chain_id, dtype=object)
chain_type = np.array(chain_type, dtype=object)
smiles = np.array(smiles, dtype=object)
is_start_terminus = np.array(is_start_terminus, dtype=bool)
res_uid_to_idx = {
uid: idx for idx, uid in enumerate(zip(chain_id, res_id, strict=True))
}
# Start terminus indicates whether residue index is 1 and chain is polymer.
is_polymer = np.isin(chain_type, tuple(mmcif_names.POLYMER_CHAIN_TYPES))
is_start_terminus = is_start_terminus & is_polymer
# Start also indicates whether amino acid is attached to H2 or proline to H.
start_terminus_atom_index = np.nonzero(
(struct.chain_type == mmcif_names.PROTEIN_CHAIN)
& (
(struct.atom_name == 'H2')
| ((struct.atom_name == 'H') & (struct.res_name == 'PRO'))
)
)[0]
# Translate atom idx to residue idx to assign start terminus.
for atom_idx in start_terminus_atom_index:
res_uid = (struct.chain_id[atom_idx], struct.res_id[atom_idx])
res_idx = res_uid_to_idx[res_uid]
is_start_terminus[res_idx] = True
# Infer end terminus: Check for OXT, or in case of
# include_missing_residues==True for the last residue of the chain.
num_all_residues = res_name.shape[0]
is_end_terminus = np.zeros(num_all_residues, dtype=bool)
end_term_atom_idxs = np.nonzero(struct.atom_name == 'OXT')[0]
for atom_idx in end_term_atom_idxs:
res_uid = (struct.chain_id[atom_idx], struct.res_id[atom_idx])
res_idx = res_uid_to_idx[res_uid]
is_end_terminus[res_idx] = True
if include_missing_residues:
for idx in range(num_all_residues - 1):
if is_polymer[idx] and chain_id[idx] != chain_id[idx + 1]:
is_end_terminus[idx] = True
if (num_all_residues > 0) and is_polymer[-1]:
is_end_terminus[-1] = True
# Infer (de-)protonation: Only if hydrogens are given.
num_hydrogens = np.sum(
(struct.atom_element == 'H') & (struct.chain_type == 'polypeptide(L)')
)
if num_hydrogens > 0:
deprotonation = np.empty(num_all_residues, dtype=object)
all_atom_uids = set(
zip(struct.chain_id, struct.res_id, struct.atom_name, strict=True)
)
for idx in range(num_all_residues):
deprotonation[idx] = set()
check_hydrogens = set()
if is_end_terminus[idx]:
check_hydrogens.add('HXT')
if res_name[idx] in atom_types.PROTONATION_HYDROGENS:
check_hydrogens.update(atom_types.PROTONATION_HYDROGENS[res_name[idx]])
for hydrogen in check_hydrogens:
if (chain_id[idx], res_id[idx], hydrogen) not in all_atom_uids:
deprotonation[idx].add(hydrogen)
else:
deprotonation = None
return Residues(
res_name=res_name,
res_id=res_id,
chain_id=chain_id,
chain_type=chain_type,
is_start_terminus=is_start_terminus.astype(bool),
is_end_terminus=is_end_terminus,
deprotonation=deprotonation,
smiles_string=smiles,
)
def get_link_drop_atoms(
res_name: str,
chain_type: str,
*,
is_start_terminus: bool,
is_end_terminus: bool,
bonded_atoms: set[str],
drop_ligand_leaving_atoms: bool = False,
) -> set[str]:
"""Returns set of atoms that are dropped when this res_name gets linked.
Args:
res_name: residue name, e.g. 'ARG'
chain_type: chain_type, e.g. 'polypeptide(L)'
is_start_terminus: whether the residue is the n-terminus
is_end_terminus: whether the residue is the c-terminus
bonded_atoms: Names of atoms coming off this residue.
drop_ligand_leaving_atoms: Flag to switch on/off leaving atoms for ligands.
Returns:
Set of atoms that are dropped when this amino acid gets linked.
"""
drop_atoms = set()
if chain_type == mmcif_names.PROTEIN_CHAIN:
if res_name == 'PRO':
if not is_start_terminus:
drop_atoms.update({'H', 'H2', 'H3'})
if not is_end_terminus:
drop_atoms.update({'OXT', 'HXT'})
else:
if not is_start_terminus:
drop_atoms.update({'H2', 'H3'})
if not is_end_terminus:
drop_atoms.update({'OXT', 'HXT'})
elif chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES:
if not is_start_terminus:
drop_atoms.update({'OP3'})
elif (
drop_ligand_leaving_atoms and chain_type in mmcif_names.LIGAND_CHAIN_TYPES
):
if res_name in {
*chemical_component_sets.GLYCAN_OTHER_LIGANDS,
*chemical_component_sets.GLYCAN_LINKING_LIGANDS,
}:
if 'O1' not in bonded_atoms:
drop_atoms.update({'O1'})
return drop_atoms
def get_bonded_atoms(
polymer_ligand_bonds: AtomLayout,
ligand_ligand_bonds: AtomLayout,
res_id: int,
chain_id: str,
) -> set[str]:
"""Finds the res_name on the opposite end of the bond, if a bond exists.
Args:
polymer_ligand_bonds: Bond information for polymer-ligand pairs.
ligand_ligand_bonds: Bond information for ligand-ligand pairs.
res_id: residue id in question.
chain_id: chain id of residue in question.
Returns:
res_name of bonded atom.
"""
bonded_atoms = set()
if polymer_ligand_bonds:
# Filter before searching to speed this up.
bond_idx = np.logical_and(
polymer_ligand_bonds.res_id == res_id,
polymer_ligand_bonds.chain_id == chain_id,
).any(axis=1)
relevant_polymer_bonds = polymer_ligand_bonds[bond_idx]
for atom_names, res_ids, chain_ids in zip(
relevant_polymer_bonds.atom_name,
relevant_polymer_bonds.res_id,
relevant_polymer_bonds.chain_id,
):
if (res_ids[0], chain_ids[0]) == (res_id, chain_id):
bonded_atoms.add(atom_names[0])
elif (res_ids[1], chain_ids[1]) == (res_id, chain_id):
bonded_atoms.add(atom_names[1])
if ligand_ligand_bonds:
bond_idx = np.logical_and(
ligand_ligand_bonds.res_id == res_id,
ligand_ligand_bonds.chain_id == chain_id,
).any(axis=1)
relevant_ligand_bonds = ligand_ligand_bonds[bond_idx]
for atom_names, res_ids, chain_ids in zip(
relevant_ligand_bonds.atom_name,
relevant_ligand_bonds.res_id,
relevant_ligand_bonds.chain_id,
):
if (res_ids[0], chain_ids[0]) == (res_id, chain_id):
bonded_atoms.add(atom_names[0])
elif (res_ids[1], chain_ids[1]) == (res_id, chain_id):
bonded_atoms.add(atom_names[1])
return bonded_atoms
def make_flat_atom_layout(
residues: Residues,
ccd: chemical_components.Ccd,
polymer_ligand_bonds: AtomLayout | None = None,
ligand_ligand_bonds: AtomLayout | None = None,
*,
with_hydrogens: bool = False,
skip_unk_residues: bool = True,
drop_ligand_leaving_atoms: bool = False,
) -> AtomLayout:
"""Make a flat atom layout for given residues.
Create a flat layout from a `Residues` object. The required atoms for each
amino acid type are taken from the CCD, hydrogens and oxygens are dropped to
make the linked residues. Terminal OXT's and protonation state for the
hydrogens come from the `Residues` object.
Args:
residues: a `Residues` object.
ccd: The chemical components dictionary.
polymer_ligand_bonds: Bond information for polymer-ligand pairs.
ligand_ligand_bonds: Bond information for ligand-ligand pairs.
with_hydrogens: whether to create hydrogens
skip_unk_residues: whether to skip 'UNK' resides -- default is True to be
compatible with the rest of AlphaFold that does not predict atoms for
unknown residues
drop_ligand_leaving_atoms: Flag to switch on/ off leaving atoms for ligands.
Returns:
an `AtomLayout` object
"""
num_res = residues.res_name.shape[0]
# Target lists.
target_atom_names = []
target_atom_elements = []
target_res_ids = []
target_res_names = []
target_chain_ids = []
target_chain_types = []
for idx in range(num_res):
# skip 'UNK' residues if requested
if (
skip_unk_residues
and residues.res_name[idx] in residue_names.UNKNOWN_TYPES
):
continue
# Get the atoms for this residue type from CCD.
if ccd.get(residues.res_name[idx]):
res_atoms = struc_chem_comps.get_all_atoms_in_entry(
ccd=ccd, res_name=residues.res_name[idx]
)
atom_names_elements = list(
zip(
res_atoms['_chem_comp_atom.atom_id'],
res_atoms['_chem_comp_atom.type_symbol'],
strict=True,
)
)
elif residues.smiles_string[idx]:
# Get atoms from RDKit via SMILES.
mol = Chem.MolFromSmiles(residues.smiles_string[idx])
if mol is None:
raise ValueError(
f'Failed to construct RDKit Mol for {residues.res_name[idx]} from'
f' SMILES string: {residues.smiles_string[idx]} . This is likely'
' due to an issue with the SMILES string. Note that the userCCD'
' input format provides an alternative way to define custom'
' molecules directly without RDKit or SMILES.'
)
mol = rdkit_utils.assign_atom_names_from_graph(mol)
atom_names_elements = [
(a.GetProp('atom_name'), a.GetSymbol()) for a in mol.GetAtoms()
]
else:
raise ValueError(
f'{residues.res_name[idx]} not found in CCD and no SMILES string'
)
# Remove hydrogens if requested.
if not with_hydrogens:
atom_names_elements = [
(n, e) for n, e in atom_names_elements if (e != 'H' and e != 'D')
]
bonded_atoms = get_bonded_atoms(
polymer_ligand_bonds,
ligand_ligand_bonds,
residues.res_id[idx],
residues.chain_id[idx],
)
# Connect the amino-acids, i.e. remove OXT, HXT and H2.
drop_atoms = get_link_drop_atoms(
res_name=residues.res_name[idx],
chain_type=residues.chain_type[idx],
is_start_terminus=residues.is_start_terminus[idx],
is_end_terminus=residues.is_end_terminus[idx],
bonded_atoms=bonded_atoms,
drop_ligand_leaving_atoms=drop_ligand_leaving_atoms,
)
# If deprotonation info is available, remove the specific atoms.
if residues.deprotonation is not None:
drop_atoms.update(residues.deprotonation[idx])
atom_names_elements = [
(n, e) for n, e in atom_names_elements if n not in drop_atoms
]
# Append the found atoms to the target lists.
target_atom_names.extend([n for n, _ in atom_names_elements])
target_atom_elements.extend([e for _, e in atom_names_elements])
num_atoms = len(atom_names_elements)
target_res_names.extend([residues.res_name[idx]] * num_atoms)
target_res_ids.extend([residues.res_id[idx]] * num_atoms)
target_chain_ids.extend([residues.chain_id[idx]] * num_atoms)
target_chain_types.extend([residues.chain_type[idx]] * num_atoms)
return AtomLayout(
atom_name=np.array(target_atom_names, dtype=object),
atom_element=np.array(target_atom_elements, dtype=object),
res_name=np.array(target_res_names, dtype=object),
res_id=np.array(target_res_ids, dtype=int),
chain_id=np.array(target_chain_ids, dtype=object),
chain_type=np.array(target_chain_types, dtype=object),
)
def compute_gather_idxs(
*,
source_layout: AtomLayout,
target_layout: AtomLayout,
fill_value: int = 0,
) -> GatherInfo:
"""Produce gather indices and mask to convert from source layout to target."""
source_uid_to_idx = {
uid: idx
for idx, uid in enumerate(
zip(
source_layout.chain_id.ravel(),
source_layout.res_id.ravel(),
source_layout.atom_name.ravel(),
strict=True,
)
)
}
gather_idxs = []
gather_mask = []
for uid in zip(
target_layout.chain_id.ravel(),
target_layout.res_id.ravel(),
target_layout.atom_name.ravel(),
strict=True,
):
if uid in source_uid_to_idx:
gather_idxs.append(source_uid_to_idx[uid])
gather_mask.append(True)
else:
gather_idxs.append(fill_value)
gather_mask.append(False)
target_shape = target_layout.atom_name.shape
return GatherInfo(
gather_idxs=np.array(gather_idxs, dtype=int).reshape(target_shape),
gather_mask=np.array(gather_mask, dtype=bool).reshape(target_shape),
input_shape=np.array(source_layout.atom_name.shape),
)
def convert(
gather_info: GatherInfo,
arr: xnp_ndarray,
*,
layout_axes: tuple[int, ...] = (0,),
) -> xnp_ndarray:
"""Convert an array from one atom layout to another."""
# Translate negative indices to the corresponding positives.
layout_axes = tuple(i if i >= 0 else i + arr.ndim for i in layout_axes)
# Ensure that layout_axes are continuous.
layout_axes_begin = layout_axes[0]
layout_axes_end = layout_axes[-1] + 1
if layout_axes != tuple(range(layout_axes_begin, layout_axes_end)):
raise ValueError(f'layout_axes must be continuous. Got {layout_axes}.')
layout_shape = arr.shape[layout_axes_begin:layout_axes_end]
# Ensure that the layout shape is compatible
# with the gather_info. I.e. the first axis size must be equal or greater
# than the gather_info.input_shape, and all subsequent axes sizes must match.
if (len(layout_shape) != gather_info.input_shape.size) or (
isinstance(gather_info.input_shape, np.ndarray)
and (
(layout_shape[0] < gather_info.input_shape[0])
or (np.any(layout_shape[1:] != gather_info.input_shape[1:]))
)
):
raise ValueError(
'Input array layout axes are incompatible. You specified layout '
f'axes {layout_axes} with an input array of shape {arr.shape}, but '
f'the gather info expects shape {gather_info.input_shape}. '
'Your first axis size must be equal or greater than the '
'gather_info.input_shape, and all subsequent axes sizes must '
'match.'
)
# Compute the shape of the input array with flattened layout.
batch_shape = arr.shape[:layout_axes_begin]
features_shape = arr.shape[layout_axes_end:]
arr_flattened_shape = batch_shape + (np.prod(layout_shape),) + features_shape
# Flatten input array and perform the gather.
arr_flattened = arr.reshape(arr_flattened_shape)
if layout_axes_begin == 0:
out_arr = arr_flattened[gather_info.gather_idxs, ...]
elif layout_axes_begin == 1:
out_arr = arr_flattened[:, gather_info.gather_idxs, ...]
elif layout_axes_begin == 2:
out_arr = arr_flattened[:, :, gather_info.gather_idxs, ...]
elif layout_axes_begin == 3:
out_arr = arr_flattened[:, :, :, gather_info.gather_idxs, ...]
elif layout_axes_begin == 4:
out_arr = arr_flattened[:, :, :, :, gather_info.gather_idxs, ...]
else:
raise ValueError(
'Only 4 batch axes supported. If you need more, the code '
'is easy to extend.'
)
# Broadcast the mask and apply it.
broadcasted_mask_shape = (
(1,) * len(batch_shape)
+ gather_info.gather_mask.shape
+ (1,) * len(features_shape)
)
out_arr *= gather_info.gather_mask.reshape(broadcasted_mask_shape)
return out_arr
def make_structure(
flat_layout: AtomLayout,
atom_coords: np.ndarray,
name: str,
*,
atom_b_factors: np.ndarray | None = None,
all_physical_residues: Residues | None = None,
) -> structure.Structure:
"""Returns a Structure from a flat layout and atom coordinates.
The provided flat_layout must be 1-dim and must not contain any padding
elements. The flat_layout.atom_name must conform to the OpenMM/CCD standard
and must not contain deuterium.
Args:
flat_layout: flat 1-dim AtomLayout without pading elements
atom_coords: np.ndarray of float, shape (num_atoms, 3)
name: str: the name (usually PDB id), e.g. '1uao'
atom_b_factors: np.ndarray of float, shape (num_atoms,) or None. If None,
they will be set to all zeros.
all_physical_residues: a Residues object that contains all physically
existing residues, i.e. also those residues that have no resolved atoms.
This is common in experimental structures, but also appears in predicted
structures for 'UNK' or other non-standard residue types, where the model
does not predict coordinates. This will be used to create the
`all_residues` field of the structure object.
"""
if flat_layout.atom_name.ndim != 1 or not np.all(
flat_layout.atom_name.astype(bool)
):
raise ValueError(
'flat_layout must be 1-dim and must not contain anypadding element'
)
if (
flat_layout.atom_element is None
or flat_layout.res_name is None
or flat_layout.chain_type is None
):
raise ValueError('All optional fields must be present.')
if atom_b_factors is None:
atom_b_factors = np.zeros(atom_coords.shape[:-1])
if all_physical_residues is not None:
# Create the all_residues field from a Residues object
# (unfortunately there is no central place to keep the chain_types in
# the structure class, so we drop it here)
all_residues = collections.defaultdict(list)
for chain_id, res_id, res_name in zip(
all_physical_residues.chain_id,
all_physical_residues.res_id,
all_physical_residues.res_name,
strict=True,
):
all_residues[chain_id].append((res_name, res_id))
else:
# Create the all_residues field from the flat_layout
all_residues = collections.defaultdict(list)
if flat_layout.chain_id.shape[0] > 0:
all_residues[flat_layout.chain_id[0]].append(
(flat_layout.res_name[0], flat_layout.res_id[0])
)
for i in range(1, flat_layout.shape[0]):
if (
flat_layout.chain_id[i] != flat_layout.chain_id[i - 1]
or flat_layout.res_name[i] != flat_layout.res_name[i - 1]
or flat_layout.res_id[i] != flat_layout.res_id[i - 1]
):
all_residues[flat_layout.chain_id[i]].append(
(flat_layout.res_name[i], flat_layout.res_id[i])
)
return structure.from_atom_arrays(
name=name,
all_residues=dict(all_residues),
chain_id=flat_layout.chain_id,
chain_type=flat_layout.chain_type,
res_id=flat_layout.res_id.astype(np.int32),
res_name=flat_layout.res_name,
atom_name=flat_layout.atom_name,
atom_element=flat_layout.atom_element,
atom_x=atom_coords[..., 0],
atom_y=atom_coords[..., 1],
atom_z=atom_coords[..., 2],
atom_b_factor=atom_b_factors,
)
def _assert_all_arrays_have_same_shape(
*,
obj: AtomLayout | Residues | GatherInfo,
expected_shape: tuple[int, ...],
attribute_names: Sequence[str],
) -> None:
"""Checks that given attributes of the object have the expected shape."""
attribute_shapes_description = []
all_shapes_are_valid = True
for attribute_name in attribute_names:
attribute = getattr(obj, attribute_name)
if attribute is None:
attribute_shape = None
else:
attribute_shape = attribute.shape
if attribute_shape is not None and expected_shape != attribute_shape:
all_shapes_are_valid = False
attribute_shape_name = attribute_name + '.shape'
attribute_shapes_description.append(
f'{attribute_shape_name:25} = {attribute_shape}'
)
if not all_shapes_are_valid:
raise ValueError(
f'All arrays must have the same shape ({expected_shape=}). Got\n'
+ '\n'.join(attribute_shapes_description)
)
================================================
FILE: src/alphafold3/model/components/haiku_modules.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Common Haiku modules."""
from collections.abc import Sequence
import contextlib
import numbers
from typing import TypeAlias
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
PRECISION: TypeAlias = (
None
| str
| jax.lax.Precision
| tuple[str, str]
| tuple[jax.lax.Precision, jax.lax.Precision]
)
# Useful for mocking in tests.
DEFAULT_PRECISION = None
# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(
0.87962566103423978, dtype=np.float32
)
class LayerNorm(hk.LayerNorm):
"""LayerNorm module.
Equivalent to hk.LayerNorm but with an extra 'upcast' option that casts
(b)float16 inputs to float32 before computing the layer norm, and then casts
the output back to the input type.
The learnable parameter shapes are also different from Haiku: they are always
vectors rather than possibly higher-rank tensors. This makes it easier
to change the layout whilst keep the model weight-compatible.
"""
def __init__(
self,
*,
axis: int = -1,
create_scale: bool = True,
create_offset: bool = True,
eps: float = 1e-5,
scale_init: hk.initializers.Initializer | None = None,
offset_init: hk.initializers.Initializer | None = None,
use_fast_variance: bool = True,
name: str,
param_axis: int | None = None,
upcast: bool = True,
):
super().__init__(
axis=axis,
create_scale=False,
create_offset=False,
eps=eps,
scale_init=None,
offset_init=None,
use_fast_variance=use_fast_variance,
name=name,
param_axis=param_axis,
)
self.upcast = upcast
self._temp_create_scale = create_scale
self._temp_create_offset = create_offset
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
dtype = x.dtype
is_16bit = x.dtype in [jnp.bfloat16, jnp.float16]
if self.upcast and is_16bit:
x = x.astype(jnp.float32)
param_axis = self.param_axis[0] if self.param_axis else -1
param_shape = (x.shape[param_axis],)
param_broadcast_shape = [1] * x.ndim
param_broadcast_shape[param_axis] = x.shape[param_axis]
scale = None
offset = None
if self._temp_create_scale:
scale = hk.get_parameter(
'scale', param_shape, x.dtype, init=self.scale_init
)
scale = scale.reshape(param_broadcast_shape)
if self._temp_create_offset:
offset = hk.get_parameter(
'offset', param_shape, x.dtype, init=self.offset_init
)
offset = offset.reshape(param_broadcast_shape)
out = super().__call__(x, scale=scale, offset=offset)
if self.upcast and is_16bit:
out = out.astype(dtype)
return out
def haiku_linear_get_params(
inputs: jax.Array | jax.ShapeDtypeStruct,
*,
num_output: int | Sequence[int],
use_bias: bool = False,
num_input_dims: int = 1,
initializer: str = 'linear',
bias_init: float = 0.0,
transpose_weights: bool = False,
name: str | None = None,
) -> tuple[jax.Array, jax.Array | None]:
"""Get parameters for linear layer.
Parameters will be at least float32 or higher precision.
Arguments:
inputs: The input to the Linear layer. Can be either a JAX array or a
jax.ShapeDtypeStruct.
num_output: The number of output channels. Can be an integer or a sequence
of integers.
use_bias: Whether to create a bias array.
num_input_dims: The number of dimensions to consider as channel dims in the
input.
initializer: The name of the weight initializer to use.
bias_init: A float used to initialize the bias.
transpose_weights: If True, will create a transposed version of the weights.
name: The Haiku namespace to use for the weight and bias.
Returns:
A tuple[weight, bias] if use_bias otherwise tuple[weight, None].
"""
if isinstance(num_output, numbers.Integral):
output_shape = (num_output,)
else:
output_shape = tuple(num_output)
if num_input_dims > 0:
in_shape = inputs.shape[-num_input_dims:]
elif num_input_dims == 0:
in_shape = ()
else:
raise ValueError('num_input_dims must be >= 0.')
weight_init = _get_initializer_scale(initializer, in_shape)
with hk.name_scope(name) if name else contextlib.nullcontext():
if transpose_weights:
weight_shape = output_shape + in_shape
weights = hk.get_parameter(
'weights', shape=weight_shape, dtype=inputs.dtype, init=weight_init
)
else:
weight_shape = in_shape + output_shape
weights = hk.get_parameter(
name='weights',
shape=weight_shape,
dtype=inputs.dtype,
init=weight_init,
)
bias = None
if use_bias:
bias = hk.get_parameter(
name='bias',
shape=output_shape,
dtype=inputs.dtype,
init=hk.initializers.Constant(bias_init),
)
return weights, bias
class Linear(hk.Module):
"""Custom Linear Module.
This differs from the standard Linear in a few ways:
* It supports inputs of arbitrary rank
* It allows to use ntk parametrization
* Initializers are specified by strings
* It allows to explicitly specify which dimension of the input will map to
the tpu sublane/lane dimensions.
"""
def __init__(
self,
num_output: int | Sequence[int],
*,
initializer: str = 'linear',
num_input_dims: int = 1,
use_bias: bool = False,
bias_init: float = 0.0,
precision: PRECISION = None,
fast_scalar_mode: bool = True,
transpose_weights: bool = False,
name: str,
):
"""Constructs Linear Module.
Args:
num_output: number of output channels. Can be tuple when outputting
multiple dimensions.
initializer: What initializer to use, should be one of {'linear', 'relu',
'zeros'}.
num_input_dims: Number of dimensions from the end to project.
use_bias: Whether to include trainable bias (False by default).
bias_init: Value used to initialize bias.
precision: What precision to use for matrix multiplication, defaults to
None.
fast_scalar_mode: Whether to use optimized path for num_input_dims = 0.
transpose_weights: decides whether weights have shape [input, output] or
[output, input], True means [output, input], this is helpful to avoid
padding on the tensors holding the weights.
name: name of module, used for name scopes.
"""
super().__init__(name=name)
if isinstance(num_output, numbers.Integral):
self.output_shape = (num_output,)
else:
self.output_shape = tuple(num_output)
self.initializer = initializer
self.use_bias = use_bias
self.bias_init = bias_init
self.num_input_dims = num_input_dims
self.num_output_dims = len(self.output_shape)
self.precision = precision if precision is not None else DEFAULT_PRECISION
self.fast_scalar_mode = fast_scalar_mode
self.transpose_weights = transpose_weights
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
"""Connects Module.
Args:
inputs: Tensor of shape [..., num_channel]
Returns:
output of shape [..., num_output]
"""
num_input_dims = self.num_input_dims
# Adds specialized path for scalar inputs in Linear layer,
# this means the linear Layer does not use the matmul units on the tpu,
# which is more efficient and gives compiler more flexibility over layout.
if num_input_dims == 0 and self.fast_scalar_mode:
weight_shape = self.output_shape
if self.initializer == 'zeros':
w_init = hk.initializers.Constant(0.0)
else:
distribution_stddev = jnp.array(1 / TRUNCATED_NORMAL_STDDEV_FACTOR)
w_init = hk.initializers.TruncatedNormal(
mean=0.0, stddev=distribution_stddev
)
weights = hk.get_parameter('weights', weight_shape, inputs.dtype, w_init)
inputs = jnp.expand_dims(
inputs, tuple(range(-1, -self.num_output_dims - 1, -1))
)
output = inputs * weights
else:
if self.num_input_dims > 0:
in_shape = inputs.shape[-self.num_input_dims :]
else:
in_shape = ()
weight_init = _get_initializer_scale(self.initializer, in_shape)
in_letters = 'abcde'[: self.num_input_dims]
out_letters = 'hijkl'[: self.num_output_dims]
if self.transpose_weights:
weight_shape = self.output_shape + in_shape
weights = hk.get_parameter(
'weights', weight_shape, inputs.dtype, weight_init
)
equation = (
f'...{in_letters}, {out_letters}{in_letters}->...{out_letters}'
)
else:
weight_shape = in_shape + self.output_shape
weights = hk.get_parameter(
'weights', weight_shape, inputs.dtype, weight_init
)
equation = (
f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}'
)
output = jnp.einsum(equation, inputs, weights, precision=self.precision)
if self.use_bias:
bias = hk.get_parameter(
'bias',
self.output_shape,
inputs.dtype,
hk.initializers.Constant(self.bias_init),
)
output += bias
return output
def _get_initializer_scale(initializer_name, input_shape):
"""Get initializer for weights."""
if initializer_name == 'zeros':
w_init = hk.initializers.Constant(0.0)
else:
# fan-in scaling
noise_scale = 1.0
for channel_dim in input_shape:
noise_scale /= channel_dim
if initializer_name == 'relu':
noise_scale *= 2
stddev = np.sqrt(noise_scale)
# Adjust stddev for truncation.
stddev = stddev / TRUNCATED_NORMAL_STDDEV_FACTOR
w_init = hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev)
return w_init
================================================
FILE: src/alphafold3/model/components/mapping.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Specialized mapping functions."""
from collections.abc import Callable, Sequence
import functools
from typing import Any, TypeVar
import haiku as hk
import jax
import jax.numpy as jnp
Pytree = Any
PytreeJaxArray = Any
partial = functools.partial
PROXY = object()
T = TypeVar("T")
def _maybe_slice(array, i, slice_size, axis):
if axis is PROXY:
return array
else:
return jax.lax.dynamic_slice_in_dim(
array, i, slice_size=slice_size, axis=axis
)
def _maybe_get_size(array, axis):
if axis == PROXY:
return -1
else:
return array.shape[axis]
def _expand_axes(axes, values, name="sharded_apply"):
values_tree_def = jax.tree_util.tree_structure(values)
flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes)
# Replace None's with PROXY.
flat_axes = [PROXY if x is None else x for x in flat_axes]
return jax.tree_util.tree_unflatten(values_tree_def, flat_axes)
def sharded_map(
fun: Callable[..., PytreeJaxArray],
shard_size: int | None = 1,
in_axes: int | Pytree = 0,
out_axes: int | Pytree = 0,
) -> Callable[..., PytreeJaxArray]:
"""Sharded vmap.
Maps `fun` over axes, in a way similar to hk.vmap, but does so in shards of
`shard_size`. This allows a smooth trade-off between memory usage
(as in a plain map) vs higher throughput (as in a vmap).
Args:
fun: Function to apply smap transform to.
shard_size: Integer denoting shard size.
in_axes: Either integer or pytree describing which axis to map over for each
input to `fun`, None denotes broadcasting.
out_axes: Integer or pytree denoting to what axis in the output the mapped
over axis maps.
Returns:
Function with smap applied.
"""
if hk.running_init():
# Guarantees initialisation independent of shard_size. Doesn't incur a high
# memory cost, as long as large concrete tensors are not encountered.
return hk.vmap(fun, in_axes=in_axes, out_axes=out_axes, split_rng=False)
else:
vmapped_fun = hk.vmap(fun, in_axes, out_axes, split_rng=True)
return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes)
def _set_docstring(docstr: str) -> Callable[[T], T]:
"""Decorator for setting the docstring of a function."""
def wrapped(fun: T) -> T:
fun.__doc__ = docstr.format(fun=getattr(fun, "__name__", repr(fun)))
return fun
return wrapped
def sharded_apply(
fun: Callable[..., PytreeJaxArray],
shard_size: int | None = 1,
in_axes: int | Pytree = 0,
out_axes: int | Pytree = 0,
new_out_axes: bool = False,
) -> Callable[..., PytreeJaxArray]:
"""Sharded apply.
Applies `fun` over shards to axes, in a way similar to vmap,
but does so in shards of `shard_size`. Shards are stacked after.
This allows a smooth trade-off between
memory usage (as in a plain map) vs higher throughput (as in a vmap).
Args:
fun: Function to apply smap transform to.
shard_size: Integer denoting shard size. None will return `fun` unchanged.
in_axes: Either integer or pytree describing which axis to map over for each
input to `fun`, None denotes broadcasting.
out_axes: Integer or pytree denoting to what axis in the output the mapped
over axis maps.
new_out_axes: Whether to stack outputs on new axes. This assumes that the
output sizes for each shard (including the possible remainder shard) are
the same.
Returns:
Function with smap applied.
"""
docstr = (
"Mapped version of {fun}. Takes similar arguments to {fun} "
"but with additional array axes over which {fun} is mapped."
)
if new_out_axes:
raise NotImplementedError("New output axes not yet implemented.")
if shard_size is None:
return fun
@_set_docstring(docstr)
@functools.wraps(fun)
def mapped_fn(*args, **kwargs):
# Expand in axes and determine loop range.
in_axes_ = _expand_axes(in_axes, args)
in_sizes = jax.tree.map(_maybe_get_size, args, in_axes_)
in_size = max(jax.tree_util.tree_leaves(in_sizes))
num_extra_shards = (in_size - 1) // shard_size
# Fix if necessary.
last_shard_size = in_size % shard_size
last_shard_size = shard_size if last_shard_size == 0 else last_shard_size
def apply_fun_to_slice(slice_start, slice_size):
input_slice = jax.tree.map(
lambda array, axis: _maybe_slice(
array, slice_start, slice_size, axis
),
args,
in_axes_,
)
return fun(*input_slice, **kwargs)
remainder_shape_dtype = hk.eval_shape(
partial(apply_fun_to_slice, 0, last_shard_size)
)
out_dtypes = jax.tree.map(lambda x: x.dtype, remainder_shape_dtype)
out_shapes = jax.tree.map(lambda x: x.shape, remainder_shape_dtype)
out_axes_ = _expand_axes(out_axes, remainder_shape_dtype)
if num_extra_shards > 0:
regular_shard_shape_dtype = hk.eval_shape(
partial(apply_fun_to_slice, 0, shard_size)
)
shard_shapes = jax.tree.map(lambda x: x.shape, regular_shard_shape_dtype)
def make_output_shape(axis, shard_shape, remainder_shape):
return (
shard_shape[:axis]
+ (shard_shape[axis] * num_extra_shards + remainder_shape[axis],)
+ shard_shape[axis + 1 :]
)
out_shapes = jax.tree.map(
make_output_shape, out_axes_, shard_shapes, out_shapes
)
# Calls dynamic Update slice with different argument order.
# This is here since tree_map only works with positional arguments.
def dynamic_update_slice_in_dim(full_array, update, axis, i):
return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis)
def compute_shard(outputs, slice_start, slice_size):
slice_out = apply_fun_to_slice(slice_start, slice_size)
update_slice = partial(dynamic_update_slice_in_dim, i=slice_start)
return jax.tree.map(update_slice, outputs, slice_out, out_axes_)
def scan_iteration(outputs, i):
new_outputs = compute_shard(outputs, i, shard_size)
return new_outputs, ()
slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size)
def allocate_buffer(dtype, shape):
return jnp.zeros(shape, dtype=dtype)
outputs = jax.tree.map(allocate_buffer, out_dtypes, out_shapes)
if slice_starts.shape[0] > 0:
outputs, _ = hk.scan(scan_iteration, outputs, slice_starts)
if last_shard_size != shard_size:
remainder_start = in_size - last_shard_size
outputs = compute_shard(outputs, remainder_start, last_shard_size)
return outputs
return mapped_fn
def inference_subbatch(
module: Callable[..., PytreeJaxArray],
subbatch_size: int,
batched_args: Sequence[PytreeJaxArray],
nonbatched_args: Sequence[PytreeJaxArray],
input_subbatch_dim: int = 0,
output_subbatch_dim: int | None = None,
) -> PytreeJaxArray:
"""Run through subbatches (like batch apply but with split and concat)."""
assert len(batched_args) > 0 # pylint: disable=g-explicit-length-test
if hk.running_init():
args = list(batched_args) + list(nonbatched_args)
return module(*args)
if output_subbatch_dim is None:
output_subbatch_dim = input_subbatch_dim
def run_module(*batched_args):
args = list(batched_args) + list(nonbatched_args)
res = module(*args)
return res
sharded_module = sharded_apply(
run_module,
shard_size=subbatch_size,
in_axes=input_subbatch_dim,
out_axes=output_subbatch_dim,
)
output = sharded_module(*batched_args)
return output
================================================
FILE: src/alphafold3/model/components/utils.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Utility functions for training AlphaFold and similar models."""
from collections import abc
import contextlib
import numbers
from alphafold3.model import features
import haiku as hk
import jax.numpy as jnp
import numpy as np
VALID_DTYPES = [np.float32, np.float64, np.int8, np.int32, np.int64, bool]
def remove_invalidly_typed_feats(
batch: features.BatchDict,
) -> features.BatchDict:
"""Remove features of types we don't want to send to the TPU e.g. strings."""
return {
k: v
for k, v in batch.items()
if hasattr(v, 'dtype') and v.dtype in VALID_DTYPES
}
def bfloat16_getter(next_getter, value, context):
"""Ensures that a bfloat16 parameter is provided by casting if necessary."""
if context.original_dtype == jnp.bfloat16:
if value.dtype != jnp.bfloat16:
value = value.astype(jnp.bfloat16)
return next_getter(value)
@contextlib.contextmanager
def bfloat16_context():
with hk.custom_getter(bfloat16_getter):
yield
def mask_mean(mask, value, axis=None, keepdims=False, eps=1e-10):
"""Masked mean."""
mask_shape = mask.shape
value_shape = value.shape
assert len(mask_shape) == len(
value_shape
), 'Shapes are not compatible, shapes: {}, {}'.format(mask_shape, value_shape)
if isinstance(axis, numbers.Integral):
axis = [axis]
elif axis is None:
axis = list(range(len(mask_shape)))
assert isinstance(
axis, abc.Iterable
), 'axis needs to be either an iterable, integer or "None"'
broadcast_factor = 1.0
for axis_ in axis:
value_size = value_shape[axis_]
mask_size = mask_shape[axis_]
if mask_size == 1:
broadcast_factor *= value_size
else:
error = f'Shapes are not compatible, shapes: {mask_shape}, {value_shape}'
assert mask_size == value_size, error
return jnp.sum(mask * value, keepdims=keepdims, axis=axis) / (
jnp.maximum(
jnp.sum(mask, keepdims=keepdims, axis=axis) * broadcast_factor, eps
)
)
================================================
FILE: src/alphafold3/model/confidence_types.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Confidence categories for predictions."""
import dataclasses
import enum
import json
from typing import Any, Self
from absl import logging
from alphafold3.model import model
import jax
import numpy as np
class StructureConfidenceFullEncoder(json.JSONEncoder):
"""JSON encoder for serializing confidence types."""
def __init__(self, **kwargs):
super().__init__(**(kwargs | dict(separators=(',', ':'))))
def encode(self, o: 'StructureConfidenceFull'):
# Cast to np.float64 before rounding, since casting to Python float will
# cast to a 64 bit float, potentially undoing np.float32 rounding.
atom_plddts = np.round(
np.clip(np.asarray(o.atom_plddts, dtype=np.float64), 0.0, 99.99), 2
).astype(float)
contact_probs = np.round(
np.clip(np.asarray(o.contact_probs, dtype=np.float64), 0.0, 1.0), 2
).astype(float)
pae = np.round(
np.clip(np.asarray(o.pae, dtype=np.float64), 0.0, 99.9), 1
).astype(float)
return """\
{
"atom_chain_ids": %s,
"atom_plddts": %s,
"contact_probs": %s,
"pae": %s,
"token_chain_ids": %s,
"token_res_ids": %s
}""" % (
super().encode(o.atom_chain_ids),
super().encode(list(atom_plddts)).replace('NaN', 'null'),
super().encode([list(x) for x in contact_probs]).replace('NaN', 'null'),
super().encode([list(x) for x in pae]).replace('NaN', 'null'),
super().encode(o.token_chain_ids),
super().encode(o.token_res_ids),
)
def _dump_json(data: Any, indent: int | None = None) -> str:
"""Dumps a json string with JSON compatible NaN representation."""
json_str = json.dumps(
data,
sort_keys=True,
indent=indent,
separators=(',', ': '),
)
return json_str.replace('NaN', 'null')
@enum.unique
class ConfidenceCategory(enum.Enum):
"""Confidence categories for AlphaFold predictions."""
HIGH = 0
MEDIUM = 1
LOW = 2
DISORDERED = 3
@classmethod
def from_char(cls, char: str) -> Self:
match char:
case 'H':
return cls.HIGH
case 'M':
return cls.MEDIUM
case 'L':
return cls.LOW
case 'D':
return cls.DISORDERED
case _:
raise ValueError(
f'Unknown character. Expected one of H, M, L or D; got: {char}'
)
def to_char(self) -> str:
match self:
case self.HIGH:
return 'H'
case self.MEDIUM:
return 'M'
case self.LOW:
return 'L'
case self.DISORDERED:
return 'D'
@classmethod
def from_confidence_score(cls, confidence: float) -> Self:
if 90 <= confidence <= 100:
return cls.HIGH
if 70 <= confidence < 90:
return cls.MEDIUM
if 50 <= confidence < 70:
return cls.LOW
if 0 <= confidence < 50:
return cls.DISORDERED
raise ValueError(f'Confidence score out of range [0, 100]: {confidence}')
@dataclasses.dataclass()
class AtomConfidence:
"""Dataclass for 1D per-atom confidences from AlphaFold."""
chain_id: list[str]
atom_number: list[int]
confidence: list[float]
confidence_category: list[ConfidenceCategory]
def __post_init__(self):
num_res = len(self.atom_number)
if not all(
len(v) == num_res
for v in [self.chain_id, self.confidence, self.confidence_category]
):
raise ValueError('All confidence fields must have the same length.')
@classmethod
def from_inference_result(
cls, inference_result: model.InferenceResult
) -> Self:
"""Instantiates an AtomConfidence from a structure.
Args:
inference_result: Inference result from AlphaFold.
Returns:
Scores in AtomConfidence dataclass.
"""
struc = inference_result.predicted_structure
as_dict = {
'chain_id': [],
'atom_number': [],
'confidence': [],
'confidence_category': [],
}
for atom_number, atom in enumerate(struc.iter_atoms()):
this_confidence = float(struc.atom_b_factor[atom_number])
as_dict['chain_id'].append(atom['chain_id'])
as_dict['atom_number'].append(atom_number)
as_dict['confidence'].append(round(this_confidence, 2))
as_dict['confidence_category'].append(
ConfidenceCategory.from_confidence_score(this_confidence)
)
return cls(**as_dict)
@classmethod
def from_json(cls, json_string: str) -> Self:
"""Instantiates a AtomConfidence from a json string."""
input_dict = json.loads(json_string)
input_dict['confidence_category'] = [
ConfidenceCategory.from_char(k)
for k in input_dict['confidence_category']
]
return cls(**input_dict)
def to_json(self) -> str:
output = dataclasses.asdict(self)
output['confidence_category'] = [
k.to_char() for k in output['confidence_category']
]
output['atom_number'] = [int(k) for k in output['atom_number']]
return _dump_json(output)
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class StructureConfidenceSummary:
"""Dataclass for the summary of structure scores from AlphaFold.
Attributes:
ptm: Predicted TM global score.
iptm: Interface predicted TM global score.
ranking_score: Ranking score extracted from CIF metadata.
fraction_disordered: Fraction disordered, measured with RASA.
has_clash: Has significant clashing.
chain_pair_pae_min: [num_chains, num_chains] Minimum cross chain PAE.
chain_pair_iptm: [num_chains, num_chains] Chain pair ipTM.
chain_ptm: [num_chains] Chain pTM.
chain_iptm: [num_chains] Mean cross chain ipTM for a chain.
"""
ptm: float
iptm: float
ranking_score: float
fraction_disordered: float
has_clash: float
chain_pair_pae_min: np.ndarray
chain_pair_iptm: np.ndarray
chain_ptm: np.ndarray
chain_iptm: np.ndarray
@classmethod
def from_inference_result(
cls, inference_result: model.InferenceResult
) -> Self:
"""Returns a new instance based on a given inference result."""
return cls(
ptm=float(inference_result.metadata['ptm']),
iptm=float(inference_result.metadata['iptm']),
ranking_score=float(inference_result.metadata['ranking_score']),
fraction_disordered=float(
inference_result.metadata['fraction_disordered']
),
has_clash=float(inference_result.metadata['has_clash']),
chain_pair_pae_min=inference_result.metadata['chain_pair_pae_min'],
chain_pair_iptm=inference_result.metadata['chain_pair_iptm'],
chain_ptm=inference_result.metadata['iptm_ichain'],
chain_iptm=inference_result.metadata['iptm_xchain'],
)
@classmethod
def from_json(cls, json_string: str) -> Self:
"""Returns a new instance from a given json string."""
return cls(**json.loads(json_string))
def to_json(self) -> str:
def convert(data):
if isinstance(data, np.ndarray):
# Cast to np.float64 before rounding, since casting to Python float will
# cast to a 64 bit float, potentially undoing np.float32 rounding.
rounded_data = np.round(data.astype(np.float64), decimals=2).tolist()
else:
rounded_data = np.round(data, decimals=2)
return rounded_data
return _dump_json(jax.tree.map(convert, dataclasses.asdict(self)), indent=1)
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class StructureConfidenceFull:
"""Dataclass for full structure data from AlphaFold."""
pae: np.ndarray
token_chain_ids: list[str]
token_res_ids: list[int]
atom_plddts: list[float]
atom_chain_ids: list[str]
contact_probs: np.ndarray # [num_tokens, num_tokens]
@classmethod
def from_inference_result(
cls, inference_result: model.InferenceResult
) -> Self:
"""Returns a new instance based on a given inference result."""
pae = inference_result.numerical_data['full_pae']
if not isinstance(pae, np.ndarray):
logging.info('%s', type(pae))
raise TypeError('pae should be a numpy array.')
contact_probs = inference_result.numerical_data['contact_probs']
if not isinstance(contact_probs, np.ndarray):
logging.info('%s', type(contact_probs))
raise TypeError('contact_probs should be a numpy array.')
struc = inference_result.predicted_structure
chain_ids = struc.chain_id.tolist()
atom_plddts = struc.atom_b_factor.tolist()
token_chain_ids = [
str(token_id)
for token_id in inference_result.metadata['token_chain_ids']
]
token_res_ids = [
int(token_id) for token_id in inference_result.metadata['token_res_ids']
]
return cls(
pae=pae,
token_chain_ids=token_chain_ids,
token_res_ids=token_res_ids,
atom_plddts=atom_plddts,
atom_chain_ids=chain_ids,
contact_probs=contact_probs,
)
@classmethod
def from_json(cls, json_string: str) -> Self:
"""Returns a new instance from a given json string."""
return cls(**json.loads(json_string))
def to_json(self) -> str:
"""Converts StructureConfidenceFull to json string."""
return json.dumps(self, cls=StructureConfidenceFullEncoder)
================================================
FILE: src/alphafold3/model/confidences.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Functions for extracting and processing confidences from model outputs."""
import warnings
from absl import logging
from alphafold3 import structure
from alphafold3.constants import residue_names
from alphafold3.cpp import mkdssp
import jax.numpy as jnp
import numpy as np
from scipy import spatial
# From Sander & Rost 1994 https://doi.org/10.1002/prot.340200303
MAX_ACCESSIBLE_SURFACE_AREA = {
'ALA': 106.0,
'ARG': 248.0,
'ASN': 157.0,
'ASP': 163.0,
'CYS': 135.0,
'GLN': 198.0,
'GLU': 194.0,
'GLY': 84.0,
'HIS': 184.0,
'ILE': 169.0,
'LEU': 164.0,
'LYS': 205.0,
'MET': 188.0,
'PHE': 197.0,
'PRO': 136.0,
'SER': 130.0,
'THR': 142.0,
'TRP': 227.0,
'TYR': 222.0,
'VAL': 142.0,
}
# Weights for ranking confidence.
_IPTM_WEIGHT = 0.8
_FRACTION_DISORDERED_WEIGHT = 0.5
_CLASH_PENALIZATION_WEIGHT = 100.0
def windowed_solvent_accessible_area(cif: str, window: int = 25) -> np.ndarray:
"""Implementation of AlphaFold-RSA.
AlphaFold-RSA defined in https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9601767.
Args:
cif: Raw cif string.
window: The window over which to average accessible surface area
Returns:
An array of size num_res that predicts disorder by using windowed solvent
accessible surface area.
"""
result = mkdssp.get_dssp(cif, calculate_surface_accessibility=True)
parse_row = False
rasa = []
for row in result.splitlines():
if parse_row:
aa = row[13:14]
if aa == '!':
continue
aa3 = residue_names.PROTEIN_COMMON_ONE_TO_THREE.get(aa, 'ALA')
max_acc = MAX_ACCESSIBLE_SURFACE_AREA[aa3]
acc = int(row[34:38])
norm_acc = acc / max_acc
if norm_acc > 1.0:
norm_acc = 1.0
rasa.append(norm_acc)
if row.startswith(' # RESIDUE'):
parse_row = True
half_w = (window - 1) // 2
pad_rasa = np.pad(rasa, (half_w, half_w), 'reflect')
rasa = np.convolve(pad_rasa, np.ones(window), 'valid') / window
return rasa
def fraction_disordered(
struc: structure.Structure, rasa_disorder_cutoff: float = 0.581
) -> float:
"""Compute fraction of protein residues that are disordered.
Args:
struc: A structure to compute rASA metrics on.
rasa_disorder_cutoff: The threshold at which residues are considered
disordered. Default value taken from
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9601767.
Returns:
The fraction of protein residues that are disordered
(rasa > rasa_disorder_cutoff).
"""
struc = struc.filter_to_entity_type(protein=True)
rasa = []
seq_rasa = {}
for chain_id, chain_seq in struc.chain_single_letter_sequence().items():
if chain_seq in seq_rasa:
# We assume that identical sequences have approximately similar rasa
# values to speed up the computation.
rasa.extend(seq_rasa[chain_seq])
continue
chain_struc = struc.filter(chain_id=chain_id)
# Rename the chain to 'A' as MKDSSP supports only single letter chain IDs.
chain_struc = chain_struc.rename_chain_ids(new_id_by_old_id={chain_id: 'A'})
try:
rasa_per_residue = windowed_solvent_accessible_area(
chain_struc.to_mmcif()
)
seq_rasa[chain_seq] = rasa_per_residue
rasa.extend(rasa_per_residue)
except (ValueError, RuntimeError) as e:
logging.warning('%s: rasa calculation failed: %s', struc.name, e)
if not rasa:
return 0.0
return np.mean(np.array(rasa) > rasa_disorder_cutoff)
def has_clash(
struc: structure.Structure,
cutoff_radius: float = 1.1,
min_clashes_for_overlap: int = 100,
min_fraction_for_overlap: float = 0.5,
) -> bool:
"""Determine whether the structure has at least one clashing chain.
A clashing chain is defined as having greater than 100 polymer atoms within
1.1A of another polymer atom, or having more than 50% of the chain with
clashing atoms.
Args:
struc: A structure to get clash metrics for.
cutoff_radius: atom distances under this threshold are considered a clash.
min_clashes_for_overlap: The minimum number of atom-atom clashes for a chain
to be considered overlapping.
min_fraction_for_overlap: The minimum fraction of atoms within a chain that
are clashing for the chain to be considered overlapping.
Returns:
True if the structure has at least one clashing chain.
"""
struc = struc.filter_to_entity_type(protein=True, rna=True, dna=True)
if not struc.chains:
return False
coords = struc.coords
coord_kdtree = spatial.cKDTree(coords)
clashes_per_atom = coord_kdtree.query_ball_point(
coords, p=2.0, r=cutoff_radius
)
per_atom_has_clash = np.zeros(len(coords), dtype=np.int32)
for atom_idx, clashing_indices in enumerate(clashes_per_atom):
for clashing_idx in clashing_indices:
if np.abs(struc.res_id[atom_idx] - struc.res_id[clashing_idx]) > 1 or (
struc.chain_id[atom_idx] != struc.chain_id[clashing_idx]
):
per_atom_has_clash[atom_idx] = True
break
for chain_id in struc.chains:
mask = struc.chain_id == chain_id
num_atoms = np.sum(mask)
if num_atoms == 0:
continue
num_clashes = np.sum(per_atom_has_clash * mask)
frac_clashes = num_clashes / num_atoms
if (
num_clashes > min_clashes_for_overlap
or frac_clashes > min_fraction_for_overlap
):
return True
return False
def get_ranking_score(
ptm: float, iptm: float, fraction_disordered_: float, has_clash_: bool
) -> float:
# ipTM is NaN for single chain structures. Use pTM for such cases.
if np.isnan(iptm):
ptm_iptm_average = ptm
else:
ptm_iptm_average = _IPTM_WEIGHT * iptm + (1.0 - _IPTM_WEIGHT) * ptm
return (
ptm_iptm_average
+ _FRACTION_DISORDERED_WEIGHT * fraction_disordered_
- _CLASH_PENALIZATION_WEIGHT * has_clash_
)
def rank_metric(
full_pde: jnp.ndarray | np.ndarray, contact_probs: jnp.ndarray | np.ndarray
) -> jnp.ndarray | np.ndarray:
"""Compute the metric that will be used to rank predictions, higher is better.
Args:
full_pde: A [num_samples, num_tokens,num_tokens] matrix of predicted
distance errors between pairs of tokens.
contact_probs: A [num_tokens, num_tokens] matrix consisting of the
probability of contact (<8A) that is returned from the distogram head.
Returns:
A scalar that can be used to rank (higher is better).
"""
if not isinstance(full_pde, type(contact_probs)):
raise ValueError('full_pde and contact_probs must be of the same type.')
if isinstance(full_pde, np.ndarray):
sum_fn = np.sum
elif isinstance(full_pde, jnp.ndarray):
sum_fn = jnp.sum
else:
raise ValueError('full_pde must be a numpy array or a jax array.')
# It was found that taking the contact_map weighted average was better than
# just the predicted distance error on its own.
return -sum_fn(full_pde * contact_probs[None, :, :], axis=(-2, -1)) / (
sum_fn(contact_probs) + 1e-6
)
def weighted_mean(mask, value, axis):
return np.mean(mask * value, axis=axis) / (1e-8 + np.mean(mask, axis=axis))
def pde_single(
num_tokens: int,
asym_ids: np.ndarray,
full_pde: np.ndarray,
contact_probs: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Compute 1D PDE summaries.
Args:
num_tokens: The number of tokens (not including padding).
asym_ids: The asym_ids (array of shape num_tokens).
full_pde: A [num_samples, num_tokens, num_tokens] matrix of predicted
distance errors.
contact_probs: A [num_tokens, num_tokens] matrix consisting of the
probability of contact (<8A) that is returned from the distogram head.
Returns:
A tuple (ichain, xchain, full_chain) where:
`ichain` is a [num_samples, num_chains] matrix where the
value assigned to each chain is an average of the full PDE matrix over all
its within-chain interactions, weighted by `contact_probs`.
`xchain` is a [num_samples, num_chains] matrix where the
value assigned to each chain is an average of the full PDE matrix over all
its cross-chain interactions, weighted by `contact_probs`.
`full_chain` is a [num_samples, num_tokens] matrix where the
value assigned to each token is an average of it PDE against all tokens,
weighted by `contact_probs`.
"""
full_pde = full_pde[:, :num_tokens, :num_tokens]
contact_probs = contact_probs[:num_tokens, :num_tokens]
asym_ids = asym_ids[:num_tokens]
unique_asym_ids = np.unique(asym_ids)
num_chains = len(unique_asym_ids)
num_samples = full_pde.shape[0]
asym_ids = asym_ids[None]
contact_probs = contact_probs[None]
ichain = np.zeros((num_samples, num_chains))
xchain = np.zeros((num_samples, num_chains))
for idx, asym_id in enumerate(unique_asym_ids):
my_asym_id = asym_ids == asym_id
imask = my_asym_id[:, :, None] * my_asym_id[:, None, :]
xmask = my_asym_id[:, :, None] * ~my_asym_id[:, None, :]
imask = imask * contact_probs
xmask = xmask * contact_probs
ichain[:, idx] = weighted_mean(mask=imask, value=full_pde, axis=(-2, -1))
xchain[:, idx] = weighted_mean(mask=xmask, value=full_pde, axis=(-2, -1))
full_chain = weighted_mean(mask=contact_probs, value=full_pde, axis=(-1,))
return ichain, xchain, full_chain
def chain_pair_pde(
num_tokens: int, asym_ids: np.ndarray, full_pde: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
"""Compute predicted distance errors for all pairs of chains.
Args:
num_tokens: The number of tokens (not including padding).
asym_ids: The asym_ids (array of shape num_tokens).
full_pde: A [num_samples, num_tokens, num_tokens] matrix of predicted
distance errors.
Returns:
chain_pair_pred_err_mean - a [num_chains, num_chains] matrix with average
per chain-pair predicted distance error.
chain_pair_pred_err_min - a [num_chains, num_chains] matrix with min
per chain-pair predicted distance error.
"""
full_pde = full_pde[:, :num_tokens, :num_tokens]
asym_ids = asym_ids[:num_tokens]
unique_asym_ids = np.unique(asym_ids)
num_chains = len(unique_asym_ids)
num_samples = full_pde.shape[0]
chain_pair_pred_err_mean = np.zeros((num_samples, num_chains, num_chains))
chain_pair_pred_err_min = np.zeros((num_samples, num_chains, num_chains))
for idx1, asym_id_1 in enumerate(unique_asym_ids):
subset = full_pde[:, asym_ids == asym_id_1, :]
for idx2, asym_id_2 in enumerate(unique_asym_ids):
subsubset = subset[:, :, asym_ids == asym_id_2]
chain_pair_pred_err_mean[:, idx1, idx2] = np.mean(subsubset, axis=(1, 2))
chain_pair_pred_err_min[:, idx1, idx2] = np.min(subsubset, axis=(1, 2))
return chain_pair_pred_err_mean, chain_pair_pred_err_min
def weighted_nanmean(
value: np.ndarray, mask: np.ndarray, axis: int
) -> np.ndarray:
"""Nan-mean with weighting -- empty slices return NaN."""
assert mask.shape == value.shape
assert not np.isnan(mask).all()
nan_idxs = np.where(np.isnan(value))
# Need to NaN the mask to get the correct denominator weighting.
mask_with_nan = mask.copy()
mask_with_nan[nan_idxs] = np.nan
with warnings.catch_warnings():
# Mean of empty slice is ok and should return a NaN.
warnings.filterwarnings(action='ignore', message='Mean of empty slice')
warnings.filterwarnings(
action='ignore', message='invalid value encountered in (scalar )?divide'
)
return np.nanmean(value * mask_with_nan, axis=axis) / np.nanmean(
mask_with_nan, axis=axis
)
def chain_pair_pae(
*,
num_tokens: int,
asym_ids: np.ndarray,
full_pae: np.ndarray,
mask: np.ndarray | None = None,
contact_probs: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Compute predicted errors for all pairs of chains.
Args:
num_tokens: The number of tokens (not including padding).
asym_ids: The asym_ids (array of shape num_tokens).
full_pae: A [num_samples, num_tokens, num_tokens] matrix of predicted
errors.
mask: A [num_tokens, num_tokens] mask matrix.
contact_probs: A [num_tokens, num_tokens] matrix consisting of the
probability of contact (<8A) that is returned from the distogram head.
Returns:
chain_pair_pred_err_mean - a [num_chains, num_chains] matrix with average
per chain-pair predicted error.
"""
if mask is None:
mask = np.ones(shape=full_pae.shape[1:], dtype=bool)
if contact_probs is None:
contact_probs = np.ones(shape=full_pae.shape[1:], dtype=float)
assert mask.shape == full_pae.shape[1:]
full_pae = full_pae[:, :num_tokens, :num_tokens]
mask = mask[:num_tokens, :num_tokens]
asym_ids = asym_ids[:num_tokens]
contact_probs = contact_probs[:num_tokens, :num_tokens]
unique_asym_ids = np.unique(asym_ids)
num_chains = len(unique_asym_ids)
num_samples = full_pae.shape[0]
chain_pair_pred_err_mean = np.zeros((num_samples, num_chains, num_chains))
chain_pair_pred_err_min = np.zeros((num_samples, num_chains, num_chains))
for idx1, asym_id_1 in enumerate(unique_asym_ids):
subset = full_pae[:, asym_ids == asym_id_1, :]
subset_mask = mask[asym_ids == asym_id_1, :]
subset_contact_probs = contact_probs[asym_ids == asym_id_1, :]
for idx2, asym_id_2 in enumerate(unique_asym_ids):
subsubset = subset[:, :, asym_ids == asym_id_2]
subsubset_mask = subset_mask[:, asym_ids == asym_id_2]
subsubset_contact_probs = subset_contact_probs[:, asym_ids == asym_id_2]
(flat_mask_idxs,) = np.where(subsubset_mask.flatten() > 0)
flat_subsubset = subsubset.reshape([num_samples, -1])
flat_contact_probs = subsubset_contact_probs.flatten()
# A ligand chain will have no valid frames if it contains fewer than
# three non-colinear atoms (e.g. a sodium ion).
if not flat_mask_idxs.size:
chain_pair_pred_err_mean[:, idx1, idx2] = np.nan
chain_pair_pred_err_min[:, idx1, idx2] = np.nan
else:
chain_pair_pred_err_min[:, idx1, idx2] = np.min(
flat_subsubset[:, flat_mask_idxs], axis=1
)
chain_pair_pred_err_mean[:, idx1, idx2] = weighted_mean(
mask=flat_contact_probs[flat_mask_idxs],
value=flat_subsubset[:, flat_mask_idxs],
axis=-1,
)
return chain_pair_pred_err_mean, chain_pair_pred_err_min, unique_asym_ids
def reduce_chain_pair(
*,
chain_pair_met: np.ndarray,
num_chain_tokens: np.ndarray,
agg_over_col: bool,
agg_type: str,
weight_method: str,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute 1D summaries from a chain-pair summary.
Args:
chain_pair_met: A [num_samples, num_chains, num_chains] aggregate matrix.
num_chain_tokens: A [num_chains] array of number of tokens for each chain.
Used for 'per_token' weighting.
agg_over_col: Whether to aggregate the PAE over rows (i.e. average error
when aligned to me) or columns (i.e. my average error when aligned to all
others.)
agg_type: The type of aggregation to use, 'mean' or 'min'.
weight_method: The method to use for weighting the PAE, 'per_token' or
'per_chain'.
Returns:
A tuple (ichain, xchain) where:
`ichain` is a [num_samples, num_chains] matrix where the
value assigned to each chain is an average of the full PAE matrix over all
its within-chain interactions, weighted by `contact_probs`.
`xchain` is a [num_samples, num_chains] matrix where the
value assigned to each chain is an average of the full PAE matrix over all
its cross-chain interactions, weighted by `contact_probs`.
"""
num_samples, num_chains, _ = chain_pair_met.shape
ichain = chain_pair_met.diagonal(axis1=-2, axis2=-1)
if weight_method == 'per_chain':
chain_weight = np.ones((num_chains,), dtype=float)
elif weight_method == 'per_token':
chain_weight = num_chain_tokens
else:
raise ValueError(f'Unknown weight method: {weight_method}')
if agg_over_col:
agg_axis = -1
else:
agg_axis = -2
if agg_type == 'mean':
weight = np.ones((num_samples, num_chains, num_chains), dtype=float)
weight -= np.eye(num_chains, dtype=float)
weight *= chain_weight[None] * chain_weight[:, None]
xchain = weighted_nanmean(chain_pair_met, mask=weight, axis=agg_axis)
elif agg_type == 'min':
is_self = np.eye(num_chains)
with warnings.catch_warnings():
# Min over empty slice is ok and should return a NaN.
warnings.filterwarnings('ignore', message='All-NaN slice encountered')
xchain = np.nanmin(chain_pair_met + 1e8 * is_self, axis=agg_axis)
else:
raise ValueError(f'Unknown aggregation method: {agg_type}')
return ichain, xchain
def pae_metrics(
num_tokens: int,
asym_ids: np.ndarray,
full_pae: np.ndarray,
mask: np.ndarray,
contact_probs: np.ndarray,
tm_adjusted_pae: np.ndarray,
):
"""PAE aggregate metrics."""
assert mask.shape == full_pae.shape[1:]
assert contact_probs.shape == full_pae.shape[1:]
chain_pair_contact_weighted, _, unique_asym_ids = chain_pair_pae(
num_tokens=num_tokens,
asym_ids=asym_ids,
full_pae=full_pae,
mask=mask,
contact_probs=contact_probs,
)
ret = {}
ret['chain_pair_pae_mean'], ret['chain_pair_pae_min'], _ = chain_pair_pae(
num_tokens=num_tokens,
asym_ids=asym_ids,
full_pae=full_pae,
mask=mask,
)
chain_pair_iptm = np.stack(
[
chain_pairwise_predicted_tm_scores(
tm_adjusted_pae=sample_tm_adjusted_pae[:num_tokens],
asym_id=asym_ids[:num_tokens],
pair_mask=mask[:num_tokens, :num_tokens],
)
for sample_tm_adjusted_pae in tm_adjusted_pae
],
axis=0,
)
num_chain_tokens = np.array(
[sum(asym_ids == asym_id) for asym_id in unique_asym_ids]
)
def reduce_chain_pair_fn(chain_pair: np.ndarray):
def inner(agg_over_col):
ichain_pae, xchain_pae = reduce_chain_pair(
num_chain_tokens=num_chain_tokens,
chain_pair_met=chain_pair,
agg_over_col=agg_over_col,
agg_type='mean',
weight_method='per_chain',
)
return ichain_pae, xchain_pae
ichain, xchain_row_agg = inner(False)
_, xchain_col_agg = inner(True)
with warnings.catch_warnings():
# Mean of empty slice is ok and should return a NaN.
warnings.filterwarnings(action='ignore', message='Mean of empty slice')
xchain = np.nanmean(
np.stack([xchain_row_agg, xchain_col_agg], axis=0), axis=0
)
return ichain, xchain
pae_ichain, pae_xchain = reduce_chain_pair_fn(chain_pair_contact_weighted)
iptm_ichain, iptm_xchain = reduce_chain_pair_fn(chain_pair_iptm)
ret.update({
'chain_pair_iptm': chain_pair_iptm,
'iptm_ichain': iptm_ichain,
'iptm_xchain': iptm_xchain,
'pae_ichain': pae_ichain,
'pae_xchain': pae_xchain,
})
return ret
def get_iptm_xchain(chain_pair_iptm: np.ndarray) -> np.ndarray:
"""Cross chain aggregate ipTM."""
num_samples, num_chains, _ = chain_pair_iptm.shape
weight = np.ones((num_samples, num_chains, num_chains), dtype=float)
weight -= np.eye(num_chains, dtype=float)
xchain_row_agg = weighted_nanmean(chain_pair_iptm, mask=weight, axis=-2)
xchain_col_agg = weighted_nanmean(chain_pair_iptm, mask=weight, axis=-1)
with warnings.catch_warnings():
# Mean of empty slice is ok and should return a NaN.
warnings.filterwarnings(action='ignore', message='Mean of empty slice')
iptm_xchain = np.nanmean(
np.stack([xchain_row_agg, xchain_col_agg], axis=0), axis=0
)
return iptm_xchain
def predicted_tm_score(
tm_adjusted_pae: np.ndarray,
pair_mask: np.ndarray,
asym_id: np.ndarray,
interface: bool = False,
) -> float:
"""Computes predicted TM alignment or predicted interface TM alignment score.
Args:
tm_adjusted_pae: [num_res, num_res] Relevant tensor for computing TMScore
values.
pair_mask: A [num_res, num_res] mask. The TM score will only aggregate over
masked-on entries.
asym_id: [num_res] asymmetric unit ID (the chain ID). Only needed for ipTM
calculation, i.e. when interface=True.
interface: If True, the interface predicted TM score is computed. If False,
the predicted TM score without any residue pair restrictions is computed.
Returns:
score: pTM or ipTM score.
"""
num_tokens, _ = tm_adjusted_pae.shape
if tm_adjusted_pae.shape != (num_tokens, num_tokens):
raise ValueError(
f'Bad tm_adjusted_pae shape, expected ({num_tokens, num_tokens}), got '
f'{tm_adjusted_pae.shape}.'
)
if pair_mask.shape != (num_tokens, num_tokens):
raise ValueError(
f'Bad pair_mask shape, expected ({num_tokens, num_tokens}), got '
f'{pair_mask.shape}.'
)
if pair_mask.dtype != bool:
raise TypeError(f'Bad pair mask type, expected bool, got {pair_mask.dtype}')
if asym_id.shape[0] != num_tokens:
raise ValueError(
f'Bad asym_id shape, expected ({num_tokens},), got {asym_id.shape}.'
)
# Create pair mask.
if interface:
pair_mask = pair_mask * (asym_id[:, None] != asym_id[None, :])
# Ions and other ligands with colinear atoms have ill-defined frames.
if pair_mask.sum() == 0:
return np.nan
normed_residue_mask = pair_mask / (
1e-8 + np.sum(pair_mask, axis=-1, keepdims=True)
)
per_alignment = np.sum(tm_adjusted_pae * normed_residue_mask, axis=-1)
return per_alignment.max()
def chain_pairwise_predicted_tm_scores(
tm_adjusted_pae: np.ndarray,
pair_mask: np.ndarray,
asym_id: np.ndarray,
) -> np.ndarray:
"""Compute predicted TM (pTM) between each pair of chains independently.
Args:
tm_adjusted_pae: [num_res, num_res] Relevant tensor for computing TMScore
values.
pair_mask: A [num_res, num_res] mask specifying which frames are valid.
Invalid frames can be the result of chains with not enough atoms (e.g.
ions).
asym_id: [num_res] asymmetric unit ID (the chain ID).
Returns:
A [num_chains, num_chains] matrix, where row i, column j indicates the
predicted TM-score for the interface between chain i and chain j.
"""
unique_chains = list(np.unique(asym_id))
num_chains = len(unique_chains)
all_pairs_iptms = np.zeros((num_chains, num_chains))
for i, chain_i in enumerate(unique_chains):
chain_i_mask = asym_id == chain_i
for j, chain_j in enumerate(unique_chains[i:]):
chain_j_mask = asym_id == chain_j
mask = chain_i_mask | chain_j_mask
(indices,) = np.where(mask)
is_interface = chain_i != chain_j
indices = np.ix_(indices, indices)
iptm = predicted_tm_score(
tm_adjusted_pae=tm_adjusted_pae[indices],
pair_mask=pair_mask[indices],
asym_id=asym_id[mask],
interface=is_interface,
)
all_pairs_iptms[i, i + j] = iptm
all_pairs_iptms[i + j, i] = iptm
return all_pairs_iptms
================================================
FILE: src/alphafold3/model/data3.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Protein features that are computed from parsed mmCIF objects."""
from collections.abc import Mapping
import datetime
from typing import TypeAlias
from alphafold3.constants import residue_names
from alphafold3.cpp import msa_profile
from alphafold3.model import protein_data_processing
import numpy as np
FeatureDict: TypeAlias = Mapping[str, np.ndarray]
def get_profile_features(
msa: np.ndarray, deletion_matrix: np.ndarray
) -> FeatureDict:
"""Returns the MSA profile and deletion_mean features."""
num_restypes = residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP
profile = msa_profile.compute_msa_profile(
msa=msa, num_residue_types=num_restypes
)
return {
'profile': profile.astype(np.float32),
'deletion_mean': np.mean(deletion_matrix, axis=0),
}
def fix_template_features(
template_features: FeatureDict, num_res: int
) -> FeatureDict:
"""Convert template features to AlphaFold 3 format.
Args:
template_features: Template features for the protein.
num_res: The length of the amino acid sequence of the protein.
Returns:
Updated template_features for the chain.
"""
if not template_features['template_aatype'].shape[0]:
template_features = empty_template_features(num_res)
else:
template_release_timestamp = [
_get_timestamp(x.decode('utf-8'))
for x in template_features['template_release_date']
]
# Convert from atom37 to dense atom
dense_atom_indices = np.take(
protein_data_processing.PROTEIN_AATYPE_DENSE_ATOM_TO_ATOM37,
template_features['template_aatype'],
axis=0,
)
atom_mask = np.take_along_axis(
template_features['template_all_atom_masks'], dense_atom_indices, axis=2
)
atom_positions = np.take_along_axis(
template_features['template_all_atom_positions'],
dense_atom_indices[..., None],
axis=2,
)
atom_positions *= atom_mask[..., None]
template_features = {
'template_aatype': template_features['template_aatype'],
'template_atom_mask': atom_mask.astype(np.int32),
'template_atom_positions': atom_positions.astype(np.float32),
'template_domain_names': np.array(
template_features['template_domain_names'], dtype=object
),
'template_release_timestamp': np.array(
template_release_timestamp, dtype=np.float32
),
}
return template_features
def empty_template_features(num_res: int) -> FeatureDict:
"""Creates a fully masked out template features to allow padding to work.
Args:
num_res: The length of the target chain.
Returns:
Empty template features for the chain.
"""
template_features = {
'template_aatype': np.zeros(num_res, dtype=np.int32)[None, ...],
'template_atom_mask': np.zeros(
(num_res, protein_data_processing.NUM_DENSE), dtype=np.int32
)[None, ...],
'template_atom_positions': np.zeros(
(num_res, protein_data_processing.NUM_DENSE, 3), dtype=np.float32
)[None, ...],
'template_domain_names': np.array([b''], dtype=object),
'template_release_timestamp': np.array([0.0], dtype=np.float32),
}
return template_features
def _get_timestamp(date_str: str):
dt = datetime.datetime.fromisoformat(date_str)
dt = dt.replace(tzinfo=datetime.timezone.utc)
return dt.timestamp()
================================================
FILE: src/alphafold3/model/data_constants.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Constants shared across modules in the AlphaFold data pipeline."""
from alphafold3.constants import residue_names
MSA_GAP_IDX = residue_names.PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP.index(
'-'
)
# Feature groups.
NUM_SEQ_NUM_RES_MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix')
NUM_SEQ_MSA_FEATURES = ('msa_species_identifiers',)
TEMPLATE_FEATURES = (
'template_aatype',
'template_atom_positions',
'template_atom_mask',
)
MSA_PAD_VALUES = {'msa': MSA_GAP_IDX, 'msa_mask': 1, 'deletion_matrix': 0}
================================================
FILE: src/alphafold3/model/feat_batch.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Batch dataclass."""
import dataclasses
from typing import Self
from alphafold3.model import features
import jax
@dataclasses.dataclass(frozen=True)
class Batch:
"""Dataclass containing batch."""
msa: features.MSA
templates: features.Templates
token_features: features.TokenFeatures
ref_structure: features.RefStructure
predicted_structure_info: features.PredictedStructureInfo
polymer_ligand_bond_info: features.PolymerLigandBondInfo
ligand_ligand_bond_info: features.LigandLigandBondInfo
pseudo_beta_info: features.PseudoBetaInfo
atom_cross_att: features.AtomCrossAtt
convert_model_output: features.ConvertModelOutput
frames: features.Frames
@property
def num_res(self) -> int:
return self.token_features.aatype.shape[-1]
@classmethod
def from_data_dict(cls, batch: features.BatchDict) -> Self:
"""Construct batch object from dictionary."""
return cls(
msa=features.MSA.from_data_dict(batch),
templates=features.Templates.from_data_dict(batch),
token_features=features.TokenFeatures.from_data_dict(batch),
ref_structure=features.RefStructure.from_data_dict(batch),
predicted_structure_info=features.PredictedStructureInfo.from_data_dict(
batch
),
polymer_ligand_bond_info=features.PolymerLigandBondInfo.from_data_dict(
batch
),
ligand_ligand_bond_info=features.LigandLigandBondInfo.from_data_dict(
batch
),
pseudo_beta_info=features.PseudoBetaInfo.from_data_dict(batch),
atom_cross_att=features.AtomCrossAtt.from_data_dict(batch),
convert_model_output=features.ConvertModelOutput.from_data_dict(batch),
frames=features.Frames.from_data_dict(batch),
)
def as_data_dict(self) -> features.BatchDict:
"""Converts batch object to dictionary."""
output = {
**self.msa.as_data_dict(),
**self.templates.as_data_dict(),
**self.token_features.as_data_dict(),
**self.ref_structure.as_data_dict(),
**self.predicted_structure_info.as_data_dict(),
**self.polymer_ligand_bond_info.as_data_dict(),
**self.ligand_ligand_bond_info.as_data_dict(),
**self.pseudo_beta_info.as_data_dict(),
**self.atom_cross_att.as_data_dict(),
**self.convert_model_output.as_data_dict(),
**self.frames.as_data_dict(),
}
return output
jax.tree_util.register_dataclass(
Batch,
data_fields=[f.name for f in dataclasses.fields(Batch)],
meta_fields=[],
)
================================================
FILE: src/alphafold3/model/features.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Data-side of the input features processing."""
import dataclasses
import datetime
import itertools
from typing import Any, Self, TypeAlias
from absl import logging
from alphafold3 import structure
from alphafold3.common import folding_input
from alphafold3.constants import chemical_components
from alphafold3.constants import mmcif_names
from alphafold3.constants import periodic_table
from alphafold3.constants import residue_names
from alphafold3.cpp import cif_dict
from alphafold3.data import msa as msa_module
from alphafold3.data import templates
from alphafold3.data.tools import rdkit_utils
from alphafold3.model import data3
from alphafold3.model import data_constants
from alphafold3.model import merging_features
from alphafold3.model import msa_pairing
from alphafold3.model.atom_layout import atom_layout
from alphafold3.structure import chemical_components as struc_chem_comps
import jax
import jax.numpy as jnp
import numpy as np
from rdkit import Chem
xnp_ndarray: TypeAlias = np.ndarray | jnp.ndarray # pylint: disable=invalid-name
BatchDict: TypeAlias = dict[str, xnp_ndarray]
_STANDARD_RESIDUES = frozenset({
*residue_names.PROTEIN_TYPES_WITH_UNKNOWN,
*residue_names.NUCLEIC_TYPES_WITH_2_UNKS,
})
@dataclasses.dataclass(frozen=True)
class PaddingShapes:
num_tokens: int
msa_size: int
num_chains: int
num_templates: int
num_atoms: int
def _pad_to(
arr: np.ndarray, shape: tuple[int | None, ...], **kwargs
) -> np.ndarray:
"""Pads an array to a given shape. Wrapper around np.pad().
Args:
arr: numpy array to pad
shape: target shape, use None for axes that should stay the same
**kwargs: additional args for np.pad, e.g. constant_values=-1
Returns:
the padded array
Raises:
ValueError if arr and shape have a different number of axes.
"""
if arr.ndim != len(shape):
raise ValueError(
f'arr and shape have different number of axes. {arr.shape=}, {shape=}'
)
num_pad = []
for axis, width in enumerate(shape):
if width is None:
num_pad.append((0, 0))
else:
if width >= arr.shape[axis]:
num_pad.append((0, width - arr.shape[axis]))
else:
raise ValueError(
f'Can not pad to a smaller shape. {arr.shape=}, {shape=}'
)
padded_arr = np.pad(arr, pad_width=num_pad, **kwargs)
return padded_arr
def _unwrap(obj):
"""Unwrap an object from a zero-dim np.ndarray."""
if isinstance(obj, np.ndarray) and obj.ndim == 0:
return obj.item()
else:
return obj
@dataclasses.dataclass(frozen=True)
class Chains:
chain_id: np.ndarray
asym_id: np.ndarray
entity_id: np.ndarray
sym_id: np.ndarray
jax.tree_util.register_dataclass(
Chains,
data_fields=[f.name for f in dataclasses.fields(Chains)],
meta_fields=[],
)
def _compute_asym_entity_and_sym_id(
all_tokens: atom_layout.AtomLayout,
) -> Chains:
"""Compute asym_id, entity_id and sym_id.
Args:
all_tokens: atom layout containing a representative atom for each token.
Returns:
A Chains object
"""
# Find identical sequences and assign entity_id and sym_id to every chain.
seq_to_entity_id_sym_id = {}
seen_chain_ids = set()
chain_ids = []
asym_ids = []
entity_ids = []
sym_ids = []
for chain_id in all_tokens.chain_id:
if chain_id not in seen_chain_ids:
asym_id = len(seen_chain_ids) + 1
seen_chain_ids.add(chain_id)
seq = ','.join(all_tokens.res_name[all_tokens.chain_id == chain_id])
if seq not in seq_to_entity_id_sym_id:
entity_id = len(seq_to_entity_id_sym_id) + 1
sym_id = 1
else:
entity_id, sym_id = seq_to_entity_id_sym_id[seq]
sym_id += 1
seq_to_entity_id_sym_id[seq] = (entity_id, sym_id)
chain_ids.append(chain_id)
asym_ids.append(asym_id)
entity_ids.append(entity_id)
sym_ids.append(sym_id)
return Chains(
chain_id=np.array(chain_ids),
asym_id=np.array(asym_ids),
entity_id=np.array(entity_ids),
sym_id=np.array(sym_ids),
)
def tokenizer(
flat_output_layout: atom_layout.AtomLayout,
ccd: chemical_components.Ccd,
max_atoms_per_token: int,
flatten_non_standard_residues: bool,
logging_name: str,
) -> tuple[atom_layout.AtomLayout, atom_layout.AtomLayout, np.ndarray]:
"""Maps a flat atom layout to tokens for evoformer.
Creates the evoformer tokens as one token per polymer residue and one token
per ligand atom. The tokens are represented as AtomLayouts all_tokens
(1 representative atom per token) atoms per residue, and
all_token_atoms_layout (num_tokens, max_atoms_per_token). The atoms in a
residue token use the layout of the corresponding CCD entry
Args:
flat_output_layout: flat AtomLayout containing all atoms that the model
wants to predict.
ccd: The chemical components dictionary.
max_atoms_per_token: number of slots per token.
flatten_non_standard_residues: whether to flatten non-standard residues,
i.e. whether to use one token per atom for non-standard residues.
logging_name: logging name for debugging (usually the mmcif_id).
Returns:
A tuple (all_tokens, all_tokens_atoms_layout) with
all_tokens: AtomLayout shape (num_tokens,) containing one representative
atom per token.
all_token_atoms_layout: AtomLayout with shape
(num_tokens, max_atoms_per_token) containing all atoms per token.
standard_token_idxs: The token index that each token would have if not
flattening non standard resiudes.
"""
# Select the representative atom for each token.
token_idxs = []
single_atom_token = []
standard_token_idxs = []
current_standard_token_id = 0
# Iterate over residues, and provide a group_iter over the atoms of each
# residue.
for key, group_iter in itertools.groupby(
zip(
flat_output_layout.chain_type,
flat_output_layout.chain_id,
flat_output_layout.res_id,
flat_output_layout.res_name,
flat_output_layout.atom_name,
np.arange(flat_output_layout.shape[0]),
),
key=lambda x: x[:3],
):
# Get chain type and chain id of this residue
chain_type, chain_id, _ = key
# Get names and global idxs for all atoms of this residue
_, _, _, res_names, atom_names, idxs = zip(*group_iter)
# As of March 2023, all OTHER CHAINs in pdb are artificial nucleics.
is_nucleic_backbone = (
chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES
or chain_type == mmcif_names.OTHER_CHAIN
)
if chain_type in mmcif_names.PEPTIDE_CHAIN_TYPES:
res_name = res_names[0]
if (
flatten_non_standard_residues
and res_name not in residue_names.PROTEIN_TYPES_WITH_UNKNOWN
and res_name != residue_names.MSE
):
# For non-standard protein residues take all atoms.
# NOTE: This may get very large if we include hydrogens.
token_idxs.extend(idxs)
single_atom_token += [True] * len(idxs)
standard_token_idxs.extend([current_standard_token_id] * len(idxs))
else:
# For standard protein residues take 'CA' if it exists, else first atom.
if 'CA' in atom_names:
token_idxs.append(idxs[atom_names.index('CA')])
else:
token_idxs.append(idxs[0])
single_atom_token += [False]
standard_token_idxs.append(current_standard_token_id)
current_standard_token_id += 1
elif is_nucleic_backbone:
res_name = res_names[0]
if (
flatten_non_standard_residues
and res_name not in residue_names.NUCLEIC_TYPES_WITH_2_UNKS
):
# For non-standard nucleic residues take all atoms.
token_idxs.extend(idxs)
single_atom_token += [True] * len(idxs)
standard_token_idxs.extend([current_standard_token_id] * len(idxs))
else:
# For standard nucleic residues take C1' if it exists, else first atom.
if "C1'" in atom_names:
token_idxs.append(idxs[atom_names.index("C1'")])
else:
token_idxs.append(idxs[0])
single_atom_token += [False]
standard_token_idxs.append(current_standard_token_id)
current_standard_token_id += 1
elif chain_type in mmcif_names.NON_POLYMER_CHAIN_TYPES:
# For non-polymers take all atoms
token_idxs.extend(idxs)
single_atom_token += [True] * len(idxs)
standard_token_idxs.extend([current_standard_token_id] * len(idxs))
current_standard_token_id += len(idxs)
else:
# Chain type that we don't handle yet.
logging.warning(
'%s: ignoring chain %s with chain type %s.',
logging_name,
chain_id,
chain_type,
)
assert len(token_idxs) == len(single_atom_token)
assert len(token_idxs) == len(standard_token_idxs)
standard_token_idxs = np.array(standard_token_idxs, dtype=np.int32)
# Create the list of all tokens, represented as a flat AtomLayout with 1
# representative atom per token.
all_tokens = flat_output_layout[token_idxs]
# Create the 2D atoms_per_token layout
num_tokens = all_tokens.shape[0]
# Target lists.
target_atom_names = []
target_atom_elements = []
target_res_ids = []
target_res_names = []
target_chain_ids = []
target_chain_types = []
# uids of all atoms in the flat layout, to check whether the dense atoms
# exist -- This is necessary for terminal atoms (e.g. 'OP3' or 'OXT')
all_atoms_uids = set(
zip(
flat_output_layout.chain_id,
flat_output_layout.res_id,
flat_output_layout.atom_name,
)
)
for idx, single_atom in enumerate(single_atom_token):
if not single_atom:
# Standard protein and nucleic residues have many atoms per token
chain_id = all_tokens.chain_id[idx]
res_id = all_tokens.res_id[idx]
res_name = all_tokens.res_name[idx]
atom_names = []
atom_elements = []
res_atoms = struc_chem_comps.get_all_atoms_in_entry(
ccd=ccd, res_name=res_name
)
atom_names_elements = list(
zip(
res_atoms['_chem_comp_atom.atom_id'],
res_atoms['_chem_comp_atom.type_symbol'],
strict=True,
)
)
for atom_name, atom_element in atom_names_elements:
# Remove hydrogens if they are not in flat layout.
if atom_element in ['H', 'D'] and (
(chain_id, res_id, atom_name) not in all_atoms_uids
):
continue
elif (chain_id, res_id, atom_name) in all_atoms_uids:
atom_names.append(atom_name)
atom_elements.append(atom_element)
# Leave spaces for OXT etc.
else:
atom_names.append('')
atom_elements.append('')
if len(atom_names) > max_atoms_per_token:
logging.warning(
'Atom list for chain %s '
'residue %s %s is too long and will be truncated: '
'%s to the max atoms limit %s. Dropped atoms: %s',
chain_id,
res_id,
res_name,
len(atom_names),
max_atoms_per_token,
list(
zip(
atom_names[max_atoms_per_token:],
atom_elements[max_atoms_per_token:],
strict=True,
)
),
)
atom_names = atom_names[:max_atoms_per_token]
atom_elements = atom_elements[:max_atoms_per_token]
num_pad = max_atoms_per_token - len(atom_names)
atom_names.extend([''] * num_pad)
atom_elements.extend([''] * num_pad)
else:
# ligands have only 1 atom per token
padding = [''] * (max_atoms_per_token - 1)
atom_names = [all_tokens.atom_name[idx]] + padding
atom_elements = [all_tokens.atom_element[idx]] + padding
# Append the atoms to the target lists.
target_atom_names.append(atom_names)
target_atom_elements.append(atom_elements)
target_res_names.append([all_tokens.res_name[idx]] * max_atoms_per_token)
target_res_ids.append([all_tokens.res_id[idx]] * max_atoms_per_token)
target_chain_ids.append([all_tokens.chain_id[idx]] * max_atoms_per_token)
target_chain_types.append(
[all_tokens.chain_type[idx]] * max_atoms_per_token
)
# Make sure to get the right shape also for 0 tokens
trg_shape = (num_tokens, max_atoms_per_token)
all_token_atoms_layout = atom_layout.AtomLayout(
atom_name=np.array(target_atom_names, dtype=object).reshape(trg_shape),
atom_element=np.array(target_atom_elements, dtype=object).reshape(
trg_shape
),
res_name=np.array(target_res_names, dtype=object).reshape(trg_shape),
res_id=np.array(target_res_ids, dtype=int).reshape(trg_shape),
chain_id=np.array(target_chain_ids, dtype=object).reshape(trg_shape),
chain_type=np.array(target_chain_types, dtype=object).reshape(trg_shape),
)
return all_tokens, all_token_atoms_layout, standard_token_idxs
@dataclasses.dataclass(frozen=True)
class MSA:
"""Dataclass containing MSA."""
rows: xnp_ndarray
mask: xnp_ndarray
deletion_matrix: xnp_ndarray
# Occurrence of each residue type along the sequence, averaged over MSA rows.
profile: xnp_ndarray
# Occurrence of deletions along the sequence, averaged over MSA rows.
deletion_mean: xnp_ndarray
# Number of MSA alignments.
num_alignments: xnp_ndarray
@classmethod
def compute_features(
cls,
*,
all_tokens: atom_layout.AtomLayout,
standard_token_idxs: np.ndarray,
padding_shapes: PaddingShapes,
fold_input: folding_input.Input,
logging_name: str,
max_paired_sequence_per_species: int,
resolve_msa_overlaps: bool = True,
) -> Self:
"""Compute the msa features."""
seen_entities = {}
substruct = atom_layout.make_structure(
flat_layout=all_tokens,
atom_coords=np.zeros(all_tokens.shape + (3,)),
name=logging_name,
)
prot = substruct.filter_to_entity_type(protein=True)
num_unique_chains = len(set(prot.chain_single_letter_sequence().values()))
need_msa_pairing = num_unique_chains > 1
np_chains_list = []
input_chains_by_id = {chain.id: chain for chain in fold_input.chains}
nonempty_chain_ids = set(all_tokens.chain_id)
for asym_id, chain_info in enumerate(substruct.iter_chains(), start=1):
b_chain_id = chain_info['chain_id']
chain_type = chain_info['chain_type']
chain = input_chains_by_id[b_chain_id]
# Generalised "sequence" for ligands (can't trust residue name)
chain_tokens = all_tokens[all_tokens.chain_id == b_chain_id]
assert chain_tokens.res_name is not None
three_letter_sequence = ','.join(chain_tokens.res_name.tolist())
chain_num_tokens = len(chain_tokens.atom_name)
if chain_type in mmcif_names.POLYMER_CHAIN_TYPES:
sequence = substruct.chain_single_letter_sequence()[b_chain_id]
if chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES:
# Only allow nucleic residue types for nucleic chains (can have some
# protein residues in e.g. tRNA, but that causes MSA search failures).
# Replace non nucleic residue types by UNK_NUCLEIC.
nucleic_types_one_letter = (
residue_names.DNA_TYPES_ONE_LETTER
+ residue_names.RNA_TYPES_ONE_LETTER_WITH_UNKNOWN
)
sequence = ''.join([
base
if base in nucleic_types_one_letter
else residue_names.UNK_NUCLEIC_ONE_LETTER
for base in sequence
])
else:
sequence = 'X' * chain_num_tokens
skip_chain = (
chain_type not in mmcif_names.STANDARD_POLYMER_CHAIN_TYPES
or len(sequence) <= 4
or b_chain_id not in nonempty_chain_ids
)
if three_letter_sequence in seen_entities:
entity_id = seen_entities[three_letter_sequence]
else:
entity_id = len(seen_entities) + 1
if chain_type in mmcif_names.STANDARD_POLYMER_CHAIN_TYPES:
unpaired_a3m = ''
paired_a3m = ''
if not skip_chain:
if need_msa_pairing and isinstance(chain, folding_input.ProteinChain):
paired_a3m = chain.paired_msa
if isinstance(
chain, folding_input.RnaChain | folding_input.ProteinChain
):
unpaired_a3m = chain.unpaired_msa
# If we generated the MSA ourselves, it is already deduplicated. If it
# is user-provided, keep it as is to prevent destroying desired pairing.
unpaired_msa = msa_module.Msa.from_a3m(
query_sequence=sequence,
chain_poly_type=chain_type,
a3m=unpaired_a3m,
deduplicate=False,
)
paired_msa = msa_module.Msa.from_a3m(
query_sequence=sequence,
chain_poly_type=chain_type,
a3m=paired_a3m,
deduplicate=False,
)
else:
unpaired_msa = msa_module.Msa.from_empty(
query_sequence='-' * len(sequence),
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
)
paired_msa = msa_module.Msa.from_empty(
query_sequence='-' * len(sequence),
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
)
msa_features = unpaired_msa.featurize()
all_seqs_msa_features = paired_msa.featurize()
msa_features = msa_features | {
f'{k}_all_seq': v for k, v in all_seqs_msa_features.items()
}
feats = msa_features
feats['chain_id'] = b_chain_id
feats['asym_id'] = np.full(chain_num_tokens, asym_id)
feats['entity_id'] = entity_id
np_chains_list.append(feats)
# Add profile features to each chain.
for chain in np_chains_list:
chain.update(
data3.get_profile_features(chain['msa'], chain['deletion_matrix'])
)
# Allow 50% of the MSA to come from MSA pairing.
max_paired_sequences = padding_shapes.msa_size // 2
if need_msa_pairing:
np_chains_list = list(map(dict, np_chains_list))
np_chains_list = msa_pairing.create_paired_features(
np_chains_list,
max_paired_sequences=max_paired_sequences,
nonempty_chain_ids=nonempty_chain_ids,
max_hits_per_species=max_paired_sequence_per_species,
)
if resolve_msa_overlaps:
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(
np_chains_list
)
# Remove all gapped rows from all seqs.
nonempty_asym_ids = []
for chain in np_chains_list:
if chain['chain_id'] in nonempty_chain_ids:
nonempty_asym_ids.append(chain['asym_id'][0])
if 'msa_all_seq' in np_chains_list[0]:
np_chains_list = msa_pairing.remove_all_gapped_rows_from_all_seqs(
np_chains_list, asym_ids=nonempty_asym_ids
)
# Crop MSA rows.
cropped_chains_list = []
for chain in np_chains_list:
unpaired_msa_size, paired_msa_size = (
msa_pairing.choose_paired_unpaired_msa_crop_sizes(
unpaired_msa=chain['msa'],
paired_msa=chain.get('msa_all_seq'),
total_msa_crop_size=padding_shapes.msa_size,
max_paired_sequences=max_paired_sequences,
)
)
cropped_chain = {
'asym_id': chain['asym_id'],
'chain_id': chain['chain_id'],
'profile': chain['profile'],
'deletion_mean': chain['deletion_mean'],
}
for feat in data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES:
if feat in chain:
cropped_chain[feat] = chain[feat][:unpaired_msa_size]
if feat + '_all_seq' in chain:
cropped_chain[feat + '_all_seq'] = chain[feat + '_all_seq'][
:paired_msa_size
]
cropped_chains_list.append(cropped_chain)
# Merge Chains.
# Make sure the chain order is unaltered before slicing with tokens.
curr_chain_order = [chain['chain_id'] for chain in cropped_chains_list]
orig_chain_order = [chain['chain_id'] for chain in substruct.iter_chains()]
assert curr_chain_order == orig_chain_order
np_example = {
'asym_id': np.concatenate(
[c['asym_id'] for c in cropped_chains_list], axis=0
),
}
for feature in data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES:
for feat in [feature, feature + '_all_seq']:
if feat in cropped_chains_list[0]:
np_example[feat] = merging_features.merge_msa_features(
feat, cropped_chains_list
)
for feature in ['profile', 'deletion_mean']:
feature_list = [c[feature] for c in cropped_chains_list]
np_example[feature] = np.concatenate(feature_list, axis=0)
# Crop MSA rows to maximum size given by chains participating in the crop.
max_allowed_unpaired = max([
len(chain['msa'])
for chain in cropped_chains_list
if chain['asym_id'][0] in nonempty_asym_ids
])
np_example['msa'] = np_example['msa'][:max_allowed_unpaired]
if 'msa_all_seq' in np_example:
max_allowed_paired = max([
len(chain['msa_all_seq'])
for chain in cropped_chains_list
if chain['asym_id'][0] in nonempty_asym_ids
])
np_example['msa_all_seq'] = np_example['msa_all_seq'][:max_allowed_paired]
np_example = merging_features.merge_paired_and_unpaired_msa(np_example)
# Crop MSA residues. Need to use the standard token indices, since msa does
# not expand non-standard residues. This means that for expanded residues,
# we get repeated msa columns.
new_cropping_idxs = standard_token_idxs
for feature in data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES:
if feature in np_example:
np_example[feature] = np_example[feature][:, new_cropping_idxs].copy()
for feature in ['profile', 'deletion_mean']:
np_example[feature] = np_example[feature][new_cropping_idxs]
# Make MSA mask.
np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.float32)
# Count MSA size before padding.
num_alignments = np_example['msa'].shape[0]
# Pad:
msa_size, num_tokens = padding_shapes.msa_size, padding_shapes.num_tokens
def safe_cast_int8(x):
return np.clip(x, np.iinfo(np.int8).min, np.iinfo(np.int8).max).astype(
np.int8
)
return MSA(
rows=_pad_to(safe_cast_int8(np_example['msa']), (msa_size, num_tokens)),
mask=_pad_to(
np_example['msa_mask'].astype(bool), (msa_size, num_tokens)
),
# deletion_matrix may be out of int8 range, but we mostly care about
# small values since we arctan it in the model.
deletion_matrix=_pad_to(
safe_cast_int8(np_example['deletion_matrix']),
(msa_size, num_tokens),
),
profile=_pad_to(np_example['profile'], (num_tokens, None)),
deletion_mean=_pad_to(np_example['deletion_mean'], (num_tokens,)),
num_alignments=np.array(num_alignments, dtype=np.int32),
)
def index_msa_rows(self, indices: xnp_ndarray) -> Self:
assert indices.ndim == 1
return MSA(
rows=self.rows[indices, :],
mask=self.mask[indices, :],
deletion_matrix=self.deletion_matrix[indices, :],
profile=self.profile,
deletion_mean=self.deletion_mean,
num_alignments=self.num_alignments,
)
@classmethod
def from_data_dict(cls, batch: BatchDict) -> Self:
output = cls(
rows=batch['msa'],
mask=batch['msa_mask'],
deletion_matrix=batch['deletion_matrix'],
profile=batch['profile'],
deletion_mean=batch['deletion_mean'],
num_alignments=batch['num_alignments'],
)
return output
def as_data_dict(self) -> BatchDict:
return {
'msa': self.rows,
'msa_mask': self.mask,
'deletion_matrix': self.deletion_matrix,
'profile': self.profile,
'deletion_mean': self.deletion_mean,
'num_alignments': self.num_alignments,
}
jax.tree_util.register_dataclass(
MSA,
data_fields=[f.name for f in dataclasses.fields(MSA)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class Templates:
"""Dataclass containing templates."""
# aatype of templates, int32 w shape [num_templates, num_res]
aatype: xnp_ndarray
# atom positions of templates, float32 w shape [num_templates, num_res, 24, 3]
atom_positions: xnp_ndarray
# atom mask of templates, bool w shape [num_templates, num_res, 24]
atom_mask: xnp_ndarray
@classmethod
def compute_features(
cls,
all_tokens: atom_layout.AtomLayout,
standard_token_idxs: np.ndarray,
padding_shapes: PaddingShapes,
fold_input: folding_input.Input,
max_templates: int,
logging_name: str,
) -> Self:
"""Compute the template features."""
seen_entities = {}
polymer_entity_features = {True: {}, False: {}}
substruct = atom_layout.make_structure(
flat_layout=all_tokens,
atom_coords=np.zeros(all_tokens.shape + (3,)),
name=logging_name,
)
np_chains_list = []
input_chains_by_id = {chain.id: chain for chain in fold_input.chains}
nonempty_chain_ids = set(all_tokens.chain_id)
for chain_info in substruct.iter_chains():
chain_id = chain_info['chain_id']
chain_type = chain_info['chain_type']
chain = input_chains_by_id[chain_id]
# Generalised "sequence" for ligands (can't trust residue name)
chain_tokens = all_tokens[all_tokens.chain_id == chain_id]
assert chain_tokens.res_name is not None
three_letter_sequence = ','.join(chain_tokens.res_name.tolist())
chain_num_tokens = len(chain_tokens.atom_name)
# Don't compute features for chains not included in the crop, or ligands.
skip_chain = (
chain_type != mmcif_names.PROTEIN_CHAIN
or chain_num_tokens <= 4 # not cache filled
or chain_id not in nonempty_chain_ids
)
if three_letter_sequence in seen_entities:
entity_id = seen_entities[three_letter_sequence]
else:
entity_id = len(seen_entities) + 1
if entity_id not in polymer_entity_features[skip_chain]:
if skip_chain:
template_features = data3.empty_template_features(chain_num_tokens)
else:
assert isinstance(chain, folding_input.ProteinChain)
sorted_features = []
for template in chain.templates:
struc = structure.from_mmcif(
template.mmcif,
fix_mse_residues=True,
fix_arginines=True,
include_bonds=False,
include_water=False,
include_other=True, # For non-standard polymer chains.
)
hit_features = templates.get_polymer_features(
chain=struc,
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
query_sequence_length=len(chain.sequence),
query_to_hit_mapping=dict(template.query_to_template_map),
)
sorted_features.append(hit_features)
template_features = templates.package_template_features(
hit_features=sorted_features,
include_ligand_features=False,
)
template_features = data3.fix_template_features(
template_features=template_features, num_res=len(chain.sequence)
)
template_features = _reduce_template_features(
template_features, max_templates
)
polymer_entity_features[skip_chain][entity_id] = template_features
seen_entities[three_letter_sequence] = entity_id
feats = polymer_entity_features[skip_chain][entity_id].copy()
feats['chain_id'] = chain_id
np_chains_list.append(feats)
# We pad the num_templates dimension before merging, so that different
# chains can be concatenated on the num_res dimension. Masking will be
# applied so that each chains templates can't see each other.
for chain in np_chains_list:
chain['template_aatype'] = _pad_to(
chain['template_aatype'], (max_templates, None)
)
chain['template_atom_positions'] = _pad_to(
chain['template_atom_positions'], (max_templates, None, None, None)
)
chain['template_atom_mask'] = _pad_to(
chain['template_atom_mask'], (max_templates, None, None)
)
# Merge on token dimension.
np_example = {
ft: np.concatenate([c[ft] for c in np_chains_list], axis=1)
for ft in np_chains_list[0]
if ft in data_constants.TEMPLATE_FEATURES
}
# Crop template data. Need to use the standard token indices, since msa does
# not expand non-standard residues. This means that for expanded residues,
# we get repeated template information.
for feature_name, v in np_example.items():
np_example[feature_name] = v[:max_templates, standard_token_idxs, ...]
# Pad along the token dimension.
templates_features = Templates(
aatype=_pad_to(
np_example['template_aatype'], (None, padding_shapes.num_tokens)
),
atom_positions=_pad_to(
np_example['template_atom_positions'],
(None, padding_shapes.num_tokens, None, None),
),
atom_mask=_pad_to(
np_example['template_atom_mask'].astype(bool),
(None, padding_shapes.num_tokens, None),
),
)
return templates_features
@classmethod
def from_data_dict(cls, batch: BatchDict) -> Self:
"""Make Template from batch dictionary."""
return cls(
aatype=batch['template_aatype'],
atom_positions=batch['template_atom_positions'],
atom_mask=batch['template_atom_mask'],
)
def as_data_dict(self) -> BatchDict:
return {
'template_aatype': self.aatype,
'template_atom_positions': self.atom_positions,
'template_atom_mask': self.atom_mask,
}
jax.tree_util.register_dataclass(
Templates,
data_fields=[f.name for f in dataclasses.fields(Templates)],
meta_fields=[],
)
def _reduce_template_features(
template_features: data3.FeatureDict,
max_templates: int,
) -> data3.FeatureDict:
"""Reduces template features to max num templates and defined feature set."""
num_templates = template_features['template_aatype'].shape[0]
template_keep_mask = np.arange(num_templates) < max_templates
template_fields = data_constants.TEMPLATE_FEATURES + (
'template_release_timestamp',
)
template_features = {
k: v[template_keep_mask]
for k, v in template_features.items()
if k in template_fields
}
return template_features
@dataclasses.dataclass(frozen=True)
class TokenFeatures:
"""Dataclass containing features for tokens."""
residue_index: xnp_ndarray
token_index: xnp_ndarray
aatype: xnp_ndarray
mask: xnp_ndarray
seq_length: xnp_ndarray
# Chain symmetry identifiers
# for an A3B2 stoichiometry the meaning of these features is as follows:
# asym_id: 1 2 3 4 5
# entity_id: 1 1 1 2 2
# sym_id: 1 2 3 1 2
asym_id: xnp_ndarray
entity_id: xnp_ndarray
sym_id: xnp_ndarray
# token type features
is_protein: xnp_ndarray
is_rna: xnp_ndarray
is_dna: xnp_ndarray
is_ligand: xnp_ndarray
is_nonstandard_polymer_chain: xnp_ndarray
is_water: xnp_ndarray
@classmethod
def compute_features(
cls,
all_tokens: atom_layout.AtomLayout,
padding_shapes: PaddingShapes,
) -> Self:
"""Compute the per-token features."""
residue_index = all_tokens.res_id.astype(np.int32)
token_index = np.arange(1, len(all_tokens.atom_name) + 1).astype(np.int32)
aatype = []
for res_name, chain_type in zip(all_tokens.res_name, all_tokens.chain_type):
if chain_type in mmcif_names.POLYMER_CHAIN_TYPES:
res_name = mmcif_names.fix_non_standard_polymer_res(
res_name=res_name, chain_type=chain_type
)
if (
chain_type == mmcif_names.DNA_CHAIN
and res_name == residue_names.UNK_DNA
):
res_name = residue_names.UNK_NUCLEIC_ONE_LETTER
elif chain_type in mmcif_names.NON_POLYMER_CHAIN_TYPES:
res_name = residue_names.UNK
else:
raise ValueError(f'Chain type {chain_type} not polymer or ligand.')
aa = residue_names.POLYMER_TYPES_ORDER_WITH_UNKNOWN_AND_GAP[res_name]
aatype.append(aa)
aatype = np.array(aatype, dtype=np.int32)
mask = np.ones(all_tokens.shape[0], dtype=bool)
chains = _compute_asym_entity_and_sym_id(all_tokens)
m = dict(zip(chains.chain_id, chains.asym_id))
asym_id = np.array([m[c] for c in all_tokens.chain_id], dtype=np.int32)
m = dict(zip(chains.chain_id, chains.entity_id))
entity_id = np.array([m[c] for c in all_tokens.chain_id], dtype=np.int32)
m = dict(zip(chains.chain_id, chains.sym_id))
sym_id = np.array([m[c] for c in all_tokens.chain_id], dtype=np.int32)
seq_length = np.array(all_tokens.shape[0], dtype=np.int32)
is_protein = all_tokens.chain_type == mmcif_names.PROTEIN_CHAIN
is_rna = all_tokens.chain_type == mmcif_names.RNA_CHAIN
is_dna = all_tokens.chain_type == mmcif_names.DNA_CHAIN
is_ligand = np.isin(
all_tokens.chain_type, list(mmcif_names.LIGAND_CHAIN_TYPES)
)
standard_polymer_chain = list(mmcif_names.NON_POLYMER_CHAIN_TYPES) + list(
mmcif_names.STANDARD_POLYMER_CHAIN_TYPES
)
is_nonstandard_polymer_chain = np.isin(
all_tokens.chain_type, standard_polymer_chain, invert=True
)
is_water = all_tokens.chain_type == mmcif_names.WATER
return TokenFeatures(
residue_index=_pad_to(residue_index, (padding_shapes.num_tokens,)),
token_index=_pad_to(token_index, (padding_shapes.num_tokens,)),
aatype=_pad_to(aatype, (padding_shapes.num_tokens,)),
mask=_pad_to(mask, (padding_shapes.num_tokens,)),
asym_id=_pad_to(asym_id, (padding_shapes.num_tokens,)),
entity_id=_pad_to(entity_id, (padding_shapes.num_tokens,)),
sym_id=_pad_to(sym_id, (padding_shapes.num_tokens,)),
seq_length=seq_length,
is_protein=_pad_to(is_protein, (padding_shapes.num_tokens,)),
is_rna=_pad_to(is_rna, (padding_shapes.num_tokens,)),
is_dna=_pad_to(is_dna, (padding_shapes.num_tokens,)),
is_ligand=_pad_to(is_ligand, (padding_shapes.num_tokens,)),
is_nonstandard_polymer_chain=_pad_to(
is_nonstandard_polymer_chain, (padding_shapes.num_tokens,)
),
is_water=_pad_to(is_water, (padding_shapes.num_tokens,)),
)
@classmethod
def from_data_dict(cls, batch: BatchDict) -> Self:
return cls(
residue_index=batch['residue_index'],
token_index=batch['token_index'],
aatype=batch['aatype'],
mask=batch['seq_mask'],
entity_id=batch['entity_id'],
asym_id=batch['asym_id'],
sym_id=batch['sym_id'],
seq_length=batch['seq_length'],
is_protein=batch['is_protein'],
is_rna=batch['is_rna'],
is_dna=batch['is_dna'],
is_ligand=batch['is_ligand'],
is_nonstandard_polymer_chain=batch['is_nonstandard_polymer_chain'],
is_water=batch['is_water'],
)
def as_data_dict(self) -> BatchDict:
return {
'residue_index': self.residue_index,
'token_index': self.token_index,
'aatype': self.aatype,
'seq_mask': self.mask,
'entity_id': self.entity_id,
'asym_id': self.asym_id,
'sym_id': self.sym_id,
'seq_length': self.seq_length,
'is_protein': self.is_protein,
'is_rna': self.is_rna,
'is_dna': self.is_dna,
'is_ligand': self.is_ligand,
'is_nonstandard_polymer_chain': self.is_nonstandard_polymer_chain,
'is_water': self.is_water,
}
jax.tree_util.register_dataclass(
TokenFeatures,
data_fields=[f.name for f in dataclasses.fields(TokenFeatures)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class PredictedStructureInfo:
"""Contains information necessary to work with predicted structure."""
atom_mask: xnp_ndarray
residue_center_index: xnp_ndarray
@classmethod
def compute_features(
cls,
all_tokens: atom_layout.AtomLayout,
all_token_atoms_layout: atom_layout.AtomLayout,
padding_shapes: PaddingShapes,
) -> Self:
"""Compute the PredictedStructureInfo features.
Args:
all_tokens: flat AtomLayout with 1 representative atom per token, shape
(num_tokens,)
all_token_atoms_layout: AtomLayout for all atoms per token, shape
(num_tokens, max_atoms_per_token)
padding_shapes: padding shapes.
Returns:
A PredictedStructureInfo object.
"""
atom_mask = _pad_to(
all_token_atoms_layout.atom_name.astype(bool),
(padding_shapes.num_tokens, None),
)
residue_center_index = np.zeros(padding_shapes.num_tokens, dtype=np.int32)
for idx in range(all_tokens.shape[0]):
repr_atom = all_tokens.atom_name[idx]
atoms = list(all_token_atoms_layout.atom_name[idx, :])
if repr_atom in atoms:
residue_center_index[idx] = atoms.index(repr_atom)
else:
# Representative atoms can be missing if cropping the number of atoms
# per residue.
logging.warning(
'The representative atom in all_tokens (%s) is not in '
'all_token_atoms_layout (%s)',
all_tokens[idx : idx + 1],
all_token_atoms_layout[idx, :],
)
residue_center_index[idx] = 0
return cls(atom_mask=atom_mask, residue_center_index=residue_center_index)
@classmethod
def from_data_dict(cls, batch: BatchDict) -> Self:
return cls(
atom_mask=batch['pred_dense_atom_mask'],
residue_center_index=batch['residue_center_index'],
)
def as_data_dict(self) -> BatchDict:
return {
'pred_dense_atom_mask': self.atom_mask,
'residue_center_index': self.residue_center_index,
}
jax.tree_util.register_dataclass(
PredictedStructureInfo,
data_fields=[f.name for f in dataclasses.fields(PredictedStructureInfo)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class PolymerLigandBondInfo:
"""Contains information about polymer-ligand bonds."""
tokens_to_polymer_ligand_bonds: atom_layout.GatherInfo
# Gather indices to convert from cropped dense atom layout to bonds layout
# (num_tokens, 2)
token_atoms_to_bonds: atom_layout.GatherInfo
@classmethod
def compute_features(
cls,
all_tokens: atom_layout.AtomLayout,
all_token_atoms_layout: atom_layout.AtomLayout,
bond_layout: atom_layout.AtomLayout | None,
padding_shapes: PaddingShapes,
) -> Self:
"""Computes the InterChainBondInfo features.
Args:
all_tokens: AtomLayout for tokens; shape (num_tokens,).
all_token_atoms_layout: Atom Layout for all atoms (num_tokens,
max_atoms_per_token)
bond_layout: Bond layout for polymer-ligand bonds.
padding_shapes: Padding shapes.
Returns:
A PolymerLigandBondInfo object.
"""
if bond_layout is not None:
# Must convert to list before calling np.isin, will not work raw.
peptide_types = list(mmcif_names.PEPTIDE_CHAIN_TYPES)
nucleic_types = list(mmcif_names.NUCLEIC_ACID_CHAIN_TYPES) + [
mmcif_names.OTHER_CHAIN
]
# These atom renames are so that we can use the atom layout code with
# all_tokens, which only has a single atom per token.
atom_names = bond_layout.atom_name.copy()
atom_names[np.isin(bond_layout.chain_type, peptide_types)] = 'CA'
atom_names[np.isin(bond_layout.chain_type, nucleic_types)] = "C1'"
adjusted_bond_layout = atom_layout.AtomLayout(
atom_name=atom_names,
res_id=bond_layout.res_id,
chain_id=bond_layout.chain_id,
chain_type=bond_layout.chain_type,
)
# Remove bonds that are not in the crop.
cropped_tokens_to_bonds = atom_layout.compute_gather_idxs(
source_layout=all_tokens, target_layout=adjusted_bond_layout
)
bond_is_in_crop = np.all(
cropped_tokens_to_bonds.gather_mask, axis=1
).astype(bool)
adjusted_bond_layout = adjusted_bond_layout[bond_is_in_crop, :]
else:
# Create layout with correct shape when bond_layout is None.
s = (0, 2)
adjusted_bond_layout = atom_layout.AtomLayout(
atom_name=np.array([], dtype=object).reshape(s),
res_id=np.array([], dtype=int).reshape(s),
chain_id=np.array([], dtype=object).reshape(s),
)
adjusted_bond_layout = adjusted_bond_layout.copy_and_pad_to(
(padding_shapes.num_tokens, 2)
)
tokens_to_polymer_ligand_bonds = atom_layout.compute_gather_idxs(
source_layout=all_tokens, target_layout=adjusted_bond_layout
)
# Stuff for computing the bond loss.
if bond_layout is not None:
# Pad to num_tokens (hoping that there are never more bonds than tokens).
padded_bond_layout = bond_layout.copy_and_pad_to(
(padding_shapes.num_tokens, 2)
)
token_atoms_to_bonds = atom_layout.compute_gather_idxs(
source_layout=all_token_atoms_layout, target_layout=padded_bond_layout
)
else:
token_atoms_to_bonds = atom_layout.GatherInfo(
gather_idxs=np.zeros((padding_shapes.num_tokens, 2), dtype=int),
gather_mask=np.zeros((padding_shapes.num_tokens, 2), dtype=bool),
input_shape=np.array((
padding_shapes.num_tokens,
all_token_atoms_layout.shape[1],
)),
)
return cls(
tokens_to_polymer_ligand_bonds=tokens_to_polymer_ligand_bonds,
token_atoms_to_bonds=token_atoms_to_bonds,
)
@classmethod
def from_data_dict(cls, batch: BatchDict) -> Self:
return cls(
tokens_to_polymer_ligand_bonds=atom_layout.GatherInfo.from_dict(
batch, key_prefix='tokens_to_polymer_ligand_bonds'
),
token_atoms_to_bonds=atom_layout.GatherInfo.from_dict(
batch, key_prefix='token_atoms_to_polymer_ligand_bonds'
),
)
def as_data_dict(self) -> BatchDict:
return {
**self.tokens_to_polymer_ligand_bonds.as_dict(
key_prefix='tokens_to_polymer_ligand_bonds'
),
**self.token_atoms_to_bonds.as_dict(
key_prefix='token_atoms_to_polymer_ligand_bonds'
),
}
jax.tree_util.register_dataclass(
PolymerLigandBondInfo,
data_fields=[f.name for f in dataclasses.fields(PolymerLigandBondInfo)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class LigandLigandBondInfo:
"""Contains information about the location of ligand-ligand bonds."""
tokens_to_ligand_ligand_bonds: atom_layout.GatherInfo
@classmethod
def compute_features(
cls,
all_tokens: atom_layout.AtomLayout,
bond_layout: atom_layout.AtomLayout | None,
padding_shapes: PaddingShapes,
) -> Self:
"""Computes the InterChainBondInfo features.
Args:
all_tokens: AtomLayout for tokens; shape (num_tokens,).
bond_layout: Bond layout for ligand-ligand bonds.
padding_shapes: Padding shapes.
Returns:
A LigandLigandBondInfo object.
"""
if bond_layout is not None:
# Discard any bonds that do not join to an existing atom.
keep_mask = []
all_atom_ids = {
uid
for uid in zip(
all_tokens.chain_id,
all_tokens.res_id,
all_tokens.atom_name,
strict=True,
)
}
for chain_id, res_id, atom_name in zip(
bond_layout.chain_id,
bond_layout.res_id,
bond_layout.atom_name,
strict=True,
):
atom_a = (chain_id[0], res_id[0], atom_name[0])
atom_b = (chain_id[1], res_id[1], atom_name[1])
if atom_a in all_atom_ids and atom_b in all_atom_ids:
keep_mask.append(True)
else:
keep_mask.append(False)
keep_mask = np.array(keep_mask).astype(bool)
bond_layout = bond_layout[keep_mask]
# Remove any bonds to Hydrogen atoms.
bond_layout = bond_layout[
~np.char.startswith(bond_layout.atom_name.astype(str), 'H').any(
axis=1
)
]
atom_names = bond_layout.atom_name
adjusted_bond_layout = atom_layout.AtomLayout(
atom_name=atom_names,
res_id=bond_layout.res_id,
chain_id=bond_layout.chain_id,
chain_type=bond_layout.chain_type,
)
else:
# Create layout with correct shape when bond_layout is None.
s = (0, 2)
adjusted_bond_layout = atom_layout.AtomLayout(
atom_name=np.array([], dtype=object).reshape(s),
res_id=np.array([], dtype=int).reshape(s),
chain_id=np.array([], dtype=object).reshape(s),
)
# 10 x num_tokens as max_inter_bonds_ratio + max_intra_bonds_ration = 2.061.
adjusted_bond_layout = adjusted_bond_layout.copy_and_pad_to(
(padding_shapes.num_tokens * 10, 2)
)
gather_idx = atom_layout.compute_gather_idxs(
source_layout=all_tokens, target_layout=adjusted_bond_layout
)
return cls(tokens_to_ligand_ligand_bonds=gather_idx)
@classmethod
def from_data_dict(cls, batch: BatchDict) -> Self:
return cls(
tokens_to_ligand_ligand_bonds=atom_layout.GatherInfo.from_dict(
batch, key_prefix='tokens_to_ligand_ligand_bonds'
)
)
def as_data_dict(self) -> BatchDict:
return {
**self.tokens_to_ligand_ligand_bonds.as_dict(
key_prefix='tokens_to_ligand_ligand_bonds'
)
}
jax.tree_util.register_dataclass(
LigandLigandBondInfo,
data_fields=[f.name for f in dataclasses.fields(LigandLigandBondInfo)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class PseudoBetaInfo:
"""Contains information for extracting pseudo-beta and equivalent atoms."""
token_atoms_to_pseudo_beta: atom_layout.GatherInfo
@classmethod
def compute_features(
cls,
all_token_atoms_layout: atom_layout.AtomLayout,
ccd: chemical_components.Ccd,
padding_shapes: PaddingShapes,
logging_name: str,
) -> Self:
"""Compute the PseudoBetaInfo features.
Args:
all_token_atoms_layout: AtomLayout for all atoms per token, shape
(num_tokens, max_atoms_per_token)
ccd: The chemical components dictionary.
padding_shapes: padding shapes.
logging_name: logging name for debugging (usually the mmcif_id)
Returns:
A PseudoBetaInfo object.
"""
token_idxs = []
atom_idxs = []
for token_idx in range(all_token_atoms_layout.shape[0]):
chain_type = all_token_atoms_layout.chain_type[token_idx, 0]
atom_names = list(all_token_atoms_layout.atom_name[token_idx, :])
atom_idx = None
is_nucleic_backbone = (
chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES
or chain_type == mmcif_names.OTHER_CHAIN
)
if chain_type == mmcif_names.PROTEIN_CHAIN:
# Protein chains
if 'CB' in atom_names:
atom_idx = atom_names.index('CB')
elif 'CA' in atom_names:
atom_idx = atom_names.index('CA')
elif is_nucleic_backbone:
# RNA / DNA chains
res_name = all_token_atoms_layout.res_name[token_idx, 0]
cifdict = ccd.get(res_name)
if cifdict:
parent = cifdict['_chem_comp.mon_nstd_parent_comp_id'][0]
if parent != '?':
res_name = parent
if res_name in {'A', 'G', 'DA', 'DG'}:
if 'C4' in atom_names:
atom_idx = atom_names.index('C4')
else:
if 'C2' in atom_names:
atom_idx = atom_names.index('C2')
elif chain_type in mmcif_names.NON_POLYMER_CHAIN_TYPES:
# Ligands: there is only one atom per token
atom_idx = 0
else:
logging.warning(
'%s: Unknown chain type for token %i. (%s)',
logging_name,
token_idx,
all_token_atoms_layout[token_idx : token_idx + 1],
)
atom_idx = 0
if atom_idx is None:
(valid_atom_idxs,) = np.nonzero(
all_token_atoms_layout.atom_name[token_idx, :]
)
if valid_atom_idxs.shape[0] > 0:
atom_idx = valid_atom_idxs[0]
else:
atom_idx = 0
logging.warning(
'%s token %i (%s), does not contain a pseudo-beta atom.'
'Using first valid atom (%s) instead.',
logging_name,
token_idx,
all_token_atoms_layout[token_idx : token_idx + 1],
all_token_atoms_layout.atom_name[token_idx, atom_idx],
)
token_idxs.append(token_idx)
atom_idxs.append(atom_idx)
pseudo_beta_layout = all_token_atoms_layout[token_idxs, atom_idxs]
pseudo_beta_layout = pseudo_beta_layout.copy_and_pad_to((
padding_shapes.num_tokens,
))
token_atoms_to_pseudo_beta = atom_layout.compute_gather_idxs(
source_layout=all_token_atoms_layout, target_layout=pseudo_beta_layout
)
return cls(
token_atoms_to_pseudo_beta=token_atoms_to_pseudo_beta,
)
@classmethod
def from_data_dict(cls, batch: BatchDict) -> Self:
return cls(
token_atoms_to_pseudo_beta=atom_layout.GatherInfo.from_dict(
batch, key_prefix='token_atoms_to_pseudo_beta'
),
)
def as_data_dict(self) -> BatchDict:
return {
**self.token_atoms_to_pseudo_beta.as_dict(
key_prefix='token_atoms_to_pseudo_beta'
),
}
jax.tree_util.register_dataclass(
PseudoBetaInfo,
data_fields=[f.name for f in dataclasses.fields(PseudoBetaInfo)],
meta_fields=[],
)
_DEFAULT_BLANK_REF = {
'positions': np.zeros(3),
'mask': 0,
'element': 0,
'charge': 0,
'atom_name_chars': np.zeros(4),
}
def random_rotation(random_state: np.random.RandomState) -> np.ndarray:
# Create a random rotation (Gram-Schmidt orthogonalization of two
# random normal vectors)
v0, v1 = random_state.normal(size=(2, 3))
e0 = v0 / np.maximum(1e-10, np.linalg.norm(v0))
v1 = v1 - e0 * np.dot(v1, e0)
e1 = v1 / np.maximum(1e-10, np.linalg.norm(v1))
e2 = np.cross(e0, e1)
return np.stack([e0, e1, e2])
def random_augmentation(
positions: np.ndarray,
random_state: np.random.RandomState,
) -> np.ndarray:
"""Center then apply random translation and rotation."""
center = np.mean(positions, axis=0)
rot = random_rotation(random_state)
positions_target = np.einsum('ij,kj->ki', rot, positions - center)
translation = random_state.normal(size=(3,))
positions_target = positions_target + translation
return positions_target
def _get_reference_positions_from_ccd_cif(
ccd_cif: cif_dict.CifDict,
ref_max_modified_date: datetime.date,
logging_name: str,
) -> np.ndarray:
"""Creates reference positions from a CCD mmcif data block."""
num_atoms = len(ccd_cif['_chem_comp_atom.atom_id'])
if '_chem_comp_atom.pdbx_model_Cartn_x_ideal' in ccd_cif:
atom_x = ccd_cif['_chem_comp_atom.pdbx_model_Cartn_x_ideal']
atom_y = ccd_cif['_chem_comp_atom.pdbx_model_Cartn_y_ideal']
atom_z = ccd_cif['_chem_comp_atom.pdbx_model_Cartn_z_ideal']
else:
atom_x = np.array(['?'] * num_atoms)
atom_y = np.array(['?'] * num_atoms)
atom_z = np.array(['?'] * num_atoms)
pos = np.array([[x, y, z] for x, y, z in zip(atom_x, atom_y, atom_z)])
# Unknown reference coordinates are specified by '?' in chem comp dict.
# Replace unknown reference coords with 0.
if '?' in pos and '_chem_comp.pdbx_modified_date' in ccd_cif:
# Use reference coordinates if modifed date is before cutoff.
modified_dates = [
datetime.date.fromisoformat(date)
for date in ccd_cif['_chem_comp.pdbx_modified_date']
]
max_modified_date = max(modified_dates)
if max_modified_date < ref_max_modified_date:
atom_x = ccd_cif['_chem_comp_atom.model_Cartn_x']
atom_y = ccd_cif['_chem_comp_atom.model_Cartn_y']
atom_z = ccd_cif['_chem_comp_atom.model_Cartn_z']
pos = np.array([[x, y, z] for x, y, z in zip(atom_x, atom_y, atom_z)])
if '?' in pos:
if np.all(pos == '?'):
logging.warning('All ref positions unknown for: %s', logging_name)
else:
logging.warning('Some ref positions unknown for: %s', logging_name)
pos[pos == '?'] = 0
return np.array(pos, dtype=np.float32)
def get_reference(
res_name: str,
chemical_components_data: struc_chem_comps.ChemicalComponentsData,
ccd: chemical_components.Ccd,
random_state: np.random.RandomState,
ref_max_modified_date: datetime.date,
conformer_max_iterations: int | None,
) -> tuple[dict[str, Any], Any, Any]:
"""Reference structure for residue from CCD or SMILES.
Uses CCD entry if available, otherwise uses SMILES from chemical components
data. Conformer generation is done using RDKit, with a fallback to CCD ideal
or reference coordinates if RDKit fails and those coordinates are supplied.
Args:
res_name: ccd code of the residue.
chemical_components_data: ChemicalComponentsData for making ref structure.
ccd: The chemical components dictionary.
random_state: Numpy RandomState
ref_max_modified_date: date beyond which reference structures must not be
modified to be allowed to use reference coordinates.
conformer_max_iterations: Optional override for maximum number of iterations
to run for RDKit conformer search.
Returns:
Mapping from atom names to features, from_atoms, dest_atoms.
"""
ccd_cif = ccd.get(res_name)
mol = None
if ccd_cif:
try:
mol = rdkit_utils.mol_from_ccd_cif(ccd_cif, remove_hydrogens=False)
except rdkit_utils.MolFromMmcifError:
logging.warning('Failed to construct mol from ccd_cif for: %s', res_name)
else: # No CCD entry, use SMILES from chemical components data.
if not (
chemical_components_data.chem_comp
and res_name in chemical_components_data.chem_comp
and chemical_components_data.chem_comp[res_name].pdbx_smiles
):
raise ValueError(f'No CCD entry or SMILES for {res_name}.')
smiles_string = chemical_components_data.chem_comp[res_name].pdbx_smiles
logging.info('Using SMILES for: %s - %s', res_name, smiles_string)
mol = Chem.MolFromSmiles(smiles_string)
if mol is None:
# In this case the model will not have any information about this molecule
# and will not be able to predict anything about it.
raise ValueError(
f'Failed to construct RDKit Mol for {res_name} from SMILES string: '
f'{smiles_string} . This is likely due to an issue with the SMILES '
'string. Note that the userCCD input format provides an alternative '
'way to define custom molecules directly without RDKit or SMILES.'
)
mol = Chem.AddHs(mol)
# No existing names, we assign them from the graph.
mol = rdkit_utils.assign_atom_names_from_graph(mol)
# Temporary CCD cif with just atom and bond information, no coordinates.
ccd_cif = rdkit_utils.mol_to_ccd_cif(mol, component_id='fake_cif')
conformer = None
atom_names = []
elements = []
charges = []
pos = []
# If mol is not None (must be True for SMILES case), then we try and generate
# an RDKit conformer.
if mol is not None:
conformer_random_seed = int(random_state.randint(1, 1 << 31))
conformer = rdkit_utils.get_random_conformer(
mol=mol,
random_seed=conformer_random_seed,
max_iterations=conformer_max_iterations,
logging_name=res_name,
)
if conformer:
for idx, atom in enumerate(mol.GetAtoms()):
atom_names.append(atom.GetProp('atom_name'))
elements.append(atom.GetAtomicNum())
charges.append(atom.GetFormalCharge())
coords = conformer.GetAtomPosition(idx)
pos.append([coords.x, coords.y, coords.z])
pos = np.array(pos, dtype=np.float32)
# If no mol could be generated (can only happen when using CCD), or no
# conformer could be generated from the mol (can happen in either case), then
# use CCD cif instead (which will have zero coordinates for SMILES case).
if conformer is None:
atom_names = ccd_cif['_chem_comp_atom.atom_id']
charges = ccd_cif['_chem_comp_atom.charge']
type_symbols = ccd_cif['_chem_comp_atom.type_symbol']
elements = [
periodic_table.ATOMIC_NUMBER.get(elem_type.capitalize(), 0)
for elem_type in type_symbols
]
pos = _get_reference_positions_from_ccd_cif(
ccd_cif=ccd_cif,
ref_max_modified_date=ref_max_modified_date,
logging_name=res_name,
)
# Augment reference positions.
pos = random_augmentation(pos, random_state)
# Extract atom and bond information from CCD cif.
from_atom = ccd_cif.get('_chem_comp_bond.atom_id_1', None)
dest_atom = ccd_cif.get('_chem_comp_bond.atom_id_2', None)
features = {}
for atom_name in atom_names:
features[atom_name] = {}
idx = atom_names.index(atom_name)
charge = 0 if charges[idx] == '?' else int(charges[idx])
atom_name_chars = np.array([ord(c) - 32 for c in atom_name], dtype=int)
atom_name_chars = _pad_to(atom_name_chars, (4,))
features[atom_name]['positions'] = pos[idx]
features[atom_name]['mask'] = 1
features[atom_name]['element'] = elements[idx]
features[atom_name]['charge'] = charge
features[atom_name]['atom_name_chars'] = atom_name_chars
return features, from_atom, dest_atom
@dataclasses.dataclass(frozen=True)
class RefStructure:
"""Contains ref structure information."""
# Array with positions, float32, shape [num_res, max_atoms_per_token, 3]
positions: xnp_ndarray
# Array with masks, bool, shape [num_res, max_atoms_per_token]
mask: xnp_ndarray
# Array with elements, int32, shape [num_res, max_atoms_per_token]
element: xnp_ndarray
# Array with charges, float32, shape [num_res, max_atoms_per_token]
charge: xnp_ndarray
# Array with atom name characters, int32, [num_res, max_atoms_per_token, 4]
atom_name_chars: xnp_ndarray
# Array with reference space uids, int32, [num_res, max_atoms_per_token]
ref_space_uid: xnp_ndarray
@classmethod
def compute_features(
cls,
all_token_atoms_layout: atom_layout.AtomLayout,
ccd: chemical_components.Ccd,
padding_shapes: PaddingShapes,
chemical_components_data: struc_chem_comps.ChemicalComponentsData,
random_state: np.random.RandomState,
ref_max_modified_date: datetime.date,
conformer_max_iterations: int | None,
ligand_ligand_bonds: atom_layout.AtomLayout | None = None,
) -> tuple[Self, Any]:
"""Reference structure information for each residue."""
# Get features per atom
padded_shape = (padding_shapes.num_tokens, all_token_atoms_layout.shape[1])
result = {
'positions': np.zeros((*padded_shape, 3), 'float32'),
'mask': np.zeros(padded_shape, 'bool'),
'element': np.zeros(padded_shape, 'int32'),
'charge': np.zeros(padded_shape, 'float32'),
'atom_name_chars': np.zeros((*padded_shape, 4), 'int32'),
'ref_space_uid': np.zeros((*padded_shape,), 'int32'),
}
atom_names_all = []
chain_ids_all = []
res_ids_all = []
# Cache reference conformations for each residue.
conformations = {}
ref_space_uids = {}
for idx in np.ndindex(all_token_atoms_layout.shape):
chain_id = all_token_atoms_layout.chain_id[idx]
res_id = all_token_atoms_layout.res_id[idx]
res_name = all_token_atoms_layout.res_name[idx]
is_non_standard = res_name not in _STANDARD_RESIDUES
atom_name = all_token_atoms_layout.atom_name[idx]
if not atom_name:
ref = _DEFAULT_BLANK_REF
else:
if (chain_id, res_id) not in conformations:
conf, from_atom, dest_atom = get_reference(
res_name=res_name,
chemical_components_data=chemical_components_data,
ccd=ccd,
random_state=random_state,
ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=conformer_max_iterations,
)
conformations[(chain_id, res_id)] = conf
if (
is_non_standard
and (from_atom is not None)
and (dest_atom is not None)
):
# Add intra-ligand bond graph
atom_names_ligand = np.stack(
[from_atom, dest_atom], axis=1, dtype=object
)
atom_names_all.append(atom_names_ligand)
res_ids_all.append(
np.full_like(atom_names_ligand, res_id, dtype=int)
)
chain_ids_all.append(
np.full_like(atom_names_ligand, chain_id, dtype=object)
)
conformation = conformations.get(
(chain_id, res_id), {atom_name: _DEFAULT_BLANK_REF}
)
if atom_name not in conformation:
logging.warning(
'Missing atom "%s" for CCD "%s"',
atom_name,
all_token_atoms_layout.res_name[idx],
)
ref = conformation.get(atom_name, _DEFAULT_BLANK_REF)
for k in ref:
result[k][idx] = ref[k]
# Assign a unique reference space id to each component, to determine which
# reference positions live in the same reference space.
space_str_id = (
all_token_atoms_layout.chain_id[idx],
all_token_atoms_layout.res_id[idx],
)
if space_str_id not in ref_space_uids:
ref_space_uids[space_str_id] = len(ref_space_uids)
result['ref_space_uid'][idx] = ref_space_uids[space_str_id]
if atom_names_all:
atom_names_all = np.concatenate(atom_names_all, axis=0)
res_ids_all = np.concatenate(res_ids_all, axis=0)
chain_ids_all = np.concatenate(chain_ids_all, axis=0)
if ligand_ligand_bonds is not None:
adjusted_ligand_ligand_bonds = atom_layout.AtomLayout(
atom_name=np.concatenate(
[ligand_ligand_bonds.atom_name, atom_names_all], axis=0
),
chain_id=np.concatenate(
[ligand_ligand_bonds.chain_id, chain_ids_all], axis=0
),
res_id=np.concatenate(
[ligand_ligand_bonds.res_id, res_ids_all], axis=0
),
)
else:
adjusted_ligand_ligand_bonds = atom_layout.AtomLayout(
atom_name=atom_names_all,
chain_id=chain_ids_all,
res_id=res_ids_all,
)
else:
adjusted_ligand_ligand_bonds = ligand_ligand_bonds
return cls(**result), adjusted_ligand_ligand_bonds
@classmethod
def from_data_dict(cls, batch: BatchDict) -> Self:
return cls(
positions=batch['ref_pos'],
mask=batch['ref_mask'],
element=batch['ref_element'],
charge=batch['ref_charge'],
atom_name_chars=batch['ref_atom_name_chars'],
ref_space_uid=batch['ref_space_uid'],
)
def as_data_dict(self) -> BatchDict:
return {
'ref_pos': self.positions,
'ref_mask': self.mask,
'ref_element': self.element,
'ref_charge': self.charge,
'ref_atom_name_chars': self.atom_name_chars,
'ref_space_uid': self.ref_space_uid,
}
jax.tree_util.register_dataclass(
RefStructure,
data_fields=[f.name for f in dataclasses.fields(RefStructure)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class ConvertModelOutput:
"""Contains atom layout info."""
cleaned_struc: structure.Structure
token_atoms_layout: atom_layout.AtomLayout
flat_output_layout: atom_layout.AtomLayout
empty_output_struc: structure.Structure
polymer_ligand_bonds: atom_layout.AtomLayout
ligand_ligand_bonds: atom_layout.AtomLayout
@classmethod
def compute_features(
cls,
all_token_atoms_layout: atom_layout.AtomLayout,
padding_shapes: PaddingShapes,
cleaned_struc: structure.Structure,
flat_output_layout: atom_layout.AtomLayout,
empty_output_struc: structure.Structure,
polymer_ligand_bonds: atom_layout.AtomLayout,
ligand_ligand_bonds: atom_layout.AtomLayout,
) -> Self:
"""Pads the all_token_atoms_layout and stores other data."""
# Crop and pad the all_token_atoms_layout.
token_atoms_layout = all_token_atoms_layout.copy_and_pad_to(
(padding_shapes.num_tokens, all_token_atoms_layout.shape[1])
)
return cls(
cleaned_struc=cleaned_struc,
token_atoms_layout=token_atoms_layout,
flat_output_layout=flat_output_layout,
empty_output_struc=empty_output_struc,
polymer_ligand_bonds=polymer_ligand_bonds,
ligand_ligand_bonds=ligand_ligand_bonds,
)
@classmethod
def from_data_dict(cls, batch: BatchDict) -> Self:
"""Construct atom layout object from dictionary."""
return cls(
cleaned_struc=_unwrap(batch.get('cleaned_struc', None)),
token_atoms_layout=_unwrap(batch.get('token_atoms_layout', None)),
flat_output_layout=_unwrap(batch.get('flat_output_layout', None)),
empty_output_struc=_unwrap(batch.get('empty_output_struc', None)),
polymer_ligand_bonds=_unwrap(batch.get('polymer_ligand_bonds', None)),
ligand_ligand_bonds=_unwrap(batch.get('ligand_ligand_bonds', None)),
)
def as_data_dict(self) -> BatchDict:
return {
'cleaned_struc': np.array(self.cleaned_struc, object),
'token_atoms_layout': np.array(self.token_atoms_layout, object),
'flat_output_layout': np.array(self.flat_output_layout, object),
'empty_output_struc': np.array(self.empty_output_struc, object),
'polymer_ligand_bonds': np.array(self.polymer_ligand_bonds, object),
'ligand_ligand_bonds': np.array(self.ligand_ligand_bonds, object),
}
jax.tree_util.register_dataclass(
ConvertModelOutput,
data_fields=[f.name for f in dataclasses.fields(ConvertModelOutput)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class AtomCrossAtt:
"""Operate on flat atoms."""
token_atoms_to_queries: atom_layout.GatherInfo
tokens_to_queries: atom_layout.GatherInfo
tokens_to_keys: atom_layout.GatherInfo
queries_to_keys: atom_layout.GatherInfo
queries_to_token_atoms: atom_layout.GatherInfo
@classmethod
def compute_features(
cls,
all_token_atoms_layout: atom_layout.AtomLayout, # (num_tokens, num_dense)
queries_subset_size: int,
keys_subset_size: int,
padding_shapes: PaddingShapes,
) -> Self:
"""Computes gather indices and meta data to work with a flat atom list."""
token_atoms_layout = all_token_atoms_layout.copy_and_pad_to(
(padding_shapes.num_tokens, all_token_atoms_layout.shape[1])
)
token_atoms_mask = token_atoms_layout.atom_name.astype(bool)
flat_layout = token_atoms_layout[token_atoms_mask]
num_atoms = flat_layout.shape[0]
padded_flat_layout = flat_layout.copy_and_pad_to((
padding_shapes.num_atoms,
))
# Create the layout for queries
num_subsets = padding_shapes.num_atoms // queries_subset_size
lay_arr = padded_flat_layout.to_array()
queries_layout = atom_layout.AtomLayout.from_array(
lay_arr.reshape((6, num_subsets, queries_subset_size))
)
# Create the layout for the keys (the key subsets are centered around the
# query subsets)
# Create initial gather indices (contain out-of-bound indices)
subset_centers = np.arange(
queries_subset_size / 2, padding_shapes.num_atoms, queries_subset_size
)
flat_to_key_gathers = (
subset_centers[:, None]
+ np.arange(-keys_subset_size / 2, keys_subset_size / 2)[None, :]
)
flat_to_key_gathers = flat_to_key_gathers.astype(int)
# Shift subsets with out-of-bound indices, such that they are fully within
# the bounds.
for row in range(flat_to_key_gathers.shape[0]):
if flat_to_key_gathers[row, 0] < 0:
flat_to_key_gathers[row, :] -= flat_to_key_gathers[row, 0]
elif flat_to_key_gathers[row, -1] > num_atoms - 1:
overflow = flat_to_key_gathers[row, -1] - (num_atoms - 1)
flat_to_key_gathers[row, :] -= overflow
# Create the keys layout.
keys_layout = padded_flat_layout[flat_to_key_gathers]
# Create gather indices for conversion between token atoms layout,
# queries layout and keys layout.
token_atoms_to_queries = atom_layout.compute_gather_idxs(
source_layout=token_atoms_layout, target_layout=queries_layout
)
token_atoms_to_keys = atom_layout.compute_gather_idxs(
source_layout=token_atoms_layout, target_layout=keys_layout
)
queries_to_keys = atom_layout.compute_gather_idxs(
source_layout=queries_layout, target_layout=keys_layout
)
queries_to_token_atoms = atom_layout.compute_gather_idxs(
source_layout=queries_layout, target_layout=token_atoms_layout
)
# Create gather indices for conversion of tokens layout to
# queries and keys layout
token_idxs = np.arange(padding_shapes.num_tokens).astype(np.int64)
token_idxs = np.broadcast_to(token_idxs[:, None], token_atoms_layout.shape)
tokens_to_queries = atom_layout.GatherInfo(
gather_idxs=atom_layout.convert(
token_atoms_to_queries, token_idxs, layout_axes=(0, 1)
),
gather_mask=atom_layout.convert(
token_atoms_to_queries, token_atoms_mask, layout_axes=(0, 1)
),
input_shape=np.array((padding_shapes.num_tokens,)),
)
tokens_to_keys = atom_layout.GatherInfo(
gather_idxs=atom_layout.convert(
token_atoms_to_keys, token_idxs, layout_axes=(0, 1)
),
gather_mask=atom_layout.convert(
token_atoms_to_keys, token_atoms_mask, layout_axes=(0, 1)
),
input_shape=np.array((padding_shapes.num_tokens,)),
)
return cls(
token_atoms_to_queries=token_atoms_to_queries,
tokens_to_queries=tokens_to_queries,
tokens_to_keys=tokens_to_keys,
queries_to_keys=queries_to_keys,
queries_to_token_atoms=queries_to_token_atoms,
)
@classmethod
def from_data_dict(cls, batch: BatchDict) -> Self:
return cls(
token_atoms_to_queries=atom_layout.GatherInfo.from_dict(
batch, key_prefix='token_atoms_to_queries'
),
tokens_to_queries=atom_layout.GatherInfo.from_dict(
batch, key_prefix='tokens_to_queries'
),
tokens_to_keys=atom_layout.GatherInfo.from_dict(
batch, key_prefix='tokens_to_keys'
),
queries_to_keys=atom_layout.GatherInfo.from_dict(
batch, key_prefix='queries_to_keys'
),
queries_to_token_atoms=atom_layout.GatherInfo.from_dict(
batch, key_prefix='queries_to_token_atoms'
),
)
def as_data_dict(self) -> BatchDict:
return {
**self.token_atoms_to_queries.as_dict(
key_prefix='token_atoms_to_queries'
),
**self.tokens_to_queries.as_dict(key_prefix='tokens_to_queries'),
**self.tokens_to_keys.as_dict(key_prefix='tokens_to_keys'),
**self.queries_to_keys.as_dict(key_prefix='queries_to_keys'),
**self.queries_to_token_atoms.as_dict(
key_prefix='queries_to_token_atoms'
),
}
jax.tree_util.register_dataclass(
AtomCrossAtt,
data_fields=[f.name for f in dataclasses.fields(AtomCrossAtt)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class Frames:
"""Features for backbone frames."""
mask: xnp_ndarray
@classmethod
def compute_features(
cls,
all_tokens: atom_layout.AtomLayout,
all_token_atoms_layout: atom_layout.AtomLayout,
ref_structure: RefStructure,
padding_shapes: PaddingShapes,
) -> Self:
"""Computes features for backbone frames."""
num_tokens = padding_shapes.num_tokens
all_token_atoms_layout = all_token_atoms_layout.copy_and_pad_to(
(num_tokens, all_token_atoms_layout.shape[1])
)
all_token_atoms_to_all_tokens = atom_layout.compute_gather_idxs(
source_layout=all_token_atoms_layout, target_layout=all_tokens
)
ref_coordinates = atom_layout.convert(
all_token_atoms_to_all_tokens,
ref_structure.positions.astype(np.float32),
layout_axes=(0, 1),
)
ref_mask = atom_layout.convert(
all_token_atoms_to_all_tokens,
ref_structure.mask.astype(bool),
layout_axes=(0, 1),
)
ref_mask = ref_mask & all_token_atoms_to_all_tokens.gather_mask.astype(bool)
all_frame_mask = []
# Iterate over tokens
for idx, args in enumerate(
zip(all_tokens.chain_type, all_tokens.chain_id, all_tokens.res_id)
):
chain_type, chain_id, res_id = args
if chain_type in list(mmcif_names.PEPTIDE_CHAIN_TYPES):
frame_mask = True
elif chain_type in list(mmcif_names.NUCLEIC_ACID_CHAIN_TYPES):
frame_mask = True
elif chain_type in list(mmcif_names.NON_POLYMER_CHAIN_TYPES):
# For ligands, build frames from closest atoms from the same molecule.
(local_token_idxs,) = np.where(
(all_tokens.chain_type == chain_type)
& (all_tokens.chain_id == chain_id)
& (all_tokens.res_id == res_id)
)
if len(local_token_idxs) < 3:
frame_mask = False
else:
# [local_tokens]
local_dist = np.linalg.norm(
ref_coordinates[idx] - ref_coordinates[local_token_idxs], axis=-1
)
local_mask = ref_mask[local_token_idxs]
cost = local_dist + 1e8 * ~local_mask
cost = cost + 1e8 * (idx == local_token_idxs)
# [local_tokens]
closest_idxs = np.argsort(cost, axis=0)
# The closest indices index an array of local tokens. Convert this
# to indices of the full (num_tokens,) array.
global_closest_idxs = local_token_idxs[closest_idxs]
# Construct frame by placing the current token at the origin and two
# nearest atoms on either side.
global_frame_idxs = np.array(
(global_closest_idxs[0], idx, global_closest_idxs[1])
)
# Check that the frame atoms are not colinear.
a, b, c = ref_coordinates[global_frame_idxs]
vec1 = a - b
vec2 = c - b
# Reference coordinates can be all zeros, in which case we have
# to explicitly set colinearity.
if np.isclose(np.linalg.norm(vec1, axis=-1), 0) or np.isclose(
np.linalg.norm(vec2, axis=-1), 0
):
is_colinear = True
logging.info('Found identical coordinates: Assigning as colinear.')
else:
vec1 = vec1 / np.linalg.norm(vec1, axis=-1)
vec2 = vec2 / np.linalg.norm(vec2, axis=-1)
cos_angle = np.einsum('...k,...k->...', vec1, vec2)
# <25 degree deviation is considered colinear.
is_colinear = 1 - np.abs(cos_angle) < 0.0937
frame_mask = not is_colinear
else:
# No frame for other chain types.
frame_mask = False
all_frame_mask.append(frame_mask)
all_frame_mask = np.array(all_frame_mask, dtype=bool)
mask = _pad_to(all_frame_mask, (padding_shapes.num_tokens,))
return cls(mask=mask)
@classmethod
def from_data_dict(cls, batch: BatchDict) -> Self:
return cls(mask=batch['frames_mask'])
def as_data_dict(self) -> BatchDict:
return {'frames_mask': self.mask}
jax.tree_util.register_dataclass(
Frames,
data_fields=[f.name for f in dataclasses.fields(Frames)],
meta_fields=[],
)
================================================
FILE: src/alphafold3/model/merging_features.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Methods for merging existing features to create a new example.
Covers:
- Merging features across chains.
- Merging the paired and unpaired parts of the MSA.
"""
from typing import TypeAlias
from alphafold3.model import data_constants
import jax.numpy as jnp
import numpy as np
NUM_SEQ_NUM_RES_MSA_FEATURES = data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES
NUM_SEQ_MSA_FEATURES = data_constants.NUM_SEQ_MSA_FEATURES
MSA_PAD_VALUES = data_constants.MSA_PAD_VALUES
xnp_ndarray: TypeAlias = np.ndarray | jnp.ndarray # pylint: disable=invalid-name
BatchDict: TypeAlias = dict[str, xnp_ndarray]
def _pad_features_to_max(feat_name: str, chains: list[BatchDict], axis: int):
"""Pad a set of features to the maximum size amongst all chains.
Args:
feat_name: The feature name to pad.
chains: A list of chains with associated features.
axis: Which axis to pad to the max.
Returns:
A list of features, all with the same size on the given axis.
"""
max_num_seq = np.max([chain[feat_name].shape[axis] for chain in chains])
padded_feats = []
for chain in chains:
feat = chain[feat_name]
padding = np.zeros_like(feat.shape) # pytype: disable=attribute-error
padding[axis] = max_num_seq - feat.shape[axis] # pytype: disable=attribute-error
padding = [(0, p) for p in padding]
padded_feats.append(
np.pad(
feat,
padding,
mode='constant',
constant_values=MSA_PAD_VALUES[feat_name],
)
)
return padded_feats
def merge_msa_features(feat_name: str, chains: list[BatchDict]) -> np.ndarray:
"""Merges MSA features with shape (NUM_SEQ, NUM_RES) across chains."""
expected_dtype = chains[0][feat_name].dtype
if '_all_seq' in feat_name:
return np.concatenate(
[c.get(feat_name, np.array([], expected_dtype)) for c in chains], axis=1
)
else:
# Since each MSA can be of different lengths, we first need to pad them
# all to the size of the largest MSA before concatenating.
padded_feats = _pad_features_to_max(feat_name, chains, axis=0)
return np.concatenate(padded_feats, axis=1)
def merge_paired_and_unpaired_msa(example: BatchDict) -> BatchDict:
"""Concatenates the paired (all_seq) MSA features with the unpaired ones."""
new_example = dict(example)
for feature_name in NUM_SEQ_NUM_RES_MSA_FEATURES + NUM_SEQ_MSA_FEATURES:
if feature_name in example and feature_name + '_all_seq' in example:
feat = example[feature_name]
feat_all_seq = example[feature_name + '_all_seq']
merged_feat = np.concatenate([feat_all_seq, feat], axis=0)
new_example[feature_name] = merged_feat
new_example['num_alignments'] = np.array(
new_example['msa'].shape[0], dtype=np.int32
)
return new_example
================================================
FILE: src/alphafold3/model/mkdssp_pybind.cc
================================================
// Copyright 2024 DeepMind Technologies Limited
//
// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
//
// To request access to the AlphaFold 3 model parameters, follow the process set
// out at https://github.com/google-deepmind/alphafold3. You may only use these
// if received directly from Google. Use is subject to terms of use available at
// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
#include "alphafold3/model/mkdssp_pybind.h"
#include
#include
#include
#include
#include
#include "absl/strings/string_view.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
namespace alphafold3 {
namespace py = pybind11;
void RegisterModuleMkdssp(pybind11::module m) {
if (!getenv("LIBCIFPP_DATA_DIR")) {
py::module site = py::module::import("site");
py::list paths = py::cast(site.attr("getsitepackages")());
// Find the first path that contains the libcifpp components.cif file.
bool found = false;
for (const auto& py_path : paths) {
auto path_str =
std::filesystem::path(py::cast(py_path)) /
"share/libcifpp/components.cif";
if (std::filesystem::exists(path_str)) {
setenv("LIBCIFPP_DATA_DIR", path_str.parent_path().c_str(), 0);
found = true;
break;
}
}
if (!found) {
throw py::type_error(
"Could not find the libcifpp components.cif file.");
}
}
m.def(
"get_dssp",
[](absl::string_view mmcif, int model_no,
int min_poly_proline_stretch_length,
bool calculate_surface_accessibility) {
cif::file cif_file(mmcif.data(), mmcif.size());
dssp result(cif_file.front(), model_no, min_poly_proline_stretch_length,
calculate_surface_accessibility);
std::stringstream sstream;
result.write_legacy_output(sstream);
return sstream.str();
},
py::arg("mmcif"), py::arg("model_no") = 1,
py::arg("min_poly_proline_stretch_length") = 3,
py::arg("calculate_surface_accessibility") = false,
py::doc("Gets secondary structure from an mmCIF file."));
}
} // namespace alphafold3
================================================
FILE: src/alphafold3/model/mkdssp_pybind.h
================================================
// Copyright 2024 DeepMind Technologies Limited
//
// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
//
// To request access to the AlphaFold 3 model parameters, follow the process set
// out at https://github.com/google-deepmind/alphafold3. You may only use these
// if received directly from Google. Use is subject to terms of use available at
// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_MODEL_MKDSSP_PYBIND_H_
#define ALPHAFOLD3_SRC_ALPHAFOLD3_MODEL_MKDSSP_PYBIND_H_
#include "pybind11/pybind11.h"
namespace alphafold3 {
void RegisterModuleMkdssp(pybind11::module m);
}
#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_MODEL_MKDSSP_PYBIND_H_
================================================
FILE: src/alphafold3/model/mmcif_metadata.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Adds mmCIF metadata (to be ModelCIF-conformant) and author and legal info."""
from typing import Final
from alphafold3.structure import mmcif
import numpy as np
_LICENSE_URL: Final[str] = (
'https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md'
)
_LICENSE: Final[str] = f"""
Non-commercial use only, by using this file you agree to the terms of use found
at {_LICENSE_URL}.
To request access to the AlphaFold 3 model parameters, follow the process set
out at https://github.com/google-deepmind/alphafold3. You may only use these if
received directly from Google. Use is subject to terms of use available at
https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.
""".strip()
_DISCLAIMER: Final[str] = """\
AlphaFold 3 and its output are not intended for, have not been validated for,
and are not approved for clinical use. They are provided "as-is" without any
warranty of any kind, whether expressed or implied. No warranty is given that
use shall not infringe the rights of any third party.
""".strip()
_MMCIF_PAPER_AUTHORS: Final[tuple[str, ...]] = (
'Google DeepMind',
'Isomorphic Labs',
)
# Authors of the mmCIF - we set them to be equal to the authors of the paper.
_MMCIF_AUTHORS: Final[tuple[str, ...]] = _MMCIF_PAPER_AUTHORS
def add_metadata_to_mmcif(
old_cif: mmcif.Mmcif, version: str, model_id: bytes
) -> mmcif.Mmcif:
"""Adds metadata to a mmCIF to make it ModelCIF-conformant."""
cif = {}
# ModelCIF conformation dictionary.
cif['_audit_conform.dict_name'] = ['mmcif_ma.dic']
cif['_audit_conform.dict_version'] = ['1.4.5']
cif['_audit_conform.dict_location'] = [
'https://raw.githubusercontent.com/ihmwg/ModelCIF/master/dist/mmcif_ma.dic'
]
cif['_pdbx_data_usage.id'] = ['1', '2']
cif['_pdbx_data_usage.type'] = ['license', 'disclaimer']
cif['_pdbx_data_usage.details'] = [_LICENSE, _DISCLAIMER]
cif['_pdbx_data_usage.url'] = [_LICENSE_URL, '?']
# Structure author details.
cif['_audit_author.name'] = []
cif['_audit_author.pdbx_ordinal'] = []
for author_index, author_name in enumerate(_MMCIF_AUTHORS, start=1):
cif['_audit_author.name'].append(author_name)
cif['_audit_author.pdbx_ordinal'].append(str(author_index))
# Paper author details.
cif['_citation_author.citation_id'] = []
cif['_citation_author.name'] = []
cif['_citation_author.ordinal'] = []
for author_index, author_name in enumerate(_MMCIF_PAPER_AUTHORS, start=1):
cif['_citation_author.citation_id'].append('primary')
cif['_citation_author.name'].append(author_name)
cif['_citation_author.ordinal'].append(str(author_index))
# Paper citation details.
cif['_citation.id'] = ['primary']
cif['_citation.title'] = [
'Accurate structure prediction of biomolecular interactions with'
' AlphaFold 3'
]
cif['_citation.journal_full'] = ['Nature']
cif['_citation.journal_volume'] = ['630']
cif['_citation.page_first'] = ['493']
cif['_citation.page_last'] = ['500']
cif['_citation.year'] = ['2024']
cif['_citation.journal_id_ASTM'] = ['NATUAS']
cif['_citation.country'] = ['UK']
cif['_citation.journal_id_ISSN'] = ['0028-0836']
cif['_citation.journal_id_CSD'] = ['0006']
cif['_citation.book_publisher'] = ['?']
cif['_citation.pdbx_database_id_PubMed'] = ['38718835']
cif['_citation.pdbx_database_id_DOI'] = ['10.1038/s41586-024-07487-w']
# Type of data in the dataset including data used in the model generation.
cif['_ma_data.id'] = ['1']
cif['_ma_data.name'] = ['Model']
cif['_ma_data.content_type'] = ['model coordinates']
# Description of number of instances for each entity.
cif['_ma_target_entity_instance.asym_id'] = old_cif['_struct_asym.id']
cif['_ma_target_entity_instance.entity_id'] = old_cif[
'_struct_asym.entity_id'
]
cif['_ma_target_entity_instance.details'] = ['.'] * len(
cif['_ma_target_entity_instance.entity_id']
)
# Details about the target entities.
cif['_ma_target_entity.entity_id'] = cif[
'_ma_target_entity_instance.entity_id'
]
cif['_ma_target_entity.data_id'] = ['1'] * len(
cif['_ma_target_entity.entity_id']
)
cif['_ma_target_entity.origin'] = ['.'] * len(
cif['_ma_target_entity.entity_id']
)
# Details of the models being deposited.
cif['_ma_model_list.ordinal_id'] = ['1']
cif['_ma_model_list.model_id'] = ['1']
cif['_ma_model_list.model_group_id'] = ['1']
cif['_ma_model_list.model_name'] = ['Top ranked model']
cif['_ma_model_list.model_group_name'] = [
f'AlphaFold-beta-20231127 ({version})'
]
cif['_ma_model_list.data_id'] = ['1']
cif['_ma_model_list.model_type'] = ['Ab initio model']
# Software used.
cif['_software.pdbx_ordinal'] = ['1']
cif['_software.name'] = ['AlphaFold']
cif['_software.version'] = [
f'AlphaFold-beta-20231127 ({model_id.decode("ascii")})'
]
cif['_software.type'] = ['package']
cif['_software.description'] = ['Structure prediction']
cif['_software.classification'] = ['other']
cif['_software.date'] = ['?']
# Collection of software into groups.
cif['_ma_software_group.ordinal_id'] = ['1']
cif['_ma_software_group.group_id'] = ['1']
cif['_ma_software_group.software_id'] = ['1']
# Method description to conform with ModelCIF.
cif['_ma_protocol_step.ordinal_id'] = ['1', '2', '3']
cif['_ma_protocol_step.protocol_id'] = ['1', '1', '1']
cif['_ma_protocol_step.step_id'] = ['1', '2', '3']
cif['_ma_protocol_step.method_type'] = [
'coevolution MSA',
'template search',
'modeling',
]
# Details of the metrics use to assess model confidence.
cif['_ma_qa_metric.id'] = ['1', '2']
cif['_ma_qa_metric.name'] = ['pLDDT', 'pLDDT']
# Accepted values are distance, energy, normalised score, other, zscore.
cif['_ma_qa_metric.type'] = ['pLDDT', 'pLDDT']
cif['_ma_qa_metric.mode'] = ['global', 'local']
cif['_ma_qa_metric.software_group_id'] = ['1', '1']
# Global model confidence pLDDT value.
cif['_ma_qa_metric_global.ordinal_id'] = ['1']
cif['_ma_qa_metric_global.model_id'] = ['1']
cif['_ma_qa_metric_global.metric_id'] = ['1']
# Mean over all atoms, since AlphaFold 3 outputs pLDDT per-atom.
global_plddt = np.mean(
[float(v) for v in old_cif['_atom_site.B_iso_or_equiv']]
)
cif['_ma_qa_metric_global.metric_value'] = [f'{global_plddt:.2f}']
# Local (per residue) model confidence pLDDT value.
cif['_ma_qa_metric_local.ordinal_id'] = []
cif['_ma_qa_metric_local.model_id'] = []
cif['_ma_qa_metric_local.label_asym_id'] = []
cif['_ma_qa_metric_local.label_seq_id'] = []
cif['_ma_qa_metric_local.label_comp_id'] = []
cif['_ma_qa_metric_local.metric_id'] = []
cif['_ma_qa_metric_local.metric_value'] = []
plddt_grouped_by_res = {}
for *res, atom_plddt in zip(
old_cif['_atom_site.label_asym_id'],
old_cif['_atom_site.label_seq_id'],
old_cif['_atom_site.label_comp_id'],
old_cif['_atom_site.B_iso_or_equiv'],
):
plddt_grouped_by_res.setdefault(tuple(res), []).append(float(atom_plddt))
for ordinal_id, ((chain_id, res_id, res_name), res_plddts) in enumerate(
plddt_grouped_by_res.items(), start=1
):
res_plddt = np.mean(res_plddts)
cif['_ma_qa_metric_local.ordinal_id'].append(str(ordinal_id))
cif['_ma_qa_metric_local.model_id'].append('1')
cif['_ma_qa_metric_local.label_asym_id'].append(chain_id)
cif['_ma_qa_metric_local.label_seq_id'].append(res_id)
cif['_ma_qa_metric_local.label_comp_id'].append(res_name)
cif['_ma_qa_metric_local.metric_id'].append('2') # See _ma_qa_metric.id.
cif['_ma_qa_metric_local.metric_value'].append(f'{res_plddt:.2f}')
cif['_atom_type.symbol'] = sorted(set(old_cif['_atom_site.type_symbol']))
return old_cif.copy_and_update(cif)
def add_legal_comment(cif: str) -> str:
"""Adds legal comment at the top of the mmCIF."""
# fmt: off
# pylint: disable=line-too-long
comment = (
'# By using this file you agree to the legally binding terms of use found at\n'
f'# {_LICENSE_URL}.\n'
'# To request access to the AlphaFold 3 model parameters, follow the process set\n'
'# out at https://github.com/google-deepmind/alphafold3. You may only use these if\n'
'# received directly from Google. Use is subject to terms of use available at\n'
'# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.'
)
# pylint: enable=line-too-long
# fmt: on
return f'{comment}\n{cif}'
================================================
FILE: src/alphafold3/model/model.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""AlphaFold3 model."""
from collections.abc import Iterable, Mapping
import concurrent
import dataclasses
import functools
from typing import Any, TypeAlias
from absl import logging
from alphafold3 import structure
from alphafold3.common import base_config
from alphafold3.model import confidences
from alphafold3.model import feat_batch
from alphafold3.model import features
from alphafold3.model import model_config
from alphafold3.model.atom_layout import atom_layout
from alphafold3.model.components import mapping
from alphafold3.model.components import utils
from alphafold3.model.network import atom_cross_attention
from alphafold3.model.network import confidence_head
from alphafold3.model.network import diffusion_head
from alphafold3.model.network import distogram_head
from alphafold3.model.network import evoformer as evoformer_network
from alphafold3.model.network import featurization
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
ModelResult: TypeAlias = Mapping[str, Any]
@dataclasses.dataclass(frozen=True, kw_only=True)
class InferenceResult:
"""Postprocessed model result.
Attributes:
predicted_structure: Predicted protein structure.
numerical_data: Useful numerical data (scalars or arrays) to be saved at
inference time.
metadata: Smaller numerical data (usually scalar) to be saved as inference
metadata.
debug_outputs: Additional dict for debugging, e.g. raw outputs of a model
forward pass.
model_id: Model identifier.
"""
predicted_structure: structure.Structure = dataclasses.field()
numerical_data: Mapping[str, float | int | np.ndarray] = dataclasses.field(
default_factory=dict
)
metadata: Mapping[str, float | int | np.ndarray] = dataclasses.field(
default_factory=dict
)
debug_outputs: Mapping[str, Any] = dataclasses.field(default_factory=dict)
model_id: bytes = b''
def get_predicted_structure(
result: ModelResult, batch: feat_batch.Batch
) -> structure.Structure:
"""Creates the predicted structure and ion preditions.
Args:
result: model output in a model specific layout
batch: model input batch
Returns:
Predicted structure.
"""
model_output_coords = result['diffusion_samples']['atom_positions']
# Rearrange model output coordinates to the flat output layout.
model_output_to_flat = atom_layout.compute_gather_idxs(
source_layout=batch.convert_model_output.token_atoms_layout,
target_layout=batch.convert_model_output.flat_output_layout,
)
pred_flat_atom_coords = atom_layout.convert(
gather_info=model_output_to_flat,
arr=model_output_coords,
layout_axes=(-3, -2),
)
predicted_lddt = result.get('predicted_lddt')
if predicted_lddt is not None:
pred_flat_b_factors = atom_layout.convert(
gather_info=model_output_to_flat,
arr=predicted_lddt,
layout_axes=(-2, -1),
)
else:
# Handle models which don't have predicted_lddt outputs.
pred_flat_b_factors = np.zeros(pred_flat_atom_coords.shape[:-1])
(missing_atoms_indices,) = np.nonzero(model_output_to_flat.gather_mask == 0)
if missing_atoms_indices.shape[0] > 0:
missing_atoms_flat_layout = batch.convert_model_output.flat_output_layout[
missing_atoms_indices
]
missing_atoms_uids = list(
zip(
missing_atoms_flat_layout.chain_id,
missing_atoms_flat_layout.res_id,
missing_atoms_flat_layout.res_name,
missing_atoms_flat_layout.atom_name,
)
)
logging.warning(
'Target %s: warning: %s atoms were not predicted by the '
'model, setting their coordinates to (0, 0, 0). '
'Missing atoms: %s',
batch.convert_model_output.empty_output_struc.name,
missing_atoms_indices.shape[0],
missing_atoms_uids,
)
# Put them into a structure
pred_struc = batch.convert_model_output.empty_output_struc
pred_struc = pred_struc.copy_and_update_atoms(
atom_x=pred_flat_atom_coords[..., 0],
atom_y=pred_flat_atom_coords[..., 1],
atom_z=pred_flat_atom_coords[..., 2],
atom_b_factor=pred_flat_b_factors,
atom_occupancy=np.ones(pred_flat_atom_coords.shape[:-1]), # Always 1.0.
)
# Set manually/differently when adding metadata.
pred_struc = pred_struc.copy_and_update_globals(release_date=None)
return pred_struc
def create_target_feat_embedding(
batch: feat_batch.Batch,
config: evoformer_network.Evoformer.Config,
global_config: model_config.GlobalConfig,
) -> jnp.ndarray:
"""Create target feature embedding."""
dtype = jnp.bfloat16 if global_config.bfloat16 == 'all' else jnp.float32
with utils.bfloat16_context():
target_feat = featurization.create_target_feat(
batch,
append_per_atom_features=False,
).astype(dtype)
enc = atom_cross_attention.atom_cross_att_encoder(
token_atoms_act=None,
trunk_single_cond=None,
trunk_pair_cond=None,
config=config.per_atom_conditioning,
global_config=global_config,
batch=batch,
name='evoformer_conditioning',
)
target_feat = jnp.concatenate([target_feat, enc.token_act], axis=-1).astype(
dtype
)
return target_feat
def _compute_ptm(
result: ModelResult,
num_tokens: int,
asym_id: np.ndarray,
pae_single_mask: np.ndarray,
interface: bool,
) -> np.ndarray:
"""Computes the pTM metrics from PAE."""
return np.stack(
[
confidences.predicted_tm_score(
tm_adjusted_pae=tm_adjusted_pae[:num_tokens, :num_tokens],
asym_id=asym_id,
pair_mask=pae_single_mask[:num_tokens, :num_tokens],
interface=interface,
)
for tm_adjusted_pae in result['tmscore_adjusted_pae_global']
],
axis=0,
)
def _compute_chain_pair_iptm(
num_tokens: int,
asym_ids: np.ndarray,
mask: np.ndarray,
tm_adjusted_pae: np.ndarray,
) -> np.ndarray:
"""Computes the chain pair ipTM metrics from PAE."""
return np.stack(
[
confidences.chain_pairwise_predicted_tm_scores(
tm_adjusted_pae=sample_tm_adjusted_pae[:num_tokens],
asym_id=asym_ids[:num_tokens],
pair_mask=mask[:num_tokens, :num_tokens],
)
for sample_tm_adjusted_pae in tm_adjusted_pae
],
axis=0,
)
class Model(hk.Module):
"""Full model. Takes in data batch and returns model outputs."""
class HeadsConfig(base_config.BaseConfig):
diffusion: diffusion_head.DiffusionHead.Config = base_config.autocreate()
confidence: confidence_head.ConfidenceHead.Config = base_config.autocreate()
distogram: distogram_head.DistogramHead.Config = base_config.autocreate()
class Config(base_config.BaseConfig):
evoformer: evoformer_network.Evoformer.Config = base_config.autocreate()
global_config: model_config.GlobalConfig = base_config.autocreate()
heads: 'Model.HeadsConfig' = base_config.autocreate()
num_recycles: int = 10
return_embeddings: bool = False
return_distogram: bool = False
def __init__(self, config: Config, name: str = 'diffuser'):
super().__init__(name=name)
self.config = config
self.global_config = config.global_config
self.diffusion_module = diffusion_head.DiffusionHead(
self.config.heads.diffusion, self.global_config
)
@hk.transparent
def _sample_diffusion(
self,
batch: feat_batch.Batch,
embeddings: dict[str, jnp.ndarray],
*,
sample_config: diffusion_head.SampleConfig,
) -> dict[str, jnp.ndarray]:
denoising_step = functools.partial(
self.diffusion_module,
batch=batch,
embeddings=embeddings,
use_conditioning=True,
)
sample = diffusion_head.sample(
denoising_step=denoising_step,
batch=batch,
key=hk.next_rng_key(),
config=sample_config,
)
return sample
def __call__(
self, batch: features.BatchDict, key: jax.Array | None = None
) -> ModelResult:
if key is None:
key = hk.next_rng_key()
batch = feat_batch.Batch.from_data_dict(batch)
embedding_module = evoformer_network.Evoformer(
self.config.evoformer, self.global_config
)
target_feat = create_target_feat_embedding(
batch=batch,
config=embedding_module.config,
global_config=self.global_config,
)
def recycle_body(_, args):
prev, key = args
key, subkey = jax.random.split(key)
embeddings = embedding_module(
batch=batch,
prev=prev,
target_feat=target_feat,
key=subkey,
)
embeddings['pair'] = embeddings['pair'].astype(jnp.float32)
embeddings['single'] = embeddings['single'].astype(jnp.float32)
return embeddings, key
num_res = batch.num_res
embeddings = {
'pair': jnp.zeros(
[num_res, num_res, self.config.evoformer.pair_channel],
dtype=jnp.float32,
),
'single': jnp.zeros(
[num_res, self.config.evoformer.seq_channel], dtype=jnp.float32
),
'target_feat': target_feat,
}
if hk.running_init():
embeddings, _ = recycle_body(None, (embeddings, key))
else:
# Number of recycles is number of additional forward trunk passes.
num_iter = self.config.num_recycles + 1
embeddings, _ = hk.fori_loop(0, num_iter, recycle_body, (embeddings, key))
samples = self._sample_diffusion(
batch,
embeddings,
sample_config=self.config.heads.diffusion.eval,
)
# Compute dist_error_fn over all samples for distance error logging.
confidence_output = mapping.sharded_map(
lambda dense_atom_positions: confidence_head.ConfidenceHead(
self.config.heads.confidence, self.global_config
)(
dense_atom_positions=dense_atom_positions,
embeddings=embeddings,
seq_mask=batch.token_features.mask,
token_atoms_to_pseudo_beta=batch.pseudo_beta_info.token_atoms_to_pseudo_beta,
asym_id=batch.token_features.asym_id,
),
in_axes=0,
)(samples['atom_positions'])
distogram = distogram_head.DistogramHead(
self.config.heads.distogram, self.global_config
)(batch, embeddings, return_distogram=self.config.return_distogram)
output = {
'diffusion_samples': samples,
'distogram': distogram,
**confidence_output,
}
if self.config.return_embeddings:
output['single_embeddings'] = embeddings['single']
output['pair_embeddings'] = embeddings['pair']
return output
@classmethod
def get_inference_result(
cls,
batch: features.BatchDict,
result: ModelResult,
target_name: str = '',
) -> Iterable[InferenceResult]:
"""Get the predicted structure, scalars, and arrays for inference.
This function also computes any inference-time quantities, which are not a
part of the forward-pass, e.g. additional confidence scores. Note that this
function is not serialized, so it should be slim if possible.
Args:
batch: data batch used for model inference, incl. TPU invalid types.
result: output dict from the model's forward pass.
target_name: target name to be saved within structure.
Yields:
inference_result: dataclass object that contains a predicted structure,
important inference-time scalars and arrays, as well as a slightly trimmed
dictionary of raw model result from the forward pass (for debugging).
"""
del target_name
batch = feat_batch.Batch.from_data_dict(batch)
# Retrieve structure and construct a predicted structure.
pred_structure = get_predicted_structure(result=result, batch=batch)
num_tokens = batch.token_features.seq_length.item()
pae_single_mask = np.tile(
batch.frames.mask[:, None],
[1, batch.frames.mask.shape[0]],
)
ptm = _compute_ptm(
result=result,
num_tokens=num_tokens,
asym_id=batch.token_features.asym_id[:num_tokens],
pae_single_mask=pae_single_mask,
interface=False,
)
iptm = _compute_ptm(
result=result,
num_tokens=num_tokens,
asym_id=batch.token_features.asym_id[:num_tokens],
pae_single_mask=pae_single_mask,
interface=True,
)
ptm_iptm_average = 0.8 * iptm + 0.2 * ptm
asym_ids = batch.token_features.asym_id[:num_tokens]
# Map asym IDs back to chain IDs. Asym IDs are constructed from chain IDs by
# iterating over the chain IDs, and for each unique chain ID incrementing
# the asym ID by 1 and mapping it to the particular chain ID. Asym IDs are
# 1-indexed, so subtract 1 to get back to the chain ID.
chain_ids = [pred_structure.chains[asym_id - 1] for asym_id in asym_ids]
res_ids = batch.token_features.residue_index[:num_tokens]
if len(np.unique(asym_ids[:num_tokens])) > 1:
# There is more than one chain, hence interface pTM (i.e. ipTM) defined,
# so use it.
ranking_confidence = ptm_iptm_average
else:
# There is only one chain, hence ipTM=NaN, so use just pTM.
ranking_confidence = ptm
contact_probs = result['distogram']['contact_probs']
# Compute PAE related summaries.
_, chain_pair_pae_min, _ = confidences.chain_pair_pae(
num_tokens=num_tokens,
asym_ids=batch.token_features.asym_id,
full_pae=result['full_pae'],
mask=pae_single_mask,
)
chain_pair_pde_mean, chain_pair_pde_min = confidences.chain_pair_pde(
num_tokens=num_tokens,
asym_ids=batch.token_features.asym_id,
full_pde=result['full_pde'],
)
intra_chain_single_pde, cross_chain_single_pde, _ = confidences.pde_single(
num_tokens,
batch.token_features.asym_id,
result['full_pde'],
contact_probs,
)
pae_metrics = confidences.pae_metrics(
num_tokens=num_tokens,
asym_ids=batch.token_features.asym_id,
full_pae=result['full_pae'],
mask=pae_single_mask,
contact_probs=contact_probs,
tm_adjusted_pae=result['tmscore_adjusted_pae_interface'],
)
ranking_confidence_pae = confidences.rank_metric(
result['full_pae'],
contact_probs * batch.frames.mask[:, None].astype(float),
)
chain_pair_iptm = _compute_chain_pair_iptm(
num_tokens=num_tokens,
asym_ids=batch.token_features.asym_id,
mask=pae_single_mask,
tm_adjusted_pae=result['tmscore_adjusted_pae_interface'],
)
# iptm_ichain is a vector of per-chain ptm values. iptm_ichain[0],
# for example, is just the zeroth diagonal entry of the chain pair iptm
# matrix:
# [[x, , ],
# [ , , ],
# [ , , ]]]
iptm_ichain = chain_pair_iptm.diagonal(axis1=-2, axis2=-1)
# iptm_xchain is a vector of cross-chain interactions for each chain.
# iptm_xchain[0], for example, is an average of chain 0's interactions with
# other chains:
# [[ ,x,x],
# [x, , ],
# [x, , ]]]
iptm_xchain = confidences.get_iptm_xchain(chain_pair_iptm)
predicted_distance_errors = result['average_pde']
# Computing solvent accessible area with dssp can be slow for large
# structures with lots of chains, so we parallelize the call.
pred_structures = pred_structure.unstack()
with concurrent.futures.ThreadPoolExecutor(
max_workers=min(len(pred_structures), 32)
) as executor:
has_clash = list(executor.map(confidences.has_clash, pred_structures))
fraction_disordered = list(
executor.map(confidences.fraction_disordered, pred_structures)
)
for idx, pred_structure in enumerate(pred_structures):
ranking_score = confidences.get_ranking_score(
ptm=ptm[idx],
iptm=iptm[idx],
fraction_disordered_=fraction_disordered[idx],
has_clash_=has_clash[idx],
)
yield InferenceResult(
predicted_structure=pred_structure,
numerical_data={
'full_pde': result['full_pde'][idx, :num_tokens, :num_tokens],
'full_pae': result['full_pae'][idx, :num_tokens, :num_tokens],
'contact_probs': contact_probs[:num_tokens, :num_tokens],
},
metadata={
'predicted_distance_error': predicted_distance_errors[idx],
'ranking_score': ranking_score,
'fraction_disordered': fraction_disordered[idx],
'has_clash': has_clash[idx],
'predicted_tm_score': ptm[idx],
'interface_predicted_tm_score': iptm[idx],
'chain_pair_pde_mean': chain_pair_pde_mean[idx],
'chain_pair_pde_min': chain_pair_pde_min[idx],
'chain_pair_pae_min': chain_pair_pae_min[idx],
'ptm': ptm[idx],
'iptm': iptm[idx],
'ptm_iptm_average': ptm_iptm_average[idx],
'intra_chain_single_pde': intra_chain_single_pde[idx],
'cross_chain_single_pde': cross_chain_single_pde[idx],
'pae_ichain': pae_metrics['pae_ichain'][idx],
'pae_xchain': pae_metrics['pae_xchain'][idx],
'ranking_confidence': ranking_confidence[idx],
'ranking_confidence_pae': ranking_confidence_pae[idx],
'chain_pair_iptm': chain_pair_iptm[idx],
'iptm_ichain': iptm_ichain[idx],
'iptm_xchain': iptm_xchain[idx],
'token_chain_ids': chain_ids,
'token_res_ids': res_ids,
},
model_id=result['__identifier__'],
debug_outputs={},
)
================================================
FILE: src/alphafold3/model/model_config.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Global config for the model."""
from collections.abc import Sequence
from typing import Literal, TypeAlias
from alphafold3.common import base_config
import tokamax
_Shape2DType: TypeAlias = tuple[int | None, int | None]
class GlobalConfig(base_config.BaseConfig):
"""Global configuration for the AlphaFold3 model."""
bfloat16: Literal['all', 'none', 'intermediate'] = 'all'
final_init: Literal['zeros', 'linear'] = 'zeros'
pair_attention_chunk_size: Sequence[_Shape2DType] = ((1536, 128), (None, 32))
pair_transition_shard_spec: Sequence[_Shape2DType] = (
(2048, None),
(None, 1024),
)
# Note: flash_attention_implementation = 'xla' means no flash attention.
flash_attention_implementation: tokamax.DotProductAttentionImplementation = (
'triton'
)
================================================
FILE: src/alphafold3/model/msa_pairing.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Functions for producing "paired" and "unpaired" MSA features for each chain.
The paired MSA:
- Is made from the result of the all_seqs MSA query.
- Is ordered such that you can concatenate features across chains and related
sequences will end up on the same row. Related here means "from the same
species". Gaps are added to facilitate this whenever a sequence has no
suitable pair.
The unpaired MSA:
- Is made from the results of the remaining MSA queries.
- Has no special ordering properties.
- Is deduplicated such that it doesn't contain any sequences in the paired MSA.
"""
from typing import Mapping, MutableMapping, Sequence
from alphafold3.model import data_constants
import numpy as np
def _align_species(
all_species: Sequence[bytes],
chains_species_to_rows: Sequence[Mapping[bytes, np.ndarray]],
min_hits_per_species: Mapping[bytes, int],
) -> np.ndarray:
"""Aligns MSA row indices based on species.
Within a species, MSAs are aligned based on their original order (the first
sequence for a species in the first chain's MSA is aligned to the first
sequence for the same species in the second chain's MSA).
Args:
all_species: A list of all unique species identifiers.
chains_species_to_rows: A dictionary for each chain, that maps species to
the set of MSA row indices from that species in that chain.
min_hits_per_species: A mapping from species id, to the minimum MSA size
across chains for that species (ignoring chains with zero hits).
Returns:
A matrix of size [num_msa_rows, num_chains], where the i,j element is an
index into the jth chains MSA. Each row consists of sequences from each
chain for the same species (or -1 if that chain has no sequences for that
species).
"""
# Each species block is of size [num_seqs x num_chains] and consists of
# indices into the respective MSAs that have been aligned and are all for the
# same species.
species_blocks = []
for species in all_species:
chain_row_indices = []
for species_to_rows in chains_species_to_rows:
min_msa_size = min_hits_per_species[species]
if species not in species_to_rows:
# If a given chain has no hits for a species then we pad it with -1's,
# later on these values are used to make sure each feature is padded
# with its appropriate pad value.
row_indices = np.full(min_msa_size, fill_value=-1, dtype=np.int32)
else:
# We crop down to the smallest MSA for a given species across chains.
row_indices = species_to_rows[species][:min_msa_size]
chain_row_indices.append(row_indices)
species_block = np.stack(chain_row_indices, axis=1)
species_blocks.append(species_block)
aligned_matrix = np.concatenate(species_blocks, axis=0)
return aligned_matrix
def create_paired_features(
chains: Sequence[MutableMapping[str, np.ndarray]],
max_paired_sequences: int,
nonempty_chain_ids: set[str],
max_hits_per_species: int,
) -> Sequence[MutableMapping[str, np.ndarray]]:
"""Creates per-chain MSA features where the MSAs have been aligned.
Args:
chains: A list of feature dicts, one for each chain.
max_paired_sequences: No more than this many paired sequences will be
returned from this function.
nonempty_chain_ids: A set of chain ids (str) that are included in the crop
there is no reason to process chains not in this list.
max_hits_per_species: No more than this number of sequences will be returned
for a given species.
Returns:
An updated feature dictionary for each chain, where the {}_all_seq features
have been aligned so that the nth row in chain 1 is aligned to the nth row
in chain 2's features.
"""
# The number of chains that the given species appears in - we rank hits
# across more chains higher.
species_num_chains = {}
# For each chain we keep a mapping from species to the row indices in the
# original MSA for that chain.
chains_species_to_rows = []
# Keep track of the minimum number of hits across chains for a given species.
min_hits_per_species = {}
for chain in chains:
species_ids = chain['msa_species_identifiers_all_seq']
# The query gets an empty species_id, so no pairing happens for this row.
if (
species_ids.size == 0
or (species_ids.size == 1 and not species_ids[0])
or chain['chain_id'] not in nonempty_chain_ids
):
chains_species_to_rows.append({})
continue
# For each species keep track of which row indices in the original MSA are
# from this species.
row_indices = np.arange(len(species_ids))
# The grouping np.split code requires that the input is already clustered
# by species id.
sort_idxs = species_ids.argsort()
species_ids = species_ids[sort_idxs]
row_indices = row_indices[sort_idxs]
species, unique_row_indices = np.unique(species_ids, return_index=True)
grouped_row_indices = np.split(row_indices, unique_row_indices[1:])
species_to_rows = dict(zip(species, grouped_row_indices, strict=True))
chains_species_to_rows.append(species_to_rows)
for s in species:
species_num_chains[s] = species_num_chains.get(s, 0) + 1
for species, row_indices in species_to_rows.items():
min_hits_per_species[species] = min(
min_hits_per_species.get(species, max_hits_per_species),
len(row_indices),
)
# Construct a mapping from the number of chains a species appears in to
# the list of species with that count.
num_chains_to_species = {}
for species, num_chains in species_num_chains.items():
if not species or num_chains <= 1:
continue
if num_chains not in num_chains_to_species:
num_chains_to_species[num_chains] = []
num_chains_to_species[num_chains].append(species)
num_rows_seen = 0
# We always keep the first row as it is the query sequence.
all_rows = [np.array([[0] * len(chains)], dtype=np.int32)]
# We prioritize species that have hits across more chains.
for num_chains in sorted(num_chains_to_species, reverse=True):
all_species = num_chains_to_species[num_chains]
# Align all the per-chain row indices by species, so every paired row is
# for a single species.
rows = _align_species(
all_species, chains_species_to_rows, min_hits_per_species
)
# Sort rows by the product of the original indices in the respective chain
# MSAS, so as to rank hits that appear higher in the original MSAs higher.
rank_metric = np.abs(np.prod(rows.astype(np.float32), axis=1))
sorted_rows = rows[np.argsort(rank_metric), :]
all_rows.append(sorted_rows)
num_rows_seen += rows.shape[0]
if num_rows_seen >= max_paired_sequences:
break
all_rows = np.concatenate(all_rows, axis=0)
all_rows = all_rows[:max_paired_sequences, :]
# Now we just have to select the relevant rows from the original msa and
# deletion matrix features
paired_chains = []
for chain_idx, chain in enumerate(chains):
out_chain = {k: v for k, v in chain.items() if 'all_seq' not in k}
selected_row_indices = all_rows[:, chain_idx]
for feat_name in {'msa', 'deletion_matrix'}:
all_seq_name = f'{feat_name}_all_seq'
feat_value = chain[all_seq_name]
# The selected row indices are padded to be the same shape for each chain,
# they are padded with -1's, so we add a single row onto the feature with
# the appropriate pad value. This has the effect that we correctly pad
# each feature since all padded indices will select this padding row.
pad_value = data_constants.MSA_PAD_VALUES[feat_name]
feat_value = np.concatenate([
feat_value,
np.full((1, feat_value.shape[1]), pad_value, feat_value.dtype),
])
feat_value = feat_value[selected_row_indices, :]
out_chain[all_seq_name] = feat_value
out_chain['num_alignments_all_seq'] = np.array(
out_chain['msa_all_seq'].shape[0]
)
paired_chains.append(out_chain)
return paired_chains
def deduplicate_unpaired_sequences(
np_chains: Sequence[MutableMapping[str, np.ndarray]],
) -> Sequence[MutableMapping[str, np.ndarray]]:
"""Deduplicates unpaired sequences based on paired sequences."""
feature_names = np_chains[0].keys()
msa_features = (
data_constants.NUM_SEQ_MSA_FEATURES
+ data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES
)
for chain in np_chains:
sequence_set = set(
hash(s.data.tobytes()) for s in chain['msa_all_seq'].astype(np.int8)
)
keep_rows = []
# Go through unpaired MSA seqs and remove any rows that correspond to the
# sequences that are already present in the paired MSA.
for row_num, seq in enumerate(chain['msa'].astype(np.int8)):
if hash(seq.data.tobytes()) not in sequence_set:
keep_rows.append(row_num)
for feature_name in feature_names:
if feature_name in msa_features:
chain[feature_name] = chain[feature_name][keep_rows]
chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32)
return np_chains
def choose_paired_unpaired_msa_crop_sizes(
unpaired_msa: np.ndarray,
paired_msa: np.ndarray | None,
total_msa_crop_size: int,
max_paired_sequences: int,
) -> tuple[int, int | None]:
"""Returns the sizes of the MSA crop and MSA_all_seq crop.
NOTE: Unpaired + paired MSA sizes can exceed total_msa_size when
there are lots of gapped rows. Through the pairing logic another chain(s)
will have fewer than total_msa_size.
Args:
unpaired_msa: The unpaired MSA array (not all_seq).
paired_msa: The paired MSA array (all_seq).
total_msa_crop_size: The maximum total number of sequences to crop to.
max_paired_sequences: The maximum number of sequences that can come from
MSA pairing.
Returns:
A tuple of:
The size of the reduced MSA crop (not all_seq features).
The size of the unreduced MSA crop (for all_seq features) or None, if
paired_msa is None.
"""
if paired_msa is not None:
paired_crop_size = np.minimum(paired_msa.shape[0], max_paired_sequences)
# We reduce the number of un-paired sequences, by the number of times a
# sequence from this chains MSA is included in the paired MSA. This keeps
# the MSA size for each chain roughly constant.
cropped_all_seq_msa = paired_msa[:max_paired_sequences]
num_non_gapped_pairs = cropped_all_seq_msa.shape[0]
assert num_non_gapped_pairs <= max_paired_sequences
unpaired_crop_size = np.minimum(
unpaired_msa.shape[0], total_msa_crop_size - num_non_gapped_pairs
)
assert unpaired_crop_size >= 0
else:
unpaired_crop_size = np.minimum(unpaired_msa.shape[0], total_msa_crop_size)
paired_crop_size = None
return unpaired_crop_size, paired_crop_size
def remove_all_gapped_rows_from_all_seqs(
chains_list: Sequence[dict[str, np.ndarray]], asym_ids: Sequence[float]
) -> Sequence[dict[str, np.ndarray]]:
"""Removes all gapped rows from all_seq feat based on selected asym_ids."""
merged_msa_all_seq = np.concatenate(
[
chain['msa_all_seq']
for chain in chains_list
if chain['asym_id'][0] in asym_ids
],
axis=1,
)
non_gapped_keep_rows = np.any(
merged_msa_all_seq != data_constants.MSA_GAP_IDX, axis=1
)
for chain in chains_list:
for feat_name in list(chains_list)[0]:
if '_all_seq' in feat_name:
feat_name_split = feat_name.split('_all_seq')[0]
if feat_name_split in (
data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES
+ data_constants.NUM_SEQ_MSA_FEATURES
):
# For consistency we do this for all chains even though the
# gapped rows are based on a selected set asym_ids.
chain[feat_name] = chain[feat_name][non_gapped_keep_rows]
chain['num_alignments_all_seq'] = np.sum(non_gapped_keep_rows)
return chains_list
================================================
FILE: src/alphafold3/model/network/atom_cross_attention.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Per-atom cross attention."""
import dataclasses
from alphafold3.common import base_config
from alphafold3.model import feat_batch
from alphafold3.model import model_config
from alphafold3.model.atom_layout import atom_layout
from alphafold3.model.components import haiku_modules as hm
from alphafold3.model.components import utils
from alphafold3.model.network import diffusion_transformer
import jax
import jax.numpy as jnp
class AtomCrossAttEncoderConfig(base_config.BaseConfig):
per_token_channels: int = 768
per_atom_channels: int = 128
atom_transformer: diffusion_transformer.CrossAttTransformer.Config = (
base_config.autocreate(num_intermediate_factor=2, num_blocks=3)
)
per_atom_pair_channels: int = 16
def _per_atom_conditioning(
config: AtomCrossAttEncoderConfig, batch: feat_batch.Batch, name: str
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""computes single and pair conditioning for all atoms in each token."""
c = config
# Compute per-atom single conditioning
# Shape (num_tokens, num_dense, channels)
act = hm.Linear(
c.per_atom_channels, precision='highest', name=f'{name}_embed_ref_pos'
)(batch.ref_structure.positions)
act += hm.Linear(c.per_atom_channels, name=f'{name}_embed_ref_mask')(
batch.ref_structure.mask.astype(jnp.float32)[:, :, None]
)
# Element is encoded as atomic number if the periodic table, so
# 128 should be fine.
act += hm.Linear(c.per_atom_channels, name=f'{name}_embed_ref_element')(
jax.nn.one_hot(batch.ref_structure.element, 128)
)
act += hm.Linear(c.per_atom_channels, name=f'{name}_embed_ref_charge')(
jnp.arcsinh(batch.ref_structure.charge)[:, :, None]
)
# Characters are encoded as ASCII code minus 32, so we need 64 classes,
# to encode all standard ASCII characters between 32 and 96.
atom_name_chars_1hot = jax.nn.one_hot(batch.ref_structure.atom_name_chars, 64)
num_token, num_dense, _ = act.shape
act += hm.Linear(c.per_atom_channels, name=f'{name}_embed_ref_atom_name')(
atom_name_chars_1hot.reshape(num_token, num_dense, -1)
)
act *= batch.ref_structure.mask.astype(jnp.float32)[:, :, None]
# Compute pair conditioning
# shape (num_tokens, num_dense, num_dense, channels)
# Embed single features
row_act = hm.Linear(
c.per_atom_pair_channels, name=f'{name}_single_to_pair_cond_row'
)(jax.nn.relu(act))
col_act = hm.Linear(
c.per_atom_pair_channels, name=f'{name}_single_to_pair_cond_col'
)(jax.nn.relu(act))
pair_act = row_act[:, :, None, :] + col_act[:, None, :, :]
# Embed pairwise offsets
pair_act += hm.Linear(
c.per_atom_pair_channels,
precision='highest',
name=f'{name}_embed_pair_offsets',
)(
batch.ref_structure.positions[:, :, None, :]
- batch.ref_structure.positions[:, None, :, :]
)
# Embed pairwise inverse squared distances
sq_dists = jnp.sum(
jnp.square(
batch.ref_structure.positions[:, :, None, :]
- batch.ref_structure.positions[:, None, :, :]
),
axis=-1,
)
pair_act += hm.Linear(
c.per_atom_pair_channels, name=f'{name}_embed_pair_distances'
)(1.0 / (1 + sq_dists[:, :, :, None]))
return act, pair_act
@dataclasses.dataclass(frozen=True)
class AtomCrossAttEncoderOutput:
token_act: jnp.ndarray # (num_tokens, ch)
skip_connection: jnp.ndarray # (num_subsets, num_queries, ch)
queries_mask: jnp.ndarray # (num_subsets, num_queries)
queries_single_cond: jnp.ndarray # (num_subsets, num_queries, ch)
keys_mask: jnp.ndarray # (num_subsets, num_keys)
keys_single_cond: jnp.ndarray # (num_subsets, num_keys, ch)
pair_cond: jnp.ndarray # (num_subsets, num_queries, num_keys, ch)
jax.tree_util.register_dataclass(
AtomCrossAttEncoderOutput,
data_fields=[f.name for f in dataclasses.fields(AtomCrossAttEncoderOutput)],
meta_fields=[],
)
def atom_cross_att_encoder(
token_atoms_act: jnp.ndarray | None, # (num_tokens, max_atoms_per_token, 3)
trunk_single_cond: jnp.ndarray | None, # (num_tokens, ch)
trunk_pair_cond: jnp.ndarray | None, # (num_tokens, num_tokens, ch)
config: AtomCrossAttEncoderConfig,
global_config: model_config.GlobalConfig,
batch: feat_batch.Batch,
name: str,
) -> AtomCrossAttEncoderOutput:
"""Cross-attention on flat atom subsets and mapping to per-token features."""
c = config
# Compute single conditioning from atom meta data and convert to queries
# layout.
# (num_subsets, num_queries, channels)
token_atoms_single_cond, _ = _per_atom_conditioning(config, batch, name)
token_atoms_mask = batch.predicted_structure_info.atom_mask
queries_single_cond = atom_layout.convert(
batch.atom_cross_att.token_atoms_to_queries,
token_atoms_single_cond,
layout_axes=(-3, -2),
)
queries_mask = atom_layout.convert(
batch.atom_cross_att.token_atoms_to_queries,
token_atoms_mask,
layout_axes=(-2, -1),
)
# If provided, broadcast single conditioning from trunk to all queries
if trunk_single_cond is not None:
trunk_single_cond = hm.Linear(
c.per_atom_channels,
precision='highest',
initializer=global_config.final_init,
name=f'{name}_embed_trunk_single_cond',
)(
hm.LayerNorm(
use_fast_variance=False,
create_offset=False,
name=f'{name}_lnorm_trunk_single_cond',
)(trunk_single_cond)
)
queries_single_cond += atom_layout.convert(
batch.atom_cross_att.tokens_to_queries,
trunk_single_cond,
layout_axes=(-2,),
)
if token_atoms_act is None:
# if no token_atoms_act is given (e.g. begin of evoformer), we use the
# static conditioning only
queries_act = queries_single_cond
else:
# Convert token_atoms_act to queries layout and map to per_atom_channels
# (num_subsets, num_queries, channels)
queries_act = atom_layout.convert(
batch.atom_cross_att.token_atoms_to_queries,
token_atoms_act,
layout_axes=(-3, -2),
)
queries_act = hm.Linear(
c.per_atom_channels,
precision='highest',
name=f'{name}_atom_positions_to_features',
)(queries_act)
queries_act *= queries_mask[..., None]
queries_act += queries_single_cond
# Gather the keys from the queries.
keys_single_cond = atom_layout.convert(
batch.atom_cross_att.queries_to_keys,
queries_single_cond,
layout_axes=(-3, -2),
)
keys_mask = atom_layout.convert(
batch.atom_cross_att.queries_to_keys, queries_mask, layout_axes=(-2, -1)
)
# Embed single features into the pair conditioning.
# shape (num_subsets, num_queries, num_keys, ch)
row_act = hm.Linear(
c.per_atom_pair_channels, name=f'{name}_single_to_pair_cond_row'
)(jax.nn.relu(queries_single_cond))
pair_cond_keys_input = atom_layout.convert(
batch.atom_cross_att.queries_to_keys,
queries_single_cond,
layout_axes=(-3, -2),
)
col_act = hm.Linear(
c.per_atom_pair_channels, name=f'{name}_single_to_pair_cond_col'
)(jax.nn.relu(pair_cond_keys_input))
pair_act = row_act[:, :, None, :] + col_act[:, None, :, :]
if trunk_pair_cond is not None:
# If provided, broadcast the pair conditioning for the trunk (evoformer
# pairs) to the atom pair activations. This should boost ligands, but also
# help for cross attention within proteins, because we always have atoms
# from multiple residues in a subset.
# Map trunk pair conditioning to per_atom_pair_channels
# (num_tokens, num_tokens, per_atom_pair_channels)
trunk_pair_cond = hm.Linear(
c.per_atom_pair_channels,
precision='highest',
initializer=global_config.final_init,
name=f'{name}_embed_trunk_pair_cond',
)(
hm.LayerNorm(
use_fast_variance=False,
create_offset=False,
name=f'{name}_lnorm_trunk_pair_cond',
)(trunk_pair_cond)
)
# Create the GatherInfo into a flattened trunk_pair_cond from the
# queries and keys gather infos.
num_tokens = trunk_pair_cond.shape[0]
# (num_subsets, num_queries)
tokens_to_queries = batch.atom_cross_att.tokens_to_queries
# (num_subsets, num_keys)
tokens_to_keys = batch.atom_cross_att.tokens_to_keys
# (num_subsets, num_queries, num_keys)
trunk_pair_to_atom_pair = atom_layout.GatherInfo(
gather_idxs=(
num_tokens * tokens_to_queries.gather_idxs[:, :, None]
+ tokens_to_keys.gather_idxs[:, None, :]
),
gather_mask=(
tokens_to_queries.gather_mask[:, :, None]
& tokens_to_keys.gather_mask[:, None, :]
),
input_shape=jnp.array((num_tokens, num_tokens)),
)
# Gather the conditioning and add it to the atom-pair activations.
pair_act += atom_layout.convert(
trunk_pair_to_atom_pair, trunk_pair_cond, layout_axes=(-3, -2)
)
# Embed pairwise offsets
queries_ref_pos = atom_layout.convert(
batch.atom_cross_att.token_atoms_to_queries,
batch.ref_structure.positions,
layout_axes=(-3, -2),
)
queries_ref_space_uid = atom_layout.convert(
batch.atom_cross_att.token_atoms_to_queries,
batch.ref_structure.ref_space_uid,
layout_axes=(-2, -1),
)
keys_ref_pos = atom_layout.convert(
batch.atom_cross_att.queries_to_keys,
queries_ref_pos,
layout_axes=(-3, -2),
)
keys_ref_space_uid = atom_layout.convert(
batch.atom_cross_att.queries_to_keys,
batch.ref_structure.ref_space_uid,
layout_axes=(-2, -1),
)
offsets_valid = (
queries_ref_space_uid[:, :, None] == keys_ref_space_uid[:, None, :]
)
offsets = queries_ref_pos[:, :, None, :] - keys_ref_pos[:, None, :, :]
pair_act += (
hm.Linear(
c.per_atom_pair_channels,
precision='highest',
name=f'{name}_embed_pair_offsets',
)(offsets)
* offsets_valid[:, :, :, None]
)
# Embed pairwise inverse squared distances
sq_dists = jnp.sum(jnp.square(offsets), axis=-1)
pair_act += (
hm.Linear(c.per_atom_pair_channels, name=f'{name}_embed_pair_distances')(
1.0 / (1 + sq_dists[:, :, :, None])
)
* offsets_valid[:, :, :, None]
)
# Embed offsets valid mask
pair_act += hm.Linear(
c.per_atom_pair_channels, name=f'{name}_embed_pair_offsets_valid'
)(offsets_valid[:, :, :, None].astype(jnp.float32))
# Run a small MLP on the pair acitvations
pair_act2 = hm.Linear(
c.per_atom_pair_channels, initializer='relu', name=f'{name}_pair_mlp_1'
)(jax.nn.relu(pair_act))
pair_act2 = hm.Linear(
c.per_atom_pair_channels, initializer='relu', name=f'{name}_pair_mlp_2'
)(jax.nn.relu(pair_act2))
pair_act += hm.Linear(
c.per_atom_pair_channels,
initializer=global_config.final_init,
name=f'{name}_pair_mlp_3',
)(jax.nn.relu(pair_act2))
# Run the atom cross attention transformer.
queries_act = diffusion_transformer.CrossAttTransformer(
c.atom_transformer, global_config, name=f'{name}_atom_transformer_encoder'
)(
queries_act=queries_act,
queries_mask=queries_mask,
queries_to_keys=batch.atom_cross_att.queries_to_keys,
keys_mask=keys_mask,
queries_single_cond=queries_single_cond,
keys_single_cond=keys_single_cond,
pair_cond=pair_act,
)
queries_act *= queries_mask[..., None]
skip_connection = queries_act
# Convert back to token-atom layout and aggregate to tokens
queries_act = hm.Linear(
c.per_token_channels, name=f'{name}_project_atom_features_for_aggr'
)(queries_act)
token_atoms_act = atom_layout.convert(
batch.atom_cross_att.queries_to_token_atoms,
queries_act,
layout_axes=(-3, -2),
)
token_act = utils.mask_mean(
token_atoms_mask[..., None], jax.nn.relu(token_atoms_act), axis=-2
)
return AtomCrossAttEncoderOutput(
token_act=token_act,
skip_connection=skip_connection,
queries_mask=queries_mask,
queries_single_cond=queries_single_cond,
keys_mask=keys_mask,
keys_single_cond=keys_single_cond,
pair_cond=pair_act,
)
class AtomCrossAttDecoderConfig(base_config.BaseConfig):
per_atom_channels: int = 128
atom_transformer: diffusion_transformer.CrossAttTransformer.Config = (
base_config.autocreate(num_intermediate_factor=2, num_blocks=3)
)
def atom_cross_att_decoder(
token_act: jnp.ndarray, # (num_tokens, ch)
enc: AtomCrossAttEncoderOutput,
config: AtomCrossAttDecoderConfig,
global_config: model_config.GlobalConfig,
batch: feat_batch.Batch,
name: str,
): # (num_tokens, max_atoms_per_token, 3)
"""Mapping to per-atom features and self-attention on subsets."""
c = config
# map per-token act down to per_atom channels
token_act = hm.Linear(
c.per_atom_channels, name=f'{name}_project_token_features_for_broadcast'
)(token_act)
# Broadcast to token-atoms layout and convert to queries layout.
num_token, max_atoms_per_token = (
batch.atom_cross_att.queries_to_token_atoms.shape
)
token_atom_act = jnp.broadcast_to(
token_act[:, None, :],
(num_token, max_atoms_per_token, c.per_atom_channels),
)
queries_act = atom_layout.convert(
batch.atom_cross_att.token_atoms_to_queries,
token_atom_act,
layout_axes=(-3, -2),
)
queries_act += enc.skip_connection
queries_act *= enc.queries_mask[..., None]
# Run the atom cross attention transformer.
queries_act = diffusion_transformer.CrossAttTransformer(
c.atom_transformer, global_config, name=f'{name}_atom_transformer_decoder'
)(
queries_act=queries_act,
queries_mask=enc.queries_mask,
queries_to_keys=batch.atom_cross_att.queries_to_keys,
keys_mask=enc.keys_mask,
queries_single_cond=enc.queries_single_cond,
keys_single_cond=enc.keys_single_cond,
pair_cond=enc.pair_cond,
)
queries_act *= enc.queries_mask[..., None]
queries_act = hm.LayerNorm(
use_fast_variance=False,
create_offset=False,
name=f'{name}_atom_features_layer_norm',
)(queries_act)
queries_position_update = hm.Linear(
3,
initializer=global_config.final_init,
precision='highest',
name=f'{name}_atom_features_to_position_update',
)(queries_act)
position_update = atom_layout.convert(
batch.atom_cross_att.queries_to_token_atoms,
queries_position_update,
layout_axes=(-3, -2),
)
return position_update
================================================
FILE: src/alphafold3/model/network/confidence_head.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Confidence Head."""
from alphafold3.common import base_config
from alphafold3.model import model_config
from alphafold3.model.atom_layout import atom_layout
from alphafold3.model.components import haiku_modules as hm
from alphafold3.model.components import utils
from alphafold3.model.network import modules
from alphafold3.model.network import template_modules
import haiku as hk
import jax
import jax.numpy as jnp
def _safe_norm(x, keepdims, axis, eps=1e-8):
return jnp.sqrt(eps + jnp.sum(jnp.square(x), axis=axis, keepdims=keepdims))
class ConfidenceHead(hk.Module):
"""Head to predict the distance errors in a prediction."""
class PAEConfig(base_config.BaseConfig):
max_error_bin: float = 31.0
num_bins: int = 64
class Config(base_config.BaseConfig):
"""Configuration for ConfidenceHead."""
pairformer: modules.PairFormerIteration.Config = base_config.autocreate(
single_attention=base_config.autocreate(),
single_transition=base_config.autocreate(),
num_layer=4,
)
max_error_bin: float = 31.0
num_plddt_bins: int = 50
num_bins: int = 64
no_embedding_prob: float = 0.2
pae: 'ConfidenceHead.PAEConfig' = base_config.autocreate()
dgram_features: template_modules.DistogramFeaturesConfig = (
base_config.autocreate()
)
def __init__(
self,
config: Config,
global_config: model_config.GlobalConfig,
name='confidence_head',
):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def _embed_features(
self,
dense_atom_positions,
token_atoms_to_pseudo_beta,
pair_mask,
pair_act,
target_feat,
):
out = hm.Linear(pair_act.shape[-1], name='left_target_feat_project')(
target_feat
).astype(pair_act.dtype)
out += hm.Linear(pair_act.shape[-1], name='right_target_feat_project')(
target_feat
).astype(pair_act.dtype)[:, None]
positions = atom_layout.convert(
token_atoms_to_pseudo_beta,
dense_atom_positions,
layout_axes=(-3, -2),
)
dgram = template_modules.dgram_from_positions(
positions, self.config.dgram_features
)
dgram *= pair_mask[..., None]
out += hm.Linear(pair_act.shape[-1], name='distogram_feat_project')(
dgram.astype(pair_act.dtype)
)
return out
def __call__(
self,
dense_atom_positions: jnp.ndarray,
embeddings: dict[str, jnp.ndarray],
seq_mask: jnp.ndarray,
token_atoms_to_pseudo_beta: atom_layout.GatherInfo,
asym_id: jnp.ndarray,
) -> dict[str, jnp.ndarray]:
"""Builds ConfidenceHead module.
Arguments:
dense_atom_positions: [N_res, N_atom, 3] array of positions.
embeddings: Dictionary of representations.
seq_mask: Sequence mask.
token_atoms_to_pseudo_beta: Pseudo beta info for atom tokens.
asym_id: Asym ID token features.
Returns:
Dictionary of results.
"""
dtype = (
jnp.bfloat16 if self.global_config.bfloat16 == 'all' else jnp.float32
)
with utils.bfloat16_context():
seq_mask_cast = seq_mask.astype(dtype)
pair_mask = seq_mask_cast[:, None] * seq_mask_cast[None, :]
pair_mask = pair_mask.astype(dtype)
pair_act = embeddings['pair'].astype(dtype)
single_act = embeddings['single'].astype(dtype)
target_feat = embeddings['target_feat'].astype(dtype)
num_residues = seq_mask.shape[0]
num_pair_channels = pair_act.shape[2]
pair_act += self._embed_features(
dense_atom_positions,
token_atoms_to_pseudo_beta,
pair_mask,
pair_act,
target_feat,
)
def pairformer_fn(act):
pair_act, single_act = act
return modules.PairFormerIteration(
self.config.pairformer,
self.global_config,
with_single=True,
name='confidence_pairformer',
)(
act=pair_act,
single_act=single_act,
pair_mask=pair_mask,
seq_mask=seq_mask,
)
pairformer_stack = hk.experimental.layer_stack(
self.config.pairformer.num_layer
)(pairformer_fn)
pair_act, single_act = pairformer_stack((pair_act, single_act))
pair_act = pair_act.astype(jnp.float32)
assert pair_act.shape == (num_residues, num_residues, num_pair_channels)
# Produce logits to predict a distogram of pairwise distance errors
# between the input prediction and the ground truth.
# Shape (num_res, num_res, num_bins)
left_distance_logits = hm.Linear(
self.config.num_bins,
initializer=self.global_config.final_init,
name='left_half_distance_logits',
)(hm.LayerNorm(name='logits_ln')(pair_act))
right_distance_logits = left_distance_logits
distance_logits = left_distance_logits + jnp.swapaxes( # Symmetrize.
right_distance_logits, -2, -3
)
# Shape (num_bins,)
distance_breaks = jnp.linspace(
0.0, self.config.max_error_bin, self.config.num_bins - 1
)
step = distance_breaks[1] - distance_breaks[0]
# Add half-step to get the center
bin_centers = distance_breaks + step / 2
# Add a catch-all bin at the end.
bin_centers = jnp.concatenate(
[bin_centers, bin_centers[-1:] + step], axis=0
)
distance_probs = jax.nn.softmax(distance_logits, axis=-1)
pred_distance_error = (
jnp.sum(distance_probs * bin_centers, axis=-1) * pair_mask
)
average_pred_distance_error = jnp.sum(
pred_distance_error, axis=[-2, -1]
) / jnp.sum(pair_mask, axis=[-2, -1])
# Predicted aligned error
pae_outputs = {}
# Shape (num_res, num_res, num_bins)
pae_logits = hm.Linear(
self.config.pae.num_bins,
initializer=self.global_config.final_init,
name='pae_logits',
)(hm.LayerNorm(name='pae_logits_ln')(pair_act))
# Shape (num_bins,)
pae_breaks = jnp.linspace(
0.0, self.config.pae.max_error_bin, self.config.pae.num_bins - 1
)
step = pae_breaks[1] - pae_breaks[0]
# Add half-step to get the center
bin_centers = pae_breaks + step / 2
# Add a catch-all bin at the end.
bin_centers = jnp.concatenate(
[bin_centers, bin_centers[-1:] + step], axis=0
)
pae_probs = jax.nn.softmax(pae_logits, axis=-1)
seq_mask_bool = seq_mask.astype(bool)
pair_mask_bool = seq_mask_bool[:, None] * seq_mask_bool[None, :]
pae = jnp.sum(pae_probs * bin_centers, axis=-1) * pair_mask_bool
pae_outputs.update({
'full_pae': pae,
})
# The pTM is computed outside of bfloat16 context.
tmscore_adjusted_pae_global, tmscore_adjusted_pae_interface = (
self._get_tmscore_adjusted_pae(
asym_id=asym_id,
seq_mask=seq_mask,
pair_mask=pair_mask_bool,
bin_centers=bin_centers,
pae_probs=pae_probs,
)
)
pae_outputs.update({
'tmscore_adjusted_pae_global': tmscore_adjusted_pae_global,
'tmscore_adjusted_pae_interface': tmscore_adjusted_pae_interface,
})
single_act = single_act.astype('float32')
# pLDDT
# Shape (num_res, num_atom, num_bins)
plddt_logits = hm.Linear(
(dense_atom_positions.shape[-2], self.config.num_plddt_bins),
initializer=self.global_config.final_init,
name='plddt_logits',
)(hm.LayerNorm(name='plddt_logits_ln')(single_act))
bin_width = 1.0 / self.config.num_plddt_bins
bin_centers = jnp.arange(0.5 * bin_width, 1.0, bin_width)
predicted_lddt = jnp.sum(
jax.nn.softmax(plddt_logits, axis=-1) * bin_centers, axis=-1
)
predicted_lddt = predicted_lddt * 100.0
# Experimentally resolved
# Shape (num_res, num_atom, 2)
experimentally_resolved_logits = hm.Linear(
(dense_atom_positions.shape[-2], 2),
initializer=self.global_config.final_init,
name='experimentally_resolved_logits',
)(hm.LayerNorm(name='experimentally_resolved_ln')(single_act))
predicted_experimentally_resolved = jax.nn.softmax(
experimentally_resolved_logits, axis=-1
)[..., 1]
return {
'predicted_lddt': predicted_lddt,
'predicted_experimentally_resolved': predicted_experimentally_resolved,
'full_pde': pred_distance_error,
'average_pde': average_pred_distance_error,
**pae_outputs,
}
def _get_tmscore_adjusted_pae(
self,
asym_id: jnp.ndarray,
seq_mask: jnp.ndarray,
pair_mask: jnp.ndarray,
bin_centers: jnp.ndarray,
pae_probs: jnp.ndarray,
):
def get_tmscore_adjusted_pae(num_interface_tokens, bin_centers, pae_probs):
# Clip to avoid negative/undefined d0.
clipped_num_res = jnp.maximum(num_interface_tokens, 19)
# Compute d_0(num_res) as defined by TM-score, eqn. (5) in
# http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
# Yang & Skolnick "Scoring function for automated
# assessment of protein structure template quality" 2004.
d0 = 1.24 * (clipped_num_res - 15) ** (1.0 / 3) - 1.8
# Make compatible with [num_tokens, num_tokens, num_bins]
d0 = d0[:, :, None]
bin_centers = bin_centers[None, None, :]
# TM-Score term for every bin.
tm_per_bin = 1.0 / (1 + jnp.square(bin_centers) / jnp.square(d0))
# E_distances tm(distance).
predicted_tm_term = jnp.sum(pae_probs * tm_per_bin, axis=-1)
return predicted_tm_term
# Interface version
x = asym_id[None, :] == asym_id[:, None]
num_chain_tokens = jnp.sum(x * pair_mask, axis=-1)
num_interface_tokens = num_chain_tokens[None, :] + num_chain_tokens[:, None]
# Don't double-count within a single chain
num_interface_tokens -= x * (num_interface_tokens // 2)
num_interface_tokens = num_interface_tokens * pair_mask
num_global_tokens = jnp.full(
shape=pair_mask.shape, fill_value=seq_mask.sum()
)
assert num_global_tokens.dtype == 'int32'
assert num_interface_tokens.dtype == 'int32'
global_apae = get_tmscore_adjusted_pae(
num_global_tokens, bin_centers, pae_probs
)
interface_apae = get_tmscore_adjusted_pae(
num_interface_tokens, bin_centers, pae_probs
)
return global_apae, interface_apae
================================================
FILE: src/alphafold3/model/network/diffusion_head.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Diffusion Head."""
from collections.abc import Callable
from alphafold3.common import base_config
from alphafold3.model import feat_batch
from alphafold3.model import model_config
from alphafold3.model.components import haiku_modules as hm
from alphafold3.model.components import utils
from alphafold3.model.network import atom_cross_attention
from alphafold3.model.network import diffusion_transformer
from alphafold3.model.network import featurization
from alphafold3.model.network import noise_level_embeddings
import haiku as hk
import jax
import jax.numpy as jnp
# Carefully measured by averaging multimer training set.
SIGMA_DATA = 16.0
def random_rotation(key):
# Create a random rotation (Gram-Schmidt orthogonalization of two
# random normal vectors)
v0, v1 = jax.random.normal(key, shape=(2, 3))
e0 = v0 / jnp.maximum(1e-10, jnp.linalg.norm(v0))
v1 = v1 - e0 * jnp.dot(v1, e0, precision=jax.lax.Precision.HIGHEST)
e1 = v1 / jnp.maximum(1e-10, jnp.linalg.norm(v1))
e2 = jnp.cross(e0, e1)
return jnp.stack([e0, e1, e2])
def random_augmentation(
rng_key: jnp.ndarray,
positions: jnp.ndarray,
mask: jnp.ndarray,
) -> jnp.ndarray:
"""Apply random rigid augmentation.
Args:
rng_key: random key
positions: atom positions of shape (, 3)
mask: per-atom mask of shape (,)
Returns:
Transformed positions with the same shape as input positions.
"""
rotation_key, translation_key = jax.random.split(rng_key)
center = utils.mask_mean(
mask[..., None], positions, axis=(-2, -3), keepdims=True, eps=1e-6
)
rot = random_rotation(rotation_key)
translation = jax.random.normal(translation_key, shape=(3,))
augmented_positions = (
jnp.einsum(
'...i,ij->...j',
positions - center,
rot,
precision=jax.lax.Precision.HIGHEST,
)
+ translation
)
return augmented_positions * mask[..., None]
def noise_schedule(t, smin=0.0004, smax=160.0, p=7):
return (
SIGMA_DATA
* (smax ** (1 / p) + t * (smin ** (1 / p) - smax ** (1 / p))) ** p
)
class ConditioningConfig(base_config.BaseConfig):
pair_channel: int
seq_channel: int
prob: float
class SampleConfig(base_config.BaseConfig):
steps: int
gamma_0: float = 0.8
gamma_min: float = 1.0
noise_scale: float = 1.003
step_scale: float = 1.5
num_samples: int = 1
class DiffusionHead(hk.Module):
"""Denoising Diffusion Head."""
class Config(
atom_cross_attention.AtomCrossAttEncoderConfig,
atom_cross_attention.AtomCrossAttDecoderConfig,
):
"""Configuration for DiffusionHead."""
eval_batch_size: int = 5
eval_batch_dim_shard_size: int = 5
conditioning: ConditioningConfig = base_config.autocreate(
prob=0.8, pair_channel=128, seq_channel=384
)
eval: SampleConfig = base_config.autocreate(
num_samples=5,
steps=200,
)
transformer: diffusion_transformer.Transformer.Config = (
base_config.autocreate()
)
def __init__(
self,
config: Config,
global_config: model_config.GlobalConfig,
name='diffusion_head',
):
self.config = config
self.global_config = global_config
super().__init__(name=name)
@hk.transparent
def _conditioning(
self,
batch: feat_batch.Batch,
embeddings: dict[str, jnp.ndarray],
noise_level: jnp.ndarray,
use_conditioning: bool,
) -> tuple[jnp.ndarray, jnp.ndarray]:
single_embedding = use_conditioning * embeddings['single']
pair_embedding = use_conditioning * embeddings['pair']
rel_features = featurization.create_relative_encoding(
seq_features=batch.token_features,
max_relative_idx=32,
max_relative_chain=2,
).astype(pair_embedding.dtype)
features_2d = jnp.concatenate([pair_embedding, rel_features], axis=-1)
pair_cond = hm.Linear(
self.config.conditioning.pair_channel,
precision='highest',
name='pair_cond_initial_projection',
)(
hm.LayerNorm(
use_fast_variance=False,
create_offset=False,
name='pair_cond_initial_norm',
)(features_2d)
)
for idx in range(2):
pair_cond += diffusion_transformer.transition_block(
pair_cond, 2, self.global_config, name=f'pair_transition_{idx}'
)
target_feat = embeddings['target_feat']
features_1d = jnp.concatenate([single_embedding, target_feat], axis=-1)
single_cond = hm.LayerNorm(
use_fast_variance=False,
create_offset=False,
name='single_cond_initial_norm',
)(features_1d)
single_cond = hm.Linear(
self.config.conditioning.seq_channel,
precision='highest',
name='single_cond_initial_projection',
)(single_cond)
noise_embedding = noise_level_embeddings.noise_embeddings(
sigma_scaled_noise_level=noise_level / SIGMA_DATA
)
single_cond += hm.Linear(
self.config.conditioning.seq_channel,
precision='highest',
name='noise_embedding_initial_projection',
)(
hm.LayerNorm(
use_fast_variance=False,
create_offset=False,
name='noise_embedding_initial_norm',
)(noise_embedding)
)
for idx in range(2):
single_cond += diffusion_transformer.transition_block(
single_cond, 2, self.global_config, name=f'single_transition_{idx}'
)
return single_cond, pair_cond
def __call__(
self,
# positions_noisy.shape: (num_token, max_atoms_per_token, 3)
positions_noisy: jnp.ndarray,
noise_level: jnp.ndarray,
batch: feat_batch.Batch,
embeddings: dict[str, jnp.ndarray],
use_conditioning: bool,
) -> jnp.ndarray:
with utils.bfloat16_context():
# Get conditioning
trunk_single_cond, trunk_pair_cond = self._conditioning(
batch=batch,
embeddings=embeddings,
noise_level=noise_level,
use_conditioning=use_conditioning,
)
# Extract features
sequence_mask = batch.token_features.mask
atom_mask = batch.predicted_structure_info.atom_mask
# Position features
act = positions_noisy * atom_mask[..., None]
act = act / jnp.sqrt(noise_level**2 + SIGMA_DATA**2)
enc = atom_cross_attention.atom_cross_att_encoder(
token_atoms_act=act,
trunk_single_cond=embeddings['single'],
trunk_pair_cond=trunk_pair_cond,
config=self.config,
global_config=self.global_config,
batch=batch,
name='diffusion',
)
act = enc.token_act
# Token-token attention
act = jnp.asarray(act, dtype=jnp.float32)
act += hm.Linear(
act.shape[-1],
precision='highest',
initializer=self.global_config.final_init,
name='single_cond_embedding_projection',
)(
hm.LayerNorm(
use_fast_variance=False,
create_offset=False,
name='single_cond_embedding_norm',
)(trunk_single_cond)
)
act = jnp.asarray(act, dtype=jnp.float32)
trunk_single_cond = jnp.asarray(trunk_single_cond, dtype=jnp.float32)
trunk_pair_cond = jnp.asarray(trunk_pair_cond, dtype=jnp.float32)
sequence_mask = jnp.asarray(sequence_mask, dtype=jnp.float32)
transformer = diffusion_transformer.Transformer(
self.config.transformer, self.global_config
)
act = transformer(
act=act,
single_cond=trunk_single_cond,
mask=sequence_mask,
pair_cond=trunk_pair_cond,
)
act = hm.LayerNorm(
use_fast_variance=False, create_offset=False, name='output_norm'
)(act)
# (n_tokens, per_token_channels)
# (Possibly) atom-granularity decoder
assert isinstance(enc, atom_cross_attention.AtomCrossAttEncoderOutput)
position_update = atom_cross_attention.atom_cross_att_decoder(
token_act=act,
enc=enc,
config=self.config,
global_config=self.global_config,
batch=batch,
name='diffusion',
)
skip_scaling = SIGMA_DATA**2 / (noise_level**2 + SIGMA_DATA**2)
out_scaling = (
noise_level * SIGMA_DATA / jnp.sqrt(noise_level**2 + SIGMA_DATA**2)
)
# End `with utils.bfloat16_context()`.
return (
skip_scaling * positions_noisy + out_scaling * position_update
) * atom_mask[..., None]
def sample(
denoising_step: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
batch: feat_batch.Batch,
key: jnp.ndarray,
config: SampleConfig,
) -> dict[str, jnp.ndarray]:
"""Sample using denoiser on batch.
Args:
denoising_step: the denoising function.
batch: the batch
key: random key
config: config for the sampling process (e.g. number of denoising steps,
etc.)
Returns:
a dict
{
'atom_positions': jnp.array(...) # shape (, 3)
'mask': jnp.array(...) # shape (,)
}
where the are
(num_samples, num_tokens, max_atoms_per_token)
"""
mask = batch.predicted_structure_info.atom_mask
def apply_denoising_step(carry, noise_level):
key, positions, noise_level_prev = carry
key, key_noise, key_aug = jax.random.split(key, 3)
positions = random_augmentation(
rng_key=key_aug, positions=positions, mask=mask
)
gamma = config.gamma_0 * (noise_level > config.gamma_min)
t_hat = noise_level_prev * (1 + gamma)
noise_scale = config.noise_scale * jnp.sqrt(t_hat**2 - noise_level_prev**2)
noise = noise_scale * jax.random.normal(key_noise, positions.shape)
positions_noisy = positions + noise
positions_denoised = denoising_step(positions_noisy, t_hat)
grad = (positions_noisy - positions_denoised) / t_hat
d_t = noise_level - t_hat
positions_out = positions_noisy + config.step_scale * d_t * grad
return (key, positions_out, noise_level), positions_out
num_samples = config.num_samples
noise_levels = noise_schedule(jnp.linspace(0, 1, config.steps + 1))
key, noise_key = jax.random.split(key)
positions = jax.random.normal(noise_key, (num_samples,) + mask.shape + (3,))
positions *= noise_levels[0]
init = (
jax.random.split(key, num_samples),
positions,
jnp.tile(noise_levels[None, 0], (num_samples,)),
)
apply_denoising_step = hk.vmap(
apply_denoising_step, in_axes=(0, None), split_rng=(not hk.running_init())
)
result, _ = hk.scan(apply_denoising_step, init, noise_levels[1:], unroll=4)
_, positions_out, _ = result
final_dense_atom_mask = jnp.tile(mask[None], (num_samples, 1, 1))
return {'atom_positions': positions_out, 'mask': final_dense_atom_mask}
================================================
FILE: src/alphafold3/model/network/diffusion_transformer.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Diffusion transformer model."""
from alphafold3.common import base_config
from alphafold3.model import model_config
from alphafold3.model.atom_layout import atom_layout
from alphafold3.model.components import haiku_modules as hm
import haiku as hk
import jax
from jax import numpy as jnp
import tokamax
def adaptive_layernorm(x, single_cond, name):
"""Adaptive LayerNorm."""
# Adopted from Scalable Diffusion Models with Transformers
# https://arxiv.org/abs/2212.09748
if single_cond is None:
x = hm.LayerNorm(name=f'{name}layer_norm', use_fast_variance=False)(x)
else:
x = hm.LayerNorm(
name=f'{name}layer_norm',
use_fast_variance=False,
create_scale=False,
create_offset=False,
)(x)
single_cond = hm.LayerNorm(
name=f'{name}single_cond_layer_norm',
use_fast_variance=False,
create_offset=False,
)(single_cond)
single_scale = hm.Linear(
x.shape[-1],
initializer='zeros',
use_bias=True,
name=f'{name}single_cond_scale',
)(single_cond)
single_bias = hm.Linear(
x.shape[-1], initializer='zeros', name=f'{name}single_cond_bias'
)(single_cond)
x = jax.nn.sigmoid(single_scale) * x + single_bias
return x
def adaptive_zero_init(
x, num_channels, single_cond, global_config: model_config.GlobalConfig, name
):
"""Adaptive zero init, from AdaLN-zero."""
if single_cond is None:
output = hm.Linear(
num_channels,
initializer=global_config.final_init,
name=f'{name}transition2',
)(x)
else:
output = hm.Linear(num_channels, name=f'{name}transition2')(x)
# Init to a small gain, sigmoid(-2) ~ 0.1
cond = hm.Linear(
output.shape[-1],
initializer='zeros',
use_bias=True,
bias_init=-2.0,
name=f'{name}adaptive_zero_cond',
)(single_cond)
output = jax.nn.sigmoid(cond) * output
return output
def transition_block(
x: jnp.ndarray,
num_intermediate_factor: int,
global_config: model_config.GlobalConfig,
single_cond: jnp.ndarray | None = None,
use_glu_kernel: bool = True,
name: str = '',
) -> jnp.ndarray:
"""Transition Block."""
num_channels = x.shape[-1]
num_intermediates = num_intermediate_factor * num_channels
x = adaptive_layernorm(x, single_cond, name=f'{name}ffw_')
if use_glu_kernel:
weights, _ = hm.haiku_linear_get_params(
x,
num_output=num_intermediates * 2,
initializer='relu',
name=f'{name}ffw_transition1',
)
weights = jnp.reshape(weights, (len(weights), 2, num_intermediates))
c = tokamax.gated_linear_unit(x=x, weights=weights, activation=jax.nn.swish)
else:
x = hm.Linear(
num_intermediates * 2, initializer='relu', name=f'{name}ffw_transition1'
)(x)
a, b = jnp.split(x, 2, axis=-1)
c = jax.nn.swish(a) * b
output = adaptive_zero_init(
c, num_channels, single_cond, global_config, f'{name}ffw_'
)
return output
class SelfAttentionConfig(base_config.BaseConfig):
num_head: int = 16
key_dim: int | None = None
value_dim: int | None = None
def self_attention(
x: jnp.ndarray, # (num_tokens, ch)
mask: jnp.ndarray, # (num_tokens,)
pair_logits: jnp.ndarray | None, # (num_heads, num_tokens, num_tokens)
config: SelfAttentionConfig,
global_config: model_config.GlobalConfig,
single_cond: jnp.ndarray | None = None, # (num_tokens, ch)
name: str = '',
) -> jnp.ndarray:
"""Multihead self-attention."""
assert len(mask.shape) == len(x.shape) - 1, f'{mask.shape}, {x.shape}'
# bias: ... x heads (1) x query (1) x key
bias = (1e9 * (mask - 1.0))[..., None, None, :]
x = adaptive_layernorm(x, single_cond, name=name)
num_channels = x.shape[-1]
# Sensible default for when the config keys are missing
key_dim = config.key_dim if config.key_dim is not None else num_channels
value_dim = config.value_dim if config.value_dim is not None else num_channels
num_head = config.num_head
assert key_dim % num_head == 0, f'{key_dim=} % {num_head=} != 0'
assert value_dim % num_head == 0, f'{value_dim=} % {num_head=} != 0'
key_dim = key_dim // num_head
value_dim = value_dim // num_head
qk_shape = (num_head, key_dim)
q = hm.Linear(qk_shape, use_bias=True, name=f'{name}q_projection')(x)
k = hm.Linear(qk_shape, use_bias=False, name=f'{name}k_projection')(x)
# In some situations the gradient norms can blow up without running this
# einsum in float32.
q = q.astype(jnp.float32)
k = k.astype(jnp.float32)
bias = bias.astype(jnp.float32)
logits = jnp.einsum('...qhc,...khc->...hqk', q * key_dim ** (-0.5), k) + bias
if pair_logits is not None:
logits += pair_logits # (num_heads, seq_len, seq_len)
weights = jax.nn.softmax(logits, axis=-1)
weights = jnp.asarray(weights, dtype=x.dtype)
v_shape = (num_head, value_dim)
v = hm.Linear(v_shape, use_bias=False, name=f'{name}v_projection')(x)
weighted_avg = jnp.einsum('...hqk,...khc->...qhc', weights, v)
weighted_avg = jnp.reshape(weighted_avg, weighted_avg.shape[:-2] + (-1,))
gate_logits = hm.Linear(
num_head * value_dim,
bias_init=1.0,
initializer='zeros',
name=f'{name}gating_query',
)(x)
weighted_avg *= jax.nn.sigmoid(gate_logits)
output = adaptive_zero_init(
weighted_avg, num_channels, single_cond, global_config, name
)
return output
class Transformer(hk.Module):
"""Simple transformer stack."""
class Config(base_config.BaseConfig):
attention: SelfAttentionConfig = base_config.autocreate()
num_blocks: int = 24
block_remat: bool = False
super_block_size: int = 4
num_intermediate_factor: int = 2
def __init__(
self,
config: Config,
global_config: model_config.GlobalConfig,
name: str = 'transformer',
):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(
self,
act: jnp.ndarray,
mask: jnp.ndarray,
single_cond: jnp.ndarray,
pair_cond: jnp.ndarray | None,
) -> jnp.ndarray:
def block(act, pair_logits):
act += self_attention(
act,
mask,
pair_logits,
self.config.attention,
self.global_config,
single_cond,
name=self.name,
)
act += transition_block(
act,
self.config.num_intermediate_factor,
self.global_config,
single_cond,
name=self.name,
)
return act, None
# Precompute pair logits for performance
if pair_cond is None:
pair_act = None
else:
pair_act = hm.LayerNorm(
name='pair_input_layer_norm',
use_fast_variance=False,
create_offset=False,
)(pair_cond)
assert self.config.num_blocks % self.config.super_block_size == 0
num_super_blocks = self.config.num_blocks // self.config.super_block_size
def super_block(act):
if pair_act is None:
pair_logits = None
else:
pair_logits = hm.Linear(
(self.config.super_block_size, self.config.attention.num_head),
name='pair_logits_projection',
)(pair_act)
pair_logits = jnp.transpose(pair_logits, [2, 3, 0, 1])
return hk.experimental.layer_stack(
self.config.super_block_size, with_per_layer_inputs=True
)(block)(act, pair_logits)
return hk.experimental.layer_stack(
num_super_blocks, with_per_layer_inputs=True
)(super_block)(act)[0]
class CrossAttentionConfig(base_config.BaseConfig):
num_head: int = 4
key_dim: int = 128
value_dim: int = 128
def cross_attention(
x_q: jnp.ndarray, # (..., Q, C)
x_k: jnp.ndarray, # (..., K, C)
mask_q: jnp.ndarray, # (..., Q)
mask_k: jnp.ndarray, # (..., K)
config: CrossAttentionConfig,
global_config: model_config.GlobalConfig,
pair_logits: jnp.ndarray | None = None, # (..., Q, K)
single_cond_q: jnp.ndarray | None = None, # (..., Q, C)
single_cond_k: jnp.ndarray | None = None, # (..., K, C)
name: str = '',
) -> jnp.ndarray:
"""Multihead self-attention."""
assert len(mask_q.shape) == len(x_q.shape) - 1, f'{mask_q.shape}, {x_q.shape}'
assert len(mask_k.shape) == len(x_k.shape) - 1, f'{mask_k.shape}, {x_k.shape}'
# bias: ... x heads (1) x query x key
bias = (
1e9
* (mask_q - 1.0)[..., None, :, None]
* (mask_k - 1.0)[..., None, None, :]
)
x_q = adaptive_layernorm(x_q, single_cond_q, name=f'{name}q')
x_k = adaptive_layernorm(x_k, single_cond_k, name=f'{name}k')
assert config.key_dim % config.num_head == 0
assert config.value_dim % config.num_head == 0
key_dim = config.key_dim // config.num_head
value_dim = config.value_dim // config.num_head
q = hm.Linear(
(config.num_head, key_dim), use_bias=True, name=f'{name}q_projection'
)(x_q)
k = hm.Linear(
(config.num_head, key_dim), use_bias=False, name=f'{name}k_projection'
)(x_k)
# In some situations the gradient norms can blow up without running this
# einsum in float32.
q = q.astype(jnp.float32)
k = k.astype(jnp.float32)
bias = bias.astype(jnp.float32)
logits = jnp.einsum('...qhc,...khc->...hqk', q * key_dim ** (-0.5), k) + bias
if pair_logits is not None:
logits += pair_logits
weights = jax.nn.softmax(logits, axis=-1)
weights = jnp.asarray(weights, dtype=x_q.dtype)
v = hm.Linear(
(config.num_head, value_dim), use_bias=False, name=f'{name}v_projection'
)(x_k)
weighted_avg = jnp.einsum('...hqk,...khc->...qhc', weights, v)
weighted_avg = jnp.reshape(weighted_avg, weighted_avg.shape[:-2] + (-1,))
gate_logits = hm.Linear(
config.num_head * value_dim,
bias_init=1.0,
initializer='zeros',
name=f'{name}gating_query',
)(x_q)
weighted_avg *= jax.nn.sigmoid(gate_logits)
output = adaptive_zero_init(
weighted_avg, x_q.shape[-1], single_cond_q, global_config, name
)
return output
class CrossAttTransformer(hk.Module):
"""Transformer that applies cross attention between two sets of subsets."""
class Config(base_config.BaseConfig):
num_intermediate_factor: int
num_blocks: int
attention: CrossAttentionConfig = base_config.autocreate()
def __init__(
self,
config: Config,
global_config: model_config.GlobalConfig,
name: str = 'transformer',
):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(
self,
queries_act: jnp.ndarray, # (num_subsets, num_queries, ch)
queries_mask: jnp.ndarray, # (num_subsets, num_queries)
queries_to_keys: atom_layout.GatherInfo, # (num_subsets, num_keys)
keys_mask: jnp.ndarray, # (num_subsets, num_keys)
queries_single_cond: jnp.ndarray, # (num_subsets, num_queries, ch)
keys_single_cond: jnp.ndarray, # (num_subsets, num_keys, ch)
pair_cond: jnp.ndarray, # (num_subsets, num_queries, num_keys, ch)
) -> jnp.ndarray:
def block(queries_act, pair_logits):
# copy the queries activations to the keys layout
keys_act = atom_layout.convert(
queries_to_keys, queries_act, layout_axes=(-3, -2)
)
# cross attention
queries_act += cross_attention(
x_q=queries_act,
x_k=keys_act,
mask_q=queries_mask,
mask_k=keys_mask,
config=self.config.attention,
global_config=self.global_config,
pair_logits=pair_logits,
single_cond_q=queries_single_cond,
single_cond_k=keys_single_cond,
name=self.name,
)
queries_act += transition_block(
queries_act,
self.config.num_intermediate_factor,
self.global_config,
queries_single_cond,
name=self.name,
)
return queries_act, None
# Precompute pair logits for performance
pair_act = hm.LayerNorm(
name='pair_input_layer_norm',
use_fast_variance=False,
create_offset=False,
)(pair_cond)
# (num_subsets, num_queries, num_keys, num_blocks, num_heads)
pair_logits = hm.Linear(
(self.config.num_blocks, self.config.attention.num_head),
name='pair_logits_projection',
)(pair_act)
# (num_block, num_subsets, num_heads, num_queries, num_keys)
pair_logits = jnp.transpose(pair_logits, [3, 0, 4, 1, 2])
return hk.experimental.layer_stack(
self.config.num_blocks, with_per_layer_inputs=True
)(block)(queries_act, pair_logits)[0]
================================================
FILE: src/alphafold3/model/network/distogram_head.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Distogram head."""
from typing import Final
from alphafold3.common import base_config
from alphafold3.model import feat_batch
from alphafold3.model import model_config
from alphafold3.model.components import haiku_modules as hm
import haiku as hk
import jax
import jax.numpy as jnp
_CONTACT_THRESHOLD: Final[float] = 8.0
_CONTACT_EPSILON: Final[float] = 1e-3
class DistogramHead(hk.Module):
"""Distogram head."""
class Config(base_config.BaseConfig):
first_break: float = 2.3125
last_break: float = 21.6875
num_bins: int = 64
def __init__(
self,
config: Config,
global_config: model_config.GlobalConfig,
name='distogram_head',
):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(
self,
batch: feat_batch.Batch,
embeddings: dict[str, jnp.ndarray],
return_distogram: bool = False,
) -> dict[str, jnp.ndarray]:
pair_act = embeddings['pair']
seq_mask = batch.token_features.mask.astype(bool)
pair_mask = seq_mask[:, None] * seq_mask[None, :]
left_half_logits = hm.Linear(
self.config.num_bins,
initializer=self.global_config.final_init,
name='half_logits',
)(pair_act)
right_half_logits = left_half_logits
logits = left_half_logits + jnp.swapaxes(right_half_logits, -2, -3)
probs = jax.nn.softmax(logits, axis=-1)
breaks = jnp.linspace(
self.config.first_break,
self.config.last_break,
self.config.num_bins - 1,
)
bin_tops = jnp.append(breaks, breaks[-1] + (breaks[-1] - breaks[-2]))
threshold = _CONTACT_THRESHOLD + _CONTACT_EPSILON
is_contact_bin = 1.0 * (bin_tops <= threshold)
contact_probs = jnp.einsum(
'ijk,k->ij', probs, is_contact_bin, precision=jax.lax.Precision.HIGHEST
)
contact_probs = pair_mask * contact_probs
return_dict = {'bin_edges': breaks, 'contact_probs': contact_probs}
if return_distogram:
return_dict['distogram'] = logits
return return_dict
================================================
FILE: src/alphafold3/model/network/evoformer.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Evoformer network."""
import functools
from alphafold3.common import base_config
from alphafold3.model import feat_batch
from alphafold3.model import features
from alphafold3.model import model_config
from alphafold3.model.components import haiku_modules as hm
from alphafold3.model.components import utils
from alphafold3.model.network import atom_cross_attention
from alphafold3.model.network import featurization
from alphafold3.model.network import modules
from alphafold3.model.network import template_modules
import haiku as hk
import jax
import jax.numpy as jnp
class Evoformer(hk.Module):
"""Creates 'single' and 'pair' embeddings."""
class PairformerConfig(modules.PairFormerIteration.Config): # pytype: disable=invalid-function-definition
block_remat: bool = False
remat_block_size: int = 8
class Config(base_config.BaseConfig):
"""Configuration for Evoformer."""
max_relative_chain: int = 2
msa_channel: int = 64
seq_channel: int = 384
max_relative_idx: int = 32
num_msa: int = 1024
pair_channel: int = 128
pairformer: 'Evoformer.PairformerConfig' = base_config.autocreate(
single_transition=base_config.autocreate(),
single_attention=base_config.autocreate(),
num_layer=48,
)
per_atom_conditioning: atom_cross_attention.AtomCrossAttEncoderConfig = (
base_config.autocreate(
per_token_channels=384,
per_atom_channels=128,
atom_transformer=base_config.autocreate(
num_intermediate_factor=2,
num_blocks=3,
),
per_atom_pair_channels=16,
)
)
template: template_modules.TemplateEmbedding.Config = (
base_config.autocreate()
)
msa_stack: modules.EvoformerIteration.Config = base_config.autocreate()
def __init__(
self,
config: Config,
global_config: model_config.GlobalConfig,
name='evoformer',
):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def _relative_encoding(
self, batch: feat_batch.Batch, pair_activations: jnp.ndarray
) -> jnp.ndarray:
"""Add relative position encodings."""
rel_feat = featurization.create_relative_encoding(
seq_features=batch.token_features,
max_relative_idx=self.config.max_relative_idx,
max_relative_chain=self.config.max_relative_chain,
)
rel_feat = rel_feat.astype(pair_activations.dtype)
pair_activations += hm.Linear(
self.config.pair_channel, name='position_activations'
)(rel_feat)
return pair_activations
@hk.transparent
def _seq_pair_embedding(
self,
token_features: features.TokenFeatures,
target_feat: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Generated Pair embedding from sequence."""
left_single = hm.Linear(self.config.pair_channel, name='left_single')(
target_feat
)[:, None]
right_single = hm.Linear(self.config.pair_channel, name='right_single')(
target_feat
)[None]
dtype = left_single.dtype
pair_activations = left_single + right_single
num_residues = pair_activations.shape[0]
assert pair_activations.shape == (
num_residues,
num_residues,
self.config.pair_channel,
)
mask = token_features.mask
pair_mask = (mask[:, None] * mask[None, :]).astype(dtype)
assert pair_mask.shape == (num_residues, num_residues)
return pair_activations, pair_mask # pytype: disable=bad-return-type # jax-ndarray
@hk.transparent
def _embed_bonds(
self,
batch: feat_batch.Batch,
pair_activations: jnp.ndarray,
) -> jnp.ndarray:
"""Embeds bond features and merges into pair activations."""
# Construct contact matrix.
num_tokens = batch.token_features.token_index.shape[0]
contact_matrix = jnp.zeros((num_tokens, num_tokens))
tokens_to_polymer_ligand_bonds = (
batch.polymer_ligand_bond_info.tokens_to_polymer_ligand_bonds
)
gather_idxs_polymer_ligand = tokens_to_polymer_ligand_bonds.gather_idxs
gather_mask_polymer_ligand = (
tokens_to_polymer_ligand_bonds.gather_mask.prod(axis=1).astype(
gather_idxs_polymer_ligand.dtype
)[:, None]
)
# If valid mask then it will be all 1's, so idxs should be unchanged.
gather_idxs_polymer_ligand = (
gather_idxs_polymer_ligand * gather_mask_polymer_ligand
)
tokens_to_ligand_ligand_bonds = (
batch.ligand_ligand_bond_info.tokens_to_ligand_ligand_bonds
)
gather_idxs_ligand_ligand = tokens_to_ligand_ligand_bonds.gather_idxs
gather_mask_ligand_ligand = tokens_to_ligand_ligand_bonds.gather_mask.prod(
axis=1
).astype(gather_idxs_ligand_ligand.dtype)[:, None]
gather_idxs_ligand_ligand = (
gather_idxs_ligand_ligand * gather_mask_ligand_ligand
)
gather_idxs = jnp.concatenate(
[gather_idxs_polymer_ligand, gather_idxs_ligand_ligand]
)
contact_matrix = contact_matrix.at[
gather_idxs[:, 0], gather_idxs[:, 1]
].set(1.0)
# Because all the padded index's are 0's.
contact_matrix = contact_matrix.at[0, 0].set(0.0)
bonds_act = hm.Linear(self.config.pair_channel, name='bond_embedding')(
contact_matrix[:, :, None].astype(pair_activations.dtype)
)
return pair_activations + bonds_act
@hk.transparent
def _embed_template_pair(
self,
batch: feat_batch.Batch,
pair_activations: jnp.ndarray,
pair_mask: jnp.ndarray,
key: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Embeds Templates and merges into pair activations."""
dtype = pair_activations.dtype
key, subkey = jax.random.split(key)
template_module = template_modules.TemplateEmbedding(
self.config.template, self.global_config
)
templates = batch.templates
asym_id = batch.token_features.asym_id
# Construct a mask such that only intra-chain template features are
# computed, since all templates are for each chain individually.
multichain_mask = (asym_id[:, None] == asym_id[None, :]).astype(dtype)
template_fn = functools.partial(template_module, key=subkey)
template_act = template_fn(
query_embedding=pair_activations,
templates=templates,
multichain_mask_2d=multichain_mask,
padding_mask_2d=pair_mask,
)
return pair_activations + template_act, key
@hk.transparent
def _embed_process_msa(
self,
msa_batch: features.MSA,
pair_activations: jnp.ndarray,
pair_mask: jnp.ndarray,
key: jnp.ndarray,
target_feat: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Processes MSA and returns updated pair activations."""
dtype = pair_activations.dtype
msa_batch, key = featurization.shuffle_msa(key, msa_batch)
msa_batch = featurization.truncate_msa_batch(msa_batch, self.config.num_msa)
msa_feat = featurization.create_msa_feat(msa_batch).astype(dtype)
msa_activations = hm.Linear(
self.config.msa_channel, name='msa_activations'
)(msa_feat)
msa_activations += hm.Linear(
self.config.msa_channel, name='extra_msa_target_feat'
)(target_feat)[None]
msa_mask = msa_batch.mask.astype(dtype)
# Evoformer MSA stack.
evoformer_input = {'msa': msa_activations, 'pair': pair_activations}
masks = {'msa': msa_mask, 'pair': pair_mask}
def evoformer_fn(x):
return modules.EvoformerIteration(
self.config.msa_stack, self.global_config, name='msa_stack'
)(
activations=x,
masks=masks,
)
evoformer_stack = hk.experimental.layer_stack(
self.config.msa_stack.num_layer
)(evoformer_fn)
evoformer_output = evoformer_stack(evoformer_input)
return evoformer_output['pair'], key
def __call__(
self,
batch: feat_batch.Batch,
prev: dict[str, jnp.ndarray],
target_feat: jnp.ndarray,
key: jnp.ndarray,
) -> dict[str, jnp.ndarray]:
assert self.global_config.bfloat16 in {'all', 'none'}
num_residues = target_feat.shape[0]
assert batch.token_features.aatype.shape == (num_residues,)
dtype = (
jnp.bfloat16 if self.global_config.bfloat16 == 'all' else jnp.float32
)
with utils.bfloat16_context():
pair_activations, pair_mask = self._seq_pair_embedding(
batch.token_features, target_feat
)
pair_activations += hm.Linear(
pair_activations.shape[-1],
name='prev_embedding',
initializer=self.global_config.final_init,
)(
hm.LayerNorm(name='prev_embedding_layer_norm')(
prev['pair'].astype(pair_activations.dtype)
)
)
pair_activations = self._relative_encoding(batch, pair_activations)
pair_activations = self._embed_bonds(
batch=batch, pair_activations=pair_activations
)
pair_activations, key = self._embed_template_pair(
batch=batch,
pair_activations=pair_activations,
pair_mask=pair_mask,
key=key,
)
pair_activations, key = self._embed_process_msa(
msa_batch=batch.msa,
pair_activations=pair_activations,
pair_mask=pair_mask,
key=key,
target_feat=target_feat,
)
del key # Unused after this point.
single_activations = hm.Linear(
self.config.seq_channel, name='single_activations'
)(target_feat)
single_activations += hm.Linear(
single_activations.shape[-1],
name='prev_single_embedding',
initializer=self.global_config.final_init,
)(
hm.LayerNorm(name='prev_single_embedding_layer_norm')(
prev['single'].astype(single_activations.dtype)
)
)
def pairformer_fn(x):
pairformer_iteration = modules.PairFormerIteration(
self.config.pairformer,
self.global_config,
with_single=True,
name='trunk_pairformer',
)
pair_act, single_act = x
return pairformer_iteration(
act=pair_act,
single_act=single_act,
pair_mask=pair_mask,
seq_mask=batch.token_features.mask.astype(dtype),
)
pairformer_stack = hk.experimental.layer_stack(
self.config.pairformer.num_layer
)(pairformer_fn)
pair_activations, single_activations = pairformer_stack(
(pair_activations, single_activations)
)
assert pair_activations.shape == (
num_residues,
num_residues,
self.config.pair_channel,
)
assert single_activations.shape == (num_residues, self.config.seq_channel)
assert len(target_feat.shape) == 2
assert target_feat.shape[0] == num_residues
output = {
'single': single_activations,
'pair': pair_activations,
'target_feat': target_feat,
}
return output
================================================
FILE: src/alphafold3/model/network/featurization.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Model-side of the input features processing."""
import functools
from alphafold3.constants import residue_names
from alphafold3.model import feat_batch
from alphafold3.model import features
from alphafold3.model.components import utils
import jax
import jax.numpy as jnp
def _grid_keys(key, shape):
"""Generate a grid of rng keys that is consistent with different padding.
Generate random keys such that the keys will be identical, regardless of
how much padding is added to any dimension.
Args:
key: A PRNG key.
shape: The shape of the output array of keys that will be generated.
Returns:
An array of shape `shape` consisting of random keys.
"""
if not shape:
return key
new_keys = jax.vmap(functools.partial(jax.random.fold_in, key))(
jnp.arange(shape[0])
)
return jax.vmap(functools.partial(_grid_keys, shape=shape[1:]))(new_keys)
def _padding_consistent_rng(f):
"""Modify any element-wise random function to be consistent with padding.
Normally if you take a function like jax.random.normal and generate an array,
say of size (10,10), you will get a different set of random numbers to if you
add padding and take the first (10,10) sub-array.
This function makes a random function that is consistent regardless of the
amount of padding added.
Note: The padding-consistent function is likely to be slower to compile and
run than the function it is wrapping, but these slowdowns are likely to be
negligible in a large network.
Args:
f: Any element-wise function that takes (PRNG key, shape) as the first 2
arguments.
Returns:
An equivalent function to f, that is now consistent for different amounts of
padding.
"""
def inner(key, shape, **kwargs):
keys = _grid_keys(key, shape)
signature = (
'()->()'
if jax.dtypes.issubdtype(keys.dtype, jax.dtypes.prng_key)
else '(2)->()'
)
return jnp.vectorize(
functools.partial(f, shape=(), **kwargs), signature=signature
)(keys)
return inner
def gumbel_argsort_sample_idx(
key: jnp.ndarray, logits: jnp.ndarray
) -> jnp.ndarray:
"""Samples with replacement from a distribution given by 'logits'.
This uses Gumbel trick to implement the sampling an efficient manner. For a
distribution over k items this samples k times without replacement, so this
is effectively sampling a random permutation with probabilities over the
permutations derived from the logprobs.
Args:
key: prng key
logits: logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
gumbel = _padding_consistent_rng(jax.random.gumbel)
z = gumbel(key, logits.shape)
# This construction is equivalent to jnp.argsort, but using a non stable sort,
# since stable sort's aren't supported by jax2tf
axis = len(logits.shape) - 1
iota = jax.lax.broadcasted_iota(jnp.int64, logits.shape, axis)
_, perm = jax.lax.sort_key_val(
logits + z, iota, dimension=-1, is_stable=False
)
return perm[::-1]
def create_msa_feat(msa: features.MSA) -> jax.Array:
"""Create and concatenate MSA features."""
msa_1hot = jax.nn.one_hot(
msa.rows, residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP + 1
)
deletion_matrix = msa.deletion_matrix
has_deletion = jnp.clip(deletion_matrix, 0.0, 1.0)[..., None]
deletion_value = (jnp.arctan(deletion_matrix / 3.0) * (2.0 / jnp.pi))[
..., None
]
msa_feat = [
msa_1hot,
has_deletion,
deletion_value,
]
return jnp.concatenate(msa_feat, axis=-1)
def truncate_msa_batch(msa: features.MSA, num_msa: int) -> features.MSA:
indices = jnp.arange(num_msa)
return msa.index_msa_rows(indices)
def create_target_feat(
batch: feat_batch.Batch,
append_per_atom_features: bool,
) -> jax.Array:
"""Make target feat."""
token_features = batch.token_features
target_features = []
target_features.append(
jax.nn.one_hot(
token_features.aatype,
residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP,
)
)
target_features.append(batch.msa.profile)
target_features.append(batch.msa.deletion_mean[..., None])
# Reference structure features
if append_per_atom_features:
ref_mask = batch.ref_structure.mask
element_feat = jax.nn.one_hot(batch.ref_structure.element, 128)
element_feat = utils.mask_mean(
mask=ref_mask[..., None], value=element_feat, axis=-2, eps=1e-6
)
target_features.append(element_feat)
pos_feat = batch.ref_structure.positions
pos_feat = pos_feat.reshape([pos_feat.shape[0], -1])
target_features.append(pos_feat)
target_features.append(ref_mask)
return jnp.concatenate(target_features, axis=-1)
def create_relative_encoding(
seq_features: features.TokenFeatures,
max_relative_idx: int,
max_relative_chain: int,
) -> jax.Array:
"""Add relative position encodings."""
rel_feats = []
token_index = seq_features.token_index
residue_index = seq_features.residue_index
asym_id = seq_features.asym_id
entity_id = seq_features.entity_id
sym_id = seq_features.sym_id
left_asym_id = asym_id[:, None]
right_asym_id = asym_id[None, :]
left_residue_index = residue_index[:, None]
right_residue_index = residue_index[None, :]
left_token_index = token_index[:, None]
right_token_index = token_index[None, :]
left_entity_id = entity_id[:, None]
right_entity_id = entity_id[None, :]
left_sym_id = sym_id[:, None]
right_sym_id = sym_id[None, :]
# Embed relative positions using a one-hot embedding of distance along chain
offset = left_residue_index - right_residue_index
clipped_offset = jnp.clip(
offset + max_relative_idx, min=0, max=2 * max_relative_idx
)
asym_id_same = left_asym_id == right_asym_id
final_offset = jnp.where(
asym_id_same,
clipped_offset,
(2 * max_relative_idx + 1) * jnp.ones_like(clipped_offset),
)
rel_pos = jax.nn.one_hot(final_offset, 2 * max_relative_idx + 2)
rel_feats.append(rel_pos)
# Embed relative token index as a one-hot embedding of distance along residue
token_offset = left_token_index - right_token_index
clipped_token_offset = jnp.clip(
token_offset + max_relative_idx, min=0, max=2 * max_relative_idx
)
residue_same = (left_asym_id == right_asym_id) & (
left_residue_index == right_residue_index
)
final_token_offset = jnp.where(
residue_same,
clipped_token_offset,
(2 * max_relative_idx + 1) * jnp.ones_like(clipped_token_offset),
)
rel_token = jax.nn.one_hot(final_token_offset, 2 * max_relative_idx + 2)
rel_feats.append(rel_token)
# Embed same entity ID
entity_id_same = left_entity_id == right_entity_id
rel_feats.append(entity_id_same.astype(rel_pos.dtype)[..., None])
# Embed relative chain ID inside each symmetry class
rel_sym_id = left_sym_id - right_sym_id
max_rel_chain = max_relative_chain
clipped_rel_chain = jnp.clip(
rel_sym_id + max_rel_chain, min=0, max=2 * max_rel_chain
)
final_rel_chain = jnp.where(
entity_id_same,
clipped_rel_chain,
(2 * max_rel_chain + 1) * jnp.ones_like(clipped_rel_chain),
)
rel_chain = jax.nn.one_hot(final_rel_chain, 2 * max_relative_chain + 2)
rel_feats.append(rel_chain)
return jnp.concatenate(rel_feats, axis=-1)
def shuffle_msa(
key: jax.Array, msa: features.MSA
) -> tuple[features.MSA, jax.Array]:
"""Shuffle MSA randomly, return batch with shuffled MSA.
Args:
key: rng key for random number generation.
msa: MSA object to sample msa from.
Returns:
Protein with sampled msa.
"""
key, sample_key = jax.random.split(key)
# Sample uniformly among sequences with at least one non-masked position.
logits = (jnp.clip(jnp.sum(msa.mask, axis=-1), 0.0, 1.0) - 1.0) * 1e6
index_order = gumbel_argsort_sample_idx(sample_key, logits)
return msa.index_msa_rows(index_order), key
================================================
FILE: src/alphafold3/model/network/modules.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Haiku modules for the Diffuser model."""
from collections.abc import Sequence
from typing import Literal
from alphafold3.common import base_config
from alphafold3.model import model_config
from alphafold3.model.components import haiku_modules as hm
from alphafold3.model.components import mapping
from alphafold3.model.network import diffusion_transformer
import haiku as hk
import jax
import jax.numpy as jnp
import tokamax
def get_shard_size(
num_residues: int, shard_spec: Sequence[tuple[int | None, int | None]]
) -> int | None:
shard_size = shard_spec[0][-1]
for num_residues_upper_bound, num_residues_shard_size in shard_spec:
shard_size = num_residues_shard_size
if (
num_residues_upper_bound is None
or num_residues <= num_residues_upper_bound
):
break
return shard_size
class TransitionBlock(hk.Module):
"""Transition block for transformer."""
class Config(base_config.BaseConfig):
num_intermediate_factor: int = 4
use_glu_kernel: bool = True
def __init__(
self, config: Config, global_config: model_config.GlobalConfig, *, name
):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, act, broadcast_dim=0):
num_channels = act.shape[-1]
num_intermediate = int(num_channels * self.config.num_intermediate_factor)
act = hm.LayerNorm(name='input_layer_norm')(act)
if self.config.use_glu_kernel:
weights, _ = hm.haiku_linear_get_params(
act,
num_output=num_intermediate * 2,
initializer='relu',
name='transition1',
)
weights = jnp.reshape(weights, (len(weights), 2, num_intermediate))
c = tokamax.gated_linear_unit(
x=act, weights=weights, activation=jax.nn.swish
)
else:
act = hm.Linear(
num_intermediate * 2, initializer='relu', name='transition1'
)(act)
a, b = jnp.split(act, 2, axis=-1)
c = jax.nn.swish(a) * b
return hm.Linear(
num_channels,
initializer=self.global_config.final_init,
name='transition2',
)(c)
class MSAAttention(hk.Module):
"""MSA Attention."""
class Config(base_config.BaseConfig):
num_head: int = 8
def __init__(
self, config: Config, global_config: model_config.GlobalConfig, *, name
):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, act, mask, pair_act):
act = hm.LayerNorm(name='act_norm')(act)
pair_act = hm.LayerNorm(name='pair_norm')(pair_act)
logits = hm.Linear(
self.config.num_head, use_bias=False, name='pair_logits'
)(pair_act)
logits = jnp.transpose(logits, [2, 0, 1])
logits += 1e9 * (jnp.max(mask, axis=0) - 1.0)
weights = jax.nn.softmax(logits, axis=-1)
num_channels = act.shape[-1]
value_dim = num_channels // self.config.num_head
v = hm.Linear(
[self.config.num_head, value_dim], use_bias=False, name='v_projection'
)(act)
v_avg = jnp.einsum('hqk, bkhc -> bqhc', weights, v)
v_avg = jnp.reshape(v_avg, v_avg.shape[:-2] + (-1,))
gate_values = hm.Linear(
self.config.num_head * value_dim,
bias_init=1.0,
initializer='zeros',
name='gating_query',
)(act)
v_avg *= jax.nn.sigmoid(gate_values)
return hm.Linear(
num_channels,
initializer=self.global_config.final_init,
name='output_projection',
)(v_avg)
class GridSelfAttention(hk.Module):
"""Self attention that is either per-sequence or per-residue."""
class Config(base_config.BaseConfig):
num_head: int = 4
def __init__(
self,
config: Config,
global_config: model_config.GlobalConfig,
transpose: bool,
*,
name: str,
):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.transpose = transpose
@hk.transparent
def _attention(
self,
act,
mask,
bias,
):
num_channels = act.shape[-1]
assert num_channels % self.config.num_head == 0
# Triton requires a minimum dimension of 16 for doing matmul.
qkv_dim = max(num_channels // self.config.num_head, 16)
qkv_shape = (self.config.num_head, qkv_dim)
q = hm.Linear(
qkv_shape, use_bias=False, name='q_projection', transpose_weights=True
)(act)
k = hm.Linear(
qkv_shape, use_bias=False, name='k_projection', transpose_weights=True
)(act)
v = hm.Linear(qkv_shape, use_bias=False, name='v_projection')(act)
# Dot product attention requires the bias term to have a batch dimension.
bias = jnp.expand_dims(bias, 0)
weighted_avg = tokamax.dot_product_attention(
q,
k,
v,
mask=mask,
bias=bias,
implementation=self.global_config.flash_attention_implementation,
)
weighted_avg = jnp.reshape(weighted_avg, weighted_avg.shape[:-2] + (-1,))
gate_values = hm.Linear(
self.config.num_head * qkv_dim,
bias_init=1.0,
initializer='zeros',
transpose_weights=True,
name='gating_query',
)(act)
weighted_avg *= jax.nn.sigmoid(gate_values)
return hm.Linear(
num_channels,
initializer=self.global_config.final_init,
name='output_projection',
)(weighted_avg)
def __call__(self, act, pair_mask):
"""Builds a module.
Arguments:
act: [num_seq, num_res, channels] activations tensor
pair_mask: [num_seq, num_res] mask of non-padded regions in the tensor.
Only used in inducing points attention currently.
Returns:
Result of the self-attention operation.
"""
assert len(act.shape) == 3
assert len(pair_mask.shape) == 2
pair_mask = jnp.swapaxes(pair_mask, -1, -2)
act = hm.LayerNorm(name='act_norm')(act)
nonbatched_bias = hm.Linear(
self.config.num_head, use_bias=False, name='pair_bias_projection'
)(act)
nonbatched_bias = jnp.transpose(nonbatched_bias, [2, 0, 1])
num_residues = act.shape[0]
chunk_size = get_shard_size(
num_residues, self.global_config.pair_attention_chunk_size
)
if self.transpose:
act = jnp.swapaxes(act, -2, -3)
pair_mask = pair_mask[:, None, None, :].astype(jnp.bool_)
act = mapping.inference_subbatch(
self._attention,
chunk_size,
batched_args=[act, pair_mask],
nonbatched_args=[nonbatched_bias],
)
if self.transpose:
act = jnp.swapaxes(act, -2, -3)
return act
class TriangleMultiplication(hk.Module):
"""Triangle Multiplication."""
class Config(base_config.BaseConfig):
equation: Literal['ikc,jkc->ijc', 'kjc,kic->ijc']
use_glu_kernel: bool = True
def __init__(
self, config: Config, global_config: model_config.GlobalConfig, *, name
):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, act, mask):
"""Applies Module.
Args:
act: The activation.
mask: The mask.
Returns:
Outputs, should have same shape/type as output_act
"""
mask = mask[None, ...]
num_channels = act.shape[-1]
equation = {
'ikc,jkc->ijc': 'cik,cjk->cij',
'kjc,kic->ijc': 'ckj,cki->cij',
}[self.config.equation]
act = hm.LayerNorm(name='left_norm_input')(act)
input_act = act
if self.config.use_glu_kernel:
weights_projection, _ = hm.haiku_linear_get_params(
act, num_output=num_channels * 2, name='projection'
)
weights_gate, _ = hm.haiku_linear_get_params(
act,
num_output=num_channels * 2,
initializer=self.global_config.final_init,
name='gate',
)
weights_glu = jnp.stack([weights_gate, weights_projection], axis=1)
projection = tokamax.gated_linear_unit(
act, weights_glu, activation=jax.nn.sigmoid
)
projection = jnp.transpose(projection, (2, 0, 1))
projection *= mask
else:
projection = hm.Linear(num_channels * 2, name='projection')(act)
projection = jnp.transpose(projection, (2, 0, 1))
projection *= mask
gate = hm.Linear(
num_channels * 2,
name='gate',
bias_init=1.0,
initializer=self.global_config.final_init,
)(act)
gate = jnp.transpose(gate, (2, 0, 1))
projection *= jax.nn.sigmoid(gate)
projection = projection.reshape(num_channels, 2, *projection.shape[1:])
a, b = jnp.split(projection, 2, axis=1)
a, b = jnp.squeeze(a, axis=1), jnp.squeeze(b, axis=1)
act = jnp.einsum(equation, a, b)
act = hm.LayerNorm(name='center_norm', axis=0, param_axis=0)(act)
act = jnp.transpose(act, (1, 2, 0))
act = hm.Linear(
num_channels,
initializer=self.global_config.final_init,
name='output_projection',
)(act)
gate_out = hm.Linear(
num_channels,
name='gating_linear',
bias_init=1.0,
initializer=self.global_config.final_init,
)(input_act)
act *= jax.nn.sigmoid(gate_out)
return act
class OuterProductMean(hk.Module):
"""Computed mean outer product."""
class Config(base_config.BaseConfig):
chunk_size: int = 128
num_outer_channel: int = 32
def __init__(
self,
config: Config,
global_config: model_config.GlobalConfig,
num_output_channel,
*,
name,
):
super().__init__(name=name)
self.global_config = global_config
self.config = config
self.num_output_channel = num_output_channel
def __call__(self, act, mask):
mask = mask[..., None]
act = hm.LayerNorm(name='layer_norm_input')(act)
left_act = mask * hm.Linear(
self.config.num_outer_channel,
initializer='linear',
name='left_projection',
)(act)
right_act = mask * hm.Linear(
self.config.num_outer_channel,
initializer='linear',
name='right_projection',
)(act)
if self.global_config.final_init == 'zeros':
w_init = hk.initializers.Constant(0.0)
else:
w_init = hk.initializers.VarianceScaling(scale=2.0, mode='fan_in')
output_w = hk.get_parameter(
'output_w',
shape=(
self.config.num_outer_channel,
self.config.num_outer_channel,
self.num_output_channel,
),
dtype=act.dtype,
init=w_init,
)
output_b = hk.get_parameter(
'output_b',
shape=(self.num_output_channel,),
dtype=act.dtype,
init=hk.initializers.Constant(0.0),
)
def compute_chunk(left_act):
# Make sure that the 'b' dimension is the most minor batch like dimension
# so it will be treated as the real batch by XLA (both during the forward
# and the backward pass)
left_act = jnp.transpose(left_act, [0, 2, 1])
act = jnp.einsum('acb,ade->dceb', left_act, right_act)
act = jnp.einsum('dceb,cef->dbf', act, output_w) + output_b
return jnp.transpose(act, [1, 0, 2])
act = mapping.inference_subbatch(
compute_chunk,
self.config.chunk_size,
batched_args=[left_act],
nonbatched_args=[],
input_subbatch_dim=1,
output_subbatch_dim=0,
)
epsilon = 1e-3
norm = jnp.einsum('abc,adc->bdc', mask, mask)
return act / (epsilon + norm)
class PairFormerIteration(hk.Module):
"""Single Iteration of Pair Former."""
class Config(base_config.BaseConfig):
"""Config for PairFormerIteration."""
num_layer: int
pair_attention: GridSelfAttention.Config = base_config.autocreate()
pair_transition: TransitionBlock.Config = base_config.autocreate()
single_attention: diffusion_transformer.SelfAttentionConfig | None = None
single_transition: TransitionBlock.Config | None = None
triangle_multiplication_incoming: TriangleMultiplication.Config = (
base_config.autocreate(equation='kjc,kic->ijc')
)
triangle_multiplication_outgoing: TriangleMultiplication.Config = (
base_config.autocreate(equation='ikc,jkc->ijc')
)
shard_transition_blocks: bool = True
def __init__(
self,
config: Config,
global_config: model_config.GlobalConfig,
with_single=False,
*,
name,
):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.with_single = with_single
def __call__(
self,
act,
pair_mask,
single_act=None,
seq_mask=None,
):
"""Build a single iteration of the pair former.
Args:
act: [num_res, num_res, num_channel] Input pairwise activations.
pair_mask: [num_res, num_res] padding mask.
single_act: [num_res, single_channel] Single Input activations, optional
seq_mask: [num_res] Sequence Mask, optional.
Returns:
[num_res, num_res, num_channel] tensor of activations.
"""
num_residues = act.shape[0]
act += TriangleMultiplication(
self.config.triangle_multiplication_outgoing,
self.global_config,
name='triangle_multiplication_outgoing',
)(act, pair_mask)
act += TriangleMultiplication(
self.config.triangle_multiplication_incoming,
self.global_config,
name='triangle_multiplication_incoming',
)(act, pair_mask)
act += GridSelfAttention(
self.config.pair_attention,
self.global_config,
name='pair_attention1',
transpose=False,
)(act, pair_mask)
act += GridSelfAttention(
self.config.pair_attention,
self.global_config,
name='pair_attention2',
transpose=True,
)(act, pair_mask)
transition_block = TransitionBlock(
self.config.pair_transition, self.global_config, name='pair_transition'
)
if self.config.shard_transition_blocks:
transition_block = mapping.sharded_apply(
transition_block,
get_shard_size(
num_residues, self.global_config.pair_transition_shard_spec
),
)
act += transition_block(act)
if self.with_single:
assert self.config.single_attention is not None
pair_logits = hm.Linear(
self.config.single_attention.num_head,
name='single_pair_logits_projection',
)(hm.LayerNorm(name='single_pair_logits_norm')(act))
pair_logits = jnp.transpose(pair_logits, [2, 0, 1])
single_act += diffusion_transformer.self_attention(
single_act,
seq_mask,
pair_logits=pair_logits,
config=self.config.single_attention,
global_config=self.global_config,
name='single_attention_',
)
single_act += TransitionBlock(
self.config.single_transition,
self.global_config,
name='single_transition',
)(single_act, broadcast_dim=None)
return act, single_act
else:
return act
class EvoformerIteration(hk.Module):
"""Single Iteration of Evoformer Main Stack."""
class Config(base_config.BaseConfig):
"""Configuration for EvoformerIteration."""
num_layer: int = 4
msa_attention: MSAAttention.Config = base_config.autocreate()
outer_product_mean: OuterProductMean.Config = base_config.autocreate()
msa_transition: TransitionBlock.Config = base_config.autocreate()
pair_attention: GridSelfAttention.Config = base_config.autocreate()
pair_transition: TransitionBlock.Config = base_config.autocreate()
triangle_multiplication_incoming: TriangleMultiplication.Config = (
base_config.autocreate(equation='kjc,kic->ijc')
)
triangle_multiplication_outgoing: TriangleMultiplication.Config = (
base_config.autocreate(equation='ikc,jkc->ijc')
)
shard_transition_blocks: bool = True
def __init__(
self,
config: Config,
global_config: model_config.GlobalConfig,
name='evoformer_iteration',
):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, activations, masks):
msa_act, pair_act = activations['msa'], activations['pair']
num_residues = pair_act.shape[0]
msa_mask, pair_mask = masks['msa'], masks['pair']
pair_act += OuterProductMean(
config=self.config.outer_product_mean,
global_config=self.global_config,
num_output_channel=int(pair_act.shape[-1]),
name='outer_product_mean',
)(msa_act, msa_mask)
msa_act += MSAAttention(
self.config.msa_attention, self.global_config, name='msa_attention1'
)(msa_act, msa_mask, pair_act=pair_act)
msa_act += TransitionBlock(
self.config.msa_transition, self.global_config, name='msa_transition'
)(msa_act)
pair_act += TriangleMultiplication(
self.config.triangle_multiplication_outgoing,
self.global_config,
name='triangle_multiplication_outgoing',
)(pair_act, pair_mask)
pair_act += TriangleMultiplication(
self.config.triangle_multiplication_incoming,
self.global_config,
name='triangle_multiplication_incoming',
)(pair_act, pair_mask)
pair_act += GridSelfAttention(
self.config.pair_attention,
self.global_config,
name='pair_attention1',
transpose=False,
)(pair_act, pair_mask)
pair_act += GridSelfAttention(
self.config.pair_attention,
self.global_config,
name='pair_attention2',
transpose=True,
)(pair_act, pair_mask)
transition_block = TransitionBlock(
self.config.pair_transition, self.global_config, name='pair_transition'
)
if self.config.shard_transition_blocks:
transition_block = mapping.sharded_apply(
transition_block,
get_shard_size(
num_residues, self.global_config.pair_transition_shard_spec
),
)
pair_act += transition_block(pair_act)
return {'msa': msa_act, 'pair': pair_act}
================================================
FILE: src/alphafold3/model/network/noise_level_embeddings.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Fourier embeddings for given noise levels.
We supply fixed weights and biases for the Fourier embeddings. These were
initially generated by the following code, but we make them into constants
to future proof against changes in jax rng generation:
```
dim = 256
w_key, b_key = jax.random.split(jax.random.PRNGKey(42))
weight = jax.random.normal(w_key, shape=[dim])
bias = jax.random.uniform(b_key, shape=[dim])
```
"""
import jax.numpy as jnp
# pyformat: disable
# pylint: disable=bad-whitespace
# pylint: disable=bad-continuation
_WEIGHT = [
0.45873642, 0.06516238, -0.07278306, -0.26992258, 0.64292115,
-0.40763968, 3.60116863, 0.54461384, -0.32644904, 2.10888267,
1.30805349, 1.19838560, -1.37745857, 1.99475312, -1.64120293,
1.07823789, -0.02288206, 0.88305283, 0.48099944, 0.17655374,
0.30281949, 0.80646873, 0.62605333, -0.23965347, -1.02609432,
0.75006109, -0.19913037, 0.07466396, 0.66431236, -0.60990530,
-0.69709194, -0.44453633, -1.77656078, 0.02299878, 0.04095552,
0.35485864, -0.47602659, -0.98820388, -0.24106771, -1.07254291,
-0.99741757, 0.22697604, 1.41390419, 1.54984057, -0.12237291,
0.20156337, 0.61767143, 0.23959029, 0.92454034, 1.84082258,
0.89030224, 0.39598912, -1.52224910, 0.29669049, 1.52356744,
-0.33968377, 0.24155144, -0.52308381, -0.23622665, 0.92825454,
-0.63864607, -0.62169307, 0.78623551, -0.80352145, -0.45496067,
1.30877995, -0.06686528, 1.00248849, -0.63593471, 0.16372502,
-1.46133232, 1.10562658, -0.01693927, 0.28684548, -0.72843230,
0.66133535, -1.92225552, 0.70241231, -0.96868867, -0.47309339,
-1.66894221, 0.46018723, -0.56806105, 0.32694784, -0.46529883,
1.02299964, 0.84688205, 1.19581807, -1.82454145, 0.05999713,
-0.59530073, 1.44862521, -0.34933713, -0.46564487, -0.55005538,
-1.61170268, 0.17502306, 0.38670063, -1.12133658, -0.29343036,
-0.52527446, -1.26285112, 1.07982683, 0.51215219, 1.48963666,
1.09847653, -0.01563358, 0.32574457, 1.94779706, -1.29198587,
1.06249654, -0.86965990, 0.22975266, -0.27182648, -0.21130897,
-0.41773933, -0.02329035, 1.31049252, 0.05579265, -1.23127055,
-0.99691105, 0.27058721, -0.72509319, -0.14421797, -1.48605061,
1.35041201, 1.29619241, -1.01022530, -0.79787987, -0.16166858,
0.87210685, 1.69248152, 1.42469788, -0.72325104, -1.24823737,
0.07051118, 0.71332991, -0.07360429, -0.91955227, -2.68856549,
-0.44033936, 0.35482934, -0.57933813, 0.97468042, -0.31050494,
-0.88454425, -2.08785224, 0.47322822, -0.02400172, 0.26644820,
-0.19147627, -2.10538960, -1.27962470, -1.35999286, 2.09867334,
0.65099514, 0.21604492, -0.45951018, 0.15994427, -0.31420693,
-0.65202618, -0.61077976, -1.06100249, -1.47254968, 1.18165290,
-0.78656220, 1.28182006, 1.80323684, 1.09196901, 0.26118696,
-0.30168581, 0.39749333, 0.26812574, -1.51995814, -0.46909946,
0.03874255, -1.36774313, 2.30143976, 2.06959820, -0.41647521,
1.85624206, 0.49019700, -0.06726539, 0.00457313, 0.23915423,
-1.84971249, -0.20482327, -0.34097880, -0.57933033, -1.10541213,
-0.30269983, -0.16430426, -0.82371718, 0.10345812, 1.78753936,
0.04786763, 1.86778629, -0.65214992, 0.81544143, -0.28214937,
0.31187257, 0.57661986, 1.21938801, -1.56046617, 0.38046429,
-0.18235965, 0.81794524, -0.40474343, 0.46538028, -1.15558851,
0.59625793, -1.07801270, 0.07310858, 0.61526084, 0.55518496,
-0.49787554, 0.92703879, -1.27780271, -0.83373469, -0.43015575,
0.41877759, -1.03987372, -1.46055734, 0.61282396, 0.15590595,
-0.34269521, 0.56509072, -1.17904210, 0.11374855, -1.83310866,
0.38734794, -0.58623004, 0.77931106, 1.53930688, -0.70299625,
-0.11389336, -1.14818096, -0.44400632, 1.21887410, 0.64066756,
-0.70249403, -0.27244881, 0.38586098, -1.07925785, 0.12448707,
-1.28286278, 0.37827531, 0.68812364, 1.65695465, 0.12440517,
-0.03689830, 1.10224664, -0.28323629, -0.47939169, 0.70120829,
-0.67204583
]
_BIAS = [
0.00465965, 0.21738243, 0.22277749, 0.68463874, 0.84596848, 0.17337036,
0.39573753, 0.78153563, 0.86311185, 0.21782327, 0.24377882, 0.42310703,
0.19887352, 0.10486019, 0.48707581, 0.22205460, 0.97263455, 0.29714966,
0.11244559, 0.53020525, 0.36796236, 0.37294638, 0.80261672, 0.04669094,
0.86319661, 0.75907171, 0.77297020, 0.01114798, 0.55850804, 0.91799915,
0.23032320, 0.12154722, 0.26701927, 0.42934716, 0.47951782, 0.96782577,
0.86785042, 0.61985648, 0.05743814, 0.41800117, 0.68881893, 0.60575199,
0.21058667, 0.64412105, 0.63958526, 0.89390790, 0.69755554, 0.89345169,
0.53330755, 0.56985939, 0.30724049, 0.00984561, 0.91407037, 0.92118979,
0.94153070, 0.81097460, 0.70537627, 0.32810748, 0.47227263, 0.11821401,
0.44983089, 0.30767226, 0.31756389, 0.62969446, 0.69892538, 0.16949117,
0.06207097, 0.46717727, 0.95348179, 0.62363589, 0.49018729, 0.06920040,
0.39333904, 0.41299903, 0.52514863, 0.61197245, 0.56871891, 0.65053988,
0.22203422, 0.46748531, 0.86931503, 0.87050021, 0.40208721, 0.32084906,
0.55084610, 0.94584596, 0.76279902, 0.36250532, 0.74272907, 0.66682065,
0.96452832, 0.64768302, 0.88070846, 0.56995463, 0.06395614, 0.69499350,
0.44494808, 0.39775658, 0.20280898, 0.33363521, 0.05999005, 0.44414878,
0.65227020, 0.01199079, 0.71995056, 0.19045687, 0.48342144, 0.25127733,
0.66515994, 0.22465158, 0.22313106, 0.06302810, 0.55783665, 0.93625581,
0.58800840, 0.72525370, 0.52879298, 0.77195418, 0.15548682, 0.01028740,
0.39325142, 0.45401239, 0.71494079, 0.33011997, 0.05050695, 0.26381660,
0.63064706, 0.47604024, 0.08593416, 0.00383425, 0.06352687, 0.05510247,
0.03552997, 0.35810637, 0.56094289, 0.60922170, 0.88599777, 0.45419788,
0.40486634, 0.71297824, 0.34976673, 0.97825217, 0.12915993, 0.09566259,
0.64318919, 0.16717327, 0.82308614, 0.32672071, 0.81688786, 0.84857118,
0.99922776, 0.07551706, 0.18766022, 0.13051236, 0.39136350, 0.08768725,
0.92048228, 0.87185788, 0.39158428, 0.79224777, 0.17492688, 0.68902445,
0.81980729, 0.70458186, 0.59489477, 0.93324888, 0.49986637, 0.40705478,
0.89202917, 0.20673239, 0.39339757, 0.20996964, 0.02923799, 0.53992438,
0.40119815, 0.10366607, 0.08044600, 0.95551598, 0.20518017, 0.68826210,
0.90159297, 0.69008791, 0.86880815, 0.16246438, 0.89628279, 0.11481643,
0.61353648, 0.41545081, 0.92478311, 0.78212476, 0.48292696, 0.79621077,
0.11947489, 0.01747024, 0.22928023, 0.87387264, 0.86349785, 0.89526737,
0.58904779, 0.13896775, 0.68194926, 0.55824125, 0.44428205, 0.55422378,
0.28189969, 0.27923775, 0.09979951, 0.66994715, 0.45943546, 0.71207762,
0.17300689, 0.83434916, 0.02573085, 0.45858085, 0.55934799, 0.30676675,
0.52219367, 0.34544575, 0.19280875, 0.26937950, 0.07147646, 0.06295013,
0.76382887, 0.38737607, 0.58825982, 0.17423475, 0.05509448, 0.97228825,
0.94380617, 0.91664016, 0.18800116, 0.41771865, 0.59420645, 0.77371931,
0.64687788, 0.27284670, 0.22310913, 0.15663862, 0.45573199, 0.50386798,
0.66712272, 0.71649647, 0.28475654, 0.83415413, 0.75261366, 0.61517799,
0.93544555, 0.76141870, 0.85474241, 0.74766934, 0.33459592, 0.78477907,
0.07250881, 0.10174239, 0.95332730, 0.80793905
]
# pyformat: enable
# pylint: enable=bad-whitespace
# pylint: enable=bad-continuation
def noise_embeddings(sigma_scaled_noise_level: jnp.ndarray) -> jnp.ndarray:
"""Returns Fourier noise level embeddings for diffusion model."""
transformed_noise_level = (1 / 4) * jnp.log(sigma_scaled_noise_level)
weight = jnp.array(_WEIGHT, dtype=jnp.float32)
bias = jnp.array(_BIAS, dtype=jnp.float32)
embeddings = transformed_noise_level[..., None] * weight + bias
return jnp.cos(2 * jnp.pi * embeddings)
================================================
FILE: src/alphafold3/model/network/template_modules.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Modules for embedding templates."""
from alphafold3.common import base_config
from alphafold3.constants import residue_names
from alphafold3.jax import geometry
from alphafold3.model import features
from alphafold3.model import model_config
from alphafold3.model import protein_data_processing
from alphafold3.model.components import haiku_modules as hm
from alphafold3.model.network import modules
from alphafold3.model.scoring import scoring
import haiku as hk
import jax
import jax.numpy as jnp
class DistogramFeaturesConfig(base_config.BaseConfig):
# The left edge of the first bin.
min_bin: float = 3.25
# The left edge of the final bin. The final bin catches everything larger than
# `max_bin`.
max_bin: float = 50.75
# The number of bins in the distogram.
num_bins: int = 39
def dgram_from_positions(positions, config: DistogramFeaturesConfig):
"""Compute distogram from amino acid positions.
Args:
positions: (num_res, 3) Position coordinates.
config: Distogram bin configuration.
Returns:
Distogram with the specified number of bins.
"""
lower_breaks = jnp.linspace(config.min_bin, config.max_bin, config.num_bins)
lower_breaks = jnp.square(lower_breaks)
upper_breaks = jnp.concatenate(
[lower_breaks[1:], jnp.array([1e8], dtype=jnp.float32)], axis=-1
)
dist2 = jnp.sum(
jnp.square(
jnp.expand_dims(positions, axis=-2)
- jnp.expand_dims(positions, axis=-3)
),
axis=-1,
keepdims=True,
)
dgram = (dist2 > lower_breaks).astype(jnp.float32) * (
dist2 < upper_breaks
).astype(jnp.float32)
return dgram
def make_backbone_rigid(
positions: geometry.Vec3Array,
mask: jnp.ndarray,
group_indices: jnp.ndarray,
) -> tuple[geometry.Rigid3Array, jnp.ndarray]:
"""Make backbone Rigid3Array and mask.
Args:
positions: (num_res, num_atoms) of atom positions as Vec3Array.
mask: (num_res, num_atoms) for atom mask.
group_indices: (num_res, num_group, 3) for atom indices forming groups.
Returns:
tuple of backbone Rigid3Array and mask (num_res,).
"""
backbone_indices = group_indices[:, 0]
# main backbone frames differ in sidechain frame convention.
# for sidechain it's (C, CA, N), for backbone it's (N, CA, C)
# Hence using c, b, a, each of shape (num_res,).
c, b, a = [backbone_indices[..., i] for i in range(3)]
slice_index = jax.vmap(lambda x, i: x[i])
rigid_mask = (
slice_index(mask, a) * slice_index(mask, b) * slice_index(mask, c)
).astype(jnp.float32)
frame_positions = []
for indices in [a, b, c]:
frame_positions.append(
jax.tree.map(lambda x, idx=indices: slice_index(x, idx), positions)
)
rotation = geometry.Rot3Array.from_two_vectors(
frame_positions[2] - frame_positions[1],
frame_positions[0] - frame_positions[1],
)
rigid = geometry.Rigid3Array(rotation, frame_positions[1])
return rigid, rigid_mask
class TemplateEmbedding(hk.Module):
"""Embed a set of templates."""
class Config(base_config.BaseConfig):
num_channels: int = 64
template_stack: modules.PairFormerIteration.Config = base_config.autocreate(
num_layer=2,
pair_transition=base_config.autocreate(num_intermediate_factor=2),
)
dgram_features: DistogramFeaturesConfig = base_config.autocreate()
def __init__(
self,
config: Config,
global_config: model_config.GlobalConfig,
name='template_embedding',
):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(
self,
query_embedding: jnp.ndarray,
templates: features.Templates,
padding_mask_2d: jnp.ndarray,
multichain_mask_2d: jnp.ndarray,
key: jnp.ndarray,
) -> jnp.ndarray:
"""Generate an embedding for a set of templates.
Args:
query_embedding: [num_res, num_res, num_channel] a query tensor that will
be used to attend over the templates to remove the num_templates
dimension.
templates: A 'Templates' object.
padding_mask_2d: [num_res, num_res] Pair mask for attention operations.
multichain_mask_2d: [num_res, num_res] Pair mask for multichain.
key: random key generator.
Returns:
An embedding of size [num_res, num_res, num_channels]
"""
c = self.config
num_residues = query_embedding.shape[0]
num_templates = templates.aatype.shape[0]
query_num_channels = query_embedding.shape[2]
num_atoms = 24
assert query_embedding.shape == (
num_residues,
num_residues,
query_num_channels,
)
assert templates.aatype.shape == (num_templates, num_residues)
assert templates.atom_positions.shape == (
num_templates,
num_residues,
num_atoms,
3,
)
assert templates.atom_mask.shape == (num_templates, num_residues, num_atoms)
assert padding_mask_2d.shape == (num_residues, num_residues)
num_templates = templates.aatype.shape[0]
num_res, _, query_num_channels = query_embedding.shape
# Embed each template separately.
template_embedder = SingleTemplateEmbedding(self.config, self.global_config)
subkeys = jnp.array(jax.random.split(key, num_templates))
def scan_fn(carry, x):
templates, key = x
embedding = template_embedder(
query_embedding,
templates,
padding_mask_2d,
multichain_mask_2d,
key,
)
return carry + embedding, None
scan_init = jnp.zeros(
(num_res, num_res, c.num_channels), dtype=query_embedding.dtype
)
summed_template_embeddings, _ = hk.scan(
scan_fn, scan_init, (templates, subkeys)
)
embedding = summed_template_embeddings / (1e-7 + num_templates)
embedding = jax.nn.relu(embedding)
embedding = hm.Linear(
query_num_channels, initializer='relu', name='output_linear'
)(embedding)
assert embedding.shape == (num_residues, num_residues, query_num_channels)
return embedding
class SingleTemplateEmbedding(hk.Module):
"""Embed a single template."""
def __init__(
self,
config: TemplateEmbedding.Config,
global_config: model_config.GlobalConfig,
name='single_template_embedding',
):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(
self,
query_embedding: jnp.ndarray,
templates: features.Templates,
padding_mask_2d: jnp.ndarray,
multichain_mask_2d: jnp.ndarray,
key: jnp.ndarray,
) -> jnp.ndarray:
"""Build the single template embedding graph.
Args:
query_embedding: (num_res, num_res, num_channels) - embedding of the query
sequence/msa.
templates: 'Templates' object containing single Template.
padding_mask_2d: Padding mask (Note: this doesn't care if a template
exists, unlike the template_pseudo_beta_mask).
multichain_mask_2d: A mask indicating intra-chain residue pairs, used to
mask out between chain distances/features when templates are for single
chains.
key: Random key generator.
Returns:
A template embedding (num_res, num_res, num_channels).
"""
gc = self.global_config
c = self.config
assert padding_mask_2d.dtype == query_embedding.dtype
dtype = query_embedding.dtype
num_channels = self.config.num_channels
def construct_input(
query_embedding, templates: features.Templates, multichain_mask_2d
):
# Compute distogram feature for the template.
aatype = templates.aatype
dense_atom_mask = templates.atom_mask
dense_atom_positions = templates.atom_positions
dense_atom_positions *= dense_atom_mask[..., None]
pseudo_beta_positions, pseudo_beta_mask = scoring.pseudo_beta_fn(
templates.aatype, dense_atom_positions, dense_atom_mask
)
pseudo_beta_mask_2d = (
pseudo_beta_mask[:, None] * pseudo_beta_mask[None, :]
)
pseudo_beta_mask_2d *= multichain_mask_2d
dgram = dgram_from_positions(
pseudo_beta_positions, self.config.dgram_features
)
dgram *= pseudo_beta_mask_2d[..., None]
dgram = dgram.astype(dtype)
pseudo_beta_mask_2d = pseudo_beta_mask_2d.astype(dtype)
to_concat = [(dgram, 1), (pseudo_beta_mask_2d, 0)]
aatype = jax.nn.one_hot(
aatype,
residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP,
axis=-1,
dtype=dtype,
)
to_concat.append((aatype[None, :, :], 1))
to_concat.append((aatype[:, None, :], 1))
# Compute a feature representing the normalized vector between each
# backbone affine - i.e. in each residues local frame, what direction are
# each of the other residues.
template_group_indices = jnp.take(
protein_data_processing.RESTYPE_RIGIDGROUP_DENSE_ATOM_IDX,
templates.aatype,
axis=0,
)
rigid, backbone_mask = make_backbone_rigid(
geometry.Vec3Array.from_array(dense_atom_positions),
dense_atom_mask,
template_group_indices.astype(jnp.int32),
)
points = rigid.translation
rigid_vec = rigid[:, None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()
unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z]
unit_vector = [x.astype(dtype) for x in unit_vector]
backbone_mask = backbone_mask.astype(dtype)
backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :]
backbone_mask_2d *= multichain_mask_2d
unit_vector = [x * backbone_mask_2d for x in unit_vector]
# Note that the backbone_mask takes into account C, CA and N (unlike
# pseudo beta mask which just needs CB) so we add both masks as features.
to_concat.extend([(x, 0) for x in unit_vector])
to_concat.append((backbone_mask_2d, 0))
query_embedding = hm.LayerNorm(name='query_embedding_norm')(
query_embedding
)
# Allow the template embedder to see the query embedding. Note this
# contains the position relative feature, so this is how the network knows
# which residues are next to each other.
to_concat.append((query_embedding, 1))
act = 0
for i, (x, n_input_dims) in enumerate(to_concat):
act += hm.Linear(
num_channels,
num_input_dims=n_input_dims,
initializer='relu',
name=f'template_pair_embedding_{i}',
)(x)
return act
act = construct_input(query_embedding, templates, multichain_mask_2d)
if c.template_stack.num_layer:
def template_iteration_fn(x):
return modules.PairFormerIteration(
c.template_stack, gc, name='template_embedding_iteration'
)(act=x, pair_mask=padding_mask_2d)
template_stack = hk.experimental.layer_stack(c.template_stack.num_layer)(
template_iteration_fn
)
act = template_stack(act)
act = hm.LayerNorm(name='output_layer_norm')(act)
return act
================================================
FILE: src/alphafold3/model/params.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Model param loading."""
import bisect
import collections
from collections.abc import Iterator
import contextlib
import io
import os
import pathlib
import re
import struct
import sys
from typing import IO
import haiku as hk
import jax.numpy as jnp
import numpy as np
import zstandard
class RecordError(Exception):
"""Error reading a record."""
def encode_record(scope: str, name: str, arr: np.ndarray) -> bytes:
"""Encodes a single haiku param as bytes, preserving non-numpy dtypes."""
scope = scope.encode('utf-8')
name = name.encode('utf-8')
shape = arr.shape
dtype = str(arr.dtype).encode('utf-8')
arr = np.ascontiguousarray(arr)
if sys.byteorder == 'big':
arr = arr.byteswap()
arr_buffer = arr.tobytes('C')
header = struct.pack(
'<5i', len(scope), len(name), len(dtype), len(shape), len(arr_buffer)
)
return header + b''.join(
(scope, name, dtype, struct.pack(f'{len(shape)}i', *shape), arr_buffer)
)
def _read_record(stream: IO[bytes]) -> tuple[str, str, np.ndarray] | None:
"""Reads a record encoded by `_encode_record` from a byte stream."""
header_size = struct.calcsize('<5i')
header = stream.read(header_size)
if not header:
return None
if len(header) < header_size:
raise RecordError(f'Incomplete header: {len(header)=} < {header_size=}')
(scope_len, name_len, dtype_len, shape_len, arr_buffer_len) = struct.unpack(
'<5i', header
)
fmt = f'<{scope_len}s{name_len}s{dtype_len}s{shape_len}i'
payload_size = struct.calcsize(fmt) + arr_buffer_len
payload = stream.read(payload_size)
if len(payload) < payload_size:
raise RecordError(f'Incomplete payload: {len(payload)=} < {payload_size=}')
scope, name, dtype, *shape = struct.unpack_from(fmt, payload)
scope = scope.decode('utf-8')
name = name.decode('utf-8')
dtype = dtype.decode('utf-8')
arr = np.frombuffer(payload[-arr_buffer_len:], dtype=dtype)
arr = np.reshape(arr, shape)
if sys.byteorder == 'big':
arr = arr.byteswap()
return scope, name, arr
def read_records(stream: IO[bytes]) -> Iterator[tuple[str, str, np.ndarray]]:
"""Fully reads the contents of a byte stream."""
while record := _read_record(stream):
yield record
class _MultiFileIO(io.RawIOBase):
"""A file-like object that presents a concatenated view of multiple files."""
def __init__(self, files: list[pathlib.Path]):
self._files = files
self._stack = contextlib.ExitStack()
self._handles = [
self._stack.enter_context(file.open('rb')) for file in files
]
self._sizes = []
for handle in self._handles:
handle.seek(0, os.SEEK_END)
self._sizes.append(handle.tell())
self._length = sum(self._sizes)
self._offsets = [0]
for s in self._sizes[:-1]:
self._offsets.append(self._offsets[-1] + s)
self._abspos = 0
self._relpos = (0, 0)
def _abs_to_rel(self, pos: int) -> tuple[int, int]:
idx = bisect.bisect_right(self._offsets, pos) - 1
return idx, pos - self._offsets[idx]
def close(self):
self._stack.close()
def closed(self) -> bool:
return all(handle.closed for handle in self._handles)
def fileno(self) -> int:
return -1
def readable(self) -> bool:
return True
def tell(self) -> int:
return self._abspos
def seek(self, pos: int, whence: int = os.SEEK_SET, /):
match whence:
case os.SEEK_SET:
pass
case os.SEEK_CUR:
pos += self._abspos
case os.SEEK_END:
pos = self._length - pos
case _:
raise ValueError(f'Invalid whence: {whence}')
self._abspos = pos
self._relpos = self._abs_to_rel(pos)
def readinto(self, b: bytearray | memoryview) -> int:
result = 0
mem = memoryview(b)
while mem:
self._handles[self._relpos[0]].seek(self._relpos[1])
count = self._handles[self._relpos[0]].readinto(mem)
result += count
self._abspos += count
self._relpos = self._abs_to_rel(self._abspos)
mem = mem[count:]
if self._abspos == self._length:
break
return result
@contextlib.contextmanager
def open_for_reading(model_files: list[pathlib.Path], is_compressed: bool):
with contextlib.closing(_MultiFileIO(model_files)) as f:
if is_compressed:
yield zstandard.ZstdDecompressor().stream_reader(f)
else:
yield f
def _match_model(
paths: list[pathlib.Path], pattern: re.Pattern[str]
) -> dict[str, list[pathlib.Path]]:
"""Match files in a directory with a pattern, and group by model name."""
models = collections.defaultdict(list)
for path in paths:
match = pattern.fullmatch(path.name)
if match:
models[match.group('model_name')].append(path)
return {k: sorted(v) for k, v in models.items()}
def select_model_files(
model_dir: pathlib.Path, model_name: str | None = None
) -> tuple[list[pathlib.Path], bool]:
"""Select the model files from a model directory."""
files = [file for file in model_dir.iterdir() if file.is_file()]
for pattern, is_compressed in (
(r'(?P.*)\.[0-9]+\.bin\.zst$', True),
(r'(?P.*)\.bin\.zst\.[0-9]+$', True),
(r'(?P.*)\.[0-9]+\.bin$', False),
(r'(?P.*)\.bin]\.[0-9]+$', False),
(r'(?P.*)\.bin\.zst$', True),
(r'(?P.*)\.bin$', False),
):
models = _match_model(files, re.compile(pattern))
if model_name is not None:
if model_name in models:
return models[model_name], is_compressed
else:
if models:
if len(models) > 1:
raise RuntimeError(f'Multiple models matched in {model_dir}')
_, model_files = models.popitem()
return model_files, is_compressed
raise FileNotFoundError(f'No models matched in {model_dir}')
def get_model_haiku_params(model_dir: pathlib.Path) -> hk.Params:
"""Get the Haiku parameters from a model name."""
params: dict[str, dict[str, jnp.Array]] = {}
model_files, is_compressed = select_model_files(model_dir)
with open_for_reading(model_files, is_compressed) as stream:
for scope, name, arr in read_records(stream):
params.setdefault(scope, {})[name] = jnp.array(arr)
if not params:
raise FileNotFoundError(f'Model missing from "{model_dir}"')
return params
================================================
FILE: src/alphafold3/model/pipeline/inter_chain_bonds.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Functions for handling inter-chain bonds."""
from collections.abc import Collection
import functools
from typing import Final, NamedTuple
from alphafold3 import structure
from alphafold3.constants import chemical_component_sets
from alphafold3.constants import mmcif_names
from alphafold3.model.atom_layout import atom_layout
import numpy as np
BOND_THRESHOLD_GLYCANS_ANGSTROM: Final[float] = 1.7
# See https://pubs.acs.org/doi/10.1021/ja010331r for P-P atom bond distances.
BOND_THRESHOLD_ALL_ANGSTROM: Final[float] = 2.4
class BondAtomArrays(NamedTuple):
chain_id: np.ndarray
chain_type: np.ndarray
res_id: np.ndarray
res_name: np.ndarray
atom_name: np.ndarray
coords: np.ndarray
def _get_bond_atom_arrays(
struc: structure.Structure, bond_atom_indices: np.ndarray
) -> BondAtomArrays:
return BondAtomArrays(
chain_id=struc.chain_id[bond_atom_indices],
chain_type=struc.chain_type[bond_atom_indices],
res_id=struc.res_id[bond_atom_indices],
res_name=struc.res_name[bond_atom_indices],
atom_name=struc.atom_name[bond_atom_indices],
coords=struc.coords[..., bond_atom_indices, :],
)
@functools.lru_cache(maxsize=1)
def get_polymer_ligand_and_ligand_ligand_bonds(
struct: structure.Structure,
only_glycan_ligands: bool,
allow_multiple_bonds_per_atom: bool,
) -> tuple[atom_layout.AtomLayout, atom_layout.AtomLayout]:
"""Return polymer-ligand & ligand-ligand inter-residue bonds.
Args:
struct: Structure object to extract bonds from.
only_glycan_ligands: Whether to only include glycans in ligand category.
allow_multiple_bonds_per_atom: If not allowed, we greedily choose the first
bond seen per atom and discard the remaining on each atom..
Returns:
polymer_ligand, ligand_ligand_bonds: Each object is an AtomLayout object
[num_bonds, 2] for the bond-defining atoms.
"""
if only_glycan_ligands:
allowed_res_names = list({
*chemical_component_sets.GLYCAN_OTHER_LIGANDS,
*chemical_component_sets.GLYCAN_LINKING_LIGANDS,
})
else:
allowed_res_names = None
all_bonds = get_bond_layout(
bond_threshold=BOND_THRESHOLD_GLYCANS_ANGSTROM
if only_glycan_ligands
else BOND_THRESHOLD_ALL_ANGSTROM,
struct=struct,
allowed_chain_types1=list({
*mmcif_names.LIGAND_CHAIN_TYPES,
*mmcif_names.POLYMER_CHAIN_TYPES,
}),
allowed_chain_types2=list(mmcif_names.LIGAND_CHAIN_TYPES),
allowed_res_names=allowed_res_names,
allow_multiple_bonds_per_atom=allow_multiple_bonds_per_atom,
)
ligand_ligand_bonds_mask = np.isin(
all_bonds.chain_type, list(mmcif_names.LIGAND_CHAIN_TYPES)
)
polymer_ligand_bonds_mask = np.isin(
all_bonds.chain_type, list(mmcif_names.POLYMER_CHAIN_TYPES)
)
polymer_ligand_bonds_mask = np.logical_and(
ligand_ligand_bonds_mask.any(axis=1),
polymer_ligand_bonds_mask.any(axis=1),
)
ligand_ligand_bonds = all_bonds[ligand_ligand_bonds_mask.all(axis=1)]
polymer_ligand_bonds = all_bonds[polymer_ligand_bonds_mask]
return polymer_ligand_bonds, ligand_ligand_bonds
def _remove_multi_bonds(
bond_layout: atom_layout.AtomLayout,
) -> atom_layout.AtomLayout:
"""Remove instances greedily."""
uids = {}
keep_indx = []
for chain_id, res_id, atom_name in zip(
bond_layout.chain_id,
bond_layout.res_id,
bond_layout.atom_name,
strict=True,
):
key1 = (chain_id[0], res_id[0], atom_name[0])
key2 = (chain_id[1], res_id[1], atom_name[1])
keep_indx.append(bool(key1 not in uids) and bool(key2 not in uids))
if key1 not in uids:
uids[key1] = None
if key2 not in uids:
uids[key2] = None
return bond_layout[np.array(keep_indx, dtype=bool)]
@functools.lru_cache(maxsize=1)
def get_ligand_ligand_bonds(
struct: structure.Structure,
only_glycan_ligands: bool,
allow_multiple_bonds_per_atom: bool = False,
) -> atom_layout.AtomLayout:
"""Return ligand-ligand inter-residue bonds.
Args:
struct: Structure object to extract bonds from.
only_glycan_ligands: Whether to only include glycans in ligand category.
allow_multiple_bonds_per_atom: If not allowed, we greedily choose the first
bond seen per atom and discard the remaining on each atom.
Returns:
bond_layout: AtomLayout object [num_bonds, 2] for the bond-defining atoms.
"""
if only_glycan_ligands:
allowed_res_names = list({
*chemical_component_sets.GLYCAN_OTHER_LIGANDS,
*chemical_component_sets.GLYCAN_LINKING_LIGANDS,
})
else:
allowed_res_names = None
return get_bond_layout(
bond_threshold=BOND_THRESHOLD_GLYCANS_ANGSTROM
if only_glycan_ligands
else BOND_THRESHOLD_ALL_ANGSTROM,
struct=struct,
allowed_chain_types1=list(mmcif_names.LIGAND_CHAIN_TYPES),
allowed_chain_types2=list(mmcif_names.LIGAND_CHAIN_TYPES),
allowed_res_names=allowed_res_names,
allow_multiple_bonds_per_atom=allow_multiple_bonds_per_atom,
)
@functools.lru_cache(maxsize=1)
def get_polymer_ligand_bonds(
struct: structure.Structure,
only_glycan_ligands: bool,
allow_multiple_bonds_per_atom: bool = False,
bond_threshold: float | None = None,
) -> atom_layout.AtomLayout:
"""Return polymer-ligand interchain bonds.
Args:
struct: Structure object to extract bonds from.
only_glycan_ligands: Whether to only include glycans in ligand category.
allow_multiple_bonds_per_atom: If not allowed, we greedily choose the first
bond seen per atom and discard the remaining on each atom.
bond_threshold: Euclidean distance of max allowed bond.
Returns:
bond_layout: AtomLayout object [num_bonds, 2] for the bond-defining atoms.
"""
if only_glycan_ligands:
allowed_res_names = list({
*chemical_component_sets.GLYCAN_OTHER_LIGANDS,
*chemical_component_sets.GLYCAN_LINKING_LIGANDS,
})
else:
allowed_res_names = None
if bond_threshold is None:
if only_glycan_ligands:
bond_threshold = BOND_THRESHOLD_GLYCANS_ANGSTROM
else:
bond_threshold = BOND_THRESHOLD_ALL_ANGSTROM
return get_bond_layout(
bond_threshold=bond_threshold,
struct=struct,
allowed_chain_types1=list(mmcif_names.POLYMER_CHAIN_TYPES),
allowed_chain_types2=list(mmcif_names.LIGAND_CHAIN_TYPES),
allowed_res_names=allowed_res_names,
allow_multiple_bonds_per_atom=allow_multiple_bonds_per_atom,
)
def get_bond_layout(
bond_threshold: float = BOND_THRESHOLD_ALL_ANGSTROM,
*,
struct: structure.Structure,
allowed_chain_types1: Collection[str],
allowed_chain_types2: Collection[str],
include_bond_types: Collection[str] = ('covale',),
allowed_res_names: Collection[str] | None = None,
allow_multiple_bonds_per_atom: bool,
) -> atom_layout.AtomLayout:
"""Get bond_layout for all bonds between two sets of chain types.
There is a mask (all_mask) that runs through this script, and each bond pair
needs to maintain a True across all conditions in order to be preserved at the
end, otherwise the bond pair has invalidated a condition with a False and is
removed entirely. Note, we remove oxygen atom bonds as they are an edge case
that causes issues with scoring, due to multiple waters bonding with single
residues.
Args:
bond_threshold: Maximum bond distance in Angstrom.
struct: Structure object to extract bonds from.
allowed_chain_types1: One end of the bonds must be an atom with one of these
chain types.
allowed_chain_types2: The other end of the bond must be an atom with one of
these chain types.
include_bond_types: Only include bonds with specified type e.g. hydrog,
metalc, covale, disulf.
allowed_res_names: Further restricts from chain_types. Either end of the
bonds must be an atom part of these res_names. If none all will be
accepted after chain and bond type filtering.
allow_multiple_bonds_per_atom: If not allowed, we greedily choose the first
bond seen per atom and discard the remaining on each atom.
Returns:
bond_layout: AtomLayout object [num_bonds, 2] for the bond-defining atoms.
"""
if not struct.bonds:
return atom_layout.AtomLayout(
atom_name=np.empty((0, 2), dtype=object),
res_id=np.empty((0, 2), dtype=int),
res_name=np.empty((0, 2), dtype=object),
chain_id=np.empty((0, 2), dtype=object),
chain_type=np.empty((0, 2), dtype=object),
atom_element=np.empty((0, 2), dtype=object),
)
from_atom_idxs, dest_atom_idxs = struct.bonds.get_atom_indices(
struct.atom_key
)
from_atoms = _get_bond_atom_arrays(struct, from_atom_idxs)
dest_atoms = _get_bond_atom_arrays(struct, dest_atom_idxs)
# Chain type
chain_mask = np.logical_or(
np.logical_and(
np.isin(
from_atoms.chain_type,
allowed_chain_types1,
),
np.isin(
dest_atoms.chain_type,
allowed_chain_types2,
),
),
np.logical_and(
np.isin(
from_atoms.chain_type,
allowed_chain_types2,
),
np.isin(
dest_atoms.chain_type,
allowed_chain_types1,
),
),
)
if allowed_res_names:
# Res type
res_mask = np.logical_or(
np.isin(from_atoms.res_name, allowed_res_names),
np.isin(dest_atoms.res_name, allowed_res_names),
)
# All mask
all_mask = np.logical_and(chain_mask, res_mask)
else:
all_mask = chain_mask
# Bond type mask
type_mask = np.isin(struct.bonds.type, list(include_bond_types))
np.logical_and(all_mask, type_mask, out=all_mask)
# Bond length check. Work in square length to avoid taking many square roots.
bond_length_squared = np.square(from_atoms.coords - dest_atoms.coords).sum(
axis=1
)
bond_threshold_squared = bond_threshold * bond_threshold
np.logical_and(
all_mask, bond_length_squared < bond_threshold_squared, out=all_mask
)
# Inter-chain and inter-residue bonds for ligands
ligand_types = list(mmcif_names.LIGAND_CHAIN_TYPES)
is_ligand = np.logical_or(
np.isin(
from_atoms.chain_type,
ligand_types,
),
np.isin(
dest_atoms.chain_type,
ligand_types,
),
)
res_id_differs = from_atoms.res_id != dest_atoms.res_id
chain_id_differs = from_atoms.chain_id != dest_atoms.chain_id
is_inter_res = np.logical_or(res_id_differs, chain_id_differs)
is_inter_ligand_res = np.logical_and(is_inter_res, is_ligand)
is_inter_chain_not_ligand = np.logical_and(chain_id_differs, ~is_ligand)
# If ligand then inter-res & inter-chain bonds, otherwise inter-chain only.
combined_allowed_bonds = np.logical_or(
is_inter_chain_not_ligand, is_inter_ligand_res
)
np.logical_and(all_mask, combined_allowed_bonds, out=all_mask)
bond_layout = atom_layout.AtomLayout(
atom_name=np.stack(
[
from_atoms.atom_name[all_mask],
dest_atoms.atom_name[all_mask],
],
axis=1,
dtype=object,
),
res_id=np.stack(
[from_atoms.res_id[all_mask], dest_atoms.res_id[all_mask]],
axis=1,
dtype=int,
),
chain_id=np.stack(
[
from_atoms.chain_id[all_mask],
dest_atoms.chain_id[all_mask],
],
axis=1,
dtype=object,
),
)
if not allow_multiple_bonds_per_atom:
bond_layout = _remove_multi_bonds(bond_layout)
return atom_layout.fill_in_optional_fields(
bond_layout,
reference_atoms=atom_layout.atom_layout_from_structure(struct),
)
================================================
FILE: src/alphafold3/model/pipeline/pipeline.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""The main featurizer."""
import bisect
from collections.abc import Sequence
import datetime
import itertools
from absl import logging
from alphafold3.common import base_config
from alphafold3.common import folding_input
from alphafold3.constants import chemical_components
from alphafold3.model import feat_batch
from alphafold3.model import features
from alphafold3.model.pipeline import inter_chain_bonds
from alphafold3.model.pipeline import structure_cleaning
from alphafold3.structure import chemical_components as struc_chem_comps
import numpy as np
_DETERMINISTIC_FRAMES_RANDOM_SEED = 12312837
def calculate_bucket_size(
num_tokens: int, buckets: Sequence[int] | None
) -> int:
"""Calculates the bucket size to pad the data to."""
if buckets is None:
return num_tokens
if not buckets:
raise ValueError('Buckets must be non-empty.')
if not all(prev < curr for prev, curr in itertools.pairwise(buckets)):
raise ValueError(
f'Buckets must be in strictly increasing order. Got {buckets=}.'
)
bucket_idx = bisect.bisect_left(buckets, num_tokens)
if bucket_idx == len(buckets):
logging.warning(
'Creating a new bucket of size %d since the input has more tokens than'
' the largest bucket size %d. This may trigger a re-compilation of the'
' model. Consider additional large bucket sizes to avoid excessive'
' re-compilation.',
num_tokens,
buckets[-1],
)
return num_tokens
return buckets[bucket_idx]
class NanDataError(Exception):
"""Raised if the data pipeline produces data containing nans."""
class TotalNumResOutOfRangeError(Exception):
"""Raised if total number of residues for all chains outside allowed range."""
class MmcifNumChainsError(Exception):
"""Raised if the mmcif file contains too many / too few chains."""
class WholePdbPipeline:
"""Processes an entire mmcif entity and merges the content."""
class Config(base_config.BaseConfig):
"""Configuration object for `WholePdbPipeline`.
Properties:
max_atoms_per_token: number of atom slots in one token (was called
num_dense, and semi-hardcoded to 24 before)
pad_num_chains: Size to pad NUM_CHAINS feature dimensions to, only for
protein chains.
buckets: Bucket sizes to pad the data to, to avoid excessive
re-compilation of the model. If None, calculate the appropriate bucket
size from the number of tokens. If not None, must be a sequence of at
least one integer, in strictly increasing order. Will raise an error if
the number of tokens is more than the largest bucket size.
max_total_residues: Any mmCIF with more total residues will be rejected.
If none, then no limit is applied.
min_total_residues: Any mmCIF with less total residues will be rejected.
msa_crop_size: Maximum size of MSA to take across all chains.
max_template_date: Optional max template date to prevent data leakage in
validation.
ref_max_modified_date: Optional maximum date that controls whether to
allow use of model coordinates for a chemical component from the CCD if
RDKit conformer generation fails and the component does not have ideal
coordinates set. Only for components that have been released before this
date the model coordinates can be used as a fallback.
max_templates: The maximum number of templates to send through the network
set to 0 to switch off templates.
filter_clashes: If true then will remove clashing chains.
filter_crystal_aids: If true ligands in the cryal aid list are removed.
max_paired_sequence_per_species: The maximum number of sequences per
species that will be used for MSA pairing.
drop_ligand_leaving_atoms: Flag for handling leaving atoms for ligands.
average_num_atoms_per_token: Target average number of atoms per token to
compute the padding size for flat atoms.
atom_cross_att_queries_subset_size: queries subset size in atom cross
attention
atom_cross_att_keys_subset_size: keys subset size in atom cross attention
flatten_non_standard_residues: Whether to expand non-standard polymer
residues into flat-atom format.
remove_nonsymmetric_bonds: Whether to remove nonsymmetric bonds from
symmetric polymer chains.
deterministic_frames: Whether to use fixed-seed reference positions to
construct deterministic frames.
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
MSA. The default behaviour matches the method described in the AlphaFold
3 paper. Set this to false if providing custom paired MSA using the
unpaired MSA field to keep it exactly as is as deduplication against
the paired MSA could break the manually crafted pairing between MSA
sequences.
"""
max_atoms_per_token: int = 24
pad_num_chains: int = 1000
buckets: list[int] | None = None
max_total_residues: int | None = None
min_total_residues: int | None = None
msa_crop_size: int = 16384
max_template_date: datetime.date | None = None
ref_max_modified_date: datetime.date | None = None
max_templates: int = 4
filter_clashes: bool = False
filter_crystal_aids: bool = False
max_paired_sequence_per_species: int = 600
drop_ligand_leaving_atoms: bool = True
average_num_atoms_per_token: int = 24
atom_cross_att_queries_subset_size: int = 32
atom_cross_att_keys_subset_size: int = 128
flatten_non_standard_residues: bool = True
remove_nonsymmetric_bonds: bool = False
deterministic_frames: bool = True
conformer_max_iterations: int | None = None
resolve_msa_overlaps: bool = True
def __init__(self, *, config: Config):
"""Initializes WholePdb data pipeline.
Args:
config: Pipeline configuration.
"""
self._config = config
def process_item(
self,
fold_input: folding_input.Input,
random_state: np.random.RandomState,
ccd: chemical_components.Ccd,
random_seed: int | None = None,
) -> features.BatchDict:
"""Takes requests from in_queue, adds (key, serialized ex) to out_queue."""
if random_seed is None:
random_seed = random_state.randint(2**31)
random_state = np.random.RandomState(seed=random_seed)
logging_name = f'{fold_input.name}, random_seed={random_seed}'
logging.info('processing %s', logging_name)
struct = fold_input.to_structure(ccd=ccd)
# Clean structure.
cleaned_struc, cleaning_metadata = structure_cleaning.clean_structure(
struct,
ccd=ccd,
drop_non_standard_atoms=True,
drop_missing_sequence=True,
filter_clashes=self._config.filter_clashes,
filter_crystal_aids=self._config.filter_crystal_aids,
filter_waters=True,
filter_hydrogens=True,
filter_leaving_atoms=self._config.drop_ligand_leaving_atoms,
only_glycan_ligands_for_leaving_atoms=True,
covalent_bonds_only=True,
remove_polymer_polymer_bonds=True,
remove_bad_bonds=True,
remove_nonsymmetric_bonds=self._config.remove_nonsymmetric_bonds,
)
num_clashing_chains_removed = cleaning_metadata[
'num_clashing_chains_removed'
]
if num_clashing_chains_removed:
logging.info(
'Removed %d clashing chains from %s',
num_clashing_chains_removed,
logging_name,
)
# No chains after fixes
if cleaned_struc.num_chains == 0:
raise MmcifNumChainsError(f'{logging_name}: No chains in structure!')
polymer_ligand_bonds, ligand_ligand_bonds = (
inter_chain_bonds.get_polymer_ligand_and_ligand_ligand_bonds(
cleaned_struc,
only_glycan_ligands=False,
allow_multiple_bonds_per_atom=True,
)
)
# If empty replace with None as this causes errors downstream.
if ligand_ligand_bonds and not ligand_ligand_bonds.atom_name.size:
ligand_ligand_bonds = None
if polymer_ligand_bonds and not polymer_ligand_bonds.atom_name.size:
polymer_ligand_bonds = None
# Create the flat output AtomLayout
empty_output_struc, flat_output_layout = (
structure_cleaning.create_empty_output_struc_and_layout(
struc=cleaned_struc,
ccd=ccd,
polymer_ligand_bonds=polymer_ligand_bonds,
ligand_ligand_bonds=ligand_ligand_bonds,
drop_ligand_leaving_atoms=self._config.drop_ligand_leaving_atoms,
)
)
# Select the tokens for Evoformer.
# Each token (e.g. a residue) is encoded as one representative atom. This
# is flexible enough to allow the 1-token-per-atom ligand representation
# in the future.
all_tokens, all_token_atoms_layout, standard_token_idxs = (
features.tokenizer(
flat_output_layout,
ccd=ccd,
max_atoms_per_token=self._config.max_atoms_per_token,
flatten_non_standard_residues=self._config.flatten_non_standard_residues,
logging_name=logging_name,
)
)
total_tokens = len(all_tokens.atom_name)
if (
self._config.max_total_residues
and total_tokens > self._config.max_total_residues
):
raise TotalNumResOutOfRangeError(
'Total Number of Residues > max_total_residues: '
f'({total_tokens} > {self._config.max_total_residues})'
)
if (
self._config.min_total_residues
and total_tokens < self._config.min_total_residues
):
raise TotalNumResOutOfRangeError(
'Total Number of Residues < min_total_residues: '
f'({total_tokens} < {self._config.min_total_residues})'
)
logging.info(
'Calculating bucket size for input with %d tokens.', total_tokens
)
padded_token_length = calculate_bucket_size(
total_tokens, self._config.buckets
)
logging.info(
'Got bucket size %d for input with %d tokens, resulting in %d padded'
' tokens.',
padded_token_length,
total_tokens,
padded_token_length - total_tokens,
)
# Padding shapes for all features.
num_atoms = padded_token_length * self._config.average_num_atoms_per_token
# Round up to next multiple of subset size.
num_atoms = int(
np.ceil(num_atoms / self._config.atom_cross_att_queries_subset_size)
* self._config.atom_cross_att_queries_subset_size
)
padding_shapes = features.PaddingShapes(
num_tokens=padded_token_length,
msa_size=self._config.msa_crop_size,
num_chains=self._config.pad_num_chains,
num_templates=self._config.max_templates,
num_atoms=num_atoms,
)
# Create the atom layouts for flat atom cross attention
batch_atom_cross_att = features.AtomCrossAtt.compute_features(
all_token_atoms_layout=all_token_atoms_layout,
queries_subset_size=self._config.atom_cross_att_queries_subset_size,
keys_subset_size=self._config.atom_cross_att_keys_subset_size,
padding_shapes=padding_shapes,
)
# Extract per-token features
batch_token_features = features.TokenFeatures.compute_features(
all_tokens=all_tokens,
padding_shapes=padding_shapes,
)
# Create reference structure features
chemical_components_data = struc_chem_comps.populate_missing_ccd_data(
ccd=ccd,
chemical_components_data=cleaned_struc.chemical_components_data,
populate_pdbx_smiles=True,
)
# Add smiles info to empty_output_struc.
empty_output_struc = empty_output_struc.copy_and_update_globals(
chemical_components_data=chemical_components_data
)
# Create layouts and store structures for model output conversion.
batch_convert_model_output = features.ConvertModelOutput.compute_features(
all_token_atoms_layout=all_token_atoms_layout,
padding_shapes=padding_shapes,
cleaned_struc=cleaned_struc,
flat_output_layout=flat_output_layout,
empty_output_struc=empty_output_struc,
polymer_ligand_bonds=polymer_ligand_bonds,
ligand_ligand_bonds=ligand_ligand_bonds,
)
# Create the PredictedStructureInfo
batch_predicted_structure_info = (
features.PredictedStructureInfo.compute_features(
all_tokens=all_tokens,
all_token_atoms_layout=all_token_atoms_layout,
padding_shapes=padding_shapes,
)
)
# Create MSA features
batch_msa = features.MSA.compute_features(
all_tokens=all_tokens,
standard_token_idxs=standard_token_idxs,
padding_shapes=padding_shapes,
fold_input=fold_input,
logging_name=logging_name,
max_paired_sequence_per_species=self._config.max_paired_sequence_per_species,
resolve_msa_overlaps=self._config.resolve_msa_overlaps,
)
# Create template features
batch_templates = features.Templates.compute_features(
all_tokens=all_tokens,
standard_token_idxs=standard_token_idxs,
padding_shapes=padding_shapes,
fold_input=fold_input,
max_templates=self._config.max_templates,
logging_name=logging_name,
)
ref_max_modified_date = self._config.ref_max_modified_date
conformer_max_iterations = self._config.conformer_max_iterations
batch_ref_structure, ligand_ligand_bonds = (
features.RefStructure.compute_features(
all_token_atoms_layout=all_token_atoms_layout,
ccd=ccd,
padding_shapes=padding_shapes,
chemical_components_data=chemical_components_data,
random_state=random_state,
ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=conformer_max_iterations,
ligand_ligand_bonds=ligand_ligand_bonds,
)
)
deterministic_ref_structure = None
if self._config.deterministic_frames:
deterministic_ref_structure, _ = features.RefStructure.compute_features(
all_token_atoms_layout=all_token_atoms_layout,
ccd=ccd,
padding_shapes=padding_shapes,
chemical_components_data=chemical_components_data,
random_state=(
np.random.RandomState(_DETERMINISTIC_FRAMES_RANDOM_SEED)
),
ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=None,
ligand_ligand_bonds=ligand_ligand_bonds,
)
# Create ligand-polymer bond features.
polymer_ligand_bond_info = features.PolymerLigandBondInfo.compute_features(
all_tokens=all_tokens,
all_token_atoms_layout=all_token_atoms_layout,
bond_layout=polymer_ligand_bonds,
padding_shapes=padding_shapes,
)
# Create ligand-ligand bond features.
ligand_ligand_bond_info = features.LigandLigandBondInfo.compute_features(
all_tokens,
ligand_ligand_bonds,
padding_shapes,
)
# Create the Pseudo-beta layout for distogram head and distance error head.
batch_pseudo_beta_info = features.PseudoBetaInfo.compute_features(
all_token_atoms_layout=all_token_atoms_layout,
ccd=ccd,
padding_shapes=padding_shapes,
logging_name=logging_name,
)
# Frame construction.
batch_frames = features.Frames.compute_features(
all_tokens=all_tokens,
all_token_atoms_layout=all_token_atoms_layout,
ref_structure=(
deterministic_ref_structure
if self._config.deterministic_frames
else batch_ref_structure
),
padding_shapes=padding_shapes,
)
# Assemble the Batch object.
batch = feat_batch.Batch(
msa=batch_msa,
templates=batch_templates,
token_features=batch_token_features,
ref_structure=batch_ref_structure,
predicted_structure_info=batch_predicted_structure_info,
polymer_ligand_bond_info=polymer_ligand_bond_info,
ligand_ligand_bond_info=ligand_ligand_bond_info,
pseudo_beta_info=batch_pseudo_beta_info,
atom_cross_att=batch_atom_cross_att,
convert_model_output=batch_convert_model_output,
frames=batch_frames,
)
np_example = batch.as_data_dict()
if 'num_iter_recycling' in np_example:
del np_example['num_iter_recycling'] # that does not belong here
for name, value in np_example.items():
if (
value.dtype.kind not in {'U', 'S'}
and value.dtype.name != 'object'
and np.isnan(np.sum(value))
):
raise NanDataError(
f'Data pipeline output for {logging_name=} contains NaNs. NaN'
f' feature: {name}'
)
return np_example
================================================
FILE: src/alphafold3/model/pipeline/structure_cleaning.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Prepare PDB structure for training or inference."""
from typing import Any
from absl import logging
from alphafold3 import structure
from alphafold3.constants import chemical_component_sets
from alphafold3.constants import chemical_components
from alphafold3.constants import mmcif_names
from alphafold3.model.atom_layout import atom_layout
from alphafold3.model.pipeline import inter_chain_bonds
from alphafold3.model.scoring import covalent_bond_cleaning
from alphafold3.structure import sterics
import numpy as np
def _get_leaving_atom_mask(
struc: structure.Structure,
polymer_ligand_bonds: atom_layout.AtomLayout | None,
ligand_ligand_bonds: atom_layout.AtomLayout | None,
chain_id: str,
chain_type: str,
res_id: int,
res_name: str,
) -> np.ndarray:
"""Updates a drop_leaving_atoms mask with new leaving atom locations."""
bonded_atoms = atom_layout.get_bonded_atoms(
polymer_ligand_bonds,
ligand_ligand_bonds,
res_id,
chain_id,
)
# Connect the amino-acids, i.e. remove OXT, HXT and H2.
drop_atoms = atom_layout.get_link_drop_atoms(
res_name=res_name,
chain_type=chain_type,
is_start_terminus=False,
is_end_terminus=False,
bonded_atoms=bonded_atoms,
drop_ligand_leaving_atoms=True,
)
# Default mask where everything is false, which equates to being kept.
drop_atom_filter_atoms = struc.chain_id != struc.chain_id
for drop_atom in drop_atoms:
drop_atom_filter_atom = np.logical_and(
np.logical_and(
struc.atom_name == drop_atom,
struc.chain_id == chain_id,
),
struc.res_id == res_id,
)
drop_atom_filter_atoms = np.logical_or(
drop_atom_filter_atoms, drop_atom_filter_atom
)
return drop_atom_filter_atoms
def clean_structure(
struc: structure.Structure,
ccd: chemical_components.Ccd,
*,
drop_missing_sequence: bool,
filter_clashes: bool,
drop_non_standard_atoms: bool,
filter_crystal_aids: bool,
filter_waters: bool,
filter_hydrogens: bool,
filter_leaving_atoms: bool,
only_glycan_ligands_for_leaving_atoms: bool,
covalent_bonds_only: bool,
remove_polymer_polymer_bonds: bool,
remove_bad_bonds: bool,
remove_nonsymmetric_bonds: bool,
) -> tuple[structure.Structure, dict[str, Any]]:
"""Cleans structure.
Args:
struc: Structure to clean.
ccd: The chemical components dictionary.
drop_missing_sequence: Whether to drop chains without specified sequences.
filter_clashes: Whether to drop clashing chains.
drop_non_standard_atoms: Whether to drop non CCD standard atoms.
filter_crystal_aids: Whether to drop ligands in the crystal aid set.
filter_waters: Whether to drop water chains.
filter_hydrogens: Whether to drop hyrdogen atoms.
filter_leaving_atoms: Whether to drop leaving atoms based on heuristics.
only_glycan_ligands_for_leaving_atoms: Whether to only include glycan
ligands when filtering leaving atoms.
covalent_bonds_only: Only include covalent bonds.
remove_polymer_polymer_bonds: Remove polymer-polymer bonds.
remove_bad_bonds: Whether to remove badly bonded ligands.
remove_nonsymmetric_bonds: Whether to remove nonsymmetric polymer-ligand
bonds from symmetric polymer chains.
Returns:
Tuple of structure and metadata dict. The metadata dict has
information about what was cleaned from the original.
"""
metadata = {}
# Crop crystallization aids.
if (
filter_crystal_aids
and struc.structure_method in mmcif_names.CRYSTALLIZATION_METHODS
):
struc = struc.filter_out(
res_name=chemical_component_sets.COMMON_CRYSTALLIZATION_AIDS
)
# Drop chains without specified sequences.
if drop_missing_sequence:
chains_with_unk_sequence = struc.find_chains_with_unknown_sequence()
num_with_unk_sequence = len(chains_with_unk_sequence)
if chains_with_unk_sequence:
struc = struc.filter_out(chain_id=chains_with_unk_sequence)
else:
num_with_unk_sequence = 0
metadata['num_with_unk_sequence'] = num_with_unk_sequence
# Remove intersecting chains.
if filter_clashes and struc.num_chains > 1:
clashing_chains = sterics.find_clashing_chains(struc)
if clashing_chains:
struc = struc.filter_out(chain_id=clashing_chains)
else:
clashing_chains = []
metadata['num_clashing_chains_removed'] = len(clashing_chains)
metadata['chains_removed'] = clashing_chains
# Drop non-standard atoms
if drop_non_standard_atoms:
struc = struc.drop_non_standard_atoms(
ccd=ccd, drop_unk=False, drop_non_ccd=False
)
# Sort chains in "reverse-spreadsheet" order.
struc = struc.with_sorted_chains
if filter_hydrogens:
struc = struc.without_hydrogen()
if filter_waters:
struc = struc.filter_out(chain_type=mmcif_names.WATER)
if filter_leaving_atoms:
drop_leaving_atoms_all = struc.chain_id != struc.chain_id
polymer_ligand_bonds = inter_chain_bonds.get_polymer_ligand_bonds(
struc,
only_glycan_ligands=only_glycan_ligands_for_leaving_atoms,
)
ligand_ligand_bonds = inter_chain_bonds.get_ligand_ligand_bonds(
struc,
only_glycan_ligands=only_glycan_ligands_for_leaving_atoms,
)
all_glycans = {
*chemical_component_sets.GLYCAN_OTHER_LIGANDS,
*chemical_component_sets.GLYCAN_LINKING_LIGANDS,
}
# If only glycan ligands and no O1 atoms, we can do parallel drop.
if (
only_glycan_ligands_for_leaving_atoms
and (not (ligand_ligand_bonds.atom_name == 'O1').any())
and (not (polymer_ligand_bonds.atom_name == 'O1').any())
):
drop_leaving_atoms_all = np.logical_and(
np.isin(struc.atom_name, 'O1'),
np.isin(struc.res_name, list(all_glycans)),
)
else:
substruct = struc.group_by_residue
glycan_mask = np.isin(substruct.res_name, list(all_glycans))
substruct = substruct.filter(glycan_mask)
# We need to iterate over all glycan residues for this.
for res in substruct.iter_residues():
# Only need to do drop leaving atoms for glycans depending on bonds.
if (res_name := res['res_name']) in all_glycans:
drop_atom_filter = _get_leaving_atom_mask(
struc=struc,
polymer_ligand_bonds=polymer_ligand_bonds,
ligand_ligand_bonds=ligand_ligand_bonds,
chain_id=res['chain_id'],
chain_type=res['chain_type'],
res_id=res['res_id'],
res_name=res_name,
)
drop_leaving_atoms_all = np.logical_or(
drop_leaving_atoms_all, drop_atom_filter
)
num_atoms_before = struc.num_atoms
struc = struc.filter_out(drop_leaving_atoms_all)
num_atoms_after = struc.num_atoms
if num_atoms_before > num_atoms_after:
logging.error(
'Dropped %s atoms from GT struc: chain_id %s res_id %s res_name %s',
num_atoms_before - num_atoms_after,
struc.chain_id,
struc.res_id,
struc.res_name,
)
# Can filter by bond type without having to iterate over bonds.
if struc.bonds and covalent_bonds_only:
is_covalent = np.isin(struc.bonds.type, ['covale'])
if sum(is_covalent) > 0:
new_bonds = struc.bonds[is_covalent]
else:
new_bonds = structure.Bonds.make_empty()
struc = struc.copy_and_update(bonds=new_bonds)
# Other bond filters require iterating over individual bonds.
if struc.bonds and (remove_bad_bonds or remove_polymer_polymer_bonds):
include_bond = []
num_pp_bonds = 0
num_bad_bonds = 0
for bond in struc.iter_bonds():
dest_atom = bond.dest_atom
from_atom = bond.from_atom
if remove_polymer_polymer_bonds:
if (
from_atom['chain_type'] in mmcif_names.POLYMER_CHAIN_TYPES
and dest_atom['chain_type'] in mmcif_names.POLYMER_CHAIN_TYPES
):
num_pp_bonds += 1
include_bond.append(False)
continue
if remove_bad_bonds:
dest_coords = np.array(
[dest_atom['atom_x'], dest_atom['atom_y'], dest_atom['atom_z']]
)
from_coords = np.array(
[from_atom['atom_x'], from_atom['atom_y'], from_atom['atom_z']]
)
squared_dist = np.sum(np.square(dest_coords - from_coords))
squared_threshold = 2.4 * 2.4
if squared_dist > squared_threshold:
num_bad_bonds += 1
include_bond.append(False)
continue
include_bond.append(True)
if sum(include_bond) < len(struc.bonds):
logging.info(
'Reducing number of bonds for %s from %s to %s, of which %s are'
' polymer-polymer bonds and %s are bad bonds.',
struc.name,
len(struc.bonds),
sum(include_bond),
num_pp_bonds,
num_bad_bonds,
)
if sum(include_bond) > 0:
# Need to index bonds with bond keys or arrays of bools with same length
# as num bonds. In this case, we use array of bools (as elsewhere in the
# cleaning code).
new_bonds = struc.bonds[np.array(include_bond, dtype=bool)]
else:
new_bonds = structure.Bonds.make_empty()
struc = struc.copy_and_update(bonds=new_bonds)
if struc.bonds and remove_nonsymmetric_bonds:
# Check for asymmetric polymer-ligand bonds and remove if these exist.
polymer_ligand_bonds = inter_chain_bonds.get_polymer_ligand_bonds(
struc,
only_glycan_ligands=False,
)
if polymer_ligand_bonds:
if covalent_bond_cleaning.has_nonsymmetric_bonds_on_symmetric_polymer_chains(
struc, polymer_ligand_bonds
):
from_atom_idxs, dest_atom_idxs = struc.bonds.get_atom_indices(
struc.atom_key
)
poly_chain_types = list(mmcif_names.POLYMER_CHAIN_TYPES)
is_polymer_bond = np.logical_or(
np.isin(struc.chain_type[from_atom_idxs], poly_chain_types),
np.isin(struc.chain_type[dest_atom_idxs], poly_chain_types),
)
struc = struc.copy_and_update(bonds=struc.bonds[~is_polymer_bond])
return struc, metadata
def create_empty_output_struc_and_layout(
struc: structure.Structure,
ccd: chemical_components.Ccd,
*,
with_hydrogens: bool = False,
skip_unk: bool = False,
polymer_ligand_bonds: atom_layout.AtomLayout | None = None,
ligand_ligand_bonds: atom_layout.AtomLayout | None = None,
drop_ligand_leaving_atoms: bool = False,
) -> tuple[structure.Structure, atom_layout.AtomLayout]:
"""Make zero-coordinate structure from all physical residues.
Args:
struc: Structure object.
ccd: The chemical components dictionary.
with_hydrogens: Whether to keep hydrogen atoms in structure.
skip_unk: Whether to remove unknown residues from structure.
polymer_ligand_bonds: Bond information for polymer-ligand pairs.
ligand_ligand_bonds: Bond information for ligand-ligand pairs.
drop_ligand_leaving_atoms: Flag for handling leaving atoms for ligands.
Returns:
Tuple of structure with all bonds, physical residues and coordinates set to
0 and a flat atom layout of empty structure.
"""
bonded_atom_pairs = []
if polymer_ligand_bonds:
for chain_ids, res_ids, atom_names in zip(
polymer_ligand_bonds.chain_id,
polymer_ligand_bonds.res_id,
polymer_ligand_bonds.atom_name,
strict=True,
):
bonded_atom_pairs.append((
(chain_ids[0], res_ids[0], atom_names[0]),
(chain_ids[1], res_ids[1], atom_names[1]),
))
if ligand_ligand_bonds:
for chain_ids, res_ids, atom_names in zip(
ligand_ligand_bonds.chain_id,
ligand_ligand_bonds.res_id,
ligand_ligand_bonds.atom_name,
strict=True,
):
bonded_atom_pairs.append((
(chain_ids[0], res_ids[0], atom_names[0]),
(chain_ids[1], res_ids[1], atom_names[1]),
))
residues = atom_layout.residues_from_structure(
struc, include_missing_residues=True
)
flat_output_layout = atom_layout.make_flat_atom_layout(
residues,
ccd=ccd,
with_hydrogens=with_hydrogens,
skip_unk_residues=skip_unk,
polymer_ligand_bonds=polymer_ligand_bonds,
ligand_ligand_bonds=ligand_ligand_bonds,
drop_ligand_leaving_atoms=drop_ligand_leaving_atoms,
)
empty_output_struc = atom_layout.make_structure(
flat_layout=flat_output_layout,
atom_coords=np.zeros((flat_output_layout.shape[0], 3)),
name=struc.name,
atom_b_factors=None,
all_physical_residues=residues,
)
if bonded_atom_pairs:
empty_output_struc = empty_output_struc.add_bonds(
bonded_atom_pairs, bond_type=mmcif_names.COVALENT_BOND
)
return empty_output_struc, flat_output_layout
================================================
FILE: src/alphafold3/model/post_processing.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Post-processing utilities for AlphaFold inference results."""
import dataclasses
import datetime
import os
from alphafold3 import version
from alphafold3.model import confidence_types
from alphafold3.model import mmcif_metadata
from alphafold3.model import model
import numpy as np
import zstandard
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class ProcessedInferenceResult:
"""Stores attributes of a processed inference result.
Attributes:
cif: CIF file containing an inference result.
mean_confidence_1d: Mean 1D confidence calculated from confidence_1d.
ranking_score: Ranking score extracted from CIF metadata.
structure_confidence_summary_json: Content of JSON file with structure
confidences summary calculated from CIF file.
structure_full_data_json: Content of JSON file with structure full
confidences calculated from CIF file.
model_id: Identifier of the model that produced the inference result.
"""
cif: bytes
mean_confidence_1d: float
ranking_score: float
structure_confidence_summary_json: bytes
structure_full_data_json: bytes
model_id: bytes
def post_process_inference_result(
inference_result: model.InferenceResult,
) -> ProcessedInferenceResult:
"""Returns cif, confidence_1d_json, confidence_2d_json, mean_confidence_1d, and ranking confidence."""
# Add mmCIF metadata fields.
timestamp = datetime.datetime.now().isoformat(sep=' ', timespec='seconds')
cif_with_metadata = mmcif_metadata.add_metadata_to_mmcif(
old_cif=inference_result.predicted_structure.to_mmcif_dict(),
version=f'{version.__version__} @ {timestamp}',
model_id=inference_result.model_id,
)
cif = mmcif_metadata.add_legal_comment(cif_with_metadata.to_string())
cif = cif.encode('utf-8')
confidence_1d = confidence_types.AtomConfidence.from_inference_result(
inference_result
)
mean_confidence_1d = np.mean(confidence_1d.confidence)
structure_confidence_summary_json = (
confidence_types.StructureConfidenceSummary.from_inference_result(
inference_result
)
.to_json()
.encode('utf-8')
)
structure_full_data_json = (
confidence_types.StructureConfidenceFull.from_inference_result(
inference_result
)
.to_json()
.encode('utf-8')
)
return ProcessedInferenceResult(
cif=cif,
mean_confidence_1d=mean_confidence_1d,
ranking_score=float(inference_result.metadata['ranking_score']),
structure_confidence_summary_json=structure_confidence_summary_json,
structure_full_data_json=structure_full_data_json,
model_id=inference_result.model_id,
)
def write_output(
inference_result: model.InferenceResult,
output_dir: os.PathLike[str] | str,
terms_of_use: str | None = None,
name: str | None = None,
compress: bool = False,
) -> None:
"""Writes processed inference result to a directory."""
processed_result = post_process_inference_result(inference_result)
prefix = f'{name}_' if name is not None else ''
if compress:
opener = zstandard.open
path_transform = lambda path: f'{path}.zst'
else:
opener = open
path_transform = lambda path: path
mmcif_path = os.path.join(output_dir, f'{prefix}model.cif')
with opener(path_transform(mmcif_path), 'wb') as f:
f.write(processed_result.cif)
full_confidences_path = os.path.join(output_dir, f'{prefix}confidences.json')
with opener(path_transform(full_confidences_path), 'wb') as f:
f.write(processed_result.structure_full_data_json)
summary_confidences_path = os.path.join(
output_dir, f'{prefix}summary_confidences.json'
)
with open(summary_confidences_path, 'wb') as f:
f.write(processed_result.structure_confidence_summary_json)
if terms_of_use is not None:
with open(os.path.join(output_dir, 'TERMS_OF_USE.md'), 'wt') as f:
f.write(terms_of_use)
def write_embeddings(
embeddings: dict[str, np.ndarray],
output_dir: os.PathLike[str] | str,
name: str | None = None,
) -> None:
"""Writes embeddings to a directory."""
prefix = f'{name}_' if name is not None else ''
with open(os.path.join(output_dir, f'{prefix}embeddings.npz'), 'wb') as f:
np.savez_compressed(f, **embeddings)
================================================
FILE: src/alphafold3/model/protein_data_processing.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Process Structure Data."""
from alphafold3.constants import atom_types
from alphafold3.constants import residue_names
from alphafold3.constants import side_chains
import numpy as np
NUM_DENSE = atom_types.DENSE_ATOM_NUM
NUM_AA = len(residue_names.PROTEIN_TYPES)
NUM_AA_WITH_UNK_AND_GAP = len(
residue_names.PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP
)
NUM_RESTYPES_WITH_UNK_AND_GAP = (
residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP
)
def _make_restype_rigidgroup_dense_atom_idx():
"""Create Mapping from rigid_groups to dense_atom indices."""
# Create an array with the atom names.
# shape (num_restypes, num_rigidgroups, 3_atoms):
# (31, 8, 3)
base_atom_indices = np.zeros(
(NUM_RESTYPES_WITH_UNK_AND_GAP, 8, 3), dtype=np.int32
)
# 4,5,6,7: 'chi1,2,3,4-group'
for restype, restype_letter in enumerate(
residue_names.PROTEIN_TYPES_ONE_LETTER
):
resname = residue_names.PROTEIN_COMMON_ONE_TO_THREE[restype_letter]
dense_atom_names = atom_types.ATOM14[resname]
# 0: backbone frame
base_atom_indices[restype, 0, :] = [
dense_atom_names.index(atom) for atom in ['C', 'CA', 'N']
]
# 3: 'psi-group'
base_atom_indices[restype, 3, :] = [
dense_atom_names.index(atom) for atom in ['CA', 'C', 'O']
]
for chi_idx in range(4):
if side_chains.CHI_ANGLES_MASK[restype][chi_idx]:
atom_names = side_chains.CHI_ANGLES_ATOMS[resname][chi_idx]
base_atom_indices[restype, chi_idx + 4, :] = [
dense_atom_names.index(atom) for atom in atom_names[1:]
]
dense_atom_names = atom_types.DENSE_ATOM['A']
nucleic_rigid_atoms = [
dense_atom_names.index(atom) for atom in ["C1'", "C3'", "C4'"]
]
for nanum, _ in enumerate(residue_names.NUCLEIC_TYPES):
# 0: backbone frame only.
# we have aa + unk + gap, so we want to start after those
resnum = nanum + NUM_AA_WITH_UNK_AND_GAP
base_atom_indices[resnum, 0, :] = nucleic_rigid_atoms
return base_atom_indices
RESTYPE_RIGIDGROUP_DENSE_ATOM_IDX = _make_restype_rigidgroup_dense_atom_idx()
def _make_restype_pseudobeta_idx():
"""Returns indices of residue's pseudo-beta."""
restype_pseudobeta_index = np.zeros(
(NUM_RESTYPES_WITH_UNK_AND_GAP,), dtype=np.int32
)
for restype, restype_letter in enumerate(
residue_names.PROTEIN_TYPES_ONE_LETTER
):
restype_name = residue_names.PROTEIN_COMMON_ONE_TO_THREE[restype_letter]
atom_names = list(atom_types.ATOM14[restype_name])
if restype_name in {'GLY'}:
restype_pseudobeta_index[restype] = atom_names.index('CA')
else:
restype_pseudobeta_index[restype] = atom_names.index('CB')
for nanum, resname in enumerate(residue_names.NUCLEIC_TYPES):
atom_names = list(atom_types.DENSE_ATOM[resname])
# 0: backbone frame only.
# we have aa + unk , so we want to start after those
restype = nanum + NUM_AA_WITH_UNK_AND_GAP
if resname in {'A', 'G', 'DA', 'DG'}:
restype_pseudobeta_index[restype] = atom_names.index('C4')
else:
restype_pseudobeta_index[restype] = atom_names.index('C2')
return restype_pseudobeta_index
RESTYPE_PSEUDOBETA_INDEX = _make_restype_pseudobeta_idx()
def _make_aatype_dense_atom_to_atom37():
"""Map from dense_atom to atom37 per residue type."""
restype_dense_atom_to_atom37 = [] # mapping (restype, dense_atom) --> atom37
for rt in residue_names.PROTEIN_TYPES_ONE_LETTER:
atom_names = list(
atom_types.ATOM14_PADDED[residue_names.PROTEIN_COMMON_ONE_TO_THREE[rt]]
)
atom_names.extend([''] * (NUM_DENSE - len(atom_names)))
restype_dense_atom_to_atom37.append(
[(atom_types.ATOM37_ORDER[name] if name else 0) for name in atom_names]
)
# Add dummy mapping for restype 'UNK', '-' (gap), and nucleics [but not DN].
for _ in range(2 + len(residue_names.NUCLEIC_TYPES_WITH_UNKNOWN)):
restype_dense_atom_to_atom37.append([0] * NUM_DENSE)
restype_dense_atom_to_atom37 = np.array(
restype_dense_atom_to_atom37, dtype=np.int32
)
return restype_dense_atom_to_atom37
PROTEIN_AATYPE_DENSE_ATOM_TO_ATOM37 = _make_aatype_dense_atom_to_atom37()
================================================
FILE: src/alphafold3/model/scoring/alignment.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Alignment based metrics."""
import numpy as np
def transform_ls(
x: np.ndarray,
b: np.ndarray,
*,
allow_reflection: bool = False,
) -> np.ndarray:
"""Find the least squares best fit rotation between two sets of N points.
Solve Ax = b for A. Where A is the transform rotating x^T into b^T.
Args:
x: NxD numpy array of coordinates. Usually dimension D is 3.
b: NxD numpy array of coordinates. Usually dimension D is 3.
allow_reflection: Whether the returned transformation can reflect as well as
rotate.
Returns:
Matrix A transforming x into b, i.e. s.t. Ax^T = b^T.
"""
assert x.shape[1] >= b.shape[1]
assert b.shape[0] == x.shape[0], '%d, %d' % (b.shape[0], x.shape[0])
# First postmultiply by x.;
# Axx^t = b x^t
bxt = np.dot(b.transpose(), x) / b.shape[0]
u, _, v = np.linalg.svd(bxt)
r = np.dot(u, v)
if not allow_reflection:
flip = np.ones((v.shape[1], 1))
flip[v.shape[1] - 1, 0] = np.sign(np.linalg.det(r))
r = np.dot(u, v * flip)
return r
def align(
*,
x: np.ndarray,
y: np.ndarray,
x_indices: np.ndarray,
y_indices: np.ndarray,
) -> np.ndarray:
"""Align x to y considering only included_idxs.
Args:
x: NxD np array of coordinates.
y: NxD np array of coordinates.
x_indices: An np array of indices for `x` that will be used in the
alignment. Must be of the same length as `y_included_idxs`.
y_indices: An np array of indices for `y` that will be used in the
alignment. Must be of the same length as `x_included_idxs`.
Returns:
NxD np array of points obtained by applying a rigid transformation to x.
These points are aligned to y and the alignment is the optimal alignment
over the points in included_idxs.
Raises:
ValueError: If the number of included indices is not the same for both
input arrays.
"""
if len(x_indices) != len(y_indices):
raise ValueError(
'Number of included indices must be the same for both input arrays,'
f' but got for x: {len(x_indices)}, and for y: {len(y_indices)}.'
)
x_mean = np.mean(x[x_indices, :], axis=0)
y_mean = np.mean(y[y_indices, :], axis=0)
centered_x = x - x_mean
centered_y = y - y_mean
t = transform_ls(centered_x[x_indices, :], centered_y[y_indices, :])
transformed_x = np.dot(centered_x, t.transpose()) + y_mean
return transformed_x
def deviations_from_coords(
decoy_coords: np.ndarray,
gt_coords: np.ndarray,
align_idxs: np.ndarray | None = None,
include_idxs: np.ndarray | None = None,
) -> np.ndarray:
"""Returns the raw per-atom deviations used in RMSD computation."""
if decoy_coords.shape != gt_coords.shape:
raise ValueError(
'decoy_coords.shape and gt_coords.shape must match.Found: %s and %s.'
% (decoy_coords.shape, gt_coords.shape)
)
# Include and align all residues unless specified otherwise.
if include_idxs is None:
include_idxs = np.arange(decoy_coords.shape[0])
if align_idxs is None:
align_idxs = include_idxs
aligned_decoy_coords = align(
x=decoy_coords,
y=gt_coords,
x_indices=align_idxs,
y_indices=align_idxs,
)
deviations = np.linalg.norm(
aligned_decoy_coords[include_idxs] - gt_coords[include_idxs], axis=1
)
return deviations
def rmsd_from_coords(
decoy_coords: np.ndarray,
gt_coords: np.ndarray,
align_idxs: np.ndarray | None = None,
include_idxs: np.ndarray | None = None,
) -> float:
"""Computes the *aligned* RMSD of two Mx3 np arrays of coordinates.
Args:
decoy_coords: [M, 3] np array of decoy atom coordinates.
gt_coords: [M, 3] np array of gt atom coordinates.
align_idxs: [M] np array of indices specifying coordinates to align on.
Defaults to None, in which case all the include_idx (see after) are used.
include_idxs: [M] np array of indices specifying coordinates to score.
Defaults to None, in which case all indices are used for scoring.
Returns:
rmsd value of the aligned decoy and gt coordinates.
"""
deviations = deviations_from_coords(
decoy_coords, gt_coords, align_idxs, include_idxs
)
return np.sqrt(np.mean(np.square(deviations)))
================================================
FILE: src/alphafold3/model/scoring/chirality.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Chirality detection and comparison."""
from collections.abc import Mapping
from absl import logging
from alphafold3 import structure
from alphafold3.constants import chemical_components
from alphafold3.data.tools import rdkit_utils
import rdkit.Chem as rd_chem
_CHIRAL_ELEMENTS = frozenset({'C', 'S'})
def _find_chiral_centres(mol: rd_chem.Mol) -> dict[str, str]:
"""Find chiral centres and detect their chirality.
Only elements listed in _CHIRAL_ELEMENTS are considered as centres.
Args:
mol: The molecule for which to detect chirality.
Returns:
Map from chiral centre atom names to identified chirality.
"""
chiral_centres = rd_chem.FindMolChiralCenters(
mol, force=True, includeUnassigned=False, useLegacyImplementation=True
)
atom_name_by_idx = {
atom.GetIdx(): atom.GetProp('atom_name') for atom in mol.GetAtoms()
}
atom_chirality_by_name = {atom_name_by_idx[k]: v for k, v in chiral_centres}
return {
k: v
for k, v in atom_chirality_by_name.items()
if any(k[: len(el)].upper() == el for el in _CHIRAL_ELEMENTS)
}
def _chiral_match(mol1: rd_chem.Mol, mol2: rd_chem.Mol) -> bool:
"""Compares chirality of two Mols. Mol1 can match a subset of mol2."""
mol1_atom_names = {a.GetProp('atom_name') for a in mol1.GetAtoms()}
mol2_atom_names = {a.GetProp('atom_name') for a in mol2.GetAtoms()}
if mol1_atom_names != mol2_atom_names:
if not mol1_atom_names.issubset(mol2_atom_names):
raise ValueError('Mol1 atoms are not a subset of mol2 atoms.')
mol1_chiral_centres = _find_chiral_centres(mol1)
mol2_chiral_centres = _find_chiral_centres(mol2)
if set(mol1_chiral_centres) != set(mol2_chiral_centres):
if not set(mol1_chiral_centres).issubset(mol2_chiral_centres):
return False
chirality_matches = {
centre_atom: chirality1 == mol2_chiral_centres[centre_atom]
for centre_atom, chirality1 in mol1_chiral_centres.items()
if '?' != mol2_chiral_centres[centre_atom]
}
return all(chirality_matches.values())
def _mol_from_ligand_struc(
ligand_struc: structure.Structure,
ref_mol: rd_chem.Mol,
) -> rd_chem.Mol | None:
"""Creates a Mol object from a ligand structure and reference mol."""
if ligand_struc.num_residues(count_unresolved=True) > 1:
raise ValueError('ligand_struc %s has more than one residue.')
coords_by_atom_name = dict(zip(ligand_struc.atom_name, ligand_struc.coords))
ref_mol = rdkit_utils.sanitize_mol(
ref_mol,
sort_alphabetically=False,
remove_hydrogens=True,
)
mol = rd_chem.Mol(ref_mol)
mol.RemoveAllConformers()
atom_indices_to_remove = [
a.GetIdx()
for a in mol.GetAtoms()
if a.GetProp('atom_name') not in coords_by_atom_name
]
editable_mol = rd_chem.EditableMol(mol)
# Remove indices from the largest to smallest, to avoid invalidating.
for atom_idx in atom_indices_to_remove[::-1]:
editable_mol.RemoveAtom(atom_idx)
mol = editable_mol.GetMol()
conformer = rd_chem.Conformer(mol.GetNumAtoms())
for atom_idx, atom in enumerate(mol.GetAtoms()):
atom_name = atom.GetProp('atom_name')
coords = coords_by_atom_name[atom_name]
conformer.SetAtomPosition(atom_idx, coords.tolist())
mol.AddConformer(conformer)
try:
rd_chem.AssignStereochemistryFrom3D(mol)
except RuntimeError as e:
# Catch only this specific rdkit error.
if 'Cannot normalize a zero length vector' in str(e):
return None
else:
raise
return mol
def _maybe_mol_from_ccd(res_name: str) -> rd_chem.Mol | None:
"""Creates a Mol object from CCD information if res_name is in the CCD."""
ccd = chemical_components.Ccd()
ccd_cif = ccd.get(res_name)
if not ccd_cif:
logging.warning('No ccd information for residue %s.', res_name)
return None
try:
mol = rdkit_utils.mol_from_ccd_cif(ccd_cif, force_parse=False)
except rdkit_utils.MolFromMmcifError as e:
logging.warning('Failed to create mol from ccd for %s: %s', res_name, e)
return None
if mol is None:
raise ValueError('Failed to create mol from ccd for %s.' % res_name)
mol = rdkit_utils.sanitize_mol(
mol,
sort_alphabetically=False,
remove_hydrogens=True,
)
return mol
def compare_chirality(
test_struc: structure.Structure,
ref_mol_by_chain: Mapping[str, rd_chem.Mol] | None = None,
) -> dict[str, bool]:
"""Compares chirality of ligands in a structure with reference molecules.
We do not enforce that ligand atoms exactly match, only that the ligand atoms
and chiral centres are a subset of those in ref mol.
Args:
test_struc: The structure for whose ligands to match chirality.
ref_mol_by_chain: Optional dictionary mapping chain IDs to mol objects with
conformers to compare against. If this is not provided, the comparison is
to the corresponding ligands in the CCD if the ligand residue name is in
the CCD.
Returns:
Dictionary mapping chain id to whether chirality mismatches the ref mol.
Only single residue ligands where reference molecules are available are
compared.
"""
ref_mol_by_chain = ref_mol_by_chain or {}
test_struc = test_struc.filter_to_entity_type(ligand=True)
name = test_struc.name
chiral_match_by_chain_id = {}
for chain_id in test_struc.chains:
chain_struc = test_struc.filter(chain_id=chain_id)
# Only compare single-residue ligands.
if chain_struc.num_residues(count_unresolved=True) > 1:
logging.warning('%s: Chain %s has >1 residues. Skipping.', name, chain_id)
continue
if chain_id not in ref_mol_by_chain:
ref_mol = _maybe_mol_from_ccd(chain_struc.res_name[0])
else:
ref_mol = ref_mol_by_chain[chain_id]
if ref_mol is None:
logging.warning(
'%s: Ref mol is None for chain %s. Skipping.', name, chain_id
)
continue
mol = _mol_from_ligand_struc(
ligand_struc=chain_struc,
ref_mol=ref_mol,
)
if mol is None:
logging.warning(
'%s: Failed to create mol for chain %s. Skipping.', name, chain_id
)
continue
chiral_match_by_chain_id[chain_id] = _chiral_match(mol, ref_mol)
return chiral_match_by_chain_id
================================================
FILE: src/alphafold3/model/scoring/covalent_bond_cleaning.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Some methods to compute metrics for PTMs."""
import collections
from collections.abc import Mapping
import dataclasses
from alphafold3 import structure
from alphafold3.constants import mmcif_names
from alphafold3.model.atom_layout import atom_layout
import numpy as np
@dataclasses.dataclass(frozen=True)
class ResIdMapping:
old_res_ids: np.ndarray
new_res_ids: np.ndarray
def _count_symmetric_chains(struc: structure.Structure) -> Mapping[str, int]:
"""Returns a dict with each chain ID and count."""
chain_res_name_sequence_from_chain_id = struc.chain_res_name_sequence(
include_missing_residues=True, fix_non_standard_polymer_res=False
)
counts_for_chain_res_name_sequence = collections.Counter(
chain_res_name_sequence_from_chain_id.values()
)
chain_symmetric_count = {}
for chain_id, chain_res_name in chain_res_name_sequence_from_chain_id.items():
chain_symmetric_count[chain_id] = counts_for_chain_res_name_sequence[
chain_res_name
]
return chain_symmetric_count
def has_nonsymmetric_bonds_on_symmetric_polymer_chains(
struc: structure.Structure, polymer_ligand_bonds: atom_layout.AtomLayout
) -> bool:
"""Returns true if nonsymmetric bonds found on polymer chains."""
try:
_get_polymer_dim(polymer_ligand_bonds)
except ValueError:
return True
if _has_non_polymer_ligand_ptm_bonds(polymer_ligand_bonds):
return True
if _has_multiple_polymers_bonded_to_one_ligand(polymer_ligand_bonds):
return True
combined_struc, _ = _combine_polymer_ligand_ptm_chains(
struc, polymer_ligand_bonds
)
struc = struc.filter(chain_type=mmcif_names.POLYMER_CHAIN_TYPES)
combined_struc = combined_struc.filter(
chain_type=mmcif_names.POLYMER_CHAIN_TYPES
)
return _count_symmetric_chains(struc) != _count_symmetric_chains(
combined_struc
)
def _has_non_polymer_ligand_ptm_bonds(
polymer_ligand_bonds: atom_layout.AtomLayout,
):
"""Checks if all bonds are between a polymer chain and a ligand chain type."""
for start_chain_type, end_chain_type in polymer_ligand_bonds.chain_type:
if (
start_chain_type in mmcif_names.POLYMER_CHAIN_TYPES
and end_chain_type in mmcif_names.LIGAND_CHAIN_TYPES
):
continue
elif (
start_chain_type in mmcif_names.LIGAND_CHAIN_TYPES
and end_chain_type in mmcif_names.POLYMER_CHAIN_TYPES
):
continue
else:
return True
return False
def _combine_polymer_ligand_ptm_chains(
struc: structure.Structure,
polymer_ligand_bonds: atom_layout.AtomLayout,
) -> tuple[structure.Structure, dict[tuple[str, str], ResIdMapping]]:
"""Combines the ptm polymer-ligand chains together.
This will prevent them from being permuted away from each other when chains
are matched to the ground truth. This function also returns the res_id mapping
from the separate ligand res_ids to their res_ids in the combined
polymer-ligand chain; this information is needed to later separate the
combined polymer-ligand chain.
Args:
struc: Structure to be modified.
polymer_ligand_bonds: AtomLayout with polymer-ligand bond info.
Returns:
A tuple of a Structure with each ptm polymer-ligand chain relabelled as one
chain and a dict from bond chain pair to the res_id mapping.
"""
if not _has_only_single_bond_from_each_chain(polymer_ligand_bonds):
if _has_multiple_ligands_bonded_to_one_polymer(polymer_ligand_bonds):
# For structures where a polymer chain is connected to multiple ligands,
# we need to sort the multiple bonds from the same chain by res_id to
# ensure that the combined polymer-ligand chain will always be the same
# when you have repeated symmetric polymer-ligand chains.
polymer_ligand_bonds = (
_sort_polymer_ligand_bonds_by_polymer_chain_and_res_id(
polymer_ligand_bonds
)
)
else:
raise ValueError(
'Code cannot handle multiple bonds from one chain unless'
' its several ligands bonded to a polymer.'
)
res_id_mappings_for_bond_chain_pair = dict()
for (start_chain_id, end_chain_id), (start_chain_type, end_chain_type) in zip(
polymer_ligand_bonds.chain_id, polymer_ligand_bonds.chain_type
):
poly_info, ligand_info = _get_polymer_and_ligand_chain_ids_and_types(
start_chain_id, end_chain_id, start_chain_type, end_chain_type
)
polymer_chain_id, polymer_chain_type = poly_info
ligand_chain_id, _ = ligand_info
# Join the ligand chain to the polymer chain.
ligand_res_ids = struc.filter(chain_id=ligand_chain_id).res_id
new_res_ids = ligand_res_ids + len(struc.all_residues[polymer_chain_id])
res_id_mappings_for_bond_chain_pair[(polymer_chain_id, ligand_chain_id)] = (
ResIdMapping(old_res_ids=ligand_res_ids, new_res_ids=new_res_ids)
)
chain_groups = []
chain_group_ids = []
chain_group_types = []
for chain_id, chain_type in zip(
struc.chains_table.id, struc.chains_table.type
):
if chain_id == ligand_chain_id:
continue
elif chain_id == polymer_chain_id:
chain_groups.append([polymer_chain_id, ligand_chain_id])
chain_group_ids.append(polymer_chain_id)
chain_group_types.append(polymer_chain_type)
else:
chain_groups.append([chain_id])
chain_group_ids.append(chain_id)
chain_group_types.append(chain_type)
struc = struc.merge_chains(
chain_groups=chain_groups,
chain_group_ids=chain_group_ids,
chain_group_types=chain_group_types,
)
return struc, res_id_mappings_for_bond_chain_pair
def _has_only_single_bond_from_each_chain(
polymer_ligand_bonds: atom_layout.AtomLayout,
) -> bool:
"""Checks that there is at most one bond from each chain."""
chain_ids = []
for chains in polymer_ligand_bonds.chain_id:
chain_ids.extend(chains)
if len(chain_ids) != len(set(chain_ids)):
return False
return True
def _get_polymer_and_ligand_chain_ids_and_types(
start_chain_id: str,
end_chain_id: str,
start_chain_type: str,
end_chain_type: str,
) -> tuple[tuple[str, str], tuple[str, str]]:
"""Finds polymer and ligand chain ids from chain types."""
if (
start_chain_type in mmcif_names.POLYMER_CHAIN_TYPES
and end_chain_type in mmcif_names.LIGAND_CHAIN_TYPES
):
return (start_chain_id, start_chain_type), (end_chain_id, end_chain_type)
elif (
start_chain_type in mmcif_names.LIGAND_CHAIN_TYPES
and end_chain_type in mmcif_names.POLYMER_CHAIN_TYPES
):
return (end_chain_id, end_chain_type), (start_chain_id, start_chain_type)
else:
raise ValueError(
'This code only handles PTM-bonds from polymer chain to ligands.'
)
def _get_polymer_dim(polymer_ligand_bonds: atom_layout.AtomLayout) -> int:
"""Gets polymer dimension from the polymer-ligand bond layout."""
start_chain_types = []
end_chain_types = []
for start_chain_type, end_chain_type in polymer_ligand_bonds.chain_type:
start_chain_types.append(start_chain_type)
end_chain_types.append(end_chain_type)
if set(start_chain_types).issubset(
set(mmcif_names.POLYMER_CHAIN_TYPES)
) and set(end_chain_types).issubset(set(mmcif_names.LIGAND_CHAIN_TYPES)):
return 0
elif set(start_chain_types).issubset(mmcif_names.LIGAND_CHAIN_TYPES) and set(
end_chain_types
).issubset(set(mmcif_names.POLYMER_CHAIN_TYPES)):
return 1
else:
raise ValueError(
'Polymer and ligand dimensions are not consistent within the structure.'
)
def _has_multiple_ligands_bonded_to_one_polymer(polymer_ligand_bonds):
"""Checks if there are multiple ligands bonded to one polymer."""
polymer_dim = _get_polymer_dim(polymer_ligand_bonds)
polymer_chain_ids = [
chains[polymer_dim] for chains in polymer_ligand_bonds.chain_id
]
if len(polymer_chain_ids) != len(set(polymer_chain_ids)):
return True
return False
def _has_multiple_polymers_bonded_to_one_ligand(polymer_ligand_bonds):
"""Checks if there are multiple polymer chains bonded to one ligand."""
polymer_dim = _get_polymer_dim(polymer_ligand_bonds)
ligand_dim = 1 - polymer_dim
ligand_chain_ids = [
chains[ligand_dim] for chains in polymer_ligand_bonds.chain_id
]
if len(ligand_chain_ids) != len(set(ligand_chain_ids)):
return True
return False
def _sort_polymer_ligand_bonds_by_polymer_chain_and_res_id(
polymer_ligand_bonds,
):
"""Sorts bonds by res_id (for when a polymer chain has multiple bonded ligands)."""
polymer_dim = _get_polymer_dim(polymer_ligand_bonds)
polymer_chain_ids = [
chains[polymer_dim] for chains in polymer_ligand_bonds.chain_id
]
polymer_res_ids = [res[polymer_dim] for res in polymer_ligand_bonds.res_id]
polymer_chain_and_res_id = zip(polymer_chain_ids, polymer_res_ids)
sorted_indices = [
idx
for idx, _ in sorted(
enumerate(polymer_chain_and_res_id), key=lambda x: x[1]
)
]
return polymer_ligand_bonds[sorted_indices]
================================================
FILE: src/alphafold3/model/scoring/scoring.py
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Library of scoring methods of the model outputs."""
from alphafold3.model import protein_data_processing
import jax.numpy as jnp
import numpy as np
Array = jnp.ndarray | np.ndarray
def pseudo_beta_fn(
aatype: Array,
dense_atom_positions: Array,
dense_atom_masks: Array,
is_ligand: Array | None = None,
use_jax: bool | None = True,
) -> tuple[Array, Array] | Array:
"""Create pseudo beta atom positions and optionally mask.
Args:
aatype: [num_res] amino acid types.
dense_atom_positions: [num_res, NUM_DENSE, 3] vector of all atom positions.
dense_atom_masks: [num_res, NUM_DENSE] mask.
is_ligand: [num_res] flag if something is a ligand.
use_jax: whether to use jax for the computations.
Returns:
Pseudo beta dense atom positions and the corresponding mask.
"""
if use_jax:
xnp = jnp
else:
xnp = np
if is_ligand is None:
is_ligand = xnp.zeros_like(aatype)
pseudobeta_index_polymer = xnp.take(
protein_data_processing.RESTYPE_PSEUDOBETA_INDEX, aatype, axis=0
).astype(xnp.int32)
pseudobeta_index = xnp.where(
is_ligand,
xnp.zeros_like(pseudobeta_index_polymer),
pseudobeta_index_polymer,
)
pseudo_beta = xnp.take_along_axis(
dense_atom_positions, pseudobeta_index[..., None, None], axis=-2
)
pseudo_beta = xnp.squeeze(pseudo_beta, axis=-2)
pseudo_beta_mask = xnp.take_along_axis(
dense_atom_masks, pseudobeta_index[..., None], axis=-1
).astype(xnp.float32)
pseudo_beta_mask = xnp.squeeze(pseudo_beta_mask, axis=-1)
return pseudo_beta, pseudo_beta_mask
================================================
FILE: src/alphafold3/parsers/cpp/cif_dict.pyi
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
from typing import Any, ClassVar, Iterable, Iterator, TypeVar, overload
import numpy as np
_T = TypeVar('_T')
class CifDict:
class ItemView:
def __iter__(self) -> Iterator[tuple[str, list[str]]]: ...
def __len__(self) -> int: ...
class KeyView:
@overload
def __contains__(self, key: str) -> bool: ...
@overload
def __contains__(self, key: object) -> bool: ...
def __iter__(self) -> Iterator[str]: ...
def __len__(self) -> int: ...
class ValueView:
def __iter__(self) -> Iterator[list[str]]: ...
def __len__(self) -> int: ...
def __init__(self, d: dict[str, Iterable[str]]) -> None: ...
def copy_and_update(self, d: dict[str, Iterable[str]]) -> CifDict: ...
def extract_loop_as_dict(self, prefix: str, index: str) -> dict:
"""Extracts loop associated with a prefix from mmCIF data as a dict.
For instance for an mmCIF with these fields:
'_a.ix': ['1', '2', '3']
'_a.1': ['a.1.1', 'a.1.2', 'a.1.3']
'_a.2': ['a.2.1', 'a.2.2', 'a.2.3']
this function called with prefix='_a.', index='_a.ix' extracts:
{'1': {'a.ix': '1', 'a.1': 'a.1.1', 'a.2': 'a.2.1'}
'2': {'a.ix': '2', 'a.1': 'a.1.2', 'a.2': 'a.2.2'}
'3': {'a.ix': '3', 'a.1': 'a.1.3', 'a.2': 'a.2.3'}}
Args:
prefix: Prefix shared by each of the data items in the loop. The prefix
should include the trailing period.
index: Which item of loop data should serve as the key.
Returns:
Dict of dicts; each dict represents 1 entry from an mmCIF loop,
indexed by the index column.
"""
def extract_loop_as_list(self, prefix: str) -> list:
"""Extracts loop associated with a prefix from mmCIF data as a list.
Reference for loop_ in mmCIF:
http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html
For instance for an mmCIF with these fields:
'_a.1': ['a.1.1', 'a.1.2', 'a.1.3']
'_a.2': ['a.2.1', 'a.2.2', 'a.2.3']
this function called with prefix='_a.' extracts:
[{'_a.1': 'a.1.1', '_a.2': 'a.2.1'}
{'_a.1': 'a.1.2', '_a.2': 'a.2.2'}
{'_a.1': 'a.1.3', '_a.2': 'a.2.3'}]
Args:
prefix: Prefix shared by each of the data items in the loop. The prefix
should include the trailing period.
Returns:
A list of dicts; each dict represents 1 entry from an mmCIF loop.
"""
def get(self, key: str, default_value: _T = ...) -> list[str] | _T: ...
def get_array(
self, key: str, dtype: object = ..., gather: object = ...
) -> np.ndarray:
"""Returns values looked up in dict converted to a NumPy array.
Args:
key: Key in dictionary.
dtype: Optional (default `object`) Specifies output dtype of array. One of
[object, np.{int,uint}{8,16,32,64} np.float{32,64}]. As with NumPy use
`object` to return a NumPy array of strings.
gather: Optional one of [slice, np.{int,uint}{32,64}] non-intermediate
version of get_array(key, dtype)[gather].
Returns:
A NumPy array of given dtype. An optimised equivalent to
np.array(cif[key]).astype(dtype). With support of '.' being treated
as np.nan if dtype is one of np.float{32,64}.
Identical strings will all reference the same object to save space.
Raises:
KeyError - if key is not found.
TypeError - if dtype is not valid or supported.
ValueError - if string cannot convert to dtype.
"""
def get_data_name(self) -> str: ...
def items(self) -> CifDict.ItemView: ...
def keys(self) -> CifDict.KeyView: ...
def to_string(self) -> str: ...
def to_dict(self) -> dict[str, list[str]]: ...
def value_length(self, key: str) -> int: ...
def values(self) -> CifDict.ValueView: ...
def __bool__(self) -> bool: ...
def __contains__(self, key: str) -> bool: ...
def __getitem__(self, key: str) -> list[str]: ...
def __getstate__(self) -> tuple: ...
def __iter__(self) -> Iterator[str]: ...
def __len__(self) -> int: ...
def __setstate__(self, state: tuple) -> None: ...
def tokenize(cif_string: str) -> list[str]: ...
def split_line(line: str) -> list[str]: ...
def from_string(mmcif_string: str | bytes) -> CifDict: ...
def parse_multi_data_cif(cif_string: str | bytes) -> dict[str, CifDict]: ...
================================================
FILE: src/alphafold3/parsers/cpp/cif_dict_lib.cc
================================================
// Copyright 2024 DeepMind Technologies Limited
//
// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
//
// To request access to the AlphaFold 3 model parameters, follow the process set
// out at https://github.com/google-deepmind/alphafold3. You may only use these
// if received directly from Google. Use is subject to terms of use available at
// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
#include "alphafold3/parsers/cpp/cif_dict_lib.h"
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include "absl/algorithm/container.h"
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/node_hash_map.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
namespace alphafold3 {
namespace {
bool IsQuote(const char symbol) { return symbol == '\'' || symbol == '"'; }
bool IsWhitespace(const char symbol) { return symbol == ' ' || symbol == '\t'; }
// Splits line into tokens, returns whether successful.
bool SplitLineInline(absl::string_view line,
std::vector* tokens) {
// See https://www.iucr.org/resources/cif/spec/version1.1/cifsyntax
for (int i = 0, line_length = line.length(); i < line_length;) {
// Skip whitespace (spaces or tabs).
while (IsWhitespace(line[i])) {
if (++i == line_length) {
break;
}
}
if (i == line_length) {
break;
}
// Skip comments (from # until the end of the line). If # is a non-comment
// character, it must be inside a quoted token.
if (line[i] == '#') {
break;
}
int start_index;
int end_index;
if (IsQuote(line[i])) {
// Token in single or double quotes. CIF v1.1 specification considers a
// quote to be an opening quote only if it is at the beginning of a token.
// So e.g. A' B has tokens A' and B. Also, ""A" is a token "A.
const char quote_char = line[i++];
start_index = i;
// Find matching quote. The double loop is not strictly necessary, but
// optimises a bit better.
while (true) {
while (i < line_length && line[i] != quote_char) {
++i;
}
if (i == line_length) {
// Reached the end of the line while still being inside a token.
return false;
}
if (i + 1 == line_length || IsWhitespace(line[i + 1])) {
break;
}
++i;
}
end_index = i++;
} else {
// Non-quoted token. Read until reaching whitespace.
start_index = i++;
while (i < line_length && !IsWhitespace(line[i])) {
++i;
}
end_index = i;
}
tokens->push_back(line.substr(start_index, end_index - start_index));
}
return true;
}
using HeapStrings = std::vector>;
// The majority of strings can be viewed on original cif_string.
// heap_strings store multi-line tokens that have internal white-space stripped.
absl::StatusOr> TokenizeInternal(
absl::string_view cif_string, HeapStrings* heap_strings) {
const std::vector lines = absl::StrSplit(cif_string, '\n');
std::vector tokens;
// Heuristic: Most lines in an mmCIF are _atom_site lines with 21 tokens.
tokens.reserve(lines.size() * 21);
int line_num = 0;
while (line_num < lines.size()) {
auto line = absl::StripSuffix(lines[line_num], "\r");
line_num++;
if (line.empty() || line[0] == '#') {
// Skip empty lines or lines that contain only comments.
continue;
} else if (line[0] == ';') {
// Leading whitespace on each line must be preserved while trailing
// whitespace may be stripped.
std::vector multiline_tokens;
// Strip the leading ";".
multiline_tokens.push_back(
absl::StripTrailingAsciiWhitespace(line.substr(1)));
while (line_num < lines.size()) {
auto multiline = absl::StripTrailingAsciiWhitespace(lines[line_num]);
line_num++;
if (!multiline.empty() && multiline[0] == ';') {
break;
} else if (line_num == lines.size()) {
return absl::InvalidArgumentError(
"Last multiline token is not terminated by a semicolon.");
}
multiline_tokens.push_back(multiline);
}
heap_strings->push_back(
std::make_unique(absl::StrJoin(multiline_tokens, "\n")));
tokens.emplace_back(*heap_strings->back());
} else {
if (!SplitLineInline(line, &tokens)) {
return absl::InvalidArgumentError(
absl::StrCat("Line ended with quote open: ", line));
}
}
}
return tokens;
}
// Returns whether the token doesn't need any quoting. This is true if the token
// isn't empty and contains only safe characters [A-Za-z0-9.?-].
bool IsTrivialToken(const absl::string_view value) {
if (value.empty()) {
return false;
}
return std::all_of(value.begin(), value.end(), [](char c) {
return absl::ascii_isalnum(c) || c == '.' || c == '?' || c == '-';
});
}
// Returns whether the token needs to be a multiline token. This happens if it
// has a newline or both single and double quotes.
bool IsMultiLineToken(const absl::string_view value) {
bool has_single_quotes = false;
bool has_double_quotes = false;
for (const char c : value) {
if (c == '\n') {
return true;
} else if (c == '\'') {
has_single_quotes = true;
} else if (c == '"') {
has_double_quotes = true;
}
}
return has_single_quotes && has_double_quotes;
}
absl::string_view GetEscapeQuote(const absl::string_view value) {
// Empty values should not happen, but if so, they should be quoted.
if (value.empty()) {
return "\"";
}
// The value must not start with one of these CIF keywords.
if (absl::StartsWithIgnoreCase(value, "data_") ||
absl::StartsWithIgnoreCase(value, "loop_") ||
absl::StartsWithIgnoreCase(value, "save_") ||
absl::StartsWithIgnoreCase(value, "stop_") ||
absl::StartsWithIgnoreCase(value, "global_")) {
return "\"";
}
// The first character must not be a special character.
const char first = value.front();
if (first == '_' || first == '#' || first == '$' || first == '[' ||
first == ']' || first == ';') {
return "\"";
}
// No quotes or whitespace allowed inside. Rare case when both double and
// single quotes are present is handled by IsMultiLineToken.
bool use_double_quote = true;
bool use_single_quote = true;
bool needs_quote = false;
for (const char c : value) {
if (c == ' ' || c == '\t') {
needs_quote = true;
} else if (c == '"') {
needs_quote = true;
use_double_quote = false;
} else if (c == '\'') {
needs_quote = true;
use_single_quote = false;
}
}
if (needs_quote && use_double_quote) {
return "\"";
} else if (needs_quote && use_single_quote) {
return "'";
}
return "";
}
int RecordIndex(absl::string_view record) {
if (record == "_entry") {
return 0; // _entry is always first.
}
if (record == "_atom_site") {
return 2; // _atom_site is always last.
}
return 1; // other records are between _entry and _atom_site.
}
struct RecordOrder {
using is_transparent = void; // Enable heterogeneous lookup.
bool operator()(absl::string_view lhs, absl::string_view rhs) const {
std::size_t lhs_index = RecordIndex(lhs);
std::size_t rhs_index = RecordIndex(rhs);
return std::tie(lhs_index, lhs) < std::tie(rhs_index, rhs);
}
};
// Make sure the _atom_site loop columns are sorted in the PDB-standard way.
constexpr absl::string_view kAtomSiteSortOrder[] = {
"_atom_site.group_PDB",
"_atom_site.id",
"_atom_site.type_symbol",
"_atom_site.label_atom_id",
"_atom_site.label_alt_id",
"_atom_site.label_comp_id",
"_atom_site.label_asym_id",
"_atom_site.label_entity_id",
"_atom_site.label_seq_id",
"_atom_site.pdbx_PDB_ins_code",
"_atom_site.Cartn_x",
"_atom_site.Cartn_y",
"_atom_site.Cartn_z",
"_atom_site.occupancy",
"_atom_site.B_iso_or_equiv",
"_atom_site.pdbx_formal_charge",
"_atom_site.auth_seq_id",
"_atom_site.auth_comp_id",
"_atom_site.auth_asym_id",
"_atom_site.auth_atom_id",
"_atom_site.pdbx_PDB_model_num",
};
size_t AtomSiteIndex(absl::string_view atom_site) {
return std::distance(std::begin(kAtomSiteSortOrder),
absl::c_find(kAtomSiteSortOrder, atom_site));
}
struct AtomSiteOrder {
bool operator()(absl::string_view lhs, absl::string_view rhs) const {
auto lhs_index = AtomSiteIndex(lhs);
auto rhs_index = AtomSiteIndex(rhs);
return std::tie(lhs_index, lhs) < std::tie(rhs_index, rhs);
}
};
class Column {
public:
Column(absl::string_view key, const std::vector* values)
: key_(key), values_(values) {
int max_value_length = 0;
for (size_t i = 0; i < values->size(); ++i) {
absl::string_view value = (*values)[i];
if (IsTrivialToken(value)) {
// Shortcut for the most common cases where no quoting/multiline needed.
max_value_length = std::max(max_value_length, value.size());
continue;
} else if (IsMultiLineToken(value)) {
values_with_newlines_.insert(i);
} else {
absl::string_view quote = GetEscapeQuote(value);
if (!quote.empty()) {
values_with_quotes_[i] = quote;
}
max_value_length =
std::max(max_value_length, value.size() + quote.size() * 2);
}
}
max_value_length_ = max_value_length;
}
absl::string_view key() const { return key_; }
const std::vector* values() const { return values_; }
int max_value_length() const { return max_value_length_; }
bool has_newlines(size_t index) const {
return values_with_newlines_.contains(index);
}
absl::string_view quote(size_t index) const {
if (auto it = values_with_quotes_.find(index);
it != values_with_quotes_.end()) {
return it->second;
}
return "";
}
private:
absl::string_view key_;
const std::vector* values_;
int max_value_length_;
// Values with newlines or quotes are very rare in a typical CIF file.
absl::flat_hash_set values_with_newlines_;
absl::flat_hash_map values_with_quotes_;
};
struct GroupedKeys {
std::vector grouped_columns;
int max_key_length;
int value_size;
};
absl::Status CheckLoopColumnSizes(int num_loop_keys, int num_loop_values) {
if ((num_loop_keys > 0) && (num_loop_values % num_loop_keys != 0)) {
return absl::InvalidArgumentError(absl::StrFormat(
"The number of values (%d) in a loop is not a multiple of the "
"number of the loop's columns (%d)",
num_loop_values, num_loop_keys));
}
return absl::OkStatus();
}
} // namespace
absl::StatusOr CifDict::FromString(absl::string_view cif_string) {
CifDict::Dict cif;
bool loop_flag = false;
absl::string_view key;
HeapStrings heap_strings;
auto tokens = TokenizeInternal(cif_string, &heap_strings);
if (!tokens.ok()) {
return tokens.status();
}
if (tokens->empty()) {
return absl::InvalidArgumentError("The CIF file must not be empty.");
}
// The first token should be data_XXX. Split into key = data, value = XXX.
absl::string_view first_token = tokens->front();
if (!absl::ConsumePrefix(&first_token, "data_")) {
return absl::InvalidArgumentError(
"The CIF file does not start with the data_ field.");
}
if (first_token.empty()) {
return absl::InvalidArgumentError(
"The CIF file does not contain a data block name.");
}
cif["data_"].emplace_back(first_token);
// Counters for CIF loop_ regions.
int loop_token_index = 0;
int num_loop_keys = 0;
// Loops have usually O(10) columns but could have up to O(10^6) rows. It is
// therefore wasteful to look up the cif vector where to add a loop value
// since that means doing `columns * rows` map lookups. If we save pointers to
// these loop column fields instead, we need only 1 cif lookup per column.
std::vector*> loop_column_values;
// Skip the first element since we already processed it above.
for (auto token_itr = tokens->begin() + 1; token_itr != tokens->end();
++token_itr) {
auto token = *token_itr;
if (absl::EqualsIgnoreCase(token, "loop_")) {
// A new loop started, check the previous loop and get rid of its data.
absl::Status loop_status =
CheckLoopColumnSizes(num_loop_keys, loop_token_index);
if (!loop_status.ok()) {
return loop_status;
}
loop_flag = true;
loop_column_values.clear();
loop_token_index = 0;
num_loop_keys = 0;
continue;
} else if (loop_flag) {
// The second condition checks we are in the first column. Some mmCIF
// files (e.g. 4q9r) have values in later columns starting with an
// underscore and we don't want to read these as keys.
int token_column_index =
num_loop_keys == 0 ? 0 : loop_token_index % num_loop_keys;
if (token_column_index == 0 && !token.empty() && token[0] == '_') {
if (loop_token_index > 0) {
// We are out of the loop.
loop_flag = false;
} else {
// We are in the keys (column names) section of the loop.
auto [it, inserted] = cif.try_emplace(token);
if (!inserted) {
return absl::InvalidArgumentError(
absl::StrCat("Duplicate loop key: '", token, "'"));
}
auto& columns = it->second;
columns.clear();
// Heuristic: _atom_site is typically the largest table in an mmCIF
// with ~16 columns. Make sure we reserve enough space for its values.
if (absl::StartsWith(token, "_atom_site.")) {
columns.reserve(tokens->size() / 20);
}
// Save the pointer to the loop column values.
loop_column_values.push_back(&columns);
num_loop_keys += 1;
continue;
}
} else {
// We are in the values section of the loop. We have a pointer to the
// loops' values, add the new token in there.
if (token_column_index >= loop_column_values.size()) {
return absl::InvalidArgumentError(
absl::StrCat("Too many columns at: '", token,
"' at column index: ", token_column_index,
" expected at most: ", loop_column_values.size()));
}
loop_column_values[token_column_index]->emplace_back(token);
loop_token_index++;
continue;
}
}
if (key.empty()) {
key = token;
if (!absl::StartsWith(key, "_")) {
return absl::InvalidArgumentError(
absl::StrCat("Key '", key, "' does not start with an underscore."));
}
} else {
auto [it, inserted] = cif.try_emplace(key);
if (!inserted) {
return absl::InvalidArgumentError(
absl::StrCat("Duplicate key: '", key, "'"));
}
(it->second).emplace_back(token);
key = "";
}
}
absl::Status loop_status =
CheckLoopColumnSizes(num_loop_keys, loop_token_index);
if (!loop_status.ok()) {
return loop_status;
}
return CifDict(std::move(cif));
}
absl::StatusOr CifDict::ToString() const {
std::string output;
absl::string_view data_name;
// Check that the data_ field exists.
if (auto name_it = (*dict_).find("data_");
name_it == (*dict_).end() || name_it->second.empty()) {
return absl::InvalidArgumentError(
"The CIF must contain a valid name for this data block in the special "
"data_ field.");
} else {
data_name = name_it->second.front();
}
if (absl::c_any_of(data_name,
[](char i) { return absl::ascii_isspace(i); })) {
return absl::InvalidArgumentError(absl::StrFormat(
"The CIF data block name must not contain any whitespace characters, "
"got '%s'.",
data_name));
}
absl::StrAppend(&output, "data_", data_name, "\n#\n");
// Group keys by their prefix. Use btree_map to iterate in alphabetical order,
// but with some keys being placed at the end (e.g. _atom_site).
absl::btree_map grouped_keys;
for (const auto& [key, values] : *dict_) {
if (key == "data_") {
continue; // Skip the special data_ key, we are already done with it.
}
const std::pair key_parts =
absl::StrSplit(key, absl::MaxSplits('.', 1));
const absl::string_view key_prefix = key_parts.first;
auto [it, inserted] = grouped_keys.emplace(key_prefix, GroupedKeys{});
GroupedKeys& grouped_key = it->second;
grouped_key.grouped_columns.push_back(Column(key, &values));
if (inserted) {
grouped_key.max_key_length = key.length();
grouped_key.value_size = values.size();
} else {
grouped_key.max_key_length =
std::max(key.length(), grouped_key.max_key_length);
if (grouped_key.value_size != values.size()) {
return absl::InvalidArgumentError(
absl::StrFormat("Values for key %s have different length (%d) than "
"the other values with the same key prefix (%d).",
key, values.size(), grouped_key.value_size));
}
}
}
for (auto& [key_prefix, group_info] : grouped_keys) {
if (key_prefix == "_atom_site") {
// Make sure we sort the _atom_site loop in the standard way.
absl::c_sort(group_info.grouped_columns,
[](const Column& lhs, const Column& rhs) {
return AtomSiteOrder{}(lhs.key(), rhs.key());
});
} else {
// Make the key ordering within a key group deterministic.
absl::c_sort(group_info.grouped_columns,
[](const Column& lhs, const Column& rhs) {
return lhs.key() < rhs.key();
});
}
// Force `_atom_site` field to always be a loop. This resolves issues with
// third party mmCIF parsers such as OpenBabel which always expect a loop
// even when there is only a single atom present.
if (group_info.value_size == 1 && key_prefix != "_atom_site") {
// Plain key-value pairs, output them as they are.
for (const Column& grouped_column : group_info.grouped_columns) {
int width = group_info.max_key_length + 1;
size_t start_pos = output.size();
output.append(width, ' ');
auto out_it = output.begin() + start_pos;
absl::c_copy(grouped_column.key(), out_it);
// Append the value, handle multi-line/quoting.
absl::string_view value = grouped_column.values()->front();
if (grouped_column.has_newlines(0)) {
absl::StrAppend(&output, "\n;", value, "\n;\n"); // Multi-line value.
} else {
const absl::string_view quote_char = grouped_column.quote(0);
absl::StrAppend(&output, quote_char, value, quote_char, "\n");
}
}
} else {
// CIF loop. Output the column names, then the rows with data.
absl::StrAppend(&output, "loop_\n");
for (Column& grouped_column : group_info.grouped_columns) {
absl::StrAppend(&output, grouped_column.key(), "\n");
}
// Write the loop values, line by line. This is the most expensive part
// since this path is taken to write the entire atom site table which has
// about 20 columns, but thousands of rows.
for (int i = 0; i < group_info.value_size; i++) {
for (int column_index = 0;
column_index < group_info.grouped_columns.size(); ++column_index) {
const Column& grouped_column =
group_info.grouped_columns[column_index];
const absl::string_view value = (*grouped_column.values())[i];
if (grouped_column.has_newlines(i)) {
// Multi-line. This is very rarely taken path.
if (column_index == 0) {
// No extra newline before leading ;, already inserted.
absl::StrAppend(&output, ";", value, "\n;\n");
} else if (column_index == group_info.grouped_columns.size() - 1) {
// No extra newline after trailing ;, will be inserted.
absl::StrAppend(&output, "\n;", value, "\n;");
} else {
absl::StrAppend(&output, "\n;", value, "\n;\n");
}
} else {
size_t start_pos = output.size();
output.append(grouped_column.max_value_length() + 1, ' ');
auto out_it = output.begin() + start_pos;
absl::string_view quote = grouped_column.quote(i);
if (!quote.empty()) {
out_it = absl::c_copy(quote, out_it);
out_it = absl::c_copy(value, out_it);
absl::c_copy(quote, out_it);
} else {
absl::c_copy(value, out_it);
}
}
}
absl::StrAppend(&output, "\n");
}
}
absl::StrAppend(&output, "#\n"); // Comment token after every key group.
}
return output;
}
absl::StatusOr<
std::vector>>
CifDict::ExtractLoopAsList(absl::string_view prefix) const {
std::vector column_names;
std::vector> column_data;
for (const auto& element : *dict_) {
if (absl::StartsWith(element.first, prefix)) {
column_names.emplace_back(element.first);
auto& cells = column_data.emplace_back();
cells.insert(cells.begin(), element.second.begin(), element.second.end());
}
}
// Make sure all columns have the same number of rows.
const std::size_t num_rows = column_data.empty() ? 0 : column_data[0].size();
for (const auto& column : column_data) {
if (column.size() != num_rows) {
return absl::InvalidArgumentError(absl::StrCat(
GetDataName(),
": Columns do not have the same number of rows for prefix: '", prefix,
"'. One possible reason could be not including the trailing dot, "
"e.g. '_atom_site.'."));
}
}
std::vector> result;
result.reserve(num_rows);
CHECK_EQ(column_names.size(), column_data.size());
for (std::size_t row_index = 0; row_index < num_rows; ++row_index) {
auto& row_dict = result.emplace_back();
row_dict.reserve(column_names.size());
for (int col_index = 0; col_index < column_names.size(); ++col_index) {
row_dict[column_names[col_index]] = column_data[col_index][row_index];
}
}
return result;
}
absl::StatusOr>>
CifDict::ExtractLoopAsDict(absl::string_view prefix,
absl::string_view index) const {
if (!absl::StartsWith(index, prefix)) {
return absl::InvalidArgumentError(
absl::StrCat(GetDataName(), ": The loop index '", index,
"' must start with the loop prefix '", prefix, "'."));
}
absl::flat_hash_map>
result;
auto loop_as_list = ExtractLoopAsList(prefix);
if (!loop_as_list.ok()) {
return loop_as_list.status();
}
result.reserve(loop_as_list->size());
for (auto& entry : *loop_as_list) {
if (const auto it = entry.find(index); it != entry.end()) {
result[it->second] = entry;
} else {
return absl::InvalidArgumentError(absl::StrCat(
GetDataName(), ": The index column '", index,
"' could not be found in the loop with prefix '", prefix, "'."));
}
}
return result;
}
absl::StatusOr> Tokenize(
absl::string_view cif_string) {
HeapStrings heap_strings;
auto tokens = TokenizeInternal(cif_string, &heap_strings);
if (!tokens.ok()) {
return tokens.status();
}
return std::vector(tokens->begin(), tokens->end());
}
absl::StatusOr> SplitLine(
absl::string_view line) {
std::vector tokens;
if (!SplitLineInline(line, &tokens)) {
return absl::InvalidArgumentError(
absl::StrCat("Line ended with quote open: ", line));
}
return tokens;
}
absl::StatusOr> ParseMultiDataCifDict(
absl::string_view cif_string) {
absl::flat_hash_map mapping;
constexpr absl::string_view delimitor = "data_";
// Check cif_string starts with correct offset.
if (!cif_string.empty() && !absl::StartsWith(cif_string, delimitor)) {
return absl::InvalidArgumentError(
"Invalid format. MultiDataCifDict must start with 'data_'");
}
for (absl::string_view data_block :
absl::StrSplit(cif_string, delimitor, absl::SkipEmpty())) {
absl::string_view block_with_delimitor(
data_block.data() - delimitor.size(),
data_block.size() + delimitor.size());
absl::StatusOr parsed_block =
CifDict::FromString(block_with_delimitor);
if (!parsed_block.ok()) {
return parsed_block.status();
}
absl::string_view data_name = parsed_block->GetDataName();
mapping[data_name] = *std::move(parsed_block);
}
return mapping;
}
} // namespace alphafold3
================================================
FILE: src/alphafold3/parsers/cpp/cif_dict_lib.h
================================================
// Copyright 2024 DeepMind Technologies Limited
//
// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
//
// To request access to the AlphaFold 3 model parameters, follow the process set
// out at https://github.com/google-deepmind/alphafold3. You may only use these
// if received directly from Google. Use is subject to terms of use available at
// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
// A C++ implementation of a CIF parser. For the format specification see
// https://www.iucr.org/resources/cif/spec/version1.1/cifsyntax
#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_LIB_H_
#define ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_LIB_H_
#include
#include
#include
#include
#include
#include "absl/container/flat_hash_map.h"
#include "absl/container/node_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
namespace alphafold3 {
class CifDict {
public:
// Use absl::node_hash_map since it guarantees pointer stability.
using Dict = absl::node_hash_map>;
CifDict() = default;
explicit CifDict(Dict dict)
: dict_(std::make_shared(std::move(dict))) {}
// Converts a CIF string into a dictionary mapping each CIF field to a list of
// values that field contains.
static absl::StatusOr FromString(absl::string_view cif_string);
// Converts the CIF into into a string that is a valid CIF file.
absl::StatusOr ToString() const;
// Extracts loop associated with a prefix from mmCIF data as a list.
// Reference for loop_ in mmCIF:
// http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html
// Args:
// prefix: Prefix shared by each of the data items in the loop.
// e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
// _entity_poly_seq.mon_id. Should include the trailing period.
//
// Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
// Lifetime of string_views tied to this.
absl::StatusOr<
std::vector>>
ExtractLoopAsList(absl::string_view prefix) const;
// Extracts loop associated with a prefix from mmCIF data as a dictionary.
// Args:
// prefix: Prefix shared by each of the data items in the loop.
// e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
// _entity_poly_seq.mon_id. Should include the trailing period.
// index: Which item of loop data should serve as the key.
//
// Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
// indexed by the index column.
// Lifetime of string_views tied to this.
absl::StatusOr>>
ExtractLoopAsDict(absl::string_view prefix, absl::string_view index) const;
// Returns value at key if present or an empty list.
absl::Span operator[](absl::string_view key) const {
auto it = dict_->find(key);
if (it != dict_->end()) {
return it->second;
}
return {};
}
// Returns boolean of whether dict contains key.
bool Contains(absl::string_view key) const { return dict_->contains(key); }
// Returns number of values for the given key if present, 0 otherwise.
size_t ValueLength(absl::string_view key) const {
return (*this)[key].size();
}
// Returns the size of the underlying dictionary.
std::size_t Length() { return dict_->size(); }
// Creates a copy of this CifDict object that will contain the original values
// but only if not updated by the given dictionary.
// E.g. if the CifDict = {a: [a1, a2], b: [b1]} and other = {a: [x], c: [z]},
// you will get {a: [x], b: [b1], c: [z]}.
CifDict CopyAndUpdate(Dict other) const {
other.insert(dict_->begin(), dict_->end());
return CifDict(std::move(other));
}
// Returns the value of the special CIF data_ field.
absl::string_view GetDataName() const {
// The data_ element has to be present by construction.
if (auto it = dict_->find("data_");
it != dict_->end() && !it->second.empty()) {
return it->second.front();
} else {
return "";
}
}
const std::shared_ptr& dict() const { return dict_; }
private:
std::shared_ptr dict_;
};
// Tokenizes a CIF string into a list of string tokens. This is more involved
// than just a simple split on whitespace as CIF allows comments and quoting.
absl::StatusOr> Tokenize(absl::string_view cif_string);
// Tokenizes a single line of a CIF string.
absl::StatusOr> SplitLine(
absl::string_view line);
// Parses a CIF string with multiple data records and returns a mapping from
// record names to CifDict objects. For instance, the following CIF string:
//
// data_001
// _foo bar
//
// data_002
// _foo baz
//
// will be parsed as:
// {'001': CifDict({'_foo': ['bar']}),
// '002': CifDict({'_foo': ['baz']})}
absl::StatusOr> ParseMultiDataCifDict(
absl::string_view cif_string);
} // namespace alphafold3
#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_LIB_H_
================================================
FILE: src/alphafold3/parsers/cpp/cif_dict_pybind.cc
================================================
// Copyright 2024 DeepMind Technologies Limited
//
// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
//
// To request access to the AlphaFold 3 model parameters, follow the process set
// out at https://github.com/google-deepmind/alphafold3. You may only use these
// if received directly from Google. Use is subject to terms of use available at
// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include "numpy/ndarrayobject.h"
#include "numpy/ndarraytypes.h"
#include "numpy/npy_common.h"
#include "absl/base/no_destructor.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "alphafold3/parsers/cpp/cif_dict_lib.h"
#include "pybind11/attr.h"
#include "pybind11/cast.h"
#include "pybind11/gil.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
namespace alphafold3 {
namespace {
namespace py = pybind11;
template
bool GatherArray(size_t num_dims, npy_intp* shape_array, npy_intp* stride_array,
const char* data, absl::Span values,
ForEach&& for_each_cb) {
if (num_dims == 1) {
const npy_intp shape = shape_array[0];
const npy_intp stride = stride_array[0];
for (size_t i = 0; i < shape; ++i) {
Item index;
std::memcpy(&index, data + stride * i, sizeof(Item));
if (index < 0 || index >= values.size()) {
PyErr_SetString(PyExc_IndexError,
absl::StrCat("index ", index,
" is out of bounds for column with size ",
values.size())
.c_str());
return false;
}
if (!for_each_cb(values[index])) {
return false;
}
}
} else if (num_dims == 0) {
Item index;
std::memcpy(&index, data, sizeof(Item));
if (index < 0 || index >= values.size()) {
PyErr_SetString(
PyExc_IndexError,
absl::StrCat("index ", index,
" is out of bounds for column with size ", values.size())
.c_str());
return false;
}
if (!for_each_cb(values[index])) {
return false;
}
} else {
const npy_intp shape = shape_array[0];
const npy_intp stride = stride_array[0];
for (size_t i = 0; i < shape; ++i) {
if (!GatherArray- (num_dims - 1, shape_array + 1, stride_array + 1,
data + stride * i, values, for_each_cb)) {
return false;
}
}
}
return true;
}
template
bool Gather(PyObject* gather, absl::Span values,
Size&& size_cb, ForEach&& for_each_cb) {
if (gather == Py_None) {
npy_intp dim = static_cast(values.size());
if (!size_cb(absl::MakeSpan(&dim, 1))) {
return false;
}
for (const std::string& v : values) {
if (!for_each_cb(v)) {
return false;
}
}
return true;
}
if (PySlice_Check(gather)) {
Py_ssize_t start, stop, step, slice_length;
if (PySlice_GetIndicesEx(gather, values.size(), &start, &stop, &step,
&slice_length) != 0) {
return false;
}
npy_intp dim = static_cast(slice_length);
if (!size_cb(absl::MakeSpan(&dim, 1))) {
return false;
}
for (size_t i = 0; i < slice_length; ++i) {
if (!for_each_cb(values[start + i * step])) {
return false;
}
}
return true;
}
if (PyArray_Check(gather)) {
PyArrayObject* gather_array = reinterpret_cast(gather);
auto shape =
absl::MakeSpan(PyArray_DIMS(gather_array), PyArray_NDIM(gather_array));
switch (PyArray_TYPE(gather_array)) {
case NPY_INT16:
if (!size_cb(shape)) {
return false;
}
return GatherArray(shape.size(), shape.data(),
PyArray_STRIDES(gather_array),
PyArray_BYTES(gather_array), values,
std::forward(for_each_cb));
case NPY_UINT16:
if (!size_cb(shape)) {
return false;
}
return GatherArray(shape.size(), shape.data(),
PyArray_STRIDES(gather_array),
PyArray_BYTES(gather_array), values,
std::forward(for_each_cb));
case NPY_INT32:
if (!size_cb(shape)) {
return false;
}
return GatherArray(shape.size(), shape.data(),
PyArray_STRIDES(gather_array),
PyArray_BYTES(gather_array), values,
std::forward(for_each_cb));
case NPY_UINT32:
if (!size_cb(shape)) {
return false;
}
return GatherArray(shape.size(), shape.data(),
PyArray_STRIDES(gather_array),
PyArray_BYTES(gather_array), values,
std::forward(for_each_cb));
case NPY_INT64:
if (!size_cb(shape)) {
return false;
}
return GatherArray(shape.size(), shape.data(),
PyArray_STRIDES(gather_array),
PyArray_BYTES(gather_array), values,
std::forward(for_each_cb));
case NPY_UINT64:
if (!size_cb(shape)) {
return false;
}
return GatherArray(shape.size(), shape.data(),
PyArray_STRIDES(gather_array),
PyArray_BYTES(gather_array), values,
std::forward(for_each_cb));
default:
PyErr_SetString(PyExc_TypeError, "Unsupported NumPy array type.");
return false;
}
}
PyErr_Format(PyExc_TypeError, "Invalid gather %R", gather);
return false;
}
// Creates a NumPy array of objects of given strings. Reusing duplicates where
// possible.
PyObject* ConvertStrings(PyObject* gather, PyArray_Descr* type,
absl::Span values) {
absl::flat_hash_map existing;
PyObject* ret = nullptr;
PyObject** dst;
if (Gather(
gather, values,
[&dst, &ret, type](absl::Span size) {
ret = PyArray_NewFromDescr(
/*subtype=*/&PyArray_Type,
/*type=*/type,
/*nd=*/size.size(),
/*dims=*/size.data(),
/*strides=*/nullptr,
/*data=*/nullptr,
/*flags=*/0,
/*obj=*/nullptr);
dst = static_cast(
PyArray_DATA(reinterpret_cast(ret)));
return true;
},
[&dst, &existing](absl::string_view value) {
auto [it, inserted] = existing.emplace(value, nullptr);
if (inserted) {
it->second =
PyUnicode_FromStringAndSize(value.data(), value.size());
PyUnicode_InternInPlace(&it->second);
} else {
Py_INCREF(it->second);
}
*dst++ = it->second;
return true;
})) {
return ret;
} else {
Py_XDECREF(ret);
return nullptr;
}
}
// Creates NumPy array with given dtype given specified converter.
// `converter` shall have the following signature:
// bool converter(const std::string& value, T* result);
// It must return whether conversion is successful and store conversion in
// result.
template
inline PyObject* Convert(PyObject* gather, PyArray_Descr* type,
absl::Span values, C&& converter) {
py::object ret;
T* dst;
if (Gather(
gather, values,
[&dst, &ret, type](absl::Span size) {
// Construct uninitialised NumPy array of type T.
ret = py::reinterpret_steal(PyArray_NewFromDescr(
/*subtype=*/&PyArray_Type,
/*type=*/type,
/*nd=*/size.size(),
/*dims=*/size.data(),
/*strides=*/nullptr,
/*data=*/nullptr,
/*flags=*/0,
/*obj=*/nullptr));
dst = static_cast(
PyArray_DATA(reinterpret_cast(ret.ptr())));
return true;
},
[&dst, &converter](const std::string& value) {
if (!converter(value, dst++)) {
PyErr_SetString(PyExc_ValueError, value.c_str());
return false;
}
return true;
})) {
return ret.release().ptr();
}
return nullptr;
}
PyObject* CifDictGetArray(const CifDict& self, absl::string_view key,
PyObject* dtype, PyObject* gather) {
import_array();
PyArray_Descr* type = nullptr;
if (dtype == Py_None) {
type = PyArray_DescrFromType(NPY_OBJECT);
} else if (PyArray_DescrConverter(dtype, &type) == NPY_FAIL || !type) {
PyErr_Format(PyExc_TypeError, "Invalid dtype %R", dtype);
Py_XDECREF(type);
return nullptr;
}
auto entry = self.dict()->find(key);
if (entry == self.dict()->end()) {
Py_DECREF(type);
PyErr_SetObject(PyExc_KeyError,
PyUnicode_FromStringAndSize(key.data(), key.size()));
return nullptr;
}
auto int_convert = [](absl::string_view str, auto* value) {
return absl::SimpleAtoi(str, value);
};
auto int_convert_bounded = [](absl::string_view str, auto* value) {
int64_t v;
if (absl::SimpleAtoi(str, &v)) {
using limits =
std::numeric_limits>;
if (limits::min() <= v && v <= limits::max()) {
*value = v;
return true;
}
}
return false;
};
absl::Span values = entry->second;
switch (type->type_num) {
case NPY_DOUBLE:
return Convert(
gather, type, values, [](absl::string_view str, double* value) {
if (str == ".") {
*value = std::numeric_limits::quiet_NaN();
return true;
}
return absl::SimpleAtod(str, value);
});
case NPY_FLOAT:
return Convert(
gather, type, values, [](absl::string_view str, float* value) {
if (str == ".") {
*value = std::numeric_limits::quiet_NaN();
return true;
}
return absl::SimpleAtof(str, value);
});
case NPY_INT8:
return Convert(gather, type, values, int_convert_bounded);
case NPY_INT16:
return Convert(gather, type, values, int_convert_bounded);
case NPY_INT32:
return Convert(gather, type, values, int_convert);
case NPY_INT64:
return Convert(gather, type, values, int_convert);
case NPY_UINT8:
return Convert(gather, type, values, int_convert_bounded);
case NPY_UINT16:
return Convert(gather, type, values, int_convert_bounded);
case NPY_UINT32:
return Convert(gather, type, values, int_convert);
case NPY_UINT64:
return Convert(gather, type, values, int_convert);
case NPY_BOOL:
return Convert(gather, type, values,
[](absl::string_view str, bool* value) {
if (str == "n" || str == "no") {
*value = false;
return true;
}
if (str == "y" || str == "yes") {
*value = true;
return true;
}
return false;
});
case NPY_OBJECT:
return ConvertStrings(gather, type, values);
default: {
PyErr_Format(PyExc_TypeError, "Unsupported dtype %R", dtype);
Py_XDECREF(type);
return nullptr;
}
}
}
} // namespace
void RegisterModuleCifDict(pybind11::module m) {
using Value = std::vector;
static absl::NoDestructor> empty_values;
m.def(
"from_string",
[](absl::string_view s) {
absl::StatusOr dict = CifDict::FromString(s);
if (!dict.ok()) {
throw py::value_error(dict.status().ToString());
}
return *dict;
},
py::call_guard());
m.def(
"tokenize",
[](absl::string_view cif_string) {
absl::StatusOr> tokens = Tokenize(cif_string);
if (!tokens.ok()) {
throw py::value_error(tokens.status().ToString());
}
return *std::move(tokens);
},
py::arg("cif_string"));
m.def("split_line", [](absl::string_view line) {
absl::StatusOr> tokens = SplitLine(line);
if (!tokens.ok()) {
throw py::value_error(tokens.status().ToString());
}
return *std::move(tokens);
});
m.def(
"parse_multi_data_cif",
[](absl::string_view cif_string) {
auto result = ParseMultiDataCifDict(cif_string);
if (!result.ok()) {
throw py::value_error(result.status().ToString());
}
py::dict dict;
for (auto& [key, value] : *result) {
dict[py::cast(key)] = py::cast(value);
}
return dict;
},
py::arg("cif_string"));
auto cif_dict =
py::class_(m, "CifDict")
.def(py::init<>([](py::dict dict) {
CifDict::Dict result;
for (const auto& [key, value] : dict) {
result.emplace(py::cast(key),
py::cast>(value));
}
return CifDict(std::move(result));
}),
"Initialise with a map")
.def("copy_and_update",
[](const CifDict& self, py::dict dict) {
CifDict::Dict result;
for (const auto& [key, value] : dict) {
result.emplace(py::cast(key),
py::cast>(value));
}
{
py::gil_scoped_release gil_release;
return self.CopyAndUpdate(std::move(result));
}
})
.def(
"__str__",
[](const CifDict& self) {
absl::StatusOr result = self.ToString();
if (!result.ok()) {
throw py::value_error(result.status().ToString());
}
return *result;
},
"Serialize to a string", py::call_guard())
.def(
"to_string",
[](const CifDict& self) {
absl::StatusOr result = self.ToString();
if (!result.ok()) {
throw py::value_error(result.status().ToString());
}
return *result;
},
"Serialize to a string", py::call_guard())
.def(
"to_dict",
[](const CifDict& self) {
py::dict result;
for (const auto& [key, value] : *self.dict()) {
result[py::cast(key)] = py::cast(value);
}
return result;
},
"Returns the CIF data as a Python dict[str, list[str]].")
.def("value_length", &CifDict::ValueLength, py::arg("key"),
"Num elements in value")
.def("__len__",
[](const CifDict& self) { return self.dict()->size(); })
.def(
"__bool__",
[](const CifDict& self) { return !self.dict()->empty(); },
"Check whether the map is nonempty")
.def(
"__contains__",
[](const CifDict& self, absl::string_view k) {
return self.dict()->find(k) != self.dict()->end();
},
py::arg("key"), py::call_guard())
.def("get_data_name", &CifDict::GetDataName)
.def(
"get",
[](const CifDict& self, absl::string_view k,
py::object default_value) -> py::object {
auto it = self.dict()->find(k);
if (it == self.dict()->end()) return default_value;
py::list result(it->second.size());
size_t index = 0;
for (const std::string& v : it->second) {
result[index++] = py::cast(v);
}
return result;
},
py::arg("key"), py::arg("default_value") = py::none())
.def(
"get_array",
[](const CifDict& self, absl::string_view key, py::handle dtype,
py::handle gather) -> py::object {
PyObject* obj =
CifDictGetArray(self, key, dtype.ptr(), gather.ptr());
if (obj == nullptr) {
throw py::error_already_set();
}
return py::reinterpret_steal(obj);
},
py::arg("key"), py::arg("dtype") = py::none(),
py::arg("gather") = py::none())
.def(
"__getitem__",
[](const CifDict& self, absl::string_view k) -> const Value& {
auto it = self.dict()->find(k);
if (it == self.dict()->end()) {
throw py::key_error(std::string(k).c_str());
}
return it->second;
},
py::arg("key"), py::call_guard())
.def(
"extract_loop_as_dict",
[](const CifDict& self, absl::string_view prefix,
absl::string_view index) {
absl::StatusOr>>
dict;
{
py::gil_scoped_release gil_release;
dict = self.ExtractLoopAsDict(prefix, index);
if (!dict.ok()) {
throw py::value_error(dict.status().ToString());
}
}
py::dict key_value_dict;
for (const auto& [key, value] : *dict) {
py::dict value_dict;
for (const auto& [key2, value2] : value) {
value_dict[py::cast(key2)] = py::cast(value2);
}
key_value_dict[py::cast(key)] = std::move(value_dict);
}
return key_value_dict;
},
py::arg("prefix"), py::arg("index"))
.def(
"extract_loop_as_list",
[](const CifDict& self, absl::string_view prefix) {
absl::StatusOr>>
list_dict;
{
py::gil_scoped_release gil_release;
list_dict = self.ExtractLoopAsList(prefix);
if (!list_dict.ok()) {
throw py::value_error(list_dict.status().ToString());
}
}
py::list list_obj(list_dict->size());
size_t index = 0;
for (const auto& value : *list_dict) {
py::dict value_dict;
for (const auto& [key, value] : value) {
value_dict[py::cast(key)] = py::cast(value);
}
list_obj[index++] = std::move(value_dict);
}
return list_obj;
},
py::arg("prefix"))
.def(py::pickle(
[](const CifDict& self) { // __getstate__.
py::tuple result_tuple(1);
py::dict result;
for (const auto& [key, value] : *self.dict()) {
result[py::cast(key)] = py::cast(value);
}
result_tuple[0] = std::move(result);
return result_tuple;
},
[](py::tuple t) { // __setstate__.
py::dict dict = t[0].cast();
CifDict::Dict result;
for (const auto& [key, value] : dict) {
result.emplace(py::cast(key),
py::cast>(value));
}
return CifDict(std::move(result));
}));
// Item, value, and key views
struct KeyView {
CifDict map;
};
struct ValueView {
CifDict map;
};
struct ItemView {
CifDict map;
};
py::class_(cif_dict, "ItemView")
.def("__len__", [](const ItemView& v) { return v.map.dict()->size(); })
.def(
"__iter__",
[](const ItemView& v) {
return py::make_iterator(v.map.dict()->begin(),
v.map.dict()->end());
},
py::keep_alive<0, 1>());
py::class_(cif_dict, "KeyView")
.def(
"__contains__",
[](const KeyView& v, absl::string_view k) {
return v.map.dict()->find(k) != v.map.dict()->end();
},
py::call_guard())
.def("__contains__", [](const KeyView&, py::handle) { return false; })
.def("__len__", [](const KeyView& v) { return v.map.dict()->size(); })
.def(
"__iter__",
[](const KeyView& v) {
return py::make_key_iterator(v.map.dict()->begin(),
v.map.dict()->end());
},
py::keep_alive<0, 1>());
py::class_(cif_dict, "ValueView")
.def("__len__", [](const ValueView& v) { return v.map.dict()->size(); })
.def(
"__iter__",
[](const ValueView& v) {
return py::make_value_iterator(v.map.dict()->begin(),
v.map.dict()->end());
},
py::keep_alive<0, 1>());
cif_dict
.def(
"__iter__",
[](const CifDict& self) {
return py::make_key_iterator(self.dict()->begin(),
self.dict()->end());
},
py::keep_alive<0, 1>())
.def(
"keys", [](const CifDict& self) { return KeyView{self}; },
"Returns an iterable view of the map's keys.")
.def(
"values", [](const CifDict& self) { return ValueView{self}; },
"Returns an iterable view of the map's values.")
.def(
"items", [](const CifDict& self) { return ItemView{self}; },
"Returns an iterable view of the map's items.");
}
} // namespace alphafold3
================================================
FILE: src/alphafold3/parsers/cpp/cif_dict_pybind.h
================================================
// Copyright 2024 DeepMind Technologies Limited
//
// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
//
// To request access to the AlphaFold 3 model parameters, follow the process set
// out at https://github.com/google-deepmind/alphafold3. You may only use these
// if received directly from Google. Use is subject to terms of use available at
// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_PYBIND_H_
#define ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_PYBIND_H_
#include "pybind11/pybind11.h"
namespace alphafold3 {
void RegisterModuleCifDict(pybind11::module m);
}
#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_PYBIND_H_
================================================
FILE: src/alphafold3/parsers/cpp/fasta_iterator.pyi
================================================
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
class FastaFileIterator:
def __init__(self, fasta_path: str) -> None: ...
def __iter__(self) -> FastaFileIterator: ...
def __next__(self) -> tuple[str,str]: ...
class FastaStringIterator:
def __init__(self, fasta_string: str | bytes) -> None: ...
def __iter__(self) -> FastaStringIterator: ...
def __next__(self) -> tuple[str,str]: ...
def parse_fasta(fasta_string: str | bytes) -> list[str]: ...
def parse_fasta_include_descriptions(fasta_string: str | bytes) -> tuple[list[str],list[str]]: ...
================================================
FILE: src/alphafold3/parsers/cpp/fasta_iterator_lib.cc
================================================
// Copyright 2024 DeepMind Technologies Limited
//
// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
//
// To request access to the AlphaFold 3 model parameters, follow the process set
// out at https://github.com/google-deepmind/alphafold3. You may only use these
// if received directly from Google. Use is subject to terms of use available at
// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
#include "alphafold3/parsers/cpp/fasta_iterator_lib.h"
#include
#include
#include
#include
#include
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
namespace alphafold3 {
// Parse FASTA string and return list of strings with amino acid sequences.
// Returns a list of amino acid sequences only.
std::vector ParseFasta(absl::string_view fasta_string) {
std::vector sequences;
std::string* sequence = nullptr;
for (absl::string_view line_raw : absl::StrSplit(fasta_string, '\n')) {
absl::string_view line = absl::StripAsciiWhitespace(line_raw);
if (absl::ConsumePrefix(&line, ">")) {
sequence = &sequences.emplace_back();
} else if (!line.empty() && sequence != nullptr) {
absl::StrAppend(sequence, line);
}
}
return sequences;
}
// Parse FASTA string and return list of strings with amino acid sequences.
// Returns two lists: The first one with amino acid sequences, the second with
// the descriptions associated with each sequence.
std::pair, std::vector>
ParseFastaIncludeDescriptions(absl::string_view fasta_string) {
std::pair, std::vector> result;
auto& [sequences, descriptions] = result;
std::string* sequence = nullptr;
for (absl::string_view line_raw : absl::StrSplit(fasta_string, '\n')) {
absl::string_view line = absl::StripAsciiWhitespace(line_raw);
if (absl::ConsumePrefix(&line, ">")) {
descriptions.emplace_back(line);
sequence = &sequences.emplace_back();
} else if (!line.empty() && sequence != nullptr) {
absl::StrAppend(sequence, line);
}
}
return result;
}
absl::StatusOr> FastaFileIterator::Next() {
std::string line_str;
while (std::getline(reader_, line_str)) {
absl::string_view line = line_str;
line = absl::StripAsciiWhitespace(line);
if (absl::ConsumePrefix(&line, ">")) {
if (!description_.has_value()) {
description_ = line;
} else {
std::pair output(sequence_, *description_);
description_ = line;
sequence_ = "";
return output;
}
} else if (description_.has_value()) {
absl::StrAppend(&sequence_, line);
}
}
has_next_ = false;
reader_.close();
if (description_.has_value()) {
return std::pair(sequence_, *description_);
} else {
return absl::InvalidArgumentError(
absl::StrCat("Invalid FASTA file: ", filename_));
}
}
absl::StatusOr>
FastaStringIterator::Next() {
size_t consumed = 0;
for (absl::string_view line_raw : absl::StrSplit(fasta_string_, '\n')) {
consumed += line_raw.size() + 1; // +1 for the newline character.
absl::string_view line = absl::StripAsciiWhitespace(line_raw);
if (absl::ConsumePrefix(&line, ">")) {
if (!description_.has_value()) {
description_ = line;
} else {
std::pair output(sequence_, *description_);
description_ = line;
sequence_ = "";
fasta_string_.remove_prefix(consumed);
return output;
}
} else if (description_.has_value()) {
absl::StrAppend(&sequence_, line);
}
}
has_next_ = false;
if (description_.has_value()) {
return std::pair(sequence_, *description_);
} else {
return absl::InvalidArgumentError("Invalid FASTA string");
}
}
} // namespace alphafold3
================================================
FILE: src/alphafold3/parsers/cpp/fasta_iterator_lib.h
================================================
// Copyright 2024 DeepMind Technologies Limited
//
// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
//
// To request access to the AlphaFold 3 model parameters, follow the process set
// out at https://github.com/google-deepmind/alphafold3. You may only use these
// if received directly from Google. Use is subject to terms of use available at
// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
// A C++ implementation of a FASTA parser.
#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_FASTA_ITERATOR_LIB_H_
#define ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_FASTA_ITERATOR_LIB_H_
#include
#include
#include
#include
#include