Repository: kuleshov-group/caduceus Branch: main Commit: 0060a6d8079b Files: 104 Total size: 394.9 KB Directory structure: gitextract_kw76odef/ ├── .gitignore ├── LICENSE ├── README.md ├── caduceus/ │ ├── __init__.py │ ├── configuration_caduceus.py │ ├── modeling_caduceus.py │ ├── modeling_rcps.py │ ├── tests/ │ │ └── test_rcps.py │ └── tokenization_caduceus.py ├── caduceus_env.yml ├── configs/ │ ├── callbacks/ │ │ ├── base.yaml │ │ ├── checkpoint.yaml │ │ ├── gpu_affinity.yaml │ │ ├── rich.yaml │ │ ├── val_every_n_global_steps.yaml │ │ └── wandb.yaml │ ├── config.yaml │ ├── dataset/ │ │ ├── genomic_benchmark.yaml │ │ ├── hg38.yaml │ │ └── nucleotide_transformer.yaml │ ├── experiment/ │ │ └── hg38/ │ │ ├── genomic_benchmark.yaml │ │ ├── genomic_benchmark_cnn.yaml │ │ ├── hg38.yaml │ │ └── nucleotide_transformer.yaml │ ├── loader/ │ │ └── default.yaml │ ├── model/ │ │ ├── caduceus.yaml │ │ ├── genomics_benchmark_cnn.yaml │ │ ├── hyena.yaml │ │ ├── layer/ │ │ │ └── hyena.yaml │ │ └── mamba.yaml │ ├── optimizer/ │ │ ├── adam.yaml │ │ ├── adamw.yaml │ │ └── sgd.yaml │ ├── pipeline/ │ │ ├── genomic_benchmark.yaml │ │ ├── hg38.yaml │ │ └── nucleotide_transformer.yaml │ ├── scheduler/ │ │ ├── constant.yaml │ │ ├── constant_warmup.yaml │ │ ├── cosine.yaml │ │ ├── cosine_warmup.yaml │ │ ├── cosine_warmup_timm.yaml │ │ ├── linear_warmup.yaml │ │ ├── multistep.yaml │ │ ├── plateau.yaml │ │ └── step.yaml │ ├── task/ │ │ ├── lm.yaml │ │ ├── multiclass_classification.yaml │ │ ├── multilabel_classification.yaml │ │ └── regression.yaml │ └── trainer/ │ ├── debug.yaml │ ├── default.yaml │ ├── full.yaml │ └── lm.yaml ├── setup_env.sh ├── slurm_scripts/ │ ├── dump_vep_embeddings.sh │ ├── run_genomics_benchmark.sh │ ├── run_genomics_benchmark_cnn.sh │ ├── run_nucleotide_transformer.sh │ ├── run_pretrain_caduceus.sh │ ├── run_pretrain_hyena.sh │ ├── run_pretrain_mamba.sh │ ├── wrapper_run_genomics.sh │ ├── wrapper_run_genomics_cnn.sh │ └── wrapper_run_nucleotide_transformer.sh ├── src/ │ ├── __init__.py │ ├── callbacks/ │ │ ├── params.py │ │ ├── timer.py │ │ └── validation.py │ ├── dataloaders/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── datasets/ │ │ │ ├── genomic_bench_dataset.py │ │ │ ├── hg38_char_tokenizer.py │ │ │ ├── hg38_dataset.py │ │ │ └── nucleotide_transformer_dataset.py │ │ ├── fault_tolerant_sampler.py │ │ ├── genomics.py │ │ └── utils/ │ │ ├── mlm.py │ │ └── rc.py │ ├── models/ │ │ ├── __init__.py │ │ ├── baseline/ │ │ │ ├── __init__.py │ │ │ └── genomics_benchmark_cnn.py │ │ ├── nn/ │ │ │ ├── __init__.py │ │ │ ├── activation.py │ │ │ ├── adaptive_softmax.py │ │ │ └── utils.py │ │ └── sequence/ │ │ ├── __init__.py │ │ ├── dna_embedding.py │ │ ├── hyena.py │ │ └── long_conv_lm.py │ ├── ops/ │ │ └── fftconv.py │ ├── tasks/ │ │ ├── decoders.py │ │ ├── encoders.py │ │ ├── metrics.py │ │ ├── tasks.py │ │ └── torchmetrics.py │ └── utils/ │ ├── __init__.py │ ├── config.py │ ├── optim/ │ │ └── schedulers.py │ ├── optim_groups.py │ ├── registry.py │ └── train.py ├── train.py ├── vep_embeddings.py └── vep_svm.ipynb ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ data.tar.gz *.tsf *.ckpt .ipynb_checkpoints */.ipynb_checkpoints/* *.lprof .DS_Store .idea/ outputs/ # slurm log files watch_folder/ data # Created by https://www.gitignore.io/api/python # Edit at https://www.gitignore.io/?templates=python ### Python ### # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # Mr Developer .mr.developer.cfg .project .pydevproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # End of https://www.gitignore.io/api/python ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================

Caduceus

# Caduceus ☤: Bi-Directional Equivariant Long-Range DNA Sequence Modeling [[Blog]](https://caduceus-dna.github.io/)   |   [[arXiv]](https://arxiv.org/abs/2403.03234)   |   [[HuggingFace 🤗]](https://huggingface.co/collections/kuleshov-group/caducues-65dcb89b4f54e416ef61c350) This repository contains code for reproducing the results in the paper "Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling," [Schiff et al. (2024)](https://arxiv.org/abs/2403.03234). ## Using Caduceus with 🤗 We have uploaded a pre-trained Caduceus model to the Huggingface hub. The available models are: - Caduceus-Ph: [kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16](https://huggingface.co/kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16) - Trained on sequences of length 131k, with a model size of 256 and 16 layers. - Trained for 50k steps and batch size of 8. - Trained with reverse-complement (RC) data augmentation. - Caduceus-PS: [kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16](https://huggingface.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16) - Trained on sequences of length 131k, with a model size of 256 and 16 layers. - Trained for 50k steps and batch size of 8. - Model is RC equivariant, hence no RC data augmentation is required. You can either use the pre-trained model directly within your trainer scripts or modify the config that initializes the model. To use the pre-trained model for masked language modeling, use the following snippet: ```python from transformers import AutoModelForMaskedLM, AutoTokenizer # See the `Caduceus` collection page on the hub for list of available models. model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForMaskedLM.from_pretrained(model_name) ``` Alternatively, you can instantiate a model from scratch to train on your own data as follows: ```python from transformers import AutoConfig, AutoModelForMaskedLM # Add any config overrides here, see the `config.json` file on the hub for details. config_overrides = {} # See the `Caduceus` collection page on the hub for list of available models. config = AutoConfig.from_pretrained( "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16", **config_overrides, ) model = AutoModelForMaskedLM.from_config(config) ``` ## Getting started in this repository To get started, create a conda environment containing the required dependencies. ```bash conda env create -f caduceus_env.yml ``` Activate the environment. ```bash conda activate caduceus_env ``` Create the following directories to store saved models and slurm logs: ```bash mkdir outputs mkdir watch_folder ``` ## Reproducing Experiments Below, we describe the steps required for reproducing the experiments in the paper. Throughout, the main entry point for running experiments is the [`train.py`](./train.py) script. We also provide sample `slurm` scripts for launching pre-training and downstream fine-tuning experiments in the [`slurm_scripts/`](./slurm_scripts) directory. ### Pretraining on Human Reference Genome (Data downloading instructions are copied from [HyenaDNA repo](https://github.com/HazyResearch/hyena-dna?tab=readme-ov-file#pretraining-on-human-reference-genome)) First, download the Human Reference Genome data. It's comprised of 2 files, 1 with all the sequences (the `.fasta` file), and with the intervals we use (`.bed` file). The file structure should look like ``` data |-- hg38/ |-- hg38.ml.fa |-- human-sequences.bed ``` Download fasta (.fa format) file (of the entire human genome) into `./data/hg38`. ~24 chromosomes in the whole genome (merged into 1 file), each chromosome is a continuous sequence, basically. Then download the .bed file with sequence intervals (contains chromosome name, start, end, split, which then allow you to retrieve from the fasta file). ```bash mkdir -p data/hg38/ curl https://storage.googleapis.com/basenji_barnyard2/hg38.ml.fa.gz > data/hg38/hg38.ml.fa.gz gunzip data/hg38/hg38.ml.fa.gz # unzip the fasta file curl https://storage.googleapis.com/basenji_barnyard2/sequences_human.bed > data/hg38/human-sequences.bed ``` Launch pretraining run using the command line ```bash python -m train \ experiment=hg38/hg38 \ callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \ dataset.max_length=1024 \ dataset.batch_size=1024 \ dataset.mlm=true \ dataset.mlm_probability=0.15 \ dataset.rc_aug=false \ model=caduceus \ model.config.d_model=128 \ model.config.n_layer=4 \ model.config.bidirectional=true \ model.config.bidirectional_strategy=add \ model.config.bidirectional_weight_tie=true \ model.config.rcps=true \ optimizer.lr="8e-3" \ train.global_batch_size=1024 \ trainer.max_steps=10000 \ +trainer.val_check_interval=10000 \ wandb=null ``` or alternatively, if using a cluster that has `slurm` installed, adapt the scripts below: ``` slurm_scripts |-- run_pretrain_caduceus.sh |-- run_pretrain_hyena.sh |-- run_pretrain_mamba.sh ``` and run the training as a batch job: ```bash cd slurm_scripts sbatch run_pretrain_caduceus.sh ``` ### GenomicBenchmarks The [GenomicBenchmarks](https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks) presented in [Grešová et al. (2023)](https://bmcgenomdata.biomedcentral.com/articles/10.1186/s12863-023-01123-8) is comprised of 8 classification tasks. We can launch a downstream fine-tuning run on one of the tasks using the sample command below: ```bash python -m train \ experiment=hg38/genomic_benchmark \ callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \ dataset.dataset_name="dummy_mouse_enhancers_ensembl" \ dataset.train_val_split_seed=1 \ dataset.batch_size=256 \ dataset.rc_aug=false \ +dataset.conjoin_train=false \ +dataset.conjoin_test=false \ loader.num_workers=2 \ model=caduceus \ model._name_=dna_embedding_caduceus \ +model.config_path="" \ +model.conjoin_test=false \ +decoder.conjoin_train=true \ +decoder.conjoin_test=false \ optimizer.lr="1e-3" \ trainer.max_epochs=10 \ train.pretrained_model_path="" \ wandb=null ``` This sample run will fine-tune a pre-trained Caduceus-PS model on the `dummy_mouse_enhancers_ensembl` task. Note some of the additional arguments present here, relative to the pre-training command from [above](#pretraining): - `model.config_path` contains the path model config that was saved during pre-training. This will be saved to the run directory of the pre-training experiment. - `train.pretrained_model_path` contains the path to the pre-trained model checkpoint. - `dataset.conjoin_train` determines whether the dataset will return a single sequence (`dataset.conjoin_train=false`) or the concatenation of a sequence and its reverse complement along `dim=-1`, during downstream fine-tuning training. - `dataset.conjoin_test` is the same as above, but for inference (e.g., validation / test). - `decoder.conjoin_train` determines whether the prediction head (a mean pooling and linear projection in the case of the Genomics Benchmark) is expecting an input tensor of shape `(batch_size, seq_len, d_model)` or `(batch_size, seq_len, d_model, 2)` during downstream fine-tuning training. When set to `true` the decoder is run on `input[..., 0]` and `input[..., 1]` and the results are averaged to produce the final prediction. - `decoder.conjoin_test` is the same as above, but for inference (e.g., validation / test). Note this benchmark only contains a training and test split for each task. Therefore, to have a more principled evaluation, we randomly split the training data into training and validation sets (90/10) using the `dataset.train_val_split_seed` argument. We perform early stopping on validation metric (accuracy) and repeat this for 5 random seeds. As with [pre-training](#pretraining), we can also launch the fine-tuning run as a batch job using the provided [`run_genomic_benchmark.sh`](./slurm_scripts/run_genomics_benchmark.sh) script. We also provide a helper shell script [`wrapper_run_genomics.sh`](./slurm_scripts/wrapper_run_genomics.sh) that can be used to launch multiple fine-tuning runs in parallel. Finally, the [`run_genomics_benchmark_cnn.sh`](./slurm_scripts/run_genomics_benchmark_cnn.sh) script can be used to train the CNN baseline for this experiment from scratch on the downstream tasks. ### Nucleotide Transformer datasets The Nucleotide Transformer suite of tasks was proposed in [Dalla-Torre et al. (2023)](https://www.biorxiv.org/content/10.1101/2023.01.11.523679v1). The data is available on HuggingFace: [InstaDeepAI/nucleotide_transformer_downstream_tasks](https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_downstream_tasks). We can launch a downstream fine-tuning run on one of the tasks using the sample command below: ```bash python -m train \ experiment=hg38/nucleotide_transformer \ callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \ dataset.dataset_name="${task}" \ dataset.train_val_split_seed=${seed} \ dataset.batch_size=${batch_size} \ dataset.rc_aug="${rc_aug}" \ +dataset.conjoin_test="${CONJOIN_TEST}" \ loader.num_workers=2 \ model._name_=dna_embedding_caduceus \ +model.config_path="" \ +model.conjoin_test=false \ +decoder.conjoin_train=true \ +decoder.conjoin_test=false \ optimizer.lr="1e-3" \ trainer.max_epochs=10 \ train.pretrained_model_path="" \ trainer.max_epochs=20 \ wandb=null ``` We can also launch as batch jobs (see [`run_nucleotide_transformer.sh`](./slurm_scripts/run_nucleotide_transformer.sh) and [`wrapper_run_nucleotide_transformer.sh`](./slurm_scripts/wrapper_run_nucleotide_transformer.sh) for details). ### eQTL SNP Variant Effect Prediction This task comes from the recently proposed Long Range Benchmark (LRB) in [Kao et al., 2023](https://llms4science-community.github.io/papers/LLMs4Bio24_paper_12.pdf). The data is available on HuggingFace: [InstaDeepAI/genomics-long-range-benchmark](https://huggingface.co/datasets/InstaDeepAI/genomics-long-range-benchmark). For this task we fit a model to the pre-trained and frozen embeddings of the DNA language models. Therefore, to perform the evaluation, we proceed in 2 steps: - **Step 1: Extract the embeddings** from the pre-trained model: Run the [`vep_embeddings.py`](./vep_embeddings.py) script to extract the embeddings from the pre-trained model. See the example below: ```bash torchrun \ --standalone \ --nnodes=1 \ --nproc-per-node=8 \ vep_embeddings.py \ --num_workers=2 \ --seq_len=131072 \ --bp_per_token=1 \ --embed_dump_batch_size=1 \ --name="caduceus-ps_downstream-seqlen=131k" \ --model_name_or_path="kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16" \ --rcps ``` The `--rcps` flag is used to indicate that the model is reverse-complement equivariant. When using other models, set this flag to false with `--no-rcps`. To speed this step up, this script utilizes torch distributed data parallelism. Please refer to the slurm script provided in [`slurm_scripts/dump_vep_embeddings.sh`](./slurm_scripts/dump_vep_embeddings.sh) to launch this step as a batch job. - **Step 2: Fit an SVM model to the embeddings** using this notebook: [`vep_svm.ipynb`](./vep_svm.ipynb). ## Citation If you find our work useful, please cite our paper using the following: ``` @article{schiff2024caduceus, title={Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling}, author={Schiff, Yair and Kao, Chia-Hsiang and Gokaslan, Aaron and Dao, Tri and Gu, Albert and Kuleshov, Volodymyr}, journal={arXiv preprint arXiv:2403.03234}, year={2024} } ``` ## Acknowledgements This repository is adapted from the [HyenaDNA repo](https://github.com/HazyResearch/hyena-dna) and leverages much of the training, data loading, and logging infrastructure defined there. HyenaDNA was originally derived from the [S4](https://github.com/state-spaces/s4) and [Safari](https://github.com/HazyResearch/safari) repositories. We would like to thank Evan Trop and the [InstaDeep](https://www.instadeep.com/) team for useful discussions about the [Nucleotide Transformer leaderboard](https://huggingface.co/spaces/InstaDeepAI/nucleotide_transformer_benchmark) and the Long Range Benchmark task. Finally, we would like to thank [MosaicML](https://www.mosaicml.com/) for providing compute resources for some of the pre-training experiments. ================================================ FILE: caduceus/__init__.py ================================================ """Hugging Face config, model, and tokenizer for Caduceus. """ from .configuration_caduceus import CaduceusConfig from .modeling_caduceus import Caduceus, CaduceusForMaskedLM, CaduceusForSequenceClassification from .tokenization_caduceus import CaduceusTokenizer ================================================ FILE: caduceus/configuration_caduceus.py ================================================ """Caduceus config for Hugging Face. """ from typing import Optional, Union from transformers import PretrainedConfig class CaduceusConfig(PretrainedConfig): """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance.""" model_type = "caduceus" def __init__( self, # From original MambaConfig d_model: int = 2560, n_layer: int = 64, vocab_size: int = 50277, ssm_cfg: Optional[dict] = None, rms_norm: bool = True, residual_in_fp32: bool = True, fused_add_norm: bool = True, pad_vocab_size_multiple: int = 8, # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm norm_epsilon: float = 1e-5, # Used in init_weights initializer_cfg: Optional[dict] = None, # Caduceus-specific params bidirectional: bool = True, bidirectional_strategy: Union[str, None] = "add", bidirectional_weight_tie: bool = True, rcps: bool = False, complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead **kwargs, ): super().__init__(**kwargs) self.d_model = d_model self.n_layer = n_layer self.vocab_size = vocab_size self.ssm_cfg = ssm_cfg self.rms_norm = rms_norm self.residual_in_fp32 = residual_in_fp32 self.fused_add_norm = fused_add_norm self.pad_vocab_size_multiple = pad_vocab_size_multiple self.norm_epsilon = norm_epsilon self.initializer_cfg = initializer_cfg self.bidirectional = bidirectional self.bidirectional_strategy = bidirectional_strategy self.bidirectional_weight_tie = bidirectional_weight_tie self.rcps = rcps self.complement_map = complement_map ================================================ FILE: caduceus/modeling_caduceus.py ================================================ """Caduceus model for Hugging Face. """ import inspect import math from functools import partial from typing import Optional, Tuple, Union import torch from mamba_ssm.modules.mamba_simple import Mamba try: from mamba_ssm.modules.mamba_simple import Block # Legacy mambav1 file structure except ImportError: from mamba_ssm.modules.block import Block # mambav2 file structure from torch import nn from torch.nn import functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput try: from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure except ImportError: try: from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None from .configuration_caduceus import CaduceusConfig from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock def create_block( d_model, ssm_cfg=None, norm_epsilon=1e-5, rms_norm=False, residual_in_fp32=False, fused_add_norm=False, layer_idx=None, bidirectional=True, bidirectional_strategy="add", bidirectional_weight_tie=True, rcps=False, device=None, dtype=None, ): """Create Caduceus block. Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py """ if ssm_cfg is None: ssm_cfg = {} factory_kwargs = {"device": device, "dtype": dtype} bidirectional_kwargs = { "bidirectional": bidirectional, "bidirectional_strategy": bidirectional_strategy, "bidirectional_weight_tie": bidirectional_weight_tie, } mixer_cls = partial(BiMambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs) norm_cls = partial( nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs ) block_cls = RCPSMambaBlock if rcps else Block # mambav2 compatibility if "mlp_cls" in inspect.signature(block_cls.__init__).parameters: block = block_cls( d_model, mixer_cls, mlp_cls=nn.Identity, norm_cls=norm_cls, fused_add_norm=fused_add_norm, residual_in_fp32=residual_in_fp32, ) else: block = block_cls( d_model, mixer_cls, norm_cls=norm_cls, fused_add_norm=fused_add_norm, residual_in_fp32=residual_in_fp32, ) block.layer_idx = layer_idx return block class BiMambaWrapper(nn.Module): """Thin wrapper around Mamba to support bi-directionality.""" def __init__( self, d_model: int, bidirectional: bool = True, bidirectional_strategy: Optional[str] = "add", bidirectional_weight_tie: bool = True, **mamba_kwargs, ): super().__init__() if bidirectional and bidirectional_strategy is None: bidirectional_strategy = "add" # Default strategy: `add` if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]: raise NotImplementedError(f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!") self.bidirectional = bidirectional self.bidirectional_strategy = bidirectional_strategy self.mamba_fwd = Mamba( d_model=d_model, **mamba_kwargs ) if bidirectional: self.mamba_rev = Mamba( d_model=d_model, **mamba_kwargs ) if bidirectional_weight_tie: # Tie in and out projections (where most of param count lies) self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias else: self.mamba_rev = None def forward(self, hidden_states, inference_params=None): """Bidirectional-enabled forward pass hidden_states: (B, L, D) Returns: same shape as hidden_states """ out = self.mamba_fwd(hidden_states, inference_params=inference_params) if self.bidirectional: out_rev = self.mamba_rev( hidden_states.flip(dims=(1,)), # Flip along the sequence length dimension inference_params=inference_params ).flip(dims=(1,)) # Flip back for combining with forward hidden states if self.bidirectional_strategy == "add": out = out + out_rev elif self.bidirectional_strategy == "ew_multiply": out = out * out_rev else: raise NotImplementedError(f"`{self.bidirectional_strategy}` for bi-directionality not implemented!") return out class CaduceusEmbeddings(nn.Module): def __init__( self, config: CaduceusConfig, device=None, dtype=None, ): super().__init__() factory_kwargs = {"device": device, "dtype": dtype} if config.rcps: self.word_embeddings = RCPSEmbedding( config.vocab_size, config.d_model, config.complement_map, **factory_kwargs ) else: self.word_embeddings = nn.Embedding(config.vocab_size, config.d_model, **factory_kwargs) def forward(self, input_ids): """ input_ids: (batch, seqlen) """ return self.word_embeddings(input_ids) class CaduceusMixerModel(nn.Module): def __init__( self, config: CaduceusConfig, device=None, dtype=None, ) -> None: super().__init__() factory_kwargs = {"device": device, "dtype": dtype} self.fused_add_norm = config.fused_add_norm self.rcps = config.rcps self.residual_in_fp32 = config.residual_in_fp32 self.embeddings = CaduceusEmbeddings(config, **factory_kwargs) # Mamba changes the order of residual and layer norm: # Instead of LN -> Attn / MLP -> Add, we do: # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and # the main branch (output of MLP / Mixer). The model definition is unchanged. # This is for performance reason: we can fuse add + layer_norm. if config.fused_add_norm: if layer_norm_fn is None or rms_norm_fn is None: raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") self.layers = nn.ModuleList( [ create_block( config.d_model, ssm_cfg=config.ssm_cfg, norm_epsilon=config.norm_epsilon, rms_norm=config.rms_norm, residual_in_fp32=config.residual_in_fp32, fused_add_norm=config.fused_add_norm, layer_idx=i, bidirectional=config.bidirectional, bidirectional_strategy=config.bidirectional_strategy, bidirectional_weight_tie=config.bidirectional_weight_tie, rcps=config.rcps, **factory_kwargs, ) for i in range(config.n_layer) ] ) norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)( config.d_model, eps=config.norm_epsilon, **factory_kwargs ) self.norm_f = norm_f if (config.fused_add_norm or not config.rcps) else RCPSAddNormWrapper(norm_f) def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False): """Mixer forward.""" all_hidden_states = [] if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embeddings(input_ids) residual = None for layer in self.layers: if output_hidden_states: all_hidden_states.append(hidden_states) # TODO: Add support for gradient checkpointing hidden_states, residual = layer( hidden_states, residual, inference_params=None ) if not self.fused_add_norm: if self.rcps: # Set prenorm=False here since we don't need the residual hidden_states = self.norm_f(hidden_states, residual=residual, prenorm=False) else: residual = (hidden_states + residual) if residual is not None else hidden_states hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) else: fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn if self.rcps: # Set prenorm=False here since we don't need the residual hidden_states_fwd = fused_add_norm_fn( hidden_states[..., :hidden_states.shape[-1] // 2], self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual[..., :hidden_states.shape[-1] // 2], prenorm=False, residual_in_fp32=self.residual_in_fp32, ) hidden_states_rc = fused_add_norm_fn( hidden_states[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]), self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]), prenorm=False, residual_in_fp32=self.residual_in_fp32, ) hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1) else: # Set prenorm=False here since we don't need the residual hidden_states = fused_add_norm_fn( hidden_states, self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual, prenorm=False, residual_in_fp32=self.residual_in_fp32, ) if output_hidden_states: all_hidden_states.append(hidden_states) return hidden_states, all_hidden_states def cross_entropy(logits, y, ignore_index=-100): """Cross entropy loss.""" logits = logits.view(-1, logits.shape[-1]) y = y.view(-1) return F.cross_entropy(logits, y, ignore_index=ignore_index) def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100): """Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome).""" logits = logits.view(-1, logits.shape[-1]) y = y.view(-1) ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none") loss_weights = loss_weights.view(-1) loss_weights[y == ignore_index] = 0.0 # TODO: Follows GPN implementation, but should we remove weight normalization? return (ce * (loss_weights / loss_weights.sum())).sum() class CaduceusPreTrainedModel(PreTrainedModel): """PreTrainedModel wrapper for Caduceus backbone.""" config_class = CaduceusConfig base_model_prefix = "caduceus" supports_gradient_checkpointing = False _no_split_modules = ["BiMambaWrapper"] def _init_weights( self, module, initializer_range=0.02, # Now only used for embedding layer. **kwargs, ): """Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py""" n_layer = self.config.n_layer initialized_cfg = self.config.initializer_cfg if self.config.initializer_cfg is not None else {} rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True) initializer_range = initialized_cfg.get("initializer_range", initializer_range) n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1) if isinstance(module, nn.Linear): if module.bias is not None: if not getattr(module.bias, "_no_reinit", False): nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=initializer_range) if rescale_prenorm_residual: # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. # > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of # residual layers. # > -- GPT-2 :: https://openai.com/blog/better-language-models/ # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py for name, p in module.named_parameters(): if name in ["out_proj.weight", "fc2.weight"]: # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) # We need to reinit p since this code could be called multiple times # Having just p *= scale would repeatedly scale it down nn.init.kaiming_uniform_(p, a=math.sqrt(5)) with torch.no_grad(): p /= math.sqrt(n_residuals_per_layer * n_layer) class Caduceus(CaduceusPreTrainedModel): """Caduceus model that can be instantiated using HF patterns.""" def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs): super().__init__(config) if config.rcps: assert config.complement_map is not None, "Complement map must be provided for RCPS." # Adjust vocab size and complement maps if vocab padding is set. if config.vocab_size % config.pad_vocab_size_multiple != 0: config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple) if config.complement_map is not None and config.vocab_size > len(config.complement_map): for i in range(len(config.complement_map), config.vocab_size): config.complement_map[i] = i self.config = config factory_kwargs = {"device": device, "dtype": dtype} self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs) def forward( self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]: """HF-compatible forward method.""" output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict hidden_states, all_hidden_states = self.backbone( input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states ) if return_dict: return BaseModelOutputWithNoAttention( last_hidden_state=hidden_states, hidden_states=all_hidden_states if output_hidden_states else None ) elif output_hidden_states: return hidden_states, all_hidden_states else: return hidden_states class CaduceusForMaskedLM(CaduceusPreTrainedModel): """HF-compatible Caduceus model for masked language modeling.""" def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs): super().__init__(config, **kwargs) factory_kwargs = {"device": device, "dtype": dtype} self.caduceus = Caduceus(config, **factory_kwargs, **kwargs) if config.rcps: self.lm_head = RCPSLMHead( complement_map=self.config.complement_map, # Use caduceus config as it might have been updated vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated true_dim=config.d_model, dtype=dtype ) else: self.lm_head = nn.Linear( config.d_model, self.config.vocab_size, # Use caduceus config as it might have been updated bias=False, **factory_kwargs ) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.caduceus.backbone.embeddings.word_embeddings def set_input_embeddings(self, value): if self.config.rcps: raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.") self.caduceus.backbone.embeddings.word_embeddings = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): """Overrides output embeddings.""" if self.config.rcps: raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.") self.lm_head = new_embeddings def tie_weights(self): """Tie weights, accounting for RCPS.""" if self.config.rcps: self.lm_head.set_weight(self.get_input_embeddings().weight) else: super().tie_weights() def get_decoder(self): """Get decoder (backbone) for the model.""" return self.caduceus def set_decoder(self, decoder): """Set decoder (backbone) for the model.""" self.caduceus = decoder def forward( self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, loss_weights: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MaskedLMOutput]: """HF-compatible forward method.""" output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.caduceus( input_ids=input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: if loss_weights is not None: loss = weighted_cross_entropy(logits, labels, loss_weights, ignore_index=self.config.pad_token_id) else: loss = cross_entropy(logits, labels, ignore_index=self.config.pad_token_id) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return MaskedLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, ) class CaduceusForSequenceClassification(CaduceusPreTrainedModel): def __init__( self, config: CaduceusConfig, pooling_strategy: str = "mean", conjoin_train: bool = False, conjoin_eval: bool = False, device=None, dtype=None, **kwargs): super().__init__(config, **kwargs) if pooling_strategy not in ["mean", "max", "first", "last"]: raise NotImplementedError(f"Pooling strategy `{pooling_strategy}` not implemented.") self.pooling_strategy = pooling_strategy factory_kwargs = {"device": device, "dtype": dtype} self.num_labels = kwargs.get("num_labels", config.num_labels) self.caduceus = Caduceus(config, **factory_kwargs, **kwargs) self.score = nn.Linear(config.d_model, self.num_labels, bias=False) self.conjoin_train = conjoin_train self.conjoin_eval = conjoin_eval # Initialize weights and apply final processing self.post_init() self.init_scorer() def init_scorer(self, initializer_range=0.02): initializer_range = self.config.initializer_cfg.get("initializer_range", initializer_range) \ if self.config.initializer_cfg is not None else initializer_range self.score.weight.data.normal_(std=initializer_range) def get_input_embeddings(self): return self.caduceus.backbone.embeddings.word_embeddings def set_input_embeddings(self, value): if self.config.rcps: raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.") self.caduceus.backbone.embeddings.word_embeddings = value def pool_hidden_states(self, hidden_states, sequence_length_dim=1): """Pools hidden states along sequence length dimension.""" if self.pooling_strategy == "mean": # Mean pooling along sequence length dimension return hidden_states.mean(dim=sequence_length_dim) if self.pooling_strategy == "max": # Max pooling along sequence length dimension return hidden_states.max(dim=sequence_length_dim).values if self.pooling_strategy == "last": # Use embedding of last token in the sequence return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[-1, ...] if self.pooling_strategy == "first": # Use embedding of first token in the sequence return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...] def forward( self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Get hidden representations from the backbone if self.config.rcps: # Hidden states have 2 * d_model channels for RCPS transformer_outputs = self.caduceus( input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = torch.stack( [ transformer_outputs[0][..., :self.config.d_model], torch.flip(transformer_outputs[0][..., self.config.d_model:], dims=[1, 2]) ], dim=-1 ) elif self.conjoin_train or (self.conjoin_eval and not self.training): # For conjoining / post-hoc conjoining assert input_ids is not None, "`input_ids` must be provided for conjoining." assert input_ids.ndim == 3, "`input_ids` must be 3D tensor: channels corresponds to forward and rc strands." transformer_outputs = self.caduceus( input_ids[..., 0], inputs_embeds=None, output_hidden_states=output_hidden_states, return_dict=return_dict, ) transformer_outputs_rc = self.caduceus( input_ids[..., 1], inputs_embeds=None, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # Stack along channel dimension (dim=-1) hidden_states = torch.stack([transformer_outputs[0], transformer_outputs_rc[0]], dim=-1) else: transformer_outputs = self.caduceus( input_ids, inputs_embeds=None, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] # Pool and get logits pooled_hidden_states = self.pool_hidden_states(hidden_states) # Potentially run `score` twice (with parameters shared) for conjoining if hidden_states.ndim == 4: # bsz, seq_len, hidden_dim, 2 where last channel has the stacked fwd and rc reps logits_fwd = self.score(pooled_hidden_states[..., 0]) logits_rc = self.score(pooled_hidden_states[..., 1]) logits = (logits_fwd + logits_rc) / 2 else: logits = self.score(pooled_hidden_states) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": if self.num_labels == 1: loss = F.mse_loss(logits.squeeze(), labels.squeeze()) else: loss = F.mse_loss(logits, labels) elif self.config.problem_type == "single_label_classification": loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss = F.binary_cross_entropy_with_logits(logits, labels) if not return_dict: output = (logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=transformer_outputs.hidden_states, ) ================================================ FILE: caduceus/modeling_rcps.py ================================================ """Reverse-complement equivariant modules. """ from collections import OrderedDict from typing import Optional import torch from torch import Tensor from torch import nn from torch.nn import functional as F try: from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure except ImportError: try: from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None class RCPSEmbedding(nn.Module): """Embedding layer that supports reverse-complement equivariance.""" def __init__(self, vocab_size: int, d_model: int, complement_map: dict, **factory_kwargs): """ Args: vocab_size: Size of vocabulary. d_model: Dimensionality of embedding (actual embedding matrix will have 1/2 the output dim). complement_map: Dictionary mapping each token id to its complement. """ super().__init__() self.register_buffer( "complement_map", torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long) ) self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) @property def weight(self): """Embedding weights.""" return self.embedding.weight def set_weight(self, value): """Set embedding weights.""" self.embedding.weight = value def rc(self, x): """Reverse-complement a tensor of input_ids by flipping along length dimension and complementing the ids.""" return torch.gather( self.complement_map.unsqueeze(0).expand(x.shape[0], -1), dim=1, index=torch.flip(x, dims=[-1]) ) def forward(self, input_ids): """Reverse-complement equivariant forward pass. This embedding module doubles the output dimensionality to support reverse-complement equivariance. Args: input_ids: Input tensor of shape (batch_size, seq_len) Returns: Embedding tensor of shape (batch_size, seq_len, d_model * 2) """ fwd_out = self.embedding(input_ids) rc_out = torch.flip(self.embedding(self.rc(input_ids)), dims=[-2, -1]) return torch.cat([fwd_out, rc_out], dim=-1) class RCPSWrapper(nn.Module): """Wrapper to convert arbitrary nn.Module into a reverse-complement equivariant module. See ref. "Towards a Better Understanding of Reverse-Complement Equivariance for Deep Learning Models in Regulatory Genomics", Zhou et al. (2022), https://proceedings.mlr.press/v165/zhou22a.html for more details. """ def __init__(self, submodule: nn.Module): super().__init__() self.submodule = submodule @staticmethod def rc(x): """Reverse-complement a tensor by flipping the length (dim=-2) and channel (dim=-1) dimensions.""" return torch.flip(x, dims=[-2, -1]) def forward(self, x, **kwargs): """Reverse-complement equivariant forward pass. Args: x: Input tensor of shape (batch_size, seq_len, channels) Returns: Output tensor of shape (batch_size, seq_len, channels) """ n_channels = x.shape[-1] # Run submodule along sequence fwd_out = self.submodule(x[..., :n_channels // 2], **kwargs) # Run submodule along rc-sequence rc_out = self.submodule(self.rc(x[..., n_channels // 2:]), **kwargs) # Concatenate along channel dimension (dim=-1) return torch.cat([fwd_out, self.rc(rc_out)], dim=-1) class RCPSAddNormWrapper(RCPSWrapper): """RC equivariant AddNorm layer.""" def __init__(self, submodule: nn.Module): super().__init__(submodule) def forward(self, x, residual=None, prenorm=False): """ Args: x: Input tensor of shape (batch_size, seq_len, channels) residual: Residual tensor of shape (batch_size, seq_len, channels) or None. prenorm: Whether to return residual. """ n_channels = x.shape[-1] if residual is None: residual = x x_fwd = self.submodule(x[..., :n_channels // 2].to(dtype=self.submodule.weight.dtype)) x_rc = self.submodule(self.rc(x[..., n_channels // 2:]).to(dtype=self.submodule.weight.dtype)) x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1) else: residual_fwd = x[..., :n_channels // 2] + residual[..., :n_channels // 2] x_fwd = self.submodule(residual_fwd.to(dtype=self.submodule.weight.dtype)) residual_rc = self.rc(x[..., n_channels // 2:]) + self.rc(residual[..., n_channels // 2:]) x_rc = self.submodule(residual_rc.to(dtype=self.submodule.weight.dtype)) residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1) x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1) return x if not prenorm else (x, residual) class RCPSMambaBlock(nn.Module): def __init__( self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False, device=None, # Keep for consistency with original Mamba Block dtype=None, # Keep for consistency with original Mamba Block ): """RCPS version of simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection. Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py """ super().__init__() self.residual_in_fp32 = residual_in_fp32 self.fused_add_norm = fused_add_norm self.mixer = RCPSWrapper(mixer_cls(dim)) norm_f = norm_cls(dim) self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f) if self.fused_add_norm: assert RMSNorm is not None, "RMSNorm import fails" assert isinstance( self.norm, (nn.LayerNorm, RMSNorm) ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" def forward( self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None ): r"""Pass the input through the encoder layer. Args: hidden_states: the sequence to the encoder layer (required). residual: hidden_states = Mixer(LN(residual)). inference_params: inference parameters for mixer. """ if not self.fused_add_norm: hidden_states, residual = self.norm(hidden_states, residual=residual, prenorm=True) if self.residual_in_fp32: residual = residual.to(torch.float32) else: fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn hidden_states_fwd, residual_fwd = fused_add_norm_fn( hidden_states[..., hidden_states.shape[-1] // 2:], self.norm.weight, self.norm.bias, residual=residual[..., hidden_states.shape[-1] // 2:] if residual is not None else None, prenorm=True, residual_in_fp32=self.residual_in_fp32, eps=self.norm.eps, ) hidden_states_rc, residual_rc = fused_add_norm_fn( hidden_states[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]), self.norm.weight, self.norm.bias, residual=residual[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]) if residual is not None else None, prenorm=True, residual_in_fp32=self.residual_in_fp32, eps=self.norm.eps, ) hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1) residual = torch.cat([residual_fwd, residual_rc.flip(dims=[-2, -1])], dim=-1) hidden_states = self.mixer(hidden_states, inference_params=inference_params) return hidden_states, residual def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): """Allocate inference cache for mixer. Keep for compatibility with original Mamba Block. """ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) class RCPSLMHead(nn.Module): """LM Head for reverse-complement equivariant inputs, which have dim * 2 relative to standard inputs.""" def __init__(self, true_dim: int, vocab_size: int, complement_map: dict, **factory_kwargs): """ `true_dim` corresponds to the actual dimensionality of the input were it not reverse-complement equivariant, i.e. 0.5 times the actual input dim. """ super().__init__() self.register_buffer( "complement_map", torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long) ) self.true_dim = true_dim self.lm_head = nn.Linear(true_dim, vocab_size, bias=False, **factory_kwargs) @property def weight(self): """LM head weights.""" return self.lm_head.weight def set_weight(self, value): """Set LM head weights.""" self.lm_head.weight = value def forward(self, x): """ Args: x: Input tensor of shape (batch_size, seq_len, dim), where dim = 2 * true_dim. """ n_channels = x.shape[-1] assert n_channels == 2 * self.true_dim, "Input must have 2 * true_dim channels." fwd_logits = F.linear(x[..., :n_channels // 2], self.weight, bias=self.lm_head.bias) rc_logits = F.linear( torch.flip(x[..., n_channels // 2:], dims=[-1]), self.weight[self.complement_map, :], bias=self.lm_head.bias ) return fwd_logits + rc_logits ================================================ FILE: caduceus/tests/test_rcps.py ================================================ """Tests for RCPS modules. """ import pytest import torch from torch import nn from torch.nn import functional as F try: # Legacy mambav1 file structure from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn except ImportError: try: # mambav2 file structure from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None from caduceus.modeling_rcps import ( RCPSEmbedding, RCPSAddNormWrapper, RCPSLMHead, RCPSWrapper ) from caduceus.modeling_caduceus import ( CaduceusConfig, CaduceusMixerModel, CaduceusForMaskedLM, create_block ) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seq_len", [512]) @pytest.mark.parametrize("d_model", [256]) @pytest.mark.parametrize("dtype", [torch.float32]) def test_rcps_embedding(batch_size, seq_len, d_model, dtype): # Set tolerance device = torch.device("cpu") rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3) if dtype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 # Set seed torch.random.manual_seed(0) # Define complement map str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6} complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"} complement_map = { str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v for k, v in str_to_id.items() } vocab_size = 12 pad_vocab_size_multiple = 8 if vocab_size % pad_vocab_size_multiple != 0: vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) if vocab_size > len(complement_map): for i in range(len(complement_map), vocab_size): complement_map[i] = i # Generate random sequences input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device) rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device) # Test RC equivariance of embedding layer factory_kwargs = {"device": device, "dtype": dtype} embedding = RCPSEmbedding( vocab_size=vocab_size, d_model=d_model, complement_map=complement_map, **factory_kwargs ).to(device) out_embed = embedding(input_ids) rc_out_embed = torch.flip(embedding(rc_input_ids), dims=[-2, -1]) # Test that channels are 2 * d_model assert tuple(out_embed.size()) == (batch_size, seq_len, d_model * 2) assert tuple(rc_out_embed.size()) == (batch_size, seq_len, d_model * 2) # Test that RC equivariance holds assert torch.allclose(out_embed.detach(), rc_out_embed.detach(), rtol=rtol, atol=atol) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seq_len", [1024]) @pytest.mark.parametrize("d_model", [128]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) def test_rcps_wrapper(batch_size, seq_len, d_model, dtype): # Set tolerance device = torch.device("cuda") rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3) if dtype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 # Set seed torch.random.manual_seed(0) # Generate random sequence with 2 * d_model channels x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype) rc_x = torch.flip(x, dims=[-2, -1]) factory_kwargs = {"device": device, "dtype": dtype} module = nn.Sequential( nn.Linear(d_model, d_model, bias=False, **factory_kwargs), nn.ReLU(), nn.Linear(d_model, d_model*2, bias=True, **factory_kwargs), nn.ReLU(), nn.Linear(d_model * 2, d_model, bias=True, **factory_kwargs) ) # Test RC equivariance of wrapper rcps_module = RCPSWrapper(module).to(device) out = rcps_module(x) rc_out = torch.flip(rcps_module(rc_x), dims=[-2, -1]) assert out.size() == x.size() assert rc_out.size() == x.size() assert torch.allclose(out.detach(), rc_out.detach(), rtol=rtol, atol=atol) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seq_len", [1024]) @pytest.mark.parametrize("d_model", [128]) @pytest.mark.parametrize("prenorm", [False, True]) @pytest.mark.parametrize("dtype", [torch.float16]) def test_rcps_add_norm_wrapper(batch_size, seq_len, d_model, prenorm, dtype): # Set tolerance device = torch.device("cuda") rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3) if dtype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 # Set seed torch.random.manual_seed(0) # Generate random sequence with 2 * d_model channels x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype) rc_x = torch.flip(x, dims=[-2, -1]) factory_kwargs = {"device": device, "dtype": dtype} norm = RMSNorm(d_model, eps=1e-5, **factory_kwargs) # Test RC equivariance of wrapper rcps_module = RCPSAddNormWrapper(norm).to(device) out = rcps_module(x, prenorm=prenorm) if prenorm: # returns tuple rc_out = tuple([torch.flip(r, dims=[-2, -1]) for r in rcps_module(rc_x, prenorm=prenorm)]) for f, r in zip(out, rc_out): assert f.size() == x.size() assert r.size() == x.size() assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol) else: rc_out = torch.flip(rcps_module(rc_x, prenorm=prenorm), dims=[-2, -1]) assert out.size() == x.size() assert rc_out.size() == x.size() assert torch.allclose(out.detach(), rc_out.detach(), rtol=rtol, atol=atol) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seq_len", [1024]) @pytest.mark.parametrize("d_model", [128]) @pytest.mark.parametrize("bidirectional", [True, False]) @pytest.mark.parametrize("fused_add_norm", [True, False]) @pytest.mark.parametrize("dtype", [torch.float16]) def test_rcps_mamba_block_wrapper(batch_size, seq_len, d_model, bidirectional, fused_add_norm, dtype): # Set tolerance device = torch.device("cuda") rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3) if dtype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 # Set seed torch.random.manual_seed(0) # Generate random sequence with 2 * d_model channels x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype) rc_x = torch.flip(x, dims=[-2, -1]) ssm_cfg = { "d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random", "dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True } factory_kwargs = {"device": device, "dtype": dtype} mamba_block = create_block( d_model, ssm_cfg=ssm_cfg, norm_epsilon=1e-5, rms_norm=True, residual_in_fp32=True, fused_add_norm=fused_add_norm, layer_idx=0, bidirectional=bidirectional, bidirectional_strategy="add", bidirectional_weight_tie=True, rcps=True, **factory_kwargs ) # Test RC equivariance of wrapper out = mamba_block(x, residual=None) rc_out = tuple([torch.flip(r, dims=[-2, -1]) for r in mamba_block(rc_x, residual=None)]) for f, r in zip(out, rc_out): assert f.size() == x.size() assert r.size() == x.size() assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol) out = mamba_block(x, residual=x) rc_out = tuple([torch.flip(r, dims=[-2, -1]) for r in mamba_block(rc_x, residual=rc_x)]) for f, r in zip(out, rc_out): assert f.size() == x.size() assert r.size() == x.size() assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol) @pytest.mark.parametrize("batch_size", [2, 4]) @pytest.mark.parametrize("seq_len", [1, 1024, 2048]) @pytest.mark.parametrize("d_model", [2, 128, 256]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) def test_rcps_lm_head(batch_size, seq_len, d_model, dtype): # Set tolerance device = torch.device("cuda") rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3) if dtype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 # Set seed torch.random.manual_seed(0) # Define complement map str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6} complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"} complement_map = { str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v for k, v in str_to_id.items() } factory_kwargs = {"device": device, "dtype": dtype} vocab_size = 12 if vocab_size > len(complement_map): for i in range(len(complement_map), vocab_size): complement_map[i] = i # Instantiate LM head lm_head = RCPSLMHead( complement_map=complement_map, vocab_size=vocab_size, true_dim=d_model, **factory_kwargs ) # Generate random sequence with 2 * d_model channels x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype) rc_x = torch.flip(x, dims=[-2, -1]) # Test RC equivariance of LM head out = lm_head(x) rc_out = lm_head(rc_x) assert tuple(out.size()) == (batch_size, seq_len, vocab_size) assert tuple(rc_out.size()) == (batch_size, seq_len, vocab_size) assert torch.allclose( out.detach(), torch.flip(rc_out.detach()[..., lm_head.complement_map], dims=[1]), rtol=rtol, atol=atol ) assert torch.allclose( F.softmax(out, dim=-1).detach(), torch.flip(F.softmax(rc_out, dim=-1).detach()[..., lm_head.complement_map], dims=[1]), rtol=rtol, atol=atol ) @pytest.mark.parametrize("batch_size", [2, 4]) @pytest.mark.parametrize("seq_len", [1024, 2048]) @pytest.mark.parametrize("n_layer", [1, 2, 3]) @pytest.mark.parametrize("d_model", [128, 256]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) @pytest.mark.parametrize("fused_add_norm", [True, False]) @pytest.mark.parametrize("bidirectional", [False, True]) @pytest.mark.parametrize("bidirectional_weight_tie", [False, True]) def test_rcps_backbone(batch_size, seq_len, n_layer, d_model, dtype, fused_add_norm, bidirectional, bidirectional_weight_tie): # Set tolerance device = torch.device("cuda") rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3) if dtype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 # Set seed torch.random.manual_seed(0) # Define complement map str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6} complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"} complement_map = { str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v for k, v in str_to_id.items() } # Setup CaduceusConfig initializer_cfg = {"initializer_range": 0.02, "rescale_prenorm_residual": True, "n_residuals_per_layer": 1} ssm_cfg = { "d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random", "dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True } config = CaduceusConfig( d_model=d_model, n_layer=n_layer, vocab_size=12, ssm_cfg=ssm_cfg, rms_norm=True, residual_in_fp32=False, fused_add_norm=fused_add_norm, pad_vocab_size_multiple=8, norm_epsilon=1e-5, initializer_cfg=initializer_cfg, bidirectional=bidirectional, bidirectional_strategy="add", bidirectional_weight_tie=bidirectional_weight_tie, rcps=True, complement_map=complement_map, ) factory_kwargs = {"device": device, "dtype": dtype} # Instantiate model backbone = CaduceusMixerModel( config, **factory_kwargs, ).to(device) # Generate random sequences input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device) rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device) # Test RC equivariance of rc backbone out = backbone(input_ids)[0] rc_out = backbone(rc_input_ids)[0] if isinstance(rc_out, tuple): rc_out = tuple([torch.flip(r, dims=[1, 2]) for r in rc_out]) for f, r in zip(out, rc_out): assert f.size() == (batch_size, seq_len, d_model * 2) assert r.size() == (batch_size, seq_len, d_model * 2) assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol) else: # Hidden state size should double assert tuple(out.size()) == (batch_size, seq_len, d_model * 2) assert tuple(rc_out.size()) == (batch_size, seq_len, d_model * 2) assert torch.allclose(out.detach(), torch.flip(rc_out.detach(), dims=[1, 2]), rtol=rtol, atol=atol) @pytest.mark.parametrize("batch_size", [2, 4]) @pytest.mark.parametrize("seq_len", [1024, 2048]) @pytest.mark.parametrize("n_layer", [1, 3, 4]) @pytest.mark.parametrize("d_model", [128, 256]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) @pytest.mark.parametrize("bidirectional", [False, True]) @pytest.mark.parametrize("bidirectional_weight_tie", [False, True]) def test_rcps_mamba_lm(batch_size, seq_len, n_layer, d_model, dtype, bidirectional, bidirectional_weight_tie): # Set tolerance device = torch.device("cuda") rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3) if dtype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 # Set seed torch.random.manual_seed(0) # Define complement map str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6} complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"} complement_map = { str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v for k, v in str_to_id.items() } # Setup CaduceusConfig initializer_cfg = {"initializer_range": 0.02, "rescale_prenorm_residual": True, "n_residuals_per_layer": 1} ssm_cfg = { "d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random", "dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True } config = CaduceusConfig( d_model=d_model, n_layer=n_layer, vocab_size=12, ssm_cfg=ssm_cfg, rms_norm=True, residual_in_fp32=False, fused_add_norm=True, pad_vocab_size_multiple=8, norm_epsilon=1e-5, initializer_cfg=initializer_cfg, bidirectional=bidirectional, bidirectional_strategy="add", bidirectional_weight_tie=bidirectional_weight_tie, rcps=True, complement_map=complement_map, ) factory_kwargs = {"device": device, "dtype": dtype} # Instantiate model mamba_lm = CaduceusForMaskedLM( config=config, **factory_kwargs, ).to(device) # Generate random sequences input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device) rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device) # Test RC equivariance of rc backbone out = mamba_lm(input_ids) rc_out = mamba_lm(rc_input_ids) if config.vocab_size % config.pad_vocab_size_multiple != 0: config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple) assert tuple(out.logits.size()) == (batch_size, seq_len, config.vocab_size) assert tuple(rc_out.logits.size()) == (batch_size, seq_len, config.vocab_size) assert torch.allclose( out.logits.detach(), torch.flip(rc_out.logits.detach()[..., mamba_lm.lm_head.complement_map], dims=[1]), rtol=rtol, atol=atol ) assert torch.allclose( F.softmax(out.logits, dim=-1).detach(), torch.flip(F.softmax(rc_out.logits, dim=-1).detach()[..., mamba_lm.lm_head.complement_map], dims=[1]), rtol=rtol, atol=atol ) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seq_len", [1024]) @pytest.mark.parametrize("n_layer", [2]) @pytest.mark.parametrize("d_model", [128]) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("bidirectional", [True, False]) @pytest.mark.parametrize("bidirectional_weight_tie", [True]) def test_collapse_invariance(batch_size, seq_len, n_layer, d_model, dtype, bidirectional, bidirectional_weight_tie): # Set tolerance device = torch.device("cuda") rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3) if dtype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 # Set seed torch.random.manual_seed(0) # Define complement map str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6} complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"} complement_map = { str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v for k, v in str_to_id.items() } # Setup CaduceusConfig initializer_cfg = {"initializer_range": 0.02, "rescale_prenorm_residual": True, "n_residuals_per_layer": 1} ssm_cfg = { "d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random", "dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True } config = CaduceusConfig( d_model=d_model, n_layer=n_layer, vocab_size=12, ssm_cfg=ssm_cfg, rms_norm=True, residual_in_fp32=False, fused_add_norm=True, pad_vocab_size_multiple=8, norm_epsilon=1e-5, initializer_cfg=initializer_cfg, bidirectional=bidirectional, bidirectional_strategy="add", bidirectional_weight_tie=bidirectional_weight_tie, rcps=True, complement_map=complement_map, ) factory_kwargs = {"device": device, "dtype": dtype} # Instantiate model backbone = CaduceusMixerModel( config, **factory_kwargs, ).to(device) # Generate random sequences input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device) rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device) # Test RC Invariance when collapsing output of backbone out = backbone(input_ids)[0] out_collapse = (out[..., :d_model] + torch.flip(out[..., d_model:], dims=[1, 2])) / 2 rc_out = backbone(rc_input_ids)[0] rc_out_collapse = (rc_out[..., :d_model] + torch.flip(rc_out[..., d_model:], dims=[1, 2])) / 2 # Hidden state size should be d_model assert tuple(out_collapse.size()) == (batch_size, seq_len, d_model) assert tuple(rc_out_collapse.size()) == (batch_size, seq_len, d_model) assert torch.allclose(out_collapse.detach(), rc_out_collapse.detach(), rtol=rtol, atol=atol) ================================================ FILE: caduceus/tokenization_caduceus.py ================================================ """Character tokenizer for Hugging Face. """ from typing import List, Optional, Dict, Sequence, Tuple from transformers import PreTrainedTokenizer class CaduceusTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids"] def __init__(self, model_max_length: int, characters: Sequence[str] = ("A", "C", "G", "T", "N"), complement_map=None, bos_token="[BOS]", eos_token="[SEP]", sep_token="[SEP]", cls_token="[CLS]", pad_token="[PAD]", mask_token="[MASK]", unk_token="[UNK]", **kwargs): """Character tokenizer for Hugging Face transformers. Adapted from https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen-hf/blob/main/tokenization_hyena.py Args: model_max_length (int): Model maximum sequence length. characters (Sequence[str]): List of desired characters. Any character which is not included in this list will be replaced by a special token called [UNK] with id=6. Following is a list of the special tokens with their corresponding ids: "[CLS]": 0 "[SEP]": 1 "[BOS]": 2 "[MASK]": 3 "[PAD]": 4 "[RESERVED]": 5 "[UNK]": 6 an id (starting at 7) will be assigned to each character. complement_map (Optional[Dict[str, str]]): Dictionary with string complements for each character. """ if complement_map is None: complement_map = {"A": "T", "C": "G", "G": "C", "T": "A", "N": "N"} self.characters = characters self.model_max_length = model_max_length self._vocab_str_to_int = { "[CLS]": 0, "[SEP]": 1, "[BOS]": 2, "[MASK]": 3, "[PAD]": 4, "[RESERVED]": 5, "[UNK]": 6, **{ch: i + 7 for i, ch in enumerate(self.characters)}, } self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()} add_prefix_space = kwargs.pop("add_prefix_space", False) padding_side = kwargs.pop("padding_side", "left") self._complement_map = {} for k, v in self._vocab_str_to_int.items(): complement_id = self._vocab_str_to_int[complement_map[k]] if k in complement_map.keys() else v self._complement_map[self._vocab_str_to_int[k]] = complement_id super().__init__( bos_token=bos_token, eos_token=eos_token, sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, mask_token=mask_token, unk_token=unk_token, add_prefix_space=add_prefix_space, model_max_length=model_max_length, padding_side=padding_side, **kwargs, ) @property def vocab_size(self) -> int: return len(self._vocab_str_to_int) @property def complement_map(self) -> Dict[int, int]: return self._complement_map def _tokenize(self, text: str, **kwargs) -> List[str]: return list(text.upper()) # Convert all base pairs to uppercase def _convert_token_to_id(self, token: str) -> int: return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"]) def _convert_id_to_token(self, index: int) -> str: return self._vocab_int_to_str[index] def convert_tokens_to_string(self, tokens): return "".join(tokens) # Note: this operation has lost info about which base pairs were originally lowercase def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False, ) -> List[int]: if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True, ) result = ([0] * len(token_ids_0)) + [1] if token_ids_1 is not None: result += ([0] * len(token_ids_1)) + [1] return result def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: sep = [self.sep_token_id] # cls = [self.cls_token_id] result = token_ids_0 + sep if token_ids_1 is not None: result += token_ids_1 + sep return result def get_vocab(self) -> Dict[str, int]: return self._vocab_str_to_int # Fixed vocabulary with no vocab file def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple: return () ================================================ FILE: caduceus_env.yml ================================================ name: caduceus_env channels: - pytorch - anaconda - nvidia - defaults dependencies: - cuda-nvcc=11.7.99 - pip=23.3.1 - python=3.8 - pytorch=2.2.0 - torchaudio=2.2.0 - torchaudio=2.2.0 - torchdata=0.7.1 - torchmetrics=1.2.1 - torchtext=0.17.0 - torchvision=0.17.0 - pytorch-cuda=12.1 - pip: - biopython==1.81 - datasets==2.15.0 - einops==0.7.0 - enformer-pytorch==0.8.8 - fsspec==2023.10.0 - genomic-benchmarks==0.0.9 - git-lfs==1.6 - h5py==3.10.0 - huggingface-hub==0.24.7 - hydra-core==1.3.2 - ipdb==0.13.13 - matplotlib==3.7.4 - notebook==7.1.1 - nvitop==1.3.2 - omegaconf==2.3.0 - pandas==2.0.3 - pyfaidx==0.8.1.1 - pysam==0.22.0 - pytest==8.0.2 - pytorch-lightning==1.8.6 - rich==13.7.0 - seaborn==0.13.2 - scikit-learn==1.3.2 - timm==0.9.16 - tqdm==4.66.1 - transformers==4.38.1 - triton==2.2.0 - wandb==0.13.5 - flash-attn==2.5.6 - causal-conv1d===1.2.0.post2 - mamba-ssm==1.2.0.post1 ================================================ FILE: configs/callbacks/base.yaml ================================================ learning_rate_monitor: # _target_: pytorch_lightning.callbacks.LearningRateMonitor logging_interval: ${train.interval} timer: # _target_: callbacks.timer.Timer step: True inter_step: False epoch: True val: True params: # _target_: callbacks.params.ParamsLog total: True trainable: True fixed: True ================================================ FILE: configs/callbacks/checkpoint.yaml ================================================ model_checkpoint: monitor: ${train.monitor} # name of the logged metric which determines when model is improving mode: ${train.mode} # can be "max" or "min" save_top_k: 1 # save k best models (determined by above metric) save_last: False # True = additionally always save model from last epoch dirpath: "checkpoints/" filename: ${train.monitor} auto_insert_metric_name: False verbose: True model_checkpoint_every_n_steps: monitor: train/loss # name of the logged metric which determines when model is improving mode: min # can be "max" or "min" save_top_k: 0 # Do not save any "best" models; this callback is being used to save every n train steps save_last: True # additionally always save model from last epoch dirpath: "checkpoints/" filename: train/loss auto_insert_metric_name: False verbose: True every_n_train_steps: 100 #model_checkpoint_every_epoch: # monitor: trainer/epoch # name of the logged metric which determines when model is improving # mode: max # can be "max" or "min" # save_top_k: 1 # Do not save any "best" models; this callback is being used to save every n train steps # save_last: False # additionally always save model from last epoch # dirpath: "checkpoints/" # filename: null # auto_insert_metric_name: False # verbose: True # every_n_epochs: 1 ================================================ FILE: configs/callbacks/gpu_affinity.yaml ================================================ gpu_affinity: _name_: gpu_affinity ================================================ FILE: configs/callbacks/rich.yaml ================================================ rich_model_summary: max_depth: 2 rich_progress_bar: refresh_rate_per_second: 1.0 ================================================ FILE: configs/callbacks/val_every_n_global_steps.yaml ================================================ val_every_n_global_steps: every_n: 10000 ================================================ FILE: configs/callbacks/wandb.yaml ================================================ defaults: - default watch_model: _target_: src.callbacks.wandb_callbacks.WatchModel log: "all" log_freq: 100 upload_code_as_artifact: _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact code_dir: ${work_dir}/src upload_ckpts_as_artifact: _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact ckpt_dir: "checkpoints/" upload_best_only: True log_f1_precision_recall_heatmap: _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap log_confusion_matrix: _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix log_image_predictions: _target_: src.callbacks.wandb_callbacks.LogImagePredictions num_samples: 8 ================================================ FILE: configs/config.yaml ================================================ # @package _global_ defaults: - _self_ - experiment: ??? # - model: ??? # Model backbone # - pipeline: ??? # Specifies collection of configs, equivalent to next 5 lines # Pipelines should specify /loader, /dataset, /task, /encoder, /decoder (ideally in that order) # # - loader: default # Dataloader (e.g. handles batches) # # - dataset: cifar # Defines the data (x and y pairs) # # - task: multiclass_classification # Defines loss and metrics # # - encoder: null # Interface between data and model # # - decoder: null # Interface between model and targets # Additional arguments used to configure the training loop # Most of these set combinations of options in the PL trainer, add callbacks, or add features to the optimizer train: seed: 0 # These three options are used by callbacks (checkpoint, monitor) and scheduler # Most of them are task dependent and are set by the pipeline interval: ??? # Should be specified by scheduler. Also used by LR monitor monitor: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer mode: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer ema: 0.0 # Moving average model for validation test: True # Test after training debug: False # Special settings to make debugging more convenient ignore_warnings: False # Disable python warnings optimizer_param_grouping: bias_weight_decay: False normalization_weight_decay: False # These control state passing between batches state: mode: null # [ None | 'none' | 'reset' | 'bptt' | 'tbptt' ] n_context: 0 # How many steps to use as memory context. Must be >= 0 or None (null), meaning infinite context n_context_eval: ${.n_context} # Context at evaluation time # Convenience keys to allow grouping runs ckpt: checkpoints/last.ckpt # Resume training disable_dataset: False # Disable dataset loading validate_at_start: false pretrained_model_path: null # Path to pretrained model pretrained_model_strict_load: true # Whether to load the pretrained model even if the model is not compatible pretrained_model_state_hook: # Hook called on the loaded model's state_dict _name_: null post_init_hook: # After initializing model, call method on model _name_: null layer_decay: # Used for ImageNet finetuning _name_: null decay: 0.7 # We primarily use wandb so this is moved to top level in the config for convenience # Set `~wandb` or `wandb=null` or `wandb.mode=disabled` to disable logging # If other loggers are added, it would make sense to put this one level lower under train/ or logger/ wandb: project: dna group: "" job_type: training mode: online # choices=['online', 'offline', 'disabled'] name: null save_dir: "." id: ${.name} # pass correct id to resume experiment! # Below options should not need to be specified # entity: "" # set to name of your wandb team or just remove it # log_model: False # prefix: "" # job_type: "train" # tags: [] hydra: run: dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S-%f} job: chdir: true ================================================ FILE: configs/dataset/genomic_benchmark.yaml ================================================ _name_: genomic_benchmark train_val_split_seed: ${train.seed} # Used for train/validation splitting dataset_name: dummy_mouse_enhancers_ensembl dest_path: null max_length: ${.${.dataset_name}.max_length} max_length_val: ${.max_length} max_length_test: ${.max_length} d_output: ${.${.dataset_name}.classes} use_padding: True padding_side: 'left' add_eos: False batch_size: 128 train_len: ${.${.dataset_name}.train_len} __l_max: ${.max_length} shuffle: true # set this as default! # these are used to find the right attributes automatically for each dataset dummy_mouse_enhancers_ensembl: train_len: 1210 classes: 2 max_length: 1024 demo_coding_vs_intergenomic_seqs: train_len: 100_000 classes: 2 max_length: 200 demo_human_or_worm: train_len: 100_000 classes: 2 max_length: 200 human_enhancers_cohn: train_len: 27791 classes: 2 max_length: 500 human_enhancers_ensembl: train_len: 154842 classes: 2 max_length: 512 human_ensembl_regulatory: train_len: 289061 classes: 3 max_length: 512 human_nontata_promoters: train_len: 36131 classes: 2 max_length: 251 human_ocr_ensembl: train_len: 174756 classes: 2 max_length: 512 # there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings # name num_seqs num_classes median len std # dummy_mouse_enhancers_ensembl 1210 2 2381 984.4 # demo_coding_vs_intergenomic_seqs 100_000 2 200 0 # demo_human_or_worm 100_000 2 200 0 # human_enhancers_cohn 27791 2 500 0 # human_enhancers_ensembl 154842 2 269 122.6 # human_ensembl_regulatory 289061 3 401 184.3 # human_nontata_promoters 36131 2 251 0 # human_ocr_ensembl 174756 2 315 108.1 ================================================ FILE: configs/dataset/hg38.yaml ================================================ _name_: hg38 bed_file: null fasta_file: null dataset_name: hg38 tokenizer_name: null cache_dir: null max_length: 1024 add_eos: True batch_size: 8 # per GPU batch_size_eval: ${eval:${.batch_size} * 2} num_workers: 4 # For preprocessing only shuffle: True __train_len: 34021 __l_max: ${.max_length} ================================================ FILE: configs/dataset/nucleotide_transformer.yaml ================================================ _name_: nucleotide_transformer # this links to the overall SequenceDataset of all nucleotide transformer datasets train_val_split_seed: ${train.seed} # Used for train/validation splitting dataset_name: enhancers # this specifies which dataset in nuc trx dest_path: null # path to overall nuc trx datasets max_length: ${.${.dataset_name}.max_length} d_output: ${.${.dataset_name}.classes} use_padding: True padding_side: left add_eos: False batch_size: 256 train_len: ${.${.dataset_name}.train_len} __l_max: ${.max_length} shuffle: true # set this as default! metric: ${.${.dataset_name}.metric} # these are used to find the right attributes automatically for each dataset enhancers: train_len: 14968 classes: 2 max_length: 200 metric: mcc enhancers_types: train_len: 14968 classes: 3 max_length: 200 metric: mcc H3: train_len: 13468 classes: 2 max_length: 500 metric: mcc H3K4me1: train_len: 28509 classes: 2 max_length: 500 metric: mcc H3K4me2: train_len: 27614 classes: 2 max_length: 500 metric: mcc H3K4me3: train_len: 33119 classes: 2 max_length: 500 metric: mcc H3K9ac: train_len: 25003 classes: 2 max_length: 500 metric: mcc H3K14ac: train_len: 29743 classes: 2 max_length: 500 metric: mcc H3K36me3: train_len: 31392 classes: 2 max_length: 500 metric: mcc H3K79me3: train_len: 25953 classes: 2 max_length: 500 metric: mcc H4: train_len: 13140 classes: 2 max_length: 500 metric: mcc H4ac: train_len: 30685 classes: 2 max_length: 500 metric: mcc promoter_all: train_len: 53276 classes: 2 max_length: 300 metric: f1_binary promoter_no_tata: train_len: 47767 classes: 2 max_length: 300 metric: f1_binary promoter_tata: train_len: 5517 classes: 2 max_length: 300 metric: f1_binary splice_sites_acceptors: train_len: 19961 classes: 2 max_length: 600 metric: f1_binary splice_sites_all: train_len: 27000 classes: 3 max_length: 400 metric: accuracy splice_sites_donors: train_len: 19775 classes: 2 max_length: 600 metric: f1_binary # name maxlen classes samples metric # enhancers 200 2 14968 MCC # enhancers_types 200 3 14968 MCC # H3 500 2 13468 MCC # H3K4me1 500 2 28509 MCC # H3K4me2 500 2 27614 MCC # H3K4me3 500 2 33119 MCC # H3K9ac 500 2 25003 MCC # H3K14ac 500 2 29743 MCC # H3K36me3 500 2 31392 MCC # H3K79me3 500 2 25953 MCC # H4 500 2 13140 MCC # H4ac 500 2 30685 MCC # promoter_all 300 2 53276 F1 # promoter_no_tata 300 2 47759 F1 # promoter_tata 300 2 5517 F1 # splice_sites_acceptor 600 2 19961 F1 # splice_sites_all 400 2 27000 F1 # splice_sites_donor 600 2 19775 F1 ================================================ FILE: configs/experiment/hg38/genomic_benchmark.yaml ================================================ # @package _global_ defaults: - /pipeline: genomic_benchmark - /model: ??? - override /scheduler: cosine_warmup_timm # there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings # name num_seqs num_classes median len std # dummy_mouse_enhancers_ensembl 1210 2 2381 984.4 # demo_coding_vs_intergenomic_seqs 100_000 2 200 0 # demo_human_or_worm 100_000 2 200 0 # human_enhancers_cohn 27791 2 500 0 # human_enhancers_ensembl 154842 2 269 122.6 # human_ensembl_regulatory 289061 3 401 184.3 # human_nontata_promoters 36131 2 251 0 # human_ocr_ensembl 174756 2 315 108.1 task: loss: _name_: cross_entropy trainer: accelerator: gpu devices: 1 num_nodes: 1 accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} max_epochs: 100 precision: 16 # bf16 only a100 gradient_clip_val: 1.0 model: _name_: dna_embedding dataset: # optional, default is max_length tokenizer_name: char rc_aug: false # reverse complement augmentation scheduler: # COSINE TIMM t_in_epochs: False t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} warmup_lr_init: 1e-6 warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} lr_min: ${eval:0.1 * ${optimizer.lr}} optimizer: lr: 6e-4 weight_decay: 0.1 train: gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} seed: 2222 global_batch_size: ${dataset.batch_size} cross_validation: true remove_test_loader_in_eval: true # test only at the end of training pretrained_model_strict_load: false # false allows encoder/decoder to be used if new model uses it # for loading backbone and not head, requires both of these flags below pretrained_model_path: ??? pretrained_model_state_hook: _name_: load_backbone freeze_backbone: false ================================================ FILE: configs/experiment/hg38/genomic_benchmark_cnn.yaml ================================================ # @package _global_ defaults: - /model: genomics_benchmark_cnn - /pipeline: genomic_benchmark - override /scheduler: cosine_warmup_timm # there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings # name num_seqs num_classes median len std # dummy_mouse_enhancers_ensembl 1210 2 2381 984.4 # demo_coding_vs_intergenomic_seqs 100_000 2 200 0 # demo_human_or_worm 100_000 2 200 0 # human_enhancers_cohn 27791 2 500 0 # human_enhancers_ensembl 154842 2 269 122.6 # human_ensembl_regulatory 289061 3 401 184.3 # human_nontata_promoters 36131 2 251 0 # human_ocr_ensembl 174756 2 315 108.1 task: loss: _name_: cross_entropy trainer: accelerator: gpu devices: 1 num_nodes: 1 accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} max_epochs: 100 precision: 16 # bf16 only a100 gradient_clip_val: 1.0 encoder: id decoder: id dataset: tokenizer_name: char rc_aug: false # reverse complement augmentation scheduler: t_in_epochs: False t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} warmup_lr_init: 1e-6 warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} lr_min: ${eval:0.1 * ${optimizer.lr}} optimizer: lr: 6e-4 weight_decay: 0.1 train: gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} seed: 2222 global_batch_size: ${dataset.batch_size} cross_validation: true remove_test_loader_in_eval: true pretrained_model_strict_load: false # false allows encoder/decoder to be used if new model uses it ================================================ FILE: configs/experiment/hg38/hg38.yaml ================================================ # @package _global_ defaults: - /pipeline: hg38 - /model: ??? # Specify a model, e.g. model=mamba or model=hyena - override /scheduler: cosine_warmup_timm task: _name_: lm loss: _name_: cross_entropy ignore_index: 4 trainer: accelerator: gpu devices: 1 num_nodes: 1 accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} max_epochs: null max_steps: 10000 precision: 16 # bf16 only a100 gradient_clip_val: 1.0 limit_val_batches: 0.125 dataset: batch_size: ${eval:1024//${trainer.devices}} max_length: 1024 # optional, default is max_length max_length_val: ${dataset.max_length} max_length_test: ${dataset.max_length} tokenizer_name: char pad_max_length: null # needed for bpe tokenizer add_eos: true rc_aug: false num_workers: 12 use_fixed_len_val: false # placing a fixed length val here, but it's really the test mlm: false mlm_probability: 0.0 scheduler: t_in_epochs: False t_initial: ${eval:${trainer.max_steps}-${.warmup_t}} warmup_prefix: True warmup_lr_init: 1e-6 warmup_t: ${eval:0.1*${trainer.max_steps}} lr_min: 1e-4 optimizer: lr: 6e-4 weight_decay: 0.1 betas: [0.9, 0.95] train: gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} seed: 2222 global_batch_size: 256 # effects the scheduler, need to set properly ================================================ FILE: configs/experiment/hg38/nucleotide_transformer.yaml ================================================ # @package _global_ defaults: - /pipeline: nucleotide_transformer - /model: ??? - override /scheduler: cosine_warmup_timm model: _name_: dna_embedding trainer: accelerator: gpu devices: 1 num_nodes: 1 accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}} max_epochs: 100 precision: 16 # bf16 only a100 gradient_clip_val: 1.0 dataset: tokenizer_name: char rc_aug: false # reverse complement augmentation scheduler: t_in_epochs: False t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}} warmup_lr_init: 1e-6 warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01} lr_min: ${eval:0.1 * ${optimizer.lr}} optimizer: lr: 1e-3 weight_decay: 0.1 train: gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} seed: 2222 global_batch_size: ${dataset.batch_size} cross_validation: true remove_test_loader_in_eval: true # test only at the end of training pretrained_model_strict_load: false # false allows encoder/decoder to be used if new model uses it # for loading backbone and not head, requires both of these flags below pretrained_model_path: ??? pretrained_model_state_hook: _name_: load_backbone freeze_backbone: false ================================================ FILE: configs/loader/default.yaml ================================================ num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"} pin_memory: True drop_last: True ================================================ FILE: configs/model/caduceus.yaml ================================================ # Use open-source version of Mamba _name_: caduceus_lm config: _target_: caduceus.configuration_caduceus.CaduceusConfig # From original MambaConfig d_model: 128 n_layer: 2 vocab_size: 12 ssm_cfg: d_state: 16 d_conv: 4 expand: 2 dt_rank: "auto" dt_min: 0.001 dt_max: 0.1 dt_init: "random" dt_scale: 1.0 dt_init_floor: 1e-4 conv_bias: true bias: false use_fast_path: true rms_norm: true fused_add_norm: true residual_in_fp32: false pad_vocab_size_multiple: 8 # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm norm_epsilon: 1e-5 # Used in init_weights initializer_cfg: initializer_range: 0.02 rescale_prenorm_residual: true n_residuals_per_layer: 1 # Caduceus-specific params bidirectional: true, bidirectional_strategy: "add" bidirectional_weight_tie: true rcps: false # Used for RCPSEmbedding / RCPSLMHead (will be filled in during model instantiation using info from tokenizer) complement_map: null ================================================ FILE: configs/model/genomics_benchmark_cnn.yaml ================================================ # Use open-source version of Mamba _name_: genomics_benchmark_cnn number_of_classes: ${dataset.d_output} vocab_size: 12 embedding_dim: 100 # See: https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/tree/main/experiments/torch_cnn_experiments input_len: ${dataset.__l_max} ================================================ FILE: configs/model/hyena.yaml ================================================ _name_: hyena_lm d_model: 128 n_layer: 2 d_inner: ${eval:4 * ${.d_model}} vocab_size: 12 resid_dropout: 0.0 embed_dropout: 0.1 fused_mlp: False fused_dropout_add_ln: False checkpoint_mixer: False # set true for memory reduction checkpoint_mlp: False # set true for memory reduction residual_in_fp32: True pad_vocab_size_multiple: 8 layer: _name_: hyena emb_dim: 5 filter_order: 64 local_order: 3 l_max: ${eval:${dataset.max_length}+2} modulate: True w: 10 lr: ${optimizer.lr} wd: 0.0 lr_pos_emb: 0.0 ================================================ FILE: configs/model/layer/hyena.yaml ================================================ _name_: hyena l_max: 1024 order: 2 filter_order: 64 num_heads: 1 inner_factor: 1 num_blocks: 1 fused_bias_fc: false outer_mixing: false dropout: 0.0 filter_dropout: 0.0 filter_cls: 'hyena-filter' post_order_ffn: false jit_filter: false short_filter_order: 3 activation: "id" ================================================ FILE: configs/model/mamba.yaml ================================================ # Use open-source version of Mamba _name_: mamba_lm config: _target_: mamba_ssm.models.config_mamba.MambaConfig d_model: 128 # Will be overwritten by CL in the scaling exps n_layer: 2 # Will be overwritten by CL in the scaling exps vocab_size: 12 pad_vocab_size_multiple: 8 rms_norm: true fused_add_norm: true residual_in_fp32: false ssm_cfg: d_state: 16 d_conv: 4 expand: 2 dt_rank: "auto" dt_min: 0.001 dt_max: 0.1 dt_init: "random" dt_scale: 1.0 dt_init_floor: 1e-4 conv_bias: true bias: false use_fast_path: true initializer_cfg: initializer_range: 0.02 rescale_prenorm_residual: true n_residuals_per_layer: 1 #norm_epsilon: 1e-5 # Default arg in mamba create_block ================================================ FILE: configs/optimizer/adam.yaml ================================================ # _target_: torch.optim.Adam _name_: adam lr: 0.001 # Initial learning rate # weight_decay: 0.0 # Weight decay for adam|lamb; should use AdamW instead if desired betas: [0.9, 0.999] ================================================ FILE: configs/optimizer/adamw.yaml ================================================ # _target_: torch.optim.AdamW _name_: adamw lr: 0.001 # Initial learning rate weight_decay: 0.00 # Weight decay betas: [0.9, 0.999] ================================================ FILE: configs/optimizer/sgd.yaml ================================================ # _target_: torch.optim.SGD _name_: sgd lr: 0.001 # Initial learning rate momentum: 0.9 weight_decay: 0.0 # Weight decay for adam|lamb ================================================ FILE: configs/pipeline/genomic_benchmark.yaml ================================================ # @package _global_ defaults: - /trainer: default - /loader: default - /dataset: genomic_benchmark - /task: multiclass_classification - /optimizer: adamw - /scheduler: plateau - /callbacks: [base, checkpoint] train: monitor: val/accuracy # Needed for plateau scheduler mode: max encoder: id # we need this for classification! decoder: _name_: sequence mode: pool ================================================ FILE: configs/pipeline/hg38.yaml ================================================ # @package _global_ defaults: - /trainer: default - /loader: null - /dataset: hg38 - /optimizer: adamw - /scheduler: cosine_warmup - /callbacks: [base, checkpoint] train: monitor: test/loss mode: min task: _name_: lm loss: _name_: cross_entropy ignore_index: 4 # Bake in tokenizer value for padding / EOS tokens torchmetrics: ['perplexity', 'num_tokens'] encoder: null decoder: null loader: num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"} pin_memory: True drop_last: True # There's enough data and epochs, ignore the edge case # shuffle: True ================================================ FILE: configs/pipeline/nucleotide_transformer.yaml ================================================ # @package _global_ defaults: - /trainer: default - /loader: default - /dataset: nucleotide_transformer - /task: multiclass_classification - /optimizer: adamw - /scheduler: plateau - /callbacks: [base, checkpoint] task: loss: _name_: cross_entropy metrics: - ${dataset.metric} train: monitor: val/${dataset.metric} mode: max encoder: id # we need this for classification! decoder: _name_: sequence mode: pool ================================================ FILE: configs/scheduler/constant.yaml ================================================ # @package _global_ train: interval: epoch scheduler: # _target_: transformers.get_constant_schedule _name_: constant ================================================ FILE: configs/scheduler/constant_warmup.yaml ================================================ # @package _global_ train: interval: step scheduler: # _target_: transformers.get_constant_schedule_with_warmup _name_: constant_warmup num_warmup_steps: 1000 # Number of iterations for LR warmup ================================================ FILE: configs/scheduler/cosine.yaml ================================================ # @package _global_ train: interval: epoch scheduler: # _target_: torch.optim.lr_scheduler.CosineAnnealingLR _name_: cosine T_max: 100 # Max number of epochs steps for LR scheduler eta_min: 1e-6 # Min learning rate for cosine scheduler ================================================ FILE: configs/scheduler/cosine_warmup.yaml ================================================ # @package _global_ train: interval: step scheduler: # _target_: transformers.get_cosine_schedule_with_warmup _name_: cosine_warmup num_warmup_steps: 1000 num_training_steps: 40000 ================================================ FILE: configs/scheduler/cosine_warmup_timm.yaml ================================================ # @package _global_ train: interval: step scheduler: # _target_: transformers.get_cosine_schedule_with_warmup _name_: cosine_warmup_timm t_in_epochs: False t_initial: 300 lr_min: 1e-5 warmup_lr_init: 1e-6 warmup_t: 10 ================================================ FILE: configs/scheduler/linear_warmup.yaml ================================================ # @package _global_ train: interval: step scheduler: # _target_: transformers.get_linear_schedule_with_warmup _name_: linear_warmup num_warmup_steps: 1000 num_training_steps: 40000 ================================================ FILE: configs/scheduler/multistep.yaml ================================================ # @package _global_ train: interval: epoch # _target_: torch.optim.lr_scheduler.MultiStepLR scheduler: _name_: multistep milestones: [80,140,180] gamma: 0.2 ================================================ FILE: configs/scheduler/plateau.yaml ================================================ # @package _global_ train: interval: epoch monitor: ??? # must be specified scheduler: # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau _name_: plateau mode: ${train.mode} # Which metric to monitor factor: 0.2 # Decay factor when ReduceLROnPlateau is used patience: 20 min_lr: 0.0 # Minimum learning rate during annealing ================================================ FILE: configs/scheduler/step.yaml ================================================ # @package _global_ train: interval: step scheduler: # _target_: torch.optim.lr_scheduler.StepLR _name_: step step_size: 1 gamma: 0.99 ================================================ FILE: configs/task/lm.yaml ================================================ _name_: lm # loss: cross_entropy # Handled by task: cross entropy loss metrics: ppl ================================================ FILE: configs/task/multiclass_classification.yaml ================================================ # _target_: tasks.tasks.MultiClass _name_: multiclass loss: cross_entropy metrics: - accuracy torchmetrics: null ================================================ FILE: configs/task/multilabel_classification.yaml ================================================ # _target_: _name_: base loss: binary_cross_entropy metrics: null torchmetrics: - MultilabelAUROC # AUROC - MultilabelAveragePrecision # Precision # - Recall # not supported in torchmetrics # - F1 # not supported in torchmetrics ================================================ FILE: configs/task/regression.yaml ================================================ # _target_: tasks.tasks.BaseTask _name_: base loss: mse metrics: mse torchmetrics: null ================================================ FILE: configs/trainer/debug.yaml ================================================ defaults: - default gpus: 1 min_epochs: 1 max_epochs: 10 # prints progress_bar_refresh_rate: null weights_summary: full profiler: null # debugs fast_dev_run: False num_sanity_val_steps: 2 overfit_batches: 0 limit_train_batches: 0.1 limit_val_batches: 0.1 limit_test_batches: 0.1 track_grad_norm: -1 terminate_on_nan: False ================================================ FILE: configs/trainer/default.yaml ================================================ _target_: pytorch_lightning.Trainer devices: 1 accelerator: gpu accumulate_grad_batches: 1 # Gradient accumulation every n batches max_epochs: 200 # accelerator: ddp # Automatically set if gpus > 1 gradient_clip_val: 0.0 log_every_n_steps: 10 limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run num_sanity_val_steps: 2 # default value: 2; override to 0 to skip sanity checking ================================================ FILE: configs/trainer/full.yaml ================================================ _target_: pytorch_lightning.Trainer # default values for all trainer parameters checkpoint_callback: True default_root_dir: null gradient_clip_val: 0.0 process_position: 0 num_nodes: 1 num_processes: 1 gpus: null auto_select_gpus: False tpu_cores: null log_gpu_memory: null overfit_batches: 0.0 track_grad_norm: -1 check_val_every_n_epoch: 1 fast_dev_run: False accumulate_grad_batches: 1 max_epochs: 1 min_epochs: 1 max_steps: null min_steps: null limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 val_check_interval: 1.0 flush_logs_every_n_steps: 100 log_every_n_steps: 50 accelerator: null sync_batchnorm: False precision: 32 weights_summary: "top" weights_save_path: null num_sanity_val_steps: 2 truncated_bptt_steps: null resume_from_checkpoint: null profiler: null benchmark: False deterministic: False reload_dataloaders_every_epoch: False auto_lr_find: False replace_sampler_ddp: True terminate_on_nan: False auto_scale_batch_size: False prepare_data_per_node: True plugins: null amp_backend: "native" amp_level: "O2" move_metrics_to_cpu: False ================================================ FILE: configs/trainer/lm.yaml ================================================ accumulate_grad_batches: 1 # accelerator: null # set to 'ddp' for distributed # amp_backend: native # 'native' | 'apex' gpus: 8 max_epochs: 50 gradient_clip_val: 0.0 # Gradient clipping log_every_n_steps: 10 precision: 16 progress_bar_refresh_rate: 1 weights_summary: top # Set to 'full' to see every layer track_grad_norm: -1 # Set to 2 to track norms of gradients limit_train_batches: 1.0 limit_val_batches: 1.0 # We use the dataloader from Transformer-XL to ensure adjacent minibatches # are from text that are next to each other. # So that dataloader has to deal with DDP, and we don't want PL to handle # that. replace_sampler_ddp: False ================================================ FILE: setup_env.sh ================================================ #!/bin/bash # Shell script to set environment variables when running code in this repository. # Usage: # source setup_env.sh # Activate conda env # shellcheck source=${HOME}/.bashrc disable=SC1091 source "${CONDA_SHELL}" if [ -z "${CONDA_PREFIX}" ]; then conda activate caduceus_env elif [[ "${CONDA_PREFIX}" != *"/caduceus_env" ]]; then conda deactivate conda activate caduceus_env fi # Add root directory to PYTHONPATH to enable module imports export PYTHONPATH="${PWD}" ================================================ FILE: slurm_scripts/dump_vep_embeddings.sh ================================================ #!/bin/bash #SBATCH --get-user-env # Retrieve the users login environment #SBATCH -t 96:00:00 # Time limit (hh:mm:ss) #SBATCH --mem=100G # RAM #SBATCH --gres=gpu:8 # Number of GPUs #SBATCH --ntasks-per-node=8 # Should correspond to num devices (at least 1-1 task to GPU) ##SBATCH --cpus-per-task=4 # Number of CPU cores per task #SBATCH -N 1 # Number of nodes #SBATCH --requeue # Requeue job if it fails #SBATCH --job-name=vep_embed # Job name #SBATCH--output=../watch_folder/%x_%j.log # Output file name #SBATCH --open-mode=append # Do not overwrite logs NUM_WORKERS=2 NUM_DEVICES=8 # Setup environment cd ../ || exit # Go to the root directory of the repo source setup_env.sh export CUDA_LAUNCH_BLOCKING=1 export CUBLAS_WORKSPACE_CONFIG=:4096:8 # Needed for setting deterministic functions for reproducibility ##################################################################################### # Choose from one of the following: ## Enformer #seq_len=196608 #bp_per_token=1 #embed_dump_batch_size=1 #model_name_or_path="EleutherAI/enformer-official-rough" #name="enformer-seqlen=196k" #rcps_flag="no-rcps" ## NTv2 #seq_len=12288 # 2048 (seq len) * 6 (kmers) #bp_per_token=6 #embed_dump_batch_size=1 #model_name_or_path="InstaDeepAI/nucleotide-transformer-v2-500m-multi-species" #name="NTv2_downstream-seqlen=12k" #rcps_flag="no-rcps" ## Hyena #seq_len=131072 #bp_per_token=1 #embed_dump_batch_size=1 #model_name_or_path="LongSafari/hyenadna-medium-160k-seqlen-hf" #name="hyena_downstream-seqlen=131k" #rcps_flag="no-rcps" ## Caduceus-Ph #seq_len=131072 #bp_per_token=1 #embed_dump_batch_size=1 #model_name_or_path="kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16" #name="caduceus-ph_downstream-seqlen=131k" #rcps_flag="no-rcps" ## Caduceus-PS #seq_len=131072 #bp_per_token=1 #embed_dump_batch_size=1 #model_name_or_path="kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16" #name="caduceus-ps_downstream-seqlen=131k" #rcps_flag="rcps" ##################################################################################### torchrun \ --standalone \ --nnodes=1 \ --nproc-per-node=${NUM_DEVICES} \ vep_embeddings.py \ --num_workers=${NUM_WORKERS} \ --seq_len=${seq_len} \ --bp_per_token=${bp_per_token} \ --embed_dump_batch_size=${embed_dump_batch_size} \ --name="${name}" \ --model_name_or_path="${model_name_or_path}" \ --"${rcps_flag}" ================================================ FILE: slurm_scripts/run_genomics_benchmark.sh ================================================ #!/bin/bash #SBATCH --get-user-env # Retrieve the users login environment #SBATCH -t 96:00:00 # Time limit (hh:mm:ss) #SBATCH --mem=64000M # RAM #SBATCH --gres=gpu:1 # Number of GPUs #SBATCH --ntasks-per-node=1 #SBATCH --cpus-per-task=2 #SBATCH -N 1 # Number of nodes #SBATCH --requeue # Requeue job if it fails #SBATCH --open-mode=append # Do not overwrite logs # Setup environment cd ../ || exit # Go to the root directory of the repo source setup_env.sh # Expected args: # - CONFIG_PATH # - PRETRAINED_PATH # - DISPLAY_NAME # - MODEL # - MODEL_NAME # - CONJOIN_TRAIN_DECODER # - CONJOIN_TEST # - TASK # - LR # - BATCH_SIZE # - RC_AUG # Run script # shellcheck disable=SC2154 WANDB_NAME="${DISPLAY_NAME}_lr-${LR}_batch_size-${BATCH_SIZE}_rc_aug-${RC_AUG}" for seed in $(seq 1 5); do # shellcheck disable=SC2154 HYDRA_RUN_DIR="./outputs/downstream/gb_cv5/${TASK}/${WANDB_NAME}/seed-${seed}" mkdir -p "${HYDRA_RUN_DIR}" echo "*****************************************************" echo "Running GenomicsBenchmark model: ${DISPLAY_NAME}, task: ${TASK}, lr: ${LR}, batch_size: ${BATCH_SIZE}, rc_aug: ${RC_AUG}, SEED: ${seed}" # shellcheck disable=SC2086 python -m train \ experiment=hg38/genomic_benchmark \ callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \ dataset.dataset_name="${TASK}" \ dataset.train_val_split_seed=${seed} \ dataset.batch_size=${BATCH_SIZE} \ dataset.rc_aug="${RC_AUG}" \ +dataset.conjoin_train=false \ +dataset.conjoin_test="${CONJOIN_TEST}" \ model="${MODEL}" \ model._name_="${MODEL_NAME}" \ +model.config_path="${CONFIG_PATH}" \ +model.conjoin_test="${CONJOIN_TEST}" \ +decoder.conjoin_train="${CONJOIN_TRAIN_DECODER}" \ +decoder.conjoin_test="${CONJOIN_TEST}" \ optimizer.lr="${LR}" \ trainer.max_epochs=10 \ train.pretrained_model_path="${PRETRAINED_PATH}" \ wandb.group="downstream/gb_cv5" \ wandb.job_type="${TASK}" \ wandb.name="${WANDB_NAME}" \ wandb.id="gb_cv5_${TASK}_${WANDB_NAME}_seed-${seed}" \ +wandb.tags=\["seed-${seed}"\] \ hydra.run.dir="${HYDRA_RUN_DIR}" echo "*****************************************************" done ================================================ FILE: slurm_scripts/run_genomics_benchmark_cnn.sh ================================================ #!/bin/bash #SBATCH --get-user-env # Retrieve the users login environment #SBATCH -t 48:00:00 # Time limit (hh:mm:ss) #SBATCH --mem=64G # RAM #SBATCH --gres=gpu:1 # Number of GPUs #SBATCH --ntasks-per-node=1 #SBATCH --cpus-per-task=2 #SBATCH -N 1 # Number of nodes #SBATCH --requeue # Requeue job if it fails #SBATCH --open-mode=append # Do not overwrite logs # Setup environment cd ../ || exit # Go to the root directory of the repo source setup_env.sh # Expected args: # - TASK # - RC_AUG # LR: 1e-3 -- in https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks, Adam optimizer is used with default lr=1e-3 LR="1e-3" # Batch size: 64 -- See https://arxiv.org/abs/2306.15794 and https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks BATCH_SIZE=64 # Run script WANDB_NAME="CNN-LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}" for seed in $(seq 1 5); do HYDRA_RUN_DIR="./outputs/downstream/gb_cv5/${TASK}/${WANDB_NAME}/seed-${seed}" mkdir -p "${HYDRA_RUN_DIR}" echo "*****************************************************" echo "Running GenomicsBenchmark TASK: ${TASK}, lr: ${LR}, batch_size: ${BATCH_SIZE}, RC_AUG: ${RC_AUG}, SEED: ${seed}" python -m train \ experiment=hg38/genomic_benchmark_cnn \ callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \ dataset.dataset_name="${TASK}" \ dataset.train_val_split_seed=${seed} \ dataset.batch_size=${BATCH_SIZE} \ dataset.rc_aug="${RC_AUG}" \ optimizer.lr="${LR}" \ trainer.max_epochs=10 \ wandb.group="downstream/gb_cv5" \ wandb.job_type="${TASK}" \ wandb.name="${WANDB_NAME}" \ wandb.id="gb_cv5_${TASK}_${WANDB_NAME}_seed-${seed}" \ +wandb.tags=\["seed-${seed}"\] \ hydra.run.dir="${HYDRA_RUN_DIR}" echo "*****************************************************" done ================================================ FILE: slurm_scripts/run_nucleotide_transformer.sh ================================================ #!/bin/bash #SBATCH --get-user-env # Retrieve the users login environment #SBATCH -t 96:00:00 # Time limit (hh:mm:ss) #SBATCH --mem=64G # RAM #SBATCH --gres=gpu:2 # Number of GPUs #SBATCH --ntasks-per-node=2 #SBATCH --cpus-per-task=4 #SBATCH -N 1 # Number of nodes #SBATCH --requeue # Requeue job if it fails #SBATCH --open-mode=append # Do not overwrite logs # Setup environment cd ../ || exit # Go to the root directory of the repo source setup_env.sh export HYDRA_FULL_ERROR=1 # Expected args: # - CONFIG_PATH # - PRETRAINED_PATH # - DISPLAY_NAME # - MODEL # - MODEL_NAME # - CONJOIN_TRAIN_DECODER # - CONJOIN_TEST # - TASK # - LR # - BATCH_SIZE # - RC_AUG # Run script WANDB_NAME="${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}" for seed in $(seq 1 10); do HYDRA_RUN_DIR="./outputs/downstream/nt_cv10_ep20/${TASK}/${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}/seed-${seed}" mkdir -p "${HYDRA_RUN_DIR}" echo "*****************************************************" echo "Running NT model: ${DISPLAY_NAME}, TASK: ${TASK}, LR: ${LR}, BATCH_SIZE: ${BATCH_SIZE}, RC_AUG: ${RC_AUG}, SEED: ${seed}" python -m train \ experiment=hg38/nucleotide_transformer \ callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \ dataset.dataset_name="${TASK}" \ dataset.train_val_split_seed=${seed} \ dataset.batch_size=${BATCH_SIZE} \ dataset.rc_aug="${RC_AUG}" \ +dataset.conjoin_test="${CONJOIN_TEST}" \ model="${MODEL}" \ model._name_="${MODEL_NAME}" \ +model.config_path="${CONFIG_PATH}" \ +model.conjoin_test="${CONJOIN_TEST}" \ +decoder.conjoin_train="${CONJOIN_TRAIN_DECODER}" \ +decoder.conjoin_test="${CONJOIN_TEST}" \ optimizer.lr="${LR}" \ train.pretrained_model_path="${PRETRAINED_PATH}" \ trainer.max_epochs=20 \ wandb.group="downstream/nt_cv10_ep20" \ wandb.job_type="${TASK}" \ wandb.name="${WANDB_NAME}" \ wandb.id="nt_cv10_ep-20_${TASK}_${WANDB_NAME}_seed-${seed}" \ +wandb.tags=\["seed-${seed}"\] \ hydra.run.dir="${HYDRA_RUN_DIR}" echo "*****************************************************" done ================================================ FILE: slurm_scripts/run_pretrain_caduceus.sh ================================================ #!/bin/bash #SBATCH --get-user-env # Retrieve the users login environment #SBATCH -t 96:00:00 # Time limit (hh:mm:ss) #SBATCH --mem=100G # RAM #SBATCH --gres=gpu:8 # Number of GPUs #SBATCH --ntasks-per-node=8 # Should correspond to num devices (at least 1-1 task to GPU) ##SBATCH --cpus-per-task=4 # Number of CPU cores per task #SBATCH -N 1 # Number of nodes #SBATCH --requeue # Requeue job if it fails #SBATCH --job-name=caduceus_ps # Job name #SBATCH --output=../watch_folder/%x_%j.log # Log file #SBATCH --open-mode=append # Do not overwrite logs # Setup environment cd ../ || exit # Go to the root directory of the repo source setup_env.sh export HYDRA_FULL_ERROR=1 NUM_DEVICES=8 # Run script SEQLEN=131072 MAX_STEPS=50000 D_MODEL=256 N_LAYER=8 LR="8e-3" BIDIRECTIONAL_STRATEGY="add" BIDIRECTIONAL_WEIGHT_TIE="true" RCPS="true" RC_AUG="false" BATCH_SIZE=$(( 1048576 / SEQLEN )) SEQLEN_DIS="$(echo "scale=0; ${SEQLEN} / 1000" | bc)k" WANDB_NAME="caduceus-ps_seqlen-${SEQLEN_DIS}_d_model-${D_MODEL}_n_layer-${N_LAYER}_lr-${LR}" HYDRA_RUN_DIR="./outputs/pretrain/hg38/${WANDB_NAME}" mkdir -p "${HYDRA_RUN_DIR}" srun python -m train \ experiment=hg38/hg38 \ callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \ dataset.max_length=${SEQLEN} \ dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \ dataset.mlm=true \ dataset.mlm_probability=0.15 \ dataset.rc_aug="${RC_AUG}" \ model="caduceus" \ model.config.d_model=${D_MODEL} \ model.config.n_layer=${N_LAYER} \ model.config.bidirectional=true \ model.config.bidirectional_strategy=${BIDIRECTIONAL_STRATEGY} \ model.config.bidirectional_weight_tie=${BIDIRECTIONAL_WEIGHT_TIE} \ model.config.rcps=${RCPS} \ optimizer.lr="${LR}" \ train.global_batch_size=${BATCH_SIZE} \ trainer.max_steps=${MAX_STEPS} \ trainer.devices=${NUM_DEVICES} \ +trainer.val_check_interval=$(( MAX_STEPS / 5 )) \ wandb.group=pretrain_hg38 \ wandb.name="${WANDB_NAME}" \ hydra.run.dir="${HYDRA_RUN_DIR}" ================================================ FILE: slurm_scripts/run_pretrain_hyena.sh ================================================ #!/bin/bash #SBATCH --get-user-env # Retrieve the users login environment #SBATCH -t 96:00:00 # Time limit (hh:mm:ss) #SBATCH --mem=100G # RAM #SBATCH --gres=gpu:8 # Number of GPUs #SBATCH --ntasks-per-node=8 #SBATCH --cpus-per-task=4 # Number of CPU cores per task #SBATCH -N 1 # Number of nodes #SBATCH --requeue # Requeue job if it fails #SBATCH --job-name=hyena # Job name #SBATCH --output=../watch_folder/%x_%j.log # Log file # Setup environment cd ../ || exit # Go to the root directory of the repo source setup_env.sh NUM_DEVICES=8 # Run script SEQLEN=1024 MAX_STEPS=10000 D_MODEL=256 N_LAYER=4 LR="6e-4" RC_AUG="true" BATCH_SIZE=$(( 1048576 / SEQLEN )) SEQLEN_DIS="$(echo "scale=0; ${SEQLEN} / 1000" | bc)k" WANDB_NAME="hyena_rc_aug_seqlen-${SEQLEN_DIS}_dmodel-${D_MODEL}_nlayer-${N_LAYER}_lr-${LR}" HYDRA_RUN_DIR="./outputs/pretrain/hg38/${WANDB_NAME}" mkdir -p "${HYDRA_RUN_DIR}" srun python -m train \ experiment=hg38/hg38 \ callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \ dataset.max_length=${SEQLEN} \ dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \ dataset.mlm=false \ dataset.mlm_probability=0.0 \ dataset.rc_aug="${RC_AUG}" \ model=hyena \ model.d_model=${D_MODEL} \ model.n_layer=${N_LAYER} \ optimizer.lr="${LR}" \ train.global_batch_size=${BATCH_SIZE} \ trainer.max_steps=${MAX_STEPS} \ trainer.devices=${NUM_DEVICES} \ +trainer.val_check_interval=$(( MAX_STEPS / 5 )) \ wandb.group=pretrain_hg38 \ wandb.name="${WANDB_NAME}" \ hydra.run.dir="${HYDRA_RUN_DIR}" ================================================ FILE: slurm_scripts/run_pretrain_mamba.sh ================================================ #!/bin/bash #SBATCH --get-user-env # Retrieve the users login environment #SBATCH -t 96:00:00 # Time limit (hh:mm:ss) #SBATCH --mem=100G # RAM #SBATCH --gres=gpu:8 # Number of GPUs #SBATCH --ntasks-per-node=8 # Should correspond to num devices (at least 1-1 task to GPU) #SBATCH --cpus-per-task=4 # Number of CPU cores per task #SBATCH -N 1 # Number of nodes #SBATCH --requeue # Requeue job if it fails #SBATCH --job-name=mamba_ntp # Job name #SBATCH --output=../watch_folder/%x_%j.log # Log file # Setup environment cd ../ || exit # Go to the root directory of the repo source setup_env.sh NUM_DEVICES=8 # Run script SEQLEN=1024 MAX_STEPS=10000 D_MODEL=256 N_LAYER=8 LR="8e-3" RC_AUG="true" BATCH_SIZE=$(( 1048576 / SEQLEN )) SEQLEN_DIS="$(echo "scale=0; ${SEQLEN} / 1000" | bc)k" WANDB_NAME="mamba_ntp_rc_aug_seqlen-${SEQLEN_DIS}_d_model-${D_MODEL}_n_layer-${N_LAYER}_lr-${LR}" HYDRA_RUN_DIR="./outputs/pretrain/hg38/${WANDB_NAME}" mkdir -p "${HYDRA_RUN_DIR}" srun python -m train \ experiment=hg38/hg38 \ callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \ dataset.max_length=${SEQLEN} \ dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \ dataset.mlm=false \ dataset.mlm_probability=0.0 \ dataset.rc_aug="${RC_AUG}" \ model=mamba \ model.config.d_model=${D_MODEL} \ model.config.n_layer=${N_LAYER} \ optimizer.lr="${LR}" \ train.global_batch_size=${BATCH_SIZE} \ trainer.max_steps=${MAX_STEPS} \ trainer.devices=${NUM_DEVICES} \ +trainer.val_check_interval=$(( MAX_STEPS / 5 )) \ wandb.group=pretrain_hg38 \ wandb.name="${WANDB_NAME}" \ hydra.run.dir="${HYDRA_RUN_DIR}" ================================================ FILE: slurm_scripts/wrapper_run_genomics.sh ================================================ #!/bin/bash # Choose one from below ## Hyena ## TODO: Download HF model from https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen to ../outputs/hyena_hf/hyenadna-tiny-1k-seqlen #LOG_DIR="../watch_folder/gb_cv5/hyena" #CONFIG_PATH=$(realpath "../outputs/hyena_hf/hyenadna-tiny-1k-seqlen/config.json") #PRETRAINED_PATH=$(realpath "../outputs/hyena_hf/hyenadna-tiny-1k-seqlen/weights.ckpt") #DISPLAY_NAME="hyena" #MODEL="hyena" #MODEL_NAME="dna_embedding" #CONJOIN_TRAIN_DECODER="false" #CONJOIN_TEST="false" #RC_AUGS=( "false" "true" ) #LRS=( "6e-4" ) ## Mamba NTP #LOG_DIR="../watch_folder/gb_cv5/mamba" #CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/mamba_ntp_rc_aug_seqlen-1k_d_model-128_n_layer-4_lr-8e-3/model_config.json") #PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/mamba_ntp_rc_aug_seqlen-1k_d_model-128_n_layer-4_lr-8e-3/checkpoints/last.ckpt") #DISPLAY_NAME="mamba_uni" #MODEL="mamba" #MODEL_NAME="dna_embedding_mamba" #CONJOIN_TRAIN_DECODER="false" #CONJOIN_TEST="false" #RC_AUGS=( "true" ) #LRS=( "1e-3" "2e-3" ) ## Caduceus NO POST HOC #LOG_DIR="../watch_folder/gb_cv5/caduceus" #CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/model_config.json") #PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/checkpoints/last.ckpt") #DISPLAY_NAME="caduceus_NO_PH" #MODEL="caduceus" #MODEL_NAME="dna_embedding_caduceus" #CONJOIN_TRAIN_DECODER="false" #CONJOIN_TEST="false" #RC_AUGS=( "true" ) #LRS=( "2e-3") ## Caduceus Post-Hoc #LOG_DIR="../watch_folder/gb_cv5/caduceus" #CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/model_config.json") #PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/checkpoints/last.ckpt") #DISPLAY_NAME="caduceus_ph" #MODEL="caduceus" #MODEL_NAME="dna_embedding_caduceus" #CONJOIN_TRAIN_DECODER="false" #CONJOIN_TEST="true" #RC_AUGS=( "false" ) #LRS=( "1e-3" "2e-3" ) ## Caduceus Parameter Sharing #LOG_DIR="../watch_folder/gb_cv5/caduceus" #CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/model_config.json") #PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/checkpoints/last.ckpt") #DISPLAY_NAME="caduceus_ps" #MODEL="caduceus" #MODEL_NAME="dna_embedding_caduceus" #CONJOIN_TRAIN_DECODER="true" # Use this in decoder to always combine forward and reverse complement channels #CONJOIN_TEST="false" #RC_AUGS=( "false" ) #LRS=( "1e-3" "2e-3" ) mkdir -p "${LOG_DIR}" export_str="ALL,CONFIG_PATH=${CONFIG_PATH},PRETRAINED_PATH=${PRETRAINED_PATH},DISPLAY_NAME=${DISPLAY_NAME},MODEL=${MODEL},MODEL_NAME=${MODEL_NAME},CONJOIN_TRAIN_DECODER=${CONJOIN_TRAIN_DECODER},CONJOIN_TEST=${CONJOIN_TEST}" for TASK in "dummy_mouse_enhancers_ensembl" "demo_coding_vs_intergenomic_seqs" "demo_human_or_worm" "human_enhancers_cohn" "human_enhancers_ensembl" "human_ensembl_regulatory" "human_nontata_promoters" "human_ocr_ensembl"; do for LR in "${LRS[@]}"; do for BATCH_SIZE in 128 256; do for RC_AUG in "${RC_AUGS[@]}"; do export_str="${export_str},TASK=${TASK},LR=${LR},BATCH_SIZE=${BATCH_SIZE},RC_AUG=${RC_AUG}" job_name="gb_${TASK}_${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}" sbatch \ --job-name="${job_name}" \ --output="${LOG_DIR}/%x_%j.log" \ --export="${export_str}" \ "run_genomics_benchmark.sh" done done done done ================================================ FILE: slurm_scripts/wrapper_run_genomics_cnn.sh ================================================ #!/bin/bash LOG_DIR="../watch_folder/gb_cv5/cnn_baseline" mkdir -p "${LOG_DIR}" export_str="ALL" for TASK in "dummy_mouse_enhancers_ensembl" "demo_coding_vs_intergenomic_seqs" "demo_human_or_worm" "human_enhancers_cohn" "human_enhancers_ensembl" "human_ensembl_regulatory" "human_nontata_promoters" "human_ocr_ensembl"; do for RC_AUG in "false"; do export_str="${export_str},TASK=${TASK},RC_AUG=${RC_AUG}" job_name="gb_${TASK}_CNN_RC_AUG-${RC_AUG}" sbatch \ --job-name="${job_name}" \ --output="${LOG_DIR}/%x_%j.log" \ --export="${export_str}" \ "run_genomics_benchmark_cnn.sh" done done ================================================ FILE: slurm_scripts/wrapper_run_nucleotide_transformer.sh ================================================ #!/bin/bash # Choose one from below ## Caduceus NO POST HOC #LOG_DIR="../watch_folder/nt_cv10_ep20/caduceus" #CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/model_config.json") #PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/checkpoints/last.ckpt") #DISPLAY_NAME="caduceus_NO_PH" #MODEL="caduceus" #MODEL_NAME="dna_embedding_caduceus" #CONJOIN_TRAIN_DECODER="false" #CONJOIN_TEST="false" #RC_AUGS=( "true" ) #LRS=( "1e-3" "2e-3") ## Caduceus Post-Hoc #LOG_DIR="../watch_folder/nt_cv10_ep20/caduceus" #CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/model_config.json") #PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/checkpoints/last.ckpt") #DISPLAY_NAME="caduceus_ph" #MODEL="caduceus" #MODEL_NAME="dna_embedding_caduceus" #CONJOIN_TRAIN_DECODER="false" #CONJOIN_TEST="true" #RC_AUGS=( "false" ) #LRS=( "1e-3" "2e-3" ) ## Caduceus Parameter Sharing #LOG_DIR="../watch_folder/nt_cv10_ep20/caduceus" #CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/model_config.json") #PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/checkpoints/last.ckpt") #DISPLAY_NAME="caduceus_ps" #MODEL="caduceus" #MODEL_NAME="dna_embedding_caduceus" #CONJOIN_TRAIN_DECODER="true" # Use this in decoder to always combine forward and reverse complement channels #CONJOIN_TEST="false" #RC_AUGS=( "false" ) #LRS=( "1e-3" "2e-3" ) mkdir -p "${LOG_DIR}" export_str="ALL,CONFIG_PATH=${CONFIG_PATH},PRETRAINED_PATH=${PRETRAINED_PATH},DISPLAY_NAME=${DISPLAY_NAME},MODEL=${MODEL},MODEL_NAME=${MODEL_NAME},CONJOIN_TRAIN_DECODER=${CONJOIN_TRAIN_DECODER},CONJOIN_TEST=${CONJOIN_TEST}" for TASK in "enhancers" "enhancers_types" "H3" "H3K4me1" "H3K4me2" "H3K4me3" "H3K9ac" "H3K14ac" "H3K36me3" "H3K79me3" "H4" "H4ac" "promoter_all" "promoter_no_tata" "promoter_tata" "splice_sites_all" "splice_sites_acceptors" "splice_sites_donors"; do for LR in "${LRS[@]}"; do for BATCH_SIZE in 128 512; do for RC_AUG in "${RC_AUGS[@]}"; do export_str="${export_str},TASK=${TASK},LR=${LR},BATCH_SIZE=${BATCH_SIZE},RC_AUG=${RC_AUG}" job_name="nt_${TASK}_${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}" sbatch \ --job-name="${job_name}" \ --output="${LOG_DIR}/%x_%j.log" \ --export="${export_str}" \ "run_nucleotide_transformer.sh" done done done done ================================================ FILE: src/__init__.py ================================================ ================================================ FILE: src/callbacks/params.py ================================================ """Callback to log the number of parameters of the model. """ import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.parsing import AttributeDict class ParamsLog(pl.Callback): """ Log the number of parameters of the model """ def __init__( self, total: bool = True, trainable: bool = True, fixed: bool = True, ): super().__init__() self._log_stats = AttributeDict( { 'total_params_log': total, 'trainable_params_log': trainable, 'non_trainable_params_log': fixed, } ) @rank_zero_only def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: logs = {} if self._log_stats.total_params_log: logs["params/total"] = sum(p.numel() for p in pl_module.parameters()) if self._log_stats.trainable_params_log: logs["params/trainable"] = sum(p.numel() for p in pl_module.parameters() if p.requires_grad) if self._log_stats.non_trainable_params_log: logs["params/fixed"] = sum(p.numel() for p in pl_module.parameters() if not p.requires_grad) if trainer.logger: trainer.logger.log_hyperparams(logs) ================================================ FILE: src/callbacks/timer.py ================================================ """Callback to monitor the speed of each step and each epoch. https://github.com/HazyResearch/transformers/blob/master/src/callbacks/speed_monitor.py Adapted from: https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor """ # We only need the speed monitoring, not the GPU monitoring import time from typing import Any from pytorch_lightning import Callback, Trainer, LightningModule from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.types import STEP_OUTPUT class Timer(Callback): """Monitor the speed of each step and each epoch. """ def __init__( self, step: bool = True, inter_step: bool = True, epoch: bool = True, val: bool = True, ): super().__init__() self._log_stats = AttributeDict( { 'step_time': step, 'inter_step_time': inter_step, 'epoch_time': epoch, 'val_time': val, }) def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: self._snap_epoch_time = None def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: self._snap_step_time = None self._snap_inter_step_time = None self._snap_epoch_time = time.time() def on_train_batch_start( self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, ) -> None: if self._log_stats.step_time: self._snap_step_time = time.time() if not self._should_log(trainer): return logs = {} if self._log_stats.inter_step_time and self._snap_inter_step_time: # First log at beginning of second step logs["timer/inter_step"] = (time.time() - self._snap_inter_step_time) # * 1000 if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step) @rank_zero_only def on_train_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, ) -> None: if self._log_stats.inter_step_time: self._snap_inter_step_time = time.time() if not self._should_log(trainer): return logs = {} if self._log_stats.step_time and self._snap_step_time: logs["timer/step"] = (time.time() - self._snap_step_time) # * 1000 if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step) @rank_zero_only def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None: logs = {} if self._log_stats.epoch_time and self._snap_epoch_time: logs["timer/epoch"] = time.time() - self._snap_epoch_time if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step) def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: self._snap_val_time = time.time() @rank_zero_only def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None: logs = {} if self._log_stats.val_time and self._snap_val_time: logs["timer/validation"] = time.time() - self._snap_val_time if trainer.logger: trainer.logger.log_metrics(logs) # , step=trainer.global_step) @staticmethod def _should_log(trainer) -> bool: return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop ================================================ FILE: src/callbacks/validation.py ================================================ """Check validation every n **global** steps. Pytorch Lightning has a `val_check_interval` parameter that checks validation every n batches, but does not support checking every n **global** steps. """ from typing import Any from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer.states import RunningStage class ValEveryNGlobalSteps(Callback): """Check validation every n **global** steps.""" def __init__(self, every_n): self.every_n = every_n self.last_run = None def on_train_batch_end(self, trainer, *_: Any): """Check if we should run validation. Adapted from: https://github.com/Lightning-AI/pytorch-lightning/issues/2534#issuecomment-1085986529 """ # Prevent Running validation many times in gradient accumulation if trainer.global_step == self.last_run: return else: self.last_run = None if trainer.global_step % self.every_n == 0 and trainer.global_step != 0: trainer.training = False stage = trainer.state.stage trainer.state.stage = RunningStage.VALIDATING trainer._run_evaluate() trainer.state.stage = stage trainer.training = True trainer._logger_connector._epoch_end_reached = False self.last_run = trainer.global_step ================================================ FILE: src/dataloaders/__init__.py ================================================ from . import genomics from .base import SequenceDataset ================================================ FILE: src/dataloaders/base.py ================================================ """ Datasets for core experimental results. """ import os from functools import partial from pathlib import Path import torch # Default data path is environment variable or /data if (default_data_path := os.getenv("DATA_PATH")) is None: default_data_path = Path(__file__).parent.parent.parent.absolute() default_data_path = default_data_path / "data" else: default_data_path = Path(default_data_path).absolute() class DefaultCollateMixin: """Controls collating in the DataLoader The CollateMixin classes instantiate a dataloader by separating collate arguments with the rest of the dataloader arguments. Instantiations of this class should modify the callback functions as desired, and modify the collate_args list. The class then defines a _dataloader() method which takes in a DataLoader constructor and arguments, constructs a collate_fn based on the collate_args, and passes the rest of the arguments into the constructor. """ @classmethod def _collate_callback(cls, x, *args, **kwargs): """ Modify the behavior of the default _collate method. """ return x _collate_arg_names = [] @classmethod def _return_callback(cls, return_value, *args, **kwargs): """ Modify the return value of the collate_fn. Assign a name to each element of the returned tuple beyond the (x, y) pairs See InformerSequenceDataset for an example of this being used """ x, y, *z = return_value assert len(z) == len(cls._collate_arg_names), "Specify a name for each auxiliary data item returned by dataset" return x, y, {k: v for k, v in zip(cls._collate_arg_names, z)} @classmethod def _collate(cls, batch, *args, **kwargs): # From https://github.com/pyforch/pytorch/blob/master/torch/utils/data/_utils/collate.py elem = batch[0] if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) storage = elem.storage()._new_shared(numel) out = elem.new(storage) x = torch.stack(batch, dim=0, out=out) # Insert custom functionality into the collate_fn x = cls._collate_callback(x, *args, **kwargs) return x else: return torch.tensor(batch) @classmethod def _collate_fn(cls, batch, *args, **kwargs): """ Default collate function. Generally accessed by the dataloader() methods to pass into torch DataLoader Arguments: batch: list of (x, y) pairs args, kwargs: extra arguments that get passed into the _collate_callback and _return_callback """ x, y, *z = zip(*batch) x = cls._collate(x, *args, **kwargs) y = cls._collate(y) z = [cls._collate(z_) for z_ in z] return_value = (x, y, *z) return cls._return_callback(return_value, *args, **kwargs) # List of loader arguments to pass into collate_fn collate_args = [] def _dataloader(self, dataset, **loader_args): collate_args = {k: loader_args[k] for k in loader_args if k in self.collate_args} loader_args = {k: loader_args[k] for k in loader_args if k not in self.collate_args} loader_cls = loader_registry[loader_args.pop("_name_", None)] return loader_cls( dataset=dataset, collate_fn=partial(self._collate_fn, **collate_args), **loader_args, ) # class SequenceDataset(LightningDataModule): # [21-09-10 AG] Subclassing LightningDataModule fails due to trying to access _has_setup_fit. No idea why. So we just # provide our own class with the same core methods as LightningDataModule (e.g. setup) class SequenceDataset(DefaultCollateMixin): registry = {} _name_ = NotImplementedError("Dataset must have shorthand name") # Since subclasses do not specify __init__ which is instead handled by this class # Subclasses can provide a list of default arguments which are automatically registered as attributes # TODO it might be possible to write this as a @dataclass, but it seems tricky to separate from the other features # of this class such as the _name_ and d_input/d_output @property def init_defaults(self): return {} # https://www.python.org/dev/peps/pep-0487/#subclass-registration def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls.registry[cls._name_] = cls def __init__(self, _name_, data_dir=None, **dataset_cfg): assert _name_ == self._name_ self.data_dir = Path(data_dir).absolute() if data_dir is not None else None # Add all arguments to self init_args = self.init_defaults.copy() init_args.update(dataset_cfg) for k, v in init_args.items(): setattr(self, k, v) # The train, val, test datasets must be set by `setup()` self.dataset_train = self.dataset_val = self.dataset_test = None self.init() def init(self): """Hook called at end of __init__, override this instead of __init__""" pass def setup(self): """This method should set self.dataset_train, self.dataset_val, and self.dataset_test.""" raise NotImplementedError def split_train_val(self, val_split): """ Randomly split self.dataset_train into a new (self.dataset_train, self.dataset_val) pair. """ train_len = int(len(self.dataset_train) * (1.0 - val_split)) self.dataset_train, self.dataset_val = torch.utils.data.random_split( self.dataset_train, (train_len, len(self.dataset_train) - train_len), generator=torch.Generator().manual_seed( getattr(self, "seed", 42) ), # PL is supposed to have a way to handle seeds properly, but doesn't seem to work for us ) def train_dataloader(self, **kwargs): """Return a DataLoader for the training dataset.""" return self._train_dataloader(self.dataset_train, **kwargs) def _train_dataloader(self, dataset, **kwargs): if dataset is None: return kwargs['shuffle'] = 'sampler' not in kwargs # shuffle cant be True if we have custom sampler return self._dataloader(dataset, **kwargs) def val_dataloader(self, **kwargs): """Return a DataLoader for the validation dataset.""" return self._eval_dataloader(self.dataset_val, **kwargs) def test_dataloader(self, **kwargs): """Return a DataLoader for the test dataset.""" return self._eval_dataloader(self.dataset_test, **kwargs) def _eval_dataloader(self, dataset, **kwargs): if dataset is None: return # Note that shuffle=False by default return self._dataloader(dataset, **kwargs) def __str__(self): return self._name_ # Registry for dataloader class loader_registry = { None: torch.utils.data.DataLoader, # default case } ================================================ FILE: src/dataloaders/datasets/genomic_bench_dataset.py ================================================ """Genomic Benchmarks Dataset. From: https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks """ from pathlib import Path import torch from genomic_benchmarks.data_check import is_downloaded from genomic_benchmarks.loc2seq import download_dataset from src.dataloaders.utils.rc import coin_flip, string_reverse_complement class GenomicBenchmarkDataset(torch.utils.data.Dataset): """ Loop through bed file, retrieve (chr, start, end), query fasta file for sequence. Returns a generator that retrieves the sequence. """ def __init__( self, split, max_length, dataset_name="human_nontata_promoters", d_output=2, # default binary classification dest_path=None, tokenizer=None, tokenizer_name=None, use_padding=None, add_eos=False, rc_aug=False, conjoin_train=False, conjoin_test=False, return_augs=False, return_mask=False, ): self.max_length = max_length self.use_padding = use_padding self.tokenizer_name = tokenizer_name self.tokenizer = tokenizer self.return_augs = return_augs self.add_eos = add_eos self.d_output = d_output # needed for decoder to grab assert not (conjoin_train and conjoin_test), "conjoin_train and conjoin_test cannot both be True" if (conjoin_train or conjoin_test) and rc_aug: print("When using conjoin, we turn off rc_aug.") rc_aug = False self.rc_aug = rc_aug self.conjoin_train = conjoin_train self.conjoin_test = conjoin_test self.return_mask = return_mask if not is_downloaded(dataset_name, cache_path=dest_path): print("downloading {} to {}".format(dataset_name, dest_path)) download_dataset(dataset_name, version=0, dest_path=dest_path) else: print("already downloaded {}-{}".format(split, dataset_name)) self.split = split # use Path object base_path = Path(dest_path) / dataset_name / split self.all_seqs = [] self.all_labels = [] label_mapper = {} for i, x in enumerate(base_path.iterdir()): label_mapper[x.stem] = i for label_type in label_mapper.keys(): for path in (base_path / label_type).iterdir(): with open(path, "r") as f: content = f.read() self.all_seqs.append(content) self.all_labels.append(label_mapper[label_type]) def __len__(self): return len(self.all_labels) def __getitem__(self, idx): x = self.all_seqs[idx] y = self.all_labels[idx] if (self.rc_aug or (self.conjoin_test and self.split == "train")) and coin_flip(): x = string_reverse_complement(x) seq = self.tokenizer( x, add_special_tokens=False, padding="max_length" if self.use_padding else None, max_length=self.max_length, truncation=True, ) seq_ids = seq["input_ids"] # get input_ids # need to handle eos here if self.add_eos: # append list seems to be faster than append tensor seq_ids.append(self.tokenizer.sep_token_id) if self.conjoin_train or (self.conjoin_test and self.split != "train"): x_rc = string_reverse_complement(x) seq_rc = self.tokenizer( x_rc, add_special_tokens=False, padding="max_length" if self.use_padding else None, max_length=self.max_length, truncation=True, ) seq_rc_ids = seq_rc["input_ids"] # get input_ids # need to handle eos here if self.add_eos: # append list seems to be faster than append tensor seq_rc_ids.append(self.tokenizer.sep_token_id) seq_ids = torch.stack((torch.LongTensor(seq_ids), torch.LongTensor(seq_rc_ids)), dim=1) else: # convert to tensor seq_ids = torch.LongTensor(seq_ids) # need to wrap in list target = torch.LongTensor([y]) # `seq` has shape: # - (seq_len,) if not conjoining # - (seq_len, 2) for conjoining if self.return_mask: return seq_ids, target, {"mask": torch.BoolTensor(seq["attention_mask"])} else: return seq_ids, target ================================================ FILE: src/dataloaders/datasets/hg38_char_tokenizer.py ================================================ """ From: https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py CharacterTokenizer for Hugging Face Transformers. This is heavily inspired from CanineTokenizer in transformers package. """ import json import os from pathlib import Path from typing import Dict, List, Optional, Sequence, Union from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer class CharacterTokenizer(PreTrainedTokenizer): def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str = 'left', **kwargs): """Character tokenizer for Hugging Face transformers. Args: characters (Sequence[str]): List of desired characters. Any character which is not included in this list will be replaced by a special token called [UNK] with id=6. Following is the list of all the special tokens with their corresponding ids: "[CLS]": 0 "[SEP]": 1 "[BOS]": 2 "[MASK]": 3 "[PAD]": 4 "[RESERVED]": 5 "[UNK]": 6 an id (starting at 7) will be assigned to each character. model_max_length (int): Model maximum sequence length. """ self.characters = characters self.model_max_length = model_max_length bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False) eos_token = AddedToken("[EOS]", lstrip=False, rstrip=False) sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False) cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False) pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False) unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False) mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False) self._vocab_str_to_int = { "[CLS]": 0, "[SEP]": 1, "[BOS]": 2, "[MASK]": 3, "[PAD]": 4, "[RESERVED]": 5, "[UNK]": 6, **{ch: i + 7 for i, ch in enumerate(characters)}, } self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()} # TODO: This should be a parameter passed to __init__ complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"} self.complement_map = {} for k, v in self._vocab_str_to_int.items(): complement_id = self._vocab_str_to_int[complement_map[k]] if k in complement_map.keys() else v self.complement_map[self._vocab_str_to_int[k]] = complement_id super().__init__( bos_token=bos_token, eos_token=pad_token, sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, mask_token=mask_token, unk_token=unk_token, add_prefix_space=False, model_max_length=model_max_length, padding_side=padding_side, **kwargs, ) @property def vocab_size(self) -> int: return len(self._vocab_str_to_int) def _tokenize(self, text: str) -> List[str]: return list(text) def _convert_token_to_id(self, token: str) -> int: return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"]) def _convert_id_to_token(self, index: int) -> str: return self._vocab_int_to_str[index] def convert_tokens_to_string(self, tokens): return "".join(tokens) def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: sep = [self.sep_token_id] cls = [self.cls_token_id] result = cls + token_ids_0 + sep if token_ids_1 is not None: result += token_ids_1 + sep return result def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False, ) -> List[int]: if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True, ) result = [1] + ([0] * len(token_ids_0)) + [1] if token_ids_1 is not None: result += ([0] * len(token_ids_1)) + [1] return result def get_vocab(self) -> Dict[str, int]: return self._vocab_str_to_int def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: sep = [self.sep_token_id] cls = [self.cls_token_id] result = len(cls + token_ids_0 + sep) * [0] if token_ids_1 is not None: result += len(token_ids_1 + sep) * [1] return result def get_config(self) -> Dict: return { "char_ords": [ord(ch) for ch in self.characters], "model_max_length": self.model_max_length, } @classmethod def from_config(cls, config: Dict) -> "CharacterTokenizer": cfg = {} cfg["characters"] = [chr(i) for i in config["char_ords"]] cfg["model_max_length"] = config["model_max_length"] return cls(**cfg) def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): cfg_file = Path(save_directory) / "tokenizer_config.json" cfg = self.get_config() with open(cfg_file, "w") as f: json.dump(cfg, f, indent=4) @classmethod def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs): cfg_file = Path(save_directory) / "tokenizer_config.json" with open(cfg_file) as f: cfg = json.load(f) return cls.from_config(cfg) ================================================ FILE: src/dataloaders/datasets/hg38_dataset.py ================================================ """Dataset for sampling arbitrary intervals from the human genome. """ import math from pathlib import Path import pandas as pd import torch from pyfaidx import Fasta from src.dataloaders.utils.mlm import mlm_getitem from src.dataloaders.utils.rc import coin_flip, string_reverse_complement MAX_ALLOWED_LENGTH = 2 ** 20 class FastaInterval: """Retrieves sequences from a fasta file given a chromosome and start/end indices.""" def __init__( self, *, fasta_file, return_seq_indices=False, rc_aug=False, ): fasta_file = Path(fasta_file) assert fasta_file.exists(), "Path to fasta file must exist!" self.seqs = Fasta(str(fasta_file)) self.return_seq_indices = return_seq_indices self.rc_aug = rc_aug # calc len of each chromosome in fasta file, store in dict self.chr_lens = {} for chr_name in self.seqs.keys(): self.chr_lens[chr_name] = len(self.seqs[chr_name]) @staticmethod def _compute_interval(start, end, max_length, i_shift): if max_length == MAX_ALLOWED_LENGTH: return start, end if max_length < MAX_ALLOWED_LENGTH: assert MAX_ALLOWED_LENGTH % max_length == 0 return start + i_shift * max_length, start + (i_shift + 1) * max_length else: raise ValueError(f"`max_length` {max_length} (> 2^{int(math.log(MAX_ALLOWED_LENGTH, 2))}) is too large!") def __call__( self, chr_name, start, end, max_length, i_shift, return_augs=False, ): """ max_length passed from dataset, not from init """ chromosome = self.seqs[chr_name] chromosome_length = self.chr_lens[chr_name] start, end = self._compute_interval(start, end, max_length, i_shift) if end > chromosome_length: # Shift interval down start = start - (end - chromosome_length) end = chromosome_length assert start == chromosome_length - max_length if start < 0: # Shift interval up end = end - start start = 0 assert end == max_length if end > chromosome_length: # This may occur if start + MAX_ALLOWED_LENGTH extends beyond the end of the chromosome start = chromosome_length - max_length end = chromosome_length seq = str(chromosome[start:end]) if self.rc_aug and coin_flip(): seq = string_reverse_complement(seq) return seq class HG38Dataset(torch.utils.data.Dataset): """Loop through bed file, retrieve (chr, start, end), query fasta file for sequence.""" def __init__( self, split, bed_file, fasta_file, max_length, mlm=False, mlm_probability=0.15, pad_max_length=None, tokenizer=None, tokenizer_name=None, add_eos=False, return_seq_indices=False, rc_aug=False, return_augs=False, ): self.mlm = mlm self.mlm_probability = mlm_probability if self.mlm and self.mlm_probability <= 0.0: raise ValueError(f"`mlm_probability` has to be > 0.0, got {self.mlm_probability}.") if self.mlm: # TODO: see if this helps # self.eligible_replacements = torch.tensor( # tokenizer("ACGT", add_special_tokens=False)["input_ids"], dtype=torch.long # ) self.eligible_replacements = None else: self.eligible_replacements = None self.max_length = max_length self.pad_max_length = pad_max_length if pad_max_length is not None else max_length self.tokenizer_name = tokenizer_name self.tokenizer = tokenizer self.return_augs = return_augs self.add_eos = add_eos if max_length <= MAX_ALLOWED_LENGTH: assert MAX_ALLOWED_LENGTH % max_length == 0, f"`max_length` must be a power of 2!" self.shifts = MAX_ALLOWED_LENGTH // max_length else: raise ValueError(f"`max_length` {max_length} (> 2^{int(math.log(MAX_ALLOWED_LENGTH, 2))}) is too large!") bed_path = Path(bed_file) assert bed_path.exists(), "Path to .bed file must exist!" # read bed file df_raw = pd.read_csv(str(bed_path), sep="\t", names=["chr_name", "start", "end", "split"]) # select only split df self.df = df_raw[df_raw["split"] == split] # Update end points so that sequences are all length == MAX_ALLOWED_LENGTH self.df.loc[:, "end"] = self.df["start"] + MAX_ALLOWED_LENGTH self.fasta = FastaInterval( fasta_file=fasta_file, return_seq_indices=return_seq_indices, rc_aug=rc_aug ) @staticmethod def replace_value(x, old_value, new_value): """Helper for replacing values in a tensor.""" return torch.where(x == old_value, new_value, x) def __len__(self): return len(self.df) * self.shifts def __getitem__(self, idx): """Returns a sequence of specified len""" # sample a random row from df row_idx, shift_idx = idx // self.shifts, idx % self.shifts row = self.df.iloc[row_idx] chr_name, start, end = (row.iloc[0], row.iloc[1], row.iloc[2]) seq = self.fasta( chr_name, start, end, max_length=self.max_length, i_shift=shift_idx, return_augs=self.return_augs, ) if end - start != MAX_ALLOWED_LENGTH: print(row, "\nLength: ", end - start) if self.tokenizer_name == "char": seq = self.tokenizer( seq, padding="max_length", max_length=self.pad_max_length, truncation=True, add_special_tokens=False ) seq = seq["input_ids"] # get input_ids # need to handle eos here if self.add_eos: # append list seems to be faster than append tensor seq.append(self.tokenizer.sep_token_id) elif self.tokenizer_name == "bpe": seq = self.tokenizer( seq, # add_special_tokens=False, padding="max_length", max_length=self.pad_max_length, truncation=True, ) # get input_ids if self.add_eos: seq = seq["input_ids"][1:] # remove the bos, keep the eos token else: seq = seq["input_ids"][1:-1] # remove both special tokens # convert to tensor seq = torch.LongTensor(seq) # replace N token with a pad token, so we can ignore it in the loss seq = self.replace_value(seq, self.tokenizer._vocab_str_to_int["N"], self.tokenizer.pad_token_id) if self.mlm: data, target = mlm_getitem( seq, mlm_probability=self.mlm_probability, contains_eos=self.add_eos, tokenizer=self.tokenizer, eligible_replacements=self.eligible_replacements, ) else: data = seq[:-1].clone() target = seq[1:].clone() return data, target ================================================ FILE: src/dataloaders/datasets/nucleotide_transformer_dataset.py ================================================ """Nucleotide Transformer Benchmarks Dataset. From: https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_downstream_tasks """ import torch from datasets import load_dataset from src.dataloaders.utils.rc import coin_flip, string_reverse_complement class NucleotideTransformerDataset(torch.utils.data.Dataset): """ Loop through fasta file for sequence. Returns a generator that retrieves the sequence. """ def __init__( self, split, max_length, dataset_name=None, d_output=2, # default binary classification tokenizer=None, tokenizer_name=None, use_padding=None, add_eos=False, rc_aug=False, conjoin_train=False, conjoin_test=False, return_augs=False ): self.max_length = max_length self.use_padding = use_padding self.tokenizer_name = tokenizer_name self.tokenizer = tokenizer self.return_augs = return_augs self.add_eos = add_eos self.d_output = d_output # needed for decoder to grab assert not (conjoin_train and conjoin_test), "conjoin_train and conjoin_test cannot both be True" if (conjoin_train or conjoin_test) and rc_aug: print("When using conjoin, we turn off rc_aug.") rc_aug = False self.rc_aug = rc_aug self.conjoin_train = conjoin_train self.conjoin_test = conjoin_test self.split = split # For NT tasks, we use data from InstaDeepAI/nucleotide_transformer_downstream_tasks self.seqs = load_dataset( "InstaDeepAI/nucleotide_transformer_downstream_tasks", name=dataset_name, split=split ) def __len__(self): return len(self.seqs) def __getitem__(self, idx): x = self.seqs[idx]["sequence"] # only one sequence y = self.seqs[idx]["label"] if (self.rc_aug or (self.conjoin_test and self.split == "train")) and coin_flip(): x = string_reverse_complement(x) seq = self.tokenizer( x, add_special_tokens=False, padding="max_length" if self.use_padding else None, max_length=self.max_length, truncation=True, ) seq_ids = seq["input_ids"] # get input_ids # need to handle eos here if self.add_eos: # append list seems to be faster than append tensor seq_ids.append(self.tokenizer.sep_token_id) if self.conjoin_train or (self.conjoin_test and self.split != "train"): x_rc = string_reverse_complement(x) seq_rc = self.tokenizer( x_rc, add_special_tokens=False, padding="max_length" if self.use_padding else None, max_length=self.max_length, truncation=True, ) seq_rc_ids = seq_rc["input_ids"] # get input_ids # need to handle eos here if self.add_eos: # append list seems to be faster than append tensor seq_rc_ids.append(self.tokenizer.sep_token_id) seq_ids = torch.stack((torch.LongTensor(seq_ids), torch.LongTensor(seq_rc_ids)), dim=1) else: # convert to tensor seq_ids = torch.LongTensor(seq_ids) # need to wrap in list target = torch.LongTensor([y]) # `seq` has shape: # - (seq_len,) if not conjoining # - (seq_len, 2) for conjoining return seq_ids, target ================================================ FILE: src/dataloaders/fault_tolerant_sampler.py ================================================ # Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytorch/utilities/test_auto_restart.py#L1397 from typing import Iterator import math import torch from torch.utils.data import RandomSampler, DistributedSampler class RandomFaultTolerantSampler(RandomSampler): def __init__(self, *args, generator=None, **kwargs): # generator = torch.Generator().manual_seed(seed) # super().__init__(*args, generator=generator, **kwargs) # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed, # which should be reproducible if pl.seed_everything was called before hand. # This means that changing the seed of the experiment will also change the # sampling order. if generator is None: seed = int(torch.empty((), dtype=torch.int64).random_().item()) generator = torch.Generator().manual_seed(seed) super().__init__(*args, generator=generator, **kwargs) self.counter = 0 # self.start_counter = 0 self.restarting = False def state_dict(self): return {"random_state": self.state, "counter": self.counter} def load_state_dict(self, state_dict): self.generator.set_state(state_dict.get("random_state")) self.counter = state_dict["counter"] # self.start_counter = self.counter self.restarting = True # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per # epoch, and subsequent epoch will have very few batches. # def __len__(self): # # We need a separate self.start_counter because PL seems to call len repeatedly. # # If we use len(self.data_source) - self.counter then PL will think the epoch ends # # when we're only half way through. # return len(self.data_source) - self.start_counter def __iter__(self) -> Iterator[int]: n = len(self.data_source) self.state = self.generator.get_state() indices = torch.randperm(n, generator=self.generator).tolist() if not self.restarting: self.counter = 0 else: indices = indices[self.counter:] self.restarting = False # self.start_counter = self.counter for index in indices: self.counter += 1 yield index self.counter = 0 # self.start_counter = self.counter class FaultTolerantDistributedSampler(DistributedSampler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.counter = 0 # self.start_counter = 0 self.restarting = False def state_dict(self): return {"epoch": self.epoch, "counter": self.counter} def load_state_dict(self, state_dict): self.epoch = state_dict["epoch"] self.counter = state_dict["counter"] # self.start_counter = self.counter self.restarting = True # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per # epoch, and subsequent epoch will have very few batches. # def __len__(self) -> int: # return self.num_samples - self.start_counter def __iter__(self): if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] else: indices = list(range(len(self.dataset))) # type: ignore[arg-type] if not self.drop_last: # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[:self.total_size] assert len(indices) == self.total_size # subsample indices = indices[self.rank:self.total_size:self.num_replicas] assert len(indices) == self.num_samples if not self.restarting: self.counter = 0 else: indices = indices[self.counter:] self.restarting = False # self.start_counter = self.counter for index in indices: self.counter += 1 yield index self.counter = 0 # self.start_counter = self.counter ================================================ FILE: src/dataloaders/genomics.py ================================================ """Dataloaders for genomics datasets, including pretraining and downstream tasks. - Adapted from: https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py - Adapted from: https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py """ import copy from typing import Any, List, Union import torch from datasets import Dataset from torch.utils.data.dataloader import DataLoader from caduceus.tokenization_caduceus import CaduceusTokenizer import src.utils.train from src.dataloaders.base import SequenceDataset, default_data_path from src.dataloaders.datasets.genomic_bench_dataset import GenomicBenchmarkDataset from src.dataloaders.datasets.hg38_char_tokenizer import CharacterTokenizer from src.dataloaders.datasets.hg38_dataset import HG38Dataset from src.dataloaders.datasets.nucleotide_transformer_dataset import NucleotideTransformerDataset from src.dataloaders.fault_tolerant_sampler import FaultTolerantDistributedSampler from src.dataloaders.fault_tolerant_sampler import RandomFaultTolerantSampler logger = src.utils.train.get_logger(__name__) class HG38(SequenceDataset): """ Base class, other dataloaders can inherit from this class. You must implement the following functions: - __init__ - setup You can then use (already have access to) the following functions: - train_dataloader - val_dataloader - test_dataloader """ _name_ = "hg38" # this name is how the dataset config finds the right dataloader def __init__(self, bed_file, fasta_file, tokenizer_name=None, dataset_config_name=None, max_length=1024, d_output=2, rc_aug=False, max_length_val=None, max_length_test=None, val_ratio=0.0005, val_split_seed=2357, add_eos=True, detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, shuffle=False, num_workers=1, fault_tolerant=False, ddp=False, fast_forward_epochs=None, fast_forward_batches=None, mlm=False, mlm_probability=0.15, *args, **kwargs): self.dataset_config_name = dataset_config_name self.tokenizer_name = tokenizer_name self.d_output = d_output self.rc_aug = rc_aug # reverse compliment augmentation self.max_length = max_length self.max_length_val = max_length_val if max_length_val is not None else max_length self.max_length_test = max_length_test if max_length_test is not None else max_length self.val_ratio = val_ratio self.val_split_seed = val_split_seed self.val_only = val_only self.add_eos = add_eos self.detokenize = detokenize self.batch_size = batch_size self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size self.shuffle = shuffle self.num_workers = num_workers self.bed_file = bed_file self.fasta_file = fasta_file # handle if file paths are None (default paths) if self.bed_file is None: self.bed_file = default_data_path / self._name_ / "human-sequences.bed" if self.fasta_file is None: self.fasta_file = default_data_path / self._name_ / "hg38.ml.fa" if fault_tolerant: assert self.shuffle self.fault_tolerant = fault_tolerant if ddp: assert fault_tolerant self.ddp = ddp self.fast_forward_epochs = fast_forward_epochs self.fast_forward_batches = fast_forward_batches if self.fast_forward_epochs is not None or self.fast_forward_batches is not None: assert ddp and fault_tolerant self.mlm = mlm self.mlm_probability = mlm_probability # To be instantiated in `setup` self.tokenizer = None self.vocab_size = 0 def setup(self, stage=None): """Set up the tokenizer and init the datasets.""" # TODO instantiate with registry if self.tokenizer_name == "char": logger.info("**Using Char-level tokenizer**") # self.tokenizer = CharacterTokenizer( # characters=["A", "C", "G", "T", "N"], # model_max_length=self.max_length, # add_special_tokens=False, # ) self.tokenizer = CaduceusTokenizer( model_max_length=self.max_length, add_special_tokens=False ) else: raise NotImplementedError(f"Tokenizer {self.tokenizer_name} not implemented.") self.vocab_size = len(self.tokenizer) self.init_datasets() # creates the datasets. You can also just create this inside the setup() here. def init_datasets(self): """Init the datasets (separate from the tokenizer)""" # delete old datasets to free memory if hasattr(self, "dataset_train"): self.dataset_train.fasta.seqs.close() del self.dataset_train.fasta.seqs # delete old datasets to free memory if hasattr(self, "dataset_test"): self.dataset_test.fasta.seqs.close() del self.dataset_test.fasta.seqs # Create all splits: torch datasets self.dataset_train, self.dataset_val, self.dataset_test = [ HG38Dataset(split=split, bed_file=self.bed_file, fasta_file=self.fasta_file, max_length=max_len, tokenizer=self.tokenizer, # pass the tokenize wrapper tokenizer_name=self.tokenizer_name, add_eos=self.add_eos, return_seq_indices=False, rc_aug=self.rc_aug, return_augs=False, mlm=self.mlm, mlm_probability=self.mlm_probability, ) for split, max_len in zip(["train", "valid", "test"], [self.max_length, self.max_length_val, self.max_length_test]) ] return def train_dataloader(self, **kwargs: Any) -> DataLoader: """ The train dataloader """ if self.shuffle and self.fault_tolerant: shuffle = False # TD [2022-12-26]: We need the distributed_sampler_kwargs in case of model parallel: # In that case the number of replicas and the data parallel rank are more complicated. distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs sampler = (FaultTolerantDistributedSampler( self.dataset_train, **distributed_sampler_kwargs ) if self.ddp else RandomFaultTolerantSampler(self.dataset_train)) # TD [2022-08-06]: Only the DDP sampler supports fast-forwarding for now # We assume that it's being resumed with the same number of GPUs if self.ddp and self.fast_forward_epochs is not None and self.fast_forward_batches is not None: sampler.load_state_dict({ "epoch": self.fast_forward_epochs, "counter": self.fast_forward_batches * self.batch_size }) else: shuffle = self.shuffle sampler = None loader = self._data_loader(self.dataset_train, batch_size=self.batch_size, shuffle=shuffle, sampler=sampler, **kwargs) return loader def val_dataloader(self, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: """ The val dataloader """ kwargs["drop_last"] = False return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval, **kwargs) def test_dataloader(self, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: """ The test dataloader """ kwargs["drop_last"] = False # TODO: Should have separate train and eval loaders return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval, **kwargs) @staticmethod def _data_loader(dataset: Dataset, batch_size: int, shuffle: bool = False, sampler=None, **kwargs) -> DataLoader: return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, **kwargs, ) def load_state_dict(self, checkpoint): if self.fault_tolerant: self.fast_forward_epochs = checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] # TD [2022-08-07] ["epoch_loop.batch_progress"]["total"]["completed"] is 1 iteration # behind, so we're using the optimizer"s progress. This is set correctly in seq.py. self.fast_forward_batches = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["current"][ "completed"] # At this point the train loader hasn't been constructed yet class GenomicBenchmark(HG38): _name_ = "genomic_benchmark" l_output = 0 # need to set this for decoder to work correctly def __init__( self, dataset_name, train_val_split_seed, dest_path=None, tokenizer_name="char", d_output=None, rc_aug=False, conjoin_train=False, conjoin_test=False, max_length=1024, use_padding=True, max_length_val=None, max_length_test=None, padding_side="left", val_ratio=0.0005, val_split_seed=2357, add_eos=False, detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1, shuffle=True, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False, fast_forward_epochs=None, fast_forward_batches=None, *args, **kwargs ): self.dataset_name = dataset_name self.train_val_split_seed = train_val_split_seed self.dest_path = dest_path self.tokenizer_name = tokenizer_name self.d_output = d_output self.rc_aug = rc_aug self.conjoin_train = conjoin_train self.conjoin_test = conjoin_test self.max_length = max_length self.use_padding = use_padding self.max_length_val = max_length_val if max_length_val is not None else max_length self.max_length_test = max_length_test if max_length_test is not None else max_length self.padding_side = padding_side self.val_ratio = val_ratio self.val_split_seed = val_split_seed self.val_only = val_only self.add_eos = add_eos self.detokenize = detokenize self.batch_size = batch_size self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size self.num_workers = num_workers self.shuffle = shuffle self.pin_memory = pin_memory self.drop_last = drop_last if self.dest_path is None: self.dest_path = default_data_path / self._name_ if fault_tolerant: assert self.shuffle self.fault_tolerant = fault_tolerant if ddp: assert fault_tolerant self.ddp = ddp self.fast_forward_epochs = fast_forward_epochs self.fast_forward_batches = fast_forward_batches if self.fast_forward_epochs is not None or self.fast_forward_batches is not None: assert ddp and fault_tolerant def setup(self, stage=None): # TODO instantiate with registry if self.tokenizer_name == "char": print("**Using Char-level tokenizer**") self.tokenizer = CharacterTokenizer( characters=["A", "C", "G", "T", "N"], model_max_length=self.max_length + 2, # add 2 since default adds eos/eos tokens, crop later add_special_tokens=False, padding_side=self.padding_side, ) # Create all splits: torch datasets (only train/test in this benchmark, val created below) self.dataset_train, self.dataset_test = [ GenomicBenchmarkDataset( split=split, max_length=max_len, dataset_name=self.dataset_name, tokenizer=self.tokenizer, # pass the tokenize wrapper tokenizer_name=self.tokenizer_name, use_padding=self.use_padding, d_output=self.d_output, add_eos=self.add_eos, dest_path=self.dest_path, rc_aug=self.rc_aug, conjoin_train=self.conjoin_train, conjoin_test=self.conjoin_test, return_augs=False ) for split, max_len in zip(["train", "test"], [self.max_length, self.max_length_val]) ] val_data, train_data = torch.utils.data.random_split( list(zip(self.dataset_train.all_seqs, self.dataset_train.all_labels)), lengths=[0.1, 0.9], generator=torch.Generator().manual_seed(self.train_val_split_seed) ) self.dataset_val = copy.deepcopy(self.dataset_train) self.dataset_train.all_seqs = [train_data[i][0] for i in range(len(train_data))] self.dataset_train.all_labels = [train_data[i][1] for i in range(len(train_data))] self.dataset_val.all_seqs = [val_data[i][0] for i in range(len(val_data))] self.dataset_val.all_labels = [val_data[i][1] for i in range(len(val_data))] self.dataset_val.split = "val" class NucleotideTransformer(HG38): _name_ = "nucleotide_transformer" l_output = 0 # need to set this for decoder to work correctly def __init__(self, dataset_name, train_val_split_seed, tokenizer_name="char", d_output=None, rc_aug=False, conjoin_train=False, conjoin_test=False, max_length=1024, use_padding=True, max_length_val=None, max_length_test=None, padding_side="left", val_ratio=0.0005, val_split_seed=2357, add_eos=False, detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1, shuffle=True, shuffle_eval=None, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False, fast_forward_epochs=None, fast_forward_batches=None, *args, **kwargs): self.dataset_name = dataset_name self.train_val_split_seed = train_val_split_seed self.tokenizer_name = tokenizer_name self.d_output = d_output self.rc_aug = rc_aug self.conjoin_train = conjoin_train self.conjoin_test = conjoin_test self.max_length = max_length self.use_padding = use_padding self.max_length_val = max_length_val if max_length_val is not None else max_length self.max_length_test = max_length_test if max_length_test is not None else max_length self.padding_side = padding_side self.val_ratio = val_ratio self.val_split_seed = val_split_seed self.val_only = val_only self.add_eos = add_eos self.detokenize = detokenize self.batch_size = batch_size self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size self.num_workers = num_workers self.shuffle = shuffle self.shuffle_eval = shuffle_eval if shuffle_eval is not None else shuffle self.pin_memory = pin_memory self.drop_last = drop_last if fault_tolerant: assert self.shuffle self.fault_tolerant = fault_tolerant if ddp: assert fault_tolerant self.ddp = ddp self.fast_forward_epochs = fast_forward_epochs self.fast_forward_batches = fast_forward_batches if self.fast_forward_epochs is not None or self.fast_forward_batches is not None: assert ddp and fault_tolerant def setup(self, stage=None): # TODO instantiate with registry if self.tokenizer_name == "char": print("**Using Char-level tokenizer**") self.tokenizer = CharacterTokenizer( characters=["A", "C", "G", "T", "N"], model_max_length=self.max_length + 2, # add 2 since default adds eos/eos tokens, crop later add_special_tokens=False, padding_side=self.padding_side, ) # Create all splits: torch datasets (only train/test in this benchmark) # self.dataset_train, self.dataset_val = [ self.dataset_train, self.dataset_test = [ NucleotideTransformerDataset( split=split, max_length=max_len, tokenizer=self.tokenizer, # pass the tokenize wrapper dataset_name=self.dataset_name, tokenizer_name=self.tokenizer_name, use_padding=self.use_padding, d_output=self.d_output, add_eos=self.add_eos, rc_aug=self.rc_aug, conjoin_train=self.conjoin_train, conjoin_test=self.conjoin_test, return_augs=False ) for split, max_len in zip(["train", "test"], [self.max_length, self.max_length_val]) ] ds_train_val_split = self.dataset_train.seqs.train_test_split( test_size=0.1, seed=self.train_val_split_seed ) self.dataset_val = copy.deepcopy(self.dataset_train) self.dataset_train.seqs = ds_train_val_split["train"] self.dataset_val.split = "val" self.dataset_val.seqs = ds_train_val_split["test"] ================================================ FILE: src/dataloaders/utils/mlm.py ================================================ import torch def mlm_getitem(seq, mlm_probability=0.15, contains_eos=False, tokenizer=None, eligible_replacements=None): """Helper method for creating MLM input / target. Adapted from: https://github.com/huggingface/transformers/blob/14666775a296a76c88e1aa686a9547f393d322e2/src/transformers/data/data_collator.py#L751 """ data = seq[:-1].clone() if contains_eos else seq.clone() # remove eos, if applicable target = data.clone() # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) probability_matrix = torch.full(target.shape, mlm_probability) # TODO: Do we need to avoid "masking" special tokens as is done here? # https://github.com/huggingface/transformers/blob/14666775a296a76c88e1aa686a9547f393d322e2/src/transformers/data/data_collator.py#L760-L766 masked_indices = torch.bernoulli(probability_matrix).bool() target[~masked_indices] = tokenizer.pad_token_id # We only compute loss on masked tokens # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) indices_replaced = torch.bernoulli(torch.full(target.shape, 0.8)).bool() & masked_indices data[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) # 10% of the time, we replace masked input tokens with random word indices_random = torch.bernoulli(torch.full(target.shape, 0.5)).bool() & masked_indices & ~indices_replaced if eligible_replacements is not None: rand_choice = torch.randint(eligible_replacements.shape[0], size=target.shape) random_words = eligible_replacements[rand_choice] else: random_words = torch.randint(len(tokenizer), size=target.shape, dtype=torch.long) data[indices_random] = random_words[indices_random] # The rest of the time (10% of the time) we keep the masked input tokens unchanged return data, target ================================================ FILE: src/dataloaders/utils/rc.py ================================================ """Utility functions for reverse complementing DNA sequences. """ from random import random STRING_COMPLEMENT_MAP = { "A": "T", "C": "G", "G": "C", "T": "A", "a": "t", "c": "g", "g": "c", "t": "a", "N": "N", "n": "n", } def coin_flip(p=0.5): """Flip a (potentially weighted) coin.""" return random() > p def string_reverse_complement(seq): """Reverse complement a DNA sequence.""" rev_comp = "" for base in seq[::-1]: if base in STRING_COMPLEMENT_MAP: rev_comp += STRING_COMPLEMENT_MAP[base] # if bp not complement map, use the same bp else: rev_comp += base return rev_comp ================================================ FILE: src/models/__init__.py ================================================ ================================================ FILE: src/models/baseline/__init__.py ================================================ ================================================ FILE: src/models/baseline/genomics_benchmark_cnn.py ================================================ """Genomics Benchmark CNN model. Adapted from https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/blob/main/src/genomic_benchmarks/models/torch.py """ import torch from torch import nn class GenomicsBenchmarkCNN(nn.Module): def __init__(self, number_of_classes, vocab_size, input_len, embedding_dim=100): """Genomics Benchmark CNN model. `embedding_dim` = 100 comes from: https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/tree/main/experiments/torch_cnn_experiments """ super(GenomicsBenchmarkCNN, self).__init__() self.embeddings = nn.Embedding(vocab_size, embedding_dim) self.cnn_model = nn.Sequential( nn.Conv1d(in_channels=embedding_dim, out_channels=16, kernel_size=8, bias=True), nn.BatchNorm1d(16), nn.ReLU(), nn.MaxPool1d(2), nn.Conv1d(in_channels=16, out_channels=8, kernel_size=8, bias=True), nn.BatchNorm1d(8), nn.MaxPool1d(2), nn.Conv1d(in_channels=8, out_channels=4, kernel_size=8, bias=True), nn.BatchNorm1d(4), nn.MaxPool1d(2), nn.Flatten() ) self.dense_model = nn.Sequential( nn.Linear(self.count_flatten_size(input_len), 512), # To be consistent with SSM classifier decoders, we use num_classes (even when it's binary) nn.Linear(512, number_of_classes) ) def count_flatten_size(self, input_len): zeros = torch.zeros([1, input_len], dtype=torch.long) x = self.embeddings(zeros) x = x.transpose(1, 2) x = self.cnn_model(x) return x.size()[1] def forward(self, x, state=None): # Adding `state` to be consistent with other models x = self.embeddings(x) x = x.transpose(1, 2) x = self.cnn_model(x) x = self.dense_model(x) return x, state # Returning tuple to be consistent with other models ================================================ FILE: src/models/nn/__init__.py ================================================ from .activation import Activation ================================================ FILE: src/models/nn/activation.py ================================================ """Utilities for activation functions.""" import math import torch import torch.nn as nn import torch.nn.functional as F def Activation(activation=None, size=None, dim=-1): """Returns a PyTorch activation module.""" if activation in [None, 'id', 'identity', 'linear', 'none']: return nn.Identity() elif activation == 'tanh': return nn.Tanh() elif activation == 'relu': return nn.ReLU() elif activation == 'gelu': return nn.GELU() elif activation == 'elu': return nn.ELU() elif activation in ['swish', 'silu']: return nn.SiLU() elif activation == 'glu': return nn.GLU(dim=dim) elif activation.startswith('glu-'): return GLU(dim=dim, activation=activation[4:]) elif activation == 'sigmoid': return nn.Sigmoid() elif activation == 'softplus': return nn.Softplus() elif activation == 'modrelu': return ModReLU(size) elif activation in ['sqrelu', 'relu2']: return SquaredReLU() elif activation == 'laplace': return Laplace() # Earlier experimentation with a LN in the middle of the block instead of activation # IIRC ConvNext does something like this? # elif activation == 'ln': # return TransposedLN(dim) else: raise NotImplementedError("hidden activation '{}' is not implemented".format(activation)) class GLU(nn.Module): def __init__(self, dim=-1, activation='sigmoid'): super().__init__() assert not activation.startswith('glu') self.dim = dim self.activation_fn = Activation(activation) def forward(self, x): x, g = torch.split(x, x.size(self.dim) // 2, dim=self.dim) return x * self.activation_fn(g) class ModReLU(nn.Module): # Adapted from https://github.com/Lezcano/expRNN def __init__(self, features): # For now we just support square layers super().__init__() self.features = features self.b = nn.Parameter(torch.Tensor(self.features)) self.reset_parameters() def reset_parameters(self): self.b.data.uniform_(-0.01, 0.01) def forward(self, inputs): norm = torch.abs(inputs) biased_norm = norm + self.b magnitude = F.relu(biased_norm) phase = torch.sign(inputs) return phase * magnitude class SquaredReLU(nn.Module): def forward(self, x): # return F.relu(x)**2 return torch.square(F.relu(x)) # Could this be faster? def laplace(x, mu=0.707107, sigma=0.282095): x = (x - mu).div(sigma * math.sqrt(2.0)) return 0.5 * (1.0 + torch.erf(x)) class Laplace(nn.Module): def __init__(self, mu=0.707107, sigma=0.282095): super().__init__() self.mu = mu self.sigma = sigma def forward(self, x): return laplace(x, mu=self.mu, sigma=self.sigma) ================================================ FILE: src/models/nn/adaptive_softmax.py ================================================ # Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional import functools import torch import torch.nn as nn import torch.nn.functional as F class OptionalParameterList(nn.ParameterList): def extra_repr(self): child_lines = [] for k, p in self._parameters.items(): if p is not None: size_str = 'x'.join(str(size) for size in p.size()) device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device()) parastr = 'Parameter containing: [{} of size {}{}]'.format( torch.typename(p), size_str, device_str) child_lines.append(' (' + str(k) + '): ' + parastr) tmpstr = '\n'.join(child_lines) return tmpstr class ProjectedAdaptiveLogSoftmax(nn.Module): def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, tie_projs=None, out_layers_weights=None, out_projs=None, keep_order=False, bias_scale=0.0, dropout=0.0, ): super().__init__() self.n_token = n_token self.d_embed = d_embed self.d_proj = d_proj self.cutoffs = list(cutoffs) + [n_token] self.cutoff_ends = [0] + self.cutoffs self.div_val = div_val self.shortlist_size = self.cutoffs[0] self.n_clusters = len(self.cutoffs) - 1 self.head_size = self.shortlist_size + self.n_clusters # bake the first False into the definition, just as [0] is built into the cutoffs if tie_projs is None: tie_projs = [] elif isinstance(tie_projs, bool): tie_projs = [tie_projs] * len(cutoffs) else: tie_projs = list(tie_projs) tie_projs = [False] + tie_projs self.tie_projs = tie_projs if self.n_clusters > 0: self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) if not out_layers_weights: self.out_layers_weights = nn.ParameterList() else: self.out_layers_weights = out_layers_weights self.out_layers_biases = nn.ParameterList() self.shared_out_projs = out_projs self.out_projs = OptionalParameterList() self.dropout = dropout self.drop = nn.Dropout(dropout) if div_val == 1: if d_proj != d_embed: for i in range(len(self.cutoffs)): if tie_projs[i]: self.out_projs.append(None) else: self.out_projs.append( nn.Parameter(torch.zeros(d_proj, d_embed)) ) else: self.out_projs.append(None) self.out_layers_biases.append( nn.Parameter(torch.zeros(n_token)) ) if not out_layers_weights: self.out_layers_weights.append( nn.Parameter(torch.zeros(n_token, d_embed)) ) else: for i in range(len(self.cutoffs)): l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] d_emb_i = d_embed // (div_val ** i) if tie_projs[i]: self.out_projs.append(None) else: self.out_projs.append( nn.Parameter(torch.zeros(d_proj, d_emb_i)) ) self.out_layers_biases.append( nn.Parameter(torch.zeros(r_idx - l_idx)) ) if not out_layers_weights: self.out_layers_weights.append( nn.Parameter(torch.zeros(r_idx - l_idx, d_emb_i)) ) for bias in self.out_layers_biases: bound = bias_scale * d_proj ** -.5 nn.init.uniform_(bias, -bound, bound) self.keep_order = keep_order def _compute_logit(self, hidden, weight, bias, proj): if proj is None: logit = F.linear(hidden, weight, bias=bias) else: if self.dropout > 0.0: logit = hidden @ proj logit = self.drop(logit) logit = logit @ weight.t() else: logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) if bias is not None: logit = logit + bias return logit def get_out_proj(self, i): if self.tie_projs[i]: if len(self.shared_out_projs) == 0: return None elif len(self.shared_out_projs) == 1: return self.shared_out_projs[0] else: return self.shared_out_projs[i] else: return self.out_projs[i] def forward(self, hidden, target, keep_order=False, key_padding_mask=None, *args, **kwargs): # [21-09-15 AG]: TODO may need to handle key_padding_mask ''' hidden :: [len*bsz x d_proj] target :: [len*bsz] ''' hidden = hidden.reshape(-1, hidden.size(-1)) target = target.reshape(-1) if hidden.size(0) != target.size(0): print(hidden.shape, target.shape) raise RuntimeError('Input and target should have the same size ' 'in the batch dimension.') if self.n_clusters == 0: logit = self._compute_logit(hidden, self.out_layers_weights[0], self.out_layers_biases[0], self.get_out_proj(0)) nll = -F.log_softmax(logit, dim=-1) \ .gather(1, target.unsqueeze(1)).squeeze(1) else: # construct weights and biases weights, biases = [], [] for i in range(len(self.cutoffs)): if self.div_val == 1: l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] weight_i = self.out_layers_weights[0][l_idx:r_idx] bias_i = self.out_layers_biases[0][l_idx:r_idx] else: weight_i = self.out_layers_weights[i] bias_i = self.out_layers_biases[i] if i == 0: weight_i = torch.cat( [weight_i, self.cluster_weight], dim=0) bias_i = torch.cat( [bias_i, self.cluster_bias], dim=0) weights.append(weight_i) biases.append(bias_i) head_weight, head_bias, head_proj = weights[0], biases[0], self.get_out_proj(0) head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) head_logprob = F.log_softmax(head_logit, dim=1) nll = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device) offset = 0 cutoff_values = [0] + self.cutoffs for i in range(len(cutoff_values) - 1): l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] mask_i = (target >= l_idx) & (target < r_idx) indices_i = mask_i.nonzero(as_tuple=False).squeeze() if indices_i.numel() == 0: continue target_i = target.index_select(0, indices_i) - l_idx head_logprob_i = head_logprob.index_select(0, indices_i) if i == 0: logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1) else: weight_i, bias_i, proj_i = weights[i], biases[i], self.get_out_proj(i) hidden_i = hidden.index_select(0, indices_i) tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) # First term accounts for cluster probabilities logprob_i = head_logprob_i[:, -i] \ + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1) if self.keep_order or keep_order: nll.index_copy_(0, indices_i, -logprob_i) else: nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) offset += logprob_i.size(0) # TODO This should be a bug in the original implementation; it should go into the continue case above as well return nll.mean() # TODO maybe cases for length or padding_mask def compute_logits(self, hidden): """Compute full vector of logits Adapted from https://github.com/kimiyoung/transformer-xl/issues/88 """ hidden = hidden.reshape(-1, hidden.size(-1)) if self.n_clusters == 0: logits = self._compute_logit(hidden, self.out_layers_weights[0], self.out_layers_biases[0], self.get_out_proj(0)) return logits else: # construct weights and biases weights, biases = [], [] for i in range(len(self.cutoffs)): if self.div_val == 1: l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] weight_i = self.out_layers_weights[0][l_idx:r_idx] bias_i = self.out_layers_biases[0][l_idx:r_idx] else: weight_i = self.out_layers_weights[i] bias_i = self.out_layers_biases[i] if i == 0: weight_i = torch.cat( [weight_i, self.cluster_weight], dim=0) bias_i = torch.cat( [bias_i, self.cluster_bias], dim=0) weights.append(weight_i) biases.append(bias_i) head_weight, head_bias, head_proj = weights[0], biases[0], self.get_out_proj(0) head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) head_logprob = F.log_softmax(head_logit, dim=1) out_full_logps = [head_logprob[:, :self.cutoffs[0]]] offset = 0 cutoff_values = [0] + self.cutoffs for i in range(1, len(cutoff_values) - 1): l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] head_logprob_i = head_logprob # .index_select(0, indices_i) if i == 0: logprob_i = head_logprob_i else: weight_i, bias_i, proj_i = weights[i], biases[i], self.get_out_proj(i) hidden_i = hidden # .index_select(0, indices_i) tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) logprob_i = head_logprob_i[:, -i].view(-1, 1) + tail_logprob_i offset += logprob_i.size(0) out_full_logps.append(logprob_i) out_full_logps = torch.cat(out_full_logps, dim = 1) # print(torch.sum(out_full_ps), out_full_ps.shape) return out_full_logps class AdaptiveEmbedding(nn.Module): """ Copy of transformers.AdaptiveEmbedding that works with fp16 by replacing the index_put_ operation Initialization has been fixed for the case when d_proj = d_embed """ def __init__(self, n_token, d_embed, d_proj, cutoffs : List[int], div_val=1, init_scale=1.0, sample_softmax=False, dropout=0.0): super().__init__() self.n_token = n_token self.d_embed = d_embed self.cutoffs = list(cutoffs) + [n_token] self.div_val = div_val self.d_proj = d_proj self.drop = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() self.emb_scale = d_proj ** 0.5 self.cutoff_ends = [0] + self.cutoffs self.emb_layers = nn.ModuleList() self.emb_projs = nn.ParameterList() if div_val == 1: self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0)) _init_embed(self.emb_layers[-1].weight, d_embed, init_scale) # torch.nn.init.normal_(self.emb_layers[-1].weight, mean=0, std=init_scale * d_embed ** -.5) if d_proj != d_embed: # TODO # self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed))) self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed))) # torch.nn.init.normal_(self.emb_projs[-1], mean=0, std=init_scale * 1./self.emb_scale) _init_proj(self.emb_projs[-1], d_proj, init_scale) else: for i in range(len(self.cutoffs)): l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] d_emb_i = d_embed // (div_val ** i) self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i)) # torch.nn.init.normal_(self.emb_layers[-1].weight, mean=0, std=init_scale * d_emb_i ** -.5) _init_embed(self.emb_layers[-1].weight, d_emb_i, init_scale) self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i))) # torch.nn.init.normal_(self.emb_projs[-1], mean=0, std=init_scale * 1./self.emb_scale) _init_proj(self.emb_projs[-1], d_proj, init_scale) def forward(self, inp): if self.div_val == 1: embed = self.emb_layers[0](inp) embed = self.drop(embed) if self.d_proj != self.d_embed: embed = F.linear(embed, self.emb_projs[0]) else: param = next(self.parameters()) inp_flat = inp.reshape(-1) # Changes from original impl # emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device) embeddings = [] indices = torch.zeros_like(inp_flat) # empty should work as long as cutoffs[-1] > max token _total_tokens = 0 # emb_flat = inp.new_zeros(inp_flat.size(0), self.d_proj) for i in range(len(self.cutoffs)): l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) indices_i = mask_i.nonzero().squeeze(-1) # shape (_tokens,) _tokens = indices_i.numel() if _tokens == 0: continue inp_i = inp_flat.index_select(0, indices_i) - l_idx emb_i = self.emb_layers[i](inp_i) emb_i = self.drop(emb_i) emb_i = F.linear(emb_i, self.emb_projs[i]) # Changes embeddings.append(emb_i) indices.index_put_( (indices_i,), torch.arange(_tokens, device=inp.device) + _total_tokens ) _total_tokens += _tokens # emb_flat.index_copy_(0, indices_i, emb_i) embeddings = torch.cat(embeddings, dim=0) emb_flat = embeddings[indices] embed_shape = inp.size() + (self.d_proj,) embed = emb_flat.view(embed_shape) embed.mul_(self.emb_scale) # embed.div_(self.emb_scale) return embed def _init_weight(weight, d : int, init_scale : Optional[float], default=None): assert init_scale or default if init_scale is None: std = default else: std = init_scale * (d ** -0.5) nn.init.normal_(weight, mean=0, std=std) _init_embed = functools.partial(_init_weight, default=0.02) _init_proj = functools.partial(_init_weight, default=0.01) ================================================ FILE: src/models/nn/utils.py ================================================ """ Utility wrappers around modules to let them handle Args and extra arguments """ import inspect from functools import wraps import torch from torch import nn def wrap_kwargs(f): """ Given a callable f that can consume some named arguments, wrap it with a kwargs that passes back any unused args EXAMPLES -------- Basic usage: def foo(x, y=None): return x wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2}) -------- The wrapped function can return its own argument dictionary, which gets merged with the new kwargs. def foo(x, y=None): return x, {} wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2}) def foo(x, y=None): return x, {"y": y, "z": None} wrap_kwargs(foo)(0, y=1, z=2) == (0, {'y': 1, 'z': 2}) -------- The wrapped function can have its own kwargs parameter: def foo(x, y=None, **kw_args): return x, {} wrap_kwargs(foo)(0, y=1, z=2) == (0, {}) -------- Partial functions and modules work automatically: class Module: def forward(self, x, y=0): return x, {"y": y+1} m = Module() wrap_kwargs(m.forward)(0, y=1, z=2) == (0, {'y': 2, 'z': 2}) """ sig = inspect.signature(f) # Check if f already has kwargs has_kwargs = any([ param.kind == inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values() ]) if has_kwargs: @wraps(f) def f_kwargs(*args, **kwargs): y = f(*args, **kwargs) if isinstance(y, tuple) and isinstance(y[-1], dict): return y else: return y, {} else: param_kwargs = inspect.Parameter("kwargs", kind=inspect.Parameter.VAR_KEYWORD) sig_kwargs = inspect.Signature(parameters=list(sig.parameters.values())+[param_kwargs]) @wraps(f) def f_kwargs(*args, **kwargs): bound = sig_kwargs.bind(*args, **kwargs) if "kwargs" in bound.arguments: kwargs = bound.arguments.pop("kwargs") else: kwargs = {} y = f(**bound.arguments) if isinstance(y, tuple) and isinstance(y[-1], dict): return *y[:-1], {**y[-1], **kwargs} else: return y, kwargs return f_kwargs def discard_kwargs(f): if f is None: return None f_kwargs = wrap_kwargs(f) @wraps(f) def f_(*args, **kwargs): return f_kwargs(*args, **kwargs)[0] return f_ def PassthroughSequential(*modules): """Special Sequential module that chains kwargs. Semantics are the same as nn.Sequential, with extra convenience features: - Discard None modules - Flatten inner Sequential modules - In case with 0 or 1 Module, rename the class for ease of inspection """ def flatten(module): if isinstance(module, nn.Sequential): return sum([flatten(m) for m in module], []) else: return [module] modules = flatten(nn.Sequential(*modules)) modules = [module for module in modules if module if not None] class Sequential(nn.Sequential): def forward(self, x, **kwargs): for layer in self: x, kwargs = wrap_kwargs(layer.forward)(x, **kwargs) return x, kwargs def step(self, x, **kwargs): for layer in self: fn = getattr(layer, "step", layer.forward) x, kwargs = wrap_kwargs(fn)(x, **kwargs) return x, kwargs if len(modules) == 0: Sequential.__name__ = "Identity" elif len(modules) == 1: Sequential.__name__ = type(modules[0]).__name__ return Sequential(*modules) ================================================ FILE: src/models/sequence/__init__.py ================================================ ================================================ FILE: src/models/sequence/dna_embedding.py ================================================ """DNA Embedding Model. Backbones from LM pre-training models, used for downstream tasks. """ from functools import partial import torch import torch.nn as nn from flash_attn.utils.generation import GenerationMixin from mamba_ssm.models.config_mamba import MambaConfig from mamba_ssm.models.mixer_seq_simple import MixerModel from mamba_ssm.models.mixer_seq_simple import _init_weights as _init_weights_mamba try: from flash_attn.ops.fused_dense import ColumnParallelLinear except ImportError: ColumnParallelLinear = None from caduceus.configuration_caduceus import CaduceusConfig from caduceus.modeling_caduceus import Caduceus from src.models.sequence.long_conv_lm import LMBackbone from src.models.sequence.long_conv_lm import _init_weights class DNAEmbeddingModel(nn.Module, GenerationMixin): """DNA Embedding Model. Same as ConvLMHeadModel (in long_conv_lm.py), except no decoder head, we just pass back the hidden states for downstream tasks. """ def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int, process_group=None, layer=None, attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0, resid_dropout: float = 0.0, embed_dropout: float = 0.1, dropout_cls=nn.Dropout, norm_epsilon: float = 1e-5, rms_norm: bool = False, initializer_cfg=None, checkpoint_mlp=False, checkpoint_mixer=False, fused_mlp=False, fused_dropout_add_ln=False, residual_in_fp32=False, pad_vocab_size_multiple: int = 1, sequence_parallel=True, device=None, dtype=None, return_hidden_state=False, **kwargs) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.d_model = d_model # for decoder self.process_group = process_group self.return_hidden_state = return_hidden_state if vocab_size % pad_vocab_size_multiple != 0: vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) self.backbone = LMBackbone( d_model=d_model, n_layer=n_layer, d_inner=d_inner, vocab_size=vocab_size, process_group=process_group, layer=layer, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg, max_position_embeddings=max_position_embeddings, resid_dropout=resid_dropout, embed_dropout=embed_dropout, dropout_cls=dropout_cls, norm_epsilon=norm_epsilon, rms_norm=rms_norm, initializer_cfg=initializer_cfg, fused_mlp=fused_mlp, fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=residual_in_fp32, sequence_parallel=sequence_parallel, checkpoint_mlp=checkpoint_mlp, checkpoint_mixer=checkpoint_mixer, **factory_kwargs, **kwargs ) # Initialize weights and apply final processing self.apply(partial(_init_weights, n_layer=n_layer, **(initializer_cfg if initializer_cfg is not None else {}))) def forward(self, input_ids, position_ids=None, inference_params=None, state=None): # state for the repo interface """DNA Embedding Model forward pass.""" hidden_states = self.backbone(input_ids, position_ids=position_ids, inference_params=inference_params) # we only need the last hidden state for embeddings (decoder head will predict classification task) return hidden_states, None @property def d_output(self): """Model /embedding dimension, used for decoder mapping. """ if getattr(self, "d_model", None) is None: raise NotImplementedError("SequenceModule instantiation must set d_output") return self.d_model class DNAEmbeddingModelMamba(DNAEmbeddingModel): """Custom DNA Embedding Model that is compatible with open-source Mamba repo.""" def __init__( self, config: MambaConfig, initializer_cfg=None, conjoin_train=False, conjoin_test=False, device=None, dtype=None, ): super(DNAEmbeddingModel, self).__init__() # nn.Module.__init__() self.config = config d_model = config.d_model self.d_model = d_model # for decoder n_layer = config.n_layer vocab_size = config.vocab_size ssm_cfg = config.ssm_cfg rms_norm = config.rms_norm residual_in_fp32 = config.residual_in_fp32 fused_add_norm = config.fused_add_norm pad_vocab_size_multiple = config.pad_vocab_size_multiple factory_kwargs = {"device": device, "dtype": dtype} if vocab_size % pad_vocab_size_multiple != 0: vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) self.backbone = MixerModel( d_model=d_model, n_layer=n_layer, vocab_size=vocab_size, ssm_cfg=ssm_cfg, rms_norm=rms_norm, initializer_cfg=initializer_cfg, fused_add_norm=fused_add_norm, residual_in_fp32=residual_in_fp32, **factory_kwargs, ) # Initialize weights and apply final processing self.apply( partial( _init_weights_mamba, n_layer=n_layer, **(initializer_cfg if initializer_cfg is not None else {}), ) ) self.conjoin_train = conjoin_train self.conjoin_test = conjoin_test def forward(self, input_ids, position_ids=None, inference_params=None, state=None): # state for the repo interface """Mamba backbone-specific forward pass that does not use `position_ids`.""" hidden_states = self.backbone(input_ids, inference_params=inference_params) # we only need the last hidden state for embeddings (decoder head will predict classification task) return hidden_states, None class DNAEmbeddingModelCaduceus(DNAEmbeddingModel): """Custom DNA Embedding Model that is compatible with Caduceus models.""" def __init__( self, config: CaduceusConfig, device=None, dtype=None, conjoin_train=False, conjoin_test=False, ): super(DNAEmbeddingModel, self).__init__() # nn.Module.__init__() self.config = config self.d_model = config.d_model # for decoder factory_kwargs = {"device": device, "dtype": dtype} self.caduceus = Caduceus( config=config, **factory_kwargs, ) self.conjoin_train = conjoin_train self.conjoin_test = conjoin_test def forward(self, input_ids, position_ids=None, inference_params=None, state=None): # state for the repo interface """Caduceus backbone-specific forward pass that does not use `position_ids`.""" if self.config.rcps: # Hidden states have 2 * d_model channels for RCPS hidden_states = self.caduceus(input_ids, return_dict=False) num_chan = hidden_states.shape[-1] return torch.stack( [hidden_states[..., :num_chan // 2], torch.flip(hidden_states[..., num_chan // 2:], dims=[1, 2])], dim=-1 ), None if self.conjoin_train or (self.conjoin_test and not self.training): # For conjoining / post-hoc conjoining assert input_ids.ndim == 3, "Input must be 3D tensor, where channels corresponds to forward and rc strands" hidden_states = self.caduceus(input_ids[..., 0], return_dict=False) hidden_states_rc = self.caduceus(input_ids[..., 1], return_dict=False) # Stack along channel dimension (dim=-1) return torch.stack([hidden_states, hidden_states_rc], dim=-1), None return self.caduceus(input_ids, return_dict=False), None def load_backbone(model, state_dict, freeze_backbone=False, ignore_head=True): """ Modifies state dict loading with custom function. This is necessary because the head of a lm outputs logits for vocab, but we just need the embeddings for downstream tasks. inputs: model: nn.Module, the from 'scratch' model state_dict: dict, from the pretrained weights ignore_head: bool, whether to inflate weights in the head (or keep scratch weights). If number of classes changes, then you need to use this. return: state_dict: dict, update with inflated weights """ # consumes prefix from pretrained model, if necessary torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( state_dict, "model." ) model_new_params_dict = model.state_dict() updated_model_state_dict = {} # loop through scratch model keys (pretrained may have extra stuff) for key in sorted(model_new_params_dict.keys()): loaded_params = state_dict.get(key, None) if loaded_params is None: # This should never happen, it should be there! print("Missing key in pretrained model!", key) raise Exception elif ignore_head and 'head' in key: # ignore head weights print("found head key / parameter, load from scratch", key) # using scratch by default, nothing needed used_params = model_new_params_dict[key] elif "decoder" in key: print("found decoder key / parameter, load from scratch", key) used_params = model_new_params_dict[key] else: print('key: shape MATCH, loading', key) # load matched weights used_params = loaded_params # we need to pass back a state dict with the '.model' prefix!!!!! key_with_prefix = 'model.' + key updated_model_state_dict[key_with_prefix] = used_params if freeze_backbone: print("freezing model backbone params!") # note, decoder not included in backbone for name, param in model.named_parameters(): param.requires_grad = False # we have updated the new model state dict with pretrained now return updated_model_state_dict ================================================ FILE: src/models/sequence/hyena.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange try: from src.ops.fftconv import fftconv_ref, fftconv_func, fftconv_heads_ref except ImportError: fftconv_func = None try: from flash_attn.ops.fused_dense import FusedDense except ImportError: FusedDense = None import src.utils.registry as registry from src.utils.train import OptimModule from src.utils.config import instantiate, auto_assign_attrs from src.models.nn import Activation class FFTConvFuncv2(torch.autograd.Function): @staticmethod def forward(ctx, u, k): seqlen = u.shape[-1] if len(u.shape) > 3: k = k.unsqueeze(1) fft_size = 2 * seqlen k_f = torch.fft.rfft(k, n=fft_size) / fft_size u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] ctx.save_for_backward(u_f, k_f) return y @staticmethod def backward(ctx, dout): u_f, k_f = ctx.saved_tensors seqlen = dout.shape[-1] fft_size = 2 * seqlen dout_f = torch.fft.rfft(dout, n=fft_size) du = torch.fft.irfft(dout_f * k_f.conj(), n=fft_size, norm="forward")[ ..., :seqlen ] dk = torch.fft.irfft(dout_f * u_f.conj(), n=fft_size, norm="forward")[ ..., :seqlen ] return du, dk.squeeze() def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): seqlen = u.shape[-1] fft_size = 2 * seqlen k_f = torch.fft.rfft(k, n=fft_size) / fft_size if k_rev is not None: k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size k_f = k_f + k_rev_f.conj() u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) if len(u.shape) > 3: k_f = k_f.unsqueeze(1) y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] out = y + u * D.unsqueeze(-1) if gelu: out = F.gelu(out) if dropout_mask is not None: return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype) else: return out.to(dtype=u.dtype) @torch.jit.script def mul_sum(q, y): return (q * y).sum(dim=1) class Sin(nn.Module): def __init__(self, dim, w=10, train_freq=True): super().__init__() self.freq = ( nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim) ) def forward(self, x): return torch.sin(self.freq * x) class PositionalEmbedding(OptimModule): def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float = 1e-5, **kwargs): """Complex exponential positional embeddings for Hyena filters.""" super().__init__() self.seq_len = seq_len # The time embedding fed to the filteres is normalized so that t_f = 1 t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1 if emb_dim > 1: bands = (emb_dim - 1) // 2 # To compute the right embeddings we use the "proper" linspace t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1 f = torch.linspace(1e-4, bands - 1, bands)[None, None] z = torch.exp(-1j * f * w) z = torch.cat([t, z.real, z.imag], dim=-1) self.register("z", z, lr=lr_pos_emb) self.register("t", t, lr=0.0) def forward(self, L): return self.z[:, :L], self.t[:, :L] class ExponentialModulation(OptimModule): def __init__( self, d_model, fast_decay_pct=0.3, slow_decay_pct=1.5, target=1e-2, modulation_lr=0.0, shift: float = 0.0, **kwargs, ): super().__init__() self.shift = shift max_decay = math.log(target) / fast_decay_pct min_decay = math.log(target) / slow_decay_pct deltas = torch.linspace(min_decay, max_decay, d_model)[None, None] self.register("deltas", deltas, lr=modulation_lr) def forward(self, t, x): decay = torch.exp(-t * self.deltas.abs()) x = x * (decay + self.shift) return x class HyenaFilter(OptimModule): def __init__( self, d_model, emb_dim=3, # dim of input to MLP, augments with positional encoding order=16, # width of the implicit MLP fused_fft_conv=False, seq_len=1024, lr=1e-3, lr_pos_emb=1e-5, dropout=0.0, w=1, # frequency of periodic activations wd=0, # weight decay of kernel parameters bias=True, num_inner_mlps=2, linear_mixer=False, modulate: bool = True, normalized=False, **kwargs, ): """ Implicit long filter with modulation. Args: d_model: number of channels in the input emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands order: width of the FFN num_inner_mlps: number of inner linear layers inside filter MLP Note: filter_dropout is not implemented """ super().__init__() auto_assign_attrs( self, d_model=d_model, emb_dim=emb_dim, seq_len=seq_len, modulate=modulate ) self.use_bias = bias self.fused_fft_conv = fused_fft_conv self.bias = nn.Parameter(torch.randn(self.d_model)) self.dropout = nn.Dropout(dropout) act = Sin(dim=order, w=w) assert ( emb_dim % 2 != 0 and emb_dim >= 3 ), "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)" self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb) # uses a variable number of inner linear layers if linear_mixer is False: self.implicit_filter = nn.Sequential( nn.Linear(emb_dim, order), act, ) for i in range(num_inner_mlps): self.implicit_filter.append(nn.Linear(order, order)) self.implicit_filter.append(act) # final linear layer self.implicit_filter.append(nn.Linear(order, d_model, bias=False)) else: self.implicit_filter = nn.Sequential( nn.Linear(emb_dim, d_model, bias=False), ) self.modulation = ExponentialModulation(d_model, **kwargs) self.normalized = normalized for c in self.implicit_filter.children(): for name, v in c.state_dict().items(): optim = {"weight_decay": wd, "lr": lr} setattr(getattr(c, name), "_optim", optim) def filter(self, L, *args, **kwargs): z, t = self.pos_emb(L) h = self.implicit_filter(z) if self.modulate: h = self.modulation(t, h) if self.normalized: h = h / torch.norm(h, dim=-1, p=1, keepdim=True) return h def forward(self, x, L, k=None, bias=None, *args, **kwargs): if k is None: k = self.filter(L) # Ensure compatibility with filters that return a tuple k = k[0] if type(k) is tuple else k if bias is None: bias = self.bias bias = bias if self.use_bias else 0 * bias if self.fused_fft_conv: bias = bias.to(dtype=torch.float32) y = fftconv_func( x, k, bias, dropout_mask=None, gelu=False, force_fp16_output=torch.is_autocast_enabled(), ) else: y = fftconv_ref(x, k, bias, dropout_mask=None, gelu=False) # y = ( # FFTConvFuncv2.apply(x, k.to(dtype=torch.float32)) # + bias.unsqueeze(-1) * x # ) return y.to(dtype=x.dtype) class HyenaOperator(nn.Module): def __init__( self, d_model, l_max, order=2, filter_order=64, num_heads=1, inner_factor=1, num_blocks=1, fused_bias_fc=False, outer_mixing=False, dropout=0.0, filter_dropout=0.0, filter_cls="hyena-filter", post_order_ffn=False, jit_filter=False, short_filter_order=3, activation="id", return_state=False, **filter_args, ): r""" Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf Args: d_model (int): Dimension of the input and output embeddings (width of the layer) l_max: (int): Maximum input sequence length. Defaults to None order: (int): Depth of the Hyena recurrence. Defaults to 2 filter_order: (int): Width of the FFN parametrizing the implicit filter. Defaults to 64 num_heads: (int): Number of heads. Defaults to 1 inner_factor: (int): Width multiplier. Defaults to 1 num_blocks: (int): Number of blocks in sequence length. Defaults to 1 fused_bias_fc: (bool): Whether to use fused bias FC. Defaults to False dropout: (float): Dropout probability. Defaults to 0.0 filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0 post_order_ffn: (bool): Apply a dense layer between steps of the recurrence. Defaults to False jit_filter: (bool): Whether JIT the implicit filter function. Defaults to False short_filter_order: (int): Length of the explicit input convolutional filter. Defaults to 3 activation: (str): type of act between kernel output and FF (default identity) return_state: (bool): whether to return a state """ super().__init__() assert ( d_model % num_heads == 0 ), f"Model dimension {d_model} must be divisible by num heads {num_heads}" assert ( l_max % num_blocks == 0 ), f"Maximum signal length {l_max} must be divisible by block dimension {num_blocks}" block_dim = l_max // num_blocks head_dim = d_model // num_heads auto_assign_attrs( self, d_model=d_model, order=order, l_max=l_max, num_heads=num_heads, inner_factor=inner_factor, block_dim=block_dim, head_dim=head_dim, filter_order=filter_order, post_order_ffn=post_order_ffn, short_filter_order=short_filter_order, num_blocks=num_blocks, filter_dropout=filter_dropout, jit_filter=jit_filter, outer_mixing=outer_mixing, activation=activation, return_state=return_state, ) self.activation = Activation(activation) self.dropout = nn.Dropout(dropout) self.setup_projections(fused_bias_fc, inner_factor) self.setup_filters(filter_cls, filter_args) def setup_projections(self, fused_bias_fc, inner_factor): "Initializes input and output projections (over the width dimension)" if fused_bias_fc and FusedDense is None: raise ImportError("fused_dense is not installed") linear_cls = nn.Linear if not fused_bias_fc else FusedDense self.out_proj = linear_cls(self.d_model * inner_factor, self.d_model) self.in_proj = linear_cls(self.d_model, (self.order + 1) * self.d_model) if self.post_order_ffn: self.ord_proj_w = nn.Parameter( torch.randn(self.order, self.num_heads, self.num_heads) / math.sqrt(self.head_dim) ) def setup_filters(self, filter_cls, filter_args): "Initializes the explicit and implicit filters" assert self.order >= 2, f"Order must be at least 2, (got {self.order})" total_width = self.d_model * self.inner_factor * (self.order + 1) self.short_filter = nn.Conv1d( in_channels=total_width, out_channels=total_width, kernel_size=self.short_filter_order, groups=total_width, padding=self.short_filter_order - 1, ) filter_cls = instantiate(registry.layer, filter_cls, partial=True) self.filter_fn = filter_cls( self.head_dim * self.inner_factor * (self.order - 1), order=self.filter_order, seq_len=self.l_max, channels=1, dropout=self.filter_dropout, **filter_args, ) if self.jit_filter: self.filter_fn = torch.jit.script(self.filter_fn, self.L) def recurrence(self, u, state): "Fast inference mode via distilled recurrence" raise NotImplementedError("Working on it!") def forward(self, u, *args, **kwargs): l = u.size(-2) l_filter = min(l, self.l_max) u = self.in_proj(u) u = rearrange(u, "b l d -> b d l") uc = self.short_filter(u)[..., :l_filter] uc = rearrange( uc, "b (ho v) (z l) -> b ho v z l", z=self.num_blocks, ho=self.num_heads, v=self.head_dim * (self.order + 1), ) *x, v = uc.split(self.d_model, dim=2) k = self.filter_fn.filter(l_filter) # `c` is always 1 by default k = rearrange(k, "c l (v o) -> c o v l", v=self.head_dim, o=self.order - 1)[0] bias = rearrange( self.filter_fn.bias, "(v o) -> o v", v=self.head_dim, o=self.order - 1 ) for o, x_i in enumerate(reversed(x[1:])): if self.outer_mixing: v = rearrange(v, "b h v z l -> b h 1 v z l") v = self.dropout(v * rearrange(x_i, "b h v z l -> b h v 1 z l")) v = v.sum(dim=2) else: v = self.dropout(v * x_i) # the bias term is broadcasted. Last dimension (l) is handled by fftconv v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None]) if self.post_order_ffn: w = self.ord_proj_w[o] v = mul_sum( rearrange(w, "h1 h2 -> 1 h1 h2 1 1 1"), rearrange(v, "b h v z l -> b h 1 v z l"), ) y = self.activation( rearrange( v * x[0], "b h v z l -> b (z l) (h v)", z=self.num_blocks, h=self.num_heads, ) ) y = self.out_proj(y) if self.return_state: return y, None return y @property def d_output(self): return self.d_model ================================================ FILE: src/models/sequence/long_conv_lm.py ================================================ import copy import math import re from collections import namedtuple from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from flash_attn.modules.block import Block from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.modules.mha import MHA, ParallelMHA from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP from flash_attn.utils.distributed import sync_shared_params, all_gather_raw from flash_attn.utils.generation import GenerationMixin from torch.utils.checkpoint import checkpoint try: from flash_attn.ops.fused_dense import ColumnParallelLinear except ImportError: ColumnParallelLinear = None try: from flash_attn.ops.layer_norm import dropout_add_layer_norm except ImportError: dropout_add_layer_norm = None from src.utils import instantiate import src.utils.registry as registry class CheckpointedModule(torch.nn.Module): def __init__(self, layer): super().__init__() self.layer = layer def forward(self, x): return checkpoint(self.layer, x) def create_mixer_cls( layer=None, process_group=None, attn_layer_idx=None, attn_cfg=None, layer_idx=None, sequence_parallel=True, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} parallel_kwargs = ( {"process_group": process_group, "sequence_parallel": sequence_parallel} if process_group is not None else {} ) if attn_layer_idx is not None and layer_idx in attn_layer_idx: causal = True if attn_cfg is None else attn_cfg.pop("causal", True) fused_bias_fc = ( False if attn_cfg is None else attn_cfg.get("fused_bias_fc", False) ) if not fused_bias_fc: assert process_group is None, "TensorParallel MHA requires fused_bias_fc" mha_cls = MHA if process_group is None else ParallelMHA # ParallelMHA doesn't take 'fused_bias_fc', it is assumed that we fuse matmul + bias if process_group is not None: attn_cfg = copy.deepcopy(attn_cfg) # Don't modify the original cfg attn_cfg.pop("fused_bias_fc", None) mixer_cls = partial( mha_cls, causal=causal, layer_idx=layer_idx, **(attn_cfg if attn_cfg is not None else {}), **parallel_kwargs, **factory_kwargs, ) else: fused_bias_fc = False if layer is None else layer.get("fused_bias_fc", False) if process_group is not None: assert fused_bias_fc, "TensorParallel SSM requires fused_bias_fc" mixer_cls = instantiate( registry.layer, layer, partial=True, layer_idx=layer_idx, **factory_kwargs, **parallel_kwargs, ) return mixer_cls def create_mlp_cls( d_model, d_inner=None, process_group=None, fused_mlp=False, sequence_parallel=True, identity_mlp=False, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} inner_dim = d_inner if d_inner is not None else 4 * d_model if process_group is not None: assert fused_mlp, "Tensor Parallel is only implemented for FusedMLP" if not fused_mlp and not identity_mlp: mlp_cls = partial( Mlp, hidden_features=inner_dim, activation=partial(F.gelu, approximate="tanh"), **factory_kwargs, ) elif fused_mlp: mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP parallel_kwargs = ( {"process_group": process_group, "sequence_parallel": sequence_parallel} if process_group is not None else {} ) mlp_cls = partial( mlp_cls, hidden_features=inner_dim, **parallel_kwargs, **factory_kwargs ) else: mlp_cls = nn.Identity return mlp_cls def create_block( d_model, d_inner=None, process_group=None, layer=None, attn_layer_idx=None, attn_cfg=None, layer_norm_epsilon=1e-5, resid_dropout1=0.0, resid_dropout2=0.0, residual_in_fp32=False, fused_mlp=False, identity_mlp=False, fused_dropout_add_ln=False, layer_idx=None, sequence_parallel=True, checkpoint_mlp=False, checkpoint_mixer=False, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} mixer_cls = create_mixer_cls( layer=layer, process_group=process_group, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg, layer_idx=layer_idx, sequence_parallel=sequence_parallel, **factory_kwargs, ) mlp_cls = create_mlp_cls( d_model, d_inner=d_inner, process_group=process_group, fused_mlp=fused_mlp, identity_mlp=identity_mlp, sequence_parallel=sequence_parallel, **factory_kwargs, ) norm_cls = partial(nn.LayerNorm, eps=layer_norm_epsilon, **factory_kwargs) block = Block( d_model, mixer_cls, mlp_cls, norm_cls=norm_cls, prenorm=True, resid_dropout1=resid_dropout1, resid_dropout2=resid_dropout2, fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=residual_in_fp32, sequence_parallel=sequence_parallel and process_group is not None, mark_shared_params=process_group is not None, ) block.layer_idx = layer_idx if checkpoint_mlp: block.mlp = CheckpointedModule(block.mlp) if checkpoint_mixer: block.mixer = CheckpointedModule(block.mixer) return block # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 def _init_weights( module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True, glu_act=False, ): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=initializer_range) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=initializer_range) if rescale_prenorm_residual: # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. # > -- GPT-2 :: https://openai.com/blog/better-language-models/ # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py for name, p in module.named_parameters(): if name in ["out_proj.weight", "fc2.weight"]: # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block nn.init.normal_( p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) ) # If using GLU activation for now, we scale the std by 2 elif name in ["output_linear.0.weight"]: # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block if not glu_act: nn.init.normal_( p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) ) else: out_features = p.shape[0] # Multiplying the first half of the matrix by 2 since sigmoid scales it down by 0.5 # on average. nn.init.normal_( p[: out_features // 2], mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) * 2, ) class LMBackbone(nn.Module): def __init__( self, d_model: int, n_layer: int, d_inner: int, vocab_size: int, process_group=None, layer=None, attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0, resid_dropout: float = 0.0, embed_dropout: float = 0.1, dropout_cls=nn.Dropout, layer_norm_epsilon: float = 1e-5, initializer_cfg=None, fused_mlp=False, identity_mlp=False, fused_dropout_add_ln=False, residual_in_fp32=False, sequence_parallel=True, checkpoint_mlp=False, checkpoint_mixer=False, device=None, dtype=None, **kwargs, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.process_group = process_group self.sequence_parallel = sequence_parallel self.residual_in_fp32 = residual_in_fp32 if process_group is None: self.embeddings = GPT2Embeddings( d_model, vocab_size, max_position_embeddings, **factory_kwargs ) else: self.embeddings = ParallelGPT2Embeddings( d_model, vocab_size, max_position_embeddings, process_group=process_group, sequence_parallel=self.sequence_parallel, **factory_kwargs, ) # We change the order of dropout, residual and layer norm: # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and # the main branch (output of MLP). The model definition is unchanged, but the mapping of the # nn.Dropout probabilities are changed. # This is for performance reason: we can fuse dropout + add + layer_norm. self.fused_dropout_add_ln = fused_dropout_add_ln if self.fused_dropout_add_ln and dropout_add_layer_norm is None: raise ImportError("dropout_add_layer_norm is not installed") self.layers = nn.ModuleList( [ create_block( d_model, d_inner=d_inner, process_group=process_group, layer=layer, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg, layer_norm_epsilon=layer_norm_epsilon, resid_dropout1=embed_dropout if i == 0 else resid_dropout, resid_dropout2=resid_dropout, residual_in_fp32=residual_in_fp32, fused_mlp=fused_mlp, identity_mlp=identity_mlp, fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, sequence_parallel=self.sequence_parallel, checkpoint_mlp=checkpoint_mlp, checkpoint_mixer=checkpoint_mixer, **factory_kwargs, ) for i in range(n_layer) ] ) self.drop_f = nn.Dropout(resid_dropout) self.ln_f = nn.LayerNorm(d_model, eps=layer_norm_epsilon, **factory_kwargs) if process_group is not None: for p in self.ln_f.parameters(): # Mark the norm parameters as "shared_params" so that we sync their values at init. p._shared_params = True # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads. if self.sequence_parallel: p._sequence_parallel = True self.apply( partial( _init_weights, n_layer=n_layer, **(initializer_cfg if initializer_cfg is not None else {}), ) ) self.tie_weights() def tie_weights(self): if self.process_group is not None: sync_shared_params(self, self.process_group) def forward(self, input_ids, position_ids=None, inference_params=None): # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen # dimensions so that we can split on it easily, in case of small batch size. # Only the attention/SSM layers need to know the seqlen. embedding_kwargs = ( {"combine_batch_seqlen_dim": True} if self.process_group is not None and self.sequence_parallel else {} ) hidden_states = self.embeddings( input_ids, position_ids=position_ids, **embedding_kwargs ) residual = None mixer_kwargs = ( {"seqlen": input_ids.shape[1]} if self.process_group is not None and self.sequence_parallel else {} ) if inference_params is not None: mixer_kwargs["inference_params"] = inference_params for layer in self.layers: hidden_states, residual = layer( hidden_states, residual, mixer_kwargs=mixer_kwargs ) if not self.fused_dropout_add_ln: dropped = self.drop_f(hidden_states) residual = (dropped + residual) if residual is not None else dropped hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) else: # Set prenorm=False here since we don't need the residual hidden_states = dropout_add_layer_norm( hidden_states, residual, self.ln_f.weight, self.ln_f.bias, self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False, residual_in_fp32=self.residual_in_fp32, ) return hidden_states class ConvLMHeadModel(nn.Module, GenerationMixin): def __init__( self, d_model: int, n_layer: int, d_inner: int, vocab_size: int, process_group=None, layer=None, attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0, resid_dropout: float = 0.0, embed_dropout: float = 0.1, dropout_cls=nn.Dropout, layer_norm_epsilon: float = 1e-5, initializer_cfg=None, fused_mlp=False, fused_dropout_add_ln=False, residual_in_fp32=False, pad_vocab_size_multiple: int = 1, sequence_parallel=True, checkpoint_mlp=False, checkpoint_mixer=False, device=None, dtype=None, **kwargs, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.process_group = process_group if vocab_size % pad_vocab_size_multiple != 0: vocab_size += pad_vocab_size_multiple - ( vocab_size % pad_vocab_size_multiple ) self.backbone = LMBackbone( d_model=d_model, n_layer=n_layer, d_inner=d_inner, vocab_size=vocab_size, process_group=process_group, layer=layer, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg, max_position_embeddings=max_position_embeddings, resid_dropout=resid_dropout, embed_dropout=embed_dropout, dropout_cls=dropout_cls, layer_norm_epsilon=layer_norm_epsilon, initializer_cfg=initializer_cfg, fused_mlp=fused_mlp, fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=residual_in_fp32, sequence_parallel=sequence_parallel, checkpoint_mlp=checkpoint_mlp, checkpoint_mixer=checkpoint_mixer, **factory_kwargs, **kwargs, ) if process_group is None: self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) else: if ColumnParallelLinear is None: raise ImportError("fused_dense_lib is not installed") self.lm_head = ColumnParallelLinear( d_model, vocab_size, process_group, bias=False, sequence_parallel=sequence_parallel, **factory_kwargs, ) # Initialize weights and apply final processing self.apply( partial( _init_weights, n_layer=n_layer, **(initializer_cfg if initializer_cfg is not None else {}), ) ) self.tie_weights() def tie_weights(self): self.lm_head.weight = self.backbone.embeddings.word_embeddings.weight if self.process_group is not None: sync_shared_params(self, self.process_group) def forward( self, input_ids, position_ids=None, inference_params=None, state=None ): # state for the repo interface hidden_states = self.backbone( input_ids, position_ids=position_ids, inference_params=inference_params ) lm_logits = self.lm_head(hidden_states) # During inference, we want the full logit for sampling if ColumnParallelLinear is not None and inference_params is not None: if isinstance(self.lm_head, ColumnParallelLinear): lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) lm_logits = rearrange( lm_logits, "(n b) s d -> b s (n d)", b=hidden_states.shape[0] ) CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) return CausalLMOutput(logits=lm_logits), None ================================================ FILE: src/ops/fftconv.py ================================================ import math import torch import torch.nn.functional as F from einops import rearrange from fftconv import fftconv_fwd, fftconv_bwd @torch.jit.script def _mul_sum(y, q): return (y * q).sum(dim=1) # reference convolution with residual connection def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): seqlen = u.shape[-1] fft_size = 2 * seqlen k_f = torch.fft.rfft(k, n=fft_size) / fft_size if k_rev is not None: k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size k_f = k_f + k_rev_f.conj() u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) if len(u.shape) > 3: k_f = k_f.unsqueeze(1) y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen] out = y + u * D.unsqueeze(-1) if gelu: out = F.gelu(out) if dropout_mask is not None: return (out * rearrange(dropout_mask, 'b H -> b H 1')).to(dtype=u.dtype) else: return out.to(dtype=u.dtype) # reference H3 forward pass def fftconv_h3_ref(k, ssm_kernel, D, q, v, head_dim=1, ssm_kernel_rev=None): seqlen = k.shape[-1] fft_size = 2 * seqlen kv = (rearrange(k, 'b (h d1) l -> b d1 1 h l', d1=head_dim) * rearrange(v, 'b (h d2) l -> b 1 d2 h l', d2=head_dim)) # b d1 d2 h l kv_f = torch.fft.rfft(kv.to(dtype=ssm_kernel.dtype), n=fft_size) / fft_size ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size) # h L+1 if ssm_kernel_rev is not None: ssm_kernel_rev_f = torch.fft.rfft(ssm_kernel_rev, n=fft_size) # h L+1 ssm_kernel_f = ssm_kernel_f + ssm_kernel_rev_f.conj() y = torch.fft.irfft(kv_f * ssm_kernel_f, n=fft_size, norm='forward')[..., :seqlen] # b d1 d2 h l out = y + kv * D.unsqueeze(-1) # b d1 d2 h l q = rearrange(q, 'b (h d1) l -> b d1 1 h l', d1=head_dim) if head_dim > 1: out = _mul_sum(out, q) return rearrange(out, 'b d2 h l -> b (h d2) l').to(dtype=k.dtype) else: return rearrange(out * q, 'b 1 1 h l -> b h l').to(dtype=k.dtype) class FFTConvFunc(torch.autograd.Function): @staticmethod def forward(ctx, u, k, D, dropout_mask=None, gelu=True, force_fp16_output=False, output_hbl_layout=False, v=None, head_dim=1, q=None, fftfp16=False, k_rev=None): seqlen = u.shape[-1] fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16) k_f = torch.fft.rfft(k, n=fft_size) if k_rev is not None: k_f = k_f + torch.fft.rfft(k_rev, n=fft_size).conj() if u.stride(-1) != 1: u = u.contiguous() k_f = k_f.contiguous() D = D.contiguous() if v is not None and v.stride(-1) != 1: v = v.contiguous() if q is not None and q.stride(-1) != 1: q = q.contiguous() if dropout_mask is not None: dropout_mask = dropout_mask.contiguous() ctx.save_for_backward(u, k_f, D, dropout_mask, v, q) ctx.output_hbl_layout = output_hbl_layout ctx.head_dim = head_dim ctx.gelu = gelu ctx.fftfp16 = fftfp16 ctx.has_k_rev = k_rev is not None out = fftconv_fwd(u, k_f, D, v, head_dim, q, dropout_mask, gelu, False, False, fft_size, force_fp16_output, output_hbl_layout, fftfp16) return out @staticmethod def backward(ctx, dout): if ctx.output_hbl_layout: dout = rearrange(rearrange(dout, 'b h l -> h b l').contiguous(), 'h b l -> b h l') else: dout = dout.contiguous() u, k_f, D, dropout_mask, v, q = ctx.saved_tensors seqlen = u.shape[-1] fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16) du, dk_f, dD, dv, dq = fftconv_bwd(dout, u, k_f, D, v, ctx.head_dim, q, dropout_mask, ctx.gelu, False, False, fft_size, ctx.output_hbl_layout, ctx.fftfp16) dk = torch.fft.irfft(dk_f, n=fft_size, norm='forward')[..., :seqlen] dk_rev = (None if not ctx.has_k_rev else torch.fft.irfft(dk_f.conj(), n=fft_size, norm='forward')[..., :seqlen]) if v is not None: dv = dv.to(dtype=v.dtype) # We do atomicAdd in fp32 so might need to convert to fp16 return du, dk, dD, None, None, None, None, dv if v is not None else None, None, dq if q is not None else None, None, dk_rev def fftconv_func(u, k, D, dropout_mask=None, gelu=True, force_fp16_output=False, output_hbl_layout=False, v=None, head_dim=1, q=None, fftfp16=False, k_rev=None): return FFTConvFunc.apply(u, k, D, dropout_mask, gelu, force_fp16_output, output_hbl_layout, v, head_dim, q, fftfp16, k_rev) ================================================ FILE: src/tasks/decoders.py ================================================ """Decoder heads. """ import torch import torch.nn as nn import torch.nn.functional as F import src.models.nn.utils as U import src.utils as utils import src.utils.train log = src.utils.train.get_logger(__name__) class Decoder(nn.Module): """This class doesn't do much but just signals the interface that Decoders are expected to adhere to TODO: is there a way to enforce the signature of the forward method? """ def forward(self, x, **kwargs): """ x: (batch, length, dim) input tensor state: additional state from the model backbone *args, **kwargs: additional info from the dataset Returns: y: output tensor *args: other arguments to pass into the loss function """ return x def step(self, x): """ x: (batch, dim) """ return self.forward(x.unsqueeze(1)).squeeze(1) class SequenceDecoder(Decoder): def __init__( self, d_model, d_output=None, l_output=None, use_lengths=False, mode="last", conjoin_train=False, conjoin_test=False ): super().__init__() self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output) if l_output is None: self.l_output = None self.squeeze = False elif l_output == 0: # Equivalent to getting an output of length 1 and then squeezing self.l_output = 1 self.squeeze = True else: assert l_output > 0 self.l_output = l_output self.squeeze = False self.use_lengths = use_lengths self.mode = mode if mode == 'ragged': assert not use_lengths self.conjoin_train = conjoin_train self.conjoin_test = conjoin_test def forward(self, x, state=None, lengths=None, l_output=None): """ x: (n_batch, l_seq, d_model) or potentially (n_batch, l_seq, d_model, 2) if using rc_conjoin Returns: (n_batch, l_output, d_output) """ if self.l_output is None: if l_output is not None: assert isinstance(l_output, int) # Override by pass in else: # Grab entire output l_output = x.size(1) squeeze = False else: l_output = self.l_output squeeze = self.squeeze if self.mode == "last": def restrict(x_seq): """Use last l_output elements of sequence.""" return x_seq[..., -l_output:, :] elif self.mode == "first": def restrict(x_seq): """Use first l_output elements of sequence.""" return x_seq[..., :l_output, :] elif self.mode == "pool": def restrict(x_seq): """Pool sequence over a certain range""" L = x_seq.size(1) s = x_seq.sum(dim=1, keepdim=True) if l_output > 1: c = torch.cumsum(x_seq[..., -(l_output - 1):, ...].flip(1), dim=1) c = F.pad(c, (0, 0, 1, 0)) s = s - c # (B, l_output, D) s = s.flip(1) denom = torch.arange( L - l_output + 1, L + 1, dtype=x_seq.dtype, device=x_seq.device ) s = s / denom return s elif self.mode == "sum": # TODO use same restrict function as pool case def restrict(x_seq): """Cumulative sum last l_output elements of sequence.""" return torch.cumsum(x_seq, dim=-2)[..., -l_output:, :] elif self.mode == 'ragged': assert lengths is not None, "lengths must be provided for ragged mode" def restrict(x_seq): """Ragged aggregation.""" # remove any additional padding (beyond max length of any sequence in the batch) return x_seq[..., : max(lengths), :] else: raise NotImplementedError( "Mode must be ['last' | 'first' | 'pool' | 'sum' | 'ragged']" ) # Restrict to actual length of sequence if self.use_lengths: assert lengths is not None x = torch.stack( [ restrict(out[..., :length, :]) for out, length in zip(torch.unbind(x, dim=0), lengths) ], dim=0, ) else: x = restrict(x) if squeeze: assert x.size(1) == 1 x = x.squeeze(1) if self.conjoin_train or (self.conjoin_test and not self.training): x, x_rc = x.chunk(2, dim=-1) x = self.output_transform(x.squeeze()) x_rc = self.output_transform(x_rc.squeeze()) x = (x + x_rc) / 2 else: x = self.output_transform(x) return x def step(self, x, state=None): # Ignore all length logic x_fwd = self.output_transform(x.mean(dim=1)) x_rc = self.output_transform(x.flip(dims=[1, 2]).mean(dim=1)).flip(dims=[1]) x_out = (x_fwd + x_rc) / 2 return x_out # For every type of encoder/decoder, specify: # - constructor class # - list of attributes to grab from dataset # - list of attributes to grab from model registry = { "stop": Decoder, "id": nn.Identity, "linear": nn.Linear, "sequence": SequenceDecoder, } model_attrs = { "linear": ["d_output"], "sequence": ["d_output"], "nd": ["d_output"], "retrieval": ["d_output"], "state": ["d_state", "state_to_tensor"], "forecast": ["d_output"], "token": ["d_output"], } dataset_attrs = { "linear": ["d_output"], "sequence": ["d_output", "l_output"], "nd": ["d_output"], "retrieval": ["d_output"], "state": ["d_output"], "forecast": ["d_output", "l_output"], "token": ["d_output"], } def _instantiate(decoder, model=None, dataset=None): """Instantiate a single decoder""" if decoder is None: return None if isinstance(decoder, str): name = decoder else: name = decoder["_name_"] # Extract arguments from attribute names dataset_args = utils.config.extract_attrs_from_obj( dataset, *dataset_attrs.get(name, []) ) model_args = utils.config.extract_attrs_from_obj(model, *model_attrs.get(name, [])) # Instantiate decoder obj = utils.instantiate(registry, decoder, *model_args, *dataset_args) return obj def instantiate(decoder, model=None, dataset=None): """Instantiate a full decoder config, e.g. handle list of configs Note that arguments are added in reverse order compared to encoder (model first, then dataset) """ decoder = utils.to_list(decoder) return U.PassthroughSequential( *[_instantiate(d, model=model, dataset=dataset) for d in decoder] ) ================================================ FILE: src/tasks/encoders.py ================================================ from torch import nn import src.models.nn.utils as U import src.utils as utils class Encoder(nn.Module): """Encoder abstraction Accepts a tensor and optional kwargs. Other than the main tensor, all other arguments should be kwargs. Returns a tensor and optional kwargs. Encoders are combined via U.PassthroughSequential which passes these kwargs through in a pipeline. The resulting kwargs are accumulated and passed into the model backbone. """ def forward(self, x, **kwargs): """ x: input tensor *args: additional info from the dataset (e.g. sequence lengths) Returns: y: output tensor *args: other arguments to pass into the model backbone """ return x, {} # For every type of encoder/decoder, specify: # - constructor class # - list of attributes to grab from dataset # - list of attributes to grab from model registry = { "stop": Encoder, "id": nn.Identity, "embedding": nn.Embedding, "linear": nn.Linear, } dataset_attrs = { "embedding": ["n_tokens"], "linear": ["d_input"], # TODO make this d_data? "class": ["n_classes"], "time": ["n_tokens_time"], "onehot": ["n_tokens"], "conv1d": ["d_input"], "patch2d": ["d_input"], } model_attrs = { "embedding": ["d_model"], "linear": ["d_model"], "position": ["d_model"], "class": ["d_model"], "time": ["d_model"], "onehot": ["d_model"], "conv1d": ["d_model"], "patch2d": ["d_model"], "timestamp_embedding": ["d_model"], "layer": ["d_model"], } def _instantiate(encoder, dataset=None, model=None): """Instantiate a single encoder""" if encoder is None: return None if isinstance(encoder, str): name = encoder else: name = encoder["_name_"] # Extract dataset/model arguments from attribute names dataset_args = utils.config.extract_attrs_from_obj( dataset, *dataset_attrs.get(name, []) ) model_args = utils.config.extract_attrs_from_obj(model, *model_attrs.get(name, [])) # Instantiate encoder obj = utils.instantiate(registry, encoder, *dataset_args, *model_args) return obj def instantiate(encoder, dataset=None, model=None): encoder = utils.to_list(encoder) return U.PassthroughSequential( *[_instantiate(e, dataset=dataset, model=model) for e in encoder] ) ================================================ FILE: src/tasks/metrics.py ================================================ import math from functools import partial import torch import torch.nn.functional as F import torchmetrics.functional as tm_f from sklearn.metrics import f1_score, roc_auc_score, matthews_corrcoef from torchmetrics.classification import MulticlassRecall, MulticlassPrecision from torchmetrics import Metric class CorrectAggregatedMetric(Metric): """This is needed to calculate some metrics b/c small batch sizes cause aggregation via a simple average to be off, as some classes might not be present in batch but will get penalized with a 0.""" def __init__(self, class_idx: int, dist_sync_on_step=False): # call `self.add_state`for every internal state that is needed for the metrics computations # dist_reduce_fx indicates the function that should be used to reduce # state from multiple processes super().__init__(dist_sync_on_step=dist_sync_on_step) self.class_idx = torch.tensor(class_idx) self.add_state("numerator", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("denominator", default=torch.tensor(0.0), dist_reduce_fx="sum") def _update(self, numerator, denominator, preds, y) -> tuple: raise NotImplemented def update(self, logits: torch.Tensor, y: torch.Tensor): # update metric states preds = torch.argmax(logits, dim=-1) logits = logits.view(-1, logits.shape[-1]) y = y.view(-1) assert preds.shape == y.shape, f"preds shape {preds.shape} != y shape {y.shape}" self.numerator, self.denominator = self._update(self.numerator, self.denominator, preds, y) def compute(self): # compute final result value = self.numerator.float() / self.denominator if self.denominator > 0 else torch.tensor(0.0) return value def reset(self): self.numerator = torch.tensor(0.0) self.denominator = torch.tensor(0.0) class AccuracyPerClass(CorrectAggregatedMetric): """Calculate per class accuracy, i.e. P(y_hat = class_idx AND y = class_idx OR y_hat != class_idx AND y != class_idx) """ def _update(self, numerator, denominator, preds, y) -> tuple: # Filter down to the class of interest class_idx = self.class_idx relevant_idxs = (y == class_idx) numerator += (preds[relevant_idxs] == class_idx).sum() denominator += relevant_idxs.sum() relevant_idxs = (y != class_idx) numerator += (preds[relevant_idxs] != class_idx).sum() denominator += relevant_idxs.sum() return numerator, denominator class PrecisionPerClass(CorrectAggregatedMetric): """Calculate per class precision, i.e. P(y_hat = y | y_hat = class_idx) """ def _update(self, numerator, denominator, preds, y) -> tuple: # Filter down to the class of interest class_idx = self.class_idx relevant_idxs = (preds == class_idx) numerator += (preds[relevant_idxs] == y[relevant_idxs]).sum() denominator += relevant_idxs.sum() return numerator, denominator class RecallPerClass(CorrectAggregatedMetric): """Calculate per class recall, i.e. P(y_hat = y | y = class_idx) """ def _update(self, numerator, denominator, preds, y) -> tuple: # Filter down to the class of interest class_idx = self.class_idx relevant_idxs = (y == class_idx) numerator += (preds[relevant_idxs] == y[relevant_idxs]).sum() denominator += relevant_idxs.sum() return numerator, denominator def mcc(logits, y): logits = logits.view(-1, logits.shape[-1]) y = y.view(-1) y_hat = torch.argmax(logits, dim=-1) return matthews_corrcoef(y.cpu().numpy(), y_hat.cpu().numpy()) def last_k_ppl(logits, y, seq_len=1024, k=None): ''' Calculate perplexity for last k tokens in a sequence. logits: (batch_size * seq_len, vocab_size), note, already flattened y: (batch_size * seq_len), note, already flattened seq_len: int, length of each sequence in the batch k: if None, use all tokens in sequence returns: (batch_size,) ppl for each sequence in the batch ''' if k is None: k = 0 # use the entire sequence # need to reshape logits and y to be (batch_size, seq_len, vocab_size) and (batch_size, seq_len) # respectively # breakpoint() logits = logits.view(-1, seq_len, logits.shape[-1]) y = y.view(-1, seq_len) # only use the last k values of seq dim in logits and y logits = logits[:, -k:, :] y = y[:, -k:] # reshape to flatten the batch and seq_len dimensions logits = logits.reshape(-1, logits.shape[-1]) y = y.reshape(-1) # get avg and put on cpu return F.cross_entropy(logits, y, reduction='none').view(y.shape[0], -1).mean().exp().cpu() def _student_t_map(mu, sigma, nu): sigma = F.softplus(sigma) nu = 2.0 + F.softplus(nu) return mu.squeeze(axis=-1), sigma.squeeze(axis=-1), nu.squeeze(axis=-1) def student_t_loss(outs, y): mu, sigma, nu = outs[..., 0], outs[..., 1], outs[..., 2] mu, sigma, nu = _student_t_map(mu, sigma, nu) y = y.squeeze(axis=-1) nup1_half = (nu + 1.0) / 2.0 part1 = 1.0 / nu * torch.square((y - mu) / sigma) Z = ( torch.lgamma(nup1_half) - torch.lgamma(nu / 2.0) - 0.5 * torch.log(math.pi * nu) - torch.log(sigma) ) ll = Z - nup1_half * torch.log1p(part1) return -ll.mean() def gaussian_ll_loss(outs, y): mu, sigma = outs[..., 0], outs[..., 1] y = y.squeeze(axis=-1) sigma = F.softplus(sigma) ll = -1.0 * ( torch.log(sigma) + 0.5 * math.log(2 * math.pi) + 0.5 * torch.square((y - mu) / sigma) ) return -ll.mean() def binary_cross_entropy(logits, y): # BCE loss requires squeezing last dimension of logits so it has the same shape as y # requires y to be float, since it's overloaded to represent a probability return F.binary_cross_entropy_with_logits(logits.squeeze(-1), y.float()) def binary_accuracy(logits, y): return torch.eq(logits.squeeze(-1) >= 0, y).float().mean() def padded_cross_entropy(logits, y, pad_mask, pad_value=-1): """Will ignore the pad value in label (eg, -1) logits: (batch_size, seq_len, vocab_size) y: (batch_size, seq_len) pad_mask: (batch_size, seq_len) """ # need to apply pad mask to y y_pad = y + pad_mask * pad_value logits = logits.view(-1, logits.shape[-1]) y_pad = y_pad.view(-1) return F.cross_entropy(logits, y_pad, ignore_index=pad_value) def cross_entropy(logits, y, ignore_index=-100): logits = logits.view(-1, logits.shape[-1]) y = y.view(-1) return F.cross_entropy(logits, y, ignore_index=ignore_index) def soft_cross_entropy(logits, y, label_smoothing=0.0): logits = logits.view(-1, logits.shape[-1]) # target is now 2d (no target flattening) return F.cross_entropy(logits, y, label_smoothing=label_smoothing) def accuracy(logits, y): logits = logits.view(-1, logits.shape[-1]) preds = torch.argmax(logits, dim=-1) if y.numel() > logits.shape[0]: # Mixup leads to this case: use argmax class y = y.argmax(dim=-1) y = y.view(-1) return torch.eq(preds, y).float().mean() def accuracy_ignore_index(logits, y, ignore_index=-100): num_classes = logits.shape[-1] preds = torch.argmax(logits, dim=-1) logits = logits.view(-1, logits.shape[-1]) y = y.view(-1) accuracy = tm_f.classification.accuracy(preds, y, 'multiclass', num_classes=num_classes, ignore_index=ignore_index, average='micro') return accuracy def accuracy_at_k(logits, y, k=1): logits = logits.view(-1, logits.shape[-1]) if y.numel() > logits.shape[0]: # Mixup leads to this case: use argmax class y = y.argmax(dim=-1) y = y.view(-1) return torch.topk(logits, k, dim=-1)[1].eq(y.unsqueeze(-1)).any(dim=-1).float().mean() def f1_binary(logits, y): logits = logits.view(-1, logits.shape[-1]) y = y.view(-1) y_hat = torch.argmax(logits, dim=-1) return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="binary") def f1_macro(logits, y): logits = logits.view(-1, logits.shape[-1]) y = y.view(-1) y_hat = torch.argmax(logits, dim=-1) return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="macro") def f1_micro(logits, y): logits = logits.view(-1, logits.shape[-1]) y = y.view(-1) y_hat = torch.argmax(logits, dim=-1) return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="micro") def roc_auc_macro(logits, y): logits = logits.view( -1, logits.shape[-1] ).detach() # KS: had to add detach to eval while training y = y.view(-1) return roc_auc_score( y.cpu().numpy(), F.softmax(logits, dim=-1).cpu().numpy()[:, 1], average="macro" ) def roc_auc_micro(logits, y): logits = logits.view(-1, logits.shape[-1]) y = y.view(-1) return roc_auc_score( y.cpu().numpy(), F.softmax(logits, dim=-1).cpu().numpy()[:, 1], average="micro" ) def mse(outs, y, len_batch=None): # assert outs.shape[:-1] == y.shape and outs.shape[-1] == 1 # outs = outs.squeeze(-1) if len(y.shape) < len(outs.shape): assert outs.shape[-1] == 1 outs = outs.squeeze(-1) if len_batch is None: return F.mse_loss(outs, y) else: # Computes the loss of the first `lens` items in the batches # TODO document the use case of this mask = torch.zeros_like(outs, dtype=torch.bool) for i, l in enumerate(len_batch): mask[i, :l, :] = 1 outs_masked = torch.masked_select(outs, mask) y_masked = torch.masked_select(y, mask) return F.mse_loss(outs_masked, y_masked) def forecast_rmse(outs, y, len_batch=None): # TODO: generalize, currently for Monash dataset return torch.sqrt(F.mse_loss(outs, y, reduction='none').mean(1)).mean() def mae(outs, y, len_batch=None): # assert outs.shape[:-1] == y.shape and outs.shape[-1] == 1 # outs = outs.squeeze(-1) if len(y.shape) < len(outs.shape): assert outs.shape[-1] == 1 outs = outs.squeeze(-1) if len_batch is None: return F.l1_loss(outs, y) else: # Computes the loss of the first `lens` items in the batches mask = torch.zeros_like(outs, dtype=torch.bool) for i, l in enumerate(len_batch): mask[i, :l, :] = 1 outs_masked = torch.masked_select(outs, mask) y_masked = torch.masked_select(y, mask) return F.l1_loss(outs_masked, y_masked) # Metrics that can depend on the loss def loss(x, y, loss_fn): """ This metric may be useful because the training loss may add extra regularization (e.g. weight decay implemented as L2 penalty), while adding this as a metric skips the additional losses """ return loss_fn(x, y) def bpb(x, y, loss_fn): """ bits per byte (image density estimation, speech generation, char LM) """ return loss_fn(x, y) / math.log(2) def ppl(x, y, loss_fn): return torch.exp(loss_fn(x, y)) # should have a better way to do this output_metric_fns = { "binary_cross_entropy": binary_cross_entropy, "cross_entropy": cross_entropy, "padded_cross_entropy": padded_cross_entropy, "binary_accuracy": binary_accuracy, # "precision": MulticlassPrecision, # "precision_species": partial(MulticlassPrecision, task='multiclass', average=None), "precision_species": partial(MulticlassPrecision, average=None), # "recall_species": partial(MulticlassRecall, task='multiclass', average=None), "recall_species": partial(MulticlassRecall, average=None), # "precision_class": partial(MulticlassPrecision, average=None), "precision_per_class": PrecisionPerClass, "recall": MulticlassRecall, "recall_per_class": RecallPerClass, "accuracy": accuracy, "accuracy_per_class": AccuracyPerClass, "accuracy_ignore_index": accuracy_ignore_index, 'accuracy@3': partial(accuracy_at_k, k=3), 'accuracy@5': partial(accuracy_at_k, k=5), 'accuracy@10': partial(accuracy_at_k, k=10), "eval_loss": loss, "mcc": mcc, "mse": mse, "mae": mae, "forecast_rmse": forecast_rmse, "f1_binary": f1_binary, "f1_macro": f1_macro, "f1_micro": f1_micro, "roc_auc_macro": roc_auc_macro, "roc_auc_micro": roc_auc_micro, "soft_cross_entropy": soft_cross_entropy, # only for pytorch 1.10+ "student_t": student_t_loss, "gaussian_ll": gaussian_ll_loss, } loss_metric_fns = { "loss": loss, "bpb": bpb, "ppl": ppl, } metric_fns = {**output_metric_fns, **loss_metric_fns} # TODO py3.9 ================================================ FILE: src/tasks/tasks.py ================================================ import inspect from typing import List import torch.nn as nn from einops import rearrange import src.models.nn.utils as U import src.tasks.metrics as M import torchmetrics as tm from src.models.nn.adaptive_softmax import AdaptiveEmbedding, ProjectedAdaptiveLogSoftmax from src.tasks.torchmetrics import torchmetric_fns as tm_mine from src.utils.config import to_list, instantiate from torchmetrics import MetricCollection class BaseTask: """ Abstract class that takes care of: - loss function - arbitrary metrics - forward pass - (optional) encoder module that interfaces with dataset (inputs) and model - (optional) decoder module that interfaces with dataset (targets) and model """ encoder = None decoder = None def __init__(self, dataset=None, model=None, loss=None, loss_val=None, metrics=None, torchmetrics=None): """ This class is allowed to grab attributes directly off a constructed dataset and model object """ self.dataset = dataset self.model = model if metrics is None: metrics = [] self.metric_names = to_list(metrics) if torchmetrics is None: torchmetrics = [] self.torchmetric_names = to_list(torchmetrics) self._tracked_torchmetrics = {} # The decoder might pass through arguments that the loss needs (e.g. sequence lengths) # but might also pass through extraneous arguments (e.g. sampling rate) # Wrap loss and metrics so that they accept kwargs and # Create loss function self.loss = instantiate(M.output_metric_fns, loss, partial=True) self.loss = U.discard_kwargs(self.loss) if loss_val is not None: self.loss_val = instantiate(M.output_metric_fns, loss_val, partial=True) self.loss_val = U.discard_kwargs(self.loss_val) torchmetrics = MetricCollection(self._init_torchmetrics()) self.train_torchmetrics = torchmetrics.clone(prefix='train/') self.val_torchmetrics = torchmetrics.clone(prefix='val/') self.test_torchmetrics = torchmetrics.clone(prefix='test/') def _init_torchmetrics(self): """ Instantiate torchmetrics. """ tracked_torchmetrics = {} for name in self.torchmetric_names: if name in tm_mine: tracked_torchmetrics[name] = tm_mine[name]() elif name in ['AUROC', 'StatScores', 'Precision', 'Recall', 'F1', 'F1Score']: tracked_torchmetrics[name] = getattr(tm, name)( average='macro', num_classes=self.dataset.d_output, compute_on_step=False ) elif name in ['MultilabelAUROC', 'MultilabelAveragePrecision']: tracked_torchmetrics[name] = getattr(tm, name)( average='macro', num_labels=self.dataset.d_output ) elif '@' in name: k = int(name.split('@')[1]) mname = name.split('@')[0] tracked_torchmetrics[name] = getattr(tm, mname)( average='macro', num_classes=self.dataset.d_output, compute_on_step=False, top_k=k ) else: tracked_torchmetrics[name] = getattr(tm, name)(compute_on_step=False) return tracked_torchmetrics def _reset_torchmetrics(self, prefix=None): """ Reset torchmetrics for a prefix associated with a particular dataloader (e.g. train, val, test). Generally do this at the start of an epoch. """ all_prefixes = [prefix] if prefix is not None else self._tracked_torchmetrics for prefix in all_prefixes: if prefix in self._tracked_torchmetrics: self._tracked_torchmetrics[prefix].reset() def get_torchmetrics(self, prefix): """ Compute torchmetrics for a prefix associated with a particular dataloader (e.g. train, val, test). Generally do this at the end of an epoch. """ return {name: self._tracked_torchmetrics[prefix][name].compute() for name in self.torchmetric_names} def torchmetrics(self, x, y, prefix, loss=None): """ Update torchmetrics with new x, y . Prefix corresponds to a particular dataloader (e.g. train, val, test). Generally call this every batch. """ if prefix not in self._tracked_torchmetrics: self._init_torchmetrics(prefix) self._tracked_torchmetrics[prefix](x, y, loss=loss) # for name in self.torchmetric_names: # if name.startswith('Accuracy'): # if len(x.shape) > 2: # # Multi-dimensional, multi-class # self._tracked_torchmetrics[prefix][name].update(x.transpose(1, 2), y.squeeze()) # continue # self._tracked_torchmetrics[prefix][name].update(x, y) def get_torchmetrics(self, prefix): return self._tracked_torchmetrics[prefix] def metrics(self, x, y, **kwargs): """ Metrics are just functions output metrics are a function of output and target loss metrics are a function of loss (e.g. perplexity) """ output_metrics = { name: U.discard_kwargs(M.output_metric_fns[name])(x, y, **kwargs) for name in self.metric_names if name in M.output_metric_fns } loss_metrics = { name: U.discard_kwargs(M.loss_metric_fns[name])(x, y, self.loss, **kwargs) for name in self.metric_names if name in M.loss_metric_fns } return {**output_metrics, **loss_metrics} def forward(self, batch, encoder, model, decoder, _state): """Passes a batch through the encoder, backbone, and decoder""" # z holds arguments such as sequence length x, y, *z = batch # z holds extra dataloader info such as resolution if len(z) == 0: z = {} else: assert len(z) == 1 and isinstance(z[0], dict), "Dataloader must return dictionary of extra arguments" z = z[0] # w can model-specific constructions, such as key_padding_mask for transformers or state for RNNs x, w = encoder(x, **z) x, state = model(x, **w, state=_state) self._state = state x, w = decoder(x, state=state, **z) return x, y, w class Scalar(nn.Module): def __init__(self, c=1): super().__init__() self.c = c def forward(self, x): return x * self.c class LMTask(BaseTask): def forward(self, batch, encoder, model, decoder, _state): """Passes a batch through the encoder, backbone, and decoder""" # z holds arguments such as sequence length x, y, *z = batch # z holds extra dataloader info such as resolution if len(z) == 0: z = {} else: assert len(z) == 1 and isinstance(z[0], dict), "Dataloader must return dictionary of extra arguments" z = z[0] # w can model-specific constructions such as key_padding_mask for transformers or state for RNNs x, w = encoder(x, **z) # Needed for Mamba (open-source repo version) if "state" in inspect.signature(model.forward).parameters.keys(): x, state = model(x, **w, state=_state) else: x = model(x, **w) state = None self._state = state x, w = decoder(x, state=state, **z) if hasattr(x, 'logits'): x = x.logits x = rearrange(x, '... C -> (...) C') y = rearrange(y, '... -> (...)') return x, y, w class MultiClass(BaseTask): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.continual_metrics = {} for name in self.metric_names: if name.endswith('_per_class'): for spec_idx, spec in enumerate(self.dataset.species): self.continual_metrics[name + '_' + spec] = M.output_metric_fns[name](spec_idx) elif name in ['precision_species', 'recall_species']: self.continual_metrics[name] = M.output_metric_fns[name](num_classes=len(self.dataset.species)) def metrics(self, x, y, **kwargs): output_metrics = {} for name in self.metric_names: if name in M.output_metric_fns: if name.endswith('_per_class'): for spec_idx, spec in enumerate(self.dataset.species): self.continual_metrics[name + '_' + spec] = self.continual_metrics[name + '_' + spec].to( x.device) self.continual_metrics[name + '_' + spec].update(x, y) output_metrics[name + '_' + spec] = self.continual_metrics[name + '_' + spec].compute() elif name in ['precision_species', 'recall_species']: self.continual_metrics[name] = self.continual_metrics[name].to(x.device) metrics = self.continual_metrics[name](x, y) for spec_idx, spec in enumerate(self.dataset.species): output_metrics[name[:-7] + spec] = metrics[spec_idx] else: output_metrics[name] = U.discard_kwargs(M.output_metric_fns[name])(x, y, **kwargs) loss_metrics = { name: U.discard_kwargs(M.loss_metric_fns[name])(x, y, self.loss, **kwargs) for name in self.metric_names if name in M.loss_metric_fns } return {**output_metrics, **loss_metrics} def _reset_torchmetrics(self, prefix=None): super()._reset_torchmetrics(prefix) for name in self.metric_names: if name.endswith('_per_class'): for spec_idx, spec in enumerate(self.dataset.species): self.continual_metrics[name + '_' + spec].reset() class HG38Task(LMTask): def __init__(self, dataset=None, model=None, loss=None, loss_val=None, metrics=None, torchmetrics=None, last_k_ppl=None, per_token_ppl=None): """ Extending LMTask to add custom metrics for HG38 task last_k_ppl: config for custom ppl, with hparams to pass with it per_token_ppl: config for per token ppl calc, with list of k (ppls) to track """ self.dataset = dataset self.model = model if metrics is None: metrics = [] self.metric_names = to_list(metrics) self.last_k_ppl = last_k_ppl self.per_token_ppl = per_token_ppl if torchmetrics is None: torchmetrics = [] self.torchmetric_names = to_list(torchmetrics) self._tracked_torchmetrics = {} # The decoder might pass through arguments that the loss needs (e.g. sequence lengths) # but might also pass through extraneous arguments (e.g. sampling rate) # Wrap loss and metrics so that they accept kwargs and # Create loss function self.loss = instantiate(M.output_metric_fns, loss, partial=True) self.loss = U.discard_kwargs(self.loss) if loss_val is not None: self.loss_val = instantiate(M.output_metric_fns, loss_val, partial=True) self.loss_val = U.discard_kwargs(self.loss_val) torchmetrics = MetricCollection(self._init_torchmetrics()) self.train_torchmetrics = torchmetrics.clone(prefix='train/') self.val_torchmetrics = torchmetrics.clone(prefix='val/') self.test_torchmetrics = torchmetrics.clone(prefix='test/') # Create custom metrics for last k ppl # last_k_ppl is a list of dicts (configs), so loop through them if self.last_k_ppl is not None: self.custom_ppl_dict = {} for k in self.last_k_ppl: key_name = "last_" + str(k) + "_ppl" # create config custom_ppl_config = {"_name_": "last_k_ppl", "k": k, "seq_len": self.dataset.max_length} k_ppl_fn = instantiate(M.output_metric_fns, custom_ppl_config, partial=True) k_ppl_fn = U.discard_kwargs(k_ppl_fn) self.custom_ppl_dict[key_name] = k_ppl_fn # Create custom metric for per token ppl if self.per_token_ppl is not None: per_token_ppl_config = {"_name_": "per_token_ppl", "ks": self.per_token_ppl["ks"], "seq_len": self.dataset.max_length} per_token_fn = instantiate(M.output_metric_fns, per_token_ppl_config, partial=True) per_token_fn = U.discard_kwargs(per_token_fn) self.per_token_fn = per_token_fn def metrics(self, x, y, **kwargs): """ Need to modify metrics to include custom metrics """ output_metrics = { name: U.discard_kwargs(M.output_metric_fns[name])(x, y, **kwargs) for name in self.metric_names if name in M.output_metric_fns } loss_metrics = { name: U.discard_kwargs(M.loss_metric_fns[name])(x, y, self.loss, **kwargs) for name in self.metric_names if name in M.loss_metric_fns } # loop through all custom ppls and add them to output_metrics if self.last_k_ppl is not None: for key_name, k_ppl_fn in self.custom_ppl_dict.items(): output_metrics[key_name] = k_ppl_fn(x, y, **kwargs) # loop through all custom ppls and add them to output_metrics if self.per_token_ppl is not None: # returns k ppl values, (averaged over batch) per_k_ppl = self.per_token_fn(x, y, **kwargs) # loop over ks to log metric for ind, k in enumerate(self.per_token_ppl["ks"]): key_name = "ppl_at_{}".format(k) output_metrics[key_name] = per_k_ppl[ind] # should be in order return {**output_metrics, **loss_metrics} class AdaptiveLMTask(BaseTask): def __init__( self, div_val, cutoffs: List[int], tie_weights: bool, tie_projs: List[bool], init_scale=1.0, bias_scale=0.0, dropemb=0.0, dropsoft=0.0, **kwargs, ): super().__init__(**kwargs) n_tokens = self.dataset.n_tokens d_model = self.model.d_model d_output = self.model.d_output encoder = AdaptiveEmbedding( n_tokens, d_model, d_model, cutoffs=cutoffs, div_val=div_val, init_scale=init_scale, dropout=dropemb, ) if tie_weights: assert d_model == d_output emb_layers = [i.weight for i in encoder.emb_layers] else: emb_layers = None # Construct decoder/loss emb_projs = encoder.emb_projs loss = ProjectedAdaptiveLogSoftmax( n_tokens, d_output, d_output, cutoffs, div_val=div_val, tie_projs=tie_projs, out_projs=emb_projs, out_layers_weights=emb_layers, bias_scale=bias_scale, dropout=dropsoft, ) self.encoder = encoder self.loss = loss registry = { 'base': BaseTask, 'multiclass': MultiClass, 'lm': LMTask, 'hg38': HG38Task, } ================================================ FILE: src/tasks/torchmetrics.py ================================================ # Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py # But we compute the perplexity correctly: exp(average(nll)), not average(exp(nll)) # Also adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/perplexity.py # But we pass in the loss to avoid recomputation from typing import Any, Dict, Optional import torch from torch import Tensor from torchmetrics import Metric try: from flash_attn.losses.cross_entropy import CrossEntropyLoss except ImportError: CrossEntropyLoss = torch.nn.CrossEntropyLoss try: from apex.transformer import parallel_state except ImportError: parallel_state = None class Perplexity(Metric): r""" Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits per word a model needs to represent the sample. Args: kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Examples: >>> import torch >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) >>> target[0, 6:] = -100 >>> metric = Perplexity(ignore_index=-100) >>> metric(preds, target) tensor(5.2545) """ is_differentiable = True higher_is_better = False full_state_update = False total_log_probs: Tensor count: Tensor def __init__(self, **kwargs: Dict[str, Any]): super().__init__(**kwargs) self.add_state("total_log_probs", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum") self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") self.loss_fn = CrossEntropyLoss() def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore """Compute and store intermediate statistics for Perplexity. Args: preds: Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. target: Ground truth values with a shape [batch_size, seq_len]. """ count = target.numel() if loss is None: loss = self.loss_fn(preds, target) self.total_log_probs += loss.double() * count self.count += count def compute(self) -> Tensor: """Compute the Perplexity. Returns: Perplexity """ return torch.exp(self.total_log_probs / self.count) class NumTokens(Metric): """Keep track of how many tokens we've seen. """ # TODO: how do we prevent the reset between the epochs? The reset happens on the 1st batch # of the next epoch. # Right now the hack is that we override reset(), which would mess up the forward method. # We then override forward to do the right thing. is_differentiable = False higher_is_better = False full_state_update = False count: Tensor def __init__(self, **kwargs: Dict[str, Any]): super().__init__(**kwargs) self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum", persistent=True) # We want the count to be saved to state-dict if parallel_state is not None and not parallel_state.is_unitialized(): self.tensor_parallel_world_size = parallel_state.get_tensor_model_parallel_world_size() else: self.tensor_parallel_world_size = 1 def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore self.count += target.numel() // self.tensor_parallel_world_size def compute(self) -> Tensor: return self.count def reset(self): count = self.count super().reset() self.count = count # Adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/metric.py def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: """forward computation using single call to `update` to calculate the metric value on the current batch and accumulate global state. This can be done when the global metric state is a sinple reduction of batch states. """ self.update(*args, **kwargs) return self.compute() torchmetric_fns = { "perplexity": Perplexity, "num_tokens": NumTokens, } ================================================ FILE: src/utils/__init__.py ================================================ from .config import is_list, is_dict, to_list, to_dict, get_class, instantiate ================================================ FILE: src/utils/config.py ================================================ """Utilities for dealing with collection objects (lists, dicts) and configs. """ import functools from typing import Sequence, Mapping, Callable import hydra from omegaconf import ListConfig, DictConfig # TODO this is usually used in a pattern where it's turned into a list, so can just do that here def is_list(x): return isinstance(x, Sequence) and not isinstance(x, str) def is_dict(x): return isinstance(x, Mapping) def to_dict(x, recursive=True): """Convert Sequence or Mapping object to dict lists get converted to {0: x[0], 1: x[1], ...} """ if is_list(x): x = {i: v for i, v in enumerate(x)} if is_dict(x): if recursive: return {k: to_dict(v, recursive=recursive) for k, v in x.items()} else: return dict(x) else: return x def to_list(x, recursive=False): """Convert an object to list. If Sequence (e.g. list, tuple, Listconfig): just return it Special case: If non-recursive and not a list, wrap in list """ if is_list(x): if recursive: return [to_list(_x) for _x in x] else: return list(x) else: if recursive: return x else: return [x] def extract_attrs_from_obj(obj, *attrs): if obj is None: assert len(attrs) == 0 return [] return [getattr(obj, attr, None) for attr in attrs] def auto_assign_attrs(cls, **kwargs): for k, v in kwargs.items(): setattr(cls, k, v) def instantiate(registry, config, *args, partial=False, wrap=None, **kwargs): """ registry: Dictionary mapping names to functions or target paths (e.g. {'model': 'models.SequenceModel'}) config: Dictionary with a '_name_' key indicating which element of the registry to grab, and kwargs to be passed into the target constructor wrap: wrap the target class (e.g. ema optimizer or tasks.wrap) *args, **kwargs: additional arguments to override the config to pass into the target constructor """ # Case 1: no config if config is None: return None # Case 2a: string means _name_ was overloaded if isinstance(config, str): _name_ = None _target_ = registry[config] config = {} # Case 2b: grab the desired callable from name else: _name_ = config.pop("_name_") _target_ = registry[_name_] # Retrieve the right constructor automatically based on type if isinstance(_target_, str): fn = hydra.utils.get_method(path=_target_) elif isinstance(_target_, Callable): fn = _target_ else: raise NotImplementedError("instantiate target must be string or callable") # Instantiate object if wrap is not None: fn = wrap(fn) obj = functools.partial(fn, *args, **config, **kwargs) # Restore _name_ if _name_ is not None: config["_name_"] = _name_ if partial: return obj else: return obj() def get_class(registry, _name_): return hydra.utils.get_class(path=registry[_name_]) def omegaconf_filter_keys(d, fn=None): """Only keep keys where fn(key) is True. Support nested DictConfig. # TODO can make this inplace? """ if fn is None: fn = lambda _: True if is_list(d): return ListConfig([omegaconf_filter_keys(v, fn) for v in d]) elif is_dict(d): return DictConfig( {k: omegaconf_filter_keys(v, fn) for k, v in d.items() if fn(k)} ) else: return d ================================================ FILE: src/utils/optim/schedulers.py ================================================ """Custom learning rate schedulers""" import math import warnings import torch from timm.scheduler import CosineLRScheduler # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html class CosineWarmup(torch.optim.lr_scheduler.CosineAnnealingLR): def __init__(self, optimizer, T_max, eta_min=0, warmup_step=0, **kwargs): self.warmup_step = warmup_step super().__init__(optimizer, T_max - warmup_step, eta_min, *kwargs) # Copied from CosineAnnealingLR, but adding warmup and changing self.last_epoch to # self.last_epoch - self.warmup_step. def get_lr(self): if not self._get_lr_called_within_step: warnings.warn("To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning) if self.last_epoch == self.warmup_step: # also covers the case where both are 0 return self.base_lrs elif self.last_epoch < self.warmup_step: return [base_lr * (self.last_epoch + 1) / self.warmup_step for base_lr in self.base_lrs] elif (self.last_epoch - self.warmup_step - 1 - self.T_max) % (2 * self.T_max) == 0: return [group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)] return [(1 + math.cos(math.pi * (self.last_epoch - self.warmup_step) / self.T_max)) / (1 + math.cos(math.pi * (self.last_epoch - self.warmup_step - 1) / self.T_max)) * (group['lr'] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups] _get_closed_form_lr = None def InvSqrt(optimizer, warmup_step): """ Originally used for Transformer (in Attention is all you need) """ def lr_lambda(step): # return a multiplier instead of a learning rate if step == warmup_step: # also covers the case where both are 0 return 1. else: return 1. / (step ** 0.5) if step > warmup_step else (step + 1) / (warmup_step ** 1.5) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) def Constant(optimizer, warmup_step): def lr_lambda(step): if step == warmup_step: # also covers the case where both are 0 return 1. else: return 1. if step > warmup_step else (step + 1) / warmup_step return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): """ Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch. It supports resuming as well. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._last_epoch = -1 self.step(epoch=0) def step(self, epoch=None): if epoch is None: self._last_epoch += 1 else: self._last_epoch = epoch # We call either step or step_update, depending on whether we're using the scheduler every # epoch or every step. # Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set # scheduler interval to "step", then the learning rate update will be wrong. if self.t_in_epochs: super().step(epoch=self._last_epoch) else: super().step_update(num_updates=self._last_epoch) ================================================ FILE: src/utils/optim_groups.py ================================================ """Utilities for special optimizer hyperparameters. group_parameters_for_optimizer is a modification of timm's optimizer logic, which is currently unused add_optimizer_hooks is an improved version that uses this codebase's _optim dictionary """ import inspect import torch.nn as nn import hydra def add_optimizer_hooks( model, bias_weight_decay=False, normalization_weight_decay=False, ): """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for normalization parameters if normalization_weight_decay==False """ # Separate out all parameters to those that will and won't experience regularizing weight decay blacklist_weight_modules = (nn.Embedding, ) if not normalization_weight_decay: blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, # Not compatible with Pytorch 1.8.1 # nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, nn.GroupNorm, nn.SyncBatchNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, nn.LayerNorm, nn.LocalResponseNorm) for mn, m in model.named_modules(): for pn, p in m.named_parameters(): if (not bias_weight_decay and pn.endswith('bias')) \ or getattr(p, '_no_weight_decay', False) \ or isinstance(m, blacklist_weight_modules): setattr(p, "_optim", {"weight_decay": 0.0}) def group_parameters_for_optimizer( model, optimizer_cfg, bias_weight_decay=False, normalization_weight_decay=False, ): """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for normalization parameters if normalization_weight_decay==False """ # Get the weight decay from the config, or from the default value of the optimizer constructor # if it's not specified in the config. if 'weight_decay' in optimizer_cfg: weight_decay = optimizer_cfg.weight_decay else: # https://stackoverflow.com/questions/12627118/get-a-function-arguments-default-value signature = inspect.signature(hydra.utils.get_class(optimizer_cfg._target_)) if 'weight_decay' in signature.parameters: weight_decay = signature.parameters['weight_decay'].default if weight_decay is inspect.Parameter.empty: weight_decay = 0.0 else: weight_decay = 0.0 # If none of the parameters have weight decay anyway, and there are no parameters with special # optimization params if weight_decay == 0.0 and not any(hasattr(p, '_optim') for p in model.parameters()): return model.parameters() skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else set() skip_keywords = (model.no_weight_decay_keywords() if hasattr(model, 'no_weight_decay_keywords') else set()) # Adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134 """ This long function is unfortunately doing something very simple and is being very defensive: We are separating out all parameters of the model into two buckets: those that will experience weight decay for regularization and those that won't (biases, and layernorm/embedding weights). We are then returning the PyTorch optimizer object. """ # separate out all parameters to those that will and won't experience regularizing weight decay decay = set() no_decay = set() special = set() whitelist_weight_modules = (nn.Linear, ) blacklist_weight_modules = (nn.Embedding, ) if not normalization_weight_decay: blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, # Not compatible with Pytorch 1.8.1 # nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, nn.GroupNorm, nn.SyncBatchNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, nn.LayerNorm, nn.LocalResponseNorm) for mn, m in model.named_modules(): for pn, p in m.named_parameters(): fpn = '%s.%s' % (mn, pn) if mn else pn # full param name if not p.requires_grad: continue # frozen weights if hasattr(p, '_optim'): special.add(fpn) elif fpn in skip or any(skip_keyword in fpn for skip_keyword in skip_keywords): no_decay.add(fpn) elif getattr(p, '_no_weight_decay', False): no_decay.add(fpn) elif not bias_weight_decay and pn.endswith('bias'): no_decay.add(fpn) elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed decay.add(fpn) elif isinstance(m, blacklist_weight_modules): # weights of blacklist modules will NOT be weight decayed no_decay.add(fpn) param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} # special case the position embedding parameter in the root GPT module as not decayed if 'pos_emb' in param_dict: no_decay.add('pos_emb') # In case of parameter sharing, some parameters show up in decay but are not in param_dict.keys() decay &= param_dict.keys() decay |= (param_dict.keys() - no_decay - special) # validate that we considered every parameter inter_params = decay & no_decay union_params = decay | no_decay assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both decay/no_decay sets!" assert len(param_dict.keys() - special - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" if weight_decay == 0.0 or not no_decay: param_groups = [{"params": [param_dict[pn] for pn in sorted(list(no_decay | decay))], "weight_decay": weight_decay}] else: param_groups = [ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] # Add parameters with special hyperparameters # Unique dicts hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)] for hp in hps: params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp] param_groups.append({"params": params, **hp}) return param_groups ================================================ FILE: src/utils/registry.py ================================================ """Class registry for models, layers, optimizers, and schedulers. """ optimizer = { "adam": "torch.optim.Adam", "adamw": "torch.optim.AdamW", "rmsprop": "torch.optim.RMSprop", "sgd": "torch.optim.SGD", "lamb": "src.utils.optim.lamb.JITLamb", } scheduler = { "constant": "transformers.get_constant_schedule", "plateau": "torch.optim.lr_scheduler.ReduceLROnPlateau", "step": "torch.optim.lr_scheduler.StepLR", "multistep": "torch.optim.lr_scheduler.MultiStepLR", "cosine": "torch.optim.lr_scheduler.CosineAnnealingLR", "constant_warmup": "transformers.get_constant_schedule_with_warmup", "linear_warmup": "transformers.get_linear_schedule_with_warmup", "cosine_warmup": "transformers.get_cosine_schedule_with_warmup", "cosine_warmup_timm": "src.utils.optim.schedulers.TimmCosineLRScheduler", } model = { # Pre-training LM head models "hyena_lm": "src.models.sequence.long_conv_lm.ConvLMHeadModel", "mamba_lm": "mamba_ssm.models.mixer_seq_simple.MambaLMHeadModel", "caduceus_lm": "caduceus.modeling_caduceus.CaduceusForMaskedLM", # Downstream task embedding backbones "dna_embedding": "src.models.sequence.dna_embedding.DNAEmbeddingModel", "dna_embedding_mamba": "src.models.sequence.dna_embedding.DNAEmbeddingModelMamba", "dna_embedding_caduceus": "src.models.sequence.dna_embedding.DNAEmbeddingModelCaduceus", # Baseline for genomics benchmark "genomics_benchmark_cnn": "src.models.baseline.genomics_benchmark_cnn.GenomicsBenchmarkCNN", } layer = { "id": "src.models.sequence.base.SequenceIdentity", "ff": "src.models.sequence.ff.FF", "hyena": "src.models.sequence.hyena.HyenaOperator", "hyena-filter": "src.models.sequence.hyena.HyenaFilter", } callbacks = { "learning_rate_monitor": "pytorch_lightning.callbacks.LearningRateMonitor", "model_checkpoint": "pytorch_lightning.callbacks.ModelCheckpoint", "model_checkpoint_every_n_steps": "pytorch_lightning.callbacks.ModelCheckpoint", "model_checkpoint_every_epoch": "pytorch_lightning.callbacks.ModelCheckpoint", "early_stopping": "pytorch_lightning.callbacks.EarlyStopping", "swa": "pytorch_lightning.callbacks.StochasticWeightAveraging", "rich_model_summary": "pytorch_lightning.callbacks.RichModelSummary", "rich_progress_bar": "pytorch_lightning.callbacks.RichProgressBar", "params": "src.callbacks.params.ParamsLog", "timer": "src.callbacks.timer.Timer", "val_every_n_global_steps": "src.callbacks.validation.ValEveryNGlobalSteps", } model_state_hook = { 'load_backbone': 'src.models.sequence.dna_embedding.load_backbone', } ================================================ FILE: src/utils/train.py ================================================ """ Utils for the training loop. Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py """ import json import logging import warnings import rich.syntax import rich.tree import torch.nn as nn from omegaconf import DictConfig, OmegaConf from pytorch_lightning.utilities import rank_zero_only from src.utils.config import omegaconf_filter_keys # Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging class LoggingContext: def __init__(self, logger, level=None, handler=None, close=True): self.logger = logger self.level = level self.handler = handler self.close = close def __enter__(self): if self.level is not None: self.old_level = self.logger.level self.logger.setLevel(self.level) if self.handler: self.logger.addHandler(self.handler) def __exit__(self, et, ev, tb): if self.level is not None: self.logger.setLevel(self.old_level) if self.handler: self.logger.removeHandler(self.handler) if self.handler and self.close: self.handler.close() # implicit return of None => don't swallow exceptions def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: """Initializes multi-GPU-friendly python logger.x""" logger = logging.getLogger(name) logger.setLevel(level) # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): setattr(logger, level, rank_zero_only(getattr(logger, level))) return logger def process_config(config: DictConfig) -> DictConfig: # TODO because of filter_keys, this is no longer in place """A couple of optional utilities, controlled by main config file: - disabling warnings - easier access to debug mode - forcing debug friendly configuration Modifies DictConfig in place. Args: config (DictConfig): Configuration composed by Hydra. """ log = get_logger() # Filter out keys that were used just for interpolation config = omegaconf_filter_keys(config, lambda k: not k.startswith('__')) # enable adding new keys to config OmegaConf.set_struct(config, False) # disable python warnings if if config.get("ignore_warnings"): log.info("Disabling python warnings! ") warnings.filterwarnings("ignore") if config.get("debug"): log.info("Running in debug mode! ") config.trainer.fast_dev_run = True # force debugger friendly configuration log.info("Forcing debugger friendly configuration! ") # Debuggers don't like GPUs or multiprocessing if config.trainer.get("gpus"): config.trainer.gpus = 0 if config.loader.get("pin_memory"): config.loader.pin_memory = False if config.loader.get("num_workers"): config.loader.num_workers = 0 # disable adding new keys to config # OmegaConf.set_struct(config, True) # [21-09-17 AG] I need this for .pop(_name_) pattern among other things return config @rank_zero_only def print_config( config: DictConfig, resolve: bool = True, save_cfg=True, ) -> None: """Prints content of DictConfig using Rich library and its tree structure. Args: config (DictConfig): Configuration composed by Hydra. resolve (bool, optional): Whether to resolve reference fields of DictConfig. save_cfg (bool, optional): Whether to save the config to a file. """ style = "dim" tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) fields = config.keys() for field in fields: branch = tree.add(field, style=style, guide_style=style) config_section = config.get(field) branch_content = str(config_section) if isinstance(config_section, DictConfig): branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) branch.add(rich.syntax.Syntax(branch_content, "yaml")) rich.print(tree) if save_cfg: with open("config_tree.txt", "w") as fp: rich.print(tree, file=fp) with open("model_config.json", "w") as fp: # Save config / model config for use in fine-tuning or testing model_config = { k: v for k, v in OmegaConf.to_container(config.model, resolve=True).items() if not k.startswith("_") or k == "config_path" } json.dump(model_config, fp, indent=4) with open("config.json", "w") as fp: json.dump(OmegaConf.to_container(config, resolve=True), fp, indent=4) def log_optimizer(logger, optimizer, keys): """ Log values of particular keys from the optimizers param groups """ keys = sorted(keys) for i, g in enumerate(optimizer.param_groups): group_hps = {k: g.get(k, None) for k in keys} logger.info(' | '.join([ f"Optimizer group {i}", f"{len(g['params'])} tensors", ] + [f"{k} {v}" for k, v in group_hps.items()])) class OptimModule(nn.Module): """ Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """ def register(self, name, tensor, lr=None, wd=0.0): """Register a tensor with a configurable learning rate and 0 weight decay""" if lr == 0.0: self.register_buffer(name, tensor) else: self.register_parameter(name, nn.Parameter(tensor)) optim = {} if lr is not None: optim["lr"] = lr if wd is not None: optim["weight_decay"] = wd setattr(getattr(self, name), "_optim", optim) ================================================ FILE: train.py ================================================ """Main training entry point for pre-training and downstream fine-tuning. """ import json import os import random import time from functools import wraps from typing import Callable, List, Sequence import fsspec import hydra import pytorch_lightning as pl import torch import wandb from omegaconf import OmegaConf from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn import src.models.nn.utils as U import src.utils as utils import src.utils.train from src.dataloaders import SequenceDataset # TODO make registry from src.tasks import decoders, encoders, tasks from src.utils import registry from src.utils.optim_groups import add_optimizer_hooks log = src.utils.train.get_logger(__name__) # Turn on TensorFloat32 (speeds up large model training substantially) import torch.backends torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True OmegaConf.register_new_resolver('eval', eval) OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y) OmegaConf.register_new_resolver('min', lambda x, y: min([x, y])) # Lots of annoying hacks to get WandbLogger to continuously retry on failure class DummyExperiment: """Dummy experiment.""" def nop(self, *args, **kw): pass def __getattr__(self, _): return self.nop def __getitem__(self, idx) -> "DummyExperiment": # enables self.logger.experiment[0].add_image(...) return self def __setitem__(self, *args, **kwargs) -> None: pass def rank_zero_experiment(fn: Callable) -> Callable: """Returns the real experiment on rank 0 and otherwise the DummyExperiment.""" @wraps(fn) def experiment(self): @rank_zero_only def get_experiment(): return fn(self) return get_experiment() or DummyExperiment() return experiment class CustomWandbLogger(WandbLogger): def __init__(self, *args, **kwargs): """Modified logger that insists on a wandb.init() call and catches wandb's error if thrown.""" super().__init__(*args, **kwargs) @property @rank_zero_experiment def experiment(self): r""" Actual wandb object. To use wandb features in your :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. Example:: code-block:: python self.logger.experiment.some_wandb_function() """ if self._experiment is None: if self._offline: os.environ["WANDB_MODE"] = "dryrun" attach_id = getattr(self, "_attach_id", None) if wandb.run is not None: # wandb process already created in this instance rank_zero_warn( "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse" " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`." ) self._experiment = wandb.run elif attach_id is not None and hasattr(wandb, "_attach"): # attach to wandb process referenced self._experiment = wandb._attach(attach_id) else: # create new wandb process while True: try: self._experiment = wandb.init(**self._wandb_init) break except Exception as e: log.error("wandb Exception:\n", e) t = random.randint(30, 60) log.warning(f"Sleeping for {t} seconds") time.sleep(t) # define default x-axis if getattr(self._experiment, "define_metric", None): self._experiment.define_metric("trainer/global_step") self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) return self._experiment class SequenceLightningModule(pl.LightningModule): def __init__(self, config): # Disable profiling executor. This reduces memory and increases speed. try: torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_mode(False) except AttributeError: pass super().__init__() # Passing in config expands it one level: access by self.hparams.train instead of self.hparams.config.train self.save_hyperparameters(config, logger=False) # Dataset arguments self.dataset = SequenceDataset.registry[self.hparams.dataset._name_]( **self.hparams.dataset ) # Check hparams self._check_config() # PL has some bugs, so add hooks and make sure they're only called once self._has_setup = False # To be set in `setup` self.encoder, self.decoder, self.model = None, None, None self.task, self.loss, self.loss_val = None, None, None self.metrics, self.train_torchmetrics, self.val_torchmetrics, self.test_torchmetrics = None, None, None, None self.setup() self._state = None self.val_loader_names, self.test_loader_names = None, None def setup(self, stage=None): if not self.hparams.train.disable_dataset: self.dataset.setup() # We need to set up the model in setup() because for some reason when training with DDP, one GPU uses much more # memory than the others. # In order to not overwrite the model multiple times during different stages, we need this hack # TODO PL 1.5 seems to have an option to skip hooks to avoid this # https://github.com/PyTorchLightning/pytorch-lightning/issues/5410#issuecomment-762257024 if self._has_setup: return else: self._has_setup = True # Convenience feature: if model specifies encoder, combine it with main encoder encoder_cfg = utils.to_list(self.hparams.encoder) + utils.to_list( self.hparams.model.pop("encoder", None) ) decoder_cfg = utils.to_list( self.hparams.model.pop("decoder", None) ) + utils.to_list(self.hparams.decoder) # Instantiate model config_path = self.hparams.model.pop("config_path", None) if config_path is not None: with open(config_path) as f: model_config_from_file = json.load(f) self.hparams.model.update(model_config_from_file) # Check if dropout_layer_norm is compiled try: from flash_attn.ops.layer_norm import dropout_add_layer_norm except ImportError: if self.hparams.model.get("fused_dropout_add_ln", None) is not None: self.hparams.model.update({"fused_dropout_add_ln": False}) # TODO: Hacky way to get complement_map for Caduceus models; need to find a more elegant implementation if "caduceus" in self.hparams.model.get("_name_"): OmegaConf.update( self.hparams.model.config, "complement_map", self.dataset.tokenizer.complement_map, force_add=True ) # Instantiate the config class if using hydra's _target_ paradigm for the config if self.hparams.model.get("config", None) is not None and self.hparams.model.config.get("_target_", None) is not None: model_hparams = OmegaConf.to_container(self.hparams.model, resolve=True) model_hparams["config"] = hydra.utils.instantiate(model_hparams["config"]) self.model = utils.instantiate(registry.model, model_hparams) else: self.model = utils.instantiate(registry.model, self.hparams.model) if (name := self.hparams.train.post_init_hook['_name_']) is not None: kwargs = self.hparams.train.post_init_hook.copy() del kwargs['_name_'] for module in self.modules(): if hasattr(module, name): getattr(module, name)(**kwargs) # if self.hparams.train.get("compile_model", False): # self.model = torch.compile(self.model, dynamic=False) # Instantiate the task self.task = utils.instantiate( tasks.registry, self.hparams.task, dataset=self.dataset, model=self.model ) # Create encoders and decoders encoder = encoders.instantiate( encoder_cfg, dataset=self.dataset, model=self.model ) decoder = decoders.instantiate( decoder_cfg, model=self.model, dataset=self.dataset ) # Extract the modules, so they show up in the top level parameter count self.encoder = U.PassthroughSequential(self.task.encoder, encoder) self.decoder = U.PassthroughSequential(decoder, self.task.decoder) self.loss = self.task.loss self.loss_val = self.task.loss if hasattr(self.task, 'loss_val'): self.loss_val = self.task.loss_val self.metrics = self.task.metrics self.train_torchmetrics = self.task.train_torchmetrics self.val_torchmetrics = self.task.val_torchmetrics self.test_torchmetrics = self.task.test_torchmetrics def load_state_dict(self, state_dict, strict=False): if self.hparams.train.pretrained_model_state_hook['_name_'] is not None: model_state_hook = utils.instantiate( registry.model_state_hook, self.hparams.train.pretrained_model_state_hook.copy(), partial=True, ) state_dict = model_state_hook(self.model, state_dict) log.info("Custom load_state_dict function is running.") # strict==True will require all modules to match # strict==False can allow encoder/decoder to be loaded from scratch too return super().load_state_dict(state_dict, strict=strict) def _check_config(self): assert self.hparams.train.state.mode in [None, "none", "null", "reset", "bptt", "tbptt"] assert ( (n := self.hparams.train.state.n_context) is None or isinstance(n, int) and n >= 0 ) assert ( (n := self.hparams.train.state.n_context_eval) is None or isinstance(n, int) and n >= 0 ) def _initialize_state(self): """Called at model setup and start of epoch to completely reset state""" self._state = None self._memory_chunks = [] def _reset_state(self, batch, device=None): """Called to construct default_state when necessary, e.g. during BPTT""" device = device or batch[0].device self._state = self.model.default_state(*batch[0].shape[:1], device=device) def _detach_state(self, state): if isinstance(state, torch.Tensor): return state.detach() elif isinstance(state, tuple): return tuple(self._detach_state(s) for s in state) elif isinstance(state, list): return [self._detach_state(s) for s in state] elif isinstance(state, dict): return {k: self._detach_state(v) for k, v in state.items()} elif state is None: return None else: raise NotImplementedError def _process_state(self, batch, batch_idx, training=True): """Handle logic for state context.""" # Number of context steps key = "n_context" if training else "n_context_eval" n_context = self.hparams.train.state.get(key) # Don't need to do anything if 0 context steps. Make sure there is no state if n_context == 0 and self.hparams.train.state.mode not in ['tbptt']: self._initialize_state() return # Reset state if needed if self.hparams.train.state.mode == "reset": if batch_idx % (n_context + 1) == 0: self._reset_state(batch) # Pass through memory chunks elif self.hparams.train.state.mode == "bptt": self._reset_state(batch) with torch.no_grad(): # should be unnecessary because individual modules should handle this for _batch in self._memory_chunks: self.forward(_batch) # Prepare for next step self._memory_chunks.append(batch) self._memory_chunks = self._memory_chunks[-n_context:] elif self.hparams.train.state.mode == 'tbptt': _, _, z = batch reset = z["reset"] if reset: self._reset_state(batch) else: self._state = self._detach_state(self._state) def forward(self, batch): return self.task.forward(batch, self.encoder, self.model, self.decoder, self._state) def step(self, x_t): x_t, *_ = self.encoder(x_t) # Potential edge case for encoders that expect (B, L, H)? x_t, state = self.model.step(x_t, state=self._state) self._state = state x_t, *_ = self.decoder.step(x_t, state=state) return x_t def _shared_step(self, batch, batch_idx, prefix="train"): """Shared step logic between training, validation, and test""" self._process_state(batch, batch_idx, training=(prefix == "train")) x, y, w = self.forward(batch) # Loss if prefix == 'train': loss = self.loss(x, y, **w) else: loss = self.loss_val(x, y, **w) # Metrics metrics = self.metrics(x, y, **w) metrics["loss"] = loss metrics = {f"{prefix}/{k}": v for k, v in metrics.items()} # Calculate torchmetrics torchmetrics = getattr(self, f'{prefix}_torchmetrics') torchmetrics(x, y, loss=loss) log_on_step = 'eval' in self.hparams and self.hparams.eval.get('log_on_step', False) and prefix == 'train' self.log_dict( metrics, on_step=log_on_step, on_epoch=True, prog_bar=True, add_dataloader_idx=False, sync_dist=True, ) # log the whole dict, otherwise lightning takes the mean to reduce it # https://pytorch-lightning.readthedocs.io/en/stable/visualize/logging_advanced.html#enable-metrics-for-distributed-training self.log_dict( torchmetrics, on_step=log_on_step, on_epoch=True, prog_bar=True, add_dataloader_idx=False, sync_dist=True, ) return loss def on_train_epoch_start(self): # Reset training torchmetrics self.task._reset_torchmetrics("train") def training_epoch_end(self, outputs): # Log training torchmetrics super().training_epoch_end(outputs) def on_validation_epoch_start(self): # Reset all validation torchmetrics for name in self.val_loader_names: self.task._reset_torchmetrics(name) def validation_epoch_end(self, outputs): # Log all validation torchmetrics super().validation_epoch_end(outputs) def on_test_epoch_start(self): # Reset all test torchmetrics for name in self.test_loader_names: self.task._reset_torchmetrics(name) def test_epoch_end(self, outputs): # Log all test torchmetrics super().test_epoch_end(outputs) def training_step(self, batch, batch_idx, dataloader_idx=0): loss = self._shared_step(batch, batch_idx, prefix="train") # Log the loss explicitly so that it shows up in WandB # Note that this currently runs into a bug in the progress bar with ddp (as of 1.4.6) # https://github.com/PyTorchLightning/pytorch-lightning/pull/9142 # We additionally log the epochs under 'trainer' to get a consistent prefix with 'global_step' loss_epoch = {"trainer/loss": loss, "trainer/epoch": float(self.current_epoch)} self.log_dict( loss_epoch, on_step=True, on_epoch=False, prog_bar=False, add_dataloader_idx=False, sync_dist=True, ) # Log any extra info that the models want to expose (e.g. output norms) metrics = {} for module in list(self.modules())[1:]: if hasattr(module, "metrics"): metrics.update(module.metrics) self.log_dict( metrics, on_step=True, on_epoch=False, prog_bar=False, add_dataloader_idx=False, sync_dist=True, ) return loss def validation_step(self, batch, batch_idx, dataloader_idx=0): # There's a bit of an annoying edge case with the first (0-th) epoch; it has to be excluded due to the initial # sanity check ema = ( self.val_loader_names[dataloader_idx].endswith("/ema") and self.optimizers().optimizer.stepped ) if ema: self.optimizers().swap_ema() loss = self._shared_step( batch, batch_idx, prefix=self.val_loader_names[dataloader_idx] ) if ema: self.optimizers().swap_ema() return loss def test_step(self, batch, batch_idx, dataloader_idx=0): return self._shared_step( batch, batch_idx, prefix=self.test_loader_names[dataloader_idx] ) def configure_optimizers(self): # Set zero weight decay for some params if 'optimizer_param_grouping' in self.hparams.train: add_optimizer_hooks(self.model, **self.hparams.train.optimizer_param_grouping) # Normal parameters all_params = list(self.parameters()) params = [p for p in all_params if not hasattr(p, "_optim")] optimizer = utils.instantiate(registry.optimizer, self.hparams.optimizer, params) del self.hparams.optimizer._name_ # Add parameters with special hyperparameters hps = [getattr(p, "_optim") for p in all_params if hasattr(p, "_optim")] hps = [ # dict(s) for s in set(frozenset(hp.items()) for hp in hps) dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps))) # dict(s) for s in dict.fromkeys(frozenset(hp.items()) for hp in hps) ] # Unique dicts print("Hyperparameter groups:", hps) # TODO: log.info throws error because hps is list of dicts for hp in hps: params = [p for p in all_params if getattr(p, "_optim", None) == hp] optimizer.add_param_group( {"params": params, **self.hparams.optimizer, **hp} ) # Layer Decay if self.hparams.train.layer_decay['_name_'] is not None: get_num_layer = utils.instantiate( registry.layer_decay, self.hparams.train.layer_decay['_name_'], partial=True, ) # Go through all parameters and get num layer layer_wise_groups = {} num_max_layers = 0 for name, p in self.named_parameters(): # Get layer id for each parameter in the model layer_id = get_num_layer(name) # Add to layer wise group if layer_id not in layer_wise_groups: layer_wise_groups[layer_id] = { 'params': [], 'lr': None, 'weight_decay': self.hparams.optimizer.weight_decay } layer_wise_groups[layer_id]['params'].append(p) if layer_id > num_max_layers: num_max_layers = layer_id # Update lr for each layer for layer_id, group in layer_wise_groups.items(): group['lr'] = self.hparams.optimizer.lr * ( self.hparams.train.layer_decay.decay ** (num_max_layers - layer_id)) # Reset the torch optimizers param groups optimizer.param_groups = [] for layer_id, group in layer_wise_groups.items(): optimizer.add_param_group(group) # Print optimizer info for debugging keys = set([k for hp in hps for k in hp.keys()]) # Special hparams utils.train.log_optimizer(log, optimizer, keys) # Configure scheduler if "scheduler" not in self.hparams: return optimizer lr_scheduler = utils.instantiate( registry.scheduler, self.hparams.scheduler, optimizer ) scheduler = { "scheduler": lr_scheduler, "interval": self.hparams.train.interval, # 'epoch' or 'step' "monitor": self.hparams.train.monitor, "name": "trainer/lr", # default is e.g. 'lr-AdamW' } # See documentation for how to configure the return # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.configure_optimizers return [optimizer], [scheduler] def train_dataloader(self): return self.dataset.train_dataloader(**self.hparams.loader) def _eval_dataloaders_names(self, loaders, prefix): """Process loaders into a list of names and loaders""" if utils.is_dict(loaders): return [ f"{prefix}/{k}" if k is not None else prefix for k in loaders.keys() ], list(loaders.values()) elif utils.is_list(loaders): return [f"{prefix}/{i}" for i in range(len(loaders))], loaders else: return [prefix], [loaders] def _eval_dataloaders(self): # Return all val + test loaders val_loaders = self.dataset.val_dataloader(**self.hparams.loader) test_loaders = self.dataset.test_dataloader(**self.hparams.loader) val_loader_names, val_loaders = self._eval_dataloaders_names(val_loaders, "val") test_loader_names, test_loaders = self._eval_dataloaders_names( test_loaders, "test" ) # Duplicate datasets for ema if self.hparams.train.ema > 0.0: val_loader_names += [name + "/ema" for name in val_loader_names] val_loaders = val_loaders + val_loaders test_loader_names += [name + "/ema" for name in test_loader_names] test_loaders = test_loaders + test_loaders # adding option to only have val loader at eval (e.g., if test is duplicate) eval_loader_names = [] eval_loaders = [] if not self.hparams.train.get("remove_val_loader_in_eval", False): eval_loader_names += val_loader_names eval_loaders += val_loaders if not self.hparams.train.get("remove_test_loader_in_eval", False): eval_loader_names += test_loader_names eval_loaders += test_loaders return eval_loader_names, eval_loaders def val_dataloader(self): val_loader_names, val_loaders = self._eval_dataloaders() self.val_loader_names = val_loader_names return val_loaders def test_dataloader(self): test_loader_names, test_loaders = self._eval_dataloaders() self.test_loader_names = ["final/" + name for name in test_loader_names] return test_loaders # pytorch-lightning utils and entrypoint def create_trainer(config, **kwargs): callbacks: List[pl.Callback] = [] logger = None # WandB Logging if config.get("wandb") is not None: # Pass in wandb.init(config=) argument to get the nice 'x.y.0.z' hparams logged # Can pass in config_exclude_keys='wandb' to remove certain groups import wandb logger = CustomWandbLogger( config=utils.to_dict(config, recursive=True), settings=wandb.Settings(start_method="fork"), **config.wandb, ) # Lightning callbacks if "callbacks" in config: for _name_, callback in config.callbacks.items(): if config.get("wandb") is None and _name_ in ["learning_rate_monitor"]: continue log.info(f"Instantiating callback <{registry.callbacks[_name_]}>") callback._name_ = _name_ callbacks.append(utils.instantiate(registry.callbacks, callback)) # Add ProgressiveResizing callback if config.callbacks.get("progressive_resizing", None) is not None: num_stages = len(config.callbacks.progressive_resizing.stage_params) log.info(f"Progressive Resizing: {num_stages} stages") for i, e in enumerate(config.callbacks.progressive_resizing.stage_params): # Stage params are resolution and epochs, pretty print log.info(f"\tStage {i}: {e['resolution']} @ {e['epochs']} epochs") # Configure ddp automatically n_devices = config.trainer.get('devices', 1) if isinstance(n_devices, Sequence): # trainer.devices could be [1, 3] for example n_devices = len(n_devices) if n_devices > 1 and config.trainer.get('strategy', None) is None: config.trainer.strategy = dict( _target_='pytorch_lightning.strategies.DDPStrategy', find_unused_parameters=False, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations gradient_as_bucket_view=True, ) # Init lightning trainer log.info(f"Instantiating trainer <{config.trainer._target_}>") # special processing for seqlen warmup scheduler (reload) trainer = hydra.utils.instantiate(config.trainer, callbacks=callbacks, logger=logger) return trainer def fsspec_exists(filename): fs, _ = fsspec.core.url_to_fs(filename) return fs.exists(filename) def train(config): if config.train.seed is not None: pl.seed_everything(config.train.seed, workers=True) trainer = create_trainer(config) model = SequenceLightningModule(config) # Load pretrained_model if specified if config.train.get("pretrained_model_path", None) is not None: # PTL style. Note, method returns a new model object, and need to pass config. model = SequenceLightningModule.load_from_checkpoint( config.train.pretrained_model_path, config=config, strict=config.train.pretrained_model_strict_load, ) # Run initial validation epoch (useful for debugging, fine-tuning) if config.train.validate_at_start: log.info("Running validation before training") trainer.validate(model) log.info(f'{config.train.ckpt=} {fsspec_exists(config.train.ckpt)=}') # if config.train.get("compile_model", False): # model = torch.compile(model, mode="reduce-overhead") if config.train.ckpt is not None and fsspec_exists(config.train.ckpt): trainer.fit(model, ckpt_path=config.train.ckpt) else: trainer.fit(model) if config.train.test: if config.train.get("cross_validation", False): # First, load the best validation model best_val_ckpt = os.path.join( model.hparams.callbacks.model_checkpoint.dirpath, f"{model.hparams.callbacks.model_checkpoint.filename}.ckpt", ) # Update config so we do not load just the backbone config.train.pretrained_model_state_hook.update({"_name_": None}) # Remove validation loader config.train.update({"remove_val_loader_in_eval": True}) config.train.update({"remove_test_loader_in_eval": False}) ckpt = torch.load(best_val_ckpt) log.info(f"Loaded best validation checkpoint from epoch {ckpt['epoch']}") trainer.validate(model, ckpt_path=best_val_ckpt) else: trainer.validate(model) @hydra.main(config_path="configs", config_name="config.yaml") def main(config: OmegaConf): # Process config: # - register evaluation resolver # - filter out keys used only for interpolation # - optional hooks, including disabling python warnings or debug friendly configuration config = utils.train.process_config(config) # if config.train.get("compile_model", False): # # See: https://github.com/arogozhnikov/einops/wiki/Using-torch.compile-with-einops # from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1 # allow_ops_in_compiled_graph() # Pretty print config using Rich library utils.train.print_config(config, resolve=True) train(config) if __name__ == "__main__": main() ================================================ FILE: vep_embeddings.py ================================================ """Dump model embeddings for VEP classification task. """ import argparse import os from functools import partial from os import path as osp from typing import Dict, Iterable, Optional import enformer_pytorch import fsspec import torch import torch.distributed as dist import torch.nn as nn from datasets import load_dataset, load_from_disk from sklearn import preprocessing from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler from tqdm.auto import tqdm from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, DefaultDataCollator from src.dataloaders.utils.rc import string_reverse_complement from src.utils.train import get_logger WINDOW_SIZE_BP = 1536 log = get_logger(__name__) class DNAEmbeddingModel(nn.Module): """Wrapper around HF model. Args: model_name_or_path: str, path to HF model. """ def __init__( self, model_name_or_path: str, ): super().__init__() self.model_name_or_path = model_name_or_path # Enformer uses different library for loading if "enformer" in model_name_or_path.lower(): self.backbone = enformer_pytorch.from_pretrained( model_name_or_path, use_tf_gamma=False, use_checkpointing=True ) # NT model is not compatible with AutoModel class elif "nucleotide-transformer" in model_name_or_path.lower(): # NT LM `backbone` is under the `.esm` attribute self.backbone = AutoModelForMaskedLM.from_pretrained(model_name_or_path, trust_remote_code=True).esm else: self.backbone = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) def forward(self, input_ids): """Backbone forward pass to retrieve last_hidden_state.""" if "enformer" in self.model_name_or_path.lower(): # Enformer forward pass has different signature return self.backbone(input_ids, return_embeddings=True)[1] return self.backbone(input_ids).last_hidden_state class EnformerTokenizer: """Enformer tokenizer.""" # Order is important here! (See: https://github.com/lucidrains/enformer-pytorch?tab=readme-ov-file#usage) pad_token = "P" # Padding token should be a character to avoid issues with tokenization encode_map = {"A": 0, "C": 1, "G": 2, "T": 3, "N": 4, pad_token: -1} @classmethod def encode( cls, seq: str, max_length: Optional[int] = None, truncation: Optional[bool] = False ) -> Iterable[int]: """Convert bp to token ids.""" if max_length is not None: assert max_length >= 0, "max_length should be a positive integer." if len(seq) < max_length: seq = seq + cls.pad_token * (max_length - len(seq)) elif truncation: seq = seq[:max_length] return [cls.encode_map[bp] for bp in seq.upper()] @classmethod def batch_encode_plus( cls, seqs: Iterable[str], max_length: Optional[int] = None, truncation: Optional[bool] = False, **kwargs, # ensures compatibility with HF tokenizer-like API ) -> Dict[str, Iterable[Iterable[int]]]: """Batch encode sequences using HF tokenizer-like API.""" input_ids = [cls.encode(seq, max_length=max_length, truncation=truncation) for seq in seqs] return {"input_ids": input_ids} def setup_distributed(): """Set environment variables for distributed runs.""" dist.init_process_group("nccl") def cleanup_distributed(): """Clean up processes from distributed runs.""" dist.destroy_process_group() def fsspec_exists(filename): """Check if file exists in manner compatible with fsspec.""" fs, _ = fsspec.core.url_to_fs(filename) return fs.exists(filename) def fsspec_listdir(dirname): """Listdir in manner compatible with fsspec.""" fs, _ = fsspec.core.url_to_fs(dirname) return fs.ls(dirname) # Processing functions def recast_chromosome_tissue_dist2TSS(examples): """Recast chromosome to int.""" return { "chromosome": -1 if examples["chromosome"] == "X" else int(examples["chromosome"]), "tissue": examples["tissue"], "distance_to_nearest_tss": examples["distance_to_nearest_tss"] } def tokenize_variants(examples, tokenizer, max_length: int): """Tokenize sequence. Args: examples: (batch of) items from the dataset. tokenizer: AutoTokenizer. max_length: int. Returns: dict with values as list of token ids. """ ref_tokenized = tokenizer.batch_encode_plus( examples["ref_forward_sequence"], add_special_tokens=False, return_attention_mask=False, max_length=max_length, truncation=True, ) alt_tokenized = tokenizer.batch_encode_plus( examples["alt_forward_sequence"], add_special_tokens=False, return_attention_mask=False, max_length=max_length, truncation=True, ) ref_rc_tokenized = tokenizer.batch_encode_plus( [string_reverse_complement(seq) for seq in examples["ref_forward_sequence"]], add_special_tokens=False, return_attention_mask=False, max_length=max_length, truncation=True, ) alt_rc_tokenized = tokenizer.batch_encode_plus( [string_reverse_complement(seq) for seq in examples["alt_forward_sequence"]], add_special_tokens=False, return_attention_mask=False, max_length=max_length, truncation=True, ) return { "ref_input_ids": ref_tokenized["input_ids"], "alt_input_ids": alt_tokenized["input_ids"], "ref_rc_input_ids": ref_rc_tokenized["input_ids"], "alt_rc_input_ids": alt_rc_tokenized["input_ids"], } def find_variant_idx(examples): """Find token location that differs between reference and variant sequence. Args: examples: items from the dataset (not batched). Returns: dict with values index of difference. """ # Guess that variant is at halfway point idx = len(examples["ref_input_ids"]) // 2 if examples["ref_input_ids"][idx] == examples["alt_input_ids"][idx]: # If no, loop through sequence and find variant location idx = -1 for i, (ref, alt) in enumerate(zip(examples["ref_input_ids"], examples["alt_input_ids"])): if ref != alt: idx = i # Same as above, but for reverse complement rc_idx = len(examples["ref_rc_input_ids"]) // 2 - 1 if examples["ref_rc_input_ids"][rc_idx] == examples["alt_rc_input_ids"][rc_idx]: rc_idx = -1 for i, (ref, alt) in enumerate(zip(examples["ref_rc_input_ids"], examples["alt_rc_input_ids"])): if ref != alt: rc_idx = i return {"variant_idx": idx, "rc_variant_idx": rc_idx} def prepare_dataset(args, tokenizer): """Prepare or load the tokenized dataset.""" # Data Preprocessing num_tokens = args.seq_len // args.bp_per_token # Load data cache_dir = osp.join( os.getenv("HF_HOME"), "datasets", "InstaDeepAI___genomics-long-range-benchmark", "variant_effect_gene_expression", f"seqlen={args.seq_len}" ) if "nucleotide-transformer" in args.model_name_or_path.lower(): # NT uses 6-mers, so tokenization is different preprocessed_cache_file = osp.join(cache_dir, "6mer_token_preprocessed") elif "enformer" in args.model_name_or_path.lower(): # Enformer tokenization requires having vocab of just `A,C,G,T,N` (in that order) preprocessed_cache_file = osp.join(cache_dir, "enformer_char_token_preprocessed") else: preprocessed_cache_file = osp.join(cache_dir, "char_token_preprocessed") log.warning(f"Cache dir: {cache_dir}") log.warning(f"Cache dir preprocessed: {preprocessed_cache_file}") if not fsspec_exists(preprocessed_cache_file): if dist.get_rank() == 0: dataset = load_dataset( "InstaDeepAI/genomics-long-range-benchmark", task_name="variant_effect_gene_expression", sequence_length=args.seq_len, load_from_cache=False, ) log.warning("Dataset loaded. Cached to disk:") log.warning(osp.dirname(list(dataset.cache_files.values())[0][0]["filename"])) try: del dataset["validation"] # `validation` split is empty except KeyError: pass # Process data dataset = dataset.filter( lambda example: example["ref_forward_sequence"].count('N') < 0.005 * args.seq_len, desc="Filter N's" ) dataset = dataset.map( recast_chromosome_tissue_dist2TSS, remove_columns=["chromosome", "tissue", "distance_to_nearest_tss"], desc="Recast chromosome" ) dataset = dataset.map( partial(tokenize_variants, tokenizer=tokenizer, max_length=num_tokens), batch_size=1000, batched=True, remove_columns=["ref_forward_sequence", "alt_forward_sequence"], desc="Tokenize" ) dataset = dataset.map(find_variant_idx, desc="Find variant idx") dataset.save_to_disk(preprocessed_cache_file) dist.barrier() # Processes need to wait for dataset to be saved to disk (if not already done) dataset = load_from_disk(preprocessed_cache_file) log.warning(f"Loaded preprocessed dataset from {preprocessed_cache_file}") log.warning(dataset) return dataset def get_backbone_model(args, device): """Get the backbone model.""" model = DNAEmbeddingModel( model_name_or_path=args.model_name_or_path, ) model.eval() return DDP(model.to(device)) def concat_storage_dict_values(storage_dict): """Helper method that combines lists of tensors in storage_dict into a single torch.Tensor.""" return {key: torch.cat(storage_dict[key], dim=0) for key in storage_dict.keys()} def dump_embeddings(args, dataset, model, device): """Dump embeddings to disk.""" def extract_embeddings(item_ref, item_alt, variant_idx): """Extract embedding representation from last layer outputs Args: item_ref: torch.Tensor, shape (batch_size, seq_len, hidden_size) Ref embedding item_alt: torch.Tensor, shape (batch_size, seq_len, hidden_size) Alt embedding variant_idx: torch.Tensor, shape (batch_size,) Index of variant Returns: layer_metrics: dict, with values to save to disk """ layer_metrics = {} # Compute windowed statistics if "enformer" in args.model_name_or_path.lower(): window_size = WINDOW_SIZE_BP // 128 # Enformer's receptive field is 128 # We also need to override variant_idx since Enformer model reduces to target_length of 896 variant_idx = torch.ones_like(variant_idx) * item_ref.size(1) // 2 else: window_size = WINDOW_SIZE_BP // args.bp_per_token # Add 1 so that window is: [window // 2 - SNP - window // 2] start, end = -window_size // 2, window_size // 2 + 1 expanded_indices = torch.arange(start, end, device=item_ref.device).unsqueeze(0) + \ variant_idx.unsqueeze(1).to(item_ref.device) expanded_indices = torch.clamp(expanded_indices, 0, item_ref.size(1) - 1) # Handle boundary conditions tokens_window_ref = torch.gather( item_ref, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, item_ref.size(2)) ).mean(dim=1) tokens_window_alt = torch.gather( item_alt, 1, expanded_indices.unsqueeze(-1).expand(-1, -1, item_ref.size(2)) ).mean(dim=1) layer_metrics["concat_avg_ws"] = torch.cat([tokens_window_ref, tokens_window_alt], dim=-1) return layer_metrics embeds_path = osp.join(args.downstream_save_dir, args.name) os.makedirs(embeds_path, exist_ok=True) dataloader_params = { "batch_size": args.embed_dump_batch_size, "collate_fn": DefaultDataCollator(return_tensors="pt"), "num_workers": args.num_workers, "pin_memory": False, "shuffle": False, "drop_last": True } # Process label_encoder = preprocessing.LabelEncoder() label_encoder = preprocessing.LabelEncoder() label_encoder.fit(dataset["test"]["tissue"]) train_tissue_embed = label_encoder.transform(dataset["train"]["tissue"]) dataset["train"] = dataset["train"].add_column("tissue_embed", train_tissue_embed) test_tissue_embed = label_encoder.transform(dataset["test"]["tissue"]) dataset["test"] = dataset["test"].add_column("tissue_embed", test_tissue_embed) if not all([ fsspec_exists(osp.join(embeds_path, f"{split_name}_embeds_combined.pt")) for split_name in dataset.keys() ]): for split_name, split in dataset.items(): sampler = DistributedSampler( split, shuffle=dataloader_params.get("shuffle", False), drop_last=dataloader_params.get("drop_last", True), ) dl = DataLoader(split, **dataloader_params, sampler=sampler) storage_dict = { "concat_avg_ws": [], "rc_concat_avg_ws": [], "chromosome": [], "labels": [], "distance_to_nearest_tss": [], "tissue_embed": [], } with torch.no_grad(): for batch_idx, batch in tqdm( enumerate(dl), total=len(dl), desc=f"[RANK {dist.get_rank()}] Embedding {split_name}", disable=dist.get_rank() != 0 # Only rank 0 updates pbar ): for key in ["chromosome", "labels", "distance_to_nearest_tss", "tissue_embed"]: storage_dict[key].append(batch[key].to("cpu", non_blocking=True)) with torch.autocast(device_type="cuda", dtype=torch.float16): output_alt = model(batch["alt_input_ids"].to(device)) output_ref = model(batch["ref_input_ids"].to(device)) if args.rcps: num_channels = output_alt.size(-1) # Flip along length and channel dims to preserve RC equivariance # i.e. output_rc(RC(inputs)) = outputs(inputs) output_alt_rc = output_alt[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) output_ref_rc = output_ref[..., num_channels // 2:].contiguous().flip(dims=[1, 2]) output_alt = output_alt[..., :num_channels // 2] output_ref = output_ref[..., :num_channels // 2] else: # Flip along length dim so variant_idx aligns output_alt_rc = model(batch["alt_rc_input_ids"].to(device)).contiguous().flip(dims=[1]) output_ref_rc = model(batch["ref_rc_input_ids"].to(device)).contiguous().flip(dims=[1]) metrics = extract_embeddings( item_ref=output_ref, item_alt=output_alt, variant_idx=batch["variant_idx"], ) for key, value in metrics.items(): storage_dict[key].append(metrics[key].to("cpu", non_blocking=True)) metrics_rc = extract_embeddings( item_ref=output_ref_rc, item_alt=output_alt_rc, variant_idx=batch["variant_idx"], ) for key, value in metrics_rc.items(): storage_dict[f"rc_{key}"].append(metrics_rc[key].to("cpu", non_blocking=True)) if batch_idx % 100 == 0: # Every machine should print progress updates print(f"[RANK {dist.get_rank()}] Completed index: {batch_idx}/{len(dl)}") storage_dict_temp = concat_storage_dict_values(storage_dict) with fsspec.open(osp.join(embeds_path, f"{split_name}_embeds_{dist.get_rank()}.pt"), "wb") as f: torch.save(storage_dict_temp, f) print(f"[RANK {dist.get_rank()}] Saved {split_name} to {osp.join(embeds_path, f'{split_name}_embeds_{dist.get_rank()}.pt')}") else: log.warning("Embeddings already exist, skipping!") def combine_embeddings(embeds_path): """Combine embeddings from different files.""" # Check if combined embeddings exist, and if not, aggregate them for split in ["train", "test"]: if not fsspec_exists(osp.join(embeds_path, f"{split}_embeds_combined.pt")): storage_dict = { "concat_avg_ws": [], "rc_concat_avg_ws": [], "chromosome": [], "labels": [], "distance_to_nearest_tss": [], "tissue_embed": [], } for filename in fsspec_listdir(embeds_path): if f"{split}_embeds_" in filename: log.warning(f"Loading data from: {filename}") with fsspec.open(filename, "rb") as f: tmp_data = torch.load(f) for key in storage_dict.keys(): storage_dict[key].append(tmp_data[key]) storage_dict = concat_storage_dict_values(storage_dict) log.warning(f"Saving combined data to: {embeds_path}/{split}_embeds_combined.pt") with fsspec.open(osp.join(embeds_path, f"{split}_embeds_combined.pt"), "wb") as f: torch.save(storage_dict, f) def main(args): """Main entry point.""" # Reproducibility torch.use_deterministic_algorithms(True) torch.backends.cudnn.benchmark = False # Init distributed log.warning("Initializing distributed...") dist.init_process_group("nccl") print(f"[RANK {dist.get_rank()}] Distributed initialized: rank {dist.get_rank()}") # All processes print this # Setup device device = torch.device(f"cuda:{dist.get_rank()}") print(f"[RANK {dist.get_rank()}] Using device: {device}.") # All processes print this # Init tokenizer if "enformer" in args.model_name_or_path.lower(): # Enformer tokenization requires having vocab of just `A,C,G,T,N` (in that order) tokenizer = EnformerTokenizer() else: tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) # Get dataset dist.barrier() dataset = prepare_dataset(args, tokenizer) # Get model dist.barrier() model = get_backbone_model(args, device) log.warning("Model loaded.") # Dump embeddings dist.barrier() dump_embeddings(args, dataset, model, device) # Combine embeddings into single file dist.barrier() cleanup_distributed() combine_embeddings(osp.join(args.downstream_save_dir, args.name)) if __name__ == "__main__": torch.multiprocessing.set_sharing_strategy('file_system') parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--seq_len", type=int, default=131072, help="Sequence length (in bp)..") parser.add_argument("--bp_per_token", type=int, default=1, help="Number of base pairs per token.") parser.add_argument("--model_name_or_path", type=str, default=None) parser.add_argument("--downstream_save_dir", type=str, default="./outputs/downstream/vep_embeddings", help="Directory to save downstream task.") parser.add_argument("--name", type=str, default=None, help="Embeddings model name.") parser.add_argument("--rcps", default=False, action="store_true", help="Use RCPS.") parser.add_argument("--no-rcps", dest="rcps", action="store_false", help="Do not use RCPS.") parser.add_argument("--embed_dump_batch_size", type=int, default=1, help="Batch size for embedding dump.") parser.add_argument("--num_workers", type=int, default=0, help="Number of workers.") opts, _ = parser.parse_known_args() log.warning("*** Args ************************") for k, v in vars(opts).items(): log.warning(f" - {k}: {v}") log.warning("******************************\n") main(opts) ================================================ FILE: vep_svm.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "db878bc1", "metadata": {}, "source": [ "## Imports and Setup" ] }, { "cell_type": "code", "execution_count": null, "id": "c79a903b", "metadata": { "tags": [] }, "outputs": [], "source": [ "import random\n", "import time\n", "from os import path as osp\n", "\n", "import fsspec\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import torch\n", "from sklearn.metrics import roc_auc_score\n", "from sklearn.pipeline import make_pipeline\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.svm import SVC\n", "from tqdm.auto import tqdm" ] }, { "cell_type": "code", "execution_count": null, "id": "4034f167", "metadata": {}, "outputs": [], "source": [ "DIST_TO_TSS = [[0, 30_000], [30_000, 100_000], [100_000, np.infty]]\n", "USE_TISSUE = [True] # used as another for loop for fitting SVM, whether to use tissue embed or not\n", "Cs = [1, 5, 10] # for loop in fitting SVM, inverse of L2 penalty (sklearn hyperparam)\n", "PATH_TO_OUTPUTS = \"./outputs/downstream/vep_embeddings\"" ] }, { "cell_type": "code", "execution_count": null, "id": "55c58437", "metadata": { "tags": [] }, "outputs": [], "source": [ "def fsspec_exists(filename: str) -> bool:\n", " \"\"\"Check if file exists in manner compatible with fsspec.\"\"\"\n", " fs, _ = fsspec.core.url_to_fs(filename)\n", " return fs.exists(filename)" ] }, { "cell_type": "code", "execution_count": null, "id": "18522e17", "metadata": {}, "outputs": [], "source": [ "def dataset_nan_filter(data: dict, data_key: str) -> dict:\n", " \"\"\"Filter any items that have NaN in embedding within TSS bucket\"\"\"\n", " mask_out = torch.logical_or(\n", " torch.any(data[data_key].isnan(), dim=1),\n", " torch.any(data[f\"rc_{data_key}\"].isnan(), dim=1)\n", " )\n", " \n", " new_data = dict()\n", " for data_key in data.keys():\n", " new_data[data_key] = data[data_key][~mask_out]\n", "\n", " return new_data\n", "\n", "def dataset_tss_filter(data: dict, min_distance: int, max_distance: int) -> dict:\n", " \"\"\"Filter the data to items that fall within TSS bucket\"\"\"\n", " distance_mask = ((data[\"distance_to_nearest_tss\"] >= min_distance) \n", " & (data[\"distance_to_nearest_tss\"] <= max_distance))\n", " new_data = dict()\n", " for data_key in data.keys():\n", " new_data[data_key] = data[data_key][distance_mask]\n", "\n", " return new_data" ] }, { "cell_type": "markdown", "id": "ef3d1006", "metadata": {}, "source": [ "## Specify which models to test" ] }, { "cell_type": "code", "execution_count": null, "id": "4629cb30", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Embeddings to test\n", "model_dict = {\n", " \"HyenaDNA\": dict(\n", " embed_path=\"hyena_downstream-seqlen=131k\",\n", " rc_aug=False,\n", " conjoin_train=False,\n", " conjoin_test=False,\n", " key=\"concat_avg_ws\",\n", " ),\n", " \"Caduceus-Ph\": dict(\n", " embed_path=\"caduceus-ph_downstream-seqlen=131k\",\n", " rc_aug=False,\n", " conjoin_train=False,\n", " conjoin_test=True,\n", " key=\"concat_avg_ws\",\n", " ),\n", " \"Caduceus w/o Equiv.\": dict(\n", " embed_path=\"caduceus-ph_downstream-seqlen=131k\",\n", " rc_aug=False,\n", " conjoin_train=False,\n", " conjoin_test=False,\n", " key=\"concat_avg_ws\",\n", " ),\n", " \"Caduceus-PS\": dict(\n", " embed_path=\"caduceus-ps_downstream-seqlen=131k\",\n", " rc_aug=False,\n", " conjoin_train=True,\n", " conjoin_test=False,\n", " key=\"concat_avg_ws\",\n", " ),\n", " \"Enformer\": dict(\n", " embed_path=\"enformer-seqlen=196k\",\n", " rc_aug=False,\n", " conjoin_train=False,\n", " conjoin_test=False,\n", " key=\"concat_avg_ws\",\n", " ),\n", " \"NTv2\": dict(\n", " embed_path=\"NTv2_downstream-seqlen=12k\",\n", " rc_aug=False,\n", " conjoin_train=False,\n", " conjoin_test=False,\n", " key=\"concat_avg_ws\",\n", " ),\n", "}" ] }, { "cell_type": "markdown", "id": "12e64367", "metadata": {}, "source": [ "## Fit and test SVM" ] }, { "cell_type": "code", "execution_count": null, "id": "6eaeb519-5c35-4fba-a09b-2d47c122320d", "metadata": { "scrolled": false, "tags": [] }, "outputs": [], "source": [ "metrics = {\n", " \"model_name\": [],\n", " \"bucket_id\": [],\n", " \"use_tissue\": [],\n", " \"C\": [],\n", " \"seed\": [],\n", " \"AUROC\": [],\n", "}\n", "\n", "for model_name, downstream_kwargs in model_dict.items():\n", " print(f\"********** Gathering results for: {model_name} **********\")\n", " embed_path = downstream_kwargs[\"embed_path\"]\n", " rc_aug = downstream_kwargs[\"rc_aug\"]\n", " conjoin_train = downstream_kwargs[\"conjoin_train\"]\n", " conjoin_test = downstream_kwargs[\"conjoin_test\"]\n", " key = downstream_kwargs[\"key\"]\n", " \n", " if \"NT\" in model_name: assert (rc_aug == False) and (conjoin_train == False) and (conjoin_test == False)\n", " \n", " base_embeds_path = PATH_TO_OUTPUTS\n", " embeds_path = osp.join(base_embeds_path, embed_path)\n", " \n", " print(f\"Embed Path: {embeds_path}\")\n", " with fsspec.open(osp.join(embeds_path, \"train_embeds_combined.pt\"), \"rb\") as f:\n", " train_val_ds_raw = torch.load(f, map_location=\"cpu\")\n", " train_val_ds_raw = dataset_nan_filter(train_val_ds_raw, data_key=key)\n", " with fsspec.open(osp.join(embeds_path, \"test_embeds_combined.pt\"), \"rb\") as f:\n", " test_ds_raw = torch.load(f, map_location=\"cpu\")\n", " test_ds_raw = dataset_nan_filter(test_ds_raw, data_key=key)\n", " print(f\"Total Train size: {len(train_val_ds_raw[key])},\", end=\" \")\n", " print(f\"Total Test size: {len(test_ds_raw[key])},\", end=\" \")\n", " print(f\"Shape: {test_ds_raw[key].shape[1:]}\")\n", "\n", "\n", " for bucket_id, (min_dist, max_dist) in enumerate(DIST_TO_TSS):\n", " # Filter data to desired TSS bucket\n", " train_val_ds_filter = dataset_tss_filter(train_val_ds_raw, min_dist, max_dist)\n", " test_ds_filter = dataset_tss_filter(test_ds_raw, min_dist, max_dist)\n", " print(f\"- TSS bucket: [{min_dist}, {max_dist}],\", end=\" \")\n", " print(f\"Train size: {len(train_val_ds_filter[key])},\", end=\" \")\n", " print(f\"Test size: {len(test_ds_filter[key])}\")\n", " \n", " for use_tissue in USE_TISSUE:\n", " for C in Cs:\n", " for seed in range(1, 6): \n", " # Re-seed for SVM fitting\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", "\n", " svm_clf = make_pipeline(\n", " StandardScaler(),\n", " SVC(C=C, random_state=seed),\n", " )\n", "\n", " # Setup Train/Test dataset\n", " if conjoin_train:\n", " X = np.array(train_val_ds_filter[key])\n", " X += np.array(train_val_ds_filter[f\"rc_{key}\"])\n", " X /= 2\n", " else:\n", " X = np.array(train_val_ds_filter[key])\n", " X_with_tissue = np.concatenate(\n", " [X, np.array(train_val_ds_filter[\"tissue_embed\"])[..., None]],\n", " axis=-1\n", " )\n", " y = train_val_ds_filter[\"labels\"]\n", " if conjoin_train or conjoin_test:\n", " X_test = np.array(test_ds_filter[key])\n", " X_test += np.array(test_ds_filter[f\"rc_{key}\"])\n", " X_test /= 2\n", " else:\n", " X_test = np.array(test_ds_filter[key])\n", " X_test_with_tissue = np.concatenate(\n", " [X_test, np.array(test_ds_filter[\"tissue_embed\"])[..., None]],\n", " axis=-1\n", " )\n", " y_test = test_ds_filter[\"labels\"]\n", "\n", " print(f\"\\tFitting SVM ({use_tissue=}, {C=}, {seed=})...\", end=\" \")\n", " \n", " mask = np.random.choice(len(X), size=5000, replace= 5000 > len(X) )\n", " if use_tissue: \n", " X_train = X_with_tissue[mask]\n", " X_test = X_test_with_tissue\n", " else: \n", " X_train = X[mask]\n", " y_train = y[mask]\n", "\n", " start = time.time()\n", " svm_clf.fit(X_train, y_train)\n", " svm_y_pred = svm_clf.predict(X_test)\n", " svm_aucroc = roc_auc_score(y_test, svm_y_pred)\n", " end = time.time()\n", " print(f\"Completed! ({end - start:0.3f} s) -\", end=\" \")\n", " print(f\"AUROC: {svm_aucroc}\")\n", " \n", " metrics[\"model_name\"] += [model_name]\n", " metrics[\"bucket_id\"] += [bucket_id]\n", " metrics[\"use_tissue\"] += [use_tissue]\n", " metrics[\"C\"] += [C]\n", " metrics[\"seed\"] += [seed]\n", " metrics[\"AUROC\"] += [svm_aucroc]" ] }, { "cell_type": "code", "execution_count": null, "id": "597b0fe9", "metadata": {}, "outputs": [], "source": [ "df_metrics = pd.DataFrame.from_dict(metrics)\n", "df_metrics.to_csv(osp.join(PATH_TO_OUTPUTS, \"SVM_results.csv\"))" ] }, { "cell_type": "markdown", "id": "03e06a25", "metadata": {}, "source": [ "## Plot results" ] }, { "cell_type": "code", "execution_count": null, "id": "a362d3fa", "metadata": {}, "outputs": [], "source": [ "model_name_replacement = {\n", " \"Caduceus w/o Equiv.\": \"Caduceus w/o\\nEquiv. (7.7M)\",\n", " \"Caduceus-Ph\": \"Caduceus-Ph\\n(7.7M)\",\n", " \"Caduceus-PS\": \"Caduceus-PS\\n(7.7M)\",\n", " \"HyenaDNA\": \"HyenaDNA\\n(6.6M)\",\n", " \"NTv2\": \"NTv2\\n(500M)\",\n", " \"Enformer\": \"Enformer\\n(252M)\",\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "85b1c4fb", "metadata": {}, "outputs": [], "source": [ "# Formatting changes to df\n", "df = pd.read_csv(osp.join(PATH_TO_OUTPUTS, \"SVM_results.csv\"), index_col=0)\n", "df_display = df.rename(columns={\"bucket_id\": \"Distance to TSS\"})\n", "df_display = df_display.replace({\"Distance to TSS\": {0: \"0 - 30k\", 1: \"30 - 100k\", 2: \"100k+\"}})\n", "df_display = df_display.replace({\"model_name\": model_name_replacement})\n", "\n", "# Take average over seeds\n", "df_display_selected = df_display.groupby(\n", " [\"model_name\", \"Distance to TSS\", \"use_tissue\", \"C\"]\n", ").agg(AUROC=(\"AUROC\", np.mean)).reset_index()\n", "\n", "# Select best hyperparam by model/bucket\n", "best_ids = df_display_selected.groupby([\"model_name\", \"Distance to TSS\"])[\"AUROC\"].idxmax()\n", "df_display_selected = df_display_selected.loc[best_ids.reset_index()[\"AUROC\"].values]\n", "display(\n", " df_display_selected.pivot(\n", " index=\"model_name\", columns=\"Distance to TSS\", values=\"AUROC\"\n", " )[[\"0 - 30k\", \"30 - 100k\", \"100k+\"]]\n", ")\n", "display(df_display_selected[[\"model_name\", \"Distance to TSS\", \"C\", \"use_tissue\"]])" ] }, { "cell_type": "code", "execution_count": null, "id": "09a7f4a9", "metadata": {}, "outputs": [], "source": [ "# Filter results to selected hyperparams\n", "df_plot = pd.merge(\n", " df_display, df_display_selected,\n", " on=[\"model_name\", \"Distance to TSS\", \"use_tissue\", \"C\"]\n", ").drop(columns=[\"AUROC_y\"]).rename(columns={\"AUROC_x\": \"AUROC\"})\n", "\n", "# Plot results by distance to TSS\n", "sns.set_style(\"whitegrid\")\n", "g = sns.catplot(\n", " data=df_plot,\n", " x=\"model_name\",\n", " y=\"AUROC\",\n", " col=\"Distance to TSS\",\n", " hue=\"Distance to TSS\",\n", " kind=\"bar\",\n", " errorbar=\"sd\",\n", " height=12,\n", " aspect=1,\n", " dodge=False,\n", " order=list(model_name_replacement.values()),\n", ")\n", "g.set_xticklabels(rotation=60, fontsize=30)\n", "g.set(xlabel=\"\")\n", "g.set(ylim=(0.4, 0.7))\n", "g.set_titles(template=\"Dist. to TSS: {col_name}\", fontsize=40)\n", "g.fig.suptitle(\"Predicting Effects of Variants on Gene Expression\", y=1.1, fontsize=40)\n", "g._legend.remove()\n", "# Display bar values\n", "# (See: https://stackoverflow.com/questions/55586912/seaborn-catplot-set-values-over-the-bars)\n", "for ax in tqdm(g.axes.ravel(), leave=False):\n", " title = ax.title.get_text()\n", " ax.set_title(title, fontsize=35)\n", " for c in tqdm(ax.containers, leave=False):\n", " labels = [f\"{v.get_height():0.3f}\" for v in c]\n", " ax.bar_label(c, labels=labels, label_type=\"center\", color=\"white\", weight=\"bold\", fontsize=24)\n", "plt.show()\n", "g.savefig(osp.join(PATH_TO_OUTPUTS, \"SVM_results.png\"))" ] }, { "cell_type": "code", "execution_count": null, "id": "a7858241", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 5 }