Copy disabled (too large)
Download .txt
Showing preview only (12,506K chars total). Download the full file to get everything.
Repository: zhonge/cryodrgn
Branch: main
Commit: cb28f71b32a9
Files: 234
Total size: 11.9 MB
Directory structure:
gitextract_tk31jno4/
├── .flake8
├── .github/
│ ├── CODEOWNERS
│ ├── ISSUE_TEMPLATE/
│ │ └── bug_report.md
│ └── workflows/
│ ├── beta_release.yml
│ ├── release.yml
│ ├── style.yml
│ └── tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE.txt
├── MANIFEST.in
├── README.md
├── analysis_scripts/
│ ├── kmeans.py
│ ├── plot_loss.py
│ ├── plot_z1.py
│ ├── plot_z2.py
│ ├── plot_z_pca.py
│ ├── run_umap.py
│ └── tsne.py
├── cryodrgn/
│ ├── __init__.py
│ ├── analysis.py
│ ├── analysis_drgnai.py
│ ├── beta_schedule.py
│ ├── command_line.py
│ ├── commands/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── abinit.py
│ │ ├── abinit_het_old.py
│ │ ├── abinit_homo_old.py
│ │ ├── analyze.py
│ │ ├── analyze_landscape.py
│ │ ├── analyze_landscape_full.py
│ │ ├── backproject_voxel.py
│ │ ├── dashboard.py
│ │ ├── direct_traversal.py
│ │ ├── downsample.py
│ │ ├── eval_images.py
│ │ ├── eval_vol.py
│ │ ├── filter.py
│ │ ├── graph_traversal.py
│ │ ├── parse_ctf_csparc.py
│ │ ├── parse_ctf_star.py
│ │ ├── parse_pose_csparc.py
│ │ ├── parse_pose_star.py
│ │ ├── parse_star.py
│ │ ├── pc_traversal.py
│ │ ├── train_dec.py
│ │ ├── train_nn.py
│ │ └── train_vae.py
│ ├── commands_utils/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── add_psize.py
│ │ ├── analyze_convergence.py
│ │ ├── clean.py
│ │ ├── concat_pkls.py
│ │ ├── filter_cs.py
│ │ ├── filter_mrcs.py
│ │ ├── filter_pkl.py
│ │ ├── filter_star.py
│ │ ├── flip_hand.py
│ │ ├── fsc.py
│ │ ├── gen_mask.py
│ │ ├── invert_contrast.py
│ │ ├── make_movies.py
│ │ ├── parse_relion.py
│ │ ├── phase_flip.py
│ │ ├── plot_classes.py
│ │ ├── plot_fsc.py
│ │ ├── select_clusters.py
│ │ ├── select_random.py
│ │ ├── translate_mrcs.py
│ │ ├── view_cs_header.py
│ │ ├── view_header.py
│ │ ├── view_mrcs.py
│ │ ├── write_cs.py
│ │ └── write_star.py
│ ├── config.py
│ ├── ctf.py
│ ├── dashboard/
│ │ ├── __init__.py
│ │ ├── app.py
│ │ ├── bench_plot_interfaces.py
│ │ ├── command_builder_cli_help.py
│ │ ├── command_builder_data.py
│ │ ├── context.py
│ │ ├── data.py
│ │ ├── explorer_volumes.py
│ │ ├── mpl_style.py
│ │ ├── plots.py
│ │ ├── preload.py
│ │ ├── templates/
│ │ │ ├── base.html
│ │ │ ├── command_builder.html
│ │ │ ├── index.html
│ │ │ ├── latent_3d.html
│ │ │ ├── no_images.html
│ │ │ ├── pair_grid.html
│ │ │ ├── pair_grid_need_more_cols.html
│ │ │ ├── scatter_explorer.html
│ │ │ └── trajectory_creator.html
│ │ └── trajectory.py
│ ├── dataset.py
│ ├── fft.py
│ ├── healpy_grid.json
│ ├── lattice.py
│ ├── lie_tools.py
│ ├── losses.py
│ ├── make_healpy.py
│ ├── masking.py
│ ├── metrics.py
│ ├── models.py
│ ├── models_ai.py
│ ├── mrcfile.py
│ ├── pose.py
│ ├── pose_search.py
│ ├── pose_search_ai.py
│ ├── shift_grid.py
│ ├── shift_grid3.py
│ ├── so3_grid.py
│ ├── source.py
│ ├── starfile.py
│ ├── templates/
│ │ ├── cryoDRGN_ET_viz_template.ipynb
│ │ ├── cryoDRGN_analyze_landscape_template.ipynb
│ │ ├── cryoDRGN_figures_template.ipynb
│ │ ├── cryoDRGN_filtering_template.ipynb
│ │ └── cryoDRGN_viz_template.ipynb
│ └── utils.py
├── pyproject.toml
├── sweep.sh
├── testing/
│ ├── diff_cryodrgn_pkl.py
│ ├── test_abinit.sh
│ ├── test_entropy.py
│ ├── test_pose_search_rag12_128.py
│ ├── test_pose_search_real_128.py
│ ├── test_pose_search_syn_64.py
│ ├── test_sta.sh
│ └── test_translate.py
└── tests/
├── conftest.py
├── data/
│ ├── 50S-vol.mrc
│ ├── FinalRefinement-OriginalParticles-PfCRT.star
│ ├── ay19102021_L3_position6_ribo_it09_bin8_1.82A.mrcs
│ ├── cryosparc_J2_particles_exported.cs
│ ├── cryosparc_P12_J24_001_particles.cs
│ ├── ctf1.pkl
│ ├── ctf2.pkl
│ ├── empiar_10076_7.cs
│ ├── empiar_10076_7.mrc
│ ├── empiar_10076_7.star
│ ├── hand-vol.mrc
│ ├── hand.5.mrcs
│ ├── hand.mrcs
│ ├── hand_11_particles.npy
│ ├── hand_rot.pkl
│ ├── hand_rot_trans.pkl
│ ├── hand_tilt.mrcs
│ ├── het_config.yaml
│ ├── het_weights.pkl
│ ├── im_shifted.npy
│ ├── ind100-rand.pkl
│ ├── ind100.pkl
│ ├── ind4.pkl
│ ├── ind5.pkl
│ ├── ind_39_sta_testing_bin8.pkl
│ ├── pose.cs.pkl
│ ├── pose.star.pkl
│ ├── relion31.6opticsgroups.star
│ ├── relion31.mrcs
│ ├── relion31.star
│ ├── relion31.v2.star
│ ├── relion5.star
│ ├── spike-vol.mrc
│ ├── sta_ctf.pkl
│ ├── sta_pose.pkl
│ ├── sta_testing.star
│ ├── sta_testing_bin8.star
│ ├── test_ctf.100.pkl
│ ├── test_ctf.pkl
│ ├── toy.star
│ ├── toy_angles.pkl
│ ├── toy_datadir/
│ │ ├── toy_images_a.mrcs
│ │ └── toy_images_b.mrcs
│ ├── toy_projections.mrc
│ ├── toy_projections.mrcs
│ ├── toy_projections.star
│ ├── toy_projections.txt
│ ├── toy_projections_0-999.mrcs
│ ├── toy_projections_13.star
│ ├── toy_projections_2.txt
│ ├── toy_projections_dir.star
│ ├── toy_rot_trans.pkl
│ ├── toy_rot_zerotrans.pkl
│ ├── toy_trans.pkl
│ ├── toy_trans.zero.pkl
│ ├── toymodel_small_nocenter.mrc
│ ├── zvals_het-2_1k.pkl
│ └── zvals_het-8_4k.pkl
├── quicktest.sh
├── test_add_psize.py
├── test_backprojection.py
├── test_clean.py
├── test_dashboard_core.py
├── test_dashboard_extended.py
├── test_dataset.py
├── test_direct_traversal.py
├── test_downsample.py
├── test_entropy.py
├── test_eval_images.py
├── test_fft.py
├── test_filter_mrcs.py
├── test_filter_pkl.py
├── test_flip_hand.py
├── test_fsc.py
├── test_graph_traversal.py
├── test_integration.py
├── test_invert_contrast.py
├── test_masks.py
├── test_mrc.py
├── test_parse.py
├── test_pc_traversal.py
├── test_phase_flip.py
├── test_read_filter_write.py
├── test_reconstruct_abinit.py
├── test_reconstruct_abinit_old.py
├── test_reconstruct_fixed.py
├── test_reconstruct_tilt.py
├── test_relion.py
├── test_select_clusters.py
├── test_select_random.py
├── test_source.py
├── test_translate.py
├── test_utils.py
├── test_view_cs_header.py
├── test_view_header.py
├── test_view_mrcs.py
├── test_writestar.py
└── unittest.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: .flake8
================================================
[flake8]
extend-ignore = E203,E501
max-complexity = 99
max-line-length = 88
================================================
FILE: .github/CODEOWNERS
================================================
* @michal-g @zhonge
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
What is the command you used?
**Expected behavior**
A clear and concise description of what you expected to happen.
**Additional context**
Add any other context about the problem here.
================================================
FILE: .github/workflows/beta_release.yml
================================================
name: Beta Release
on:
push:
tags:
- '[0-9]+\.[0-9]+\.[0-9]+-*'
jobs:
beta-release:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Upgrade setuptools/build
run: |
python3 -m venv myenv/
myenv/bin/pip install setuptools --upgrade
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.9'
- name: Release to TestPyPI
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
myenv/bin/python -m pip install --upgrade build
myenv/bin/python -m build .
myenv/bin/python -m pip install --upgrade twine
myenv/bin/python -m pip install importlib_metadata==7.2.1
myenv/bin/twine upload --repository testpypi dist/*
================================================
FILE: .github/workflows/release.yml
================================================
name: Release
on:
push:
tags:
- '[0-9]+.[0-9]+.[0-9]+'
- '!*-[a-z]+[0-9]+'
jobs:
release:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Upgrade setuptools/build
run: |
python3 -m venv myenv/
myenv/bin/pip install setuptools --upgrade
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.9'
- name: Release to pypi
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_MAIN_TOKEN }}
run: |
myenv/bin/python -m pip install --upgrade build
myenv/bin/python -m build .
myenv/bin/python -m pip install --upgrade twine
myenv/bin/python -m pip install importlib_metadata==7.2.1
myenv/bin/twine upload dist/*
================================================
FILE: .github/workflows/style.yml
================================================
name: Code Linting
on:
push:
branches: [ main, develop ]
tags:
- '[0-9]+\.[0-9]+\.[0-9]+'
- '[0-9]+\.[0-9]+\.[0-9]+-*'
pull_request:
branches: [ main, develop ]
jobs:
run_tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
check-latest: true
- name: Install cryoDRGN with dev dependencies
run: |
python3 -m pip install .[dev]
- name: Run pre-commit checks
run: |
pre-commit run --all-files --show-diff-on-failure
- name: Run Pyright
run: |
pyright --version
#pyright
================================================
FILE: .github/workflows/tests.yml
================================================
name: CI Testing
on:
push:
branches: [ develop ]
tags:
- '[0-9]+\.[0-9]+\.[0-9]+'
- '[0-9]+\.[0-9]+\.[0-9]+-*'
pull_request:
branches: [ main ]
jobs:
run_tests:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python: [ '3.10', '3.11' , '3.12' , '3.13' ]
os: [ macos-latest, ubuntu-latest ]
include:
- python: '3.10'
torch: '2.0'
- python: '3.11'
torch: '2.3'
- python: '3.12'
torch: '2.6'
- python: '3.13'
torch: '2.9'
fail-fast: false
steps:
- uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install cryoDRGN with pytest dependencies
run: |
python3 -m pip install --upgrade pip
python3 -m pip install pytest-xdist
python3 -m pip install .
python3 -m pip uninstall -y torch
python3 -m pip cache purge
python3 -m pip install torch==${{ matrix.torch }}
- name: Pytest
run: |
pytest -v -n2 --dist=loadscope --show-capture=stderr
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
.DS_Store
env/
build/
develop-eggs/
docs/generated/
dist/
downloads/
eggs/
.eggs/
.idea/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
cryodrgn/_version.py
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# dotenv
.env
# virtualenv
.venv
venv/
ENV/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
================================================
FILE: .pre-commit-config.yaml
================================================
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
exclude: '.cs$|.star$'
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- repo: https://github.com/pycqa/flake8
rev: '6.1.0'
hooks:
- id: flake8
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/MarcoGorelli/absolufy-imports
rev: v0.3.1
hooks:
- id: absolufy-imports
================================================
FILE: LICENSE.txt
================================================
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.
================================================
FILE: MANIFEST.in
================================================
include cryodrgn/templates/*ipynb
================================================
FILE: README.md
================================================





# :snowflake::dragon: cryoDRGN: Deep Reconstructing Generative Networks for cryo-EM and cryo-ET heterogeneous reconstruction
CryoDRGN is a neural network based algorithm for heterogeneous cryo-EM reconstruction. In particular, the method models
a *continuous* distribution over 3D structures by using a neural network based representation for the volume.
## Documentation
The latest documentation for cryoDRGN is available in our [user guide](https://ez-lab.gitbook.io/cryodrgn/), including
an overview and walkthrough of cryoDRGN installation, training and analysis. A brief quick start is provided below.
For any feedback, questions, or bugs, please file a Github issue or start a Github discussion.
### Updates in Version 4.2.x
* [NEW] interactive dashboard for visualizing cryoDRGN results in a series of web apps
* [NEW] cryoDRGN-AI *ab initio* reconstruction method integrated into cryoDRGN as `cryodrgn abinit`
* former ab-initio reconstruction methods are deprecated as `cryodrgn abinit_het_old` and `cryodrgn abinit_homo_old`
* `cryodrgn analyze`, `landscape`, etc. now support cryoDRGN-AI models as well as the previous cryoDRGN models
* more memory-efficient *ab initio* reconstruction
* support for Python 3.13 and PyTorch 2.9; PyTorch <2.0 is no longer supported
A full list of cryoDRGN version updates can be found at our
[release notes](https://github.com/ml-struct-bio/cryodrgn/releases).
## Installation
`cryodrgn` may be installed via `pip`, and we recommend installing `cryodrgn` in a clean conda environment.
Our package is compatible with Python versions 3.10 through 3.13;
we recommend using the latest available Python version:
# Create and activate conda environment
(base) $ conda create --name cryodrgn python=3.13
(cryodrgn) $ conda activate cryodrgn
# install cryodrgn
(cryodrgn) $ pip install cryodrgn
You can alternatively install a newer, less stable, development version of `cryodrgn` using our beta release channel:
(cryodrgn) $ pip install -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ cryodrgn --pre
More installation instructions are found in the [documentation](https://ez-lab.gitbook.io/cryodrgn/installation).
## Quickstart: heterogeneous reconstruction with consensus poses
### 1. Preprocess image stack
First resize your particle images using the `cryodrgn downsample` command:
<details><summary><code>$ cryodrgn downsample -h</code></summary>
usage: cryodrgn downsample [-h] -D D -o MRCS [--is-vol] [--chunk CHUNK]
[--datadir DATADIR]
mrcs
Downsample an image stack or volume by clipping fourier frequencies
positional arguments:
mrcs Input images or volume (.mrc, .mrcs, .star, .cs, or .txt)
optional arguments:
-h, --help show this help message and exit
-D D New box size in pixels, must be even
-o MRCS Output image stack (.mrcs) or volume (.mrc)
--is-vol Flag if input .mrc is a volume
--chunk CHUNK Chunksize (in # of images) to split particle stack when
saving
--relion31 Flag for relion3.1 star format
--datadir DATADIR Optionally provide path to input .mrcs if loading from a
.star or .cs file
--max-threads MAX_THREADS
Maximum number of CPU cores for parallelization (default: 16)
--ind PKL Filter image stack by these indices
</details>
We recommend first downsampling images to 128x128 since larger images can take much longer to train:
$ cryodrgn downsample [input particle stack] -D 128 -o particles.128.mrcs
The maximum recommended image size is D=256, so we also recommend downsampling your images to D=256 if your images
are larger than 256x256:
$ cryodrgn downsample [input particle stack] -D 256 -o particles.256.mrcs
The input file format can be a single `.mrcs` file, a `.txt` file containing paths to multiple `.mrcs` files, a RELION
`.star` file, or a cryoSPARC `.cs` file. For the latter two options, if the relative paths to the `.mrcs` are broken,
the argument `--datadir` can be used to supply the path to where the `.mrcs` files are located.
If there are memory issues with downsampling large particle stacks, add the `--chunk 10000` argument to
save images as separate `.mrcs` files of 10k images.
### 2. Parse image poses from a consensus homogeneous reconstruction
CryoDRGN expects image poses to be stored in a binary pickle format (`.pkl`). Use the `parse_pose_star` or
`parse_pose_csparc` command to extract the poses from a `.star` file or a `.cs` file, respectively.
Example usage to parse image poses from a RELION 3.1 starfile:
$ cryodrgn parse_pose_star particles.star -o pose.pkl
Example usage to parse image poses from a cryoSPARC homogeneous refinement particles.cs file:
$ cryodrgn parse_pose_csparc cryosparc_P27_J3_005_particles.cs -o pose.pkl -D 300
**Note:** The `-D` argument should be the box size of the consensus refinement (and not the downsampled
images from step 1) so that the units for translation shifts are parsed correctly.
### 3. Parse CTF parameters from a .star/.cs file
CryoDRGN expects CTF parameters to be stored in a binary pickle format (`.pkl`).
Use the `parse_ctf_star` or `parse_ctf_csparc` command to extract the relevant CTF parameters from a `.star` file
or a `.cs` file, respectively.
Example usage for a .star file:
$ cryodrgn parse_ctf_star particles.star -o ctf.pkl
If the box size and Angstrom/pixel values are not included in the .star file under fields `_rlnImageSize` and
`_rlnImagePixelSize` respectively, the `-D` and `--Apix` arguments to `parse_ctf_star` should be used instead to
provide the original parameters of the input file (before any downsampling):
$ cryodrgn parse_ctf_star particles.star -D 300 --Apix 1.03 -o ctf.pkl
Example usage for a .cs file:
$ cryodrgn parse_ctf_csparc cryosparc_P27_J3_005_particles.cs -o ctf.pkl
### 4. (Optional) Test pose/CTF parameters parsing
Next, test that pose and CTF parameters were parsed correctly using the voxel-based backprojection script.
The goal is to quickly verify that there are no major problems with the extracted values and that the output structure
resembles the structure from the consensus reconstruction before training.
Example usage:
$ cryodrgn backproject_voxel projections.128.mrcs \
--poses pose.pkl \
--ctf ctf.pkl \
-o backproject.128 \
--first 10000
The output structure `backproject.128/backproject.mrc` will not be identical to the consensus reconstruction because we
only used the first 10k particles images for quicker results.
If the structure is too noisy to interpret, you can use more images with `--first 25000` or use the
entire particle stack (by leaving off the `--first` flag).
**Note:** If the volume does not resemble your structure, you may need to use the flag `--uninvert-data`.
This flips the data sign (e.g. light-on-dark or dark-on-light), which may be needed depending on the
convention used in upstream processing tools.
### 5. Running cryoDRGN heterogeneous reconstruction
When the input images (.mrcs), poses (.pkl), and CTF parameters (.pkl) have been prepared, a cryoDRGN model
can be trained with following command:
<details><summary><code>$ cryodrgn train_vae -h</code></summary>
usage: cryodrgn train_vae [-h] -o OUTDIR --zdim ZDIM --poses POSES [--ctf pkl]
[--load WEIGHTS.PKL] [--checkpoint CHECKPOINT]
[--log-interval LOG_INTERVAL] [-v] [--seed SEED]
[--ind PKL] [--uninvert-data] [--no-window]
[--window-r WINDOW_R] [--datadir DATADIR] [--lazy]
[--max-threads MAX_THREADS]
[--tilt TILT] [--tilt-deg TILT_DEG] [-n NUM_EPOCHS]
[-b BATCH_SIZE] [--wd WD] [--lr LR] [--beta BETA]
[--beta-control BETA_CONTROL] [--norm NORM NORM]
[--no-amp] [--multigpu] [--do-pose-sgd]
[--pretrain PRETRAIN] [--emb-type {s2s2,quat}]
[--pose-lr POSE_LR] [--enc-layers QLAYERS]
[--enc-dim QDIM]
[--encode-mode {conv,resid,mlp,tilt}]
[--enc-mask ENC_MASK] [--use-real]
[--dec-layers PLAYERS] [--dec-dim PDIM]
[--pe-type {geom_ft,geom_full,geom_lowf,geom_nohighf,linear_lowf,gaussian,none}]
[--feat-sigma FEAT_SIGMA] [--pe-dim PE_DIM]
[--domain {hartley,fourier}]
[--activation {relu,leaky_relu}]
particles
Train a VAE for heterogeneous reconstruction with known pose
positional arguments:
particles Input particles (.mrcs, .star, .cs, or .txt)
optional arguments:
-h, --help show this help message and exit
-o OUTDIR, --outdir OUTDIR
Output directory to save model
--zdim ZDIM Dimension of latent variable
--poses POSES Image poses (.pkl)
--ctf pkl CTF parameters (.pkl)
--load WEIGHTS.PKL Initialize training from a checkpoint
--checkpoint CHECKPOINT
Checkpointing interval in N_EPOCHS (default: 1)
--log-interval LOG_INTERVAL
Logging interval in N_IMGS (default: 1000)
-v, --verbose Increaes verbosity
--seed SEED Random seed
Dataset loading:
--ind PKL Filter particle stack by these indices
--uninvert-data Do not invert data sign
--no-window Turn off real space windowing of dataset
--window-r WINDOW_R Windowing radius (default: 0.85)
--datadir DATADIR Path prefix to particle stack if loading relative
paths from a .star or .cs file
--lazy Lazy loading if full dataset is too large to fit in
memory (Should copy dataset to SSD)
--max-threads MAX_THREADS
Maximum number of CPU cores for FFT parallelization
(default: 16)
Tilt series:
--tilt TILT Particles (.mrcs)
--tilt-deg TILT_DEG X-axis tilt offset in degrees (default: 45)
Training parameters:
-n NUM_EPOCHS, --num-epochs NUM_EPOCHS
Number of training epochs (default: 20)
-b BATCH_SIZE, --batch-size BATCH_SIZE
Minibatch size (default: 8)
--wd WD Weight decay in Adam optimizer (default: 0)
--lr LR Learning rate in Adam optimizer (default: 0.0001)
--beta BETA Choice of beta schedule or a constant for KLD weight
(default: 1/zdim)
--beta-control BETA_CONTROL
KL-Controlled VAE gamma. Beta is KL target. (default:
None)
--norm NORM NORM Data normalization as shift, 1/scale (default: 0, std
of dataset)
--no-amp Do not use mixed-precision training
--multigpu Parallelize training across all detected GPUs
Pose SGD:
--do-pose-sgd Refine poses with gradient descent
--pretrain PRETRAIN Number of epochs with fixed poses before pose SGD
(default: 1)
--emb-type {s2s2,quat}
SO(3) embedding type for pose SGD (default: quat)
--pose-lr POSE_LR Learning rate for pose optimizer (default: 0.0003)
Encoder Network:
--enc-layers QLAYERS Number of hidden layers (default: 3)
--enc-dim QDIM Number of nodes in hidden layers (default: 1024)
--encode-mode {conv,resid,mlp,tilt}
Type of encoder network (default: resid)
--enc-mask ENC_MASK Circular mask of image for encoder (default: D/2; -1
for no mask)
--use-real Use real space image for encoder (for convolutional
encoder)
Decoder Network:
--dec-layers PLAYERS Number of hidden layers (default: 3)
--dec-dim PDIM Number of nodes in hidden layers (default: 1024)
--pe-type {geom_ft,geom_full,geom_lowf,geom_nohighf,linear_lowf,gaussian,none}
Type of positional encoding (default: gaussian)
--feat-sigma FEAT_SIGMA
Scale for random Gaussian features
--pe-dim PE_DIM Num features in positional encoding (default: image D)
--domain {hartley,fourier}
Decoder representation domain (default: fourier)
--activation {relu,leaky_relu}
Activation (default: relu)
</details>
Many of the parameters of this script have sensible defaults. The required arguments are:
* an input image stack (`.mrcs` or other listed file types)
* `--poses`, image poses (`.pkl`) that correspond to the input images
* `--ctf`, ctf parameters (`.pkl`), unless phase-flipped images are used
* `--zdim`, the dimension of the latent variable
* `-o`, a clean output directory for saving results
Additional parameters that may be adjusted include:
* `-n`, Number of epochs to train
* `--uninvert-data`, Used if particles are dark on light (negative stain format)
* Architecture parameters with `--enc-layers`, `--enc-dim`, `--dec-layers`, `--dec-dim`
* `--multigpu` to enable parallelized training across multiple GPUs
* `-b`, Minibatch size (affects training speed/dynamics)
### Recommended usage:
1) We highly recommend first training on downsampled images (e.g. D=128) to sanity check results and perform any particle filtering (e.g. of junk particles). If your dataset is very large (>300k particles), we also recommend training on a subset of your dataset.
Example command to train a cryoDRGN model for 25 epochs on an image dataset `particles.128.mrcs`
with poses `pose.pkl` and ctf parameters `ctf.pkl`:
# 8-D latent variable model, small images
$ cryodrgn train_vae particles.128.mrcs \
--poses pose.pkl \
--ctf ctf.pkl \
--zdim 8 -n 25 \
-o 00_cryodrgn128
2) After validating that the initial cryodrgn results are sensible (e.g. after any particle filtering or pose optimization),
then train on the full resolution images (up to D=256):
Example command to train a cryoDRGN model for 25 epochs on an image dataset `particles.256.mrcs`
with poses `pose.pkl` and ctf parameters `ctf.pkl`:
# 8-D latent variable model, larger images
$ cryodrgn train_vae particles.256.mrcs \
--poses pose.pkl \
--ctf ctf.pkl \
--zdim 8 -n 25 \
-o 01_cryodrgn256
The number of epochs `-n` refers to the number of full passes through the dataset for training, and should be modified
depending on the number of particles in the dataset. For a 100k particle dataset on 1 V100 GPU,
the above settings required ~12 min/epoch for D=128 images and ~47 min/epoch for D=256 images.
If you would like to train longer, a training job can be extended with the `--load` argument.
For example to extend the training of the previous example to 50 epochs:
$ cryodrgn train_vae particles.256.mrcs \
--poses pose.pkl \
--ctf ctf.pkl \
--zdim 8 -n 50 \
-o 01_cryodrgn256 \
--load 01_cryodrgn256/weights.25.pkl # 1-based indexing
### Accelerated training with GPU parallelization
Use cryoDRGN's `--multigpu` flag to parallelize training across all detected GPUs on the machine.
To select specific GPUs for cryoDRGN, use the environmental variable `CUDA_VISIBLE_DEVICES`, e.g.:
$ cryodrgn train_vae ... # Run on GPU 0
$ cryodrgn train_vae ... --multigpu # Run on all GPUs on the machine
$ CUDA_VISIBLE_DEVICES=0,3 cryodrgn train_vae ... --multigpu # Run on GPU 0,3
We recommend using `--multigpu` for large images, e.g. D=256.
Note that GPU computation may not be the training bottleneck for smaller images (D=128).
In this case, `--multigpu` may not speed up training (while taking up additional compute resources).
With `--multigpu`, the batch size is multiplied by the number of available GPUs to better utilize GPU resources.
We note that GPU utilization may be further improved by increasing the batch size (e.g. `-b 16`), however,
faster wall-clock time per epoch does not necessarily lead to faster *convergence* since the training dynamics
are affected (fewer model updates per epoch with larger `-b`).
Thus, using `--multigpu` may require increasing the total number of epochs. As a best practice, we recommend
first training for 25 epochs (or however many is practical for your dataset size), and then doubling to 50 epochs
to check for model convergence by inspecting if the final results have changed.
### Local pose refinement -- *beta*
Depending on the quality of the consensus reconstruction, image poses may contain errors.
Image poses may be *locally* refined using the `--do-pose-sgd` flag, however, we recommend reaching out to the
developers for recommended training settings.
For global pose optimization or *ab initio* reconstruction, please see our [cryoDRGN-AI](https://cryodrgnai.cs.princeton.edu/) method.
## 6. Analysis of results
Once the model has finished training, the output directory will contain a configuration file `config.yaml`,
neural network weights `weights.pkl`, image poses (if performing pose sgd) `pose.pkl`,
and the latent embeddings for each image `z.pkl`.
The latent embeddings are provided in the same order as the input particles.
To analyze these results, use the `cryodrgn analyze` command to visualize the latent space and generate structures.
`cryodrgn analyze` will also provide a template jupyter notebook for further interactive visualization and analysis.
### cryodrgn analyze
<details><summary><code>$ cryodrgn analyze -h</code></summary>
usage: cryodrgn analyze [-h] [--device DEVICE] [-o OUTDIR] [--skip-vol]
[--skip-umap] [--Apix APIX] [--flip] [--invert]
[-d DOWNSAMPLE] [--pc PC] [--ksample KSAMPLE]
workdir epoch
Visualize latent space and generate volumes
positional arguments:
workdir Directory with cryoDRGN results
epoch Epoch number N to analyze (1-based indexing,
corresponding to z.N.pkl, weights.N.pkl)
optional arguments:
-h, --help show this help message and exit
--device DEVICE Optionally specify CUDA device
-o OUTDIR, --outdir OUTDIR
Output directory for analysis results (default:
[workdir]/analyze.[epoch])
--skip-vol Skip generation of volumes
--skip-umap Skip running UMAP
Extra arguments for volume generation:
--Apix APIX Pixel size to add to .mrc header (default: 1 A/pix)
--flip Flip handedness of output volumes
--invert Invert contrast of output volumes
-d DOWNSAMPLE, --downsample DOWNSAMPLE
Downsample volumes to this box size (pixels)
--pc PC Number of principal component traversals to generate
(default: 2)
--ksample KSAMPLE Number of kmeans samples to generate (default: 20)
</details>
This script runs a series of standard analyses:
* PCA visualization of the latent embeddings
* UMAP visualization of the latent embeddings
* Generation of volumes. See note [1].
* Generation of trajectories along the first and second principal components of the latent embeddings
* Generation of template jupyter notebooks that may be used for further interactive analyses, visualization, and volume generation
Example usage to analyze results from the direction `01_cryodrgn256` containing results after 25 epochs of training:
$ cryodrgn analyze 01_cryodrgn256 25 --Apix 1.31 # 25 for 1-based indexing of epoch numbers
Notes:
[1] Volumes are generated after k-means clustering of the latent embeddings with k=20 by default.
Note that we use k-means clustering here not to identify clusters, but to segment the latent space and
generate structures from different regions of the latent space.
The number of structures that are generated may be increased with the option `--ksample`.
[2] The `cryodrgn analyze` command chains together a series of calls to `cryodrgn eval_vol` and other scripts
that can be run separately for more flexibility.
These scripts are located in the `analysis_scripts` directory within the source code.
[3] In particular, you may find it useful to perform filtering of particles separately from other analyses. This can
done using our interactive interface available from the command line: `cryodrgn filter 01_cryodrgn256`.
[4] `--Apix` only needs to be given if it is not present in the CTF file that was used in training.
### Generating additional volumes
A simple way of generating additional volumes is to increase the number of k-means samples in `cryodrgn analyze`
by using the flag `--ksample 100` (for 100 structures).
For additional flexibility, `cryodrgn eval_vol` may be called directly:
<details><summary><code>$ cryodrgn eval_vol -h</code></summary>
usage: cryodrgn eval_vol [-h] -c PKL -o O [--prefix PREFIX] [-v]
[-z [Z [Z ...]]] [--z-start [Z_START [Z_START ...]]]
[--z-end [Z_END [Z_END ...]]] [-n N] [--zfile ZFILE]
[--Apix APIX] [--flip] [-d DOWNSAMPLE]
[--norm NORM NORM] [-D D] [--enc-layers QLAYERS]
[--enc-dim QDIM] [--zdim ZDIM]
[--encode-mode {conv,resid,mlp,tilt}]
[--dec-layers PLAYERS] [--dec-dim PDIM]
[--enc-mask ENC_MASK]
[--pe-type {geom_ft,geom_full,geom_lowf,geom_nohighf,linear_lowf,none}]
[--pe-dim PE_DIM] [--domain {hartley,fourier}]
[--l-extent L_EXTENT]
[--activation {relu,leaky_relu}]
weights
Evaluate the decoder at specified values of z
positional arguments:
weights Model weights
optional arguments:
-h, --help show this help message and exit
-c YAML, --config YAML CryoDRGN config.yaml file
-o O Output .mrc or directory
--prefix PREFIX Prefix when writing out multiple .mrc files (default: vol_)
-v, --verbose Increase verbosity
Specify z values:
-z [Z [Z ...]] Specify one z-value
--z-start [Z_START [Z_START ...]]
Specify a starting z-value
--z-end [Z_END [Z_END ...]]
Specify an ending z-value
-n N Number of structures between [z_start, z_end]
--zfile ZFILE Text file with z-values to evaluate
Volume arguments:
--Apix APIX Pixel size to add to .mrc header (default: 1 A/pix)
--flip Flip handedness of output volume
-d DOWNSAMPLE, --downsample DOWNSAMPLE
Downsample volumes to this box size (pixels)
Overwrite architecture hyperparameters in config.yaml:
--norm NORM NORM
-D D Box size
--enc-layers QLAYERS Number of hidden layers
--enc-dim QDIM Number of nodes in hidden layers
--zdim ZDIM Dimension of latent variable
--encode-mode {conv,resid,mlp,tilt}
Type of encoder network
--dec-layers PLAYERS Number of hidden layers
--dec-dim PDIM Number of nodes in hidden layers
--enc-mask ENC_MASK Circular mask radius for image encoder
--pe-type {geom_ft,geom_full,geom_lowf,geom_nohighf,linear_lowf,none}
Type of positional encoding
--pe-dim PE_DIM Num sinusoid features in positional encoding (default:
D/2)
--domain {hartley,fourier}
--l-extent L_EXTENT Coordinate lattice size
--activation {relu,leaky_relu}
Activation (default: relu)
</details>
**Example usage:**
To generate a volume at a single value of the latent variable:
$ cryodrgn eval_vol [YOUR_WORKDIR]/weights.pkl --config [YOUR_WORKDIR]/config.yaml -z ZVALUE -o reconstruct.mrc
The number of inputs for `-z` must match the dimension of your latent variable.
Or to generate a trajectory of structures from a defined start and ending point,
use the `--z-start` and `--z-end` arugments:
$ cryodrgn eval_vol [YOUR_WORKDIR]/weights.pkl --config [YOUR_WORKDIR]/config.yaml -o [WORKDIR]/trajectory \
--z-start -3 --z-end 3 -n 20
This example generates 20 structures at evenly spaced values between z=[-3,3],
assuming a 1-dimensional latent variable model.
Finally, a series of structures can be generated using values of z given in a file specified by the arugment `--zfile`:
$ cryodrgn eval_vol [WORKDIR]/weights.pkl --config [WORKDIR]/config.yaml --zfile zvalues.txt -o [WORKDIR]/trajectory
The input to `--zfile` is expected to be an array of dimension (N_volumes x zdim), loaded with np.loadtxt.
### Making trajectories
Three additional commands can be used in conjunction with `cryodrgn eval_vol` to generate trajectories:
$ cryodrgn pc_traversal -h
$ cryodrgn graph_traversal -h
$ cryodrgn direct_traversal -h
These scripts produce a text file of z values that can be input to `cryodrgn eval_vol` to generate a series of
structures that can be visualized as a trajectory in ChimeraX (https://www.cgl.ucsf.edu/chimerax).
Documentation: https://ez-lab.gitbook.io/cryodrgn/cryodrgn-graph-traversal-for-making-long-trajectories
### cryodrgn analyze_landscape
NEW in version 1.0: There are two additional tools `cryodrgn analyze_landscape` and `cryodrgn analyze_landscape_full`
for more comprehensive and automated analyses of cryodrgn results.
Documentation: https://ez-lab.gitbook.io/cryodrgn/cryodrgn-conformational-landscape-analysis
## *Ab Initio* Reconstruction
CryoDRGN-AI is currently available through the `cryodrgn abinit` command. Please see the
[corresponding manuscript](https://cryodrgnai.cs.princeton.edu/) for details on our latest version
of *ab initio* reconstruction.
An earlier version of *ab initio* reconstruction was developed as cryoDRGN2 and is still available with the
`cryodrgn abinit_het_old` and `cryodrgn abinit_homo_old` executables.
CryoDRGN2 documentation: https://ez-lab.gitbook.io/cryodrgn/cryodrgn2-ab-initio-reconstruction
The arguments for all *ab initio* reconstruction commands are similar to `cryodrgn train_vae`,
but the `--poses` argument is not required.
## CryoDRGN-ET for subtomogram analysis
CryoDRGN-ET for heterogeneous subtomogram averaging is available in cryodrgn version 3.0+. Documentation for getting started can be found
in the [user guide](https://ez-lab.gitbook.io/cryodrgn/cryodrgn-et-subtomogram-analysis).
## References:
For a complete description of the method, see:
* CryoDRGN: reconstruction of heterogeneous cryo-EM structures using neural networks
Ellen D. Zhong, Tristan Bepler, Bonnie Berger*, Joseph H Davis*
Nature Methods, 2021, https://doi.org/10.1038/s41592-020-01049-4 [pdf](https://ezlab.princeton.edu/assets/pdf/2021_cryodrgn_nature_methods.pdf)
For a description of our extension to heterogeneous subtomogram averaging, see:
* CryoDRGN-ET: deep reconstructing generative networks for visualizing dynamic biomolecules inside cells
Ramya Rangan*, Ryan Feathers*, Sagar Khavnekar, Adam Lerer, Jake Johnston, Ron Kelley, Martin Obr, Abhay Kotecha, and Ellen D. Zhong
Nature Methods, 2024, https://doi.org/10.1038/s41592-024-02340-4 [pdf](https://ezlab.cs.princeton.edu/assets/pdf/2024_cryodrgnet.pdf)
For a description of our *ab initio* reconstruction method, see:
* CryoDRGN-AI: neural ab initio reconstruction of challenging cryo-EM and cryo-ET datasets
Axel Levy, Rishwanth Raghu, Ryan Feathers, Michal Grzadkowski, Frederic Poitevin, Jake D. Johnston, Francesca Vallese, Oliver B. Clarke, Gordon Wetzstein, and Ellen D. Zhong
Nature Methods, 2025, https://doi.org/10.1038/s41592-025-02720-4
A preliminary version of cryoDRGN was presented at ICLR 2020:
* Reconstructing continuous distributions of protein structure from cryo-EM images
Ellen D. Zhong, Tristan Bepler, Joseph H. Davis*, Bonnie Berger*
ICLR 2020, Spotlight, https://arxiv.org/abs/1909.05215
A preliminary version of *ab initio* reconstruction in cryoDRGN2 was presented at ICCV 2021:
* CryoDRGN2: Ab Initio Neural Reconstruction of 3D Protein Structures From Real Cryo-EM Images
Ellen D. Zhong, Adam Lerer, Joseph H Davis, and Bonnie Berger
International Conference on Computer Vision (ICCV) 2021, [paper](https://openaccess.thecvf.com/content/ICCV2021/papers/Zhong_CryoDRGN2_Ab_Initio_Neural_Reconstruction_of_3D_Protein_Structures_From_ICCV_2021_paper.pdf)
A protocols paper that describes the analysis of the EMPIAR-10076 assembling ribosome dataset:
* Uncovering structural ensembles from single particle cryo-EM data using cryoDRGN
Laurel Kinman, Barrett Powell, Ellen D. Zhong*, Bonnie Berger*, Joseph H Davis*
Nature Protocols 2023, https://doi.org/10.1038/s41596-022-00763-x
## Contact
Please submit any bug reports, feature requests, or general usage feedback as a github issue or discussion! Thank you!
================================================
FILE: analysis_scripts/kmeans.py
================================================
"""K-means clustering"""
import argparse
import pickle
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial.distance import cdist
from sklearn.decomposition import PCA
from cryodrgn import analysis
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input", help="Input z.pkl")
parser.add_argument("-k", type=int, required=True, help="# clusters")
parser.add_argument("--stride", type=int, help="Stride the dataset")
parser.add_argument("-o", help="Output labels (.pkl)")
parser.add_argument("--out-png", help="Output image (.png)")
parser.add_argument("--out-k", help="Output cluster centers z values (.txt)")
parser.add_argument(
"--on-data",
action="store_true",
help="Use nearest data point instead of cluster center",
)
parser.add_argument("--out-k-ind", help="Output cluster center indices (.txt)")
parser.add_argument(
"--reorder", action="store_true", help="Reorder cluster centers"
)
return parser
def main(args):
fig, ax = plt.subplots()
print(args)
z = pickle.load(open(args.input, "rb"))
if args.stride:
z = z[:: args.stride]
print("{} points".format(len(z)))
# k-means clustering
labels, centers = analysis.cluster_kmeans(
z, args.k, on_data=args.on_data, reorder=args.reorder
)
# use the nearest data point instead of cluster centroid
if args.on_data:
centers_zi = cdist(centers, z).argmin(axis=1)
print(centers_zi)
centers_z = z[centers_zi]
centers = centers_z
if args.out_k_ind:
np.savetxt(args.out_k_ind, centers_zi, fmt="%d")
if args.o:
with open(args.o, "wb") as f:
pickle.dump(labels, f)
if args.out_k:
np.savetxt(args.out_k, centers)
# dimensionality reduction for viz
pca = PCA(z.shape[1])
pca.fit(z)
print("PCA explained variance ratio:")
print(pca.explained_variance_ratio_)
pc = pca.transform(z)
for i in range(args.k):
ii = np.where(labels == i)
pc_sub = pc[ii]
plt.scatter(
pc_sub[:, 0], pc_sub[:, 1], s=2, alpha=0.1, label="cluster {}".format(i)
)
c = pca.transform(centers)
plt.scatter(c[:, 0], c[:, 1], c="k")
for i in range(args.k):
ax.annotate(str(i), c[i, 0:2])
xx, yy = 0, 1
plt.xlabel("PC{} ({:3f})".format(xx + 1, pca.explained_variance_ratio_[xx]))
plt.ylabel("PC{} ({:3f})".format(yy + 1, pca.explained_variance_ratio_[yy]))
if args.out_png:
plt.savefig(args.out_png)
else:
plt.show()
if __name__ == "__main__":
main(parse_args().parse_args())
================================================
FILE: analysis_scripts/plot_loss.py
================================================
"""Plot the learning curve"""
import argparse
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from cryodrgn import analysis
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input", nargs="+", help="Input run.log file(s)")
parser.add_argument("-o", help="Output PNG")
return parser.parse_args()
def main(args):
cmap = matplotlib.cm.get_cmap("jet")
i = 0
cs = np.arange(len(args.input)) / len(args.input)
for f in args.input:
loss = analysis.parse_loss(f)
c = cmap(cs[i])
plt.plot(loss, label=f, c=c)
print(f)
print(loss)
i += 1
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend(loc="best")
if args.o:
plt.savefig(args.o)
else:
plt.show()
if __name__ == "__main__":
main(parse_args())
================================================
FILE: analysis_scripts/plot_z1.py
================================================
"""
"""
import argparse
import pickle
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input", help="Input z.pkl")
parser.add_argument("-o", help="Output PNG")
parser.add_argument(
"--ms",
default=2,
type=float,
help="Marker size for plotting (default: %(default)s)",
)
parser.add_argument(
"--alpha",
default=0.1,
type=float,
help="Alpha value for plotting (default: %(default)s)",
)
parser.add_argument("--ylim", nargs=2, type=float)
parser.add_argument(
"--sample1", type=int, help="Plot z value for N randomly sampled points"
)
parser.add_argument(
"--sample2", type=int, help="Plot median z after chunking into N chunks"
)
parser.add_argument(
"--seed", default=0, type=int, help="Random seed (default: %(default)s)"
)
parser.add_argument("--out-s", help="Save sampled z values (.txt)")
return parser
def main(args):
np.random.seed(args.seed)
f = args.input
print(f)
fi = open(f, "rb")
x = pickle.load(fi)
N = len(x)
plt.scatter(np.arange(N), x, label=f, alpha=args.alpha, s=args.ms)
# plt.scatter(np.arange(N), x, c=np.arange(len(x[:,0])), label=f, alpha=.1, s=2, cmap='hsv')
plt.xlim((0, N))
xd = None
if args.sample1:
s = np.random.choice(len(x), args.sample1)
xd = x[s]
print(xd)
plt.plot(s, xd, "o")
if args.sample2:
t = np.array_split(np.arange(len(x)), args.sample2)
t = np.array([np.median(tt, axis=0) for tt in t])
xsplit = np.array_split(x, args.sample2)
xd = np.array([np.median(xs, axis=0) for xs in xsplit])
print(len(xd))
print(xd)
plt.plot(t, xd, "o", color="k")
if args.out_s and xd is not None:
np.savetxt(args.out_s, xd)
if args.ylim:
plt.ylim(args.ylim)
plt.xlabel("image")
plt.ylabel("latent encoding")
plt.legend(loc="best")
if args.o:
plt.savefig(args.o)
# Plot histogram
plt.figure()
sns.distplot(x)
plt.show()
if __name__ == "__main__":
main(parse_args().parse_args())
================================================
FILE: analysis_scripts/plot_z2.py
================================================
"""
Plot 2D latent space
"""
import argparse
import pickle
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input", help="Input z pkl")
parser.add_argument("-o", "--out-png", help="Output PNG")
parser.add_argument(
"--ms",
default=2,
type=float,
help="Marker size for plotting (default: %(default)s)",
)
parser.add_argument(
"--alpha",
default=0.1,
type=float,
help="Alpha value for plotting (default: %(default)s)",
)
parser.add_argument(
"--sample1",
type=int,
help="Optionally plot z value for N randomly sampled points",
)
parser.add_argument(
"--sample2",
type=int,
help="Optionally plot median z after chunking into N chunks",
)
parser.add_argument("--out-s", help="Save sampled z values (.txt)")
parser.add_argument(
"--color", action="store_true", help="Color points by image index"
)
parser.add_argument(
"--seed", default=0, type=int, help="Random seed (default: %(default)s)"
)
parser.add_argument(
"--annotate", action="store_true", help="Annotate sampled points in plot"
)
parser.add_argument(
"--kde", action="store_true", help="KDE plot instead of scatter"
)
parser.add_argument("--stride", type=int, help="Stride dataset")
return parser
def main(args):
np.random.seed(args.seed)
f = args.input
print(f)
x = pickle.load(open(f, "rb"))
if args.stride:
x = x[:: args.stride]
print(x.shape)
# seaborn jointpoint
if args.kde:
g = sns.jointplot(x[:, 0], x[:, 1], kind="kde")
ax = g.ax_joint
# scatter plot
else:
fig, ax = plt.subplots()
if args.color:
plt.scatter(
x[:, 0],
x[:, 1],
c=np.arange(len(x[:, 0])),
label=f,
alpha=args.alpha,
s=args.ms,
cmap="hsv",
)
else:
plt.scatter(x[:, 0], x[:, 1], label=f, alpha=args.alpha, s=args.ms)
plt.xlabel("z1")
plt.ylabel("z2")
plt.legend(loc="best")
xd = None
if args.sample1:
ii = np.random.choice(len(x), args.sample1)
print(ii)
xd = x[ii]
print(xd)
plt.scatter(xd[:, 0], xd[:, 1], c=np.arange(len(xd)), cmap="hsv")
if args.annotate:
for i in range(args.sample1):
ax.annotate(str(i), xd[i])
if args.sample2:
xsplit = np.array_split(x, args.sample2)
xd = np.array([np.median(xs, axis=0) for xs in xsplit])
print(len(xd))
print(xd)
plt.scatter(xd[:, 0], xd[:, 1], c="k") # np.arange(len(xd)),cmap='hsv')
if args.out_s and xd is not None:
np.savetxt(args.out_s, xd)
if args.out_png:
plt.savefig(args.out_png)
else:
plt.show()
if __name__ == "__main__":
main(parse_args().parse_args())
================================================
FILE: analysis_scripts/plot_z_pca.py
================================================
"""
Plot PCA projection of latent space
"""
import argparse
import pickle
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input", help="Input z pkl")
parser.add_argument("-o", "--out-png", help="Output PNG")
parser.add_argument(
"--axis",
type=int,
nargs=2,
default=[0, 1],
help="Dimensions to plot (default: %(default)s)",
)
parser.add_argument(
"--ms",
default=2,
type=float,
help="Marker size for plotting (default: %(default)s)",
)
parser.add_argument(
"--alpha",
default=0.1,
type=float,
help="Alpha value for plotting (default: %(default)s)",
)
parser.add_argument(
"--sample1",
type=int,
help="Optionally plot z value for N randomly sampled points",
)
parser.add_argument(
"--sample2",
type=int,
help="Optionally plot median z after chunking into N chunks",
)
parser.add_argument("--out-s", help="Save sampled z values (.txt)")
parser.add_argument(
"--color", action="store_true", help="Color points by image index"
)
parser.add_argument(
"--seed", default=0, type=int, help="Random seed (default: %(default)s)"
)
parser.add_argument(
"--annotate", action="store_true", help="Annotate sampled points in plot"
)
return parser
def main(args):
np.random.seed(args.seed)
fig, ax = plt.subplots()
print(args.input)
x = pickle.load(open(args.input, "rb"))
# PCA
pca = PCA(x.shape[1])
pca.fit(x)
print("Explained variance ratio:")
print(pca.explained_variance_ratio_)
pc = pca.transform(x)
ii, jj = args.axis
if args.color:
plt.scatter(
pc[:, ii],
pc[:, jj],
c=np.arange(len(x)),
label=args.input,
alpha=args.alpha,
s=args.ms,
cmap="hsv",
)
else:
plt.scatter(pc[:, ii], pc[:, jj], label=args.input, alpha=args.alpha, s=args.ms)
plt.xlabel("PC{} ({:3f})".format(ii + 1, pca.explained_variance_ratio_[ii]))
plt.ylabel("PC{} ({:3f})".format(jj + 1, pca.explained_variance_ratio_[jj]))
xd = None
if args.sample1:
s = np.random.choice(len(x), args.sample1)
print(s)
xd = x[s]
xd_pc = pca.transform(xd)
plt.scatter(xd_pc[:, ii], xd_pc[:, jj], c=np.arange(len(xd)), cmap="hsv")
if args.annotate:
for i in range(args.sample1):
ax.annotate(str(i), xd_pc[i, args.axis])
if args.sample2:
xsplit = np.array_split(x, args.sample2)
print([len(k) for k in xsplit])
xd = np.array([np.median(xs, axis=0) for xs in xsplit])
# xd = np.array([np.mean(xs,axis=0) for xs in xsplit])
print(len(xd))
xd_pc = pca.transform(xd)
plt.scatter(xd_pc[:, ii], xd_pc[:, jj], c="k") # np.arange(len(xd)),cmap='hsv')
if args.out_s and xd is not None:
np.savetxt(args.out_s, xd)
plt.legend(loc="best")
if args.out_png:
plt.savefig(args.out_png)
else:
plt.show()
if __name__ == "__main__":
main(parse_args().parse_args())
================================================
FILE: analysis_scripts/run_umap.py
================================================
"""
UMAP dimensionality reduction
"""
import argparse
import pickle
import warnings
import matplotlib.pyplot as plt
import umap
warnings.filterwarnings("ignore") # ignore numba warnings from umap
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input", help="Input z.pkl")
parser.add_argument("--stride", type=int, help="Stride the dataset")
parser.add_argument("-o", help="Output UMAP embeddings (.pkl)")
parser.add_argument("--show", action="store_true", help="Show UMAP plot")
return parser
def main(args):
z = pickle.load(open(args.input, "rb"))
if args.stride:
z = z[:: args.stride]
print(z.shape)
reducer = umap.UMAP()
z_embedded = reducer.fit_transform(z)
if args.o:
pickle.dump(z_embedded, open(args.o, "wb"))
if args.show:
plt.scatter(z_embedded[:, 0], z_embedded[:, 1], s=2, alpha=0.05)
plt.show()
if __name__ == "__main__":
main(parse_args().parse_args())
================================================
FILE: analysis_scripts/tsne.py
================================================
"""tSNE dimensionality reduction"""
import argparse
import pickle
import logging
from sklearn.manifold import TSNE
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input", help="Input z.pkl")
parser.add_argument("--stride", type=int, help="Stride the dataset")
parser.add_argument(
"-p", default=1000.0, type=float, help="Perplexity (default: %(default)s)"
)
parser.add_argument("-o", help="Output pickle")
return parser
def main(args):
with open(args.input, "rb") as f:
z = pickle.load(f)
if args.stride:
z = z[:: args.stride]
logger.info(
f"Loaded zdim={z.shape[1]} latent space and used "
f"striding to reduce to {z.shape[0]} datapoints"
)
else:
logger.info(
f"Loaded zdim={z.shape[1]} latent space with {z.shape[0]} datapoints"
)
logger.info("Fitting t-SNE...")
z_embedded = TSNE(n_components=2, perplexity=args.p).fit_transform(z)
with open(args.o, "wb") as f:
pickle.dump(z_embedded, f)
if __name__ == "__main__":
main(parse_args().parse_args())
================================================
FILE: cryodrgn/__init__.py
================================================
import os
import logging.config
# Necessary to avoid deprecation warnings from datetime.strptime when using seaborn
# with Python 3.13
import warnings
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message="Parsing dates involving a day of month without "
"a year specified is ambiguious.*",
)
# The _version.py file is managed by setuptools-scm
# and is not in version control.
try:
from cryodrgn._version import version as __version__ # type: ignore
except ModuleNotFoundError:
# We're likely running as a source package without installation
__version__ = "src"
_ROOT = os.path.abspath(os.path.dirname(__file__))
logging.config.dictConfig(
{
"version": 1,
"formatters": {
"standard": {
"format": "(%(levelname)s) (%(filename)s) (%(asctime)s) %(message)s",
"datefmt": "%d-%b-%Y %H:%M:%S",
}
},
"handlers": {
"default": {
"level": "NOTSET",
"formatter": "standard",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
}
},
"loggers": {"": {"handlers": ["default"], "level": "INFO"}},
}
)
================================================
FILE: cryodrgn/analysis.py
================================================
import argparse
import re
import logging
import warnings
import matplotlib.pyplot as plt
from matplotlib.figure import Figure, Axes
import numpy as np
import numpy.typing as npt
import pandas as pd
import seaborn as sns
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.mixture import GaussianMixture
from typing import Optional, Union, Tuple, List
from cryodrgn.commands import eval_vol
logger = logging.getLogger(__name__)
# Necessary to avoid warnings from UMAP when using newer versions of numpy/scipy
warnings.filterwarnings(
"ignore",
category=FutureWarning,
message=".*force_all_finite.*ensure_all_finite.*",
)
def parse_loss(f: str) -> np.ndarray:
"""Parse loss from run.log"""
lines = open(f).readlines()
lines = [x for x in lines if "====" in x]
regex = "total\sloss\s=\s(\d.\d+)" # type: ignore # noqa: W605
matches = [re.search(regex, x) for x in lines]
loss = []
for m in matches:
# assert m is not None
if m:
loss.append(m.group(1))
return np.asarray(loss).astype(np.float32)
# Dimensionality reduction
def run_pca(z: np.ndarray) -> Tuple[np.ndarray, PCA]:
pca = PCA(z.shape[1])
pca.fit(z)
logger.info("Explained variance ratio:")
logger.info(pca.explained_variance_ratio_)
pc = pca.transform(z)
return pc, pca
def get_pc_traj(
pca: PCA,
zdim: int,
numpoints: int,
dim: int,
start: Optional[float],
end: Optional[float],
percentiles: Optional[np.ndarray] = None,
) -> npt.NDArray[np.float32]:
"""
Create trajectory along specified principal component
Inputs:
pca: sklearn PCA object from run_pca
zdim (int)
numpoints (int): number of points between @start and @end
dim (int): PC dimension for the trajectory (1-based index)
start (float): Value of PC{dim} to start trajectory
end (float): Value of PC{dim} to stop trajectory
percentiles (np.array or None): Define percentile array instead of np.linspace(start,stop,numpoints)
Returns:
np.array (numpoints x zdim) of z values along PC
"""
if percentiles is not None:
assert len(percentiles) == numpoints
traj_pca = np.zeros((numpoints, zdim))
if percentiles is not None:
traj_pca[:, dim - 1] = percentiles
else:
assert start is not None
assert end is not None
traj_pca[:, dim - 1] = np.linspace(start, end, numpoints)
ztraj_pca = pca.inverse_transform(traj_pca)
return ztraj_pca
def run_tsne(
z: np.ndarray, n_components: int = 2, perplexity: float = 1000
) -> np.ndarray:
if len(z) > 10000:
logger.warning(
"WARNING: {} datapoints > {}. This may take awhile.".format(len(z), 10000)
)
z_embedded = TSNE(n_components=n_components, perplexity=perplexity).fit_transform(z)
return z_embedded
def run_umap(z: np.ndarray, **kwargs) -> np.ndarray:
import umap # CAN GET STUCK IN INFINITE IMPORT LOOP
reducer = umap.UMAP(**kwargs)
z_embedded = reducer.fit_transform(z)
return z_embedded
# Clustering
def cluster_kmeans(
z: np.ndarray,
K: int,
on_data: bool = True,
reorder: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Cluster z by K means clustering
Returns cluster labels, cluster centers
If reorder=True, reorders clusters according to agglomerative clustering of cluster centers
"""
kmeans = KMeans(n_clusters=K, random_state=0, max_iter=10)
kmeans_groups = kmeans.fit_predict(z)
centers = kmeans.cluster_centers_
centers_ind = None
if on_data:
centers, centers_ind = get_nearest_point(z, centers)
if reorder:
g = sns.clustermap(centers)
reordered = g.dendrogram_row.reordered_ind
centers = centers[reordered]
if centers_ind is not None:
centers_ind = centers_ind[reordered]
tmp = {k: i for i, k in enumerate(reordered)}
kmeans_groups = np.array([tmp[k] for k in kmeans_groups])
return kmeans_groups, centers
def cluster_gmm(
z,
K: int,
on_data: bool = True,
random_state: Union[int, np.random.RandomState, None] = None,
**kwargs,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Cluster z by a K-component full covariance Gaussian mixture model
Inputs:
z (Ndata x zdim np.array): Latent encodings
K (int): Number of clusters
on_data (bool): Compute cluster center as nearest point on the data manifold
random_state (int or None): Random seed used for GMM clustering
**kwargs: Additional keyword arguments passed to sklearn.mixture.GaussianMixture
Returns:
np.array (Ndata,) of cluster labels
np.array (K x zdim) of cluster centers
"""
clf = GaussianMixture(
n_components=K, covariance_type="full", random_state=random_state, **kwargs
)
labels = clf.fit_predict(z)
centers = clf.means_
if on_data:
centers, centers_ind = get_nearest_point(z, centers)
return labels, centers
def get_nearest_point(
data: np.ndarray, query: np.ndarray
) -> Tuple[npt.NDArray[np.float32], np.ndarray]:
"""
Find closest point in @data to @query
Return datapoint, index
"""
ind = cdist(query, data).argmin(axis=1)
return data[ind], ind
# HELPER FUNCTIONS FOR INDEX ARRAY MANIPULATION
def convert_original_indices(
ind: np.ndarray, N_orig: int, orig_ind: np.ndarray
) -> np.ndarray:
"""
Convert index array into indices into the original particle stack
""" # todo -- finish docstring
return np.arange(N_orig)[orig_ind][ind]
def combine_ind(
N: int, sel1: np.ndarray, sel2: np.ndarray, kind: str = "intersection"
) -> Tuple[np.ndarray, np.ndarray]:
# todo -- docstring
if kind == "intersection":
ind_selected = set(sel1) & set(sel2)
elif kind == "union":
ind_selected = set(sel1) | set(sel2)
else:
raise RuntimeError(
f"Mode {kind} not recognized. Choose either 'intersection' or 'union'"
)
ind_selected_not = np.array(sorted(set(np.arange(N)) - ind_selected))
ind_selected = np.array(sorted(ind_selected))
return ind_selected, ind_selected_not
def get_ind_for_cluster(
labels: np.ndarray, selected_clusters: np.ndarray
) -> np.ndarray:
"""Return index array of the selected clusters
Inputs:
labels: np.array of cluster labels for each particle
selected_clusters: list of cluster labels to select
Return:
ind_selected: np.array of particle indices with the desired cluster labels
Example usage:
ind_keep = get_ind_for_cluster(kmeans_labels, [0,4,6,14])
"""
ind_selected = np.array(
[i for i, label in enumerate(labels) if label in selected_clusters]
)
return ind_selected
# PLOTTING
def _get_chimerax_colors(K: int) -> List:
colors = [
"#b2b2b2",
"#ffffb2",
"#b2ffff",
"#b2b2ff",
"#ffb2ff",
"#ffb2b2",
"#b2ffb2",
"#e5bf99",
"#99bfe5",
"#cccc99",
]
colors = [colors[i % len(colors)] for i in range(K)]
return colors
def _get_colors(K: int, cmap: Optional[str] = None) -> List:
if cmap is not None:
cm = plt.get_cmap(cmap)
colors = [cm(i / float(K)) for i in range(K)]
else:
colors = ["C{}".format(i) for i in range(10)]
colors = [colors[i % len(colors)] for i in range(K)]
return colors
def scatter_annotate(
x: np.ndarray,
y: np.ndarray,
centers: Optional[np.ndarray] = None,
centers_ind: Optional[np.ndarray] = None,
annotate: bool = True,
labels: Optional[np.ndarray] = None,
alpha: Union[float, np.ndarray, None] = 0.1,
s: Union[float, np.ndarray, None] = 1,
colors: Union[list, str, None] = None,
) -> Tuple[Figure, Axes]:
fig, ax = plt.subplots(figsize=(4, 4))
plt.scatter(x, y, alpha=alpha, s=s, rasterized=True)
# plot cluster centers
if centers_ind is not None:
assert centers is None
centers = np.array([[x[i], y[i]] for i in centers_ind])
if centers is not None:
if colors is None:
colors = "k"
plt.scatter(centers[:, 0], centers[:, 1], c=colors, edgecolor="black")
if annotate:
assert centers is not None
if labels is None:
labels = np.arange(len(centers)) + 1
assert labels is not None
for i, lbl in enumerate(labels):
ax.annotate(str(lbl), centers[i, 0:2] + np.array([0.1, 0.1]))
return fig, ax
def scatter_annotate_hex(
x: np.ndarray,
y: np.ndarray,
centers: Optional[np.ndarray] = None,
centers_ind: Optional[np.ndarray] = None,
annotate: bool = True,
labels: Optional[np.ndarray] = None,
colors: Union[list, str, None] = None,
) -> sns.JointGrid:
g = sns.jointplot(x=x, y=y, kind="hex", height=4)
# plot cluster centers
if centers_ind is not None:
assert centers is None
centers = np.array([[x[i], y[i]] for i in centers_ind])
if centers is not None:
if colors is None:
colors = "k"
g.ax_joint.scatter(centers[:, 0], centers[:, 1], c=colors, edgecolor="black")
if annotate:
assert centers is not None
if labels is None:
labels = np.arange(len(centers)) + 1
assert labels is not None
for i, lbl in enumerate(labels):
g.ax_joint.annotate(
str(lbl),
centers[i, 0:2] + np.array([0.1, 0.1]),
color="black",
bbox=dict(boxstyle="square,pad=.1", ec="None", fc="1", alpha=0.5),
)
return g
def scatter_color(
x: np.ndarray,
y: np.ndarray,
c: np.ndarray,
cmap: str = "viridis",
s=1,
alpha: float = 0.1,
label: Optional[str] = None,
figsize: Optional[Tuple[float, float]] = None,
) -> Tuple[Figure, Axes]:
fig, ax = plt.subplots(figsize=figsize)
assert len(x) == len(y) == len(c)
sc = plt.scatter(x, y, s=s, alpha=alpha, rasterized=True, cmap=cmap, c=c)
cbar = plt.colorbar(sc)
cbar.set_alpha(1)
if hasattr(cbar, "draw_all"):
cbar.draw_all()
else:
cbar._draw_all()
if label:
cbar.set_label(label)
return fig, ax
def plot_by_cluster(
x,
y,
K,
labels,
centers=None,
centers_ind=None,
annotate=False,
s=2,
alpha=0.1,
colors=None,
cmap=None,
figsize=None,
):
fig, ax = plt.subplots(figsize=figsize)
if type(K) is int:
K = list(range(1, K + 1))
if colors is None:
colors = _get_colors(len(K), cmap)
# scatter by cluster
for i in K:
ii = labels == i
x_sub = x[ii]
y_sub = y[ii]
plt.scatter(
x_sub,
y_sub,
s=s,
alpha=alpha,
label="cluster {}".format(i),
color=colors[i - 1],
rasterized=True,
)
# plot cluster centers
if centers_ind is not None:
assert centers is None
centers = np.array([[x[i], y[i]] for i in centers_ind])
if centers is not None:
plt.scatter(centers[:, 0], centers[:, 1], c="k")
if annotate:
assert centers is not None
for ii, i in enumerate(K):
ax.annotate(str(i), centers[ii, 0:2])
return fig, ax
def plot_by_cluster_subplot(
x, y, K, labels, s=2, alpha=0.1, colors=None, cmap=None, figsize=None
):
if type(K) is int:
K = list(range(1, K + 1))
ncol = int(np.ceil(len(K) ** 0.5))
nrow = int(np.ceil(len(K) / ncol))
fig, ax = plt.subplots(ncol, nrow, sharex=True, sharey=True, figsize=(10, 10))
if colors is None:
colors = _get_colors(len(K), cmap)
for i, ax in zip(K, ax.ravel()):
ii = labels == i
x_sub = x[ii]
y_sub = y[ii]
ax.scatter(x_sub, y_sub, s=s, alpha=alpha, rasterized=True, color=colors[i - 1])
ax.set_title(i)
return fig, ax
def plot_euler(theta, phi, psi, plot_psi=True):
sns.jointplot(
x=theta, y=phi, kind="hex", xlim=(-180, 180), ylim=(0, 180)
).set_axis_labels("theta", "phi")
if plot_psi:
plt.figure()
plt.hist(psi)
plt.xlabel("psi")
def ipy_plot_interactive_annotate(df, ind, opacity=0.3):
"""Interactive plotly widget for a cryoDRGN pandas dataframe with annotated points"""
import plotly.graph_objs as go
from ipywidgets import interactive
if "labels" in df.columns:
text = [
f"Class {k}: index {i}" for i, k in zip(df.index, df.labels)
] # hovertext
else:
text = [f"index {i}" for i in df.index] # hovertext
xaxis, yaxis = df.columns[0], df.columns[1]
scatter = go.Scattergl(
x=df[xaxis],
y=df[yaxis],
mode="markers",
text=text,
marker=dict(
size=2,
opacity=opacity,
),
)
sub = df.loc[ind]
text = [f"{k}){i}" for i, k in zip(sub.index, sub.labels)]
scatter2 = go.Scatter(
x=sub[xaxis],
y=sub[yaxis],
mode="markers+text",
text=text,
textposition="top center",
textfont=dict(size=9, color="black"),
marker=dict(size=5, color="black"),
)
f = go.FigureWidget([scatter, scatter2])
f.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)
def update_axes(xaxis, yaxis, color_by, colorscale):
scatter = f.data[0]
scatter.x = df[xaxis]
scatter.y = df[yaxis]
scatter.marker.colorscale = colorscale
if colorscale is None:
scatter.marker.color = None
else:
scatter.marker.color = df[color_by] if color_by != "index" else df.index
scatter2 = f.data[1]
scatter2.x = sub[xaxis]
scatter2.y = sub[yaxis]
with f.batch_update(): # what is this for??
f.layout.xaxis.title = xaxis
f.layout.yaxis.title = yaxis
widget = interactive(
update_axes,
yaxis=df.select_dtypes("number").columns,
xaxis=df.select_dtypes("number").columns,
color_by=df.columns,
colorscale=[None, "hsv", "plotly3", "deep", "portland", "picnic", "armyrose"],
)
return widget, f
def ipy_plot_interactive(df, opacity=0.3):
"""Interactive plotly widget for a cryoDRGN pandas dataframe"""
import plotly.graph_objs as go
from ipywidgets import interactive
if "labels" in df.columns:
text = [
f"Class {k}: index {i}" for i, k in zip(df.index, df.labels)
] # hovertext
else:
text = [f"index {i}" for i in df.index] # hovertext
xaxis, yaxis = df.columns[0], df.columns[1]
plt_size = max(1.7, 53 / df.shape[0] ** 0.31)
plt_mrk = dict(
size=plt_size, opacity=opacity, color=np.arange(len(df)), colorscale="hsv"
)
f = go.FigureWidget(
[
go.Scattergl(
x=df[xaxis], y=df[yaxis], mode="markers", text=text, marker=plt_mrk
)
]
)
scatter = f.data[0]
f.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)
f.layout.dragmode = "lasso"
def update_axes(xaxis, yaxis, color_by, colorscale):
scatter = f.data[0]
scatter.x = df[xaxis]
scatter.y = df[yaxis]
scatter.marker.colorscale = colorscale
if colorscale is None:
scatter.marker.color = None
else:
scatter.marker.color = df[color_by] if color_by != "index" else df.index
with f.batch_update(): # what is this for??
f.layout.xaxis.title = xaxis
f.layout.yaxis.title = yaxis
widget = interactive(
update_axes,
yaxis=df.select_dtypes("number").columns,
xaxis=df.select_dtypes("number").columns,
color_by=df.columns,
colorscale=[None, "hsv", "plotly3", "deep", "portland", "picnic", "armyrose"],
)
t = go.FigureWidget(
[
go.Table(
header=dict(values=["index"]),
cells=dict(values=[df.index]),
)
]
)
def selection_fn(trace, points, selector):
t.data[0].cells.values = [df.loc[points.point_inds].index]
scatter.on_selection(selection_fn)
return widget, f, t
def plot_projections(imgs, labels=None, max_imgs=25):
if len(imgs) > max_imgs:
imgs = imgs[:max_imgs]
N = len(imgs)
nrows = int(np.floor(N**0.5))
ncols = int(np.ceil(N**0.5))
fig, axes = plt.subplots(
nrows=nrows, ncols=ncols, figsize=(1.5 * ncols, 1.5 * nrows)
)
if not isinstance(axes, np.ndarray):
axes = np.array([[axes]])
if labels is None:
labels = [None for _ in axes.ravel()]
for img, ax, lbl in zip(imgs, axes.ravel(), labels):
ax.imshow(img, cmap="Greys_r")
if lbl is not None:
ax.set_title(lbl)
ax.axis("off")
plt.tight_layout()
return fig, axes
def gen_volumes(
weights,
config,
zfile,
outdir,
device=None,
Apix=None,
flip=False,
downsample=None,
invert=None,
low_pass=None,
crop=None,
vol_start_index=1,
):
"""Call cryodrgn eval_vol to generate volumes at specified z values
Input:
weights (str): Path to model weights .pkl
config (str): Path to config.yaml
zfile (str): Path to .txt file of z values
outdir (str): Path to output directory for volumes,
device (int or None): Specify cuda device
Apix (float or None): Apix of output volume
flip (bool): Flag to flip chirality of output volumes
downsample (int or None): Generate volumes at this box size
invert (bool): Invert contrast of output volumes
low_pass (float or None): Low-pass filter resolution in Angstroms
crop (int or None): Crop volume to this box size after downsampling or low-pass filtering
vol_start_index (int): Start index for generated volumes
"""
args = [weights, "--config", config, "--zfile", zfile, "-o", outdir]
if Apix is not None:
args += ["--Apix", f"{Apix}"]
if flip:
args += ["--flip"]
if downsample is not None:
args += ["-d", f"{downsample}"]
if invert:
args += ["--invert"]
if low_pass is not None:
args += ["--low-pass", f"{low_pass}"]
if crop is not None:
args += ["--crop", f"{crop}"]
if device is not None:
args += ["--device", f"{device}"]
if vol_start_index is not None:
args += ["--vol-start-index", f"{vol_start_index}"]
parser = argparse.ArgumentParser()
eval_vol.add_args(parser)
return eval_vol.main(parser.parse_args(args))
def load_dataframe(
z=None, pc=None, euler=None, trans=None, labels=None, tsne=None, umap=None, **kwargs
):
"""Load results into a pandas dataframe for downstream analysis"""
data = {}
if umap is not None:
data["UMAP1"] = umap[:, 0]
data["UMAP2"] = umap[:, 1]
if tsne is not None:
data["TSNE1"] = tsne[:, 0]
data["TSNE2"] = tsne[:, 1]
if pc is not None:
zD = pc.shape[1]
for i in range(zD):
data[f"PC{i+1}"] = pc[:, i]
if labels is not None:
data["labels"] = labels
if euler is not None:
data["theta"] = euler[:, 0]
data["phi"] = euler[:, 1]
data["psi"] = euler[:, 2]
if trans is not None:
data["tx"] = trans[:, 0]
data["ty"] = trans[:, 1]
if z is not None:
zD = z.shape[1]
for i in range(zD):
data[f"z{i}"] = z[:, i]
for kk, vv in kwargs.items():
data[kk] = vv
df = pd.DataFrame(data=data)
df["index"] = df.index
return df
================================================
FILE: cryodrgn/analysis_drgnai.py
================================================
"""Visualizing latent space and generating volumes for trained models."""
import os
import shutil
import logging
from types import SimpleNamespace
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import nbformat
from cryodrgn import analysis, utils
from cryodrgn import _ROOT as CRYODRGN_ROOT
from cryodrgn import models_ai as models
from cryodrgn.mrcfile import write_mrc
from cryodrgn.lattice import Lattice
matplotlib.use("Agg")
class VolumeGenerator:
"""Helper class to call analysis.gen_volumes"""
def __init__(
self,
hypervolume,
lattice,
zdim,
invert,
radius_mask,
data_norm=(0, 1),
vol_start_index=1,
apix=1.0,
):
self.hypervolume = hypervolume
self.lattice = lattice
self.zdim = zdim
self.invert = invert
self.radius_mask = radius_mask
self.data_norm = data_norm
self.vol_start_index = vol_start_index
self.apix = apix
def gen_volumes(self, outdir, z_values, suffix=None):
"""
z_values: [nz, zdim]
"""
if not os.path.exists(outdir):
os.makedirs(outdir)
zfile = f"{outdir}/z_values.txt"
np.savetxt(zfile, z_values)
for i, z in enumerate(z_values):
if suffix is None:
out_mrc = "{}/{}{:03d}.mrc".format(
outdir, "vol_", i + self.vol_start_index
)
else:
out_mrc = "{}/{}{:03d}.mrc".format(outdir, "vol_", suffix)
vol = models.eval_volume_method(
self.hypervolume,
self.lattice,
self.zdim,
self.data_norm,
zval=z,
radius=self.radius_mask,
)
if self.invert:
vol *= -1
write_mrc(out_mrc, vol.cpu().numpy().astype(np.float32), Apix=self.apix)
class ModelAnalyzer:
"""An engine for analyzing the output of a reconstruction model.
Attributes
----------
configs (AnalysisConfigurations): Values of all parameters that can be
set by the user.
train_configs (TrainingConfigurations): Parameters that were used when
the model was trained.
epoch (int): Which epoch will be analyzed.
skip_umap (bool): UMAP clustering is relatively computationally intense
so sometimes we choose not to do it
n_per_pc (int): How many samples of the latent reconstruction space
will be taken along each principal component axis.
"""
@classmethod
def get_last_cached_epoch(cls, traindir: str) -> int:
chkpnt_files = [fl for fl in os.listdir(traindir) if fl[:8] == "weights."]
epoch = (
-2
if not chkpnt_files
else max(
int(fl.split(".")[1])
for fl in os.listdir(traindir)
if fl[:8] == "weights."
)
)
return epoch
def __init__(
self, traindir: str, config_vals: dict, train_config_vals: dict
) -> None:
self.logger = logging.getLogger(__name__)
self.configs = SimpleNamespace(**config_vals)
self.train_configs = SimpleNamespace(**train_config_vals["training"])
self.traindir = traindir
# Find how input data was normalized for training
self.out_cfgs = {k: v for k, v in train_config_vals.items() if k != "training"}
if "data_norm_mean" not in self.out_cfgs:
self.out_cfgs["data_norm_mean"] = 0.0
if "data_norm_std" not in self.out_cfgs:
self.out_cfgs["data_norm_std"] = 1.0
# Use last completed epoch if no epoch given specified by the user
if self.configs.epoch == -1:
self.epoch = self.get_last_cached_epoch(traindir)
else:
self.epoch = self.configs.epoch
if self.epoch == -2:
raise ValueError(
f"Cannot perform any analyses for output directory `{self.traindir}` "
f"which does not contain any saved drngai training checkpoints!"
)
self.use_cuda = torch.cuda.is_available()
self.device = torch.device("cuda:0" if self.use_cuda else "cpu")
self.logger.info(f"Use cuda {self.use_cuda}")
# Load reconstruction model from the saved checkpoint file
checkpoint_path = os.path.join(self.traindir, f"weights.{self.epoch}.pkl")
self.logger.info(f"Loading model from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, weights_only=False)
hypervolume_params = checkpoint["hypervolume_params"]
hypervolume = models.HyperVolume(**hypervolume_params)
hypervolume.load_state_dict(checkpoint["hypervolume_state_dict"])
hypervolume.eval()
hypervolume.to(self.device)
lattice = Lattice(
checkpoint["hypervolume_params"]["resolution"],
extent=0.5,
device=self.device,
)
self.zdim = checkpoint["hypervolume_params"]["z_dim"]
radius_mask = (
checkpoint["output_mask_radius"]
if "output_mask_radius" in checkpoint
else None
)
self.vg = VolumeGenerator(
hypervolume,
lattice,
self.zdim,
self.configs.invert,
radius_mask,
data_norm=(self.out_cfgs["data_norm_mean"], self.out_cfgs["data_norm_std"]),
apix=self.configs.apix,
)
# Load the conformations if the using a heterogeneous reconstruction model
if self.train_configs.zdim > 0:
self.z = utils.load_pkl(os.path.join(self.traindir, f"z.{self.epoch}.pkl"))
self.n_samples = self.z.shape[0]
else:
self.z = None
self.n_samples = None
# Create an output directory for these analyses
self.outdir = os.path.join(self.traindir, f"analyze.{self.epoch}")
os.makedirs(self.outdir, exist_ok=True)
@staticmethod
def linear_interpolation(z_0, z_1, n, exclude_last=False):
delta = 0 if not exclude_last else 1.0 / n
t = np.linspace(0, 1 - delta, n)[..., None]
return z_0[None] * (1.0 - t) + z_1[None] * t
def analyze(self):
if self.zdim == 0:
self.logger.info("No analyses available for homogeneous reconstruction!")
return
if self.zdim == 1:
self.analyze_z1()
else:
self.analyze_zN()
for ipynb in ["cryoDRGN_figures", "cryoDRGN_filtering"]:
out_ipynb = os.path.join(self.outdir, f"{ipynb}.ipynb")
template_dir = os.path.join(CRYODRGN_ROOT, "templates")
if not os.path.exists(out_ipynb):
self.logger.info(f"Creating demo Jupyter notebook {out_ipynb}...")
ipynb = os.path.join(template_dir, f"{ipynb}_template.ipynb")
shutil.copyfile(ipynb, out_ipynb)
else:
self.logger.info(f"{out_ipynb} already exists. Skipping")
# Edit the notebook with the epoch to analyze
with open(out_ipynb, "r") as f:
filter_ntbook = nbformat.read(f, as_version=nbformat.NO_CONVERT)
for cell in filter_ntbook["cells"]:
cell["source"] = cell["source"].replace(
"WORKDIR = None", f"WORKDIR = '{self.outdir}'"
)
cell["source"] = cell["source"].replace(
"EPOCH = None", f"EPOCH = {self.epoch}"
)
cell["source"] = cell["source"].replace(
"KMEANS = None", f"KMEANS = {self.configs.ksample}"
)
with open(out_ipynb, "w") as f:
nbformat.write(filter_ntbook, f)
self.logger.info("Done")
def analyze_z1(self) -> None:
"""Plotting and volume generation for 1D z"""
assert self.z.shape[1] == 1
z = self.z.reshape(-1)
n = len(z)
plt.figure(1)
plt.scatter(np.arange(n), z, alpha=0.1, s=2)
plt.xlabel("particle")
plt.ylabel("z")
plt.savefig(os.path.join(self.outdir, "z.png"))
plt.close()
plt.figure(2)
sns.distplot(z)
plt.xlabel("z")
plt.savefig(os.path.join(self.outdir, "z_hist.png"))
plt.close()
ztraj = np.percentile(z, np.linspace(5, 95, 10))
self.vg.gen_volumes(self.outdir, ztraj)
kmeans_labels, centers = analysis.cluster_kmeans(
z[..., None], self.configs.ksample, reorder=False
)
centers, centers_ind = analysis.get_nearest_point(z[:, None], centers)
volpath = os.path.join(self.outdir, f"kmeans{self.configs.ksample}")
self.vg.gen_volumes(volpath, centers)
def analyze_zN(self) -> None:
zdim = self.z.shape[1]
# Principal component analysis
self.logger.info("Performing principal component analysis...")
pc, pca = analysis.run_pca(self.z)
self.logger.info("Generating volumes...")
for i in range(self.configs.pc):
start, end = np.percentile(pc[:, i], (5, 95))
z_pc = analysis.get_pc_traj(
pca, self.z.shape[1], self.configs.n_per_pc, i + 1, start, end
)
volpath = os.path.join(self.outdir, f"pc{i + 1}")
self.vg.gen_volumes(volpath, z_pc)
# Kmeans clustering
self.logger.info("K-means clustering...")
k = min(self.configs.ksample, self.n_samples)
if self.n_samples < self.configs.ksample:
self.logger.warning(f"Changing ksample to # of samples: {self.n_samples}")
kmeans_labels, centers = analysis.cluster_kmeans(self.z, k)
centers, centers_ind = analysis.get_nearest_point(self.z, centers)
kmean_path = os.path.join(self.outdir, f"kmeans{k}")
os.makedirs(kmean_path, exist_ok=True)
utils.save_pkl(kmeans_labels, os.path.join(kmean_path, "labels.pkl"))
np.savetxt(os.path.join(kmean_path, "centers.txt"), centers)
np.savetxt(os.path.join(kmean_path, "centers_ind.txt"), centers_ind, fmt="%d")
self.logger.info("Generating volumes...")
self.vg.gen_volumes(kmean_path, centers)
# UMAP -- slow step
umap_emb = None
if zdim > 2 and not self.configs.skip_umap:
self.logger.info("Running UMAP...")
if self.n_samples and self.n_samples < 15:
n_neighbours = self.n_samples - 1
else:
n_neighbours = 15
umap_emb = analysis.run_umap(self.z, n_neighbors=n_neighbours)
utils.save_pkl(umap_emb, os.path.join(self.outdir, "umap.pkl"))
# Make some plots
self.logger.info("Generating plots...")
def plt_pc_labels(pc1=0, pc2=1):
plt.xlabel(f"PC{pc1 + 1} " f"({pca.explained_variance_ratio_[pc1]:.2f})")
plt.ylabel(f"PC{pc2 + 1} " f"({pca.explained_variance_ratio_[pc2]:.2f})")
def plt_pc_labels_jointplot(g, pc1=0, pc2=1):
g.ax_joint.set_xlabel(
f"PC{pc1 + 1} ({pca.explained_variance_ratio_[pc1]:.2f})"
)
g.ax_joint.set_ylabel(
f"PC{pc2 + 1} ({pca.explained_variance_ratio_[pc2]:.2f})"
)
def plt_umap_labels():
plt.xticks([])
plt.yticks([])
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
def plt_umap_labels_jointplot(g):
g.ax_joint.set_xlabel("UMAP1")
g.ax_joint.set_ylabel("UMAP2")
# PCA -- Style 1 -- Scatter
plt.figure(figsize=(4, 4))
plt.scatter(pc[:, 0], pc[:, 1], alpha=0.1, s=1, rasterized=True)
plt_pc_labels()
plt.tight_layout()
plt.savefig(os.path.join(self.outdir, "z_pca.png"))
plt.close()
# PCA -- Style 2 -- Scatter, with marginals
g = sns.jointplot(
x=pc[:, 0], y=pc[:, 1], alpha=0.1, s=1, rasterized=True, height=4
)
plt_pc_labels_jointplot(g)
plt.tight_layout()
plt.savefig(os.path.join(self.outdir, "z_pca_marginals.png"))
plt.close()
# PCA -- Style 3 -- Hexbin
g = sns.jointplot(x=pc[:, 0], y=pc[:, 1], height=4, kind="hex")
plt_pc_labels_jointplot(g)
plt.tight_layout()
plt.savefig(os.path.join(self.outdir, "z_pca_hexbin.png"))
plt.close()
if umap_emb is not None:
# Style 1 -- Scatter
plt.figure(figsize=(4, 4))
plt.scatter(umap_emb[:, 0], umap_emb[:, 1], alpha=0.1, s=1, rasterized=True)
plt_umap_labels()
plt.tight_layout()
plt.savefig(os.path.join(self.outdir, "umap.png"))
plt.close()
# Style 2 -- Scatter with marginal distributions
g = sns.jointplot(
x=umap_emb[:, 0],
y=umap_emb[:, 1],
alpha=0.1,
s=1,
rasterized=True,
height=4,
)
plt_umap_labels_jointplot(g)
plt.tight_layout()
plt.savefig(os.path.join(self.outdir, "umap_marginals.png"))
plt.close()
# Style 3 -- Hexbin / heatmap
g = sns.jointplot(x=umap_emb[:, 0], y=umap_emb[:, 1], kind="hex", height=4)
plt_umap_labels_jointplot(g)
plt.tight_layout()
plt.savefig(os.path.join(self.outdir, "umap_hexbin.png"))
plt.close()
# Plot kmeans sample points
colors = analysis._get_chimerax_colors(k)
analysis.scatter_annotate(
pc[:, 0],
pc[:, 1],
centers_ind=centers_ind,
annotate=True,
colors=colors,
)
plt_pc_labels()
plt.tight_layout()
plt.savefig(os.path.join(kmean_path, "z_pca.png"))
plt.close()
g = analysis.scatter_annotate_hex(
pc[:, 0],
pc[:, 1],
centers_ind=centers_ind,
annotate=True,
colors=colors,
)
plt_pc_labels_jointplot(g)
plt.tight_layout()
plt.savefig(os.path.join(kmean_path, "z_pca_hex.png"))
plt.close()
if umap_emb is not None:
analysis.scatter_annotate(
umap_emb[:, 0],
umap_emb[:, 1],
centers_ind=centers_ind,
annotate=True,
colors=colors,
)
plt_umap_labels()
plt.tight_layout()
plt.savefig(os.path.join(kmean_path, "umap.png"))
plt.close()
g = analysis.scatter_annotate_hex(
umap_emb[:, 0],
umap_emb[:, 1],
centers_ind=centers_ind,
annotate=True,
colors=colors,
)
plt_umap_labels_jointplot(g)
plt.tight_layout()
plt.savefig(os.path.join(kmean_path, "umap_hex.png"))
plt.close()
# Plot PC trajectories
for i in range(self.configs.pc):
start, end = np.percentile(pc[:, i], (5, 95))
pc_path = os.path.join(self.outdir, f"pc{i + 1}")
z_pc = analysis.get_pc_traj(pca, self.z.shape[1], 10, i + 1, start, end)
if umap_emb is not None:
# UMAP, colored by PCX
analysis.scatter_color(
umap_emb[:, 0],
umap_emb[:, 1],
pc[:, i],
label=f"PC{i + 1}",
)
plt_umap_labels()
plt.tight_layout()
plt.savefig(os.path.join(pc_path, "umap.png"))
plt.close()
# UMAP, with PC traversal
z_pc_on_data, pc_ind = analysis.get_nearest_point(self.z, z_pc)
dists = ((z_pc_on_data - z_pc) ** 2).sum(axis=1) ** 0.5
if np.any(dists > 2):
self.logger.warning(
f"Warning: PC{i + 1} point locations "
"in UMAP plot may be inaccurate"
)
plt.figure(figsize=(4, 4))
plt.scatter(
umap_emb[:, 0], umap_emb[:, 1], alpha=0.05, s=1, rasterized=True
)
plt.scatter(
umap_emb[pc_ind, 0],
umap_emb[pc_ind, 1],
c="cornflowerblue",
edgecolor="black",
)
plt_umap_labels()
plt.tight_layout()
plt.savefig(os.path.join(pc_path, "umap_traversal.png"))
plt.close()
# UMAP, with PC traversal, connected
plt.figure(figsize=(4, 4))
plt.scatter(
umap_emb[:, 0], umap_emb[:, 1], alpha=0.05, s=1, rasterized=True
)
plt.plot(umap_emb[pc_ind, 0], umap_emb[pc_ind, 1], "--", c="k")
plt.scatter(
umap_emb[pc_ind, 0],
umap_emb[pc_ind, 1],
c="cornflowerblue",
edgecolor="black",
)
plt_umap_labels()
plt.tight_layout()
plt.savefig(os.path.join(pc_path, "umap_traversal_connected.png"))
plt.close()
# 10 points, from 5th to 95th percentile of PC1 values
t = np.linspace(start, end, 10, endpoint=True)
plt.figure(figsize=(4, 4))
if i > 0 and i == self.configs.pc - 1:
plt.scatter(pc[:, i - 1], pc[:, i], alpha=0.1, s=1, rasterized=True)
plt.scatter(np.zeros(10), t, c="cornflowerblue", edgecolor="white")
plt_pc_labels(i - 1, i)
else:
plt.scatter(pc[:, i], pc[:, i + 1], alpha=0.1, s=1, rasterized=True)
plt.scatter(t, np.zeros(10), c="cornflowerblue", edgecolor="white")
plt_pc_labels(i, i + 1)
plt.tight_layout()
plt.savefig(os.path.join(pc_path, "pca_traversal.png"))
plt.close()
if i > 0 and i == self.configs.pc - 1:
g = sns.jointplot(
x=pc[:, i - 1],
y=pc[:, i],
alpha=0.1,
s=1,
rasterized=True,
height=4,
)
g.ax_joint.scatter(
np.zeros(10), t, c="cornflowerblue", edgecolor="white"
)
plt_pc_labels_jointplot(g, i - 1, i)
else:
g = sns.jointplot(
x=pc[:, i],
y=pc[:, i + 1],
alpha=0.1,
s=1,
rasterized=True,
height=4,
)
g.ax_joint.scatter(
t, np.zeros(10), c="cornflowerblue", edgecolor="white"
)
plt_pc_labels_jointplot(g)
plt.tight_layout()
plt.savefig(os.path.join(pc_path, "pca_traversal_hex.png"))
plt.close()
================================================
FILE: cryodrgn/beta_schedule.py
================================================
import numpy as np
def get_beta_schedule(schedule):
if isinstance(schedule, float):
return ConstantSchedule(schedule)
elif schedule == "a":
return LinearSchedule(0.001, 15, 0, 1000000)
elif schedule == "b":
return LinearSchedule(5, 15, 200000, 800000)
elif schedule == "c":
return LinearSchedule(5, 18, 200000, 800000)
elif schedule == "d":
return LinearSchedule(5, 18, 1000000, 5000000)
else:
raise RuntimeError("Wrong beta schedule. Schedule={}".format(schedule))
class ConstantSchedule:
def __init__(self, value):
self.value = value
def __call__(self, x):
return self.value
class LinearSchedule:
def __init__(self, start_y, end_y, start_x, end_x):
self.min_y = min(start_y, end_y)
self.max_y = max(start_y, end_y)
self.start_x = start_x
self.start_y = start_y
self.coef = (end_y - start_y) / (end_x - start_x)
def __call__(self, x):
return np.clip(
(x - self.start_x) * self.coef + self.start_y, self.min_y, self.max_y
).item(0)
================================================
FILE: cryodrgn/command_line.py
================================================
"""Creating the commands installed with cryoDRGN using the package's modules.
Here we add modules under the `cryodrgn.commands` and `cryodrgn.commands_utils` folders
to the namespace of commands that are installed as part of the cryoDRGN package.
Each module in the former folder thus corresponds to a `cryodrgn <module_name>` command,
while those in the latter folder correspond to a `cryodrgn_utils <module_name>` command.
See the `[project.scripts]` entry in the `pyproject.toml` file for how this module
is used to create the commands during installation. We list the modules to use
explicitly for each folder in case the namespace is inadvertantly polluted, and also
since automated scanning for command modules is computationally non-trivial.
"""
import argparse
import os
from importlib import import_module
import re
import cryodrgn
def _get_commands(cmd_dir: str, cmds: list[str], doc_str: str = "") -> None:
"""Start up a command line interface using given modules as subcommands.
Arguments
---------
cmd_dir: path to folder containing cryoDRGN command modules
cmds: list of commands in the above directory we want to use in the package
doc_str: short documentation string describing this list of commands as a whole
"""
parser = argparse.ArgumentParser(description=doc_str)
parser.add_argument(
"--version", action="version", version="cryoDRGN " + cryodrgn.__version__
)
subparsers = parser.add_subparsers(title="Choose a command")
subparsers.required = True
dir_lbl = os.path.basename(cmd_dir)
# look for Python modules that have the `add_args` method defined, which is what we
# use to mark a module in these directories as added to the command namespace
for cmd in cmds:
module_name = ".".join(["cryodrgn", dir_lbl, cmd])
module = import_module(module_name)
if not hasattr(module, "add_args"):
raise RuntimeError(
f"Module `{cmd}` under `{cmd_dir}` does not have the required "
f"`add_args()` function defined; see other modules under the "
f"same directory for examples!"
)
# Parse the module-level documentation appearing at the top of the file
parsed_doc = module.__doc__.split("\n") if module.__doc__ else list()
descr_txt = parsed_doc[0] if parsed_doc else ""
epilog_txt = "" if len(parsed_doc) <= 1 else "\n".join(parsed_doc[1:])
# We have to manually re-add the backslashes used to break up lines
# for multi-line examples as these get parsed into spaces by .__doc__
# NOTE: This means command docstrings shouldn't otherwise have
# consecutive spaces!
epilog_txt = re.sub(" ([ ]+)", " \\\n\\1", epilog_txt)
# the docstring header becomes the help message "description", while
# the rest of the docstring becomes the "epilog"
this_parser = subparsers.add_parser(
cmd,
description=descr_txt,
epilog=epilog_txt,
formatter_class=argparse.RawTextHelpFormatter,
)
module.add_args(this_parser)
this_parser.set_defaults(func=module.main)
args = parser.parse_args()
args.func(args)
def main_commands() -> None:
"""Primary commands installed with cryoDRGN as `cryodrgn <cmd_module_name>."""
_get_commands(
cmd_dir=os.path.join(os.path.dirname(__file__), "commands"),
cmds=[
"abinit",
"abinit_het_old",
"abinit_homo_old",
"analyze",
"analyze_landscape",
"analyze_landscape_full",
"backproject_voxel",
"dashboard",
"direct_traversal",
"downsample",
"eval_images",
"eval_vol",
"filter",
"graph_traversal",
"parse_ctf_csparc",
"parse_ctf_star",
"parse_pose_csparc",
"parse_pose_star",
"parse_star",
"pc_traversal",
"train_nn",
"train_vae",
"train_dec",
],
doc_str="Commands installed with cryoDRGN",
)
def util_commands() -> None:
"""Utility commands installed with cryoDRGN as `cryodrgn_utils <cmd_module_name>."""
_get_commands(
cmd_dir=os.path.join(os.path.dirname(__file__), "commands_utils"),
cmds=[
"analyze_convergence",
"add_psize",
"clean",
"concat_pkls",
"filter_cs",
"filter_mrcs",
"filter_pkl",
"filter_star",
"flip_hand",
"fsc",
"gen_mask",
"invert_contrast",
"make_movies",
"parse_relion",
"phase_flip",
"plot_classes",
"plot_fsc",
"select_clusters",
"select_random",
"translate_mrcs",
"view_cs_header",
"view_header",
"view_mrcs",
"write_cs",
"write_star",
],
doc_str="Utility commands installed with cryoDRGN",
)
================================================
FILE: cryodrgn/commands/README.md
================================================
# cryoDRGN commands #
This folder contains the primary commands that are installed as part of the cryoDRGN package, as well as any associated
auxiliary files.
See `cryodrgn.command_line` for how the contents of this folder are parsed as part of creating the cryoDRGN command
line interface upon installation of the package.
See also the `cryodrgn/commands_utils/` folder for the utility commands that are the other part of the cryoDRGN command
line interface.
================================================
FILE: cryodrgn/commands/__init__.py
================================================
================================================
FILE: cryodrgn/commands/abinit.py
================================================
"""Reconstructing volume(s) from picked cryoEM/ET particles using cryoDRGN-AI.
Example usage
-------------
# Run with fifty total training epochs, the first three of which are for pose search
$ cryodrgn abinit particles.mrcs -o cryodrgn-outs/001_abinit --zdim 8 \
--ctf ctf.pkl -n 50
"""
import os
import sys
import argparse
import pickle
import logging
from datetime import datetime as dt
import time
from types import SimpleNamespace
from typing_extensions import Any
import contextlib
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from cryodrgn.mrcfile import write_mrc
from cryodrgn.lattice import Lattice
from cryodrgn import utils, dataset, ctf
from cryodrgn.losses import kl_divergence_conf, l1_regularizer, l2_frequency_bias
from cryodrgn.models_ai import DrgnAI, MyDataParallel
from cryodrgn.masking import CircularMask, FrequencyMarchingMask
from cryodrgn.analysis_drgnai import ModelAnalyzer
from cryodrgn.config import save as save_config
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
def add_args(parser: argparse.ArgumentParser) -> None:
"""The command-line arguments for use with the command `cryodrgn abinit`."""
parser.add_argument(
"particles",
type=os.path.abspath,
help="Input particles (.mrcs, .star, .cs, or .txt)",
)
parser.add_argument(
"-o",
"--outdir",
type=os.path.abspath,
required=True,
help="Output directory to save model",
)
parser.add_argument(
"--load",
type=str,
help="Path to a checkpoint (weights.<epoch>.pkl) to load.",
)
parser.add_argument(
"--load-poses",
type=str,
help="Path to a pose file (pose.<epoch>.pkl) to load.",
)
parser.add_argument(
"--seed",
type=int,
default=np.random.randint(0, 100000),
help="Fix the random seed used by numpy and PyTorch operations",
)
parser.add_argument(
"-v", "--verbose", action="store_true", help="Increase verbosity"
)
parser.add_argument(
"--ctf",
type=os.path.abspath,
help="Path to CTF parameters (.pkl). Must be the size of the full dataset.",
)
parser.add_argument(
"--datadir",
type=os.path.abspath,
help="When using a .star or .cs file with relative paths, "
"path to the directory containing the .mrcs files.",
)
# Dataset loading
group = parser.add_argument_group("Dataset loading")
group.add_argument(
"--ind",
type=os.path.abspath,
help="Path to indices (.pkl) or number of images to keep (first images kept). "
"Use full dataset if None.",
)
group.add_argument(
"--relion31",
action="store_true",
help="Flag for relion 3.1 data format.",
)
group.add_argument(
"--uninvert-data",
dest="invert_data",
action="store_false",
help="Flag for not inverting input data (e.g. for EMPIAR-10076).",
)
group.add_argument(
"--lazy",
action="store_true",
help="Flag for lazy data loading.",
)
group.add_argument(
"--max-threads",
type=int,
default=16,
help="Number of threads (default: %(default)s).",
)
group = parser.add_argument_group("Logging")
group.add_argument(
"--log-interval",
type=int,
default=10000,
help="Logging interval in N_IMGS (default: %(default)s)",
)
group.add_argument(
"--checkpoint",
type=int,
default=5,
help="Checkpointing interval in N_EPOCHS (default: %(default)s)",
)
group.add_argument(
"--verbose-time",
action="store_true",
help="Print time taken for each training step",
)
group = parser.add_argument_group("Training parameters")
group.add_argument(
"--num-epochs",
"-n",
type=int,
default=30,
help="Number of total epochs to train for (default: %(default)s)",
)
group.add_argument(
"--epochs-pose-search",
type=int,
default=None,
help="Number of epochs to train for pose search (default: %(default)s)",
)
group.add_argument(
"--n-imgs-pose-search",
type=int,
default=None,
help="Number of images to train for pose search (default: %(default)s)",
)
group.add_argument(
"--epochs-sgd",
type=int,
default=None,
help="Number of epochs to train for SGD (default: %(default)s)",
)
parser.add_argument(
"--n-imgs-pretrain",
type=int,
default=10000,
help="Number of images to use for pre-training (default: %(default)s).",
)
group.add_argument(
"--pose-only-phase",
type=int,
default=0,
help="Number of epochs to train for pose only phase (default: %(default)s)",
)
group.add_argument(
"--no-shuffle",
dest="shuffle",
action="store_false",
help="Disable shuffling of the dataset for training batches.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="Number of subprocesses to use for data loading (default: %(default)s)",
)
group.add_argument(
"--shuffler-size",
type=int,
default=32768,
help="Size of the shuffler when using accelerated data loading "
"(default: %(default)s).",
)
group.add_argument(
"--multigpu",
action="store_true",
help="Activate multi-GPU mode if more than one GPU is available.",
)
group.add_argument(
"--no-amp",
action="store_false",
dest="amp",
help="Disable automatic mixed precision (torch.amp).",
)
group.add_argument(
"--batch-size-hps",
type=int,
default=16,
help="Training batch size used for hierarchical pose search "
"(default: %(default)s)",
)
group.add_argument(
"--batch-size-known-poses",
type=int,
default=64,
help="Training batch size used for pose refinement (default: %(default)s)",
)
group.add_argument(
"--batch-size-sgd",
type=int,
default=128,
help="Training batch size used for stochastic gradient descent "
"(default: %(default)s)",
)
group = parser.add_argument_group("Optimizers")
group.add_argument(
"--lr",
type=float,
default=1e-4,
help="Learning rate for the optimizer (default: %(default)s)",
)
group.add_argument(
"--lr-pose-table",
type=float,
default=1e-3,
help="Learning rate for the pose table optimizer (default: %(default)s)",
)
group.add_argument(
"--lr-conf-table",
type=float,
default=1e-2,
help="Learning rate for the conf table optimizer (default: %(default)s)",
)
group.add_argument(
"--lr-conf-encoder",
type=float,
default=1e-4,
help="Learning rate for the conf encoder optimizer (default: %(default)s)",
)
group.add_argument(
"--wd",
type=float,
default=0.0,
help="Weight decay for the optimizer (default: %(default)s)",
)
group.add_argument(
"--hypervolume-optimizer-type",
choices=("adam",),
default="adam",
help="Optimizer type for the hypervolume (default: %(default)s).",
)
group.add_argument(
"--pose-table-optimizer-type",
choices=("adam", "lbfgs"),
default="adam",
help="Optimizer for the pose table (default: %(default)s).",
)
group.add_argument(
"--conf-table-optimizer-type",
choices=("adam", "lbfgs"),
default="adam",
help="Optimizer for the conformation table (default: %(default)s).",
)
group.add_argument(
"--conf-encoder-optimizer-type",
choices=("adam",),
default="adam",
help="Optimizer for the conformation encoder (default: %(default)s).",
)
group = parser.add_argument_group("Masking")
group.add_argument(
"--output-mask",
choices=("circ", "frequency_marching"),
default="circ",
help="Type of output mask to use (default: %(default)s)",
)
group.add_argument(
"--add-one-frequency-every",
type=int,
default=100000,
help="Frequency (in images) for adding new frequencies in the output mask "
"during HPS (default: %(default)s).",
)
group.add_argument(
"--n-frequencies-per-epoch",
type=int,
default=10,
help="Number of frequencies to add in the output mask at each epoch "
"during SGD (default: %(default)s).",
)
group.add_argument(
"--max-freq",
type=int,
help="Highest frequency to use in the loss. Use all frequencies if not set.",
)
group.add_argument(
"--window-radius-gt-real",
type=float,
default=0.85,
help="Radius of the circular mask applied on images in real space; "
"maximum radius is 1 (default: %(default)s).",
)
group = parser.add_argument_group("Losses")
group.add_argument(
"--beta-conf",
type=float,
default=0.0,
help="Beta term penalizing KL divergence of conformation posterior; "
"only used in variational mode (default: %(default)s).",
)
group.add_argument(
"--trans-l1-regularizer",
type=float,
default=0.0,
help="Strength of the L1 regularizer applied on estimated "
"translations (default: %(default)s).",
)
group.add_argument(
"--l2-smoothness-regularizer",
type=float,
default=0.0,
help="Strength of the L2 smoothness regularization "
"(penalization of strong gradients) (default: %(default)s).",
)
group = parser.add_argument_group("Z / heterogeneity")
group.add_argument(
"--variational-het",
action="store_true",
help="Activate variational mode of conformation estimation.",
)
group.add_argument(
"--zdim",
type=int,
help="Dimension of conformation latent space; required.",
required=True,
)
group.add_argument(
"--std-z-init",
type=float,
default=0.1,
help="Standard deviation of initial conformations "
"(i.i.d. centered Gaussian) (default: %(default)s).",
)
group.add_argument(
"--use-conf-encoder",
action="store_true",
help="Use an encoder to predict conformations.",
)
group.add_argument(
"--depth-cnn",
type=int,
default=5,
help="Depth of the encoder (default: %(default)s).",
)
group.add_argument(
"--channels-cnn",
type=int,
default=32,
help="Number of channels in the encoder (default: %(default)s).",
)
group.add_argument(
"--kernel-size-cnn",
type=int,
default=3,
help="Size of the kernels in the encoder (default: %(default)s).",
)
group.add_argument(
"--resolution-encoder",
type=int,
help="Resolution of images given to the encoder. "
"Images are not downsampled if not set.",
)
group = parser.add_argument_group("Hypervolume")
group.add_argument(
"--explicit-volume",
action="store_true",
help="Use an explicit volume (voxel array).",
)
group.add_argument(
"--layers",
type=int,
default=3,
help="Number of hidden layers in the hypervolume (default: %(default)s).",
)
group.add_argument(
"--dim",
type=int,
default=256,
help="Dimension of hidden layers in the hypervolume (default: %(default)s).",
)
group.add_argument(
"--pe-type",
choices=("gaussian",),
default="gaussian",
help="Type of positional encoding for Fourier coordinates "
"(default: %(default)s).",
)
group.add_argument(
"--pe-dim",
type=int,
default=64,
help="Number of frequencies used for "
"positional encoding (default: %(default)s).",
)
group.add_argument(
"--feat-sigma",
type=float,
default=0.5,
help="Standard deviation of encoding frequencies (default: %(default)s).",
)
group.add_argument(
"--hypervolume-domain",
choices=("hartley",),
default="hartley",
help="Domain of the hypervolume (default: %(default)s).",
)
group.add_argument(
"--pe-type-conf",
choices=(None, "geom"),
default=None,
help="Type of positional encoding for conformations (default: None).",
)
group.add_argument(
"--initial-conf",
type=os.path.abspath,
help="Path to initial conformations (.pkl). "
"Conformations are randomly initialized if not set.",
)
group = parser.add_argument_group("Pose search")
group.add_argument(
"--l-start",
type=int,
default=12,
help="Number of frequencies to use during the first "
"pose search step (default: %(default)s).",
)
group.add_argument(
"--l-end",
type=int,
default=32,
help="Number of frequencies to use during the last pose search step "
"(default: %(default)s).",
)
group.add_argument(
"--niter",
type=int,
default=4,
help="Number of pose search iterations (default: %(default)s).",
)
group.add_argument(
"--t-extent",
type=float,
default=20.0,
help="Extent of the translation search grid, in pixels (default: %(default)s).",
)
group.add_argument(
"--t-ngrid",
type=int,
default=7,
help="Number of points per dimension in the translation search grid "
"(default: %(default)s).",
)
group.add_argument(
"--t-xshift",
type=float,
default=0.0,
help="X-axis shift of the translation search grid (default: %(default)s).",
)
group.add_argument(
"--t-yshift",
type=float,
default=0.0,
help="Y-axis shift of the translation search grid (default: %(default)s).",
)
group.add_argument(
"--no-trans-search-at-pose-search",
action="store_true",
help="Bypass the translation search during pose search.",
)
group.add_argument(
"--nkeptposes",
type=int,
default=8,
help="Number of poses kept per image (default: %(default)s).",
)
group.add_argument(
"--base-healpy",
type=int,
default=2,
help="Base healpy index (default: %(default)s).",
)
group.add_argument(
"--no-trans",
action="store_true",
help="Indicate that the dataset does not contain translations.",
)
parser.add_argument(
"--norm",
type=float,
nargs=2,
default=None,
help="Data normalization as shift, 1/scale (default: mean, std of dataset)",
)
parser.add_argument(
"--no-analysis",
dest="do_analysis",
action="store_false",
help="Do not run analysis on the final training epoch",
)
class ModelTrainer:
"""An engine for training the DRGN-AI reconstruction model on particle data.
The two key methods of this engine class are the `__init__()` method, in which
model parameters and data structures are initialized, and `train()`, in which the
model is trained in batches over the particle input data.
Attributes
----------
configs (TrainingConfigurations)
Values of all user-set parameters controlling the behaviour of the model.
outdir (str): Folder `out/` within the experiment working directory where
model results will be saved.
n_particles_dataset (int): The number of picked particles in the data.
pretraining (bool): Whether we are in the pretraining stage.
epoch (int): Which training epoch the model is presently in.
logger (logging.Logger): Utility for printing and writing information
about the model as it is running.
"""
# options for optimizers to use
optim_types = {"adam": torch.optim.Adam, "lbfgs": torch.optim.LBFGS}
# placeholders for runtimes
run_phases = [
"dataloading",
"to_gpu",
"ctf",
"encoder",
"decoder",
"decoder_coords",
"decoder_query",
"loss",
"backward",
"to_cpu",
]
def make_dataloader(self, batch_size: int) -> DataLoader:
return dataset.make_dataloader(
self.data,
batch_size=batch_size,
num_workers=self.configs.num_workers,
shuffle=self.configs.shuffle,
seed=self.configs.seed,
)
def __init__(self, outdir: str, config_vals: dict[str, Any]) -> None:
"""Initialize model parameters and variables.
Arguments
---------
outdir: Location on file where model results will be saved.
config_vals: Parsed model parameter values provided by the user.
"""
self.logger = logging.getLogger(__name__)
self.outdir = outdir
self.configs = SimpleNamespace(**config_vals)
# Create the output folder for model results and log file for model training
os.makedirs(self.outdir, exist_ok=True)
self.logger.addHandler(
logging.FileHandler(os.path.join(self.outdir, "run.log"))
)
self.logger.info(" ".join(sys.argv))
self.logger.info(self.configs)
# Parallelize training across GPUs if --multigpu config option is selected
gpu_count = torch.cuda.device_count()
if self.configs.multigpu and gpu_count > 1:
self.n_prcs = int(gpu_count)
self.logger.info(f"Using {gpu_count} GPUs!")
if self.configs.batch_size_known_poses is not None:
new_batch_size = self.configs.batch_size_known_poses * self.n_prcs
self.logger.info(
f"Increasing batch size for known poses to {new_batch_size}"
)
self.configs.batch_size_known_poses = new_batch_size
if self.configs.batch_size_hps is not None:
new_batch_size = self.configs.batch_size_hps * self.n_prcs
self.logger.info(f"Increasing batch size for HPS to {new_batch_size}")
self.configs.batch_size_hps = new_batch_size
if self.configs.batch_size_sgd is not None:
new_batch_size = self.configs.batch_size_sgd * self.n_prcs
self.logger.info(f"Increasing batch size for SGD to {new_batch_size}")
self.configs.batch_size_sgd = new_batch_size
elif self.configs.multigpu:
self.n_prcs = 1
self.logger.warning(
f"--multigpu selected, but only {gpu_count} GPUs detected!"
)
elif gpu_count > 1:
self.n_prcs = 1
self.logger.warning(
f"Using one GPU in spite of {gpu_count} available GPUs "
f"because --multigpu is not being used!"
)
else:
self.n_prcs = 1
np.random.seed(self.configs.seed)
torch.manual_seed(self.configs.seed)
# Set the compute device, using the first available GPU
self.use_cuda = torch.cuda.is_available()
self.device = torch.device("cuda:0" if self.use_cuda else "cpu")
self.logger.info(f"Use cuda {self.use_cuda}")
# Load the index used to filter particles, if given
if self.configs.ind is not None:
if isinstance(self.configs.ind, int):
self.logger.info(f"Keeping {self.configs.ind} particles")
self.index = np.arange(self.configs.ind)
elif isinstance(self.configs.ind, str):
if not os.path.exists(self.configs.ind):
raise ValueError(
"Given subset index file "
f"`{self.configs.ind}` does not exist!"
)
self.logger.info(f"Filtering dataset with {self.configs.ind}")
self.index = utils.load_pkl(self.configs.ind)
else:
self.index = None
self.logger.info("Creating dataset")
if self.configs.norm is None:
norm_mean, norm_std = None, None
# If user provides two values for data norm use them as the mean and standard
# deviation; if one number is provided, use it as the mean and set std to 1.0
elif isinstance(self.configs.norm, (list, tuple)):
norm_mean, norm_std = self.configs.norm
else:
norm_mean, norm_std = self.configs.norm, 1.0
if norm_mean is not None and norm_std is not None:
data_norm = (norm_mean, norm_std)
elif norm_mean is not None:
data_norm = (norm_mean, 1.0)
elif norm_std is not None:
data_norm = (0.0, norm_std)
else:
data_norm = None
if data_norm is not None:
self.logger.info(
f"Manually overriding data normalization: (mean, std) = {data_norm}"
)
self.data = dataset.ImageDataset(
self.configs.particles,
norm=data_norm,
keepreal=True,
invert_data=self.configs.invert_data,
ind=self.index,
window_r=self.configs.window_radius_gt_real,
max_threads=self.configs.max_threads,
lazy=self.configs.lazy,
datadir=self.configs.datadir,
)
self.n_particles_dataset = self.data.N
self.n_tilts_dataset = self.data.N
self.resolution = self.data.D
# Load contrast transfer function parameters, if given
if self.configs.ctf is not None:
self.logger.info(f"Loading ctf params from {self.configs.ctf}")
ctf_params = ctf.load_ctf_for_training(
self.resolution - 1, self.configs.ctf
)
if self.index is not None:
self.logger.info("Filtering dataset")
ctf_params = ctf_params[self.index]
assert ctf_params.shape == (self.n_tilts_dataset, 8)
self.ctf_params = torch.tensor(ctf_params)
self.ctf_params = self.ctf_params.to(self.device)
else:
self.ctf_params = None
self.apix = self.ctf_params[0, 0] if self.ctf_params is not None else 1
self.logger.info("Building lattice")
self.lattice = Lattice(self.resolution, extent=0.5, device=self.device)
# Set up the output mask
if self.configs.output_mask == "circ":
radius = (
self.lattice.D // 2
if self.configs.max_freq is None
else self.configs.max_freq
)
self.output_mask = CircularMask(self.lattice, radius)
elif self.configs.output_mask == "frequency_marching":
self.output_mask = FrequencyMarchingMask(
self.lattice,
self.lattice.D // 2,
radius=self.configs.l_start_fm,
add_one_every=self.configs.add_one_frequency_every,
)
else:
raise NotImplementedError
# Set up pose search epoch scheduling based on user inputs
# Default number of pose search epochs is 3 (warmup, training, refinement) and
# minimum number of pose search epochs is 2
self.num_epochs = self.configs.num_epochs
if self.configs.n_imgs_pose_search is not None:
self.epochs_pose_search = max(
2, self.configs.n_imgs_pose_search // self.n_particles_dataset + 1
)
elif self.configs.epochs_pose_search is not None:
self.epochs_pose_search = self.configs.epochs_pose_search
else:
self.epochs_pose_search = 3
if self.configs.epochs_sgd is None:
self.configs.epochs_sgd = self.num_epochs - self.epochs_pose_search
else:
self.num_epochs = self.epochs_pose_search + self.configs.epochs_sgd
ps_params = {
"l_min": self.configs.l_start,
"l_max": self.configs.l_end,
"t_extent": self.configs.t_extent,
"t_n_grid": self.configs.t_n_grid,
"niter": self.configs.n_iter,
"nkeptposes": self.configs.n_kept_poses,
"base_healpy": self.configs.base_healpy,
"t_xshift": self.configs.t_x_shift,
"t_yshift": self.configs.t_y_shift,
"no_trans_search_at_pose_search": self.configs.no_trans_search_at_pose_search,
"tilting_func": None,
}
# CNN
cnn_params = {
"conf": self.configs.use_conf_encoder,
"depth_cnn": self.configs.depth_cnn,
"channels_cnn": self.configs.channels_cnn,
"kernel_size_cnn": self.configs.kernel_size_cnn,
}
# Conformational encoder
if self.configs.zdim > 0:
self.logger.info(
"Heterogeneous reconstruction with " f"zdim = {self.configs.zdim}"
)
else:
self.logger.info("Homogeneous reconstruction")
conf_regressor_params = {
"z_dim": self.configs.zdim,
"std_z_init": self.configs.std_z_init,
"variational": self.configs.variational_het,
}
# Hypervolume
hyper_volume_params = {
"explicit_volume": self.configs.explicit_volume,
"n_layers": self.configs.hypervolume_layers,
"hidden_dim": self.configs.hypervolume_dim,
"pe_type": self.configs.pe_type,
"pe_dim": self.configs.pe_dim,
"feat_sigma": self.configs.feat_sigma,
"domain": self.configs.hypervolume_domain,
"extent": self.lattice.extent,
"pe_type_conf": self.configs.pe_type_conf,
}
will_use_point_estimates = self.configs.epochs_sgd >= 1
self.logger.info("Initializing model...")
self.model = DrgnAI(
self.lattice,
self.output_mask,
self.n_particles_dataset,
self.n_tilts_dataset,
cnn_params,
conf_regressor_params,
hyper_volume_params,
resolution_encoder=self.configs.resolution_encoder,
no_trans=self.configs.no_trans,
use_gt_poses=False,
use_gt_trans=False,
will_use_point_estimates=will_use_point_estimates,
ps_params=ps_params,
verbose_time=self.configs.verbose_time,
pretrain_with_gt_poses=False,
)
# Initialization from a checkpoint saved to file from a previous training run
if self.configs.load:
self.logger.info(f"Loading checkpoint from {self.configs.load}")
checkpoint = torch.load(self.configs.load, weights_only=False)
state_dict = checkpoint["model_state_dict"]
if "base_shifts" in state_dict:
state_dict.pop("base_shifts")
self.logger.info(self.model.load_state_dict(state_dict, strict=False))
self.start_epoch = checkpoint["epoch"] + 1
if "output_mask_radius" in checkpoint:
self.output_mask.update_radius(checkpoint["output_mask_radius"])
else:
self.start_epoch = 0 if self.configs.n_imgs_pretrain > 0 else 1
# Move to GPU and parallelize the model if necessary
self.logger.info(self.model)
parameter_count = sum(
p.numel() for p in self.model.parameters() if p.requires_grad
)
self.logger.info(f"{parameter_count} parameters in model")
# TODO: Replace with DistributedDataParallel
if self.n_prcs > 1:
self.model = MyDataParallel(self.model)
self.logger.info("Model initialized. Moving to GPU...")
self.model.to(self.device)
self.model.output_mask.binary_mask = self.model.output_mask.binary_mask.cpu()
self.optimizers = dict()
self.optimizer_types = dict()
# Hypervolume
hyper_volume_params = [{"params": list(self.model.hypervolume.parameters())}]
self.optimizers["hypervolume"] = self.optim_types[
self.configs.hypervolume_optimizer_type
](hyper_volume_params, lr=self.configs.lr)
self.optimizer_types["hypervolume"] = self.configs.hypervolume_optimizer_type
# Pose table
if self.configs.epochs_sgd > 0:
pose_table_params = [{"params": list(self.model.pose_table.parameters())}]
self.optimizers["pose_table"] = self.optim_types[
self.configs.pose_table_optimizer_type
](pose_table_params, lr=self.configs.lr_pose_table)
self.optimizer_types["pose_table"] = self.configs.pose_table_optimizer_type
# Z-latent-space conformations
if self.configs.zdim > 0:
if self.configs.use_conf_encoder:
conf_encoder_params = [
{
"params": (
list(self.model.conf_cnn.parameters())
+ list(self.model.conf_regressor.parameters())
)
}
]
self.optimizers["conf_encoder"] = self.optim_types[
self.configs.conf_encoder_optimizer_type
](
conf_encoder_params,
lr=self.configs.lr_conf_encoder,
weight_decay=self.configs.wd,
)
self.optimizer_types[
"conf_encoder"
] = self.configs.conf_encoder_optimizer_type
else:
conf_table_params = [
{"params": list(self.model.conf_table.parameters())}
]
self.optimizers["conf_table"] = self.optim_types[
self.configs.conf_table_optimizer_type
](conf_table_params, lr=self.configs.lr_conf_table)
self.optimizer_types[
"conf_table"
] = self.configs.conf_table_optimizer_type
self.optimized_modules = []
# Complete initialization from a previous checkpoint
if self.configs.load:
checkpoint = torch.load(self.configs.load, weights_only=False)
for key in self.optimizers:
self.optimizers[key].load_state_dict(
checkpoint["optimizers_state_dict"][key]
)
# Data loaders used to iterated over training batches of input images
self.data_generator_pose_search = self.make_dataloader(
batch_size=self.configs.batch_size_hps
)
self.data_generator = self.make_dataloader(
batch_size=self.configs.batch_size_known_poses
)
self.data_generator_latent_optimization = self.make_dataloader(
batch_size=self.configs.batch_size_sgd
)
# Save configurations within the output directory for future reference
cfg_path = os.path.join(self.outdir, "config.yaml")
data_norm_mean = float(self.data.norm[0])
data_norm_std = float(self.data.norm[1])
payload = {
"dataset_args": dict(
particles=self.configs.particles,
ctf=self.configs.ctf,
invert_data=self.configs.invert_data,
ind=self.configs.ind,
datadir=self.configs.datadir,
),
"lattice_args": dict(
D=self.lattice.D,
extent=self.lattice.extent,
ignore_DC=self.lattice.ignore_DC,
),
"model_args": dict(
zdim=self.configs.zdim,
use_conf_encoder=self.configs.use_conf_encoder,
pe_type=self.configs.pe_type,
pe_dim=self.configs.pe_dim,
feat_sigma=self.configs.feat_sigma,
domain=self.configs.hypervolume_domain,
),
"training": dict(vars(self.configs)),
"data_norm_mean": data_norm_mean,
"data_norm_std": data_norm_std,
}
save_config(payload, cfg_path)
epsilon = 1e-8
# Booleans used to track the current state of the training process
self.log_latents = False
self.pose_only = True
self.pretraining = False
self.is_in_pose_search_step = False
self.use_point_estimates = False
self.first_switch_to_point_estimates = True
self.first_switch_to_point_estimates_conf = True
if self.configs.load is not None:
if self.start_epoch > self.epochs_pose_search:
self.first_switch_to_point_estimates = False
self.first_switch_to_point_estimates_conf = False
self.use_kl_divergence = (
not self.configs.zdim == 0
and self.configs.variational_het
and self.configs.beta_conf >= epsilon
)
self.use_trans_l1_regularizer = (
self.configs.trans_l1_regularizer >= epsilon
and not self.configs.use_gt_trans
and not self.configs.no_trans
)
self.use_l2_smoothness_regularizer = (
self.configs.l2_smoothness_regularizer >= epsilon
)
self.n_particles_pretrain = (
self.configs.n_imgs_pretrain
if self.configs.n_imgs_pretrain >= 0
else self.n_particles_dataset
)
# Placeholders for predicted latent variables, last input/output batch, losses
self.in_dict_last = None
self.y_pred_last = None
self.predicted_rots = np.empty((self.n_tilts_dataset, 3, 3))
self.predicted_trans = (
np.empty((self.n_tilts_dataset, 2)) if not self.configs.no_trans else None
)
self.predicted_conf = (
np.empty((self.n_particles_dataset, self.configs.zdim))
if self.configs.zdim > 0
else None
)
self.predicted_logvar = (
np.empty((self.n_particles_dataset, self.configs.zdim))
if self.configs.zdim > 0 and self.configs.variational_het
else None
)
self.mask_particles_seen_at_last_epoch = np.zeros(self.n_particles_dataset)
self.mask_tilts_seen_at_last_epoch = np.zeros(self.n_tilts_dataset)
# Counters used to track the progress of the training process
self.epoch = None
self.run_times = {phase: [] for phase in self.run_phases}
self.current_epoch_particles_count = 0
self.total_batch_count = 0
self.total_particles_count = 0
self.batch_idx = 0
self.cur_loss = None
self.norm_mean, self.norm_std = self.data.norm
# Activating Automatic Mixed Precision (AMP) model training through `torch.amp`
if self.configs.amp:
self.logger.info("Using Automatic Mixed Precision training via torch.amp")
if self.configs.pose_table_optimizer_type == "lbfgs":
raise ValueError("AMP is not compatible with the lbfgs optimizer!")
if (self.data.D - 1) % 8 != 0:
self.logger.warning(
f"torch.amp mixed precision training is not optimized; "
f"image box size {self.data.D-1} is not a multiple of 8!"
)
if self.configs.zdim > 0 and self.configs.zdim % 8 != 0:
self.logger.warning(
f"torch.amp mixed precision training is not optimized; "
f"{self.configs.zdim=} is not a multiple of 8!"
)
if self.configs.batch_size_hps % 8 != 0:
self.logger.warning(
f"torch.amp mixed precision training is not optimized; "
f"{self.configs.batch_size_hps=} is not a multiple of 8!"
)
if self.configs.batch_size_sgd % 8 != 0:
self.logger.warning(
f"torch.amp mixed precision training is not optimized; "
f"{self.configs.batch_size_sgd=} is not a multiple of 8!"
)
if self.configs.batch_size_known_poses % 8 != 0:
self.logger.warning(
f"torch.amp mixed precision training is not optimized; "
f"{self.configs.batch_size_known_poses=} is not a multiple of 8!"
)
if self.configs.hypervolume_dim % 8 != 0:
self.logger.warning(
f"torch.amp mixed precision training is not optimized; "
f"{self.configs.hypervolume_dim=} is not a multiple of 8!"
)
self.scaler = torch.cuda.amp.GradScaler()
else:
self.scaler = None
def train(self):
self.logger.info("--- Training Starts Now ---")
t_0 = dt.now()
if self.configs.load_poses is not None:
self.logger.info(f"Loading poses from {self.configs.load_poses}")
self.predicted_rots, self.predicted_trans = utils.load_pkl(
self.configs.load_poses
)
else:
self.predicted_rots = (
np.eye(3).reshape(1, 3, 3).repeat(self.n_tilts_dataset, axis=0)
)
self.predicted_trans = (
np.zeros((self.n_tilts_dataset, 2))
if not self.configs.no_trans
else None
)
self.predicted_conf = (
np.zeros((self.n_particles_dataset, self.configs.zdim))
if self.configs.zdim > 0
else None
)
self.total_batch_count = 0
self.total_particles_count = 0
for epoch in range(self.start_epoch, self.num_epochs + 1):
te = dt.now()
self.epoch = epoch
self.mask_particles_seen_at_last_epoch = np.zeros(self.n_particles_dataset)
self.mask_tilts_seen_at_last_epoch = np.zeros(self.n_tilts_dataset)
self.current_epoch_particles_count = 0
self.optimized_modules = ["hypervolume"]
self.pose_only = (
self.total_particles_count < self.configs.pose_only_phase
or self.configs.zdim == 0
or epoch == 0
)
self.pretraining = self.epoch == 0
self.is_in_pose_search_step = 0 < epoch <= self.epochs_pose_search
self.use_point_estimates = epoch > self.epochs_pose_search
n_max_particles = self.n_particles_dataset
data_generator = self.data_generator
# Pre-training
if self.pretraining:
n_max_particles = self.n_particles_pretrain
self.logger.info(f"Will pretrain on {n_max_particles} particles")
# HPS
elif self.is_in_pose_search_step:
n_max_particles = self.n_particles_dataset
self.logger.info(f"Will use pose search on {n_max_particles} particles")
data_generator = self.data_generator_pose_search
# SGD
elif self.use_point_estimates:
if self.first_switch_to_point_estimates:
self.first_switch_to_point_estimates = False
self.logger.info("Switched to autodecoding poses")
self.logger.info(
"Initializing pose table from " "hierarchical pose search"
)
self.model.pose_table.initialize(
self.predicted_rots, self.predicted_trans
)
self.model.to(self.device)
self.logger.info(
"Will use latent optimization on "
f"{self.n_particles_dataset} particles"
)
data_generator = self.data_generator_latent_optimization
self.optimized_modules.append("pose_table")
# GT poses
else:
raise RuntimeError("GT poses are not supported in this mode")
# Z-latent-space conformations
if not self.pose_only:
if self.configs.use_conf_encoder:
self.optimized_modules.append("conf_encoder")
else:
if self.first_switch_to_point_estimates_conf:
self.first_switch_to_point_estimates_conf = False
if self.configs.initial_conf is not None:
self.logger.info(
"Initializing conformation table " "from given z's"
)
self.model.conf_table.initialize(
utils.load_pkl(self.configs.initial_conf)
)
self.model.to(self.device)
self.optimized_modules.append("conf_table")
will_make_summary = epoch % self.configs.log_heavy_interval == 0
will_make_summary |= self.is_in_pose_search_step
will_make_summary |= self.pretraining
will_make_summary |= epoch == self.num_epochs
self.log_latents = will_make_summary
if will_make_summary:
self.logger.info("Will make a full summary at the end of this epoch")
for key in self.run_times.keys():
self.run_times[key] = []
end_time = time.time()
self.cur_loss = 0
# Inner loop
for batch_idx, in_dict in enumerate(data_generator):
self.batch_idx = batch_idx
# with torch.autograd.detect_anomaly():
self.train_step(in_dict, end_time=end_time)
if self.configs.verbose_time:
torch.cuda.synchronize()
end_time = time.time()
if self.current_epoch_particles_count > n_max_particles:
break
total_loss = self.cur_loss / self.current_epoch_particles_count
self.logger.info(
f"# =====> {self.epoch_type()} Epoch: {self.epoch} "
f"finished in {dt.now() - te}; "
f"total loss = {format(total_loss, '.6f')}"
)
# Image and pose summary at the end of each epoch
if will_make_summary:
self.save_latents()
self.save_volume()
self.save_model()
# Update output mask -- epoch-based scaling
if hasattr(self.output_mask, "update_epoch") and self.use_point_estimates:
self.output_mask.update_epoch(self.configs.n_frequencies_per_epoch)
t_total = dt.now() - t_0
self.logger.info(
f"Finished in {t_total} ({t_total / self.num_epochs} per epoch)"
)
def get_ctfs_at(self, index):
batch_size = len(index)
ctf_params_local = (
self.ctf_params[index] if self.ctf_params is not None else None
)
if ctf_params_local is not None:
freqs = self.lattice.freqs2d.unsqueeze(0).expand(
batch_size, *self.lattice.freqs2d.shape
) / ctf_params_local[:, 0].view(batch_size, 1, 1)
ctf_local = ctf.compute_ctf(
freqs, *torch.split(ctf_params_local[:, 1:], 1, 1)
).view(batch_size, self.resolution, self.resolution)
else:
ctf_local = None
return ctf_local
def train_step(self, in_dict, end_time):
if self.configs.verbose_time:
torch.cuda.synchronize()
self.run_times["dataloading"].append(time.time() - end_time)
# Update output mask -- image-based scaling
if hasattr(self.output_mask, "update") and self.is_in_pose_search_step:
self.output_mask.update(self.total_particles_count)
if self.is_in_pose_search_step:
self.model.ps_params["l_min"] = self.configs.l_start
if self.configs.output_mask == "circ":
self.model.ps_params["l_max"] = self.configs.l_end
else:
self.model.ps_params["l_max"] = min(
self.output_mask.current_radius, self.configs.l_end
)
y_gt = in_dict["y"]
ind = in_dict["index"]
in_dict["tilt_index"] = in_dict["index"]
ind_tilt = in_dict["tilt_index"]
self.total_batch_count += 1
batch_size = len(y_gt)
self.total_particles_count += batch_size
self.current_epoch_particles_count += batch_size
# Move to GPU
if self.configs.verbose_time:
torch.cuda.synchronize()
start_time_gpu = time.time()
for key in in_dict.keys():
if in_dict[key] is not None:
in_dict[key] = in_dict[key].to(self.device)
if self.configs.verbose_time:
torch.cuda.synchronize()
self.run_times["to_gpu"].append(time.time() - start_time_gpu)
# Zero grad
for key in self.optimized_modules:
self.optimizers[key].zero_grad()
# Forward pass
if self.scaler is not None and not self.is_in_pose_search_step:
amp_mode = torch.cuda.amp.autocast()
else:
amp_mode = contextlib.nullcontext()
with amp_mode:
latent_variables_dict, y_pred, y_gt_processed = self.forward_pass(in_dict)
if self.n_prcs > 1:
self.model.module.is_in_pose_search_step = False
else:
self.model.is_in_pose_search_step = False
# Loss
if self.configs.verbose_time:
torch.cuda.synchronize()
start_time_loss = time.time()
total_loss, all_losses = self.loss(
y_pred, y_gt_processed, latent_variables_dict
)
if self.configs.verbose_time:
torch.cuda.synchronize()
self.run_times["loss"].append(time.time() - start_time_loss)
# Backward pass
if self.configs.verbose_time:
torch.cuda.synchronize()
start_time_backward = time.time()
if self.scaler is not None:
self.scaler.scale(total_loss).backward()
else:
total_loss.backward()
self.cur_loss += total_loss.item() * len(ind)
for key in self.optimized_modules:
if self.optimizer_types[key] == "adam":
if self.scaler is not None:
self.scaler.step(self.optimizers[key])
else:
self.optimizers[key].step()
elif self.optimizer_types[key] == "lbfgs":
def closure():
self.optimizers[key].zero_grad()
(
_latent_variables_dict,
_y_pred,
_y_gt_processed,
) = self.forward_pass(in_dict)
_loss, _ = self.loss(
_y_pred, _y_gt_processed, _latent_variables_dict
)
_loss.backward()
return _loss.item()
self.optimizers[key].step(closure)
else:
raise NotImplementedError
if self.scaler is not None:
self.scaler.update()
if self.configs.verbose_time:
torch.cuda.synchronize()
self.run_times["backward"].append(time.time() - start_time_backward)
# Detach from GPU
if self.log_latents:
self.in_dict_last = in_dict
self.y_pred_last = y_pred
if self.configs.verbose_time:
torch.cuda.synchronize()
start_time_cpu = time.time()
rot_pred, trans_pred, conf_pred, logvar_pred = self.detach_latent_variables(
latent_variables_dict
)
if self.configs.verbose_time:
torch.cuda.synchronize()
self.run_times["to_cpu"].append(time.time() - start_time_cpu)
# Log
if self.use_cuda:
ind = ind.cpu()
ind_tilt = ind_tilt.cpu()
self.mask_particles_seen_at_last_epoch[ind] = 1
self.mask_tilts_seen_at_last_epoch[ind_tilt] = 1
self.predicted_rots[ind_tilt] = rot_pred.reshape(-1, 3, 3)
if not self.configs.no_trans:
self.predicted_trans[ind_tilt] = trans_pred.reshape(-1, 2)
if self.configs.zdim > 0:
self.predicted_conf[ind] = conf_pred
if self.configs.variational_het:
self.predicted_logvar[ind] = logvar_pred
else:
self.run_times["to_cpu"].append(0.0)
# Scalar summary
if self.total_particles_count % self.configs.log_interval < batch_size:
self.make_light_summary(all_losses)
def detach_latent_variables(self, latent_variables_dict):
rot_pred = latent_variables_dict["R"].detach().cpu().numpy()
trans_pred = (
latent_variables_dict["t"].detach().cpu().numpy()
if not self.configs.no_trans
else None
)
conf_pred = (
latent_variables_dict["z"].detach().cpu().numpy()
if self.configs.zdim > 0 and "z" in latent_variables_dict
else None
)
logvar_pred = (
latent_variables_dict["z_logvar"].detach().cpu().numpy()
if self.configs.zdim > 0 and "z_logvar" in latent_variables_dict
else None
)
return rot_pred, trans_pred, conf_pred, logvar_pred
def forward_pass(self, in_dict):
if self.configs.verbose_time:
torch.cuda.synchronize()
start_time_ctf = time.time()
ctf_local = self.get_ctfs_at(in_dict["tilt_index"])
if self.configs.verbose_time:
torch.cuda.synchronize()
self.run_times["ctf"].append(time.time() - start_time_ctf)
# Forward pass
if "hypervolume" in self.optimized_modules:
self.model.hypervolume.train()
else:
self.model.hypervolume.eval()
if hasattr(self.model, "conf_cnn"):
if hasattr(self.model, "conf_regressor"):
if "conf_encoder" in self.optimized_modules:
self.model.conf_cnn.train()
self.model.conf_regressor.train()
else:
self.model.conf_cnn.eval()
self.model.conf_regressor.eval()
if hasattr(self.model, "pose_table"):
if "pose_table" in self.optimized_modules:
self.model.pose_table.train()
else:
self.model.pose_table.eval()
if hasattr(self.model, "conf_table"):
if "conf_table" in self.optimized_modules:
self.model.conf_table.train()
else:
self.model.conf_table.eval()
in_dict["ctf"] = ctf_local
if self.n_prcs > 1:
self.model.module.pose_only = self.pose_only
self.model.module.use_point_estimates = self.use_point_estimates
self.model.module.pretrain = self.pretraining
self.model.module.is_in_pose_search_step = self.is_in_pose_search_step
self.model.module.use_point_estimates_conf = (
not self.configs.use_conf_encoder
)
else:
self.model.pose_only = self.pose_only
self.model.use_point_estimates = self.use_point_estimates
self.model.pretrain = self.pretraining
self.model.is_in_pose_search_step = self.is_in_pose_search_step
self.model.use_point_estimates_conf = not self.configs.use_conf_encoder
out_dict = self.model(in_dict)
self.run_times["encoder"].append(
torch.mean(out_dict["time_encoder"].cpu())
if self.configs.verbose_time
else 0.0
)
self.run_times["decoder"].append(
torch.mean(out_dict["time_decoder"].cpu())
if self.configs.verbose_time
else 0.0
)
self.run_times["decoder_coords"].append(
torch.mean(out_dict["time_decoder_coords"].cpu())
if self.configs.verbose_time
else 0.0
)
self.run_times["decoder_query"].append(
torch.mean(out_dict["time_decoder_query"].cpu())
if self.configs.verbose_time
else 0.0
)
latent_variables_dict = out_dict
y_pred = out_dict["y_pred"]
y_gt_processed = out_dict["y_gt_processed"]
return latent_variables_dict, y_pred, y_gt_processed
def loss(self, y_pred, y_gt, latent_variables_dict):
"""
y_pred: [batch_size(, n_tilts), n_pts]
y_gt: [batch_size(, n_tilts), n_pts]
"""
all_losses = {}
# Data loss
data_loss = F.mse_loss(y_pred, y_gt)
all_losses["Data Loss"] = data_loss.item()
total_loss = data_loss
# KL divergence
if self.use_kl_divergence:
kld_conf = kl_divergence_conf(latent_variables_dict)
total_loss += self.configs.beta_conf * kld_conf / self.resolution**2
all_losses["KL Div. Conf."] = kld_conf.item()
# L1 regularization for translations
if self.use_trans_l1_regularizer and self.use_point_estimates:
trans_l1_loss = l1_regularizer(latent_variables_dict["t"])
total_loss += self.configs.trans_l1_regularizer * trans_l1_loss
all_losses["L1 Reg. Trans."] = trans_l1_loss.item()
# L2 smoothness prior
if self.use_l2_smoothness_regularizer:
smoothness_loss = l2_frequency_bias(
y_pred,
self.lattice.freqs2d,
self.output_mask.binary_mask,
self.resolution,
)
total_loss += self.configs.l2_smoothness_regularizer * smoothness_loss
all_losses["L2 Smoothness Loss"] = smoothness_loss.item()
return total_loss, all_losses
def make_light_summary(self, all_losses: dict[str, float]) -> None:
"""Creates a log describing progress within batches of a training epoch."""
self.logger.info(
f"# [Train Epoch: {self.epoch}/{self.num_epochs}] "
f"[{self.current_epoch_particles_count}"
f"/{self.n_particles_dataset} particles]"
)
if hasattr(self.output_mask, "current_radius"):
all_losses["Mask Radius"] = self.output_mask.current_radius
if self.model.trans_search_factor is not None:
all_losses["Trans. Search Factor"] = self.model.trans_search_factor
if self.configs.verbose_time:
for key in self.run_times.keys():
self.logger.info(
f"{key} time: {np.mean(np.array(self.run_times[key]))}"
)
def save_latents(self) -> None:
"""Write model's latent variables to file."""
out_pose = os.path.join(self.outdir, f"pose.{self.epoch}.pkl")
if self.configs.no_trans:
with open(out_pose, "wb") as f:
pickle.dump(self.predicted_rots, f)
else:
with open(out_pose, "wb") as f:
pickle.dump((self.predicted_rots, self.predicted_trans), f)
if self.configs.zdim > 0:
out_conf = os.path.join(self.outdir, f"z.{self.epoch}.pkl")
with open(out_conf, "wb") as f:
pickle.dump(self.predicted_conf, f)
def save_volume(self) -> None:
"""Write reconstructed volume to file."""
out_mrc = os.path.join(self.outdir, f"reconstruct.{self.epoch}.mrc")
self.model.hypervolume.eval()
if hasattr(self.model, "conf_cnn"):
if hasattr(self.model, "conf_regressor"):
self.model.conf_cnn.eval()
self.model.conf_regressor.eval()
if hasattr(self.model, "pose_table"):
self.model.pose_table.eval()
if hasattr(self.model, "conf_table"):
self.model.conf_table.eval()
# For heterogeneous models reconstruct the volume at the latent coordinates
# of the image whose embedding is closest to the mean of all embeddings
if self.configs.zdim > 0:
mean_z = np.mean(self.predicted_conf, axis=0)
distances = np.linalg.norm(self.predicted_conf - mean_z, axis=1)
closest_idx = np.argmin(distances)
zval = self.predicted_conf[closest_idx].reshape(-1)
else:
zval = None
vol = self.model.eval_volume(self.data.norm, zval=zval)
write_mrc(out_mrc, vol.cpu().numpy().astype(np.float32), Apix=self.apix)
def save_model(self) -> None:
"""Write current PyTorch model state to file."""
out_weights = os.path.join(self.outdir, f"weights.{self.epoch}.pkl")
optimizers_state_dict = {}
for key in self.optimizers.keys():
optimizers_state_dict[key] = self.optimizers[key].state_dict()
saved_objects = {
"epoch": self.epoch,
"model_state_dict": (
self.model.module.state_dict()
if self.n_prcs > 1
else self.model.state_dict()
),
"hypervolume_state_dict": (
self.model.module.hypervolume.state_dict()
if self.n_prcs > 1
else self.model.hypervolume.state_dict()
),
"hypervolume_params": self.model.hypervolume.get_building_params(),
"optimizers_state_dict": optimizers_state_dict,
}
if hasattr(self.output_mask, "current_radius"):
saved_objects["output_mask_radius"] = self.output_mask.current_radius
torch.save(saved_objects, out_weights)
def epoch_type(self) -> str:
"""Returns a label for the type of epoch currently being run."""
if self.pretraining:
return "Pretrain"
elif self.is_in_pose_search_step:
return "HPS"
else:
return "SGD"
def main(args: argparse.Namespace) -> None:
# Build configs dict from args similar to TrainingConfigurations
cfg = dict(
particles=args.particles,
ctf=args.ctf,
datadir=args.datadir,
ind=args.ind,
relion31=args.relion31,
invert_data=args.invert_data,
load=args.load,
load_poses=args.load_poses,
lazy=args.lazy,
max_threads=args.max_threads,
log_interval=args.log_interval,
log_heavy_interval=args.checkpoint,
verbose_time=args.verbose_time,
shuffle=args.shuffle,
num_workers=args.num_workers,
shuffler_size=args.shuffler_size,
multigpu=args.multigpu,
amp=args.amp,
batch_size_known_poses=args.batch_size_known_poses,
batch_size_hps=args.batch_size_hps,
batch_size_sgd=args.batch_size_sgd,
hypervolume_optimizer_type=args.hypervolume_optimizer_type,
pose_table_optimizer_type=args.pose_table_optimizer_type,
conf_table_optimizer_type=args.conf_table_optimizer_type,
conf_encoder_optimizer_type=args.conf_encoder_optimizer_type,
lr=args.lr,
lr_pose_table=args.lr_pose_table,
lr_conf_table=args.lr_conf_table,
lr_conf_encoder=args.lr_conf_encoder,
wd=args.wd,
n_imgs_pose_search=args.n_imgs_pose_search,
num_epochs=args.num_epochs,
epochs_sgd=args.epochs_sgd,
epochs_pose_search=args.epochs_pose_search,
pose_only_phase=args.pose_only_phase,
output_mask=args.output_mask,
add_one_frequency_every=args.add_one_frequency_every,
n_frequencies_per_epoch=args.n_frequencies_per_epoch,
max_freq=args.max_freq,
window_radius_gt_real=args.window_radius_gt_real,
beta_conf=args.beta_conf,
trans_l1_regularizer=args.trans_l1_regularizer,
l2_smoothness_regularizer=args.l2_smoothness_regularizer,
variational_het=args.variational_het,
zdim=args.zdim,
std_z_init=args.std_z_init,
use_conf_encoder=args.use_conf_encoder,
depth_cnn=args.depth_cnn,
channels_cnn=args.channels_cnn,
kernel_size_cnn=args.kernel_size_cnn,
resolution_encoder=args.resolution_encoder,
explicit_volume=args.explicit_volume,
hypervolume_layers=args.layers,
hypervolume_dim=args.dim,
pe_type=args.pe_type,
pe_dim=args.pe_dim,
feat_sigma=args.feat_sigma,
hypervolume_domain=args.hypervolume_domain,
pe_type_conf=args.pe_type_conf,
n_imgs_pretrain=args.n_imgs_pretrain,
l_start=args.l_start,
l_end=args.l_end,
n_iter=args.niter,
t_extent=args.t_extent,
t_n_grid=args.t_ngrid,
t_x_shift=args.t_xshift,
t_y_shift=args.t_yshift,
no_trans_search_at_pose_search=args.no_trans_search_at_pose_search,
n_kept_poses=args.nkeptposes,
base_healpy=args.base_healpy,
no_trans=args.no_trans,
seed=args.seed,
norm=args.norm,
initial_conf=args.initial_conf,
)
if cfg["load"] is not None:
if cfg["load"].strip().lower() == "latest":
weights_pkl, pose_pkl = utils.get_latest_checkpoint(args.outdir)
cfg["load"] = weights_pkl
elif not os.path.exists(cfg["load"]):
raise ValueError(
f"Invalid load argument which must be a path to "
f"a .pkl file or `latest`: {cfg['load']}"
)
if cfg["load_poses"] is not None:
if not os.path.exists(cfg["load_poses"]):
raise ValueError(
f"Invalid load_poses argument which must be a path to "
f"a .pkl file or `latest`: {cfg['load_poses']}"
)
trainer = ModelTrainer(args.outdir, cfg)
trainer.train()
if args.do_analysis:
anlz_cfgs = {
"workdir": args.outdir,
gitextract_tk31jno4/
├── .flake8
├── .github/
│ ├── CODEOWNERS
│ ├── ISSUE_TEMPLATE/
│ │ └── bug_report.md
│ └── workflows/
│ ├── beta_release.yml
│ ├── release.yml
│ ├── style.yml
│ └── tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE.txt
├── MANIFEST.in
├── README.md
├── analysis_scripts/
│ ├── kmeans.py
│ ├── plot_loss.py
│ ├── plot_z1.py
│ ├── plot_z2.py
│ ├── plot_z_pca.py
│ ├── run_umap.py
│ └── tsne.py
├── cryodrgn/
│ ├── __init__.py
│ ├── analysis.py
│ ├── analysis_drgnai.py
│ ├── beta_schedule.py
│ ├── command_line.py
│ ├── commands/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── abinit.py
│ │ ├── abinit_het_old.py
│ │ ├── abinit_homo_old.py
│ │ ├── analyze.py
│ │ ├── analyze_landscape.py
│ │ ├── analyze_landscape_full.py
│ │ ├── backproject_voxel.py
│ │ ├── dashboard.py
│ │ ├── direct_traversal.py
│ │ ├── downsample.py
│ │ ├── eval_images.py
│ │ ├── eval_vol.py
│ │ ├── filter.py
│ │ ├── graph_traversal.py
│ │ ├── parse_ctf_csparc.py
│ │ ├── parse_ctf_star.py
│ │ ├── parse_pose_csparc.py
│ │ ├── parse_pose_star.py
│ │ ├── parse_star.py
│ │ ├── pc_traversal.py
│ │ ├── train_dec.py
│ │ ├── train_nn.py
│ │ └── train_vae.py
│ ├── commands_utils/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── add_psize.py
│ │ ├── analyze_convergence.py
│ │ ├── clean.py
│ │ ├── concat_pkls.py
│ │ ├── filter_cs.py
│ │ ├── filter_mrcs.py
│ │ ├── filter_pkl.py
│ │ ├── filter_star.py
│ │ ├── flip_hand.py
│ │ ├── fsc.py
│ │ ├── gen_mask.py
│ │ ├── invert_contrast.py
│ │ ├── make_movies.py
│ │ ├── parse_relion.py
│ │ ├── phase_flip.py
│ │ ├── plot_classes.py
│ │ ├── plot_fsc.py
│ │ ├── select_clusters.py
│ │ ├── select_random.py
│ │ ├── translate_mrcs.py
│ │ ├── view_cs_header.py
│ │ ├── view_header.py
│ │ ├── view_mrcs.py
│ │ ├── write_cs.py
│ │ └── write_star.py
│ ├── config.py
│ ├── ctf.py
│ ├── dashboard/
│ │ ├── __init__.py
│ │ ├── app.py
│ │ ├── bench_plot_interfaces.py
│ │ ├── command_builder_cli_help.py
│ │ ├── command_builder_data.py
│ │ ├── context.py
│ │ ├── data.py
│ │ ├── explorer_volumes.py
│ │ ├── mpl_style.py
│ │ ├── plots.py
│ │ ├── preload.py
│ │ ├── templates/
│ │ │ ├── base.html
│ │ │ ├── command_builder.html
│ │ │ ├── index.html
│ │ │ ├── latent_3d.html
│ │ │ ├── no_images.html
│ │ │ ├── pair_grid.html
│ │ │ ├── pair_grid_need_more_cols.html
│ │ │ ├── scatter_explorer.html
│ │ │ └── trajectory_creator.html
│ │ └── trajectory.py
│ ├── dataset.py
│ ├── fft.py
│ ├── healpy_grid.json
│ ├── lattice.py
│ ├── lie_tools.py
│ ├── losses.py
│ ├── make_healpy.py
│ ├── masking.py
│ ├── metrics.py
│ ├── models.py
│ ├── models_ai.py
│ ├── mrcfile.py
│ ├── pose.py
│ ├── pose_search.py
│ ├── pose_search_ai.py
│ ├── shift_grid.py
│ ├── shift_grid3.py
│ ├── so3_grid.py
│ ├── source.py
│ ├── starfile.py
│ ├── templates/
│ │ ├── cryoDRGN_ET_viz_template.ipynb
│ │ ├── cryoDRGN_analyze_landscape_template.ipynb
│ │ ├── cryoDRGN_figures_template.ipynb
│ │ ├── cryoDRGN_filtering_template.ipynb
│ │ └── cryoDRGN_viz_template.ipynb
│ └── utils.py
├── pyproject.toml
├── sweep.sh
├── testing/
│ ├── diff_cryodrgn_pkl.py
│ ├── test_abinit.sh
│ ├── test_entropy.py
│ ├── test_pose_search_rag12_128.py
│ ├── test_pose_search_real_128.py
│ ├── test_pose_search_syn_64.py
│ ├── test_sta.sh
│ └── test_translate.py
└── tests/
├── conftest.py
├── data/
│ ├── 50S-vol.mrc
│ ├── FinalRefinement-OriginalParticles-PfCRT.star
│ ├── ay19102021_L3_position6_ribo_it09_bin8_1.82A.mrcs
│ ├── cryosparc_J2_particles_exported.cs
│ ├── cryosparc_P12_J24_001_particles.cs
│ ├── ctf1.pkl
│ ├── ctf2.pkl
│ ├── empiar_10076_7.cs
│ ├── empiar_10076_7.mrc
│ ├── empiar_10076_7.star
│ ├── hand-vol.mrc
│ ├── hand.5.mrcs
│ ├── hand.mrcs
│ ├── hand_11_particles.npy
│ ├── hand_rot.pkl
│ ├── hand_rot_trans.pkl
│ ├── hand_tilt.mrcs
│ ├── het_config.yaml
│ ├── het_weights.pkl
│ ├── im_shifted.npy
│ ├── ind100-rand.pkl
│ ├── ind100.pkl
│ ├── ind4.pkl
│ ├── ind5.pkl
│ ├── ind_39_sta_testing_bin8.pkl
│ ├── pose.cs.pkl
│ ├── pose.star.pkl
│ ├── relion31.6opticsgroups.star
│ ├── relion31.mrcs
│ ├── relion31.star
│ ├── relion31.v2.star
│ ├── relion5.star
│ ├── spike-vol.mrc
│ ├── sta_ctf.pkl
│ ├── sta_pose.pkl
│ ├── sta_testing.star
│ ├── sta_testing_bin8.star
│ ├── test_ctf.100.pkl
│ ├── test_ctf.pkl
│ ├── toy.star
│ ├── toy_angles.pkl
│ ├── toy_datadir/
│ │ ├── toy_images_a.mrcs
│ │ └── toy_images_b.mrcs
│ ├── toy_projections.mrc
│ ├── toy_projections.mrcs
│ ├── toy_projections.star
│ ├── toy_projections.txt
│ ├── toy_projections_0-999.mrcs
│ ├── toy_projections_13.star
│ ├── toy_projections_2.txt
│ ├── toy_projections_dir.star
│ ├── toy_rot_trans.pkl
│ ├── toy_rot_zerotrans.pkl
│ ├── toy_trans.pkl
│ ├── toy_trans.zero.pkl
│ ├── toymodel_small_nocenter.mrc
│ ├── zvals_het-2_1k.pkl
│ └── zvals_het-8_4k.pkl
├── quicktest.sh
├── test_add_psize.py
├── test_backprojection.py
├── test_clean.py
├── test_dashboard_core.py
├── test_dashboard_extended.py
├── test_dataset.py
├── test_direct_traversal.py
├── test_downsample.py
├── test_entropy.py
├── test_eval_images.py
├── test_fft.py
├── test_filter_mrcs.py
├── test_filter_pkl.py
├── test_flip_hand.py
├── test_fsc.py
├── test_graph_traversal.py
├── test_integration.py
├── test_invert_contrast.py
├── test_masks.py
├── test_mrc.py
├── test_parse.py
├── test_pc_traversal.py
├── test_phase_flip.py
├── test_read_filter_write.py
├── test_reconstruct_abinit.py
├── test_reconstruct_abinit_old.py
├── test_reconstruct_fixed.py
├── test_reconstruct_tilt.py
├── test_relion.py
├── test_select_clusters.py
├── test_select_random.py
├── test_source.py
├── test_translate.py
├── test_utils.py
├── test_view_cs_header.py
├── test_view_header.py
├── test_view_mrcs.py
├── test_writestar.py
└── unittest.sh
SYMBOL INDEX (1397 symbols across 133 files)
FILE: analysis_scripts/kmeans.py
function parse_args (line 14) | def parse_args():
function main (line 34) | def main(args):
FILE: analysis_scripts/plot_loss.py
function parse_args (line 12) | def parse_args():
function main (line 19) | def main(args):
FILE: analysis_scripts/plot_z1.py
function parse_args (line 12) | def parse_args():
function main (line 42) | def main(args):
FILE: analysis_scripts/plot_z2.py
function parse_args (line 13) | def parse_args():
function main (line 56) | def main(args):
FILE: analysis_scripts/plot_z_pca.py
function parse_args (line 13) | def parse_args():
function main (line 59) | def main(args):
FILE: analysis_scripts/run_umap.py
function parse_args (line 15) | def parse_args():
function main (line 24) | def main(args):
FILE: analysis_scripts/tsne.py
function parse_args (line 11) | def parse_args():
function main (line 22) | def main(args):
FILE: cryodrgn/analysis.py
function parse_loss (line 29) | def parse_loss(f: str) -> np.ndarray:
function run_pca (line 47) | def run_pca(z: np.ndarray) -> Tuple[np.ndarray, PCA]:
function get_pc_traj (line 56) | def get_pc_traj(
function run_tsne (line 93) | def run_tsne(
function run_umap (line 104) | def run_umap(z: np.ndarray, **kwargs) -> np.ndarray:
function cluster_kmeans (line 115) | def cluster_kmeans(
function cluster_gmm (line 147) | def cluster_gmm(
function get_nearest_point (line 179) | def get_nearest_point(
function convert_original_indices (line 193) | def convert_original_indices(
function combine_ind (line 202) | def combine_ind(
function get_ind_for_cluster (line 219) | def get_ind_for_cluster(
function _get_chimerax_colors (line 243) | def _get_chimerax_colors(K: int) -> List:
function _get_colors (line 260) | def _get_colors(K: int, cmap: Optional[str] = None) -> List:
function scatter_annotate (line 270) | def scatter_annotate(
function scatter_annotate_hex (line 304) | def scatter_annotate_hex(
function scatter_color (line 342) | def scatter_color(
function plot_by_cluster (line 368) | def plot_by_cluster(
function plot_by_cluster_subplot (line 418) | def plot_by_cluster_subplot(
function plot_euler (line 438) | def plot_euler(theta, phi, psi, plot_psi=True):
function ipy_plot_interactive_annotate (line 448) | def ipy_plot_interactive_annotate(df, ind, opacity=0.3):
function ipy_plot_interactive (line 513) | def ipy_plot_interactive(df, opacity=0.3):
function plot_projections (line 580) | def plot_projections(imgs, labels=None, max_imgs=25):
function gen_volumes (line 607) | def gen_volumes(
function load_dataframe (line 660) | def load_dataframe(
FILE: cryodrgn/analysis_drgnai.py
class VolumeGenerator (line 23) | class VolumeGenerator:
method __init__ (line 26) | def __init__(
method gen_volumes (line 46) | def gen_volumes(self, outdir, z_values, suffix=None):
class ModelAnalyzer (line 77) | class ModelAnalyzer:
method get_last_cached_epoch (line 96) | def get_last_cached_epoch(cls, traindir: str) -> int:
method __init__ (line 111) | def __init__(
method linear_interpolation (line 189) | def linear_interpolation(z_0, z_1, n, exclude_last=False):
method analyze (line 195) | def analyze(self):
method analyze_z1 (line 235) | def analyze_z1(self) -> None:
method analyze_zN (line 265) | def analyze_zN(self) -> None:
FILE: cryodrgn/beta_schedule.py
function get_beta_schedule (line 4) | def get_beta_schedule(schedule):
class ConstantSchedule (line 19) | class ConstantSchedule:
method __init__ (line 20) | def __init__(self, value):
method __call__ (line 23) | def __call__(self, x):
class LinearSchedule (line 27) | class LinearSchedule:
method __init__ (line 28) | def __init__(self, start_y, end_y, start_x, end_x):
method __call__ (line 35) | def __call__(self, x):
FILE: cryodrgn/command_line.py
function _get_commands (line 21) | def _get_commands(cmd_dir: str, cmds: list[str], doc_str: str = "") -> N...
function main_commands (line 79) | def main_commands() -> None:
function util_commands (line 112) | def util_commands() -> None:
FILE: cryodrgn/commands/abinit.py
function add_args (line 40) | def add_args(parser: argparse.ArgumentParser) -> None:
class ModelTrainer (line 531) | class ModelTrainer:
method make_dataloader (line 571) | def make_dataloader(self, batch_size: int) -> DataLoader:
method __init__ (line 580) | def __init__(self, outdir: str, config_vals: dict[str, Any]) -> None:
method train (line 1080) | def train(self):
method get_ctfs_at (line 1234) | def get_ctfs_at(self, index):
method train_step (line 1253) | def train_step(self, in_dict, end_time):
method detach_latent_variables (line 1410) | def detach_latent_variables(self, latent_variables_dict):
method forward_pass (line 1432) | def forward_pass(self, in_dict):
method loss (line 1516) | def loss(self, y_pred, y_gt, latent_variables_dict):
method make_light_summary (line 1553) | def make_light_summary(self, all_losses: dict[str, float]) -> None:
method save_latents (line 1573) | def save_latents(self) -> None:
method save_volume (line 1589) | def save_volume(self) -> None:
method save_model (line 1617) | def save_model(self) -> None:
method epoch_type (line 1646) | def epoch_type(self) -> str:
function main (line 1656) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/abinit_het_old.py
function add_args (line 41) | def add_args(parser: argparse.ArgumentParser) -> None:
function make_model (line 409) | def make_model(args, lattice, enc_mask, in_dim) -> HetOnlyVAE:
function pretrain (line 428) | def pretrain(model, lattice, optim, minibatch, tilt, zdim):
function train (line 461) | def train(
function eval_z (line 597) | def eval_z(
function save_checkpoint (line 650) | def save_checkpoint(
function save_config (line 693) | def save_config(args, dataset, lattice, model, out_config):
function sort_poses (line 731) | def sort_poses(poses):
function main (line 745) | def main(args):
FILE: cryodrgn/commands/abinit_homo_old.py
function add_args (line 30) | def add_args(parser: argparse.ArgumentParser) -> None:
function save_checkpoint (line 321) | def save_checkpoint(
function pretrain (line 342) | def pretrain(model, lattice, optim, batch, tilt=None):
function sort_poses (line 369) | def sort_poses(pose):
function sort_base_poses (line 384) | def sort_base_poses(pose):
function train (line 391) | def train(
function make_model (line 486) | def make_model(args, D: int):
function save_config (line 501) | def save_config(args, dataset, lattice, out_config):
function main (line 531) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/analyze.py
function add_args (line 37) | def add_args(parser: argparse.ArgumentParser) -> None:
function analyze_z1 (line 111) | def analyze_z1(z, outdir, vg, n_per_pc=10):
function analyze_zN (line 132) | def analyze_zN(
class VolumeGenerator (line 427) | class VolumeGenerator:
method __init__ (line 430) | def __init__(self, weights, config, vol_args={}, skip_vol=False):
method gen_volumes (line 436) | def gen_volumes(self, outdir, z_values):
function main (line 446) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/analyze_landscape.py
function add_args (line 44) | def add_args(parser: argparse.ArgumentParser) -> None:
function generate_volumes (line 163) | def generate_volumes(z, outdir, vg_list, K, vol_start_index):
function make_mask (line 203) | def make_mask(
function view_slices (line 255) | def view_slices(y: np.array, out_png: str, D: Optional[int] = None) -> N...
function choose_cmap (line 267) | def choose_cmap(M):
function get_colors_for_cmap (line 277) | def get_colors_for_cmap(cmap, M):
function analyze_volumes (line 285) | def analyze_volumes(
function make_volume_generator (line 543) | def make_volume_generator(
function main (line 578) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/analyze_landscape_full.py
function add_args (line 42) | def add_args(parser: argparse.ArgumentParser) -> None:
function train (line 152) | def train(model, device, train_loader, optimizer, epoch):
function test (line 173) | def test(model, device, test_loader):
class MyDataset (line 189) | class MyDataset(Dataset):
method __init__ (line 190) | def __init__(self, x, y):
method __len__ (line 195) | def __len__(self):
method __getitem__ (line 198) | def __getitem__(self, idx):
function generate_and_map_volumes (line 202) | def generate_and_map_volumes(zfile, cfg, weights, mask_mrc, pca_obj_pkl,...
function train_model (line 348) | def train_model(x, y, outdir, zfile, args):
function choose_cmap (line 402) | def choose_cmap(M):
function get_colors_for_cmap (line 412) | def get_colors_for_cmap(cmap, M):
function main (line 420) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/backproject_voxel.py
function add_args (line 47) | def add_args(parser: argparse.ArgumentParser) -> None:
function add_slice (line 178) | def add_slice(volume, counts, ff_coord, ff, D, ctf_mul):
function regularize_volume (line 209) | def regularize_volume(
function main (line 219) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/dashboard.py
function _configure_dashboard_logging (line 25) | def _configure_dashboard_logging(verbosity: int) -> None:
function add_args (line 63) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 176) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/direct_traversal.py
function add_args (line 16) | def add_args(parser: argparse.ArgumentParser) -> None:
function parse_anchors (line 44) | def parse_anchors(
function main (line 83) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/downsample.py
function add_args (line 57) | def add_args(parser: argparse.ArgumentParser) -> None:
function downsample_mrc_images (line 113) | def downsample_mrc_images(
function main (line 220) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/eval_images.py
function add_args (line 28) | def add_args(parser):
function eval_batch (line 159) | def eval_batch(
function main (line 179) | def main(args):
FILE: cryodrgn/commands/eval_vol.py
function add_args (line 30) | def add_args(parser: argparse.ArgumentParser) -> None:
function check_inputs (line 102) | def check_inputs(args: argparse.Namespace) -> None:
function postprocess_vol (line 110) | def postprocess_vol(vol, args):
function reset_origin (line 122) | def reset_origin(oldD, cropD, Apix):
function main (line 132) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/filter.py
function add_args (line 64) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 103) | def main(args: argparse.Namespace) -> None:
class SelectFromScatter (line 314) | class SelectFromScatter:
method __init__ (line 317) | def __init__(
method gridspec (line 396) | def gridspec(self) -> GridSpec:
method plot (line 402) | def plot(self) -> None:
method update_xaxis (line 464) | def update_xaxis(self, xlbl: str) -> None:
method update_yaxis (line 469) | def update_yaxis(self, ylbl: str) -> None:
method choose_colors (line 474) | def choose_colors(self, chosen_colors: str) -> None:
method choose_points (line 483) | def choose_points(self, verts: np.array) -> None:
method hover_points (line 493) | def hover_points(self, event: MouseEvent) -> None:
method on_click (line 530) | def on_click(self, event: MouseEvent) -> None:
method on_release (line 535) | def on_release(self, event: MouseEvent) -> None:
method save_click (line 542) | def save_click(self, event: MouseEvent) -> None:
method exit_click (line 548) | def exit_click(self, event: MouseEvent) -> None:
FILE: cryodrgn/commands/graph_traversal.py
function add_args (line 32) | def add_args(parser: argparse.ArgumentParser) -> None:
class GraphLatentTraversor (line 90) | class GraphLatentTraversor:
method __init__ (line 93) | def __init__(self, edges: List[Tuple[int, int, float]]) -> None:
method find_path (line 113) | def find_path(self, src: int, dest: int) -> Tuple[List[int], float]:
function main (line 153) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/parse_ctf_csparc.py
function add_args (line 13) | def add_args(parser):
function main (line 28) | def main(args):
FILE: cryodrgn/commands/parse_ctf_star.py
function add_args (line 32) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 51) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/parse_pose_csparc.py
function add_args (line 14) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 35) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/parse_pose_star.py
function add_args (line 25) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 47) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/parse_star.py
function add_args (line 21) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 45) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/pc_traversal.py
function add_args (line 18) | def add_args(parser: argparse.ArgumentParser) -> None:
function analyze_data_support (line 52) | def analyze_data_support(z, traj, cutoff=3):
function main (line 58) | def main(args):
FILE: cryodrgn/commands/train_dec.py
function add_args (line 23) | def add_args(parser: argparse.ArgumentParser) -> None:
function save_checkpoint (line 265) | def save_checkpoint(
function save_z (line 285) | def save_z(z, out_z):
function cat_z (line 290) | def cat_z(coords, z, zdim):
function train (line 302) | def train(
function save_config (line 355) | def save_config(args, dataset, lattice, model, out_config):
function get_latest (line 386) | def get_latest(args):
function main (line 408) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/train_nn.py
function add_args (line 38) | def add_args(parser: argparse.ArgumentParser) -> None:
function save_checkpoint (line 251) | def save_checkpoint(
function train (line 268) | def train(
function save_config (line 318) | def save_config(args, dataset, lattice, model, out_config):
function main (line 348) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands/train_vae.py
function add_args (line 45) | def add_args(parser: argparse.ArgumentParser) -> None:
function train_batch (line 366) | def train_batch(
function preprocess_input (line 420) | def preprocess_input(y, lattice, trans):
function run_batch (line 428) | def run_batch(model, lattice, y, rot, ntilts: Optional[int], ctf_params=...
function loss_function (line 462) | def loss_function(
function eval_z (line 498) | def eval_z(
function save_checkpoint (line 562) | def save_checkpoint(model, optim, epoch, z_mu, z_logvar, out_weights, ou...
function save_config (line 579) | def save_config(args, dataset, lattice, model, out_config):
function main (line 624) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/add_psize.py
function add_args (line 16) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 24) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/analyze_convergence.py
function add_args (line 37) | def add_args(parser: argparse.ArgumentParser) -> None:
function find_configs (line 188) | def find_configs(workdir: str) -> str:
function plot_loss (line 199) | def plot_loss(logfile, outdir, E):
function encoder_latent_umaps (line 225) | def encoder_latent_umaps(
function encoder_latent_shifts (line 342) | def encoder_latent_shifts(workdir: str, outdir: str, E: int):
function sketch_via_umap_local_maxima (line 400) | def sketch_via_umap_local_maxima(
function follow_candidate_particles (line 624) | def follow_candidate_particles(
function generate_volumes (line 745) | def generate_volumes(workdir, outdir, epochs, Apix, flip, invert, downsa...
function mask_volume (line 769) | def mask_volume(volpath, outpath, Apix, thresh=None, dilate=3, dist=10):
function mask_volumes (line 805) | def mask_volumes(
function calculate_CCs (line 849) | def calculate_CCs(outdir, epochs, labels, chimerax_colors):
function calculate_FSCs (line 908) | def calculate_FSCs(outdir, epochs, labels, img_size, chimerax_colors):
function main (line 1032) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/clean.py
function add_args (line 57) | def add_args(parser):
function clean_dir (line 94) | def clean_dir(d: Path, args: argparse.Namespace) -> None:
function _prompt_dir (line 125) | def _prompt_dir(
function check_open_config (line 151) | def check_open_config(d: Path) -> dict:
function main (line 186) | def main(args):
FILE: cryodrgn/commands_utils/concat_pkls.py
function add_args (line 11) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 16) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/filter_cs.py
function add_args (line 18) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 30) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/filter_mrcs.py
function add_args (line 20) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 28) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/filter_pkl.py
function add_args (line 20) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 37) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/filter_star.py
function add_args (line 21) | def add_args(parser: argparse.ArgumentParser):
function main (line 41) | def main(args: argparse.Namespace):
FILE: cryodrgn/commands_utils/flip_hand.py
function add_args (line 20) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 27) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/fsc.py
function add_args (line 54) | def add_args(parser: argparse.ArgumentParser) -> None:
function get_fftn_center_dists (line 104) | def get_fftn_center_dists(box_size: int) -> np.array:
function calculate_fsc (line 116) | def calculate_fsc(
function get_fsc_curve (line 131) | def get_fsc_curve(
function get_fsc_thresholds (line 165) | def get_fsc_thresholds(
function randomize_phase (line 191) | def randomize_phase(cval: complex) -> complex:
function correct_fsc (line 199) | def correct_fsc(
function calculate_cryosparc_fscs (line 260) | def calculate_cryosparc_fscs(
function main (line 361) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/gen_mask.py
function add_args (line 22) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 61) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/invert_contrast.py
function add_args (line 20) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 27) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/make_movies.py
function generate_movie_prologue (line 25) | def generate_movie_prologue(width: int, height: int) -> list[str]:
function generate_movie_epilogue (line 33) | def generate_movie_epilogue(
function add_args (line 59) | def add_args(parser: argparse.ArgumentParser) -> None:
function check_chimerax_installation (line 119) | def check_chimerax_installation() -> bool:
function find_subdirs (line 142) | def find_subdirs(directory: str, keyword: str) -> list[Path]:
function get_vols (line 154) | def get_vols(directory: Path, postfix_regex: str = "") -> list[str]:
function record_movie (line 161) | def record_movie(
function latent_movies (line 200) | def latent_movies(
function landscape_movies (line 222) | def landscape_movies(
function main (line 253) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/parse_relion.py
function add_args (line 16) | def add_args(parser: argparse.ArgumentParser) -> None:
class Tomogram (line 38) | class Tomogram:
method __init__ (line 45) | def __init__(
method _translation_matrix (line 73) | def _translation_matrix(self, shift_3d):
method _rotation_matrix (line 79) | def _rotation_matrix(self, axis: list[float], angle_deg: float) -> np....
method _build_projection_matrices (line 88) | def _build_projection_matrices(self):
method project_point (line 131) | def project_point(self, point_3d, i_tilt):
method calculate_local_defocus_uv (line 145) | def calculate_local_defocus_uv(self, i_tilt, point_3d):
method expand_particle_to_2drows (line 170) | def expand_particle_to_2drows(
function main (line 247) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/phase_flip.py
function add_args (line 15) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 26) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/plot_classes.py
function add_args (line 53) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 100) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/plot_fsc.py
function add_args (line 26) | def add_args(parser: argparse.ArgumentParser) -> None:
function plot_fsc_vals (line 50) | def plot_fsc_vals(fsc_arr: pd.DataFrame, label: str, **plot_args) -> None:
function create_fsc_plot (line 66) | def create_fsc_plot(
function main (line 139) | def main(args):
FILE: cryodrgn/commands_utils/select_clusters.py
function add_args (line 11) | def add_args(parser):
function main (line 28) | def main(args):
FILE: cryodrgn/commands_utils/select_random.py
function add_args (line 17) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 30) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/translate_mrcs.py
function add_args (line 19) | def add_args(parser: argparse.ArgumentParser) -> None:
function plot_projections (line 38) | def plot_projections(out_png: str, imgs: np.ndarray) -> None:
function main (line 47) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/view_cs_header.py
function add_args (line 8) | def add_args(parser):
function main (line 13) | def main(args):
FILE: cryodrgn/commands_utils/view_header.py
function add_args (line 11) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 15) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/view_mrcs.py
function add_args (line 14) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 32) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/write_cs.py
function add_args (line 18) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 36) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/commands_utils/write_star.py
function add_args (line 54) | def add_args(parser: argparse.ArgumentParser) -> None:
function main (line 83) | def main(args: argparse.Namespace) -> None:
FILE: cryodrgn/config.py
function load (line 10) | def load(config: Union[str, dict]) -> dict:
function save (line 26) | def save(
function update_config_v1 (line 47) | def update_config_v1(config: Union[str, dict]) -> dict:
FILE: cryodrgn/ctf.py
function compute_ctf (line 59) | def compute_ctf(
function print_ctf_params (line 113) | def print_ctf_params(params: np.ndarray) -> None:
function plot_ctf (line 126) | def plot_ctf(D: int, Apix: float, ctf_params: np.ndarray) -> None:
function load_ctf_for_training (line 144) | def load_ctf_for_training(D: int, ctf_params_pkl: str) -> np.ndarray:
FILE: cryodrgn/dashboard/app.py
function _request_json_dict (line 116) | def _request_json_dict() -> dict:
function _filter_ui_scatter_max_points (line 122) | def _filter_ui_scatter_max_points() -> int:
function _default_xy_cols (line 131) | def _default_xy_cols(cols: list[str]) -> tuple[str, str]:
function _covariate_display_name (line 138) | def _covariate_display_name(name: str) -> str:
function _parse_preselect_rows_param (line 145) | def _parse_preselect_rows_param(raw: str | None) -> tuple[list[int] | No...
function _redirect (line 156) | def _redirect(endpoint: str):
function _trajectory_eligibility_error (line 160) | def _trajectory_eligibility_error(e: DashboardExperiment):
function _parse_pairplot_request (line 167) | def _parse_pairplot_request(
function _add_direct_anchor_pidx (line 198) | def _add_direct_anchor_pidx(payload: dict, p: dict, z_traj: np.ndarray) ...
function index (line 216) | def index():
function command_builder_page (line 235) | def command_builder_page():
function abinit_builder_redirect (line 245) | def abinit_builder_redirect():
function filter_page_redirect (line 249) | def filter_page_redirect():
function api_save_selection (line 253) | def api_save_selection():
function explorer (line 305) | def explorer():
function api_explorer_volume_media (line 338) | def api_explorer_volume_media():
function api_scatter (line 407) | def api_scatter():
function latent_3d_page (line 454) | def latent_3d_page():
function api_scatter3d_z (line 474) | def api_scatter3d_z():
function api_latent3d_preview_png (line 499) | def api_latent3d_preview_png():
function api_preview_montage (line 532) | def api_preview_montage():
function api_preload_images (line 545) | def api_preload_images():
function pairplot_page (line 645) | def pairplot_page():
function api_pairplot (line 682) | def api_pairplot():
function api_save_pairplot_png (line 709) | def api_save_pairplot_png():
function api_default_trajectory_endpoints (line 749) | def api_default_trajectory_endpoints():
function trajectory_creator_page (line 767) | def trajectory_creator_page():
function api_trajectory_save_zpath (line 797) | def api_trajectory_save_zpath():
function api_trajectory_save_volumes (line 832) | def api_trajectory_save_volumes():
function _trajectory_anchor_driven_json (line 858) | def _trajectory_anchor_driven_json(
function api_trajectory_import_anchors (line 878) | def api_trajectory_import_anchors():
function api_list_server_files (line 905) | def api_list_server_files():
function api_trajectory_kmeans_centers (line 931) | def api_trajectory_kmeans_centers():
function api_trajectory_random_indices (line 948) | def api_trajectory_random_indices():
function api_trajectory_coords (line 965) | def api_trajectory_coords():
function api_trajectory_volumes (line 996) | def api_trajectory_volumes():
function create_app (line 1074) | def create_app(
function run_server (line 1132) | def run_server(
FILE: cryodrgn/dashboard/bench_plot_interfaces.py
function main (line 20) | def main() -> int:
FILE: cryodrgn/dashboard/command_builder_cli_help.py
function _string_from_ast (line 14) | def _string_from_ast(node: ast.expr | None) -> str | None:
function _flags_and_positional_from_add_argument (line 36) | def _flags_and_positional_from_add_argument(
function _help_from_add_argument (line 57) | def _help_from_add_argument(call: ast.Call) -> str | None:
function help_map_from_command_py (line 64) | def help_map_from_command_py(path: Path) -> dict[str, str]:
function attach_help_to_groups (line 87) | def attach_help_to_groups(
function load_cli_help_maps (line 107) | def load_cli_help_maps() -> dict[str, dict[str, str]]:
FILE: cryodrgn/dashboard/command_builder_data.py
function _g (line 32) | def _g(title: str, args: list[Arg]) -> Group:
function _required_field_titles (line 925) | def _required_field_titles(
function _build_required_field_titles (line 940) | def _build_required_field_titles() -> dict[str, str]:
FILE: cryodrgn/dashboard/context.py
function clear_experiment_caches (line 37) | def clear_experiment_caches() -> None:
function _config_has_cryodrgn_cmd (line 83) | def _config_has_cryodrgn_cmd(config: object) -> bool:
function discover_cryodrgn_workdirs (line 94) | def discover_cryodrgn_workdirs(cwd: str) -> list[str]:
function _workdir_options (line 116) | def _workdir_options(abs_paths: list[str], base_cwd: str) -> list[dict[s...
function _sync_discovery_session_boot (line 132) | def _sync_discovery_session_boot() -> None:
function active_workdir (line 149) | def active_workdir(app: Flask) -> str | None:
function epochs_for_workdir (line 160) | def epochs_for_workdir(workdir: str) -> list[int]:
function resolve_epoch (line 169) | def resolve_epoch(app: Flask) -> int:
function get_dashboard_exp (line 188) | def get_dashboard_exp(app: Flask) -> DashboardExperiment:
function bind_dashboard_exp (line 200) | def bind_dashboard_exp() -> None:
function _request_json_dict (line 216) | def _request_json_dict() -> dict:
function api_set_epoch (line 221) | def api_set_epoch():
function api_set_workdir (line 240) | def api_set_workdir():
function abbrev_middle (line 278) | def abbrev_middle(text: object, maxlen: int = 30) -> str:
function _cmd_argv_for_nav_display (line 291) | def _cmd_argv_for_nav_display(cmd_parts: list[str]) -> list[str]:
function _abbrev_middle_token (line 310) | def _abbrev_middle_token(text: str, maxlen: int = 120) -> str:
function _argv_four_command_lines (line 322) | def _argv_four_command_lines(argv: list[str]) -> list[str]:
function command_builder_template_kwargs (line 400) | def command_builder_template_kwargs(
function inject_meta (line 449) | def inject_meta() -> dict:
function inject_meta_command_builder_only (line 489) | def inject_meta_command_builder_only() -> dict:
FILE: cryodrgn/dashboard/data.py
function list_z_epochs (line 19) | def list_z_epochs(workdir: str) -> list[int]:
class DashboardExperiment (line 35) | class DashboardExperiment:
method numeric_columns (line 56) | def numeric_columns(self) -> list[str]:
method can_preview_particles (line 65) | def can_preview_particles(self) -> bool:
function load_experiment (line 69) | def load_experiment(
function particle_image_array (line 204) | def particle_image_array(exp: DashboardExperiment, row_index: int) -> np...
FILE: cryodrgn/dashboard/explorer_volumes.py
function _vol_cache_evict_unlocked (line 30) | def _vol_cache_evict_unlocked(token: str) -> None:
function _vol_cache_prune_unlocked (line 36) | def _vol_cache_prune_unlocked() -> None:
function _register_vol_mrc_cache (line 50) | def _register_vol_mrc_cache(
function volume_cell_gif_from_cache (line 65) | def volume_cell_gif_from_cache(
function save_cached_volumes_to_dir (line 101) | def save_cached_volumes_to_dir(
function torch_cuda_available (line 137) | def torch_cuda_available() -> bool:
function chimerax_path (line 146) | def chimerax_path() -> str:
function run_chimerax_cmds (line 156) | def run_chimerax_cmds(
function explorer_volumes_eligible (line 176) | def explorer_volumes_eligible(exp: DashboardExperiment) -> bool:
function _config_yaml_path (line 186) | def _config_yaml_path(workdir: str) -> str:
function _is_drgnai_config (line 196) | def _is_drgnai_config(train_configs: dict) -> bool:
function _decode_z_values_classic (line 200) | def _decode_z_values_classic(
function _drgnai_volume_generator (line 225) | def _drgnai_volume_generator(exp: DashboardExperiment):
function _decode_z_values_drgnai (line 270) | def _decode_z_values_drgnai(
function _sorted_vol_mrc_paths (line 281) | def _sorted_vol_mrc_paths(mrc_dir: str, n_take: int) -> list[str]:
function _mpl_retrim_png (line 296) | def _mpl_retrim_png(out_png: str, dpi: int) -> None:
function _chimerax_render_cmds (line 310) | def _chimerax_render_cmds(
function mrc_to_static_png (line 328) | def mrc_to_static_png(mrc_path: str, out_png: str, dpi: int = 100) -> None:
function mrc_to_rotating_gif (line 335) | def mrc_to_rotating_gif(
function _decode_z_values_to_vol_paths (line 387) | def _decode_z_values_to_vol_paths(
function generate_trajectory_volume_pngs (line 397) | def generate_trajectory_volume_pngs(
function generate_montage_volume_pngs (line 430) | def generate_montage_volume_pngs(
FILE: cryodrgn/dashboard/mpl_style.py
function ezlab_matplotlib_rc (line 21) | def ezlab_matplotlib_rc():
FILE: cryodrgn/dashboard/plots.py
function _lower_color_series_is_discrete (line 35) | def _lower_color_series_is_discrete(s: pd.Series) -> bool:
function normalize_continuous_palette (line 76) | def normalize_continuous_palette(raw: str | None) -> str:
function mpl_cmap_for_palette (line 87) | def mpl_cmap_for_palette(plotly_palette: str) -> str:
function _subsample (line 118) | def _subsample(
function _axes_cell_bboxes (line 139) | def _axes_cell_bboxes(axes: np.ndarray) -> list[dict[str, float]]:
function _plotly_to_json (line 157) | def _plotly_to_json(fig: go.Figure) -> str:
function _labels_colors_and_legend_items (line 164) | def _labels_colors_and_legend_items(
function _continuous_series_stats (line 192) | def _continuous_series_stats(values: pd.Series) -> tuple[np.ndarray, flo...
function pair_grid_skeleton_placeholder_layout (line 213) | def pair_grid_skeleton_placeholder_layout(
function _pair_jointplot_hex_cmap (line 239) | def _pair_jointplot_hex_cmap(color: str | None = None):
function _pair_jointplot_hex_gridsize (line 248) | def _pair_jointplot_hex_gridsize(x: np.ndarray, y: np.ndarray) -> int:
function _refine_pair_grid_right_margin (line 255) | def _refine_pair_grid_right_margin(
function _pair_lower_triangle_pictogram_da (line 298) | def _pair_lower_triangle_pictogram_da(px: float = 28.0) -> DrawingArea:
function _lower_legend_covariate_title (line 330) | def _lower_legend_covariate_title(
function _lower_legend_entry_label (line 338) | def _lower_legend_entry_label(lower_color_col: str, u: Any) -> str:
function _draw_pair_grid_edge_labels (line 361) | def _draw_pair_grid_edge_labels(
function _draw_pair_grid_diagonal_legends (line 424) | def _draw_pair_grid_diagonal_legends(
function _attach_pair_lower_legend_caption (line 489) | def _attach_pair_lower_legend_caption(
function scatter_json (line 523) | def scatter_json(
function scatter3d_z_json (line 627) | def scatter3d_z_json(
function scatter3d_z_preview_png (line 718) | def scatter3d_z_preview_png(
function pair_grid_png (line 818) | def pair_grid_png(
FILE: cryodrgn/dashboard/preload.py
function encode_particle_batch (line 25) | def encode_particle_batch(
function montage_bytes (line 61) | def montage_bytes(exp: DashboardExperiment, row_indices: list[int]) -> b...
function particle_thumbnail_b64_from_row (line 100) | def particle_thumbnail_b64_from_row(
function _stratified_xy_row_indices (line 129) | def _stratified_xy_row_indices(
function sample_plot_df_rows_for_preload (line 180) | def sample_plot_df_rows_for_preload(
function _preload_cache_time_estimate_bounds (line 215) | def _preload_cache_time_estimate_bounds(cpus: int) -> tuple[int, int]:
function format_preload_cache_time_hint (line 228) | def format_preload_cache_time_hint(cpus: int) -> str:
function load_plot_df_rows_from_plot_inds_file (line 236) | def load_plot_df_rows_from_plot_inds_file(
FILE: cryodrgn/dashboard/trajectory.py
function has_umap_columns (line 31) | def has_umap_columns(exp: DashboardExperiment) -> bool:
function has_pc_columns (line 39) | def has_pc_columns(exp: DashboardExperiment) -> bool:
function trajectory_default_xy_cols (line 43) | def trajectory_default_xy_cols(cols: list[str], zdim: int) -> tuple[str,...
function trajectory_plot_axis_columns (line 56) | def trajectory_plot_axis_columns(e: DashboardExperiment) -> list[str]:
function _trajectory_xy_ok_for_direct (line 82) | def _trajectory_xy_ok_for_direct(xcol: str, ycol: str) -> bool:
function validate_trajectory_plot_axes (line 91) | def validate_trajectory_plot_axes(e: DashboardExperiment, xcol: str, yco...
function trajectory_axes_from_payload (line 99) | def trajectory_axes_from_payload(e: DashboardExperiment, data: dict) -> ...
function parse_int_from_dict (line 113) | def parse_int_from_dict(data: dict, key: str, *, default: int, lo: int, ...
function parse_traj_points_value (line 122) | def parse_traj_points_value(data: dict, default: int = 4) -> int:
function parse_traj_interpolation_value (line 127) | def parse_traj_interpolation_value(
function parse_traj_neighbor_value (line 134) | def parse_traj_neighbor_value(data: dict, key: str, default: int) -> int:
function trajectory_anchor_mode_params (line 139) | def trajectory_anchor_mode_params(data: dict) -> tuple[str, int, int, int]:
function z_traj_to_savetxt_str (line 155) | def z_traj_to_savetxt_str(z_traj: np.ndarray) -> str:
function _round_direct_mode_traj_xy (line 162) | def _round_direct_mode_traj_xy(traj_xy: np.ndarray) -> np.ndarray:
function parse_anchor_indices_txt (line 175) | def parse_anchor_indices_txt(raw: bytes) -> list[int]:
function _plot_row_particle_index (line 194) | def _plot_row_particle_index(exp: DashboardExperiment, row_index: int) -...
function _compute_trajectory_from_anchor_indices (line 206) | def _compute_trajectory_from_anchor_indices(
function _compute_direct_anchor_trajectory (line 229) | def _compute_direct_anchor_trajectory(
function _graph_neighbor_arrays (line 261) | def _graph_neighbor_arrays(
function _dijkstra_path_from_neighbors (line 305) | def _dijkstra_path_from_neighbors(
function _compute_graph_anchor_trajectory (line 352) | def _compute_graph_anchor_trajectory(
function parse_trajectory_request_body (line 390) | def parse_trajectory_request_body(e: DashboardExperiment, data: dict) ->...
function compute_trajectory_latent_path (line 477) | def compute_trajectory_latent_path(
function trajectory_shared_json_payload (line 541) | def trajectory_shared_json_payload(
function direct_anchor_particle_indices_payload (line 570) | def direct_anchor_particle_indices_payload(
function trajectory_anchor_payload_from_indices (line 596) | def trajectory_anchor_payload_from_indices(
function kmeans_centers_ind_path (line 659) | def kmeans_centers_ind_path(e: DashboardExperiment) -> str:
function load_kmeans_center_indices (line 668) | def load_kmeans_center_indices(e: DashboardExperiment) -> list[int]:
function random_dataset_indices (line 682) | def random_dataset_indices(e: DashboardExperiment, k: int = 10) -> list[...
function plot_df_rows_for_dataset_indices (line 691) | def plot_df_rows_for_dataset_indices(
function default_trajectory_endpoints_xy (line 702) | def default_trajectory_endpoints_xy(
FILE: cryodrgn/dataset.py
class ImageDataset (line 31) | class ImageDataset(torch.utils.data.Dataset):
method __init__ (line 32) | def __init__(
method estimate_normalization (line 79) | def estimate_normalization(self, n=1000):
method estimate_normalization_real (line 93) | def estimate_normalization_real(self, n=1000):
method _process (line 102) | def _process(self, data):
method __len__ (line 121) | def __len__(self):
method __getitem__ (line 124) | def __getitem__(self, index):
method get_slice (line 152) | def get_slice(
class TiltSeriesData (line 161) | class TiltSeriesData(ImageDataset):
method __init__ (line 166) | def __init__(
method __len__ (line 234) | def __len__(self):
method __getitem__ (line 237) | def __getitem__(self, index) -> dict[str, torch.Tensor]:
method parse_particle_tilt (line 265) | def parse_particle_tilt(
method particles_to_tilts (line 296) | def particles_to_tilts(
method tilts_to_particles (line 305) | def tilts_to_particles(cls, tilts_to_particles, tilts):
method get_tilt (line 310) | def get_tilt(self, index):
method get_tilt_particle (line 313) | def get_tilt_particle(self, index) -> int:
method get_slice (line 321) | def get_slice(self, start: int, stop: int) -> Tuple[np.ndarray, np.nda...
method critical_exposure (line 345) | def critical_exposure(self, freq):
method get_dose_filters (line 362) | def get_dose_filters(self, tilt_index, lattice, Apix):
method optimal_exposure (line 388) | def optimal_exposure(self, freq):
class DataShuffler (line 392) | class DataShuffler:
method __init__ (line 393) | def __init__(
method __iter__ (line 421) | def __iter__(self):
class _DataShufflerIterator (line 425) | class _DataShufflerIterator:
method __init__ (line 426) | def __init__(self, shuffler: DataShuffler):
method _get_next_chunk (line 464) | def _get_next_chunk(self) -> Tuple[np.ndarray, Optional[np.ndarray], n...
method __iter__ (line 482) | def __iter__(self):
method __next__ (line 485) | def __next__(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
function make_dataloader (line 554) | def make_dataloader(
FILE: cryodrgn/fft.py
function normalize (line 12) | def normalize(
function fft2_center (line 31) | def fft2_center(img: torch.Tensor) -> torch.Tensor:
function fftn_center (line 39) | def fftn_center(img: torch.Tensor) -> torch.Tensor:
function ifftn_center (line 44) | def ifftn_center(img: torch.Tensor) -> torch.Tensor:
function ht2_center (line 49) | def ht2_center(img: torch.Tensor) -> torch.Tensor:
function htn_center (line 55) | def htn_center(img: torch.Tensor) -> torch.Tensor:
function iht2_center (line 61) | def iht2_center(img: torch.Tensor) -> torch.Tensor:
function ihtn_center (line 68) | def ihtn_center(img: torch.Tensor) -> torch.Tensor:
function symmetrize_ht (line 75) | def symmetrize_ht(ht: torch.Tensor) -> torch.Tensor:
FILE: cryodrgn/lattice.py
class Lattice (line 20) | class Lattice:
method __init__ (line 21) | def __init__(
method get_downsample_coords (line 50) | def get_downsample_coords(self, d: int) -> Tensor:
method get_square_lattice (line 62) | def get_square_lattice(self, L: int) -> Tensor:
method get_square_mask (line 69) | def get_square_mask(self, side_length: int) -> Tensor:
method get_circular_mask (line 100) | def get_circular_mask(self, radius: float, verbose: bool = False) -> T...
method rotate (line 125) | def rotate(self, images: Tensor, theta: Tensor) -> Tensor:
method translate_ft (line 142) | def translate_ft(self, img, t, mask=None):
method translate_ht (line 168) | def translate_ht(self, img, t, mask=None):
class EvenLattice (line 193) | class EvenLattice(Lattice):
method __init__ (line 196) | def __init__(
method get_downsampled_coords (line 230) | def get_downsampled_coords(self, d):
FILE: cryodrgn/lie_tools.py
function map_to_lie_algebra (line 12) | def map_to_lie_algebra(v):
function map_to_lie_vector (line 37) | def map_to_lie_vector(X):
function expmap (line 45) | def expmap(v):
function logmap (line 59) | def logmap(R):
function s2s1rodrigues (line 65) | def s2s1rodrigues(s2_el, s1_el):
function s2s2_to_SO3 (line 78) | def s2s2_to_SO3(v1, v2=None):
function SO3_to_s2s2 (line 93) | def SO3_to_s2s2(r):
function SO3_to_quaternions (line 99) | def SO3_to_quaternions(r):
function rotmat_to_s2s2 (line 163) | def rotmat_to_s2s2(rotmat):
function s2s2_to_rotmat (line 172) | def s2s2_to_rotmat(s2s2):
function quaternions_to_SO3 (line 191) | def quaternions_to_SO3(q):
function random_quaternions (line 212) | def random_quaternions(n, dtype=torch.float32, device=None):
function random_rotmat (line 225) | def random_rotmat(n, dtype=torch.float32, device=None):
function logsumexp (line 229) | def logsumexp(inputs, dim=None, keepdim=False):
function so3_entropy_old (line 254) | def so3_entropy_old(w_eps, std, k=10):
function so3_entropy (line 279) | def so3_entropy(w_eps, std, k=10):
FILE: cryodrgn/losses.py
class EquivarianceLoss (line 10) | class EquivarianceLoss(nn.Module):
method __init__ (line 13) | def __init__(self, model, D):
method forward (line 18) | def forward(self, img, encoding):
method rotate (line 29) | def rotate(self, img, theta):
function l2_frequency_bias (line 42) | def l2_frequency_bias(
function kl_divergence_conf (line 57) | def kl_divergence_conf(latent_variables_dict: dict[str, torch.Tensor]) -...
function l1_regularizer (line 65) | def l1_regularizer(x: torch.Tensor) -> torch.Tensor:
FILE: cryodrgn/masking.py
function spherical_window_mask (line 13) | def spherical_window_mask(
function cosine_dilation_mask (line 73) | def cosine_dilation_mask(
class CircularMask (line 133) | class CircularMask:
method __init__ (line 136) | def __init__(self, lattice: Lattice, radius: int) -> None:
method update_radius (line 141) | def update_radius(self, radius: int) -> None:
method update_batch (line 145) | def update_batch(self, total_images_count: int) -> None:
method update_epoch (line 148) | def update_epoch(self, n_frequencies: int) -> None:
method get_lf_submask (line 151) | def get_lf_submask(self) -> torch.Tensor:
method get_hf_submask (line 156) | def get_hf_submask(self) -> torch.Tensor:
class FrequencyMarchingMask (line 160) | class FrequencyMarchingMask(CircularMask):
method __init__ (line 163) | def __init__(
method update_batch (line 175) | def update_batch(self, total_images_count) -> None:
method update_epoch (line 184) | def update_epoch(self, n_frequencies: int) -> None:
method reset (line 187) | def reset(self) -> None:
class FrequencyMarchingExpMask (line 191) | class FrequencyMarchingExpMask(FrequencyMarchingMask):
method __init__ (line 192) | def __init__(
method update_batch (line 203) | def update_batch(self, total_images_count: int) -> None:
FILE: cryodrgn/metrics.py
function get_ref_matrix (line 13) | def get_ref_matrix(r1, r2, i, flip=False):
function _flip (line 20) | def _flip(rot):
function align_rot (line 25) | def align_rot(r1, r2, i, flip=False):
function rigid_transform_3d (line 32) | def rigid_transform_3d(a, b):
function align_view_dir (line 51) | def align_view_dir(rot_gt_tensor, rot_pred_tensor):
function align_rot_best (line 82) | def align_rot_best(rot_gt_tensor, rot_pred_tensor, n_tries=100):
function frob_norm (line 125) | def frob_norm(r1, r2):
function get_angular_error (line 135) | def get_angular_error(rot_gt, rot_pred):
function get_trans_metrics (line 161) | def get_trans_metrics(trans_gt, trans_pred, rotmat, correct_global_trans...
FILE: cryodrgn/models.py
function unparallelize (line 18) | def unparallelize(model: nn.Module) -> nn.Module:
class HetOnlyVAE (line 26) | class HetOnlyVAE(nn.Module):
method __init__ (line 28) | def __init__(
method load (line 93) | def load(cls, config, weights=None, device=None):
method reparameterize (line 141) | def reparameterize(self, mu, logvar):
method encode (line 148) | def encode(self, *img) -> Tuple[Tensor, Tensor]:
method cat_z (line 155) | def cat_z(self, coords, z) -> Tensor:
method decode (line 165) | def decode(self, coords, z=None) -> torch.Tensor:
method forward (line 176) | def forward(self, *args, **kwargs):
class Decoder (line 180) | class Decoder(nn.Module):
method eval_volume (line 181) | def eval_volume(
method get_voxel_decoder (line 200) | def get_voxel_decoder(self) -> Optional["Decoder"]:
class DataParallelDecoder (line 204) | class DataParallelDecoder(Decoder):
method __init__ (line 205) | def __init__(self, decoder: Decoder):
method eval_volume (line 209) | def eval_volume(self, *args, **kwargs):
method forward (line 214) | def forward(self, *args, **kwargs):
method state_dict (line 217) | def state_dict(self, *args, **kwargs):
function load_decoder (line 221) | def load_decoder(config, weights=None, device=None) -> Tuple[Decoder, La...
class PositionalDecoder (line 260) | class PositionalDecoder(Decoder):
method __init__ (line 261) | def __init__(
method positional_encoding_geom (line 298) | def positional_encoding_geom(self, coords):
method random_fourier_encoding (line 337) | def random_fourier_encoding(self, coords):
method positional_encoding_linear (line 355) | def positional_encoding_linear(self, coords):
method forward (line 370) | def forward(self, coords: Tensor) -> Tensor:
method eval_volume (line 375) | def eval_volume(
class FTPositionalDecoder (line 421) | class FTPositionalDecoder(Decoder):
method __init__ (line 422) | def __init__(
method positional_encoding_geom (line 459) | def positional_encoding_geom(self, coords: Tensor) -> Tensor:
method random_fourier_encoding (line 498) | def random_fourier_encoding(self, coords):
method positional_encoding_linear (line 516) | def positional_encoding_linear(self, coords: Tensor) -> Tensor:
method forward (line 531) | def forward(self, lattice: Tensor) -> Tensor:
method decode (line 554) | def decode(self, lattice: Tensor):
method eval_volume (line 567) | def eval_volume(
class FTSliceDecoder (line 620) | class FTSliceDecoder(Decoder):
method __init__ (line 629) | def __init__(self, in_dim: int, D: int, nlayers: int, hidden_dim: int,...
method forward (line 658) | def forward(self, lattice):
method forward_even (line 682) | def forward_even(self, lattice):
method decode (line 693) | def decode(self, lattice):
method eval_volume (line 706) | def eval_volume(
function get_decoder (line 754) | def get_decoder(
class VAE (line 785) | class VAE(nn.Module):
method __init__ (line 786) | def __init__(
method reparameterize (line 835) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method encode (line 842) | def encode(self, img) -> Tuple[Tensor, Tensor, Optional[Tensor], Optio...
method eval_volume (line 856) | def eval_volume(self, norm) -> Tensor:
method decode (line 861) | def decode(self, rot):
method forward (line 868) | def forward(self, img: Tensor):
class TiltVAE (line 884) | class TiltVAE(nn.Module):
method __init__ (line 885) | def __init__(
method reparameterize (line 904) | def reparameterize(self, mu, logvar):
method eval_volume (line 911) | def eval_volume(self, norm) -> Tensor:
method encode (line 916) | def encode(self, img, img_tilt):
method forward (line 935) | def forward(self, img, img_tilt):
class TiltEncoder (line 959) | class TiltEncoder(nn.Module):
method __init__ (line 960) | def __init__(
method forward (line 980) | def forward(self, x):
class ResidLinearMLP (line 986) | class ResidLinearMLP(Decoder):
method __init__ (line 987) | def __init__(
method forward (line 1014) | def forward(self, x):
method eval_volume (line 1020) | def eval_volume(
function half_linear (line 1060) | def half_linear(input, weight, bias):
function single_linear (line 1065) | def single_linear(input, weight, bias):
class MyLinear (line 1072) | class MyLinear(nn.Linear):
method forward (line 1073) | def forward(self, input):
class ResidLinear (line 1084) | class ResidLinear(nn.Module):
method __init__ (line 1085) | def __init__(self, nin, nout):
method forward (line 1090) | def forward(self, x):
class MLP (line 1095) | class MLP(nn.Module):
method __init__ (line 1096) | def __init__(
method forward (line 1112) | def forward(self, x):
class ConvEncoder (line 1117) | class ConvEncoder(nn.Module):
method __init__ (line 1118) | def __init__(self, hidden_dim, out_dim):
method forward (line 1142) | def forward(self, x):
class SO3reparameterize (line 1148) | class SO3reparameterize(nn.Module):
method __init__ (line 1151) | def __init__(self, input_dims, nlayers: int, hidden_dim: int):
method sampleSO3 (line 1162) | def sampleSO3(
method forward (line 1180) | def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
FILE: cryodrgn/models_ai.py
class MyDataParallel (line 17) | class MyDataParallel(nn.DataParallel):
method __getattr__ (line 18) | def __getattr__(self, name):
class DrgnAI (line 25) | class DrgnAI(nn.Module):
method __init__ (line 26) | def __init__(
method update_trans_search_factor (line 190) | def update_trans_search_factor(self, ratio):
method forward (line 194) | def forward(self, in_dict):
method process_y_real (line 258) | def process_y_real(in_dict):
method encode (line 262) | def encode(self, in_dict, ctf=None):
method decode (line 357) | def decode(self, latent_variables_dict, ctf_local, y_gt):
method eval_on_slice (line 418) | def eval_on_slice(self, x, z=None):
method apply_ctf (line 434) | def apply_ctf(self, y_pred, ctf_local):
method eval_volume (line 447) | def eval_volume(self, norm, zval=None):
method load (line 462) | def load(cls, config, weights=None, device=None):
function sample_conf (line 477) | def sample_conf(z_mu, z_logvar):
function eval_volume_method (line 492) | def eval_volume_method(hypervolume, lattice, z_dim, norm, zval=None, rad...
class SharedCNN (line 534) | class SharedCNN(nn.Module):
method __init__ (line 535) | def __init__(
method forward (line 617) | def forward(self, y_real):
class RadialAverager (line 661) | class RadialAverager(nn.Module):
method __init__ (line 662) | def __init__(self):
method forward (line 666) | def forward(y_real):
class AddCoords (line 688) | class AddCoords(nn.Module):
method __init__ (line 689) | def __init__(self, resolution, radius_channel=True):
method forward (line 734) | def forward(self, x):
class ConfTable (line 753) | class ConfTable(nn.Module):
method __init__ (line 754) | def __init__(self, n_imgs, z_dim, variational, std_z_init):
method initialize (line 770) | def initialize(self, conf):
method forward (line 778) | def forward(self, in_dict):
method reset (line 799) | def reset(self):
class PoseTable (line 805) | class PoseTable(nn.Module):
method __init__ (line 806) | def __init__(self, n_imgs, no_trans, resolution, use_gt_trans):
method initialize (line 827) | def initialize(self, rots, trans):
method forward (line 840) | def forward(self, in_dict):
class ConfRegressor (line 869) | class ConfRegressor(nn.Module):
method __init__ (line 870) | def __init__(self, channels, kernel_size, z_dim, std_z_init, variation...
method forward (line 889) | def forward(self, shared_features):
class HyperVolume (line 913) | class HyperVolume(nn.Module):
method __init__ (line 914) | def __init__(
method forward (line 973) | def forward(self, x, z):
method random_fourier_encoding (line 1003) | def random_fourier_encoding(self, x):
method geom_fourier_encoding_conf (line 1017) | def geom_fourier_encoding_conf(self, z):
method get_building_params (line 1029) | def get_building_params(self):
class VolumeExplicit (line 1044) | class VolumeExplicit(nn.Module):
method __init__ (line 1045) | def __init__(self, resolution, domain, extent):
method forward (line 1063) | def forward(self, x, z):
method get_building_params (line 1084) | def get_building_params(self):
class GaussianPyramid (line 1094) | class GaussianPyramid(nn.Module):
method __init__ (line 1095) | def __init__(self, n_layers):
method forward (line 1122) | def forward(self, x):
class ResidualLinearMLP (line 1131) | class ResidualLinearMLP(nn.Module):
method __init__ (line 1132) | def __init__(self, in_dim, n_layers, hidden_dim, out_dim, nl=nn.ReLU):
method forward (line 1150) | def forward(self, x):
class ResidualLinear (line 1162) | class ResidualLinear(nn.Module):
method __init__ (line 1163) | def __init__(self, n_in, n_out):
method forward (line 1167) | def forward(self, x):
class MyLinear (line 1172) | class MyLinear(nn.Linear):
method forward (line 1173) | def forward(self, x):
function half_linear (line 1180) | def half_linear(x, weight, bias):
function single_linear (line 1184) | def single_linear(x, weight, bias):
FILE: cryodrgn/mrcfile.py
class MRCHeader (line 24) | class MRCHeader:
method __init__ (line 109) | def __init__(self, header_values, extended_header=b""):
method __str__ (line 123) | def __str__(self):
method parse (line 127) | def parse(cls, fname: str) -> Self:
method make_default_header (line 146) | def make_default_header(
method write (line 251) | def write(self, fh):
method apix (line 258) | def apix(self) -> float:
method apix (line 262) | def apix(self, value: float) -> None:
method origin (line 268) | def origin(self) -> tuple[float, float, float]:
method origin (line 272) | def origin(self, value: tuple[float, float, float]) -> None:
function parse_mrc (line 278) | def parse_mrc(fname: str) -> Tuple[np.ndarray, MRCHeader]:
function get_mrc_header (line 296) | def get_mrc_header(
function fix_mrc_header (line 316) | def fix_mrc_header(header: MRCHeader) -> MRCHeader:
function write_mrc (line 330) | def write_mrc(
FILE: cryodrgn/pose.py
class PoseTracker (line 15) | class PoseTracker(nn.Module):
method __init__ (line 16) | def __init__(
method load (line 58) | def load(
method save (line 139) | def save(self, out_pkl: str) -> None:
method get_pose (line 161) | def get_pose(self, ind: Union[int, Tensor]) -> Tuple[Tensor, Optional[...
FILE: cryodrgn/pose_search.py
function rot_2d (line 14) | def rot_2d(angle: float, outD: int, device: torch.device) -> torch.Tensor:
function to_tensor (line 23) | def to_tensor(x: Union[np.ndarray, torch.Tensor, None]):
function interpolate (line 29) | def interpolate(img: torch.Tensor, coords: torch.Tensor) -> torch.Tensor:
class PoseSearch (line 52) | class PoseSearch:
method __init__ (line 55) | def __init__(
method eval_grid (line 108) | def eval_grid(
method mask_images (line 194) | def mask_images(self, images, L):
method translate_images (line 203) | def translate_images(
method rotate_images (line 218) | def rotate_images(
method get_neighbor_so3 (line 248) | def get_neighbor_so3(self, quat: np.ndarray, s2i: int, s1i: int, res: ...
method get_neighbor_shift (line 256) | def get_neighbor_shift(self, x, y, res):
method subdivide (line 266) | def subdivide(
method keep_matrix (line 304) | def keep_matrix(self, loss: torch.Tensor, B: int, max_poses: int) -> t...
method getL (line 331) | def getL(self, iter_: int) -> int:
method opt_theta_trans (line 336) | def opt_theta_trans(
FILE: cryodrgn/pose_search_ai.py
function get_base_shifts (line 13) | def get_base_shifts(ps_params):
function get_base_rot (line 25) | def get_base_rot(ps_params):
function get_so3_base_quat (line 30) | def get_so3_base_quat(ps_params):
function get_base_inplane (line 34) | def get_base_inplane(ps_params):
function to_tensor (line 38) | def to_tensor(x):
function get_l (line 44) | def get_l(step, res, ps_params):
function translate_images (line 54) | def translate_images(images, shifts, l_current, lattice, freqs2d):
function rot_2d (line 68) | def rot_2d(angle, out_d, device):
function rot_2d_tensor (line 77) | def rot_2d_tensor(angles, out_d, device):
function interpolate (line 86) | def interpolate(img, coords):
function rotate_images (line 100) | def rotate_images(images, angles, l_current, masked_coords, lattice):
function compute_err (line 131) | def compute_err(
function eval_grid (line 246) | def eval_grid(
function keep_matrix (line 291) | def keep_matrix(loss, batch_size, max_poses):
function get_neighbor_so3 (line 318) | def get_neighbor_so3(quat, q_ind, res, device):
function subdivide (line 330) | def subdivide(quat, q_ind, cur_res, device):
function opt_trans (line 348) | def opt_trans(model, y_gt, y_pred, lattice, ps_params, current_radius):
function opt_theta_trans (line 402) | def opt_theta_trans(
FILE: cryodrgn/shift_grid.py
function grid_1d (line 4) | def grid_1d(resol: int, extent: int, ngrid: int, shift: int = 0) -> np.n...
function grid_2d (line 11) | def grid_2d(
function base_shift_grid (line 21) | def base_shift_grid(
function get_1d_neighbor (line 30) | def get_1d_neighbor(mini, cur_res, extent, ngrid):
function get_base_ind (line 37) | def get_base_ind(ind, ngrid):
function get_neighbor (line 43) | def get_neighbor(xi, yi, cur_res, extent, ngrid):
FILE: cryodrgn/shift_grid3.py
function grid_1d (line 4) | def grid_1d(resol, extent, ngrid):
function grid_3d (line 11) | def grid_3d(resol: int, extent: int, ngrid: int) -> np.ndarray:
function base_shift_grid (line 19) | def base_shift_grid(extent: int, ngrid: int) -> np.ndarray:
function get_1d_neighbor (line 26) | def get_1d_neighbor(mini, curr_res, extent, ngrid):
function get_base_id (line 33) | def get_base_id(id_, ngrid):
function get_neighbor (line 40) | def get_neighbor(xi, yi, zi, curr_res, extent, ngrid):
FILE: cryodrgn/so3_grid.py
function grid_s1 (line 12) | def grid_s1(resol):
function grid_s2 (line 19) | def grid_s2(resol):
function hopf_to_quat (line 26) | def hopf_to_quat(theta, phi, psi):
function hopf_to_quat_tensor (line 47) | def hopf_to_quat_tensor(theta, phi, psi):
function grid_SO3 (line 71) | def grid_SO3(resol):
function s2_grid_SO3 (line 82) | def s2_grid_SO3(resol):
function get_s1_neighbor (line 91) | def get_s1_neighbor(mini, curr_res):
function get_s1_neighbor_tensor (line 108) | def get_s1_neighbor_tensor(mini, curr_res):
function get_s2_neighbor (line 128) | def get_s2_neighbor(mini, cur_res):
function get_s2_neighbor_tensor (line 137) | def get_s2_neighbor_tensor(mini, cur_res):
function get_base_ind (line 150) | def get_base_ind(ind, base):
function get_neighbor (line 160) | def get_neighbor(quat, s2i, s1i, cur_res):
function get_neighbor_tensor (line 180) | def get_neighbor_tensor(quat, q_ind, cur_res, device):
function pix2ang_tensor (line 232) | def pix2ang_tensor(n_side, i_pix, nest=False, lonlat=False):
function pix2ang (line 244) | def pix2ang(Nside, ipix, nest=False, lonlat=False):
FILE: cryodrgn/source.py
class ImageSource (line 47) | class ImageSource:
method __init__ (line 82) | def __init__(
method from_file (line 129) | def from_file(
method __len__ (line 184) | def __len__(self) -> int:
method lazy (line 188) | def lazy(self) -> bool:
method __getitem__ (line 192) | def __getitem__(self, item) -> torch.Tensor:
method __eq__ (line 195) | def __eq__(self, other):
method _convert_to_ndarray (line 198) | def _convert_to_ndarray(
method images (line 227) | def images(
method _images (line 268) | def _images(
method chunks (line 274) | def chunks(
method apix (line 292) | def apix(self) -> Union[None, float, np.ndarray]:
method write_mrc (line 296) | def write_mrc(
method get_default_mrc_header (line 347) | def get_default_mrc_header(self) -> MRCHeader:
class MRCFileSource (line 351) | class MRCFileSource(ImageSource):
method __init__ (line 354) | def __init__(
method _images (line 384) | def _images(
method write (line 436) | def write(
method apix (line 446) | def apix(self) -> float:
class _MRCDataFrameSource (line 450) | class _MRCDataFrameSource(ImageSource):
method __init__ (line 465) | def __init__(
method _images (line 499) | def _images(
method sources (line 531) | def sources(self) -> Iterator[tuple[str, MRCFileSource]]:
method parse_filename (line 534) | def parse_filename(self, filename: str) -> str:
class CsSource (line 563) | class CsSource(_MRCDataFrameSource):
method __init__ (line 566) | def __init__(
class TxtFileSource (line 593) | class TxtFileSource(_MRCDataFrameSource):
method __init__ (line 601) | def __init__(
method write (line 628) | def write(self, output_file: str):
class StarfileSource (line 634) | class StarfileSource(_MRCDataFrameSource, Starfile):
method __init__ (line 648) | def __init__(
FILE: cryodrgn/starfile.py
function parse_star (line 30) | def parse_star(starfile: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
function write_star (line 93) | def write_star(
function _write_star_block (line 112) | def _write_star_block(
class Starfile (line 130) | class Starfile:
method __init__ (line 151) | def __init__(
method load (line 177) | def load(cls, starfile: str) -> Self:
method write (line 181) | def write(self, outstar: str) -> None:
method relion31 (line 186) | def relion31(self) -> bool:
method __len__ (line 190) | def __len__(self) -> int:
method __eq__ (line 194) | def __eq__(self, other: Self) -> bool:
method get_optics_values (line 202) | def get_optics_values(
method set_optics_values (line 230) | def set_optics_values(self, fieldname: str, vals: Union[float, Iterabl...
method apix (line 310) | def apix(self) -> Union[None, np.ndarray]:
method resolution (line 315) | def resolution(self) -> Union[None, np.ndarray]:
method to_relion30 (line 323) | def to_relion30(self) -> pd.DataFrame:
FILE: cryodrgn/utils.py
function get_igraph_from_adjacency (line 21) | def get_igraph_from_adjacency(adjacency):
function meshgrid_2d (line 31) | def meshgrid_2d(lo, hi, n, endpoint=False):
class memoized (line 49) | class memoized(object):
method __init__ (line 55) | def __init__(self, func):
method __call__ (line 59) | def __call__(self, *args):
method __repr__ (line 71) | def __repr__(self):
method __get__ (line 75) | def __get__(self, obj, objtype):
function load_pkl (line 80) | def load_pkl(pkl: str):
function save_pkl (line 86) | def save_pkl(data, out_pkl: str, mode: str = "wb") -> None:
function load_yaml (line 93) | def load_yaml(yamlfile: str):
function save_yaml (line 98) | def save_yaml(data, out_yamlfile: str, mode: str = "w"):
function create_basedir (line 105) | def create_basedir(out: str) -> None:
function warn_file_exists (line 110) | def warn_file_exists(out: str) -> None:
function run_command (line 116) | def run_command(cmd: str) -> tuple[str, str]:
function R_from_eman (line 127) | def R_from_eman(a: np.ndarray, b: np.ndarray, y: np.ndarray) -> np.ndarray:
function R_from_relion (line 146) | def R_from_relion(euler: np.ndarray) -> np.ndarray:
function R_from_relion_scipy (line 188) | def R_from_relion_scipy(euler_: np.ndarray, degrees: bool = True) -> np....
function R_to_relion_scipy (line 206) | def R_to_relion_scipy(rot: np.ndarray, degrees: bool = True) -> np.ndarray:
function xrot (line 229) | def xrot(tilt_deg):
function _zero_sphere_helper (line 243) | def _zero_sphere_helper(D: int) -> Tuple[np.ndarray, np.ndarray]:
function zero_sphere (line 252) | def zero_sphere(vol: np.ndarray) -> np.ndarray:
function assert_pkl_close (line 262) | def assert_pkl_close(pkl_a: str, pkl_b: str, atol: float = 1e-4) -> None:
function low_pass_filter (line 272) | def low_pass_filter(vol, apix, low_pass_res):
function crop_real_space (line 305) | def crop_real_space(vol, D, deepcopy=False):
function get_latest_checkpoint (line 333) | def get_latest_checkpoint(outdir: str) -> tuple[str, Union[str, None]]:
FILE: testing/test_pose_search_rag12_128.py
function load_model (line 17) | def load_model(path, D):
function get_poses (line 28) | def get_poses(path, D):
function mse (line 37) | def mse(x, y):
function medse (line 44) | def medse(x, y):
function trans_offset (line 49) | def trans_offset(x, y):
function run (line 53) | def run(args):
FILE: testing/test_pose_search_real_128.py
function load_model (line 34) | def load_model(path):
function do_pose_search (line 48) | def do_pose_search(
function mse (line 68) | def mse(x, y):
function medse (line 75) | def medse(x, y):
function trans_offset (line 80) | def trans_offset(x, y):
function eval_pose_search (line 84) | def eval_pose_search(data, model, B=512, label="", **kwargs):
FILE: testing/test_pose_search_syn_64.py
function load_model (line 33) | def load_model(path):
function do_pose_search (line 48) | def do_pose_search(images, model, nkeptposes=24, Lmin=12, Lmax=24, niter...
function mse (line 65) | def mse(x, y):
function medse (line 72) | def medse(x, y):
function eval_pose_search (line 77) | def eval_pose_search(data, model, B=512, label="", **kwargs):
FILE: tests/conftest.py
function pytest_configure (line 19) | def pytest_configure():
function get_testing_datasets (line 23) | def get_testing_datasets(dataset_lbl: str) -> tuple[str, str]:
class DataFixture (line 99) | class DataFixture:
function produce_data_fixture (line 104) | def produce_data_fixture(
function particles (line 130) | def particles(request) -> Union[DataFixture, dict[str, DataFixture]]:
function poses (line 135) | def poses(request) -> Union[DataFixture, dict[str, DataFixture]]:
function ctf (line 140) | def ctf(request) -> Union[DataFixture, dict[str, DataFixture]]:
function indices (line 145) | def indices(request) -> Union[DataFixture, dict[str, DataFixture]]:
function datadir (line 150) | def datadir(request) -> Union[DataFixture, dict[str, DataFixture]]:
function trans (line 155) | def trans(request) -> Union[DataFixture, dict[str, DataFixture]]:
function volume (line 160) | def volume(request) -> Union[DataFixture, dict[str, DataFixture]]:
function weights (line 165) | def weights(request) -> Union[DataFixture, dict[str, DataFixture]]:
function configs (line 170) | def configs(request) -> Union[DataFixture, dict[str, DataFixture]]:
class TrainDir (line 174) | class TrainDir:
method __init__ (line 182) | def __init__(
method parse_request (line 224) | def parse_request(cls, req: dict[str, Any]) -> dict[str, Any]:
method out_files (line 253) | def out_files(self) -> list[str]:
method epoch_cleaned (line 256) | def epoch_cleaned(self, epoch: Union[int, None]) -> bool:
method all_files_present (line 285) | def all_files_present(self) -> bool:
method replace_files (line 291) | def replace_files(self) -> None:
method train_load_epoch (line 298) | def train_load_epoch(self, load_epoch: int, train_epochs: int) -> None:
function train_dir (line 330) | def train_dir(request, tmpdir_factory) -> Generator[TrainDir, None, None]:
function trained_dir (line 342) | def trained_dir(train_dir: TrainDir) -> Generator[TrainDir, None, None]:
function train_dirs (line 349) | def train_dirs(request) -> Generator[list[TrainDir], None, None]:
function trained_dirs (line 358) | def trained_dirs(train_dirs) -> Generator[list[TrainDir], None, None]:
class AbInitioDir (line 365) | class AbInitioDir:
method __init__ (line 368) | def __init__(
method parse_request (line 391) | def parse_request(cls, req: dict[str, Any]) -> dict[str, Any]:
method train (line 415) | def train(self, load_epoch: Optional[int] = None) -> None:
method analyze (line 443) | def analyze(self, analysis_epoch: int) -> None:
method backproject (line 450) | def backproject(self) -> None:
function abinit_dir (line 463) | def abinit_dir(request, tmpdir_factory) -> AbInitioDir:
function _dashboard_data_dir (line 481) | def _dashboard_data_dir() -> str:
function _dashboard_is_usable_workdir (line 485) | def _dashboard_is_usable_workdir(workdir: str) -> bool:
function _dashboard_run_train_and_analyze (line 495) | def _dashboard_run_train_and_analyze(workdir: str) -> None:
function dashboard_workdir (line 535) | def dashboard_workdir(tmp_path_factory: pytest.TempPathFactory) -> str:
function dashboard_experiment (line 552) | def dashboard_experiment(dashboard_workdir: str) -> DashboardExperiment:
function flask_client (line 557) | def flask_client(dashboard_workdir: str):
FILE: tests/test_add_psize.py
function test_add_psize (line 11) | def test_add_psize(tmpdir, volume, Apix):
FILE: tests/test_backprojection.py
class TestBackprojection (line 10) | class TestBackprojection:
method get_outdir (line 11) | def get_outdir(
method test_train (line 45) | def test_train(
method test_train_no_halfmaps (line 79) | def test_train_no_halfmaps(
method test_train_no_fscs (line 113) | def test_train_no_fscs(
method test_fidelity (line 150) | def test_fidelity(self, tmpdir_factory, particles, poses, ctf, indices...
method test_to_fsc (line 174) | def test_to_fsc(self, tmpdir_factory, particles, poses, ctf, indices, ...
class TestTiltBackprojection (line 194) | class TestTiltBackprojection:
method get_outdir (line 195) | def get_outdir(
method test_train (line 223) | def test_train(
method test_to_fsc (line 270) | def test_to_fsc(
FILE: tests/test_clean.py
function test_clean_here (line 10) | def test_clean_here(trained_dir, every_n: int) -> None:
function test_clean_one (line 36) | def test_clean_one(trained_dir, every_n: int) -> None:
function test_clean_two (line 70) | def test_clean_two(trained_dirs, every_n: int) -> None:
FILE: tests/test_dashboard_core.py
function _traj_flask_200_or_ineligible (line 70) | def _traj_flask_200_or_ineligible(r, experiment: DashboardExperiment) ->...
class TestParseIntFromDict (line 83) | class TestParseIntFromDict:
method test_coerce_and_clamp (line 90) | def test_coerce_and_clamp(self, raw: object, expected: int) -> None:
method test_non_numeric_falls_back_to_default (line 97) | def test_non_numeric_falls_back_to_default(self, bad: object) -> None:
method test_missing_key_returns_default (line 100) | def test_missing_key_returns_default(self) -> None:
method test_bad_value_returns_default_unclamped (line 103) | def test_bad_value_returns_default_unclamped(self) -> None:
method test_traj_points_wrapper_clamps (line 108) | def test_traj_points_wrapper_clamps(self) -> None:
method test_traj_interpolation_wrapper_allows_zero (line 113) | def test_traj_interpolation_wrapper_allows_zero(self) -> None:
method test_traj_neighbor_wrapper_bounds (line 119) | def test_traj_neighbor_wrapper_bounds(self) -> None:
class TestTrajectoryEligibilityError (line 125) | class TestTrajectoryEligibilityError:
method test_eligible_returns_none (line 128) | def test_eligible_returns_none(self, monkeypatch: pytest.MonkeyPatch) ...
method test_ineligible_returns_400_json (line 136) | def test_ineligible_returns_400_json(self, monkeypatch: pytest.MonkeyP...
class TestChimeraxRenderCmds (line 147) | class TestChimeraxRenderCmds:
method test_static_view_has_no_turn (line 150) | def test_static_view_has_no_turn(self) -> None:
method test_rotated_view_injects_turn (line 159) | def test_rotated_view_injects_turn(self) -> None:
method test_paths_with_spaces_are_quoted (line 165) | def test_paths_with_spaces_are_quoted(self) -> None:
class TestDashboardExperiment (line 183) | class TestDashboardExperiment:
method test_list_z_epochs (line 186) | def test_list_z_epochs(self, dashboard_workdir: str) -> None:
method test_load_experiment_shapes (line 189) | def test_load_experiment_shapes(
class TestDashboardPages (line 201) | class TestDashboardPages:
method test_page_renders (line 215) | def test_page_renders(self, flask_client, path: str) -> None:
class TestDashboardScatterApis (line 221) | class TestDashboardScatterApis:
method test_api_scatter_json (line 224) | def test_api_scatter_json(self, flask_client) -> None:
method test_api_scatter_no_color (line 231) | def test_api_scatter_no_color(self, flask_client) -> None:
method test_api_scatter3d_z (line 235) | def test_api_scatter3d_z(self, flask_client) -> None:
method test_api_latent3d_preview_png (line 240) | def test_api_latent3d_preview_png(self, flask_client) -> None:
method test_api_preview_montage (line 246) | def test_api_preview_montage(self, flask_client) -> None:
class TestDashboardPairPlot (line 252) | class TestDashboardPairPlot:
method test_api_pairplot (line 263) | def test_api_pairplot(
method test_pair_grid_png_is_deterministic (line 285) | def test_pair_grid_png_is_deterministic(
class TestDashboardTrajectoryCoords (line 307) | class TestDashboardTrajectoryCoords:
method test_direct_interpolation (line 310) | def test_direct_interpolation(
method test_nearest_mode (line 332) | def test_nearest_mode(
method test_anchor_driven_trajectory (line 350) | def test_anchor_driven_trajectory(
method test_kmeans_centers (line 368) | def test_kmeans_centers(
method test_random_indices (line 378) | def test_random_indices(
method test_default_endpoints (line 388) | def test_default_endpoints(
class TestDashboardZPkl (line 396) | class TestDashboardZPkl:
method test_z_pkl_matches_experiment (line 399) | def test_z_pkl_matches_experiment(
class TestParseAnchorIndicesTxt (line 417) | class TestParseAnchorIndicesTxt:
method test_parses_whitespace (line 420) | def test_parses_whitespace(self) -> None:
method test_parses_commas_and_semicolons (line 423) | def test_parses_commas_and_semicolons(self) -> None:
method test_strips_surrounding_whitespace_and_newlines (line 426) | def test_strips_surrounding_whitespace_and_newlines(self) -> None:
method test_allows_negative_and_zero (line 429) | def test_allows_negative_and_zero(self) -> None:
method test_rejects_single_index (line 434) | def test_rejects_single_index(self) -> None:
method test_rejects_empty (line 438) | def test_rejects_empty(self) -> None:
method test_rejects_non_integer_tokens (line 442) | def test_rejects_non_integer_tokens(self) -> None:
method test_rejects_floats (line 446) | def test_rejects_floats(self) -> None:
method test_rejects_non_utf8 (line 450) | def test_rejects_non_utf8(self) -> None:
class TestZTrajSavetxtRoundTrip (line 455) | class TestZTrajSavetxtRoundTrip:
method test_roundtrip_through_numpy (line 456) | def test_roundtrip_through_numpy(self) -> None:
class TestRoundDirectModeTrajXY (line 463) | class TestRoundDirectModeTrajXY:
method test_small_range_rounds_to_three_decimals (line 464) | def test_small_range_rounds_to_three_decimals(self) -> None:
method test_large_range_rounds_to_two_decimals (line 469) | def test_large_range_rounds_to_two_decimals(self) -> None:
method test_nan_entries_preserved (line 476) | def test_nan_entries_preserved(self) -> None:
class TestTrajectoryXYOkForDirect (line 484) | class TestTrajectoryXYOkForDirect:
method test_allowed_combinations (line 496) | def test_allowed_combinations(self, x: str, y: str, expected: bool) ->...
class TestTrajectoryDefaultXYCols (line 500) | class TestTrajectoryDefaultXYCols:
method test_prefers_pc_when_zdim_greater_than_2 (line 501) | def test_prefers_pc_when_zdim_greater_than_2(self) -> None:
method test_falls_back_to_umap_when_no_pc (line 505) | def test_falls_back_to_umap_when_no_pc(self) -> None:
method test_falls_back_to_first_two_when_no_embedding (line 509) | def test_falls_back_to_first_two_when_no_embedding(self) -> None:
class TestTrajectoryPlotAxisColumns (line 514) | class TestTrajectoryPlotAxisColumns:
method test_includes_all_z_pc_umap (line 515) | def test_includes_all_z_pc_umap(
method test_validate_raises_for_disallowed_pair (line 526) | def test_validate_raises_for_disallowed_pair(
method test_validate_accepts_allowed_pair (line 532) | def test_validate_accepts_allowed_pair(
class TestHasEmbeddingColumns (line 539) | class TestHasEmbeddingColumns:
method test_has_umap_and_pc_on_real_experiment (line 540) | def test_has_umap_and_pc_on_real_experiment(
class TestDijkstraFromNeighbors (line 547) | class TestDijkstraFromNeighbors:
method _line_graph (line 551) | def _line_graph() -> tuple[np.ndarray, np.ndarray]:
method test_same_node_returns_singleton (line 560) | def test_same_node_returns_singleton(self) -> None:
method test_path_across_line (line 564) | def test_path_across_line(self) -> None:
method test_disconnected_returns_none (line 569) | def test_disconnected_returns_none(self) -> None:
method test_prefers_shorter_edge (line 577) | def test_prefers_shorter_edge(self) -> None:
class TestGraphNeighborArrays (line 584) | class TestGraphNeighborArrays:
method test_shapes_and_cache (line 585) | def test_shapes_and_cache(self, dashboard_experiment: DashboardExperim...
class TestComputeDirectAnchorTrajectory (line 601) | class TestComputeDirectAnchorTrajectory:
method test_endpoints_and_interpolation_count (line 602) | def test_endpoints_and_interpolation_count(
method test_requires_two_anchors (line 618) | def test_requires_two_anchors(
method test_out_of_range_raises (line 624) | def test_out_of_range_raises(
class TestParseTrajectoryRequestBody (line 633) | class TestParseTrajectoryRequestBody:
method test_anchor_direct_happy_path (line 634) | def test_anchor_direct_happy_path(
method test_anchor_graph_happy_path (line 652) | def test_anchor_graph_happy_path(
method test_anchor_bad_mode_rejected (line 669) | def test_anchor_bad_mode_rejected(
method test_anchor_indices_not_int_rejected (line 683) | def test_anchor_indices_not_int_rejected(
method test_direct_requires_z_or_pc_axes (line 692) | def test_direct_requires_z_or_pc_axes(
method test_bad_mode_for_non_anchor_rejected (line 708) | def test_bad_mode_for_non_anchor_rejected(
method test_bad_start_end_rejected (line 723) | def test_bad_start_end_rejected(
method test_non_numeric_endpoints_rejected (line 732) | def test_non_numeric_endpoints_rejected(
method test_traj_xy_custom_overrides_start_end (line 747) | def test_traj_xy_custom_overrides_start_end(
class TestComputeTrajectoryLatentPath (line 758) | class TestComputeTrajectoryLatentPath:
method test_nearest_rows_land_on_real_particles (line 759) | def test_nearest_rows_land_on_real_particles(
method test_direct_endpoints_equal_data_points (line 779) | def test_direct_endpoints_equal_data_points(
class TestRandomDatasetIndices (line 800) | class TestRandomDatasetIndices:
method test_returns_distinct_indices_in_range (line 801) | def test_returns_distinct_indices_in_range(
method test_clips_to_available (line 810) | def test_clips_to_available(
class TestDefaultTrajectoryEndpointsXY (line 818) | class TestDefaultTrajectoryEndpointsXY:
method test_endpoints_span_long_axis (line 819) | def test_endpoints_span_long_axis(
method test_synthetic_tight_cluster_follows_pc1 (line 829) | def test_synthetic_tight_cluster_follows_pc1(self) -> None:
class TestTrajectoryAnchorModeParams (line 843) | class TestTrajectoryAnchorModeParams:
method test_direct_mode_clamps_interpolation (line 844) | def test_direct_mode_clamps_interpolation(self) -> None:
method test_graph_mode_clamps_anchor_count (line 852) | def test_graph_mode_clamps_anchor_count(self) -> None:
class TestDirectAnchorParticleIndicesPayload (line 860) | class TestDirectAnchorParticleIndicesPayload:
method test_no_interp_is_straight_passthrough (line 861) | def test_no_interp_is_straight_passthrough(self) -> None:
method test_interpolated_fills_none_between_anchors (line 867) | def test_interpolated_fills_none_between_anchors(self) -> None:
method test_total_mismatch_returns_none (line 874) | def test_total_mismatch_returns_none(self) -> None:
method test_too_few_anchors_returns_none (line 882) | def test_too_few_anchors_returns_none(self) -> None:
class TestTrajectoryAnchorPayloadFromIndices (line 891) | class TestTrajectoryAnchorPayloadFromIndices:
method test_direct_shape (line 892) | def test_direct_shape(self, dashboard_experiment: DashboardExperiment)...
class TestPlotDfRowsForDatasetIndices (line 907) | class TestPlotDfRowsForDatasetIndices:
method test_roundtrip_from_all_indices (line 908) | def test_roundtrip_from_all_indices(
method test_empty_input_returns_empty (line 915) | def test_empty_input_returns_empty(
class TestPreloadTimeHints (line 931) | class TestPreloadTimeHints:
method test_bounds_are_ordered_and_positive (line 933) | def test_bounds_are_ordered_and_positive(self, cpus: int) -> None:
method test_format_singular_core (line 938) | def test_format_singular_core(self) -> None:
method test_format_plural_cores_embeds_cpus (line 942) | def test_format_plural_cores_embeds_cpus(self) -> None:
class TestStratifiedXYRowIndices (line 948) | class TestStratifiedXYRowIndices:
method test_count_is_bounded_and_unique (line 949) | def test_count_is_bounded_and_unique(self) -> None:
method test_empty_coords_returns_empty (line 957) | def test_empty_coords_returns_empty(self) -> None:
class TestSamplePlotDfRowsForPreload (line 962) | class TestSamplePlotDfRowsForPreload:
method test_restrict_rows_returns_subset (line 963) | def test_restrict_rows_returns_subset(
method test_empty_restrict_returns_empty (line 973) | def test_empty_restrict_returns_empty(
class TestMontageBytes (line 982) | class TestMontageBytes:
method test_empty_returns_hint_png (line 983) | def test_empty_returns_hint_png(
method test_rows_returns_png (line 989) | def test_rows_returns_png(self, dashboard_experiment: DashboardExperim...
method test_too_many_rows_are_capped (line 994) | def test_too_many_rows_are_capped(
class TestParticleThumbnailB64FromRow (line 1002) | class TestParticleThumbnailB64FromRow:
method test_returns_base64_jpeg (line 1003) | def test_returns_base64_jpeg(
class TestEncodeParticleBatch (line 1013) | class TestEncodeParticleBatch:
method test_matches_request_count (line 1014) | def test_matches_request_count(
class TestLoadPlotDfRowsFromPlotIndsFile (line 1026) | class TestLoadPlotDfRowsFromPlotIndsFile:
method test_empty_path_returns_empty (line 1027) | def test_empty_path_returns_empty(
method test_missing_file_returns_empty (line 1033) | def test_missing_file_returns_empty(
method test_roundtrip_via_pickle (line 1043) | def test_roundtrip_via_pickle(
FILE: tests/test_dashboard_extended.py
function _traj_flask_200_or_ineligible (line 69) | def _traj_flask_200_or_ineligible(r, experiment: DashboardExperiment) ->...
class TestConfigHasCryodrgnCmd (line 82) | class TestConfigHasCryodrgnCmd:
method test_detects_cryodrgn (line 94) | def test_detects_cryodrgn(self, cfg: object, expected: bool) -> None:
class TestDiscoverCryodrgnWorkdirs (line 98) | class TestDiscoverCryodrgnWorkdirs:
method test_only_returns_cryodrgn_workdirs (line 99) | def test_only_returns_cryodrgn_workdirs(self, tmp_path) -> None:
method test_nonexistent_cwd_returns_empty (line 125) | def test_nonexistent_cwd_returns_empty(self, tmp_path) -> None:
class TestWorkdirOptions (line 129) | class TestWorkdirOptions:
method test_relative_labels (line 130) | def test_relative_labels(self, tmp_path) -> None:
class TestEpochsForWorkdir (line 141) | class TestEpochsForWorkdir:
method test_returns_analyzed_epochs_sorted (line 142) | def test_returns_analyzed_epochs_sorted(self, dashboard_workdir: str) ...
class TestAbbrevMiddle (line 149) | class TestAbbrevMiddle:
method test_short_unchanged (line 150) | def test_short_unchanged(self) -> None:
method test_long_uses_middle_ellipsis (line 153) | def test_long_uses_middle_ellipsis(self) -> None:
method test_none_returns_empty (line 161) | def test_none_returns_empty(self) -> None:
method test_small_maxlen_truncates_plainly (line 164) | def test_small_maxlen_truncates_plainly(self) -> None:
class TestAbbrevMiddleToken (line 168) | class TestAbbrevMiddleToken:
method test_short_unchanged (line 169) | def test_short_unchanged(self) -> None:
method test_long_has_ellipsis (line 172) | def test_long_has_ellipsis(self) -> None:
class TestCmdArgvForNavDisplay (line 179) | class TestCmdArgvForNavDisplay:
method test_python_m_cryodrgn (line 180) | def test_python_m_cryodrgn(self) -> None:
method test_entrypoint_first (line 185) | def test_entrypoint_first(self) -> None:
method test_python_wrapper_second (line 190) | def test_python_wrapper_second(self) -> None:
method test_empty_returns_empty (line 195) | def test_empty_returns_empty(self) -> None:
method test_unrecognised_is_unchanged (line 198) | def test_unrecognised_is_unchanged(self) -> None:
class TestArgvFourCommandLines (line 202) | class TestArgvFourCommandLines:
method test_empty (line 203) | def test_empty(self) -> None:
method test_single_token (line 206) | def test_single_token(self) -> None:
method test_head_is_two_tokens (line 209) | def test_head_is_two_tokens(self) -> None:
method test_long_token_is_abbreviated (line 217) | def test_long_token_is_abbreviated(self) -> None:
class TestClearExperimentCaches (line 224) | class TestClearExperimentCaches:
method test_clears_all_three_caches (line 225) | def test_clears_all_three_caches(
class TestCommandBuilderTemplateKwargs (line 240) | class TestCommandBuilderTemplateKwargs:
method test_no_experiment_uses_defaults (line 241) | def test_no_experiment_uses_defaults(self) -> None:
method test_with_experiment_uses_config (line 249) | def test_with_experiment_uses_config(
class TestActiveWorkdirAndResolveEpoch (line 260) | class TestActiveWorkdirAndResolveEpoch:
method test_active_workdir_returns_dashboard_workdir (line 261) | def test_active_workdir_returns_dashboard_workdir(
method test_no_workdir_returns_none (line 269) | def test_no_workdir_returns_none(self) -> None:
class TestNormalizeContinuousPalette (line 280) | class TestNormalizeContinuousPalette:
method test_cases (line 293) | def test_cases(self, raw: object, expected: str) -> None:
method test_mpl_cmap_for_palette (line 297) | def test_mpl_cmap_for_palette(self) -> None:
class TestContinuousSeriesStats (line 303) | class TestContinuousSeriesStats:
method test_constant_series_has_finite_span (line 304) | def test_constant_series_has_finite_span(self) -> None:
method test_all_nan_falls_back (line 309) | def test_all_nan_falls_back(self) -> None:
method test_mixed_series_uses_extrema (line 315) | def test_mixed_series_uses_extrema(self) -> None:
class TestPairGridHexAndSkeleton (line 321) | class TestPairGridHexAndSkeleton:
method test_hex_style_is_png_and_deterministic (line 322) | def test_hex_style_is_png_and_deterministic(
method test_placeholder_layout_shape (line 341) | def test_placeholder_layout_shape(self) -> None:
method test_placeholder_zdim_zero_is_empty (line 347) | def test_placeholder_zdim_zero_is_empty(self) -> None:
class TestIsDrgnaiConfig (line 356) | class TestIsDrgnaiConfig:
method test_recognises_drgnai (line 357) | def test_recognises_drgnai(self) -> None:
method test_classic_config_is_false (line 360) | def test_classic_config_is_false(self) -> None:
class TestConfigYamlPath (line 364) | class TestConfigYamlPath:
method test_returns_yaml_when_present (line 365) | def test_returns_yaml_when_present(self, dashboard_workdir: str) -> None:
method test_falls_back_to_pkl (line 368) | def test_falls_back_to_pkl(self, tmp_path) -> None:
method test_missing_raises (line 372) | def test_missing_raises(self, tmp_path) -> None:
class TestSortedVolMrcPaths (line 377) | class TestSortedVolMrcPaths:
method test_sorts_by_index_and_caps_count (line 378) | def test_sorts_by_index_and_caps_count(self, tmp_path) -> None:
method test_insufficient_volumes_raises (line 388) | def test_insufficient_volumes_raises(self, tmp_path) -> None:
class TestExplorerVolumesEligible (line 394) | class TestExplorerVolumesEligible:
method test_false_without_weights (line 395) | def test_false_without_weights(
method test_true_with_weights_and_cuda (line 414) | def test_true_with_weights_and_cuda(
method test_false_without_cuda (line 437) | def test_false_without_cuda(
class TestMplRetrimPng (line 446) | class TestMplRetrimPng:
method test_rewrites_png_in_place (line 447) | def test_rewrites_png_in_place(self, tmp_path) -> None:
class TestListZEpochs (line 466) | class TestListZEpochs:
method test_missing_workdir_returns_empty (line 467) | def test_missing_workdir_returns_empty(self, tmp_path) -> None:
method test_requires_matching_analyze_dir (line 470) | def test_requires_matching_analyze_dir(self, tmp_path) -> None:
method test_multi_epoch_sorted (line 477) | def test_multi_epoch_sorted(self, tmp_path) -> None:
class TestDashboardExperimentExtras (line 484) | class TestDashboardExperimentExtras:
method test_can_preview_particles_is_true (line 485) | def test_can_preview_particles_is_true(
method test_numeric_columns_exclude_index (line 490) | def test_numeric_columns_exclude_index(
class TestCommandBuilderSchemaIntegrity (line 503) | class TestCommandBuilderSchemaIntegrity:
method test_schema_covers_all_four_commands (line 506) | def test_schema_covers_all_four_commands(self) -> None:
method test_arg_ids_are_unique (line 515) | def test_arg_ids_are_unique(self, cmd: str) -> None:
method test_every_cli_flag_has_help_entry (line 520) | def test_every_cli_flag_has_help_entry(self, cmd: str) -> None:
class TestRequiredFieldTitles (line 535) | class TestRequiredFieldTitles:
method test_contains_expected_keys (line 536) | def test_contains_expected_keys(self) -> None:
method test_all_values_nonempty (line 555) | def test_all_values_nonempty(self) -> None:
class TestHelpMapFromCommandPy (line 560) | class TestHelpMapFromCommandPy:
method test_train_vae_has_common_flags (line 561) | def test_train_vae_has_common_flags(self) -> None:
class TestFlaskErrorPaths (line 576) | class TestFlaskErrorPaths:
method test_scatter_bad_axis (line 577) | def test_scatter_bad_axis(self, flask_client) -> None:
method test_scatter_bad_color (line 582) | def test_scatter_bad_color(self, flask_client) -> None:
method test_scatter3d_bad_color (line 586) | def test_scatter3d_bad_color(self, flask_client) -> None:
method test_latent3d_non_numeric_elev (line 590) | def test_latent3d_non_numeric_elev(self, flask_client) -> None:
method test_preview_montage_non_integer_rows (line 594) | def test_preview_montage_non_integer_rows(self, flask_client) -> None:
method test_pairplot_missing_color_col (line 598) | def test_pairplot_missing_color_col(self, flask_client) -> None:
method test_pairplot_bogus_diagonal (line 602) | def test_pairplot_bogus_diagonal(self, flask_client) -> None:
method test_pairplot_bogus_upper (line 613) | def test_pairplot_bogus_upper(self, flask_client) -> None:
method test_pairplot_z_as_color_rejected (line 624) | def test_pairplot_z_as_color_rejected(self, flask_client) -> None:
method test_save_selection_empty_rows (line 631) | def test_save_selection_empty_rows(self, flask_client) -> None:
method test_save_selection_out_of_range (line 635) | def test_save_selection_out_of_range(self, flask_client) -> None:
class TestTrajectoryImportAnchors (line 640) | class TestTrajectoryImportAnchors:
method test_import_happy_path (line 641) | def test_import_happy_path(
method test_rejects_non_txt (line 665) | def test_rejects_non_txt(self, flask_client, tmp_path) -> None:
method test_missing_file (line 674) | def test_missing_file(self, flask_client, tmp_path) -> None:
method test_missing_path_field (line 686) | def test_missing_path_field(self, flask_client) -> None:
class TestListServerFiles (line 691) | class TestListServerFiles:
method test_lists_workdir_contents (line 692) | def test_lists_workdir_contents(self, flask_client) -> None:
method test_bad_dir_returns_400 (line 699) | def test_bad_dir_returns_400(self, flask_client, tmp_path) -> None:
class TestSaveZPath (line 706) | class TestSaveZPath:
method test_roundtrip (line 707) | def test_roundtrip(
method test_non_string_txt_rejected (line 723) | def test_non_string_txt_rejected(self, flask_client) -> None:
class TestSaveSelectionRoundTrip (line 731) | class TestSaveSelectionRoundTrip:
method test_pkl_roundtrip (line 732) | def test_pkl_roundtrip(self, flask_client, tmp_path) -> None:
class TestSavePairPlotPng (line 750) | class TestSavePairPlotPng:
method test_writes_png (line 751) | def test_writes_png(self, flask_client, tmp_path) -> None:
class TestApiPreloadImages (line 769) | class TestApiPreloadImages:
method test_get_small_selection (line 770) | def test_get_small_selection(self, flask_client) -> None:
method test_post_body (line 778) | def test_post_body(self, flask_client) -> None:
method test_bad_selected_rows_rejected (line 787) | def test_bad_selected_rows_rejected(self, flask_client) -> None:
class TestSetEpochEndpoint (line 795) | class TestSetEpochEndpoint:
method test_post_same_epoch_succeeds (line 796) | def test_post_same_epoch_succeeds(self, flask_client) -> None:
method test_invalid_epoch_is_400 (line 801) | def test_invalid_epoch_is_400(self, flask_client) -> None:
method test_non_integer_epoch_is_400 (line 805) | def test_non_integer_epoch_is_400(self, flask_client) -> None:
method test_missing_epoch_is_400 (line 809) | def test_missing_epoch_is_400(self, flask_client) -> None:
class TestSetWorkdirEndpoint (line 814) | class TestSetWorkdirEndpoint:
method test_invalid_workdir_rejected_by_default_app (line 815) | def test_invalid_workdir_rejected_by_default_app(self, flask_client) -...
method test_clear_workdir_rejected_in_bound_mode (line 821) | def test_clear_workdir_rejected_in_bound_mode(self, flask_client) -> N...
method test_valid_workdir_switch_in_builder_only_mode (line 826) | def test_valid_workdir_switch_in_builder_only_mode(
method test_clear_in_builder_only_mode_succeeds (line 839) | def test_clear_in_builder_only_mode_succeeds(self, dashboard_workdir: ...
class TestRoutesTableIntegrity (line 847) | class TestRoutesTableIntegrity:
method test_every_entry_is_callable (line 848) | def test_every_entry_is_callable(self) -> None:
method test_create_app_registers_every_route (line 853) | def test_create_app_registers_every_route(self, dashboard_workdir: str...
class TestIndexTemplateNavLinks (line 860) | class TestIndexTemplateNavLinks:
method test_index_has_expected_nav_links (line 861) | def test_index_has_expected_nav_links(
class TestCommandBuilderOnlyMode (line 882) | class TestCommandBuilderOnlyMode:
method test_builder_only_index_renders (line 883) | def test_builder_only_index_renders(self, dashboard_workdir: str) -> N...
method test_builder_only_explorer_redirects_home (line 891) | def test_builder_only_explorer_redirects_home(self) -> None:
class TestDashboardCLI (line 904) | class TestDashboardCLI:
method _parse (line 906) | def _parse(argv: list[str]) -> argparse.Namespace:
method test_parses_minimal_args (line 911) | def test_parses_minimal_args(self) -> None:
method test_parses_with_outdir (line 922) | def test_parses_with_outdir(self) -> None:
method test_view_flag_aliases (line 928) | def test_view_flag_aliases(self) -> None:
method test_verbose_count_levels (line 932) | def test_verbose_count_levels(self) -> None:
method test_view_flags_are_mutually_exclusive (line 937) | def test_view_flags_are_mutually_exclusive(self) -> None:
method test_filter_max_points_sets_env (line 943) | def test_filter_max_points_sets_env(self, monkeypatch: pytest.MonkeyPa...
method test_builder_only_with_experiment_view_raises (line 952) | def test_builder_only_with_experiment_view_raises(
method test_builder_only_with_command_builder_ok (line 960) | def test_builder_only_with_command_builder_ok(
method test_main_invokes_run_server_with_outdir (line 974) | def test_main_invokes_run_server_with_outdir(
method test_main_configures_logging_from_verbose (line 990) | def test_main_configures_logging_from_verbose(
method test_default_logging_suppresses_werkzeug_internal (line 1004) | def test_default_logging_suppresses_werkzeug_internal(self) -> None:
FILE: tests/test_dataset.py
function test_particles (line 11) | def test_particles(particles):
class TestImageDatasetLoading (line 40) | class TestImageDatasetLoading:
method test_loading_slow (line 45) | def test_loading_slow(self, particles, indices, batch_size):
method test_loading_fast (line 70) | def test_loading_fast(self, particles, indices, batch_size):
class TestTiltSeriesLoading (line 116) | class TestTiltSeriesLoading:
method test_loading_slow (line 118) | def test_loading_slow(self, particles, indices, ntilts, batch_size):
function test_data_shuffler (line 187) | def test_data_shuffler(particles, indices, batch_size, buffer_size):
FILE: tests/test_direct_traversal.py
function test_fidelity_small (line 7) | def test_fidelity_small(tmpdir):
function test_fidelity_big (line 31) | def test_fidelity_big(tmpdir):
FILE: tests/test_downsample.py
class TestDownsampleToMRCS (line 40) | class TestDownsampleToMRCS:
method test_downsample (line 54) | def test_downsample(self, tmpdir, particles, datadir, indices, downsam...
method test_downsample_with_chunks (line 82) | def test_downsample_with_chunks(
function test_downsample_starout (line 145) | def test_downsample_starout(
function test_downsample_txtout (line 187) | def test_downsample_txtout(tmpdir, particles, outdir, downsample_dim):
function test_difficult_directory (line 221) | def test_difficult_directory(tmpdir, particles, datadir, newdatadir):
FILE: tests/test_entropy.py
function test_so3_entropy (line 6) | def test_so3_entropy():
FILE: tests/test_eval_images.py
function test_invert_contrast (line 12) | def test_invert_contrast(tmpdir, particles, poses, weights, configs):
FILE: tests/test_fft.py
function test_fft2 (line 11) | def test_fft2():
function test_fftn (line 15) | def test_fftn():
function test_ifftn (line 19) | def test_ifftn():
function test_fftshift (line 25) | def test_fftshift():
function test_ifftshift (line 32) | def test_ifftshift():
FILE: tests/test_filter_mrcs.py
function test_filter_mrcs (line 14) | def test_filter_mrcs(tmpdir, particles, ind_size, random_seed):
FILE: tests/test_filter_pkl.py
function test_select_clusters (line 13) | def test_select_clusters(tmpdir, total_size, ind_size, selected_size, ra...
function test_filter_ctf_pkl (line 44) | def test_filter_ctf_pkl(tmpdir, ctf, ind1, ind2):
function test_filter_pose_pkl (line 82) | def test_filter_pose_pkl(tmpdir, poses, ind1, ind2):
FILE: tests/test_flip_hand.py
function test_output (line 11) | def test_output(tmpdir, volume):
function test_mrc_file (line 40) | def test_mrc_file(tmpdir, volume):
function test_image_source (line 54) | def test_image_source(tmpdir, volume):
FILE: tests/test_fsc.py
function test_fidelity (line 14) | def test_fidelity(trained_dir) -> None:
function test_output_file (line 55) | def test_output_file(trained_dir, epochs: tuple[int, int]) -> None:
function test_apply_mask (line 85) | def test_apply_mask(trained_dir, epochs: tuple[int, int]) -> None:
function test_apply_phase_randomization (line 114) | def test_apply_phase_randomization(trained_dir) -> None:
function test_use_cryosparc_correction (line 135) | def test_use_cryosparc_correction(trained_dir) -> None:
function test_plotting (line 177) | def test_plotting(trained_dir, epochs: tuple[int, int]) -> None:
FILE: tests/test_graph_traversal.py
function test_fidelity_small (line 6) | def test_fidelity_small(tmpdir):
function test_fidelity_medium (line 41) | def test_fidelity_medium(tmpdir):
function test_no_path (line 73) | def test_no_path(tmpdir, ind1, ind2):
function test_fidelity_large (line 89) | def test_fidelity_large(tmpdir):
FILE: tests/test_integration.py
class TestIterativeFiltering (line 40) | class TestIterativeFiltering:
method get_outdir (line 41) | def get_outdir(self, tmpdir_factory, particles, indices, poses, ctf):
method test_train_model (line 50) | def test_train_model(self, tmpdir_factory, particles, poses, ctf, indi...
method test_analyze (line 82) | def test_analyze(self, tmpdir_factory, particles, poses, ctf, indices):
method test_notebooks (line 109) | def test_notebooks(self, tmpdir_factory, particles, poses, ctf, indice...
method test_refiltering (line 122) | def test_refiltering(self, tmpdir_factory, particles, poses, ctf, indi...
class TestParseWriteStar (line 168) | class TestParseWriteStar:
method get_outdir (line 169) | def get_outdir(self, tmpdir_factory, particles, datadir):
method test_parse_ctf_star (line 176) | def test_parse_ctf_star(self, tmpdir_factory, particles, datadir):
method test_write_star_from_mrcs (line 203) | def test_write_star_from_mrcs(
method test_parse_pose (line 234) | def test_parse_pose(self, tmpdir_factory, particles, datadir, indices,...
method test_downsample_and_from_txt (line 263) | def test_downsample_and_from_txt(
method test_backproject_from_downsample_txt (line 339) | def test_backproject_from_downsample_txt(
method test_backproject_from_downsample_star (line 397) | def test_backproject_from_downsample_star(
class TestBackprojectFromChunkedDownsampled (line 442) | class TestBackprojectFromChunkedDownsampled:
method get_outpaths (line 443) | def get_outpaths(
method test_downsample_with_chunks (line 460) | def test_downsample_with_chunks(
method test_backprojection_from_chunks (line 490) | def test_backprojection_from_chunks(
class TestBackprojectTilts (line 529) | class TestBackprojectTilts:
method get_outpaths (line 530) | def get_outpaths(self, tmpdir_factory, particles, poses, ctf, datadir):
method test_backprojection_from_newind (line 543) | def test_backprojection_from_newind(
FILE: tests/test_invert_contrast.py
function test_output (line 12) | def test_output(tmpdir, volume):
function test_mrc_file (line 41) | def test_mrc_file(tmpdir, volume):
function test_image_source (line 57) | def test_image_source(tmpdir, volume):
FILE: tests/test_masks.py
function hash_file (line 10) | def hash_file(filename: str) -> str:
function test_mask_fidelity (line 25) | def test_mask_fidelity(tmpdir, volume, dist, dilate, apix) -> None:
function test_png_output_file (line 83) | def test_png_output_file(tmpdir, volume, dist_val) -> None:
FILE: tests/test_mrc.py
function mrcs_data (line 9) | def mrcs_data():
function test_lazy_loading (line 15) | def test_lazy_loading(mrcs_data):
function test_star (line 29) | def test_star(mrcs_data):
function test_txt (line 42) | def test_txt(mrcs_data):
FILE: tests/test_parse.py
function particles_starfile (line 18) | def particles_starfile():
class TestCtfStar (line 22) | class TestCtfStar:
method get_outdir (line 23) | def get_outdir(self, tmpdir_factory, resolution):
method test_parse (line 31) | def test_parse(self, tmpdir_factory, particles_starfile, resolution):
method test_fidelity (line 55) | def test_fidelity(self, tmpdir_factory, particles_starfile, resolution):
method test_write_star_from_mrcs (line 61) | def test_write_star_from_mrcs(self, tmpdir_factory, particles_starfile...
function test_parse_ctf_cs (line 81) | def test_parse_ctf_cs(tmpdir, particles):
function test_parse_pose_star (line 92) | def test_parse_pose_star(tmpdir, particles_starfile):
function test_parse_star (line 103) | def test_parse_star(tmpdir, particles_starfile, resolution):
function test_parse_pose_cs (line 138) | def test_parse_pose_cs(tmpdir, particles):
FILE: tests/test_pc_traversal.py
function test_fidelity_small (line 6) | def test_fidelity_small():
function test_fidelity_big (line 20) | def test_fidelity_big():
FILE: tests/test_phase_flip.py
function test_phase_flip (line 7) | def test_phase_flip(tmpdir):
FILE: tests/test_read_filter_write.py
function input_cs_proj_dir (line 24) | def input_cs_proj_dir():
function test_read_mrcs (line 34) | def test_read_mrcs(particles, datadir):
function test_read_starfile (line 45) | def test_read_starfile(particles, datadir):
function test_concat_indices_pkls (line 54) | def test_concat_indices_pkls(tmpdir):
function test_filter (line 81) | def test_filter(tmpdir, particles, datadir, index_pair):
class TestFilterStar (line 120) | class TestFilterStar:
method test_filter_with_indices (line 121) | def test_filter_with_indices(self, tmpdir, particles, datadir):
method test_filter_with_separate_files (line 167) | def test_filter_with_separate_files(self, tmpdir, particles, indices, ...
class TestParseCTFWriteStar (line 206) | class TestParseCTFWriteStar:
method get_outdir (line 207) | def get_outdir(self, tmpdir_factory, particles, datadir):
method test_parse_ctf_star (line 236) | def test_parse_ctf_star(self, tmpdir_factory, particles, datadir, reso...
method test_write_star_from_mrcs (line 267) | def test_write_star_from_mrcs(self, tmpdir_factory, particles, datadir):
method test_write_filter_star (line 301) | def test_write_filter_star(self, tmpdir_factory, particles, datadir, u...
function test_filter_cs (line 333) | def test_filter_cs(tmpdir, particles):
FILE: tests/test_reconstruct_abinit.py
class TestAbinitHomo (line 28) | class TestAbinitHomo:
method get_outdir (line 49) | def get_outdir(self, tmpdir_factory, particles, ctf, indices):
method test_train_model (line 56) | def test_train_model(self, tmpdir_factory, particles, ctf, indices):
method test_load_checkpoint (line 89) | def test_load_checkpoint(self, tmpdir_factory, particles, ctf, indices):
class TestAbinitHetero (line 130) | class TestAbinitHetero:
method get_outdir (line 151) | def get_outdir(self, tmpdir_factory, particles, ctf, indices):
method test_train_model (line 158) | def test_train_model(self, tmpdir_factory, particles, ctf, indices):
method test_analyze (line 196) | def test_analyze(
method test_load_checkpoint (line 226) | def test_load_checkpoint(self, tmpdir_factory, particles, ctf, indices):
method test_load_poses (line 256) | def test_load_poses(self, tmpdir_factory, particles, ctf, indices):
method test_notebooks (line 289) | def test_notebooks(self, tmpdir_factory, particles, ctf, indices, nb_l...
method test_interactive_filtering (line 317) | def test_interactive_filtering(
method test_graph_traversal (line 354) | def test_graph_traversal(self, tmpdir_factory, particles, ctf, indices):
method test_analyze_landscape (line 377) | def test_analyze_landscape(self, tmpdir_factory, particles, ctf, indic...
method test_analyze_landscape_full (line 406) | def test_analyze_landscape_full(self, tmpdir_factory, particles, ctf, ...
FILE: tests/test_reconstruct_abinit_old.py
class TestAbinitHetero (line 40) | class TestAbinitHetero:
method get_outdir (line 65) | def get_outdir(self, tmpdir_factory, particles, ctf, indices):
method test_train_model (line 72) | def test_train_model(self, tmpdir_factory, particles, ctf, indices):
method test_analyze (line 106) | def test_analyze(
method test_notebooks (line 139) | def test_notebooks(self, tmpdir_factory, particles, ctf, indices, nb_l...
method test_interactive_filtering (line 167) | def test_interactive_filtering(
method test_graph_traversal (line 205) | def test_graph_traversal(self, tmpdir_factory, particles, ctf, indices...
method test_analyze_landscape (line 231) | def test_analyze_landscape(self, tmpdir_factory, particles, ctf, indic...
method test_analyze_landscape_full (line 251) | def test_analyze_landscape_full(self, tmpdir_factory, particles, ctf, ...
method test_eval_volume (line 258) | def test_eval_volume(self, tmpdir_factory, particles, ctf, indices):
function test_abinit_checkpoint_analysis_and_backproject (line 281) | def test_abinit_checkpoint_analysis_and_backproject(abinit_dir):
FILE: tests/test_reconstruct_fixed.py
class TestFixedHetero (line 40) | class TestFixedHetero:
method get_outdir (line 41) | def get_outdir(self, tmpdir_factory, train_cmd, particles, poses, ctf,...
method test_train_model (line 53) | def test_train_model(
method test_train_from_checkpoint (line 109) | def test_train_from_checkpoint(
method test_analyze (line 171) | def test_analyze(
method test_notebooks (line 202) | def test_notebooks(
method test_interactive_filtering (line 231) | def test_interactive_filtering(
method test_landscape (line 303) | def test_landscape(
method test_landscape_full (line 378) | def test_landscape_full(
method test_landscape_notebook (line 403) | def test_landscape_notebook(
method test_direct_traversal (line 443) | def test_direct_traversal(
method test_graph_traversal (line 474) | def test_graph_traversal(
method test_eval_volume (line 523) | def test_eval_volume(
method test_eval_images (line 546) | def test_eval_images(
method test_plot_classes (line 599) | def test_plot_classes(
method test_clean_all (line 649) | def test_clean_all(self, tmpdir_factory, train_cmd, particles, poses, ...
function test_homogeneous_with_poses (line 666) | def test_homogeneous_with_poses(tmpdir, particles, poses, batch_size, us...
function test_frompose_train_and_from_checkpoint (line 696) | def test_frompose_train_and_from_checkpoint(trained_dir, load_epoch, tra...
class TestStarFixedHomo (line 705) | class TestStarFixedHomo:
method test_train_model (line 708) | def test_train_model(self, tmpdir, particles, indices, poses, ctf, dat...
FILE: tests/test_reconstruct_tilt.py
class TestTiltFixedHetero (line 34) | class TestTiltFixedHetero:
method get_outdir (line 41) | def get_outdir(
method test_train_model (line 58) | def test_train_model(
method test_filter_command (line 98) | def test_filter_command(
method test_analyze (line 162) | def test_analyze(
method test_backproject (line 191) | def test_backproject(
method test_notebooks (line 231) | def test_notebooks(
method test_interactive_filtering (line 249) | def test_interactive_filtering(
method test_refiltering (line 282) | def test_refiltering(
class TestTiltAbinitHomo (line 330) | class TestTiltAbinitHomo:
method test_train_model (line 331) | def test_train_model(self, tmpdir, particles, indices, ctf, datadir):
class TestTiltAbinitHetero (line 370) | class TestTiltAbinitHetero:
method get_outdir (line 371) | def get_outdir(self, tmpdir_factory, particles, ctf, indices, datadir):
method test_train_model (line 384) | def test_train_model(self, tmpdir_factory, particles, indices, ctf, da...
method test_analyze (line 420) | def test_analyze(self, tmpdir_factory, particles, indices, ctf, datadi...
method test_notebooks (line 441) | def test_notebooks(self, tmpdir_factory, particles, indices, ctf, data...
method test_interactive_filtering (line 456) | def test_interactive_filtering(
method test_refiltering (line 488) | def test_refiltering(
FILE: tests/test_relion.py
function rln_starfile (line 15) | def rln_starfile(request):
class TestFilterStar (line 29) | class TestFilterStar:
method get_outdir (line 30) | def get_outdir(self, tmpdir_factory, rln_starfile, index_seed, index_f...
method test_command (line 39) | def test_command(self, tmpdir_factory, rln_starfile, index_seed, index...
method test_relion30_consistency (line 77) | def test_relion30_consistency(
class TestParsePoseStar (line 118) | class TestParsePoseStar:
method get_outdir (line 119) | def get_outdir(self, tmpdir_factory, rln_starfile, apix, resolution):
method test_command (line 128) | def test_command(self, tmpdir_factory, rln_starfile, apix, resolution):
method test_relion30_consistency (line 170) | def test_relion30_consistency(self, tmpdir_factory, rln_starfile, apix...
class TestParseCTFStar (line 215) | class TestParseCTFStar:
method get_outdir (line 216) | def get_outdir(self, tmpdir_factory, rln_starfile, apix, resolution, k...
method test_command (line 232) | def test_command(
method test_relion30_consistency (line 287) | def test_relion30_consistency(
function test_relion50 (line 328) | def test_relion50(tmpdir, rln_starfile):
FILE: tests/test_select_clusters.py
function test_select_clusters (line 12) | def test_select_clusters(tmpdir, cluster_count, chosen_count):
function test_select_clusters_parent_ind (line 37) | def test_select_clusters_parent_ind(tmpdir, cluster_count, chosen_count):
FILE: tests/test_select_random.py
function test_select_random_n (line 11) | def test_select_random_n(tmpdir, total_count, chosen_count):
function test_select_random_frac (line 31) | def test_select_random_frac(tmpdir, total_count, chosen_frac):
FILE: tests/test_source.py
function mrcs_data (line 9) | def mrcs_data():
function test_loading_mrcs (line 15) | def test_loading_mrcs(mrcs_data):
function test_loading_starfile (line 24) | def test_loading_starfile(mrcs_data):
function test_loading_txtfile (line 33) | def test_loading_txtfile(mrcs_data):
function test_loading_csfile (line 42) | def test_loading_csfile(mrcs_data):
function test_write_mrc (line 63) | def test_write_mrc(tmpdir, particles, indices, chunksize, transform_fn):
function test_source_iteration (line 76) | def test_source_iteration(particles, chunksize):
function test_prespecified_indices (line 93) | def test_prespecified_indices(mrcs_data):
function test_prespecified_indices_eager (line 108) | def test_prespecified_indices_eager(mrcs_data):
function test_txt_prespecified_indices (line 123) | def test_txt_prespecified_indices(mrcs_data):
function test_txt_prespecified_indices_contiguous (line 141) | def test_txt_prespecified_indices_contiguous(mrcs_data):
function test_txt_prespecified_indices_contiguous_eager (line 159) | def test_txt_prespecified_indices_contiguous_eager(mrcs_data):
function test_txt_prespecified_indices_eager (line 178) | def test_txt_prespecified_indices_eager(mrcs_data):
function test_prespecified_indices_chunked (line 197) | def test_prespecified_indices_chunked(mrcs_data):
function test_prespecified_indices_eager_chunked (line 212) | def test_prespecified_indices_eager_chunked(mrcs_data):
FILE: tests/test_translate.py
function test_shifted_image (line 13) | def test_shifted_image():
class TestTranslateStack (line 34) | class TestTranslateStack:
method get_outdir (line 37) | def get_outdir(self, tmpdir_factory, particles, trans):
method test_default_translate (line 47) | def test_default_translate(self, tmpdir_factory, particles, trans):
method test_filetype_consistency (line 57) | def test_filetype_consistency(self, tmpdir_factory, trans):
method test_tscales (line 77) | def test_tscales(self, tmpdir_factory, particles, trans, tscale):
method test_tscale_consistency (line 90) | def test_tscale_consistency(self, tmpdir_factory, particles, trans):
method test_png_output (line 103) | def test_png_output(self, tmpdir_factory, particles, trans):
FILE: tests/test_utils.py
function test_convert_from_relion_scipy (line 7) | def test_convert_from_relion_scipy():
function test_convert_from_relion (line 14) | def test_convert_from_relion():
function test_convert_to_relion (line 29) | def test_convert_to_relion():
FILE: tests/test_view_cs_header.py
function test_view_cs_header (line 7) | def test_view_cs_header():
FILE: tests/test_view_header.py
function test_view_header (line 7) | def test_view_header(particles):
FILE: tests/test_view_mrcs.py
function test_invert_contrast (line 9) | def test_invert_contrast(mock_pyplot_show):
FILE: tests/test_writestar.py
function particles_starfile (line 11) | def particles_starfile():
function relion31_mrcs (line 16) | def relion31_mrcs():
class TestBasic (line 49) | class TestBasic:
method test_from_mrcs (line 59) | def test_from_mrcs(
method test_from_txt (line 118) | def test_from_txt(
function test_from_txt_with_two_files (line 194) | def test_from_txt_with_two_files(
function test_relion31 (line 254) | def test_relion31(tmpdir, relion31_mrcs, ctf, indices):
Copy disabled (too large)
Download .json
Condensed preview — 234 files, each showing path, character count, and a content snippet. Download the .json file for the full structured content (12,578K chars).
[
{
"path": ".flake8",
"chars": 76,
"preview": "[flake8]\nextend-ignore = E203,E501\nmax-complexity = 99\nmax-line-length = 88\n"
},
{
"path": ".github/CODEOWNERS",
"chars": 21,
"preview": "* \t@michal-g @zhonge\n"
},
{
"path": ".github/ISSUE_TEMPLATE/bug_report.md",
"chars": 382,
"preview": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Describe the b"
},
{
"path": ".github/workflows/beta_release.yml",
"chars": 867,
"preview": "name: Beta Release\n\non:\n push:\n tags:\n - '[0-9]+\\.[0-9]+\\.[0-9]+-*'\n\njobs:\n beta-release:\n\n runs-on: ubuntu"
},
{
"path": ".github/workflows/release.yml",
"chars": 854,
"preview": "name: Release\n\non:\n push:\n tags:\n - '[0-9]+.[0-9]+.[0-9]+'\n - '!*-[a-z]+[0-9]+'\n\njobs:\n release:\n\n run"
},
{
"path": ".github/workflows/style.yml",
"chars": 736,
"preview": "name: Code Linting\n\non:\n push:\n branches: [ main, develop ]\n tags:\n - '[0-9]+\\.[0-9]+\\.[0-9]+'\n - '[0-9"
},
{
"path": ".github/workflows/tests.yml",
"chars": 1211,
"preview": "name: CI Testing\n\non:\n push:\n branches: [ develop ]\n tags:\n - '[0-9]+\\.[0-9]+\\.[0-9]+'\n - '[0-9]+\\.[0-9"
},
{
"path": ".gitignore",
"chars": 1211,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": ".pre-commit-config.yaml",
"chars": 609,
"preview": "# See https://pre-commit.com for more information\n# See https://pre-commit.com/hooks.html for more hooks\n\nexclude: '.cs$"
},
{
"path": "LICENSE.txt",
"chars": 35149,
"preview": " GNU GENERAL PUBLIC LICENSE\n Version 3, 29 June 2007\n\n Copyright (C) 2007 Free "
},
{
"path": "MANIFEST.in",
"chars": 34,
"preview": "include cryodrgn/templates/*ipynb\n"
},
{
"path": "README.md",
"chars": 30446,
"preview": ":\n if isinstance(schedule, float):\n return ConstantSchedul"
},
{
"path": "cryodrgn/command_line.py",
"chars": 5174,
"preview": "\"\"\"Creating the commands installed with cryoDRGN using the package's modules.\n\nHere we add modules under the `cryodrgn.c"
},
{
"path": "cryodrgn/commands/README.md",
"chars": 463,
"preview": "# cryoDRGN commands #\n\nThis folder contains the primary commands that are installed as part of the cryoDRGN package, as "
},
{
"path": "cryodrgn/commands/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cryodrgn/commands/abinit.py",
"chars": 63485,
"preview": "\"\"\"Reconstructing volume(s) from picked cryoEM/ET particles using cryoDRGN-AI.\n\nExample usage\n-------------\n# Run with f"
},
{
"path": "cryodrgn/commands/abinit_het_old.py",
"chars": 36889,
"preview": "\"\"\"Train a heterogeneous NN reconstruction model with hierarchical pose optimization.\n\nExample usage\n-------------\n# The"
},
{
"path": "cryodrgn/commands/abinit_homo_old.py",
"chars": 25111,
"preview": "\"\"\"Homogeneous neural net ab initio reconstruction with hierarchical pose optimization.\n\nExample usage\n-------------\n$ c"
},
{
"path": "cryodrgn/commands/analyze.py",
"chars": 19703,
"preview": "\"\"\"Visualize latent space and generate volumes from a trained cryoDRGN model.\n\nExample usage\n-------------\n$ cryodrgn an"
},
{
"path": "cryodrgn/commands/analyze_landscape.py",
"chars": 22495,
"preview": "\"\"\"Describe the latent space produced by a cryoDRGN model by directly comparing volumes.\n\nExample usage\n-------------\n# "
},
{
"path": "cryodrgn/commands/analyze_landscape_full.py",
"chars": 19987,
"preview": "\"\"\"Transform a cryoDRGN latent space to better capture differences between volumes.\n\nExample usage\n-------------\n$ cryod"
},
{
"path": "cryodrgn/commands/backproject_voxel.py",
"chars": 17485,
"preview": "\"\"\"Voxel-based backprojection to reconstruct a volume as well as half-maps.\n\nThis command performs volume reconstruction"
},
{
"path": "cryodrgn/commands/dashboard.py",
"chars": 7023,
"preview": "\"\"\"Launch a local web dashboard for cryoDRGN interactive analyses.\n\nThe dashboard opens in your browser with a particle "
},
{
"path": "cryodrgn/commands/direct_traversal.py",
"chars": 3370,
"preview": "\"\"\"Construct a path in z-latent-space interpolating directly between anchor points.\n\nExample usage\n-------------\n$ cryod"
},
{
"path": "cryodrgn/commands/downsample.py",
"chars": 12455,
"preview": "\"\"\"Downsample an image stack or volume to a lower resolution by clipping Fourier freqs.\n\nExample usage\n-------------\n# D"
},
{
"path": "cryodrgn/commands/eval_images.py",
"chars": 11253,
"preview": "\"\"\"Evaluate cryoDRGN model latent variables and loss for a stack of images.\n\nExample usage\n-------------\n\n$ cryodrgn eva"
},
{
"path": "cryodrgn/commands/eval_vol.py",
"chars": 7958,
"preview": "\"\"\"Evaluate the decoder of a heterogeneous model at given z-latent-space co-ordinates.\n\nExample usage\n-------------\n# Th"
},
{
"path": "cryodrgn/commands/filter.py",
"chars": 21220,
"preview": "\"\"\"Interactive filtering of particles plotted using various model variables.\n\nThis command opens an interactive interfac"
},
{
"path": "cryodrgn/commands/graph_traversal.py",
"chars": 9326,
"preview": "\"\"\"Construct the shortest path along a nearest neighbor graph in the latent z-space.\n\nExample usage\n-------------\n# Find"
},
{
"path": "cryodrgn/commands/parse_ctf_csparc.py",
"chars": 2532,
"preview": "\"\"\"Parse CTF parameters from a cryoSPARC particles.cs file\"\"\"\n\nimport argparse\nimport os\nimport pickle\nimport logging\nim"
},
{
"path": "cryodrgn/commands/parse_ctf_star.py",
"chars": 4051,
"preview": "\"\"\"Parse contrast transfer function values from a RELION .star file into separate file.\n\nThis command is often used as a"
},
{
"path": "cryodrgn/commands/parse_pose_csparc.py",
"chars": 2186,
"preview": "\"\"\"Parse image poses from a cryoSPARC .cs metafile\"\"\"\n\nimport argparse\nimport os\nimport pickle\nimport logging\nimport num"
},
{
"path": "cryodrgn/commands/parse_pose_star.py",
"chars": 3739,
"preview": "\"\"\"Parse image poses from RELION .star file into a separate file for use in cryoDRGN.\n\nThis command is often used as a p"
},
{
"path": "cryodrgn/commands/parse_star.py",
"chars": 2290,
"preview": "\"\"\"Parse image CTF and poses from RELION .star file into separate files for cryoDRGN.\n\nThis command is often used as a p"
},
{
"path": "cryodrgn/commands/pc_traversal.py",
"chars": 2811,
"preview": "\"\"\"Construct a path of embeddings in latent space along principal components.\n\nExample usage\n-------------\n$ cryodrgn pc"
},
{
"path": "cryodrgn/commands/train_dec.py",
"chars": 20379,
"preview": "\"\"\"Train an autodecoder\"\"\"\nimport argparse\nimport os\nimport sys\nfrom datetime import datetime as dt\nimport logging\nimpor"
},
{
"path": "cryodrgn/commands/train_nn.py",
"chars": 17632,
"preview": "\"\"\"Train a neural net to model a 3D density map given 2D images with pose assignments.\n\nExample usage\n-------------\n$ cr"
},
{
"path": "cryodrgn/commands/train_vae.py",
"chars": 32898,
"preview": "\"\"\"Train a VAE for heterogeneous reconstruction with known poses.\n\nExample usage\n-------------\n$ cryodrgn train_vae proj"
},
{
"path": "cryodrgn/commands_utils/README.md",
"chars": 472,
"preview": "# cryoDRGN utility commands #\n\nThis folder contains the supporting commands that are installed as part of the cryoDRGN p"
},
{
"path": "cryodrgn/commands_utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cryodrgn/commands_utils/add_psize.py",
"chars": 965,
"preview": "\"\"\"Add pixel size to the header of .mrc file containing a volume.\n\nExample usage\n-------------\n# Overwrite given file wi"
},
{
"path": "cryodrgn/commands_utils/analyze_convergence.py",
"chars": 45203,
"preview": "\"\"\"\nVisualize convergence and training dynamics\n(BETA -- contributed by Barrett Powell bmp@mit.edu)\n\nExample usage\n-----"
},
{
"path": "cryodrgn/commands_utils/clean.py",
"chars": 6692,
"preview": "\"\"\"Remove extraneous files from experiment output directories\n\nThis utility removes output files from cryoDRGN output di"
},
{
"path": "cryodrgn/commands_utils/concat_pkls.py",
"chars": 1115,
"preview": "\"\"\"Concatenate arrays from multiple .pkl files\"\"\"\n\nimport argparse\nimport logging\nimport numpy as np\nfrom cryodrgn.utils"
},
{
"path": "cryodrgn/commands_utils/filter_cs.py",
"chars": 1422,
"preview": "\"\"\"Create a CryoSparc .cs file from a particle stack, using poses and CTF if necessary.\n\nExample usage\n-------------\n$ c"
},
{
"path": "cryodrgn/commands_utils/filter_mrcs.py",
"chars": 1447,
"preview": "\"\"\"Filter a particle stack using given indices to produce a new stack file.\n\nExample usage\n-------------\ncryodrgn_utils "
},
{
"path": "cryodrgn/commands_utils/filter_pkl.py",
"chars": 2200,
"preview": "\"\"\"Filter cryoDRGN data stored in a .pkl file, writing to a new .pkl file.\n\nExample usage\n-------------\n$ cryodrgn_utils"
},
{
"path": "cryodrgn/commands_utils/filter_star.py",
"chars": 3189,
"preview": "\"\"\"Filter a .star file using a saved set of particle indices.\n\nExample usage\n-------------\n$ cryodrgn_utils filter_star "
},
{
"path": "cryodrgn/commands_utils/flip_hand.py",
"chars": 1170,
"preview": "\"\"\"Flip handedness of an .mrc file\n\nExample usage\n-------------\n# Writes to vol_000_flipped.mrc\n$ cryodrgn_utils flip_ha"
},
{
"path": "cryodrgn/commands_utils/fsc.py",
"chars": 16263,
"preview": "\"\"\"Compute Fourier shell correlations between two volumes, applying an optional mask.\n\nWhen using `--ref-volume`, this r"
},
{
"path": "cryodrgn/commands_utils/gen_mask.py",
"chars": 2159,
"preview": "\"\"\"Creating masking filters for 3D volumes using threshold dilation with a cosine edge.\n\nExample usage\n-------------\n$ c"
},
{
"path": "cryodrgn/commands_utils/invert_contrast.py",
"chars": 1160,
"preview": "\"\"\"Invert the contrast of an .mrc file\n\nExample usage\n-------------\n# Writes to vol_000_inverted.mrc\n$ cryodrgn_utils in"
},
{
"path": "cryodrgn/commands_utils/make_movies.py",
"chars": 8340,
"preview": "\"\"\"Make MP4 movies of .mrc volumes produced by cryodrgn analyze* commands.\n\nYou must install ChimeraX under the alias `c"
},
{
"path": "cryodrgn/commands_utils/parse_relion.py",
"chars": 13089,
"preview": "\"\"\"Parse .star files generated by RELION v5 into 2D particle coordinates.\n\nExample usage\n-------------\ncryodrgn_utils pa"
},
{
"path": "cryodrgn/commands_utils/phase_flip.py",
"chars": 2303,
"preview": "\"\"\"Phase flip images by CTF sign\"\"\"\n\nimport argparse\nimport os\nimport logging\nimport numpy as np\nimport torch\nfrom cryod"
},
{
"path": "cryodrgn/commands_utils/plot_classes.py",
"chars": 10225,
"preview": "\"\"\"Create plots of cryoDRGN model results arranged by given particle class labels.\n\nClass labels are expected to be save"
},
{
"path": "cryodrgn/commands_utils/plot_fsc.py",
"chars": 5396,
"preview": "\"\"\"Create a plot of one or more sets of computed Fourier shell correlations.\n\nExample usage\n-------------\n# Plot two cur"
},
{
"path": "cryodrgn/commands_utils/select_clusters.py",
"chars": 1586,
"preview": "\"\"\"Select particle or volume data based on (kmeans) cluster labels\"\"\"\n\nimport argparse\nimport os\nimport logging\nfrom cry"
},
{
"path": "cryodrgn/commands_utils/select_random.py",
"chars": 1727,
"preview": "\"\"\"Create an index corresponding to the selection of a random subset of particles.\n\nExample usage\n-------------\n# Sample"
},
{
"path": "cryodrgn/commands_utils/translate_mrcs.py",
"chars": 2540,
"preview": "\"\"\"Translate a particle stack by applying 2D shifts.\n\nExample usage\n-------------\n$ cryodrgn_utils translate_mrcs projec"
},
{
"path": "cryodrgn/commands_utils/view_cs_header.py",
"chars": 521,
"preview": "\"\"\"View the first row of a cryosparc .cs file\"\"\"\n\nimport argparse\n\nimport numpy as np\n\n\ndef add_args(parser):\n parser"
},
{
"path": "cryodrgn/commands_utils/view_header.py",
"chars": 640,
"preview": "\"\"\"View the header metadata of a .mrc or .mrcs file\"\"\"\n\nimport argparse\nfrom pprint import pprint\nimport logging\nfrom cr"
},
{
"path": "cryodrgn/commands_utils/view_mrcs.py",
"chars": 1481,
"preview": "\"\"\"View images in a particle stack\"\"\"\n\nimport argparse\nimport os\nimport logging\nimport matplotlib.pyplot as plt\nimport o"
},
{
"path": "cryodrgn/commands_utils/write_cs.py",
"chars": 1466,
"preview": "\"\"\"Create a CryoSparc .cs file from a particle stack, using poses and CTF if necessary.\n\nExample usage\n-------------\n$ c"
},
{
"path": "cryodrgn/commands_utils/write_star.py",
"chars": 8125,
"preview": "\"\"\"Create a Relion .star file from a given particle stack and CTF parameters.\n\nExample usage\n-------------\n# If using a "
},
{
"path": "cryodrgn/config.py",
"chars": 1762,
"preview": "\"\"\"Tools for working with cryoDRGN configuration parameters saved to .yaml files.\"\"\"\n\nimport os.path\nimport sys\nfrom dat"
},
{
"path": "cryodrgn/ctf.py",
"chars": 5384,
"preview": "from typing import Optional\nimport numpy as np\nimport torch\nimport logging\nfrom cryodrgn import utils\n\nlogger = logging."
},
{
"path": "cryodrgn/dashboard/__init__.py",
"chars": 179,
"preview": "\"\"\"Web dashboard for cryoDRGN interactive analyses.\"\"\"\n\n# Agg must be selected before any submodule imports matplotlib.p"
},
{
"path": "cryodrgn/dashboard/app.py",
"chars": 40960,
"preview": "\"\"\"Flask app for the cryoDRGN analysis dashboard.\n\nResponsibilities are split across small sibling modules:\n\n* :mod:`cry"
},
{
"path": "cryodrgn/dashboard/bench_plot_interfaces.py",
"chars": 4630,
"preview": "#!/usr/bin/env python3\n\"\"\"Time initial plot payloads for each dashboard analysis view (same paths as the Flask app).\n\nUs"
},
{
"path": "cryodrgn/dashboard/command_builder_cli_help.py",
"chars": 3655,
"preview": "\"\"\"Extract argparse ``help=`` strings from command modules without importing them.\n\n``abinit`` and training commands imp"
},
{
"path": "cryodrgn/dashboard/command_builder_data.py",
"chars": 30889,
"preview": "\"\"\"Structured optional-arg groups for the dashboard command builder.\n\nMirrors ``add_argument_group`` titles and flags fr"
},
{
"path": "cryodrgn/dashboard/context.py",
"chars": 18141,
"preview": "\"\"\"Workdir / epoch resolution, caches, and Jinja template context injectors.\n\nThis module owns the long-lived per-app st"
},
{
"path": "cryodrgn/dashboard/data.py",
"chars": 6881,
"preview": "\"\"\"Load heterogeneous-reconstruction outputs for dashboard UIs (mirrors `filter` command).\"\"\"\n\nfrom __future__ import an"
},
{
"path": "cryodrgn/dashboard/explorer_volumes.py",
"chars": 15292,
"preview": "\"\"\"On-demand decoder volumes + ChimeraX static PNGs and rotating GIFs for the particle explorer.\"\"\"\n\nfrom __future__ imp"
},
{
"path": "cryodrgn/dashboard/mpl_style.py",
"chars": 608,
"preview": "\"\"\"Matplotlib rc settings aligned with https://ezlab.princeton.edu/ (Barlow / Roboto).\"\"\"\n\nfrom __future__ import annota"
},
{
"path": "cryodrgn/dashboard/plots.py",
"chars": 39196,
"preview": "\"\"\"Plotly figures for the dashboard; pair grid uses Matplotlib + Seaborn (PNG).\"\"\"\n\nfrom __future__ import annotations\n\n"
},
{
"path": "cryodrgn/dashboard/preload.py",
"chars": 8537,
"preview": "\"\"\"Particle-thumbnail sampling, encoding, and montage helpers.\n\nThese feed the explorer's preview montage, the hover pre"
},
{
"path": "cryodrgn/dashboard/templates/base.html",
"chars": 33608,
"preview": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n <meta charset=\"utf-8\"/>\n <meta name=\"viewport\" content=\"width=device-width, i"
},
{
"path": "cryodrgn/dashboard/templates/command_builder.html",
"chars": 26472,
"preview": "{% extends \"base.html\" %}\n{% block title %}Command builder · cryoDRGN{% endblock %}\n{% block nav_page_title %}<span clas"
},
{
"path": "cryodrgn/dashboard/templates/index.html",
"chars": 10928,
"preview": "{% extends \"base.html\" %}\n{% block title %}cryoDRGN dashboard{% endblock %}\n{% block head %}\n<style>\n main:has(.landing"
},
{
"path": "cryodrgn/dashboard/templates/latent_3d.html",
"chars": 17448,
"preview": "{% extends \"base.html\" %}\n{% block title %}3-D Latent Space Visualizer · cryoDRGN{% endblock %}\n{% block nav_page_title "
},
{
"path": "cryodrgn/dashboard/templates/no_images.html",
"chars": 405,
"preview": "{% extends \"base.html\" %}\n{% block title %}Thumbnails unavailable · cryoDRGN{% endblock %}\n{% block nav_page_title %}<sp"
},
{
"path": "cryodrgn/dashboard/templates/pair_grid.html",
"chars": 25964,
"preview": "{% extends \"base.html\" %}\n{# 4×4 mini-grids: row i = top→bottom, col j = left→right; matches zᵢ (row) vs zⱼ (col) #}\n{% "
},
{
"path": "cryodrgn/dashboard/templates/pair_grid_need_more_cols.html",
"chars": 1046,
"preview": "{% extends \"base.html\" %}\n{% block title %}\n{% if kind == \"z3\" %}3-D Latent Space Visualizer · cryoDRGN{% else %}Pair gr"
},
{
"path": "cryodrgn/dashboard/templates/scatter_explorer.html",
"chars": 75240,
"preview": "{% extends \"base.html\" %}\n{% block title %}Particle explorer · cryoDRGN{% endblock %}\n{% block nav_page_title %}<span cl"
},
{
"path": "cryodrgn/dashboard/templates/trajectory_creator.html",
"chars": 79820,
"preview": "{% extends \"base.html\" %}\n{% block title %}Trajectory creator · cryoDRGN{% endblock %}\n{% block nav_page_title %}<span c"
},
{
"path": "cryodrgn/dashboard/trajectory.py",
"chars": 26100,
"preview": "\"\"\"Pure-logic helpers for the trajectory-creator view.\n\nNothing in this module imports Flask — the functions take a\n:cla"
},
{
"path": "cryodrgn/dataset.py",
"chars": 21359,
"preview": "\"\"\"Classes for using particle image datasets in PyTorch learning methods.\n\nThis module contains classes that implement v"
},
{
"path": "cryodrgn/fft.py",
"chars": 2891,
"preview": "\"\"\"Utility functions used in Fast Fourier transform calculations on image tensors.\"\"\"\n\nimport logging\nimport numpy as np"
},
{
"path": "cryodrgn/healpy_grid.json",
"chars": 10286969,
"preview": "{\"2\": [[1.2309594173407747, 0.8410686705679303, 0.8410686705679303, 0.4111378623223478, 1.2309594173407747, 0.8410686705"
},
{
"path": "cryodrgn/lattice.py",
"chars": 8048,
"preview": "\"\"\"Lattices used to represent spatial co-ordinates in reconstruction methods.\n\nExample usage\n-------------\n> from cryodr"
},
{
"path": "cryodrgn/lie_tools.py",
"chars": 9174,
"preview": "\"\"\"\nTools for dealing with SO(3) group and algebra\nAdapted from https://github.com/pimdh/lie-vae\nAll functions are pytor"
},
{
"path": "cryodrgn/losses.py",
"chars": 1917,
"preview": "\"\"\"Equivariance loss for Encoder\"\"\"\n\nfrom __future__ import annotations\nimport numpy as np\nimport torch\nimport torch.nn "
},
{
"path": "cryodrgn/make_healpy.py",
"chars": 319,
"preview": "import json\n\nimport healpy\nimport numpy as np\n\nx = {}\nfor r in range(7):\n Nside = 2 ** (r + 1)\n Npix = 12 * Nside "
},
{
"path": "cryodrgn/masking.py",
"chars": 7262,
"preview": "\"\"\"Filters applied to lattice coordinates as part of training.\"\"\"\n\nimport numpy as np\nimport torch\nfrom scipy.ndimage im"
},
{
"path": "cryodrgn/metrics.py",
"chars": 5486,
"preview": "\"\"\"\nMetrics\n\"\"\"\n\nimport numpy as np\nimport logging\nimport torch\nfrom cryodrgn import lie_tools\n\nlogger = logging.getLogg"
},
{
"path": "cryodrgn/models.py",
"chars": 43048,
"preview": "\"\"\"Pytorch models\"\"\"\n\nfrom typing import Optional, Tuple, Type, Sequence, Any\nimport numpy as np\nimport torch\nfrom torch"
},
{
"path": "cryodrgn/models_ai.py",
"chars": 38211,
"preview": "\"\"\"\nModels\n\"\"\"\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport time\n\nfrom "
},
{
"path": "cryodrgn/mrcfile.py",
"chars": 11014,
"preview": "\"\"\"Utilities for reading and writing .mrc/.mrcs files.\n\nExample usage\n-------------\n> from cryodrgn.mrcfile import parse"
},
{
"path": "cryodrgn/pose.py",
"chars": 6403,
"preview": "\"\"\"Keeping track of poses used under different embeddings in reconstruction models.\"\"\"\n\nimport pickle\nfrom typing import"
},
{
"path": "cryodrgn/pose_search.py",
"chars": 16488,
"preview": "import logging\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom typing import Optional, Union, Tuple"
},
{
"path": "cryodrgn/pose_search_ai.py",
"chars": 19814,
"preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom itertools import repeat\n\nfrom cryodrgn import shift"
},
{
"path": "cryodrgn/shift_grid.py",
"chars": 1528,
"preview": "import numpy as np\n\n\ndef grid_1d(resol: int, extent: int, ngrid: int, shift: int = 0) -> np.ndarray:\n Npix = ngrid * "
},
{
"path": "cryodrgn/shift_grid3.py",
"chars": 1543,
"preview": "import numpy as np\n\n\ndef grid_1d(resol, extent, ngrid):\n Npix = ngrid * 2**resol\n dt = 2 * extent / Npix\n grid "
},
{
"path": "cryodrgn/so3_grid.py",
"chars": 7385,
"preview": "\"\"\"\nImplementation of Yershova et al. \"Generating uniform incremental\ngrids on SO(3) using the Hopf fribration\"\n\"\"\"\n\nimp"
},
{
"path": "cryodrgn/source.py",
"chars": 25174,
"preview": "\"\"\"Classes for reading and using particle image data from various file formats.\n\nThis module contains the class hierarch"
},
{
"path": "cryodrgn/starfile.py",
"chars": 12982,
"preview": "\"\"\"Utilities for reading/loading data from .star files.\n\nExample usage\n-------------\n# Read in the tables of a .star fil"
},
{
"path": "cryodrgn/templates/cryoDRGN_ET_viz_template.ipynb",
"chars": 21211,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"scrolled\": false\n },\n \"source\": [\n \"# CryoDRG"
},
{
"path": "cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb",
"chars": 17088,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"# CryoDRGN landscape analysis\\n\",\n "
},
{
"path": "cryodrgn/templates/cryoDRGN_figures_template.ipynb",
"chars": 14106,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"12d6bdad\",\n \"metadata\": {},\n \"source\": [\n \"# CryoDRGN vis"
},
{
"path": "cryodrgn/templates/cryoDRGN_filtering_template.ipynb",
"chars": 31006,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"scrolled\": false\n },\n \"source\": [\n \"# CryoDRG"
},
{
"path": "cryodrgn/templates/cryoDRGN_viz_template.ipynb",
"chars": 14463,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"scrolled\": false\n },\n \"source\": [\n \"# CryoDRG"
},
{
"path": "cryodrgn/utils.py",
"chars": 10434,
"preview": "\"\"\"Utility functions shared between various cryoDRGN operations and commands.\"\"\"\n\nfrom collections.abc import Hashable\ni"
},
{
"path": "pyproject.toml",
"chars": 1632,
"preview": "[build-system]\nrequires = [\"setuptools>=61.0\", \"setuptools_scm>=6.2\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]"
},
{
"path": "sweep.sh",
"chars": 4604,
"preview": "#!/bin/bash\n\nset -e\n\npython setup.py develop\n\nfunction run {\n O=/checkpoint/$USER/cryodrgn\n if [ -d $O/$N ]; then\n"
},
{
"path": "testing/diff_cryodrgn_pkl.py",
"chars": 368,
"preview": "import pickle\nimport sys\n\na = sys.argv[1]\nb = sys.argv[2]\n\nwith open(a, \"rb\") as f:\n a = pickle.load(f)\nwith open(b, "
},
{
"path": "testing/test_abinit.sh",
"chars": 663,
"preview": "#!/bin/bash\nset -e\n\n# https://stackoverflow.com/questions/59895\nB=$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev"
},
{
"path": "testing/test_entropy.py",
"chars": 843,
"preview": "import numpy as np\nimport torch\n\nimport cryodrgn.lie_tools\n\navg = []\nstd = torch.tensor([2.3407, 1.0999, 1.2962])\nfor _ "
},
{
"path": "testing/test_pose_search_rag12_128.py",
"chars": 4962,
"preview": "import argparse\nimport os\nimport time\n\nimport numpy as np\nimport torch\n\nfrom cryodrgn import dataset, lattice, models, p"
},
{
"path": "testing/test_pose_search_real_128.py",
"chars": 4122,
"preview": "import time\n\nimport numpy as np\nimport torch\n\nfrom cryodrgn import dataset, lattice, models, pose_search, utils\n\nuse_cud"
},
{
"path": "testing/test_pose_search_syn_64.py",
"chars": 3266,
"preview": "import time\n\nimport numpy as np\nimport torch\n\nfrom cryodrgn import dataset, lattice, models, pose_search, utils\n\nuse_cud"
},
{
"path": "testing/test_sta.sh",
"chars": 224,
"preview": "#!/bin/bash\n\nset -e\nset -x\n\ncryodrgn train_vae data/sta_testing.star --datadir data/tilts/128 --encode-mode tilt --poses"
},
{
"path": "testing/test_translate.py",
"chars": 793,
"preview": "import matplotlib.pyplot as plt\nimport torch\nimport torch.nn as nn\nimport cryodrgn.fft\nimport cryodrgn.models\nimport cry"
},
{
"path": "tests/conftest.py",
"chars": 17712,
"preview": "\"\"\"Fixtures used across many unit test modules.\"\"\"\n\nimport pytest\nimport os\nimport argparse\nimport shutil\nfrom typing im"
},
{
"path": "tests/data/FinalRefinement-OriginalParticles-PfCRT.star",
"chars": 2667,
"preview": "# RELION; version 3.0.6\n\ndata_images\n\nloop_ \n_rlnAnglePsi #1 \n_rlnAngleRot #2 \n_rlnAngleTilt #3 \n_rlnClassNumber #4 \n_rl"
},
{
"path": "tests/data/empiar_10076_7.star",
"chars": 680,
"preview": "data_\n\nloop_\n_rlnImageName #1\n_rlnMicrographName #2\n_rlnDefocusU #3\n_rlnDefocusV #4\n_rlnDefocusAngle #5\n_rlnVoltage #6\n_"
},
{
"path": "tests/data/het_config.yaml",
"chars": 445,
"preview": "dataset_args:\n ctf: null\n datadir: null\n ind: null\n invert_data: true\n keepreal: false\n norm:\n - 0.0\n - 94.4266\n"
},
{
"path": "tests/data/relion31.6opticsgroups.star",
"chars": 40704,
"preview": "# Created 2024-05-20 17:39:07.660886\n\ndata_optics\n\nloop_\n_rlnOpticsGroup\n_rlnOpticsGroupName\n_rlnAmplitudeContrast\n_rlnS"
},
{
"path": "tests/data/relion31.star",
"chars": 3660,
"preview": "\n# version 30001\n\ndata_optics\n\nloop_ \n_rlnOpticsGroup #1 \n_rlnOpticsGroupName #2 \n_rlnAmplitudeContrast #3 \n_rlnSpherica"
},
{
"path": "tests/data/relion31.v2.star",
"chars": 3662,
"preview": "# version 30001\n\ndata_particles\n\nloop_ \n_rlnCoordinateX #1 \n_rlnCoordinateY #2 \n_rlnImageOriginalName #3 \n_rlnMicrograph"
},
{
"path": "tests/data/relion5.star",
"chars": 8877,
"preview": "\n# version 50001\ndata_optics\nloop_ \n_rlnOpticsGroupName #1 \n_rlnOpticsGroup #2 \n_rlnMicrographOriginalPixelSize #3 \n_rln"
},
{
"path": "tests/data/sta_testing.star",
"chars": 68301,
"preview": "\ndata_\n\nloop_\n_rlnMagnification #1\n_rlnDetectorPixelSize #2\n_rlnVoltage #3\n_rlnSphericalAberration #4\n_rlnAmplitudeContr"
},
{
"path": "tests/data/sta_testing_bin8.star",
"chars": 63045,
"preview": "\ndata_\n\nloop_\n_rlnMagnification #1\n_rlnDetectorPixelSize #2\n_rlnVoltage #3\n_rlnSphericalAberration #4\n_rlnAmplitudeContr"
},
{
"path": "tests/data/toy.star",
"chars": 65185,
"preview": "# Created 2019-11-11 22:12:27.005879\ndata_images\nloop_\n_rlnImageName\n_rlnDefocusU\n_rlnDefocusV\n_rlnDefocusAngle\n_rlnVolt"
},
{
"path": "tests/data/toy_projections.star",
"chars": 68192,
"preview": "# Created 2019-11-24 15:48:52.495156\ndata_images\nloop_\n_rlnImageName\n_rlnDefocusU\n_rlnDefocusV\n_rlnDefocusAngle\n_rlnVolt"
},
{
"path": "tests/data/toy_projections.txt",
"chars": 21,
"preview": "toy_projections.mrcs\n"
},
{
"path": "tests/data/toy_projections_13.star",
"chars": 1391,
"preview": "data_optics\n\nloop_\n_rlnOpticsGroup #1\n_rlnOpticsGroupName #2\n_rlnAmplitudeContrast #3\n_rlnSphericalAberration #4\n_rlnVol"
},
{
"path": "tests/data/toy_projections_2.txt",
"chars": 42,
"preview": "toy_projections.mrcs\ntoy_projections.mrcs\n"
},
{
"path": "tests/data/toy_projections_dir.star",
"chars": 65187,
"preview": "# Created 2024-07-15 14:49:04.546559\n\ndata_\n\nloop_\n_rlnImageName\n_rlnDefocusU\n_rlnDefocusV\n_rlnDefocusAngle\n_rlnVoltage\n"
},
{
"path": "tests/quicktest.sh",
"chars": 320,
"preview": "#!/usr/bin/env sh\nset -e\nset -x\ncryodrgn train_vae data/hand.mrcs -o output/toy_recon_vae --lr .0001 --seed 0 --poses d"
},
{
"path": "tests/test_add_psize.py",
"chars": 1072,
"preview": "import pytest\nimport os\nimport argparse\nimport torch\nfrom cryodrgn.source import ImageSource\nfrom cryodrgn.commands_util"
},
{
"path": "tests/test_backprojection.py",
"chars": 10747,
"preview": "import pytest\nimport os\nimport shutil\nimport argparse\nfrom cryodrgn.commands import backproject_voxel\nfrom cryodrgn.comm"
},
{
"path": "tests/test_clean.py",
"chars": 3335,
"preview": "\"\"\"Unit tests of the cryodrgn clean command.\"\"\"\n\nimport pytest\nimport os\nfrom cryodrgn.utils import run_command\n\n\n@pytes"
},
{
"path": "tests/test_dashboard_core.py",
"chars": 38803,
"preview": "\"\"\"Core dashboard tests (logic + primary API integration).\n\nSplit from ``tests/test_dashboard.py`` to keep module size m"
},
{
"path": "tests/test_dashboard_extended.py",
"chars": 36050,
"preview": "\"\"\"Extended dashboard tests (context, CLI, and extra API paths).\n\nSplit from ``tests/test_dashboard.py`` to keep module "
},
{
"path": "tests/test_dataset.py",
"chars": 8240,
"preview": "import pytest\nimport os\nimport numpy as np\nfrom torch.utils.data.sampler import BatchSampler, RandomSampler\nfrom torch.u"
},
{
"path": "tests/test_direct_traversal.py",
"chars": 1132,
"preview": "import pytest\nimport os\nimport numpy as np\nfrom cryodrgn.utils import run_command\n\n\ndef test_fidelity_small(tmpdir):\n "
},
{
"path": "tests/test_downsample.py",
"chars": 9765,
"preview": "\"\"\"Unit testing of the `cryodrgn downsample` command.\"\"\"\n\nimport pytest\nimport os\nimport shutil\nimport argparse\nimport t"
},
{
"path": "tests/test_entropy.py",
"chars": 420,
"preview": "import torch\n\nfrom cryodrgn import lie_tools\n\n\ndef test_so3_entropy():\n entropy = lie_tools.so3_entropy(\n w_ep"
},
{
"path": "tests/test_eval_images.py",
"chars": 834,
"preview": "import pytest\nimport os\nimport argparse\nfrom cryodrgn.commands import eval_images\n\n\n@pytest.mark.parametrize(\n \"parti"
},
{
"path": "tests/test_fft.py",
"chars": 825,
"preview": "import numpy as np\nimport numpy.fft\nimport torch\nimport torch.fft\n\n\nimg_np = np.random.random((100, 32, 32))\nimg_torch ="
},
{
"path": "tests/test_filter_mrcs.py",
"chars": 1206,
"preview": "import pytest\nimport os\nimport argparse\nimport numpy as np\nimport torch\nfrom cryodrgn.source import ImageSource\nfrom cry"
},
{
"path": "tests/test_filter_pkl.py",
"chars": 4477,
"preview": "import pytest\nimport os\nimport argparse\nimport numpy as np\nfrom cryodrgn.commands_utils import filter_pkl, concat_pkls\nf"
},
{
"path": "tests/test_flip_hand.py",
"chars": 2223,
"preview": "import pytest\nimport os\nimport shutil\nimport numpy as np\nfrom cryodrgn.source import ImageSource\nfrom cryodrgn.mrcfile i"
},
{
"path": "tests/test_fsc.py",
"chars": 7440,
"preview": "\"\"\"Unit tests of the cryodrgn fsc command.\"\"\"\nimport pandas as pd\nimport pytest\nimport os\nimport numpy as np\nfrom cryodr"
},
{
"path": "tests/test_graph_traversal.py",
"chars": 5366,
"preview": "import pytest\nimport os.path\nfrom cryodrgn.utils import run_command\n\n\ndef test_fidelity_small(tmpdir):\n zvals_fl = os"
},
{
"path": "tests/test_integration.py",
"chars": 20576,
"preview": "\"\"\"Integration tests of ab initio reconstruction and downstream analyses.\n\nNote that the training done here has unrealis"
},
{
"path": "tests/test_invert_contrast.py",
"chars": 2185,
"preview": "import pytest\nimport os\nimport shutil\nimport numpy as np\nimport torch\nfrom cryodrgn.source import ImageSource\nfrom cryod"
},
{
"path": "tests/test_masks.py",
"chars": 3544,
"preview": "\"\"\"Unit tests of the cryodrgn_utils gen_mask command.\"\"\"\n\nimport pytest\nimport os\nfrom hashlib import md5\nfrom cryodrgn."
},
{
"path": "tests/test_mrc.py",
"chars": 1358,
"preview": "import pytest\nimport os\nimport numpy as np\nimport torch\nfrom cryodrgn.source import ImageSource\n\n\n@pytest.fixture\ndef mr"
},
{
"path": "tests/test_parse.py",
"chars": 4821,
"preview": "import pytest\nimport argparse\nimport os\nimport shutil\n\nfrom cryodrgn.commands import (\n parse_ctf_csparc,\n parse_c"
},
{
"path": "tests/test_pc_traversal.py",
"chars": 1122,
"preview": "import pytest\nimport os.path\nfrom cryodrgn.utils import run_command\n\n\ndef test_fidelity_small():\n out, err = run_comm"
},
{
"path": "tests/test_phase_flip.py",
"chars": 427,
"preview": "import pytest\nimport os\nimport argparse\nfrom cryodrgn.commands_utils import phase_flip\n\n\ndef test_phase_flip(tmpdir):\n "
},
{
"path": "tests/test_read_filter_write.py",
"chars": 13386,
"preview": "import pandas as pd\nimport pytest\nimport argparse\nimport os\nimport shutil\nimport pickle\nimport numpy as np\nimport torch\n"
},
{
"path": "tests/test_reconstruct_abinit.py",
"chars": 14138,
"preview": "\"\"\"Running ab-initio volume reconstruction followed by downstream analyses.\"\"\"\n\nimport pytest\nimport argparse\nimport os."
},
{
"path": "tests/test_reconstruct_abinit_old.py",
"chars": 9506,
"preview": "\"\"\"Running ab-initio volume reconstruction followed by downstream analyses.\"\"\"\n\nimport pytest\nimport argparse\nimport os."
},
{
"path": "tests/test_reconstruct_fixed.py",
"chars": 23516,
"preview": "\"\"\"Running an experiment of training followed by downstream analyses.\"\"\"\n\nimport pytest\nimport argparse\nimport os.path\ni"
},
{
"path": "tests/test_reconstruct_tilt.py",
"chars": 18434,
"preview": "\"\"\"Running an experiment of training followed by downstream analyses.\"\"\"\n\nimport pytest\nimport argparse\nimport os.path\ni"
},
{
"path": "tests/test_relion.py",
"chars": 12490,
"preview": "\"\"\"Tests of compatibility with RELION formats used to produce input .star files.\"\"\"\n\nimport pytest\nimport argparse\nimpor"
},
{
"path": "tests/test_select_clusters.py",
"chars": 2641,
"preview": "import pytest\nimport os\nimport argparse\nimport numpy as np\nfrom cryodrgn.commands_utils import select_clusters\nfrom cryo"
},
{
"path": "tests/test_select_random.py",
"chars": 1771,
"preview": "import pytest\nimport os\nimport argparse\nfrom cryodrgn.commands_utils import select_random\nfrom cryodrgn.utils import loa"
},
{
"path": "tests/test_source.py",
"chars": 8894,
"preview": "import os.path\nimport numpy as np\nimport torch\nimport pytest\nfrom cryodrgn.source import ImageSource, MRCFileSource\n\n\n@p"
},
{
"path": "tests/test_translate.py",
"chars": 4555,
"preview": "import pytest\nimport os\nimport shutil\nimport argparse\nimport numpy as np\nimport torch\nfrom cryodrgn import fft\nfrom cryo"
},
{
"path": "tests/test_utils.py",
"chars": 902,
"preview": "import numpy as np\nfrom numpy.testing import assert_array_almost_equal\n\nfrom cryodrgn import utils\n\n\ndef test_convert_fr"
},
{
"path": "tests/test_view_cs_header.py",
"chars": 312,
"preview": "import os.path\nimport argparse\nimport pytest\nfrom cryodrgn.commands_utils import view_cs_header\n\n\ndef test_view_cs_heade"
},
{
"path": "tests/test_view_header.py",
"chars": 318,
"preview": "import pytest\nimport argparse\nfrom cryodrgn.commands_utils import view_header\n\n\n@pytest.mark.parametrize(\"particles\", [\""
},
{
"path": "tests/test_view_mrcs.py",
"chars": 413,
"preview": "import pytest\nimport os.path\nimport argparse\nfrom unittest.mock import patch\nfrom cryodrgn.commands_utils import view_mr"
},
{
"path": "tests/test_writestar.py",
"chars": 8737,
"preview": "import pandas as pd\nimport pytest\nimport os\nimport argparse\nfrom cryodrgn.commands_utils import write_star, filter_mrcs\n"
},
{
"path": "tests/unittest.sh",
"chars": 3896,
"preview": "# Script for running tests of cryoDRGN training and analysis methods outside of pytest\n# NOTE: must be run within the fo"
}
]
// ... and 43 more files (download for full content)
About this extraction
This page contains the full source code of the zhonge/cryodrgn GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 234 files (11.9 MB), approximately 3.1M tokens, and a symbol index with 1397 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.