[
  {
    "path": ".gitignore",
    "content": "data.tar.gz\n*.tsf\n*.ckpt\n.ipynb_checkpoints\n*/.ipynb_checkpoints/*\n*.lprof\n\n.DS_Store\n.idea/\noutputs/\n\n# slurm log files\nwatch_folder/\n\ndata\n\n# Created by https://www.gitignore.io/api/python\n# Edit at https://www.gitignore.io/?templates=python\n\n### Python ###\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# Mr Developer\n.mr.developer.cfg\n.project\n.pydevproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# End of https://www.gitignore.io/api/python\n"
  },
  {
    "path": "LICENSE",
    "content": "                                Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n    <img src=\"assets/Caduceus_image.png\" alt=\"Caduceus\" width=\"200\"/>\n</p>\n\n\n# Caduceus &#9764;: Bi-Directional Equivariant Long-Range DNA Sequence Modeling\n[[Blog]](https://caduceus-dna.github.io/) &nbsp; | &nbsp; [[arXiv]](https://arxiv.org/abs/2403.03234) &nbsp; | &nbsp; [[HuggingFace 🤗]](https://huggingface.co/collections/kuleshov-group/caducues-65dcb89b4f54e416ef61c350)\n\nThis repository contains code for reproducing the results in the paper \"Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling,\" [Schiff et al. (2024)](https://arxiv.org/abs/2403.03234).\n\n## Using Caduceus with 🤗\n<a name=\"HF\"></a>\nWe have uploaded a pre-trained Caduceus model to the Huggingface hub.\nThe available models are:\n- Caduceus-Ph: [kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16](https://huggingface.co/kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16)\n  - Trained on sequences of length 131k, with a model size of 256 and 16 layers.\n  - Trained for 50k steps and batch size of 8.\n  - Trained with reverse-complement (RC) data augmentation.\n- Caduceus-PS: [kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16](https://huggingface.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16)\n  - Trained on sequences of length 131k, with a model size of 256 and 16 layers.\n  - Trained for 50k steps and batch size of 8.\n  - Model is RC equivariant, hence no RC data augmentation is required.\n\nYou can either use the pre-trained model directly within your trainer scripts or modify the config that initializes the model.\n\nTo use the pre-trained model for masked language modeling, use the following snippet:\n```python\nfrom transformers import AutoModelForMaskedLM, AutoTokenizer\n\n# See the `Caduceus` collection page on the hub for list of available models.\nmodel_name = \"kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16\"\ntokenizer = AutoTokenizer.from_pretrained(model_name)\nmodel = AutoModelForMaskedLM.from_pretrained(model_name)\n```\n\nAlternatively, you can instantiate a model from scratch to train on your own data as follows:\n```python\nfrom transformers import AutoConfig, AutoModelForMaskedLM\n\n# Add any config overrides here, see the `config.json` file on the hub for details.\nconfig_overrides = {}\n# See the `Caduceus` collection page on the hub for list of available models.\nconfig = AutoConfig.from_pretrained(\n \"kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16\",\n **config_overrides,\n)\nmodel = AutoModelForMaskedLM.from_config(config)\n```\n\n## Getting started in this repository\n<a name=\"getting_started\"></a>\n\nTo get started, create a conda environment containing the required dependencies.\n\n```bash\nconda env create -f caduceus_env.yml\n```\n\nActivate the environment.\n\n```bash\nconda activate caduceus_env\n```\n\nCreate the following directories to store saved models and slurm logs:\n```bash\nmkdir outputs\nmkdir watch_folder\n```\n\n## Reproducing Experiments\n\nBelow, we describe the steps required for reproducing the experiments in the paper.\nThroughout, the main entry point for running experiments is the [`train.py`](./train.py) script.\nWe also provide sample `slurm` scripts for launching pre-training and downstream fine-tuning experiments in the [`slurm_scripts/`](./slurm_scripts) directory.\n\n### Pretraining on Human Reference Genome\n<a name=\"pretraining\"></a>\n(Data downloading instructions are copied from [HyenaDNA repo](https://github.com/HazyResearch/hyena-dna?tab=readme-ov-file#pretraining-on-human-reference-genome))\n\nFirst, download the Human Reference Genome data.\nIt's comprised of 2 files, 1 with all the sequences (the `.fasta` file), and with the intervals we use (`.bed` file).\n\nThe file structure should look like\n\n```\ndata\n|-- hg38/\n    |-- hg38.ml.fa\n    |-- human-sequences.bed\n```\n\nDownload fasta (.fa format) file (of the entire human genome) into `./data/hg38`.\n~24 chromosomes in the whole genome (merged into 1 file), each chromosome is a continuous sequence, basically.\nThen download the .bed file with sequence intervals (contains chromosome name, start, end, split, which then allow you to retrieve from the fasta file).\n```bash\nmkdir -p data/hg38/\ncurl https://storage.googleapis.com/basenji_barnyard2/hg38.ml.fa.gz > data/hg38/hg38.ml.fa.gz\ngunzip data/hg38/hg38.ml.fa.gz  # unzip the fasta file\ncurl https://storage.googleapis.com/basenji_barnyard2/sequences_human.bed > data/hg38/human-sequences.bed\n```\n\nLaunch pretraining run using the command line\n\n```bash\npython -m train \\\n  experiment=hg38/hg38 \\\n  callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \\\n  dataset.max_length=1024 \\\n  dataset.batch_size=1024 \\\n  dataset.mlm=true \\\n  dataset.mlm_probability=0.15 \\\n  dataset.rc_aug=false \\\n  model=caduceus \\\n  model.config.d_model=128 \\\n  model.config.n_layer=4 \\\n  model.config.bidirectional=true \\\n  model.config.bidirectional_strategy=add \\\n  model.config.bidirectional_weight_tie=true \\\n  model.config.rcps=true \\\n  optimizer.lr=\"8e-3\" \\\n  train.global_batch_size=1024 \\\n  trainer.max_steps=10000 \\\n  +trainer.val_check_interval=10000 \\\n  wandb=null\n```\n\nor alternatively, if using a cluster that has `slurm` installed, adapt the scripts below:\n```\nslurm_scripts\n|-- run_pretrain_caduceus.sh\n|-- run_pretrain_hyena.sh\n|-- run_pretrain_mamba.sh\n```\n\nand run the training as a batch job:\n```bash\ncd slurm_scripts\nsbatch run_pretrain_caduceus.sh\n```\n\n### GenomicBenchmarks\n<a name=\"genomicbenchmarks\"></a>\n\nThe [GenomicBenchmarks](https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks) presented in [Grešová et al. (2023)](https://bmcgenomdata.biomedcentral.com/articles/10.1186/s12863-023-01123-8) is comprised of 8 classification tasks.\n\nWe can launch a downstream fine-tuning run on one of the tasks using the sample command below:\n```bash\npython -m train \\\n    experiment=hg38/genomic_benchmark \\\n    callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \\\n    dataset.dataset_name=\"dummy_mouse_enhancers_ensembl\" \\\n    dataset.train_val_split_seed=1 \\\n    dataset.batch_size=256 \\\n    dataset.rc_aug=false \\\n    +dataset.conjoin_train=false \\\n    +dataset.conjoin_test=false \\\n    loader.num_workers=2 \\\n    model=caduceus \\\n    model._name_=dna_embedding_caduceus \\\n    +model.config_path=\"<path to model_config.json>\" \\\n    +model.conjoin_test=false \\\n    +decoder.conjoin_train=true \\\n    +decoder.conjoin_test=false \\\n    optimizer.lr=\"1e-3\" \\\n    trainer.max_epochs=10 \\\n    train.pretrained_model_path=\"<path to .ckpt file>\" \\\n    wandb=null\n```\n\nThis sample run will fine-tune a pre-trained Caduceus-PS model on the `dummy_mouse_enhancers_ensembl` task.\nNote some of the additional arguments present here, relative to the pre-training command from [above](#pretraining):\n- `model.config_path` contains the path model config that was saved during pre-training.\nThis will be saved to the run directory of the pre-training experiment.\n- `train.pretrained_model_path` contains the path to the pre-trained model checkpoint.\n- `dataset.conjoin_train` determines whether the dataset will return a single sequence (`dataset.conjoin_train=false`) or the concatenation of a sequence and its reverse complement along `dim=-1`, during downstream fine-tuning training.\n- `dataset.conjoin_test` is the same as above, but for inference (e.g., validation / test).\n- `decoder.conjoin_train` determines whether the prediction head (a mean pooling and linear projection in the case of the Genomics Benchmark) is expecting an input tensor of shape `(batch_size, seq_len, d_model)` or `(batch_size, seq_len, d_model, 2)` during downstream fine-tuning training.\nWhen set to `true` the decoder is run on `input[..., 0]` and `input[..., 1]` and the results are averaged to produce the final prediction.\n- `decoder.conjoin_test` is the same as above, but for inference (e.g., validation / test).\n\nNote this benchmark only contains a training and test split for each task.\nTherefore, to have a more principled evaluation, we randomly split the training data into training and validation sets (90/10) using the `dataset.train_val_split_seed` argument.\nWe perform early stopping on validation metric (accuracy) and repeat this for 5 random seeds.\n\nAs with [pre-training](#pretraining), we can also launch the fine-tuning run as a batch job using the provided [`run_genomic_benchmark.sh`](./slurm_scripts/run_genomics_benchmark.sh) script.\nWe also provide a helper shell script [`wrapper_run_genomics.sh`](./slurm_scripts/wrapper_run_genomics.sh) that can be used to launch multiple fine-tuning runs in parallel.\n\nFinally, the [`run_genomics_benchmark_cnn.sh`](./slurm_scripts/run_genomics_benchmark_cnn.sh) script can be used to train the CNN baseline for this experiment from scratch on the downstream tasks.\n\n### Nucleotide Transformer datasets\n<a name=\"nucleotidetransformer\"></a>\n\nThe Nucleotide Transformer suite of tasks was proposed in [Dalla-Torre et al. (2023)](https://www.biorxiv.org/content/10.1101/2023.01.11.523679v1).\nThe data is available on HuggingFace: [InstaDeepAI/nucleotide_transformer_downstream_tasks](https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_downstream_tasks).\n\nWe can launch a downstream fine-tuning run on one of the tasks using the sample command below:\n```bash\npython -m train \\\n    experiment=hg38/nucleotide_transformer \\\n    callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \\\n    dataset.dataset_name=\"${task}\" \\\n    dataset.train_val_split_seed=${seed} \\\n    dataset.batch_size=${batch_size} \\\n    dataset.rc_aug=\"${rc_aug}\" \\\n    +dataset.conjoin_test=\"${CONJOIN_TEST}\" \\\n    loader.num_workers=2 \\\n    model._name_=dna_embedding_caduceus \\\n    +model.config_path=\"<path to model_config.json>\" \\\n    +model.conjoin_test=false \\\n    +decoder.conjoin_train=true \\\n    +decoder.conjoin_test=false \\\n    optimizer.lr=\"1e-3\" \\\n    trainer.max_epochs=10 \\\n    train.pretrained_model_path=\"<path to .ckpt file>\" \\\n    trainer.max_epochs=20 \\\n    wandb=null\n```\n\nWe can also launch as batch jobs (see [`run_nucleotide_transformer.sh`](./slurm_scripts/run_nucleotide_transformer.sh) and [`wrapper_run_nucleotide_transformer.sh`](./slurm_scripts/wrapper_run_nucleotide_transformer.sh) for details).\n\n### eQTL SNP Variant Effect Prediction\n<a name=\"vep\"></a>\nThis task comes from the recently proposed Long Range Benchmark (LRB) in [Kao et al., 2023](https://llms4science-community.github.io/papers/LLMs4Bio24_paper_12.pdf).\nThe data is available on HuggingFace: [InstaDeepAI/genomics-long-range-benchmark](https://huggingface.co/datasets/InstaDeepAI/genomics-long-range-benchmark).\nFor this task we fit a model to the pre-trained and frozen embeddings of the DNA language models.\nTherefore, to perform the evaluation, we proceed in 2 steps:\n- **Step 1: Extract the embeddings** from the pre-trained model:\nRun the [`vep_embeddings.py`](./vep_embeddings.py) script to extract the embeddings from the pre-trained model.\nSee the example below:\n```bash\ntorchrun \\\n    --standalone \\\n    --nnodes=1 \\\n    --nproc-per-node=8 \\\n    vep_embeddings.py \\\n      --num_workers=2 \\\n      --seq_len=131072  \\\n      --bp_per_token=1  \\\n      --embed_dump_batch_size=1 \\\n      --name=\"caduceus-ps_downstream-seqlen=131k\"  \\\n      --model_name_or_path=\"kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16\" \\\n      --rcps\n```\n\nThe `--rcps` flag is used to indicate that the model is reverse-complement equivariant.\nWhen using other models, set this flag to false with `--no-rcps`.\nTo speed this step up, this script utilizes torch distributed data parallelism.\n\nPlease refer to the slurm script provided in [`slurm_scripts/dump_vep_embeddings.sh`](./slurm_scripts/dump_vep_embeddings.sh)\nto launch this step as a batch job.\n\n- **Step 2: Fit an SVM model to the embeddings** using this notebook: [`vep_svm.ipynb`](./vep_svm.ipynb).\n\n## Citation\n<a name=\"citation\"></a>\n\nIf you find our work useful, please cite our paper using the following:\n```\n@article{schiff2024caduceus,\n  title={Caduceus: Bi-Directional Equivariant Long-Range DNA Sequence Modeling},\n  author={Schiff, Yair and Kao, Chia-Hsiang and Gokaslan, Aaron and Dao, Tri and Gu, Albert and Kuleshov, Volodymyr},\n  journal={arXiv preprint arXiv:2403.03234},\n  year={2024}\n}\n```\n\n## Acknowledgements\n<a name=\"acknowledgements\"></a>\nThis repository is adapted from the [HyenaDNA repo](https://github.com/HazyResearch/hyena-dna) and leverages much of the training, data loading, and logging infrastructure defined there.\nHyenaDNA was originally derived from the [S4](https://github.com/state-spaces/s4) and [Safari](https://github.com/HazyResearch/safari) repositories.\n\nWe would like to thank Evan Trop and the [InstaDeep](https://www.instadeep.com/) team for useful discussions about the [Nucleotide Transformer leaderboard](https://huggingface.co/spaces/InstaDeepAI/nucleotide_transformer_benchmark)\nand the Long Range Benchmark task.\n\nFinally, we would like to thank [MosaicML](https://www.mosaicml.com/) for providing compute resources for some of the pre-training experiments.\n"
  },
  {
    "path": "caduceus/__init__.py",
    "content": "\"\"\"Hugging Face config, model, and tokenizer for Caduceus.\n\n\"\"\"\n\nfrom .configuration_caduceus import CaduceusConfig\nfrom .modeling_caduceus import Caduceus, CaduceusForMaskedLM, CaduceusForSequenceClassification\nfrom .tokenization_caduceus import CaduceusTokenizer\n"
  },
  {
    "path": "caduceus/configuration_caduceus.py",
    "content": "\"\"\"Caduceus config for Hugging Face.\n\n\"\"\"\n\nfrom typing import Optional, Union\n\nfrom transformers import PretrainedConfig\n\n\nclass CaduceusConfig(PretrainedConfig):\n    \"\"\"Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance.\"\"\"\n    model_type = \"caduceus\"\n\n    def __init__(\n            self,\n            # From original MambaConfig\n            d_model: int = 2560,\n            n_layer: int = 64,\n            vocab_size: int = 50277,\n            ssm_cfg: Optional[dict] = None,\n            rms_norm: bool = True,\n            residual_in_fp32: bool = True,\n            fused_add_norm: bool = True,\n            pad_vocab_size_multiple: int = 8,\n\n            # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm\n            norm_epsilon: float = 1e-5,\n\n            # Used in init_weights\n            initializer_cfg: Optional[dict] = None,\n\n            # Caduceus-specific params\n            bidirectional: bool = True,\n            bidirectional_strategy: Union[str, None] = \"add\",\n            bidirectional_weight_tie: bool = True,\n            rcps: bool = False,\n            complement_map: Optional[dict] = None,  # used for RCPSEmbedding / RCPSLMHead\n            **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.d_model = d_model\n        self.n_layer = n_layer\n        self.vocab_size = vocab_size\n        self.ssm_cfg = ssm_cfg\n        self.rms_norm = rms_norm\n        self.residual_in_fp32 = residual_in_fp32\n        self.fused_add_norm = fused_add_norm\n        self.pad_vocab_size_multiple = pad_vocab_size_multiple\n        self.norm_epsilon = norm_epsilon\n        self.initializer_cfg = initializer_cfg\n        self.bidirectional = bidirectional\n        self.bidirectional_strategy = bidirectional_strategy\n        self.bidirectional_weight_tie = bidirectional_weight_tie\n        self.rcps = rcps\n        self.complement_map = complement_map\n"
  },
  {
    "path": "caduceus/modeling_caduceus.py",
    "content": "\"\"\"Caduceus model for Hugging Face.\n\n\"\"\"\n\nimport inspect\nimport math\nfrom functools import partial\nfrom typing import Optional, Tuple, Union\n\nimport torch\nfrom mamba_ssm.modules.mamba_simple import Mamba\ntry:\n    from mamba_ssm.modules.mamba_simple import Block  # Legacy mambav1 file structure\nexcept ImportError:\n    from mamba_ssm.modules.block import Block  # mambav2 file structure\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom transformers import PreTrainedModel\nfrom transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput\n\ntry:\n    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn  # Legacy mambav1 file structure\nexcept ImportError:\n    try:\n        from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn  # mambav2 file structure\n    except ImportError:\n        RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None\n\nfrom .configuration_caduceus import CaduceusConfig\nfrom .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock\n\n\ndef create_block(\n        d_model,\n        ssm_cfg=None,\n        norm_epsilon=1e-5,\n        rms_norm=False,\n        residual_in_fp32=False,\n        fused_add_norm=False,\n        layer_idx=None,\n        bidirectional=True,\n        bidirectional_strategy=\"add\",\n        bidirectional_weight_tie=True,\n        rcps=False,\n        device=None,\n        dtype=None,\n):\n    \"\"\"Create Caduceus block.\n\n    Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py\n    \"\"\"\n    if ssm_cfg is None:\n        ssm_cfg = {}\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n    bidirectional_kwargs = {\n        \"bidirectional\": bidirectional,\n        \"bidirectional_strategy\": bidirectional_strategy,\n        \"bidirectional_weight_tie\": bidirectional_weight_tie,\n    }\n    mixer_cls = partial(BiMambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs)\n    norm_cls = partial(\n        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs\n    )\n    block_cls = RCPSMambaBlock if rcps else Block\n    # mambav2 compatibility\n    if \"mlp_cls\" in inspect.signature(block_cls.__init__).parameters:\n        block = block_cls(\n            d_model,\n            mixer_cls,\n            mlp_cls=nn.Identity,\n            norm_cls=norm_cls,\n            fused_add_norm=fused_add_norm,\n            residual_in_fp32=residual_in_fp32,\n        )\n    else:\n        block = block_cls(\n            d_model,\n            mixer_cls,\n            norm_cls=norm_cls,\n            fused_add_norm=fused_add_norm,\n            residual_in_fp32=residual_in_fp32,\n        )\n    block.layer_idx = layer_idx\n    return block\n\n\nclass BiMambaWrapper(nn.Module):\n    \"\"\"Thin wrapper around Mamba to support bi-directionality.\"\"\"\n\n    def __init__(\n            self,\n            d_model: int,\n            bidirectional: bool = True,\n            bidirectional_strategy: Optional[str] = \"add\",\n            bidirectional_weight_tie: bool = True,\n            **mamba_kwargs,\n    ):\n        super().__init__()\n        if bidirectional and bidirectional_strategy is None:\n            bidirectional_strategy = \"add\"  # Default strategy: `add`\n        if bidirectional and bidirectional_strategy not in [\"add\", \"ew_multiply\"]:\n            raise NotImplementedError(f\"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!\")\n        self.bidirectional = bidirectional\n        self.bidirectional_strategy = bidirectional_strategy\n        self.mamba_fwd = Mamba(\n            d_model=d_model,\n            **mamba_kwargs\n        )\n        if bidirectional:\n            self.mamba_rev = Mamba(\n                d_model=d_model,\n                **mamba_kwargs\n            )\n            if bidirectional_weight_tie:  # Tie in and out projections (where most of param count lies)\n                self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight\n                self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias\n                self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight\n                self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias\n        else:\n            self.mamba_rev = None\n\n    def forward(self, hidden_states, inference_params=None):\n        \"\"\"Bidirectional-enabled forward pass\n\n        hidden_states: (B, L, D)\n        Returns: same shape as hidden_states\n        \"\"\"\n        out = self.mamba_fwd(hidden_states, inference_params=inference_params)\n        if self.bidirectional:\n            out_rev = self.mamba_rev(\n                hidden_states.flip(dims=(1,)),  # Flip along the sequence length dimension\n                inference_params=inference_params\n            ).flip(dims=(1,))  # Flip back for combining with forward hidden states\n            if self.bidirectional_strategy == \"add\":\n                out = out + out_rev\n            elif self.bidirectional_strategy == \"ew_multiply\":\n                out = out * out_rev\n            else:\n                raise NotImplementedError(f\"`{self.bidirectional_strategy}` for bi-directionality not implemented!\")\n        return out\n\n\nclass CaduceusEmbeddings(nn.Module):\n    def __init__(\n            self,\n            config: CaduceusConfig,\n            device=None,\n            dtype=None,\n    ):\n        super().__init__()\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        if config.rcps:\n            self.word_embeddings = RCPSEmbedding(\n                config.vocab_size, config.d_model, config.complement_map, **factory_kwargs\n            )\n        else:\n            self.word_embeddings = nn.Embedding(config.vocab_size, config.d_model, **factory_kwargs)\n\n    def forward(self, input_ids):\n        \"\"\"\n            input_ids: (batch, seqlen)\n        \"\"\"\n        return self.word_embeddings(input_ids)\n\n\nclass CaduceusMixerModel(nn.Module):\n    def __init__(\n            self,\n            config: CaduceusConfig,\n            device=None,\n            dtype=None,\n    ) -> None:\n        super().__init__()\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n\n        self.fused_add_norm = config.fused_add_norm\n        self.rcps = config.rcps\n        self.residual_in_fp32 = config.residual_in_fp32\n\n        self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)\n\n        # Mamba changes the order of residual and layer norm:\n        # Instead of LN -> Attn / MLP -> Add, we do:\n        # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and\n        # the main branch (output of MLP / Mixer). The model definition is unchanged.\n        # This is for performance reason: we can fuse add + layer_norm.\n        if config.fused_add_norm:\n            if layer_norm_fn is None or rms_norm_fn is None:\n                raise ImportError(\"Failed to import Triton LayerNorm / RMSNorm kernels\")\n\n        self.layers = nn.ModuleList(\n            [\n                create_block(\n                    config.d_model,\n                    ssm_cfg=config.ssm_cfg,\n                    norm_epsilon=config.norm_epsilon,\n                    rms_norm=config.rms_norm,\n                    residual_in_fp32=config.residual_in_fp32,\n                    fused_add_norm=config.fused_add_norm,\n                    layer_idx=i,\n                    bidirectional=config.bidirectional,\n                    bidirectional_strategy=config.bidirectional_strategy,\n                    bidirectional_weight_tie=config.bidirectional_weight_tie,\n                    rcps=config.rcps,\n                    **factory_kwargs,\n                )\n                for i in range(config.n_layer)\n            ]\n        )\n\n        norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(\n            config.d_model, eps=config.norm_epsilon, **factory_kwargs\n        )\n        self.norm_f = norm_f if (config.fused_add_norm or not config.rcps) else RCPSAddNormWrapper(norm_f)\n\n    def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):\n        \"\"\"Mixer forward.\"\"\"\n        all_hidden_states = []\n        if inputs_embeds is not None:\n            hidden_states = inputs_embeds\n        else:\n            hidden_states = self.embeddings(input_ids)\n\n        residual = None\n        for layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states.append(hidden_states)\n            # TODO: Add support for gradient checkpointing\n            hidden_states, residual = layer(\n                hidden_states, residual, inference_params=None\n            )\n\n        if not self.fused_add_norm:\n            if self.rcps:\n                # Set prenorm=False here since we don't need the residual\n                hidden_states = self.norm_f(hidden_states, residual=residual, prenorm=False)\n            else:\n                residual = (hidden_states + residual) if residual is not None else hidden_states\n                hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))\n        else:\n            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn\n            if self.rcps:\n                # Set prenorm=False here since we don't need the residual\n                hidden_states_fwd = fused_add_norm_fn(\n                    hidden_states[..., :hidden_states.shape[-1] // 2],\n                    self.norm_f.weight,\n                    self.norm_f.bias,\n                    eps=self.norm_f.eps,\n                    residual=residual[..., :hidden_states.shape[-1] // 2],\n                    prenorm=False,\n                    residual_in_fp32=self.residual_in_fp32,\n                )\n                hidden_states_rc = fused_add_norm_fn(\n                    hidden_states[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),\n                    self.norm_f.weight,\n                    self.norm_f.bias,\n                    eps=self.norm_f.eps,\n                    residual=residual[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),\n                    prenorm=False,\n                    residual_in_fp32=self.residual_in_fp32,\n                )\n                hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)\n            else:\n                # Set prenorm=False here since we don't need the residual\n                hidden_states = fused_add_norm_fn(\n                    hidden_states,\n                    self.norm_f.weight,\n                    self.norm_f.bias,\n                    eps=self.norm_f.eps,\n                    residual=residual,\n                    prenorm=False,\n                    residual_in_fp32=self.residual_in_fp32,\n                )\n            if output_hidden_states:\n                all_hidden_states.append(hidden_states)\n        return hidden_states, all_hidden_states\n\n\ndef cross_entropy(logits, y, ignore_index=-100):\n    \"\"\"Cross entropy loss.\"\"\"\n    logits = logits.view(-1, logits.shape[-1])\n    y = y.view(-1)\n    return F.cross_entropy(logits, y, ignore_index=ignore_index)\n\n\ndef weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):\n    \"\"\"Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome).\"\"\"\n    logits = logits.view(-1, logits.shape[-1])\n    y = y.view(-1)\n    ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction=\"none\")\n    loss_weights = loss_weights.view(-1)\n    loss_weights[y == ignore_index] = 0.0\n    # TODO: Follows GPN implementation, but should we remove weight normalization?\n    return (ce * (loss_weights / loss_weights.sum())).sum()\n\n\nclass CaduceusPreTrainedModel(PreTrainedModel):\n    \"\"\"PreTrainedModel wrapper for Caduceus backbone.\"\"\"\n    config_class = CaduceusConfig\n    base_model_prefix = \"caduceus\"\n    supports_gradient_checkpointing = False\n    _no_split_modules = [\"BiMambaWrapper\"]\n\n    def _init_weights(\n            self,\n            module,\n            initializer_range=0.02,  # Now only used for embedding layer.\n            **kwargs,\n    ):\n        \"\"\"Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py\"\"\"\n\n        n_layer = self.config.n_layer\n        initialized_cfg = self.config.initializer_cfg if self.config.initializer_cfg is not None else {}\n        rescale_prenorm_residual = initialized_cfg.get(\"rescale_prenorm_residual\", True)\n        initializer_range = initialized_cfg.get(\"initializer_range\", initializer_range)\n        n_residuals_per_layer = initialized_cfg.get(\"n_residuals_per_layer\", 1)\n\n        if isinstance(module, nn.Linear):\n            if module.bias is not None:\n                if not getattr(module.bias, \"_no_reinit\", False):\n                    nn.init.zeros_(module.bias)\n        elif isinstance(module, nn.Embedding):\n            nn.init.normal_(module.weight, std=initializer_range)\n\n        if rescale_prenorm_residual:\n            # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:\n            #   > A modified initialization which accounts for the accumulation on the residual path with model depth.\n            #   > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of\n            #   residual layers.\n            #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/\n            #\n            # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py\n            for name, p in module.named_parameters():\n                if name in [\"out_proj.weight\", \"fc2.weight\"]:\n                    # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block\n                    # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)\n                    # We need to reinit p since this code could be called multiple times\n                    # Having just p *= scale would repeatedly scale it down\n                    nn.init.kaiming_uniform_(p, a=math.sqrt(5))\n                    with torch.no_grad():\n                        p /= math.sqrt(n_residuals_per_layer * n_layer)\n\n\nclass Caduceus(CaduceusPreTrainedModel):\n    \"\"\"Caduceus model that can be instantiated using HF patterns.\"\"\"\n    def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):\n        super().__init__(config)\n\n        if config.rcps:\n            assert config.complement_map is not None, \"Complement map must be provided for RCPS.\"\n\n        # Adjust vocab size and complement maps if vocab padding is set.\n        if config.vocab_size % config.pad_vocab_size_multiple != 0:\n            config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple)\n        if config.complement_map is not None and config.vocab_size > len(config.complement_map):\n            for i in range(len(config.complement_map), config.vocab_size):\n                config.complement_map[i] = i\n\n        self.config = config\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)\n\n    def forward(\n            self,\n            input_ids: torch.LongTensor = None,\n            inputs_embeds: Optional[torch.FloatTensor] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n    ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:\n        \"\"\"HF-compatible forward method.\"\"\"\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        hidden_states, all_hidden_states = self.backbone(\n            input_ids,\n            inputs_embeds=inputs_embeds,\n            output_hidden_states=output_hidden_states\n        )\n        if return_dict:\n            return BaseModelOutputWithNoAttention(\n                last_hidden_state=hidden_states,\n                hidden_states=all_hidden_states if output_hidden_states else None\n            )\n        elif output_hidden_states:\n            return hidden_states, all_hidden_states\n        else:\n            return hidden_states\n\n\nclass CaduceusForMaskedLM(CaduceusPreTrainedModel):\n    \"\"\"HF-compatible Caduceus model for masked language modeling.\"\"\"\n\n    def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):\n        super().__init__(config, **kwargs)\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)\n        if config.rcps:\n            self.lm_head = RCPSLMHead(\n                complement_map=self.config.complement_map,  # Use caduceus config as it might have been updated\n                vocab_size=self.config.vocab_size,  # Use caduceus config as it might have been updated\n                true_dim=config.d_model,\n                dtype=dtype\n            )\n        else:\n            self.lm_head = nn.Linear(\n                config.d_model,\n                self.config.vocab_size,  # Use caduceus config as it might have been updated\n                bias=False,\n                **factory_kwargs\n            )\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.caduceus.backbone.embeddings.word_embeddings\n\n    def set_input_embeddings(self, value):\n        if self.config.rcps:\n            raise NotImplementedError(\"Setting input embeddings for RCPS LM is not supported.\")\n        self.caduceus.backbone.embeddings.word_embeddings = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        \"\"\"Overrides output embeddings.\"\"\"\n        if self.config.rcps:\n            raise NotImplementedError(\"Setting output embeddings for RCPS LM is not supported.\")\n        self.lm_head = new_embeddings\n\n    def tie_weights(self):\n        \"\"\"Tie weights, accounting for RCPS.\"\"\"\n        if self.config.rcps:\n            self.lm_head.set_weight(self.get_input_embeddings().weight)\n        else:\n            super().tie_weights()\n\n    def get_decoder(self):\n        \"\"\"Get decoder (backbone) for the model.\"\"\"\n        return self.caduceus\n\n    def set_decoder(self, decoder):\n        \"\"\"Set decoder (backbone) for the model.\"\"\"\n        self.caduceus = decoder\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        loss_weights: Optional[torch.FloatTensor] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, MaskedLMOutput]:\n        \"\"\"HF-compatible forward method.\"\"\"\n\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.caduceus(\n            input_ids=input_ids,\n            inputs_embeds=inputs_embeds,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states)\n        logits = logits.float()\n\n        loss = None\n        if labels is not None:\n            if loss_weights is not None:\n                loss = weighted_cross_entropy(logits, labels, loss_weights, ignore_index=self.config.pad_token_id)\n            else:\n                loss = cross_entropy(logits, labels, ignore_index=self.config.pad_token_id)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return MaskedLMOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n        )\n\n\nclass CaduceusForSequenceClassification(CaduceusPreTrainedModel):\n    def __init__(\n            self,\n            config: CaduceusConfig,\n            pooling_strategy: str = \"mean\",\n            conjoin_train: bool = False,\n            conjoin_eval: bool = False,\n            device=None,\n            dtype=None,\n            **kwargs):\n        super().__init__(config, **kwargs)\n        if pooling_strategy not in [\"mean\", \"max\", \"first\", \"last\"]:\n            raise NotImplementedError(f\"Pooling strategy `{pooling_strategy}` not implemented.\")\n        self.pooling_strategy = pooling_strategy\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        self.num_labels = kwargs.get(\"num_labels\", config.num_labels)\n        self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)\n        self.score = nn.Linear(config.d_model, self.num_labels, bias=False)\n\n        self.conjoin_train = conjoin_train\n        self.conjoin_eval = conjoin_eval\n\n        # Initialize weights and apply final processing\n        self.post_init()\n        self.init_scorer()\n\n    def init_scorer(self, initializer_range=0.02):\n        initializer_range = self.config.initializer_cfg.get(\"initializer_range\", initializer_range) \\\n            if self.config.initializer_cfg is not None else initializer_range\n        self.score.weight.data.normal_(std=initializer_range)\n\n    def get_input_embeddings(self):\n        return self.caduceus.backbone.embeddings.word_embeddings\n\n    def set_input_embeddings(self, value):\n        if self.config.rcps:\n            raise NotImplementedError(\"Setting input embeddings for RCPS LM is not supported.\")\n        self.caduceus.backbone.embeddings.word_embeddings = value\n\n    def pool_hidden_states(self, hidden_states, sequence_length_dim=1):\n        \"\"\"Pools hidden states along sequence length dimension.\"\"\"\n        if self.pooling_strategy == \"mean\":  # Mean pooling along sequence length dimension\n            return hidden_states.mean(dim=sequence_length_dim)\n        if self.pooling_strategy == \"max\":  # Max pooling along sequence length dimension\n            return hidden_states.max(dim=sequence_length_dim).values\n        if self.pooling_strategy == \"last\":  # Use embedding of last token in the sequence\n            return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[-1, ...]\n        if self.pooling_strategy == \"first\":  # Use embedding of first token in the sequence\n            return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutput]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # Get hidden representations from the backbone\n        if self.config.rcps:  # Hidden states have 2 * d_model channels for RCPS\n            transformer_outputs = self.caduceus(\n                input_ids,\n                inputs_embeds=inputs_embeds,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n            )\n            hidden_states = torch.stack(\n                [\n                    transformer_outputs[0][..., :self.config.d_model],\n                    torch.flip(transformer_outputs[0][..., self.config.d_model:], dims=[1, 2])\n                 ],\n                dim=-1\n            )\n        elif self.conjoin_train or (self.conjoin_eval and not self.training):  # For conjoining / post-hoc conjoining\n            assert input_ids is not None, \"`input_ids` must be provided for conjoining.\"\n            assert input_ids.ndim == 3, \"`input_ids` must be 3D tensor: channels corresponds to forward and rc strands.\"\n            transformer_outputs = self.caduceus(\n                input_ids[..., 0],\n                inputs_embeds=None,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n            )\n            transformer_outputs_rc = self.caduceus(\n                input_ids[..., 1],\n                inputs_embeds=None,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n            )\n            # Stack along channel dimension (dim=-1)\n            hidden_states = torch.stack([transformer_outputs[0], transformer_outputs_rc[0]], dim=-1)\n        else:\n            transformer_outputs = self.caduceus(\n                input_ids,\n                inputs_embeds=None,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n            )\n            hidden_states = transformer_outputs[0]\n\n        # Pool and get logits\n        pooled_hidden_states = self.pool_hidden_states(hidden_states)\n        # Potentially run `score` twice (with parameters shared) for conjoining\n        if hidden_states.ndim == 4:  # bsz, seq_len, hidden_dim, 2 where last channel has the stacked fwd and rc reps\n            logits_fwd = self.score(pooled_hidden_states[..., 0])\n            logits_rc = self.score(pooled_hidden_states[..., 1])\n            logits = (logits_fwd + logits_rc) / 2\n        else:\n            logits = self.score(pooled_hidden_states)\n\n        loss = None\n        if labels is not None:\n            labels = labels.to(logits.device)\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                if self.num_labels == 1:\n                    loss = F.mse_loss(logits.squeeze(), labels.squeeze())\n                else:\n                    loss = F.mse_loss(logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss = F.binary_cross_entropy_with_logits(logits, labels)\n        if not return_dict:\n            output = (logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=transformer_outputs.hidden_states,\n        )\n"
  },
  {
    "path": "caduceus/modeling_rcps.py",
    "content": "\"\"\"Reverse-complement equivariant modules.\n\n\"\"\"\nfrom collections import OrderedDict\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch import nn\nfrom torch.nn import functional as F\n\ntry:\n    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn  # Legacy mambav1 file structure\nexcept ImportError:\n    try:\n        from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn  # mambav2 file structure\n    except ImportError:\n        RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None\n\n\nclass RCPSEmbedding(nn.Module):\n    \"\"\"Embedding layer that supports reverse-complement equivariance.\"\"\"\n    def __init__(self, vocab_size: int, d_model: int, complement_map: dict, **factory_kwargs):\n        \"\"\"\n        Args:\n            vocab_size: Size of vocabulary.\n            d_model: Dimensionality of embedding (actual embedding matrix will have 1/2 the output dim).\n            complement_map: Dictionary mapping each token id to its complement.\n        \"\"\"\n        super().__init__()\n        self.register_buffer(\n            \"complement_map\",\n            torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)\n        )\n        self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)\n\n    @property\n    def weight(self):\n        \"\"\"Embedding weights.\"\"\"\n        return self.embedding.weight\n\n    def set_weight(self, value):\n        \"\"\"Set embedding weights.\"\"\"\n        self.embedding.weight = value\n\n    def rc(self, x):\n        \"\"\"Reverse-complement a tensor of input_ids by flipping along length dimension and complementing the ids.\"\"\"\n        return torch.gather(\n            self.complement_map.unsqueeze(0).expand(x.shape[0], -1),\n            dim=1,\n            index=torch.flip(x, dims=[-1])\n        )\n\n    def forward(self, input_ids):\n        \"\"\"Reverse-complement equivariant forward pass.\n\n        This embedding module doubles the output dimensionality to support reverse-complement equivariance.\n\n        Args:\n            input_ids: Input tensor of shape (batch_size, seq_len)\n        Returns:\n            Embedding tensor of shape (batch_size, seq_len, d_model * 2)\n        \"\"\"\n        fwd_out = self.embedding(input_ids)\n        rc_out = torch.flip(self.embedding(self.rc(input_ids)), dims=[-2, -1])\n\n        return torch.cat([fwd_out, rc_out], dim=-1)\n\n\nclass RCPSWrapper(nn.Module):\n    \"\"\"Wrapper to convert arbitrary nn.Module into a reverse-complement equivariant module.\n\n    See ref. \"Towards a Better Understanding of Reverse-Complement Equivariance for Deep Learning Models in Regulatory\n    Genomics\", Zhou et al. (2022), https://proceedings.mlr.press/v165/zhou22a.html for more details.\n    \"\"\"\n    def __init__(self, submodule: nn.Module):\n        super().__init__()\n        self.submodule = submodule\n\n    @staticmethod\n    def rc(x):\n        \"\"\"Reverse-complement a tensor by flipping the length (dim=-2) and channel (dim=-1) dimensions.\"\"\"\n        return torch.flip(x, dims=[-2, -1])\n\n    def forward(self, x, **kwargs):\n        \"\"\"Reverse-complement equivariant forward pass.\n\n        Args:\n            x: Input tensor of shape (batch_size, seq_len, channels)\n        Returns:\n            Output tensor of shape (batch_size, seq_len, channels)\n        \"\"\"\n        n_channels = x.shape[-1]\n        # Run submodule along sequence\n        fwd_out = self.submodule(x[..., :n_channels // 2], **kwargs)\n        # Run submodule along rc-sequence\n        rc_out = self.submodule(self.rc(x[..., n_channels // 2:]), **kwargs)\n        # Concatenate along channel dimension (dim=-1)\n        return torch.cat([fwd_out, self.rc(rc_out)], dim=-1)\n\n\nclass RCPSAddNormWrapper(RCPSWrapper):\n    \"\"\"RC equivariant AddNorm layer.\"\"\"\n    def __init__(self, submodule: nn.Module):\n        super().__init__(submodule)\n\n    def forward(self, x, residual=None, prenorm=False):\n        \"\"\"\n        Args:\n            x: Input tensor of shape (batch_size, seq_len, channels)\n            residual: Residual tensor of shape (batch_size, seq_len, channels) or None.\n            prenorm: Whether to return residual.\n        \"\"\"\n        n_channels = x.shape[-1]\n        if residual is None:\n            residual = x\n            x_fwd = self.submodule(x[..., :n_channels // 2].to(dtype=self.submodule.weight.dtype))\n            x_rc = self.submodule(self.rc(x[..., n_channels // 2:]).to(dtype=self.submodule.weight.dtype))\n            x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)\n        else:\n            residual_fwd = x[..., :n_channels // 2] + residual[..., :n_channels // 2]\n            x_fwd = self.submodule(residual_fwd.to(dtype=self.submodule.weight.dtype))\n\n            residual_rc = self.rc(x[..., n_channels // 2:]) + self.rc(residual[..., n_channels // 2:])\n            x_rc = self.submodule(residual_rc.to(dtype=self.submodule.weight.dtype))\n\n            residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)\n            x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)\n\n        return x if not prenorm else (x, residual)\n\n\nclass RCPSMambaBlock(nn.Module):\n    def __init__(\n            self,\n            dim,\n            mixer_cls,\n            norm_cls=nn.LayerNorm,\n            fused_add_norm=False,\n            residual_in_fp32=False,\n            device=None,  # Keep for consistency with original Mamba Block\n            dtype=None,  # Keep for consistency with original Mamba Block\n    ):\n        \"\"\"RCPS version of simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection.\n\n        Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py\n        \"\"\"\n        super().__init__()\n        self.residual_in_fp32 = residual_in_fp32\n        self.fused_add_norm = fused_add_norm\n        self.mixer = RCPSWrapper(mixer_cls(dim))\n        norm_f = norm_cls(dim)\n        self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)\n        if self.fused_add_norm:\n            assert RMSNorm is not None, \"RMSNorm import fails\"\n            assert isinstance(\n                self.norm, (nn.LayerNorm, RMSNorm)\n            ), \"Only LayerNorm and RMSNorm are supported for fused_add_norm\"\n\n    def forward(\n        self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None\n    ):\n        r\"\"\"Pass the input through the encoder layer.\n\n        Args:\n            hidden_states: the sequence to the encoder layer (required).\n            residual: hidden_states = Mixer(LN(residual)).\n            inference_params: inference parameters for mixer.\n        \"\"\"\n        if not self.fused_add_norm:\n            hidden_states, residual = self.norm(hidden_states, residual=residual, prenorm=True)\n            if self.residual_in_fp32:\n                residual = residual.to(torch.float32)\n        else:\n            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn\n\n            hidden_states_fwd, residual_fwd = fused_add_norm_fn(\n                hidden_states[..., hidden_states.shape[-1] // 2:],\n                self.norm.weight,\n                self.norm.bias,\n                residual=residual[..., hidden_states.shape[-1] // 2:] if residual is not None else None,\n                prenorm=True,\n                residual_in_fp32=self.residual_in_fp32,\n                eps=self.norm.eps,\n            )\n\n            hidden_states_rc, residual_rc = fused_add_norm_fn(\n                hidden_states[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]),\n                self.norm.weight,\n                self.norm.bias,\n                residual=residual[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]) if residual is not None else None,\n                prenorm=True,\n                residual_in_fp32=self.residual_in_fp32,\n                eps=self.norm.eps,\n            )\n            hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)\n            residual = torch.cat([residual_fwd, residual_rc.flip(dims=[-2, -1])], dim=-1)\n        hidden_states = self.mixer(hidden_states, inference_params=inference_params)\n        return hidden_states, residual\n\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n        \"\"\"Allocate inference cache for mixer.\n\n        Keep for compatibility with original Mamba Block.\n        \"\"\"\n        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)\n\n\nclass RCPSLMHead(nn.Module):\n    \"\"\"LM Head for reverse-complement equivariant inputs, which have dim * 2 relative to standard inputs.\"\"\"\n    def __init__(self, true_dim: int, vocab_size: int, complement_map: dict, **factory_kwargs):\n        \"\"\"\n        `true_dim` corresponds to the actual dimensionality of the input were it not reverse-complement\n        equivariant, i.e. 0.5 times the actual input dim.\n        \"\"\"\n        super().__init__()\n        self.register_buffer(\n            \"complement_map\",\n            torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)\n        )\n        self.true_dim = true_dim\n        self.lm_head = nn.Linear(true_dim, vocab_size, bias=False, **factory_kwargs)\n\n    @property\n    def weight(self):\n        \"\"\"LM head weights.\"\"\"\n        return self.lm_head.weight\n\n    def set_weight(self, value):\n        \"\"\"Set LM head weights.\"\"\"\n        self.lm_head.weight = value\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x: Input tensor of shape (batch_size, seq_len, dim), where dim = 2 * true_dim.\n        \"\"\"\n        n_channels = x.shape[-1]\n        assert n_channels == 2 * self.true_dim, \"Input must have 2 * true_dim channels.\"\n        fwd_logits = F.linear(x[..., :n_channels // 2], self.weight, bias=self.lm_head.bias)\n        rc_logits = F.linear(\n            torch.flip(x[..., n_channels // 2:], dims=[-1]),\n            self.weight[self.complement_map, :],\n            bias=self.lm_head.bias\n        )\n        return fwd_logits + rc_logits\n"
  },
  {
    "path": "caduceus/tests/test_rcps.py",
    "content": "\"\"\"Tests for RCPS modules.\n\n\"\"\"\n\nimport pytest\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ntry:  # Legacy mambav1 file structure\n    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn\nexcept ImportError:\n    try:  # mambav2 file structure\n        from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn\n    except ImportError:\n        RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None\n\nfrom caduceus.modeling_rcps import (\n    RCPSEmbedding, RCPSAddNormWrapper, RCPSLMHead, RCPSWrapper\n)\n\nfrom caduceus.modeling_caduceus import (\n    CaduceusConfig, CaduceusMixerModel, CaduceusForMaskedLM, create_block\n)\n\n\n@pytest.mark.parametrize(\"batch_size\", [4])\n@pytest.mark.parametrize(\"seq_len\", [512])\n@pytest.mark.parametrize(\"d_model\", [256])\n@pytest.mark.parametrize(\"dtype\", [torch.float32])\ndef test_rcps_embedding(batch_size, seq_len, d_model, dtype):\n    # Set tolerance\n    device = torch.device(\"cpu\")\n    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)\n    if dtype == torch.bfloat16:\n        rtol, atol = 3e-2, 5e-2\n    # Set seed\n    torch.random.manual_seed(0)\n\n    # Define complement map\n    str_to_id = {\"[CLS]\": 0, \"[MASK]\": 1, \"A\": 2, \"C\": 3, \"G\": 4, \"T\": 5, \"N\": 6}\n    complement_map = {\"A\": \"T\", \"C\": \"G\", \"G\": \"C\", \"T\": \"A\"}\n    complement_map = {\n        str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v\n        for k, v in str_to_id.items()\n    }\n    vocab_size = 12\n    pad_vocab_size_multiple = 8\n    if vocab_size % pad_vocab_size_multiple != 0:\n        vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)\n    if vocab_size > len(complement_map):\n        for i in range(len(complement_map), vocab_size):\n            complement_map[i] = i\n\n    # Generate random sequences\n    input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)\n    rc_input_ids = torch.flip(input_ids, dims=[-1]).to(\"cpu\").apply_(lambda t: complement_map[t]).to(device)\n\n    # Test RC equivariance of embedding layer\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n    embedding = RCPSEmbedding(\n        vocab_size=vocab_size,\n        d_model=d_model,\n        complement_map=complement_map,\n        **factory_kwargs\n    ).to(device)\n    out_embed = embedding(input_ids)\n    rc_out_embed = torch.flip(embedding(rc_input_ids), dims=[-2, -1])\n    # Test that channels are 2 * d_model\n    assert tuple(out_embed.size()) == (batch_size, seq_len, d_model * 2)\n    assert tuple(rc_out_embed.size()) == (batch_size, seq_len, d_model * 2)\n    # Test that RC equivariance holds\n    assert torch.allclose(out_embed.detach(), rc_out_embed.detach(), rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"batch_size\", [2])\n@pytest.mark.parametrize(\"seq_len\", [1024])\n@pytest.mark.parametrize(\"d_model\", [128])\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16])\ndef test_rcps_wrapper(batch_size, seq_len, d_model, dtype):\n    # Set tolerance\n    device = torch.device(\"cuda\")\n    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)\n    if dtype == torch.bfloat16:\n        rtol, atol = 3e-2, 5e-2\n    # Set seed\n    torch.random.manual_seed(0)\n\n    # Generate random sequence with 2 * d_model channels\n    x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)\n    rc_x = torch.flip(x, dims=[-2, -1])\n\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n    module = nn.Sequential(\n        nn.Linear(d_model, d_model, bias=False, **factory_kwargs),\n        nn.ReLU(),\n        nn.Linear(d_model, d_model*2, bias=True, **factory_kwargs),\n        nn.ReLU(),\n        nn.Linear(d_model * 2, d_model, bias=True, **factory_kwargs)\n    )\n\n    # Test RC equivariance of wrapper\n    rcps_module = RCPSWrapper(module).to(device)\n    out = rcps_module(x)\n    rc_out = torch.flip(rcps_module(rc_x), dims=[-2, -1])\n    assert out.size() == x.size()\n    assert rc_out.size() == x.size()\n    assert torch.allclose(out.detach(), rc_out.detach(), rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"batch_size\", [2])\n@pytest.mark.parametrize(\"seq_len\", [1024])\n@pytest.mark.parametrize(\"d_model\", [128])\n@pytest.mark.parametrize(\"prenorm\", [False, True])\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\ndef test_rcps_add_norm_wrapper(batch_size, seq_len, d_model, prenorm, dtype):\n    # Set tolerance\n    device = torch.device(\"cuda\")\n    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)\n    if dtype == torch.bfloat16:\n        rtol, atol = 3e-2, 5e-2\n    # Set seed\n    torch.random.manual_seed(0)\n\n    # Generate random sequence with 2 * d_model channels\n    x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)\n    rc_x = torch.flip(x, dims=[-2, -1])\n\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n    norm = RMSNorm(d_model, eps=1e-5, **factory_kwargs)\n\n    # Test RC equivariance of wrapper\n    rcps_module = RCPSAddNormWrapper(norm).to(device)\n    out = rcps_module(x, prenorm=prenorm)\n    if prenorm:  # returns tuple\n        rc_out = tuple([torch.flip(r, dims=[-2, -1])\n                        for r in rcps_module(rc_x, prenorm=prenorm)])\n        for f, r in zip(out, rc_out):\n            assert f.size() == x.size()\n            assert r.size() == x.size()\n            assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)\n    else:\n        rc_out = torch.flip(rcps_module(rc_x, prenorm=prenorm), dims=[-2, -1])\n        assert out.size() == x.size()\n        assert rc_out.size() == x.size()\n        assert torch.allclose(out.detach(), rc_out.detach(), rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"batch_size\", [2])\n@pytest.mark.parametrize(\"seq_len\", [1024])\n@pytest.mark.parametrize(\"d_model\", [128])\n@pytest.mark.parametrize(\"bidirectional\", [True, False])\n@pytest.mark.parametrize(\"fused_add_norm\", [True, False])\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\ndef test_rcps_mamba_block_wrapper(batch_size, seq_len, d_model, bidirectional, fused_add_norm, dtype):\n    # Set tolerance\n    device = torch.device(\"cuda\")\n    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)\n    if dtype == torch.bfloat16:\n        rtol, atol = 3e-2, 5e-2\n    # Set seed\n    torch.random.manual_seed(0)\n\n    # Generate random sequence with 2 * d_model channels\n    x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)\n    rc_x = torch.flip(x, dims=[-2, -1])\n\n    ssm_cfg = {\n        \"d_state\": 16, \"d_conv\": 4, \"expand\": 2, \"dt_rank\": \"auto\", \"dt_min\": 0.001, \"dt_max\": 0.1, \"dt_init\": \"random\",\n        \"dt_scale\": 1.0, \"dt_init_floor\": 1e-4, \"conv_bias\": True, \"bias\": False, \"use_fast_path\": True\n    }\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n\n    mamba_block = create_block(\n        d_model,\n        ssm_cfg=ssm_cfg,\n        norm_epsilon=1e-5,\n        rms_norm=True,\n        residual_in_fp32=True,\n        fused_add_norm=fused_add_norm,\n        layer_idx=0,\n        bidirectional=bidirectional,\n        bidirectional_strategy=\"add\",\n        bidirectional_weight_tie=True,\n        rcps=True,\n        **factory_kwargs\n    )\n\n    # Test RC equivariance of wrapper\n    out = mamba_block(x, residual=None)\n    rc_out = tuple([torch.flip(r, dims=[-2, -1]) for r in mamba_block(rc_x, residual=None)])\n    for f, r in zip(out, rc_out):\n        assert f.size() == x.size()\n        assert r.size() == x.size()\n        assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)\n\n    out = mamba_block(x, residual=x)\n    rc_out = tuple([torch.flip(r, dims=[-2, -1]) for r in mamba_block(rc_x, residual=rc_x)])\n    for f, r in zip(out, rc_out):\n        assert f.size() == x.size()\n        assert r.size() == x.size()\n        assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"batch_size\", [2, 4])\n@pytest.mark.parametrize(\"seq_len\", [1, 1024, 2048])\n@pytest.mark.parametrize(\"d_model\", [2, 128, 256])\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16])\ndef test_rcps_lm_head(batch_size, seq_len, d_model, dtype):\n    # Set tolerance\n    device = torch.device(\"cuda\")\n    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)\n    if dtype == torch.bfloat16:\n        rtol, atol = 3e-2, 5e-2\n\n    # Set seed\n    torch.random.manual_seed(0)\n\n    # Define complement map\n    str_to_id = {\"[CLS]\": 0, \"[MASK]\": 1, \"A\": 2, \"C\": 3, \"G\": 4, \"T\": 5, \"N\": 6}\n    complement_map = {\"A\": \"T\", \"C\": \"G\", \"G\": \"C\", \"T\": \"A\"}\n    complement_map = {\n        str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v\n        for k, v in str_to_id.items()\n    }\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n    vocab_size = 12\n    if vocab_size > len(complement_map):\n        for i in range(len(complement_map), vocab_size):\n            complement_map[i] = i\n\n    # Instantiate LM head\n    lm_head = RCPSLMHead(\n        complement_map=complement_map,\n        vocab_size=vocab_size,\n        true_dim=d_model,\n        **factory_kwargs\n    )\n\n    # Generate random sequence with 2 * d_model channels\n    x = torch.randn(batch_size, seq_len, d_model * 2, device=device, dtype=dtype)\n    rc_x = torch.flip(x, dims=[-2, -1])\n\n    # Test RC equivariance of LM head\n    out = lm_head(x)\n    rc_out = lm_head(rc_x)\n    assert tuple(out.size()) == (batch_size, seq_len, vocab_size)\n    assert tuple(rc_out.size()) == (batch_size, seq_len, vocab_size)\n    assert torch.allclose(\n        out.detach(),\n        torch.flip(rc_out.detach()[..., lm_head.complement_map], dims=[1]),\n        rtol=rtol,\n        atol=atol\n    )\n    assert torch.allclose(\n        F.softmax(out, dim=-1).detach(),\n        torch.flip(F.softmax(rc_out, dim=-1).detach()[..., lm_head.complement_map], dims=[1]),\n        rtol=rtol,\n        atol=atol\n    )\n\n\n@pytest.mark.parametrize(\"batch_size\", [2, 4])\n@pytest.mark.parametrize(\"seq_len\", [1024, 2048])\n@pytest.mark.parametrize(\"n_layer\", [1, 2, 3])\n@pytest.mark.parametrize(\"d_model\", [128, 256])\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16])\n@pytest.mark.parametrize(\"fused_add_norm\", [True, False])\n@pytest.mark.parametrize(\"bidirectional\", [False, True])\n@pytest.mark.parametrize(\"bidirectional_weight_tie\", [False, True])\ndef test_rcps_backbone(batch_size, seq_len, n_layer, d_model, dtype, fused_add_norm,\n                       bidirectional, bidirectional_weight_tie):\n    # Set tolerance\n    device = torch.device(\"cuda\")\n    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)\n    if dtype == torch.bfloat16:\n        rtol, atol = 3e-2, 5e-2\n\n    # Set seed\n    torch.random.manual_seed(0)\n\n    # Define complement map\n    str_to_id = {\"[CLS]\": 0, \"[MASK]\": 1, \"A\": 2, \"C\": 3, \"G\": 4, \"T\": 5, \"N\": 6}\n    complement_map = {\"A\": \"T\", \"C\": \"G\", \"G\": \"C\", \"T\": \"A\"}\n    complement_map = {\n        str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v\n        for k, v in str_to_id.items()\n    }\n\n    # Setup CaduceusConfig\n    initializer_cfg = {\"initializer_range\": 0.02, \"rescale_prenorm_residual\": True, \"n_residuals_per_layer\": 1}\n    ssm_cfg = {\n        \"d_state\": 16, \"d_conv\": 4, \"expand\": 2, \"dt_rank\": \"auto\", \"dt_min\": 0.001, \"dt_max\": 0.1, \"dt_init\": \"random\",\n        \"dt_scale\": 1.0, \"dt_init_floor\": 1e-4, \"conv_bias\": True, \"bias\": False, \"use_fast_path\": True\n    }\n    config = CaduceusConfig(\n        d_model=d_model,\n        n_layer=n_layer,\n        vocab_size=12,\n        ssm_cfg=ssm_cfg,\n        rms_norm=True,\n        residual_in_fp32=False,\n        fused_add_norm=fused_add_norm,\n        pad_vocab_size_multiple=8,\n        norm_epsilon=1e-5,\n        initializer_cfg=initializer_cfg,\n        bidirectional=bidirectional,\n        bidirectional_strategy=\"add\",\n        bidirectional_weight_tie=bidirectional_weight_tie,\n        rcps=True,\n        complement_map=complement_map,\n    )\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n\n    # Instantiate model\n    backbone = CaduceusMixerModel(\n        config,\n        **factory_kwargs,\n    ).to(device)\n\n    # Generate random sequences\n    input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)\n    rc_input_ids = torch.flip(input_ids, dims=[-1]).to(\"cpu\").apply_(lambda t: complement_map[t]).to(device)\n\n    # Test RC equivariance of rc backbone\n    out = backbone(input_ids)[0]\n    rc_out = backbone(rc_input_ids)[0]\n    if isinstance(rc_out, tuple):\n        rc_out = tuple([torch.flip(r, dims=[1, 2]) for r in rc_out])\n        for f, r in zip(out, rc_out):\n            assert f.size() == (batch_size, seq_len, d_model * 2)\n            assert r.size() == (batch_size, seq_len, d_model * 2)\n            assert torch.allclose(f.detach(), r.detach(), rtol=rtol, atol=atol)\n    else:\n        # Hidden state size should double\n        assert tuple(out.size()) == (batch_size, seq_len, d_model * 2)\n        assert tuple(rc_out.size()) == (batch_size, seq_len, d_model * 2)\n        assert torch.allclose(out.detach(), torch.flip(rc_out.detach(), dims=[1, 2]), rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"batch_size\", [2, 4])\n@pytest.mark.parametrize(\"seq_len\", [1024, 2048])\n@pytest.mark.parametrize(\"n_layer\", [1, 3, 4])\n@pytest.mark.parametrize(\"d_model\", [128, 256])\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16])\n@pytest.mark.parametrize(\"bidirectional\", [False, True])\n@pytest.mark.parametrize(\"bidirectional_weight_tie\", [False, True])\ndef test_rcps_mamba_lm(batch_size, seq_len, n_layer, d_model, dtype, bidirectional, bidirectional_weight_tie):\n    # Set tolerance\n    device = torch.device(\"cuda\")\n    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)\n    if dtype == torch.bfloat16:\n        rtol, atol = 3e-2, 5e-2\n\n    # Set seed\n    torch.random.manual_seed(0)\n\n    # Define complement map\n    str_to_id = {\"[CLS]\": 0, \"[MASK]\": 1, \"A\": 2, \"C\": 3, \"G\": 4, \"T\": 5, \"N\": 6}\n    complement_map = {\"A\": \"T\", \"C\": \"G\", \"G\": \"C\", \"T\": \"A\"}\n    complement_map = {\n        str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v\n        for k, v in str_to_id.items()\n    }\n\n    # Setup CaduceusConfig\n    initializer_cfg = {\"initializer_range\": 0.02, \"rescale_prenorm_residual\": True, \"n_residuals_per_layer\": 1}\n    ssm_cfg = {\n        \"d_state\": 16, \"d_conv\": 4, \"expand\": 2, \"dt_rank\": \"auto\", \"dt_min\": 0.001, \"dt_max\": 0.1, \"dt_init\": \"random\",\n        \"dt_scale\": 1.0, \"dt_init_floor\": 1e-4, \"conv_bias\": True, \"bias\": False, \"use_fast_path\": True\n    }\n    config = CaduceusConfig(\n        d_model=d_model,\n        n_layer=n_layer,\n        vocab_size=12,\n        ssm_cfg=ssm_cfg,\n        rms_norm=True,\n        residual_in_fp32=False,\n        fused_add_norm=True,\n        pad_vocab_size_multiple=8,\n        norm_epsilon=1e-5,\n        initializer_cfg=initializer_cfg,\n        bidirectional=bidirectional,\n        bidirectional_strategy=\"add\",\n        bidirectional_weight_tie=bidirectional_weight_tie,\n        rcps=True,\n        complement_map=complement_map,\n    )\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n\n    # Instantiate model\n    mamba_lm = CaduceusForMaskedLM(\n        config=config,\n        **factory_kwargs,\n    ).to(device)\n\n    # Generate random sequences\n    input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)\n    rc_input_ids = torch.flip(input_ids, dims=[-1]).to(\"cpu\").apply_(lambda t: complement_map[t]).to(device)\n\n    # Test RC equivariance of rc backbone\n    out = mamba_lm(input_ids)\n    rc_out = mamba_lm(rc_input_ids)\n    if config.vocab_size % config.pad_vocab_size_multiple != 0:\n        config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple)\n    assert tuple(out.logits.size()) == (batch_size, seq_len, config.vocab_size)\n    assert tuple(rc_out.logits.size()) == (batch_size, seq_len, config.vocab_size)\n    assert torch.allclose(\n        out.logits.detach(),\n        torch.flip(rc_out.logits.detach()[..., mamba_lm.lm_head.complement_map], dims=[1]),\n        rtol=rtol,\n        atol=atol\n    )\n    assert torch.allclose(\n        F.softmax(out.logits, dim=-1).detach(),\n        torch.flip(F.softmax(rc_out.logits, dim=-1).detach()[..., mamba_lm.lm_head.complement_map], dims=[1]),\n        rtol=rtol,\n        atol=atol\n    )\n\n\n@pytest.mark.parametrize(\"batch_size\", [2])\n@pytest.mark.parametrize(\"seq_len\", [1024])\n@pytest.mark.parametrize(\"n_layer\", [2])\n@pytest.mark.parametrize(\"d_model\", [128])\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"bidirectional\", [True, False])\n@pytest.mark.parametrize(\"bidirectional_weight_tie\", [True])\ndef test_collapse_invariance(batch_size, seq_len, n_layer, d_model, dtype, bidirectional, bidirectional_weight_tie):\n    # Set tolerance\n    device = torch.device(\"cuda\")\n    rtol, atol = (6e-4, 2e-3) if dtype == torch.float32 else (3e-3, 5e-3)\n    if dtype == torch.bfloat16:\n        rtol, atol = 3e-2, 5e-2\n\n    # Set seed\n    torch.random.manual_seed(0)\n\n    # Define complement map\n    str_to_id = {\"[CLS]\": 0, \"[MASK]\": 1, \"A\": 2, \"C\": 3, \"G\": 4, \"T\": 5, \"N\": 6}\n    complement_map = {\"A\": \"T\", \"C\": \"G\", \"G\": \"C\", \"T\": \"A\"}\n    complement_map = {\n        str_to_id[k]: str_to_id[complement_map[k]] if k in complement_map.keys() else v\n        for k, v in str_to_id.items()\n    }\n\n    # Setup CaduceusConfig\n    initializer_cfg = {\"initializer_range\": 0.02, \"rescale_prenorm_residual\": True, \"n_residuals_per_layer\": 1}\n    ssm_cfg = {\n        \"d_state\": 16, \"d_conv\": 4, \"expand\": 2, \"dt_rank\": \"auto\", \"dt_min\": 0.001, \"dt_max\": 0.1, \"dt_init\": \"random\",\n        \"dt_scale\": 1.0, \"dt_init_floor\": 1e-4, \"conv_bias\": True, \"bias\": False, \"use_fast_path\": True\n    }\n    config = CaduceusConfig(\n        d_model=d_model,\n        n_layer=n_layer,\n        vocab_size=12,\n        ssm_cfg=ssm_cfg,\n        rms_norm=True,\n        residual_in_fp32=False,\n        fused_add_norm=True,\n        pad_vocab_size_multiple=8,\n        norm_epsilon=1e-5,\n        initializer_cfg=initializer_cfg,\n        bidirectional=bidirectional,\n        bidirectional_strategy=\"add\",\n        bidirectional_weight_tie=bidirectional_weight_tie,\n        rcps=True,\n        complement_map=complement_map,\n    )\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n\n    # Instantiate model\n    backbone = CaduceusMixerModel(\n        config,\n        **factory_kwargs,\n    ).to(device)\n\n    # Generate random sequences\n    input_ids = torch.randint(low=1, high=len(str_to_id), size=(batch_size, seq_len), device=device)\n    rc_input_ids = torch.flip(input_ids, dims=[-1]).to(\"cpu\").apply_(lambda t: complement_map[t]).to(device)\n\n    # Test RC Invariance when collapsing output of backbone\n    out = backbone(input_ids)[0]\n    out_collapse = (out[..., :d_model] + torch.flip(out[..., d_model:], dims=[1, 2])) / 2\n    rc_out = backbone(rc_input_ids)[0]\n    rc_out_collapse = (rc_out[..., :d_model] + torch.flip(rc_out[..., d_model:], dims=[1, 2])) / 2\n    # Hidden state size should be d_model\n    assert tuple(out_collapse.size()) == (batch_size, seq_len, d_model)\n    assert tuple(rc_out_collapse.size()) == (batch_size, seq_len, d_model)\n    assert torch.allclose(out_collapse.detach(), rc_out_collapse.detach(), rtol=rtol, atol=atol)\n"
  },
  {
    "path": "caduceus/tokenization_caduceus.py",
    "content": "\"\"\"Character tokenizer for Hugging Face.\n\n\"\"\"\n\nfrom typing import List, Optional, Dict, Sequence, Tuple\n\nfrom transformers import PreTrainedTokenizer\n\n\nclass CaduceusTokenizer(PreTrainedTokenizer):\n    model_input_names = [\"input_ids\"]\n\n    def __init__(self,\n                 model_max_length: int,\n                 characters: Sequence[str] = (\"A\", \"C\", \"G\", \"T\", \"N\"),\n                 complement_map=None,\n                 bos_token=\"[BOS]\",\n                 eos_token=\"[SEP]\",\n                 sep_token=\"[SEP]\",\n                 cls_token=\"[CLS]\",\n                 pad_token=\"[PAD]\",\n                 mask_token=\"[MASK]\",\n                 unk_token=\"[UNK]\",\n                 **kwargs):\n        \"\"\"Character tokenizer for Hugging Face transformers.\n\n        Adapted from https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen-hf/blob/main/tokenization_hyena.py\n        Args:\n            model_max_length (int): Model maximum sequence length.\n            characters (Sequence[str]): List of desired characters. Any character which\n                is not included in this list will be replaced by a special token called\n                [UNK] with id=6. Following is a list of the special tokens with\n                their corresponding ids:\n                    \"[CLS]\": 0\n                    \"[SEP]\": 1\n                    \"[BOS]\": 2\n                    \"[MASK]\": 3\n                    \"[PAD]\": 4\n                    \"[RESERVED]\": 5\n                    \"[UNK]\": 6\n                an id (starting at 7) will be assigned to each character.\n            complement_map (Optional[Dict[str, str]]): Dictionary with string complements for each character.\n        \"\"\"\n        if complement_map is None:\n            complement_map = {\"A\": \"T\", \"C\": \"G\", \"G\": \"C\", \"T\": \"A\", \"N\": \"N\"}\n        self.characters = characters\n        self.model_max_length = model_max_length\n\n        self._vocab_str_to_int = {\n            \"[CLS]\": 0,\n            \"[SEP]\": 1,\n            \"[BOS]\": 2,\n            \"[MASK]\": 3,\n            \"[PAD]\": 4,\n            \"[RESERVED]\": 5,\n            \"[UNK]\": 6,\n            **{ch: i + 7 for i, ch in enumerate(self.characters)},\n        }\n        self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}\n        add_prefix_space = kwargs.pop(\"add_prefix_space\", False)\n        padding_side = kwargs.pop(\"padding_side\", \"left\")\n\n        self._complement_map = {}\n        for k, v in self._vocab_str_to_int.items():\n            complement_id = self._vocab_str_to_int[complement_map[k]] if k in complement_map.keys() else v\n            self._complement_map[self._vocab_str_to_int[k]] = complement_id\n\n        super().__init__(\n            bos_token=bos_token,\n            eos_token=eos_token,\n            sep_token=sep_token,\n            cls_token=cls_token,\n            pad_token=pad_token,\n            mask_token=mask_token,\n            unk_token=unk_token,\n            add_prefix_space=add_prefix_space,\n            model_max_length=model_max_length,\n            padding_side=padding_side,\n            **kwargs,\n        )\n\n    @property\n    def vocab_size(self) -> int:\n        return len(self._vocab_str_to_int)\n\n    @property\n    def complement_map(self) -> Dict[int, int]:\n        return self._complement_map\n\n    def _tokenize(self, text: str, **kwargs) -> List[str]:\n        return list(text.upper())  # Convert all base pairs to uppercase\n\n    def _convert_token_to_id(self, token: str) -> int:\n        return self._vocab_str_to_int.get(token, self._vocab_str_to_int[\"[UNK]\"])\n\n    def _convert_id_to_token(self, index: int) -> str:\n        return self._vocab_int_to_str[index]\n\n    def convert_tokens_to_string(self, tokens):\n        return \"\".join(tokens)  # Note: this operation has lost info about which base pairs were originally lowercase\n\n    def get_special_tokens_mask(\n        self,\n        token_ids_0: List[int],\n        token_ids_1: Optional[List[int]] = None,\n        already_has_special_tokens: bool = False,\n    ) -> List[int]:\n        if already_has_special_tokens:\n            return super().get_special_tokens_mask(\n                token_ids_0=token_ids_0,\n                token_ids_1=token_ids_1,\n                already_has_special_tokens=True,\n            )\n\n        result = ([0] * len(token_ids_0)) + [1]\n        if token_ids_1 is not None:\n            result += ([0] * len(token_ids_1)) + [1]\n        return result\n\n    def build_inputs_with_special_tokens(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        sep = [self.sep_token_id]\n        # cls = [self.cls_token_id]\n        result = token_ids_0 + sep\n        if token_ids_1 is not None:\n            result += token_ids_1 + sep\n        return result\n\n    def get_vocab(self) -> Dict[str, int]:\n        return self._vocab_str_to_int\n\n    # Fixed vocabulary with no vocab file\n    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple:\n        return ()\n"
  },
  {
    "path": "caduceus_env.yml",
    "content": "name: caduceus_env\nchannels:\n  - pytorch\n  - anaconda\n  - nvidia\n  - defaults\ndependencies:\n  - cuda-nvcc=11.7.99\n  - pip=23.3.1\n  - python=3.8\n  - pytorch=2.2.0\n  - torchaudio=2.2.0\n  - torchaudio=2.2.0\n  - torchdata=0.7.1\n  - torchmetrics=1.2.1\n  - torchtext=0.17.0\n  - torchvision=0.17.0\n  - pytorch-cuda=12.1\n  - pip:\n      - biopython==1.81\n      - datasets==2.15.0\n      - einops==0.7.0\n      - enformer-pytorch==0.8.8\n      - fsspec==2023.10.0\n      - genomic-benchmarks==0.0.9\n      - git-lfs==1.6\n      - h5py==3.10.0\n      - huggingface-hub==0.24.7\n      - hydra-core==1.3.2\n      - ipdb==0.13.13\n      - matplotlib==3.7.4\n      - notebook==7.1.1\n      - nvitop==1.3.2\n      - omegaconf==2.3.0\n      - pandas==2.0.3\n      - pyfaidx==0.8.1.1\n      - pysam==0.22.0\n      - pytest==8.0.2\n      - pytorch-lightning==1.8.6\n      - rich==13.7.0\n      - seaborn==0.13.2\n      - scikit-learn==1.3.2\n      - timm==0.9.16\n      - tqdm==4.66.1\n      - transformers==4.38.1\n      - triton==2.2.0\n      - wandb==0.13.5\n      - flash-attn==2.5.6\n      - causal-conv1d===1.2.0.post2\n      - mamba-ssm==1.2.0.post1\n"
  },
  {
    "path": "configs/callbacks/base.yaml",
    "content": "learning_rate_monitor:\n  # _target_: pytorch_lightning.callbacks.LearningRateMonitor\n  logging_interval: ${train.interval}\n\ntimer:\n  # _target_: callbacks.timer.Timer\n  step: True\n  inter_step: False\n  epoch: True\n  val: True\n\nparams:\n  # _target_: callbacks.params.ParamsLog\n  total: True\n  trainable: True\n  fixed: True\n"
  },
  {
    "path": "configs/callbacks/checkpoint.yaml",
    "content": "model_checkpoint:\n  monitor: ${train.monitor} # name of the logged metric which determines when model is improving\n  mode: ${train.mode} # can be \"max\" or \"min\"\n  save_top_k: 1 # save k best models (determined by above metric)\n  save_last: False # True = additionally always save model from last epoch\n  dirpath: \"checkpoints/\"\n  filename: ${train.monitor}\n  auto_insert_metric_name: False\n  verbose: True\n\nmodel_checkpoint_every_n_steps:\n  monitor: train/loss # name of the logged metric which determines when model is improving\n  mode: min # can be \"max\" or \"min\"\n  save_top_k: 0 # Do not save any \"best\" models; this callback is being used to save every n train steps\n  save_last: True # additionally always save model from last epoch\n  dirpath: \"checkpoints/\"\n  filename: train/loss\n  auto_insert_metric_name: False\n  verbose: True\n  every_n_train_steps: 100\n\n#model_checkpoint_every_epoch:\n#  monitor: trainer/epoch  # name of the logged metric which determines when model is improving\n#  mode: max # can be \"max\" or \"min\"\n#  save_top_k: 1 # Do not save any \"best\" models; this callback is being used to save every n train steps\n#  save_last: False # additionally always save model from last epoch\n#  dirpath: \"checkpoints/\"\n#  filename: null\n#  auto_insert_metric_name: False\n#  verbose: True\n#  every_n_epochs: 1\n"
  },
  {
    "path": "configs/callbacks/gpu_affinity.yaml",
    "content": "gpu_affinity:\n  _name_: gpu_affinity\n"
  },
  {
    "path": "configs/callbacks/rich.yaml",
    "content": "rich_model_summary:\n  max_depth: 2\n\nrich_progress_bar:\n  refresh_rate_per_second: 1.0\n"
  },
  {
    "path": "configs/callbacks/val_every_n_global_steps.yaml",
    "content": "val_every_n_global_steps:\n  every_n: 10000\n"
  },
  {
    "path": "configs/callbacks/wandb.yaml",
    "content": "defaults:\n  - default\n\nwatch_model:\n  _target_: src.callbacks.wandb_callbacks.WatchModel\n  log: \"all\"\n  log_freq: 100\n\nupload_code_as_artifact:\n  _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact\n  code_dir: ${work_dir}/src\n\nupload_ckpts_as_artifact:\n  _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n  ckpt_dir: \"checkpoints/\"\n  upload_best_only: True\n\nlog_f1_precision_recall_heatmap:\n  _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap\n\nlog_confusion_matrix:\n  _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix\n\nlog_image_predictions:\n  _target_: src.callbacks.wandb_callbacks.LogImagePredictions\n  num_samples: 8\n"
  },
  {
    "path": "configs/config.yaml",
    "content": "# @package _global_\ndefaults:\n  - _self_\n  - experiment: ???\n  # - model: ???  # Model backbone\n  # - pipeline: ???  # Specifies collection of configs, equivalent to next 5 lines\n  # Pipelines should specify /loader, /dataset, /task, /encoder, /decoder (ideally in that order)\n  # # - loader: default # Dataloader (e.g. handles batches)\n  # # - dataset: cifar # Defines the data (x and y pairs)\n  # # - task: multiclass_classification # Defines loss and metrics\n  # # - encoder: null # Interface between data and model\n  # # - decoder: null # Interface between model and targets\n\n# Additional arguments used to configure the training loop\n# Most of these set combinations of options in the PL trainer, add callbacks, or add features to the optimizer\ntrain:\n  seed: 0\n  # These three options are used by callbacks (checkpoint, monitor) and scheduler\n  # Most of them are task dependent and are set by the pipeline\n  interval: ??? # Should be specified by scheduler. Also used by LR monitor\n  monitor: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer\n  mode: ??? # Should be specified by pipeline. Used by scheduler (plateau) and checkpointer\n  ema: 0.0 # Moving average model for validation\n  test: True # Test after training\n  debug: False # Special settings to make debugging more convenient\n  ignore_warnings: False # Disable python warnings\n\n  optimizer_param_grouping:\n    bias_weight_decay: False\n    normalization_weight_decay: False\n\n  # These control state passing between batches\n  state:\n    mode: null # [ None | 'none' | 'reset' | 'bptt' | 'tbptt' ]\n    n_context: 0 # How many steps to use as memory context. Must be >= 0 or None (null), meaning infinite context\n    n_context_eval: ${.n_context} # Context at evaluation time\n  # Convenience keys to allow grouping runs\n\n  ckpt: checkpoints/last.ckpt # Resume training\n\n  disable_dataset: False # Disable dataset loading\n  validate_at_start: false\n\n  pretrained_model_path: null # Path to pretrained model\n  pretrained_model_strict_load: true # Whether to load the pretrained model even if the model is not compatible\n  pretrained_model_state_hook: # Hook called on the loaded model's state_dict\n    _name_: null\n  post_init_hook: # After initializing model, call method on model\n    _name_: null\n\n  layer_decay: # Used for ImageNet finetuning\n    _name_: null\n    decay: 0.7\n\n# We primarily use wandb so this is moved to top level in the config for convenience\n# Set `~wandb` or `wandb=null` or `wandb.mode=disabled` to disable logging\n# If other loggers are added, it would make sense to put this one level lower under train/ or logger/\nwandb:\n  project: dna\n  group: \"\"\n  job_type: training\n  mode: online # choices=['online', 'offline', 'disabled']\n  name: null\n  save_dir: \".\"\n  id: ${.name} # pass correct id to resume experiment!\n  # Below options should not need to be specified\n  # entity: \"\"  # set to name of your wandb team or just remove it\n  # log_model: False\n  # prefix: \"\"\n  # job_type: \"train\"\n  # tags: []\n\nhydra:\n  run:\n    dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S-%f}\n  job:\n    chdir: true\n"
  },
  {
    "path": "configs/dataset/genomic_benchmark.yaml",
    "content": "_name_: genomic_benchmark\ntrain_val_split_seed: ${train.seed}  # Used for train/validation splitting\ndataset_name: dummy_mouse_enhancers_ensembl\ndest_path: null\nmax_length: ${.${.dataset_name}.max_length}\nmax_length_val: ${.max_length}\nmax_length_test: ${.max_length}\nd_output: ${.${.dataset_name}.classes}\nuse_padding: True\npadding_side: 'left'\nadd_eos: False\nbatch_size: 128\ntrain_len: ${.${.dataset_name}.train_len}\n__l_max: ${.max_length}\nshuffle: true  # set this as default!\n# these are used to find the right attributes automatically for each dataset\ndummy_mouse_enhancers_ensembl:\n  train_len: 1210\n  classes: 2\n  max_length: 1024\ndemo_coding_vs_intergenomic_seqs:\n  train_len: 100_000\n  classes: 2\n  max_length: 200\ndemo_human_or_worm:\n  train_len: 100_000\n  classes: 2\n  max_length: 200\nhuman_enhancers_cohn:\n  train_len: 27791\n  classes: 2\n  max_length: 500\nhuman_enhancers_ensembl:\n  train_len: 154842\n  classes: 2\n  max_length: 512\nhuman_ensembl_regulatory:\n  train_len: 289061\n  classes: 3\n  max_length: 512\nhuman_nontata_promoters:\n  train_len: 36131\n  classes: 2\n  max_length: 251\nhuman_ocr_ensembl:\n  train_len: 174756\n  classes: 2\n  max_length: 512\n\n# there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings\n# name                                num_seqs        num_classes     median len    std\n# dummy_mouse_enhancers_ensembl       1210            2               2381          984.4  \n# demo_coding_vs_intergenomic_seqs    100_000         2               200           0\n# demo_human_or_worm                  100_000         2               200           0\n# human_enhancers_cohn                27791           2               500           0\n# human_enhancers_ensembl             154842          2               269           122.6\n# human_ensembl_regulatory            289061          3               401           184.3\n# human_nontata_promoters             36131           2               251           0\n# human_ocr_ensembl                   174756          2               315           108.1\n"
  },
  {
    "path": "configs/dataset/hg38.yaml",
    "content": "_name_: hg38\nbed_file: null\nfasta_file: null\ndataset_name: hg38\ntokenizer_name: null\ncache_dir: null\nmax_length: 1024\nadd_eos: True\nbatch_size: 8  # per GPU\nbatch_size_eval: ${eval:${.batch_size} * 2}\nnum_workers: 4  # For preprocessing only\nshuffle: True\n__train_len: 34021\n__l_max: ${.max_length}\n"
  },
  {
    "path": "configs/dataset/nucleotide_transformer.yaml",
    "content": "_name_: nucleotide_transformer  # this links to the overall SequenceDataset of all nucleotide transformer datasets\ntrain_val_split_seed: ${train.seed}  # Used for train/validation splitting\ndataset_name: enhancers  # this specifies which dataset in nuc trx\ndest_path: null  # path to overall nuc trx datasets\nmax_length: ${.${.dataset_name}.max_length}\nd_output: ${.${.dataset_name}.classes} \nuse_padding: True\npadding_side: left\nadd_eos: False\nbatch_size: 256\ntrain_len: ${.${.dataset_name}.train_len}\n__l_max: ${.max_length}\nshuffle: true  # set this as default!\nmetric: ${.${.dataset_name}.metric}\n# these are used to find the right attributes automatically for each dataset\nenhancers:\n  train_len: 14968\n  classes: 2\n  max_length: 200\n  metric: mcc\nenhancers_types:\n  train_len: 14968\n  classes: 3\n  max_length: 200\n  metric: mcc\nH3:\n  train_len: 13468\n  classes: 2\n  max_length: 500\n  metric: mcc\nH3K4me1:\n  train_len: 28509\n  classes: 2\n  max_length: 500\n  metric: mcc\nH3K4me2:\n  train_len: 27614\n  classes: 2\n  max_length: 500\n  metric: mcc\nH3K4me3:\n  train_len: 33119\n  classes: 2\n  max_length: 500\n  metric: mcc\nH3K9ac:\n  train_len: 25003\n  classes: 2\n  max_length: 500\n  metric: mcc\nH3K14ac:\n  train_len: 29743\n  classes: 2\n  max_length: 500\n  metric: mcc\nH3K36me3:\n  train_len: 31392\n  classes: 2\n  max_length: 500\n  metric: mcc\nH3K79me3:\n  train_len: 25953\n  classes: 2\n  max_length: 500\n  metric: mcc\nH4:\n  train_len: 13140\n  classes: 2\n  max_length: 500\n  metric: mcc\nH4ac:\n  train_len: 30685\n  classes: 2\n  max_length: 500\n  metric: mcc\npromoter_all:\n  train_len: 53276\n  classes: 2\n  max_length: 300\n  metric: f1_binary\npromoter_no_tata:\n  train_len: 47767\n  classes: 2\n  max_length: 300\n  metric: f1_binary\npromoter_tata:\n  train_len: 5517\n  classes: 2\n  max_length: 300\n  metric: f1_binary\nsplice_sites_acceptors:\n  train_len: 19961\n  classes: 2\n  max_length: 600\n  metric: f1_binary\nsplice_sites_all:\n  train_len: 27000\n  classes: 3\n  max_length: 400\n  metric: accuracy\nsplice_sites_donors:\n  train_len: 19775\n  classes: 2\n  max_length: 600\n  metric: f1_binary\n\n# name maxlen classes samples metric\n\n# enhancers 200   2  14968 MCC\n# enhancers_types 200   3  14968 MCC\n# H3 500   2  13468 MCC\n# H3K4me1  500   2  28509 MCC\n# H3K4me2  500   2  27614 MCC\n# H3K4me3  500   2  33119 MCC\n# H3K9ac   500   2  25003 MCC\n# H3K14ac  500   2  29743 MCC\n# H3K36me3 500   2  31392 MCC\n# H3K79me3 500   2  25953 MCC\n# H4 500   2  13140 MCC\n# H4ac  500   2  30685 MCC\n# promoter_all   300   2  53276 F1\n# promoter_no_tata 300   2  47759 F1\n# promoter_tata  300   2  5517  F1\n# splice_sites_acceptor   600   2  19961 F1\n# splice_sites_all   400   2  27000 F1\n# splice_sites_donor   600   2  19775 F1\n"
  },
  {
    "path": "configs/experiment/hg38/genomic_benchmark.yaml",
    "content": "# @package _global_\ndefaults:\n  - /pipeline: genomic_benchmark\n  - /model: ???\n  - override /scheduler: cosine_warmup_timm\n\n# there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings\n# name                                num_seqs        num_classes     median len    std\n# dummy_mouse_enhancers_ensembl       1210            2               2381          984.4  \n# demo_coding_vs_intergenomic_seqs    100_000         2               200           0\n# demo_human_or_worm                  100_000         2               200           0\n# human_enhancers_cohn                27791           2               500           0\n# human_enhancers_ensembl             154842          2               269           122.6\n# human_ensembl_regulatory            289061          3               401           184.3\n# human_nontata_promoters             36131           2               251           0\n# human_ocr_ensembl                   174756          2               315           108.1\n\ntask:\n  loss:\n    _name_: cross_entropy\n\ntrainer:\n  accelerator: gpu\n  devices: 1\n  num_nodes: 1\n  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}\n  max_epochs: 100\n  precision: 16  # bf16 only a100\n  gradient_clip_val: 1.0\n\nmodel:\n  _name_: dna_embedding\n\ndataset:\n  # optional, default is max_length\n  tokenizer_name: char\n  rc_aug: false  # reverse complement augmentation\n\nscheduler:\n# COSINE TIMM\n  t_in_epochs: False\n  t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}}\n  warmup_lr_init: 1e-6\n  warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01}\n  lr_min: ${eval:0.1 * ${optimizer.lr}}\n\n\noptimizer:\n  lr: 6e-4\n  weight_decay: 0.1\n\ntrain:\n  gpu_mem: ${eval:\"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)\"}\n  seed: 2222\n  global_batch_size: ${dataset.batch_size}\n  cross_validation: true\n  remove_test_loader_in_eval: true  # test only at the end of training\n  pretrained_model_strict_load: false  # false allows encoder/decoder to be used if new model uses it\n  # for loading backbone and not head, requires both of these flags below\n  pretrained_model_path: ???\n  pretrained_model_state_hook:\n    _name_: load_backbone\n    freeze_backbone: false\n"
  },
  {
    "path": "configs/experiment/hg38/genomic_benchmark_cnn.yaml",
    "content": "# @package _global_\ndefaults:\n  - /model: genomics_benchmark_cnn\n  - /pipeline: genomic_benchmark\n  - override /scheduler: cosine_warmup_timm\n\n# there are 8 datasets in this suite, choose 1 at a time, with their corresponding settings\n# name                                num_seqs        num_classes     median len    std\n# dummy_mouse_enhancers_ensembl       1210            2               2381          984.4\n# demo_coding_vs_intergenomic_seqs    100_000         2               200           0\n# demo_human_or_worm                  100_000         2               200           0\n# human_enhancers_cohn                27791           2               500           0\n# human_enhancers_ensembl             154842          2               269           122.6\n# human_ensembl_regulatory            289061          3               401           184.3\n# human_nontata_promoters             36131           2               251           0\n# human_ocr_ensembl                   174756          2               315           108.1\n\ntask:\n  loss:\n    _name_: cross_entropy\n\ntrainer:\n  accelerator: gpu\n  devices: 1\n  num_nodes: 1\n  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}\n  max_epochs: 100\n  precision: 16  # bf16 only a100\n  gradient_clip_val: 1.0\n\nencoder: id\ndecoder: id\n\ndataset:\n  tokenizer_name: char\n  rc_aug: false  # reverse complement augmentation\n\nscheduler:\n  t_in_epochs: False\n  t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}}\n  warmup_lr_init: 1e-6\n  warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01}\n  lr_min: ${eval:0.1 * ${optimizer.lr}}\n\n\noptimizer:\n  lr: 6e-4\n  weight_decay: 0.1\n\ntrain:\n  gpu_mem: ${eval:\"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)\"}\n  seed: 2222\n  global_batch_size: ${dataset.batch_size}\n  cross_validation: true\n  remove_test_loader_in_eval: true\n  pretrained_model_strict_load: false  # false allows encoder/decoder to be used if new model uses it\n"
  },
  {
    "path": "configs/experiment/hg38/hg38.yaml",
    "content": "# @package _global_\ndefaults:\n  - /pipeline: hg38\n  - /model: ???  # Specify a model, e.g. model=mamba or model=hyena\n  - override /scheduler: cosine_warmup_timm\n\ntask:\n  _name_: lm\n  loss:\n    _name_: cross_entropy\n    ignore_index: 4\n\ntrainer:\n  accelerator: gpu\n  devices: 1\n  num_nodes: 1\n  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}\n  max_epochs: null\n  max_steps: 10000\n  precision: 16  # bf16 only a100\n  gradient_clip_val: 1.0\n  limit_val_batches: 0.125\n\ndataset:\n  batch_size: ${eval:1024//${trainer.devices}}\n  max_length: 1024\n  # optional, default is max_length\n  max_length_val: ${dataset.max_length}\n  max_length_test: ${dataset.max_length}\n  tokenizer_name: char\n  pad_max_length: null  # needed for bpe tokenizer\n  add_eos: true\n  rc_aug: false\n  num_workers: 12\n  use_fixed_len_val: false  # placing a fixed length val here, but it's really the test\n  mlm: false\n  mlm_probability: 0.0\n\nscheduler:\n  t_in_epochs: False\n  t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}\n  warmup_prefix: True\n  warmup_lr_init: 1e-6\n  warmup_t: ${eval:0.1*${trainer.max_steps}}\n  lr_min: 1e-4\n\noptimizer:\n  lr: 6e-4\n  weight_decay: 0.1\n  betas: [0.9, 0.95]\n\ntrain:\n  gpu_mem: ${eval:\"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)\"}\n  seed: 2222\n  global_batch_size: 256  # effects the scheduler, need to set properly\n"
  },
  {
    "path": "configs/experiment/hg38/nucleotide_transformer.yaml",
    "content": "# @package _global_\ndefaults:\n  - /pipeline: nucleotide_transformer\n  - /model: ???\n  - override /scheduler: cosine_warmup_timm\n\nmodel:\n  _name_: dna_embedding\n\ntrainer:\n  accelerator: gpu\n  devices: 1\n  num_nodes: 1\n  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${dataset.batch_size} * ${trainer.num_nodes}}}\n  max_epochs: 100\n  precision: 16  # bf16 only a100\n  gradient_clip_val: 1.0\n\ndataset:\n  tokenizer_name: char\n  rc_aug: false  # reverse complement augmentation\n\nscheduler:\n  t_in_epochs: False\n  t_initial: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs}}\n  warmup_lr_init: 1e-6\n  warmup_t: ${eval:${div_up:${dataset.train_len}, ${train.global_batch_size}} * ${trainer.max_epochs} * 0.01}\n  lr_min: ${eval:0.1 * ${optimizer.lr}}\n\noptimizer:\n  lr: 1e-3\n  weight_decay: 0.1\n\ntrain:\n  gpu_mem: ${eval:\"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)\"}\n  seed: 2222\n  global_batch_size: ${dataset.batch_size}\n  cross_validation: true\n  remove_test_loader_in_eval: true  # test only at the end of training\n  pretrained_model_strict_load: false  # false allows encoder/decoder to be used if new model uses it\n  # for loading backbone and not head, requires both of these flags below\n  pretrained_model_path: ???\n  pretrained_model_state_hook:\n    _name_: load_backbone\n    freeze_backbone: false\n"
  },
  {
    "path": "configs/loader/default.yaml",
    "content": "num_workers: ${eval:\"len(__import__('os').sched_getaffinity(0))\"}\npin_memory: True\ndrop_last: True\n"
  },
  {
    "path": "configs/model/caduceus.yaml",
    "content": "# Use open-source version of Mamba\n_name_: caduceus_lm\nconfig:\n  _target_: caduceus.configuration_caduceus.CaduceusConfig\n  # From original MambaConfig\n  d_model: 128\n  n_layer: 2\n  vocab_size: 12\n  ssm_cfg:\n    d_state: 16\n    d_conv: 4\n    expand: 2\n    dt_rank: \"auto\"\n    dt_min: 0.001\n    dt_max: 0.1\n    dt_init: \"random\"\n    dt_scale: 1.0\n    dt_init_floor: 1e-4\n    conv_bias: true\n    bias: false\n    use_fast_path: true\n  rms_norm: true\n  fused_add_norm: true\n  residual_in_fp32: false\n  pad_vocab_size_multiple: 8\n  # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm\n  norm_epsilon: 1e-5\n\n  # Used in init_weights\n  initializer_cfg:\n    initializer_range: 0.02\n    rescale_prenorm_residual: true\n    n_residuals_per_layer: 1\n\n  # Caduceus-specific params\n  bidirectional: true,\n  bidirectional_strategy: \"add\"\n  bidirectional_weight_tie: true\n  rcps: false\n\n  # Used for RCPSEmbedding / RCPSLMHead (will be filled in during model instantiation using info from tokenizer)\n  complement_map: null\n"
  },
  {
    "path": "configs/model/genomics_benchmark_cnn.yaml",
    "content": "# Use open-source version of Mamba\n_name_: genomics_benchmark_cnn\nnumber_of_classes: ${dataset.d_output}\nvocab_size: 12\nembedding_dim: 100  # See: https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/tree/main/experiments/torch_cnn_experiments\ninput_len: ${dataset.__l_max}\n"
  },
  {
    "path": "configs/model/hyena.yaml",
    "content": "_name_: hyena_lm\nd_model: 128\nn_layer: 2\nd_inner: ${eval:4 * ${.d_model}}\nvocab_size: 12\nresid_dropout: 0.0\nembed_dropout: 0.1\nfused_mlp: False\nfused_dropout_add_ln: False\ncheckpoint_mixer: False  # set true for memory reduction\ncheckpoint_mlp: False  # set true for memory reduction\nresidual_in_fp32: True\npad_vocab_size_multiple: 8\nlayer:\n  _name_: hyena\n  emb_dim: 5\n  filter_order: 64\n  local_order: 3\n  l_max: ${eval:${dataset.max_length}+2}\n  modulate: True\n  w: 10\n  lr: ${optimizer.lr}\n  wd: 0.0\n  lr_pos_emb: 0.0\n"
  },
  {
    "path": "configs/model/layer/hyena.yaml",
    "content": "_name_: hyena\nl_max: 1024\norder: 2\nfilter_order: 64\nnum_heads: 1\ninner_factor: 1\nnum_blocks: 1\nfused_bias_fc: false\nouter_mixing: false\ndropout: 0.0 \nfilter_dropout: 0.0\nfilter_cls: 'hyena-filter'\npost_order_ffn: false\njit_filter: false\nshort_filter_order: 3\nactivation: \"id\""
  },
  {
    "path": "configs/model/mamba.yaml",
    "content": "# Use open-source version of Mamba\n_name_: mamba_lm\nconfig:\n  _target_: mamba_ssm.models.config_mamba.MambaConfig\n  d_model: 128  # Will be overwritten by CL in the scaling exps\n  n_layer: 2  # Will be overwritten by CL in the scaling exps\n  vocab_size: 12\n  pad_vocab_size_multiple: 8\n  rms_norm: true\n  fused_add_norm: true\n  residual_in_fp32: false\n  ssm_cfg:\n    d_state: 16\n    d_conv: 4\n    expand: 2\n    dt_rank: \"auto\"\n    dt_min: 0.001\n    dt_max: 0.1\n    dt_init: \"random\"\n    dt_scale: 1.0\n    dt_init_floor: 1e-4\n    conv_bias: true\n    bias: false\n    use_fast_path: true\ninitializer_cfg:\n  initializer_range: 0.02\n  rescale_prenorm_residual: true\n  n_residuals_per_layer: 1\n#norm_epsilon: 1e-5  # Default arg in mamba create_block\n"
  },
  {
    "path": "configs/optimizer/adam.yaml",
    "content": "# _target_: torch.optim.Adam\n_name_: adam\nlr: 0.001  # Initial learning rate\n# weight_decay: 0.0  # Weight decay for adam|lamb; should use AdamW instead if desired\nbetas: [0.9, 0.999]\n"
  },
  {
    "path": "configs/optimizer/adamw.yaml",
    "content": "# _target_: torch.optim.AdamW\n_name_: adamw\nlr: 0.001 # Initial learning rate\nweight_decay: 0.00 # Weight decay\nbetas: [0.9, 0.999]\n"
  },
  {
    "path": "configs/optimizer/sgd.yaml",
    "content": "# _target_: torch.optim.SGD\n_name_: sgd\nlr: 0.001  # Initial learning rate\nmomentum: 0.9\nweight_decay: 0.0  # Weight decay for adam|lamb\n"
  },
  {
    "path": "configs/pipeline/genomic_benchmark.yaml",
    "content": "# @package _global_\ndefaults:\n  - /trainer: default\n  - /loader: default\n  - /dataset: genomic_benchmark\n  - /task: multiclass_classification\n  - /optimizer: adamw\n  - /scheduler: plateau\n  - /callbacks: [base, checkpoint]\n\ntrain:\n  monitor: val/accuracy # Needed for plateau scheduler\n  mode: max\n\nencoder: id\n\n# we need this for classification!\ndecoder:\n  _name_: sequence\n  mode: pool\n"
  },
  {
    "path": "configs/pipeline/hg38.yaml",
    "content": "# @package _global_\ndefaults:\n  - /trainer: default\n  - /loader: null\n  - /dataset: hg38\n  - /optimizer: adamw\n  - /scheduler: cosine_warmup\n  - /callbacks: [base, checkpoint]\n\ntrain:\n  monitor: test/loss\n  mode: min\n\ntask:\n  _name_: lm\n  loss:\n    _name_: cross_entropy\n    ignore_index: 4  # Bake in tokenizer value for padding / EOS tokens\n  torchmetrics: ['perplexity', 'num_tokens']\n\nencoder: null\ndecoder: null\n\nloader:\n  num_workers: ${eval:\"len(__import__('os').sched_getaffinity(0))\"}\n  pin_memory: True\n  drop_last: True  # There's enough data and epochs, ignore the edge case\n  # shuffle: True\n"
  },
  {
    "path": "configs/pipeline/nucleotide_transformer.yaml",
    "content": "# @package _global_\ndefaults:\n  - /trainer: default\n  - /loader: default\n  - /dataset: nucleotide_transformer\n  - /task: multiclass_classification\n  - /optimizer: adamw\n  - /scheduler: plateau\n  - /callbacks: [base, checkpoint]\n\ntask:\n  loss:\n    _name_: cross_entropy\n  metrics:\n    - ${dataset.metric}\n\ntrain:\n  monitor: val/${dataset.metric}\n  mode: max\n\nencoder: id\n\n# we need this for classification!\ndecoder:\n  _name_: sequence\n  mode: pool"
  },
  {
    "path": "configs/scheduler/constant.yaml",
    "content": "# @package _global_\ntrain:\n  interval: epoch\nscheduler:\n  # _target_: transformers.get_constant_schedule\n  _name_: constant\n"
  },
  {
    "path": "configs/scheduler/constant_warmup.yaml",
    "content": "# @package _global_\ntrain:\n  interval: step\nscheduler:\n  # _target_: transformers.get_constant_schedule_with_warmup\n  _name_: constant_warmup\n  num_warmup_steps: 1000  # Number of iterations for LR warmup\n"
  },
  {
    "path": "configs/scheduler/cosine.yaml",
    "content": "# @package _global_\ntrain:\n  interval: epoch\nscheduler:\n  # _target_: torch.optim.lr_scheduler.CosineAnnealingLR\n  _name_: cosine\n  T_max: 100  # Max number of epochs steps for LR scheduler\n  eta_min: 1e-6  # Min learning rate for cosine scheduler\n"
  },
  {
    "path": "configs/scheduler/cosine_warmup.yaml",
    "content": "# @package _global_\ntrain:\n  interval: step\nscheduler:\n  # _target_: transformers.get_cosine_schedule_with_warmup\n  _name_: cosine_warmup\n  num_warmup_steps: 1000\n  num_training_steps: 40000\n"
  },
  {
    "path": "configs/scheduler/cosine_warmup_timm.yaml",
    "content": "# @package _global_\ntrain:\n  interval: step\nscheduler:\n  # _target_: transformers.get_cosine_schedule_with_warmup\n  _name_: cosine_warmup_timm\n  t_in_epochs: False\n  t_initial: 300\n  lr_min: 1e-5\n  warmup_lr_init: 1e-6\n  warmup_t: 10\n"
  },
  {
    "path": "configs/scheduler/linear_warmup.yaml",
    "content": "# @package _global_\ntrain:\n  interval: step\nscheduler:\n  # _target_: transformers.get_linear_schedule_with_warmup\n  _name_: linear_warmup\n  num_warmup_steps: 1000\n  num_training_steps: 40000\n"
  },
  {
    "path": "configs/scheduler/multistep.yaml",
    "content": "# @package _global_\ntrain:\n  interval: epoch\n# _target_: torch.optim.lr_scheduler.MultiStepLR\nscheduler:\n  _name_: multistep\n  milestones: [80,140,180]\n  gamma: 0.2\n"
  },
  {
    "path": "configs/scheduler/plateau.yaml",
    "content": "# @package _global_\ntrain:\n  interval: epoch\n  monitor: ??? # must be specified\nscheduler:\n  # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau\n  _name_: plateau\n  mode: ${train.mode} # Which metric to monitor\n  factor: 0.2  # Decay factor when ReduceLROnPlateau is used\n  patience: 20\n  min_lr: 0.0  # Minimum learning rate during annealing\n"
  },
  {
    "path": "configs/scheduler/step.yaml",
    "content": "# @package _global_\ntrain:\n  interval: step\nscheduler:\n  # _target_: torch.optim.lr_scheduler.StepLR\n  _name_: step\n  step_size: 1\n  gamma: 0.99\n"
  },
  {
    "path": "configs/task/lm.yaml",
    "content": "_name_: lm\n# loss: cross_entropy # Handled by task: cross entropy loss\nmetrics: ppl\n"
  },
  {
    "path": "configs/task/multiclass_classification.yaml",
    "content": "# _target_: tasks.tasks.MultiClass\n_name_: multiclass\nloss: cross_entropy\nmetrics:\n  - accuracy\ntorchmetrics: null\n"
  },
  {
    "path": "configs/task/multilabel_classification.yaml",
    "content": "# _target_:\n_name_: base\nloss: binary_cross_entropy\nmetrics: null\ntorchmetrics:\n  - MultilabelAUROC # AUROC\n  - MultilabelAveragePrecision # Precision\n#  - Recall # not supported in torchmetrics\n#  - F1 # not supported in torchmetrics\n"
  },
  {
    "path": "configs/task/regression.yaml",
    "content": "# _target_: tasks.tasks.BaseTask\n_name_: base\nloss: mse\nmetrics: mse\ntorchmetrics: null\n"
  },
  {
    "path": "configs/trainer/debug.yaml",
    "content": "defaults:\n  - default\n\ngpus: 1\nmin_epochs: 1\nmax_epochs: 10\n\n# prints\nprogress_bar_refresh_rate: null\nweights_summary: full\nprofiler: null\n\n# debugs\nfast_dev_run: False\nnum_sanity_val_steps: 2\noverfit_batches: 0\nlimit_train_batches: 0.1\nlimit_val_batches: 0.1\nlimit_test_batches: 0.1\ntrack_grad_norm: -1\nterminate_on_nan: False\n"
  },
  {
    "path": "configs/trainer/default.yaml",
    "content": "_target_: pytorch_lightning.Trainer\n\ndevices: 1\naccelerator: gpu\naccumulate_grad_batches: 1 # Gradient accumulation every n batches\nmax_epochs: 200\n                           # accelerator: ddp # Automatically set if gpus > 1\ngradient_clip_val: 0.0\nlog_every_n_steps: 10\nlimit_train_batches: 1.0   # train on full dataset, can be used to toggle quick run\nlimit_val_batches: 1.0     # validate on full dataset, can be used to toggle quick run\nnum_sanity_val_steps: 2    # default value: 2; override to 0 to skip sanity checking\n"
  },
  {
    "path": "configs/trainer/full.yaml",
    "content": "_target_: pytorch_lightning.Trainer\n\n# default values for all trainer parameters\ncheckpoint_callback: True\ndefault_root_dir: null\ngradient_clip_val: 0.0\nprocess_position: 0\nnum_nodes: 1\nnum_processes: 1\ngpus: null\nauto_select_gpus: False\ntpu_cores: null\nlog_gpu_memory: null\noverfit_batches: 0.0\ntrack_grad_norm: -1\ncheck_val_every_n_epoch: 1\nfast_dev_run: False\naccumulate_grad_batches: 1\nmax_epochs: 1\nmin_epochs: 1\nmax_steps: null\nmin_steps: null\nlimit_train_batches: 1.0\nlimit_val_batches: 1.0\nlimit_test_batches: 1.0\nval_check_interval: 1.0\nflush_logs_every_n_steps: 100\nlog_every_n_steps: 50\naccelerator: null\nsync_batchnorm: False\nprecision: 32\nweights_summary: \"top\"\nweights_save_path: null\nnum_sanity_val_steps: 2\ntruncated_bptt_steps: null\nresume_from_checkpoint: null\nprofiler: null\nbenchmark: False\ndeterministic: False\nreload_dataloaders_every_epoch: False\nauto_lr_find: False\nreplace_sampler_ddp: True\nterminate_on_nan: False\nauto_scale_batch_size: False\nprepare_data_per_node: True\nplugins: null\namp_backend: \"native\"\namp_level: \"O2\"\nmove_metrics_to_cpu: False\n"
  },
  {
    "path": "configs/trainer/lm.yaml",
    "content": "accumulate_grad_batches: 1\n# accelerator: null # set to 'ddp' for distributed\n# amp_backend: native # 'native' | 'apex'\ngpus: 8\nmax_epochs: 50\ngradient_clip_val: 0.0 # Gradient clipping\nlog_every_n_steps: 10\nprecision: 16\nprogress_bar_refresh_rate: 1\nweights_summary: top # Set to 'full' to see every layer\ntrack_grad_norm: -1 # Set to 2 to track norms of gradients\nlimit_train_batches: 1.0\nlimit_val_batches: 1.0\n# We use the dataloader from Transformer-XL to ensure adjacent minibatches\n# are from text that are next to each other.\n# So that dataloader has to deal with DDP, and we don't want PL to handle\n# that.\nreplace_sampler_ddp: False\n"
  },
  {
    "path": "setup_env.sh",
    "content": "#!/bin/bash\n\n# Shell script to set environment variables when running code in this repository.\n# Usage:\n#     source setup_env.sh\n\n# Activate conda env\n# shellcheck source=${HOME}/.bashrc disable=SC1091\nsource \"${CONDA_SHELL}\"\nif [ -z \"${CONDA_PREFIX}\" ]; then\n    conda activate caduceus_env\n elif [[ \"${CONDA_PREFIX}\" != *\"/caduceus_env\" ]]; then\n  conda deactivate\n  conda activate caduceus_env\nfi\n\n# Add root directory to PYTHONPATH to enable module imports\nexport PYTHONPATH=\"${PWD}\"\n"
  },
  {
    "path": "slurm_scripts/dump_vep_embeddings.sh",
    "content": "#!/bin/bash\n#SBATCH --get-user-env                      # Retrieve the users login environment\n#SBATCH -t 96:00:00                         # Time limit (hh:mm:ss)\n#SBATCH --mem=100G                          # RAM\n#SBATCH --gres=gpu:8                        # Number of GPUs\n#SBATCH --ntasks-per-node=8                 # Should correspond to num devices (at least 1-1 task to GPU)\n##SBATCH --cpus-per-task=4                  # Number of CPU cores per task\n#SBATCH -N 1                                # Number of nodes\n#SBATCH --requeue                           # Requeue job if it fails\n#SBATCH --job-name=vep_embed                # Job name\n#SBATCH--output=../watch_folder/%x_%j.log   # Output file name\n#SBATCH --open-mode=append                  # Do not overwrite logs\n\nNUM_WORKERS=2\nNUM_DEVICES=8\n\n# Setup environment\ncd ../ || exit  # Go to the root directory of the repo\nsource setup_env.sh\nexport CUDA_LAUNCH_BLOCKING=1\nexport CUBLAS_WORKSPACE_CONFIG=:4096:8  # Needed for setting deterministic functions for reproducibility\n\n#####################################################################################\n# Choose from one of the following:\n\n## Enformer\n#seq_len=196608\n#bp_per_token=1\n#embed_dump_batch_size=1\n#model_name_or_path=\"EleutherAI/enformer-official-rough\"\n#name=\"enformer-seqlen=196k\"\n#rcps_flag=\"no-rcps\"\n\n## NTv2\n#seq_len=12288  # 2048 (seq len) * 6 (kmers)\n#bp_per_token=6\n#embed_dump_batch_size=1\n#model_name_or_path=\"InstaDeepAI/nucleotide-transformer-v2-500m-multi-species\"\n#name=\"NTv2_downstream-seqlen=12k\"\n#rcps_flag=\"no-rcps\"\n\n## Hyena\n#seq_len=131072\n#bp_per_token=1\n#embed_dump_batch_size=1\n#model_name_or_path=\"LongSafari/hyenadna-medium-160k-seqlen-hf\"\n#name=\"hyena_downstream-seqlen=131k\"\n#rcps_flag=\"no-rcps\"\n\n## Caduceus-Ph\n#seq_len=131072\n#bp_per_token=1\n#embed_dump_batch_size=1\n#model_name_or_path=\"kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16\"\n#name=\"caduceus-ph_downstream-seqlen=131k\"\n#rcps_flag=\"no-rcps\"\n\n## Caduceus-PS\n#seq_len=131072\n#bp_per_token=1\n#embed_dump_batch_size=1\n#model_name_or_path=\"kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16\"\n#name=\"caduceus-ps_downstream-seqlen=131k\"\n#rcps_flag=\"rcps\"\n#####################################################################################\n\ntorchrun \\\n    --standalone \\\n    --nnodes=1 \\\n    --nproc-per-node=${NUM_DEVICES} \\\n    vep_embeddings.py \\\n      --num_workers=${NUM_WORKERS} \\\n      --seq_len=${seq_len}  \\\n      --bp_per_token=${bp_per_token}  \\\n      --embed_dump_batch_size=${embed_dump_batch_size} \\\n      --name=\"${name}\"  \\\n      --model_name_or_path=\"${model_name_or_path}\" \\\n      --\"${rcps_flag}\"\n"
  },
  {
    "path": "slurm_scripts/run_genomics_benchmark.sh",
    "content": "#!/bin/bash\n#SBATCH --get-user-env                   # Retrieve the users login environment\n#SBATCH -t 96:00:00                      # Time limit (hh:mm:ss)\n#SBATCH --mem=64000M                     # RAM\n#SBATCH --gres=gpu:1                     # Number of GPUs\n#SBATCH --ntasks-per-node=1\n#SBATCH --cpus-per-task=2\n#SBATCH -N 1                             # Number of nodes\n#SBATCH --requeue                        # Requeue job if it fails\n#SBATCH --open-mode=append               # Do not overwrite logs\n\n# Setup environment\ncd ../ || exit  # Go to the root directory of the repo\nsource setup_env.sh\n\n# Expected args:\n# - CONFIG_PATH\n# - PRETRAINED_PATH\n# - DISPLAY_NAME\n# - MODEL\n# - MODEL_NAME\n# - CONJOIN_TRAIN_DECODER\n# - CONJOIN_TEST\n# - TASK\n# - LR\n# - BATCH_SIZE\n# - RC_AUG\n\n\n# Run script\n# shellcheck disable=SC2154\nWANDB_NAME=\"${DISPLAY_NAME}_lr-${LR}_batch_size-${BATCH_SIZE}_rc_aug-${RC_AUG}\"\nfor seed in $(seq 1 5); do\n  # shellcheck disable=SC2154\n  HYDRA_RUN_DIR=\"./outputs/downstream/gb_cv5/${TASK}/${WANDB_NAME}/seed-${seed}\"\n  mkdir -p \"${HYDRA_RUN_DIR}\"\n  echo \"*****************************************************\"\n  echo \"Running GenomicsBenchmark model: ${DISPLAY_NAME}, task: ${TASK}, lr: ${LR}, batch_size: ${BATCH_SIZE}, rc_aug: ${RC_AUG}, SEED: ${seed}\"\n  # shellcheck disable=SC2086\n  python -m train \\\n    experiment=hg38/genomic_benchmark \\\n    callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \\\n    dataset.dataset_name=\"${TASK}\" \\\n    dataset.train_val_split_seed=${seed} \\\n    dataset.batch_size=${BATCH_SIZE} \\\n    dataset.rc_aug=\"${RC_AUG}\" \\\n    +dataset.conjoin_train=false \\\n    +dataset.conjoin_test=\"${CONJOIN_TEST}\" \\\n    model=\"${MODEL}\" \\\n    model._name_=\"${MODEL_NAME}\" \\\n    +model.config_path=\"${CONFIG_PATH}\" \\\n    +model.conjoin_test=\"${CONJOIN_TEST}\" \\\n    +decoder.conjoin_train=\"${CONJOIN_TRAIN_DECODER}\" \\\n    +decoder.conjoin_test=\"${CONJOIN_TEST}\" \\\n    optimizer.lr=\"${LR}\" \\\n    trainer.max_epochs=10 \\\n    train.pretrained_model_path=\"${PRETRAINED_PATH}\" \\\n    wandb.group=\"downstream/gb_cv5\" \\\n    wandb.job_type=\"${TASK}\" \\\n    wandb.name=\"${WANDB_NAME}\" \\\n    wandb.id=\"gb_cv5_${TASK}_${WANDB_NAME}_seed-${seed}\" \\\n    +wandb.tags=\\[\"seed-${seed}\"\\] \\\n    hydra.run.dir=\"${HYDRA_RUN_DIR}\"\n  echo \"*****************************************************\"\ndone\n"
  },
  {
    "path": "slurm_scripts/run_genomics_benchmark_cnn.sh",
    "content": "#!/bin/bash\n#SBATCH --get-user-env                   # Retrieve the users login environment\n#SBATCH -t 48:00:00                      # Time limit (hh:mm:ss)\n#SBATCH --mem=64G                     # RAM\n#SBATCH --gres=gpu:1                     # Number of GPUs\n#SBATCH --ntasks-per-node=1\n#SBATCH --cpus-per-task=2\n#SBATCH -N 1                             # Number of nodes\n#SBATCH --requeue                        # Requeue job if it fails\n#SBATCH --open-mode=append               # Do not overwrite logs\n\n# Setup environment\ncd ../ || exit  # Go to the root directory of the repo\nsource setup_env.sh\n\n# Expected args:\n# - TASK\n# - RC_AUG\n\n\n# LR: 1e-3 -- in https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks, Adam optimizer is used with default lr=1e-3\nLR=\"1e-3\"\n# Batch size: 64 -- See https://arxiv.org/abs/2306.15794 and https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks\nBATCH_SIZE=64\n\n# Run script\nWANDB_NAME=\"CNN-LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}\"\nfor seed in $(seq 1 5); do\n  HYDRA_RUN_DIR=\"./outputs/downstream/gb_cv5/${TASK}/${WANDB_NAME}/seed-${seed}\"\n  mkdir -p \"${HYDRA_RUN_DIR}\"\n  echo \"*****************************************************\"\n  echo \"Running GenomicsBenchmark TASK: ${TASK}, lr: ${LR}, batch_size: ${BATCH_SIZE}, RC_AUG: ${RC_AUG}, SEED: ${seed}\"\n  python -m train \\\n    experiment=hg38/genomic_benchmark_cnn \\\n    callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \\\n    dataset.dataset_name=\"${TASK}\" \\\n    dataset.train_val_split_seed=${seed} \\\n    dataset.batch_size=${BATCH_SIZE} \\\n    dataset.rc_aug=\"${RC_AUG}\" \\\n    optimizer.lr=\"${LR}\" \\\n    trainer.max_epochs=10 \\\n    wandb.group=\"downstream/gb_cv5\" \\\n    wandb.job_type=\"${TASK}\" \\\n    wandb.name=\"${WANDB_NAME}\" \\\n    wandb.id=\"gb_cv5_${TASK}_${WANDB_NAME}_seed-${seed}\" \\\n    +wandb.tags=\\[\"seed-${seed}\"\\] \\\n    hydra.run.dir=\"${HYDRA_RUN_DIR}\"\n  echo \"*****************************************************\"\ndone\n"
  },
  {
    "path": "slurm_scripts/run_nucleotide_transformer.sh",
    "content": "#!/bin/bash\n#SBATCH --get-user-env                   # Retrieve the users login environment\n#SBATCH -t 96:00:00                      # Time limit (hh:mm:ss)\n#SBATCH --mem=64G                        # RAM\n#SBATCH --gres=gpu:2                     # Number of GPUs\n#SBATCH --ntasks-per-node=2\n#SBATCH --cpus-per-task=4\n#SBATCH -N 1                             # Number of nodes\n#SBATCH --requeue                        # Requeue job if it fails\n#SBATCH --open-mode=append               # Do not overwrite logs\n\n# Setup environment\ncd ../ || exit  # Go to the root directory of the repo\nsource setup_env.sh\nexport HYDRA_FULL_ERROR=1\n\n# Expected args:\n# - CONFIG_PATH\n# - PRETRAINED_PATH\n# - DISPLAY_NAME\n# - MODEL\n# - MODEL_NAME\n# - CONJOIN_TRAIN_DECODER\n# - CONJOIN_TEST\n# - TASK\n# - LR\n# - BATCH_SIZE\n# - RC_AUG\n\n# Run script\nWANDB_NAME=\"${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}\"\nfor seed in $(seq 1 10); do\n  HYDRA_RUN_DIR=\"./outputs/downstream/nt_cv10_ep20/${TASK}/${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}/seed-${seed}\"\n  mkdir -p \"${HYDRA_RUN_DIR}\"\n  echo \"*****************************************************\"\n  echo \"Running NT model: ${DISPLAY_NAME}, TASK: ${TASK}, LR: ${LR}, BATCH_SIZE: ${BATCH_SIZE}, RC_AUG: ${RC_AUG}, SEED: ${seed}\"\n  python -m train \\\n    experiment=hg38/nucleotide_transformer \\\n    callbacks.model_checkpoint_every_n_steps.every_n_train_steps=5000 \\\n    dataset.dataset_name=\"${TASK}\" \\\n    dataset.train_val_split_seed=${seed} \\\n    dataset.batch_size=${BATCH_SIZE} \\\n    dataset.rc_aug=\"${RC_AUG}\" \\\n    +dataset.conjoin_test=\"${CONJOIN_TEST}\" \\\n    model=\"${MODEL}\" \\\n    model._name_=\"${MODEL_NAME}\" \\\n    +model.config_path=\"${CONFIG_PATH}\" \\\n    +model.conjoin_test=\"${CONJOIN_TEST}\" \\\n    +decoder.conjoin_train=\"${CONJOIN_TRAIN_DECODER}\" \\\n    +decoder.conjoin_test=\"${CONJOIN_TEST}\" \\\n    optimizer.lr=\"${LR}\" \\\n    train.pretrained_model_path=\"${PRETRAINED_PATH}\" \\\n    trainer.max_epochs=20 \\\n    wandb.group=\"downstream/nt_cv10_ep20\" \\\n    wandb.job_type=\"${TASK}\" \\\n    wandb.name=\"${WANDB_NAME}\" \\\n    wandb.id=\"nt_cv10_ep-20_${TASK}_${WANDB_NAME}_seed-${seed}\" \\\n    +wandb.tags=\\[\"seed-${seed}\"\\] \\\n    hydra.run.dir=\"${HYDRA_RUN_DIR}\"\n  echo \"*****************************************************\"\ndone\n"
  },
  {
    "path": "slurm_scripts/run_pretrain_caduceus.sh",
    "content": "#!/bin/bash\n#SBATCH --get-user-env                      # Retrieve the users login environment\n#SBATCH -t 96:00:00                         # Time limit (hh:mm:ss)\n#SBATCH --mem=100G                          # RAM\n#SBATCH --gres=gpu:8                        # Number of GPUs\n#SBATCH --ntasks-per-node=8                 # Should correspond to num devices (at least 1-1 task to GPU)\n##SBATCH --cpus-per-task=4                  # Number of CPU cores per task\n#SBATCH -N 1                                # Number of nodes\n#SBATCH --requeue                           # Requeue job if it fails\n#SBATCH --job-name=caduceus_ps              # Job name\n#SBATCH --output=../watch_folder/%x_%j.log  # Log file\n#SBATCH --open-mode=append                  # Do not overwrite logs\n\n# Setup environment\ncd ../ || exit  # Go to the root directory of the repo\nsource setup_env.sh\nexport HYDRA_FULL_ERROR=1\n\nNUM_DEVICES=8\n\n# Run script\nSEQLEN=131072\nMAX_STEPS=50000\nD_MODEL=256\nN_LAYER=8\nLR=\"8e-3\"\nBIDIRECTIONAL_STRATEGY=\"add\"\nBIDIRECTIONAL_WEIGHT_TIE=\"true\"\nRCPS=\"true\"\nRC_AUG=\"false\"\n\nBATCH_SIZE=$(( 1048576 / SEQLEN ))\nSEQLEN_DIS=\"$(echo \"scale=0; ${SEQLEN} / 1000\" | bc)k\"\nWANDB_NAME=\"caduceus-ps_seqlen-${SEQLEN_DIS}_d_model-${D_MODEL}_n_layer-${N_LAYER}_lr-${LR}\"\nHYDRA_RUN_DIR=\"./outputs/pretrain/hg38/${WANDB_NAME}\"\n\nmkdir -p \"${HYDRA_RUN_DIR}\"\nsrun python -m train \\\n  experiment=hg38/hg38 \\\n  callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \\\n  dataset.max_length=${SEQLEN} \\\n  dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \\\n  dataset.mlm=true \\\n  dataset.mlm_probability=0.15 \\\n  dataset.rc_aug=\"${RC_AUG}\" \\\n  model=\"caduceus\" \\\n  model.config.d_model=${D_MODEL} \\\n  model.config.n_layer=${N_LAYER} \\\n  model.config.bidirectional=true \\\n  model.config.bidirectional_strategy=${BIDIRECTIONAL_STRATEGY} \\\n  model.config.bidirectional_weight_tie=${BIDIRECTIONAL_WEIGHT_TIE} \\\n  model.config.rcps=${RCPS} \\\n  optimizer.lr=\"${LR}\" \\\n  train.global_batch_size=${BATCH_SIZE} \\\n  trainer.max_steps=${MAX_STEPS} \\\n  trainer.devices=${NUM_DEVICES} \\\n  +trainer.val_check_interval=$(( MAX_STEPS / 5 )) \\\n  wandb.group=pretrain_hg38 \\\n  wandb.name=\"${WANDB_NAME}\" \\\n  hydra.run.dir=\"${HYDRA_RUN_DIR}\"\n"
  },
  {
    "path": "slurm_scripts/run_pretrain_hyena.sh",
    "content": "#!/bin/bash\n#SBATCH --get-user-env                      # Retrieve the users login environment\n#SBATCH -t 96:00:00                         # Time limit (hh:mm:ss)\n#SBATCH --mem=100G                          # RAM\n#SBATCH --gres=gpu:8                        # Number of GPUs\n#SBATCH --ntasks-per-node=8\n#SBATCH --cpus-per-task=4                   # Number of CPU cores per task\n#SBATCH -N 1                                # Number of nodes\n#SBATCH --requeue                           # Requeue job if it fails\n#SBATCH --job-name=hyena                    # Job name\n#SBATCH --output=../watch_folder/%x_%j.log  # Log file\n\n# Setup environment\ncd ../ || exit  # Go to the root directory of the repo\nsource setup_env.sh\n\nNUM_DEVICES=8\n\n# Run script\nSEQLEN=1024\nMAX_STEPS=10000\nD_MODEL=256\nN_LAYER=4\nLR=\"6e-4\"\nRC_AUG=\"true\"\n\nBATCH_SIZE=$(( 1048576 / SEQLEN ))\nSEQLEN_DIS=\"$(echo \"scale=0; ${SEQLEN} / 1000\" | bc)k\"\nWANDB_NAME=\"hyena_rc_aug_seqlen-${SEQLEN_DIS}_dmodel-${D_MODEL}_nlayer-${N_LAYER}_lr-${LR}\"\nHYDRA_RUN_DIR=\"./outputs/pretrain/hg38/${WANDB_NAME}\"\n\nmkdir -p \"${HYDRA_RUN_DIR}\"\nsrun python -m train \\\n  experiment=hg38/hg38 \\\n  callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \\\n  dataset.max_length=${SEQLEN} \\\n  dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \\\n  dataset.mlm=false \\\n  dataset.mlm_probability=0.0 \\\n  dataset.rc_aug=\"${RC_AUG}\" \\\n  model=hyena \\\n  model.d_model=${D_MODEL} \\\n  model.n_layer=${N_LAYER} \\\n  optimizer.lr=\"${LR}\" \\\n  train.global_batch_size=${BATCH_SIZE} \\\n  trainer.max_steps=${MAX_STEPS} \\\n  trainer.devices=${NUM_DEVICES} \\\n  +trainer.val_check_interval=$(( MAX_STEPS / 5 )) \\\n  wandb.group=pretrain_hg38 \\\n  wandb.name=\"${WANDB_NAME}\" \\\n  hydra.run.dir=\"${HYDRA_RUN_DIR}\"\n"
  },
  {
    "path": "slurm_scripts/run_pretrain_mamba.sh",
    "content": "#!/bin/bash\n#SBATCH --get-user-env                      # Retrieve the users login environment\n#SBATCH -t 96:00:00                         # Time limit (hh:mm:ss)\n#SBATCH --mem=100G                           # RAM\n#SBATCH --gres=gpu:8                        # Number of GPUs\n#SBATCH --ntasks-per-node=8                 # Should correspond to num devices (at least 1-1 task to GPU)\n#SBATCH --cpus-per-task=4                   # Number of CPU cores per task\n#SBATCH -N 1                                # Number of nodes\n#SBATCH --requeue                           # Requeue job if it fails\n#SBATCH --job-name=mamba_ntp                # Job name\n#SBATCH --output=../watch_folder/%x_%j.log  # Log file\n\n# Setup environment\ncd ../ || exit  # Go to the root directory of the repo\nsource setup_env.sh\n\nNUM_DEVICES=8\n\n# Run script\nSEQLEN=1024\nMAX_STEPS=10000\nD_MODEL=256\nN_LAYER=8\nLR=\"8e-3\"\nRC_AUG=\"true\"\n\nBATCH_SIZE=$(( 1048576 / SEQLEN ))\nSEQLEN_DIS=\"$(echo \"scale=0; ${SEQLEN} / 1000\" | bc)k\"\nWANDB_NAME=\"mamba_ntp_rc_aug_seqlen-${SEQLEN_DIS}_d_model-${D_MODEL}_n_layer-${N_LAYER}_lr-${LR}\"\nHYDRA_RUN_DIR=\"./outputs/pretrain/hg38/${WANDB_NAME}\"\n\nmkdir -p \"${HYDRA_RUN_DIR}\"\nsrun python -m train \\\n  experiment=hg38/hg38 \\\n  callbacks.model_checkpoint_every_n_steps.every_n_train_steps=500 \\\n  dataset.max_length=${SEQLEN} \\\n  dataset.batch_size=$(( BATCH_SIZE / NUM_DEVICES )) \\\n  dataset.mlm=false \\\n  dataset.mlm_probability=0.0 \\\n  dataset.rc_aug=\"${RC_AUG}\" \\\n  model=mamba \\\n  model.config.d_model=${D_MODEL} \\\n  model.config.n_layer=${N_LAYER} \\\n  optimizer.lr=\"${LR}\" \\\n  train.global_batch_size=${BATCH_SIZE} \\\n  trainer.max_steps=${MAX_STEPS} \\\n  trainer.devices=${NUM_DEVICES} \\\n  +trainer.val_check_interval=$(( MAX_STEPS / 5 )) \\\n  wandb.group=pretrain_hg38 \\\n  wandb.name=\"${WANDB_NAME}\" \\\n  hydra.run.dir=\"${HYDRA_RUN_DIR}\"\n"
  },
  {
    "path": "slurm_scripts/wrapper_run_genomics.sh",
    "content": "#!/bin/bash\n\n# Choose one from below\n\n## Hyena\n## TODO: Download HF model from https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen to ../outputs/hyena_hf/hyenadna-tiny-1k-seqlen\n#LOG_DIR=\"../watch_folder/gb_cv5/hyena\"\n#CONFIG_PATH=$(realpath \"../outputs/hyena_hf/hyenadna-tiny-1k-seqlen/config.json\")\n#PRETRAINED_PATH=$(realpath \"../outputs/hyena_hf/hyenadna-tiny-1k-seqlen/weights.ckpt\")\n#DISPLAY_NAME=\"hyena\"\n#MODEL=\"hyena\"\n#MODEL_NAME=\"dna_embedding\"\n#CONJOIN_TRAIN_DECODER=\"false\"\n#CONJOIN_TEST=\"false\"\n#RC_AUGS=( \"false\" \"true\" )\n#LRS=( \"6e-4\" )\n\n## Mamba NTP\n#LOG_DIR=\"../watch_folder/gb_cv5/mamba\"\n#CONFIG_PATH=$(realpath \"../outputs/pretrain/hg38/mamba_ntp_rc_aug_seqlen-1k_d_model-128_n_layer-4_lr-8e-3/model_config.json\")\n#PRETRAINED_PATH=$(realpath \"../outputs/pretrain/hg38/mamba_ntp_rc_aug_seqlen-1k_d_model-128_n_layer-4_lr-8e-3/checkpoints/last.ckpt\")\n#DISPLAY_NAME=\"mamba_uni\"\n#MODEL=\"mamba\"\n#MODEL_NAME=\"dna_embedding_mamba\"\n#CONJOIN_TRAIN_DECODER=\"false\"\n#CONJOIN_TEST=\"false\"\n#RC_AUGS=( \"true\" )\n#LRS=( \"1e-3\" \"2e-3\" )\n\n## Caduceus NO POST HOC\n#LOG_DIR=\"../watch_folder/gb_cv5/caduceus\"\n#CONFIG_PATH=$(realpath \"../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/model_config.json\")\n#PRETRAINED_PATH=$(realpath \"../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/checkpoints/last.ckpt\")\n#DISPLAY_NAME=\"caduceus_NO_PH\"\n#MODEL=\"caduceus\"\n#MODEL_NAME=\"dna_embedding_caduceus\"\n#CONJOIN_TRAIN_DECODER=\"false\"\n#CONJOIN_TEST=\"false\"\n#RC_AUGS=( \"true\" )\n#LRS=( \"2e-3\")\n\n## Caduceus Post-Hoc\n#LOG_DIR=\"../watch_folder/gb_cv5/caduceus\"\n#CONFIG_PATH=$(realpath \"../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/model_config.json\")\n#PRETRAINED_PATH=$(realpath \"../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/checkpoints/last.ckpt\")\n#DISPLAY_NAME=\"caduceus_ph\"\n#MODEL=\"caduceus\"\n#MODEL_NAME=\"dna_embedding_caduceus\"\n#CONJOIN_TRAIN_DECODER=\"false\"\n#CONJOIN_TEST=\"true\"\n#RC_AUGS=( \"false\" )\n#LRS=( \"1e-3\" \"2e-3\" )\n\n## Caduceus Parameter Sharing\n#LOG_DIR=\"../watch_folder/gb_cv5/caduceus\"\n#CONFIG_PATH=$(realpath \"../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/model_config.json\")\n#PRETRAINED_PATH=$(realpath \"../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-118_n_layer-4_lr-8e-3/checkpoints/last.ckpt\")\n#DISPLAY_NAME=\"caduceus_ps\"\n#MODEL=\"caduceus\"\n#MODEL_NAME=\"dna_embedding_caduceus\"\n#CONJOIN_TRAIN_DECODER=\"true\"  # Use this in decoder to always combine forward and reverse complement channels\n#CONJOIN_TEST=\"false\"\n#RC_AUGS=( \"false\" )\n#LRS=( \"1e-3\" \"2e-3\" )\n\nmkdir -p \"${LOG_DIR}\"\nexport_str=\"ALL,CONFIG_PATH=${CONFIG_PATH},PRETRAINED_PATH=${PRETRAINED_PATH},DISPLAY_NAME=${DISPLAY_NAME},MODEL=${MODEL},MODEL_NAME=${MODEL_NAME},CONJOIN_TRAIN_DECODER=${CONJOIN_TRAIN_DECODER},CONJOIN_TEST=${CONJOIN_TEST}\"\nfor TASK in \"dummy_mouse_enhancers_ensembl\" \"demo_coding_vs_intergenomic_seqs\" \"demo_human_or_worm\" \"human_enhancers_cohn\" \"human_enhancers_ensembl\" \"human_ensembl_regulatory\" \"human_nontata_promoters\" \"human_ocr_ensembl\"; do\n  for LR in \"${LRS[@]}\"; do\n    for BATCH_SIZE in 128 256; do\n      for RC_AUG in \"${RC_AUGS[@]}\"; do\n        export_str=\"${export_str},TASK=${TASK},LR=${LR},BATCH_SIZE=${BATCH_SIZE},RC_AUG=${RC_AUG}\"\n        job_name=\"gb_${TASK}_${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}\"\n        sbatch \\\n          --job-name=\"${job_name}\" \\\n          --output=\"${LOG_DIR}/%x_%j.log\" \\\n          --export=\"${export_str}\" \\\n          \"run_genomics_benchmark.sh\"\n      done\n    done\n  done\ndone"
  },
  {
    "path": "slurm_scripts/wrapper_run_genomics_cnn.sh",
    "content": "#!/bin/bash\n\nLOG_DIR=\"../watch_folder/gb_cv5/cnn_baseline\"\nmkdir -p \"${LOG_DIR}\"\nexport_str=\"ALL\"\nfor TASK in \"dummy_mouse_enhancers_ensembl\" \"demo_coding_vs_intergenomic_seqs\" \"demo_human_or_worm\" \"human_enhancers_cohn\" \"human_enhancers_ensembl\" \"human_ensembl_regulatory\" \"human_nontata_promoters\" \"human_ocr_ensembl\"; do\n  for RC_AUG in \"false\"; do\n    export_str=\"${export_str},TASK=${TASK},RC_AUG=${RC_AUG}\"\n    job_name=\"gb_${TASK}_CNN_RC_AUG-${RC_AUG}\"\n    sbatch \\\n      --job-name=\"${job_name}\" \\\n      --output=\"${LOG_DIR}/%x_%j.log\" \\\n      --export=\"${export_str}\" \\\n      \"run_genomics_benchmark_cnn.sh\"\n  done\ndone\n"
  },
  {
    "path": "slurm_scripts/wrapper_run_nucleotide_transformer.sh",
    "content": "#!/bin/bash\n\n# Choose one from below\n\n## Caduceus NO POST HOC\n#LOG_DIR=\"../watch_folder/nt_cv10_ep20/caduceus\"\n#CONFIG_PATH=$(realpath \"../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/model_config.json\")\n#PRETRAINED_PATH=$(realpath \"../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/checkpoints/last.ckpt\")\n#DISPLAY_NAME=\"caduceus_NO_PH\"\n#MODEL=\"caduceus\"\n#MODEL_NAME=\"dna_embedding_caduceus\"\n#CONJOIN_TRAIN_DECODER=\"false\"\n#CONJOIN_TEST=\"false\"\n#RC_AUGS=( \"true\" )\n#LRS=( \"1e-3\" \"2e-3\")\n\n## Caduceus Post-Hoc\n#LOG_DIR=\"../watch_folder/nt_cv10_ep20/caduceus\"\n#CONFIG_PATH=$(realpath \"../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/model_config.json\")\n#PRETRAINED_PATH=$(realpath \"../outputs/pretrain/hg38/caduceus-ph_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/checkpoints/last.ckpt\")\n#DISPLAY_NAME=\"caduceus_ph\"\n#MODEL=\"caduceus\"\n#MODEL_NAME=\"dna_embedding_caduceus\"\n#CONJOIN_TRAIN_DECODER=\"false\"\n#CONJOIN_TEST=\"true\"\n#RC_AUGS=( \"false\" )\n#LRS=( \"1e-3\" \"2e-3\" )\n\n## Caduceus Parameter Sharing\n#LOG_DIR=\"../watch_folder/nt_cv10_ep20/caduceus\"\n#CONFIG_PATH=$(realpath \"../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/model_config.json\")\n#PRETRAINED_PATH=$(realpath \"../outputs/pretrain/hg38/caduceus-ps_seqlen-1k_d_model-256_n_layer-4_lr-8e-3/checkpoints/last.ckpt\")\n#DISPLAY_NAME=\"caduceus_ps\"\n#MODEL=\"caduceus\"\n#MODEL_NAME=\"dna_embedding_caduceus\"\n#CONJOIN_TRAIN_DECODER=\"true\"  # Use this in decoder to always combine forward and reverse complement channels\n#CONJOIN_TEST=\"false\"\n#RC_AUGS=( \"false\" )\n#LRS=( \"1e-3\" \"2e-3\" )\n\nmkdir -p \"${LOG_DIR}\"\nexport_str=\"ALL,CONFIG_PATH=${CONFIG_PATH},PRETRAINED_PATH=${PRETRAINED_PATH},DISPLAY_NAME=${DISPLAY_NAME},MODEL=${MODEL},MODEL_NAME=${MODEL_NAME},CONJOIN_TRAIN_DECODER=${CONJOIN_TRAIN_DECODER},CONJOIN_TEST=${CONJOIN_TEST}\"\nfor TASK in \"enhancers\" \"enhancers_types\" \"H3\" \"H3K4me1\" \"H3K4me2\" \"H3K4me3\" \"H3K9ac\" \"H3K14ac\" \"H3K36me3\" \"H3K79me3\" \"H4\" \"H4ac\" \"promoter_all\" \"promoter_no_tata\" \"promoter_tata\" \"splice_sites_all\" \"splice_sites_acceptors\" \"splice_sites_donors\"; do\n  for LR in \"${LRS[@]}\"; do\n    for BATCH_SIZE in 128 512; do\n      for RC_AUG in \"${RC_AUGS[@]}\"; do\n        export_str=\"${export_str},TASK=${TASK},LR=${LR},BATCH_SIZE=${BATCH_SIZE},RC_AUG=${RC_AUG}\"\n        job_name=\"nt_${TASK}_${DISPLAY_NAME}_LR-${LR}_BATCH_SIZE-${BATCH_SIZE}_RC_AUG-${RC_AUG}\"\n        sbatch \\\n          --job-name=\"${job_name}\" \\\n          --output=\"${LOG_DIR}/%x_%j.log\" \\\n          --export=\"${export_str}\" \\\n          \"run_nucleotide_transformer.sh\"\n      done\n    done\n  done\ndone\n"
  },
  {
    "path": "src/__init__.py",
    "content": ""
  },
  {
    "path": "src/callbacks/params.py",
    "content": "\"\"\"Callback to log the number of parameters of the model.\n\n\"\"\"\n\nimport pytorch_lightning as pl\nfrom pytorch_lightning.utilities import rank_zero_only\nfrom pytorch_lightning.utilities.parsing import AttributeDict\n\n\nclass ParamsLog(pl.Callback):\n    \"\"\" Log the number of parameters of the model \"\"\"\n    def __init__(\n        self,\n        total: bool = True,\n        trainable: bool = True,\n        fixed: bool = True,\n    ):\n        super().__init__()\n        self._log_stats = AttributeDict(\n            {\n                'total_params_log': total,\n                'trainable_params_log': trainable,\n                'non_trainable_params_log': fixed,\n            }\n        )\n\n    @rank_zero_only\n    def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:\n        logs = {}\n        if self._log_stats.total_params_log:\n            logs[\"params/total\"] = sum(p.numel() for p in pl_module.parameters())\n        if self._log_stats.trainable_params_log:\n            logs[\"params/trainable\"] = sum(p.numel() for p in pl_module.parameters()\n                                             if p.requires_grad)\n        if self._log_stats.non_trainable_params_log:\n            logs[\"params/fixed\"] = sum(p.numel() for p in pl_module.parameters()\n                                                     if not p.requires_grad)\n        if trainer.logger:\n            trainer.logger.log_hyperparams(logs)\n"
  },
  {
    "path": "src/callbacks/timer.py",
    "content": "\"\"\"Callback to monitor the speed of each step and each epoch.\n\nhttps://github.com/HazyResearch/transformers/blob/master/src/callbacks/speed_monitor.py\nAdapted from:\n    https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor\n\"\"\"\n\n# We only need the speed monitoring, not the GPU monitoring\nimport time\nfrom typing import Any\n\nfrom pytorch_lightning import Callback, Trainer, LightningModule\nfrom pytorch_lightning.utilities import rank_zero_only\nfrom pytorch_lightning.utilities.parsing import AttributeDict\nfrom pytorch_lightning.utilities.types import STEP_OUTPUT\n\n\nclass Timer(Callback):\n    \"\"\"Monitor the speed of each step and each epoch.\n    \"\"\"\n    def __init__(\n        self,\n        step: bool = True,\n        inter_step: bool = True,\n        epoch: bool = True,\n        val: bool = True,\n    ):\n        super().__init__()\n        self._log_stats = AttributeDict( {\n            'step_time': step,\n            'inter_step_time': inter_step,\n            'epoch_time': epoch,\n            'val_time': val,\n        })\n\n    def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:\n        self._snap_epoch_time = None\n\n    def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:\n        self._snap_step_time = None\n        self._snap_inter_step_time = None\n        self._snap_epoch_time = time.time()\n\n    def on_train_batch_start(\n        self,\n        trainer: Trainer,\n        pl_module: LightningModule,\n        batch: Any,\n        batch_idx: int,\n    ) -> None:\n        if self._log_stats.step_time:\n            self._snap_step_time = time.time()\n\n        if not self._should_log(trainer):\n            return\n\n        logs = {}\n        if self._log_stats.inter_step_time and self._snap_inter_step_time:\n            # First log at beginning of second step\n            logs[\"timer/inter_step\"] = (time.time() - self._snap_inter_step_time) # * 1000\n\n        if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step)\n\n    @rank_zero_only\n    def on_train_batch_end(\n        self,\n        trainer: Trainer,\n        pl_module: LightningModule,\n        outputs: STEP_OUTPUT,\n        batch: Any,\n        batch_idx: int,\n    ) -> None:\n        if self._log_stats.inter_step_time:\n            self._snap_inter_step_time = time.time()\n\n        if not self._should_log(trainer):\n            return\n\n        logs = {}\n        if self._log_stats.step_time and self._snap_step_time:\n            logs[\"timer/step\"] = (time.time() - self._snap_step_time) # * 1000\n\n        if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step)\n\n    @rank_zero_only\n    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None:\n        logs = {}\n        if self._log_stats.epoch_time and self._snap_epoch_time:\n            logs[\"timer/epoch\"] = time.time() - self._snap_epoch_time\n        if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step)\n\n    def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:\n        self._snap_val_time = time.time()\n\n    @rank_zero_only\n    def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None:\n        logs = {}\n        if self._log_stats.val_time and self._snap_val_time:\n            logs[\"timer/validation\"] = time.time() - self._snap_val_time\n        if trainer.logger: trainer.logger.log_metrics(logs) # , step=trainer.global_step)\n\n    @staticmethod\n    def _should_log(trainer) -> bool:\n        return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop\n"
  },
  {
    "path": "src/callbacks/validation.py",
    "content": "\"\"\"Check validation every n **global** steps.\n\nPytorch Lightning has a `val_check_interval` parameter that checks validation every n batches, but does not support\nchecking every n **global** steps.\n\"\"\"\n\nfrom typing import Any\n\nfrom pytorch_lightning.callbacks import Callback\nfrom pytorch_lightning.trainer.states import RunningStage\n\n\nclass ValEveryNGlobalSteps(Callback):\n    \"\"\"Check validation every n **global** steps.\"\"\"\n    def __init__(self, every_n):\n        self.every_n = every_n\n        self.last_run = None\n\n    def on_train_batch_end(self, trainer, *_: Any):\n        \"\"\"Check if we should run validation.\n\n        Adapted from: https://github.com/Lightning-AI/pytorch-lightning/issues/2534#issuecomment-1085986529\n        \"\"\"\n        # Prevent Running validation many times in gradient accumulation\n        if trainer.global_step == self.last_run:\n            return\n        else:\n            self.last_run = None\n        if trainer.global_step % self.every_n == 0 and trainer.global_step != 0:\n            trainer.training = False\n            stage = trainer.state.stage\n            trainer.state.stage = RunningStage.VALIDATING\n            trainer._run_evaluate()\n            trainer.state.stage = stage\n            trainer.training = True\n            trainer._logger_connector._epoch_end_reached = False\n            self.last_run = trainer.global_step\n"
  },
  {
    "path": "src/dataloaders/__init__.py",
    "content": "from . import genomics\nfrom .base import SequenceDataset\n"
  },
  {
    "path": "src/dataloaders/base.py",
    "content": "\"\"\" Datasets for core experimental results.\n\n\"\"\"\n\nimport os\nfrom functools import partial\nfrom pathlib import Path\n\nimport torch\n\n\n# Default data path is environment variable or <repo_root_dir>/data\nif (default_data_path := os.getenv(\"DATA_PATH\")) is None:\n    default_data_path = Path(__file__).parent.parent.parent.absolute()\n    default_data_path = default_data_path / \"data\"\nelse:\n    default_data_path = Path(default_data_path).absolute()\n\n\nclass DefaultCollateMixin:\n    \"\"\"Controls collating in the DataLoader\n\n    The CollateMixin classes instantiate a dataloader by separating collate arguments with the rest of the dataloader\n    arguments. Instantiations of this class should modify the callback functions as desired, and modify the collate_args\n    list. The class then defines a _dataloader() method which takes in a DataLoader constructor and arguments,\n    constructs a collate_fn based on the collate_args, and passes the rest of the arguments into the constructor.\n    \"\"\"\n\n    @classmethod\n    def _collate_callback(cls, x, *args, **kwargs):\n        \"\"\"\n        Modify the behavior of the default _collate method.\n        \"\"\"\n        return x\n\n    _collate_arg_names = []\n\n    @classmethod\n    def _return_callback(cls, return_value, *args, **kwargs):\n        \"\"\"\n        Modify the return value of the collate_fn.\n        Assign a name to each element of the returned tuple beyond the (x, y) pairs\n        See InformerSequenceDataset for an example of this being used\n        \"\"\"\n        x, y, *z = return_value\n        assert len(z) == len(cls._collate_arg_names), \"Specify a name for each auxiliary data item returned by dataset\"\n        return x, y, {k: v for k, v in zip(cls._collate_arg_names, z)}\n\n    @classmethod\n    def _collate(cls, batch, *args, **kwargs):\n        # From https://github.com/pyforch/pytorch/blob/master/torch/utils/data/_utils/collate.py\n        elem = batch[0]\n        if isinstance(elem, torch.Tensor):\n            out = None\n            if torch.utils.data.get_worker_info() is not None:\n                # If we're in a background process, concatenate directly into a\n                # shared memory tensor to avoid an extra copy\n                numel = sum(x.numel() for x in batch)\n                storage = elem.storage()._new_shared(numel)\n                out = elem.new(storage)\n            x = torch.stack(batch, dim=0, out=out)\n\n            # Insert custom functionality into the collate_fn\n            x = cls._collate_callback(x, *args, **kwargs)\n\n            return x\n        else:\n            return torch.tensor(batch)\n\n    @classmethod\n    def _collate_fn(cls, batch, *args, **kwargs):\n        \"\"\"\n        Default collate function.\n        Generally accessed by the dataloader() methods to pass into torch DataLoader\n\n        Arguments:\n            batch: list of (x, y) pairs\n            args, kwargs: extra arguments that get passed into the _collate_callback and _return_callback\n        \"\"\"\n        x, y, *z = zip(*batch)\n\n        x = cls._collate(x, *args, **kwargs)\n        y = cls._collate(y)\n        z = [cls._collate(z_) for z_ in z]\n\n        return_value = (x, y, *z)\n        return cls._return_callback(return_value, *args, **kwargs)\n\n    # List of loader arguments to pass into collate_fn\n    collate_args = []\n\n    def _dataloader(self, dataset, **loader_args):\n        collate_args = {k: loader_args[k] for k in loader_args if k in self.collate_args}\n        loader_args = {k: loader_args[k] for k in loader_args if k not in self.collate_args}\n        loader_cls = loader_registry[loader_args.pop(\"_name_\", None)]\n        return loader_cls(\n            dataset=dataset,\n            collate_fn=partial(self._collate_fn, **collate_args),\n            **loader_args,\n        )\n\n\n# class SequenceDataset(LightningDataModule):\n# [21-09-10 AG] Subclassing LightningDataModule fails due to trying to access _has_setup_fit. No idea why. So we just\n# provide our own class with the same core methods as LightningDataModule (e.g. setup)\nclass SequenceDataset(DefaultCollateMixin):\n    registry = {}\n    _name_ = NotImplementedError(\"Dataset must have shorthand name\")\n\n    # Since subclasses do not specify __init__ which is instead handled by this class\n    # Subclasses can provide a list of default arguments which are automatically registered as attributes\n    # TODO it might be possible to write this as a @dataclass, but it seems tricky to separate from the other features\n    #  of this class such as the _name_ and d_input/d_output\n    @property\n    def init_defaults(self):\n        return {}\n\n    # https://www.python.org/dev/peps/pep-0487/#subclass-registration\n    def __init_subclass__(cls, **kwargs):\n        super().__init_subclass__(**kwargs)\n        cls.registry[cls._name_] = cls\n\n    def __init__(self, _name_, data_dir=None, **dataset_cfg):\n        assert _name_ == self._name_\n        self.data_dir = Path(data_dir).absolute() if data_dir is not None else None\n\n        # Add all arguments to self\n        init_args = self.init_defaults.copy()\n        init_args.update(dataset_cfg)\n        for k, v in init_args.items():\n            setattr(self, k, v)\n\n        # The train, val, test datasets must be set by `setup()`\n        self.dataset_train = self.dataset_val = self.dataset_test = None\n\n        self.init()\n\n    def init(self):\n        \"\"\"Hook called at end of __init__, override this instead of __init__\"\"\"\n        pass\n\n    def setup(self):\n        \"\"\"This method should set self.dataset_train, self.dataset_val, and self.dataset_test.\"\"\"\n        raise NotImplementedError\n\n    def split_train_val(self, val_split):\n        \"\"\"\n        Randomly split self.dataset_train into a new (self.dataset_train, self.dataset_val) pair.\n        \"\"\"\n        train_len = int(len(self.dataset_train) * (1.0 - val_split))\n        self.dataset_train, self.dataset_val = torch.utils.data.random_split(\n            self.dataset_train,\n            (train_len, len(self.dataset_train) - train_len),\n            generator=torch.Generator().manual_seed(\n                getattr(self, \"seed\", 42)\n            ),  # PL is supposed to have a way to handle seeds properly, but doesn't seem to work for us\n        )\n\n    def train_dataloader(self, **kwargs):\n        \"\"\"Return a DataLoader for the training dataset.\"\"\"\n        return self._train_dataloader(self.dataset_train, **kwargs)\n\n    def _train_dataloader(self, dataset, **kwargs):\n        if dataset is None:\n            return\n        kwargs['shuffle'] = 'sampler' not in kwargs  # shuffle cant be True if we have custom sampler\n        return self._dataloader(dataset, **kwargs)\n\n    def val_dataloader(self, **kwargs):\n        \"\"\"Return a DataLoader for the validation dataset.\"\"\"\n        return self._eval_dataloader(self.dataset_val, **kwargs)\n\n    def test_dataloader(self, **kwargs):\n        \"\"\"Return a DataLoader for the test dataset.\"\"\"\n        return self._eval_dataloader(self.dataset_test, **kwargs)\n\n    def _eval_dataloader(self, dataset, **kwargs):\n        if dataset is None:\n            return\n        # Note that shuffle=False by default\n        return self._dataloader(dataset, **kwargs)\n\n    def __str__(self):\n        return self._name_\n\n\n# Registry for dataloader class\nloader_registry = {\n    None: torch.utils.data.DataLoader,  # default case\n}\n"
  },
  {
    "path": "src/dataloaders/datasets/genomic_bench_dataset.py",
    "content": "\"\"\"Genomic Benchmarks Dataset.\n\nFrom: https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks\n\"\"\"\n\nfrom pathlib import Path\n\nimport torch\nfrom genomic_benchmarks.data_check import is_downloaded\nfrom genomic_benchmarks.loc2seq import download_dataset\n\nfrom src.dataloaders.utils.rc import coin_flip, string_reverse_complement\n\n\nclass GenomicBenchmarkDataset(torch.utils.data.Dataset):\n    \"\"\"\n    Loop through bed file, retrieve (chr, start, end), query fasta file for sequence.\n    Returns a generator that retrieves the sequence.\n    \"\"\"\n\n    def __init__(\n            self,\n            split,\n            max_length,\n            dataset_name=\"human_nontata_promoters\",\n            d_output=2,  # default binary classification\n            dest_path=None,\n            tokenizer=None,\n            tokenizer_name=None,\n            use_padding=None,\n            add_eos=False,\n            rc_aug=False,\n            conjoin_train=False,\n            conjoin_test=False,\n            return_augs=False,\n            return_mask=False,\n    ):\n\n        self.max_length = max_length\n        self.use_padding = use_padding\n        self.tokenizer_name = tokenizer_name\n        self.tokenizer = tokenizer\n        self.return_augs = return_augs\n        self.add_eos = add_eos\n        self.d_output = d_output  # needed for decoder to grab\n        assert not (conjoin_train and conjoin_test), \"conjoin_train and conjoin_test cannot both be True\"\n        if (conjoin_train or conjoin_test) and rc_aug:\n            print(\"When using conjoin, we turn off rc_aug.\")\n            rc_aug = False\n        self.rc_aug = rc_aug\n        self.conjoin_train = conjoin_train\n        self.conjoin_test = conjoin_test\n        self.return_mask = return_mask\n\n        if not is_downloaded(dataset_name, cache_path=dest_path):\n            print(\"downloading {} to {}\".format(dataset_name, dest_path))\n            download_dataset(dataset_name, version=0, dest_path=dest_path)\n        else:\n            print(\"already downloaded {}-{}\".format(split, dataset_name))\n\n        self.split = split\n\n        # use Path object\n        base_path = Path(dest_path) / dataset_name / split\n\n        self.all_seqs = []\n        self.all_labels = []\n        label_mapper = {}\n\n        for i, x in enumerate(base_path.iterdir()):\n            label_mapper[x.stem] = i\n\n        for label_type in label_mapper.keys():\n            for path in (base_path / label_type).iterdir():\n                with open(path, \"r\") as f:\n                    content = f.read()\n                self.all_seqs.append(content)\n                self.all_labels.append(label_mapper[label_type])\n\n    def __len__(self):\n        return len(self.all_labels)\n\n    def __getitem__(self, idx):\n        x = self.all_seqs[idx]\n        y = self.all_labels[idx]\n\n        if (self.rc_aug or (self.conjoin_test and self.split == \"train\")) and coin_flip():\n            x = string_reverse_complement(x)\n\n        seq = self.tokenizer(\n            x,\n            add_special_tokens=False,\n            padding=\"max_length\" if self.use_padding else None,\n            max_length=self.max_length,\n            truncation=True,\n        )\n        seq_ids = seq[\"input_ids\"]  # get input_ids\n\n        # need to handle eos here\n        if self.add_eos:\n            # append list seems to be faster than append tensor\n            seq_ids.append(self.tokenizer.sep_token_id)\n\n        if self.conjoin_train or (self.conjoin_test and self.split != \"train\"):\n            x_rc = string_reverse_complement(x)\n            seq_rc = self.tokenizer(\n                x_rc,\n                add_special_tokens=False,\n                padding=\"max_length\" if self.use_padding else None,\n                max_length=self.max_length,\n                truncation=True,\n            )\n            seq_rc_ids = seq_rc[\"input_ids\"]  # get input_ids\n            # need to handle eos here\n            if self.add_eos:\n                # append list seems to be faster than append tensor\n                seq_rc_ids.append(self.tokenizer.sep_token_id)\n            seq_ids = torch.stack((torch.LongTensor(seq_ids), torch.LongTensor(seq_rc_ids)), dim=1)\n\n        else:\n            # convert to tensor\n            seq_ids = torch.LongTensor(seq_ids)\n\n        # need to wrap in list\n        target = torch.LongTensor([y])\n\n        # `seq` has shape:\n        #     - (seq_len,) if not conjoining\n        #     - (seq_len, 2) for conjoining\n        if self.return_mask:\n            return seq_ids, target, {\"mask\": torch.BoolTensor(seq[\"attention_mask\"])}\n        else:\n            return seq_ids, target\n"
  },
  {
    "path": "src/dataloaders/datasets/hg38_char_tokenizer.py",
    "content": "\"\"\" \nFrom: https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py\n\nCharacterTokenizer for Hugging Face Transformers.\nThis is heavily inspired from CanineTokenizer in transformers package.\n\"\"\"\nimport json\nimport os\nfrom pathlib import Path\nfrom typing import Dict, List, Optional, Sequence, Union\n\nfrom transformers.tokenization_utils import AddedToken, PreTrainedTokenizer\n\n\nclass CharacterTokenizer(PreTrainedTokenizer):\n    def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str = 'left', **kwargs):\n        \"\"\"Character tokenizer for Hugging Face transformers.\n        Args:\n            characters (Sequence[str]): List of desired characters. Any character which\n                is not included in this list will be replaced by a special token called\n                [UNK] with id=6. Following is the list of all the special tokens with\n                their corresponding ids:\n                    \"[CLS]\": 0\n                    \"[SEP]\": 1\n                    \"[BOS]\": 2\n                    \"[MASK]\": 3\n                    \"[PAD]\": 4\n                    \"[RESERVED]\": 5\n                    \"[UNK]\": 6\n                an id (starting at 7) will be assigned to each character.\n            model_max_length (int): Model maximum sequence length.\n        \"\"\"\n        self.characters = characters\n        self.model_max_length = model_max_length\n        bos_token = AddedToken(\"[BOS]\", lstrip=False, rstrip=False)\n        eos_token = AddedToken(\"[EOS]\", lstrip=False, rstrip=False)\n        sep_token = AddedToken(\"[SEP]\", lstrip=False, rstrip=False)\n        cls_token = AddedToken(\"[CLS]\", lstrip=False, rstrip=False)\n        pad_token = AddedToken(\"[PAD]\", lstrip=False, rstrip=False)\n        unk_token = AddedToken(\"[UNK]\", lstrip=False, rstrip=False)\n\n        mask_token = AddedToken(\"[MASK]\", lstrip=True, rstrip=False)\n\n        self._vocab_str_to_int = {\n            \"[CLS]\": 0,\n            \"[SEP]\": 1,\n            \"[BOS]\": 2,\n            \"[MASK]\": 3,\n            \"[PAD]\": 4,\n            \"[RESERVED]\": 5,\n            \"[UNK]\": 6,\n            **{ch: i + 7 for i, ch in enumerate(characters)},\n        }\n        self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}\n\n        # TODO: This should be a parameter passed to __init__\n        complement_map = {\"A\": \"T\", \"C\": \"G\", \"G\": \"C\", \"T\": \"A\"}\n        self.complement_map = {}\n        for k, v in self._vocab_str_to_int.items():\n            complement_id = self._vocab_str_to_int[complement_map[k]] if k in complement_map.keys() else v\n            self.complement_map[self._vocab_str_to_int[k]] = complement_id\n        super().__init__(\n            bos_token=bos_token,\n            eos_token=pad_token,\n            sep_token=sep_token,\n            cls_token=cls_token,\n            pad_token=pad_token,\n            mask_token=mask_token,\n            unk_token=unk_token,\n            add_prefix_space=False,\n            model_max_length=model_max_length,\n            padding_side=padding_side,\n            **kwargs,\n        )\n\n    @property\n    def vocab_size(self) -> int:\n        return len(self._vocab_str_to_int)\n\n    def _tokenize(self, text: str) -> List[str]:\n        return list(text)\n\n    def _convert_token_to_id(self, token: str) -> int:\n        return self._vocab_str_to_int.get(token, self._vocab_str_to_int[\"[UNK]\"])\n\n    def _convert_id_to_token(self, index: int) -> str:\n        return self._vocab_int_to_str[index]\n\n    def convert_tokens_to_string(self, tokens):\n        return \"\".join(tokens)\n\n    def build_inputs_with_special_tokens(\n            self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        sep = [self.sep_token_id]\n        cls = [self.cls_token_id]\n        result = cls + token_ids_0 + sep\n        if token_ids_1 is not None:\n            result += token_ids_1 + sep\n        return result\n\n    def get_special_tokens_mask(\n            self,\n            token_ids_0: List[int],\n            token_ids_1: Optional[List[int]] = None,\n            already_has_special_tokens: bool = False,\n    ) -> List[int]:\n        if already_has_special_tokens:\n            return super().get_special_tokens_mask(\n                token_ids_0=token_ids_0,\n                token_ids_1=token_ids_1,\n                already_has_special_tokens=True,\n            )\n\n        result = [1] + ([0] * len(token_ids_0)) + [1]\n        if token_ids_1 is not None:\n            result += ([0] * len(token_ids_1)) + [1]\n        return result\n\n    def get_vocab(self) -> Dict[str, int]:\n        return self._vocab_str_to_int\n\n    def create_token_type_ids_from_sequences(\n            self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        sep = [self.sep_token_id]\n        cls = [self.cls_token_id]\n\n        result = len(cls + token_ids_0 + sep) * [0]\n        if token_ids_1 is not None:\n            result += len(token_ids_1 + sep) * [1]\n        return result\n\n    def get_config(self) -> Dict:\n        return {\n            \"char_ords\": [ord(ch) for ch in self.characters],\n            \"model_max_length\": self.model_max_length,\n        }\n\n    @classmethod\n    def from_config(cls, config: Dict) -> \"CharacterTokenizer\":\n        cfg = {}\n        cfg[\"characters\"] = [chr(i) for i in config[\"char_ords\"]]\n        cfg[\"model_max_length\"] = config[\"model_max_length\"]\n        return cls(**cfg)\n\n    def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):\n        cfg_file = Path(save_directory) / \"tokenizer_config.json\"\n        cfg = self.get_config()\n        with open(cfg_file, \"w\") as f:\n            json.dump(cfg, f, indent=4)\n\n    @classmethod\n    def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):\n        cfg_file = Path(save_directory) / \"tokenizer_config.json\"\n        with open(cfg_file) as f:\n            cfg = json.load(f)\n        return cls.from_config(cfg)\n"
  },
  {
    "path": "src/dataloaders/datasets/hg38_dataset.py",
    "content": "\"\"\"Dataset for sampling arbitrary intervals from the human genome.\n\n\"\"\"\n\nimport math\nfrom pathlib import Path\n\nimport pandas as pd\nimport torch\nfrom pyfaidx import Fasta\n\nfrom src.dataloaders.utils.mlm import mlm_getitem\nfrom src.dataloaders.utils.rc import coin_flip, string_reverse_complement\n\nMAX_ALLOWED_LENGTH = 2 ** 20\n\n\nclass FastaInterval:\n    \"\"\"Retrieves sequences from a fasta file given a chromosome and start/end indices.\"\"\"\n    def __init__(\n            self,\n            *,\n            fasta_file,\n            return_seq_indices=False,\n            rc_aug=False,\n    ):\n        fasta_file = Path(fasta_file)\n        assert fasta_file.exists(), \"Path to fasta file must exist!\"\n\n        self.seqs = Fasta(str(fasta_file))\n        self.return_seq_indices = return_seq_indices\n        self.rc_aug = rc_aug\n\n        # calc len of each chromosome in fasta file, store in dict\n        self.chr_lens = {}\n\n        for chr_name in self.seqs.keys():\n            self.chr_lens[chr_name] = len(self.seqs[chr_name])\n\n    @staticmethod\n    def _compute_interval(start, end, max_length, i_shift):\n        if max_length == MAX_ALLOWED_LENGTH:\n            return start, end\n        if max_length < MAX_ALLOWED_LENGTH:\n            assert MAX_ALLOWED_LENGTH % max_length == 0\n            return start + i_shift * max_length, start + (i_shift + 1) * max_length\n        else:\n            raise ValueError(f\"`max_length` {max_length} (> 2^{int(math.log(MAX_ALLOWED_LENGTH, 2))}) is too large!\")\n\n    def __call__(\n            self,\n            chr_name,\n            start,\n            end,\n            max_length,\n            i_shift,\n            return_augs=False,\n    ):\n        \"\"\"\n        max_length passed from dataset, not from init\n        \"\"\"\n        chromosome = self.seqs[chr_name]\n        chromosome_length = self.chr_lens[chr_name]\n\n        start, end = self._compute_interval(start, end, max_length, i_shift)\n\n        if end > chromosome_length:\n            # Shift interval down\n            start = start - (end - chromosome_length)\n            end = chromosome_length\n            assert start == chromosome_length - max_length\n\n        if start < 0:\n            # Shift interval up\n            end = end - start\n            start = 0\n            assert end == max_length\n\n        if end > chromosome_length:\n            # This may occur if start + MAX_ALLOWED_LENGTH extends beyond the end of the chromosome\n            start = chromosome_length - max_length\n            end = chromosome_length\n\n        seq = str(chromosome[start:end])\n\n        if self.rc_aug and coin_flip():\n            seq = string_reverse_complement(seq)\n\n        return seq\n\n\nclass HG38Dataset(torch.utils.data.Dataset):\n    \"\"\"Loop through bed file, retrieve (chr, start, end), query fasta file for sequence.\"\"\"\n\n    def __init__(\n            self,\n            split,\n            bed_file,\n            fasta_file,\n            max_length,\n            mlm=False,\n            mlm_probability=0.15,\n            pad_max_length=None,\n            tokenizer=None,\n            tokenizer_name=None,\n            add_eos=False,\n            return_seq_indices=False,\n            rc_aug=False,\n            return_augs=False,\n    ):\n        self.mlm = mlm\n        self.mlm_probability = mlm_probability\n        if self.mlm and self.mlm_probability <= 0.0:\n            raise ValueError(f\"`mlm_probability` has to be > 0.0, got {self.mlm_probability}.\")\n        if self.mlm:\n            # TODO: see if this helps\n            # self.eligible_replacements = torch.tensor(\n            #     tokenizer(\"ACGT\", add_special_tokens=False)[\"input_ids\"], dtype=torch.long\n            # )\n            self.eligible_replacements = None\n        else:\n            self.eligible_replacements = None\n        self.max_length = max_length\n        self.pad_max_length = pad_max_length if pad_max_length is not None else max_length\n        self.tokenizer_name = tokenizer_name\n        self.tokenizer = tokenizer\n        self.return_augs = return_augs\n        self.add_eos = add_eos\n\n        if max_length <= MAX_ALLOWED_LENGTH:\n            assert MAX_ALLOWED_LENGTH % max_length == 0, f\"`max_length` must be a power of 2!\"\n            self.shifts = MAX_ALLOWED_LENGTH // max_length\n        else:\n            raise ValueError(f\"`max_length` {max_length} (> 2^{int(math.log(MAX_ALLOWED_LENGTH, 2))}) is too large!\")\n\n        bed_path = Path(bed_file)\n        assert bed_path.exists(), \"Path to .bed file must exist!\"\n\n        # read bed file\n        df_raw = pd.read_csv(str(bed_path), sep=\"\\t\", names=[\"chr_name\", \"start\", \"end\", \"split\"])\n        # select only split df\n        self.df = df_raw[df_raw[\"split\"] == split]\n        # Update end points so that sequences are all length == MAX_ALLOWED_LENGTH\n        self.df.loc[:, \"end\"] = self.df[\"start\"] + MAX_ALLOWED_LENGTH\n\n        self.fasta = FastaInterval(\n            fasta_file=fasta_file,\n            return_seq_indices=return_seq_indices,\n            rc_aug=rc_aug\n        )\n\n    @staticmethod\n    def replace_value(x, old_value, new_value):\n        \"\"\"Helper for replacing values in a tensor.\"\"\"\n        return torch.where(x == old_value, new_value, x)\n\n    def __len__(self):\n        return len(self.df) * self.shifts\n\n    def __getitem__(self, idx):\n        \"\"\"Returns a sequence of specified len\"\"\"\n        # sample a random row from df\n        row_idx, shift_idx = idx // self.shifts, idx % self.shifts\n        row = self.df.iloc[row_idx]\n        chr_name, start, end = (row.iloc[0], row.iloc[1], row.iloc[2])\n\n        seq = self.fasta(\n            chr_name,\n            start,\n            end,\n            max_length=self.max_length,\n            i_shift=shift_idx,\n            return_augs=self.return_augs,\n        )\n        if end - start != MAX_ALLOWED_LENGTH:\n            print(row, \"\\nLength: \", end - start)\n\n        if self.tokenizer_name == \"char\":\n            seq = self.tokenizer(\n                seq,\n                padding=\"max_length\",\n                max_length=self.pad_max_length,\n                truncation=True,\n                add_special_tokens=False\n            )\n\n            seq = seq[\"input_ids\"]  # get input_ids\n\n            # need to handle eos here\n            if self.add_eos:\n                # append list seems to be faster than append tensor\n                seq.append(self.tokenizer.sep_token_id)\n\n        elif self.tokenizer_name == \"bpe\":\n            seq = self.tokenizer(\n                seq,\n                # add_special_tokens=False,\n                padding=\"max_length\",\n                max_length=self.pad_max_length,\n                truncation=True,\n            )\n            # get input_ids\n            if self.add_eos:\n                seq = seq[\"input_ids\"][1:]  # remove the bos, keep the eos token\n            else:\n                seq = seq[\"input_ids\"][1:-1]  # remove both special tokens\n\n        # convert to tensor\n        seq = torch.LongTensor(seq)\n\n        # replace N token with a pad token, so we can ignore it in the loss\n        seq = self.replace_value(seq, self.tokenizer._vocab_str_to_int[\"N\"], self.tokenizer.pad_token_id)\n\n        if self.mlm:\n            data, target = mlm_getitem(\n                seq,\n                mlm_probability=self.mlm_probability,\n                contains_eos=self.add_eos,\n                tokenizer=self.tokenizer,\n                eligible_replacements=self.eligible_replacements,\n            )\n\n        else:\n            data = seq[:-1].clone()\n            target = seq[1:].clone()\n\n        return data, target\n"
  },
  {
    "path": "src/dataloaders/datasets/nucleotide_transformer_dataset.py",
    "content": "\"\"\"Nucleotide Transformer Benchmarks Dataset.\n\nFrom: https://huggingface.co/datasets/InstaDeepAI/nucleotide_transformer_downstream_tasks\n\"\"\"\n\nimport torch\nfrom datasets import load_dataset\n\nfrom src.dataloaders.utils.rc import coin_flip, string_reverse_complement\n\n\nclass NucleotideTransformerDataset(torch.utils.data.Dataset):\n\n    \"\"\"\n    Loop through fasta file for sequence.\n    Returns a generator that retrieves the sequence.\n    \"\"\"\n\n    def __init__(\n        self,\n        split,\n        max_length,\n        dataset_name=None,\n        d_output=2,  # default binary classification\n        tokenizer=None,\n        tokenizer_name=None,\n        use_padding=None,\n        add_eos=False,\n        rc_aug=False,\n        conjoin_train=False,\n        conjoin_test=False,\n        return_augs=False\n    ):\n\n        self.max_length = max_length\n        self.use_padding = use_padding\n        self.tokenizer_name = tokenizer_name\n        self.tokenizer = tokenizer\n        self.return_augs = return_augs\n        self.add_eos = add_eos\n        self.d_output = d_output  # needed for decoder to grab\n        assert not (conjoin_train and conjoin_test), \"conjoin_train and conjoin_test cannot both be True\"\n        if (conjoin_train or conjoin_test) and rc_aug:\n            print(\"When using conjoin, we turn off rc_aug.\")\n            rc_aug = False\n        self.rc_aug = rc_aug\n        self.conjoin_train = conjoin_train\n        self.conjoin_test = conjoin_test\n\n        self.split = split\n\n        # For NT tasks, we use data from InstaDeepAI/nucleotide_transformer_downstream_tasks\n        self.seqs = load_dataset(\n            \"InstaDeepAI/nucleotide_transformer_downstream_tasks\",\n            name=dataset_name,\n            split=split\n        )\n\n    def __len__(self):\n        return len(self.seqs)\n\n    def __getitem__(self, idx):\n        x = self.seqs[idx][\"sequence\"]  # only one sequence\n        y = self.seqs[idx][\"label\"]\n\n        if (self.rc_aug or (self.conjoin_test and self.split == \"train\")) and coin_flip():\n            x = string_reverse_complement(x)\n\n        seq = self.tokenizer(\n            x,\n            add_special_tokens=False,\n            padding=\"max_length\" if self.use_padding else None,\n            max_length=self.max_length,\n            truncation=True,\n        )\n        seq_ids = seq[\"input_ids\"]  # get input_ids\n\n        # need to handle eos here\n        if self.add_eos:\n            # append list seems to be faster than append tensor\n            seq_ids.append(self.tokenizer.sep_token_id)\n\n        if self.conjoin_train or (self.conjoin_test and self.split != \"train\"):\n            x_rc = string_reverse_complement(x)\n            seq_rc = self.tokenizer(\n                x_rc,\n                add_special_tokens=False,\n                padding=\"max_length\" if self.use_padding else None,\n                max_length=self.max_length,\n                truncation=True,\n            )\n            seq_rc_ids = seq_rc[\"input_ids\"]  # get input_ids\n            # need to handle eos here\n            if self.add_eos:\n                # append list seems to be faster than append tensor\n                seq_rc_ids.append(self.tokenizer.sep_token_id)\n            seq_ids = torch.stack((torch.LongTensor(seq_ids), torch.LongTensor(seq_rc_ids)), dim=1)\n\n        else:\n            # convert to tensor\n            seq_ids = torch.LongTensor(seq_ids)\n\n        # need to wrap in list\n        target = torch.LongTensor([y])\n\n        # `seq` has shape:\n        #     - (seq_len,) if not conjoining\n        #     - (seq_len, 2) for conjoining\n        return seq_ids, target\n"
  },
  {
    "path": "src/dataloaders/fault_tolerant_sampler.py",
    "content": "# Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytorch/utilities/test_auto_restart.py#L1397\nfrom typing import Iterator\nimport math\n\nimport torch\nfrom torch.utils.data import RandomSampler, DistributedSampler\n\n\nclass RandomFaultTolerantSampler(RandomSampler):\n\n    def __init__(self, *args, generator=None, **kwargs):\n        # generator = torch.Generator().manual_seed(seed)\n        # super().__init__(*args, generator=generator, **kwargs)\n        # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,\n        # which should be reproducible if pl.seed_everything was called before hand.\n        # This means that changing the seed of the experiment will also change the\n        # sampling order.\n        if generator is None:\n            seed = int(torch.empty((), dtype=torch.int64).random_().item())\n            generator = torch.Generator().manual_seed(seed)\n        super().__init__(*args, generator=generator, **kwargs)\n        self.counter = 0\n        # self.start_counter = 0\n        self.restarting = False\n\n    def state_dict(self):\n        return {\"random_state\": self.state, \"counter\": self.counter}\n\n    def load_state_dict(self, state_dict):\n        self.generator.set_state(state_dict.get(\"random_state\"))\n        self.counter = state_dict[\"counter\"]\n        # self.start_counter = self.counter\n        self.restarting = True\n\n    # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per\n    # epoch, and subsequent epoch will have very few batches.\n    # def __len__(self):\n    #     # We need a separate self.start_counter because PL seems to call len repeatedly.\n    #     # If we use len(self.data_source) - self.counter then PL will think the epoch ends\n    #     # when we're only half way through.\n    #     return len(self.data_source) - self.start_counter\n\n    def __iter__(self) -> Iterator[int]:\n        n = len(self.data_source)\n\n        self.state = self.generator.get_state()\n        indices = torch.randperm(n, generator=self.generator).tolist()\n\n        if not self.restarting:\n            self.counter = 0\n        else:\n            indices = indices[self.counter:]\n            self.restarting = False\n        # self.start_counter = self.counter\n\n        for index in indices:\n            self.counter += 1\n            yield index\n\n        self.counter = 0\n        # self.start_counter = self.counter\n\n\nclass FaultTolerantDistributedSampler(DistributedSampler):\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.counter = 0\n        # self.start_counter = 0\n        self.restarting = False\n\n    def state_dict(self):\n        return {\"epoch\": self.epoch, \"counter\": self.counter}\n\n    def load_state_dict(self, state_dict):\n        self.epoch = state_dict[\"epoch\"]\n        self.counter = state_dict[\"counter\"]\n        # self.start_counter = self.counter\n        self.restarting = True\n\n    # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per\n    # epoch, and subsequent epoch will have very few batches.\n    # def __len__(self) -> int:\n        # return self.num_samples - self.start_counter\n\n    def __iter__(self):\n        if self.shuffle:\n            # deterministically shuffle based on epoch and seed\n            g = torch.Generator()\n            g.manual_seed(self.seed + self.epoch)\n            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]\n        else:\n            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]\n\n        if not self.drop_last:\n            # add extra samples to make it evenly divisible\n            padding_size = self.total_size - len(indices)\n            if padding_size <= len(indices):\n                indices += indices[:padding_size]\n            else:\n                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]\n        else:\n            # remove tail of data to make it evenly divisible.\n            indices = indices[:self.total_size]\n        assert len(indices) == self.total_size\n\n        # subsample\n        indices = indices[self.rank:self.total_size:self.num_replicas]\n        assert len(indices) == self.num_samples\n\n        if not self.restarting:\n            self.counter = 0\n        else:\n            indices = indices[self.counter:]\n            self.restarting = False\n        # self.start_counter = self.counter\n\n        for index in indices:\n            self.counter += 1\n            yield index\n\n        self.counter = 0\n        # self.start_counter = self.counter"
  },
  {
    "path": "src/dataloaders/genomics.py",
    "content": "\"\"\"Dataloaders for genomics datasets, including pretraining and downstream tasks.\n\n    - Adapted from:\n        https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py\n    - Adapted from:\n        https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py\n\"\"\"\n\nimport copy\nfrom typing import Any, List, Union\n\nimport torch\nfrom datasets import Dataset\nfrom torch.utils.data.dataloader import DataLoader\n\nfrom caduceus.tokenization_caduceus import CaduceusTokenizer\nimport src.utils.train\nfrom src.dataloaders.base import SequenceDataset, default_data_path\nfrom src.dataloaders.datasets.genomic_bench_dataset import GenomicBenchmarkDataset\nfrom src.dataloaders.datasets.hg38_char_tokenizer import CharacterTokenizer\nfrom src.dataloaders.datasets.hg38_dataset import HG38Dataset\nfrom src.dataloaders.datasets.nucleotide_transformer_dataset import NucleotideTransformerDataset\nfrom src.dataloaders.fault_tolerant_sampler import FaultTolerantDistributedSampler\nfrom src.dataloaders.fault_tolerant_sampler import RandomFaultTolerantSampler\n\nlogger = src.utils.train.get_logger(__name__)\n\n\nclass HG38(SequenceDataset):\n    \"\"\"\n    Base class, other dataloaders can inherit from this class.\n\n    You must implement the following functions:\n        - __init__\n        - setup\n\n    You can then use (already have access to) the following functions:\n        - train_dataloader\n        - val_dataloader\n        - test_dataloader\n\n    \"\"\"\n    _name_ = \"hg38\"  # this name is how the dataset config finds the right dataloader\n\n    def __init__(self, bed_file, fasta_file, tokenizer_name=None, dataset_config_name=None, max_length=1024, d_output=2,\n                 rc_aug=False,\n                 max_length_val=None, max_length_test=None, val_ratio=0.0005, val_split_seed=2357,\n                 add_eos=True, detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, shuffle=False,\n                 num_workers=1,\n                 fault_tolerant=False, ddp=False,\n                 fast_forward_epochs=None, fast_forward_batches=None,\n                 mlm=False, mlm_probability=0.15,\n                 *args, **kwargs):\n        self.dataset_config_name = dataset_config_name\n        self.tokenizer_name = tokenizer_name\n        self.d_output = d_output\n        self.rc_aug = rc_aug  # reverse compliment augmentation\n        self.max_length = max_length\n        self.max_length_val = max_length_val if max_length_val is not None else max_length\n        self.max_length_test = max_length_test if max_length_test is not None else max_length\n        self.val_ratio = val_ratio\n        self.val_split_seed = val_split_seed\n        self.val_only = val_only\n        self.add_eos = add_eos\n        self.detokenize = detokenize\n        self.batch_size = batch_size\n        self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size\n        self.shuffle = shuffle\n        self.num_workers = num_workers\n        self.bed_file = bed_file\n        self.fasta_file = fasta_file\n\n        # handle if file paths are None (default paths)\n        if self.bed_file is None:\n            self.bed_file = default_data_path / self._name_ / \"human-sequences.bed\"\n        if self.fasta_file is None:\n            self.fasta_file = default_data_path / self._name_ / \"hg38.ml.fa\"\n\n        if fault_tolerant:\n            assert self.shuffle\n        self.fault_tolerant = fault_tolerant\n        if ddp:\n            assert fault_tolerant\n        self.ddp = ddp\n        self.fast_forward_epochs = fast_forward_epochs\n        self.fast_forward_batches = fast_forward_batches\n        if self.fast_forward_epochs is not None or self.fast_forward_batches is not None:\n            assert ddp and fault_tolerant\n\n        self.mlm = mlm\n        self.mlm_probability = mlm_probability\n\n        # To be instantiated in `setup`\n        self.tokenizer = None\n        self.vocab_size = 0\n\n    def setup(self, stage=None):\n        \"\"\"Set up the tokenizer and init the datasets.\"\"\"\n        # TODO instantiate with registry\n\n        if self.tokenizer_name == \"char\":\n            logger.info(\"**Using Char-level tokenizer**\")\n            # self.tokenizer = CharacterTokenizer(\n            #     characters=[\"A\", \"C\", \"G\", \"T\", \"N\"],\n            #     model_max_length=self.max_length,\n            #     add_special_tokens=False,\n            # )\n            self.tokenizer = CaduceusTokenizer(\n                model_max_length=self.max_length,\n                add_special_tokens=False\n            )\n        else:\n            raise NotImplementedError(f\"Tokenizer {self.tokenizer_name} not implemented.\")\n\n        self.vocab_size = len(self.tokenizer)\n\n        self.init_datasets()  # creates the datasets.  You can also just create this inside the setup() here.\n\n    def init_datasets(self):\n        \"\"\"Init the datasets (separate from the tokenizer)\"\"\"\n\n        # delete old datasets to free memory\n        if hasattr(self, \"dataset_train\"):\n            self.dataset_train.fasta.seqs.close()\n            del self.dataset_train.fasta.seqs\n\n        # delete old datasets to free memory\n        if hasattr(self, \"dataset_test\"):\n            self.dataset_test.fasta.seqs.close()\n            del self.dataset_test.fasta.seqs\n\n        # Create all splits: torch datasets\n        self.dataset_train, self.dataset_val, self.dataset_test = [\n            HG38Dataset(split=split,\n                        bed_file=self.bed_file,\n                        fasta_file=self.fasta_file,\n                        max_length=max_len,\n                        tokenizer=self.tokenizer,  # pass the tokenize wrapper\n                        tokenizer_name=self.tokenizer_name,\n                        add_eos=self.add_eos,\n                        return_seq_indices=False,\n                        rc_aug=self.rc_aug,\n                        return_augs=False,\n                        mlm=self.mlm,\n                        mlm_probability=self.mlm_probability, )\n            for split, max_len in\n            zip([\"train\", \"valid\", \"test\"], [self.max_length, self.max_length_val, self.max_length_test])\n        ]\n\n        return\n\n    def train_dataloader(self, **kwargs: Any) -> DataLoader:\n        \"\"\" The train dataloader \"\"\"\n        if self.shuffle and self.fault_tolerant:\n            shuffle = False\n            # TD [2022-12-26]: We need the distributed_sampler_kwargs in case of model parallel:\n            # In that case the number of replicas and the data parallel rank are more complicated.\n            distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs\n            sampler = (FaultTolerantDistributedSampler(\n                self.dataset_train,\n                **distributed_sampler_kwargs\n            ) if self.ddp else RandomFaultTolerantSampler(self.dataset_train))\n            # TD [2022-08-06]: Only the DDP sampler supports fast-forwarding for now\n            # We assume that it's being resumed with the same number of GPUs\n            if self.ddp and self.fast_forward_epochs is not None and self.fast_forward_batches is not None:\n                sampler.load_state_dict({\n                    \"epoch\": self.fast_forward_epochs,\n                    \"counter\": self.fast_forward_batches * self.batch_size\n                })\n        else:\n            shuffle = self.shuffle\n            sampler = None\n        loader = self._data_loader(self.dataset_train, batch_size=self.batch_size,\n                                   shuffle=shuffle, sampler=sampler, **kwargs)\n        return loader\n\n    def val_dataloader(self, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:\n        \"\"\" The val dataloader \"\"\"\n        kwargs[\"drop_last\"] = False\n        return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval, **kwargs)\n\n    def test_dataloader(self, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:\n        \"\"\" The test dataloader \"\"\"\n        kwargs[\"drop_last\"] = False\n        # TODO: Should have separate train and eval loaders\n        return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval, **kwargs)\n\n    @staticmethod\n    def _data_loader(dataset: Dataset, batch_size: int, shuffle: bool = False, sampler=None, **kwargs) -> DataLoader:\n        return DataLoader(\n            dataset,\n            batch_size=batch_size,\n            shuffle=shuffle,\n            sampler=sampler,\n            **kwargs,\n        )\n\n    def load_state_dict(self, checkpoint):\n        if self.fault_tolerant:\n            self.fast_forward_epochs = checkpoint[\"loops\"][\"fit_loop\"][\"epoch_progress\"][\"current\"][\"completed\"]\n            # TD [2022-08-07] [\"epoch_loop.batch_progress\"][\"total\"][\"completed\"] is 1 iteration\n            # behind, so we're using the optimizer\"s progress. This is set correctly in seq.py.\n            self.fast_forward_batches = checkpoint[\"loops\"][\"fit_loop\"][\"epoch_loop.batch_progress\"][\"current\"][\n                \"completed\"]\n        # At this point the train loader hasn't been constructed yet\n\n\nclass GenomicBenchmark(HG38):\n    _name_ = \"genomic_benchmark\"\n    l_output = 0  # need to set this for decoder to work correctly\n\n    def __init__(\n            self, dataset_name, train_val_split_seed,\n            dest_path=None, tokenizer_name=\"char\", d_output=None, rc_aug=False,\n            conjoin_train=False, conjoin_test=False,\n            max_length=1024, use_padding=True, max_length_val=None, max_length_test=None,\n            padding_side=\"left\", val_ratio=0.0005, val_split_seed=2357, add_eos=False,\n            detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1,\n            shuffle=True, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False,\n            fast_forward_epochs=None, fast_forward_batches=None, *args, **kwargs\n    ):\n\n        self.dataset_name = dataset_name\n        self.train_val_split_seed = train_val_split_seed\n        self.dest_path = dest_path\n        self.tokenizer_name = tokenizer_name\n        self.d_output = d_output\n        self.rc_aug = rc_aug\n        self.conjoin_train = conjoin_train\n        self.conjoin_test = conjoin_test\n        self.max_length = max_length\n        self.use_padding = use_padding\n        self.max_length_val = max_length_val if max_length_val is not None else max_length\n        self.max_length_test = max_length_test if max_length_test is not None else max_length\n        self.padding_side = padding_side\n        self.val_ratio = val_ratio\n        self.val_split_seed = val_split_seed\n        self.val_only = val_only\n        self.add_eos = add_eos\n        self.detokenize = detokenize\n        self.batch_size = batch_size\n        self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size\n        self.num_workers = num_workers\n        self.shuffle = shuffle\n        self.pin_memory = pin_memory\n        self.drop_last = drop_last\n\n        if self.dest_path is None:\n            self.dest_path = default_data_path / self._name_\n\n        if fault_tolerant:\n            assert self.shuffle\n        self.fault_tolerant = fault_tolerant\n        if ddp:\n            assert fault_tolerant\n        self.ddp = ddp\n        self.fast_forward_epochs = fast_forward_epochs\n        self.fast_forward_batches = fast_forward_batches\n        if self.fast_forward_epochs is not None or self.fast_forward_batches is not None:\n            assert ddp and fault_tolerant\n\n    def setup(self, stage=None):\n        # TODO instantiate with registry\n\n        if self.tokenizer_name == \"char\":\n            print(\"**Using Char-level tokenizer**\")\n            self.tokenizer = CharacterTokenizer(\n                characters=[\"A\", \"C\", \"G\", \"T\", \"N\"],\n                model_max_length=self.max_length + 2,  # add 2 since default adds eos/eos tokens, crop later\n                add_special_tokens=False,\n                padding_side=self.padding_side,\n            )\n\n        # Create all splits: torch datasets (only train/test in this benchmark, val created below)\n        self.dataset_train, self.dataset_test = [\n            GenomicBenchmarkDataset(\n                split=split,\n                max_length=max_len,\n                dataset_name=self.dataset_name,\n                tokenizer=self.tokenizer,  # pass the tokenize wrapper\n                tokenizer_name=self.tokenizer_name,\n                use_padding=self.use_padding,\n                d_output=self.d_output,\n                add_eos=self.add_eos,\n                dest_path=self.dest_path,\n                rc_aug=self.rc_aug,\n                conjoin_train=self.conjoin_train,\n                conjoin_test=self.conjoin_test,\n                return_augs=False\n            )\n            for split, max_len in zip([\"train\", \"test\"], [self.max_length, self.max_length_val])\n        ]\n\n        val_data, train_data = torch.utils.data.random_split(\n            list(zip(self.dataset_train.all_seqs, self.dataset_train.all_labels)),\n            lengths=[0.1, 0.9],\n            generator=torch.Generator().manual_seed(self.train_val_split_seed)\n        )\n        self.dataset_val = copy.deepcopy(self.dataset_train)\n        self.dataset_train.all_seqs = [train_data[i][0] for i in range(len(train_data))]\n        self.dataset_train.all_labels = [train_data[i][1] for i in range(len(train_data))]\n\n        self.dataset_val.all_seqs = [val_data[i][0] for i in range(len(val_data))]\n        self.dataset_val.all_labels = [val_data[i][1] for i in range(len(val_data))]\n        self.dataset_val.split = \"val\"\n\n\nclass NucleotideTransformer(HG38):\n    _name_ = \"nucleotide_transformer\"\n    l_output = 0  # need to set this for decoder to work correctly\n\n    def __init__(self, dataset_name, train_val_split_seed,\n                 tokenizer_name=\"char\", d_output=None, rc_aug=False,\n                 conjoin_train=False, conjoin_test=False,\n                 max_length=1024, use_padding=True, max_length_val=None, max_length_test=None,\n                 padding_side=\"left\", val_ratio=0.0005, val_split_seed=2357, add_eos=False,\n                 detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1,\n                 shuffle=True, shuffle_eval=None, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False,\n                 fast_forward_epochs=None, fast_forward_batches=None, *args, **kwargs):\n\n        self.dataset_name = dataset_name\n        self.train_val_split_seed = train_val_split_seed\n        self.tokenizer_name = tokenizer_name\n        self.d_output = d_output\n        self.rc_aug = rc_aug\n        self.conjoin_train = conjoin_train\n        self.conjoin_test = conjoin_test\n        self.max_length = max_length\n        self.use_padding = use_padding\n        self.max_length_val = max_length_val if max_length_val is not None else max_length\n        self.max_length_test = max_length_test if max_length_test is not None else max_length\n        self.padding_side = padding_side\n        self.val_ratio = val_ratio\n        self.val_split_seed = val_split_seed\n        self.val_only = val_only\n        self.add_eos = add_eos\n        self.detokenize = detokenize\n        self.batch_size = batch_size\n        self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size\n        self.num_workers = num_workers\n        self.shuffle = shuffle\n        self.shuffle_eval = shuffle_eval if shuffle_eval is not None else shuffle\n        self.pin_memory = pin_memory\n        self.drop_last = drop_last\n\n        if fault_tolerant:\n            assert self.shuffle\n        self.fault_tolerant = fault_tolerant\n        if ddp:\n            assert fault_tolerant\n        self.ddp = ddp\n        self.fast_forward_epochs = fast_forward_epochs\n        self.fast_forward_batches = fast_forward_batches\n        if self.fast_forward_epochs is not None or self.fast_forward_batches is not None:\n            assert ddp and fault_tolerant\n\n    def setup(self, stage=None):\n        # TODO instantiate with registry\n\n        if self.tokenizer_name == \"char\":\n            print(\"**Using Char-level tokenizer**\")\n            self.tokenizer = CharacterTokenizer(\n                characters=[\"A\", \"C\", \"G\", \"T\", \"N\"],\n                model_max_length=self.max_length + 2,  # add 2 since default adds eos/eos tokens, crop later\n                add_special_tokens=False,\n                padding_side=self.padding_side,\n            )\n\n        # Create all splits: torch datasets (only train/test in this benchmark)\n        # self.dataset_train, self.dataset_val = [\n        self.dataset_train, self.dataset_test = [\n            NucleotideTransformerDataset(\n                split=split,\n                max_length=max_len,\n                tokenizer=self.tokenizer,  # pass the tokenize wrapper\n                dataset_name=self.dataset_name,\n                tokenizer_name=self.tokenizer_name,\n                use_padding=self.use_padding,\n                d_output=self.d_output,\n                add_eos=self.add_eos,\n                rc_aug=self.rc_aug,\n                conjoin_train=self.conjoin_train,\n                conjoin_test=self.conjoin_test,\n                return_augs=False\n            )\n            for split, max_len in zip([\"train\", \"test\"], [self.max_length, self.max_length_val])\n        ]\n\n        ds_train_val_split = self.dataset_train.seqs.train_test_split(\n            test_size=0.1,\n            seed=self.train_val_split_seed\n        )\n        self.dataset_val = copy.deepcopy(self.dataset_train)\n        self.dataset_train.seqs = ds_train_val_split[\"train\"]\n\n        self.dataset_val.split = \"val\"\n        self.dataset_val.seqs = ds_train_val_split[\"test\"]\n"
  },
  {
    "path": "src/dataloaders/utils/mlm.py",
    "content": "import torch\n\n\ndef mlm_getitem(seq, mlm_probability=0.15, contains_eos=False, tokenizer=None, eligible_replacements=None):\n    \"\"\"Helper method for creating MLM input / target.\n\n    Adapted from:\n    https://github.com/huggingface/transformers/blob/14666775a296a76c88e1aa686a9547f393d322e2/src/transformers/data/data_collator.py#L751\n    \"\"\"\n    data = seq[:-1].clone() if contains_eos else seq.clone()  # remove eos, if applicable\n    target = data.clone()\n    # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)\n    probability_matrix = torch.full(target.shape, mlm_probability)\n    # TODO: Do we need to avoid \"masking\" special tokens as is done here?\n    #  https://github.com/huggingface/transformers/blob/14666775a296a76c88e1aa686a9547f393d322e2/src/transformers/data/data_collator.py#L760-L766\n    masked_indices = torch.bernoulli(probability_matrix).bool()\n    target[~masked_indices] = tokenizer.pad_token_id  # We only compute loss on masked tokens\n\n    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])\n    indices_replaced = torch.bernoulli(torch.full(target.shape, 0.8)).bool() & masked_indices\n    data[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)\n\n    # 10% of the time, we replace masked input tokens with random word\n    indices_random = torch.bernoulli(torch.full(target.shape, 0.5)).bool() & masked_indices & ~indices_replaced\n    if eligible_replacements is not None:\n        rand_choice = torch.randint(eligible_replacements.shape[0], size=target.shape)\n        random_words = eligible_replacements[rand_choice]\n    else:\n        random_words = torch.randint(len(tokenizer), size=target.shape, dtype=torch.long)\n    data[indices_random] = random_words[indices_random]\n    # The rest of the time (10% of the time) we keep the masked input tokens unchanged\n    return data, target\n"
  },
  {
    "path": "src/dataloaders/utils/rc.py",
    "content": "\"\"\"Utility functions for reverse complementing DNA sequences.\n\n\"\"\"\n\nfrom random import random\n\nSTRING_COMPLEMENT_MAP = {\n    \"A\": \"T\", \"C\": \"G\", \"G\": \"C\", \"T\": \"A\", \"a\": \"t\", \"c\": \"g\", \"g\": \"c\", \"t\": \"a\",\n    \"N\": \"N\", \"n\": \"n\",\n}\n\ndef coin_flip(p=0.5):\n    \"\"\"Flip a (potentially weighted) coin.\"\"\"\n    return random() > p\n\n\ndef string_reverse_complement(seq):\n    \"\"\"Reverse complement a DNA sequence.\"\"\"\n    rev_comp = \"\"\n    for base in seq[::-1]:\n        if base in STRING_COMPLEMENT_MAP:\n            rev_comp += STRING_COMPLEMENT_MAP[base]\n        # if bp not complement map, use the same bp\n        else:\n            rev_comp += base\n    return rev_comp\n"
  },
  {
    "path": "src/models/__init__.py",
    "content": ""
  },
  {
    "path": "src/models/baseline/__init__.py",
    "content": ""
  },
  {
    "path": "src/models/baseline/genomics_benchmark_cnn.py",
    "content": "\"\"\"Genomics Benchmark CNN model.\n\nAdapted from https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/blob/main/src/genomic_benchmarks/models/torch.py\n\"\"\"\n\nimport torch\nfrom torch import nn\n\n\nclass GenomicsBenchmarkCNN(nn.Module):\n    def __init__(self, number_of_classes, vocab_size, input_len, embedding_dim=100):\n        \"\"\"Genomics Benchmark CNN model.\n\n        `embedding_dim` = 100 comes from:\n        https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/tree/main/experiments/torch_cnn_experiments\n        \"\"\"\n        super(GenomicsBenchmarkCNN, self).__init__()\n\n        self.embeddings = nn.Embedding(vocab_size, embedding_dim)\n        self.cnn_model = nn.Sequential(\n            nn.Conv1d(in_channels=embedding_dim, out_channels=16, kernel_size=8, bias=True),\n            nn.BatchNorm1d(16),\n            nn.ReLU(),\n            nn.MaxPool1d(2),\n\n            nn.Conv1d(in_channels=16, out_channels=8, kernel_size=8, bias=True),\n            nn.BatchNorm1d(8),\n            nn.MaxPool1d(2),\n\n            nn.Conv1d(in_channels=8, out_channels=4, kernel_size=8, bias=True),\n            nn.BatchNorm1d(4),\n            nn.MaxPool1d(2),\n\n            nn.Flatten()\n        )\n        self.dense_model = nn.Sequential(\n            nn.Linear(self.count_flatten_size(input_len), 512),\n            # To be consistent with SSM classifier decoders, we use num_classes (even when it's binary)\n            nn.Linear(512, number_of_classes)\n        )\n\n    def count_flatten_size(self, input_len):\n        zeros = torch.zeros([1, input_len], dtype=torch.long)\n        x = self.embeddings(zeros)\n        x = x.transpose(1, 2)\n        x = self.cnn_model(x)\n        return x.size()[1]\n\n    def forward(self, x, state=None):  # Adding `state` to be consistent with other models\n        x = self.embeddings(x)\n        x = x.transpose(1, 2)\n        x = self.cnn_model(x)\n        x = self.dense_model(x)\n        return x, state  # Returning tuple to be consistent with other models\n"
  },
  {
    "path": "src/models/nn/__init__.py",
    "content": "from .activation import Activation\n"
  },
  {
    "path": "src/models/nn/activation.py",
    "content": "\"\"\"Utilities for activation functions.\"\"\"\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef Activation(activation=None, size=None, dim=-1):\n    \"\"\"Returns a PyTorch activation module.\"\"\"\n    if activation in [None, 'id', 'identity', 'linear', 'none']:\n        return nn.Identity()\n    elif activation == 'tanh':\n        return nn.Tanh()\n    elif activation == 'relu':\n        return nn.ReLU()\n    elif activation == 'gelu':\n        return nn.GELU()\n    elif activation == 'elu':\n        return nn.ELU()\n    elif activation in ['swish', 'silu']:\n        return nn.SiLU()\n    elif activation == 'glu':\n        return nn.GLU(dim=dim)\n    elif activation.startswith('glu-'):\n        return GLU(dim=dim, activation=activation[4:])\n    elif activation == 'sigmoid':\n        return nn.Sigmoid()\n    elif activation == 'softplus':\n        return nn.Softplus()\n    elif activation == 'modrelu':\n        return ModReLU(size)\n    elif activation in ['sqrelu', 'relu2']:\n        return SquaredReLU()\n    elif activation == 'laplace':\n        return Laplace()\n    # Earlier experimentation with a LN in the middle of the block instead of activation\n    # IIRC ConvNext does something like this?\n    # elif activation == 'ln':\n    #     return TransposedLN(dim)\n    else:\n        raise NotImplementedError(\"hidden activation '{}' is not implemented\".format(activation))\n\n\nclass GLU(nn.Module):\n    def __init__(self, dim=-1, activation='sigmoid'):\n        super().__init__()\n        assert not activation.startswith('glu')\n        self.dim = dim\n        self.activation_fn = Activation(activation)\n\n    def forward(self, x):\n        x, g = torch.split(x, x.size(self.dim) // 2, dim=self.dim)\n        return x * self.activation_fn(g)\n\n\nclass ModReLU(nn.Module):\n    # Adapted from https://github.com/Lezcano/expRNN\n\n    def __init__(self, features):\n        # For now we just support square layers\n        super().__init__()\n        self.features = features\n        self.b = nn.Parameter(torch.Tensor(self.features))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.b.data.uniform_(-0.01, 0.01)\n\n    def forward(self, inputs):\n        norm = torch.abs(inputs)\n        biased_norm = norm + self.b\n        magnitude = F.relu(biased_norm)\n        phase = torch.sign(inputs)\n\n        return phase * magnitude\n\n\nclass SquaredReLU(nn.Module):\n    def forward(self, x):\n        # return F.relu(x)**2\n        return torch.square(F.relu(x))  # Could this be faster?\n\n\ndef laplace(x, mu=0.707107, sigma=0.282095):\n    x = (x - mu).div(sigma * math.sqrt(2.0))\n    return 0.5 * (1.0 + torch.erf(x))\n\n\nclass Laplace(nn.Module):\n    def __init__(self, mu=0.707107, sigma=0.282095):\n        super().__init__()\n        self.mu = mu\n        self.sigma = sigma\n\n    def forward(self, x):\n        return laplace(x, mu=self.mu, sigma=self.sigma)\n"
  },
  {
    "path": "src/models/nn/adaptive_softmax.py",
    "content": "# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Optional\nimport functools\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass OptionalParameterList(nn.ParameterList):\n    def extra_repr(self):\n        child_lines = []\n        for k, p in self._parameters.items():\n            if p is not None:\n                size_str = 'x'.join(str(size) for size in p.size())\n                device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())\n                parastr = 'Parameter containing: [{} of size {}{}]'.format(\n                    torch.typename(p), size_str, device_str)\n                child_lines.append('  (' + str(k) + '): ' + parastr)\n        tmpstr = '\\n'.join(child_lines)\n        return tmpstr\n\n\nclass ProjectedAdaptiveLogSoftmax(nn.Module):\n    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,\n                 tie_projs=None, out_layers_weights=None, out_projs=None,\n                 keep_order=False,\n                 bias_scale=0.0,\n                 dropout=0.0,\n                 ):\n        super().__init__()\n\n        self.n_token = n_token\n        self.d_embed = d_embed\n        self.d_proj = d_proj\n\n        self.cutoffs = list(cutoffs) + [n_token]\n        self.cutoff_ends = [0] + self.cutoffs\n        self.div_val = div_val\n\n        self.shortlist_size = self.cutoffs[0]\n        self.n_clusters = len(self.cutoffs) - 1\n        self.head_size = self.shortlist_size + self.n_clusters\n\n        # bake the first False into the definition, just as [0] is built into the cutoffs\n        if tie_projs is None: tie_projs = []\n        elif isinstance(tie_projs, bool): tie_projs = [tie_projs] * len(cutoffs)\n        else: tie_projs = list(tie_projs)\n        tie_projs = [False] + tie_projs\n        self.tie_projs = tie_projs\n\n        if self.n_clusters > 0:\n            self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))\n            self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))\n\n        if not out_layers_weights:\n            self.out_layers_weights = nn.ParameterList()\n        else:\n            self.out_layers_weights = out_layers_weights\n\n        self.out_layers_biases = nn.ParameterList()\n\n        self.shared_out_projs = out_projs\n        self.out_projs = OptionalParameterList()\n\n        self.dropout = dropout\n        self.drop = nn.Dropout(dropout)\n\n        if div_val == 1:\n            if d_proj != d_embed:\n                for i in range(len(self.cutoffs)):\n                    if tie_projs[i]:\n                        self.out_projs.append(None)\n                    else:\n                        self.out_projs.append(\n                            nn.Parameter(torch.zeros(d_proj, d_embed))\n                        )\n            else:\n                self.out_projs.append(None)\n\n            self.out_layers_biases.append(\n                nn.Parameter(torch.zeros(n_token))\n                )\n\n            if not out_layers_weights:\n                self.out_layers_weights.append(\n                    nn.Parameter(torch.zeros(n_token, d_embed))\n                    )\n        else:\n            for i in range(len(self.cutoffs)):\n                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]\n                d_emb_i = d_embed // (div_val ** i)\n\n                if tie_projs[i]:\n                    self.out_projs.append(None)\n                else:\n                    self.out_projs.append(\n                        nn.Parameter(torch.zeros(d_proj, d_emb_i))\n                    )\n\n                self.out_layers_biases.append(\n                    nn.Parameter(torch.zeros(r_idx - l_idx))\n                    )\n                if not out_layers_weights:\n                    self.out_layers_weights.append(\n                        nn.Parameter(torch.zeros(r_idx - l_idx, d_emb_i))\n                        )\n        for bias in self.out_layers_biases:\n            bound = bias_scale * d_proj ** -.5\n            nn.init.uniform_(bias, -bound, bound)\n\n\n        self.keep_order = keep_order\n\n    def _compute_logit(self, hidden, weight, bias, proj):\n        if proj is None:\n            logit = F.linear(hidden, weight, bias=bias)\n        else:\n            if self.dropout > 0.0:\n                logit = hidden @ proj\n                logit = self.drop(logit)\n                logit = logit @ weight.t()\n            else:\n                logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))\n            if bias is not None:\n                logit = logit + bias\n        return logit\n\n    def get_out_proj(self, i):\n        if self.tie_projs[i]:\n            if len(self.shared_out_projs) == 0:\n                return None\n            elif len(self.shared_out_projs) == 1:\n                return self.shared_out_projs[0]\n            else:\n                return self.shared_out_projs[i]\n        else:\n            return self.out_projs[i]\n\n    def forward(self, hidden, target, keep_order=False, key_padding_mask=None, *args, **kwargs):\n        # [21-09-15 AG]: TODO may need to handle key_padding_mask\n        '''\n            hidden :: [len*bsz x d_proj]\n            target :: [len*bsz]\n        '''\n\n        hidden = hidden.reshape(-1, hidden.size(-1))\n        target = target.reshape(-1)\n        if hidden.size(0) != target.size(0):\n            print(hidden.shape, target.shape)\n            raise RuntimeError('Input and target should have the same size '\n                               'in the batch dimension.')\n\n        if self.n_clusters == 0:\n            logit = self._compute_logit(hidden, self.out_layers_weights[0],\n                                        self.out_layers_biases[0], self.get_out_proj(0))\n            nll = -F.log_softmax(logit, dim=-1) \\\n                    .gather(1, target.unsqueeze(1)).squeeze(1)\n        else:\n            # construct weights and biases\n            weights, biases = [], []\n            for i in range(len(self.cutoffs)):\n                if self.div_val == 1:\n                    l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]\n                    weight_i = self.out_layers_weights[0][l_idx:r_idx]\n                    bias_i = self.out_layers_biases[0][l_idx:r_idx]\n                else:\n                    weight_i = self.out_layers_weights[i]\n                    bias_i = self.out_layers_biases[i]\n\n                if i == 0:\n                    weight_i = torch.cat(\n                        [weight_i, self.cluster_weight], dim=0)\n                    bias_i = torch.cat(\n                        [bias_i, self.cluster_bias], dim=0)\n\n                weights.append(weight_i)\n                biases.append(bias_i)\n\n            head_weight, head_bias, head_proj = weights[0], biases[0], self.get_out_proj(0)\n\n            head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)\n            head_logprob = F.log_softmax(head_logit, dim=1)\n\n            nll = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device)\n\n            offset = 0\n            cutoff_values = [0] + self.cutoffs\n            for i in range(len(cutoff_values) - 1):\n                l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]\n\n                mask_i = (target >= l_idx) & (target < r_idx)\n                indices_i = mask_i.nonzero(as_tuple=False).squeeze()\n\n                if indices_i.numel() == 0:\n                    continue\n\n                target_i = target.index_select(0, indices_i) - l_idx\n                head_logprob_i = head_logprob.index_select(0, indices_i)\n\n                if i == 0:\n                    logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)\n                else:\n                    weight_i, bias_i, proj_i = weights[i], biases[i], self.get_out_proj(i)\n\n                    hidden_i = hidden.index_select(0, indices_i)\n\n                    tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)\n                    tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)\n\n                    # First term accounts for cluster probabilities\n                    logprob_i = head_logprob_i[:, -i] \\\n                        + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)\n\n                if self.keep_order or keep_order:\n                    nll.index_copy_(0, indices_i, -logprob_i)\n                else:\n                    nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)\n\n                offset += logprob_i.size(0) # TODO This should be a bug in the original implementation; it should go into the continue case above as well\n\n        return nll.mean() # TODO maybe cases for length or padding_mask\n\n    def compute_logits(self, hidden):\n        \"\"\"Compute full vector of logits\n\n        Adapted from https://github.com/kimiyoung/transformer-xl/issues/88\n        \"\"\"\n        hidden = hidden.reshape(-1, hidden.size(-1))\n\n        if self.n_clusters == 0:\n            logits = self._compute_logit(hidden, self.out_layers_weights[0],\n                                        self.out_layers_biases[0], self.get_out_proj(0))\n            return logits\n        else:\n            # construct weights and biases\n            weights, biases = [], []\n            for i in range(len(self.cutoffs)):\n                if self.div_val == 1:\n                    l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]\n                    weight_i = self.out_layers_weights[0][l_idx:r_idx]\n                    bias_i = self.out_layers_biases[0][l_idx:r_idx]\n                else:\n                    weight_i = self.out_layers_weights[i]\n                    bias_i = self.out_layers_biases[i]\n\n                if i == 0:\n                    weight_i = torch.cat(\n                        [weight_i, self.cluster_weight], dim=0)\n                    bias_i = torch.cat(\n                        [bias_i, self.cluster_bias], dim=0)\n\n                weights.append(weight_i)\n                biases.append(bias_i)\n\n            head_weight, head_bias, head_proj = weights[0], biases[0], self.get_out_proj(0)\n\n            head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)\n            head_logprob = F.log_softmax(head_logit, dim=1)\n\n            out_full_logps = [head_logprob[:, :self.cutoffs[0]]]\n            offset = 0\n            cutoff_values = [0] + self.cutoffs\n\n            for i in range(1, len(cutoff_values) - 1):\n                l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]\n                head_logprob_i = head_logprob # .index_select(0, indices_i)\n\n                if i == 0:\n                    logprob_i = head_logprob_i\n                else:\n                    weight_i, bias_i, proj_i = weights[i], biases[i], self.get_out_proj(i)\n\n                    hidden_i = hidden # .index_select(0, indices_i)\n\n                    tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)\n                    tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)\n                    logprob_i = head_logprob_i[:, -i].view(-1, 1) + tail_logprob_i\n\n                offset += logprob_i.size(0)\n                out_full_logps.append(logprob_i)\n            out_full_logps = torch.cat(out_full_logps, dim = 1)\n            # print(torch.sum(out_full_ps), out_full_ps.shape)\n            return out_full_logps\n\n\nclass AdaptiveEmbedding(nn.Module):\n    \"\"\" Copy of transformers.AdaptiveEmbedding that works with fp16 by replacing the index_put_ operation\n\n    Initialization has been fixed for the case when d_proj = d_embed\n    \"\"\"\n    def __init__(self, n_token, d_embed, d_proj, cutoffs : List[int], div_val=1, init_scale=1.0, sample_softmax=False, dropout=0.0):\n        super().__init__()\n\n        self.n_token = n_token\n        self.d_embed = d_embed\n\n        self.cutoffs = list(cutoffs) + [n_token]\n        self.div_val = div_val\n        self.d_proj = d_proj\n        self.drop = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()\n\n        self.emb_scale = d_proj ** 0.5\n\n        self.cutoff_ends = [0] + self.cutoffs\n\n        self.emb_layers = nn.ModuleList()\n        self.emb_projs = nn.ParameterList()\n        if div_val == 1:\n            self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0))\n            _init_embed(self.emb_layers[-1].weight, d_embed, init_scale)\n            # torch.nn.init.normal_(self.emb_layers[-1].weight, mean=0, std=init_scale * d_embed ** -.5)\n            if d_proj != d_embed: # TODO\n                # self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))\n                self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))\n                # torch.nn.init.normal_(self.emb_projs[-1], mean=0, std=init_scale * 1./self.emb_scale)\n                _init_proj(self.emb_projs[-1], d_proj, init_scale)\n        else:\n            for i in range(len(self.cutoffs)):\n                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]\n                d_emb_i = d_embed // (div_val ** i)\n                self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i))\n                # torch.nn.init.normal_(self.emb_layers[-1].weight, mean=0, std=init_scale * d_emb_i ** -.5)\n                _init_embed(self.emb_layers[-1].weight, d_emb_i, init_scale)\n                self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))\n                # torch.nn.init.normal_(self.emb_projs[-1], mean=0, std=init_scale * 1./self.emb_scale)\n                _init_proj(self.emb_projs[-1], d_proj, init_scale)\n\n    def forward(self, inp):\n        if self.div_val == 1:\n            embed = self.emb_layers[0](inp)\n            embed = self.drop(embed)\n            if self.d_proj != self.d_embed:\n                embed = F.linear(embed, self.emb_projs[0])\n        else:\n            param = next(self.parameters())\n            inp_flat = inp.reshape(-1)\n\n            # Changes from original impl\n            # emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device)\n            embeddings = []\n            indices = torch.zeros_like(inp_flat) # empty should work as long as cutoffs[-1] > max token\n            _total_tokens = 0\n\n            # emb_flat = inp.new_zeros(inp_flat.size(0), self.d_proj)\n            for i in range(len(self.cutoffs)):\n                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]\n\n                mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)\n                indices_i = mask_i.nonzero().squeeze(-1) # shape (_tokens,)\n\n                _tokens = indices_i.numel()\n                if _tokens == 0:\n                    continue\n\n                inp_i = inp_flat.index_select(0, indices_i) - l_idx\n                emb_i = self.emb_layers[i](inp_i)\n                emb_i = self.drop(emb_i)\n                emb_i = F.linear(emb_i, self.emb_projs[i])\n\n                # Changes\n                embeddings.append(emb_i)\n                indices.index_put_(\n                    (indices_i,),\n                    torch.arange(_tokens, device=inp.device) + _total_tokens\n                )\n                _total_tokens += _tokens\n\n                # emb_flat.index_copy_(0, indices_i, emb_i)\n            embeddings = torch.cat(embeddings, dim=0)\n            emb_flat = embeddings[indices]\n\n            embed_shape = inp.size() + (self.d_proj,)\n            embed = emb_flat.view(embed_shape)\n\n        embed.mul_(self.emb_scale)\n        # embed.div_(self.emb_scale)\n\n        return embed\n\n\ndef _init_weight(weight, d : int, init_scale : Optional[float], default=None):\n    assert init_scale or default\n    if init_scale is None:\n        std = default\n    else:\n        std = init_scale * (d ** -0.5)\n    nn.init.normal_(weight, mean=0, std=std)\n\n_init_embed = functools.partial(_init_weight, default=0.02)\n_init_proj = functools.partial(_init_weight, default=0.01)\n"
  },
  {
    "path": "src/models/nn/utils.py",
    "content": "\"\"\" Utility wrappers around modules to let them handle Args and extra arguments \"\"\"\n\nimport inspect\nfrom functools import wraps\nimport torch\nfrom torch import nn\n\ndef wrap_kwargs(f):\n    \"\"\"\n    Given a callable f that can consume some named arguments,\n    wrap it with a kwargs that passes back any unused args\n\n    EXAMPLES\n    --------\n\n    Basic usage:\n    def foo(x, y=None):\n        return x\n\n    wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2})\n\n    --------\n\n    The wrapped function can return its own argument dictionary,\n    which gets merged with the new kwargs.\n    def foo(x, y=None):\n        return x, {}\n    wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2})\n\n    def foo(x, y=None):\n        return x, {\"y\": y, \"z\": None}\n    wrap_kwargs(foo)(0, y=1, z=2) == (0, {'y': 1, 'z': 2})\n\n    --------\n\n    The wrapped function can have its own kwargs parameter:\n    def foo(x, y=None, **kw_args):\n        return x, {}\n    wrap_kwargs(foo)(0, y=1, z=2) == (0, {})\n\n    --------\n\n    Partial functions and modules work automatically:\n    class Module:\n        def forward(self, x, y=0):\n            return x, {\"y\": y+1}\n\n    m = Module()\n\n    wrap_kwargs(m.forward)(0, y=1, z=2) == (0, {'y': 2, 'z': 2})\n\n    \"\"\"\n    sig = inspect.signature(f)\n    # Check if f already has kwargs\n    has_kwargs = any([\n        param.kind == inspect.Parameter.VAR_KEYWORD\n        for param in sig.parameters.values()\n    ])\n    if has_kwargs:\n        @wraps(f)\n        def f_kwargs(*args, **kwargs):\n            y = f(*args, **kwargs)\n            if isinstance(y, tuple) and isinstance(y[-1], dict):\n                return y\n            else:\n                return y, {}\n    else:\n        param_kwargs = inspect.Parameter(\"kwargs\", kind=inspect.Parameter.VAR_KEYWORD)\n        sig_kwargs = inspect.Signature(parameters=list(sig.parameters.values())+[param_kwargs])\n        @wraps(f)\n        def f_kwargs(*args, **kwargs):\n            bound = sig_kwargs.bind(*args, **kwargs)\n            if \"kwargs\" in bound.arguments:\n                kwargs = bound.arguments.pop(\"kwargs\")\n            else:\n                kwargs = {}\n            y = f(**bound.arguments)\n            if isinstance(y, tuple) and isinstance(y[-1], dict):\n                return *y[:-1], {**y[-1], **kwargs}\n            else:\n                return y, kwargs\n    return f_kwargs\n\ndef discard_kwargs(f):\n    if f is None: return None\n    f_kwargs = wrap_kwargs(f)\n    @wraps(f)\n    def f_(*args, **kwargs):\n        return f_kwargs(*args, **kwargs)[0]\n    return f_\n\ndef PassthroughSequential(*modules):\n    \"\"\"Special Sequential module that chains kwargs.\n\n    Semantics are the same as nn.Sequential, with extra convenience features:\n    - Discard None modules\n    - Flatten inner Sequential modules\n    - In case with 0 or 1 Module, rename the class for ease of inspection\n    \"\"\"\n    def flatten(module):\n        if isinstance(module, nn.Sequential):\n            return sum([flatten(m) for m in module], [])\n        else:\n            return [module]\n\n    modules = flatten(nn.Sequential(*modules))\n    modules = [module for module in modules if module if not None]\n\n    class Sequential(nn.Sequential):\n        def forward(self, x, **kwargs):\n            for layer in self:\n                x, kwargs = wrap_kwargs(layer.forward)(x, **kwargs)\n            return x, kwargs\n\n        def step(self, x, **kwargs):\n            for layer in self:\n                fn = getattr(layer, \"step\", layer.forward)\n                x, kwargs = wrap_kwargs(fn)(x, **kwargs)\n            return x, kwargs\n\n    if len(modules) == 0:\n        Sequential.__name__ = \"Identity\"\n    elif len(modules) == 1:\n        Sequential.__name__ = type(modules[0]).__name__\n    return Sequential(*modules)\n"
  },
  {
    "path": "src/models/sequence/__init__.py",
    "content": ""
  },
  {
    "path": "src/models/sequence/dna_embedding.py",
    "content": "\"\"\"DNA Embedding Model.\n\nBackbones from LM pre-training models, used for downstream tasks.\n\"\"\"\n\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\nfrom flash_attn.utils.generation import GenerationMixin\nfrom mamba_ssm.models.config_mamba import MambaConfig\nfrom mamba_ssm.models.mixer_seq_simple import MixerModel\nfrom mamba_ssm.models.mixer_seq_simple import _init_weights as _init_weights_mamba\n\ntry:\n    from flash_attn.ops.fused_dense import ColumnParallelLinear\nexcept ImportError:\n    ColumnParallelLinear = None\n\n\nfrom caduceus.configuration_caduceus import CaduceusConfig\nfrom caduceus.modeling_caduceus import Caduceus\nfrom src.models.sequence.long_conv_lm import LMBackbone\nfrom src.models.sequence.long_conv_lm import _init_weights\n\n\nclass DNAEmbeddingModel(nn.Module, GenerationMixin):\n    \"\"\"DNA Embedding Model.\n\n    Same as ConvLMHeadModel (in long_conv_lm.py), except no decoder head, we just pass back the hidden states for\n    downstream tasks.\n    \"\"\"\n\n    def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int,\n                 process_group=None, layer=None,\n                 attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0,\n                 resid_dropout: float = 0.0, embed_dropout: float = 0.1, dropout_cls=nn.Dropout,\n                 norm_epsilon: float = 1e-5,\n                 rms_norm: bool = False,\n                 initializer_cfg=None,\n                 checkpoint_mlp=False,\n                 checkpoint_mixer=False,\n                 fused_mlp=False, fused_dropout_add_ln=False, residual_in_fp32=False,\n                 pad_vocab_size_multiple: int = 1, sequence_parallel=True,\n                 device=None, dtype=None, return_hidden_state=False, **kwargs) -> None:\n        factory_kwargs = {'device': device, 'dtype': dtype}\n        super().__init__()\n        self.d_model = d_model  # for decoder\n        self.process_group = process_group\n        self.return_hidden_state = return_hidden_state\n        if vocab_size % pad_vocab_size_multiple != 0:\n            vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)\n        self.backbone = LMBackbone(\n            d_model=d_model,\n            n_layer=n_layer,\n            d_inner=d_inner,\n            vocab_size=vocab_size,\n            process_group=process_group,\n            layer=layer,\n            attn_layer_idx=attn_layer_idx,\n            attn_cfg=attn_cfg,\n            max_position_embeddings=max_position_embeddings,\n            resid_dropout=resid_dropout,\n            embed_dropout=embed_dropout,\n            dropout_cls=dropout_cls,\n            norm_epsilon=norm_epsilon,\n            rms_norm=rms_norm,\n            initializer_cfg=initializer_cfg,\n            fused_mlp=fused_mlp,\n            fused_dropout_add_ln=fused_dropout_add_ln,\n            residual_in_fp32=residual_in_fp32,\n            sequence_parallel=sequence_parallel,\n            checkpoint_mlp=checkpoint_mlp,\n            checkpoint_mixer=checkpoint_mixer,\n            **factory_kwargs, **kwargs\n        )\n\n        # Initialize weights and apply final processing\n        self.apply(partial(_init_weights, n_layer=n_layer,\n                           **(initializer_cfg if initializer_cfg is not None else {})))\n\n    def forward(self, input_ids, position_ids=None, inference_params=None, state=None):  # state for the repo interface\n        \"\"\"DNA Embedding Model forward pass.\"\"\"\n        hidden_states = self.backbone(input_ids, position_ids=position_ids,\n                                      inference_params=inference_params)\n        # we only need the last hidden state for embeddings (decoder head will predict classification task)\n        return hidden_states, None\n\n    @property\n    def d_output(self):\n        \"\"\"Model /embedding dimension, used for decoder mapping.\n\n        \"\"\"\n        if getattr(self, \"d_model\", None) is None:\n            raise NotImplementedError(\"SequenceModule instantiation must set d_output\")\n        return self.d_model\n\n\nclass DNAEmbeddingModelMamba(DNAEmbeddingModel):\n    \"\"\"Custom DNA Embedding Model that is compatible with open-source Mamba repo.\"\"\"\n\n    def __init__(\n            self,\n            config: MambaConfig,\n            initializer_cfg=None,\n            conjoin_train=False,\n            conjoin_test=False,\n            device=None,\n            dtype=None,\n    ):\n        super(DNAEmbeddingModel, self).__init__()  # nn.Module.__init__()\n        self.config = config\n        d_model = config.d_model\n        self.d_model = d_model  # for decoder\n        n_layer = config.n_layer\n        vocab_size = config.vocab_size\n        ssm_cfg = config.ssm_cfg\n        rms_norm = config.rms_norm\n        residual_in_fp32 = config.residual_in_fp32\n        fused_add_norm = config.fused_add_norm\n        pad_vocab_size_multiple = config.pad_vocab_size_multiple\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n\n        if vocab_size % pad_vocab_size_multiple != 0:\n            vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)\n        self.backbone = MixerModel(\n            d_model=d_model,\n            n_layer=n_layer,\n            vocab_size=vocab_size,\n            ssm_cfg=ssm_cfg,\n            rms_norm=rms_norm,\n            initializer_cfg=initializer_cfg,\n            fused_add_norm=fused_add_norm,\n            residual_in_fp32=residual_in_fp32,\n            **factory_kwargs,\n        )\n        # Initialize weights and apply final processing\n        self.apply(\n            partial(\n                _init_weights_mamba,\n                n_layer=n_layer,\n                **(initializer_cfg if initializer_cfg is not None else {}),\n            )\n        )\n\n        self.conjoin_train = conjoin_train\n        self.conjoin_test = conjoin_test\n\n    def forward(self, input_ids, position_ids=None, inference_params=None, state=None):  # state for the repo interface\n        \"\"\"Mamba backbone-specific forward pass that does not use `position_ids`.\"\"\"\n        hidden_states = self.backbone(input_ids, inference_params=inference_params)\n        # we only need the last hidden state for embeddings (decoder head will predict classification task)\n        return hidden_states, None\n\n\nclass DNAEmbeddingModelCaduceus(DNAEmbeddingModel):\n    \"\"\"Custom DNA Embedding Model that is compatible with Caduceus models.\"\"\"\n\n    def __init__(\n            self,\n            config: CaduceusConfig,\n            device=None,\n            dtype=None,\n            conjoin_train=False,\n            conjoin_test=False,\n    ):\n        super(DNAEmbeddingModel, self).__init__()  # nn.Module.__init__()\n        self.config = config\n        self.d_model = config.d_model  # for decoder\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        self.caduceus = Caduceus(\n            config=config,\n            **factory_kwargs,\n        )\n\n        self.conjoin_train = conjoin_train\n        self.conjoin_test = conjoin_test\n\n    def forward(self, input_ids, position_ids=None, inference_params=None, state=None):  # state for the repo interface\n        \"\"\"Caduceus backbone-specific forward pass that does not use `position_ids`.\"\"\"\n        if self.config.rcps:  # Hidden states have 2 * d_model channels for RCPS\n            hidden_states = self.caduceus(input_ids, return_dict=False)\n            num_chan = hidden_states.shape[-1]\n            return torch.stack(\n                [hidden_states[..., :num_chan // 2], torch.flip(hidden_states[..., num_chan // 2:], dims=[1, 2])],\n                dim=-1\n            ), None\n        if self.conjoin_train or (self.conjoin_test and not self.training):  # For conjoining / post-hoc conjoining\n            assert input_ids.ndim == 3, \"Input must be 3D tensor, where channels corresponds to forward and rc strands\"\n            hidden_states = self.caduceus(input_ids[..., 0], return_dict=False)\n            hidden_states_rc = self.caduceus(input_ids[..., 1], return_dict=False)\n            # Stack along channel dimension (dim=-1)\n            return torch.stack([hidden_states, hidden_states_rc], dim=-1), None\n\n        return self.caduceus(input_ids, return_dict=False), None\n\n\ndef load_backbone(model, state_dict, freeze_backbone=False, ignore_head=True):\n    \"\"\"\n\n    Modifies state dict loading with custom function.  This is necessary because the head of\n    a lm outputs logits for vocab, but we just need the embeddings for downstream tasks.\n\n    inputs:\n        model: nn.Module, the from 'scratch' model\n        state_dict: dict, from the pretrained weights\n        ignore_head: bool, whether to inflate weights in the head (or keep scratch weights).\n            If number of classes changes, then you need to use this.\n\n    return:\n        state_dict: dict, update with inflated weights\n    \"\"\"\n\n    # consumes prefix from pretrained model, if necessary\n    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(\n        state_dict, \"model.\"\n    )\n\n    model_new_params_dict = model.state_dict()\n    updated_model_state_dict = {}\n\n    # loop through scratch model keys (pretrained may have extra stuff)\n    for key in sorted(model_new_params_dict.keys()):\n\n        loaded_params = state_dict.get(key, None)\n        if loaded_params is None:\n            # This should never happen, it should be there!\n            print(\"Missing key in pretrained model!\", key)\n            raise Exception\n\n        elif ignore_head and 'head' in key:\n            # ignore head weights\n            print(\"found head key / parameter, load from scratch\", key)\n            # using scratch by default, nothing needed\n            used_params = model_new_params_dict[key]\n\n        elif \"decoder\" in key:\n            print(\"found decoder key / parameter, load from scratch\", key)\n            used_params = model_new_params_dict[key]\n        else:\n            print('key: shape MATCH, loading', key)  # load matched weights\n            used_params = loaded_params\n\n        # we need to pass back a state dict with the '.model' prefix!!!!!\n        key_with_prefix = 'model.' + key\n        updated_model_state_dict[key_with_prefix] = used_params\n\n    if freeze_backbone:\n        print(\"freezing model backbone params!\")\n        # note, decoder not included in backbone\n        for name, param in model.named_parameters():\n            param.requires_grad = False\n\n    # we have updated the new model state dict with pretrained now\n    return updated_model_state_dict\n"
  },
  {
    "path": "src/models/sequence/hyena.py",
    "content": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\n\ntry:\n    from src.ops.fftconv import fftconv_ref, fftconv_func, fftconv_heads_ref\n\nexcept ImportError:\n    fftconv_func = None\n\ntry:\n    from flash_attn.ops.fused_dense import FusedDense\nexcept ImportError:\n    FusedDense = None\n\nimport src.utils.registry as registry\nfrom src.utils.train import OptimModule\nfrom src.utils.config import instantiate, auto_assign_attrs\nfrom src.models.nn import Activation\n\n\nclass FFTConvFuncv2(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, u, k):\n        seqlen = u.shape[-1]\n        if len(u.shape) > 3:\n            k = k.unsqueeze(1)\n        fft_size = 2 * seqlen\n\n        k_f = torch.fft.rfft(k, n=fft_size) / fft_size\n        u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)\n        y = torch.fft.irfft(u_f * k_f, n=fft_size, norm=\"forward\")[..., :seqlen]\n        ctx.save_for_backward(u_f, k_f)\n        return y\n\n    @staticmethod\n    def backward(ctx, dout):\n        u_f, k_f = ctx.saved_tensors\n        seqlen = dout.shape[-1]\n        fft_size = 2 * seqlen\n\n        dout_f = torch.fft.rfft(dout, n=fft_size)\n        du = torch.fft.irfft(dout_f * k_f.conj(), n=fft_size, norm=\"forward\")[\n            ..., :seqlen\n        ]\n        dk = torch.fft.irfft(dout_f * u_f.conj(), n=fft_size, norm=\"forward\")[\n            ..., :seqlen\n        ]\n        return du, dk.squeeze()\n\n\ndef fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None):\n    seqlen = u.shape[-1]\n    fft_size = 2 * seqlen\n    k_f = torch.fft.rfft(k, n=fft_size) / fft_size\n    if k_rev is not None:\n        k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size\n        k_f = k_f + k_rev_f.conj()\n    u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)\n\n    if len(u.shape) > 3:\n        k_f = k_f.unsqueeze(1)\n\n    y = torch.fft.irfft(u_f * k_f, n=fft_size, norm=\"forward\")[..., :seqlen]\n\n    out = y + u * D.unsqueeze(-1)\n    if gelu:\n        out = F.gelu(out)\n    if dropout_mask is not None:\n        return (out * rearrange(dropout_mask, \"b H -> b H 1\")).to(dtype=u.dtype)\n    else:\n        return out.to(dtype=u.dtype)\n\n\n@torch.jit.script\ndef mul_sum(q, y):\n    return (q * y).sum(dim=1)\n\n\nclass Sin(nn.Module):\n    def __init__(self, dim, w=10, train_freq=True):\n        super().__init__()\n        self.freq = (\n            nn.Parameter(w * torch.ones(1, dim))\n            if train_freq\n            else w * torch.ones(1, dim)\n        )\n\n    def forward(self, x):\n        return torch.sin(self.freq * x)\n\n\nclass PositionalEmbedding(OptimModule):\n    def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float = 1e-5, **kwargs):\n        \"\"\"Complex exponential positional embeddings for Hyena filters.\"\"\"\n        super().__init__()\n\n        self.seq_len = seq_len\n        # The time embedding fed to the filteres is normalized so that t_f = 1\n        t = torch.linspace(0, 1, self.seq_len)[None, :, None]  # 1, L, 1\n\n        if emb_dim > 1:\n            bands = (emb_dim - 1) // 2\n        # To compute the right embeddings we use the \"proper\" linspace\n        t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]\n        w = 2 * math.pi * t_rescaled / seq_len  # 1, L, 1\n\n        f = torch.linspace(1e-4, bands - 1, bands)[None, None]\n        z = torch.exp(-1j * f * w)\n        z = torch.cat([t, z.real, z.imag], dim=-1)\n        self.register(\"z\", z, lr=lr_pos_emb)\n        self.register(\"t\", t, lr=0.0)\n\n    def forward(self, L):\n        return self.z[:, :L], self.t[:, :L]\n\n\nclass ExponentialModulation(OptimModule):\n    def __init__(\n        self,\n        d_model,\n        fast_decay_pct=0.3,\n        slow_decay_pct=1.5,\n        target=1e-2,\n        modulation_lr=0.0,\n        shift: float = 0.0,\n        **kwargs,\n    ):\n        super().__init__()\n        self.shift = shift\n        max_decay = math.log(target) / fast_decay_pct\n        min_decay = math.log(target) / slow_decay_pct\n        deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]\n        self.register(\"deltas\", deltas, lr=modulation_lr)\n\n    def forward(self, t, x):\n        decay = torch.exp(-t * self.deltas.abs())\n        x = x * (decay + self.shift)\n        return x\n\n\nclass HyenaFilter(OptimModule):\n    def __init__(\n        self,\n        d_model,\n        emb_dim=3,  # dim of input to MLP, augments with positional encoding\n        order=16,  # width of the implicit MLP\n        fused_fft_conv=False,\n        seq_len=1024,\n        lr=1e-3,\n        lr_pos_emb=1e-5,\n        dropout=0.0,\n        w=1,  # frequency of periodic activations\n        wd=0,  # weight decay of kernel parameters\n        bias=True,\n        num_inner_mlps=2,\n        linear_mixer=False,\n        modulate: bool = True,\n        normalized=False,\n        **kwargs,\n    ):\n        \"\"\"\n        Implicit long filter with modulation.\n        Args:\n            d_model: number of channels in the input\n            emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands\n            order: width of the FFN\n            num_inner_mlps: number of inner linear layers inside filter MLP\n        Note:\n            filter_dropout is not implemented\n        \"\"\"\n        super().__init__()\n        auto_assign_attrs(\n            self, d_model=d_model, emb_dim=emb_dim, seq_len=seq_len, modulate=modulate\n        )\n        self.use_bias = bias\n        self.fused_fft_conv = fused_fft_conv\n        self.bias = nn.Parameter(torch.randn(self.d_model))\n        self.dropout = nn.Dropout(dropout)\n\n        act = Sin(dim=order, w=w)\n        assert (\n            emb_dim % 2 != 0 and emb_dim >= 3\n        ), \"emb_dim must be odd and greater or equal to 3 (time, sine and cosine)\"\n        self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb)\n\n        # uses a variable number of inner linear layers\n        if linear_mixer is False:\n            self.implicit_filter = nn.Sequential(\n                nn.Linear(emb_dim, order),\n                act,\n            )\n            for i in range(num_inner_mlps):\n                self.implicit_filter.append(nn.Linear(order, order))\n                self.implicit_filter.append(act)\n            # final linear layer\n            self.implicit_filter.append(nn.Linear(order, d_model, bias=False))\n        else:\n            self.implicit_filter = nn.Sequential(\n                nn.Linear(emb_dim, d_model, bias=False),\n            )\n\n        self.modulation = ExponentialModulation(d_model, **kwargs)\n\n        self.normalized = normalized\n        for c in self.implicit_filter.children():\n            for name, v in c.state_dict().items():\n                optim = {\"weight_decay\": wd, \"lr\": lr}\n                setattr(getattr(c, name), \"_optim\", optim)\n\n    def filter(self, L, *args, **kwargs):\n        z, t = self.pos_emb(L)\n        h = self.implicit_filter(z)\n        if self.modulate:\n            h = self.modulation(t, h)\n\n        if self.normalized:\n            h = h / torch.norm(h, dim=-1, p=1, keepdim=True)\n\n        return h\n\n    def forward(self, x, L, k=None, bias=None, *args, **kwargs):\n        if k is None:\n            k = self.filter(L)\n\n        # Ensure compatibility with filters that return a tuple\n        k = k[0] if type(k) is tuple else k\n        if bias is None:\n            bias = self.bias\n        bias = bias if self.use_bias else 0 * bias\n\n        if self.fused_fft_conv:\n            bias = bias.to(dtype=torch.float32)\n            y = fftconv_func(\n                x,\n                k,\n                bias,\n                dropout_mask=None,\n                gelu=False,\n                force_fp16_output=torch.is_autocast_enabled(),\n            )\n        else:\n            y = fftconv_ref(x, k, bias, dropout_mask=None, gelu=False)\n            # y = (\n            #     FFTConvFuncv2.apply(x, k.to(dtype=torch.float32))\n            #     + bias.unsqueeze(-1) * x\n            # )\n\n        return y.to(dtype=x.dtype)\n\n\nclass HyenaOperator(nn.Module):\n    def __init__(\n        self,\n        d_model,\n        l_max,\n        order=2,\n        filter_order=64,\n        num_heads=1,\n        inner_factor=1,\n        num_blocks=1,\n        fused_bias_fc=False,\n        outer_mixing=False,\n        dropout=0.0,\n        filter_dropout=0.0,\n        filter_cls=\"hyena-filter\",\n        post_order_ffn=False,\n        jit_filter=False,\n        short_filter_order=3,\n        activation=\"id\",\n        return_state=False,\n        **filter_args,\n    ):\n        r\"\"\"\n        Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf\n        Args:\n            d_model (int): Dimension of the input and output embeddings (width of the layer)\n            l_max: (int): Maximum input sequence length. Defaults to None\n            order: (int): Depth of the Hyena recurrence. Defaults to 2\n            filter_order: (int): Width of the FFN parametrizing the implicit filter. Defaults to 64\n            num_heads: (int): Number of heads. Defaults to 1\n            inner_factor: (int): Width multiplier. Defaults to 1\n            num_blocks: (int): Number of blocks in sequence length. Defaults to 1\n            fused_bias_fc: (bool): Whether to use fused bias FC. Defaults to False\n            dropout: (float): Dropout probability. Defaults to 0.0\n            filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0\n            post_order_ffn: (bool): Apply a dense layer between steps of the recurrence. Defaults to False\n            jit_filter: (bool): Whether JIT the implicit filter function. Defaults to False\n            short_filter_order: (int): Length of the explicit input convolutional filter. Defaults to 3\n            activation: (str): type of act between kernel output and FF (default identity)\n            return_state: (bool): whether to return a state\n        \"\"\"\n        super().__init__()\n        assert (\n            d_model % num_heads == 0\n        ), f\"Model dimension {d_model} must be divisible by num heads {num_heads}\"\n        assert (\n            l_max % num_blocks == 0\n        ), f\"Maximum signal length {l_max} must be divisible by block dimension {num_blocks}\"\n        block_dim = l_max // num_blocks\n        head_dim = d_model // num_heads\n\n        auto_assign_attrs(\n            self,\n            d_model=d_model,\n            order=order,\n            l_max=l_max,\n            num_heads=num_heads,\n            inner_factor=inner_factor,\n            block_dim=block_dim,\n            head_dim=head_dim,\n            filter_order=filter_order,\n            post_order_ffn=post_order_ffn,\n            short_filter_order=short_filter_order,\n            num_blocks=num_blocks,\n            filter_dropout=filter_dropout,\n            jit_filter=jit_filter,\n            outer_mixing=outer_mixing,\n            activation=activation,\n            return_state=return_state,\n        )\n        self.activation = Activation(activation)\n        self.dropout = nn.Dropout(dropout)\n        self.setup_projections(fused_bias_fc, inner_factor)\n        self.setup_filters(filter_cls, filter_args)\n\n    def setup_projections(self, fused_bias_fc, inner_factor):\n        \"Initializes input and output projections (over the width dimension)\"\n        if fused_bias_fc and FusedDense is None:\n            raise ImportError(\"fused_dense is not installed\")\n        linear_cls = nn.Linear if not fused_bias_fc else FusedDense\n        self.out_proj = linear_cls(self.d_model * inner_factor, self.d_model)\n        self.in_proj = linear_cls(self.d_model, (self.order + 1) * self.d_model)\n        if self.post_order_ffn:\n            self.ord_proj_w = nn.Parameter(\n                torch.randn(self.order, self.num_heads, self.num_heads)\n                / math.sqrt(self.head_dim)\n            )\n\n    def setup_filters(self, filter_cls, filter_args):\n        \"Initializes the explicit and implicit filters\"\n        assert self.order >= 2, f\"Order must be at least 2, (got {self.order})\"\n        total_width = self.d_model * self.inner_factor * (self.order + 1)\n\n        self.short_filter = nn.Conv1d(\n            in_channels=total_width,\n            out_channels=total_width,\n            kernel_size=self.short_filter_order,\n            groups=total_width,\n            padding=self.short_filter_order - 1,\n        )\n\n        filter_cls = instantiate(registry.layer, filter_cls, partial=True)\n\n        self.filter_fn = filter_cls(\n            self.head_dim * self.inner_factor * (self.order - 1),\n            order=self.filter_order,\n            seq_len=self.l_max,\n            channels=1,\n            dropout=self.filter_dropout,\n            **filter_args,\n        )\n        if self.jit_filter:\n            self.filter_fn = torch.jit.script(self.filter_fn, self.L)\n\n    def recurrence(self, u, state):\n        \"Fast inference mode via distilled recurrence\"\n        raise NotImplementedError(\"Working on it!\")\n\n    def forward(self, u, *args, **kwargs):\n        l = u.size(-2)\n        l_filter = min(l, self.l_max)\n        u = self.in_proj(u)\n        u = rearrange(u, \"b l d -> b d l\")\n\n        uc = self.short_filter(u)[..., :l_filter]\n\n        uc = rearrange(\n            uc,\n            \"b (ho v) (z l) -> b ho v z l\",\n            z=self.num_blocks,\n            ho=self.num_heads,\n            v=self.head_dim * (self.order + 1),\n        )\n\n        *x, v = uc.split(self.d_model, dim=2)\n        k = self.filter_fn.filter(l_filter)\n\n        # `c` is always 1 by default\n        k = rearrange(k, \"c l (v o) -> c o v l\", v=self.head_dim, o=self.order - 1)[0]\n\n        bias = rearrange(\n            self.filter_fn.bias, \"(v o) -> o v\", v=self.head_dim, o=self.order - 1\n        )\n\n        for o, x_i in enumerate(reversed(x[1:])):\n            if self.outer_mixing:\n                v = rearrange(v, \"b h v z l -> b h 1 v z l\")\n                v = self.dropout(v * rearrange(x_i, \"b h v z l -> b h v 1 z l\"))\n                v = v.sum(dim=2)\n            else:\n                v = self.dropout(v * x_i)\n\n            # the bias term is broadcasted. Last dimension (l) is handled by fftconv\n            v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None])\n\n            if self.post_order_ffn:\n                w = self.ord_proj_w[o]\n                v = mul_sum(\n                    rearrange(w, \"h1 h2 -> 1 h1 h2 1 1 1\"),\n                    rearrange(v, \"b h v z l -> b h 1 v z l\"),\n                )\n\n        y = self.activation(\n            rearrange(\n                v * x[0],\n                \"b h v z l -> b (z l) (h v)\",\n                z=self.num_blocks,\n                h=self.num_heads,\n            )\n        )\n        y = self.out_proj(y)\n\n        if self.return_state:\n            return y, None\n        return y\n\n    @property\n    def d_output(self):\n        return self.d_model\n\n"
  },
  {
    "path": "src/models/sequence/long_conv_lm.py",
    "content": "import copy\nimport math\nimport re\nfrom collections import namedtuple\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom flash_attn.modules.block import Block\nfrom flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings\nfrom flash_attn.modules.mha import MHA, ParallelMHA\nfrom flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP\nfrom flash_attn.utils.distributed import sync_shared_params, all_gather_raw\nfrom flash_attn.utils.generation import GenerationMixin\nfrom torch.utils.checkpoint import checkpoint\n\ntry:\n    from flash_attn.ops.fused_dense import ColumnParallelLinear\nexcept ImportError:\n    ColumnParallelLinear = None\n\ntry:\n    from flash_attn.ops.layer_norm import dropout_add_layer_norm\nexcept ImportError:\n    dropout_add_layer_norm = None\n\nfrom src.utils import instantiate\nimport src.utils.registry as registry\n\n\nclass CheckpointedModule(torch.nn.Module):\n    def __init__(self, layer):\n        super().__init__()\n        self.layer = layer\n\n    def forward(self, x):\n        return checkpoint(self.layer, x)\n\n\ndef create_mixer_cls(\n    layer=None,\n    process_group=None,\n    attn_layer_idx=None,\n    attn_cfg=None,\n    layer_idx=None,\n    sequence_parallel=True,\n    device=None,\n    dtype=None,\n):\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n    parallel_kwargs = (\n        {\"process_group\": process_group, \"sequence_parallel\": sequence_parallel}\n        if process_group is not None\n        else {}\n    )\n    if attn_layer_idx is not None and layer_idx in attn_layer_idx:\n        causal = True if attn_cfg is None else attn_cfg.pop(\"causal\", True)\n        fused_bias_fc = (\n            False if attn_cfg is None else attn_cfg.get(\"fused_bias_fc\", False)\n        )\n        if not fused_bias_fc:\n            assert process_group is None, \"TensorParallel MHA requires fused_bias_fc\"\n        mha_cls = MHA if process_group is None else ParallelMHA\n        # ParallelMHA doesn't take 'fused_bias_fc', it is assumed that we fuse matmul + bias\n        if process_group is not None:\n            attn_cfg = copy.deepcopy(attn_cfg)  # Don't modify the original cfg\n            attn_cfg.pop(\"fused_bias_fc\", None)\n        mixer_cls = partial(\n            mha_cls,\n            causal=causal,\n            layer_idx=layer_idx,\n            **(attn_cfg if attn_cfg is not None else {}),\n            **parallel_kwargs,\n            **factory_kwargs,\n        )\n    else:\n        fused_bias_fc = False if layer is None else layer.get(\"fused_bias_fc\", False)\n        if process_group is not None:\n            assert fused_bias_fc, \"TensorParallel SSM requires fused_bias_fc\"\n        mixer_cls = instantiate(\n            registry.layer,\n            layer,\n            partial=True,\n            layer_idx=layer_idx,\n            **factory_kwargs,\n            **parallel_kwargs,\n        )\n    return mixer_cls\n\n\ndef create_mlp_cls(\n    d_model,\n    d_inner=None,\n    process_group=None,\n    fused_mlp=False,\n    sequence_parallel=True,\n    identity_mlp=False,\n    device=None,\n    dtype=None,\n):\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n    inner_dim = d_inner if d_inner is not None else 4 * d_model\n    if process_group is not None:\n        assert fused_mlp, \"Tensor Parallel is only implemented for FusedMLP\"\n\n    if not fused_mlp and not identity_mlp:\n        mlp_cls = partial(\n            Mlp,\n            hidden_features=inner_dim,\n            activation=partial(F.gelu, approximate=\"tanh\"),\n            **factory_kwargs,\n        )\n    elif fused_mlp:\n        mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP\n        parallel_kwargs = (\n            {\"process_group\": process_group, \"sequence_parallel\": sequence_parallel}\n            if process_group is not None\n            else {}\n        )\n        mlp_cls = partial(\n            mlp_cls, hidden_features=inner_dim, **parallel_kwargs, **factory_kwargs\n        )\n    else:\n        mlp_cls = nn.Identity\n    return mlp_cls\n\n\ndef create_block(\n    d_model,\n    d_inner=None,\n    process_group=None,\n    layer=None,\n    attn_layer_idx=None,\n    attn_cfg=None,\n    layer_norm_epsilon=1e-5,\n    resid_dropout1=0.0,\n    resid_dropout2=0.0,\n    residual_in_fp32=False,\n    fused_mlp=False,\n    identity_mlp=False,\n    fused_dropout_add_ln=False,\n    layer_idx=None,\n    sequence_parallel=True,\n    checkpoint_mlp=False,\n    checkpoint_mixer=False,\n    device=None,\n    dtype=None,\n):\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n    mixer_cls = create_mixer_cls(\n        layer=layer,\n        process_group=process_group,\n        attn_layer_idx=attn_layer_idx,\n        attn_cfg=attn_cfg,\n        layer_idx=layer_idx,\n        sequence_parallel=sequence_parallel,\n        **factory_kwargs,\n    )\n    mlp_cls = create_mlp_cls(\n        d_model,\n        d_inner=d_inner,\n        process_group=process_group,\n        fused_mlp=fused_mlp,\n        identity_mlp=identity_mlp,\n        sequence_parallel=sequence_parallel,\n        **factory_kwargs,\n    )\n    norm_cls = partial(nn.LayerNorm, eps=layer_norm_epsilon, **factory_kwargs)\n    block = Block(\n        d_model,\n        mixer_cls,\n        mlp_cls,\n        norm_cls=norm_cls,\n        prenorm=True,\n        resid_dropout1=resid_dropout1,\n        resid_dropout2=resid_dropout2,\n        fused_dropout_add_ln=fused_dropout_add_ln,\n        residual_in_fp32=residual_in_fp32,\n        sequence_parallel=sequence_parallel and process_group is not None,\n        mark_shared_params=process_group is not None,\n    )\n\n    block.layer_idx = layer_idx\n\n    if checkpoint_mlp:\n        block.mlp = CheckpointedModule(block.mlp)\n    if checkpoint_mixer:\n        block.mixer = CheckpointedModule(block.mixer)\n    return block\n\n\n# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454\ndef _init_weights(\n    module,\n    n_layer,\n    initializer_range=0.02,\n    rescale_prenorm_residual=True,\n    glu_act=False,\n):\n    if isinstance(module, nn.Linear):\n        nn.init.normal_(module.weight, std=initializer_range)\n        if module.bias is not None:\n            nn.init.zeros_(module.bias)\n    elif isinstance(module, nn.Embedding):\n        nn.init.normal_(module.weight, std=initializer_range)\n\n    if rescale_prenorm_residual:\n        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:\n        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale\n        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.\n        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/\n        #\n        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py\n        for name, p in module.named_parameters():\n            if name in [\"out_proj.weight\", \"fc2.weight\"]:\n                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block\n                nn.init.normal_(\n                    p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)\n                )\n            # If using GLU activation for now, we scale the std by 2\n            elif name in [\"output_linear.0.weight\"]:\n                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block\n                if not glu_act:\n                    nn.init.normal_(\n                        p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)\n                    )\n                else:\n                    out_features = p.shape[0]\n                    # Multiplying the first half of the matrix by 2 since sigmoid scales it down by 0.5\n                    # on average.\n                    nn.init.normal_(\n                        p[: out_features // 2],\n                        mean=0.0,\n                        std=initializer_range / math.sqrt(2 * n_layer) * 2,\n                    )\n\n\nclass LMBackbone(nn.Module):\n    def __init__(\n        self,\n        d_model: int,\n        n_layer: int,\n        d_inner: int,\n        vocab_size: int,\n        process_group=None,\n        layer=None,\n        attn_layer_idx=None,\n        attn_cfg=None,\n        max_position_embeddings=0,\n        resid_dropout: float = 0.0,\n        embed_dropout: float = 0.1,\n        dropout_cls=nn.Dropout,\n        layer_norm_epsilon: float = 1e-5,\n        initializer_cfg=None,\n        fused_mlp=False,\n        identity_mlp=False,\n        fused_dropout_add_ln=False,\n        residual_in_fp32=False,\n        sequence_parallel=True,\n        checkpoint_mlp=False,\n        checkpoint_mixer=False,\n        device=None,\n        dtype=None,\n        **kwargs,\n    ) -> None:\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.process_group = process_group\n        self.sequence_parallel = sequence_parallel\n        self.residual_in_fp32 = residual_in_fp32\n\n        if process_group is None:\n            self.embeddings = GPT2Embeddings(\n                d_model, vocab_size, max_position_embeddings, **factory_kwargs\n            )\n        else:\n            self.embeddings = ParallelGPT2Embeddings(\n                d_model,\n                vocab_size,\n                max_position_embeddings,\n                process_group=process_group,\n                sequence_parallel=self.sequence_parallel,\n                **factory_kwargs,\n            )\n\n        # We change the order of dropout, residual and layer norm:\n        # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:\n        # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and\n        # the main branch (output of MLP). The model definition is unchanged, but the mapping of the\n        # nn.Dropout probabilities are changed.\n        # This is for performance reason: we can fuse dropout + add + layer_norm.\n        self.fused_dropout_add_ln = fused_dropout_add_ln\n        if self.fused_dropout_add_ln and dropout_add_layer_norm is None:\n            raise ImportError(\"dropout_add_layer_norm is not installed\")\n\n        self.layers = nn.ModuleList(\n            [\n                create_block(\n                    d_model,\n                    d_inner=d_inner,\n                    process_group=process_group,\n                    layer=layer,\n                    attn_layer_idx=attn_layer_idx,\n                    attn_cfg=attn_cfg,\n                    layer_norm_epsilon=layer_norm_epsilon,\n                    resid_dropout1=embed_dropout if i == 0 else resid_dropout,\n                    resid_dropout2=resid_dropout,\n                    residual_in_fp32=residual_in_fp32,\n                    fused_mlp=fused_mlp,\n                    identity_mlp=identity_mlp,\n                    fused_dropout_add_ln=fused_dropout_add_ln,\n                    layer_idx=i,\n                    sequence_parallel=self.sequence_parallel,\n                    checkpoint_mlp=checkpoint_mlp,\n                    checkpoint_mixer=checkpoint_mixer,\n                    **factory_kwargs,\n                )\n                for i in range(n_layer)\n            ]\n        )\n\n        self.drop_f = nn.Dropout(resid_dropout)\n        self.ln_f = nn.LayerNorm(d_model, eps=layer_norm_epsilon, **factory_kwargs)\n\n        if process_group is not None:\n            for p in self.ln_f.parameters():\n                # Mark the norm parameters as \"shared_params\" so that we sync their values at init.\n                p._shared_params = True\n                # Mark the norm params as \"sequence_parallel\" so we run all-reduce on their grads.\n                if self.sequence_parallel:\n                    p._sequence_parallel = True\n\n        self.apply(\n            partial(\n                _init_weights,\n                n_layer=n_layer,\n                **(initializer_cfg if initializer_cfg is not None else {}),\n            )\n        )\n        self.tie_weights()\n\n    def tie_weights(self):\n        if self.process_group is not None:\n            sync_shared_params(self, self.process_group)\n\n    def forward(self, input_ids, position_ids=None, inference_params=None):\n        # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen\n        # dimensions so that we can split on it easily, in case of small batch size.\n        # Only the attention/SSM layers need to know the seqlen.\n        embedding_kwargs = (\n            {\"combine_batch_seqlen_dim\": True}\n            if self.process_group is not None and self.sequence_parallel\n            else {}\n        )\n        hidden_states = self.embeddings(\n            input_ids, position_ids=position_ids, **embedding_kwargs\n        )\n        residual = None\n        mixer_kwargs = (\n            {\"seqlen\": input_ids.shape[1]}\n            if self.process_group is not None and self.sequence_parallel\n            else {}\n        )\n        if inference_params is not None:\n            mixer_kwargs[\"inference_params\"] = inference_params\n        for layer in self.layers:\n            hidden_states, residual = layer(\n                hidden_states, residual, mixer_kwargs=mixer_kwargs\n            )\n        if not self.fused_dropout_add_ln:\n            dropped = self.drop_f(hidden_states)\n            residual = (dropped + residual) if residual is not None else dropped\n            hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))\n        else:\n            # Set prenorm=False here since we don't need the residual\n            hidden_states = dropout_add_layer_norm(\n                hidden_states,\n                residual,\n                self.ln_f.weight,\n                self.ln_f.bias,\n                self.drop_f.p if self.training else 0.0,\n                self.ln_f.eps,\n                prenorm=False,\n                residual_in_fp32=self.residual_in_fp32,\n            )\n        return hidden_states\n\n\nclass ConvLMHeadModel(nn.Module, GenerationMixin):\n    def __init__(\n        self,\n        d_model: int,\n        n_layer: int,\n        d_inner: int,\n        vocab_size: int,\n        process_group=None,\n        layer=None,\n        attn_layer_idx=None,\n        attn_cfg=None,\n        max_position_embeddings=0,\n        resid_dropout: float = 0.0,\n        embed_dropout: float = 0.1,\n        dropout_cls=nn.Dropout,\n        layer_norm_epsilon: float = 1e-5,\n        initializer_cfg=None,\n        fused_mlp=False,\n        fused_dropout_add_ln=False,\n        residual_in_fp32=False,\n        pad_vocab_size_multiple: int = 1,\n        sequence_parallel=True,\n        checkpoint_mlp=False,\n        checkpoint_mixer=False,\n        device=None,\n        dtype=None,\n        **kwargs,\n    ) -> None:\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.process_group = process_group\n        if vocab_size % pad_vocab_size_multiple != 0:\n            vocab_size += pad_vocab_size_multiple - (\n                vocab_size % pad_vocab_size_multiple\n            )\n        self.backbone = LMBackbone(\n            d_model=d_model,\n            n_layer=n_layer,\n            d_inner=d_inner,\n            vocab_size=vocab_size,\n            process_group=process_group,\n            layer=layer,\n            attn_layer_idx=attn_layer_idx,\n            attn_cfg=attn_cfg,\n            max_position_embeddings=max_position_embeddings,\n            resid_dropout=resid_dropout,\n            embed_dropout=embed_dropout,\n            dropout_cls=dropout_cls,\n            layer_norm_epsilon=layer_norm_epsilon,\n            initializer_cfg=initializer_cfg,\n            fused_mlp=fused_mlp,\n            fused_dropout_add_ln=fused_dropout_add_ln,\n            residual_in_fp32=residual_in_fp32,\n            sequence_parallel=sequence_parallel,\n            checkpoint_mlp=checkpoint_mlp,\n            checkpoint_mixer=checkpoint_mixer,\n            **factory_kwargs,\n            **kwargs,\n        )\n        if process_group is None:\n            self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)\n        else:\n            if ColumnParallelLinear is None:\n                raise ImportError(\"fused_dense_lib is not installed\")\n            self.lm_head = ColumnParallelLinear(\n                d_model,\n                vocab_size,\n                process_group,\n                bias=False,\n                sequence_parallel=sequence_parallel,\n                **factory_kwargs,\n            )\n        # Initialize weights and apply final processing\n        self.apply(\n            partial(\n                _init_weights,\n                n_layer=n_layer,\n                **(initializer_cfg if initializer_cfg is not None else {}),\n            )\n        )\n        self.tie_weights()\n\n    def tie_weights(self):\n        self.lm_head.weight = self.backbone.embeddings.word_embeddings.weight\n        if self.process_group is not None:\n            sync_shared_params(self, self.process_group)\n\n    def forward(\n        self, input_ids, position_ids=None, inference_params=None, state=None\n    ):  # state for the repo interface\n        hidden_states = self.backbone(\n            input_ids, position_ids=position_ids, inference_params=inference_params\n        )\n        lm_logits = self.lm_head(hidden_states)\n        # During inference, we want the full logit for sampling\n        if ColumnParallelLinear is not None and inference_params is not None:\n            if isinstance(self.lm_head, ColumnParallelLinear):\n                lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)\n                lm_logits = rearrange(\n                    lm_logits, \"(n b) s d -> b s (n d)\", b=hidden_states.shape[0]\n                )\n        CausalLMOutput = namedtuple(\"CausalLMOutput\", [\"logits\"])\n        return CausalLMOutput(logits=lm_logits), None\n"
  },
  {
    "path": "src/ops/fftconv.py",
    "content": "import math\n\nimport torch\nimport torch.nn.functional as F\n\nfrom einops import rearrange\n\nfrom fftconv import fftconv_fwd, fftconv_bwd\n\n@torch.jit.script\ndef _mul_sum(y, q):\n    return (y * q).sum(dim=1)\n\n# reference convolution with residual connection\ndef fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None):\n    seqlen = u.shape[-1]\n    fft_size = 2 * seqlen\n    k_f = torch.fft.rfft(k, n=fft_size) / fft_size\n    if k_rev is not None:\n        k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size\n        k_f = k_f + k_rev_f.conj()\n    u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)\n    \n    if len(u.shape) > 3: k_f = k_f.unsqueeze(1)\n\n    y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]\n\n    out = y + u * D.unsqueeze(-1)\n    if gelu:\n        out = F.gelu(out)\n    if dropout_mask is not None:\n        return (out * rearrange(dropout_mask, 'b H -> b H 1')).to(dtype=u.dtype)\n    else:\n        return out.to(dtype=u.dtype)\n\n    \n# reference H3 forward pass\ndef fftconv_h3_ref(k, ssm_kernel, D, q, v, head_dim=1, ssm_kernel_rev=None):\n    seqlen = k.shape[-1]\n    fft_size = 2 * seqlen\n    kv = (rearrange(k, 'b (h d1) l -> b d1 1 h l', d1=head_dim)\n            * rearrange(v, 'b (h d2) l -> b 1 d2 h l', d2=head_dim))  # b d1 d2 h l\n    kv_f = torch.fft.rfft(kv.to(dtype=ssm_kernel.dtype), n=fft_size) / fft_size\n    ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size)  # h L+1\n    if ssm_kernel_rev is not None:\n        ssm_kernel_rev_f = torch.fft.rfft(ssm_kernel_rev, n=fft_size)  # h L+1\n        ssm_kernel_f = ssm_kernel_f + ssm_kernel_rev_f.conj()\n    y = torch.fft.irfft(kv_f * ssm_kernel_f, n=fft_size, norm='forward')[..., :seqlen]  # b d1 d2 h l\n    out = y + kv * D.unsqueeze(-1)  # b d1 d2 h l\n    q = rearrange(q, 'b (h d1) l -> b d1 1 h l', d1=head_dim)\n    if head_dim > 1:\n        out = _mul_sum(out, q)\n        return rearrange(out, 'b d2 h l -> b (h d2) l').to(dtype=k.dtype)\n    else:\n        return rearrange(out * q, 'b 1 1 h l -> b h l').to(dtype=k.dtype)\n\n\nclass FFTConvFunc(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, u, k, D, dropout_mask=None, gelu=True, force_fp16_output=False,\n                output_hbl_layout=False, v=None, head_dim=1, q=None, fftfp16=False, k_rev=None):\n        seqlen = u.shape[-1]\n        fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16)\n        k_f = torch.fft.rfft(k, n=fft_size)\n        if k_rev is not None:\n            k_f = k_f + torch.fft.rfft(k_rev, n=fft_size).conj()\n        if u.stride(-1) != 1:\n            u = u.contiguous()\n        k_f = k_f.contiguous()\n        D = D.contiguous()\n        if v is not None and v.stride(-1) != 1:\n            v = v.contiguous()\n        if q is not None and q.stride(-1) != 1:\n            q = q.contiguous()\n        if dropout_mask is not None:\n            dropout_mask = dropout_mask.contiguous()\n        ctx.save_for_backward(u, k_f, D, dropout_mask, v, q)\n        ctx.output_hbl_layout = output_hbl_layout\n        ctx.head_dim = head_dim\n        ctx.gelu = gelu\n        ctx.fftfp16 = fftfp16\n        ctx.has_k_rev = k_rev is not None\n        out = fftconv_fwd(u, k_f, D, v, head_dim, q, dropout_mask, gelu, False, False, fft_size, force_fp16_output, output_hbl_layout, fftfp16)\n        return out\n\n    @staticmethod\n    def backward(ctx, dout):\n        if ctx.output_hbl_layout:\n            dout = rearrange(rearrange(dout, 'b h l -> h b l').contiguous(), 'h b l -> b h l')\n        else:\n            dout = dout.contiguous()\n        u, k_f, D, dropout_mask, v, q = ctx.saved_tensors\n        seqlen = u.shape[-1]\n        fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16)\n        du, dk_f, dD, dv, dq = fftconv_bwd(dout, u, k_f, D, v, ctx.head_dim, q, dropout_mask, ctx.gelu, False, False, fft_size,\n                                   ctx.output_hbl_layout, ctx.fftfp16)\n        dk = torch.fft.irfft(dk_f, n=fft_size, norm='forward')[..., :seqlen]\n        dk_rev = (None if not ctx.has_k_rev\n                  else torch.fft.irfft(dk_f.conj(), n=fft_size, norm='forward')[..., :seqlen])\n        if v is not None:\n            dv = dv.to(dtype=v.dtype)  # We do atomicAdd in fp32 so might need to convert to fp16\n        return du, dk, dD, None, None, None, None, dv if v is not None else None, None, dq if q is not None else None, None, dk_rev\n\ndef fftconv_func(u, k, D, dropout_mask=None, gelu=True, force_fp16_output=False,\n                 output_hbl_layout=False, v=None, head_dim=1, q=None, fftfp16=False, k_rev=None):\n    return FFTConvFunc.apply(u, k, D, dropout_mask, gelu, force_fp16_output,\n                             output_hbl_layout, v, head_dim, q, fftfp16, k_rev)\n"
  },
  {
    "path": "src/tasks/decoders.py",
    "content": "\"\"\"Decoder heads.\n\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport src.models.nn.utils as U\nimport src.utils as utils\nimport src.utils.train\n\nlog = src.utils.train.get_logger(__name__)\n\n\nclass Decoder(nn.Module):\n    \"\"\"This class doesn't do much but just signals the interface that Decoders are expected to adhere to\n    TODO: is there a way to enforce the signature of the forward method?\n    \"\"\"\n\n    def forward(self, x, **kwargs):\n        \"\"\"\n        x: (batch, length, dim) input tensor\n        state: additional state from the model backbone\n        *args, **kwargs: additional info from the dataset\n\n        Returns:\n        y: output tensor\n        *args: other arguments to pass into the loss function\n        \"\"\"\n        return x\n\n    def step(self, x):\n        \"\"\"\n        x: (batch, dim)\n        \"\"\"\n        return self.forward(x.unsqueeze(1)).squeeze(1)\n\n\nclass SequenceDecoder(Decoder):\n    def __init__(\n        self, d_model, d_output=None, l_output=None, use_lengths=False, mode=\"last\",\n            conjoin_train=False, conjoin_test=False\n    ):\n        super().__init__()\n\n        self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output)\n\n        if l_output is None:\n            self.l_output = None\n            self.squeeze = False\n        elif l_output == 0:\n            # Equivalent to getting an output of length 1 and then squeezing\n            self.l_output = 1\n            self.squeeze = True\n        else:\n            assert l_output > 0\n            self.l_output = l_output\n            self.squeeze = False\n\n        self.use_lengths = use_lengths\n        self.mode = mode\n\n        if mode == 'ragged':\n            assert not use_lengths\n\n        self.conjoin_train = conjoin_train\n        self.conjoin_test = conjoin_test\n\n    def forward(self, x, state=None, lengths=None, l_output=None):\n        \"\"\"\n        x: (n_batch, l_seq, d_model) or potentially (n_batch, l_seq, d_model, 2) if using rc_conjoin\n        Returns: (n_batch, l_output, d_output)\n        \"\"\"\n        if self.l_output is None:\n            if l_output is not None:\n                assert isinstance(l_output, int)  # Override by pass in\n            else:\n                # Grab entire output\n                l_output = x.size(1)\n            squeeze = False\n        else:\n            l_output = self.l_output\n            squeeze = self.squeeze\n\n        if self.mode == \"last\":\n            def restrict(x_seq):\n                \"\"\"Use last l_output elements of sequence.\"\"\"\n                return x_seq[..., -l_output:, :]\n\n        elif self.mode == \"first\":\n            def restrict(x_seq):\n                \"\"\"Use first l_output elements of sequence.\"\"\"\n                return x_seq[..., :l_output, :]\n\n        elif self.mode == \"pool\":\n            def restrict(x_seq):\n                \"\"\"Pool sequence over a certain range\"\"\"\n                L = x_seq.size(1)\n                s = x_seq.sum(dim=1, keepdim=True)\n                if l_output > 1:\n                    c = torch.cumsum(x_seq[..., -(l_output - 1):, ...].flip(1), dim=1)\n                    c = F.pad(c, (0, 0, 1, 0))\n                    s = s - c  # (B, l_output, D)\n                    s = s.flip(1)\n                denom = torch.arange(\n                    L - l_output + 1, L + 1, dtype=x_seq.dtype, device=x_seq.device\n                )\n                s = s / denom\n                return s\n\n        elif self.mode == \"sum\":\n            # TODO use same restrict function as pool case\n            def restrict(x_seq):\n                \"\"\"Cumulative sum last l_output elements of sequence.\"\"\"\n                return torch.cumsum(x_seq, dim=-2)[..., -l_output:, :]\n        elif self.mode == 'ragged':\n            assert lengths is not None, \"lengths must be provided for ragged mode\"\n\n            def restrict(x_seq):\n                \"\"\"Ragged aggregation.\"\"\"\n                # remove any additional padding (beyond max length of any sequence in the batch)\n                return x_seq[..., : max(lengths), :]\n        else:\n            raise NotImplementedError(\n                \"Mode must be ['last' | 'first' | 'pool' | 'sum' | 'ragged']\"\n            )\n\n        # Restrict to actual length of sequence\n        if self.use_lengths:\n            assert lengths is not None\n            x = torch.stack(\n                [\n                    restrict(out[..., :length, :])\n                    for out, length in zip(torch.unbind(x, dim=0), lengths)\n                ],\n                dim=0,\n            )\n        else:\n            x = restrict(x)\n\n        if squeeze:\n            assert x.size(1) == 1\n            x = x.squeeze(1)\n\n        if self.conjoin_train or (self.conjoin_test and not self.training):\n            x, x_rc = x.chunk(2, dim=-1)\n            x = self.output_transform(x.squeeze())\n            x_rc = self.output_transform(x_rc.squeeze())\n            x = (x + x_rc) / 2\n        else:\n            x = self.output_transform(x)\n\n        return x\n\n    def step(self, x, state=None):\n        # Ignore all length logic\n        x_fwd = self.output_transform(x.mean(dim=1))\n        x_rc = self.output_transform(x.flip(dims=[1, 2]).mean(dim=1)).flip(dims=[1])\n        x_out = (x_fwd + x_rc) / 2\n        return x_out\n\n\n# For every type of encoder/decoder, specify:\n# - constructor class\n# - list of attributes to grab from dataset\n# - list of attributes to grab from model\n\nregistry = {\n    \"stop\": Decoder,\n    \"id\": nn.Identity,\n    \"linear\": nn.Linear,\n    \"sequence\": SequenceDecoder,\n}\n\nmodel_attrs = {\n    \"linear\": [\"d_output\"],\n    \"sequence\": [\"d_output\"],\n    \"nd\": [\"d_output\"],\n    \"retrieval\": [\"d_output\"],\n    \"state\": [\"d_state\", \"state_to_tensor\"],\n    \"forecast\": [\"d_output\"],\n    \"token\": [\"d_output\"],\n}\n\ndataset_attrs = {\n    \"linear\": [\"d_output\"],\n    \"sequence\": [\"d_output\", \"l_output\"],\n    \"nd\": [\"d_output\"],\n    \"retrieval\": [\"d_output\"],\n    \"state\": [\"d_output\"],\n    \"forecast\": [\"d_output\", \"l_output\"],\n    \"token\": [\"d_output\"],\n}\n\n\ndef _instantiate(decoder, model=None, dataset=None):\n    \"\"\"Instantiate a single decoder\"\"\"\n    if decoder is None:\n        return None\n\n    if isinstance(decoder, str):\n        name = decoder\n    else:\n        name = decoder[\"_name_\"]\n\n    # Extract arguments from attribute names\n    dataset_args = utils.config.extract_attrs_from_obj(\n        dataset, *dataset_attrs.get(name, [])\n    )\n    model_args = utils.config.extract_attrs_from_obj(model, *model_attrs.get(name, []))\n    # Instantiate decoder\n    obj = utils.instantiate(registry, decoder, *model_args, *dataset_args)\n    return obj\n\n\ndef instantiate(decoder, model=None, dataset=None):\n    \"\"\"Instantiate a full decoder config, e.g. handle list of configs\n    Note that arguments are added in reverse order compared to encoder (model first, then dataset)\n    \"\"\"\n    decoder = utils.to_list(decoder)\n    return U.PassthroughSequential(\n        *[_instantiate(d, model=model, dataset=dataset) for d in decoder]\n    )\n"
  },
  {
    "path": "src/tasks/encoders.py",
    "content": "from torch import nn\n\nimport src.models.nn.utils as U\nimport src.utils as utils\n\n\nclass Encoder(nn.Module):\n    \"\"\"Encoder abstraction\n\n    Accepts a tensor and optional kwargs. Other than the main tensor, all other arguments should be kwargs.\n    Returns a tensor and optional kwargs.\n    Encoders are combined via U.PassthroughSequential which passes these kwargs through in a pipeline. The resulting\n    kwargs are accumulated and passed into the model backbone.\n    \"\"\"\n\n    def forward(self, x, **kwargs):\n        \"\"\"\n        x: input tensor\n        *args: additional info from the dataset (e.g. sequence lengths)\n\n        Returns:\n        y: output tensor\n        *args: other arguments to pass into the model backbone\n        \"\"\"\n        return x, {}\n\n\n# For every type of encoder/decoder, specify:\n# - constructor class\n# - list of attributes to grab from dataset\n# - list of attributes to grab from model\n\nregistry = {\n    \"stop\": Encoder,\n    \"id\": nn.Identity,\n    \"embedding\": nn.Embedding,\n    \"linear\": nn.Linear,\n}\n\ndataset_attrs = {\n    \"embedding\": [\"n_tokens\"],\n    \"linear\": [\"d_input\"],  # TODO make this d_data?\n    \"class\": [\"n_classes\"],\n    \"time\": [\"n_tokens_time\"],\n    \"onehot\": [\"n_tokens\"],\n    \"conv1d\": [\"d_input\"],\n    \"patch2d\": [\"d_input\"],\n}\n\nmodel_attrs = {\n    \"embedding\": [\"d_model\"],\n    \"linear\": [\"d_model\"],\n    \"position\": [\"d_model\"],\n    \"class\": [\"d_model\"],\n    \"time\": [\"d_model\"],\n    \"onehot\": [\"d_model\"],\n    \"conv1d\": [\"d_model\"],\n    \"patch2d\": [\"d_model\"],\n    \"timestamp_embedding\": [\"d_model\"],\n    \"layer\": [\"d_model\"],\n}\n\n\ndef _instantiate(encoder, dataset=None, model=None):\n    \"\"\"Instantiate a single encoder\"\"\"\n    if encoder is None:\n        return None\n    if isinstance(encoder, str):\n        name = encoder\n    else:\n        name = encoder[\"_name_\"]\n\n    # Extract dataset/model arguments from attribute names\n    dataset_args = utils.config.extract_attrs_from_obj(\n        dataset, *dataset_attrs.get(name, [])\n    )\n    model_args = utils.config.extract_attrs_from_obj(model, *model_attrs.get(name, []))\n\n    # Instantiate encoder\n    obj = utils.instantiate(registry, encoder, *dataset_args, *model_args)\n    return obj\n\n\ndef instantiate(encoder, dataset=None, model=None):\n    encoder = utils.to_list(encoder)\n    return U.PassthroughSequential(\n        *[_instantiate(e, dataset=dataset, model=model) for e in encoder]\n    )\n"
  },
  {
    "path": "src/tasks/metrics.py",
    "content": "import math\nfrom functools import partial\n\nimport torch\nimport torch.nn.functional as F\nimport torchmetrics.functional as tm_f\nfrom sklearn.metrics import f1_score, roc_auc_score, matthews_corrcoef\nfrom torchmetrics.classification import MulticlassRecall, MulticlassPrecision\n\nfrom torchmetrics import Metric\n\n\nclass CorrectAggregatedMetric(Metric):\n    \"\"\"This is needed to calculate some metrics b/c small batch sizes cause aggregation via a simple\n        average to be off, as some classes might not be present in batch but will get penalized with a 0.\"\"\"\n    def __init__(self, class_idx: int, dist_sync_on_step=False):\n        # call `self.add_state`for every internal state that is needed for the metrics computations\n        # dist_reduce_fx indicates the function that should be used to reduce\n        # state from multiple processes\n        super().__init__(dist_sync_on_step=dist_sync_on_step)\n        self.class_idx = torch.tensor(class_idx)\n        self.add_state(\"numerator\", default=torch.tensor(0.0), dist_reduce_fx=\"sum\")\n        self.add_state(\"denominator\", default=torch.tensor(0.0), dist_reduce_fx=\"sum\")\n\n    def _update(self, numerator, denominator, preds, y) -> tuple:\n        raise NotImplemented\n\n    def update(self, logits: torch.Tensor, y: torch.Tensor):\n        # update metric states\n        preds = torch.argmax(logits, dim=-1)\n        logits = logits.view(-1, logits.shape[-1])\n        y = y.view(-1)\n        assert preds.shape == y.shape, f\"preds shape {preds.shape} != y shape {y.shape}\"\n        self.numerator, self.denominator = self._update(self.numerator, self.denominator, preds, y)\n\n    def compute(self):\n        # compute final result\n        value = self.numerator.float() / self.denominator if self.denominator > 0 else torch.tensor(0.0)\n        return value\n\n    def reset(self):\n        self.numerator = torch.tensor(0.0)\n        self.denominator = torch.tensor(0.0)\n\nclass AccuracyPerClass(CorrectAggregatedMetric):\n    \"\"\"Calculate per class accuracy, i.e. P(y_hat = class_idx AND y = class_idx OR y_hat != class_idx AND y != class_idx)\n    \"\"\"\n    def _update(self, numerator, denominator, preds, y) -> tuple:\n        # Filter down to the class of interest\n        class_idx = self.class_idx\n        relevant_idxs = (y == class_idx)\n        numerator += (preds[relevant_idxs] == class_idx).sum()\n        denominator += relevant_idxs.sum()\n        relevant_idxs = (y != class_idx)\n        numerator += (preds[relevant_idxs] != class_idx).sum()\n        denominator += relevant_idxs.sum()\n        return numerator, denominator\n\nclass PrecisionPerClass(CorrectAggregatedMetric):\n    \"\"\"Calculate per class precision, i.e. P(y_hat = y | y_hat = class_idx)\n    \"\"\"\n    def _update(self, numerator, denominator, preds, y) -> tuple:\n        # Filter down to the class of interest\n        class_idx = self.class_idx\n        relevant_idxs = (preds == class_idx)\n        numerator += (preds[relevant_idxs] == y[relevant_idxs]).sum()\n        denominator += relevant_idxs.sum()\n        return numerator, denominator\n\n\nclass RecallPerClass(CorrectAggregatedMetric):\n    \"\"\"Calculate per class recall, i.e. P(y_hat = y | y = class_idx)\n    \"\"\"\n    def _update(self, numerator, denominator, preds, y) -> tuple:\n        # Filter down to the class of interest\n        class_idx = self.class_idx\n        relevant_idxs = (y == class_idx)\n        numerator += (preds[relevant_idxs] == y[relevant_idxs]).sum()\n        denominator += relevant_idxs.sum()\n        return numerator, denominator\n\n\ndef mcc(logits, y):\n    logits = logits.view(-1, logits.shape[-1])\n    y = y.view(-1)\n    y_hat = torch.argmax(logits, dim=-1)\n    return matthews_corrcoef(y.cpu().numpy(), y_hat.cpu().numpy())\n\n\ndef last_k_ppl(logits, y, seq_len=1024, k=None):\n    '''\n    Calculate perplexity for last k tokens in a sequence.\n\n    logits: (batch_size * seq_len, vocab_size), note, already flattened\n    y: (batch_size * seq_len), note, already flattened\n    seq_len: int, length of each sequence in the batch\n    k: if None, use all tokens in sequence\n    \n    returns: (batch_size,)  ppl for each sequence in the batch\n    '''\n\n    if k is None:\n        k = 0  # use the entire sequence\n\n    # need to reshape logits and y to be (batch_size, seq_len, vocab_size) and (batch_size, seq_len)\n    # respectively\n    # breakpoint()\n    logits = logits.view(-1, seq_len, logits.shape[-1])\n    y = y.view(-1, seq_len)\n\n    # only use the last k values of seq dim in logits and y\n    logits = logits[:, -k:, :]\n    y = y[:, -k:]\n\n    # reshape to flatten the batch and seq_len dimensions\n    logits = logits.reshape(-1, logits.shape[-1])\n    y = y.reshape(-1)\n    # get avg and put on cpu\n    return F.cross_entropy(logits, y, reduction='none').view(y.shape[0], -1).mean().exp().cpu()\n\n\ndef _student_t_map(mu, sigma, nu):\n    sigma = F.softplus(sigma)\n    nu = 2.0 + F.softplus(nu)\n    return mu.squeeze(axis=-1), sigma.squeeze(axis=-1), nu.squeeze(axis=-1)\n\ndef student_t_loss(outs, y):\n    mu, sigma, nu = outs[..., 0], outs[..., 1], outs[..., 2]\n    mu, sigma, nu = _student_t_map(mu, sigma, nu)\n    y = y.squeeze(axis=-1)\n\n    nup1_half = (nu + 1.0) / 2.0\n    part1 = 1.0 / nu * torch.square((y - mu) / sigma)\n    Z = (\n        torch.lgamma(nup1_half)\n        - torch.lgamma(nu / 2.0)\n        - 0.5 * torch.log(math.pi * nu)\n        - torch.log(sigma)\n    )\n\n    ll = Z - nup1_half * torch.log1p(part1)\n    return -ll.mean()\n\ndef gaussian_ll_loss(outs, y):\n    mu, sigma = outs[..., 0], outs[..., 1]\n    y = y.squeeze(axis=-1)\n    sigma = F.softplus(sigma)\n    ll = -1.0 * (\n        torch.log(sigma)\n        + 0.5 * math.log(2 * math.pi)\n        + 0.5 * torch.square((y - mu) / sigma)\n    )\n    return -ll.mean()\n\ndef binary_cross_entropy(logits, y):\n    # BCE loss requires squeezing last dimension of logits so it has the same shape as y\n    # requires y to be float, since it's overloaded to represent a probability\n    return F.binary_cross_entropy_with_logits(logits.squeeze(-1), y.float())\n\n\ndef binary_accuracy(logits, y):\n    return torch.eq(logits.squeeze(-1) >= 0, y).float().mean()\n\ndef padded_cross_entropy(logits, y, pad_mask, pad_value=-1):\n    \"\"\"Will ignore the pad value in label (eg, -1)\n    \n    logits: (batch_size, seq_len, vocab_size)\n    y: (batch_size, seq_len)\n    pad_mask: (batch_size, seq_len)\n    \n    \"\"\"\n\n    # need to apply pad mask to y\n    y_pad = y + pad_mask * pad_value\n\n    logits = logits.view(-1, logits.shape[-1])\n    y_pad = y_pad.view(-1)\n    return F.cross_entropy(logits, y_pad, ignore_index=pad_value)\n\n\ndef cross_entropy(logits, y, ignore_index=-100):\n    logits = logits.view(-1, logits.shape[-1])\n    y = y.view(-1)\n    return F.cross_entropy(logits, y, ignore_index=ignore_index)\n\n\ndef soft_cross_entropy(logits, y, label_smoothing=0.0):\n    logits = logits.view(-1, logits.shape[-1])\n    # target is now 2d (no target flattening)\n    return F.cross_entropy(logits, y, label_smoothing=label_smoothing)\n\n\ndef accuracy(logits, y):\n    logits = logits.view(-1, logits.shape[-1])\n    preds = torch.argmax(logits, dim=-1)\n    if y.numel() > logits.shape[0]:\n        # Mixup leads to this case: use argmax class\n        y = y.argmax(dim=-1)\n    y = y.view(-1)\n    return torch.eq(preds, y).float().mean()\n\n\ndef accuracy_ignore_index(logits, y, ignore_index=-100):\n    num_classes = logits.shape[-1]\n    preds = torch.argmax(logits, dim=-1)\n    logits = logits.view(-1, logits.shape[-1])\n    y = y.view(-1)\n    accuracy = tm_f.classification.accuracy(preds, y, 'multiclass', num_classes=num_classes, ignore_index=ignore_index, average='micro')\n    return accuracy\n\n\ndef accuracy_at_k(logits, y, k=1):\n    logits = logits.view(-1, logits.shape[-1])\n    if y.numel() > logits.shape[0]:\n        # Mixup leads to this case: use argmax class\n        y = y.argmax(dim=-1)\n    y = y.view(-1)\n    return torch.topk(logits, k, dim=-1)[1].eq(y.unsqueeze(-1)).any(dim=-1).float().mean()\n\n\ndef f1_binary(logits, y):\n    logits = logits.view(-1, logits.shape[-1])\n    y = y.view(-1)\n    y_hat = torch.argmax(logits, dim=-1)\n    return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average=\"binary\")\n\n\ndef f1_macro(logits, y):\n    logits = logits.view(-1, logits.shape[-1])\n    y = y.view(-1)\n    y_hat = torch.argmax(logits, dim=-1)\n    return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average=\"macro\")\n\n\ndef f1_micro(logits, y):\n    logits = logits.view(-1, logits.shape[-1])\n    y = y.view(-1)\n    y_hat = torch.argmax(logits, dim=-1)\n    return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average=\"micro\")\n\n\ndef roc_auc_macro(logits, y):\n    logits = logits.view(\n        -1, logits.shape[-1]\n    ).detach()  # KS: had to add detach to eval while training\n    y = y.view(-1)\n    return roc_auc_score(\n        y.cpu().numpy(), F.softmax(logits, dim=-1).cpu().numpy()[:, 1], average=\"macro\"\n    )\n\n\ndef roc_auc_micro(logits, y):\n    logits = logits.view(-1, logits.shape[-1])\n    y = y.view(-1)\n    return roc_auc_score(\n        y.cpu().numpy(), F.softmax(logits, dim=-1).cpu().numpy()[:, 1], average=\"micro\"\n    )\n\n\ndef mse(outs, y, len_batch=None):\n    # assert outs.shape[:-1] == y.shape and outs.shape[-1] == 1\n    # outs = outs.squeeze(-1)\n    if len(y.shape) < len(outs.shape):\n        assert outs.shape[-1] == 1\n        outs = outs.squeeze(-1)\n    if len_batch is None:\n        return F.mse_loss(outs, y)\n    else:\n        # Computes the loss of the first `lens` items in the batches\n        # TODO document the use case of this\n        mask = torch.zeros_like(outs, dtype=torch.bool)\n        for i, l in enumerate(len_batch):\n            mask[i, :l, :] = 1\n        outs_masked = torch.masked_select(outs, mask)\n        y_masked = torch.masked_select(y, mask)\n        return F.mse_loss(outs_masked, y_masked)\n\ndef forecast_rmse(outs, y, len_batch=None):\n    # TODO: generalize, currently for Monash dataset\n    return torch.sqrt(F.mse_loss(outs, y, reduction='none').mean(1)).mean()\n\ndef mae(outs, y, len_batch=None):\n    # assert outs.shape[:-1] == y.shape and outs.shape[-1] == 1\n    # outs = outs.squeeze(-1)\n    if len(y.shape) < len(outs.shape):\n        assert outs.shape[-1] == 1\n        outs = outs.squeeze(-1)\n    if len_batch is None:\n        return F.l1_loss(outs, y)\n    else:\n        # Computes the loss of the first `lens` items in the batches\n        mask = torch.zeros_like(outs, dtype=torch.bool)\n        for i, l in enumerate(len_batch):\n            mask[i, :l, :] = 1\n        outs_masked = torch.masked_select(outs, mask)\n        y_masked = torch.masked_select(y, mask)\n        return F.l1_loss(outs_masked, y_masked)\n\n\n# Metrics that can depend on the loss\ndef loss(x, y, loss_fn):\n    \"\"\" This metric may be useful because the training loss may add extra regularization (e.g. weight decay implemented as L2 penalty), while adding this as a metric skips the additional losses \"\"\"\n    return loss_fn(x, y)\n\n\ndef bpb(x, y, loss_fn):\n    \"\"\" bits per byte (image density estimation, speech generation, char LM) \"\"\"\n    return loss_fn(x, y) / math.log(2)\n\n\ndef ppl(x, y, loss_fn):\n    return torch.exp(loss_fn(x, y))\n\n\n# should have a better way to do this\noutput_metric_fns = {\n    \"binary_cross_entropy\": binary_cross_entropy,\n    \"cross_entropy\": cross_entropy,\n    \"padded_cross_entropy\": padded_cross_entropy,\n    \"binary_accuracy\": binary_accuracy,\n    # \"precision\": MulticlassPrecision,\n    # \"precision_species\": partial(MulticlassPrecision, task='multiclass', average=None),\n    \"precision_species\": partial(MulticlassPrecision, average=None),\n    # \"recall_species\": partial(MulticlassRecall, task='multiclass', average=None),\n    \"recall_species\": partial(MulticlassRecall, average=None),\n    # \"precision_class\": partial(MulticlassPrecision, average=None),\n    \"precision_per_class\": PrecisionPerClass,\n    \"recall\": MulticlassRecall,\n    \"recall_per_class\": RecallPerClass,\n    \"accuracy\": accuracy,\n    \"accuracy_per_class\": AccuracyPerClass,\n    \"accuracy_ignore_index\": accuracy_ignore_index,\n    'accuracy@3': partial(accuracy_at_k, k=3),\n    'accuracy@5': partial(accuracy_at_k, k=5),\n    'accuracy@10': partial(accuracy_at_k, k=10),\n    \"eval_loss\": loss,\n    \"mcc\": mcc,\n    \"mse\": mse,\n    \"mae\": mae,\n    \"forecast_rmse\": forecast_rmse,\n    \"f1_binary\": f1_binary,\n    \"f1_macro\": f1_macro,\n    \"f1_micro\": f1_micro,\n    \"roc_auc_macro\": roc_auc_macro,\n    \"roc_auc_micro\": roc_auc_micro,\n    \"soft_cross_entropy\": soft_cross_entropy,  # only for pytorch 1.10+\n    \"student_t\": student_t_loss,\n    \"gaussian_ll\": gaussian_ll_loss,\n}\n\nloss_metric_fns = {\n    \"loss\": loss,\n    \"bpb\": bpb,\n    \"ppl\": ppl,\n}\nmetric_fns = {**output_metric_fns, **loss_metric_fns}  # TODO py3.9\n\n"
  },
  {
    "path": "src/tasks/tasks.py",
    "content": "import inspect\nfrom typing import List\n\nimport torch.nn as nn\nfrom einops import rearrange\n\nimport src.models.nn.utils as U\nimport src.tasks.metrics as M\nimport torchmetrics as tm\nfrom src.models.nn.adaptive_softmax import AdaptiveEmbedding, ProjectedAdaptiveLogSoftmax\nfrom src.tasks.torchmetrics import torchmetric_fns as tm_mine\nfrom src.utils.config import to_list, instantiate\nfrom torchmetrics import MetricCollection\n\n\nclass BaseTask:\n    \"\"\" Abstract class that takes care of:\n    - loss function\n    - arbitrary metrics\n    - forward pass\n    - (optional) encoder module that interfaces with dataset (inputs) and model\n    - (optional) decoder module that interfaces with dataset (targets) and model\n    \"\"\"\n    encoder = None\n    decoder = None\n\n    def __init__(self, dataset=None, model=None, loss=None, loss_val=None, metrics=None, torchmetrics=None):\n        \"\"\" This class is allowed to grab attributes directly off a constructed dataset and model object \"\"\"\n        self.dataset = dataset\n        self.model = model\n        if metrics is None:\n            metrics = []\n        self.metric_names = to_list(metrics)\n\n        if torchmetrics is None:\n            torchmetrics = []\n        self.torchmetric_names = to_list(torchmetrics)\n        self._tracked_torchmetrics = {}\n\n        # The decoder might pass through arguments that the loss needs (e.g. sequence lengths)\n        # but might also pass through extraneous arguments (e.g. sampling rate)\n        # Wrap loss and metrics so that they accept kwargs and\n\n        # Create loss function\n        self.loss = instantiate(M.output_metric_fns, loss, partial=True)\n        self.loss = U.discard_kwargs(self.loss)\n        if loss_val is not None:\n            self.loss_val = instantiate(M.output_metric_fns, loss_val, partial=True)\n            self.loss_val = U.discard_kwargs(self.loss_val)\n        torchmetrics = MetricCollection(self._init_torchmetrics())\n        self.train_torchmetrics = torchmetrics.clone(prefix='train/')\n        self.val_torchmetrics = torchmetrics.clone(prefix='val/')\n        self.test_torchmetrics = torchmetrics.clone(prefix='test/')\n\n    def _init_torchmetrics(self):\n        \"\"\"\n        Instantiate torchmetrics.\n        \"\"\"\n        tracked_torchmetrics = {}\n\n        for name in self.torchmetric_names:\n            if name in tm_mine:\n                tracked_torchmetrics[name] = tm_mine[name]()\n            elif name in ['AUROC', 'StatScores', 'Precision', 'Recall', 'F1', 'F1Score']:\n                tracked_torchmetrics[name] = getattr(tm, name)(\n                    average='macro', num_classes=self.dataset.d_output, compute_on_step=False\n                )\n            elif name in ['MultilabelAUROC', 'MultilabelAveragePrecision']:\n                tracked_torchmetrics[name] = getattr(tm, name)(\n                    average='macro', num_labels=self.dataset.d_output\n                )\n            elif '@' in name:\n                k = int(name.split('@')[1])\n                mname = name.split('@')[0]\n                tracked_torchmetrics[name] = getattr(tm, mname)(\n                    average='macro', num_classes=self.dataset.d_output, compute_on_step=False, top_k=k\n                )\n            else:\n                tracked_torchmetrics[name] = getattr(tm, name)(compute_on_step=False)\n\n        return tracked_torchmetrics\n\n    def _reset_torchmetrics(self, prefix=None):\n        \"\"\"\n        Reset torchmetrics for a prefix\n        associated with a particular dataloader (e.g. train, val, test).\n\n        Generally do this at the start of an epoch.\n        \"\"\"\n        all_prefixes = [prefix] if prefix is not None else self._tracked_torchmetrics\n\n        for prefix in all_prefixes:\n            if prefix in self._tracked_torchmetrics:\n                self._tracked_torchmetrics[prefix].reset()\n\n    def get_torchmetrics(self, prefix):\n        \"\"\"\n        Compute torchmetrics for a prefix associated with\n        a particular dataloader (e.g. train, val, test).\n\n        Generally do this at the end of an epoch.\n        \"\"\"\n        return {name: self._tracked_torchmetrics[prefix][name].compute() for name in self.torchmetric_names}\n\n    def torchmetrics(self, x, y, prefix, loss=None):\n        \"\"\"\n        Update torchmetrics with new x, y .\n        Prefix corresponds to a particular dataloader (e.g. train, val, test).\n\n        Generally call this every batch.\n        \"\"\"\n        if prefix not in self._tracked_torchmetrics:\n            self._init_torchmetrics(prefix)\n        self._tracked_torchmetrics[prefix](x, y, loss=loss)\n\n        # for name in self.torchmetric_names:\n        #     if name.startswith('Accuracy'):\n        #         if len(x.shape) > 2:\n        #             # Multi-dimensional, multi-class\n        #             self._tracked_torchmetrics[prefix][name].update(x.transpose(1, 2), y.squeeze())\n        #             continue\n        #     self._tracked_torchmetrics[prefix][name].update(x, y)\n\n    def get_torchmetrics(self, prefix):\n        return self._tracked_torchmetrics[prefix]\n\n    def metrics(self, x, y, **kwargs):\n        \"\"\"\n        Metrics are just functions\n        output metrics are a function of output and target\n        loss metrics are a function of loss (e.g. perplexity)\n        \"\"\"\n        output_metrics = {\n            name: U.discard_kwargs(M.output_metric_fns[name])(x, y, **kwargs)\n            for name in self.metric_names if name in M.output_metric_fns\n        }\n        loss_metrics = {\n            name: U.discard_kwargs(M.loss_metric_fns[name])(x, y, self.loss, **kwargs)\n            for name in self.metric_names if name in M.loss_metric_fns\n        }\n        return {**output_metrics, **loss_metrics}\n\n    def forward(self, batch, encoder, model, decoder, _state):\n        \"\"\"Passes a batch through the encoder, backbone, and decoder\"\"\"\n        # z holds arguments such as sequence length\n        x, y, *z = batch  # z holds extra dataloader info such as resolution\n        if len(z) == 0:\n            z = {}\n        else:\n            assert len(z) == 1 and isinstance(z[0], dict), \"Dataloader must return dictionary of extra arguments\"\n            z = z[0]\n\n        # w can model-specific constructions, such as key_padding_mask for transformers or state for RNNs\n        x, w = encoder(x, **z)\n        x, state = model(x, **w, state=_state)\n        self._state = state\n        x, w = decoder(x, state=state, **z)\n        return x, y, w\n\n\nclass Scalar(nn.Module):\n    def __init__(self, c=1):\n        super().__init__()\n        self.c = c\n\n    def forward(self, x):\n        return x * self.c\n\n\nclass LMTask(BaseTask):\n    def forward(self, batch, encoder, model, decoder, _state):\n        \"\"\"Passes a batch through the encoder, backbone, and decoder\"\"\"\n        # z holds arguments such as sequence length\n        x, y, *z = batch  # z holds extra dataloader info such as resolution\n        if len(z) == 0:\n            z = {}\n        else:\n            assert len(z) == 1 and isinstance(z[0], dict), \"Dataloader must return dictionary of extra arguments\"\n            z = z[0]\n        # w can model-specific constructions such as key_padding_mask for transformers or state for RNNs\n        x, w = encoder(x, **z)\n        # Needed for Mamba (open-source repo version)\n        if \"state\" in inspect.signature(model.forward).parameters.keys():\n            x, state = model(x, **w, state=_state)\n        else:\n            x = model(x, **w)\n            state = None\n        self._state = state\n        x, w = decoder(x, state=state, **z)\n\n        if hasattr(x, 'logits'):\n            x = x.logits\n        x = rearrange(x, '... C -> (...) C')\n        y = rearrange(y, '... -> (...)')\n\n        return x, y, w\n\n\nclass MultiClass(BaseTask):\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.continual_metrics = {}\n        for name in self.metric_names:\n            if name.endswith('_per_class'):\n                for spec_idx, spec in enumerate(self.dataset.species):\n                    self.continual_metrics[name + '_' + spec] = M.output_metric_fns[name](spec_idx)\n            elif name in ['precision_species', 'recall_species']:\n                self.continual_metrics[name] = M.output_metric_fns[name](num_classes=len(self.dataset.species))\n\n    def metrics(self, x, y, **kwargs):\n        output_metrics = {}\n        for name in self.metric_names:\n            if name in M.output_metric_fns:\n                if name.endswith('_per_class'):\n                    for spec_idx, spec in enumerate(self.dataset.species):\n                        self.continual_metrics[name + '_' + spec] = self.continual_metrics[name + '_' + spec].to(\n                            x.device)\n                        self.continual_metrics[name + '_' + spec].update(x, y)\n                        output_metrics[name + '_' + spec] = self.continual_metrics[name + '_' + spec].compute()\n                elif name in ['precision_species', 'recall_species']:\n                    self.continual_metrics[name] = self.continual_metrics[name].to(x.device)\n                    metrics = self.continual_metrics[name](x, y)\n                    for spec_idx, spec in enumerate(self.dataset.species):\n                        output_metrics[name[:-7] + spec] = metrics[spec_idx]\n                else:\n                    output_metrics[name] = U.discard_kwargs(M.output_metric_fns[name])(x, y, **kwargs)\n\n        loss_metrics = {\n            name: U.discard_kwargs(M.loss_metric_fns[name])(x, y, self.loss, **kwargs)\n            for name in self.metric_names if name in M.loss_metric_fns\n        }\n\n        return {**output_metrics, **loss_metrics}\n\n    def _reset_torchmetrics(self, prefix=None):\n        super()._reset_torchmetrics(prefix)\n        for name in self.metric_names:\n            if name.endswith('_per_class'):\n                for spec_idx, spec in enumerate(self.dataset.species):\n                    self.continual_metrics[name + '_' + spec].reset()\n\n\nclass HG38Task(LMTask):\n\n    def __init__(self, dataset=None, model=None, loss=None, loss_val=None, metrics=None, torchmetrics=None,\n                 last_k_ppl=None, per_token_ppl=None):\n        \"\"\" Extending LMTask to add custom metrics for HG38 task \n        \n        last_k_ppl: config for custom ppl, with hparams to pass with it\n\n        per_token_ppl: config for per token ppl calc, with list of k (ppls) to track\n\n        \"\"\"\n        self.dataset = dataset\n        self.model = model\n        if metrics is None:\n            metrics = []\n        self.metric_names = to_list(metrics)\n        self.last_k_ppl = last_k_ppl\n        self.per_token_ppl = per_token_ppl\n\n        if torchmetrics is None:\n            torchmetrics = []\n        self.torchmetric_names = to_list(torchmetrics)\n        self._tracked_torchmetrics = {}\n\n        # The decoder might pass through arguments that the loss needs (e.g. sequence lengths)\n        # but might also pass through extraneous arguments (e.g. sampling rate)\n        # Wrap loss and metrics so that they accept kwargs and\n\n        # Create loss function\n        self.loss = instantiate(M.output_metric_fns, loss, partial=True)\n        self.loss = U.discard_kwargs(self.loss)\n        if loss_val is not None:\n            self.loss_val = instantiate(M.output_metric_fns, loss_val, partial=True)\n            self.loss_val = U.discard_kwargs(self.loss_val)\n        torchmetrics = MetricCollection(self._init_torchmetrics())\n        self.train_torchmetrics = torchmetrics.clone(prefix='train/')\n        self.val_torchmetrics = torchmetrics.clone(prefix='val/')\n        self.test_torchmetrics = torchmetrics.clone(prefix='test/')\n\n        # Create custom metrics for last k ppl\n        # last_k_ppl is a list of dicts (configs), so loop through them\n        if self.last_k_ppl is not None:\n            self.custom_ppl_dict = {}\n            for k in self.last_k_ppl:\n                key_name = \"last_\" + str(k) + \"_ppl\"\n                # create config\n                custom_ppl_config = {\"_name_\": \"last_k_ppl\", \"k\": k, \"seq_len\": self.dataset.max_length}\n                k_ppl_fn = instantiate(M.output_metric_fns, custom_ppl_config, partial=True)\n                k_ppl_fn = U.discard_kwargs(k_ppl_fn)\n                self.custom_ppl_dict[key_name] = k_ppl_fn\n\n        # Create custom metric for per token ppl\n        if self.per_token_ppl is not None:\n            per_token_ppl_config = {\"_name_\": \"per_token_ppl\", \"ks\": self.per_token_ppl[\"ks\"],\n                                    \"seq_len\": self.dataset.max_length}\n            per_token_fn = instantiate(M.output_metric_fns, per_token_ppl_config, partial=True)\n            per_token_fn = U.discard_kwargs(per_token_fn)\n            self.per_token_fn = per_token_fn\n\n    def metrics(self, x, y, **kwargs):\n        \"\"\"\n        Need to modify metrics to include custom metrics\n        \"\"\"\n\n        output_metrics = {\n            name: U.discard_kwargs(M.output_metric_fns[name])(x, y, **kwargs)\n            for name in self.metric_names if name in M.output_metric_fns\n        }\n        loss_metrics = {\n            name: U.discard_kwargs(M.loss_metric_fns[name])(x, y, self.loss, **kwargs)\n            for name in self.metric_names if name in M.loss_metric_fns\n        }\n\n        # loop through all custom ppls and add them to output_metrics\n        if self.last_k_ppl is not None:\n            for key_name, k_ppl_fn in self.custom_ppl_dict.items():\n                output_metrics[key_name] = k_ppl_fn(x, y, **kwargs)\n\n        # loop through all custom ppls and add them to output_metrics\n        if self.per_token_ppl is not None:\n            # returns k ppl values, (averaged over batch)\n            per_k_ppl = self.per_token_fn(x, y, **kwargs)\n\n            # loop over ks to log metric\n            for ind, k in enumerate(self.per_token_ppl[\"ks\"]):\n                key_name = \"ppl_at_{}\".format(k)\n                output_metrics[key_name] = per_k_ppl[ind]  # should be in order\n\n        return {**output_metrics, **loss_metrics}\n\n\nclass AdaptiveLMTask(BaseTask):\n    def __init__(\n            self,\n            div_val,\n            cutoffs: List[int],\n            tie_weights: bool,\n            tie_projs: List[bool],\n            init_scale=1.0,\n            bias_scale=0.0,\n            dropemb=0.0,\n            dropsoft=0.0,\n            **kwargs,\n    ):\n        super().__init__(**kwargs)\n        n_tokens = self.dataset.n_tokens\n        d_model = self.model.d_model\n        d_output = self.model.d_output\n\n        encoder = AdaptiveEmbedding(\n            n_tokens,\n            d_model,\n            d_model,\n            cutoffs=cutoffs,\n            div_val=div_val,\n            init_scale=init_scale,\n            dropout=dropemb,\n        )\n\n        if tie_weights:\n            assert d_model == d_output\n            emb_layers = [i.weight for i in encoder.emb_layers]\n        else:\n            emb_layers = None\n\n        # Construct decoder/loss\n        emb_projs = encoder.emb_projs\n        loss = ProjectedAdaptiveLogSoftmax(\n            n_tokens, d_output, d_output,\n            cutoffs, div_val=div_val,\n            tie_projs=tie_projs,\n            out_projs=emb_projs,\n            out_layers_weights=emb_layers,\n            bias_scale=bias_scale,\n            dropout=dropsoft,\n        )\n\n        self.encoder = encoder\n        self.loss = loss\n\n\nregistry = {\n    'base': BaseTask,\n    'multiclass': MultiClass,\n    'lm': LMTask,\n    'hg38': HG38Task,\n}\n"
  },
  {
    "path": "src/tasks/torchmetrics.py",
    "content": "# Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py\n# But we compute the perplexity correctly: exp(average(nll)), not average(exp(nll))\n# Also adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/perplexity.py\n# But we pass in the loss to avoid recomputation\n\nfrom typing import Any, Dict, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torchmetrics import Metric\n\ntry:\n    from flash_attn.losses.cross_entropy import CrossEntropyLoss\nexcept ImportError:\n    CrossEntropyLoss = torch.nn.CrossEntropyLoss\n\ntry:\n    from apex.transformer import parallel_state\nexcept ImportError:\n    parallel_state = None\n\n\nclass Perplexity(Metric):\n    r\"\"\"\n    Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits\n    per word a model needs to represent the sample.\n    Args:\n        kwargs:\n            Additional keyword arguments, see :ref:`Metric kwargs` for more info.\n    Examples:\n        >>> import torch\n        >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))\n        >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))\n        >>> target[0, 6:] = -100\n        >>> metric = Perplexity(ignore_index=-100)\n        >>> metric(preds, target)\n        tensor(5.2545)\n    \"\"\"\n    is_differentiable = True\n    higher_is_better = False\n    full_state_update = False\n    total_log_probs: Tensor\n    count: Tensor\n\n    def __init__(self, **kwargs: Dict[str, Any]):\n        super().__init__(**kwargs)\n        self.add_state(\"total_log_probs\", default=torch.tensor(0.0, dtype=torch.float64),\n                       dist_reduce_fx=\"sum\")\n        self.add_state(\"count\", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx=\"sum\")\n\n        self.loss_fn = CrossEntropyLoss()\n\n    def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None:  # type: ignore\n        \"\"\"Compute and store intermediate statistics for Perplexity.\n        Args:\n            preds:\n                Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size].\n            target:\n                Ground truth values with a shape [batch_size, seq_len].\n        \"\"\"\n        count = target.numel()\n        if loss is None:\n            loss = self.loss_fn(preds, target)\n        self.total_log_probs += loss.double() * count\n        self.count += count\n\n    def compute(self) -> Tensor:\n        \"\"\"Compute the Perplexity.\n        Returns:\n           Perplexity\n        \"\"\"\n        return torch.exp(self.total_log_probs / self.count)\n\nclass NumTokens(Metric):\n    \"\"\"Keep track of how many tokens we've seen.\n    \"\"\"\n    # TODO: how do we prevent the reset between the epochs? The reset happens on the 1st batch\n    # of the next epoch.\n    # Right now the hack is that we override reset(), which would mess up the forward method.\n    # We then override forward to do the right thing.\n\n    is_differentiable = False\n    higher_is_better = False\n    full_state_update = False\n    count: Tensor\n\n    def __init__(self, **kwargs: Dict[str, Any]):\n        super().__init__(**kwargs)\n        self.add_state(\"count\", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx=\"sum\",\n                       persistent=True)  # We want the count to be saved to state-dict\n        if parallel_state is not None and not parallel_state.is_unitialized():\n            self.tensor_parallel_world_size = parallel_state.get_tensor_model_parallel_world_size()\n        else:\n            self.tensor_parallel_world_size = 1\n\n    def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None:  # type: ignore\n        self.count += target.numel() // self.tensor_parallel_world_size\n\n    def compute(self) -> Tensor:\n        return self.count\n\n    def reset(self):\n        count = self.count\n        super().reset()\n        self.count = count\n\n    # Adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/metric.py\n    def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:\n        \"\"\"forward computation using single call to `update` to calculate the metric value on the current batch and\n        accumulate global state.\n        This can be done when the global metric state is a sinple reduction of batch states.\n        \"\"\"\n        self.update(*args, **kwargs)\n        return self.compute()\n\ntorchmetric_fns = {\n    \"perplexity\": Perplexity,\n    \"num_tokens\": NumTokens,\n}\n"
  },
  {
    "path": "src/utils/__init__.py",
    "content": "from .config import is_list, is_dict, to_list, to_dict, get_class, instantiate\n"
  },
  {
    "path": "src/utils/config.py",
    "content": "\"\"\"Utilities for dealing with collection objects (lists, dicts) and configs.\n\n\"\"\"\n\nimport functools\nfrom typing import Sequence, Mapping, Callable\n\nimport hydra\nfrom omegaconf import ListConfig, DictConfig\n\n\n# TODO this is usually used in a pattern where it's turned into a list, so can just do that here\ndef is_list(x):\n    return isinstance(x, Sequence) and not isinstance(x, str)\n\n\ndef is_dict(x):\n    return isinstance(x, Mapping)\n\n\ndef to_dict(x, recursive=True):\n    \"\"\"Convert Sequence or Mapping object to dict\n\n    lists get converted to {0: x[0], 1: x[1], ...}\n    \"\"\"\n    if is_list(x):\n        x = {i: v for i, v in enumerate(x)}\n    if is_dict(x):\n        if recursive:\n            return {k: to_dict(v, recursive=recursive) for k, v in x.items()}\n        else:\n            return dict(x)\n    else:\n        return x\n\n\ndef to_list(x, recursive=False):\n    \"\"\"Convert an object to list.\n\n    If Sequence (e.g. list, tuple, Listconfig): just return it\n\n    Special case: If non-recursive and not a list, wrap in list\n    \"\"\"\n    if is_list(x):\n        if recursive:\n            return [to_list(_x) for _x in x]\n        else:\n            return list(x)\n    else:\n        if recursive:\n            return x\n        else:\n            return [x]\n\n\ndef extract_attrs_from_obj(obj, *attrs):\n    if obj is None:\n        assert len(attrs) == 0\n        return []\n    return [getattr(obj, attr, None) for attr in attrs]\n\n\ndef auto_assign_attrs(cls, **kwargs):\n    for k, v in kwargs.items():\n        setattr(cls, k, v)\n        \n        \ndef instantiate(registry, config, *args, partial=False, wrap=None, **kwargs):\n    \"\"\"\n    registry: Dictionary mapping names to functions or target paths (e.g. {'model': 'models.SequenceModel'})\n    config: Dictionary with a '_name_' key indicating which element of the registry to grab, and kwargs to be passed into the target constructor\n    wrap: wrap the target class (e.g. ema optimizer or tasks.wrap)\n    *args, **kwargs: additional arguments to override the config to pass into the target constructor\n    \"\"\"\n\n    # Case 1: no config\n    if config is None:\n        return None\n    # Case 2a: string means _name_ was overloaded\n    if isinstance(config, str):\n        _name_ = None\n        _target_ = registry[config]\n        config = {}\n    # Case 2b: grab the desired callable from name\n    else:\n        _name_ = config.pop(\"_name_\")\n        _target_ = registry[_name_]\n\n    # Retrieve the right constructor automatically based on type\n    if isinstance(_target_, str):\n        fn = hydra.utils.get_method(path=_target_)\n    elif isinstance(_target_, Callable):\n        fn = _target_\n    else:\n        raise NotImplementedError(\"instantiate target must be string or callable\")\n\n    # Instantiate object\n    if wrap is not None:\n        fn = wrap(fn)\n    obj = functools.partial(fn, *args, **config, **kwargs)\n\n    # Restore _name_\n    if _name_ is not None:\n        config[\"_name_\"] = _name_\n\n    if partial:\n        return obj\n    else:\n        return obj()\n\n\ndef get_class(registry, _name_):\n    return hydra.utils.get_class(path=registry[_name_])\n\n\ndef omegaconf_filter_keys(d, fn=None):\n    \"\"\"Only keep keys where fn(key) is True. Support nested DictConfig.\n    # TODO can make this inplace?\n    \"\"\"\n    if fn is None:\n        fn = lambda _: True\n    if is_list(d):\n        return ListConfig([omegaconf_filter_keys(v, fn) for v in d])\n    elif is_dict(d):\n        return DictConfig(\n            {k: omegaconf_filter_keys(v, fn) for k, v in d.items() if fn(k)}\n        )\n    else:\n        return d\n"
  },
  {
    "path": "src/utils/optim/schedulers.py",
    "content": "\"\"\"Custom learning rate schedulers\"\"\"\n\nimport math\nimport warnings\nimport torch\n\nfrom timm.scheduler import CosineLRScheduler\n\n\n# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html\nclass CosineWarmup(torch.optim.lr_scheduler.CosineAnnealingLR):\n\n    def __init__(self, optimizer, T_max, eta_min=0, warmup_step=0, **kwargs):\n        self.warmup_step = warmup_step\n        super().__init__(optimizer, T_max - warmup_step, eta_min, *kwargs)\n\n    # Copied from CosineAnnealingLR, but adding warmup and changing self.last_epoch to\n    # self.last_epoch - self.warmup_step.\n    def get_lr(self):\n        if not self._get_lr_called_within_step:\n            warnings.warn(\"To get the last learning rate computed by the scheduler, \"\n                          \"please use `get_last_lr()`.\", UserWarning)\n\n        if self.last_epoch == self.warmup_step:  # also covers the case where both are 0\n            return self.base_lrs\n        elif self.last_epoch < self.warmup_step:\n            return [base_lr * (self.last_epoch + 1) / self.warmup_step for base_lr in self.base_lrs]\n        elif (self.last_epoch - self.warmup_step - 1 - self.T_max) % (2 * self.T_max) == 0:\n            return [group['lr'] + (base_lr - self.eta_min) *\n                    (1 - math.cos(math.pi / self.T_max)) / 2\n                    for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)]\n        return [(1 + math.cos(math.pi * (self.last_epoch - self.warmup_step) / self.T_max)) /\n                (1 + math.cos(math.pi * (self.last_epoch - self.warmup_step - 1) / self.T_max)) *\n                (group['lr'] - self.eta_min) + self.eta_min\n                for group in self.optimizer.param_groups]\n\n    _get_closed_form_lr = None\n\n\ndef InvSqrt(optimizer, warmup_step):\n    \"\"\" Originally used for Transformer (in Attention is all you need)\n    \"\"\"\n\n    def lr_lambda(step):\n        # return a multiplier instead of a learning rate\n        if step == warmup_step:  # also covers the case where both are 0\n            return 1.\n        else:\n            return 1. / (step ** 0.5) if step > warmup_step else (step + 1) / (warmup_step ** 1.5)\n\n    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)\n\n\ndef Constant(optimizer, warmup_step):\n\n    def lr_lambda(step):\n        if step == warmup_step:  # also covers the case where both are 0\n            return 1.\n        else:\n            return 1. if step > warmup_step else (step + 1) / warmup_step\n\n    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)\n\n\nclass TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler):\n    \"\"\" Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch.\n    It supports resuming as well.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._last_epoch = -1\n        self.step(epoch=0)\n\n    def step(self, epoch=None):\n        if epoch is None:\n            self._last_epoch += 1\n        else:\n            self._last_epoch = epoch\n        # We call either step or step_update, depending on whether we're using the scheduler every\n        # epoch or every step.\n        # Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set\n        # scheduler interval to \"step\", then the learning rate update will be wrong.\n        if self.t_in_epochs:\n            super().step(epoch=self._last_epoch)\n        else:\n            super().step_update(num_updates=self._last_epoch)\n"
  },
  {
    "path": "src/utils/optim_groups.py",
    "content": "\"\"\"Utilities for special optimizer hyperparameters.\n\ngroup_parameters_for_optimizer is a modification of timm's optimizer logic, which is currently unused\nadd_optimizer_hooks is an improved version that uses this codebase's _optim dictionary\n\"\"\"\n\nimport inspect\n\nimport torch.nn as nn\n\nimport hydra\n\n\ndef add_optimizer_hooks(\n    model,\n    bias_weight_decay=False,\n    normalization_weight_decay=False,\n):\n    \"\"\"Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with\n    attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for\n    normalization parameters if normalization_weight_decay==False\n    \"\"\"\n\n    # Separate out all parameters to those that will and won't experience regularizing weight decay\n    blacklist_weight_modules = (nn.Embedding, )\n    if not normalization_weight_decay:\n        blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,\n                                     # Not compatible with Pytorch 1.8.1\n                                     # nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d,\n                                     nn.GroupNorm, nn.SyncBatchNorm,\n                                     nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,\n                                     nn.LayerNorm, nn.LocalResponseNorm)\n    for mn, m in model.named_modules():\n        for pn, p in m.named_parameters():\n            if (not bias_weight_decay and pn.endswith('bias')) \\\n                    or getattr(p, '_no_weight_decay', False) \\\n                    or isinstance(m, blacklist_weight_modules):\n                        setattr(p, \"_optim\", {\"weight_decay\": 0.0})\n\n\ndef group_parameters_for_optimizer(\n    model,\n    optimizer_cfg,\n    bias_weight_decay=False,\n    normalization_weight_decay=False,\n):\n    \"\"\"Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with\n    attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for\n    normalization parameters if normalization_weight_decay==False\n    \"\"\"\n    # Get the weight decay from the config, or from the default value of the optimizer constructor\n    # if it's not specified in the config.\n    if 'weight_decay' in optimizer_cfg:\n        weight_decay = optimizer_cfg.weight_decay\n    else:\n        # https://stackoverflow.com/questions/12627118/get-a-function-arguments-default-value\n        signature = inspect.signature(hydra.utils.get_class(optimizer_cfg._target_))\n        if 'weight_decay' in signature.parameters:\n            weight_decay = signature.parameters['weight_decay'].default\n            if weight_decay is inspect.Parameter.empty:\n                weight_decay = 0.0\n        else:\n            weight_decay = 0.0\n\n    # If none of the parameters have weight decay anyway, and there are no parameters with special\n    # optimization params\n    if weight_decay == 0.0 and not any(hasattr(p, '_optim') for p in model.parameters()):\n        return model.parameters()\n\n    skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else set()\n    skip_keywords = (model.no_weight_decay_keywords() if hasattr(model, 'no_weight_decay_keywords')\n                     else set())\n\n    # Adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134\n    \"\"\"\n    This long function is unfortunately doing something very simple and is being very defensive:\n    We are separating out all parameters of the model into two buckets: those that will experience\n    weight decay for regularization and those that won't (biases, and layernorm/embedding weights).\n    We are then returning the PyTorch optimizer object.\n    \"\"\"\n\n    # separate out all parameters to those that will and won't experience regularizing weight decay\n    decay = set()\n    no_decay = set()\n    special = set()\n    whitelist_weight_modules = (nn.Linear, )\n    blacklist_weight_modules = (nn.Embedding, )\n    if not normalization_weight_decay:\n        blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,\n                                     # Not compatible with Pytorch 1.8.1\n                                     # nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d,\n                                     nn.GroupNorm, nn.SyncBatchNorm,\n                                     nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,\n                                     nn.LayerNorm, nn.LocalResponseNorm)\n    for mn, m in model.named_modules():\n        for pn, p in m.named_parameters():\n            fpn = '%s.%s' % (mn, pn) if mn else pn # full param name\n            if not p.requires_grad:\n                continue  # frozen weights\n            if hasattr(p, '_optim'):\n                special.add(fpn)\n            elif fpn in skip or any(skip_keyword in fpn for skip_keyword in skip_keywords):\n                no_decay.add(fpn)\n            elif getattr(p, '_no_weight_decay', False):\n                no_decay.add(fpn)\n            elif not bias_weight_decay and pn.endswith('bias'):\n                no_decay.add(fpn)\n            elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):\n                # weights of whitelist modules will be weight decayed\n                decay.add(fpn)\n            elif isinstance(m, blacklist_weight_modules):\n                # weights of blacklist modules will NOT be weight decayed\n                no_decay.add(fpn)\n\n    param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}\n    # special case the position embedding parameter in the root GPT module as not decayed\n    if 'pos_emb' in param_dict:\n        no_decay.add('pos_emb')\n\n    # In case of parameter sharing, some parameters show up in decay but are not in param_dict.keys()\n    decay &= param_dict.keys()\n    decay |= (param_dict.keys() - no_decay - special)\n    # validate that we considered every parameter\n    inter_params = decay & no_decay\n    union_params = decay | no_decay\n    assert len(inter_params) == 0, f\"Parameters {str(inter_params)} made it into both decay/no_decay sets!\"\n    assert len(param_dict.keys() - special - union_params) == 0, f\"parameters {str(param_dict.keys() - union_params)}  were not separated into either decay/no_decay set!\"\n\n    if weight_decay == 0.0 or not no_decay:\n        param_groups = [{\"params\": [param_dict[pn] for pn in sorted(list(no_decay | decay))],\n                         \"weight_decay\": weight_decay}]\n    else:\n        param_groups = [\n            {\"params\": [param_dict[pn] for pn in sorted(list(decay))], \"weight_decay\": weight_decay},\n            {\"params\": [param_dict[pn] for pn in sorted(list(no_decay))], \"weight_decay\": 0.0},\n        ]\n    # Add parameters with special hyperparameters\n    # Unique dicts\n    hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)]\n    for hp in hps:\n        params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp]\n        param_groups.append({\"params\": params, **hp})\n\n    return param_groups\n"
  },
  {
    "path": "src/utils/registry.py",
    "content": "\"\"\"Class registry for models, layers, optimizers, and schedulers.\n\n\"\"\"\n\noptimizer = {\n    \"adam\": \"torch.optim.Adam\",\n    \"adamw\": \"torch.optim.AdamW\",\n    \"rmsprop\": \"torch.optim.RMSprop\",\n    \"sgd\": \"torch.optim.SGD\",\n    \"lamb\": \"src.utils.optim.lamb.JITLamb\",\n}\n\nscheduler = {\n    \"constant\": \"transformers.get_constant_schedule\",\n    \"plateau\": \"torch.optim.lr_scheduler.ReduceLROnPlateau\",\n    \"step\": \"torch.optim.lr_scheduler.StepLR\",\n    \"multistep\": \"torch.optim.lr_scheduler.MultiStepLR\",\n    \"cosine\": \"torch.optim.lr_scheduler.CosineAnnealingLR\",\n    \"constant_warmup\": \"transformers.get_constant_schedule_with_warmup\",\n    \"linear_warmup\": \"transformers.get_linear_schedule_with_warmup\",\n    \"cosine_warmup\": \"transformers.get_cosine_schedule_with_warmup\",\n    \"cosine_warmup_timm\": \"src.utils.optim.schedulers.TimmCosineLRScheduler\",\n}\n\nmodel = {\n    # Pre-training LM head models\n    \"hyena_lm\": \"src.models.sequence.long_conv_lm.ConvLMHeadModel\",\n    \"mamba_lm\": \"mamba_ssm.models.mixer_seq_simple.MambaLMHeadModel\",\n    \"caduceus_lm\": \"caduceus.modeling_caduceus.CaduceusForMaskedLM\",\n\n    # Downstream task embedding backbones\n    \"dna_embedding\": \"src.models.sequence.dna_embedding.DNAEmbeddingModel\",\n    \"dna_embedding_mamba\": \"src.models.sequence.dna_embedding.DNAEmbeddingModelMamba\",\n    \"dna_embedding_caduceus\": \"src.models.sequence.dna_embedding.DNAEmbeddingModelCaduceus\",\n\n    # Baseline for genomics benchmark\n    \"genomics_benchmark_cnn\": \"src.models.baseline.genomics_benchmark_cnn.GenomicsBenchmarkCNN\",\n}\n\nlayer = {\n    \"id\": \"src.models.sequence.base.SequenceIdentity\",\n    \"ff\": \"src.models.sequence.ff.FF\",\n    \"hyena\": \"src.models.sequence.hyena.HyenaOperator\",\n    \"hyena-filter\": \"src.models.sequence.hyena.HyenaFilter\",\n}\n\ncallbacks = {\n    \"learning_rate_monitor\": \"pytorch_lightning.callbacks.LearningRateMonitor\",\n    \"model_checkpoint\": \"pytorch_lightning.callbacks.ModelCheckpoint\",\n    \"model_checkpoint_every_n_steps\": \"pytorch_lightning.callbacks.ModelCheckpoint\",\n    \"model_checkpoint_every_epoch\": \"pytorch_lightning.callbacks.ModelCheckpoint\",\n    \"early_stopping\": \"pytorch_lightning.callbacks.EarlyStopping\",\n    \"swa\": \"pytorch_lightning.callbacks.StochasticWeightAveraging\",\n    \"rich_model_summary\": \"pytorch_lightning.callbacks.RichModelSummary\",\n    \"rich_progress_bar\": \"pytorch_lightning.callbacks.RichProgressBar\",\n    \"params\": \"src.callbacks.params.ParamsLog\",\n    \"timer\": \"src.callbacks.timer.Timer\",\n    \"val_every_n_global_steps\": \"src.callbacks.validation.ValEveryNGlobalSteps\",\n}\n\nmodel_state_hook = {\n    'load_backbone': 'src.models.sequence.dna_embedding.load_backbone',\n}\n"
  },
  {
    "path": "src/utils/train.py",
    "content": "\"\"\" Utils for the training loop.\n\nCopied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py\n\"\"\"\n\nimport json\nimport logging\nimport warnings\n\nimport rich.syntax\nimport rich.tree\nimport torch.nn as nn\nfrom omegaconf import DictConfig, OmegaConf\nfrom pytorch_lightning.utilities import rank_zero_only\n\nfrom src.utils.config import omegaconf_filter_keys\n\n\n# Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging\nclass LoggingContext:\n    def __init__(self, logger, level=None, handler=None, close=True):\n        self.logger = logger\n        self.level = level\n        self.handler = handler\n        self.close = close\n\n    def __enter__(self):\n        if self.level is not None:\n            self.old_level = self.logger.level\n            self.logger.setLevel(self.level)\n        if self.handler:\n            self.logger.addHandler(self.handler)\n\n    def __exit__(self, et, ev, tb):\n        if self.level is not None:\n            self.logger.setLevel(self.old_level)\n        if self.handler:\n            self.logger.removeHandler(self.handler)\n        if self.handler and self.close:\n            self.handler.close()\n        # implicit return of None => don't swallow exceptions\n\n\ndef get_logger(name=__name__, level=logging.INFO) -> logging.Logger:\n    \"\"\"Initializes multi-GPU-friendly python logger.x\"\"\"\n\n    logger = logging.getLogger(name)\n    logger.setLevel(level)\n\n    # this ensures all logging levels get marked with the rank zero decorator\n    # otherwise logs would get multiplied for each GPU process in multi-GPU setup\n    for level in (\"debug\", \"info\", \"warning\", \"error\", \"exception\", \"fatal\", \"critical\"):\n        setattr(logger, level, rank_zero_only(getattr(logger, level)))\n\n    return logger\n\n\ndef process_config(config: DictConfig) -> DictConfig:  # TODO because of filter_keys, this is no longer in place\n    \"\"\"A couple of optional utilities, controlled by main config file:\n    - disabling warnings\n    - easier access to debug mode\n    - forcing debug friendly configuration\n    Modifies DictConfig in place.\n    Args:\n        config (DictConfig): Configuration composed by Hydra.\n    \"\"\"\n    log = get_logger()\n\n    # Filter out keys that were used just for interpolation\n    config = omegaconf_filter_keys(config, lambda k: not k.startswith('__'))\n\n    # enable adding new keys to config\n    OmegaConf.set_struct(config, False)\n\n    # disable python warnings if <config.ignore_warnings=True>\n    if config.get(\"ignore_warnings\"):\n        log.info(\"Disabling python warnings! <config.ignore_warnings=True>\")\n        warnings.filterwarnings(\"ignore\")\n\n    if config.get(\"debug\"):\n        log.info(\"Running in debug mode! <config.debug=True>\")\n        config.trainer.fast_dev_run = True\n\n        # force debugger friendly configuration\n        log.info(\"Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>\")\n        # Debuggers don't like GPUs or multiprocessing\n        if config.trainer.get(\"gpus\"):\n            config.trainer.gpus = 0\n        if config.loader.get(\"pin_memory\"):\n            config.loader.pin_memory = False\n        if config.loader.get(\"num_workers\"):\n            config.loader.num_workers = 0\n\n    # disable adding new keys to config\n    # OmegaConf.set_struct(config, True) # [21-09-17 AG] I need this for .pop(_name_) pattern among other things\n\n    return config\n\n\n@rank_zero_only\ndef print_config(\n        config: DictConfig,\n        resolve: bool = True,\n        save_cfg=True,\n) -> None:\n    \"\"\"Prints content of DictConfig using Rich library and its tree structure.\n    Args:\n        config (DictConfig): Configuration composed by Hydra.\n        resolve (bool, optional): Whether to resolve reference fields of DictConfig.\n        save_cfg (bool, optional): Whether to save the config to a file.\n    \"\"\"\n\n    style = \"dim\"\n    tree = rich.tree.Tree(\"CONFIG\", style=style, guide_style=style)\n\n    fields = config.keys()\n    for field in fields:\n        branch = tree.add(field, style=style, guide_style=style)\n\n        config_section = config.get(field)\n        branch_content = str(config_section)\n        if isinstance(config_section, DictConfig):\n            branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)\n\n        branch.add(rich.syntax.Syntax(branch_content, \"yaml\"))\n\n    rich.print(tree)\n\n    if save_cfg:\n        with open(\"config_tree.txt\", \"w\") as fp:\n            rich.print(tree, file=fp)\n        with open(\"model_config.json\", \"w\") as fp:  # Save config / model config for use in fine-tuning or testing\n            model_config = {\n                k: v\n                for k, v in OmegaConf.to_container(config.model, resolve=True).items()\n                if not k.startswith(\"_\") or k == \"config_path\"\n            }\n            json.dump(model_config, fp, indent=4)\n        with open(\"config.json\", \"w\") as fp:\n            json.dump(OmegaConf.to_container(config, resolve=True), fp, indent=4)\n\n\ndef log_optimizer(logger, optimizer, keys):\n    \"\"\" Log values of particular keys from the optimizers param groups \"\"\"\n    keys = sorted(keys)\n    for i, g in enumerate(optimizer.param_groups):\n        group_hps = {k: g.get(k, None) for k in keys}\n        logger.info(' | '.join([\n                                   f\"Optimizer group {i}\",\n                                   f\"{len(g['params'])} tensors\",\n                               ] + [f\"{k} {v}\" for k, v in group_hps.items()]))\n\n\nclass OptimModule(nn.Module):\n    \"\"\" Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters \"\"\"\n\n    def register(self, name, tensor, lr=None, wd=0.0):\n        \"\"\"Register a tensor with a configurable learning rate and 0 weight decay\"\"\"\n\n        if lr == 0.0:\n            self.register_buffer(name, tensor)\n        else:\n            self.register_parameter(name, nn.Parameter(tensor))\n\n            optim = {}\n            if lr is not None:\n                optim[\"lr\"] = lr\n            if wd is not None:\n                optim[\"weight_decay\"] = wd\n            setattr(getattr(self, name), \"_optim\", optim)\n"
  },
  {
    "path": "train.py",
    "content": "\"\"\"Main training entry point for pre-training and downstream fine-tuning.\n\n\"\"\"\n\nimport json\nimport os\nimport random\nimport time\nfrom functools import wraps\nfrom typing import Callable, List, Sequence\n\nimport fsspec\nimport hydra\nimport pytorch_lightning as pl\nimport torch\nimport wandb\nfrom omegaconf import OmegaConf\nfrom pytorch_lightning.loggers import WandbLogger\nfrom pytorch_lightning.utilities import rank_zero_only, rank_zero_warn\n\nimport src.models.nn.utils as U\nimport src.utils as utils\nimport src.utils.train\nfrom src.dataloaders import SequenceDataset  # TODO make registry\nfrom src.tasks import decoders, encoders, tasks\nfrom src.utils import registry\nfrom src.utils.optim_groups import add_optimizer_hooks\n\nlog = src.utils.train.get_logger(__name__)\n\n# Turn on TensorFloat32 (speeds up large model training substantially)\nimport torch.backends\n\ntorch.backends.cuda.matmul.allow_tf32 = True\ntorch.backends.cudnn.allow_tf32 = True\n\nOmegaConf.register_new_resolver('eval', eval)\nOmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y)\nOmegaConf.register_new_resolver('min', lambda x, y: min([x, y]))\n\n\n# Lots of annoying hacks to get WandbLogger to continuously retry on failure\nclass DummyExperiment:\n    \"\"\"Dummy experiment.\"\"\"\n\n    def nop(self, *args, **kw):\n        pass\n\n    def __getattr__(self, _):\n        return self.nop\n\n    def __getitem__(self, idx) -> \"DummyExperiment\":\n        # enables self.logger.experiment[0].add_image(...)\n        return self\n\n    def __setitem__(self, *args, **kwargs) -> None:\n        pass\n\n\ndef rank_zero_experiment(fn: Callable) -> Callable:\n    \"\"\"Returns the real experiment on rank 0 and otherwise the DummyExperiment.\"\"\"\n\n    @wraps(fn)\n    def experiment(self):\n        @rank_zero_only\n        def get_experiment():\n            return fn(self)\n\n        return get_experiment() or DummyExperiment()\n\n    return experiment\n\n\nclass CustomWandbLogger(WandbLogger):\n\n    def __init__(self, *args, **kwargs):\n        \"\"\"Modified logger that insists on a wandb.init() call and catches wandb's error if thrown.\"\"\"\n\n        super().__init__(*args, **kwargs)\n\n    @property\n    @rank_zero_experiment\n    def experiment(self):\n        r\"\"\"\n        Actual wandb object. To use wandb features in your\n        :class:`~pytorch_lightning.core.lightning.LightningModule` do the following.\n        Example::\n            code-block:: python\n            self.logger.experiment.some_wandb_function()\n        \"\"\"\n        if self._experiment is None:\n            if self._offline:\n                os.environ[\"WANDB_MODE\"] = \"dryrun\"\n\n            attach_id = getattr(self, \"_attach_id\", None)\n            if wandb.run is not None:\n                # wandb process already created in this instance\n                rank_zero_warn(\n                    \"There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse\"\n                    \" this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.\"\n                )\n                self._experiment = wandb.run\n            elif attach_id is not None and hasattr(wandb, \"_attach\"):\n                # attach to wandb process referenced\n                self._experiment = wandb._attach(attach_id)\n            else:\n                # create new wandb process\n                while True:\n                    try:\n                        self._experiment = wandb.init(**self._wandb_init)\n                        break\n                    except Exception as e:\n                        log.error(\"wandb Exception:\\n\", e)\n                        t = random.randint(30, 60)\n                        log.warning(f\"Sleeping for {t} seconds\")\n                        time.sleep(t)\n\n                # define default x-axis\n                if getattr(self._experiment, \"define_metric\", None):\n                    self._experiment.define_metric(\"trainer/global_step\")\n                    self._experiment.define_metric(\"*\", step_metric=\"trainer/global_step\", step_sync=True)\n\n        return self._experiment\n\n\nclass SequenceLightningModule(pl.LightningModule):\n    def __init__(self, config):\n        # Disable profiling executor. This reduces memory and increases speed.\n        try:\n            torch._C._jit_set_profiling_executor(False)\n            torch._C._jit_set_profiling_mode(False)\n        except AttributeError:\n            pass\n\n        super().__init__()\n        # Passing in config expands it one level: access by self.hparams.train instead of self.hparams.config.train\n        self.save_hyperparameters(config, logger=False)\n\n        # Dataset arguments\n        self.dataset = SequenceDataset.registry[self.hparams.dataset._name_](\n            **self.hparams.dataset\n        )\n\n        # Check hparams\n        self._check_config()\n\n        # PL has some bugs, so add hooks and make sure they're only called once\n        self._has_setup = False\n\n        # To be set in `setup`\n        self.encoder, self.decoder, self.model = None, None, None\n        self.task, self.loss, self.loss_val = None, None, None\n        self.metrics, self.train_torchmetrics, self.val_torchmetrics, self.test_torchmetrics = None, None, None, None\n        self.setup()\n\n        self._state = None\n        self.val_loader_names, self.test_loader_names = None, None\n\n    def setup(self, stage=None):\n        if not self.hparams.train.disable_dataset:\n            self.dataset.setup()\n\n        # We need to set up the model in setup() because for some reason when training with DDP, one GPU uses much more\n        # memory than the others.\n        # In order to not overwrite the model multiple times during different stages, we need this hack\n        # TODO PL 1.5 seems to have an option to skip hooks to avoid this\n        # https://github.com/PyTorchLightning/pytorch-lightning/issues/5410#issuecomment-762257024\n        if self._has_setup:\n            return\n        else:\n            self._has_setup = True\n\n        # Convenience feature: if model specifies encoder, combine it with main encoder\n        encoder_cfg = utils.to_list(self.hparams.encoder) + utils.to_list(\n            self.hparams.model.pop(\"encoder\", None)\n        )\n        decoder_cfg = utils.to_list(\n            self.hparams.model.pop(\"decoder\", None)\n        ) + utils.to_list(self.hparams.decoder)\n\n        # Instantiate model\n        config_path = self.hparams.model.pop(\"config_path\", None)\n        if config_path is not None:\n            with open(config_path) as f:\n                model_config_from_file = json.load(f)\n            self.hparams.model.update(model_config_from_file)\n            # Check if dropout_layer_norm is compiled\n            try:\n                from flash_attn.ops.layer_norm import dropout_add_layer_norm\n            except ImportError:\n                if self.hparams.model.get(\"fused_dropout_add_ln\", None) is not None:\n                    self.hparams.model.update({\"fused_dropout_add_ln\": False})\n        # TODO: Hacky way to get complement_map for Caduceus models; need to find a more elegant implementation\n        if \"caduceus\" in self.hparams.model.get(\"_name_\"):\n            OmegaConf.update(\n                self.hparams.model.config, \"complement_map\", self.dataset.tokenizer.complement_map, force_add=True\n            )\n        # Instantiate the config class if using hydra's _target_ paradigm for the config\n        if self.hparams.model.get(\"config\", None) is not None and self.hparams.model.config.get(\"_target_\", None) is not None:\n            model_hparams = OmegaConf.to_container(self.hparams.model, resolve=True)\n            model_hparams[\"config\"] = hydra.utils.instantiate(model_hparams[\"config\"])\n            self.model = utils.instantiate(registry.model, model_hparams)\n        else:\n            self.model = utils.instantiate(registry.model, self.hparams.model)\n        if (name := self.hparams.train.post_init_hook['_name_']) is not None:\n            kwargs = self.hparams.train.post_init_hook.copy()\n            del kwargs['_name_']\n            for module in self.modules():\n                if hasattr(module, name):\n                    getattr(module, name)(**kwargs)\n\n        # if self.hparams.train.get(\"compile_model\", False):\n        #     self.model = torch.compile(self.model, dynamic=False)\n\n        # Instantiate the task\n        self.task = utils.instantiate(\n            tasks.registry, self.hparams.task, dataset=self.dataset, model=self.model\n        )\n\n        # Create encoders and decoders\n        encoder = encoders.instantiate(\n            encoder_cfg, dataset=self.dataset, model=self.model\n        )\n        decoder = decoders.instantiate(\n            decoder_cfg, model=self.model, dataset=self.dataset\n        )\n\n        # Extract the modules, so they show up in the top level parameter count\n        self.encoder = U.PassthroughSequential(self.task.encoder, encoder)\n        self.decoder = U.PassthroughSequential(decoder, self.task.decoder)\n        self.loss = self.task.loss\n        self.loss_val = self.task.loss\n        if hasattr(self.task, 'loss_val'):\n            self.loss_val = self.task.loss_val\n        self.metrics = self.task.metrics\n        self.train_torchmetrics = self.task.train_torchmetrics\n        self.val_torchmetrics = self.task.val_torchmetrics\n        self.test_torchmetrics = self.task.test_torchmetrics\n\n    def load_state_dict(self, state_dict, strict=False):\n        if self.hparams.train.pretrained_model_state_hook['_name_'] is not None:\n            model_state_hook = utils.instantiate(\n                registry.model_state_hook,\n                self.hparams.train.pretrained_model_state_hook.copy(),\n                partial=True,\n            )\n            state_dict = model_state_hook(self.model, state_dict)\n\n        log.info(\"Custom load_state_dict function is running.\")\n\n        # strict==True will require all modules to match\n        # strict==False can allow encoder/decoder to be loaded from scratch too\n        return super().load_state_dict(state_dict, strict=strict)\n\n    def _check_config(self):\n        assert self.hparams.train.state.mode in [None, \"none\", \"null\", \"reset\", \"bptt\", \"tbptt\"]\n        assert (\n                (n := self.hparams.train.state.n_context) is None\n                or isinstance(n, int)\n                and n >= 0\n        )\n        assert (\n                (n := self.hparams.train.state.n_context_eval) is None\n                or isinstance(n, int)\n                and n >= 0\n        )\n\n    def _initialize_state(self):\n        \"\"\"Called at model setup and start of epoch to completely reset state\"\"\"\n        self._state = None\n        self._memory_chunks = []\n\n    def _reset_state(self, batch, device=None):\n        \"\"\"Called to construct default_state when necessary, e.g. during BPTT\"\"\"\n        device = device or batch[0].device\n        self._state = self.model.default_state(*batch[0].shape[:1], device=device)\n\n    def _detach_state(self, state):\n        if isinstance(state, torch.Tensor):\n            return state.detach()\n        elif isinstance(state, tuple):\n            return tuple(self._detach_state(s) for s in state)\n        elif isinstance(state, list):\n            return [self._detach_state(s) for s in state]\n        elif isinstance(state, dict):\n            return {k: self._detach_state(v) for k, v in state.items()}\n        elif state is None:\n            return None\n        else:\n            raise NotImplementedError\n\n    def _process_state(self, batch, batch_idx, training=True):\n        \"\"\"Handle logic for state context.\"\"\"\n        # Number of context steps\n        key = \"n_context\" if training else \"n_context_eval\"\n        n_context = self.hparams.train.state.get(key)\n\n        # Don't need to do anything if 0 context steps. Make sure there is no state\n        if n_context == 0 and self.hparams.train.state.mode not in ['tbptt']:\n            self._initialize_state()\n            return\n\n        # Reset state if needed\n        if self.hparams.train.state.mode == \"reset\":\n            if batch_idx % (n_context + 1) == 0:\n                self._reset_state(batch)\n\n        # Pass through memory chunks\n        elif self.hparams.train.state.mode == \"bptt\":\n            self._reset_state(batch)\n            with torch.no_grad():  # should be unnecessary because individual modules should handle this\n                for _batch in self._memory_chunks:\n                    self.forward(_batch)\n            # Prepare for next step\n            self._memory_chunks.append(batch)\n            self._memory_chunks = self._memory_chunks[-n_context:]\n\n        elif self.hparams.train.state.mode == 'tbptt':\n            _, _, z = batch\n            reset = z[\"reset\"]\n            if reset:\n                self._reset_state(batch)\n            else:\n                self._state = self._detach_state(self._state)\n\n    def forward(self, batch):\n        return self.task.forward(batch, self.encoder, self.model, self.decoder, self._state)\n\n    def step(self, x_t):\n        x_t, *_ = self.encoder(x_t)  # Potential edge case for encoders that expect (B, L, H)?\n        x_t, state = self.model.step(x_t, state=self._state)\n        self._state = state\n        x_t, *_ = self.decoder.step(x_t, state=state)\n        return x_t\n\n    def _shared_step(self, batch, batch_idx, prefix=\"train\"):\n        \"\"\"Shared step logic between training, validation, and test\"\"\"\n        self._process_state(batch, batch_idx, training=(prefix == \"train\"))\n        x, y, w = self.forward(batch)\n\n        # Loss\n        if prefix == 'train':\n            loss = self.loss(x, y, **w)\n        else:\n            loss = self.loss_val(x, y, **w)\n\n        # Metrics\n        metrics = self.metrics(x, y, **w)\n        metrics[\"loss\"] = loss\n        metrics = {f\"{prefix}/{k}\": v for k, v in metrics.items()}\n\n        # Calculate torchmetrics\n        torchmetrics = getattr(self, f'{prefix}_torchmetrics')\n        torchmetrics(x, y, loss=loss)\n\n        log_on_step = 'eval' in self.hparams and self.hparams.eval.get('log_on_step', False) and prefix == 'train'\n\n        self.log_dict(\n            metrics,\n            on_step=log_on_step,\n            on_epoch=True,\n            prog_bar=True,\n            add_dataloader_idx=False,\n            sync_dist=True,\n        )\n\n        # log the whole dict, otherwise lightning takes the mean to reduce it\n        # https://pytorch-lightning.readthedocs.io/en/stable/visualize/logging_advanced.html#enable-metrics-for-distributed-training\n        self.log_dict(\n            torchmetrics,\n            on_step=log_on_step,\n            on_epoch=True,\n            prog_bar=True,\n            add_dataloader_idx=False,\n            sync_dist=True,\n        )\n        return loss\n\n    def on_train_epoch_start(self):\n        # Reset training torchmetrics\n        self.task._reset_torchmetrics(\"train\")\n\n    def training_epoch_end(self, outputs):\n        # Log training torchmetrics\n        super().training_epoch_end(outputs)\n\n    def on_validation_epoch_start(self):\n        # Reset all validation torchmetrics\n        for name in self.val_loader_names:\n            self.task._reset_torchmetrics(name)\n\n    def validation_epoch_end(self, outputs):\n        # Log all validation torchmetrics\n        super().validation_epoch_end(outputs)\n\n    def on_test_epoch_start(self):\n        # Reset all test torchmetrics\n        for name in self.test_loader_names:\n            self.task._reset_torchmetrics(name)\n\n    def test_epoch_end(self, outputs):\n        # Log all test torchmetrics\n        super().test_epoch_end(outputs)\n\n    def training_step(self, batch, batch_idx, dataloader_idx=0):\n        loss = self._shared_step(batch, batch_idx, prefix=\"train\")\n\n        # Log the loss explicitly so that it shows up in WandB\n        # Note that this currently runs into a bug in the progress bar with ddp (as of 1.4.6)\n        # https://github.com/PyTorchLightning/pytorch-lightning/pull/9142\n        # We additionally log the epochs under 'trainer' to get a consistent prefix with 'global_step'\n        loss_epoch = {\"trainer/loss\": loss, \"trainer/epoch\": float(self.current_epoch)}\n        self.log_dict(\n            loss_epoch,\n            on_step=True,\n            on_epoch=False,\n            prog_bar=False,\n            add_dataloader_idx=False,\n            sync_dist=True,\n        )\n\n        # Log any extra info that the models want to expose (e.g. output norms)\n        metrics = {}\n        for module in list(self.modules())[1:]:\n            if hasattr(module, \"metrics\"):\n                metrics.update(module.metrics)\n\n        self.log_dict(\n            metrics,\n            on_step=True,\n            on_epoch=False,\n            prog_bar=False,\n            add_dataloader_idx=False,\n            sync_dist=True,\n        )\n        return loss\n\n    def validation_step(self, batch, batch_idx, dataloader_idx=0):\n        # There's a bit of an annoying edge case with the first (0-th) epoch; it has to be excluded due to the initial\n        # sanity check\n        ema = (\n                self.val_loader_names[dataloader_idx].endswith(\"/ema\")\n                and self.optimizers().optimizer.stepped\n        )\n        if ema:\n            self.optimizers().swap_ema()\n        loss = self._shared_step(\n            batch, batch_idx, prefix=self.val_loader_names[dataloader_idx]\n        )\n        if ema:\n            self.optimizers().swap_ema()\n\n        return loss\n\n    def test_step(self, batch, batch_idx, dataloader_idx=0):\n        return self._shared_step(\n            batch, batch_idx, prefix=self.test_loader_names[dataloader_idx]\n        )\n\n    def configure_optimizers(self):\n        # Set zero weight decay for some params\n        if 'optimizer_param_grouping' in self.hparams.train:\n            add_optimizer_hooks(self.model, **self.hparams.train.optimizer_param_grouping)\n\n        # Normal parameters\n        all_params = list(self.parameters())\n        params = [p for p in all_params if not hasattr(p, \"_optim\")]\n\n        optimizer = utils.instantiate(registry.optimizer, self.hparams.optimizer, params)\n\n        del self.hparams.optimizer._name_\n\n        # Add parameters with special hyperparameters\n        hps = [getattr(p, \"_optim\") for p in all_params if hasattr(p, \"_optim\")]\n        hps = [\n            # dict(s) for s in set(frozenset(hp.items()) for hp in hps)\n            dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps)))\n            # dict(s) for s in dict.fromkeys(frozenset(hp.items()) for hp in hps)\n        ]  # Unique dicts\n        print(\"Hyperparameter groups:\", hps)  # TODO: log.info throws error because hps is list of dicts\n        for hp in hps:\n            params = [p for p in all_params if getattr(p, \"_optim\", None) == hp]\n            optimizer.add_param_group(\n                {\"params\": params, **self.hparams.optimizer, **hp}\n            )\n\n        # Layer Decay\n        if self.hparams.train.layer_decay['_name_'] is not None:\n            get_num_layer = utils.instantiate(\n                registry.layer_decay,\n                self.hparams.train.layer_decay['_name_'],\n                partial=True,\n            )\n\n            # Go through all parameters and get num layer\n            layer_wise_groups = {}\n            num_max_layers = 0\n            for name, p in self.named_parameters():\n                # Get layer id for each parameter in the model\n                layer_id = get_num_layer(name)\n\n                # Add to layer wise group\n                if layer_id not in layer_wise_groups:\n                    layer_wise_groups[layer_id] = {\n                        'params': [],\n                        'lr': None,\n                        'weight_decay': self.hparams.optimizer.weight_decay\n                    }\n                layer_wise_groups[layer_id]['params'].append(p)\n\n                if layer_id > num_max_layers:\n                    num_max_layers = layer_id\n\n            # Update lr for each layer\n            for layer_id, group in layer_wise_groups.items():\n                group['lr'] = self.hparams.optimizer.lr * (\n                        self.hparams.train.layer_decay.decay ** (num_max_layers - layer_id))\n\n            # Reset the torch optimizers param groups\n            optimizer.param_groups = []\n            for layer_id, group in layer_wise_groups.items():\n                optimizer.add_param_group(group)\n\n        # Print optimizer info for debugging\n        keys = set([k for hp in hps for k in hp.keys()])  # Special hparams\n        utils.train.log_optimizer(log, optimizer, keys)\n        # Configure scheduler\n        if \"scheduler\" not in self.hparams:\n            return optimizer\n        lr_scheduler = utils.instantiate(\n            registry.scheduler, self.hparams.scheduler, optimizer\n        )\n        scheduler = {\n            \"scheduler\": lr_scheduler,\n            \"interval\": self.hparams.train.interval,  # 'epoch' or 'step'\n            \"monitor\": self.hparams.train.monitor,\n            \"name\": \"trainer/lr\",  # default is e.g. 'lr-AdamW'\n        }\n        # See documentation for how to configure the return\n        # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.configure_optimizers\n        return [optimizer], [scheduler]\n\n    def train_dataloader(self):\n        return self.dataset.train_dataloader(**self.hparams.loader)\n\n    def _eval_dataloaders_names(self, loaders, prefix):\n        \"\"\"Process loaders into a list of names and loaders\"\"\"\n        if utils.is_dict(loaders):\n            return [\n                f\"{prefix}/{k}\" if k is not None else prefix for k in loaders.keys()\n            ], list(loaders.values())\n        elif utils.is_list(loaders):\n            return [f\"{prefix}/{i}\" for i in range(len(loaders))], loaders\n        else:\n            return [prefix], [loaders]\n\n    def _eval_dataloaders(self):\n        # Return all val + test loaders\n        val_loaders = self.dataset.val_dataloader(**self.hparams.loader)\n        test_loaders = self.dataset.test_dataloader(**self.hparams.loader)\n        val_loader_names, val_loaders = self._eval_dataloaders_names(val_loaders, \"val\")\n        test_loader_names, test_loaders = self._eval_dataloaders_names(\n            test_loaders, \"test\"\n        )\n\n        # Duplicate datasets for ema\n        if self.hparams.train.ema > 0.0:\n            val_loader_names += [name + \"/ema\" for name in val_loader_names]\n            val_loaders = val_loaders + val_loaders\n            test_loader_names += [name + \"/ema\" for name in test_loader_names]\n            test_loaders = test_loaders + test_loaders\n\n        # adding option to only have val loader at eval (e.g., if test is duplicate)\n        eval_loader_names = []\n        eval_loaders = []\n        if not self.hparams.train.get(\"remove_val_loader_in_eval\", False):\n            eval_loader_names += val_loader_names\n            eval_loaders += val_loaders\n        if not self.hparams.train.get(\"remove_test_loader_in_eval\", False):\n            eval_loader_names += test_loader_names\n            eval_loaders += test_loaders\n        return eval_loader_names, eval_loaders\n\n    def val_dataloader(self):\n        val_loader_names, val_loaders = self._eval_dataloaders()\n        self.val_loader_names = val_loader_names\n        return val_loaders\n\n    def test_dataloader(self):\n        test_loader_names, test_loaders = self._eval_dataloaders()\n        self.test_loader_names = [\"final/\" + name for name in test_loader_names]\n        return test_loaders\n\n\n# pytorch-lightning utils and entrypoint\ndef create_trainer(config, **kwargs):\n    callbacks: List[pl.Callback] = []\n    logger = None\n\n    # WandB Logging\n    if config.get(\"wandb\") is not None:\n        # Pass in wandb.init(config=) argument to get the nice 'x.y.0.z' hparams logged\n        # Can pass in config_exclude_keys='wandb' to remove certain groups\n        import wandb\n\n        logger = CustomWandbLogger(\n            config=utils.to_dict(config, recursive=True),\n            settings=wandb.Settings(start_method=\"fork\"),\n            **config.wandb,\n        )\n\n    # Lightning callbacks\n    if \"callbacks\" in config:\n        for _name_, callback in config.callbacks.items():\n            if config.get(\"wandb\") is None and _name_ in [\"learning_rate_monitor\"]:\n                continue\n            log.info(f\"Instantiating callback <{registry.callbacks[_name_]}>\")\n            callback._name_ = _name_\n            callbacks.append(utils.instantiate(registry.callbacks, callback))\n\n    # Add ProgressiveResizing callback\n    if config.callbacks.get(\"progressive_resizing\", None) is not None:\n        num_stages = len(config.callbacks.progressive_resizing.stage_params)\n        log.info(f\"Progressive Resizing: {num_stages} stages\")\n        for i, e in enumerate(config.callbacks.progressive_resizing.stage_params):\n            # Stage params are resolution and epochs, pretty print\n            log.info(f\"\\tStage {i}: {e['resolution']} @ {e['epochs']} epochs\")\n\n    # Configure ddp automatically\n    n_devices = config.trainer.get('devices', 1)\n    if isinstance(n_devices, Sequence):  # trainer.devices could be [1, 3] for example\n        n_devices = len(n_devices)\n    if n_devices > 1 and config.trainer.get('strategy', None) is None:\n        config.trainer.strategy = dict(\n            _target_='pytorch_lightning.strategies.DDPStrategy',\n            find_unused_parameters=False,\n            # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations\n            gradient_as_bucket_view=True,\n        )\n\n    # Init lightning trainer\n    log.info(f\"Instantiating trainer <{config.trainer._target_}>\")\n    # special processing for seqlen warmup scheduler (reload)\n    trainer = hydra.utils.instantiate(config.trainer, callbacks=callbacks, logger=logger)\n\n    return trainer\n\n\ndef fsspec_exists(filename):\n    fs, _ = fsspec.core.url_to_fs(filename)\n    return fs.exists(filename)\n\n\ndef train(config):\n    if config.train.seed is not None:\n        pl.seed_everything(config.train.seed, workers=True)\n    trainer = create_trainer(config)\n    model = SequenceLightningModule(config)\n\n    # Load pretrained_model if specified\n    if config.train.get(\"pretrained_model_path\", None) is not None:\n        # PTL style.  Note, method returns a new model object, and need to pass config.\n        model = SequenceLightningModule.load_from_checkpoint(\n            config.train.pretrained_model_path,\n            config=config,\n            strict=config.train.pretrained_model_strict_load,\n        )\n\n    # Run initial validation epoch (useful for debugging, fine-tuning)\n    if config.train.validate_at_start:\n        log.info(\"Running validation before training\")\n        trainer.validate(model)\n\n    log.info(f'{config.train.ckpt=} {fsspec_exists(config.train.ckpt)=}')\n    # if config.train.get(\"compile_model\", False):\n    #     model = torch.compile(model, mode=\"reduce-overhead\")\n    if config.train.ckpt is not None and fsspec_exists(config.train.ckpt):\n        trainer.fit(model, ckpt_path=config.train.ckpt)\n    else:\n        trainer.fit(model)\n\n    if config.train.test:\n        if config.train.get(\"cross_validation\", False):  # First, load the best validation model\n            best_val_ckpt = os.path.join(\n                model.hparams.callbacks.model_checkpoint.dirpath,\n                f\"{model.hparams.callbacks.model_checkpoint.filename}.ckpt\",\n            )\n            # Update config so we do not load just the backbone\n            config.train.pretrained_model_state_hook.update({\"_name_\": None})\n            # Remove validation loader\n            config.train.update({\"remove_val_loader_in_eval\": True})\n            config.train.update({\"remove_test_loader_in_eval\": False})\n            ckpt = torch.load(best_val_ckpt)\n            log.info(f\"Loaded best validation checkpoint from epoch {ckpt['epoch']}\")\n            trainer.validate(model, ckpt_path=best_val_ckpt)\n        else:\n            trainer.validate(model)\n\n\n@hydra.main(config_path=\"configs\", config_name=\"config.yaml\")\ndef main(config: OmegaConf):\n    # Process config:\n    # - register evaluation resolver\n    # - filter out keys used only for interpolation\n    # - optional hooks, including disabling python warnings or debug friendly configuration\n    config = utils.train.process_config(config)\n    # if config.train.get(\"compile_model\", False):\n    #     # See: https://github.com/arogozhnikov/einops/wiki/Using-torch.compile-with-einops\n    #     from einops._torch_specific import allow_ops_in_compiled_graph  # requires einops>=0.6.1\n    #     allow_ops_in_compiled_graph()\n\n    # Pretty print config using Rich library\n    utils.train.print_config(config, resolve=True)\n\n    train(config)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "vep_embeddings.py",
    "content": "\"\"\"Dump model embeddings for VEP classification task.\n\n\"\"\"\n\nimport argparse\nimport os\nfrom functools import partial\nfrom os import path as osp\nfrom typing import Dict, Iterable, Optional\n\nimport enformer_pytorch\nimport fsspec\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nfrom datasets import load_dataset, load_from_disk\nfrom sklearn import preprocessing\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torch.utils.data import DataLoader, DistributedSampler\nfrom tqdm.auto import tqdm\nfrom transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, DefaultDataCollator\n\nfrom src.dataloaders.utils.rc import string_reverse_complement\nfrom src.utils.train import get_logger\n\nWINDOW_SIZE_BP = 1536\nlog = get_logger(__name__)\n\n\nclass DNAEmbeddingModel(nn.Module):\n    \"\"\"Wrapper around HF model.\n\n    Args:\n        model_name_or_path: str, path to HF model.\n    \"\"\"\n    def __init__(\n            self,\n            model_name_or_path: str,\n    ):\n        super().__init__()\n        self.model_name_or_path = model_name_or_path\n        # Enformer uses different library for loading\n        if \"enformer\" in model_name_or_path.lower():\n            self.backbone = enformer_pytorch.from_pretrained(\n                model_name_or_path,\n                use_tf_gamma=False,\n                use_checkpointing=True\n            )\n        # NT model is not compatible with AutoModel class\n        elif \"nucleotide-transformer\" in model_name_or_path.lower():\n            # NT LM `backbone` is under the `.esm` attribute\n            self.backbone = AutoModelForMaskedLM.from_pretrained(model_name_or_path, trust_remote_code=True).esm\n        else:\n            self.backbone = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)\n\n    def forward(self, input_ids):\n        \"\"\"Backbone forward pass to retrieve last_hidden_state.\"\"\"\n        if \"enformer\" in self.model_name_or_path.lower():\n            # Enformer forward pass has different signature\n            return self.backbone(input_ids, return_embeddings=True)[1]\n        return self.backbone(input_ids).last_hidden_state\n\nclass EnformerTokenizer:\n    \"\"\"Enformer tokenizer.\"\"\"\n    # Order is important here! (See: https://github.com/lucidrains/enformer-pytorch?tab=readme-ov-file#usage)\n    pad_token = \"P\"  # Padding token should be a character to avoid issues with tokenization\n    encode_map = {\"A\": 0, \"C\": 1, \"G\": 2, \"T\": 3, \"N\": 4, pad_token: -1}\n\n    @classmethod\n    def encode(\n            cls, seq: str, max_length: Optional[int] = None, truncation: Optional[bool] = False\n    ) -> Iterable[int]:\n        \"\"\"Convert bp to token ids.\"\"\"\n        if max_length is not None:\n            assert max_length >= 0, \"max_length should be a positive integer.\"\n            if len(seq) < max_length:\n                seq = seq + cls.pad_token * (max_length - len(seq))\n            elif truncation:\n                seq = seq[:max_length]\n        return [cls.encode_map[bp] for bp in seq.upper()]\n\n    @classmethod\n    def batch_encode_plus(\n            cls, seqs: Iterable[str], max_length: Optional[int] = None, truncation: Optional[bool] = False,\n            **kwargs,  # ensures compatibility with HF tokenizer-like API\n    ) -> Dict[str, Iterable[Iterable[int]]]:\n        \"\"\"Batch encode sequences using HF tokenizer-like API.\"\"\"\n        input_ids = [cls.encode(seq, max_length=max_length, truncation=truncation) for seq in seqs]\n        return {\"input_ids\": input_ids}\n\n\ndef setup_distributed():\n    \"\"\"Set environment variables for distributed runs.\"\"\"\n    dist.init_process_group(\"nccl\")\n\n\ndef cleanup_distributed():\n    \"\"\"Clean up processes from distributed runs.\"\"\"\n    dist.destroy_process_group()\n\n\ndef fsspec_exists(filename):\n    \"\"\"Check if file exists in manner compatible with fsspec.\"\"\"\n    fs, _ = fsspec.core.url_to_fs(filename)\n    return fs.exists(filename)\n\n\ndef fsspec_listdir(dirname):\n    \"\"\"Listdir in manner compatible with fsspec.\"\"\"\n    fs, _ = fsspec.core.url_to_fs(dirname)\n    return fs.ls(dirname)\n\n\n# Processing functions\ndef recast_chromosome_tissue_dist2TSS(examples):\n    \"\"\"Recast chromosome to int.\"\"\"\n    return {\n        \"chromosome\": -1 if examples[\"chromosome\"] == \"X\" else int(examples[\"chromosome\"]),\n        \"tissue\": examples[\"tissue\"],\n        \"distance_to_nearest_tss\": examples[\"distance_to_nearest_tss\"]\n    }\n\n\ndef tokenize_variants(examples, tokenizer, max_length: int):\n    \"\"\"Tokenize sequence.\n\n    Args:\n        examples: (batch of) items from the dataset.\n        tokenizer: AutoTokenizer.\n        max_length: int.\n    Returns:\n        dict with values as list of token ids.\n    \"\"\"\n\n    ref_tokenized = tokenizer.batch_encode_plus(\n        examples[\"ref_forward_sequence\"],\n        add_special_tokens=False,\n        return_attention_mask=False,\n        max_length=max_length,\n        truncation=True,\n    )\n    alt_tokenized = tokenizer.batch_encode_plus(\n        examples[\"alt_forward_sequence\"],\n        add_special_tokens=False,\n        return_attention_mask=False,\n        max_length=max_length,\n        truncation=True,\n    )\n    ref_rc_tokenized = tokenizer.batch_encode_plus(\n        [string_reverse_complement(seq) for seq in examples[\"ref_forward_sequence\"]],\n        add_special_tokens=False,\n        return_attention_mask=False,\n        max_length=max_length,\n        truncation=True,\n    )\n    alt_rc_tokenized = tokenizer.batch_encode_plus(\n        [string_reverse_complement(seq) for seq in examples[\"alt_forward_sequence\"]],\n        add_special_tokens=False,\n        return_attention_mask=False,\n        max_length=max_length,\n        truncation=True,\n    )\n\n    return {\n        \"ref_input_ids\": ref_tokenized[\"input_ids\"],\n        \"alt_input_ids\": alt_tokenized[\"input_ids\"],\n        \"ref_rc_input_ids\": ref_rc_tokenized[\"input_ids\"],\n        \"alt_rc_input_ids\": alt_rc_tokenized[\"input_ids\"],\n    }\n\n\ndef find_variant_idx(examples):\n    \"\"\"Find token location that differs between reference and variant sequence.\n\n    Args:\n        examples: items from the dataset (not batched).\n    Returns:\n        dict with values index of difference.\n    \"\"\"\n    # Guess that variant is at halfway point\n    idx = len(examples[\"ref_input_ids\"]) // 2\n    if examples[\"ref_input_ids\"][idx] == examples[\"alt_input_ids\"][idx]:\n        # If no, loop through sequence and find variant location\n        idx = -1\n        for i, (ref, alt) in enumerate(zip(examples[\"ref_input_ids\"], examples[\"alt_input_ids\"])):\n            if ref != alt:\n                idx = i\n    # Same as above, but for reverse complement\n    rc_idx = len(examples[\"ref_rc_input_ids\"]) // 2 - 1\n    if examples[\"ref_rc_input_ids\"][rc_idx] == examples[\"alt_rc_input_ids\"][rc_idx]:\n        rc_idx = -1\n        for i, (ref, alt) in enumerate(zip(examples[\"ref_rc_input_ids\"], examples[\"alt_rc_input_ids\"])):\n            if ref != alt:\n                rc_idx = i\n    return {\"variant_idx\": idx, \"rc_variant_idx\": rc_idx}\n\n\ndef prepare_dataset(args, tokenizer):\n    \"\"\"Prepare or load the tokenized dataset.\"\"\"\n    # Data Preprocessing\n    num_tokens = args.seq_len // args.bp_per_token\n\n    # Load data\n    cache_dir = osp.join(\n        os.getenv(\"HF_HOME\"), \"datasets\", \"InstaDeepAI___genomics-long-range-benchmark\",\n        \"variant_effect_gene_expression\", f\"seqlen={args.seq_len}\"\n    )\n    if \"nucleotide-transformer\" in args.model_name_or_path.lower():  # NT uses 6-mers, so tokenization is different\n        preprocessed_cache_file = osp.join(cache_dir, \"6mer_token_preprocessed\")\n\n    elif \"enformer\" in args.model_name_or_path.lower():\n        # Enformer tokenization requires having vocab of just `A,C,G,T,N` (in that order)\n        preprocessed_cache_file = osp.join(cache_dir, \"enformer_char_token_preprocessed\")\n    else:\n        preprocessed_cache_file = osp.join(cache_dir, \"char_token_preprocessed\")\n    log.warning(f\"Cache dir: {cache_dir}\")\n    log.warning(f\"Cache dir preprocessed: {preprocessed_cache_file}\")\n\n    if not fsspec_exists(preprocessed_cache_file):\n        if dist.get_rank() == 0:\n            dataset = load_dataset(\n                \"InstaDeepAI/genomics-long-range-benchmark\",\n                task_name=\"variant_effect_gene_expression\",\n                sequence_length=args.seq_len,\n                load_from_cache=False,\n            )\n            log.warning(\"Dataset loaded. Cached to disk:\")\n            log.warning(osp.dirname(list(dataset.cache_files.values())[0][0][\"filename\"]))\n            try:\n                del dataset[\"validation\"]  # `validation` split is empty\n            except KeyError:\n                pass\n\n            # Process data\n            dataset = dataset.filter(\n                lambda example: example[\"ref_forward_sequence\"].count('N') < 0.005 * args.seq_len,\n                desc=\"Filter N's\"\n            )\n            dataset = dataset.map(\n                recast_chromosome_tissue_dist2TSS,\n                remove_columns=[\"chromosome\", \"tissue\", \"distance_to_nearest_tss\"],\n                desc=\"Recast chromosome\"\n            )\n            dataset = dataset.map(\n                partial(tokenize_variants, tokenizer=tokenizer, max_length=num_tokens),\n                batch_size=1000,\n                batched=True,\n                remove_columns=[\"ref_forward_sequence\", \"alt_forward_sequence\"],\n                desc=\"Tokenize\"\n            )\n            dataset = dataset.map(find_variant_idx, desc=\"Find variant idx\")\n            dataset.save_to_disk(preprocessed_cache_file)\n    dist.barrier()  # Processes need to wait for dataset to be saved to disk (if not already done)\n    dataset = load_from_disk(preprocessed_cache_file)\n    log.warning(f\"Loaded preprocessed dataset from {preprocessed_cache_file}\")\n    log.warning(dataset)\n    return dataset\n\n\ndef get_backbone_model(args, device):\n    \"\"\"Get the backbone model.\"\"\"\n\n    model = DNAEmbeddingModel(\n        model_name_or_path=args.model_name_or_path,\n    )\n    model.eval()\n    return DDP(model.to(device))\n\n\ndef concat_storage_dict_values(storage_dict):\n    \"\"\"Helper method that combines lists of tensors in storage_dict into a single torch.Tensor.\"\"\"\n    return {key: torch.cat(storage_dict[key], dim=0) for key in storage_dict.keys()}\n\n\ndef dump_embeddings(args, dataset, model, device):\n    \"\"\"Dump embeddings to disk.\"\"\"\n    def extract_embeddings(item_ref, item_alt, variant_idx):\n        \"\"\"Extract embedding representation from last layer outputs\n\n        Args:\n            item_ref: torch.Tensor, shape (batch_size, seq_len, hidden_size) Ref embedding\n            item_alt: torch.Tensor, shape (batch_size, seq_len, hidden_size) Alt embedding\n            variant_idx: torch.Tensor, shape (batch_size,) Index of variant\n        Returns:\n            layer_metrics: dict, with values to save to disk\n        \"\"\"\n        layer_metrics = {}\n\n        # Compute windowed statistics\n        if \"enformer\" in args.model_name_or_path.lower():\n            window_size = WINDOW_SIZE_BP // 128  # Enformer's receptive field is 128\n            # We also need to override variant_idx since Enformer model reduces to target_length of 896\n            variant_idx = torch.ones_like(variant_idx) * item_ref.size(1) // 2\n        else:\n            window_size = WINDOW_SIZE_BP // args.bp_per_token\n\n        # Add 1 so that window is: [window // 2 - SNP - window // 2]\n        start, end = -window_size // 2, window_size // 2 + 1\n        expanded_indices = torch.arange(start, end, device=item_ref.device).unsqueeze(0) + \\\n                           variant_idx.unsqueeze(1).to(item_ref.device)\n        expanded_indices = torch.clamp(expanded_indices, 0, item_ref.size(1) - 1)  # Handle boundary conditions\n        tokens_window_ref = torch.gather(\n            item_ref, 1,\n            expanded_indices.unsqueeze(-1).expand(-1, -1, item_ref.size(2))\n        ).mean(dim=1)\n        tokens_window_alt = torch.gather(\n            item_alt, 1,\n            expanded_indices.unsqueeze(-1).expand(-1, -1, item_ref.size(2))\n        ).mean(dim=1)\n        layer_metrics[\"concat_avg_ws\"] = torch.cat([tokens_window_ref, tokens_window_alt], dim=-1)\n        return layer_metrics\n\n    embeds_path = osp.join(args.downstream_save_dir, args.name)\n    os.makedirs(embeds_path, exist_ok=True)\n\n    dataloader_params = {\n        \"batch_size\": args.embed_dump_batch_size,\n        \"collate_fn\": DefaultDataCollator(return_tensors=\"pt\"),\n        \"num_workers\": args.num_workers,\n        \"pin_memory\": False,\n        \"shuffle\": False,\n        \"drop_last\": True\n    }\n\n    # Process label_encoder = preprocessing.LabelEncoder()\n    label_encoder = preprocessing.LabelEncoder()\n    label_encoder.fit(dataset[\"test\"][\"tissue\"])\n    train_tissue_embed = label_encoder.transform(dataset[\"train\"][\"tissue\"])\n    dataset[\"train\"] = dataset[\"train\"].add_column(\"tissue_embed\", train_tissue_embed)\n    test_tissue_embed = label_encoder.transform(dataset[\"test\"][\"tissue\"])\n    dataset[\"test\"] = dataset[\"test\"].add_column(\"tissue_embed\", test_tissue_embed)\n\n    if not all([\n        fsspec_exists(osp.join(embeds_path, f\"{split_name}_embeds_combined.pt\")) for split_name in dataset.keys()\n    ]):\n        for split_name, split in dataset.items():\n            sampler = DistributedSampler(\n                split,\n                shuffle=dataloader_params.get(\"shuffle\", False),\n                drop_last=dataloader_params.get(\"drop_last\", True),\n            )\n\n            dl = DataLoader(split, **dataloader_params, sampler=sampler)\n\n            storage_dict = {\n                \"concat_avg_ws\": [],\n                \"rc_concat_avg_ws\": [],\n                \"chromosome\": [],\n                \"labels\": [],\n                \"distance_to_nearest_tss\": [],\n                \"tissue_embed\": [],\n            }\n\n            with torch.no_grad():\n\n                for batch_idx, batch in tqdm(\n                        enumerate(dl), total=len(dl), desc=f\"[RANK {dist.get_rank()}] Embedding {split_name}\",\n                        disable=dist.get_rank() != 0  # Only rank 0 updates pbar\n                ):\n                    for key in [\"chromosome\", \"labels\", \"distance_to_nearest_tss\", \"tissue_embed\"]:\n                        storage_dict[key].append(batch[key].to(\"cpu\", non_blocking=True))\n                    with torch.autocast(device_type=\"cuda\", dtype=torch.float16):\n                        output_alt = model(batch[\"alt_input_ids\"].to(device))\n                        output_ref = model(batch[\"ref_input_ids\"].to(device))\n                        if args.rcps:\n                            num_channels = output_alt.size(-1)\n                            # Flip along length and channel dims to preserve RC equivariance\n                            # i.e. output_rc(RC(inputs)) = outputs(inputs)\n                            output_alt_rc = output_alt[..., num_channels // 2:].contiguous().flip(dims=[1, 2])\n                            output_ref_rc = output_ref[..., num_channels // 2:].contiguous().flip(dims=[1, 2])\n                            output_alt = output_alt[..., :num_channels // 2]\n                            output_ref = output_ref[..., :num_channels // 2]\n\n                        else:\n                            # Flip along length dim so variant_idx aligns\n                            output_alt_rc = model(batch[\"alt_rc_input_ids\"].to(device)).contiguous().flip(dims=[1])\n                            output_ref_rc = model(batch[\"ref_rc_input_ids\"].to(device)).contiguous().flip(dims=[1])\n\n                    metrics = extract_embeddings(\n                        item_ref=output_ref,\n                        item_alt=output_alt,\n                        variant_idx=batch[\"variant_idx\"],\n                    )\n                    for key, value in metrics.items():\n                        storage_dict[key].append(metrics[key].to(\"cpu\", non_blocking=True))\n\n                    metrics_rc = extract_embeddings(\n                        item_ref=output_ref_rc,\n                        item_alt=output_alt_rc,\n                        variant_idx=batch[\"variant_idx\"],\n                    )\n                    for key, value in metrics_rc.items():\n                        storage_dict[f\"rc_{key}\"].append(metrics_rc[key].to(\"cpu\", non_blocking=True))\n\n                    if batch_idx % 100 == 0:\n                        # Every machine should print progress updates\n                        print(f\"[RANK {dist.get_rank()}] Completed index: {batch_idx}/{len(dl)}\")\n\n                storage_dict_temp = concat_storage_dict_values(storage_dict)\n                with fsspec.open(osp.join(embeds_path, f\"{split_name}_embeds_{dist.get_rank()}.pt\"), \"wb\") as f:\n                    torch.save(storage_dict_temp, f)\n                print(f\"[RANK {dist.get_rank()}] Saved {split_name} to {osp.join(embeds_path, f'{split_name}_embeds_{dist.get_rank()}.pt')}\")\n    else:\n        log.warning(\"Embeddings already exist, skipping!\")\n\n\ndef combine_embeddings(embeds_path):\n    \"\"\"Combine embeddings from different files.\"\"\"\n    # Check if combined embeddings exist, and if not, aggregate them\n    for split in [\"train\", \"test\"]:\n        if not fsspec_exists(osp.join(embeds_path, f\"{split}_embeds_combined.pt\")):\n            storage_dict = {\n                \"concat_avg_ws\": [],\n                \"rc_concat_avg_ws\": [],\n                \"chromosome\": [],\n                \"labels\": [],\n                \"distance_to_nearest_tss\": [],\n                \"tissue_embed\": [],\n            }\n            for filename in fsspec_listdir(embeds_path):\n                if f\"{split}_embeds_\" in filename:\n                    log.warning(f\"Loading data from: {filename}\")\n                    with fsspec.open(filename, \"rb\") as f:\n                        tmp_data = torch.load(f)\n                    for key in storage_dict.keys():\n                        storage_dict[key].append(tmp_data[key])\n            storage_dict = concat_storage_dict_values(storage_dict)\n            log.warning(f\"Saving combined data to: {embeds_path}/{split}_embeds_combined.pt\")\n            with fsspec.open(osp.join(embeds_path, f\"{split}_embeds_combined.pt\"), \"wb\") as f:\n                torch.save(storage_dict, f)\n\n\ndef main(args):\n    \"\"\"Main entry point.\"\"\"\n    # Reproducibility\n    torch.use_deterministic_algorithms(True)\n    torch.backends.cudnn.benchmark = False\n\n    # Init distributed\n    log.warning(\"Initializing distributed...\")\n    dist.init_process_group(\"nccl\")\n    print(f\"[RANK {dist.get_rank()}] Distributed initialized: rank {dist.get_rank()}\")  # All processes print this\n    # Setup device\n    device = torch.device(f\"cuda:{dist.get_rank()}\")\n    print(f\"[RANK {dist.get_rank()}] Using device: {device}.\")  # All processes print this\n\n    # Init tokenizer\n    if \"enformer\" in args.model_name_or_path.lower():\n        # Enformer tokenization requires having vocab of just `A,C,G,T,N` (in that order)\n        tokenizer = EnformerTokenizer()\n    else:\n        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)\n\n    # Get dataset\n    dist.barrier()\n    dataset = prepare_dataset(args, tokenizer)\n\n    # Get model\n    dist.barrier()\n    model = get_backbone_model(args, device)\n    log.warning(\"Model loaded.\")\n\n    # Dump embeddings\n    dist.barrier()\n    dump_embeddings(args, dataset, model, device)\n\n    # Combine embeddings into single file\n    dist.barrier()\n    cleanup_distributed()\n    combine_embeddings(osp.join(args.downstream_save_dir, args.name))\n\n\nif __name__ == \"__main__\":\n    torch.multiprocessing.set_sharing_strategy('file_system')\n    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    parser.add_argument(\"--seq_len\", type=int, default=131072,\n                        help=\"Sequence length (in bp)..\")\n    parser.add_argument(\"--bp_per_token\", type=int, default=1,\n                        help=\"Number of base pairs per token.\")\n    parser.add_argument(\"--model_name_or_path\", type=str, default=None)\n    parser.add_argument(\"--downstream_save_dir\", type=str, default=\"./outputs/downstream/vep_embeddings\",\n                        help=\"Directory to save downstream task.\")\n    parser.add_argument(\"--name\", type=str, default=None, help=\"Embeddings model name.\")\n    parser.add_argument(\"--rcps\", default=False, action=\"store_true\", help=\"Use RCPS.\")\n    parser.add_argument(\"--no-rcps\", dest=\"rcps\", action=\"store_false\", help=\"Do not use RCPS.\")\n    parser.add_argument(\"--embed_dump_batch_size\", type=int, default=1,\n                        help=\"Batch size for embedding dump.\")\n    parser.add_argument(\"--num_workers\", type=int, default=0, help=\"Number of workers.\")\n    opts, _ = parser.parse_known_args()\n    log.warning(\"*** Args ************************\")\n    for k, v in vars(opts).items():\n        log.warning(f\"  - {k}: {v}\")\n    log.warning(\"******************************\\n\")\n\n    main(opts)\n"
  },
  {
    "path": "vep_svm.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"db878bc1\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Imports and Setup\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"c79a903b\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import random\\n\",\n    \"import time\\n\",\n    \"from os import path as osp\\n\",\n    \"\\n\",\n    \"import fsspec\\n\",\n    \"import matplotlib.pyplot as plt\\n\",\n    \"import numpy as np\\n\",\n    \"import pandas as pd\\n\",\n    \"import seaborn as sns\\n\",\n    \"import torch\\n\",\n    \"from sklearn.metrics import roc_auc_score\\n\",\n    \"from sklearn.pipeline import make_pipeline\\n\",\n    \"from sklearn.preprocessing import StandardScaler\\n\",\n    \"from sklearn.svm import SVC\\n\",\n    \"from tqdm.auto import tqdm\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"4034f167\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"DIST_TO_TSS = [[0, 30_000], [30_000, 100_000], [100_000, np.infty]]\\n\",\n    \"USE_TISSUE = [True]  # used as another for loop for fitting SVM, whether to use tissue embed or not\\n\",\n    \"Cs = [1, 5, 10]  # for loop in fitting SVM, inverse of L2 penalty (sklearn hyperparam)\\n\",\n    \"PATH_TO_OUTPUTS = \\\"./outputs/downstream/vep_embeddings\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"55c58437\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"def fsspec_exists(filename: str) -> bool:\\n\",\n    \"    \\\"\\\"\\\"Check if file exists in manner compatible with fsspec.\\\"\\\"\\\"\\n\",\n    \"    fs, _ = fsspec.core.url_to_fs(filename)\\n\",\n    \"    return fs.exists(filename)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"18522e17\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def dataset_nan_filter(data: dict, data_key: str) -> dict:\\n\",\n    \"    \\\"\\\"\\\"Filter any items that have NaN in embedding within TSS bucket\\\"\\\"\\\"\\n\",\n    \"    mask_out = torch.logical_or(\\n\",\n    \"        torch.any(data[data_key].isnan(), dim=1),\\n\",\n    \"        torch.any(data[f\\\"rc_{data_key}\\\"].isnan(), dim=1)\\n\",\n    \"    )\\n\",\n    \"    \\n\",\n    \"    new_data = dict()\\n\",\n    \"    for data_key in data.keys():\\n\",\n    \"        new_data[data_key] = data[data_key][~mask_out]\\n\",\n    \"\\n\",\n    \"    return new_data\\n\",\n    \"\\n\",\n    \"def dataset_tss_filter(data: dict, min_distance: int, max_distance: int) -> dict:\\n\",\n    \"    \\\"\\\"\\\"Filter the data to items that fall within TSS bucket\\\"\\\"\\\"\\n\",\n    \"    distance_mask = ((data[\\\"distance_to_nearest_tss\\\"] >= min_distance) \\n\",\n    \"                     & (data[\\\"distance_to_nearest_tss\\\"] <= max_distance))\\n\",\n    \"    new_data = dict()\\n\",\n    \"    for data_key in data.keys():\\n\",\n    \"        new_data[data_key] = data[data_key][distance_mask]\\n\",\n    \"\\n\",\n    \"    return new_data\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ef3d1006\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Specify which models to test\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"4629cb30\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Embeddings to test\\n\",\n    \"model_dict = {\\n\",\n    \"    \\\"HyenaDNA\\\": dict(\\n\",\n    \"        embed_path=\\\"hyena_downstream-seqlen=131k\\\",\\n\",\n    \"        rc_aug=False,\\n\",\n    \"        conjoin_train=False,\\n\",\n    \"        conjoin_test=False,\\n\",\n    \"        key=\\\"concat_avg_ws\\\",\\n\",\n    \"    ),\\n\",\n    \"    \\\"Caduceus-Ph\\\": dict(\\n\",\n    \"        embed_path=\\\"caduceus-ph_downstream-seqlen=131k\\\",\\n\",\n    \"        rc_aug=False,\\n\",\n    \"        conjoin_train=False,\\n\",\n    \"        conjoin_test=True,\\n\",\n    \"        key=\\\"concat_avg_ws\\\",\\n\",\n    \"    ),\\n\",\n    \"    \\\"Caduceus w/o Equiv.\\\": dict(\\n\",\n    \"        embed_path=\\\"caduceus-ph_downstream-seqlen=131k\\\",\\n\",\n    \"        rc_aug=False,\\n\",\n    \"        conjoin_train=False,\\n\",\n    \"        conjoin_test=False,\\n\",\n    \"        key=\\\"concat_avg_ws\\\",\\n\",\n    \"    ),\\n\",\n    \"    \\\"Caduceus-PS\\\": dict(\\n\",\n    \"        embed_path=\\\"caduceus-ps_downstream-seqlen=131k\\\",\\n\",\n    \"        rc_aug=False,\\n\",\n    \"        conjoin_train=True,\\n\",\n    \"        conjoin_test=False,\\n\",\n    \"        key=\\\"concat_avg_ws\\\",\\n\",\n    \"    ),\\n\",\n    \"    \\\"Enformer\\\": dict(\\n\",\n    \"        embed_path=\\\"enformer-seqlen=196k\\\",\\n\",\n    \"        rc_aug=False,\\n\",\n    \"        conjoin_train=False,\\n\",\n    \"        conjoin_test=False,\\n\",\n    \"        key=\\\"concat_avg_ws\\\",\\n\",\n    \"    ),\\n\",\n    \"    \\\"NTv2\\\": dict(\\n\",\n    \"        embed_path=\\\"NTv2_downstream-seqlen=12k\\\",\\n\",\n    \"        rc_aug=False,\\n\",\n    \"        conjoin_train=False,\\n\",\n    \"        conjoin_test=False,\\n\",\n    \"        key=\\\"concat_avg_ws\\\",\\n\",\n    \"    ),\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"12e64367\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Fit and test SVM\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"6eaeb519-5c35-4fba-a09b-2d47c122320d\",\n   \"metadata\": {\n    \"scrolled\": false,\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"metrics = {\\n\",\n    \"    \\\"model_name\\\": [],\\n\",\n    \"    \\\"bucket_id\\\": [],\\n\",\n    \"    \\\"use_tissue\\\": [],\\n\",\n    \"    \\\"C\\\": [],\\n\",\n    \"    \\\"seed\\\": [],\\n\",\n    \"    \\\"AUROC\\\": [],\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"for model_name, downstream_kwargs in model_dict.items():\\n\",\n    \"    print(f\\\"********** Gathering results for: {model_name} **********\\\")\\n\",\n    \"    embed_path = downstream_kwargs[\\\"embed_path\\\"]\\n\",\n    \"    rc_aug = downstream_kwargs[\\\"rc_aug\\\"]\\n\",\n    \"    conjoin_train = downstream_kwargs[\\\"conjoin_train\\\"]\\n\",\n    \"    conjoin_test = downstream_kwargs[\\\"conjoin_test\\\"]\\n\",\n    \"    key = downstream_kwargs[\\\"key\\\"]\\n\",\n    \"    \\n\",\n    \"    if \\\"NT\\\" in model_name: assert (rc_aug == False) and (conjoin_train == False) and (conjoin_test == False)\\n\",\n    \"    \\n\",\n    \"    base_embeds_path = PATH_TO_OUTPUTS\\n\",\n    \"    embeds_path = osp.join(base_embeds_path, embed_path)\\n\",\n    \"    \\n\",\n    \"    print(f\\\"Embed Path: {embeds_path}\\\")\\n\",\n    \"    with fsspec.open(osp.join(embeds_path, \\\"train_embeds_combined.pt\\\"), \\\"rb\\\") as f:\\n\",\n    \"        train_val_ds_raw = torch.load(f, map_location=\\\"cpu\\\")\\n\",\n    \"        train_val_ds_raw = dataset_nan_filter(train_val_ds_raw, data_key=key)\\n\",\n    \"    with fsspec.open(osp.join(embeds_path, \\\"test_embeds_combined.pt\\\"), \\\"rb\\\") as f:\\n\",\n    \"        test_ds_raw = torch.load(f, map_location=\\\"cpu\\\")\\n\",\n    \"        test_ds_raw = dataset_nan_filter(test_ds_raw, data_key=key)\\n\",\n    \"    print(f\\\"Total Train size: {len(train_val_ds_raw[key])},\\\", end=\\\" \\\")\\n\",\n    \"    print(f\\\"Total Test size: {len(test_ds_raw[key])},\\\", end=\\\" \\\")\\n\",\n    \"    print(f\\\"Shape: {test_ds_raw[key].shape[1:]}\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"    for bucket_id, (min_dist, max_dist) in enumerate(DIST_TO_TSS):\\n\",\n    \"        # Filter data to desired TSS bucket\\n\",\n    \"        train_val_ds_filter = dataset_tss_filter(train_val_ds_raw, min_dist, max_dist)\\n\",\n    \"        test_ds_filter = dataset_tss_filter(test_ds_raw, min_dist, max_dist)\\n\",\n    \"        print(f\\\"- TSS bucket: [{min_dist}, {max_dist}],\\\", end=\\\" \\\")\\n\",\n    \"        print(f\\\"Train size: {len(train_val_ds_filter[key])},\\\", end=\\\" \\\")\\n\",\n    \"        print(f\\\"Test size: {len(test_ds_filter[key])}\\\")\\n\",\n    \"    \\n\",\n    \"        for use_tissue in USE_TISSUE:\\n\",\n    \"            for C in Cs:\\n\",\n    \"                for seed in range(1, 6):     \\n\",\n    \"                    # Re-seed for SVM fitting\\n\",\n    \"                    random.seed(seed)\\n\",\n    \"                    np.random.seed(seed)\\n\",\n    \"                    torch.manual_seed(seed)\\n\",\n    \"                    torch.cuda.manual_seed_all(seed)\\n\",\n    \"\\n\",\n    \"                    svm_clf = make_pipeline(\\n\",\n    \"                        StandardScaler(),\\n\",\n    \"                        SVC(C=C, random_state=seed),\\n\",\n    \"                    )\\n\",\n    \"\\n\",\n    \"                    # Setup Train/Test dataset\\n\",\n    \"                    if conjoin_train:\\n\",\n    \"                        X = np.array(train_val_ds_filter[key])\\n\",\n    \"                        X += np.array(train_val_ds_filter[f\\\"rc_{key}\\\"])\\n\",\n    \"                        X /= 2\\n\",\n    \"                    else:\\n\",\n    \"                        X = np.array(train_val_ds_filter[key])\\n\",\n    \"                    X_with_tissue = np.concatenate(\\n\",\n    \"                        [X, np.array(train_val_ds_filter[\\\"tissue_embed\\\"])[..., None]],\\n\",\n    \"                        axis=-1\\n\",\n    \"                    )\\n\",\n    \"                    y = train_val_ds_filter[\\\"labels\\\"]\\n\",\n    \"                    if conjoin_train or conjoin_test:\\n\",\n    \"                        X_test = np.array(test_ds_filter[key])\\n\",\n    \"                        X_test += np.array(test_ds_filter[f\\\"rc_{key}\\\"])\\n\",\n    \"                        X_test /= 2\\n\",\n    \"                    else:\\n\",\n    \"                        X_test = np.array(test_ds_filter[key])\\n\",\n    \"                    X_test_with_tissue = np.concatenate(\\n\",\n    \"                        [X_test, np.array(test_ds_filter[\\\"tissue_embed\\\"])[..., None]],\\n\",\n    \"                        axis=-1\\n\",\n    \"                    )\\n\",\n    \"                    y_test = test_ds_filter[\\\"labels\\\"]\\n\",\n    \"\\n\",\n    \"                    print(f\\\"\\\\tFitting SVM ({use_tissue=}, {C=}, {seed=})...\\\", end=\\\" \\\")\\n\",\n    \"                    \\n\",\n    \"                    mask = np.random.choice(len(X), size=5000, replace= 5000 > len(X) )\\n\",\n    \"                    if use_tissue: \\n\",\n    \"                        X_train = X_with_tissue[mask]\\n\",\n    \"                        X_test = X_test_with_tissue\\n\",\n    \"                    else: \\n\",\n    \"                        X_train = X[mask]\\n\",\n    \"                    y_train = y[mask]\\n\",\n    \"\\n\",\n    \"                    start = time.time()\\n\",\n    \"                    svm_clf.fit(X_train, y_train)\\n\",\n    \"                    svm_y_pred = svm_clf.predict(X_test)\\n\",\n    \"                    svm_aucroc = roc_auc_score(y_test, svm_y_pred)\\n\",\n    \"                    end = time.time()\\n\",\n    \"                    print(f\\\"Completed! ({end - start:0.3f} s) -\\\", end=\\\" \\\")\\n\",\n    \"                    print(f\\\"AUROC: {svm_aucroc}\\\")\\n\",\n    \"                     \\n\",\n    \"                    metrics[\\\"model_name\\\"] += [model_name]\\n\",\n    \"                    metrics[\\\"bucket_id\\\"] += [bucket_id]\\n\",\n    \"                    metrics[\\\"use_tissue\\\"] += [use_tissue]\\n\",\n    \"                    metrics[\\\"C\\\"] += [C]\\n\",\n    \"                    metrics[\\\"seed\\\"] += [seed]\\n\",\n    \"                    metrics[\\\"AUROC\\\"] += [svm_aucroc]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"597b0fe9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df_metrics = pd.DataFrame.from_dict(metrics)\\n\",\n    \"df_metrics.to_csv(osp.join(PATH_TO_OUTPUTS, \\\"SVM_results.csv\\\"))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"03e06a25\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Plot results\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"a362d3fa\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name_replacement = {\\n\",\n    \"    \\\"Caduceus w/o Equiv.\\\": \\\"Caduceus w/o\\\\nEquiv. (7.7M)\\\",\\n\",\n    \"    \\\"Caduceus-Ph\\\": \\\"Caduceus-Ph\\\\n(7.7M)\\\",\\n\",\n    \"    \\\"Caduceus-PS\\\": \\\"Caduceus-PS\\\\n(7.7M)\\\",\\n\",\n    \"    \\\"HyenaDNA\\\": \\\"HyenaDNA\\\\n(6.6M)\\\",\\n\",\n    \"    \\\"NTv2\\\": \\\"NTv2\\\\n(500M)\\\",\\n\",\n    \"    \\\"Enformer\\\": \\\"Enformer\\\\n(252M)\\\",\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"85b1c4fb\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Formatting changes to df\\n\",\n    \"df = pd.read_csv(osp.join(PATH_TO_OUTPUTS, \\\"SVM_results.csv\\\"), index_col=0)\\n\",\n    \"df_display = df.rename(columns={\\\"bucket_id\\\": \\\"Distance to TSS\\\"})\\n\",\n    \"df_display = df_display.replace({\\\"Distance to TSS\\\": {0: \\\"0 - 30k\\\", 1: \\\"30 - 100k\\\", 2: \\\"100k+\\\"}})\\n\",\n    \"df_display = df_display.replace({\\\"model_name\\\": model_name_replacement})\\n\",\n    \"\\n\",\n    \"# Take average over seeds\\n\",\n    \"df_display_selected = df_display.groupby(\\n\",\n    \"    [\\\"model_name\\\", \\\"Distance to TSS\\\", \\\"use_tissue\\\", \\\"C\\\"]\\n\",\n    \").agg(AUROC=(\\\"AUROC\\\", np.mean)).reset_index()\\n\",\n    \"\\n\",\n    \"# Select best hyperparam by model/bucket\\n\",\n    \"best_ids = df_display_selected.groupby([\\\"model_name\\\", \\\"Distance to TSS\\\"])[\\\"AUROC\\\"].idxmax()\\n\",\n    \"df_display_selected = df_display_selected.loc[best_ids.reset_index()[\\\"AUROC\\\"].values]\\n\",\n    \"display(\\n\",\n    \"    df_display_selected.pivot(\\n\",\n    \"        index=\\\"model_name\\\", columns=\\\"Distance to TSS\\\", values=\\\"AUROC\\\"\\n\",\n    \"    )[[\\\"0 - 30k\\\", \\\"30 - 100k\\\", \\\"100k+\\\"]]\\n\",\n    \")\\n\",\n    \"display(df_display_selected[[\\\"model_name\\\", \\\"Distance to TSS\\\", \\\"C\\\", \\\"use_tissue\\\"]])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"09a7f4a9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Filter results to selected hyperparams\\n\",\n    \"df_plot = pd.merge(\\n\",\n    \"    df_display, df_display_selected,\\n\",\n    \"    on=[\\\"model_name\\\", \\\"Distance to TSS\\\", \\\"use_tissue\\\", \\\"C\\\"]\\n\",\n    \").drop(columns=[\\\"AUROC_y\\\"]).rename(columns={\\\"AUROC_x\\\": \\\"AUROC\\\"})\\n\",\n    \"\\n\",\n    \"# Plot results by distance to TSS\\n\",\n    \"sns.set_style(\\\"whitegrid\\\")\\n\",\n    \"g = sns.catplot(\\n\",\n    \"    data=df_plot,\\n\",\n    \"    x=\\\"model_name\\\",\\n\",\n    \"    y=\\\"AUROC\\\",\\n\",\n    \"    col=\\\"Distance to TSS\\\",\\n\",\n    \"    hue=\\\"Distance to TSS\\\",\\n\",\n    \"    kind=\\\"bar\\\",\\n\",\n    \"    errorbar=\\\"sd\\\",\\n\",\n    \"    height=12,\\n\",\n    \"    aspect=1,\\n\",\n    \"    dodge=False,\\n\",\n    \"    order=list(model_name_replacement.values()),\\n\",\n    \")\\n\",\n    \"g.set_xticklabels(rotation=60, fontsize=30)\\n\",\n    \"g.set(xlabel=\\\"\\\")\\n\",\n    \"g.set(ylim=(0.4, 0.7))\\n\",\n    \"g.set_titles(template=\\\"Dist. to TSS: {col_name}\\\", fontsize=40)\\n\",\n    \"g.fig.suptitle(\\\"Predicting Effects of Variants on Gene Expression\\\", y=1.1, fontsize=40)\\n\",\n    \"g._legend.remove()\\n\",\n    \"# Display bar values\\n\",\n    \"# (See: https://stackoverflow.com/questions/55586912/seaborn-catplot-set-values-over-the-bars)\\n\",\n    \"for ax in tqdm(g.axes.ravel(), leave=False):\\n\",\n    \"    title = ax.title.get_text()\\n\",\n    \"    ax.set_title(title, fontsize=35)\\n\",\n    \"    for c in tqdm(ax.containers, leave=False):\\n\",\n    \"        labels = [f\\\"{v.get_height():0.3f}\\\" for v in c]\\n\",\n    \"        ax.bar_label(c, labels=labels, label_type=\\\"center\\\", color=\\\"white\\\", weight=\\\"bold\\\", fontsize=24)\\n\",\n    \"plt.show()\\n\",\n    \"g.savefig(osp.join(PATH_TO_OUTPUTS, \\\"SVM_results.png\\\"))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"a7858241\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.18\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  }
]