Showing preview only (424K chars total). Download the full file or copy to clipboard to get everything.
Repository: kuleshov-group/caduceus
Branch: main
Commit: 0060a6d8079b
Files: 104
Total size: 394.9 KB
Directory structure:
gitextract_kw76odef/
├── .gitignore
├── LICENSE
├── README.md
├── caduceus/
│ ├── __init__.py
│ ├── configuration_caduceus.py
│ ├── modeling_caduceus.py
│ ├── modeling_rcps.py
│ ├── tests/
│ │ └── test_rcps.py
│ └── tokenization_caduceus.py
├── caduceus_env.yml
├── configs/
│ ├── callbacks/
│ │ ├── base.yaml
│ │ ├── checkpoint.yaml
│ │ ├── gpu_affinity.yaml
│ │ ├── rich.yaml
│ │ ├── val_every_n_global_steps.yaml
│ │ └── wandb.yaml
│ ├── config.yaml
│ ├── dataset/
│ │ ├── genomic_benchmark.yaml
│ │ ├── hg38.yaml
│ │ └── nucleotide_transformer.yaml
│ ├── experiment/
│ │ └── hg38/
│ │ ├── genomic_benchmark.yaml
│ │ ├── genomic_benchmark_cnn.yaml
│ │ ├── hg38.yaml
│ │ └── nucleotide_transformer.yaml
│ ├── loader/
│ │ └── default.yaml
│ ├── model/
│ │ ├── caduceus.yaml
│ │ ├── genomics_benchmark_cnn.yaml
│ │ ├── hyena.yaml
│ │ ├── layer/
│ │ │ └── hyena.yaml
│ │ └── mamba.yaml
│ ├── optimizer/
│ │ ├── adam.yaml
│ │ ├── adamw.yaml
│ │ └── sgd.yaml
│ ├── pipeline/
│ │ ├── genomic_benchmark.yaml
│ │ ├── hg38.yaml
│ │ └── nucleotide_transformer.yaml
│ ├── scheduler/
│ │ ├── constant.yaml
│ │ ├── constant_warmup.yaml
│ │ ├── cosine.yaml
│ │ ├── cosine_warmup.yaml
│ │ ├── cosine_warmup_timm.yaml
│ │ ├── linear_warmup.yaml
│ │ ├── multistep.yaml
│ │ ├── plateau.yaml
│ │ └── step.yaml
│ ├── task/
│ │ ├── lm.yaml
│ │ ├── multiclass_classification.yaml
│ │ ├── multilabel_classification.yaml
│ │ └── regression.yaml
│ └── trainer/
│ ├── debug.yaml
│ ├── default.yaml
│ ├── full.yaml
│ └── lm.yaml
├── setup_env.sh
├── slurm_scripts/
│ ├── dump_vep_embeddings.sh
│ ├── run_genomics_benchmark.sh
│ ├── run_genomics_benchmark_cnn.sh
│ ├── run_nucleotide_transformer.sh
│ ├── run_pretrain_caduceus.sh
│ ├── run_pretrain_hyena.sh
│ ├── run_pretrain_mamba.sh
│ ├── wrapper_run_genomics.sh
│ ├── wrapper_run_genomics_cnn.sh
│ └── wrapper_run_nucleotide_transformer.sh
├── src/
│ ├── __init__.py
│ ├── callbacks/
│ │ ├── params.py
│ │ ├── timer.py
│ │ └── validation.py
│ ├── dataloaders/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── datasets/
│ │ │ ├── genomic_bench_dataset.py
│ │ │ ├── hg38_char_tokenizer.py
│ │ │ ├── hg38_dataset.py
│ │ │ └── nucleotide_transformer_dataset.py
│ │ ├── fault_tolerant_sampler.py
│ │ ├── genomics.py
│ │ └── utils/
│ │ ├── mlm.py
│ │ └── rc.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── baseline/
│ │ │ ├── __init__.py
│ │ │ └── genomics_benchmark_cnn.py
│ │ ├── nn/
│ │ │ ├── __init__.py
│ │ │ ├── activation.py
│ │ │ ├── adaptive_softmax.py
│ │ │ └── utils.py
│ │ └── sequence/
│ │ ├── __init__.py
│ │ ├── dna_embedding.py
│ │ ├── hyena.py
│ │ └── long_conv_lm.py
│ ├── ops/
│ │ └── fftconv.py
│ ├── tasks/
│ │ ├── decoders.py
│ │ ├── encoders.py
│ │ ├── metrics.py
│ │ ├── tasks.py
│ │ └── torchmetrics.py
│ └── utils/
│ ├── __init__.py
│ ├── config.py
│ ├── optim/
│ │ └── schedulers.py
│ ├── optim_groups.py
│ ├── registry.py
│ └── train.py
├── train.py
├── vep_embeddings.py
└── vep_svm.ipynb
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
data.tar.gz
*.tsf
*.ckpt
.ipynb_checkpoints
*/.ipynb_checkpoints/*
*.lprof
.DS_Store
.idea/
outputs/
# slurm log files
watch_folder/
data
# Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# Mr Developer
.mr.developer.cfg
.project
.pydevproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# End of https://www.gitignore.io/api/python
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
<p align="center">
<img src="assets/Caduceus_image.png" alt="Caduceus" width="200"/>
</p>
# Caduceus ☤: Bi-Directional Equivariant Long-Range DNA Sequence Modeling
[[Blog]](https://caduceus-dna.github.io/) | [[arXiv]](https://arxiv.org/abs/2403.03234) | [[HuggingFace 🤗]](https://huggingface.co/collections/kuleshov-group/caducues-65dcb89b4f54e416ef61c350)
This repository contains code for reproducing the results in the paper "Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling," [Schiff et al. (2024)](https://arxiv.org/abs/2403.03234).
## Using Caduceus with 🤗
<a name="HF"></a>
We have uploaded a pre-trained Caduceus model to the Huggingface hub.
The available models are:
- Caduceus-Ph: [kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16](https://huggingface.co/kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16)
- Trained on sequences of length 131k, with a model size of 256 and 16 layers.
- Trained for 50k steps and batch size of 8.
- Trained with reverse-complement (RC) data augmentation.
- Caduceus-PS: [kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16](https://huggingface.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16)
- Trained on sequences of length 131k, with a model size of 256 and 16 layers.
- Trained for 50k steps and batch size of 8.
- Model is RC equivariant, hence no RC data augmentation is required.
You can either use the pre-trained model directly within your trainer scripts or modify the config that initializes the model.
To use the pre-trained model for masked language modeling, use the following snippet:
```python
from transformers import AutoModelForMaskedLM, AutoTokenizer
# See the `Caduceus` collection page on the hub for list of available models.
model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
```
Alternatively, you can instantiate a model from scratch to train on your own data as follows:
```python
from transformers import AutoConfig, AutoModelForMaskedLM
# Add any config overrides here, see the `config.json` file on the hub for details.
config_overrides = {}
# See the `Caduceus` collection page on the hub for list of available models.
config = AutoConfig.from_pretrained(
"kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16",
**config_overrides,
)
model = AutoModelForMaskedLM.from_config(config)
```
## Getting started in this repository
<a name="getting_started"></a>
To get started, create a conda environment containing the required dependencies.
```bash
conda env create -f caduceus_env.yml
```
Activate the environment.
```bash
conda activate caduceus_env
```
Create the following directories to store saved models and slurm logs:
```bash
mkdir outputs
mkdir watch_folder
```
## Reproducing Experiments
Below, we describe the steps required for reproducing the experiments in the paper.
Throughout, the main entry point for running experiments is the [`train.py`](./train.py) script.
We also provide sample `slurm` scripts for launching pre-training and downstream fine-tuning experiments in the [`slurm_scripts/`](./slurm_scripts) directory.
### Pretraining on Human Reference Genome
<a name="pretraining"></a>
(Data downloading instructions are copied from [HyenaDNA repo](https://github.com/HazyResearch/hyena-dna?tab=readme-ov-file#pretraining-on-human-reference-genome))
First, download the Human Reference Genome data.
It's comprised of 2 files, 1 with all the sequences (the `.fasta` file), and with the intervals we use (`.bed` file).
The file structure should look like
```
data
|-- hg38/
|-- hg38.ml.fa
|-- human-sequences.bed
```
Download fasta (.fa format) file (of the entire human genome) into `./data/hg38`.
~24 chromosomes in the whole genome (merged into 1 file), each chromosome is a continuous sequence, basically.
Then download the .bed file with sequence intervals (contains chromosome name, start, end, split, which then allow you to retrieve from the fasta file).
```bash
mkdir -p data/hg38/
curl https://storage.googleapis.com/basenji_barnyard2/hg38.ml.fa.gz > data/hg38/hg38.ml.fa.gz
gunzip data/hg38/hg38.ml.fa.gz # unzip the fasta file
curl https://storage.googleapis.com/basenji_barnyard2/sequences_human.bed > data/hg38/human-sequences.bed
```
Launch pretraining run using the command line
```bash
python -m train \
experiment=hg38/hg38 \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \
dataset.max_length=1024 \
dataset.batch_size=1024 \
dataset.mlm=true \
dataset.mlm_probability=0.15 \
dataset.rc_aug=false \
model=caduceus \
model.config.d_model=128 \
model.config.n_layer=4 \
model.config.bidirectional=true \
model.config.bidirectional_strategy=add \
model.config.bidirectional_weight_tie=true \
model.config.rcps=true \
optimizer.lr="8e-3" \
train.global_batch_size=1024 \
trainer.max_steps=10000 \
+trainer.val_check_interval=10000 \
wandb=null
```
or alternatively, if using a cluster that has `slurm` installed, adapt the scripts below:
```
slurm_scripts
|-- run_pretrain_caduceus.sh
|-- run_pretrain_hyena.sh
|-- run_pretrain_mamba.sh
```
and run the training as a batch job:
```bash
cd slurm_scripts
sbatch run_pretrain_caduceus.sh
```
### GenomicBenchmarks
<a name="genomicbenchmarks"></a>
The [GenomicBenchmarks](https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks) presented in [Grešová et al. (2023)](https://bmcgenomdata.biomedcentral.com/articles/10.1186/s12863-023-01123-8) is comprised of 8 classification tasks.
We can launch a downstream fine-tuning run on one of the tasks using the sample command below:
```bash
python -m train \
experiment=hg38/genomic_benchmark \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
dataset.dataset_name="dummy_mouse_enhancers_ensembl" \
dataset.train_val_split_seed=1 \
dataset.batch_size=256 \
dataset.rc_aug=false \
+dataset.conjoin_train=false \
+dataset.conjoin_test=false \
loader.num_workers=2 \
model=caduceus \
model._name_=dna_embedding_caduceus \
+model.config_path="<path to model_config.json>" \
+model.conjoin_test=false \
+decoder.conjoin_train=true \
+decoder.conjoin_test=false \
optimizer.lr="1e-3" \
trainer.max_epochs=10 \
train.pretrained_model_path="<path to .ckpt file>" \
wandb=null
```
This sample run will fine-tune a pre-trained Caduceus-PS model on the `dummy_mouse_enhancers_ensembl` task.
Note some of the additional arguments present here, relative to the pre-training command from [above](#pretraining):
- `model.config_path` contains the path model config that was saved during pre-training.
This will be saved to the run directory of the pre-training experiment.
- `train.pretrained_model_path` contains the path to the pre-trained model checkpoint.
- `dataset.conjoin_train` determines whether the dataset will return a single sequence (`dataset.conjoin_train=false`) or the concatenation of a sequence and its reverse complement along `dim=-1`, during downstream fine-tuning training.
- `dataset.conjoin_test` is the same as above, but for inference (e.g., validation / test).
- `decoder.conjoin_train` determines whether the prediction head (a mean pooling and linear projection in the case of the Genomics Benchmark) is expecting an input tensor of shape `(batch_size, seq_len, d_model)` or `(batch_size, seq_len, d_model, 2)` during downstream fine-tuning training.
When set to `true` the decoder is run on `input[..., 0]` and `input[..., 1]` and the results are averaged to produce the final prediction.
- `decoder.conjoin_test` is the same as above, but for inference (e.g., validation / test).
Note this benchmark only contains a training and test split for each task.
Therefore, to have a more principled evaluation, we randomly split the training data into training and validation sets (90/10) using the `dataset.train_val_split_seed` argument.
We perform early stopping on validation metric (accuracy) and repeat this for 5 random seeds.
As with [pre-training](#pretraining), we can also launch the fine-tuning run as a batch job using the provided [`run_genomic_benchmark.sh`](./slurm_scripts/run_genomics_benchmark.sh) script.
We also provide a helper shell script [`wrapper_run_genomics.sh`](./slurm_scripts/wrapper_run_genomics.sh) that can be used to launch multiple fine-tuning runs in parallel.
Finally, the [`run_genomics_benchmark_cnn.sh`](./slurm_scripts/run_genomics_benchmark_cnn.sh) script can be used to train the CNN baseline for this experiment from scratch on the downstream tasks.
### Nucleotide Transformer datasets
<a name="nucleotidetransformer"></a>
The Nucleotide Transformer suite of tasks was proposed in [Dalla-Torre et al. (2023)](https://www.biorxiv.org/content/10.1101/2023.01.11.523679v1).
The data is available on HuggingFace: [InstaDeepAI/nucleotide_transformer_downstream_tasks](https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_downstream_tasks).
We can launch a downstream fine-tuning run on one of the tasks using the sample command below:
```bash
python -m train \
experiment=hg38/nucleotide_transformer \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
dataset.dataset_name="${task}" \
dataset.train_val_split_seed=${seed} \
dataset.batch_size=${batch_size} \
dataset.rc_aug="${rc_aug}" \
+dataset.conjoin_test="${CONJOIN_TEST}" \
loader.num_workers=2 \
model._name_=dna_embedding_caduceus \
+model.config_path="<path to model_config.json>" \
+model.conjoin_test=false \
+decoder.conjoin_train=true \
+decoder.conjoin_test=false \
optimizer.lr="1e-3" \
trainer.max_epochs=10 \
train.pretrained_model_path="<path to .ckpt file>" \
trainer.max_epochs=20 \
wandb=null
```
We can also launch as batch jobs (see [`run_nucleotide_transformer.sh`](./slurm_scripts/run_nucleotide_transformer.sh) and [`wrapper_run_nucleotide_transformer.sh`](./slurm_scripts/wrapper_run_nucleotide_transformer.sh) for details).
### eQTL SNP Variant Effect Prediction
<a name="vep"></a>
This task comes from the recently proposed Long Range Benchmark (LRB) in [Kao et al., 2023](https://llms4science-community.github.io/papers/LLMs4Bio24_paper_12.pdf).
The data is available on HuggingFace: [InstaDeepAI/genomics-long-range-benchmark](https://huggingface.co/datasets/InstaDeepAI/genomics-long-range-benchmark).
For this task we fit a model to the pre-trained and frozen embeddings of the DNA language models.
Therefore, to perform the evaluation, we proceed in 2 steps:
- **Step 1: Extract the embeddings** from the pre-trained model:
Run the [`vep_embeddings.py`](./vep_embeddings.py) script to extract the embeddings from the pre-trained model.
See the example below:
```bash
torchrun \
--standalone \
--nnodes=1 \
--nproc-per-node=8 \
vep_embeddings.py \
--num_workers=2 \
--seq_len=131072 \
--bp_per_token=1 \
--embed_dump_batch_size=1 \
--name="caduceus-ps_downstream-seqlen=131k" \
--model_name_or_path="kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16" \
--rcps
```
The `--rcps` flag is used to indicate that the model is reverse-complement equivariant.
When using other models, set this flag to false with `--no-rcps`.
To speed this step up, this script utilizes torch distributed data parallelism.
Please refer to the slurm script provided in [`slurm_scripts/dump_vep_embeddings.sh`](./slurm_scripts/dump_vep_embeddings.sh)
to launch this step as a batch job.
- **Step 2: Fit an SVM model to the embeddings** using this notebook: [`vep_svm.ipynb`](./vep_svm.ipynb).
## Citation
<a name="citation"></a>
If you find our work useful, please cite our paper using the following:
```
@article{schiff2024caduceus,
title={Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling},
author={Schiff, Yair and Kao, Chia-Hsiang and Gokaslan, Aaron and Dao, Tri and Gu, Albert and Kuleshov, Volodymyr},
journal={arXiv preprint arXiv:2403.03234},
year={2024}
}
```
## Acknowledgements
<a name="acknowledgements"></a>
This repository is adapted from the [HyenaDNA repo](https://github.com/HazyResearch/hyena-dna) and leverages much of the training, data loading, and logging infrastructure defined there.
HyenaDNA was originally derived from the [S4](https://github.com/state-spaces/s4) and [Safari](https://github.com/HazyResearch/safari) repositories.
We would like to thank Evan Trop and the [InstaDeep](https://www.instadeep.com/) team for useful discussions about the [Nucleotide Transformer leaderboard](https://huggingface.co/spaces/InstaDeepAI/nucleotide_transformer_benchmark)
and the Long Range Benchmark task.
Finally, we would like to thank [MosaicML](https://www.mosaicml.com/) for providing compute resources for some of the pre-training experiments.
================================================
FILE: caduceus/__init__.py
================================================
"""Hugging Face config, model, and tokenizer for Caduceus.
"""
from .configuration_caduceus import CaduceusConfig
from .modeling_caduceus import Caduceus, CaduceusForMaskedLM, CaduceusForSequenceClassification
from .tokenization_caduceus import CaduceusTokenizer
================================================
FILE: caduceus/configuration_caduceus.py
================================================
"""Caduceus config for Hugging Face.
"""
from typing import Optional, Union
from transformers import PretrainedConfig
class CaduceusConfig(PretrainedConfig):
"""Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""
model_type = "caduceus"
def __init__(
self,
# From original MambaConfig
d_model: int = 2560,
n_layer: int = 64,
vocab_size: int = 50277,
ssm_cfg: Optional[dict] = None,
rms_norm: bool = True,
residual_in_fp32: bool = True,
fused_add_norm: bool = True,
pad_vocab_size_multiple: int = 8,
# Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
norm_epsilon: float = 1e-5,
# Used in init_weights
initializer_cfg: Optional[dict] = None,
# Caduceus-specific params
bidirectional: bool = True,
bidirectional_strategy: Union[str, None] = "add",
bidirectional_weight_tie: bool = True,
rcps: bool = False,
complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead
**kwargs,
):
super().__init__(**kwargs)
self.d_model = d_model
self.n_layer = n_layer
self.vocab_size = vocab_size
self.ssm_cfg = ssm_cfg
self.rms_norm = rms_norm
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.norm_epsilon = norm_epsilon
self.initializer_cfg = initializer_cfg
self.bidirectional = bidirectional
self.bidirectional_strategy = bidirectional_strategy
self.bidirectional_weight_tie = bidirectional_weight_tie
self.rcps = rcps
self.complement_map = complement_map
================================================
FILE: caduceus/modeling_caduceus.py
================================================
"""Caduceus model for Hugging Face.
"""
import inspect
import math
from functools import partial
from typing import Optional, Tuple, Union
import torch
from mamba_ssm.modules.mamba_simple import Mamba
try:
from mamba_ssm.modules.mamba_simple import Block # Legacy mambav1 file structure
except ImportError:
from mamba_ssm.modules.block import Block # mambav2 file structure
from torch import nn
from torch.nn import functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
try:
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
except ImportError:
try:
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
from .configuration_caduceus import CaduceusConfig
from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
def create_block(
d_model,
ssm_cfg=None,
norm_epsilon=1e-5,
rms_norm=False,
residual_in_fp32=False,
fused_add_norm=False,
layer_idx=None,
bidirectional=True,
bidirectional_strategy="add",
bidirectional_weight_tie=True,
rcps=False,
device=None,
dtype=None,
):
"""Create Caduceus block.
Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
"""
if ssm_cfg is None:
ssm_cfg = {}
factory_kwargs = {"device": device, "dtype": dtype}
bidirectional_kwargs = {
"bidirectional": bidirectional,
"bidirectional_strategy": bidirectional_strategy,
"bidirectional_weight_tie": bidirectional_weight_tie,
}
mixer_cls = partial(BiMambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs)
norm_cls = partial(
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
)
block_cls = RCPSMambaBlock if rcps else Block
# mambav2 compatibility
if "mlp_cls" in inspect.signature(block_cls.__init__).parameters:
block = block_cls(
d_model,
mixer_cls,
mlp_cls=nn.Identity,
norm_cls=norm_cls,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
)
else:
block = block_cls(
d_model,
mixer_cls,
norm_cls=norm_cls,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
)
block.layer_idx = layer_idx
return block
class BiMambaWrapper(nn.Module):
"""Thin wrapper around Mamba to support bi-directionality."""
def __init__(
self,
d_model: int,
bidirectional: bool = True,
bidirectional_strategy: Optional[str] = "add",
bidirectional_weight_tie: bool = True,
**mamba_kwargs,
):
super().__init__()
if bidirectional and bidirectional_strategy is None:
bidirectional_strategy = "add" # Default strategy: `add`
if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]:
raise NotImplementedError(f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!")
self.bidirectional = bidirectional
self.bidirectional_strategy = bidirectional_strategy
self.mamba_fwd = Mamba(
d_model=d_model,
**mamba_kwargs
)
if bidirectional:
self.mamba_rev = Mamba(
d_model=d_model,
**mamba_kwargs
)
if bidirectional_weight_tie: # Tie in and out projections (where most of param count lies)
self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight
self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias
self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight
self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias
else:
self.mamba_rev = None
def forward(self, hidden_states, inference_params=None):
"""Bidirectional-enabled forward pass
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
out = self.mamba_fwd(hidden_states, inference_params=inference_params)
if self.bidirectional:
out_rev = self.mamba_rev(
hidden_states.flip(dims=(1,)), # Flip along the sequence length dimension
inference_params=inference_params
).flip(dims=(1,)) # Flip back for combining with forward hidden states
if self.bidirectional_strategy == "add":
out = out + out_rev
elif self.bidirectional_strategy == "ew_multiply":
out = out * out_rev
else:
raise NotImplementedError(f"`{self.bidirectional_strategy}` for bi-directionality not implemented!")
return out
class CaduceusEmbeddings(nn.Module):
def __init__(
self,
config: CaduceusConfig,
device=None,
dtype=None,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
if config.rcps:
self.word_embeddings = RCPSEmbedding(
config.vocab_size, config.d_model, config.complement_map, **factory_kwargs
)
else:
self.word_embeddings = nn.Embedding(config.vocab_size, config.d_model, **factory_kwargs)
def forward(self, input_ids):
"""
input_ids: (batch, seqlen)
"""
return self.word_embeddings(input_ids)
class CaduceusMixerModel(nn.Module):
def __init__(
self,
config: CaduceusConfig,
device=None,
dtype=None,
) -> None:
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.fused_add_norm = config.fused_add_norm
self.rcps = config.rcps
self.residual_in_fp32 = config.residual_in_fp32
self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)
# Mamba changes the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Add, we do:
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
# the main branch (output of MLP / Mixer). The model definition is unchanged.
# This is for performance reason: we can fuse add + layer_norm.
if config.fused_add_norm:
if layer_norm_fn is None or rms_norm_fn is None:
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
self.layers = nn.ModuleList(
[
create_block(
config.d_model,
ssm_cfg=config.ssm_cfg,
norm_epsilon=config.norm_epsilon,
rms_norm=config.rms_norm,
residual_in_fp32=config.residual_in_fp32,
fused_add_norm=config.fused_add_norm,
layer_idx=i,
bidirectional=config.bidirectional,
bidirectional_strategy=config.bidirectional_strategy,
bidirectional_weight_tie=config.bidirectional_weight_tie,
rcps=config.rcps,
**factory_kwargs,
)
for i in range(config.n_layer)
]
)
norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
config.d_model, eps=config.norm_epsilon, **factory_kwargs
)
self.norm_f = norm_f if (config.fused_add_norm or not config.rcps) else RCPSAddNormWrapper(norm_f)
def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
"""Mixer forward."""
all_hidden_states = []
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids)
residual = None
for layer in self.layers:
if output_hidden_states:
all_hidden_states.append(hidden_states)
# TODO: Add support for gradient checkpointing
hidden_states, residual = layer(
hidden_states, residual, inference_params=None
)
if not self.fused_add_norm:
if self.rcps:
# Set prenorm=False here since we don't need the residual
hidden_states = self.norm_f(hidden_states, residual=residual, prenorm=False)
else:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
if self.rcps:
# Set prenorm=False here since we don't need the residual
hidden_states_fwd = fused_add_norm_fn(
hidden_states[..., :hidden_states.shape[-1] // 2],
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual[..., :hidden_states.shape[-1] // 2],
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
hidden_states_rc = fused_add_norm_fn(
hidden_states[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)
else:
# Set prenorm=False here since we don't need the residual
hidden_states = fused_add_norm_fn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
if output_hidden_states:
all_hidden_states.append(hidden_states)
return hidden_states, all_hidden_states
def cross_entropy(logits, y, ignore_index=-100):
"""Cross entropy loss."""
logits = logits.view(-1, logits.shape[-1])
y = y.view(-1)
return F.cross_entropy(logits, y, ignore_index=ignore_index)
def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
"""Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
logits = logits.view(-1, logits.shape[-1])
y = y.view(-1)
ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
loss_weights = loss_weights.view(-1)
loss_weights[y == ignore_index] = 0.0
# TODO: Follows GPN implementation, but should we remove weight normalization?
return (ce * (loss_weights / loss_weights.sum())).sum()
class CaduceusPreTrainedModel(PreTrainedModel):
"""PreTrainedModel wrapper for Caduceus backbone."""
config_class = CaduceusConfig
base_model_prefix = "caduceus"
supports_gradient_checkpointing = False
_no_split_modules = ["BiMambaWrapper"]
def _init_weights(
self,
module,
initializer_range=0.02, # Now only used for embedding layer.
**kwargs,
):
"""Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py"""
n_layer = self.config.n_layer
initialized_cfg = self.config.initializer_cfg if self.config.initializer_cfg is not None else {}
rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True)
initializer_range = initialized_cfg.get("initializer_range", initializer_range)
n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1)
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth.
# > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of
# residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(n_residuals_per_layer * n_layer)
class Caduceus(CaduceusPreTrainedModel):
"""Caduceus model that can be instantiated using HF patterns."""
def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
super().__init__(config)
if config.rcps:
assert config.complement_map is not None, "Complement map must be provided for RCPS."
# Adjust vocab size and complement maps if vocab padding is set.
if config.vocab_size % config.pad_vocab_size_multiple != 0:
config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple)
if config.complement_map is not None and config.vocab_size > len(config.complement_map):
for i in range(len(config.complement_map), config.vocab_size):
config.complement_map[i] = i
self.config = config
factory_kwargs = {"device": device, "dtype": dtype}
self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
"""HF-compatible forward method."""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states, all_hidden_states = self.backbone(
input_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states
)
if return_dict:
return BaseModelOutputWithNoAttention(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states if output_hidden_states else None
)
elif output_hidden_states:
return hidden_states, all_hidden_states
else:
return hidden_states
class CaduceusForMaskedLM(CaduceusPreTrainedModel):
"""HF-compatible Caduceus model for masked language modeling."""
def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
super().__init__(config, **kwargs)
factory_kwargs = {"device": device, "dtype": dtype}
self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
if config.rcps:
self.lm_head = RCPSLMHead(
complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
true_dim=config.d_model,
dtype=dtype
)
else:
self.lm_head = nn.Linear(
config.d_model,
self.config.vocab_size, # Use caduceus config as it might have been updated
bias=False,
**factory_kwargs
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.caduceus.backbone.embeddings.word_embeddings
def set_input_embeddings(self, value):
if self.config.rcps:
raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.")
self.caduceus.backbone.embeddings.word_embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
"""Overrides output embeddings."""
if self.config.rcps:
raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
self.lm_head = new_embeddings
def tie_weights(self):
"""Tie weights, accounting for RCPS."""
if self.config.rcps:
self.lm_head.set_weight(self.get_input_embeddings().weight)
else:
super().tie_weights()
def get_decoder(self):
"""Get decoder (backbone) for the model."""
return self.caduceus
def set_decoder(self, decoder):
"""Set decoder (backbone) for the model."""
self.caduceus = decoder
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
loss_weights: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
"""HF-compatible forward method."""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.caduceus(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
if loss_weights is not None:
loss = weighted_cross_entropy(logits, labels, loss_weights, ignore_index=self.config.pad_token_id)
else:
loss = cross_entropy(logits, labels, ignore_index=self.config.pad_token_id)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)
class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
def __init__(
self,
config: CaduceusConfig,
pooling_strategy: str = "mean",
conjoin_train: bool = False,
conjoin_eval: bool = False,
device=None,
dtype=None,
**kwargs):
super().__init__(config, **kwargs)
if pooling_strategy not in ["mean", "max", "first", "last"]:
raise NotImplementedError(f"Pooling strategy `{pooling_strategy}` not implemented.")
self.pooling_strategy = pooling_strategy
factory_kwargs = {"device": device, "dtype": dtype}
self.num_labels = kwargs.get("num_labels", config.num_labels)
self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
self.conjoin_train = conjoin_train
self.conjoin_eval = conjoin_eval
# Initialize weights and apply final processing
self.post_init()
self.init_scorer()
def init_scorer(self, initializer_range=0.02):
initializer_range = self.config.initializer_cfg.get("initializer_range", initializer_range) \
if self.config.initializer_cfg is not None else initializer_range
self.score.weight.data.normal_(std=initializer_range)
def get_input_embeddings(self):
return self.caduceus.backbone.embeddings.word_embeddings
def set_input_embeddings(self, value):
if self.config.rcps:
raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.")
self.caduceus.backbone.embeddings.word_embeddings = value
def pool_hidden_states(self, hidden_states, sequence_length_dim=1):
"""Pools hidden states along sequence length dimension."""
if self.pooling_strategy == "mean": # Mean pooling along sequence length dimension
return hidden_states.mean(dim=sequence_length_dim)
if self.pooling_strategy == "max": # Max pooling along sequence length dimension
return hidden_states.max(dim=sequence_length_dim).values
if self.pooling_strategy == "last": # Use embedding of last token in the sequence
return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[-1, ...]
if self.pooling_strategy == "first": # Use embedding of first token in the sequence
return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Get hidden representations from the backbone
if self.config.rcps: # Hidden states have 2 * d_model channels for RCPS
transformer_outputs = self.caduceus(
input_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = torch.stack(
[
transformer_outputs[0][..., :self.config.d_model],
torch.flip(transformer_outputs[0][..., self.config.d_model:], dims=[1, 2])
],
dim=-1
)
elif self.conjoin_train or (self.conjoin_eval and not self.training): # For conjoining / post-hoc conjoining
assert input_ids is not None, "`input_ids` must be provided for conjoining."
assert input_ids.ndim == 3, "`input_ids` must be 3D tensor: channels corresponds to forward and rc strands."
transformer_outputs = self.caduceus(
input_ids[..., 0],
inputs_embeds=None,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
transformer_outputs_rc = self.caduceus(
input_ids[..., 1],
inputs_embeds=None,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Stack along channel dimension (dim=-1)
hidden_states = torch.stack([transformer_outputs[0], transformer_outputs_rc[0]], dim=-1)
else:
transformer_outputs = self.caduceus(
input_ids,
inputs_embeds=None,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
# Pool and get logits
pooled_hidden_states = self.pool_hidden_states(hidden_states)
# Potentially run `score` twice (with parameters shared) for conjoining
if hidden_states.ndim == 4: # bsz, seq_len, hidden_dim, 2 where last channel has the stacked fwd and rc reps
logits_fwd = self.score(pooled_hidden_states[..., 0])
logits_rc = self.score(pooled_hidden_states[..., 1])
logits = (logits_fwd + logits_rc) / 2
else:
logits = self.score(pooled_hidden_states)
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
if self.num_labels == 1:
loss = F.mse_loss(logits.squeeze(), labels.squeeze())
else:
loss = F.mse_loss(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss = F.binary_cross_entropy_with_logits(logits, labels)
if not return_dict:
output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=transformer_outputs.hidden_states,
)
================================================
FILE: caduceus/modeling_rcps.py
================================================
"""Reverse-complement equivariant modules.
"""
from collections import OrderedDict
from typing import Optional
import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
try:
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
except ImportError:
try:
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
class RCPSEmbedding(nn.Module):
"""Embedding layer that supports reverse-complement equivariance."""
def __init__(self, vocab_size: int, d_model: int, complement_map: dict, **factory_kwargs):
"""
Args:
vocab_size: Size of vocabulary.
d_model: Dimensionality of embedding (actual embedding matrix will have 1/2 the output dim).
complement_map: Dictionary mapping each token id to its complement.
"""
super().__init__()
self.register_buffer(
"complement_map",
torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
)
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
@property
def weight(self):
"""Embedding weights."""
return self.embedding.weight
def set_weight(self, value):
"""Set embedding weights."""
self.embedding.weight = value
def rc(self, x):
"""Reverse-complement a tensor of input_ids by flipping along length dimension and complementing the ids."""
return torch.gather(
self.complement_map.unsqueeze(0).expand(x.shape[0], -1),
dim=1,
index=torch.flip(x, dims=[-1])
)
def forward(self, input_ids):
"""Reverse-complement equivariant forward pass.
This embedding module doubles the output dimensionality to support reverse-complement equivariance.
Args:
input_ids: Input tensor of shape (batch_size, seq_len)
Returns:
Embedding tensor of shape (batch_size, seq_len, d_model * 2)
"""
fwd_out = self.embedding(input_ids)
rc_out = torch.flip(self.embedding(self.rc(input_ids)), dims=[-2, -1])
return torch.cat([fwd_out, rc_out], dim=-1)
class RCPSWrapper(nn.Module):
"""Wrapper to convert arbitrary nn.Module into a reverse-complement equivariant module.
See ref. "Towards a Better Understanding of Reverse-Complement Equivariance for Deep Learning Models in Regulatory
Genomics", Zhou et al. (2022), https://proceedings.mlr.press/v165/zhou22a.html for more details.
"""
def __init__(self, submodule: nn.Module):
super().__init__()
self.submodule = submodule
@staticmethod
def rc(x):
"""Reverse-complement a tensor by flipping the length (dim=-2) and channel (dim=-1) dimensions."""
return torch.flip(x, dims=[-2, -1])
def forward(self, x, **kwargs):
"""Reverse-complement equivariant forward pass.
Args:
x: Input tensor of shape (batch_size, seq_len, channels)
Returns:
Output tensor of shape (batch_size, seq_len, channels)
"""
n_channels = x.shape[-1]
# Run submodule along sequence
fwd_out = self.submodule(x[..., :n_channels // 2], **kwargs)
# Run submodule along rc-sequence
rc_out = self.submodule(self.rc(x[..., n_channels // 2:]), **kwargs)
# Concatenate along channel dimension (dim=-1)
return torch.cat([fwd_out, self.rc(rc_out)], dim=-1)
class RCPSAddNormWrapper(RCPSWrapper):
"""RC equivariant AddNorm layer."""
def __init__(self, submodule: nn.Module):
super().__init__(submodule)
def forward(self, x, residual=None, prenorm=False):
"""
Args:
x: Input tensor of shape (batch_size, seq_len, channels)
residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
prenorm: Whether to return residual.
"""
n_channels = x.shape[-1]
if residual is None:
residual = x
x_fwd = self.submodule(x[..., :n_channels // 2].to(dtype=self.submodule.weight.dtype))
x_rc = self.submodule(self.rc(x[..., n_channels // 2:]).to(dtype=self.submodule.weight.dtype))
x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
else:
residual_fwd = x[..., :n_channels // 2] + residual[..., :n_channels // 2]
x_fwd = self.submodule(residual_fwd.to(dtype=self.submodule.weight.dtype))
residual_rc = self.rc(x[..., n_channels // 2:]) + self.rc(residual[..., n_channels // 2:])
x_rc = self.submodule(residual_rc.to(dtype=self.submodule.weight.dtype))
residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
return x if not prenorm else (x, residual)
class RCPSMambaBlock(nn.Module):
def __init__(
self,
dim,
mixer_cls,
norm_cls=nn.LayerNorm,
fused_add_norm=False,
residual_in_fp32=False,
device=None, # Keep for consistency with original Mamba Block
dtype=None, # Keep for consistency with original Mamba Block
):
"""RCPS version of simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection.
Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
"""
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.mixer = RCPSWrapper(mixer_cls(dim))
norm_f = norm_cls(dim)
self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm import fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
def forward(
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Mixer(LN(residual)).
inference_params: inference parameters for mixer.
"""
if not self.fused_add_norm:
hidden_states, residual = self.norm(hidden_states, residual=residual, prenorm=True)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
hidden_states_fwd, residual_fwd = fused_add_norm_fn(
hidden_states[..., hidden_states.shape[-1] // 2:],
self.norm.weight,
self.norm.bias,
residual=residual[..., hidden_states.shape[-1] // 2:] if residual is not None else None,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
)
hidden_states_rc, residual_rc = fused_add_norm_fn(
hidden_states[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]),
self.norm.weight,
self.norm.bias,
residual=residual[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]) if residual is not None else None,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
)
hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)
residual = torch.cat([residual_fwd, residual_rc.flip(dims=[-2, -1])], dim=-1)
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
return hidden_states, residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
"""Allocate inference cache for mixer.
Keep for compatibility with original Mamba Block.
"""
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
class RCPSLMHead(nn.Module):
"""LM Head for reverse-complement equivariant inputs, which have dim * 2 relative to standard inputs."""
def __init__(self, true_dim: int, vocab_size: int, complement_map: dict, **factory_kwargs):
"""
`true_dim` corresponds to the actual dimensionality of the input were it not reverse-complement
equivariant, i.e. 0.5 times the actual input dim.
"""
super().__init__()
self.register_buffer(
"complement_map",
torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
)
self.true_dim = true_dim
self.lm_head = nn.Linear(true_dim, vocab_size, bias=False, **factory_kwargs)
@property
def weight(self):
"""LM head weights."""
return self.lm_head.weight
def set_weight(self, value):
"""Set LM head weights."""
self.lm_head.weight = value
def forward(self, x):
"""
Args:
x: Input tensor of shape (batch_size, seq_len, dim), where dim = 2 * true_dim.
"""
n_channels = x.shape[-1]
assert n_channels == 2 * self.true_dim, "Input must have 2 * true_dim channels."
fwd_logits = F.linear(x[..., :n_channels // 2], self.weight, bias=self.lm_head.bias)
rc_logits = F.linear(
torch.flip(x[..., n_channels // 2:], dims=[-1]),
self.weight[self.complement_map, :],
bias=self.lm_head.bias
)
return fwd_logits + rc_logits
================================================
FILE: caduceus/tests/test_rcps.py
================================================
"""Tests for RCPS modules.
"""
import pytest
import torch
from torch import nn
from torch.nn import functional as F
try: # Legacy mambav1 file structure
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
try: # mambav2 file structure
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
from caduceus.modeling_rcps import (
RCPSEmbedding, RCPSAddNormWrapper, RCPSLMHead, RCPSWrapper
)
from caduceus.modeling_caduceus import (
CaduceusConfig, CaduceusMixerModel, CaduceusForMaskedLM, create_block
)
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [512])
@pytest.mark.parametrize("d_model", [256])
@pytest.mark.parametrize("dtype", [torch.float32])
def test_rcps_embedding(batch_size, seq_len, d_model, dtype):
# Set tolerance
device = torch.device("cpu")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Define complement map
str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6}
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
complement_map = {
str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v
for k, v in str_to_id.items()
}
vocab_size = 12
pad_vocab_size_multiple = 8
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
if vocab_size > len(complement_map):
for i in range(len(complement_map), vocab_size):
complement_map[i] = i
# Generate random sequences
input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)
rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device)
# Test RC equivariance of embedding layer
factory_kwargs = {"device": device, "dtype": dtype}
embedding = RCPSEmbedding(
vocab_size=vocab_size,
d_model=d_model,
complement_map=complement_map,
**factory_kwargs
).to(device)
out_embed = embedding(input_ids)
rc_out_embed = torch.flip(embedding(rc_input_ids), dims=[-2, -1])
# Test that channels are 2 * d_model
assert tuple(out_embed.size()) == (batch_size, seq_len, d_model * 2)
assert tuple(rc_out_embed.size()) == (batch_size, seq_len, d_model * 2)
# Test that RC equivariance holds
assert torch.allclose(out_embed.detach(), rc_out_embed.detach(), rtol=rtol, atol=atol)
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seq_len", [1024])
@pytest.mark.parametrize("d_model", [128])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_rcps_wrapper(batch_size, seq_len, d_model, dtype):
# Set tolerance
device = torch.device("cuda")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Generate random sequence with 2 * d_model channels
x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)
rc_x = torch.flip(x, dims=[-2, -1])
factory_kwargs = {"device": device, "dtype": dtype}
module = nn.Sequential(
nn.Linear(d_model, d_model, bias=False, **factory_kwargs),
nn.ReLU(),
nn.Linear(d_model, d_model*2, bias=True, **factory_kwargs),
nn.ReLU(),
nn.Linear(d_model * 2, d_model, bias=True, **factory_kwargs)
)
# Test RC equivariance of wrapper
rcps_module = RCPSWrapper(module).to(device)
out = rcps_module(x)
rc_out = torch.flip(rcps_module(rc_x), dims=[-2, -1])
assert out.size() == x.size()
assert rc_out.size() == x.size()
assert torch.allclose(out.detach(), rc_out.detach(), rtol=rtol, atol=atol)
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seq_len", [1024])
@pytest.mark.parametrize("d_model", [128])
@pytest.mark.parametrize("prenorm", [False, True])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_rcps_add_norm_wrapper(batch_size, seq_len, d_model, prenorm, dtype):
# Set tolerance
device = torch.device("cuda")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Generate random sequence with 2 * d_model channels
x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)
rc_x = torch.flip(x, dims=[-2, -1])
factory_kwargs = {"device": device, "dtype": dtype}
norm = RMSNorm(d_model, eps=1e-5, **factory_kwargs)
# Test RC equivariance of wrapper
rcps_module = RCPSAddNormWrapper(norm).to(device)
out = rcps_module(x, prenorm=prenorm)
if prenorm: # returns tuple
rc_out = tuple([torch.flip(r, dims=[-2, -1])
for r in rcps_module(rc_x, prenorm=prenorm)])
for f, r in zip(out, rc_out):
assert f.size() == x.size()
assert r.size() == x.size()
assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)
else:
rc_out = torch.flip(rcps_module(rc_x, prenorm=prenorm), dims=[-2, -1])
assert out.size() == x.size()
assert rc_out.size() == x.size()
assert torch.allclose(out.detach(), rc_out.detach(), rtol=rtol, atol=atol)
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seq_len", [1024])
@pytest.mark.parametrize("d_model", [128])
@pytest.mark.parametrize("bidirectional", [True, False])
@pytest.mark.parametrize("fused_add_norm", [True, False])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_rcps_mamba_block_wrapper(batch_size, seq_len, d_model, bidirectional, fused_add_norm, dtype):
# Set tolerance
device = torch.device("cuda")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Generate random sequence with 2 * d_model channels
x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)
rc_x = torch.flip(x, dims=[-2, -1])
ssm_cfg = {
"d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random",
"dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True
}
factory_kwargs = {"device": device, "dtype": dtype}
mamba_block = create_block(
d_model,
ssm_cfg=ssm_cfg,
norm_epsilon=1e-5,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=fused_add_norm,
layer_idx=0,
bidirectional=bidirectional,
bidirectional_strategy="add",
bidirectional_weight_tie=True,
rcps=True,
**factory_kwargs
)
# Test RC equivariance of wrapper
out = mamba_block(x, residual=None)
rc_out = tuple([torch.flip(r, dims=[-2, -1]) for r in mamba_block(rc_x, residual=None)])
for f, r in zip(out, rc_out):
assert f.size() == x.size()
assert r.size() == x.size()
assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)
out = mamba_block(x, residual=x)
rc_out = tuple([torch.flip(r, dims=[-2, -1]) for r in mamba_block(rc_x, residual=rc_x)])
for f, r in zip(out, rc_out):
assert f.size() == x.size()
assert r.size() == x.size()
assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)
@pytest.mark.parametrize("batch_size", [2, 4])
@pytest.mark.parametrize("seq_len", [1, 1024, 2048])
@pytest.mark.parametrize("d_model", [2, 128, 256])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_rcps_lm_head(batch_size, seq_len, d_model, dtype):
# Set tolerance
device = torch.device("cuda")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Define complement map
str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6}
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
complement_map = {
str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v
for k, v in str_to_id.items()
}
factory_kwargs = {"device": device, "dtype": dtype}
vocab_size = 12
if vocab_size > len(complement_map):
for i in range(len(complement_map), vocab_size):
complement_map[i] = i
# Instantiate LM head
lm_head = RCPSLMHead(
complement_map=complement_map,
vocab_size=vocab_size,
true_dim=d_model,
**factory_kwargs
)
# Generate random sequence with 2 * d_model channels
x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)
rc_x = torch.flip(x, dims=[-2, -1])
# Test RC equivariance of LM head
out = lm_head(x)
rc_out = lm_head(rc_x)
assert tuple(out.size()) == (batch_size, seq_len, vocab_size)
assert tuple(rc_out.size()) == (batch_size, seq_len, vocab_size)
assert torch.allclose(
out.detach(),
torch.flip(rc_out.detach()[..., lm_head.complement_map], dims=[1]),
rtol=rtol,
atol=atol
)
assert torch.allclose(
F.softmax(out, dim=-1).detach(),
torch.flip(F.softmax(rc_out, dim=-1).detach()[..., lm_head.complement_map], dims=[1]),
rtol=rtol,
atol=atol
)
@pytest.mark.parametrize("batch_size", [2, 4])
@pytest.mark.parametrize("seq_len", [1024, 2048])
@pytest.mark.parametrize("n_layer", [1, 2, 3])
@pytest.mark.parametrize("d_model", [128, 256])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize("fused_add_norm", [True, False])
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("bidirectional_weight_tie", [False, True])
def test_rcps_backbone(batch_size, seq_len, n_layer, d_model, dtype, fused_add_norm,
bidirectional, bidirectional_weight_tie):
# Set tolerance
device = torch.device("cuda")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Define complement map
str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6}
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
complement_map = {
str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v
for k, v in str_to_id.items()
}
# Setup CaduceusConfig
initializer_cfg = {"initializer_range": 0.02, "rescale_prenorm_residual": True, "n_residuals_per_layer": 1}
ssm_cfg = {
"d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random",
"dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True
}
config = CaduceusConfig(
d_model=d_model,
n_layer=n_layer,
vocab_size=12,
ssm_cfg=ssm_cfg,
rms_norm=True,
residual_in_fp32=False,
fused_add_norm=fused_add_norm,
pad_vocab_size_multiple=8,
norm_epsilon=1e-5,
initializer_cfg=initializer_cfg,
bidirectional=bidirectional,
bidirectional_strategy="add",
bidirectional_weight_tie=bidirectional_weight_tie,
rcps=True,
complement_map=complement_map,
)
factory_kwargs = {"device": device, "dtype": dtype}
# Instantiate model
backbone = CaduceusMixerModel(
config,
**factory_kwargs,
).to(device)
# Generate random sequences
input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)
rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device)
# Test RC equivariance of rc backbone
out = backbone(input_ids)[0]
rc_out = backbone(rc_input_ids)[0]
if isinstance(rc_out, tuple):
rc_out = tuple([torch.flip(r, dims=[1, 2]) for r in rc_out])
for f, r in zip(out, rc_out):
assert f.size() == (batch_size, seq_len, d_model * 2)
assert r.size() == (batch_size, seq_len, d_model * 2)
assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)
else:
# Hidden state size should double
assert tuple(out.size()) == (batch_size, seq_len, d_model * 2)
assert tuple(rc_out.size()) == (batch_size, seq_len, d_model * 2)
assert torch.allclose(out.detach(), torch.flip(rc_out.detach(), dims=[1, 2]), rtol=rtol, atol=atol)
@pytest.mark.parametrize("batch_size", [2, 4])
@pytest.mark.parametrize("seq_len", [1024, 2048])
@pytest.mark.parametrize("n_layer", [1, 3, 4])
@pytest.mark.parametrize("d_model", [128, 256])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("bidirectional_weight_tie", [False, True])
def test_rcps_mamba_lm(batch_size, seq_len, n_layer, d_model, dtype, bidirectional, bidirectional_weight_tie):
# Set tolerance
device = torch.device("cuda")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Define complement map
str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6}
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
complement_map = {
str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v
for k, v in str_to_id.items()
}
# Setup CaduceusConfig
initializer_cfg = {"initializer_range": 0.02, "rescale_prenorm_residual": True, "n_residuals_per_layer": 1}
ssm_cfg = {
"d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random",
"dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True
}
config = CaduceusConfig(
d_model=d_model,
n_layer=n_layer,
vocab_size=12,
ssm_cfg=ssm_cfg,
rms_norm=True,
residual_in_fp32=False,
fused_add_norm=True,
pad_vocab_size_multiple=8,
norm_epsilon=1e-5,
initializer_cfg=initializer_cfg,
bidirectional=bidirectional,
bidirectional_strategy="add",
bidirectional_weight_tie=bidirectional_weight_tie,
rcps=True,
complement_map=complement_map,
)
factory_kwargs = {"device": device, "dtype": dtype}
# Instantiate model
mamba_lm = CaduceusForMaskedLM(
config=config,
**factory_kwargs,
).to(device)
# Generate random sequences
input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)
rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device)
# Test RC equivariance of rc backbone
out = mamba_lm(input_ids)
rc_out = mamba_lm(rc_input_ids)
if config.vocab_size % config.pad_vocab_size_multiple != 0:
config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple)
assert tuple(out.logits.size()) == (batch_size, seq_len, config.vocab_size)
assert tuple(rc_out.logits.size()) == (batch_size, seq_len, config.vocab_size)
assert torch.allclose(
out.logits.detach(),
torch.flip(rc_out.logits.detach()[..., mamba_lm.lm_head.complement_map], dims=[1]),
rtol=rtol,
atol=atol
)
assert torch.allclose(
F.softmax(out.logits, dim=-1).detach(),
torch.flip(F.softmax(rc_out.logits, dim=-1).detach()[..., mamba_lm.lm_head.complement_map], dims=[1]),
rtol=rtol,
atol=atol
)
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seq_len", [1024])
@pytest.mark.parametrize("n_layer", [2])
@pytest.mark.parametrize("d_model", [128])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("bidirectional", [True, False])
@pytest.mark.parametrize("bidirectional_weight_tie", [True])
def test_collapse_invariance(batch_size, seq_len, n_layer, d_model, dtype, bidirectional, bidirectional_weight_tie):
# Set tolerance
device = torch.device("cuda")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Define complement map
str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6}
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
complement_map = {
str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v
for k, v in str_to_id.items()
}
# Setup CaduceusConfig
initializer_cfg = {"initializer_range": 0.02, "rescale_prenorm_residual": True, "n_residuals_per_layer": 1}
ssm_cfg = {
"d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random",
"dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True
}
config = CaduceusConfig(
d_model=d_model,
n_layer=n_layer,
vocab_size=12,
ssm_cfg=ssm_cfg,
rms_norm=True,
residual_in_fp32=False,
fused_add_norm=True,
pad_vocab_size_multiple=8,
norm_epsilon=1e-5,
initializer_cfg=initializer_cfg,
bidirectional=bidirectional,
bidirectional_strategy="add",
bidirectional_weight_tie=bidirectional_weight_tie,
rcps=True,
complement_map=complement_map,
)
factory_kwargs = {"device": device, "dtype": dtype}
# Instantiate model
backbone = CaduceusMixerModel(
config,
**factory_kwargs,
).to(device)
# Generate random sequences
input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)
rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device)
# Test RC Invariance when collapsing output of backbone
out = backbone(input_ids)[0]
out_collapse = (out[..., :d_model] + torch.flip(out[..., d_model:], dims=[1, 2])) / 2
rc_out = backbone(rc_input_ids)[0]
rc_out_collapse = (rc_out[..., :d_model] + torch.flip(rc_out[..., d_model:], dims=[1, 2])) / 2
# Hidden state size should be d_model
assert tuple(out_collapse.size()) == (batch_size, seq_len, d_model)
assert tuple(rc_out_collapse.size()) == (batch_size, seq_len, d_model)
assert torch.allclose(out_collapse.detach(), rc_out_collapse.detach(), rtol=rtol, atol=atol)
================================================
FILE: caduceus/tokenization_caduceus.py
================================================
"""Character tokenizer for Hugging Face.
"""
from typing import List, Optional, Dict, Sequence, Tuple
from transformers import PreTrainedTokenizer
class CaduceusTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids"]
def __init__(self,
model_max_length: int,
characters: Sequence[str] = ("A", "C", "G", "T", "N"),
complement_map=None,
bos_token="[BOS]",
eos_token="[SEP]",
sep_token="[SEP]",
cls_token="[CLS]",
pad_token="[PAD]",
mask_token="[MASK]",
unk_token="[UNK]",
**kwargs):
"""Character tokenizer for Hugging Face transformers.
Adapted from https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen-hf/blob/main/tokenization_hyena.py
Args:
model_max_length (int): Model maximum sequence length.
characters (Sequence[str]): List of desired characters. Any character which
is not included in this list will be replaced by a special token called
[UNK] with id=6. Following is a list of the special tokens with
their corresponding ids:
"[CLS]": 0
"[SEP]": 1
"[BOS]": 2
"[MASK]": 3
"[PAD]": 4
"[RESERVED]": 5
"[UNK]": 6
an id (starting at 7) will be assigned to each character.
complement_map (Optional[Dict[str, str]]): Dictionary with string complements for each character.
"""
if complement_map is None:
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A", "N": "N"}
self.characters = characters
self.model_max_length = model_max_length
self._vocab_str_to_int = {
"[CLS]": 0,
"[SEP]": 1,
"[BOS]": 2,
"[MASK]": 3,
"[PAD]": 4,
"[RESERVED]": 5,
"[UNK]": 6,
**{ch: i + 7 for i, ch in enumerate(self.characters)},
}
self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
add_prefix_space = kwargs.pop("add_prefix_space", False)
padding_side = kwargs.pop("padding_side", "left")
self._complement_map = {}
for k, v in self._vocab_str_to_int.items():
complement_id = self._vocab_str_to_int[complement_map[k]] if k in complement_map.keys() else v
self._complement_map[self._vocab_str_to_int[k]] = complement_id
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
unk_token=unk_token,
add_prefix_space=add_prefix_space,
model_max_length=model_max_length,
padding_side=padding_side,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self._vocab_str_to_int)
@property
def complement_map(self) -> Dict[int, int]:
return self._complement_map
def _tokenize(self, text: str, **kwargs) -> List[str]:
return list(text.upper()) # Convert all base pairs to uppercase
def _convert_token_to_id(self, token: str) -> int:
return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
def _convert_id_to_token(self, index: int) -> str:
return self._vocab_int_to_str[index]
def convert_tokens_to_string(self, tokens):
return "".join(tokens) # Note: this operation has lost info about which base pairs were originally lowercase
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
result = ([0] * len(token_ids_0)) + [1]
if token_ids_1 is not None:
result += ([0] * len(token_ids_1)) + [1]
return result
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
sep = [self.sep_token_id]
# cls = [self.cls_token_id]
result = token_ids_0 + sep
if token_ids_1 is not None:
result += token_ids_1 + sep
return result
def get_vocab(self) -> Dict[str, int]:
return self._vocab_str_to_int
# Fixed vocabulary with no vocab file
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple:
return ()
================================================
FILE: caduceus_env.yml
================================================
name: caduceus_env
channels:
- pytorch
- anaconda
- nvidia
- defaults
dependencies:
- cuda-nvcc=11.7.99
- pip=23.3.1
- python=3.8
- pytorch=2.2.0
- torchaudio=2.2.0
- torchaudio=2.2.0
- torchdata=0.7.1
- torchmetrics=1.2.1
- torchtext=0.17.0
- torchvision=0.17.0
- pytorch-cuda=12.1
- pip:
- biopython==1.81
- datasets==2.15.0
- einops==0.7.0
- enformer-pytorch==0.8.8
- fsspec==2023.10.0
- genomic-benchmarks==0.0.9
- git-lfs==1.6
- h5py==3.10.0
- huggingface-hub==0.24.7
- hydra-core==1.3.2
- ipdb==0.13.13
- matplotlib==3.7.4
- notebook==7.1.1
- nvitop==1.3.2
- omegaconf==2.3.0
- pandas==2.0.3
- pyfaidx==0.8.1.1
- pysam==0.22.0
- pytest==8.0.2
- pytorch-lightning==1.8.6
- rich==13.7.0
- seaborn==0.13.2
- scikit-learn==1.3.2
- timm==0.9.16
- tqdm==4.66.1
- transformers==4.38.1
- triton==2.2.0
- wandb==0.13.5
- flash-attn==2.5.6
- causal-conv1d===1.2.0.post2
- mamba-ssm==1.2.0.post1
================================================
FILE: configs/callbacks/base.yaml
================================================
learning_rate_monitor:
# _target_: pytorch_lightning.callbacks.LearningRateMonitor
logging_interval: ${train.interval}
timer:
# _target_: callbacks.timer.Timer
step: True
inter_step: False
epoch: True
val: True
params:
# _target_: callbacks.params.ParamsLog
total: True
trainable: True
fixed: True
================================================
FILE: configs/callbacks/checkpoint.yaml
================================================
model_checkpoint:
monitor: ${train.monitor} # name of the logged metric which determines when model is improving
mode: ${train.mode} # can be "max" or "min"
save_top_k: 1 # save k best models (determined by above metric)
save_last: False # True = additionally always save model from last epoch
dirpath: "checkpoints/"
filename: ${train.monitor}
auto_insert_metric_name: False
verbose: True
model_checkpoint_every_n_steps:
monitor: train/loss # name of the logged metric which determines when model is improving
mode: min # can be "max" or "min"
save_top_k: 0 # Do not save any "best" models; this callback is being used to save every n train steps
save_last: True # additionally always save model from last epoch
dirpath: "checkpoints/"
filename: train/loss
auto_insert_metric_name: False
verbose: True
every_n_train_steps: 100
#model_checkpoint_every_epoch:
# monitor: trainer/epoch # name of the logged metric which determines when model is improving
# mode: max # can be "max" or "min"
# save_top_k: 1 # Do not save any "best" models; this callback is being used to save every n train steps
# save_last: False # additionally always save model from last epoch
# dirpath: "checkpoints/"
# filename: null
# auto_insert_metric_name: False
# verbose: True
# every_n_epochs: 1
================================================
FILE: configs/callbacks/gpu_affinity.yaml
================================================
gpu_affinity:
_name_: gpu_affinity
================================================
FILE: configs/callbacks/rich.yaml
================================================
rich_model_summary:
max_depth: 2
rich_progress_bar:
refresh_rate_per_second: 1.0
================================================
FILE: configs/callbacks/val_every_n_global_steps.yaml
================================================
val_every_n_global_steps:
every_n: 10000
================================================
FILE: configs/callbacks/wandb.yaml
================================================
defaults:
- default
watch_model:
_target_: src.callbacks.wandb_callbacks.WatchModel
log: "all"
log_freq: 100
upload_code_as_artifact:
_target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact
code_dir: ${work_dir}/src
upload_ckpts_as_artifact:
_target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
ckpt_dir: "checkpoints/"
upload_best_only: True
log_f1_precision_recall_heatmap:
_target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap
log_confusion_matrix:
_target_: src.callbacks.wandb_callbacks.LogConfusionMatrix
log_image_predictions:
_target_: src.callbacks.wandb_callbacks.LogImagePredictions
num_samples: 8
================================================
FILE: configs/config.yaml
================================================
# @package _global_
defaults:
- _self_
- experiment: ???
# - model: ??? # Model backbone
# - pipeline: ??? # Specifies collection of configs, equivalent to next 5 lines
# Pipelines should specify /loader, /dataset, /task, /encoder, /decoder (ideally in that order)
# # - loader: default # Dataloader (e.g. handles batches)
# # - dataset: cifar # Defines the data (x and y pairs)
# # - task: multiclass_classification # Defines loss and metrics
# # - encoder: null # Interface between data and model
# # - decoder: null # Interface between model and targets
# Additional arguments used to configure the training loop
# Most of these set combinations of options in the PL trainer, add callbacks, or add features to the optimizer
train:
seed: 0
# These three options are used by callbacks (checkpoint, monitor) and scheduler
# Most of them are task dependent and are set by the pipeline
interval: ??? # Should be specified by scheduler. Also used by LR monitor
monitor: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer
mode: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer
ema: 0.0 # Moving average model for validation
test: True # Test after training
debug: False # Special settings to make debugging more convenient
ignore_warnings: False # Disable python warnings
optimizer_param_grouping:
bias_weight_decay: False
normalization_weight_decay: False
# These control state passing between batches
state:
mode: null # [ None | 'none' | 'reset' | 'bptt' | 'tbptt' ]
n_context: 0 # How many steps to use as memory context. Must be >= 0 or None (null), meaning infinite context
n_context_eval: ${.n_context} # Context at evaluation time
# Convenience keys to allow grouping runs
ckpt: checkpoints/last.ckpt # Resume training
disable_dataset: False # Disable dataset loading
validate_at_start: false
pretrained_model_path: null # Path to pretrained model
pretrained_model_strict_load: true # Whether to load the pretrained model even if the model is not compatible
pretrained_model_state_hook: # Hook called on the loaded model's state_dict
_name_: null
post_init_hook: # After initializing model, call method on model
_name_: null
layer_decay: # Used for ImageNet finetuning
_name_: null
decay: 0.7
# We primarily use wandb so this is moved to top level in the config for convenience
# Set `~wandb` or `wandb=null` or `wandb.mode=disabled` to disable logging
# If other loggers are added, it would make sense to put this one level lower under train/ or logger/
wandb:
project: dna
group: ""
job_type: training
mode: online # choices=['online', 'offline', 'disabled']
name: null
save_dir: "."
id: ${.name} # pass correct id to resume experiment!
# Below options should not need to be specified
# entity: "" # set to name of your wandb team or just remove it
# log_model: False
# prefix: ""
# job_type: "train"
# tags: []
hydra:
run:
dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S-%f}
job:
chdir: true
================================================
FILE: configs/dataset/genomic_benchmark.yaml
================================================
_name_: genomic_benchmark
train_val_split_seed: ${train.seed} # Used for train/validation splitting
dataset_name: dummy_mouse_enhancers_ensembl
dest_path: null
max_length: ${.${.dataset_name}.max_length}
max_length_val: ${.max_length}
max_length_test: ${.max_length}
d_output: ${.${.dataset_name}.classes}
use_padding: True
padding_side: 'left'
add_eos: False
batch_size: 128
train_len: ${.${.dataset_name}.train_len}
__l_max: ${.max_length}
shuffle: true # set this as default!
# these are used to find the right attributes automatically for each dataset
dummy_mouse_enhancers_ensembl:
train_len: 1210
classes: 2
max_length: 1024
demo_coding_vs_intergenomic_seqs:
train_len: 100_000
classes: 2
max_length: 200
demo_human_or_worm:
train_len: 100_000
classes: 2
max_length: 200
human_enhancers_cohn:
train_len: 27791
classes: 2
max_length: 500
human_enhancers_ensembl:
train_len: 154842
classes: 2
max_length: 512
human_ensembl_regulatory:
train_len: 289061
classes: 3
max_length: 512
human_nontata_promoters:
train_len: 36131
classes: 2
max_length: 251
human_ocr_ensembl:
train_len: 174756
classes: 2
max_length: 512
# there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings
# name num_seqs num_classes median len std
# dummy_mouse_enhancers_ensembl 1210 2 2381 984.4
# demo_coding_vs_intergenomic_seqs 100_000 2 200 0
# demo_human_or_worm 100_000 2 200 0
# human_enhancers_cohn 27791 2 500 0
# human_enhancers_ensembl 154842 2 269 122.6
# human_ensembl_regulatory 289061 3 401 184.3
# human_nontata_promoters 36131 2 251 0
# human_ocr_ensembl 174756 2 315 108.1
================================================
FILE: configs/dataset/hg38.yaml
================================================
_name_: hg38
bed_file: null
fasta_file: null
dataset_name: hg38
tokenizer_name: null
cache_dir: null
max_length: 1024
add_eos: True
batch_size: 8 # per GPU
batch_size_eval: ${eval:${.batch_size} * 2}
num_workers: 4 # For preprocessing only
shuffle: True
__train_len: 34021
__l_max: ${.max_length}
================================================
FILE: configs/dataset/nucleotide_transformer.yaml
================================================
_name_: nucleotide_transformer # this links to the overall SequenceDataset of all nucleotide transformer datasets
train_val_split_seed: ${train.seed} # Used for train/validation splitting
dataset_name: enhancers # this specifies which dataset in nuc trx
dest_path: null # path to overall nuc trx datasets
max_length: ${.${.dataset_name}.max_length}
d_output: ${.${.dataset_name}.classes}
use_padding: True
padding_side: left
add_eos: False
batch_size: 256
train_len: ${.${.dataset_name}.train_len}
__l_max: ${.max_length}
shuffle: true # set this as default!
metric: ${.${.dataset_name}.metric}
# these are used to find the right attributes automatically for each dataset
enhancers:
train_len: 14968
classes: 2
max_length: 200
metric: mcc
enhancers_types:
train_len: 14968
classes: 3
max_length: 200
metric: mcc
H3:
train_len: 13468
classes: 2
max_length: 500
metric: mcc
H3K4me1:
train_len: 28509
classes: 2
max_length: 500
metric: mcc
H3K4me2:
train_len: 27614
classes: 2
max_length: 500
metric: mcc
H3K4me3:
train_len: 33119
classes: 2
max_length: 500
metric: mcc
H3K9ac:
train_len: 25003
classes: 2
max_length: 500
metric: mcc
H3K14ac:
train_len: 29743
classes: 2
max_length: 500
metric: mcc
H3K36me3:
train_len: 31392
classes: 2
max_length: 500
metric: mcc
H3K79me3:
train_len: 25953
classes: 2
max_length: 500
metric: mcc
H4:
train_len: 13140
classes: 2
max_length: 500
metric: mcc
H4ac:
train_len: 30685
classes: 2
max_length: 500
metric: mcc
promoter_all:
train_len: 53276
classes: 2
max_length: 300
metric: f1_binary
promoter_no_tata:
train_len: 47767
classes: 2
max_length: 300
metric: f1_binary
promoter_tata:
train_len: 5517
classes: 2
max_length: 300
metric: f1_binary
splice_sites_acceptors:
train_len: 19961
classes: 2
max_length: 600
metric: f1_binary
splice_sites_all:
train_len: 27000
classes: 3
max_length: 400
metric: accuracy
splice_sites_donors:
train_len: 19775
classes: 2
max_length: 600
metric: f1_binary
# name maxlen classes samples metric
# enhancers 200 2 14968 MCC
# enhancers_types 200 3 14968 MCC
# H3 500 2 13468 MCC
# H3K4me1 500 2 28509 MCC
# H3K4me2 500 2 27614 MCC
# H3K4me3 500 2 33119 MCC
# H3K9ac 500 2 25003 MCC
# H3K14ac 500 2 29743 MCC
# H3K36me3 500 2 31392 MCC
# H3K79me3 500 2 25953 MCC
# H4 500 2 13140 MCC
# H4ac 500 2 30685 MCC
# promoter_all 300 2 53276 F1
# promoter_no_tata 300 2 47759 F1
# promoter_tata 300 2 5517 F1
# splice_sites_acceptor 600 2 19961 F1
# splice_sites_all 400 2 27000 F1
# splice_sites_donor 600 2 19775 F1
================================================
FILE: configs/experiment/hg38/genomic_benchmark.yaml
================================================
# @package _global_
defaults:
- /pipeline: genomic_benchmark
- /model: ???
- override /scheduler: cosine_warmup_timm
# there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings
# name num_seqs num_classes median len std
# dummy_mouse_enhancers_ensembl 1210 2 2381 984.4
# demo_coding_vs_intergenomic_seqs 100_000 2 200 0
# demo_human_or_worm 100_000 2 200 0
# human_enhancers_cohn 27791 2 500 0
# human_enhancers_ensembl 154842 2 269 122.6
# human_ensembl_regulatory 289061 3 401 184.3
# human_nontata_promoters 36131 2 251 0
# human_ocr_ensembl 174756 2 315 108.1
task:
loss:
_name_: cross_entropy
trainer:
accelerator: gpu
devices: 1
num_nodes: 1
accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}
max_epochs: 100
precision: 16 # bf16 only a100
gradient_clip_val: 1.0
model:
_name_: dna_embedding
dataset:
# optional, default is max_length
tokenizer_name: char
rc_aug: false # reverse complement augmentation
scheduler:
# COSINE TIMM
t_in_epochs: False
t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}}
warmup_lr_init: 1e-6
warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01}
lr_min: ${eval:0.1 * ${optimizer.lr}}
optimizer:
lr: 6e-4
weight_decay: 0.1
train:
gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
seed: 2222
global_batch_size: ${dataset.batch_size}
cross_validation: true
remove_test_loader_in_eval: true # test only at the end of training
pretrained_model_strict_load: false # false allows encoder/decoder to be used if new model uses it
# for loading backbone and not head, requires both of these flags below
pretrained_model_path: ???
pretrained_model_state_hook:
_name_: load_backbone
freeze_backbone: false
================================================
FILE: configs/experiment/hg38/genomic_benchmark_cnn.yaml
================================================
# @package _global_
defaults:
- /model: genomics_benchmark_cnn
- /pipeline: genomic_benchmark
- override /scheduler: cosine_warmup_timm
# there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings
# name num_seqs num_classes median len std
# dummy_mouse_enhancers_ensembl 1210 2 2381 984.4
# demo_coding_vs_intergenomic_seqs 100_000 2 200 0
# demo_human_or_worm 100_000 2 200 0
# human_enhancers_cohn 27791 2 500 0
# human_enhancers_ensembl 154842 2 269 122.6
# human_ensembl_regulatory 289061 3 401 184.3
# human_nontata_promoters 36131 2 251 0
# human_ocr_ensembl 174756 2 315 108.1
task:
loss:
_name_: cross_entropy
trainer:
accelerator: gpu
devices: 1
num_nodes: 1
accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}
max_epochs: 100
precision: 16 # bf16 only a100
gradient_clip_val: 1.0
encoder: id
decoder: id
dataset:
tokenizer_name: char
rc_aug: false # reverse complement augmentation
scheduler:
t_in_epochs: False
t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}}
warmup_lr_init: 1e-6
warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01}
lr_min: ${eval:0.1 * ${optimizer.lr}}
optimizer:
lr: 6e-4
weight_decay: 0.1
train:
gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
seed: 2222
global_batch_size: ${dataset.batch_size}
cross_validation: true
remove_test_loader_in_eval: true
pretrained_model_strict_load: false # false allows encoder/decoder to be used if new model uses it
================================================
FILE: configs/experiment/hg38/hg38.yaml
================================================
# @package _global_
defaults:
- /pipeline: hg38
- /model: ??? # Specify a model, e.g. model=mamba or model=hyena
- override /scheduler: cosine_warmup_timm
task:
_name_: lm
loss:
_name_: cross_entropy
ignore_index: 4
trainer:
accelerator: gpu
devices: 1
num_nodes: 1
accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}
max_epochs: null
max_steps: 10000
precision: 16 # bf16 only a100
gradient_clip_val: 1.0
limit_val_batches: 0.125
dataset:
batch_size: ${eval:1024//${trainer.devices}}
max_length: 1024
# optional, default is max_length
max_length_val: ${dataset.max_length}
max_length_test: ${dataset.max_length}
tokenizer_name: char
pad_max_length: null # needed for bpe tokenizer
add_eos: true
rc_aug: false
num_workers: 12
use_fixed_len_val: false # placing a fixed length val here, but it's really the test
mlm: false
mlm_probability: 0.0
scheduler:
t_in_epochs: False
t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
warmup_prefix: True
warmup_lr_init: 1e-6
warmup_t: ${eval:0.1*${trainer.max_steps}}
lr_min: 1e-4
optimizer:
lr: 6e-4
weight_decay: 0.1
betas: [0.9, 0.95]
train:
gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
seed: 2222
global_batch_size: 256 # effects the scheduler, need to set properly
================================================
FILE: configs/experiment/hg38/nucleotide_transformer.yaml
================================================
# @package _global_
defaults:
- /pipeline: nucleotide_transformer
- /model: ???
- override /scheduler: cosine_warmup_timm
model:
_name_: dna_embedding
trainer:
accelerator: gpu
devices: 1
num_nodes: 1
accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}
max_epochs: 100
precision: 16 # bf16 only a100
gradient_clip_val: 1.0
dataset:
tokenizer_name: char
rc_aug: false # reverse complement augmentation
scheduler:
t_in_epochs: False
t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}}
warmup_lr_init: 1e-6
warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01}
lr_min: ${eval:0.1 * ${optimizer.lr}}
optimizer:
lr: 1e-3
weight_decay: 0.1
train:
gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
seed: 2222
global_batch_size: ${dataset.batch_size}
cross_validation: true
remove_test_loader_in_eval: true # test only at the end of training
pretrained_model_strict_load: false # false allows encoder/decoder to be used if new model uses it
# for loading backbone and not head, requires both of these flags below
pretrained_model_path: ???
pretrained_model_state_hook:
_name_: load_backbone
freeze_backbone: false
================================================
FILE: configs/loader/default.yaml
================================================
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
pin_memory: True
drop_last: True
================================================
FILE: configs/model/caduceus.yaml
================================================
# Use open-source version of Mamba
_name_: caduceus_lm
config:
_target_: caduceus.configuration_caduceus.CaduceusConfig
# From original MambaConfig
d_model: 128
n_layer: 2
vocab_size: 12
ssm_cfg:
d_state: 16
d_conv: 4
expand: 2
dt_rank: "auto"
dt_min: 0.001
dt_max: 0.1
dt_init: "random"
dt_scale: 1.0
dt_init_floor: 1e-4
conv_bias: true
bias: false
use_fast_path: true
rms_norm: true
fused_add_norm: true
residual_in_fp32: false
pad_vocab_size_multiple: 8
# Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
norm_epsilon: 1e-5
# Used in init_weights
initializer_cfg:
initializer_range: 0.02
rescale_prenorm_residual: true
n_residuals_per_layer: 1
# Caduceus-specific params
bidirectional: true,
bidirectional_strategy: "add"
bidirectional_weight_tie: true
rcps: false
# Used for RCPSEmbedding / RCPSLMHead (will be filled in during model instantiation using info from tokenizer)
complement_map: null
================================================
FILE: configs/model/genomics_benchmark_cnn.yaml
================================================
# Use open-source version of Mamba
_name_: genomics_benchmark_cnn
number_of_classes: ${dataset.d_output}
vocab_size: 12
embedding_dim: 100 # See: https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/tree/main/experiments/torch_cnn_experiments
input_len: ${dataset.__l_max}
================================================
FILE: configs/model/hyena.yaml
================================================
_name_: hyena_lm
d_model: 128
n_layer: 2
d_inner: ${eval:4 * ${.d_model}}
vocab_size: 12
resid_dropout: 0.0
embed_dropout: 0.1
fused_mlp: False
fused_dropout_add_ln: False
checkpoint_mixer: False # set true for memory reduction
checkpoint_mlp: False # set true for memory reduction
residual_in_fp32: True
pad_vocab_size_multiple: 8
layer:
_name_: hyena
emb_dim: 5
filter_order: 64
local_order: 3
l_max: ${eval:${dataset.max_length}+2}
modulate: True
w: 10
lr: ${optimizer.lr}
wd: 0.0
lr_pos_emb: 0.0
================================================
FILE: configs/model/layer/hyena.yaml
================================================
_name_: hyena
l_max: 1024
order: 2
filter_order: 64
num_heads: 1
inner_factor: 1
num_blocks: 1
fused_bias_fc: false
outer_mixing: false
dropout: 0.0
filter_dropout: 0.0
filter_cls: 'hyena-filter'
post_order_ffn: false
jit_filter: false
short_filter_order: 3
activation: "id"
================================================
FILE: configs/model/mamba.yaml
================================================
# Use open-source version of Mamba
_name_: mamba_lm
config:
_target_: mamba_ssm.models.config_mamba.MambaConfig
d_model: 128 # Will be overwritten by CL in the scaling exps
n_layer: 2 # Will be overwritten by CL in the scaling exps
vocab_size: 12
pad_vocab_size_multiple: 8
rms_norm: true
fused_add_norm: true
residual_in_fp32: false
ssm_cfg:
d_state: 16
d_conv: 4
expand: 2
dt_rank: "auto"
dt_min: 0.001
dt_max: 0.1
dt_init: "random"
dt_scale: 1.0
dt_init_floor: 1e-4
conv_bias: true
bias: false
use_fast_path: true
initializer_cfg:
initializer_range: 0.02
rescale_prenorm_residual: true
n_residuals_per_layer: 1
#norm_epsilon: 1e-5 # Default arg in mamba create_block
================================================
FILE: configs/optimizer/adam.yaml
================================================
# _target_: torch.optim.Adam
_name_: adam
lr: 0.001 # Initial learning rate
# weight_decay: 0.0 # Weight decay for adam|lamb; should use AdamW instead if desired
betas: [0.9, 0.999]
================================================
FILE: configs/optimizer/adamw.yaml
================================================
# _target_: torch.optim.AdamW
_name_: adamw
lr: 0.001 # Initial learning rate
weight_decay: 0.00 # Weight decay
betas: [0.9, 0.999]
================================================
FILE: configs/optimizer/sgd.yaml
================================================
# _target_: torch.optim.SGD
_name_: sgd
lr: 0.001 # Initial learning rate
momentum: 0.9
weight_decay: 0.0 # Weight decay for adam|lamb
================================================
FILE: configs/pipeline/genomic_benchmark.yaml
================================================
# @package _global_
defaults:
- /trainer: default
- /loader: default
- /dataset: genomic_benchmark
- /task: multiclass_classification
- /optimizer: adamw
- /scheduler: plateau
- /callbacks: [base, checkpoint]
train:
monitor: val/accuracy # Needed for plateau scheduler
mode: max
encoder: id
# we need this for classification!
decoder:
_name_: sequence
mode: pool
================================================
FILE: configs/pipeline/hg38.yaml
================================================
# @package _global_
defaults:
- /trainer: default
- /loader: null
- /dataset: hg38
- /optimizer: adamw
- /scheduler: cosine_warmup
- /callbacks: [base, checkpoint]
train:
monitor: test/loss
mode: min
task:
_name_: lm
loss:
_name_: cross_entropy
ignore_index: 4 # Bake in tokenizer value for padding / EOS tokens
torchmetrics: ['perplexity', 'num_tokens']
encoder: null
decoder: null
loader:
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
pin_memory: True
drop_last: True # There's enough data and epochs, ignore the edge case
# shuffle: True
================================================
FILE: configs/pipeline/nucleotide_transformer.yaml
================================================
# @package _global_
defaults:
- /trainer: default
- /loader: default
- /dataset: nucleotide_transformer
- /task: multiclass_classification
- /optimizer: adamw
- /scheduler: plateau
- /callbacks: [base, checkpoint]
task:
loss:
_name_: cross_entropy
metrics:
- ${dataset.metric}
train:
monitor: val/${dataset.metric}
mode: max
encoder: id
# we need this for classification!
decoder:
_name_: sequence
mode: pool
================================================
FILE: configs/scheduler/constant.yaml
================================================
# @package _global_
train:
interval: epoch
scheduler:
# _target_: transformers.get_constant_schedule
_name_: constant
================================================
FILE: configs/scheduler/constant_warmup.yaml
================================================
# @package _global_
train:
interval: step
scheduler:
# _target_: transformers.get_constant_schedule_with_warmup
_name_: constant_warmup
num_warmup_steps: 1000 # Number of iterations for LR warmup
================================================
FILE: configs/scheduler/cosine.yaml
================================================
# @package _global_
train:
interval: epoch
scheduler:
# _target_: torch.optim.lr_scheduler.CosineAnnealingLR
_name_: cosine
T_max: 100 # Max number of epochs steps for LR scheduler
eta_min: 1e-6 # Min learning rate for cosine scheduler
================================================
FILE: configs/scheduler/cosine_warmup.yaml
================================================
# @package _global_
train:
interval: step
scheduler:
# _target_: transformers.get_cosine_schedule_with_warmup
_name_: cosine_warmup
num_warmup_steps: 1000
num_training_steps: 40000
================================================
FILE: configs/scheduler/cosine_warmup_timm.yaml
================================================
# @package _global_
train:
interval: step
scheduler:
# _target_: transformers.get_cosine_schedule_with_warmup
_name_: cosine_warmup_timm
t_in_epochs: False
t_initial: 300
lr_min: 1e-5
warmup_lr_init: 1e-6
warmup_t: 10
================================================
FILE: configs/scheduler/linear_warmup.yaml
================================================
# @package _global_
train:
interval: step
scheduler:
# _target_: transformers.get_linear_schedule_with_warmup
_name_: linear_warmup
num_warmup_steps: 1000
num_training_steps: 40000
================================================
FILE: configs/scheduler/multistep.yaml
================================================
# @package _global_
train:
interval: epoch
# _target_: torch.optim.lr_scheduler.MultiStepLR
scheduler:
_name_: multistep
milestones: [80,140,180]
gamma: 0.2
================================================
FILE: configs/scheduler/plateau.yaml
================================================
# @package _global_
train:
interval: epoch
monitor: ??? # must be specified
scheduler:
# _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
_name_: plateau
mode: ${train.mode} # Which metric to monitor
factor: 0.2 # Decay factor when ReduceLROnPlateau is used
patience: 20
min_lr: 0.0 # Minimum learning rate during annealing
================================================
FILE: configs/scheduler/step.yaml
================================================
# @package _global_
train:
interval: step
scheduler:
# _target_: torch.optim.lr_scheduler.StepLR
_name_: step
step_size: 1
gamma: 0.99
================================================
FILE: configs/task/lm.yaml
================================================
_name_: lm
# loss: cross_entropy # Handled by task: cross entropy loss
metrics: ppl
================================================
FILE: configs/task/multiclass_classification.yaml
================================================
# _target_: tasks.tasks.MultiClass
_name_: multiclass
loss: cross_entropy
metrics:
- accuracy
torchmetrics: null
================================================
FILE: configs/task/multilabel_classification.yaml
================================================
# _target_:
_name_: base
loss: binary_cross_entropy
metrics: null
torchmetrics:
- MultilabelAUROC # AUROC
- MultilabelAveragePrecision # Precision
# - Recall # not supported in torchmetrics
# - F1 # not supported in torchmetrics
================================================
FILE: configs/task/regression.yaml
================================================
# _target_: tasks.tasks.BaseTask
_name_: base
loss: mse
metrics: mse
torchmetrics: null
================================================
FILE: configs/trainer/debug.yaml
================================================
defaults:
- default
gpus: 1
min_epochs: 1
max_epochs: 10
# prints
progress_bar_refresh_rate: null
weights_summary: full
profiler: null
# debugs
fast_dev_run: False
num_sanity_val_steps: 2
overfit_batches: 0
limit_train_batches: 0.1
limit_val_batches: 0.1
limit_test_batches: 0.1
track_grad_norm: -1
terminate_on_nan: False
================================================
FILE: configs/trainer/default.yaml
================================================
_target_: pytorch_lightning.Trainer
devices: 1
accelerator: gpu
accumulate_grad_batches: 1 # Gradient accumulation every n batches
max_epochs: 200
# accelerator: ddp # Automatically set if gpus > 1
gradient_clip_val: 0.0
log_every_n_steps: 10
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
num_sanity_val_steps: 2 # default value: 2; override to 0 to skip sanity checking
================================================
FILE: configs/trainer/full.yaml
================================================
_target_: pytorch_lightning.Trainer
# default values for all trainer parameters
checkpoint_callback: True
default_root_dir: null
gradient_clip_val: 0.0
process_position: 0
num_nodes: 1
num_processes: 1
gpus: null
auto_select_gpus: False
tpu_cores: null
log_gpu_memory: null
overfit_batches: 0.0
track_grad_norm: -1
check_val_every_n_epoch: 1
fast_dev_run: False
accumulate_grad_batches: 1
max_epochs: 1
min_epochs: 1
max_steps: null
min_steps: null
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
val_check_interval: 1.0
flush_logs_every_n_steps: 100
log_every_n_steps: 50
accelerator: null
sync_batchnorm: False
precision: 32
weights_summary: "top"
weights_save_path: null
num_sanity_val_steps: 2
truncated_bptt_steps: null
resume_from_checkpoint: null
profiler: null
benchmark: False
deterministic: False
reload_dataloaders_every_epoch: False
auto_lr_find: False
replace_sampler_ddp: True
terminate_on_nan: False
auto_scale_batch_size: False
prepare_data_per_node: True
plugins: null
amp_backend: "native"
amp_level: "O2"
move_metrics_to_cpu: False
================================================
FILE: configs/trainer/lm.yaml
================================================
accumulate_grad_batches: 1
# accelerator: null # set to 'ddp' for distributed
# amp_backend: native # 'native' | 'apex'
gpus: 8
max_epochs: 50
gradient_clip_val: 0.0 # Gradient clipping
log_every_n_steps: 10
precision: 16
progress_bar_refresh_rate: 1
weights_summary: top # Set to 'full' to see every layer
track_grad_norm: -1 # Set to 2 to track norms of gradients
limit_train_batches: 1.0
limit_val_batches: 1.0
# We use the dataloader from Transformer-XL to ensure adjacent minibatches
# are from text that are next to each other.
# So that dataloader has to deal with DDP, and we don't want PL to handle
# that.
replace_sampler_ddp: False
================================================
FILE: setup_env.sh
================================================
#!/bin/bash
# Shell script to set environment variables when running code in this repository.
# Usage:
# source setup_env.sh
# Activate conda env
# shellcheck source=${HOME}/.bashrc disable=SC1091
source "${CONDA_SHELL}"
if [ -z "${CONDA_PREFIX}" ]; then
conda activate caduceus_env
elif [[ "${CONDA_PREFIX}" != *"/caduceus_env" ]]; then
conda deactivate
conda activate caduceus_env
fi
# Add root directory to PYTHONPATH to enable module imports
export PYTHONPATH="${PWD}"
================================================
FILE: slurm_scripts/dump_vep_embeddings.sh
================================================
#!/bin/bash
#SBATCH --get-user-env # Retrieve the users login environment
#SBATCH -t 96:00:00 # Time limit (hh:mm:ss)
#SBATCH --mem=100G # RAM
#SBATCH --gres=gpu:8 # Number of GPUs
#SBATCH --ntasks-per-node=8 # Should correspond to num devices (at least 1-1 task to GPU)
##SBATCH --cpus-per-task=4 # Number of CPU cores per task
#SBATCH -N 1 # Number of nodes
#SBATCH --requeue # Requeue job if it fails
#SBATCH --job-name=vep_embed # Job name
#SBATCH--output=../watch_folder/%x_%j.log # Output file name
#SBATCH --open-mode=append # Do not overwrite logs
NUM_WORKERS=2
NUM_DEVICES=8
# Setup environment
cd ../ || exit # Go to the root directory of the repo
source setup_env.sh
export CUDA_LAUNCH_BLOCKING=1
export CUBLAS_WORKSPACE_CONFIG=:4096:8 # Needed for setting deterministic functions for reproducibility
#####################################################################################
# Choose from one of the following:
## Enformer
#seq_len=196608
#bp_per_token=1
#embed_dump_batch_size=1
#model_name_or_path="EleutherAI/enformer-official-rough"
#name="enformer-seqlen=196k"
#rcps_flag="no-rcps"
## NTv2
#seq_len=12288 # 2048 (seq len) * 6 (kmers)
#bp_per_token=6
#embed_dump_batch_size=1
#model_name_or_path="InstaDeepAI/nucleotide-transformer-v2-500m-multi-species"
#name="NTv2_downstream-seqlen=12k"
#rcps_flag="no-rcps"
## Hyena
#seq_len=131072
#bp_per_token=1
#embed_dump_batch_size=1
#model_name_or_path="LongSafari/hyenadna-medium-160k-seqlen-hf"
#name="hyena_downstream-seqlen=131k"
#rcps_flag="no-rcps"
## Caduceus-Ph
#seq_len=131072
#bp_per_token=1
#embed_dump_batch_size=1
#model_name_or_path="kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
#name="caduceus-ph_downstream-seqlen=131k"
#rcps_flag="no-rcps"
## Caduceus-PS
#seq_len=131072
#bp_per_token=1
#embed_dump_batch_size=1
#model_name_or_path="kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16"
#name="caduceus-ps_downstream-seqlen=131k"
#rcps_flag="rcps"
#####################################################################################
torchrun \
--standalone \
--nnodes=1 \
--nproc-per-node=${NUM_DEVICES} \
vep_embeddings.py \
--num_workers=${NUM_WORKERS} \
--seq_len=${seq_len} \
--bp_per_token=${bp_per_token} \
--embed_dump_batch_size=${embed_dump_batch_size} \
--name="${name}" \
--model_name_or_path="${model_name_or_path}" \
--"${rcps_flag}"
================================================
FILE: slurm_scripts/run_genomics_benchmark.sh
================================================
#!/bin/bash
#SBATCH --get-user-env # Retrieve the users login environment
#SBATCH -t 96:00:00 # Time limit (hh:mm:ss)
#SBATCH --mem=64000M # RAM
#SBATCH --gres=gpu:1 # Number of GPUs
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=2
#SBATCH -N 1 # Number of nodes
#SBATCH --requeue # Requeue job if it fails
#SBATCH --open-mode=append # Do not overwrite logs
# Setup environment
cd ../ || exit # Go to the root directory of the repo
source setup_env.sh
# Expected args:
# - CONFIG_PATH
# - PRETRAINED_PATH
# - DISPLAY_NAME
# - MODEL
# - MODEL_NAME
# - CONJOIN_TRAIN_DECODER
# - CONJOIN_TEST
# - TASK
# - LR
# - BATCH_SIZE
# - RC_AUG
# Run script
# shellcheck disable=SC2154
WANDB_NAME="${DISPLAY_NAME}_lr-${LR}_batch_size-${BATCH_SIZE}_rc_aug-${RC_AUG}"
for seed in $(seq 1 5); do
# shellcheck disable=SC2154
HYDRA_RUN_DIR="./outputs/downstream/gb_cv5/${TASK}/${WANDB_NAME}/seed-${seed}"
mkdir -p "${HYDRA_RUN_DIR}"
echo "*****************************************************"
echo "Running GenomicsBenchmark model: ${DISPLAY_NAME}, task: ${TASK}, lr: ${LR}, batch_size: ${BATCH_SIZE}, rc_aug: ${RC_AUG}, SEED: ${seed}"
# shellcheck disable=SC2086
python -m train \
experiment=hg38/genomic_benchmark \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
dataset.dataset_name="${TASK}" \
dataset.train_val_split_seed=${seed} \
dataset.batch_size=${BATCH_SIZE} \
dataset.rc_aug="${RC_AUG}" \
+dataset.conjoin_train=false \
+dataset.conjoin_test="${CONJOIN_TEST}" \
model="${MODEL}" \
model._name_="${MODEL_NAME}" \
+model.config_path="${CONFIG_PATH}" \
+model.conjoin_test="${CONJOIN_TEST}" \
+decoder.conjoin_train="${CONJOIN_TRAIN_DECODER}" \
+decoder.conjoin_test="${CONJOIN_TEST}" \
optimizer.lr="${LR}" \
trainer.max_epochs=10 \
train.pretrained_model_path="${PRETRAINED_PATH}" \
wandb.group="downstream/gb_cv5" \
wandb.job_type="${TASK}" \
wandb.name="${WANDB_NAME}" \
wandb.id="gb_cv5_${TASK}_${WANDB_NAME}_seed-${seed}" \
+wandb.tags=\["seed-${seed}"\] \
hydra.run.dir="${HYDRA_RUN_DIR}"
echo "*****************************************************"
done
================================================
FILE: slurm_scripts/run_genomics_benchmark_cnn.sh
================================================
#!/bin/bash
#SBATCH --get-user-env # Retrieve the users login environment
#SBATCH -t 48:00:00 # Time limit (hh:mm:ss)
#SBATCH --mem=64G # RAM
#SBATCH --gres=gpu:1 # Number of GPUs
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=2
#SBATCH -N 1 # Number of nodes
#SBATCH --requeue # Requeue job if it fails
#SBATCH --open-mode=append # Do not overwrite logs
# Setup environment
cd ../ || exit # Go to the root directory of the repo
source setup_env.sh
# Expected args:
# - TASK
# - RC_AUG
# LR: 1e-3 -- in https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks, Adam optimizer is used with default lr=1e-3
LR="1e-3"
# Batch size: 64 -- See https://arxiv.org/abs/2306.15794 and https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks
BATCH_SIZE=64
# Run script
WANDB_NAME="CNN-LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}"
for seed in $(seq 1 5); do
HYDRA_RUN_DIR="./outputs/downstream/gb_cv5/${TASK}/${WANDB_NAME}/seed-${seed}"
mkdir -p "${HYDRA_RUN_DIR}"
echo "*****************************************************"
echo "Running GenomicsBenchmark TASK: ${TASK}, lr: ${LR}, batch_size: ${BATCH_SIZE}, RC_AUG: ${RC_AUG}, SEED: ${seed}"
python -m train \
experiment=hg38/genomic_benchmark_cnn \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
dataset.dataset_name="${TASK}" \
dataset.train_val_split_seed=${seed} \
dataset.batch_size=${BATCH_SIZE} \
dataset.rc_aug="${RC_AUG}" \
optimizer.lr="${LR}" \
trainer.max_epochs=10 \
wandb.group="downstream/gb_cv5" \
wandb.job_type="${TASK}" \
wandb.name="${WANDB_NAME}" \
wandb.id="gb_cv5_${TASK}_${WANDB_NAME}_seed-${seed}" \
+wandb.tags=\["seed-${seed}"\] \
hydra.run.dir="${HYDRA_RUN_DIR}"
echo "*****************************************************"
done
================================================
FILE: slurm_scripts/run_nucleotide_transformer.sh
================================================
#!/bin/bash
#SBATCH --get-user-env # Retrieve the users login environment
#SBATCH -t 96:00:00 # Time limit (hh:mm:ss)
#SBATCH --mem=64G # RAM
#SBATCH --gres=gpu:2 # Number of GPUs
#SBATCH --ntasks-per-node=2
#SBATCH --cpus-per-task=4
#SBATCH -N 1 # Number of nodes
#SBATCH --requeue # Requeue job if it fails
#SBATCH --open-mode=append # Do not overwrite logs
# Setup environment
cd ../ || exit # Go to the root directory of the repo
source setup_env.sh
export HYDRA_FULL_ERROR=1
# Expected args:
# - CONFIG_PATH
# - PRETRAINED_PATH
# - DISPLAY_NAME
# - MODEL
# - MODEL_NAME
# - CONJOIN_TRAIN_DECODER
# - CONJOIN_TEST
# - TASK
# - LR
# - BATCH_SIZE
# - RC_AUG
# Run script
WANDB_NAME="${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}"
for seed in $(seq 1 10); do
HYDRA_RUN_DIR="./outputs/downstream/nt_cv10_ep20/${TASK}/${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}/seed-${seed}"
mkdir -p "${HYDRA_RUN_DIR}"
echo "*****************************************************"
echo "Running NT model: ${DISPLAY_NAME}, TASK: ${TASK}, LR: ${LR}, BATCH_SIZE: ${BATCH_SIZE}, RC_AUG: ${RC_AUG}, SEED: ${seed}"
python -m train \
experiment=hg38/nucleotide_transformer \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
dataset.dataset_name="${TASK}" \
dataset.train_val_split_seed=${seed} \
dataset.batch_size=${BATCH_SIZE} \
dataset.rc_aug="${RC_AUG}" \
+dataset.conjoin_test="${CONJOIN_TEST}" \
model="${MODEL}" \
model._name_="${MODEL_NAME}" \
+model.config_path="${CONFIG_PATH}" \
+model.conjoin_test="${CONJOIN_TEST}" \
+decoder.conjoin_train="${CONJOIN_TRAIN_DECODER}" \
+decoder.conjoin_test="${CONJOIN_TEST}" \
optimizer.lr="${LR}" \
train.pretrained_model_path="${PRETRAINED_PATH}" \
trainer.max_epochs=20 \
wandb.group="downstream/nt_cv10_ep20" \
wandb.job_type="${TASK}" \
wandb.name="${WANDB_NAME}" \
wandb.id="nt_cv10_ep-20_${TASK}_${WANDB_NAME}_seed-${seed}" \
+wandb.tags=\["seed-${seed}"\] \
hydra.run.dir="${HYDRA_RUN_DIR}"
echo "*****************************************************"
done
================================================
FILE: slurm_scripts/run_pretrain_caduceus.sh
================================================
#!/bin/bash
#SBATCH --get-user-env # Retrieve the users login environment
#SBATCH -t 96:00:00 # Time limit (hh:mm:ss)
#SBATCH --mem=100G # RAM
#SBATCH --gres=gpu:8 # Number of GPUs
#SBATCH --ntasks-per-node=8 # Should correspond to num devices (at least 1-1 task to GPU)
##SBATCH --cpus-per-task=4 # Number of CPU cores per task
#SBATCH -N 1 # Number of nodes
#SBATCH --requeue # Requeue job if it fails
#SBATCH --job-name=caduceus_ps # Job name
#SBATCH --output=../watch_folder/%x_%j.log # Log file
#SBATCH --open-mode=append # Do not overwrite logs
# Setup environment
cd ../ || exit # Go to the root directory of the repo
source setup_env.sh
export HYDRA_FULL_ERROR=1
NUM_DEVICES=8
# Run script
SEQLEN=131072
MAX_STEPS=50000
D_MODEL=256
N_LAYER=8
LR="8e-3"
BIDIRECTIONAL_STRATEGY="add"
BIDIRECTIONAL_WEIGHT_TIE="true"
RCPS="true"
RC_AUG="false"
BATCH_SIZE=$(( 1048576 / SEQLEN ))
SEQLEN_DIS="$(echo "scale=0; ${SEQLEN} / 1000" | bc)k"
WANDB_NAME="caduceus-ps_seqlen-${SEQLEN_DIS}_d_model-${D_MODEL}_n_layer-${N_LAYER}_lr-${LR}"
HYDRA_RUN_DIR="./outputs/pretrain/hg38/${WANDB_NAME}"
mkdir -p "${HYDRA_RUN_DIR}"
srun python -m train \
experiment=hg38/hg38 \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \
dataset.max_length=${SEQLEN} \
dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \
dataset.mlm=true \
dataset.mlm_probability=0.15 \
dataset.rc_aug="${RC_AUG}" \
model="caduceus" \
model.config.d_model=${D_MODEL} \
model.config.n_layer=${N_LAYER} \
model.config.bidirectional=true \
model.config.bidirectional_strategy=${BIDIRECTIONAL_STRATEGY} \
model.config.bidirectional_weight_tie=${BIDIRECTIONAL_WEIGHT_TIE} \
model.config.rcps=${RCPS} \
optimizer.lr="${LR}" \
train.global_batch_size=${BATCH_SIZE} \
trainer.max_steps=${MAX_STEPS} \
trainer.devices=${NUM_DEVICES} \
+trainer.val_check_interval=$(( MAX_STEPS / 5 )) \
wandb.group=pretrain_hg38 \
wandb.name="${WANDB_NAME}" \
hydra.run.dir="${HYDRA_RUN_DIR}"
================================================
FILE: slurm_scripts/run_pretrain_hyena.sh
================================================
#!/bin/bash
#SBATCH --get-user-env # Retrieve the users login environment
#SBATCH -t 96:00:00 # Time limit (hh:mm:ss)
#SBATCH --mem=100G # RAM
#SBATCH --gres=gpu:8 # Number of GPUs
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=4 # Number of CPU cores per task
#SBATCH -N 1 # Number of nodes
#SBATCH --requeue # Requeue job if it fails
#SBATCH --job-name=hyena # Job name
#SBATCH --output=../watch_folder/%x_%j.log # Log file
# Setup environment
cd ../ || exit # Go to the root directory of the repo
source setup_env.sh
NUM_DEVICES=8
# Run script
SEQLEN=1024
MAX_STEPS=10000
D_MODEL=256
N_LAYER=4
LR="6e-4"
RC_AUG="true"
BATCH_SIZE=$(( 1048576 / SEQLEN ))
SEQLEN_DIS="$(echo "scale=0; ${SEQLEN} / 1000" | bc)k"
WANDB_NAME="hyena_rc_aug_seqlen-${SEQLEN_DIS}_dmodel-${D_MODEL}_nlayer-${N_LAYER}_lr-${LR}"
HYDRA_RUN_DIR="./outputs/pretrain/hg38/${WANDB_NAME}"
mkdir -p "${HYDRA_RUN_DIR}"
srun python -m train \
experiment=hg38/hg38 \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \
dataset.max_length=${SEQLEN} \
dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \
dataset.mlm=false \
dataset.mlm_probability=0.0 \
dataset.rc_aug="${RC_AUG}" \
model=hyena \
model.d_model=${D_MODEL} \
model.n_layer=${N_LAYER} \
optimizer.lr="${LR}" \
train.global_batch_size=${BATCH_SIZE} \
trainer.max_steps=${MAX_STEPS} \
trainer.devices=${NUM_DEVICES} \
+trainer.val_check_interval=$(( MAX_STEPS / 5 )) \
wandb.group=pretrain_hg38 \
wandb.name="${WANDB_NAME}" \
hydra.run.dir="${HYDRA_RUN_DIR}"
================================================
FILE: slurm_scripts/run_pretrain_mamba.sh
================================================
#!/bin/bash
#SBATCH --get-user-env # Retrieve the users login environment
#SBATCH -t 96:00:00 # Time limit (hh:mm:ss)
#SBATCH --mem=100G # RAM
#SBATCH --gres=gpu:8 # Number of GPUs
#SBATCH --ntasks-per-node=8 # Should correspond to num devices (at least 1-1 task to GPU)
#SBATCH --cpus-per-task=4 # Number of CPU cores per task
#SBATCH -N 1 # Number of nodes
#SBATCH --requeue # Requeue job if it fails
#SBATCH --job-name=mamba_ntp # Job name
#SBATCH --output=../watch_folder/%x_%j.log # Log file
# Setup environment
cd ../ || exit # Go to the root directory of the repo
source setup_env.sh
NUM_DEVICES=8
# Run script
SEQLEN=1024
MAX_STEPS=10000
D_MODEL=256
N_LAYER=8
LR="8e-3"
RC_AUG="true"
BATCH_SIZE=$(( 1048576 / SEQLEN ))
SEQLEN_DIS="$(echo "scale=0; ${SEQLEN} / 1000" | bc)k"
WANDB_NAME="mamba_ntp_rc_aug_seqlen-${SEQLEN_DIS}_d_model-${D_MODEL}_n_layer-${N_LAYER}_lr-${LR}"
HYDRA_RUN_DIR="./outputs/pretrain/hg38/${WANDB_NAME}"
mkdir -p "${HYDRA_RUN_DIR}"
srun python -m train \
experiment=hg38/hg38 \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \
dataset.max_length=${SEQLEN} \
dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \
dataset.mlm=false \
dataset.mlm_probability=0.0 \
dataset.rc_aug="${RC_AUG}" \
model=mamba \
model.config.d_model=${D_MODEL} \
model.config.n_layer=${N_LAYER} \
optimizer.lr="${LR}" \
train.global_batch_size=${BATCH_SIZE} \
trainer.max_steps=${MAX_STEPS} \
trainer.devices=${NUM_DEVICES} \
+trainer.val_check_interval=$(( MAX_STEPS / 5 )) \
wandb.group=pretrain_hg38 \
wandb.name="${WANDB_NAME}" \
hydra.run.dir="${HYDRA_RUN_DIR}"
================================================
FILE: slurm_scripts/wrapper_run_genomics.sh
================================================
#!/bin/bash
# Choose one from below
## Hyena
## TODO: Download HF model from https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen to ../outputs/hyena_hf/hyenadna-tiny-1k-seqlen
#LOG_DIR="../watch_folder/gb_cv5/hyena"
#CONFIG_PATH=$(realpath "../outputs/hyena_hf/hyenadna-tiny-1k-seqlen/config.json")
#PRETRAINED_PATH=$(realpath "../outputs/hyena_hf/hyenadna-tiny-1k-seqlen/weights.ckpt")
#DISPLAY_NAME="hyena"
#MODEL="hyena"
#MODEL_NAME="dna_embedding"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="false"
#RC_AUGS=( "false" "true" )
#LRS=( "6e-4" )
## Mamba NTP
#LOG_DIR="../watch_folder/gb_cv5/mamba"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/mamba_ntp_rc_aug_seqlen-1k_d_model-128_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/mamba_ntp_rc_aug_seqlen-1k_d_model-128_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="mamba_uni"
#MODEL="mamba"
#MODEL_NAME="dna_embedding_mamba"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="false"
#RC_AUGS=( "true" )
#LRS=( "1e-3" "2e-3" )
## Caduceus NO POST HOC
#LOG_DIR="../watch_folder/gb_cv5/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_NO_PH"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="false"
#RC_AUGS=( "true" )
#LRS=( "2e-3")
## Caduceus Post-Hoc
#LOG_DIR="../watch_folder/gb_cv5/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_ph"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="true"
#RC_AUGS=( "false" )
#LRS=( "1e-3" "2e-3" )
## Caduceus Parameter Sharing
#LOG_DIR="../watch_folder/gb_cv5/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_ps"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="true" # Use this in decoder to always combine forward and reverse complement channels
#CONJOIN_TEST="false"
#RC_AUGS=( "false" )
#LRS=( "1e-3" "2e-3" )
mkdir -p "${LOG_DIR}"
export_str="ALL,CONFIG_PATH=${CONFIG_PATH},PRETRAINED_PATH=${PRETRAINED_PATH},DISPLAY_NAME=${DISPLAY_NAME},MODEL=${MODEL},MODEL_NAME=${MODEL_NAME},CONJOIN_TRAIN_DECODER=${CONJOIN_TRAIN_DECODER},CONJOIN_TEST=${CONJOIN_TEST}"
for TASK in "dummy_mouse_enhancers_ensembl" "demo_coding_vs_intergenomic_seqs" "demo_human_or_worm" "human_enhancers_cohn" "human_enhancers_ensembl" "human_ensembl_regulatory" "human_nontata_promoters" "human_ocr_ensembl"; do
for LR in "${LRS[@]}"; do
for BATCH_SIZE in 128 256; do
for RC_AUG in "${RC_AUGS[@]}"; do
export_str="${export_str},TASK=${TASK},LR=${LR},BATCH_SIZE=${BATCH_SIZE},RC_AUG=${RC_AUG}"
job_name="gb_${TASK}_${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}"
sbatch \
--job-name="${job_name}" \
--output="${LOG_DIR}/%x_%j.log" \
--export="${export_str}" \
"run_genomics_benchmark.sh"
done
done
done
done
================================================
FILE: slurm_scripts/wrapper_run_genomics_cnn.sh
================================================
#!/bin/bash
LOG_DIR="../watch_folder/gb_cv5/cnn_baseline"
mkdir -p "${LOG_DIR}"
export_str="ALL"
for TASK in "dummy_mouse_enhancers_ensembl" "demo_coding_vs_intergenomic_seqs" "demo_human_or_worm" "human_enhancers_cohn" "human_enhancers_ensembl" "human_ensembl_regulatory" "human_nontata_promoters" "human_ocr_ensembl"; do
for RC_AUG in "false"; do
export_str="${export_str},TASK=${TASK},RC_AUG=${RC_AUG}"
job_name="gb_${TASK}_CNN_RC_AUG-${RC_AUG}"
sbatch \
--job-name="${job_name}" \
--output="${LOG_DIR}/%x_%j.log" \
--export="${export_str}" \
"run_genomics_benchmark_cnn.sh"
done
done
================================================
FILE: slurm_scripts/wrapper_run_nucleotide_transformer.sh
================================================
#!/bin/bash
# Choose one from below
## Caduceus NO POST HOC
#LOG_DIR="../watch_folder/nt_cv10_ep20/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_NO_PH"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="false"
#RC_AUGS=( "true" )
#LRS=( "1e-3" "2e-3")
## Caduceus Post-Hoc
#LOG_DIR="../watch_folder/nt_cv10_ep20/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_ph"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="true"
#RC_AUGS=( "false" )
#LRS=( "1e-3" "2e-3" )
## Caduceus Parameter Sharing
#LOG_DIR="../watch_folder/nt_cv10_ep20/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_ps"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="true" # Use this in decoder to always combine forward and reverse complement channels
#CONJOIN_TEST="false"
#RC_AUGS=( "false" )
#LRS=( "1e-3" "2e-3" )
mkdir -p "${LOG_DIR}"
export_str="ALL,CONFIG_PATH=${CONFIG_PATH},PRETRAINED_PATH=${PRETRAINED_PATH},DISPLAY_NAME=${DISPLAY_NAME},MODEL=${MODEL},MODEL_NAME=${MODEL_NAME},CONJOIN_TRAIN_DECODER=${CONJOIN_TRAIN_DECODER},CONJOIN_TEST=${CONJOIN_TEST}"
for TASK in "enhancers" "enhancers_types" "H3" "H3K4me1" "H3K4me2" "H3K4me3" "H3K9ac" "H3K14ac" "H3K36me3" "H3K79me3" "H4" "H4ac" "promoter_all" "promoter_no_tata" "promoter_tata" "splice_sites_all" "splice_sites_acceptors" "splice_sites_donors"; do
for LR in "${LRS[@]}"; do
for BATCH_SIZE in 128 512; do
for RC_AUG in "${RC_AUGS[@]}"; do
export_str="${export_str},TASK=${TASK},LR=${LR},BATCH_SIZE=${BATCH_SIZE},RC_AUG=${RC_AUG}"
job_name="nt_${TASK}_${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}"
sbatch \
--job-name="${job_name}" \
--output="${LOG_DIR}/%x_%j.log" \
--export="${export_str}" \
"run_nucleotide_transformer.sh"
done
done
done
done
================================================
FILE: src/__init__.py
================================================
================================================
FILE: src/callbacks/params.py
================================================
"""Callback to log the number of parameters of the model.
"""
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
class ParamsLog(pl.Callback):
""" Log the number of parameters of the model """
def __init__(
self,
total: bool = True,
trainable: bool = True,
fixed: bool = True,
):
super().__init__()
self._log_stats = AttributeDict(
{
'total_params_log': total,
'trainable_params_log': trainable,
'non_trainable_params_log': fixed,
}
)
@rank_zero_only
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
logs = {}
if self._log_stats.total_params_log:
logs["params/total"] = sum(p.numel() for p in pl_module.parameters())
if self._log_stats.trainable_params_log:
logs["params/trainable"] = sum(p.numel() for p in pl_module.parameters()
if p.requires_grad)
if self._log_stats.non_trainable_params_log:
logs["params/fixed"] = sum(p.numel() for p in pl_module.parameters()
if not p.requires_grad)
if trainer.logger:
trainer.logger.log_hyperparams(logs)
================================================
FILE: src/callbacks/timer.py
================================================
"""Callback to monitor the speed of each step and each epoch.
https://github.com/HazyResearch/transformers/blob/master/src/callbacks/speed_monitor.py
Adapted from:
https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor
"""
# We only need the speed monitoring, not the GPU monitoring
import time
from typing import Any
from pytorch_lightning import Callback, Trainer, LightningModule
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT
class Timer(Callback):
"""Monitor the speed of each step and each epoch.
"""
def __init__(
self,
step: bool = True,
inter_step: bool = True,
epoch: bool = True,
val: bool = True,
):
super().__init__()
self._log_stats = AttributeDict( {
'step_time': step,
'inter_step_time': inter_step,
'epoch_time': epoch,
'val_time': val,
})
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self._snap_epoch_time = None
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self._snap_step_time = None
self._snap_inter_step_time = None
self._snap_epoch_time = time.time()
def on_train_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any,
batch_idx: int,
) -> None:
if self._log_stats.step_time:
self._snap_step_time = time.time()
if not self._should_log(trainer):
return
logs = {}
if self._log_stats.inter_step_time and self._snap_inter_step_time:
# First log at beginning of second step
logs["timer/inter_step"] = (time.time() - self._snap_inter_step_time) # * 1000
if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step)
@rank_zero_only
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
if self._log_stats.inter_step_time:
self._snap_inter_step_time = time.time()
if not self._should_log(trainer):
return
logs = {}
if self._log_stats.step_time and self._snap_step_time:
logs["timer/step"] = (time.time() - self._snap_step_time) # * 1000
if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step)
@rank_zero_only
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None:
logs = {}
if self._log_stats.epoch_time and self._snap_epoch_time:
logs["timer/epoch"] = time.time() - self._snap_epoch_time
if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step)
def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self._snap_val_time = time.time()
@rank_zero_only
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None:
logs = {}
if self._log_stats.val_time and self._snap_val_time:
logs["timer/validation"] = time.time() - self._snap_val_time
if trainer.logger: trainer.logger.log_metrics(logs) # , step=trainer.global_step)
@staticmethod
def _should_log(trainer) -> bool:
return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop
================================================
FILE: src/callbacks/validation.py
================================================
"""Check validation every n **global** steps.
Pytorch Lightning has a `val_check_interval` parameter that checks validation every n batches, but does not support
checking every n **global** steps.
"""
from typing import Any
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import RunningStage
class ValEveryNGlobalSteps(Callback):
"""Check validation every n **global** steps."""
def __init__(self, every_n):
self.every_n = every_n
self.last_run = None
def on_train_batch_end(self, trainer, *_: Any):
"""Check if we should run validation.
Adapted from: https://github.com/Lightning-AI/pytorch-lightning/issues/2534#issuecomment-1085986529
"""
# Prevent Running validation many times in gradient accumulation
if trainer.global_step == self.last_run:
return
else:
self.last_run = None
if trainer.global_step % self.every_n == 0 and trainer.global_step != 0:
trainer.training = False
stage = trainer.state.stage
trainer.state.stage = RunningStage.VALIDATING
trainer._run_evaluate()
trainer.state.stage = stage
trainer.training = True
trainer._logger_connector._epoch_end_reached = False
self.last_run = trainer.global_step
================================================
FILE: src/dataloaders/__init__.py
================================================
from . import genomics
from .base import SequenceDataset
================================================
FILE: src/dataloaders/base.py
================================================
""" Datasets for core experimental results.
"""
import os
from functools import partial
from pathlib import Path
import torch
# Default data path is environment variable or <repo_root_dir>/data
if (default_data_path := os.getenv("DATA_PATH")) is None:
default_data_path = Path(__file__).parent.parent.parent.absolute()
default_data_path = default_data_path / "data"
else:
default_data_path = Path(default_data_path).absolute()
class DefaultCollateMixin:
"""Controls collating in the DataLoader
The CollateMixin classes instantiate a dataloader by separating collate arguments with the rest of the dataloader
arguments. Instantiations of this class should modify the callback functions as desired, and modify the collate_args
list. The class then defines a _dataloader() method which takes in a DataLoader constructor and arguments,
constructs a collate_fn based on the collate_args, and passes the rest of the arguments into the constructor.
"""
@classmethod
def _collate_callback(cls, x, *args, **kwargs):
"""
Modify the behavior of the default _collate method.
"""
return x
_collate_arg_names = []
@classmethod
def _return_callback(cls, return_value, *args, **kwargs):
"""
Modify the return value of the collate_fn.
Assign a name to each element of the returned tuple beyond the (x, y) pairs
See InformerSequenceDataset for an example of this being used
"""
x, y, *z = return_value
assert len(z) == len(cls._collate_arg_names), "Specify a name for each auxiliary data item returned by dataset"
return x, y, {k: v for k, v in zip(cls._collate_arg_names, z)}
@classmethod
def _collate(cls, batch, *args, **kwargs):
# From https://github.com/pyforch/pytorch/blob/master/torch/utils/data/_utils/collate.py
elem = batch[0]
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
x = torch.stack(batch, dim=0, out=out)
# Insert custom functionality into the collate_fn
x = cls._collate_callback(x, *args, **kwargs)
return x
else:
return torch.tensor(batch)
@classmethod
def _collate_fn(cls, batch, *args, **kwargs):
"""
Default collate function.
Generally accessed by the dataloader() methods to pass into torch DataLoader
Arguments:
batch: list of (x, y) pairs
args, kwargs: extra arguments that get passed into the _collate_callback and _return_callback
"""
x, y, *z = zip(*batch)
x = cls._collate(x, *args, **kwargs)
y = cls._collate(y)
z = [cls._collate(z_) for z_ in z]
return_value = (x, y, *z)
return cls._return_callback(return_value, *args, **kwargs)
# List of loader arguments to pass into collate_fn
collate_args = []
def _dataloader(self, dataset, **loader_args):
collate_args = {k: loader_args[k] for k in loader_args if k in self.collate_args}
loader_args = {k: loader_args[k] for k in loader_args if k not in self.collate_args}
loader_cls = loader_registry[loader_args.pop("_name_", None)]
return loader_cls(
dataset=dataset,
collate_fn=partial(self._collate_fn, **collate_args),
**loader_args,
)
# class SequenceDataset(LightningDataModule):
# [21-09-10 AG] Subclassing LightningDataModule fails due to trying to access _has_setup_fit. No idea why. So we just
# provide our own class with the same core methods as LightningDataModule (e.g. setup)
class SequenceDataset(DefaultCollateMixin):
registry = {}
_name_ = NotImplementedError("Dataset must have shorthand name")
# Since subclasses do not specify __init__ which is instead handled by this class
# Subclasses can provide a list of default arguments which are automatically registered as attributes
# TODO it might be possible to write this as a @dataclass, but it seems tricky to separate from the other features
# of this class such as the _name_ and d_input/d_output
@property
def init_defaults(self):
return {}
# https://www.python.org/dev/peps/pep-0487/#subclass-registration
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.registry[cls._name_] = cls
def __init__(self, _name_, data_dir=None, **dataset_cfg):
assert _name_ == self._name_
self.data_dir = Path(data_dir).absolute() if data_dir is not None else None
# Add all arguments to self
init_args = self.init_defaults.copy()
init_args.update(dataset_cfg)
for k, v in init_args.items():
setattr(self, k, v)
# The train, val, test datasets must be set by `setup()`
self.dataset_train = self.dataset_val = self.dataset_test = None
self.init()
def init(self):
"""Hook called at end of __init__, override this instead of __init__"""
pass
def setup(self):
"""This method should set self.dataset_train, self.dataset_val, and self.dataset_test."""
raise NotImplementedError
def split_train_val(self, val_split):
"""
Randomly split self.dataset_train into a new (self.dataset_train, self.dataset_val) pair.
"""
train_len = int(len(self.dataset_train) * (1.0 - val_split))
self.dataset_train, self.dataset_val = torch.utils.data.random_split(
self.dataset_train,
(train_len, len(self.dataset_train) - train_len),
generator=torch.Generator().manual_seed(
getattr(self, "seed", 42)
), # PL is supposed to have a way to handle seeds properly, but doesn't seem to work for us
)
def train_dataloader(self, **kwargs):
"""Return a DataLoader for the training dataset."""
return self._train_dataloader(self.dataset_train, **kwargs)
def _train_dataloader(self, dataset, **kwargs):
if dataset is None:
return
kwargs['shuffle'] = 'sampler' not in kwargs # shuffle cant be True if we have custom sampler
return self._dataloader(dataset, **kwargs)
def val_dataloader(self, **kwargs):
"""Return a DataLoader for the validation dataset."""
return self._eval_dataloader(self.dataset_val, **kwargs)
def test_dataloader(self, **kwargs):
"""Return a DataLoader for the test dataset."""
return self._eval_dataloader(self.dataset_test, **kwargs)
def _eval_dataloader(self, dataset, **kwargs):
if dataset is None:
return
# Note that shuffle=False by default
return self._dataloader(dataset, **kwargs)
def __str__(self):
return self._name_
# Registry for dataloader class
loader_registry = {
None: torch.utils.data.DataLoader, # default case
}
================================================
FILE: src/dataloaders/datasets/genomic_bench_dataset.py
================================================
"""Genomic Benchmarks Dataset.
From: https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks
"""
from pathlib import Path
import torch
from genomic_benchmarks.data_check import is_downloaded
from genomic_benchmarks.loc2seq import download_dataset
from src.dataloaders.utils.rc import coin_flip, string_reverse_complement
class GenomicBenchmarkDataset(torch.utils.data.Dataset):
"""
Loop through bed file, retrieve (chr, start, end), query fasta file for sequence.
Returns a generator that retrieves the sequence.
"""
def __init__(
self,
split,
max_length,
dataset_name="human_nontata_promoters",
d_output=2, # default binary classification
dest_path=None,
tokenizer=None,
tokenizer_name=None,
use_padding=None,
add_eos=False,
rc_aug=False,
conjoin_train=False,
conjoin_test=False,
return_augs=False,
return_mask=False,
):
self.max_length = max_length
self.use_padding = use_padding
self.tokenizer_name = tokenizer_name
self.tokenizer = tokenizer
self.return_augs = return_augs
self.add_eos = add_eos
self.d_output = d_output # needed for decoder to grab
assert not (conjoin_train and conjoin_test), "conjoin_train and conjoin_test cannot both be True"
if (conjoin_train or conjoin_test) and rc_aug:
print("When using conjoin, we turn off rc_aug.")
rc_aug = False
self.rc_aug = rc_aug
self.conjoin_train = conjoin_train
self.conjoin_test = conjoin_test
self.return_mask = return_mask
if not is_downloaded(dataset_name, cache_path=dest_path):
print("downloading {} to {}".format(dataset_name, dest_path))
download_dataset(dataset_name, version=0, dest_path=dest_path)
else:
print("already downloaded {}-{}".format(split, dataset_name))
self.split = split
# use Path object
base_path = Path(dest_path) / dataset_name / split
self.all_seqs = []
self.all_labels = []
label_mapper = {}
for i, x in enumerate(base_path.iterdir()):
label_mapper[x.stem] = i
for label_type in label_mapper.keys():
for path in (base_path / label_type).iterdir():
with open(path, "r") as f:
content = f.read()
self.all_seqs.append(content)
self.all_labels.append(label_mapper[label_type])
def __len__(self):
return len(self.all_labels)
def __getitem__(self, idx):
x = self.all_seqs[idx]
y = self.all_labels[idx]
if (self.rc_aug or (self.conjoin_test and self.split == "train")) and coin_flip():
x = string_reverse_complement(x)
seq = self.tokenizer(
x,
add_special_tokens=False,
padding="max_length" if self.use_padding else None,
max_length=self.max_length,
truncation=True,
)
seq_ids = seq["input_ids"] # get input_ids
# need to handle eos here
if self.add_eos:
# append list seems to be faster than append tensor
seq_ids.append(self.tokenizer.sep_token_id)
if self.conjoin_train or (self.conjoin_test and self.split != "train"):
x_rc = string_reverse_complement(x)
seq_rc = self.tokenizer(
x_rc,
add_special_tokens=False,
padding="max_length" if self.use_padding else None,
max_length=self.max_length,
truncation=True,
)
seq_rc_ids = seq_rc["input_ids"] # get input_ids
# need to handle eos here
if self.add_eos:
# append list seems to be faster than append tensor
seq_rc_ids.append(self.tokenizer.sep_token_id)
seq_ids = torch.stack((torch.LongTensor(seq_ids), torch.LongTensor(seq_rc_ids)), dim=1)
else:
# convert to tensor
seq_ids = torch.LongTensor(seq_ids)
# need to wrap in list
target = torch.LongTensor([y])
# `seq` has shape:
# - (seq_len,) if not conjoining
# - (seq_len, 2) for conjoining
if self.return_mask:
return seq_ids, target, {"mask": torch.BoolTensor(seq["attention_mask"])}
else:
return seq_ids, target
================================================
FILE: src/dataloaders/datasets/hg38_char_tokenizer.py
================================================
"""
From: https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py
CharacterTokenizer for Hugging Face Transformers.
This is heavily inspired from CanineTokenizer in transformers package.
"""
import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
class CharacterTokenizer(PreTrainedTokenizer):
def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str = 'left', **kwargs):
"""Character tokenizer for Hugging Face transformers.
Args:
characters (Sequence[str]): List of desired characters. Any character which
is not included in this list will be replaced by a special token called
[UNK] with id=6. Following is the list of all the special tokens with
their corresponding ids:
"[CLS]": 0
"[SEP]": 1
"[BOS]": 2
"[MASK]": 3
"[PAD]": 4
"[RESERVED]": 5
"[UNK]": 6
an id (starting at 7) will be assigned to each character.
model_max_length (int): Model maximum sequence length.
"""
self.characters = characters
self.model_max_length = model_max_length
bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False)
eos_token = AddedToken("[EOS]", lstrip=False, rstrip=False)
sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False)
pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False)
unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False)
mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False)
self._vocab_str_to_int = {
"[CLS]": 0,
"[SEP]": 1,
"[BOS]": 2,
"[MASK]": 3,
"[PAD]": 4,
"[RESERVED]": 5,
"[UNK]": 6,
**{ch: i + 7 for i, ch in enumerate(characters)},
}
self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
# TODO: This should be a parameter passed to __init__
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
self.complement_map = {}
for k, v in self._vocab_str_to_int.items():
complement_id = self._vocab_str_to_int[complement_map[k]] if k in complement_map.keys() else v
self.complement_map[self._vocab_str_to_int[k]] = complement_id
super().__init__(
bos_token=bos_token,
eos_token=pad_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
unk_token=unk_token,
add_prefix_space=False,
model_max_length=model_max_length,
padding_side=padding_side,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self._vocab_str_to_int)
def _tokenize(self, text: str) -> List[str]:
return list(text)
def _convert_token_to_id(self, token: str) -> int:
return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
def _convert_id_to_token(self, index: int) -> str:
return self._vocab_int_to_str[index]
def convert_tokens_to_string(self, tokens):
return "".join(tokens)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
sep = [self.sep_token_id]
cls = [self.cls_token_id]
result = cls + token_ids_0 + sep
if token_ids_1 is not None:
result += token_ids_1 + sep
return result
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
result = [1] + ([0] * len(token_ids_0)) + [1]
if token_ids_1 is not None:
result += ([0] * len(token_ids_1)) + [1]
return result
def get_vocab(self) -> Dict[str, int]:
return self._vocab_str_to_int
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
sep = [self.sep_token_id]
cls = [self.cls_token_id]
result = len(cls + token_ids_0 + sep) * [0]
if token_ids_1 is not None:
result += len(token_ids_1 + sep) * [1]
return result
def get_config(self) -> Dict:
return {
"char_ords": [ord(ch) for ch in self.characters],
"model_max_length": self.model_max_length,
}
@classmethod
def from_config(cls, config: Dict) -> "CharacterTokenizer":
cfg = {}
cfg["characters"] = [chr(i) for i in config["char_ords"]]
cfg["model_max_length"] = config["model_max_length"]
return cls(**cfg)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
cfg_file = Path(save_directory) / "tokenizer_config.json"
cfg = self.get_config()
with open(cfg_file, "w") as f:
json.dump(cfg, f, indent=4)
@classmethod
def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
cfg_file = Path(save_directory) / "tokenizer_config.json"
with open(cfg_file) as f:
cfg = json.load(f)
return cls.from_config(cfg)
================================================
FILE: src/dataloaders/datasets/hg38_dataset.py
================================================
"""Dataset for sampling arbitrary intervals from the human genome.
"""
import math
from pathlib import Path
import pandas as pd
import torch
from pyfaidx import Fasta
from src.dataloaders.utils.mlm import mlm_getitem
from src.dataloaders.utils.rc import coin_flip, string_reverse_complement
MAX_ALLOWED_LENGTH = 2 ** 20
class FastaInterval:
"""Retrieves sequences from a fasta file given a chromosome and start/end indices."""
def __init__(
self,
*,
fasta_file,
return_seq_indices=False,
rc_aug=False,
):
fasta_file = Path(fasta_file)
assert fasta_file.exists(), "Path to fasta file must exist!"
self.seqs = Fasta(str(fasta_file))
self.return_seq_indices = return_seq_indices
self.rc_aug = rc_aug
# calc len of each chromosome in fasta file, store in dict
self.chr_lens = {}
for chr_name in self.seqs.keys():
self.chr_lens[chr_name] = len(self.seqs[chr_name])
@staticmethod
def _compute_interval(start, end, max_length, i_shift):
if max_length == MAX_ALLOWED_LENGTH:
return start, end
if max_length < MAX_ALLOWED_LENGTH:
assert MAX_ALLOWED_LENGTH % max_length == 0
return start + i_shift * max_length, start + (i_shift + 1) * max_length
else:
raise ValueError(f"`max_length` {max_length} (> 2^{int(math.log(MAX_ALLOWED_LENGTH, 2))}) is too large!")
def __call__(
self,
chr_name,
start,
end,
max_length,
i_shift,
return_augs=False,
):
"""
max_length passed from dataset, not from init
"""
chromosome = self.seqs[chr_name]
chromosome_length = self.chr_lens[chr_name]
start, end = self._compute_interval(start, end, max_length, i_shift)
if end > chromosome_length:
# Shift interval down
start = start - (end - chromosome_length)
end = chromosome_length
assert start == chromosome_length - max_length
if start < 0:
# Shift interval up
end = end - start
start = 0
assert end == max_length
if end > chromosome_length:
# This may occur if start + MAX_ALLOWED_LENGTH extends beyond the end of the chromosome
start = chromosome_length - max_length
end = chromosome_length
seq = str(chromosome[start:end])
if self.rc_aug and coin_flip():
seq = string_reverse_complement(seq)
return seq
class HG38Dataset(torch.utils.data.Dataset):
"""Loop through bed file, retrieve (chr, start, end), query fasta file for sequence."""
def __init__(
self,
split,
bed_file,
fasta_file,
max_length,
mlm=False,
mlm_probability=0.15,
pad_max_length=None,
tokenizer=None,
tokenizer_name=None,
add_eos=False,
return_seq_indices=False,
rc_aug=False,
return_augs=False,
):
self.mlm = mlm
self.mlm_probability = mlm_probability
if self.mlm and self.mlm_probability <= 0.0:
raise ValueError(f"`mlm_probability` has to be > 0.0, got {self.mlm_probability}.")
if self.mlm:
# TODO: see if this helps
# self.eligible_replacements = torch.tensor(
# tokenizer("ACGT", add_special_tokens=False)["input_ids"], dtype=torch.long
# )
self.eligible_replacements = None
else:
self.eligible_replacements = None
self.max_length = max_length
self.pad_max_length = pad_max_length if pad_max_length is not None else max_length
self.tokenizer_name = tokenizer_name
self.tokenizer = tokenizer
self.return_augs = return_augs
self.add_eos = add_eos
if max_length <= MAX_ALLOWED_LENGTH:
assert MAX_ALLOWED_LENGTH % max_length == 0, f"`max_length` must be a power of 2!"
self.shifts = MAX_ALLOWED_LENGTH // max_length
else:
raise ValueError(f"`max_length` {max_length} (> 2^{int(math.log(MAX_ALLOWED_LENGTH, 2))}) is too large!")
bed_path = Path(bed_file)
assert bed_path.exists(), "Path to .bed file must exist!"
# read bed file
df_raw = pd.read_csv(str(bed_path), sep="\t", names=["chr_name", "start", "end", "split"])
# select only split df
self.df = df_raw[df_raw["split"] == split]
# Update end points so that sequences are all length == MAX_ALLOWED_LENGTH
self.df.loc[:, "end"] = self.df["start"] + MAX_ALLOWED_LENGTH
self.fasta = FastaInterval(
fasta_file=fasta_file,
return_seq_indices=return_seq_indices,
rc_aug=rc_aug
)
@staticmethod
def replace_value(x, old_value, new_value):
"""Helper for replacing values in a tensor."""
return torch.where(x == old_value, new_value, x)
def __len__(self):
return len(self.df) * self.shifts
def __getitem__(self, idx):
"""Returns a sequence of specified len"""
# sample a random row from df
row_idx, shift_idx = idx // self.shifts, idx % self.shifts
row = self.df.iloc[row_idx]
chr_name, start, end = (row.iloc[0], row.iloc[1], row.iloc[2])
seq = self.fasta(
chr_name,
start,
end,
max_length=self.max_length,
i_shift=shift_idx,
return_augs=self.return_augs,
)
if end - start != MAX_ALLOWED_LENGTH:
print(row, "\nLength: ", end - start)
if self.tokenizer_name == "char":
seq = self.tokenizer(
seq,
padding="max_length",
max_length=self.pad_max_length,
truncation=True,
add_special_tokens=False
)
seq = seq["input_ids"] # get input_ids
# need to handle eos here
if self.add_eos:
# append list seems to be faster than append tensor
seq.append(self.tokenizer.sep_token_id)
elif self.tokenizer_name == "bpe":
seq = self.tokenizer(
seq,
# add_special_tokens=False,
padding="max_length",
max_length=self.pad_max_length,
truncation=True,
)
# get input_ids
if self.add_eos:
seq = seq["input_ids"][1:] # remove the bos, keep the eos token
else:
seq = seq["input_ids"][1:-1] # remove both special tokens
# convert to tensor
seq = torch.LongTensor(seq)
# replace N token with a pad token, so we can ignore it in the loss
seq = self.replace_value(seq, self.tokenizer._vocab_str_to_int["N"], self.tokenizer.pad_token_id)
if self.mlm:
data, target = mlm_getitem(
seq,
mlm_probability=self.mlm_probability,
contains_eos=self.add_eos,
tokenizer=self.tokenizer,
eligible_replacements=self.eligible_replacements,
)
else:
data = seq[:-1].clone()
target = seq[1:].clone()
return data, target
================================================
FILE: src/dataloaders/datasets/nucleotide_transformer_dataset.py
================================================
"""Nucleotide Transformer Benchmarks Dataset.
From: https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_downstream_tasks
"""
import torch
from datasets import load_dataset
from src.dataloaders.utils.rc import coin_flip, string_reverse_complement
class NucleotideTransformerDataset(torch.utils.data.Dataset):
"""
Loop through fasta file for sequence.
Returns a generator that retrieves the sequence.
"""
def __init__(
self,
split,
max_length,
dataset_name=None,
d_output=2, # default binary classification
tokenizer=None,
tokenizer_name=None,
use_padding=None,
add_eos=False,
rc_aug=False,
conjoin_train=False,
conjoin_test=False,
return_augs=False
):
self.max_length = max_length
self.use_padding = use_padding
self.tokenizer_name = tokenizer_name
self.tokenizer = tokenizer
self.return_augs = return_augs
self.add_eos = add_eos
self.d_output = d_output # needed for decoder to grab
assert not (conjoin_train and conjoin_test), "conjoin_train and conjoin_test cannot both be True"
if (conjoin_train or conjoin_test) and rc_aug:
print("When using conjoin, we turn off rc_aug.")
rc_aug = False
self.rc_aug = rc_aug
self.conjoin_train = conjoin_train
self.conjoin_test = conjoin_test
self.split = split
# For NT tasks, we use data from InstaDeepAI/nucleotide_transformer_downstream_tasks
self.seqs = load_dataset(
"InstaDeepAI/nucleotide_transformer_downstream_tasks",
name=dataset_name,
split=split
)
def __len__(self):
return len(self.seqs)
def __getitem__(self, idx):
x = self.seqs[idx]["sequence"] # only one sequence
y = self.seqs[idx]["label"]
if (self.rc_aug or (self.conjoin_test and self.split == "train")) and coin_flip():
x = string_reverse_complement(x)
seq = self.tokenizer(
x,
add_special_tokens=False,
padding="max_length" if self.use_padding else None,
max_length=self.max_length,
truncation=True,
)
seq_ids = seq["input_ids"] # get input_ids
# need to handle eos here
if self.add_eos:
# append list seems to be faster than append tensor
seq_ids.append(self.tokenizer.sep_token_id)
if self.conjoin_train or (self.conjoin_test and self.split != "train"):
x_rc = string_reverse_complement(x)
seq_rc = self.tokenizer(
x_rc,
add_special_tokens=False,
padding="max_length" if self.use_padding else None,
max_length=self.max_length,
truncation=True,
)
seq_rc_ids = seq_rc["input_ids"] # get input_ids
# need to handle eos here
if self.add_eos:
# append list seems to be faster than append tensor
seq_rc_ids.append(self.tokenizer.sep_token_id)
seq_ids = torch.stack((torch.LongTensor(seq_ids), torch.LongTensor(seq_rc_ids)), dim=1)
else:
# convert to tensor
seq_ids = torch.LongTensor(seq_ids)
# need to wrap in list
target = torch.LongTensor([y])
# `seq` has shape:
# - (seq_len,) if not conjoining
# - (seq_len, 2) for conjoining
return seq_ids, target
================================================
FILE: src/dataloaders/fault_tolerant_sampler.py
================================================
# Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytorch/utilities/test_auto_restart.py#L1397
from typing import Iterator
import math
import torch
from torch.utils.data import RandomSampler, DistributedSampler
class RandomFaultTolerantSampler(RandomSampler):
def __init__(self, *args, generator=None, **kwargs):
# generator = torch.Generator().manual_seed(seed)
# super().__init__(*args, generator=generator, **kwargs)
# TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
# which should be reproducible if pl.seed_everything was called before hand.
# This means that changing the seed of the experiment will also change the
# sampling order.
if generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator().manual_seed(seed)
super().__init__(*args, generator=generator, **kwargs)
self.counter = 0
# self.start_counter = 0
self.restarting = False
def state_dict(self):
return {"random_state": self.state, "counter": self.counter}
def load_state_dict(self, state_dict):
self.generator.set_state(state_dict.get("random_state"))
self.counter = state_dict["counter"]
# self.start_counter = self.counter
self.restarting = True
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
# epoch, and subsequent epoch will have very few batches.
# def __len__(self):
# # We need a separate self.start_counter because PL seems to call len repeatedly.
# # If we use len(self.data_source) - self.counter then PL will think the epoch ends
# # when we're only half way through.
# return len(self.data_source) - self.start_counter
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
self.state = self.generator.get_state()
indices = torch.randperm(n, generator=self.generator).tolist()
if not self.restarting:
self.counter = 0
else:
indices = indices[self.counter:]
self.restarting = False
# self.start_counter = self.counter
for index in indices:
self.counter += 1
yield index
self.counter = 0
# self.start_counter = self.counter
class FaultTolerantDistributedSampler(DistributedSampler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.counter = 0
# self.start_counter = 0
self.restarting = False
def state_dict(self):
return {"epoch": self.epoch, "counter": self.counter}
def load_state_dict(self, state_dict):
self.epoch = state_dict["epoch"]
self.counter = state_dict["counter"]
# self.start_counter = self.counter
self.restarting = True
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
# epoch, and subsequent epoch will have very few batches.
# def __len__(self) -> int:
# return self.num_samples - self.start_counter
def __iter__(self):
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
if not self.restarting:
self.counter = 0
else:
indices = indices[self.counter:]
self.restarting = False
# self.start_counter = self.counter
for index in indices:
self.counter += 1
yield index
self.counter = 0
# self.start_counter = self.counter
================================================
FILE: src/dataloaders/genomics.py
================================================
"""Dataloaders for genomics datasets, including pretraining and downstream tasks.
- Adapted from:
https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py
- Adapted from:
https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
"""
import copy
from typing import Any, List, Union
import torch
from datasets import Dataset
from torch.utils.data.dataloader import DataLoader
from caduceus.tokenization_caduceus import CaduceusTokenizer
import src.utils.train
from src.dataloaders.base import SequenceDataset, default_data_path
from src.dataloaders.datasets.genomic_bench_dataset import GenomicBenchmarkDataset
from src.dataloaders.datasets.hg38_char_tokenizer import CharacterTokenizer
from src.dataloaders.datasets.hg38_dataset import HG38Dataset
from src.dataloaders.datasets.nucleotide_transformer_dataset import NucleotideTransformerDataset
from src.dataloaders.fault_tolerant_sampler import FaultTolerantDistributedSampler
from src.dataloaders.fault_tolerant_sampler import RandomFaultTolerantSampler
logger = src.utils.train.get_logger(__name__)
class HG38(SequenceDataset):
"""
Base class, other dataloaders can inherit from this class.
You must implement the following functions:
- __init__
- setup
You can then use (already have access to) the following functions:
- train_dataloader
- val_dataloader
- test_dataloader
"""
_name_ = "hg38" # this name is how the dataset config finds the right dataloader
def __init__(self, bed_file, fasta_file, tokenizer_name=None, dataset_config_name=None, max_length=1024, d_output=2,
rc_aug=False,
max_length_val=None, max_length_test=None, val_ratio=0.0005, val_split_seed=2357,
add_eos=True, detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, shuffle=False,
num_workers=1,
fault_tolerant=False, ddp=False,
fast_forward_epochs=None, fast_forward_batches=None,
mlm=False, mlm_probability=0.15,
*args, **kwargs):
self.dataset_config_name = dataset_config_name
self.tokenizer_name = tokenizer_name
self.d_output = d_output
self.rc_aug = rc_aug # reverse compliment augmentation
self.max_length = max_length
self.max_length_val = max_length_val if max_length_val is not None else max_length
self.max_length_test = max_length_test if max_length_test is not None else max_length
self.val_ratio = val_ratio
self.val_split_seed = val_split_seed
self.val_only = val_only
self.add_eos = add_eos
self.detokenize = detokenize
self.batch_size = batch_size
self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size
self.shuffle = shuffle
self.num_workers = num_workers
self.bed_file = bed_file
self.fasta_file = fasta_file
# handle if file paths are None (default paths)
if self.bed_file is None:
self.bed_file = default_data_path / self._name_ / "human-sequences.bed"
if self.fasta_file is None:
self.fasta_file = default_data_path / self._name_ / "hg38.ml.fa"
if fault_tolerant:
assert self.shuffle
self.fault_tolerant = fault_tolerant
if ddp:
assert fault_tolerant
self.ddp = ddp
self.fast_forward_epochs = fast_forward_epochs
self.fast_forward_batches = fast_forward_batches
if self.fast_forward_epochs is not None or self.fast_forward_batches is not None:
assert ddp and fault_tolerant
self.mlm = mlm
self.mlm_probability = mlm_probability
# To be instantiated in `setup`
self.tokenizer = None
self.vocab_size = 0
def setup(self, stage=None):
"""Set up the tokenizer and init the datasets."""
# TODO instantiate with registry
if self.tokenizer_name == "char":
logger.info("**Using Char-level tokenizer**")
# self.tokenizer = CharacterTokenizer(
# characters=["A", "C", "G", "T", "N"],
# model_max_length=self.max_length,
# add_special_tokens=False,
# )
self.tokenizer = CaduceusTokenizer(
model_max_length=self.max_length,
add_special_tokens=False
)
else:
raise NotImplementedError(f"Tokenizer {self.tokenizer_name} not implemented.")
self.vocab_size = len(self.tokenizer)
self.init_datasets() # creates the datasets. You can also just create this inside the setup() here.
def init_datasets(self):
"""Init the datasets (separate from the tokenizer)"""
# delete old datasets to free memory
if hasattr(self, "dataset_train"):
self.dataset_train.fasta.seqs.close()
del self.dataset_train.fasta.seqs
# delete old datasets to free memory
if hasattr(self, "dataset_test"):
self.dataset_test.fasta.seqs.close()
del self.dataset_test.fasta.seqs
# Create all splits: torch datasets
self.dataset_train, self.dataset_val, self.dataset_test = [
HG38Dataset(split=split,
bed_file=self.bed_file,
fasta_file=self.fasta_file,
max_length=max_len,
tokenizer=self.tokenizer, # pass the tokenize wrapper
tokenizer_name=self.tokenizer_name,
add_eos=self.add_eos,
return_seq_indices=False,
rc_aug=self.rc_aug,
return_augs=False,
mlm=self.mlm,
mlm_probability=self.mlm_probability, )
for split, max_len in
zip(["train", "valid", "test"], [self.max_length, self.max_length_val, self.max_length_test])
]
return
def train_dataloader(self, **kwargs: Any) -> DataLoader:
""" The train dataloader """
if self.shuffle and self.fault_tolerant:
shuffle = False
# TD [2022-12-26]: We need the distributed_sampler_kwargs in case of model parallel:
# In that case the number of replicas and the data parallel rank are more complicated.
distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs
sampler = (FaultTolerantDistributedSampler(
self.dataset_train,
**distributed_sampler_kwargs
) if self.ddp else RandomFaultTolerantSampler(self.dataset_train))
# TD [2022-08-06]: Only the DDP sampler supports fast-forwarding for now
# We assume that it's being resumed with the same number of GPUs
if self.ddp and self.fast_forward_epochs is not None and self.fast_forward_batches is not None:
sampler.load_state_dict({
"epoch": self.fast_forward_epochs,
"counter": self.fast_forward_batches * self.batch_size
})
else:
shuffle = self.shuffle
sampler = None
loader = self._data_loader(self.dataset_train, batch_size=self.batch_size,
shuffle=shuffle, sampler=sampler, **kwargs)
return loader
def val_dataloader(self, **kwargs: Any) -> Union[DataLoader, List[DataLoader]
gitextract_kw76odef/ ├── .gitignore ├── LICENSE ├── README.md ├── caduceus/ │ ├── __init__.py │ ├── configuration_caduceus.py │ ├── modeling_caduceus.py │ ├── modeling_rcps.py │ ├── tests/ │ │ └── test_rcps.py │ └── tokenization_caduceus.py ├── caduceus_env.yml ├── configs/ │ ├── callbacks/ │ │ ├── base.yaml │ │ ├── checkpoint.yaml │ │ ├── gpu_affinity.yaml │ │ ├── rich.yaml │ │ ├── val_every_n_global_steps.yaml │ │ └── wandb.yaml │ ├── config.yaml │ ├── dataset/ │ │ ├── genomic_benchmark.yaml │ │ ├── hg38.yaml │ │ └── nucleotide_transformer.yaml │ ├── experiment/ │ │ └── hg38/ │ │ ├── genomic_benchmark.yaml │ │ ├── genomic_benchmark_cnn.yaml │ │ ├── hg38.yaml │ │ └── nucleotide_transformer.yaml │ ├── loader/ │ │ └── default.yaml │ ├── model/ │ │ ├── caduceus.yaml │ │ ├── genomics_benchmark_cnn.yaml │ │ ├── hyena.yaml │ │ ├── layer/ │ │ │ └── hyena.yaml │ │ └── mamba.yaml │ ├── optimizer/ │ │ ├── adam.yaml │ │ ├── adamw.yaml │ │ └── sgd.yaml │ ├── pipeline/ │ │ ├── genomic_benchmark.yaml │ │ ├── hg38.yaml │ │ └── nucleotide_transformer.yaml │ ├── scheduler/ │ │ ├── constant.yaml │ │ ├── constant_warmup.yaml │ │ ├── cosine.yaml │ │ ├── cosine_warmup.yaml │ │ ├── cosine_warmup_timm.yaml │ │ ├── linear_warmup.yaml │ │ ├── multistep.yaml │ │ ├── plateau.yaml │ │ └── step.yaml │ ├── task/ │ │ ├── lm.yaml │ │ ├── multiclass_classification.yaml │ │ ├── multilabel_classification.yaml │ │ └── regression.yaml │ └── trainer/ │ ├── debug.yaml │ ├── default.yaml │ ├── full.yaml │ └── lm.yaml ├── setup_env.sh ├── slurm_scripts/ │ ├── dump_vep_embeddings.sh │ ├── run_genomics_benchmark.sh │ ├── run_genomics_benchmark_cnn.sh │ ├── run_nucleotide_transformer.sh │ ├── run_pretrain_caduceus.sh │ ├── run_pretrain_hyena.sh │ ├── run_pretrain_mamba.sh │ ├── wrapper_run_genomics.sh │ ├── wrapper_run_genomics_cnn.sh │ └── wrapper_run_nucleotide_transformer.sh ├── src/ │ ├── __init__.py │ ├── callbacks/ │ │ ├── params.py │ │ ├── timer.py │ │ └── validation.py │ ├── dataloaders/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── datasets/ │ │ │ ├── genomic_bench_dataset.py │ │ │ ├── hg38_char_tokenizer.py │ │ │ ├── hg38_dataset.py │ │ │ └── nucleotide_transformer_dataset.py │ │ ├── fault_tolerant_sampler.py │ │ ├── genomics.py │ │ └── utils/ │ │ ├── mlm.py │ │ └── rc.py │ ├── models/ │ │ ├── __init__.py │ │ ├── baseline/ │ │ │ ├── __init__.py │ │ │ └── genomics_benchmark_cnn.py │ │ ├── nn/ │ │ │ ├── __init__.py │ │ │ ├── activation.py │ │ │ ├── adaptive_softmax.py │ │ │ └── utils.py │ │ └── sequence/ │ │ ├── __init__.py │ │ ├── dna_embedding.py │ │ ├── hyena.py │ │ └── long_conv_lm.py │ ├── ops/ │ │ └── fftconv.py │ ├── tasks/ │ │ ├── decoders.py │ │ ├── encoders.py │ │ ├── metrics.py │ │ ├── tasks.py │ │ └── torchmetrics.py │ └── utils/ │ ├── __init__.py │ ├── config.py │ ├── optim/ │ │ └── schedulers.py │ ├── optim_groups.py │ ├── registry.py │ └── train.py ├── train.py ├── vep_embeddings.py └── vep_svm.ipynb
SYMBOL INDEX (434 symbols across 36 files)
FILE: caduceus/configuration_caduceus.py
class CaduceusConfig (line 10) | class CaduceusConfig(PretrainedConfig):
method __init__ (line 14) | def __init__(
FILE: caduceus/modeling_caduceus.py
function create_block (line 33) | def create_block(
class BiMambaWrapper (line 87) | class BiMambaWrapper(nn.Module):
method __init__ (line 90) | def __init__(
method forward (line 122) | def forward(self, hidden_states, inference_params=None):
class CaduceusEmbeddings (line 143) | class CaduceusEmbeddings(nn.Module):
method __init__ (line 144) | def __init__(
method forward (line 159) | def forward(self, input_ids):
class CaduceusMixerModel (line 166) | class CaduceusMixerModel(nn.Module):
method __init__ (line 167) | def __init__(
method forward (line 216) | def forward(self, input_ids, inputs_embeds=None, output_hidden_states=...
function cross_entropy (line 279) | def cross_entropy(logits, y, ignore_index=-100):
function weighted_cross_entropy (line 286) | def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
class CaduceusPreTrainedModel (line 297) | class CaduceusPreTrainedModel(PreTrainedModel):
method _init_weights (line 304) | def _init_weights(
class Caduceus (line 344) | class Caduceus(CaduceusPreTrainedModel):
method __init__ (line 346) | def __init__(self, config: CaduceusConfig, device=None, dtype=None, **...
method forward (line 363) | def forward(
class CaduceusForMaskedLM (line 392) | class CaduceusForMaskedLM(CaduceusPreTrainedModel):
method __init__ (line 395) | def __init__(self, config: CaduceusConfig, device=None, dtype=None, **...
method get_input_embeddings (line 417) | def get_input_embeddings(self):
method set_input_embeddings (line 420) | def set_input_embeddings(self, value):
method get_output_embeddings (line 425) | def get_output_embeddings(self):
method set_output_embeddings (line 428) | def set_output_embeddings(self, new_embeddings):
method tie_weights (line 434) | def tie_weights(self):
method get_decoder (line 441) | def get_decoder(self):
method set_decoder (line 445) | def set_decoder(self, decoder):
method forward (line 449) | def forward(
class CaduceusForSequenceClassification (line 495) | class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
method __init__ (line 496) | def __init__(
method init_scorer (line 521) | def init_scorer(self, initializer_range=0.02):
method get_input_embeddings (line 526) | def get_input_embeddings(self):
method set_input_embeddings (line 529) | def set_input_embeddings(self, value):
method pool_hidden_states (line 534) | def pool_hidden_states(self, hidden_states, sequence_length_dim=1):
method forward (line 545) | def forward(
FILE: caduceus/modeling_rcps.py
class RCPSEmbedding (line 21) | class RCPSEmbedding(nn.Module):
method __init__ (line 23) | def __init__(self, vocab_size: int, d_model: int, complement_map: dict...
method weight (line 38) | def weight(self):
method set_weight (line 42) | def set_weight(self, value):
method rc (line 46) | def rc(self, x):
method forward (line 54) | def forward(self, input_ids):
class RCPSWrapper (line 70) | class RCPSWrapper(nn.Module):
method __init__ (line 76) | def __init__(self, submodule: nn.Module):
method rc (line 81) | def rc(x):
method forward (line 85) | def forward(self, x, **kwargs):
class RCPSAddNormWrapper (line 102) | class RCPSAddNormWrapper(RCPSWrapper):
method __init__ (line 104) | def __init__(self, submodule: nn.Module):
method forward (line 107) | def forward(self, x, residual=None, prenorm=False):
class RCPSMambaBlock (line 133) | class RCPSMambaBlock(nn.Module):
method __init__ (line 134) | def __init__(
method forward (line 160) | def forward(
method allocate_inference_cache (line 201) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,...
class RCPSLMHead (line 209) | class RCPSLMHead(nn.Module):
method __init__ (line 211) | def __init__(self, true_dim: int, vocab_size: int, complement_map: dic...
method weight (line 225) | def weight(self):
method set_weight (line 229) | def set_weight(self, value):
method forward (line 233) | def forward(self, x):
FILE: caduceus/tests/test_rcps.py
function test_rcps_embedding (line 31) | def test_rcps_embedding(batch_size, seq_len, d_model, dtype):
function test_rcps_wrapper (line 80) | def test_rcps_wrapper(batch_size, seq_len, d_model, dtype):
function test_rcps_add_norm_wrapper (line 116) | def test_rcps_add_norm_wrapper(batch_size, seq_len, d_model, prenorm, dt...
function test_rcps_mamba_block_wrapper (line 155) | def test_rcps_mamba_block_wrapper(batch_size, seq_len, d_model, bidirect...
function test_rcps_lm_head (line 209) | def test_rcps_lm_head(batch_size, seq_len, d_model, dtype):
function test_rcps_backbone (line 271) | def test_rcps_backbone(batch_size, seq_len, n_layer, d_model, dtype, fus...
function test_rcps_mamba_lm (line 348) | def test_rcps_mamba_lm(batch_size, seq_len, n_layer, d_model, dtype, bid...
function test_collapse_invariance (line 429) | def test_collapse_invariance(batch_size, seq_len, n_layer, d_model, dtyp...
FILE: caduceus/tokenization_caduceus.py
class CaduceusTokenizer (line 10) | class CaduceusTokenizer(PreTrainedTokenizer):
method __init__ (line 13) | def __init__(self,
method vocab_size (line 83) | def vocab_size(self) -> int:
method complement_map (line 87) | def complement_map(self) -> Dict[int, int]:
method _tokenize (line 90) | def _tokenize(self, text: str, **kwargs) -> List[str]:
method _convert_token_to_id (line 93) | def _convert_token_to_id(self, token: str) -> int:
method _convert_id_to_token (line 96) | def _convert_id_to_token(self, index: int) -> str:
method convert_tokens_to_string (line 99) | def convert_tokens_to_string(self, tokens):
method get_special_tokens_mask (line 102) | def get_special_tokens_mask(
method build_inputs_with_special_tokens (line 120) | def build_inputs_with_special_tokens(
method get_vocab (line 130) | def get_vocab(self) -> Dict[str, int]:
method save_vocabulary (line 134) | def save_vocabulary(self, save_directory: str, filename_prefix: Option...
FILE: src/callbacks/params.py
class ParamsLog (line 10) | class ParamsLog(pl.Callback):
method __init__ (line 12) | def __init__(
method on_fit_start (line 28) | def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningMod...
FILE: src/callbacks/timer.py
class Timer (line 18) | class Timer(Callback):
method __init__ (line 21) | def __init__(
method on_train_start (line 36) | def on_train_start(self, trainer: Trainer, pl_module: LightningModule)...
method on_train_epoch_start (line 39) | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningM...
method on_train_batch_start (line 44) | def on_train_batch_start(
method on_train_batch_end (line 65) | def on_train_batch_end(
method on_train_epoch_end (line 86) | def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningMod...
method on_validation_epoch_start (line 92) | def on_validation_epoch_start(self, trainer: Trainer, pl_module: Light...
method on_validation_epoch_end (line 96) | def on_validation_epoch_end(self, trainer: Trainer, pl_module: Lightni...
method _should_log (line 103) | def _should_log(trainer) -> bool:
FILE: src/callbacks/validation.py
class ValEveryNGlobalSteps (line 13) | class ValEveryNGlobalSteps(Callback):
method __init__ (line 15) | def __init__(self, every_n):
method on_train_batch_end (line 19) | def on_train_batch_end(self, trainer, *_: Any):
FILE: src/dataloaders/base.py
class DefaultCollateMixin (line 20) | class DefaultCollateMixin:
method _collate_callback (line 30) | def _collate_callback(cls, x, *args, **kwargs):
method _return_callback (line 39) | def _return_callback(cls, return_value, *args, **kwargs):
method _collate (line 50) | def _collate(cls, batch, *args, **kwargs):
method _collate_fn (line 71) | def _collate_fn(cls, batch, *args, **kwargs):
method _dataloader (line 92) | def _dataloader(self, dataset, **loader_args):
class SequenceDataset (line 106) | class SequenceDataset(DefaultCollateMixin):
method init_defaults (line 115) | def init_defaults(self):
method __init_subclass__ (line 119) | def __init_subclass__(cls, **kwargs):
method __init__ (line 123) | def __init__(self, _name_, data_dir=None, **dataset_cfg):
method init (line 138) | def init(self):
method setup (line 142) | def setup(self):
method split_train_val (line 146) | def split_train_val(self, val_split):
method train_dataloader (line 159) | def train_dataloader(self, **kwargs):
method _train_dataloader (line 163) | def _train_dataloader(self, dataset, **kwargs):
method val_dataloader (line 169) | def val_dataloader(self, **kwargs):
method test_dataloader (line 173) | def test_dataloader(self, **kwargs):
method _eval_dataloader (line 177) | def _eval_dataloader(self, dataset, **kwargs):
method __str__ (line 183) | def __str__(self):
FILE: src/dataloaders/datasets/genomic_bench_dataset.py
class GenomicBenchmarkDataset (line 15) | class GenomicBenchmarkDataset(torch.utils.data.Dataset):
method __init__ (line 21) | def __init__(
method __len__ (line 80) | def __len__(self):
method __getitem__ (line 83) | def __getitem__(self, idx):
FILE: src/dataloaders/datasets/hg38_char_tokenizer.py
class CharacterTokenizer (line 15) | class CharacterTokenizer(PreTrainedTokenizer):
method __init__ (line 16) | def __init__(self, characters: Sequence[str], model_max_length: int, p...
method vocab_size (line 77) | def vocab_size(self) -> int:
method _tokenize (line 80) | def _tokenize(self, text: str) -> List[str]:
method _convert_token_to_id (line 83) | def _convert_token_to_id(self, token: str) -> int:
method _convert_id_to_token (line 86) | def _convert_id_to_token(self, index: int) -> str:
method convert_tokens_to_string (line 89) | def convert_tokens_to_string(self, tokens):
method build_inputs_with_special_tokens (line 92) | def build_inputs_with_special_tokens(
method get_special_tokens_mask (line 102) | def get_special_tokens_mask(
method get_vocab (line 120) | def get_vocab(self) -> Dict[str, int]:
method create_token_type_ids_from_sequences (line 123) | def create_token_type_ids_from_sequences(
method get_config (line 134) | def get_config(self) -> Dict:
method from_config (line 141) | def from_config(cls, config: Dict) -> "CharacterTokenizer":
method save_pretrained (line 147) | def save_pretrained(self, save_directory: Union[str, os.PathLike], **k...
method from_pretrained (line 154) | def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kw...
FILE: src/dataloaders/datasets/hg38_dataset.py
class FastaInterval (line 18) | class FastaInterval:
method __init__ (line 20) | def __init__(
method _compute_interval (line 41) | def _compute_interval(start, end, max_length, i_shift):
method __call__ (line 50) | def __call__(
class HG38Dataset (line 92) | class HG38Dataset(torch.utils.data.Dataset):
method __init__ (line 95) | def __init__(
method replace_value (line 153) | def replace_value(x, old_value, new_value):
method __len__ (line 157) | def __len__(self):
method __getitem__ (line 160) | def __getitem__(self, idx):
FILE: src/dataloaders/datasets/nucleotide_transformer_dataset.py
class NucleotideTransformerDataset (line 12) | class NucleotideTransformerDataset(torch.utils.data.Dataset):
method __init__ (line 19) | def __init__(
method __len__ (line 59) | def __len__(self):
method __getitem__ (line 62) | def __getitem__(self, idx):
FILE: src/dataloaders/fault_tolerant_sampler.py
class RandomFaultTolerantSampler (line 9) | class RandomFaultTolerantSampler(RandomSampler):
method __init__ (line 11) | def __init__(self, *args, generator=None, **kwargs):
method state_dict (line 26) | def state_dict(self):
method load_state_dict (line 29) | def load_state_dict(self, state_dict):
method __iter__ (line 43) | def __iter__(self) -> Iterator[int]:
class FaultTolerantDistributedSampler (line 64) | class FaultTolerantDistributedSampler(DistributedSampler):
method __init__ (line 66) | def __init__(self, *args, **kwargs):
method state_dict (line 72) | def state_dict(self):
method load_state_dict (line 75) | def load_state_dict(self, state_dict):
method __iter__ (line 86) | def __iter__(self):
FILE: src/dataloaders/genomics.py
class HG38 (line 29) | class HG38(SequenceDataset):
method __init__ (line 45) | def __init__(self, bed_file, fasta_file, tokenizer_name=None, dataset_...
method setup (line 97) | def setup(self, stage=None):
method init_datasets (line 119) | def init_datasets(self):
method train_dataloader (line 152) | def train_dataloader(self, **kwargs: Any) -> DataLoader:
method val_dataloader (line 177) | def val_dataloader(self, **kwargs: Any) -> Union[DataLoader, List[Data...
method test_dataloader (line 182) | def test_dataloader(self, **kwargs: Any) -> Union[DataLoader, List[Dat...
method _data_loader (line 189) | def _data_loader(dataset: Dataset, batch_size: int, shuffle: bool = Fa...
method load_state_dict (line 198) | def load_state_dict(self, checkpoint):
class GenomicBenchmark (line 208) | class GenomicBenchmark(HG38):
method __init__ (line 212) | def __init__(
method setup (line 262) | def setup(self, stage=None):
class NucleotideTransformer (line 308) | class NucleotideTransformer(HG38):
method __init__ (line 312) | def __init__(self, dataset_name, train_val_split_seed,
method setup (line 357) | def setup(self, stage=None):
FILE: src/dataloaders/utils/mlm.py
function mlm_getitem (line 4) | def mlm_getitem(seq, mlm_probability=0.15, contains_eos=False, tokenizer...
FILE: src/dataloaders/utils/rc.py
function coin_flip (line 12) | def coin_flip(p=0.5):
function string_reverse_complement (line 17) | def string_reverse_complement(seq):
FILE: src/models/baseline/genomics_benchmark_cnn.py
class GenomicsBenchmarkCNN (line 10) | class GenomicsBenchmarkCNN(nn.Module):
method __init__ (line 11) | def __init__(self, number_of_classes, vocab_size, input_len, embedding...
method count_flatten_size (line 42) | def count_flatten_size(self, input_len):
method forward (line 49) | def forward(self, x, state=None): # Adding `state` to be consistent w...
FILE: src/models/nn/activation.py
function Activation (line 9) | def Activation(activation=None, size=None, dim=-1):
class GLU (line 45) | class GLU(nn.Module):
method __init__ (line 46) | def __init__(self, dim=-1, activation='sigmoid'):
method forward (line 52) | def forward(self, x):
class ModReLU (line 57) | class ModReLU(nn.Module):
method __init__ (line 60) | def __init__(self, features):
method reset_parameters (line 67) | def reset_parameters(self):
method forward (line 70) | def forward(self, inputs):
class SquaredReLU (line 79) | class SquaredReLU(nn.Module):
method forward (line 80) | def forward(self, x):
function laplace (line 85) | def laplace(x, mu=0.707107, sigma=0.282095):
class Laplace (line 90) | class Laplace(nn.Module):
method __init__ (line 91) | def __init__(self, mu=0.707107, sigma=0.282095):
method forward (line 96) | def forward(self, x):
FILE: src/models/nn/adaptive_softmax.py
class OptionalParameterList (line 23) | class OptionalParameterList(nn.ParameterList):
method extra_repr (line 24) | def extra_repr(self):
class ProjectedAdaptiveLogSoftmax (line 37) | class ProjectedAdaptiveLogSoftmax(nn.Module):
method __init__ (line 38) | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
method _compute_logit (line 128) | def _compute_logit(self, hidden, weight, bias, proj):
method get_out_proj (line 142) | def get_out_proj(self, i):
method forward (line 153) | def forward(self, hidden, target, keep_order=False, key_padding_mask=N...
method compute_logits (line 237) | def compute_logits(self, hidden):
class AdaptiveEmbedding (line 300) | class AdaptiveEmbedding(nn.Module):
method __init__ (line 305) | def __init__(self, n_token, d_embed, d_proj, cutoffs : List[int], div_...
method forward (line 342) | def forward(self, inp):
function _init_weight (line 395) | def _init_weight(weight, d : int, init_scale : Optional[float], default=...
FILE: src/models/nn/utils.py
function wrap_kwargs (line 8) | def wrap_kwargs(f):
function discard_kwargs (line 84) | def discard_kwargs(f):
function PassthroughSequential (line 92) | def PassthroughSequential(*modules):
FILE: src/models/sequence/dna_embedding.py
class DNAEmbeddingModel (line 27) | class DNAEmbeddingModel(nn.Module, GenerationMixin):
method __init__ (line 34) | def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_siz...
method forward (line 82) | def forward(self, input_ids, position_ids=None, inference_params=None,...
method d_output (line 90) | def d_output(self):
class DNAEmbeddingModelMamba (line 99) | class DNAEmbeddingModelMamba(DNAEmbeddingModel):
method __init__ (line 102) | def __init__(
method forward (line 149) | def forward(self, input_ids, position_ids=None, inference_params=None,...
class DNAEmbeddingModelCaduceus (line 156) | class DNAEmbeddingModelCaduceus(DNAEmbeddingModel):
method __init__ (line 159) | def __init__(
method forward (line 179) | def forward(self, input_ids, position_ids=None, inference_params=None,...
function load_backbone (line 198) | def load_backbone(model, state_dict, freeze_backbone=False, ignore_head=...
FILE: src/models/sequence/hyena.py
class FFTConvFuncv2 (line 25) | class FFTConvFuncv2(torch.autograd.Function):
method forward (line 27) | def forward(ctx, u, k):
method backward (line 40) | def backward(ctx, dout):
function fftconv_ref (line 55) | def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None):
function mul_sum (line 79) | def mul_sum(q, y):
class Sin (line 83) | class Sin(nn.Module):
method __init__ (line 84) | def __init__(self, dim, w=10, train_freq=True):
method forward (line 92) | def forward(self, x):
class PositionalEmbedding (line 96) | class PositionalEmbedding(OptimModule):
method __init__ (line 97) | def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float = 1e-...
method forward (line 117) | def forward(self, L):
class ExponentialModulation (line 121) | class ExponentialModulation(OptimModule):
method __init__ (line 122) | def __init__(
method forward (line 139) | def forward(self, t, x):
class HyenaFilter (line 145) | class HyenaFilter(OptimModule):
method __init__ (line 146) | def __init__(
method filter (line 214) | def filter(self, L, *args, **kwargs):
method forward (line 225) | def forward(self, x, L, k=None, bias=None, *args, **kwargs):
class HyenaOperator (line 255) | class HyenaOperator(nn.Module):
method __init__ (line 256) | def __init__(
method setup_projections (line 330) | def setup_projections(self, fused_bias_fc, inner_factor):
method setup_filters (line 343) | def setup_filters(self, filter_cls, filter_args):
method recurrence (line 369) | def recurrence(self, u, state):
method forward (line 373) | def forward(self, u, *args, **kwargs):
method d_output (line 432) | def d_output(self):
FILE: src/models/sequence/long_conv_lm.py
class CheckpointedModule (line 33) | class CheckpointedModule(torch.nn.Module):
method __init__ (line 34) | def __init__(self, layer):
method forward (line 38) | def forward(self, x):
function create_mixer_cls (line 42) | def create_mixer_cls(
function create_mlp_cls (line 93) | def create_mlp_cls(
function create_block (line 130) | def create_block(
function _init_weights (line 195) | def _init_weights(
class LMBackbone (line 240) | class LMBackbone(nn.Module):
method __init__ (line 241) | def __init__(
method tie_weights (line 344) | def tie_weights(self):
method forward (line 348) | def forward(self, input_ids, position_ids=None, inference_params=None):
class ConvLMHeadModel (line 391) | class ConvLMHeadModel(nn.Module, GenerationMixin):
method __init__ (line 392) | def __init__(
method tie_weights (line 473) | def tie_weights(self):
method forward (line 478) | def forward(
FILE: src/ops/fftconv.py
function _mul_sum (line 11) | def _mul_sum(y, q):
function fftconv_ref (line 15) | def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None):
function fftconv_h3_ref (line 38) | def fftconv_h3_ref(k, ssm_kernel, D, q, v, head_dim=1, ssm_kernel_rev=No...
class FFTConvFunc (line 58) | class FFTConvFunc(torch.autograd.Function):
method forward (line 61) | def forward(ctx, u, k, D, dropout_mask=None, gelu=True, force_fp16_out...
method backward (line 88) | def backward(ctx, dout):
function fftconv_func (line 105) | def fftconv_func(u, k, D, dropout_mask=None, gelu=True, force_fp16_outpu...
FILE: src/tasks/decoders.py
class Decoder (line 16) | class Decoder(nn.Module):
method forward (line 21) | def forward(self, x, **kwargs):
method step (line 33) | def step(self, x):
class SequenceDecoder (line 40) | class SequenceDecoder(Decoder):
method __init__ (line 41) | def __init__(
method forward (line 70) | def forward(self, x, state=None, lengths=None, l_output=None):
method step (line 156) | def step(self, x, state=None):
function _instantiate (line 197) | def _instantiate(decoder, model=None, dataset=None):
function instantiate (line 217) | def instantiate(decoder, model=None, dataset=None):
FILE: src/tasks/encoders.py
class Encoder (line 7) | class Encoder(nn.Module):
method forward (line 16) | def forward(self, x, **kwargs):
function _instantiate (line 64) | def _instantiate(encoder, dataset=None, model=None):
function instantiate (line 84) | def instantiate(encoder, dataset=None, model=None):
FILE: src/tasks/metrics.py
class CorrectAggregatedMetric (line 13) | class CorrectAggregatedMetric(Metric):
method __init__ (line 16) | def __init__(self, class_idx: int, dist_sync_on_step=False):
method _update (line 25) | def _update(self, numerator, denominator, preds, y) -> tuple:
method update (line 28) | def update(self, logits: torch.Tensor, y: torch.Tensor):
method compute (line 36) | def compute(self):
method reset (line 41) | def reset(self):
class AccuracyPerClass (line 45) | class AccuracyPerClass(CorrectAggregatedMetric):
method _update (line 48) | def _update(self, numerator, denominator, preds, y) -> tuple:
class PrecisionPerClass (line 59) | class PrecisionPerClass(CorrectAggregatedMetric):
method _update (line 62) | def _update(self, numerator, denominator, preds, y) -> tuple:
class RecallPerClass (line 71) | class RecallPerClass(CorrectAggregatedMetric):
method _update (line 74) | def _update(self, numerator, denominator, preds, y) -> tuple:
function mcc (line 83) | def mcc(logits, y):
function last_k_ppl (line 90) | def last_k_ppl(logits, y, seq_len=1024, k=None):
function _student_t_map (line 122) | def _student_t_map(mu, sigma, nu):
function student_t_loss (line 127) | def student_t_loss(outs, y):
function gaussian_ll_loss (line 144) | def gaussian_ll_loss(outs, y):
function binary_cross_entropy (line 155) | def binary_cross_entropy(logits, y):
function binary_accuracy (line 161) | def binary_accuracy(logits, y):
function padded_cross_entropy (line 164) | def padded_cross_entropy(logits, y, pad_mask, pad_value=-1):
function cross_entropy (line 181) | def cross_entropy(logits, y, ignore_index=-100):
function soft_cross_entropy (line 187) | def soft_cross_entropy(logits, y, label_smoothing=0.0):
function accuracy (line 193) | def accuracy(logits, y):
function accuracy_ignore_index (line 203) | def accuracy_ignore_index(logits, y, ignore_index=-100):
function accuracy_at_k (line 212) | def accuracy_at_k(logits, y, k=1):
function f1_binary (line 221) | def f1_binary(logits, y):
function f1_macro (line 228) | def f1_macro(logits, y):
function f1_micro (line 235) | def f1_micro(logits, y):
function roc_auc_macro (line 242) | def roc_auc_macro(logits, y):
function roc_auc_micro (line 252) | def roc_auc_micro(logits, y):
function mse (line 260) | def mse(outs, y, len_batch=None):
function forecast_rmse (line 278) | def forecast_rmse(outs, y, len_batch=None):
function mae (line 282) | def mae(outs, y, len_batch=None):
function loss (line 301) | def loss(x, y, loss_fn):
function bpb (line 306) | def bpb(x, y, loss_fn):
function ppl (line 311) | def ppl(x, y, loss_fn):
FILE: src/tasks/tasks.py
class BaseTask (line 16) | class BaseTask:
method __init__ (line 27) | def __init__(self, dataset=None, model=None, loss=None, loss_val=None,...
method _init_torchmetrics (line 55) | def _init_torchmetrics(self):
method _reset_torchmetrics (line 83) | def _reset_torchmetrics(self, prefix=None):
method get_torchmetrics (line 96) | def get_torchmetrics(self, prefix):
method torchmetrics (line 105) | def torchmetrics(self, x, y, prefix, loss=None):
method get_torchmetrics (line 124) | def get_torchmetrics(self, prefix):
method metrics (line 127) | def metrics(self, x, y, **kwargs):
method forward (line 143) | def forward(self, batch, encoder, model, decoder, _state):
class Scalar (line 161) | class Scalar(nn.Module):
method __init__ (line 162) | def __init__(self, c=1):
method forward (line 166) | def forward(self, x):
class LMTask (line 170) | class LMTask(BaseTask):
method forward (line 171) | def forward(self, batch, encoder, model, decoder, _state):
class MultiClass (line 199) | class MultiClass(BaseTask):
method __init__ (line 201) | def __init__(self, *args, **kwargs):
method metrics (line 211) | def metrics(self, x, y, **kwargs):
method _reset_torchmetrics (line 236) | def _reset_torchmetrics(self, prefix=None):
class HG38Task (line 244) | class HG38Task(LMTask):
method __init__ (line 246) | def __init__(self, dataset=None, model=None, loss=None, loss_val=None,...
method metrics (line 303) | def metrics(self, x, y, **kwargs):
class AdaptiveLMTask (line 335) | class AdaptiveLMTask(BaseTask):
method __init__ (line 336) | def __init__(
FILE: src/tasks/torchmetrics.py
class Perplexity (line 24) | class Perplexity(Metric):
method __init__ (line 46) | def __init__(self, **kwargs: Dict[str, Any]):
method update (line 54) | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor]...
method compute (line 68) | def compute(self) -> Tensor:
class NumTokens (line 75) | class NumTokens(Metric):
method __init__ (line 88) | def __init__(self, **kwargs: Dict[str, Any]):
method update (line 97) | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor]...
method compute (line 100) | def compute(self) -> Tensor:
method reset (line 103) | def reset(self):
method _forward_reduce_state_update (line 109) | def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:
FILE: src/utils/config.py
function is_list (line 13) | def is_list(x):
function is_dict (line 17) | def is_dict(x):
function to_dict (line 21) | def to_dict(x, recursive=True):
function to_list (line 37) | def to_list(x, recursive=False):
function extract_attrs_from_obj (line 56) | def extract_attrs_from_obj(obj, *attrs):
function auto_assign_attrs (line 63) | def auto_assign_attrs(cls, **kwargs):
function instantiate (line 68) | def instantiate(registry, config, *args, partial=False, wrap=None, **kwa...
function get_class (line 112) | def get_class(registry, _name_):
function omegaconf_filter_keys (line 116) | def omegaconf_filter_keys(d, fn=None):
FILE: src/utils/optim/schedulers.py
class CosineWarmup (line 11) | class CosineWarmup(torch.optim.lr_scheduler.CosineAnnealingLR):
method __init__ (line 13) | def __init__(self, optimizer, T_max, eta_min=0, warmup_step=0, **kwargs):
method get_lr (line 19) | def get_lr(self):
function InvSqrt (line 40) | def InvSqrt(optimizer, warmup_step):
function Constant (line 54) | def Constant(optimizer, warmup_step):
class TimmCosineLRScheduler (line 65) | class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler....
method __init__ (line 70) | def __init__(self, *args, **kwargs):
method step (line 75) | def step(self, epoch=None):
FILE: src/utils/optim_groups.py
function add_optimizer_hooks (line 14) | def add_optimizer_hooks(
function group_parameters_for_optimizer (line 41) | def group_parameters_for_optimizer(
FILE: src/utils/train.py
class LoggingContext (line 20) | class LoggingContext:
method __init__ (line 21) | def __init__(self, logger, level=None, handler=None, close=True):
method __enter__ (line 27) | def __enter__(self):
method __exit__ (line 34) | def __exit__(self, et, ev, tb):
function get_logger (line 44) | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
function process_config (line 58) | def process_config(config: DictConfig) -> DictConfig: # TODO because of...
function print_config (line 101) | def print_config(
function log_optimizer (line 143) | def log_optimizer(logger, optimizer, keys):
class OptimModule (line 154) | class OptimModule(nn.Module):
method register (line 157) | def register(self, name, tensor, lr=None, wd=0.0):
FILE: train.py
class DummyExperiment (line 43) | class DummyExperiment:
method nop (line 46) | def nop(self, *args, **kw):
method __getattr__ (line 49) | def __getattr__(self, _):
method __getitem__ (line 52) | def __getitem__(self, idx) -> "DummyExperiment":
method __setitem__ (line 56) | def __setitem__(self, *args, **kwargs) -> None:
function rank_zero_experiment (line 60) | def rank_zero_experiment(fn: Callable) -> Callable:
class CustomWandbLogger (line 74) | class CustomWandbLogger(WandbLogger):
method __init__ (line 76) | def __init__(self, *args, **kwargs):
method experiment (line 83) | def experiment(self):
class SequenceLightningModule (line 126) | class SequenceLightningModule(pl.LightningModule):
method __init__ (line 127) | def __init__(self, config):
method setup (line 159) | def setup(self, stage=None):
method load_state_dict (line 240) | def load_state_dict(self, state_dict, strict=False):
method _check_config (line 255) | def _check_config(self):
method _initialize_state (line 268) | def _initialize_state(self):
method _reset_state (line 273) | def _reset_state(self, batch, device=None):
method _detach_state (line 278) | def _detach_state(self, state):
method _process_state (line 292) | def _process_state(self, batch, batch_idx, training=True):
method forward (line 326) | def forward(self, batch):
method step (line 329) | def step(self, x_t):
method _shared_step (line 336) | def _shared_step(self, batch, batch_idx, prefix="train"):
method on_train_epoch_start (line 379) | def on_train_epoch_start(self):
method training_epoch_end (line 383) | def training_epoch_end(self, outputs):
method on_validation_epoch_start (line 387) | def on_validation_epoch_start(self):
method validation_epoch_end (line 392) | def validation_epoch_end(self, outputs):
method on_test_epoch_start (line 396) | def on_test_epoch_start(self):
method test_epoch_end (line 401) | def test_epoch_end(self, outputs):
method training_step (line 405) | def training_step(self, batch, batch_idx, dataloader_idx=0):
method validation_step (line 438) | def validation_step(self, batch, batch_idx, dataloader_idx=0):
method test_step (line 455) | def test_step(self, batch, batch_idx, dataloader_idx=0):
method configure_optimizers (line 460) | def configure_optimizers(self):
method train_dataloader (line 543) | def train_dataloader(self):
method _eval_dataloaders_names (line 546) | def _eval_dataloaders_names(self, loaders, prefix):
method _eval_dataloaders (line 557) | def _eval_dataloaders(self):
method val_dataloader (line 584) | def val_dataloader(self):
method test_dataloader (line 589) | def test_dataloader(self):
function create_trainer (line 596) | def create_trainer(config, **kwargs):
function fsspec_exists (line 649) | def fsspec_exists(filename):
function train (line 654) | def train(config):
function main (line 701) | def main(config: OmegaConf):
FILE: vep_embeddings.py
class DNAEmbeddingModel (line 30) | class DNAEmbeddingModel(nn.Module):
method __init__ (line 36) | def __init__(
method forward (line 56) | def forward(self, input_ids):
class EnformerTokenizer (line 63) | class EnformerTokenizer:
method encode (line 70) | def encode(
method batch_encode_plus (line 83) | def batch_encode_plus(
function setup_distributed (line 92) | def setup_distributed():
function cleanup_distributed (line 97) | def cleanup_distributed():
function fsspec_exists (line 102) | def fsspec_exists(filename):
function fsspec_listdir (line 108) | def fsspec_listdir(dirname):
function recast_chromosome_tissue_dist2TSS (line 115) | def recast_chromosome_tissue_dist2TSS(examples):
function tokenize_variants (line 124) | def tokenize_variants(examples, tokenizer, max_length: int):
function find_variant_idx (line 172) | def find_variant_idx(examples):
function prepare_dataset (line 198) | def prepare_dataset(args, tokenizer):
function get_backbone_model (line 260) | def get_backbone_model(args, device):
function concat_storage_dict_values (line 270) | def concat_storage_dict_values(storage_dict):
function dump_embeddings (line 275) | def dump_embeddings(args, dataset, model, device):
function combine_embeddings (line 407) | def combine_embeddings(embeds_path):
function main (line 433) | def main(args):
Condensed preview — 104 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (427K chars).
[
{
"path": ".gitignore",
"chars": 1815,
"preview": "data.tar.gz\n*.tsf\n*.ckpt\n.ipynb_checkpoints\n*/.ipynb_checkpoints/*\n*.lprof\n\n.DS_Store\n.idea/\noutputs/\n\n# slurm log files"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 13053,
"preview": "<p align=\"center\">\n <img src=\"assets/Caduceus_image.png\" alt=\"Caduceus\" width=\"200\"/>\n</p>\n\n\n# Caduceus ☤: Bi-D"
},
{
"path": "caduceus/__init__.py",
"chars": 265,
"preview": "\"\"\"Hugging Face config, model, and tokenizer for Caduceus.\n\n\"\"\"\n\nfrom .configuration_caduceus import CaduceusConfig\nfrom"
},
{
"path": "caduceus/configuration_caduceus.py",
"chars": 1964,
"preview": "\"\"\"Caduceus config for Hugging Face.\n\n\"\"\"\n\nfrom typing import Optional, Union\n\nfrom transformers import PretrainedConfig"
},
{
"path": "caduceus/modeling_caduceus.py",
"chars": 27326,
"preview": "\"\"\"Caduceus model for Hugging Face.\n\n\"\"\"\n\nimport inspect\nimport math\nfrom functools import partial\nfrom typing import Op"
},
{
"path": "caduceus/modeling_rcps.py",
"chars": 9977,
"preview": "\"\"\"Reverse-complement equivariant modules.\n\n\"\"\"\nfrom collections import OrderedDict\nfrom typing import Optional\n\nimport "
},
{
"path": "caduceus/tests/test_rcps.py",
"chars": 19337,
"preview": "\"\"\"Tests for RCPS modules.\n\n\"\"\"\n\nimport pytest\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nt"
},
{
"path": "caduceus/tokenization_caduceus.py",
"chars": 4966,
"preview": "\"\"\"Character tokenizer for Hugging Face.\n\n\"\"\"\n\nfrom typing import List, Optional, Dict, Sequence, Tuple\n\nfrom transforme"
},
{
"path": "caduceus_env.yml",
"chars": 1109,
"preview": "name: caduceus_env\nchannels:\n - pytorch\n - anaconda\n - nvidia\n - defaults\ndependencies:\n - cuda-nvcc=11.7.99\n - pi"
},
{
"path": "configs/callbacks/base.yaml",
"chars": 322,
"preview": "learning_rate_monitor:\n # _target_: pytorch_lightning.callbacks.LearningRateMonitor\n logging_interval: ${train.interva"
},
{
"path": "configs/callbacks/checkpoint.yaml",
"chars": 1320,
"preview": "model_checkpoint:\n monitor: ${train.monitor} # name of the logged metric which determines when model is improving\n mod"
},
{
"path": "configs/callbacks/gpu_affinity.yaml",
"chars": 37,
"preview": "gpu_affinity:\n _name_: gpu_affinity\n"
},
{
"path": "configs/callbacks/rich.yaml",
"chars": 86,
"preview": "rich_model_summary:\n max_depth: 2\n\nrich_progress_bar:\n refresh_rate_per_second: 1.0\n"
},
{
"path": "configs/callbacks/val_every_n_global_steps.yaml",
"chars": 43,
"preview": "val_every_n_global_steps:\n every_n: 10000\n"
},
{
"path": "configs/callbacks/wandb.yaml",
"chars": 667,
"preview": "defaults:\n - default\n\nwatch_model:\n _target_: src.callbacks.wandb_callbacks.WatchModel\n log: \"all\"\n log_freq: 100\n\nu"
},
{
"path": "configs/config.yaml",
"chars": 3115,
"preview": "# @package _global_\ndefaults:\n - _self_\n - experiment: ???\n # - model: ??? # Model backbone\n # - pipeline: ??? # S"
},
{
"path": "configs/dataset/genomic_benchmark.yaml",
"chars": 2054,
"preview": "_name_: genomic_benchmark\ntrain_val_split_seed: ${train.seed} # Used for train/validation splitting\ndataset_name: dummy"
},
{
"path": "configs/dataset/hg38.yaml",
"chars": 299,
"preview": "_name_: hg38\nbed_file: null\nfasta_file: null\ndataset_name: hg38\ntokenizer_name: null\ncache_dir: null\nmax_length: 1024\nad"
},
{
"path": "configs/dataset/nucleotide_transformer.yaml",
"chars": 2703,
"preview": "_name_: nucleotide_transformer # this links to the overall SequenceDataset of all nucleotide transformer datasets\ntrain"
},
{
"path": "configs/experiment/hg38/genomic_benchmark.yaml",
"chars": 2476,
"preview": "# @package _global_\ndefaults:\n - /pipeline: genomic_benchmark\n - /model: ???\n - override /scheduler: cosine_warmup_ti"
},
{
"path": "configs/experiment/hg38/genomic_benchmark_cnn.yaml",
"chars": 2213,
"preview": "# @package _global_\ndefaults:\n - /model: genomics_benchmark_cnn\n - /pipeline: genomic_benchmark\n - override /schedule"
},
{
"path": "configs/experiment/hg38/hg38.yaml",
"chars": 1530,
"preview": "# @package _global_\ndefaults:\n - /pipeline: hg38\n - /model: ??? # Specify a model, e.g. model=mamba or model=hyena\n "
},
{
"path": "configs/experiment/hg38/nucleotide_transformer.yaml",
"chars": 1502,
"preview": "# @package _global_\ndefaults:\n - /pipeline: nucleotide_transformer\n - /model: ???\n - override /scheduler: cosine_warm"
},
{
"path": "configs/loader/default.yaml",
"chars": 99,
"preview": "num_workers: ${eval:\"len(__import__('os').sched_getaffinity(0))\"}\npin_memory: True\ndrop_last: True\n"
},
{
"path": "configs/model/caduceus.yaml",
"chars": 1055,
"preview": "# Use open-source version of Mamba\n_name_: caduceus_lm\nconfig:\n _target_: caduceus.configuration_caduceus.CaduceusConfi"
},
{
"path": "configs/model/genomics_benchmark_cnn.yaml",
"chars": 277,
"preview": "# Use open-source version of Mamba\n_name_: genomics_benchmark_cnn\nnumber_of_classes: ${dataset.d_output}\nvocab_size: 12\n"
},
{
"path": "configs/model/hyena.yaml",
"chars": 522,
"preview": "_name_: hyena_lm\nd_model: 128\nn_layer: 2\nd_inner: ${eval:4 * ${.d_model}}\nvocab_size: 12\nresid_dropout: 0.0\nembed_dropou"
},
{
"path": "configs/model/layer/hyena.yaml",
"chars": 275,
"preview": "_name_: hyena\nl_max: 1024\norder: 2\nfilter_order: 64\nnum_heads: 1\ninner_factor: 1\nnum_blocks: 1\nfused_bias_fc: false\noute"
},
{
"path": "configs/model/mamba.yaml",
"chars": 745,
"preview": "# Use open-source version of Mamba\n_name_: mamba_lm\nconfig:\n _target_: mamba_ssm.models.config_mamba.MambaConfig\n d_mo"
},
{
"path": "configs/optimizer/adam.yaml",
"chars": 184,
"preview": "# _target_: torch.optim.Adam\n_name_: adam\nlr: 0.001 # Initial learning rate\n# weight_decay: 0.0 # Weight decay for ada"
},
{
"path": "configs/optimizer/adamw.yaml",
"chars": 132,
"preview": "# _target_: torch.optim.AdamW\n_name_: adamw\nlr: 0.001 # Initial learning rate\nweight_decay: 0.00 # Weight decay\nbetas: ["
},
{
"path": "configs/optimizer/sgd.yaml",
"chars": 137,
"preview": "# _target_: torch.optim.SGD\n_name_: sgd\nlr: 0.001 # Initial learning rate\nmomentum: 0.9\nweight_decay: 0.0 # Weight dec"
},
{
"path": "configs/pipeline/genomic_benchmark.yaml",
"chars": 388,
"preview": "# @package _global_\ndefaults:\n - /trainer: default\n - /loader: default\n - /dataset: genomic_benchmark\n - /task: mult"
},
{
"path": "configs/pipeline/hg38.yaml",
"chars": 605,
"preview": "# @package _global_\ndefaults:\n - /trainer: default\n - /loader: null\n - /dataset: hg38\n - /optimizer: adamw\n - /sche"
},
{
"path": "configs/pipeline/nucleotide_transformer.yaml",
"chars": 446,
"preview": "# @package _global_\ndefaults:\n - /trainer: default\n - /loader: default\n - /dataset: nucleotide_transformer\n - /task:"
},
{
"path": "configs/scheduler/constant.yaml",
"chars": 124,
"preview": "# @package _global_\ntrain:\n interval: epoch\nscheduler:\n # _target_: transformers.get_constant_schedule\n _name_: const"
},
{
"path": "configs/scheduler/constant_warmup.yaml",
"chars": 205,
"preview": "# @package _global_\ntrain:\n interval: step\nscheduler:\n # _target_: transformers.get_constant_schedule_with_warmup\n _n"
},
{
"path": "configs/scheduler/cosine.yaml",
"chars": 248,
"preview": "# @package _global_\ntrain:\n interval: epoch\nscheduler:\n # _target_: torch.optim.lr_scheduler.CosineAnnealingLR\n _name"
},
{
"path": "configs/scheduler/cosine_warmup.yaml",
"chars": 191,
"preview": "# @package _global_\ntrain:\n interval: step\nscheduler:\n # _target_: transformers.get_cosine_schedule_with_warmup\n _nam"
},
{
"path": "configs/scheduler/cosine_warmup_timm.yaml",
"chars": 234,
"preview": "# @package _global_\ntrain:\n interval: step\nscheduler:\n # _target_: transformers.get_cosine_schedule_with_warmup\n _nam"
},
{
"path": "configs/scheduler/linear_warmup.yaml",
"chars": 191,
"preview": "# @package _global_\ntrain:\n interval: step\nscheduler:\n # _target_: transformers.get_linear_schedule_with_warmup\n _nam"
},
{
"path": "configs/scheduler/multistep.yaml",
"chars": 165,
"preview": "# @package _global_\ntrain:\n interval: epoch\n# _target_: torch.optim.lr_scheduler.MultiStepLR\nscheduler:\n _name_: multi"
},
{
"path": "configs/scheduler/plateau.yaml",
"chars": 346,
"preview": "# @package _global_\ntrain:\n interval: epoch\n monitor: ??? # must be specified\nscheduler:\n # _target_: torch.optim.lr_"
},
{
"path": "configs/scheduler/step.yaml",
"chars": 145,
"preview": "# @package _global_\ntrain:\n interval: step\nscheduler:\n # _target_: torch.optim.lr_scheduler.StepLR\n _name_: step\n st"
},
{
"path": "configs/task/lm.yaml",
"chars": 84,
"preview": "_name_: lm\n# loss: cross_entropy # Handled by task: cross entropy loss\nmetrics: ppl\n"
},
{
"path": "configs/task/multiclass_classification.yaml",
"chars": 115,
"preview": "# _target_: tasks.tasks.MultiClass\n_name_: multiclass\nloss: cross_entropy\nmetrics:\n - accuracy\ntorchmetrics: null\n"
},
{
"path": "configs/task/multilabel_classification.yaml",
"chars": 235,
"preview": "# _target_:\n_name_: base\nloss: binary_cross_entropy\nmetrics: null\ntorchmetrics:\n - MultilabelAUROC # AUROC\n - Multilab"
},
{
"path": "configs/task/regression.yaml",
"chars": 88,
"preview": "# _target_: tasks.tasks.BaseTask\n_name_: base\nloss: mse\nmetrics: mse\ntorchmetrics: null\n"
},
{
"path": "configs/trainer/debug.yaml",
"chars": 328,
"preview": "defaults:\n - default\n\ngpus: 1\nmin_epochs: 1\nmax_epochs: 10\n\n# prints\nprogress_bar_refresh_rate: null\nweights_summary: f"
},
{
"path": "configs/trainer/default.yaml",
"chars": 527,
"preview": "_target_: pytorch_lightning.Trainer\n\ndevices: 1\naccelerator: gpu\naccumulate_grad_batches: 1 # Gradient accumulation ever"
},
{
"path": "configs/trainer/full.yaml",
"chars": 1076,
"preview": "_target_: pytorch_lightning.Trainer\n\n# default values for all trainer parameters\ncheckpoint_callback: True\ndefault_root_"
},
{
"path": "configs/trainer/lm.yaml",
"chars": 643,
"preview": "accumulate_grad_batches: 1\n# accelerator: null # set to 'ddp' for distributed\n# amp_backend: native # 'native' | 'apex'\n"
},
{
"path": "setup_env.sh",
"chars": 489,
"preview": "#!/bin/bash\n\n# Shell script to set environment variables when running code in this repository.\n# Usage:\n# source set"
},
{
"path": "slurm_scripts/dump_vep_embeddings.sh",
"chars": 2653,
"preview": "#!/bin/bash\n#SBATCH --get-user-env # Retrieve the users login environment\n#SBATCH -t 96:00:00 "
},
{
"path": "slurm_scripts/run_genomics_benchmark.sh",
"chars": 2337,
"preview": "#!/bin/bash\n#SBATCH --get-user-env # Retrieve the users login environment\n#SBATCH -t 96:00:00 "
},
{
"path": "slurm_scripts/run_genomics_benchmark_cnn.sh",
"chars": 1949,
"preview": "#!/bin/bash\n#SBATCH --get-user-env # Retrieve the users login environment\n#SBATCH -t 48:00:00 "
},
{
"path": "slurm_scripts/run_nucleotide_transformer.sh",
"chars": 2302,
"preview": "#!/bin/bash\n#SBATCH --get-user-env # Retrieve the users login environment\n#SBATCH -t 96:00:00 "
},
{
"path": "slurm_scripts/run_pretrain_caduceus.sh",
"chars": 2204,
"preview": "#!/bin/bash\n#SBATCH --get-user-env # Retrieve the users login environment\n#SBATCH -t 96:00:00 "
},
{
"path": "slurm_scripts/run_pretrain_hyena.sh",
"chars": 1734,
"preview": "#!/bin/bash\n#SBATCH --get-user-env # Retrieve the users login environment\n#SBATCH -t 96:00:00 "
},
{
"path": "slurm_scripts/run_pretrain_mamba.sh",
"chars": 1833,
"preview": "#!/bin/bash\n#SBATCH --get-user-env # Retrieve the users login environment\n#SBATCH -t 96:00:00 "
},
{
"path": "slurm_scripts/wrapper_run_genomics.sh",
"chars": 3595,
"preview": "#!/bin/bash\n\n# Choose one from below\n\n## Hyena\n## TODO: Download HF model from https://huggingface.co/LongSafari/hyenadn"
},
{
"path": "slurm_scripts/wrapper_run_genomics_cnn.sh",
"chars": 629,
"preview": "#!/bin/bash\n\nLOG_DIR=\"../watch_folder/gb_cv5/cnn_baseline\"\nmkdir -p \"${LOG_DIR}\"\nexport_str=\"ALL\"\nfor TASK in \"dummy_mou"
},
{
"path": "slurm_scripts/wrapper_run_nucleotide_transformer.sh",
"chars": 2642,
"preview": "#!/bin/bash\n\n# Choose one from below\n\n## Caduceus NO POST HOC\n#LOG_DIR=\"../watch_folder/nt_cv10_ep20/caduceus\"\n#CONFIG_P"
},
{
"path": "src/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/callbacks/params.py",
"chars": 1416,
"preview": "\"\"\"Callback to log the number of parameters of the model.\n\n\"\"\"\n\nimport pytorch_lightning as pl\nfrom pytorch_lightning.ut"
},
{
"path": "src/callbacks/timer.py",
"chars": 3685,
"preview": "\"\"\"Callback to monitor the speed of each step and each epoch.\n\nhttps://github.com/HazyResearch/transformers/blob/master/"
},
{
"path": "src/callbacks/validation.py",
"chars": 1369,
"preview": "\"\"\"Check validation every n **global** steps.\n\nPytorch Lightning has a `val_check_interval` parameter that checks valida"
},
{
"path": "src/dataloaders/__init__.py",
"chars": 57,
"preview": "from . import genomics\nfrom .base import SequenceDataset\n"
},
{
"path": "src/dataloaders/base.py",
"chars": 7296,
"preview": "\"\"\" Datasets for core experimental results.\n\n\"\"\"\n\nimport os\nfrom functools import partial\nfrom pathlib import Path\n\nimpo"
},
{
"path": "src/dataloaders/datasets/genomic_bench_dataset.py",
"chars": 4566,
"preview": "\"\"\"Genomic Benchmarks Dataset.\n\nFrom: https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks\n\"\"\"\n\nfrom pathlib import P"
},
{
"path": "src/dataloaders/datasets/hg38_char_tokenizer.py",
"chars": 5939,
"preview": "\"\"\" \nFrom: https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py\n\nCharacterToken"
},
{
"path": "src/dataloaders/datasets/hg38_dataset.py",
"chars": 7514,
"preview": "\"\"\"Dataset for sampling arbitrary intervals from the human genome.\n\n\"\"\"\n\nimport math\nfrom pathlib import Path\n\nimport pa"
},
{
"path": "src/dataloaders/datasets/nucleotide_transformer_dataset.py",
"chars": 3584,
"preview": "\"\"\"Nucleotide Transformer Benchmarks Dataset.\n\nFrom: https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_"
},
{
"path": "src/dataloaders/fault_tolerant_sampler.py",
"chars": 4660,
"preview": "# Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytor"
},
{
"path": "src/dataloaders/genomics.py",
"chars": 17628,
"preview": "\"\"\"Dataloaders for genomics datasets, including pretraining and downstream tasks.\n\n - Adapted from:\n https://g"
},
{
"path": "src/dataloaders/utils/mlm.py",
"chars": 1913,
"preview": "import torch\n\n\ndef mlm_getitem(seq, mlm_probability=0.15, contains_eos=False, tokenizer=None, eligible_replacements=None"
},
{
"path": "src/dataloaders/utils/rc.py",
"chars": 661,
"preview": "\"\"\"Utility functions for reverse complementing DNA sequences.\n\n\"\"\"\n\nfrom random import random\n\nSTRING_COMPLEMENT_MAP = {"
},
{
"path": "src/models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/models/baseline/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/models/baseline/genomics_benchmark_cnn.py",
"chars": 1966,
"preview": "\"\"\"Genomics Benchmark CNN model.\n\nAdapted from https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/blob/main/src/gen"
},
{
"path": "src/models/nn/__init__.py",
"chars": 35,
"preview": "from .activation import Activation\n"
},
{
"path": "src/models/nn/activation.py",
"chars": 2876,
"preview": "\"\"\"Utilities for activation functions.\"\"\"\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as "
},
{
"path": "src/models/nn/adaptive_softmax.py",
"chars": 16332,
"preview": "# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 ("
},
{
"path": "src/models/nn/utils.py",
"chars": 3722,
"preview": "\"\"\" Utility wrappers around modules to let them handle Args and extra arguments \"\"\"\n\nimport inspect\nfrom functools impor"
},
{
"path": "src/models/sequence/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/models/sequence/dna_embedding.py",
"chars": 10395,
"preview": "\"\"\"DNA Embedding Model.\n\nBackbones from LM pre-training models, used for downstream tasks.\n\"\"\"\n\nfrom functools import pa"
},
{
"path": "src/models/sequence/hyena.py",
"chars": 14719,
"preview": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\n\ntry:\n f"
},
{
"path": "src/models/sequence/long_conv_lm.py",
"chars": 17796,
"preview": "import copy\nimport math\nimport re\nfrom collections import namedtuple\nfrom functools import partial\n\nimport torch\nimport "
},
{
"path": "src/ops/fftconv.py",
"chars": 4672,
"preview": "import math\n\nimport torch\nimport torch.nn.functional as F\n\nfrom einops import rearrange\n\nfrom fftconv import fftconv_fwd"
},
{
"path": "src/tasks/decoders.py",
"chars": 6977,
"preview": "\"\"\"Decoder heads.\n\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport src.models.nn.utils a"
},
{
"path": "src/tasks/encoders.py",
"chars": 2398,
"preview": "from torch import nn\n\nimport src.models.nn.utils as U\nimport src.utils as utils\n\n\nclass Encoder(nn.Module):\n \"\"\"Encod"
},
{
"path": "src/tasks/metrics.py",
"chars": 12663,
"preview": "import math\nfrom functools import partial\n\nimport torch\nimport torch.nn.functional as F\nimport torchmetrics.functional a"
},
{
"path": "src/tasks/tasks.py",
"chars": 15398,
"preview": "import inspect\nfrom typing import List\n\nimport torch.nn as nn\nfrom einops import rearrange\n\nimport src.models.nn.utils a"
},
{
"path": "src/tasks/torchmetrics.py",
"chars": 4532,
"preview": "# Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py\n# But we compute th"
},
{
"path": "src/utils/__init__.py",
"chars": 79,
"preview": "from .config import is_list, is_dict, to_list, to_dict, get_class, instantiate\n"
},
{
"path": "src/utils/config.py",
"chars": 3548,
"preview": "\"\"\"Utilities for dealing with collection objects (lists, dicts) and configs.\n\n\"\"\"\n\nimport functools\nfrom typing import S"
},
{
"path": "src/utils/optim/schedulers.py",
"chars": 3534,
"preview": "\"\"\"Custom learning rate schedulers\"\"\"\n\nimport math\nimport warnings\nimport torch\n\nfrom timm.scheduler import CosineLRSche"
},
{
"path": "src/utils/optim_groups.py",
"chars": 7052,
"preview": "\"\"\"Utilities for special optimizer hyperparameters.\n\ngroup_parameters_for_optimizer is a modification of timm's optimize"
},
{
"path": "src/utils/registry.py",
"chars": 2648,
"preview": "\"\"\"Class registry for models, layers, optimizers, and schedulers.\n\n\"\"\"\n\noptimizer = {\n \"adam\": \"torch.optim.Adam\",\n "
},
{
"path": "src/utils/train.py",
"chars": 6135,
"preview": "\"\"\" Utils for the training loop.\n\nCopied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.p"
},
{
"path": "train.py",
"chars": 28722,
"preview": "\"\"\"Main training entry point for pre-training and downstream fine-tuning.\n\n\"\"\"\n\nimport json\nimport os\nimport random\nimpo"
},
{
"path": "vep_embeddings.py",
"chars": 20804,
"preview": "\"\"\"Dump model embeddings for VEP classification task.\n\n\"\"\"\n\nimport argparse\nimport os\nfrom functools import partial\nfrom"
},
{
"path": "vep_svm.ipynb",
"chars": 15382,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"db878bc1\",\n \"metadata\": {},\n \"source\": [\n \"## Imports and"
}
]
About this extraction
This page contains the full source code of the kuleshov-group/caduceus GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 104 files (394.9 KB), approximately 102.2k tokens, and a symbol index with 434 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.