Repository: kuleshov-group/caduceus
Branch: main
Commit: 0060a6d8079b
Files: 104
Total size: 394.9 KB
Directory structure:
gitextract_kw76odef/
├── .gitignore
├── LICENSE
├── README.md
├── caduceus/
│ ├── __init__.py
│ ├── configuration_caduceus.py
│ ├── modeling_caduceus.py
│ ├── modeling_rcps.py
│ ├── tests/
│ │ └── test_rcps.py
│ └── tokenization_caduceus.py
├── caduceus_env.yml
├── configs/
│ ├── callbacks/
│ │ ├── base.yaml
│ │ ├── checkpoint.yaml
│ │ ├── gpu_affinity.yaml
│ │ ├── rich.yaml
│ │ ├── val_every_n_global_steps.yaml
│ │ └── wandb.yaml
│ ├── config.yaml
│ ├── dataset/
│ │ ├── genomic_benchmark.yaml
│ │ ├── hg38.yaml
│ │ └── nucleotide_transformer.yaml
│ ├── experiment/
│ │ └── hg38/
│ │ ├── genomic_benchmark.yaml
│ │ ├── genomic_benchmark_cnn.yaml
│ │ ├── hg38.yaml
│ │ └── nucleotide_transformer.yaml
│ ├── loader/
│ │ └── default.yaml
│ ├── model/
│ │ ├── caduceus.yaml
│ │ ├── genomics_benchmark_cnn.yaml
│ │ ├── hyena.yaml
│ │ ├── layer/
│ │ │ └── hyena.yaml
│ │ └── mamba.yaml
│ ├── optimizer/
│ │ ├── adam.yaml
│ │ ├── adamw.yaml
│ │ └── sgd.yaml
│ ├── pipeline/
│ │ ├── genomic_benchmark.yaml
│ │ ├── hg38.yaml
│ │ └── nucleotide_transformer.yaml
│ ├── scheduler/
│ │ ├── constant.yaml
│ │ ├── constant_warmup.yaml
│ │ ├── cosine.yaml
│ │ ├── cosine_warmup.yaml
│ │ ├── cosine_warmup_timm.yaml
│ │ ├── linear_warmup.yaml
│ │ ├── multistep.yaml
│ │ ├── plateau.yaml
│ │ └── step.yaml
│ ├── task/
│ │ ├── lm.yaml
│ │ ├── multiclass_classification.yaml
│ │ ├── multilabel_classification.yaml
│ │ └── regression.yaml
│ └── trainer/
│ ├── debug.yaml
│ ├── default.yaml
│ ├── full.yaml
│ └── lm.yaml
├── setup_env.sh
├── slurm_scripts/
│ ├── dump_vep_embeddings.sh
│ ├── run_genomics_benchmark.sh
│ ├── run_genomics_benchmark_cnn.sh
│ ├── run_nucleotide_transformer.sh
│ ├── run_pretrain_caduceus.sh
│ ├── run_pretrain_hyena.sh
│ ├── run_pretrain_mamba.sh
│ ├── wrapper_run_genomics.sh
│ ├── wrapper_run_genomics_cnn.sh
│ └── wrapper_run_nucleotide_transformer.sh
├── src/
│ ├── __init__.py
│ ├── callbacks/
│ │ ├── params.py
│ │ ├── timer.py
│ │ └── validation.py
│ ├── dataloaders/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── datasets/
│ │ │ ├── genomic_bench_dataset.py
│ │ │ ├── hg38_char_tokenizer.py
│ │ │ ├── hg38_dataset.py
│ │ │ └── nucleotide_transformer_dataset.py
│ │ ├── fault_tolerant_sampler.py
│ │ ├── genomics.py
│ │ └── utils/
│ │ ├── mlm.py
│ │ └── rc.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── baseline/
│ │ │ ├── __init__.py
│ │ │ └── genomics_benchmark_cnn.py
│ │ ├── nn/
│ │ │ ├── __init__.py
│ │ │ ├── activation.py
│ │ │ ├── adaptive_softmax.py
│ │ │ └── utils.py
│ │ └── sequence/
│ │ ├── __init__.py
│ │ ├── dna_embedding.py
│ │ ├── hyena.py
│ │ └── long_conv_lm.py
│ ├── ops/
│ │ └── fftconv.py
│ ├── tasks/
│ │ ├── decoders.py
│ │ ├── encoders.py
│ │ ├── metrics.py
│ │ ├── tasks.py
│ │ └── torchmetrics.py
│ └── utils/
│ ├── __init__.py
│ ├── config.py
│ ├── optim/
│ │ └── schedulers.py
│ ├── optim_groups.py
│ ├── registry.py
│ └── train.py
├── train.py
├── vep_embeddings.py
└── vep_svm.ipynb
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
data.tar.gz
*.tsf
*.ckpt
.ipynb_checkpoints
*/.ipynb_checkpoints/*
*.lprof
.DS_Store
.idea/
outputs/
# slurm log files
watch_folder/
data
# Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# Mr Developer
.mr.developer.cfg
.project
.pydevproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# End of https://www.gitignore.io/api/python
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Caduceus ☤: Bi-Directional Equivariant Long-Range DNA Sequence Modeling
[[Blog]](https://caduceus-dna.github.io/) | [[arXiv]](https://arxiv.org/abs/2403.03234) | [[HuggingFace 🤗]](https://huggingface.co/collections/kuleshov-group/caducues-65dcb89b4f54e416ef61c350)
This repository contains code for reproducing the results in the paper "Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling," [Schiff et al. (2024)](https://arxiv.org/abs/2403.03234).
## Using Caduceus with 🤗
We have uploaded a pre-trained Caduceus model to the Huggingface hub.
The available models are:
- Caduceus-Ph: [kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16](https://huggingface.co/kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16)
- Trained on sequences of length 131k, with a model size of 256 and 16 layers.
- Trained for 50k steps and batch size of 8.
- Trained with reverse-complement (RC) data augmentation.
- Caduceus-PS: [kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16](https://huggingface.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16)
- Trained on sequences of length 131k, with a model size of 256 and 16 layers.
- Trained for 50k steps and batch size of 8.
- Model is RC equivariant, hence no RC data augmentation is required.
You can either use the pre-trained model directly within your trainer scripts or modify the config that initializes the model.
To use the pre-trained model for masked language modeling, use the following snippet:
```python
from transformers import AutoModelForMaskedLM, AutoTokenizer
# See the `Caduceus` collection page on the hub for list of available models.
model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
```
Alternatively, you can instantiate a model from scratch to train on your own data as follows:
```python
from transformers import AutoConfig, AutoModelForMaskedLM
# Add any config overrides here, see the `config.json` file on the hub for details.
config_overrides = {}
# See the `Caduceus` collection page on the hub for list of available models.
config = AutoConfig.from_pretrained(
"kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16",
**config_overrides,
)
model = AutoModelForMaskedLM.from_config(config)
```
## Getting started in this repository
To get started, create a conda environment containing the required dependencies.
```bash
conda env create -f caduceus_env.yml
```
Activate the environment.
```bash
conda activate caduceus_env
```
Create the following directories to store saved models and slurm logs:
```bash
mkdir outputs
mkdir watch_folder
```
## Reproducing Experiments
Below, we describe the steps required for reproducing the experiments in the paper.
Throughout, the main entry point for running experiments is the [`train.py`](./train.py) script.
We also provide sample `slurm` scripts for launching pre-training and downstream fine-tuning experiments in the [`slurm_scripts/`](./slurm_scripts) directory.
### Pretraining on Human Reference Genome
(Data downloading instructions are copied from [HyenaDNA repo](https://github.com/HazyResearch/hyena-dna?tab=readme-ov-file#pretraining-on-human-reference-genome))
First, download the Human Reference Genome data.
It's comprised of 2 files, 1 with all the sequences (the `.fasta` file), and with the intervals we use (`.bed` file).
The file structure should look like
```
data
|-- hg38/
|-- hg38.ml.fa
|-- human-sequences.bed
```
Download fasta (.fa format) file (of the entire human genome) into `./data/hg38`.
~24 chromosomes in the whole genome (merged into 1 file), each chromosome is a continuous sequence, basically.
Then download the .bed file with sequence intervals (contains chromosome name, start, end, split, which then allow you to retrieve from the fasta file).
```bash
mkdir -p data/hg38/
curl https://storage.googleapis.com/basenji_barnyard2/hg38.ml.fa.gz > data/hg38/hg38.ml.fa.gz
gunzip data/hg38/hg38.ml.fa.gz # unzip the fasta file
curl https://storage.googleapis.com/basenji_barnyard2/sequences_human.bed > data/hg38/human-sequences.bed
```
Launch pretraining run using the command line
```bash
python -m train \
experiment=hg38/hg38 \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \
dataset.max_length=1024 \
dataset.batch_size=1024 \
dataset.mlm=true \
dataset.mlm_probability=0.15 \
dataset.rc_aug=false \
model=caduceus \
model.config.d_model=128 \
model.config.n_layer=4 \
model.config.bidirectional=true \
model.config.bidirectional_strategy=add \
model.config.bidirectional_weight_tie=true \
model.config.rcps=true \
optimizer.lr="8e-3" \
train.global_batch_size=1024 \
trainer.max_steps=10000 \
+trainer.val_check_interval=10000 \
wandb=null
```
or alternatively, if using a cluster that has `slurm` installed, adapt the scripts below:
```
slurm_scripts
|-- run_pretrain_caduceus.sh
|-- run_pretrain_hyena.sh
|-- run_pretrain_mamba.sh
```
and run the training as a batch job:
```bash
cd slurm_scripts
sbatch run_pretrain_caduceus.sh
```
### GenomicBenchmarks
The [GenomicBenchmarks](https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks) presented in [Grešová et al. (2023)](https://bmcgenomdata.biomedcentral.com/articles/10.1186/s12863-023-01123-8) is comprised of 8 classification tasks.
We can launch a downstream fine-tuning run on one of the tasks using the sample command below:
```bash
python -m train \
experiment=hg38/genomic_benchmark \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
dataset.dataset_name="dummy_mouse_enhancers_ensembl" \
dataset.train_val_split_seed=1 \
dataset.batch_size=256 \
dataset.rc_aug=false \
+dataset.conjoin_train=false \
+dataset.conjoin_test=false \
loader.num_workers=2 \
model=caduceus \
model._name_=dna_embedding_caduceus \
+model.config_path="" \
+model.conjoin_test=false \
+decoder.conjoin_train=true \
+decoder.conjoin_test=false \
optimizer.lr="1e-3" \
trainer.max_epochs=10 \
train.pretrained_model_path="" \
wandb=null
```
This sample run will fine-tune a pre-trained Caduceus-PS model on the `dummy_mouse_enhancers_ensembl` task.
Note some of the additional arguments present here, relative to the pre-training command from [above](#pretraining):
- `model.config_path` contains the path model config that was saved during pre-training.
This will be saved to the run directory of the pre-training experiment.
- `train.pretrained_model_path` contains the path to the pre-trained model checkpoint.
- `dataset.conjoin_train` determines whether the dataset will return a single sequence (`dataset.conjoin_train=false`) or the concatenation of a sequence and its reverse complement along `dim=-1`, during downstream fine-tuning training.
- `dataset.conjoin_test` is the same as above, but for inference (e.g., validation / test).
- `decoder.conjoin_train` determines whether the prediction head (a mean pooling and linear projection in the case of the Genomics Benchmark) is expecting an input tensor of shape `(batch_size, seq_len, d_model)` or `(batch_size, seq_len, d_model, 2)` during downstream fine-tuning training.
When set to `true` the decoder is run on `input[..., 0]` and `input[..., 1]` and the results are averaged to produce the final prediction.
- `decoder.conjoin_test` is the same as above, but for inference (e.g., validation / test).
Note this benchmark only contains a training and test split for each task.
Therefore, to have a more principled evaluation, we randomly split the training data into training and validation sets (90/10) using the `dataset.train_val_split_seed` argument.
We perform early stopping on validation metric (accuracy) and repeat this for 5 random seeds.
As with [pre-training](#pretraining), we can also launch the fine-tuning run as a batch job using the provided [`run_genomic_benchmark.sh`](./slurm_scripts/run_genomics_benchmark.sh) script.
We also provide a helper shell script [`wrapper_run_genomics.sh`](./slurm_scripts/wrapper_run_genomics.sh) that can be used to launch multiple fine-tuning runs in parallel.
Finally, the [`run_genomics_benchmark_cnn.sh`](./slurm_scripts/run_genomics_benchmark_cnn.sh) script can be used to train the CNN baseline for this experiment from scratch on the downstream tasks.
### Nucleotide Transformer datasets
The Nucleotide Transformer suite of tasks was proposed in [Dalla-Torre et al. (2023)](https://www.biorxiv.org/content/10.1101/2023.01.11.523679v1).
The data is available on HuggingFace: [InstaDeepAI/nucleotide_transformer_downstream_tasks](https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_downstream_tasks).
We can launch a downstream fine-tuning run on one of the tasks using the sample command below:
```bash
python -m train \
experiment=hg38/nucleotide_transformer \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
dataset.dataset_name="${task}" \
dataset.train_val_split_seed=${seed} \
dataset.batch_size=${batch_size} \
dataset.rc_aug="${rc_aug}" \
+dataset.conjoin_test="${CONJOIN_TEST}" \
loader.num_workers=2 \
model._name_=dna_embedding_caduceus \
+model.config_path="" \
+model.conjoin_test=false \
+decoder.conjoin_train=true \
+decoder.conjoin_test=false \
optimizer.lr="1e-3" \
trainer.max_epochs=10 \
train.pretrained_model_path="" \
trainer.max_epochs=20 \
wandb=null
```
We can also launch as batch jobs (see [`run_nucleotide_transformer.sh`](./slurm_scripts/run_nucleotide_transformer.sh) and [`wrapper_run_nucleotide_transformer.sh`](./slurm_scripts/wrapper_run_nucleotide_transformer.sh) for details).
### eQTL SNP Variant Effect Prediction
This task comes from the recently proposed Long Range Benchmark (LRB) in [Kao et al., 2023](https://llms4science-community.github.io/papers/LLMs4Bio24_paper_12.pdf).
The data is available on HuggingFace: [InstaDeepAI/genomics-long-range-benchmark](https://huggingface.co/datasets/InstaDeepAI/genomics-long-range-benchmark).
For this task we fit a model to the pre-trained and frozen embeddings of the DNA language models.
Therefore, to perform the evaluation, we proceed in 2 steps:
- **Step 1: Extract the embeddings** from the pre-trained model:
Run the [`vep_embeddings.py`](./vep_embeddings.py) script to extract the embeddings from the pre-trained model.
See the example below:
```bash
torchrun \
--standalone \
--nnodes=1 \
--nproc-per-node=8 \
vep_embeddings.py \
--num_workers=2 \
--seq_len=131072 \
--bp_per_token=1 \
--embed_dump_batch_size=1 \
--name="caduceus-ps_downstream-seqlen=131k" \
--model_name_or_path="kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16" \
--rcps
```
The `--rcps` flag is used to indicate that the model is reverse-complement equivariant.
When using other models, set this flag to false with `--no-rcps`.
To speed this step up, this script utilizes torch distributed data parallelism.
Please refer to the slurm script provided in [`slurm_scripts/dump_vep_embeddings.sh`](./slurm_scripts/dump_vep_embeddings.sh)
to launch this step as a batch job.
- **Step 2: Fit an SVM model to the embeddings** using this notebook: [`vep_svm.ipynb`](./vep_svm.ipynb).
## Citation
If you find our work useful, please cite our paper using the following:
```
@article{schiff2024caduceus,
title={Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling},
author={Schiff, Yair and Kao, Chia-Hsiang and Gokaslan, Aaron and Dao, Tri and Gu, Albert and Kuleshov, Volodymyr},
journal={arXiv preprint arXiv:2403.03234},
year={2024}
}
```
## Acknowledgements
This repository is adapted from the [HyenaDNA repo](https://github.com/HazyResearch/hyena-dna) and leverages much of the training, data loading, and logging infrastructure defined there.
HyenaDNA was originally derived from the [S4](https://github.com/state-spaces/s4) and [Safari](https://github.com/HazyResearch/safari) repositories.
We would like to thank Evan Trop and the [InstaDeep](https://www.instadeep.com/) team for useful discussions about the [Nucleotide Transformer leaderboard](https://huggingface.co/spaces/InstaDeepAI/nucleotide_transformer_benchmark)
and the Long Range Benchmark task.
Finally, we would like to thank [MosaicML](https://www.mosaicml.com/) for providing compute resources for some of the pre-training experiments.
================================================
FILE: caduceus/__init__.py
================================================
"""Hugging Face config, model, and tokenizer for Caduceus.
"""
from .configuration_caduceus import CaduceusConfig
from .modeling_caduceus import Caduceus, CaduceusForMaskedLM, CaduceusForSequenceClassification
from .tokenization_caduceus import CaduceusTokenizer
================================================
FILE: caduceus/configuration_caduceus.py
================================================
"""Caduceus config for Hugging Face.
"""
from typing import Optional, Union
from transformers import PretrainedConfig
class CaduceusConfig(PretrainedConfig):
"""Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""
model_type = "caduceus"
def __init__(
self,
# From original MambaConfig
d_model: int = 2560,
n_layer: int = 64,
vocab_size: int = 50277,
ssm_cfg: Optional[dict] = None,
rms_norm: bool = True,
residual_in_fp32: bool = True,
fused_add_norm: bool = True,
pad_vocab_size_multiple: int = 8,
# Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
norm_epsilon: float = 1e-5,
# Used in init_weights
initializer_cfg: Optional[dict] = None,
# Caduceus-specific params
bidirectional: bool = True,
bidirectional_strategy: Union[str, None] = "add",
bidirectional_weight_tie: bool = True,
rcps: bool = False,
complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead
**kwargs,
):
super().__init__(**kwargs)
self.d_model = d_model
self.n_layer = n_layer
self.vocab_size = vocab_size
self.ssm_cfg = ssm_cfg
self.rms_norm = rms_norm
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.norm_epsilon = norm_epsilon
self.initializer_cfg = initializer_cfg
self.bidirectional = bidirectional
self.bidirectional_strategy = bidirectional_strategy
self.bidirectional_weight_tie = bidirectional_weight_tie
self.rcps = rcps
self.complement_map = complement_map
================================================
FILE: caduceus/modeling_caduceus.py
================================================
"""Caduceus model for Hugging Face.
"""
import inspect
import math
from functools import partial
from typing import Optional, Tuple, Union
import torch
from mamba_ssm.modules.mamba_simple import Mamba
try:
from mamba_ssm.modules.mamba_simple import Block # Legacy mambav1 file structure
except ImportError:
from mamba_ssm.modules.block import Block # mambav2 file structure
from torch import nn
from torch.nn import functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
try:
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
except ImportError:
try:
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
from .configuration_caduceus import CaduceusConfig
from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
def create_block(
d_model,
ssm_cfg=None,
norm_epsilon=1e-5,
rms_norm=False,
residual_in_fp32=False,
fused_add_norm=False,
layer_idx=None,
bidirectional=True,
bidirectional_strategy="add",
bidirectional_weight_tie=True,
rcps=False,
device=None,
dtype=None,
):
"""Create Caduceus block.
Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
"""
if ssm_cfg is None:
ssm_cfg = {}
factory_kwargs = {"device": device, "dtype": dtype}
bidirectional_kwargs = {
"bidirectional": bidirectional,
"bidirectional_strategy": bidirectional_strategy,
"bidirectional_weight_tie": bidirectional_weight_tie,
}
mixer_cls = partial(BiMambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs)
norm_cls = partial(
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
)
block_cls = RCPSMambaBlock if rcps else Block
# mambav2 compatibility
if "mlp_cls" in inspect.signature(block_cls.__init__).parameters:
block = block_cls(
d_model,
mixer_cls,
mlp_cls=nn.Identity,
norm_cls=norm_cls,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
)
else:
block = block_cls(
d_model,
mixer_cls,
norm_cls=norm_cls,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
)
block.layer_idx = layer_idx
return block
class BiMambaWrapper(nn.Module):
"""Thin wrapper around Mamba to support bi-directionality."""
def __init__(
self,
d_model: int,
bidirectional: bool = True,
bidirectional_strategy: Optional[str] = "add",
bidirectional_weight_tie: bool = True,
**mamba_kwargs,
):
super().__init__()
if bidirectional and bidirectional_strategy is None:
bidirectional_strategy = "add" # Default strategy: `add`
if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]:
raise NotImplementedError(f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!")
self.bidirectional = bidirectional
self.bidirectional_strategy = bidirectional_strategy
self.mamba_fwd = Mamba(
d_model=d_model,
**mamba_kwargs
)
if bidirectional:
self.mamba_rev = Mamba(
d_model=d_model,
**mamba_kwargs
)
if bidirectional_weight_tie: # Tie in and out projections (where most of param count lies)
self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight
self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias
self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight
self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias
else:
self.mamba_rev = None
def forward(self, hidden_states, inference_params=None):
"""Bidirectional-enabled forward pass
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
out = self.mamba_fwd(hidden_states, inference_params=inference_params)
if self.bidirectional:
out_rev = self.mamba_rev(
hidden_states.flip(dims=(1,)), # Flip along the sequence length dimension
inference_params=inference_params
).flip(dims=(1,)) # Flip back for combining with forward hidden states
if self.bidirectional_strategy == "add":
out = out + out_rev
elif self.bidirectional_strategy == "ew_multiply":
out = out * out_rev
else:
raise NotImplementedError(f"`{self.bidirectional_strategy}` for bi-directionality not implemented!")
return out
class CaduceusEmbeddings(nn.Module):
def __init__(
self,
config: CaduceusConfig,
device=None,
dtype=None,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
if config.rcps:
self.word_embeddings = RCPSEmbedding(
config.vocab_size, config.d_model, config.complement_map, **factory_kwargs
)
else:
self.word_embeddings = nn.Embedding(config.vocab_size, config.d_model, **factory_kwargs)
def forward(self, input_ids):
"""
input_ids: (batch, seqlen)
"""
return self.word_embeddings(input_ids)
class CaduceusMixerModel(nn.Module):
def __init__(
self,
config: CaduceusConfig,
device=None,
dtype=None,
) -> None:
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.fused_add_norm = config.fused_add_norm
self.rcps = config.rcps
self.residual_in_fp32 = config.residual_in_fp32
self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)
# Mamba changes the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Add, we do:
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
# the main branch (output of MLP / Mixer). The model definition is unchanged.
# This is for performance reason: we can fuse add + layer_norm.
if config.fused_add_norm:
if layer_norm_fn is None or rms_norm_fn is None:
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
self.layers = nn.ModuleList(
[
create_block(
config.d_model,
ssm_cfg=config.ssm_cfg,
norm_epsilon=config.norm_epsilon,
rms_norm=config.rms_norm,
residual_in_fp32=config.residual_in_fp32,
fused_add_norm=config.fused_add_norm,
layer_idx=i,
bidirectional=config.bidirectional,
bidirectional_strategy=config.bidirectional_strategy,
bidirectional_weight_tie=config.bidirectional_weight_tie,
rcps=config.rcps,
**factory_kwargs,
)
for i in range(config.n_layer)
]
)
norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
config.d_model, eps=config.norm_epsilon, **factory_kwargs
)
self.norm_f = norm_f if (config.fused_add_norm or not config.rcps) else RCPSAddNormWrapper(norm_f)
def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
"""Mixer forward."""
all_hidden_states = []
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids)
residual = None
for layer in self.layers:
if output_hidden_states:
all_hidden_states.append(hidden_states)
# TODO: Add support for gradient checkpointing
hidden_states, residual = layer(
hidden_states, residual, inference_params=None
)
if not self.fused_add_norm:
if self.rcps:
# Set prenorm=False here since we don't need the residual
hidden_states = self.norm_f(hidden_states, residual=residual, prenorm=False)
else:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
if self.rcps:
# Set prenorm=False here since we don't need the residual
hidden_states_fwd = fused_add_norm_fn(
hidden_states[..., :hidden_states.shape[-1] // 2],
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual[..., :hidden_states.shape[-1] // 2],
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
hidden_states_rc = fused_add_norm_fn(
hidden_states[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)
else:
# Set prenorm=False here since we don't need the residual
hidden_states = fused_add_norm_fn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
if output_hidden_states:
all_hidden_states.append(hidden_states)
return hidden_states, all_hidden_states
def cross_entropy(logits, y, ignore_index=-100):
"""Cross entropy loss."""
logits = logits.view(-1, logits.shape[-1])
y = y.view(-1)
return F.cross_entropy(logits, y, ignore_index=ignore_index)
def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
"""Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
logits = logits.view(-1, logits.shape[-1])
y = y.view(-1)
ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
loss_weights = loss_weights.view(-1)
loss_weights[y == ignore_index] = 0.0
# TODO: Follows GPN implementation, but should we remove weight normalization?
return (ce * (loss_weights / loss_weights.sum())).sum()
class CaduceusPreTrainedModel(PreTrainedModel):
"""PreTrainedModel wrapper for Caduceus backbone."""
config_class = CaduceusConfig
base_model_prefix = "caduceus"
supports_gradient_checkpointing = False
_no_split_modules = ["BiMambaWrapper"]
def _init_weights(
self,
module,
initializer_range=0.02, # Now only used for embedding layer.
**kwargs,
):
"""Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py"""
n_layer = self.config.n_layer
initialized_cfg = self.config.initializer_cfg if self.config.initializer_cfg is not None else {}
rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True)
initializer_range = initialized_cfg.get("initializer_range", initializer_range)
n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1)
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth.
# > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of
# residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(n_residuals_per_layer * n_layer)
class Caduceus(CaduceusPreTrainedModel):
"""Caduceus model that can be instantiated using HF patterns."""
def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
super().__init__(config)
if config.rcps:
assert config.complement_map is not None, "Complement map must be provided for RCPS."
# Adjust vocab size and complement maps if vocab padding is set.
if config.vocab_size % config.pad_vocab_size_multiple != 0:
config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple)
if config.complement_map is not None and config.vocab_size > len(config.complement_map):
for i in range(len(config.complement_map), config.vocab_size):
config.complement_map[i] = i
self.config = config
factory_kwargs = {"device": device, "dtype": dtype}
self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
"""HF-compatible forward method."""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states, all_hidden_states = self.backbone(
input_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states
)
if return_dict:
return BaseModelOutputWithNoAttention(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states if output_hidden_states else None
)
elif output_hidden_states:
return hidden_states, all_hidden_states
else:
return hidden_states
class CaduceusForMaskedLM(CaduceusPreTrainedModel):
"""HF-compatible Caduceus model for masked language modeling."""
def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
super().__init__(config, **kwargs)
factory_kwargs = {"device": device, "dtype": dtype}
self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
if config.rcps:
self.lm_head = RCPSLMHead(
complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
true_dim=config.d_model,
dtype=dtype
)
else:
self.lm_head = nn.Linear(
config.d_model,
self.config.vocab_size, # Use caduceus config as it might have been updated
bias=False,
**factory_kwargs
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.caduceus.backbone.embeddings.word_embeddings
def set_input_embeddings(self, value):
if self.config.rcps:
raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.")
self.caduceus.backbone.embeddings.word_embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
"""Overrides output embeddings."""
if self.config.rcps:
raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
self.lm_head = new_embeddings
def tie_weights(self):
"""Tie weights, accounting for RCPS."""
if self.config.rcps:
self.lm_head.set_weight(self.get_input_embeddings().weight)
else:
super().tie_weights()
def get_decoder(self):
"""Get decoder (backbone) for the model."""
return self.caduceus
def set_decoder(self, decoder):
"""Set decoder (backbone) for the model."""
self.caduceus = decoder
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
loss_weights: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
"""HF-compatible forward method."""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.caduceus(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
if loss_weights is not None:
loss = weighted_cross_entropy(logits, labels, loss_weights, ignore_index=self.config.pad_token_id)
else:
loss = cross_entropy(logits, labels, ignore_index=self.config.pad_token_id)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)
class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
def __init__(
self,
config: CaduceusConfig,
pooling_strategy: str = "mean",
conjoin_train: bool = False,
conjoin_eval: bool = False,
device=None,
dtype=None,
**kwargs):
super().__init__(config, **kwargs)
if pooling_strategy not in ["mean", "max", "first", "last"]:
raise NotImplementedError(f"Pooling strategy `{pooling_strategy}` not implemented.")
self.pooling_strategy = pooling_strategy
factory_kwargs = {"device": device, "dtype": dtype}
self.num_labels = kwargs.get("num_labels", config.num_labels)
self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
self.conjoin_train = conjoin_train
self.conjoin_eval = conjoin_eval
# Initialize weights and apply final processing
self.post_init()
self.init_scorer()
def init_scorer(self, initializer_range=0.02):
initializer_range = self.config.initializer_cfg.get("initializer_range", initializer_range) \
if self.config.initializer_cfg is not None else initializer_range
self.score.weight.data.normal_(std=initializer_range)
def get_input_embeddings(self):
return self.caduceus.backbone.embeddings.word_embeddings
def set_input_embeddings(self, value):
if self.config.rcps:
raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.")
self.caduceus.backbone.embeddings.word_embeddings = value
def pool_hidden_states(self, hidden_states, sequence_length_dim=1):
"""Pools hidden states along sequence length dimension."""
if self.pooling_strategy == "mean": # Mean pooling along sequence length dimension
return hidden_states.mean(dim=sequence_length_dim)
if self.pooling_strategy == "max": # Max pooling along sequence length dimension
return hidden_states.max(dim=sequence_length_dim).values
if self.pooling_strategy == "last": # Use embedding of last token in the sequence
return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[-1, ...]
if self.pooling_strategy == "first": # Use embedding of first token in the sequence
return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Get hidden representations from the backbone
if self.config.rcps: # Hidden states have 2 * d_model channels for RCPS
transformer_outputs = self.caduceus(
input_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = torch.stack(
[
transformer_outputs[0][..., :self.config.d_model],
torch.flip(transformer_outputs[0][..., self.config.d_model:], dims=[1, 2])
],
dim=-1
)
elif self.conjoin_train or (self.conjoin_eval and not self.training): # For conjoining / post-hoc conjoining
assert input_ids is not None, "`input_ids` must be provided for conjoining."
assert input_ids.ndim == 3, "`input_ids` must be 3D tensor: channels corresponds to forward and rc strands."
transformer_outputs = self.caduceus(
input_ids[..., 0],
inputs_embeds=None,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
transformer_outputs_rc = self.caduceus(
input_ids[..., 1],
inputs_embeds=None,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Stack along channel dimension (dim=-1)
hidden_states = torch.stack([transformer_outputs[0], transformer_outputs_rc[0]], dim=-1)
else:
transformer_outputs = self.caduceus(
input_ids,
inputs_embeds=None,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
# Pool and get logits
pooled_hidden_states = self.pool_hidden_states(hidden_states)
# Potentially run `score` twice (with parameters shared) for conjoining
if hidden_states.ndim == 4: # bsz, seq_len, hidden_dim, 2 where last channel has the stacked fwd and rc reps
logits_fwd = self.score(pooled_hidden_states[..., 0])
logits_rc = self.score(pooled_hidden_states[..., 1])
logits = (logits_fwd + logits_rc) / 2
else:
logits = self.score(pooled_hidden_states)
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
if self.num_labels == 1:
loss = F.mse_loss(logits.squeeze(), labels.squeeze())
else:
loss = F.mse_loss(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss = F.binary_cross_entropy_with_logits(logits, labels)
if not return_dict:
output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=transformer_outputs.hidden_states,
)
================================================
FILE: caduceus/modeling_rcps.py
================================================
"""Reverse-complement equivariant modules.
"""
from collections import OrderedDict
from typing import Optional
import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
try:
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
except ImportError:
try:
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
class RCPSEmbedding(nn.Module):
"""Embedding layer that supports reverse-complement equivariance."""
def __init__(self, vocab_size: int, d_model: int, complement_map: dict, **factory_kwargs):
"""
Args:
vocab_size: Size of vocabulary.
d_model: Dimensionality of embedding (actual embedding matrix will have 1/2 the output dim).
complement_map: Dictionary mapping each token id to its complement.
"""
super().__init__()
self.register_buffer(
"complement_map",
torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
)
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
@property
def weight(self):
"""Embedding weights."""
return self.embedding.weight
def set_weight(self, value):
"""Set embedding weights."""
self.embedding.weight = value
def rc(self, x):
"""Reverse-complement a tensor of input_ids by flipping along length dimension and complementing the ids."""
return torch.gather(
self.complement_map.unsqueeze(0).expand(x.shape[0], -1),
dim=1,
index=torch.flip(x, dims=[-1])
)
def forward(self, input_ids):
"""Reverse-complement equivariant forward pass.
This embedding module doubles the output dimensionality to support reverse-complement equivariance.
Args:
input_ids: Input tensor of shape (batch_size, seq_len)
Returns:
Embedding tensor of shape (batch_size, seq_len, d_model * 2)
"""
fwd_out = self.embedding(input_ids)
rc_out = torch.flip(self.embedding(self.rc(input_ids)), dims=[-2, -1])
return torch.cat([fwd_out, rc_out], dim=-1)
class RCPSWrapper(nn.Module):
"""Wrapper to convert arbitrary nn.Module into a reverse-complement equivariant module.
See ref. "Towards a Better Understanding of Reverse-Complement Equivariance for Deep Learning Models in Regulatory
Genomics", Zhou et al. (2022), https://proceedings.mlr.press/v165/zhou22a.html for more details.
"""
def __init__(self, submodule: nn.Module):
super().__init__()
self.submodule = submodule
@staticmethod
def rc(x):
"""Reverse-complement a tensor by flipping the length (dim=-2) and channel (dim=-1) dimensions."""
return torch.flip(x, dims=[-2, -1])
def forward(self, x, **kwargs):
"""Reverse-complement equivariant forward pass.
Args:
x: Input tensor of shape (batch_size, seq_len, channels)
Returns:
Output tensor of shape (batch_size, seq_len, channels)
"""
n_channels = x.shape[-1]
# Run submodule along sequence
fwd_out = self.submodule(x[..., :n_channels // 2], **kwargs)
# Run submodule along rc-sequence
rc_out = self.submodule(self.rc(x[..., n_channels // 2:]), **kwargs)
# Concatenate along channel dimension (dim=-1)
return torch.cat([fwd_out, self.rc(rc_out)], dim=-1)
class RCPSAddNormWrapper(RCPSWrapper):
"""RC equivariant AddNorm layer."""
def __init__(self, submodule: nn.Module):
super().__init__(submodule)
def forward(self, x, residual=None, prenorm=False):
"""
Args:
x: Input tensor of shape (batch_size, seq_len, channels)
residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
prenorm: Whether to return residual.
"""
n_channels = x.shape[-1]
if residual is None:
residual = x
x_fwd = self.submodule(x[..., :n_channels // 2].to(dtype=self.submodule.weight.dtype))
x_rc = self.submodule(self.rc(x[..., n_channels // 2:]).to(dtype=self.submodule.weight.dtype))
x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
else:
residual_fwd = x[..., :n_channels // 2] + residual[..., :n_channels // 2]
x_fwd = self.submodule(residual_fwd.to(dtype=self.submodule.weight.dtype))
residual_rc = self.rc(x[..., n_channels // 2:]) + self.rc(residual[..., n_channels // 2:])
x_rc = self.submodule(residual_rc.to(dtype=self.submodule.weight.dtype))
residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
return x if not prenorm else (x, residual)
class RCPSMambaBlock(nn.Module):
def __init__(
self,
dim,
mixer_cls,
norm_cls=nn.LayerNorm,
fused_add_norm=False,
residual_in_fp32=False,
device=None, # Keep for consistency with original Mamba Block
dtype=None, # Keep for consistency with original Mamba Block
):
"""RCPS version of simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection.
Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
"""
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.mixer = RCPSWrapper(mixer_cls(dim))
norm_f = norm_cls(dim)
self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm import fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
def forward(
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Mixer(LN(residual)).
inference_params: inference parameters for mixer.
"""
if not self.fused_add_norm:
hidden_states, residual = self.norm(hidden_states, residual=residual, prenorm=True)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
hidden_states_fwd, residual_fwd = fused_add_norm_fn(
hidden_states[..., hidden_states.shape[-1] // 2:],
self.norm.weight,
self.norm.bias,
residual=residual[..., hidden_states.shape[-1] // 2:] if residual is not None else None,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
)
hidden_states_rc, residual_rc = fused_add_norm_fn(
hidden_states[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]),
self.norm.weight,
self.norm.bias,
residual=residual[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]) if residual is not None else None,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
)
hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)
residual = torch.cat([residual_fwd, residual_rc.flip(dims=[-2, -1])], dim=-1)
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
return hidden_states, residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
"""Allocate inference cache for mixer.
Keep for compatibility with original Mamba Block.
"""
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
class RCPSLMHead(nn.Module):
"""LM Head for reverse-complement equivariant inputs, which have dim * 2 relative to standard inputs."""
def __init__(self, true_dim: int, vocab_size: int, complement_map: dict, **factory_kwargs):
"""
`true_dim` corresponds to the actual dimensionality of the input were it not reverse-complement
equivariant, i.e. 0.5 times the actual input dim.
"""
super().__init__()
self.register_buffer(
"complement_map",
torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
)
self.true_dim = true_dim
self.lm_head = nn.Linear(true_dim, vocab_size, bias=False, **factory_kwargs)
@property
def weight(self):
"""LM head weights."""
return self.lm_head.weight
def set_weight(self, value):
"""Set LM head weights."""
self.lm_head.weight = value
def forward(self, x):
"""
Args:
x: Input tensor of shape (batch_size, seq_len, dim), where dim = 2 * true_dim.
"""
n_channels = x.shape[-1]
assert n_channels == 2 * self.true_dim, "Input must have 2 * true_dim channels."
fwd_logits = F.linear(x[..., :n_channels // 2], self.weight, bias=self.lm_head.bias)
rc_logits = F.linear(
torch.flip(x[..., n_channels // 2:], dims=[-1]),
self.weight[self.complement_map, :],
bias=self.lm_head.bias
)
return fwd_logits + rc_logits
================================================
FILE: caduceus/tests/test_rcps.py
================================================
"""Tests for RCPS modules.
"""
import pytest
import torch
from torch import nn
from torch.nn import functional as F
try: # Legacy mambav1 file structure
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
try: # mambav2 file structure
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
from caduceus.modeling_rcps import (
RCPSEmbedding, RCPSAddNormWrapper, RCPSLMHead, RCPSWrapper
)
from caduceus.modeling_caduceus import (
CaduceusConfig, CaduceusMixerModel, CaduceusForMaskedLM, create_block
)
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [512])
@pytest.mark.parametrize("d_model", [256])
@pytest.mark.parametrize("dtype", [torch.float32])
def test_rcps_embedding(batch_size, seq_len, d_model, dtype):
# Set tolerance
device = torch.device("cpu")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Define complement map
str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6}
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
complement_map = {
str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v
for k, v in str_to_id.items()
}
vocab_size = 12
pad_vocab_size_multiple = 8
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
if vocab_size > len(complement_map):
for i in range(len(complement_map), vocab_size):
complement_map[i] = i
# Generate random sequences
input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)
rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device)
# Test RC equivariance of embedding layer
factory_kwargs = {"device": device, "dtype": dtype}
embedding = RCPSEmbedding(
vocab_size=vocab_size,
d_model=d_model,
complement_map=complement_map,
**factory_kwargs
).to(device)
out_embed = embedding(input_ids)
rc_out_embed = torch.flip(embedding(rc_input_ids), dims=[-2, -1])
# Test that channels are 2 * d_model
assert tuple(out_embed.size()) == (batch_size, seq_len, d_model * 2)
assert tuple(rc_out_embed.size()) == (batch_size, seq_len, d_model * 2)
# Test that RC equivariance holds
assert torch.allclose(out_embed.detach(), rc_out_embed.detach(), rtol=rtol, atol=atol)
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seq_len", [1024])
@pytest.mark.parametrize("d_model", [128])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_rcps_wrapper(batch_size, seq_len, d_model, dtype):
# Set tolerance
device = torch.device("cuda")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Generate random sequence with 2 * d_model channels
x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)
rc_x = torch.flip(x, dims=[-2, -1])
factory_kwargs = {"device": device, "dtype": dtype}
module = nn.Sequential(
nn.Linear(d_model, d_model, bias=False, **factory_kwargs),
nn.ReLU(),
nn.Linear(d_model, d_model*2, bias=True, **factory_kwargs),
nn.ReLU(),
nn.Linear(d_model * 2, d_model, bias=True, **factory_kwargs)
)
# Test RC equivariance of wrapper
rcps_module = RCPSWrapper(module).to(device)
out = rcps_module(x)
rc_out = torch.flip(rcps_module(rc_x), dims=[-2, -1])
assert out.size() == x.size()
assert rc_out.size() == x.size()
assert torch.allclose(out.detach(), rc_out.detach(), rtol=rtol, atol=atol)
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seq_len", [1024])
@pytest.mark.parametrize("d_model", [128])
@pytest.mark.parametrize("prenorm", [False, True])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_rcps_add_norm_wrapper(batch_size, seq_len, d_model, prenorm, dtype):
# Set tolerance
device = torch.device("cuda")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Generate random sequence with 2 * d_model channels
x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)
rc_x = torch.flip(x, dims=[-2, -1])
factory_kwargs = {"device": device, "dtype": dtype}
norm = RMSNorm(d_model, eps=1e-5, **factory_kwargs)
# Test RC equivariance of wrapper
rcps_module = RCPSAddNormWrapper(norm).to(device)
out = rcps_module(x, prenorm=prenorm)
if prenorm: # returns tuple
rc_out = tuple([torch.flip(r, dims=[-2, -1])
for r in rcps_module(rc_x, prenorm=prenorm)])
for f, r in zip(out, rc_out):
assert f.size() == x.size()
assert r.size() == x.size()
assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)
else:
rc_out = torch.flip(rcps_module(rc_x, prenorm=prenorm), dims=[-2, -1])
assert out.size() == x.size()
assert rc_out.size() == x.size()
assert torch.allclose(out.detach(), rc_out.detach(), rtol=rtol, atol=atol)
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seq_len", [1024])
@pytest.mark.parametrize("d_model", [128])
@pytest.mark.parametrize("bidirectional", [True, False])
@pytest.mark.parametrize("fused_add_norm", [True, False])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_rcps_mamba_block_wrapper(batch_size, seq_len, d_model, bidirectional, fused_add_norm, dtype):
# Set tolerance
device = torch.device("cuda")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Generate random sequence with 2 * d_model channels
x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)
rc_x = torch.flip(x, dims=[-2, -1])
ssm_cfg = {
"d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random",
"dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True
}
factory_kwargs = {"device": device, "dtype": dtype}
mamba_block = create_block(
d_model,
ssm_cfg=ssm_cfg,
norm_epsilon=1e-5,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=fused_add_norm,
layer_idx=0,
bidirectional=bidirectional,
bidirectional_strategy="add",
bidirectional_weight_tie=True,
rcps=True,
**factory_kwargs
)
# Test RC equivariance of wrapper
out = mamba_block(x, residual=None)
rc_out = tuple([torch.flip(r, dims=[-2, -1]) for r in mamba_block(rc_x, residual=None)])
for f, r in zip(out, rc_out):
assert f.size() == x.size()
assert r.size() == x.size()
assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)
out = mamba_block(x, residual=x)
rc_out = tuple([torch.flip(r, dims=[-2, -1]) for r in mamba_block(rc_x, residual=rc_x)])
for f, r in zip(out, rc_out):
assert f.size() == x.size()
assert r.size() == x.size()
assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)
@pytest.mark.parametrize("batch_size", [2, 4])
@pytest.mark.parametrize("seq_len", [1, 1024, 2048])
@pytest.mark.parametrize("d_model", [2, 128, 256])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_rcps_lm_head(batch_size, seq_len, d_model, dtype):
# Set tolerance
device = torch.device("cuda")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Define complement map
str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6}
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
complement_map = {
str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v
for k, v in str_to_id.items()
}
factory_kwargs = {"device": device, "dtype": dtype}
vocab_size = 12
if vocab_size > len(complement_map):
for i in range(len(complement_map), vocab_size):
complement_map[i] = i
# Instantiate LM head
lm_head = RCPSLMHead(
complement_map=complement_map,
vocab_size=vocab_size,
true_dim=d_model,
**factory_kwargs
)
# Generate random sequence with 2 * d_model channels
x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)
rc_x = torch.flip(x, dims=[-2, -1])
# Test RC equivariance of LM head
out = lm_head(x)
rc_out = lm_head(rc_x)
assert tuple(out.size()) == (batch_size, seq_len, vocab_size)
assert tuple(rc_out.size()) == (batch_size, seq_len, vocab_size)
assert torch.allclose(
out.detach(),
torch.flip(rc_out.detach()[..., lm_head.complement_map], dims=[1]),
rtol=rtol,
atol=atol
)
assert torch.allclose(
F.softmax(out, dim=-1).detach(),
torch.flip(F.softmax(rc_out, dim=-1).detach()[..., lm_head.complement_map], dims=[1]),
rtol=rtol,
atol=atol
)
@pytest.mark.parametrize("batch_size", [2, 4])
@pytest.mark.parametrize("seq_len", [1024, 2048])
@pytest.mark.parametrize("n_layer", [1, 2, 3])
@pytest.mark.parametrize("d_model", [128, 256])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize("fused_add_norm", [True, False])
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("bidirectional_weight_tie", [False, True])
def test_rcps_backbone(batch_size, seq_len, n_layer, d_model, dtype, fused_add_norm,
bidirectional, bidirectional_weight_tie):
# Set tolerance
device = torch.device("cuda")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Define complement map
str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6}
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
complement_map = {
str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v
for k, v in str_to_id.items()
}
# Setup CaduceusConfig
initializer_cfg = {"initializer_range": 0.02, "rescale_prenorm_residual": True, "n_residuals_per_layer": 1}
ssm_cfg = {
"d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random",
"dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True
}
config = CaduceusConfig(
d_model=d_model,
n_layer=n_layer,
vocab_size=12,
ssm_cfg=ssm_cfg,
rms_norm=True,
residual_in_fp32=False,
fused_add_norm=fused_add_norm,
pad_vocab_size_multiple=8,
norm_epsilon=1e-5,
initializer_cfg=initializer_cfg,
bidirectional=bidirectional,
bidirectional_strategy="add",
bidirectional_weight_tie=bidirectional_weight_tie,
rcps=True,
complement_map=complement_map,
)
factory_kwargs = {"device": device, "dtype": dtype}
# Instantiate model
backbone = CaduceusMixerModel(
config,
**factory_kwargs,
).to(device)
# Generate random sequences
input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)
rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device)
# Test RC equivariance of rc backbone
out = backbone(input_ids)[0]
rc_out = backbone(rc_input_ids)[0]
if isinstance(rc_out, tuple):
rc_out = tuple([torch.flip(r, dims=[1, 2]) for r in rc_out])
for f, r in zip(out, rc_out):
assert f.size() == (batch_size, seq_len, d_model * 2)
assert r.size() == (batch_size, seq_len, d_model * 2)
assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)
else:
# Hidden state size should double
assert tuple(out.size()) == (batch_size, seq_len, d_model * 2)
assert tuple(rc_out.size()) == (batch_size, seq_len, d_model * 2)
assert torch.allclose(out.detach(), torch.flip(rc_out.detach(), dims=[1, 2]), rtol=rtol, atol=atol)
@pytest.mark.parametrize("batch_size", [2, 4])
@pytest.mark.parametrize("seq_len", [1024, 2048])
@pytest.mark.parametrize("n_layer", [1, 3, 4])
@pytest.mark.parametrize("d_model", [128, 256])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("bidirectional_weight_tie", [False, True])
def test_rcps_mamba_lm(batch_size, seq_len, n_layer, d_model, dtype, bidirectional, bidirectional_weight_tie):
# Set tolerance
device = torch.device("cuda")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Define complement map
str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6}
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
complement_map = {
str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v
for k, v in str_to_id.items()
}
# Setup CaduceusConfig
initializer_cfg = {"initializer_range": 0.02, "rescale_prenorm_residual": True, "n_residuals_per_layer": 1}
ssm_cfg = {
"d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random",
"dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True
}
config = CaduceusConfig(
d_model=d_model,
n_layer=n_layer,
vocab_size=12,
ssm_cfg=ssm_cfg,
rms_norm=True,
residual_in_fp32=False,
fused_add_norm=True,
pad_vocab_size_multiple=8,
norm_epsilon=1e-5,
initializer_cfg=initializer_cfg,
bidirectional=bidirectional,
bidirectional_strategy="add",
bidirectional_weight_tie=bidirectional_weight_tie,
rcps=True,
complement_map=complement_map,
)
factory_kwargs = {"device": device, "dtype": dtype}
# Instantiate model
mamba_lm = CaduceusForMaskedLM(
config=config,
**factory_kwargs,
).to(device)
# Generate random sequences
input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)
rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device)
# Test RC equivariance of rc backbone
out = mamba_lm(input_ids)
rc_out = mamba_lm(rc_input_ids)
if config.vocab_size % config.pad_vocab_size_multiple != 0:
config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple)
assert tuple(out.logits.size()) == (batch_size, seq_len, config.vocab_size)
assert tuple(rc_out.logits.size()) == (batch_size, seq_len, config.vocab_size)
assert torch.allclose(
out.logits.detach(),
torch.flip(rc_out.logits.detach()[..., mamba_lm.lm_head.complement_map], dims=[1]),
rtol=rtol,
atol=atol
)
assert torch.allclose(
F.softmax(out.logits, dim=-1).detach(),
torch.flip(F.softmax(rc_out.logits, dim=-1).detach()[..., mamba_lm.lm_head.complement_map], dims=[1]),
rtol=rtol,
atol=atol
)
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seq_len", [1024])
@pytest.mark.parametrize("n_layer", [2])
@pytest.mark.parametrize("d_model", [128])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("bidirectional", [True, False])
@pytest.mark.parametrize("bidirectional_weight_tie", [True])
def test_collapse_invariance(batch_size, seq_len, n_layer, d_model, dtype, bidirectional, bidirectional_weight_tie):
# Set tolerance
device = torch.device("cuda")
rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
# Set seed
torch.random.manual_seed(0)
# Define complement map
str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6}
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
complement_map = {
str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v
for k, v in str_to_id.items()
}
# Setup CaduceusConfig
initializer_cfg = {"initializer_range": 0.02, "rescale_prenorm_residual": True, "n_residuals_per_layer": 1}
ssm_cfg = {
"d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random",
"dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True
}
config = CaduceusConfig(
d_model=d_model,
n_layer=n_layer,
vocab_size=12,
ssm_cfg=ssm_cfg,
rms_norm=True,
residual_in_fp32=False,
fused_add_norm=True,
pad_vocab_size_multiple=8,
norm_epsilon=1e-5,
initializer_cfg=initializer_cfg,
bidirectional=bidirectional,
bidirectional_strategy="add",
bidirectional_weight_tie=bidirectional_weight_tie,
rcps=True,
complement_map=complement_map,
)
factory_kwargs = {"device": device, "dtype": dtype}
# Instantiate model
backbone = CaduceusMixerModel(
config,
**factory_kwargs,
).to(device)
# Generate random sequences
input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)
rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device)
# Test RC Invariance when collapsing output of backbone
out = backbone(input_ids)[0]
out_collapse = (out[..., :d_model] + torch.flip(out[..., d_model:], dims=[1, 2])) / 2
rc_out = backbone(rc_input_ids)[0]
rc_out_collapse = (rc_out[..., :d_model] + torch.flip(rc_out[..., d_model:], dims=[1, 2])) / 2
# Hidden state size should be d_model
assert tuple(out_collapse.size()) == (batch_size, seq_len, d_model)
assert tuple(rc_out_collapse.size()) == (batch_size, seq_len, d_model)
assert torch.allclose(out_collapse.detach(), rc_out_collapse.detach(), rtol=rtol, atol=atol)
================================================
FILE: caduceus/tokenization_caduceus.py
================================================
"""Character tokenizer for Hugging Face.
"""
from typing import List, Optional, Dict, Sequence, Tuple
from transformers import PreTrainedTokenizer
class CaduceusTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids"]
def __init__(self,
model_max_length: int,
characters: Sequence[str] = ("A", "C", "G", "T", "N"),
complement_map=None,
bos_token="[BOS]",
eos_token="[SEP]",
sep_token="[SEP]",
cls_token="[CLS]",
pad_token="[PAD]",
mask_token="[MASK]",
unk_token="[UNK]",
**kwargs):
"""Character tokenizer for Hugging Face transformers.
Adapted from https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen-hf/blob/main/tokenization_hyena.py
Args:
model_max_length (int): Model maximum sequence length.
characters (Sequence[str]): List of desired characters. Any character which
is not included in this list will be replaced by a special token called
[UNK] with id=6. Following is a list of the special tokens with
their corresponding ids:
"[CLS]": 0
"[SEP]": 1
"[BOS]": 2
"[MASK]": 3
"[PAD]": 4
"[RESERVED]": 5
"[UNK]": 6
an id (starting at 7) will be assigned to each character.
complement_map (Optional[Dict[str, str]]): Dictionary with string complements for each character.
"""
if complement_map is None:
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A", "N": "N"}
self.characters = characters
self.model_max_length = model_max_length
self._vocab_str_to_int = {
"[CLS]": 0,
"[SEP]": 1,
"[BOS]": 2,
"[MASK]": 3,
"[PAD]": 4,
"[RESERVED]": 5,
"[UNK]": 6,
**{ch: i + 7 for i, ch in enumerate(self.characters)},
}
self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
add_prefix_space = kwargs.pop("add_prefix_space", False)
padding_side = kwargs.pop("padding_side", "left")
self._complement_map = {}
for k, v in self._vocab_str_to_int.items():
complement_id = self._vocab_str_to_int[complement_map[k]] if k in complement_map.keys() else v
self._complement_map[self._vocab_str_to_int[k]] = complement_id
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
unk_token=unk_token,
add_prefix_space=add_prefix_space,
model_max_length=model_max_length,
padding_side=padding_side,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self._vocab_str_to_int)
@property
def complement_map(self) -> Dict[int, int]:
return self._complement_map
def _tokenize(self, text: str, **kwargs) -> List[str]:
return list(text.upper()) # Convert all base pairs to uppercase
def _convert_token_to_id(self, token: str) -> int:
return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
def _convert_id_to_token(self, index: int) -> str:
return self._vocab_int_to_str[index]
def convert_tokens_to_string(self, tokens):
return "".join(tokens) # Note: this operation has lost info about which base pairs were originally lowercase
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
result = ([0] * len(token_ids_0)) + [1]
if token_ids_1 is not None:
result += ([0] * len(token_ids_1)) + [1]
return result
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
sep = [self.sep_token_id]
# cls = [self.cls_token_id]
result = token_ids_0 + sep
if token_ids_1 is not None:
result += token_ids_1 + sep
return result
def get_vocab(self) -> Dict[str, int]:
return self._vocab_str_to_int
# Fixed vocabulary with no vocab file
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple:
return ()
================================================
FILE: caduceus_env.yml
================================================
name: caduceus_env
channels:
- pytorch
- anaconda
- nvidia
- defaults
dependencies:
- cuda-nvcc=11.7.99
- pip=23.3.1
- python=3.8
- pytorch=2.2.0
- torchaudio=2.2.0
- torchaudio=2.2.0
- torchdata=0.7.1
- torchmetrics=1.2.1
- torchtext=0.17.0
- torchvision=0.17.0
- pytorch-cuda=12.1
- pip:
- biopython==1.81
- datasets==2.15.0
- einops==0.7.0
- enformer-pytorch==0.8.8
- fsspec==2023.10.0
- genomic-benchmarks==0.0.9
- git-lfs==1.6
- h5py==3.10.0
- huggingface-hub==0.24.7
- hydra-core==1.3.2
- ipdb==0.13.13
- matplotlib==3.7.4
- notebook==7.1.1
- nvitop==1.3.2
- omegaconf==2.3.0
- pandas==2.0.3
- pyfaidx==0.8.1.1
- pysam==0.22.0
- pytest==8.0.2
- pytorch-lightning==1.8.6
- rich==13.7.0
- seaborn==0.13.2
- scikit-learn==1.3.2
- timm==0.9.16
- tqdm==4.66.1
- transformers==4.38.1
- triton==2.2.0
- wandb==0.13.5
- flash-attn==2.5.6
- causal-conv1d===1.2.0.post2
- mamba-ssm==1.2.0.post1
================================================
FILE: configs/callbacks/base.yaml
================================================
learning_rate_monitor:
# _target_: pytorch_lightning.callbacks.LearningRateMonitor
logging_interval: ${train.interval}
timer:
# _target_: callbacks.timer.Timer
step: True
inter_step: False
epoch: True
val: True
params:
# _target_: callbacks.params.ParamsLog
total: True
trainable: True
fixed: True
================================================
FILE: configs/callbacks/checkpoint.yaml
================================================
model_checkpoint:
monitor: ${train.monitor} # name of the logged metric which determines when model is improving
mode: ${train.mode} # can be "max" or "min"
save_top_k: 1 # save k best models (determined by above metric)
save_last: False # True = additionally always save model from last epoch
dirpath: "checkpoints/"
filename: ${train.monitor}
auto_insert_metric_name: False
verbose: True
model_checkpoint_every_n_steps:
monitor: train/loss # name of the logged metric which determines when model is improving
mode: min # can be "max" or "min"
save_top_k: 0 # Do not save any "best" models; this callback is being used to save every n train steps
save_last: True # additionally always save model from last epoch
dirpath: "checkpoints/"
filename: train/loss
auto_insert_metric_name: False
verbose: True
every_n_train_steps: 100
#model_checkpoint_every_epoch:
# monitor: trainer/epoch # name of the logged metric which determines when model is improving
# mode: max # can be "max" or "min"
# save_top_k: 1 # Do not save any "best" models; this callback is being used to save every n train steps
# save_last: False # additionally always save model from last epoch
# dirpath: "checkpoints/"
# filename: null
# auto_insert_metric_name: False
# verbose: True
# every_n_epochs: 1
================================================
FILE: configs/callbacks/gpu_affinity.yaml
================================================
gpu_affinity:
_name_: gpu_affinity
================================================
FILE: configs/callbacks/rich.yaml
================================================
rich_model_summary:
max_depth: 2
rich_progress_bar:
refresh_rate_per_second: 1.0
================================================
FILE: configs/callbacks/val_every_n_global_steps.yaml
================================================
val_every_n_global_steps:
every_n: 10000
================================================
FILE: configs/callbacks/wandb.yaml
================================================
defaults:
- default
watch_model:
_target_: src.callbacks.wandb_callbacks.WatchModel
log: "all"
log_freq: 100
upload_code_as_artifact:
_target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact
code_dir: ${work_dir}/src
upload_ckpts_as_artifact:
_target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
ckpt_dir: "checkpoints/"
upload_best_only: True
log_f1_precision_recall_heatmap:
_target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap
log_confusion_matrix:
_target_: src.callbacks.wandb_callbacks.LogConfusionMatrix
log_image_predictions:
_target_: src.callbacks.wandb_callbacks.LogImagePredictions
num_samples: 8
================================================
FILE: configs/config.yaml
================================================
# @package _global_
defaults:
- _self_
- experiment: ???
# - model: ??? # Model backbone
# - pipeline: ??? # Specifies collection of configs, equivalent to next 5 lines
# Pipelines should specify /loader, /dataset, /task, /encoder, /decoder (ideally in that order)
# # - loader: default # Dataloader (e.g. handles batches)
# # - dataset: cifar # Defines the data (x and y pairs)
# # - task: multiclass_classification # Defines loss and metrics
# # - encoder: null # Interface between data and model
# # - decoder: null # Interface between model and targets
# Additional arguments used to configure the training loop
# Most of these set combinations of options in the PL trainer, add callbacks, or add features to the optimizer
train:
seed: 0
# These three options are used by callbacks (checkpoint, monitor) and scheduler
# Most of them are task dependent and are set by the pipeline
interval: ??? # Should be specified by scheduler. Also used by LR monitor
monitor: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer
mode: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer
ema: 0.0 # Moving average model for validation
test: True # Test after training
debug: False # Special settings to make debugging more convenient
ignore_warnings: False # Disable python warnings
optimizer_param_grouping:
bias_weight_decay: False
normalization_weight_decay: False
# These control state passing between batches
state:
mode: null # [ None | 'none' | 'reset' | 'bptt' | 'tbptt' ]
n_context: 0 # How many steps to use as memory context. Must be >= 0 or None (null), meaning infinite context
n_context_eval: ${.n_context} # Context at evaluation time
# Convenience keys to allow grouping runs
ckpt: checkpoints/last.ckpt # Resume training
disable_dataset: False # Disable dataset loading
validate_at_start: false
pretrained_model_path: null # Path to pretrained model
pretrained_model_strict_load: true # Whether to load the pretrained model even if the model is not compatible
pretrained_model_state_hook: # Hook called on the loaded model's state_dict
_name_: null
post_init_hook: # After initializing model, call method on model
_name_: null
layer_decay: # Used for ImageNet finetuning
_name_: null
decay: 0.7
# We primarily use wandb so this is moved to top level in the config for convenience
# Set `~wandb` or `wandb=null` or `wandb.mode=disabled` to disable logging
# If other loggers are added, it would make sense to put this one level lower under train/ or logger/
wandb:
project: dna
group: ""
job_type: training
mode: online # choices=['online', 'offline', 'disabled']
name: null
save_dir: "."
id: ${.name} # pass correct id to resume experiment!
# Below options should not need to be specified
# entity: "" # set to name of your wandb team or just remove it
# log_model: False
# prefix: ""
# job_type: "train"
# tags: []
hydra:
run:
dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S-%f}
job:
chdir: true
================================================
FILE: configs/dataset/genomic_benchmark.yaml
================================================
_name_: genomic_benchmark
train_val_split_seed: ${train.seed} # Used for train/validation splitting
dataset_name: dummy_mouse_enhancers_ensembl
dest_path: null
max_length: ${.${.dataset_name}.max_length}
max_length_val: ${.max_length}
max_length_test: ${.max_length}
d_output: ${.${.dataset_name}.classes}
use_padding: True
padding_side: 'left'
add_eos: False
batch_size: 128
train_len: ${.${.dataset_name}.train_len}
__l_max: ${.max_length}
shuffle: true # set this as default!
# these are used to find the right attributes automatically for each dataset
dummy_mouse_enhancers_ensembl:
train_len: 1210
classes: 2
max_length: 1024
demo_coding_vs_intergenomic_seqs:
train_len: 100_000
classes: 2
max_length: 200
demo_human_or_worm:
train_len: 100_000
classes: 2
max_length: 200
human_enhancers_cohn:
train_len: 27791
classes: 2
max_length: 500
human_enhancers_ensembl:
train_len: 154842
classes: 2
max_length: 512
human_ensembl_regulatory:
train_len: 289061
classes: 3
max_length: 512
human_nontata_promoters:
train_len: 36131
classes: 2
max_length: 251
human_ocr_ensembl:
train_len: 174756
classes: 2
max_length: 512
# there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings
# name num_seqs num_classes median len std
# dummy_mouse_enhancers_ensembl 1210 2 2381 984.4
# demo_coding_vs_intergenomic_seqs 100_000 2 200 0
# demo_human_or_worm 100_000 2 200 0
# human_enhancers_cohn 27791 2 500 0
# human_enhancers_ensembl 154842 2 269 122.6
# human_ensembl_regulatory 289061 3 401 184.3
# human_nontata_promoters 36131 2 251 0
# human_ocr_ensembl 174756 2 315 108.1
================================================
FILE: configs/dataset/hg38.yaml
================================================
_name_: hg38
bed_file: null
fasta_file: null
dataset_name: hg38
tokenizer_name: null
cache_dir: null
max_length: 1024
add_eos: True
batch_size: 8 # per GPU
batch_size_eval: ${eval:${.batch_size} * 2}
num_workers: 4 # For preprocessing only
shuffle: True
__train_len: 34021
__l_max: ${.max_length}
================================================
FILE: configs/dataset/nucleotide_transformer.yaml
================================================
_name_: nucleotide_transformer # this links to the overall SequenceDataset of all nucleotide transformer datasets
train_val_split_seed: ${train.seed} # Used for train/validation splitting
dataset_name: enhancers # this specifies which dataset in nuc trx
dest_path: null # path to overall nuc trx datasets
max_length: ${.${.dataset_name}.max_length}
d_output: ${.${.dataset_name}.classes}
use_padding: True
padding_side: left
add_eos: False
batch_size: 256
train_len: ${.${.dataset_name}.train_len}
__l_max: ${.max_length}
shuffle: true # set this as default!
metric: ${.${.dataset_name}.metric}
# these are used to find the right attributes automatically for each dataset
enhancers:
train_len: 14968
classes: 2
max_length: 200
metric: mcc
enhancers_types:
train_len: 14968
classes: 3
max_length: 200
metric: mcc
H3:
train_len: 13468
classes: 2
max_length: 500
metric: mcc
H3K4me1:
train_len: 28509
classes: 2
max_length: 500
metric: mcc
H3K4me2:
train_len: 27614
classes: 2
max_length: 500
metric: mcc
H3K4me3:
train_len: 33119
classes: 2
max_length: 500
metric: mcc
H3K9ac:
train_len: 25003
classes: 2
max_length: 500
metric: mcc
H3K14ac:
train_len: 29743
classes: 2
max_length: 500
metric: mcc
H3K36me3:
train_len: 31392
classes: 2
max_length: 500
metric: mcc
H3K79me3:
train_len: 25953
classes: 2
max_length: 500
metric: mcc
H4:
train_len: 13140
classes: 2
max_length: 500
metric: mcc
H4ac:
train_len: 30685
classes: 2
max_length: 500
metric: mcc
promoter_all:
train_len: 53276
classes: 2
max_length: 300
metric: f1_binary
promoter_no_tata:
train_len: 47767
classes: 2
max_length: 300
metric: f1_binary
promoter_tata:
train_len: 5517
classes: 2
max_length: 300
metric: f1_binary
splice_sites_acceptors:
train_len: 19961
classes: 2
max_length: 600
metric: f1_binary
splice_sites_all:
train_len: 27000
classes: 3
max_length: 400
metric: accuracy
splice_sites_donors:
train_len: 19775
classes: 2
max_length: 600
metric: f1_binary
# name maxlen classes samples metric
# enhancers 200 2 14968 MCC
# enhancers_types 200 3 14968 MCC
# H3 500 2 13468 MCC
# H3K4me1 500 2 28509 MCC
# H3K4me2 500 2 27614 MCC
# H3K4me3 500 2 33119 MCC
# H3K9ac 500 2 25003 MCC
# H3K14ac 500 2 29743 MCC
# H3K36me3 500 2 31392 MCC
# H3K79me3 500 2 25953 MCC
# H4 500 2 13140 MCC
# H4ac 500 2 30685 MCC
# promoter_all 300 2 53276 F1
# promoter_no_tata 300 2 47759 F1
# promoter_tata 300 2 5517 F1
# splice_sites_acceptor 600 2 19961 F1
# splice_sites_all 400 2 27000 F1
# splice_sites_donor 600 2 19775 F1
================================================
FILE: configs/experiment/hg38/genomic_benchmark.yaml
================================================
# @package _global_
defaults:
- /pipeline: genomic_benchmark
- /model: ???
- override /scheduler: cosine_warmup_timm
# there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings
# name num_seqs num_classes median len std
# dummy_mouse_enhancers_ensembl 1210 2 2381 984.4
# demo_coding_vs_intergenomic_seqs 100_000 2 200 0
# demo_human_or_worm 100_000 2 200 0
# human_enhancers_cohn 27791 2 500 0
# human_enhancers_ensembl 154842 2 269 122.6
# human_ensembl_regulatory 289061 3 401 184.3
# human_nontata_promoters 36131 2 251 0
# human_ocr_ensembl 174756 2 315 108.1
task:
loss:
_name_: cross_entropy
trainer:
accelerator: gpu
devices: 1
num_nodes: 1
accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}
max_epochs: 100
precision: 16 # bf16 only a100
gradient_clip_val: 1.0
model:
_name_: dna_embedding
dataset:
# optional, default is max_length
tokenizer_name: char
rc_aug: false # reverse complement augmentation
scheduler:
# COSINE TIMM
t_in_epochs: False
t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}}
warmup_lr_init: 1e-6
warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01}
lr_min: ${eval:0.1 * ${optimizer.lr}}
optimizer:
lr: 6e-4
weight_decay: 0.1
train:
gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
seed: 2222
global_batch_size: ${dataset.batch_size}
cross_validation: true
remove_test_loader_in_eval: true # test only at the end of training
pretrained_model_strict_load: false # false allows encoder/decoder to be used if new model uses it
# for loading backbone and not head, requires both of these flags below
pretrained_model_path: ???
pretrained_model_state_hook:
_name_: load_backbone
freeze_backbone: false
================================================
FILE: configs/experiment/hg38/genomic_benchmark_cnn.yaml
================================================
# @package _global_
defaults:
- /model: genomics_benchmark_cnn
- /pipeline: genomic_benchmark
- override /scheduler: cosine_warmup_timm
# there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings
# name num_seqs num_classes median len std
# dummy_mouse_enhancers_ensembl 1210 2 2381 984.4
# demo_coding_vs_intergenomic_seqs 100_000 2 200 0
# demo_human_or_worm 100_000 2 200 0
# human_enhancers_cohn 27791 2 500 0
# human_enhancers_ensembl 154842 2 269 122.6
# human_ensembl_regulatory 289061 3 401 184.3
# human_nontata_promoters 36131 2 251 0
# human_ocr_ensembl 174756 2 315 108.1
task:
loss:
_name_: cross_entropy
trainer:
accelerator: gpu
devices: 1
num_nodes: 1
accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}
max_epochs: 100
precision: 16 # bf16 only a100
gradient_clip_val: 1.0
encoder: id
decoder: id
dataset:
tokenizer_name: char
rc_aug: false # reverse complement augmentation
scheduler:
t_in_epochs: False
t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}}
warmup_lr_init: 1e-6
warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01}
lr_min: ${eval:0.1 * ${optimizer.lr}}
optimizer:
lr: 6e-4
weight_decay: 0.1
train:
gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
seed: 2222
global_batch_size: ${dataset.batch_size}
cross_validation: true
remove_test_loader_in_eval: true
pretrained_model_strict_load: false # false allows encoder/decoder to be used if new model uses it
================================================
FILE: configs/experiment/hg38/hg38.yaml
================================================
# @package _global_
defaults:
- /pipeline: hg38
- /model: ??? # Specify a model, e.g. model=mamba or model=hyena
- override /scheduler: cosine_warmup_timm
task:
_name_: lm
loss:
_name_: cross_entropy
ignore_index: 4
trainer:
accelerator: gpu
devices: 1
num_nodes: 1
accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}
max_epochs: null
max_steps: 10000
precision: 16 # bf16 only a100
gradient_clip_val: 1.0
limit_val_batches: 0.125
dataset:
batch_size: ${eval:1024//${trainer.devices}}
max_length: 1024
# optional, default is max_length
max_length_val: ${dataset.max_length}
max_length_test: ${dataset.max_length}
tokenizer_name: char
pad_max_length: null # needed for bpe tokenizer
add_eos: true
rc_aug: false
num_workers: 12
use_fixed_len_val: false # placing a fixed length val here, but it's really the test
mlm: false
mlm_probability: 0.0
scheduler:
t_in_epochs: False
t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
warmup_prefix: True
warmup_lr_init: 1e-6
warmup_t: ${eval:0.1*${trainer.max_steps}}
lr_min: 1e-4
optimizer:
lr: 6e-4
weight_decay: 0.1
betas: [0.9, 0.95]
train:
gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
seed: 2222
global_batch_size: 256 # effects the scheduler, need to set properly
================================================
FILE: configs/experiment/hg38/nucleotide_transformer.yaml
================================================
# @package _global_
defaults:
- /pipeline: nucleotide_transformer
- /model: ???
- override /scheduler: cosine_warmup_timm
model:
_name_: dna_embedding
trainer:
accelerator: gpu
devices: 1
num_nodes: 1
accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}
max_epochs: 100
precision: 16 # bf16 only a100
gradient_clip_val: 1.0
dataset:
tokenizer_name: char
rc_aug: false # reverse complement augmentation
scheduler:
t_in_epochs: False
t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}}
warmup_lr_init: 1e-6
warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01}
lr_min: ${eval:0.1 * ${optimizer.lr}}
optimizer:
lr: 1e-3
weight_decay: 0.1
train:
gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
seed: 2222
global_batch_size: ${dataset.batch_size}
cross_validation: true
remove_test_loader_in_eval: true # test only at the end of training
pretrained_model_strict_load: false # false allows encoder/decoder to be used if new model uses it
# for loading backbone and not head, requires both of these flags below
pretrained_model_path: ???
pretrained_model_state_hook:
_name_: load_backbone
freeze_backbone: false
================================================
FILE: configs/loader/default.yaml
================================================
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
pin_memory: True
drop_last: True
================================================
FILE: configs/model/caduceus.yaml
================================================
# Use open-source version of Mamba
_name_: caduceus_lm
config:
_target_: caduceus.configuration_caduceus.CaduceusConfig
# From original MambaConfig
d_model: 128
n_layer: 2
vocab_size: 12
ssm_cfg:
d_state: 16
d_conv: 4
expand: 2
dt_rank: "auto"
dt_min: 0.001
dt_max: 0.1
dt_init: "random"
dt_scale: 1.0
dt_init_floor: 1e-4
conv_bias: true
bias: false
use_fast_path: true
rms_norm: true
fused_add_norm: true
residual_in_fp32: false
pad_vocab_size_multiple: 8
# Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
norm_epsilon: 1e-5
# Used in init_weights
initializer_cfg:
initializer_range: 0.02
rescale_prenorm_residual: true
n_residuals_per_layer: 1
# Caduceus-specific params
bidirectional: true,
bidirectional_strategy: "add"
bidirectional_weight_tie: true
rcps: false
# Used for RCPSEmbedding / RCPSLMHead (will be filled in during model instantiation using info from tokenizer)
complement_map: null
================================================
FILE: configs/model/genomics_benchmark_cnn.yaml
================================================
# Use open-source version of Mamba
_name_: genomics_benchmark_cnn
number_of_classes: ${dataset.d_output}
vocab_size: 12
embedding_dim: 100 # See: https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/tree/main/experiments/torch_cnn_experiments
input_len: ${dataset.__l_max}
================================================
FILE: configs/model/hyena.yaml
================================================
_name_: hyena_lm
d_model: 128
n_layer: 2
d_inner: ${eval:4 * ${.d_model}}
vocab_size: 12
resid_dropout: 0.0
embed_dropout: 0.1
fused_mlp: False
fused_dropout_add_ln: False
checkpoint_mixer: False # set true for memory reduction
checkpoint_mlp: False # set true for memory reduction
residual_in_fp32: True
pad_vocab_size_multiple: 8
layer:
_name_: hyena
emb_dim: 5
filter_order: 64
local_order: 3
l_max: ${eval:${dataset.max_length}+2}
modulate: True
w: 10
lr: ${optimizer.lr}
wd: 0.0
lr_pos_emb: 0.0
================================================
FILE: configs/model/layer/hyena.yaml
================================================
_name_: hyena
l_max: 1024
order: 2
filter_order: 64
num_heads: 1
inner_factor: 1
num_blocks: 1
fused_bias_fc: false
outer_mixing: false
dropout: 0.0
filter_dropout: 0.0
filter_cls: 'hyena-filter'
post_order_ffn: false
jit_filter: false
short_filter_order: 3
activation: "id"
================================================
FILE: configs/model/mamba.yaml
================================================
# Use open-source version of Mamba
_name_: mamba_lm
config:
_target_: mamba_ssm.models.config_mamba.MambaConfig
d_model: 128 # Will be overwritten by CL in the scaling exps
n_layer: 2 # Will be overwritten by CL in the scaling exps
vocab_size: 12
pad_vocab_size_multiple: 8
rms_norm: true
fused_add_norm: true
residual_in_fp32: false
ssm_cfg:
d_state: 16
d_conv: 4
expand: 2
dt_rank: "auto"
dt_min: 0.001
dt_max: 0.1
dt_init: "random"
dt_scale: 1.0
dt_init_floor: 1e-4
conv_bias: true
bias: false
use_fast_path: true
initializer_cfg:
initializer_range: 0.02
rescale_prenorm_residual: true
n_residuals_per_layer: 1
#norm_epsilon: 1e-5 # Default arg in mamba create_block
================================================
FILE: configs/optimizer/adam.yaml
================================================
# _target_: torch.optim.Adam
_name_: adam
lr: 0.001 # Initial learning rate
# weight_decay: 0.0 # Weight decay for adam|lamb; should use AdamW instead if desired
betas: [0.9, 0.999]
================================================
FILE: configs/optimizer/adamw.yaml
================================================
# _target_: torch.optim.AdamW
_name_: adamw
lr: 0.001 # Initial learning rate
weight_decay: 0.00 # Weight decay
betas: [0.9, 0.999]
================================================
FILE: configs/optimizer/sgd.yaml
================================================
# _target_: torch.optim.SGD
_name_: sgd
lr: 0.001 # Initial learning rate
momentum: 0.9
weight_decay: 0.0 # Weight decay for adam|lamb
================================================
FILE: configs/pipeline/genomic_benchmark.yaml
================================================
# @package _global_
defaults:
- /trainer: default
- /loader: default
- /dataset: genomic_benchmark
- /task: multiclass_classification
- /optimizer: adamw
- /scheduler: plateau
- /callbacks: [base, checkpoint]
train:
monitor: val/accuracy # Needed for plateau scheduler
mode: max
encoder: id
# we need this for classification!
decoder:
_name_: sequence
mode: pool
================================================
FILE: configs/pipeline/hg38.yaml
================================================
# @package _global_
defaults:
- /trainer: default
- /loader: null
- /dataset: hg38
- /optimizer: adamw
- /scheduler: cosine_warmup
- /callbacks: [base, checkpoint]
train:
monitor: test/loss
mode: min
task:
_name_: lm
loss:
_name_: cross_entropy
ignore_index: 4 # Bake in tokenizer value for padding / EOS tokens
torchmetrics: ['perplexity', 'num_tokens']
encoder: null
decoder: null
loader:
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
pin_memory: True
drop_last: True # There's enough data and epochs, ignore the edge case
# shuffle: True
================================================
FILE: configs/pipeline/nucleotide_transformer.yaml
================================================
# @package _global_
defaults:
- /trainer: default
- /loader: default
- /dataset: nucleotide_transformer
- /task: multiclass_classification
- /optimizer: adamw
- /scheduler: plateau
- /callbacks: [base, checkpoint]
task:
loss:
_name_: cross_entropy
metrics:
- ${dataset.metric}
train:
monitor: val/${dataset.metric}
mode: max
encoder: id
# we need this for classification!
decoder:
_name_: sequence
mode: pool
================================================
FILE: configs/scheduler/constant.yaml
================================================
# @package _global_
train:
interval: epoch
scheduler:
# _target_: transformers.get_constant_schedule
_name_: constant
================================================
FILE: configs/scheduler/constant_warmup.yaml
================================================
# @package _global_
train:
interval: step
scheduler:
# _target_: transformers.get_constant_schedule_with_warmup
_name_: constant_warmup
num_warmup_steps: 1000 # Number of iterations for LR warmup
================================================
FILE: configs/scheduler/cosine.yaml
================================================
# @package _global_
train:
interval: epoch
scheduler:
# _target_: torch.optim.lr_scheduler.CosineAnnealingLR
_name_: cosine
T_max: 100 # Max number of epochs steps for LR scheduler
eta_min: 1e-6 # Min learning rate for cosine scheduler
================================================
FILE: configs/scheduler/cosine_warmup.yaml
================================================
# @package _global_
train:
interval: step
scheduler:
# _target_: transformers.get_cosine_schedule_with_warmup
_name_: cosine_warmup
num_warmup_steps: 1000
num_training_steps: 40000
================================================
FILE: configs/scheduler/cosine_warmup_timm.yaml
================================================
# @package _global_
train:
interval: step
scheduler:
# _target_: transformers.get_cosine_schedule_with_warmup
_name_: cosine_warmup_timm
t_in_epochs: False
t_initial: 300
lr_min: 1e-5
warmup_lr_init: 1e-6
warmup_t: 10
================================================
FILE: configs/scheduler/linear_warmup.yaml
================================================
# @package _global_
train:
interval: step
scheduler:
# _target_: transformers.get_linear_schedule_with_warmup
_name_: linear_warmup
num_warmup_steps: 1000
num_training_steps: 40000
================================================
FILE: configs/scheduler/multistep.yaml
================================================
# @package _global_
train:
interval: epoch
# _target_: torch.optim.lr_scheduler.MultiStepLR
scheduler:
_name_: multistep
milestones: [80,140,180]
gamma: 0.2
================================================
FILE: configs/scheduler/plateau.yaml
================================================
# @package _global_
train:
interval: epoch
monitor: ??? # must be specified
scheduler:
# _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
_name_: plateau
mode: ${train.mode} # Which metric to monitor
factor: 0.2 # Decay factor when ReduceLROnPlateau is used
patience: 20
min_lr: 0.0 # Minimum learning rate during annealing
================================================
FILE: configs/scheduler/step.yaml
================================================
# @package _global_
train:
interval: step
scheduler:
# _target_: torch.optim.lr_scheduler.StepLR
_name_: step
step_size: 1
gamma: 0.99
================================================
FILE: configs/task/lm.yaml
================================================
_name_: lm
# loss: cross_entropy # Handled by task: cross entropy loss
metrics: ppl
================================================
FILE: configs/task/multiclass_classification.yaml
================================================
# _target_: tasks.tasks.MultiClass
_name_: multiclass
loss: cross_entropy
metrics:
- accuracy
torchmetrics: null
================================================
FILE: configs/task/multilabel_classification.yaml
================================================
# _target_:
_name_: base
loss: binary_cross_entropy
metrics: null
torchmetrics:
- MultilabelAUROC # AUROC
- MultilabelAveragePrecision # Precision
# - Recall # not supported in torchmetrics
# - F1 # not supported in torchmetrics
================================================
FILE: configs/task/regression.yaml
================================================
# _target_: tasks.tasks.BaseTask
_name_: base
loss: mse
metrics: mse
torchmetrics: null
================================================
FILE: configs/trainer/debug.yaml
================================================
defaults:
- default
gpus: 1
min_epochs: 1
max_epochs: 10
# prints
progress_bar_refresh_rate: null
weights_summary: full
profiler: null
# debugs
fast_dev_run: False
num_sanity_val_steps: 2
overfit_batches: 0
limit_train_batches: 0.1
limit_val_batches: 0.1
limit_test_batches: 0.1
track_grad_norm: -1
terminate_on_nan: False
================================================
FILE: configs/trainer/default.yaml
================================================
_target_: pytorch_lightning.Trainer
devices: 1
accelerator: gpu
accumulate_grad_batches: 1 # Gradient accumulation every n batches
max_epochs: 200
# accelerator: ddp # Automatically set if gpus > 1
gradient_clip_val: 0.0
log_every_n_steps: 10
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
num_sanity_val_steps: 2 # default value: 2; override to 0 to skip sanity checking
================================================
FILE: configs/trainer/full.yaml
================================================
_target_: pytorch_lightning.Trainer
# default values for all trainer parameters
checkpoint_callback: True
default_root_dir: null
gradient_clip_val: 0.0
process_position: 0
num_nodes: 1
num_processes: 1
gpus: null
auto_select_gpus: False
tpu_cores: null
log_gpu_memory: null
overfit_batches: 0.0
track_grad_norm: -1
check_val_every_n_epoch: 1
fast_dev_run: False
accumulate_grad_batches: 1
max_epochs: 1
min_epochs: 1
max_steps: null
min_steps: null
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
val_check_interval: 1.0
flush_logs_every_n_steps: 100
log_every_n_steps: 50
accelerator: null
sync_batchnorm: False
precision: 32
weights_summary: "top"
weights_save_path: null
num_sanity_val_steps: 2
truncated_bptt_steps: null
resume_from_checkpoint: null
profiler: null
benchmark: False
deterministic: False
reload_dataloaders_every_epoch: False
auto_lr_find: False
replace_sampler_ddp: True
terminate_on_nan: False
auto_scale_batch_size: False
prepare_data_per_node: True
plugins: null
amp_backend: "native"
amp_level: "O2"
move_metrics_to_cpu: False
================================================
FILE: configs/trainer/lm.yaml
================================================
accumulate_grad_batches: 1
# accelerator: null # set to 'ddp' for distributed
# amp_backend: native # 'native' | 'apex'
gpus: 8
max_epochs: 50
gradient_clip_val: 0.0 # Gradient clipping
log_every_n_steps: 10
precision: 16
progress_bar_refresh_rate: 1
weights_summary: top # Set to 'full' to see every layer
track_grad_norm: -1 # Set to 2 to track norms of gradients
limit_train_batches: 1.0
limit_val_batches: 1.0
# We use the dataloader from Transformer-XL to ensure adjacent minibatches
# are from text that are next to each other.
# So that dataloader has to deal with DDP, and we don't want PL to handle
# that.
replace_sampler_ddp: False
================================================
FILE: setup_env.sh
================================================
#!/bin/bash
# Shell script to set environment variables when running code in this repository.
# Usage:
# source setup_env.sh
# Activate conda env
# shellcheck source=${HOME}/.bashrc disable=SC1091
source "${CONDA_SHELL}"
if [ -z "${CONDA_PREFIX}" ]; then
conda activate caduceus_env
elif [[ "${CONDA_PREFIX}" != *"/caduceus_env" ]]; then
conda deactivate
conda activate caduceus_env
fi
# Add root directory to PYTHONPATH to enable module imports
export PYTHONPATH="${PWD}"
================================================
FILE: slurm_scripts/dump_vep_embeddings.sh
================================================
#!/bin/bash
#SBATCH --get-user-env # Retrieve the users login environment
#SBATCH -t 96:00:00 # Time limit (hh:mm:ss)
#SBATCH --mem=100G # RAM
#SBATCH --gres=gpu:8 # Number of GPUs
#SBATCH --ntasks-per-node=8 # Should correspond to num devices (at least 1-1 task to GPU)
##SBATCH --cpus-per-task=4 # Number of CPU cores per task
#SBATCH -N 1 # Number of nodes
#SBATCH --requeue # Requeue job if it fails
#SBATCH --job-name=vep_embed # Job name
#SBATCH--output=../watch_folder/%x_%j.log # Output file name
#SBATCH --open-mode=append # Do not overwrite logs
NUM_WORKERS=2
NUM_DEVICES=8
# Setup environment
cd ../ || exit # Go to the root directory of the repo
source setup_env.sh
export CUDA_LAUNCH_BLOCKING=1
export CUBLAS_WORKSPACE_CONFIG=:4096:8 # Needed for setting deterministic functions for reproducibility
#####################################################################################
# Choose from one of the following:
## Enformer
#seq_len=196608
#bp_per_token=1
#embed_dump_batch_size=1
#model_name_or_path="EleutherAI/enformer-official-rough"
#name="enformer-seqlen=196k"
#rcps_flag="no-rcps"
## NTv2
#seq_len=12288 # 2048 (seq len) * 6 (kmers)
#bp_per_token=6
#embed_dump_batch_size=1
#model_name_or_path="InstaDeepAI/nucleotide-transformer-v2-500m-multi-species"
#name="NTv2_downstream-seqlen=12k"
#rcps_flag="no-rcps"
## Hyena
#seq_len=131072
#bp_per_token=1
#embed_dump_batch_size=1
#model_name_or_path="LongSafari/hyenadna-medium-160k-seqlen-hf"
#name="hyena_downstream-seqlen=131k"
#rcps_flag="no-rcps"
## Caduceus-Ph
#seq_len=131072
#bp_per_token=1
#embed_dump_batch_size=1
#model_name_or_path="kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
#name="caduceus-ph_downstream-seqlen=131k"
#rcps_flag="no-rcps"
## Caduceus-PS
#seq_len=131072
#bp_per_token=1
#embed_dump_batch_size=1
#model_name_or_path="kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16"
#name="caduceus-ps_downstream-seqlen=131k"
#rcps_flag="rcps"
#####################################################################################
torchrun \
--standalone \
--nnodes=1 \
--nproc-per-node=${NUM_DEVICES} \
vep_embeddings.py \
--num_workers=${NUM_WORKERS} \
--seq_len=${seq_len} \
--bp_per_token=${bp_per_token} \
--embed_dump_batch_size=${embed_dump_batch_size} \
--name="${name}" \
--model_name_or_path="${model_name_or_path}" \
--"${rcps_flag}"
================================================
FILE: slurm_scripts/run_genomics_benchmark.sh
================================================
#!/bin/bash
#SBATCH --get-user-env # Retrieve the users login environment
#SBATCH -t 96:00:00 # Time limit (hh:mm:ss)
#SBATCH --mem=64000M # RAM
#SBATCH --gres=gpu:1 # Number of GPUs
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=2
#SBATCH -N 1 # Number of nodes
#SBATCH --requeue # Requeue job if it fails
#SBATCH --open-mode=append # Do not overwrite logs
# Setup environment
cd ../ || exit # Go to the root directory of the repo
source setup_env.sh
# Expected args:
# - CONFIG_PATH
# - PRETRAINED_PATH
# - DISPLAY_NAME
# - MODEL
# - MODEL_NAME
# - CONJOIN_TRAIN_DECODER
# - CONJOIN_TEST
# - TASK
# - LR
# - BATCH_SIZE
# - RC_AUG
# Run script
# shellcheck disable=SC2154
WANDB_NAME="${DISPLAY_NAME}_lr-${LR}_batch_size-${BATCH_SIZE}_rc_aug-${RC_AUG}"
for seed in $(seq 1 5); do
# shellcheck disable=SC2154
HYDRA_RUN_DIR="./outputs/downstream/gb_cv5/${TASK}/${WANDB_NAME}/seed-${seed}"
mkdir -p "${HYDRA_RUN_DIR}"
echo "*****************************************************"
echo "Running GenomicsBenchmark model: ${DISPLAY_NAME}, task: ${TASK}, lr: ${LR}, batch_size: ${BATCH_SIZE}, rc_aug: ${RC_AUG}, SEED: ${seed}"
# shellcheck disable=SC2086
python -m train \
experiment=hg38/genomic_benchmark \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
dataset.dataset_name="${TASK}" \
dataset.train_val_split_seed=${seed} \
dataset.batch_size=${BATCH_SIZE} \
dataset.rc_aug="${RC_AUG}" \
+dataset.conjoin_train=false \
+dataset.conjoin_test="${CONJOIN_TEST}" \
model="${MODEL}" \
model._name_="${MODEL_NAME}" \
+model.config_path="${CONFIG_PATH}" \
+model.conjoin_test="${CONJOIN_TEST}" \
+decoder.conjoin_train="${CONJOIN_TRAIN_DECODER}" \
+decoder.conjoin_test="${CONJOIN_TEST}" \
optimizer.lr="${LR}" \
trainer.max_epochs=10 \
train.pretrained_model_path="${PRETRAINED_PATH}" \
wandb.group="downstream/gb_cv5" \
wandb.job_type="${TASK}" \
wandb.name="${WANDB_NAME}" \
wandb.id="gb_cv5_${TASK}_${WANDB_NAME}_seed-${seed}" \
+wandb.tags=\["seed-${seed}"\] \
hydra.run.dir="${HYDRA_RUN_DIR}"
echo "*****************************************************"
done
================================================
FILE: slurm_scripts/run_genomics_benchmark_cnn.sh
================================================
#!/bin/bash
#SBATCH --get-user-env # Retrieve the users login environment
#SBATCH -t 48:00:00 # Time limit (hh:mm:ss)
#SBATCH --mem=64G # RAM
#SBATCH --gres=gpu:1 # Number of GPUs
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=2
#SBATCH -N 1 # Number of nodes
#SBATCH --requeue # Requeue job if it fails
#SBATCH --open-mode=append # Do not overwrite logs
# Setup environment
cd ../ || exit # Go to the root directory of the repo
source setup_env.sh
# Expected args:
# - TASK
# - RC_AUG
# LR: 1e-3 -- in https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks, Adam optimizer is used with default lr=1e-3
LR="1e-3"
# Batch size: 64 -- See https://arxiv.org/abs/2306.15794 and https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks
BATCH_SIZE=64
# Run script
WANDB_NAME="CNN-LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}"
for seed in $(seq 1 5); do
HYDRA_RUN_DIR="./outputs/downstream/gb_cv5/${TASK}/${WANDB_NAME}/seed-${seed}"
mkdir -p "${HYDRA_RUN_DIR}"
echo "*****************************************************"
echo "Running GenomicsBenchmark TASK: ${TASK}, lr: ${LR}, batch_size: ${BATCH_SIZE}, RC_AUG: ${RC_AUG}, SEED: ${seed}"
python -m train \
experiment=hg38/genomic_benchmark_cnn \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
dataset.dataset_name="${TASK}" \
dataset.train_val_split_seed=${seed} \
dataset.batch_size=${BATCH_SIZE} \
dataset.rc_aug="${RC_AUG}" \
optimizer.lr="${LR}" \
trainer.max_epochs=10 \
wandb.group="downstream/gb_cv5" \
wandb.job_type="${TASK}" \
wandb.name="${WANDB_NAME}" \
wandb.id="gb_cv5_${TASK}_${WANDB_NAME}_seed-${seed}" \
+wandb.tags=\["seed-${seed}"\] \
hydra.run.dir="${HYDRA_RUN_DIR}"
echo "*****************************************************"
done
================================================
FILE: slurm_scripts/run_nucleotide_transformer.sh
================================================
#!/bin/bash
#SBATCH --get-user-env # Retrieve the users login environment
#SBATCH -t 96:00:00 # Time limit (hh:mm:ss)
#SBATCH --mem=64G # RAM
#SBATCH --gres=gpu:2 # Number of GPUs
#SBATCH --ntasks-per-node=2
#SBATCH --cpus-per-task=4
#SBATCH -N 1 # Number of nodes
#SBATCH --requeue # Requeue job if it fails
#SBATCH --open-mode=append # Do not overwrite logs
# Setup environment
cd ../ || exit # Go to the root directory of the repo
source setup_env.sh
export HYDRA_FULL_ERROR=1
# Expected args:
# - CONFIG_PATH
# - PRETRAINED_PATH
# - DISPLAY_NAME
# - MODEL
# - MODEL_NAME
# - CONJOIN_TRAIN_DECODER
# - CONJOIN_TEST
# - TASK
# - LR
# - BATCH_SIZE
# - RC_AUG
# Run script
WANDB_NAME="${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}"
for seed in $(seq 1 10); do
HYDRA_RUN_DIR="./outputs/downstream/nt_cv10_ep20/${TASK}/${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}/seed-${seed}"
mkdir -p "${HYDRA_RUN_DIR}"
echo "*****************************************************"
echo "Running NT model: ${DISPLAY_NAME}, TASK: ${TASK}, LR: ${LR}, BATCH_SIZE: ${BATCH_SIZE}, RC_AUG: ${RC_AUG}, SEED: ${seed}"
python -m train \
experiment=hg38/nucleotide_transformer \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
dataset.dataset_name="${TASK}" \
dataset.train_val_split_seed=${seed} \
dataset.batch_size=${BATCH_SIZE} \
dataset.rc_aug="${RC_AUG}" \
+dataset.conjoin_test="${CONJOIN_TEST}" \
model="${MODEL}" \
model._name_="${MODEL_NAME}" \
+model.config_path="${CONFIG_PATH}" \
+model.conjoin_test="${CONJOIN_TEST}" \
+decoder.conjoin_train="${CONJOIN_TRAIN_DECODER}" \
+decoder.conjoin_test="${CONJOIN_TEST}" \
optimizer.lr="${LR}" \
train.pretrained_model_path="${PRETRAINED_PATH}" \
trainer.max_epochs=20 \
wandb.group="downstream/nt_cv10_ep20" \
wandb.job_type="${TASK}" \
wandb.name="${WANDB_NAME}" \
wandb.id="nt_cv10_ep-20_${TASK}_${WANDB_NAME}_seed-${seed}" \
+wandb.tags=\["seed-${seed}"\] \
hydra.run.dir="${HYDRA_RUN_DIR}"
echo "*****************************************************"
done
================================================
FILE: slurm_scripts/run_pretrain_caduceus.sh
================================================
#!/bin/bash
#SBATCH --get-user-env # Retrieve the users login environment
#SBATCH -t 96:00:00 # Time limit (hh:mm:ss)
#SBATCH --mem=100G # RAM
#SBATCH --gres=gpu:8 # Number of GPUs
#SBATCH --ntasks-per-node=8 # Should correspond to num devices (at least 1-1 task to GPU)
##SBATCH --cpus-per-task=4 # Number of CPU cores per task
#SBATCH -N 1 # Number of nodes
#SBATCH --requeue # Requeue job if it fails
#SBATCH --job-name=caduceus_ps # Job name
#SBATCH --output=../watch_folder/%x_%j.log # Log file
#SBATCH --open-mode=append # Do not overwrite logs
# Setup environment
cd ../ || exit # Go to the root directory of the repo
source setup_env.sh
export HYDRA_FULL_ERROR=1
NUM_DEVICES=8
# Run script
SEQLEN=131072
MAX_STEPS=50000
D_MODEL=256
N_LAYER=8
LR="8e-3"
BIDIRECTIONAL_STRATEGY="add"
BIDIRECTIONAL_WEIGHT_TIE="true"
RCPS="true"
RC_AUG="false"
BATCH_SIZE=$(( 1048576 / SEQLEN ))
SEQLEN_DIS="$(echo "scale=0; ${SEQLEN} / 1000" | bc)k"
WANDB_NAME="caduceus-ps_seqlen-${SEQLEN_DIS}_d_model-${D_MODEL}_n_layer-${N_LAYER}_lr-${LR}"
HYDRA_RUN_DIR="./outputs/pretrain/hg38/${WANDB_NAME}"
mkdir -p "${HYDRA_RUN_DIR}"
srun python -m train \
experiment=hg38/hg38 \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \
dataset.max_length=${SEQLEN} \
dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \
dataset.mlm=true \
dataset.mlm_probability=0.15 \
dataset.rc_aug="${RC_AUG}" \
model="caduceus" \
model.config.d_model=${D_MODEL} \
model.config.n_layer=${N_LAYER} \
model.config.bidirectional=true \
model.config.bidirectional_strategy=${BIDIRECTIONAL_STRATEGY} \
model.config.bidirectional_weight_tie=${BIDIRECTIONAL_WEIGHT_TIE} \
model.config.rcps=${RCPS} \
optimizer.lr="${LR}" \
train.global_batch_size=${BATCH_SIZE} \
trainer.max_steps=${MAX_STEPS} \
trainer.devices=${NUM_DEVICES} \
+trainer.val_check_interval=$(( MAX_STEPS / 5 )) \
wandb.group=pretrain_hg38 \
wandb.name="${WANDB_NAME}" \
hydra.run.dir="${HYDRA_RUN_DIR}"
================================================
FILE: slurm_scripts/run_pretrain_hyena.sh
================================================
#!/bin/bash
#SBATCH --get-user-env # Retrieve the users login environment
#SBATCH -t 96:00:00 # Time limit (hh:mm:ss)
#SBATCH --mem=100G # RAM
#SBATCH --gres=gpu:8 # Number of GPUs
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=4 # Number of CPU cores per task
#SBATCH -N 1 # Number of nodes
#SBATCH --requeue # Requeue job if it fails
#SBATCH --job-name=hyena # Job name
#SBATCH --output=../watch_folder/%x_%j.log # Log file
# Setup environment
cd ../ || exit # Go to the root directory of the repo
source setup_env.sh
NUM_DEVICES=8
# Run script
SEQLEN=1024
MAX_STEPS=10000
D_MODEL=256
N_LAYER=4
LR="6e-4"
RC_AUG="true"
BATCH_SIZE=$(( 1048576 / SEQLEN ))
SEQLEN_DIS="$(echo "scale=0; ${SEQLEN} / 1000" | bc)k"
WANDB_NAME="hyena_rc_aug_seqlen-${SEQLEN_DIS}_dmodel-${D_MODEL}_nlayer-${N_LAYER}_lr-${LR}"
HYDRA_RUN_DIR="./outputs/pretrain/hg38/${WANDB_NAME}"
mkdir -p "${HYDRA_RUN_DIR}"
srun python -m train \
experiment=hg38/hg38 \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \
dataset.max_length=${SEQLEN} \
dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \
dataset.mlm=false \
dataset.mlm_probability=0.0 \
dataset.rc_aug="${RC_AUG}" \
model=hyena \
model.d_model=${D_MODEL} \
model.n_layer=${N_LAYER} \
optimizer.lr="${LR}" \
train.global_batch_size=${BATCH_SIZE} \
trainer.max_steps=${MAX_STEPS} \
trainer.devices=${NUM_DEVICES} \
+trainer.val_check_interval=$(( MAX_STEPS / 5 )) \
wandb.group=pretrain_hg38 \
wandb.name="${WANDB_NAME}" \
hydra.run.dir="${HYDRA_RUN_DIR}"
================================================
FILE: slurm_scripts/run_pretrain_mamba.sh
================================================
#!/bin/bash
#SBATCH --get-user-env # Retrieve the users login environment
#SBATCH -t 96:00:00 # Time limit (hh:mm:ss)
#SBATCH --mem=100G # RAM
#SBATCH --gres=gpu:8 # Number of GPUs
#SBATCH --ntasks-per-node=8 # Should correspond to num devices (at least 1-1 task to GPU)
#SBATCH --cpus-per-task=4 # Number of CPU cores per task
#SBATCH -N 1 # Number of nodes
#SBATCH --requeue # Requeue job if it fails
#SBATCH --job-name=mamba_ntp # Job name
#SBATCH --output=../watch_folder/%x_%j.log # Log file
# Setup environment
cd ../ || exit # Go to the root directory of the repo
source setup_env.sh
NUM_DEVICES=8
# Run script
SEQLEN=1024
MAX_STEPS=10000
D_MODEL=256
N_LAYER=8
LR="8e-3"
RC_AUG="true"
BATCH_SIZE=$(( 1048576 / SEQLEN ))
SEQLEN_DIS="$(echo "scale=0; ${SEQLEN} / 1000" | bc)k"
WANDB_NAME="mamba_ntp_rc_aug_seqlen-${SEQLEN_DIS}_d_model-${D_MODEL}_n_layer-${N_LAYER}_lr-${LR}"
HYDRA_RUN_DIR="./outputs/pretrain/hg38/${WANDB_NAME}"
mkdir -p "${HYDRA_RUN_DIR}"
srun python -m train \
experiment=hg38/hg38 \
callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \
dataset.max_length=${SEQLEN} \
dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \
dataset.mlm=false \
dataset.mlm_probability=0.0 \
dataset.rc_aug="${RC_AUG}" \
model=mamba \
model.config.d_model=${D_MODEL} \
model.config.n_layer=${N_LAYER} \
optimizer.lr="${LR}" \
train.global_batch_size=${BATCH_SIZE} \
trainer.max_steps=${MAX_STEPS} \
trainer.devices=${NUM_DEVICES} \
+trainer.val_check_interval=$(( MAX_STEPS / 5 )) \
wandb.group=pretrain_hg38 \
wandb.name="${WANDB_NAME}" \
hydra.run.dir="${HYDRA_RUN_DIR}"
================================================
FILE: slurm_scripts/wrapper_run_genomics.sh
================================================
#!/bin/bash
# Choose one from below
## Hyena
## TODO: Download HF model from https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen to ../outputs/hyena_hf/hyenadna-tiny-1k-seqlen
#LOG_DIR="../watch_folder/gb_cv5/hyena"
#CONFIG_PATH=$(realpath "../outputs/hyena_hf/hyenadna-tiny-1k-seqlen/config.json")
#PRETRAINED_PATH=$(realpath "../outputs/hyena_hf/hyenadna-tiny-1k-seqlen/weights.ckpt")
#DISPLAY_NAME="hyena"
#MODEL="hyena"
#MODEL_NAME="dna_embedding"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="false"
#RC_AUGS=( "false" "true" )
#LRS=( "6e-4" )
## Mamba NTP
#LOG_DIR="../watch_folder/gb_cv5/mamba"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/mamba_ntp_rc_aug_seqlen-1k_d_model-128_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/mamba_ntp_rc_aug_seqlen-1k_d_model-128_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="mamba_uni"
#MODEL="mamba"
#MODEL_NAME="dna_embedding_mamba"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="false"
#RC_AUGS=( "true" )
#LRS=( "1e-3" "2e-3" )
## Caduceus NO POST HOC
#LOG_DIR="../watch_folder/gb_cv5/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_NO_PH"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="false"
#RC_AUGS=( "true" )
#LRS=( "2e-3")
## Caduceus Post-Hoc
#LOG_DIR="../watch_folder/gb_cv5/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_ph"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="true"
#RC_AUGS=( "false" )
#LRS=( "1e-3" "2e-3" )
## Caduceus Parameter Sharing
#LOG_DIR="../watch_folder/gb_cv5/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_ps"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="true" # Use this in decoder to always combine forward and reverse complement channels
#CONJOIN_TEST="false"
#RC_AUGS=( "false" )
#LRS=( "1e-3" "2e-3" )
mkdir -p "${LOG_DIR}"
export_str="ALL,CONFIG_PATH=${CONFIG_PATH},PRETRAINED_PATH=${PRETRAINED_PATH},DISPLAY_NAME=${DISPLAY_NAME},MODEL=${MODEL},MODEL_NAME=${MODEL_NAME},CONJOIN_TRAIN_DECODER=${CONJOIN_TRAIN_DECODER},CONJOIN_TEST=${CONJOIN_TEST}"
for TASK in "dummy_mouse_enhancers_ensembl" "demo_coding_vs_intergenomic_seqs" "demo_human_or_worm" "human_enhancers_cohn" "human_enhancers_ensembl" "human_ensembl_regulatory" "human_nontata_promoters" "human_ocr_ensembl"; do
for LR in "${LRS[@]}"; do
for BATCH_SIZE in 128 256; do
for RC_AUG in "${RC_AUGS[@]}"; do
export_str="${export_str},TASK=${TASK},LR=${LR},BATCH_SIZE=${BATCH_SIZE},RC_AUG=${RC_AUG}"
job_name="gb_${TASK}_${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}"
sbatch \
--job-name="${job_name}" \
--output="${LOG_DIR}/%x_%j.log" \
--export="${export_str}" \
"run_genomics_benchmark.sh"
done
done
done
done
================================================
FILE: slurm_scripts/wrapper_run_genomics_cnn.sh
================================================
#!/bin/bash
LOG_DIR="../watch_folder/gb_cv5/cnn_baseline"
mkdir -p "${LOG_DIR}"
export_str="ALL"
for TASK in "dummy_mouse_enhancers_ensembl" "demo_coding_vs_intergenomic_seqs" "demo_human_or_worm" "human_enhancers_cohn" "human_enhancers_ensembl" "human_ensembl_regulatory" "human_nontata_promoters" "human_ocr_ensembl"; do
for RC_AUG in "false"; do
export_str="${export_str},TASK=${TASK},RC_AUG=${RC_AUG}"
job_name="gb_${TASK}_CNN_RC_AUG-${RC_AUG}"
sbatch \
--job-name="${job_name}" \
--output="${LOG_DIR}/%x_%j.log" \
--export="${export_str}" \
"run_genomics_benchmark_cnn.sh"
done
done
================================================
FILE: slurm_scripts/wrapper_run_nucleotide_transformer.sh
================================================
#!/bin/bash
# Choose one from below
## Caduceus NO POST HOC
#LOG_DIR="../watch_folder/nt_cv10_ep20/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_NO_PH"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="false"
#RC_AUGS=( "true" )
#LRS=( "1e-3" "2e-3")
## Caduceus Post-Hoc
#LOG_DIR="../watch_folder/nt_cv10_ep20/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_ph"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="true"
#RC_AUGS=( "false" )
#LRS=( "1e-3" "2e-3" )
## Caduceus Parameter Sharing
#LOG_DIR="../watch_folder/nt_cv10_ep20/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_ps"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="true" # Use this in decoder to always combine forward and reverse complement channels
#CONJOIN_TEST="false"
#RC_AUGS=( "false" )
#LRS=( "1e-3" "2e-3" )
mkdir -p "${LOG_DIR}"
export_str="ALL,CONFIG_PATH=${CONFIG_PATH},PRETRAINED_PATH=${PRETRAINED_PATH},DISPLAY_NAME=${DISPLAY_NAME},MODEL=${MODEL},MODEL_NAME=${MODEL_NAME},CONJOIN_TRAIN_DECODER=${CONJOIN_TRAIN_DECODER},CONJOIN_TEST=${CONJOIN_TEST}"
for TASK in "enhancers" "enhancers_types" "H3" "H3K4me1" "H3K4me2" "H3K4me3" "H3K9ac" "H3K14ac" "H3K36me3" "H3K79me3" "H4" "H4ac" "promoter_all" "promoter_no_tata" "promoter_tata" "splice_sites_all" "splice_sites_acceptors" "splice_sites_donors"; do
for LR in "${LRS[@]}"; do
for BATCH_SIZE in 128 512; do
for RC_AUG in "${RC_AUGS[@]}"; do
export_str="${export_str},TASK=${TASK},LR=${LR},BATCH_SIZE=${BATCH_SIZE},RC_AUG=${RC_AUG}"
job_name="nt_${TASK}_${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}"
sbatch \
--job-name="${job_name}" \
--output="${LOG_DIR}/%x_%j.log" \
--export="${export_str}" \
"run_nucleotide_transformer.sh"
done
done
done
done
================================================
FILE: src/__init__.py
================================================
================================================
FILE: src/callbacks/params.py
================================================
"""Callback to log the number of parameters of the model.
"""
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
class ParamsLog(pl.Callback):
""" Log the number of parameters of the model """
def __init__(
self,
total: bool = True,
trainable: bool = True,
fixed: bool = True,
):
super().__init__()
self._log_stats = AttributeDict(
{
'total_params_log': total,
'trainable_params_log': trainable,
'non_trainable_params_log': fixed,
}
)
@rank_zero_only
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
logs = {}
if self._log_stats.total_params_log:
logs["params/total"] = sum(p.numel() for p in pl_module.parameters())
if self._log_stats.trainable_params_log:
logs["params/trainable"] = sum(p.numel() for p in pl_module.parameters()
if p.requires_grad)
if self._log_stats.non_trainable_params_log:
logs["params/fixed"] = sum(p.numel() for p in pl_module.parameters()
if not p.requires_grad)
if trainer.logger:
trainer.logger.log_hyperparams(logs)
================================================
FILE: src/callbacks/timer.py
================================================
"""Callback to monitor the speed of each step and each epoch.
https://github.com/HazyResearch/transformers/blob/master/src/callbacks/speed_monitor.py
Adapted from:
https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor
"""
# We only need the speed monitoring, not the GPU monitoring
import time
from typing import Any
from pytorch_lightning import Callback, Trainer, LightningModule
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT
class Timer(Callback):
"""Monitor the speed of each step and each epoch.
"""
def __init__(
self,
step: bool = True,
inter_step: bool = True,
epoch: bool = True,
val: bool = True,
):
super().__init__()
self._log_stats = AttributeDict( {
'step_time': step,
'inter_step_time': inter_step,
'epoch_time': epoch,
'val_time': val,
})
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self._snap_epoch_time = None
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self._snap_step_time = None
self._snap_inter_step_time = None
self._snap_epoch_time = time.time()
def on_train_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any,
batch_idx: int,
) -> None:
if self._log_stats.step_time:
self._snap_step_time = time.time()
if not self._should_log(trainer):
return
logs = {}
if self._log_stats.inter_step_time and self._snap_inter_step_time:
# First log at beginning of second step
logs["timer/inter_step"] = (time.time() - self._snap_inter_step_time) # * 1000
if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step)
@rank_zero_only
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
if self._log_stats.inter_step_time:
self._snap_inter_step_time = time.time()
if not self._should_log(trainer):
return
logs = {}
if self._log_stats.step_time and self._snap_step_time:
logs["timer/step"] = (time.time() - self._snap_step_time) # * 1000
if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step)
@rank_zero_only
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None:
logs = {}
if self._log_stats.epoch_time and self._snap_epoch_time:
logs["timer/epoch"] = time.time() - self._snap_epoch_time
if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step)
def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self._snap_val_time = time.time()
@rank_zero_only
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None:
logs = {}
if self._log_stats.val_time and self._snap_val_time:
logs["timer/validation"] = time.time() - self._snap_val_time
if trainer.logger: trainer.logger.log_metrics(logs) # , step=trainer.global_step)
@staticmethod
def _should_log(trainer) -> bool:
return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop
================================================
FILE: src/callbacks/validation.py
================================================
"""Check validation every n **global** steps.
Pytorch Lightning has a `val_check_interval` parameter that checks validation every n batches, but does not support
checking every n **global** steps.
"""
from typing import Any
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import RunningStage
class ValEveryNGlobalSteps(Callback):
"""Check validation every n **global** steps."""
def __init__(self, every_n):
self.every_n = every_n
self.last_run = None
def on_train_batch_end(self, trainer, *_: Any):
"""Check if we should run validation.
Adapted from: https://github.com/Lightning-AI/pytorch-lightning/issues/2534#issuecomment-1085986529
"""
# Prevent Running validation many times in gradient accumulation
if trainer.global_step == self.last_run:
return
else:
self.last_run = None
if trainer.global_step % self.every_n == 0 and trainer.global_step != 0:
trainer.training = False
stage = trainer.state.stage
trainer.state.stage = RunningStage.VALIDATING
trainer._run_evaluate()
trainer.state.stage = stage
trainer.training = True
trainer._logger_connector._epoch_end_reached = False
self.last_run = trainer.global_step
================================================
FILE: src/dataloaders/__init__.py
================================================
from . import genomics
from .base import SequenceDataset
================================================
FILE: src/dataloaders/base.py
================================================
""" Datasets for core experimental results.
"""
import os
from functools import partial
from pathlib import Path
import torch
# Default data path is environment variable or /data
if (default_data_path := os.getenv("DATA_PATH")) is None:
default_data_path = Path(__file__).parent.parent.parent.absolute()
default_data_path = default_data_path / "data"
else:
default_data_path = Path(default_data_path).absolute()
class DefaultCollateMixin:
"""Controls collating in the DataLoader
The CollateMixin classes instantiate a dataloader by separating collate arguments with the rest of the dataloader
arguments. Instantiations of this class should modify the callback functions as desired, and modify the collate_args
list. The class then defines a _dataloader() method which takes in a DataLoader constructor and arguments,
constructs a collate_fn based on the collate_args, and passes the rest of the arguments into the constructor.
"""
@classmethod
def _collate_callback(cls, x, *args, **kwargs):
"""
Modify the behavior of the default _collate method.
"""
return x
_collate_arg_names = []
@classmethod
def _return_callback(cls, return_value, *args, **kwargs):
"""
Modify the return value of the collate_fn.
Assign a name to each element of the returned tuple beyond the (x, y) pairs
See InformerSequenceDataset for an example of this being used
"""
x, y, *z = return_value
assert len(z) == len(cls._collate_arg_names), "Specify a name for each auxiliary data item returned by dataset"
return x, y, {k: v for k, v in zip(cls._collate_arg_names, z)}
@classmethod
def _collate(cls, batch, *args, **kwargs):
# From https://github.com/pyforch/pytorch/blob/master/torch/utils/data/_utils/collate.py
elem = batch[0]
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
x = torch.stack(batch, dim=0, out=out)
# Insert custom functionality into the collate_fn
x = cls._collate_callback(x, *args, **kwargs)
return x
else:
return torch.tensor(batch)
@classmethod
def _collate_fn(cls, batch, *args, **kwargs):
"""
Default collate function.
Generally accessed by the dataloader() methods to pass into torch DataLoader
Arguments:
batch: list of (x, y) pairs
args, kwargs: extra arguments that get passed into the _collate_callback and _return_callback
"""
x, y, *z = zip(*batch)
x = cls._collate(x, *args, **kwargs)
y = cls._collate(y)
z = [cls._collate(z_) for z_ in z]
return_value = (x, y, *z)
return cls._return_callback(return_value, *args, **kwargs)
# List of loader arguments to pass into collate_fn
collate_args = []
def _dataloader(self, dataset, **loader_args):
collate_args = {k: loader_args[k] for k in loader_args if k in self.collate_args}
loader_args = {k: loader_args[k] for k in loader_args if k not in self.collate_args}
loader_cls = loader_registry[loader_args.pop("_name_", None)]
return loader_cls(
dataset=dataset,
collate_fn=partial(self._collate_fn, **collate_args),
**loader_args,
)
# class SequenceDataset(LightningDataModule):
# [21-09-10 AG] Subclassing LightningDataModule fails due to trying to access _has_setup_fit. No idea why. So we just
# provide our own class with the same core methods as LightningDataModule (e.g. setup)
class SequenceDataset(DefaultCollateMixin):
registry = {}
_name_ = NotImplementedError("Dataset must have shorthand name")
# Since subclasses do not specify __init__ which is instead handled by this class
# Subclasses can provide a list of default arguments which are automatically registered as attributes
# TODO it might be possible to write this as a @dataclass, but it seems tricky to separate from the other features
# of this class such as the _name_ and d_input/d_output
@property
def init_defaults(self):
return {}
# https://www.python.org/dev/peps/pep-0487/#subclass-registration
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.registry[cls._name_] = cls
def __init__(self, _name_, data_dir=None, **dataset_cfg):
assert _name_ == self._name_
self.data_dir = Path(data_dir).absolute() if data_dir is not None else None
# Add all arguments to self
init_args = self.init_defaults.copy()
init_args.update(dataset_cfg)
for k, v in init_args.items():
setattr(self, k, v)
# The train, val, test datasets must be set by `setup()`
self.dataset_train = self.dataset_val = self.dataset_test = None
self.init()
def init(self):
"""Hook called at end of __init__, override this instead of __init__"""
pass
def setup(self):
"""This method should set self.dataset_train, self.dataset_val, and self.dataset_test."""
raise NotImplementedError
def split_train_val(self, val_split):
"""
Randomly split self.dataset_train into a new (self.dataset_train, self.dataset_val) pair.
"""
train_len = int(len(self.dataset_train) * (1.0 - val_split))
self.dataset_train, self.dataset_val = torch.utils.data.random_split(
self.dataset_train,
(train_len, len(self.dataset_train) - train_len),
generator=torch.Generator().manual_seed(
getattr(self, "seed", 42)
), # PL is supposed to have a way to handle seeds properly, but doesn't seem to work for us
)
def train_dataloader(self, **kwargs):
"""Return a DataLoader for the training dataset."""
return self._train_dataloader(self.dataset_train, **kwargs)
def _train_dataloader(self, dataset, **kwargs):
if dataset is None:
return
kwargs['shuffle'] = 'sampler' not in kwargs # shuffle cant be True if we have custom sampler
return self._dataloader(dataset, **kwargs)
def val_dataloader(self, **kwargs):
"""Return a DataLoader for the validation dataset."""
return self._eval_dataloader(self.dataset_val, **kwargs)
def test_dataloader(self, **kwargs):
"""Return a DataLoader for the test dataset."""
return self._eval_dataloader(self.dataset_test, **kwargs)
def _eval_dataloader(self, dataset, **kwargs):
if dataset is None:
return
# Note that shuffle=False by default
return self._dataloader(dataset, **kwargs)
def __str__(self):
return self._name_
# Registry for dataloader class
loader_registry = {
None: torch.utils.data.DataLoader, # default case
}
================================================
FILE: src/dataloaders/datasets/genomic_bench_dataset.py
================================================
"""Genomic Benchmarks Dataset.
From: https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks
"""
from pathlib import Path
import torch
from genomic_benchmarks.data_check import is_downloaded
from genomic_benchmarks.loc2seq import download_dataset
from src.dataloaders.utils.rc import coin_flip, string_reverse_complement
class GenomicBenchmarkDataset(torch.utils.data.Dataset):
"""
Loop through bed file, retrieve (chr, start, end), query fasta file for sequence.
Returns a generator that retrieves the sequence.
"""
def __init__(
self,
split,
max_length,
dataset_name="human_nontata_promoters",
d_output=2, # default binary classification
dest_path=None,
tokenizer=None,
tokenizer_name=None,
use_padding=None,
add_eos=False,
rc_aug=False,
conjoin_train=False,
conjoin_test=False,
return_augs=False,
return_mask=False,
):
self.max_length = max_length
self.use_padding = use_padding
self.tokenizer_name = tokenizer_name
self.tokenizer = tokenizer
self.return_augs = return_augs
self.add_eos = add_eos
self.d_output = d_output # needed for decoder to grab
assert not (conjoin_train and conjoin_test), "conjoin_train and conjoin_test cannot both be True"
if (conjoin_train or conjoin_test) and rc_aug:
print("When using conjoin, we turn off rc_aug.")
rc_aug = False
self.rc_aug = rc_aug
self.conjoin_train = conjoin_train
self.conjoin_test = conjoin_test
self.return_mask = return_mask
if not is_downloaded(dataset_name, cache_path=dest_path):
print("downloading {} to {}".format(dataset_name, dest_path))
download_dataset(dataset_name, version=0, dest_path=dest_path)
else:
print("already downloaded {}-{}".format(split, dataset_name))
self.split = split
# use Path object
base_path = Path(dest_path) / dataset_name / split
self.all_seqs = []
self.all_labels = []
label_mapper = {}
for i, x in enumerate(base_path.iterdir()):
label_mapper[x.stem] = i
for label_type in label_mapper.keys():
for path in (base_path / label_type).iterdir():
with open(path, "r") as f:
content = f.read()
self.all_seqs.append(content)
self.all_labels.append(label_mapper[label_type])
def __len__(self):
return len(self.all_labels)
def __getitem__(self, idx):
x = self.all_seqs[idx]
y = self.all_labels[idx]
if (self.rc_aug or (self.conjoin_test and self.split == "train")) and coin_flip():
x = string_reverse_complement(x)
seq = self.tokenizer(
x,
add_special_tokens=False,
padding="max_length" if self.use_padding else None,
max_length=self.max_length,
truncation=True,
)
seq_ids = seq["input_ids"] # get input_ids
# need to handle eos here
if self.add_eos:
# append list seems to be faster than append tensor
seq_ids.append(self.tokenizer.sep_token_id)
if self.conjoin_train or (self.conjoin_test and self.split != "train"):
x_rc = string_reverse_complement(x)
seq_rc = self.tokenizer(
x_rc,
add_special_tokens=False,
padding="max_length" if self.use_padding else None,
max_length=self.max_length,
truncation=True,
)
seq_rc_ids = seq_rc["input_ids"] # get input_ids
# need to handle eos here
if self.add_eos:
# append list seems to be faster than append tensor
seq_rc_ids.append(self.tokenizer.sep_token_id)
seq_ids = torch.stack((torch.LongTensor(seq_ids), torch.LongTensor(seq_rc_ids)), dim=1)
else:
# convert to tensor
seq_ids = torch.LongTensor(seq_ids)
# need to wrap in list
target = torch.LongTensor([y])
# `seq` has shape:
# - (seq_len,) if not conjoining
# - (seq_len, 2) for conjoining
if self.return_mask:
return seq_ids, target, {"mask": torch.BoolTensor(seq["attention_mask"])}
else:
return seq_ids, target
================================================
FILE: src/dataloaders/datasets/hg38_char_tokenizer.py
================================================
"""
From: https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py
CharacterTokenizer for Hugging Face Transformers.
This is heavily inspired from CanineTokenizer in transformers package.
"""
import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
class CharacterTokenizer(PreTrainedTokenizer):
def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str = 'left', **kwargs):
"""Character tokenizer for Hugging Face transformers.
Args:
characters (Sequence[str]): List of desired characters. Any character which
is not included in this list will be replaced by a special token called
[UNK] with id=6. Following is the list of all the special tokens with
their corresponding ids:
"[CLS]": 0
"[SEP]": 1
"[BOS]": 2
"[MASK]": 3
"[PAD]": 4
"[RESERVED]": 5
"[UNK]": 6
an id (starting at 7) will be assigned to each character.
model_max_length (int): Model maximum sequence length.
"""
self.characters = characters
self.model_max_length = model_max_length
bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False)
eos_token = AddedToken("[EOS]", lstrip=False, rstrip=False)
sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False)
pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False)
unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False)
mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False)
self._vocab_str_to_int = {
"[CLS]": 0,
"[SEP]": 1,
"[BOS]": 2,
"[MASK]": 3,
"[PAD]": 4,
"[RESERVED]": 5,
"[UNK]": 6,
**{ch: i + 7 for i, ch in enumerate(characters)},
}
self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
# TODO: This should be a parameter passed to __init__
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
self.complement_map = {}
for k, v in self._vocab_str_to_int.items():
complement_id = self._vocab_str_to_int[complement_map[k]] if k in complement_map.keys() else v
self.complement_map[self._vocab_str_to_int[k]] = complement_id
super().__init__(
bos_token=bos_token,
eos_token=pad_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
unk_token=unk_token,
add_prefix_space=False,
model_max_length=model_max_length,
padding_side=padding_side,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self._vocab_str_to_int)
def _tokenize(self, text: str) -> List[str]:
return list(text)
def _convert_token_to_id(self, token: str) -> int:
return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
def _convert_id_to_token(self, index: int) -> str:
return self._vocab_int_to_str[index]
def convert_tokens_to_string(self, tokens):
return "".join(tokens)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
sep = [self.sep_token_id]
cls = [self.cls_token_id]
result = cls + token_ids_0 + sep
if token_ids_1 is not None:
result += token_ids_1 + sep
return result
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
result = [1] + ([0] * len(token_ids_0)) + [1]
if token_ids_1 is not None:
result += ([0] * len(token_ids_1)) + [1]
return result
def get_vocab(self) -> Dict[str, int]:
return self._vocab_str_to_int
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
sep = [self.sep_token_id]
cls = [self.cls_token_id]
result = len(cls + token_ids_0 + sep) * [0]
if token_ids_1 is not None:
result += len(token_ids_1 + sep) * [1]
return result
def get_config(self) -> Dict:
return {
"char_ords": [ord(ch) for ch in self.characters],
"model_max_length": self.model_max_length,
}
@classmethod
def from_config(cls, config: Dict) -> "CharacterTokenizer":
cfg = {}
cfg["characters"] = [chr(i) for i in config["char_ords"]]
cfg["model_max_length"] = config["model_max_length"]
return cls(**cfg)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
cfg_file = Path(save_directory) / "tokenizer_config.json"
cfg = self.get_config()
with open(cfg_file, "w") as f:
json.dump(cfg, f, indent=4)
@classmethod
def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
cfg_file = Path(save_directory) / "tokenizer_config.json"
with open(cfg_file) as f:
cfg = json.load(f)
return cls.from_config(cfg)
================================================
FILE: src/dataloaders/datasets/hg38_dataset.py
================================================
"""Dataset for sampling arbitrary intervals from the human genome.
"""
import math
from pathlib import Path
import pandas as pd
import torch
from pyfaidx import Fasta
from src.dataloaders.utils.mlm import mlm_getitem
from src.dataloaders.utils.rc import coin_flip, string_reverse_complement
MAX_ALLOWED_LENGTH = 2 ** 20
class FastaInterval:
"""Retrieves sequences from a fasta file given a chromosome and start/end indices."""
def __init__(
self,
*,
fasta_file,
return_seq_indices=False,
rc_aug=False,
):
fasta_file = Path(fasta_file)
assert fasta_file.exists(), "Path to fasta file must exist!"
self.seqs = Fasta(str(fasta_file))
self.return_seq_indices = return_seq_indices
self.rc_aug = rc_aug
# calc len of each chromosome in fasta file, store in dict
self.chr_lens = {}
for chr_name in self.seqs.keys():
self.chr_lens[chr_name] = len(self.seqs[chr_name])
@staticmethod
def _compute_interval(start, end, max_length, i_shift):
if max_length == MAX_ALLOWED_LENGTH:
return start, end
if max_length < MAX_ALLOWED_LENGTH:
assert MAX_ALLOWED_LENGTH % max_length == 0
return start + i_shift * max_length, start + (i_shift + 1) * max_length
else:
raise ValueError(f"`max_length` {max_length} (> 2^{int(math.log(MAX_ALLOWED_LENGTH, 2))}) is too large!")
def __call__(
self,
chr_name,
start,
end,
max_length,
i_shift,
return_augs=False,
):
"""
max_length passed from dataset, not from init
"""
chromosome = self.seqs[chr_name]
chromosome_length = self.chr_lens[chr_name]
start, end = self._compute_interval(start, end, max_length, i_shift)
if end > chromosome_length:
# Shift interval down
start = start - (end - chromosome_length)
end = chromosome_length
assert start == chromosome_length - max_length
if start < 0:
# Shift interval up
end = end - start
start = 0
assert end == max_length
if end > chromosome_length:
# This may occur if start + MAX_ALLOWED_LENGTH extends beyond the end of the chromosome
start = chromosome_length - max_length
end = chromosome_length
seq = str(chromosome[start:end])
if self.rc_aug and coin_flip():
seq = string_reverse_complement(seq)
return seq
class HG38Dataset(torch.utils.data.Dataset):
"""Loop through bed file, retrieve (chr, start, end), query fasta file for sequence."""
def __init__(
self,
split,
bed_file,
fasta_file,
max_length,
mlm=False,
mlm_probability=0.15,
pad_max_length=None,
tokenizer=None,
tokenizer_name=None,
add_eos=False,
return_seq_indices=False,
rc_aug=False,
return_augs=False,
):
self.mlm = mlm
self.mlm_probability = mlm_probability
if self.mlm and self.mlm_probability <= 0.0:
raise ValueError(f"`mlm_probability` has to be > 0.0, got {self.mlm_probability}.")
if self.mlm:
# TODO: see if this helps
# self.eligible_replacements = torch.tensor(
# tokenizer("ACGT", add_special_tokens=False)["input_ids"], dtype=torch.long
# )
self.eligible_replacements = None
else:
self.eligible_replacements = None
self.max_length = max_length
self.pad_max_length = pad_max_length if pad_max_length is not None else max_length
self.tokenizer_name = tokenizer_name
self.tokenizer = tokenizer
self.return_augs = return_augs
self.add_eos = add_eos
if max_length <= MAX_ALLOWED_LENGTH:
assert MAX_ALLOWED_LENGTH % max_length == 0, f"`max_length` must be a power of 2!"
self.shifts = MAX_ALLOWED_LENGTH // max_length
else:
raise ValueError(f"`max_length` {max_length} (> 2^{int(math.log(MAX_ALLOWED_LENGTH, 2))}) is too large!")
bed_path = Path(bed_file)
assert bed_path.exists(), "Path to .bed file must exist!"
# read bed file
df_raw = pd.read_csv(str(bed_path), sep="\t", names=["chr_name", "start", "end", "split"])
# select only split df
self.df = df_raw[df_raw["split"] == split]
# Update end points so that sequences are all length == MAX_ALLOWED_LENGTH
self.df.loc[:, "end"] = self.df["start"] + MAX_ALLOWED_LENGTH
self.fasta = FastaInterval(
fasta_file=fasta_file,
return_seq_indices=return_seq_indices,
rc_aug=rc_aug
)
@staticmethod
def replace_value(x, old_value, new_value):
"""Helper for replacing values in a tensor."""
return torch.where(x == old_value, new_value, x)
def __len__(self):
return len(self.df) * self.shifts
def __getitem__(self, idx):
"""Returns a sequence of specified len"""
# sample a random row from df
row_idx, shift_idx = idx // self.shifts, idx % self.shifts
row = self.df.iloc[row_idx]
chr_name, start, end = (row.iloc[0], row.iloc[1], row.iloc[2])
seq = self.fasta(
chr_name,
start,
end,
max_length=self.max_length,
i_shift=shift_idx,
return_augs=self.return_augs,
)
if end - start != MAX_ALLOWED_LENGTH:
print(row, "\nLength: ", end - start)
if self.tokenizer_name == "char":
seq = self.tokenizer(
seq,
padding="max_length",
max_length=self.pad_max_length,
truncation=True,
add_special_tokens=False
)
seq = seq["input_ids"] # get input_ids
# need to handle eos here
if self.add_eos:
# append list seems to be faster than append tensor
seq.append(self.tokenizer.sep_token_id)
elif self.tokenizer_name == "bpe":
seq = self.tokenizer(
seq,
# add_special_tokens=False,
padding="max_length",
max_length=self.pad_max_length,
truncation=True,
)
# get input_ids
if self.add_eos:
seq = seq["input_ids"][1:] # remove the bos, keep the eos token
else:
seq = seq["input_ids"][1:-1] # remove both special tokens
# convert to tensor
seq = torch.LongTensor(seq)
# replace N token with a pad token, so we can ignore it in the loss
seq = self.replace_value(seq, self.tokenizer._vocab_str_to_int["N"], self.tokenizer.pad_token_id)
if self.mlm:
data, target = mlm_getitem(
seq,
mlm_probability=self.mlm_probability,
contains_eos=self.add_eos,
tokenizer=self.tokenizer,
eligible_replacements=self.eligible_replacements,
)
else:
data = seq[:-1].clone()
target = seq[1:].clone()
return data, target
================================================
FILE: src/dataloaders/datasets/nucleotide_transformer_dataset.py
================================================
"""Nucleotide Transformer Benchmarks Dataset.
From: https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_downstream_tasks
"""
import torch
from datasets import load_dataset
from src.dataloaders.utils.rc import coin_flip, string_reverse_complement
class NucleotideTransformerDataset(torch.utils.data.Dataset):
"""
Loop through fasta file for sequence.
Returns a generator that retrieves the sequence.
"""
def __init__(
self,
split,
max_length,
dataset_name=None,
d_output=2, # default binary classification
tokenizer=None,
tokenizer_name=None,
use_padding=None,
add_eos=False,
rc_aug=False,
conjoin_train=False,
conjoin_test=False,
return_augs=False
):
self.max_length = max_length
self.use_padding = use_padding
self.tokenizer_name = tokenizer_name
self.tokenizer = tokenizer
self.return_augs = return_augs
self.add_eos = add_eos
self.d_output = d_output # needed for decoder to grab
assert not (conjoin_train and conjoin_test), "conjoin_train and conjoin_test cannot both be True"
if (conjoin_train or conjoin_test) and rc_aug:
print("When using conjoin, we turn off rc_aug.")
rc_aug = False
self.rc_aug = rc_aug
self.conjoin_train = conjoin_train
self.conjoin_test = conjoin_test
self.split = split
# For NT tasks, we use data from InstaDeepAI/nucleotide_transformer_downstream_tasks
self.seqs = load_dataset(
"InstaDeepAI/nucleotide_transformer_downstream_tasks",
name=dataset_name,
split=split
)
def __len__(self):
return len(self.seqs)
def __getitem__(self, idx):
x = self.seqs[idx]["sequence"] # only one sequence
y = self.seqs[idx]["label"]
if (self.rc_aug or (self.conjoin_test and self.split == "train")) and coin_flip():
x = string_reverse_complement(x)
seq = self.tokenizer(
x,
add_special_tokens=False,
padding="max_length" if self.use_padding else None,
max_length=self.max_length,
truncation=True,
)
seq_ids = seq["input_ids"] # get input_ids
# need to handle eos here
if self.add_eos:
# append list seems to be faster than append tensor
seq_ids.append(self.tokenizer.sep_token_id)
if self.conjoin_train or (self.conjoin_test and self.split != "train"):
x_rc = string_reverse_complement(x)
seq_rc = self.tokenizer(
x_rc,
add_special_tokens=False,
padding="max_length" if self.use_padding else None,
max_length=self.max_length,
truncation=True,
)
seq_rc_ids = seq_rc["input_ids"] # get input_ids
# need to handle eos here
if self.add_eos:
# append list seems to be faster than append tensor
seq_rc_ids.append(self.tokenizer.sep_token_id)
seq_ids = torch.stack((torch.LongTensor(seq_ids), torch.LongTensor(seq_rc_ids)), dim=1)
else:
# convert to tensor
seq_ids = torch.LongTensor(seq_ids)
# need to wrap in list
target = torch.LongTensor([y])
# `seq` has shape:
# - (seq_len,) if not conjoining
# - (seq_len, 2) for conjoining
return seq_ids, target
================================================
FILE: src/dataloaders/fault_tolerant_sampler.py
================================================
# Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytorch/utilities/test_auto_restart.py#L1397
from typing import Iterator
import math
import torch
from torch.utils.data import RandomSampler, DistributedSampler
class RandomFaultTolerantSampler(RandomSampler):
def __init__(self, *args, generator=None, **kwargs):
# generator = torch.Generator().manual_seed(seed)
# super().__init__(*args, generator=generator, **kwargs)
# TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
# which should be reproducible if pl.seed_everything was called before hand.
# This means that changing the seed of the experiment will also change the
# sampling order.
if generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator().manual_seed(seed)
super().__init__(*args, generator=generator, **kwargs)
self.counter = 0
# self.start_counter = 0
self.restarting = False
def state_dict(self):
return {"random_state": self.state, "counter": self.counter}
def load_state_dict(self, state_dict):
self.generator.set_state(state_dict.get("random_state"))
self.counter = state_dict["counter"]
# self.start_counter = self.counter
self.restarting = True
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
# epoch, and subsequent epoch will have very few batches.
# def __len__(self):
# # We need a separate self.start_counter because PL seems to call len repeatedly.
# # If we use len(self.data_source) - self.counter then PL will think the epoch ends
# # when we're only half way through.
# return len(self.data_source) - self.start_counter
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
self.state = self.generator.get_state()
indices = torch.randperm(n, generator=self.generator).tolist()
if not self.restarting:
self.counter = 0
else:
indices = indices[self.counter:]
self.restarting = False
# self.start_counter = self.counter
for index in indices:
self.counter += 1
yield index
self.counter = 0
# self.start_counter = self.counter
class FaultTolerantDistributedSampler(DistributedSampler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.counter = 0
# self.start_counter = 0
self.restarting = False
def state_dict(self):
return {"epoch": self.epoch, "counter": self.counter}
def load_state_dict(self, state_dict):
self.epoch = state_dict["epoch"]
self.counter = state_dict["counter"]
# self.start_counter = self.counter
self.restarting = True
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
# epoch, and subsequent epoch will have very few batches.
# def __len__(self) -> int:
# return self.num_samples - self.start_counter
def __iter__(self):
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
if not self.restarting:
self.counter = 0
else:
indices = indices[self.counter:]
self.restarting = False
# self.start_counter = self.counter
for index in indices:
self.counter += 1
yield index
self.counter = 0
# self.start_counter = self.counter
================================================
FILE: src/dataloaders/genomics.py
================================================
"""Dataloaders for genomics datasets, including pretraining and downstream tasks.
- Adapted from:
https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py
- Adapted from:
https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
"""
import copy
from typing import Any, List, Union
import torch
from datasets import Dataset
from torch.utils.data.dataloader import DataLoader
from caduceus.tokenization_caduceus import CaduceusTokenizer
import src.utils.train
from src.dataloaders.base import SequenceDataset, default_data_path
from src.dataloaders.datasets.genomic_bench_dataset import GenomicBenchmarkDataset
from src.dataloaders.datasets.hg38_char_tokenizer import CharacterTokenizer
from src.dataloaders.datasets.hg38_dataset import HG38Dataset
from src.dataloaders.datasets.nucleotide_transformer_dataset import NucleotideTransformerDataset
from src.dataloaders.fault_tolerant_sampler import FaultTolerantDistributedSampler
from src.dataloaders.fault_tolerant_sampler import RandomFaultTolerantSampler
logger = src.utils.train.get_logger(__name__)
class HG38(SequenceDataset):
"""
Base class, other dataloaders can inherit from this class.
You must implement the following functions:
- __init__
- setup
You can then use (already have access to) the following functions:
- train_dataloader
- val_dataloader
- test_dataloader
"""
_name_ = "hg38" # this name is how the dataset config finds the right dataloader
def __init__(self, bed_file, fasta_file, tokenizer_name=None, dataset_config_name=None, max_length=1024, d_output=2,
rc_aug=False,
max_length_val=None, max_length_test=None, val_ratio=0.0005, val_split_seed=2357,
add_eos=True, detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, shuffle=False,
num_workers=1,
fault_tolerant=False, ddp=False,
fast_forward_epochs=None, fast_forward_batches=None,
mlm=False, mlm_probability=0.15,
*args, **kwargs):
self.dataset_config_name = dataset_config_name
self.tokenizer_name = tokenizer_name
self.d_output = d_output
self.rc_aug = rc_aug # reverse compliment augmentation
self.max_length = max_length
self.max_length_val = max_length_val if max_length_val is not None else max_length
self.max_length_test = max_length_test if max_length_test is not None else max_length
self.val_ratio = val_ratio
self.val_split_seed = val_split_seed
self.val_only = val_only
self.add_eos = add_eos
self.detokenize = detokenize
self.batch_size = batch_size
self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size
self.shuffle = shuffle
self.num_workers = num_workers
self.bed_file = bed_file
self.fasta_file = fasta_file
# handle if file paths are None (default paths)
if self.bed_file is None:
self.bed_file = default_data_path / self._name_ / "human-sequences.bed"
if self.fasta_file is None:
self.fasta_file = default_data_path / self._name_ / "hg38.ml.fa"
if fault_tolerant:
assert self.shuffle
self.fault_tolerant = fault_tolerant
if ddp:
assert fault_tolerant
self.ddp = ddp
self.fast_forward_epochs = fast_forward_epochs
self.fast_forward_batches = fast_forward_batches
if self.fast_forward_epochs is not None or self.fast_forward_batches is not None:
assert ddp and fault_tolerant
self.mlm = mlm
self.mlm_probability = mlm_probability
# To be instantiated in `setup`
self.tokenizer = None
self.vocab_size = 0
def setup(self, stage=None):
"""Set up the tokenizer and init the datasets."""
# TODO instantiate with registry
if self.tokenizer_name == "char":
logger.info("**Using Char-level tokenizer**")
# self.tokenizer = CharacterTokenizer(
# characters=["A", "C", "G", "T", "N"],
# model_max_length=self.max_length,
# add_special_tokens=False,
# )
self.tokenizer = CaduceusTokenizer(
model_max_length=self.max_length,
add_special_tokens=False
)
else:
raise NotImplementedError(f"Tokenizer {self.tokenizer_name} not implemented.")
self.vocab_size = len(self.tokenizer)
self.init_datasets() # creates the datasets. You can also just create this inside the setup() here.
def init_datasets(self):
"""Init the datasets (separate from the tokenizer)"""
# delete old datasets to free memory
if hasattr(self, "dataset_train"):
self.dataset_train.fasta.seqs.close()
del self.dataset_train.fasta.seqs
# delete old datasets to free memory
if hasattr(self, "dataset_test"):
self.dataset_test.fasta.seqs.close()
del self.dataset_test.fasta.seqs
# Create all splits: torch datasets
self.dataset_train, self.dataset_val, self.dataset_test = [
HG38Dataset(split=split,
bed_file=self.bed_file,
fasta_file=self.fasta_file,
max_length=max_len,
tokenizer=self.tokenizer, # pass the tokenize wrapper
tokenizer_name=self.tokenizer_name,
add_eos=self.add_eos,
return_seq_indices=False,
rc_aug=self.rc_aug,
return_augs=False,
mlm=self.mlm,
mlm_probability=self.mlm_probability, )
for split, max_len in
zip(["train", "valid", "test"], [self.max_length, self.max_length_val, self.max_length_test])
]
return
def train_dataloader(self, **kwargs: Any) -> DataLoader:
""" The train dataloader """
if self.shuffle and self.fault_tolerant:
shuffle = False
# TD [2022-12-26]: We need the distributed_sampler_kwargs in case of model parallel:
# In that case the number of replicas and the data parallel rank are more complicated.
distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs
sampler = (FaultTolerantDistributedSampler(
self.dataset_train,
**distributed_sampler_kwargs
) if self.ddp else RandomFaultTolerantSampler(self.dataset_train))
# TD [2022-08-06]: Only the DDP sampler supports fast-forwarding for now
# We assume that it's being resumed with the same number of GPUs
if self.ddp and self.fast_forward_epochs is not None and self.fast_forward_batches is not None:
sampler.load_state_dict({
"epoch": self.fast_forward_epochs,
"counter": self.fast_forward_batches * self.batch_size
})
else:
shuffle = self.shuffle
sampler = None
loader = self._data_loader(self.dataset_train, batch_size=self.batch_size,
shuffle=shuffle, sampler=sampler, **kwargs)
return loader
def val_dataloader(self, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The val dataloader """
kwargs["drop_last"] = False
return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval, **kwargs)
def test_dataloader(self, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The test dataloader """
kwargs["drop_last"] = False
# TODO: Should have separate train and eval loaders
return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval, **kwargs)
@staticmethod
def _data_loader(dataset: Dataset, batch_size: int, shuffle: bool = False, sampler=None, **kwargs) -> DataLoader:
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
**kwargs,
)
def load_state_dict(self, checkpoint):
if self.fault_tolerant:
self.fast_forward_epochs = checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"]
# TD [2022-08-07] ["epoch_loop.batch_progress"]["total"]["completed"] is 1 iteration
# behind, so we're using the optimizer"s progress. This is set correctly in seq.py.
self.fast_forward_batches = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["current"][
"completed"]
# At this point the train loader hasn't been constructed yet
class GenomicBenchmark(HG38):
_name_ = "genomic_benchmark"
l_output = 0 # need to set this for decoder to work correctly
def __init__(
self, dataset_name, train_val_split_seed,
dest_path=None, tokenizer_name="char", d_output=None, rc_aug=False,
conjoin_train=False, conjoin_test=False,
max_length=1024, use_padding=True, max_length_val=None, max_length_test=None,
padding_side="left", val_ratio=0.0005, val_split_seed=2357, add_eos=False,
detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1,
shuffle=True, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False,
fast_forward_epochs=None, fast_forward_batches=None, *args, **kwargs
):
self.dataset_name = dataset_name
self.train_val_split_seed = train_val_split_seed
self.dest_path = dest_path
self.tokenizer_name = tokenizer_name
self.d_output = d_output
self.rc_aug = rc_aug
self.conjoin_train = conjoin_train
self.conjoin_test = conjoin_test
self.max_length = max_length
self.use_padding = use_padding
self.max_length_val = max_length_val if max_length_val is not None else max_length
self.max_length_test = max_length_test if max_length_test is not None else max_length
self.padding_side = padding_side
self.val_ratio = val_ratio
self.val_split_seed = val_split_seed
self.val_only = val_only
self.add_eos = add_eos
self.detokenize = detokenize
self.batch_size = batch_size
self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size
self.num_workers = num_workers
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
if self.dest_path is None:
self.dest_path = default_data_path / self._name_
if fault_tolerant:
assert self.shuffle
self.fault_tolerant = fault_tolerant
if ddp:
assert fault_tolerant
self.ddp = ddp
self.fast_forward_epochs = fast_forward_epochs
self.fast_forward_batches = fast_forward_batches
if self.fast_forward_epochs is not None or self.fast_forward_batches is not None:
assert ddp and fault_tolerant
def setup(self, stage=None):
# TODO instantiate with registry
if self.tokenizer_name == "char":
print("**Using Char-level tokenizer**")
self.tokenizer = CharacterTokenizer(
characters=["A", "C", "G", "T", "N"],
model_max_length=self.max_length + 2, # add 2 since default adds eos/eos tokens, crop later
add_special_tokens=False,
padding_side=self.padding_side,
)
# Create all splits: torch datasets (only train/test in this benchmark, val created below)
self.dataset_train, self.dataset_test = [
GenomicBenchmarkDataset(
split=split,
max_length=max_len,
dataset_name=self.dataset_name,
tokenizer=self.tokenizer, # pass the tokenize wrapper
tokenizer_name=self.tokenizer_name,
use_padding=self.use_padding,
d_output=self.d_output,
add_eos=self.add_eos,
dest_path=self.dest_path,
rc_aug=self.rc_aug,
conjoin_train=self.conjoin_train,
conjoin_test=self.conjoin_test,
return_augs=False
)
for split, max_len in zip(["train", "test"], [self.max_length, self.max_length_val])
]
val_data, train_data = torch.utils.data.random_split(
list(zip(self.dataset_train.all_seqs, self.dataset_train.all_labels)),
lengths=[0.1, 0.9],
generator=torch.Generator().manual_seed(self.train_val_split_seed)
)
self.dataset_val = copy.deepcopy(self.dataset_train)
self.dataset_train.all_seqs = [train_data[i][0] for i in range(len(train_data))]
self.dataset_train.all_labels = [train_data[i][1] for i in range(len(train_data))]
self.dataset_val.all_seqs = [val_data[i][0] for i in range(len(val_data))]
self.dataset_val.all_labels = [val_data[i][1] for i in range(len(val_data))]
self.dataset_val.split = "val"
class NucleotideTransformer(HG38):
_name_ = "nucleotide_transformer"
l_output = 0 # need to set this for decoder to work correctly
def __init__(self, dataset_name, train_val_split_seed,
tokenizer_name="char", d_output=None, rc_aug=False,
conjoin_train=False, conjoin_test=False,
max_length=1024, use_padding=True, max_length_val=None, max_length_test=None,
padding_side="left", val_ratio=0.0005, val_split_seed=2357, add_eos=False,
detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1,
shuffle=True, shuffle_eval=None, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False,
fast_forward_epochs=None, fast_forward_batches=None, *args, **kwargs):
self.dataset_name = dataset_name
self.train_val_split_seed = train_val_split_seed
self.tokenizer_name = tokenizer_name
self.d_output = d_output
self.rc_aug = rc_aug
self.conjoin_train = conjoin_train
self.conjoin_test = conjoin_test
self.max_length = max_length
self.use_padding = use_padding
self.max_length_val = max_length_val if max_length_val is not None else max_length
self.max_length_test = max_length_test if max_length_test is not None else max_length
self.padding_side = padding_side
self.val_ratio = val_ratio
self.val_split_seed = val_split_seed
self.val_only = val_only
self.add_eos = add_eos
self.detokenize = detokenize
self.batch_size = batch_size
self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size
self.num_workers = num_workers
self.shuffle = shuffle
self.shuffle_eval = shuffle_eval if shuffle_eval is not None else shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
if fault_tolerant:
assert self.shuffle
self.fault_tolerant = fault_tolerant
if ddp:
assert fault_tolerant
self.ddp = ddp
self.fast_forward_epochs = fast_forward_epochs
self.fast_forward_batches = fast_forward_batches
if self.fast_forward_epochs is not None or self.fast_forward_batches is not None:
assert ddp and fault_tolerant
def setup(self, stage=None):
# TODO instantiate with registry
if self.tokenizer_name == "char":
print("**Using Char-level tokenizer**")
self.tokenizer = CharacterTokenizer(
characters=["A", "C", "G", "T", "N"],
model_max_length=self.max_length + 2, # add 2 since default adds eos/eos tokens, crop later
add_special_tokens=False,
padding_side=self.padding_side,
)
# Create all splits: torch datasets (only train/test in this benchmark)
# self.dataset_train, self.dataset_val = [
self.dataset_train, self.dataset_test = [
NucleotideTransformerDataset(
split=split,
max_length=max_len,
tokenizer=self.tokenizer, # pass the tokenize wrapper
dataset_name=self.dataset_name,
tokenizer_name=self.tokenizer_name,
use_padding=self.use_padding,
d_output=self.d_output,
add_eos=self.add_eos,
rc_aug=self.rc_aug,
conjoin_train=self.conjoin_train,
conjoin_test=self.conjoin_test,
return_augs=False
)
for split, max_len in zip(["train", "test"], [self.max_length, self.max_length_val])
]
ds_train_val_split = self.dataset_train.seqs.train_test_split(
test_size=0.1,
seed=self.train_val_split_seed
)
self.dataset_val = copy.deepcopy(self.dataset_train)
self.dataset_train.seqs = ds_train_val_split["train"]
self.dataset_val.split = "val"
self.dataset_val.seqs = ds_train_val_split["test"]
================================================
FILE: src/dataloaders/utils/mlm.py
================================================
import torch
def mlm_getitem(seq, mlm_probability=0.15, contains_eos=False, tokenizer=None, eligible_replacements=None):
"""Helper method for creating MLM input / target.
Adapted from:
https://github.com/huggingface/transformers/blob/14666775a296a76c88e1aa686a9547f393d322e2/src/transformers/data/data_collator.py#L751
"""
data = seq[:-1].clone() if contains_eos else seq.clone() # remove eos, if applicable
target = data.clone()
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = torch.full(target.shape, mlm_probability)
# TODO: Do we need to avoid "masking" special tokens as is done here?
# https://github.com/huggingface/transformers/blob/14666775a296a76c88e1aa686a9547f393d322e2/src/transformers/data/data_collator.py#L760-L766
masked_indices = torch.bernoulli(probability_matrix).bool()
target[~masked_indices] = tokenizer.pad_token_id # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(target.shape, 0.8)).bool() & masked_indices
data[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(target.shape, 0.5)).bool() & masked_indices & ~indices_replaced
if eligible_replacements is not None:
rand_choice = torch.randint(eligible_replacements.shape[0], size=target.shape)
random_words = eligible_replacements[rand_choice]
else:
random_words = torch.randint(len(tokenizer), size=target.shape, dtype=torch.long)
data[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return data, target
================================================
FILE: src/dataloaders/utils/rc.py
================================================
"""Utility functions for reverse complementing DNA sequences.
"""
from random import random
STRING_COMPLEMENT_MAP = {
"A": "T", "C": "G", "G": "C", "T": "A", "a": "t", "c": "g", "g": "c", "t": "a",
"N": "N", "n": "n",
}
def coin_flip(p=0.5):
"""Flip a (potentially weighted) coin."""
return random() > p
def string_reverse_complement(seq):
"""Reverse complement a DNA sequence."""
rev_comp = ""
for base in seq[::-1]:
if base in STRING_COMPLEMENT_MAP:
rev_comp += STRING_COMPLEMENT_MAP[base]
# if bp not complement map, use the same bp
else:
rev_comp += base
return rev_comp
================================================
FILE: src/models/__init__.py
================================================
================================================
FILE: src/models/baseline/__init__.py
================================================
================================================
FILE: src/models/baseline/genomics_benchmark_cnn.py
================================================
"""Genomics Benchmark CNN model.
Adapted from https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/blob/main/src/genomic_benchmarks/models/torch.py
"""
import torch
from torch import nn
class GenomicsBenchmarkCNN(nn.Module):
def __init__(self, number_of_classes, vocab_size, input_len, embedding_dim=100):
"""Genomics Benchmark CNN model.
`embedding_dim` = 100 comes from:
https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/tree/main/experiments/torch_cnn_experiments
"""
super(GenomicsBenchmarkCNN, self).__init__()
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.cnn_model = nn.Sequential(
nn.Conv1d(in_channels=embedding_dim, out_channels=16, kernel_size=8, bias=True),
nn.BatchNorm1d(16),
nn.ReLU(),
nn.MaxPool1d(2),
nn.Conv1d(in_channels=16, out_channels=8, kernel_size=8, bias=True),
nn.BatchNorm1d(8),
nn.MaxPool1d(2),
nn.Conv1d(in_channels=8, out_channels=4, kernel_size=8, bias=True),
nn.BatchNorm1d(4),
nn.MaxPool1d(2),
nn.Flatten()
)
self.dense_model = nn.Sequential(
nn.Linear(self.count_flatten_size(input_len), 512),
# To be consistent with SSM classifier decoders, we use num_classes (even when it's binary)
nn.Linear(512, number_of_classes)
)
def count_flatten_size(self, input_len):
zeros = torch.zeros([1, input_len], dtype=torch.long)
x = self.embeddings(zeros)
x = x.transpose(1, 2)
x = self.cnn_model(x)
return x.size()[1]
def forward(self, x, state=None): # Adding `state` to be consistent with other models
x = self.embeddings(x)
x = x.transpose(1, 2)
x = self.cnn_model(x)
x = self.dense_model(x)
return x, state # Returning tuple to be consistent with other models
================================================
FILE: src/models/nn/__init__.py
================================================
from .activation import Activation
================================================
FILE: src/models/nn/activation.py
================================================
"""Utilities for activation functions."""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def Activation(activation=None, size=None, dim=-1):
"""Returns a PyTorch activation module."""
if activation in [None, 'id', 'identity', 'linear', 'none']:
return nn.Identity()
elif activation == 'tanh':
return nn.Tanh()
elif activation == 'relu':
return nn.ReLU()
elif activation == 'gelu':
return nn.GELU()
elif activation == 'elu':
return nn.ELU()
elif activation in ['swish', 'silu']:
return nn.SiLU()
elif activation == 'glu':
return nn.GLU(dim=dim)
elif activation.startswith('glu-'):
return GLU(dim=dim, activation=activation[4:])
elif activation == 'sigmoid':
return nn.Sigmoid()
elif activation == 'softplus':
return nn.Softplus()
elif activation == 'modrelu':
return ModReLU(size)
elif activation in ['sqrelu', 'relu2']:
return SquaredReLU()
elif activation == 'laplace':
return Laplace()
# Earlier experimentation with a LN in the middle of the block instead of activation
# IIRC ConvNext does something like this?
# elif activation == 'ln':
# return TransposedLN(dim)
else:
raise NotImplementedError("hidden activation '{}' is not implemented".format(activation))
class GLU(nn.Module):
def __init__(self, dim=-1, activation='sigmoid'):
super().__init__()
assert not activation.startswith('glu')
self.dim = dim
self.activation_fn = Activation(activation)
def forward(self, x):
x, g = torch.split(x, x.size(self.dim) // 2, dim=self.dim)
return x * self.activation_fn(g)
class ModReLU(nn.Module):
# Adapted from https://github.com/Lezcano/expRNN
def __init__(self, features):
# For now we just support square layers
super().__init__()
self.features = features
self.b = nn.Parameter(torch.Tensor(self.features))
self.reset_parameters()
def reset_parameters(self):
self.b.data.uniform_(-0.01, 0.01)
def forward(self, inputs):
norm = torch.abs(inputs)
biased_norm = norm + self.b
magnitude = F.relu(biased_norm)
phase = torch.sign(inputs)
return phase * magnitude
class SquaredReLU(nn.Module):
def forward(self, x):
# return F.relu(x)**2
return torch.square(F.relu(x)) # Could this be faster?
def laplace(x, mu=0.707107, sigma=0.282095):
x = (x - mu).div(sigma * math.sqrt(2.0))
return 0.5 * (1.0 + torch.erf(x))
class Laplace(nn.Module):
def __init__(self, mu=0.707107, sigma=0.282095):
super().__init__()
self.mu = mu
self.sigma = sigma
def forward(self, x):
return laplace(x, mu=self.mu, sigma=self.sigma)
================================================
FILE: src/models/nn/adaptive_softmax.py
================================================
# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
class OptionalParameterList(nn.ParameterList):
def extra_repr(self):
child_lines = []
for k, p in self._parameters.items():
if p is not None:
size_str = 'x'.join(str(size) for size in p.size())
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
parastr = 'Parameter containing: [{} of size {}{}]'.format(
torch.typename(p), size_str, device_str)
child_lines.append(' (' + str(k) + '): ' + parastr)
tmpstr = '\n'.join(child_lines)
return tmpstr
class ProjectedAdaptiveLogSoftmax(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
tie_projs=None, out_layers_weights=None, out_projs=None,
keep_order=False,
bias_scale=0.0,
dropout=0.0,
):
super().__init__()
self.n_token = n_token
self.d_embed = d_embed
self.d_proj = d_proj
self.cutoffs = list(cutoffs) + [n_token]
self.cutoff_ends = [0] + self.cutoffs
self.div_val = div_val
self.shortlist_size = self.cutoffs[0]
self.n_clusters = len(self.cutoffs) - 1
self.head_size = self.shortlist_size + self.n_clusters
# bake the first False into the definition, just as [0] is built into the cutoffs
if tie_projs is None: tie_projs = []
elif isinstance(tie_projs, bool): tie_projs = [tie_projs] * len(cutoffs)
else: tie_projs = list(tie_projs)
tie_projs = [False] + tie_projs
self.tie_projs = tie_projs
if self.n_clusters > 0:
self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
if not out_layers_weights:
self.out_layers_weights = nn.ParameterList()
else:
self.out_layers_weights = out_layers_weights
self.out_layers_biases = nn.ParameterList()
self.shared_out_projs = out_projs
self.out_projs = OptionalParameterList()
self.dropout = dropout
self.drop = nn.Dropout(dropout)
if div_val == 1:
if d_proj != d_embed:
for i in range(len(self.cutoffs)):
if tie_projs[i]:
self.out_projs.append(None)
else:
self.out_projs.append(
nn.Parameter(torch.zeros(d_proj, d_embed))
)
else:
self.out_projs.append(None)
self.out_layers_biases.append(
nn.Parameter(torch.zeros(n_token))
)
if not out_layers_weights:
self.out_layers_weights.append(
nn.Parameter(torch.zeros(n_token, d_embed))
)
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
d_emb_i = d_embed // (div_val ** i)
if tie_projs[i]:
self.out_projs.append(None)
else:
self.out_projs.append(
nn.Parameter(torch.zeros(d_proj, d_emb_i))
)
self.out_layers_biases.append(
nn.Parameter(torch.zeros(r_idx - l_idx))
)
if not out_layers_weights:
self.out_layers_weights.append(
nn.Parameter(torch.zeros(r_idx - l_idx, d_emb_i))
)
for bias in self.out_layers_biases:
bound = bias_scale * d_proj ** -.5
nn.init.uniform_(bias, -bound, bound)
self.keep_order = keep_order
def _compute_logit(self, hidden, weight, bias, proj):
if proj is None:
logit = F.linear(hidden, weight, bias=bias)
else:
if self.dropout > 0.0:
logit = hidden @ proj
logit = self.drop(logit)
logit = logit @ weight.t()
else:
logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
if bias is not None:
logit = logit + bias
return logit
def get_out_proj(self, i):
if self.tie_projs[i]:
if len(self.shared_out_projs) == 0:
return None
elif len(self.shared_out_projs) == 1:
return self.shared_out_projs[0]
else:
return self.shared_out_projs[i]
else:
return self.out_projs[i]
def forward(self, hidden, target, keep_order=False, key_padding_mask=None, *args, **kwargs):
# [21-09-15 AG]: TODO may need to handle key_padding_mask
'''
hidden :: [len*bsz x d_proj]
target :: [len*bsz]
'''
hidden = hidden.reshape(-1, hidden.size(-1))
target = target.reshape(-1)
if hidden.size(0) != target.size(0):
print(hidden.shape, target.shape)
raise RuntimeError('Input and target should have the same size '
'in the batch dimension.')
if self.n_clusters == 0:
logit = self._compute_logit(hidden, self.out_layers_weights[0],
self.out_layers_biases[0], self.get_out_proj(0))
nll = -F.log_softmax(logit, dim=-1) \
.gather(1, target.unsqueeze(1)).squeeze(1)
else:
# construct weights and biases
weights, biases = [], []
for i in range(len(self.cutoffs)):
if self.div_val == 1:
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
weight_i = self.out_layers_weights[0][l_idx:r_idx]
bias_i = self.out_layers_biases[0][l_idx:r_idx]
else:
weight_i = self.out_layers_weights[i]
bias_i = self.out_layers_biases[i]
if i == 0:
weight_i = torch.cat(
[weight_i, self.cluster_weight], dim=0)
bias_i = torch.cat(
[bias_i, self.cluster_bias], dim=0)
weights.append(weight_i)
biases.append(bias_i)
head_weight, head_bias, head_proj = weights[0], biases[0], self.get_out_proj(0)
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
head_logprob = F.log_softmax(head_logit, dim=1)
nll = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device)
offset = 0
cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1):
l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
mask_i = (target >= l_idx) & (target < r_idx)
indices_i = mask_i.nonzero(as_tuple=False).squeeze()
if indices_i.numel() == 0:
continue
target_i = target.index_select(0, indices_i) - l_idx
head_logprob_i = head_logprob.index_select(0, indices_i)
if i == 0:
logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
else:
weight_i, bias_i, proj_i = weights[i], biases[i], self.get_out_proj(i)
hidden_i = hidden.index_select(0, indices_i)
tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
# First term accounts for cluster probabilities
logprob_i = head_logprob_i[:, -i] \
+ tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)
if self.keep_order or keep_order:
nll.index_copy_(0, indices_i, -logprob_i)
else:
nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
offset += logprob_i.size(0) # TODO This should be a bug in the original implementation; it should go into the continue case above as well
return nll.mean() # TODO maybe cases for length or padding_mask
def compute_logits(self, hidden):
"""Compute full vector of logits
Adapted from https://github.com/kimiyoung/transformer-xl/issues/88
"""
hidden = hidden.reshape(-1, hidden.size(-1))
if self.n_clusters == 0:
logits = self._compute_logit(hidden, self.out_layers_weights[0],
self.out_layers_biases[0], self.get_out_proj(0))
return logits
else:
# construct weights and biases
weights, biases = [], []
for i in range(len(self.cutoffs)):
if self.div_val == 1:
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
weight_i = self.out_layers_weights[0][l_idx:r_idx]
bias_i = self.out_layers_biases[0][l_idx:r_idx]
else:
weight_i = self.out_layers_weights[i]
bias_i = self.out_layers_biases[i]
if i == 0:
weight_i = torch.cat(
[weight_i, self.cluster_weight], dim=0)
bias_i = torch.cat(
[bias_i, self.cluster_bias], dim=0)
weights.append(weight_i)
biases.append(bias_i)
head_weight, head_bias, head_proj = weights[0], biases[0], self.get_out_proj(0)
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
head_logprob = F.log_softmax(head_logit, dim=1)
out_full_logps = [head_logprob[:, :self.cutoffs[0]]]
offset = 0
cutoff_values = [0] + self.cutoffs
for i in range(1, len(cutoff_values) - 1):
l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
head_logprob_i = head_logprob # .index_select(0, indices_i)
if i == 0:
logprob_i = head_logprob_i
else:
weight_i, bias_i, proj_i = weights[i], biases[i], self.get_out_proj(i)
hidden_i = hidden # .index_select(0, indices_i)
tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
logprob_i = head_logprob_i[:, -i].view(-1, 1) + tail_logprob_i
offset += logprob_i.size(0)
out_full_logps.append(logprob_i)
out_full_logps = torch.cat(out_full_logps, dim = 1)
# print(torch.sum(out_full_ps), out_full_ps.shape)
return out_full_logps
class AdaptiveEmbedding(nn.Module):
""" Copy of transformers.AdaptiveEmbedding that works with fp16 by replacing the index_put_ operation
Initialization has been fixed for the case when d_proj = d_embed
"""
def __init__(self, n_token, d_embed, d_proj, cutoffs : List[int], div_val=1, init_scale=1.0, sample_softmax=False, dropout=0.0):
super().__init__()
self.n_token = n_token
self.d_embed = d_embed
self.cutoffs = list(cutoffs) + [n_token]
self.div_val = div_val
self.d_proj = d_proj
self.drop = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
self.emb_scale = d_proj ** 0.5
self.cutoff_ends = [0] + self.cutoffs
self.emb_layers = nn.ModuleList()
self.emb_projs = nn.ParameterList()
if div_val == 1:
self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0))
_init_embed(self.emb_layers[-1].weight, d_embed, init_scale)
# torch.nn.init.normal_(self.emb_layers[-1].weight, mean=0, std=init_scale * d_embed ** -.5)
if d_proj != d_embed: # TODO
# self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
# torch.nn.init.normal_(self.emb_projs[-1], mean=0, std=init_scale * 1./self.emb_scale)
_init_proj(self.emb_projs[-1], d_proj, init_scale)
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
d_emb_i = d_embed // (div_val ** i)
self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i))
# torch.nn.init.normal_(self.emb_layers[-1].weight, mean=0, std=init_scale * d_emb_i ** -.5)
_init_embed(self.emb_layers[-1].weight, d_emb_i, init_scale)
self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
# torch.nn.init.normal_(self.emb_projs[-1], mean=0, std=init_scale * 1./self.emb_scale)
_init_proj(self.emb_projs[-1], d_proj, init_scale)
def forward(self, inp):
if self.div_val == 1:
embed = self.emb_layers[0](inp)
embed = self.drop(embed)
if self.d_proj != self.d_embed:
embed = F.linear(embed, self.emb_projs[0])
else:
param = next(self.parameters())
inp_flat = inp.reshape(-1)
# Changes from original impl
# emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device)
embeddings = []
indices = torch.zeros_like(inp_flat) # empty should work as long as cutoffs[-1] > max token
_total_tokens = 0
# emb_flat = inp.new_zeros(inp_flat.size(0), self.d_proj)
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
indices_i = mask_i.nonzero().squeeze(-1) # shape (_tokens,)
_tokens = indices_i.numel()
if _tokens == 0:
continue
inp_i = inp_flat.index_select(0, indices_i) - l_idx
emb_i = self.emb_layers[i](inp_i)
emb_i = self.drop(emb_i)
emb_i = F.linear(emb_i, self.emb_projs[i])
# Changes
embeddings.append(emb_i)
indices.index_put_(
(indices_i,),
torch.arange(_tokens, device=inp.device) + _total_tokens
)
_total_tokens += _tokens
# emb_flat.index_copy_(0, indices_i, emb_i)
embeddings = torch.cat(embeddings, dim=0)
emb_flat = embeddings[indices]
embed_shape = inp.size() + (self.d_proj,)
embed = emb_flat.view(embed_shape)
embed.mul_(self.emb_scale)
# embed.div_(self.emb_scale)
return embed
def _init_weight(weight, d : int, init_scale : Optional[float], default=None):
assert init_scale or default
if init_scale is None:
std = default
else:
std = init_scale * (d ** -0.5)
nn.init.normal_(weight, mean=0, std=std)
_init_embed = functools.partial(_init_weight, default=0.02)
_init_proj = functools.partial(_init_weight, default=0.01)
================================================
FILE: src/models/nn/utils.py
================================================
""" Utility wrappers around modules to let them handle Args and extra arguments """
import inspect
from functools import wraps
import torch
from torch import nn
def wrap_kwargs(f):
"""
Given a callable f that can consume some named arguments,
wrap it with a kwargs that passes back any unused args
EXAMPLES
--------
Basic usage:
def foo(x, y=None):
return x
wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2})
--------
The wrapped function can return its own argument dictionary,
which gets merged with the new kwargs.
def foo(x, y=None):
return x, {}
wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2})
def foo(x, y=None):
return x, {"y": y, "z": None}
wrap_kwargs(foo)(0, y=1, z=2) == (0, {'y': 1, 'z': 2})
--------
The wrapped function can have its own kwargs parameter:
def foo(x, y=None, **kw_args):
return x, {}
wrap_kwargs(foo)(0, y=1, z=2) == (0, {})
--------
Partial functions and modules work automatically:
class Module:
def forward(self, x, y=0):
return x, {"y": y+1}
m = Module()
wrap_kwargs(m.forward)(0, y=1, z=2) == (0, {'y': 2, 'z': 2})
"""
sig = inspect.signature(f)
# Check if f already has kwargs
has_kwargs = any([
param.kind == inspect.Parameter.VAR_KEYWORD
for param in sig.parameters.values()
])
if has_kwargs:
@wraps(f)
def f_kwargs(*args, **kwargs):
y = f(*args, **kwargs)
if isinstance(y, tuple) and isinstance(y[-1], dict):
return y
else:
return y, {}
else:
param_kwargs = inspect.Parameter("kwargs", kind=inspect.Parameter.VAR_KEYWORD)
sig_kwargs = inspect.Signature(parameters=list(sig.parameters.values())+[param_kwargs])
@wraps(f)
def f_kwargs(*args, **kwargs):
bound = sig_kwargs.bind(*args, **kwargs)
if "kwargs" in bound.arguments:
kwargs = bound.arguments.pop("kwargs")
else:
kwargs = {}
y = f(**bound.arguments)
if isinstance(y, tuple) and isinstance(y[-1], dict):
return *y[:-1], {**y[-1], **kwargs}
else:
return y, kwargs
return f_kwargs
def discard_kwargs(f):
if f is None: return None
f_kwargs = wrap_kwargs(f)
@wraps(f)
def f_(*args, **kwargs):
return f_kwargs(*args, **kwargs)[0]
return f_
def PassthroughSequential(*modules):
"""Special Sequential module that chains kwargs.
Semantics are the same as nn.Sequential, with extra convenience features:
- Discard None modules
- Flatten inner Sequential modules
- In case with 0 or 1 Module, rename the class for ease of inspection
"""
def flatten(module):
if isinstance(module, nn.Sequential):
return sum([flatten(m) for m in module], [])
else:
return [module]
modules = flatten(nn.Sequential(*modules))
modules = [module for module in modules if module if not None]
class Sequential(nn.Sequential):
def forward(self, x, **kwargs):
for layer in self:
x, kwargs = wrap_kwargs(layer.forward)(x, **kwargs)
return x, kwargs
def step(self, x, **kwargs):
for layer in self:
fn = getattr(layer, "step", layer.forward)
x, kwargs = wrap_kwargs(fn)(x, **kwargs)
return x, kwargs
if len(modules) == 0:
Sequential.__name__ = "Identity"
elif len(modules) == 1:
Sequential.__name__ = type(modules[0]).__name__
return Sequential(*modules)
================================================
FILE: src/models/sequence/__init__.py
================================================
================================================
FILE: src/models/sequence/dna_embedding.py
================================================
"""DNA Embedding Model.
Backbones from LM pre-training models, used for downstream tasks.
"""
from functools import partial
import torch
import torch.nn as nn
from flash_attn.utils.generation import GenerationMixin
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MixerModel
from mamba_ssm.models.mixer_seq_simple import _init_weights as _init_weights_mamba
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear
except ImportError:
ColumnParallelLinear = None
from caduceus.configuration_caduceus import CaduceusConfig
from caduceus.modeling_caduceus import Caduceus
from src.models.sequence.long_conv_lm import LMBackbone
from src.models.sequence.long_conv_lm import _init_weights
class DNAEmbeddingModel(nn.Module, GenerationMixin):
"""DNA Embedding Model.
Same as ConvLMHeadModel (in long_conv_lm.py), except no decoder head, we just pass back the hidden states for
downstream tasks.
"""
def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int,
process_group=None, layer=None,
attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0,
resid_dropout: float = 0.0, embed_dropout: float = 0.1, dropout_cls=nn.Dropout,
norm_epsilon: float = 1e-5,
rms_norm: bool = False,
initializer_cfg=None,
checkpoint_mlp=False,
checkpoint_mixer=False,
fused_mlp=False, fused_dropout_add_ln=False, residual_in_fp32=False,
pad_vocab_size_multiple: int = 1, sequence_parallel=True,
device=None, dtype=None, return_hidden_state=False, **kwargs) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.d_model = d_model # for decoder
self.process_group = process_group
self.return_hidden_state = return_hidden_state
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
self.backbone = LMBackbone(
d_model=d_model,
n_layer=n_layer,
d_inner=d_inner,
vocab_size=vocab_size,
process_group=process_group,
layer=layer,
attn_layer_idx=attn_layer_idx,
attn_cfg=attn_cfg,
max_position_embeddings=max_position_embeddings,
resid_dropout=resid_dropout,
embed_dropout=embed_dropout,
dropout_cls=dropout_cls,
norm_epsilon=norm_epsilon,
rms_norm=rms_norm,
initializer_cfg=initializer_cfg,
fused_mlp=fused_mlp,
fused_dropout_add_ln=fused_dropout_add_ln,
residual_in_fp32=residual_in_fp32,
sequence_parallel=sequence_parallel,
checkpoint_mlp=checkpoint_mlp,
checkpoint_mixer=checkpoint_mixer,
**factory_kwargs, **kwargs
)
# Initialize weights and apply final processing
self.apply(partial(_init_weights, n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {})))
def forward(self, input_ids, position_ids=None, inference_params=None, state=None): # state for the repo interface
"""DNA Embedding Model forward pass."""
hidden_states = self.backbone(input_ids, position_ids=position_ids,
inference_params=inference_params)
# we only need the last hidden state for embeddings (decoder head will predict classification task)
return hidden_states, None
@property
def d_output(self):
"""Model /embedding dimension, used for decoder mapping.
"""
if getattr(self, "d_model", None) is None:
raise NotImplementedError("SequenceModule instantiation must set d_output")
return self.d_model
class DNAEmbeddingModelMamba(DNAEmbeddingModel):
"""Custom DNA Embedding Model that is compatible with open-source Mamba repo."""
def __init__(
self,
config: MambaConfig,
initializer_cfg=None,
conjoin_train=False,
conjoin_test=False,
device=None,
dtype=None,
):
super(DNAEmbeddingModel, self).__init__() # nn.Module.__init__()
self.config = config
d_model = config.d_model
self.d_model = d_model # for decoder
n_layer = config.n_layer
vocab_size = config.vocab_size
ssm_cfg = config.ssm_cfg
rms_norm = config.rms_norm
residual_in_fp32 = config.residual_in_fp32
fused_add_norm = config.fused_add_norm
pad_vocab_size_multiple = config.pad_vocab_size_multiple
factory_kwargs = {"device": device, "dtype": dtype}
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
self.backbone = MixerModel(
d_model=d_model,
n_layer=n_layer,
vocab_size=vocab_size,
ssm_cfg=ssm_cfg,
rms_norm=rms_norm,
initializer_cfg=initializer_cfg,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
**factory_kwargs,
)
# Initialize weights and apply final processing
self.apply(
partial(
_init_weights_mamba,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
self.conjoin_train = conjoin_train
self.conjoin_test = conjoin_test
def forward(self, input_ids, position_ids=None, inference_params=None, state=None): # state for the repo interface
"""Mamba backbone-specific forward pass that does not use `position_ids`."""
hidden_states = self.backbone(input_ids, inference_params=inference_params)
# we only need the last hidden state for embeddings (decoder head will predict classification task)
return hidden_states, None
class DNAEmbeddingModelCaduceus(DNAEmbeddingModel):
"""Custom DNA Embedding Model that is compatible with Caduceus models."""
def __init__(
self,
config: CaduceusConfig,
device=None,
dtype=None,
conjoin_train=False,
conjoin_test=False,
):
super(DNAEmbeddingModel, self).__init__() # nn.Module.__init__()
self.config = config
self.d_model = config.d_model # for decoder
factory_kwargs = {"device": device, "dtype": dtype}
self.caduceus = Caduceus(
config=config,
**factory_kwargs,
)
self.conjoin_train = conjoin_train
self.conjoin_test = conjoin_test
def forward(self, input_ids, position_ids=None, inference_params=None, state=None): # state for the repo interface
"""Caduceus backbone-specific forward pass that does not use `position_ids`."""
if self.config.rcps: # Hidden states have 2 * d_model channels for RCPS
hidden_states = self.caduceus(input_ids, return_dict=False)
num_chan = hidden_states.shape[-1]
return torch.stack(
[hidden_states[..., :num_chan // 2], torch.flip(hidden_states[..., num_chan // 2:], dims=[1, 2])],
dim=-1
), None
if self.conjoin_train or (self.conjoin_test and not self.training): # For conjoining / post-hoc conjoining
assert input_ids.ndim == 3, "Input must be 3D tensor, where channels corresponds to forward and rc strands"
hidden_states = self.caduceus(input_ids[..., 0], return_dict=False)
hidden_states_rc = self.caduceus(input_ids[..., 1], return_dict=False)
# Stack along channel dimension (dim=-1)
return torch.stack([hidden_states, hidden_states_rc], dim=-1), None
return self.caduceus(input_ids, return_dict=False), None
def load_backbone(model, state_dict, freeze_backbone=False, ignore_head=True):
"""
Modifies state dict loading with custom function. This is necessary because the head of
a lm outputs logits for vocab, but we just need the embeddings for downstream tasks.
inputs:
model: nn.Module, the from 'scratch' model
state_dict: dict, from the pretrained weights
ignore_head: bool, whether to inflate weights in the head (or keep scratch weights).
If number of classes changes, then you need to use this.
return:
state_dict: dict, update with inflated weights
"""
# consumes prefix from pretrained model, if necessary
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
state_dict, "model."
)
model_new_params_dict = model.state_dict()
updated_model_state_dict = {}
# loop through scratch model keys (pretrained may have extra stuff)
for key in sorted(model_new_params_dict.keys()):
loaded_params = state_dict.get(key, None)
if loaded_params is None:
# This should never happen, it should be there!
print("Missing key in pretrained model!", key)
raise Exception
elif ignore_head and 'head' in key:
# ignore head weights
print("found head key / parameter, load from scratch", key)
# using scratch by default, nothing needed
used_params = model_new_params_dict[key]
elif "decoder" in key:
print("found decoder key / parameter, load from scratch", key)
used_params = model_new_params_dict[key]
else:
print('key: shape MATCH, loading', key) # load matched weights
used_params = loaded_params
# we need to pass back a state dict with the '.model' prefix!!!!!
key_with_prefix = 'model.' + key
updated_model_state_dict[key_with_prefix] = used_params
if freeze_backbone:
print("freezing model backbone params!")
# note, decoder not included in backbone
for name, param in model.named_parameters():
param.requires_grad = False
# we have updated the new model state dict with pretrained now
return updated_model_state_dict
================================================
FILE: src/models/sequence/hyena.py
================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
try:
from src.ops.fftconv import fftconv_ref, fftconv_func, fftconv_heads_ref
except ImportError:
fftconv_func = None
try:
from flash_attn.ops.fused_dense import FusedDense
except ImportError:
FusedDense = None
import src.utils.registry as registry
from src.utils.train import OptimModule
from src.utils.config import instantiate, auto_assign_attrs
from src.models.nn import Activation
class FFTConvFuncv2(torch.autograd.Function):
@staticmethod
def forward(ctx, u, k):
seqlen = u.shape[-1]
if len(u.shape) > 3:
k = k.unsqueeze(1)
fft_size = 2 * seqlen
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
ctx.save_for_backward(u_f, k_f)
return y
@staticmethod
def backward(ctx, dout):
u_f, k_f = ctx.saved_tensors
seqlen = dout.shape[-1]
fft_size = 2 * seqlen
dout_f = torch.fft.rfft(dout, n=fft_size)
du = torch.fft.irfft(dout_f * k_f.conj(), n=fft_size, norm="forward")[
..., :seqlen
]
dk = torch.fft.irfft(dout_f * u_f.conj(), n=fft_size, norm="forward")[
..., :seqlen
]
return du, dk.squeeze()
def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None):
seqlen = u.shape[-1]
fft_size = 2 * seqlen
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
if k_rev is not None:
k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
k_f = k_f + k_rev_f.conj()
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
if len(u.shape) > 3:
k_f = k_f.unsqueeze(1)
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
out = y + u * D.unsqueeze(-1)
if gelu:
out = F.gelu(out)
if dropout_mask is not None:
return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype)
else:
return out.to(dtype=u.dtype)
@torch.jit.script
def mul_sum(q, y):
return (q * y).sum(dim=1)
class Sin(nn.Module):
def __init__(self, dim, w=10, train_freq=True):
super().__init__()
self.freq = (
nn.Parameter(w * torch.ones(1, dim))
if train_freq
else w * torch.ones(1, dim)
)
def forward(self, x):
return torch.sin(self.freq * x)
class PositionalEmbedding(OptimModule):
def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float = 1e-5, **kwargs):
"""Complex exponential positional embeddings for Hyena filters."""
super().__init__()
self.seq_len = seq_len
# The time embedding fed to the filteres is normalized so that t_f = 1
t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
if emb_dim > 1:
bands = (emb_dim - 1) // 2
# To compute the right embeddings we use the "proper" linspace
t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1
f = torch.linspace(1e-4, bands - 1, bands)[None, None]
z = torch.exp(-1j * f * w)
z = torch.cat([t, z.real, z.imag], dim=-1)
self.register("z", z, lr=lr_pos_emb)
self.register("t", t, lr=0.0)
def forward(self, L):
return self.z[:, :L], self.t[:, :L]
class ExponentialModulation(OptimModule):
def __init__(
self,
d_model,
fast_decay_pct=0.3,
slow_decay_pct=1.5,
target=1e-2,
modulation_lr=0.0,
shift: float = 0.0,
**kwargs,
):
super().__init__()
self.shift = shift
max_decay = math.log(target) / fast_decay_pct
min_decay = math.log(target) / slow_decay_pct
deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
self.register("deltas", deltas, lr=modulation_lr)
def forward(self, t, x):
decay = torch.exp(-t * self.deltas.abs())
x = x * (decay + self.shift)
return x
class HyenaFilter(OptimModule):
def __init__(
self,
d_model,
emb_dim=3, # dim of input to MLP, augments with positional encoding
order=16, # width of the implicit MLP
fused_fft_conv=False,
seq_len=1024,
lr=1e-3,
lr_pos_emb=1e-5,
dropout=0.0,
w=1, # frequency of periodic activations
wd=0, # weight decay of kernel parameters
bias=True,
num_inner_mlps=2,
linear_mixer=False,
modulate: bool = True,
normalized=False,
**kwargs,
):
"""
Implicit long filter with modulation.
Args:
d_model: number of channels in the input
emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
order: width of the FFN
num_inner_mlps: number of inner linear layers inside filter MLP
Note:
filter_dropout is not implemented
"""
super().__init__()
auto_assign_attrs(
self, d_model=d_model, emb_dim=emb_dim, seq_len=seq_len, modulate=modulate
)
self.use_bias = bias
self.fused_fft_conv = fused_fft_conv
self.bias = nn.Parameter(torch.randn(self.d_model))
self.dropout = nn.Dropout(dropout)
act = Sin(dim=order, w=w)
assert (
emb_dim % 2 != 0 and emb_dim >= 3
), "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb)
# uses a variable number of inner linear layers
if linear_mixer is False:
self.implicit_filter = nn.Sequential(
nn.Linear(emb_dim, order),
act,
)
for i in range(num_inner_mlps):
self.implicit_filter.append(nn.Linear(order, order))
self.implicit_filter.append(act)
# final linear layer
self.implicit_filter.append(nn.Linear(order, d_model, bias=False))
else:
self.implicit_filter = nn.Sequential(
nn.Linear(emb_dim, d_model, bias=False),
)
self.modulation = ExponentialModulation(d_model, **kwargs)
self.normalized = normalized
for c in self.implicit_filter.children():
for name, v in c.state_dict().items():
optim = {"weight_decay": wd, "lr": lr}
setattr(getattr(c, name), "_optim", optim)
def filter(self, L, *args, **kwargs):
z, t = self.pos_emb(L)
h = self.implicit_filter(z)
if self.modulate:
h = self.modulation(t, h)
if self.normalized:
h = h / torch.norm(h, dim=-1, p=1, keepdim=True)
return h
def forward(self, x, L, k=None, bias=None, *args, **kwargs):
if k is None:
k = self.filter(L)
# Ensure compatibility with filters that return a tuple
k = k[0] if type(k) is tuple else k
if bias is None:
bias = self.bias
bias = bias if self.use_bias else 0 * bias
if self.fused_fft_conv:
bias = bias.to(dtype=torch.float32)
y = fftconv_func(
x,
k,
bias,
dropout_mask=None,
gelu=False,
force_fp16_output=torch.is_autocast_enabled(),
)
else:
y = fftconv_ref(x, k, bias, dropout_mask=None, gelu=False)
# y = (
# FFTConvFuncv2.apply(x, k.to(dtype=torch.float32))
# + bias.unsqueeze(-1) * x
# )
return y.to(dtype=x.dtype)
class HyenaOperator(nn.Module):
def __init__(
self,
d_model,
l_max,
order=2,
filter_order=64,
num_heads=1,
inner_factor=1,
num_blocks=1,
fused_bias_fc=False,
outer_mixing=False,
dropout=0.0,
filter_dropout=0.0,
filter_cls="hyena-filter",
post_order_ffn=False,
jit_filter=False,
short_filter_order=3,
activation="id",
return_state=False,
**filter_args,
):
r"""
Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf
Args:
d_model (int): Dimension of the input and output embeddings (width of the layer)
l_max: (int): Maximum input sequence length. Defaults to None
order: (int): Depth of the Hyena recurrence. Defaults to 2
filter_order: (int): Width of the FFN parametrizing the implicit filter. Defaults to 64
num_heads: (int): Number of heads. Defaults to 1
inner_factor: (int): Width multiplier. Defaults to 1
num_blocks: (int): Number of blocks in sequence length. Defaults to 1
fused_bias_fc: (bool): Whether to use fused bias FC. Defaults to False
dropout: (float): Dropout probability. Defaults to 0.0
filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0
post_order_ffn: (bool): Apply a dense layer between steps of the recurrence. Defaults to False
jit_filter: (bool): Whether JIT the implicit filter function. Defaults to False
short_filter_order: (int): Length of the explicit input convolutional filter. Defaults to 3
activation: (str): type of act between kernel output and FF (default identity)
return_state: (bool): whether to return a state
"""
super().__init__()
assert (
d_model % num_heads == 0
), f"Model dimension {d_model} must be divisible by num heads {num_heads}"
assert (
l_max % num_blocks == 0
), f"Maximum signal length {l_max} must be divisible by block dimension {num_blocks}"
block_dim = l_max // num_blocks
head_dim = d_model // num_heads
auto_assign_attrs(
self,
d_model=d_model,
order=order,
l_max=l_max,
num_heads=num_heads,
inner_factor=inner_factor,
block_dim=block_dim,
head_dim=head_dim,
filter_order=filter_order,
post_order_ffn=post_order_ffn,
short_filter_order=short_filter_order,
num_blocks=num_blocks,
filter_dropout=filter_dropout,
jit_filter=jit_filter,
outer_mixing=outer_mixing,
activation=activation,
return_state=return_state,
)
self.activation = Activation(activation)
self.dropout = nn.Dropout(dropout)
self.setup_projections(fused_bias_fc, inner_factor)
self.setup_filters(filter_cls, filter_args)
def setup_projections(self, fused_bias_fc, inner_factor):
"Initializes input and output projections (over the width dimension)"
if fused_bias_fc and FusedDense is None:
raise ImportError("fused_dense is not installed")
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.out_proj = linear_cls(self.d_model * inner_factor, self.d_model)
self.in_proj = linear_cls(self.d_model, (self.order + 1) * self.d_model)
if self.post_order_ffn:
self.ord_proj_w = nn.Parameter(
torch.randn(self.order, self.num_heads, self.num_heads)
/ math.sqrt(self.head_dim)
)
def setup_filters(self, filter_cls, filter_args):
"Initializes the explicit and implicit filters"
assert self.order >= 2, f"Order must be at least 2, (got {self.order})"
total_width = self.d_model * self.inner_factor * (self.order + 1)
self.short_filter = nn.Conv1d(
in_channels=total_width,
out_channels=total_width,
kernel_size=self.short_filter_order,
groups=total_width,
padding=self.short_filter_order - 1,
)
filter_cls = instantiate(registry.layer, filter_cls, partial=True)
self.filter_fn = filter_cls(
self.head_dim * self.inner_factor * (self.order - 1),
order=self.filter_order,
seq_len=self.l_max,
channels=1,
dropout=self.filter_dropout,
**filter_args,
)
if self.jit_filter:
self.filter_fn = torch.jit.script(self.filter_fn, self.L)
def recurrence(self, u, state):
"Fast inference mode via distilled recurrence"
raise NotImplementedError("Working on it!")
def forward(self, u, *args, **kwargs):
l = u.size(-2)
l_filter = min(l, self.l_max)
u = self.in_proj(u)
u = rearrange(u, "b l d -> b d l")
uc = self.short_filter(u)[..., :l_filter]
uc = rearrange(
uc,
"b (ho v) (z l) -> b ho v z l",
z=self.num_blocks,
ho=self.num_heads,
v=self.head_dim * (self.order + 1),
)
*x, v = uc.split(self.d_model, dim=2)
k = self.filter_fn.filter(l_filter)
# `c` is always 1 by default
k = rearrange(k, "c l (v o) -> c o v l", v=self.head_dim, o=self.order - 1)[0]
bias = rearrange(
self.filter_fn.bias, "(v o) -> o v", v=self.head_dim, o=self.order - 1
)
for o, x_i in enumerate(reversed(x[1:])):
if self.outer_mixing:
v = rearrange(v, "b h v z l -> b h 1 v z l")
v = self.dropout(v * rearrange(x_i, "b h v z l -> b h v 1 z l"))
v = v.sum(dim=2)
else:
v = self.dropout(v * x_i)
# the bias term is broadcasted. Last dimension (l) is handled by fftconv
v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None])
if self.post_order_ffn:
w = self.ord_proj_w[o]
v = mul_sum(
rearrange(w, "h1 h2 -> 1 h1 h2 1 1 1"),
rearrange(v, "b h v z l -> b h 1 v z l"),
)
y = self.activation(
rearrange(
v * x[0],
"b h v z l -> b (z l) (h v)",
z=self.num_blocks,
h=self.num_heads,
)
)
y = self.out_proj(y)
if self.return_state:
return y, None
return y
@property
def d_output(self):
return self.d_model
================================================
FILE: src/models/sequence/long_conv_lm.py
================================================
import copy
import math
import re
from collections import namedtuple
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from flash_attn.modules.block import Block
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
from flash_attn.utils.generation import GenerationMixin
from torch.utils.checkpoint import checkpoint
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear
except ImportError:
ColumnParallelLinear = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
dropout_add_layer_norm = None
from src.utils import instantiate
import src.utils.registry as registry
class CheckpointedModule(torch.nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
def forward(self, x):
return checkpoint(self.layer, x)
def create_mixer_cls(
layer=None,
process_group=None,
attn_layer_idx=None,
attn_cfg=None,
layer_idx=None,
sequence_parallel=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
parallel_kwargs = (
{"process_group": process_group, "sequence_parallel": sequence_parallel}
if process_group is not None
else {}
)
if attn_layer_idx is not None and layer_idx in attn_layer_idx:
causal = True if attn_cfg is None else attn_cfg.pop("causal", True)
fused_bias_fc = (
False if attn_cfg is None else attn_cfg.get("fused_bias_fc", False)
)
if not fused_bias_fc:
assert process_group is None, "TensorParallel MHA requires fused_bias_fc"
mha_cls = MHA if process_group is None else ParallelMHA
# ParallelMHA doesn't take 'fused_bias_fc', it is assumed that we fuse matmul + bias
if process_group is not None:
attn_cfg = copy.deepcopy(attn_cfg) # Don't modify the original cfg
attn_cfg.pop("fused_bias_fc", None)
mixer_cls = partial(
mha_cls,
causal=causal,
layer_idx=layer_idx,
**(attn_cfg if attn_cfg is not None else {}),
**parallel_kwargs,
**factory_kwargs,
)
else:
fused_bias_fc = False if layer is None else layer.get("fused_bias_fc", False)
if process_group is not None:
assert fused_bias_fc, "TensorParallel SSM requires fused_bias_fc"
mixer_cls = instantiate(
registry.layer,
layer,
partial=True,
layer_idx=layer_idx,
**factory_kwargs,
**parallel_kwargs,
)
return mixer_cls
def create_mlp_cls(
d_model,
d_inner=None,
process_group=None,
fused_mlp=False,
sequence_parallel=True,
identity_mlp=False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
inner_dim = d_inner if d_inner is not None else 4 * d_model
if process_group is not None:
assert fused_mlp, "Tensor Parallel is only implemented for FusedMLP"
if not fused_mlp and not identity_mlp:
mlp_cls = partial(
Mlp,
hidden_features=inner_dim,
activation=partial(F.gelu, approximate="tanh"),
**factory_kwargs,
)
elif fused_mlp:
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
parallel_kwargs = (
{"process_group": process_group, "sequence_parallel": sequence_parallel}
if process_group is not None
else {}
)
mlp_cls = partial(
mlp_cls, hidden_features=inner_dim, **parallel_kwargs, **factory_kwargs
)
else:
mlp_cls = nn.Identity
return mlp_cls
def create_block(
d_model,
d_inner=None,
process_group=None,
layer=None,
attn_layer_idx=None,
attn_cfg=None,
layer_norm_epsilon=1e-5,
resid_dropout1=0.0,
resid_dropout2=0.0,
residual_in_fp32=False,
fused_mlp=False,
identity_mlp=False,
fused_dropout_add_ln=False,
layer_idx=None,
sequence_parallel=True,
checkpoint_mlp=False,
checkpoint_mixer=False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
mixer_cls = create_mixer_cls(
layer=layer,
process_group=process_group,
attn_layer_idx=attn_layer_idx,
attn_cfg=attn_cfg,
layer_idx=layer_idx,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
mlp_cls = create_mlp_cls(
d_model,
d_inner=d_inner,
process_group=process_group,
fused_mlp=fused_mlp,
identity_mlp=identity_mlp,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
norm_cls = partial(nn.LayerNorm, eps=layer_norm_epsilon, **factory_kwargs)
block = Block(
d_model,
mixer_cls,
mlp_cls,
norm_cls=norm_cls,
prenorm=True,
resid_dropout1=resid_dropout1,
resid_dropout2=resid_dropout2,
fused_dropout_add_ln=fused_dropout_add_ln,
residual_in_fp32=residual_in_fp32,
sequence_parallel=sequence_parallel and process_group is not None,
mark_shared_params=process_group is not None,
)
block.layer_idx = layer_idx
if checkpoint_mlp:
block.mlp = CheckpointedModule(block.mlp)
if checkpoint_mixer:
block.mixer = CheckpointedModule(block.mixer)
return block
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
module,
n_layer,
initializer_range=0.02,
rescale_prenorm_residual=True,
glu_act=False,
):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
nn.init.normal_(
p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)
)
# If using GLU activation for now, we scale the std by 2
elif name in ["output_linear.0.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
if not glu_act:
nn.init.normal_(
p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)
)
else:
out_features = p.shape[0]
# Multiplying the first half of the matrix by 2 since sigmoid scales it down by 0.5
# on average.
nn.init.normal_(
p[: out_features // 2],
mean=0.0,
std=initializer_range / math.sqrt(2 * n_layer) * 2,
)
class LMBackbone(nn.Module):
def __init__(
self,
d_model: int,
n_layer: int,
d_inner: int,
vocab_size: int,
process_group=None,
layer=None,
attn_layer_idx=None,
attn_cfg=None,
max_position_embeddings=0,
resid_dropout: float = 0.0,
embed_dropout: float = 0.1,
dropout_cls=nn.Dropout,
layer_norm_epsilon: float = 1e-5,
initializer_cfg=None,
fused_mlp=False,
identity_mlp=False,
fused_dropout_add_ln=False,
residual_in_fp32=False,
sequence_parallel=True,
checkpoint_mlp=False,
checkpoint_mixer=False,
device=None,
dtype=None,
**kwargs,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.residual_in_fp32 = residual_in_fp32
if process_group is None:
self.embeddings = GPT2Embeddings(
d_model, vocab_size, max_position_embeddings, **factory_kwargs
)
else:
self.embeddings = ParallelGPT2Embeddings(
d_model,
vocab_size,
max_position_embeddings,
process_group=process_group,
sequence_parallel=self.sequence_parallel,
**factory_kwargs,
)
# We change the order of dropout, residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# nn.Dropout probabilities are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self.fused_dropout_add_ln = fused_dropout_add_ln
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
raise ImportError("dropout_add_layer_norm is not installed")
self.layers = nn.ModuleList(
[
create_block(
d_model,
d_inner=d_inner,
process_group=process_group,
layer=layer,
attn_layer_idx=attn_layer_idx,
attn_cfg=attn_cfg,
layer_norm_epsilon=layer_norm_epsilon,
resid_dropout1=embed_dropout if i == 0 else resid_dropout,
resid_dropout2=resid_dropout,
residual_in_fp32=residual_in_fp32,
fused_mlp=fused_mlp,
identity_mlp=identity_mlp,
fused_dropout_add_ln=fused_dropout_add_ln,
layer_idx=i,
sequence_parallel=self.sequence_parallel,
checkpoint_mlp=checkpoint_mlp,
checkpoint_mixer=checkpoint_mixer,
**factory_kwargs,
)
for i in range(n_layer)
]
)
self.drop_f = nn.Dropout(resid_dropout)
self.ln_f = nn.LayerNorm(d_model, eps=layer_norm_epsilon, **factory_kwargs)
if process_group is not None:
for p in self.ln_f.parameters():
# Mark the norm parameters as "shared_params" so that we sync their values at init.
p._shared_params = True
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
if self.sequence_parallel:
p._sequence_parallel = True
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
self.tie_weights()
def tie_weights(self):
if self.process_group is not None:
sync_shared_params(self, self.process_group)
def forward(self, input_ids, position_ids=None, inference_params=None):
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# dimensions so that we can split on it easily, in case of small batch size.
# Only the attention/SSM layers need to know the seqlen.
embedding_kwargs = (
{"combine_batch_seqlen_dim": True}
if self.process_group is not None and self.sequence_parallel
else {}
)
hidden_states = self.embeddings(
input_ids, position_ids=position_ids, **embedding_kwargs
)
residual = None
mixer_kwargs = (
{"seqlen": input_ids.shape[1]}
if self.process_group is not None and self.sequence_parallel
else {}
)
if inference_params is not None:
mixer_kwargs["inference_params"] = inference_params
for layer in self.layers:
hidden_states, residual = layer(
hidden_states, residual, mixer_kwargs=mixer_kwargs
)
if not self.fused_dropout_add_ln:
dropped = self.drop_f(hidden_states)
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
hidden_states = dropout_add_layer_norm(
hidden_states,
residual,
self.ln_f.weight,
self.ln_f.bias,
self.drop_f.p if self.training else 0.0,
self.ln_f.eps,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
return hidden_states
class ConvLMHeadModel(nn.Module, GenerationMixin):
def __init__(
self,
d_model: int,
n_layer: int,
d_inner: int,
vocab_size: int,
process_group=None,
layer=None,
attn_layer_idx=None,
attn_cfg=None,
max_position_embeddings=0,
resid_dropout: float = 0.0,
embed_dropout: float = 0.1,
dropout_cls=nn.Dropout,
layer_norm_epsilon: float = 1e-5,
initializer_cfg=None,
fused_mlp=False,
fused_dropout_add_ln=False,
residual_in_fp32=False,
pad_vocab_size_multiple: int = 1,
sequence_parallel=True,
checkpoint_mlp=False,
checkpoint_mixer=False,
device=None,
dtype=None,
**kwargs,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.process_group = process_group
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (
vocab_size % pad_vocab_size_multiple
)
self.backbone = LMBackbone(
d_model=d_model,
n_layer=n_layer,
d_inner=d_inner,
vocab_size=vocab_size,
process_group=process_group,
layer=layer,
attn_layer_idx=attn_layer_idx,
attn_cfg=attn_cfg,
max_position_embeddings=max_position_embeddings,
resid_dropout=resid_dropout,
embed_dropout=embed_dropout,
dropout_cls=dropout_cls,
layer_norm_epsilon=layer_norm_epsilon,
initializer_cfg=initializer_cfg,
fused_mlp=fused_mlp,
fused_dropout_add_ln=fused_dropout_add_ln,
residual_in_fp32=residual_in_fp32,
sequence_parallel=sequence_parallel,
checkpoint_mlp=checkpoint_mlp,
checkpoint_mixer=checkpoint_mixer,
**factory_kwargs,
**kwargs,
)
if process_group is None:
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
else:
if ColumnParallelLinear is None:
raise ImportError("fused_dense_lib is not installed")
self.lm_head = ColumnParallelLinear(
d_model,
vocab_size,
process_group,
bias=False,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
# Initialize weights and apply final processing
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
self.tie_weights()
def tie_weights(self):
self.lm_head.weight = self.backbone.embeddings.word_embeddings.weight
if self.process_group is not None:
sync_shared_params(self, self.process_group)
def forward(
self, input_ids, position_ids=None, inference_params=None, state=None
): # state for the repo interface
hidden_states = self.backbone(
input_ids, position_ids=position_ids, inference_params=inference_params
)
lm_logits = self.lm_head(hidden_states)
# During inference, we want the full logit for sampling
if ColumnParallelLinear is not None and inference_params is not None:
if isinstance(self.lm_head, ColumnParallelLinear):
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
lm_logits = rearrange(
lm_logits, "(n b) s d -> b s (n d)", b=hidden_states.shape[0]
)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits), None
================================================
FILE: src/ops/fftconv.py
================================================
import math
import torch
import torch.nn.functional as F
from einops import rearrange
from fftconv import fftconv_fwd, fftconv_bwd
@torch.jit.script
def _mul_sum(y, q):
return (y * q).sum(dim=1)
# reference convolution with residual connection
def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None):
seqlen = u.shape[-1]
fft_size = 2 * seqlen
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
if k_rev is not None:
k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
k_f = k_f + k_rev_f.conj()
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
out = y + u * D.unsqueeze(-1)
if gelu:
out = F.gelu(out)
if dropout_mask is not None:
return (out * rearrange(dropout_mask, 'b H -> b H 1')).to(dtype=u.dtype)
else:
return out.to(dtype=u.dtype)
# reference H3 forward pass
def fftconv_h3_ref(k, ssm_kernel, D, q, v, head_dim=1, ssm_kernel_rev=None):
seqlen = k.shape[-1]
fft_size = 2 * seqlen
kv = (rearrange(k, 'b (h d1) l -> b d1 1 h l', d1=head_dim)
* rearrange(v, 'b (h d2) l -> b 1 d2 h l', d2=head_dim)) # b d1 d2 h l
kv_f = torch.fft.rfft(kv.to(dtype=ssm_kernel.dtype), n=fft_size) / fft_size
ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size) # h L+1
if ssm_kernel_rev is not None:
ssm_kernel_rev_f = torch.fft.rfft(ssm_kernel_rev, n=fft_size) # h L+1
ssm_kernel_f = ssm_kernel_f + ssm_kernel_rev_f.conj()
y = torch.fft.irfft(kv_f * ssm_kernel_f, n=fft_size, norm='forward')[..., :seqlen] # b d1 d2 h l
out = y + kv * D.unsqueeze(-1) # b d1 d2 h l
q = rearrange(q, 'b (h d1) l -> b d1 1 h l', d1=head_dim)
if head_dim > 1:
out = _mul_sum(out, q)
return rearrange(out, 'b d2 h l -> b (h d2) l').to(dtype=k.dtype)
else:
return rearrange(out * q, 'b 1 1 h l -> b h l').to(dtype=k.dtype)
class FFTConvFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, u, k, D, dropout_mask=None, gelu=True, force_fp16_output=False,
output_hbl_layout=False, v=None, head_dim=1, q=None, fftfp16=False, k_rev=None):
seqlen = u.shape[-1]
fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16)
k_f = torch.fft.rfft(k, n=fft_size)
if k_rev is not None:
k_f = k_f + torch.fft.rfft(k_rev, n=fft_size).conj()
if u.stride(-1) != 1:
u = u.contiguous()
k_f = k_f.contiguous()
D = D.contiguous()
if v is not None and v.stride(-1) != 1:
v = v.contiguous()
if q is not None and q.stride(-1) != 1:
q = q.contiguous()
if dropout_mask is not None:
dropout_mask = dropout_mask.contiguous()
ctx.save_for_backward(u, k_f, D, dropout_mask, v, q)
ctx.output_hbl_layout = output_hbl_layout
ctx.head_dim = head_dim
ctx.gelu = gelu
ctx.fftfp16 = fftfp16
ctx.has_k_rev = k_rev is not None
out = fftconv_fwd(u, k_f, D, v, head_dim, q, dropout_mask, gelu, False, False, fft_size, force_fp16_output, output_hbl_layout, fftfp16)
return out
@staticmethod
def backward(ctx, dout):
if ctx.output_hbl_layout:
dout = rearrange(rearrange(dout, 'b h l -> h b l').contiguous(), 'h b l -> b h l')
else:
dout = dout.contiguous()
u, k_f, D, dropout_mask, v, q = ctx.saved_tensors
seqlen = u.shape[-1]
fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16)
du, dk_f, dD, dv, dq = fftconv_bwd(dout, u, k_f, D, v, ctx.head_dim, q, dropout_mask, ctx.gelu, False, False, fft_size,
ctx.output_hbl_layout, ctx.fftfp16)
dk = torch.fft.irfft(dk_f, n=fft_size, norm='forward')[..., :seqlen]
dk_rev = (None if not ctx.has_k_rev
else torch.fft.irfft(dk_f.conj(), n=fft_size, norm='forward')[..., :seqlen])
if v is not None:
dv = dv.to(dtype=v.dtype) # We do atomicAdd in fp32 so might need to convert to fp16
return du, dk, dD, None, None, None, None, dv if v is not None else None, None, dq if q is not None else None, None, dk_rev
def fftconv_func(u, k, D, dropout_mask=None, gelu=True, force_fp16_output=False,
output_hbl_layout=False, v=None, head_dim=1, q=None, fftfp16=False, k_rev=None):
return FFTConvFunc.apply(u, k, D, dropout_mask, gelu, force_fp16_output,
output_hbl_layout, v, head_dim, q, fftfp16, k_rev)
================================================
FILE: src/tasks/decoders.py
================================================
"""Decoder heads.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import src.models.nn.utils as U
import src.utils as utils
import src.utils.train
log = src.utils.train.get_logger(__name__)
class Decoder(nn.Module):
"""This class doesn't do much but just signals the interface that Decoders are expected to adhere to
TODO: is there a way to enforce the signature of the forward method?
"""
def forward(self, x, **kwargs):
"""
x: (batch, length, dim) input tensor
state: additional state from the model backbone
*args, **kwargs: additional info from the dataset
Returns:
y: output tensor
*args: other arguments to pass into the loss function
"""
return x
def step(self, x):
"""
x: (batch, dim)
"""
return self.forward(x.unsqueeze(1)).squeeze(1)
class SequenceDecoder(Decoder):
def __init__(
self, d_model, d_output=None, l_output=None, use_lengths=False, mode="last",
conjoin_train=False, conjoin_test=False
):
super().__init__()
self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output)
if l_output is None:
self.l_output = None
self.squeeze = False
elif l_output == 0:
# Equivalent to getting an output of length 1 and then squeezing
self.l_output = 1
self.squeeze = True
else:
assert l_output > 0
self.l_output = l_output
self.squeeze = False
self.use_lengths = use_lengths
self.mode = mode
if mode == 'ragged':
assert not use_lengths
self.conjoin_train = conjoin_train
self.conjoin_test = conjoin_test
def forward(self, x, state=None, lengths=None, l_output=None):
"""
x: (n_batch, l_seq, d_model) or potentially (n_batch, l_seq, d_model, 2) if using rc_conjoin
Returns: (n_batch, l_output, d_output)
"""
if self.l_output is None:
if l_output is not None:
assert isinstance(l_output, int) # Override by pass in
else:
# Grab entire output
l_output = x.size(1)
squeeze = False
else:
l_output = self.l_output
squeeze = self.squeeze
if self.mode == "last":
def restrict(x_seq):
"""Use last l_output elements of sequence."""
return x_seq[..., -l_output:, :]
elif self.mode == "first":
def restrict(x_seq):
"""Use first l_output elements of sequence."""
return x_seq[..., :l_output, :]
elif self.mode == "pool":
def restrict(x_seq):
"""Pool sequence over a certain range"""
L = x_seq.size(1)
s = x_seq.sum(dim=1, keepdim=True)
if l_output > 1:
c = torch.cumsum(x_seq[..., -(l_output - 1):, ...].flip(1), dim=1)
c = F.pad(c, (0, 0, 1, 0))
s = s - c # (B, l_output, D)
s = s.flip(1)
denom = torch.arange(
L - l_output + 1, L + 1, dtype=x_seq.dtype, device=x_seq.device
)
s = s / denom
return s
elif self.mode == "sum":
# TODO use same restrict function as pool case
def restrict(x_seq):
"""Cumulative sum last l_output elements of sequence."""
return torch.cumsum(x_seq, dim=-2)[..., -l_output:, :]
elif self.mode == 'ragged':
assert lengths is not None, "lengths must be provided for ragged mode"
def restrict(x_seq):
"""Ragged aggregation."""
# remove any additional padding (beyond max length of any sequence in the batch)
return x_seq[..., : max(lengths), :]
else:
raise NotImplementedError(
"Mode must be ['last' | 'first' | 'pool' | 'sum' | 'ragged']"
)
# Restrict to actual length of sequence
if self.use_lengths:
assert lengths is not None
x = torch.stack(
[
restrict(out[..., :length, :])
for out, length in zip(torch.unbind(x, dim=0), lengths)
],
dim=0,
)
else:
x = restrict(x)
if squeeze:
assert x.size(1) == 1
x = x.squeeze(1)
if self.conjoin_train or (self.conjoin_test and not self.training):
x, x_rc = x.chunk(2, dim=-1)
x = self.output_transform(x.squeeze())
x_rc = self.output_transform(x_rc.squeeze())
x = (x + x_rc) / 2
else:
x = self.output_transform(x)
return x
def step(self, x, state=None):
# Ignore all length logic
x_fwd = self.output_transform(x.mean(dim=1))
x_rc = self.output_transform(x.flip(dims=[1, 2]).mean(dim=1)).flip(dims=[1])
x_out = (x_fwd + x_rc) / 2
return x_out
# For every type of encoder/decoder, specify:
# - constructor class
# - list of attributes to grab from dataset
# - list of attributes to grab from model
registry = {
"stop": Decoder,
"id": nn.Identity,
"linear": nn.Linear,
"sequence": SequenceDecoder,
}
model_attrs = {
"linear": ["d_output"],
"sequence": ["d_output"],
"nd": ["d_output"],
"retrieval": ["d_output"],
"state": ["d_state", "state_to_tensor"],
"forecast": ["d_output"],
"token": ["d_output"],
}
dataset_attrs = {
"linear": ["d_output"],
"sequence": ["d_output", "l_output"],
"nd": ["d_output"],
"retrieval": ["d_output"],
"state": ["d_output"],
"forecast": ["d_output", "l_output"],
"token": ["d_output"],
}
def _instantiate(decoder, model=None, dataset=None):
"""Instantiate a single decoder"""
if decoder is None:
return None
if isinstance(decoder, str):
name = decoder
else:
name = decoder["_name_"]
# Extract arguments from attribute names
dataset_args = utils.config.extract_attrs_from_obj(
dataset, *dataset_attrs.get(name, [])
)
model_args = utils.config.extract_attrs_from_obj(model, *model_attrs.get(name, []))
# Instantiate decoder
obj = utils.instantiate(registry, decoder, *model_args, *dataset_args)
return obj
def instantiate(decoder, model=None, dataset=None):
"""Instantiate a full decoder config, e.g. handle list of configs
Note that arguments are added in reverse order compared to encoder (model first, then dataset)
"""
decoder = utils.to_list(decoder)
return U.PassthroughSequential(
*[_instantiate(d, model=model, dataset=dataset) for d in decoder]
)
================================================
FILE: src/tasks/encoders.py
================================================
from torch import nn
import src.models.nn.utils as U
import src.utils as utils
class Encoder(nn.Module):
"""Encoder abstraction
Accepts a tensor and optional kwargs. Other than the main tensor, all other arguments should be kwargs.
Returns a tensor and optional kwargs.
Encoders are combined via U.PassthroughSequential which passes these kwargs through in a pipeline. The resulting
kwargs are accumulated and passed into the model backbone.
"""
def forward(self, x, **kwargs):
"""
x: input tensor
*args: additional info from the dataset (e.g. sequence lengths)
Returns:
y: output tensor
*args: other arguments to pass into the model backbone
"""
return x, {}
# For every type of encoder/decoder, specify:
# - constructor class
# - list of attributes to grab from dataset
# - list of attributes to grab from model
registry = {
"stop": Encoder,
"id": nn.Identity,
"embedding": nn.Embedding,
"linear": nn.Linear,
}
dataset_attrs = {
"embedding": ["n_tokens"],
"linear": ["d_input"], # TODO make this d_data?
"class": ["n_classes"],
"time": ["n_tokens_time"],
"onehot": ["n_tokens"],
"conv1d": ["d_input"],
"patch2d": ["d_input"],
}
model_attrs = {
"embedding": ["d_model"],
"linear": ["d_model"],
"position": ["d_model"],
"class": ["d_model"],
"time": ["d_model"],
"onehot": ["d_model"],
"conv1d": ["d_model"],
"patch2d": ["d_model"],
"timestamp_embedding": ["d_model"],
"layer": ["d_model"],
}
def _instantiate(encoder, dataset=None, model=None):
"""Instantiate a single encoder"""
if encoder is None:
return None
if isinstance(encoder, str):
name = encoder
else:
name = encoder["_name_"]
# Extract dataset/model arguments from attribute names
dataset_args = utils.config.extract_attrs_from_obj(
dataset, *dataset_attrs.get(name, [])
)
model_args = utils.config.extract_attrs_from_obj(model, *model_attrs.get(name, []))
# Instantiate encoder
obj = utils.instantiate(registry, encoder, *dataset_args, *model_args)
return obj
def instantiate(encoder, dataset=None, model=None):
encoder = utils.to_list(encoder)
return U.PassthroughSequential(
*[_instantiate(e, dataset=dataset, model=model) for e in encoder]
)
================================================
FILE: src/tasks/metrics.py
================================================
import math
from functools import partial
import torch
import torch.nn.functional as F
import torchmetrics.functional as tm_f
from sklearn.metrics import f1_score, roc_auc_score, matthews_corrcoef
from torchmetrics.classification import MulticlassRecall, MulticlassPrecision
from torchmetrics import Metric
class CorrectAggregatedMetric(Metric):
"""This is needed to calculate some metrics b/c small batch sizes cause aggregation via a simple
average to be off, as some classes might not be present in batch but will get penalized with a 0."""
def __init__(self, class_idx: int, dist_sync_on_step=False):
# call `self.add_state`for every internal state that is needed for the metrics computations
# dist_reduce_fx indicates the function that should be used to reduce
# state from multiple processes
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.class_idx = torch.tensor(class_idx)
self.add_state("numerator", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("denominator", default=torch.tensor(0.0), dist_reduce_fx="sum")
def _update(self, numerator, denominator, preds, y) -> tuple:
raise NotImplemented
def update(self, logits: torch.Tensor, y: torch.Tensor):
# update metric states
preds = torch.argmax(logits, dim=-1)
logits = logits.view(-1, logits.shape[-1])
y = y.view(-1)
assert preds.shape == y.shape, f"preds shape {preds.shape} != y shape {y.shape}"
self.numerator, self.denominator = self._update(self.numerator, self.denominator, preds, y)
def compute(self):
# compute final result
value = self.numerator.float() / self.denominator if self.denominator > 0 else torch.tensor(0.0)
return value
def reset(self):
self.numerator = torch.tensor(0.0)
self.denominator = torch.tensor(0.0)
class AccuracyPerClass(CorrectAggregatedMetric):
"""Calculate per class accuracy, i.e. P(y_hat = class_idx AND y = class_idx OR y_hat != class_idx AND y != class_idx)
"""
def _update(self, numerator, denominator, preds, y) -> tuple:
# Filter down to the class of interest
class_idx = self.class_idx
relevant_idxs = (y == class_idx)
numerator += (preds[relevant_idxs] == class_idx).sum()
denominator += relevant_idxs.sum()
relevant_idxs = (y != class_idx)
numerator += (preds[relevant_idxs] != class_idx).sum()
denominator += relevant_idxs.sum()
return numerator, denominator
class PrecisionPerClass(CorrectAggregatedMetric):
"""Calculate per class precision, i.e. P(y_hat = y | y_hat = class_idx)
"""
def _update(self, numerator, denominator, preds, y) -> tuple:
# Filter down to the class of interest
class_idx = self.class_idx
relevant_idxs = (preds == class_idx)
numerator += (preds[relevant_idxs] == y[relevant_idxs]).sum()
denominator += relevant_idxs.sum()
return numerator, denominator
class RecallPerClass(CorrectAggregatedMetric):
"""Calculate per class recall, i.e. P(y_hat = y | y = class_idx)
"""
def _update(self, numerator, denominator, preds, y) -> tuple:
# Filter down to the class of interest
class_idx = self.class_idx
relevant_idxs = (y == class_idx)
numerator += (preds[relevant_idxs] == y[relevant_idxs]).sum()
denominator += relevant_idxs.sum()
return numerator, denominator
def mcc(logits, y):
logits = logits.view(-1, logits.shape[-1])
y = y.view(-1)
y_hat = torch.argmax(logits, dim=-1)
return matthews_corrcoef(y.cpu().numpy(), y_hat.cpu().numpy())
def last_k_ppl(logits, y, seq_len=1024, k=None):
'''
Calculate perplexity for last k tokens in a sequence.
logits: (batch_size * seq_len, vocab_size), note, already flattened
y: (batch_size * seq_len), note, already flattened
seq_len: int, length of each sequence in the batch
k: if None, use all tokens in sequence
returns: (batch_size,) ppl for each sequence in the batch
'''
if k is None:
k = 0 # use the entire sequence
# need to reshape logits and y to be (batch_size, seq_len, vocab_size) and (batch_size, seq_len)
# respectively
# breakpoint()
logits = logits.view(-1, seq_len, logits.shape[-1])
y = y.view(-1, seq_len)
# only use the last k values of seq dim in logits and y
logits = logits[:, -k:, :]
y = y[:, -k:]
# reshape to flatten the batch and seq_len dimensions
logits = logits.reshape(-1, logits.shape[-1])
y = y.reshape(-1)
# get avg and put on cpu
return F.cross_entropy(logits, y, reduction='none').view(y.shape[0], -1).mean().exp().cpu()
def _student_t_map(mu, sigma, nu):
sigma = F.softplus(sigma)
nu = 2.0 + F.softplus(nu)
return mu.squeeze(axis=-1), sigma.squeeze(axis=-1), nu.squeeze(axis=-1)
def student_t_loss(outs, y):
mu, sigma, nu = outs[..., 0], outs[..., 1], outs[..., 2]
mu, sigma, nu = _student_t_map(mu, sigma, nu)
y = y.squeeze(axis=-1)
nup1_half = (nu + 1.0) / 2.0
part1 = 1.0 / nu * torch.square((y - mu) / sigma)
Z = (
torch.lgamma(nup1_half)
- torch.lgamma(nu / 2.0)
- 0.5 * torch.log(math.pi * nu)
- torch.log(sigma)
)
ll = Z - nup1_half * torch.log1p(part1)
return -ll.mean()
def gaussian_ll_loss(outs, y):
mu, sigma = outs[..., 0], outs[..., 1]
y = y.squeeze(axis=-1)
sigma = F.softplus(sigma)
ll = -1.0 * (
torch.log(sigma)
+ 0.5 * math.log(2 * math.pi)
+ 0.5 * torch.square((y - mu) / sigma)
)
return -ll.mean()
def binary_cross_entropy(logits, y):
# BCE loss requires squeezing last dimension of logits so it has the same shape as y
# requires y to be float, since it's overloaded to represent a probability
return F.binary_cross_entropy_with_logits(logits.squeeze(-1), y.float())
def binary_accuracy(logits, y):
return torch.eq(logits.squeeze(-1) >= 0, y).float().mean()
def padded_cross_entropy(logits, y, pad_mask, pad_value=-1):
"""Will ignore the pad value in label (eg, -1)
logits: (batch_size, seq_len, vocab_size)
y: (batch_size, seq_len)
pad_mask: (batch_size, seq_len)
"""
# need to apply pad mask to y
y_pad = y + pad_mask * pad_value
logits = logits.view(-1, logits.shape[-1])
y_pad = y_pad.view(-1)
return F.cross_entropy(logits, y_pad, ignore_index=pad_value)
def cross_entropy(logits, y, ignore_index=-100):
logits = logits.view(-1, logits.shape[-1])
y = y.view(-1)
return F.cross_entropy(logits, y, ignore_index=ignore_index)
def soft_cross_entropy(logits, y, label_smoothing=0.0):
logits = logits.view(-1, logits.shape[-1])
# target is now 2d (no target flattening)
return F.cross_entropy(logits, y, label_smoothing=label_smoothing)
def accuracy(logits, y):
logits = logits.view(-1, logits.shape[-1])
preds = torch.argmax(logits, dim=-1)
if y.numel() > logits.shape[0]:
# Mixup leads to this case: use argmax class
y = y.argmax(dim=-1)
y = y.view(-1)
return torch.eq(preds, y).float().mean()
def accuracy_ignore_index(logits, y, ignore_index=-100):
num_classes = logits.shape[-1]
preds = torch.argmax(logits, dim=-1)
logits = logits.view(-1, logits.shape[-1])
y = y.view(-1)
accuracy = tm_f.classification.accuracy(preds, y, 'multiclass', num_classes=num_classes, ignore_index=ignore_index, average='micro')
return accuracy
def accuracy_at_k(logits, y, k=1):
logits = logits.view(-1, logits.shape[-1])
if y.numel() > logits.shape[0]:
# Mixup leads to this case: use argmax class
y = y.argmax(dim=-1)
y = y.view(-1)
return torch.topk(logits, k, dim=-1)[1].eq(y.unsqueeze(-1)).any(dim=-1).float().mean()
def f1_binary(logits, y):
logits = logits.view(-1, logits.shape[-1])
y = y.view(-1)
y_hat = torch.argmax(logits, dim=-1)
return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="binary")
def f1_macro(logits, y):
logits = logits.view(-1, logits.shape[-1])
y = y.view(-1)
y_hat = torch.argmax(logits, dim=-1)
return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="macro")
def f1_micro(logits, y):
logits = logits.view(-1, logits.shape[-1])
y = y.view(-1)
y_hat = torch.argmax(logits, dim=-1)
return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="micro")
def roc_auc_macro(logits, y):
logits = logits.view(
-1, logits.shape[-1]
).detach() # KS: had to add detach to eval while training
y = y.view(-1)
return roc_auc_score(
y.cpu().numpy(), F.softmax(logits, dim=-1).cpu().numpy()[:, 1], average="macro"
)
def roc_auc_micro(logits, y):
logits = logits.view(-1, logits.shape[-1])
y = y.view(-1)
return roc_auc_score(
y.cpu().numpy(), F.softmax(logits, dim=-1).cpu().numpy()[:, 1], average="micro"
)
def mse(outs, y, len_batch=None):
# assert outs.shape[:-1] == y.shape and outs.shape[-1] == 1
# outs = outs.squeeze(-1)
if len(y.shape) < len(outs.shape):
assert outs.shape[-1] == 1
outs = outs.squeeze(-1)
if len_batch is None:
return F.mse_loss(outs, y)
else:
# Computes the loss of the first `lens` items in the batches
# TODO document the use case of this
mask = torch.zeros_like(outs, dtype=torch.bool)
for i, l in enumerate(len_batch):
mask[i, :l, :] = 1
outs_masked = torch.masked_select(outs, mask)
y_masked = torch.masked_select(y, mask)
return F.mse_loss(outs_masked, y_masked)
def forecast_rmse(outs, y, len_batch=None):
# TODO: generalize, currently for Monash dataset
return torch.sqrt(F.mse_loss(outs, y, reduction='none').mean(1)).mean()
def mae(outs, y, len_batch=None):
# assert outs.shape[:-1] == y.shape and outs.shape[-1] == 1
# outs = outs.squeeze(-1)
if len(y.shape) < len(outs.shape):
assert outs.shape[-1] == 1
outs = outs.squeeze(-1)
if len_batch is None:
return F.l1_loss(outs, y)
else:
# Computes the loss of the first `lens` items in the batches
mask = torch.zeros_like(outs, dtype=torch.bool)
for i, l in enumerate(len_batch):
mask[i, :l, :] = 1
outs_masked = torch.masked_select(outs, mask)
y_masked = torch.masked_select(y, mask)
return F.l1_loss(outs_masked, y_masked)
# Metrics that can depend on the loss
def loss(x, y, loss_fn):
""" This metric may be useful because the training loss may add extra regularization (e.g. weight decay implemented as L2 penalty), while adding this as a metric skips the additional losses """
return loss_fn(x, y)
def bpb(x, y, loss_fn):
""" bits per byte (image density estimation, speech generation, char LM) """
return loss_fn(x, y) / math.log(2)
def ppl(x, y, loss_fn):
return torch.exp(loss_fn(x, y))
# should have a better way to do this
output_metric_fns = {
"binary_cross_entropy": binary_cross_entropy,
"cross_entropy": cross_entropy,
"padded_cross_entropy": padded_cross_entropy,
"binary_accuracy": binary_accuracy,
# "precision": MulticlassPrecision,
# "precision_species": partial(MulticlassPrecision, task='multiclass', average=None),
"precision_species": partial(MulticlassPrecision, average=None),
# "recall_species": partial(MulticlassRecall, task='multiclass', average=None),
"recall_species": partial(MulticlassRecall, average=None),
# "precision_class": partial(MulticlassPrecision, average=None),
"precision_per_class": PrecisionPerClass,
"recall": MulticlassRecall,
"recall_per_class": RecallPerClass,
"accuracy": accuracy,
"accuracy_per_class": AccuracyPerClass,
"accuracy_ignore_index": accuracy_ignore_index,
'accuracy@3': partial(accuracy_at_k, k=3),
'accuracy@5': partial(accuracy_at_k, k=5),
'accuracy@10': partial(accuracy_at_k, k=10),
"eval_loss": loss,
"mcc": mcc,
"mse": mse,
"mae": mae,
"forecast_rmse": forecast_rmse,
"f1_binary": f1_binary,
"f1_macro": f1_macro,
"f1_micro": f1_micro,
"roc_auc_macro": roc_auc_macro,
"roc_auc_micro": roc_auc_micro,
"soft_cross_entropy": soft_cross_entropy, # only for pytorch 1.10+
"student_t": student_t_loss,
"gaussian_ll": gaussian_ll_loss,
}
loss_metric_fns = {
"loss": loss,
"bpb": bpb,
"ppl": ppl,
}
metric_fns = {**output_metric_fns, **loss_metric_fns} # TODO py3.9
================================================
FILE: src/tasks/tasks.py
================================================
import inspect
from typing import List
import torch.nn as nn
from einops import rearrange
import src.models.nn.utils as U
import src.tasks.metrics as M
import torchmetrics as tm
from src.models.nn.adaptive_softmax import AdaptiveEmbedding, ProjectedAdaptiveLogSoftmax
from src.tasks.torchmetrics import torchmetric_fns as tm_mine
from src.utils.config import to_list, instantiate
from torchmetrics import MetricCollection
class BaseTask:
""" Abstract class that takes care of:
- loss function
- arbitrary metrics
- forward pass
- (optional) encoder module that interfaces with dataset (inputs) and model
- (optional) decoder module that interfaces with dataset (targets) and model
"""
encoder = None
decoder = None
def __init__(self, dataset=None, model=None, loss=None, loss_val=None, metrics=None, torchmetrics=None):
""" This class is allowed to grab attributes directly off a constructed dataset and model object """
self.dataset = dataset
self.model = model
if metrics is None:
metrics = []
self.metric_names = to_list(metrics)
if torchmetrics is None:
torchmetrics = []
self.torchmetric_names = to_list(torchmetrics)
self._tracked_torchmetrics = {}
# The decoder might pass through arguments that the loss needs (e.g. sequence lengths)
# but might also pass through extraneous arguments (e.g. sampling rate)
# Wrap loss and metrics so that they accept kwargs and
# Create loss function
self.loss = instantiate(M.output_metric_fns, loss, partial=True)
self.loss = U.discard_kwargs(self.loss)
if loss_val is not None:
self.loss_val = instantiate(M.output_metric_fns, loss_val, partial=True)
self.loss_val = U.discard_kwargs(self.loss_val)
torchmetrics = MetricCollection(self._init_torchmetrics())
self.train_torchmetrics = torchmetrics.clone(prefix='train/')
self.val_torchmetrics = torchmetrics.clone(prefix='val/')
self.test_torchmetrics = torchmetrics.clone(prefix='test/')
def _init_torchmetrics(self):
"""
Instantiate torchmetrics.
"""
tracked_torchmetrics = {}
for name in self.torchmetric_names:
if name in tm_mine:
tracked_torchmetrics[name] = tm_mine[name]()
elif name in ['AUROC', 'StatScores', 'Precision', 'Recall', 'F1', 'F1Score']:
tracked_torchmetrics[name] = getattr(tm, name)(
average='macro', num_classes=self.dataset.d_output, compute_on_step=False
)
elif name in ['MultilabelAUROC', 'MultilabelAveragePrecision']:
tracked_torchmetrics[name] = getattr(tm, name)(
average='macro', num_labels=self.dataset.d_output
)
elif '@' in name:
k = int(name.split('@')[1])
mname = name.split('@')[0]
tracked_torchmetrics[name] = getattr(tm, mname)(
average='macro', num_classes=self.dataset.d_output, compute_on_step=False, top_k=k
)
else:
tracked_torchmetrics[name] = getattr(tm, name)(compute_on_step=False)
return tracked_torchmetrics
def _reset_torchmetrics(self, prefix=None):
"""
Reset torchmetrics for a prefix
associated with a particular dataloader (e.g. train, val, test).
Generally do this at the start of an epoch.
"""
all_prefixes = [prefix] if prefix is not None else self._tracked_torchmetrics
for prefix in all_prefixes:
if prefix in self._tracked_torchmetrics:
self._tracked_torchmetrics[prefix].reset()
def get_torchmetrics(self, prefix):
"""
Compute torchmetrics for a prefix associated with
a particular dataloader (e.g. train, val, test).
Generally do this at the end of an epoch.
"""
return {name: self._tracked_torchmetrics[prefix][name].compute() for name in self.torchmetric_names}
def torchmetrics(self, x, y, prefix, loss=None):
"""
Update torchmetrics with new x, y .
Prefix corresponds to a particular dataloader (e.g. train, val, test).
Generally call this every batch.
"""
if prefix not in self._tracked_torchmetrics:
self._init_torchmetrics(prefix)
self._tracked_torchmetrics[prefix](x, y, loss=loss)
# for name in self.torchmetric_names:
# if name.startswith('Accuracy'):
# if len(x.shape) > 2:
# # Multi-dimensional, multi-class
# self._tracked_torchmetrics[prefix][name].update(x.transpose(1, 2), y.squeeze())
# continue
# self._tracked_torchmetrics[prefix][name].update(x, y)
def get_torchmetrics(self, prefix):
return self._tracked_torchmetrics[prefix]
def metrics(self, x, y, **kwargs):
"""
Metrics are just functions
output metrics are a function of output and target
loss metrics are a function of loss (e.g. perplexity)
"""
output_metrics = {
name: U.discard_kwargs(M.output_metric_fns[name])(x, y, **kwargs)
for name in self.metric_names if name in M.output_metric_fns
}
loss_metrics = {
name: U.discard_kwargs(M.loss_metric_fns[name])(x, y, self.loss, **kwargs)
for name in self.metric_names if name in M.loss_metric_fns
}
return {**output_metrics, **loss_metrics}
def forward(self, batch, encoder, model, decoder, _state):
"""Passes a batch through the encoder, backbone, and decoder"""
# z holds arguments such as sequence length
x, y, *z = batch # z holds extra dataloader info such as resolution
if len(z) == 0:
z = {}
else:
assert len(z) == 1 and isinstance(z[0], dict), "Dataloader must return dictionary of extra arguments"
z = z[0]
# w can model-specific constructions, such as key_padding_mask for transformers or state for RNNs
x, w = encoder(x, **z)
x, state = model(x, **w, state=_state)
self._state = state
x, w = decoder(x, state=state, **z)
return x, y, w
class Scalar(nn.Module):
def __init__(self, c=1):
super().__init__()
self.c = c
def forward(self, x):
return x * self.c
class LMTask(BaseTask):
def forward(self, batch, encoder, model, decoder, _state):
"""Passes a batch through the encoder, backbone, and decoder"""
# z holds arguments such as sequence length
x, y, *z = batch # z holds extra dataloader info such as resolution
if len(z) == 0:
z = {}
else:
assert len(z) == 1 and isinstance(z[0], dict), "Dataloader must return dictionary of extra arguments"
z = z[0]
# w can model-specific constructions such as key_padding_mask for transformers or state for RNNs
x, w = encoder(x, **z)
# Needed for Mamba (open-source repo version)
if "state" in inspect.signature(model.forward).parameters.keys():
x, state = model(x, **w, state=_state)
else:
x = model(x, **w)
state = None
self._state = state
x, w = decoder(x, state=state, **z)
if hasattr(x, 'logits'):
x = x.logits
x = rearrange(x, '... C -> (...) C')
y = rearrange(y, '... -> (...)')
return x, y, w
class MultiClass(BaseTask):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.continual_metrics = {}
for name in self.metric_names:
if name.endswith('_per_class'):
for spec_idx, spec in enumerate(self.dataset.species):
self.continual_metrics[name + '_' + spec] = M.output_metric_fns[name](spec_idx)
elif name in ['precision_species', 'recall_species']:
self.continual_metrics[name] = M.output_metric_fns[name](num_classes=len(self.dataset.species))
def metrics(self, x, y, **kwargs):
output_metrics = {}
for name in self.metric_names:
if name in M.output_metric_fns:
if name.endswith('_per_class'):
for spec_idx, spec in enumerate(self.dataset.species):
self.continual_metrics[name + '_' + spec] = self.continual_metrics[name + '_' + spec].to(
x.device)
self.continual_metrics[name + '_' + spec].update(x, y)
output_metrics[name + '_' + spec] = self.continual_metrics[name + '_' + spec].compute()
elif name in ['precision_species', 'recall_species']:
self.continual_metrics[name] = self.continual_metrics[name].to(x.device)
metrics = self.continual_metrics[name](x, y)
for spec_idx, spec in enumerate(self.dataset.species):
output_metrics[name[:-7] + spec] = metrics[spec_idx]
else:
output_metrics[name] = U.discard_kwargs(M.output_metric_fns[name])(x, y, **kwargs)
loss_metrics = {
name: U.discard_kwargs(M.loss_metric_fns[name])(x, y, self.loss, **kwargs)
for name in self.metric_names if name in M.loss_metric_fns
}
return {**output_metrics, **loss_metrics}
def _reset_torchmetrics(self, prefix=None):
super()._reset_torchmetrics(prefix)
for name in self.metric_names:
if name.endswith('_per_class'):
for spec_idx, spec in enumerate(self.dataset.species):
self.continual_metrics[name + '_' + spec].reset()
class HG38Task(LMTask):
def __init__(self, dataset=None, model=None, loss=None, loss_val=None, metrics=None, torchmetrics=None,
last_k_ppl=None, per_token_ppl=None):
""" Extending LMTask to add custom metrics for HG38 task
last_k_ppl: config for custom ppl, with hparams to pass with it
per_token_ppl: config for per token ppl calc, with list of k (ppls) to track
"""
self.dataset = dataset
self.model = model
if metrics is None:
metrics = []
self.metric_names = to_list(metrics)
self.last_k_ppl = last_k_ppl
self.per_token_ppl = per_token_ppl
if torchmetrics is None:
torchmetrics = []
self.torchmetric_names = to_list(torchmetrics)
self._tracked_torchmetrics = {}
# The decoder might pass through arguments that the loss needs (e.g. sequence lengths)
# but might also pass through extraneous arguments (e.g. sampling rate)
# Wrap loss and metrics so that they accept kwargs and
# Create loss function
self.loss = instantiate(M.output_metric_fns, loss, partial=True)
self.loss = U.discard_kwargs(self.loss)
if loss_val is not None:
self.loss_val = instantiate(M.output_metric_fns, loss_val, partial=True)
self.loss_val = U.discard_kwargs(self.loss_val)
torchmetrics = MetricCollection(self._init_torchmetrics())
self.train_torchmetrics = torchmetrics.clone(prefix='train/')
self.val_torchmetrics = torchmetrics.clone(prefix='val/')
self.test_torchmetrics = torchmetrics.clone(prefix='test/')
# Create custom metrics for last k ppl
# last_k_ppl is a list of dicts (configs), so loop through them
if self.last_k_ppl is not None:
self.custom_ppl_dict = {}
for k in self.last_k_ppl:
key_name = "last_" + str(k) + "_ppl"
# create config
custom_ppl_config = {"_name_": "last_k_ppl", "k": k, "seq_len": self.dataset.max_length}
k_ppl_fn = instantiate(M.output_metric_fns, custom_ppl_config, partial=True)
k_ppl_fn = U.discard_kwargs(k_ppl_fn)
self.custom_ppl_dict[key_name] = k_ppl_fn
# Create custom metric for per token ppl
if self.per_token_ppl is not None:
per_token_ppl_config = {"_name_": "per_token_ppl", "ks": self.per_token_ppl["ks"],
"seq_len": self.dataset.max_length}
per_token_fn = instantiate(M.output_metric_fns, per_token_ppl_config, partial=True)
per_token_fn = U.discard_kwargs(per_token_fn)
self.per_token_fn = per_token_fn
def metrics(self, x, y, **kwargs):
"""
Need to modify metrics to include custom metrics
"""
output_metrics = {
name: U.discard_kwargs(M.output_metric_fns[name])(x, y, **kwargs)
for name in self.metric_names if name in M.output_metric_fns
}
loss_metrics = {
name: U.discard_kwargs(M.loss_metric_fns[name])(x, y, self.loss, **kwargs)
for name in self.metric_names if name in M.loss_metric_fns
}
# loop through all custom ppls and add them to output_metrics
if self.last_k_ppl is not None:
for key_name, k_ppl_fn in self.custom_ppl_dict.items():
output_metrics[key_name] = k_ppl_fn(x, y, **kwargs)
# loop through all custom ppls and add them to output_metrics
if self.per_token_ppl is not None:
# returns k ppl values, (averaged over batch)
per_k_ppl = self.per_token_fn(x, y, **kwargs)
# loop over ks to log metric
for ind, k in enumerate(self.per_token_ppl["ks"]):
key_name = "ppl_at_{}".format(k)
output_metrics[key_name] = per_k_ppl[ind] # should be in order
return {**output_metrics, **loss_metrics}
class AdaptiveLMTask(BaseTask):
def __init__(
self,
div_val,
cutoffs: List[int],
tie_weights: bool,
tie_projs: List[bool],
init_scale=1.0,
bias_scale=0.0,
dropemb=0.0,
dropsoft=0.0,
**kwargs,
):
super().__init__(**kwargs)
n_tokens = self.dataset.n_tokens
d_model = self.model.d_model
d_output = self.model.d_output
encoder = AdaptiveEmbedding(
n_tokens,
d_model,
d_model,
cutoffs=cutoffs,
div_val=div_val,
init_scale=init_scale,
dropout=dropemb,
)
if tie_weights:
assert d_model == d_output
emb_layers = [i.weight for i in encoder.emb_layers]
else:
emb_layers = None
# Construct decoder/loss
emb_projs = encoder.emb_projs
loss = ProjectedAdaptiveLogSoftmax(
n_tokens, d_output, d_output,
cutoffs, div_val=div_val,
tie_projs=tie_projs,
out_projs=emb_projs,
out_layers_weights=emb_layers,
bias_scale=bias_scale,
dropout=dropsoft,
)
self.encoder = encoder
self.loss = loss
registry = {
'base': BaseTask,
'multiclass': MultiClass,
'lm': LMTask,
'hg38': HG38Task,
}
================================================
FILE: src/tasks/torchmetrics.py
================================================
# Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py
# But we compute the perplexity correctly: exp(average(nll)), not average(exp(nll))
# Also adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/perplexity.py
# But we pass in the loss to avoid recomputation
from typing import Any, Dict, Optional
import torch
from torch import Tensor
from torchmetrics import Metric
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
except ImportError:
CrossEntropyLoss = torch.nn.CrossEntropyLoss
try:
from apex.transformer import parallel_state
except ImportError:
parallel_state = None
class Perplexity(Metric):
r"""
Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits
per word a model needs to represent the sample.
Args:
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Examples:
>>> import torch
>>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))
>>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))
>>> target[0, 6:] = -100
>>> metric = Perplexity(ignore_index=-100)
>>> metric(preds, target)
tensor(5.2545)
"""
is_differentiable = True
higher_is_better = False
full_state_update = False
total_log_probs: Tensor
count: Tensor
def __init__(self, **kwargs: Dict[str, Any]):
super().__init__(**kwargs)
self.add_state("total_log_probs", default=torch.tensor(0.0, dtype=torch.float64),
dist_reduce_fx="sum")
self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum")
self.loss_fn = CrossEntropyLoss()
def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore
"""Compute and store intermediate statistics for Perplexity.
Args:
preds:
Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size].
target:
Ground truth values with a shape [batch_size, seq_len].
"""
count = target.numel()
if loss is None:
loss = self.loss_fn(preds, target)
self.total_log_probs += loss.double() * count
self.count += count
def compute(self) -> Tensor:
"""Compute the Perplexity.
Returns:
Perplexity
"""
return torch.exp(self.total_log_probs / self.count)
class NumTokens(Metric):
"""Keep track of how many tokens we've seen.
"""
# TODO: how do we prevent the reset between the epochs? The reset happens on the 1st batch
# of the next epoch.
# Right now the hack is that we override reset(), which would mess up the forward method.
# We then override forward to do the right thing.
is_differentiable = False
higher_is_better = False
full_state_update = False
count: Tensor
def __init__(self, **kwargs: Dict[str, Any]):
super().__init__(**kwargs)
self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum",
persistent=True) # We want the count to be saved to state-dict
if parallel_state is not None and not parallel_state.is_unitialized():
self.tensor_parallel_world_size = parallel_state.get_tensor_model_parallel_world_size()
else:
self.tensor_parallel_world_size = 1
def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore
self.count += target.numel() // self.tensor_parallel_world_size
def compute(self) -> Tensor:
return self.count
def reset(self):
count = self.count
super().reset()
self.count = count
# Adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/metric.py
def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:
"""forward computation using single call to `update` to calculate the metric value on the current batch and
accumulate global state.
This can be done when the global metric state is a sinple reduction of batch states.
"""
self.update(*args, **kwargs)
return self.compute()
torchmetric_fns = {
"perplexity": Perplexity,
"num_tokens": NumTokens,
}
================================================
FILE: src/utils/__init__.py
================================================
from .config import is_list, is_dict, to_list, to_dict, get_class, instantiate
================================================
FILE: src/utils/config.py
================================================
"""Utilities for dealing with collection objects (lists, dicts) and configs.
"""
import functools
from typing import Sequence, Mapping, Callable
import hydra
from omegaconf import ListConfig, DictConfig
# TODO this is usually used in a pattern where it's turned into a list, so can just do that here
def is_list(x):
return isinstance(x, Sequence) and not isinstance(x, str)
def is_dict(x):
return isinstance(x, Mapping)
def to_dict(x, recursive=True):
"""Convert Sequence or Mapping object to dict
lists get converted to {0: x[0], 1: x[1], ...}
"""
if is_list(x):
x = {i: v for i, v in enumerate(x)}
if is_dict(x):
if recursive:
return {k: to_dict(v, recursive=recursive) for k, v in x.items()}
else:
return dict(x)
else:
return x
def to_list(x, recursive=False):
"""Convert an object to list.
If Sequence (e.g. list, tuple, Listconfig): just return it
Special case: If non-recursive and not a list, wrap in list
"""
if is_list(x):
if recursive:
return [to_list(_x) for _x in x]
else:
return list(x)
else:
if recursive:
return x
else:
return [x]
def extract_attrs_from_obj(obj, *attrs):
if obj is None:
assert len(attrs) == 0
return []
return [getattr(obj, attr, None) for attr in attrs]
def auto_assign_attrs(cls, **kwargs):
for k, v in kwargs.items():
setattr(cls, k, v)
def instantiate(registry, config, *args, partial=False, wrap=None, **kwargs):
"""
registry: Dictionary mapping names to functions or target paths (e.g. {'model': 'models.SequenceModel'})
config: Dictionary with a '_name_' key indicating which element of the registry to grab, and kwargs to be passed into the target constructor
wrap: wrap the target class (e.g. ema optimizer or tasks.wrap)
*args, **kwargs: additional arguments to override the config to pass into the target constructor
"""
# Case 1: no config
if config is None:
return None
# Case 2a: string means _name_ was overloaded
if isinstance(config, str):
_name_ = None
_target_ = registry[config]
config = {}
# Case 2b: grab the desired callable from name
else:
_name_ = config.pop("_name_")
_target_ = registry[_name_]
# Retrieve the right constructor automatically based on type
if isinstance(_target_, str):
fn = hydra.utils.get_method(path=_target_)
elif isinstance(_target_, Callable):
fn = _target_
else:
raise NotImplementedError("instantiate target must be string or callable")
# Instantiate object
if wrap is not None:
fn = wrap(fn)
obj = functools.partial(fn, *args, **config, **kwargs)
# Restore _name_
if _name_ is not None:
config["_name_"] = _name_
if partial:
return obj
else:
return obj()
def get_class(registry, _name_):
return hydra.utils.get_class(path=registry[_name_])
def omegaconf_filter_keys(d, fn=None):
"""Only keep keys where fn(key) is True. Support nested DictConfig.
# TODO can make this inplace?
"""
if fn is None:
fn = lambda _: True
if is_list(d):
return ListConfig([omegaconf_filter_keys(v, fn) for v in d])
elif is_dict(d):
return DictConfig(
{k: omegaconf_filter_keys(v, fn) for k, v in d.items() if fn(k)}
)
else:
return d
================================================
FILE: src/utils/optim/schedulers.py
================================================
"""Custom learning rate schedulers"""
import math
import warnings
import torch
from timm.scheduler import CosineLRScheduler
# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html
class CosineWarmup(torch.optim.lr_scheduler.CosineAnnealingLR):
def __init__(self, optimizer, T_max, eta_min=0, warmup_step=0, **kwargs):
self.warmup_step = warmup_step
super().__init__(optimizer, T_max - warmup_step, eta_min, *kwargs)
# Copied from CosineAnnealingLR, but adding warmup and changing self.last_epoch to
# self.last_epoch - self.warmup_step.
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if self.last_epoch == self.warmup_step: # also covers the case where both are 0
return self.base_lrs
elif self.last_epoch < self.warmup_step:
return [base_lr * (self.last_epoch + 1) / self.warmup_step for base_lr in self.base_lrs]
elif (self.last_epoch - self.warmup_step - 1 - self.T_max) % (2 * self.T_max) == 0:
return [group['lr'] + (base_lr - self.eta_min) *
(1 - math.cos(math.pi / self.T_max)) / 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)]
return [(1 + math.cos(math.pi * (self.last_epoch - self.warmup_step) / self.T_max)) /
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_step - 1) / self.T_max)) *
(group['lr'] - self.eta_min) + self.eta_min
for group in self.optimizer.param_groups]
_get_closed_form_lr = None
def InvSqrt(optimizer, warmup_step):
""" Originally used for Transformer (in Attention is all you need)
"""
def lr_lambda(step):
# return a multiplier instead of a learning rate
if step == warmup_step: # also covers the case where both are 0
return 1.
else:
return 1. / (step ** 0.5) if step > warmup_step else (step + 1) / (warmup_step ** 1.5)
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
def Constant(optimizer, warmup_step):
def lr_lambda(step):
if step == warmup_step: # also covers the case where both are 0
return 1.
else:
return 1. if step > warmup_step else (step + 1) / warmup_step
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler):
""" Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch.
It supports resuming as well.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._last_epoch = -1
self.step(epoch=0)
def step(self, epoch=None):
if epoch is None:
self._last_epoch += 1
else:
self._last_epoch = epoch
# We call either step or step_update, depending on whether we're using the scheduler every
# epoch or every step.
# Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set
# scheduler interval to "step", then the learning rate update will be wrong.
if self.t_in_epochs:
super().step(epoch=self._last_epoch)
else:
super().step_update(num_updates=self._last_epoch)
================================================
FILE: src/utils/optim_groups.py
================================================
"""Utilities for special optimizer hyperparameters.
group_parameters_for_optimizer is a modification of timm's optimizer logic, which is currently unused
add_optimizer_hooks is an improved version that uses this codebase's _optim dictionary
"""
import inspect
import torch.nn as nn
import hydra
def add_optimizer_hooks(
model,
bias_weight_decay=False,
normalization_weight_decay=False,
):
"""Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with
attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for
normalization parameters if normalization_weight_decay==False
"""
# Separate out all parameters to those that will and won't experience regularizing weight decay
blacklist_weight_modules = (nn.Embedding, )
if not normalization_weight_decay:
blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
# Not compatible with Pytorch 1.8.1
# nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d,
nn.GroupNorm, nn.SyncBatchNorm,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
nn.LayerNorm, nn.LocalResponseNorm)
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
if (not bias_weight_decay and pn.endswith('bias')) \
or getattr(p, '_no_weight_decay', False) \
or isinstance(m, blacklist_weight_modules):
setattr(p, "_optim", {"weight_decay": 0.0})
def group_parameters_for_optimizer(
model,
optimizer_cfg,
bias_weight_decay=False,
normalization_weight_decay=False,
):
"""Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with
attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for
normalization parameters if normalization_weight_decay==False
"""
# Get the weight decay from the config, or from the default value of the optimizer constructor
# if it's not specified in the config.
if 'weight_decay' in optimizer_cfg:
weight_decay = optimizer_cfg.weight_decay
else:
# https://stackoverflow.com/questions/12627118/get-a-function-arguments-default-value
signature = inspect.signature(hydra.utils.get_class(optimizer_cfg._target_))
if 'weight_decay' in signature.parameters:
weight_decay = signature.parameters['weight_decay'].default
if weight_decay is inspect.Parameter.empty:
weight_decay = 0.0
else:
weight_decay = 0.0
# If none of the parameters have weight decay anyway, and there are no parameters with special
# optimization params
if weight_decay == 0.0 and not any(hasattr(p, '_optim') for p in model.parameters()):
return model.parameters()
skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else set()
skip_keywords = (model.no_weight_decay_keywords() if hasattr(model, 'no_weight_decay_keywords')
else set())
# Adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
special = set()
whitelist_weight_modules = (nn.Linear, )
blacklist_weight_modules = (nn.Embedding, )
if not normalization_weight_decay:
blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
# Not compatible with Pytorch 1.8.1
# nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d,
nn.GroupNorm, nn.SyncBatchNorm,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
nn.LayerNorm, nn.LocalResponseNorm)
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if not p.requires_grad:
continue # frozen weights
if hasattr(p, '_optim'):
special.add(fpn)
elif fpn in skip or any(skip_keyword in fpn for skip_keyword in skip_keywords):
no_decay.add(fpn)
elif getattr(p, '_no_weight_decay', False):
no_decay.add(fpn)
elif not bias_weight_decay and pn.endswith('bias'):
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
# special case the position embedding parameter in the root GPT module as not decayed
if 'pos_emb' in param_dict:
no_decay.add('pos_emb')
# In case of parameter sharing, some parameters show up in decay but are not in param_dict.keys()
decay &= param_dict.keys()
decay |= (param_dict.keys() - no_decay - special)
# validate that we considered every parameter
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both decay/no_decay sets!"
assert len(param_dict.keys() - special - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"
if weight_decay == 0.0 or not no_decay:
param_groups = [{"params": [param_dict[pn] for pn in sorted(list(no_decay | decay))],
"weight_decay": weight_decay}]
else:
param_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
# Add parameters with special hyperparameters
# Unique dicts
hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)]
for hp in hps:
params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp]
param_groups.append({"params": params, **hp})
return param_groups
================================================
FILE: src/utils/registry.py
================================================
"""Class registry for models, layers, optimizers, and schedulers.
"""
optimizer = {
"adam": "torch.optim.Adam",
"adamw": "torch.optim.AdamW",
"rmsprop": "torch.optim.RMSprop",
"sgd": "torch.optim.SGD",
"lamb": "src.utils.optim.lamb.JITLamb",
}
scheduler = {
"constant": "transformers.get_constant_schedule",
"plateau": "torch.optim.lr_scheduler.ReduceLROnPlateau",
"step": "torch.optim.lr_scheduler.StepLR",
"multistep": "torch.optim.lr_scheduler.MultiStepLR",
"cosine": "torch.optim.lr_scheduler.CosineAnnealingLR",
"constant_warmup": "transformers.get_constant_schedule_with_warmup",
"linear_warmup": "transformers.get_linear_schedule_with_warmup",
"cosine_warmup": "transformers.get_cosine_schedule_with_warmup",
"cosine_warmup_timm": "src.utils.optim.schedulers.TimmCosineLRScheduler",
}
model = {
# Pre-training LM head models
"hyena_lm": "src.models.sequence.long_conv_lm.ConvLMHeadModel",
"mamba_lm": "mamba_ssm.models.mixer_seq_simple.MambaLMHeadModel",
"caduceus_lm": "caduceus.modeling_caduceus.CaduceusForMaskedLM",
# Downstream task embedding backbones
"dna_embedding": "src.models.sequence.dna_embedding.DNAEmbeddingModel",
"dna_embedding_mamba": "src.models.sequence.dna_embedding.DNAEmbeddingModelMamba",
"dna_embedding_caduceus": "src.models.sequence.dna_embedding.DNAEmbeddingModelCaduceus",
# Baseline for genomics benchmark
"genomics_benchmark_cnn": "src.models.baseline.genomics_benchmark_cnn.GenomicsBenchmarkCNN",
}
layer = {
"id": "src.models.sequence.base.SequenceIdentity",
"ff": "src.models.sequence.ff.FF",
"hyena": "src.models.sequence.hyena.HyenaOperator",
"hyena-filter": "src.models.sequence.hyena.HyenaFilter",
}
callbacks = {
"learning_rate_monitor": "pytorch_lightning.callbacks.LearningRateMonitor",
"model_checkpoint": "pytorch_lightning.callbacks.ModelCheckpoint",
"model_checkpoint_every_n_steps": "pytorch_lightning.callbacks.ModelCheckpoint",
"model_checkpoint_every_epoch": "pytorch_lightning.callbacks.ModelCheckpoint",
"early_stopping": "pytorch_lightning.callbacks.EarlyStopping",
"swa": "pytorch_lightning.callbacks.StochasticWeightAveraging",
"rich_model_summary": "pytorch_lightning.callbacks.RichModelSummary",
"rich_progress_bar": "pytorch_lightning.callbacks.RichProgressBar",
"params": "src.callbacks.params.ParamsLog",
"timer": "src.callbacks.timer.Timer",
"val_every_n_global_steps": "src.callbacks.validation.ValEveryNGlobalSteps",
}
model_state_hook = {
'load_backbone': 'src.models.sequence.dna_embedding.load_backbone',
}
================================================
FILE: src/utils/train.py
================================================
""" Utils for the training loop.
Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
"""
import json
import logging
import warnings
import rich.syntax
import rich.tree
import torch.nn as nn
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities import rank_zero_only
from src.utils.config import omegaconf_filter_keys
# Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
class LoggingContext:
def __init__(self, logger, level=None, handler=None, close=True):
self.logger = logger
self.level = level
self.handler = handler
self.close = close
def __enter__(self):
if self.level is not None:
self.old_level = self.logger.level
self.logger.setLevel(self.level)
if self.handler:
self.logger.addHandler(self.handler)
def __exit__(self, et, ev, tb):
if self.level is not None:
self.logger.setLevel(self.old_level)
if self.handler:
self.logger.removeHandler(self.handler)
if self.handler and self.close:
self.handler.close()
# implicit return of None => don't swallow exceptions
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
"""Initializes multi-GPU-friendly python logger.x"""
logger = logging.getLogger(name)
logger.setLevel(level)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"):
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger
def process_config(config: DictConfig) -> DictConfig: # TODO because of filter_keys, this is no longer in place
"""A couple of optional utilities, controlled by main config file:
- disabling warnings
- easier access to debug mode
- forcing debug friendly configuration
Modifies DictConfig in place.
Args:
config (DictConfig): Configuration composed by Hydra.
"""
log = get_logger()
# Filter out keys that were used just for interpolation
config = omegaconf_filter_keys(config, lambda k: not k.startswith('__'))
# enable adding new keys to config
OmegaConf.set_struct(config, False)
# disable python warnings if
if config.get("ignore_warnings"):
log.info("Disabling python warnings! ")
warnings.filterwarnings("ignore")
if config.get("debug"):
log.info("Running in debug mode! ")
config.trainer.fast_dev_run = True
# force debugger friendly configuration
log.info("Forcing debugger friendly configuration! ")
# Debuggers don't like GPUs or multiprocessing
if config.trainer.get("gpus"):
config.trainer.gpus = 0
if config.loader.get("pin_memory"):
config.loader.pin_memory = False
if config.loader.get("num_workers"):
config.loader.num_workers = 0
# disable adding new keys to config
# OmegaConf.set_struct(config, True) # [21-09-17 AG] I need this for .pop(_name_) pattern among other things
return config
@rank_zero_only
def print_config(
config: DictConfig,
resolve: bool = True,
save_cfg=True,
) -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
Args:
config (DictConfig): Configuration composed by Hydra.
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
save_cfg (bool, optional): Whether to save the config to a file.
"""
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
fields = config.keys()
for field in fields:
branch = tree.add(field, style=style, guide_style=style)
config_section = config.get(field)
branch_content = str(config_section)
if isinstance(config_section, DictConfig):
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
rich.print(tree)
if save_cfg:
with open("config_tree.txt", "w") as fp:
rich.print(tree, file=fp)
with open("model_config.json", "w") as fp: # Save config / model config for use in fine-tuning or testing
model_config = {
k: v
for k, v in OmegaConf.to_container(config.model, resolve=True).items()
if not k.startswith("_") or k == "config_path"
}
json.dump(model_config, fp, indent=4)
with open("config.json", "w") as fp:
json.dump(OmegaConf.to_container(config, resolve=True), fp, indent=4)
def log_optimizer(logger, optimizer, keys):
""" Log values of particular keys from the optimizers param groups """
keys = sorted(keys)
for i, g in enumerate(optimizer.param_groups):
group_hps = {k: g.get(k, None) for k in keys}
logger.info(' | '.join([
f"Optimizer group {i}",
f"{len(g['params'])} tensors",
] + [f"{k} {v}" for k, v in group_hps.items()]))
class OptimModule(nn.Module):
""" Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """
def register(self, name, tensor, lr=None, wd=0.0):
"""Register a tensor with a configurable learning rate and 0 weight decay"""
if lr == 0.0:
self.register_buffer(name, tensor)
else:
self.register_parameter(name, nn.Parameter(tensor))
optim = {}
if lr is not None:
optim["lr"] = lr
if wd is not None:
optim["weight_decay"] = wd
setattr(getattr(self, name), "_optim", optim)
================================================
FILE: train.py
================================================
"""Main training entry point for pre-training and downstream fine-tuning.
"""
import json
import os
import random
import time
from functools import wraps
from typing import Callable, List, Sequence
import fsspec
import hydra
import pytorch_lightning as pl
import torch
import wandb
from omegaconf import OmegaConf
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
import src.models.nn.utils as U
import src.utils as utils
import src.utils.train
from src.dataloaders import SequenceDataset # TODO make registry
from src.tasks import decoders, encoders, tasks
from src.utils import registry
from src.utils.optim_groups import add_optimizer_hooks
log = src.utils.train.get_logger(__name__)
# Turn on TensorFloat32 (speeds up large model training substantially)
import torch.backends
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
OmegaConf.register_new_resolver('eval', eval)
OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y)
OmegaConf.register_new_resolver('min', lambda x, y: min([x, y]))
# Lots of annoying hacks to get WandbLogger to continuously retry on failure
class DummyExperiment:
"""Dummy experiment."""
def nop(self, *args, **kw):
pass
def __getattr__(self, _):
return self.nop
def __getitem__(self, idx) -> "DummyExperiment":
# enables self.logger.experiment[0].add_image(...)
return self
def __setitem__(self, *args, **kwargs) -> None:
pass
def rank_zero_experiment(fn: Callable) -> Callable:
"""Returns the real experiment on rank 0 and otherwise the DummyExperiment."""
@wraps(fn)
def experiment(self):
@rank_zero_only
def get_experiment():
return fn(self)
return get_experiment() or DummyExperiment()
return experiment
class CustomWandbLogger(WandbLogger):
def __init__(self, *args, **kwargs):
"""Modified logger that insists on a wandb.init() call and catches wandb's error if thrown."""
super().__init__(*args, **kwargs)
@property
@rank_zero_experiment
def experiment(self):
r"""
Actual wandb object. To use wandb features in your
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
Example::
code-block:: python
self.logger.experiment.some_wandb_function()
"""
if self._experiment is None:
if self._offline:
os.environ["WANDB_MODE"] = "dryrun"
attach_id = getattr(self, "_attach_id", None)
if wandb.run is not None:
# wandb process already created in this instance
rank_zero_warn(
"There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
" this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`."
)
self._experiment = wandb.run
elif attach_id is not None and hasattr(wandb, "_attach"):
# attach to wandb process referenced
self._experiment = wandb._attach(attach_id)
else:
# create new wandb process
while True:
try:
self._experiment = wandb.init(**self._wandb_init)
break
except Exception as e:
log.error("wandb Exception:\n", e)
t = random.randint(30, 60)
log.warning(f"Sleeping for {t} seconds")
time.sleep(t)
# define default x-axis
if getattr(self._experiment, "define_metric", None):
self._experiment.define_metric("trainer/global_step")
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)
return self._experiment
class SequenceLightningModule(pl.LightningModule):
def __init__(self, config):
# Disable profiling executor. This reduces memory and increases speed.
try:
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
except AttributeError:
pass
super().__init__()
# Passing in config expands it one level: access by self.hparams.train instead of self.hparams.config.train
self.save_hyperparameters(config, logger=False)
# Dataset arguments
self.dataset = SequenceDataset.registry[self.hparams.dataset._name_](
**self.hparams.dataset
)
# Check hparams
self._check_config()
# PL has some bugs, so add hooks and make sure they're only called once
self._has_setup = False
# To be set in `setup`
self.encoder, self.decoder, self.model = None, None, None
self.task, self.loss, self.loss_val = None, None, None
self.metrics, self.train_torchmetrics, self.val_torchmetrics, self.test_torchmetrics = None, None, None, None
self.setup()
self._state = None
self.val_loader_names, self.test_loader_names = None, None
def setup(self, stage=None):
if not self.hparams.train.disable_dataset:
self.dataset.setup()
# We need to set up the model in setup() because for some reason when training with DDP, one GPU uses much more
# memory than the others.
# In order to not overwrite the model multiple times during different stages, we need this hack
# TODO PL 1.5 seems to have an option to skip hooks to avoid this
# https://github.com/PyTorchLightning/pytorch-lightning/issues/5410#issuecomment-762257024
if self._has_setup:
return
else:
self._has_setup = True
# Convenience feature: if model specifies encoder, combine it with main encoder
encoder_cfg = utils.to_list(self.hparams.encoder) + utils.to_list(
self.hparams.model.pop("encoder", None)
)
decoder_cfg = utils.to_list(
self.hparams.model.pop("decoder", None)
) + utils.to_list(self.hparams.decoder)
# Instantiate model
config_path = self.hparams.model.pop("config_path", None)
if config_path is not None:
with open(config_path) as f:
model_config_from_file = json.load(f)
self.hparams.model.update(model_config_from_file)
# Check if dropout_layer_norm is compiled
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
if self.hparams.model.get("fused_dropout_add_ln", None) is not None:
self.hparams.model.update({"fused_dropout_add_ln": False})
# TODO: Hacky way to get complement_map for Caduceus models; need to find a more elegant implementation
if "caduceus" in self.hparams.model.get("_name_"):
OmegaConf.update(
self.hparams.model.config, "complement_map", self.dataset.tokenizer.complement_map, force_add=True
)
# Instantiate the config class if using hydra's _target_ paradigm for the config
if self.hparams.model.get("config", None) is not None and self.hparams.model.config.get("_target_", None) is not None:
model_hparams = OmegaConf.to_container(self.hparams.model, resolve=True)
model_hparams["config"] = hydra.utils.instantiate(model_hparams["config"])
self.model = utils.instantiate(registry.model, model_hparams)
else:
self.model = utils.instantiate(registry.model, self.hparams.model)
if (name := self.hparams.train.post_init_hook['_name_']) is not None:
kwargs = self.hparams.train.post_init_hook.copy()
del kwargs['_name_']
for module in self.modules():
if hasattr(module, name):
getattr(module, name)(**kwargs)
# if self.hparams.train.get("compile_model", False):
# self.model = torch.compile(self.model, dynamic=False)
# Instantiate the task
self.task = utils.instantiate(
tasks.registry, self.hparams.task, dataset=self.dataset, model=self.model
)
# Create encoders and decoders
encoder = encoders.instantiate(
encoder_cfg, dataset=self.dataset, model=self.model
)
decoder = decoders.instantiate(
decoder_cfg, model=self.model, dataset=self.dataset
)
# Extract the modules, so they show up in the top level parameter count
self.encoder = U.PassthroughSequential(self.task.encoder, encoder)
self.decoder = U.PassthroughSequential(decoder, self.task.decoder)
self.loss = self.task.loss
self.loss_val = self.task.loss
if hasattr(self.task, 'loss_val'):
self.loss_val = self.task.loss_val
self.metrics = self.task.metrics
self.train_torchmetrics = self.task.train_torchmetrics
self.val_torchmetrics = self.task.val_torchmetrics
self.test_torchmetrics = self.task.test_torchmetrics
def load_state_dict(self, state_dict, strict=False):
if self.hparams.train.pretrained_model_state_hook['_name_'] is not None:
model_state_hook = utils.instantiate(
registry.model_state_hook,
self.hparams.train.pretrained_model_state_hook.copy(),
partial=True,
)
state_dict = model_state_hook(self.model, state_dict)
log.info("Custom load_state_dict function is running.")
# strict==True will require all modules to match
# strict==False can allow encoder/decoder to be loaded from scratch too
return super().load_state_dict(state_dict, strict=strict)
def _check_config(self):
assert self.hparams.train.state.mode in [None, "none", "null", "reset", "bptt", "tbptt"]
assert (
(n := self.hparams.train.state.n_context) is None
or isinstance(n, int)
and n >= 0
)
assert (
(n := self.hparams.train.state.n_context_eval) is None
or isinstance(n, int)
and n >= 0
)
def _initialize_state(self):
"""Called at model setup and start of epoch to completely reset state"""
self._state = None
self._memory_chunks = []
def _reset_state(self, batch, device=None):
"""Called to construct default_state when necessary, e.g. during BPTT"""
device = device or batch[0].device
self._state = self.model.default_state(*batch[0].shape[:1], device=device)
def _detach_state(self, state):
if isinstance(state, torch.Tensor):
return state.detach()
elif isinstance(state, tuple):
return tuple(self._detach_state(s) for s in state)
elif isinstance(state, list):
return [self._detach_state(s) for s in state]
elif isinstance(state, dict):
return {k: self._detach_state(v) for k, v in state.items()}
elif state is None:
return None
else:
raise NotImplementedError
def _process_state(self, batch, batch_idx, training=True):
"""Handle logic for state context."""
# Number of context steps
key = "n_context" if training else "n_context_eval"
n_context = self.hparams.train.state.get(key)
# Don't need to do anything if 0 context steps. Make sure there is no state
if n_context == 0 and self.hparams.train.state.mode not in ['tbptt']:
self._initialize_state()
return
# Reset state if needed
if self.hparams.train.state.mode == "reset":
if batch_idx % (n_context + 1) == 0:
self._reset_state(batch)
# Pass through memory chunks
elif self.hparams.train.state.mode == "bptt":
self._reset_state(batch)
with torch.no_grad(): # should be unnecessary because individual modules should handle this
for _batch in self._memory_chunks:
self.forward(_batch)
# Prepare for next step
self._memory_chunks.append(batch)
self._memory_chunks = self._memory_chunks[-n_context:]
elif self.hparams.train.state.mode == 'tbptt':
_, _, z = batch
reset = z["reset"]
if reset:
self._reset_state(batch)
else:
self._state = self._detach_state(self._state)
def forward(self, batch):
return self.task.forward(batch, self.encoder, self.model, self.decoder, self._state)
def step(self, x_t):
x_t, *_ = self.encoder(x_t) # Potential edge case for encoders that expect (B, L, H)?
x_t, state = self.model.step(x_t, state=self._state)
self._state = state
x_t, *_ = self.decoder.step(x_t, state=state)
return x_t
def _shared_step(self, batch, batch_idx, prefix="train"):
"""Shared step logic between training, validation, and test"""
self._process_state(batch, batch_idx, training=(prefix == "train"))
x, y, w = self.forward(batch)
# Loss
if prefix == 'train':
loss = self.loss(x, y, **w)
else:
loss = self.loss_val(x, y, **w)
# Metrics
metrics = self.metrics(x, y, **w)
metrics["loss"] = loss
metrics = {f"{prefix}/{k}": v for k, v in metrics.items()}
# Calculate torchmetrics
torchmetrics = getattr(self, f'{prefix}_torchmetrics')
torchmetrics(x, y, loss=loss)
log_on_step = 'eval' in self.hparams and self.hparams.eval.get('log_on_step', False) and prefix == 'train'
self.log_dict(
metrics,
on_step=log_on_step,
on_epoch=True,
prog_bar=True,
add_dataloader_idx=False,
sync_dist=True,
)
# log the whole dict, otherwise lightning takes the mean to reduce it
# https://pytorch-lightning.readthedocs.io/en/stable/visualize/logging_advanced.html#enable-metrics-for-distributed-training
self.log_dict(
torchmetrics,
on_step=log_on_step,
on_epoch=True,
prog_bar=True,
add_dataloader_idx=False,
sync_dist=True,
)
return loss
def on_train_epoch_start(self):
# Reset training torchmetrics
self.task._reset_torchmetrics("train")
def training_epoch_end(self, outputs):
# Log training torchmetrics
super().training_epoch_end(outputs)
def on_validation_epoch_start(self):
# Reset all validation torchmetrics
for name in self.val_loader_names:
self.task._reset_torchmetrics(name)
def validation_epoch_end(self, outputs):
# Log all validation torchmetrics
super().validation_epoch_end(outputs)
def on_test_epoch_start(self):
# Reset all test torchmetrics
for name in self.test_loader_names:
self.task._reset_torchmetrics(name)
def test_epoch_end(self, outputs):
# Log all test torchmetrics
super().test_epoch_end(outputs)
def training_step(self, batch, batch_idx, dataloader_idx=0):
loss = self._shared_step(batch, batch_idx, prefix="train")
# Log the loss explicitly so that it shows up in WandB
# Note that this currently runs into a bug in the progress bar with ddp (as of 1.4.6)
# https://github.com/PyTorchLightning/pytorch-lightning/pull/9142
# We additionally log the epochs under 'trainer' to get a consistent prefix with 'global_step'
loss_epoch = {"trainer/loss": loss, "trainer/epoch": float(self.current_epoch)}
self.log_dict(
loss_epoch,
on_step=True,
on_epoch=False,
prog_bar=False,
add_dataloader_idx=False,
sync_dist=True,
)
# Log any extra info that the models want to expose (e.g. output norms)
metrics = {}
for module in list(self.modules())[1:]:
if hasattr(module, "metrics"):
metrics.update(module.metrics)
self.log_dict(
metrics,
on_step=True,
on_epoch=False,
prog_bar=False,
add_dataloader_idx=False,
sync_dist=True,
)
return loss
def validation_step(self, batch, batch_idx, dataloader_idx=0):
# There's a bit of an annoying edge case with the first (0-th) epoch; it has to be excluded due to the initial
# sanity check
ema = (
self.val_loader_names[dataloader_idx].endswith("/ema")
and self.optimizers().optimizer.stepped
)
if ema:
self.optimizers().swap_ema()
loss = self._shared_step(
batch, batch_idx, prefix=self.val_loader_names[dataloader_idx]
)
if ema:
self.optimizers().swap_ema()
return loss
def test_step(self, batch, batch_idx, dataloader_idx=0):
return self._shared_step(
batch, batch_idx, prefix=self.test_loader_names[dataloader_idx]
)
def configure_optimizers(self):
# Set zero weight decay for some params
if 'optimizer_param_grouping' in self.hparams.train:
add_optimizer_hooks(self.model, **self.hparams.train.optimizer_param_grouping)
# Normal parameters
all_params = list(self.parameters())
params = [p for p in all_params if not hasattr(p, "_optim")]
optimizer = utils.instantiate(registry.optimizer, self.hparams.optimizer, params)
del self.hparams.optimizer._name_
# Add parameters with special hyperparameters
hps = [getattr(p, "_optim") for p in all_params if hasattr(p, "_optim")]
hps = [
# dict(s) for s in set(frozenset(hp.items()) for hp in hps)
dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps)))
# dict(s) for s in dict.fromkeys(frozenset(hp.items()) for hp in hps)
] # Unique dicts
print("Hyperparameter groups:", hps) # TODO: log.info throws error because hps is list of dicts
for hp in hps:
params = [p for p in all_params if getattr(p, "_optim", None) == hp]
optimizer.add_param_group(
{"params": params, **self.hparams.optimizer, **hp}
)
# Layer Decay
if self.hparams.train.layer_decay['_name_'] is not None:
get_num_layer = utils.instantiate(
registry.layer_decay,
self.hparams.train.layer_decay['_name_'],
partial=True,
)
# Go through all parameters and get num layer
layer_wise_groups = {}
num_max_layers = 0
for name, p in self.named_parameters():
# Get layer id for each parameter in the model
layer_id = get_num_layer(name)
# Add to layer wise group
if layer_id not in layer_wise_groups:
layer_wise_groups[layer_id] = {
'params': [],
'lr': None,
'weight_decay': self.hparams.optimizer.weight_decay
}
layer_wise_groups[layer_id]['params'].append(p)
if layer_id > num_max_layers:
num_max_layers = layer_id
# Update lr for each layer
for layer_id, group in layer_wise_groups.items():
group['lr'] = self.hparams.optimizer.lr * (
self.hparams.train.layer_decay.decay ** (num_max_layers - layer_id))
# Reset the torch optimizers param groups
optimizer.param_groups = []
for layer_id, group in layer_wise_groups.items():
optimizer.add_param_group(group)
# Print optimizer info for debugging
keys = set([k for hp in hps for k in hp.keys()]) # Special hparams
utils.train.log_optimizer(log, optimizer, keys)
# Configure scheduler
if "scheduler" not in self.hparams:
return optimizer
lr_scheduler = utils.instantiate(
registry.scheduler, self.hparams.scheduler, optimizer
)
scheduler = {
"scheduler": lr_scheduler,
"interval": self.hparams.train.interval, # 'epoch' or 'step'
"monitor": self.hparams.train.monitor,
"name": "trainer/lr", # default is e.g. 'lr-AdamW'
}
# See documentation for how to configure the return
# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.configure_optimizers
return [optimizer], [scheduler]
def train_dataloader(self):
return self.dataset.train_dataloader(**self.hparams.loader)
def _eval_dataloaders_names(self, loaders, prefix):
"""Process loaders into a list of names and loaders"""
if utils.is_dict(loaders):
return [
f"{prefix}/{k}" if k is not None else prefix for k in loaders.keys()
], list(loaders.values())
elif utils.is_list(loaders):
return [f"{prefix}/{i}" for i in range(len(loaders))], loaders
else:
return [prefix], [loaders]
def _eval_dataloaders(self):
# Return all val + test loaders
val_loaders = self.dataset.val_dataloader(**self.hparams.loader)
test_loaders = self.dataset.test_dataloader(**self.hparams.loader)
val_loader_names, val_loaders = self._eval_dataloaders_names(val_loaders, "val")
test_loader_names, test_loaders = self._eval_dataloaders_names(
test_loaders, "test"
)
# Duplicate datasets for ema
if self.hparams.train.ema > 0.0:
val_loader_names += [name + "/ema" for name in val_loader_names]
val_loaders = val_loaders + val_loaders
test_loader_names += [name + "/ema" for name in test_loader_names]
test_loaders = test_loaders + test_loaders
# adding option to only have val loader at eval (e.g., if test is duplicate)
eval_loader_names = []
eval_loaders = []
if not self.hparams.train.get("remove_val_loader_in_eval", False):
eval_loader_names += val_loader_names
eval_loaders += val_loaders
if not self.hparams.train.get("remove_test_loader_in_eval", False):
eval_loader_names += test_loader_names
eval_loaders += test_loaders
return eval_loader_names, eval_loaders
def val_dataloader(self):
val_loader_names, val_loaders = self._eval_dataloaders()
self.val_loader_names = val_loader_names
return val_loaders
def test_dataloader(self):
test_loader_names, test_loaders = self._eval_dataloaders()
self.test_loader_names = ["final/" + name for name in test_loader_names]
return test_loaders
# pytorch-lightning utils and entrypoint
def create_trainer(config, **kwargs):
callbacks: List[pl.Callback] = []
logger = None
# WandB Logging
if config.get("wandb") is not None:
# Pass in wandb.init(config=) argument to get the nice 'x.y.0.z' hparams logged
# Can pass in config_exclude_keys='wandb' to remove certain groups
import wandb
logger = CustomWandbLogger(
config=utils.to_dict(config, recursive=True),
settings=wandb.Settings(start_method="fork"),
**config.wandb,
)
# Lightning callbacks
if "callbacks" in config:
for _name_, callback in config.callbacks.items():
if config.get("wandb") is None and _name_ in ["learning_rate_monitor"]:
continue
log.info(f"Instantiating callback <{registry.callbacks[_name_]}>")
callback._name_ = _name_
callbacks.append(utils.instantiate(registry.callbacks, callback))
# Add ProgressiveResizing callback
if config.callbacks.get("progressive_resizing", None) is not None:
num_stages = len(config.callbacks.progressive_resizing.stage_params)
log.info(f"Progressive Resizing: {num_stages} stages")
for i, e in enumerate(config.callbacks.progressive_resizing.stage_params):
# Stage params are resolution and epochs, pretty print
log.info(f"\tStage {i}: {e['resolution']} @ {e['epochs']} epochs")
# Configure ddp automatically
n_devices = config.trainer.get('devices', 1)
if isinstance(n_devices, Sequence): # trainer.devices could be [1, 3] for example
n_devices = len(n_devices)
if n_devices > 1 and config.trainer.get('strategy', None) is None:
config.trainer.strategy = dict(
_target_='pytorch_lightning.strategies.DDPStrategy',
find_unused_parameters=False,
# https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations
gradient_as_bucket_view=True,
)
# Init lightning trainer
log.info(f"Instantiating trainer <{config.trainer._target_}>")
# special processing for seqlen warmup scheduler (reload)
trainer = hydra.utils.instantiate(config.trainer, callbacks=callbacks, logger=logger)
return trainer
def fsspec_exists(filename):
fs, _ = fsspec.core.url_to_fs(filename)
return fs.exists(filename)
def train(config):
if config.train.seed is not None:
pl.seed_everything(config.train.seed, workers=True)
trainer = create_trainer(config)
model = SequenceLightningModule(config)
# Load pretrained_model if specified
if config.train.get("pretrained_model_path", None) is not None:
# PTL style. Note, method returns a new model object, and need to pass config.
model = SequenceLightningModule.load_from_checkpoint(
config.train.pretrained_model_path,
config=config,
strict=config.train.pretrained_model_strict_load,
)
# Run initial validation epoch (useful for debugging, fine-tuning)
if config.train.validate_at_start:
log.info("Running validation before training")
trainer.validate(model)
log.info(f'{config.train.ckpt=} {fsspec_exists(config.train.ckpt)=}')
# if config.train.get("compile_model", False):
# model = torch.compile(model, mode="reduce-overhead")
if config.train.ckpt is not None and fsspec_exists(config.train.ckpt):
trainer.fit(model, ckpt_path=config.train.ckpt)
else:
trainer.fit(model)
if config.train.test:
if config.train.get("cross_validation", False): # First, load the best validation model
best_val_ckpt = os.path.join(
model.hparams.callbacks.model_checkpoint.dirpath,
f"{model.hparams.callbacks.model_checkpoint.filename}.ckpt",
)
# Update config so we do not load just the backbone
config.train.pretrained_model_state_hook.update({"_name_": None})
# Remove validation loader
config.train.update({"remove_val_loader_in_eval": True})
config.train.update({"remove_test_loader_in_eval": False})
ckpt = torch.load(best_val_ckpt)
log.info(f"Loaded best validation checkpoint from epoch {ckpt['epoch']}")
trainer.validate(model, ckpt_path=best_val_ckpt)
else:
trainer.validate(model)
@hydra.main(config_path="configs", config_name="config.yaml")
def main(config: OmegaConf):
# Process config:
# - register evaluation resolver
# - filter out keys used only for interpolation
# - optional hooks, including disabling python warnings or debug friendly configuration
config = utils.train.process_config(config)
# if config.train.get("compile_model", False):
# # See: https://github.com/arogozhnikov/einops/wiki/Using-torch.compile-with-einops
# from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1
# allow_ops_in_compiled_graph()
# Pretty print config using Rich library
utils.train.print_config(config, resolve=True)
train(config)
if __name__ == "__main__":
main()
================================================
FILE: vep_embeddings.py
================================================
"""Dump model embeddings for VEP classification task.
"""
import argparse
import os
from functools import partial
from os import path as osp
from typing import Dict, Iterable, Optional
import enformer_pytorch
import fsspec
import torch
import torch.distributed as dist
import torch.nn as nn
from datasets import load_dataset, load_from_disk
from sklearn import preprocessing
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from tqdm.auto import tqdm
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, DefaultDataCollator
from src.dataloaders.utils.rc import string_reverse_complement
from src.utils.train import get_logger
WINDOW_SIZE_BP = 1536
log = get_logger(__name__)
class DNAEmbeddingModel(nn.Module):
"""Wrapper around HF model.
Args:
model_name_or_path: str, path to HF model.
"""
def __init__(
self,
model_name_or_path: str,
):
super().__init__()
self.model_name_or_path = model_name_or_path
# Enformer uses different library for loading
if "enformer" in model_name_or_path.lower():
self.backbone = enformer_pytorch.from_pretrained(
model_name_or_path,
use_tf_gamma=False,
use_checkpointing=True
)
# NT model is not compatible with AutoModel class
elif "nucleotide-transformer" in model_name_or_path.lower():
# NT LM `backbone` is under the `.esm` attribute
self.backbone = AutoModelForMaskedLM.from_pretrained(model_name_or_path, trust_remote_code=True).esm
else:
self.backbone = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
def forward(self, input_ids):
"""Backbone forward pass to retrieve last_hidden_state."""
if "enformer" in self.model_name_or_path.lower():
# Enformer forward pass has different signature
return self.backbone(input_ids, return_embeddings=True)[1]
return self.backbone(input_ids).last_hidden_state
class EnformerTokenizer:
"""Enformer tokenizer."""
# Order is important here! (See: https://github.com/lucidrains/enformer-pytorch?tab=readme-ov-file#usage)
pad_token = "P" # Padding token should be a character to avoid issues with tokenization
encode_map = {"A": 0, "C": 1, "G": 2, "T": 3, "N": 4, pad_token: -1}
@classmethod
def encode(
cls, seq: str, max_length: Optional[int] = None, truncation: Optional[bool] = False
) -> Iterable[int]:
"""Convert bp to token ids."""
if max_length is not None:
assert max_length >= 0, "max_length should be a positive integer."
if len(seq) < max_length:
seq = seq + cls.pad_token * (max_length - len(seq))
elif truncation:
seq = seq[:max_length]
return [cls.encode_map[bp] for bp in seq.upper()]
@classmethod
def batch_encode_plus(
cls, seqs: Iterable[str], max_length: Optional[int] = None, truncation: Optional[bool] = False,
**kwargs, # ensures compatibility with HF tokenizer-like API
) -> Dict[str, Iterable[Iterable[int]]]:
"""Batch encode sequences using HF tokenizer-like API."""
input_ids = [cls.encode(seq, max_length=max_length, truncation=truncation) for seq in seqs]
return {"input_ids": input_ids}
def setup_distributed():
"""Set environment variables for distributed runs."""
dist.init_process_group("nccl")
def cleanup_distributed():
"""Clean up processes from distributed runs."""
dist.destroy_process_group()
def fsspec_exists(filename):
"""Check if file exists in manner compatible with fsspec."""
fs, _ = fsspec.core.url_to_fs(filename)
return fs.exists(filename)
def fsspec_listdir(dirname):
"""Listdir in manner compatible with fsspec."""
fs, _ = fsspec.core.url_to_fs(dirname)
return fs.ls(dirname)
# Processing functions
def recast_chromosome_tissue_dist2TSS(examples):
"""Recast chromosome to int."""
return {
"chromosome": -1 if examples["chromosome"] == "X" else int(examples["chromosome"]),
"tissue": examples["tissue"],
"distance_to_nearest_tss": examples["distance_to_nearest_tss"]
}
def tokenize_variants(examples, tokenizer, max_length: int):
"""Tokenize sequence.
Args:
examples: (batch of) items from the dataset.
tokenizer: AutoTokenizer.
max_length: int.
Returns:
dict with values as list of token ids.
"""
ref_tokenized = tokenizer.batch_encode_plus(
examples["ref_forward_sequence"],
add_special_tokens=False,
return_attention_mask=False,
max_length=max_length,
truncation=True,
)
alt_tokenized = tokenizer.batch_encode_plus(
examples["alt_forward_sequence"],
add_special_tokens=False,
return_attention_mask=False,
max_length=max_length,
truncation=True,
)
ref_rc_tokenized = tokenizer.batch_encode_plus(
[string_reverse_complement(seq) for seq in examples["ref_forward_sequence"]],
add_special_tokens=False,
return_attention_mask=False,
max_length=max_length,
truncation=True,
)
alt_rc_tokenized = tokenizer.batch_encode_plus(
[string_reverse_complement(seq) for seq in examples["alt_forward_sequence"]],
add_special_tokens=False,
return_attention_mask=False,
max_length=max_length,
truncation=True,
)
return {
"ref_input_ids": ref_tokenized["input_ids"],
"alt_input_ids": alt_tokenized["input_ids"],
"ref_rc_input_ids": ref_rc_tokenized["input_ids"],
"alt_rc_input_ids": alt_rc_tokenized["input_ids"],
}
def find_variant_idx(examples):
"""Find token location that differs between reference and variant sequence.
Args:
examples: items from the dataset (not batched).
Returns:
dict with values index of difference.
"""
# Guess that variant is at halfway point
idx = len(examples["ref_input_ids"]) // 2
if examples["ref_input_ids"][idx] == examples["alt_input_ids"][idx]:
# If no, loop through sequence and find variant location
idx = -1
for i, (ref, alt) in enumerate(zip(examples["ref_input_ids"], examples["alt_input_ids"])):
if ref != alt:
idx = i
# Same as above, but for reverse complement
rc_idx = len(examples["ref_rc_input_ids"]) // 2 - 1
if examples["ref_rc_input_ids"][rc_idx] == examples["alt_rc_input_ids"][rc_idx]:
rc_idx = -1
for i, (ref, alt) in enumerate(zip(examples["ref_rc_input_ids"], examples["alt_rc_input_ids"])):
if ref != alt:
rc_idx = i
return {"variant_idx": idx, "rc_variant_idx": rc_idx}
def prepare_dataset(args, tokenizer):
"""Prepare or load the tokenized dataset."""
# Data Preprocessing
num_tokens = args.seq_len // args.bp_per_token
# Load data
cache_dir = osp.join(
os.getenv("HF_HOME"), "datasets", "InstaDeepAI___genomics-long-range-benchmark",
"variant_effect_gene_expression", f"seqlen={args.seq_len}"
)
if "nucleotide-transformer" in args.model_name_or_path.lower(): # NT uses 6-mers, so tokenization is different
preprocessed_cache_file = osp.join(cache_dir, "6mer_token_preprocessed")
elif "enformer" in args.model_name_or_path.lower():
# Enformer tokenization requires having vocab of just `A,C,G,T,N` (in that order)
preprocessed_cache_file = osp.join(cache_dir, "enformer_char_token_preprocessed")
else:
preprocessed_cache_file = osp.join(cache_dir, "char_token_preprocessed")
log.warning(f"Cache dir: {cache_dir}")
log.warning(f"Cache dir preprocessed: {preprocessed_cache_file}")
if not fsspec_exists(preprocessed_cache_file):
if dist.get_rank() == 0:
dataset = load_dataset(
"InstaDeepAI/genomics-long-range-benchmark",
task_name="variant_effect_gene_expression",
sequence_length=args.seq_len,
load_from_cache=False,
)
log.warning("Dataset loaded. Cached to disk:")
log.warning(osp.dirname(list(dataset.cache_files.values())[0][0]["filename"]))
try:
del dataset["validation"] # `validation` split is empty
except KeyError:
pass
# Process data
dataset = dataset.filter(
lambda example: example["ref_forward_sequence"].count('N') < 0.005 * args.seq_len,
desc="Filter N's"
)
dataset = dataset.map(
recast_chromosome_tissue_dist2TSS,
remove_columns=["chromosome", "tissue", "distance_to_nearest_tss"],
desc="Recast chromosome"
)
dataset = dataset.map(
partial(tokenize_variants, tokenizer=tokenizer, max_length=num_tokens),
batch_size=1000,
batched=True,
remove_columns=["ref_forward_sequence", "alt_forward_sequence"],
desc="Tokenize"
)
dataset = dataset.map(find_variant_idx, desc="Find variant idx")
dataset.save_to_disk(preprocessed_cache_file)
dist.barrier() # Processes need to wait for dataset to be saved to disk (if not already done)
dataset = load_from_disk(preprocessed_cache_file)
log.warning(f"Loaded preprocessed dataset from {preprocessed_cache_file}")
log.warning(dataset)
return dataset
def get_backbone_model(args, device):
"""Get the backbone model."""
model = DNAEmbeddingModel(
model_name_or_path=args.model_name_or_path,
)
model.eval()
return DDP(model.to(device))
def concat_storage_dict_values(storage_dict):
"""Helper method that combines lists of tensors in storage_dict into a single torch.Tensor."""
return {key: torch.cat(storage_dict[key], dim=0) for key in storage_dict.keys()}
def dump_embeddings(args, dataset, model, device):
"""Dump embeddings to disk."""
def extract_embeddings(item_ref, item_alt, variant_idx):
"""Extract embedding representation from last layer outputs
Args:
item_ref: torch.Tensor, shape (batch_size, seq_len, hidden_size) Ref embedding
item_alt: torch.Tensor, shape (batch_size, seq_len, hidden_size) Alt embedding
variant_idx: torch.Tensor, shape (batch_size,) Index of variant
Returns:
layer_metrics: dict, with values to save to disk
"""
layer_metrics = {}
# Compute windowed statistics
if "enformer" in args.model_name_or_path.lower():
window_size = WINDOW_SIZE_BP // 128 # Enformer's receptive field is 128
# We also need to override variant_idx since Enformer model reduces to target_length of 896
variant_idx = torch.ones_like(variant_idx) * item_ref.size(1) // 2
else:
window_size = WINDOW_SIZE_BP // args.bp_per_token
# Add 1 so that window is: [window // 2 - SNP - window // 2]
start, end = -window_size // 2, window_size // 2 + 1
expanded_indices = torch.arange(start, end, device=item_ref.device).unsqueeze(0) + \
variant_idx.unsqueeze(1).to(item_ref.device)
expanded_indices = torch.clamp(expanded_indices, 0, item_ref.size(1) - 1) # Handle boundary conditions
tokens_window_ref = torch.gather(
item_ref, 1,
expanded_indices.unsqueeze(-1).expand(-1, -1, item_ref.size(2))
).mean(dim=1)
tokens_window_alt = torch.gather(
item_alt, 1,
expanded_indices.unsqueeze(-1).expand(-1, -1, item_ref.size(2))
).mean(dim=1)
layer_metrics["concat_avg_ws"] = torch.cat([tokens_window_ref, tokens_window_alt], dim=-1)
return layer_metrics
embeds_path = osp.join(args.downstream_save_dir, args.name)
os.makedirs(embeds_path, exist_ok=True)
dataloader_params = {
"batch_size": args.embed_dump_batch_size,
"collate_fn": DefaultDataCollator(return_tensors="pt"),
"num_workers": args.num_workers,
"pin_memory": False,
"shuffle": False,
"drop_last": True
}
# Process label_encoder = preprocessing.LabelEncoder()
label_encoder = preprocessing.LabelEncoder()
label_encoder.fit(dataset["test"]["tissue"])
train_tissue_embed = label_encoder.transform(dataset["train"]["tissue"])
dataset["train"] = dataset["train"].add_column("tissue_embed", train_tissue_embed)
test_tissue_embed = label_encoder.transform(dataset["test"]["tissue"])
dataset["test"] = dataset["test"].add_column("tissue_embed", test_tissue_embed)
if not all([
fsspec_exists(osp.join(embeds_path, f"{split_name}_embeds_combined.pt")) for split_name in dataset.keys()
]):
for split_name, split in dataset.items():
sampler = DistributedSampler(
split,
shuffle=dataloader_params.get("shuffle", False),
drop_last=dataloader_params.get("drop_last", True),
)
dl = DataLoader(split, **dataloader_params, sampler=sampler)
storage_dict = {
"concat_avg_ws": [],
"rc_concat_avg_ws": [],
"chromosome": [],
"labels": [],
"distance_to_nearest_tss": [],
"tissue_embed": [],
}
with torch.no_grad():
for batch_idx, batch in tqdm(
enumerate(dl), total=len(dl), desc=f"[RANK {dist.get_rank()}] Embedding {split_name}",
disable=dist.get_rank() != 0 # Only rank 0 updates pbar
):
for key in ["chromosome", "labels", "distance_to_nearest_tss", "tissue_embed"]:
storage_dict[key].append(batch[key].to("cpu", non_blocking=True))
with torch.autocast(device_type="cuda", dtype=torch.float16):
output_alt = model(batch["alt_input_ids"].to(device))
output_ref = model(batch["ref_input_ids"].to(device))
if args.rcps:
num_channels = output_alt.size(-1)
# Flip along length and channel dims to preserve RC equivariance
# i.e. output_rc(RC(inputs)) = outputs(inputs)
output_alt_rc = output_alt[..., num_channels // 2:].contiguous().flip(dims=[1, 2])
output_ref_rc = output_ref[..., num_channels // 2:].contiguous().flip(dims=[1, 2])
output_alt = output_alt[..., :num_channels // 2]
output_ref = output_ref[..., :num_channels // 2]
else:
# Flip along length dim so variant_idx aligns
output_alt_rc = model(batch["alt_rc_input_ids"].to(device)).contiguous().flip(dims=[1])
output_ref_rc = model(batch["ref_rc_input_ids"].to(device)).contiguous().flip(dims=[1])
metrics = extract_embeddings(
item_ref=output_ref,
item_alt=output_alt,
variant_idx=batch["variant_idx"],
)
for key, value in metrics.items():
storage_dict[key].append(metrics[key].to("cpu", non_blocking=True))
metrics_rc = extract_embeddings(
item_ref=output_ref_rc,
item_alt=output_alt_rc,
variant_idx=batch["variant_idx"],
)
for key, value in metrics_rc.items():
storage_dict[f"rc_{key}"].append(metrics_rc[key].to("cpu", non_blocking=True))
if batch_idx % 100 == 0:
# Every machine should print progress updates
print(f"[RANK {dist.get_rank()}] Completed index: {batch_idx}/{len(dl)}")
storage_dict_temp = concat_storage_dict_values(storage_dict)
with fsspec.open(osp.join(embeds_path, f"{split_name}_embeds_{dist.get_rank()}.pt"), "wb") as f:
torch.save(storage_dict_temp, f)
print(f"[RANK {dist.get_rank()}] Saved {split_name} to {osp.join(embeds_path, f'{split_name}_embeds_{dist.get_rank()}.pt')}")
else:
log.warning("Embeddings already exist, skipping!")
def combine_embeddings(embeds_path):
"""Combine embeddings from different files."""
# Check if combined embeddings exist, and if not, aggregate them
for split in ["train", "test"]:
if not fsspec_exists(osp.join(embeds_path, f"{split}_embeds_combined.pt")):
storage_dict = {
"concat_avg_ws": [],
"rc_concat_avg_ws": [],
"chromosome": [],
"labels": [],
"distance_to_nearest_tss": [],
"tissue_embed": [],
}
for filename in fsspec_listdir(embeds_path):
if f"{split}_embeds_" in filename:
log.warning(f"Loading data from: {filename}")
with fsspec.open(filename, "rb") as f:
tmp_data = torch.load(f)
for key in storage_dict.keys():
storage_dict[key].append(tmp_data[key])
storage_dict = concat_storage_dict_values(storage_dict)
log.warning(f"Saving combined data to: {embeds_path}/{split}_embeds_combined.pt")
with fsspec.open(osp.join(embeds_path, f"{split}_embeds_combined.pt"), "wb") as f:
torch.save(storage_dict, f)
def main(args):
"""Main entry point."""
# Reproducibility
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
# Init distributed
log.warning("Initializing distributed...")
dist.init_process_group("nccl")
print(f"[RANK {dist.get_rank()}] Distributed initialized: rank {dist.get_rank()}") # All processes print this
# Setup device
device = torch.device(f"cuda:{dist.get_rank()}")
print(f"[RANK {dist.get_rank()}] Using device: {device}.") # All processes print this
# Init tokenizer
if "enformer" in args.model_name_or_path.lower():
# Enformer tokenization requires having vocab of just `A,C,G,T,N` (in that order)
tokenizer = EnformerTokenizer()
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
# Get dataset
dist.barrier()
dataset = prepare_dataset(args, tokenizer)
# Get model
dist.barrier()
model = get_backbone_model(args, device)
log.warning("Model loaded.")
# Dump embeddings
dist.barrier()
dump_embeddings(args, dataset, model, device)
# Combine embeddings into single file
dist.barrier()
cleanup_distributed()
combine_embeddings(osp.join(args.downstream_save_dir, args.name))
if __name__ == "__main__":
torch.multiprocessing.set_sharing_strategy('file_system')
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--seq_len", type=int, default=131072,
help="Sequence length (in bp)..")
parser.add_argument("--bp_per_token", type=int, default=1,
help="Number of base pairs per token.")
parser.add_argument("--model_name_or_path", type=str, default=None)
parser.add_argument("--downstream_save_dir", type=str, default="./outputs/downstream/vep_embeddings",
help="Directory to save downstream task.")
parser.add_argument("--name", type=str, default=None, help="Embeddings model name.")
parser.add_argument("--rcps", default=False, action="store_true", help="Use RCPS.")
parser.add_argument("--no-rcps", dest="rcps", action="store_false", help="Do not use RCPS.")
parser.add_argument("--embed_dump_batch_size", type=int, default=1,
help="Batch size for embedding dump.")
parser.add_argument("--num_workers", type=int, default=0, help="Number of workers.")
opts, _ = parser.parse_known_args()
log.warning("*** Args ************************")
for k, v in vars(opts).items():
log.warning(f" - {k}: {v}")
log.warning("******************************\n")
main(opts)
================================================
FILE: vep_svm.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "db878bc1",
"metadata": {},
"source": [
"## Imports and Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c79a903b",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import random\n",
"import time\n",
"from os import path as osp\n",
"\n",
"import fsspec\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"import torch\n",
"from sklearn.metrics import roc_auc_score\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.svm import SVC\n",
"from tqdm.auto import tqdm"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4034f167",
"metadata": {},
"outputs": [],
"source": [
"DIST_TO_TSS = [[0, 30_000], [30_000, 100_000], [100_000, np.infty]]\n",
"USE_TISSUE = [True] # used as another for loop for fitting SVM, whether to use tissue embed or not\n",
"Cs = [1, 5, 10] # for loop in fitting SVM, inverse of L2 penalty (sklearn hyperparam)\n",
"PATH_TO_OUTPUTS = \"./outputs/downstream/vep_embeddings\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55c58437",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def fsspec_exists(filename: str) -> bool:\n",
" \"\"\"Check if file exists in manner compatible with fsspec.\"\"\"\n",
" fs, _ = fsspec.core.url_to_fs(filename)\n",
" return fs.exists(filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18522e17",
"metadata": {},
"outputs": [],
"source": [
"def dataset_nan_filter(data: dict, data_key: str) -> dict:\n",
" \"\"\"Filter any items that have NaN in embedding within TSS bucket\"\"\"\n",
" mask_out = torch.logical_or(\n",
" torch.any(data[data_key].isnan(), dim=1),\n",
" torch.any(data[f\"rc_{data_key}\"].isnan(), dim=1)\n",
" )\n",
" \n",
" new_data = dict()\n",
" for data_key in data.keys():\n",
" new_data[data_key] = data[data_key][~mask_out]\n",
"\n",
" return new_data\n",
"\n",
"def dataset_tss_filter(data: dict, min_distance: int, max_distance: int) -> dict:\n",
" \"\"\"Filter the data to items that fall within TSS bucket\"\"\"\n",
" distance_mask = ((data[\"distance_to_nearest_tss\"] >= min_distance) \n",
" & (data[\"distance_to_nearest_tss\"] <= max_distance))\n",
" new_data = dict()\n",
" for data_key in data.keys():\n",
" new_data[data_key] = data[data_key][distance_mask]\n",
"\n",
" return new_data"
]
},
{
"cell_type": "markdown",
"id": "ef3d1006",
"metadata": {},
"source": [
"## Specify which models to test"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4629cb30",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Embeddings to test\n",
"model_dict = {\n",
" \"HyenaDNA\": dict(\n",
" embed_path=\"hyena_downstream-seqlen=131k\",\n",
" rc_aug=False,\n",
" conjoin_train=False,\n",
" conjoin_test=False,\n",
" key=\"concat_avg_ws\",\n",
" ),\n",
" \"Caduceus-Ph\": dict(\n",
" embed_path=\"caduceus-ph_downstream-seqlen=131k\",\n",
" rc_aug=False,\n",
" conjoin_train=False,\n",
" conjoin_test=True,\n",
" key=\"concat_avg_ws\",\n",
" ),\n",
" \"Caduceus w/o Equiv.\": dict(\n",
" embed_path=\"caduceus-ph_downstream-seqlen=131k\",\n",
" rc_aug=False,\n",
" conjoin_train=False,\n",
" conjoin_test=False,\n",
" key=\"concat_avg_ws\",\n",
" ),\n",
" \"Caduceus-PS\": dict(\n",
" embed_path=\"caduceus-ps_downstream-seqlen=131k\",\n",
" rc_aug=False,\n",
" conjoin_train=True,\n",
" conjoin_test=False,\n",
" key=\"concat_avg_ws\",\n",
" ),\n",
" \"Enformer\": dict(\n",
" embed_path=\"enformer-seqlen=196k\",\n",
" rc_aug=False,\n",
" conjoin_train=False,\n",
" conjoin_test=False,\n",
" key=\"concat_avg_ws\",\n",
" ),\n",
" \"NTv2\": dict(\n",
" embed_path=\"NTv2_downstream-seqlen=12k\",\n",
" rc_aug=False,\n",
" conjoin_train=False,\n",
" conjoin_test=False,\n",
" key=\"concat_avg_ws\",\n",
" ),\n",
"}"
]
},
{
"cell_type": "markdown",
"id": "12e64367",
"metadata": {},
"source": [
"## Fit and test SVM"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6eaeb519-5c35-4fba-a09b-2d47c122320d",
"metadata": {
"scrolled": false,
"tags": []
},
"outputs": [],
"source": [
"metrics = {\n",
" \"model_name\": [],\n",
" \"bucket_id\": [],\n",
" \"use_tissue\": [],\n",
" \"C\": [],\n",
" \"seed\": [],\n",
" \"AUROC\": [],\n",
"}\n",
"\n",
"for model_name, downstream_kwargs in model_dict.items():\n",
" print(f\"********** Gathering results for: {model_name} **********\")\n",
" embed_path = downstream_kwargs[\"embed_path\"]\n",
" rc_aug = downstream_kwargs[\"rc_aug\"]\n",
" conjoin_train = downstream_kwargs[\"conjoin_train\"]\n",
" conjoin_test = downstream_kwargs[\"conjoin_test\"]\n",
" key = downstream_kwargs[\"key\"]\n",
" \n",
" if \"NT\" in model_name: assert (rc_aug == False) and (conjoin_train == False) and (conjoin_test == False)\n",
" \n",
" base_embeds_path = PATH_TO_OUTPUTS\n",
" embeds_path = osp.join(base_embeds_path, embed_path)\n",
" \n",
" print(f\"Embed Path: {embeds_path}\")\n",
" with fsspec.open(osp.join(embeds_path, \"train_embeds_combined.pt\"), \"rb\") as f:\n",
" train_val_ds_raw = torch.load(f, map_location=\"cpu\")\n",
" train_val_ds_raw = dataset_nan_filter(train_val_ds_raw, data_key=key)\n",
" with fsspec.open(osp.join(embeds_path, \"test_embeds_combined.pt\"), \"rb\") as f:\n",
" test_ds_raw = torch.load(f, map_location=\"cpu\")\n",
" test_ds_raw = dataset_nan_filter(test_ds_raw, data_key=key)\n",
" print(f\"Total Train size: {len(train_val_ds_raw[key])},\", end=\" \")\n",
" print(f\"Total Test size: {len(test_ds_raw[key])},\", end=\" \")\n",
" print(f\"Shape: {test_ds_raw[key].shape[1:]}\")\n",
"\n",
"\n",
" for bucket_id, (min_dist, max_dist) in enumerate(DIST_TO_TSS):\n",
" # Filter data to desired TSS bucket\n",
" train_val_ds_filter = dataset_tss_filter(train_val_ds_raw, min_dist, max_dist)\n",
" test_ds_filter = dataset_tss_filter(test_ds_raw, min_dist, max_dist)\n",
" print(f\"- TSS bucket: [{min_dist}, {max_dist}],\", end=\" \")\n",
" print(f\"Train size: {len(train_val_ds_filter[key])},\", end=\" \")\n",
" print(f\"Test size: {len(test_ds_filter[key])}\")\n",
" \n",
" for use_tissue in USE_TISSUE:\n",
" for C in Cs:\n",
" for seed in range(1, 6): \n",
" # Re-seed for SVM fitting\n",
" random.seed(seed)\n",
" np.random.seed(seed)\n",
" torch.manual_seed(seed)\n",
" torch.cuda.manual_seed_all(seed)\n",
"\n",
" svm_clf = make_pipeline(\n",
" StandardScaler(),\n",
" SVC(C=C, random_state=seed),\n",
" )\n",
"\n",
" # Setup Train/Test dataset\n",
" if conjoin_train:\n",
" X = np.array(train_val_ds_filter[key])\n",
" X += np.array(train_val_ds_filter[f\"rc_{key}\"])\n",
" X /= 2\n",
" else:\n",
" X = np.array(train_val_ds_filter[key])\n",
" X_with_tissue = np.concatenate(\n",
" [X, np.array(train_val_ds_filter[\"tissue_embed\"])[..., None]],\n",
" axis=-1\n",
" )\n",
" y = train_val_ds_filter[\"labels\"]\n",
" if conjoin_train or conjoin_test:\n",
" X_test = np.array(test_ds_filter[key])\n",
" X_test += np.array(test_ds_filter[f\"rc_{key}\"])\n",
" X_test /= 2\n",
" else:\n",
" X_test = np.array(test_ds_filter[key])\n",
" X_test_with_tissue = np.concatenate(\n",
" [X_test, np.array(test_ds_filter[\"tissue_embed\"])[..., None]],\n",
" axis=-1\n",
" )\n",
" y_test = test_ds_filter[\"labels\"]\n",
"\n",
" print(f\"\\tFitting SVM ({use_tissue=}, {C=}, {seed=})...\", end=\" \")\n",
" \n",
" mask = np.random.choice(len(X), size=5000, replace= 5000 > len(X) )\n",
" if use_tissue: \n",
" X_train = X_with_tissue[mask]\n",
" X_test = X_test_with_tissue\n",
" else: \n",
" X_train = X[mask]\n",
" y_train = y[mask]\n",
"\n",
" start = time.time()\n",
" svm_clf.fit(X_train, y_train)\n",
" svm_y_pred = svm_clf.predict(X_test)\n",
" svm_aucroc = roc_auc_score(y_test, svm_y_pred)\n",
" end = time.time()\n",
" print(f\"Completed! ({end - start:0.3f} s) -\", end=\" \")\n",
" print(f\"AUROC: {svm_aucroc}\")\n",
" \n",
" metrics[\"model_name\"] += [model_name]\n",
" metrics[\"bucket_id\"] += [bucket_id]\n",
" metrics[\"use_tissue\"] += [use_tissue]\n",
" metrics[\"C\"] += [C]\n",
" metrics[\"seed\"] += [seed]\n",
" metrics[\"AUROC\"] += [svm_aucroc]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "597b0fe9",
"metadata": {},
"outputs": [],
"source": [
"df_metrics = pd.DataFrame.from_dict(metrics)\n",
"df_metrics.to_csv(osp.join(PATH_TO_OUTPUTS, \"SVM_results.csv\"))"
]
},
{
"cell_type": "markdown",
"id": "03e06a25",
"metadata": {},
"source": [
"## Plot results"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a362d3fa",
"metadata": {},
"outputs": [],
"source": [
"model_name_replacement = {\n",
" \"Caduceus w/o Equiv.\": \"Caduceus w/o\\nEquiv. (7.7M)\",\n",
" \"Caduceus-Ph\": \"Caduceus-Ph\\n(7.7M)\",\n",
" \"Caduceus-PS\": \"Caduceus-PS\\n(7.7M)\",\n",
" \"HyenaDNA\": \"HyenaDNA\\n(6.6M)\",\n",
" \"NTv2\": \"NTv2\\n(500M)\",\n",
" \"Enformer\": \"Enformer\\n(252M)\",\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "85b1c4fb",
"metadata": {},
"outputs": [],
"source": [
"# Formatting changes to df\n",
"df = pd.read_csv(osp.join(PATH_TO_OUTPUTS, \"SVM_results.csv\"), index_col=0)\n",
"df_display = df.rename(columns={\"bucket_id\": \"Distance to TSS\"})\n",
"df_display = df_display.replace({\"Distance to TSS\": {0: \"0 - 30k\", 1: \"30 - 100k\", 2: \"100k+\"}})\n",
"df_display = df_display.replace({\"model_name\": model_name_replacement})\n",
"\n",
"# Take average over seeds\n",
"df_display_selected = df_display.groupby(\n",
" [\"model_name\", \"Distance to TSS\", \"use_tissue\", \"C\"]\n",
").agg(AUROC=(\"AUROC\", np.mean)).reset_index()\n",
"\n",
"# Select best hyperparam by model/bucket\n",
"best_ids = df_display_selected.groupby([\"model_name\", \"Distance to TSS\"])[\"AUROC\"].idxmax()\n",
"df_display_selected = df_display_selected.loc[best_ids.reset_index()[\"AUROC\"].values]\n",
"display(\n",
" df_display_selected.pivot(\n",
" index=\"model_name\", columns=\"Distance to TSS\", values=\"AUROC\"\n",
" )[[\"0 - 30k\", \"30 - 100k\", \"100k+\"]]\n",
")\n",
"display(df_display_selected[[\"model_name\", \"Distance to TSS\", \"C\", \"use_tissue\"]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09a7f4a9",
"metadata": {},
"outputs": [],
"source": [
"# Filter results to selected hyperparams\n",
"df_plot = pd.merge(\n",
" df_display, df_display_selected,\n",
" on=[\"model_name\", \"Distance to TSS\", \"use_tissue\", \"C\"]\n",
").drop(columns=[\"AUROC_y\"]).rename(columns={\"AUROC_x\": \"AUROC\"})\n",
"\n",
"# Plot results by distance to TSS\n",
"sns.set_style(\"whitegrid\")\n",
"g = sns.catplot(\n",
" data=df_plot,\n",
" x=\"model_name\",\n",
" y=\"AUROC\",\n",
" col=\"Distance to TSS\",\n",
" hue=\"Distance to TSS\",\n",
" kind=\"bar\",\n",
" errorbar=\"sd\",\n",
" height=12,\n",
" aspect=1,\n",
" dodge=False,\n",
" order=list(model_name_replacement.values()),\n",
")\n",
"g.set_xticklabels(rotation=60, fontsize=30)\n",
"g.set(xlabel=\"\")\n",
"g.set(ylim=(0.4, 0.7))\n",
"g.set_titles(template=\"Dist. to TSS: {col_name}\", fontsize=40)\n",
"g.fig.suptitle(\"Predicting Effects of Variants on Gene Expression\", y=1.1, fontsize=40)\n",
"g._legend.remove()\n",
"# Display bar values\n",
"# (See: https://stackoverflow.com/questions/55586912/seaborn-catplot-set-values-over-the-bars)\n",
"for ax in tqdm(g.axes.ravel(), leave=False):\n",
" title = ax.title.get_text()\n",
" ax.set_title(title, fontsize=35)\n",
" for c in tqdm(ax.containers, leave=False):\n",
" labels = [f\"{v.get_height():0.3f}\" for v in c]\n",
" ax.bar_label(c, labels=labels, label_type=\"center\", color=\"white\", weight=\"bold\", fontsize=24)\n",
"plt.show()\n",
"g.savefig(osp.join(PATH_TO_OUTPUTS, \"SVM_results.png\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a7858241",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}