Full Code of kuleshov-group/caduceus for AI

main 0060a6d8079b cached
104 files
394.9 KB
102.2k tokens
434 symbols
1 requests
Download .txt
Showing preview only (424K chars total). Download the full file or copy to clipboard to get everything.
Repository: kuleshov-group/caduceus
Branch: main
Commit: 0060a6d8079b
Files: 104
Total size: 394.9 KB

Directory structure:
gitextract_kw76odef/

├── .gitignore
├── LICENSE
├── README.md
├── caduceus/
│   ├── __init__.py
│   ├── configuration_caduceus.py
│   ├── modeling_caduceus.py
│   ├── modeling_rcps.py
│   ├── tests/
│   │   └── test_rcps.py
│   └── tokenization_caduceus.py
├── caduceus_env.yml
├── configs/
│   ├── callbacks/
│   │   ├── base.yaml
│   │   ├── checkpoint.yaml
│   │   ├── gpu_affinity.yaml
│   │   ├── rich.yaml
│   │   ├── val_every_n_global_steps.yaml
│   │   └── wandb.yaml
│   ├── config.yaml
│   ├── dataset/
│   │   ├── genomic_benchmark.yaml
│   │   ├── hg38.yaml
│   │   └── nucleotide_transformer.yaml
│   ├── experiment/
│   │   └── hg38/
│   │       ├── genomic_benchmark.yaml
│   │       ├── genomic_benchmark_cnn.yaml
│   │       ├── hg38.yaml
│   │       └── nucleotide_transformer.yaml
│   ├── loader/
│   │   └── default.yaml
│   ├── model/
│   │   ├── caduceus.yaml
│   │   ├── genomics_benchmark_cnn.yaml
│   │   ├── hyena.yaml
│   │   ├── layer/
│   │   │   └── hyena.yaml
│   │   └── mamba.yaml
│   ├── optimizer/
│   │   ├── adam.yaml
│   │   ├── adamw.yaml
│   │   └── sgd.yaml
│   ├── pipeline/
│   │   ├── genomic_benchmark.yaml
│   │   ├── hg38.yaml
│   │   └── nucleotide_transformer.yaml
│   ├── scheduler/
│   │   ├── constant.yaml
│   │   ├── constant_warmup.yaml
│   │   ├── cosine.yaml
│   │   ├── cosine_warmup.yaml
│   │   ├── cosine_warmup_timm.yaml
│   │   ├── linear_warmup.yaml
│   │   ├── multistep.yaml
│   │   ├── plateau.yaml
│   │   └── step.yaml
│   ├── task/
│   │   ├── lm.yaml
│   │   ├── multiclass_classification.yaml
│   │   ├── multilabel_classification.yaml
│   │   └── regression.yaml
│   └── trainer/
│       ├── debug.yaml
│       ├── default.yaml
│       ├── full.yaml
│       └── lm.yaml
├── setup_env.sh
├── slurm_scripts/
│   ├── dump_vep_embeddings.sh
│   ├── run_genomics_benchmark.sh
│   ├── run_genomics_benchmark_cnn.sh
│   ├── run_nucleotide_transformer.sh
│   ├── run_pretrain_caduceus.sh
│   ├── run_pretrain_hyena.sh
│   ├── run_pretrain_mamba.sh
│   ├── wrapper_run_genomics.sh
│   ├── wrapper_run_genomics_cnn.sh
│   └── wrapper_run_nucleotide_transformer.sh
├── src/
│   ├── __init__.py
│   ├── callbacks/
│   │   ├── params.py
│   │   ├── timer.py
│   │   └── validation.py
│   ├── dataloaders/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── datasets/
│   │   │   ├── genomic_bench_dataset.py
│   │   │   ├── hg38_char_tokenizer.py
│   │   │   ├── hg38_dataset.py
│   │   │   └── nucleotide_transformer_dataset.py
│   │   ├── fault_tolerant_sampler.py
│   │   ├── genomics.py
│   │   └── utils/
│   │       ├── mlm.py
│   │       └── rc.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── baseline/
│   │   │   ├── __init__.py
│   │   │   └── genomics_benchmark_cnn.py
│   │   ├── nn/
│   │   │   ├── __init__.py
│   │   │   ├── activation.py
│   │   │   ├── adaptive_softmax.py
│   │   │   └── utils.py
│   │   └── sequence/
│   │       ├── __init__.py
│   │       ├── dna_embedding.py
│   │       ├── hyena.py
│   │       └── long_conv_lm.py
│   ├── ops/
│   │   └── fftconv.py
│   ├── tasks/
│   │   ├── decoders.py
│   │   ├── encoders.py
│   │   ├── metrics.py
│   │   ├── tasks.py
│   │   └── torchmetrics.py
│   └── utils/
│       ├── __init__.py
│       ├── config.py
│       ├── optim/
│       │   └── schedulers.py
│       ├── optim_groups.py
│       ├── registry.py
│       └── train.py
├── train.py
├── vep_embeddings.py
└── vep_svm.ipynb

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
data.tar.gz
*.tsf
*.ckpt
.ipynb_checkpoints
*/.ipynb_checkpoints/*
*.lprof

.DS_Store
.idea/
outputs/

# slurm log files
watch_folder/

data

# Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# Mr Developer
.mr.developer.cfg
.project
.pydevproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# End of https://www.gitignore.io/api/python


================================================
FILE: LICENSE
================================================
                                Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.



================================================
FILE: README.md
================================================
<p align="center">
    <img src="assets/Caduceus_image.png" alt="Caduceus" width="200"/>
</p>


# Caduceus &#9764;: Bi-Directional Equivariant Long-Range DNA Sequence Modeling
[[Blog]](https://caduceus-dna.github.io/) &nbsp; | &nbsp; [[arXiv]](https://arxiv.org/abs/2403.03234) &nbsp; | &nbsp; [[HuggingFace 🤗]](https://huggingface.co/collections/kuleshov-group/caducues-65dcb89b4f54e416ef61c350)

This repository contains code for reproducing the results in the paper "Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling," [Schiff et al. (2024)](https://arxiv.org/abs/2403.03234).

## Using Caduceus with 🤗
<a name="HF"></a>
We have uploaded a pre-trained Caduceus model to the Huggingface hub.
The available models are:
- Caduceus-Ph: [kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16](https://huggingface.co/kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16)
  - Trained on sequences of length 131k, with a model size of 256 and 16 layers.
  - Trained for 50k steps and batch size of 8.
  - Trained with reverse-complement (RC) data augmentation.
- Caduceus-PS: [kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16](https://huggingface.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16)
  - Trained on sequences of length 131k, with a model size of 256 and 16 layers.
  - Trained for 50k steps and batch size of 8.
  - Model is RC equivariant, hence no RC data augmentation is required.

You can either use the pre-trained model directly within your trainer scripts or modify the config that initializes the model.

To use the pre-trained model for masked language modeling, use the following snippet:
```python
from transformers import AutoModelForMaskedLM, AutoTokenizer

# See the `Caduceus` collection page on the hub for list of available models.
model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
```

Alternatively, you can instantiate a model from scratch to train on your own data as follows:
```python
from transformers import AutoConfig, AutoModelForMaskedLM

# Add any config overrides here, see the `config.json` file on the hub for details.
config_overrides = {}
# See the `Caduceus` collection page on the hub for list of available models.
config = AutoConfig.from_pretrained(
 "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16",
 **config_overrides,
)
model = AutoModelForMaskedLM.from_config(config)
```

## Getting started in this repository
<a name="getting_started"></a>

To get started, create a conda environment containing the required dependencies.

```bash
conda env create -f caduceus_env.yml
```

Activate the environment.

```bash
conda activate caduceus_env
```

Create the following directories to store saved models and slurm logs:
```bash
mkdir outputs
mkdir watch_folder
```

## Reproducing Experiments

Below, we describe the steps required for reproducing the experiments in the paper.
Throughout, the main entry point for running experiments is the [`train.py`](./train.py) script.
We also provide sample `slurm` scripts for launching pre-training and downstream fine-tuning experiments in the [`slurm_scripts/`](./slurm_scripts) directory.

### Pretraining on Human Reference Genome
<a name="pretraining"></a>
(Data downloading instructions are copied from [HyenaDNA repo](https://github.com/HazyResearch/hyena-dna?tab=readme-ov-file#pretraining-on-human-reference-genome))

First, download the Human Reference Genome data.
It's comprised of 2 files, 1 with all the sequences (the `.fasta` file), and with the intervals we use (`.bed` file).

The file structure should look like

```
data
|-- hg38/
    |-- hg38.ml.fa
    |-- human-sequences.bed
```

Download fasta (.fa format) file (of the entire human genome) into `./data/hg38`.
~24 chromosomes in the whole genome (merged into 1 file), each chromosome is a continuous sequence, basically.
Then download the .bed file with sequence intervals (contains chromosome name, start, end, split, which then allow you to retrieve from the fasta file).
```bash
mkdir -p data/hg38/
curl https://storage.googleapis.com/basenji_barnyard2/hg38.ml.fa.gz > data/hg38/hg38.ml.fa.gz
gunzip data/hg38/hg38.ml.fa.gz  # unzip the fasta file
curl https://storage.googleapis.com/basenji_barnyard2/sequences_human.bed > data/hg38/human-sequences.bed
```

Launch pretraining run using the command line

```bash
python -m train \
  experiment=hg38/hg38 \
  callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \
  dataset.max_length=1024 \
  dataset.batch_size=1024 \
  dataset.mlm=true \
  dataset.mlm_probability=0.15 \
  dataset.rc_aug=false \
  model=caduceus \
  model.config.d_model=128 \
  model.config.n_layer=4 \
  model.config.bidirectional=true \
  model.config.bidirectional_strategy=add \
  model.config.bidirectional_weight_tie=true \
  model.config.rcps=true \
  optimizer.lr="8e-3" \
  train.global_batch_size=1024 \
  trainer.max_steps=10000 \
  +trainer.val_check_interval=10000 \
  wandb=null
```

or alternatively, if using a cluster that has `slurm` installed, adapt the scripts below:
```
slurm_scripts
|-- run_pretrain_caduceus.sh
|-- run_pretrain_hyena.sh
|-- run_pretrain_mamba.sh
```

and run the training as a batch job:
```bash
cd slurm_scripts
sbatch run_pretrain_caduceus.sh
```

### GenomicBenchmarks
<a name="genomicbenchmarks"></a>

The [GenomicBenchmarks](https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks) presented in [Grešová et al. (2023)](https://bmcgenomdata.biomedcentral.com/articles/10.1186/s12863-023-01123-8) is comprised of 8 classification tasks.

We can launch a downstream fine-tuning run on one of the tasks using the sample command below:
```bash
python -m train \
    experiment=hg38/genomic_benchmark \
    callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
    dataset.dataset_name="dummy_mouse_enhancers_ensembl" \
    dataset.train_val_split_seed=1 \
    dataset.batch_size=256 \
    dataset.rc_aug=false \
    +dataset.conjoin_train=false \
    +dataset.conjoin_test=false \
    loader.num_workers=2 \
    model=caduceus \
    model._name_=dna_embedding_caduceus \
    +model.config_path="<path to model_config.json>" \
    +model.conjoin_test=false \
    +decoder.conjoin_train=true \
    +decoder.conjoin_test=false \
    optimizer.lr="1e-3" \
    trainer.max_epochs=10 \
    train.pretrained_model_path="<path to .ckpt file>" \
    wandb=null
```

This sample run will fine-tune a pre-trained Caduceus-PS model on the `dummy_mouse_enhancers_ensembl` task.
Note some of the additional arguments present here, relative to the pre-training command from [above](#pretraining):
- `model.config_path` contains the path model config that was saved during pre-training.
This will be saved to the run directory of the pre-training experiment.
- `train.pretrained_model_path` contains the path to the pre-trained model checkpoint.
- `dataset.conjoin_train` determines whether the dataset will return a single sequence (`dataset.conjoin_train=false`) or the concatenation of a sequence and its reverse complement along `dim=-1`, during downstream fine-tuning training.
- `dataset.conjoin_test` is the same as above, but for inference (e.g., validation / test).
- `decoder.conjoin_train` determines whether the prediction head (a mean pooling and linear projection in the case of the Genomics Benchmark) is expecting an input tensor of shape `(batch_size, seq_len, d_model)` or `(batch_size, seq_len, d_model, 2)` during downstream fine-tuning training.
When set to `true` the decoder is run on `input[..., 0]` and `input[..., 1]` and the results are averaged to produce the final prediction.
- `decoder.conjoin_test` is the same as above, but for inference (e.g., validation / test).

Note this benchmark only contains a training and test split for each task.
Therefore, to have a more principled evaluation, we randomly split the training data into training and validation sets (90/10) using the `dataset.train_val_split_seed` argument.
We perform early stopping on validation metric (accuracy) and repeat this for 5 random seeds.

As with [pre-training](#pretraining), we can also launch the fine-tuning run as a batch job using the provided [`run_genomic_benchmark.sh`](./slurm_scripts/run_genomics_benchmark.sh) script.
We also provide a helper shell script [`wrapper_run_genomics.sh`](./slurm_scripts/wrapper_run_genomics.sh) that can be used to launch multiple fine-tuning runs in parallel.

Finally, the [`run_genomics_benchmark_cnn.sh`](./slurm_scripts/run_genomics_benchmark_cnn.sh) script can be used to train the CNN baseline for this experiment from scratch on the downstream tasks.

### Nucleotide Transformer datasets
<a name="nucleotidetransformer"></a>

The Nucleotide Transformer suite of tasks was proposed in [Dalla-Torre et al. (2023)](https://www.biorxiv.org/content/10.1101/2023.01.11.523679v1).
The data is available on HuggingFace: [InstaDeepAI/nucleotide_transformer_downstream_tasks](https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_downstream_tasks).

We can launch a downstream fine-tuning run on one of the tasks using the sample command below:
```bash
python -m train \
    experiment=hg38/nucleotide_transformer \
    callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
    dataset.dataset_name="${task}" \
    dataset.train_val_split_seed=${seed} \
    dataset.batch_size=${batch_size} \
    dataset.rc_aug="${rc_aug}" \
    +dataset.conjoin_test="${CONJOIN_TEST}" \
    loader.num_workers=2 \
    model._name_=dna_embedding_caduceus \
    +model.config_path="<path to model_config.json>" \
    +model.conjoin_test=false \
    +decoder.conjoin_train=true \
    +decoder.conjoin_test=false \
    optimizer.lr="1e-3" \
    trainer.max_epochs=10 \
    train.pretrained_model_path="<path to .ckpt file>" \
    trainer.max_epochs=20 \
    wandb=null
```

We can also launch as batch jobs (see [`run_nucleotide_transformer.sh`](./slurm_scripts/run_nucleotide_transformer.sh) and [`wrapper_run_nucleotide_transformer.sh`](./slurm_scripts/wrapper_run_nucleotide_transformer.sh) for details).

### eQTL SNP Variant Effect Prediction
<a name="vep"></a>
This task comes from the recently proposed Long Range Benchmark (LRB) in [Kao et al., 2023](https://llms4science-community.github.io/papers/LLMs4Bio24_paper_12.pdf).
The data is available on HuggingFace: [InstaDeepAI/genomics-long-range-benchmark](https://huggingface.co/datasets/InstaDeepAI/genomics-long-range-benchmark).
For this task we fit a model to the pre-trained and frozen embeddings of the DNA language models.
Therefore, to perform the evaluation, we proceed in 2 steps:
- **Step 1: Extract the embeddings** from the pre-trained model:
Run the [`vep_embeddings.py`](./vep_embeddings.py) script to extract the embeddings from the pre-trained model.
See the example below:
```bash
torchrun \
    --standalone \
    --nnodes=1 \
    --nproc-per-node=8 \
    vep_embeddings.py \
      --num_workers=2 \
      --seq_len=131072  \
      --bp_per_token=1  \
      --embed_dump_batch_size=1 \
      --name="caduceus-ps_downstream-seqlen=131k"  \
      --model_name_or_path="kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16" \
      --rcps
```

The `--rcps` flag is used to indicate that the model is reverse-complement equivariant.
When using other models, set this flag to false with `--no-rcps`.
To speed this step up, this script utilizes torch distributed data parallelism.

Please refer to the slurm script provided in [`slurm_scripts/dump_vep_embeddings.sh`](./slurm_scripts/dump_vep_embeddings.sh)
to launch this step as a batch job.

- **Step 2: Fit an SVM model to the embeddings** using this notebook: [`vep_svm.ipynb`](./vep_svm.ipynb).

## Citation
<a name="citation"></a>

If you find our work useful, please cite our paper using the following:
```
@article{schiff2024caduceus,
  title={Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling},
  author={Schiff, Yair and Kao, Chia-Hsiang and Gokaslan, Aaron and Dao, Tri and Gu, Albert and Kuleshov, Volodymyr},
  journal={arXiv preprint arXiv:2403.03234},
  year={2024}
}
```

## Acknowledgements
<a name="acknowledgements"></a>
This repository is adapted from the [HyenaDNA repo](https://github.com/HazyResearch/hyena-dna) and leverages much of the training, data loading, and logging infrastructure defined there.
HyenaDNA was originally derived from the [S4](https://github.com/state-spaces/s4) and [Safari](https://github.com/HazyResearch/safari) repositories.

We would like to thank Evan Trop and the [InstaDeep](https://www.instadeep.com/) team for useful discussions about the [Nucleotide Transformer leaderboard](https://huggingface.co/spaces/InstaDeepAI/nucleotide_transformer_benchmark)
and the Long Range Benchmark task.

Finally, we would like to thank [MosaicML](https://www.mosaicml.com/) for providing compute resources for some of the pre-training experiments.


================================================
FILE: caduceus/__init__.py
================================================
"""Hugging Face config, model, and tokenizer for Caduceus.

"""

from .configuration_caduceus import CaduceusConfig
from .modeling_caduceus import Caduceus, CaduceusForMaskedLM, CaduceusForSequenceClassification
from .tokenization_caduceus import CaduceusTokenizer


================================================
FILE: caduceus/configuration_caduceus.py
================================================
"""Caduceus config for Hugging Face.

"""

from typing import Optional, Union

from transformers import PretrainedConfig


class CaduceusConfig(PretrainedConfig):
    """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""
    model_type = "caduceus"

    def __init__(
            self,
            # From original MambaConfig
            d_model: int = 2560,
            n_layer: int = 64,
            vocab_size: int = 50277,
            ssm_cfg: Optional[dict] = None,
            rms_norm: bool = True,
            residual_in_fp32: bool = True,
            fused_add_norm: bool = True,
            pad_vocab_size_multiple: int = 8,

            # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
            norm_epsilon: float = 1e-5,

            # Used in init_weights
            initializer_cfg: Optional[dict] = None,

            # Caduceus-specific params
            bidirectional: bool = True,
            bidirectional_strategy: Union[str, None] = "add",
            bidirectional_weight_tie: bool = True,
            rcps: bool = False,
            complement_map: Optional[dict] = None,  # used for RCPSEmbedding / RCPSLMHead
            **kwargs,
    ):
        super().__init__(**kwargs)
        self.d_model = d_model
        self.n_layer = n_layer
        self.vocab_size = vocab_size
        self.ssm_cfg = ssm_cfg
        self.rms_norm = rms_norm
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.pad_vocab_size_multiple = pad_vocab_size_multiple
        self.norm_epsilon = norm_epsilon
        self.initializer_cfg = initializer_cfg
        self.bidirectional = bidirectional
        self.bidirectional_strategy = bidirectional_strategy
        self.bidirectional_weight_tie = bidirectional_weight_tie
        self.rcps = rcps
        self.complement_map = complement_map


================================================
FILE: caduceus/modeling_caduceus.py
================================================
"""Caduceus model for Hugging Face.

"""

import inspect
import math
from functools import partial
from typing import Optional, Tuple, Union

import torch
from mamba_ssm.modules.mamba_simple import Mamba
try:
    from mamba_ssm.modules.mamba_simple import Block  # Legacy mambav1 file structure
except ImportError:
    from mamba_ssm.modules.block import Block  # mambav2 file structure
from torch import nn
from torch.nn import functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput

try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn  # Legacy mambav1 file structure
except ImportError:
    try:
        from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn  # mambav2 file structure
    except ImportError:
        RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None

from .configuration_caduceus import CaduceusConfig
from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock


def create_block(
        d_model,
        ssm_cfg=None,
        norm_epsilon=1e-5,
        rms_norm=False,
        residual_in_fp32=False,
        fused_add_norm=False,
        layer_idx=None,
        bidirectional=True,
        bidirectional_strategy="add",
        bidirectional_weight_tie=True,
        rcps=False,
        device=None,
        dtype=None,
):
    """Create Caduceus block.

    Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
    """
    if ssm_cfg is None:
        ssm_cfg = {}
    factory_kwargs = {"device": device, "dtype": dtype}
    bidirectional_kwargs = {
        "bidirectional": bidirectional,
        "bidirectional_strategy": bidirectional_strategy,
        "bidirectional_weight_tie": bidirectional_weight_tie,
    }
    mixer_cls = partial(BiMambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs)
    norm_cls = partial(
        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
    )
    block_cls = RCPSMambaBlock if rcps else Block
    # mambav2 compatibility
    if "mlp_cls" in inspect.signature(block_cls.__init__).parameters:
        block = block_cls(
            d_model,
            mixer_cls,
            mlp_cls=nn.Identity,
            norm_cls=norm_cls,
            fused_add_norm=fused_add_norm,
            residual_in_fp32=residual_in_fp32,
        )
    else:
        block = block_cls(
            d_model,
            mixer_cls,
            norm_cls=norm_cls,
            fused_add_norm=fused_add_norm,
            residual_in_fp32=residual_in_fp32,
        )
    block.layer_idx = layer_idx
    return block


class BiMambaWrapper(nn.Module):
    """Thin wrapper around Mamba to support bi-directionality."""

    def __init__(
            self,
            d_model: int,
            bidirectional: bool = True,
            bidirectional_strategy: Optional[str] = "add",
            bidirectional_weight_tie: bool = True,
            **mamba_kwargs,
    ):
        super().__init__()
        if bidirectional and bidirectional_strategy is None:
            bidirectional_strategy = "add"  # Default strategy: `add`
        if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]:
            raise NotImplementedError(f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!")
        self.bidirectional = bidirectional
        self.bidirectional_strategy = bidirectional_strategy
        self.mamba_fwd = Mamba(
            d_model=d_model,
            **mamba_kwargs
        )
        if bidirectional:
            self.mamba_rev = Mamba(
                d_model=d_model,
                **mamba_kwargs
            )
            if bidirectional_weight_tie:  # Tie in and out projections (where most of param count lies)
                self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight
                self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias
                self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight
                self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias
        else:
            self.mamba_rev = None

    def forward(self, hidden_states, inference_params=None):
        """Bidirectional-enabled forward pass

        hidden_states: (B, L, D)
        Returns: same shape as hidden_states
        """
        out = self.mamba_fwd(hidden_states, inference_params=inference_params)
        if self.bidirectional:
            out_rev = self.mamba_rev(
                hidden_states.flip(dims=(1,)),  # Flip along the sequence length dimension
                inference_params=inference_params
            ).flip(dims=(1,))  # Flip back for combining with forward hidden states
            if self.bidirectional_strategy == "add":
                out = out + out_rev
            elif self.bidirectional_strategy == "ew_multiply":
                out = out * out_rev
            else:
                raise NotImplementedError(f"`{self.bidirectional_strategy}` for bi-directionality not implemented!")
        return out


class CaduceusEmbeddings(nn.Module):
    def __init__(
            self,
            config: CaduceusConfig,
            device=None,
            dtype=None,
    ):
        super().__init__()
        factory_kwargs = {"device": device, "dtype": dtype}
        if config.rcps:
            self.word_embeddings = RCPSEmbedding(
                config.vocab_size, config.d_model, config.complement_map, **factory_kwargs
            )
        else:
            self.word_embeddings = nn.Embedding(config.vocab_size, config.d_model, **factory_kwargs)

    def forward(self, input_ids):
        """
            input_ids: (batch, seqlen)
        """
        return self.word_embeddings(input_ids)


class CaduceusMixerModel(nn.Module):
    def __init__(
            self,
            config: CaduceusConfig,
            device=None,
            dtype=None,
    ) -> None:
        super().__init__()
        factory_kwargs = {"device": device, "dtype": dtype}

        self.fused_add_norm = config.fused_add_norm
        self.rcps = config.rcps
        self.residual_in_fp32 = config.residual_in_fp32

        self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)

        # Mamba changes the order of residual and layer norm:
        # Instead of LN -> Attn / MLP -> Add, we do:
        # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
        # the main branch (output of MLP / Mixer). The model definition is unchanged.
        # This is for performance reason: we can fuse add + layer_norm.
        if config.fused_add_norm:
            if layer_norm_fn is None or rms_norm_fn is None:
                raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")

        self.layers = nn.ModuleList(
            [
                create_block(
                    config.d_model,
                    ssm_cfg=config.ssm_cfg,
                    norm_epsilon=config.norm_epsilon,
                    rms_norm=config.rms_norm,
                    residual_in_fp32=config.residual_in_fp32,
                    fused_add_norm=config.fused_add_norm,
                    layer_idx=i,
                    bidirectional=config.bidirectional,
                    bidirectional_strategy=config.bidirectional_strategy,
                    bidirectional_weight_tie=config.bidirectional_weight_tie,
                    rcps=config.rcps,
                    **factory_kwargs,
                )
                for i in range(config.n_layer)
            ]
        )

        norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
            config.d_model, eps=config.norm_epsilon, **factory_kwargs
        )
        self.norm_f = norm_f if (config.fused_add_norm or not config.rcps) else RCPSAddNormWrapper(norm_f)

    def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
        """Mixer forward."""
        all_hidden_states = []
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.embeddings(input_ids)

        residual = None
        for layer in self.layers:
            if output_hidden_states:
                all_hidden_states.append(hidden_states)
            # TODO: Add support for gradient checkpointing
            hidden_states, residual = layer(
                hidden_states, residual, inference_params=None
            )

        if not self.fused_add_norm:
            if self.rcps:
                # Set prenorm=False here since we don't need the residual
                hidden_states = self.norm_f(hidden_states, residual=residual, prenorm=False)
            else:
                residual = (hidden_states + residual) if residual is not None else hidden_states
                hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
        else:
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
            if self.rcps:
                # Set prenorm=False here since we don't need the residual
                hidden_states_fwd = fused_add_norm_fn(
                    hidden_states[..., :hidden_states.shape[-1] // 2],
                    self.norm_f.weight,
                    self.norm_f.bias,
                    eps=self.norm_f.eps,
                    residual=residual[..., :hidden_states.shape[-1] // 2],
                    prenorm=False,
                    residual_in_fp32=self.residual_in_fp32,
                )
                hidden_states_rc = fused_add_norm_fn(
                    hidden_states[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),
                    self.norm_f.weight,
                    self.norm_f.bias,
                    eps=self.norm_f.eps,
                    residual=residual[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),
                    prenorm=False,
                    residual_in_fp32=self.residual_in_fp32,
                )
                hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)
            else:
                # Set prenorm=False here since we don't need the residual
                hidden_states = fused_add_norm_fn(
                    hidden_states,
                    self.norm_f.weight,
                    self.norm_f.bias,
                    eps=self.norm_f.eps,
                    residual=residual,
                    prenorm=False,
                    residual_in_fp32=self.residual_in_fp32,
                )
            if output_hidden_states:
                all_hidden_states.append(hidden_states)
        return hidden_states, all_hidden_states


def cross_entropy(logits, y, ignore_index=-100):
    """Cross entropy loss."""
    logits = logits.view(-1, logits.shape[-1])
    y = y.view(-1)
    return F.cross_entropy(logits, y, ignore_index=ignore_index)


def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
    """Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
    logits = logits.view(-1, logits.shape[-1])
    y = y.view(-1)
    ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
    loss_weights = loss_weights.view(-1)
    loss_weights[y == ignore_index] = 0.0
    # TODO: Follows GPN implementation, but should we remove weight normalization?
    return (ce * (loss_weights / loss_weights.sum())).sum()


class CaduceusPreTrainedModel(PreTrainedModel):
    """PreTrainedModel wrapper for Caduceus backbone."""
    config_class = CaduceusConfig
    base_model_prefix = "caduceus"
    supports_gradient_checkpointing = False
    _no_split_modules = ["BiMambaWrapper"]

    def _init_weights(
            self,
            module,
            initializer_range=0.02,  # Now only used for embedding layer.
            **kwargs,
    ):
        """Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py"""

        n_layer = self.config.n_layer
        initialized_cfg = self.config.initializer_cfg if self.config.initializer_cfg is not None else {}
        rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True)
        initializer_range = initialized_cfg.get("initializer_range", initializer_range)
        n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1)

        if isinstance(module, nn.Linear):
            if module.bias is not None:
                if not getattr(module.bias, "_no_reinit", False):
                    nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, std=initializer_range)

        if rescale_prenorm_residual:
            # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
            #   > A modified initialization which accounts for the accumulation on the residual path with model depth.
            #   > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of
            #   residual layers.
            #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
            #
            # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
            for name, p in module.named_parameters():
                if name in ["out_proj.weight", "fc2.weight"]:
                    # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                    # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
                    # We need to reinit p since this code could be called multiple times
                    # Having just p *= scale would repeatedly scale it down
                    nn.init.kaiming_uniform_(p, a=math.sqrt(5))
                    with torch.no_grad():
                        p /= math.sqrt(n_residuals_per_layer * n_layer)


class Caduceus(CaduceusPreTrainedModel):
    """Caduceus model that can be instantiated using HF patterns."""
    def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
        super().__init__(config)

        if config.rcps:
            assert config.complement_map is not None, "Complement map must be provided for RCPS."

        # Adjust vocab size and complement maps if vocab padding is set.
        if config.vocab_size % config.pad_vocab_size_multiple != 0:
            config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple)
        if config.complement_map is not None and config.vocab_size > len(config.complement_map):
            for i in range(len(config.complement_map), config.vocab_size):
                config.complement_map[i] = i

        self.config = config
        factory_kwargs = {"device": device, "dtype": dtype}
        self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
        """HF-compatible forward method."""
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        hidden_states, all_hidden_states = self.backbone(
            input_ids,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states
        )
        if return_dict:
            return BaseModelOutputWithNoAttention(
                last_hidden_state=hidden_states,
                hidden_states=all_hidden_states if output_hidden_states else None
            )
        elif output_hidden_states:
            return hidden_states, all_hidden_states
        else:
            return hidden_states


class CaduceusForMaskedLM(CaduceusPreTrainedModel):
    """HF-compatible Caduceus model for masked language modeling."""

    def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
        super().__init__(config, **kwargs)
        factory_kwargs = {"device": device, "dtype": dtype}
        self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
        if config.rcps:
            self.lm_head = RCPSLMHead(
                complement_map=self.config.complement_map,  # Use caduceus config as it might have been updated
                vocab_size=self.config.vocab_size,  # Use caduceus config as it might have been updated
                true_dim=config.d_model,
                dtype=dtype
            )
        else:
            self.lm_head = nn.Linear(
                config.d_model,
                self.config.vocab_size,  # Use caduceus config as it might have been updated
                bias=False,
                **factory_kwargs
            )

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.caduceus.backbone.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        if self.config.rcps:
            raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.")
        self.caduceus.backbone.embeddings.word_embeddings = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        """Overrides output embeddings."""
        if self.config.rcps:
            raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
        self.lm_head = new_embeddings

    def tie_weights(self):
        """Tie weights, accounting for RCPS."""
        if self.config.rcps:
            self.lm_head.set_weight(self.get_input_embeddings().weight)
        else:
            super().tie_weights()

    def get_decoder(self):
        """Get decoder (backbone) for the model."""
        return self.caduceus

    def set_decoder(self, decoder):
        """Set decoder (backbone) for the model."""
        self.caduceus = decoder

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        loss_weights: Optional[torch.FloatTensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, MaskedLMOutput]:
        """HF-compatible forward method."""

        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.caduceus(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            if loss_weights is not None:
                loss = weighted_cross_entropy(logits, labels, loss_weights, ignore_index=self.config.pad_token_id)
            else:
                loss = cross_entropy(logits, labels, ignore_index=self.config.pad_token_id)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return MaskedLMOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
        )


class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
    def __init__(
            self,
            config: CaduceusConfig,
            pooling_strategy: str = "mean",
            conjoin_train: bool = False,
            conjoin_eval: bool = False,
            device=None,
            dtype=None,
            **kwargs):
        super().__init__(config, **kwargs)
        if pooling_strategy not in ["mean", "max", "first", "last"]:
            raise NotImplementedError(f"Pooling strategy `{pooling_strategy}` not implemented.")
        self.pooling_strategy = pooling_strategy
        factory_kwargs = {"device": device, "dtype": dtype}
        self.num_labels = kwargs.get("num_labels", config.num_labels)
        self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
        self.score = nn.Linear(config.d_model, self.num_labels, bias=False)

        self.conjoin_train = conjoin_train
        self.conjoin_eval = conjoin_eval

        # Initialize weights and apply final processing
        self.post_init()
        self.init_scorer()

    def init_scorer(self, initializer_range=0.02):
        initializer_range = self.config.initializer_cfg.get("initializer_range", initializer_range) \
            if self.config.initializer_cfg is not None else initializer_range
        self.score.weight.data.normal_(std=initializer_range)

    def get_input_embeddings(self):
        return self.caduceus.backbone.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        if self.config.rcps:
            raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.")
        self.caduceus.backbone.embeddings.word_embeddings = value

    def pool_hidden_states(self, hidden_states, sequence_length_dim=1):
        """Pools hidden states along sequence length dimension."""
        if self.pooling_strategy == "mean":  # Mean pooling along sequence length dimension
            return hidden_states.mean(dim=sequence_length_dim)
        if self.pooling_strategy == "max":  # Max pooling along sequence length dimension
            return hidden_states.max(dim=sequence_length_dim).values
        if self.pooling_strategy == "last":  # Use embedding of last token in the sequence
            return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[-1, ...]
        if self.pooling_strategy == "first":  # Use embedding of first token in the sequence
            return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Get hidden representations from the backbone
        if self.config.rcps:  # Hidden states have 2 * d_model channels for RCPS
            transformer_outputs = self.caduceus(
                input_ids,
                inputs_embeds=inputs_embeds,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            hidden_states = torch.stack(
                [
                    transformer_outputs[0][..., :self.config.d_model],
                    torch.flip(transformer_outputs[0][..., self.config.d_model:], dims=[1, 2])
                 ],
                dim=-1
            )
        elif self.conjoin_train or (self.conjoin_eval and not self.training):  # For conjoining / post-hoc conjoining
            assert input_ids is not None, "`input_ids` must be provided for conjoining."
            assert input_ids.ndim == 3, "`input_ids` must be 3D tensor: channels corresponds to forward and rc strands."
            transformer_outputs = self.caduceus(
                input_ids[..., 0],
                inputs_embeds=None,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            transformer_outputs_rc = self.caduceus(
                input_ids[..., 1],
                inputs_embeds=None,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            # Stack along channel dimension (dim=-1)
            hidden_states = torch.stack([transformer_outputs[0], transformer_outputs_rc[0]], dim=-1)
        else:
            transformer_outputs = self.caduceus(
                input_ids,
                inputs_embeds=None,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            hidden_states = transformer_outputs[0]

        # Pool and get logits
        pooled_hidden_states = self.pool_hidden_states(hidden_states)
        # Potentially run `score` twice (with parameters shared) for conjoining
        if hidden_states.ndim == 4:  # bsz, seq_len, hidden_dim, 2 where last channel has the stacked fwd and rc reps
            logits_fwd = self.score(pooled_hidden_states[..., 0])
            logits_rc = self.score(pooled_hidden_states[..., 1])
            logits = (logits_fwd + logits_rc) / 2
        else:
            logits = self.score(pooled_hidden_states)

        loss = None
        if labels is not None:
            labels = labels.to(logits.device)
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                if self.num_labels == 1:
                    loss = F.mse_loss(logits.squeeze(), labels.squeeze())
                else:
                    loss = F.mse_loss(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss = F.binary_cross_entropy_with_logits(logits, labels)
        if not return_dict:
            output = (logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=transformer_outputs.hidden_states,
        )


================================================
FILE: caduceus/modeling_rcps.py
================================================
"""Reverse-complement equivariant modules.

"""
from collections import OrderedDict
from typing import Optional

import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F

try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn  # Legacy mambav1 file structure
except ImportError:
    try:
        from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn  # mambav2 file structure
    except ImportError:
        RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None


class RCPSEmbedding(nn.Module):
    """Embedding layer that supports reverse-complement equivariance."""
    def __init__(self, vocab_size: int, d_model: int, complement_map: dict, **factory_kwargs):
        """
        Args:
            vocab_size: Size of vocabulary.
            d_model: Dimensionality of embedding (actual embedding matrix will have 1/2 the output dim).
            complement_map: Dictionary mapping each token id to its complement.
        """
        super().__init__()
        self.register_buffer(
            "complement_map",
            torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
        )
        self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)

    @property
    def weight(self):
        """Embedding weights."""
        return self.embedding.weight

    def set_weight(self, value):
        """Set embedding weights."""
        self.embedding.weight = value

    def rc(self, x):
        """Reverse-complement a tensor of input_ids by flipping along length dimension and complementing the ids."""
        return torch.gather(
            self.complement_map.unsqueeze(0).expand(x.shape[0], -1),
            dim=1,
            index=torch.flip(x, dims=[-1])
        )

    def forward(self, input_ids):
        """Reverse-complement equivariant forward pass.

        This embedding module doubles the output dimensionality to support reverse-complement equivariance.

        Args:
            input_ids: Input tensor of shape (batch_size, seq_len)
        Returns:
            Embedding tensor of shape (batch_size, seq_len, d_model * 2)
        """
        fwd_out = self.embedding(input_ids)
        rc_out = torch.flip(self.embedding(self.rc(input_ids)), dims=[-2, -1])

        return torch.cat([fwd_out, rc_out], dim=-1)


class RCPSWrapper(nn.Module):
    """Wrapper to convert arbitrary nn.Module into a reverse-complement equivariant module.

    See ref. "Towards a Better Understanding of Reverse-Complement Equivariance for Deep Learning Models in Regulatory
    Genomics", Zhou et al. (2022), https://proceedings.mlr.press/v165/zhou22a.html for more details.
    """
    def __init__(self, submodule: nn.Module):
        super().__init__()
        self.submodule = submodule

    @staticmethod
    def rc(x):
        """Reverse-complement a tensor by flipping the length (dim=-2) and channel (dim=-1) dimensions."""
        return torch.flip(x, dims=[-2, -1])

    def forward(self, x, **kwargs):
        """Reverse-complement equivariant forward pass.

        Args:
            x: Input tensor of shape (batch_size, seq_len, channels)
        Returns:
            Output tensor of shape (batch_size, seq_len, channels)
        """
        n_channels = x.shape[-1]
        # Run submodule along sequence
        fwd_out = self.submodule(x[..., :n_channels // 2], **kwargs)
        # Run submodule along rc-sequence
        rc_out = self.submodule(self.rc(x[..., n_channels // 2:]), **kwargs)
        # Concatenate along channel dimension (dim=-1)
        return torch.cat([fwd_out, self.rc(rc_out)], dim=-1)


class RCPSAddNormWrapper(RCPSWrapper):
    """RC equivariant AddNorm layer."""
    def __init__(self, submodule: nn.Module):
        super().__init__(submodule)

    def forward(self, x, residual=None, prenorm=False):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, channels)
            residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
            prenorm: Whether to return residual.
        """
        n_channels = x.shape[-1]
        if residual is None:
            residual = x
            x_fwd = self.submodule(x[..., :n_channels // 2].to(dtype=self.submodule.weight.dtype))
            x_rc = self.submodule(self.rc(x[..., n_channels // 2:]).to(dtype=self.submodule.weight.dtype))
            x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
        else:
            residual_fwd = x[..., :n_channels // 2] + residual[..., :n_channels // 2]
            x_fwd = self.submodule(residual_fwd.to(dtype=self.submodule.weight.dtype))

            residual_rc = self.rc(x[..., n_channels // 2:]) + self.rc(residual[..., n_channels // 2:])
            x_rc = self.submodule(residual_rc.to(dtype=self.submodule.weight.dtype))

            residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
            x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)

        return x if not prenorm else (x, residual)


class RCPSMambaBlock(nn.Module):
    def __init__(
            self,
            dim,
            mixer_cls,
            norm_cls=nn.LayerNorm,
            fused_add_norm=False,
            residual_in_fp32=False,
            device=None,  # Keep for consistency with original Mamba Block
            dtype=None,  # Keep for consistency with original Mamba Block
    ):
        """RCPS version of simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection.

        Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
        """
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.mixer = RCPSWrapper(mixer_cls(dim))
        norm_f = norm_cls(dim)
        self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
        if self.fused_add_norm:
            assert RMSNorm is not None, "RMSNorm import fails"
            assert isinstance(
                self.norm, (nn.LayerNorm, RMSNorm)
            ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"

    def forward(
        self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
    ):
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states: the sequence to the encoder layer (required).
            residual: hidden_states = Mixer(LN(residual)).
            inference_params: inference parameters for mixer.
        """
        if not self.fused_add_norm:
            hidden_states, residual = self.norm(hidden_states, residual=residual, prenorm=True)
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
        else:
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn

            hidden_states_fwd, residual_fwd = fused_add_norm_fn(
                hidden_states[..., hidden_states.shape[-1] // 2:],
                self.norm.weight,
                self.norm.bias,
                residual=residual[..., hidden_states.shape[-1] // 2:] if residual is not None else None,
                prenorm=True,
                residual_in_fp32=self.residual_in_fp32,
                eps=self.norm.eps,
            )

            hidden_states_rc, residual_rc = fused_add_norm_fn(
                hidden_states[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]),
                self.norm.weight,
                self.norm.bias,
                residual=residual[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]) if residual is not None else None,
                prenorm=True,
                residual_in_fp32=self.residual_in_fp32,
                eps=self.norm.eps,
            )
            hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)
            residual = torch.cat([residual_fwd, residual_rc.flip(dims=[-2, -1])], dim=-1)
        hidden_states = self.mixer(hidden_states, inference_params=inference_params)
        return hidden_states, residual

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        """Allocate inference cache for mixer.

        Keep for compatibility with original Mamba Block.
        """
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)


class RCPSLMHead(nn.Module):
    """LM Head for reverse-complement equivariant inputs, which have dim * 2 relative to standard inputs."""
    def __init__(self, true_dim: int, vocab_size: int, complement_map: dict, **factory_kwargs):
        """
        `true_dim` corresponds to the actual dimensionality of the input were it not reverse-complement
        equivariant, i.e. 0.5 times the actual input dim.
        """
        super().__init__()
        self.register_buffer(
            "complement_map",
            torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
        )
        self.true_dim = true_dim
        self.lm_head = nn.Linear(true_dim, vocab_size, bias=False, **factory_kwargs)

    @property
    def weight(self):
        """LM head weights."""
        return self.lm_head.weight

    def set_weight(self, value):
        """Set LM head weights."""
        self.lm_head.weight = value

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, dim), where dim = 2 * true_dim.
        """
        n_channels = x.shape[-1]
        assert n_channels == 2 * self.true_dim, "Input must have 2 * true_dim channels."
        fwd_logits = F.linear(x[..., :n_channels // 2], self.weight, bias=self.lm_head.bias)
        rc_logits = F.linear(
            torch.flip(x[..., n_channels // 2:], dims=[-1]),
            self.weight[self.complement_map, :],
            bias=self.lm_head.bias
        )
        return fwd_logits + rc_logits


================================================
FILE: caduceus/tests/test_rcps.py
================================================
"""Tests for RCPS modules.

"""

import pytest
import torch
from torch import nn
from torch.nn import functional as F

try:  # Legacy mambav1 file structure
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    try:  # mambav2 file structure
        from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
    except ImportError:
        RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None

from caduceus.modeling_rcps import (
    RCPSEmbedding, RCPSAddNormWrapper, RCPSLMHead, RCPSWrapper
)

from caduceus.modeling_caduceus import (
    CaduceusConfig, CaduceusMixerModel, CaduceusForMaskedLM, create_block
)


@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [512])
@pytest.mark.parametrize("d_model", [256])
@pytest.mark.parametrize("dtype", [torch.float32])
def test_rcps_embedding(batch_size, seq_len, d_model, dtype):
    # Set tolerance
    device = torch.device("cpu")
    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
    if dtype == torch.bfloat16:
        rtol, atol = 3e-2, 5e-2
    # Set seed
    torch.random.manual_seed(0)

    # Define complement map
    str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6}
    complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
    complement_map = {
        str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v
        for k, v in str_to_id.items()
    }
    vocab_size = 12
    pad_vocab_size_multiple = 8
    if vocab_size % pad_vocab_size_multiple != 0:
        vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
    if vocab_size > len(complement_map):
        for i in range(len(complement_map), vocab_size):
            complement_map[i] = i

    # Generate random sequences
    input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)
    rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device)

    # Test RC equivariance of embedding layer
    factory_kwargs = {"device": device, "dtype": dtype}
    embedding = RCPSEmbedding(
        vocab_size=vocab_size,
        d_model=d_model,
        complement_map=complement_map,
        **factory_kwargs
    ).to(device)
    out_embed = embedding(input_ids)
    rc_out_embed = torch.flip(embedding(rc_input_ids), dims=[-2, -1])
    # Test that channels are 2 * d_model
    assert tuple(out_embed.size()) == (batch_size, seq_len, d_model * 2)
    assert tuple(rc_out_embed.size()) == (batch_size, seq_len, d_model * 2)
    # Test that RC equivariance holds
    assert torch.allclose(out_embed.detach(), rc_out_embed.detach(), rtol=rtol, atol=atol)


@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seq_len", [1024])
@pytest.mark.parametrize("d_model", [128])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_rcps_wrapper(batch_size, seq_len, d_model, dtype):
    # Set tolerance
    device = torch.device("cuda")
    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
    if dtype == torch.bfloat16:
        rtol, atol = 3e-2, 5e-2
    # Set seed
    torch.random.manual_seed(0)

    # Generate random sequence with 2 * d_model channels
    x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)
    rc_x = torch.flip(x, dims=[-2, -1])

    factory_kwargs = {"device": device, "dtype": dtype}
    module = nn.Sequential(
        nn.Linear(d_model, d_model, bias=False, **factory_kwargs),
        nn.ReLU(),
        nn.Linear(d_model, d_model*2, bias=True, **factory_kwargs),
        nn.ReLU(),
        nn.Linear(d_model * 2, d_model, bias=True, **factory_kwargs)
    )

    # Test RC equivariance of wrapper
    rcps_module = RCPSWrapper(module).to(device)
    out = rcps_module(x)
    rc_out = torch.flip(rcps_module(rc_x), dims=[-2, -1])
    assert out.size() == x.size()
    assert rc_out.size() == x.size()
    assert torch.allclose(out.detach(), rc_out.detach(), rtol=rtol, atol=atol)


@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seq_len", [1024])
@pytest.mark.parametrize("d_model", [128])
@pytest.mark.parametrize("prenorm", [False, True])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_rcps_add_norm_wrapper(batch_size, seq_len, d_model, prenorm, dtype):
    # Set tolerance
    device = torch.device("cuda")
    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
    if dtype == torch.bfloat16:
        rtol, atol = 3e-2, 5e-2
    # Set seed
    torch.random.manual_seed(0)

    # Generate random sequence with 2 * d_model channels
    x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)
    rc_x = torch.flip(x, dims=[-2, -1])

    factory_kwargs = {"device": device, "dtype": dtype}
    norm = RMSNorm(d_model, eps=1e-5, **factory_kwargs)

    # Test RC equivariance of wrapper
    rcps_module = RCPSAddNormWrapper(norm).to(device)
    out = rcps_module(x, prenorm=prenorm)
    if prenorm:  # returns tuple
        rc_out = tuple([torch.flip(r, dims=[-2, -1])
                        for r in rcps_module(rc_x, prenorm=prenorm)])
        for f, r in zip(out, rc_out):
            assert f.size() == x.size()
            assert r.size() == x.size()
            assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)
    else:
        rc_out = torch.flip(rcps_module(rc_x, prenorm=prenorm), dims=[-2, -1])
        assert out.size() == x.size()
        assert rc_out.size() == x.size()
        assert torch.allclose(out.detach(), rc_out.detach(), rtol=rtol, atol=atol)


@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seq_len", [1024])
@pytest.mark.parametrize("d_model", [128])
@pytest.mark.parametrize("bidirectional", [True, False])
@pytest.mark.parametrize("fused_add_norm", [True, False])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_rcps_mamba_block_wrapper(batch_size, seq_len, d_model, bidirectional, fused_add_norm, dtype):
    # Set tolerance
    device = torch.device("cuda")
    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
    if dtype == torch.bfloat16:
        rtol, atol = 3e-2, 5e-2
    # Set seed
    torch.random.manual_seed(0)

    # Generate random sequence with 2 * d_model channels
    x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)
    rc_x = torch.flip(x, dims=[-2, -1])

    ssm_cfg = {
        "d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random",
        "dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True
    }
    factory_kwargs = {"device": device, "dtype": dtype}

    mamba_block = create_block(
        d_model,
        ssm_cfg=ssm_cfg,
        norm_epsilon=1e-5,
        rms_norm=True,
        residual_in_fp32=True,
        fused_add_norm=fused_add_norm,
        layer_idx=0,
        bidirectional=bidirectional,
        bidirectional_strategy="add",
        bidirectional_weight_tie=True,
        rcps=True,
        **factory_kwargs
    )

    # Test RC equivariance of wrapper
    out = mamba_block(x, residual=None)
    rc_out = tuple([torch.flip(r, dims=[-2, -1]) for r in mamba_block(rc_x, residual=None)])
    for f, r in zip(out, rc_out):
        assert f.size() == x.size()
        assert r.size() == x.size()
        assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)

    out = mamba_block(x, residual=x)
    rc_out = tuple([torch.flip(r, dims=[-2, -1]) for r in mamba_block(rc_x, residual=rc_x)])
    for f, r in zip(out, rc_out):
        assert f.size() == x.size()
        assert r.size() == x.size()
        assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)


@pytest.mark.parametrize("batch_size", [2, 4])
@pytest.mark.parametrize("seq_len", [1, 1024, 2048])
@pytest.mark.parametrize("d_model", [2, 128, 256])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_rcps_lm_head(batch_size, seq_len, d_model, dtype):
    # Set tolerance
    device = torch.device("cuda")
    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
    if dtype == torch.bfloat16:
        rtol, atol = 3e-2, 5e-2

    # Set seed
    torch.random.manual_seed(0)

    # Define complement map
    str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6}
    complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
    complement_map = {
        str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v
        for k, v in str_to_id.items()
    }
    factory_kwargs = {"device": device, "dtype": dtype}
    vocab_size = 12
    if vocab_size > len(complement_map):
        for i in range(len(complement_map), vocab_size):
            complement_map[i] = i

    # Instantiate LM head
    lm_head = RCPSLMHead(
        complement_map=complement_map,
        vocab_size=vocab_size,
        true_dim=d_model,
        **factory_kwargs
    )

    # Generate random sequence with 2 * d_model channels
    x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)
    rc_x = torch.flip(x, dims=[-2, -1])

    # Test RC equivariance of LM head
    out = lm_head(x)
    rc_out = lm_head(rc_x)
    assert tuple(out.size()) == (batch_size, seq_len, vocab_size)
    assert tuple(rc_out.size()) == (batch_size, seq_len, vocab_size)
    assert torch.allclose(
        out.detach(),
        torch.flip(rc_out.detach()[..., lm_head.complement_map], dims=[1]),
        rtol=rtol,
        atol=atol
    )
    assert torch.allclose(
        F.softmax(out, dim=-1).detach(),
        torch.flip(F.softmax(rc_out, dim=-1).detach()[..., lm_head.complement_map], dims=[1]),
        rtol=rtol,
        atol=atol
    )


@pytest.mark.parametrize("batch_size", [2, 4])
@pytest.mark.parametrize("seq_len", [1024, 2048])
@pytest.mark.parametrize("n_layer", [1, 2, 3])
@pytest.mark.parametrize("d_model", [128, 256])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize("fused_add_norm", [True, False])
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("bidirectional_weight_tie", [False, True])
def test_rcps_backbone(batch_size, seq_len, n_layer, d_model, dtype, fused_add_norm,
                       bidirectional, bidirectional_weight_tie):
    # Set tolerance
    device = torch.device("cuda")
    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
    if dtype == torch.bfloat16:
        rtol, atol = 3e-2, 5e-2

    # Set seed
    torch.random.manual_seed(0)

    # Define complement map
    str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6}
    complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
    complement_map = {
        str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v
        for k, v in str_to_id.items()
    }

    # Setup CaduceusConfig
    initializer_cfg = {"initializer_range": 0.02, "rescale_prenorm_residual": True, "n_residuals_per_layer": 1}
    ssm_cfg = {
        "d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random",
        "dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True
    }
    config = CaduceusConfig(
        d_model=d_model,
        n_layer=n_layer,
        vocab_size=12,
        ssm_cfg=ssm_cfg,
        rms_norm=True,
        residual_in_fp32=False,
        fused_add_norm=fused_add_norm,
        pad_vocab_size_multiple=8,
        norm_epsilon=1e-5,
        initializer_cfg=initializer_cfg,
        bidirectional=bidirectional,
        bidirectional_strategy="add",
        bidirectional_weight_tie=bidirectional_weight_tie,
        rcps=True,
        complement_map=complement_map,
    )
    factory_kwargs = {"device": device, "dtype": dtype}

    # Instantiate model
    backbone = CaduceusMixerModel(
        config,
        **factory_kwargs,
    ).to(device)

    # Generate random sequences
    input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)
    rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device)

    # Test RC equivariance of rc backbone
    out = backbone(input_ids)[0]
    rc_out = backbone(rc_input_ids)[0]
    if isinstance(rc_out, tuple):
        rc_out = tuple([torch.flip(r, dims=[1, 2]) for r in rc_out])
        for f, r in zip(out, rc_out):
            assert f.size() == (batch_size, seq_len, d_model * 2)
            assert r.size() == (batch_size, seq_len, d_model * 2)
            assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)
    else:
        # Hidden state size should double
        assert tuple(out.size()) == (batch_size, seq_len, d_model * 2)
        assert tuple(rc_out.size()) == (batch_size, seq_len, d_model * 2)
        assert torch.allclose(out.detach(), torch.flip(rc_out.detach(), dims=[1, 2]), rtol=rtol, atol=atol)


@pytest.mark.parametrize("batch_size", [2, 4])
@pytest.mark.parametrize("seq_len", [1024, 2048])
@pytest.mark.parametrize("n_layer", [1, 3, 4])
@pytest.mark.parametrize("d_model", [128, 256])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize("bidirectional", [False, True])
@pytest.mark.parametrize("bidirectional_weight_tie", [False, True])
def test_rcps_mamba_lm(batch_size, seq_len, n_layer, d_model, dtype, bidirectional, bidirectional_weight_tie):
    # Set tolerance
    device = torch.device("cuda")
    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
    if dtype == torch.bfloat16:
        rtol, atol = 3e-2, 5e-2

    # Set seed
    torch.random.manual_seed(0)

    # Define complement map
    str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6}
    complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
    complement_map = {
        str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v
        for k, v in str_to_id.items()
    }

    # Setup CaduceusConfig
    initializer_cfg = {"initializer_range": 0.02, "rescale_prenorm_residual": True, "n_residuals_per_layer": 1}
    ssm_cfg = {
        "d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random",
        "dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True
    }
    config = CaduceusConfig(
        d_model=d_model,
        n_layer=n_layer,
        vocab_size=12,
        ssm_cfg=ssm_cfg,
        rms_norm=True,
        residual_in_fp32=False,
        fused_add_norm=True,
        pad_vocab_size_multiple=8,
        norm_epsilon=1e-5,
        initializer_cfg=initializer_cfg,
        bidirectional=bidirectional,
        bidirectional_strategy="add",
        bidirectional_weight_tie=bidirectional_weight_tie,
        rcps=True,
        complement_map=complement_map,
    )
    factory_kwargs = {"device": device, "dtype": dtype}

    # Instantiate model
    mamba_lm = CaduceusForMaskedLM(
        config=config,
        **factory_kwargs,
    ).to(device)

    # Generate random sequences
    input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)
    rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device)

    # Test RC equivariance of rc backbone
    out = mamba_lm(input_ids)
    rc_out = mamba_lm(rc_input_ids)
    if config.vocab_size % config.pad_vocab_size_multiple != 0:
        config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple)
    assert tuple(out.logits.size()) == (batch_size, seq_len, config.vocab_size)
    assert tuple(rc_out.logits.size()) == (batch_size, seq_len, config.vocab_size)
    assert torch.allclose(
        out.logits.detach(),
        torch.flip(rc_out.logits.detach()[..., mamba_lm.lm_head.complement_map], dims=[1]),
        rtol=rtol,
        atol=atol
    )
    assert torch.allclose(
        F.softmax(out.logits, dim=-1).detach(),
        torch.flip(F.softmax(rc_out.logits, dim=-1).detach()[..., mamba_lm.lm_head.complement_map], dims=[1]),
        rtol=rtol,
        atol=atol
    )


@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seq_len", [1024])
@pytest.mark.parametrize("n_layer", [2])
@pytest.mark.parametrize("d_model", [128])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("bidirectional", [True, False])
@pytest.mark.parametrize("bidirectional_weight_tie", [True])
def test_collapse_invariance(batch_size, seq_len, n_layer, d_model, dtype, bidirectional, bidirectional_weight_tie):
    # Set tolerance
    device = torch.device("cuda")
    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)
    if dtype == torch.bfloat16:
        rtol, atol = 3e-2, 5e-2

    # Set seed
    torch.random.manual_seed(0)

    # Define complement map
    str_to_id = {"[CLS]": 0, "[MASK]": 1, "A": 2, "C": 3, "G": 4, "T": 5, "N": 6}
    complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
    complement_map = {
        str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v
        for k, v in str_to_id.items()
    }

    # Setup CaduceusConfig
    initializer_cfg = {"initializer_range": 0.02, "rescale_prenorm_residual": True, "n_residuals_per_layer": 1}
    ssm_cfg = {
        "d_state": 16, "d_conv": 4, "expand": 2, "dt_rank": "auto", "dt_min": 0.001, "dt_max": 0.1, "dt_init": "random",
        "dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, "bias": False, "use_fast_path": True
    }
    config = CaduceusConfig(
        d_model=d_model,
        n_layer=n_layer,
        vocab_size=12,
        ssm_cfg=ssm_cfg,
        rms_norm=True,
        residual_in_fp32=False,
        fused_add_norm=True,
        pad_vocab_size_multiple=8,
        norm_epsilon=1e-5,
        initializer_cfg=initializer_cfg,
        bidirectional=bidirectional,
        bidirectional_strategy="add",
        bidirectional_weight_tie=bidirectional_weight_tie,
        rcps=True,
        complement_map=complement_map,
    )
    factory_kwargs = {"device": device, "dtype": dtype}

    # Instantiate model
    backbone = CaduceusMixerModel(
        config,
        **factory_kwargs,
    ).to(device)

    # Generate random sequences
    input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)
    rc_input_ids = torch.flip(input_ids, dims=[-1]).to("cpu").apply_(lambda t: complement_map[t]).to(device)

    # Test RC Invariance when collapsing output of backbone
    out = backbone(input_ids)[0]
    out_collapse = (out[..., :d_model] + torch.flip(out[..., d_model:], dims=[1, 2])) / 2
    rc_out = backbone(rc_input_ids)[0]
    rc_out_collapse = (rc_out[..., :d_model] + torch.flip(rc_out[..., d_model:], dims=[1, 2])) / 2
    # Hidden state size should be d_model
    assert tuple(out_collapse.size()) == (batch_size, seq_len, d_model)
    assert tuple(rc_out_collapse.size()) == (batch_size, seq_len, d_model)
    assert torch.allclose(out_collapse.detach(), rc_out_collapse.detach(), rtol=rtol, atol=atol)


================================================
FILE: caduceus/tokenization_caduceus.py
================================================
"""Character tokenizer for Hugging Face.

"""

from typing import List, Optional, Dict, Sequence, Tuple

from transformers import PreTrainedTokenizer


class CaduceusTokenizer(PreTrainedTokenizer):
    model_input_names = ["input_ids"]

    def __init__(self,
                 model_max_length: int,
                 characters: Sequence[str] = ("A", "C", "G", "T", "N"),
                 complement_map=None,
                 bos_token="[BOS]",
                 eos_token="[SEP]",
                 sep_token="[SEP]",
                 cls_token="[CLS]",
                 pad_token="[PAD]",
                 mask_token="[MASK]",
                 unk_token="[UNK]",
                 **kwargs):
        """Character tokenizer for Hugging Face transformers.

        Adapted from https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen-hf/blob/main/tokenization_hyena.py
        Args:
            model_max_length (int): Model maximum sequence length.
            characters (Sequence[str]): List of desired characters. Any character which
                is not included in this list will be replaced by a special token called
                [UNK] with id=6. Following is a list of the special tokens with
                their corresponding ids:
                    "[CLS]": 0
                    "[SEP]": 1
                    "[BOS]": 2
                    "[MASK]": 3
                    "[PAD]": 4
                    "[RESERVED]": 5
                    "[UNK]": 6
                an id (starting at 7) will be assigned to each character.
            complement_map (Optional[Dict[str, str]]): Dictionary with string complements for each character.
        """
        if complement_map is None:
            complement_map = {"A": "T", "C": "G", "G": "C", "T": "A", "N": "N"}
        self.characters = characters
        self.model_max_length = model_max_length

        self._vocab_str_to_int = {
            "[CLS]": 0,
            "[SEP]": 1,
            "[BOS]": 2,
            "[MASK]": 3,
            "[PAD]": 4,
            "[RESERVED]": 5,
            "[UNK]": 6,
            **{ch: i + 7 for i, ch in enumerate(self.characters)},
        }
        self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
        add_prefix_space = kwargs.pop("add_prefix_space", False)
        padding_side = kwargs.pop("padding_side", "left")

        self._complement_map = {}
        for k, v in self._vocab_str_to_int.items():
            complement_id = self._vocab_str_to_int[complement_map[k]] if k in complement_map.keys() else v
            self._complement_map[self._vocab_str_to_int[k]] = complement_id

        super().__init__(
            bos_token=bos_token,
            eos_token=eos_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            unk_token=unk_token,
            add_prefix_space=add_prefix_space,
            model_max_length=model_max_length,
            padding_side=padding_side,
            **kwargs,
        )

    @property
    def vocab_size(self) -> int:
        return len(self._vocab_str_to_int)

    @property
    def complement_map(self) -> Dict[int, int]:
        return self._complement_map

    def _tokenize(self, text: str, **kwargs) -> List[str]:
        return list(text.upper())  # Convert all base pairs to uppercase

    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])

    def _convert_id_to_token(self, index: int) -> str:
        return self._vocab_int_to_str[index]

    def convert_tokens_to_string(self, tokens):
        return "".join(tokens)  # Note: this operation has lost info about which base pairs were originally lowercase

    def get_special_tokens_mask(
        self,
        token_ids_0: List[int],
        token_ids_1: Optional[List[int]] = None,
        already_has_special_tokens: bool = False,
    ) -> List[int]:
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0,
                token_ids_1=token_ids_1,
                already_has_special_tokens=True,
            )

        result = ([0] * len(token_ids_0)) + [1]
        if token_ids_1 is not None:
            result += ([0] * len(token_ids_1)) + [1]
        return result

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.sep_token_id]
        # cls = [self.cls_token_id]
        result = token_ids_0 + sep
        if token_ids_1 is not None:
            result += token_ids_1 + sep
        return result

    def get_vocab(self) -> Dict[str, int]:
        return self._vocab_str_to_int

    # Fixed vocabulary with no vocab file
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple:
        return ()


================================================
FILE: caduceus_env.yml
================================================
name: caduceus_env
channels:
  - pytorch
  - anaconda
  - nvidia
  - defaults
dependencies:
  - cuda-nvcc=11.7.99
  - pip=23.3.1
  - python=3.8
  - pytorch=2.2.0
  - torchaudio=2.2.0
  - torchaudio=2.2.0
  - torchdata=0.7.1
  - torchmetrics=1.2.1
  - torchtext=0.17.0
  - torchvision=0.17.0
  - pytorch-cuda=12.1
  - pip:
      - biopython==1.81
      - datasets==2.15.0
      - einops==0.7.0
      - enformer-pytorch==0.8.8
      - fsspec==2023.10.0
      - genomic-benchmarks==0.0.9
      - git-lfs==1.6
      - h5py==3.10.0
      - huggingface-hub==0.24.7
      - hydra-core==1.3.2
      - ipdb==0.13.13
      - matplotlib==3.7.4
      - notebook==7.1.1
      - nvitop==1.3.2
      - omegaconf==2.3.0
      - pandas==2.0.3
      - pyfaidx==0.8.1.1
      - pysam==0.22.0
      - pytest==8.0.2
      - pytorch-lightning==1.8.6
      - rich==13.7.0
      - seaborn==0.13.2
      - scikit-learn==1.3.2
      - timm==0.9.16
      - tqdm==4.66.1
      - transformers==4.38.1
      - triton==2.2.0
      - wandb==0.13.5
      - flash-attn==2.5.6
      - causal-conv1d===1.2.0.post2
      - mamba-ssm==1.2.0.post1


================================================
FILE: configs/callbacks/base.yaml
================================================
learning_rate_monitor:
  # _target_: pytorch_lightning.callbacks.LearningRateMonitor
  logging_interval: ${train.interval}

timer:
  # _target_: callbacks.timer.Timer
  step: True
  inter_step: False
  epoch: True
  val: True

params:
  # _target_: callbacks.params.ParamsLog
  total: True
  trainable: True
  fixed: True


================================================
FILE: configs/callbacks/checkpoint.yaml
================================================
model_checkpoint:
  monitor: ${train.monitor} # name of the logged metric which determines when model is improving
  mode: ${train.mode} # can be "max" or "min"
  save_top_k: 1 # save k best models (determined by above metric)
  save_last: False # True = additionally always save model from last epoch
  dirpath: "checkpoints/"
  filename: ${train.monitor}
  auto_insert_metric_name: False
  verbose: True

model_checkpoint_every_n_steps:
  monitor: train/loss # name of the logged metric which determines when model is improving
  mode: min # can be "max" or "min"
  save_top_k: 0 # Do not save any "best" models; this callback is being used to save every n train steps
  save_last: True # additionally always save model from last epoch
  dirpath: "checkpoints/"
  filename: train/loss
  auto_insert_metric_name: False
  verbose: True
  every_n_train_steps: 100

#model_checkpoint_every_epoch:
#  monitor: trainer/epoch  # name of the logged metric which determines when model is improving
#  mode: max # can be "max" or "min"
#  save_top_k: 1 # Do not save any "best" models; this callback is being used to save every n train steps
#  save_last: False # additionally always save model from last epoch
#  dirpath: "checkpoints/"
#  filename: null
#  auto_insert_metric_name: False
#  verbose: True
#  every_n_epochs: 1


================================================
FILE: configs/callbacks/gpu_affinity.yaml
================================================
gpu_affinity:
  _name_: gpu_affinity


================================================
FILE: configs/callbacks/rich.yaml
================================================
rich_model_summary:
  max_depth: 2

rich_progress_bar:
  refresh_rate_per_second: 1.0


================================================
FILE: configs/callbacks/val_every_n_global_steps.yaml
================================================
val_every_n_global_steps:
  every_n: 10000


================================================
FILE: configs/callbacks/wandb.yaml
================================================
defaults:
  - default

watch_model:
  _target_: src.callbacks.wandb_callbacks.WatchModel
  log: "all"
  log_freq: 100

upload_code_as_artifact:
  _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact
  code_dir: ${work_dir}/src

upload_ckpts_as_artifact:
  _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
  ckpt_dir: "checkpoints/"
  upload_best_only: True

log_f1_precision_recall_heatmap:
  _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap

log_confusion_matrix:
  _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix

log_image_predictions:
  _target_: src.callbacks.wandb_callbacks.LogImagePredictions
  num_samples: 8


================================================
FILE: configs/config.yaml
================================================
# @package _global_
defaults:
  - _self_
  - experiment: ???
  # - model: ???  # Model backbone
  # - pipeline: ???  # Specifies collection of configs, equivalent to next 5 lines
  # Pipelines should specify /loader, /dataset, /task, /encoder, /decoder (ideally in that order)
  # # - loader: default # Dataloader (e.g. handles batches)
  # # - dataset: cifar # Defines the data (x and y pairs)
  # # - task: multiclass_classification # Defines loss and metrics
  # # - encoder: null # Interface between data and model
  # # - decoder: null # Interface between model and targets

# Additional arguments used to configure the training loop
# Most of these set combinations of options in the PL trainer, add callbacks, or add features to the optimizer
train:
  seed: 0
  # These three options are used by callbacks (checkpoint, monitor) and scheduler
  # Most of them are task dependent and are set by the pipeline
  interval: ??? # Should be specified by scheduler. Also used by LR monitor
  monitor: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer
  mode: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer
  ema: 0.0 # Moving average model for validation
  test: True # Test after training
  debug: False # Special settings to make debugging more convenient
  ignore_warnings: False # Disable python warnings

  optimizer_param_grouping:
    bias_weight_decay: False
    normalization_weight_decay: False

  # These control state passing between batches
  state:
    mode: null # [ None | 'none' | 'reset' | 'bptt' | 'tbptt' ]
    n_context: 0 # How many steps to use as memory context. Must be >= 0 or None (null), meaning infinite context
    n_context_eval: ${.n_context} # Context at evaluation time
  # Convenience keys to allow grouping runs

  ckpt: checkpoints/last.ckpt # Resume training

  disable_dataset: False # Disable dataset loading
  validate_at_start: false

  pretrained_model_path: null # Path to pretrained model
  pretrained_model_strict_load: true # Whether to load the pretrained model even if the model is not compatible
  pretrained_model_state_hook: # Hook called on the loaded model's state_dict
    _name_: null
  post_init_hook: # After initializing model, call method on model
    _name_: null

  layer_decay: # Used for ImageNet finetuning
    _name_: null
    decay: 0.7

# We primarily use wandb so this is moved to top level in the config for convenience
# Set `~wandb` or `wandb=null` or `wandb.mode=disabled` to disable logging
# If other loggers are added, it would make sense to put this one level lower under train/ or logger/
wandb:
  project: dna
  group: ""
  job_type: training
  mode: online # choices=['online', 'offline', 'disabled']
  name: null
  save_dir: "."
  id: ${.name} # pass correct id to resume experiment!
  # Below options should not need to be specified
  # entity: ""  # set to name of your wandb team or just remove it
  # log_model: False
  # prefix: ""
  # job_type: "train"
  # tags: []

hydra:
  run:
    dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S-%f}
  job:
    chdir: true


================================================
FILE: configs/dataset/genomic_benchmark.yaml
================================================
_name_: genomic_benchmark
train_val_split_seed: ${train.seed}  # Used for train/validation splitting
dataset_name: dummy_mouse_enhancers_ensembl
dest_path: null
max_length: ${.${.dataset_name}.max_length}
max_length_val: ${.max_length}
max_length_test: ${.max_length}
d_output: ${.${.dataset_name}.classes}
use_padding: True
padding_side: 'left'
add_eos: False
batch_size: 128
train_len: ${.${.dataset_name}.train_len}
__l_max: ${.max_length}
shuffle: true  # set this as default!
# these are used to find the right attributes automatically for each dataset
dummy_mouse_enhancers_ensembl:
  train_len: 1210
  classes: 2
  max_length: 1024
demo_coding_vs_intergenomic_seqs:
  train_len: 100_000
  classes: 2
  max_length: 200
demo_human_or_worm:
  train_len: 100_000
  classes: 2
  max_length: 200
human_enhancers_cohn:
  train_len: 27791
  classes: 2
  max_length: 500
human_enhancers_ensembl:
  train_len: 154842
  classes: 2
  max_length: 512
human_ensembl_regulatory:
  train_len: 289061
  classes: 3
  max_length: 512
human_nontata_promoters:
  train_len: 36131
  classes: 2
  max_length: 251
human_ocr_ensembl:
  train_len: 174756
  classes: 2
  max_length: 512

# there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings
# name                                num_seqs        num_classes     median len    std
# dummy_mouse_enhancers_ensembl       1210            2               2381          984.4  
# demo_coding_vs_intergenomic_seqs    100_000         2               200           0
# demo_human_or_worm                  100_000         2               200           0
# human_enhancers_cohn                27791           2               500           0
# human_enhancers_ensembl             154842          2               269           122.6
# human_ensembl_regulatory            289061          3               401           184.3
# human_nontata_promoters             36131           2               251           0
# human_ocr_ensembl                   174756          2               315           108.1


================================================
FILE: configs/dataset/hg38.yaml
================================================
_name_: hg38
bed_file: null
fasta_file: null
dataset_name: hg38
tokenizer_name: null
cache_dir: null
max_length: 1024
add_eos: True
batch_size: 8  # per GPU
batch_size_eval: ${eval:${.batch_size} * 2}
num_workers: 4  # For preprocessing only
shuffle: True
__train_len: 34021
__l_max: ${.max_length}


================================================
FILE: configs/dataset/nucleotide_transformer.yaml
================================================
_name_: nucleotide_transformer  # this links to the overall SequenceDataset of all nucleotide transformer datasets
train_val_split_seed: ${train.seed}  # Used for train/validation splitting
dataset_name: enhancers  # this specifies which dataset in nuc trx
dest_path: null  # path to overall nuc trx datasets
max_length: ${.${.dataset_name}.max_length}
d_output: ${.${.dataset_name}.classes} 
use_padding: True
padding_side: left
add_eos: False
batch_size: 256
train_len: ${.${.dataset_name}.train_len}
__l_max: ${.max_length}
shuffle: true  # set this as default!
metric: ${.${.dataset_name}.metric}
# these are used to find the right attributes automatically for each dataset
enhancers:
  train_len: 14968
  classes: 2
  max_length: 200
  metric: mcc
enhancers_types:
  train_len: 14968
  classes: 3
  max_length: 200
  metric: mcc
H3:
  train_len: 13468
  classes: 2
  max_length: 500
  metric: mcc
H3K4me1:
  train_len: 28509
  classes: 2
  max_length: 500
  metric: mcc
H3K4me2:
  train_len: 27614
  classes: 2
  max_length: 500
  metric: mcc
H3K4me3:
  train_len: 33119
  classes: 2
  max_length: 500
  metric: mcc
H3K9ac:
  train_len: 25003
  classes: 2
  max_length: 500
  metric: mcc
H3K14ac:
  train_len: 29743
  classes: 2
  max_length: 500
  metric: mcc
H3K36me3:
  train_len: 31392
  classes: 2
  max_length: 500
  metric: mcc
H3K79me3:
  train_len: 25953
  classes: 2
  max_length: 500
  metric: mcc
H4:
  train_len: 13140
  classes: 2
  max_length: 500
  metric: mcc
H4ac:
  train_len: 30685
  classes: 2
  max_length: 500
  metric: mcc
promoter_all:
  train_len: 53276
  classes: 2
  max_length: 300
  metric: f1_binary
promoter_no_tata:
  train_len: 47767
  classes: 2
  max_length: 300
  metric: f1_binary
promoter_tata:
  train_len: 5517
  classes: 2
  max_length: 300
  metric: f1_binary
splice_sites_acceptors:
  train_len: 19961
  classes: 2
  max_length: 600
  metric: f1_binary
splice_sites_all:
  train_len: 27000
  classes: 3
  max_length: 400
  metric: accuracy
splice_sites_donors:
  train_len: 19775
  classes: 2
  max_length: 600
  metric: f1_binary

# name maxlen classes samples metric

# enhancers 200   2  14968 MCC
# enhancers_types 200   3  14968 MCC
# H3 500   2  13468 MCC
# H3K4me1  500   2  28509 MCC
# H3K4me2  500   2  27614 MCC
# H3K4me3  500   2  33119 MCC
# H3K9ac   500   2  25003 MCC
# H3K14ac  500   2  29743 MCC
# H3K36me3 500   2  31392 MCC
# H3K79me3 500   2  25953 MCC
# H4 500   2  13140 MCC
# H4ac  500   2  30685 MCC
# promoter_all   300   2  53276 F1
# promoter_no_tata 300   2  47759 F1
# promoter_tata  300   2  5517  F1
# splice_sites_acceptor   600   2  19961 F1
# splice_sites_all   400   2  27000 F1
# splice_sites_donor   600   2  19775 F1


================================================
FILE: configs/experiment/hg38/genomic_benchmark.yaml
================================================
# @package _global_
defaults:
  - /pipeline: genomic_benchmark
  - /model: ???
  - override /scheduler: cosine_warmup_timm

# there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings
# name                                num_seqs        num_classes     median len    std
# dummy_mouse_enhancers_ensembl       1210            2               2381          984.4  
# demo_coding_vs_intergenomic_seqs    100_000         2               200           0
# demo_human_or_worm                  100_000         2               200           0
# human_enhancers_cohn                27791           2               500           0
# human_enhancers_ensembl             154842          2               269           122.6
# human_ensembl_regulatory            289061          3               401           184.3
# human_nontata_promoters             36131           2               251           0
# human_ocr_ensembl                   174756          2               315           108.1

task:
  loss:
    _name_: cross_entropy

trainer:
  accelerator: gpu
  devices: 1
  num_nodes: 1
  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}
  max_epochs: 100
  precision: 16  # bf16 only a100
  gradient_clip_val: 1.0

model:
  _name_: dna_embedding

dataset:
  # optional, default is max_length
  tokenizer_name: char
  rc_aug: false  # reverse complement augmentation

scheduler:
# COSINE TIMM
  t_in_epochs: False
  t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}}
  warmup_lr_init: 1e-6
  warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01}
  lr_min: ${eval:0.1 * ${optimizer.lr}}


optimizer:
  lr: 6e-4
  weight_decay: 0.1

train:
  gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
  seed: 2222
  global_batch_size: ${dataset.batch_size}
  cross_validation: true
  remove_test_loader_in_eval: true  # test only at the end of training
  pretrained_model_strict_load: false  # false allows encoder/decoder to be used if new model uses it
  # for loading backbone and not head, requires both of these flags below
  pretrained_model_path: ???
  pretrained_model_state_hook:
    _name_: load_backbone
    freeze_backbone: false


================================================
FILE: configs/experiment/hg38/genomic_benchmark_cnn.yaml
================================================
# @package _global_
defaults:
  - /model: genomics_benchmark_cnn
  - /pipeline: genomic_benchmark
  - override /scheduler: cosine_warmup_timm

# there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings
# name                                num_seqs        num_classes     median len    std
# dummy_mouse_enhancers_ensembl       1210            2               2381          984.4
# demo_coding_vs_intergenomic_seqs    100_000         2               200           0
# demo_human_or_worm                  100_000         2               200           0
# human_enhancers_cohn                27791           2               500           0
# human_enhancers_ensembl             154842          2               269           122.6
# human_ensembl_regulatory            289061          3               401           184.3
# human_nontata_promoters             36131           2               251           0
# human_ocr_ensembl                   174756          2               315           108.1

task:
  loss:
    _name_: cross_entropy

trainer:
  accelerator: gpu
  devices: 1
  num_nodes: 1
  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}
  max_epochs: 100
  precision: 16  # bf16 only a100
  gradient_clip_val: 1.0

encoder: id
decoder: id

dataset:
  tokenizer_name: char
  rc_aug: false  # reverse complement augmentation

scheduler:
  t_in_epochs: False
  t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}}
  warmup_lr_init: 1e-6
  warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01}
  lr_min: ${eval:0.1 * ${optimizer.lr}}


optimizer:
  lr: 6e-4
  weight_decay: 0.1

train:
  gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
  seed: 2222
  global_batch_size: ${dataset.batch_size}
  cross_validation: true
  remove_test_loader_in_eval: true
  pretrained_model_strict_load: false  # false allows encoder/decoder to be used if new model uses it


================================================
FILE: configs/experiment/hg38/hg38.yaml
================================================
# @package _global_
defaults:
  - /pipeline: hg38
  - /model: ???  # Specify a model, e.g. model=mamba or model=hyena
  - override /scheduler: cosine_warmup_timm

task:
  _name_: lm
  loss:
    _name_: cross_entropy
    ignore_index: 4

trainer:
  accelerator: gpu
  devices: 1
  num_nodes: 1
  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}
  max_epochs: null
  max_steps: 10000
  precision: 16  # bf16 only a100
  gradient_clip_val: 1.0
  limit_val_batches: 0.125

dataset:
  batch_size: ${eval:1024//${trainer.devices}}
  max_length: 1024
  # optional, default is max_length
  max_length_val: ${dataset.max_length}
  max_length_test: ${dataset.max_length}
  tokenizer_name: char
  pad_max_length: null  # needed for bpe tokenizer
  add_eos: true
  rc_aug: false
  num_workers: 12
  use_fixed_len_val: false  # placing a fixed length val here, but it's really the test
  mlm: false
  mlm_probability: 0.0

scheduler:
  t_in_epochs: False
  t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
  warmup_prefix: True
  warmup_lr_init: 1e-6
  warmup_t: ${eval:0.1*${trainer.max_steps}}
  lr_min: 1e-4

optimizer:
  lr: 6e-4
  weight_decay: 0.1
  betas: [0.9, 0.95]

train:
  gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
  seed: 2222
  global_batch_size: 256  # effects the scheduler, need to set properly


================================================
FILE: configs/experiment/hg38/nucleotide_transformer.yaml
================================================
# @package _global_
defaults:
  - /pipeline: nucleotide_transformer
  - /model: ???
  - override /scheduler: cosine_warmup_timm

model:
  _name_: dna_embedding

trainer:
  accelerator: gpu
  devices: 1
  num_nodes: 1
  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}
  max_epochs: 100
  precision: 16  # bf16 only a100
  gradient_clip_val: 1.0

dataset:
  tokenizer_name: char
  rc_aug: false  # reverse complement augmentation

scheduler:
  t_in_epochs: False
  t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}}
  warmup_lr_init: 1e-6
  warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01}
  lr_min: ${eval:0.1 * ${optimizer.lr}}

optimizer:
  lr: 1e-3
  weight_decay: 0.1

train:
  gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
  seed: 2222
  global_batch_size: ${dataset.batch_size}
  cross_validation: true
  remove_test_loader_in_eval: true  # test only at the end of training
  pretrained_model_strict_load: false  # false allows encoder/decoder to be used if new model uses it
  # for loading backbone and not head, requires both of these flags below
  pretrained_model_path: ???
  pretrained_model_state_hook:
    _name_: load_backbone
    freeze_backbone: false


================================================
FILE: configs/loader/default.yaml
================================================
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
pin_memory: True
drop_last: True


================================================
FILE: configs/model/caduceus.yaml
================================================
# Use open-source version of Mamba
_name_: caduceus_lm
config:
  _target_: caduceus.configuration_caduceus.CaduceusConfig
  # From original MambaConfig
  d_model: 128
  n_layer: 2
  vocab_size: 12
  ssm_cfg:
    d_state: 16
    d_conv: 4
    expand: 2
    dt_rank: "auto"
    dt_min: 0.001
    dt_max: 0.1
    dt_init: "random"
    dt_scale: 1.0
    dt_init_floor: 1e-4
    conv_bias: true
    bias: false
    use_fast_path: true
  rms_norm: true
  fused_add_norm: true
  residual_in_fp32: false
  pad_vocab_size_multiple: 8
  # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
  norm_epsilon: 1e-5

  # Used in init_weights
  initializer_cfg:
    initializer_range: 0.02
    rescale_prenorm_residual: true
    n_residuals_per_layer: 1

  # Caduceus-specific params
  bidirectional: true,
  bidirectional_strategy: "add"
  bidirectional_weight_tie: true
  rcps: false

  # Used for RCPSEmbedding / RCPSLMHead (will be filled in during model instantiation using info from tokenizer)
  complement_map: null


================================================
FILE: configs/model/genomics_benchmark_cnn.yaml
================================================
# Use open-source version of Mamba
_name_: genomics_benchmark_cnn
number_of_classes: ${dataset.d_output}
vocab_size: 12
embedding_dim: 100  # See: https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/tree/main/experiments/torch_cnn_experiments
input_len: ${dataset.__l_max}


================================================
FILE: configs/model/hyena.yaml
================================================
_name_: hyena_lm
d_model: 128
n_layer: 2
d_inner: ${eval:4 * ${.d_model}}
vocab_size: 12
resid_dropout: 0.0
embed_dropout: 0.1
fused_mlp: False
fused_dropout_add_ln: False
checkpoint_mixer: False  # set true for memory reduction
checkpoint_mlp: False  # set true for memory reduction
residual_in_fp32: True
pad_vocab_size_multiple: 8
layer:
  _name_: hyena
  emb_dim: 5
  filter_order: 64
  local_order: 3
  l_max: ${eval:${dataset.max_length}+2}
  modulate: True
  w: 10
  lr: ${optimizer.lr}
  wd: 0.0
  lr_pos_emb: 0.0


================================================
FILE: configs/model/layer/hyena.yaml
================================================
_name_: hyena
l_max: 1024
order: 2
filter_order: 64
num_heads: 1
inner_factor: 1
num_blocks: 1
fused_bias_fc: false
outer_mixing: false
dropout: 0.0 
filter_dropout: 0.0
filter_cls: 'hyena-filter'
post_order_ffn: false
jit_filter: false
short_filter_order: 3
activation: "id"

================================================
FILE: configs/model/mamba.yaml
================================================
# Use open-source version of Mamba
_name_: mamba_lm
config:
  _target_: mamba_ssm.models.config_mamba.MambaConfig
  d_model: 128  # Will be overwritten by CL in the scaling exps
  n_layer: 2  # Will be overwritten by CL in the scaling exps
  vocab_size: 12
  pad_vocab_size_multiple: 8
  rms_norm: true
  fused_add_norm: true
  residual_in_fp32: false
  ssm_cfg:
    d_state: 16
    d_conv: 4
    expand: 2
    dt_rank: "auto"
    dt_min: 0.001
    dt_max: 0.1
    dt_init: "random"
    dt_scale: 1.0
    dt_init_floor: 1e-4
    conv_bias: true
    bias: false
    use_fast_path: true
initializer_cfg:
  initializer_range: 0.02
  rescale_prenorm_residual: true
  n_residuals_per_layer: 1
#norm_epsilon: 1e-5  # Default arg in mamba create_block


================================================
FILE: configs/optimizer/adam.yaml
================================================
# _target_: torch.optim.Adam
_name_: adam
lr: 0.001  # Initial learning rate
# weight_decay: 0.0  # Weight decay for adam|lamb; should use AdamW instead if desired
betas: [0.9, 0.999]


================================================
FILE: configs/optimizer/adamw.yaml
================================================
# _target_: torch.optim.AdamW
_name_: adamw
lr: 0.001 # Initial learning rate
weight_decay: 0.00 # Weight decay
betas: [0.9, 0.999]


================================================
FILE: configs/optimizer/sgd.yaml
================================================
# _target_: torch.optim.SGD
_name_: sgd
lr: 0.001  # Initial learning rate
momentum: 0.9
weight_decay: 0.0  # Weight decay for adam|lamb


================================================
FILE: configs/pipeline/genomic_benchmark.yaml
================================================
# @package _global_
defaults:
  - /trainer: default
  - /loader: default
  - /dataset: genomic_benchmark
  - /task: multiclass_classification
  - /optimizer: adamw
  - /scheduler: plateau
  - /callbacks: [base, checkpoint]

train:
  monitor: val/accuracy # Needed for plateau scheduler
  mode: max

encoder: id

# we need this for classification!
decoder:
  _name_: sequence
  mode: pool


================================================
FILE: configs/pipeline/hg38.yaml
================================================
# @package _global_
defaults:
  - /trainer: default
  - /loader: null
  - /dataset: hg38
  - /optimizer: adamw
  - /scheduler: cosine_warmup
  - /callbacks: [base, checkpoint]

train:
  monitor: test/loss
  mode: min

task:
  _name_: lm
  loss:
    _name_: cross_entropy
    ignore_index: 4  # Bake in tokenizer value for padding / EOS tokens
  torchmetrics: ['perplexity', 'num_tokens']

encoder: null
decoder: null

loader:
  num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
  pin_memory: True
  drop_last: True  # There's enough data and epochs, ignore the edge case
  # shuffle: True


================================================
FILE: configs/pipeline/nucleotide_transformer.yaml
================================================
# @package _global_
defaults:
  - /trainer: default
  - /loader: default
  - /dataset: nucleotide_transformer
  - /task: multiclass_classification
  - /optimizer: adamw
  - /scheduler: plateau
  - /callbacks: [base, checkpoint]

task:
  loss:
    _name_: cross_entropy
  metrics:
    - ${dataset.metric}

train:
  monitor: val/${dataset.metric}
  mode: max

encoder: id

# we need this for classification!
decoder:
  _name_: sequence
  mode: pool

================================================
FILE: configs/scheduler/constant.yaml
================================================
# @package _global_
train:
  interval: epoch
scheduler:
  # _target_: transformers.get_constant_schedule
  _name_: constant


================================================
FILE: configs/scheduler/constant_warmup.yaml
================================================
# @package _global_
train:
  interval: step
scheduler:
  # _target_: transformers.get_constant_schedule_with_warmup
  _name_: constant_warmup
  num_warmup_steps: 1000  # Number of iterations for LR warmup


================================================
FILE: configs/scheduler/cosine.yaml
================================================
# @package _global_
train:
  interval: epoch
scheduler:
  # _target_: torch.optim.lr_scheduler.CosineAnnealingLR
  _name_: cosine
  T_max: 100  # Max number of epochs steps for LR scheduler
  eta_min: 1e-6  # Min learning rate for cosine scheduler


================================================
FILE: configs/scheduler/cosine_warmup.yaml
================================================
# @package _global_
train:
  interval: step
scheduler:
  # _target_: transformers.get_cosine_schedule_with_warmup
  _name_: cosine_warmup
  num_warmup_steps: 1000
  num_training_steps: 40000


================================================
FILE: configs/scheduler/cosine_warmup_timm.yaml
================================================
# @package _global_
train:
  interval: step
scheduler:
  # _target_: transformers.get_cosine_schedule_with_warmup
  _name_: cosine_warmup_timm
  t_in_epochs: False
  t_initial: 300
  lr_min: 1e-5
  warmup_lr_init: 1e-6
  warmup_t: 10


================================================
FILE: configs/scheduler/linear_warmup.yaml
================================================
# @package _global_
train:
  interval: step
scheduler:
  # _target_: transformers.get_linear_schedule_with_warmup
  _name_: linear_warmup
  num_warmup_steps: 1000
  num_training_steps: 40000


================================================
FILE: configs/scheduler/multistep.yaml
================================================
# @package _global_
train:
  interval: epoch
# _target_: torch.optim.lr_scheduler.MultiStepLR
scheduler:
  _name_: multistep
  milestones: [80,140,180]
  gamma: 0.2


================================================
FILE: configs/scheduler/plateau.yaml
================================================
# @package _global_
train:
  interval: epoch
  monitor: ??? # must be specified
scheduler:
  # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
  _name_: plateau
  mode: ${train.mode} # Which metric to monitor
  factor: 0.2  # Decay factor when ReduceLROnPlateau is used
  patience: 20
  min_lr: 0.0  # Minimum learning rate during annealing


================================================
FILE: configs/scheduler/step.yaml
================================================
# @package _global_
train:
  interval: step
scheduler:
  # _target_: torch.optim.lr_scheduler.StepLR
  _name_: step
  step_size: 1
  gamma: 0.99


================================================
FILE: configs/task/lm.yaml
================================================
_name_: lm
# loss: cross_entropy # Handled by task: cross entropy loss
metrics: ppl


================================================
FILE: configs/task/multiclass_classification.yaml
================================================
# _target_: tasks.tasks.MultiClass
_name_: multiclass
loss: cross_entropy
metrics:
  - accuracy
torchmetrics: null


================================================
FILE: configs/task/multilabel_classification.yaml
================================================
# _target_:
_name_: base
loss: binary_cross_entropy
metrics: null
torchmetrics:
  - MultilabelAUROC # AUROC
  - MultilabelAveragePrecision # Precision
#  - Recall # not supported in torchmetrics
#  - F1 # not supported in torchmetrics


================================================
FILE: configs/task/regression.yaml
================================================
# _target_: tasks.tasks.BaseTask
_name_: base
loss: mse
metrics: mse
torchmetrics: null


================================================
FILE: configs/trainer/debug.yaml
================================================
defaults:
  - default

gpus: 1
min_epochs: 1
max_epochs: 10

# prints
progress_bar_refresh_rate: null
weights_summary: full
profiler: null

# debugs
fast_dev_run: False
num_sanity_val_steps: 2
overfit_batches: 0
limit_train_batches: 0.1
limit_val_batches: 0.1
limit_test_batches: 0.1
track_grad_norm: -1
terminate_on_nan: False


================================================
FILE: configs/trainer/default.yaml
================================================
_target_: pytorch_lightning.Trainer

devices: 1
accelerator: gpu
accumulate_grad_batches: 1 # Gradient accumulation every n batches
max_epochs: 200
                           # accelerator: ddp # Automatically set if gpus > 1
gradient_clip_val: 0.0
log_every_n_steps: 10
limit_train_batches: 1.0   # train on full dataset, can be used to toggle quick run
limit_val_batches: 1.0     # validate on full dataset, can be used to toggle quick run
num_sanity_val_steps: 2    # default value: 2; override to 0 to skip sanity checking


================================================
FILE: configs/trainer/full.yaml
================================================
_target_: pytorch_lightning.Trainer

# default values for all trainer parameters
checkpoint_callback: True
default_root_dir: null
gradient_clip_val: 0.0
process_position: 0
num_nodes: 1
num_processes: 1
gpus: null
auto_select_gpus: False
tpu_cores: null
log_gpu_memory: null
overfit_batches: 0.0
track_grad_norm: -1
check_val_every_n_epoch: 1
fast_dev_run: False
accumulate_grad_batches: 1
max_epochs: 1
min_epochs: 1
max_steps: null
min_steps: null
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
val_check_interval: 1.0
flush_logs_every_n_steps: 100
log_every_n_steps: 50
accelerator: null
sync_batchnorm: False
precision: 32
weights_summary: "top"
weights_save_path: null
num_sanity_val_steps: 2
truncated_bptt_steps: null
resume_from_checkpoint: null
profiler: null
benchmark: False
deterministic: False
reload_dataloaders_every_epoch: False
auto_lr_find: False
replace_sampler_ddp: True
terminate_on_nan: False
auto_scale_batch_size: False
prepare_data_per_node: True
plugins: null
amp_backend: "native"
amp_level: "O2"
move_metrics_to_cpu: False


================================================
FILE: configs/trainer/lm.yaml
================================================
accumulate_grad_batches: 1
# accelerator: null # set to 'ddp' for distributed
# amp_backend: native # 'native' | 'apex'
gpus: 8
max_epochs: 50
gradient_clip_val: 0.0 # Gradient clipping
log_every_n_steps: 10
precision: 16
progress_bar_refresh_rate: 1
weights_summary: top # Set to 'full' to see every layer
track_grad_norm: -1 # Set to 2 to track norms of gradients
limit_train_batches: 1.0
limit_val_batches: 1.0
# We use the dataloader from Transformer-XL to ensure adjacent minibatches
# are from text that are next to each other.
# So that dataloader has to deal with DDP, and we don't want PL to handle
# that.
replace_sampler_ddp: False


================================================
FILE: setup_env.sh
================================================
#!/bin/bash

# Shell script to set environment variables when running code in this repository.
# Usage:
#     source setup_env.sh

# Activate conda env
# shellcheck source=${HOME}/.bashrc disable=SC1091
source "${CONDA_SHELL}"
if [ -z "${CONDA_PREFIX}" ]; then
    conda activate caduceus_env
 elif [[ "${CONDA_PREFIX}" != *"/caduceus_env" ]]; then
  conda deactivate
  conda activate caduceus_env
fi

# Add root directory to PYTHONPATH to enable module imports
export PYTHONPATH="${PWD}"


================================================
FILE: slurm_scripts/dump_vep_embeddings.sh
================================================
#!/bin/bash
#SBATCH --get-user-env                      # Retrieve the users login environment
#SBATCH -t 96:00:00                         # Time limit (hh:mm:ss)
#SBATCH --mem=100G                          # RAM
#SBATCH --gres=gpu:8                        # Number of GPUs
#SBATCH --ntasks-per-node=8                 # Should correspond to num devices (at least 1-1 task to GPU)
##SBATCH --cpus-per-task=4                  # Number of CPU cores per task
#SBATCH -N 1                                # Number of nodes
#SBATCH --requeue                           # Requeue job if it fails
#SBATCH --job-name=vep_embed                # Job name
#SBATCH--output=../watch_folder/%x_%j.log   # Output file name
#SBATCH --open-mode=append                  # Do not overwrite logs

NUM_WORKERS=2
NUM_DEVICES=8

# Setup environment
cd ../ || exit  # Go to the root directory of the repo
source setup_env.sh
export CUDA_LAUNCH_BLOCKING=1
export CUBLAS_WORKSPACE_CONFIG=:4096:8  # Needed for setting deterministic functions for reproducibility

#####################################################################################
# Choose from one of the following:

## Enformer
#seq_len=196608
#bp_per_token=1
#embed_dump_batch_size=1
#model_name_or_path="EleutherAI/enformer-official-rough"
#name="enformer-seqlen=196k"
#rcps_flag="no-rcps"

## NTv2
#seq_len=12288  # 2048 (seq len) * 6 (kmers)
#bp_per_token=6
#embed_dump_batch_size=1
#model_name_or_path="InstaDeepAI/nucleotide-transformer-v2-500m-multi-species"
#name="NTv2_downstream-seqlen=12k"
#rcps_flag="no-rcps"

## Hyena
#seq_len=131072
#bp_per_token=1
#embed_dump_batch_size=1
#model_name_or_path="LongSafari/hyenadna-medium-160k-seqlen-hf"
#name="hyena_downstream-seqlen=131k"
#rcps_flag="no-rcps"

## Caduceus-Ph
#seq_len=131072
#bp_per_token=1
#embed_dump_batch_size=1
#model_name_or_path="kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
#name="caduceus-ph_downstream-seqlen=131k"
#rcps_flag="no-rcps"

## Caduceus-PS
#seq_len=131072
#bp_per_token=1
#embed_dump_batch_size=1
#model_name_or_path="kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16"
#name="caduceus-ps_downstream-seqlen=131k"
#rcps_flag="rcps"
#####################################################################################

torchrun \
    --standalone \
    --nnodes=1 \
    --nproc-per-node=${NUM_DEVICES} \
    vep_embeddings.py \
      --num_workers=${NUM_WORKERS} \
      --seq_len=${seq_len}  \
      --bp_per_token=${bp_per_token}  \
      --embed_dump_batch_size=${embed_dump_batch_size} \
      --name="${name}"  \
      --model_name_or_path="${model_name_or_path}" \
      --"${rcps_flag}"


================================================
FILE: slurm_scripts/run_genomics_benchmark.sh
================================================
#!/bin/bash
#SBATCH --get-user-env                   # Retrieve the users login environment
#SBATCH -t 96:00:00                      # Time limit (hh:mm:ss)
#SBATCH --mem=64000M                     # RAM
#SBATCH --gres=gpu:1                     # Number of GPUs
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=2
#SBATCH -N 1                             # Number of nodes
#SBATCH --requeue                        # Requeue job if it fails
#SBATCH --open-mode=append               # Do not overwrite logs

# Setup environment
cd ../ || exit  # Go to the root directory of the repo
source setup_env.sh

# Expected args:
# - CONFIG_PATH
# - PRETRAINED_PATH
# - DISPLAY_NAME
# - MODEL
# - MODEL_NAME
# - CONJOIN_TRAIN_DECODER
# - CONJOIN_TEST
# - TASK
# - LR
# - BATCH_SIZE
# - RC_AUG


# Run script
# shellcheck disable=SC2154
WANDB_NAME="${DISPLAY_NAME}_lr-${LR}_batch_size-${BATCH_SIZE}_rc_aug-${RC_AUG}"
for seed in $(seq 1 5); do
  # shellcheck disable=SC2154
  HYDRA_RUN_DIR="./outputs/downstream/gb_cv5/${TASK}/${WANDB_NAME}/seed-${seed}"
  mkdir -p "${HYDRA_RUN_DIR}"
  echo "*****************************************************"
  echo "Running GenomicsBenchmark model: ${DISPLAY_NAME}, task: ${TASK}, lr: ${LR}, batch_size: ${BATCH_SIZE}, rc_aug: ${RC_AUG}, SEED: ${seed}"
  # shellcheck disable=SC2086
  python -m train \
    experiment=hg38/genomic_benchmark \
    callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
    dataset.dataset_name="${TASK}" \
    dataset.train_val_split_seed=${seed} \
    dataset.batch_size=${BATCH_SIZE} \
    dataset.rc_aug="${RC_AUG}" \
    +dataset.conjoin_train=false \
    +dataset.conjoin_test="${CONJOIN_TEST}" \
    model="${MODEL}" \
    model._name_="${MODEL_NAME}" \
    +model.config_path="${CONFIG_PATH}" \
    +model.conjoin_test="${CONJOIN_TEST}" \
    +decoder.conjoin_train="${CONJOIN_TRAIN_DECODER}" \
    +decoder.conjoin_test="${CONJOIN_TEST}" \
    optimizer.lr="${LR}" \
    trainer.max_epochs=10 \
    train.pretrained_model_path="${PRETRAINED_PATH}" \
    wandb.group="downstream/gb_cv5" \
    wandb.job_type="${TASK}" \
    wandb.name="${WANDB_NAME}" \
    wandb.id="gb_cv5_${TASK}_${WANDB_NAME}_seed-${seed}" \
    +wandb.tags=\["seed-${seed}"\] \
    hydra.run.dir="${HYDRA_RUN_DIR}"
  echo "*****************************************************"
done


================================================
FILE: slurm_scripts/run_genomics_benchmark_cnn.sh
================================================
#!/bin/bash
#SBATCH --get-user-env                   # Retrieve the users login environment
#SBATCH -t 48:00:00                      # Time limit (hh:mm:ss)
#SBATCH --mem=64G                     # RAM
#SBATCH --gres=gpu:1                     # Number of GPUs
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=2
#SBATCH -N 1                             # Number of nodes
#SBATCH --requeue                        # Requeue job if it fails
#SBATCH --open-mode=append               # Do not overwrite logs

# Setup environment
cd ../ || exit  # Go to the root directory of the repo
source setup_env.sh

# Expected args:
# - TASK
# - RC_AUG


# LR: 1e-3 -- in https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks, Adam optimizer is used with default lr=1e-3
LR="1e-3"
# Batch size: 64 -- See https://arxiv.org/abs/2306.15794 and https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks
BATCH_SIZE=64

# Run script
WANDB_NAME="CNN-LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}"
for seed in $(seq 1 5); do
  HYDRA_RUN_DIR="./outputs/downstream/gb_cv5/${TASK}/${WANDB_NAME}/seed-${seed}"
  mkdir -p "${HYDRA_RUN_DIR}"
  echo "*****************************************************"
  echo "Running GenomicsBenchmark TASK: ${TASK}, lr: ${LR}, batch_size: ${BATCH_SIZE}, RC_AUG: ${RC_AUG}, SEED: ${seed}"
  python -m train \
    experiment=hg38/genomic_benchmark_cnn \
    callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
    dataset.dataset_name="${TASK}" \
    dataset.train_val_split_seed=${seed} \
    dataset.batch_size=${BATCH_SIZE} \
    dataset.rc_aug="${RC_AUG}" \
    optimizer.lr="${LR}" \
    trainer.max_epochs=10 \
    wandb.group="downstream/gb_cv5" \
    wandb.job_type="${TASK}" \
    wandb.name="${WANDB_NAME}" \
    wandb.id="gb_cv5_${TASK}_${WANDB_NAME}_seed-${seed}" \
    +wandb.tags=\["seed-${seed}"\] \
    hydra.run.dir="${HYDRA_RUN_DIR}"
  echo "*****************************************************"
done


================================================
FILE: slurm_scripts/run_nucleotide_transformer.sh
================================================
#!/bin/bash
#SBATCH --get-user-env                   # Retrieve the users login environment
#SBATCH -t 96:00:00                      # Time limit (hh:mm:ss)
#SBATCH --mem=64G                        # RAM
#SBATCH --gres=gpu:2                     # Number of GPUs
#SBATCH --ntasks-per-node=2
#SBATCH --cpus-per-task=4
#SBATCH -N 1                             # Number of nodes
#SBATCH --requeue                        # Requeue job if it fails
#SBATCH --open-mode=append               # Do not overwrite logs

# Setup environment
cd ../ || exit  # Go to the root directory of the repo
source setup_env.sh
export HYDRA_FULL_ERROR=1

# Expected args:
# - CONFIG_PATH
# - PRETRAINED_PATH
# - DISPLAY_NAME
# - MODEL
# - MODEL_NAME
# - CONJOIN_TRAIN_DECODER
# - CONJOIN_TEST
# - TASK
# - LR
# - BATCH_SIZE
# - RC_AUG

# Run script
WANDB_NAME="${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}"
for seed in $(seq 1 10); do
  HYDRA_RUN_DIR="./outputs/downstream/nt_cv10_ep20/${TASK}/${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}/seed-${seed}"
  mkdir -p "${HYDRA_RUN_DIR}"
  echo "*****************************************************"
  echo "Running NT model: ${DISPLAY_NAME}, TASK: ${TASK}, LR: ${LR}, BATCH_SIZE: ${BATCH_SIZE}, RC_AUG: ${RC_AUG}, SEED: ${seed}"
  python -m train \
    experiment=hg38/nucleotide_transformer \
    callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \
    dataset.dataset_name="${TASK}" \
    dataset.train_val_split_seed=${seed} \
    dataset.batch_size=${BATCH_SIZE} \
    dataset.rc_aug="${RC_AUG}" \
    +dataset.conjoin_test="${CONJOIN_TEST}" \
    model="${MODEL}" \
    model._name_="${MODEL_NAME}" \
    +model.config_path="${CONFIG_PATH}" \
    +model.conjoin_test="${CONJOIN_TEST}" \
    +decoder.conjoin_train="${CONJOIN_TRAIN_DECODER}" \
    +decoder.conjoin_test="${CONJOIN_TEST}" \
    optimizer.lr="${LR}" \
    train.pretrained_model_path="${PRETRAINED_PATH}" \
    trainer.max_epochs=20 \
    wandb.group="downstream/nt_cv10_ep20" \
    wandb.job_type="${TASK}" \
    wandb.name="${WANDB_NAME}" \
    wandb.id="nt_cv10_ep-20_${TASK}_${WANDB_NAME}_seed-${seed}" \
    +wandb.tags=\["seed-${seed}"\] \
    hydra.run.dir="${HYDRA_RUN_DIR}"
  echo "*****************************************************"
done


================================================
FILE: slurm_scripts/run_pretrain_caduceus.sh
================================================
#!/bin/bash
#SBATCH --get-user-env                      # Retrieve the users login environment
#SBATCH -t 96:00:00                         # Time limit (hh:mm:ss)
#SBATCH --mem=100G                          # RAM
#SBATCH --gres=gpu:8                        # Number of GPUs
#SBATCH --ntasks-per-node=8                 # Should correspond to num devices (at least 1-1 task to GPU)
##SBATCH --cpus-per-task=4                  # Number of CPU cores per task
#SBATCH -N 1                                # Number of nodes
#SBATCH --requeue                           # Requeue job if it fails
#SBATCH --job-name=caduceus_ps              # Job name
#SBATCH --output=../watch_folder/%x_%j.log  # Log file
#SBATCH --open-mode=append                  # Do not overwrite logs

# Setup environment
cd ../ || exit  # Go to the root directory of the repo
source setup_env.sh
export HYDRA_FULL_ERROR=1

NUM_DEVICES=8

# Run script
SEQLEN=131072
MAX_STEPS=50000
D_MODEL=256
N_LAYER=8
LR="8e-3"
BIDIRECTIONAL_STRATEGY="add"
BIDIRECTIONAL_WEIGHT_TIE="true"
RCPS="true"
RC_AUG="false"

BATCH_SIZE=$(( 1048576 / SEQLEN ))
SEQLEN_DIS="$(echo "scale=0; ${SEQLEN} / 1000" | bc)k"
WANDB_NAME="caduceus-ps_seqlen-${SEQLEN_DIS}_d_model-${D_MODEL}_n_layer-${N_LAYER}_lr-${LR}"
HYDRA_RUN_DIR="./outputs/pretrain/hg38/${WANDB_NAME}"

mkdir -p "${HYDRA_RUN_DIR}"
srun python -m train \
  experiment=hg38/hg38 \
  callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \
  dataset.max_length=${SEQLEN} \
  dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \
  dataset.mlm=true \
  dataset.mlm_probability=0.15 \
  dataset.rc_aug="${RC_AUG}" \
  model="caduceus" \
  model.config.d_model=${D_MODEL} \
  model.config.n_layer=${N_LAYER} \
  model.config.bidirectional=true \
  model.config.bidirectional_strategy=${BIDIRECTIONAL_STRATEGY} \
  model.config.bidirectional_weight_tie=${BIDIRECTIONAL_WEIGHT_TIE} \
  model.config.rcps=${RCPS} \
  optimizer.lr="${LR}" \
  train.global_batch_size=${BATCH_SIZE} \
  trainer.max_steps=${MAX_STEPS} \
  trainer.devices=${NUM_DEVICES} \
  +trainer.val_check_interval=$(( MAX_STEPS / 5 )) \
  wandb.group=pretrain_hg38 \
  wandb.name="${WANDB_NAME}" \
  hydra.run.dir="${HYDRA_RUN_DIR}"


================================================
FILE: slurm_scripts/run_pretrain_hyena.sh
================================================
#!/bin/bash
#SBATCH --get-user-env                      # Retrieve the users login environment
#SBATCH -t 96:00:00                         # Time limit (hh:mm:ss)
#SBATCH --mem=100G                          # RAM
#SBATCH --gres=gpu:8                        # Number of GPUs
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=4                   # Number of CPU cores per task
#SBATCH -N 1                                # Number of nodes
#SBATCH --requeue                           # Requeue job if it fails
#SBATCH --job-name=hyena                    # Job name
#SBATCH --output=../watch_folder/%x_%j.log  # Log file

# Setup environment
cd ../ || exit  # Go to the root directory of the repo
source setup_env.sh

NUM_DEVICES=8

# Run script
SEQLEN=1024
MAX_STEPS=10000
D_MODEL=256
N_LAYER=4
LR="6e-4"
RC_AUG="true"

BATCH_SIZE=$(( 1048576 / SEQLEN ))
SEQLEN_DIS="$(echo "scale=0; ${SEQLEN} / 1000" | bc)k"
WANDB_NAME="hyena_rc_aug_seqlen-${SEQLEN_DIS}_dmodel-${D_MODEL}_nlayer-${N_LAYER}_lr-${LR}"
HYDRA_RUN_DIR="./outputs/pretrain/hg38/${WANDB_NAME}"

mkdir -p "${HYDRA_RUN_DIR}"
srun python -m train \
  experiment=hg38/hg38 \
  callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \
  dataset.max_length=${SEQLEN} \
  dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \
  dataset.mlm=false \
  dataset.mlm_probability=0.0 \
  dataset.rc_aug="${RC_AUG}" \
  model=hyena \
  model.d_model=${D_MODEL} \
  model.n_layer=${N_LAYER} \
  optimizer.lr="${LR}" \
  train.global_batch_size=${BATCH_SIZE} \
  trainer.max_steps=${MAX_STEPS} \
  trainer.devices=${NUM_DEVICES} \
  +trainer.val_check_interval=$(( MAX_STEPS / 5 )) \
  wandb.group=pretrain_hg38 \
  wandb.name="${WANDB_NAME}" \
  hydra.run.dir="${HYDRA_RUN_DIR}"


================================================
FILE: slurm_scripts/run_pretrain_mamba.sh
================================================
#!/bin/bash
#SBATCH --get-user-env                      # Retrieve the users login environment
#SBATCH -t 96:00:00                         # Time limit (hh:mm:ss)
#SBATCH --mem=100G                           # RAM
#SBATCH --gres=gpu:8                        # Number of GPUs
#SBATCH --ntasks-per-node=8                 # Should correspond to num devices (at least 1-1 task to GPU)
#SBATCH --cpus-per-task=4                   # Number of CPU cores per task
#SBATCH -N 1                                # Number of nodes
#SBATCH --requeue                           # Requeue job if it fails
#SBATCH --job-name=mamba_ntp                # Job name
#SBATCH --output=../watch_folder/%x_%j.log  # Log file

# Setup environment
cd ../ || exit  # Go to the root directory of the repo
source setup_env.sh

NUM_DEVICES=8

# Run script
SEQLEN=1024
MAX_STEPS=10000
D_MODEL=256
N_LAYER=8
LR="8e-3"
RC_AUG="true"

BATCH_SIZE=$(( 1048576 / SEQLEN ))
SEQLEN_DIS="$(echo "scale=0; ${SEQLEN} / 1000" | bc)k"
WANDB_NAME="mamba_ntp_rc_aug_seqlen-${SEQLEN_DIS}_d_model-${D_MODEL}_n_layer-${N_LAYER}_lr-${LR}"
HYDRA_RUN_DIR="./outputs/pretrain/hg38/${WANDB_NAME}"

mkdir -p "${HYDRA_RUN_DIR}"
srun python -m train \
  experiment=hg38/hg38 \
  callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \
  dataset.max_length=${SEQLEN} \
  dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \
  dataset.mlm=false \
  dataset.mlm_probability=0.0 \
  dataset.rc_aug="${RC_AUG}" \
  model=mamba \
  model.config.d_model=${D_MODEL} \
  model.config.n_layer=${N_LAYER} \
  optimizer.lr="${LR}" \
  train.global_batch_size=${BATCH_SIZE} \
  trainer.max_steps=${MAX_STEPS} \
  trainer.devices=${NUM_DEVICES} \
  +trainer.val_check_interval=$(( MAX_STEPS / 5 )) \
  wandb.group=pretrain_hg38 \
  wandb.name="${WANDB_NAME}" \
  hydra.run.dir="${HYDRA_RUN_DIR}"


================================================
FILE: slurm_scripts/wrapper_run_genomics.sh
================================================
#!/bin/bash

# Choose one from below

## Hyena
## TODO: Download HF model from https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen to ../outputs/hyena_hf/hyenadna-tiny-1k-seqlen
#LOG_DIR="../watch_folder/gb_cv5/hyena"
#CONFIG_PATH=$(realpath "../outputs/hyena_hf/hyenadna-tiny-1k-seqlen/config.json")
#PRETRAINED_PATH=$(realpath "../outputs/hyena_hf/hyenadna-tiny-1k-seqlen/weights.ckpt")
#DISPLAY_NAME="hyena"
#MODEL="hyena"
#MODEL_NAME="dna_embedding"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="false"
#RC_AUGS=( "false" "true" )
#LRS=( "6e-4" )

## Mamba NTP
#LOG_DIR="../watch_folder/gb_cv5/mamba"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/mamba_ntp_rc_aug_seqlen-1k_d_model-128_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/mamba_ntp_rc_aug_seqlen-1k_d_model-128_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="mamba_uni"
#MODEL="mamba"
#MODEL_NAME="dna_embedding_mamba"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="false"
#RC_AUGS=( "true" )
#LRS=( "1e-3" "2e-3" )

## Caduceus NO POST HOC
#LOG_DIR="../watch_folder/gb_cv5/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_NO_PH"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="false"
#RC_AUGS=( "true" )
#LRS=( "2e-3")

## Caduceus Post-Hoc
#LOG_DIR="../watch_folder/gb_cv5/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_ph"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="true"
#RC_AUGS=( "false" )
#LRS=( "1e-3" "2e-3" )

## Caduceus Parameter Sharing
#LOG_DIR="../watch_folder/gb_cv5/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_ps"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="true"  # Use this in decoder to always combine forward and reverse complement channels
#CONJOIN_TEST="false"
#RC_AUGS=( "false" )
#LRS=( "1e-3" "2e-3" )

mkdir -p "${LOG_DIR}"
export_str="ALL,CONFIG_PATH=${CONFIG_PATH},PRETRAINED_PATH=${PRETRAINED_PATH},DISPLAY_NAME=${DISPLAY_NAME},MODEL=${MODEL},MODEL_NAME=${MODEL_NAME},CONJOIN_TRAIN_DECODER=${CONJOIN_TRAIN_DECODER},CONJOIN_TEST=${CONJOIN_TEST}"
for TASK in "dummy_mouse_enhancers_ensembl" "demo_coding_vs_intergenomic_seqs" "demo_human_or_worm" "human_enhancers_cohn" "human_enhancers_ensembl" "human_ensembl_regulatory" "human_nontata_promoters" "human_ocr_ensembl"; do
  for LR in "${LRS[@]}"; do
    for BATCH_SIZE in 128 256; do
      for RC_AUG in "${RC_AUGS[@]}"; do
        export_str="${export_str},TASK=${TASK},LR=${LR},BATCH_SIZE=${BATCH_SIZE},RC_AUG=${RC_AUG}"
        job_name="gb_${TASK}_${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}"
        sbatch \
          --job-name="${job_name}" \
          --output="${LOG_DIR}/%x_%j.log" \
          --export="${export_str}" \
          "run_genomics_benchmark.sh"
      done
    done
  done
done

================================================
FILE: slurm_scripts/wrapper_run_genomics_cnn.sh
================================================
#!/bin/bash

LOG_DIR="../watch_folder/gb_cv5/cnn_baseline"
mkdir -p "${LOG_DIR}"
export_str="ALL"
for TASK in "dummy_mouse_enhancers_ensembl" "demo_coding_vs_intergenomic_seqs" "demo_human_or_worm" "human_enhancers_cohn" "human_enhancers_ensembl" "human_ensembl_regulatory" "human_nontata_promoters" "human_ocr_ensembl"; do
  for RC_AUG in "false"; do
    export_str="${export_str},TASK=${TASK},RC_AUG=${RC_AUG}"
    job_name="gb_${TASK}_CNN_RC_AUG-${RC_AUG}"
    sbatch \
      --job-name="${job_name}" \
      --output="${LOG_DIR}/%x_%j.log" \
      --export="${export_str}" \
      "run_genomics_benchmark_cnn.sh"
  done
done


================================================
FILE: slurm_scripts/wrapper_run_nucleotide_transformer.sh
================================================
#!/bin/bash

# Choose one from below

## Caduceus NO POST HOC
#LOG_DIR="../watch_folder/nt_cv10_ep20/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_NO_PH"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="false"
#RC_AUGS=( "true" )
#LRS=( "1e-3" "2e-3")

## Caduceus Post-Hoc
#LOG_DIR="../watch_folder/nt_cv10_ep20/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_ph"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="false"
#CONJOIN_TEST="true"
#RC_AUGS=( "false" )
#LRS=( "1e-3" "2e-3" )

## Caduceus Parameter Sharing
#LOG_DIR="../watch_folder/nt_cv10_ep20/caduceus"
#CONFIG_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/model_config.json")
#PRETRAINED_PATH=$(realpath "../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/checkpoints/last.ckpt")
#DISPLAY_NAME="caduceus_ps"
#MODEL="caduceus"
#MODEL_NAME="dna_embedding_caduceus"
#CONJOIN_TRAIN_DECODER="true"  # Use this in decoder to always combine forward and reverse complement channels
#CONJOIN_TEST="false"
#RC_AUGS=( "false" )
#LRS=( "1e-3" "2e-3" )

mkdir -p "${LOG_DIR}"
export_str="ALL,CONFIG_PATH=${CONFIG_PATH},PRETRAINED_PATH=${PRETRAINED_PATH},DISPLAY_NAME=${DISPLAY_NAME},MODEL=${MODEL},MODEL_NAME=${MODEL_NAME},CONJOIN_TRAIN_DECODER=${CONJOIN_TRAIN_DECODER},CONJOIN_TEST=${CONJOIN_TEST}"
for TASK in "enhancers" "enhancers_types" "H3" "H3K4me1" "H3K4me2" "H3K4me3" "H3K9ac" "H3K14ac" "H3K36me3" "H3K79me3" "H4" "H4ac" "promoter_all" "promoter_no_tata" "promoter_tata" "splice_sites_all" "splice_sites_acceptors" "splice_sites_donors"; do
  for LR in "${LRS[@]}"; do
    for BATCH_SIZE in 128 512; do
      for RC_AUG in "${RC_AUGS[@]}"; do
        export_str="${export_str},TASK=${TASK},LR=${LR},BATCH_SIZE=${BATCH_SIZE},RC_AUG=${RC_AUG}"
        job_name="nt_${TASK}_${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}"
        sbatch \
          --job-name="${job_name}" \
          --output="${LOG_DIR}/%x_%j.log" \
          --export="${export_str}" \
          "run_nucleotide_transformer.sh"
      done
    done
  done
done


================================================
FILE: src/__init__.py
================================================


================================================
FILE: src/callbacks/params.py
================================================
"""Callback to log the number of parameters of the model.

"""

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict


class ParamsLog(pl.Callback):
    """ Log the number of parameters of the model """
    def __init__(
        self,
        total: bool = True,
        trainable: bool = True,
        fixed: bool = True,
    ):
        super().__init__()
        self._log_stats = AttributeDict(
            {
                'total_params_log': total,
                'trainable_params_log': trainable,
                'non_trainable_params_log': fixed,
            }
        )

    @rank_zero_only
    def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        logs = {}
        if self._log_stats.total_params_log:
            logs["params/total"] = sum(p.numel() for p in pl_module.parameters())
        if self._log_stats.trainable_params_log:
            logs["params/trainable"] = sum(p.numel() for p in pl_module.parameters()
                                             if p.requires_grad)
        if self._log_stats.non_trainable_params_log:
            logs["params/fixed"] = sum(p.numel() for p in pl_module.parameters()
                                                     if not p.requires_grad)
        if trainer.logger:
            trainer.logger.log_hyperparams(logs)


================================================
FILE: src/callbacks/timer.py
================================================
"""Callback to monitor the speed of each step and each epoch.

https://github.com/HazyResearch/transformers/blob/master/src/callbacks/speed_monitor.py
Adapted from:
    https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor
"""

# We only need the speed monitoring, not the GPU monitoring
import time
from typing import Any

from pytorch_lightning import Callback, Trainer, LightningModule
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT


class Timer(Callback):
    """Monitor the speed of each step and each epoch.
    """
    def __init__(
        self,
        step: bool = True,
        inter_step: bool = True,
        epoch: bool = True,
        val: bool = True,
    ):
        super().__init__()
        self._log_stats = AttributeDict( {
            'step_time': step,
            'inter_step_time': inter_step,
            'epoch_time': epoch,
            'val_time': val,
        })

    def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        self._snap_epoch_time = None

    def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        self._snap_step_time = None
        self._snap_inter_step_time = None
        self._snap_epoch_time = time.time()

    def on_train_batch_start(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        batch: Any,
        batch_idx: int,
    ) -> None:
        if self._log_stats.step_time:
            self._snap_step_time = time.time()

        if not self._should_log(trainer):
            return

        logs = {}
        if self._log_stats.inter_step_time and self._snap_inter_step_time:
            # First log at beginning of second step
            logs["timer/inter_step"] = (time.time() - self._snap_inter_step_time) # * 1000

        if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step)

    @rank_zero_only
    def on_train_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: STEP_OUTPUT,
        batch: Any,
        batch_idx: int,
    ) -> None:
        if self._log_stats.inter_step_time:
            self._snap_inter_step_time = time.time()

        if not self._should_log(trainer):
            return

        logs = {}
        if self._log_stats.step_time and self._snap_step_time:
            logs["timer/step"] = (time.time() - self._snap_step_time) # * 1000

        if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step)

    @rank_zero_only
    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None:
        logs = {}
        if self._log_stats.epoch_time and self._snap_epoch_time:
            logs["timer/epoch"] = time.time() - self._snap_epoch_time
        if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step)

    def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        self._snap_val_time = time.time()

    @rank_zero_only
    def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None:
        logs = {}
        if self._log_stats.val_time and self._snap_val_time:
            logs["timer/validation"] = time.time() - self._snap_val_time
        if trainer.logger: trainer.logger.log_metrics(logs) # , step=trainer.global_step)

    @staticmethod
    def _should_log(trainer) -> bool:
        return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop


================================================
FILE: src/callbacks/validation.py
================================================
"""Check validation every n **global** steps.

Pytorch Lightning has a `val_check_interval` parameter that checks validation every n batches, but does not support
checking every n **global** steps.
"""

from typing import Any

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import RunningStage


class ValEveryNGlobalSteps(Callback):
    """Check validation every n **global** steps."""
    def __init__(self, every_n):
        self.every_n = every_n
        self.last_run = None

    def on_train_batch_end(self, trainer, *_: Any):
        """Check if we should run validation.

        Adapted from: https://github.com/Lightning-AI/pytorch-lightning/issues/2534#issuecomment-1085986529
        """
        # Prevent Running validation many times in gradient accumulation
        if trainer.global_step == self.last_run:
            return
        else:
            self.last_run = None
        if trainer.global_step % self.every_n == 0 and trainer.global_step != 0:
            trainer.training = False
            stage = trainer.state.stage
            trainer.state.stage = RunningStage.VALIDATING
            trainer._run_evaluate()
            trainer.state.stage = stage
            trainer.training = True
            trainer._logger_connector._epoch_end_reached = False
            self.last_run = trainer.global_step


================================================
FILE: src/dataloaders/__init__.py
================================================
from . import genomics
from .base import SequenceDataset


================================================
FILE: src/dataloaders/base.py
================================================
""" Datasets for core experimental results.

"""

import os
from functools import partial
from pathlib import Path

import torch


# Default data path is environment variable or <repo_root_dir>/data
if (default_data_path := os.getenv("DATA_PATH")) is None:
    default_data_path = Path(__file__).parent.parent.parent.absolute()
    default_data_path = default_data_path / "data"
else:
    default_data_path = Path(default_data_path).absolute()


class DefaultCollateMixin:
    """Controls collating in the DataLoader

    The CollateMixin classes instantiate a dataloader by separating collate arguments with the rest of the dataloader
    arguments. Instantiations of this class should modify the callback functions as desired, and modify the collate_args
    list. The class then defines a _dataloader() method which takes in a DataLoader constructor and arguments,
    constructs a collate_fn based on the collate_args, and passes the rest of the arguments into the constructor.
    """

    @classmethod
    def _collate_callback(cls, x, *args, **kwargs):
        """
        Modify the behavior of the default _collate method.
        """
        return x

    _collate_arg_names = []

    @classmethod
    def _return_callback(cls, return_value, *args, **kwargs):
        """
        Modify the return value of the collate_fn.
        Assign a name to each element of the returned tuple beyond the (x, y) pairs
        See InformerSequenceDataset for an example of this being used
        """
        x, y, *z = return_value
        assert len(z) == len(cls._collate_arg_names), "Specify a name for each auxiliary data item returned by dataset"
        return x, y, {k: v for k, v in zip(cls._collate_arg_names, z)}

    @classmethod
    def _collate(cls, batch, *args, **kwargs):
        # From https://github.com/pyforch/pytorch/blob/master/torch/utils/data/_utils/collate.py
        elem = batch[0]
        if isinstance(elem, torch.Tensor):
            out = None
            if torch.utils.data.get_worker_info() is not None:
                # If we're in a background process, concatenate directly into a
                # shared memory tensor to avoid an extra copy
                numel = sum(x.numel() for x in batch)
                storage = elem.storage()._new_shared(numel)
                out = elem.new(storage)
            x = torch.stack(batch, dim=0, out=out)

            # Insert custom functionality into the collate_fn
            x = cls._collate_callback(x, *args, **kwargs)

            return x
        else:
            return torch.tensor(batch)

    @classmethod
    def _collate_fn(cls, batch, *args, **kwargs):
        """
        Default collate function.
        Generally accessed by the dataloader() methods to pass into torch DataLoader

        Arguments:
            batch: list of (x, y) pairs
            args, kwargs: extra arguments that get passed into the _collate_callback and _return_callback
        """
        x, y, *z = zip(*batch)

        x = cls._collate(x, *args, **kwargs)
        y = cls._collate(y)
        z = [cls._collate(z_) for z_ in z]

        return_value = (x, y, *z)
        return cls._return_callback(return_value, *args, **kwargs)

    # List of loader arguments to pass into collate_fn
    collate_args = []

    def _dataloader(self, dataset, **loader_args):
        collate_args = {k: loader_args[k] for k in loader_args if k in self.collate_args}
        loader_args = {k: loader_args[k] for k in loader_args if k not in self.collate_args}
        loader_cls = loader_registry[loader_args.pop("_name_", None)]
        return loader_cls(
            dataset=dataset,
            collate_fn=partial(self._collate_fn, **collate_args),
            **loader_args,
        )


# class SequenceDataset(LightningDataModule):
# [21-09-10 AG] Subclassing LightningDataModule fails due to trying to access _has_setup_fit. No idea why. So we just
# provide our own class with the same core methods as LightningDataModule (e.g. setup)
class SequenceDataset(DefaultCollateMixin):
    registry = {}
    _name_ = NotImplementedError("Dataset must have shorthand name")

    # Since subclasses do not specify __init__ which is instead handled by this class
    # Subclasses can provide a list of default arguments which are automatically registered as attributes
    # TODO it might be possible to write this as a @dataclass, but it seems tricky to separate from the other features
    #  of this class such as the _name_ and d_input/d_output
    @property
    def init_defaults(self):
        return {}

    # https://www.python.org/dev/peps/pep-0487/#subclass-registration
    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        cls.registry[cls._name_] = cls

    def __init__(self, _name_, data_dir=None, **dataset_cfg):
        assert _name_ == self._name_
        self.data_dir = Path(data_dir).absolute() if data_dir is not None else None

        # Add all arguments to self
        init_args = self.init_defaults.copy()
        init_args.update(dataset_cfg)
        for k, v in init_args.items():
            setattr(self, k, v)

        # The train, val, test datasets must be set by `setup()`
        self.dataset_train = self.dataset_val = self.dataset_test = None

        self.init()

    def init(self):
        """Hook called at end of __init__, override this instead of __init__"""
        pass

    def setup(self):
        """This method should set self.dataset_train, self.dataset_val, and self.dataset_test."""
        raise NotImplementedError

    def split_train_val(self, val_split):
        """
        Randomly split self.dataset_train into a new (self.dataset_train, self.dataset_val) pair.
        """
        train_len = int(len(self.dataset_train) * (1.0 - val_split))
        self.dataset_train, self.dataset_val = torch.utils.data.random_split(
            self.dataset_train,
            (train_len, len(self.dataset_train) - train_len),
            generator=torch.Generator().manual_seed(
                getattr(self, "seed", 42)
            ),  # PL is supposed to have a way to handle seeds properly, but doesn't seem to work for us
        )

    def train_dataloader(self, **kwargs):
        """Return a DataLoader for the training dataset."""
        return self._train_dataloader(self.dataset_train, **kwargs)

    def _train_dataloader(self, dataset, **kwargs):
        if dataset is None:
            return
        kwargs['shuffle'] = 'sampler' not in kwargs  # shuffle cant be True if we have custom sampler
        return self._dataloader(dataset, **kwargs)

    def val_dataloader(self, **kwargs):
        """Return a DataLoader for the validation dataset."""
        return self._eval_dataloader(self.dataset_val, **kwargs)

    def test_dataloader(self, **kwargs):
        """Return a DataLoader for the test dataset."""
        return self._eval_dataloader(self.dataset_test, **kwargs)

    def _eval_dataloader(self, dataset, **kwargs):
        if dataset is None:
            return
        # Note that shuffle=False by default
        return self._dataloader(dataset, **kwargs)

    def __str__(self):
        return self._name_


# Registry for dataloader class
loader_registry = {
    None: torch.utils.data.DataLoader,  # default case
}


================================================
FILE: src/dataloaders/datasets/genomic_bench_dataset.py
================================================
"""Genomic Benchmarks Dataset.

From: https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks
"""

from pathlib import Path

import torch
from genomic_benchmarks.data_check import is_downloaded
from genomic_benchmarks.loc2seq import download_dataset

from src.dataloaders.utils.rc import coin_flip, string_reverse_complement


class GenomicBenchmarkDataset(torch.utils.data.Dataset):
    """
    Loop through bed file, retrieve (chr, start, end), query fasta file for sequence.
    Returns a generator that retrieves the sequence.
    """

    def __init__(
            self,
            split,
            max_length,
            dataset_name="human_nontata_promoters",
            d_output=2,  # default binary classification
            dest_path=None,
            tokenizer=None,
            tokenizer_name=None,
            use_padding=None,
            add_eos=False,
            rc_aug=False,
            conjoin_train=False,
            conjoin_test=False,
            return_augs=False,
            return_mask=False,
    ):

        self.max_length = max_length
        self.use_padding = use_padding
        self.tokenizer_name = tokenizer_name
        self.tokenizer = tokenizer
        self.return_augs = return_augs
        self.add_eos = add_eos
        self.d_output = d_output  # needed for decoder to grab
        assert not (conjoin_train and conjoin_test), "conjoin_train and conjoin_test cannot both be True"
        if (conjoin_train or conjoin_test) and rc_aug:
            print("When using conjoin, we turn off rc_aug.")
            rc_aug = False
        self.rc_aug = rc_aug
        self.conjoin_train = conjoin_train
        self.conjoin_test = conjoin_test
        self.return_mask = return_mask

        if not is_downloaded(dataset_name, cache_path=dest_path):
            print("downloading {} to {}".format(dataset_name, dest_path))
            download_dataset(dataset_name, version=0, dest_path=dest_path)
        else:
            print("already downloaded {}-{}".format(split, dataset_name))

        self.split = split

        # use Path object
        base_path = Path(dest_path) / dataset_name / split

        self.all_seqs = []
        self.all_labels = []
        label_mapper = {}

        for i, x in enumerate(base_path.iterdir()):
            label_mapper[x.stem] = i

        for label_type in label_mapper.keys():
            for path in (base_path / label_type).iterdir():
                with open(path, "r") as f:
                    content = f.read()
                self.all_seqs.append(content)
                self.all_labels.append(label_mapper[label_type])

    def __len__(self):
        return len(self.all_labels)

    def __getitem__(self, idx):
        x = self.all_seqs[idx]
        y = self.all_labels[idx]

        if (self.rc_aug or (self.conjoin_test and self.split == "train")) and coin_flip():
            x = string_reverse_complement(x)

        seq = self.tokenizer(
            x,
            add_special_tokens=False,
            padding="max_length" if self.use_padding else None,
            max_length=self.max_length,
            truncation=True,
        )
        seq_ids = seq["input_ids"]  # get input_ids

        # need to handle eos here
        if self.add_eos:
            # append list seems to be faster than append tensor
            seq_ids.append(self.tokenizer.sep_token_id)

        if self.conjoin_train or (self.conjoin_test and self.split != "train"):
            x_rc = string_reverse_complement(x)
            seq_rc = self.tokenizer(
                x_rc,
                add_special_tokens=False,
                padding="max_length" if self.use_padding else None,
                max_length=self.max_length,
                truncation=True,
            )
            seq_rc_ids = seq_rc["input_ids"]  # get input_ids
            # need to handle eos here
            if self.add_eos:
                # append list seems to be faster than append tensor
                seq_rc_ids.append(self.tokenizer.sep_token_id)
            seq_ids = torch.stack((torch.LongTensor(seq_ids), torch.LongTensor(seq_rc_ids)), dim=1)

        else:
            # convert to tensor
            seq_ids = torch.LongTensor(seq_ids)

        # need to wrap in list
        target = torch.LongTensor([y])

        # `seq` has shape:
        #     - (seq_len,) if not conjoining
        #     - (seq_len, 2) for conjoining
        if self.return_mask:
            return seq_ids, target, {"mask": torch.BoolTensor(seq["attention_mask"])}
        else:
            return seq_ids, target


================================================
FILE: src/dataloaders/datasets/hg38_char_tokenizer.py
================================================
""" 
From: https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py

CharacterTokenizer for Hugging Face Transformers.
This is heavily inspired from CanineTokenizer in transformers package.
"""
import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union

from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer


class CharacterTokenizer(PreTrainedTokenizer):
    def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str = 'left', **kwargs):
        """Character tokenizer for Hugging Face transformers.
        Args:
            characters (Sequence[str]): List of desired characters. Any character which
                is not included in this list will be replaced by a special token called
                [UNK] with id=6. Following is the list of all the special tokens with
                their corresponding ids:
                    "[CLS]": 0
                    "[SEP]": 1
                    "[BOS]": 2
                    "[MASK]": 3
                    "[PAD]": 4
                    "[RESERVED]": 5
                    "[UNK]": 6
                an id (starting at 7) will be assigned to each character.
            model_max_length (int): Model maximum sequence length.
        """
        self.characters = characters
        self.model_max_length = model_max_length
        bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False)
        eos_token = AddedToken("[EOS]", lstrip=False, rstrip=False)
        sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
        cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False)
        pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False)
        unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False)

        mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False)

        self._vocab_str_to_int = {
            "[CLS]": 0,
            "[SEP]": 1,
            "[BOS]": 2,
            "[MASK]": 3,
            "[PAD]": 4,
            "[RESERVED]": 5,
            "[UNK]": 6,
            **{ch: i + 7 for i, ch in enumerate(characters)},
        }
        self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}

        # TODO: This should be a parameter passed to __init__
        complement_map = {"A": "T", "C": "G", "G": "C", "T": "A"}
        self.complement_map = {}
        for k, v in self._vocab_str_to_int.items():
            complement_id = self._vocab_str_to_int[complement_map[k]] if k in complement_map.keys() else v
            self.complement_map[self._vocab_str_to_int[k]] = complement_id
        super().__init__(
            bos_token=bos_token,
            eos_token=pad_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            unk_token=unk_token,
            add_prefix_space=False,
            model_max_length=model_max_length,
            padding_side=padding_side,
            **kwargs,
        )

    @property
    def vocab_size(self) -> int:
        return len(self._vocab_str_to_int)

    def _tokenize(self, text: str) -> List[str]:
        return list(text)

    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])

    def _convert_id_to_token(self, index: int) -> str:
        return self._vocab_int_to_str[index]

    def convert_tokens_to_string(self, tokens):
        return "".join(tokens)

    def build_inputs_with_special_tokens(
            self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        result = cls + token_ids_0 + sep
        if token_ids_1 is not None:
            result += token_ids_1 + sep
        return result

    def get_special_tokens_mask(
            self,
            token_ids_0: List[int],
            token_ids_1: Optional[List[int]] = None,
            already_has_special_tokens: bool = False,
    ) -> List[int]:
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0,
                token_ids_1=token_ids_1,
                already_has_special_tokens=True,
            )

        result = [1] + ([0] * len(token_ids_0)) + [1]
        if token_ids_1 is not None:
            result += ([0] * len(token_ids_1)) + [1]
        return result

    def get_vocab(self) -> Dict[str, int]:
        return self._vocab_str_to_int

    def create_token_type_ids_from_sequences(
            self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        result = len(cls + token_ids_0 + sep) * [0]
        if token_ids_1 is not None:
            result += len(token_ids_1 + sep) * [1]
        return result

    def get_config(self) -> Dict:
        return {
            "char_ords": [ord(ch) for ch in self.characters],
            "model_max_length": self.model_max_length,
        }

    @classmethod
    def from_config(cls, config: Dict) -> "CharacterTokenizer":
        cfg = {}
        cfg["characters"] = [chr(i) for i in config["char_ords"]]
        cfg["model_max_length"] = config["model_max_length"]
        return cls(**cfg)

    def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
        cfg_file = Path(save_directory) / "tokenizer_config.json"
        cfg = self.get_config()
        with open(cfg_file, "w") as f:
            json.dump(cfg, f, indent=4)

    @classmethod
    def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
        cfg_file = Path(save_directory) / "tokenizer_config.json"
        with open(cfg_file) as f:
            cfg = json.load(f)
        return cls.from_config(cfg)


================================================
FILE: src/dataloaders/datasets/hg38_dataset.py
================================================
"""Dataset for sampling arbitrary intervals from the human genome.

"""

import math
from pathlib import Path

import pandas as pd
import torch
from pyfaidx import Fasta

from src.dataloaders.utils.mlm import mlm_getitem
from src.dataloaders.utils.rc import coin_flip, string_reverse_complement

MAX_ALLOWED_LENGTH = 2 ** 20


class FastaInterval:
    """Retrieves sequences from a fasta file given a chromosome and start/end indices."""
    def __init__(
            self,
            *,
            fasta_file,
            return_seq_indices=False,
            rc_aug=False,
    ):
        fasta_file = Path(fasta_file)
        assert fasta_file.exists(), "Path to fasta file must exist!"

        self.seqs = Fasta(str(fasta_file))
        self.return_seq_indices = return_seq_indices
        self.rc_aug = rc_aug

        # calc len of each chromosome in fasta file, store in dict
        self.chr_lens = {}

        for chr_name in self.seqs.keys():
            self.chr_lens[chr_name] = len(self.seqs[chr_name])

    @staticmethod
    def _compute_interval(start, end, max_length, i_shift):
        if max_length == MAX_ALLOWED_LENGTH:
            return start, end
        if max_length < MAX_ALLOWED_LENGTH:
            assert MAX_ALLOWED_LENGTH % max_length == 0
            return start + i_shift * max_length, start + (i_shift + 1) * max_length
        else:
            raise ValueError(f"`max_length` {max_length} (> 2^{int(math.log(MAX_ALLOWED_LENGTH, 2))}) is too large!")

    def __call__(
            self,
            chr_name,
            start,
            end,
            max_length,
            i_shift,
            return_augs=False,
    ):
        """
        max_length passed from dataset, not from init
        """
        chromosome = self.seqs[chr_name]
        chromosome_length = self.chr_lens[chr_name]

        start, end = self._compute_interval(start, end, max_length, i_shift)

        if end > chromosome_length:
            # Shift interval down
            start = start - (end - chromosome_length)
            end = chromosome_length
            assert start == chromosome_length - max_length

        if start < 0:
            # Shift interval up
            end = end - start
            start = 0
            assert end == max_length

        if end > chromosome_length:
            # This may occur if start + MAX_ALLOWED_LENGTH extends beyond the end of the chromosome
            start = chromosome_length - max_length
            end = chromosome_length

        seq = str(chromosome[start:end])

        if self.rc_aug and coin_flip():
            seq = string_reverse_complement(seq)

        return seq


class HG38Dataset(torch.utils.data.Dataset):
    """Loop through bed file, retrieve (chr, start, end), query fasta file for sequence."""

    def __init__(
            self,
            split,
            bed_file,
            fasta_file,
            max_length,
            mlm=False,
            mlm_probability=0.15,
            pad_max_length=None,
            tokenizer=None,
            tokenizer_name=None,
            add_eos=False,
            return_seq_indices=False,
            rc_aug=False,
            return_augs=False,
    ):
        self.mlm = mlm
        self.mlm_probability = mlm_probability
        if self.mlm and self.mlm_probability <= 0.0:
            raise ValueError(f"`mlm_probability` has to be > 0.0, got {self.mlm_probability}.")
        if self.mlm:
            # TODO: see if this helps
            # self.eligible_replacements = torch.tensor(
            #     tokenizer("ACGT", add_special_tokens=False)["input_ids"], dtype=torch.long
            # )
            self.eligible_replacements = None
        else:
            self.eligible_replacements = None
        self.max_length = max_length
        self.pad_max_length = pad_max_length if pad_max_length is not None else max_length
        self.tokenizer_name = tokenizer_name
        self.tokenizer = tokenizer
        self.return_augs = return_augs
        self.add_eos = add_eos

        if max_length <= MAX_ALLOWED_LENGTH:
            assert MAX_ALLOWED_LENGTH % max_length == 0, f"`max_length` must be a power of 2!"
            self.shifts = MAX_ALLOWED_LENGTH // max_length
        else:
            raise ValueError(f"`max_length` {max_length} (> 2^{int(math.log(MAX_ALLOWED_LENGTH, 2))}) is too large!")

        bed_path = Path(bed_file)
        assert bed_path.exists(), "Path to .bed file must exist!"

        # read bed file
        df_raw = pd.read_csv(str(bed_path), sep="\t", names=["chr_name", "start", "end", "split"])
        # select only split df
        self.df = df_raw[df_raw["split"] == split]
        # Update end points so that sequences are all length == MAX_ALLOWED_LENGTH
        self.df.loc[:, "end"] = self.df["start"] + MAX_ALLOWED_LENGTH

        self.fasta = FastaInterval(
            fasta_file=fasta_file,
            return_seq_indices=return_seq_indices,
            rc_aug=rc_aug
        )

    @staticmethod
    def replace_value(x, old_value, new_value):
        """Helper for replacing values in a tensor."""
        return torch.where(x == old_value, new_value, x)

    def __len__(self):
        return len(self.df) * self.shifts

    def __getitem__(self, idx):
        """Returns a sequence of specified len"""
        # sample a random row from df
        row_idx, shift_idx = idx // self.shifts, idx % self.shifts
        row = self.df.iloc[row_idx]
        chr_name, start, end = (row.iloc[0], row.iloc[1], row.iloc[2])

        seq = self.fasta(
            chr_name,
            start,
            end,
            max_length=self.max_length,
            i_shift=shift_idx,
            return_augs=self.return_augs,
        )
        if end - start != MAX_ALLOWED_LENGTH:
            print(row, "\nLength: ", end - start)

        if self.tokenizer_name == "char":
            seq = self.tokenizer(
                seq,
                padding="max_length",
                max_length=self.pad_max_length,
                truncation=True,
                add_special_tokens=False
            )

            seq = seq["input_ids"]  # get input_ids

            # need to handle eos here
            if self.add_eos:
                # append list seems to be faster than append tensor
                seq.append(self.tokenizer.sep_token_id)

        elif self.tokenizer_name == "bpe":
            seq = self.tokenizer(
                seq,
                # add_special_tokens=False,
                padding="max_length",
                max_length=self.pad_max_length,
                truncation=True,
            )
            # get input_ids
            if self.add_eos:
                seq = seq["input_ids"][1:]  # remove the bos, keep the eos token
            else:
                seq = seq["input_ids"][1:-1]  # remove both special tokens

        # convert to tensor
        seq = torch.LongTensor(seq)

        # replace N token with a pad token, so we can ignore it in the loss
        seq = self.replace_value(seq, self.tokenizer._vocab_str_to_int["N"], self.tokenizer.pad_token_id)

        if self.mlm:
            data, target = mlm_getitem(
                seq,
                mlm_probability=self.mlm_probability,
                contains_eos=self.add_eos,
                tokenizer=self.tokenizer,
                eligible_replacements=self.eligible_replacements,
            )

        else:
            data = seq[:-1].clone()
            target = seq[1:].clone()

        return data, target


================================================
FILE: src/dataloaders/datasets/nucleotide_transformer_dataset.py
================================================
"""Nucleotide Transformer Benchmarks Dataset.

From: https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_downstream_tasks
"""

import torch
from datasets import load_dataset

from src.dataloaders.utils.rc import coin_flip, string_reverse_complement


class NucleotideTransformerDataset(torch.utils.data.Dataset):

    """
    Loop through fasta file for sequence.
    Returns a generator that retrieves the sequence.
    """

    def __init__(
        self,
        split,
        max_length,
        dataset_name=None,
        d_output=2,  # default binary classification
        tokenizer=None,
        tokenizer_name=None,
        use_padding=None,
        add_eos=False,
        rc_aug=False,
        conjoin_train=False,
        conjoin_test=False,
        return_augs=False
    ):

        self.max_length = max_length
        self.use_padding = use_padding
        self.tokenizer_name = tokenizer_name
        self.tokenizer = tokenizer
        self.return_augs = return_augs
        self.add_eos = add_eos
        self.d_output = d_output  # needed for decoder to grab
        assert not (conjoin_train and conjoin_test), "conjoin_train and conjoin_test cannot both be True"
        if (conjoin_train or conjoin_test) and rc_aug:
            print("When using conjoin, we turn off rc_aug.")
            rc_aug = False
        self.rc_aug = rc_aug
        self.conjoin_train = conjoin_train
        self.conjoin_test = conjoin_test

        self.split = split

        # For NT tasks, we use data from InstaDeepAI/nucleotide_transformer_downstream_tasks
        self.seqs = load_dataset(
            "InstaDeepAI/nucleotide_transformer_downstream_tasks",
            name=dataset_name,
            split=split
        )

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        x = self.seqs[idx]["sequence"]  # only one sequence
        y = self.seqs[idx]["label"]

        if (self.rc_aug or (self.conjoin_test and self.split == "train")) and coin_flip():
            x = string_reverse_complement(x)

        seq = self.tokenizer(
            x,
            add_special_tokens=False,
            padding="max_length" if self.use_padding else None,
            max_length=self.max_length,
            truncation=True,
        )
        seq_ids = seq["input_ids"]  # get input_ids

        # need to handle eos here
        if self.add_eos:
            # append list seems to be faster than append tensor
            seq_ids.append(self.tokenizer.sep_token_id)

        if self.conjoin_train or (self.conjoin_test and self.split != "train"):
            x_rc = string_reverse_complement(x)
            seq_rc = self.tokenizer(
                x_rc,
                add_special_tokens=False,
                padding="max_length" if self.use_padding else None,
                max_length=self.max_length,
                truncation=True,
            )
            seq_rc_ids = seq_rc["input_ids"]  # get input_ids
            # need to handle eos here
            if self.add_eos:
                # append list seems to be faster than append tensor
                seq_rc_ids.append(self.tokenizer.sep_token_id)
            seq_ids = torch.stack((torch.LongTensor(seq_ids), torch.LongTensor(seq_rc_ids)), dim=1)

        else:
            # convert to tensor
            seq_ids = torch.LongTensor(seq_ids)

        # need to wrap in list
        target = torch.LongTensor([y])

        # `seq` has shape:
        #     - (seq_len,) if not conjoining
        #     - (seq_len, 2) for conjoining
        return seq_ids, target


================================================
FILE: src/dataloaders/fault_tolerant_sampler.py
================================================
# Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytorch/utilities/test_auto_restart.py#L1397
from typing import Iterator
import math

import torch
from torch.utils.data import RandomSampler, DistributedSampler


class RandomFaultTolerantSampler(RandomSampler):

    def __init__(self, *args, generator=None, **kwargs):
        # generator = torch.Generator().manual_seed(seed)
        # super().__init__(*args, generator=generator, **kwargs)
        # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
        # which should be reproducible if pl.seed_everything was called before hand.
        # This means that changing the seed of the experiment will also change the
        # sampling order.
        if generator is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator().manual_seed(seed)
        super().__init__(*args, generator=generator, **kwargs)
        self.counter = 0
        # self.start_counter = 0
        self.restarting = False

    def state_dict(self):
        return {"random_state": self.state, "counter": self.counter}

    def load_state_dict(self, state_dict):
        self.generator.set_state(state_dict.get("random_state"))
        self.counter = state_dict["counter"]
        # self.start_counter = self.counter
        self.restarting = True

    # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
    # epoch, and subsequent epoch will have very few batches.
    # def __len__(self):
    #     # We need a separate self.start_counter because PL seems to call len repeatedly.
    #     # If we use len(self.data_source) - self.counter then PL will think the epoch ends
    #     # when we're only half way through.
    #     return len(self.data_source) - self.start_counter

    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)

        self.state = self.generator.get_state()
        indices = torch.randperm(n, generator=self.generator).tolist()

        if not self.restarting:
            self.counter = 0
        else:
            indices = indices[self.counter:]
            self.restarting = False
        # self.start_counter = self.counter

        for index in indices:
            self.counter += 1
            yield index

        self.counter = 0
        # self.start_counter = self.counter


class FaultTolerantDistributedSampler(DistributedSampler):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.counter = 0
        # self.start_counter = 0
        self.restarting = False

    def state_dict(self):
        return {"epoch": self.epoch, "counter": self.counter}

    def load_state_dict(self, state_dict):
        self.epoch = state_dict["epoch"]
        self.counter = state_dict["counter"]
        # self.start_counter = self.counter
        self.restarting = True

    # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
    # epoch, and subsequent epoch will have very few batches.
    # def __len__(self) -> int:
        # return self.num_samples - self.start_counter

    def __iter__(self):
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[:self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        if not self.restarting:
            self.counter = 0
        else:
            indices = indices[self.counter:]
            self.restarting = False
        # self.start_counter = self.counter

        for index in indices:
            self.counter += 1
            yield index

        self.counter = 0
        # self.start_counter = self.counter

================================================
FILE: src/dataloaders/genomics.py
================================================
"""Dataloaders for genomics datasets, including pretraining and downstream tasks.

    - Adapted from:
        https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py
    - Adapted from:
        https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
"""

import copy
from typing import Any, List, Union

import torch
from datasets import Dataset
from torch.utils.data.dataloader import DataLoader

from caduceus.tokenization_caduceus import CaduceusTokenizer
import src.utils.train
from src.dataloaders.base import SequenceDataset, default_data_path
from src.dataloaders.datasets.genomic_bench_dataset import GenomicBenchmarkDataset
from src.dataloaders.datasets.hg38_char_tokenizer import CharacterTokenizer
from src.dataloaders.datasets.hg38_dataset import HG38Dataset
from src.dataloaders.datasets.nucleotide_transformer_dataset import NucleotideTransformerDataset
from src.dataloaders.fault_tolerant_sampler import FaultTolerantDistributedSampler
from src.dataloaders.fault_tolerant_sampler import RandomFaultTolerantSampler

logger = src.utils.train.get_logger(__name__)


class HG38(SequenceDataset):
    """
    Base class, other dataloaders can inherit from this class.

    You must implement the following functions:
        - __init__
        - setup

    You can then use (already have access to) the following functions:
        - train_dataloader
        - val_dataloader
        - test_dataloader

    """
    _name_ = "hg38"  # this name is how the dataset config finds the right dataloader

    def __init__(self, bed_file, fasta_file, tokenizer_name=None, dataset_config_name=None, max_length=1024, d_output=2,
                 rc_aug=False,
                 max_length_val=None, max_length_test=None, val_ratio=0.0005, val_split_seed=2357,
                 add_eos=True, detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, shuffle=False,
                 num_workers=1,
                 fault_tolerant=False, ddp=False,
                 fast_forward_epochs=None, fast_forward_batches=None,
                 mlm=False, mlm_probability=0.15,
                 *args, **kwargs):
        self.dataset_config_name = dataset_config_name
        self.tokenizer_name = tokenizer_name
        self.d_output = d_output
        self.rc_aug = rc_aug  # reverse compliment augmentation
        self.max_length = max_length
        self.max_length_val = max_length_val if max_length_val is not None else max_length
        self.max_length_test = max_length_test if max_length_test is not None else max_length
        self.val_ratio = val_ratio
        self.val_split_seed = val_split_seed
        self.val_only = val_only
        self.add_eos = add_eos
        self.detokenize = detokenize
        self.batch_size = batch_size
        self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        self.bed_file = bed_file
        self.fasta_file = fasta_file

        # handle if file paths are None (default paths)
        if self.bed_file is None:
            self.bed_file = default_data_path / self._name_ / "human-sequences.bed"
        if self.fasta_file is None:
            self.fasta_file = default_data_path / self._name_ / "hg38.ml.fa"

        if fault_tolerant:
            assert self.shuffle
        self.fault_tolerant = fault_tolerant
        if ddp:
            assert fault_tolerant
        self.ddp = ddp
        self.fast_forward_epochs = fast_forward_epochs
        self.fast_forward_batches = fast_forward_batches
        if self.fast_forward_epochs is not None or self.fast_forward_batches is not None:
            assert ddp and fault_tolerant

        self.mlm = mlm
        self.mlm_probability = mlm_probability

        # To be instantiated in `setup`
        self.tokenizer = None
        self.vocab_size = 0

    def setup(self, stage=None):
        """Set up the tokenizer and init the datasets."""
        # TODO instantiate with registry

        if self.tokenizer_name == "char":
            logger.info("**Using Char-level tokenizer**")
            # self.tokenizer = CharacterTokenizer(
            #     characters=["A", "C", "G", "T", "N"],
            #     model_max_length=self.max_length,
            #     add_special_tokens=False,
            # )
            self.tokenizer = CaduceusTokenizer(
                model_max_length=self.max_length,
                add_special_tokens=False
            )
        else:
            raise NotImplementedError(f"Tokenizer {self.tokenizer_name} not implemented.")

        self.vocab_size = len(self.tokenizer)

        self.init_datasets()  # creates the datasets.  You can also just create this inside the setup() here.

    def init_datasets(self):
        """Init the datasets (separate from the tokenizer)"""

        # delete old datasets to free memory
        if hasattr(self, "dataset_train"):
            self.dataset_train.fasta.seqs.close()
            del self.dataset_train.fasta.seqs

        # delete old datasets to free memory
        if hasattr(self, "dataset_test"):
            self.dataset_test.fasta.seqs.close()
            del self.dataset_test.fasta.seqs

        # Create all splits: torch datasets
        self.dataset_train, self.dataset_val, self.dataset_test = [
            HG38Dataset(split=split,
                        bed_file=self.bed_file,
                        fasta_file=self.fasta_file,
                        max_length=max_len,
                        tokenizer=self.tokenizer,  # pass the tokenize wrapper
                        tokenizer_name=self.tokenizer_name,
                        add_eos=self.add_eos,
                        return_seq_indices=False,
                        rc_aug=self.rc_aug,
                        return_augs=False,
                        mlm=self.mlm,
                        mlm_probability=self.mlm_probability, )
            for split, max_len in
            zip(["train", "valid", "test"], [self.max_length, self.max_length_val, self.max_length_test])
        ]

        return

    def train_dataloader(self, **kwargs: Any) -> DataLoader:
        """ The train dataloader """
        if self.shuffle and self.fault_tolerant:
            shuffle = False
            # TD [2022-12-26]: We need the distributed_sampler_kwargs in case of model parallel:
            # In that case the number of replicas and the data parallel rank are more complicated.
            distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs
            sampler = (FaultTolerantDistributedSampler(
                self.dataset_train,
                **distributed_sampler_kwargs
            ) if self.ddp else RandomFaultTolerantSampler(self.dataset_train))
            # TD [2022-08-06]: Only the DDP sampler supports fast-forwarding for now
            # We assume that it's being resumed with the same number of GPUs
            if self.ddp and self.fast_forward_epochs is not None and self.fast_forward_batches is not None:
                sampler.load_state_dict({
                    "epoch": self.fast_forward_epochs,
                    "counter": self.fast_forward_batches * self.batch_size
                })
        else:
            shuffle = self.shuffle
            sampler = None
        loader = self._data_loader(self.dataset_train, batch_size=self.batch_size,
                                   shuffle=shuffle, sampler=sampler, **kwargs)
        return loader

    def val_dataloader(self, **kwargs: Any) -> Union[DataLoader, List[DataLoader]
Download .txt
gitextract_kw76odef/

├── .gitignore
├── LICENSE
├── README.md
├── caduceus/
│   ├── __init__.py
│   ├── configuration_caduceus.py
│   ├── modeling_caduceus.py
│   ├── modeling_rcps.py
│   ├── tests/
│   │   └── test_rcps.py
│   └── tokenization_caduceus.py
├── caduceus_env.yml
├── configs/
│   ├── callbacks/
│   │   ├── base.yaml
│   │   ├── checkpoint.yaml
│   │   ├── gpu_affinity.yaml
│   │   ├── rich.yaml
│   │   ├── val_every_n_global_steps.yaml
│   │   └── wandb.yaml
│   ├── config.yaml
│   ├── dataset/
│   │   ├── genomic_benchmark.yaml
│   │   ├── hg38.yaml
│   │   └── nucleotide_transformer.yaml
│   ├── experiment/
│   │   └── hg38/
│   │       ├── genomic_benchmark.yaml
│   │       ├── genomic_benchmark_cnn.yaml
│   │       ├── hg38.yaml
│   │       └── nucleotide_transformer.yaml
│   ├── loader/
│   │   └── default.yaml
│   ├── model/
│   │   ├── caduceus.yaml
│   │   ├── genomics_benchmark_cnn.yaml
│   │   ├── hyena.yaml
│   │   ├── layer/
│   │   │   └── hyena.yaml
│   │   └── mamba.yaml
│   ├── optimizer/
│   │   ├── adam.yaml
│   │   ├── adamw.yaml
│   │   └── sgd.yaml
│   ├── pipeline/
│   │   ├── genomic_benchmark.yaml
│   │   ├── hg38.yaml
│   │   └── nucleotide_transformer.yaml
│   ├── scheduler/
│   │   ├── constant.yaml
│   │   ├── constant_warmup.yaml
│   │   ├── cosine.yaml
│   │   ├── cosine_warmup.yaml
│   │   ├── cosine_warmup_timm.yaml
│   │   ├── linear_warmup.yaml
│   │   ├── multistep.yaml
│   │   ├── plateau.yaml
│   │   └── step.yaml
│   ├── task/
│   │   ├── lm.yaml
│   │   ├── multiclass_classification.yaml
│   │   ├── multilabel_classification.yaml
│   │   └── regression.yaml
│   └── trainer/
│       ├── debug.yaml
│       ├── default.yaml
│       ├── full.yaml
│       └── lm.yaml
├── setup_env.sh
├── slurm_scripts/
│   ├── dump_vep_embeddings.sh
│   ├── run_genomics_benchmark.sh
│   ├── run_genomics_benchmark_cnn.sh
│   ├── run_nucleotide_transformer.sh
│   ├── run_pretrain_caduceus.sh
│   ├── run_pretrain_hyena.sh
│   ├── run_pretrain_mamba.sh
│   ├── wrapper_run_genomics.sh
│   ├── wrapper_run_genomics_cnn.sh
│   └── wrapper_run_nucleotide_transformer.sh
├── src/
│   ├── __init__.py
│   ├── callbacks/
│   │   ├── params.py
│   │   ├── timer.py
│   │   └── validation.py
│   ├── dataloaders/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── datasets/
│   │   │   ├── genomic_bench_dataset.py
│   │   │   ├── hg38_char_tokenizer.py
│   │   │   ├── hg38_dataset.py
│   │   │   └── nucleotide_transformer_dataset.py
│   │   ├── fault_tolerant_sampler.py
│   │   ├── genomics.py
│   │   └── utils/
│   │       ├── mlm.py
│   │       └── rc.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── baseline/
│   │   │   ├── __init__.py
│   │   │   └── genomics_benchmark_cnn.py
│   │   ├── nn/
│   │   │   ├── __init__.py
│   │   │   ├── activation.py
│   │   │   ├── adaptive_softmax.py
│   │   │   └── utils.py
│   │   └── sequence/
│   │       ├── __init__.py
│   │       ├── dna_embedding.py
│   │       ├── hyena.py
│   │       └── long_conv_lm.py
│   ├── ops/
│   │   └── fftconv.py
│   ├── tasks/
│   │   ├── decoders.py
│   │   ├── encoders.py
│   │   ├── metrics.py
│   │   ├── tasks.py
│   │   └── torchmetrics.py
│   └── utils/
│       ├── __init__.py
│       ├── config.py
│       ├── optim/
│       │   └── schedulers.py
│       ├── optim_groups.py
│       ├── registry.py
│       └── train.py
├── train.py
├── vep_embeddings.py
└── vep_svm.ipynb
Download .txt
SYMBOL INDEX (434 symbols across 36 files)

FILE: caduceus/configuration_caduceus.py
  class CaduceusConfig (line 10) | class CaduceusConfig(PretrainedConfig):
    method __init__ (line 14) | def __init__(

FILE: caduceus/modeling_caduceus.py
  function create_block (line 33) | def create_block(
  class BiMambaWrapper (line 87) | class BiMambaWrapper(nn.Module):
    method __init__ (line 90) | def __init__(
    method forward (line 122) | def forward(self, hidden_states, inference_params=None):
  class CaduceusEmbeddings (line 143) | class CaduceusEmbeddings(nn.Module):
    method __init__ (line 144) | def __init__(
    method forward (line 159) | def forward(self, input_ids):
  class CaduceusMixerModel (line 166) | class CaduceusMixerModel(nn.Module):
    method __init__ (line 167) | def __init__(
    method forward (line 216) | def forward(self, input_ids, inputs_embeds=None, output_hidden_states=...
  function cross_entropy (line 279) | def cross_entropy(logits, y, ignore_index=-100):
  function weighted_cross_entropy (line 286) | def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
  class CaduceusPreTrainedModel (line 297) | class CaduceusPreTrainedModel(PreTrainedModel):
    method _init_weights (line 304) | def _init_weights(
  class Caduceus (line 344) | class Caduceus(CaduceusPreTrainedModel):
    method __init__ (line 346) | def __init__(self, config: CaduceusConfig, device=None, dtype=None, **...
    method forward (line 363) | def forward(
  class CaduceusForMaskedLM (line 392) | class CaduceusForMaskedLM(CaduceusPreTrainedModel):
    method __init__ (line 395) | def __init__(self, config: CaduceusConfig, device=None, dtype=None, **...
    method get_input_embeddings (line 417) | def get_input_embeddings(self):
    method set_input_embeddings (line 420) | def set_input_embeddings(self, value):
    method get_output_embeddings (line 425) | def get_output_embeddings(self):
    method set_output_embeddings (line 428) | def set_output_embeddings(self, new_embeddings):
    method tie_weights (line 434) | def tie_weights(self):
    method get_decoder (line 441) | def get_decoder(self):
    method set_decoder (line 445) | def set_decoder(self, decoder):
    method forward (line 449) | def forward(
  class CaduceusForSequenceClassification (line 495) | class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
    method __init__ (line 496) | def __init__(
    method init_scorer (line 521) | def init_scorer(self, initializer_range=0.02):
    method get_input_embeddings (line 526) | def get_input_embeddings(self):
    method set_input_embeddings (line 529) | def set_input_embeddings(self, value):
    method pool_hidden_states (line 534) | def pool_hidden_states(self, hidden_states, sequence_length_dim=1):
    method forward (line 545) | def forward(

FILE: caduceus/modeling_rcps.py
  class RCPSEmbedding (line 21) | class RCPSEmbedding(nn.Module):
    method __init__ (line 23) | def __init__(self, vocab_size: int, d_model: int, complement_map: dict...
    method weight (line 38) | def weight(self):
    method set_weight (line 42) | def set_weight(self, value):
    method rc (line 46) | def rc(self, x):
    method forward (line 54) | def forward(self, input_ids):
  class RCPSWrapper (line 70) | class RCPSWrapper(nn.Module):
    method __init__ (line 76) | def __init__(self, submodule: nn.Module):
    method rc (line 81) | def rc(x):
    method forward (line 85) | def forward(self, x, **kwargs):
  class RCPSAddNormWrapper (line 102) | class RCPSAddNormWrapper(RCPSWrapper):
    method __init__ (line 104) | def __init__(self, submodule: nn.Module):
    method forward (line 107) | def forward(self, x, residual=None, prenorm=False):
  class RCPSMambaBlock (line 133) | class RCPSMambaBlock(nn.Module):
    method __init__ (line 134) | def __init__(
    method forward (line 160) | def forward(
    method allocate_inference_cache (line 201) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,...
  class RCPSLMHead (line 209) | class RCPSLMHead(nn.Module):
    method __init__ (line 211) | def __init__(self, true_dim: int, vocab_size: int, complement_map: dic...
    method weight (line 225) | def weight(self):
    method set_weight (line 229) | def set_weight(self, value):
    method forward (line 233) | def forward(self, x):

FILE: caduceus/tests/test_rcps.py
  function test_rcps_embedding (line 31) | def test_rcps_embedding(batch_size, seq_len, d_model, dtype):
  function test_rcps_wrapper (line 80) | def test_rcps_wrapper(batch_size, seq_len, d_model, dtype):
  function test_rcps_add_norm_wrapper (line 116) | def test_rcps_add_norm_wrapper(batch_size, seq_len, d_model, prenorm, dt...
  function test_rcps_mamba_block_wrapper (line 155) | def test_rcps_mamba_block_wrapper(batch_size, seq_len, d_model, bidirect...
  function test_rcps_lm_head (line 209) | def test_rcps_lm_head(batch_size, seq_len, d_model, dtype):
  function test_rcps_backbone (line 271) | def test_rcps_backbone(batch_size, seq_len, n_layer, d_model, dtype, fus...
  function test_rcps_mamba_lm (line 348) | def test_rcps_mamba_lm(batch_size, seq_len, n_layer, d_model, dtype, bid...
  function test_collapse_invariance (line 429) | def test_collapse_invariance(batch_size, seq_len, n_layer, d_model, dtyp...

FILE: caduceus/tokenization_caduceus.py
  class CaduceusTokenizer (line 10) | class CaduceusTokenizer(PreTrainedTokenizer):
    method __init__ (line 13) | def __init__(self,
    method vocab_size (line 83) | def vocab_size(self) -> int:
    method complement_map (line 87) | def complement_map(self) -> Dict[int, int]:
    method _tokenize (line 90) | def _tokenize(self, text: str, **kwargs) -> List[str]:
    method _convert_token_to_id (line 93) | def _convert_token_to_id(self, token: str) -> int:
    method _convert_id_to_token (line 96) | def _convert_id_to_token(self, index: int) -> str:
    method convert_tokens_to_string (line 99) | def convert_tokens_to_string(self, tokens):
    method get_special_tokens_mask (line 102) | def get_special_tokens_mask(
    method build_inputs_with_special_tokens (line 120) | def build_inputs_with_special_tokens(
    method get_vocab (line 130) | def get_vocab(self) -> Dict[str, int]:
    method save_vocabulary (line 134) | def save_vocabulary(self, save_directory: str, filename_prefix: Option...

FILE: src/callbacks/params.py
  class ParamsLog (line 10) | class ParamsLog(pl.Callback):
    method __init__ (line 12) | def __init__(
    method on_fit_start (line 28) | def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningMod...

FILE: src/callbacks/timer.py
  class Timer (line 18) | class Timer(Callback):
    method __init__ (line 21) | def __init__(
    method on_train_start (line 36) | def on_train_start(self, trainer: Trainer, pl_module: LightningModule)...
    method on_train_epoch_start (line 39) | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningM...
    method on_train_batch_start (line 44) | def on_train_batch_start(
    method on_train_batch_end (line 65) | def on_train_batch_end(
    method on_train_epoch_end (line 86) | def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningMod...
    method on_validation_epoch_start (line 92) | def on_validation_epoch_start(self, trainer: Trainer, pl_module: Light...
    method on_validation_epoch_end (line 96) | def on_validation_epoch_end(self, trainer: Trainer, pl_module: Lightni...
    method _should_log (line 103) | def _should_log(trainer) -> bool:

FILE: src/callbacks/validation.py
  class ValEveryNGlobalSteps (line 13) | class ValEveryNGlobalSteps(Callback):
    method __init__ (line 15) | def __init__(self, every_n):
    method on_train_batch_end (line 19) | def on_train_batch_end(self, trainer, *_: Any):

FILE: src/dataloaders/base.py
  class DefaultCollateMixin (line 20) | class DefaultCollateMixin:
    method _collate_callback (line 30) | def _collate_callback(cls, x, *args, **kwargs):
    method _return_callback (line 39) | def _return_callback(cls, return_value, *args, **kwargs):
    method _collate (line 50) | def _collate(cls, batch, *args, **kwargs):
    method _collate_fn (line 71) | def _collate_fn(cls, batch, *args, **kwargs):
    method _dataloader (line 92) | def _dataloader(self, dataset, **loader_args):
  class SequenceDataset (line 106) | class SequenceDataset(DefaultCollateMixin):
    method init_defaults (line 115) | def init_defaults(self):
    method __init_subclass__ (line 119) | def __init_subclass__(cls, **kwargs):
    method __init__ (line 123) | def __init__(self, _name_, data_dir=None, **dataset_cfg):
    method init (line 138) | def init(self):
    method setup (line 142) | def setup(self):
    method split_train_val (line 146) | def split_train_val(self, val_split):
    method train_dataloader (line 159) | def train_dataloader(self, **kwargs):
    method _train_dataloader (line 163) | def _train_dataloader(self, dataset, **kwargs):
    method val_dataloader (line 169) | def val_dataloader(self, **kwargs):
    method test_dataloader (line 173) | def test_dataloader(self, **kwargs):
    method _eval_dataloader (line 177) | def _eval_dataloader(self, dataset, **kwargs):
    method __str__ (line 183) | def __str__(self):

FILE: src/dataloaders/datasets/genomic_bench_dataset.py
  class GenomicBenchmarkDataset (line 15) | class GenomicBenchmarkDataset(torch.utils.data.Dataset):
    method __init__ (line 21) | def __init__(
    method __len__ (line 80) | def __len__(self):
    method __getitem__ (line 83) | def __getitem__(self, idx):

FILE: src/dataloaders/datasets/hg38_char_tokenizer.py
  class CharacterTokenizer (line 15) | class CharacterTokenizer(PreTrainedTokenizer):
    method __init__ (line 16) | def __init__(self, characters: Sequence[str], model_max_length: int, p...
    method vocab_size (line 77) | def vocab_size(self) -> int:
    method _tokenize (line 80) | def _tokenize(self, text: str) -> List[str]:
    method _convert_token_to_id (line 83) | def _convert_token_to_id(self, token: str) -> int:
    method _convert_id_to_token (line 86) | def _convert_id_to_token(self, index: int) -> str:
    method convert_tokens_to_string (line 89) | def convert_tokens_to_string(self, tokens):
    method build_inputs_with_special_tokens (line 92) | def build_inputs_with_special_tokens(
    method get_special_tokens_mask (line 102) | def get_special_tokens_mask(
    method get_vocab (line 120) | def get_vocab(self) -> Dict[str, int]:
    method create_token_type_ids_from_sequences (line 123) | def create_token_type_ids_from_sequences(
    method get_config (line 134) | def get_config(self) -> Dict:
    method from_config (line 141) | def from_config(cls, config: Dict) -> "CharacterTokenizer":
    method save_pretrained (line 147) | def save_pretrained(self, save_directory: Union[str, os.PathLike], **k...
    method from_pretrained (line 154) | def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kw...

FILE: src/dataloaders/datasets/hg38_dataset.py
  class FastaInterval (line 18) | class FastaInterval:
    method __init__ (line 20) | def __init__(
    method _compute_interval (line 41) | def _compute_interval(start, end, max_length, i_shift):
    method __call__ (line 50) | def __call__(
  class HG38Dataset (line 92) | class HG38Dataset(torch.utils.data.Dataset):
    method __init__ (line 95) | def __init__(
    method replace_value (line 153) | def replace_value(x, old_value, new_value):
    method __len__ (line 157) | def __len__(self):
    method __getitem__ (line 160) | def __getitem__(self, idx):

FILE: src/dataloaders/datasets/nucleotide_transformer_dataset.py
  class NucleotideTransformerDataset (line 12) | class NucleotideTransformerDataset(torch.utils.data.Dataset):
    method __init__ (line 19) | def __init__(
    method __len__ (line 59) | def __len__(self):
    method __getitem__ (line 62) | def __getitem__(self, idx):

FILE: src/dataloaders/fault_tolerant_sampler.py
  class RandomFaultTolerantSampler (line 9) | class RandomFaultTolerantSampler(RandomSampler):
    method __init__ (line 11) | def __init__(self, *args, generator=None, **kwargs):
    method state_dict (line 26) | def state_dict(self):
    method load_state_dict (line 29) | def load_state_dict(self, state_dict):
    method __iter__ (line 43) | def __iter__(self) -> Iterator[int]:
  class FaultTolerantDistributedSampler (line 64) | class FaultTolerantDistributedSampler(DistributedSampler):
    method __init__ (line 66) | def __init__(self, *args, **kwargs):
    method state_dict (line 72) | def state_dict(self):
    method load_state_dict (line 75) | def load_state_dict(self, state_dict):
    method __iter__ (line 86) | def __iter__(self):

FILE: src/dataloaders/genomics.py
  class HG38 (line 29) | class HG38(SequenceDataset):
    method __init__ (line 45) | def __init__(self, bed_file, fasta_file, tokenizer_name=None, dataset_...
    method setup (line 97) | def setup(self, stage=None):
    method init_datasets (line 119) | def init_datasets(self):
    method train_dataloader (line 152) | def train_dataloader(self, **kwargs: Any) -> DataLoader:
    method val_dataloader (line 177) | def val_dataloader(self, **kwargs: Any) -> Union[DataLoader, List[Data...
    method test_dataloader (line 182) | def test_dataloader(self, **kwargs: Any) -> Union[DataLoader, List[Dat...
    method _data_loader (line 189) | def _data_loader(dataset: Dataset, batch_size: int, shuffle: bool = Fa...
    method load_state_dict (line 198) | def load_state_dict(self, checkpoint):
  class GenomicBenchmark (line 208) | class GenomicBenchmark(HG38):
    method __init__ (line 212) | def __init__(
    method setup (line 262) | def setup(self, stage=None):
  class NucleotideTransformer (line 308) | class NucleotideTransformer(HG38):
    method __init__ (line 312) | def __init__(self, dataset_name, train_val_split_seed,
    method setup (line 357) | def setup(self, stage=None):

FILE: src/dataloaders/utils/mlm.py
  function mlm_getitem (line 4) | def mlm_getitem(seq, mlm_probability=0.15, contains_eos=False, tokenizer...

FILE: src/dataloaders/utils/rc.py
  function coin_flip (line 12) | def coin_flip(p=0.5):
  function string_reverse_complement (line 17) | def string_reverse_complement(seq):

FILE: src/models/baseline/genomics_benchmark_cnn.py
  class GenomicsBenchmarkCNN (line 10) | class GenomicsBenchmarkCNN(nn.Module):
    method __init__ (line 11) | def __init__(self, number_of_classes, vocab_size, input_len, embedding...
    method count_flatten_size (line 42) | def count_flatten_size(self, input_len):
    method forward (line 49) | def forward(self, x, state=None):  # Adding `state` to be consistent w...

FILE: src/models/nn/activation.py
  function Activation (line 9) | def Activation(activation=None, size=None, dim=-1):
  class GLU (line 45) | class GLU(nn.Module):
    method __init__ (line 46) | def __init__(self, dim=-1, activation='sigmoid'):
    method forward (line 52) | def forward(self, x):
  class ModReLU (line 57) | class ModReLU(nn.Module):
    method __init__ (line 60) | def __init__(self, features):
    method reset_parameters (line 67) | def reset_parameters(self):
    method forward (line 70) | def forward(self, inputs):
  class SquaredReLU (line 79) | class SquaredReLU(nn.Module):
    method forward (line 80) | def forward(self, x):
  function laplace (line 85) | def laplace(x, mu=0.707107, sigma=0.282095):
  class Laplace (line 90) | class Laplace(nn.Module):
    method __init__ (line 91) | def __init__(self, mu=0.707107, sigma=0.282095):
    method forward (line 96) | def forward(self, x):

FILE: src/models/nn/adaptive_softmax.py
  class OptionalParameterList (line 23) | class OptionalParameterList(nn.ParameterList):
    method extra_repr (line 24) | def extra_repr(self):
  class ProjectedAdaptiveLogSoftmax (line 37) | class ProjectedAdaptiveLogSoftmax(nn.Module):
    method __init__ (line 38) | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
    method _compute_logit (line 128) | def _compute_logit(self, hidden, weight, bias, proj):
    method get_out_proj (line 142) | def get_out_proj(self, i):
    method forward (line 153) | def forward(self, hidden, target, keep_order=False, key_padding_mask=N...
    method compute_logits (line 237) | def compute_logits(self, hidden):
  class AdaptiveEmbedding (line 300) | class AdaptiveEmbedding(nn.Module):
    method __init__ (line 305) | def __init__(self, n_token, d_embed, d_proj, cutoffs : List[int], div_...
    method forward (line 342) | def forward(self, inp):
  function _init_weight (line 395) | def _init_weight(weight, d : int, init_scale : Optional[float], default=...

FILE: src/models/nn/utils.py
  function wrap_kwargs (line 8) | def wrap_kwargs(f):
  function discard_kwargs (line 84) | def discard_kwargs(f):
  function PassthroughSequential (line 92) | def PassthroughSequential(*modules):

FILE: src/models/sequence/dna_embedding.py
  class DNAEmbeddingModel (line 27) | class DNAEmbeddingModel(nn.Module, GenerationMixin):
    method __init__ (line 34) | def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_siz...
    method forward (line 82) | def forward(self, input_ids, position_ids=None, inference_params=None,...
    method d_output (line 90) | def d_output(self):
  class DNAEmbeddingModelMamba (line 99) | class DNAEmbeddingModelMamba(DNAEmbeddingModel):
    method __init__ (line 102) | def __init__(
    method forward (line 149) | def forward(self, input_ids, position_ids=None, inference_params=None,...
  class DNAEmbeddingModelCaduceus (line 156) | class DNAEmbeddingModelCaduceus(DNAEmbeddingModel):
    method __init__ (line 159) | def __init__(
    method forward (line 179) | def forward(self, input_ids, position_ids=None, inference_params=None,...
  function load_backbone (line 198) | def load_backbone(model, state_dict, freeze_backbone=False, ignore_head=...

FILE: src/models/sequence/hyena.py
  class FFTConvFuncv2 (line 25) | class FFTConvFuncv2(torch.autograd.Function):
    method forward (line 27) | def forward(ctx, u, k):
    method backward (line 40) | def backward(ctx, dout):
  function fftconv_ref (line 55) | def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None):
  function mul_sum (line 79) | def mul_sum(q, y):
  class Sin (line 83) | class Sin(nn.Module):
    method __init__ (line 84) | def __init__(self, dim, w=10, train_freq=True):
    method forward (line 92) | def forward(self, x):
  class PositionalEmbedding (line 96) | class PositionalEmbedding(OptimModule):
    method __init__ (line 97) | def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float = 1e-...
    method forward (line 117) | def forward(self, L):
  class ExponentialModulation (line 121) | class ExponentialModulation(OptimModule):
    method __init__ (line 122) | def __init__(
    method forward (line 139) | def forward(self, t, x):
  class HyenaFilter (line 145) | class HyenaFilter(OptimModule):
    method __init__ (line 146) | def __init__(
    method filter (line 214) | def filter(self, L, *args, **kwargs):
    method forward (line 225) | def forward(self, x, L, k=None, bias=None, *args, **kwargs):
  class HyenaOperator (line 255) | class HyenaOperator(nn.Module):
    method __init__ (line 256) | def __init__(
    method setup_projections (line 330) | def setup_projections(self, fused_bias_fc, inner_factor):
    method setup_filters (line 343) | def setup_filters(self, filter_cls, filter_args):
    method recurrence (line 369) | def recurrence(self, u, state):
    method forward (line 373) | def forward(self, u, *args, **kwargs):
    method d_output (line 432) | def d_output(self):

FILE: src/models/sequence/long_conv_lm.py
  class CheckpointedModule (line 33) | class CheckpointedModule(torch.nn.Module):
    method __init__ (line 34) | def __init__(self, layer):
    method forward (line 38) | def forward(self, x):
  function create_mixer_cls (line 42) | def create_mixer_cls(
  function create_mlp_cls (line 93) | def create_mlp_cls(
  function create_block (line 130) | def create_block(
  function _init_weights (line 195) | def _init_weights(
  class LMBackbone (line 240) | class LMBackbone(nn.Module):
    method __init__ (line 241) | def __init__(
    method tie_weights (line 344) | def tie_weights(self):
    method forward (line 348) | def forward(self, input_ids, position_ids=None, inference_params=None):
  class ConvLMHeadModel (line 391) | class ConvLMHeadModel(nn.Module, GenerationMixin):
    method __init__ (line 392) | def __init__(
    method tie_weights (line 473) | def tie_weights(self):
    method forward (line 478) | def forward(

FILE: src/ops/fftconv.py
  function _mul_sum (line 11) | def _mul_sum(y, q):
  function fftconv_ref (line 15) | def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None):
  function fftconv_h3_ref (line 38) | def fftconv_h3_ref(k, ssm_kernel, D, q, v, head_dim=1, ssm_kernel_rev=No...
  class FFTConvFunc (line 58) | class FFTConvFunc(torch.autograd.Function):
    method forward (line 61) | def forward(ctx, u, k, D, dropout_mask=None, gelu=True, force_fp16_out...
    method backward (line 88) | def backward(ctx, dout):
  function fftconv_func (line 105) | def fftconv_func(u, k, D, dropout_mask=None, gelu=True, force_fp16_outpu...

FILE: src/tasks/decoders.py
  class Decoder (line 16) | class Decoder(nn.Module):
    method forward (line 21) | def forward(self, x, **kwargs):
    method step (line 33) | def step(self, x):
  class SequenceDecoder (line 40) | class SequenceDecoder(Decoder):
    method __init__ (line 41) | def __init__(
    method forward (line 70) | def forward(self, x, state=None, lengths=None, l_output=None):
    method step (line 156) | def step(self, x, state=None):
  function _instantiate (line 197) | def _instantiate(decoder, model=None, dataset=None):
  function instantiate (line 217) | def instantiate(decoder, model=None, dataset=None):

FILE: src/tasks/encoders.py
  class Encoder (line 7) | class Encoder(nn.Module):
    method forward (line 16) | def forward(self, x, **kwargs):
  function _instantiate (line 64) | def _instantiate(encoder, dataset=None, model=None):
  function instantiate (line 84) | def instantiate(encoder, dataset=None, model=None):

FILE: src/tasks/metrics.py
  class CorrectAggregatedMetric (line 13) | class CorrectAggregatedMetric(Metric):
    method __init__ (line 16) | def __init__(self, class_idx: int, dist_sync_on_step=False):
    method _update (line 25) | def _update(self, numerator, denominator, preds, y) -> tuple:
    method update (line 28) | def update(self, logits: torch.Tensor, y: torch.Tensor):
    method compute (line 36) | def compute(self):
    method reset (line 41) | def reset(self):
  class AccuracyPerClass (line 45) | class AccuracyPerClass(CorrectAggregatedMetric):
    method _update (line 48) | def _update(self, numerator, denominator, preds, y) -> tuple:
  class PrecisionPerClass (line 59) | class PrecisionPerClass(CorrectAggregatedMetric):
    method _update (line 62) | def _update(self, numerator, denominator, preds, y) -> tuple:
  class RecallPerClass (line 71) | class RecallPerClass(CorrectAggregatedMetric):
    method _update (line 74) | def _update(self, numerator, denominator, preds, y) -> tuple:
  function mcc (line 83) | def mcc(logits, y):
  function last_k_ppl (line 90) | def last_k_ppl(logits, y, seq_len=1024, k=None):
  function _student_t_map (line 122) | def _student_t_map(mu, sigma, nu):
  function student_t_loss (line 127) | def student_t_loss(outs, y):
  function gaussian_ll_loss (line 144) | def gaussian_ll_loss(outs, y):
  function binary_cross_entropy (line 155) | def binary_cross_entropy(logits, y):
  function binary_accuracy (line 161) | def binary_accuracy(logits, y):
  function padded_cross_entropy (line 164) | def padded_cross_entropy(logits, y, pad_mask, pad_value=-1):
  function cross_entropy (line 181) | def cross_entropy(logits, y, ignore_index=-100):
  function soft_cross_entropy (line 187) | def soft_cross_entropy(logits, y, label_smoothing=0.0):
  function accuracy (line 193) | def accuracy(logits, y):
  function accuracy_ignore_index (line 203) | def accuracy_ignore_index(logits, y, ignore_index=-100):
  function accuracy_at_k (line 212) | def accuracy_at_k(logits, y, k=1):
  function f1_binary (line 221) | def f1_binary(logits, y):
  function f1_macro (line 228) | def f1_macro(logits, y):
  function f1_micro (line 235) | def f1_micro(logits, y):
  function roc_auc_macro (line 242) | def roc_auc_macro(logits, y):
  function roc_auc_micro (line 252) | def roc_auc_micro(logits, y):
  function mse (line 260) | def mse(outs, y, len_batch=None):
  function forecast_rmse (line 278) | def forecast_rmse(outs, y, len_batch=None):
  function mae (line 282) | def mae(outs, y, len_batch=None):
  function loss (line 301) | def loss(x, y, loss_fn):
  function bpb (line 306) | def bpb(x, y, loss_fn):
  function ppl (line 311) | def ppl(x, y, loss_fn):

FILE: src/tasks/tasks.py
  class BaseTask (line 16) | class BaseTask:
    method __init__ (line 27) | def __init__(self, dataset=None, model=None, loss=None, loss_val=None,...
    method _init_torchmetrics (line 55) | def _init_torchmetrics(self):
    method _reset_torchmetrics (line 83) | def _reset_torchmetrics(self, prefix=None):
    method get_torchmetrics (line 96) | def get_torchmetrics(self, prefix):
    method torchmetrics (line 105) | def torchmetrics(self, x, y, prefix, loss=None):
    method get_torchmetrics (line 124) | def get_torchmetrics(self, prefix):
    method metrics (line 127) | def metrics(self, x, y, **kwargs):
    method forward (line 143) | def forward(self, batch, encoder, model, decoder, _state):
  class Scalar (line 161) | class Scalar(nn.Module):
    method __init__ (line 162) | def __init__(self, c=1):
    method forward (line 166) | def forward(self, x):
  class LMTask (line 170) | class LMTask(BaseTask):
    method forward (line 171) | def forward(self, batch, encoder, model, decoder, _state):
  class MultiClass (line 199) | class MultiClass(BaseTask):
    method __init__ (line 201) | def __init__(self, *args, **kwargs):
    method metrics (line 211) | def metrics(self, x, y, **kwargs):
    method _reset_torchmetrics (line 236) | def _reset_torchmetrics(self, prefix=None):
  class HG38Task (line 244) | class HG38Task(LMTask):
    method __init__ (line 246) | def __init__(self, dataset=None, model=None, loss=None, loss_val=None,...
    method metrics (line 303) | def metrics(self, x, y, **kwargs):
  class AdaptiveLMTask (line 335) | class AdaptiveLMTask(BaseTask):
    method __init__ (line 336) | def __init__(

FILE: src/tasks/torchmetrics.py
  class Perplexity (line 24) | class Perplexity(Metric):
    method __init__ (line 46) | def __init__(self, **kwargs: Dict[str, Any]):
    method update (line 54) | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor]...
    method compute (line 68) | def compute(self) -> Tensor:
  class NumTokens (line 75) | class NumTokens(Metric):
    method __init__ (line 88) | def __init__(self, **kwargs: Dict[str, Any]):
    method update (line 97) | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor]...
    method compute (line 100) | def compute(self) -> Tensor:
    method reset (line 103) | def reset(self):
    method _forward_reduce_state_update (line 109) | def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:

FILE: src/utils/config.py
  function is_list (line 13) | def is_list(x):
  function is_dict (line 17) | def is_dict(x):
  function to_dict (line 21) | def to_dict(x, recursive=True):
  function to_list (line 37) | def to_list(x, recursive=False):
  function extract_attrs_from_obj (line 56) | def extract_attrs_from_obj(obj, *attrs):
  function auto_assign_attrs (line 63) | def auto_assign_attrs(cls, **kwargs):
  function instantiate (line 68) | def instantiate(registry, config, *args, partial=False, wrap=None, **kwa...
  function get_class (line 112) | def get_class(registry, _name_):
  function omegaconf_filter_keys (line 116) | def omegaconf_filter_keys(d, fn=None):

FILE: src/utils/optim/schedulers.py
  class CosineWarmup (line 11) | class CosineWarmup(torch.optim.lr_scheduler.CosineAnnealingLR):
    method __init__ (line 13) | def __init__(self, optimizer, T_max, eta_min=0, warmup_step=0, **kwargs):
    method get_lr (line 19) | def get_lr(self):
  function InvSqrt (line 40) | def InvSqrt(optimizer, warmup_step):
  function Constant (line 54) | def Constant(optimizer, warmup_step):
  class TimmCosineLRScheduler (line 65) | class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler....
    method __init__ (line 70) | def __init__(self, *args, **kwargs):
    method step (line 75) | def step(self, epoch=None):

FILE: src/utils/optim_groups.py
  function add_optimizer_hooks (line 14) | def add_optimizer_hooks(
  function group_parameters_for_optimizer (line 41) | def group_parameters_for_optimizer(

FILE: src/utils/train.py
  class LoggingContext (line 20) | class LoggingContext:
    method __init__ (line 21) | def __init__(self, logger, level=None, handler=None, close=True):
    method __enter__ (line 27) | def __enter__(self):
    method __exit__ (line 34) | def __exit__(self, et, ev, tb):
  function get_logger (line 44) | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
  function process_config (line 58) | def process_config(config: DictConfig) -> DictConfig:  # TODO because of...
  function print_config (line 101) | def print_config(
  function log_optimizer (line 143) | def log_optimizer(logger, optimizer, keys):
  class OptimModule (line 154) | class OptimModule(nn.Module):
    method register (line 157) | def register(self, name, tensor, lr=None, wd=0.0):

FILE: train.py
  class DummyExperiment (line 43) | class DummyExperiment:
    method nop (line 46) | def nop(self, *args, **kw):
    method __getattr__ (line 49) | def __getattr__(self, _):
    method __getitem__ (line 52) | def __getitem__(self, idx) -> "DummyExperiment":
    method __setitem__ (line 56) | def __setitem__(self, *args, **kwargs) -> None:
  function rank_zero_experiment (line 60) | def rank_zero_experiment(fn: Callable) -> Callable:
  class CustomWandbLogger (line 74) | class CustomWandbLogger(WandbLogger):
    method __init__ (line 76) | def __init__(self, *args, **kwargs):
    method experiment (line 83) | def experiment(self):
  class SequenceLightningModule (line 126) | class SequenceLightningModule(pl.LightningModule):
    method __init__ (line 127) | def __init__(self, config):
    method setup (line 159) | def setup(self, stage=None):
    method load_state_dict (line 240) | def load_state_dict(self, state_dict, strict=False):
    method _check_config (line 255) | def _check_config(self):
    method _initialize_state (line 268) | def _initialize_state(self):
    method _reset_state (line 273) | def _reset_state(self, batch, device=None):
    method _detach_state (line 278) | def _detach_state(self, state):
    method _process_state (line 292) | def _process_state(self, batch, batch_idx, training=True):
    method forward (line 326) | def forward(self, batch):
    method step (line 329) | def step(self, x_t):
    method _shared_step (line 336) | def _shared_step(self, batch, batch_idx, prefix="train"):
    method on_train_epoch_start (line 379) | def on_train_epoch_start(self):
    method training_epoch_end (line 383) | def training_epoch_end(self, outputs):
    method on_validation_epoch_start (line 387) | def on_validation_epoch_start(self):
    method validation_epoch_end (line 392) | def validation_epoch_end(self, outputs):
    method on_test_epoch_start (line 396) | def on_test_epoch_start(self):
    method test_epoch_end (line 401) | def test_epoch_end(self, outputs):
    method training_step (line 405) | def training_step(self, batch, batch_idx, dataloader_idx=0):
    method validation_step (line 438) | def validation_step(self, batch, batch_idx, dataloader_idx=0):
    method test_step (line 455) | def test_step(self, batch, batch_idx, dataloader_idx=0):
    method configure_optimizers (line 460) | def configure_optimizers(self):
    method train_dataloader (line 543) | def train_dataloader(self):
    method _eval_dataloaders_names (line 546) | def _eval_dataloaders_names(self, loaders, prefix):
    method _eval_dataloaders (line 557) | def _eval_dataloaders(self):
    method val_dataloader (line 584) | def val_dataloader(self):
    method test_dataloader (line 589) | def test_dataloader(self):
  function create_trainer (line 596) | def create_trainer(config, **kwargs):
  function fsspec_exists (line 649) | def fsspec_exists(filename):
  function train (line 654) | def train(config):
  function main (line 701) | def main(config: OmegaConf):

FILE: vep_embeddings.py
  class DNAEmbeddingModel (line 30) | class DNAEmbeddingModel(nn.Module):
    method __init__ (line 36) | def __init__(
    method forward (line 56) | def forward(self, input_ids):
  class EnformerTokenizer (line 63) | class EnformerTokenizer:
    method encode (line 70) | def encode(
    method batch_encode_plus (line 83) | def batch_encode_plus(
  function setup_distributed (line 92) | def setup_distributed():
  function cleanup_distributed (line 97) | def cleanup_distributed():
  function fsspec_exists (line 102) | def fsspec_exists(filename):
  function fsspec_listdir (line 108) | def fsspec_listdir(dirname):
  function recast_chromosome_tissue_dist2TSS (line 115) | def recast_chromosome_tissue_dist2TSS(examples):
  function tokenize_variants (line 124) | def tokenize_variants(examples, tokenizer, max_length: int):
  function find_variant_idx (line 172) | def find_variant_idx(examples):
  function prepare_dataset (line 198) | def prepare_dataset(args, tokenizer):
  function get_backbone_model (line 260) | def get_backbone_model(args, device):
  function concat_storage_dict_values (line 270) | def concat_storage_dict_values(storage_dict):
  function dump_embeddings (line 275) | def dump_embeddings(args, dataset, model, device):
  function combine_embeddings (line 407) | def combine_embeddings(embeds_path):
  function main (line 433) | def main(args):
Condensed preview — 104 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (427K chars).
[
  {
    "path": ".gitignore",
    "chars": 1815,
    "preview": "data.tar.gz\n*.tsf\n*.ckpt\n.ipynb_checkpoints\n*/.ipynb_checkpoints/*\n*.lprof\n\n.DS_Store\n.idea/\noutputs/\n\n# slurm log files"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                Apache License\n                           Version 2.0, January 2004\n                    "
  },
  {
    "path": "README.md",
    "chars": 13053,
    "preview": "<p align=\"center\">\n    <img src=\"assets/Caduceus_image.png\" alt=\"Caduceus\" width=\"200\"/>\n</p>\n\n\n# Caduceus &#9764;: Bi-D"
  },
  {
    "path": "caduceus/__init__.py",
    "chars": 265,
    "preview": "\"\"\"Hugging Face config, model, and tokenizer for Caduceus.\n\n\"\"\"\n\nfrom .configuration_caduceus import CaduceusConfig\nfrom"
  },
  {
    "path": "caduceus/configuration_caduceus.py",
    "chars": 1964,
    "preview": "\"\"\"Caduceus config for Hugging Face.\n\n\"\"\"\n\nfrom typing import Optional, Union\n\nfrom transformers import PretrainedConfig"
  },
  {
    "path": "caduceus/modeling_caduceus.py",
    "chars": 27326,
    "preview": "\"\"\"Caduceus model for Hugging Face.\n\n\"\"\"\n\nimport inspect\nimport math\nfrom functools import partial\nfrom typing import Op"
  },
  {
    "path": "caduceus/modeling_rcps.py",
    "chars": 9977,
    "preview": "\"\"\"Reverse-complement equivariant modules.\n\n\"\"\"\nfrom collections import OrderedDict\nfrom typing import Optional\n\nimport "
  },
  {
    "path": "caduceus/tests/test_rcps.py",
    "chars": 19337,
    "preview": "\"\"\"Tests for RCPS modules.\n\n\"\"\"\n\nimport pytest\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nt"
  },
  {
    "path": "caduceus/tokenization_caduceus.py",
    "chars": 4966,
    "preview": "\"\"\"Character tokenizer for Hugging Face.\n\n\"\"\"\n\nfrom typing import List, Optional, Dict, Sequence, Tuple\n\nfrom transforme"
  },
  {
    "path": "caduceus_env.yml",
    "chars": 1109,
    "preview": "name: caduceus_env\nchannels:\n  - pytorch\n  - anaconda\n  - nvidia\n  - defaults\ndependencies:\n  - cuda-nvcc=11.7.99\n  - pi"
  },
  {
    "path": "configs/callbacks/base.yaml",
    "chars": 322,
    "preview": "learning_rate_monitor:\n  # _target_: pytorch_lightning.callbacks.LearningRateMonitor\n  logging_interval: ${train.interva"
  },
  {
    "path": "configs/callbacks/checkpoint.yaml",
    "chars": 1320,
    "preview": "model_checkpoint:\n  monitor: ${train.monitor} # name of the logged metric which determines when model is improving\n  mod"
  },
  {
    "path": "configs/callbacks/gpu_affinity.yaml",
    "chars": 37,
    "preview": "gpu_affinity:\n  _name_: gpu_affinity\n"
  },
  {
    "path": "configs/callbacks/rich.yaml",
    "chars": 86,
    "preview": "rich_model_summary:\n  max_depth: 2\n\nrich_progress_bar:\n  refresh_rate_per_second: 1.0\n"
  },
  {
    "path": "configs/callbacks/val_every_n_global_steps.yaml",
    "chars": 43,
    "preview": "val_every_n_global_steps:\n  every_n: 10000\n"
  },
  {
    "path": "configs/callbacks/wandb.yaml",
    "chars": 667,
    "preview": "defaults:\n  - default\n\nwatch_model:\n  _target_: src.callbacks.wandb_callbacks.WatchModel\n  log: \"all\"\n  log_freq: 100\n\nu"
  },
  {
    "path": "configs/config.yaml",
    "chars": 3115,
    "preview": "# @package _global_\ndefaults:\n  - _self_\n  - experiment: ???\n  # - model: ???  # Model backbone\n  # - pipeline: ???  # S"
  },
  {
    "path": "configs/dataset/genomic_benchmark.yaml",
    "chars": 2054,
    "preview": "_name_: genomic_benchmark\ntrain_val_split_seed: ${train.seed}  # Used for train/validation splitting\ndataset_name: dummy"
  },
  {
    "path": "configs/dataset/hg38.yaml",
    "chars": 299,
    "preview": "_name_: hg38\nbed_file: null\nfasta_file: null\ndataset_name: hg38\ntokenizer_name: null\ncache_dir: null\nmax_length: 1024\nad"
  },
  {
    "path": "configs/dataset/nucleotide_transformer.yaml",
    "chars": 2703,
    "preview": "_name_: nucleotide_transformer  # this links to the overall SequenceDataset of all nucleotide transformer datasets\ntrain"
  },
  {
    "path": "configs/experiment/hg38/genomic_benchmark.yaml",
    "chars": 2476,
    "preview": "# @package _global_\ndefaults:\n  - /pipeline: genomic_benchmark\n  - /model: ???\n  - override /scheduler: cosine_warmup_ti"
  },
  {
    "path": "configs/experiment/hg38/genomic_benchmark_cnn.yaml",
    "chars": 2213,
    "preview": "# @package _global_\ndefaults:\n  - /model: genomics_benchmark_cnn\n  - /pipeline: genomic_benchmark\n  - override /schedule"
  },
  {
    "path": "configs/experiment/hg38/hg38.yaml",
    "chars": 1530,
    "preview": "# @package _global_\ndefaults:\n  - /pipeline: hg38\n  - /model: ???  # Specify a model, e.g. model=mamba or model=hyena\n  "
  },
  {
    "path": "configs/experiment/hg38/nucleotide_transformer.yaml",
    "chars": 1502,
    "preview": "# @package _global_\ndefaults:\n  - /pipeline: nucleotide_transformer\n  - /model: ???\n  - override /scheduler: cosine_warm"
  },
  {
    "path": "configs/loader/default.yaml",
    "chars": 99,
    "preview": "num_workers: ${eval:\"len(__import__('os').sched_getaffinity(0))\"}\npin_memory: True\ndrop_last: True\n"
  },
  {
    "path": "configs/model/caduceus.yaml",
    "chars": 1055,
    "preview": "# Use open-source version of Mamba\n_name_: caduceus_lm\nconfig:\n  _target_: caduceus.configuration_caduceus.CaduceusConfi"
  },
  {
    "path": "configs/model/genomics_benchmark_cnn.yaml",
    "chars": 277,
    "preview": "# Use open-source version of Mamba\n_name_: genomics_benchmark_cnn\nnumber_of_classes: ${dataset.d_output}\nvocab_size: 12\n"
  },
  {
    "path": "configs/model/hyena.yaml",
    "chars": 522,
    "preview": "_name_: hyena_lm\nd_model: 128\nn_layer: 2\nd_inner: ${eval:4 * ${.d_model}}\nvocab_size: 12\nresid_dropout: 0.0\nembed_dropou"
  },
  {
    "path": "configs/model/layer/hyena.yaml",
    "chars": 275,
    "preview": "_name_: hyena\nl_max: 1024\norder: 2\nfilter_order: 64\nnum_heads: 1\ninner_factor: 1\nnum_blocks: 1\nfused_bias_fc: false\noute"
  },
  {
    "path": "configs/model/mamba.yaml",
    "chars": 745,
    "preview": "# Use open-source version of Mamba\n_name_: mamba_lm\nconfig:\n  _target_: mamba_ssm.models.config_mamba.MambaConfig\n  d_mo"
  },
  {
    "path": "configs/optimizer/adam.yaml",
    "chars": 184,
    "preview": "# _target_: torch.optim.Adam\n_name_: adam\nlr: 0.001  # Initial learning rate\n# weight_decay: 0.0  # Weight decay for ada"
  },
  {
    "path": "configs/optimizer/adamw.yaml",
    "chars": 132,
    "preview": "# _target_: torch.optim.AdamW\n_name_: adamw\nlr: 0.001 # Initial learning rate\nweight_decay: 0.00 # Weight decay\nbetas: ["
  },
  {
    "path": "configs/optimizer/sgd.yaml",
    "chars": 137,
    "preview": "# _target_: torch.optim.SGD\n_name_: sgd\nlr: 0.001  # Initial learning rate\nmomentum: 0.9\nweight_decay: 0.0  # Weight dec"
  },
  {
    "path": "configs/pipeline/genomic_benchmark.yaml",
    "chars": 388,
    "preview": "# @package _global_\ndefaults:\n  - /trainer: default\n  - /loader: default\n  - /dataset: genomic_benchmark\n  - /task: mult"
  },
  {
    "path": "configs/pipeline/hg38.yaml",
    "chars": 605,
    "preview": "# @package _global_\ndefaults:\n  - /trainer: default\n  - /loader: null\n  - /dataset: hg38\n  - /optimizer: adamw\n  - /sche"
  },
  {
    "path": "configs/pipeline/nucleotide_transformer.yaml",
    "chars": 446,
    "preview": "# @package _global_\ndefaults:\n  - /trainer: default\n  - /loader: default\n  - /dataset: nucleotide_transformer\n  - /task:"
  },
  {
    "path": "configs/scheduler/constant.yaml",
    "chars": 124,
    "preview": "# @package _global_\ntrain:\n  interval: epoch\nscheduler:\n  # _target_: transformers.get_constant_schedule\n  _name_: const"
  },
  {
    "path": "configs/scheduler/constant_warmup.yaml",
    "chars": 205,
    "preview": "# @package _global_\ntrain:\n  interval: step\nscheduler:\n  # _target_: transformers.get_constant_schedule_with_warmup\n  _n"
  },
  {
    "path": "configs/scheduler/cosine.yaml",
    "chars": 248,
    "preview": "# @package _global_\ntrain:\n  interval: epoch\nscheduler:\n  # _target_: torch.optim.lr_scheduler.CosineAnnealingLR\n  _name"
  },
  {
    "path": "configs/scheduler/cosine_warmup.yaml",
    "chars": 191,
    "preview": "# @package _global_\ntrain:\n  interval: step\nscheduler:\n  # _target_: transformers.get_cosine_schedule_with_warmup\n  _nam"
  },
  {
    "path": "configs/scheduler/cosine_warmup_timm.yaml",
    "chars": 234,
    "preview": "# @package _global_\ntrain:\n  interval: step\nscheduler:\n  # _target_: transformers.get_cosine_schedule_with_warmup\n  _nam"
  },
  {
    "path": "configs/scheduler/linear_warmup.yaml",
    "chars": 191,
    "preview": "# @package _global_\ntrain:\n  interval: step\nscheduler:\n  # _target_: transformers.get_linear_schedule_with_warmup\n  _nam"
  },
  {
    "path": "configs/scheduler/multistep.yaml",
    "chars": 165,
    "preview": "# @package _global_\ntrain:\n  interval: epoch\n# _target_: torch.optim.lr_scheduler.MultiStepLR\nscheduler:\n  _name_: multi"
  },
  {
    "path": "configs/scheduler/plateau.yaml",
    "chars": 346,
    "preview": "# @package _global_\ntrain:\n  interval: epoch\n  monitor: ??? # must be specified\nscheduler:\n  # _target_: torch.optim.lr_"
  },
  {
    "path": "configs/scheduler/step.yaml",
    "chars": 145,
    "preview": "# @package _global_\ntrain:\n  interval: step\nscheduler:\n  # _target_: torch.optim.lr_scheduler.StepLR\n  _name_: step\n  st"
  },
  {
    "path": "configs/task/lm.yaml",
    "chars": 84,
    "preview": "_name_: lm\n# loss: cross_entropy # Handled by task: cross entropy loss\nmetrics: ppl\n"
  },
  {
    "path": "configs/task/multiclass_classification.yaml",
    "chars": 115,
    "preview": "# _target_: tasks.tasks.MultiClass\n_name_: multiclass\nloss: cross_entropy\nmetrics:\n  - accuracy\ntorchmetrics: null\n"
  },
  {
    "path": "configs/task/multilabel_classification.yaml",
    "chars": 235,
    "preview": "# _target_:\n_name_: base\nloss: binary_cross_entropy\nmetrics: null\ntorchmetrics:\n  - MultilabelAUROC # AUROC\n  - Multilab"
  },
  {
    "path": "configs/task/regression.yaml",
    "chars": 88,
    "preview": "# _target_: tasks.tasks.BaseTask\n_name_: base\nloss: mse\nmetrics: mse\ntorchmetrics: null\n"
  },
  {
    "path": "configs/trainer/debug.yaml",
    "chars": 328,
    "preview": "defaults:\n  - default\n\ngpus: 1\nmin_epochs: 1\nmax_epochs: 10\n\n# prints\nprogress_bar_refresh_rate: null\nweights_summary: f"
  },
  {
    "path": "configs/trainer/default.yaml",
    "chars": 527,
    "preview": "_target_: pytorch_lightning.Trainer\n\ndevices: 1\naccelerator: gpu\naccumulate_grad_batches: 1 # Gradient accumulation ever"
  },
  {
    "path": "configs/trainer/full.yaml",
    "chars": 1076,
    "preview": "_target_: pytorch_lightning.Trainer\n\n# default values for all trainer parameters\ncheckpoint_callback: True\ndefault_root_"
  },
  {
    "path": "configs/trainer/lm.yaml",
    "chars": 643,
    "preview": "accumulate_grad_batches: 1\n# accelerator: null # set to 'ddp' for distributed\n# amp_backend: native # 'native' | 'apex'\n"
  },
  {
    "path": "setup_env.sh",
    "chars": 489,
    "preview": "#!/bin/bash\n\n# Shell script to set environment variables when running code in this repository.\n# Usage:\n#     source set"
  },
  {
    "path": "slurm_scripts/dump_vep_embeddings.sh",
    "chars": 2653,
    "preview": "#!/bin/bash\n#SBATCH --get-user-env                      # Retrieve the users login environment\n#SBATCH -t 96:00:00      "
  },
  {
    "path": "slurm_scripts/run_genomics_benchmark.sh",
    "chars": 2337,
    "preview": "#!/bin/bash\n#SBATCH --get-user-env                   # Retrieve the users login environment\n#SBATCH -t 96:00:00         "
  },
  {
    "path": "slurm_scripts/run_genomics_benchmark_cnn.sh",
    "chars": 1949,
    "preview": "#!/bin/bash\n#SBATCH --get-user-env                   # Retrieve the users login environment\n#SBATCH -t 48:00:00         "
  },
  {
    "path": "slurm_scripts/run_nucleotide_transformer.sh",
    "chars": 2302,
    "preview": "#!/bin/bash\n#SBATCH --get-user-env                   # Retrieve the users login environment\n#SBATCH -t 96:00:00         "
  },
  {
    "path": "slurm_scripts/run_pretrain_caduceus.sh",
    "chars": 2204,
    "preview": "#!/bin/bash\n#SBATCH --get-user-env                      # Retrieve the users login environment\n#SBATCH -t 96:00:00      "
  },
  {
    "path": "slurm_scripts/run_pretrain_hyena.sh",
    "chars": 1734,
    "preview": "#!/bin/bash\n#SBATCH --get-user-env                      # Retrieve the users login environment\n#SBATCH -t 96:00:00      "
  },
  {
    "path": "slurm_scripts/run_pretrain_mamba.sh",
    "chars": 1833,
    "preview": "#!/bin/bash\n#SBATCH --get-user-env                      # Retrieve the users login environment\n#SBATCH -t 96:00:00      "
  },
  {
    "path": "slurm_scripts/wrapper_run_genomics.sh",
    "chars": 3595,
    "preview": "#!/bin/bash\n\n# Choose one from below\n\n## Hyena\n## TODO: Download HF model from https://huggingface.co/LongSafari/hyenadn"
  },
  {
    "path": "slurm_scripts/wrapper_run_genomics_cnn.sh",
    "chars": 629,
    "preview": "#!/bin/bash\n\nLOG_DIR=\"../watch_folder/gb_cv5/cnn_baseline\"\nmkdir -p \"${LOG_DIR}\"\nexport_str=\"ALL\"\nfor TASK in \"dummy_mou"
  },
  {
    "path": "slurm_scripts/wrapper_run_nucleotide_transformer.sh",
    "chars": 2642,
    "preview": "#!/bin/bash\n\n# Choose one from below\n\n## Caduceus NO POST HOC\n#LOG_DIR=\"../watch_folder/nt_cv10_ep20/caduceus\"\n#CONFIG_P"
  },
  {
    "path": "src/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/callbacks/params.py",
    "chars": 1416,
    "preview": "\"\"\"Callback to log the number of parameters of the model.\n\n\"\"\"\n\nimport pytorch_lightning as pl\nfrom pytorch_lightning.ut"
  },
  {
    "path": "src/callbacks/timer.py",
    "chars": 3685,
    "preview": "\"\"\"Callback to monitor the speed of each step and each epoch.\n\nhttps://github.com/HazyResearch/transformers/blob/master/"
  },
  {
    "path": "src/callbacks/validation.py",
    "chars": 1369,
    "preview": "\"\"\"Check validation every n **global** steps.\n\nPytorch Lightning has a `val_check_interval` parameter that checks valida"
  },
  {
    "path": "src/dataloaders/__init__.py",
    "chars": 57,
    "preview": "from . import genomics\nfrom .base import SequenceDataset\n"
  },
  {
    "path": "src/dataloaders/base.py",
    "chars": 7296,
    "preview": "\"\"\" Datasets for core experimental results.\n\n\"\"\"\n\nimport os\nfrom functools import partial\nfrom pathlib import Path\n\nimpo"
  },
  {
    "path": "src/dataloaders/datasets/genomic_bench_dataset.py",
    "chars": 4566,
    "preview": "\"\"\"Genomic Benchmarks Dataset.\n\nFrom: https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks\n\"\"\"\n\nfrom pathlib import P"
  },
  {
    "path": "src/dataloaders/datasets/hg38_char_tokenizer.py",
    "chars": 5939,
    "preview": "\"\"\" \nFrom: https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py\n\nCharacterToken"
  },
  {
    "path": "src/dataloaders/datasets/hg38_dataset.py",
    "chars": 7514,
    "preview": "\"\"\"Dataset for sampling arbitrary intervals from the human genome.\n\n\"\"\"\n\nimport math\nfrom pathlib import Path\n\nimport pa"
  },
  {
    "path": "src/dataloaders/datasets/nucleotide_transformer_dataset.py",
    "chars": 3584,
    "preview": "\"\"\"Nucleotide Transformer Benchmarks Dataset.\n\nFrom: https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_"
  },
  {
    "path": "src/dataloaders/fault_tolerant_sampler.py",
    "chars": 4660,
    "preview": "# Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytor"
  },
  {
    "path": "src/dataloaders/genomics.py",
    "chars": 17628,
    "preview": "\"\"\"Dataloaders for genomics datasets, including pretraining and downstream tasks.\n\n    - Adapted from:\n        https://g"
  },
  {
    "path": "src/dataloaders/utils/mlm.py",
    "chars": 1913,
    "preview": "import torch\n\n\ndef mlm_getitem(seq, mlm_probability=0.15, contains_eos=False, tokenizer=None, eligible_replacements=None"
  },
  {
    "path": "src/dataloaders/utils/rc.py",
    "chars": 661,
    "preview": "\"\"\"Utility functions for reverse complementing DNA sequences.\n\n\"\"\"\n\nfrom random import random\n\nSTRING_COMPLEMENT_MAP = {"
  },
  {
    "path": "src/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/models/baseline/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/models/baseline/genomics_benchmark_cnn.py",
    "chars": 1966,
    "preview": "\"\"\"Genomics Benchmark CNN model.\n\nAdapted from https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/blob/main/src/gen"
  },
  {
    "path": "src/models/nn/__init__.py",
    "chars": 35,
    "preview": "from .activation import Activation\n"
  },
  {
    "path": "src/models/nn/activation.py",
    "chars": 2876,
    "preview": "\"\"\"Utilities for activation functions.\"\"\"\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as "
  },
  {
    "path": "src/models/nn/adaptive_softmax.py",
    "chars": 16332,
    "preview": "# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 ("
  },
  {
    "path": "src/models/nn/utils.py",
    "chars": 3722,
    "preview": "\"\"\" Utility wrappers around modules to let them handle Args and extra arguments \"\"\"\n\nimport inspect\nfrom functools impor"
  },
  {
    "path": "src/models/sequence/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/models/sequence/dna_embedding.py",
    "chars": 10395,
    "preview": "\"\"\"DNA Embedding Model.\n\nBackbones from LM pre-training models, used for downstream tasks.\n\"\"\"\n\nfrom functools import pa"
  },
  {
    "path": "src/models/sequence/hyena.py",
    "chars": 14719,
    "preview": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\n\ntry:\n    f"
  },
  {
    "path": "src/models/sequence/long_conv_lm.py",
    "chars": 17796,
    "preview": "import copy\nimport math\nimport re\nfrom collections import namedtuple\nfrom functools import partial\n\nimport torch\nimport "
  },
  {
    "path": "src/ops/fftconv.py",
    "chars": 4672,
    "preview": "import math\n\nimport torch\nimport torch.nn.functional as F\n\nfrom einops import rearrange\n\nfrom fftconv import fftconv_fwd"
  },
  {
    "path": "src/tasks/decoders.py",
    "chars": 6977,
    "preview": "\"\"\"Decoder heads.\n\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport src.models.nn.utils a"
  },
  {
    "path": "src/tasks/encoders.py",
    "chars": 2398,
    "preview": "from torch import nn\n\nimport src.models.nn.utils as U\nimport src.utils as utils\n\n\nclass Encoder(nn.Module):\n    \"\"\"Encod"
  },
  {
    "path": "src/tasks/metrics.py",
    "chars": 12663,
    "preview": "import math\nfrom functools import partial\n\nimport torch\nimport torch.nn.functional as F\nimport torchmetrics.functional a"
  },
  {
    "path": "src/tasks/tasks.py",
    "chars": 15398,
    "preview": "import inspect\nfrom typing import List\n\nimport torch.nn as nn\nfrom einops import rearrange\n\nimport src.models.nn.utils a"
  },
  {
    "path": "src/tasks/torchmetrics.py",
    "chars": 4532,
    "preview": "# Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py\n# But we compute th"
  },
  {
    "path": "src/utils/__init__.py",
    "chars": 79,
    "preview": "from .config import is_list, is_dict, to_list, to_dict, get_class, instantiate\n"
  },
  {
    "path": "src/utils/config.py",
    "chars": 3548,
    "preview": "\"\"\"Utilities for dealing with collection objects (lists, dicts) and configs.\n\n\"\"\"\n\nimport functools\nfrom typing import S"
  },
  {
    "path": "src/utils/optim/schedulers.py",
    "chars": 3534,
    "preview": "\"\"\"Custom learning rate schedulers\"\"\"\n\nimport math\nimport warnings\nimport torch\n\nfrom timm.scheduler import CosineLRSche"
  },
  {
    "path": "src/utils/optim_groups.py",
    "chars": 7052,
    "preview": "\"\"\"Utilities for special optimizer hyperparameters.\n\ngroup_parameters_for_optimizer is a modification of timm's optimize"
  },
  {
    "path": "src/utils/registry.py",
    "chars": 2648,
    "preview": "\"\"\"Class registry for models, layers, optimizers, and schedulers.\n\n\"\"\"\n\noptimizer = {\n    \"adam\": \"torch.optim.Adam\",\n  "
  },
  {
    "path": "src/utils/train.py",
    "chars": 6135,
    "preview": "\"\"\" Utils for the training loop.\n\nCopied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.p"
  },
  {
    "path": "train.py",
    "chars": 28722,
    "preview": "\"\"\"Main training entry point for pre-training and downstream fine-tuning.\n\n\"\"\"\n\nimport json\nimport os\nimport random\nimpo"
  },
  {
    "path": "vep_embeddings.py",
    "chars": 20804,
    "preview": "\"\"\"Dump model embeddings for VEP classification task.\n\n\"\"\"\n\nimport argparse\nimport os\nfrom functools import partial\nfrom"
  },
  {
    "path": "vep_svm.ipynb",
    "chars": 15382,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"db878bc1\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports and"
  }
]

About this extraction

This page contains the full source code of the kuleshov-group/caduceus GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 104 files (394.9 KB), approximately 102.2k tokens, and a symbol index with 434 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!